CIRCT 23.0.0git
Loading...
Searching...
No Matches
FIRRTLFolds.cpp
Go to the documentation of this file.
1//===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the folding and canonicalizations for FIRRTL ops.
10//
11//===----------------------------------------------------------------------===//
12
17#include "circt/Support/APInt.h"
18#include "circt/Support/LLVM.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/SmallPtrSet.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/ADT/TypeSwitch.h"
26
27using namespace circt;
28using namespace firrtl;
29
30// Drop writes to old and pass through passthrough to make patterns easier to
31// write.
32static Value dropWrite(PatternRewriter &rewriter, OpResult old,
33 Value passthrough) {
34 SmallPtrSet<Operation *, 8> users;
35 for (auto *user : old.getUsers())
36 users.insert(user);
37 for (Operation *user : users)
38 if (auto connect = dyn_cast<FConnectLike>(user))
39 if (connect.getDest() == old)
40 rewriter.eraseOp(user);
41 return passthrough;
42}
43
44// Return true if it is OK to propagate the name to the operation.
45// Non-pure operations such as instances, registers, and memories are not
46// allowed to update names for name stabilities and LEC reasons.
47static bool isOkToPropagateName(Operation *op) {
48 // Conservatively disallow operations with regions to prevent performance
49 // regression due to recursive calls to sub-regions in mlir::isPure.
50 if (op->getNumRegions() != 0)
51 return false;
52 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
53}
54
55// Move a name hint from a soon to be deleted operation to a new operation.
56// Pass through the new operation to make patterns easier to write. This cannot
57// move a name to a port (block argument), doing so would require rewriting all
58// instance sites as well as the module.
59static Value moveNameHint(OpResult old, Value passthrough) {
60 Operation *op = passthrough.getDefiningOp();
61 // This should handle ports, but it isn't clear we can change those in
62 // canonicalizers.
63 assert(op && "passthrough must be an operation");
64 Operation *oldOp = old.getOwner();
65 auto name = oldOp->getAttrOfType<StringAttr>("name");
66 if (name && !name.getValue().empty() && isOkToPropagateName(op))
67 op->setAttr("name", name);
68 return passthrough;
69}
70
71// Declarative canonicalization patterns
72namespace circt {
73namespace firrtl {
74namespace patterns {
75#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
76} // namespace patterns
77} // namespace firrtl
78} // namespace circt
79
80/// Return true if this operation's operands and results all have a known width.
81/// This only works for integer types.
82static bool hasKnownWidthIntTypes(Operation *op) {
83 auto resultType = type_cast<IntType>(op->getResult(0).getType());
84 if (!resultType.hasWidth())
85 return false;
86 for (Value operand : op->getOperands())
87 if (!type_cast<IntType>(operand.getType()).hasWidth())
88 return false;
89 return true;
90}
91
92/// Return true if this value is 1 bit UInt.
93static bool isUInt1(Type type) {
94 auto t = type_dyn_cast<UIntType>(type);
95 if (!t || !t.hasWidth() || t.getWidth() != 1)
96 return false;
97 return true;
98}
99
100/// Set the name of an op based on the best of two names: The current name, and
101/// the name passed in.
102static void updateName(PatternRewriter &rewriter, Operation *op,
103 StringAttr name) {
104 // Should never rename InstanceOp
105 if (!name || name.getValue().empty() || !isOkToPropagateName(op))
106 return;
107 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) && "Should never rename");
108 auto newName = name.getValue(); // old name is interesting
109 auto newOpName = op->getAttrOfType<StringAttr>("name");
110 // new name might not be interesting
111 if (newOpName)
112 newName = chooseName(newOpName.getValue(), name.getValue());
113 // Only update if needed
114 if (!newOpName || newOpName.getValue() != newName)
115 rewriter.modifyOpInPlace(
116 op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
117}
118
119/// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
120/// If a replaced op has a "name" attribute, this function propagates the name
121/// to the new value.
122static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
123 Value newValue) {
124 if (auto *newOp = newValue.getDefiningOp()) {
125 auto name = op->getAttrOfType<StringAttr>("name");
126 updateName(rewriter, newOp, name);
127 }
128 rewriter.replaceOp(op, newValue);
129}
130
131/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
132/// attribute. If a replaced op has a "name" attribute, this function propagates
133/// the name to the new value.
134template <typename OpTy, typename... Args>
135static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
136 Operation *op, Args &&...args) {
137 auto name = op->getAttrOfType<StringAttr>("name");
138 auto newOp =
139 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
140 updateName(rewriter, newOp, name);
141 return newOp;
142}
143
144/// Return true if the name is droppable. Note that this is different from
145/// `isUselessName` because non-useless names may be also droppable.
147 if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
148 return namableOp.hasDroppableName();
149 return false;
150}
151
152/// Implicitly replace the operand to a constant folding operation with a const
153/// 0 in case the operand is non-constant but has a bit width 0, or if the
154/// operand is an invalid value.
155///
156/// This makes constant folding significantly easier, as we can simply pass the
157/// operands to an operation through this function to appropriately replace any
158/// zero-width dynamic values or invalid values with a constant of value 0.
159static std::optional<APSInt>
160getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
161 assert(type_cast<IntType>(operand.getType()) &&
162 "getExtendedConstant is limited to integer types");
163
164 // We never support constant folding to unknown width values.
165 if (destWidth < 0)
166 return {};
167
168 // Extension signedness follows the operand sign.
169 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
170 return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
171
172 // If the operand is zero bits, then we can return a zero of the result
173 // type.
174 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
175 return APSInt(destWidth,
176 type_cast<IntType>(operand.getType()).isUnsigned());
177 return {};
178}
179
180/// Determine the value of a constant operand for the sake of constant folding.
181static std::optional<APSInt> getConstant(Attribute operand) {
182 if (!operand)
183 return {};
184 if (auto attr = dyn_cast<BoolAttr>(operand))
185 return APSInt(APInt(1, attr.getValue()));
186 if (auto attr = dyn_cast<IntegerAttr>(operand))
187 return attr.getAPSInt();
188 return {};
189}
190
191/// Determine whether a constant operand is a zero value for the sake of
192/// constant folding. This considers `invalidvalue` to be zero.
193static bool isConstantZero(Attribute operand) {
194 if (auto cst = getConstant(operand))
195 return cst->isZero();
196 return false;
197}
198
199/// Determine whether a constant operand is a one value for the sake of constant
200/// folding.
201static bool isConstantOne(Attribute operand) {
202 if (auto cst = getConstant(operand))
203 return cst->isOne();
204 return false;
205}
206
207/// This is the policy for folding, which depends on the sort of operator we're
208/// processing.
209enum class BinOpKind {
210 Normal,
211 Compare,
213};
214
215/// Applies the constant folding function `calculate` to the given operands.
216///
217/// Sign or zero extends the operands appropriately to the bitwidth of the
218/// result type if \p useDstWidth is true, else to the larger of the two operand
219/// bit widths and depending on whether the operation is to be performed on
220/// signed or unsigned operands.
221static Attribute constFoldFIRRTLBinaryOp(
222 Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
223 const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
224 assert(operands.size() == 2 && "binary op takes two operands");
225
226 // We cannot fold something to an unknown width.
227 auto resultType = type_cast<IntType>(op->getResult(0).getType());
228 if (resultType.getWidthOrSentinel() < 0)
229 return {};
230
231 // Any binary op returning i0 is 0.
232 if (resultType.getWidthOrSentinel() == 0)
233 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
234
235 // Determine the operand widths. This is either dictated by the operand type,
236 // or if that type is an unsized integer, by the actual bits necessary to
237 // represent the constant value.
238 auto lhsWidth =
239 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
240 auto rhsWidth =
241 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
242 if (auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
243 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
244 if (auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
245 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
246
247 // Compares extend the operands to the widest of the operand types, not to the
248 // result type.
249 int32_t operandWidth;
250 switch (opKind) {
252 operandWidth = resultType.getWidthOrSentinel();
253 break;
255 // Compares compute with the widest operand, not at the destination type
256 // (which is always i1).
257 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
258 break;
260 operandWidth =
261 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
262 break;
263 }
264
265 auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
266 if (!lhs)
267 return {};
268 auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
269 if (!rhs)
270 return {};
271
272 APInt resultValue = calculate(*lhs, *rhs);
273
274 // If the result type is smaller than the computation then we need to
275 // narrow the constant after the calculation.
276 if (opKind == BinOpKind::DivideOrShift)
277 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
278
279 assert((unsigned)resultType.getWidthOrSentinel() ==
280 resultValue.getBitWidth());
281 return getIntAttr(resultType, resultValue);
282}
283
284/// Applies the canonicalization function `canonicalize` to the given operation.
285///
286/// Determines which (if any) of the operation's operands are constants, and
287/// provides them as arguments to the callback function. Any `invalidvalue` in
288/// the input is mapped to a constant zero. The value returned from the callback
289/// is used as the replacement for `op`, and an additional pad operation is
290/// inserted if necessary. Does nothing if the result of `op` is of unknown
291/// width, in which case the necessity of a pad cannot be determined.
292static LogicalResult canonicalizePrimOp(
293 Operation *op, PatternRewriter &rewriter,
294 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
295 // Can only operate on FIRRTL primitive operations.
296 if (op->getNumResults() != 1)
297 return failure();
298 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
299 if (!type)
300 return failure();
301
302 // Can only operate on operations with a known result width.
303 auto width = type.getBitWidthOrSentinel();
304 if (width < 0)
305 return failure();
306
307 // Determine which of the operands are constants.
308 SmallVector<Attribute, 3> constOperands;
309 constOperands.reserve(op->getNumOperands());
310 for (auto operand : op->getOperands()) {
311 Attribute attr;
312 if (auto *defOp = operand.getDefiningOp())
313 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
314 [&](auto op) { attr = op.getValueAttr(); });
315 constOperands.push_back(attr);
316 }
317
318 // Perform the canonicalization and materialize the result if it is a
319 // constant.
320 auto result = canonicalize(constOperands);
321 if (!result)
322 return failure();
323 Value resultValue;
324 if (auto cst = dyn_cast<Attribute>(result))
325 resultValue = op->getDialect()
326 ->materializeConstant(rewriter, cst, type, op->getLoc())
327 ->getResult(0);
328 else
329 resultValue = cast<Value>(result);
330
331 // Insert a pad if the type widths disagree.
332 if (width !=
333 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
334 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
335
336 // Insert a cast if this is a uint vs. sint or vice versa.
337 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
338 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
339 else if (type_isa<UIntType>(type) &&
340 type_isa<SIntType>(resultValue.getType()))
341 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
342
343 assert(type == resultValue.getType() && "canonicalization changed type");
344 replaceOpAndCopyName(rewriter, op, resultValue);
345 return success();
346}
347
348/// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
349/// value if `bitWidth` is 0.
350static APInt getMaxUnsignedValue(unsigned bitWidth) {
351 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
352}
353
354/// Get the smallest signed value of a given bit width. Returns a 1-bit zero
355/// value if `bitWidth` is 0.
356static APInt getMinSignedValue(unsigned bitWidth) {
357 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
358}
359
360/// Get the largest signed value of a given bit width. Returns a 1-bit zero
361/// value if `bitWidth` is 0.
362static APInt getMaxSignedValue(unsigned bitWidth) {
363 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
364}
365
366//===----------------------------------------------------------------------===//
367// Fold Hooks
368//===----------------------------------------------------------------------===//
369
370OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
371 assert(adaptor.getOperands().empty() && "constant has no operands");
372 return getValueAttr();
373}
374
375OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
376 assert(adaptor.getOperands().empty() && "constant has no operands");
377 return getValueAttr();
378}
379
380OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
381 assert(adaptor.getOperands().empty() && "constant has no operands");
382 return getFieldsAttr();
383}
384
385OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
386 assert(adaptor.getOperands().empty() && "constant has no operands");
387 return getValueAttr();
388}
389
390OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
391 assert(adaptor.getOperands().empty() && "constant has no operands");
392 return getValueAttr();
393}
394
395OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
396 assert(adaptor.getOperands().empty() && "constant has no operands");
397 return getValueAttr();
398}
399
400OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
401 assert(adaptor.getOperands().empty() && "constant has no operands");
402 return getValueAttr();
403}
404
405//===----------------------------------------------------------------------===//
406// Binary Operators
407//===----------------------------------------------------------------------===//
408
409OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
411 *this, adaptor.getOperands(), BinOpKind::Normal,
412 [=](const APSInt &a, const APSInt &b) { return a + b; });
413}
414
415void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
416 MLIRContext *context) {
417 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
418 patterns::AddOfSelf, patterns::AddOfPad>(context);
419}
420
421OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
423 *this, adaptor.getOperands(), BinOpKind::Normal,
424 [=](const APSInt &a, const APSInt &b) { return a - b; });
425}
426
427void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
428 MLIRContext *context) {
429 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
430 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
431 patterns::SubOfPadL, patterns::SubOfPadR>(context);
432}
433
434OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
435 // mul(x, 0) -> 0
436 //
437 // This is legal because it aligns with the Scala FIRRTL Compiler
438 // interpretation of lowering invalid to constant zero before constant
439 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
440 // multiplication this way and will emit "x * 0".
441 if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
442 return getIntZerosAttr(getType());
443
445 *this, adaptor.getOperands(), BinOpKind::Normal,
446 [=](const APSInt &a, const APSInt &b) { return a * b; });
447}
448
449OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
450 /// div(x, x) -> 1
451 ///
452 /// Division by zero is undefined in the FIRRTL specification. This fold
453 /// exploits that fact to optimize self division to one. Note: this should
454 /// supersede any division with invalid or zero. Division of invalid by
455 /// invalid should be one.
456 if (getLhs() == getRhs()) {
457 auto width = getType().base().getWidthOrSentinel();
458 if (width == -1)
459 width = 2;
460 // Only fold if we have at least 1 bit of width to represent the `1` value.
461 if (width != 0)
462 return getIntAttr(getType(), APInt(width, 1));
463 }
464
465 // div(0, x) -> 0
466 //
467 // This is legal because it aligns with the Scala FIRRTL Compiler
468 // interpretation of lowering invalid to constant zero before constant
469 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
470 // division this way and will emit "0 / x".
471 if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
472 return getIntZerosAttr(getType());
473
474 /// div(x, 1) -> x : (uint, uint) -> uint
475 ///
476 /// UInt division by one returns the numerator. SInt division can't
477 /// be folded here because it increases the return type bitwidth by
478 /// one and requires sign extension (a new op).
479 if (auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
480 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
481 return getLhs();
482
484 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
485 [=](const APSInt &a, const APSInt &b) -> APInt {
486 if (!!b)
487 return a / b;
488 return APInt(a.getBitWidth(), 0);
489 });
490}
491
492OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
493 // rem(x, x) -> 0
494 //
495 // Division by zero is undefined in the FIRRTL specification. This fold
496 // exploits that fact to optimize self division remainder to zero. Note:
497 // this should supersede any division with invalid or zero. Remainder of
498 // division of invalid by invalid should be zero.
499 if (getLhs() == getRhs())
500 return getIntZerosAttr(getType());
501
502 // rem(0, x) -> 0
503 //
504 // This is legal because it aligns with the Scala FIRRTL Compiler
505 // interpretation of lowering invalid to constant zero before constant
506 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
507 // division this way and will emit "0 % x".
508 if (isConstantZero(adaptor.getLhs()))
509 return getIntZerosAttr(getType());
510
512 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
513 [=](const APSInt &a, const APSInt &b) -> APInt {
514 if (!!b)
515 return a % b;
516 return APInt(a.getBitWidth(), 0);
517 });
518}
519
520OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
522 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
523 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
524}
525
526OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
528 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
529 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
530}
531
532OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
534 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
535 [=](const APSInt &a, const APSInt &b) -> APInt {
536 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
537 : a.ashr(b);
538 });
539}
540
541// TODO: Move to DRR.
542OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
543 if (auto rhsCst = getConstant(adaptor.getRhs())) {
544 /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
545 if (rhsCst->isZero())
546 return getIntZerosAttr(getType());
547
548 /// and(x, -1) -> x
549 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
550 getRhs().getType() == getType())
551 return getLhs();
552 }
553
554 if (auto lhsCst = getConstant(adaptor.getLhs())) {
555 /// and(0, x) -> 0, 0 is largest or is implicit zero extended
556 if (lhsCst->isZero())
557 return getIntZerosAttr(getType());
558
559 /// and(-1, x) -> x
560 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
561 getRhs().getType() == getType())
562 return getRhs();
563 }
564
565 /// and(x, x) -> x
566 if (getLhs() == getRhs() && getRhs().getType() == getType())
567 return getRhs();
568
570 *this, adaptor.getOperands(), BinOpKind::Normal,
571 [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
572}
573
574void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
575 MLIRContext *context) {
576 results
577 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
578 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
579 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
580}
581
582OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
583 if (auto rhsCst = getConstant(adaptor.getRhs())) {
584 /// or(x, 0) -> x
585 if (rhsCst->isZero() && getLhs().getType() == getType())
586 return getLhs();
587
588 /// or(x, -1) -> -1
589 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
590 getLhs().getType() == getType())
591 return getRhs();
592 }
593
594 if (auto lhsCst = getConstant(adaptor.getLhs())) {
595 /// or(0, x) -> x
596 if (lhsCst->isZero() && getRhs().getType() == getType())
597 return getRhs();
598
599 /// or(-1, x) -> -1
600 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
601 getRhs().getType() == getType())
602 return getLhs();
603 }
604
605 /// or(x, x) -> x
606 if (getLhs() == getRhs() && getRhs().getType() == getType())
607 return getRhs();
608
610 *this, adaptor.getOperands(), BinOpKind::Normal,
611 [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
612}
613
614void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
615 MLIRContext *context) {
616 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
617 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
618 patterns::OrOrr>(context);
619}
620
621OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
622 /// xor(x, 0) -> x
623 if (auto rhsCst = getConstant(adaptor.getRhs()))
624 if (rhsCst->isZero() &&
625 firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
626 return getLhs();
627
628 /// xor(x, 0) -> x
629 if (auto lhsCst = getConstant(adaptor.getLhs()))
630 if (lhsCst->isZero() &&
631 firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
632 return getRhs();
633
634 /// xor(x, x) -> 0
635 if (getLhs() == getRhs())
636 return getIntAttr(
637 getType(),
638 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
639
641 *this, adaptor.getOperands(), BinOpKind::Normal,
642 [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
643}
644
645void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
646 MLIRContext *context) {
647 results.insert<patterns::extendXor, patterns::moveConstXor,
648 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
649 context);
650}
651
652void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
653 MLIRContext *context) {
654 results.insert<patterns::LEQWithConstLHS>(context);
655}
656
657OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
658 bool isUnsigned = getLhs().getType().base().isUnsigned();
659
660 // leq(x, x) -> 1
661 if (getLhs() == getRhs())
662 return getIntAttr(getType(), APInt(1, 1));
663
664 // Comparison against constant outside type bounds.
665 if (auto width = getLhs().getType().base().getWidth()) {
666 if (auto rhsCst = getConstant(adaptor.getRhs())) {
667 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
668 commonWidth = std::max(commonWidth, 1);
669
670 // leq(x, const) -> 0 where const < minValue of the unsigned type of x
671 // This can never occur since const is unsigned and cannot be less than 0.
672
673 // leq(x, const) -> 0 where const < minValue of the signed type of x
674 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
675 .slt(getMinSignedValue(*width).sext(commonWidth)))
676 return getIntAttr(getType(), APInt(1, 0));
677
678 // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
679 if (isUnsigned && rhsCst->zext(commonWidth)
680 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
681 return getIntAttr(getType(), APInt(1, 1));
682
683 // leq(x, const) -> 1 where const >= maxValue of the signed type of x
684 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
685 .sge(getMaxSignedValue(*width).sext(commonWidth)))
686 return getIntAttr(getType(), APInt(1, 1));
687 }
688 }
689
691 *this, adaptor.getOperands(), BinOpKind::Compare,
692 [=](const APSInt &a, const APSInt &b) -> APInt {
693 return APInt(1, a <= b);
694 });
695}
696
697void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
698 MLIRContext *context) {
699 results.insert<patterns::LTWithConstLHS>(context);
700}
701
702OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
703 IntType lhsType = getLhs().getType();
704 bool isUnsigned = lhsType.isUnsigned();
705
706 // lt(x, x) -> 0
707 if (getLhs() == getRhs())
708 return getIntAttr(getType(), APInt(1, 0));
709
710 // lt(x, 0) -> 0 when x is unsigned
711 if (auto rhsCst = getConstant(adaptor.getRhs())) {
712 if (rhsCst->isZero() && lhsType.isUnsigned())
713 return getIntAttr(getType(), APInt(1, 0));
714 }
715
716 // Comparison against constant outside type bounds.
717 if (auto width = lhsType.getWidth()) {
718 if (auto rhsCst = getConstant(adaptor.getRhs())) {
719 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
720 commonWidth = std::max(commonWidth, 1);
721
722 // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
723 // Handled explicitly above.
724
725 // lt(x, const) -> 0 where const <= minValue of the signed type of x
726 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
727 .sle(getMinSignedValue(*width).sext(commonWidth)))
728 return getIntAttr(getType(), APInt(1, 0));
729
730 // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
731 if (isUnsigned && rhsCst->zext(commonWidth)
732 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
733 return getIntAttr(getType(), APInt(1, 1));
734
735 // lt(x, const) -> 1 where const > maxValue of the signed type of x
736 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
737 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
738 return getIntAttr(getType(), APInt(1, 1));
739 }
740 }
741
743 *this, adaptor.getOperands(), BinOpKind::Compare,
744 [=](const APSInt &a, const APSInt &b) -> APInt {
745 return APInt(1, a < b);
746 });
747}
748
749void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
750 MLIRContext *context) {
751 results.insert<patterns::GEQWithConstLHS>(context);
752}
753
754OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
755 IntType lhsType = getLhs().getType();
756 bool isUnsigned = lhsType.isUnsigned();
757
758 // geq(x, x) -> 1
759 if (getLhs() == getRhs())
760 return getIntAttr(getType(), APInt(1, 1));
761
762 // geq(x, 0) -> 1 when x is unsigned
763 if (auto rhsCst = getConstant(adaptor.getRhs())) {
764 if (rhsCst->isZero() && isUnsigned)
765 return getIntAttr(getType(), APInt(1, 1));
766 }
767
768 // Comparison against constant outside type bounds.
769 if (auto width = lhsType.getWidth()) {
770 if (auto rhsCst = getConstant(adaptor.getRhs())) {
771 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
772 commonWidth = std::max(commonWidth, 1);
773
774 // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
775 if (isUnsigned && rhsCst->zext(commonWidth)
776 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
777 return getIntAttr(getType(), APInt(1, 0));
778
779 // geq(x, const) -> 0 where const > maxValue of the signed type of x
780 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
781 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
782 return getIntAttr(getType(), APInt(1, 0));
783
784 // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
785 // Handled explicitly above.
786
787 // geq(x, const) -> 1 where const <= minValue of the signed type of x
788 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
789 .sle(getMinSignedValue(*width).sext(commonWidth)))
790 return getIntAttr(getType(), APInt(1, 1));
791 }
792 }
793
795 *this, adaptor.getOperands(), BinOpKind::Compare,
796 [=](const APSInt &a, const APSInt &b) -> APInt {
797 return APInt(1, a >= b);
798 });
799}
800
801void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
802 MLIRContext *context) {
803 results.insert<patterns::GTWithConstLHS>(context);
804}
805
806OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
807 IntType lhsType = getLhs().getType();
808 bool isUnsigned = lhsType.isUnsigned();
809
810 // gt(x, x) -> 0
811 if (getLhs() == getRhs())
812 return getIntAttr(getType(), APInt(1, 0));
813
814 // Comparison against constant outside type bounds.
815 if (auto width = lhsType.getWidth()) {
816 if (auto rhsCst = getConstant(adaptor.getRhs())) {
817 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
818 commonWidth = std::max(commonWidth, 1);
819
820 // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
821 if (isUnsigned && rhsCst->zext(commonWidth)
822 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
823 return getIntAttr(getType(), APInt(1, 0));
824
825 // gt(x, const) -> 0 where const >= maxValue of the signed type of x
826 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
827 .sge(getMaxSignedValue(*width).sext(commonWidth)))
828 return getIntAttr(getType(), APInt(1, 0));
829
830 // gt(x, const) -> 1 where const < minValue of the unsigned type of x
831 // This can never occur since const is unsigned and cannot be less than 0.
832
833 // gt(x, const) -> 1 where const < minValue of the signed type of x
834 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
835 .slt(getMinSignedValue(*width).sext(commonWidth)))
836 return getIntAttr(getType(), APInt(1, 1));
837 }
838 }
839
841 *this, adaptor.getOperands(), BinOpKind::Compare,
842 [=](const APSInt &a, const APSInt &b) -> APInt {
843 return APInt(1, a > b);
844 });
845}
846
847OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
848 // eq(x, x) -> 1
849 if (getLhs() == getRhs())
850 return getIntAttr(getType(), APInt(1, 1));
851
852 if (auto rhsCst = getConstant(adaptor.getRhs())) {
853 /// eq(x, 1) -> x when x is 1 bit.
854 /// TODO: Support SInt<1> on the LHS etc.
855 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
856 getRhs().getType() == getType())
857 return getLhs();
858 }
859
861 *this, adaptor.getOperands(), BinOpKind::Compare,
862 [=](const APSInt &a, const APSInt &b) -> APInt {
863 return APInt(1, a == b);
864 });
865}
866
867LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
868 return canonicalizePrimOp(
869 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
870 if (auto rhsCst = getConstant(operands[1])) {
871 auto width = op.getLhs().getType().getBitWidthOrSentinel();
872
873 // eq(x, 0) -> not(x) when x is 1 bit.
874 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
875 op.getRhs().getType() == op.getType()) {
876 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
877 .getResult();
878 }
879
880 // eq(x, 0) -> not(orr(x)) when x is >1 bit
881 if (rhsCst->isZero() && width > 1) {
882 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
883 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
884 }
885
886 // eq(x, ~0) -> andr(x) when x is >1 bit
887 if (rhsCst->isAllOnes() && width > 1 &&
888 op.getLhs().getType() == op.getRhs().getType()) {
889 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
890 .getResult();
891 }
892 }
893 return {};
894 });
895}
896
897OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
898 // neq(x, x) -> 0
899 if (getLhs() == getRhs())
900 return getIntAttr(getType(), APInt(1, 0));
901
902 if (auto rhsCst = getConstant(adaptor.getRhs())) {
903 /// neq(x, 0) -> x when x is 1 bit.
904 /// TODO: Support SInt<1> on the LHS etc.
905 if (rhsCst->isZero() && getLhs().getType() == getType() &&
906 getRhs().getType() == getType())
907 return getLhs();
908 }
909
911 *this, adaptor.getOperands(), BinOpKind::Compare,
912 [=](const APSInt &a, const APSInt &b) -> APInt {
913 return APInt(1, a != b);
914 });
915}
916
917LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
918 return canonicalizePrimOp(
919 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
920 if (auto rhsCst = getConstant(operands[1])) {
921 auto width = op.getLhs().getType().getBitWidthOrSentinel();
922
923 // neq(x, 1) -> not(x) when x is 1 bit
924 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
925 op.getRhs().getType() == op.getType()) {
926 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
927 .getResult();
928 }
929
930 // neq(x, 0) -> orr(x) when x is >1 bit
931 if (rhsCst->isZero() && width > 1) {
932 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
933 .getResult();
934 }
935
936 // neq(x, ~0) -> not(andr(x))) when x is >1 bit
937 if (rhsCst->isAllOnes() && width > 1 &&
938 op.getLhs().getType() == op.getRhs().getType()) {
939 auto andrOp =
940 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
941 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
942 }
943 }
944
945 return {};
946 });
947}
948
949OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
950 // TODO: implement constant folding, etc.
951 // Tracked in https://github.com/llvm/circt/issues/6696.
952 return {};
953}
954
955OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
956 // TODO: implement constant folding, etc.
957 // Tracked in https://github.com/llvm/circt/issues/6724.
958 return {};
959}
960
961OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
962 if (auto rhsCst = getConstant(adaptor.getRhs())) {
963 if (auto lhsCst = getConstant(adaptor.getLhs())) {
964
965 return IntegerAttr::get(IntegerType::get(getContext(),
966 lhsCst->getBitWidth(),
967 IntegerType::Signed),
968 lhsCst->ashr(*rhsCst));
969 }
970
971 if (rhsCst->isZero())
972 return getLhs();
973 }
974
975 return {};
976}
977
978OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
979 if (auto rhsCst = getConstant(adaptor.getRhs())) {
980 // Constant folding
981 if (auto lhsCst = getConstant(adaptor.getLhs()))
982
983 return IntegerAttr::get(IntegerType::get(getContext(),
984 lhsCst->getBitWidth(),
985 IntegerType::Signed),
986 lhsCst->shl(*rhsCst));
987
988 // integer.shl(x, 0) -> x
989 if (rhsCst->isZero())
990 return getLhs();
991 }
992
993 return {};
994}
995
996//===----------------------------------------------------------------------===//
997// Unary Operators
998//===----------------------------------------------------------------------===//
999
1000OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
1001 auto base = getInput().getType();
1002 auto w = getBitWidth(base);
1003 if (w)
1004 return getIntAttr(getType(), APInt(32, *w));
1005 return {};
1006}
1007
1008OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
1009 // No constant can be 'x' by definition.
1010 if (auto cst = getConstant(adaptor.getArg()))
1011 return getIntAttr(getType(), APInt(1, 0));
1012 return {};
1013}
1014
1015OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1016 // No effect.
1017 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1018 return getInput();
1019
1020 // Be careful to only fold the cast into the constant if the size is known.
1021 // Otherwise width inference may produce differently-sized constants if the
1022 // sign changes.
1023 if (getType().base().hasWidth())
1024 if (auto cst = getConstant(adaptor.getInput()))
1025 return getIntAttr(getType(), *cst);
1026
1027 return {};
1028}
1029
1030void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1031 MLIRContext *context) {
1032 results.insert<patterns::StoUtoS>(context);
1033}
1034
1035OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1036 // No effect.
1037 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1038 return getInput();
1039
1040 // Be careful to only fold the cast into the constant if the size is known.
1041 // Otherwise width inference may produce differently-sized constants if the
1042 // sign changes.
1043 if (getType().base().hasWidth())
1044 if (auto cst = getConstant(adaptor.getInput()))
1045 return getIntAttr(getType(), *cst);
1046
1047 return {};
1048}
1049
1050void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1051 MLIRContext *context) {
1052 results.insert<patterns::UtoStoU>(context);
1053}
1054
1055OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1056 // No effect.
1057 if (getInput().getType() == getType())
1058 return getInput();
1059
1060 // Constant fold.
1061 if (auto cst = getConstant(adaptor.getInput()))
1062 return BoolAttr::get(getContext(), cst->getBoolValue());
1063
1064 return {};
1065}
1066
1067OpFoldResult AsResetPrimOp::fold(FoldAdaptor adaptor) {
1068 if (auto cst = getConstant(adaptor.getInput()))
1069 return BoolAttr::get(getContext(), cst->getBoolValue());
1070 return {};
1071}
1072
1073OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1074 // No effect.
1075 if (getInput().getType() == getType())
1076 return getInput();
1077
1078 // Constant fold.
1079 if (auto cst = getConstant(adaptor.getInput()))
1080 return BoolAttr::get(getContext(), cst->getBoolValue());
1081
1082 return {};
1083}
1084
1085OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1086 if (!hasKnownWidthIntTypes(*this))
1087 return {};
1088
1089 // Signed to signed is a noop, unsigned operands prepend a zero bit.
1090 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1091 getType().base().getWidthOrSentinel()))
1092 return getIntAttr(getType(), *cst);
1093
1094 return {};
1095}
1096
1097void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1098 MLIRContext *context) {
1099 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1100}
1101
1102OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1103 if (!hasKnownWidthIntTypes(*this))
1104 return {};
1105
1106 // FIRRTL negate always adds a bit.
1107 // -x ---> 0-sext(x) or 0-zext(x)
1108 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1109 getType().base().getWidthOrSentinel()))
1110 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1111
1112 return {};
1113}
1114
1115OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1116 if (!hasKnownWidthIntTypes(*this))
1117 return {};
1118
1119 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1120 getType().base().getWidthOrSentinel()))
1121 return getIntAttr(getType(), ~*cst);
1122
1123 return {};
1124}
1125
1126void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1127 MLIRContext *context) {
1128 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1129 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1130 patterns::NotGt>(context);
1131}
1132
1133class ReductionCat : public mlir::RewritePattern {
1134public:
1135 ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
1136 : RewritePattern(opName, 0, context) {}
1137
1138 /// Handle a constant operand in the cat operation.
1139 /// Returns true if the entire reduction can be replaced with a constant.
1140 /// May add non-zero constants to the remaining operands list.
1141 virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1142 ConstantOp constantOp,
1143 SmallVectorImpl<Value> &remaining) const = 0;
1144
1145 /// Return the unit value for this reduction operation:
1146 virtual bool getIdentityValue() const = 0;
1147
1148 LogicalResult
1149 matchAndRewrite(Operation *op,
1150 mlir::PatternRewriter &rewriter) const override {
1151 // Check if the operand is a cat operation
1152 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1153 if (!catOp)
1154 return failure();
1155
1156 SmallVector<Value> nonConstantOperands;
1157
1158 // Process each operand of the cat operation
1159 for (auto operand : catOp.getInputs()) {
1160 if (auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1161 // Handle constant operands - may short-circuit the entire operation
1162 if (handleConstant(rewriter, op, constantOp, nonConstantOperands))
1163 return success();
1164 } else {
1165 // Keep non-constant operands for further processing
1166 nonConstantOperands.push_back(operand);
1167 }
1168 }
1169
1170 // If no operands remain, replace with identity value
1171 if (nonConstantOperands.empty()) {
1172 replaceOpWithNewOpAndCopyName<ConstantOp>(
1173 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1174 APInt(1, getIdentityValue()));
1175 return success();
1176 }
1177
1178 // If only one operand remains, apply reduction directly to it
1179 if (nonConstantOperands.size() == 1) {
1180 rewriter.modifyOpInPlace(
1181 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1182 return success();
1183 }
1184
1185 // Multiple operands remain - optimize only when cat has a single use.
1186 if (catOp->hasOneUse() &&
1187 nonConstantOperands.size() < catOp->getNumOperands()) {
1188 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1189 nonConstantOperands);
1190 return success();
1191 }
1192 return failure();
1193 }
1194};
1195
1196class OrRCat : public ReductionCat {
1197public:
1198 OrRCat(MLIRContext *context)
1199 : ReductionCat(context, OrRPrimOp::getOperationName()) {}
1200 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1201 ConstantOp value,
1202 SmallVectorImpl<Value> &remaining) const override {
1203 if (value.getValue().isZero())
1204 return false;
1205
1206 replaceOpWithNewOpAndCopyName<ConstantOp>(
1207 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1208 llvm::APInt(1, 1));
1209 return true;
1210 }
1211 bool getIdentityValue() const override { return false; }
1212};
1213
1214class AndRCat : public ReductionCat {
1215public:
1216 AndRCat(MLIRContext *context)
1217 : ReductionCat(context, AndRPrimOp::getOperationName()) {}
1218 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1219 ConstantOp value,
1220 SmallVectorImpl<Value> &remaining) const override {
1221 if (value.getValue().isAllOnes())
1222 return false;
1223
1224 replaceOpWithNewOpAndCopyName<ConstantOp>(
1225 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1226 llvm::APInt(1, 0));
1227 return true;
1228 }
1229 bool getIdentityValue() const override { return true; }
1230};
1231
1232class XorRCat : public ReductionCat {
1233public:
1234 XorRCat(MLIRContext *context)
1235 : ReductionCat(context, XorRPrimOp::getOperationName()) {}
1236 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1237 ConstantOp value,
1238 SmallVectorImpl<Value> &remaining) const override {
1239 if (value.getValue().isZero())
1240 return false;
1241 remaining.push_back(value);
1242 return false;
1243 }
1244 bool getIdentityValue() const override { return false; }
1245};
1246
1247OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1248 if (!hasKnownWidthIntTypes(*this))
1249 return {};
1250
1251 if (getInput().getType().getBitWidthOrSentinel() == 0)
1252 return getIntAttr(getType(), APInt(1, 1));
1253
1254 // x == -1
1255 if (auto cst = getConstant(adaptor.getInput()))
1256 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1257
1258 // one bit is identity. Only applies to UInt since we can't make a cast
1259 // here.
1260 if (isUInt1(getInput().getType()))
1261 return getInput();
1262
1263 return {};
1264}
1265
1266void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1267 MLIRContext *context) {
1268 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1269 patterns::AndRPadS, patterns::AndRCatAndR_left,
1270 patterns::AndRCatAndR_right, AndRCat>(context);
1271}
1272
1273OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1274 if (!hasKnownWidthIntTypes(*this))
1275 return {};
1276
1277 if (getInput().getType().getBitWidthOrSentinel() == 0)
1278 return getIntAttr(getType(), APInt(1, 0));
1279
1280 // x != 0
1281 if (auto cst = getConstant(adaptor.getInput()))
1282 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1283
1284 // one bit is identity. Only applies to UInt since we can't make a cast
1285 // here.
1286 if (isUInt1(getInput().getType()))
1287 return getInput();
1288
1289 return {};
1290}
1291
1292void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1293 MLIRContext *context) {
1294 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1295 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right, OrRCat>(
1296 context);
1297}
1298
1299OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1300 if (!hasKnownWidthIntTypes(*this))
1301 return {};
1302
1303 if (getInput().getType().getBitWidthOrSentinel() == 0)
1304 return getIntAttr(getType(), APInt(1, 0));
1305
1306 // popcount(x) & 1
1307 if (auto cst = getConstant(adaptor.getInput()))
1308 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1309
1310 // one bit is identity. Only applies to UInt since we can't make a cast here.
1311 if (isUInt1(getInput().getType()))
1312 return getInput();
1313
1314 return {};
1315}
1316
1317void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1318 MLIRContext *context) {
1319 results
1320 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1321 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right, XorRCat>(
1322 context);
1323}
1324
1325//===----------------------------------------------------------------------===//
1326// Other Operators
1327//===----------------------------------------------------------------------===//
1328
1329OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1330 auto inputs = getInputs();
1331 auto inputAdaptors = adaptor.getInputs();
1332
1333 // If no inputs, return 0-bit value
1334 if (inputs.empty())
1335 return getIntZerosAttr(getType());
1336
1337 // If single input and same type, return it
1338 if (inputs.size() == 1 && inputs[0].getType() == getType())
1339 return inputs[0];
1340
1341 // Make sure it's safe to fold.
1342 if (!hasKnownWidthIntTypes(*this))
1343 return {};
1344
1345 // Filter out zero-width operands
1346 SmallVector<Value> nonZeroInputs;
1347 SmallVector<Attribute> nonZeroAttributes;
1348 bool allConstant = true;
1349 for (auto [input, attr] : llvm::zip(inputs, inputAdaptors)) {
1350 auto inputType = type_cast<IntType>(input.getType());
1351 if (inputType.getBitWidthOrSentinel() != 0) {
1352 nonZeroInputs.push_back(input);
1353 if (!attr)
1354 allConstant = false;
1355 if (nonZeroInputs.size() > 1 && !allConstant)
1356 return {};
1357 }
1358 }
1359
1360 // If all inputs were zero-width, return 0-bit value
1361 if (nonZeroInputs.empty())
1362 return getIntZerosAttr(getType());
1363
1364 // If only one non-zero input and it has the same type as result, return it
1365 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1366 return nonZeroInputs[0];
1367
1368 if (!hasKnownWidthIntTypes(*this))
1369 return {};
1370
1371 // Constant fold cat - concatenate all constant operands
1372 SmallVector<APInt> constants;
1373 for (auto inputAdaptor : inputAdaptors) {
1374 if (auto cst = getConstant(inputAdaptor))
1375 constants.push_back(*cst);
1376 else
1377 return {}; // Not all operands are constant
1378 }
1379
1380 assert(!constants.empty());
1381 // Concatenate all constants from left to right
1382 APInt result = constants[0];
1383 for (size_t i = 1; i < constants.size(); ++i)
1384 result = result.concat(constants[i]);
1385
1386 return getIntAttr(getType(), result);
1387}
1388
1389void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1390 MLIRContext *context) {
1391 results.insert<patterns::DShlOfConstant>(context);
1392}
1393
1394void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1395 MLIRContext *context) {
1396 results.insert<patterns::DShrOfConstant>(context);
1397}
1398
1399namespace {
1400
1401/// Canonicalize a cat tree into single cat operation.
1402class FlattenCat : public mlir::RewritePattern {
1403public:
1404 FlattenCat(MLIRContext *context)
1405 : RewritePattern(CatPrimOp::getOperationName(), 0, context) {}
1406
1407 LogicalResult
1408 matchAndRewrite(Operation *op,
1409 mlir::PatternRewriter &rewriter) const override {
1410 auto cat = cast<CatPrimOp>(op);
1411 if (!hasKnownWidthIntTypes(cat) ||
1412 cat.getType().getBitWidthOrSentinel() == 0)
1413 return failure();
1414
1415 // If this is not a root of concat tree, skip.
1416 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1417 return failure();
1418
1419 // Try flattening the cat tree.
1420 SmallVector<Value> operands;
1421 SmallVector<Value> worklist;
1422 auto pushOperands = [&worklist](CatPrimOp op) {
1423 for (auto operand : llvm::reverse(op.getInputs()))
1424 worklist.push_back(operand);
1425 };
1426 pushOperands(cat);
1427 bool hasSigned = false, hasUnsigned = false;
1428 while (!worklist.empty()) {
1429 auto value = worklist.pop_back_val();
1430 auto catOp = value.getDefiningOp<CatPrimOp>();
1431 if (!catOp) {
1432 operands.push_back(value);
1433 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) = true;
1434 continue;
1435 }
1436
1437 pushOperands(catOp);
1438 }
1439
1440 // Helper function to cast signed values to unsigned. CatPrimOp converts
1441 // signed values to unsigned so we need to do it explicitly here.
1442 auto castToUIntIfSigned = [&](Value value) -> Value {
1443 if (type_isa<UIntType>(value.getType()))
1444 return value;
1445 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1446 };
1447
1448 assert(operands.size() >= 1 && "zero width cast must be rejected");
1449
1450 if (operands.size() == 1) {
1451 rewriter.replaceOp(op, castToUIntIfSigned(operands[0]));
1452 return success();
1453 }
1454
1455 if (operands.size() == cat->getNumOperands())
1456 return failure();
1457
1458 // If types are mixed, cast all operands to unsigned.
1459 if (hasSigned && hasUnsigned)
1460 for (auto &operand : operands)
1461 operand = castToUIntIfSigned(operand);
1462
1463 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, op, cat.getType(),
1464 operands);
1465 return success();
1466 }
1467};
1468
1469// Fold successive constants cat(x, y, 1, 10, z) -> cat(x, y, 110, z)
1470class CatOfConstant : public mlir::RewritePattern {
1471public:
1472 CatOfConstant(MLIRContext *context)
1473 : RewritePattern(CatPrimOp::getOperationName(), 0, context) {}
1474
1475 LogicalResult
1476 matchAndRewrite(Operation *op,
1477 mlir::PatternRewriter &rewriter) const override {
1478 auto cat = cast<CatPrimOp>(op);
1479 if (!hasKnownWidthIntTypes(cat))
1480 return failure();
1481
1482 SmallVector<Value> operands;
1483
1484 for (size_t i = 0; i < cat->getNumOperands(); ++i) {
1485 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1486 if (!cst) {
1487 operands.push_back(cat.getInputs()[i]);
1488 continue;
1489 }
1490 APSInt value = cst.getValue();
1491 size_t j = i + 1;
1492 for (; j < cat->getNumOperands(); ++j) {
1493 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1494 if (!nextCst)
1495 break;
1496 value = value.concat(nextCst.getValue());
1497 }
1498
1499 if (j == i + 1) {
1500 // Not folded.
1501 operands.push_back(cst);
1502 } else {
1503 // Folded.
1504 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1505 }
1506
1507 i = j - 1;
1508 }
1509
1510 if (operands.size() == cat->getNumOperands())
1511 return failure();
1512
1513 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, op, cat.getType(),
1514 operands);
1515
1516 return success();
1517 }
1518};
1519
1520} // namespace
1521
1522void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1523 MLIRContext *context) {
1524 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1525 patterns::CatCast, FlattenCat, CatOfConstant>(context);
1526}
1527
1528OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1529 auto op = (*this);
1530 // BitCast is redundant if input and result types are same.
1531 if (op.getType() == op.getInput().getType())
1532 return op.getInput();
1533
1534 // Two consecutive BitCasts are redundant if first bitcast type is same as the
1535 // final result type.
1536 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1537 if (op.getType() == in.getInput().getType())
1538 return in.getInput();
1539
1540 return {};
1541}
1542
1543OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1544 IntType inputType = getInput().getType();
1545 IntType resultType = getType();
1546 // If we are extracting the entire input, then return it.
1547 if (inputType == getType() && resultType.hasWidth())
1548 return getInput();
1549
1550 // Constant fold.
1551 if (hasKnownWidthIntTypes(*this))
1552 if (auto cst = getConstant(adaptor.getInput()))
1553 return getIntAttr(resultType,
1554 cst->extractBits(getHi() - getLo() + 1, getLo()));
1555
1556 return {};
1557}
1558
1559struct BitsOfCat : public mlir::RewritePattern {
1560 BitsOfCat(MLIRContext *context)
1561 : RewritePattern(BitsPrimOp::getOperationName(), 0, context) {}
1562
1563 LogicalResult
1564 matchAndRewrite(Operation *op,
1565 mlir::PatternRewriter &rewriter) const override {
1566 auto bits = cast<BitsPrimOp>(op);
1567 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1568 if (!cat)
1569 return failure();
1570 int32_t bitPos = bits.getLo();
1571 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1572 if (resultWidth < 0)
1573 return failure();
1574 for (auto operand : llvm::reverse(cat.getInputs())) {
1575 auto operandWidth =
1576 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1577 if (operandWidth < 0)
1578 return failure();
1579 if (bitPos < operandWidth) {
1580 if (bitPos + resultWidth <= operandWidth) {
1581 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1582 op->getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1583 replaceOpAndCopyName(rewriter, op, newBits);
1584 return success();
1585 }
1586 return failure();
1587 }
1588 bitPos -= operandWidth;
1589 }
1590 return failure();
1591 }
1592};
1593
1594void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1595 MLIRContext *context) {
1596 results
1597 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1598 patterns::BitsOfAnd, patterns::BitsOfPad, BitsOfCat>(context);
1599}
1600
1601/// Replace the specified operation with a 'bits' op from the specified hi/lo
1602/// bits. Insert a cast to handle the case where the original operation
1603/// returned a signed integer.
1604static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1605 unsigned loBit, PatternRewriter &rewriter) {
1606 auto resType = type_cast<IntType>(op->getResult(0).getType());
1607 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1608 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1609
1610 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1611 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1612 } else if (resType.isUnsigned() &&
1613 !type_cast<IntType>(value.getType()).isUnsigned()) {
1614 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1615 }
1616 rewriter.replaceOp(op, value);
1617}
1618
1619template <typename OpTy>
1620static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1621 // mux : UInt<0> -> 0
1622 if (op.getType().getBitWidthOrSentinel() == 0)
1623 return getIntAttr(op.getType(),
1624 APInt(0, 0, op.getType().isSignedInteger()));
1625
1626 // mux(cond, x, x) -> x
1627 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1628 return op.getHigh();
1629
1630 // The following folds require that the result has a known width. Otherwise
1631 // the mux requires an additional padding operation to be inserted, which is
1632 // not possible in a fold.
1633 if (op.getType().getBitWidthOrSentinel() < 0)
1634 return {};
1635
1636 // mux(0/1, x, y) -> x or y
1637 if (auto cond = getConstant(adaptor.getSel())) {
1638 if (cond->isZero() && op.getLow().getType() == op.getType())
1639 return op.getLow();
1640 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1641 return op.getHigh();
1642 }
1643
1644 // mux(cond, x, cst)
1645 if (auto lowCst = getConstant(adaptor.getLow())) {
1646 // mux(cond, c1, c2)
1647 if (auto highCst = getConstant(adaptor.getHigh())) {
1648 // mux(cond, cst, cst) -> cst
1649 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1650 *highCst == *lowCst)
1651 return getIntAttr(op.getType(), *highCst);
1652 // mux(cond, 1, 0) -> cond
1653 if (highCst->isOne() && lowCst->isZero() &&
1654 op.getType() == op.getSel().getType())
1655 return op.getSel();
1656
1657 // TODO: x ? ~0 : 0 -> sext(x)
1658 // TODO: "x ? c1 : c2" -> many tricks
1659 }
1660 // TODO: "x ? a : 0" -> sext(x) & a
1661 }
1662
1663 // TODO: "x ? c1 : y" -> "~x ? y : c1"
1664 return {};
1665}
1666
1667OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1668 return foldMux(*this, adaptor);
1669}
1670
1671OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1672 return foldMux(*this, adaptor);
1673}
1674
1675OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1676
1677namespace {
1678
1679// If the mux has a known output width, pad the operands up to this width.
1680// Most folds on mux require that folded operands are of the same width as
1681// the mux itself.
1682class MuxPad : public mlir::RewritePattern {
1683public:
1684 MuxPad(MLIRContext *context)
1685 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1686
1687 LogicalResult
1688 matchAndRewrite(Operation *op,
1689 mlir::PatternRewriter &rewriter) const override {
1690 auto mux = cast<MuxPrimOp>(op);
1691 auto width = mux.getType().getBitWidthOrSentinel();
1692 if (width < 0)
1693 return failure();
1694
1695 auto pad = [&](Value input) -> Value {
1696 auto inputWidth =
1697 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1698 if (inputWidth < 0 || width == inputWidth)
1699 return input;
1700 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1701 width)
1702 .getResult();
1703 };
1704
1705 auto newHigh = pad(mux.getHigh());
1706 auto newLow = pad(mux.getLow());
1707 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1708 return failure();
1709
1710 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1711 rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1712 mux->getAttrs());
1713 return success();
1714 }
1715};
1716
1717// Find muxes which have conditions dominated by other muxes with the same
1718// condition.
1719class MuxSharedCond : public mlir::RewritePattern {
1720public:
1721 MuxSharedCond(MLIRContext *context)
1722 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1723
1724 static const int depthLimit = 5;
1725
1726 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1727 mlir::PatternRewriter &rewriter,
1728 bool updateInPlace) const {
1729 if (updateInPlace) {
1730 rewriter.modifyOpInPlace(mux, [&] {
1731 mux.setOperand(1, high);
1732 mux.setOperand(2, low);
1733 });
1734 return {};
1735 }
1736 rewriter.setInsertionPointAfter(mux);
1737 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1738 ValueRange{mux.getSel(), high, low})
1739 .getResult();
1740 }
1741
1742 // Walk a dependent mux tree assuming the condition cond is true.
1743 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1744 bool updateInPlace, int limit) const {
1745 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1746 if (!mux)
1747 return {};
1748 if (mux.getSel() == cond)
1749 return mux.getHigh();
1750 if (limit > depthLimit)
1751 return {};
1752 updateInPlace &= mux->hasOneUse();
1753
1754 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1755 limit + 1))
1756 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1757
1758 if (Value v =
1759 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1760 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1761 return {};
1762 }
1763
1764 // Walk a dependent mux tree assuming the condition cond is false.
1765 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1766 bool updateInPlace, int limit) const {
1767 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1768 if (!mux)
1769 return {};
1770 if (mux.getSel() == cond)
1771 return mux.getLow();
1772 if (limit > depthLimit)
1773 return {};
1774 updateInPlace &= mux->hasOneUse();
1775
1776 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1777 limit + 1))
1778 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1779
1780 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1781 limit + 1))
1782 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1783
1784 return {};
1785 }
1786
1787 LogicalResult
1788 matchAndRewrite(Operation *op,
1789 mlir::PatternRewriter &rewriter) const override {
1790 auto mux = cast<MuxPrimOp>(op);
1791 auto width = mux.getType().getBitWidthOrSentinel();
1792 if (width < 0)
1793 return failure();
1794
1795 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1796 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1797 return success();
1798 }
1799
1800 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
1801 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1802 return success();
1803 }
1804
1805 return failure();
1806 }
1807};
1808} // namespace
1809
1810void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1811 MLIRContext *context) {
1812 results
1813 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1814 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1815 patterns::MuxSameTrue, patterns::MuxSameFalse,
1816 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1817 context);
1818}
1819
1820void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1821 RewritePatternSet &results, MLIRContext *context) {
1822 results.add<patterns::Mux2PadSel>(context);
1823}
1824
1825void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1826 RewritePatternSet &results, MLIRContext *context) {
1827 results.add<patterns::Mux4PadSel>(context);
1828}
1829
1830OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1831 auto input = this->getInput();
1832
1833 // pad(x) -> x if the width doesn't change.
1834 if (input.getType() == getType())
1835 return input;
1836
1837 // Need to know the input width.
1838 auto inputType = input.getType().base();
1839 int32_t width = inputType.getWidthOrSentinel();
1840 if (width == -1)
1841 return {};
1842
1843 // Constant fold.
1844 if (auto cst = getConstant(adaptor.getInput())) {
1845 auto destWidth = getType().base().getWidthOrSentinel();
1846 if (destWidth == -1)
1847 return {};
1848
1849 if (inputType.isSigned() && cst->getBitWidth())
1850 return getIntAttr(getType(), cst->sext(destWidth));
1851 return getIntAttr(getType(), cst->zext(destWidth));
1852 }
1853
1854 return {};
1855}
1856
1857OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1858 auto input = this->getInput();
1859 IntType inputType = input.getType();
1860 int shiftAmount = getAmount();
1861
1862 // shl(x, 0) -> x
1863 if (shiftAmount == 0)
1864 return input;
1865
1866 // Constant fold.
1867 if (auto cst = getConstant(adaptor.getInput())) {
1868 auto inputWidth = inputType.getWidthOrSentinel();
1869 if (inputWidth != -1) {
1870 auto resultWidth = inputWidth + shiftAmount;
1871 shiftAmount = std::min(shiftAmount, resultWidth);
1872 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1873 }
1874 }
1875 return {};
1876}
1877
1878OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1879 auto input = this->getInput();
1880 IntType inputType = input.getType();
1881 int shiftAmount = getAmount();
1882 auto inputWidth = inputType.getWidthOrSentinel();
1883
1884 // shr(x, 0) -> x
1885 // Once the shr width changes, do this: shiftAmount == 0 &&
1886 // (!inputType.isSigned() || inputWidth > 0)
1887 if (shiftAmount == 0 && inputWidth > 0)
1888 return input;
1889
1890 if (inputWidth == -1)
1891 return {};
1892 if (inputWidth == 0)
1893 return getIntZerosAttr(getType());
1894
1895 // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
1896 // If x is signed, it is the sign bit.
1897 if (shiftAmount >= inputWidth && inputType.isUnsigned())
1898 return getIntAttr(getType(), APInt(0, 0, false));
1899
1900 // Constant fold.
1901 if (auto cst = getConstant(adaptor.getInput())) {
1902 APInt value;
1903 if (inputType.isSigned())
1904 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1905 else
1906 value = cst->lshr(std::min(shiftAmount, inputWidth));
1907 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1908 return getIntAttr(getType(), value.trunc(resultWidth));
1909 }
1910 return {};
1911}
1912
1913LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1914 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1915 if (inputWidth <= 0)
1916 return failure();
1917
1918 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1919 unsigned shiftAmount = op.getAmount();
1920 if (int(shiftAmount) >= inputWidth) {
1921 // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
1922 if (op.getType().base().isUnsigned())
1923 return failure();
1924
1925 // Shifting a signed value by the full width is actually taking the
1926 // sign bit. If the shift amount is greater than the input width, it
1927 // is equivalent to shifting by the input width.
1928 shiftAmount = inputWidth - 1;
1929 }
1930
1931 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1932 return success();
1933}
1934
1935LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1936 PatternRewriter &rewriter) {
1937 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1938 if (inputWidth <= 0)
1939 return failure();
1940
1941 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1942 unsigned keepAmount = op.getAmount();
1943 if (keepAmount)
1944 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1945 rewriter);
1946 return success();
1947}
1948
1949OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1950 if (hasKnownWidthIntTypes(*this))
1951 if (auto cst = getConstant(adaptor.getInput())) {
1952 int shiftAmount =
1953 getInput().getType().base().getWidthOrSentinel() - getAmount();
1954 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1955 }
1956
1957 return {};
1958}
1959
1960OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1961 if (hasKnownWidthIntTypes(*this))
1962 if (auto cst = getConstant(adaptor.getInput()))
1963 return getIntAttr(getType(),
1964 cst->trunc(getType().base().getWidthOrSentinel()));
1965 return {};
1966}
1967
1968LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1969 PatternRewriter &rewriter) {
1970 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1971 if (inputWidth <= 0)
1972 return failure();
1973
1974 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1975 unsigned dropAmount = op.getAmount();
1976 if (dropAmount != unsigned(inputWidth))
1977 replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
1978 rewriter);
1979 return success();
1980}
1981
1982void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1983 MLIRContext *context) {
1984 results.add<patterns::SubaccessOfConstant>(context);
1985}
1986
1987OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1988 // If there is only one input, just return it.
1989 if (adaptor.getInputs().size() == 1)
1990 return getOperand(1);
1991
1992 if (auto constIndex = getConstant(adaptor.getIndex())) {
1993 auto index = constIndex->getZExtValue();
1994 if (index < getInputs().size())
1995 return getInputs()[getInputs().size() - 1 - index];
1996 }
1997
1998 return {};
1999}
2000
2001LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
2002 PatternRewriter &rewriter) {
2003 // If all operands are equal, just canonicalize to it. We can add this
2004 // canonicalization as a folder but it costly to look through all inputs so it
2005 // is added here.
2006 if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
2007 return input == op.getInputs().front();
2008 })) {
2009 replaceOpAndCopyName(rewriter, op, op.getInputs().front());
2010 return success();
2011 }
2012
2013 // If the index width is narrower than the size of inputs, drop front
2014 // elements.
2015 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
2016 uint64_t inputSize = op.getInputs().size();
2017 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
2018 rewriter.modifyOpInPlace(op, [&]() {
2019 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2020 });
2021 return success();
2022 }
2023
2024 // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
2025 // a[0]`), we can fold the op into subaccess op `a[idx]`.
2026 if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2027 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
2028 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2029 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2030 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2031 })) {
2032 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2033 rewriter, op, lastSubindex.getInput(), op.getIndex());
2034 return success();
2035 }
2036 }
2037
2038 // If the size is 2, canonicalize into a normal mux to introduce more folds.
2039 if (op.getInputs().size() != 2)
2040 return failure();
2041
2042 // TODO: Handle even when `index` doesn't have uint<1>.
2043 auto uintType = op.getIndex().getType();
2044 if (uintType.getBitWidthOrSentinel() != 1)
2045 return failure();
2046
2047 // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
2048 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2049 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2050 return success();
2051}
2052
2053//===----------------------------------------------------------------------===//
2054// Declarations
2055//===----------------------------------------------------------------------===//
2056
2057/// Scan all the uses of the specified value, checking to see if there is
2058/// exactly one connect that has the value as its destination. This returns the
2059/// operation if found and if all the other users are "reads" from the value.
2060/// Returns null if there are no connects, or multiple connects to the value, or
2061/// if the value is involved in an `AttachOp`, or if the connect isn't matching.
2062///
2063/// Note that this will simply return the connect, which is located *anywhere*
2064/// after the definition of the value. Users of this function are likely
2065/// interested in the source side of the returned connect, the definition of
2066/// which does likely not dominate the original value.
2067MatchingConnectOp firrtl::getSingleConnectUserOf(Value value) {
2068 MatchingConnectOp connect;
2069 for (Operation *user : value.getUsers()) {
2070 // If we see an attach or aggregate sublements, just conservatively fail.
2071 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2072 return {};
2073
2074 if (auto aConnect = dyn_cast<FConnectLike>(user))
2075 if (aConnect.getDest() == value) {
2076 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2077 // If this is not a matching connect, a second matching connect or in a
2078 // different block, fail.
2079 if (!matchingConnect || (connect && connect != matchingConnect) ||
2080 matchingConnect->getBlock() != value.getParentBlock())
2081 return {};
2082 connect = matchingConnect;
2083 }
2084 }
2085 return connect;
2086}
2087
2088// Forward simple values through wire's and reg's.
2089static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op,
2090 PatternRewriter &rewriter) {
2091 // While we can do this for nearly all wires, we currently limit it to simple
2092 // things.
2093 Operation *connectedDecl = op.getDest().getDefiningOp();
2094 if (!connectedDecl)
2095 return failure();
2096
2097 // Only support wire and reg for now.
2098 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2099 return failure();
2100 if (hasDontTouch(connectedDecl) || !AnnotationSet(connectedDecl).empty() ||
2101 !hasDroppableName(connectedDecl) ||
2102 cast<Forceable>(connectedDecl).isForceable())
2103 return failure();
2104
2105 // Only forward if the types exactly match and there is one connect.
2106 if (getSingleConnectUserOf(op.getDest()) != op)
2107 return failure();
2108
2109 // Only forward if there is more than one use
2110 if (connectedDecl->hasOneUse())
2111 return failure();
2112
2113 // Only do this if the connectee and the declaration are in the same block.
2114 auto *declBlock = connectedDecl->getBlock();
2115 auto *srcValueOp = op.getSrc().getDefiningOp();
2116 if (!srcValueOp) {
2117 // Ports are ok for wires but not registers.
2118 if (!isa<WireOp>(connectedDecl))
2119 return failure();
2120
2121 } else {
2122 // Constants/invalids in the same block are ok to forward, even through
2123 // reg's since the clocking doesn't matter for constants.
2124 if (!isa<ConstantOp>(srcValueOp))
2125 return failure();
2126 if (srcValueOp->getBlock() != declBlock)
2127 return failure();
2128 }
2129
2130 // Ok, we know we are doing the transformation.
2131
2132 auto replacement = op.getSrc();
2133 // This will be replaced with the constant source. First, make sure the
2134 // constant dominates all users.
2135 if (srcValueOp && srcValueOp != &declBlock->front())
2136 srcValueOp->moveBefore(&declBlock->front());
2137
2138 // Replace all things *using* the decl with the constant/port, and
2139 // remove the declaration.
2140 replaceOpAndCopyName(rewriter, connectedDecl, replacement);
2141
2142 // Remove the connect
2143 rewriter.eraseOp(op);
2144 return success();
2145}
2146
2147void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2148 MLIRContext *context) {
2149 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2150 context);
2151}
2152
2153LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2154 PatternRewriter &rewriter) {
2155 // TODO: Canonicalize towards explicit extensions and flips here.
2156
2157 // If there is a simple value connected to a foldable decl like a wire or reg,
2158 // see if we can eliminate the decl.
2159 if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
2160 return success();
2161 return failure();
2162}
2163
2164//===----------------------------------------------------------------------===//
2165// Statements
2166//===----------------------------------------------------------------------===//
2167
2168/// If the specified value has an AttachOp user strictly dominating by
2169/// "dominatingAttach" then return it.
2170static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
2171 for (auto *user : value.getUsers()) {
2172 auto attach = dyn_cast<AttachOp>(user);
2173 if (!attach || attach == dominatedAttach)
2174 continue;
2175 if (attach->isBeforeInBlock(dominatedAttach))
2176 return attach;
2177 }
2178 return {};
2179}
2180
2181LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2182 // Single operand attaches are a noop.
2183 if (op.getNumOperands() <= 1) {
2184 rewriter.eraseOp(op);
2185 return success();
2186 }
2187
2188 for (auto operand : op.getOperands()) {
2189 // Check to see if any of our operands has other attaches to it:
2190 // attach x, y
2191 // ...
2192 // attach x, z
2193 // If so, we can merge these into "attach x, y, z".
2194 if (auto attach = getDominatingAttachUser(operand, op)) {
2195 SmallVector<Value> newOperands(op.getOperands());
2196 for (auto newOperand : attach.getOperands())
2197 if (newOperand != operand) // Don't add operand twice.
2198 newOperands.push_back(newOperand);
2199 AttachOp::create(rewriter, op->getLoc(), newOperands);
2200 rewriter.eraseOp(attach);
2201 rewriter.eraseOp(op);
2202 return success();
2203 }
2204
2205 // If this wire is *only* used by an attach then we can just delete
2206 // it.
2207 // TODO: May need to be sensitive to "don't touch" or other
2208 // annotations.
2209 if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2210 if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2211 !wire.isForceable()) {
2212 SmallVector<Value> newOperands;
2213 for (auto newOperand : op.getOperands())
2214 if (newOperand != operand) // Don't the add wire.
2215 newOperands.push_back(newOperand);
2216
2217 AttachOp::create(rewriter, op->getLoc(), newOperands);
2218 rewriter.eraseOp(op);
2219 rewriter.eraseOp(wire);
2220 return success();
2221 }
2222 }
2223 }
2224 return failure();
2225}
2226
2227/// Replaces the given op with the contents of the given single-block region.
2228static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
2229 Region &region) {
2230 assert(llvm::hasSingleElement(region) && "expected single-region block");
2231 rewriter.inlineBlockBefore(&region.front(), op, {});
2232}
2233
2234LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2235 if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2236 if (constant.getValue().isAllOnes())
2237 replaceOpWithRegion(rewriter, op, op.getThenRegion());
2238 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2239 replaceOpWithRegion(rewriter, op, op.getElseRegion());
2240
2241 rewriter.eraseOp(op);
2242
2243 return success();
2244 }
2245
2246 // Erase empty if-else block.
2247 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2248 op.getElseBlock().empty()) {
2249 rewriter.eraseBlock(&op.getElseBlock());
2250 return success();
2251 }
2252
2253 // Erase empty whens.
2254
2255 // If there is stuff in the then block, leave this operation alone.
2256 if (!op.getThenBlock().empty())
2257 return failure();
2258
2259 // If not and there is no else, then this operation is just useless.
2260 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2261 rewriter.eraseOp(op);
2262 return success();
2263 }
2264 return failure();
2265}
2266
2267namespace {
2268// Remove private nodes. If they have an interesting names, move the name to
2269// the source expression.
2270struct FoldNodeName : public mlir::RewritePattern {
2271 FoldNodeName(MLIRContext *context)
2272 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
2273 LogicalResult matchAndRewrite(Operation *op,
2274 PatternRewriter &rewriter) const override {
2275 auto node = cast<NodeOp>(op);
2276 auto name = node.getNameAttr();
2277 if (!node.hasDroppableName() || node.getInnerSym() ||
2278 !AnnotationSet(node).empty() || node.isForceable())
2279 return failure();
2280 auto *newOp = node.getInput().getDefiningOp();
2281 if (newOp)
2282 updateName(rewriter, newOp, name);
2283 rewriter.replaceOp(node, node.getInput());
2284 return success();
2285 }
2286};
2287
2288// Bypass nodes.
2289struct NodeBypass : public mlir::RewritePattern {
2290 NodeBypass(MLIRContext *context)
2291 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
2292 LogicalResult matchAndRewrite(Operation *op,
2293 PatternRewriter &rewriter) const override {
2294 auto node = cast<NodeOp>(op);
2295 if (node.getInnerSym() || !AnnotationSet(node).empty() ||
2296 node.use_empty() || node.isForceable())
2297 return failure();
2298 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2299 return success();
2300 }
2301};
2302
2303} // namespace
2304
2305template <typename OpTy>
2306static LogicalResult demoteForceableIfUnused(OpTy op,
2307 PatternRewriter &rewriter) {
2308 if (!op.isForceable() || !op.getDataRef().use_empty())
2309 return failure();
2310
2311 firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
2312 return success();
2313}
2314
2315// Interesting names and symbols and don't touch force nodes to stick around.
2316LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2317 SmallVectorImpl<OpFoldResult> &results) {
2318 if (!hasDroppableName())
2319 return failure();
2320 if (hasDontTouch(getResult())) // handles inner symbols
2321 return failure();
2322 if (getAnnotationsAttr() && !AnnotationSet(getAnnotationsAttr()).empty())
2323 return failure();
2324 if (isForceable())
2325 return failure();
2326 if (!adaptor.getInput())
2327 return failure();
2328
2329 results.push_back(adaptor.getInput());
2330 return success();
2331}
2332
2333void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2334 MLIRContext *context) {
2335 results.insert<FoldNodeName>(context);
2336 results.add(demoteForceableIfUnused<NodeOp>);
2337}
2338
2339namespace {
2340// For a lhs, find all the writers of fields of the aggregate type. If there
2341// is one writer for each field, merge the writes
2342struct AggOneShot : public mlir::RewritePattern {
2343 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2344 : RewritePattern(name, 0, context) {}
2345
2346 SmallVector<Value> getCompleteWrite(Operation *lhs) const {
2347 auto lhsTy = lhs->getResult(0).getType();
2348 if (!type_isa<BundleType, FVectorType>(lhsTy))
2349 return {};
2350
2351 DenseMap<uint32_t, Value> fields;
2352 for (Operation *user : lhs->getResult(0).getUsers()) {
2353 if (user->getParentOp() != lhs->getParentOp())
2354 return {};
2355 if (auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2356 if (aConnect.getDest() == lhs->getResult(0))
2357 return {};
2358 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2359 for (Operation *subuser : subField.getResult().getUsers()) {
2360 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2361 if (aConnect.getDest() == subField) {
2362 if (subuser->getParentOp() != lhs->getParentOp())
2363 return {};
2364 if (fields.count(subField.getFieldIndex())) // duplicate write
2365 return {};
2366 fields[subField.getFieldIndex()] = aConnect.getSrc();
2367 }
2368 continue;
2369 }
2370 return {};
2371 }
2372 } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2373 for (Operation *subuser : subIndex.getResult().getUsers()) {
2374 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2375 if (aConnect.getDest() == subIndex) {
2376 if (subuser->getParentOp() != lhs->getParentOp())
2377 return {};
2378 if (fields.count(subIndex.getIndex())) // duplicate write
2379 return {};
2380 fields[subIndex.getIndex()] = aConnect.getSrc();
2381 }
2382 continue;
2383 }
2384 return {};
2385 }
2386 } else {
2387 return {};
2388 }
2389 }
2390
2391 SmallVector<Value> values;
2392 uint32_t total = type_isa<BundleType>(lhsTy)
2393 ? type_cast<BundleType>(lhsTy).getNumElements()
2394 : type_cast<FVectorType>(lhsTy).getNumElements();
2395 for (uint32_t i = 0; i < total; ++i) {
2396 if (!fields.count(i))
2397 return {};
2398 values.push_back(fields[i]);
2399 }
2400 return values;
2401 }
2402
2403 LogicalResult matchAndRewrite(Operation *op,
2404 PatternRewriter &rewriter) const override {
2405 auto values = getCompleteWrite(op);
2406 if (values.empty())
2407 return failure();
2408 rewriter.setInsertionPointToEnd(op->getBlock());
2409 auto dest = op->getResult(0);
2410 auto destType = dest.getType();
2411
2412 // If not passive, cannot matchingconnect.
2413 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2414 return failure();
2415
2416 Value newVal = type_isa<BundleType>(destType)
2417 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2418 destType, values)
2419 : rewriter.createOrFold<VectorCreateOp>(
2420 op->getLoc(), destType, values);
2421 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2422 for (Operation *user : dest.getUsers()) {
2423 if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2424 for (Operation *subuser :
2425 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2426 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2427 if (aConnect.getDest() == subIndex)
2428 rewriter.eraseOp(aConnect);
2429 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2430 for (Operation *subuser :
2431 llvm::make_early_inc_range(subField.getResult().getUsers()))
2432 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2433 if (aConnect.getDest() == subField)
2434 rewriter.eraseOp(aConnect);
2435 }
2436 }
2437 return success();
2438 }
2439};
2440
2441struct WireAggOneShot : public AggOneShot {
2442 WireAggOneShot(MLIRContext *context)
2443 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2444};
2445struct SubindexAggOneShot : public AggOneShot {
2446 SubindexAggOneShot(MLIRContext *context)
2447 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2448};
2449struct SubfieldAggOneShot : public AggOneShot {
2450 SubfieldAggOneShot(MLIRContext *context)
2451 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2452};
2453} // namespace
2454
2455void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2456 MLIRContext *context) {
2457 results.insert<WireAggOneShot>(context);
2458 results.add(demoteForceableIfUnused<WireOp>);
2459}
2460
2461void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2462 MLIRContext *context) {
2463 results.insert<SubindexAggOneShot>(context);
2464}
2465
2466OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2467 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2468 if (!attr)
2469 return {};
2470 return attr[getIndex()];
2471}
2472
2473OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2474 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2475 if (!attr)
2476 return {};
2477 auto index = getFieldIndex();
2478 return attr[index];
2479}
2480
2481void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2482 MLIRContext *context) {
2483 results.insert<SubfieldAggOneShot>(context);
2484}
2485
2486static Attribute collectFields(MLIRContext *context,
2487 ArrayRef<Attribute> operands) {
2488 for (auto operand : operands)
2489 if (!operand)
2490 return {};
2491 return ArrayAttr::get(context, operands);
2492}
2493
2494OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2495 // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2496 // bundle<a:..., b:...>.
2497 if (getNumOperands() > 0)
2498 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2499 if (first.getFieldIndex() == 0 &&
2500 first.getInput().getType() == getType() &&
2501 llvm::all_of(
2502 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2503 auto subindex =
2504 elem.value().template getDefiningOp<SubfieldOp>();
2505 return subindex && subindex.getInput() == first.getInput() &&
2506 subindex.getFieldIndex() == elem.index();
2507 }))
2508 return first.getInput();
2509
2510 return collectFields(getContext(), adaptor.getOperands());
2511}
2512
2513OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2514 // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2515 // vector<..., 2>.
2516 if (getNumOperands() > 0)
2517 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2518 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2519 llvm::all_of(
2520 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2521 auto subindex =
2522 elem.value().template getDefiningOp<SubindexOp>();
2523 return subindex && subindex.getInput() == first.getInput() &&
2524 subindex.getIndex() == elem.index();
2525 }))
2526 return first.getInput();
2527
2528 return collectFields(getContext(), adaptor.getOperands());
2529}
2530
2531OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2532 if (getOperand().getType() == getType())
2533 return getOperand();
2534 return {};
2535}
2536
2537namespace {
2538// A register with constant reset and all connection to either itself or the
2539// same constant, must be replaced by the constant.
2540struct FoldResetMux : public mlir::RewritePattern {
2541 FoldResetMux(MLIRContext *context)
2542 : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2543 LogicalResult matchAndRewrite(Operation *op,
2544 PatternRewriter &rewriter) const override {
2545 auto reg = cast<RegResetOp>(op);
2546 auto reset =
2547 dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2548 if (!reset || hasDontTouch(reg.getOperation()) ||
2549 !AnnotationSet(reg).empty() || reg.isForceable())
2550 return failure();
2551 // Find the one true connect, or bail
2552 auto con = getSingleConnectUserOf(reg.getResult());
2553 if (!con)
2554 return failure();
2555
2556 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2557 if (!mux)
2558 return failure();
2559 auto *high = mux.getHigh().getDefiningOp();
2560 auto *low = mux.getLow().getDefiningOp();
2561 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2562
2563 if (constOp && low != reg)
2564 return failure();
2565 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2566 constOp = dyn_cast<ConstantOp>(low);
2567
2568 if (!constOp || constOp.getType() != reset.getType() ||
2569 constOp.getValue() != reset.getValue())
2570 return failure();
2571
2572 // Check all types should be typed by now
2573 auto regTy = reg.getResult().getType();
2574 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2575 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2576 regTy.getBitWidthOrSentinel() < 0)
2577 return failure();
2578
2579 // Ok, we know we are doing the transformation.
2580
2581 // Make sure the constant dominates all users.
2582 if (constOp != &con->getBlock()->front())
2583 constOp->moveBefore(&con->getBlock()->front());
2584
2585 // Replace the register with the constant.
2586 replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2587 // Remove the connect.
2588 rewriter.eraseOp(con);
2589 return success();
2590 }
2591};
2592} // namespace
2593
2594static bool isDefinedByOneConstantOp(Value v) {
2595 if (auto c = v.getDefiningOp<ConstantOp>())
2596 return c.getValue().isOne();
2597 if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2598 return sc.getValue();
2599 return false;
2600}
2601
2602static LogicalResult
2603canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2604 if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2605 return failure();
2606
2607 auto resetValue = reg.getResetValue();
2608 if (reg.getType(0) != resetValue.getType())
2609 return failure();
2610
2611 // Ignore 'passthrough'.
2612 (void)dropWrite(rewriter, reg->getResult(0), {});
2613 replaceOpWithNewOpAndCopyName<NodeOp>(
2614 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2615 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2616 return success();
2617}
2618
2619void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2620 MLIRContext *context) {
2621 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2623 results.add(demoteForceableIfUnused<RegResetOp>);
2624}
2625
2626// Returns the value connected to a port, if there is only one.
2627static Value getPortFieldValue(Value port, StringRef name) {
2628 auto portTy = type_cast<BundleType>(port.getType());
2629 auto fieldIndex = portTy.getElementIndex(name);
2630 assert(fieldIndex && "missing field on memory port");
2631
2632 Value value = {};
2633 for (auto *op : port.getUsers()) {
2634 auto portAccess = cast<SubfieldOp>(op);
2635 if (fieldIndex != portAccess.getFieldIndex())
2636 continue;
2637 auto conn = getSingleConnectUserOf(portAccess);
2638 if (!conn || value)
2639 return {};
2640 value = conn.getSrc();
2641 }
2642 return value;
2643}
2644
2645// Returns true if the enable field of a port is set to false.
2646static bool isPortDisabled(Value port) {
2647 auto value = getPortFieldValue(port, "en");
2648 if (!value)
2649 return false;
2650 auto portConst = value.getDefiningOp<ConstantOp>();
2651 if (!portConst)
2652 return false;
2653 return portConst.getValue().isZero();
2654}
2655
2656// Returns true if the data output is unused.
2657static bool isPortUnused(Value port, StringRef data) {
2658 auto portTy = type_cast<BundleType>(port.getType());
2659 auto fieldIndex = portTy.getElementIndex(data);
2660 assert(fieldIndex && "missing enable flag on memory port");
2661
2662 for (auto *op : port.getUsers()) {
2663 auto portAccess = cast<SubfieldOp>(op);
2664 if (fieldIndex != portAccess.getFieldIndex())
2665 continue;
2666 if (!portAccess.use_empty())
2667 return false;
2668 }
2669
2670 return true;
2671}
2672
2673// Returns the value connected to a port, if there is only one.
2674static void replacePortField(PatternRewriter &rewriter, Value port,
2675 StringRef name, Value value) {
2676 auto portTy = type_cast<BundleType>(port.getType());
2677 auto fieldIndex = portTy.getElementIndex(name);
2678 assert(fieldIndex && "missing field on memory port");
2679
2680 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2681 auto portAccess = cast<SubfieldOp>(op);
2682 if (fieldIndex != portAccess.getFieldIndex())
2683 continue;
2684 rewriter.replaceAllUsesWith(portAccess, value);
2685 rewriter.eraseOp(portAccess);
2686 }
2687}
2688
2689// Remove accesses to a port which is used.
2690static void erasePort(PatternRewriter &rewriter, Value port) {
2691 // Helper to create a dummy 0 clock for the dummy registers.
2692 Value clock;
2693 auto getClock = [&] {
2694 if (!clock)
2695 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2696 ClockType::get(rewriter.getContext()),
2697 false);
2698 return clock;
2699 };
2700
2701 // Find the clock field of the port and determine whether the port is
2702 // accessed only through its subfields or as a whole wire. If the port
2703 // is used in its entirety, replace it with a wire. Otherwise,
2704 // eliminate individual subfields and replace with reasonable defaults.
2705 for (auto *op : port.getUsers()) {
2706 auto subfield = dyn_cast<SubfieldOp>(op);
2707 if (!subfield) {
2708 auto ty = port.getType();
2709 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2710 rewriter.replaceAllUsesWith(port, reg.getResult());
2711 return;
2712 }
2713 }
2714
2715 // Remove all connects to field accesses as they are no longer relevant.
2716 // If field values are used anywhere, which should happen solely for read
2717 // ports, a dummy register is introduced which replicates the behaviour of
2718 // memory that is never written, but might be read.
2719 for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2720 auto access = cast<SubfieldOp>(accessOp);
2721 for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2722 auto connect = dyn_cast<FConnectLike>(user);
2723 if (connect && connect.getDest() == access) {
2724 rewriter.eraseOp(user);
2725 continue;
2726 }
2727 }
2728 if (access.use_empty()) {
2729 rewriter.eraseOp(access);
2730 continue;
2731 }
2732
2733 // Replace read values with a register that is never written, handing off
2734 // the canonicalization of such a register to another canonicalizer.
2735 auto ty = access.getType();
2736 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2737 rewriter.replaceOp(access, reg.getResult());
2738 }
2739 assert(port.use_empty() && "port should have no remaining uses");
2740}
2741
2742namespace {
2743// If memory has known, but zero width, eliminate it.
2744struct FoldZeroWidthMemory : public mlir::RewritePattern {
2745 FoldZeroWidthMemory(MLIRContext *context)
2746 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2747 LogicalResult matchAndRewrite(Operation *op,
2748 PatternRewriter &rewriter) const override {
2749 MemOp mem = cast<MemOp>(op);
2750 if (hasDontTouch(mem))
2751 return failure();
2752
2753 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2754 mem.getDataType().getBitWidthOrSentinel() != 0)
2755 return failure();
2756
2757 // Make sure are users are safe to replace
2758 for (auto port : mem.getResults())
2759 for (auto *user : port.getUsers())
2760 if (!isa<SubfieldOp>(user))
2761 return failure();
2762
2763 // Annoyingly, there isn't a good replacement for the port as a whole,
2764 // since they have an outer flip type.
2765 for (auto port : op->getResults()) {
2766 for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2767 SubfieldOp sfop = cast<SubfieldOp>(user);
2768 StringRef fieldName = sfop.getFieldName();
2769 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2770 rewriter, sfop, sfop.getResult().getType())
2771 .getResult();
2772 if (fieldName.ends_with("data")) {
2773 // Make sure to write data ports.
2774 auto zero = firrtl::ConstantOp::create(
2775 rewriter, wire.getLoc(),
2776 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2777 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2778 }
2779 }
2780 }
2781 rewriter.eraseOp(op);
2782 return success();
2783 }
2784};
2785
2786// If memory has no write ports and no file initialization, eliminate it.
2787struct FoldReadOrWriteOnlyMemory : public mlir::RewritePattern {
2788 FoldReadOrWriteOnlyMemory(MLIRContext *context)
2789 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2790 LogicalResult matchAndRewrite(Operation *op,
2791 PatternRewriter &rewriter) const override {
2792 MemOp mem = cast<MemOp>(op);
2793 if (hasDontTouch(mem))
2794 return failure();
2795 bool isRead = false, isWritten = false;
2796 for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2797 switch (mem.getPortKind(i)) {
2798 case MemOp::PortKind::Read:
2799 isRead = true;
2800 if (isWritten)
2801 return failure();
2802 continue;
2803 case MemOp::PortKind::Write:
2804 isWritten = true;
2805 if (isRead)
2806 return failure();
2807 continue;
2808 case MemOp::PortKind::Debug:
2809 case MemOp::PortKind::ReadWrite:
2810 return failure();
2811 }
2812 llvm_unreachable("unknown port kind");
2813 }
2814 assert((!isWritten || !isRead) && "memory is in use");
2815
2816 // If the memory is read only, but has a file initialization, then we can't
2817 // remove it. A write only memory with file initialization is okay to
2818 // remove.
2819 if (isRead && mem.getInit())
2820 return failure();
2821
2822 for (auto port : mem.getResults())
2823 erasePort(rewriter, port);
2824
2825 rewriter.eraseOp(op);
2826 return success();
2827 }
2828};
2829
2830// Eliminate the dead ports of memories.
2831struct FoldUnusedPorts : public mlir::RewritePattern {
2832 FoldUnusedPorts(MLIRContext *context)
2833 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2834 LogicalResult matchAndRewrite(Operation *op,
2835 PatternRewriter &rewriter) const override {
2836 MemOp mem = cast<MemOp>(op);
2837 if (hasDontTouch(mem))
2838 return failure();
2839 // Identify the dead and changed ports.
2840 llvm::SmallBitVector deadPorts(mem.getNumResults());
2841 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2842 // Do not simplify annotated ports.
2843 if (!mem.getPortAnnotation(i).empty())
2844 continue;
2845
2846 // Skip debug ports.
2847 auto kind = mem.getPortKind(i);
2848 if (kind == MemOp::PortKind::Debug)
2849 continue;
2850
2851 // If a port is disabled, always eliminate it.
2852 if (isPortDisabled(port)) {
2853 deadPorts.set(i);
2854 continue;
2855 }
2856 // Eliminate read ports whose outputs are not used.
2857 if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
2858 deadPorts.set(i);
2859 continue;
2860 }
2861 }
2862 if (deadPorts.none())
2863 return failure();
2864
2865 // Rebuild the new memory with the altered ports.
2866 SmallVector<Type> resultTypes;
2867 SmallVector<StringRef> portNames;
2868 SmallVector<Attribute> portAnnotations;
2869 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2870 if (deadPorts[i])
2871 continue;
2872 resultTypes.push_back(port.getType());
2873 portNames.push_back(mem.getPortName(i));
2874 portAnnotations.push_back(mem.getPortAnnotation(i));
2875 }
2876
2877 MemOp newOp;
2878 if (!resultTypes.empty())
2879 newOp = MemOp::create(
2880 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2881 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2882 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2883 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2884 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2885
2886 // Replace the dead ports with dummy wires.
2887 unsigned nextPort = 0;
2888 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2889 if (deadPorts[i])
2890 erasePort(rewriter, port);
2891 else
2892 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2893 }
2894
2895 rewriter.eraseOp(op);
2896 return success();
2897 }
2898};
2899
2900// Rewrite write-only read-write ports to write ports.
2901struct FoldReadWritePorts : public mlir::RewritePattern {
2902 FoldReadWritePorts(MLIRContext *context)
2903 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2904 LogicalResult matchAndRewrite(Operation *op,
2905 PatternRewriter &rewriter) const override {
2906 MemOp mem = cast<MemOp>(op);
2907 if (hasDontTouch(mem))
2908 return failure();
2909
2910 // Identify read-write ports whose read end is unused.
2911 llvm::SmallBitVector deadReads(mem.getNumResults());
2912 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2913 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2914 continue;
2915 if (!mem.getPortAnnotation(i).empty())
2916 continue;
2917 if (isPortUnused(port, "rdata")) {
2918 deadReads.set(i);
2919 continue;
2920 }
2921 }
2922 if (deadReads.none())
2923 return failure();
2924
2925 SmallVector<Type> resultTypes;
2926 SmallVector<StringRef> portNames;
2927 SmallVector<Attribute> portAnnotations;
2928 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2929 if (deadReads[i])
2930 resultTypes.push_back(
2931 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2932 MemOp::PortKind::Write, mem.getMaskBits()));
2933 else
2934 resultTypes.push_back(port.getType());
2935
2936 portNames.push_back(mem.getPortName(i));
2937 portAnnotations.push_back(mem.getPortAnnotation(i));
2938 }
2939
2940 auto newOp = MemOp::create(
2941 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2942 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2943 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2944 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2945 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2946
2947 for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2948 auto result = mem.getResult(i);
2949 auto newResult = newOp.getResult(i);
2950 if (deadReads[i]) {
2951 auto resultPortTy = type_cast<BundleType>(result.getType());
2952
2953 // Rewrite accesses to the old port field to accesses to a
2954 // corresponding field of the new port.
2955 auto replace = [&](StringRef toName, StringRef fromName) {
2956 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2957 assert(fromFieldIndex && "missing enable flag on memory port");
2958
2959 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
2960 newResult, toName);
2961 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2962 auto fromField = cast<SubfieldOp>(op);
2963 if (fromFieldIndex != fromField.getFieldIndex())
2964 continue;
2965 rewriter.replaceOp(fromField, toField.getResult());
2966 }
2967 };
2968
2969 replace("addr", "addr");
2970 replace("en", "en");
2971 replace("clk", "clk");
2972 replace("data", "wdata");
2973 replace("mask", "wmask");
2974
2975 // Remove the wmode field, replacing it with dummy wires.
2976 auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
2977 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2978 auto wmodeField = cast<SubfieldOp>(op);
2979 if (wmodeFieldIndex != wmodeField.getFieldIndex())
2980 continue;
2981 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2982 }
2983 } else {
2984 rewriter.replaceAllUsesWith(result, newResult);
2985 }
2986 }
2987 rewriter.eraseOp(op);
2988 return success();
2989 }
2990};
2991
2992// Eliminate the dead ports of memories.
2993struct FoldUnusedBits : public mlir::RewritePattern {
2994 FoldUnusedBits(MLIRContext *context)
2995 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2996
2997 LogicalResult matchAndRewrite(Operation *op,
2998 PatternRewriter &rewriter) const override {
2999 MemOp mem = cast<MemOp>(op);
3000 if (hasDontTouch(mem))
3001 return failure();
3002
3003 // Only apply the transformation if the memory is not sequential.
3004 const auto &summary = mem.getSummary();
3005 if (summary.isMasked || summary.isSeqMem())
3006 return failure();
3007
3008 auto type = type_dyn_cast<IntType>(mem.getDataType());
3009 if (!type)
3010 return failure();
3011 auto width = type.getBitWidthOrSentinel();
3012 if (width <= 0)
3013 return failure();
3014
3015 llvm::SmallBitVector usedBits(width);
3016 DenseMap<unsigned, unsigned> mapping;
3017
3018 // Find which bits are used out of the users of a read port. This detects
3019 // ports whose data/rdata field is used only through bit select ops. The
3020 // bit selects are then used to build a bit-mask. The ops are collected.
3021 SmallVector<BitsPrimOp> readOps;
3022 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3023 auto portTy = type_cast<BundleType>(port.getType());
3024 auto fieldIndex = portTy.getElementIndex(field);
3025 assert(fieldIndex && "missing data port");
3026
3027 for (auto *op : port.getUsers()) {
3028 auto portAccess = cast<SubfieldOp>(op);
3029 if (fieldIndex != portAccess.getFieldIndex())
3030 continue;
3031
3032 for (auto *user : op->getUsers()) {
3033 auto bits = dyn_cast<BitsPrimOp>(user);
3034 if (!bits)
3035 return failure();
3036
3037 usedBits.set(bits.getLo(), bits.getHi() + 1);
3038 if (usedBits.all())
3039 return failure();
3040
3041 mapping[bits.getLo()] = 0;
3042 readOps.push_back(bits);
3043 }
3044 }
3045
3046 return success();
3047 };
3048
3049 // Finds the users of write ports. This expects all the data/wdata fields
3050 // of the ports to be used solely as the destination of matching connects.
3051 // If a memory has ports with other uses, it is excluded from optimisation.
3052 SmallVector<MatchingConnectOp> writeOps;
3053 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3054 auto portTy = type_cast<BundleType>(port.getType());
3055 auto fieldIndex = portTy.getElementIndex(field);
3056 assert(fieldIndex && "missing data port");
3057
3058 for (auto *op : port.getUsers()) {
3059 auto portAccess = cast<SubfieldOp>(op);
3060 if (fieldIndex != portAccess.getFieldIndex())
3061 continue;
3062
3063 auto conn = getSingleConnectUserOf(portAccess);
3064 if (!conn)
3065 return failure();
3066
3067 writeOps.push_back(conn);
3068 }
3069 return success();
3070 };
3071
3072 // Traverse all ports and find the read and used data fields.
3073 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3074 // Do not simplify annotated ports.
3075 if (!mem.getPortAnnotation(i).empty())
3076 return failure();
3077
3078 switch (mem.getPortKind(i)) {
3079 case MemOp::PortKind::Debug:
3080 // Skip debug ports.
3081 return failure();
3082 case MemOp::PortKind::Write:
3083 if (failed(findWriteUsers(port, "data")))
3084 return failure();
3085 continue;
3086 case MemOp::PortKind::Read:
3087 if (failed(findReadUsers(port, "data")))
3088 return failure();
3089 continue;
3090 case MemOp::PortKind::ReadWrite:
3091 if (failed(findWriteUsers(port, "wdata")))
3092 return failure();
3093 if (failed(findReadUsers(port, "rdata")))
3094 return failure();
3095 continue;
3096 }
3097 llvm_unreachable("unknown port kind");
3098 }
3099
3100 // Unused memories are handled in a different canonicalizer.
3101 if (usedBits.none())
3102 return failure();
3103
3104 // Build a mapping of existing indices to compacted ones.
3105 SmallVector<std::pair<unsigned, unsigned>> ranges;
3106 unsigned newWidth = 0;
3107 for (int i = usedBits.find_first(); 0 <= i && i < width;) {
3108 int e = usedBits.find_next_unset(i);
3109 if (e < 0)
3110 e = width;
3111 for (int idx = i; idx < e; ++idx, ++newWidth) {
3112 if (auto it = mapping.find(idx); it != mapping.end()) {
3113 it->second = newWidth;
3114 }
3115 }
3116 ranges.emplace_back(i, e - 1);
3117 i = e != width ? usedBits.find_next(e) : e;
3118 }
3119
3120 // Create the new op with the new port types.
3121 auto newType = IntType::get(op->getContext(), type.isSigned(), newWidth);
3122 SmallVector<Type> portTypes;
3123 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3124 portTypes.push_back(
3125 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3126 }
3127 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3128 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3129 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3130 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3131 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3132
3133 // Rewrite bundle users to the new data type.
3134 auto rewriteSubfield = [&](Value port, StringRef field) {
3135 auto portTy = type_cast<BundleType>(port.getType());
3136 auto fieldIndex = portTy.getElementIndex(field);
3137 assert(fieldIndex && "missing data port");
3138
3139 rewriter.setInsertionPointAfter(newMem);
3140 auto newPortAccess =
3141 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3142
3143 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
3144 auto portAccess = cast<SubfieldOp>(op);
3145 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3146 continue;
3147 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3148 }
3149 };
3150
3151 // Rewrite the field accesses.
3152 for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
3153 switch (newMem.getPortKind(i)) {
3154 case MemOp::PortKind::Debug:
3155 llvm_unreachable("cannot rewrite debug port");
3156 case MemOp::PortKind::Write:
3157 rewriteSubfield(port, "data");
3158 continue;
3159 case MemOp::PortKind::Read:
3160 rewriteSubfield(port, "data");
3161 continue;
3162 case MemOp::PortKind::ReadWrite:
3163 rewriteSubfield(port, "rdata");
3164 rewriteSubfield(port, "wdata");
3165 continue;
3166 }
3167 llvm_unreachable("unknown port kind");
3168 }
3169
3170 // Rewrite the reads to the new ranges, compacting them.
3171 for (auto readOp : readOps) {
3172 rewriter.setInsertionPointAfter(readOp);
3173 auto it = mapping.find(readOp.getLo());
3174 assert(it != mapping.end() && "bit op mapping not found");
3175 // Create a new bit selection from the compressed memory. The new op may
3176 // be folded if we are selecting the entire compressed memory.
3177 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3178 readOp.getLoc(), readOp.getInput(),
3179 readOp.getHi() - readOp.getLo() + it->second, it->second);
3180 rewriter.replaceAllUsesWith(readOp, newReadValue);
3181 rewriter.eraseOp(readOp);
3182 }
3183
3184 // Rewrite the writes into a concatenation of slices.
3185 for (auto writeOp : writeOps) {
3186 Value source = writeOp.getSrc();
3187 rewriter.setInsertionPoint(writeOp);
3188
3189 SmallVector<Value> slices;
3190 for (auto &[start, end] : llvm::reverse(ranges)) {
3191 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3192 source, end, start);
3193 slices.push_back(slice);
3194 }
3195
3196 Value catOfSlices =
3197 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3198
3199 // If the original memory held a signed integer, then the compressed
3200 // memory will be signed too. Since the catOfSlices is always unsigned,
3201 // cast the data to a signed integer if needed before connecting back to
3202 // the memory.
3203 if (type.isSigned())
3204 catOfSlices =
3205 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3206
3207 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3208 catOfSlices);
3209 }
3210
3211 return success();
3212 }
3213};
3214
3215// Rewrite single-address memories to a firrtl register.
3216struct FoldRegMems : public mlir::RewritePattern {
3217 FoldRegMems(MLIRContext *context)
3218 : RewritePattern(MemOp::getOperationName(), 0, context) {}
3219 LogicalResult matchAndRewrite(Operation *op,
3220 PatternRewriter &rewriter) const override {
3221 MemOp mem = cast<MemOp>(op);
3222 const FirMemory &info = mem.getSummary();
3223 if (hasDontTouch(mem) || info.depth != 1)
3224 return failure();
3225
3226 auto ty = mem.getDataType();
3227 auto loc = mem.getLoc();
3228 auto *block = mem->getBlock();
3229
3230 // Find the clock of the register-to-be, all write ports should share it.
3231 Value clock;
3232 SmallPtrSet<Operation *, 8> connects;
3233 SmallVector<SubfieldOp> portAccesses;
3234 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3235 if (!mem.getPortAnnotation(i).empty())
3236 continue;
3237
3238 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3239 auto portTy = type_cast<BundleType>(port.getType());
3240 for (auto field : fields) {
3241 auto fieldIndex = portTy.getElementIndex(field);
3242 assert(fieldIndex && "missing field on memory port");
3243
3244 for (auto *op : port.getUsers()) {
3245 auto portAccess = cast<SubfieldOp>(op);
3246 if (fieldIndex != portAccess.getFieldIndex())
3247 continue;
3248 portAccesses.push_back(portAccess);
3249 for (auto *user : portAccess->getUsers()) {
3250 auto conn = dyn_cast<FConnectLike>(user);
3251 if (!conn)
3252 return failure();
3253 connects.insert(conn);
3254 }
3255 }
3256 }
3257 return success();
3258 };
3259
3260 switch (mem.getPortKind(i)) {
3261 case MemOp::PortKind::Debug:
3262 return failure();
3263 case MemOp::PortKind::Read:
3264 if (failed(collect({"clk", "en", "addr"})))
3265 return failure();
3266 continue;
3267 case MemOp::PortKind::Write:
3268 if (failed(collect({"clk", "en", "addr", "data", "mask"})))
3269 return failure();
3270 break;
3271 case MemOp::PortKind::ReadWrite:
3272 if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
3273 return failure();
3274 break;
3275 }
3276
3277 Value portClock = getPortFieldValue(port, "clk");
3278 if (!portClock || (clock && portClock != clock))
3279 return failure();
3280 clock = portClock;
3281 }
3282 // Create a new wire where the memory used to be. This wire will dominate
3283 // all readers of the memory. Reads should be made through this wire.
3284 rewriter.setInsertionPointAfter(mem);
3285 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3286
3287 // The memory is replaced by a register, which we place at the end of the
3288 // block, so that any value driven to the original memory will dominate the
3289 // new register (including the clock). All other ops will be placed
3290 // after the register.
3291 rewriter.setInsertionPointToEnd(block);
3292 auto memReg =
3293 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3294
3295 // Connect the output of the register to the wire.
3296 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3297
3298 // Helper to insert a given number of pipeline stages through registers.
3299 // The pipelines are placed at the end of the block.
3300 auto pipeline = [&](Value value, Value clock, const Twine &name,
3301 unsigned latency) {
3302 for (unsigned i = 0; i < latency; ++i) {
3303 std::string regName;
3304 {
3305 llvm::raw_string_ostream os(regName);
3306 os << mem.getName() << "_" << name << "_" << i;
3307 }
3308 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3309 rewriter.getStringAttr(regName))
3310 .getResult();
3311 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3312 value = reg;
3313 }
3314 return value;
3315 };
3316
3317 const unsigned writeStages = info.writeLatency - 1;
3318
3319 // Traverse each port. Replace reads with the pipelined register, discarding
3320 // the enable flag and reading unconditionally. Pipeline the mask, enable
3321 // and data bits of all write ports to be arbitrated and wired to the reg.
3322 SmallVector<std::tuple<Value, Value, Value>> writes;
3323 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3324 Value portClock = getPortFieldValue(port, "clk");
3325 StringRef name = mem.getPortName(i);
3326
3327 auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
3328 Value value = getPortFieldValue(port, field);
3329 assert(value);
3330 return pipeline(value, portClock, name + "_" + field, stages);
3331 };
3332
3333 switch (mem.getPortKind(i)) {
3334 case MemOp::PortKind::Debug:
3335 llvm_unreachable("unknown port kind");
3336 case MemOp::PortKind::Read: {
3337 // Read ports pipeline the addr and enable signals. However, the
3338 // address must be 0 for single-address memories and the enable signal
3339 // is ignored, always reading out the register. Under these constraints,
3340 // the read port can be replaced with the value from the register.
3341 replacePortField(rewriter, port, "data", memWire);
3342 break;
3343 }
3344 case MemOp::PortKind::Write: {
3345 auto data = portPipeline("data", writeStages);
3346 auto en = portPipeline("en", writeStages);
3347 auto mask = portPipeline("mask", writeStages);
3348 writes.emplace_back(data, en, mask);
3349 break;
3350 }
3351 case MemOp::PortKind::ReadWrite: {
3352 // Always read the register into the read end.
3353 replacePortField(rewriter, port, "rdata", memWire);
3354
3355 // Create a write enable and pipeline stages.
3356 auto wdata = portPipeline("wdata", writeStages);
3357 auto wmask = portPipeline("wmask", writeStages);
3358
3359 Value en = getPortFieldValue(port, "en");
3360 Value wmode = getPortFieldValue(port, "wmode");
3361
3362 auto wen = AndPrimOp::create(rewriter, port.getLoc(), en, wmode);
3363 auto wenPipelined =
3364 pipeline(wen, portClock, name + "_wen", writeStages);
3365 writes.emplace_back(wdata, wenPipelined, wmask);
3366 break;
3367 }
3368 }
3369 }
3370
3371 // Regardless of `writeUnderWrite`, always implement PortOrder.
3372 Value next = memReg;
3373 for (auto &[data, en, mask] : writes) {
3374 Value masked;
3375
3376 // If a mask bit is used, emit muxes to select the input from the
3377 // register (no mask) or the input (mask bit set).
3378 Location loc = mem.getLoc();
3379 unsigned maskGran = info.dataWidth / info.maskBits;
3380 SmallVector<Value> chunks;
3381 for (unsigned i = 0; i < info.maskBits; ++i) {
3382 unsigned hi = (i + 1) * maskGran - 1;
3383 unsigned lo = i * maskGran;
3384
3385 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
3386 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3387 auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
3388 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3389 chunks.push_back(chunk);
3390 }
3391
3392 std::reverse(chunks.begin(), chunks.end());
3393 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3394 next = MuxPrimOp::create(rewriter, next.getLoc(), en, masked, next);
3395 }
3396 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3397 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3398
3399 // Delete the fields and their associated connects.
3400 for (Operation *conn : connects)
3401 rewriter.eraseOp(conn);
3402 for (auto portAccess : portAccesses)
3403 rewriter.eraseOp(portAccess);
3404 rewriter.eraseOp(mem);
3405
3406 return success();
3407 }
3408};
3409} // namespace
3410
3411void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3412 MLIRContext *context) {
3413 results
3414 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3415 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3416 context);
3417}
3418
3419//===----------------------------------------------------------------------===//
3420// Declarations
3421//===----------------------------------------------------------------------===//
3422
3423// Turn synchronous reset looking register updates to registers with resets.
3424// Also, const prop registers that are driven by a mux tree containing only
3425// instances of one constant or self-assigns.
3426static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3427 // reg ; connect(reg, mux(port, const, val)) ->
3428 // reg.reset(port, const); connect(reg, val)
3429
3430 // Find the one true connect, or bail
3431 auto con = getSingleConnectUserOf(reg.getResult());
3432 if (!con)
3433 return failure();
3434
3435 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3436 if (!mux)
3437 return failure();
3438 auto *high = mux.getHigh().getDefiningOp();
3439 auto *low = mux.getLow().getDefiningOp();
3440 // Reset value must be constant
3441 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3442
3443 // Detect the case if a register only has two possible drivers:
3444 // (1) itself/uninit and (2) constant.
3445 // The mux can then be replaced with the constant.
3446 // r = mux(cond, r, 3) --> r = 3
3447 // r = mux(cond, 3, r) --> r = 3
3448 bool constReg = false;
3449
3450 if (constOp && low == reg)
3451 constReg = true;
3452 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3453 constReg = true;
3454 constOp = dyn_cast<ConstantOp>(low);
3455 }
3456 if (!constOp)
3457 return failure();
3458
3459 // For a non-constant register, reset should be a module port (heuristic to
3460 // limit to intended reset lines). Replace the register anyway if constant.
3461 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3462 return failure();
3463
3464 // Check all types should be typed by now
3465 auto regTy = reg.getResult().getType();
3466 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3467 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3468 regTy.getBitWidthOrSentinel() < 0)
3469 return failure();
3470
3471 // Ok, we know we are doing the transformation.
3472
3473 // Make sure the constant dominates all users.
3474 if (constOp != &con->getBlock()->front())
3475 constOp->moveBefore(&con->getBlock()->front());
3476
3477 if (!constReg) {
3478 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3479 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3480 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3481 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3482 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3483 reg.getForceableAttr());
3484 newReg->setDialectAttrs(attrs);
3485 }
3486 auto pt = rewriter.saveInsertionPoint();
3487 rewriter.setInsertionPoint(con);
3488 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3489 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3490 rewriter.restoreInsertionPoint(pt);
3491 return success();
3492}
3493
3494LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3495 if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3496 succeeded(foldHiddenReset(op, rewriter)))
3497 return success();
3498
3499 if (succeeded(demoteForceableIfUnused(op, rewriter)))
3500 return success();
3501
3502 return failure();
3503}
3504
3505//===----------------------------------------------------------------------===//
3506// Verification Ops.
3507//===----------------------------------------------------------------------===//
3508
3509static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3510 Value enable,
3511 PatternRewriter &rewriter,
3512 bool eraseIfZero) {
3513 // If the verification op is never enabled, delete it.
3514 if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3515 if (constant.getValue().isZero()) {
3516 rewriter.eraseOp(op);
3517 return success();
3518 }
3519 }
3520
3521 // If the verification op is never triggered, delete it.
3522 if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3523 if (constant.getValue().isZero() == eraseIfZero) {
3524 rewriter.eraseOp(op);
3525 return success();
3526 }
3527 }
3528
3529 return failure();
3530}
3531
3532template <class Op, bool EraseIfZero = false>
3533static LogicalResult canonicalizeImmediateVerifOp(Op op,
3534 PatternRewriter &rewriter) {
3535 return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3536 EraseIfZero);
3537}
3538
3539void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3540 MLIRContext *context) {
3541 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3542 results.add<patterns::AssertXWhenX>(context);
3543}
3544
3545void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3546 MLIRContext *context) {
3547 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3548 results.add<patterns::AssumeXWhenX>(context);
3549}
3550
3551void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3552 RewritePatternSet &results, MLIRContext *context) {
3553 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3554 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(context);
3555}
3556
3557void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3558 MLIRContext *context) {
3559 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3560}
3561
3562//===----------------------------------------------------------------------===//
3563// InvalidValueOp
3564//===----------------------------------------------------------------------===//
3565
3566LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3567 PatternRewriter &rewriter) {
3568 // Remove `InvalidValueOp`s with no uses.
3569 if (op.use_empty()) {
3570 rewriter.eraseOp(op);
3571 return success();
3572 }
3573 // Propagate invalids through a single use which is a unary op. You cannot
3574 // propagate through multiple uses as that breaks invalid semantics. Nor
3575 // can you propagate through binary ops or generally any op which computes.
3576 // Not is an exception as it is a pure, all-bits inverse.
3577 if (op->hasOneUse() &&
3578 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3579 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3580 *op->user_begin()) ||
3581 (isa<CvtPrimOp>(*op->user_begin()) &&
3582 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3583 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3584 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3585 .getBitWidthOrSentinel() > 0))) {
3586 auto *modop = *op->user_begin();
3587 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3588 modop->getResult(0).getType());
3589 rewriter.replaceAllOpUsesWith(modop, inv);
3590 rewriter.eraseOp(modop);
3591 rewriter.eraseOp(op);
3592 return success();
3593 }
3594 return failure();
3595}
3596
3597OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3598 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3599 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3600 return {};
3601}
3602
3603//===----------------------------------------------------------------------===//
3604// ClockGateIntrinsicOp
3605//===----------------------------------------------------------------------===//
3606
3607OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3608 // Forward the clock if one of the enables is always true.
3609 if (isConstantOne(adaptor.getEnable()) ||
3610 isConstantOne(adaptor.getTestEnable()))
3611 return getInput();
3612
3613 // Fold to a constant zero clock if the enables are always false.
3614 if (isConstantZero(adaptor.getEnable()) &&
3615 (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3616 return BoolAttr::get(getContext(), false);
3617
3618 // Forward constant zero clocks.
3619 if (isConstantZero(adaptor.getInput()))
3620 return BoolAttr::get(getContext(), false);
3621
3622 return {};
3623}
3624
3625LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3626 PatternRewriter &rewriter) {
3627 // Remove constant false test enable.
3628 if (auto testEnable = op.getTestEnable()) {
3629 if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3630 if (constOp.getValue().isZero()) {
3631 rewriter.modifyOpInPlace(op,
3632 [&] { op.getTestEnableMutable().clear(); });
3633 return success();
3634 }
3635 }
3636 }
3637
3638 return failure();
3639}
3640
3641//===----------------------------------------------------------------------===//
3642// Reference Ops.
3643//===----------------------------------------------------------------------===//
3644
3645// refresolve(forceable.ref) -> forceable.data
3646static LogicalResult
3647canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3648 auto forceable = op.getRef().getDefiningOp<Forceable>();
3649 if (!forceable || !forceable.isForceable() ||
3650 op.getRef() != forceable.getDataRef() ||
3651 op.getType() != forceable.getDataType())
3652 return failure();
3653 rewriter.replaceAllUsesWith(op, forceable.getData());
3654 return success();
3655}
3656
3657void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3658 MLIRContext *context) {
3659 results.insert<patterns::RefResolveOfRefSend>(context);
3660 results.insert(canonicalizeRefResolveOfForceable);
3661}
3662
3663OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3664 // RefCast is unnecessary if types match.
3665 if (getInput().getType() == getType())
3666 return getInput();
3667 return {};
3668}
3669
3670static bool isConstantZero(Value operand) {
3671 auto constOp = operand.getDefiningOp<ConstantOp>();
3672 return constOp && constOp.getValue().isZero();
3673}
3674
3675template <typename Op>
3676static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3677 if (isConstantZero(op.getPredicate())) {
3678 rewriter.eraseOp(op);
3679 return success();
3680 }
3681 return failure();
3682}
3683
3684void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3685 MLIRContext *context) {
3686 results.add(eraseIfPredFalse<RefForceOp>);
3687}
3688void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3689 MLIRContext *context) {
3690 results.add(eraseIfPredFalse<RefForceInitialOp>);
3691}
3692void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3693 MLIRContext *context) {
3694 results.add(eraseIfPredFalse<RefReleaseOp>);
3695}
3696void RefReleaseInitialOp::getCanonicalizationPatterns(
3697 RewritePatternSet &results, MLIRContext *context) {
3698 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3699}
3700
3701//===----------------------------------------------------------------------===//
3702// HasBeenResetIntrinsicOp
3703//===----------------------------------------------------------------------===//
3704
3705OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3706 // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3707
3708 // Fold to zero if the reset is a constant. In this case the op is either
3709 // permanently in reset or never resets. Both mean that the reset never
3710 // finishes, so this op never returns true.
3711 if (adaptor.getReset())
3712 return getIntZerosAttr(UIntType::get(getContext(), 1));
3713
3714 // Fold to zero if the clock is a constant and the reset is synchronous. In
3715 // that case the reset will never be started.
3716 if (isUInt1(getReset().getType()) && adaptor.getClock())
3717 return getIntZerosAttr(UIntType::get(getContext(), 1));
3718
3719 return {};
3720}
3721
3722//===----------------------------------------------------------------------===//
3723// FPGAProbeIntrinsicOp
3724//===----------------------------------------------------------------------===//
3725
3726static bool isTypeEmpty(FIRRTLType type) {
3728 .Case<FVectorType>(
3729 [&](auto ty) -> bool { return isTypeEmpty(ty.getElementType()); })
3730 .Case<BundleType>([&](auto ty) -> bool {
3731 for (auto elem : ty.getElements())
3732 if (!isTypeEmpty(elem.type))
3733 return false;
3734 return true;
3735 })
3736 .Case<IntType>([&](auto ty) { return ty.getWidth() == 0; })
3737 .Default([](auto) -> bool { return false; });
3738}
3739
3740LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3741 PatternRewriter &rewriter) {
3742 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3743 if (!firrtlTy)
3744 return failure();
3745
3746 if (!isTypeEmpty(firrtlTy))
3747 return failure();
3748
3749 rewriter.eraseOp(op);
3750 return success();
3751}
3752
3753//===----------------------------------------------------------------------===//
3754// Layer Block Op
3755//===----------------------------------------------------------------------===//
3756
3757LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3758 PatternRewriter &rewriter) {
3759
3760 // If the layerblock is empty, erase it.
3761 if (op.getBody()->empty()) {
3762 rewriter.eraseOp(op);
3763 return success();
3764 }
3765
3766 return failure();
3767}
3768
3769//===----------------------------------------------------------------------===//
3770// Domain-related Ops
3771//===----------------------------------------------------------------------===//
3772
3773OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3774 // If no domains are specified, then forward the input to the result.
3775 if (getDomains().empty())
3776 return getInput();
3777
3778 return {};
3779}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:18
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition FoldUtils.h:28
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition FoldUtils.h:35
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
BitsOfCat(MLIRContext *context)