CIRCT 23.0.0git
Loading...
Searching...
No Matches
FIRRTLFolds.cpp
Go to the documentation of this file.
1//===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the folding and canonicalizations for FIRRTL ops.
10//
11//===----------------------------------------------------------------------===//
12
17#include "circt/Support/APInt.h"
18#include "circt/Support/LLVM.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/STLExtras.h"
24#include "llvm/ADT/SmallPtrSet.h"
25#include "llvm/ADT/StringExtras.h"
26#include "llvm/ADT/TypeSwitch.h"
27
28using namespace circt;
29using namespace firrtl;
30
31// Drop writes to old and pass through passthrough to make patterns easier to
32// write.
33static Value dropWrite(PatternRewriter &rewriter, OpResult old,
34 Value passthrough) {
35 SmallPtrSet<Operation *, 8> users;
36 for (auto *user : old.getUsers())
37 users.insert(user);
38 for (Operation *user : users)
39 if (auto connect = dyn_cast<FConnectLike>(user))
40 if (connect.getDest() == old)
41 rewriter.eraseOp(user);
42 return passthrough;
43}
44
45// Return true if it is OK to propagate the name to the operation.
46// Non-pure operations such as instances, registers, and memories are not
47// allowed to update names for name stabilities and LEC reasons.
48static bool isOkToPropagateName(Operation *op) {
49 // Conservatively disallow operations with regions to prevent performance
50 // regression due to recursive calls to sub-regions in mlir::isPure.
51 if (op->getNumRegions() != 0)
52 return false;
53 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
54}
55
56// Move a name hint from a soon to be deleted operation to a new operation.
57// Pass through the new operation to make patterns easier to write. This cannot
58// move a name to a port (block argument), doing so would require rewriting all
59// instance sites as well as the module.
60static Value moveNameHint(OpResult old, Value passthrough) {
61 Operation *op = passthrough.getDefiningOp();
62 // This should handle ports, but it isn't clear we can change those in
63 // canonicalizers.
64 assert(op && "passthrough must be an operation");
65 Operation *oldOp = old.getOwner();
66 auto name = oldOp->getAttrOfType<StringAttr>("name");
67 if (name && !name.getValue().empty() && isOkToPropagateName(op))
68 op->setAttr("name", name);
69 return passthrough;
70}
71
72// Declarative canonicalization patterns
73namespace circt {
74namespace firrtl {
75namespace patterns {
76#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
77} // namespace patterns
78} // namespace firrtl
79} // namespace circt
80
81/// Return true if this operation's operands and results all have a known width.
82/// This only works for integer types.
83static bool hasKnownWidthIntTypes(Operation *op) {
84 auto resultType = type_cast<IntType>(op->getResult(0).getType());
85 if (!resultType.hasWidth())
86 return false;
87 for (Value operand : op->getOperands())
88 if (!type_cast<IntType>(operand.getType()).hasWidth())
89 return false;
90 return true;
91}
92
93/// Return true if this value is 1 bit UInt.
94static bool isUInt1(Type type) {
95 auto t = type_dyn_cast<UIntType>(type);
96 if (!t || !t.hasWidth() || t.getWidth() != 1)
97 return false;
98 return true;
99}
100
101/// Set the name of an op based on the best of two names: The current name, and
102/// the name passed in.
103static void updateName(PatternRewriter &rewriter, Operation *op,
104 StringAttr name) {
105 // Should never rename InstanceOp
106 if (!name || name.getValue().empty() || !isOkToPropagateName(op))
107 return;
108 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) && "Should never rename");
109 auto newName = name.getValue(); // old name is interesting
110 auto newOpName = op->getAttrOfType<StringAttr>("name");
111 // new name might not be interesting
112 if (newOpName)
113 newName = chooseName(newOpName.getValue(), name.getValue());
114 // Only update if needed
115 if (!newOpName || newOpName.getValue() != newName)
116 rewriter.modifyOpInPlace(
117 op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
118}
119
120/// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
121/// If a replaced op has a "name" attribute, this function propagates the name
122/// to the new value.
123static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
124 Value newValue) {
125 if (auto *newOp = newValue.getDefiningOp()) {
126 auto name = op->getAttrOfType<StringAttr>("name");
127 updateName(rewriter, newOp, name);
128 }
129 rewriter.replaceOp(op, newValue);
130}
131
132/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
133/// attribute. If a replaced op has a "name" attribute, this function propagates
134/// the name to the new value.
135template <typename OpTy, typename... Args>
136static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
137 Operation *op, Args &&...args) {
138 auto name = op->getAttrOfType<StringAttr>("name");
139 auto newOp =
140 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
141 updateName(rewriter, newOp, name);
142 return newOp;
143}
144
145/// Return true if the name is droppable. Note that this is different from
146/// `isUselessName` because non-useless names may be also droppable.
148 if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
149 return namableOp.hasDroppableName();
150 return false;
151}
152
153/// Implicitly replace the operand to a constant folding operation with a const
154/// 0 in case the operand is non-constant but has a bit width 0, or if the
155/// operand is an invalid value.
156///
157/// This makes constant folding significantly easier, as we can simply pass the
158/// operands to an operation through this function to appropriately replace any
159/// zero-width dynamic values or invalid values with a constant of value 0.
160static std::optional<APSInt>
161getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
162 assert(type_cast<IntType>(operand.getType()) &&
163 "getExtendedConstant is limited to integer types");
164
165 // We never support constant folding to unknown width values.
166 if (destWidth < 0)
167 return {};
168
169 // Extension signedness follows the operand sign.
170 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
171 return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
172
173 // If the operand is zero bits, then we can return a zero of the result
174 // type.
175 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
176 return APSInt(destWidth,
177 type_cast<IntType>(operand.getType()).isUnsigned());
178 return {};
179}
180
181/// Determine the value of a constant operand for the sake of constant folding.
182static std::optional<APSInt> getConstant(Attribute operand) {
183 if (!operand)
184 return {};
185 if (auto attr = dyn_cast<BoolAttr>(operand))
186 return APSInt(APInt(1, attr.getValue()));
187 if (auto attr = dyn_cast<IntegerAttr>(operand))
188 return attr.getAPSInt();
189 return {};
190}
191
192/// Determine whether a constant operand is a zero value for the sake of
193/// constant folding. This considers `invalidvalue` to be zero.
194static bool isConstantZero(Attribute operand) {
195 if (auto cst = getConstant(operand))
196 return cst->isZero();
197 return false;
198}
199
200/// Determine whether a constant operand is a one value for the sake of constant
201/// folding.
202static bool isConstantOne(Attribute operand) {
203 if (auto cst = getConstant(operand))
204 return cst->isOne();
205 return false;
206}
207
208/// This is the policy for folding, which depends on the sort of operator we're
209/// processing.
210enum class BinOpKind {
211 Normal,
212 Compare,
214};
215
216/// Applies the constant folding function `calculate` to the given operands.
217///
218/// Sign or zero extends the operands appropriately to the bitwidth of the
219/// result type if \p useDstWidth is true, else to the larger of the two operand
220/// bit widths and depending on whether the operation is to be performed on
221/// signed or unsigned operands.
222static Attribute constFoldFIRRTLBinaryOp(
223 Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
224 const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
225 assert(operands.size() == 2 && "binary op takes two operands");
226
227 // We cannot fold something to an unknown width.
228 auto resultType = type_cast<IntType>(op->getResult(0).getType());
229 if (resultType.getWidthOrSentinel() < 0)
230 return {};
231
232 // Any binary op returning i0 is 0.
233 if (resultType.getWidthOrSentinel() == 0)
234 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
235
236 // Determine the operand widths. This is either dictated by the operand type,
237 // or if that type is an unsized integer, by the actual bits necessary to
238 // represent the constant value.
239 auto lhsWidth =
240 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
241 auto rhsWidth =
242 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
243 if (auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
244 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
245 if (auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
246 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
247
248 // Compares extend the operands to the widest of the operand types, not to the
249 // result type.
250 int32_t operandWidth;
251 switch (opKind) {
253 operandWidth = resultType.getWidthOrSentinel();
254 break;
256 // Compares compute with the widest operand, not at the destination type
257 // (which is always i1).
258 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
259 break;
261 operandWidth =
262 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
263 break;
264 }
265
266 auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
267 if (!lhs)
268 return {};
269 auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
270 if (!rhs)
271 return {};
272
273 APInt resultValue = calculate(*lhs, *rhs);
274
275 // If the result type is smaller than the computation then we need to
276 // narrow the constant after the calculation.
277 if (opKind == BinOpKind::DivideOrShift)
278 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
279
280 assert((unsigned)resultType.getWidthOrSentinel() ==
281 resultValue.getBitWidth());
282 return getIntAttr(resultType, resultValue);
283}
284
285/// Applies the canonicalization function `canonicalize` to the given operation.
286///
287/// Determines which (if any) of the operation's operands are constants, and
288/// provides them as arguments to the callback function. Any `invalidvalue` in
289/// the input is mapped to a constant zero. The value returned from the callback
290/// is used as the replacement for `op`, and an additional pad operation is
291/// inserted if necessary. Does nothing if the result of `op` is of unknown
292/// width, in which case the necessity of a pad cannot be determined.
293static LogicalResult canonicalizePrimOp(
294 Operation *op, PatternRewriter &rewriter,
295 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
296 // Can only operate on FIRRTL primitive operations.
297 if (op->getNumResults() != 1)
298 return failure();
299 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
300 if (!type)
301 return failure();
302
303 // Can only operate on operations with a known result width.
304 auto width = type.getBitWidthOrSentinel();
305 if (width < 0)
306 return failure();
307
308 // Determine which of the operands are constants.
309 SmallVector<Attribute, 3> constOperands;
310 constOperands.reserve(op->getNumOperands());
311 for (auto operand : op->getOperands()) {
312 Attribute attr;
313 if (auto *defOp = operand.getDefiningOp())
314 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
315 [&](auto op) { attr = op.getValueAttr(); });
316 constOperands.push_back(attr);
317 }
318
319 // Perform the canonicalization and materialize the result if it is a
320 // constant.
321 auto result = canonicalize(constOperands);
322 if (!result)
323 return failure();
324 Value resultValue;
325 if (auto cst = dyn_cast<Attribute>(result))
326 resultValue = op->getDialect()
327 ->materializeConstant(rewriter, cst, type, op->getLoc())
328 ->getResult(0);
329 else
330 resultValue = cast<Value>(result);
331
332 // Insert a pad if the type widths disagree.
333 if (width !=
334 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
335 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
336
337 // Insert a cast if this is a uint vs. sint or vice versa.
338 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
339 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
340 else if (type_isa<UIntType>(type) &&
341 type_isa<SIntType>(resultValue.getType()))
342 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
343
344 assert(type == resultValue.getType() && "canonicalization changed type");
345 replaceOpAndCopyName(rewriter, op, resultValue);
346 return success();
347}
348
349/// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
350/// value if `bitWidth` is 0.
351static APInt getMaxUnsignedValue(unsigned bitWidth) {
352 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
353}
354
355/// Get the smallest signed value of a given bit width. Returns a 1-bit zero
356/// value if `bitWidth` is 0.
357static APInt getMinSignedValue(unsigned bitWidth) {
358 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
359}
360
361/// Get the largest signed value of a given bit width. Returns a 1-bit zero
362/// value if `bitWidth` is 0.
363static APInt getMaxSignedValue(unsigned bitWidth) {
364 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
365}
366
367//===----------------------------------------------------------------------===//
368// Fold Hooks
369//===----------------------------------------------------------------------===//
370
371OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
372 assert(adaptor.getOperands().empty() && "constant has no operands");
373 return getValueAttr();
374}
375
376OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
377 assert(adaptor.getOperands().empty() && "constant has no operands");
378 return getValueAttr();
379}
380
381OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
382 assert(adaptor.getOperands().empty() && "constant has no operands");
383 return getFieldsAttr();
384}
385
386OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
387 assert(adaptor.getOperands().empty() && "constant has no operands");
388 return getValueAttr();
389}
390
391OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
392 assert(adaptor.getOperands().empty() && "constant has no operands");
393 return getValueAttr();
394}
395
396OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
397 assert(adaptor.getOperands().empty() && "constant has no operands");
398 return getValueAttr();
399}
400
401OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
402 assert(adaptor.getOperands().empty() && "constant has no operands");
403 return getValueAttr();
404}
405
406//===----------------------------------------------------------------------===//
407// Binary Operators
408//===----------------------------------------------------------------------===//
409
410OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
412 *this, adaptor.getOperands(), BinOpKind::Normal,
413 [=](const APSInt &a, const APSInt &b) { return a + b; });
414}
415
416void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417 MLIRContext *context) {
418 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
419 patterns::AddOfSelf, patterns::AddOfPad>(context);
420}
421
422OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
424 *this, adaptor.getOperands(), BinOpKind::Normal,
425 [=](const APSInt &a, const APSInt &b) { return a - b; });
426}
427
428void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
429 MLIRContext *context) {
430 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
431 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
432 patterns::SubOfPadL, patterns::SubOfPadR>(context);
433}
434
435OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
436 // mul(x, 0) -> 0
437 //
438 // This is legal because it aligns with the Scala FIRRTL Compiler
439 // interpretation of lowering invalid to constant zero before constant
440 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
441 // multiplication this way and will emit "x * 0".
442 if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
443 return getIntZerosAttr(getType());
444
446 *this, adaptor.getOperands(), BinOpKind::Normal,
447 [=](const APSInt &a, const APSInt &b) { return a * b; });
448}
449
450OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
451 /// div(x, x) -> 1
452 ///
453 /// Division by zero is undefined in the FIRRTL specification. This fold
454 /// exploits that fact to optimize self division to one. Note: this should
455 /// supersede any division with invalid or zero. Division of invalid by
456 /// invalid should be one.
457 if (getLhs() == getRhs()) {
458 auto width = getType().base().getWidthOrSentinel();
459 if (width == -1)
460 width = 2;
461 // Only fold if we have at least 1 bit of width to represent the `1` value.
462 if (width != 0)
463 return getIntAttr(getType(), APInt(width, 1));
464 }
465
466 // div(0, x) -> 0
467 //
468 // This is legal because it aligns with the Scala FIRRTL Compiler
469 // interpretation of lowering invalid to constant zero before constant
470 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
471 // division this way and will emit "0 / x".
472 if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
473 return getIntZerosAttr(getType());
474
475 /// div(x, 1) -> x : (uint, uint) -> uint
476 ///
477 /// UInt division by one returns the numerator. SInt division can't
478 /// be folded here because it increases the return type bitwidth by
479 /// one and requires sign extension (a new op).
480 if (auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
481 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
482 return getLhs();
483
485 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
486 [=](const APSInt &a, const APSInt &b) -> APInt {
487 if (!!b)
488 return a / b;
489 return APInt(a.getBitWidth(), 0);
490 });
491}
492
493OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
494 // rem(x, x) -> 0
495 //
496 // Division by zero is undefined in the FIRRTL specification. This fold
497 // exploits that fact to optimize self division remainder to zero. Note:
498 // this should supersede any division with invalid or zero. Remainder of
499 // division of invalid by invalid should be zero.
500 if (getLhs() == getRhs())
501 return getIntZerosAttr(getType());
502
503 // rem(0, x) -> 0
504 //
505 // This is legal because it aligns with the Scala FIRRTL Compiler
506 // interpretation of lowering invalid to constant zero before constant
507 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
508 // division this way and will emit "0 % x".
509 if (isConstantZero(adaptor.getLhs()))
510 return getIntZerosAttr(getType());
511
513 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
514 [=](const APSInt &a, const APSInt &b) -> APInt {
515 if (!!b)
516 return a % b;
517 return APInt(a.getBitWidth(), 0);
518 });
519}
520
521OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
523 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
524 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
525}
526
527OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
529 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
530 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
531}
532
533OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
535 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
536 [=](const APSInt &a, const APSInt &b) -> APInt {
537 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
538 : a.ashr(b);
539 });
540}
541
542// TODO: Move to DRR.
543OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
544 if (auto rhsCst = getConstant(adaptor.getRhs())) {
545 /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
546 if (rhsCst->isZero())
547 return getIntZerosAttr(getType());
548
549 /// and(x, -1) -> x
550 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
551 getRhs().getType() == getType())
552 return getLhs();
553 }
554
555 if (auto lhsCst = getConstant(adaptor.getLhs())) {
556 /// and(0, x) -> 0, 0 is largest or is implicit zero extended
557 if (lhsCst->isZero())
558 return getIntZerosAttr(getType());
559
560 /// and(-1, x) -> x
561 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
562 getRhs().getType() == getType())
563 return getRhs();
564 }
565
566 /// and(x, x) -> x
567 if (getLhs() == getRhs() && getRhs().getType() == getType())
568 return getRhs();
569
571 *this, adaptor.getOperands(), BinOpKind::Normal,
572 [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
573}
574
575void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
576 MLIRContext *context) {
577 results
578 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
579 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
580 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
581}
582
583OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
584 if (auto rhsCst = getConstant(adaptor.getRhs())) {
585 /// or(x, 0) -> x
586 if (rhsCst->isZero() && getLhs().getType() == getType())
587 return getLhs();
588
589 /// or(x, -1) -> -1
590 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
591 getLhs().getType() == getType())
592 return getRhs();
593 }
594
595 if (auto lhsCst = getConstant(adaptor.getLhs())) {
596 /// or(0, x) -> x
597 if (lhsCst->isZero() && getRhs().getType() == getType())
598 return getRhs();
599
600 /// or(-1, x) -> -1
601 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
602 getRhs().getType() == getType())
603 return getLhs();
604 }
605
606 /// or(x, x) -> x
607 if (getLhs() == getRhs() && getRhs().getType() == getType())
608 return getRhs();
609
611 *this, adaptor.getOperands(), BinOpKind::Normal,
612 [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
613}
614
615void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
616 MLIRContext *context) {
617 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
618 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
619 patterns::OrOrr>(context);
620}
621
622OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
623 /// xor(x, 0) -> x
624 if (auto rhsCst = getConstant(adaptor.getRhs()))
625 if (rhsCst->isZero() &&
626 firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
627 return getLhs();
628
629 /// xor(x, 0) -> x
630 if (auto lhsCst = getConstant(adaptor.getLhs()))
631 if (lhsCst->isZero() &&
632 firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
633 return getRhs();
634
635 /// xor(x, x) -> 0
636 if (getLhs() == getRhs())
637 return getIntAttr(
638 getType(),
639 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
640
642 *this, adaptor.getOperands(), BinOpKind::Normal,
643 [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
644}
645
646void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
647 MLIRContext *context) {
648 results.insert<patterns::extendXor, patterns::moveConstXor,
649 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
650 context);
651}
652
653void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
654 MLIRContext *context) {
655 results.insert<patterns::LEQWithConstLHS>(context);
656}
657
658OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
659 bool isUnsigned = getLhs().getType().base().isUnsigned();
660
661 // leq(x, x) -> 1
662 if (getLhs() == getRhs())
663 return getIntAttr(getType(), APInt(1, 1));
664
665 // Comparison against constant outside type bounds.
666 if (auto width = getLhs().getType().base().getWidth()) {
667 if (auto rhsCst = getConstant(adaptor.getRhs())) {
668 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
669 commonWidth = std::max(commonWidth, 1);
670
671 // leq(x, const) -> 0 where const < minValue of the unsigned type of x
672 // This can never occur since const is unsigned and cannot be less than 0.
673
674 // leq(x, const) -> 0 where const < minValue of the signed type of x
675 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
676 .slt(getMinSignedValue(*width).sext(commonWidth)))
677 return getIntAttr(getType(), APInt(1, 0));
678
679 // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
680 if (isUnsigned && rhsCst->zext(commonWidth)
681 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
682 return getIntAttr(getType(), APInt(1, 1));
683
684 // leq(x, const) -> 1 where const >= maxValue of the signed type of x
685 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
686 .sge(getMaxSignedValue(*width).sext(commonWidth)))
687 return getIntAttr(getType(), APInt(1, 1));
688 }
689 }
690
692 *this, adaptor.getOperands(), BinOpKind::Compare,
693 [=](const APSInt &a, const APSInt &b) -> APInt {
694 return APInt(1, a <= b);
695 });
696}
697
698void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
699 MLIRContext *context) {
700 results.insert<patterns::LTWithConstLHS>(context);
701}
702
703OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
704 IntType lhsType = getLhs().getType();
705 bool isUnsigned = lhsType.isUnsigned();
706
707 // lt(x, x) -> 0
708 if (getLhs() == getRhs())
709 return getIntAttr(getType(), APInt(1, 0));
710
711 // lt(x, 0) -> 0 when x is unsigned
712 if (auto rhsCst = getConstant(adaptor.getRhs())) {
713 if (rhsCst->isZero() && lhsType.isUnsigned())
714 return getIntAttr(getType(), APInt(1, 0));
715 }
716
717 // Comparison against constant outside type bounds.
718 if (auto width = lhsType.getWidth()) {
719 if (auto rhsCst = getConstant(adaptor.getRhs())) {
720 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
721 commonWidth = std::max(commonWidth, 1);
722
723 // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
724 // Handled explicitly above.
725
726 // lt(x, const) -> 0 where const <= minValue of the signed type of x
727 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
728 .sle(getMinSignedValue(*width).sext(commonWidth)))
729 return getIntAttr(getType(), APInt(1, 0));
730
731 // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
732 if (isUnsigned && rhsCst->zext(commonWidth)
733 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
734 return getIntAttr(getType(), APInt(1, 1));
735
736 // lt(x, const) -> 1 where const > maxValue of the signed type of x
737 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
738 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
739 return getIntAttr(getType(), APInt(1, 1));
740 }
741 }
742
744 *this, adaptor.getOperands(), BinOpKind::Compare,
745 [=](const APSInt &a, const APSInt &b) -> APInt {
746 return APInt(1, a < b);
747 });
748}
749
750void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
751 MLIRContext *context) {
752 results.insert<patterns::GEQWithConstLHS>(context);
753}
754
755OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
756 IntType lhsType = getLhs().getType();
757 bool isUnsigned = lhsType.isUnsigned();
758
759 // geq(x, x) -> 1
760 if (getLhs() == getRhs())
761 return getIntAttr(getType(), APInt(1, 1));
762
763 // geq(x, 0) -> 1 when x is unsigned
764 if (auto rhsCst = getConstant(adaptor.getRhs())) {
765 if (rhsCst->isZero() && isUnsigned)
766 return getIntAttr(getType(), APInt(1, 1));
767 }
768
769 // Comparison against constant outside type bounds.
770 if (auto width = lhsType.getWidth()) {
771 if (auto rhsCst = getConstant(adaptor.getRhs())) {
772 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
773 commonWidth = std::max(commonWidth, 1);
774
775 // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
776 if (isUnsigned && rhsCst->zext(commonWidth)
777 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
778 return getIntAttr(getType(), APInt(1, 0));
779
780 // geq(x, const) -> 0 where const > maxValue of the signed type of x
781 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
782 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
783 return getIntAttr(getType(), APInt(1, 0));
784
785 // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
786 // Handled explicitly above.
787
788 // geq(x, const) -> 1 where const <= minValue of the signed type of x
789 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
790 .sle(getMinSignedValue(*width).sext(commonWidth)))
791 return getIntAttr(getType(), APInt(1, 1));
792 }
793 }
794
796 *this, adaptor.getOperands(), BinOpKind::Compare,
797 [=](const APSInt &a, const APSInt &b) -> APInt {
798 return APInt(1, a >= b);
799 });
800}
801
802void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
803 MLIRContext *context) {
804 results.insert<patterns::GTWithConstLHS>(context);
805}
806
807OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
808 IntType lhsType = getLhs().getType();
809 bool isUnsigned = lhsType.isUnsigned();
810
811 // gt(x, x) -> 0
812 if (getLhs() == getRhs())
813 return getIntAttr(getType(), APInt(1, 0));
814
815 // Comparison against constant outside type bounds.
816 if (auto width = lhsType.getWidth()) {
817 if (auto rhsCst = getConstant(adaptor.getRhs())) {
818 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
819 commonWidth = std::max(commonWidth, 1);
820
821 // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
822 if (isUnsigned && rhsCst->zext(commonWidth)
823 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
824 return getIntAttr(getType(), APInt(1, 0));
825
826 // gt(x, const) -> 0 where const >= maxValue of the signed type of x
827 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
828 .sge(getMaxSignedValue(*width).sext(commonWidth)))
829 return getIntAttr(getType(), APInt(1, 0));
830
831 // gt(x, const) -> 1 where const < minValue of the unsigned type of x
832 // This can never occur since const is unsigned and cannot be less than 0.
833
834 // gt(x, const) -> 1 where const < minValue of the signed type of x
835 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
836 .slt(getMinSignedValue(*width).sext(commonWidth)))
837 return getIntAttr(getType(), APInt(1, 1));
838 }
839 }
840
842 *this, adaptor.getOperands(), BinOpKind::Compare,
843 [=](const APSInt &a, const APSInt &b) -> APInt {
844 return APInt(1, a > b);
845 });
846}
847
848OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
849 // eq(x, x) -> 1
850 if (getLhs() == getRhs())
851 return getIntAttr(getType(), APInt(1, 1));
852
853 if (auto rhsCst = getConstant(adaptor.getRhs())) {
854 /// eq(x, 1) -> x when x is 1 bit.
855 /// TODO: Support SInt<1> on the LHS etc.
856 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
857 getRhs().getType() == getType())
858 return getLhs();
859 }
860
862 *this, adaptor.getOperands(), BinOpKind::Compare,
863 [=](const APSInt &a, const APSInt &b) -> APInt {
864 return APInt(1, a == b);
865 });
866}
867
868LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
869 return canonicalizePrimOp(
870 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
871 if (auto rhsCst = getConstant(operands[1])) {
872 auto width = op.getLhs().getType().getBitWidthOrSentinel();
873
874 // eq(x, 0) -> not(x) when x is 1 bit.
875 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
876 op.getRhs().getType() == op.getType()) {
877 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
878 .getResult();
879 }
880
881 // eq(x, 0) -> not(orr(x)) when x is >1 bit
882 if (rhsCst->isZero() && width > 1) {
883 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
884 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
885 }
886
887 // eq(x, ~0) -> andr(x) when x is >1 bit
888 if (rhsCst->isAllOnes() && width > 1 &&
889 op.getLhs().getType() == op.getRhs().getType()) {
890 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
891 .getResult();
892 }
893 }
894 return {};
895 });
896}
897
898OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
899 // neq(x, x) -> 0
900 if (getLhs() == getRhs())
901 return getIntAttr(getType(), APInt(1, 0));
902
903 if (auto rhsCst = getConstant(adaptor.getRhs())) {
904 /// neq(x, 0) -> x when x is 1 bit.
905 /// TODO: Support SInt<1> on the LHS etc.
906 if (rhsCst->isZero() && getLhs().getType() == getType() &&
907 getRhs().getType() == getType())
908 return getLhs();
909 }
910
912 *this, adaptor.getOperands(), BinOpKind::Compare,
913 [=](const APSInt &a, const APSInt &b) -> APInt {
914 return APInt(1, a != b);
915 });
916}
917
918LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
919 return canonicalizePrimOp(
920 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
921 if (auto rhsCst = getConstant(operands[1])) {
922 auto width = op.getLhs().getType().getBitWidthOrSentinel();
923
924 // neq(x, 1) -> not(x) when x is 1 bit
925 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
926 op.getRhs().getType() == op.getType()) {
927 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
928 .getResult();
929 }
930
931 // neq(x, 0) -> orr(x) when x is >1 bit
932 if (rhsCst->isZero() && width > 1) {
933 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
934 .getResult();
935 }
936
937 // neq(x, ~0) -> not(andr(x))) when x is >1 bit
938 if (rhsCst->isAllOnes() && width > 1 &&
939 op.getLhs().getType() == op.getRhs().getType()) {
940 auto andrOp =
941 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
942 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
943 }
944 }
945
946 return {};
947 });
948}
949
950OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
951 // TODO: implement constant folding, etc.
952 // Tracked in https://github.com/llvm/circt/issues/6696.
953 return {};
954}
955
956OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
957 // TODO: implement constant folding, etc.
958 // Tracked in https://github.com/llvm/circt/issues/6724.
959 return {};
960}
961
962OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
963 if (auto rhsCst = getConstant(adaptor.getRhs())) {
964 if (auto lhsCst = getConstant(adaptor.getLhs())) {
965
966 return IntegerAttr::get(IntegerType::get(getContext(),
967 lhsCst->getBitWidth(),
968 IntegerType::Signed),
969 lhsCst->ashr(*rhsCst));
970 }
971
972 if (rhsCst->isZero())
973 return getLhs();
974 }
975
976 return {};
977}
978
979OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
980 if (auto rhsCst = getConstant(adaptor.getRhs())) {
981 // Constant folding
982 if (auto lhsCst = getConstant(adaptor.getLhs()))
983
984 return IntegerAttr::get(IntegerType::get(getContext(),
985 lhsCst->getBitWidth(),
986 IntegerType::Signed),
987 lhsCst->shl(*rhsCst));
988
989 // integer.shl(x, 0) -> x
990 if (rhsCst->isZero())
991 return getLhs();
992 }
993
994 return {};
995}
996
997//===----------------------------------------------------------------------===//
998// Unary Operators
999//===----------------------------------------------------------------------===//
1000
1001OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
1002 auto base = getInput().getType();
1003 auto w = getBitWidth(base);
1004 if (w)
1005 return getIntAttr(getType(), APInt(32, *w));
1006 return {};
1007}
1008
1009OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
1010 // No constant can be 'x' by definition.
1011 if (auto cst = getConstant(adaptor.getArg()))
1012 return getIntAttr(getType(), APInt(1, 0));
1013 return {};
1014}
1015
1016OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1017 // No effect.
1018 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1019 return getInput();
1020
1021 // Be careful to only fold the cast into the constant if the size is known.
1022 // Otherwise width inference may produce differently-sized constants if the
1023 // sign changes.
1024 if (getType().base().hasWidth())
1025 if (auto cst = getConstant(adaptor.getInput()))
1026 return getIntAttr(getType(), *cst);
1027
1028 return {};
1029}
1030
1031void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1032 MLIRContext *context) {
1033 results.insert<patterns::StoUtoS>(context);
1034}
1035
1036OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1037 // No effect.
1038 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1039 return getInput();
1040
1041 // Be careful to only fold the cast into the constant if the size is known.
1042 // Otherwise width inference may produce differently-sized constants if the
1043 // sign changes.
1044 if (getType().base().hasWidth())
1045 if (auto cst = getConstant(adaptor.getInput()))
1046 return getIntAttr(getType(), *cst);
1047
1048 return {};
1049}
1050
1051void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1052 MLIRContext *context) {
1053 results.insert<patterns::UtoStoU>(context);
1054}
1055
1056OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1057 // No effect.
1058 if (getInput().getType() == getType())
1059 return getInput();
1060
1061 // Constant fold.
1062 if (auto cst = getConstant(adaptor.getInput()))
1063 return BoolAttr::get(getContext(), cst->getBoolValue());
1064
1065 return {};
1066}
1067
1068OpFoldResult AsResetPrimOp::fold(FoldAdaptor adaptor) {
1069 if (auto cst = getConstant(adaptor.getInput()))
1070 return BoolAttr::get(getContext(), cst->getBoolValue());
1071 return {};
1072}
1073
1074OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1075 // No effect.
1076 if (getInput().getType() == getType())
1077 return getInput();
1078
1079 // Constant fold.
1080 if (auto cst = getConstant(adaptor.getInput()))
1081 return BoolAttr::get(getContext(), cst->getBoolValue());
1082
1083 return {};
1084}
1085
1086OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1087 if (!hasKnownWidthIntTypes(*this))
1088 return {};
1089
1090 // Signed to signed is a noop, unsigned operands prepend a zero bit.
1091 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1092 getType().base().getWidthOrSentinel()))
1093 return getIntAttr(getType(), *cst);
1094
1095 return {};
1096}
1097
1098void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1099 MLIRContext *context) {
1100 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1101}
1102
1103OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1104 if (!hasKnownWidthIntTypes(*this))
1105 return {};
1106
1107 // FIRRTL negate always adds a bit.
1108 // -x ---> 0-sext(x) or 0-zext(x)
1109 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1110 getType().base().getWidthOrSentinel()))
1111 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1112
1113 return {};
1114}
1115
1116OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1117 if (!hasKnownWidthIntTypes(*this))
1118 return {};
1119
1120 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1121 getType().base().getWidthOrSentinel()))
1122 return getIntAttr(getType(), ~*cst);
1123
1124 return {};
1125}
1126
1127void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1128 MLIRContext *context) {
1129 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1130 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1131 patterns::NotGt>(context);
1132}
1133
1134class ReductionCat : public mlir::RewritePattern {
1135public:
1136 ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
1137 : RewritePattern(opName, 0, context) {}
1138
1139 /// Handle a constant operand in the cat operation.
1140 /// Returns true if the entire reduction can be replaced with a constant.
1141 /// May add non-zero constants to the remaining operands list.
1142 virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1143 ConstantOp constantOp,
1144 SmallVectorImpl<Value> &remaining) const = 0;
1145
1146 /// Return the unit value for this reduction operation:
1147 virtual bool getIdentityValue() const = 0;
1148
1149 LogicalResult
1150 matchAndRewrite(Operation *op,
1151 mlir::PatternRewriter &rewriter) const override {
1152 // Check if the operand is a cat operation
1153 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1154 if (!catOp)
1155 return failure();
1156
1157 SmallVector<Value> nonConstantOperands;
1158
1159 // Process each operand of the cat operation
1160 for (auto operand : catOp.getInputs()) {
1161 if (auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1162 // Handle constant operands - may short-circuit the entire operation
1163 if (handleConstant(rewriter, op, constantOp, nonConstantOperands))
1164 return success();
1165 } else {
1166 // Keep non-constant operands for further processing
1167 nonConstantOperands.push_back(operand);
1168 }
1169 }
1170
1171 // If no operands remain, replace with identity value
1172 if (nonConstantOperands.empty()) {
1173 replaceOpWithNewOpAndCopyName<ConstantOp>(
1174 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1175 APInt(1, getIdentityValue()));
1176 return success();
1177 }
1178
1179 // If only one operand remains, apply reduction directly to it
1180 if (nonConstantOperands.size() == 1) {
1181 rewriter.modifyOpInPlace(
1182 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1183 return success();
1184 }
1185
1186 // Multiple operands remain - optimize only when cat has a single use.
1187 if (catOp->hasOneUse() &&
1188 nonConstantOperands.size() < catOp->getNumOperands()) {
1189 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1190 nonConstantOperands);
1191 return success();
1192 }
1193 return failure();
1194 }
1195};
1196
1197class OrRCat : public ReductionCat {
1198public:
1199 OrRCat(MLIRContext *context)
1200 : ReductionCat(context, OrRPrimOp::getOperationName()) {}
1201 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1202 ConstantOp value,
1203 SmallVectorImpl<Value> &remaining) const override {
1204 if (value.getValue().isZero())
1205 return false;
1206
1207 replaceOpWithNewOpAndCopyName<ConstantOp>(
1208 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1209 llvm::APInt(1, 1));
1210 return true;
1211 }
1212 bool getIdentityValue() const override { return false; }
1213};
1214
1215class AndRCat : public ReductionCat {
1216public:
1217 AndRCat(MLIRContext *context)
1218 : ReductionCat(context, AndRPrimOp::getOperationName()) {}
1219 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1220 ConstantOp value,
1221 SmallVectorImpl<Value> &remaining) const override {
1222 if (value.getValue().isAllOnes())
1223 return false;
1224
1225 replaceOpWithNewOpAndCopyName<ConstantOp>(
1226 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1227 llvm::APInt(1, 0));
1228 return true;
1229 }
1230 bool getIdentityValue() const override { return true; }
1231};
1232
1233class XorRCat : public ReductionCat {
1234public:
1235 XorRCat(MLIRContext *context)
1236 : ReductionCat(context, XorRPrimOp::getOperationName()) {}
1237 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1238 ConstantOp value,
1239 SmallVectorImpl<Value> &remaining) const override {
1240 if (value.getValue().isZero())
1241 return false;
1242 remaining.push_back(value);
1243 return false;
1244 }
1245 bool getIdentityValue() const override { return false; }
1246};
1247
1248OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1249 if (!hasKnownWidthIntTypes(*this))
1250 return {};
1251
1252 if (getInput().getType().getBitWidthOrSentinel() == 0)
1253 return getIntAttr(getType(), APInt(1, 1));
1254
1255 // x == -1
1256 if (auto cst = getConstant(adaptor.getInput()))
1257 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1258
1259 // one bit is identity. Only applies to UInt since we can't make a cast
1260 // here.
1261 if (isUInt1(getInput().getType()))
1262 return getInput();
1263
1264 return {};
1265}
1266
1267void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1268 MLIRContext *context) {
1269 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1270 patterns::AndRPadS, patterns::AndRCatAndR_left,
1271 patterns::AndRCatAndR_right, AndRCat>(context);
1272}
1273
1274OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1275 if (!hasKnownWidthIntTypes(*this))
1276 return {};
1277
1278 if (getInput().getType().getBitWidthOrSentinel() == 0)
1279 return getIntAttr(getType(), APInt(1, 0));
1280
1281 // x != 0
1282 if (auto cst = getConstant(adaptor.getInput()))
1283 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1284
1285 // one bit is identity. Only applies to UInt since we can't make a cast
1286 // here.
1287 if (isUInt1(getInput().getType()))
1288 return getInput();
1289
1290 return {};
1291}
1292
1293void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1294 MLIRContext *context) {
1295 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1296 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right, OrRCat>(
1297 context);
1298}
1299
1300OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1301 if (!hasKnownWidthIntTypes(*this))
1302 return {};
1303
1304 if (getInput().getType().getBitWidthOrSentinel() == 0)
1305 return getIntAttr(getType(), APInt(1, 0));
1306
1307 // popcount(x) & 1
1308 if (auto cst = getConstant(adaptor.getInput()))
1309 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1310
1311 // one bit is identity. Only applies to UInt since we can't make a cast here.
1312 if (isUInt1(getInput().getType()))
1313 return getInput();
1314
1315 return {};
1316}
1317
1318void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1319 MLIRContext *context) {
1320 results
1321 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1322 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right, XorRCat>(
1323 context);
1324}
1325
1326//===----------------------------------------------------------------------===//
1327// Other Operators
1328//===----------------------------------------------------------------------===//
1329
1330OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1331 auto inputs = getInputs();
1332 auto inputAdaptors = adaptor.getInputs();
1333
1334 // If no inputs, return 0-bit value
1335 if (inputs.empty())
1336 return getIntZerosAttr(getType());
1337
1338 // If single input and same type, return it
1339 if (inputs.size() == 1 && inputs[0].getType() == getType())
1340 return inputs[0];
1341
1342 // Make sure it's safe to fold.
1343 if (!hasKnownWidthIntTypes(*this))
1344 return {};
1345
1346 // Filter out zero-width operands
1347 SmallVector<Value> nonZeroInputs;
1348 SmallVector<Attribute> nonZeroAttributes;
1349 bool allConstant = true;
1350 for (auto [input, attr] : llvm::zip(inputs, inputAdaptors)) {
1351 auto inputType = type_cast<IntType>(input.getType());
1352 if (inputType.getBitWidthOrSentinel() != 0) {
1353 nonZeroInputs.push_back(input);
1354 if (!attr)
1355 allConstant = false;
1356 if (nonZeroInputs.size() > 1 && !allConstant)
1357 return {};
1358 }
1359 }
1360
1361 // If all inputs were zero-width, return 0-bit value
1362 if (nonZeroInputs.empty())
1363 return getIntZerosAttr(getType());
1364
1365 // If only one non-zero input and it has the same type as result, return it
1366 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1367 return nonZeroInputs[0];
1368
1369 if (!hasKnownWidthIntTypes(*this))
1370 return {};
1371
1372 // Constant fold cat - concatenate all constant operands
1373 SmallVector<APInt> constants;
1374 for (auto inputAdaptor : inputAdaptors) {
1375 if (auto cst = getConstant(inputAdaptor))
1376 constants.push_back(*cst);
1377 else
1378 return {}; // Not all operands are constant
1379 }
1380
1381 assert(!constants.empty());
1382 // Concatenate all constants from left to right
1383 APInt result = constants[0];
1384 for (size_t i = 1; i < constants.size(); ++i)
1385 result = result.concat(constants[i]);
1386
1387 return getIntAttr(getType(), result);
1388}
1389
1390void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1391 MLIRContext *context) {
1392 results.insert<patterns::DShlOfConstant>(context);
1393}
1394
1395void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1396 MLIRContext *context) {
1397 results.insert<patterns::DShrOfConstant>(context);
1398}
1399
1400namespace {
1401
1402/// Canonicalize a cat tree into single cat operation.
1403class FlattenCat : public mlir::OpRewritePattern<CatPrimOp> {
1404public:
1405 using OpRewritePattern::OpRewritePattern;
1406
1407 LogicalResult
1408 matchAndRewrite(CatPrimOp cat,
1409 mlir::PatternRewriter &rewriter) const override {
1410 if (!hasKnownWidthIntTypes(cat) ||
1411 cat.getType().getBitWidthOrSentinel() == 0)
1412 return failure();
1413
1414 // If this is not a root of concat tree, skip.
1415 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1416 return failure();
1417
1418 // Try flattening the cat tree.
1419 SmallVector<Value> operands;
1420 SmallVector<Value> worklist;
1421 auto pushOperands = [&worklist](CatPrimOp op) {
1422 for (auto operand : llvm::reverse(op.getInputs()))
1423 worklist.push_back(operand);
1424 };
1425 pushOperands(cat);
1426 bool hasSigned = false, hasUnsigned = false;
1427 while (!worklist.empty()) {
1428 auto value = worklist.pop_back_val();
1429 auto catOp = value.getDefiningOp<CatPrimOp>();
1430 if (!catOp) {
1431 operands.push_back(value);
1432 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) = true;
1433 continue;
1434 }
1435
1436 pushOperands(catOp);
1437 }
1438
1439 // Helper function to cast signed values to unsigned. CatPrimOp converts
1440 // signed values to unsigned so we need to do it explicitly here.
1441 auto castToUIntIfSigned = [&](Value value) -> Value {
1442 if (type_isa<UIntType>(value.getType()))
1443 return value;
1444 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1445 };
1446
1447 assert(operands.size() >= 1 && "zero width cast must be rejected");
1448
1449 if (operands.size() == 1) {
1450 rewriter.replaceOp(cat, castToUIntIfSigned(operands[0]));
1451 return success();
1452 }
1453
1454 if (operands.size() == cat->getNumOperands())
1455 return failure();
1456
1457 // If types are mixed, cast all operands to unsigned.
1458 if (hasSigned && hasUnsigned)
1459 for (auto &operand : operands)
1460 operand = castToUIntIfSigned(operand);
1461
1462 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1463 operands);
1464 return success();
1465 }
1466};
1467
1468// Fold successive constants cat(x, y, 1, 10, z) -> cat(x, y, 110, z)
1469class CatOfConstant : public mlir::OpRewritePattern<CatPrimOp> {
1470public:
1471 using OpRewritePattern::OpRewritePattern;
1472
1473 LogicalResult
1474 matchAndRewrite(CatPrimOp cat,
1475 mlir::PatternRewriter &rewriter) const override {
1476 if (!hasKnownWidthIntTypes(cat))
1477 return failure();
1478
1479 SmallVector<Value> operands;
1480
1481 for (size_t i = 0; i < cat->getNumOperands(); ++i) {
1482 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1483 if (!cst) {
1484 operands.push_back(cat.getInputs()[i]);
1485 continue;
1486 }
1487 APSInt value = cst.getValue();
1488 size_t j = i + 1;
1489 for (; j < cat->getNumOperands(); ++j) {
1490 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1491 if (!nextCst)
1492 break;
1493 value = value.concat(nextCst.getValue());
1494 }
1495
1496 if (j == i + 1) {
1497 // Not folded.
1498 operands.push_back(cst);
1499 } else {
1500 // Folded.
1501 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1502 }
1503
1504 i = j - 1;
1505 }
1506
1507 if (operands.size() == cat->getNumOperands())
1508 return failure();
1509
1510 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, cat, cat.getType(),
1511 operands);
1512
1513 return success();
1514 }
1515};
1516
1517} // namespace
1518
1519void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1520 MLIRContext *context) {
1521 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1522 patterns::CatCast, FlattenCat, CatOfConstant>(context);
1523}
1524
1525//===----------------------------------------------------------------------===//
1526// StringConcatOp
1527//===----------------------------------------------------------------------===//
1528
1529OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
1530 // Fold single-operand concat to just the operand.
1531 if (getInputs().size() == 1)
1532 return getInputs()[0];
1533
1534 // Check if all operands are constant strings before accumulating.
1535 if (!llvm::all_of(adaptor.getInputs(), [](Attribute operand) {
1536 return isa_and_nonnull<StringAttr>(operand);
1537 }))
1538 return {};
1539
1540 // All operands are constant strings, concatenate them.
1541 SmallString<64> result;
1542 for (auto operand : adaptor.getInputs())
1543 result += cast<StringAttr>(operand).getValue();
1544
1545 return StringAttr::get(getContext(), result);
1546}
1547
1548namespace {
1549/// Flatten nested string.concat operations into a single concat.
1550/// string.concat(a, string.concat(b, c), d) -> string.concat(a, b, c, d)
1551class FlattenStringConcat : public mlir::OpRewritePattern<StringConcatOp> {
1552public:
1553 using OpRewritePattern::OpRewritePattern;
1554
1555 LogicalResult
1556 matchAndRewrite(StringConcatOp concat,
1557 mlir::PatternRewriter &rewriter) const override {
1558
1559 // Check if any operands are nested concats with a single use. Only inline
1560 // single-use nested concats to avoid fighting with DCE.
1561 bool hasNestedConcat = llvm::any_of(concat.getInputs(), [](Value operand) {
1562 auto nestedConcat = operand.getDefiningOp<StringConcatOp>();
1563 return nestedConcat && operand.hasOneUse();
1564 });
1565
1566 if (!hasNestedConcat)
1567 return failure();
1568
1569 // Flatten nested concats that have a single use.
1570 SmallVector<Value> flatOperands;
1571 for (auto input : concat.getInputs()) {
1572 if (auto nestedConcat = input.getDefiningOp<StringConcatOp>();
1573 nestedConcat && input.hasOneUse())
1574 llvm::append_range(flatOperands, nestedConcat.getInputs());
1575 else
1576 flatOperands.push_back(input);
1577 }
1578
1579 rewriter.modifyOpInPlace(concat,
1580 [&]() { concat->setOperands(flatOperands); });
1581 return success();
1582 }
1583};
1584
1585/// Merge consecutive constant strings in a concat and remove empty strings.
1586/// string.concat("a", "b", x, "", "c", "d") -> string.concat("ab", x, "cd")
1587class MergeAdjacentStringConstants
1588 : public mlir::OpRewritePattern<StringConcatOp> {
1589public:
1590 using OpRewritePattern::OpRewritePattern;
1591
1592 LogicalResult
1593 matchAndRewrite(StringConcatOp concat,
1594 mlir::PatternRewriter &rewriter) const override {
1595
1596 SmallVector<Value> newOperands;
1597 SmallString<64> accumulatedLit;
1598 SmallVector<StringConstantOp> accumulatedOps;
1599 bool changed = false;
1600
1601 auto flushLiterals = [&]() {
1602 if (accumulatedOps.empty())
1603 return;
1604
1605 // If only one literal, reuse it.
1606 if (accumulatedOps.size() == 1) {
1607 newOperands.push_back(accumulatedOps[0]);
1608 } else {
1609 // Multiple literals - merge them.
1610 auto newLit = rewriter.createOrFold<StringConstantOp>(
1611 concat.getLoc(), StringAttr::get(getContext(), accumulatedLit));
1612 newOperands.push_back(newLit);
1613 changed = true;
1614 }
1615 accumulatedLit.clear();
1616 accumulatedOps.clear();
1617 };
1618
1619 for (auto operand : concat.getInputs()) {
1620 if (auto litOp = operand.getDefiningOp<StringConstantOp>()) {
1621 // Skip empty strings.
1622 if (litOp.getValue().empty()) {
1623 changed = true;
1624 continue;
1625 }
1626 accumulatedLit += litOp.getValue();
1627 accumulatedOps.push_back(litOp);
1628 } else {
1629 flushLiterals();
1630 newOperands.push_back(operand);
1631 }
1632 }
1633
1634 // Flush any remaining literals.
1635 flushLiterals();
1636
1637 if (!changed)
1638 return failure();
1639
1640 // If no operands remain, replace with empty string.
1641 if (newOperands.empty())
1642 return rewriter.replaceOpWithNewOp<StringConstantOp>(
1643 concat, StringAttr::get(getContext(), "")),
1644 success();
1645
1646 // Single-operand case is handled by the folder.
1647 rewriter.modifyOpInPlace(concat,
1648 [&]() { concat->setOperands(newOperands); });
1649 return success();
1650 }
1651};
1652
1653} // namespace
1654
1655void StringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
1656 MLIRContext *context) {
1657 results.insert<FlattenStringConcat, MergeAdjacentStringConstants>(context);
1658}
1659
1660//===----------------------------------------------------------------------===//
1661// PropEqOp
1662//===----------------------------------------------------------------------===//
1663
1664OpFoldResult PropEqOp::fold(FoldAdaptor adaptor) {
1665 auto lhsAttr = adaptor.getLhs();
1666 auto rhsAttr = adaptor.getRhs();
1667 if (!lhsAttr || !rhsAttr)
1668 return {};
1669
1670 return BoolAttr::get(getContext(), lhsAttr == rhsAttr);
1671}
1672
1673//===----------------------------------------------------------------------===//
1674// Boolean Property Ops
1675//===----------------------------------------------------------------------===//
1676
1677// Helper to extract the bool value from a BoolAttr.
1678static std::optional<bool> getBoolValue(Attribute attr) {
1679 if (auto boolAttr = dyn_cast_or_null<BoolAttr>(attr))
1680 return boolAttr.getValue();
1681 return std::nullopt;
1682}
1683
1684OpFoldResult BoolAndOp::fold(FoldAdaptor adaptor) {
1685 auto lhs = getBoolValue(adaptor.getLhs());
1686 auto rhs = getBoolValue(adaptor.getRhs());
1687 if (lhs && rhs)
1688 return BoolAttr::get(getContext(), *lhs && *rhs);
1689 // AND with false is always false.
1690 if ((lhs && !*lhs) || (rhs && !*rhs))
1691 return BoolAttr::get(getContext(), false);
1692 // AND with true is identity.
1693 if (lhs && *lhs)
1694 return getRhs();
1695 if (rhs && *rhs)
1696 return getLhs();
1697 return {};
1698}
1699
1700OpFoldResult BoolOrOp::fold(FoldAdaptor adaptor) {
1701 auto lhs = getBoolValue(adaptor.getLhs());
1702 auto rhs = getBoolValue(adaptor.getRhs());
1703 if (lhs && rhs)
1704 return BoolAttr::get(getContext(), *lhs || *rhs);
1705 // OR with true is always true.
1706 if ((lhs && *lhs) || (rhs && *rhs))
1707 return BoolAttr::get(getContext(), true);
1708 // OR with false is identity.
1709 if (lhs && !*lhs)
1710 return getRhs();
1711 if (rhs && !*rhs)
1712 return getLhs();
1713 return {};
1714}
1715
1716OpFoldResult BoolXorOp::fold(FoldAdaptor adaptor) {
1717 auto lhs = getBoolValue(adaptor.getLhs());
1718 auto rhs = getBoolValue(adaptor.getRhs());
1719 if (lhs && rhs)
1720 return BoolAttr::get(getContext(), *lhs ^ *rhs);
1721 // XOR with false is identity.
1722 if (lhs && !*lhs)
1723 return getRhs();
1724 if (rhs && !*rhs)
1725 return getLhs();
1726 return {};
1727}
1728
1729OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1730 auto op = (*this);
1731 // BitCast is redundant if input and result types are same.
1732 if (op.getType() == op.getInput().getType())
1733 return op.getInput();
1734
1735 // Two consecutive BitCasts are redundant if first bitcast type is same as the
1736 // final result type.
1737 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1738 if (op.getType() == in.getInput().getType())
1739 return in.getInput();
1740
1741 return {};
1742}
1743
1744OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1745 IntType inputType = getInput().getType();
1746 IntType resultType = getType();
1747 // If we are extracting the entire input, then return it.
1748 if (inputType == getType() && resultType.hasWidth())
1749 return getInput();
1750
1751 // Constant fold.
1752 if (hasKnownWidthIntTypes(*this))
1753 if (auto cst = getConstant(adaptor.getInput()))
1754 return getIntAttr(resultType,
1755 cst->extractBits(getHi() - getLo() + 1, getLo()));
1756
1757 return {};
1758}
1759
1760struct BitsOfCat : public mlir::OpRewritePattern<BitsPrimOp> {
1761 using OpRewritePattern::OpRewritePattern;
1762
1763 LogicalResult
1764 matchAndRewrite(BitsPrimOp bits,
1765 mlir::PatternRewriter &rewriter) const override {
1766 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1767 if (!cat)
1768 return failure();
1769 int32_t bitPos = bits.getLo();
1770 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1771 if (resultWidth < 0)
1772 return failure();
1773 for (auto operand : llvm::reverse(cat.getInputs())) {
1774 auto operandWidth =
1775 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1776 if (operandWidth < 0)
1777 return failure();
1778 if (bitPos < operandWidth) {
1779 if (bitPos + resultWidth <= operandWidth) {
1780 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1781 bits.getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1782 replaceOpAndCopyName(rewriter, bits, newBits);
1783 return success();
1784 }
1785 return failure();
1786 }
1787 bitPos -= operandWidth;
1788 }
1789 return failure();
1790 }
1791};
1792
1793void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1794 MLIRContext *context) {
1795 results
1796 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1797 patterns::BitsOfAnd, patterns::BitsOfPad, BitsOfCat>(context);
1798}
1799
1800/// Replace the specified operation with a 'bits' op from the specified hi/lo
1801/// bits. Insert a cast to handle the case where the original operation
1802/// returned a signed integer.
1803static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1804 unsigned loBit, PatternRewriter &rewriter) {
1805 auto resType = type_cast<IntType>(op->getResult(0).getType());
1806 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1807 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1808
1809 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1810 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1811 } else if (resType.isUnsigned() &&
1812 !type_cast<IntType>(value.getType()).isUnsigned()) {
1813 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1814 }
1815 rewriter.replaceOp(op, value);
1816}
1817
1818template <typename OpTy>
1819static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1820 // mux : UInt<0> -> 0
1821 if (op.getType().getBitWidthOrSentinel() == 0)
1822 return getIntAttr(op.getType(),
1823 APInt(0, 0, op.getType().isSignedInteger()));
1824
1825 // mux(cond, x, x) -> x
1826 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1827 return op.getHigh();
1828
1829 // The following folds require that the result has a known width. Otherwise
1830 // the mux requires an additional padding operation to be inserted, which is
1831 // not possible in a fold.
1832 if (op.getType().getBitWidthOrSentinel() < 0)
1833 return {};
1834
1835 // mux(0/1, x, y) -> x or y
1836 if (auto cond = getConstant(adaptor.getSel())) {
1837 if (cond->isZero() && op.getLow().getType() == op.getType())
1838 return op.getLow();
1839 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1840 return op.getHigh();
1841 }
1842
1843 // mux(cond, x, cst)
1844 if (auto lowCst = getConstant(adaptor.getLow())) {
1845 // mux(cond, c1, c2)
1846 if (auto highCst = getConstant(adaptor.getHigh())) {
1847 // mux(cond, cst, cst) -> cst
1848 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1849 *highCst == *lowCst)
1850 // Ensure that this has an integer representation. This specifically
1851 // avoids problems with clocks.
1852 if (auto intType = type_dyn_cast<IntType>(op.getType()))
1853 if (intType.hasWidth() &&
1854 (unsigned)intType.getWidthOrSentinel() == highCst->getBitWidth())
1855 return getIntAttr(op.getType(), *highCst);
1856 // mux(cond, 1, 0) -> cond
1857 if (highCst->isOne() && lowCst->isZero() &&
1858 op.getType() == op.getSel().getType())
1859 return op.getSel();
1860
1861 // TODO: x ? ~0 : 0 -> sext(x)
1862 // TODO: "x ? c1 : c2" -> many tricks
1863 }
1864 // TODO: "x ? a : 0" -> sext(x) & a
1865 }
1866
1867 // TODO: "x ? c1 : y" -> "~x ? y : c1"
1868 return {};
1869}
1870
1871OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1872 return foldMux(*this, adaptor);
1873}
1874
1875OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1876 return foldMux(*this, adaptor);
1877}
1878
1879OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1880
1881namespace {
1882
1883// If the mux has a known output width, pad the operands up to this width.
1884// Most folds on mux require that folded operands are of the same width as
1885// the mux itself.
1886class MuxPad : public mlir::OpRewritePattern<MuxPrimOp> {
1887public:
1888 using OpRewritePattern::OpRewritePattern;
1889
1890 LogicalResult
1891 matchAndRewrite(MuxPrimOp mux,
1892 mlir::PatternRewriter &rewriter) const override {
1893 auto width = mux.getType().getBitWidthOrSentinel();
1894 if (width < 0)
1895 return failure();
1896
1897 auto pad = [&](Value input) -> Value {
1898 auto inputWidth =
1899 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1900 if (inputWidth < 0 || width == inputWidth)
1901 return input;
1902 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1903 width)
1904 .getResult();
1905 };
1906
1907 auto newHigh = pad(mux.getHigh());
1908 auto newLow = pad(mux.getLow());
1909 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1910 return failure();
1911
1912 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1913 rewriter, mux, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1914 mux->getAttrs());
1915 return success();
1916 }
1917};
1918
1919// Find muxes which have conditions dominated by other muxes with the same
1920// condition.
1921class MuxSharedCond : public mlir::OpRewritePattern<MuxPrimOp> {
1922public:
1923 using OpRewritePattern::OpRewritePattern;
1924
1925 static const int depthLimit = 5;
1926
1927 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1928 mlir::PatternRewriter &rewriter,
1929 bool updateInPlace) const {
1930 if (updateInPlace) {
1931 rewriter.modifyOpInPlace(mux, [&] {
1932 mux.setOperand(1, high);
1933 mux.setOperand(2, low);
1934 });
1935 return {};
1936 }
1937 rewriter.setInsertionPointAfter(mux);
1938 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1939 ValueRange{mux.getSel(), high, low})
1940 .getResult();
1941 }
1942
1943 // Walk a dependent mux tree assuming the condition cond is true.
1944 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1945 bool updateInPlace, int limit) const {
1946 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1947 if (!mux)
1948 return {};
1949 if (mux.getSel() == cond)
1950 return mux.getHigh();
1951 if (limit > depthLimit)
1952 return {};
1953 updateInPlace &= mux->hasOneUse();
1954
1955 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1956 limit + 1))
1957 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1958
1959 if (Value v =
1960 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1961 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1962 return {};
1963 }
1964
1965 // Walk a dependent mux tree assuming the condition cond is false.
1966 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1967 bool updateInPlace, int limit) const {
1968 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1969 if (!mux)
1970 return {};
1971 if (mux.getSel() == cond)
1972 return mux.getLow();
1973 if (limit > depthLimit)
1974 return {};
1975 updateInPlace &= mux->hasOneUse();
1976
1977 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1978 limit + 1))
1979 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1980
1981 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1982 limit + 1))
1983 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1984
1985 return {};
1986 }
1987
1988 LogicalResult
1989 matchAndRewrite(MuxPrimOp mux,
1990 mlir::PatternRewriter &rewriter) const override {
1991 auto width = mux.getType().getBitWidthOrSentinel();
1992 if (width < 0)
1993 return failure();
1994
1995 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1996 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1997 return success();
1998 }
1999
2000 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
2001 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
2002 return success();
2003 }
2004
2005 return failure();
2006 }
2007};
2008} // namespace
2009
2010void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
2011 MLIRContext *context) {
2012 results
2013 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
2014 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
2015 patterns::MuxSameTrue, patterns::MuxSameFalse,
2016 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
2017 context);
2018}
2019
2020void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
2021 RewritePatternSet &results, MLIRContext *context) {
2022 results.add<patterns::Mux2PadSel>(context);
2023}
2024
2025void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
2026 RewritePatternSet &results, MLIRContext *context) {
2027 results.add<patterns::Mux4PadSel>(context);
2028}
2029
2030OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
2031 auto input = this->getInput();
2032
2033 // pad(x) -> x if the width doesn't change.
2034 if (input.getType() == getType())
2035 return input;
2036
2037 // Need to know the input width.
2038 auto inputType = input.getType().base();
2039 int32_t width = inputType.getWidthOrSentinel();
2040 if (width == -1)
2041 return {};
2042
2043 // Constant fold.
2044 if (auto cst = getConstant(adaptor.getInput())) {
2045 auto destWidth = getType().base().getWidthOrSentinel();
2046 if (destWidth == -1)
2047 return {};
2048
2049 if (inputType.isSigned() && cst->getBitWidth())
2050 return getIntAttr(getType(), cst->sext(destWidth));
2051 return getIntAttr(getType(), cst->zext(destWidth));
2052 }
2053
2054 return {};
2055}
2056
2057OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
2058 auto input = this->getInput();
2059 IntType inputType = input.getType();
2060 int shiftAmount = getAmount();
2061
2062 // shl(x, 0) -> x
2063 if (shiftAmount == 0)
2064 return input;
2065
2066 // Constant fold.
2067 if (auto cst = getConstant(adaptor.getInput())) {
2068 auto inputWidth = inputType.getWidthOrSentinel();
2069 if (inputWidth != -1) {
2070 auto resultWidth = inputWidth + shiftAmount;
2071 shiftAmount = std::min(shiftAmount, resultWidth);
2072 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
2073 }
2074 }
2075 return {};
2076}
2077
2078OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
2079 auto input = this->getInput();
2080 IntType inputType = input.getType();
2081 int shiftAmount = getAmount();
2082 auto inputWidth = inputType.getWidthOrSentinel();
2083
2084 // shr(x, 0) -> x
2085 // Once the shr width changes, do this: shiftAmount == 0 &&
2086 // (!inputType.isSigned() || inputWidth > 0)
2087 if (shiftAmount == 0 && inputWidth > 0)
2088 return input;
2089
2090 if (inputWidth == -1)
2091 return {};
2092 if (inputWidth == 0)
2093 return getIntZerosAttr(getType());
2094
2095 // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
2096 // If x is signed, it is the sign bit.
2097 if (shiftAmount >= inputWidth && inputType.isUnsigned())
2098 return getIntAttr(getType(), APInt(0, 0, false));
2099
2100 // Constant fold.
2101 if (auto cst = getConstant(adaptor.getInput())) {
2102 APInt value;
2103 if (inputType.isSigned())
2104 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
2105 else
2106 value = cst->lshr(std::min(shiftAmount, inputWidth));
2107 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
2108 return getIntAttr(getType(), value.trunc(resultWidth));
2109 }
2110 return {};
2111}
2112
2113LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
2114 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2115 if (inputWidth <= 0)
2116 return failure();
2117
2118 // If we know the input width, we can canonicalize this into a BitsPrimOp.
2119 unsigned shiftAmount = op.getAmount();
2120 if (int(shiftAmount) >= inputWidth) {
2121 // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
2122 if (op.getType().base().isUnsigned())
2123 return failure();
2124
2125 // Shifting a signed value by the full width is actually taking the
2126 // sign bit. If the shift amount is greater than the input width, it
2127 // is equivalent to shifting by the input width.
2128 shiftAmount = inputWidth - 1;
2129 }
2130
2131 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
2132 return success();
2133}
2134
2135LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
2136 PatternRewriter &rewriter) {
2137 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2138 if (inputWidth <= 0)
2139 return failure();
2140
2141 // If we know the input width, we can canonicalize this into a BitsPrimOp.
2142 unsigned keepAmount = op.getAmount();
2143 if (keepAmount)
2144 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
2145 rewriter);
2146 return success();
2147}
2148
2149OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
2150 if (hasKnownWidthIntTypes(*this))
2151 if (auto cst = getConstant(adaptor.getInput())) {
2152 int shiftAmount =
2153 getInput().getType().base().getWidthOrSentinel() - getAmount();
2154 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
2155 }
2156
2157 return {};
2158}
2159
2160OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
2161 if (hasKnownWidthIntTypes(*this))
2162 if (auto cst = getConstant(adaptor.getInput()))
2163 return getIntAttr(getType(),
2164 cst->trunc(getType().base().getWidthOrSentinel()));
2165 return {};
2166}
2167
2168LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
2169 PatternRewriter &rewriter) {
2170 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
2171 if (inputWidth <= 0)
2172 return failure();
2173
2174 // If we know the input width, we can canonicalize this into a BitsPrimOp.
2175 unsigned dropAmount = op.getAmount();
2176 if (dropAmount != unsigned(inputWidth))
2177 replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
2178 rewriter);
2179 return success();
2180}
2181
2182void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
2183 MLIRContext *context) {
2184 results.add<patterns::SubaccessOfConstant>(context);
2185}
2186
2187OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
2188 // If there is only one input, just return it.
2189 if (adaptor.getInputs().size() == 1)
2190 return getOperand(1);
2191
2192 if (auto constIndex = getConstant(adaptor.getIndex())) {
2193 auto index = constIndex->getZExtValue();
2194 if (index < getInputs().size())
2195 return getInputs()[getInputs().size() - 1 - index];
2196 }
2197
2198 return {};
2199}
2200
2201LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
2202 PatternRewriter &rewriter) {
2203 // If all operands are equal, just canonicalize to it. We can add this
2204 // canonicalization as a folder but it costly to look through all inputs so it
2205 // is added here.
2206 if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
2207 return input == op.getInputs().front();
2208 })) {
2209 replaceOpAndCopyName(rewriter, op, op.getInputs().front());
2210 return success();
2211 }
2212
2213 // If the index width is narrower than the size of inputs, drop front
2214 // elements.
2215 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
2216 uint64_t inputSize = op.getInputs().size();
2217 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
2218 rewriter.modifyOpInPlace(op, [&]() {
2219 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2220 });
2221 return success();
2222 }
2223
2224 // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
2225 // a[0]`), we can fold the op into subaccess op `a[idx]`.
2226 if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2227 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
2228 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2229 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2230 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2231 })) {
2232 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2233 rewriter, op, lastSubindex.getInput(), op.getIndex());
2234 return success();
2235 }
2236 }
2237
2238 // If the size is 2, canonicalize into a normal mux to introduce more folds.
2239 if (op.getInputs().size() != 2)
2240 return failure();
2241
2242 // TODO: Handle even when `index` doesn't have uint<1>.
2243 auto uintType = op.getIndex().getType();
2244 if (uintType.getBitWidthOrSentinel() != 1)
2245 return failure();
2246
2247 // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
2248 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2249 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2250 return success();
2251}
2252
2253//===----------------------------------------------------------------------===//
2254// Declarations
2255//===----------------------------------------------------------------------===//
2256
2257/// Scan all the uses of the specified value, checking to see if there is
2258/// exactly one connect that has the value as its destination. This returns the
2259/// operation if found and if all the other users are "reads" from the value.
2260/// Returns null if there are no connects, or multiple connects to the value, or
2261/// if the value is involved in an `AttachOp`, or if the connect isn't matching.
2262///
2263/// Note that this will simply return the connect, which is located *anywhere*
2264/// after the definition of the value. Users of this function are likely
2265/// interested in the source side of the returned connect, the definition of
2266/// which does likely not dominate the original value.
2267MatchingConnectOp firrtl::getSingleConnectUserOf(Value value) {
2268 MatchingConnectOp connect;
2269 for (Operation *user : value.getUsers()) {
2270 // If we see an attach or aggregate sublements, just conservatively fail.
2271 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2272 return {};
2273
2274 if (auto aConnect = dyn_cast<FConnectLike>(user))
2275 if (aConnect.getDest() == value) {
2276 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2277 // If this is not a matching connect, a second matching connect or in a
2278 // different block, fail.
2279 if (!matchingConnect || (connect && connect != matchingConnect) ||
2280 matchingConnect->getBlock() != value.getParentBlock())
2281 return {};
2282 connect = matchingConnect;
2283 }
2284 }
2285 return connect;
2286}
2287
2288// Forward simple values through wire's and reg's.
2289static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op,
2290 PatternRewriter &rewriter) {
2291 // While we can do this for nearly all wires, we currently limit it to simple
2292 // things.
2293 Operation *connectedDecl = op.getDest().getDefiningOp();
2294 if (!connectedDecl)
2295 return failure();
2296
2297 // Only support wire and reg for now.
2298 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2299 return failure();
2300 if (hasDontTouch(connectedDecl) || !AnnotationSet(connectedDecl).empty() ||
2301 !hasDroppableName(connectedDecl) ||
2302 cast<Forceable>(connectedDecl).isForceable())
2303 return failure();
2304
2305 // Only forward if the types exactly match and there is one connect.
2306 if (getSingleConnectUserOf(op.getDest()) != op)
2307 return failure();
2308
2309 // Only forward if there is more than one use
2310 if (connectedDecl->hasOneUse())
2311 return failure();
2312
2313 // Only do this if the connectee and the declaration are in the same block.
2314 auto *declBlock = connectedDecl->getBlock();
2315 auto *srcValueOp = op.getSrc().getDefiningOp();
2316 if (!srcValueOp) {
2317 // Ports are ok for wires but not registers.
2318 if (!isa<WireOp>(connectedDecl))
2319 return failure();
2320
2321 } else {
2322 // Constants/invalids in the same block are ok to forward, even through
2323 // reg's since the clocking doesn't matter for constants.
2324 if (!isa<ConstantOp>(srcValueOp))
2325 return failure();
2326 if (srcValueOp->getBlock() != declBlock)
2327 return failure();
2328 }
2329
2330 // Ok, we know we are doing the transformation.
2331
2332 auto replacement = op.getSrc();
2333 // This will be replaced with the constant source. First, make sure the
2334 // constant dominates all users.
2335 if (srcValueOp && srcValueOp != &declBlock->front())
2336 srcValueOp->moveBefore(&declBlock->front());
2337
2338 // Replace all things *using* the decl with the constant/port, and
2339 // remove the declaration.
2340 replaceOpAndCopyName(rewriter, connectedDecl, replacement);
2341
2342 // Remove the connect
2343 rewriter.eraseOp(op);
2344 return success();
2345}
2346
2347void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2348 MLIRContext *context) {
2349 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2350 context);
2351}
2352
2353LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2354 PatternRewriter &rewriter) {
2355 // TODO: Canonicalize towards explicit extensions and flips here.
2356
2357 // If there is a simple value connected to a foldable decl like a wire or reg,
2358 // see if we can eliminate the decl.
2359 if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
2360 return success();
2361 return failure();
2362}
2363
2364//===----------------------------------------------------------------------===//
2365// Statements
2366//===----------------------------------------------------------------------===//
2367
2368/// If the specified value has an AttachOp user strictly dominating by
2369/// "dominatingAttach" then return it.
2370static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
2371 for (auto *user : value.getUsers()) {
2372 auto attach = dyn_cast<AttachOp>(user);
2373 if (!attach || attach == dominatedAttach)
2374 continue;
2375 if (attach->isBeforeInBlock(dominatedAttach))
2376 return attach;
2377 }
2378 return {};
2379}
2380
2381LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2382 // Single operand attaches are a noop.
2383 if (op.getNumOperands() <= 1) {
2384 rewriter.eraseOp(op);
2385 return success();
2386 }
2387
2388 for (auto operand : op.getOperands()) {
2389 // Check to see if any of our operands has other attaches to it:
2390 // attach x, y
2391 // ...
2392 // attach x, z
2393 // If so, we can merge these into "attach x, y, z".
2394 if (auto attach = getDominatingAttachUser(operand, op)) {
2395 SmallVector<Value> newOperands(op.getOperands());
2396 for (auto newOperand : attach.getOperands())
2397 if (newOperand != operand) // Don't add operand twice.
2398 newOperands.push_back(newOperand);
2399 AttachOp::create(rewriter, op->getLoc(), newOperands);
2400 rewriter.eraseOp(attach);
2401 rewriter.eraseOp(op);
2402 return success();
2403 }
2404
2405 // If this wire is *only* used by an attach then we can just delete
2406 // it.
2407 // TODO: May need to be sensitive to "don't touch" or other
2408 // annotations.
2409 if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2410 if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2411 !wire.isForceable()) {
2412 SmallVector<Value> newOperands;
2413 for (auto newOperand : op.getOperands())
2414 if (newOperand != operand) // Don't the add wire.
2415 newOperands.push_back(newOperand);
2416
2417 AttachOp::create(rewriter, op->getLoc(), newOperands);
2418 rewriter.eraseOp(op);
2419 rewriter.eraseOp(wire);
2420 return success();
2421 }
2422 }
2423 }
2424 return failure();
2425}
2426
2427/// Replaces the given op with the contents of the given single-block region.
2428static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
2429 Region &region) {
2430 assert(llvm::hasSingleElement(region) && "expected single-region block");
2431 rewriter.inlineBlockBefore(&region.front(), op, {});
2432}
2433
2434LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2435 if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2436 if (constant.getValue().isAllOnes())
2437 replaceOpWithRegion(rewriter, op, op.getThenRegion());
2438 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2439 replaceOpWithRegion(rewriter, op, op.getElseRegion());
2440
2441 rewriter.eraseOp(op);
2442
2443 return success();
2444 }
2445
2446 // Erase empty if-else block.
2447 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2448 op.getElseBlock().empty()) {
2449 rewriter.eraseBlock(&op.getElseBlock());
2450 return success();
2451 }
2452
2453 // Erase empty whens.
2454
2455 // If there is stuff in the then block, leave this operation alone.
2456 if (!op.getThenBlock().empty())
2457 return failure();
2458
2459 // If not and there is no else, then this operation is just useless.
2460 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2461 rewriter.eraseOp(op);
2462 return success();
2463 }
2464 return failure();
2465}
2466
2467namespace {
2468// Remove private nodes. If they have an interesting names, move the name to
2469// the source expression.
2470struct FoldNodeName : public mlir::OpRewritePattern<NodeOp> {
2471 using OpRewritePattern::OpRewritePattern;
2472 LogicalResult matchAndRewrite(NodeOp node,
2473 PatternRewriter &rewriter) const override {
2474 auto name = node.getNameAttr();
2475 if (!node.hasDroppableName() || node.getInnerSym() ||
2476 !AnnotationSet(node).empty() || node.isForceable())
2477 return failure();
2478 auto *newOp = node.getInput().getDefiningOp();
2479 if (newOp)
2480 updateName(rewriter, newOp, name);
2481 rewriter.replaceOp(node, node.getInput());
2482 return success();
2483 }
2484};
2485
2486// Bypass nodes.
2487struct NodeBypass : public mlir::OpRewritePattern<NodeOp> {
2488 using OpRewritePattern::OpRewritePattern;
2489 LogicalResult matchAndRewrite(NodeOp node,
2490 PatternRewriter &rewriter) const override {
2491 if (node.getInnerSym() || !AnnotationSet(node).empty() ||
2492 node.use_empty() || node.isForceable())
2493 return failure();
2494 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2495 return success();
2496 }
2497};
2498
2499} // namespace
2500
2501template <typename OpTy>
2502static LogicalResult demoteForceableIfUnused(OpTy op,
2503 PatternRewriter &rewriter) {
2504 if (!op.isForceable() || !op.getDataRef().use_empty())
2505 return failure();
2506
2507 firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
2508 return success();
2509}
2510
2511// Interesting names and symbols and don't touch force nodes to stick around.
2512LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2513 SmallVectorImpl<OpFoldResult> &results) {
2514 if (!hasDroppableName())
2515 return failure();
2516 if (hasDontTouch(getResult())) // handles inner symbols
2517 return failure();
2518 if (getAnnotationsAttr() && !AnnotationSet(getAnnotationsAttr()).empty())
2519 return failure();
2520 if (isForceable())
2521 return failure();
2522 if (!adaptor.getInput())
2523 return failure();
2524
2525 results.push_back(adaptor.getInput());
2526 return success();
2527}
2528
2529void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2530 MLIRContext *context) {
2531 results.insert<FoldNodeName>(context);
2532 results.add(demoteForceableIfUnused<NodeOp>);
2533}
2534
2535namespace {
2536// For a lhs, find all the writers of fields of the aggregate type. If there
2537// is one writer for each field, merge the writes
2538struct AggOneShot : public mlir::RewritePattern {
2539 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2540 : RewritePattern(name, 0, context) {}
2541
2542 SmallVector<Value> getCompleteWrite(Operation *lhs) const {
2543 auto lhsTy = lhs->getResult(0).getType();
2544 if (!type_isa<BundleType, FVectorType>(lhsTy))
2545 return {};
2546
2547 DenseMap<uint32_t, Value> fields;
2548 for (Operation *user : lhs->getResult(0).getUsers()) {
2549 if (user->getParentOp() != lhs->getParentOp())
2550 return {};
2551 if (auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2552 if (aConnect.getDest() == lhs->getResult(0))
2553 return {};
2554 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2555 for (Operation *subuser : subField.getResult().getUsers()) {
2556 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2557 if (aConnect.getDest() == subField) {
2558 if (subuser->getParentOp() != lhs->getParentOp())
2559 return {};
2560 if (fields.count(subField.getFieldIndex())) // duplicate write
2561 return {};
2562 fields[subField.getFieldIndex()] = aConnect.getSrc();
2563 }
2564 continue;
2565 }
2566 return {};
2567 }
2568 } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2569 for (Operation *subuser : subIndex.getResult().getUsers()) {
2570 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2571 if (aConnect.getDest() == subIndex) {
2572 if (subuser->getParentOp() != lhs->getParentOp())
2573 return {};
2574 if (fields.count(subIndex.getIndex())) // duplicate write
2575 return {};
2576 fields[subIndex.getIndex()] = aConnect.getSrc();
2577 }
2578 continue;
2579 }
2580 return {};
2581 }
2582 } else {
2583 return {};
2584 }
2585 }
2586
2587 SmallVector<Value> values;
2588 uint32_t total = type_isa<BundleType>(lhsTy)
2589 ? type_cast<BundleType>(lhsTy).getNumElements()
2590 : type_cast<FVectorType>(lhsTy).getNumElements();
2591 for (uint32_t i = 0; i < total; ++i) {
2592 if (!fields.count(i))
2593 return {};
2594 values.push_back(fields[i]);
2595 }
2596 return values;
2597 }
2598
2599 LogicalResult matchAndRewrite(Operation *op,
2600 PatternRewriter &rewriter) const override {
2601 auto values = getCompleteWrite(op);
2602 if (values.empty())
2603 return failure();
2604 rewriter.setInsertionPointToEnd(op->getBlock());
2605 auto dest = op->getResult(0);
2606 auto destType = dest.getType();
2607
2608 // If not passive, cannot matchingconnect.
2609 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2610 return failure();
2611
2612 Value newVal = type_isa<BundleType>(destType)
2613 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2614 destType, values)
2615 : rewriter.createOrFold<VectorCreateOp>(
2616 op->getLoc(), destType, values);
2617 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2618 for (Operation *user : dest.getUsers()) {
2619 if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2620 for (Operation *subuser :
2621 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2622 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2623 if (aConnect.getDest() == subIndex)
2624 rewriter.eraseOp(aConnect);
2625 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2626 for (Operation *subuser :
2627 llvm::make_early_inc_range(subField.getResult().getUsers()))
2628 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2629 if (aConnect.getDest() == subField)
2630 rewriter.eraseOp(aConnect);
2631 }
2632 }
2633 return success();
2634 }
2635};
2636
2637struct WireAggOneShot : public AggOneShot {
2638 WireAggOneShot(MLIRContext *context)
2639 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2640};
2641struct SubindexAggOneShot : public AggOneShot {
2642 SubindexAggOneShot(MLIRContext *context)
2643 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2644};
2645struct SubfieldAggOneShot : public AggOneShot {
2646 SubfieldAggOneShot(MLIRContext *context)
2647 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2648};
2649} // namespace
2650
2651void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2652 MLIRContext *context) {
2653 results.insert<WireAggOneShot>(context);
2654 results.add(demoteForceableIfUnused<WireOp>);
2655}
2656
2657void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2658 MLIRContext *context) {
2659 results.insert<SubindexAggOneShot>(context);
2660}
2661
2662OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2663 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2664 if (!attr)
2665 return {};
2666 return attr[getIndex()];
2667}
2668
2669OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2670 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2671 if (!attr)
2672 return {};
2673 auto index = getFieldIndex();
2674 return attr[index];
2675}
2676
2677void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2678 MLIRContext *context) {
2679 results.insert<SubfieldAggOneShot>(context);
2680}
2681
2682static Attribute collectFields(MLIRContext *context,
2683 ArrayRef<Attribute> operands) {
2684 for (auto operand : operands)
2685 if (!operand)
2686 return {};
2687 return ArrayAttr::get(context, operands);
2688}
2689
2690OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2691 // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2692 // bundle<a:..., b:...>.
2693 if (getNumOperands() > 0)
2694 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2695 if (first.getFieldIndex() == 0 &&
2696 first.getInput().getType() == getType() &&
2697 llvm::all_of(
2698 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2699 auto subindex =
2700 elem.value().template getDefiningOp<SubfieldOp>();
2701 return subindex && subindex.getInput() == first.getInput() &&
2702 subindex.getFieldIndex() == elem.index();
2703 }))
2704 return first.getInput();
2705
2706 return collectFields(getContext(), adaptor.getOperands());
2707}
2708
2709OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2710 // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2711 // vector<..., 2>.
2712 if (getNumOperands() > 0)
2713 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2714 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2715 llvm::all_of(
2716 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2717 auto subindex =
2718 elem.value().template getDefiningOp<SubindexOp>();
2719 return subindex && subindex.getInput() == first.getInput() &&
2720 subindex.getIndex() == elem.index();
2721 }))
2722 return first.getInput();
2723
2724 return collectFields(getContext(), adaptor.getOperands());
2725}
2726
2727OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2728 if (getOperand().getType() == getType())
2729 return getOperand();
2730 return {};
2731}
2732
2733namespace {
2734// A register with constant reset and all connection to either itself or the
2735// same constant, must be replaced by the constant.
2736struct FoldResetMux : public mlir::OpRewritePattern<RegResetOp> {
2737 using OpRewritePattern::OpRewritePattern;
2738 LogicalResult matchAndRewrite(RegResetOp reg,
2739 PatternRewriter &rewriter) const override {
2740 auto reset =
2741 dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2742 if (!reset || hasDontTouch(reg.getOperation()) ||
2743 !AnnotationSet(reg).empty() || reg.isForceable())
2744 return failure();
2745 // Find the one true connect, or bail
2746 auto con = getSingleConnectUserOf(reg.getResult());
2747 if (!con)
2748 return failure();
2749
2750 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2751 if (!mux)
2752 return failure();
2753 auto *high = mux.getHigh().getDefiningOp();
2754 auto *low = mux.getLow().getDefiningOp();
2755 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2756
2757 if (constOp && low != reg)
2758 return failure();
2759 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2760 constOp = dyn_cast<ConstantOp>(low);
2761
2762 if (!constOp || constOp.getType() != reset.getType() ||
2763 constOp.getValue() != reset.getValue())
2764 return failure();
2765
2766 // Check all types should be typed by now
2767 auto regTy = reg.getResult().getType();
2768 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2769 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2770 regTy.getBitWidthOrSentinel() < 0)
2771 return failure();
2772
2773 // Ok, we know we are doing the transformation.
2774
2775 // Make sure the constant dominates all users.
2776 if (constOp != &con->getBlock()->front())
2777 constOp->moveBefore(&con->getBlock()->front());
2778
2779 // Replace the register with the constant.
2780 replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2781 // Remove the connect.
2782 rewriter.eraseOp(con);
2783 return success();
2784 }
2785};
2786} // namespace
2787
2788static bool isDefinedByOneConstantOp(Value v) {
2789 if (auto c = v.getDefiningOp<ConstantOp>())
2790 return c.getValue().isOne();
2791 if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2792 return sc.getValue();
2793 return false;
2794}
2795
2796static LogicalResult
2797canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2798 if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2799 return failure();
2800
2801 auto resetValue = reg.getResetValue();
2802 if (reg.getType(0) != resetValue.getType())
2803 return failure();
2804
2805 // Ignore 'passthrough'.
2806 (void)dropWrite(rewriter, reg->getResult(0), {});
2807 replaceOpWithNewOpAndCopyName<NodeOp>(
2808 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2809 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2810 return success();
2811}
2812
2813void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2814 MLIRContext *context) {
2815 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2817 results.add(demoteForceableIfUnused<RegResetOp>);
2818}
2819
2820// Returns the value connected to a port, if there is only one.
2821static Value getPortFieldValue(Value port, StringRef name) {
2822 auto portTy = type_cast<BundleType>(port.getType());
2823 auto fieldIndex = portTy.getElementIndex(name);
2824 assert(fieldIndex && "missing field on memory port");
2825
2826 Value value = {};
2827 for (auto *op : port.getUsers()) {
2828 auto portAccess = cast<SubfieldOp>(op);
2829 if (fieldIndex != portAccess.getFieldIndex())
2830 continue;
2831 auto conn = getSingleConnectUserOf(portAccess);
2832 if (!conn || value)
2833 return {};
2834 value = conn.getSrc();
2835 }
2836 return value;
2837}
2838
2839// Returns true if the enable field of a port is set to false.
2840static bool isPortDisabled(Value port) {
2841 auto value = getPortFieldValue(port, "en");
2842 if (!value)
2843 return false;
2844 auto portConst = value.getDefiningOp<ConstantOp>();
2845 if (!portConst)
2846 return false;
2847 return portConst.getValue().isZero();
2848}
2849
2850// Returns true if the data output is unused.
2851static bool isPortUnused(Value port, StringRef data) {
2852 auto portTy = type_cast<BundleType>(port.getType());
2853 auto fieldIndex = portTy.getElementIndex(data);
2854 assert(fieldIndex && "missing enable flag on memory port");
2855
2856 for (auto *op : port.getUsers()) {
2857 auto portAccess = cast<SubfieldOp>(op);
2858 if (fieldIndex != portAccess.getFieldIndex())
2859 continue;
2860 if (!portAccess.use_empty())
2861 return false;
2862 }
2863
2864 return true;
2865}
2866
2867// Returns the value connected to a port, if there is only one.
2868static void replacePortField(PatternRewriter &rewriter, Value port,
2869 StringRef name, Value value) {
2870 auto portTy = type_cast<BundleType>(port.getType());
2871 auto fieldIndex = portTy.getElementIndex(name);
2872 assert(fieldIndex && "missing field on memory port");
2873
2874 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2875 auto portAccess = cast<SubfieldOp>(op);
2876 if (fieldIndex != portAccess.getFieldIndex())
2877 continue;
2878 rewriter.replaceAllUsesWith(portAccess, value);
2879 rewriter.eraseOp(portAccess);
2880 }
2881}
2882
2883// Remove accesses to a port which is used.
2884static void erasePort(PatternRewriter &rewriter, Value port) {
2885 // Helper to create a dummy 0 clock for the dummy registers.
2886 Value clock;
2887 auto getClock = [&] {
2888 if (!clock)
2889 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2890 ClockType::get(rewriter.getContext()),
2891 false);
2892 return clock;
2893 };
2894
2895 // Find the clock field of the port and determine whether the port is
2896 // accessed only through its subfields or as a whole wire. If the port
2897 // is used in its entirety, replace it with a wire. Otherwise,
2898 // eliminate individual subfields and replace with reasonable defaults.
2899 for (auto *op : port.getUsers()) {
2900 auto subfield = dyn_cast<SubfieldOp>(op);
2901 if (!subfield) {
2902 auto ty = port.getType();
2903 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2904 rewriter.replaceAllUsesWith(port, reg.getResult());
2905 return;
2906 }
2907 }
2908
2909 // Remove all connects to field accesses as they are no longer relevant.
2910 // If field values are used anywhere, which should happen solely for read
2911 // ports, a dummy register is introduced which replicates the behaviour of
2912 // memory that is never written, but might be read.
2913 for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2914 auto access = cast<SubfieldOp>(accessOp);
2915 for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2916 auto connect = dyn_cast<FConnectLike>(user);
2917 if (connect && connect.getDest() == access) {
2918 rewriter.eraseOp(user);
2919 continue;
2920 }
2921 }
2922 if (access.use_empty()) {
2923 rewriter.eraseOp(access);
2924 continue;
2925 }
2926
2927 // Replace read values with a register that is never written, handing off
2928 // the canonicalization of such a register to another canonicalizer.
2929 auto ty = access.getType();
2930 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2931 rewriter.replaceOp(access, reg.getResult());
2932 }
2933 assert(port.use_empty() && "port should have no remaining uses");
2934}
2935
2936namespace {
2937// If memory has known, but zero width, eliminate it.
2938struct FoldZeroWidthMemory : public mlir::OpRewritePattern<MemOp> {
2939 using OpRewritePattern::OpRewritePattern;
2940 LogicalResult matchAndRewrite(MemOp mem,
2941 PatternRewriter &rewriter) const override {
2942 if (hasDontTouch(mem))
2943 return failure();
2944
2945 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2946 mem.getDataType().getBitWidthOrSentinel() != 0)
2947 return failure();
2948
2949 // Make sure are users are safe to replace
2950 for (auto port : mem.getResults())
2951 for (auto *user : port.getUsers())
2952 if (!isa<SubfieldOp>(user))
2953 return failure();
2954
2955 // Annoyingly, there isn't a good replacement for the port as a whole,
2956 // since they have an outer flip type.
2957 for (auto port : mem.getResults()) {
2958 for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2959 SubfieldOp sfop = cast<SubfieldOp>(user);
2960 StringRef fieldName = sfop.getFieldName();
2961 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2962 rewriter, sfop, sfop.getResult().getType())
2963 .getResult();
2964 if (fieldName.ends_with("data")) {
2965 // Make sure to write data ports.
2966 auto zero = firrtl::ConstantOp::create(
2967 rewriter, wire.getLoc(),
2968 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2969 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2970 }
2971 }
2972 }
2973 rewriter.eraseOp(mem);
2974 return success();
2975 }
2976};
2977
2978// If memory has no write ports and no file initialization, eliminate it.
2979struct FoldReadOrWriteOnlyMemory : public mlir::OpRewritePattern<MemOp> {
2980 using OpRewritePattern::OpRewritePattern;
2981 LogicalResult matchAndRewrite(MemOp mem,
2982 PatternRewriter &rewriter) const override {
2983 if (hasDontTouch(mem))
2984 return failure();
2985 bool isRead = false, isWritten = false;
2986 for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2987 switch (mem.getPortKind(i)) {
2988 case MemOp::PortKind::Read:
2989 isRead = true;
2990 if (isWritten)
2991 return failure();
2992 continue;
2993 case MemOp::PortKind::Write:
2994 isWritten = true;
2995 if (isRead)
2996 return failure();
2997 continue;
2998 case MemOp::PortKind::Debug:
2999 case MemOp::PortKind::ReadWrite:
3000 return failure();
3001 }
3002 llvm_unreachable("unknown port kind");
3003 }
3004 assert((!isWritten || !isRead) && "memory is in use");
3005
3006 // If the memory is read only, but has a file initialization, then we can't
3007 // remove it. A write only memory with file initialization is okay to
3008 // remove.
3009 if (isRead && mem.getInit())
3010 return failure();
3011
3012 for (auto port : mem.getResults())
3013 erasePort(rewriter, port);
3014
3015 rewriter.eraseOp(mem);
3016 return success();
3017 }
3018};
3019
3020// Eliminate the dead ports of memories.
3021struct FoldUnusedPorts : public mlir::OpRewritePattern<MemOp> {
3022 using OpRewritePattern::OpRewritePattern;
3023 LogicalResult matchAndRewrite(MemOp mem,
3024 PatternRewriter &rewriter) const override {
3025 if (hasDontTouch(mem))
3026 return failure();
3027 // Identify the dead and changed ports.
3028 llvm::SmallBitVector deadPorts(mem.getNumResults());
3029 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3030 // Do not simplify annotated ports.
3031 if (!mem.getPortAnnotation(i).empty())
3032 continue;
3033
3034 // Skip debug ports.
3035 auto kind = mem.getPortKind(i);
3036 if (kind == MemOp::PortKind::Debug)
3037 continue;
3038
3039 // If a port is disabled, always eliminate it.
3040 if (isPortDisabled(port)) {
3041 deadPorts.set(i);
3042 continue;
3043 }
3044 // Eliminate read ports whose outputs are not used.
3045 if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
3046 deadPorts.set(i);
3047 continue;
3048 }
3049 }
3050 if (deadPorts.none())
3051 return failure();
3052
3053 // Rebuild the new memory with the altered ports.
3054 SmallVector<Type> resultTypes;
3055 SmallVector<StringRef> portNames;
3056 SmallVector<Attribute> portAnnotations;
3057 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3058 if (deadPorts[i])
3059 continue;
3060 resultTypes.push_back(port.getType());
3061 portNames.push_back(mem.getPortName(i));
3062 portAnnotations.push_back(mem.getPortAnnotation(i));
3063 }
3064
3065 MemOp newOp;
3066 if (!resultTypes.empty())
3067 newOp = MemOp::create(
3068 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3069 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3070 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3071 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3072 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3073
3074 // Replace the dead ports with dummy wires.
3075 unsigned nextPort = 0;
3076 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3077 if (deadPorts[i])
3078 erasePort(rewriter, port);
3079 else
3080 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
3081 }
3082
3083 rewriter.eraseOp(mem);
3084 return success();
3085 }
3086};
3087
3088// Rewrite write-only read-write ports to write ports.
3089struct FoldReadWritePorts : public mlir::OpRewritePattern<MemOp> {
3090 using OpRewritePattern::OpRewritePattern;
3091 LogicalResult matchAndRewrite(MemOp mem,
3092 PatternRewriter &rewriter) const override {
3093 if (hasDontTouch(mem))
3094 return failure();
3095
3096 // Identify read-write ports whose read end is unused.
3097 llvm::SmallBitVector deadReads(mem.getNumResults());
3098 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3099 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
3100 continue;
3101 if (!mem.getPortAnnotation(i).empty())
3102 continue;
3103 if (isPortUnused(port, "rdata")) {
3104 deadReads.set(i);
3105 continue;
3106 }
3107 }
3108 if (deadReads.none())
3109 return failure();
3110
3111 SmallVector<Type> resultTypes;
3112 SmallVector<StringRef> portNames;
3113 SmallVector<Attribute> portAnnotations;
3114 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3115 if (deadReads[i])
3116 resultTypes.push_back(
3117 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
3118 MemOp::PortKind::Write, mem.getMaskBits()));
3119 else
3120 resultTypes.push_back(port.getType());
3121
3122 portNames.push_back(mem.getPortName(i));
3123 portAnnotations.push_back(mem.getPortAnnotation(i));
3124 }
3125
3126 auto newOp = MemOp::create(
3127 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
3128 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
3129 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
3130 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
3131 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3132
3133 for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
3134 auto result = mem.getResult(i);
3135 auto newResult = newOp.getResult(i);
3136 if (deadReads[i]) {
3137 auto resultPortTy = type_cast<BundleType>(result.getType());
3138
3139 // Rewrite accesses to the old port field to accesses to a
3140 // corresponding field of the new port.
3141 auto replace = [&](StringRef toName, StringRef fromName) {
3142 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
3143 assert(fromFieldIndex && "missing enable flag on memory port");
3144
3145 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
3146 newResult, toName);
3147 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
3148 auto fromField = cast<SubfieldOp>(op);
3149 if (fromFieldIndex != fromField.getFieldIndex())
3150 continue;
3151 rewriter.replaceOp(fromField, toField.getResult());
3152 }
3153 };
3154
3155 replace("addr", "addr");
3156 replace("en", "en");
3157 replace("clk", "clk");
3158 replace("data", "wdata");
3159 replace("mask", "wmask");
3160
3161 // Remove the wmode field, replacing it with dummy wires.
3162 auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
3163 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
3164 auto wmodeField = cast<SubfieldOp>(op);
3165 if (wmodeFieldIndex != wmodeField.getFieldIndex())
3166 continue;
3167 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
3168 }
3169 } else {
3170 rewriter.replaceAllUsesWith(result, newResult);
3171 }
3172 }
3173 rewriter.eraseOp(mem);
3174 return success();
3175 }
3176};
3177
3178// Eliminate the dead ports of memories.
3179struct FoldUnusedBits : public mlir::OpRewritePattern<MemOp> {
3180 using OpRewritePattern::OpRewritePattern;
3181
3182 LogicalResult matchAndRewrite(MemOp mem,
3183 PatternRewriter &rewriter) const override {
3184 if (hasDontTouch(mem))
3185 return failure();
3186
3187 // Only apply the transformation if the memory is not sequential.
3188 const auto &summary = mem.getSummary();
3189 if (summary.isMasked || summary.isSeqMem())
3190 return failure();
3191
3192 auto type = type_dyn_cast<IntType>(mem.getDataType());
3193 if (!type)
3194 return failure();
3195 auto width = type.getBitWidthOrSentinel();
3196 if (width <= 0)
3197 return failure();
3198
3199 llvm::SmallBitVector usedBits(width);
3200 DenseMap<unsigned, unsigned> mapping;
3201
3202 // Find which bits are used out of the users of a read port. This detects
3203 // ports whose data/rdata field is used only through bit select ops. The
3204 // bit selects are then used to build a bit-mask. The ops are collected.
3205 SmallVector<BitsPrimOp> readOps;
3206 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3207 auto portTy = type_cast<BundleType>(port.getType());
3208 auto fieldIndex = portTy.getElementIndex(field);
3209 assert(fieldIndex && "missing data port");
3210
3211 for (auto *op : port.getUsers()) {
3212 auto portAccess = cast<SubfieldOp>(op);
3213 if (fieldIndex != portAccess.getFieldIndex())
3214 continue;
3215
3216 for (auto *user : op->getUsers()) {
3217 auto bits = dyn_cast<BitsPrimOp>(user);
3218 if (!bits)
3219 return failure();
3220
3221 usedBits.set(bits.getLo(), bits.getHi() + 1);
3222 if (usedBits.all())
3223 return failure();
3224
3225 mapping[bits.getLo()] = 0;
3226 readOps.push_back(bits);
3227 }
3228 }
3229
3230 return success();
3231 };
3232
3233 // Finds the users of write ports. This expects all the data/wdata fields
3234 // of the ports to be used solely as the destination of matching connects.
3235 // If a memory has ports with other uses, it is excluded from optimisation.
3236 SmallVector<MatchingConnectOp> writeOps;
3237 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3238 auto portTy = type_cast<BundleType>(port.getType());
3239 auto fieldIndex = portTy.getElementIndex(field);
3240 assert(fieldIndex && "missing data port");
3241
3242 for (auto *op : port.getUsers()) {
3243 auto portAccess = cast<SubfieldOp>(op);
3244 if (fieldIndex != portAccess.getFieldIndex())
3245 continue;
3246
3247 auto conn = getSingleConnectUserOf(portAccess);
3248 if (!conn)
3249 return failure();
3250
3251 writeOps.push_back(conn);
3252 }
3253 return success();
3254 };
3255
3256 // Traverse all ports and find the read and used data fields.
3257 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3258 // Do not simplify annotated ports.
3259 if (!mem.getPortAnnotation(i).empty())
3260 return failure();
3261
3262 switch (mem.getPortKind(i)) {
3263 case MemOp::PortKind::Debug:
3264 // Skip debug ports.
3265 return failure();
3266 case MemOp::PortKind::Write:
3267 if (failed(findWriteUsers(port, "data")))
3268 return failure();
3269 continue;
3270 case MemOp::PortKind::Read:
3271 if (failed(findReadUsers(port, "data")))
3272 return failure();
3273 continue;
3274 case MemOp::PortKind::ReadWrite:
3275 if (failed(findWriteUsers(port, "wdata")))
3276 return failure();
3277 if (failed(findReadUsers(port, "rdata")))
3278 return failure();
3279 continue;
3280 }
3281 llvm_unreachable("unknown port kind");
3282 }
3283
3284 // Unused memories are handled in a different canonicalizer.
3285 if (usedBits.none())
3286 return failure();
3287
3288 // Build a mapping of existing indices to compacted ones.
3289 SmallVector<std::pair<unsigned, unsigned>> ranges;
3290 unsigned newWidth = 0;
3291 for (int i = usedBits.find_first(); 0 <= i && i < width;) {
3292 int e = usedBits.find_next_unset(i);
3293 if (e < 0)
3294 e = width;
3295 for (int idx = i; idx < e; ++idx, ++newWidth) {
3296 if (auto it = mapping.find(idx); it != mapping.end()) {
3297 it->second = newWidth;
3298 }
3299 }
3300 ranges.emplace_back(i, e - 1);
3301 i = e != width ? usedBits.find_next(e) : e;
3302 }
3303
3304 // Create the new op with the new port types.
3305 auto newType = IntType::get(mem->getContext(), type.isSigned(), newWidth);
3306 SmallVector<Type> portTypes;
3307 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3308 portTypes.push_back(
3309 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3310 }
3311 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3312 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3313 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3314 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3315 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3316
3317 // Rewrite bundle users to the new data type.
3318 auto rewriteSubfield = [&](Value port, StringRef field) {
3319 auto portTy = type_cast<BundleType>(port.getType());
3320 auto fieldIndex = portTy.getElementIndex(field);
3321 assert(fieldIndex && "missing data port");
3322
3323 rewriter.setInsertionPointAfter(newMem);
3324 auto newPortAccess =
3325 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3326
3327 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
3328 auto portAccess = cast<SubfieldOp>(op);
3329 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3330 continue;
3331 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3332 }
3333 };
3334
3335 // Rewrite the field accesses.
3336 for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
3337 switch (newMem.getPortKind(i)) {
3338 case MemOp::PortKind::Debug:
3339 llvm_unreachable("cannot rewrite debug port");
3340 case MemOp::PortKind::Write:
3341 rewriteSubfield(port, "data");
3342 continue;
3343 case MemOp::PortKind::Read:
3344 rewriteSubfield(port, "data");
3345 continue;
3346 case MemOp::PortKind::ReadWrite:
3347 rewriteSubfield(port, "rdata");
3348 rewriteSubfield(port, "wdata");
3349 continue;
3350 }
3351 llvm_unreachable("unknown port kind");
3352 }
3353
3354 // Rewrite the reads to the new ranges, compacting them.
3355 for (auto readOp : readOps) {
3356 rewriter.setInsertionPointAfter(readOp);
3357 auto it = mapping.find(readOp.getLo());
3358 assert(it != mapping.end() && "bit op mapping not found");
3359 // Create a new bit selection from the compressed memory. The new op may
3360 // be folded if we are selecting the entire compressed memory.
3361 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3362 readOp.getLoc(), readOp.getInput(),
3363 readOp.getHi() - readOp.getLo() + it->second, it->second);
3364 rewriter.replaceAllUsesWith(readOp, newReadValue);
3365 rewriter.eraseOp(readOp);
3366 }
3367
3368 // Rewrite the writes into a concatenation of slices.
3369 for (auto writeOp : writeOps) {
3370 Value source = writeOp.getSrc();
3371 rewriter.setInsertionPoint(writeOp);
3372
3373 SmallVector<Value> slices;
3374 for (auto &[start, end] : llvm::reverse(ranges)) {
3375 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3376 source, end, start);
3377 slices.push_back(slice);
3378 }
3379
3380 Value catOfSlices =
3381 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3382
3383 // If the original memory held a signed integer, then the compressed
3384 // memory will be signed too. Since the catOfSlices is always unsigned,
3385 // cast the data to a signed integer if needed before connecting back to
3386 // the memory.
3387 if (type.isSigned())
3388 catOfSlices =
3389 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3390
3391 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3392 catOfSlices);
3393 }
3394
3395 return success();
3396 }
3397};
3398
3399// Rewrite single-address memories to a firrtl register.
3400struct FoldRegMems : public mlir::OpRewritePattern<MemOp> {
3401 using OpRewritePattern::OpRewritePattern;
3402 LogicalResult matchAndRewrite(MemOp mem,
3403 PatternRewriter &rewriter) const override {
3404 const FirMemory &info = mem.getSummary();
3405 if (hasDontTouch(mem) || info.depth != 1)
3406 return failure();
3407
3408 auto ty = mem.getDataType();
3409 auto loc = mem.getLoc();
3410 auto *block = mem->getBlock();
3411
3412 // Find the clock of the register-to-be, all write ports should share it.
3413 Value clock;
3414 SmallPtrSet<Operation *, 8> connects;
3415 SmallVector<SubfieldOp> portAccesses;
3416 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3417 if (!mem.getPortAnnotation(i).empty())
3418 continue;
3419
3420 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3421 auto portTy = type_cast<BundleType>(port.getType());
3422 for (auto field : fields) {
3423 auto fieldIndex = portTy.getElementIndex(field);
3424 assert(fieldIndex && "missing field on memory port");
3425
3426 for (auto *op : port.getUsers()) {
3427 auto portAccess = cast<SubfieldOp>(op);
3428 if (fieldIndex != portAccess.getFieldIndex())
3429 continue;
3430 portAccesses.push_back(portAccess);
3431 for (auto *user : portAccess->getUsers()) {
3432 auto conn = dyn_cast<FConnectLike>(user);
3433 if (!conn)
3434 return failure();
3435 connects.insert(conn);
3436 }
3437 }
3438 }
3439 return success();
3440 };
3441
3442 switch (mem.getPortKind(i)) {
3443 case MemOp::PortKind::Debug:
3444 return failure();
3445 case MemOp::PortKind::Read:
3446 if (failed(collect({"clk", "en", "addr"})))
3447 return failure();
3448 continue;
3449 case MemOp::PortKind::Write:
3450 if (failed(collect({"clk", "en", "addr", "data", "mask"})))
3451 return failure();
3452 break;
3453 case MemOp::PortKind::ReadWrite:
3454 if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
3455 return failure();
3456 break;
3457 }
3458
3459 Value portClock = getPortFieldValue(port, "clk");
3460 if (!portClock || (clock && portClock != clock))
3461 return failure();
3462 clock = portClock;
3463 }
3464 // Create a new wire where the memory used to be. This wire will dominate
3465 // all readers of the memory. Reads should be made through this wire.
3466 rewriter.setInsertionPointAfter(mem);
3467 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3468
3469 // The memory is replaced by a register, which we place at the end of the
3470 // block, so that any value driven to the original memory will dominate the
3471 // new register (including the clock). All other ops will be placed
3472 // after the register.
3473 rewriter.setInsertionPointToEnd(block);
3474 auto memReg =
3475 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3476
3477 // Connect the output of the register to the wire.
3478 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3479
3480 // Helper to insert a given number of pipeline stages through registers.
3481 // The pipelines are placed at the end of the block.
3482 auto pipeline = [&](Value value, Value clock, const Twine &name,
3483 unsigned latency) {
3484 for (unsigned i = 0; i < latency; ++i) {
3485 std::string regName;
3486 {
3487 llvm::raw_string_ostream os(regName);
3488 os << mem.getName() << "_" << name << "_" << i;
3489 }
3490 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3491 rewriter.getStringAttr(regName))
3492 .getResult();
3493 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3494 value = reg;
3495 }
3496 return value;
3497 };
3498
3499 const unsigned writeStages = info.writeLatency - 1;
3500
3501 // Traverse each port. Replace reads with the pipelined register, discarding
3502 // the enable flag and reading unconditionally. Pipeline the mask, enable
3503 // and data bits of all write ports to be arbitrated and wired to the reg.
3504 SmallVector<std::tuple<Value, Value, Value>> writes;
3505 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3506 Value portClock = getPortFieldValue(port, "clk");
3507 StringRef name = mem.getPortName(i);
3508
3509 auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
3510 Value value = getPortFieldValue(port, field);
3511 assert(value);
3512 return pipeline(value, portClock, name + "_" + field, stages);
3513 };
3514
3515 switch (mem.getPortKind(i)) {
3516 case MemOp::PortKind::Debug:
3517 llvm_unreachable("unknown port kind");
3518 case MemOp::PortKind::Read: {
3519 // Read ports pipeline the addr and enable signals. However, the
3520 // address must be 0 for single-address memories and the enable signal
3521 // is ignored, always reading out the register. Under these constraints,
3522 // the read port can be replaced with the value from the register.
3523 replacePortField(rewriter, port, "data", memWire);
3524 break;
3525 }
3526 case MemOp::PortKind::Write: {
3527 auto data = portPipeline("data", writeStages);
3528 auto en = portPipeline("en", writeStages);
3529 auto mask = portPipeline("mask", writeStages);
3530 writes.emplace_back(data, en, mask);
3531 break;
3532 }
3533 case MemOp::PortKind::ReadWrite: {
3534 // Always read the register into the read end.
3535 replacePortField(rewriter, port, "rdata", memWire);
3536
3537 // Create a write enable and pipeline stages.
3538 auto wdata = portPipeline("wdata", writeStages);
3539 auto wmask = portPipeline("wmask", writeStages);
3540
3541 Value en = getPortFieldValue(port, "en");
3542 Value wmode = getPortFieldValue(port, "wmode");
3543
3544 auto wen = AndPrimOp::create(rewriter, port.getLoc(), en, wmode);
3545 auto wenPipelined =
3546 pipeline(wen, portClock, name + "_wen", writeStages);
3547 writes.emplace_back(wdata, wenPipelined, wmask);
3548 break;
3549 }
3550 }
3551 }
3552
3553 // Regardless of `writeUnderWrite`, always implement PortOrder.
3554 Value next = memReg;
3555 for (auto &[data, en, mask] : writes) {
3556 Value masked;
3557
3558 // If a mask bit is used, emit muxes to select the input from the
3559 // register (no mask) or the input (mask bit set).
3560 Location loc = mem.getLoc();
3561 unsigned maskGran = info.dataWidth / info.maskBits;
3562 SmallVector<Value> chunks;
3563 for (unsigned i = 0; i < info.maskBits; ++i) {
3564 unsigned hi = (i + 1) * maskGran - 1;
3565 unsigned lo = i * maskGran;
3566
3567 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
3568 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3569 auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
3570 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3571 chunks.push_back(chunk);
3572 }
3573
3574 std::reverse(chunks.begin(), chunks.end());
3575 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3576 next = MuxPrimOp::create(rewriter, next.getLoc(), en, masked, next);
3577 }
3578 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3579 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3580
3581 // Delete the fields and their associated connects.
3582 for (Operation *conn : connects)
3583 rewriter.eraseOp(conn);
3584 for (auto portAccess : portAccesses)
3585 rewriter.eraseOp(portAccess);
3586 rewriter.eraseOp(mem);
3587
3588 return success();
3589 }
3590};
3591} // namespace
3592
3593void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3594 MLIRContext *context) {
3595 results
3596 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3597 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3598 context);
3599}
3600
3601//===----------------------------------------------------------------------===//
3602// Declarations
3603//===----------------------------------------------------------------------===//
3604
3605// Turn synchronous reset looking register updates to registers with resets.
3606// Also, const prop registers that are driven by a mux tree containing only
3607// instances of one constant or self-assigns.
3608static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3609 // reg ; connect(reg, mux(port, const, val)) ->
3610 // reg.reset(port, const); connect(reg, val)
3611
3612 // Find the one true connect, or bail
3613 auto con = getSingleConnectUserOf(reg.getResult());
3614 if (!con)
3615 return failure();
3616
3617 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3618 if (!mux)
3619 return failure();
3620 auto *high = mux.getHigh().getDefiningOp();
3621 auto *low = mux.getLow().getDefiningOp();
3622 // Reset value must be constant
3623 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3624
3625 // Detect the case if a register only has two possible drivers:
3626 // (1) itself/uninit and (2) constant.
3627 // The mux can then be replaced with the constant.
3628 // r = mux(cond, r, 3) --> r = 3
3629 // r = mux(cond, 3, r) --> r = 3
3630 bool constReg = false;
3631
3632 if (constOp && low == reg)
3633 constReg = true;
3634 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3635 constReg = true;
3636 constOp = dyn_cast<ConstantOp>(low);
3637 }
3638 if (!constOp)
3639 return failure();
3640
3641 // For a non-constant register, reset should be a module port (heuristic to
3642 // limit to intended reset lines). Replace the register anyway if constant.
3643 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3644 return failure();
3645
3646 // Check all types should be typed by now
3647 auto regTy = reg.getResult().getType();
3648 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3649 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3650 regTy.getBitWidthOrSentinel() < 0)
3651 return failure();
3652
3653 // Ok, we know we are doing the transformation.
3654
3655 // Make sure the constant dominates all users.
3656 if (constOp != &con->getBlock()->front())
3657 constOp->moveBefore(&con->getBlock()->front());
3658
3659 if (!constReg) {
3660 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3661 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3662 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3663 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3664 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3665 reg.getForceableAttr());
3666 newReg->setDialectAttrs(attrs);
3667 }
3668 auto pt = rewriter.saveInsertionPoint();
3669 rewriter.setInsertionPoint(con);
3670 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3671 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3672 rewriter.restoreInsertionPoint(pt);
3673 return success();
3674}
3675
3676LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3677 if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3678 succeeded(foldHiddenReset(op, rewriter)))
3679 return success();
3680
3681 if (succeeded(demoteForceableIfUnused(op, rewriter)))
3682 return success();
3683
3684 return failure();
3685}
3686
3687//===----------------------------------------------------------------------===//
3688// Verification Ops.
3689//===----------------------------------------------------------------------===//
3690
3691static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3692 Value enable,
3693 PatternRewriter &rewriter,
3694 bool eraseIfZero) {
3695 // If the verification op is never enabled, delete it.
3696 if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3697 if (constant.getValue().isZero()) {
3698 rewriter.eraseOp(op);
3699 return success();
3700 }
3701 }
3702
3703 // If the verification op is never triggered, delete it.
3704 if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3705 if (constant.getValue().isZero() == eraseIfZero) {
3706 rewriter.eraseOp(op);
3707 return success();
3708 }
3709 }
3710
3711 return failure();
3712}
3713
3714template <class Op, bool EraseIfZero = false>
3715static LogicalResult canonicalizeImmediateVerifOp(Op op,
3716 PatternRewriter &rewriter) {
3717 return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3718 EraseIfZero);
3719}
3720
3721void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3722 MLIRContext *context) {
3723 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3724 results.add<patterns::AssertXWhenX>(context);
3725}
3726
3727void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3728 MLIRContext *context) {
3729 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3730 results.add<patterns::AssumeXWhenX>(context);
3731}
3732
3733void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3734 RewritePatternSet &results, MLIRContext *context) {
3735 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3736 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(context);
3737}
3738
3739void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3740 MLIRContext *context) {
3741 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3742}
3743
3744//===----------------------------------------------------------------------===//
3745// InvalidValueOp
3746//===----------------------------------------------------------------------===//
3747
3748LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3749 PatternRewriter &rewriter) {
3750 // Remove `InvalidValueOp`s with no uses.
3751 if (op.use_empty()) {
3752 rewriter.eraseOp(op);
3753 return success();
3754 }
3755 // Propagate invalids through a single use which is a unary op. You cannot
3756 // propagate through multiple uses as that breaks invalid semantics. Nor
3757 // can you propagate through binary ops or generally any op which computes.
3758 // Not is an exception as it is a pure, all-bits inverse.
3759 if (op->hasOneUse() &&
3760 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3761 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3762 *op->user_begin()) ||
3763 (isa<CvtPrimOp>(*op->user_begin()) &&
3764 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3765 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3766 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3767 .getBitWidthOrSentinel() > 0))) {
3768 auto *modop = *op->user_begin();
3769 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3770 modop->getResult(0).getType());
3771 rewriter.replaceAllOpUsesWith(modop, inv);
3772 rewriter.eraseOp(modop);
3773 rewriter.eraseOp(op);
3774 return success();
3775 }
3776 return failure();
3777}
3778
3779OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3780 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3781 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3782 return {};
3783}
3784
3785//===----------------------------------------------------------------------===//
3786// ClockGateIntrinsicOp
3787//===----------------------------------------------------------------------===//
3788
3789OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3790 // Forward the clock if one of the enables is always true.
3791 if (isConstantOne(adaptor.getEnable()) ||
3792 isConstantOne(adaptor.getTestEnable()))
3793 return getInput();
3794
3795 // Fold to a constant zero clock if the enables are always false.
3796 if (isConstantZero(adaptor.getEnable()) &&
3797 (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3798 return BoolAttr::get(getContext(), false);
3799
3800 // Forward constant zero clocks.
3801 if (isConstantZero(adaptor.getInput()))
3802 return BoolAttr::get(getContext(), false);
3803
3804 return {};
3805}
3806
3807LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3808 PatternRewriter &rewriter) {
3809 // Remove constant false test enable.
3810 if (auto testEnable = op.getTestEnable()) {
3811 if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3812 if (constOp.getValue().isZero()) {
3813 rewriter.modifyOpInPlace(op,
3814 [&] { op.getTestEnableMutable().clear(); });
3815 return success();
3816 }
3817 }
3818 }
3819
3820 return failure();
3821}
3822
3823//===----------------------------------------------------------------------===//
3824// Reference Ops.
3825//===----------------------------------------------------------------------===//
3826
3827// refresolve(forceable.ref) -> forceable.data
3828static LogicalResult
3829canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3830 auto forceable = op.getRef().getDefiningOp<Forceable>();
3831 if (!forceable || !forceable.isForceable() ||
3832 op.getRef() != forceable.getDataRef() ||
3833 op.getType() != forceable.getDataType())
3834 return failure();
3835 rewriter.replaceAllUsesWith(op, forceable.getData());
3836 return success();
3837}
3838
3839void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3840 MLIRContext *context) {
3841 results.insert<patterns::RefResolveOfRefSend>(context);
3842 results.insert(canonicalizeRefResolveOfForceable);
3843}
3844
3845OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3846 // RefCast is unnecessary if types match.
3847 if (getInput().getType() == getType())
3848 return getInput();
3849 return {};
3850}
3851
3852static bool isConstantZero(Value operand) {
3853 auto constOp = operand.getDefiningOp<ConstantOp>();
3854 return constOp && constOp.getValue().isZero();
3855}
3856
3857template <typename Op>
3858static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3859 if (isConstantZero(op.getPredicate())) {
3860 rewriter.eraseOp(op);
3861 return success();
3862 }
3863 return failure();
3864}
3865
3866void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3867 MLIRContext *context) {
3868 results.add(eraseIfPredFalse<RefForceOp>);
3869}
3870void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3871 MLIRContext *context) {
3872 results.add(eraseIfPredFalse<RefForceInitialOp>);
3873}
3874void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3875 MLIRContext *context) {
3876 results.add(eraseIfPredFalse<RefReleaseOp>);
3877}
3878void RefReleaseInitialOp::getCanonicalizationPatterns(
3879 RewritePatternSet &results, MLIRContext *context) {
3880 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3881}
3882
3883//===----------------------------------------------------------------------===//
3884// HasBeenResetIntrinsicOp
3885//===----------------------------------------------------------------------===//
3886
3887OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3888 // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3889
3890 // Fold to zero if the reset is a constant. In this case the op is either
3891 // permanently in reset or never resets. Both mean that the reset never
3892 // finishes, so this op never returns true.
3893 if (adaptor.getReset())
3894 return getIntZerosAttr(UIntType::get(getContext(), 1));
3895
3896 // Fold to zero if the clock is a constant and the reset is synchronous. In
3897 // that case the reset will never be started.
3898 if (isUInt1(getReset().getType()) && adaptor.getClock())
3899 return getIntZerosAttr(UIntType::get(getContext(), 1));
3900
3901 return {};
3902}
3903
3904//===----------------------------------------------------------------------===//
3905// FPGAProbeIntrinsicOp
3906//===----------------------------------------------------------------------===//
3907
3908static bool isTypeEmpty(FIRRTLType type) {
3910 .Case<FVectorType>(
3911 [&](auto ty) -> bool { return isTypeEmpty(ty.getElementType()); })
3912 .Case<BundleType>([&](auto ty) -> bool {
3913 for (auto elem : ty.getElements())
3914 if (!isTypeEmpty(elem.type))
3915 return false;
3916 return true;
3917 })
3918 .Case<IntType>([&](auto ty) { return ty.getWidth() == 0; })
3919 .Default([](auto) -> bool { return false; });
3920}
3921
3922LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3923 PatternRewriter &rewriter) {
3924 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3925 if (!firrtlTy)
3926 return failure();
3927
3928 if (!isTypeEmpty(firrtlTy))
3929 return failure();
3930
3931 rewriter.eraseOp(op);
3932 return success();
3933}
3934
3935//===----------------------------------------------------------------------===//
3936// Layer Block Op
3937//===----------------------------------------------------------------------===//
3938
3939LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3940 PatternRewriter &rewriter) {
3941
3942 // If the layerblock is empty, erase it.
3943 if (op.getBody()->empty()) {
3944 rewriter.eraseOp(op);
3945 return success();
3946 }
3947
3948 return failure();
3949}
3950
3951//===----------------------------------------------------------------------===//
3952// Domain-related Ops
3953//===----------------------------------------------------------------------===//
3954
3955OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3956 // If no domains are specified, then forward the input to the result.
3957 if (getDomains().empty())
3958 return getInput();
3959
3960 return {};
3961}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static std::optional< bool > getBoolValue(Attribute attr)
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:218
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:18
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition FoldUtils.h:28
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition FoldUtils.h:35
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21
LogicalResult matchAndRewrite(BitsPrimOp bits, mlir::PatternRewriter &rewriter) const override