CIRCT 22.0.0git
Loading...
Searching...
No Matches
FIRRTLFolds.cpp
Go to the documentation of this file.
1//===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the folding and canonicalizations for FIRRTL ops.
10//
11//===----------------------------------------------------------------------===//
12
17#include "circt/Support/APInt.h"
18#include "circt/Support/LLVM.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/SmallPtrSet.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/ADT/TypeSwitch.h"
26
27using namespace circt;
28using namespace firrtl;
29
30// Drop writes to old and pass through passthrough to make patterns easier to
31// write.
32static Value dropWrite(PatternRewriter &rewriter, OpResult old,
33 Value passthrough) {
34 SmallPtrSet<Operation *, 8> users;
35 for (auto *user : old.getUsers())
36 users.insert(user);
37 for (Operation *user : users)
38 if (auto connect = dyn_cast<FConnectLike>(user))
39 if (connect.getDest() == old)
40 rewriter.eraseOp(user);
41 return passthrough;
42}
43
44// Return true if it is OK to propagate the name to the operation.
45// Non-pure operations such as instances, registers, and memories are not
46// allowed to update names for name stabilities and LEC reasons.
47static bool isOkToPropagateName(Operation *op) {
48 // Conservatively disallow operations with regions to prevent performance
49 // regression due to recursive calls to sub-regions in mlir::isPure.
50 if (op->getNumRegions() != 0)
51 return false;
52 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
53}
54
55// Move a name hint from a soon to be deleted operation to a new operation.
56// Pass through the new operation to make patterns easier to write. This cannot
57// move a name to a port (block argument), doing so would require rewriting all
58// instance sites as well as the module.
59static Value moveNameHint(OpResult old, Value passthrough) {
60 Operation *op = passthrough.getDefiningOp();
61 // This should handle ports, but it isn't clear we can change those in
62 // canonicalizers.
63 assert(op && "passthrough must be an operation");
64 Operation *oldOp = old.getOwner();
65 auto name = oldOp->getAttrOfType<StringAttr>("name");
66 if (name && !name.getValue().empty() && isOkToPropagateName(op))
67 op->setAttr("name", name);
68 return passthrough;
69}
70
71// Declarative canonicalization patterns
72namespace circt {
73namespace firrtl {
74namespace patterns {
75#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
76} // namespace patterns
77} // namespace firrtl
78} // namespace circt
79
80/// Return true if this operation's operands and results all have a known width.
81/// This only works for integer types.
82static bool hasKnownWidthIntTypes(Operation *op) {
83 auto resultType = type_cast<IntType>(op->getResult(0).getType());
84 if (!resultType.hasWidth())
85 return false;
86 for (Value operand : op->getOperands())
87 if (!type_cast<IntType>(operand.getType()).hasWidth())
88 return false;
89 return true;
90}
91
92/// Return true if this value is 1 bit UInt.
93static bool isUInt1(Type type) {
94 auto t = type_dyn_cast<UIntType>(type);
95 if (!t || !t.hasWidth() || t.getWidth() != 1)
96 return false;
97 return true;
98}
99
100/// Set the name of an op based on the best of two names: The current name, and
101/// the name passed in.
102static void updateName(PatternRewriter &rewriter, Operation *op,
103 StringAttr name) {
104 // Should never rename InstanceOp
105 if (!name || name.getValue().empty() || !isOkToPropagateName(op))
106 return;
107 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) && "Should never rename");
108 auto newName = name.getValue(); // old name is interesting
109 auto newOpName = op->getAttrOfType<StringAttr>("name");
110 // new name might not be interesting
111 if (newOpName)
112 newName = chooseName(newOpName.getValue(), name.getValue());
113 // Only update if needed
114 if (!newOpName || newOpName.getValue() != newName)
115 rewriter.modifyOpInPlace(
116 op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
117}
118
119/// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
120/// If a replaced op has a "name" attribute, this function propagates the name
121/// to the new value.
122static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
123 Value newValue) {
124 if (auto *newOp = newValue.getDefiningOp()) {
125 auto name = op->getAttrOfType<StringAttr>("name");
126 updateName(rewriter, newOp, name);
127 }
128 rewriter.replaceOp(op, newValue);
129}
130
131/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
132/// attribute. If a replaced op has a "name" attribute, this function propagates
133/// the name to the new value.
134template <typename OpTy, typename... Args>
135static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
136 Operation *op, Args &&...args) {
137 auto name = op->getAttrOfType<StringAttr>("name");
138 auto newOp =
139 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
140 updateName(rewriter, newOp, name);
141 return newOp;
142}
143
144/// Return true if the name is droppable. Note that this is different from
145/// `isUselessName` because non-useless names may be also droppable.
147 if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
148 return namableOp.hasDroppableName();
149 return false;
150}
151
152/// Implicitly replace the operand to a constant folding operation with a const
153/// 0 in case the operand is non-constant but has a bit width 0, or if the
154/// operand is an invalid value.
155///
156/// This makes constant folding significantly easier, as we can simply pass the
157/// operands to an operation through this function to appropriately replace any
158/// zero-width dynamic values or invalid values with a constant of value 0.
159static std::optional<APSInt>
160getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
161 assert(type_cast<IntType>(operand.getType()) &&
162 "getExtendedConstant is limited to integer types");
163
164 // We never support constant folding to unknown width values.
165 if (destWidth < 0)
166 return {};
167
168 // Extension signedness follows the operand sign.
169 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
170 return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
171
172 // If the operand is zero bits, then we can return a zero of the result
173 // type.
174 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
175 return APSInt(destWidth,
176 type_cast<IntType>(operand.getType()).isUnsigned());
177 return {};
178}
179
180/// Determine the value of a constant operand for the sake of constant folding.
181static std::optional<APSInt> getConstant(Attribute operand) {
182 if (!operand)
183 return {};
184 if (auto attr = dyn_cast<BoolAttr>(operand))
185 return APSInt(APInt(1, attr.getValue()));
186 if (auto attr = dyn_cast<IntegerAttr>(operand))
187 return attr.getAPSInt();
188 return {};
189}
190
191/// Determine whether a constant operand is a zero value for the sake of
192/// constant folding. This considers `invalidvalue` to be zero.
193static bool isConstantZero(Attribute operand) {
194 if (auto cst = getConstant(operand))
195 return cst->isZero();
196 return false;
197}
198
199/// Determine whether a constant operand is a one value for the sake of constant
200/// folding.
201static bool isConstantOne(Attribute operand) {
202 if (auto cst = getConstant(operand))
203 return cst->isOne();
204 return false;
205}
206
207/// This is the policy for folding, which depends on the sort of operator we're
208/// processing.
209enum class BinOpKind {
210 Normal,
211 Compare,
213};
214
215/// Applies the constant folding function `calculate` to the given operands.
216///
217/// Sign or zero extends the operands appropriately to the bitwidth of the
218/// result type if \p useDstWidth is true, else to the larger of the two operand
219/// bit widths and depending on whether the operation is to be performed on
220/// signed or unsigned operands.
221static Attribute constFoldFIRRTLBinaryOp(
222 Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
223 const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
224 assert(operands.size() == 2 && "binary op takes two operands");
225
226 // We cannot fold something to an unknown width.
227 auto resultType = type_cast<IntType>(op->getResult(0).getType());
228 if (resultType.getWidthOrSentinel() < 0)
229 return {};
230
231 // Any binary op returning i0 is 0.
232 if (resultType.getWidthOrSentinel() == 0)
233 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
234
235 // Determine the operand widths. This is either dictated by the operand type,
236 // or if that type is an unsized integer, by the actual bits necessary to
237 // represent the constant value.
238 auto lhsWidth =
239 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
240 auto rhsWidth =
241 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
242 if (auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
243 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
244 if (auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
245 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
246
247 // Compares extend the operands to the widest of the operand types, not to the
248 // result type.
249 int32_t operandWidth;
250 switch (opKind) {
252 operandWidth = resultType.getWidthOrSentinel();
253 break;
255 // Compares compute with the widest operand, not at the destination type
256 // (which is always i1).
257 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
258 break;
260 operandWidth =
261 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
262 break;
263 }
264
265 auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
266 if (!lhs)
267 return {};
268 auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
269 if (!rhs)
270 return {};
271
272 APInt resultValue = calculate(*lhs, *rhs);
273
274 // If the result type is smaller than the computation then we need to
275 // narrow the constant after the calculation.
276 if (opKind == BinOpKind::DivideOrShift)
277 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
278
279 assert((unsigned)resultType.getWidthOrSentinel() ==
280 resultValue.getBitWidth());
281 return getIntAttr(resultType, resultValue);
282}
283
284/// Applies the canonicalization function `canonicalize` to the given operation.
285///
286/// Determines which (if any) of the operation's operands are constants, and
287/// provides them as arguments to the callback function. Any `invalidvalue` in
288/// the input is mapped to a constant zero. The value returned from the callback
289/// is used as the replacement for `op`, and an additional pad operation is
290/// inserted if necessary. Does nothing if the result of `op` is of unknown
291/// width, in which case the necessity of a pad cannot be determined.
292static LogicalResult canonicalizePrimOp(
293 Operation *op, PatternRewriter &rewriter,
294 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
295 // Can only operate on FIRRTL primitive operations.
296 if (op->getNumResults() != 1)
297 return failure();
298 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
299 if (!type)
300 return failure();
301
302 // Can only operate on operations with a known result width.
303 auto width = type.getBitWidthOrSentinel();
304 if (width < 0)
305 return failure();
306
307 // Determine which of the operands are constants.
308 SmallVector<Attribute, 3> constOperands;
309 constOperands.reserve(op->getNumOperands());
310 for (auto operand : op->getOperands()) {
311 Attribute attr;
312 if (auto *defOp = operand.getDefiningOp())
313 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
314 [&](auto op) { attr = op.getValueAttr(); });
315 constOperands.push_back(attr);
316 }
317
318 // Perform the canonicalization and materialize the result if it is a
319 // constant.
320 auto result = canonicalize(constOperands);
321 if (!result)
322 return failure();
323 Value resultValue;
324 if (auto cst = dyn_cast<Attribute>(result))
325 resultValue = op->getDialect()
326 ->materializeConstant(rewriter, cst, type, op->getLoc())
327 ->getResult(0);
328 else
329 resultValue = cast<Value>(result);
330
331 // Insert a pad if the type widths disagree.
332 if (width !=
333 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
334 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
335
336 // Insert a cast if this is a uint vs. sint or vice versa.
337 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
338 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
339 else if (type_isa<UIntType>(type) &&
340 type_isa<SIntType>(resultValue.getType()))
341 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
342
343 assert(type == resultValue.getType() && "canonicalization changed type");
344 replaceOpAndCopyName(rewriter, op, resultValue);
345 return success();
346}
347
348/// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
349/// value if `bitWidth` is 0.
350static APInt getMaxUnsignedValue(unsigned bitWidth) {
351 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
352}
353
354/// Get the smallest signed value of a given bit width. Returns a 1-bit zero
355/// value if `bitWidth` is 0.
356static APInt getMinSignedValue(unsigned bitWidth) {
357 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
358}
359
360/// Get the largest signed value of a given bit width. Returns a 1-bit zero
361/// value if `bitWidth` is 0.
362static APInt getMaxSignedValue(unsigned bitWidth) {
363 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
364}
365
366//===----------------------------------------------------------------------===//
367// Fold Hooks
368//===----------------------------------------------------------------------===//
369
370OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
371 assert(adaptor.getOperands().empty() && "constant has no operands");
372 return getValueAttr();
373}
374
375OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
376 assert(adaptor.getOperands().empty() && "constant has no operands");
377 return getValueAttr();
378}
379
380OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
381 assert(adaptor.getOperands().empty() && "constant has no operands");
382 return getFieldsAttr();
383}
384
385OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
386 assert(adaptor.getOperands().empty() && "constant has no operands");
387 return getValueAttr();
388}
389
390OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
391 assert(adaptor.getOperands().empty() && "constant has no operands");
392 return getValueAttr();
393}
394
395OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
396 assert(adaptor.getOperands().empty() && "constant has no operands");
397 return getValueAttr();
398}
399
400OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
401 assert(adaptor.getOperands().empty() && "constant has no operands");
402 return getValueAttr();
403}
404
405//===----------------------------------------------------------------------===//
406// Binary Operators
407//===----------------------------------------------------------------------===//
408
409OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
411 *this, adaptor.getOperands(), BinOpKind::Normal,
412 [=](const APSInt &a, const APSInt &b) { return a + b; });
413}
414
415void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
416 MLIRContext *context) {
417 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
418 patterns::AddOfSelf, patterns::AddOfPad>(context);
419}
420
421OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
423 *this, adaptor.getOperands(), BinOpKind::Normal,
424 [=](const APSInt &a, const APSInt &b) { return a - b; });
425}
426
427void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
428 MLIRContext *context) {
429 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
430 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
431 patterns::SubOfPadL, patterns::SubOfPadR>(context);
432}
433
434OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
435 // mul(x, 0) -> 0
436 //
437 // This is legal because it aligns with the Scala FIRRTL Compiler
438 // interpretation of lowering invalid to constant zero before constant
439 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
440 // multiplication this way and will emit "x * 0".
441 if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
442 return getIntZerosAttr(getType());
443
445 *this, adaptor.getOperands(), BinOpKind::Normal,
446 [=](const APSInt &a, const APSInt &b) { return a * b; });
447}
448
449OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
450 /// div(x, x) -> 1
451 ///
452 /// Division by zero is undefined in the FIRRTL specification. This fold
453 /// exploits that fact to optimize self division to one. Note: this should
454 /// supersede any division with invalid or zero. Division of invalid by
455 /// invalid should be one.
456 if (getLhs() == getRhs()) {
457 auto width = getType().base().getWidthOrSentinel();
458 if (width == -1)
459 width = 2;
460 // Only fold if we have at least 1 bit of width to represent the `1` value.
461 if (width != 0)
462 return getIntAttr(getType(), APInt(width, 1));
463 }
464
465 // div(0, x) -> 0
466 //
467 // This is legal because it aligns with the Scala FIRRTL Compiler
468 // interpretation of lowering invalid to constant zero before constant
469 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
470 // division this way and will emit "0 / x".
471 if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
472 return getIntZerosAttr(getType());
473
474 /// div(x, 1) -> x : (uint, uint) -> uint
475 ///
476 /// UInt division by one returns the numerator. SInt division can't
477 /// be folded here because it increases the return type bitwidth by
478 /// one and requires sign extension (a new op).
479 if (auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
480 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
481 return getLhs();
482
484 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
485 [=](const APSInt &a, const APSInt &b) -> APInt {
486 if (!!b)
487 return a / b;
488 return APInt(a.getBitWidth(), 0);
489 });
490}
491
492OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
493 // rem(x, x) -> 0
494 //
495 // Division by zero is undefined in the FIRRTL specification. This fold
496 // exploits that fact to optimize self division remainder to zero. Note:
497 // this should supersede any division with invalid or zero. Remainder of
498 // division of invalid by invalid should be zero.
499 if (getLhs() == getRhs())
500 return getIntZerosAttr(getType());
501
502 // rem(0, x) -> 0
503 //
504 // This is legal because it aligns with the Scala FIRRTL Compiler
505 // interpretation of lowering invalid to constant zero before constant
506 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
507 // division this way and will emit "0 % x".
508 if (isConstantZero(adaptor.getLhs()))
509 return getIntZerosAttr(getType());
510
512 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
513 [=](const APSInt &a, const APSInt &b) -> APInt {
514 if (!!b)
515 return a % b;
516 return APInt(a.getBitWidth(), 0);
517 });
518}
519
520OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
522 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
523 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
524}
525
526OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
528 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
529 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
530}
531
532OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
534 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
535 [=](const APSInt &a, const APSInt &b) -> APInt {
536 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
537 : a.ashr(b);
538 });
539}
540
541// TODO: Move to DRR.
542OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
543 if (auto rhsCst = getConstant(adaptor.getRhs())) {
544 /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
545 if (rhsCst->isZero())
546 return getIntZerosAttr(getType());
547
548 /// and(x, -1) -> x
549 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
550 getRhs().getType() == getType())
551 return getLhs();
552 }
553
554 if (auto lhsCst = getConstant(adaptor.getLhs())) {
555 /// and(0, x) -> 0, 0 is largest or is implicit zero extended
556 if (lhsCst->isZero())
557 return getIntZerosAttr(getType());
558
559 /// and(-1, x) -> x
560 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
561 getRhs().getType() == getType())
562 return getRhs();
563 }
564
565 /// and(x, x) -> x
566 if (getLhs() == getRhs() && getRhs().getType() == getType())
567 return getRhs();
568
570 *this, adaptor.getOperands(), BinOpKind::Normal,
571 [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
572}
573
574void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
575 MLIRContext *context) {
576 results
577 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
578 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
579 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
580}
581
582OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
583 if (auto rhsCst = getConstant(adaptor.getRhs())) {
584 /// or(x, 0) -> x
585 if (rhsCst->isZero() && getLhs().getType() == getType())
586 return getLhs();
587
588 /// or(x, -1) -> -1
589 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
590 getLhs().getType() == getType())
591 return getRhs();
592 }
593
594 if (auto lhsCst = getConstant(adaptor.getLhs())) {
595 /// or(0, x) -> x
596 if (lhsCst->isZero() && getRhs().getType() == getType())
597 return getRhs();
598
599 /// or(-1, x) -> -1
600 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
601 getRhs().getType() == getType())
602 return getLhs();
603 }
604
605 /// or(x, x) -> x
606 if (getLhs() == getRhs() && getRhs().getType() == getType())
607 return getRhs();
608
610 *this, adaptor.getOperands(), BinOpKind::Normal,
611 [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
612}
613
614void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
615 MLIRContext *context) {
616 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
617 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
618 patterns::OrOrr>(context);
619}
620
621OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
622 /// xor(x, 0) -> x
623 if (auto rhsCst = getConstant(adaptor.getRhs()))
624 if (rhsCst->isZero() &&
625 firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
626 return getLhs();
627
628 /// xor(x, 0) -> x
629 if (auto lhsCst = getConstant(adaptor.getLhs()))
630 if (lhsCst->isZero() &&
631 firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
632 return getRhs();
633
634 /// xor(x, x) -> 0
635 if (getLhs() == getRhs())
636 return getIntAttr(
637 getType(),
638 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
639
641 *this, adaptor.getOperands(), BinOpKind::Normal,
642 [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
643}
644
645void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
646 MLIRContext *context) {
647 results.insert<patterns::extendXor, patterns::moveConstXor,
648 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
649 context);
650}
651
652void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
653 MLIRContext *context) {
654 results.insert<patterns::LEQWithConstLHS>(context);
655}
656
657OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
658 bool isUnsigned = getLhs().getType().base().isUnsigned();
659
660 // leq(x, x) -> 1
661 if (getLhs() == getRhs())
662 return getIntAttr(getType(), APInt(1, 1));
663
664 // Comparison against constant outside type bounds.
665 if (auto width = getLhs().getType().base().getWidth()) {
666 if (auto rhsCst = getConstant(adaptor.getRhs())) {
667 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
668 commonWidth = std::max(commonWidth, 1);
669
670 // leq(x, const) -> 0 where const < minValue of the unsigned type of x
671 // This can never occur since const is unsigned and cannot be less than 0.
672
673 // leq(x, const) -> 0 where const < minValue of the signed type of x
674 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
675 .slt(getMinSignedValue(*width).sext(commonWidth)))
676 return getIntAttr(getType(), APInt(1, 0));
677
678 // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
679 if (isUnsigned && rhsCst->zext(commonWidth)
680 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
681 return getIntAttr(getType(), APInt(1, 1));
682
683 // leq(x, const) -> 1 where const >= maxValue of the signed type of x
684 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
685 .sge(getMaxSignedValue(*width).sext(commonWidth)))
686 return getIntAttr(getType(), APInt(1, 1));
687 }
688 }
689
691 *this, adaptor.getOperands(), BinOpKind::Compare,
692 [=](const APSInt &a, const APSInt &b) -> APInt {
693 return APInt(1, a <= b);
694 });
695}
696
697void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
698 MLIRContext *context) {
699 results.insert<patterns::LTWithConstLHS>(context);
700}
701
702OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
703 IntType lhsType = getLhs().getType();
704 bool isUnsigned = lhsType.isUnsigned();
705
706 // lt(x, x) -> 0
707 if (getLhs() == getRhs())
708 return getIntAttr(getType(), APInt(1, 0));
709
710 // lt(x, 0) -> 0 when x is unsigned
711 if (auto rhsCst = getConstant(adaptor.getRhs())) {
712 if (rhsCst->isZero() && lhsType.isUnsigned())
713 return getIntAttr(getType(), APInt(1, 0));
714 }
715
716 // Comparison against constant outside type bounds.
717 if (auto width = lhsType.getWidth()) {
718 if (auto rhsCst = getConstant(adaptor.getRhs())) {
719 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
720 commonWidth = std::max(commonWidth, 1);
721
722 // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
723 // Handled explicitly above.
724
725 // lt(x, const) -> 0 where const <= minValue of the signed type of x
726 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
727 .sle(getMinSignedValue(*width).sext(commonWidth)))
728 return getIntAttr(getType(), APInt(1, 0));
729
730 // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
731 if (isUnsigned && rhsCst->zext(commonWidth)
732 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
733 return getIntAttr(getType(), APInt(1, 1));
734
735 // lt(x, const) -> 1 where const > maxValue of the signed type of x
736 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
737 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
738 return getIntAttr(getType(), APInt(1, 1));
739 }
740 }
741
743 *this, adaptor.getOperands(), BinOpKind::Compare,
744 [=](const APSInt &a, const APSInt &b) -> APInt {
745 return APInt(1, a < b);
746 });
747}
748
749void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
750 MLIRContext *context) {
751 results.insert<patterns::GEQWithConstLHS>(context);
752}
753
754OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
755 IntType lhsType = getLhs().getType();
756 bool isUnsigned = lhsType.isUnsigned();
757
758 // geq(x, x) -> 1
759 if (getLhs() == getRhs())
760 return getIntAttr(getType(), APInt(1, 1));
761
762 // geq(x, 0) -> 1 when x is unsigned
763 if (auto rhsCst = getConstant(adaptor.getRhs())) {
764 if (rhsCst->isZero() && isUnsigned)
765 return getIntAttr(getType(), APInt(1, 1));
766 }
767
768 // Comparison against constant outside type bounds.
769 if (auto width = lhsType.getWidth()) {
770 if (auto rhsCst = getConstant(adaptor.getRhs())) {
771 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
772 commonWidth = std::max(commonWidth, 1);
773
774 // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
775 if (isUnsigned && rhsCst->zext(commonWidth)
776 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
777 return getIntAttr(getType(), APInt(1, 0));
778
779 // geq(x, const) -> 0 where const > maxValue of the signed type of x
780 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
781 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
782 return getIntAttr(getType(), APInt(1, 0));
783
784 // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
785 // Handled explicitly above.
786
787 // geq(x, const) -> 1 where const <= minValue of the signed type of x
788 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
789 .sle(getMinSignedValue(*width).sext(commonWidth)))
790 return getIntAttr(getType(), APInt(1, 1));
791 }
792 }
793
795 *this, adaptor.getOperands(), BinOpKind::Compare,
796 [=](const APSInt &a, const APSInt &b) -> APInt {
797 return APInt(1, a >= b);
798 });
799}
800
801void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
802 MLIRContext *context) {
803 results.insert<patterns::GTWithConstLHS>(context);
804}
805
806OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
807 IntType lhsType = getLhs().getType();
808 bool isUnsigned = lhsType.isUnsigned();
809
810 // gt(x, x) -> 0
811 if (getLhs() == getRhs())
812 return getIntAttr(getType(), APInt(1, 0));
813
814 // Comparison against constant outside type bounds.
815 if (auto width = lhsType.getWidth()) {
816 if (auto rhsCst = getConstant(adaptor.getRhs())) {
817 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
818 commonWidth = std::max(commonWidth, 1);
819
820 // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
821 if (isUnsigned && rhsCst->zext(commonWidth)
822 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
823 return getIntAttr(getType(), APInt(1, 0));
824
825 // gt(x, const) -> 0 where const >= maxValue of the signed type of x
826 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
827 .sge(getMaxSignedValue(*width).sext(commonWidth)))
828 return getIntAttr(getType(), APInt(1, 0));
829
830 // gt(x, const) -> 1 where const < minValue of the unsigned type of x
831 // This can never occur since const is unsigned and cannot be less than 0.
832
833 // gt(x, const) -> 1 where const < minValue of the signed type of x
834 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
835 .slt(getMinSignedValue(*width).sext(commonWidth)))
836 return getIntAttr(getType(), APInt(1, 1));
837 }
838 }
839
841 *this, adaptor.getOperands(), BinOpKind::Compare,
842 [=](const APSInt &a, const APSInt &b) -> APInt {
843 return APInt(1, a > b);
844 });
845}
846
847OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
848 // eq(x, x) -> 1
849 if (getLhs() == getRhs())
850 return getIntAttr(getType(), APInt(1, 1));
851
852 if (auto rhsCst = getConstant(adaptor.getRhs())) {
853 /// eq(x, 1) -> x when x is 1 bit.
854 /// TODO: Support SInt<1> on the LHS etc.
855 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
856 getRhs().getType() == getType())
857 return getLhs();
858 }
859
861 *this, adaptor.getOperands(), BinOpKind::Compare,
862 [=](const APSInt &a, const APSInt &b) -> APInt {
863 return APInt(1, a == b);
864 });
865}
866
867LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
868 return canonicalizePrimOp(
869 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
870 if (auto rhsCst = getConstant(operands[1])) {
871 auto width = op.getLhs().getType().getBitWidthOrSentinel();
872
873 // eq(x, 0) -> not(x) when x is 1 bit.
874 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
875 op.getRhs().getType() == op.getType()) {
876 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
877 .getResult();
878 }
879
880 // eq(x, 0) -> not(orr(x)) when x is >1 bit
881 if (rhsCst->isZero() && width > 1) {
882 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
883 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
884 }
885
886 // eq(x, ~0) -> andr(x) when x is >1 bit
887 if (rhsCst->isAllOnes() && width > 1 &&
888 op.getLhs().getType() == op.getRhs().getType()) {
889 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
890 .getResult();
891 }
892 }
893 return {};
894 });
895}
896
897OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
898 // neq(x, x) -> 0
899 if (getLhs() == getRhs())
900 return getIntAttr(getType(), APInt(1, 0));
901
902 if (auto rhsCst = getConstant(adaptor.getRhs())) {
903 /// neq(x, 0) -> x when x is 1 bit.
904 /// TODO: Support SInt<1> on the LHS etc.
905 if (rhsCst->isZero() && getLhs().getType() == getType() &&
906 getRhs().getType() == getType())
907 return getLhs();
908 }
909
911 *this, adaptor.getOperands(), BinOpKind::Compare,
912 [=](const APSInt &a, const APSInt &b) -> APInt {
913 return APInt(1, a != b);
914 });
915}
916
917LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
918 return canonicalizePrimOp(
919 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
920 if (auto rhsCst = getConstant(operands[1])) {
921 auto width = op.getLhs().getType().getBitWidthOrSentinel();
922
923 // neq(x, 1) -> not(x) when x is 1 bit
924 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
925 op.getRhs().getType() == op.getType()) {
926 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
927 .getResult();
928 }
929
930 // neq(x, 0) -> orr(x) when x is >1 bit
931 if (rhsCst->isZero() && width > 1) {
932 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
933 .getResult();
934 }
935
936 // neq(x, ~0) -> not(andr(x))) when x is >1 bit
937 if (rhsCst->isAllOnes() && width > 1 &&
938 op.getLhs().getType() == op.getRhs().getType()) {
939 auto andrOp =
940 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
941 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
942 }
943 }
944
945 return {};
946 });
947}
948
949OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
950 // TODO: implement constant folding, etc.
951 // Tracked in https://github.com/llvm/circt/issues/6696.
952 return {};
953}
954
955OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
956 // TODO: implement constant folding, etc.
957 // Tracked in https://github.com/llvm/circt/issues/6724.
958 return {};
959}
960
961OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
962 // TODO: implement constant folding, etc.
963 // Tracked in https://github.com/llvm/circt/issues/6725.
964 return {};
965}
966
967OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
968 if (auto rhsCst = getConstant(adaptor.getRhs())) {
969 // Constant folding
970 if (auto lhsCst = getConstant(adaptor.getLhs()))
971
972 return IntegerAttr::get(
973 IntegerType::get(getContext(), lhsCst->getBitWidth()),
974 lhsCst->shl(*rhsCst));
975
976 // integer.shl(x, 0) -> x
977 if (rhsCst->isZero())
978 return getLhs();
979 }
980
981 return {};
982}
983
984//===----------------------------------------------------------------------===//
985// Unary Operators
986//===----------------------------------------------------------------------===//
987
988OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
989 auto base = getInput().getType();
990 auto w = getBitWidth(base);
991 if (w)
992 return getIntAttr(getType(), APInt(32, *w));
993 return {};
994}
995
996OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
997 // No constant can be 'x' by definition.
998 if (auto cst = getConstant(adaptor.getArg()))
999 return getIntAttr(getType(), APInt(1, 0));
1000 return {};
1001}
1002
1003OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1004 // No effect.
1005 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1006 return getInput();
1007
1008 // Be careful to only fold the cast into the constant if the size is known.
1009 // Otherwise width inference may produce differently-sized constants if the
1010 // sign changes.
1011 if (getType().base().hasWidth())
1012 if (auto cst = getConstant(adaptor.getInput()))
1013 return getIntAttr(getType(), *cst);
1014
1015 return {};
1016}
1017
1018void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1019 MLIRContext *context) {
1020 results.insert<patterns::StoUtoS>(context);
1021}
1022
1023OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1024 // No effect.
1025 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1026 return getInput();
1027
1028 // Be careful to only fold the cast into the constant if the size is known.
1029 // Otherwise width inference may produce differently-sized constants if the
1030 // sign changes.
1031 if (getType().base().hasWidth())
1032 if (auto cst = getConstant(adaptor.getInput()))
1033 return getIntAttr(getType(), *cst);
1034
1035 return {};
1036}
1037
1038void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1039 MLIRContext *context) {
1040 results.insert<patterns::UtoStoU>(context);
1041}
1042
1043OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1044 // No effect.
1045 if (getInput().getType() == getType())
1046 return getInput();
1047
1048 // Constant fold.
1049 if (auto cst = getConstant(adaptor.getInput()))
1050 return BoolAttr::get(getContext(), cst->getBoolValue());
1051
1052 return {};
1053}
1054
1055OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1056 // No effect.
1057 if (getInput().getType() == getType())
1058 return getInput();
1059
1060 // Constant fold.
1061 if (auto cst = getConstant(adaptor.getInput()))
1062 return BoolAttr::get(getContext(), cst->getBoolValue());
1063
1064 return {};
1065}
1066
1067OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1068 if (!hasKnownWidthIntTypes(*this))
1069 return {};
1070
1071 // Signed to signed is a noop, unsigned operands prepend a zero bit.
1072 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1073 getType().base().getWidthOrSentinel()))
1074 return getIntAttr(getType(), *cst);
1075
1076 return {};
1077}
1078
1079void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1080 MLIRContext *context) {
1081 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1082}
1083
1084OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1085 if (!hasKnownWidthIntTypes(*this))
1086 return {};
1087
1088 // FIRRTL negate always adds a bit.
1089 // -x ---> 0-sext(x) or 0-zext(x)
1090 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1091 getType().base().getWidthOrSentinel()))
1092 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1093
1094 return {};
1095}
1096
1097OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1098 if (!hasKnownWidthIntTypes(*this))
1099 return {};
1100
1101 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1102 getType().base().getWidthOrSentinel()))
1103 return getIntAttr(getType(), ~*cst);
1104
1105 return {};
1106}
1107
1108void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1109 MLIRContext *context) {
1110 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1111 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1112 patterns::NotGt>(context);
1113}
1114
1115class ReductionCat : public mlir::RewritePattern {
1116public:
1117 ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
1118 : RewritePattern(opName, 0, context) {}
1119
1120 /// Handle a constant operand in the cat operation.
1121 /// Returns true if the entire reduction can be replaced with a constant.
1122 /// May add non-zero constants to the remaining operands list.
1123 virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1124 ConstantOp constantOp,
1125 SmallVectorImpl<Value> &remaining) const = 0;
1126
1127 /// Return the unit value for this reduction operation:
1128 virtual bool getIdentityValue() const = 0;
1129
1130 LogicalResult
1131 matchAndRewrite(Operation *op,
1132 mlir::PatternRewriter &rewriter) const override {
1133 // Check if the operand is a cat operation
1134 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1135 if (!catOp)
1136 return failure();
1137
1138 SmallVector<Value> nonConstantOperands;
1139
1140 // Process each operand of the cat operation
1141 for (auto operand : catOp.getInputs()) {
1142 if (auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1143 // Handle constant operands - may short-circuit the entire operation
1144 if (handleConstant(rewriter, op, constantOp, nonConstantOperands))
1145 return success();
1146 } else {
1147 // Keep non-constant operands for further processing
1148 nonConstantOperands.push_back(operand);
1149 }
1150 }
1151
1152 // If no operands remain, replace with identity value
1153 if (nonConstantOperands.empty()) {
1154 replaceOpWithNewOpAndCopyName<ConstantOp>(
1155 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1156 APInt(1, getIdentityValue()));
1157 return success();
1158 }
1159
1160 // If only one operand remains, apply reduction directly to it
1161 if (nonConstantOperands.size() == 1) {
1162 rewriter.modifyOpInPlace(
1163 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1164 return success();
1165 }
1166
1167 // Multiple operands remain - optimize only when cat has a single use.
1168 if (catOp->hasOneUse() &&
1169 nonConstantOperands.size() < catOp->getNumOperands()) {
1170 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1171 nonConstantOperands);
1172 }
1173 return failure();
1174 }
1175};
1176
1177class OrRCat : public ReductionCat {
1178public:
1179 OrRCat(MLIRContext *context)
1180 : ReductionCat(context, OrRPrimOp::getOperationName()) {}
1181 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1182 ConstantOp value,
1183 SmallVectorImpl<Value> &remaining) const override {
1184 if (value.getValue().isZero())
1185 return false;
1186
1187 replaceOpWithNewOpAndCopyName<ConstantOp>(
1188 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1189 llvm::APInt(1, 1));
1190 return true;
1191 }
1192 bool getIdentityValue() const override { return false; }
1193};
1194
1195class AndRCat : public ReductionCat {
1196public:
1197 AndRCat(MLIRContext *context)
1198 : ReductionCat(context, AndRPrimOp::getOperationName()) {}
1199 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1200 ConstantOp value,
1201 SmallVectorImpl<Value> &remaining) const override {
1202 if (value.getValue().isAllOnes())
1203 return false;
1204
1205 replaceOpWithNewOpAndCopyName<ConstantOp>(
1206 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1207 llvm::APInt(1, 0));
1208 return true;
1209 }
1210 bool getIdentityValue() const override { return true; }
1211};
1212
1213class XorRCat : public ReductionCat {
1214public:
1215 XorRCat(MLIRContext *context)
1216 : ReductionCat(context, XorRPrimOp::getOperationName()) {}
1217 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1218 ConstantOp value,
1219 SmallVectorImpl<Value> &remaining) const override {
1220 if (value.getValue().isZero())
1221 return false;
1222 remaining.push_back(value);
1223 return false;
1224 }
1225 bool getIdentityValue() const override { return false; }
1226};
1227
1228OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1229 if (!hasKnownWidthIntTypes(*this))
1230 return {};
1231
1232 if (getInput().getType().getBitWidthOrSentinel() == 0)
1233 return getIntAttr(getType(), APInt(1, 1));
1234
1235 // x == -1
1236 if (auto cst = getConstant(adaptor.getInput()))
1237 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1238
1239 // one bit is identity. Only applies to UInt since we can't make a cast
1240 // here.
1241 if (isUInt1(getInput().getType()))
1242 return getInput();
1243
1244 return {};
1245}
1246
1247void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1248 MLIRContext *context) {
1249 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1250 patterns::AndRPadS, patterns::AndRCatAndR_left,
1251 patterns::AndRCatAndR_right, AndRCat>(context);
1252}
1253
1254OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1255 if (!hasKnownWidthIntTypes(*this))
1256 return {};
1257
1258 if (getInput().getType().getBitWidthOrSentinel() == 0)
1259 return getIntAttr(getType(), APInt(1, 0));
1260
1261 // x != 0
1262 if (auto cst = getConstant(adaptor.getInput()))
1263 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1264
1265 // one bit is identity. Only applies to UInt since we can't make a cast
1266 // here.
1267 if (isUInt1(getInput().getType()))
1268 return getInput();
1269
1270 return {};
1271}
1272
1273void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1274 MLIRContext *context) {
1275 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1276 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right, OrRCat>(
1277 context);
1278}
1279
1280OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1281 if (!hasKnownWidthIntTypes(*this))
1282 return {};
1283
1284 if (getInput().getType().getBitWidthOrSentinel() == 0)
1285 return getIntAttr(getType(), APInt(1, 0));
1286
1287 // popcount(x) & 1
1288 if (auto cst = getConstant(adaptor.getInput()))
1289 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1290
1291 // one bit is identity. Only applies to UInt since we can't make a cast here.
1292 if (isUInt1(getInput().getType()))
1293 return getInput();
1294
1295 return {};
1296}
1297
1298void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1299 MLIRContext *context) {
1300 results
1301 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1302 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right, XorRCat>(
1303 context);
1304}
1305
1306//===----------------------------------------------------------------------===//
1307// Other Operators
1308//===----------------------------------------------------------------------===//
1309
1310OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1311 auto inputs = getInputs();
1312 auto inputAdaptors = adaptor.getInputs();
1313
1314 // If no inputs, return 0-bit value
1315 if (inputs.empty())
1316 return getIntZerosAttr(getType());
1317
1318 // If single input and same type, return it
1319 if (inputs.size() == 1 && inputs[0].getType() == getType())
1320 return inputs[0];
1321
1322 // Make sure it's safe to fold.
1323 if (!hasKnownWidthIntTypes(*this))
1324 return {};
1325
1326 // Filter out zero-width operands
1327 SmallVector<Value> nonZeroInputs;
1328 SmallVector<Attribute> nonZeroAttributes;
1329 bool allConstant = true;
1330 for (auto [input, attr] : llvm::zip(inputs, inputAdaptors)) {
1331 auto inputType = type_cast<IntType>(input.getType());
1332 if (inputType.getBitWidthOrSentinel() != 0) {
1333 nonZeroInputs.push_back(input);
1334 if (!attr)
1335 allConstant = false;
1336 if (nonZeroInputs.size() > 1 && !allConstant)
1337 return {};
1338 }
1339 }
1340
1341 // If all inputs were zero-width, return 0-bit value
1342 if (nonZeroInputs.empty())
1343 return getIntZerosAttr(getType());
1344
1345 // If only one non-zero input and it has the same type as result, return it
1346 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1347 return nonZeroInputs[0];
1348
1349 if (!hasKnownWidthIntTypes(*this))
1350 return {};
1351
1352 // Constant fold cat - concatenate all constant operands
1353 SmallVector<APInt> constants;
1354 for (auto inputAdaptor : inputAdaptors) {
1355 if (auto cst = getConstant(inputAdaptor))
1356 constants.push_back(*cst);
1357 else
1358 return {}; // Not all operands are constant
1359 }
1360
1361 assert(!constants.empty());
1362 // Concatenate all constants from left to right
1363 APInt result = constants[0];
1364 for (size_t i = 1; i < constants.size(); ++i)
1365 result = result.concat(constants[i]);
1366
1367 return getIntAttr(getType(), result);
1368}
1369
1370void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1371 MLIRContext *context) {
1372 results.insert<patterns::DShlOfConstant>(context);
1373}
1374
1375void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1376 MLIRContext *context) {
1377 results.insert<patterns::DShrOfConstant>(context);
1378}
1379
1380namespace {
1381
1382/// Canonicalize a cat tree into single cat operation.
1383class FlattenCat : public mlir::RewritePattern {
1384public:
1385 FlattenCat(MLIRContext *context)
1386 : RewritePattern(CatPrimOp::getOperationName(), 0, context) {}
1387
1388 LogicalResult
1389 matchAndRewrite(Operation *op,
1390 mlir::PatternRewriter &rewriter) const override {
1391 auto cat = cast<CatPrimOp>(op);
1392 if (!hasKnownWidthIntTypes(cat) ||
1393 cat.getType().getBitWidthOrSentinel() == 0)
1394 return failure();
1395
1396 // If this is not a root of concat tree, skip.
1397 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1398 return failure();
1399
1400 // Try flattening the cat tree.
1401 SmallVector<Value> operands;
1402 SmallVector<Value> worklist;
1403 auto pushOperands = [&worklist](CatPrimOp op) {
1404 for (auto operand : llvm::reverse(op.getInputs()))
1405 worklist.push_back(operand);
1406 };
1407 pushOperands(cat);
1408 bool hasSigned = false, hasUnsigned = false;
1409 while (!worklist.empty()) {
1410 auto value = worklist.pop_back_val();
1411 auto catOp = value.getDefiningOp<CatPrimOp>();
1412 if (!catOp) {
1413 operands.push_back(value);
1414 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) = true;
1415 continue;
1416 }
1417
1418 pushOperands(catOp);
1419 }
1420
1421 // Helper function to cast signed values to unsigned. CatPrimOp converts
1422 // signed values to unsigned so we need to do it explicitly here.
1423 auto castToUIntIfSigned = [&](Value value) -> Value {
1424 if (type_isa<UIntType>(value.getType()))
1425 return value;
1426 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1427 };
1428
1429 assert(operands.size() >= 1 && "zero width cast must be rejected");
1430
1431 if (operands.size() == 1) {
1432 rewriter.replaceOp(op, castToUIntIfSigned(operands[0]));
1433 return success();
1434 }
1435
1436 if (operands.size() == cat->getNumOperands())
1437 return failure();
1438
1439 // If types are mixed, cast all operands to unsigned.
1440 if (hasSigned && hasUnsigned)
1441 for (auto &operand : operands)
1442 operand = castToUIntIfSigned(operand);
1443
1444 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, op, cat.getType(),
1445 operands);
1446 return success();
1447 }
1448};
1449
1450// Fold successive constants cat(x, y, 1, 10, z) -> cat(x, y, 110, z)
1451class CatOfConstant : public mlir::RewritePattern {
1452public:
1453 CatOfConstant(MLIRContext *context)
1454 : RewritePattern(CatPrimOp::getOperationName(), 0, context) {}
1455
1456 LogicalResult
1457 matchAndRewrite(Operation *op,
1458 mlir::PatternRewriter &rewriter) const override {
1459 auto cat = cast<CatPrimOp>(op);
1460 if (!hasKnownWidthIntTypes(cat))
1461 return failure();
1462
1463 SmallVector<Value> operands;
1464
1465 for (size_t i = 0; i < cat->getNumOperands(); ++i) {
1466 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1467 if (!cst) {
1468 operands.push_back(cat.getInputs()[i]);
1469 continue;
1470 }
1471 APSInt value = cst.getValue();
1472 size_t j = i + 1;
1473 for (; j < cat->getNumOperands(); ++j) {
1474 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1475 if (!nextCst)
1476 break;
1477 value = value.concat(nextCst.getValue());
1478 }
1479
1480 if (j == i + 1) {
1481 // Not folded.
1482 operands.push_back(cst);
1483 } else {
1484 // Folded.
1485 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1486 }
1487
1488 i = j - 1;
1489 }
1490
1491 if (operands.size() == cat->getNumOperands())
1492 return failure();
1493
1494 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, op, cat.getType(),
1495 operands);
1496
1497 return success();
1498 }
1499};
1500
1501} // namespace
1502
1503void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1504 MLIRContext *context) {
1505 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1506 patterns::CatCast, FlattenCat, CatOfConstant>(context);
1507}
1508
1509OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1510 auto op = (*this);
1511 // BitCast is redundant if input and result types are same.
1512 if (op.getType() == op.getInput().getType())
1513 return op.getInput();
1514
1515 // Two consecutive BitCasts are redundant if first bitcast type is same as the
1516 // final result type.
1517 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1518 if (op.getType() == in.getInput().getType())
1519 return in.getInput();
1520
1521 return {};
1522}
1523
1524OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1525 IntType inputType = getInput().getType();
1526 IntType resultType = getType();
1527 // If we are extracting the entire input, then return it.
1528 if (inputType == getType() && resultType.hasWidth())
1529 return getInput();
1530
1531 // Constant fold.
1532 if (hasKnownWidthIntTypes(*this))
1533 if (auto cst = getConstant(adaptor.getInput()))
1534 return getIntAttr(resultType,
1535 cst->extractBits(getHi() - getLo() + 1, getLo()));
1536
1537 return {};
1538}
1539
1540struct BitsOfCat : public mlir::RewritePattern {
1541 BitsOfCat(MLIRContext *context)
1542 : RewritePattern(BitsPrimOp::getOperationName(), 0, context) {}
1543
1544 LogicalResult
1545 matchAndRewrite(Operation *op,
1546 mlir::PatternRewriter &rewriter) const override {
1547 auto bits = cast<BitsPrimOp>(op);
1548 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1549 if (!cat)
1550 return failure();
1551 int32_t bitPos = bits.getLo();
1552 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1553 if (resultWidth < 0)
1554 return failure();
1555 for (auto operand : llvm::reverse(cat.getInputs())) {
1556 auto operandWidth =
1557 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1558 if (operandWidth < 0)
1559 return failure();
1560 if (bitPos < operandWidth) {
1561 if (bitPos + resultWidth <= operandWidth) {
1562 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1563 op->getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1564 replaceOpAndCopyName(rewriter, op, newBits);
1565 return success();
1566 }
1567 return failure();
1568 }
1569 bitPos -= operandWidth;
1570 }
1571 return failure();
1572 }
1573};
1574
1575void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1576 MLIRContext *context) {
1577 results
1578 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1579 patterns::BitsOfAnd, patterns::BitsOfPad, BitsOfCat>(context);
1580}
1581
1582/// Replace the specified operation with a 'bits' op from the specified hi/lo
1583/// bits. Insert a cast to handle the case where the original operation
1584/// returned a signed integer.
1585static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1586 unsigned loBit, PatternRewriter &rewriter) {
1587 auto resType = type_cast<IntType>(op->getResult(0).getType());
1588 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1589 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1590
1591 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1592 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1593 } else if (resType.isUnsigned() &&
1594 !type_cast<IntType>(value.getType()).isUnsigned()) {
1595 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1596 }
1597 rewriter.replaceOp(op, value);
1598}
1599
1600template <typename OpTy>
1601static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1602 // mux : UInt<0> -> 0
1603 if (op.getType().getBitWidthOrSentinel() == 0)
1604 return getIntAttr(op.getType(),
1605 APInt(0, 0, op.getType().isSignedInteger()));
1606
1607 // mux(cond, x, x) -> x
1608 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1609 return op.getHigh();
1610
1611 // The following folds require that the result has a known width. Otherwise
1612 // the mux requires an additional padding operation to be inserted, which is
1613 // not possible in a fold.
1614 if (op.getType().getBitWidthOrSentinel() < 0)
1615 return {};
1616
1617 // mux(0/1, x, y) -> x or y
1618 if (auto cond = getConstant(adaptor.getSel())) {
1619 if (cond->isZero() && op.getLow().getType() == op.getType())
1620 return op.getLow();
1621 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1622 return op.getHigh();
1623 }
1624
1625 // mux(cond, x, cst)
1626 if (auto lowCst = getConstant(adaptor.getLow())) {
1627 // mux(cond, c1, c2)
1628 if (auto highCst = getConstant(adaptor.getHigh())) {
1629 // mux(cond, cst, cst) -> cst
1630 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1631 *highCst == *lowCst)
1632 return getIntAttr(op.getType(), *highCst);
1633 // mux(cond, 1, 0) -> cond
1634 if (highCst->isOne() && lowCst->isZero() &&
1635 op.getType() == op.getSel().getType())
1636 return op.getSel();
1637
1638 // TODO: x ? ~0 : 0 -> sext(x)
1639 // TODO: "x ? c1 : c2" -> many tricks
1640 }
1641 // TODO: "x ? a : 0" -> sext(x) & a
1642 }
1643
1644 // TODO: "x ? c1 : y" -> "~x ? y : c1"
1645 return {};
1646}
1647
1648OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1649 return foldMux(*this, adaptor);
1650}
1651
1652OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1653 return foldMux(*this, adaptor);
1654}
1655
1656OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1657
1658namespace {
1659
1660// If the mux has a known output width, pad the operands up to this width.
1661// Most folds on mux require that folded operands are of the same width as
1662// the mux itself.
1663class MuxPad : public mlir::RewritePattern {
1664public:
1665 MuxPad(MLIRContext *context)
1666 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1667
1668 LogicalResult
1669 matchAndRewrite(Operation *op,
1670 mlir::PatternRewriter &rewriter) const override {
1671 auto mux = cast<MuxPrimOp>(op);
1672 auto width = mux.getType().getBitWidthOrSentinel();
1673 if (width < 0)
1674 return failure();
1675
1676 auto pad = [&](Value input) -> Value {
1677 auto inputWidth =
1678 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1679 if (inputWidth < 0 || width == inputWidth)
1680 return input;
1681 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1682 width)
1683 .getResult();
1684 };
1685
1686 auto newHigh = pad(mux.getHigh());
1687 auto newLow = pad(mux.getLow());
1688 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1689 return failure();
1690
1691 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1692 rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1693 mux->getAttrs());
1694 return success();
1695 }
1696};
1697
1698// Find muxes which have conditions dominated by other muxes with the same
1699// condition.
1700class MuxSharedCond : public mlir::RewritePattern {
1701public:
1702 MuxSharedCond(MLIRContext *context)
1703 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1704
1705 static const int depthLimit = 5;
1706
1707 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1708 mlir::PatternRewriter &rewriter,
1709 bool updateInPlace) const {
1710 if (updateInPlace) {
1711 rewriter.modifyOpInPlace(mux, [&] {
1712 mux.setOperand(1, high);
1713 mux.setOperand(2, low);
1714 });
1715 return {};
1716 }
1717 rewriter.setInsertionPointAfter(mux);
1718 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1719 ValueRange{mux.getSel(), high, low})
1720 .getResult();
1721 }
1722
1723 // Walk a dependent mux tree assuming the condition cond is true.
1724 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1725 bool updateInPlace, int limit) const {
1726 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1727 if (!mux)
1728 return {};
1729 if (mux.getSel() == cond)
1730 return mux.getHigh();
1731 if (limit > depthLimit)
1732 return {};
1733 updateInPlace &= mux->hasOneUse();
1734
1735 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1736 limit + 1))
1737 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1738
1739 if (Value v =
1740 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1741 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1742 return {};
1743 }
1744
1745 // Walk a dependent mux tree assuming the condition cond is false.
1746 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1747 bool updateInPlace, int limit) const {
1748 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1749 if (!mux)
1750 return {};
1751 if (mux.getSel() == cond)
1752 return mux.getLow();
1753 if (limit > depthLimit)
1754 return {};
1755 updateInPlace &= mux->hasOneUse();
1756
1757 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1758 limit + 1))
1759 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1760
1761 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1762 limit + 1))
1763 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1764
1765 return {};
1766 }
1767
1768 LogicalResult
1769 matchAndRewrite(Operation *op,
1770 mlir::PatternRewriter &rewriter) const override {
1771 auto mux = cast<MuxPrimOp>(op);
1772 auto width = mux.getType().getBitWidthOrSentinel();
1773 if (width < 0)
1774 return failure();
1775
1776 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1777 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1778 return success();
1779 }
1780
1781 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
1782 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1783 return success();
1784 }
1785
1786 return failure();
1787 }
1788};
1789} // namespace
1790
1791void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1792 MLIRContext *context) {
1793 results
1794 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1795 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1796 patterns::MuxSameTrue, patterns::MuxSameFalse,
1797 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1798 context);
1799}
1800
1801void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1802 RewritePatternSet &results, MLIRContext *context) {
1803 results.add<patterns::Mux2PadSel>(context);
1804}
1805
1806void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1807 RewritePatternSet &results, MLIRContext *context) {
1808 results.add<patterns::Mux4PadSel>(context);
1809}
1810
1811OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1812 auto input = this->getInput();
1813
1814 // pad(x) -> x if the width doesn't change.
1815 if (input.getType() == getType())
1816 return input;
1817
1818 // Need to know the input width.
1819 auto inputType = input.getType().base();
1820 int32_t width = inputType.getWidthOrSentinel();
1821 if (width == -1)
1822 return {};
1823
1824 // Constant fold.
1825 if (auto cst = getConstant(adaptor.getInput())) {
1826 auto destWidth = getType().base().getWidthOrSentinel();
1827 if (destWidth == -1)
1828 return {};
1829
1830 if (inputType.isSigned() && cst->getBitWidth())
1831 return getIntAttr(getType(), cst->sext(destWidth));
1832 return getIntAttr(getType(), cst->zext(destWidth));
1833 }
1834
1835 return {};
1836}
1837
1838OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1839 auto input = this->getInput();
1840 IntType inputType = input.getType();
1841 int shiftAmount = getAmount();
1842
1843 // shl(x, 0) -> x
1844 if (shiftAmount == 0)
1845 return input;
1846
1847 // Constant fold.
1848 if (auto cst = getConstant(adaptor.getInput())) {
1849 auto inputWidth = inputType.getWidthOrSentinel();
1850 if (inputWidth != -1) {
1851 auto resultWidth = inputWidth + shiftAmount;
1852 shiftAmount = std::min(shiftAmount, resultWidth);
1853 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1854 }
1855 }
1856 return {};
1857}
1858
1859OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1860 auto input = this->getInput();
1861 IntType inputType = input.getType();
1862 int shiftAmount = getAmount();
1863 auto inputWidth = inputType.getWidthOrSentinel();
1864
1865 // shr(x, 0) -> x
1866 // Once the shr width changes, do this: shiftAmount == 0 &&
1867 // (!inputType.isSigned() || inputWidth > 0)
1868 if (shiftAmount == 0 && inputWidth > 0)
1869 return input;
1870
1871 if (inputWidth == -1)
1872 return {};
1873 if (inputWidth == 0)
1874 return getIntZerosAttr(getType());
1875
1876 // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
1877 // If x is signed, it is the sign bit.
1878 if (shiftAmount >= inputWidth && inputType.isUnsigned())
1879 return getIntAttr(getType(), APInt(0, 0, false));
1880
1881 // Constant fold.
1882 if (auto cst = getConstant(adaptor.getInput())) {
1883 APInt value;
1884 if (inputType.isSigned())
1885 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1886 else
1887 value = cst->lshr(std::min(shiftAmount, inputWidth));
1888 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1889 return getIntAttr(getType(), value.trunc(resultWidth));
1890 }
1891 return {};
1892}
1893
1894LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1895 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1896 if (inputWidth <= 0)
1897 return failure();
1898
1899 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1900 unsigned shiftAmount = op.getAmount();
1901 if (int(shiftAmount) >= inputWidth) {
1902 // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
1903 if (op.getType().base().isUnsigned())
1904 return failure();
1905
1906 // Shifting a signed value by the full width is actually taking the
1907 // sign bit. If the shift amount is greater than the input width, it
1908 // is equivalent to shifting by the input width.
1909 shiftAmount = inputWidth - 1;
1910 }
1911
1912 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1913 return success();
1914}
1915
1916LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1917 PatternRewriter &rewriter) {
1918 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1919 if (inputWidth <= 0)
1920 return failure();
1921
1922 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1923 unsigned keepAmount = op.getAmount();
1924 if (keepAmount)
1925 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1926 rewriter);
1927 return success();
1928}
1929
1930OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1931 if (hasKnownWidthIntTypes(*this))
1932 if (auto cst = getConstant(adaptor.getInput())) {
1933 int shiftAmount =
1934 getInput().getType().base().getWidthOrSentinel() - getAmount();
1935 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1936 }
1937
1938 return {};
1939}
1940
1941OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1942 if (hasKnownWidthIntTypes(*this))
1943 if (auto cst = getConstant(adaptor.getInput()))
1944 return getIntAttr(getType(),
1945 cst->trunc(getType().base().getWidthOrSentinel()));
1946 return {};
1947}
1948
1949LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1950 PatternRewriter &rewriter) {
1951 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1952 if (inputWidth <= 0)
1953 return failure();
1954
1955 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1956 unsigned dropAmount = op.getAmount();
1957 if (dropAmount != unsigned(inputWidth))
1958 replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
1959 rewriter);
1960 return success();
1961}
1962
1963void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1964 MLIRContext *context) {
1965 results.add<patterns::SubaccessOfConstant>(context);
1966}
1967
1968OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1969 // If there is only one input, just return it.
1970 if (adaptor.getInputs().size() == 1)
1971 return getOperand(1);
1972
1973 if (auto constIndex = getConstant(adaptor.getIndex())) {
1974 auto index = constIndex->getZExtValue();
1975 if (index < getInputs().size())
1976 return getInputs()[getInputs().size() - 1 - index];
1977 }
1978
1979 return {};
1980}
1981
1982LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
1983 PatternRewriter &rewriter) {
1984 // If all operands are equal, just canonicalize to it. We can add this
1985 // canonicalization as a folder but it costly to look through all inputs so it
1986 // is added here.
1987 if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
1988 return input == op.getInputs().front();
1989 })) {
1990 replaceOpAndCopyName(rewriter, op, op.getInputs().front());
1991 return success();
1992 }
1993
1994 // If the index width is narrower than the size of inputs, drop front
1995 // elements.
1996 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
1997 uint64_t inputSize = op.getInputs().size();
1998 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
1999 rewriter.modifyOpInPlace(op, [&]() {
2000 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2001 });
2002 return success();
2003 }
2004
2005 // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
2006 // a[0]`), we can fold the op into subaccess op `a[idx]`.
2007 if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2008 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
2009 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2010 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2011 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2012 })) {
2013 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2014 rewriter, op, lastSubindex.getInput(), op.getIndex());
2015 return success();
2016 }
2017 }
2018
2019 // If the size is 2, canonicalize into a normal mux to introduce more folds.
2020 if (op.getInputs().size() != 2)
2021 return failure();
2022
2023 // TODO: Handle even when `index` doesn't have uint<1>.
2024 auto uintType = op.getIndex().getType();
2025 if (uintType.getBitWidthOrSentinel() != 1)
2026 return failure();
2027
2028 // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
2029 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2030 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2031 return success();
2032}
2033
2034//===----------------------------------------------------------------------===//
2035// Declarations
2036//===----------------------------------------------------------------------===//
2037
2038/// Scan all the uses of the specified value, checking to see if there is
2039/// exactly one connect that has the value as its destination. This returns the
2040/// operation if found and if all the other users are "reads" from the value.
2041/// Returns null if there are no connects, or multiple connects to the value, or
2042/// if the value is involved in an `AttachOp`, or if the connect isn't matching.
2043///
2044/// Note that this will simply return the connect, which is located *anywhere*
2045/// after the definition of the value. Users of this function are likely
2046/// interested in the source side of the returned connect, the definition of
2047/// which does likely not dominate the original value.
2048MatchingConnectOp firrtl::getSingleConnectUserOf(Value value) {
2049 MatchingConnectOp connect;
2050 for (Operation *user : value.getUsers()) {
2051 // If we see an attach or aggregate sublements, just conservatively fail.
2052 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2053 return {};
2054
2055 if (auto aConnect = dyn_cast<FConnectLike>(user))
2056 if (aConnect.getDest() == value) {
2057 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2058 // If this is not a matching connect, a second matching connect or in a
2059 // different block, fail.
2060 if (!matchingConnect || (connect && connect != matchingConnect) ||
2061 matchingConnect->getBlock() != value.getParentBlock())
2062 return {};
2063 connect = matchingConnect;
2064 }
2065 }
2066 return connect;
2067}
2068
2069// Forward simple values through wire's and reg's.
2070static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op,
2071 PatternRewriter &rewriter) {
2072 // While we can do this for nearly all wires, we currently limit it to simple
2073 // things.
2074 Operation *connectedDecl = op.getDest().getDefiningOp();
2075 if (!connectedDecl)
2076 return failure();
2077
2078 // Only support wire and reg for now.
2079 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2080 return failure();
2081 if (hasDontTouch(connectedDecl) || !AnnotationSet(connectedDecl).empty() ||
2082 !hasDroppableName(connectedDecl) ||
2083 cast<Forceable>(connectedDecl).isForceable())
2084 return failure();
2085
2086 // Only forward if the types exactly match and there is one connect.
2087 if (getSingleConnectUserOf(op.getDest()) != op)
2088 return failure();
2089
2090 // Only forward if there is more than one use
2091 if (connectedDecl->hasOneUse())
2092 return failure();
2093
2094 // Only do this if the connectee and the declaration are in the same block.
2095 auto *declBlock = connectedDecl->getBlock();
2096 auto *srcValueOp = op.getSrc().getDefiningOp();
2097 if (!srcValueOp) {
2098 // Ports are ok for wires but not registers.
2099 if (!isa<WireOp>(connectedDecl))
2100 return failure();
2101
2102 } else {
2103 // Constants/invalids in the same block are ok to forward, even through
2104 // reg's since the clocking doesn't matter for constants.
2105 if (!isa<ConstantOp>(srcValueOp))
2106 return failure();
2107 if (srcValueOp->getBlock() != declBlock)
2108 return failure();
2109 }
2110
2111 // Ok, we know we are doing the transformation.
2112
2113 auto replacement = op.getSrc();
2114 // This will be replaced with the constant source. First, make sure the
2115 // constant dominates all users.
2116 if (srcValueOp && srcValueOp != &declBlock->front())
2117 srcValueOp->moveBefore(&declBlock->front());
2118
2119 // Replace all things *using* the decl with the constant/port, and
2120 // remove the declaration.
2121 replaceOpAndCopyName(rewriter, connectedDecl, replacement);
2122
2123 // Remove the connect
2124 rewriter.eraseOp(op);
2125 return success();
2126}
2127
2128void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2129 MLIRContext *context) {
2130 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2131 context);
2132}
2133
2134LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2135 PatternRewriter &rewriter) {
2136 // TODO: Canonicalize towards explicit extensions and flips here.
2137
2138 // If there is a simple value connected to a foldable decl like a wire or reg,
2139 // see if we can eliminate the decl.
2140 if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
2141 return success();
2142 return failure();
2143}
2144
2145//===----------------------------------------------------------------------===//
2146// Statements
2147//===----------------------------------------------------------------------===//
2148
2149/// If the specified value has an AttachOp user strictly dominating by
2150/// "dominatingAttach" then return it.
2151static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
2152 for (auto *user : value.getUsers()) {
2153 auto attach = dyn_cast<AttachOp>(user);
2154 if (!attach || attach == dominatedAttach)
2155 continue;
2156 if (attach->isBeforeInBlock(dominatedAttach))
2157 return attach;
2158 }
2159 return {};
2160}
2161
2162LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2163 // Single operand attaches are a noop.
2164 if (op.getNumOperands() <= 1) {
2165 rewriter.eraseOp(op);
2166 return success();
2167 }
2168
2169 for (auto operand : op.getOperands()) {
2170 // Check to see if any of our operands has other attaches to it:
2171 // attach x, y
2172 // ...
2173 // attach x, z
2174 // If so, we can merge these into "attach x, y, z".
2175 if (auto attach = getDominatingAttachUser(operand, op)) {
2176 SmallVector<Value> newOperands(op.getOperands());
2177 for (auto newOperand : attach.getOperands())
2178 if (newOperand != operand) // Don't add operand twice.
2179 newOperands.push_back(newOperand);
2180 AttachOp::create(rewriter, op->getLoc(), newOperands);
2181 rewriter.eraseOp(attach);
2182 rewriter.eraseOp(op);
2183 return success();
2184 }
2185
2186 // If this wire is *only* used by an attach then we can just delete
2187 // it.
2188 // TODO: May need to be sensitive to "don't touch" or other
2189 // annotations.
2190 if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2191 if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2192 !wire.isForceable()) {
2193 SmallVector<Value> newOperands;
2194 for (auto newOperand : op.getOperands())
2195 if (newOperand != operand) // Don't the add wire.
2196 newOperands.push_back(newOperand);
2197
2198 AttachOp::create(rewriter, op->getLoc(), newOperands);
2199 rewriter.eraseOp(op);
2200 rewriter.eraseOp(wire);
2201 return success();
2202 }
2203 }
2204 }
2205 return failure();
2206}
2207
2208/// Replaces the given op with the contents of the given single-block region.
2209static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
2210 Region &region) {
2211 assert(llvm::hasSingleElement(region) && "expected single-region block");
2212 rewriter.inlineBlockBefore(&region.front(), op, {});
2213}
2214
2215LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2216 if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2217 if (constant.getValue().isAllOnes())
2218 replaceOpWithRegion(rewriter, op, op.getThenRegion());
2219 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2220 replaceOpWithRegion(rewriter, op, op.getElseRegion());
2221
2222 rewriter.eraseOp(op);
2223
2224 return success();
2225 }
2226
2227 // Erase empty if-else block.
2228 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2229 op.getElseBlock().empty()) {
2230 rewriter.eraseBlock(&op.getElseBlock());
2231 return success();
2232 }
2233
2234 // Erase empty whens.
2235
2236 // If there is stuff in the then block, leave this operation alone.
2237 if (!op.getThenBlock().empty())
2238 return failure();
2239
2240 // If not and there is no else, then this operation is just useless.
2241 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2242 rewriter.eraseOp(op);
2243 return success();
2244 }
2245 return failure();
2246}
2247
2248namespace {
2249// Remove private nodes. If they have an interesting names, move the name to
2250// the source expression.
2251struct FoldNodeName : public mlir::RewritePattern {
2252 FoldNodeName(MLIRContext *context)
2253 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
2254 LogicalResult matchAndRewrite(Operation *op,
2255 PatternRewriter &rewriter) const override {
2256 auto node = cast<NodeOp>(op);
2257 auto name = node.getNameAttr();
2258 if (!node.hasDroppableName() || node.getInnerSym() ||
2259 !AnnotationSet(node).empty() || node.isForceable())
2260 return failure();
2261 auto *newOp = node.getInput().getDefiningOp();
2262 if (newOp)
2263 updateName(rewriter, newOp, name);
2264 rewriter.replaceOp(node, node.getInput());
2265 return success();
2266 }
2267};
2268
2269// Bypass nodes.
2270struct NodeBypass : public mlir::RewritePattern {
2271 NodeBypass(MLIRContext *context)
2272 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
2273 LogicalResult matchAndRewrite(Operation *op,
2274 PatternRewriter &rewriter) const override {
2275 auto node = cast<NodeOp>(op);
2276 if (node.getInnerSym() || !AnnotationSet(node).empty() ||
2277 node.use_empty() || node.isForceable())
2278 return failure();
2279 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2280 return success();
2281 }
2282};
2283
2284} // namespace
2285
2286template <typename OpTy>
2287static LogicalResult demoteForceableIfUnused(OpTy op,
2288 PatternRewriter &rewriter) {
2289 if (!op.isForceable() || !op.getDataRef().use_empty())
2290 return failure();
2291
2292 firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
2293 return success();
2294}
2295
2296// Interesting names and symbols and don't touch force nodes to stick around.
2297LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2298 SmallVectorImpl<OpFoldResult> &results) {
2299 if (!hasDroppableName())
2300 return failure();
2301 if (hasDontTouch(getResult())) // handles inner symbols
2302 return failure();
2303 if (getAnnotationsAttr() && !AnnotationSet(getAnnotationsAttr()).empty())
2304 return failure();
2305 if (isForceable())
2306 return failure();
2307 if (!adaptor.getInput())
2308 return failure();
2309
2310 results.push_back(adaptor.getInput());
2311 return success();
2312}
2313
2314void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2315 MLIRContext *context) {
2316 results.insert<FoldNodeName>(context);
2317 results.add(demoteForceableIfUnused<NodeOp>);
2318}
2319
2320namespace {
2321// For a lhs, find all the writers of fields of the aggregate type. If there
2322// is one writer for each field, merge the writes
2323struct AggOneShot : public mlir::RewritePattern {
2324 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2325 : RewritePattern(name, 0, context) {}
2326
2327 SmallVector<Value> getCompleteWrite(Operation *lhs) const {
2328 auto lhsTy = lhs->getResult(0).getType();
2329 if (!type_isa<BundleType, FVectorType>(lhsTy))
2330 return {};
2331
2332 DenseMap<uint32_t, Value> fields;
2333 for (Operation *user : lhs->getResult(0).getUsers()) {
2334 if (user->getParentOp() != lhs->getParentOp())
2335 return {};
2336 if (auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2337 if (aConnect.getDest() == lhs->getResult(0))
2338 return {};
2339 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2340 for (Operation *subuser : subField.getResult().getUsers()) {
2341 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2342 if (aConnect.getDest() == subField) {
2343 if (subuser->getParentOp() != lhs->getParentOp())
2344 return {};
2345 if (fields.count(subField.getFieldIndex())) // duplicate write
2346 return {};
2347 fields[subField.getFieldIndex()] = aConnect.getSrc();
2348 }
2349 continue;
2350 }
2351 return {};
2352 }
2353 } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2354 for (Operation *subuser : subIndex.getResult().getUsers()) {
2355 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2356 if (aConnect.getDest() == subIndex) {
2357 if (subuser->getParentOp() != lhs->getParentOp())
2358 return {};
2359 if (fields.count(subIndex.getIndex())) // duplicate write
2360 return {};
2361 fields[subIndex.getIndex()] = aConnect.getSrc();
2362 }
2363 continue;
2364 }
2365 return {};
2366 }
2367 } else {
2368 return {};
2369 }
2370 }
2371
2372 SmallVector<Value> values;
2373 uint32_t total = type_isa<BundleType>(lhsTy)
2374 ? type_cast<BundleType>(lhsTy).getNumElements()
2375 : type_cast<FVectorType>(lhsTy).getNumElements();
2376 for (uint32_t i = 0; i < total; ++i) {
2377 if (!fields.count(i))
2378 return {};
2379 values.push_back(fields[i]);
2380 }
2381 return values;
2382 }
2383
2384 LogicalResult matchAndRewrite(Operation *op,
2385 PatternRewriter &rewriter) const override {
2386 auto values = getCompleteWrite(op);
2387 if (values.empty())
2388 return failure();
2389 rewriter.setInsertionPointToEnd(op->getBlock());
2390 auto dest = op->getResult(0);
2391 auto destType = dest.getType();
2392
2393 // If not passive, cannot matchingconnect.
2394 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2395 return failure();
2396
2397 Value newVal = type_isa<BundleType>(destType)
2398 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2399 destType, values)
2400 : rewriter.createOrFold<VectorCreateOp>(
2401 op->getLoc(), destType, values);
2402 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2403 for (Operation *user : dest.getUsers()) {
2404 if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2405 for (Operation *subuser :
2406 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2407 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2408 if (aConnect.getDest() == subIndex)
2409 rewriter.eraseOp(aConnect);
2410 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2411 for (Operation *subuser :
2412 llvm::make_early_inc_range(subField.getResult().getUsers()))
2413 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2414 if (aConnect.getDest() == subField)
2415 rewriter.eraseOp(aConnect);
2416 }
2417 }
2418 return success();
2419 }
2420};
2421
2422struct WireAggOneShot : public AggOneShot {
2423 WireAggOneShot(MLIRContext *context)
2424 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2425};
2426struct SubindexAggOneShot : public AggOneShot {
2427 SubindexAggOneShot(MLIRContext *context)
2428 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2429};
2430struct SubfieldAggOneShot : public AggOneShot {
2431 SubfieldAggOneShot(MLIRContext *context)
2432 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2433};
2434} // namespace
2435
2436void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2437 MLIRContext *context) {
2438 results.insert<WireAggOneShot>(context);
2439 results.add(demoteForceableIfUnused<WireOp>);
2440}
2441
2442void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2443 MLIRContext *context) {
2444 results.insert<SubindexAggOneShot>(context);
2445}
2446
2447OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2448 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2449 if (!attr)
2450 return {};
2451 return attr[getIndex()];
2452}
2453
2454OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2455 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2456 if (!attr)
2457 return {};
2458 auto index = getFieldIndex();
2459 return attr[index];
2460}
2461
2462void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2463 MLIRContext *context) {
2464 results.insert<SubfieldAggOneShot>(context);
2465}
2466
2467static Attribute collectFields(MLIRContext *context,
2468 ArrayRef<Attribute> operands) {
2469 for (auto operand : operands)
2470 if (!operand)
2471 return {};
2472 return ArrayAttr::get(context, operands);
2473}
2474
2475OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2476 // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2477 // bundle<a:..., b:...>.
2478 if (getNumOperands() > 0)
2479 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2480 if (first.getFieldIndex() == 0 &&
2481 first.getInput().getType() == getType() &&
2482 llvm::all_of(
2483 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2484 auto subindex =
2485 elem.value().template getDefiningOp<SubfieldOp>();
2486 return subindex && subindex.getInput() == first.getInput() &&
2487 subindex.getFieldIndex() == elem.index();
2488 }))
2489 return first.getInput();
2490
2491 return collectFields(getContext(), adaptor.getOperands());
2492}
2493
2494OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2495 // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2496 // vector<..., 2>.
2497 if (getNumOperands() > 0)
2498 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2499 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2500 llvm::all_of(
2501 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2502 auto subindex =
2503 elem.value().template getDefiningOp<SubindexOp>();
2504 return subindex && subindex.getInput() == first.getInput() &&
2505 subindex.getIndex() == elem.index();
2506 }))
2507 return first.getInput();
2508
2509 return collectFields(getContext(), adaptor.getOperands());
2510}
2511
2512OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2513 if (getOperand().getType() == getType())
2514 return getOperand();
2515 return {};
2516}
2517
2518namespace {
2519// A register with constant reset and all connection to either itself or the
2520// same constant, must be replaced by the constant.
2521struct FoldResetMux : public mlir::RewritePattern {
2522 FoldResetMux(MLIRContext *context)
2523 : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2524 LogicalResult matchAndRewrite(Operation *op,
2525 PatternRewriter &rewriter) const override {
2526 auto reg = cast<RegResetOp>(op);
2527 auto reset =
2528 dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2529 if (!reset || hasDontTouch(reg.getOperation()) ||
2530 !AnnotationSet(reg).empty() || reg.isForceable())
2531 return failure();
2532 // Find the one true connect, or bail
2533 auto con = getSingleConnectUserOf(reg.getResult());
2534 if (!con)
2535 return failure();
2536
2537 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2538 if (!mux)
2539 return failure();
2540 auto *high = mux.getHigh().getDefiningOp();
2541 auto *low = mux.getLow().getDefiningOp();
2542 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2543
2544 if (constOp && low != reg)
2545 return failure();
2546 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2547 constOp = dyn_cast<ConstantOp>(low);
2548
2549 if (!constOp || constOp.getType() != reset.getType() ||
2550 constOp.getValue() != reset.getValue())
2551 return failure();
2552
2553 // Check all types should be typed by now
2554 auto regTy = reg.getResult().getType();
2555 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2556 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2557 regTy.getBitWidthOrSentinel() < 0)
2558 return failure();
2559
2560 // Ok, we know we are doing the transformation.
2561
2562 // Make sure the constant dominates all users.
2563 if (constOp != &con->getBlock()->front())
2564 constOp->moveBefore(&con->getBlock()->front());
2565
2566 // Replace the register with the constant.
2567 replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2568 // Remove the connect.
2569 rewriter.eraseOp(con);
2570 return success();
2571 }
2572};
2573} // namespace
2574
2575static bool isDefinedByOneConstantOp(Value v) {
2576 if (auto c = v.getDefiningOp<ConstantOp>())
2577 return c.getValue().isOne();
2578 if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2579 return sc.getValue();
2580 return false;
2581}
2582
2583static LogicalResult
2584canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2585 if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2586 return failure();
2587
2588 auto resetValue = reg.getResetValue();
2589 if (reg.getType(0) != resetValue.getType())
2590 return failure();
2591
2592 // Ignore 'passthrough'.
2593 (void)dropWrite(rewriter, reg->getResult(0), {});
2594 replaceOpWithNewOpAndCopyName<NodeOp>(
2595 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2596 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2597 return success();
2598}
2599
2600void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2601 MLIRContext *context) {
2602 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2604 results.add(demoteForceableIfUnused<RegResetOp>);
2605}
2606
2607// Returns the value connected to a port, if there is only one.
2608static Value getPortFieldValue(Value port, StringRef name) {
2609 auto portTy = type_cast<BundleType>(port.getType());
2610 auto fieldIndex = portTy.getElementIndex(name);
2611 assert(fieldIndex && "missing field on memory port");
2612
2613 Value value = {};
2614 for (auto *op : port.getUsers()) {
2615 auto portAccess = cast<SubfieldOp>(op);
2616 if (fieldIndex != portAccess.getFieldIndex())
2617 continue;
2618 auto conn = getSingleConnectUserOf(portAccess);
2619 if (!conn || value)
2620 return {};
2621 value = conn.getSrc();
2622 }
2623 return value;
2624}
2625
2626// Returns true if the enable field of a port is set to false.
2627static bool isPortDisabled(Value port) {
2628 auto value = getPortFieldValue(port, "en");
2629 if (!value)
2630 return false;
2631 auto portConst = value.getDefiningOp<ConstantOp>();
2632 if (!portConst)
2633 return false;
2634 return portConst.getValue().isZero();
2635}
2636
2637// Returns true if the data output is unused.
2638static bool isPortUnused(Value port, StringRef data) {
2639 auto portTy = type_cast<BundleType>(port.getType());
2640 auto fieldIndex = portTy.getElementIndex(data);
2641 assert(fieldIndex && "missing enable flag on memory port");
2642
2643 for (auto *op : port.getUsers()) {
2644 auto portAccess = cast<SubfieldOp>(op);
2645 if (fieldIndex != portAccess.getFieldIndex())
2646 continue;
2647 if (!portAccess.use_empty())
2648 return false;
2649 }
2650
2651 return true;
2652}
2653
2654// Returns the value connected to a port, if there is only one.
2655static void replacePortField(PatternRewriter &rewriter, Value port,
2656 StringRef name, Value value) {
2657 auto portTy = type_cast<BundleType>(port.getType());
2658 auto fieldIndex = portTy.getElementIndex(name);
2659 assert(fieldIndex && "missing field on memory port");
2660
2661 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2662 auto portAccess = cast<SubfieldOp>(op);
2663 if (fieldIndex != portAccess.getFieldIndex())
2664 continue;
2665 rewriter.replaceAllUsesWith(portAccess, value);
2666 rewriter.eraseOp(portAccess);
2667 }
2668}
2669
2670// Remove accesses to a port which is used.
2671static void erasePort(PatternRewriter &rewriter, Value port) {
2672 // Helper to create a dummy 0 clock for the dummy registers.
2673 Value clock;
2674 auto getClock = [&] {
2675 if (!clock)
2676 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2677 ClockType::get(rewriter.getContext()),
2678 false);
2679 return clock;
2680 };
2681
2682 // Find the clock field of the port and determine whether the port is
2683 // accessed only through its subfields or as a whole wire. If the port
2684 // is used in its entirety, replace it with a wire. Otherwise,
2685 // eliminate individual subfields and replace with reasonable defaults.
2686 for (auto *op : port.getUsers()) {
2687 auto subfield = dyn_cast<SubfieldOp>(op);
2688 if (!subfield) {
2689 auto ty = port.getType();
2690 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2691 rewriter.replaceAllUsesWith(port, reg.getResult());
2692 return;
2693 }
2694 }
2695
2696 // Remove all connects to field accesses as they are no longer relevant.
2697 // If field values are used anywhere, which should happen solely for read
2698 // ports, a dummy register is introduced which replicates the behaviour of
2699 // memory that is never written, but might be read.
2700 for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2701 auto access = cast<SubfieldOp>(accessOp);
2702 for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2703 auto connect = dyn_cast<FConnectLike>(user);
2704 if (connect && connect.getDest() == access) {
2705 rewriter.eraseOp(user);
2706 continue;
2707 }
2708 }
2709 if (access.use_empty()) {
2710 rewriter.eraseOp(access);
2711 continue;
2712 }
2713
2714 // Replace read values with a register that is never written, handing off
2715 // the canonicalization of such a register to another canonicalizer.
2716 auto ty = access.getType();
2717 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2718 rewriter.replaceOp(access, reg.getResult());
2719 }
2720 assert(port.use_empty() && "port should have no remaining uses");
2721}
2722
2723namespace {
2724// If memory has known, but zero width, eliminate it.
2725struct FoldZeroWidthMemory : public mlir::RewritePattern {
2726 FoldZeroWidthMemory(MLIRContext *context)
2727 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2728 LogicalResult matchAndRewrite(Operation *op,
2729 PatternRewriter &rewriter) const override {
2730 MemOp mem = cast<MemOp>(op);
2731 if (hasDontTouch(mem))
2732 return failure();
2733
2734 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2735 mem.getDataType().getBitWidthOrSentinel() != 0)
2736 return failure();
2737
2738 // Make sure are users are safe to replace
2739 for (auto port : mem.getResults())
2740 for (auto *user : port.getUsers())
2741 if (!isa<SubfieldOp>(user))
2742 return failure();
2743
2744 // Annoyingly, there isn't a good replacement for the port as a whole,
2745 // since they have an outer flip type.
2746 for (auto port : op->getResults()) {
2747 for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2748 SubfieldOp sfop = cast<SubfieldOp>(user);
2749 StringRef fieldName = sfop.getFieldName();
2750 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2751 rewriter, sfop, sfop.getResult().getType())
2752 .getResult();
2753 if (fieldName.ends_with("data")) {
2754 // Make sure to write data ports.
2755 auto zero = firrtl::ConstantOp::create(
2756 rewriter, wire.getLoc(),
2757 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2758 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2759 }
2760 }
2761 }
2762 rewriter.eraseOp(op);
2763 return success();
2764 }
2765};
2766
2767// If memory has no write ports and no file initialization, eliminate it.
2768struct FoldReadOrWriteOnlyMemory : public mlir::RewritePattern {
2769 FoldReadOrWriteOnlyMemory(MLIRContext *context)
2770 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2771 LogicalResult matchAndRewrite(Operation *op,
2772 PatternRewriter &rewriter) const override {
2773 MemOp mem = cast<MemOp>(op);
2774 if (hasDontTouch(mem))
2775 return failure();
2776 bool isRead = false, isWritten = false;
2777 for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2778 switch (mem.getPortKind(i)) {
2779 case MemOp::PortKind::Read:
2780 isRead = true;
2781 if (isWritten)
2782 return failure();
2783 continue;
2784 case MemOp::PortKind::Write:
2785 isWritten = true;
2786 if (isRead)
2787 return failure();
2788 continue;
2789 case MemOp::PortKind::Debug:
2790 case MemOp::PortKind::ReadWrite:
2791 return failure();
2792 }
2793 llvm_unreachable("unknown port kind");
2794 }
2795 assert((!isWritten || !isRead) && "memory is in use");
2796
2797 // If the memory is read only, but has a file initialization, then we can't
2798 // remove it. A write only memory with file initialization is okay to
2799 // remove.
2800 if (isRead && mem.getInit())
2801 return failure();
2802
2803 for (auto port : mem.getResults())
2804 erasePort(rewriter, port);
2805
2806 rewriter.eraseOp(op);
2807 return success();
2808 }
2809};
2810
2811// Eliminate the dead ports of memories.
2812struct FoldUnusedPorts : public mlir::RewritePattern {
2813 FoldUnusedPorts(MLIRContext *context)
2814 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2815 LogicalResult matchAndRewrite(Operation *op,
2816 PatternRewriter &rewriter) const override {
2817 MemOp mem = cast<MemOp>(op);
2818 if (hasDontTouch(mem))
2819 return failure();
2820 // Identify the dead and changed ports.
2821 llvm::SmallBitVector deadPorts(mem.getNumResults());
2822 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2823 // Do not simplify annotated ports.
2824 if (!mem.getPortAnnotation(i).empty())
2825 continue;
2826
2827 // Skip debug ports.
2828 auto kind = mem.getPortKind(i);
2829 if (kind == MemOp::PortKind::Debug)
2830 continue;
2831
2832 // If a port is disabled, always eliminate it.
2833 if (isPortDisabled(port)) {
2834 deadPorts.set(i);
2835 continue;
2836 }
2837 // Eliminate read ports whose outputs are not used.
2838 if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
2839 deadPorts.set(i);
2840 continue;
2841 }
2842 }
2843 if (deadPorts.none())
2844 return failure();
2845
2846 // Rebuild the new memory with the altered ports.
2847 SmallVector<Type> resultTypes;
2848 SmallVector<StringRef> portNames;
2849 SmallVector<Attribute> portAnnotations;
2850 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2851 if (deadPorts[i])
2852 continue;
2853 resultTypes.push_back(port.getType());
2854 portNames.push_back(mem.getPortName(i));
2855 portAnnotations.push_back(mem.getPortAnnotation(i));
2856 }
2857
2858 MemOp newOp;
2859 if (!resultTypes.empty())
2860 newOp = MemOp::create(
2861 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2862 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2863 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2864 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2865 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2866
2867 // Replace the dead ports with dummy wires.
2868 unsigned nextPort = 0;
2869 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2870 if (deadPorts[i])
2871 erasePort(rewriter, port);
2872 else
2873 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2874 }
2875
2876 rewriter.eraseOp(op);
2877 return success();
2878 }
2879};
2880
2881// Rewrite write-only read-write ports to write ports.
2882struct FoldReadWritePorts : public mlir::RewritePattern {
2883 FoldReadWritePorts(MLIRContext *context)
2884 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2885 LogicalResult matchAndRewrite(Operation *op,
2886 PatternRewriter &rewriter) const override {
2887 MemOp mem = cast<MemOp>(op);
2888 if (hasDontTouch(mem))
2889 return failure();
2890
2891 // Identify read-write ports whose read end is unused.
2892 llvm::SmallBitVector deadReads(mem.getNumResults());
2893 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2894 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2895 continue;
2896 if (!mem.getPortAnnotation(i).empty())
2897 continue;
2898 if (isPortUnused(port, "rdata")) {
2899 deadReads.set(i);
2900 continue;
2901 }
2902 }
2903 if (deadReads.none())
2904 return failure();
2905
2906 SmallVector<Type> resultTypes;
2907 SmallVector<StringRef> portNames;
2908 SmallVector<Attribute> portAnnotations;
2909 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2910 if (deadReads[i])
2911 resultTypes.push_back(
2912 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2913 MemOp::PortKind::Write, mem.getMaskBits()));
2914 else
2915 resultTypes.push_back(port.getType());
2916
2917 portNames.push_back(mem.getPortName(i));
2918 portAnnotations.push_back(mem.getPortAnnotation(i));
2919 }
2920
2921 auto newOp = MemOp::create(
2922 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2923 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2924 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2925 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2926 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2927
2928 for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2929 auto result = mem.getResult(i);
2930 auto newResult = newOp.getResult(i);
2931 if (deadReads[i]) {
2932 auto resultPortTy = type_cast<BundleType>(result.getType());
2933
2934 // Rewrite accesses to the old port field to accesses to a
2935 // corresponding field of the new port.
2936 auto replace = [&](StringRef toName, StringRef fromName) {
2937 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2938 assert(fromFieldIndex && "missing enable flag on memory port");
2939
2940 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
2941 newResult, toName);
2942 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2943 auto fromField = cast<SubfieldOp>(op);
2944 if (fromFieldIndex != fromField.getFieldIndex())
2945 continue;
2946 rewriter.replaceOp(fromField, toField.getResult());
2947 }
2948 };
2949
2950 replace("addr", "addr");
2951 replace("en", "en");
2952 replace("clk", "clk");
2953 replace("data", "wdata");
2954 replace("mask", "wmask");
2955
2956 // Remove the wmode field, replacing it with dummy wires.
2957 auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
2958 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2959 auto wmodeField = cast<SubfieldOp>(op);
2960 if (wmodeFieldIndex != wmodeField.getFieldIndex())
2961 continue;
2962 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2963 }
2964 } else {
2965 rewriter.replaceAllUsesWith(result, newResult);
2966 }
2967 }
2968 rewriter.eraseOp(op);
2969 return success();
2970 }
2971};
2972
2973// Eliminate the dead ports of memories.
2974struct FoldUnusedBits : public mlir::RewritePattern {
2975 FoldUnusedBits(MLIRContext *context)
2976 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2977
2978 LogicalResult matchAndRewrite(Operation *op,
2979 PatternRewriter &rewriter) const override {
2980 MemOp mem = cast<MemOp>(op);
2981 if (hasDontTouch(mem))
2982 return failure();
2983
2984 // Only apply the transformation if the memory is not sequential.
2985 const auto &summary = mem.getSummary();
2986 if (summary.isMasked || summary.isSeqMem())
2987 return failure();
2988
2989 auto type = type_dyn_cast<IntType>(mem.getDataType());
2990 if (!type)
2991 return failure();
2992 auto width = type.getBitWidthOrSentinel();
2993 if (width <= 0)
2994 return failure();
2995
2996 llvm::SmallBitVector usedBits(width);
2997 DenseMap<unsigned, unsigned> mapping;
2998
2999 // Find which bits are used out of the users of a read port. This detects
3000 // ports whose data/rdata field is used only through bit select ops. The
3001 // bit selects are then used to build a bit-mask. The ops are collected.
3002 SmallVector<BitsPrimOp> readOps;
3003 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3004 auto portTy = type_cast<BundleType>(port.getType());
3005 auto fieldIndex = portTy.getElementIndex(field);
3006 assert(fieldIndex && "missing data port");
3007
3008 for (auto *op : port.getUsers()) {
3009 auto portAccess = cast<SubfieldOp>(op);
3010 if (fieldIndex != portAccess.getFieldIndex())
3011 continue;
3012
3013 for (auto *user : op->getUsers()) {
3014 auto bits = dyn_cast<BitsPrimOp>(user);
3015 if (!bits)
3016 return failure();
3017
3018 usedBits.set(bits.getLo(), bits.getHi() + 1);
3019 if (usedBits.all())
3020 return failure();
3021
3022 mapping[bits.getLo()] = 0;
3023 readOps.push_back(bits);
3024 }
3025 }
3026
3027 return success();
3028 };
3029
3030 // Finds the users of write ports. This expects all the data/wdata fields
3031 // of the ports to be used solely as the destination of matching connects.
3032 // If a memory has ports with other uses, it is excluded from optimisation.
3033 SmallVector<MatchingConnectOp> writeOps;
3034 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3035 auto portTy = type_cast<BundleType>(port.getType());
3036 auto fieldIndex = portTy.getElementIndex(field);
3037 assert(fieldIndex && "missing data port");
3038
3039 for (auto *op : port.getUsers()) {
3040 auto portAccess = cast<SubfieldOp>(op);
3041 if (fieldIndex != portAccess.getFieldIndex())
3042 continue;
3043
3044 auto conn = getSingleConnectUserOf(portAccess);
3045 if (!conn)
3046 return failure();
3047
3048 writeOps.push_back(conn);
3049 }
3050 return success();
3051 };
3052
3053 // Traverse all ports and find the read and used data fields.
3054 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3055 // Do not simplify annotated ports.
3056 if (!mem.getPortAnnotation(i).empty())
3057 return failure();
3058
3059 switch (mem.getPortKind(i)) {
3060 case MemOp::PortKind::Debug:
3061 // Skip debug ports.
3062 return failure();
3063 case MemOp::PortKind::Write:
3064 if (failed(findWriteUsers(port, "data")))
3065 return failure();
3066 continue;
3067 case MemOp::PortKind::Read:
3068 if (failed(findReadUsers(port, "data")))
3069 return failure();
3070 continue;
3071 case MemOp::PortKind::ReadWrite:
3072 if (failed(findWriteUsers(port, "wdata")))
3073 return failure();
3074 if (failed(findReadUsers(port, "rdata")))
3075 return failure();
3076 continue;
3077 }
3078 llvm_unreachable("unknown port kind");
3079 }
3080
3081 // Unused memories are handled in a different canonicalizer.
3082 if (usedBits.none())
3083 return failure();
3084
3085 // Build a mapping of existing indices to compacted ones.
3086 SmallVector<std::pair<unsigned, unsigned>> ranges;
3087 unsigned newWidth = 0;
3088 for (int i = usedBits.find_first(); 0 <= i && i < width;) {
3089 int e = usedBits.find_next_unset(i);
3090 if (e < 0)
3091 e = width;
3092 for (int idx = i; idx < e; ++idx, ++newWidth) {
3093 if (auto it = mapping.find(idx); it != mapping.end()) {
3094 it->second = newWidth;
3095 }
3096 }
3097 ranges.emplace_back(i, e - 1);
3098 i = e != width ? usedBits.find_next(e) : e;
3099 }
3100
3101 // Create the new op with the new port types.
3102 auto newType = IntType::get(op->getContext(), type.isSigned(), newWidth);
3103 SmallVector<Type> portTypes;
3104 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3105 portTypes.push_back(
3106 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3107 }
3108 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3109 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3110 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3111 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3112 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3113
3114 // Rewrite bundle users to the new data type.
3115 auto rewriteSubfield = [&](Value port, StringRef field) {
3116 auto portTy = type_cast<BundleType>(port.getType());
3117 auto fieldIndex = portTy.getElementIndex(field);
3118 assert(fieldIndex && "missing data port");
3119
3120 rewriter.setInsertionPointAfter(newMem);
3121 auto newPortAccess =
3122 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3123
3124 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
3125 auto portAccess = cast<SubfieldOp>(op);
3126 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3127 continue;
3128 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3129 }
3130 };
3131
3132 // Rewrite the field accesses.
3133 for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
3134 switch (newMem.getPortKind(i)) {
3135 case MemOp::PortKind::Debug:
3136 llvm_unreachable("cannot rewrite debug port");
3137 case MemOp::PortKind::Write:
3138 rewriteSubfield(port, "data");
3139 continue;
3140 case MemOp::PortKind::Read:
3141 rewriteSubfield(port, "data");
3142 continue;
3143 case MemOp::PortKind::ReadWrite:
3144 rewriteSubfield(port, "rdata");
3145 rewriteSubfield(port, "wdata");
3146 continue;
3147 }
3148 llvm_unreachable("unknown port kind");
3149 }
3150
3151 // Rewrite the reads to the new ranges, compacting them.
3152 for (auto readOp : readOps) {
3153 rewriter.setInsertionPointAfter(readOp);
3154 auto it = mapping.find(readOp.getLo());
3155 assert(it != mapping.end() && "bit op mapping not found");
3156 // Create a new bit selection from the compressed memory. The new op may
3157 // be folded if we are selecting the entire compressed memory.
3158 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3159 readOp.getLoc(), readOp.getInput(),
3160 readOp.getHi() - readOp.getLo() + it->second, it->second);
3161 rewriter.replaceAllUsesWith(readOp, newReadValue);
3162 rewriter.eraseOp(readOp);
3163 }
3164
3165 // Rewrite the writes into a concatenation of slices.
3166 for (auto writeOp : writeOps) {
3167 Value source = writeOp.getSrc();
3168 rewriter.setInsertionPoint(writeOp);
3169
3170 SmallVector<Value> slices;
3171 for (auto &[start, end] : llvm::reverse(ranges)) {
3172 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3173 source, end, start);
3174 slices.push_back(slice);
3175 }
3176
3177 Value catOfSlices =
3178 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3179
3180 // If the original memory held a signed integer, then the compressed
3181 // memory will be signed too. Since the catOfSlices is always unsigned,
3182 // cast the data to a signed integer if needed before connecting back to
3183 // the memory.
3184 if (type.isSigned())
3185 catOfSlices =
3186 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3187
3188 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3189 catOfSlices);
3190 }
3191
3192 return success();
3193 }
3194};
3195
3196// Rewrite single-address memories to a firrtl register.
3197struct FoldRegMems : public mlir::RewritePattern {
3198 FoldRegMems(MLIRContext *context)
3199 : RewritePattern(MemOp::getOperationName(), 0, context) {}
3200 LogicalResult matchAndRewrite(Operation *op,
3201 PatternRewriter &rewriter) const override {
3202 MemOp mem = cast<MemOp>(op);
3203 const FirMemory &info = mem.getSummary();
3204 if (hasDontTouch(mem) || info.depth != 1)
3205 return failure();
3206
3207 auto ty = mem.getDataType();
3208 auto loc = mem.getLoc();
3209 auto *block = mem->getBlock();
3210
3211 // Find the clock of the register-to-be, all write ports should share it.
3212 Value clock;
3213 SmallPtrSet<Operation *, 8> connects;
3214 SmallVector<SubfieldOp> portAccesses;
3215 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3216 if (!mem.getPortAnnotation(i).empty())
3217 continue;
3218
3219 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3220 auto portTy = type_cast<BundleType>(port.getType());
3221 for (auto field : fields) {
3222 auto fieldIndex = portTy.getElementIndex(field);
3223 assert(fieldIndex && "missing field on memory port");
3224
3225 for (auto *op : port.getUsers()) {
3226 auto portAccess = cast<SubfieldOp>(op);
3227 if (fieldIndex != portAccess.getFieldIndex())
3228 continue;
3229 portAccesses.push_back(portAccess);
3230 for (auto *user : portAccess->getUsers()) {
3231 auto conn = dyn_cast<FConnectLike>(user);
3232 if (!conn)
3233 return failure();
3234 connects.insert(conn);
3235 }
3236 }
3237 }
3238 return success();
3239 };
3240
3241 switch (mem.getPortKind(i)) {
3242 case MemOp::PortKind::Debug:
3243 return failure();
3244 case MemOp::PortKind::Read:
3245 if (failed(collect({"clk", "en", "addr"})))
3246 return failure();
3247 continue;
3248 case MemOp::PortKind::Write:
3249 if (failed(collect({"clk", "en", "addr", "data", "mask"})))
3250 return failure();
3251 break;
3252 case MemOp::PortKind::ReadWrite:
3253 if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
3254 return failure();
3255 break;
3256 }
3257
3258 Value portClock = getPortFieldValue(port, "clk");
3259 if (!portClock || (clock && portClock != clock))
3260 return failure();
3261 clock = portClock;
3262 }
3263 // Create a new wire where the memory used to be. This wire will dominate
3264 // all readers of the memory. Reads should be made through this wire.
3265 rewriter.setInsertionPointAfter(mem);
3266 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3267
3268 // The memory is replaced by a register, which we place at the end of the
3269 // block, so that any value driven to the original memory will dominate the
3270 // new register (including the clock). All other ops will be placed
3271 // after the register.
3272 rewriter.setInsertionPointToEnd(block);
3273 auto memReg =
3274 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3275
3276 // Connect the output of the register to the wire.
3277 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3278
3279 // Helper to insert a given number of pipeline stages through registers.
3280 // The pipelines are placed at the end of the block.
3281 auto pipeline = [&](Value value, Value clock, const Twine &name,
3282 unsigned latency) {
3283 for (unsigned i = 0; i < latency; ++i) {
3284 std::string regName;
3285 {
3286 llvm::raw_string_ostream os(regName);
3287 os << mem.getName() << "_" << name << "_" << i;
3288 }
3289 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3290 rewriter.getStringAttr(regName))
3291 .getResult();
3292 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3293 value = reg;
3294 }
3295 return value;
3296 };
3297
3298 const unsigned writeStages = info.writeLatency - 1;
3299
3300 // Traverse each port. Replace reads with the pipelined register, discarding
3301 // the enable flag and reading unconditionally. Pipeline the mask, enable
3302 // and data bits of all write ports to be arbitrated and wired to the reg.
3303 SmallVector<std::tuple<Value, Value, Value>> writes;
3304 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3305 Value portClock = getPortFieldValue(port, "clk");
3306 StringRef name = mem.getPortName(i);
3307
3308 auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
3309 Value value = getPortFieldValue(port, field);
3310 assert(value);
3311 return pipeline(value, portClock, name + "_" + field, stages);
3312 };
3313
3314 switch (mem.getPortKind(i)) {
3315 case MemOp::PortKind::Debug:
3316 llvm_unreachable("unknown port kind");
3317 case MemOp::PortKind::Read: {
3318 // Read ports pipeline the addr and enable signals. However, the
3319 // address must be 0 for single-address memories and the enable signal
3320 // is ignored, always reading out the register. Under these constraints,
3321 // the read port can be replaced with the value from the register.
3322 replacePortField(rewriter, port, "data", memWire);
3323 break;
3324 }
3325 case MemOp::PortKind::Write: {
3326 auto data = portPipeline("data", writeStages);
3327 auto en = portPipeline("en", writeStages);
3328 auto mask = portPipeline("mask", writeStages);
3329 writes.emplace_back(data, en, mask);
3330 break;
3331 }
3332 case MemOp::PortKind::ReadWrite: {
3333 // Always read the register into the read end.
3334 replacePortField(rewriter, port, "rdata", memWire);
3335
3336 // Create a write enable and pipeline stages.
3337 auto wdata = portPipeline("wdata", writeStages);
3338 auto wmask = portPipeline("wmask", writeStages);
3339
3340 Value en = getPortFieldValue(port, "en");
3341 Value wmode = getPortFieldValue(port, "wmode");
3342
3343 auto wen = AndPrimOp::create(rewriter, port.getLoc(), en, wmode);
3344 auto wenPipelined =
3345 pipeline(wen, portClock, name + "_wen", writeStages);
3346 writes.emplace_back(wdata, wenPipelined, wmask);
3347 break;
3348 }
3349 }
3350 }
3351
3352 // Regardless of `writeUnderWrite`, always implement PortOrder.
3353 Value next = memReg;
3354 for (auto &[data, en, mask] : writes) {
3355 Value masked;
3356
3357 // If a mask bit is used, emit muxes to select the input from the
3358 // register (no mask) or the input (mask bit set).
3359 Location loc = mem.getLoc();
3360 unsigned maskGran = info.dataWidth / info.maskBits;
3361 SmallVector<Value> chunks;
3362 for (unsigned i = 0; i < info.maskBits; ++i) {
3363 unsigned hi = (i + 1) * maskGran - 1;
3364 unsigned lo = i * maskGran;
3365
3366 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
3367 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3368 auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
3369 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3370 chunks.push_back(chunk);
3371 }
3372
3373 std::reverse(chunks.begin(), chunks.end());
3374 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3375 next = MuxPrimOp::create(rewriter, next.getLoc(), en, masked, next);
3376 }
3377 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3378 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3379
3380 // Delete the fields and their associated connects.
3381 for (Operation *conn : connects)
3382 rewriter.eraseOp(conn);
3383 for (auto portAccess : portAccesses)
3384 rewriter.eraseOp(portAccess);
3385 rewriter.eraseOp(mem);
3386
3387 return success();
3388 }
3389};
3390} // namespace
3391
3392void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3393 MLIRContext *context) {
3394 results
3395 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3396 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3397 context);
3398}
3399
3400//===----------------------------------------------------------------------===//
3401// Declarations
3402//===----------------------------------------------------------------------===//
3403
3404// Turn synchronous reset looking register updates to registers with resets.
3405// Also, const prop registers that are driven by a mux tree containing only
3406// instances of one constant or self-assigns.
3407static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3408 // reg ; connect(reg, mux(port, const, val)) ->
3409 // reg.reset(port, const); connect(reg, val)
3410
3411 // Find the one true connect, or bail
3412 auto con = getSingleConnectUserOf(reg.getResult());
3413 if (!con)
3414 return failure();
3415
3416 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3417 if (!mux)
3418 return failure();
3419 auto *high = mux.getHigh().getDefiningOp();
3420 auto *low = mux.getLow().getDefiningOp();
3421 // Reset value must be constant
3422 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3423
3424 // Detect the case if a register only has two possible drivers:
3425 // (1) itself/uninit and (2) constant.
3426 // The mux can then be replaced with the constant.
3427 // r = mux(cond, r, 3) --> r = 3
3428 // r = mux(cond, 3, r) --> r = 3
3429 bool constReg = false;
3430
3431 if (constOp && low == reg)
3432 constReg = true;
3433 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3434 constReg = true;
3435 constOp = dyn_cast<ConstantOp>(low);
3436 }
3437 if (!constOp)
3438 return failure();
3439
3440 // For a non-constant register, reset should be a module port (heuristic to
3441 // limit to intended reset lines). Replace the register anyway if constant.
3442 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3443 return failure();
3444
3445 // Check all types should be typed by now
3446 auto regTy = reg.getResult().getType();
3447 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3448 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3449 regTy.getBitWidthOrSentinel() < 0)
3450 return failure();
3451
3452 // Ok, we know we are doing the transformation.
3453
3454 // Make sure the constant dominates all users.
3455 if (constOp != &con->getBlock()->front())
3456 constOp->moveBefore(&con->getBlock()->front());
3457
3458 if (!constReg) {
3459 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3460 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3461 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3462 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3463 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3464 reg.getForceableAttr());
3465 newReg->setDialectAttrs(attrs);
3466 }
3467 auto pt = rewriter.saveInsertionPoint();
3468 rewriter.setInsertionPoint(con);
3469 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3470 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3471 rewriter.restoreInsertionPoint(pt);
3472 return success();
3473}
3474
3475LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3476 if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3477 succeeded(foldHiddenReset(op, rewriter)))
3478 return success();
3479
3480 if (succeeded(demoteForceableIfUnused(op, rewriter)))
3481 return success();
3482
3483 return failure();
3484}
3485
3486//===----------------------------------------------------------------------===//
3487// Verification Ops.
3488//===----------------------------------------------------------------------===//
3489
3490static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3491 Value enable,
3492 PatternRewriter &rewriter,
3493 bool eraseIfZero) {
3494 // If the verification op is never enabled, delete it.
3495 if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3496 if (constant.getValue().isZero()) {
3497 rewriter.eraseOp(op);
3498 return success();
3499 }
3500 }
3501
3502 // If the verification op is never triggered, delete it.
3503 if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3504 if (constant.getValue().isZero() == eraseIfZero) {
3505 rewriter.eraseOp(op);
3506 return success();
3507 }
3508 }
3509
3510 return failure();
3511}
3512
3513template <class Op, bool EraseIfZero = false>
3514static LogicalResult canonicalizeImmediateVerifOp(Op op,
3515 PatternRewriter &rewriter) {
3516 return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3517 EraseIfZero);
3518}
3519
3520void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3521 MLIRContext *context) {
3522 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3523 results.add<patterns::AssertXWhenX>(context);
3524}
3525
3526void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3527 MLIRContext *context) {
3528 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3529 results.add<patterns::AssumeXWhenX>(context);
3530}
3531
3532void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3533 RewritePatternSet &results, MLIRContext *context) {
3534 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3535 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(context);
3536}
3537
3538void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3539 MLIRContext *context) {
3540 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3541}
3542
3543//===----------------------------------------------------------------------===//
3544// InvalidValueOp
3545//===----------------------------------------------------------------------===//
3546
3547LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3548 PatternRewriter &rewriter) {
3549 // Remove `InvalidValueOp`s with no uses.
3550 if (op.use_empty()) {
3551 rewriter.eraseOp(op);
3552 return success();
3553 }
3554 // Propagate invalids through a single use which is a unary op. You cannot
3555 // propagate through multiple uses as that breaks invalid semantics. Nor
3556 // can you propagate through binary ops or generally any op which computes.
3557 // Not is an exception as it is a pure, all-bits inverse.
3558 if (op->hasOneUse() &&
3559 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3560 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3561 *op->user_begin()) ||
3562 (isa<CvtPrimOp>(*op->user_begin()) &&
3563 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3564 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3565 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3566 .getBitWidthOrSentinel() > 0))) {
3567 auto *modop = *op->user_begin();
3568 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3569 modop->getResult(0).getType());
3570 rewriter.replaceAllOpUsesWith(modop, inv);
3571 rewriter.eraseOp(modop);
3572 rewriter.eraseOp(op);
3573 return success();
3574 }
3575 return failure();
3576}
3577
3578OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3579 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3580 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3581 return {};
3582}
3583
3584//===----------------------------------------------------------------------===//
3585// ClockGateIntrinsicOp
3586//===----------------------------------------------------------------------===//
3587
3588OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3589 // Forward the clock if one of the enables is always true.
3590 if (isConstantOne(adaptor.getEnable()) ||
3591 isConstantOne(adaptor.getTestEnable()))
3592 return getInput();
3593
3594 // Fold to a constant zero clock if the enables are always false.
3595 if (isConstantZero(adaptor.getEnable()) &&
3596 (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3597 return BoolAttr::get(getContext(), false);
3598
3599 // Forward constant zero clocks.
3600 if (isConstantZero(adaptor.getInput()))
3601 return BoolAttr::get(getContext(), false);
3602
3603 return {};
3604}
3605
3606LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3607 PatternRewriter &rewriter) {
3608 // Remove constant false test enable.
3609 if (auto testEnable = op.getTestEnable()) {
3610 if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3611 if (constOp.getValue().isZero()) {
3612 rewriter.modifyOpInPlace(op,
3613 [&] { op.getTestEnableMutable().clear(); });
3614 return success();
3615 }
3616 }
3617 }
3618
3619 return failure();
3620}
3621
3622//===----------------------------------------------------------------------===//
3623// Reference Ops.
3624//===----------------------------------------------------------------------===//
3625
3626// refresolve(forceable.ref) -> forceable.data
3627static LogicalResult
3628canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3629 auto forceable = op.getRef().getDefiningOp<Forceable>();
3630 if (!forceable || !forceable.isForceable() ||
3631 op.getRef() != forceable.getDataRef() ||
3632 op.getType() != forceable.getDataType())
3633 return failure();
3634 rewriter.replaceAllUsesWith(op, forceable.getData());
3635 return success();
3636}
3637
3638void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3639 MLIRContext *context) {
3640 results.insert<patterns::RefResolveOfRefSend>(context);
3641 results.insert(canonicalizeRefResolveOfForceable);
3642}
3643
3644OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3645 // RefCast is unnecessary if types match.
3646 if (getInput().getType() == getType())
3647 return getInput();
3648 return {};
3649}
3650
3651static bool isConstantZero(Value operand) {
3652 auto constOp = operand.getDefiningOp<ConstantOp>();
3653 return constOp && constOp.getValue().isZero();
3654}
3655
3656template <typename Op>
3657static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3658 if (isConstantZero(op.getPredicate())) {
3659 rewriter.eraseOp(op);
3660 return success();
3661 }
3662 return failure();
3663}
3664
3665void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3666 MLIRContext *context) {
3667 results.add(eraseIfPredFalse<RefForceOp>);
3668}
3669void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3670 MLIRContext *context) {
3671 results.add(eraseIfPredFalse<RefForceInitialOp>);
3672}
3673void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3674 MLIRContext *context) {
3675 results.add(eraseIfPredFalse<RefReleaseOp>);
3676}
3677void RefReleaseInitialOp::getCanonicalizationPatterns(
3678 RewritePatternSet &results, MLIRContext *context) {
3679 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3680}
3681
3682//===----------------------------------------------------------------------===//
3683// HasBeenResetIntrinsicOp
3684//===----------------------------------------------------------------------===//
3685
3686OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3687 // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3688
3689 // Fold to zero if the reset is a constant. In this case the op is either
3690 // permanently in reset or never resets. Both mean that the reset never
3691 // finishes, so this op never returns true.
3692 if (adaptor.getReset())
3693 return getIntZerosAttr(UIntType::get(getContext(), 1));
3694
3695 // Fold to zero if the clock is a constant and the reset is synchronous. In
3696 // that case the reset will never be started.
3697 if (isUInt1(getReset().getType()) && adaptor.getClock())
3698 return getIntZerosAttr(UIntType::get(getContext(), 1));
3699
3700 return {};
3701}
3702
3703//===----------------------------------------------------------------------===//
3704// FPGAProbeIntrinsicOp
3705//===----------------------------------------------------------------------===//
3706
3707static bool isTypeEmpty(FIRRTLType type) {
3709 .Case<FVectorType>(
3710 [&](auto ty) -> bool { return isTypeEmpty(ty.getElementType()); })
3711 .Case<BundleType>([&](auto ty) -> bool {
3712 for (auto elem : ty.getElements())
3713 if (!isTypeEmpty(elem.type))
3714 return false;
3715 return true;
3716 })
3717 .Case<IntType>([&](auto ty) { return ty.getWidth() == 0; })
3718 .Default([](auto) -> bool { return false; });
3719}
3720
3721LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3722 PatternRewriter &rewriter) {
3723 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3724 if (!firrtlTy)
3725 return failure();
3726
3727 if (!isTypeEmpty(firrtlTy))
3728 return failure();
3729
3730 rewriter.eraseOp(op);
3731 return success();
3732}
3733
3734//===----------------------------------------------------------------------===//
3735// Layer Block Op
3736//===----------------------------------------------------------------------===//
3737
3738LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3739 PatternRewriter &rewriter) {
3740
3741 // If the layerblock is empty, erase it.
3742 if (op.getBody()->empty()) {
3743 rewriter.eraseOp(op);
3744 return success();
3745 }
3746
3747 return failure();
3748}
3749
3750//===----------------------------------------------------------------------===//
3751// Domain-related Ops
3752//===----------------------------------------------------------------------===//
3753
3754OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3755 // If no domains are specified, then forward the input to the result.
3756 if (getDomains().empty())
3757 return getInput();
3758
3759 return {};
3760}
assert(baseType &&"element must be base type")
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:18
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition FoldUtils.h:28
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition FoldUtils.h:35
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
BitsOfCat(MLIRContext *context)