CIRCT 22.0.0git
Loading...
Searching...
No Matches
FIRRTLFolds.cpp
Go to the documentation of this file.
1//===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the folding and canonicalizations for FIRRTL ops.
10//
11//===----------------------------------------------------------------------===//
12
17#include "circt/Support/APInt.h"
18#include "circt/Support/LLVM.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/SmallPtrSet.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/ADT/TypeSwitch.h"
26
27using namespace circt;
28using namespace firrtl;
29
30// Drop writes to old and pass through passthrough to make patterns easier to
31// write.
32static Value dropWrite(PatternRewriter &rewriter, OpResult old,
33 Value passthrough) {
34 SmallPtrSet<Operation *, 8> users;
35 for (auto *user : old.getUsers())
36 users.insert(user);
37 for (Operation *user : users)
38 if (auto connect = dyn_cast<FConnectLike>(user))
39 if (connect.getDest() == old)
40 rewriter.eraseOp(user);
41 return passthrough;
42}
43
44// Return true if it is OK to propagate the name to the operation.
45// Non-pure operations such as instances, registers, and memories are not
46// allowed to update names for name stabilities and LEC reasons.
47static bool isOkToPropagateName(Operation *op) {
48 // Conservatively disallow operations with regions to prevent performance
49 // regression due to recursive calls to sub-regions in mlir::isPure.
50 if (op->getNumRegions() != 0)
51 return false;
52 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
53}
54
55// Move a name hint from a soon to be deleted operation to a new operation.
56// Pass through the new operation to make patterns easier to write. This cannot
57// move a name to a port (block argument), doing so would require rewriting all
58// instance sites as well as the module.
59static Value moveNameHint(OpResult old, Value passthrough) {
60 Operation *op = passthrough.getDefiningOp();
61 // This should handle ports, but it isn't clear we can change those in
62 // canonicalizers.
63 assert(op && "passthrough must be an operation");
64 Operation *oldOp = old.getOwner();
65 auto name = oldOp->getAttrOfType<StringAttr>("name");
66 if (name && !name.getValue().empty() && isOkToPropagateName(op))
67 op->setAttr("name", name);
68 return passthrough;
69}
70
71// Declarative canonicalization patterns
72namespace circt {
73namespace firrtl {
74namespace patterns {
75#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
76} // namespace patterns
77} // namespace firrtl
78} // namespace circt
79
80/// Return true if this operation's operands and results all have a known width.
81/// This only works for integer types.
82static bool hasKnownWidthIntTypes(Operation *op) {
83 auto resultType = type_cast<IntType>(op->getResult(0).getType());
84 if (!resultType.hasWidth())
85 return false;
86 for (Value operand : op->getOperands())
87 if (!type_cast<IntType>(operand.getType()).hasWidth())
88 return false;
89 return true;
90}
91
92/// Return true if this value is 1 bit UInt.
93static bool isUInt1(Type type) {
94 auto t = type_dyn_cast<UIntType>(type);
95 if (!t || !t.hasWidth() || t.getWidth() != 1)
96 return false;
97 return true;
98}
99
100/// Set the name of an op based on the best of two names: The current name, and
101/// the name passed in.
102static void updateName(PatternRewriter &rewriter, Operation *op,
103 StringAttr name) {
104 // Should never rename InstanceOp
105 if (!name || name.getValue().empty() || !isOkToPropagateName(op))
106 return;
107 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) && "Should never rename");
108 auto newName = name.getValue(); // old name is interesting
109 auto newOpName = op->getAttrOfType<StringAttr>("name");
110 // new name might not be interesting
111 if (newOpName)
112 newName = chooseName(newOpName.getValue(), name.getValue());
113 // Only update if needed
114 if (!newOpName || newOpName.getValue() != newName)
115 rewriter.modifyOpInPlace(
116 op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
117}
118
119/// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
120/// If a replaced op has a "name" attribute, this function propagates the name
121/// to the new value.
122static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
123 Value newValue) {
124 if (auto *newOp = newValue.getDefiningOp()) {
125 auto name = op->getAttrOfType<StringAttr>("name");
126 updateName(rewriter, newOp, name);
127 }
128 rewriter.replaceOp(op, newValue);
129}
130
131/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
132/// attribute. If a replaced op has a "name" attribute, this function propagates
133/// the name to the new value.
134template <typename OpTy, typename... Args>
135static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
136 Operation *op, Args &&...args) {
137 auto name = op->getAttrOfType<StringAttr>("name");
138 auto newOp =
139 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
140 updateName(rewriter, newOp, name);
141 return newOp;
142}
143
144/// Return true if the name is droppable. Note that this is different from
145/// `isUselessName` because non-useless names may be also droppable.
147 if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
148 return namableOp.hasDroppableName();
149 return false;
150}
151
152/// Implicitly replace the operand to a constant folding operation with a const
153/// 0 in case the operand is non-constant but has a bit width 0, or if the
154/// operand is an invalid value.
155///
156/// This makes constant folding significantly easier, as we can simply pass the
157/// operands to an operation through this function to appropriately replace any
158/// zero-width dynamic values or invalid values with a constant of value 0.
159static std::optional<APSInt>
160getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
161 assert(type_cast<IntType>(operand.getType()) &&
162 "getExtendedConstant is limited to integer types");
163
164 // We never support constant folding to unknown width values.
165 if (destWidth < 0)
166 return {};
167
168 // Extension signedness follows the operand sign.
169 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
170 return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
171
172 // If the operand is zero bits, then we can return a zero of the result
173 // type.
174 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
175 return APSInt(destWidth,
176 type_cast<IntType>(operand.getType()).isUnsigned());
177 return {};
178}
179
180/// Determine the value of a constant operand for the sake of constant folding.
181static std::optional<APSInt> getConstant(Attribute operand) {
182 if (!operand)
183 return {};
184 if (auto attr = dyn_cast<BoolAttr>(operand))
185 return APSInt(APInt(1, attr.getValue()));
186 if (auto attr = dyn_cast<IntegerAttr>(operand))
187 return attr.getAPSInt();
188 return {};
189}
190
191/// Determine whether a constant operand is a zero value for the sake of
192/// constant folding. This considers `invalidvalue` to be zero.
193static bool isConstantZero(Attribute operand) {
194 if (auto cst = getConstant(operand))
195 return cst->isZero();
196 return false;
197}
198
199/// Determine whether a constant operand is a one value for the sake of constant
200/// folding.
201static bool isConstantOne(Attribute operand) {
202 if (auto cst = getConstant(operand))
203 return cst->isOne();
204 return false;
205}
206
207/// This is the policy for folding, which depends on the sort of operator we're
208/// processing.
209enum class BinOpKind {
210 Normal,
211 Compare,
213};
214
215/// Applies the constant folding function `calculate` to the given operands.
216///
217/// Sign or zero extends the operands appropriately to the bitwidth of the
218/// result type if \p useDstWidth is true, else to the larger of the two operand
219/// bit widths and depending on whether the operation is to be performed on
220/// signed or unsigned operands.
221static Attribute constFoldFIRRTLBinaryOp(
222 Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
223 const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
224 assert(operands.size() == 2 && "binary op takes two operands");
225
226 // We cannot fold something to an unknown width.
227 auto resultType = type_cast<IntType>(op->getResult(0).getType());
228 if (resultType.getWidthOrSentinel() < 0)
229 return {};
230
231 // Any binary op returning i0 is 0.
232 if (resultType.getWidthOrSentinel() == 0)
233 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
234
235 // Determine the operand widths. This is either dictated by the operand type,
236 // or if that type is an unsized integer, by the actual bits necessary to
237 // represent the constant value.
238 auto lhsWidth =
239 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
240 auto rhsWidth =
241 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
242 if (auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
243 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
244 if (auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
245 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
246
247 // Compares extend the operands to the widest of the operand types, not to the
248 // result type.
249 int32_t operandWidth;
250 switch (opKind) {
252 operandWidth = resultType.getWidthOrSentinel();
253 break;
255 // Compares compute with the widest operand, not at the destination type
256 // (which is always i1).
257 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
258 break;
260 operandWidth =
261 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
262 break;
263 }
264
265 auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
266 if (!lhs)
267 return {};
268 auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
269 if (!rhs)
270 return {};
271
272 APInt resultValue = calculate(*lhs, *rhs);
273
274 // If the result type is smaller than the computation then we need to
275 // narrow the constant after the calculation.
276 if (opKind == BinOpKind::DivideOrShift)
277 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
278
279 assert((unsigned)resultType.getWidthOrSentinel() ==
280 resultValue.getBitWidth());
281 return getIntAttr(resultType, resultValue);
282}
283
284/// Applies the canonicalization function `canonicalize` to the given operation.
285///
286/// Determines which (if any) of the operation's operands are constants, and
287/// provides them as arguments to the callback function. Any `invalidvalue` in
288/// the input is mapped to a constant zero. The value returned from the callback
289/// is used as the replacement for `op`, and an additional pad operation is
290/// inserted if necessary. Does nothing if the result of `op` is of unknown
291/// width, in which case the necessity of a pad cannot be determined.
292static LogicalResult canonicalizePrimOp(
293 Operation *op, PatternRewriter &rewriter,
294 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
295 // Can only operate on FIRRTL primitive operations.
296 if (op->getNumResults() != 1)
297 return failure();
298 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
299 if (!type)
300 return failure();
301
302 // Can only operate on operations with a known result width.
303 auto width = type.getBitWidthOrSentinel();
304 if (width < 0)
305 return failure();
306
307 // Determine which of the operands are constants.
308 SmallVector<Attribute, 3> constOperands;
309 constOperands.reserve(op->getNumOperands());
310 for (auto operand : op->getOperands()) {
311 Attribute attr;
312 if (auto *defOp = operand.getDefiningOp())
313 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
314 [&](auto op) { attr = op.getValueAttr(); });
315 constOperands.push_back(attr);
316 }
317
318 // Perform the canonicalization and materialize the result if it is a
319 // constant.
320 auto result = canonicalize(constOperands);
321 if (!result)
322 return failure();
323 Value resultValue;
324 if (auto cst = dyn_cast<Attribute>(result))
325 resultValue = op->getDialect()
326 ->materializeConstant(rewriter, cst, type, op->getLoc())
327 ->getResult(0);
328 else
329 resultValue = cast<Value>(result);
330
331 // Insert a pad if the type widths disagree.
332 if (width !=
333 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
334 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
335
336 // Insert a cast if this is a uint vs. sint or vice versa.
337 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
338 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
339 else if (type_isa<UIntType>(type) &&
340 type_isa<SIntType>(resultValue.getType()))
341 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
342
343 assert(type == resultValue.getType() && "canonicalization changed type");
344 replaceOpAndCopyName(rewriter, op, resultValue);
345 return success();
346}
347
348/// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
349/// value if `bitWidth` is 0.
350static APInt getMaxUnsignedValue(unsigned bitWidth) {
351 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
352}
353
354/// Get the smallest signed value of a given bit width. Returns a 1-bit zero
355/// value if `bitWidth` is 0.
356static APInt getMinSignedValue(unsigned bitWidth) {
357 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
358}
359
360/// Get the largest signed value of a given bit width. Returns a 1-bit zero
361/// value if `bitWidth` is 0.
362static APInt getMaxSignedValue(unsigned bitWidth) {
363 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
364}
365
366//===----------------------------------------------------------------------===//
367// Fold Hooks
368//===----------------------------------------------------------------------===//
369
370OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
371 assert(adaptor.getOperands().empty() && "constant has no operands");
372 return getValueAttr();
373}
374
375OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
376 assert(adaptor.getOperands().empty() && "constant has no operands");
377 return getValueAttr();
378}
379
380OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
381 assert(adaptor.getOperands().empty() && "constant has no operands");
382 return getFieldsAttr();
383}
384
385OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
386 assert(adaptor.getOperands().empty() && "constant has no operands");
387 return getValueAttr();
388}
389
390OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
391 assert(adaptor.getOperands().empty() && "constant has no operands");
392 return getValueAttr();
393}
394
395OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
396 assert(adaptor.getOperands().empty() && "constant has no operands");
397 return getValueAttr();
398}
399
400OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
401 assert(adaptor.getOperands().empty() && "constant has no operands");
402 return getValueAttr();
403}
404
405//===----------------------------------------------------------------------===//
406// Binary Operators
407//===----------------------------------------------------------------------===//
408
409OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
411 *this, adaptor.getOperands(), BinOpKind::Normal,
412 [=](const APSInt &a, const APSInt &b) { return a + b; });
413}
414
415void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
416 MLIRContext *context) {
417 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
418 patterns::AddOfSelf, patterns::AddOfPad>(context);
419}
420
421OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
423 *this, adaptor.getOperands(), BinOpKind::Normal,
424 [=](const APSInt &a, const APSInt &b) { return a - b; });
425}
426
427void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
428 MLIRContext *context) {
429 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
430 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
431 patterns::SubOfPadL, patterns::SubOfPadR>(context);
432}
433
434OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
435 // mul(x, 0) -> 0
436 //
437 // This is legal because it aligns with the Scala FIRRTL Compiler
438 // interpretation of lowering invalid to constant zero before constant
439 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
440 // multiplication this way and will emit "x * 0".
441 if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
442 return getIntZerosAttr(getType());
443
445 *this, adaptor.getOperands(), BinOpKind::Normal,
446 [=](const APSInt &a, const APSInt &b) { return a * b; });
447}
448
449OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
450 /// div(x, x) -> 1
451 ///
452 /// Division by zero is undefined in the FIRRTL specification. This fold
453 /// exploits that fact to optimize self division to one. Note: this should
454 /// supersede any division with invalid or zero. Division of invalid by
455 /// invalid should be one.
456 if (getLhs() == getRhs()) {
457 auto width = getType().base().getWidthOrSentinel();
458 if (width == -1)
459 width = 2;
460 // Only fold if we have at least 1 bit of width to represent the `1` value.
461 if (width != 0)
462 return getIntAttr(getType(), APInt(width, 1));
463 }
464
465 // div(0, x) -> 0
466 //
467 // This is legal because it aligns with the Scala FIRRTL Compiler
468 // interpretation of lowering invalid to constant zero before constant
469 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
470 // division this way and will emit "0 / x".
471 if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
472 return getIntZerosAttr(getType());
473
474 /// div(x, 1) -> x : (uint, uint) -> uint
475 ///
476 /// UInt division by one returns the numerator. SInt division can't
477 /// be folded here because it increases the return type bitwidth by
478 /// one and requires sign extension (a new op).
479 if (auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
480 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
481 return getLhs();
482
484 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
485 [=](const APSInt &a, const APSInt &b) -> APInt {
486 if (!!b)
487 return a / b;
488 return APInt(a.getBitWidth(), 0);
489 });
490}
491
492OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
493 // rem(x, x) -> 0
494 //
495 // Division by zero is undefined in the FIRRTL specification. This fold
496 // exploits that fact to optimize self division remainder to zero. Note:
497 // this should supersede any division with invalid or zero. Remainder of
498 // division of invalid by invalid should be zero.
499 if (getLhs() == getRhs())
500 return getIntZerosAttr(getType());
501
502 // rem(0, x) -> 0
503 //
504 // This is legal because it aligns with the Scala FIRRTL Compiler
505 // interpretation of lowering invalid to constant zero before constant
506 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
507 // division this way and will emit "0 % x".
508 if (isConstantZero(adaptor.getLhs()))
509 return getIntZerosAttr(getType());
510
512 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
513 [=](const APSInt &a, const APSInt &b) -> APInt {
514 if (!!b)
515 return a % b;
516 return APInt(a.getBitWidth(), 0);
517 });
518}
519
520OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
522 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
523 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
524}
525
526OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
528 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
529 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
530}
531
532OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
534 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
535 [=](const APSInt &a, const APSInt &b) -> APInt {
536 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
537 : a.ashr(b);
538 });
539}
540
541// TODO: Move to DRR.
542OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
543 if (auto rhsCst = getConstant(adaptor.getRhs())) {
544 /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
545 if (rhsCst->isZero())
546 return getIntZerosAttr(getType());
547
548 /// and(x, -1) -> x
549 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
550 getRhs().getType() == getType())
551 return getLhs();
552 }
553
554 if (auto lhsCst = getConstant(adaptor.getLhs())) {
555 /// and(0, x) -> 0, 0 is largest or is implicit zero extended
556 if (lhsCst->isZero())
557 return getIntZerosAttr(getType());
558
559 /// and(-1, x) -> x
560 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
561 getRhs().getType() == getType())
562 return getRhs();
563 }
564
565 /// and(x, x) -> x
566 if (getLhs() == getRhs() && getRhs().getType() == getType())
567 return getRhs();
568
570 *this, adaptor.getOperands(), BinOpKind::Normal,
571 [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
572}
573
574void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
575 MLIRContext *context) {
576 results
577 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
578 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
579 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
580}
581
582OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
583 if (auto rhsCst = getConstant(adaptor.getRhs())) {
584 /// or(x, 0) -> x
585 if (rhsCst->isZero() && getLhs().getType() == getType())
586 return getLhs();
587
588 /// or(x, -1) -> -1
589 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
590 getLhs().getType() == getType())
591 return getRhs();
592 }
593
594 if (auto lhsCst = getConstant(adaptor.getLhs())) {
595 /// or(0, x) -> x
596 if (lhsCst->isZero() && getRhs().getType() == getType())
597 return getRhs();
598
599 /// or(-1, x) -> -1
600 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
601 getRhs().getType() == getType())
602 return getLhs();
603 }
604
605 /// or(x, x) -> x
606 if (getLhs() == getRhs() && getRhs().getType() == getType())
607 return getRhs();
608
610 *this, adaptor.getOperands(), BinOpKind::Normal,
611 [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
612}
613
614void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
615 MLIRContext *context) {
616 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
617 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
618 patterns::OrOrr>(context);
619}
620
621OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
622 /// xor(x, 0) -> x
623 if (auto rhsCst = getConstant(adaptor.getRhs()))
624 if (rhsCst->isZero() &&
625 firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
626 return getLhs();
627
628 /// xor(x, 0) -> x
629 if (auto lhsCst = getConstant(adaptor.getLhs()))
630 if (lhsCst->isZero() &&
631 firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
632 return getRhs();
633
634 /// xor(x, x) -> 0
635 if (getLhs() == getRhs())
636 return getIntAttr(
637 getType(),
638 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
639
641 *this, adaptor.getOperands(), BinOpKind::Normal,
642 [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
643}
644
645void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
646 MLIRContext *context) {
647 results.insert<patterns::extendXor, patterns::moveConstXor,
648 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
649 context);
650}
651
652void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
653 MLIRContext *context) {
654 results.insert<patterns::LEQWithConstLHS>(context);
655}
656
657OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
658 bool isUnsigned = getLhs().getType().base().isUnsigned();
659
660 // leq(x, x) -> 1
661 if (getLhs() == getRhs())
662 return getIntAttr(getType(), APInt(1, 1));
663
664 // Comparison against constant outside type bounds.
665 if (auto width = getLhs().getType().base().getWidth()) {
666 if (auto rhsCst = getConstant(adaptor.getRhs())) {
667 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
668 commonWidth = std::max(commonWidth, 1);
669
670 // leq(x, const) -> 0 where const < minValue of the unsigned type of x
671 // This can never occur since const is unsigned and cannot be less than 0.
672
673 // leq(x, const) -> 0 where const < minValue of the signed type of x
674 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
675 .slt(getMinSignedValue(*width).sext(commonWidth)))
676 return getIntAttr(getType(), APInt(1, 0));
677
678 // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
679 if (isUnsigned && rhsCst->zext(commonWidth)
680 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
681 return getIntAttr(getType(), APInt(1, 1));
682
683 // leq(x, const) -> 1 where const >= maxValue of the signed type of x
684 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
685 .sge(getMaxSignedValue(*width).sext(commonWidth)))
686 return getIntAttr(getType(), APInt(1, 1));
687 }
688 }
689
691 *this, adaptor.getOperands(), BinOpKind::Compare,
692 [=](const APSInt &a, const APSInt &b) -> APInt {
693 return APInt(1, a <= b);
694 });
695}
696
697void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
698 MLIRContext *context) {
699 results.insert<patterns::LTWithConstLHS>(context);
700}
701
702OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
703 IntType lhsType = getLhs().getType();
704 bool isUnsigned = lhsType.isUnsigned();
705
706 // lt(x, x) -> 0
707 if (getLhs() == getRhs())
708 return getIntAttr(getType(), APInt(1, 0));
709
710 // lt(x, 0) -> 0 when x is unsigned
711 if (auto rhsCst = getConstant(adaptor.getRhs())) {
712 if (rhsCst->isZero() && lhsType.isUnsigned())
713 return getIntAttr(getType(), APInt(1, 0));
714 }
715
716 // Comparison against constant outside type bounds.
717 if (auto width = lhsType.getWidth()) {
718 if (auto rhsCst = getConstant(adaptor.getRhs())) {
719 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
720 commonWidth = std::max(commonWidth, 1);
721
722 // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
723 // Handled explicitly above.
724
725 // lt(x, const) -> 0 where const <= minValue of the signed type of x
726 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
727 .sle(getMinSignedValue(*width).sext(commonWidth)))
728 return getIntAttr(getType(), APInt(1, 0));
729
730 // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
731 if (isUnsigned && rhsCst->zext(commonWidth)
732 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
733 return getIntAttr(getType(), APInt(1, 1));
734
735 // lt(x, const) -> 1 where const > maxValue of the signed type of x
736 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
737 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
738 return getIntAttr(getType(), APInt(1, 1));
739 }
740 }
741
743 *this, adaptor.getOperands(), BinOpKind::Compare,
744 [=](const APSInt &a, const APSInt &b) -> APInt {
745 return APInt(1, a < b);
746 });
747}
748
749void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
750 MLIRContext *context) {
751 results.insert<patterns::GEQWithConstLHS>(context);
752}
753
754OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
755 IntType lhsType = getLhs().getType();
756 bool isUnsigned = lhsType.isUnsigned();
757
758 // geq(x, x) -> 1
759 if (getLhs() == getRhs())
760 return getIntAttr(getType(), APInt(1, 1));
761
762 // geq(x, 0) -> 1 when x is unsigned
763 if (auto rhsCst = getConstant(adaptor.getRhs())) {
764 if (rhsCst->isZero() && isUnsigned)
765 return getIntAttr(getType(), APInt(1, 1));
766 }
767
768 // Comparison against constant outside type bounds.
769 if (auto width = lhsType.getWidth()) {
770 if (auto rhsCst = getConstant(adaptor.getRhs())) {
771 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
772 commonWidth = std::max(commonWidth, 1);
773
774 // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
775 if (isUnsigned && rhsCst->zext(commonWidth)
776 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
777 return getIntAttr(getType(), APInt(1, 0));
778
779 // geq(x, const) -> 0 where const > maxValue of the signed type of x
780 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
781 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
782 return getIntAttr(getType(), APInt(1, 0));
783
784 // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
785 // Handled explicitly above.
786
787 // geq(x, const) -> 1 where const <= minValue of the signed type of x
788 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
789 .sle(getMinSignedValue(*width).sext(commonWidth)))
790 return getIntAttr(getType(), APInt(1, 1));
791 }
792 }
793
795 *this, adaptor.getOperands(), BinOpKind::Compare,
796 [=](const APSInt &a, const APSInt &b) -> APInt {
797 return APInt(1, a >= b);
798 });
799}
800
801void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
802 MLIRContext *context) {
803 results.insert<patterns::GTWithConstLHS>(context);
804}
805
806OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
807 IntType lhsType = getLhs().getType();
808 bool isUnsigned = lhsType.isUnsigned();
809
810 // gt(x, x) -> 0
811 if (getLhs() == getRhs())
812 return getIntAttr(getType(), APInt(1, 0));
813
814 // Comparison against constant outside type bounds.
815 if (auto width = lhsType.getWidth()) {
816 if (auto rhsCst = getConstant(adaptor.getRhs())) {
817 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
818 commonWidth = std::max(commonWidth, 1);
819
820 // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
821 if (isUnsigned && rhsCst->zext(commonWidth)
822 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
823 return getIntAttr(getType(), APInt(1, 0));
824
825 // gt(x, const) -> 0 where const >= maxValue of the signed type of x
826 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
827 .sge(getMaxSignedValue(*width).sext(commonWidth)))
828 return getIntAttr(getType(), APInt(1, 0));
829
830 // gt(x, const) -> 1 where const < minValue of the unsigned type of x
831 // This can never occur since const is unsigned and cannot be less than 0.
832
833 // gt(x, const) -> 1 where const < minValue of the signed type of x
834 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
835 .slt(getMinSignedValue(*width).sext(commonWidth)))
836 return getIntAttr(getType(), APInt(1, 1));
837 }
838 }
839
841 *this, adaptor.getOperands(), BinOpKind::Compare,
842 [=](const APSInt &a, const APSInt &b) -> APInt {
843 return APInt(1, a > b);
844 });
845}
846
847OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
848 // eq(x, x) -> 1
849 if (getLhs() == getRhs())
850 return getIntAttr(getType(), APInt(1, 1));
851
852 if (auto rhsCst = getConstant(adaptor.getRhs())) {
853 /// eq(x, 1) -> x when x is 1 bit.
854 /// TODO: Support SInt<1> on the LHS etc.
855 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
856 getRhs().getType() == getType())
857 return getLhs();
858 }
859
861 *this, adaptor.getOperands(), BinOpKind::Compare,
862 [=](const APSInt &a, const APSInt &b) -> APInt {
863 return APInt(1, a == b);
864 });
865}
866
867LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
868 return canonicalizePrimOp(
869 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
870 if (auto rhsCst = getConstant(operands[1])) {
871 auto width = op.getLhs().getType().getBitWidthOrSentinel();
872
873 // eq(x, 0) -> not(x) when x is 1 bit.
874 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
875 op.getRhs().getType() == op.getType()) {
876 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
877 .getResult();
878 }
879
880 // eq(x, 0) -> not(orr(x)) when x is >1 bit
881 if (rhsCst->isZero() && width > 1) {
882 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
883 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
884 }
885
886 // eq(x, ~0) -> andr(x) when x is >1 bit
887 if (rhsCst->isAllOnes() && width > 1 &&
888 op.getLhs().getType() == op.getRhs().getType()) {
889 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
890 .getResult();
891 }
892 }
893 return {};
894 });
895}
896
897OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
898 // neq(x, x) -> 0
899 if (getLhs() == getRhs())
900 return getIntAttr(getType(), APInt(1, 0));
901
902 if (auto rhsCst = getConstant(adaptor.getRhs())) {
903 /// neq(x, 0) -> x when x is 1 bit.
904 /// TODO: Support SInt<1> on the LHS etc.
905 if (rhsCst->isZero() && getLhs().getType() == getType() &&
906 getRhs().getType() == getType())
907 return getLhs();
908 }
909
911 *this, adaptor.getOperands(), BinOpKind::Compare,
912 [=](const APSInt &a, const APSInt &b) -> APInt {
913 return APInt(1, a != b);
914 });
915}
916
917LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
918 return canonicalizePrimOp(
919 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
920 if (auto rhsCst = getConstant(operands[1])) {
921 auto width = op.getLhs().getType().getBitWidthOrSentinel();
922
923 // neq(x, 1) -> not(x) when x is 1 bit
924 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
925 op.getRhs().getType() == op.getType()) {
926 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
927 .getResult();
928 }
929
930 // neq(x, 0) -> orr(x) when x is >1 bit
931 if (rhsCst->isZero() && width > 1) {
932 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
933 .getResult();
934 }
935
936 // neq(x, ~0) -> not(andr(x))) when x is >1 bit
937 if (rhsCst->isAllOnes() && width > 1 &&
938 op.getLhs().getType() == op.getRhs().getType()) {
939 auto andrOp =
940 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
941 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
942 }
943 }
944
945 return {};
946 });
947}
948
949OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
950 // TODO: implement constant folding, etc.
951 // Tracked in https://github.com/llvm/circt/issues/6696.
952 return {};
953}
954
955OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
956 // TODO: implement constant folding, etc.
957 // Tracked in https://github.com/llvm/circt/issues/6724.
958 return {};
959}
960
961OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
962 // TODO: implement constant folding, etc.
963 // Tracked in https://github.com/llvm/circt/issues/6725.
964 return {};
965}
966
967OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
968 if (auto rhsCst = getConstant(adaptor.getRhs())) {
969 // Constant folding
970 if (auto lhsCst = getConstant(adaptor.getLhs()))
971
972 return IntegerAttr::get(
973 IntegerType::get(getContext(), lhsCst->getBitWidth()),
974 lhsCst->shl(*rhsCst));
975
976 // integer.shl(x, 0) -> x
977 if (rhsCst->isZero())
978 return getLhs();
979 }
980
981 return {};
982}
983
984//===----------------------------------------------------------------------===//
985// Unary Operators
986//===----------------------------------------------------------------------===//
987
988OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
989 auto base = getInput().getType();
990 auto w = getBitWidth(base);
991 if (w)
992 return getIntAttr(getType(), APInt(32, *w));
993 return {};
994}
995
996OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
997 // No constant can be 'x' by definition.
998 if (auto cst = getConstant(adaptor.getArg()))
999 return getIntAttr(getType(), APInt(1, 0));
1000 return {};
1001}
1002
1003OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1004 // No effect.
1005 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1006 return getInput();
1007
1008 // Be careful to only fold the cast into the constant if the size is known.
1009 // Otherwise width inference may produce differently-sized constants if the
1010 // sign changes.
1011 if (getType().base().hasWidth())
1012 if (auto cst = getConstant(adaptor.getInput()))
1013 return getIntAttr(getType(), *cst);
1014
1015 return {};
1016}
1017
1018void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1019 MLIRContext *context) {
1020 results.insert<patterns::StoUtoS>(context);
1021}
1022
1023OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1024 // No effect.
1025 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1026 return getInput();
1027
1028 // Be careful to only fold the cast into the constant if the size is known.
1029 // Otherwise width inference may produce differently-sized constants if the
1030 // sign changes.
1031 if (getType().base().hasWidth())
1032 if (auto cst = getConstant(adaptor.getInput()))
1033 return getIntAttr(getType(), *cst);
1034
1035 return {};
1036}
1037
1038void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1039 MLIRContext *context) {
1040 results.insert<patterns::UtoStoU>(context);
1041}
1042
1043OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1044 // No effect.
1045 if (getInput().getType() == getType())
1046 return getInput();
1047
1048 // Constant fold.
1049 if (auto cst = getConstant(adaptor.getInput()))
1050 return BoolAttr::get(getContext(), cst->getBoolValue());
1051
1052 return {};
1053}
1054
1055OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1056 // No effect.
1057 if (getInput().getType() == getType())
1058 return getInput();
1059
1060 // Constant fold.
1061 if (auto cst = getConstant(adaptor.getInput()))
1062 return BoolAttr::get(getContext(), cst->getBoolValue());
1063
1064 return {};
1065}
1066
1067OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1068 if (!hasKnownWidthIntTypes(*this))
1069 return {};
1070
1071 // Signed to signed is a noop, unsigned operands prepend a zero bit.
1072 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1073 getType().base().getWidthOrSentinel()))
1074 return getIntAttr(getType(), *cst);
1075
1076 return {};
1077}
1078
1079void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1080 MLIRContext *context) {
1081 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1082}
1083
1084OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1085 if (!hasKnownWidthIntTypes(*this))
1086 return {};
1087
1088 // FIRRTL negate always adds a bit.
1089 // -x ---> 0-sext(x) or 0-zext(x)
1090 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1091 getType().base().getWidthOrSentinel()))
1092 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1093
1094 return {};
1095}
1096
1097OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1098 if (!hasKnownWidthIntTypes(*this))
1099 return {};
1100
1101 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1102 getType().base().getWidthOrSentinel()))
1103 return getIntAttr(getType(), ~*cst);
1104
1105 return {};
1106}
1107
1108void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1109 MLIRContext *context) {
1110 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1111 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1112 patterns::NotGt>(context);
1113}
1114
1115class ReductionCat : public mlir::RewritePattern {
1116public:
1117 ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
1118 : RewritePattern(opName, 0, context) {}
1119
1120 /// Handle a constant operand in the cat operation.
1121 /// Returns true if the entire reduction can be replaced with a constant.
1122 /// May add non-zero constants to the remaining operands list.
1123 virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1124 ConstantOp constantOp,
1125 SmallVectorImpl<Value> &remaining) const = 0;
1126
1127 /// Return the unit value for this reduction operation:
1128 virtual bool getIdentityValue() const = 0;
1129
1130 LogicalResult
1131 matchAndRewrite(Operation *op,
1132 mlir::PatternRewriter &rewriter) const override {
1133 // Check if the operand is a cat operation
1134 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1135 if (!catOp)
1136 return failure();
1137
1138 SmallVector<Value> nonConstantOperands;
1139
1140 // Process each operand of the cat operation
1141 for (auto operand : catOp.getInputs()) {
1142 if (auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1143 // Handle constant operands - may short-circuit the entire operation
1144 if (handleConstant(rewriter, op, constantOp, nonConstantOperands))
1145 return success();
1146 } else {
1147 // Keep non-constant operands for further processing
1148 nonConstantOperands.push_back(operand);
1149 }
1150 }
1151
1152 // If no operands remain, replace with identity value
1153 if (nonConstantOperands.empty()) {
1154 replaceOpWithNewOpAndCopyName<ConstantOp>(
1155 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1156 APInt(1, getIdentityValue()));
1157 return success();
1158 }
1159
1160 // If only one operand remains, apply reduction directly to it
1161 if (nonConstantOperands.size() == 1) {
1162 rewriter.modifyOpInPlace(
1163 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1164 return success();
1165 }
1166
1167 // Multiple operands remain - optimize only when cat has a single use.
1168 if (catOp->hasOneUse() &&
1169 nonConstantOperands.size() < catOp->getNumOperands()) {
1170 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1171 nonConstantOperands);
1172 return success();
1173 }
1174 return failure();
1175 }
1176};
1177
1178class OrRCat : public ReductionCat {
1179public:
1180 OrRCat(MLIRContext *context)
1181 : ReductionCat(context, OrRPrimOp::getOperationName()) {}
1182 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1183 ConstantOp value,
1184 SmallVectorImpl<Value> &remaining) const override {
1185 if (value.getValue().isZero())
1186 return false;
1187
1188 replaceOpWithNewOpAndCopyName<ConstantOp>(
1189 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1190 llvm::APInt(1, 1));
1191 return true;
1192 }
1193 bool getIdentityValue() const override { return false; }
1194};
1195
1196class AndRCat : public ReductionCat {
1197public:
1198 AndRCat(MLIRContext *context)
1199 : ReductionCat(context, AndRPrimOp::getOperationName()) {}
1200 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1201 ConstantOp value,
1202 SmallVectorImpl<Value> &remaining) const override {
1203 if (value.getValue().isAllOnes())
1204 return false;
1205
1206 replaceOpWithNewOpAndCopyName<ConstantOp>(
1207 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1208 llvm::APInt(1, 0));
1209 return true;
1210 }
1211 bool getIdentityValue() const override { return true; }
1212};
1213
1214class XorRCat : public ReductionCat {
1215public:
1216 XorRCat(MLIRContext *context)
1217 : ReductionCat(context, XorRPrimOp::getOperationName()) {}
1218 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1219 ConstantOp value,
1220 SmallVectorImpl<Value> &remaining) const override {
1221 if (value.getValue().isZero())
1222 return false;
1223 remaining.push_back(value);
1224 return false;
1225 }
1226 bool getIdentityValue() const override { return false; }
1227};
1228
1229OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1230 if (!hasKnownWidthIntTypes(*this))
1231 return {};
1232
1233 if (getInput().getType().getBitWidthOrSentinel() == 0)
1234 return getIntAttr(getType(), APInt(1, 1));
1235
1236 // x == -1
1237 if (auto cst = getConstant(adaptor.getInput()))
1238 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1239
1240 // one bit is identity. Only applies to UInt since we can't make a cast
1241 // here.
1242 if (isUInt1(getInput().getType()))
1243 return getInput();
1244
1245 return {};
1246}
1247
1248void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1249 MLIRContext *context) {
1250 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1251 patterns::AndRPadS, patterns::AndRCatAndR_left,
1252 patterns::AndRCatAndR_right, AndRCat>(context);
1253}
1254
1255OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1256 if (!hasKnownWidthIntTypes(*this))
1257 return {};
1258
1259 if (getInput().getType().getBitWidthOrSentinel() == 0)
1260 return getIntAttr(getType(), APInt(1, 0));
1261
1262 // x != 0
1263 if (auto cst = getConstant(adaptor.getInput()))
1264 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1265
1266 // one bit is identity. Only applies to UInt since we can't make a cast
1267 // here.
1268 if (isUInt1(getInput().getType()))
1269 return getInput();
1270
1271 return {};
1272}
1273
1274void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1275 MLIRContext *context) {
1276 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1277 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right, OrRCat>(
1278 context);
1279}
1280
1281OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1282 if (!hasKnownWidthIntTypes(*this))
1283 return {};
1284
1285 if (getInput().getType().getBitWidthOrSentinel() == 0)
1286 return getIntAttr(getType(), APInt(1, 0));
1287
1288 // popcount(x) & 1
1289 if (auto cst = getConstant(adaptor.getInput()))
1290 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1291
1292 // one bit is identity. Only applies to UInt since we can't make a cast here.
1293 if (isUInt1(getInput().getType()))
1294 return getInput();
1295
1296 return {};
1297}
1298
1299void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1300 MLIRContext *context) {
1301 results
1302 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1303 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right, XorRCat>(
1304 context);
1305}
1306
1307//===----------------------------------------------------------------------===//
1308// Other Operators
1309//===----------------------------------------------------------------------===//
1310
1311OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1312 auto inputs = getInputs();
1313 auto inputAdaptors = adaptor.getInputs();
1314
1315 // If no inputs, return 0-bit value
1316 if (inputs.empty())
1317 return getIntZerosAttr(getType());
1318
1319 // If single input and same type, return it
1320 if (inputs.size() == 1 && inputs[0].getType() == getType())
1321 return inputs[0];
1322
1323 // Make sure it's safe to fold.
1324 if (!hasKnownWidthIntTypes(*this))
1325 return {};
1326
1327 // Filter out zero-width operands
1328 SmallVector<Value> nonZeroInputs;
1329 SmallVector<Attribute> nonZeroAttributes;
1330 bool allConstant = true;
1331 for (auto [input, attr] : llvm::zip(inputs, inputAdaptors)) {
1332 auto inputType = type_cast<IntType>(input.getType());
1333 if (inputType.getBitWidthOrSentinel() != 0) {
1334 nonZeroInputs.push_back(input);
1335 if (!attr)
1336 allConstant = false;
1337 if (nonZeroInputs.size() > 1 && !allConstant)
1338 return {};
1339 }
1340 }
1341
1342 // If all inputs were zero-width, return 0-bit value
1343 if (nonZeroInputs.empty())
1344 return getIntZerosAttr(getType());
1345
1346 // If only one non-zero input and it has the same type as result, return it
1347 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1348 return nonZeroInputs[0];
1349
1350 if (!hasKnownWidthIntTypes(*this))
1351 return {};
1352
1353 // Constant fold cat - concatenate all constant operands
1354 SmallVector<APInt> constants;
1355 for (auto inputAdaptor : inputAdaptors) {
1356 if (auto cst = getConstant(inputAdaptor))
1357 constants.push_back(*cst);
1358 else
1359 return {}; // Not all operands are constant
1360 }
1361
1362 assert(!constants.empty());
1363 // Concatenate all constants from left to right
1364 APInt result = constants[0];
1365 for (size_t i = 1; i < constants.size(); ++i)
1366 result = result.concat(constants[i]);
1367
1368 return getIntAttr(getType(), result);
1369}
1370
1371void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1372 MLIRContext *context) {
1373 results.insert<patterns::DShlOfConstant>(context);
1374}
1375
1376void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1377 MLIRContext *context) {
1378 results.insert<patterns::DShrOfConstant>(context);
1379}
1380
1381namespace {
1382
1383/// Canonicalize a cat tree into single cat operation.
1384class FlattenCat : public mlir::RewritePattern {
1385public:
1386 FlattenCat(MLIRContext *context)
1387 : RewritePattern(CatPrimOp::getOperationName(), 0, context) {}
1388
1389 LogicalResult
1390 matchAndRewrite(Operation *op,
1391 mlir::PatternRewriter &rewriter) const override {
1392 auto cat = cast<CatPrimOp>(op);
1393 if (!hasKnownWidthIntTypes(cat) ||
1394 cat.getType().getBitWidthOrSentinel() == 0)
1395 return failure();
1396
1397 // If this is not a root of concat tree, skip.
1398 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1399 return failure();
1400
1401 // Try flattening the cat tree.
1402 SmallVector<Value> operands;
1403 SmallVector<Value> worklist;
1404 auto pushOperands = [&worklist](CatPrimOp op) {
1405 for (auto operand : llvm::reverse(op.getInputs()))
1406 worklist.push_back(operand);
1407 };
1408 pushOperands(cat);
1409 bool hasSigned = false, hasUnsigned = false;
1410 while (!worklist.empty()) {
1411 auto value = worklist.pop_back_val();
1412 auto catOp = value.getDefiningOp<CatPrimOp>();
1413 if (!catOp) {
1414 operands.push_back(value);
1415 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) = true;
1416 continue;
1417 }
1418
1419 pushOperands(catOp);
1420 }
1421
1422 // Helper function to cast signed values to unsigned. CatPrimOp converts
1423 // signed values to unsigned so we need to do it explicitly here.
1424 auto castToUIntIfSigned = [&](Value value) -> Value {
1425 if (type_isa<UIntType>(value.getType()))
1426 return value;
1427 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1428 };
1429
1430 assert(operands.size() >= 1 && "zero width cast must be rejected");
1431
1432 if (operands.size() == 1) {
1433 rewriter.replaceOp(op, castToUIntIfSigned(operands[0]));
1434 return success();
1435 }
1436
1437 if (operands.size() == cat->getNumOperands())
1438 return failure();
1439
1440 // If types are mixed, cast all operands to unsigned.
1441 if (hasSigned && hasUnsigned)
1442 for (auto &operand : operands)
1443 operand = castToUIntIfSigned(operand);
1444
1445 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, op, cat.getType(),
1446 operands);
1447 return success();
1448 }
1449};
1450
1451// Fold successive constants cat(x, y, 1, 10, z) -> cat(x, y, 110, z)
1452class CatOfConstant : public mlir::RewritePattern {
1453public:
1454 CatOfConstant(MLIRContext *context)
1455 : RewritePattern(CatPrimOp::getOperationName(), 0, context) {}
1456
1457 LogicalResult
1458 matchAndRewrite(Operation *op,
1459 mlir::PatternRewriter &rewriter) const override {
1460 auto cat = cast<CatPrimOp>(op);
1461 if (!hasKnownWidthIntTypes(cat))
1462 return failure();
1463
1464 SmallVector<Value> operands;
1465
1466 for (size_t i = 0; i < cat->getNumOperands(); ++i) {
1467 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1468 if (!cst) {
1469 operands.push_back(cat.getInputs()[i]);
1470 continue;
1471 }
1472 APSInt value = cst.getValue();
1473 size_t j = i + 1;
1474 for (; j < cat->getNumOperands(); ++j) {
1475 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1476 if (!nextCst)
1477 break;
1478 value = value.concat(nextCst.getValue());
1479 }
1480
1481 if (j == i + 1) {
1482 // Not folded.
1483 operands.push_back(cst);
1484 } else {
1485 // Folded.
1486 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1487 }
1488
1489 i = j - 1;
1490 }
1491
1492 if (operands.size() == cat->getNumOperands())
1493 return failure();
1494
1495 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, op, cat.getType(),
1496 operands);
1497
1498 return success();
1499 }
1500};
1501
1502} // namespace
1503
1504void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1505 MLIRContext *context) {
1506 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1507 patterns::CatCast, FlattenCat, CatOfConstant>(context);
1508}
1509
1510OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1511 auto op = (*this);
1512 // BitCast is redundant if input and result types are same.
1513 if (op.getType() == op.getInput().getType())
1514 return op.getInput();
1515
1516 // Two consecutive BitCasts are redundant if first bitcast type is same as the
1517 // final result type.
1518 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1519 if (op.getType() == in.getInput().getType())
1520 return in.getInput();
1521
1522 return {};
1523}
1524
1525OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1526 IntType inputType = getInput().getType();
1527 IntType resultType = getType();
1528 // If we are extracting the entire input, then return it.
1529 if (inputType == getType() && resultType.hasWidth())
1530 return getInput();
1531
1532 // Constant fold.
1533 if (hasKnownWidthIntTypes(*this))
1534 if (auto cst = getConstant(adaptor.getInput()))
1535 return getIntAttr(resultType,
1536 cst->extractBits(getHi() - getLo() + 1, getLo()));
1537
1538 return {};
1539}
1540
1541struct BitsOfCat : public mlir::RewritePattern {
1542 BitsOfCat(MLIRContext *context)
1543 : RewritePattern(BitsPrimOp::getOperationName(), 0, context) {}
1544
1545 LogicalResult
1546 matchAndRewrite(Operation *op,
1547 mlir::PatternRewriter &rewriter) const override {
1548 auto bits = cast<BitsPrimOp>(op);
1549 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1550 if (!cat)
1551 return failure();
1552 int32_t bitPos = bits.getLo();
1553 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1554 if (resultWidth < 0)
1555 return failure();
1556 for (auto operand : llvm::reverse(cat.getInputs())) {
1557 auto operandWidth =
1558 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1559 if (operandWidth < 0)
1560 return failure();
1561 if (bitPos < operandWidth) {
1562 if (bitPos + resultWidth <= operandWidth) {
1563 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1564 op->getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1565 replaceOpAndCopyName(rewriter, op, newBits);
1566 return success();
1567 }
1568 return failure();
1569 }
1570 bitPos -= operandWidth;
1571 }
1572 return failure();
1573 }
1574};
1575
1576void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1577 MLIRContext *context) {
1578 results
1579 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1580 patterns::BitsOfAnd, patterns::BitsOfPad, BitsOfCat>(context);
1581}
1582
1583/// Replace the specified operation with a 'bits' op from the specified hi/lo
1584/// bits. Insert a cast to handle the case where the original operation
1585/// returned a signed integer.
1586static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1587 unsigned loBit, PatternRewriter &rewriter) {
1588 auto resType = type_cast<IntType>(op->getResult(0).getType());
1589 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1590 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1591
1592 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1593 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1594 } else if (resType.isUnsigned() &&
1595 !type_cast<IntType>(value.getType()).isUnsigned()) {
1596 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1597 }
1598 rewriter.replaceOp(op, value);
1599}
1600
1601template <typename OpTy>
1602static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1603 // mux : UInt<0> -> 0
1604 if (op.getType().getBitWidthOrSentinel() == 0)
1605 return getIntAttr(op.getType(),
1606 APInt(0, 0, op.getType().isSignedInteger()));
1607
1608 // mux(cond, x, x) -> x
1609 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1610 return op.getHigh();
1611
1612 // The following folds require that the result has a known width. Otherwise
1613 // the mux requires an additional padding operation to be inserted, which is
1614 // not possible in a fold.
1615 if (op.getType().getBitWidthOrSentinel() < 0)
1616 return {};
1617
1618 // mux(0/1, x, y) -> x or y
1619 if (auto cond = getConstant(adaptor.getSel())) {
1620 if (cond->isZero() && op.getLow().getType() == op.getType())
1621 return op.getLow();
1622 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1623 return op.getHigh();
1624 }
1625
1626 // mux(cond, x, cst)
1627 if (auto lowCst = getConstant(adaptor.getLow())) {
1628 // mux(cond, c1, c2)
1629 if (auto highCst = getConstant(adaptor.getHigh())) {
1630 // mux(cond, cst, cst) -> cst
1631 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1632 *highCst == *lowCst)
1633 return getIntAttr(op.getType(), *highCst);
1634 // mux(cond, 1, 0) -> cond
1635 if (highCst->isOne() && lowCst->isZero() &&
1636 op.getType() == op.getSel().getType())
1637 return op.getSel();
1638
1639 // TODO: x ? ~0 : 0 -> sext(x)
1640 // TODO: "x ? c1 : c2" -> many tricks
1641 }
1642 // TODO: "x ? a : 0" -> sext(x) & a
1643 }
1644
1645 // TODO: "x ? c1 : y" -> "~x ? y : c1"
1646 return {};
1647}
1648
1649OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1650 return foldMux(*this, adaptor);
1651}
1652
1653OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1654 return foldMux(*this, adaptor);
1655}
1656
1657OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1658
1659namespace {
1660
1661// If the mux has a known output width, pad the operands up to this width.
1662// Most folds on mux require that folded operands are of the same width as
1663// the mux itself.
1664class MuxPad : public mlir::RewritePattern {
1665public:
1666 MuxPad(MLIRContext *context)
1667 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1668
1669 LogicalResult
1670 matchAndRewrite(Operation *op,
1671 mlir::PatternRewriter &rewriter) const override {
1672 auto mux = cast<MuxPrimOp>(op);
1673 auto width = mux.getType().getBitWidthOrSentinel();
1674 if (width < 0)
1675 return failure();
1676
1677 auto pad = [&](Value input) -> Value {
1678 auto inputWidth =
1679 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1680 if (inputWidth < 0 || width == inputWidth)
1681 return input;
1682 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1683 width)
1684 .getResult();
1685 };
1686
1687 auto newHigh = pad(mux.getHigh());
1688 auto newLow = pad(mux.getLow());
1689 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1690 return failure();
1691
1692 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1693 rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1694 mux->getAttrs());
1695 return success();
1696 }
1697};
1698
1699// Find muxes which have conditions dominated by other muxes with the same
1700// condition.
1701class MuxSharedCond : public mlir::RewritePattern {
1702public:
1703 MuxSharedCond(MLIRContext *context)
1704 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1705
1706 static const int depthLimit = 5;
1707
1708 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1709 mlir::PatternRewriter &rewriter,
1710 bool updateInPlace) const {
1711 if (updateInPlace) {
1712 rewriter.modifyOpInPlace(mux, [&] {
1713 mux.setOperand(1, high);
1714 mux.setOperand(2, low);
1715 });
1716 return {};
1717 }
1718 rewriter.setInsertionPointAfter(mux);
1719 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1720 ValueRange{mux.getSel(), high, low})
1721 .getResult();
1722 }
1723
1724 // Walk a dependent mux tree assuming the condition cond is true.
1725 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1726 bool updateInPlace, int limit) const {
1727 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1728 if (!mux)
1729 return {};
1730 if (mux.getSel() == cond)
1731 return mux.getHigh();
1732 if (limit > depthLimit)
1733 return {};
1734 updateInPlace &= mux->hasOneUse();
1735
1736 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1737 limit + 1))
1738 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1739
1740 if (Value v =
1741 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1742 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1743 return {};
1744 }
1745
1746 // Walk a dependent mux tree assuming the condition cond is false.
1747 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1748 bool updateInPlace, int limit) const {
1749 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1750 if (!mux)
1751 return {};
1752 if (mux.getSel() == cond)
1753 return mux.getLow();
1754 if (limit > depthLimit)
1755 return {};
1756 updateInPlace &= mux->hasOneUse();
1757
1758 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1759 limit + 1))
1760 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1761
1762 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1763 limit + 1))
1764 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1765
1766 return {};
1767 }
1768
1769 LogicalResult
1770 matchAndRewrite(Operation *op,
1771 mlir::PatternRewriter &rewriter) const override {
1772 auto mux = cast<MuxPrimOp>(op);
1773 auto width = mux.getType().getBitWidthOrSentinel();
1774 if (width < 0)
1775 return failure();
1776
1777 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1778 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1779 return success();
1780 }
1781
1782 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
1783 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1784 return success();
1785 }
1786
1787 return failure();
1788 }
1789};
1790} // namespace
1791
1792void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1793 MLIRContext *context) {
1794 results
1795 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1796 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1797 patterns::MuxSameTrue, patterns::MuxSameFalse,
1798 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1799 context);
1800}
1801
1802void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1803 RewritePatternSet &results, MLIRContext *context) {
1804 results.add<patterns::Mux2PadSel>(context);
1805}
1806
1807void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1808 RewritePatternSet &results, MLIRContext *context) {
1809 results.add<patterns::Mux4PadSel>(context);
1810}
1811
1812OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1813 auto input = this->getInput();
1814
1815 // pad(x) -> x if the width doesn't change.
1816 if (input.getType() == getType())
1817 return input;
1818
1819 // Need to know the input width.
1820 auto inputType = input.getType().base();
1821 int32_t width = inputType.getWidthOrSentinel();
1822 if (width == -1)
1823 return {};
1824
1825 // Constant fold.
1826 if (auto cst = getConstant(adaptor.getInput())) {
1827 auto destWidth = getType().base().getWidthOrSentinel();
1828 if (destWidth == -1)
1829 return {};
1830
1831 if (inputType.isSigned() && cst->getBitWidth())
1832 return getIntAttr(getType(), cst->sext(destWidth));
1833 return getIntAttr(getType(), cst->zext(destWidth));
1834 }
1835
1836 return {};
1837}
1838
1839OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1840 auto input = this->getInput();
1841 IntType inputType = input.getType();
1842 int shiftAmount = getAmount();
1843
1844 // shl(x, 0) -> x
1845 if (shiftAmount == 0)
1846 return input;
1847
1848 // Constant fold.
1849 if (auto cst = getConstant(adaptor.getInput())) {
1850 auto inputWidth = inputType.getWidthOrSentinel();
1851 if (inputWidth != -1) {
1852 auto resultWidth = inputWidth + shiftAmount;
1853 shiftAmount = std::min(shiftAmount, resultWidth);
1854 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1855 }
1856 }
1857 return {};
1858}
1859
1860OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1861 auto input = this->getInput();
1862 IntType inputType = input.getType();
1863 int shiftAmount = getAmount();
1864 auto inputWidth = inputType.getWidthOrSentinel();
1865
1866 // shr(x, 0) -> x
1867 // Once the shr width changes, do this: shiftAmount == 0 &&
1868 // (!inputType.isSigned() || inputWidth > 0)
1869 if (shiftAmount == 0 && inputWidth > 0)
1870 return input;
1871
1872 if (inputWidth == -1)
1873 return {};
1874 if (inputWidth == 0)
1875 return getIntZerosAttr(getType());
1876
1877 // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
1878 // If x is signed, it is the sign bit.
1879 if (shiftAmount >= inputWidth && inputType.isUnsigned())
1880 return getIntAttr(getType(), APInt(0, 0, false));
1881
1882 // Constant fold.
1883 if (auto cst = getConstant(adaptor.getInput())) {
1884 APInt value;
1885 if (inputType.isSigned())
1886 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1887 else
1888 value = cst->lshr(std::min(shiftAmount, inputWidth));
1889 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1890 return getIntAttr(getType(), value.trunc(resultWidth));
1891 }
1892 return {};
1893}
1894
1895LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1896 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1897 if (inputWidth <= 0)
1898 return failure();
1899
1900 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1901 unsigned shiftAmount = op.getAmount();
1902 if (int(shiftAmount) >= inputWidth) {
1903 // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
1904 if (op.getType().base().isUnsigned())
1905 return failure();
1906
1907 // Shifting a signed value by the full width is actually taking the
1908 // sign bit. If the shift amount is greater than the input width, it
1909 // is equivalent to shifting by the input width.
1910 shiftAmount = inputWidth - 1;
1911 }
1912
1913 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1914 return success();
1915}
1916
1917LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1918 PatternRewriter &rewriter) {
1919 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1920 if (inputWidth <= 0)
1921 return failure();
1922
1923 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1924 unsigned keepAmount = op.getAmount();
1925 if (keepAmount)
1926 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1927 rewriter);
1928 return success();
1929}
1930
1931OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1932 if (hasKnownWidthIntTypes(*this))
1933 if (auto cst = getConstant(adaptor.getInput())) {
1934 int shiftAmount =
1935 getInput().getType().base().getWidthOrSentinel() - getAmount();
1936 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1937 }
1938
1939 return {};
1940}
1941
1942OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1943 if (hasKnownWidthIntTypes(*this))
1944 if (auto cst = getConstant(adaptor.getInput()))
1945 return getIntAttr(getType(),
1946 cst->trunc(getType().base().getWidthOrSentinel()));
1947 return {};
1948}
1949
1950LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1951 PatternRewriter &rewriter) {
1952 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1953 if (inputWidth <= 0)
1954 return failure();
1955
1956 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1957 unsigned dropAmount = op.getAmount();
1958 if (dropAmount != unsigned(inputWidth))
1959 replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
1960 rewriter);
1961 return success();
1962}
1963
1964void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1965 MLIRContext *context) {
1966 results.add<patterns::SubaccessOfConstant>(context);
1967}
1968
1969OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1970 // If there is only one input, just return it.
1971 if (adaptor.getInputs().size() == 1)
1972 return getOperand(1);
1973
1974 if (auto constIndex = getConstant(adaptor.getIndex())) {
1975 auto index = constIndex->getZExtValue();
1976 if (index < getInputs().size())
1977 return getInputs()[getInputs().size() - 1 - index];
1978 }
1979
1980 return {};
1981}
1982
1983LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
1984 PatternRewriter &rewriter) {
1985 // If all operands are equal, just canonicalize to it. We can add this
1986 // canonicalization as a folder but it costly to look through all inputs so it
1987 // is added here.
1988 if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
1989 return input == op.getInputs().front();
1990 })) {
1991 replaceOpAndCopyName(rewriter, op, op.getInputs().front());
1992 return success();
1993 }
1994
1995 // If the index width is narrower than the size of inputs, drop front
1996 // elements.
1997 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
1998 uint64_t inputSize = op.getInputs().size();
1999 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
2000 rewriter.modifyOpInPlace(op, [&]() {
2001 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2002 });
2003 return success();
2004 }
2005
2006 // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
2007 // a[0]`), we can fold the op into subaccess op `a[idx]`.
2008 if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2009 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
2010 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2011 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2012 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2013 })) {
2014 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2015 rewriter, op, lastSubindex.getInput(), op.getIndex());
2016 return success();
2017 }
2018 }
2019
2020 // If the size is 2, canonicalize into a normal mux to introduce more folds.
2021 if (op.getInputs().size() != 2)
2022 return failure();
2023
2024 // TODO: Handle even when `index` doesn't have uint<1>.
2025 auto uintType = op.getIndex().getType();
2026 if (uintType.getBitWidthOrSentinel() != 1)
2027 return failure();
2028
2029 // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
2030 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2031 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2032 return success();
2033}
2034
2035//===----------------------------------------------------------------------===//
2036// Declarations
2037//===----------------------------------------------------------------------===//
2038
2039/// Scan all the uses of the specified value, checking to see if there is
2040/// exactly one connect that has the value as its destination. This returns the
2041/// operation if found and if all the other users are "reads" from the value.
2042/// Returns null if there are no connects, or multiple connects to the value, or
2043/// if the value is involved in an `AttachOp`, or if the connect isn't matching.
2044///
2045/// Note that this will simply return the connect, which is located *anywhere*
2046/// after the definition of the value. Users of this function are likely
2047/// interested in the source side of the returned connect, the definition of
2048/// which does likely not dominate the original value.
2049MatchingConnectOp firrtl::getSingleConnectUserOf(Value value) {
2050 MatchingConnectOp connect;
2051 for (Operation *user : value.getUsers()) {
2052 // If we see an attach or aggregate sublements, just conservatively fail.
2053 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2054 return {};
2055
2056 if (auto aConnect = dyn_cast<FConnectLike>(user))
2057 if (aConnect.getDest() == value) {
2058 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2059 // If this is not a matching connect, a second matching connect or in a
2060 // different block, fail.
2061 if (!matchingConnect || (connect && connect != matchingConnect) ||
2062 matchingConnect->getBlock() != value.getParentBlock())
2063 return {};
2064 connect = matchingConnect;
2065 }
2066 }
2067 return connect;
2068}
2069
2070// Forward simple values through wire's and reg's.
2071static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op,
2072 PatternRewriter &rewriter) {
2073 // While we can do this for nearly all wires, we currently limit it to simple
2074 // things.
2075 Operation *connectedDecl = op.getDest().getDefiningOp();
2076 if (!connectedDecl)
2077 return failure();
2078
2079 // Only support wire and reg for now.
2080 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2081 return failure();
2082 if (hasDontTouch(connectedDecl) || !AnnotationSet(connectedDecl).empty() ||
2083 !hasDroppableName(connectedDecl) ||
2084 cast<Forceable>(connectedDecl).isForceable())
2085 return failure();
2086
2087 // Only forward if the types exactly match and there is one connect.
2088 if (getSingleConnectUserOf(op.getDest()) != op)
2089 return failure();
2090
2091 // Only forward if there is more than one use
2092 if (connectedDecl->hasOneUse())
2093 return failure();
2094
2095 // Only do this if the connectee and the declaration are in the same block.
2096 auto *declBlock = connectedDecl->getBlock();
2097 auto *srcValueOp = op.getSrc().getDefiningOp();
2098 if (!srcValueOp) {
2099 // Ports are ok for wires but not registers.
2100 if (!isa<WireOp>(connectedDecl))
2101 return failure();
2102
2103 } else {
2104 // Constants/invalids in the same block are ok to forward, even through
2105 // reg's since the clocking doesn't matter for constants.
2106 if (!isa<ConstantOp>(srcValueOp))
2107 return failure();
2108 if (srcValueOp->getBlock() != declBlock)
2109 return failure();
2110 }
2111
2112 // Ok, we know we are doing the transformation.
2113
2114 auto replacement = op.getSrc();
2115 // This will be replaced with the constant source. First, make sure the
2116 // constant dominates all users.
2117 if (srcValueOp && srcValueOp != &declBlock->front())
2118 srcValueOp->moveBefore(&declBlock->front());
2119
2120 // Replace all things *using* the decl with the constant/port, and
2121 // remove the declaration.
2122 replaceOpAndCopyName(rewriter, connectedDecl, replacement);
2123
2124 // Remove the connect
2125 rewriter.eraseOp(op);
2126 return success();
2127}
2128
2129void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2130 MLIRContext *context) {
2131 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2132 context);
2133}
2134
2135LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2136 PatternRewriter &rewriter) {
2137 // TODO: Canonicalize towards explicit extensions and flips here.
2138
2139 // If there is a simple value connected to a foldable decl like a wire or reg,
2140 // see if we can eliminate the decl.
2141 if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
2142 return success();
2143 return failure();
2144}
2145
2146//===----------------------------------------------------------------------===//
2147// Statements
2148//===----------------------------------------------------------------------===//
2149
2150/// If the specified value has an AttachOp user strictly dominating by
2151/// "dominatingAttach" then return it.
2152static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
2153 for (auto *user : value.getUsers()) {
2154 auto attach = dyn_cast<AttachOp>(user);
2155 if (!attach || attach == dominatedAttach)
2156 continue;
2157 if (attach->isBeforeInBlock(dominatedAttach))
2158 return attach;
2159 }
2160 return {};
2161}
2162
2163LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2164 // Single operand attaches are a noop.
2165 if (op.getNumOperands() <= 1) {
2166 rewriter.eraseOp(op);
2167 return success();
2168 }
2169
2170 for (auto operand : op.getOperands()) {
2171 // Check to see if any of our operands has other attaches to it:
2172 // attach x, y
2173 // ...
2174 // attach x, z
2175 // If so, we can merge these into "attach x, y, z".
2176 if (auto attach = getDominatingAttachUser(operand, op)) {
2177 SmallVector<Value> newOperands(op.getOperands());
2178 for (auto newOperand : attach.getOperands())
2179 if (newOperand != operand) // Don't add operand twice.
2180 newOperands.push_back(newOperand);
2181 AttachOp::create(rewriter, op->getLoc(), newOperands);
2182 rewriter.eraseOp(attach);
2183 rewriter.eraseOp(op);
2184 return success();
2185 }
2186
2187 // If this wire is *only* used by an attach then we can just delete
2188 // it.
2189 // TODO: May need to be sensitive to "don't touch" or other
2190 // annotations.
2191 if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2192 if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2193 !wire.isForceable()) {
2194 SmallVector<Value> newOperands;
2195 for (auto newOperand : op.getOperands())
2196 if (newOperand != operand) // Don't the add wire.
2197 newOperands.push_back(newOperand);
2198
2199 AttachOp::create(rewriter, op->getLoc(), newOperands);
2200 rewriter.eraseOp(op);
2201 rewriter.eraseOp(wire);
2202 return success();
2203 }
2204 }
2205 }
2206 return failure();
2207}
2208
2209/// Replaces the given op with the contents of the given single-block region.
2210static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
2211 Region &region) {
2212 assert(llvm::hasSingleElement(region) && "expected single-region block");
2213 rewriter.inlineBlockBefore(&region.front(), op, {});
2214}
2215
2216LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2217 if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2218 if (constant.getValue().isAllOnes())
2219 replaceOpWithRegion(rewriter, op, op.getThenRegion());
2220 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2221 replaceOpWithRegion(rewriter, op, op.getElseRegion());
2222
2223 rewriter.eraseOp(op);
2224
2225 return success();
2226 }
2227
2228 // Erase empty if-else block.
2229 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2230 op.getElseBlock().empty()) {
2231 rewriter.eraseBlock(&op.getElseBlock());
2232 return success();
2233 }
2234
2235 // Erase empty whens.
2236
2237 // If there is stuff in the then block, leave this operation alone.
2238 if (!op.getThenBlock().empty())
2239 return failure();
2240
2241 // If not and there is no else, then this operation is just useless.
2242 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2243 rewriter.eraseOp(op);
2244 return success();
2245 }
2246 return failure();
2247}
2248
2249namespace {
2250// Remove private nodes. If they have an interesting names, move the name to
2251// the source expression.
2252struct FoldNodeName : public mlir::RewritePattern {
2253 FoldNodeName(MLIRContext *context)
2254 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
2255 LogicalResult matchAndRewrite(Operation *op,
2256 PatternRewriter &rewriter) const override {
2257 auto node = cast<NodeOp>(op);
2258 auto name = node.getNameAttr();
2259 if (!node.hasDroppableName() || node.getInnerSym() ||
2260 !AnnotationSet(node).empty() || node.isForceable())
2261 return failure();
2262 auto *newOp = node.getInput().getDefiningOp();
2263 if (newOp)
2264 updateName(rewriter, newOp, name);
2265 rewriter.replaceOp(node, node.getInput());
2266 return success();
2267 }
2268};
2269
2270// Bypass nodes.
2271struct NodeBypass : public mlir::RewritePattern {
2272 NodeBypass(MLIRContext *context)
2273 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
2274 LogicalResult matchAndRewrite(Operation *op,
2275 PatternRewriter &rewriter) const override {
2276 auto node = cast<NodeOp>(op);
2277 if (node.getInnerSym() || !AnnotationSet(node).empty() ||
2278 node.use_empty() || node.isForceable())
2279 return failure();
2280 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2281 return success();
2282 }
2283};
2284
2285} // namespace
2286
2287template <typename OpTy>
2288static LogicalResult demoteForceableIfUnused(OpTy op,
2289 PatternRewriter &rewriter) {
2290 if (!op.isForceable() || !op.getDataRef().use_empty())
2291 return failure();
2292
2293 firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
2294 return success();
2295}
2296
2297// Interesting names and symbols and don't touch force nodes to stick around.
2298LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2299 SmallVectorImpl<OpFoldResult> &results) {
2300 if (!hasDroppableName())
2301 return failure();
2302 if (hasDontTouch(getResult())) // handles inner symbols
2303 return failure();
2304 if (getAnnotationsAttr() && !AnnotationSet(getAnnotationsAttr()).empty())
2305 return failure();
2306 if (isForceable())
2307 return failure();
2308 if (!adaptor.getInput())
2309 return failure();
2310
2311 results.push_back(adaptor.getInput());
2312 return success();
2313}
2314
2315void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2316 MLIRContext *context) {
2317 results.insert<FoldNodeName>(context);
2318 results.add(demoteForceableIfUnused<NodeOp>);
2319}
2320
2321namespace {
2322// For a lhs, find all the writers of fields of the aggregate type. If there
2323// is one writer for each field, merge the writes
2324struct AggOneShot : public mlir::RewritePattern {
2325 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2326 : RewritePattern(name, 0, context) {}
2327
2328 SmallVector<Value> getCompleteWrite(Operation *lhs) const {
2329 auto lhsTy = lhs->getResult(0).getType();
2330 if (!type_isa<BundleType, FVectorType>(lhsTy))
2331 return {};
2332
2333 DenseMap<uint32_t, Value> fields;
2334 for (Operation *user : lhs->getResult(0).getUsers()) {
2335 if (user->getParentOp() != lhs->getParentOp())
2336 return {};
2337 if (auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2338 if (aConnect.getDest() == lhs->getResult(0))
2339 return {};
2340 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2341 for (Operation *subuser : subField.getResult().getUsers()) {
2342 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2343 if (aConnect.getDest() == subField) {
2344 if (subuser->getParentOp() != lhs->getParentOp())
2345 return {};
2346 if (fields.count(subField.getFieldIndex())) // duplicate write
2347 return {};
2348 fields[subField.getFieldIndex()] = aConnect.getSrc();
2349 }
2350 continue;
2351 }
2352 return {};
2353 }
2354 } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2355 for (Operation *subuser : subIndex.getResult().getUsers()) {
2356 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2357 if (aConnect.getDest() == subIndex) {
2358 if (subuser->getParentOp() != lhs->getParentOp())
2359 return {};
2360 if (fields.count(subIndex.getIndex())) // duplicate write
2361 return {};
2362 fields[subIndex.getIndex()] = aConnect.getSrc();
2363 }
2364 continue;
2365 }
2366 return {};
2367 }
2368 } else {
2369 return {};
2370 }
2371 }
2372
2373 SmallVector<Value> values;
2374 uint32_t total = type_isa<BundleType>(lhsTy)
2375 ? type_cast<BundleType>(lhsTy).getNumElements()
2376 : type_cast<FVectorType>(lhsTy).getNumElements();
2377 for (uint32_t i = 0; i < total; ++i) {
2378 if (!fields.count(i))
2379 return {};
2380 values.push_back(fields[i]);
2381 }
2382 return values;
2383 }
2384
2385 LogicalResult matchAndRewrite(Operation *op,
2386 PatternRewriter &rewriter) const override {
2387 auto values = getCompleteWrite(op);
2388 if (values.empty())
2389 return failure();
2390 rewriter.setInsertionPointToEnd(op->getBlock());
2391 auto dest = op->getResult(0);
2392 auto destType = dest.getType();
2393
2394 // If not passive, cannot matchingconnect.
2395 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2396 return failure();
2397
2398 Value newVal = type_isa<BundleType>(destType)
2399 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2400 destType, values)
2401 : rewriter.createOrFold<VectorCreateOp>(
2402 op->getLoc(), destType, values);
2403 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2404 for (Operation *user : dest.getUsers()) {
2405 if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2406 for (Operation *subuser :
2407 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2408 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2409 if (aConnect.getDest() == subIndex)
2410 rewriter.eraseOp(aConnect);
2411 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2412 for (Operation *subuser :
2413 llvm::make_early_inc_range(subField.getResult().getUsers()))
2414 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2415 if (aConnect.getDest() == subField)
2416 rewriter.eraseOp(aConnect);
2417 }
2418 }
2419 return success();
2420 }
2421};
2422
2423struct WireAggOneShot : public AggOneShot {
2424 WireAggOneShot(MLIRContext *context)
2425 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2426};
2427struct SubindexAggOneShot : public AggOneShot {
2428 SubindexAggOneShot(MLIRContext *context)
2429 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2430};
2431struct SubfieldAggOneShot : public AggOneShot {
2432 SubfieldAggOneShot(MLIRContext *context)
2433 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2434};
2435} // namespace
2436
2437void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2438 MLIRContext *context) {
2439 results.insert<WireAggOneShot>(context);
2440 results.add(demoteForceableIfUnused<WireOp>);
2441}
2442
2443void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2444 MLIRContext *context) {
2445 results.insert<SubindexAggOneShot>(context);
2446}
2447
2448OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2449 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2450 if (!attr)
2451 return {};
2452 return attr[getIndex()];
2453}
2454
2455OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2456 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2457 if (!attr)
2458 return {};
2459 auto index = getFieldIndex();
2460 return attr[index];
2461}
2462
2463void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2464 MLIRContext *context) {
2465 results.insert<SubfieldAggOneShot>(context);
2466}
2467
2468static Attribute collectFields(MLIRContext *context,
2469 ArrayRef<Attribute> operands) {
2470 for (auto operand : operands)
2471 if (!operand)
2472 return {};
2473 return ArrayAttr::get(context, operands);
2474}
2475
2476OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2477 // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2478 // bundle<a:..., b:...>.
2479 if (getNumOperands() > 0)
2480 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2481 if (first.getFieldIndex() == 0 &&
2482 first.getInput().getType() == getType() &&
2483 llvm::all_of(
2484 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2485 auto subindex =
2486 elem.value().template getDefiningOp<SubfieldOp>();
2487 return subindex && subindex.getInput() == first.getInput() &&
2488 subindex.getFieldIndex() == elem.index();
2489 }))
2490 return first.getInput();
2491
2492 return collectFields(getContext(), adaptor.getOperands());
2493}
2494
2495OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2496 // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2497 // vector<..., 2>.
2498 if (getNumOperands() > 0)
2499 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2500 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2501 llvm::all_of(
2502 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2503 auto subindex =
2504 elem.value().template getDefiningOp<SubindexOp>();
2505 return subindex && subindex.getInput() == first.getInput() &&
2506 subindex.getIndex() == elem.index();
2507 }))
2508 return first.getInput();
2509
2510 return collectFields(getContext(), adaptor.getOperands());
2511}
2512
2513OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2514 if (getOperand().getType() == getType())
2515 return getOperand();
2516 return {};
2517}
2518
2519namespace {
2520// A register with constant reset and all connection to either itself or the
2521// same constant, must be replaced by the constant.
2522struct FoldResetMux : public mlir::RewritePattern {
2523 FoldResetMux(MLIRContext *context)
2524 : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2525 LogicalResult matchAndRewrite(Operation *op,
2526 PatternRewriter &rewriter) const override {
2527 auto reg = cast<RegResetOp>(op);
2528 auto reset =
2529 dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2530 if (!reset || hasDontTouch(reg.getOperation()) ||
2531 !AnnotationSet(reg).empty() || reg.isForceable())
2532 return failure();
2533 // Find the one true connect, or bail
2534 auto con = getSingleConnectUserOf(reg.getResult());
2535 if (!con)
2536 return failure();
2537
2538 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2539 if (!mux)
2540 return failure();
2541 auto *high = mux.getHigh().getDefiningOp();
2542 auto *low = mux.getLow().getDefiningOp();
2543 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2544
2545 if (constOp && low != reg)
2546 return failure();
2547 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2548 constOp = dyn_cast<ConstantOp>(low);
2549
2550 if (!constOp || constOp.getType() != reset.getType() ||
2551 constOp.getValue() != reset.getValue())
2552 return failure();
2553
2554 // Check all types should be typed by now
2555 auto regTy = reg.getResult().getType();
2556 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2557 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2558 regTy.getBitWidthOrSentinel() < 0)
2559 return failure();
2560
2561 // Ok, we know we are doing the transformation.
2562
2563 // Make sure the constant dominates all users.
2564 if (constOp != &con->getBlock()->front())
2565 constOp->moveBefore(&con->getBlock()->front());
2566
2567 // Replace the register with the constant.
2568 replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2569 // Remove the connect.
2570 rewriter.eraseOp(con);
2571 return success();
2572 }
2573};
2574} // namespace
2575
2576static bool isDefinedByOneConstantOp(Value v) {
2577 if (auto c = v.getDefiningOp<ConstantOp>())
2578 return c.getValue().isOne();
2579 if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2580 return sc.getValue();
2581 return false;
2582}
2583
2584static LogicalResult
2585canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2586 if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2587 return failure();
2588
2589 auto resetValue = reg.getResetValue();
2590 if (reg.getType(0) != resetValue.getType())
2591 return failure();
2592
2593 // Ignore 'passthrough'.
2594 (void)dropWrite(rewriter, reg->getResult(0), {});
2595 replaceOpWithNewOpAndCopyName<NodeOp>(
2596 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2597 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2598 return success();
2599}
2600
2601void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2602 MLIRContext *context) {
2603 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2605 results.add(demoteForceableIfUnused<RegResetOp>);
2606}
2607
2608// Returns the value connected to a port, if there is only one.
2609static Value getPortFieldValue(Value port, StringRef name) {
2610 auto portTy = type_cast<BundleType>(port.getType());
2611 auto fieldIndex = portTy.getElementIndex(name);
2612 assert(fieldIndex && "missing field on memory port");
2613
2614 Value value = {};
2615 for (auto *op : port.getUsers()) {
2616 auto portAccess = cast<SubfieldOp>(op);
2617 if (fieldIndex != portAccess.getFieldIndex())
2618 continue;
2619 auto conn = getSingleConnectUserOf(portAccess);
2620 if (!conn || value)
2621 return {};
2622 value = conn.getSrc();
2623 }
2624 return value;
2625}
2626
2627// Returns true if the enable field of a port is set to false.
2628static bool isPortDisabled(Value port) {
2629 auto value = getPortFieldValue(port, "en");
2630 if (!value)
2631 return false;
2632 auto portConst = value.getDefiningOp<ConstantOp>();
2633 if (!portConst)
2634 return false;
2635 return portConst.getValue().isZero();
2636}
2637
2638// Returns true if the data output is unused.
2639static bool isPortUnused(Value port, StringRef data) {
2640 auto portTy = type_cast<BundleType>(port.getType());
2641 auto fieldIndex = portTy.getElementIndex(data);
2642 assert(fieldIndex && "missing enable flag on memory port");
2643
2644 for (auto *op : port.getUsers()) {
2645 auto portAccess = cast<SubfieldOp>(op);
2646 if (fieldIndex != portAccess.getFieldIndex())
2647 continue;
2648 if (!portAccess.use_empty())
2649 return false;
2650 }
2651
2652 return true;
2653}
2654
2655// Returns the value connected to a port, if there is only one.
2656static void replacePortField(PatternRewriter &rewriter, Value port,
2657 StringRef name, Value value) {
2658 auto portTy = type_cast<BundleType>(port.getType());
2659 auto fieldIndex = portTy.getElementIndex(name);
2660 assert(fieldIndex && "missing field on memory port");
2661
2662 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2663 auto portAccess = cast<SubfieldOp>(op);
2664 if (fieldIndex != portAccess.getFieldIndex())
2665 continue;
2666 rewriter.replaceAllUsesWith(portAccess, value);
2667 rewriter.eraseOp(portAccess);
2668 }
2669}
2670
2671// Remove accesses to a port which is used.
2672static void erasePort(PatternRewriter &rewriter, Value port) {
2673 // Helper to create a dummy 0 clock for the dummy registers.
2674 Value clock;
2675 auto getClock = [&] {
2676 if (!clock)
2677 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2678 ClockType::get(rewriter.getContext()),
2679 false);
2680 return clock;
2681 };
2682
2683 // Find the clock field of the port and determine whether the port is
2684 // accessed only through its subfields or as a whole wire. If the port
2685 // is used in its entirety, replace it with a wire. Otherwise,
2686 // eliminate individual subfields and replace with reasonable defaults.
2687 for (auto *op : port.getUsers()) {
2688 auto subfield = dyn_cast<SubfieldOp>(op);
2689 if (!subfield) {
2690 auto ty = port.getType();
2691 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2692 rewriter.replaceAllUsesWith(port, reg.getResult());
2693 return;
2694 }
2695 }
2696
2697 // Remove all connects to field accesses as they are no longer relevant.
2698 // If field values are used anywhere, which should happen solely for read
2699 // ports, a dummy register is introduced which replicates the behaviour of
2700 // memory that is never written, but might be read.
2701 for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2702 auto access = cast<SubfieldOp>(accessOp);
2703 for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2704 auto connect = dyn_cast<FConnectLike>(user);
2705 if (connect && connect.getDest() == access) {
2706 rewriter.eraseOp(user);
2707 continue;
2708 }
2709 }
2710 if (access.use_empty()) {
2711 rewriter.eraseOp(access);
2712 continue;
2713 }
2714
2715 // Replace read values with a register that is never written, handing off
2716 // the canonicalization of such a register to another canonicalizer.
2717 auto ty = access.getType();
2718 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2719 rewriter.replaceOp(access, reg.getResult());
2720 }
2721 assert(port.use_empty() && "port should have no remaining uses");
2722}
2723
2724namespace {
2725// If memory has known, but zero width, eliminate it.
2726struct FoldZeroWidthMemory : public mlir::RewritePattern {
2727 FoldZeroWidthMemory(MLIRContext *context)
2728 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2729 LogicalResult matchAndRewrite(Operation *op,
2730 PatternRewriter &rewriter) const override {
2731 MemOp mem = cast<MemOp>(op);
2732 if (hasDontTouch(mem))
2733 return failure();
2734
2735 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2736 mem.getDataType().getBitWidthOrSentinel() != 0)
2737 return failure();
2738
2739 // Make sure are users are safe to replace
2740 for (auto port : mem.getResults())
2741 for (auto *user : port.getUsers())
2742 if (!isa<SubfieldOp>(user))
2743 return failure();
2744
2745 // Annoyingly, there isn't a good replacement for the port as a whole,
2746 // since they have an outer flip type.
2747 for (auto port : op->getResults()) {
2748 for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2749 SubfieldOp sfop = cast<SubfieldOp>(user);
2750 StringRef fieldName = sfop.getFieldName();
2751 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2752 rewriter, sfop, sfop.getResult().getType())
2753 .getResult();
2754 if (fieldName.ends_with("data")) {
2755 // Make sure to write data ports.
2756 auto zero = firrtl::ConstantOp::create(
2757 rewriter, wire.getLoc(),
2758 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2759 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2760 }
2761 }
2762 }
2763 rewriter.eraseOp(op);
2764 return success();
2765 }
2766};
2767
2768// If memory has no write ports and no file initialization, eliminate it.
2769struct FoldReadOrWriteOnlyMemory : public mlir::RewritePattern {
2770 FoldReadOrWriteOnlyMemory(MLIRContext *context)
2771 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2772 LogicalResult matchAndRewrite(Operation *op,
2773 PatternRewriter &rewriter) const override {
2774 MemOp mem = cast<MemOp>(op);
2775 if (hasDontTouch(mem))
2776 return failure();
2777 bool isRead = false, isWritten = false;
2778 for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2779 switch (mem.getPortKind(i)) {
2780 case MemOp::PortKind::Read:
2781 isRead = true;
2782 if (isWritten)
2783 return failure();
2784 continue;
2785 case MemOp::PortKind::Write:
2786 isWritten = true;
2787 if (isRead)
2788 return failure();
2789 continue;
2790 case MemOp::PortKind::Debug:
2791 case MemOp::PortKind::ReadWrite:
2792 return failure();
2793 }
2794 llvm_unreachable("unknown port kind");
2795 }
2796 assert((!isWritten || !isRead) && "memory is in use");
2797
2798 // If the memory is read only, but has a file initialization, then we can't
2799 // remove it. A write only memory with file initialization is okay to
2800 // remove.
2801 if (isRead && mem.getInit())
2802 return failure();
2803
2804 for (auto port : mem.getResults())
2805 erasePort(rewriter, port);
2806
2807 rewriter.eraseOp(op);
2808 return success();
2809 }
2810};
2811
2812// Eliminate the dead ports of memories.
2813struct FoldUnusedPorts : public mlir::RewritePattern {
2814 FoldUnusedPorts(MLIRContext *context)
2815 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2816 LogicalResult matchAndRewrite(Operation *op,
2817 PatternRewriter &rewriter) const override {
2818 MemOp mem = cast<MemOp>(op);
2819 if (hasDontTouch(mem))
2820 return failure();
2821 // Identify the dead and changed ports.
2822 llvm::SmallBitVector deadPorts(mem.getNumResults());
2823 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2824 // Do not simplify annotated ports.
2825 if (!mem.getPortAnnotation(i).empty())
2826 continue;
2827
2828 // Skip debug ports.
2829 auto kind = mem.getPortKind(i);
2830 if (kind == MemOp::PortKind::Debug)
2831 continue;
2832
2833 // If a port is disabled, always eliminate it.
2834 if (isPortDisabled(port)) {
2835 deadPorts.set(i);
2836 continue;
2837 }
2838 // Eliminate read ports whose outputs are not used.
2839 if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
2840 deadPorts.set(i);
2841 continue;
2842 }
2843 }
2844 if (deadPorts.none())
2845 return failure();
2846
2847 // Rebuild the new memory with the altered ports.
2848 SmallVector<Type> resultTypes;
2849 SmallVector<StringRef> portNames;
2850 SmallVector<Attribute> portAnnotations;
2851 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2852 if (deadPorts[i])
2853 continue;
2854 resultTypes.push_back(port.getType());
2855 portNames.push_back(mem.getPortName(i));
2856 portAnnotations.push_back(mem.getPortAnnotation(i));
2857 }
2858
2859 MemOp newOp;
2860 if (!resultTypes.empty())
2861 newOp = MemOp::create(
2862 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2863 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2864 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2865 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2866 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2867
2868 // Replace the dead ports with dummy wires.
2869 unsigned nextPort = 0;
2870 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2871 if (deadPorts[i])
2872 erasePort(rewriter, port);
2873 else
2874 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2875 }
2876
2877 rewriter.eraseOp(op);
2878 return success();
2879 }
2880};
2881
2882// Rewrite write-only read-write ports to write ports.
2883struct FoldReadWritePorts : public mlir::RewritePattern {
2884 FoldReadWritePorts(MLIRContext *context)
2885 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2886 LogicalResult matchAndRewrite(Operation *op,
2887 PatternRewriter &rewriter) const override {
2888 MemOp mem = cast<MemOp>(op);
2889 if (hasDontTouch(mem))
2890 return failure();
2891
2892 // Identify read-write ports whose read end is unused.
2893 llvm::SmallBitVector deadReads(mem.getNumResults());
2894 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2895 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2896 continue;
2897 if (!mem.getPortAnnotation(i).empty())
2898 continue;
2899 if (isPortUnused(port, "rdata")) {
2900 deadReads.set(i);
2901 continue;
2902 }
2903 }
2904 if (deadReads.none())
2905 return failure();
2906
2907 SmallVector<Type> resultTypes;
2908 SmallVector<StringRef> portNames;
2909 SmallVector<Attribute> portAnnotations;
2910 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2911 if (deadReads[i])
2912 resultTypes.push_back(
2913 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2914 MemOp::PortKind::Write, mem.getMaskBits()));
2915 else
2916 resultTypes.push_back(port.getType());
2917
2918 portNames.push_back(mem.getPortName(i));
2919 portAnnotations.push_back(mem.getPortAnnotation(i));
2920 }
2921
2922 auto newOp = MemOp::create(
2923 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2924 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2925 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2926 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2927 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2928
2929 for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2930 auto result = mem.getResult(i);
2931 auto newResult = newOp.getResult(i);
2932 if (deadReads[i]) {
2933 auto resultPortTy = type_cast<BundleType>(result.getType());
2934
2935 // Rewrite accesses to the old port field to accesses to a
2936 // corresponding field of the new port.
2937 auto replace = [&](StringRef toName, StringRef fromName) {
2938 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2939 assert(fromFieldIndex && "missing enable flag on memory port");
2940
2941 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
2942 newResult, toName);
2943 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2944 auto fromField = cast<SubfieldOp>(op);
2945 if (fromFieldIndex != fromField.getFieldIndex())
2946 continue;
2947 rewriter.replaceOp(fromField, toField.getResult());
2948 }
2949 };
2950
2951 replace("addr", "addr");
2952 replace("en", "en");
2953 replace("clk", "clk");
2954 replace("data", "wdata");
2955 replace("mask", "wmask");
2956
2957 // Remove the wmode field, replacing it with dummy wires.
2958 auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
2959 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2960 auto wmodeField = cast<SubfieldOp>(op);
2961 if (wmodeFieldIndex != wmodeField.getFieldIndex())
2962 continue;
2963 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2964 }
2965 } else {
2966 rewriter.replaceAllUsesWith(result, newResult);
2967 }
2968 }
2969 rewriter.eraseOp(op);
2970 return success();
2971 }
2972};
2973
2974// Eliminate the dead ports of memories.
2975struct FoldUnusedBits : public mlir::RewritePattern {
2976 FoldUnusedBits(MLIRContext *context)
2977 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2978
2979 LogicalResult matchAndRewrite(Operation *op,
2980 PatternRewriter &rewriter) const override {
2981 MemOp mem = cast<MemOp>(op);
2982 if (hasDontTouch(mem))
2983 return failure();
2984
2985 // Only apply the transformation if the memory is not sequential.
2986 const auto &summary = mem.getSummary();
2987 if (summary.isMasked || summary.isSeqMem())
2988 return failure();
2989
2990 auto type = type_dyn_cast<IntType>(mem.getDataType());
2991 if (!type)
2992 return failure();
2993 auto width = type.getBitWidthOrSentinel();
2994 if (width <= 0)
2995 return failure();
2996
2997 llvm::SmallBitVector usedBits(width);
2998 DenseMap<unsigned, unsigned> mapping;
2999
3000 // Find which bits are used out of the users of a read port. This detects
3001 // ports whose data/rdata field is used only through bit select ops. The
3002 // bit selects are then used to build a bit-mask. The ops are collected.
3003 SmallVector<BitsPrimOp> readOps;
3004 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3005 auto portTy = type_cast<BundleType>(port.getType());
3006 auto fieldIndex = portTy.getElementIndex(field);
3007 assert(fieldIndex && "missing data port");
3008
3009 for (auto *op : port.getUsers()) {
3010 auto portAccess = cast<SubfieldOp>(op);
3011 if (fieldIndex != portAccess.getFieldIndex())
3012 continue;
3013
3014 for (auto *user : op->getUsers()) {
3015 auto bits = dyn_cast<BitsPrimOp>(user);
3016 if (!bits)
3017 return failure();
3018
3019 usedBits.set(bits.getLo(), bits.getHi() + 1);
3020 if (usedBits.all())
3021 return failure();
3022
3023 mapping[bits.getLo()] = 0;
3024 readOps.push_back(bits);
3025 }
3026 }
3027
3028 return success();
3029 };
3030
3031 // Finds the users of write ports. This expects all the data/wdata fields
3032 // of the ports to be used solely as the destination of matching connects.
3033 // If a memory has ports with other uses, it is excluded from optimisation.
3034 SmallVector<MatchingConnectOp> writeOps;
3035 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3036 auto portTy = type_cast<BundleType>(port.getType());
3037 auto fieldIndex = portTy.getElementIndex(field);
3038 assert(fieldIndex && "missing data port");
3039
3040 for (auto *op : port.getUsers()) {
3041 auto portAccess = cast<SubfieldOp>(op);
3042 if (fieldIndex != portAccess.getFieldIndex())
3043 continue;
3044
3045 auto conn = getSingleConnectUserOf(portAccess);
3046 if (!conn)
3047 return failure();
3048
3049 writeOps.push_back(conn);
3050 }
3051 return success();
3052 };
3053
3054 // Traverse all ports and find the read and used data fields.
3055 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3056 // Do not simplify annotated ports.
3057 if (!mem.getPortAnnotation(i).empty())
3058 return failure();
3059
3060 switch (mem.getPortKind(i)) {
3061 case MemOp::PortKind::Debug:
3062 // Skip debug ports.
3063 return failure();
3064 case MemOp::PortKind::Write:
3065 if (failed(findWriteUsers(port, "data")))
3066 return failure();
3067 continue;
3068 case MemOp::PortKind::Read:
3069 if (failed(findReadUsers(port, "data")))
3070 return failure();
3071 continue;
3072 case MemOp::PortKind::ReadWrite:
3073 if (failed(findWriteUsers(port, "wdata")))
3074 return failure();
3075 if (failed(findReadUsers(port, "rdata")))
3076 return failure();
3077 continue;
3078 }
3079 llvm_unreachable("unknown port kind");
3080 }
3081
3082 // Unused memories are handled in a different canonicalizer.
3083 if (usedBits.none())
3084 return failure();
3085
3086 // Build a mapping of existing indices to compacted ones.
3087 SmallVector<std::pair<unsigned, unsigned>> ranges;
3088 unsigned newWidth = 0;
3089 for (int i = usedBits.find_first(); 0 <= i && i < width;) {
3090 int e = usedBits.find_next_unset(i);
3091 if (e < 0)
3092 e = width;
3093 for (int idx = i; idx < e; ++idx, ++newWidth) {
3094 if (auto it = mapping.find(idx); it != mapping.end()) {
3095 it->second = newWidth;
3096 }
3097 }
3098 ranges.emplace_back(i, e - 1);
3099 i = e != width ? usedBits.find_next(e) : e;
3100 }
3101
3102 // Create the new op with the new port types.
3103 auto newType = IntType::get(op->getContext(), type.isSigned(), newWidth);
3104 SmallVector<Type> portTypes;
3105 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3106 portTypes.push_back(
3107 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3108 }
3109 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3110 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3111 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3112 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3113 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3114
3115 // Rewrite bundle users to the new data type.
3116 auto rewriteSubfield = [&](Value port, StringRef field) {
3117 auto portTy = type_cast<BundleType>(port.getType());
3118 auto fieldIndex = portTy.getElementIndex(field);
3119 assert(fieldIndex && "missing data port");
3120
3121 rewriter.setInsertionPointAfter(newMem);
3122 auto newPortAccess =
3123 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3124
3125 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
3126 auto portAccess = cast<SubfieldOp>(op);
3127 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3128 continue;
3129 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3130 }
3131 };
3132
3133 // Rewrite the field accesses.
3134 for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
3135 switch (newMem.getPortKind(i)) {
3136 case MemOp::PortKind::Debug:
3137 llvm_unreachable("cannot rewrite debug port");
3138 case MemOp::PortKind::Write:
3139 rewriteSubfield(port, "data");
3140 continue;
3141 case MemOp::PortKind::Read:
3142 rewriteSubfield(port, "data");
3143 continue;
3144 case MemOp::PortKind::ReadWrite:
3145 rewriteSubfield(port, "rdata");
3146 rewriteSubfield(port, "wdata");
3147 continue;
3148 }
3149 llvm_unreachable("unknown port kind");
3150 }
3151
3152 // Rewrite the reads to the new ranges, compacting them.
3153 for (auto readOp : readOps) {
3154 rewriter.setInsertionPointAfter(readOp);
3155 auto it = mapping.find(readOp.getLo());
3156 assert(it != mapping.end() && "bit op mapping not found");
3157 // Create a new bit selection from the compressed memory. The new op may
3158 // be folded if we are selecting the entire compressed memory.
3159 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3160 readOp.getLoc(), readOp.getInput(),
3161 readOp.getHi() - readOp.getLo() + it->second, it->second);
3162 rewriter.replaceAllUsesWith(readOp, newReadValue);
3163 rewriter.eraseOp(readOp);
3164 }
3165
3166 // Rewrite the writes into a concatenation of slices.
3167 for (auto writeOp : writeOps) {
3168 Value source = writeOp.getSrc();
3169 rewriter.setInsertionPoint(writeOp);
3170
3171 SmallVector<Value> slices;
3172 for (auto &[start, end] : llvm::reverse(ranges)) {
3173 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3174 source, end, start);
3175 slices.push_back(slice);
3176 }
3177
3178 Value catOfSlices =
3179 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3180
3181 // If the original memory held a signed integer, then the compressed
3182 // memory will be signed too. Since the catOfSlices is always unsigned,
3183 // cast the data to a signed integer if needed before connecting back to
3184 // the memory.
3185 if (type.isSigned())
3186 catOfSlices =
3187 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3188
3189 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3190 catOfSlices);
3191 }
3192
3193 return success();
3194 }
3195};
3196
3197// Rewrite single-address memories to a firrtl register.
3198struct FoldRegMems : public mlir::RewritePattern {
3199 FoldRegMems(MLIRContext *context)
3200 : RewritePattern(MemOp::getOperationName(), 0, context) {}
3201 LogicalResult matchAndRewrite(Operation *op,
3202 PatternRewriter &rewriter) const override {
3203 MemOp mem = cast<MemOp>(op);
3204 const FirMemory &info = mem.getSummary();
3205 if (hasDontTouch(mem) || info.depth != 1)
3206 return failure();
3207
3208 auto ty = mem.getDataType();
3209 auto loc = mem.getLoc();
3210 auto *block = mem->getBlock();
3211
3212 // Find the clock of the register-to-be, all write ports should share it.
3213 Value clock;
3214 SmallPtrSet<Operation *, 8> connects;
3215 SmallVector<SubfieldOp> portAccesses;
3216 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3217 if (!mem.getPortAnnotation(i).empty())
3218 continue;
3219
3220 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3221 auto portTy = type_cast<BundleType>(port.getType());
3222 for (auto field : fields) {
3223 auto fieldIndex = portTy.getElementIndex(field);
3224 assert(fieldIndex && "missing field on memory port");
3225
3226 for (auto *op : port.getUsers()) {
3227 auto portAccess = cast<SubfieldOp>(op);
3228 if (fieldIndex != portAccess.getFieldIndex())
3229 continue;
3230 portAccesses.push_back(portAccess);
3231 for (auto *user : portAccess->getUsers()) {
3232 auto conn = dyn_cast<FConnectLike>(user);
3233 if (!conn)
3234 return failure();
3235 connects.insert(conn);
3236 }
3237 }
3238 }
3239 return success();
3240 };
3241
3242 switch (mem.getPortKind(i)) {
3243 case MemOp::PortKind::Debug:
3244 return failure();
3245 case MemOp::PortKind::Read:
3246 if (failed(collect({"clk", "en", "addr"})))
3247 return failure();
3248 continue;
3249 case MemOp::PortKind::Write:
3250 if (failed(collect({"clk", "en", "addr", "data", "mask"})))
3251 return failure();
3252 break;
3253 case MemOp::PortKind::ReadWrite:
3254 if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
3255 return failure();
3256 break;
3257 }
3258
3259 Value portClock = getPortFieldValue(port, "clk");
3260 if (!portClock || (clock && portClock != clock))
3261 return failure();
3262 clock = portClock;
3263 }
3264 // Create a new wire where the memory used to be. This wire will dominate
3265 // all readers of the memory. Reads should be made through this wire.
3266 rewriter.setInsertionPointAfter(mem);
3267 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3268
3269 // The memory is replaced by a register, which we place at the end of the
3270 // block, so that any value driven to the original memory will dominate the
3271 // new register (including the clock). All other ops will be placed
3272 // after the register.
3273 rewriter.setInsertionPointToEnd(block);
3274 auto memReg =
3275 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3276
3277 // Connect the output of the register to the wire.
3278 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3279
3280 // Helper to insert a given number of pipeline stages through registers.
3281 // The pipelines are placed at the end of the block.
3282 auto pipeline = [&](Value value, Value clock, const Twine &name,
3283 unsigned latency) {
3284 for (unsigned i = 0; i < latency; ++i) {
3285 std::string regName;
3286 {
3287 llvm::raw_string_ostream os(regName);
3288 os << mem.getName() << "_" << name << "_" << i;
3289 }
3290 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3291 rewriter.getStringAttr(regName))
3292 .getResult();
3293 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3294 value = reg;
3295 }
3296 return value;
3297 };
3298
3299 const unsigned writeStages = info.writeLatency - 1;
3300
3301 // Traverse each port. Replace reads with the pipelined register, discarding
3302 // the enable flag and reading unconditionally. Pipeline the mask, enable
3303 // and data bits of all write ports to be arbitrated and wired to the reg.
3304 SmallVector<std::tuple<Value, Value, Value>> writes;
3305 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3306 Value portClock = getPortFieldValue(port, "clk");
3307 StringRef name = mem.getPortName(i);
3308
3309 auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
3310 Value value = getPortFieldValue(port, field);
3311 assert(value);
3312 return pipeline(value, portClock, name + "_" + field, stages);
3313 };
3314
3315 switch (mem.getPortKind(i)) {
3316 case MemOp::PortKind::Debug:
3317 llvm_unreachable("unknown port kind");
3318 case MemOp::PortKind::Read: {
3319 // Read ports pipeline the addr and enable signals. However, the
3320 // address must be 0 for single-address memories and the enable signal
3321 // is ignored, always reading out the register. Under these constraints,
3322 // the read port can be replaced with the value from the register.
3323 replacePortField(rewriter, port, "data", memWire);
3324 break;
3325 }
3326 case MemOp::PortKind::Write: {
3327 auto data = portPipeline("data", writeStages);
3328 auto en = portPipeline("en", writeStages);
3329 auto mask = portPipeline("mask", writeStages);
3330 writes.emplace_back(data, en, mask);
3331 break;
3332 }
3333 case MemOp::PortKind::ReadWrite: {
3334 // Always read the register into the read end.
3335 replacePortField(rewriter, port, "rdata", memWire);
3336
3337 // Create a write enable and pipeline stages.
3338 auto wdata = portPipeline("wdata", writeStages);
3339 auto wmask = portPipeline("wmask", writeStages);
3340
3341 Value en = getPortFieldValue(port, "en");
3342 Value wmode = getPortFieldValue(port, "wmode");
3343
3344 auto wen = AndPrimOp::create(rewriter, port.getLoc(), en, wmode);
3345 auto wenPipelined =
3346 pipeline(wen, portClock, name + "_wen", writeStages);
3347 writes.emplace_back(wdata, wenPipelined, wmask);
3348 break;
3349 }
3350 }
3351 }
3352
3353 // Regardless of `writeUnderWrite`, always implement PortOrder.
3354 Value next = memReg;
3355 for (auto &[data, en, mask] : writes) {
3356 Value masked;
3357
3358 // If a mask bit is used, emit muxes to select the input from the
3359 // register (no mask) or the input (mask bit set).
3360 Location loc = mem.getLoc();
3361 unsigned maskGran = info.dataWidth / info.maskBits;
3362 SmallVector<Value> chunks;
3363 for (unsigned i = 0; i < info.maskBits; ++i) {
3364 unsigned hi = (i + 1) * maskGran - 1;
3365 unsigned lo = i * maskGran;
3366
3367 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
3368 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3369 auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
3370 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3371 chunks.push_back(chunk);
3372 }
3373
3374 std::reverse(chunks.begin(), chunks.end());
3375 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3376 next = MuxPrimOp::create(rewriter, next.getLoc(), en, masked, next);
3377 }
3378 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3379 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3380
3381 // Delete the fields and their associated connects.
3382 for (Operation *conn : connects)
3383 rewriter.eraseOp(conn);
3384 for (auto portAccess : portAccesses)
3385 rewriter.eraseOp(portAccess);
3386 rewriter.eraseOp(mem);
3387
3388 return success();
3389 }
3390};
3391} // namespace
3392
3393void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3394 MLIRContext *context) {
3395 results
3396 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3397 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3398 context);
3399}
3400
3401//===----------------------------------------------------------------------===//
3402// Declarations
3403//===----------------------------------------------------------------------===//
3404
3405// Turn synchronous reset looking register updates to registers with resets.
3406// Also, const prop registers that are driven by a mux tree containing only
3407// instances of one constant or self-assigns.
3408static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3409 // reg ; connect(reg, mux(port, const, val)) ->
3410 // reg.reset(port, const); connect(reg, val)
3411
3412 // Find the one true connect, or bail
3413 auto con = getSingleConnectUserOf(reg.getResult());
3414 if (!con)
3415 return failure();
3416
3417 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3418 if (!mux)
3419 return failure();
3420 auto *high = mux.getHigh().getDefiningOp();
3421 auto *low = mux.getLow().getDefiningOp();
3422 // Reset value must be constant
3423 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3424
3425 // Detect the case if a register only has two possible drivers:
3426 // (1) itself/uninit and (2) constant.
3427 // The mux can then be replaced with the constant.
3428 // r = mux(cond, r, 3) --> r = 3
3429 // r = mux(cond, 3, r) --> r = 3
3430 bool constReg = false;
3431
3432 if (constOp && low == reg)
3433 constReg = true;
3434 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3435 constReg = true;
3436 constOp = dyn_cast<ConstantOp>(low);
3437 }
3438 if (!constOp)
3439 return failure();
3440
3441 // For a non-constant register, reset should be a module port (heuristic to
3442 // limit to intended reset lines). Replace the register anyway if constant.
3443 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3444 return failure();
3445
3446 // Check all types should be typed by now
3447 auto regTy = reg.getResult().getType();
3448 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3449 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3450 regTy.getBitWidthOrSentinel() < 0)
3451 return failure();
3452
3453 // Ok, we know we are doing the transformation.
3454
3455 // Make sure the constant dominates all users.
3456 if (constOp != &con->getBlock()->front())
3457 constOp->moveBefore(&con->getBlock()->front());
3458
3459 if (!constReg) {
3460 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3461 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3462 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3463 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3464 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3465 reg.getForceableAttr());
3466 newReg->setDialectAttrs(attrs);
3467 }
3468 auto pt = rewriter.saveInsertionPoint();
3469 rewriter.setInsertionPoint(con);
3470 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3471 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3472 rewriter.restoreInsertionPoint(pt);
3473 return success();
3474}
3475
3476LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3477 if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3478 succeeded(foldHiddenReset(op, rewriter)))
3479 return success();
3480
3481 if (succeeded(demoteForceableIfUnused(op, rewriter)))
3482 return success();
3483
3484 return failure();
3485}
3486
3487//===----------------------------------------------------------------------===//
3488// Verification Ops.
3489//===----------------------------------------------------------------------===//
3490
3491static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3492 Value enable,
3493 PatternRewriter &rewriter,
3494 bool eraseIfZero) {
3495 // If the verification op is never enabled, delete it.
3496 if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3497 if (constant.getValue().isZero()) {
3498 rewriter.eraseOp(op);
3499 return success();
3500 }
3501 }
3502
3503 // If the verification op is never triggered, delete it.
3504 if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3505 if (constant.getValue().isZero() == eraseIfZero) {
3506 rewriter.eraseOp(op);
3507 return success();
3508 }
3509 }
3510
3511 return failure();
3512}
3513
3514template <class Op, bool EraseIfZero = false>
3515static LogicalResult canonicalizeImmediateVerifOp(Op op,
3516 PatternRewriter &rewriter) {
3517 return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3518 EraseIfZero);
3519}
3520
3521void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3522 MLIRContext *context) {
3523 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3524 results.add<patterns::AssertXWhenX>(context);
3525}
3526
3527void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3528 MLIRContext *context) {
3529 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3530 results.add<patterns::AssumeXWhenX>(context);
3531}
3532
3533void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3534 RewritePatternSet &results, MLIRContext *context) {
3535 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3536 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(context);
3537}
3538
3539void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3540 MLIRContext *context) {
3541 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3542}
3543
3544//===----------------------------------------------------------------------===//
3545// InvalidValueOp
3546//===----------------------------------------------------------------------===//
3547
3548LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3549 PatternRewriter &rewriter) {
3550 // Remove `InvalidValueOp`s with no uses.
3551 if (op.use_empty()) {
3552 rewriter.eraseOp(op);
3553 return success();
3554 }
3555 // Propagate invalids through a single use which is a unary op. You cannot
3556 // propagate through multiple uses as that breaks invalid semantics. Nor
3557 // can you propagate through binary ops or generally any op which computes.
3558 // Not is an exception as it is a pure, all-bits inverse.
3559 if (op->hasOneUse() &&
3560 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3561 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3562 *op->user_begin()) ||
3563 (isa<CvtPrimOp>(*op->user_begin()) &&
3564 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3565 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3566 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3567 .getBitWidthOrSentinel() > 0))) {
3568 auto *modop = *op->user_begin();
3569 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3570 modop->getResult(0).getType());
3571 rewriter.replaceAllOpUsesWith(modop, inv);
3572 rewriter.eraseOp(modop);
3573 rewriter.eraseOp(op);
3574 return success();
3575 }
3576 return failure();
3577}
3578
3579OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3580 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3581 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3582 return {};
3583}
3584
3585//===----------------------------------------------------------------------===//
3586// ClockGateIntrinsicOp
3587//===----------------------------------------------------------------------===//
3588
3589OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3590 // Forward the clock if one of the enables is always true.
3591 if (isConstantOne(adaptor.getEnable()) ||
3592 isConstantOne(adaptor.getTestEnable()))
3593 return getInput();
3594
3595 // Fold to a constant zero clock if the enables are always false.
3596 if (isConstantZero(adaptor.getEnable()) &&
3597 (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3598 return BoolAttr::get(getContext(), false);
3599
3600 // Forward constant zero clocks.
3601 if (isConstantZero(adaptor.getInput()))
3602 return BoolAttr::get(getContext(), false);
3603
3604 return {};
3605}
3606
3607LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3608 PatternRewriter &rewriter) {
3609 // Remove constant false test enable.
3610 if (auto testEnable = op.getTestEnable()) {
3611 if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3612 if (constOp.getValue().isZero()) {
3613 rewriter.modifyOpInPlace(op,
3614 [&] { op.getTestEnableMutable().clear(); });
3615 return success();
3616 }
3617 }
3618 }
3619
3620 return failure();
3621}
3622
3623//===----------------------------------------------------------------------===//
3624// Reference Ops.
3625//===----------------------------------------------------------------------===//
3626
3627// refresolve(forceable.ref) -> forceable.data
3628static LogicalResult
3629canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3630 auto forceable = op.getRef().getDefiningOp<Forceable>();
3631 if (!forceable || !forceable.isForceable() ||
3632 op.getRef() != forceable.getDataRef() ||
3633 op.getType() != forceable.getDataType())
3634 return failure();
3635 rewriter.replaceAllUsesWith(op, forceable.getData());
3636 return success();
3637}
3638
3639void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3640 MLIRContext *context) {
3641 results.insert<patterns::RefResolveOfRefSend>(context);
3642 results.insert(canonicalizeRefResolveOfForceable);
3643}
3644
3645OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3646 // RefCast is unnecessary if types match.
3647 if (getInput().getType() == getType())
3648 return getInput();
3649 return {};
3650}
3651
3652static bool isConstantZero(Value operand) {
3653 auto constOp = operand.getDefiningOp<ConstantOp>();
3654 return constOp && constOp.getValue().isZero();
3655}
3656
3657template <typename Op>
3658static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3659 if (isConstantZero(op.getPredicate())) {
3660 rewriter.eraseOp(op);
3661 return success();
3662 }
3663 return failure();
3664}
3665
3666void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3667 MLIRContext *context) {
3668 results.add(eraseIfPredFalse<RefForceOp>);
3669}
3670void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3671 MLIRContext *context) {
3672 results.add(eraseIfPredFalse<RefForceInitialOp>);
3673}
3674void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3675 MLIRContext *context) {
3676 results.add(eraseIfPredFalse<RefReleaseOp>);
3677}
3678void RefReleaseInitialOp::getCanonicalizationPatterns(
3679 RewritePatternSet &results, MLIRContext *context) {
3680 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3681}
3682
3683//===----------------------------------------------------------------------===//
3684// HasBeenResetIntrinsicOp
3685//===----------------------------------------------------------------------===//
3686
3687OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3688 // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3689
3690 // Fold to zero if the reset is a constant. In this case the op is either
3691 // permanently in reset or never resets. Both mean that the reset never
3692 // finishes, so this op never returns true.
3693 if (adaptor.getReset())
3694 return getIntZerosAttr(UIntType::get(getContext(), 1));
3695
3696 // Fold to zero if the clock is a constant and the reset is synchronous. In
3697 // that case the reset will never be started.
3698 if (isUInt1(getReset().getType()) && adaptor.getClock())
3699 return getIntZerosAttr(UIntType::get(getContext(), 1));
3700
3701 return {};
3702}
3703
3704//===----------------------------------------------------------------------===//
3705// FPGAProbeIntrinsicOp
3706//===----------------------------------------------------------------------===//
3707
3708static bool isTypeEmpty(FIRRTLType type) {
3710 .Case<FVectorType>(
3711 [&](auto ty) -> bool { return isTypeEmpty(ty.getElementType()); })
3712 .Case<BundleType>([&](auto ty) -> bool {
3713 for (auto elem : ty.getElements())
3714 if (!isTypeEmpty(elem.type))
3715 return false;
3716 return true;
3717 })
3718 .Case<IntType>([&](auto ty) { return ty.getWidth() == 0; })
3719 .Default([](auto) -> bool { return false; });
3720}
3721
3722LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3723 PatternRewriter &rewriter) {
3724 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3725 if (!firrtlTy)
3726 return failure();
3727
3728 if (!isTypeEmpty(firrtlTy))
3729 return failure();
3730
3731 rewriter.eraseOp(op);
3732 return success();
3733}
3734
3735//===----------------------------------------------------------------------===//
3736// Layer Block Op
3737//===----------------------------------------------------------------------===//
3738
3739LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3740 PatternRewriter &rewriter) {
3741
3742 // If the layerblock is empty, erase it.
3743 if (op.getBody()->empty()) {
3744 rewriter.eraseOp(op);
3745 return success();
3746 }
3747
3748 return failure();
3749}
3750
3751//===----------------------------------------------------------------------===//
3752// Domain-related Ops
3753//===----------------------------------------------------------------------===//
3754
3755OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3756 // If no domains are specified, then forward the input to the result.
3757 if (getDomains().empty())
3758 return getInput();
3759
3760 return {};
3761}
assert(baseType &&"element must be base type")
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:18
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition FoldUtils.h:28
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition FoldUtils.h:35
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
BitsOfCat(MLIRContext *context)