CIRCT 23.0.0git
Loading...
Searching...
No Matches
FIRRTLFolds.cpp
Go to the documentation of this file.
1//===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the folding and canonicalizations for FIRRTL ops.
10//
11//===----------------------------------------------------------------------===//
12
17#include "circt/Support/APInt.h"
18#include "circt/Support/LLVM.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallPtrSet.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
27
28using namespace circt;
29using namespace firrtl;
30
31// Drop writes to old and pass through passthrough to make patterns easier to
32// write.
33static Value dropWrite(PatternRewriter &rewriter, OpResult old,
34 Value passthrough) {
35 SmallPtrSet<Operation *, 8> users;
36 for (auto *user : old.getUsers())
37 users.insert(user);
38 for (Operation *user : users)
39 if (auto connect = dyn_cast<FConnectLike>(user))
40 if (connect.getDest() == old)
41 rewriter.eraseOp(user);
42 return passthrough;
43}
44
45// Return true if it is OK to propagate the name to the operation.
46// Non-pure operations such as instances, registers, and memories are not
47// allowed to update names for name stabilities and LEC reasons.
48static bool isOkToPropagateName(Operation *op) {
49 // Conservatively disallow operations with regions to prevent performance
50 // regression due to recursive calls to sub-regions in mlir::isPure.
51 if (op->getNumRegions() != 0)
52 return false;
53 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
54}
55
56// Move a name hint from a soon to be deleted operation to a new operation.
57// Pass through the new operation to make patterns easier to write. This cannot
58// move a name to a port (block argument), doing so would require rewriting all
59// instance sites as well as the module.
60static Value moveNameHint(OpResult old, Value passthrough) {
61 Operation *op = passthrough.getDefiningOp();
62 // This should handle ports, but it isn't clear we can change those in
63 // canonicalizers.
64 assert(op && "passthrough must be an operation");
65 Operation *oldOp = old.getOwner();
66 auto name = oldOp->getAttrOfType<StringAttr>("name");
67 if (name && !name.getValue().empty() && isOkToPropagateName(op))
68 op->setAttr("name", name);
69 return passthrough;
70}
71
72// Declarative canonicalization patterns
73namespace circt {
74namespace firrtl {
75namespace patterns {
76#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
77} // namespace patterns
78} // namespace firrtl
79} // namespace circt
80
81/// Return true if this operation's operands and results all have a known width.
82/// This only works for integer types.
83static bool hasKnownWidthIntTypes(Operation *op) {
84 auto resultType = type_cast<IntType>(op->getResult(0).getType());
85 if (!resultType.hasWidth())
86 return false;
87 for (Value operand : op->getOperands())
88 if (!type_cast<IntType>(operand.getType()).hasWidth())
89 return false;
90 return true;
91}
92
93/// Return true if this value is 1 bit UInt.
94static bool isUInt1(Type type) {
95 auto t = type_dyn_cast<UIntType>(type);
96 if (!t || !t.hasWidth() || t.getWidth() != 1)
97 return false;
98 return true;
99}
100
101/// Set the name of an op based on the best of two names: The current name, and
102/// the name passed in.
103static void updateName(PatternRewriter &rewriter, Operation *op,
104 StringAttr name) {
105 // Should never rename InstanceOp
106 if (!name || name.getValue().empty() || !isOkToPropagateName(op))
107 return;
108 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) && "Should never rename");
109 auto newName = name.getValue(); // old name is interesting
110 auto newOpName = op->getAttrOfType<StringAttr>("name");
111 // new name might not be interesting
112 if (newOpName)
113 newName = chooseName(newOpName.getValue(), name.getValue());
114 // Only update if needed
115 if (!newOpName || newOpName.getValue() != newName)
116 rewriter.modifyOpInPlace(
117 op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
118}
119
120/// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
121/// If a replaced op has a "name" attribute, this function propagates the name
122/// to the new value.
123static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
124 Value newValue) {
125 if (auto *newOp = newValue.getDefiningOp()) {
126 auto name = op->getAttrOfType<StringAttr>("name");
127 updateName(rewriter, newOp, name);
128 }
129 rewriter.replaceOp(op, newValue);
130}
131
132/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
133/// attribute. If a replaced op has a "name" attribute, this function propagates
134/// the name to the new value.
135template <typename OpTy, typename... Args>
136static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
137 Operation *op, Args &&...args) {
138 auto name = op->getAttrOfType<StringAttr>("name");
139 auto newOp =
140 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
141 updateName(rewriter, newOp, name);
142 return newOp;
143}
144
145/// Return true if the name is droppable. Note that this is different from
146/// `isUselessName` because non-useless names may be also droppable.
148 if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
149 return namableOp.hasDroppableName();
150 return false;
151}
152
153/// Implicitly replace the operand to a constant folding operation with a const
154/// 0 in case the operand is non-constant but has a bit width 0, or if the
155/// operand is an invalid value.
156///
157/// This makes constant folding significantly easier, as we can simply pass the
158/// operands to an operation through this function to appropriately replace any
159/// zero-width dynamic values or invalid values with a constant of value 0.
160static std::optional<APSInt>
161getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
162 assert(type_cast<IntType>(operand.getType()) &&
163 "getExtendedConstant is limited to integer types");
164
165 // We never support constant folding to unknown width values.
166 if (destWidth < 0)
167 return {};
168
169 // Extension signedness follows the operand sign.
170 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
171 return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
172
173 // If the operand is zero bits, then we can return a zero of the result
174 // type.
175 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
176 return APSInt(destWidth,
177 type_cast<IntType>(operand.getType()).isUnsigned());
178 return {};
179}
180
181/// Determine the value of a constant operand for the sake of constant folding.
182static std::optional<APSInt> getConstant(Attribute operand) {
183 if (!operand)
184 return {};
185 if (auto attr = dyn_cast<BoolAttr>(operand))
186 return APSInt(APInt(1, attr.getValue()));
187 if (auto attr = dyn_cast<IntegerAttr>(operand))
188 return attr.getAPSInt();
189 return {};
190}
191
192/// Determine whether a constant operand is a zero value for the sake of
193/// constant folding. This considers `invalidvalue` to be zero.
194static bool isConstantZero(Attribute operand) {
195 if (auto cst = getConstant(operand))
196 return cst->isZero();
197 return false;
198}
199
200/// Determine whether a constant operand is a one value for the sake of constant
201/// folding.
202static bool isConstantOne(Attribute operand) {
203 if (auto cst = getConstant(operand))
204 return cst->isOne();
205 return false;
206}
207
208/// This is the policy for folding, which depends on the sort of operator we're
209/// processing.
210enum class BinOpKind {
211 Normal,
212 Compare,
214};
215
216/// Applies the constant folding function `calculate` to the given operands.
217///
218/// Sign or zero extends the operands appropriately to the bitwidth of the
219/// result type if \p useDstWidth is true, else to the larger of the two operand
220/// bit widths and depending on whether the operation is to be performed on
221/// signed or unsigned operands.
222static Attribute constFoldFIRRTLBinaryOp(
223 Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
224 const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
225 assert(operands.size() == 2 && "binary op takes two operands");
226
227 // We cannot fold something to an unknown width.
228 auto resultType = type_cast<IntType>(op->getResult(0).getType());
229 if (resultType.getWidthOrSentinel() < 0)
230 return {};
231
232 // Any binary op returning i0 is 0.
233 if (resultType.getWidthOrSentinel() == 0)
234 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
235
236 // Determine the operand widths. This is either dictated by the operand type,
237 // or if that type is an unsized integer, by the actual bits necessary to
238 // represent the constant value.
239 auto lhsWidth =
240 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
241 auto rhsWidth =
242 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
243 if (auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
244 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
245 if (auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
246 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
247
248 // Compares extend the operands to the widest of the operand types, not to the
249 // result type.
250 int32_t operandWidth;
251 switch (opKind) {
253 operandWidth = resultType.getWidthOrSentinel();
254 break;
256 // Compares compute with the widest operand, not at the destination type
257 // (which is always i1).
258 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
259 break;
261 operandWidth =
262 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
263 break;
264 }
265
266 auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
267 if (!lhs)
268 return {};
269 auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
270 if (!rhs)
271 return {};
272
273 APInt resultValue = calculate(*lhs, *rhs);
274
275 // If the result type is smaller than the computation then we need to
276 // narrow the constant after the calculation.
277 if (opKind == BinOpKind::DivideOrShift)
278 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
279
280 assert((unsigned)resultType.getWidthOrSentinel() ==
281 resultValue.getBitWidth());
282 return getIntAttr(resultType, resultValue);
283}
284
285/// Applies the canonicalization function `canonicalize` to the given operation.
286///
287/// Determines which (if any) of the operation's operands are constants, and
288/// provides them as arguments to the callback function. Any `invalidvalue` in
289/// the input is mapped to a constant zero. The value returned from the callback
290/// is used as the replacement for `op`, and an additional pad operation is
291/// inserted if necessary. Does nothing if the result of `op` is of unknown
292/// width, in which case the necessity of a pad cannot be determined.
293static LogicalResult canonicalizePrimOp(
294 Operation *op, PatternRewriter &rewriter,
295 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
296 // Can only operate on FIRRTL primitive operations.
297 if (op->getNumResults() != 1)
298 return failure();
299 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
300 if (!type)
301 return failure();
302
303 // Can only operate on operations with a known result width.
304 auto width = type.getBitWidthOrSentinel();
305 if (width < 0)
306 return failure();
307
308 // Determine which of the operands are constants.
309 SmallVector<Attribute, 3> constOperands;
310 constOperands.reserve(op->getNumOperands());
311 for (auto operand : op->getOperands()) {
312 Attribute attr;
313 if (auto *defOp = operand.getDefiningOp())
314 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
315 [&](auto op) { attr = op.getValueAttr(); });
316 constOperands.push_back(attr);
317 }
318
319 // Perform the canonicalization and materialize the result if it is a
320 // constant.
321 auto result = canonicalize(constOperands);
322 if (!result)
323 return failure();
324 Value resultValue;
325 if (auto cst = dyn_cast<Attribute>(result))
326 resultValue = op->getDialect()
327 ->materializeConstant(rewriter, cst, type, op->getLoc())
328 ->getResult(0);
329 else
330 resultValue = cast<Value>(result);
331
332 // Insert a pad if the type widths disagree.
333 if (width !=
334 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
335 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
336
337 // Insert a cast if this is a uint vs. sint or vice versa.
338 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
339 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
340 else if (type_isa<UIntType>(type) &&
341 type_isa<SIntType>(resultValue.getType()))
342 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
343
344 assert(type == resultValue.getType() && "canonicalization changed type");
345 replaceOpAndCopyName(rewriter, op, resultValue);
346 return success();
347}
348
349/// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
350/// value if `bitWidth` is 0.
351static APInt getMaxUnsignedValue(unsigned bitWidth) {
352 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
353}
354
355/// Get the smallest signed value of a given bit width. Returns a 1-bit zero
356/// value if `bitWidth` is 0.
357static APInt getMinSignedValue(unsigned bitWidth) {
358 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
359}
360
361/// Get the largest signed value of a given bit width. Returns a 1-bit zero
362/// value if `bitWidth` is 0.
363static APInt getMaxSignedValue(unsigned bitWidth) {
364 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
365}
366
367//===----------------------------------------------------------------------===//
368// Fold Hooks
369//===----------------------------------------------------------------------===//
370
371OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
372 assert(adaptor.getOperands().empty() && "constant has no operands");
373 return getValueAttr();
374}
375
376OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
377 assert(adaptor.getOperands().empty() && "constant has no operands");
378 return getValueAttr();
379}
380
381OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
382 assert(adaptor.getOperands().empty() && "constant has no operands");
383 return getFieldsAttr();
384}
385
386OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
387 assert(adaptor.getOperands().empty() && "constant has no operands");
388 return getValueAttr();
389}
390
391OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
392 assert(adaptor.getOperands().empty() && "constant has no operands");
393 return getValueAttr();
394}
395
396OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
397 assert(adaptor.getOperands().empty() && "constant has no operands");
398 return getValueAttr();
399}
400
401OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
402 assert(adaptor.getOperands().empty() && "constant has no operands");
403 return getValueAttr();
404}
405
406//===----------------------------------------------------------------------===//
407// Binary Operators
408//===----------------------------------------------------------------------===//
409
410OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
412 *this, adaptor.getOperands(), BinOpKind::Normal,
413 [=](const APSInt &a, const APSInt &b) { return a + b; });
414}
415
416void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417 MLIRContext *context) {
418 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
419 patterns::AddOfSelf, patterns::AddOfPad>(context);
420}
421
422OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
424 *this, adaptor.getOperands(), BinOpKind::Normal,
425 [=](const APSInt &a, const APSInt &b) { return a - b; });
426}
427
428void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
429 MLIRContext *context) {
430 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
431 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
432 patterns::SubOfPadL, patterns::SubOfPadR>(context);
433}
434
435OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
436 // mul(x, 0) -> 0
437 //
438 // This is legal because it aligns with the Scala FIRRTL Compiler
439 // interpretation of lowering invalid to constant zero before constant
440 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
441 // multiplication this way and will emit "x * 0".
442 if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
443 return getIntZerosAttr(getType());
444
446 *this, adaptor.getOperands(), BinOpKind::Normal,
447 [=](const APSInt &a, const APSInt &b) { return a * b; });
448}
449
450OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
451 /// div(x, x) -> 1
452 ///
453 /// Division by zero is undefined in the FIRRTL specification. This fold
454 /// exploits that fact to optimize self division to one. Note: this should
455 /// supersede any division with invalid or zero. Division of invalid by
456 /// invalid should be one.
457 if (getLhs() == getRhs()) {
458 auto width = getType().base().getWidthOrSentinel();
459 if (width == -1)
460 width = 2;
461 // Only fold if we have at least 1 bit of width to represent the `1` value.
462 if (width != 0)
463 return getIntAttr(getType(), APInt(width, 1));
464 }
465
466 // div(0, x) -> 0
467 //
468 // This is legal because it aligns with the Scala FIRRTL Compiler
469 // interpretation of lowering invalid to constant zero before constant
470 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
471 // division this way and will emit "0 / x".
472 if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
473 return getIntZerosAttr(getType());
474
475 /// div(x, 1) -> x : (uint, uint) -> uint
476 ///
477 /// UInt division by one returns the numerator. SInt division can't
478 /// be folded here because it increases the return type bitwidth by
479 /// one and requires sign extension (a new op).
480 if (auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
481 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
482 return getLhs();
483
485 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
486 [=](const APSInt &a, const APSInt &b) -> APInt {
487 if (!!b)
488 return a / b;
489 return APInt(a.getBitWidth(), 0);
490 });
491}
492
493OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
494 // rem(x, x) -> 0
495 //
496 // Division by zero is undefined in the FIRRTL specification. This fold
497 // exploits that fact to optimize self division remainder to zero. Note:
498 // this should supersede any division with invalid or zero. Remainder of
499 // division of invalid by invalid should be zero.
500 if (getLhs() == getRhs())
501 return getIntZerosAttr(getType());
502
503 // rem(0, x) -> 0
504 //
505 // This is legal because it aligns with the Scala FIRRTL Compiler
506 // interpretation of lowering invalid to constant zero before constant
507 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
508 // division this way and will emit "0 % x".
509 if (isConstantZero(adaptor.getLhs()))
510 return getIntZerosAttr(getType());
511
513 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
514 [=](const APSInt &a, const APSInt &b) -> APInt {
515 if (!!b)
516 return a % b;
517 return APInt(a.getBitWidth(), 0);
518 });
519}
520
521OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
523 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
524 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
525}
526
527OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
529 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
530 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
531}
532
533OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
535 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
536 [=](const APSInt &a, const APSInt &b) -> APInt {
537 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
538 : a.ashr(b);
539 });
540}
541
542// TODO: Move to DRR.
543OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
544 if (auto rhsCst = getConstant(adaptor.getRhs())) {
545 /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
546 if (rhsCst->isZero())
547 return getIntZerosAttr(getType());
548
549 /// and(x, -1) -> x
550 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
551 getRhs().getType() == getType())
552 return getLhs();
553 }
554
555 if (auto lhsCst = getConstant(adaptor.getLhs())) {
556 /// and(0, x) -> 0, 0 is largest or is implicit zero extended
557 if (lhsCst->isZero())
558 return getIntZerosAttr(getType());
559
560 /// and(-1, x) -> x
561 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
562 getRhs().getType() == getType())
563 return getRhs();
564 }
565
566 /// and(x, x) -> x
567 if (getLhs() == getRhs() && getRhs().getType() == getType())
568 return getRhs();
569
571 *this, adaptor.getOperands(), BinOpKind::Normal,
572 [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
573}
574
575void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
576 MLIRContext *context) {
577 results
578 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
579 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
580 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
581}
582
583OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
584 if (auto rhsCst = getConstant(adaptor.getRhs())) {
585 /// or(x, 0) -> x
586 if (rhsCst->isZero() && getLhs().getType() == getType())
587 return getLhs();
588
589 /// or(x, -1) -> -1
590 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
591 getLhs().getType() == getType())
592 return getRhs();
593 }
594
595 if (auto lhsCst = getConstant(adaptor.getLhs())) {
596 /// or(0, x) -> x
597 if (lhsCst->isZero() && getRhs().getType() == getType())
598 return getRhs();
599
600 /// or(-1, x) -> -1
601 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
602 getRhs().getType() == getType())
603 return getLhs();
604 }
605
606 /// or(x, x) -> x
607 if (getLhs() == getRhs() && getRhs().getType() == getType())
608 return getRhs();
609
611 *this, adaptor.getOperands(), BinOpKind::Normal,
612 [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
613}
614
615void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
616 MLIRContext *context) {
617 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
618 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
619 patterns::OrOrr>(context);
620}
621
622OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
623 /// xor(x, 0) -> x
624 if (auto rhsCst = getConstant(adaptor.getRhs()))
625 if (rhsCst->isZero() &&
626 firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
627 return getLhs();
628
629 /// xor(x, 0) -> x
630 if (auto lhsCst = getConstant(adaptor.getLhs()))
631 if (lhsCst->isZero() &&
632 firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
633 return getRhs();
634
635 /// xor(x, x) -> 0
636 if (getLhs() == getRhs())
637 return getIntAttr(
638 getType(),
639 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
640
642 *this, adaptor.getOperands(), BinOpKind::Normal,
643 [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
644}
645
646void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
647 MLIRContext *context) {
648 results.insert<patterns::extendXor, patterns::moveConstXor,
649 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
650 context);
651}
652
653void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
654 MLIRContext *context) {
655 results.insert<patterns::LEQWithConstLHS>(context);
656}
657
658OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
659 bool isUnsigned = getLhs().getType().base().isUnsigned();
660
661 // leq(x, x) -> 1
662 if (getLhs() == getRhs())
663 return getIntAttr(getType(), APInt(1, 1));
664
665 // Comparison against constant outside type bounds.
666 if (auto width = getLhs().getType().base().getWidth()) {
667 if (auto rhsCst = getConstant(adaptor.getRhs())) {
668 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
669 commonWidth = std::max(commonWidth, 1);
670
671 // leq(x, const) -> 0 where const < minValue of the unsigned type of x
672 // This can never occur since const is unsigned and cannot be less than 0.
673
674 // leq(x, const) -> 0 where const < minValue of the signed type of x
675 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
676 .slt(getMinSignedValue(*width).sext(commonWidth)))
677 return getIntAttr(getType(), APInt(1, 0));
678
679 // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
680 if (isUnsigned && rhsCst->zext(commonWidth)
681 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
682 return getIntAttr(getType(), APInt(1, 1));
683
684 // leq(x, const) -> 1 where const >= maxValue of the signed type of x
685 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
686 .sge(getMaxSignedValue(*width).sext(commonWidth)))
687 return getIntAttr(getType(), APInt(1, 1));
688 }
689 }
690
692 *this, adaptor.getOperands(), BinOpKind::Compare,
693 [=](const APSInt &a, const APSInt &b) -> APInt {
694 return APInt(1, a <= b);
695 });
696}
697
698void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
699 MLIRContext *context) {
700 results.insert<patterns::LTWithConstLHS>(context);
701}
702
703OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
704 IntType lhsType = getLhs().getType();
705 bool isUnsigned = lhsType.isUnsigned();
706
707 // lt(x, x) -> 0
708 if (getLhs() == getRhs())
709 return getIntAttr(getType(), APInt(1, 0));
710
711 // lt(x, 0) -> 0 when x is unsigned
712 if (auto rhsCst = getConstant(adaptor.getRhs())) {
713 if (rhsCst->isZero() && lhsType.isUnsigned())
714 return getIntAttr(getType(), APInt(1, 0));
715 }
716
717 // Comparison against constant outside type bounds.
718 if (auto width = lhsType.getWidth()) {
719 if (auto rhsCst = getConstant(adaptor.getRhs())) {
720 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
721 commonWidth = std::max(commonWidth, 1);
722
723 // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
724 // Handled explicitly above.
725
726 // lt(x, const) -> 0 where const <= minValue of the signed type of x
727 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
728 .sle(getMinSignedValue(*width).sext(commonWidth)))
729 return getIntAttr(getType(), APInt(1, 0));
730
731 // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
732 if (isUnsigned && rhsCst->zext(commonWidth)
733 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
734 return getIntAttr(getType(), APInt(1, 1));
735
736 // lt(x, const) -> 1 where const > maxValue of the signed type of x
737 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
738 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
739 return getIntAttr(getType(), APInt(1, 1));
740 }
741 }
742
744 *this, adaptor.getOperands(), BinOpKind::Compare,
745 [=](const APSInt &a, const APSInt &b) -> APInt {
746 return APInt(1, a < b);
747 });
748}
749
750void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
751 MLIRContext *context) {
752 results.insert<patterns::GEQWithConstLHS>(context);
753}
754
755OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
756 IntType lhsType = getLhs().getType();
757 bool isUnsigned = lhsType.isUnsigned();
758
759 // geq(x, x) -> 1
760 if (getLhs() == getRhs())
761 return getIntAttr(getType(), APInt(1, 1));
762
763 // geq(x, 0) -> 1 when x is unsigned
764 if (auto rhsCst = getConstant(adaptor.getRhs())) {
765 if (rhsCst->isZero() && isUnsigned)
766 return getIntAttr(getType(), APInt(1, 1));
767 }
768
769 // Comparison against constant outside type bounds.
770 if (auto width = lhsType.getWidth()) {
771 if (auto rhsCst = getConstant(adaptor.getRhs())) {
772 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
773 commonWidth = std::max(commonWidth, 1);
774
775 // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
776 if (isUnsigned && rhsCst->zext(commonWidth)
777 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
778 return getIntAttr(getType(), APInt(1, 0));
779
780 // geq(x, const) -> 0 where const > maxValue of the signed type of x
781 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
782 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
783 return getIntAttr(getType(), APInt(1, 0));
784
785 // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
786 // Handled explicitly above.
787
788 // geq(x, const) -> 1 where const <= minValue of the signed type of x
789 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
790 .sle(getMinSignedValue(*width).sext(commonWidth)))
791 return getIntAttr(getType(), APInt(1, 1));
792 }
793 }
794
796 *this, adaptor.getOperands(), BinOpKind::Compare,
797 [=](const APSInt &a, const APSInt &b) -> APInt {
798 return APInt(1, a >= b);
799 });
800}
801
802void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
803 MLIRContext *context) {
804 results.insert<patterns::GTWithConstLHS>(context);
805}
806
807OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
808 IntType lhsType = getLhs().getType();
809 bool isUnsigned = lhsType.isUnsigned();
810
811 // gt(x, x) -> 0
812 if (getLhs() == getRhs())
813 return getIntAttr(getType(), APInt(1, 0));
814
815 // Comparison against constant outside type bounds.
816 if (auto width = lhsType.getWidth()) {
817 if (auto rhsCst = getConstant(adaptor.getRhs())) {
818 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
819 commonWidth = std::max(commonWidth, 1);
820
821 // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
822 if (isUnsigned && rhsCst->zext(commonWidth)
823 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
824 return getIntAttr(getType(), APInt(1, 0));
825
826 // gt(x, const) -> 0 where const >= maxValue of the signed type of x
827 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
828 .sge(getMaxSignedValue(*width).sext(commonWidth)))
829 return getIntAttr(getType(), APInt(1, 0));
830
831 // gt(x, const) -> 1 where const < minValue of the unsigned type of x
832 // This can never occur since const is unsigned and cannot be less than 0.
833
834 // gt(x, const) -> 1 where const < minValue of the signed type of x
835 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
836 .slt(getMinSignedValue(*width).sext(commonWidth)))
837 return getIntAttr(getType(), APInt(1, 1));
838 }
839 }
840
842 *this, adaptor.getOperands(), BinOpKind::Compare,
843 [=](const APSInt &a, const APSInt &b) -> APInt {
844 return APInt(1, a > b);
845 });
846}
847
848OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
849 // eq(x, x) -> 1
850 if (getLhs() == getRhs())
851 return getIntAttr(getType(), APInt(1, 1));
852
853 if (auto rhsCst = getConstant(adaptor.getRhs())) {
854 /// eq(x, 1) -> x when x is 1 bit.
855 /// TODO: Support SInt<1> on the LHS etc.
856 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
857 getRhs().getType() == getType())
858 return getLhs();
859 }
860
862 *this, adaptor.getOperands(), BinOpKind::Compare,
863 [=](const APSInt &a, const APSInt &b) -> APInt {
864 return APInt(1, a == b);
865 });
866}
867
868LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
869 return canonicalizePrimOp(
870 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
871 if (auto rhsCst = getConstant(operands[1])) {
872 auto width = op.getLhs().getType().getBitWidthOrSentinel();
873
874 // eq(x, 0) -> not(x) when x is 1 bit.
875 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
876 op.getRhs().getType() == op.getType()) {
877 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
878 .getResult();
879 }
880
881 // eq(x, 0) -> not(orr(x)) when x is >1 bit
882 if (rhsCst->isZero() && width > 1) {
883 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
884 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
885 }
886
887 // eq(x, ~0) -> andr(x) when x is >1 bit
888 if (rhsCst->isAllOnes() && width > 1 &&
889 op.getLhs().getType() == op.getRhs().getType()) {
890 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
891 .getResult();
892 }
893 }
894 return {};
895 });
896}
897
898OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
899 // neq(x, x) -> 0
900 if (getLhs() == getRhs())
901 return getIntAttr(getType(), APInt(1, 0));
902
903 if (auto rhsCst = getConstant(adaptor.getRhs())) {
904 /// neq(x, 0) -> x when x is 1 bit.
905 /// TODO: Support SInt<1> on the LHS etc.
906 if (rhsCst->isZero() && getLhs().getType() == getType() &&
907 getRhs().getType() == getType())
908 return getLhs();
909 }
910
912 *this, adaptor.getOperands(), BinOpKind::Compare,
913 [=](const APSInt &a, const APSInt &b) -> APInt {
914 return APInt(1, a != b);
915 });
916}
917
918LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
919 return canonicalizePrimOp(
920 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
921 if (auto rhsCst = getConstant(operands[1])) {
922 auto width = op.getLhs().getType().getBitWidthOrSentinel();
923
924 // neq(x, 1) -> not(x) when x is 1 bit
925 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
926 op.getRhs().getType() == op.getType()) {
927 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
928 .getResult();
929 }
930
931 // neq(x, 0) -> orr(x) when x is >1 bit
932 if (rhsCst->isZero() && width > 1) {
933 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
934 .getResult();
935 }
936
937 // neq(x, ~0) -> not(andr(x))) when x is >1 bit
938 if (rhsCst->isAllOnes() && width > 1 &&
939 op.getLhs().getType() == op.getRhs().getType()) {
940 auto andrOp =
941 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
942 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
943 }
944 }
945
946 return {};
947 });
948}
949
950OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
951 // TODO: implement constant folding, etc.
952 // Tracked in https://github.com/llvm/circt/issues/6696.
953 return {};
954}
955
956OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
957 // TODO: implement constant folding, etc.
958 // Tracked in https://github.com/llvm/circt/issues/6724.
959 return {};
960}
961
962OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
963 if (auto rhsCst = getConstant(adaptor.getRhs())) {
964 if (auto lhsCst = getConstant(adaptor.getLhs())) {
965
966 return IntegerAttr::get(IntegerType::get(getContext(),
967 lhsCst->getBitWidth(),
968 IntegerType::Signed),
969 lhsCst->ashr(*rhsCst));
970 }
971
972 if (rhsCst->isZero())
973 return getLhs();
974 }
975
976 return {};
977}
978
979OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
980 if (auto rhsCst = getConstant(adaptor.getRhs())) {
981 // Constant folding
982 if (auto lhsCst = getConstant(adaptor.getLhs()))
983
984 return IntegerAttr::get(IntegerType::get(getContext(),
985 lhsCst->getBitWidth(),
986 IntegerType::Signed),
987 lhsCst->shl(*rhsCst));
988
989 // integer.shl(x, 0) -> x
990 if (rhsCst->isZero())
991 return getLhs();
992 }
993
994 return {};
995}
996
997//===----------------------------------------------------------------------===//
998// Unary Operators
999//===----------------------------------------------------------------------===//
1000
1001OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
1002 auto base = getInput().getType();
1003 auto w = getBitWidth(base);
1004 if (w)
1005 return getIntAttr(getType(), APInt(32, *w));
1006 return {};
1007}
1008
1009OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
1010 // No constant can be 'x' by definition.
1011 if (auto cst = getConstant(adaptor.getArg()))
1012 return getIntAttr(getType(), APInt(1, 0));
1013 return {};
1014}
1015
1016OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1017 // No effect.
1018 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1019 return getInput();
1020
1021 // Be careful to only fold the cast into the constant if the size is known.
1022 // Otherwise width inference may produce differently-sized constants if the
1023 // sign changes.
1024 if (getType().base().hasWidth())
1025 if (auto cst = getConstant(adaptor.getInput()))
1026 return getIntAttr(getType(), *cst);
1027
1028 return {};
1029}
1030
1031void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1032 MLIRContext *context) {
1033 results.insert<patterns::StoUtoS>(context);
1034}
1035
1036OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1037 // No effect.
1038 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1039 return getInput();
1040
1041 // Be careful to only fold the cast into the constant if the size is known.
1042 // Otherwise width inference may produce differently-sized constants if the
1043 // sign changes.
1044 if (getType().base().hasWidth())
1045 if (auto cst = getConstant(adaptor.getInput()))
1046 return getIntAttr(getType(), *cst);
1047
1048 return {};
1049}
1050
1051void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1052 MLIRContext *context) {
1053 results.insert<patterns::UtoStoU>(context);
1054}
1055
1056OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1057 // No effect.
1058 if (getInput().getType() == getType())
1059 return getInput();
1060
1061 // Constant fold.
1062 if (auto cst = getConstant(adaptor.getInput()))
1063 return BoolAttr::get(getContext(), cst->getBoolValue());
1064
1065 return {};
1066}
1067
1068OpFoldResult AsResetPrimOp::fold(FoldAdaptor adaptor) {
1069 if (auto cst = getConstant(adaptor.getInput()))
1070 return BoolAttr::get(getContext(), cst->getBoolValue());
1071 return {};
1072}
1073
1074OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1075 // No effect.
1076 if (getInput().getType() == getType())
1077 return getInput();
1078
1079 // Constant fold.
1080 if (auto cst = getConstant(adaptor.getInput()))
1081 return BoolAttr::get(getContext(), cst->getBoolValue());
1082
1083 return {};
1084}
1085
1086OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1087 if (!hasKnownWidthIntTypes(*this))
1088 return {};
1089
1090 // Signed to signed is a noop, unsigned operands prepend a zero bit.
1091 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1092 getType().base().getWidthOrSentinel()))
1093 return getIntAttr(getType(), *cst);
1094
1095 return {};
1096}
1097
1098void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1099 MLIRContext *context) {
1100 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1101}
1102
1103OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1104 if (!hasKnownWidthIntTypes(*this))
1105 return {};
1106
1107 // FIRRTL negate always adds a bit.
1108 // -x ---> 0-sext(x) or 0-zext(x)
1109 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1110 getType().base().getWidthOrSentinel()))
1111 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1112
1113 return {};
1114}
1115
1116OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1117 if (!hasKnownWidthIntTypes(*this))
1118 return {};
1119
1120 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1121 getType().base().getWidthOrSentinel()))
1122 return getIntAttr(getType(), ~*cst);
1123
1124 return {};
1125}
1126
1127void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1128 MLIRContext *context) {
1129 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1130 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1131 patterns::NotGt>(context);
1132}
1133
1134class ReductionCat : public mlir::RewritePattern {
1135public:
1136 ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
1137 : RewritePattern(opName, 0, context) {}
1138
1139 /// Handle a constant operand in the cat operation.
1140 /// Returns true if the entire reduction can be replaced with a constant.
1141 /// May add non-zero constants to the remaining operands list.
1142 virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1143 ConstantOp constantOp,
1144 SmallVectorImpl<Value> &remaining) const = 0;
1145
1146 /// Return the unit value for this reduction operation:
1147 virtual bool getIdentityValue() const = 0;
1148
1149 LogicalResult
1150 matchAndRewrite(Operation *op,
1151 mlir::PatternRewriter &rewriter) const override {
1152 // Check if the operand is a cat operation
1153 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1154 if (!catOp)
1155 return failure();
1156
1157 SmallVector<Value> nonConstantOperands;
1158
1159 // Process each operand of the cat operation
1160 for (auto operand : catOp.getInputs()) {
1161 if (auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1162 // Handle constant operands - may short-circuit the entire operation
1163 if (handleConstant(rewriter, op, constantOp, nonConstantOperands))
1164 return success();
1165 } else {
1166 // Keep non-constant operands for further processing
1167 nonConstantOperands.push_back(operand);
1168 }
1169 }
1170
1171 // If no operands remain, replace with identity value
1172 if (nonConstantOperands.empty()) {
1173 replaceOpWithNewOpAndCopyName<ConstantOp>(
1174 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1175 APInt(1, getIdentityValue()));
1176 return success();
1177 }
1178
1179 // If only one operand remains, apply reduction directly to it
1180 if (nonConstantOperands.size() == 1) {
1181 rewriter.modifyOpInPlace(
1182 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1183 return success();
1184 }
1185
1186 // Multiple operands remain - optimize only when cat has a single use.
1187 if (catOp->hasOneUse() &&
1188 nonConstantOperands.size() < catOp->getNumOperands()) {
1189 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1190 nonConstantOperands);
1191 return success();
1192 }
1193 return failure();
1194 }
1195};
1196
1197class OrRCat : public ReductionCat {
1198public:
1199 OrRCat(MLIRContext *context)
1200 : ReductionCat(context, OrRPrimOp::getOperationName()) {}
1201 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1202 ConstantOp value,
1203 SmallVectorImpl<Value> &remaining) const override {
1204 if (value.getValue().isZero())
1205 return false;
1206
1207 replaceOpWithNewOpAndCopyName<ConstantOp>(
1208 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1209 llvm::APInt(1, 1));
1210 return true;
1211 }
1212 bool getIdentityValue() const override { return false; }
1213};
1214
1215class AndRCat : public ReductionCat {
1216public:
1217 AndRCat(MLIRContext *context)
1218 : ReductionCat(context, AndRPrimOp::getOperationName()) {}
1219 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1220 ConstantOp value,
1221 SmallVectorImpl<Value> &remaining) const override {
1222 if (value.getValue().isAllOnes())
1223 return false;
1224
1225 replaceOpWithNewOpAndCopyName<ConstantOp>(
1226 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1227 llvm::APInt(1, 0));
1228 return true;
1229 }
1230 bool getIdentityValue() const override { return true; }
1231};
1232
1233class XorRCat : public ReductionCat {
1234public:
1235 XorRCat(MLIRContext *context)
1236 : ReductionCat(context, XorRPrimOp::getOperationName()) {}
1237 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1238 ConstantOp value,
1239 SmallVectorImpl<Value> &remaining) const override {
1240 if (value.getValue().isZero())
1241 return false;
1242 remaining.push_back(value);
1243 return false;
1244 }
1245 bool getIdentityValue() const override { return false; }
1246};
1247
1248OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1249 if (!hasKnownWidthIntTypes(*this))
1250 return {};
1251
1252 if (getInput().getType().getBitWidthOrSentinel() == 0)
1253 return getIntAttr(getType(), APInt(1, 1));
1254
1255 // x == -1
1256 if (auto cst = getConstant(adaptor.getInput()))
1257 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1258
1259 // one bit is identity. Only applies to UInt since we can't make a cast
1260 // here.
1261 if (isUInt1(getInput().getType()))
1262 return getInput();
1263
1264 return {};
1265}
1266
1267void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1268 MLIRContext *context) {
1269 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1270 patterns::AndRPadS, patterns::AndRCatAndR_left,
1271 patterns::AndRCatAndR_right, AndRCat>(context);
1272}
1273
1274OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1275 if (!hasKnownWidthIntTypes(*this))
1276 return {};
1277
1278 if (getInput().getType().getBitWidthOrSentinel() == 0)
1279 return getIntAttr(getType(), APInt(1, 0));
1280
1281 // x != 0
1282 if (auto cst = getConstant(adaptor.getInput()))
1283 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1284
1285 // one bit is identity. Only applies to UInt since we can't make a cast
1286 // here.
1287 if (isUInt1(getInput().getType()))
1288 return getInput();
1289
1290 return {};
1291}
1292
1293void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1294 MLIRContext *context) {
1295 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1296 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right, OrRCat>(
1297 context);
1298}
1299
1300OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1301 if (!hasKnownWidthIntTypes(*this))
1302 return {};
1303
1304 if (getInput().getType().getBitWidthOrSentinel() == 0)
1305 return getIntAttr(getType(), APInt(1, 0));
1306
1307 // popcount(x) & 1
1308 if (auto cst = getConstant(adaptor.getInput()))
1309 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1310
1311 // one bit is identity. Only applies to UInt since we can't make a cast here.
1312 if (isUInt1(getInput().getType()))
1313 return getInput();
1314
1315 return {};
1316}
1317
1318void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1319 MLIRContext *context) {
1320 results
1321 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1322 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right, XorRCat>(
1323 context);
1324}
1325
1326//===----------------------------------------------------------------------===//
1327// Other Operators
1328//===----------------------------------------------------------------------===//
1329
1330OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1331 auto inputs = getInputs();
1332 auto inputAdaptors = adaptor.getInputs();
1333
1334 // If no inputs, return 0-bit value
1335 if (inputs.empty())
1336 return getIntZerosAttr(getType());
1337
1338 // If single input and same type, return it
1339 if (inputs.size() == 1 && inputs[0].getType() == getType())
1340 return inputs[0];
1341
1342 // Make sure it's safe to fold.
1343 if (!hasKnownWidthIntTypes(*this))
1344 return {};
1345
1346 // Filter out zero-width operands
1347 SmallVector<Value> nonZeroInputs;
1348 SmallVector<Attribute> nonZeroAttributes;
1349 bool allConstant = true;
1350 for (auto [input, attr] : llvm::zip(inputs, inputAdaptors)) {
1351 auto inputType = type_cast<IntType>(input.getType());
1352 if (inputType.getBitWidthOrSentinel() != 0) {
1353 nonZeroInputs.push_back(input);
1354 if (!attr)
1355 allConstant = false;
1356 if (nonZeroInputs.size() > 1 && !allConstant)
1357 return {};
1358 }
1359 }
1360
1361 // If all inputs were zero-width, return 0-bit value
1362 if (nonZeroInputs.empty())
1363 return getIntZerosAttr(getType());
1364
1365 // If only one non-zero input and it has the same type as result, return it
1366 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1367 return nonZeroInputs[0];
1368
1369 if (!hasKnownWidthIntTypes(*this))
1370 return {};
1371
1372 // Constant fold cat - concatenate all constant operands
1373 SmallVector<APInt> constants;
1374 for (auto inputAdaptor : inputAdaptors) {
1375 if (auto cst = getConstant(inputAdaptor))
1376 constants.push_back(*cst);
1377 else
1378 return {}; // Not all operands are constant
1379 }
1380
1381 assert(!constants.empty());
1382 // Concatenate all constants from left to right
1383 APInt result = constants[0];
1384 for (size_t i = 1; i < constants.size(); ++i)
1385 result = result.concat(constants[i]);
1386
1387 return getIntAttr(getType(), result);
1388}
1389
1390void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1391 MLIRContext *context) {
1392 results.insert<patterns::DShlOfConstant>(context);
1393}
1394
1395void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1396 MLIRContext *context) {
1397 results.insert<patterns::DShrOfConstant>(context);
1398}
1399
1400namespace {
1401
1402/// Canonicalize a cat tree into single cat operation.
1403class FlattenCat : public mlir::OpRewritePattern<CatPrimOp> {
1404public:
1405 using OpRewritePattern::OpRewritePattern;
1406
1407 LogicalResult
1408 matchAndRewrite(CatPrimOp cat,
1409 mlir::PatternRewriter &rewriter) const override {
1410 if (!hasKnownWidthIntTypes(cat) ||
1411 cat.getType().getBitWidthOrSentinel() == 0)
1412 return failure();
1413
1414 // If this is not a root of concat tree, skip.
1415 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1416 return failure();
1417
1418 // Try flattening the cat tree.
1419 SmallVector<Value> operands;
1420 SmallVector<Value> worklist;
1421 auto pushOperands = [&worklist](CatPrimOp op) {
1422 for (auto operand : llvm::reverse(op.getInputs()))
1423 worklist.push_back(operand);
1424 };
1425 pushOperands(cat);
1426 bool hasSigned = false, hasUnsigned = false;
1427 while (!worklist.empty()) {
1428 auto value = worklist.pop_back_val();
1429 auto catOp = value.getDefiningOp<CatPrimOp>();
1430 if (!catOp) {
1431 operands.push_back(value);
1432 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) = true;
1433 continue;
1434 }
1435
1436 pushOperands(catOp);
1437 }
1438
1439 // Helper function to cast signed values to unsigned. CatPrimOp converts
1440 // signed values to unsigned so we need to do it explicitly here.
1441 auto castToUIntIfSigned = [&](Value value) -> Value {
1442 if (type_isa<UIntType>(value.getType()))
1443 return value;
1444 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1445 };
1446
1447 assert(operands.size() >= 1 && "zero width cast must be rejected");
1448
1449 if (operands.size() == 1) {
1450 rewriter.replaceOp(cat, castToUIntIfSigned(operands[0]));
1451 return success();
1452 }
1453
1454 if (operands.size() == cat->getNumOperands())
1455 return failure();
1456
1457 // If types are mixed, cast all operands to unsigned.
1458 if (hasSigned && hasUnsigned)
1459 for (auto &operand : operands)
1460 operand = castToUIntIfSigned(operand);
1461
1462 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1463 operands);
1464 return success();
1465 }
1466};
1467
1468// Fold successive constants cat(x, y, 1, 10, z) -> cat(x, y, 110, z)
1469class CatOfConstant : public mlir::OpRewritePattern<CatPrimOp> {
1470public:
1471 using OpRewritePattern::OpRewritePattern;
1472
1473 LogicalResult
1474 matchAndRewrite(CatPrimOp cat,
1475 mlir::PatternRewriter &rewriter) const override {
1476 if (!hasKnownWidthIntTypes(cat))
1477 return failure();
1478
1479 SmallVector<Value> operands;
1480
1481 for (size_t i = 0; i < cat->getNumOperands(); ++i) {
1482 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1483 if (!cst) {
1484 operands.push_back(cat.getInputs()[i]);
1485 continue;
1486 }
1487 APSInt value = cst.getValue();
1488 size_t j = i + 1;
1489 for (; j < cat->getNumOperands(); ++j) {
1490 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1491 if (!nextCst)
1492 break;
1493 value = value.concat(nextCst.getValue());
1494 }
1495
1496 if (j == i + 1) {
1497 // Not folded.
1498 operands.push_back(cst);
1499 } else {
1500 // Folded.
1501 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1502 }
1503
1504 i = j - 1;
1505 }
1506
1507 if (operands.size() == cat->getNumOperands())
1508 return failure();
1509
1510 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1511 operands);
1512
1513 return success();
1514 }
1515};
1516
1517} // namespace
1518
1519void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1520 MLIRContext *context) {
1521 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1522 patterns::CatCast, FlattenCat, CatOfConstant>(context);
1523}
1524
1525//===----------------------------------------------------------------------===//
1526// StringConcatOp
1527//===----------------------------------------------------------------------===//
1528
1529OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
1530 // Fold single-operand concat to just the operand.
1531 if (getInputs().size() == 1)
1532 return getInputs()[0];
1533
1534 // Check if all operands are constant strings before accumulating.
1535 if (!llvm::all_of(adaptor.getInputs(), [](Attribute operand) {
1536 return isa_and_nonnull<StringAttr>(operand);
1537 }))
1538 return {};
1539
1540 // All operands are constant strings, concatenate them.
1541 SmallString<64> result;
1542 for (auto operand : adaptor.getInputs())
1543 result += cast<StringAttr>(operand).getValue();
1544
1545 return StringAttr::get(getContext(), result);
1546}
1547
1548namespace {
1549/// Flatten nested string.concat operations into a single concat.
1550/// string.concat(a, string.concat(b, c), d) -> string.concat(a, b, c, d)
1551class FlattenStringConcat : public mlir::OpRewritePattern<StringConcatOp> {
1552public:
1553 using OpRewritePattern::OpRewritePattern;
1554
1555 LogicalResult
1556 matchAndRewrite(StringConcatOp concat,
1557 mlir::PatternRewriter &rewriter) const override {
1558
1559 // Check if any operands are nested concats with a single use. Only inline
1560 // single-use nested concats to avoid fighting with DCE.
1561 bool hasNestedConcat = llvm::any_of(concat.getInputs(), [](Value operand) {
1562 auto nestedConcat = operand.getDefiningOp<StringConcatOp>();
1563 return nestedConcat && operand.hasOneUse();
1564 });
1565
1566 if (!hasNestedConcat)
1567 return failure();
1568
1569 // Flatten nested concats that have a single use.
1570 SmallVector<Value> flatOperands;
1571 for (auto input : concat.getInputs()) {
1572 if (auto nestedConcat = input.getDefiningOp<StringConcatOp>();
1573 nestedConcat && input.hasOneUse())
1574 llvm::append_range(flatOperands, nestedConcat.getInputs());
1575 else
1576 flatOperands.push_back(input);
1577 }
1578
1579 rewriter.modifyOpInPlace(concat,
1580 [&]() { concat->setOperands(flatOperands); });
1581 return success();
1582 }
1583};
1584
1585/// Merge consecutive constant strings in a concat and remove empty strings.
1586/// string.concat("a", "b", x, "", "c", "d") -> string.concat("ab", x, "cd")
1587class MergeAdjacentStringConstants
1588 : public mlir::OpRewritePattern<StringConcatOp> {
1589public:
1590 using OpRewritePattern::OpRewritePattern;
1591
1592 LogicalResult
1593 matchAndRewrite(StringConcatOp concat,
1594 mlir::PatternRewriter &rewriter) const override {
1595
1596 SmallVector<Value> newOperands;
1597 SmallString<64> accumulatedLit;
1598 SmallVector<StringConstantOp> accumulatedOps;
1599 bool changed = false;
1600
1601 auto flushLiterals = [&]() {
1602 if (accumulatedOps.empty())
1603 return;
1604
1605 // If only one literal, reuse it.
1606 if (accumulatedOps.size() == 1) {
1607 newOperands.push_back(accumulatedOps[0]);
1608 } else {
1609 // Multiple literals - merge them.
1610 auto newLit = rewriter.createOrFold<StringConstantOp>(
1611 concat.getLoc(), StringAttr::get(getContext(), accumulatedLit));
1612 newOperands.push_back(newLit);
1613 changed = true;
1614 }
1615 accumulatedLit.clear();
1616 accumulatedOps.clear();
1617 };
1618
1619 for (auto operand : concat.getInputs()) {
1620 if (auto litOp = operand.getDefiningOp<StringConstantOp>()) {
1621 // Skip empty strings.
1622 if (litOp.getValue().empty()) {
1623 changed = true;
1624 continue;
1625 }
1626 accumulatedLit += litOp.getValue();
1627 accumulatedOps.push_back(litOp);
1628 } else {
1629 flushLiterals();
1630 newOperands.push_back(operand);
1631 }
1632 }
1633
1634 // Flush any remaining literals.
1635 flushLiterals();
1636
1637 if (!changed)
1638 return failure();
1639
1640 // If no operands remain, replace with empty string.
1641 if (newOperands.empty())
1642 return rewriter.replaceOpWithNewOp<StringConstantOp>(
1643 concat, StringAttr::get(getContext(), "")),
1644 success();
1645
1646 // Single-operand case is handled by the folder.
1647 rewriter.modifyOpInPlace(concat,
1648 [&]() { concat->setOperands(newOperands); });
1649 return success();
1650 }
1651};
1652
1653} // namespace
1654
1655void StringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
1656 MLIRContext *context) {
1657 results.insert<FlattenStringConcat, MergeAdjacentStringConstants>(context);
1658}
1659
1660OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1661 auto op = (*this);
1662 // BitCast is redundant if input and result types are same.
1663 if (op.getType() == op.getInput().getType())
1664 return op.getInput();
1665
1666 // Two consecutive BitCasts are redundant if first bitcast type is same as the
1667 // final result type.
1668 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1669 if (op.getType() == in.getInput().getType())
1670 return in.getInput();
1671
1672 return {};
1673}
1674
1675OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1676 IntType inputType = getInput().getType();
1677 IntType resultType = getType();
1678 // If we are extracting the entire input, then return it.
1679 if (inputType == getType() && resultType.hasWidth())
1680 return getInput();
1681
1682 // Constant fold.
1683 if (hasKnownWidthIntTypes(*this))
1684 if (auto cst = getConstant(adaptor.getInput()))
1685 return getIntAttr(resultType,
1686 cst->extractBits(getHi() - getLo() + 1, getLo()));
1687
1688 return {};
1689}
1690
1691struct BitsOfCat : public mlir::OpRewritePattern<BitsPrimOp> {
1692 using OpRewritePattern::OpRewritePattern;
1693
1694 LogicalResult
1695 matchAndRewrite(BitsPrimOp bits,
1696 mlir::PatternRewriter &rewriter) const override {
1697 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1698 if (!cat)
1699 return failure();
1700 int32_t bitPos = bits.getLo();
1701 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1702 if (resultWidth < 0)
1703 return failure();
1704 for (auto operand : llvm::reverse(cat.getInputs())) {
1705 auto operandWidth =
1706 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1707 if (operandWidth < 0)
1708 return failure();
1709 if (bitPos < operandWidth) {
1710 if (bitPos + resultWidth <= operandWidth) {
1711 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1712 bits.getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1713 replaceOpAndCopyName(rewriter, bits, newBits);
1714 return success();
1715 }
1716 return failure();
1717 }
1718 bitPos -= operandWidth;
1719 }
1720 return failure();
1721 }
1722};
1723
1724void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1725 MLIRContext *context) {
1726 results
1727 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1728 patterns::BitsOfAnd, patterns::BitsOfPad, BitsOfCat>(context);
1729}
1730
1731/// Replace the specified operation with a 'bits' op from the specified hi/lo
1732/// bits. Insert a cast to handle the case where the original operation
1733/// returned a signed integer.
1734static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1735 unsigned loBit, PatternRewriter &rewriter) {
1736 auto resType = type_cast<IntType>(op->getResult(0).getType());
1737 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1738 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1739
1740 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1741 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1742 } else if (resType.isUnsigned() &&
1743 !type_cast<IntType>(value.getType()).isUnsigned()) {
1744 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1745 }
1746 rewriter.replaceOp(op, value);
1747}
1748
1749template <typename OpTy>
1750static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1751 // mux : UInt<0> -> 0
1752 if (op.getType().getBitWidthOrSentinel() == 0)
1753 return getIntAttr(op.getType(),
1754 APInt(0, 0, op.getType().isSignedInteger()));
1755
1756 // mux(cond, x, x) -> x
1757 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1758 return op.getHigh();
1759
1760 // The following folds require that the result has a known width. Otherwise
1761 // the mux requires an additional padding operation to be inserted, which is
1762 // not possible in a fold.
1763 if (op.getType().getBitWidthOrSentinel() < 0)
1764 return {};
1765
1766 // mux(0/1, x, y) -> x or y
1767 if (auto cond = getConstant(adaptor.getSel())) {
1768 if (cond->isZero() && op.getLow().getType() == op.getType())
1769 return op.getLow();
1770 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1771 return op.getHigh();
1772 }
1773
1774 // mux(cond, x, cst)
1775 if (auto lowCst = getConstant(adaptor.getLow())) {
1776 // mux(cond, c1, c2)
1777 if (auto highCst = getConstant(adaptor.getHigh())) {
1778 // mux(cond, cst, cst) -> cst
1779 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1780 *highCst == *lowCst)
1781 return getIntAttr(op.getType(), *highCst);
1782 // mux(cond, 1, 0) -> cond
1783 if (highCst->isOne() && lowCst->isZero() &&
1784 op.getType() == op.getSel().getType())
1785 return op.getSel();
1786
1787 // TODO: x ? ~0 : 0 -> sext(x)
1788 // TODO: "x ? c1 : c2" -> many tricks
1789 }
1790 // TODO: "x ? a : 0" -> sext(x) & a
1791 }
1792
1793 // TODO: "x ? c1 : y" -> "~x ? y : c1"
1794 return {};
1795}
1796
1797OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1798 return foldMux(*this, adaptor);
1799}
1800
1801OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1802 return foldMux(*this, adaptor);
1803}
1804
1805OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1806
1807namespace {
1808
1809// If the mux has a known output width, pad the operands up to this width.
1810// Most folds on mux require that folded operands are of the same width as
1811// the mux itself.
1812class MuxPad : public mlir::OpRewritePattern<MuxPrimOp> {
1813public:
1814 using OpRewritePattern::OpRewritePattern;
1815
1816 LogicalResult
1817 matchAndRewrite(MuxPrimOp mux,
1818 mlir::PatternRewriter &rewriter) const override {
1819 auto width = mux.getType().getBitWidthOrSentinel();
1820 if (width < 0)
1821 return failure();
1822
1823 auto pad = [&](Value input) -> Value {
1824 auto inputWidth =
1825 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1826 if (inputWidth < 0 || width == inputWidth)
1827 return input;
1828 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1829 width)
1830 .getResult();
1831 };
1832
1833 auto newHigh = pad(mux.getHigh());
1834 auto newLow = pad(mux.getLow());
1835 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1836 return failure();
1837
1838 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1839 rewriter, mux, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1840 mux->getAttrs());
1841 return success();
1842 }
1843};
1844
1845// Find muxes which have conditions dominated by other muxes with the same
1846// condition.
1847class MuxSharedCond : public mlir::OpRewritePattern<MuxPrimOp> {
1848public:
1849 using OpRewritePattern::OpRewritePattern;
1850
1851 static const int depthLimit = 5;
1852
1853 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1854 mlir::PatternRewriter &rewriter,
1855 bool updateInPlace) const {
1856 if (updateInPlace) {
1857 rewriter.modifyOpInPlace(mux, [&] {
1858 mux.setOperand(1, high);
1859 mux.setOperand(2, low);
1860 });
1861 return {};
1862 }
1863 rewriter.setInsertionPointAfter(mux);
1864 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1865 ValueRange{mux.getSel(), high, low})
1866 .getResult();
1867 }
1868
1869 // Walk a dependent mux tree assuming the condition cond is true.
1870 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1871 bool updateInPlace, int limit) const {
1872 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1873 if (!mux)
1874 return {};
1875 if (mux.getSel() == cond)
1876 return mux.getHigh();
1877 if (limit > depthLimit)
1878 return {};
1879 updateInPlace &= mux->hasOneUse();
1880
1881 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1882 limit + 1))
1883 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1884
1885 if (Value v =
1886 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1887 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1888 return {};
1889 }
1890
1891 // Walk a dependent mux tree assuming the condition cond is false.
1892 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1893 bool updateInPlace, int limit) const {
1894 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1895 if (!mux)
1896 return {};
1897 if (mux.getSel() == cond)
1898 return mux.getLow();
1899 if (limit > depthLimit)
1900 return {};
1901 updateInPlace &= mux->hasOneUse();
1902
1903 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1904 limit + 1))
1905 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1906
1907 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1908 limit + 1))
1909 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1910
1911 return {};
1912 }
1913
1914 LogicalResult
1915 matchAndRewrite(MuxPrimOp mux,
1916 mlir::PatternRewriter &rewriter) const override {
1917 auto width = mux.getType().getBitWidthOrSentinel();
1918 if (width < 0)
1919 return failure();
1920
1921 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1922 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1923 return success();
1924 }
1925
1926 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
1927 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1928 return success();
1929 }
1930
1931 return failure();
1932 }
1933};
1934} // namespace
1935
1936void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1937 MLIRContext *context) {
1938 results
1939 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1940 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1941 patterns::MuxSameTrue, patterns::MuxSameFalse,
1942 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1943 context);
1944}
1945
1946void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1947 RewritePatternSet &results, MLIRContext *context) {
1948 results.add<patterns::Mux2PadSel>(context);
1949}
1950
1951void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1952 RewritePatternSet &results, MLIRContext *context) {
1953 results.add<patterns::Mux4PadSel>(context);
1954}
1955
1956OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1957 auto input = this->getInput();
1958
1959 // pad(x) -> x if the width doesn't change.
1960 if (input.getType() == getType())
1961 return input;
1962
1963 // Need to know the input width.
1964 auto inputType = input.getType().base();
1965 int32_t width = inputType.getWidthOrSentinel();
1966 if (width == -1)
1967 return {};
1968
1969 // Constant fold.
1970 if (auto cst = getConstant(adaptor.getInput())) {
1971 auto destWidth = getType().base().getWidthOrSentinel();
1972 if (destWidth == -1)
1973 return {};
1974
1975 if (inputType.isSigned() && cst->getBitWidth())
1976 return getIntAttr(getType(), cst->sext(destWidth));
1977 return getIntAttr(getType(), cst->zext(destWidth));
1978 }
1979
1980 return {};
1981}
1982
1983OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1984 auto input = this->getInput();
1985 IntType inputType = input.getType();
1986 int shiftAmount = getAmount();
1987
1988 // shl(x, 0) -> x
1989 if (shiftAmount == 0)
1990 return input;
1991
1992 // Constant fold.
1993 if (auto cst = getConstant(adaptor.getInput())) {
1994 auto inputWidth = inputType.getWidthOrSentinel();
1995 if (inputWidth != -1) {
1996 auto resultWidth = inputWidth + shiftAmount;
1997 shiftAmount = std::min(shiftAmount, resultWidth);
1998 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1999 }
2000 }
2001 return {};
2002}
2003
2004OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
2005 auto input = this->getInput();
2006 IntType inputType = input.getType();
2007 int shiftAmount = getAmount();
2008 auto inputWidth = inputType.getWidthOrSentinel();
2009
2010 // shr(x, 0) -> x
2011 // Once the shr width changes, do this: shiftAmount == 0 &&
2012 // (!inputType.isSigned() || inputWidth > 0)
2013 if (shiftAmount == 0 && inputWidth > 0)
2014 return input;
2015
2016 if (inputWidth == -1)
2017 return {};
2018 if (inputWidth == 0)
2019 return getIntZerosAttr(getType());
2020
2021 // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
2022 // If x is signed, it is the sign bit.
2023 if (shiftAmount >= inputWidth && inputType.isUnsigned())
2024 return getIntAttr(getType(), APInt(0, 0, false));
2025
2026 // Constant fold.
2027 if (auto cst = getConstant(adaptor.getInput())) {
2028 APInt value;
2029 if (inputType.isSigned())
2030 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
2031 else
2032 value = cst->lshr(std::min(shiftAmount, inputWidth));
2033 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
2034 return getIntAttr(getType(), value.trunc(resultWidth));
2035 }
2036 return {};
2037}
2038
2039LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
2040 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2041 if (inputWidth <= 0)
2042 return failure();
2043
2044 // If we know the input width, we can canonicalize this into a BitsPrimOp.
2045 unsigned shiftAmount = op.getAmount();
2046 if (int(shiftAmount) >= inputWidth) {
2047 // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
2048 if (op.getType().base().isUnsigned())
2049 return failure();
2050
2051 // Shifting a signed value by the full width is actually taking the
2052 // sign bit. If the shift amount is greater than the input width, it
2053 // is equivalent to shifting by the input width.
2054 shiftAmount = inputWidth - 1;
2055 }
2056
2057 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
2058 return success();
2059}
2060
2061LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
2062 PatternRewriter &rewriter) {
2063 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2064 if (inputWidth <= 0)
2065 return failure();
2066
2067 // If we know the input width, we can canonicalize this into a BitsPrimOp.
2068 unsigned keepAmount = op.getAmount();
2069 if (keepAmount)
2070 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
2071 rewriter);
2072 return success();
2073}
2074
2075OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
2076 if (hasKnownWidthIntTypes(*this))
2077 if (auto cst = getConstant(adaptor.getInput())) {
2078 int shiftAmount =
2079 getInput().getType().base().getWidthOrSentinel() - getAmount();
2080 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
2081 }
2082
2083 return {};
2084}
2085
2086OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
2087 if (hasKnownWidthIntTypes(*this))
2088 if (auto cst = getConstant(adaptor.getInput()))
2089 return getIntAttr(getType(),
2090 cst->trunc(getType().base().getWidthOrSentinel()));
2091 return {};
2092}
2093
2094LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
2095 PatternRewriter &rewriter) {
2096 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2097 if (inputWidth <= 0)
2098 return failure();
2099
2100 // If we know the input width, we can canonicalize this into a BitsPrimOp.
2101 unsigned dropAmount = op.getAmount();
2102 if (dropAmount != unsigned(inputWidth))
2103 replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
2104 rewriter);
2105 return success();
2106}
2107
2108void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
2109 MLIRContext *context) {
2110 results.add<patterns::SubaccessOfConstant>(context);
2111}
2112
2113OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
2114 // If there is only one input, just return it.
2115 if (adaptor.getInputs().size() == 1)
2116 return getOperand(1);
2117
2118 if (auto constIndex = getConstant(adaptor.getIndex())) {
2119 auto index = constIndex->getZExtValue();
2120 if (index < getInputs().size())
2121 return getInputs()[getInputs().size() - 1 - index];
2122 }
2123
2124 return {};
2125}
2126
2127LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
2128 PatternRewriter &rewriter) {
2129 // If all operands are equal, just canonicalize to it. We can add this
2130 // canonicalization as a folder but it costly to look through all inputs so it
2131 // is added here.
2132 if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
2133 return input == op.getInputs().front();
2134 })) {
2135 replaceOpAndCopyName(rewriter, op, op.getInputs().front());
2136 return success();
2137 }
2138
2139 // If the index width is narrower than the size of inputs, drop front
2140 // elements.
2141 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
2142 uint64_t inputSize = op.getInputs().size();
2143 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
2144 rewriter.modifyOpInPlace(op, [&]() {
2145 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2146 });
2147 return success();
2148 }
2149
2150 // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
2151 // a[0]`), we can fold the op into subaccess op `a[idx]`.
2152 if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2153 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
2154 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2155 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2156 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2157 })) {
2158 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2159 rewriter, op, lastSubindex.getInput(), op.getIndex());
2160 return success();
2161 }
2162 }
2163
2164 // If the size is 2, canonicalize into a normal mux to introduce more folds.
2165 if (op.getInputs().size() != 2)
2166 return failure();
2167
2168 // TODO: Handle even when `index` doesn't have uint<1>.
2169 auto uintType = op.getIndex().getType();
2170 if (uintType.getBitWidthOrSentinel() != 1)
2171 return failure();
2172
2173 // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
2174 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2175 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2176 return success();
2177}
2178
2179//===----------------------------------------------------------------------===//
2180// Declarations
2181//===----------------------------------------------------------------------===//
2182
2183/// Scan all the uses of the specified value, checking to see if there is
2184/// exactly one connect that has the value as its destination. This returns the
2185/// operation if found and if all the other users are "reads" from the value.
2186/// Returns null if there are no connects, or multiple connects to the value, or
2187/// if the value is involved in an `AttachOp`, or if the connect isn't matching.
2188///
2189/// Note that this will simply return the connect, which is located *anywhere*
2190/// after the definition of the value. Users of this function are likely
2191/// interested in the source side of the returned connect, the definition of
2192/// which does likely not dominate the original value.
2193MatchingConnectOp firrtl::getSingleConnectUserOf(Value value) {
2194 MatchingConnectOp connect;
2195 for (Operation *user : value.getUsers()) {
2196 // If we see an attach or aggregate sublements, just conservatively fail.
2197 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2198 return {};
2199
2200 if (auto aConnect = dyn_cast<FConnectLike>(user))
2201 if (aConnect.getDest() == value) {
2202 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2203 // If this is not a matching connect, a second matching connect or in a
2204 // different block, fail.
2205 if (!matchingConnect || (connect && connect != matchingConnect) ||
2206 matchingConnect->getBlock() != value.getParentBlock())
2207 return {};
2208 connect = matchingConnect;
2209 }
2210 }
2211 return connect;
2212}
2213
2214// Forward simple values through wire's and reg's.
2215static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op,
2216 PatternRewriter &rewriter) {
2217 // While we can do this for nearly all wires, we currently limit it to simple
2218 // things.
2219 Operation *connectedDecl = op.getDest().getDefiningOp();
2220 if (!connectedDecl)
2221 return failure();
2222
2223 // Only support wire and reg for now.
2224 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2225 return failure();
2226 if (hasDontTouch(connectedDecl) || !AnnotationSet(connectedDecl).empty() ||
2227 !hasDroppableName(connectedDecl) ||
2228 cast<Forceable>(connectedDecl).isForceable())
2229 return failure();
2230
2231 // Only forward if the types exactly match and there is one connect.
2232 if (getSingleConnectUserOf(op.getDest()) != op)
2233 return failure();
2234
2235 // Only forward if there is more than one use
2236 if (connectedDecl->hasOneUse())
2237 return failure();
2238
2239 // Only do this if the connectee and the declaration are in the same block.
2240 auto *declBlock = connectedDecl->getBlock();
2241 auto *srcValueOp = op.getSrc().getDefiningOp();
2242 if (!srcValueOp) {
2243 // Ports are ok for wires but not registers.
2244 if (!isa<WireOp>(connectedDecl))
2245 return failure();
2246
2247 } else {
2248 // Constants/invalids in the same block are ok to forward, even through
2249 // reg's since the clocking doesn't matter for constants.
2250 if (!isa<ConstantOp>(srcValueOp))
2251 return failure();
2252 if (srcValueOp->getBlock() != declBlock)
2253 return failure();
2254 }
2255
2256 // Ok, we know we are doing the transformation.
2257
2258 auto replacement = op.getSrc();
2259 // This will be replaced with the constant source. First, make sure the
2260 // constant dominates all users.
2261 if (srcValueOp && srcValueOp != &declBlock->front())
2262 srcValueOp->moveBefore(&declBlock->front());
2263
2264 // Replace all things *using* the decl with the constant/port, and
2265 // remove the declaration.
2266 replaceOpAndCopyName(rewriter, connectedDecl, replacement);
2267
2268 // Remove the connect
2269 rewriter.eraseOp(op);
2270 return success();
2271}
2272
2273void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2274 MLIRContext *context) {
2275 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2276 context);
2277}
2278
2279LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2280 PatternRewriter &rewriter) {
2281 // TODO: Canonicalize towards explicit extensions and flips here.
2282
2283 // If there is a simple value connected to a foldable decl like a wire or reg,
2284 // see if we can eliminate the decl.
2285 if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
2286 return success();
2287 return failure();
2288}
2289
2290//===----------------------------------------------------------------------===//
2291// Statements
2292//===----------------------------------------------------------------------===//
2293
2294/// If the specified value has an AttachOp user strictly dominating by
2295/// "dominatingAttach" then return it.
2296static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
2297 for (auto *user : value.getUsers()) {
2298 auto attach = dyn_cast<AttachOp>(user);
2299 if (!attach || attach == dominatedAttach)
2300 continue;
2301 if (attach->isBeforeInBlock(dominatedAttach))
2302 return attach;
2303 }
2304 return {};
2305}
2306
2307LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2308 // Single operand attaches are a noop.
2309 if (op.getNumOperands() <= 1) {
2310 rewriter.eraseOp(op);
2311 return success();
2312 }
2313
2314 for (auto operand : op.getOperands()) {
2315 // Check to see if any of our operands has other attaches to it:
2316 // attach x, y
2317 // ...
2318 // attach x, z
2319 // If so, we can merge these into "attach x, y, z".
2320 if (auto attach = getDominatingAttachUser(operand, op)) {
2321 SmallVector<Value> newOperands(op.getOperands());
2322 for (auto newOperand : attach.getOperands())
2323 if (newOperand != operand) // Don't add operand twice.
2324 newOperands.push_back(newOperand);
2325 AttachOp::create(rewriter, op->getLoc(), newOperands);
2326 rewriter.eraseOp(attach);
2327 rewriter.eraseOp(op);
2328 return success();
2329 }
2330
2331 // If this wire is *only* used by an attach then we can just delete
2332 // it.
2333 // TODO: May need to be sensitive to "don't touch" or other
2334 // annotations.
2335 if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2336 if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2337 !wire.isForceable()) {
2338 SmallVector<Value> newOperands;
2339 for (auto newOperand : op.getOperands())
2340 if (newOperand != operand) // Don't the add wire.
2341 newOperands.push_back(newOperand);
2342
2343 AttachOp::create(rewriter, op->getLoc(), newOperands);
2344 rewriter.eraseOp(op);
2345 rewriter.eraseOp(wire);
2346 return success();
2347 }
2348 }
2349 }
2350 return failure();
2351}
2352
2353/// Replaces the given op with the contents of the given single-block region.
2354static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
2355 Region &region) {
2356 assert(llvm::hasSingleElement(region) && "expected single-region block");
2357 rewriter.inlineBlockBefore(&region.front(), op, {});
2358}
2359
2360LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2361 if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2362 if (constant.getValue().isAllOnes())
2363 replaceOpWithRegion(rewriter, op, op.getThenRegion());
2364 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2365 replaceOpWithRegion(rewriter, op, op.getElseRegion());
2366
2367 rewriter.eraseOp(op);
2368
2369 return success();
2370 }
2371
2372 // Erase empty if-else block.
2373 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2374 op.getElseBlock().empty()) {
2375 rewriter.eraseBlock(&op.getElseBlock());
2376 return success();
2377 }
2378
2379 // Erase empty whens.
2380
2381 // If there is stuff in the then block, leave this operation alone.
2382 if (!op.getThenBlock().empty())
2383 return failure();
2384
2385 // If not and there is no else, then this operation is just useless.
2386 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2387 rewriter.eraseOp(op);
2388 return success();
2389 }
2390 return failure();
2391}
2392
2393namespace {
2394// Remove private nodes. If they have an interesting names, move the name to
2395// the source expression.
2396struct FoldNodeName : public mlir::OpRewritePattern<NodeOp> {
2397 using OpRewritePattern::OpRewritePattern;
2398 LogicalResult matchAndRewrite(NodeOp node,
2399 PatternRewriter &rewriter) const override {
2400 auto name = node.getNameAttr();
2401 if (!node.hasDroppableName() || node.getInnerSym() ||
2402 !AnnotationSet(node).empty() || node.isForceable())
2403 return failure();
2404 auto *newOp = node.getInput().getDefiningOp();
2405 if (newOp)
2406 updateName(rewriter, newOp, name);
2407 rewriter.replaceOp(node, node.getInput());
2408 return success();
2409 }
2410};
2411
2412// Bypass nodes.
2413struct NodeBypass : public mlir::OpRewritePattern<NodeOp> {
2414 using OpRewritePattern::OpRewritePattern;
2415 LogicalResult matchAndRewrite(NodeOp node,
2416 PatternRewriter &rewriter) const override {
2417 if (node.getInnerSym() || !AnnotationSet(node).empty() ||
2418 node.use_empty() || node.isForceable())
2419 return failure();
2420 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2421 return success();
2422 }
2423};
2424
2425} // namespace
2426
2427template <typename OpTy>
2428static LogicalResult demoteForceableIfUnused(OpTy op,
2429 PatternRewriter &rewriter) {
2430 if (!op.isForceable() || !op.getDataRef().use_empty())
2431 return failure();
2432
2433 firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
2434 return success();
2435}
2436
2437// Interesting names and symbols and don't touch force nodes to stick around.
2438LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2439 SmallVectorImpl<OpFoldResult> &results) {
2440 if (!hasDroppableName())
2441 return failure();
2442 if (hasDontTouch(getResult())) // handles inner symbols
2443 return failure();
2444 if (getAnnotationsAttr() && !AnnotationSet(getAnnotationsAttr()).empty())
2445 return failure();
2446 if (isForceable())
2447 return failure();
2448 if (!adaptor.getInput())
2449 return failure();
2450
2451 results.push_back(adaptor.getInput());
2452 return success();
2453}
2454
2455void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2456 MLIRContext *context) {
2457 results.insert<FoldNodeName>(context);
2458 results.add(demoteForceableIfUnused<NodeOp>);
2459}
2460
2461namespace {
2462// For a lhs, find all the writers of fields of the aggregate type. If there
2463// is one writer for each field, merge the writes
2464struct AggOneShot : public mlir::RewritePattern {
2465 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2466 : RewritePattern(name, 0, context) {}
2467
2468 SmallVector<Value> getCompleteWrite(Operation *lhs) const {
2469 auto lhsTy = lhs->getResult(0).getType();
2470 if (!type_isa<BundleType, FVectorType>(lhsTy))
2471 return {};
2472
2473 DenseMap<uint32_t, Value> fields;
2474 for (Operation *user : lhs->getResult(0).getUsers()) {
2475 if (user->getParentOp() != lhs->getParentOp())
2476 return {};
2477 if (auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2478 if (aConnect.getDest() == lhs->getResult(0))
2479 return {};
2480 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2481 for (Operation *subuser : subField.getResult().getUsers()) {
2482 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2483 if (aConnect.getDest() == subField) {
2484 if (subuser->getParentOp() != lhs->getParentOp())
2485 return {};
2486 if (fields.count(subField.getFieldIndex())) // duplicate write
2487 return {};
2488 fields[subField.getFieldIndex()] = aConnect.getSrc();
2489 }
2490 continue;
2491 }
2492 return {};
2493 }
2494 } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2495 for (Operation *subuser : subIndex.getResult().getUsers()) {
2496 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2497 if (aConnect.getDest() == subIndex) {
2498 if (subuser->getParentOp() != lhs->getParentOp())
2499 return {};
2500 if (fields.count(subIndex.getIndex())) // duplicate write
2501 return {};
2502 fields[subIndex.getIndex()] = aConnect.getSrc();
2503 }
2504 continue;
2505 }
2506 return {};
2507 }
2508 } else {
2509 return {};
2510 }
2511 }
2512
2513 SmallVector<Value> values;
2514 uint32_t total = type_isa<BundleType>(lhsTy)
2515 ? type_cast<BundleType>(lhsTy).getNumElements()
2516 : type_cast<FVectorType>(lhsTy).getNumElements();
2517 for (uint32_t i = 0; i < total; ++i) {
2518 if (!fields.count(i))
2519 return {};
2520 values.push_back(fields[i]);
2521 }
2522 return values;
2523 }
2524
2525 LogicalResult matchAndRewrite(Operation *op,
2526 PatternRewriter &rewriter) const override {
2527 auto values = getCompleteWrite(op);
2528 if (values.empty())
2529 return failure();
2530 rewriter.setInsertionPointToEnd(op->getBlock());
2531 auto dest = op->getResult(0);
2532 auto destType = dest.getType();
2533
2534 // If not passive, cannot matchingconnect.
2535 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2536 return failure();
2537
2538 Value newVal = type_isa<BundleType>(destType)
2539 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2540 destType, values)
2541 : rewriter.createOrFold<VectorCreateOp>(
2542 op->getLoc(), destType, values);
2543 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2544 for (Operation *user : dest.getUsers()) {
2545 if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2546 for (Operation *subuser :
2547 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2548 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2549 if (aConnect.getDest() == subIndex)
2550 rewriter.eraseOp(aConnect);
2551 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2552 for (Operation *subuser :
2553 llvm::make_early_inc_range(subField.getResult().getUsers()))
2554 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2555 if (aConnect.getDest() == subField)
2556 rewriter.eraseOp(aConnect);
2557 }
2558 }
2559 return success();
2560 }
2561};
2562
2563struct WireAggOneShot : public AggOneShot {
2564 WireAggOneShot(MLIRContext *context)
2565 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2566};
2567struct SubindexAggOneShot : public AggOneShot {
2568 SubindexAggOneShot(MLIRContext *context)
2569 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2570};
2571struct SubfieldAggOneShot : public AggOneShot {
2572 SubfieldAggOneShot(MLIRContext *context)
2573 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2574};
2575} // namespace
2576
2577void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2578 MLIRContext *context) {
2579 results.insert<WireAggOneShot>(context);
2580 results.add(demoteForceableIfUnused<WireOp>);
2581}
2582
2583void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2584 MLIRContext *context) {
2585 results.insert<SubindexAggOneShot>(context);
2586}
2587
2588OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2589 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2590 if (!attr)
2591 return {};
2592 return attr[getIndex()];
2593}
2594
2595OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2596 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2597 if (!attr)
2598 return {};
2599 auto index = getFieldIndex();
2600 return attr[index];
2601}
2602
2603void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2604 MLIRContext *context) {
2605 results.insert<SubfieldAggOneShot>(context);
2606}
2607
2608static Attribute collectFields(MLIRContext *context,
2609 ArrayRef<Attribute> operands) {
2610 for (auto operand : operands)
2611 if (!operand)
2612 return {};
2613 return ArrayAttr::get(context, operands);
2614}
2615
2616OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2617 // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2618 // bundle<a:..., b:...>.
2619 if (getNumOperands() > 0)
2620 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2621 if (first.getFieldIndex() == 0 &&
2622 first.getInput().getType() == getType() &&
2623 llvm::all_of(
2624 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2625 auto subindex =
2626 elem.value().template getDefiningOp<SubfieldOp>();
2627 return subindex && subindex.getInput() == first.getInput() &&
2628 subindex.getFieldIndex() == elem.index();
2629 }))
2630 return first.getInput();
2631
2632 return collectFields(getContext(), adaptor.getOperands());
2633}
2634
2635OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2636 // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2637 // vector<..., 2>.
2638 if (getNumOperands() > 0)
2639 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2640 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2641 llvm::all_of(
2642 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2643 auto subindex =
2644 elem.value().template getDefiningOp<SubindexOp>();
2645 return subindex && subindex.getInput() == first.getInput() &&
2646 subindex.getIndex() == elem.index();
2647 }))
2648 return first.getInput();
2649
2650 return collectFields(getContext(), adaptor.getOperands());
2651}
2652
2653OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2654 if (getOperand().getType() == getType())
2655 return getOperand();
2656 return {};
2657}
2658
2659namespace {
2660// A register with constant reset and all connection to either itself or the
2661// same constant, must be replaced by the constant.
2662struct FoldResetMux : public mlir::OpRewritePattern<RegResetOp> {
2663 using OpRewritePattern::OpRewritePattern;
2664 LogicalResult matchAndRewrite(RegResetOp reg,
2665 PatternRewriter &rewriter) const override {
2666 auto reset =
2667 dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2668 if (!reset || hasDontTouch(reg.getOperation()) ||
2669 !AnnotationSet(reg).empty() || reg.isForceable())
2670 return failure();
2671 // Find the one true connect, or bail
2672 auto con = getSingleConnectUserOf(reg.getResult());
2673 if (!con)
2674 return failure();
2675
2676 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2677 if (!mux)
2678 return failure();
2679 auto *high = mux.getHigh().getDefiningOp();
2680 auto *low = mux.getLow().getDefiningOp();
2681 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2682
2683 if (constOp && low != reg)
2684 return failure();
2685 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2686 constOp = dyn_cast<ConstantOp>(low);
2687
2688 if (!constOp || constOp.getType() != reset.getType() ||
2689 constOp.getValue() != reset.getValue())
2690 return failure();
2691
2692 // Check all types should be typed by now
2693 auto regTy = reg.getResult().getType();
2694 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2695 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2696 regTy.getBitWidthOrSentinel() < 0)
2697 return failure();
2698
2699 // Ok, we know we are doing the transformation.
2700
2701 // Make sure the constant dominates all users.
2702 if (constOp != &con->getBlock()->front())
2703 constOp->moveBefore(&con->getBlock()->front());
2704
2705 // Replace the register with the constant.
2706 replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2707 // Remove the connect.
2708 rewriter.eraseOp(con);
2709 return success();
2710 }
2711};
2712} // namespace
2713
2714static bool isDefinedByOneConstantOp(Value v) {
2715 if (auto c = v.getDefiningOp<ConstantOp>())
2716 return c.getValue().isOne();
2717 if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2718 return sc.getValue();
2719 return false;
2720}
2721
2722static LogicalResult
2723canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2724 if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2725 return failure();
2726
2727 auto resetValue = reg.getResetValue();
2728 if (reg.getType(0) != resetValue.getType())
2729 return failure();
2730
2731 // Ignore 'passthrough'.
2732 (void)dropWrite(rewriter, reg->getResult(0), {});
2733 replaceOpWithNewOpAndCopyName<NodeOp>(
2734 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2735 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2736 return success();
2737}
2738
2739void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2740 MLIRContext *context) {
2741 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2743 results.add(demoteForceableIfUnused<RegResetOp>);
2744}
2745
2746// Returns the value connected to a port, if there is only one.
2747static Value getPortFieldValue(Value port, StringRef name) {
2748 auto portTy = type_cast<BundleType>(port.getType());
2749 auto fieldIndex = portTy.getElementIndex(name);
2750 assert(fieldIndex && "missing field on memory port");
2751
2752 Value value = {};
2753 for (auto *op : port.getUsers()) {
2754 auto portAccess = cast<SubfieldOp>(op);
2755 if (fieldIndex != portAccess.getFieldIndex())
2756 continue;
2757 auto conn = getSingleConnectUserOf(portAccess);
2758 if (!conn || value)
2759 return {};
2760 value = conn.getSrc();
2761 }
2762 return value;
2763}
2764
2765// Returns true if the enable field of a port is set to false.
2766static bool isPortDisabled(Value port) {
2767 auto value = getPortFieldValue(port, "en");
2768 if (!value)
2769 return false;
2770 auto portConst = value.getDefiningOp<ConstantOp>();
2771 if (!portConst)
2772 return false;
2773 return portConst.getValue().isZero();
2774}
2775
2776// Returns true if the data output is unused.
2777static bool isPortUnused(Value port, StringRef data) {
2778 auto portTy = type_cast<BundleType>(port.getType());
2779 auto fieldIndex = portTy.getElementIndex(data);
2780 assert(fieldIndex && "missing enable flag on memory port");
2781
2782 for (auto *op : port.getUsers()) {
2783 auto portAccess = cast<SubfieldOp>(op);
2784 if (fieldIndex != portAccess.getFieldIndex())
2785 continue;
2786 if (!portAccess.use_empty())
2787 return false;
2788 }
2789
2790 return true;
2791}
2792
2793// Returns the value connected to a port, if there is only one.
2794static void replacePortField(PatternRewriter &rewriter, Value port,
2795 StringRef name, Value value) {
2796 auto portTy = type_cast<BundleType>(port.getType());
2797 auto fieldIndex = portTy.getElementIndex(name);
2798 assert(fieldIndex && "missing field on memory port");
2799
2800 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2801 auto portAccess = cast<SubfieldOp>(op);
2802 if (fieldIndex != portAccess.getFieldIndex())
2803 continue;
2804 rewriter.replaceAllUsesWith(portAccess, value);
2805 rewriter.eraseOp(portAccess);
2806 }
2807}
2808
2809// Remove accesses to a port which is used.
2810static void erasePort(PatternRewriter &rewriter, Value port) {
2811 // Helper to create a dummy 0 clock for the dummy registers.
2812 Value clock;
2813 auto getClock = [&] {
2814 if (!clock)
2815 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2816 ClockType::get(rewriter.getContext()),
2817 false);
2818 return clock;
2819 };
2820
2821 // Find the clock field of the port and determine whether the port is
2822 // accessed only through its subfields or as a whole wire. If the port
2823 // is used in its entirety, replace it with a wire. Otherwise,
2824 // eliminate individual subfields and replace with reasonable defaults.
2825 for (auto *op : port.getUsers()) {
2826 auto subfield = dyn_cast<SubfieldOp>(op);
2827 if (!subfield) {
2828 auto ty = port.getType();
2829 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2830 rewriter.replaceAllUsesWith(port, reg.getResult());
2831 return;
2832 }
2833 }
2834
2835 // Remove all connects to field accesses as they are no longer relevant.
2836 // If field values are used anywhere, which should happen solely for read
2837 // ports, a dummy register is introduced which replicates the behaviour of
2838 // memory that is never written, but might be read.
2839 for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2840 auto access = cast<SubfieldOp>(accessOp);
2841 for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2842 auto connect = dyn_cast<FConnectLike>(user);
2843 if (connect && connect.getDest() == access) {
2844 rewriter.eraseOp(user);
2845 continue;
2846 }
2847 }
2848 if (access.use_empty()) {
2849 rewriter.eraseOp(access);
2850 continue;
2851 }
2852
2853 // Replace read values with a register that is never written, handing off
2854 // the canonicalization of such a register to another canonicalizer.
2855 auto ty = access.getType();
2856 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2857 rewriter.replaceOp(access, reg.getResult());
2858 }
2859 assert(port.use_empty() && "port should have no remaining uses");
2860}
2861
2862namespace {
2863// If memory has known, but zero width, eliminate it.
2864struct FoldZeroWidthMemory : public mlir::OpRewritePattern<MemOp> {
2865 using OpRewritePattern::OpRewritePattern;
2866 LogicalResult matchAndRewrite(MemOp mem,
2867 PatternRewriter &rewriter) const override {
2868 if (hasDontTouch(mem))
2869 return failure();
2870
2871 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2872 mem.getDataType().getBitWidthOrSentinel() != 0)
2873 return failure();
2874
2875 // Make sure are users are safe to replace
2876 for (auto port : mem.getResults())
2877 for (auto *user : port.getUsers())
2878 if (!isa<SubfieldOp>(user))
2879 return failure();
2880
2881 // Annoyingly, there isn't a good replacement for the port as a whole,
2882 // since they have an outer flip type.
2883 for (auto port : mem.getResults()) {
2884 for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2885 SubfieldOp sfop = cast<SubfieldOp>(user);
2886 StringRef fieldName = sfop.getFieldName();
2887 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2888 rewriter, sfop, sfop.getResult().getType())
2889 .getResult();
2890 if (fieldName.ends_with("data")) {
2891 // Make sure to write data ports.
2892 auto zero = firrtl::ConstantOp::create(
2893 rewriter, wire.getLoc(),
2894 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2895 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2896 }
2897 }
2898 }
2899 rewriter.eraseOp(mem);
2900 return success();
2901 }
2902};
2903
2904// If memory has no write ports and no file initialization, eliminate it.
2905struct FoldReadOrWriteOnlyMemory : public mlir::OpRewritePattern<MemOp> {
2906 using OpRewritePattern::OpRewritePattern;
2907 LogicalResult matchAndRewrite(MemOp mem,
2908 PatternRewriter &rewriter) const override {
2909 if (hasDontTouch(mem))
2910 return failure();
2911 bool isRead = false, isWritten = false;
2912 for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2913 switch (mem.getPortKind(i)) {
2914 case MemOp::PortKind::Read:
2915 isRead = true;
2916 if (isWritten)
2917 return failure();
2918 continue;
2919 case MemOp::PortKind::Write:
2920 isWritten = true;
2921 if (isRead)
2922 return failure();
2923 continue;
2924 case MemOp::PortKind::Debug:
2925 case MemOp::PortKind::ReadWrite:
2926 return failure();
2927 }
2928 llvm_unreachable("unknown port kind");
2929 }
2930 assert((!isWritten || !isRead) && "memory is in use");
2931
2932 // If the memory is read only, but has a file initialization, then we can't
2933 // remove it. A write only memory with file initialization is okay to
2934 // remove.
2935 if (isRead && mem.getInit())
2936 return failure();
2937
2938 for (auto port : mem.getResults())
2939 erasePort(rewriter, port);
2940
2941 rewriter.eraseOp(mem);
2942 return success();
2943 }
2944};
2945
2946// Eliminate the dead ports of memories.
2947struct FoldUnusedPorts : public mlir::OpRewritePattern<MemOp> {
2948 using OpRewritePattern::OpRewritePattern;
2949 LogicalResult matchAndRewrite(MemOp mem,
2950 PatternRewriter &rewriter) const override {
2951 if (hasDontTouch(mem))
2952 return failure();
2953 // Identify the dead and changed ports.
2954 llvm::SmallBitVector deadPorts(mem.getNumResults());
2955 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2956 // Do not simplify annotated ports.
2957 if (!mem.getPortAnnotation(i).empty())
2958 continue;
2959
2960 // Skip debug ports.
2961 auto kind = mem.getPortKind(i);
2962 if (kind == MemOp::PortKind::Debug)
2963 continue;
2964
2965 // If a port is disabled, always eliminate it.
2966 if (isPortDisabled(port)) {
2967 deadPorts.set(i);
2968 continue;
2969 }
2970 // Eliminate read ports whose outputs are not used.
2971 if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
2972 deadPorts.set(i);
2973 continue;
2974 }
2975 }
2976 if (deadPorts.none())
2977 return failure();
2978
2979 // Rebuild the new memory with the altered ports.
2980 SmallVector<Type> resultTypes;
2981 SmallVector<StringRef> portNames;
2982 SmallVector<Attribute> portAnnotations;
2983 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2984 if (deadPorts[i])
2985 continue;
2986 resultTypes.push_back(port.getType());
2987 portNames.push_back(mem.getPortName(i));
2988 portAnnotations.push_back(mem.getPortAnnotation(i));
2989 }
2990
2991 MemOp newOp;
2992 if (!resultTypes.empty())
2993 newOp = MemOp::create(
2994 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2995 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2996 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2997 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2998 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2999
3000 // Replace the dead ports with dummy wires.
3001 unsigned nextPort = 0;
3002 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3003 if (deadPorts[i])
3004 erasePort(rewriter, port);
3005 else
3006 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
3007 }
3008
3009 rewriter.eraseOp(mem);
3010 return success();
3011 }
3012};
3013
3014// Rewrite write-only read-write ports to write ports.
3015struct FoldReadWritePorts : public mlir::OpRewritePattern<MemOp> {
3016 using OpRewritePattern::OpRewritePattern;
3017 LogicalResult matchAndRewrite(MemOp mem,
3018 PatternRewriter &rewriter) const override {
3019 if (hasDontTouch(mem))
3020 return failure();
3021
3022 // Identify read-write ports whose read end is unused.
3023 llvm::SmallBitVector deadReads(mem.getNumResults());
3024 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3025 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
3026 continue;
3027 if (!mem.getPortAnnotation(i).empty())
3028 continue;
3029 if (isPortUnused(port, "rdata")) {
3030 deadReads.set(i);
3031 continue;
3032 }
3033 }
3034 if (deadReads.none())
3035 return failure();
3036
3037 SmallVector<Type> resultTypes;
3038 SmallVector<StringRef> portNames;
3039 SmallVector<Attribute> portAnnotations;
3040 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3041 if (deadReads[i])
3042 resultTypes.push_back(
3043 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
3044 MemOp::PortKind::Write, mem.getMaskBits()));
3045 else
3046 resultTypes.push_back(port.getType());
3047
3048 portNames.push_back(mem.getPortName(i));
3049 portAnnotations.push_back(mem.getPortAnnotation(i));
3050 }
3051
3052 auto newOp = MemOp::create(
3053 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3054 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3055 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3056 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3057 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3058
3059 for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
3060 auto result = mem.getResult(i);
3061 auto newResult = newOp.getResult(i);
3062 if (deadReads[i]) {
3063 auto resultPortTy = type_cast<BundleType>(result.getType());
3064
3065 // Rewrite accesses to the old port field to accesses to a
3066 // corresponding field of the new port.
3067 auto replace = [&](StringRef toName, StringRef fromName) {
3068 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
3069 assert(fromFieldIndex && "missing enable flag on memory port");
3070
3071 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
3072 newResult, toName);
3073 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
3074 auto fromField = cast<SubfieldOp>(op);
3075 if (fromFieldIndex != fromField.getFieldIndex())
3076 continue;
3077 rewriter.replaceOp(fromField, toField.getResult());
3078 }
3079 };
3080
3081 replace("addr", "addr");
3082 replace("en", "en");
3083 replace("clk", "clk");
3084 replace("data", "wdata");
3085 replace("mask", "wmask");
3086
3087 // Remove the wmode field, replacing it with dummy wires.
3088 auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
3089 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
3090 auto wmodeField = cast<SubfieldOp>(op);
3091 if (wmodeFieldIndex != wmodeField.getFieldIndex())
3092 continue;
3093 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
3094 }
3095 } else {
3096 rewriter.replaceAllUsesWith(result, newResult);
3097 }
3098 }
3099 rewriter.eraseOp(mem);
3100 return success();
3101 }
3102};
3103
3104// Eliminate the dead ports of memories.
3105struct FoldUnusedBits : public mlir::OpRewritePattern<MemOp> {
3106 using OpRewritePattern::OpRewritePattern;
3107
3108 LogicalResult matchAndRewrite(MemOp mem,
3109 PatternRewriter &rewriter) const override {
3110 if (hasDontTouch(mem))
3111 return failure();
3112
3113 // Only apply the transformation if the memory is not sequential.
3114 const auto &summary = mem.getSummary();
3115 if (summary.isMasked || summary.isSeqMem())
3116 return failure();
3117
3118 auto type = type_dyn_cast<IntType>(mem.getDataType());
3119 if (!type)
3120 return failure();
3121 auto width = type.getBitWidthOrSentinel();
3122 if (width <= 0)
3123 return failure();
3124
3125 llvm::SmallBitVector usedBits(width);
3126 DenseMap<unsigned, unsigned> mapping;
3127
3128 // Find which bits are used out of the users of a read port. This detects
3129 // ports whose data/rdata field is used only through bit select ops. The
3130 // bit selects are then used to build a bit-mask. The ops are collected.
3131 SmallVector<BitsPrimOp> readOps;
3132 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3133 auto portTy = type_cast<BundleType>(port.getType());
3134 auto fieldIndex = portTy.getElementIndex(field);
3135 assert(fieldIndex && "missing data port");
3136
3137 for (auto *op : port.getUsers()) {
3138 auto portAccess = cast<SubfieldOp>(op);
3139 if (fieldIndex != portAccess.getFieldIndex())
3140 continue;
3141
3142 for (auto *user : op->getUsers()) {
3143 auto bits = dyn_cast<BitsPrimOp>(user);
3144 if (!bits)
3145 return failure();
3146
3147 usedBits.set(bits.getLo(), bits.getHi() + 1);
3148 if (usedBits.all())
3149 return failure();
3150
3151 mapping[bits.getLo()] = 0;
3152 readOps.push_back(bits);
3153 }
3154 }
3155
3156 return success();
3157 };
3158
3159 // Finds the users of write ports. This expects all the data/wdata fields
3160 // of the ports to be used solely as the destination of matching connects.
3161 // If a memory has ports with other uses, it is excluded from optimisation.
3162 SmallVector<MatchingConnectOp> writeOps;
3163 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3164 auto portTy = type_cast<BundleType>(port.getType());
3165 auto fieldIndex = portTy.getElementIndex(field);
3166 assert(fieldIndex && "missing data port");
3167
3168 for (auto *op : port.getUsers()) {
3169 auto portAccess = cast<SubfieldOp>(op);
3170 if (fieldIndex != portAccess.getFieldIndex())
3171 continue;
3172
3173 auto conn = getSingleConnectUserOf(portAccess);
3174 if (!conn)
3175 return failure();
3176
3177 writeOps.push_back(conn);
3178 }
3179 return success();
3180 };
3181
3182 // Traverse all ports and find the read and used data fields.
3183 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3184 // Do not simplify annotated ports.
3185 if (!mem.getPortAnnotation(i).empty())
3186 return failure();
3187
3188 switch (mem.getPortKind(i)) {
3189 case MemOp::PortKind::Debug:
3190 // Skip debug ports.
3191 return failure();
3192 case MemOp::PortKind::Write:
3193 if (failed(findWriteUsers(port, "data")))
3194 return failure();
3195 continue;
3196 case MemOp::PortKind::Read:
3197 if (failed(findReadUsers(port, "data")))
3198 return failure();
3199 continue;
3200 case MemOp::PortKind::ReadWrite:
3201 if (failed(findWriteUsers(port, "wdata")))
3202 return failure();
3203 if (failed(findReadUsers(port, "rdata")))
3204 return failure();
3205 continue;
3206 }
3207 llvm_unreachable("unknown port kind");
3208 }
3209
3210 // Unused memories are handled in a different canonicalizer.
3211 if (usedBits.none())
3212 return failure();
3213
3214 // Build a mapping of existing indices to compacted ones.
3215 SmallVector<std::pair<unsigned, unsigned>> ranges;
3216 unsigned newWidth = 0;
3217 for (int i = usedBits.find_first(); 0 <= i && i < width;) {
3218 int e = usedBits.find_next_unset(i);
3219 if (e < 0)
3220 e = width;
3221 for (int idx = i; idx < e; ++idx, ++newWidth) {
3222 if (auto it = mapping.find(idx); it != mapping.end()) {
3223 it->second = newWidth;
3224 }
3225 }
3226 ranges.emplace_back(i, e - 1);
3227 i = e != width ? usedBits.find_next(e) : e;
3228 }
3229
3230 // Create the new op with the new port types.
3231 auto newType = IntType::get(mem->getContext(), type.isSigned(), newWidth);
3232 SmallVector<Type> portTypes;
3233 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3234 portTypes.push_back(
3235 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3236 }
3237 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3238 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3239 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3240 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3241 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3242
3243 // Rewrite bundle users to the new data type.
3244 auto rewriteSubfield = [&](Value port, StringRef field) {
3245 auto portTy = type_cast<BundleType>(port.getType());
3246 auto fieldIndex = portTy.getElementIndex(field);
3247 assert(fieldIndex && "missing data port");
3248
3249 rewriter.setInsertionPointAfter(newMem);
3250 auto newPortAccess =
3251 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3252
3253 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
3254 auto portAccess = cast<SubfieldOp>(op);
3255 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3256 continue;
3257 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3258 }
3259 };
3260
3261 // Rewrite the field accesses.
3262 for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
3263 switch (newMem.getPortKind(i)) {
3264 case MemOp::PortKind::Debug:
3265 llvm_unreachable("cannot rewrite debug port");
3266 case MemOp::PortKind::Write:
3267 rewriteSubfield(port, "data");
3268 continue;
3269 case MemOp::PortKind::Read:
3270 rewriteSubfield(port, "data");
3271 continue;
3272 case MemOp::PortKind::ReadWrite:
3273 rewriteSubfield(port, "rdata");
3274 rewriteSubfield(port, "wdata");
3275 continue;
3276 }
3277 llvm_unreachable("unknown port kind");
3278 }
3279
3280 // Rewrite the reads to the new ranges, compacting them.
3281 for (auto readOp : readOps) {
3282 rewriter.setInsertionPointAfter(readOp);
3283 auto it = mapping.find(readOp.getLo());
3284 assert(it != mapping.end() && "bit op mapping not found");
3285 // Create a new bit selection from the compressed memory. The new op may
3286 // be folded if we are selecting the entire compressed memory.
3287 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3288 readOp.getLoc(), readOp.getInput(),
3289 readOp.getHi() - readOp.getLo() + it->second, it->second);
3290 rewriter.replaceAllUsesWith(readOp, newReadValue);
3291 rewriter.eraseOp(readOp);
3292 }
3293
3294 // Rewrite the writes into a concatenation of slices.
3295 for (auto writeOp : writeOps) {
3296 Value source = writeOp.getSrc();
3297 rewriter.setInsertionPoint(writeOp);
3298
3299 SmallVector<Value> slices;
3300 for (auto &[start, end] : llvm::reverse(ranges)) {
3301 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3302 source, end, start);
3303 slices.push_back(slice);
3304 }
3305
3306 Value catOfSlices =
3307 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3308
3309 // If the original memory held a signed integer, then the compressed
3310 // memory will be signed too. Since the catOfSlices is always unsigned,
3311 // cast the data to a signed integer if needed before connecting back to
3312 // the memory.
3313 if (type.isSigned())
3314 catOfSlices =
3315 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3316
3317 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3318 catOfSlices);
3319 }
3320
3321 return success();
3322 }
3323};
3324
3325// Rewrite single-address memories to a firrtl register.
3326struct FoldRegMems : public mlir::OpRewritePattern<MemOp> {
3327 using OpRewritePattern::OpRewritePattern;
3328 LogicalResult matchAndRewrite(MemOp mem,
3329 PatternRewriter &rewriter) const override {
3330 const FirMemory &info = mem.getSummary();
3331 if (hasDontTouch(mem) || info.depth != 1)
3332 return failure();
3333
3334 auto ty = mem.getDataType();
3335 auto loc = mem.getLoc();
3336 auto *block = mem->getBlock();
3337
3338 // Find the clock of the register-to-be, all write ports should share it.
3339 Value clock;
3340 SmallPtrSet<Operation *, 8> connects;
3341 SmallVector<SubfieldOp> portAccesses;
3342 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3343 if (!mem.getPortAnnotation(i).empty())
3344 continue;
3345
3346 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3347 auto portTy = type_cast<BundleType>(port.getType());
3348 for (auto field : fields) {
3349 auto fieldIndex = portTy.getElementIndex(field);
3350 assert(fieldIndex && "missing field on memory port");
3351
3352 for (auto *op : port.getUsers()) {
3353 auto portAccess = cast<SubfieldOp>(op);
3354 if (fieldIndex != portAccess.getFieldIndex())
3355 continue;
3356 portAccesses.push_back(portAccess);
3357 for (auto *user : portAccess->getUsers()) {
3358 auto conn = dyn_cast<FConnectLike>(user);
3359 if (!conn)
3360 return failure();
3361 connects.insert(conn);
3362 }
3363 }
3364 }
3365 return success();
3366 };
3367
3368 switch (mem.getPortKind(i)) {
3369 case MemOp::PortKind::Debug:
3370 return failure();
3371 case MemOp::PortKind::Read:
3372 if (failed(collect({"clk", "en", "addr"})))
3373 return failure();
3374 continue;
3375 case MemOp::PortKind::Write:
3376 if (failed(collect({"clk", "en", "addr", "data", "mask"})))
3377 return failure();
3378 break;
3379 case MemOp::PortKind::ReadWrite:
3380 if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
3381 return failure();
3382 break;
3383 }
3384
3385 Value portClock = getPortFieldValue(port, "clk");
3386 if (!portClock || (clock && portClock != clock))
3387 return failure();
3388 clock = portClock;
3389 }
3390 // Create a new wire where the memory used to be. This wire will dominate
3391 // all readers of the memory. Reads should be made through this wire.
3392 rewriter.setInsertionPointAfter(mem);
3393 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3394
3395 // The memory is replaced by a register, which we place at the end of the
3396 // block, so that any value driven to the original memory will dominate the
3397 // new register (including the clock). All other ops will be placed
3398 // after the register.
3399 rewriter.setInsertionPointToEnd(block);
3400 auto memReg =
3401 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3402
3403 // Connect the output of the register to the wire.
3404 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3405
3406 // Helper to insert a given number of pipeline stages through registers.
3407 // The pipelines are placed at the end of the block.
3408 auto pipeline = [&](Value value, Value clock, const Twine &name,
3409 unsigned latency) {
3410 for (unsigned i = 0; i < latency; ++i) {
3411 std::string regName;
3412 {
3413 llvm::raw_string_ostream os(regName);
3414 os << mem.getName() << "_" << name << "_" << i;
3415 }
3416 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3417 rewriter.getStringAttr(regName))
3418 .getResult();
3419 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3420 value = reg;
3421 }
3422 return value;
3423 };
3424
3425 const unsigned writeStages = info.writeLatency - 1;
3426
3427 // Traverse each port. Replace reads with the pipelined register, discarding
3428 // the enable flag and reading unconditionally. Pipeline the mask, enable
3429 // and data bits of all write ports to be arbitrated and wired to the reg.
3430 SmallVector<std::tuple<Value, Value, Value>> writes;
3431 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3432 Value portClock = getPortFieldValue(port, "clk");
3433 StringRef name = mem.getPortName(i);
3434
3435 auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
3436 Value value = getPortFieldValue(port, field);
3437 assert(value);
3438 return pipeline(value, portClock, name + "_" + field, stages);
3439 };
3440
3441 switch (mem.getPortKind(i)) {
3442 case MemOp::PortKind::Debug:
3443 llvm_unreachable("unknown port kind");
3444 case MemOp::PortKind::Read: {
3445 // Read ports pipeline the addr and enable signals. However, the
3446 // address must be 0 for single-address memories and the enable signal
3447 // is ignored, always reading out the register. Under these constraints,
3448 // the read port can be replaced with the value from the register.
3449 replacePortField(rewriter, port, "data", memWire);
3450 break;
3451 }
3452 case MemOp::PortKind::Write: {
3453 auto data = portPipeline("data", writeStages);
3454 auto en = portPipeline("en", writeStages);
3455 auto mask = portPipeline("mask", writeStages);
3456 writes.emplace_back(data, en, mask);
3457 break;
3458 }
3459 case MemOp::PortKind::ReadWrite: {
3460 // Always read the register into the read end.
3461 replacePortField(rewriter, port, "rdata", memWire);
3462
3463 // Create a write enable and pipeline stages.
3464 auto wdata = portPipeline("wdata", writeStages);
3465 auto wmask = portPipeline("wmask", writeStages);
3466
3467 Value en = getPortFieldValue(port, "en");
3468 Value wmode = getPortFieldValue(port, "wmode");
3469
3470 auto wen = AndPrimOp::create(rewriter, port.getLoc(), en, wmode);
3471 auto wenPipelined =
3472 pipeline(wen, portClock, name + "_wen", writeStages);
3473 writes.emplace_back(wdata, wenPipelined, wmask);
3474 break;
3475 }
3476 }
3477 }
3478
3479 // Regardless of `writeUnderWrite`, always implement PortOrder.
3480 Value next = memReg;
3481 for (auto &[data, en, mask] : writes) {
3482 Value masked;
3483
3484 // If a mask bit is used, emit muxes to select the input from the
3485 // register (no mask) or the input (mask bit set).
3486 Location loc = mem.getLoc();
3487 unsigned maskGran = info.dataWidth / info.maskBits;
3488 SmallVector<Value> chunks;
3489 for (unsigned i = 0; i < info.maskBits; ++i) {
3490 unsigned hi = (i + 1) * maskGran - 1;
3491 unsigned lo = i * maskGran;
3492
3493 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
3494 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3495 auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
3496 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3497 chunks.push_back(chunk);
3498 }
3499
3500 std::reverse(chunks.begin(), chunks.end());
3501 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3502 next = MuxPrimOp::create(rewriter, next.getLoc(), en, masked, next);
3503 }
3504 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3505 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3506
3507 // Delete the fields and their associated connects.
3508 for (Operation *conn : connects)
3509 rewriter.eraseOp(conn);
3510 for (auto portAccess : portAccesses)
3511 rewriter.eraseOp(portAccess);
3512 rewriter.eraseOp(mem);
3513
3514 return success();
3515 }
3516};
3517} // namespace
3518
3519void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3520 MLIRContext *context) {
3521 results
3522 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3523 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3524 context);
3525}
3526
3527//===----------------------------------------------------------------------===//
3528// Declarations
3529//===----------------------------------------------------------------------===//
3530
3531// Turn synchronous reset looking register updates to registers with resets.
3532// Also, const prop registers that are driven by a mux tree containing only
3533// instances of one constant or self-assigns.
3534static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3535 // reg ; connect(reg, mux(port, const, val)) ->
3536 // reg.reset(port, const); connect(reg, val)
3537
3538 // Find the one true connect, or bail
3539 auto con = getSingleConnectUserOf(reg.getResult());
3540 if (!con)
3541 return failure();
3542
3543 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3544 if (!mux)
3545 return failure();
3546 auto *high = mux.getHigh().getDefiningOp();
3547 auto *low = mux.getLow().getDefiningOp();
3548 // Reset value must be constant
3549 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3550
3551 // Detect the case if a register only has two possible drivers:
3552 // (1) itself/uninit and (2) constant.
3553 // The mux can then be replaced with the constant.
3554 // r = mux(cond, r, 3) --> r = 3
3555 // r = mux(cond, 3, r) --> r = 3
3556 bool constReg = false;
3557
3558 if (constOp && low == reg)
3559 constReg = true;
3560 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3561 constReg = true;
3562 constOp = dyn_cast<ConstantOp>(low);
3563 }
3564 if (!constOp)
3565 return failure();
3566
3567 // For a non-constant register, reset should be a module port (heuristic to
3568 // limit to intended reset lines). Replace the register anyway if constant.
3569 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3570 return failure();
3571
3572 // Check all types should be typed by now
3573 auto regTy = reg.getResult().getType();
3574 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3575 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3576 regTy.getBitWidthOrSentinel() < 0)
3577 return failure();
3578
3579 // Ok, we know we are doing the transformation.
3580
3581 // Make sure the constant dominates all users.
3582 if (constOp != &con->getBlock()->front())
3583 constOp->moveBefore(&con->getBlock()->front());
3584
3585 if (!constReg) {
3586 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3587 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3588 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3589 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3590 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3591 reg.getForceableAttr());
3592 newReg->setDialectAttrs(attrs);
3593 }
3594 auto pt = rewriter.saveInsertionPoint();
3595 rewriter.setInsertionPoint(con);
3596 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3597 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3598 rewriter.restoreInsertionPoint(pt);
3599 return success();
3600}
3601
3602LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3603 if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3604 succeeded(foldHiddenReset(op, rewriter)))
3605 return success();
3606
3607 if (succeeded(demoteForceableIfUnused(op, rewriter)))
3608 return success();
3609
3610 return failure();
3611}
3612
3613//===----------------------------------------------------------------------===//
3614// Verification Ops.
3615//===----------------------------------------------------------------------===//
3616
3617static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3618 Value enable,
3619 PatternRewriter &rewriter,
3620 bool eraseIfZero) {
3621 // If the verification op is never enabled, delete it.
3622 if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3623 if (constant.getValue().isZero()) {
3624 rewriter.eraseOp(op);
3625 return success();
3626 }
3627 }
3628
3629 // If the verification op is never triggered, delete it.
3630 if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3631 if (constant.getValue().isZero() == eraseIfZero) {
3632 rewriter.eraseOp(op);
3633 return success();
3634 }
3635 }
3636
3637 return failure();
3638}
3639
3640template <class Op, bool EraseIfZero = false>
3641static LogicalResult canonicalizeImmediateVerifOp(Op op,
3642 PatternRewriter &rewriter) {
3643 return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3644 EraseIfZero);
3645}
3646
3647void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3648 MLIRContext *context) {
3649 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3650 results.add<patterns::AssertXWhenX>(context);
3651}
3652
3653void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3654 MLIRContext *context) {
3655 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3656 results.add<patterns::AssumeXWhenX>(context);
3657}
3658
3659void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3660 RewritePatternSet &results, MLIRContext *context) {
3661 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3662 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(context);
3663}
3664
3665void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3666 MLIRContext *context) {
3667 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3668}
3669
3670//===----------------------------------------------------------------------===//
3671// InvalidValueOp
3672//===----------------------------------------------------------------------===//
3673
3674LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3675 PatternRewriter &rewriter) {
3676 // Remove `InvalidValueOp`s with no uses.
3677 if (op.use_empty()) {
3678 rewriter.eraseOp(op);
3679 return success();
3680 }
3681 // Propagate invalids through a single use which is a unary op. You cannot
3682 // propagate through multiple uses as that breaks invalid semantics. Nor
3683 // can you propagate through binary ops or generally any op which computes.
3684 // Not is an exception as it is a pure, all-bits inverse.
3685 if (op->hasOneUse() &&
3686 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3687 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3688 *op->user_begin()) ||
3689 (isa<CvtPrimOp>(*op->user_begin()) &&
3690 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3691 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3692 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3693 .getBitWidthOrSentinel() > 0))) {
3694 auto *modop = *op->user_begin();
3695 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3696 modop->getResult(0).getType());
3697 rewriter.replaceAllOpUsesWith(modop, inv);
3698 rewriter.eraseOp(modop);
3699 rewriter.eraseOp(op);
3700 return success();
3701 }
3702 return failure();
3703}
3704
3705OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3706 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3707 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3708 return {};
3709}
3710
3711//===----------------------------------------------------------------------===//
3712// ClockGateIntrinsicOp
3713//===----------------------------------------------------------------------===//
3714
3715OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3716 // Forward the clock if one of the enables is always true.
3717 if (isConstantOne(adaptor.getEnable()) ||
3718 isConstantOne(adaptor.getTestEnable()))
3719 return getInput();
3720
3721 // Fold to a constant zero clock if the enables are always false.
3722 if (isConstantZero(adaptor.getEnable()) &&
3723 (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3724 return BoolAttr::get(getContext(), false);
3725
3726 // Forward constant zero clocks.
3727 if (isConstantZero(adaptor.getInput()))
3728 return BoolAttr::get(getContext(), false);
3729
3730 return {};
3731}
3732
3733LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3734 PatternRewriter &rewriter) {
3735 // Remove constant false test enable.
3736 if (auto testEnable = op.getTestEnable()) {
3737 if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3738 if (constOp.getValue().isZero()) {
3739 rewriter.modifyOpInPlace(op,
3740 [&] { op.getTestEnableMutable().clear(); });
3741 return success();
3742 }
3743 }
3744 }
3745
3746 return failure();
3747}
3748
3749//===----------------------------------------------------------------------===//
3750// Reference Ops.
3751//===----------------------------------------------------------------------===//
3752
3753// refresolve(forceable.ref) -> forceable.data
3754static LogicalResult
3755canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3756 auto forceable = op.getRef().getDefiningOp<Forceable>();
3757 if (!forceable || !forceable.isForceable() ||
3758 op.getRef() != forceable.getDataRef() ||
3759 op.getType() != forceable.getDataType())
3760 return failure();
3761 rewriter.replaceAllUsesWith(op, forceable.getData());
3762 return success();
3763}
3764
3765void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3766 MLIRContext *context) {
3767 results.insert<patterns::RefResolveOfRefSend>(context);
3768 results.insert(canonicalizeRefResolveOfForceable);
3769}
3770
3771OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3772 // RefCast is unnecessary if types match.
3773 if (getInput().getType() == getType())
3774 return getInput();
3775 return {};
3776}
3777
3778static bool isConstantZero(Value operand) {
3779 auto constOp = operand.getDefiningOp<ConstantOp>();
3780 return constOp && constOp.getValue().isZero();
3781}
3782
3783template <typename Op>
3784static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3785 if (isConstantZero(op.getPredicate())) {
3786 rewriter.eraseOp(op);
3787 return success();
3788 }
3789 return failure();
3790}
3791
3792void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3793 MLIRContext *context) {
3794 results.add(eraseIfPredFalse<RefForceOp>);
3795}
3796void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3797 MLIRContext *context) {
3798 results.add(eraseIfPredFalse<RefForceInitialOp>);
3799}
3800void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3801 MLIRContext *context) {
3802 results.add(eraseIfPredFalse<RefReleaseOp>);
3803}
3804void RefReleaseInitialOp::getCanonicalizationPatterns(
3805 RewritePatternSet &results, MLIRContext *context) {
3806 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3807}
3808
3809//===----------------------------------------------------------------------===//
3810// HasBeenResetIntrinsicOp
3811//===----------------------------------------------------------------------===//
3812
3813OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3814 // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3815
3816 // Fold to zero if the reset is a constant. In this case the op is either
3817 // permanently in reset or never resets. Both mean that the reset never
3818 // finishes, so this op never returns true.
3819 if (adaptor.getReset())
3820 return getIntZerosAttr(UIntType::get(getContext(), 1));
3821
3822 // Fold to zero if the clock is a constant and the reset is synchronous. In
3823 // that case the reset will never be started.
3824 if (isUInt1(getReset().getType()) && adaptor.getClock())
3825 return getIntZerosAttr(UIntType::get(getContext(), 1));
3826
3827 return {};
3828}
3829
3830//===----------------------------------------------------------------------===//
3831// FPGAProbeIntrinsicOp
3832//===----------------------------------------------------------------------===//
3833
3834static bool isTypeEmpty(FIRRTLType type) {
3836 .Case<FVectorType>(
3837 [&](auto ty) -> bool { return isTypeEmpty(ty.getElementType()); })
3838 .Case<BundleType>([&](auto ty) -> bool {
3839 for (auto elem : ty.getElements())
3840 if (!isTypeEmpty(elem.type))
3841 return false;
3842 return true;
3843 })
3844 .Case<IntType>([&](auto ty) { return ty.getWidth() == 0; })
3845 .Default([](auto) -> bool { return false; });
3846}
3847
3848LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3849 PatternRewriter &rewriter) {
3850 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3851 if (!firrtlTy)
3852 return failure();
3853
3854 if (!isTypeEmpty(firrtlTy))
3855 return failure();
3856
3857 rewriter.eraseOp(op);
3858 return success();
3859}
3860
3861//===----------------------------------------------------------------------===//
3862// Layer Block Op
3863//===----------------------------------------------------------------------===//
3864
3865LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3866 PatternRewriter &rewriter) {
3867
3868 // If the layerblock is empty, erase it.
3869 if (op.getBody()->empty()) {
3870 rewriter.eraseOp(op);
3871 return success();
3872 }
3873
3874 return failure();
3875}
3876
3877//===----------------------------------------------------------------------===//
3878// Domain-related Ops
3879//===----------------------------------------------------------------------===//
3880
3881OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3882 // If no domains are specified, then forward the input to the result.
3883 if (getDomains().empty())
3884 return getInput();
3885
3886 return {};
3887}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:18
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition FoldUtils.h:28
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition FoldUtils.h:35
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21
LogicalResult matchAndRewrite(BitsPrimOp bits, mlir::PatternRewriter &rewriter) const override