CIRCT 23.0.0git
Loading...
Searching...
No Matches
FIRRTLFolds.cpp
Go to the documentation of this file.
1//===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the folding and canonicalizations for FIRRTL ops.
10//
11//===----------------------------------------------------------------------===//
12
17#include "circt/Support/APInt.h"
18#include "circt/Support/LLVM.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallPtrSet.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
27
28using namespace circt;
29using namespace firrtl;
30
31// Drop writes to old and pass through passthrough to make patterns easier to
32// write.
33static Value dropWrite(PatternRewriter &rewriter, OpResult old,
34 Value passthrough) {
35 SmallPtrSet<Operation *, 8> users;
36 for (auto *user : old.getUsers())
37 users.insert(user);
38 for (Operation *user : users)
39 if (auto connect = dyn_cast<FConnectLike>(user))
40 if (connect.getDest() == old)
41 rewriter.eraseOp(user);
42 return passthrough;
43}
44
45// Return true if it is OK to propagate the name to the operation.
46// Non-pure operations such as instances, registers, and memories are not
47// allowed to update names for name stabilities and LEC reasons.
48static bool isOkToPropagateName(Operation *op) {
49 // Conservatively disallow operations with regions to prevent performance
50 // regression due to recursive calls to sub-regions in mlir::isPure.
51 if (op->getNumRegions() != 0)
52 return false;
53 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
54}
55
56// Move a name hint from a soon to be deleted operation to a new operation.
57// Pass through the new operation to make patterns easier to write. This cannot
58// move a name to a port (block argument), doing so would require rewriting all
59// instance sites as well as the module.
60static Value moveNameHint(OpResult old, Value passthrough) {
61 Operation *op = passthrough.getDefiningOp();
62 // This should handle ports, but it isn't clear we can change those in
63 // canonicalizers.
64 assert(op && "passthrough must be an operation");
65 Operation *oldOp = old.getOwner();
66 auto name = oldOp->getAttrOfType<StringAttr>("name");
67 if (name && !name.getValue().empty() && isOkToPropagateName(op))
68 op->setAttr("name", name);
69 return passthrough;
70}
71
72// Declarative canonicalization patterns
73namespace circt {
74namespace firrtl {
75namespace patterns {
76#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
77} // namespace patterns
78} // namespace firrtl
79} // namespace circt
80
81/// Return true if this operation's operands and results all have a known width.
82/// This only works for integer types.
83static bool hasKnownWidthIntTypes(Operation *op) {
84 auto resultType = type_cast<IntType>(op->getResult(0).getType());
85 if (!resultType.hasWidth())
86 return false;
87 for (Value operand : op->getOperands())
88 if (!type_cast<IntType>(operand.getType()).hasWidth())
89 return false;
90 return true;
91}
92
93/// Return true if this value is 1 bit UInt.
94static bool isUInt1(Type type) {
95 auto t = type_dyn_cast<UIntType>(type);
96 if (!t || !t.hasWidth() || t.getWidth() != 1)
97 return false;
98 return true;
99}
100
101/// Set the name of an op based on the best of two names: The current name, and
102/// the name passed in.
103static void updateName(PatternRewriter &rewriter, Operation *op,
104 StringAttr name) {
105 // Should never rename InstanceOp
106 if (!name || name.getValue().empty() || !isOkToPropagateName(op))
107 return;
108 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) && "Should never rename");
109 auto newName = name.getValue(); // old name is interesting
110 auto newOpName = op->getAttrOfType<StringAttr>("name");
111 // new name might not be interesting
112 if (newOpName)
113 newName = chooseName(newOpName.getValue(), name.getValue());
114 // Only update if needed
115 if (!newOpName || newOpName.getValue() != newName)
116 rewriter.modifyOpInPlace(
117 op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
118}
119
120/// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
121/// If a replaced op has a "name" attribute, this function propagates the name
122/// to the new value.
123static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
124 Value newValue) {
125 if (auto *newOp = newValue.getDefiningOp()) {
126 auto name = op->getAttrOfType<StringAttr>("name");
127 updateName(rewriter, newOp, name);
128 }
129 rewriter.replaceOp(op, newValue);
130}
131
132/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
133/// attribute. If a replaced op has a "name" attribute, this function propagates
134/// the name to the new value.
135template <typename OpTy, typename... Args>
136static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
137 Operation *op, Args &&...args) {
138 auto name = op->getAttrOfType<StringAttr>("name");
139 auto newOp =
140 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
141 updateName(rewriter, newOp, name);
142 return newOp;
143}
144
145/// Return true if the name is droppable. Note that this is different from
146/// `isUselessName` because non-useless names may be also droppable.
148 if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
149 return namableOp.hasDroppableName();
150 return false;
151}
152
153/// Implicitly replace the operand to a constant folding operation with a const
154/// 0 in case the operand is non-constant but has a bit width 0, or if the
155/// operand is an invalid value.
156///
157/// This makes constant folding significantly easier, as we can simply pass the
158/// operands to an operation through this function to appropriately replace any
159/// zero-width dynamic values or invalid values with a constant of value 0.
160static std::optional<APSInt>
161getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
162 assert(type_cast<IntType>(operand.getType()) &&
163 "getExtendedConstant is limited to integer types");
164
165 // We never support constant folding to unknown width values.
166 if (destWidth < 0)
167 return {};
168
169 // Extension signedness follows the operand sign.
170 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
171 return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
172
173 // If the operand is zero bits, then we can return a zero of the result
174 // type.
175 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
176 return APSInt(destWidth,
177 type_cast<IntType>(operand.getType()).isUnsigned());
178 return {};
179}
180
181/// Determine the value of a constant operand for the sake of constant folding.
182static std::optional<APSInt> getConstant(Attribute operand) {
183 if (!operand)
184 return {};
185 if (auto attr = dyn_cast<BoolAttr>(operand))
186 return APSInt(APInt(1, attr.getValue()));
187 if (auto attr = dyn_cast<IntegerAttr>(operand))
188 return attr.getAPSInt();
189 return {};
190}
191
192/// Determine whether a constant operand is a zero value for the sake of
193/// constant folding. This considers `invalidvalue` to be zero.
194static bool isConstantZero(Attribute operand) {
195 if (auto cst = getConstant(operand))
196 return cst->isZero();
197 return false;
198}
199
200/// Determine whether a constant operand is a one value for the sake of constant
201/// folding.
202static bool isConstantOne(Attribute operand) {
203 if (auto cst = getConstant(operand))
204 return cst->isOne();
205 return false;
206}
207
208/// This is the policy for folding, which depends on the sort of operator we're
209/// processing.
210enum class BinOpKind {
211 Normal,
212 Compare,
214};
215
216/// Applies the constant folding function `calculate` to the given operands.
217///
218/// Sign or zero extends the operands appropriately to the bitwidth of the
219/// result type if \p useDstWidth is true, else to the larger of the two operand
220/// bit widths and depending on whether the operation is to be performed on
221/// signed or unsigned operands.
222static Attribute constFoldFIRRTLBinaryOp(
223 Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
224 const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
225 assert(operands.size() == 2 && "binary op takes two operands");
226
227 // We cannot fold something to an unknown width.
228 auto resultType = type_cast<IntType>(op->getResult(0).getType());
229 if (resultType.getWidthOrSentinel() < 0)
230 return {};
231
232 // Any binary op returning i0 is 0.
233 if (resultType.getWidthOrSentinel() == 0)
234 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
235
236 // Determine the operand widths. This is either dictated by the operand type,
237 // or if that type is an unsized integer, by the actual bits necessary to
238 // represent the constant value.
239 auto lhsWidth =
240 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
241 auto rhsWidth =
242 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
243 if (auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
244 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
245 if (auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
246 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
247
248 // Compares extend the operands to the widest of the operand types, not to the
249 // result type.
250 int32_t operandWidth;
251 switch (opKind) {
253 operandWidth = resultType.getWidthOrSentinel();
254 break;
256 // Compares compute with the widest operand, not at the destination type
257 // (which is always i1).
258 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
259 break;
261 operandWidth =
262 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
263 break;
264 }
265
266 auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
267 if (!lhs)
268 return {};
269 auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
270 if (!rhs)
271 return {};
272
273 APInt resultValue = calculate(*lhs, *rhs);
274
275 // If the result type is smaller than the computation then we need to
276 // narrow the constant after the calculation.
277 if (opKind == BinOpKind::DivideOrShift)
278 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
279
280 assert((unsigned)resultType.getWidthOrSentinel() ==
281 resultValue.getBitWidth());
282 return getIntAttr(resultType, resultValue);
283}
284
285/// Applies the canonicalization function `canonicalize` to the given operation.
286///
287/// Determines which (if any) of the operation's operands are constants, and
288/// provides them as arguments to the callback function. Any `invalidvalue` in
289/// the input is mapped to a constant zero. The value returned from the callback
290/// is used as the replacement for `op`, and an additional pad operation is
291/// inserted if necessary. Does nothing if the result of `op` is of unknown
292/// width, in which case the necessity of a pad cannot be determined.
293static LogicalResult canonicalizePrimOp(
294 Operation *op, PatternRewriter &rewriter,
295 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
296 // Can only operate on FIRRTL primitive operations.
297 if (op->getNumResults() != 1)
298 return failure();
299 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
300 if (!type)
301 return failure();
302
303 // Can only operate on operations with a known result width.
304 auto width = type.getBitWidthOrSentinel();
305 if (width < 0)
306 return failure();
307
308 // Determine which of the operands are constants.
309 SmallVector<Attribute, 3> constOperands;
310 constOperands.reserve(op->getNumOperands());
311 for (auto operand : op->getOperands()) {
312 Attribute attr;
313 if (auto *defOp = operand.getDefiningOp())
314 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
315 [&](auto op) { attr = op.getValueAttr(); });
316 constOperands.push_back(attr);
317 }
318
319 // Perform the canonicalization and materialize the result if it is a
320 // constant.
321 auto result = canonicalize(constOperands);
322 if (!result)
323 return failure();
324 Value resultValue;
325 if (auto cst = dyn_cast<Attribute>(result))
326 resultValue = op->getDialect()
327 ->materializeConstant(rewriter, cst, type, op->getLoc())
328 ->getResult(0);
329 else
330 resultValue = cast<Value>(result);
331
332 // Insert a pad if the type widths disagree.
333 if (width !=
334 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
335 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
336
337 // Insert a cast if this is a uint vs. sint or vice versa.
338 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
339 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
340 else if (type_isa<UIntType>(type) &&
341 type_isa<SIntType>(resultValue.getType()))
342 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
343
344 assert(type == resultValue.getType() && "canonicalization changed type");
345 replaceOpAndCopyName(rewriter, op, resultValue);
346 return success();
347}
348
349/// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
350/// value if `bitWidth` is 0.
351static APInt getMaxUnsignedValue(unsigned bitWidth) {
352 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
353}
354
355/// Get the smallest signed value of a given bit width. Returns a 1-bit zero
356/// value if `bitWidth` is 0.
357static APInt getMinSignedValue(unsigned bitWidth) {
358 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
359}
360
361/// Get the largest signed value of a given bit width. Returns a 1-bit zero
362/// value if `bitWidth` is 0.
363static APInt getMaxSignedValue(unsigned bitWidth) {
364 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
365}
366
367//===----------------------------------------------------------------------===//
368// Fold Hooks
369//===----------------------------------------------------------------------===//
370
371OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
372 assert(adaptor.getOperands().empty() && "constant has no operands");
373 return getValueAttr();
374}
375
376OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
377 assert(adaptor.getOperands().empty() && "constant has no operands");
378 return getValueAttr();
379}
380
381OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
382 assert(adaptor.getOperands().empty() && "constant has no operands");
383 return getFieldsAttr();
384}
385
386OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
387 assert(adaptor.getOperands().empty() && "constant has no operands");
388 return getValueAttr();
389}
390
391OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
392 assert(adaptor.getOperands().empty() && "constant has no operands");
393 return getValueAttr();
394}
395
396OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
397 assert(adaptor.getOperands().empty() && "constant has no operands");
398 return getValueAttr();
399}
400
401OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
402 assert(adaptor.getOperands().empty() && "constant has no operands");
403 return getValueAttr();
404}
405
406//===----------------------------------------------------------------------===//
407// Binary Operators
408//===----------------------------------------------------------------------===//
409
410OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
412 *this, adaptor.getOperands(), BinOpKind::Normal,
413 [=](const APSInt &a, const APSInt &b) { return a + b; });
414}
415
416void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417 MLIRContext *context) {
418 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
419 patterns::AddOfSelf, patterns::AddOfPad>(context);
420}
421
422OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
424 *this, adaptor.getOperands(), BinOpKind::Normal,
425 [=](const APSInt &a, const APSInt &b) { return a - b; });
426}
427
428void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
429 MLIRContext *context) {
430 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
431 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
432 patterns::SubOfPadL, patterns::SubOfPadR>(context);
433}
434
435OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
436 // mul(x, 0) -> 0
437 //
438 // This is legal because it aligns with the Scala FIRRTL Compiler
439 // interpretation of lowering invalid to constant zero before constant
440 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
441 // multiplication this way and will emit "x * 0".
442 if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
443 return getIntZerosAttr(getType());
444
446 *this, adaptor.getOperands(), BinOpKind::Normal,
447 [=](const APSInt &a, const APSInt &b) { return a * b; });
448}
449
450OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
451 /// div(x, x) -> 1
452 ///
453 /// Division by zero is undefined in the FIRRTL specification. This fold
454 /// exploits that fact to optimize self division to one. Note: this should
455 /// supersede any division with invalid or zero. Division of invalid by
456 /// invalid should be one.
457 if (getLhs() == getRhs()) {
458 auto width = getType().base().getWidthOrSentinel();
459 if (width == -1)
460 width = 2;
461 // Only fold if we have at least 1 bit of width to represent the `1` value.
462 if (width != 0)
463 return getIntAttr(getType(), APInt(width, 1));
464 }
465
466 // div(0, x) -> 0
467 //
468 // This is legal because it aligns with the Scala FIRRTL Compiler
469 // interpretation of lowering invalid to constant zero before constant
470 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
471 // division this way and will emit "0 / x".
472 if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
473 return getIntZerosAttr(getType());
474
475 /// div(x, 1) -> x : (uint, uint) -> uint
476 ///
477 /// UInt division by one returns the numerator. SInt division can't
478 /// be folded here because it increases the return type bitwidth by
479 /// one and requires sign extension (a new op).
480 if (auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
481 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
482 return getLhs();
483
485 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
486 [=](const APSInt &a, const APSInt &b) -> APInt {
487 if (!!b)
488 return a / b;
489 return APInt(a.getBitWidth(), 0);
490 });
491}
492
493OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
494 // rem(x, x) -> 0
495 //
496 // Division by zero is undefined in the FIRRTL specification. This fold
497 // exploits that fact to optimize self division remainder to zero. Note:
498 // this should supersede any division with invalid or zero. Remainder of
499 // division of invalid by invalid should be zero.
500 if (getLhs() == getRhs())
501 return getIntZerosAttr(getType());
502
503 // rem(0, x) -> 0
504 //
505 // This is legal because it aligns with the Scala FIRRTL Compiler
506 // interpretation of lowering invalid to constant zero before constant
507 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
508 // division this way and will emit "0 % x".
509 if (isConstantZero(adaptor.getLhs()))
510 return getIntZerosAttr(getType());
511
513 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
514 [=](const APSInt &a, const APSInt &b) -> APInt {
515 if (!!b)
516 return a % b;
517 return APInt(a.getBitWidth(), 0);
518 });
519}
520
521OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
523 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
524 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
525}
526
527OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
529 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
530 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
531}
532
533OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
535 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
536 [=](const APSInt &a, const APSInt &b) -> APInt {
537 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
538 : a.ashr(b);
539 });
540}
541
542// TODO: Move to DRR.
543OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
544 if (auto rhsCst = getConstant(adaptor.getRhs())) {
545 /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
546 if (rhsCst->isZero())
547 return getIntZerosAttr(getType());
548
549 /// and(x, -1) -> x
550 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
551 getRhs().getType() == getType())
552 return getLhs();
553 }
554
555 if (auto lhsCst = getConstant(adaptor.getLhs())) {
556 /// and(0, x) -> 0, 0 is largest or is implicit zero extended
557 if (lhsCst->isZero())
558 return getIntZerosAttr(getType());
559
560 /// and(-1, x) -> x
561 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
562 getRhs().getType() == getType())
563 return getRhs();
564 }
565
566 /// and(x, x) -> x
567 if (getLhs() == getRhs() && getRhs().getType() == getType())
568 return getRhs();
569
571 *this, adaptor.getOperands(), BinOpKind::Normal,
572 [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
573}
574
575void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
576 MLIRContext *context) {
577 results
578 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
579 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
580 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
581}
582
583OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
584 if (auto rhsCst = getConstant(adaptor.getRhs())) {
585 /// or(x, 0) -> x
586 if (rhsCst->isZero() && getLhs().getType() == getType())
587 return getLhs();
588
589 /// or(x, -1) -> -1
590 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
591 getLhs().getType() == getType())
592 return getRhs();
593 }
594
595 if (auto lhsCst = getConstant(adaptor.getLhs())) {
596 /// or(0, x) -> x
597 if (lhsCst->isZero() && getRhs().getType() == getType())
598 return getRhs();
599
600 /// or(-1, x) -> -1
601 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
602 getRhs().getType() == getType())
603 return getLhs();
604 }
605
606 /// or(x, x) -> x
607 if (getLhs() == getRhs() && getRhs().getType() == getType())
608 return getRhs();
609
611 *this, adaptor.getOperands(), BinOpKind::Normal,
612 [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
613}
614
615void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
616 MLIRContext *context) {
617 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
618 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
619 patterns::OrOrr>(context);
620}
621
622OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
623 /// xor(x, 0) -> x
624 if (auto rhsCst = getConstant(adaptor.getRhs()))
625 if (rhsCst->isZero() &&
626 firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
627 return getLhs();
628
629 /// xor(x, 0) -> x
630 if (auto lhsCst = getConstant(adaptor.getLhs()))
631 if (lhsCst->isZero() &&
632 firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
633 return getRhs();
634
635 /// xor(x, x) -> 0
636 if (getLhs() == getRhs())
637 return getIntAttr(
638 getType(),
639 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
640
642 *this, adaptor.getOperands(), BinOpKind::Normal,
643 [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
644}
645
646void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
647 MLIRContext *context) {
648 results.insert<patterns::extendXor, patterns::moveConstXor,
649 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
650 context);
651}
652
653void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
654 MLIRContext *context) {
655 results.insert<patterns::LEQWithConstLHS>(context);
656}
657
658OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
659 bool isUnsigned = getLhs().getType().base().isUnsigned();
660
661 // leq(x, x) -> 1
662 if (getLhs() == getRhs())
663 return getIntAttr(getType(), APInt(1, 1));
664
665 // Comparison against constant outside type bounds.
666 if (auto width = getLhs().getType().base().getWidth()) {
667 if (auto rhsCst = getConstant(adaptor.getRhs())) {
668 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
669 commonWidth = std::max(commonWidth, 1);
670
671 // leq(x, const) -> 0 where const < minValue of the unsigned type of x
672 // This can never occur since const is unsigned and cannot be less than 0.
673
674 // leq(x, const) -> 0 where const < minValue of the signed type of x
675 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
676 .slt(getMinSignedValue(*width).sext(commonWidth)))
677 return getIntAttr(getType(), APInt(1, 0));
678
679 // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
680 if (isUnsigned && rhsCst->zext(commonWidth)
681 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
682 return getIntAttr(getType(), APInt(1, 1));
683
684 // leq(x, const) -> 1 where const >= maxValue of the signed type of x
685 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
686 .sge(getMaxSignedValue(*width).sext(commonWidth)))
687 return getIntAttr(getType(), APInt(1, 1));
688 }
689 }
690
692 *this, adaptor.getOperands(), BinOpKind::Compare,
693 [=](const APSInt &a, const APSInt &b) -> APInt {
694 return APInt(1, a <= b);
695 });
696}
697
698void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
699 MLIRContext *context) {
700 results.insert<patterns::LTWithConstLHS>(context);
701}
702
703OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
704 IntType lhsType = getLhs().getType();
705 bool isUnsigned = lhsType.isUnsigned();
706
707 // lt(x, x) -> 0
708 if (getLhs() == getRhs())
709 return getIntAttr(getType(), APInt(1, 0));
710
711 // lt(x, 0) -> 0 when x is unsigned
712 if (auto rhsCst = getConstant(adaptor.getRhs())) {
713 if (rhsCst->isZero() && lhsType.isUnsigned())
714 return getIntAttr(getType(), APInt(1, 0));
715 }
716
717 // Comparison against constant outside type bounds.
718 if (auto width = lhsType.getWidth()) {
719 if (auto rhsCst = getConstant(adaptor.getRhs())) {
720 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
721 commonWidth = std::max(commonWidth, 1);
722
723 // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
724 // Handled explicitly above.
725
726 // lt(x, const) -> 0 where const <= minValue of the signed type of x
727 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
728 .sle(getMinSignedValue(*width).sext(commonWidth)))
729 return getIntAttr(getType(), APInt(1, 0));
730
731 // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
732 if (isUnsigned && rhsCst->zext(commonWidth)
733 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
734 return getIntAttr(getType(), APInt(1, 1));
735
736 // lt(x, const) -> 1 where const > maxValue of the signed type of x
737 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
738 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
739 return getIntAttr(getType(), APInt(1, 1));
740 }
741 }
742
744 *this, adaptor.getOperands(), BinOpKind::Compare,
745 [=](const APSInt &a, const APSInt &b) -> APInt {
746 return APInt(1, a < b);
747 });
748}
749
750void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
751 MLIRContext *context) {
752 results.insert<patterns::GEQWithConstLHS>(context);
753}
754
755OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
756 IntType lhsType = getLhs().getType();
757 bool isUnsigned = lhsType.isUnsigned();
758
759 // geq(x, x) -> 1
760 if (getLhs() == getRhs())
761 return getIntAttr(getType(), APInt(1, 1));
762
763 // geq(x, 0) -> 1 when x is unsigned
764 if (auto rhsCst = getConstant(adaptor.getRhs())) {
765 if (rhsCst->isZero() && isUnsigned)
766 return getIntAttr(getType(), APInt(1, 1));
767 }
768
769 // Comparison against constant outside type bounds.
770 if (auto width = lhsType.getWidth()) {
771 if (auto rhsCst = getConstant(adaptor.getRhs())) {
772 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
773 commonWidth = std::max(commonWidth, 1);
774
775 // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
776 if (isUnsigned && rhsCst->zext(commonWidth)
777 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
778 return getIntAttr(getType(), APInt(1, 0));
779
780 // geq(x, const) -> 0 where const > maxValue of the signed type of x
781 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
782 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
783 return getIntAttr(getType(), APInt(1, 0));
784
785 // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
786 // Handled explicitly above.
787
788 // geq(x, const) -> 1 where const <= minValue of the signed type of x
789 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
790 .sle(getMinSignedValue(*width).sext(commonWidth)))
791 return getIntAttr(getType(), APInt(1, 1));
792 }
793 }
794
796 *this, adaptor.getOperands(), BinOpKind::Compare,
797 [=](const APSInt &a, const APSInt &b) -> APInt {
798 return APInt(1, a >= b);
799 });
800}
801
802void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
803 MLIRContext *context) {
804 results.insert<patterns::GTWithConstLHS>(context);
805}
806
807OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
808 IntType lhsType = getLhs().getType();
809 bool isUnsigned = lhsType.isUnsigned();
810
811 // gt(x, x) -> 0
812 if (getLhs() == getRhs())
813 return getIntAttr(getType(), APInt(1, 0));
814
815 // Comparison against constant outside type bounds.
816 if (auto width = lhsType.getWidth()) {
817 if (auto rhsCst = getConstant(adaptor.getRhs())) {
818 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
819 commonWidth = std::max(commonWidth, 1);
820
821 // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
822 if (isUnsigned && rhsCst->zext(commonWidth)
823 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
824 return getIntAttr(getType(), APInt(1, 0));
825
826 // gt(x, const) -> 0 where const >= maxValue of the signed type of x
827 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
828 .sge(getMaxSignedValue(*width).sext(commonWidth)))
829 return getIntAttr(getType(), APInt(1, 0));
830
831 // gt(x, const) -> 1 where const < minValue of the unsigned type of x
832 // This can never occur since const is unsigned and cannot be less than 0.
833
834 // gt(x, const) -> 1 where const < minValue of the signed type of x
835 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
836 .slt(getMinSignedValue(*width).sext(commonWidth)))
837 return getIntAttr(getType(), APInt(1, 1));
838 }
839 }
840
842 *this, adaptor.getOperands(), BinOpKind::Compare,
843 [=](const APSInt &a, const APSInt &b) -> APInt {
844 return APInt(1, a > b);
845 });
846}
847
848OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
849 // eq(x, x) -> 1
850 if (getLhs() == getRhs())
851 return getIntAttr(getType(), APInt(1, 1));
852
853 if (auto rhsCst = getConstant(adaptor.getRhs())) {
854 /// eq(x, 1) -> x when x is 1 bit.
855 /// TODO: Support SInt<1> on the LHS etc.
856 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
857 getRhs().getType() == getType())
858 return getLhs();
859 }
860
862 *this, adaptor.getOperands(), BinOpKind::Compare,
863 [=](const APSInt &a, const APSInt &b) -> APInt {
864 return APInt(1, a == b);
865 });
866}
867
868LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
869 return canonicalizePrimOp(
870 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
871 if (auto rhsCst = getConstant(operands[1])) {
872 auto width = op.getLhs().getType().getBitWidthOrSentinel();
873
874 // eq(x, 0) -> not(x) when x is 1 bit.
875 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
876 op.getRhs().getType() == op.getType()) {
877 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
878 .getResult();
879 }
880
881 // eq(x, 0) -> not(orr(x)) when x is >1 bit
882 if (rhsCst->isZero() && width > 1) {
883 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
884 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
885 }
886
887 // eq(x, ~0) -> andr(x) when x is >1 bit
888 if (rhsCst->isAllOnes() && width > 1 &&
889 op.getLhs().getType() == op.getRhs().getType()) {
890 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
891 .getResult();
892 }
893 }
894 return {};
895 });
896}
897
898OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
899 // neq(x, x) -> 0
900 if (getLhs() == getRhs())
901 return getIntAttr(getType(), APInt(1, 0));
902
903 if (auto rhsCst = getConstant(adaptor.getRhs())) {
904 /// neq(x, 0) -> x when x is 1 bit.
905 /// TODO: Support SInt<1> on the LHS etc.
906 if (rhsCst->isZero() && getLhs().getType() == getType() &&
907 getRhs().getType() == getType())
908 return getLhs();
909 }
910
912 *this, adaptor.getOperands(), BinOpKind::Compare,
913 [=](const APSInt &a, const APSInt &b) -> APInt {
914 return APInt(1, a != b);
915 });
916}
917
918LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
919 return canonicalizePrimOp(
920 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
921 if (auto rhsCst = getConstant(operands[1])) {
922 auto width = op.getLhs().getType().getBitWidthOrSentinel();
923
924 // neq(x, 1) -> not(x) when x is 1 bit
925 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
926 op.getRhs().getType() == op.getType()) {
927 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
928 .getResult();
929 }
930
931 // neq(x, 0) -> orr(x) when x is >1 bit
932 if (rhsCst->isZero() && width > 1) {
933 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
934 .getResult();
935 }
936
937 // neq(x, ~0) -> not(andr(x))) when x is >1 bit
938 if (rhsCst->isAllOnes() && width > 1 &&
939 op.getLhs().getType() == op.getRhs().getType()) {
940 auto andrOp =
941 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
942 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
943 }
944 }
945
946 return {};
947 });
948}
949
950OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
951 // TODO: implement constant folding, etc.
952 // Tracked in https://github.com/llvm/circt/issues/6696.
953 return {};
954}
955
956OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
957 // TODO: implement constant folding, etc.
958 // Tracked in https://github.com/llvm/circt/issues/6724.
959 return {};
960}
961
962OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
963 if (auto rhsCst = getConstant(adaptor.getRhs())) {
964 if (auto lhsCst = getConstant(adaptor.getLhs())) {
965
966 return IntegerAttr::get(IntegerType::get(getContext(),
967 lhsCst->getBitWidth(),
968 IntegerType::Signed),
969 lhsCst->ashr(*rhsCst));
970 }
971
972 if (rhsCst->isZero())
973 return getLhs();
974 }
975
976 return {};
977}
978
979OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
980 if (auto rhsCst = getConstant(adaptor.getRhs())) {
981 // Constant folding
982 if (auto lhsCst = getConstant(adaptor.getLhs()))
983
984 return IntegerAttr::get(IntegerType::get(getContext(),
985 lhsCst->getBitWidth(),
986 IntegerType::Signed),
987 lhsCst->shl(*rhsCst));
988
989 // integer.shl(x, 0) -> x
990 if (rhsCst->isZero())
991 return getLhs();
992 }
993
994 return {};
995}
996
997//===----------------------------------------------------------------------===//
998// Unary Operators
999//===----------------------------------------------------------------------===//
1000
1001OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
1002 auto base = getInput().getType();
1003 auto w = getBitWidth(base);
1004 if (w)
1005 return getIntAttr(getType(), APInt(32, *w));
1006 return {};
1007}
1008
1009OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
1010 // No constant can be 'x' by definition.
1011 if (auto cst = getConstant(adaptor.getArg()))
1012 return getIntAttr(getType(), APInt(1, 0));
1013 return {};
1014}
1015
1016OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1017 // No effect.
1018 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1019 return getInput();
1020
1021 // Be careful to only fold the cast into the constant if the size is known.
1022 // Otherwise width inference may produce differently-sized constants if the
1023 // sign changes.
1024 if (getType().base().hasWidth())
1025 if (auto cst = getConstant(adaptor.getInput()))
1026 return getIntAttr(getType(), *cst);
1027
1028 return {};
1029}
1030
1031void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1032 MLIRContext *context) {
1033 results.insert<patterns::StoUtoS>(context);
1034}
1035
1036OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1037 // No effect.
1038 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1039 return getInput();
1040
1041 // Be careful to only fold the cast into the constant if the size is known.
1042 // Otherwise width inference may produce differently-sized constants if the
1043 // sign changes.
1044 if (getType().base().hasWidth())
1045 if (auto cst = getConstant(adaptor.getInput()))
1046 return getIntAttr(getType(), *cst);
1047
1048 return {};
1049}
1050
1051void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1052 MLIRContext *context) {
1053 results.insert<patterns::UtoStoU>(context);
1054}
1055
1056OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1057 // No effect.
1058 if (getInput().getType() == getType())
1059 return getInput();
1060
1061 // Constant fold.
1062 if (auto cst = getConstant(adaptor.getInput()))
1063 return BoolAttr::get(getContext(), cst->getBoolValue());
1064
1065 return {};
1066}
1067
1068OpFoldResult AsResetPrimOp::fold(FoldAdaptor adaptor) {
1069 if (auto cst = getConstant(adaptor.getInput()))
1070 return BoolAttr::get(getContext(), cst->getBoolValue());
1071 return {};
1072}
1073
1074OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1075 // No effect.
1076 if (getInput().getType() == getType())
1077 return getInput();
1078
1079 // Constant fold.
1080 if (auto cst = getConstant(adaptor.getInput()))
1081 return BoolAttr::get(getContext(), cst->getBoolValue());
1082
1083 return {};
1084}
1085
1086OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1087 if (!hasKnownWidthIntTypes(*this))
1088 return {};
1089
1090 // Signed to signed is a noop, unsigned operands prepend a zero bit.
1091 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1092 getType().base().getWidthOrSentinel()))
1093 return getIntAttr(getType(), *cst);
1094
1095 return {};
1096}
1097
1098void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1099 MLIRContext *context) {
1100 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1101}
1102
1103OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1104 if (!hasKnownWidthIntTypes(*this))
1105 return {};
1106
1107 // FIRRTL negate always adds a bit.
1108 // -x ---> 0-sext(x) or 0-zext(x)
1109 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1110 getType().base().getWidthOrSentinel()))
1111 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1112
1113 return {};
1114}
1115
1116OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1117 if (!hasKnownWidthIntTypes(*this))
1118 return {};
1119
1120 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1121 getType().base().getWidthOrSentinel()))
1122 return getIntAttr(getType(), ~*cst);
1123
1124 return {};
1125}
1126
1127void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1128 MLIRContext *context) {
1129 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1130 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1131 patterns::NotGt>(context);
1132}
1133
1134class ReductionCat : public mlir::RewritePattern {
1135public:
1136 ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
1137 : RewritePattern(opName, 0, context) {}
1138
1139 /// Handle a constant operand in the cat operation.
1140 /// Returns true if the entire reduction can be replaced with a constant.
1141 /// May add non-zero constants to the remaining operands list.
1142 virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1143 ConstantOp constantOp,
1144 SmallVectorImpl<Value> &remaining) const = 0;
1145
1146 /// Return the unit value for this reduction operation:
1147 virtual bool getIdentityValue() const = 0;
1148
1149 LogicalResult
1150 matchAndRewrite(Operation *op,
1151 mlir::PatternRewriter &rewriter) const override {
1152 // Check if the operand is a cat operation
1153 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1154 if (!catOp)
1155 return failure();
1156
1157 SmallVector<Value> nonConstantOperands;
1158
1159 // Process each operand of the cat operation
1160 for (auto operand : catOp.getInputs()) {
1161 if (auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1162 // Handle constant operands - may short-circuit the entire operation
1163 if (handleConstant(rewriter, op, constantOp, nonConstantOperands))
1164 return success();
1165 } else {
1166 // Keep non-constant operands for further processing
1167 nonConstantOperands.push_back(operand);
1168 }
1169 }
1170
1171 // If no operands remain, replace with identity value
1172 if (nonConstantOperands.empty()) {
1173 replaceOpWithNewOpAndCopyName<ConstantOp>(
1174 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1175 APInt(1, getIdentityValue()));
1176 return success();
1177 }
1178
1179 // If only one operand remains, apply reduction directly to it
1180 if (nonConstantOperands.size() == 1) {
1181 rewriter.modifyOpInPlace(
1182 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1183 return success();
1184 }
1185
1186 // Multiple operands remain - optimize only when cat has a single use.
1187 if (catOp->hasOneUse() &&
1188 nonConstantOperands.size() < catOp->getNumOperands()) {
1189 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1190 nonConstantOperands);
1191 return success();
1192 }
1193 return failure();
1194 }
1195};
1196
1197class OrRCat : public ReductionCat {
1198public:
1199 OrRCat(MLIRContext *context)
1200 : ReductionCat(context, OrRPrimOp::getOperationName()) {}
1201 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1202 ConstantOp value,
1203 SmallVectorImpl<Value> &remaining) const override {
1204 if (value.getValue().isZero())
1205 return false;
1206
1207 replaceOpWithNewOpAndCopyName<ConstantOp>(
1208 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1209 llvm::APInt(1, 1));
1210 return true;
1211 }
1212 bool getIdentityValue() const override { return false; }
1213};
1214
1215class AndRCat : public ReductionCat {
1216public:
1217 AndRCat(MLIRContext *context)
1218 : ReductionCat(context, AndRPrimOp::getOperationName()) {}
1219 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1220 ConstantOp value,
1221 SmallVectorImpl<Value> &remaining) const override {
1222 if (value.getValue().isAllOnes())
1223 return false;
1224
1225 replaceOpWithNewOpAndCopyName<ConstantOp>(
1226 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1227 llvm::APInt(1, 0));
1228 return true;
1229 }
1230 bool getIdentityValue() const override { return true; }
1231};
1232
1233class XorRCat : public ReductionCat {
1234public:
1235 XorRCat(MLIRContext *context)
1236 : ReductionCat(context, XorRPrimOp::getOperationName()) {}
1237 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1238 ConstantOp value,
1239 SmallVectorImpl<Value> &remaining) const override {
1240 if (value.getValue().isZero())
1241 return false;
1242 remaining.push_back(value);
1243 return false;
1244 }
1245 bool getIdentityValue() const override { return false; }
1246};
1247
1248OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1249 if (!hasKnownWidthIntTypes(*this))
1250 return {};
1251
1252 if (getInput().getType().getBitWidthOrSentinel() == 0)
1253 return getIntAttr(getType(), APInt(1, 1));
1254
1255 // x == -1
1256 if (auto cst = getConstant(adaptor.getInput()))
1257 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1258
1259 // one bit is identity. Only applies to UInt since we can't make a cast
1260 // here.
1261 if (isUInt1(getInput().getType()))
1262 return getInput();
1263
1264 return {};
1265}
1266
1267void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1268 MLIRContext *context) {
1269 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1270 patterns::AndRPadS, patterns::AndRCatAndR_left,
1271 patterns::AndRCatAndR_right, AndRCat>(context);
1272}
1273
1274OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1275 if (!hasKnownWidthIntTypes(*this))
1276 return {};
1277
1278 if (getInput().getType().getBitWidthOrSentinel() == 0)
1279 return getIntAttr(getType(), APInt(1, 0));
1280
1281 // x != 0
1282 if (auto cst = getConstant(adaptor.getInput()))
1283 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1284
1285 // one bit is identity. Only applies to UInt since we can't make a cast
1286 // here.
1287 if (isUInt1(getInput().getType()))
1288 return getInput();
1289
1290 return {};
1291}
1292
1293void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1294 MLIRContext *context) {
1295 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1296 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right, OrRCat>(
1297 context);
1298}
1299
1300OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1301 if (!hasKnownWidthIntTypes(*this))
1302 return {};
1303
1304 if (getInput().getType().getBitWidthOrSentinel() == 0)
1305 return getIntAttr(getType(), APInt(1, 0));
1306
1307 // popcount(x) & 1
1308 if (auto cst = getConstant(adaptor.getInput()))
1309 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1310
1311 // one bit is identity. Only applies to UInt since we can't make a cast here.
1312 if (isUInt1(getInput().getType()))
1313 return getInput();
1314
1315 return {};
1316}
1317
1318void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1319 MLIRContext *context) {
1320 results
1321 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1322 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right, XorRCat>(
1323 context);
1324}
1325
1326//===----------------------------------------------------------------------===//
1327// Other Operators
1328//===----------------------------------------------------------------------===//
1329
1330OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1331 auto inputs = getInputs();
1332 auto inputAdaptors = adaptor.getInputs();
1333
1334 // If no inputs, return 0-bit value
1335 if (inputs.empty())
1336 return getIntZerosAttr(getType());
1337
1338 // If single input and same type, return it
1339 if (inputs.size() == 1 && inputs[0].getType() == getType())
1340 return inputs[0];
1341
1342 // Make sure it's safe to fold.
1343 if (!hasKnownWidthIntTypes(*this))
1344 return {};
1345
1346 // Filter out zero-width operands
1347 SmallVector<Value> nonZeroInputs;
1348 SmallVector<Attribute> nonZeroAttributes;
1349 bool allConstant = true;
1350 for (auto [input, attr] : llvm::zip(inputs, inputAdaptors)) {
1351 auto inputType = type_cast<IntType>(input.getType());
1352 if (inputType.getBitWidthOrSentinel() != 0) {
1353 nonZeroInputs.push_back(input);
1354 if (!attr)
1355 allConstant = false;
1356 if (nonZeroInputs.size() > 1 && !allConstant)
1357 return {};
1358 }
1359 }
1360
1361 // If all inputs were zero-width, return 0-bit value
1362 if (nonZeroInputs.empty())
1363 return getIntZerosAttr(getType());
1364
1365 // If only one non-zero input and it has the same type as result, return it
1366 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1367 return nonZeroInputs[0];
1368
1369 if (!hasKnownWidthIntTypes(*this))
1370 return {};
1371
1372 // Constant fold cat - concatenate all constant operands
1373 SmallVector<APInt> constants;
1374 for (auto inputAdaptor : inputAdaptors) {
1375 if (auto cst = getConstant(inputAdaptor))
1376 constants.push_back(*cst);
1377 else
1378 return {}; // Not all operands are constant
1379 }
1380
1381 assert(!constants.empty());
1382 // Concatenate all constants from left to right
1383 APInt result = constants[0];
1384 for (size_t i = 1; i < constants.size(); ++i)
1385 result = result.concat(constants[i]);
1386
1387 return getIntAttr(getType(), result);
1388}
1389
1390void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1391 MLIRContext *context) {
1392 results.insert<patterns::DShlOfConstant>(context);
1393}
1394
1395void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1396 MLIRContext *context) {
1397 results.insert<patterns::DShrOfConstant>(context);
1398}
1399
1400namespace {
1401
1402/// Canonicalize a cat tree into single cat operation.
1403class FlattenCat : public mlir::OpRewritePattern<CatPrimOp> {
1404public:
1405 using OpRewritePattern::OpRewritePattern;
1406
1407 LogicalResult
1408 matchAndRewrite(CatPrimOp cat,
1409 mlir::PatternRewriter &rewriter) const override {
1410 if (!hasKnownWidthIntTypes(cat) ||
1411 cat.getType().getBitWidthOrSentinel() == 0)
1412 return failure();
1413
1414 // If this is not a root of concat tree, skip.
1415 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1416 return failure();
1417
1418 // Try flattening the cat tree.
1419 SmallVector<Value> operands;
1420 SmallVector<Value> worklist;
1421 auto pushOperands = [&worklist](CatPrimOp op) {
1422 for (auto operand : llvm::reverse(op.getInputs()))
1423 worklist.push_back(operand);
1424 };
1425 pushOperands(cat);
1426 bool hasSigned = false, hasUnsigned = false;
1427 while (!worklist.empty()) {
1428 auto value = worklist.pop_back_val();
1429 auto catOp = value.getDefiningOp<CatPrimOp>();
1430 if (!catOp) {
1431 operands.push_back(value);
1432 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) = true;
1433 continue;
1434 }
1435
1436 pushOperands(catOp);
1437 }
1438
1439 // Helper function to cast signed values to unsigned. CatPrimOp converts
1440 // signed values to unsigned so we need to do it explicitly here.
1441 auto castToUIntIfSigned = [&](Value value) -> Value {
1442 if (type_isa<UIntType>(value.getType()))
1443 return value;
1444 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1445 };
1446
1447 assert(operands.size() >= 1 && "zero width cast must be rejected");
1448
1449 if (operands.size() == 1) {
1450 rewriter.replaceOp(cat, castToUIntIfSigned(operands[0]));
1451 return success();
1452 }
1453
1454 if (operands.size() == cat->getNumOperands())
1455 return failure();
1456
1457 // If types are mixed, cast all operands to unsigned.
1458 if (hasSigned && hasUnsigned)
1459 for (auto &operand : operands)
1460 operand = castToUIntIfSigned(operand);
1461
1462 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1463 operands);
1464 return success();
1465 }
1466};
1467
1468// Fold successive constants cat(x, y, 1, 10, z) -> cat(x, y, 110, z)
1469class CatOfConstant : public mlir::OpRewritePattern<CatPrimOp> {
1470public:
1471 using OpRewritePattern::OpRewritePattern;
1472
1473 LogicalResult
1474 matchAndRewrite(CatPrimOp cat,
1475 mlir::PatternRewriter &rewriter) const override {
1476 if (!hasKnownWidthIntTypes(cat))
1477 return failure();
1478
1479 SmallVector<Value> operands;
1480
1481 for (size_t i = 0; i < cat->getNumOperands(); ++i) {
1482 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1483 if (!cst) {
1484 operands.push_back(cat.getInputs()[i]);
1485 continue;
1486 }
1487 APSInt value = cst.getValue();
1488 size_t j = i + 1;
1489 for (; j < cat->getNumOperands(); ++j) {
1490 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1491 if (!nextCst)
1492 break;
1493 value = value.concat(nextCst.getValue());
1494 }
1495
1496 if (j == i + 1) {
1497 // Not folded.
1498 operands.push_back(cst);
1499 } else {
1500 // Folded.
1501 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1502 }
1503
1504 i = j - 1;
1505 }
1506
1507 if (operands.size() == cat->getNumOperands())
1508 return failure();
1509
1510 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1511 operands);
1512
1513 return success();
1514 }
1515};
1516
1517} // namespace
1518
1519void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1520 MLIRContext *context) {
1521 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1522 patterns::CatCast, FlattenCat, CatOfConstant>(context);
1523}
1524
1525//===----------------------------------------------------------------------===//
1526// StringConcatOp
1527//===----------------------------------------------------------------------===//
1528
1529OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
1530 // Fold single-operand concat to just the operand.
1531 if (getInputs().size() == 1)
1532 return getInputs()[0];
1533
1534 // Check if all operands are constant strings before accumulating.
1535 if (!llvm::all_of(adaptor.getInputs(), [](Attribute operand) {
1536 return isa_and_nonnull<StringAttr>(operand);
1537 }))
1538 return {};
1539
1540 // All operands are constant strings, concatenate them.
1541 SmallString<64> result;
1542 for (auto operand : adaptor.getInputs())
1543 result += cast<StringAttr>(operand).getValue();
1544
1545 return StringAttr::get(getContext(), result);
1546}
1547
1548namespace {
1549/// Flatten nested string.concat operations into a single concat.
1550/// string.concat(a, string.concat(b, c), d) -> string.concat(a, b, c, d)
1551class FlattenStringConcat : public mlir::OpRewritePattern<StringConcatOp> {
1552public:
1553 using OpRewritePattern::OpRewritePattern;
1554
1555 LogicalResult
1556 matchAndRewrite(StringConcatOp concat,
1557 mlir::PatternRewriter &rewriter) const override {
1558
1559 // Check if any operands are nested concats with a single use. Only inline
1560 // single-use nested concats to avoid fighting with DCE.
1561 bool hasNestedConcat = llvm::any_of(concat.getInputs(), [](Value operand) {
1562 auto nestedConcat = operand.getDefiningOp<StringConcatOp>();
1563 return nestedConcat && operand.hasOneUse();
1564 });
1565
1566 if (!hasNestedConcat)
1567 return failure();
1568
1569 // Flatten nested concats that have a single use.
1570 SmallVector<Value> flatOperands;
1571 for (auto input : concat.getInputs()) {
1572 if (auto nestedConcat = input.getDefiningOp<StringConcatOp>();
1573 nestedConcat && input.hasOneUse())
1574 llvm::append_range(flatOperands, nestedConcat.getInputs());
1575 else
1576 flatOperands.push_back(input);
1577 }
1578
1579 rewriter.modifyOpInPlace(concat,
1580 [&]() { concat->setOperands(flatOperands); });
1581 return success();
1582 }
1583};
1584
1585/// Merge consecutive constant strings in a concat and remove empty strings.
1586/// string.concat("a", "b", x, "", "c", "d") -> string.concat("ab", x, "cd")
1587class MergeAdjacentStringConstants
1588 : public mlir::OpRewritePattern<StringConcatOp> {
1589public:
1590 using OpRewritePattern::OpRewritePattern;
1591
1592 LogicalResult
1593 matchAndRewrite(StringConcatOp concat,
1594 mlir::PatternRewriter &rewriter) const override {
1595
1596 SmallVector<Value> newOperands;
1597 SmallString<64> accumulatedLit;
1598 SmallVector<StringConstantOp> accumulatedOps;
1599 bool changed = false;
1600
1601 auto flushLiterals = [&]() {
1602 if (accumulatedOps.empty())
1603 return;
1604
1605 // If only one literal, reuse it.
1606 if (accumulatedOps.size() == 1) {
1607 newOperands.push_back(accumulatedOps[0]);
1608 } else {
1609 // Multiple literals - merge them.
1610 auto newLit = rewriter.createOrFold<StringConstantOp>(
1611 concat.getLoc(), StringAttr::get(getContext(), accumulatedLit));
1612 newOperands.push_back(newLit);
1613 changed = true;
1614 }
1615 accumulatedLit.clear();
1616 accumulatedOps.clear();
1617 };
1618
1619 for (auto operand : concat.getInputs()) {
1620 if (auto litOp = operand.getDefiningOp<StringConstantOp>()) {
1621 // Skip empty strings.
1622 if (litOp.getValue().empty()) {
1623 changed = true;
1624 continue;
1625 }
1626 accumulatedLit += litOp.getValue();
1627 accumulatedOps.push_back(litOp);
1628 } else {
1629 flushLiterals();
1630 newOperands.push_back(operand);
1631 }
1632 }
1633
1634 // Flush any remaining literals.
1635 flushLiterals();
1636
1637 if (!changed)
1638 return failure();
1639
1640 // If no operands remain, replace with empty string.
1641 if (newOperands.empty())
1642 return rewriter.replaceOpWithNewOp<StringConstantOp>(
1643 concat, StringAttr::get(getContext(), "")),
1644 success();
1645
1646 // Single-operand case is handled by the folder.
1647 rewriter.modifyOpInPlace(concat,
1648 [&]() { concat->setOperands(newOperands); });
1649 return success();
1650 }
1651};
1652
1653} // namespace
1654
1655void StringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
1656 MLIRContext *context) {
1657 results.insert<FlattenStringConcat, MergeAdjacentStringConstants>(context);
1658}
1659
1660//===----------------------------------------------------------------------===//
1661// PropEqOp
1662//===----------------------------------------------------------------------===//
1663
1664OpFoldResult PropEqOp::fold(FoldAdaptor adaptor) {
1665 auto lhsAttr = adaptor.getLhs();
1666 auto rhsAttr = adaptor.getRhs();
1667 if (!lhsAttr || !rhsAttr)
1668 return {};
1669
1670 return BoolAttr::get(getContext(), lhsAttr == rhsAttr);
1671}
1672
1673OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1674 auto op = (*this);
1675 // BitCast is redundant if input and result types are same.
1676 if (op.getType() == op.getInput().getType())
1677 return op.getInput();
1678
1679 // Two consecutive BitCasts are redundant if first bitcast type is same as the
1680 // final result type.
1681 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1682 if (op.getType() == in.getInput().getType())
1683 return in.getInput();
1684
1685 return {};
1686}
1687
1688OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1689 IntType inputType = getInput().getType();
1690 IntType resultType = getType();
1691 // If we are extracting the entire input, then return it.
1692 if (inputType == getType() && resultType.hasWidth())
1693 return getInput();
1694
1695 // Constant fold.
1696 if (hasKnownWidthIntTypes(*this))
1697 if (auto cst = getConstant(adaptor.getInput()))
1698 return getIntAttr(resultType,
1699 cst->extractBits(getHi() - getLo() + 1, getLo()));
1700
1701 return {};
1702}
1703
1704struct BitsOfCat : public mlir::OpRewritePattern<BitsPrimOp> {
1705 using OpRewritePattern::OpRewritePattern;
1706
1707 LogicalResult
1708 matchAndRewrite(BitsPrimOp bits,
1709 mlir::PatternRewriter &rewriter) const override {
1710 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1711 if (!cat)
1712 return failure();
1713 int32_t bitPos = bits.getLo();
1714 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1715 if (resultWidth < 0)
1716 return failure();
1717 for (auto operand : llvm::reverse(cat.getInputs())) {
1718 auto operandWidth =
1719 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1720 if (operandWidth < 0)
1721 return failure();
1722 if (bitPos < operandWidth) {
1723 if (bitPos + resultWidth <= operandWidth) {
1724 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1725 bits.getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1726 replaceOpAndCopyName(rewriter, bits, newBits);
1727 return success();
1728 }
1729 return failure();
1730 }
1731 bitPos -= operandWidth;
1732 }
1733 return failure();
1734 }
1735};
1736
1737void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1738 MLIRContext *context) {
1739 results
1740 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1741 patterns::BitsOfAnd, patterns::BitsOfPad, BitsOfCat>(context);
1742}
1743
1744/// Replace the specified operation with a 'bits' op from the specified hi/lo
1745/// bits. Insert a cast to handle the case where the original operation
1746/// returned a signed integer.
1747static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1748 unsigned loBit, PatternRewriter &rewriter) {
1749 auto resType = type_cast<IntType>(op->getResult(0).getType());
1750 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1751 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1752
1753 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1754 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1755 } else if (resType.isUnsigned() &&
1756 !type_cast<IntType>(value.getType()).isUnsigned()) {
1757 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1758 }
1759 rewriter.replaceOp(op, value);
1760}
1761
1762template <typename OpTy>
1763static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1764 // mux : UInt<0> -> 0
1765 if (op.getType().getBitWidthOrSentinel() == 0)
1766 return getIntAttr(op.getType(),
1767 APInt(0, 0, op.getType().isSignedInteger()));
1768
1769 // mux(cond, x, x) -> x
1770 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1771 return op.getHigh();
1772
1773 // The following folds require that the result has a known width. Otherwise
1774 // the mux requires an additional padding operation to be inserted, which is
1775 // not possible in a fold.
1776 if (op.getType().getBitWidthOrSentinel() < 0)
1777 return {};
1778
1779 // mux(0/1, x, y) -> x or y
1780 if (auto cond = getConstant(adaptor.getSel())) {
1781 if (cond->isZero() && op.getLow().getType() == op.getType())
1782 return op.getLow();
1783 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1784 return op.getHigh();
1785 }
1786
1787 // mux(cond, x, cst)
1788 if (auto lowCst = getConstant(adaptor.getLow())) {
1789 // mux(cond, c1, c2)
1790 if (auto highCst = getConstant(adaptor.getHigh())) {
1791 // mux(cond, cst, cst) -> cst
1792 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1793 *highCst == *lowCst)
1794 return getIntAttr(op.getType(), *highCst);
1795 // mux(cond, 1, 0) -> cond
1796 if (highCst->isOne() && lowCst->isZero() &&
1797 op.getType() == op.getSel().getType())
1798 return op.getSel();
1799
1800 // TODO: x ? ~0 : 0 -> sext(x)
1801 // TODO: "x ? c1 : c2" -> many tricks
1802 }
1803 // TODO: "x ? a : 0" -> sext(x) & a
1804 }
1805
1806 // TODO: "x ? c1 : y" -> "~x ? y : c1"
1807 return {};
1808}
1809
1810OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1811 return foldMux(*this, adaptor);
1812}
1813
1814OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1815 return foldMux(*this, adaptor);
1816}
1817
1818OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1819
1820namespace {
1821
1822// If the mux has a known output width, pad the operands up to this width.
1823// Most folds on mux require that folded operands are of the same width as
1824// the mux itself.
1825class MuxPad : public mlir::OpRewritePattern<MuxPrimOp> {
1826public:
1827 using OpRewritePattern::OpRewritePattern;
1828
1829 LogicalResult
1830 matchAndRewrite(MuxPrimOp mux,
1831 mlir::PatternRewriter &rewriter) const override {
1832 auto width = mux.getType().getBitWidthOrSentinel();
1833 if (width < 0)
1834 return failure();
1835
1836 auto pad = [&](Value input) -> Value {
1837 auto inputWidth =
1838 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1839 if (inputWidth < 0 || width == inputWidth)
1840 return input;
1841 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1842 width)
1843 .getResult();
1844 };
1845
1846 auto newHigh = pad(mux.getHigh());
1847 auto newLow = pad(mux.getLow());
1848 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1849 return failure();
1850
1851 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1852 rewriter, mux, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1853 mux->getAttrs());
1854 return success();
1855 }
1856};
1857
1858// Find muxes which have conditions dominated by other muxes with the same
1859// condition.
1860class MuxSharedCond : public mlir::OpRewritePattern<MuxPrimOp> {
1861public:
1862 using OpRewritePattern::OpRewritePattern;
1863
1864 static const int depthLimit = 5;
1865
1866 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1867 mlir::PatternRewriter &rewriter,
1868 bool updateInPlace) const {
1869 if (updateInPlace) {
1870 rewriter.modifyOpInPlace(mux, [&] {
1871 mux.setOperand(1, high);
1872 mux.setOperand(2, low);
1873 });
1874 return {};
1875 }
1876 rewriter.setInsertionPointAfter(mux);
1877 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1878 ValueRange{mux.getSel(), high, low})
1879 .getResult();
1880 }
1881
1882 // Walk a dependent mux tree assuming the condition cond is true.
1883 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1884 bool updateInPlace, int limit) const {
1885 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1886 if (!mux)
1887 return {};
1888 if (mux.getSel() == cond)
1889 return mux.getHigh();
1890 if (limit > depthLimit)
1891 return {};
1892 updateInPlace &= mux->hasOneUse();
1893
1894 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1895 limit + 1))
1896 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1897
1898 if (Value v =
1899 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1900 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1901 return {};
1902 }
1903
1904 // Walk a dependent mux tree assuming the condition cond is false.
1905 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1906 bool updateInPlace, int limit) const {
1907 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1908 if (!mux)
1909 return {};
1910 if (mux.getSel() == cond)
1911 return mux.getLow();
1912 if (limit > depthLimit)
1913 return {};
1914 updateInPlace &= mux->hasOneUse();
1915
1916 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1917 limit + 1))
1918 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1919
1920 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1921 limit + 1))
1922 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1923
1924 return {};
1925 }
1926
1927 LogicalResult
1928 matchAndRewrite(MuxPrimOp mux,
1929 mlir::PatternRewriter &rewriter) const override {
1930 auto width = mux.getType().getBitWidthOrSentinel();
1931 if (width < 0)
1932 return failure();
1933
1934 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1935 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1936 return success();
1937 }
1938
1939 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
1940 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1941 return success();
1942 }
1943
1944 return failure();
1945 }
1946};
1947} // namespace
1948
1949void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1950 MLIRContext *context) {
1951 results
1952 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1953 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1954 patterns::MuxSameTrue, patterns::MuxSameFalse,
1955 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1956 context);
1957}
1958
1959void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1960 RewritePatternSet &results, MLIRContext *context) {
1961 results.add<patterns::Mux2PadSel>(context);
1962}
1963
1964void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1965 RewritePatternSet &results, MLIRContext *context) {
1966 results.add<patterns::Mux4PadSel>(context);
1967}
1968
1969OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1970 auto input = this->getInput();
1971
1972 // pad(x) -> x if the width doesn't change.
1973 if (input.getType() == getType())
1974 return input;
1975
1976 // Need to know the input width.
1977 auto inputType = input.getType().base();
1978 int32_t width = inputType.getWidthOrSentinel();
1979 if (width == -1)
1980 return {};
1981
1982 // Constant fold.
1983 if (auto cst = getConstant(adaptor.getInput())) {
1984 auto destWidth = getType().base().getWidthOrSentinel();
1985 if (destWidth == -1)
1986 return {};
1987
1988 if (inputType.isSigned() && cst->getBitWidth())
1989 return getIntAttr(getType(), cst->sext(destWidth));
1990 return getIntAttr(getType(), cst->zext(destWidth));
1991 }
1992
1993 return {};
1994}
1995
1996OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1997 auto input = this->getInput();
1998 IntType inputType = input.getType();
1999 int shiftAmount = getAmount();
2000
2001 // shl(x, 0) -> x
2002 if (shiftAmount == 0)
2003 return input;
2004
2005 // Constant fold.
2006 if (auto cst = getConstant(adaptor.getInput())) {
2007 auto inputWidth = inputType.getWidthOrSentinel();
2008 if (inputWidth != -1) {
2009 auto resultWidth = inputWidth + shiftAmount;
2010 shiftAmount = std::min(shiftAmount, resultWidth);
2011 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
2012 }
2013 }
2014 return {};
2015}
2016
2017OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
2018 auto input = this->getInput();
2019 IntType inputType = input.getType();
2020 int shiftAmount = getAmount();
2021 auto inputWidth = inputType.getWidthOrSentinel();
2022
2023 // shr(x, 0) -> x
2024 // Once the shr width changes, do this: shiftAmount == 0 &&
2025 // (!inputType.isSigned() || inputWidth > 0)
2026 if (shiftAmount == 0 && inputWidth > 0)
2027 return input;
2028
2029 if (inputWidth == -1)
2030 return {};
2031 if (inputWidth == 0)
2032 return getIntZerosAttr(getType());
2033
2034 // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
2035 // If x is signed, it is the sign bit.
2036 if (shiftAmount >= inputWidth && inputType.isUnsigned())
2037 return getIntAttr(getType(), APInt(0, 0, false));
2038
2039 // Constant fold.
2040 if (auto cst = getConstant(adaptor.getInput())) {
2041 APInt value;
2042 if (inputType.isSigned())
2043 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
2044 else
2045 value = cst->lshr(std::min(shiftAmount, inputWidth));
2046 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
2047 return getIntAttr(getType(), value.trunc(resultWidth));
2048 }
2049 return {};
2050}
2051
2052LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
2053 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2054 if (inputWidth <= 0)
2055 return failure();
2056
2057 // If we know the input width, we can canonicalize this into a BitsPrimOp.
2058 unsigned shiftAmount = op.getAmount();
2059 if (int(shiftAmount) >= inputWidth) {
2060 // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
2061 if (op.getType().base().isUnsigned())
2062 return failure();
2063
2064 // Shifting a signed value by the full width is actually taking the
2065 // sign bit. If the shift amount is greater than the input width, it
2066 // is equivalent to shifting by the input width.
2067 shiftAmount = inputWidth - 1;
2068 }
2069
2070 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
2071 return success();
2072}
2073
2074LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
2075 PatternRewriter &rewriter) {
2076 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2077 if (inputWidth <= 0)
2078 return failure();
2079
2080 // If we know the input width, we can canonicalize this into a BitsPrimOp.
2081 unsigned keepAmount = op.getAmount();
2082 if (keepAmount)
2083 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
2084 rewriter);
2085 return success();
2086}
2087
2088OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
2089 if (hasKnownWidthIntTypes(*this))
2090 if (auto cst = getConstant(adaptor.getInput())) {
2091 int shiftAmount =
2092 getInput().getType().base().getWidthOrSentinel() - getAmount();
2093 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
2094 }
2095
2096 return {};
2097}
2098
2099OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
2100 if (hasKnownWidthIntTypes(*this))
2101 if (auto cst = getConstant(adaptor.getInput()))
2102 return getIntAttr(getType(),
2103 cst->trunc(getType().base().getWidthOrSentinel()));
2104 return {};
2105}
2106
2107LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
2108 PatternRewriter &rewriter) {
2109 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2110 if (inputWidth <= 0)
2111 return failure();
2112
2113 // If we know the input width, we can canonicalize this into a BitsPrimOp.
2114 unsigned dropAmount = op.getAmount();
2115 if (dropAmount != unsigned(inputWidth))
2116 replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
2117 rewriter);
2118 return success();
2119}
2120
2121void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
2122 MLIRContext *context) {
2123 results.add<patterns::SubaccessOfConstant>(context);
2124}
2125
2126OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
2127 // If there is only one input, just return it.
2128 if (adaptor.getInputs().size() == 1)
2129 return getOperand(1);
2130
2131 if (auto constIndex = getConstant(adaptor.getIndex())) {
2132 auto index = constIndex->getZExtValue();
2133 if (index < getInputs().size())
2134 return getInputs()[getInputs().size() - 1 - index];
2135 }
2136
2137 return {};
2138}
2139
2140LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
2141 PatternRewriter &rewriter) {
2142 // If all operands are equal, just canonicalize to it. We can add this
2143 // canonicalization as a folder but it costly to look through all inputs so it
2144 // is added here.
2145 if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
2146 return input == op.getInputs().front();
2147 })) {
2148 replaceOpAndCopyName(rewriter, op, op.getInputs().front());
2149 return success();
2150 }
2151
2152 // If the index width is narrower than the size of inputs, drop front
2153 // elements.
2154 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
2155 uint64_t inputSize = op.getInputs().size();
2156 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
2157 rewriter.modifyOpInPlace(op, [&]() {
2158 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2159 });
2160 return success();
2161 }
2162
2163 // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
2164 // a[0]`), we can fold the op into subaccess op `a[idx]`.
2165 if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2166 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
2167 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2168 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2169 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2170 })) {
2171 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2172 rewriter, op, lastSubindex.getInput(), op.getIndex());
2173 return success();
2174 }
2175 }
2176
2177 // If the size is 2, canonicalize into a normal mux to introduce more folds.
2178 if (op.getInputs().size() != 2)
2179 return failure();
2180
2181 // TODO: Handle even when `index` doesn't have uint<1>.
2182 auto uintType = op.getIndex().getType();
2183 if (uintType.getBitWidthOrSentinel() != 1)
2184 return failure();
2185
2186 // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
2187 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2188 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2189 return success();
2190}
2191
2192//===----------------------------------------------------------------------===//
2193// Declarations
2194//===----------------------------------------------------------------------===//
2195
2196/// Scan all the uses of the specified value, checking to see if there is
2197/// exactly one connect that has the value as its destination. This returns the
2198/// operation if found and if all the other users are "reads" from the value.
2199/// Returns null if there are no connects, or multiple connects to the value, or
2200/// if the value is involved in an `AttachOp`, or if the connect isn't matching.
2201///
2202/// Note that this will simply return the connect, which is located *anywhere*
2203/// after the definition of the value. Users of this function are likely
2204/// interested in the source side of the returned connect, the definition of
2205/// which does likely not dominate the original value.
2206MatchingConnectOp firrtl::getSingleConnectUserOf(Value value) {
2207 MatchingConnectOp connect;
2208 for (Operation *user : value.getUsers()) {
2209 // If we see an attach or aggregate sublements, just conservatively fail.
2210 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2211 return {};
2212
2213 if (auto aConnect = dyn_cast<FConnectLike>(user))
2214 if (aConnect.getDest() == value) {
2215 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2216 // If this is not a matching connect, a second matching connect or in a
2217 // different block, fail.
2218 if (!matchingConnect || (connect && connect != matchingConnect) ||
2219 matchingConnect->getBlock() != value.getParentBlock())
2220 return {};
2221 connect = matchingConnect;
2222 }
2223 }
2224 return connect;
2225}
2226
2227// Forward simple values through wire's and reg's.
2228static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op,
2229 PatternRewriter &rewriter) {
2230 // While we can do this for nearly all wires, we currently limit it to simple
2231 // things.
2232 Operation *connectedDecl = op.getDest().getDefiningOp();
2233 if (!connectedDecl)
2234 return failure();
2235
2236 // Only support wire and reg for now.
2237 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2238 return failure();
2239 if (hasDontTouch(connectedDecl) || !AnnotationSet(connectedDecl).empty() ||
2240 !hasDroppableName(connectedDecl) ||
2241 cast<Forceable>(connectedDecl).isForceable())
2242 return failure();
2243
2244 // Only forward if the types exactly match and there is one connect.
2245 if (getSingleConnectUserOf(op.getDest()) != op)
2246 return failure();
2247
2248 // Only forward if there is more than one use
2249 if (connectedDecl->hasOneUse())
2250 return failure();
2251
2252 // Only do this if the connectee and the declaration are in the same block.
2253 auto *declBlock = connectedDecl->getBlock();
2254 auto *srcValueOp = op.getSrc().getDefiningOp();
2255 if (!srcValueOp) {
2256 // Ports are ok for wires but not registers.
2257 if (!isa<WireOp>(connectedDecl))
2258 return failure();
2259
2260 } else {
2261 // Constants/invalids in the same block are ok to forward, even through
2262 // reg's since the clocking doesn't matter for constants.
2263 if (!isa<ConstantOp>(srcValueOp))
2264 return failure();
2265 if (srcValueOp->getBlock() != declBlock)
2266 return failure();
2267 }
2268
2269 // Ok, we know we are doing the transformation.
2270
2271 auto replacement = op.getSrc();
2272 // This will be replaced with the constant source. First, make sure the
2273 // constant dominates all users.
2274 if (srcValueOp && srcValueOp != &declBlock->front())
2275 srcValueOp->moveBefore(&declBlock->front());
2276
2277 // Replace all things *using* the decl with the constant/port, and
2278 // remove the declaration.
2279 replaceOpAndCopyName(rewriter, connectedDecl, replacement);
2280
2281 // Remove the connect
2282 rewriter.eraseOp(op);
2283 return success();
2284}
2285
2286void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2287 MLIRContext *context) {
2288 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2289 context);
2290}
2291
2292LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2293 PatternRewriter &rewriter) {
2294 // TODO: Canonicalize towards explicit extensions and flips here.
2295
2296 // If there is a simple value connected to a foldable decl like a wire or reg,
2297 // see if we can eliminate the decl.
2298 if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
2299 return success();
2300 return failure();
2301}
2302
2303//===----------------------------------------------------------------------===//
2304// Statements
2305//===----------------------------------------------------------------------===//
2306
2307/// If the specified value has an AttachOp user strictly dominating by
2308/// "dominatingAttach" then return it.
2309static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
2310 for (auto *user : value.getUsers()) {
2311 auto attach = dyn_cast<AttachOp>(user);
2312 if (!attach || attach == dominatedAttach)
2313 continue;
2314 if (attach->isBeforeInBlock(dominatedAttach))
2315 return attach;
2316 }
2317 return {};
2318}
2319
2320LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2321 // Single operand attaches are a noop.
2322 if (op.getNumOperands() <= 1) {
2323 rewriter.eraseOp(op);
2324 return success();
2325 }
2326
2327 for (auto operand : op.getOperands()) {
2328 // Check to see if any of our operands has other attaches to it:
2329 // attach x, y
2330 // ...
2331 // attach x, z
2332 // If so, we can merge these into "attach x, y, z".
2333 if (auto attach = getDominatingAttachUser(operand, op)) {
2334 SmallVector<Value> newOperands(op.getOperands());
2335 for (auto newOperand : attach.getOperands())
2336 if (newOperand != operand) // Don't add operand twice.
2337 newOperands.push_back(newOperand);
2338 AttachOp::create(rewriter, op->getLoc(), newOperands);
2339 rewriter.eraseOp(attach);
2340 rewriter.eraseOp(op);
2341 return success();
2342 }
2343
2344 // If this wire is *only* used by an attach then we can just delete
2345 // it.
2346 // TODO: May need to be sensitive to "don't touch" or other
2347 // annotations.
2348 if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2349 if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2350 !wire.isForceable()) {
2351 SmallVector<Value> newOperands;
2352 for (auto newOperand : op.getOperands())
2353 if (newOperand != operand) // Don't the add wire.
2354 newOperands.push_back(newOperand);
2355
2356 AttachOp::create(rewriter, op->getLoc(), newOperands);
2357 rewriter.eraseOp(op);
2358 rewriter.eraseOp(wire);
2359 return success();
2360 }
2361 }
2362 }
2363 return failure();
2364}
2365
2366/// Replaces the given op with the contents of the given single-block region.
2367static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
2368 Region &region) {
2369 assert(llvm::hasSingleElement(region) && "expected single-region block");
2370 rewriter.inlineBlockBefore(&region.front(), op, {});
2371}
2372
2373LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2374 if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2375 if (constant.getValue().isAllOnes())
2376 replaceOpWithRegion(rewriter, op, op.getThenRegion());
2377 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2378 replaceOpWithRegion(rewriter, op, op.getElseRegion());
2379
2380 rewriter.eraseOp(op);
2381
2382 return success();
2383 }
2384
2385 // Erase empty if-else block.
2386 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2387 op.getElseBlock().empty()) {
2388 rewriter.eraseBlock(&op.getElseBlock());
2389 return success();
2390 }
2391
2392 // Erase empty whens.
2393
2394 // If there is stuff in the then block, leave this operation alone.
2395 if (!op.getThenBlock().empty())
2396 return failure();
2397
2398 // If not and there is no else, then this operation is just useless.
2399 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2400 rewriter.eraseOp(op);
2401 return success();
2402 }
2403 return failure();
2404}
2405
2406namespace {
2407// Remove private nodes. If they have an interesting names, move the name to
2408// the source expression.
2409struct FoldNodeName : public mlir::OpRewritePattern<NodeOp> {
2410 using OpRewritePattern::OpRewritePattern;
2411 LogicalResult matchAndRewrite(NodeOp node,
2412 PatternRewriter &rewriter) const override {
2413 auto name = node.getNameAttr();
2414 if (!node.hasDroppableName() || node.getInnerSym() ||
2415 !AnnotationSet(node).empty() || node.isForceable())
2416 return failure();
2417 auto *newOp = node.getInput().getDefiningOp();
2418 if (newOp)
2419 updateName(rewriter, newOp, name);
2420 rewriter.replaceOp(node, node.getInput());
2421 return success();
2422 }
2423};
2424
2425// Bypass nodes.
2426struct NodeBypass : public mlir::OpRewritePattern<NodeOp> {
2427 using OpRewritePattern::OpRewritePattern;
2428 LogicalResult matchAndRewrite(NodeOp node,
2429 PatternRewriter &rewriter) const override {
2430 if (node.getInnerSym() || !AnnotationSet(node).empty() ||
2431 node.use_empty() || node.isForceable())
2432 return failure();
2433 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2434 return success();
2435 }
2436};
2437
2438} // namespace
2439
2440template <typename OpTy>
2441static LogicalResult demoteForceableIfUnused(OpTy op,
2442 PatternRewriter &rewriter) {
2443 if (!op.isForceable() || !op.getDataRef().use_empty())
2444 return failure();
2445
2446 firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
2447 return success();
2448}
2449
2450// Interesting names and symbols and don't touch force nodes to stick around.
2451LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2452 SmallVectorImpl<OpFoldResult> &results) {
2453 if (!hasDroppableName())
2454 return failure();
2455 if (hasDontTouch(getResult())) // handles inner symbols
2456 return failure();
2457 if (getAnnotationsAttr() && !AnnotationSet(getAnnotationsAttr()).empty())
2458 return failure();
2459 if (isForceable())
2460 return failure();
2461 if (!adaptor.getInput())
2462 return failure();
2463
2464 results.push_back(adaptor.getInput());
2465 return success();
2466}
2467
2468void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2469 MLIRContext *context) {
2470 results.insert<FoldNodeName>(context);
2471 results.add(demoteForceableIfUnused<NodeOp>);
2472}
2473
2474namespace {
2475// For a lhs, find all the writers of fields of the aggregate type. If there
2476// is one writer for each field, merge the writes
2477struct AggOneShot : public mlir::RewritePattern {
2478 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2479 : RewritePattern(name, 0, context) {}
2480
2481 SmallVector<Value> getCompleteWrite(Operation *lhs) const {
2482 auto lhsTy = lhs->getResult(0).getType();
2483 if (!type_isa<BundleType, FVectorType>(lhsTy))
2484 return {};
2485
2486 DenseMap<uint32_t, Value> fields;
2487 for (Operation *user : lhs->getResult(0).getUsers()) {
2488 if (user->getParentOp() != lhs->getParentOp())
2489 return {};
2490 if (auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2491 if (aConnect.getDest() == lhs->getResult(0))
2492 return {};
2493 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2494 for (Operation *subuser : subField.getResult().getUsers()) {
2495 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2496 if (aConnect.getDest() == subField) {
2497 if (subuser->getParentOp() != lhs->getParentOp())
2498 return {};
2499 if (fields.count(subField.getFieldIndex())) // duplicate write
2500 return {};
2501 fields[subField.getFieldIndex()] = aConnect.getSrc();
2502 }
2503 continue;
2504 }
2505 return {};
2506 }
2507 } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2508 for (Operation *subuser : subIndex.getResult().getUsers()) {
2509 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2510 if (aConnect.getDest() == subIndex) {
2511 if (subuser->getParentOp() != lhs->getParentOp())
2512 return {};
2513 if (fields.count(subIndex.getIndex())) // duplicate write
2514 return {};
2515 fields[subIndex.getIndex()] = aConnect.getSrc();
2516 }
2517 continue;
2518 }
2519 return {};
2520 }
2521 } else {
2522 return {};
2523 }
2524 }
2525
2526 SmallVector<Value> values;
2527 uint32_t total = type_isa<BundleType>(lhsTy)
2528 ? type_cast<BundleType>(lhsTy).getNumElements()
2529 : type_cast<FVectorType>(lhsTy).getNumElements();
2530 for (uint32_t i = 0; i < total; ++i) {
2531 if (!fields.count(i))
2532 return {};
2533 values.push_back(fields[i]);
2534 }
2535 return values;
2536 }
2537
2538 LogicalResult matchAndRewrite(Operation *op,
2539 PatternRewriter &rewriter) const override {
2540 auto values = getCompleteWrite(op);
2541 if (values.empty())
2542 return failure();
2543 rewriter.setInsertionPointToEnd(op->getBlock());
2544 auto dest = op->getResult(0);
2545 auto destType = dest.getType();
2546
2547 // If not passive, cannot matchingconnect.
2548 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2549 return failure();
2550
2551 Value newVal = type_isa<BundleType>(destType)
2552 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2553 destType, values)
2554 : rewriter.createOrFold<VectorCreateOp>(
2555 op->getLoc(), destType, values);
2556 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2557 for (Operation *user : dest.getUsers()) {
2558 if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2559 for (Operation *subuser :
2560 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2561 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2562 if (aConnect.getDest() == subIndex)
2563 rewriter.eraseOp(aConnect);
2564 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2565 for (Operation *subuser :
2566 llvm::make_early_inc_range(subField.getResult().getUsers()))
2567 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2568 if (aConnect.getDest() == subField)
2569 rewriter.eraseOp(aConnect);
2570 }
2571 }
2572 return success();
2573 }
2574};
2575
2576struct WireAggOneShot : public AggOneShot {
2577 WireAggOneShot(MLIRContext *context)
2578 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2579};
2580struct SubindexAggOneShot : public AggOneShot {
2581 SubindexAggOneShot(MLIRContext *context)
2582 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2583};
2584struct SubfieldAggOneShot : public AggOneShot {
2585 SubfieldAggOneShot(MLIRContext *context)
2586 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2587};
2588} // namespace
2589
2590void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2591 MLIRContext *context) {
2592 results.insert<WireAggOneShot>(context);
2593 results.add(demoteForceableIfUnused<WireOp>);
2594}
2595
2596void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2597 MLIRContext *context) {
2598 results.insert<SubindexAggOneShot>(context);
2599}
2600
2601OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2602 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2603 if (!attr)
2604 return {};
2605 return attr[getIndex()];
2606}
2607
2608OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2609 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2610 if (!attr)
2611 return {};
2612 auto index = getFieldIndex();
2613 return attr[index];
2614}
2615
2616void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2617 MLIRContext *context) {
2618 results.insert<SubfieldAggOneShot>(context);
2619}
2620
2621static Attribute collectFields(MLIRContext *context,
2622 ArrayRef<Attribute> operands) {
2623 for (auto operand : operands)
2624 if (!operand)
2625 return {};
2626 return ArrayAttr::get(context, operands);
2627}
2628
2629OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2630 // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2631 // bundle<a:..., b:...>.
2632 if (getNumOperands() > 0)
2633 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2634 if (first.getFieldIndex() == 0 &&
2635 first.getInput().getType() == getType() &&
2636 llvm::all_of(
2637 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2638 auto subindex =
2639 elem.value().template getDefiningOp<SubfieldOp>();
2640 return subindex && subindex.getInput() == first.getInput() &&
2641 subindex.getFieldIndex() == elem.index();
2642 }))
2643 return first.getInput();
2644
2645 return collectFields(getContext(), adaptor.getOperands());
2646}
2647
2648OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2649 // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2650 // vector<..., 2>.
2651 if (getNumOperands() > 0)
2652 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2653 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2654 llvm::all_of(
2655 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2656 auto subindex =
2657 elem.value().template getDefiningOp<SubindexOp>();
2658 return subindex && subindex.getInput() == first.getInput() &&
2659 subindex.getIndex() == elem.index();
2660 }))
2661 return first.getInput();
2662
2663 return collectFields(getContext(), adaptor.getOperands());
2664}
2665
2666OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2667 if (getOperand().getType() == getType())
2668 return getOperand();
2669 return {};
2670}
2671
2672namespace {
2673// A register with constant reset and all connection to either itself or the
2674// same constant, must be replaced by the constant.
2675struct FoldResetMux : public mlir::OpRewritePattern<RegResetOp> {
2676 using OpRewritePattern::OpRewritePattern;
2677 LogicalResult matchAndRewrite(RegResetOp reg,
2678 PatternRewriter &rewriter) const override {
2679 auto reset =
2680 dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2681 if (!reset || hasDontTouch(reg.getOperation()) ||
2682 !AnnotationSet(reg).empty() || reg.isForceable())
2683 return failure();
2684 // Find the one true connect, or bail
2685 auto con = getSingleConnectUserOf(reg.getResult());
2686 if (!con)
2687 return failure();
2688
2689 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2690 if (!mux)
2691 return failure();
2692 auto *high = mux.getHigh().getDefiningOp();
2693 auto *low = mux.getLow().getDefiningOp();
2694 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2695
2696 if (constOp && low != reg)
2697 return failure();
2698 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2699 constOp = dyn_cast<ConstantOp>(low);
2700
2701 if (!constOp || constOp.getType() != reset.getType() ||
2702 constOp.getValue() != reset.getValue())
2703 return failure();
2704
2705 // Check all types should be typed by now
2706 auto regTy = reg.getResult().getType();
2707 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2708 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2709 regTy.getBitWidthOrSentinel() < 0)
2710 return failure();
2711
2712 // Ok, we know we are doing the transformation.
2713
2714 // Make sure the constant dominates all users.
2715 if (constOp != &con->getBlock()->front())
2716 constOp->moveBefore(&con->getBlock()->front());
2717
2718 // Replace the register with the constant.
2719 replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2720 // Remove the connect.
2721 rewriter.eraseOp(con);
2722 return success();
2723 }
2724};
2725} // namespace
2726
2727static bool isDefinedByOneConstantOp(Value v) {
2728 if (auto c = v.getDefiningOp<ConstantOp>())
2729 return c.getValue().isOne();
2730 if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2731 return sc.getValue();
2732 return false;
2733}
2734
2735static LogicalResult
2736canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2737 if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2738 return failure();
2739
2740 auto resetValue = reg.getResetValue();
2741 if (reg.getType(0) != resetValue.getType())
2742 return failure();
2743
2744 // Ignore 'passthrough'.
2745 (void)dropWrite(rewriter, reg->getResult(0), {});
2746 replaceOpWithNewOpAndCopyName<NodeOp>(
2747 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2748 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2749 return success();
2750}
2751
2752void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2753 MLIRContext *context) {
2754 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2756 results.add(demoteForceableIfUnused<RegResetOp>);
2757}
2758
2759// Returns the value connected to a port, if there is only one.
2760static Value getPortFieldValue(Value port, StringRef name) {
2761 auto portTy = type_cast<BundleType>(port.getType());
2762 auto fieldIndex = portTy.getElementIndex(name);
2763 assert(fieldIndex && "missing field on memory port");
2764
2765 Value value = {};
2766 for (auto *op : port.getUsers()) {
2767 auto portAccess = cast<SubfieldOp>(op);
2768 if (fieldIndex != portAccess.getFieldIndex())
2769 continue;
2770 auto conn = getSingleConnectUserOf(portAccess);
2771 if (!conn || value)
2772 return {};
2773 value = conn.getSrc();
2774 }
2775 return value;
2776}
2777
2778// Returns true if the enable field of a port is set to false.
2779static bool isPortDisabled(Value port) {
2780 auto value = getPortFieldValue(port, "en");
2781 if (!value)
2782 return false;
2783 auto portConst = value.getDefiningOp<ConstantOp>();
2784 if (!portConst)
2785 return false;
2786 return portConst.getValue().isZero();
2787}
2788
2789// Returns true if the data output is unused.
2790static bool isPortUnused(Value port, StringRef data) {
2791 auto portTy = type_cast<BundleType>(port.getType());
2792 auto fieldIndex = portTy.getElementIndex(data);
2793 assert(fieldIndex && "missing enable flag on memory port");
2794
2795 for (auto *op : port.getUsers()) {
2796 auto portAccess = cast<SubfieldOp>(op);
2797 if (fieldIndex != portAccess.getFieldIndex())
2798 continue;
2799 if (!portAccess.use_empty())
2800 return false;
2801 }
2802
2803 return true;
2804}
2805
2806// Returns the value connected to a port, if there is only one.
2807static void replacePortField(PatternRewriter &rewriter, Value port,
2808 StringRef name, Value value) {
2809 auto portTy = type_cast<BundleType>(port.getType());
2810 auto fieldIndex = portTy.getElementIndex(name);
2811 assert(fieldIndex && "missing field on memory port");
2812
2813 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2814 auto portAccess = cast<SubfieldOp>(op);
2815 if (fieldIndex != portAccess.getFieldIndex())
2816 continue;
2817 rewriter.replaceAllUsesWith(portAccess, value);
2818 rewriter.eraseOp(portAccess);
2819 }
2820}
2821
2822// Remove accesses to a port which is used.
2823static void erasePort(PatternRewriter &rewriter, Value port) {
2824 // Helper to create a dummy 0 clock for the dummy registers.
2825 Value clock;
2826 auto getClock = [&] {
2827 if (!clock)
2828 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2829 ClockType::get(rewriter.getContext()),
2830 false);
2831 return clock;
2832 };
2833
2834 // Find the clock field of the port and determine whether the port is
2835 // accessed only through its subfields or as a whole wire. If the port
2836 // is used in its entirety, replace it with a wire. Otherwise,
2837 // eliminate individual subfields and replace with reasonable defaults.
2838 for (auto *op : port.getUsers()) {
2839 auto subfield = dyn_cast<SubfieldOp>(op);
2840 if (!subfield) {
2841 auto ty = port.getType();
2842 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2843 rewriter.replaceAllUsesWith(port, reg.getResult());
2844 return;
2845 }
2846 }
2847
2848 // Remove all connects to field accesses as they are no longer relevant.
2849 // If field values are used anywhere, which should happen solely for read
2850 // ports, a dummy register is introduced which replicates the behaviour of
2851 // memory that is never written, but might be read.
2852 for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2853 auto access = cast<SubfieldOp>(accessOp);
2854 for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2855 auto connect = dyn_cast<FConnectLike>(user);
2856 if (connect && connect.getDest() == access) {
2857 rewriter.eraseOp(user);
2858 continue;
2859 }
2860 }
2861 if (access.use_empty()) {
2862 rewriter.eraseOp(access);
2863 continue;
2864 }
2865
2866 // Replace read values with a register that is never written, handing off
2867 // the canonicalization of such a register to another canonicalizer.
2868 auto ty = access.getType();
2869 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2870 rewriter.replaceOp(access, reg.getResult());
2871 }
2872 assert(port.use_empty() && "port should have no remaining uses");
2873}
2874
2875namespace {
2876// If memory has known, but zero width, eliminate it.
2877struct FoldZeroWidthMemory : public mlir::OpRewritePattern<MemOp> {
2878 using OpRewritePattern::OpRewritePattern;
2879 LogicalResult matchAndRewrite(MemOp mem,
2880 PatternRewriter &rewriter) const override {
2881 if (hasDontTouch(mem))
2882 return failure();
2883
2884 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2885 mem.getDataType().getBitWidthOrSentinel() != 0)
2886 return failure();
2887
2888 // Make sure are users are safe to replace
2889 for (auto port : mem.getResults())
2890 for (auto *user : port.getUsers())
2891 if (!isa<SubfieldOp>(user))
2892 return failure();
2893
2894 // Annoyingly, there isn't a good replacement for the port as a whole,
2895 // since they have an outer flip type.
2896 for (auto port : mem.getResults()) {
2897 for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2898 SubfieldOp sfop = cast<SubfieldOp>(user);
2899 StringRef fieldName = sfop.getFieldName();
2900 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2901 rewriter, sfop, sfop.getResult().getType())
2902 .getResult();
2903 if (fieldName.ends_with("data")) {
2904 // Make sure to write data ports.
2905 auto zero = firrtl::ConstantOp::create(
2906 rewriter, wire.getLoc(),
2907 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2908 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2909 }
2910 }
2911 }
2912 rewriter.eraseOp(mem);
2913 return success();
2914 }
2915};
2916
2917// If memory has no write ports and no file initialization, eliminate it.
2918struct FoldReadOrWriteOnlyMemory : public mlir::OpRewritePattern<MemOp> {
2919 using OpRewritePattern::OpRewritePattern;
2920 LogicalResult matchAndRewrite(MemOp mem,
2921 PatternRewriter &rewriter) const override {
2922 if (hasDontTouch(mem))
2923 return failure();
2924 bool isRead = false, isWritten = false;
2925 for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2926 switch (mem.getPortKind(i)) {
2927 case MemOp::PortKind::Read:
2928 isRead = true;
2929 if (isWritten)
2930 return failure();
2931 continue;
2932 case MemOp::PortKind::Write:
2933 isWritten = true;
2934 if (isRead)
2935 return failure();
2936 continue;
2937 case MemOp::PortKind::Debug:
2938 case MemOp::PortKind::ReadWrite:
2939 return failure();
2940 }
2941 llvm_unreachable("unknown port kind");
2942 }
2943 assert((!isWritten || !isRead) && "memory is in use");
2944
2945 // If the memory is read only, but has a file initialization, then we can't
2946 // remove it. A write only memory with file initialization is okay to
2947 // remove.
2948 if (isRead && mem.getInit())
2949 return failure();
2950
2951 for (auto port : mem.getResults())
2952 erasePort(rewriter, port);
2953
2954 rewriter.eraseOp(mem);
2955 return success();
2956 }
2957};
2958
2959// Eliminate the dead ports of memories.
2960struct FoldUnusedPorts : public mlir::OpRewritePattern<MemOp> {
2961 using OpRewritePattern::OpRewritePattern;
2962 LogicalResult matchAndRewrite(MemOp mem,
2963 PatternRewriter &rewriter) const override {
2964 if (hasDontTouch(mem))
2965 return failure();
2966 // Identify the dead and changed ports.
2967 llvm::SmallBitVector deadPorts(mem.getNumResults());
2968 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2969 // Do not simplify annotated ports.
2970 if (!mem.getPortAnnotation(i).empty())
2971 continue;
2972
2973 // Skip debug ports.
2974 auto kind = mem.getPortKind(i);
2975 if (kind == MemOp::PortKind::Debug)
2976 continue;
2977
2978 // If a port is disabled, always eliminate it.
2979 if (isPortDisabled(port)) {
2980 deadPorts.set(i);
2981 continue;
2982 }
2983 // Eliminate read ports whose outputs are not used.
2984 if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
2985 deadPorts.set(i);
2986 continue;
2987 }
2988 }
2989 if (deadPorts.none())
2990 return failure();
2991
2992 // Rebuild the new memory with the altered ports.
2993 SmallVector<Type> resultTypes;
2994 SmallVector<StringRef> portNames;
2995 SmallVector<Attribute> portAnnotations;
2996 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2997 if (deadPorts[i])
2998 continue;
2999 resultTypes.push_back(port.getType());
3000 portNames.push_back(mem.getPortName(i));
3001 portAnnotations.push_back(mem.getPortAnnotation(i));
3002 }
3003
3004 MemOp newOp;
3005 if (!resultTypes.empty())
3006 newOp = MemOp::create(
3007 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3008 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3009 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3010 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3011 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3012
3013 // Replace the dead ports with dummy wires.
3014 unsigned nextPort = 0;
3015 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3016 if (deadPorts[i])
3017 erasePort(rewriter, port);
3018 else
3019 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
3020 }
3021
3022 rewriter.eraseOp(mem);
3023 return success();
3024 }
3025};
3026
3027// Rewrite write-only read-write ports to write ports.
3028struct FoldReadWritePorts : public mlir::OpRewritePattern<MemOp> {
3029 using OpRewritePattern::OpRewritePattern;
3030 LogicalResult matchAndRewrite(MemOp mem,
3031 PatternRewriter &rewriter) const override {
3032 if (hasDontTouch(mem))
3033 return failure();
3034
3035 // Identify read-write ports whose read end is unused.
3036 llvm::SmallBitVector deadReads(mem.getNumResults());
3037 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3038 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
3039 continue;
3040 if (!mem.getPortAnnotation(i).empty())
3041 continue;
3042 if (isPortUnused(port, "rdata")) {
3043 deadReads.set(i);
3044 continue;
3045 }
3046 }
3047 if (deadReads.none())
3048 return failure();
3049
3050 SmallVector<Type> resultTypes;
3051 SmallVector<StringRef> portNames;
3052 SmallVector<Attribute> portAnnotations;
3053 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3054 if (deadReads[i])
3055 resultTypes.push_back(
3056 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
3057 MemOp::PortKind::Write, mem.getMaskBits()));
3058 else
3059 resultTypes.push_back(port.getType());
3060
3061 portNames.push_back(mem.getPortName(i));
3062 portAnnotations.push_back(mem.getPortAnnotation(i));
3063 }
3064
3065 auto newOp = MemOp::create(
3066 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3067 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3068 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3069 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3070 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3071
3072 for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
3073 auto result = mem.getResult(i);
3074 auto newResult = newOp.getResult(i);
3075 if (deadReads[i]) {
3076 auto resultPortTy = type_cast<BundleType>(result.getType());
3077
3078 // Rewrite accesses to the old port field to accesses to a
3079 // corresponding field of the new port.
3080 auto replace = [&](StringRef toName, StringRef fromName) {
3081 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
3082 assert(fromFieldIndex && "missing enable flag on memory port");
3083
3084 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
3085 newResult, toName);
3086 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
3087 auto fromField = cast<SubfieldOp>(op);
3088 if (fromFieldIndex != fromField.getFieldIndex())
3089 continue;
3090 rewriter.replaceOp(fromField, toField.getResult());
3091 }
3092 };
3093
3094 replace("addr", "addr");
3095 replace("en", "en");
3096 replace("clk", "clk");
3097 replace("data", "wdata");
3098 replace("mask", "wmask");
3099
3100 // Remove the wmode field, replacing it with dummy wires.
3101 auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
3102 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
3103 auto wmodeField = cast<SubfieldOp>(op);
3104 if (wmodeFieldIndex != wmodeField.getFieldIndex())
3105 continue;
3106 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
3107 }
3108 } else {
3109 rewriter.replaceAllUsesWith(result, newResult);
3110 }
3111 }
3112 rewriter.eraseOp(mem);
3113 return success();
3114 }
3115};
3116
3117// Eliminate the dead ports of memories.
3118struct FoldUnusedBits : public mlir::OpRewritePattern<MemOp> {
3119 using OpRewritePattern::OpRewritePattern;
3120
3121 LogicalResult matchAndRewrite(MemOp mem,
3122 PatternRewriter &rewriter) const override {
3123 if (hasDontTouch(mem))
3124 return failure();
3125
3126 // Only apply the transformation if the memory is not sequential.
3127 const auto &summary = mem.getSummary();
3128 if (summary.isMasked || summary.isSeqMem())
3129 return failure();
3130
3131 auto type = type_dyn_cast<IntType>(mem.getDataType());
3132 if (!type)
3133 return failure();
3134 auto width = type.getBitWidthOrSentinel();
3135 if (width <= 0)
3136 return failure();
3137
3138 llvm::SmallBitVector usedBits(width);
3139 DenseMap<unsigned, unsigned> mapping;
3140
3141 // Find which bits are used out of the users of a read port. This detects
3142 // ports whose data/rdata field is used only through bit select ops. The
3143 // bit selects are then used to build a bit-mask. The ops are collected.
3144 SmallVector<BitsPrimOp> readOps;
3145 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3146 auto portTy = type_cast<BundleType>(port.getType());
3147 auto fieldIndex = portTy.getElementIndex(field);
3148 assert(fieldIndex && "missing data port");
3149
3150 for (auto *op : port.getUsers()) {
3151 auto portAccess = cast<SubfieldOp>(op);
3152 if (fieldIndex != portAccess.getFieldIndex())
3153 continue;
3154
3155 for (auto *user : op->getUsers()) {
3156 auto bits = dyn_cast<BitsPrimOp>(user);
3157 if (!bits)
3158 return failure();
3159
3160 usedBits.set(bits.getLo(), bits.getHi() + 1);
3161 if (usedBits.all())
3162 return failure();
3163
3164 mapping[bits.getLo()] = 0;
3165 readOps.push_back(bits);
3166 }
3167 }
3168
3169 return success();
3170 };
3171
3172 // Finds the users of write ports. This expects all the data/wdata fields
3173 // of the ports to be used solely as the destination of matching connects.
3174 // If a memory has ports with other uses, it is excluded from optimisation.
3175 SmallVector<MatchingConnectOp> writeOps;
3176 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3177 auto portTy = type_cast<BundleType>(port.getType());
3178 auto fieldIndex = portTy.getElementIndex(field);
3179 assert(fieldIndex && "missing data port");
3180
3181 for (auto *op : port.getUsers()) {
3182 auto portAccess = cast<SubfieldOp>(op);
3183 if (fieldIndex != portAccess.getFieldIndex())
3184 continue;
3185
3186 auto conn = getSingleConnectUserOf(portAccess);
3187 if (!conn)
3188 return failure();
3189
3190 writeOps.push_back(conn);
3191 }
3192 return success();
3193 };
3194
3195 // Traverse all ports and find the read and used data fields.
3196 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3197 // Do not simplify annotated ports.
3198 if (!mem.getPortAnnotation(i).empty())
3199 return failure();
3200
3201 switch (mem.getPortKind(i)) {
3202 case MemOp::PortKind::Debug:
3203 // Skip debug ports.
3204 return failure();
3205 case MemOp::PortKind::Write:
3206 if (failed(findWriteUsers(port, "data")))
3207 return failure();
3208 continue;
3209 case MemOp::PortKind::Read:
3210 if (failed(findReadUsers(port, "data")))
3211 return failure();
3212 continue;
3213 case MemOp::PortKind::ReadWrite:
3214 if (failed(findWriteUsers(port, "wdata")))
3215 return failure();
3216 if (failed(findReadUsers(port, "rdata")))
3217 return failure();
3218 continue;
3219 }
3220 llvm_unreachable("unknown port kind");
3221 }
3222
3223 // Unused memories are handled in a different canonicalizer.
3224 if (usedBits.none())
3225 return failure();
3226
3227 // Build a mapping of existing indices to compacted ones.
3228 SmallVector<std::pair<unsigned, unsigned>> ranges;
3229 unsigned newWidth = 0;
3230 for (int i = usedBits.find_first(); 0 <= i && i < width;) {
3231 int e = usedBits.find_next_unset(i);
3232 if (e < 0)
3233 e = width;
3234 for (int idx = i; idx < e; ++idx, ++newWidth) {
3235 if (auto it = mapping.find(idx); it != mapping.end()) {
3236 it->second = newWidth;
3237 }
3238 }
3239 ranges.emplace_back(i, e - 1);
3240 i = e != width ? usedBits.find_next(e) : e;
3241 }
3242
3243 // Create the new op with the new port types.
3244 auto newType = IntType::get(mem->getContext(), type.isSigned(), newWidth);
3245 SmallVector<Type> portTypes;
3246 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3247 portTypes.push_back(
3248 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3249 }
3250 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3251 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3252 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3253 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3254 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3255
3256 // Rewrite bundle users to the new data type.
3257 auto rewriteSubfield = [&](Value port, StringRef field) {
3258 auto portTy = type_cast<BundleType>(port.getType());
3259 auto fieldIndex = portTy.getElementIndex(field);
3260 assert(fieldIndex && "missing data port");
3261
3262 rewriter.setInsertionPointAfter(newMem);
3263 auto newPortAccess =
3264 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3265
3266 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
3267 auto portAccess = cast<SubfieldOp>(op);
3268 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3269 continue;
3270 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3271 }
3272 };
3273
3274 // Rewrite the field accesses.
3275 for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
3276 switch (newMem.getPortKind(i)) {
3277 case MemOp::PortKind::Debug:
3278 llvm_unreachable("cannot rewrite debug port");
3279 case MemOp::PortKind::Write:
3280 rewriteSubfield(port, "data");
3281 continue;
3282 case MemOp::PortKind::Read:
3283 rewriteSubfield(port, "data");
3284 continue;
3285 case MemOp::PortKind::ReadWrite:
3286 rewriteSubfield(port, "rdata");
3287 rewriteSubfield(port, "wdata");
3288 continue;
3289 }
3290 llvm_unreachable("unknown port kind");
3291 }
3292
3293 // Rewrite the reads to the new ranges, compacting them.
3294 for (auto readOp : readOps) {
3295 rewriter.setInsertionPointAfter(readOp);
3296 auto it = mapping.find(readOp.getLo());
3297 assert(it != mapping.end() && "bit op mapping not found");
3298 // Create a new bit selection from the compressed memory. The new op may
3299 // be folded if we are selecting the entire compressed memory.
3300 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3301 readOp.getLoc(), readOp.getInput(),
3302 readOp.getHi() - readOp.getLo() + it->second, it->second);
3303 rewriter.replaceAllUsesWith(readOp, newReadValue);
3304 rewriter.eraseOp(readOp);
3305 }
3306
3307 // Rewrite the writes into a concatenation of slices.
3308 for (auto writeOp : writeOps) {
3309 Value source = writeOp.getSrc();
3310 rewriter.setInsertionPoint(writeOp);
3311
3312 SmallVector<Value> slices;
3313 for (auto &[start, end] : llvm::reverse(ranges)) {
3314 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3315 source, end, start);
3316 slices.push_back(slice);
3317 }
3318
3319 Value catOfSlices =
3320 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3321
3322 // If the original memory held a signed integer, then the compressed
3323 // memory will be signed too. Since the catOfSlices is always unsigned,
3324 // cast the data to a signed integer if needed before connecting back to
3325 // the memory.
3326 if (type.isSigned())
3327 catOfSlices =
3328 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3329
3330 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3331 catOfSlices);
3332 }
3333
3334 return success();
3335 }
3336};
3337
3338// Rewrite single-address memories to a firrtl register.
3339struct FoldRegMems : public mlir::OpRewritePattern<MemOp> {
3340 using OpRewritePattern::OpRewritePattern;
3341 LogicalResult matchAndRewrite(MemOp mem,
3342 PatternRewriter &rewriter) const override {
3343 const FirMemory &info = mem.getSummary();
3344 if (hasDontTouch(mem) || info.depth != 1)
3345 return failure();
3346
3347 auto ty = mem.getDataType();
3348 auto loc = mem.getLoc();
3349 auto *block = mem->getBlock();
3350
3351 // Find the clock of the register-to-be, all write ports should share it.
3352 Value clock;
3353 SmallPtrSet<Operation *, 8> connects;
3354 SmallVector<SubfieldOp> portAccesses;
3355 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3356 if (!mem.getPortAnnotation(i).empty())
3357 continue;
3358
3359 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3360 auto portTy = type_cast<BundleType>(port.getType());
3361 for (auto field : fields) {
3362 auto fieldIndex = portTy.getElementIndex(field);
3363 assert(fieldIndex && "missing field on memory port");
3364
3365 for (auto *op : port.getUsers()) {
3366 auto portAccess = cast<SubfieldOp>(op);
3367 if (fieldIndex != portAccess.getFieldIndex())
3368 continue;
3369 portAccesses.push_back(portAccess);
3370 for (auto *user : portAccess->getUsers()) {
3371 auto conn = dyn_cast<FConnectLike>(user);
3372 if (!conn)
3373 return failure();
3374 connects.insert(conn);
3375 }
3376 }
3377 }
3378 return success();
3379 };
3380
3381 switch (mem.getPortKind(i)) {
3382 case MemOp::PortKind::Debug:
3383 return failure();
3384 case MemOp::PortKind::Read:
3385 if (failed(collect({"clk", "en", "addr"})))
3386 return failure();
3387 continue;
3388 case MemOp::PortKind::Write:
3389 if (failed(collect({"clk", "en", "addr", "data", "mask"})))
3390 return failure();
3391 break;
3392 case MemOp::PortKind::ReadWrite:
3393 if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
3394 return failure();
3395 break;
3396 }
3397
3398 Value portClock = getPortFieldValue(port, "clk");
3399 if (!portClock || (clock && portClock != clock))
3400 return failure();
3401 clock = portClock;
3402 }
3403 // Create a new wire where the memory used to be. This wire will dominate
3404 // all readers of the memory. Reads should be made through this wire.
3405 rewriter.setInsertionPointAfter(mem);
3406 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3407
3408 // The memory is replaced by a register, which we place at the end of the
3409 // block, so that any value driven to the original memory will dominate the
3410 // new register (including the clock). All other ops will be placed
3411 // after the register.
3412 rewriter.setInsertionPointToEnd(block);
3413 auto memReg =
3414 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3415
3416 // Connect the output of the register to the wire.
3417 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3418
3419 // Helper to insert a given number of pipeline stages through registers.
3420 // The pipelines are placed at the end of the block.
3421 auto pipeline = [&](Value value, Value clock, const Twine &name,
3422 unsigned latency) {
3423 for (unsigned i = 0; i < latency; ++i) {
3424 std::string regName;
3425 {
3426 llvm::raw_string_ostream os(regName);
3427 os << mem.getName() << "_" << name << "_" << i;
3428 }
3429 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3430 rewriter.getStringAttr(regName))
3431 .getResult();
3432 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3433 value = reg;
3434 }
3435 return value;
3436 };
3437
3438 const unsigned writeStages = info.writeLatency - 1;
3439
3440 // Traverse each port. Replace reads with the pipelined register, discarding
3441 // the enable flag and reading unconditionally. Pipeline the mask, enable
3442 // and data bits of all write ports to be arbitrated and wired to the reg.
3443 SmallVector<std::tuple<Value, Value, Value>> writes;
3444 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3445 Value portClock = getPortFieldValue(port, "clk");
3446 StringRef name = mem.getPortName(i);
3447
3448 auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
3449 Value value = getPortFieldValue(port, field);
3450 assert(value);
3451 return pipeline(value, portClock, name + "_" + field, stages);
3452 };
3453
3454 switch (mem.getPortKind(i)) {
3455 case MemOp::PortKind::Debug:
3456 llvm_unreachable("unknown port kind");
3457 case MemOp::PortKind::Read: {
3458 // Read ports pipeline the addr and enable signals. However, the
3459 // address must be 0 for single-address memories and the enable signal
3460 // is ignored, always reading out the register. Under these constraints,
3461 // the read port can be replaced with the value from the register.
3462 replacePortField(rewriter, port, "data", memWire);
3463 break;
3464 }
3465 case MemOp::PortKind::Write: {
3466 auto data = portPipeline("data", writeStages);
3467 auto en = portPipeline("en", writeStages);
3468 auto mask = portPipeline("mask", writeStages);
3469 writes.emplace_back(data, en, mask);
3470 break;
3471 }
3472 case MemOp::PortKind::ReadWrite: {
3473 // Always read the register into the read end.
3474 replacePortField(rewriter, port, "rdata", memWire);
3475
3476 // Create a write enable and pipeline stages.
3477 auto wdata = portPipeline("wdata", writeStages);
3478 auto wmask = portPipeline("wmask", writeStages);
3479
3480 Value en = getPortFieldValue(port, "en");
3481 Value wmode = getPortFieldValue(port, "wmode");
3482
3483 auto wen = AndPrimOp::create(rewriter, port.getLoc(), en, wmode);
3484 auto wenPipelined =
3485 pipeline(wen, portClock, name + "_wen", writeStages);
3486 writes.emplace_back(wdata, wenPipelined, wmask);
3487 break;
3488 }
3489 }
3490 }
3491
3492 // Regardless of `writeUnderWrite`, always implement PortOrder.
3493 Value next = memReg;
3494 for (auto &[data, en, mask] : writes) {
3495 Value masked;
3496
3497 // If a mask bit is used, emit muxes to select the input from the
3498 // register (no mask) or the input (mask bit set).
3499 Location loc = mem.getLoc();
3500 unsigned maskGran = info.dataWidth / info.maskBits;
3501 SmallVector<Value> chunks;
3502 for (unsigned i = 0; i < info.maskBits; ++i) {
3503 unsigned hi = (i + 1) * maskGran - 1;
3504 unsigned lo = i * maskGran;
3505
3506 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
3507 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3508 auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
3509 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3510 chunks.push_back(chunk);
3511 }
3512
3513 std::reverse(chunks.begin(), chunks.end());
3514 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3515 next = MuxPrimOp::create(rewriter, next.getLoc(), en, masked, next);
3516 }
3517 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3518 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3519
3520 // Delete the fields and their associated connects.
3521 for (Operation *conn : connects)
3522 rewriter.eraseOp(conn);
3523 for (auto portAccess : portAccesses)
3524 rewriter.eraseOp(portAccess);
3525 rewriter.eraseOp(mem);
3526
3527 return success();
3528 }
3529};
3530} // namespace
3531
3532void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3533 MLIRContext *context) {
3534 results
3535 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3536 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3537 context);
3538}
3539
3540//===----------------------------------------------------------------------===//
3541// Declarations
3542//===----------------------------------------------------------------------===//
3543
3544// Turn synchronous reset looking register updates to registers with resets.
3545// Also, const prop registers that are driven by a mux tree containing only
3546// instances of one constant or self-assigns.
3547static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3548 // reg ; connect(reg, mux(port, const, val)) ->
3549 // reg.reset(port, const); connect(reg, val)
3550
3551 // Find the one true connect, or bail
3552 auto con = getSingleConnectUserOf(reg.getResult());
3553 if (!con)
3554 return failure();
3555
3556 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3557 if (!mux)
3558 return failure();
3559 auto *high = mux.getHigh().getDefiningOp();
3560 auto *low = mux.getLow().getDefiningOp();
3561 // Reset value must be constant
3562 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3563
3564 // Detect the case if a register only has two possible drivers:
3565 // (1) itself/uninit and (2) constant.
3566 // The mux can then be replaced with the constant.
3567 // r = mux(cond, r, 3) --> r = 3
3568 // r = mux(cond, 3, r) --> r = 3
3569 bool constReg = false;
3570
3571 if (constOp && low == reg)
3572 constReg = true;
3573 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3574 constReg = true;
3575 constOp = dyn_cast<ConstantOp>(low);
3576 }
3577 if (!constOp)
3578 return failure();
3579
3580 // For a non-constant register, reset should be a module port (heuristic to
3581 // limit to intended reset lines). Replace the register anyway if constant.
3582 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3583 return failure();
3584
3585 // Check all types should be typed by now
3586 auto regTy = reg.getResult().getType();
3587 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3588 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3589 regTy.getBitWidthOrSentinel() < 0)
3590 return failure();
3591
3592 // Ok, we know we are doing the transformation.
3593
3594 // Make sure the constant dominates all users.
3595 if (constOp != &con->getBlock()->front())
3596 constOp->moveBefore(&con->getBlock()->front());
3597
3598 if (!constReg) {
3599 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3600 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3601 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3602 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3603 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3604 reg.getForceableAttr());
3605 newReg->setDialectAttrs(attrs);
3606 }
3607 auto pt = rewriter.saveInsertionPoint();
3608 rewriter.setInsertionPoint(con);
3609 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3610 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3611 rewriter.restoreInsertionPoint(pt);
3612 return success();
3613}
3614
3615LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3616 if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3617 succeeded(foldHiddenReset(op, rewriter)))
3618 return success();
3619
3620 if (succeeded(demoteForceableIfUnused(op, rewriter)))
3621 return success();
3622
3623 return failure();
3624}
3625
3626//===----------------------------------------------------------------------===//
3627// Verification Ops.
3628//===----------------------------------------------------------------------===//
3629
3630static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3631 Value enable,
3632 PatternRewriter &rewriter,
3633 bool eraseIfZero) {
3634 // If the verification op is never enabled, delete it.
3635 if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3636 if (constant.getValue().isZero()) {
3637 rewriter.eraseOp(op);
3638 return success();
3639 }
3640 }
3641
3642 // If the verification op is never triggered, delete it.
3643 if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3644 if (constant.getValue().isZero() == eraseIfZero) {
3645 rewriter.eraseOp(op);
3646 return success();
3647 }
3648 }
3649
3650 return failure();
3651}
3652
3653template <class Op, bool EraseIfZero = false>
3654static LogicalResult canonicalizeImmediateVerifOp(Op op,
3655 PatternRewriter &rewriter) {
3656 return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3657 EraseIfZero);
3658}
3659
3660void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3661 MLIRContext *context) {
3662 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3663 results.add<patterns::AssertXWhenX>(context);
3664}
3665
3666void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3667 MLIRContext *context) {
3668 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3669 results.add<patterns::AssumeXWhenX>(context);
3670}
3671
3672void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3673 RewritePatternSet &results, MLIRContext *context) {
3674 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3675 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(context);
3676}
3677
3678void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3679 MLIRContext *context) {
3680 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3681}
3682
3683//===----------------------------------------------------------------------===//
3684// InvalidValueOp
3685//===----------------------------------------------------------------------===//
3686
3687LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3688 PatternRewriter &rewriter) {
3689 // Remove `InvalidValueOp`s with no uses.
3690 if (op.use_empty()) {
3691 rewriter.eraseOp(op);
3692 return success();
3693 }
3694 // Propagate invalids through a single use which is a unary op. You cannot
3695 // propagate through multiple uses as that breaks invalid semantics. Nor
3696 // can you propagate through binary ops or generally any op which computes.
3697 // Not is an exception as it is a pure, all-bits inverse.
3698 if (op->hasOneUse() &&
3699 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3700 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3701 *op->user_begin()) ||
3702 (isa<CvtPrimOp>(*op->user_begin()) &&
3703 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3704 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3705 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3706 .getBitWidthOrSentinel() > 0))) {
3707 auto *modop = *op->user_begin();
3708 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3709 modop->getResult(0).getType());
3710 rewriter.replaceAllOpUsesWith(modop, inv);
3711 rewriter.eraseOp(modop);
3712 rewriter.eraseOp(op);
3713 return success();
3714 }
3715 return failure();
3716}
3717
3718OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3719 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3720 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3721 return {};
3722}
3723
3724//===----------------------------------------------------------------------===//
3725// ClockGateIntrinsicOp
3726//===----------------------------------------------------------------------===//
3727
3728OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3729 // Forward the clock if one of the enables is always true.
3730 if (isConstantOne(adaptor.getEnable()) ||
3731 isConstantOne(adaptor.getTestEnable()))
3732 return getInput();
3733
3734 // Fold to a constant zero clock if the enables are always false.
3735 if (isConstantZero(adaptor.getEnable()) &&
3736 (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3737 return BoolAttr::get(getContext(), false);
3738
3739 // Forward constant zero clocks.
3740 if (isConstantZero(adaptor.getInput()))
3741 return BoolAttr::get(getContext(), false);
3742
3743 return {};
3744}
3745
3746LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3747 PatternRewriter &rewriter) {
3748 // Remove constant false test enable.
3749 if (auto testEnable = op.getTestEnable()) {
3750 if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3751 if (constOp.getValue().isZero()) {
3752 rewriter.modifyOpInPlace(op,
3753 [&] { op.getTestEnableMutable().clear(); });
3754 return success();
3755 }
3756 }
3757 }
3758
3759 return failure();
3760}
3761
3762//===----------------------------------------------------------------------===//
3763// Reference Ops.
3764//===----------------------------------------------------------------------===//
3765
3766// refresolve(forceable.ref) -> forceable.data
3767static LogicalResult
3768canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3769 auto forceable = op.getRef().getDefiningOp<Forceable>();
3770 if (!forceable || !forceable.isForceable() ||
3771 op.getRef() != forceable.getDataRef() ||
3772 op.getType() != forceable.getDataType())
3773 return failure();
3774 rewriter.replaceAllUsesWith(op, forceable.getData());
3775 return success();
3776}
3777
3778void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3779 MLIRContext *context) {
3780 results.insert<patterns::RefResolveOfRefSend>(context);
3781 results.insert(canonicalizeRefResolveOfForceable);
3782}
3783
3784OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3785 // RefCast is unnecessary if types match.
3786 if (getInput().getType() == getType())
3787 return getInput();
3788 return {};
3789}
3790
3791static bool isConstantZero(Value operand) {
3792 auto constOp = operand.getDefiningOp<ConstantOp>();
3793 return constOp && constOp.getValue().isZero();
3794}
3795
3796template <typename Op>
3797static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3798 if (isConstantZero(op.getPredicate())) {
3799 rewriter.eraseOp(op);
3800 return success();
3801 }
3802 return failure();
3803}
3804
3805void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3806 MLIRContext *context) {
3807 results.add(eraseIfPredFalse<RefForceOp>);
3808}
3809void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3810 MLIRContext *context) {
3811 results.add(eraseIfPredFalse<RefForceInitialOp>);
3812}
3813void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3814 MLIRContext *context) {
3815 results.add(eraseIfPredFalse<RefReleaseOp>);
3816}
3817void RefReleaseInitialOp::getCanonicalizationPatterns(
3818 RewritePatternSet &results, MLIRContext *context) {
3819 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3820}
3821
3822//===----------------------------------------------------------------------===//
3823// HasBeenResetIntrinsicOp
3824//===----------------------------------------------------------------------===//
3825
3826OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3827 // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3828
3829 // Fold to zero if the reset is a constant. In this case the op is either
3830 // permanently in reset or never resets. Both mean that the reset never
3831 // finishes, so this op never returns true.
3832 if (adaptor.getReset())
3833 return getIntZerosAttr(UIntType::get(getContext(), 1));
3834
3835 // Fold to zero if the clock is a constant and the reset is synchronous. In
3836 // that case the reset will never be started.
3837 if (isUInt1(getReset().getType()) && adaptor.getClock())
3838 return getIntZerosAttr(UIntType::get(getContext(), 1));
3839
3840 return {};
3841}
3842
3843//===----------------------------------------------------------------------===//
3844// FPGAProbeIntrinsicOp
3845//===----------------------------------------------------------------------===//
3846
3847static bool isTypeEmpty(FIRRTLType type) {
3849 .Case<FVectorType>(
3850 [&](auto ty) -> bool { return isTypeEmpty(ty.getElementType()); })
3851 .Case<BundleType>([&](auto ty) -> bool {
3852 for (auto elem : ty.getElements())
3853 if (!isTypeEmpty(elem.type))
3854 return false;
3855 return true;
3856 })
3857 .Case<IntType>([&](auto ty) { return ty.getWidth() == 0; })
3858 .Default([](auto) -> bool { return false; });
3859}
3860
3861LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3862 PatternRewriter &rewriter) {
3863 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3864 if (!firrtlTy)
3865 return failure();
3866
3867 if (!isTypeEmpty(firrtlTy))
3868 return failure();
3869
3870 rewriter.eraseOp(op);
3871 return success();
3872}
3873
3874//===----------------------------------------------------------------------===//
3875// Layer Block Op
3876//===----------------------------------------------------------------------===//
3877
3878LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3879 PatternRewriter &rewriter) {
3880
3881 // If the layerblock is empty, erase it.
3882 if (op.getBody()->empty()) {
3883 rewriter.eraseOp(op);
3884 return success();
3885 }
3886
3887 return failure();
3888}
3889
3890//===----------------------------------------------------------------------===//
3891// Domain-related Ops
3892//===----------------------------------------------------------------------===//
3893
3894OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3895 // If no domains are specified, then forward the input to the result.
3896 if (getDomains().empty())
3897 return getInput();
3898
3899 return {};
3900}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:18
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition FoldUtils.h:28
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition FoldUtils.h:35
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21
LogicalResult matchAndRewrite(BitsPrimOp bits, mlir::PatternRewriter &rewriter) const override