CIRCT 23.0.0git
Loading...
Searching...
No Matches
FIRRTLFolds.cpp
Go to the documentation of this file.
1//===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the folding and canonicalizations for FIRRTL ops.
10//
11//===----------------------------------------------------------------------===//
12
17#include "circt/Support/APInt.h"
18#include "circt/Support/LLVM.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallPtrSet.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
27
28using namespace circt;
29using namespace firrtl;
30
31// Drop writes to old and pass through passthrough to make patterns easier to
32// write.
33static Value dropWrite(PatternRewriter &rewriter, OpResult old,
34 Value passthrough) {
35 SmallPtrSet<Operation *, 8> users;
36 for (auto *user : old.getUsers())
37 users.insert(user);
38 for (Operation *user : users)
39 if (auto connect = dyn_cast<FConnectLike>(user))
40 if (connect.getDest() == old)
41 rewriter.eraseOp(user);
42 return passthrough;
43}
44
45// Return true if it is OK to propagate the name to the operation.
46// Non-pure operations such as instances, registers, and memories are not
47// allowed to update names for name stabilities and LEC reasons.
48static bool isOkToPropagateName(Operation *op) {
49 // Conservatively disallow operations with regions to prevent performance
50 // regression due to recursive calls to sub-regions in mlir::isPure.
51 if (op->getNumRegions() != 0)
52 return false;
53 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
54}
55
56// Move a name hint from a soon to be deleted operation to a new operation.
57// Pass through the new operation to make patterns easier to write. This cannot
58// move a name to a port (block argument), doing so would require rewriting all
59// instance sites as well as the module.
60static Value moveNameHint(OpResult old, Value passthrough) {
61 Operation *op = passthrough.getDefiningOp();
62 // This should handle ports, but it isn't clear we can change those in
63 // canonicalizers.
64 assert(op && "passthrough must be an operation");
65 Operation *oldOp = old.getOwner();
66 auto name = oldOp->getAttrOfType<StringAttr>("name");
67 if (name && !name.getValue().empty() && isOkToPropagateName(op))
68 op->setAttr("name", name);
69 return passthrough;
70}
71
72// Declarative canonicalization patterns
73namespace circt {
74namespace firrtl {
75namespace patterns {
76#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
77} // namespace patterns
78} // namespace firrtl
79} // namespace circt
80
81/// Return true if this operation's operands and results all have a known width.
82/// This only works for integer types.
83static bool hasKnownWidthIntTypes(Operation *op) {
84 auto resultType = type_cast<IntType>(op->getResult(0).getType());
85 if (!resultType.hasWidth())
86 return false;
87 for (Value operand : op->getOperands())
88 if (!type_cast<IntType>(operand.getType()).hasWidth())
89 return false;
90 return true;
91}
92
93/// Return true if this value is 1 bit UInt.
94static bool isUInt1(Type type) {
95 auto t = type_dyn_cast<UIntType>(type);
96 if (!t || !t.hasWidth() || t.getWidth() != 1)
97 return false;
98 return true;
99}
100
101/// Set the name of an op based on the best of two names: The current name, and
102/// the name passed in.
103static void updateName(PatternRewriter &rewriter, Operation *op,
104 StringAttr name) {
105 // Should never rename InstanceOp
106 if (!name || name.getValue().empty() || !isOkToPropagateName(op))
107 return;
108 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) && "Should never rename");
109 auto newName = name.getValue(); // old name is interesting
110 auto newOpName = op->getAttrOfType<StringAttr>("name");
111 // new name might not be interesting
112 if (newOpName)
113 newName = chooseName(newOpName.getValue(), name.getValue());
114 // Only update if needed
115 if (!newOpName || newOpName.getValue() != newName)
116 rewriter.modifyOpInPlace(
117 op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
118}
119
120/// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
121/// If a replaced op has a "name" attribute, this function propagates the name
122/// to the new value.
123static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
124 Value newValue) {
125 if (auto *newOp = newValue.getDefiningOp()) {
126 auto name = op->getAttrOfType<StringAttr>("name");
127 updateName(rewriter, newOp, name);
128 }
129 rewriter.replaceOp(op, newValue);
130}
131
132/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
133/// attribute. If a replaced op has a "name" attribute, this function propagates
134/// the name to the new value.
135template <typename OpTy, typename... Args>
136static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
137 Operation *op, Args &&...args) {
138 auto name = op->getAttrOfType<StringAttr>("name");
139 auto newOp =
140 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
141 updateName(rewriter, newOp, name);
142 return newOp;
143}
144
145/// Return true if the name is droppable. Note that this is different from
146/// `isUselessName` because non-useless names may be also droppable.
148 if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
149 return namableOp.hasDroppableName();
150 return false;
151}
152
153/// Implicitly replace the operand to a constant folding operation with a const
154/// 0 in case the operand is non-constant but has a bit width 0, or if the
155/// operand is an invalid value.
156///
157/// This makes constant folding significantly easier, as we can simply pass the
158/// operands to an operation through this function to appropriately replace any
159/// zero-width dynamic values or invalid values with a constant of value 0.
160static std::optional<APSInt>
161getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
162 assert(type_cast<IntType>(operand.getType()) &&
163 "getExtendedConstant is limited to integer types");
164
165 // We never support constant folding to unknown width values.
166 if (destWidth < 0)
167 return {};
168
169 // Extension signedness follows the operand sign.
170 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
171 return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
172
173 // If the operand is zero bits, then we can return a zero of the result
174 // type.
175 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
176 return APSInt(destWidth,
177 type_cast<IntType>(operand.getType()).isUnsigned());
178 return {};
179}
180
181/// Determine the value of a constant operand for the sake of constant folding.
182static std::optional<APSInt> getConstant(Attribute operand) {
183 if (!operand)
184 return {};
185 if (auto attr = dyn_cast<BoolAttr>(operand))
186 return APSInt(APInt(1, attr.getValue()));
187 if (auto attr = dyn_cast<IntegerAttr>(operand))
188 return attr.getAPSInt();
189 return {};
190}
191
192/// Determine whether a constant operand is a zero value for the sake of
193/// constant folding. This considers `invalidvalue` to be zero.
194static bool isConstantZero(Attribute operand) {
195 if (auto cst = getConstant(operand))
196 return cst->isZero();
197 return false;
198}
199
200/// Determine whether a constant operand is a one value for the sake of constant
201/// folding.
202static bool isConstantOne(Attribute operand) {
203 if (auto cst = getConstant(operand))
204 return cst->isOne();
205 return false;
206}
207
208/// This is the policy for folding, which depends on the sort of operator we're
209/// processing.
210enum class BinOpKind {
211 Normal,
212 Compare,
214};
215
216/// Applies the constant folding function `calculate` to the given operands.
217///
218/// Sign or zero extends the operands appropriately to the bitwidth of the
219/// result type if \p useDstWidth is true, else to the larger of the two operand
220/// bit widths and depending on whether the operation is to be performed on
221/// signed or unsigned operands.
222static Attribute constFoldFIRRTLBinaryOp(
223 Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
224 const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
225 assert(operands.size() == 2 && "binary op takes two operands");
226
227 // We cannot fold something to an unknown width.
228 auto resultType = type_cast<IntType>(op->getResult(0).getType());
229 if (resultType.getWidthOrSentinel() < 0)
230 return {};
231
232 // Any binary op returning i0 is 0.
233 if (resultType.getWidthOrSentinel() == 0)
234 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
235
236 // Determine the operand widths. This is either dictated by the operand type,
237 // or if that type is an unsized integer, by the actual bits necessary to
238 // represent the constant value.
239 auto lhsWidth =
240 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
241 auto rhsWidth =
242 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
243 if (auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
244 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
245 if (auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
246 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
247
248 // Compares extend the operands to the widest of the operand types, not to the
249 // result type.
250 int32_t operandWidth;
251 switch (opKind) {
253 operandWidth = resultType.getWidthOrSentinel();
254 break;
256 // Compares compute with the widest operand, not at the destination type
257 // (which is always i1).
258 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
259 break;
261 operandWidth =
262 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
263 break;
264 }
265
266 auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
267 if (!lhs)
268 return {};
269 auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
270 if (!rhs)
271 return {};
272
273 APInt resultValue = calculate(*lhs, *rhs);
274
275 // If the result type is smaller than the computation then we need to
276 // narrow the constant after the calculation.
277 if (opKind == BinOpKind::DivideOrShift)
278 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
279
280 assert((unsigned)resultType.getWidthOrSentinel() ==
281 resultValue.getBitWidth());
282 return getIntAttr(resultType, resultValue);
283}
284
285/// Applies the canonicalization function `canonicalize` to the given operation.
286///
287/// Determines which (if any) of the operation's operands are constants, and
288/// provides them as arguments to the callback function. Any `invalidvalue` in
289/// the input is mapped to a constant zero. The value returned from the callback
290/// is used as the replacement for `op`, and an additional pad operation is
291/// inserted if necessary. Does nothing if the result of `op` is of unknown
292/// width, in which case the necessity of a pad cannot be determined.
293static LogicalResult canonicalizePrimOp(
294 Operation *op, PatternRewriter &rewriter,
295 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
296 // Can only operate on FIRRTL primitive operations.
297 if (op->getNumResults() != 1)
298 return failure();
299 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
300 if (!type)
301 return failure();
302
303 // Can only operate on operations with a known result width.
304 auto width = type.getBitWidthOrSentinel();
305 if (width < 0)
306 return failure();
307
308 // Determine which of the operands are constants.
309 SmallVector<Attribute, 3> constOperands;
310 constOperands.reserve(op->getNumOperands());
311 for (auto operand : op->getOperands()) {
312 Attribute attr;
313 if (auto *defOp = operand.getDefiningOp())
314 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
315 [&](auto op) { attr = op.getValueAttr(); });
316 constOperands.push_back(attr);
317 }
318
319 // Perform the canonicalization and materialize the result if it is a
320 // constant.
321 auto result = canonicalize(constOperands);
322 if (!result)
323 return failure();
324 Value resultValue;
325 if (auto cst = dyn_cast<Attribute>(result))
326 resultValue = op->getDialect()
327 ->materializeConstant(rewriter, cst, type, op->getLoc())
328 ->getResult(0);
329 else
330 resultValue = cast<Value>(result);
331
332 // Insert a pad if the type widths disagree.
333 if (width !=
334 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
335 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
336
337 // Insert a cast if this is a uint vs. sint or vice versa.
338 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
339 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
340 else if (type_isa<UIntType>(type) &&
341 type_isa<SIntType>(resultValue.getType()))
342 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
343
344 assert(type == resultValue.getType() && "canonicalization changed type");
345 replaceOpAndCopyName(rewriter, op, resultValue);
346 return success();
347}
348
349/// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
350/// value if `bitWidth` is 0.
351static APInt getMaxUnsignedValue(unsigned bitWidth) {
352 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
353}
354
355/// Get the smallest signed value of a given bit width. Returns a 1-bit zero
356/// value if `bitWidth` is 0.
357static APInt getMinSignedValue(unsigned bitWidth) {
358 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
359}
360
361/// Get the largest signed value of a given bit width. Returns a 1-bit zero
362/// value if `bitWidth` is 0.
363static APInt getMaxSignedValue(unsigned bitWidth) {
364 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
365}
366
367//===----------------------------------------------------------------------===//
368// Fold Hooks
369//===----------------------------------------------------------------------===//
370
371OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
372 assert(adaptor.getOperands().empty() && "constant has no operands");
373 return getValueAttr();
374}
375
376OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
377 assert(adaptor.getOperands().empty() && "constant has no operands");
378 return getValueAttr();
379}
380
381OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
382 assert(adaptor.getOperands().empty() && "constant has no operands");
383 return getFieldsAttr();
384}
385
386OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
387 assert(adaptor.getOperands().empty() && "constant has no operands");
388 return getValueAttr();
389}
390
391OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
392 assert(adaptor.getOperands().empty() && "constant has no operands");
393 return getValueAttr();
394}
395
396OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
397 assert(adaptor.getOperands().empty() && "constant has no operands");
398 return getValueAttr();
399}
400
401OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
402 assert(adaptor.getOperands().empty() && "constant has no operands");
403 return getValueAttr();
404}
405
406//===----------------------------------------------------------------------===//
407// Binary Operators
408//===----------------------------------------------------------------------===//
409
410OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
412 *this, adaptor.getOperands(), BinOpKind::Normal,
413 [=](const APSInt &a, const APSInt &b) { return a + b; });
414}
415
416void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417 MLIRContext *context) {
418 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
419 patterns::AddOfSelf, patterns::AddOfPad>(context);
420}
421
422OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
424 *this, adaptor.getOperands(), BinOpKind::Normal,
425 [=](const APSInt &a, const APSInt &b) { return a - b; });
426}
427
428void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
429 MLIRContext *context) {
430 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
431 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
432 patterns::SubOfPadL, patterns::SubOfPadR>(context);
433}
434
435OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
436 // mul(x, 0) -> 0
437 //
438 // This is legal because it aligns with the Scala FIRRTL Compiler
439 // interpretation of lowering invalid to constant zero before constant
440 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
441 // multiplication this way and will emit "x * 0".
442 if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
443 return getIntZerosAttr(getType());
444
446 *this, adaptor.getOperands(), BinOpKind::Normal,
447 [=](const APSInt &a, const APSInt &b) { return a * b; });
448}
449
450OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
451 /// div(x, x) -> 1
452 ///
453 /// Division by zero is undefined in the FIRRTL specification. This fold
454 /// exploits that fact to optimize self division to one. Note: this should
455 /// supersede any division with invalid or zero. Division of invalid by
456 /// invalid should be one.
457 if (getLhs() == getRhs()) {
458 auto width = getType().base().getWidthOrSentinel();
459 if (width == -1)
460 width = 2;
461 // Only fold if we have at least 1 bit of width to represent the `1` value.
462 if (width != 0)
463 return getIntAttr(getType(), APInt(width, 1));
464 }
465
466 // div(0, x) -> 0
467 //
468 // This is legal because it aligns with the Scala FIRRTL Compiler
469 // interpretation of lowering invalid to constant zero before constant
470 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
471 // division this way and will emit "0 / x".
472 if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
473 return getIntZerosAttr(getType());
474
475 /// div(x, 1) -> x : (uint, uint) -> uint
476 ///
477 /// UInt division by one returns the numerator. SInt division can't
478 /// be folded here because it increases the return type bitwidth by
479 /// one and requires sign extension (a new op).
480 if (auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
481 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
482 return getLhs();
483
485 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
486 [=](const APSInt &a, const APSInt &b) -> APInt {
487 if (!!b)
488 return a / b;
489 return APInt(a.getBitWidth(), 0);
490 });
491}
492
493OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
494 // rem(x, x) -> 0
495 //
496 // Division by zero is undefined in the FIRRTL specification. This fold
497 // exploits that fact to optimize self division remainder to zero. Note:
498 // this should supersede any division with invalid or zero. Remainder of
499 // division of invalid by invalid should be zero.
500 if (getLhs() == getRhs())
501 return getIntZerosAttr(getType());
502
503 // rem(0, x) -> 0
504 //
505 // This is legal because it aligns with the Scala FIRRTL Compiler
506 // interpretation of lowering invalid to constant zero before constant
507 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
508 // division this way and will emit "0 % x".
509 if (isConstantZero(adaptor.getLhs()))
510 return getIntZerosAttr(getType());
511
513 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
514 [=](const APSInt &a, const APSInt &b) -> APInt {
515 if (!!b)
516 return a % b;
517 return APInt(a.getBitWidth(), 0);
518 });
519}
520
521OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
523 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
524 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
525}
526
527OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
529 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
530 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
531}
532
533OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
535 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
536 [=](const APSInt &a, const APSInt &b) -> APInt {
537 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
538 : a.ashr(b);
539 });
540}
541
542// TODO: Move to DRR.
543OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
544 if (auto rhsCst = getConstant(adaptor.getRhs())) {
545 /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
546 if (rhsCst->isZero())
547 return getIntZerosAttr(getType());
548
549 /// and(x, -1) -> x
550 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
551 getRhs().getType() == getType())
552 return getLhs();
553 }
554
555 if (auto lhsCst = getConstant(adaptor.getLhs())) {
556 /// and(0, x) -> 0, 0 is largest or is implicit zero extended
557 if (lhsCst->isZero())
558 return getIntZerosAttr(getType());
559
560 /// and(-1, x) -> x
561 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
562 getRhs().getType() == getType())
563 return getRhs();
564 }
565
566 /// and(x, x) -> x
567 if (getLhs() == getRhs() && getRhs().getType() == getType())
568 return getRhs();
569
571 *this, adaptor.getOperands(), BinOpKind::Normal,
572 [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
573}
574
575void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
576 MLIRContext *context) {
577 results
578 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
579 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
580 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
581}
582
583OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
584 if (auto rhsCst = getConstant(adaptor.getRhs())) {
585 /// or(x, 0) -> x
586 if (rhsCst->isZero() && getLhs().getType() == getType())
587 return getLhs();
588
589 /// or(x, -1) -> -1
590 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
591 getLhs().getType() == getType())
592 return getRhs();
593 }
594
595 if (auto lhsCst = getConstant(adaptor.getLhs())) {
596 /// or(0, x) -> x
597 if (lhsCst->isZero() && getRhs().getType() == getType())
598 return getRhs();
599
600 /// or(-1, x) -> -1
601 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
602 getRhs().getType() == getType())
603 return getLhs();
604 }
605
606 /// or(x, x) -> x
607 if (getLhs() == getRhs() && getRhs().getType() == getType())
608 return getRhs();
609
611 *this, adaptor.getOperands(), BinOpKind::Normal,
612 [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
613}
614
615void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
616 MLIRContext *context) {
617 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
618 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
619 patterns::OrOrr>(context);
620}
621
622OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
623 /// xor(x, 0) -> x
624 if (auto rhsCst = getConstant(adaptor.getRhs()))
625 if (rhsCst->isZero() &&
626 firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
627 return getLhs();
628
629 /// xor(x, 0) -> x
630 if (auto lhsCst = getConstant(adaptor.getLhs()))
631 if (lhsCst->isZero() &&
632 firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
633 return getRhs();
634
635 /// xor(x, x) -> 0
636 if (getLhs() == getRhs())
637 return getIntAttr(
638 getType(),
639 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
640
642 *this, adaptor.getOperands(), BinOpKind::Normal,
643 [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
644}
645
646void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
647 MLIRContext *context) {
648 results.insert<patterns::extendXor, patterns::moveConstXor,
649 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
650 context);
651}
652
653void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
654 MLIRContext *context) {
655 results.insert<patterns::LEQWithConstLHS>(context);
656}
657
658OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
659 bool isUnsigned = getLhs().getType().base().isUnsigned();
660
661 // leq(x, x) -> 1
662 if (getLhs() == getRhs())
663 return getIntAttr(getType(), APInt(1, 1));
664
665 // Comparison against constant outside type bounds.
666 if (auto width = getLhs().getType().base().getWidth()) {
667 if (auto rhsCst = getConstant(adaptor.getRhs())) {
668 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
669 commonWidth = std::max(commonWidth, 1);
670
671 // leq(x, const) -> 0 where const < minValue of the unsigned type of x
672 // This can never occur since const is unsigned and cannot be less than 0.
673
674 // leq(x, const) -> 0 where const < minValue of the signed type of x
675 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
676 .slt(getMinSignedValue(*width).sext(commonWidth)))
677 return getIntAttr(getType(), APInt(1, 0));
678
679 // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
680 if (isUnsigned && rhsCst->zext(commonWidth)
681 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
682 return getIntAttr(getType(), APInt(1, 1));
683
684 // leq(x, const) -> 1 where const >= maxValue of the signed type of x
685 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
686 .sge(getMaxSignedValue(*width).sext(commonWidth)))
687 return getIntAttr(getType(), APInt(1, 1));
688 }
689 }
690
692 *this, adaptor.getOperands(), BinOpKind::Compare,
693 [=](const APSInt &a, const APSInt &b) -> APInt {
694 return APInt(1, a <= b);
695 });
696}
697
698void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
699 MLIRContext *context) {
700 results.insert<patterns::LTWithConstLHS>(context);
701}
702
703OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
704 IntType lhsType = getLhs().getType();
705 bool isUnsigned = lhsType.isUnsigned();
706
707 // lt(x, x) -> 0
708 if (getLhs() == getRhs())
709 return getIntAttr(getType(), APInt(1, 0));
710
711 // lt(x, 0) -> 0 when x is unsigned
712 if (auto rhsCst = getConstant(adaptor.getRhs())) {
713 if (rhsCst->isZero() && lhsType.isUnsigned())
714 return getIntAttr(getType(), APInt(1, 0));
715 }
716
717 // Comparison against constant outside type bounds.
718 if (auto width = lhsType.getWidth()) {
719 if (auto rhsCst = getConstant(adaptor.getRhs())) {
720 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
721 commonWidth = std::max(commonWidth, 1);
722
723 // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
724 // Handled explicitly above.
725
726 // lt(x, const) -> 0 where const <= minValue of the signed type of x
727 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
728 .sle(getMinSignedValue(*width).sext(commonWidth)))
729 return getIntAttr(getType(), APInt(1, 0));
730
731 // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
732 if (isUnsigned && rhsCst->zext(commonWidth)
733 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
734 return getIntAttr(getType(), APInt(1, 1));
735
736 // lt(x, const) -> 1 where const > maxValue of the signed type of x
737 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
738 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
739 return getIntAttr(getType(), APInt(1, 1));
740 }
741 }
742
744 *this, adaptor.getOperands(), BinOpKind::Compare,
745 [=](const APSInt &a, const APSInt &b) -> APInt {
746 return APInt(1, a < b);
747 });
748}
749
750void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
751 MLIRContext *context) {
752 results.insert<patterns::GEQWithConstLHS>(context);
753}
754
755OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
756 IntType lhsType = getLhs().getType();
757 bool isUnsigned = lhsType.isUnsigned();
758
759 // geq(x, x) -> 1
760 if (getLhs() == getRhs())
761 return getIntAttr(getType(), APInt(1, 1));
762
763 // geq(x, 0) -> 1 when x is unsigned
764 if (auto rhsCst = getConstant(adaptor.getRhs())) {
765 if (rhsCst->isZero() && isUnsigned)
766 return getIntAttr(getType(), APInt(1, 1));
767 }
768
769 // Comparison against constant outside type bounds.
770 if (auto width = lhsType.getWidth()) {
771 if (auto rhsCst = getConstant(adaptor.getRhs())) {
772 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
773 commonWidth = std::max(commonWidth, 1);
774
775 // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
776 if (isUnsigned && rhsCst->zext(commonWidth)
777 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
778 return getIntAttr(getType(), APInt(1, 0));
779
780 // geq(x, const) -> 0 where const > maxValue of the signed type of x
781 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
782 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
783 return getIntAttr(getType(), APInt(1, 0));
784
785 // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
786 // Handled explicitly above.
787
788 // geq(x, const) -> 1 where const <= minValue of the signed type of x
789 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
790 .sle(getMinSignedValue(*width).sext(commonWidth)))
791 return getIntAttr(getType(), APInt(1, 1));
792 }
793 }
794
796 *this, adaptor.getOperands(), BinOpKind::Compare,
797 [=](const APSInt &a, const APSInt &b) -> APInt {
798 return APInt(1, a >= b);
799 });
800}
801
802void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
803 MLIRContext *context) {
804 results.insert<patterns::GTWithConstLHS>(context);
805}
806
807OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
808 IntType lhsType = getLhs().getType();
809 bool isUnsigned = lhsType.isUnsigned();
810
811 // gt(x, x) -> 0
812 if (getLhs() == getRhs())
813 return getIntAttr(getType(), APInt(1, 0));
814
815 // Comparison against constant outside type bounds.
816 if (auto width = lhsType.getWidth()) {
817 if (auto rhsCst = getConstant(adaptor.getRhs())) {
818 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
819 commonWidth = std::max(commonWidth, 1);
820
821 // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
822 if (isUnsigned && rhsCst->zext(commonWidth)
823 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
824 return getIntAttr(getType(), APInt(1, 0));
825
826 // gt(x, const) -> 0 where const >= maxValue of the signed type of x
827 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
828 .sge(getMaxSignedValue(*width).sext(commonWidth)))
829 return getIntAttr(getType(), APInt(1, 0));
830
831 // gt(x, const) -> 1 where const < minValue of the unsigned type of x
832 // This can never occur since const is unsigned and cannot be less than 0.
833
834 // gt(x, const) -> 1 where const < minValue of the signed type of x
835 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
836 .slt(getMinSignedValue(*width).sext(commonWidth)))
837 return getIntAttr(getType(), APInt(1, 1));
838 }
839 }
840
842 *this, adaptor.getOperands(), BinOpKind::Compare,
843 [=](const APSInt &a, const APSInt &b) -> APInt {
844 return APInt(1, a > b);
845 });
846}
847
848OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
849 // eq(x, x) -> 1
850 if (getLhs() == getRhs())
851 return getIntAttr(getType(), APInt(1, 1));
852
853 if (auto rhsCst = getConstant(adaptor.getRhs())) {
854 /// eq(x, 1) -> x when x is 1 bit.
855 /// TODO: Support SInt<1> on the LHS etc.
856 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
857 getRhs().getType() == getType())
858 return getLhs();
859 }
860
862 *this, adaptor.getOperands(), BinOpKind::Compare,
863 [=](const APSInt &a, const APSInt &b) -> APInt {
864 return APInt(1, a == b);
865 });
866}
867
868LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
869 return canonicalizePrimOp(
870 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
871 if (auto rhsCst = getConstant(operands[1])) {
872 auto width = op.getLhs().getType().getBitWidthOrSentinel();
873
874 // eq(x, 0) -> not(x) when x is 1 bit.
875 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
876 op.getRhs().getType() == op.getType()) {
877 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
878 .getResult();
879 }
880
881 // eq(x, 0) -> not(orr(x)) when x is >1 bit
882 if (rhsCst->isZero() && width > 1) {
883 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
884 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
885 }
886
887 // eq(x, ~0) -> andr(x) when x is >1 bit
888 if (rhsCst->isAllOnes() && width > 1 &&
889 op.getLhs().getType() == op.getRhs().getType()) {
890 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
891 .getResult();
892 }
893 }
894 return {};
895 });
896}
897
898OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
899 // neq(x, x) -> 0
900 if (getLhs() == getRhs())
901 return getIntAttr(getType(), APInt(1, 0));
902
903 if (auto rhsCst = getConstant(adaptor.getRhs())) {
904 /// neq(x, 0) -> x when x is 1 bit.
905 /// TODO: Support SInt<1> on the LHS etc.
906 if (rhsCst->isZero() && getLhs().getType() == getType() &&
907 getRhs().getType() == getType())
908 return getLhs();
909 }
910
912 *this, adaptor.getOperands(), BinOpKind::Compare,
913 [=](const APSInt &a, const APSInt &b) -> APInt {
914 return APInt(1, a != b);
915 });
916}
917
918LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
919 return canonicalizePrimOp(
920 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
921 if (auto rhsCst = getConstant(operands[1])) {
922 auto width = op.getLhs().getType().getBitWidthOrSentinel();
923
924 // neq(x, 1) -> not(x) when x is 1 bit
925 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
926 op.getRhs().getType() == op.getType()) {
927 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
928 .getResult();
929 }
930
931 // neq(x, 0) -> orr(x) when x is >1 bit
932 if (rhsCst->isZero() && width > 1) {
933 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
934 .getResult();
935 }
936
937 // neq(x, ~0) -> not(andr(x))) when x is >1 bit
938 if (rhsCst->isAllOnes() && width > 1 &&
939 op.getLhs().getType() == op.getRhs().getType()) {
940 auto andrOp =
941 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
942 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
943 }
944 }
945
946 return {};
947 });
948}
949
950OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
951 // TODO: implement constant folding, etc.
952 // Tracked in https://github.com/llvm/circt/issues/6696.
953 return {};
954}
955
956OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
957 // TODO: implement constant folding, etc.
958 // Tracked in https://github.com/llvm/circt/issues/6724.
959 return {};
960}
961
962OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
963 if (auto rhsCst = getConstant(adaptor.getRhs())) {
964 if (auto lhsCst = getConstant(adaptor.getLhs())) {
965
966 return IntegerAttr::get(IntegerType::get(getContext(),
967 lhsCst->getBitWidth(),
968 IntegerType::Signed),
969 lhsCst->ashr(*rhsCst));
970 }
971
972 if (rhsCst->isZero())
973 return getLhs();
974 }
975
976 return {};
977}
978
979OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
980 if (auto rhsCst = getConstant(adaptor.getRhs())) {
981 // Constant folding
982 if (auto lhsCst = getConstant(adaptor.getLhs()))
983
984 return IntegerAttr::get(IntegerType::get(getContext(),
985 lhsCst->getBitWidth(),
986 IntegerType::Signed),
987 lhsCst->shl(*rhsCst));
988
989 // integer.shl(x, 0) -> x
990 if (rhsCst->isZero())
991 return getLhs();
992 }
993
994 return {};
995}
996
997//===----------------------------------------------------------------------===//
998// Unary Operators
999//===----------------------------------------------------------------------===//
1000
1001OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
1002 auto base = getInput().getType();
1003 auto w = getBitWidth(base);
1004 if (w)
1005 return getIntAttr(getType(), APInt(32, *w));
1006 return {};
1007}
1008
1009OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
1010 // No constant can be 'x' by definition.
1011 if (auto cst = getConstant(adaptor.getArg()))
1012 return getIntAttr(getType(), APInt(1, 0));
1013 return {};
1014}
1015
1016OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1017 // No effect.
1018 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1019 return getInput();
1020
1021 // Be careful to only fold the cast into the constant if the size is known.
1022 // Otherwise width inference may produce differently-sized constants if the
1023 // sign changes.
1024 if (getType().base().hasWidth())
1025 if (auto cst = getConstant(adaptor.getInput()))
1026 return getIntAttr(getType(), *cst);
1027
1028 return {};
1029}
1030
1031void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1032 MLIRContext *context) {
1033 results.insert<patterns::StoUtoS>(context);
1034}
1035
1036OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1037 // No effect.
1038 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1039 return getInput();
1040
1041 // Be careful to only fold the cast into the constant if the size is known.
1042 // Otherwise width inference may produce differently-sized constants if the
1043 // sign changes.
1044 if (getType().base().hasWidth())
1045 if (auto cst = getConstant(adaptor.getInput()))
1046 return getIntAttr(getType(), *cst);
1047
1048 return {};
1049}
1050
1051void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1052 MLIRContext *context) {
1053 results.insert<patterns::UtoStoU>(context);
1054}
1055
1056OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1057 // No effect.
1058 if (getInput().getType() == getType())
1059 return getInput();
1060
1061 // Constant fold.
1062 if (auto cst = getConstant(adaptor.getInput()))
1063 return BoolAttr::get(getContext(), cst->getBoolValue());
1064
1065 return {};
1066}
1067
1068OpFoldResult AsResetPrimOp::fold(FoldAdaptor adaptor) {
1069 if (auto cst = getConstant(adaptor.getInput()))
1070 return BoolAttr::get(getContext(), cst->getBoolValue());
1071 return {};
1072}
1073
1074OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1075 // No effect.
1076 if (getInput().getType() == getType())
1077 return getInput();
1078
1079 // Constant fold.
1080 if (auto cst = getConstant(adaptor.getInput()))
1081 return BoolAttr::get(getContext(), cst->getBoolValue());
1082
1083 return {};
1084}
1085
1086OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1087 if (!hasKnownWidthIntTypes(*this))
1088 return {};
1089
1090 // Signed to signed is a noop, unsigned operands prepend a zero bit.
1091 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1092 getType().base().getWidthOrSentinel()))
1093 return getIntAttr(getType(), *cst);
1094
1095 return {};
1096}
1097
1098void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1099 MLIRContext *context) {
1100 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1101}
1102
1103OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1104 if (!hasKnownWidthIntTypes(*this))
1105 return {};
1106
1107 // FIRRTL negate always adds a bit.
1108 // -x ---> 0-sext(x) or 0-zext(x)
1109 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1110 getType().base().getWidthOrSentinel()))
1111 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1112
1113 return {};
1114}
1115
1116OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1117 if (!hasKnownWidthIntTypes(*this))
1118 return {};
1119
1120 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1121 getType().base().getWidthOrSentinel()))
1122 return getIntAttr(getType(), ~*cst);
1123
1124 return {};
1125}
1126
1127void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1128 MLIRContext *context) {
1129 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1130 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1131 patterns::NotGt>(context);
1132}
1133
1134class ReductionCat : public mlir::RewritePattern {
1135public:
1136 ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
1137 : RewritePattern(opName, 0, context) {}
1138
1139 /// Handle a constant operand in the cat operation.
1140 /// Returns true if the entire reduction can be replaced with a constant.
1141 /// May add non-zero constants to the remaining operands list.
1142 virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1143 ConstantOp constantOp,
1144 SmallVectorImpl<Value> &remaining) const = 0;
1145
1146 /// Return the unit value for this reduction operation:
1147 virtual bool getIdentityValue() const = 0;
1148
1149 LogicalResult
1150 matchAndRewrite(Operation *op,
1151 mlir::PatternRewriter &rewriter) const override {
1152 // Check if the operand is a cat operation
1153 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1154 if (!catOp)
1155 return failure();
1156
1157 SmallVector<Value> nonConstantOperands;
1158
1159 // Process each operand of the cat operation
1160 for (auto operand : catOp.getInputs()) {
1161 if (auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1162 // Handle constant operands - may short-circuit the entire operation
1163 if (handleConstant(rewriter, op, constantOp, nonConstantOperands))
1164 return success();
1165 } else {
1166 // Keep non-constant operands for further processing
1167 nonConstantOperands.push_back(operand);
1168 }
1169 }
1170
1171 // If no operands remain, replace with identity value
1172 if (nonConstantOperands.empty()) {
1173 replaceOpWithNewOpAndCopyName<ConstantOp>(
1174 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1175 APInt(1, getIdentityValue()));
1176 return success();
1177 }
1178
1179 // If only one operand remains, apply reduction directly to it
1180 if (nonConstantOperands.size() == 1) {
1181 rewriter.modifyOpInPlace(
1182 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1183 return success();
1184 }
1185
1186 // Multiple operands remain - optimize only when cat has a single use.
1187 if (catOp->hasOneUse() &&
1188 nonConstantOperands.size() < catOp->getNumOperands()) {
1189 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1190 nonConstantOperands);
1191 return success();
1192 }
1193 return failure();
1194 }
1195};
1196
1197class OrRCat : public ReductionCat {
1198public:
1199 OrRCat(MLIRContext *context)
1200 : ReductionCat(context, OrRPrimOp::getOperationName()) {}
1201 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1202 ConstantOp value,
1203 SmallVectorImpl<Value> &remaining) const override {
1204 if (value.getValue().isZero())
1205 return false;
1206
1207 replaceOpWithNewOpAndCopyName<ConstantOp>(
1208 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1209 llvm::APInt(1, 1));
1210 return true;
1211 }
1212 bool getIdentityValue() const override { return false; }
1213};
1214
1215class AndRCat : public ReductionCat {
1216public:
1217 AndRCat(MLIRContext *context)
1218 : ReductionCat(context, AndRPrimOp::getOperationName()) {}
1219 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1220 ConstantOp value,
1221 SmallVectorImpl<Value> &remaining) const override {
1222 if (value.getValue().isAllOnes())
1223 return false;
1224
1225 replaceOpWithNewOpAndCopyName<ConstantOp>(
1226 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1227 llvm::APInt(1, 0));
1228 return true;
1229 }
1230 bool getIdentityValue() const override { return true; }
1231};
1232
1233class XorRCat : public ReductionCat {
1234public:
1235 XorRCat(MLIRContext *context)
1236 : ReductionCat(context, XorRPrimOp::getOperationName()) {}
1237 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1238 ConstantOp value,
1239 SmallVectorImpl<Value> &remaining) const override {
1240 if (value.getValue().isZero())
1241 return false;
1242 remaining.push_back(value);
1243 return false;
1244 }
1245 bool getIdentityValue() const override { return false; }
1246};
1247
1248OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1249 if (!hasKnownWidthIntTypes(*this))
1250 return {};
1251
1252 if (getInput().getType().getBitWidthOrSentinel() == 0)
1253 return getIntAttr(getType(), APInt(1, 1));
1254
1255 // x == -1
1256 if (auto cst = getConstant(adaptor.getInput()))
1257 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1258
1259 // one bit is identity. Only applies to UInt since we can't make a cast
1260 // here.
1261 if (isUInt1(getInput().getType()))
1262 return getInput();
1263
1264 return {};
1265}
1266
1267void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1268 MLIRContext *context) {
1269 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1270 patterns::AndRPadS, patterns::AndRCatAndR_left,
1271 patterns::AndRCatAndR_right, AndRCat>(context);
1272}
1273
1274OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1275 if (!hasKnownWidthIntTypes(*this))
1276 return {};
1277
1278 if (getInput().getType().getBitWidthOrSentinel() == 0)
1279 return getIntAttr(getType(), APInt(1, 0));
1280
1281 // x != 0
1282 if (auto cst = getConstant(adaptor.getInput()))
1283 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1284
1285 // one bit is identity. Only applies to UInt since we can't make a cast
1286 // here.
1287 if (isUInt1(getInput().getType()))
1288 return getInput();
1289
1290 return {};
1291}
1292
1293void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1294 MLIRContext *context) {
1295 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1296 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right, OrRCat>(
1297 context);
1298}
1299
1300OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1301 if (!hasKnownWidthIntTypes(*this))
1302 return {};
1303
1304 if (getInput().getType().getBitWidthOrSentinel() == 0)
1305 return getIntAttr(getType(), APInt(1, 0));
1306
1307 // popcount(x) & 1
1308 if (auto cst = getConstant(adaptor.getInput()))
1309 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1310
1311 // one bit is identity. Only applies to UInt since we can't make a cast here.
1312 if (isUInt1(getInput().getType()))
1313 return getInput();
1314
1315 return {};
1316}
1317
1318void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1319 MLIRContext *context) {
1320 results
1321 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1322 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right, XorRCat>(
1323 context);
1324}
1325
1326//===----------------------------------------------------------------------===//
1327// Other Operators
1328//===----------------------------------------------------------------------===//
1329
1330OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1331 auto inputs = getInputs();
1332 auto inputAdaptors = adaptor.getInputs();
1333
1334 // If no inputs, return 0-bit value
1335 if (inputs.empty())
1336 return getIntZerosAttr(getType());
1337
1338 // If single input and same type, return it
1339 if (inputs.size() == 1 && inputs[0].getType() == getType())
1340 return inputs[0];
1341
1342 // Make sure it's safe to fold.
1343 if (!hasKnownWidthIntTypes(*this))
1344 return {};
1345
1346 // Filter out zero-width operands
1347 SmallVector<Value> nonZeroInputs;
1348 SmallVector<Attribute> nonZeroAttributes;
1349 bool allConstant = true;
1350 for (auto [input, attr] : llvm::zip(inputs, inputAdaptors)) {
1351 auto inputType = type_cast<IntType>(input.getType());
1352 if (inputType.getBitWidthOrSentinel() != 0) {
1353 nonZeroInputs.push_back(input);
1354 if (!attr)
1355 allConstant = false;
1356 if (nonZeroInputs.size() > 1 && !allConstant)
1357 return {};
1358 }
1359 }
1360
1361 // If all inputs were zero-width, return 0-bit value
1362 if (nonZeroInputs.empty())
1363 return getIntZerosAttr(getType());
1364
1365 // If only one non-zero input and it has the same type as result, return it
1366 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1367 return nonZeroInputs[0];
1368
1369 if (!hasKnownWidthIntTypes(*this))
1370 return {};
1371
1372 // Constant fold cat - concatenate all constant operands
1373 SmallVector<APInt> constants;
1374 for (auto inputAdaptor : inputAdaptors) {
1375 if (auto cst = getConstant(inputAdaptor))
1376 constants.push_back(*cst);
1377 else
1378 return {}; // Not all operands are constant
1379 }
1380
1381 assert(!constants.empty());
1382 // Concatenate all constants from left to right
1383 APInt result = constants[0];
1384 for (size_t i = 1; i < constants.size(); ++i)
1385 result = result.concat(constants[i]);
1386
1387 return getIntAttr(getType(), result);
1388}
1389
1390void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1391 MLIRContext *context) {
1392 results.insert<patterns::DShlOfConstant>(context);
1393}
1394
1395void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1396 MLIRContext *context) {
1397 results.insert<patterns::DShrOfConstant>(context);
1398}
1399
1400namespace {
1401
1402/// Canonicalize a cat tree into single cat operation.
1403class FlattenCat : public mlir::OpRewritePattern<CatPrimOp> {
1404public:
1405 using OpRewritePattern::OpRewritePattern;
1406
1407 LogicalResult
1408 matchAndRewrite(CatPrimOp cat,
1409 mlir::PatternRewriter &rewriter) const override {
1410 if (!hasKnownWidthIntTypes(cat) ||
1411 cat.getType().getBitWidthOrSentinel() == 0)
1412 return failure();
1413
1414 // If this is not a root of concat tree, skip.
1415 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1416 return failure();
1417
1418 // Try flattening the cat tree.
1419 SmallVector<Value> operands;
1420 SmallVector<Value> worklist;
1421 auto pushOperands = [&worklist](CatPrimOp op) {
1422 for (auto operand : llvm::reverse(op.getInputs()))
1423 worklist.push_back(operand);
1424 };
1425 pushOperands(cat);
1426 bool hasSigned = false, hasUnsigned = false;
1427 while (!worklist.empty()) {
1428 auto value = worklist.pop_back_val();
1429 auto catOp = value.getDefiningOp<CatPrimOp>();
1430 if (!catOp) {
1431 operands.push_back(value);
1432 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) = true;
1433 continue;
1434 }
1435
1436 pushOperands(catOp);
1437 }
1438
1439 // Helper function to cast signed values to unsigned. CatPrimOp converts
1440 // signed values to unsigned so we need to do it explicitly here.
1441 auto castToUIntIfSigned = [&](Value value) -> Value {
1442 if (type_isa<UIntType>(value.getType()))
1443 return value;
1444 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1445 };
1446
1447 assert(operands.size() >= 1 && "zero width cast must be rejected");
1448
1449 if (operands.size() == 1) {
1450 rewriter.replaceOp(cat, castToUIntIfSigned(operands[0]));
1451 return success();
1452 }
1453
1454 if (operands.size() == cat->getNumOperands())
1455 return failure();
1456
1457 // If types are mixed, cast all operands to unsigned.
1458 if (hasSigned && hasUnsigned)
1459 for (auto &operand : operands)
1460 operand = castToUIntIfSigned(operand);
1461
1462 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1463 operands);
1464 return success();
1465 }
1466};
1467
1468// Fold successive constants cat(x, y, 1, 10, z) -> cat(x, y, 110, z)
1469class CatOfConstant : public mlir::OpRewritePattern<CatPrimOp> {
1470public:
1471 using OpRewritePattern::OpRewritePattern;
1472
1473 LogicalResult
1474 matchAndRewrite(CatPrimOp cat,
1475 mlir::PatternRewriter &rewriter) const override {
1476 if (!hasKnownWidthIntTypes(cat))
1477 return failure();
1478
1479 SmallVector<Value> operands;
1480
1481 for (size_t i = 0; i < cat->getNumOperands(); ++i) {
1482 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1483 if (!cst) {
1484 operands.push_back(cat.getInputs()[i]);
1485 continue;
1486 }
1487 APSInt value = cst.getValue();
1488 size_t j = i + 1;
1489 for (; j < cat->getNumOperands(); ++j) {
1490 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1491 if (!nextCst)
1492 break;
1493 value = value.concat(nextCst.getValue());
1494 }
1495
1496 if (j == i + 1) {
1497 // Not folded.
1498 operands.push_back(cst);
1499 } else {
1500 // Folded.
1501 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1502 }
1503
1504 i = j - 1;
1505 }
1506
1507 if (operands.size() == cat->getNumOperands())
1508 return failure();
1509
1510 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1511 operands);
1512
1513 return success();
1514 }
1515};
1516
1517} // namespace
1518
1519void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1520 MLIRContext *context) {
1521 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1522 patterns::CatCast, FlattenCat, CatOfConstant>(context);
1523}
1524
1525//===----------------------------------------------------------------------===//
1526// StringConcatOp
1527//===----------------------------------------------------------------------===//
1528
1529OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
1530 // Fold single-operand concat to just the operand.
1531 if (getInputs().size() == 1)
1532 return getInputs()[0];
1533
1534 // Check if all operands are constant strings before accumulating.
1535 if (!llvm::all_of(adaptor.getInputs(), [](Attribute operand) {
1536 return isa_and_nonnull<StringAttr>(operand);
1537 }))
1538 return {};
1539
1540 // All operands are constant strings, concatenate them.
1541 SmallString<64> result;
1542 for (auto operand : adaptor.getInputs())
1543 result += cast<StringAttr>(operand).getValue();
1544
1545 return StringAttr::get(getContext(), result);
1546}
1547
1548namespace {
1549/// Flatten nested string.concat operations into a single concat.
1550/// string.concat(a, string.concat(b, c), d) -> string.concat(a, b, c, d)
1551class FlattenStringConcat : public mlir::OpRewritePattern<StringConcatOp> {
1552public:
1553 using OpRewritePattern::OpRewritePattern;
1554
1555 LogicalResult
1556 matchAndRewrite(StringConcatOp concat,
1557 mlir::PatternRewriter &rewriter) const override {
1558
1559 // Check if any operands are nested concats with a single use. Only inline
1560 // single-use nested concats to avoid fighting with DCE.
1561 bool hasNestedConcat = llvm::any_of(concat.getInputs(), [](Value operand) {
1562 auto nestedConcat = operand.getDefiningOp<StringConcatOp>();
1563 return nestedConcat && operand.hasOneUse();
1564 });
1565
1566 if (!hasNestedConcat)
1567 return failure();
1568
1569 // Flatten nested concats that have a single use.
1570 SmallVector<Value> flatOperands;
1571 for (auto input : concat.getInputs()) {
1572 if (auto nestedConcat = input.getDefiningOp<StringConcatOp>();
1573 nestedConcat && input.hasOneUse())
1574 llvm::append_range(flatOperands, nestedConcat.getInputs());
1575 else
1576 flatOperands.push_back(input);
1577 }
1578
1579 rewriter.modifyOpInPlace(concat,
1580 [&]() { concat->setOperands(flatOperands); });
1581 return success();
1582 }
1583};
1584
1585/// Merge consecutive constant strings in a concat and remove empty strings.
1586/// string.concat("a", "b", x, "", "c", "d") -> string.concat("ab", x, "cd")
1587class MergeAdjacentStringConstants
1588 : public mlir::OpRewritePattern<StringConcatOp> {
1589public:
1590 using OpRewritePattern::OpRewritePattern;
1591
1592 LogicalResult
1593 matchAndRewrite(StringConcatOp concat,
1594 mlir::PatternRewriter &rewriter) const override {
1595
1596 SmallVector<Value> newOperands;
1597 SmallString<64> accumulatedLit;
1598 SmallVector<StringConstantOp> accumulatedOps;
1599 bool changed = false;
1600
1601 auto flushLiterals = [&]() {
1602 if (accumulatedOps.empty())
1603 return;
1604
1605 // If only one literal, reuse it.
1606 if (accumulatedOps.size() == 1) {
1607 newOperands.push_back(accumulatedOps[0]);
1608 } else {
1609 // Multiple literals - merge them.
1610 auto newLit = rewriter.createOrFold<StringConstantOp>(
1611 concat.getLoc(), StringAttr::get(getContext(), accumulatedLit));
1612 newOperands.push_back(newLit);
1613 changed = true;
1614 }
1615 accumulatedLit.clear();
1616 accumulatedOps.clear();
1617 };
1618
1619 for (auto operand : concat.getInputs()) {
1620 if (auto litOp = operand.getDefiningOp<StringConstantOp>()) {
1621 // Skip empty strings.
1622 if (litOp.getValue().empty()) {
1623 changed = true;
1624 continue;
1625 }
1626 accumulatedLit += litOp.getValue();
1627 accumulatedOps.push_back(litOp);
1628 } else {
1629 flushLiterals();
1630 newOperands.push_back(operand);
1631 }
1632 }
1633
1634 // Flush any remaining literals.
1635 flushLiterals();
1636
1637 if (!changed)
1638 return failure();
1639
1640 // If no operands remain, replace with empty string.
1641 if (newOperands.empty())
1642 return rewriter.replaceOpWithNewOp<StringConstantOp>(
1643 concat, StringAttr::get(getContext(), "")),
1644 success();
1645
1646 // Single-operand case is handled by the folder.
1647 rewriter.modifyOpInPlace(concat,
1648 [&]() { concat->setOperands(newOperands); });
1649 return success();
1650 }
1651};
1652
1653} // namespace
1654
1655void StringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
1656 MLIRContext *context) {
1657 results.insert<FlattenStringConcat, MergeAdjacentStringConstants>(context);
1658}
1659
1660//===----------------------------------------------------------------------===//
1661// PropEqOp
1662//===----------------------------------------------------------------------===//
1663
1664OpFoldResult PropEqOp::fold(FoldAdaptor adaptor) {
1665 auto lhsAttr = adaptor.getLhs();
1666 auto rhsAttr = adaptor.getRhs();
1667 if (!lhsAttr || !rhsAttr)
1668 return {};
1669
1670 return BoolAttr::get(getContext(), lhsAttr == rhsAttr);
1671}
1672
1673//===----------------------------------------------------------------------===//
1674// Boolean Property Ops
1675//===----------------------------------------------------------------------===//
1676
1677// Helper to extract the bool value from a BoolAttr.
1678static std::optional<bool> getBoolValue(Attribute attr) {
1679 if (auto boolAttr = dyn_cast_or_null<BoolAttr>(attr))
1680 return boolAttr.getValue();
1681 return std::nullopt;
1682}
1683
1684OpFoldResult BoolAndOp::fold(FoldAdaptor adaptor) {
1685 auto lhs = getBoolValue(adaptor.getLhs());
1686 auto rhs = getBoolValue(adaptor.getRhs());
1687 if (lhs && rhs)
1688 return BoolAttr::get(getContext(), *lhs && *rhs);
1689 // AND with false is always false.
1690 if ((lhs && !*lhs) || (rhs && !*rhs))
1691 return BoolAttr::get(getContext(), false);
1692 // AND with true is identity.
1693 if (lhs && *lhs)
1694 return getRhs();
1695 if (rhs && *rhs)
1696 return getLhs();
1697 return {};
1698}
1699
1700OpFoldResult BoolOrOp::fold(FoldAdaptor adaptor) {
1701 auto lhs = getBoolValue(adaptor.getLhs());
1702 auto rhs = getBoolValue(adaptor.getRhs());
1703 if (lhs && rhs)
1704 return BoolAttr::get(getContext(), *lhs || *rhs);
1705 // OR with true is always true.
1706 if ((lhs && *lhs) || (rhs && *rhs))
1707 return BoolAttr::get(getContext(), true);
1708 // OR with false is identity.
1709 if (lhs && !*lhs)
1710 return getRhs();
1711 if (rhs && !*rhs)
1712 return getLhs();
1713 return {};
1714}
1715
1716OpFoldResult BoolXorOp::fold(FoldAdaptor adaptor) {
1717 auto lhs = getBoolValue(adaptor.getLhs());
1718 auto rhs = getBoolValue(adaptor.getRhs());
1719 if (lhs && rhs)
1720 return BoolAttr::get(getContext(), *lhs ^ *rhs);
1721 // XOR with false is identity.
1722 if (lhs && !*lhs)
1723 return getRhs();
1724 if (rhs && !*rhs)
1725 return getLhs();
1726 return {};
1727}
1728
1729OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1730 auto op = (*this);
1731 // BitCast is redundant if input and result types are same.
1732 if (op.getType() == op.getInput().getType())
1733 return op.getInput();
1734
1735 // Two consecutive BitCasts are redundant if first bitcast type is same as the
1736 // final result type.
1737 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1738 if (op.getType() == in.getInput().getType())
1739 return in.getInput();
1740
1741 return {};
1742}
1743
1744OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1745 IntType inputType = getInput().getType();
1746 IntType resultType = getType();
1747 // If we are extracting the entire input, then return it.
1748 if (inputType == getType() && resultType.hasWidth())
1749 return getInput();
1750
1751 // Constant fold.
1752 if (hasKnownWidthIntTypes(*this))
1753 if (auto cst = getConstant(adaptor.getInput()))
1754 return getIntAttr(resultType,
1755 cst->extractBits(getHi() - getLo() + 1, getLo()));
1756
1757 return {};
1758}
1759
1760struct BitsOfCat : public mlir::OpRewritePattern<BitsPrimOp> {
1761 using OpRewritePattern::OpRewritePattern;
1762
1763 LogicalResult
1764 matchAndRewrite(BitsPrimOp bits,
1765 mlir::PatternRewriter &rewriter) const override {
1766 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1767 if (!cat)
1768 return failure();
1769 int32_t bitPos = bits.getLo();
1770 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1771 if (resultWidth < 0)
1772 return failure();
1773 for (auto operand : llvm::reverse(cat.getInputs())) {
1774 auto operandWidth =
1775 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1776 if (operandWidth < 0)
1777 return failure();
1778 if (bitPos < operandWidth) {
1779 if (bitPos + resultWidth <= operandWidth) {
1780 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1781 bits.getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1782 replaceOpAndCopyName(rewriter, bits, newBits);
1783 return success();
1784 }
1785 return failure();
1786 }
1787 bitPos -= operandWidth;
1788 }
1789 return failure();
1790 }
1791};
1792
1793void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1794 MLIRContext *context) {
1795 results
1796 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1797 patterns::BitsOfAnd, patterns::BitsOfPad, BitsOfCat>(context);
1798}
1799
1800/// Replace the specified operation with a 'bits' op from the specified hi/lo
1801/// bits. Insert a cast to handle the case where the original operation
1802/// returned a signed integer.
1803static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1804 unsigned loBit, PatternRewriter &rewriter) {
1805 auto resType = type_cast<IntType>(op->getResult(0).getType());
1806 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1807 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1808
1809 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1810 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1811 } else if (resType.isUnsigned() &&
1812 !type_cast<IntType>(value.getType()).isUnsigned()) {
1813 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1814 }
1815 rewriter.replaceOp(op, value);
1816}
1817
1818template <typename OpTy>
1819static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1820 // mux : UInt<0> -> 0
1821 if (op.getType().getBitWidthOrSentinel() == 0)
1822 return getIntAttr(op.getType(),
1823 APInt(0, 0, op.getType().isSignedInteger()));
1824
1825 // mux(cond, x, x) -> x
1826 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1827 return op.getHigh();
1828
1829 // The following folds require that the result has a known width. Otherwise
1830 // the mux requires an additional padding operation to be inserted, which is
1831 // not possible in a fold.
1832 if (op.getType().getBitWidthOrSentinel() < 0)
1833 return {};
1834
1835 // mux(0/1, x, y) -> x or y
1836 if (auto cond = getConstant(adaptor.getSel())) {
1837 if (cond->isZero() && op.getLow().getType() == op.getType())
1838 return op.getLow();
1839 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1840 return op.getHigh();
1841 }
1842
1843 // mux(cond, x, cst)
1844 if (auto lowCst = getConstant(adaptor.getLow())) {
1845 // mux(cond, c1, c2)
1846 if (auto highCst = getConstant(adaptor.getHigh())) {
1847 // mux(cond, cst, cst) -> cst
1848 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1849 *highCst == *lowCst)
1850 return getIntAttr(op.getType(), *highCst);
1851 // mux(cond, 1, 0) -> cond
1852 if (highCst->isOne() && lowCst->isZero() &&
1853 op.getType() == op.getSel().getType())
1854 return op.getSel();
1855
1856 // TODO: x ? ~0 : 0 -> sext(x)
1857 // TODO: "x ? c1 : c2" -> many tricks
1858 }
1859 // TODO: "x ? a : 0" -> sext(x) & a
1860 }
1861
1862 // TODO: "x ? c1 : y" -> "~x ? y : c1"
1863 return {};
1864}
1865
1866OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1867 return foldMux(*this, adaptor);
1868}
1869
1870OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1871 return foldMux(*this, adaptor);
1872}
1873
1874OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1875
1876namespace {
1877
1878// If the mux has a known output width, pad the operands up to this width.
1879// Most folds on mux require that folded operands are of the same width as
1880// the mux itself.
1881class MuxPad : public mlir::OpRewritePattern<MuxPrimOp> {
1882public:
1883 using OpRewritePattern::OpRewritePattern;
1884
1885 LogicalResult
1886 matchAndRewrite(MuxPrimOp mux,
1887 mlir::PatternRewriter &rewriter) const override {
1888 auto width = mux.getType().getBitWidthOrSentinel();
1889 if (width < 0)
1890 return failure();
1891
1892 auto pad = [&](Value input) -> Value {
1893 auto inputWidth =
1894 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1895 if (inputWidth < 0 || width == inputWidth)
1896 return input;
1897 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1898 width)
1899 .getResult();
1900 };
1901
1902 auto newHigh = pad(mux.getHigh());
1903 auto newLow = pad(mux.getLow());
1904 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1905 return failure();
1906
1907 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1908 rewriter, mux, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1909 mux->getAttrs());
1910 return success();
1911 }
1912};
1913
1914// Find muxes which have conditions dominated by other muxes with the same
1915// condition.
1916class MuxSharedCond : public mlir::OpRewritePattern<MuxPrimOp> {
1917public:
1918 using OpRewritePattern::OpRewritePattern;
1919
1920 static const int depthLimit = 5;
1921
1922 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1923 mlir::PatternRewriter &rewriter,
1924 bool updateInPlace) const {
1925 if (updateInPlace) {
1926 rewriter.modifyOpInPlace(mux, [&] {
1927 mux.setOperand(1, high);
1928 mux.setOperand(2, low);
1929 });
1930 return {};
1931 }
1932 rewriter.setInsertionPointAfter(mux);
1933 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1934 ValueRange{mux.getSel(), high, low})
1935 .getResult();
1936 }
1937
1938 // Walk a dependent mux tree assuming the condition cond is true.
1939 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1940 bool updateInPlace, int limit) const {
1941 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1942 if (!mux)
1943 return {};
1944 if (mux.getSel() == cond)
1945 return mux.getHigh();
1946 if (limit > depthLimit)
1947 return {};
1948 updateInPlace &= mux->hasOneUse();
1949
1950 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1951 limit + 1))
1952 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1953
1954 if (Value v =
1955 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1956 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1957 return {};
1958 }
1959
1960 // Walk a dependent mux tree assuming the condition cond is false.
1961 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1962 bool updateInPlace, int limit) const {
1963 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1964 if (!mux)
1965 return {};
1966 if (mux.getSel() == cond)
1967 return mux.getLow();
1968 if (limit > depthLimit)
1969 return {};
1970 updateInPlace &= mux->hasOneUse();
1971
1972 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1973 limit + 1))
1974 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1975
1976 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1977 limit + 1))
1978 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1979
1980 return {};
1981 }
1982
1983 LogicalResult
1984 matchAndRewrite(MuxPrimOp mux,
1985 mlir::PatternRewriter &rewriter) const override {
1986 auto width = mux.getType().getBitWidthOrSentinel();
1987 if (width < 0)
1988 return failure();
1989
1990 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1991 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1992 return success();
1993 }
1994
1995 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
1996 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1997 return success();
1998 }
1999
2000 return failure();
2001 }
2002};
2003} // namespace
2004
2005void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
2006 MLIRContext *context) {
2007 results
2008 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
2009 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
2010 patterns::MuxSameTrue, patterns::MuxSameFalse,
2011 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
2012 context);
2013}
2014
2015void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
2016 RewritePatternSet &results, MLIRContext *context) {
2017 results.add<patterns::Mux2PadSel>(context);
2018}
2019
2020void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
2021 RewritePatternSet &results, MLIRContext *context) {
2022 results.add<patterns::Mux4PadSel>(context);
2023}
2024
2025OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
2026 auto input = this->getInput();
2027
2028 // pad(x) -> x if the width doesn't change.
2029 if (input.getType() == getType())
2030 return input;
2031
2032 // Need to know the input width.
2033 auto inputType = input.getType().base();
2034 int32_t width = inputType.getWidthOrSentinel();
2035 if (width == -1)
2036 return {};
2037
2038 // Constant fold.
2039 if (auto cst = getConstant(adaptor.getInput())) {
2040 auto destWidth = getType().base().getWidthOrSentinel();
2041 if (destWidth == -1)
2042 return {};
2043
2044 if (inputType.isSigned() && cst->getBitWidth())
2045 return getIntAttr(getType(), cst->sext(destWidth));
2046 return getIntAttr(getType(), cst->zext(destWidth));
2047 }
2048
2049 return {};
2050}
2051
2052OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
2053 auto input = this->getInput();
2054 IntType inputType = input.getType();
2055 int shiftAmount = getAmount();
2056
2057 // shl(x, 0) -> x
2058 if (shiftAmount == 0)
2059 return input;
2060
2061 // Constant fold.
2062 if (auto cst = getConstant(adaptor.getInput())) {
2063 auto inputWidth = inputType.getWidthOrSentinel();
2064 if (inputWidth != -1) {
2065 auto resultWidth = inputWidth + shiftAmount;
2066 shiftAmount = std::min(shiftAmount, resultWidth);
2067 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
2068 }
2069 }
2070 return {};
2071}
2072
2073OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
2074 auto input = this->getInput();
2075 IntType inputType = input.getType();
2076 int shiftAmount = getAmount();
2077 auto inputWidth = inputType.getWidthOrSentinel();
2078
2079 // shr(x, 0) -> x
2080 // Once the shr width changes, do this: shiftAmount == 0 &&
2081 // (!inputType.isSigned() || inputWidth > 0)
2082 if (shiftAmount == 0 && inputWidth > 0)
2083 return input;
2084
2085 if (inputWidth == -1)
2086 return {};
2087 if (inputWidth == 0)
2088 return getIntZerosAttr(getType());
2089
2090 // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
2091 // If x is signed, it is the sign bit.
2092 if (shiftAmount >= inputWidth && inputType.isUnsigned())
2093 return getIntAttr(getType(), APInt(0, 0, false));
2094
2095 // Constant fold.
2096 if (auto cst = getConstant(adaptor.getInput())) {
2097 APInt value;
2098 if (inputType.isSigned())
2099 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
2100 else
2101 value = cst->lshr(std::min(shiftAmount, inputWidth));
2102 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
2103 return getIntAttr(getType(), value.trunc(resultWidth));
2104 }
2105 return {};
2106}
2107
2108LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
2109 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2110 if (inputWidth <= 0)
2111 return failure();
2112
2113 // If we know the input width, we can canonicalize this into a BitsPrimOp.
2114 unsigned shiftAmount = op.getAmount();
2115 if (int(shiftAmount) >= inputWidth) {
2116 // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
2117 if (op.getType().base().isUnsigned())
2118 return failure();
2119
2120 // Shifting a signed value by the full width is actually taking the
2121 // sign bit. If the shift amount is greater than the input width, it
2122 // is equivalent to shifting by the input width.
2123 shiftAmount = inputWidth - 1;
2124 }
2125
2126 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
2127 return success();
2128}
2129
2130LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
2131 PatternRewriter &rewriter) {
2132 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2133 if (inputWidth <= 0)
2134 return failure();
2135
2136 // If we know the input width, we can canonicalize this into a BitsPrimOp.
2137 unsigned keepAmount = op.getAmount();
2138 if (keepAmount)
2139 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
2140 rewriter);
2141 return success();
2142}
2143
2144OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
2145 if (hasKnownWidthIntTypes(*this))
2146 if (auto cst = getConstant(adaptor.getInput())) {
2147 int shiftAmount =
2148 getInput().getType().base().getWidthOrSentinel() - getAmount();
2149 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
2150 }
2151
2152 return {};
2153}
2154
2155OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
2156 if (hasKnownWidthIntTypes(*this))
2157 if (auto cst = getConstant(adaptor.getInput()))
2158 return getIntAttr(getType(),
2159 cst->trunc(getType().base().getWidthOrSentinel()));
2160 return {};
2161}
2162
2163LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
2164 PatternRewriter &rewriter) {
2165 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2166 if (inputWidth <= 0)
2167 return failure();
2168
2169 // If we know the input width, we can canonicalize this into a BitsPrimOp.
2170 unsigned dropAmount = op.getAmount();
2171 if (dropAmount != unsigned(inputWidth))
2172 replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
2173 rewriter);
2174 return success();
2175}
2176
2177void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
2178 MLIRContext *context) {
2179 results.add<patterns::SubaccessOfConstant>(context);
2180}
2181
2182OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
2183 // If there is only one input, just return it.
2184 if (adaptor.getInputs().size() == 1)
2185 return getOperand(1);
2186
2187 if (auto constIndex = getConstant(adaptor.getIndex())) {
2188 auto index = constIndex->getZExtValue();
2189 if (index < getInputs().size())
2190 return getInputs()[getInputs().size() - 1 - index];
2191 }
2192
2193 return {};
2194}
2195
2196LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
2197 PatternRewriter &rewriter) {
2198 // If all operands are equal, just canonicalize to it. We can add this
2199 // canonicalization as a folder but it costly to look through all inputs so it
2200 // is added here.
2201 if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
2202 return input == op.getInputs().front();
2203 })) {
2204 replaceOpAndCopyName(rewriter, op, op.getInputs().front());
2205 return success();
2206 }
2207
2208 // If the index width is narrower than the size of inputs, drop front
2209 // elements.
2210 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
2211 uint64_t inputSize = op.getInputs().size();
2212 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
2213 rewriter.modifyOpInPlace(op, [&]() {
2214 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2215 });
2216 return success();
2217 }
2218
2219 // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
2220 // a[0]`), we can fold the op into subaccess op `a[idx]`.
2221 if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2222 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
2223 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2224 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2225 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2226 })) {
2227 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2228 rewriter, op, lastSubindex.getInput(), op.getIndex());
2229 return success();
2230 }
2231 }
2232
2233 // If the size is 2, canonicalize into a normal mux to introduce more folds.
2234 if (op.getInputs().size() != 2)
2235 return failure();
2236
2237 // TODO: Handle even when `index` doesn't have uint<1>.
2238 auto uintType = op.getIndex().getType();
2239 if (uintType.getBitWidthOrSentinel() != 1)
2240 return failure();
2241
2242 // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
2243 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2244 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2245 return success();
2246}
2247
2248//===----------------------------------------------------------------------===//
2249// Declarations
2250//===----------------------------------------------------------------------===//
2251
2252/// Scan all the uses of the specified value, checking to see if there is
2253/// exactly one connect that has the value as its destination. This returns the
2254/// operation if found and if all the other users are "reads" from the value.
2255/// Returns null if there are no connects, or multiple connects to the value, or
2256/// if the value is involved in an `AttachOp`, or if the connect isn't matching.
2257///
2258/// Note that this will simply return the connect, which is located *anywhere*
2259/// after the definition of the value. Users of this function are likely
2260/// interested in the source side of the returned connect, the definition of
2261/// which does likely not dominate the original value.
2262MatchingConnectOp firrtl::getSingleConnectUserOf(Value value) {
2263 MatchingConnectOp connect;
2264 for (Operation *user : value.getUsers()) {
2265 // If we see an attach or aggregate sublements, just conservatively fail.
2266 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2267 return {};
2268
2269 if (auto aConnect = dyn_cast<FConnectLike>(user))
2270 if (aConnect.getDest() == value) {
2271 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2272 // If this is not a matching connect, a second matching connect or in a
2273 // different block, fail.
2274 if (!matchingConnect || (connect && connect != matchingConnect) ||
2275 matchingConnect->getBlock() != value.getParentBlock())
2276 return {};
2277 connect = matchingConnect;
2278 }
2279 }
2280 return connect;
2281}
2282
2283// Forward simple values through wire's and reg's.
2284static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op,
2285 PatternRewriter &rewriter) {
2286 // While we can do this for nearly all wires, we currently limit it to simple
2287 // things.
2288 Operation *connectedDecl = op.getDest().getDefiningOp();
2289 if (!connectedDecl)
2290 return failure();
2291
2292 // Only support wire and reg for now.
2293 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2294 return failure();
2295 if (hasDontTouch(connectedDecl) || !AnnotationSet(connectedDecl).empty() ||
2296 !hasDroppableName(connectedDecl) ||
2297 cast<Forceable>(connectedDecl).isForceable())
2298 return failure();
2299
2300 // Only forward if the types exactly match and there is one connect.
2301 if (getSingleConnectUserOf(op.getDest()) != op)
2302 return failure();
2303
2304 // Only forward if there is more than one use
2305 if (connectedDecl->hasOneUse())
2306 return failure();
2307
2308 // Only do this if the connectee and the declaration are in the same block.
2309 auto *declBlock = connectedDecl->getBlock();
2310 auto *srcValueOp = op.getSrc().getDefiningOp();
2311 if (!srcValueOp) {
2312 // Ports are ok for wires but not registers.
2313 if (!isa<WireOp>(connectedDecl))
2314 return failure();
2315
2316 } else {
2317 // Constants/invalids in the same block are ok to forward, even through
2318 // reg's since the clocking doesn't matter for constants.
2319 if (!isa<ConstantOp>(srcValueOp))
2320 return failure();
2321 if (srcValueOp->getBlock() != declBlock)
2322 return failure();
2323 }
2324
2325 // Ok, we know we are doing the transformation.
2326
2327 auto replacement = op.getSrc();
2328 // This will be replaced with the constant source. First, make sure the
2329 // constant dominates all users.
2330 if (srcValueOp && srcValueOp != &declBlock->front())
2331 srcValueOp->moveBefore(&declBlock->front());
2332
2333 // Replace all things *using* the decl with the constant/port, and
2334 // remove the declaration.
2335 replaceOpAndCopyName(rewriter, connectedDecl, replacement);
2336
2337 // Remove the connect
2338 rewriter.eraseOp(op);
2339 return success();
2340}
2341
2342void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2343 MLIRContext *context) {
2344 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2345 context);
2346}
2347
2348LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2349 PatternRewriter &rewriter) {
2350 // TODO: Canonicalize towards explicit extensions and flips here.
2351
2352 // If there is a simple value connected to a foldable decl like a wire or reg,
2353 // see if we can eliminate the decl.
2354 if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
2355 return success();
2356 return failure();
2357}
2358
2359//===----------------------------------------------------------------------===//
2360// Statements
2361//===----------------------------------------------------------------------===//
2362
2363/// If the specified value has an AttachOp user strictly dominating by
2364/// "dominatingAttach" then return it.
2365static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
2366 for (auto *user : value.getUsers()) {
2367 auto attach = dyn_cast<AttachOp>(user);
2368 if (!attach || attach == dominatedAttach)
2369 continue;
2370 if (attach->isBeforeInBlock(dominatedAttach))
2371 return attach;
2372 }
2373 return {};
2374}
2375
2376LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2377 // Single operand attaches are a noop.
2378 if (op.getNumOperands() <= 1) {
2379 rewriter.eraseOp(op);
2380 return success();
2381 }
2382
2383 for (auto operand : op.getOperands()) {
2384 // Check to see if any of our operands has other attaches to it:
2385 // attach x, y
2386 // ...
2387 // attach x, z
2388 // If so, we can merge these into "attach x, y, z".
2389 if (auto attach = getDominatingAttachUser(operand, op)) {
2390 SmallVector<Value> newOperands(op.getOperands());
2391 for (auto newOperand : attach.getOperands())
2392 if (newOperand != operand) // Don't add operand twice.
2393 newOperands.push_back(newOperand);
2394 AttachOp::create(rewriter, op->getLoc(), newOperands);
2395 rewriter.eraseOp(attach);
2396 rewriter.eraseOp(op);
2397 return success();
2398 }
2399
2400 // If this wire is *only* used by an attach then we can just delete
2401 // it.
2402 // TODO: May need to be sensitive to "don't touch" or other
2403 // annotations.
2404 if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2405 if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2406 !wire.isForceable()) {
2407 SmallVector<Value> newOperands;
2408 for (auto newOperand : op.getOperands())
2409 if (newOperand != operand) // Don't the add wire.
2410 newOperands.push_back(newOperand);
2411
2412 AttachOp::create(rewriter, op->getLoc(), newOperands);
2413 rewriter.eraseOp(op);
2414 rewriter.eraseOp(wire);
2415 return success();
2416 }
2417 }
2418 }
2419 return failure();
2420}
2421
2422/// Replaces the given op with the contents of the given single-block region.
2423static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
2424 Region &region) {
2425 assert(llvm::hasSingleElement(region) && "expected single-region block");
2426 rewriter.inlineBlockBefore(&region.front(), op, {});
2427}
2428
2429LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2430 if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2431 if (constant.getValue().isAllOnes())
2432 replaceOpWithRegion(rewriter, op, op.getThenRegion());
2433 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2434 replaceOpWithRegion(rewriter, op, op.getElseRegion());
2435
2436 rewriter.eraseOp(op);
2437
2438 return success();
2439 }
2440
2441 // Erase empty if-else block.
2442 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2443 op.getElseBlock().empty()) {
2444 rewriter.eraseBlock(&op.getElseBlock());
2445 return success();
2446 }
2447
2448 // Erase empty whens.
2449
2450 // If there is stuff in the then block, leave this operation alone.
2451 if (!op.getThenBlock().empty())
2452 return failure();
2453
2454 // If not and there is no else, then this operation is just useless.
2455 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2456 rewriter.eraseOp(op);
2457 return success();
2458 }
2459 return failure();
2460}
2461
2462namespace {
2463// Remove private nodes. If they have an interesting names, move the name to
2464// the source expression.
2465struct FoldNodeName : public mlir::OpRewritePattern<NodeOp> {
2466 using OpRewritePattern::OpRewritePattern;
2467 LogicalResult matchAndRewrite(NodeOp node,
2468 PatternRewriter &rewriter) const override {
2469 auto name = node.getNameAttr();
2470 if (!node.hasDroppableName() || node.getInnerSym() ||
2471 !AnnotationSet(node).empty() || node.isForceable())
2472 return failure();
2473 auto *newOp = node.getInput().getDefiningOp();
2474 if (newOp)
2475 updateName(rewriter, newOp, name);
2476 rewriter.replaceOp(node, node.getInput());
2477 return success();
2478 }
2479};
2480
2481// Bypass nodes.
2482struct NodeBypass : public mlir::OpRewritePattern<NodeOp> {
2483 using OpRewritePattern::OpRewritePattern;
2484 LogicalResult matchAndRewrite(NodeOp node,
2485 PatternRewriter &rewriter) const override {
2486 if (node.getInnerSym() || !AnnotationSet(node).empty() ||
2487 node.use_empty() || node.isForceable())
2488 return failure();
2489 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2490 return success();
2491 }
2492};
2493
2494} // namespace
2495
2496template <typename OpTy>
2497static LogicalResult demoteForceableIfUnused(OpTy op,
2498 PatternRewriter &rewriter) {
2499 if (!op.isForceable() || !op.getDataRef().use_empty())
2500 return failure();
2501
2502 firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
2503 return success();
2504}
2505
2506// Interesting names and symbols and don't touch force nodes to stick around.
2507LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2508 SmallVectorImpl<OpFoldResult> &results) {
2509 if (!hasDroppableName())
2510 return failure();
2511 if (hasDontTouch(getResult())) // handles inner symbols
2512 return failure();
2513 if (getAnnotationsAttr() && !AnnotationSet(getAnnotationsAttr()).empty())
2514 return failure();
2515 if (isForceable())
2516 return failure();
2517 if (!adaptor.getInput())
2518 return failure();
2519
2520 results.push_back(adaptor.getInput());
2521 return success();
2522}
2523
2524void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2525 MLIRContext *context) {
2526 results.insert<FoldNodeName>(context);
2527 results.add(demoteForceableIfUnused<NodeOp>);
2528}
2529
2530namespace {
2531// For a lhs, find all the writers of fields of the aggregate type. If there
2532// is one writer for each field, merge the writes
2533struct AggOneShot : public mlir::RewritePattern {
2534 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2535 : RewritePattern(name, 0, context) {}
2536
2537 SmallVector<Value> getCompleteWrite(Operation *lhs) const {
2538 auto lhsTy = lhs->getResult(0).getType();
2539 if (!type_isa<BundleType, FVectorType>(lhsTy))
2540 return {};
2541
2542 DenseMap<uint32_t, Value> fields;
2543 for (Operation *user : lhs->getResult(0).getUsers()) {
2544 if (user->getParentOp() != lhs->getParentOp())
2545 return {};
2546 if (auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2547 if (aConnect.getDest() == lhs->getResult(0))
2548 return {};
2549 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2550 for (Operation *subuser : subField.getResult().getUsers()) {
2551 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2552 if (aConnect.getDest() == subField) {
2553 if (subuser->getParentOp() != lhs->getParentOp())
2554 return {};
2555 if (fields.count(subField.getFieldIndex())) // duplicate write
2556 return {};
2557 fields[subField.getFieldIndex()] = aConnect.getSrc();
2558 }
2559 continue;
2560 }
2561 return {};
2562 }
2563 } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2564 for (Operation *subuser : subIndex.getResult().getUsers()) {
2565 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2566 if (aConnect.getDest() == subIndex) {
2567 if (subuser->getParentOp() != lhs->getParentOp())
2568 return {};
2569 if (fields.count(subIndex.getIndex())) // duplicate write
2570 return {};
2571 fields[subIndex.getIndex()] = aConnect.getSrc();
2572 }
2573 continue;
2574 }
2575 return {};
2576 }
2577 } else {
2578 return {};
2579 }
2580 }
2581
2582 SmallVector<Value> values;
2583 uint32_t total = type_isa<BundleType>(lhsTy)
2584 ? type_cast<BundleType>(lhsTy).getNumElements()
2585 : type_cast<FVectorType>(lhsTy).getNumElements();
2586 for (uint32_t i = 0; i < total; ++i) {
2587 if (!fields.count(i))
2588 return {};
2589 values.push_back(fields[i]);
2590 }
2591 return values;
2592 }
2593
2594 LogicalResult matchAndRewrite(Operation *op,
2595 PatternRewriter &rewriter) const override {
2596 auto values = getCompleteWrite(op);
2597 if (values.empty())
2598 return failure();
2599 rewriter.setInsertionPointToEnd(op->getBlock());
2600 auto dest = op->getResult(0);
2601 auto destType = dest.getType();
2602
2603 // If not passive, cannot matchingconnect.
2604 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2605 return failure();
2606
2607 Value newVal = type_isa<BundleType>(destType)
2608 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2609 destType, values)
2610 : rewriter.createOrFold<VectorCreateOp>(
2611 op->getLoc(), destType, values);
2612 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2613 for (Operation *user : dest.getUsers()) {
2614 if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2615 for (Operation *subuser :
2616 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2617 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2618 if (aConnect.getDest() == subIndex)
2619 rewriter.eraseOp(aConnect);
2620 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2621 for (Operation *subuser :
2622 llvm::make_early_inc_range(subField.getResult().getUsers()))
2623 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2624 if (aConnect.getDest() == subField)
2625 rewriter.eraseOp(aConnect);
2626 }
2627 }
2628 return success();
2629 }
2630};
2631
2632struct WireAggOneShot : public AggOneShot {
2633 WireAggOneShot(MLIRContext *context)
2634 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2635};
2636struct SubindexAggOneShot : public AggOneShot {
2637 SubindexAggOneShot(MLIRContext *context)
2638 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2639};
2640struct SubfieldAggOneShot : public AggOneShot {
2641 SubfieldAggOneShot(MLIRContext *context)
2642 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2643};
2644} // namespace
2645
2646void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2647 MLIRContext *context) {
2648 results.insert<WireAggOneShot>(context);
2649 results.add(demoteForceableIfUnused<WireOp>);
2650}
2651
2652void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2653 MLIRContext *context) {
2654 results.insert<SubindexAggOneShot>(context);
2655}
2656
2657OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2658 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2659 if (!attr)
2660 return {};
2661 return attr[getIndex()];
2662}
2663
2664OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2665 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2666 if (!attr)
2667 return {};
2668 auto index = getFieldIndex();
2669 return attr[index];
2670}
2671
2672void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2673 MLIRContext *context) {
2674 results.insert<SubfieldAggOneShot>(context);
2675}
2676
2677static Attribute collectFields(MLIRContext *context,
2678 ArrayRef<Attribute> operands) {
2679 for (auto operand : operands)
2680 if (!operand)
2681 return {};
2682 return ArrayAttr::get(context, operands);
2683}
2684
2685OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2686 // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2687 // bundle<a:..., b:...>.
2688 if (getNumOperands() > 0)
2689 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2690 if (first.getFieldIndex() == 0 &&
2691 first.getInput().getType() == getType() &&
2692 llvm::all_of(
2693 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2694 auto subindex =
2695 elem.value().template getDefiningOp<SubfieldOp>();
2696 return subindex && subindex.getInput() == first.getInput() &&
2697 subindex.getFieldIndex() == elem.index();
2698 }))
2699 return first.getInput();
2700
2701 return collectFields(getContext(), adaptor.getOperands());
2702}
2703
2704OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2705 // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2706 // vector<..., 2>.
2707 if (getNumOperands() > 0)
2708 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2709 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2710 llvm::all_of(
2711 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2712 auto subindex =
2713 elem.value().template getDefiningOp<SubindexOp>();
2714 return subindex && subindex.getInput() == first.getInput() &&
2715 subindex.getIndex() == elem.index();
2716 }))
2717 return first.getInput();
2718
2719 return collectFields(getContext(), adaptor.getOperands());
2720}
2721
2722OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2723 if (getOperand().getType() == getType())
2724 return getOperand();
2725 return {};
2726}
2727
2728namespace {
2729// A register with constant reset and all connection to either itself or the
2730// same constant, must be replaced by the constant.
2731struct FoldResetMux : public mlir::OpRewritePattern<RegResetOp> {
2732 using OpRewritePattern::OpRewritePattern;
2733 LogicalResult matchAndRewrite(RegResetOp reg,
2734 PatternRewriter &rewriter) const override {
2735 auto reset =
2736 dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2737 if (!reset || hasDontTouch(reg.getOperation()) ||
2738 !AnnotationSet(reg).empty() || reg.isForceable())
2739 return failure();
2740 // Find the one true connect, or bail
2741 auto con = getSingleConnectUserOf(reg.getResult());
2742 if (!con)
2743 return failure();
2744
2745 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2746 if (!mux)
2747 return failure();
2748 auto *high = mux.getHigh().getDefiningOp();
2749 auto *low = mux.getLow().getDefiningOp();
2750 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2751
2752 if (constOp && low != reg)
2753 return failure();
2754 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2755 constOp = dyn_cast<ConstantOp>(low);
2756
2757 if (!constOp || constOp.getType() != reset.getType() ||
2758 constOp.getValue() != reset.getValue())
2759 return failure();
2760
2761 // Check all types should be typed by now
2762 auto regTy = reg.getResult().getType();
2763 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2764 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2765 regTy.getBitWidthOrSentinel() < 0)
2766 return failure();
2767
2768 // Ok, we know we are doing the transformation.
2769
2770 // Make sure the constant dominates all users.
2771 if (constOp != &con->getBlock()->front())
2772 constOp->moveBefore(&con->getBlock()->front());
2773
2774 // Replace the register with the constant.
2775 replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2776 // Remove the connect.
2777 rewriter.eraseOp(con);
2778 return success();
2779 }
2780};
2781} // namespace
2782
2783static bool isDefinedByOneConstantOp(Value v) {
2784 if (auto c = v.getDefiningOp<ConstantOp>())
2785 return c.getValue().isOne();
2786 if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2787 return sc.getValue();
2788 return false;
2789}
2790
2791static LogicalResult
2792canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2793 if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2794 return failure();
2795
2796 auto resetValue = reg.getResetValue();
2797 if (reg.getType(0) != resetValue.getType())
2798 return failure();
2799
2800 // Ignore 'passthrough'.
2801 (void)dropWrite(rewriter, reg->getResult(0), {});
2802 replaceOpWithNewOpAndCopyName<NodeOp>(
2803 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2804 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2805 return success();
2806}
2807
2808void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2809 MLIRContext *context) {
2810 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2812 results.add(demoteForceableIfUnused<RegResetOp>);
2813}
2814
2815// Returns the value connected to a port, if there is only one.
2816static Value getPortFieldValue(Value port, StringRef name) {
2817 auto portTy = type_cast<BundleType>(port.getType());
2818 auto fieldIndex = portTy.getElementIndex(name);
2819 assert(fieldIndex && "missing field on memory port");
2820
2821 Value value = {};
2822 for (auto *op : port.getUsers()) {
2823 auto portAccess = cast<SubfieldOp>(op);
2824 if (fieldIndex != portAccess.getFieldIndex())
2825 continue;
2826 auto conn = getSingleConnectUserOf(portAccess);
2827 if (!conn || value)
2828 return {};
2829 value = conn.getSrc();
2830 }
2831 return value;
2832}
2833
2834// Returns true if the enable field of a port is set to false.
2835static bool isPortDisabled(Value port) {
2836 auto value = getPortFieldValue(port, "en");
2837 if (!value)
2838 return false;
2839 auto portConst = value.getDefiningOp<ConstantOp>();
2840 if (!portConst)
2841 return false;
2842 return portConst.getValue().isZero();
2843}
2844
2845// Returns true if the data output is unused.
2846static bool isPortUnused(Value port, StringRef data) {
2847 auto portTy = type_cast<BundleType>(port.getType());
2848 auto fieldIndex = portTy.getElementIndex(data);
2849 assert(fieldIndex && "missing enable flag on memory port");
2850
2851 for (auto *op : port.getUsers()) {
2852 auto portAccess = cast<SubfieldOp>(op);
2853 if (fieldIndex != portAccess.getFieldIndex())
2854 continue;
2855 if (!portAccess.use_empty())
2856 return false;
2857 }
2858
2859 return true;
2860}
2861
2862// Returns the value connected to a port, if there is only one.
2863static void replacePortField(PatternRewriter &rewriter, Value port,
2864 StringRef name, Value value) {
2865 auto portTy = type_cast<BundleType>(port.getType());
2866 auto fieldIndex = portTy.getElementIndex(name);
2867 assert(fieldIndex && "missing field on memory port");
2868
2869 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2870 auto portAccess = cast<SubfieldOp>(op);
2871 if (fieldIndex != portAccess.getFieldIndex())
2872 continue;
2873 rewriter.replaceAllUsesWith(portAccess, value);
2874 rewriter.eraseOp(portAccess);
2875 }
2876}
2877
2878// Remove accesses to a port which is used.
2879static void erasePort(PatternRewriter &rewriter, Value port) {
2880 // Helper to create a dummy 0 clock for the dummy registers.
2881 Value clock;
2882 auto getClock = [&] {
2883 if (!clock)
2884 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2885 ClockType::get(rewriter.getContext()),
2886 false);
2887 return clock;
2888 };
2889
2890 // Find the clock field of the port and determine whether the port is
2891 // accessed only through its subfields or as a whole wire. If the port
2892 // is used in its entirety, replace it with a wire. Otherwise,
2893 // eliminate individual subfields and replace with reasonable defaults.
2894 for (auto *op : port.getUsers()) {
2895 auto subfield = dyn_cast<SubfieldOp>(op);
2896 if (!subfield) {
2897 auto ty = port.getType();
2898 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2899 rewriter.replaceAllUsesWith(port, reg.getResult());
2900 return;
2901 }
2902 }
2903
2904 // Remove all connects to field accesses as they are no longer relevant.
2905 // If field values are used anywhere, which should happen solely for read
2906 // ports, a dummy register is introduced which replicates the behaviour of
2907 // memory that is never written, but might be read.
2908 for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2909 auto access = cast<SubfieldOp>(accessOp);
2910 for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2911 auto connect = dyn_cast<FConnectLike>(user);
2912 if (connect && connect.getDest() == access) {
2913 rewriter.eraseOp(user);
2914 continue;
2915 }
2916 }
2917 if (access.use_empty()) {
2918 rewriter.eraseOp(access);
2919 continue;
2920 }
2921
2922 // Replace read values with a register that is never written, handing off
2923 // the canonicalization of such a register to another canonicalizer.
2924 auto ty = access.getType();
2925 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2926 rewriter.replaceOp(access, reg.getResult());
2927 }
2928 assert(port.use_empty() && "port should have no remaining uses");
2929}
2930
2931namespace {
2932// If memory has known, but zero width, eliminate it.
2933struct FoldZeroWidthMemory : public mlir::OpRewritePattern<MemOp> {
2934 using OpRewritePattern::OpRewritePattern;
2935 LogicalResult matchAndRewrite(MemOp mem,
2936 PatternRewriter &rewriter) const override {
2937 if (hasDontTouch(mem))
2938 return failure();
2939
2940 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2941 mem.getDataType().getBitWidthOrSentinel() != 0)
2942 return failure();
2943
2944 // Make sure are users are safe to replace
2945 for (auto port : mem.getResults())
2946 for (auto *user : port.getUsers())
2947 if (!isa<SubfieldOp>(user))
2948 return failure();
2949
2950 // Annoyingly, there isn't a good replacement for the port as a whole,
2951 // since they have an outer flip type.
2952 for (auto port : mem.getResults()) {
2953 for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2954 SubfieldOp sfop = cast<SubfieldOp>(user);
2955 StringRef fieldName = sfop.getFieldName();
2956 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2957 rewriter, sfop, sfop.getResult().getType())
2958 .getResult();
2959 if (fieldName.ends_with("data")) {
2960 // Make sure to write data ports.
2961 auto zero = firrtl::ConstantOp::create(
2962 rewriter, wire.getLoc(),
2963 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2964 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2965 }
2966 }
2967 }
2968 rewriter.eraseOp(mem);
2969 return success();
2970 }
2971};
2972
2973// If memory has no write ports and no file initialization, eliminate it.
2974struct FoldReadOrWriteOnlyMemory : public mlir::OpRewritePattern<MemOp> {
2975 using OpRewritePattern::OpRewritePattern;
2976 LogicalResult matchAndRewrite(MemOp mem,
2977 PatternRewriter &rewriter) const override {
2978 if (hasDontTouch(mem))
2979 return failure();
2980 bool isRead = false, isWritten = false;
2981 for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2982 switch (mem.getPortKind(i)) {
2983 case MemOp::PortKind::Read:
2984 isRead = true;
2985 if (isWritten)
2986 return failure();
2987 continue;
2988 case MemOp::PortKind::Write:
2989 isWritten = true;
2990 if (isRead)
2991 return failure();
2992 continue;
2993 case MemOp::PortKind::Debug:
2994 case MemOp::PortKind::ReadWrite:
2995 return failure();
2996 }
2997 llvm_unreachable("unknown port kind");
2998 }
2999 assert((!isWritten || !isRead) && "memory is in use");
3000
3001 // If the memory is read only, but has a file initialization, then we can't
3002 // remove it. A write only memory with file initialization is okay to
3003 // remove.
3004 if (isRead && mem.getInit())
3005 return failure();
3006
3007 for (auto port : mem.getResults())
3008 erasePort(rewriter, port);
3009
3010 rewriter.eraseOp(mem);
3011 return success();
3012 }
3013};
3014
3015// Eliminate the dead ports of memories.
3016struct FoldUnusedPorts : public mlir::OpRewritePattern<MemOp> {
3017 using OpRewritePattern::OpRewritePattern;
3018 LogicalResult matchAndRewrite(MemOp mem,
3019 PatternRewriter &rewriter) const override {
3020 if (hasDontTouch(mem))
3021 return failure();
3022 // Identify the dead and changed ports.
3023 llvm::SmallBitVector deadPorts(mem.getNumResults());
3024 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3025 // Do not simplify annotated ports.
3026 if (!mem.getPortAnnotation(i).empty())
3027 continue;
3028
3029 // Skip debug ports.
3030 auto kind = mem.getPortKind(i);
3031 if (kind == MemOp::PortKind::Debug)
3032 continue;
3033
3034 // If a port is disabled, always eliminate it.
3035 if (isPortDisabled(port)) {
3036 deadPorts.set(i);
3037 continue;
3038 }
3039 // Eliminate read ports whose outputs are not used.
3040 if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
3041 deadPorts.set(i);
3042 continue;
3043 }
3044 }
3045 if (deadPorts.none())
3046 return failure();
3047
3048 // Rebuild the new memory with the altered ports.
3049 SmallVector<Type> resultTypes;
3050 SmallVector<StringRef> portNames;
3051 SmallVector<Attribute> portAnnotations;
3052 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3053 if (deadPorts[i])
3054 continue;
3055 resultTypes.push_back(port.getType());
3056 portNames.push_back(mem.getPortName(i));
3057 portAnnotations.push_back(mem.getPortAnnotation(i));
3058 }
3059
3060 MemOp newOp;
3061 if (!resultTypes.empty())
3062 newOp = MemOp::create(
3063 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3064 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3065 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3066 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3067 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3068
3069 // Replace the dead ports with dummy wires.
3070 unsigned nextPort = 0;
3071 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3072 if (deadPorts[i])
3073 erasePort(rewriter, port);
3074 else
3075 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
3076 }
3077
3078 rewriter.eraseOp(mem);
3079 return success();
3080 }
3081};
3082
3083// Rewrite write-only read-write ports to write ports.
3084struct FoldReadWritePorts : public mlir::OpRewritePattern<MemOp> {
3085 using OpRewritePattern::OpRewritePattern;
3086 LogicalResult matchAndRewrite(MemOp mem,
3087 PatternRewriter &rewriter) const override {
3088 if (hasDontTouch(mem))
3089 return failure();
3090
3091 // Identify read-write ports whose read end is unused.
3092 llvm::SmallBitVector deadReads(mem.getNumResults());
3093 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3094 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
3095 continue;
3096 if (!mem.getPortAnnotation(i).empty())
3097 continue;
3098 if (isPortUnused(port, "rdata")) {
3099 deadReads.set(i);
3100 continue;
3101 }
3102 }
3103 if (deadReads.none())
3104 return failure();
3105
3106 SmallVector<Type> resultTypes;
3107 SmallVector<StringRef> portNames;
3108 SmallVector<Attribute> portAnnotations;
3109 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3110 if (deadReads[i])
3111 resultTypes.push_back(
3112 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
3113 MemOp::PortKind::Write, mem.getMaskBits()));
3114 else
3115 resultTypes.push_back(port.getType());
3116
3117 portNames.push_back(mem.getPortName(i));
3118 portAnnotations.push_back(mem.getPortAnnotation(i));
3119 }
3120
3121 auto newOp = MemOp::create(
3122 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3123 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3124 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3125 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3126 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3127
3128 for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
3129 auto result = mem.getResult(i);
3130 auto newResult = newOp.getResult(i);
3131 if (deadReads[i]) {
3132 auto resultPortTy = type_cast<BundleType>(result.getType());
3133
3134 // Rewrite accesses to the old port field to accesses to a
3135 // corresponding field of the new port.
3136 auto replace = [&](StringRef toName, StringRef fromName) {
3137 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
3138 assert(fromFieldIndex && "missing enable flag on memory port");
3139
3140 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
3141 newResult, toName);
3142 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
3143 auto fromField = cast<SubfieldOp>(op);
3144 if (fromFieldIndex != fromField.getFieldIndex())
3145 continue;
3146 rewriter.replaceOp(fromField, toField.getResult());
3147 }
3148 };
3149
3150 replace("addr", "addr");
3151 replace("en", "en");
3152 replace("clk", "clk");
3153 replace("data", "wdata");
3154 replace("mask", "wmask");
3155
3156 // Remove the wmode field, replacing it with dummy wires.
3157 auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
3158 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
3159 auto wmodeField = cast<SubfieldOp>(op);
3160 if (wmodeFieldIndex != wmodeField.getFieldIndex())
3161 continue;
3162 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
3163 }
3164 } else {
3165 rewriter.replaceAllUsesWith(result, newResult);
3166 }
3167 }
3168 rewriter.eraseOp(mem);
3169 return success();
3170 }
3171};
3172
3173// Eliminate the dead ports of memories.
3174struct FoldUnusedBits : public mlir::OpRewritePattern<MemOp> {
3175 using OpRewritePattern::OpRewritePattern;
3176
3177 LogicalResult matchAndRewrite(MemOp mem,
3178 PatternRewriter &rewriter) const override {
3179 if (hasDontTouch(mem))
3180 return failure();
3181
3182 // Only apply the transformation if the memory is not sequential.
3183 const auto &summary = mem.getSummary();
3184 if (summary.isMasked || summary.isSeqMem())
3185 return failure();
3186
3187 auto type = type_dyn_cast<IntType>(mem.getDataType());
3188 if (!type)
3189 return failure();
3190 auto width = type.getBitWidthOrSentinel();
3191 if (width <= 0)
3192 return failure();
3193
3194 llvm::SmallBitVector usedBits(width);
3195 DenseMap<unsigned, unsigned> mapping;
3196
3197 // Find which bits are used out of the users of a read port. This detects
3198 // ports whose data/rdata field is used only through bit select ops. The
3199 // bit selects are then used to build a bit-mask. The ops are collected.
3200 SmallVector<BitsPrimOp> readOps;
3201 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3202 auto portTy = type_cast<BundleType>(port.getType());
3203 auto fieldIndex = portTy.getElementIndex(field);
3204 assert(fieldIndex && "missing data port");
3205
3206 for (auto *op : port.getUsers()) {
3207 auto portAccess = cast<SubfieldOp>(op);
3208 if (fieldIndex != portAccess.getFieldIndex())
3209 continue;
3210
3211 for (auto *user : op->getUsers()) {
3212 auto bits = dyn_cast<BitsPrimOp>(user);
3213 if (!bits)
3214 return failure();
3215
3216 usedBits.set(bits.getLo(), bits.getHi() + 1);
3217 if (usedBits.all())
3218 return failure();
3219
3220 mapping[bits.getLo()] = 0;
3221 readOps.push_back(bits);
3222 }
3223 }
3224
3225 return success();
3226 };
3227
3228 // Finds the users of write ports. This expects all the data/wdata fields
3229 // of the ports to be used solely as the destination of matching connects.
3230 // If a memory has ports with other uses, it is excluded from optimisation.
3231 SmallVector<MatchingConnectOp> writeOps;
3232 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3233 auto portTy = type_cast<BundleType>(port.getType());
3234 auto fieldIndex = portTy.getElementIndex(field);
3235 assert(fieldIndex && "missing data port");
3236
3237 for (auto *op : port.getUsers()) {
3238 auto portAccess = cast<SubfieldOp>(op);
3239 if (fieldIndex != portAccess.getFieldIndex())
3240 continue;
3241
3242 auto conn = getSingleConnectUserOf(portAccess);
3243 if (!conn)
3244 return failure();
3245
3246 writeOps.push_back(conn);
3247 }
3248 return success();
3249 };
3250
3251 // Traverse all ports and find the read and used data fields.
3252 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3253 // Do not simplify annotated ports.
3254 if (!mem.getPortAnnotation(i).empty())
3255 return failure();
3256
3257 switch (mem.getPortKind(i)) {
3258 case MemOp::PortKind::Debug:
3259 // Skip debug ports.
3260 return failure();
3261 case MemOp::PortKind::Write:
3262 if (failed(findWriteUsers(port, "data")))
3263 return failure();
3264 continue;
3265 case MemOp::PortKind::Read:
3266 if (failed(findReadUsers(port, "data")))
3267 return failure();
3268 continue;
3269 case MemOp::PortKind::ReadWrite:
3270 if (failed(findWriteUsers(port, "wdata")))
3271 return failure();
3272 if (failed(findReadUsers(port, "rdata")))
3273 return failure();
3274 continue;
3275 }
3276 llvm_unreachable("unknown port kind");
3277 }
3278
3279 // Unused memories are handled in a different canonicalizer.
3280 if (usedBits.none())
3281 return failure();
3282
3283 // Build a mapping of existing indices to compacted ones.
3284 SmallVector<std::pair<unsigned, unsigned>> ranges;
3285 unsigned newWidth = 0;
3286 for (int i = usedBits.find_first(); 0 <= i && i < width;) {
3287 int e = usedBits.find_next_unset(i);
3288 if (e < 0)
3289 e = width;
3290 for (int idx = i; idx < e; ++idx, ++newWidth) {
3291 if (auto it = mapping.find(idx); it != mapping.end()) {
3292 it->second = newWidth;
3293 }
3294 }
3295 ranges.emplace_back(i, e - 1);
3296 i = e != width ? usedBits.find_next(e) : e;
3297 }
3298
3299 // Create the new op with the new port types.
3300 auto newType = IntType::get(mem->getContext(), type.isSigned(), newWidth);
3301 SmallVector<Type> portTypes;
3302 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3303 portTypes.push_back(
3304 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3305 }
3306 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3307 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3308 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3309 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3310 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3311
3312 // Rewrite bundle users to the new data type.
3313 auto rewriteSubfield = [&](Value port, StringRef field) {
3314 auto portTy = type_cast<BundleType>(port.getType());
3315 auto fieldIndex = portTy.getElementIndex(field);
3316 assert(fieldIndex && "missing data port");
3317
3318 rewriter.setInsertionPointAfter(newMem);
3319 auto newPortAccess =
3320 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3321
3322 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
3323 auto portAccess = cast<SubfieldOp>(op);
3324 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3325 continue;
3326 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3327 }
3328 };
3329
3330 // Rewrite the field accesses.
3331 for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
3332 switch (newMem.getPortKind(i)) {
3333 case MemOp::PortKind::Debug:
3334 llvm_unreachable("cannot rewrite debug port");
3335 case MemOp::PortKind::Write:
3336 rewriteSubfield(port, "data");
3337 continue;
3338 case MemOp::PortKind::Read:
3339 rewriteSubfield(port, "data");
3340 continue;
3341 case MemOp::PortKind::ReadWrite:
3342 rewriteSubfield(port, "rdata");
3343 rewriteSubfield(port, "wdata");
3344 continue;
3345 }
3346 llvm_unreachable("unknown port kind");
3347 }
3348
3349 // Rewrite the reads to the new ranges, compacting them.
3350 for (auto readOp : readOps) {
3351 rewriter.setInsertionPointAfter(readOp);
3352 auto it = mapping.find(readOp.getLo());
3353 assert(it != mapping.end() && "bit op mapping not found");
3354 // Create a new bit selection from the compressed memory. The new op may
3355 // be folded if we are selecting the entire compressed memory.
3356 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3357 readOp.getLoc(), readOp.getInput(),
3358 readOp.getHi() - readOp.getLo() + it->second, it->second);
3359 rewriter.replaceAllUsesWith(readOp, newReadValue);
3360 rewriter.eraseOp(readOp);
3361 }
3362
3363 // Rewrite the writes into a concatenation of slices.
3364 for (auto writeOp : writeOps) {
3365 Value source = writeOp.getSrc();
3366 rewriter.setInsertionPoint(writeOp);
3367
3368 SmallVector<Value> slices;
3369 for (auto &[start, end] : llvm::reverse(ranges)) {
3370 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3371 source, end, start);
3372 slices.push_back(slice);
3373 }
3374
3375 Value catOfSlices =
3376 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3377
3378 // If the original memory held a signed integer, then the compressed
3379 // memory will be signed too. Since the catOfSlices is always unsigned,
3380 // cast the data to a signed integer if needed before connecting back to
3381 // the memory.
3382 if (type.isSigned())
3383 catOfSlices =
3384 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3385
3386 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3387 catOfSlices);
3388 }
3389
3390 return success();
3391 }
3392};
3393
3394// Rewrite single-address memories to a firrtl register.
3395struct FoldRegMems : public mlir::OpRewritePattern<MemOp> {
3396 using OpRewritePattern::OpRewritePattern;
3397 LogicalResult matchAndRewrite(MemOp mem,
3398 PatternRewriter &rewriter) const override {
3399 const FirMemory &info = mem.getSummary();
3400 if (hasDontTouch(mem) || info.depth != 1)
3401 return failure();
3402
3403 auto ty = mem.getDataType();
3404 auto loc = mem.getLoc();
3405 auto *block = mem->getBlock();
3406
3407 // Find the clock of the register-to-be, all write ports should share it.
3408 Value clock;
3409 SmallPtrSet<Operation *, 8> connects;
3410 SmallVector<SubfieldOp> portAccesses;
3411 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3412 if (!mem.getPortAnnotation(i).empty())
3413 continue;
3414
3415 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3416 auto portTy = type_cast<BundleType>(port.getType());
3417 for (auto field : fields) {
3418 auto fieldIndex = portTy.getElementIndex(field);
3419 assert(fieldIndex && "missing field on memory port");
3420
3421 for (auto *op : port.getUsers()) {
3422 auto portAccess = cast<SubfieldOp>(op);
3423 if (fieldIndex != portAccess.getFieldIndex())
3424 continue;
3425 portAccesses.push_back(portAccess);
3426 for (auto *user : portAccess->getUsers()) {
3427 auto conn = dyn_cast<FConnectLike>(user);
3428 if (!conn)
3429 return failure();
3430 connects.insert(conn);
3431 }
3432 }
3433 }
3434 return success();
3435 };
3436
3437 switch (mem.getPortKind(i)) {
3438 case MemOp::PortKind::Debug:
3439 return failure();
3440 case MemOp::PortKind::Read:
3441 if (failed(collect({"clk", "en", "addr"})))
3442 return failure();
3443 continue;
3444 case MemOp::PortKind::Write:
3445 if (failed(collect({"clk", "en", "addr", "data", "mask"})))
3446 return failure();
3447 break;
3448 case MemOp::PortKind::ReadWrite:
3449 if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
3450 return failure();
3451 break;
3452 }
3453
3454 Value portClock = getPortFieldValue(port, "clk");
3455 if (!portClock || (clock && portClock != clock))
3456 return failure();
3457 clock = portClock;
3458 }
3459 // Create a new wire where the memory used to be. This wire will dominate
3460 // all readers of the memory. Reads should be made through this wire.
3461 rewriter.setInsertionPointAfter(mem);
3462 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3463
3464 // The memory is replaced by a register, which we place at the end of the
3465 // block, so that any value driven to the original memory will dominate the
3466 // new register (including the clock). All other ops will be placed
3467 // after the register.
3468 rewriter.setInsertionPointToEnd(block);
3469 auto memReg =
3470 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3471
3472 // Connect the output of the register to the wire.
3473 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3474
3475 // Helper to insert a given number of pipeline stages through registers.
3476 // The pipelines are placed at the end of the block.
3477 auto pipeline = [&](Value value, Value clock, const Twine &name,
3478 unsigned latency) {
3479 for (unsigned i = 0; i < latency; ++i) {
3480 std::string regName;
3481 {
3482 llvm::raw_string_ostream os(regName);
3483 os << mem.getName() << "_" << name << "_" << i;
3484 }
3485 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3486 rewriter.getStringAttr(regName))
3487 .getResult();
3488 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3489 value = reg;
3490 }
3491 return value;
3492 };
3493
3494 const unsigned writeStages = info.writeLatency - 1;
3495
3496 // Traverse each port. Replace reads with the pipelined register, discarding
3497 // the enable flag and reading unconditionally. Pipeline the mask, enable
3498 // and data bits of all write ports to be arbitrated and wired to the reg.
3499 SmallVector<std::tuple<Value, Value, Value>> writes;
3500 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3501 Value portClock = getPortFieldValue(port, "clk");
3502 StringRef name = mem.getPortName(i);
3503
3504 auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
3505 Value value = getPortFieldValue(port, field);
3506 assert(value);
3507 return pipeline(value, portClock, name + "_" + field, stages);
3508 };
3509
3510 switch (mem.getPortKind(i)) {
3511 case MemOp::PortKind::Debug:
3512 llvm_unreachable("unknown port kind");
3513 case MemOp::PortKind::Read: {
3514 // Read ports pipeline the addr and enable signals. However, the
3515 // address must be 0 for single-address memories and the enable signal
3516 // is ignored, always reading out the register. Under these constraints,
3517 // the read port can be replaced with the value from the register.
3518 replacePortField(rewriter, port, "data", memWire);
3519 break;
3520 }
3521 case MemOp::PortKind::Write: {
3522 auto data = portPipeline("data", writeStages);
3523 auto en = portPipeline("en", writeStages);
3524 auto mask = portPipeline("mask", writeStages);
3525 writes.emplace_back(data, en, mask);
3526 break;
3527 }
3528 case MemOp::PortKind::ReadWrite: {
3529 // Always read the register into the read end.
3530 replacePortField(rewriter, port, "rdata", memWire);
3531
3532 // Create a write enable and pipeline stages.
3533 auto wdata = portPipeline("wdata", writeStages);
3534 auto wmask = portPipeline("wmask", writeStages);
3535
3536 Value en = getPortFieldValue(port, "en");
3537 Value wmode = getPortFieldValue(port, "wmode");
3538
3539 auto wen = AndPrimOp::create(rewriter, port.getLoc(), en, wmode);
3540 auto wenPipelined =
3541 pipeline(wen, portClock, name + "_wen", writeStages);
3542 writes.emplace_back(wdata, wenPipelined, wmask);
3543 break;
3544 }
3545 }
3546 }
3547
3548 // Regardless of `writeUnderWrite`, always implement PortOrder.
3549 Value next = memReg;
3550 for (auto &[data, en, mask] : writes) {
3551 Value masked;
3552
3553 // If a mask bit is used, emit muxes to select the input from the
3554 // register (no mask) or the input (mask bit set).
3555 Location loc = mem.getLoc();
3556 unsigned maskGran = info.dataWidth / info.maskBits;
3557 SmallVector<Value> chunks;
3558 for (unsigned i = 0; i < info.maskBits; ++i) {
3559 unsigned hi = (i + 1) * maskGran - 1;
3560 unsigned lo = i * maskGran;
3561
3562 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
3563 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3564 auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
3565 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3566 chunks.push_back(chunk);
3567 }
3568
3569 std::reverse(chunks.begin(), chunks.end());
3570 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3571 next = MuxPrimOp::create(rewriter, next.getLoc(), en, masked, next);
3572 }
3573 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3574 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3575
3576 // Delete the fields and their associated connects.
3577 for (Operation *conn : connects)
3578 rewriter.eraseOp(conn);
3579 for (auto portAccess : portAccesses)
3580 rewriter.eraseOp(portAccess);
3581 rewriter.eraseOp(mem);
3582
3583 return success();
3584 }
3585};
3586} // namespace
3587
3588void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3589 MLIRContext *context) {
3590 results
3591 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3592 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3593 context);
3594}
3595
3596//===----------------------------------------------------------------------===//
3597// Declarations
3598//===----------------------------------------------------------------------===//
3599
3600// Turn synchronous reset looking register updates to registers with resets.
3601// Also, const prop registers that are driven by a mux tree containing only
3602// instances of one constant or self-assigns.
3603static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3604 // reg ; connect(reg, mux(port, const, val)) ->
3605 // reg.reset(port, const); connect(reg, val)
3606
3607 // Find the one true connect, or bail
3608 auto con = getSingleConnectUserOf(reg.getResult());
3609 if (!con)
3610 return failure();
3611
3612 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3613 if (!mux)
3614 return failure();
3615 auto *high = mux.getHigh().getDefiningOp();
3616 auto *low = mux.getLow().getDefiningOp();
3617 // Reset value must be constant
3618 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3619
3620 // Detect the case if a register only has two possible drivers:
3621 // (1) itself/uninit and (2) constant.
3622 // The mux can then be replaced with the constant.
3623 // r = mux(cond, r, 3) --> r = 3
3624 // r = mux(cond, 3, r) --> r = 3
3625 bool constReg = false;
3626
3627 if (constOp && low == reg)
3628 constReg = true;
3629 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3630 constReg = true;
3631 constOp = dyn_cast<ConstantOp>(low);
3632 }
3633 if (!constOp)
3634 return failure();
3635
3636 // For a non-constant register, reset should be a module port (heuristic to
3637 // limit to intended reset lines). Replace the register anyway if constant.
3638 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3639 return failure();
3640
3641 // Check all types should be typed by now
3642 auto regTy = reg.getResult().getType();
3643 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3644 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3645 regTy.getBitWidthOrSentinel() < 0)
3646 return failure();
3647
3648 // Ok, we know we are doing the transformation.
3649
3650 // Make sure the constant dominates all users.
3651 if (constOp != &con->getBlock()->front())
3652 constOp->moveBefore(&con->getBlock()->front());
3653
3654 if (!constReg) {
3655 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3656 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3657 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3658 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3659 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3660 reg.getForceableAttr());
3661 newReg->setDialectAttrs(attrs);
3662 }
3663 auto pt = rewriter.saveInsertionPoint();
3664 rewriter.setInsertionPoint(con);
3665 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3666 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3667 rewriter.restoreInsertionPoint(pt);
3668 return success();
3669}
3670
3671LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3672 if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3673 succeeded(foldHiddenReset(op, rewriter)))
3674 return success();
3675
3676 if (succeeded(demoteForceableIfUnused(op, rewriter)))
3677 return success();
3678
3679 return failure();
3680}
3681
3682//===----------------------------------------------------------------------===//
3683// Verification Ops.
3684//===----------------------------------------------------------------------===//
3685
3686static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3687 Value enable,
3688 PatternRewriter &rewriter,
3689 bool eraseIfZero) {
3690 // If the verification op is never enabled, delete it.
3691 if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3692 if (constant.getValue().isZero()) {
3693 rewriter.eraseOp(op);
3694 return success();
3695 }
3696 }
3697
3698 // If the verification op is never triggered, delete it.
3699 if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3700 if (constant.getValue().isZero() == eraseIfZero) {
3701 rewriter.eraseOp(op);
3702 return success();
3703 }
3704 }
3705
3706 return failure();
3707}
3708
3709template <class Op, bool EraseIfZero = false>
3710static LogicalResult canonicalizeImmediateVerifOp(Op op,
3711 PatternRewriter &rewriter) {
3712 return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3713 EraseIfZero);
3714}
3715
3716void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3717 MLIRContext *context) {
3718 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3719 results.add<patterns::AssertXWhenX>(context);
3720}
3721
3722void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3723 MLIRContext *context) {
3724 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3725 results.add<patterns::AssumeXWhenX>(context);
3726}
3727
3728void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3729 RewritePatternSet &results, MLIRContext *context) {
3730 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3731 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(context);
3732}
3733
3734void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3735 MLIRContext *context) {
3736 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3737}
3738
3739//===----------------------------------------------------------------------===//
3740// InvalidValueOp
3741//===----------------------------------------------------------------------===//
3742
3743LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3744 PatternRewriter &rewriter) {
3745 // Remove `InvalidValueOp`s with no uses.
3746 if (op.use_empty()) {
3747 rewriter.eraseOp(op);
3748 return success();
3749 }
3750 // Propagate invalids through a single use which is a unary op. You cannot
3751 // propagate through multiple uses as that breaks invalid semantics. Nor
3752 // can you propagate through binary ops or generally any op which computes.
3753 // Not is an exception as it is a pure, all-bits inverse.
3754 if (op->hasOneUse() &&
3755 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3756 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3757 *op->user_begin()) ||
3758 (isa<CvtPrimOp>(*op->user_begin()) &&
3759 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3760 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3761 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3762 .getBitWidthOrSentinel() > 0))) {
3763 auto *modop = *op->user_begin();
3764 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3765 modop->getResult(0).getType());
3766 rewriter.replaceAllOpUsesWith(modop, inv);
3767 rewriter.eraseOp(modop);
3768 rewriter.eraseOp(op);
3769 return success();
3770 }
3771 return failure();
3772}
3773
3774OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3775 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3776 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3777 return {};
3778}
3779
3780//===----------------------------------------------------------------------===//
3781// ClockGateIntrinsicOp
3782//===----------------------------------------------------------------------===//
3783
3784OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3785 // Forward the clock if one of the enables is always true.
3786 if (isConstantOne(adaptor.getEnable()) ||
3787 isConstantOne(adaptor.getTestEnable()))
3788 return getInput();
3789
3790 // Fold to a constant zero clock if the enables are always false.
3791 if (isConstantZero(adaptor.getEnable()) &&
3792 (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3793 return BoolAttr::get(getContext(), false);
3794
3795 // Forward constant zero clocks.
3796 if (isConstantZero(adaptor.getInput()))
3797 return BoolAttr::get(getContext(), false);
3798
3799 return {};
3800}
3801
3802LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3803 PatternRewriter &rewriter) {
3804 // Remove constant false test enable.
3805 if (auto testEnable = op.getTestEnable()) {
3806 if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3807 if (constOp.getValue().isZero()) {
3808 rewriter.modifyOpInPlace(op,
3809 [&] { op.getTestEnableMutable().clear(); });
3810 return success();
3811 }
3812 }
3813 }
3814
3815 return failure();
3816}
3817
3818//===----------------------------------------------------------------------===//
3819// Reference Ops.
3820//===----------------------------------------------------------------------===//
3821
3822// refresolve(forceable.ref) -> forceable.data
3823static LogicalResult
3824canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3825 auto forceable = op.getRef().getDefiningOp<Forceable>();
3826 if (!forceable || !forceable.isForceable() ||
3827 op.getRef() != forceable.getDataRef() ||
3828 op.getType() != forceable.getDataType())
3829 return failure();
3830 rewriter.replaceAllUsesWith(op, forceable.getData());
3831 return success();
3832}
3833
3834void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3835 MLIRContext *context) {
3836 results.insert<patterns::RefResolveOfRefSend>(context);
3837 results.insert(canonicalizeRefResolveOfForceable);
3838}
3839
3840OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3841 // RefCast is unnecessary if types match.
3842 if (getInput().getType() == getType())
3843 return getInput();
3844 return {};
3845}
3846
3847static bool isConstantZero(Value operand) {
3848 auto constOp = operand.getDefiningOp<ConstantOp>();
3849 return constOp && constOp.getValue().isZero();
3850}
3851
3852template <typename Op>
3853static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3854 if (isConstantZero(op.getPredicate())) {
3855 rewriter.eraseOp(op);
3856 return success();
3857 }
3858 return failure();
3859}
3860
3861void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3862 MLIRContext *context) {
3863 results.add(eraseIfPredFalse<RefForceOp>);
3864}
3865void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3866 MLIRContext *context) {
3867 results.add(eraseIfPredFalse<RefForceInitialOp>);
3868}
3869void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3870 MLIRContext *context) {
3871 results.add(eraseIfPredFalse<RefReleaseOp>);
3872}
3873void RefReleaseInitialOp::getCanonicalizationPatterns(
3874 RewritePatternSet &results, MLIRContext *context) {
3875 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3876}
3877
3878//===----------------------------------------------------------------------===//
3879// HasBeenResetIntrinsicOp
3880//===----------------------------------------------------------------------===//
3881
3882OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3883 // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3884
3885 // Fold to zero if the reset is a constant. In this case the op is either
3886 // permanently in reset or never resets. Both mean that the reset never
3887 // finishes, so this op never returns true.
3888 if (adaptor.getReset())
3889 return getIntZerosAttr(UIntType::get(getContext(), 1));
3890
3891 // Fold to zero if the clock is a constant and the reset is synchronous. In
3892 // that case the reset will never be started.
3893 if (isUInt1(getReset().getType()) && adaptor.getClock())
3894 return getIntZerosAttr(UIntType::get(getContext(), 1));
3895
3896 return {};
3897}
3898
3899//===----------------------------------------------------------------------===//
3900// FPGAProbeIntrinsicOp
3901//===----------------------------------------------------------------------===//
3902
3903static bool isTypeEmpty(FIRRTLType type) {
3905 .Case<FVectorType>(
3906 [&](auto ty) -> bool { return isTypeEmpty(ty.getElementType()); })
3907 .Case<BundleType>([&](auto ty) -> bool {
3908 for (auto elem : ty.getElements())
3909 if (!isTypeEmpty(elem.type))
3910 return false;
3911 return true;
3912 })
3913 .Case<IntType>([&](auto ty) { return ty.getWidth() == 0; })
3914 .Default([](auto) -> bool { return false; });
3915}
3916
3917LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3918 PatternRewriter &rewriter) {
3919 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3920 if (!firrtlTy)
3921 return failure();
3922
3923 if (!isTypeEmpty(firrtlTy))
3924 return failure();
3925
3926 rewriter.eraseOp(op);
3927 return success();
3928}
3929
3930//===----------------------------------------------------------------------===//
3931// Layer Block Op
3932//===----------------------------------------------------------------------===//
3933
3934LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3935 PatternRewriter &rewriter) {
3936
3937 // If the layerblock is empty, erase it.
3938 if (op.getBody()->empty()) {
3939 rewriter.eraseOp(op);
3940 return success();
3941 }
3942
3943 return failure();
3944}
3945
3946//===----------------------------------------------------------------------===//
3947// Domain-related Ops
3948//===----------------------------------------------------------------------===//
3949
3950OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3951 // If no domains are specified, then forward the input to the result.
3952 if (getDomains().empty())
3953 return getInput();
3954
3955 return {};
3956}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static std::optional< bool > getBoolValue(Attribute attr)
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:218
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:18
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition FoldUtils.h:28
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition FoldUtils.h:35
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21
LogicalResult matchAndRewrite(BitsPrimOp bits, mlir::PatternRewriter &rewriter) const override