CIRCT 22.0.0git
Loading...
Searching...
No Matches
FIRRTLFolds.cpp
Go to the documentation of this file.
1//===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the folding and canonicalizations for FIRRTL ops.
10//
11//===----------------------------------------------------------------------===//
12
17#include "circt/Support/APInt.h"
18#include "circt/Support/LLVM.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/SmallPtrSet.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/ADT/TypeSwitch.h"
26
27using namespace circt;
28using namespace firrtl;
29
30// Drop writes to old and pass through passthrough to make patterns easier to
31// write.
32static Value dropWrite(PatternRewriter &rewriter, OpResult old,
33 Value passthrough) {
34 SmallPtrSet<Operation *, 8> users;
35 for (auto *user : old.getUsers())
36 users.insert(user);
37 for (Operation *user : users)
38 if (auto connect = dyn_cast<FConnectLike>(user))
39 if (connect.getDest() == old)
40 rewriter.eraseOp(user);
41 return passthrough;
42}
43
44// Return true if it is OK to propagate the name to the operation.
45// Non-pure operations such as instances, registers, and memories are not
46// allowed to update names for name stabilities and LEC reasons.
47static bool isOkToPropagateName(Operation *op) {
48 // Conservatively disallow operations with regions to prevent performance
49 // regression due to recursive calls to sub-regions in mlir::isPure.
50 if (op->getNumRegions() != 0)
51 return false;
52 return mlir::isPure(op) || isa<NodeOp, WireOp>(op);
53}
54
55// Move a name hint from a soon to be deleted operation to a new operation.
56// Pass through the new operation to make patterns easier to write. This cannot
57// move a name to a port (block argument), doing so would require rewriting all
58// instance sites as well as the module.
59static Value moveNameHint(OpResult old, Value passthrough) {
60 Operation *op = passthrough.getDefiningOp();
61 // This should handle ports, but it isn't clear we can change those in
62 // canonicalizers.
63 assert(op && "passthrough must be an operation");
64 Operation *oldOp = old.getOwner();
65 auto name = oldOp->getAttrOfType<StringAttr>("name");
66 if (name && !name.getValue().empty() && isOkToPropagateName(op))
67 op->setAttr("name", name);
68 return passthrough;
69}
70
71// Declarative canonicalization patterns
72namespace circt {
73namespace firrtl {
74namespace patterns {
75#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
76} // namespace patterns
77} // namespace firrtl
78} // namespace circt
79
80/// Return true if this operation's operands and results all have a known width.
81/// This only works for integer types.
82static bool hasKnownWidthIntTypes(Operation *op) {
83 auto resultType = type_cast<IntType>(op->getResult(0).getType());
84 if (!resultType.hasWidth())
85 return false;
86 for (Value operand : op->getOperands())
87 if (!type_cast<IntType>(operand.getType()).hasWidth())
88 return false;
89 return true;
90}
91
92/// Return true if this value is 1 bit UInt.
93static bool isUInt1(Type type) {
94 auto t = type_dyn_cast<UIntType>(type);
95 if (!t || !t.hasWidth() || t.getWidth() != 1)
96 return false;
97 return true;
98}
99
100/// Set the name of an op based on the best of two names: The current name, and
101/// the name passed in.
102static void updateName(PatternRewriter &rewriter, Operation *op,
103 StringAttr name) {
104 // Should never rename InstanceOp
105 if (!name || name.getValue().empty() || !isOkToPropagateName(op))
106 return;
107 assert((!isa<InstanceOp, RegOp, RegResetOp>(op)) && "Should never rename");
108 auto newName = name.getValue(); // old name is interesting
109 auto newOpName = op->getAttrOfType<StringAttr>("name");
110 // new name might not be interesting
111 if (newOpName)
112 newName = chooseName(newOpName.getValue(), name.getValue());
113 // Only update if needed
114 if (!newOpName || newOpName.getValue() != newName)
115 rewriter.modifyOpInPlace(
116 op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
117}
118
119/// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
120/// If a replaced op has a "name" attribute, this function propagates the name
121/// to the new value.
122static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
123 Value newValue) {
124 if (auto *newOp = newValue.getDefiningOp()) {
125 auto name = op->getAttrOfType<StringAttr>("name");
126 updateName(rewriter, newOp, name);
127 }
128 rewriter.replaceOp(op, newValue);
129}
130
131/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
132/// attribute. If a replaced op has a "name" attribute, this function propagates
133/// the name to the new value.
134template <typename OpTy, typename... Args>
135static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
136 Operation *op, Args &&...args) {
137 auto name = op->getAttrOfType<StringAttr>("name");
138 auto newOp =
139 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
140 updateName(rewriter, newOp, name);
141 return newOp;
142}
143
144/// Return true if the name is droppable. Note that this is different from
145/// `isUselessName` because non-useless names may be also droppable.
147 if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
148 return namableOp.hasDroppableName();
149 return false;
150}
151
152/// Implicitly replace the operand to a constant folding operation with a const
153/// 0 in case the operand is non-constant but has a bit width 0, or if the
154/// operand is an invalid value.
155///
156/// This makes constant folding significantly easier, as we can simply pass the
157/// operands to an operation through this function to appropriately replace any
158/// zero-width dynamic values or invalid values with a constant of value 0.
159static std::optional<APSInt>
160getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
161 assert(type_cast<IntType>(operand.getType()) &&
162 "getExtendedConstant is limited to integer types");
163
164 // We never support constant folding to unknown width values.
165 if (destWidth < 0)
166 return {};
167
168 // Extension signedness follows the operand sign.
169 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
170 return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
171
172 // If the operand is zero bits, then we can return a zero of the result
173 // type.
174 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
175 return APSInt(destWidth,
176 type_cast<IntType>(operand.getType()).isUnsigned());
177 return {};
178}
179
180/// Determine the value of a constant operand for the sake of constant folding.
181static std::optional<APSInt> getConstant(Attribute operand) {
182 if (!operand)
183 return {};
184 if (auto attr = dyn_cast<BoolAttr>(operand))
185 return APSInt(APInt(1, attr.getValue()));
186 if (auto attr = dyn_cast<IntegerAttr>(operand))
187 return attr.getAPSInt();
188 return {};
189}
190
191/// Determine whether a constant operand is a zero value for the sake of
192/// constant folding. This considers `invalidvalue` to be zero.
193static bool isConstantZero(Attribute operand) {
194 if (auto cst = getConstant(operand))
195 return cst->isZero();
196 return false;
197}
198
199/// Determine whether a constant operand is a one value for the sake of constant
200/// folding.
201static bool isConstantOne(Attribute operand) {
202 if (auto cst = getConstant(operand))
203 return cst->isOne();
204 return false;
205}
206
207/// This is the policy for folding, which depends on the sort of operator we're
208/// processing.
209enum class BinOpKind {
210 Normal,
211 Compare,
213};
214
215/// Applies the constant folding function `calculate` to the given operands.
216///
217/// Sign or zero extends the operands appropriately to the bitwidth of the
218/// result type if \p useDstWidth is true, else to the larger of the two operand
219/// bit widths and depending on whether the operation is to be performed on
220/// signed or unsigned operands.
221static Attribute constFoldFIRRTLBinaryOp(
222 Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
223 const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
224 assert(operands.size() == 2 && "binary op takes two operands");
225
226 // We cannot fold something to an unknown width.
227 auto resultType = type_cast<IntType>(op->getResult(0).getType());
228 if (resultType.getWidthOrSentinel() < 0)
229 return {};
230
231 // Any binary op returning i0 is 0.
232 if (resultType.getWidthOrSentinel() == 0)
233 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
234
235 // Determine the operand widths. This is either dictated by the operand type,
236 // or if that type is an unsized integer, by the actual bits necessary to
237 // represent the constant value.
238 auto lhsWidth =
239 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
240 auto rhsWidth =
241 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
242 if (auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
243 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
244 if (auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
245 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
246
247 // Compares extend the operands to the widest of the operand types, not to the
248 // result type.
249 int32_t operandWidth;
250 switch (opKind) {
252 operandWidth = resultType.getWidthOrSentinel();
253 break;
255 // Compares compute with the widest operand, not at the destination type
256 // (which is always i1).
257 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
258 break;
260 operandWidth =
261 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
262 break;
263 }
264
265 auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
266 if (!lhs)
267 return {};
268 auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
269 if (!rhs)
270 return {};
271
272 APInt resultValue = calculate(*lhs, *rhs);
273
274 // If the result type is smaller than the computation then we need to
275 // narrow the constant after the calculation.
276 if (opKind == BinOpKind::DivideOrShift)
277 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
278
279 assert((unsigned)resultType.getWidthOrSentinel() ==
280 resultValue.getBitWidth());
281 return getIntAttr(resultType, resultValue);
282}
283
284/// Applies the canonicalization function `canonicalize` to the given operation.
285///
286/// Determines which (if any) of the operation's operands are constants, and
287/// provides them as arguments to the callback function. Any `invalidvalue` in
288/// the input is mapped to a constant zero. The value returned from the callback
289/// is used as the replacement for `op`, and an additional pad operation is
290/// inserted if necessary. Does nothing if the result of `op` is of unknown
291/// width, in which case the necessity of a pad cannot be determined.
292static LogicalResult canonicalizePrimOp(
293 Operation *op, PatternRewriter &rewriter,
294 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
295 // Can only operate on FIRRTL primitive operations.
296 if (op->getNumResults() != 1)
297 return failure();
298 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
299 if (!type)
300 return failure();
301
302 // Can only operate on operations with a known result width.
303 auto width = type.getBitWidthOrSentinel();
304 if (width < 0)
305 return failure();
306
307 // Determine which of the operands are constants.
308 SmallVector<Attribute, 3> constOperands;
309 constOperands.reserve(op->getNumOperands());
310 for (auto operand : op->getOperands()) {
311 Attribute attr;
312 if (auto *defOp = operand.getDefiningOp())
313 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
314 [&](auto op) { attr = op.getValueAttr(); });
315 constOperands.push_back(attr);
316 }
317
318 // Perform the canonicalization and materialize the result if it is a
319 // constant.
320 auto result = canonicalize(constOperands);
321 if (!result)
322 return failure();
323 Value resultValue;
324 if (auto cst = dyn_cast<Attribute>(result))
325 resultValue = op->getDialect()
326 ->materializeConstant(rewriter, cst, type, op->getLoc())
327 ->getResult(0);
328 else
329 resultValue = cast<Value>(result);
330
331 // Insert a pad if the type widths disagree.
332 if (width !=
333 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
334 resultValue = PadPrimOp::create(rewriter, op->getLoc(), resultValue, width);
335
336 // Insert a cast if this is a uint vs. sint or vice versa.
337 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
338 resultValue = AsSIntPrimOp::create(rewriter, op->getLoc(), resultValue);
339 else if (type_isa<UIntType>(type) &&
340 type_isa<SIntType>(resultValue.getType()))
341 resultValue = AsUIntPrimOp::create(rewriter, op->getLoc(), resultValue);
342
343 assert(type == resultValue.getType() && "canonicalization changed type");
344 replaceOpAndCopyName(rewriter, op, resultValue);
345 return success();
346}
347
348/// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
349/// value if `bitWidth` is 0.
350static APInt getMaxUnsignedValue(unsigned bitWidth) {
351 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
352}
353
354/// Get the smallest signed value of a given bit width. Returns a 1-bit zero
355/// value if `bitWidth` is 0.
356static APInt getMinSignedValue(unsigned bitWidth) {
357 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
358}
359
360/// Get the largest signed value of a given bit width. Returns a 1-bit zero
361/// value if `bitWidth` is 0.
362static APInt getMaxSignedValue(unsigned bitWidth) {
363 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
364}
365
366//===----------------------------------------------------------------------===//
367// Fold Hooks
368//===----------------------------------------------------------------------===//
369
370OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
371 assert(adaptor.getOperands().empty() && "constant has no operands");
372 return getValueAttr();
373}
374
375OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
376 assert(adaptor.getOperands().empty() && "constant has no operands");
377 return getValueAttr();
378}
379
380OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
381 assert(adaptor.getOperands().empty() && "constant has no operands");
382 return getFieldsAttr();
383}
384
385OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
386 assert(adaptor.getOperands().empty() && "constant has no operands");
387 return getValueAttr();
388}
389
390OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
391 assert(adaptor.getOperands().empty() && "constant has no operands");
392 return getValueAttr();
393}
394
395OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
396 assert(adaptor.getOperands().empty() && "constant has no operands");
397 return getValueAttr();
398}
399
400OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
401 assert(adaptor.getOperands().empty() && "constant has no operands");
402 return getValueAttr();
403}
404
405//===----------------------------------------------------------------------===//
406// Binary Operators
407//===----------------------------------------------------------------------===//
408
409OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
411 *this, adaptor.getOperands(), BinOpKind::Normal,
412 [=](const APSInt &a, const APSInt &b) { return a + b; });
413}
414
415void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
416 MLIRContext *context) {
417 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
418 patterns::AddOfSelf, patterns::AddOfPad>(context);
419}
420
421OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
423 *this, adaptor.getOperands(), BinOpKind::Normal,
424 [=](const APSInt &a, const APSInt &b) { return a - b; });
425}
426
427void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
428 MLIRContext *context) {
429 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
430 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
431 patterns::SubOfPadL, patterns::SubOfPadR>(context);
432}
433
434OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
435 // mul(x, 0) -> 0
436 //
437 // This is legal because it aligns with the Scala FIRRTL Compiler
438 // interpretation of lowering invalid to constant zero before constant
439 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
440 // multiplication this way and will emit "x * 0".
441 if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
442 return getIntZerosAttr(getType());
443
445 *this, adaptor.getOperands(), BinOpKind::Normal,
446 [=](const APSInt &a, const APSInt &b) { return a * b; });
447}
448
449OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
450 /// div(x, x) -> 1
451 ///
452 /// Division by zero is undefined in the FIRRTL specification. This fold
453 /// exploits that fact to optimize self division to one. Note: this should
454 /// supersede any division with invalid or zero. Division of invalid by
455 /// invalid should be one.
456 if (getLhs() == getRhs()) {
457 auto width = getType().base().getWidthOrSentinel();
458 if (width == -1)
459 width = 2;
460 // Only fold if we have at least 1 bit of width to represent the `1` value.
461 if (width != 0)
462 return getIntAttr(getType(), APInt(width, 1));
463 }
464
465 // div(0, x) -> 0
466 //
467 // This is legal because it aligns with the Scala FIRRTL Compiler
468 // interpretation of lowering invalid to constant zero before constant
469 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
470 // division this way and will emit "0 / x".
471 if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
472 return getIntZerosAttr(getType());
473
474 /// div(x, 1) -> x : (uint, uint) -> uint
475 ///
476 /// UInt division by one returns the numerator. SInt division can't
477 /// be folded here because it increases the return type bitwidth by
478 /// one and requires sign extension (a new op).
479 if (auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
480 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
481 return getLhs();
482
484 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
485 [=](const APSInt &a, const APSInt &b) -> APInt {
486 if (!!b)
487 return a / b;
488 return APInt(a.getBitWidth(), 0);
489 });
490}
491
492OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
493 // rem(x, x) -> 0
494 //
495 // Division by zero is undefined in the FIRRTL specification. This fold
496 // exploits that fact to optimize self division remainder to zero. Note:
497 // this should supersede any division with invalid or zero. Remainder of
498 // division of invalid by invalid should be zero.
499 if (getLhs() == getRhs())
500 return getIntZerosAttr(getType());
501
502 // rem(0, x) -> 0
503 //
504 // This is legal because it aligns with the Scala FIRRTL Compiler
505 // interpretation of lowering invalid to constant zero before constant
506 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
507 // division this way and will emit "0 % x".
508 if (isConstantZero(adaptor.getLhs()))
509 return getIntZerosAttr(getType());
510
512 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
513 [=](const APSInt &a, const APSInt &b) -> APInt {
514 if (!!b)
515 return a % b;
516 return APInt(a.getBitWidth(), 0);
517 });
518}
519
520OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
522 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
523 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
524}
525
526OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
528 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
529 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
530}
531
532OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
534 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
535 [=](const APSInt &a, const APSInt &b) -> APInt {
536 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
537 : a.ashr(b);
538 });
539}
540
541// TODO: Move to DRR.
542OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
543 if (auto rhsCst = getConstant(adaptor.getRhs())) {
544 /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
545 if (rhsCst->isZero())
546 return getIntZerosAttr(getType());
547
548 /// and(x, -1) -> x
549 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
550 getRhs().getType() == getType())
551 return getLhs();
552 }
553
554 if (auto lhsCst = getConstant(adaptor.getLhs())) {
555 /// and(0, x) -> 0, 0 is largest or is implicit zero extended
556 if (lhsCst->isZero())
557 return getIntZerosAttr(getType());
558
559 /// and(-1, x) -> x
560 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
561 getRhs().getType() == getType())
562 return getRhs();
563 }
564
565 /// and(x, x) -> x
566 if (getLhs() == getRhs() && getRhs().getType() == getType())
567 return getRhs();
568
570 *this, adaptor.getOperands(), BinOpKind::Normal,
571 [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
572}
573
574void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
575 MLIRContext *context) {
576 results
577 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
578 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
579 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
580}
581
582OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
583 if (auto rhsCst = getConstant(adaptor.getRhs())) {
584 /// or(x, 0) -> x
585 if (rhsCst->isZero() && getLhs().getType() == getType())
586 return getLhs();
587
588 /// or(x, -1) -> -1
589 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
590 getLhs().getType() == getType())
591 return getRhs();
592 }
593
594 if (auto lhsCst = getConstant(adaptor.getLhs())) {
595 /// or(0, x) -> x
596 if (lhsCst->isZero() && getRhs().getType() == getType())
597 return getRhs();
598
599 /// or(-1, x) -> -1
600 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
601 getRhs().getType() == getType())
602 return getLhs();
603 }
604
605 /// or(x, x) -> x
606 if (getLhs() == getRhs() && getRhs().getType() == getType())
607 return getRhs();
608
610 *this, adaptor.getOperands(), BinOpKind::Normal,
611 [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
612}
613
614void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
615 MLIRContext *context) {
616 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
617 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
618 patterns::OrOrr>(context);
619}
620
621OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
622 /// xor(x, 0) -> x
623 if (auto rhsCst = getConstant(adaptor.getRhs()))
624 if (rhsCst->isZero() &&
625 firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
626 return getLhs();
627
628 /// xor(x, 0) -> x
629 if (auto lhsCst = getConstant(adaptor.getLhs()))
630 if (lhsCst->isZero() &&
631 firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
632 return getRhs();
633
634 /// xor(x, x) -> 0
635 if (getLhs() == getRhs())
636 return getIntAttr(
637 getType(),
638 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
639
641 *this, adaptor.getOperands(), BinOpKind::Normal,
642 [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
643}
644
645void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
646 MLIRContext *context) {
647 results.insert<patterns::extendXor, patterns::moveConstXor,
648 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
649 context);
650}
651
652void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
653 MLIRContext *context) {
654 results.insert<patterns::LEQWithConstLHS>(context);
655}
656
657OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
658 bool isUnsigned = getLhs().getType().base().isUnsigned();
659
660 // leq(x, x) -> 1
661 if (getLhs() == getRhs())
662 return getIntAttr(getType(), APInt(1, 1));
663
664 // Comparison against constant outside type bounds.
665 if (auto width = getLhs().getType().base().getWidth()) {
666 if (auto rhsCst = getConstant(adaptor.getRhs())) {
667 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
668 commonWidth = std::max(commonWidth, 1);
669
670 // leq(x, const) -> 0 where const < minValue of the unsigned type of x
671 // This can never occur since const is unsigned and cannot be less than 0.
672
673 // leq(x, const) -> 0 where const < minValue of the signed type of x
674 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
675 .slt(getMinSignedValue(*width).sext(commonWidth)))
676 return getIntAttr(getType(), APInt(1, 0));
677
678 // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
679 if (isUnsigned && rhsCst->zext(commonWidth)
680 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
681 return getIntAttr(getType(), APInt(1, 1));
682
683 // leq(x, const) -> 1 where const >= maxValue of the signed type of x
684 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
685 .sge(getMaxSignedValue(*width).sext(commonWidth)))
686 return getIntAttr(getType(), APInt(1, 1));
687 }
688 }
689
691 *this, adaptor.getOperands(), BinOpKind::Compare,
692 [=](const APSInt &a, const APSInt &b) -> APInt {
693 return APInt(1, a <= b);
694 });
695}
696
697void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
698 MLIRContext *context) {
699 results.insert<patterns::LTWithConstLHS>(context);
700}
701
702OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
703 IntType lhsType = getLhs().getType();
704 bool isUnsigned = lhsType.isUnsigned();
705
706 // lt(x, x) -> 0
707 if (getLhs() == getRhs())
708 return getIntAttr(getType(), APInt(1, 0));
709
710 // lt(x, 0) -> 0 when x is unsigned
711 if (auto rhsCst = getConstant(adaptor.getRhs())) {
712 if (rhsCst->isZero() && lhsType.isUnsigned())
713 return getIntAttr(getType(), APInt(1, 0));
714 }
715
716 // Comparison against constant outside type bounds.
717 if (auto width = lhsType.getWidth()) {
718 if (auto rhsCst = getConstant(adaptor.getRhs())) {
719 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
720 commonWidth = std::max(commonWidth, 1);
721
722 // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
723 // Handled explicitly above.
724
725 // lt(x, const) -> 0 where const <= minValue of the signed type of x
726 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
727 .sle(getMinSignedValue(*width).sext(commonWidth)))
728 return getIntAttr(getType(), APInt(1, 0));
729
730 // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
731 if (isUnsigned && rhsCst->zext(commonWidth)
732 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
733 return getIntAttr(getType(), APInt(1, 1));
734
735 // lt(x, const) -> 1 where const > maxValue of the signed type of x
736 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
737 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
738 return getIntAttr(getType(), APInt(1, 1));
739 }
740 }
741
743 *this, adaptor.getOperands(), BinOpKind::Compare,
744 [=](const APSInt &a, const APSInt &b) -> APInt {
745 return APInt(1, a < b);
746 });
747}
748
749void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
750 MLIRContext *context) {
751 results.insert<patterns::GEQWithConstLHS>(context);
752}
753
754OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
755 IntType lhsType = getLhs().getType();
756 bool isUnsigned = lhsType.isUnsigned();
757
758 // geq(x, x) -> 1
759 if (getLhs() == getRhs())
760 return getIntAttr(getType(), APInt(1, 1));
761
762 // geq(x, 0) -> 1 when x is unsigned
763 if (auto rhsCst = getConstant(adaptor.getRhs())) {
764 if (rhsCst->isZero() && isUnsigned)
765 return getIntAttr(getType(), APInt(1, 1));
766 }
767
768 // Comparison against constant outside type bounds.
769 if (auto width = lhsType.getWidth()) {
770 if (auto rhsCst = getConstant(adaptor.getRhs())) {
771 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
772 commonWidth = std::max(commonWidth, 1);
773
774 // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
775 if (isUnsigned && rhsCst->zext(commonWidth)
776 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
777 return getIntAttr(getType(), APInt(1, 0));
778
779 // geq(x, const) -> 0 where const > maxValue of the signed type of x
780 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
781 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
782 return getIntAttr(getType(), APInt(1, 0));
783
784 // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
785 // Handled explicitly above.
786
787 // geq(x, const) -> 1 where const <= minValue of the signed type of x
788 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
789 .sle(getMinSignedValue(*width).sext(commonWidth)))
790 return getIntAttr(getType(), APInt(1, 1));
791 }
792 }
793
795 *this, adaptor.getOperands(), BinOpKind::Compare,
796 [=](const APSInt &a, const APSInt &b) -> APInt {
797 return APInt(1, a >= b);
798 });
799}
800
801void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
802 MLIRContext *context) {
803 results.insert<patterns::GTWithConstLHS>(context);
804}
805
806OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
807 IntType lhsType = getLhs().getType();
808 bool isUnsigned = lhsType.isUnsigned();
809
810 // gt(x, x) -> 0
811 if (getLhs() == getRhs())
812 return getIntAttr(getType(), APInt(1, 0));
813
814 // Comparison against constant outside type bounds.
815 if (auto width = lhsType.getWidth()) {
816 if (auto rhsCst = getConstant(adaptor.getRhs())) {
817 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
818 commonWidth = std::max(commonWidth, 1);
819
820 // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
821 if (isUnsigned && rhsCst->zext(commonWidth)
822 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
823 return getIntAttr(getType(), APInt(1, 0));
824
825 // gt(x, const) -> 0 where const >= maxValue of the signed type of x
826 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
827 .sge(getMaxSignedValue(*width).sext(commonWidth)))
828 return getIntAttr(getType(), APInt(1, 0));
829
830 // gt(x, const) -> 1 where const < minValue of the unsigned type of x
831 // This can never occur since const is unsigned and cannot be less than 0.
832
833 // gt(x, const) -> 1 where const < minValue of the signed type of x
834 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
835 .slt(getMinSignedValue(*width).sext(commonWidth)))
836 return getIntAttr(getType(), APInt(1, 1));
837 }
838 }
839
841 *this, adaptor.getOperands(), BinOpKind::Compare,
842 [=](const APSInt &a, const APSInt &b) -> APInt {
843 return APInt(1, a > b);
844 });
845}
846
847OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
848 // eq(x, x) -> 1
849 if (getLhs() == getRhs())
850 return getIntAttr(getType(), APInt(1, 1));
851
852 if (auto rhsCst = getConstant(adaptor.getRhs())) {
853 /// eq(x, 1) -> x when x is 1 bit.
854 /// TODO: Support SInt<1> on the LHS etc.
855 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
856 getRhs().getType() == getType())
857 return getLhs();
858 }
859
861 *this, adaptor.getOperands(), BinOpKind::Compare,
862 [=](const APSInt &a, const APSInt &b) -> APInt {
863 return APInt(1, a == b);
864 });
865}
866
867LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
868 return canonicalizePrimOp(
869 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
870 if (auto rhsCst = getConstant(operands[1])) {
871 auto width = op.getLhs().getType().getBitWidthOrSentinel();
872
873 // eq(x, 0) -> not(x) when x is 1 bit.
874 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
875 op.getRhs().getType() == op.getType()) {
876 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
877 .getResult();
878 }
879
880 // eq(x, 0) -> not(orr(x)) when x is >1 bit
881 if (rhsCst->isZero() && width > 1) {
882 auto orrOp = OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
883 return NotPrimOp::create(rewriter, op.getLoc(), orrOp).getResult();
884 }
885
886 // eq(x, ~0) -> andr(x) when x is >1 bit
887 if (rhsCst->isAllOnes() && width > 1 &&
888 op.getLhs().getType() == op.getRhs().getType()) {
889 return AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
890 .getResult();
891 }
892 }
893 return {};
894 });
895}
896
897OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
898 // neq(x, x) -> 0
899 if (getLhs() == getRhs())
900 return getIntAttr(getType(), APInt(1, 0));
901
902 if (auto rhsCst = getConstant(adaptor.getRhs())) {
903 /// neq(x, 0) -> x when x is 1 bit.
904 /// TODO: Support SInt<1> on the LHS etc.
905 if (rhsCst->isZero() && getLhs().getType() == getType() &&
906 getRhs().getType() == getType())
907 return getLhs();
908 }
909
911 *this, adaptor.getOperands(), BinOpKind::Compare,
912 [=](const APSInt &a, const APSInt &b) -> APInt {
913 return APInt(1, a != b);
914 });
915}
916
917LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
918 return canonicalizePrimOp(
919 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
920 if (auto rhsCst = getConstant(operands[1])) {
921 auto width = op.getLhs().getType().getBitWidthOrSentinel();
922
923 // neq(x, 1) -> not(x) when x is 1 bit
924 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
925 op.getRhs().getType() == op.getType()) {
926 return NotPrimOp::create(rewriter, op.getLoc(), op.getLhs())
927 .getResult();
928 }
929
930 // neq(x, 0) -> orr(x) when x is >1 bit
931 if (rhsCst->isZero() && width > 1) {
932 return OrRPrimOp::create(rewriter, op.getLoc(), op.getLhs())
933 .getResult();
934 }
935
936 // neq(x, ~0) -> not(andr(x))) when x is >1 bit
937 if (rhsCst->isAllOnes() && width > 1 &&
938 op.getLhs().getType() == op.getRhs().getType()) {
939 auto andrOp =
940 AndRPrimOp::create(rewriter, op.getLoc(), op.getLhs());
941 return NotPrimOp::create(rewriter, op.getLoc(), andrOp).getResult();
942 }
943 }
944
945 return {};
946 });
947}
948
949OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
950 // TODO: implement constant folding, etc.
951 // Tracked in https://github.com/llvm/circt/issues/6696.
952 return {};
953}
954
955OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
956 // TODO: implement constant folding, etc.
957 // Tracked in https://github.com/llvm/circt/issues/6724.
958 return {};
959}
960
961OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
962 if (auto rhsCst = getConstant(adaptor.getRhs())) {
963 if (auto lhsCst = getConstant(adaptor.getLhs())) {
964
965 return IntegerAttr::get(
966 IntegerType::get(getContext(), lhsCst->getBitWidth()),
967 lhsCst->ashr(*rhsCst));
968 }
969
970 if (rhsCst->isZero())
971 return getLhs();
972 }
973
974 return {};
975}
976
977OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
978 if (auto rhsCst = getConstant(adaptor.getRhs())) {
979 // Constant folding
980 if (auto lhsCst = getConstant(adaptor.getLhs()))
981
982 return IntegerAttr::get(
983 IntegerType::get(getContext(), lhsCst->getBitWidth()),
984 lhsCst->shl(*rhsCst));
985
986 // integer.shl(x, 0) -> x
987 if (rhsCst->isZero())
988 return getLhs();
989 }
990
991 return {};
992}
993
994//===----------------------------------------------------------------------===//
995// Unary Operators
996//===----------------------------------------------------------------------===//
997
998OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
999 auto base = getInput().getType();
1000 auto w = getBitWidth(base);
1001 if (w)
1002 return getIntAttr(getType(), APInt(32, *w));
1003 return {};
1004}
1005
1006OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
1007 // No constant can be 'x' by definition.
1008 if (auto cst = getConstant(adaptor.getArg()))
1009 return getIntAttr(getType(), APInt(1, 0));
1010 return {};
1011}
1012
1013OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
1014 // No effect.
1015 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1016 return getInput();
1017
1018 // Be careful to only fold the cast into the constant if the size is known.
1019 // Otherwise width inference may produce differently-sized constants if the
1020 // sign changes.
1021 if (getType().base().hasWidth())
1022 if (auto cst = getConstant(adaptor.getInput()))
1023 return getIntAttr(getType(), *cst);
1024
1025 return {};
1026}
1027
1028void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1029 MLIRContext *context) {
1030 results.insert<patterns::StoUtoS>(context);
1031}
1032
1033OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1034 // No effect.
1035 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1036 return getInput();
1037
1038 // Be careful to only fold the cast into the constant if the size is known.
1039 // Otherwise width inference may produce differently-sized constants if the
1040 // sign changes.
1041 if (getType().base().hasWidth())
1042 if (auto cst = getConstant(adaptor.getInput()))
1043 return getIntAttr(getType(), *cst);
1044
1045 return {};
1046}
1047
1048void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1049 MLIRContext *context) {
1050 results.insert<patterns::UtoStoU>(context);
1051}
1052
1053OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1054 // No effect.
1055 if (getInput().getType() == getType())
1056 return getInput();
1057
1058 // Constant fold.
1059 if (auto cst = getConstant(adaptor.getInput()))
1060 return BoolAttr::get(getContext(), cst->getBoolValue());
1061
1062 return {};
1063}
1064
1065OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1066 // No effect.
1067 if (getInput().getType() == getType())
1068 return getInput();
1069
1070 // Constant fold.
1071 if (auto cst = getConstant(adaptor.getInput()))
1072 return BoolAttr::get(getContext(), cst->getBoolValue());
1073
1074 return {};
1075}
1076
1077OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1078 if (!hasKnownWidthIntTypes(*this))
1079 return {};
1080
1081 // Signed to signed is a noop, unsigned operands prepend a zero bit.
1082 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1083 getType().base().getWidthOrSentinel()))
1084 return getIntAttr(getType(), *cst);
1085
1086 return {};
1087}
1088
1089void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1090 MLIRContext *context) {
1091 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1092}
1093
1094OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1095 if (!hasKnownWidthIntTypes(*this))
1096 return {};
1097
1098 // FIRRTL negate always adds a bit.
1099 // -x ---> 0-sext(x) or 0-zext(x)
1100 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1101 getType().base().getWidthOrSentinel()))
1102 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1103
1104 return {};
1105}
1106
1107OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1108 if (!hasKnownWidthIntTypes(*this))
1109 return {};
1110
1111 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1112 getType().base().getWidthOrSentinel()))
1113 return getIntAttr(getType(), ~*cst);
1114
1115 return {};
1116}
1117
1118void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1119 MLIRContext *context) {
1120 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1121 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1122 patterns::NotGt>(context);
1123}
1124
1125class ReductionCat : public mlir::RewritePattern {
1126public:
1127 ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
1128 : RewritePattern(opName, 0, context) {}
1129
1130 /// Handle a constant operand in the cat operation.
1131 /// Returns true if the entire reduction can be replaced with a constant.
1132 /// May add non-zero constants to the remaining operands list.
1133 virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1134 ConstantOp constantOp,
1135 SmallVectorImpl<Value> &remaining) const = 0;
1136
1137 /// Return the unit value for this reduction operation:
1138 virtual bool getIdentityValue() const = 0;
1139
1140 LogicalResult
1141 matchAndRewrite(Operation *op,
1142 mlir::PatternRewriter &rewriter) const override {
1143 // Check if the operand is a cat operation
1144 auto catOp = op->getOperand(0).getDefiningOp<CatPrimOp>();
1145 if (!catOp)
1146 return failure();
1147
1148 SmallVector<Value> nonConstantOperands;
1149
1150 // Process each operand of the cat operation
1151 for (auto operand : catOp.getInputs()) {
1152 if (auto constantOp = operand.getDefiningOp<ConstantOp>()) {
1153 // Handle constant operands - may short-circuit the entire operation
1154 if (handleConstant(rewriter, op, constantOp, nonConstantOperands))
1155 return success();
1156 } else {
1157 // Keep non-constant operands for further processing
1158 nonConstantOperands.push_back(operand);
1159 }
1160 }
1161
1162 // If no operands remain, replace with identity value
1163 if (nonConstantOperands.empty()) {
1164 replaceOpWithNewOpAndCopyName<ConstantOp>(
1165 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1166 APInt(1, getIdentityValue()));
1167 return success();
1168 }
1169
1170 // If only one operand remains, apply reduction directly to it
1171 if (nonConstantOperands.size() == 1) {
1172 rewriter.modifyOpInPlace(
1173 op, [&] { op->setOperand(0, nonConstantOperands.front()); });
1174 return success();
1175 }
1176
1177 // Multiple operands remain - optimize only when cat has a single use.
1178 if (catOp->hasOneUse() &&
1179 nonConstantOperands.size() < catOp->getNumOperands()) {
1180 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, catOp,
1181 nonConstantOperands);
1182 return success();
1183 }
1184 return failure();
1185 }
1186};
1187
1188class OrRCat : public ReductionCat {
1189public:
1190 OrRCat(MLIRContext *context)
1191 : ReductionCat(context, OrRPrimOp::getOperationName()) {}
1192 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1193 ConstantOp value,
1194 SmallVectorImpl<Value> &remaining) const override {
1195 if (value.getValue().isZero())
1196 return false;
1197
1198 replaceOpWithNewOpAndCopyName<ConstantOp>(
1199 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1200 llvm::APInt(1, 1));
1201 return true;
1202 }
1203 bool getIdentityValue() const override { return false; }
1204};
1205
1206class AndRCat : public ReductionCat {
1207public:
1208 AndRCat(MLIRContext *context)
1209 : ReductionCat(context, AndRPrimOp::getOperationName()) {}
1210 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1211 ConstantOp value,
1212 SmallVectorImpl<Value> &remaining) const override {
1213 if (value.getValue().isAllOnes())
1214 return false;
1215
1216 replaceOpWithNewOpAndCopyName<ConstantOp>(
1217 rewriter, op, cast<IntType>(op->getResult(0).getType()),
1218 llvm::APInt(1, 0));
1219 return true;
1220 }
1221 bool getIdentityValue() const override { return true; }
1222};
1223
1224class XorRCat : public ReductionCat {
1225public:
1226 XorRCat(MLIRContext *context)
1227 : ReductionCat(context, XorRPrimOp::getOperationName()) {}
1228 bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op,
1229 ConstantOp value,
1230 SmallVectorImpl<Value> &remaining) const override {
1231 if (value.getValue().isZero())
1232 return false;
1233 remaining.push_back(value);
1234 return false;
1235 }
1236 bool getIdentityValue() const override { return false; }
1237};
1238
1239OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1240 if (!hasKnownWidthIntTypes(*this))
1241 return {};
1242
1243 if (getInput().getType().getBitWidthOrSentinel() == 0)
1244 return getIntAttr(getType(), APInt(1, 1));
1245
1246 // x == -1
1247 if (auto cst = getConstant(adaptor.getInput()))
1248 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1249
1250 // one bit is identity. Only applies to UInt since we can't make a cast
1251 // here.
1252 if (isUInt1(getInput().getType()))
1253 return getInput();
1254
1255 return {};
1256}
1257
1258void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1259 MLIRContext *context) {
1260 results.insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1261 patterns::AndRPadS, patterns::AndRCatAndR_left,
1262 patterns::AndRCatAndR_right, AndRCat>(context);
1263}
1264
1265OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1266 if (!hasKnownWidthIntTypes(*this))
1267 return {};
1268
1269 if (getInput().getType().getBitWidthOrSentinel() == 0)
1270 return getIntAttr(getType(), APInt(1, 0));
1271
1272 // x != 0
1273 if (auto cst = getConstant(adaptor.getInput()))
1274 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1275
1276 // one bit is identity. Only applies to UInt since we can't make a cast
1277 // here.
1278 if (isUInt1(getInput().getType()))
1279 return getInput();
1280
1281 return {};
1282}
1283
1284void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1285 MLIRContext *context) {
1286 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1287 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right, OrRCat>(
1288 context);
1289}
1290
1291OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1292 if (!hasKnownWidthIntTypes(*this))
1293 return {};
1294
1295 if (getInput().getType().getBitWidthOrSentinel() == 0)
1296 return getIntAttr(getType(), APInt(1, 0));
1297
1298 // popcount(x) & 1
1299 if (auto cst = getConstant(adaptor.getInput()))
1300 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1301
1302 // one bit is identity. Only applies to UInt since we can't make a cast here.
1303 if (isUInt1(getInput().getType()))
1304 return getInput();
1305
1306 return {};
1307}
1308
1309void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1310 MLIRContext *context) {
1311 results
1312 .insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1313 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right, XorRCat>(
1314 context);
1315}
1316
1317//===----------------------------------------------------------------------===//
1318// Other Operators
1319//===----------------------------------------------------------------------===//
1320
1321OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1322 auto inputs = getInputs();
1323 auto inputAdaptors = adaptor.getInputs();
1324
1325 // If no inputs, return 0-bit value
1326 if (inputs.empty())
1327 return getIntZerosAttr(getType());
1328
1329 // If single input and same type, return it
1330 if (inputs.size() == 1 && inputs[0].getType() == getType())
1331 return inputs[0];
1332
1333 // Make sure it's safe to fold.
1334 if (!hasKnownWidthIntTypes(*this))
1335 return {};
1336
1337 // Filter out zero-width operands
1338 SmallVector<Value> nonZeroInputs;
1339 SmallVector<Attribute> nonZeroAttributes;
1340 bool allConstant = true;
1341 for (auto [input, attr] : llvm::zip(inputs, inputAdaptors)) {
1342 auto inputType = type_cast<IntType>(input.getType());
1343 if (inputType.getBitWidthOrSentinel() != 0) {
1344 nonZeroInputs.push_back(input);
1345 if (!attr)
1346 allConstant = false;
1347 if (nonZeroInputs.size() > 1 && !allConstant)
1348 return {};
1349 }
1350 }
1351
1352 // If all inputs were zero-width, return 0-bit value
1353 if (nonZeroInputs.empty())
1354 return getIntZerosAttr(getType());
1355
1356 // If only one non-zero input and it has the same type as result, return it
1357 if (nonZeroInputs.size() == 1 && nonZeroInputs[0].getType() == getType())
1358 return nonZeroInputs[0];
1359
1360 if (!hasKnownWidthIntTypes(*this))
1361 return {};
1362
1363 // Constant fold cat - concatenate all constant operands
1364 SmallVector<APInt> constants;
1365 for (auto inputAdaptor : inputAdaptors) {
1366 if (auto cst = getConstant(inputAdaptor))
1367 constants.push_back(*cst);
1368 else
1369 return {}; // Not all operands are constant
1370 }
1371
1372 assert(!constants.empty());
1373 // Concatenate all constants from left to right
1374 APInt result = constants[0];
1375 for (size_t i = 1; i < constants.size(); ++i)
1376 result = result.concat(constants[i]);
1377
1378 return getIntAttr(getType(), result);
1379}
1380
1381void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1382 MLIRContext *context) {
1383 results.insert<patterns::DShlOfConstant>(context);
1384}
1385
1386void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1387 MLIRContext *context) {
1388 results.insert<patterns::DShrOfConstant>(context);
1389}
1390
1391namespace {
1392
1393/// Canonicalize a cat tree into single cat operation.
1394class FlattenCat : public mlir::RewritePattern {
1395public:
1396 FlattenCat(MLIRContext *context)
1397 : RewritePattern(CatPrimOp::getOperationName(), 0, context) {}
1398
1399 LogicalResult
1400 matchAndRewrite(Operation *op,
1401 mlir::PatternRewriter &rewriter) const override {
1402 auto cat = cast<CatPrimOp>(op);
1403 if (!hasKnownWidthIntTypes(cat) ||
1404 cat.getType().getBitWidthOrSentinel() == 0)
1405 return failure();
1406
1407 // If this is not a root of concat tree, skip.
1408 if (cat->hasOneUse() && isa<CatPrimOp>(*cat->getUsers().begin()))
1409 return failure();
1410
1411 // Try flattening the cat tree.
1412 SmallVector<Value> operands;
1413 SmallVector<Value> worklist;
1414 auto pushOperands = [&worklist](CatPrimOp op) {
1415 for (auto operand : llvm::reverse(op.getInputs()))
1416 worklist.push_back(operand);
1417 };
1418 pushOperands(cat);
1419 bool hasSigned = false, hasUnsigned = false;
1420 while (!worklist.empty()) {
1421 auto value = worklist.pop_back_val();
1422 auto catOp = value.getDefiningOp<CatPrimOp>();
1423 if (!catOp) {
1424 operands.push_back(value);
1425 (type_isa<UIntType>(value.getType()) ? hasUnsigned : hasSigned) = true;
1426 continue;
1427 }
1428
1429 pushOperands(catOp);
1430 }
1431
1432 // Helper function to cast signed values to unsigned. CatPrimOp converts
1433 // signed values to unsigned so we need to do it explicitly here.
1434 auto castToUIntIfSigned = [&](Value value) -> Value {
1435 if (type_isa<UIntType>(value.getType()))
1436 return value;
1437 return AsUIntPrimOp::create(rewriter, value.getLoc(), value);
1438 };
1439
1440 assert(operands.size() >= 1 && "zero width cast must be rejected");
1441
1442 if (operands.size() == 1) {
1443 rewriter.replaceOp(op, castToUIntIfSigned(operands[0]));
1444 return success();
1445 }
1446
1447 if (operands.size() == cat->getNumOperands())
1448 return failure();
1449
1450 // If types are mixed, cast all operands to unsigned.
1451 if (hasSigned && hasUnsigned)
1452 for (auto &operand : operands)
1453 operand = castToUIntIfSigned(operand);
1454
1455 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, op, cat.getType(),
1456 operands);
1457 return success();
1458 }
1459};
1460
1461// Fold successive constants cat(x, y, 1, 10, z) -> cat(x, y, 110, z)
1462class CatOfConstant : public mlir::RewritePattern {
1463public:
1464 CatOfConstant(MLIRContext *context)
1465 : RewritePattern(CatPrimOp::getOperationName(), 0, context) {}
1466
1467 LogicalResult
1468 matchAndRewrite(Operation *op,
1469 mlir::PatternRewriter &rewriter) const override {
1470 auto cat = cast<CatPrimOp>(op);
1471 if (!hasKnownWidthIntTypes(cat))
1472 return failure();
1473
1474 SmallVector<Value> operands;
1475
1476 for (size_t i = 0; i < cat->getNumOperands(); ++i) {
1477 auto cst = cat.getInputs()[i].getDefiningOp<ConstantOp>();
1478 if (!cst) {
1479 operands.push_back(cat.getInputs()[i]);
1480 continue;
1481 }
1482 APSInt value = cst.getValue();
1483 size_t j = i + 1;
1484 for (; j < cat->getNumOperands(); ++j) {
1485 auto nextCst = cat.getInputs()[j].getDefiningOp<ConstantOp>();
1486 if (!nextCst)
1487 break;
1488 value = value.concat(nextCst.getValue());
1489 }
1490
1491 if (j == i + 1) {
1492 // Not folded.
1493 operands.push_back(cst);
1494 } else {
1495 // Folded.
1496 operands.push_back(ConstantOp::create(rewriter, cat.getLoc(), value));
1497 }
1498
1499 i = j - 1;
1500 }
1501
1502 if (operands.size() == cat->getNumOperands())
1503 return failure();
1504
1505 replaceOpWithNewOpAndCopyName<CatPrimOp>(rewriter, op, cat.getType(),
1506 operands);
1507
1508 return success();
1509 }
1510};
1511
1512} // namespace
1513
1514void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1515 MLIRContext *context) {
1516 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1517 patterns::CatCast, FlattenCat, CatOfConstant>(context);
1518}
1519
1520OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1521 auto op = (*this);
1522 // BitCast is redundant if input and result types are same.
1523 if (op.getType() == op.getInput().getType())
1524 return op.getInput();
1525
1526 // Two consecutive BitCasts are redundant if first bitcast type is same as the
1527 // final result type.
1528 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1529 if (op.getType() == in.getInput().getType())
1530 return in.getInput();
1531
1532 return {};
1533}
1534
1535OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1536 IntType inputType = getInput().getType();
1537 IntType resultType = getType();
1538 // If we are extracting the entire input, then return it.
1539 if (inputType == getType() && resultType.hasWidth())
1540 return getInput();
1541
1542 // Constant fold.
1543 if (hasKnownWidthIntTypes(*this))
1544 if (auto cst = getConstant(adaptor.getInput()))
1545 return getIntAttr(resultType,
1546 cst->extractBits(getHi() - getLo() + 1, getLo()));
1547
1548 return {};
1549}
1550
1551struct BitsOfCat : public mlir::RewritePattern {
1552 BitsOfCat(MLIRContext *context)
1553 : RewritePattern(BitsPrimOp::getOperationName(), 0, context) {}
1554
1555 LogicalResult
1556 matchAndRewrite(Operation *op,
1557 mlir::PatternRewriter &rewriter) const override {
1558 auto bits = cast<BitsPrimOp>(op);
1559 auto cat = bits.getInput().getDefiningOp<CatPrimOp>();
1560 if (!cat)
1561 return failure();
1562 int32_t bitPos = bits.getLo();
1563 auto resultWidth = type_cast<UIntType>(bits.getType()).getWidthOrSentinel();
1564 if (resultWidth < 0)
1565 return failure();
1566 for (auto operand : llvm::reverse(cat.getInputs())) {
1567 auto operandWidth =
1568 type_cast<IntType>(operand.getType()).getWidthOrSentinel();
1569 if (operandWidth < 0)
1570 return failure();
1571 if (bitPos < operandWidth) {
1572 if (bitPos + resultWidth <= operandWidth) {
1573 auto newBits = rewriter.createOrFold<BitsPrimOp>(
1574 op->getLoc(), operand, bitPos + resultWidth - 1, bitPos);
1575 replaceOpAndCopyName(rewriter, op, newBits);
1576 return success();
1577 }
1578 return failure();
1579 }
1580 bitPos -= operandWidth;
1581 }
1582 return failure();
1583 }
1584};
1585
1586void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1587 MLIRContext *context) {
1588 results
1589 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1590 patterns::BitsOfAnd, patterns::BitsOfPad, BitsOfCat>(context);
1591}
1592
1593/// Replace the specified operation with a 'bits' op from the specified hi/lo
1594/// bits. Insert a cast to handle the case where the original operation
1595/// returned a signed integer.
1596static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1597 unsigned loBit, PatternRewriter &rewriter) {
1598 auto resType = type_cast<IntType>(op->getResult(0).getType());
1599 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1600 value = BitsPrimOp::create(rewriter, op->getLoc(), value, hiBit, loBit);
1601
1602 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1603 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1604 } else if (resType.isUnsigned() &&
1605 !type_cast<IntType>(value.getType()).isUnsigned()) {
1606 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1607 }
1608 rewriter.replaceOp(op, value);
1609}
1610
1611template <typename OpTy>
1612static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1613 // mux : UInt<0> -> 0
1614 if (op.getType().getBitWidthOrSentinel() == 0)
1615 return getIntAttr(op.getType(),
1616 APInt(0, 0, op.getType().isSignedInteger()));
1617
1618 // mux(cond, x, x) -> x
1619 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1620 return op.getHigh();
1621
1622 // The following folds require that the result has a known width. Otherwise
1623 // the mux requires an additional padding operation to be inserted, which is
1624 // not possible in a fold.
1625 if (op.getType().getBitWidthOrSentinel() < 0)
1626 return {};
1627
1628 // mux(0/1, x, y) -> x or y
1629 if (auto cond = getConstant(adaptor.getSel())) {
1630 if (cond->isZero() && op.getLow().getType() == op.getType())
1631 return op.getLow();
1632 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1633 return op.getHigh();
1634 }
1635
1636 // mux(cond, x, cst)
1637 if (auto lowCst = getConstant(adaptor.getLow())) {
1638 // mux(cond, c1, c2)
1639 if (auto highCst = getConstant(adaptor.getHigh())) {
1640 // mux(cond, cst, cst) -> cst
1641 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1642 *highCst == *lowCst)
1643 return getIntAttr(op.getType(), *highCst);
1644 // mux(cond, 1, 0) -> cond
1645 if (highCst->isOne() && lowCst->isZero() &&
1646 op.getType() == op.getSel().getType())
1647 return op.getSel();
1648
1649 // TODO: x ? ~0 : 0 -> sext(x)
1650 // TODO: "x ? c1 : c2" -> many tricks
1651 }
1652 // TODO: "x ? a : 0" -> sext(x) & a
1653 }
1654
1655 // TODO: "x ? c1 : y" -> "~x ? y : c1"
1656 return {};
1657}
1658
1659OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1660 return foldMux(*this, adaptor);
1661}
1662
1663OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1664 return foldMux(*this, adaptor);
1665}
1666
1667OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1668
1669namespace {
1670
1671// If the mux has a known output width, pad the operands up to this width.
1672// Most folds on mux require that folded operands are of the same width as
1673// the mux itself.
1674class MuxPad : public mlir::RewritePattern {
1675public:
1676 MuxPad(MLIRContext *context)
1677 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1678
1679 LogicalResult
1680 matchAndRewrite(Operation *op,
1681 mlir::PatternRewriter &rewriter) const override {
1682 auto mux = cast<MuxPrimOp>(op);
1683 auto width = mux.getType().getBitWidthOrSentinel();
1684 if (width < 0)
1685 return failure();
1686
1687 auto pad = [&](Value input) -> Value {
1688 auto inputWidth =
1689 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1690 if (inputWidth < 0 || width == inputWidth)
1691 return input;
1692 return PadPrimOp::create(rewriter, mux.getLoc(), mux.getType(), input,
1693 width)
1694 .getResult();
1695 };
1696
1697 auto newHigh = pad(mux.getHigh());
1698 auto newLow = pad(mux.getLow());
1699 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1700 return failure();
1701
1702 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1703 rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1704 mux->getAttrs());
1705 return success();
1706 }
1707};
1708
1709// Find muxes which have conditions dominated by other muxes with the same
1710// condition.
1711class MuxSharedCond : public mlir::RewritePattern {
1712public:
1713 MuxSharedCond(MLIRContext *context)
1714 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1715
1716 static const int depthLimit = 5;
1717
1718 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1719 mlir::PatternRewriter &rewriter,
1720 bool updateInPlace) const {
1721 if (updateInPlace) {
1722 rewriter.modifyOpInPlace(mux, [&] {
1723 mux.setOperand(1, high);
1724 mux.setOperand(2, low);
1725 });
1726 return {};
1727 }
1728 rewriter.setInsertionPointAfter(mux);
1729 return MuxPrimOp::create(rewriter, mux.getLoc(), mux.getType(),
1730 ValueRange{mux.getSel(), high, low})
1731 .getResult();
1732 }
1733
1734 // Walk a dependent mux tree assuming the condition cond is true.
1735 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1736 bool updateInPlace, int limit) const {
1737 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1738 if (!mux)
1739 return {};
1740 if (mux.getSel() == cond)
1741 return mux.getHigh();
1742 if (limit > depthLimit)
1743 return {};
1744 updateInPlace &= mux->hasOneUse();
1745
1746 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1747 limit + 1))
1748 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1749
1750 if (Value v =
1751 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1752 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1753 return {};
1754 }
1755
1756 // Walk a dependent mux tree assuming the condition cond is false.
1757 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1758 bool updateInPlace, int limit) const {
1759 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1760 if (!mux)
1761 return {};
1762 if (mux.getSel() == cond)
1763 return mux.getLow();
1764 if (limit > depthLimit)
1765 return {};
1766 updateInPlace &= mux->hasOneUse();
1767
1768 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1769 limit + 1))
1770 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1771
1772 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1773 limit + 1))
1774 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1775
1776 return {};
1777 }
1778
1779 LogicalResult
1780 matchAndRewrite(Operation *op,
1781 mlir::PatternRewriter &rewriter) const override {
1782 auto mux = cast<MuxPrimOp>(op);
1783 auto width = mux.getType().getBitWidthOrSentinel();
1784 if (width < 0)
1785 return failure();
1786
1787 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1788 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1789 return success();
1790 }
1791
1792 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
1793 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1794 return success();
1795 }
1796
1797 return failure();
1798 }
1799};
1800} // namespace
1801
1802void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1803 MLIRContext *context) {
1804 results
1805 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1806 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1807 patterns::MuxSameTrue, patterns::MuxSameFalse,
1808 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1809 context);
1810}
1811
1812void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1813 RewritePatternSet &results, MLIRContext *context) {
1814 results.add<patterns::Mux2PadSel>(context);
1815}
1816
1817void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1818 RewritePatternSet &results, MLIRContext *context) {
1819 results.add<patterns::Mux4PadSel>(context);
1820}
1821
1822OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1823 auto input = this->getInput();
1824
1825 // pad(x) -> x if the width doesn't change.
1826 if (input.getType() == getType())
1827 return input;
1828
1829 // Need to know the input width.
1830 auto inputType = input.getType().base();
1831 int32_t width = inputType.getWidthOrSentinel();
1832 if (width == -1)
1833 return {};
1834
1835 // Constant fold.
1836 if (auto cst = getConstant(adaptor.getInput())) {
1837 auto destWidth = getType().base().getWidthOrSentinel();
1838 if (destWidth == -1)
1839 return {};
1840
1841 if (inputType.isSigned() && cst->getBitWidth())
1842 return getIntAttr(getType(), cst->sext(destWidth));
1843 return getIntAttr(getType(), cst->zext(destWidth));
1844 }
1845
1846 return {};
1847}
1848
1849OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1850 auto input = this->getInput();
1851 IntType inputType = input.getType();
1852 int shiftAmount = getAmount();
1853
1854 // shl(x, 0) -> x
1855 if (shiftAmount == 0)
1856 return input;
1857
1858 // Constant fold.
1859 if (auto cst = getConstant(adaptor.getInput())) {
1860 auto inputWidth = inputType.getWidthOrSentinel();
1861 if (inputWidth != -1) {
1862 auto resultWidth = inputWidth + shiftAmount;
1863 shiftAmount = std::min(shiftAmount, resultWidth);
1864 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1865 }
1866 }
1867 return {};
1868}
1869
1870OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1871 auto input = this->getInput();
1872 IntType inputType = input.getType();
1873 int shiftAmount = getAmount();
1874 auto inputWidth = inputType.getWidthOrSentinel();
1875
1876 // shr(x, 0) -> x
1877 // Once the shr width changes, do this: shiftAmount == 0 &&
1878 // (!inputType.isSigned() || inputWidth > 0)
1879 if (shiftAmount == 0 && inputWidth > 0)
1880 return input;
1881
1882 if (inputWidth == -1)
1883 return {};
1884 if (inputWidth == 0)
1885 return getIntZerosAttr(getType());
1886
1887 // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
1888 // If x is signed, it is the sign bit.
1889 if (shiftAmount >= inputWidth && inputType.isUnsigned())
1890 return getIntAttr(getType(), APInt(0, 0, false));
1891
1892 // Constant fold.
1893 if (auto cst = getConstant(adaptor.getInput())) {
1894 APInt value;
1895 if (inputType.isSigned())
1896 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1897 else
1898 value = cst->lshr(std::min(shiftAmount, inputWidth));
1899 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1900 return getIntAttr(getType(), value.trunc(resultWidth));
1901 }
1902 return {};
1903}
1904
1905LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1906 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1907 if (inputWidth <= 0)
1908 return failure();
1909
1910 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1911 unsigned shiftAmount = op.getAmount();
1912 if (int(shiftAmount) >= inputWidth) {
1913 // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
1914 if (op.getType().base().isUnsigned())
1915 return failure();
1916
1917 // Shifting a signed value by the full width is actually taking the
1918 // sign bit. If the shift amount is greater than the input width, it
1919 // is equivalent to shifting by the input width.
1920 shiftAmount = inputWidth - 1;
1921 }
1922
1923 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1924 return success();
1925}
1926
1927LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1928 PatternRewriter &rewriter) {
1929 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1930 if (inputWidth <= 0)
1931 return failure();
1932
1933 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1934 unsigned keepAmount = op.getAmount();
1935 if (keepAmount)
1936 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1937 rewriter);
1938 return success();
1939}
1940
1941OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1942 if (hasKnownWidthIntTypes(*this))
1943 if (auto cst = getConstant(adaptor.getInput())) {
1944 int shiftAmount =
1945 getInput().getType().base().getWidthOrSentinel() - getAmount();
1946 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1947 }
1948
1949 return {};
1950}
1951
1952OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1953 if (hasKnownWidthIntTypes(*this))
1954 if (auto cst = getConstant(adaptor.getInput()))
1955 return getIntAttr(getType(),
1956 cst->trunc(getType().base().getWidthOrSentinel()));
1957 return {};
1958}
1959
1960LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1961 PatternRewriter &rewriter) {
1962 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1963 if (inputWidth <= 0)
1964 return failure();
1965
1966 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1967 unsigned dropAmount = op.getAmount();
1968 if (dropAmount != unsigned(inputWidth))
1969 replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
1970 rewriter);
1971 return success();
1972}
1973
1974void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1975 MLIRContext *context) {
1976 results.add<patterns::SubaccessOfConstant>(context);
1977}
1978
1979OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1980 // If there is only one input, just return it.
1981 if (adaptor.getInputs().size() == 1)
1982 return getOperand(1);
1983
1984 if (auto constIndex = getConstant(adaptor.getIndex())) {
1985 auto index = constIndex->getZExtValue();
1986 if (index < getInputs().size())
1987 return getInputs()[getInputs().size() - 1 - index];
1988 }
1989
1990 return {};
1991}
1992
1993LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
1994 PatternRewriter &rewriter) {
1995 // If all operands are equal, just canonicalize to it. We can add this
1996 // canonicalization as a folder but it costly to look through all inputs so it
1997 // is added here.
1998 if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
1999 return input == op.getInputs().front();
2000 })) {
2001 replaceOpAndCopyName(rewriter, op, op.getInputs().front());
2002 return success();
2003 }
2004
2005 // If the index width is narrower than the size of inputs, drop front
2006 // elements.
2007 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
2008 uint64_t inputSize = op.getInputs().size();
2009 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
2010 rewriter.modifyOpInPlace(op, [&]() {
2011 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
2012 });
2013 return success();
2014 }
2015
2016 // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
2017 // a[0]`), we can fold the op into subaccess op `a[idx]`.
2018 if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
2019 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
2020 auto subindex = e.value().template getDefiningOp<SubindexOp>();
2021 return subindex && lastSubindex.getInput() == subindex.getInput() &&
2022 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
2023 })) {
2024 replaceOpWithNewOpAndCopyName<SubaccessOp>(
2025 rewriter, op, lastSubindex.getInput(), op.getIndex());
2026 return success();
2027 }
2028 }
2029
2030 // If the size is 2, canonicalize into a normal mux to introduce more folds.
2031 if (op.getInputs().size() != 2)
2032 return failure();
2033
2034 // TODO: Handle even when `index` doesn't have uint<1>.
2035 auto uintType = op.getIndex().getType();
2036 if (uintType.getBitWidthOrSentinel() != 1)
2037 return failure();
2038
2039 // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
2040 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
2041 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
2042 return success();
2043}
2044
2045//===----------------------------------------------------------------------===//
2046// Declarations
2047//===----------------------------------------------------------------------===//
2048
2049/// Scan all the uses of the specified value, checking to see if there is
2050/// exactly one connect that has the value as its destination. This returns the
2051/// operation if found and if all the other users are "reads" from the value.
2052/// Returns null if there are no connects, or multiple connects to the value, or
2053/// if the value is involved in an `AttachOp`, or if the connect isn't matching.
2054///
2055/// Note that this will simply return the connect, which is located *anywhere*
2056/// after the definition of the value. Users of this function are likely
2057/// interested in the source side of the returned connect, the definition of
2058/// which does likely not dominate the original value.
2059MatchingConnectOp firrtl::getSingleConnectUserOf(Value value) {
2060 MatchingConnectOp connect;
2061 for (Operation *user : value.getUsers()) {
2062 // If we see an attach or aggregate sublements, just conservatively fail.
2063 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
2064 return {};
2065
2066 if (auto aConnect = dyn_cast<FConnectLike>(user))
2067 if (aConnect.getDest() == value) {
2068 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
2069 // If this is not a matching connect, a second matching connect or in a
2070 // different block, fail.
2071 if (!matchingConnect || (connect && connect != matchingConnect) ||
2072 matchingConnect->getBlock() != value.getParentBlock())
2073 return {};
2074 connect = matchingConnect;
2075 }
2076 }
2077 return connect;
2078}
2079
2080// Forward simple values through wire's and reg's.
2081static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op,
2082 PatternRewriter &rewriter) {
2083 // While we can do this for nearly all wires, we currently limit it to simple
2084 // things.
2085 Operation *connectedDecl = op.getDest().getDefiningOp();
2086 if (!connectedDecl)
2087 return failure();
2088
2089 // Only support wire and reg for now.
2090 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
2091 return failure();
2092 if (hasDontTouch(connectedDecl) || !AnnotationSet(connectedDecl).empty() ||
2093 !hasDroppableName(connectedDecl) ||
2094 cast<Forceable>(connectedDecl).isForceable())
2095 return failure();
2096
2097 // Only forward if the types exactly match and there is one connect.
2098 if (getSingleConnectUserOf(op.getDest()) != op)
2099 return failure();
2100
2101 // Only forward if there is more than one use
2102 if (connectedDecl->hasOneUse())
2103 return failure();
2104
2105 // Only do this if the connectee and the declaration are in the same block.
2106 auto *declBlock = connectedDecl->getBlock();
2107 auto *srcValueOp = op.getSrc().getDefiningOp();
2108 if (!srcValueOp) {
2109 // Ports are ok for wires but not registers.
2110 if (!isa<WireOp>(connectedDecl))
2111 return failure();
2112
2113 } else {
2114 // Constants/invalids in the same block are ok to forward, even through
2115 // reg's since the clocking doesn't matter for constants.
2116 if (!isa<ConstantOp>(srcValueOp))
2117 return failure();
2118 if (srcValueOp->getBlock() != declBlock)
2119 return failure();
2120 }
2121
2122 // Ok, we know we are doing the transformation.
2123
2124 auto replacement = op.getSrc();
2125 // This will be replaced with the constant source. First, make sure the
2126 // constant dominates all users.
2127 if (srcValueOp && srcValueOp != &declBlock->front())
2128 srcValueOp->moveBefore(&declBlock->front());
2129
2130 // Replace all things *using* the decl with the constant/port, and
2131 // remove the declaration.
2132 replaceOpAndCopyName(rewriter, connectedDecl, replacement);
2133
2134 // Remove the connect
2135 rewriter.eraseOp(op);
2136 return success();
2137}
2138
2139void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
2140 MLIRContext *context) {
2141 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
2142 context);
2143}
2144
2145LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
2146 PatternRewriter &rewriter) {
2147 // TODO: Canonicalize towards explicit extensions and flips here.
2148
2149 // If there is a simple value connected to a foldable decl like a wire or reg,
2150 // see if we can eliminate the decl.
2151 if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
2152 return success();
2153 return failure();
2154}
2155
2156//===----------------------------------------------------------------------===//
2157// Statements
2158//===----------------------------------------------------------------------===//
2159
2160/// If the specified value has an AttachOp user strictly dominating by
2161/// "dominatingAttach" then return it.
2162static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
2163 for (auto *user : value.getUsers()) {
2164 auto attach = dyn_cast<AttachOp>(user);
2165 if (!attach || attach == dominatedAttach)
2166 continue;
2167 if (attach->isBeforeInBlock(dominatedAttach))
2168 return attach;
2169 }
2170 return {};
2171}
2172
2173LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
2174 // Single operand attaches are a noop.
2175 if (op.getNumOperands() <= 1) {
2176 rewriter.eraseOp(op);
2177 return success();
2178 }
2179
2180 for (auto operand : op.getOperands()) {
2181 // Check to see if any of our operands has other attaches to it:
2182 // attach x, y
2183 // ...
2184 // attach x, z
2185 // If so, we can merge these into "attach x, y, z".
2186 if (auto attach = getDominatingAttachUser(operand, op)) {
2187 SmallVector<Value> newOperands(op.getOperands());
2188 for (auto newOperand : attach.getOperands())
2189 if (newOperand != operand) // Don't add operand twice.
2190 newOperands.push_back(newOperand);
2191 AttachOp::create(rewriter, op->getLoc(), newOperands);
2192 rewriter.eraseOp(attach);
2193 rewriter.eraseOp(op);
2194 return success();
2195 }
2196
2197 // If this wire is *only* used by an attach then we can just delete
2198 // it.
2199 // TODO: May need to be sensitive to "don't touch" or other
2200 // annotations.
2201 if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
2202 if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
2203 !wire.isForceable()) {
2204 SmallVector<Value> newOperands;
2205 for (auto newOperand : op.getOperands())
2206 if (newOperand != operand) // Don't the add wire.
2207 newOperands.push_back(newOperand);
2208
2209 AttachOp::create(rewriter, op->getLoc(), newOperands);
2210 rewriter.eraseOp(op);
2211 rewriter.eraseOp(wire);
2212 return success();
2213 }
2214 }
2215 }
2216 return failure();
2217}
2218
2219/// Replaces the given op with the contents of the given single-block region.
2220static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
2221 Region &region) {
2222 assert(llvm::hasSingleElement(region) && "expected single-region block");
2223 rewriter.inlineBlockBefore(&region.front(), op, {});
2224}
2225
2226LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
2227 if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
2228 if (constant.getValue().isAllOnes())
2229 replaceOpWithRegion(rewriter, op, op.getThenRegion());
2230 else if (op.hasElseRegion() && !op.getElseRegion().empty())
2231 replaceOpWithRegion(rewriter, op, op.getElseRegion());
2232
2233 rewriter.eraseOp(op);
2234
2235 return success();
2236 }
2237
2238 // Erase empty if-else block.
2239 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
2240 op.getElseBlock().empty()) {
2241 rewriter.eraseBlock(&op.getElseBlock());
2242 return success();
2243 }
2244
2245 // Erase empty whens.
2246
2247 // If there is stuff in the then block, leave this operation alone.
2248 if (!op.getThenBlock().empty())
2249 return failure();
2250
2251 // If not and there is no else, then this operation is just useless.
2252 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
2253 rewriter.eraseOp(op);
2254 return success();
2255 }
2256 return failure();
2257}
2258
2259namespace {
2260// Remove private nodes. If they have an interesting names, move the name to
2261// the source expression.
2262struct FoldNodeName : public mlir::RewritePattern {
2263 FoldNodeName(MLIRContext *context)
2264 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
2265 LogicalResult matchAndRewrite(Operation *op,
2266 PatternRewriter &rewriter) const override {
2267 auto node = cast<NodeOp>(op);
2268 auto name = node.getNameAttr();
2269 if (!node.hasDroppableName() || node.getInnerSym() ||
2270 !AnnotationSet(node).empty() || node.isForceable())
2271 return failure();
2272 auto *newOp = node.getInput().getDefiningOp();
2273 if (newOp)
2274 updateName(rewriter, newOp, name);
2275 rewriter.replaceOp(node, node.getInput());
2276 return success();
2277 }
2278};
2279
2280// Bypass nodes.
2281struct NodeBypass : public mlir::RewritePattern {
2282 NodeBypass(MLIRContext *context)
2283 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
2284 LogicalResult matchAndRewrite(Operation *op,
2285 PatternRewriter &rewriter) const override {
2286 auto node = cast<NodeOp>(op);
2287 if (node.getInnerSym() || !AnnotationSet(node).empty() ||
2288 node.use_empty() || node.isForceable())
2289 return failure();
2290 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
2291 return success();
2292 }
2293};
2294
2295} // namespace
2296
2297template <typename OpTy>
2298static LogicalResult demoteForceableIfUnused(OpTy op,
2299 PatternRewriter &rewriter) {
2300 if (!op.isForceable() || !op.getDataRef().use_empty())
2301 return failure();
2302
2303 firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
2304 return success();
2305}
2306
2307// Interesting names and symbols and don't touch force nodes to stick around.
2308LogicalResult NodeOp::fold(FoldAdaptor adaptor,
2309 SmallVectorImpl<OpFoldResult> &results) {
2310 if (!hasDroppableName())
2311 return failure();
2312 if (hasDontTouch(getResult())) // handles inner symbols
2313 return failure();
2314 if (getAnnotationsAttr() && !AnnotationSet(getAnnotationsAttr()).empty())
2315 return failure();
2316 if (isForceable())
2317 return failure();
2318 if (!adaptor.getInput())
2319 return failure();
2320
2321 results.push_back(adaptor.getInput());
2322 return success();
2323}
2324
2325void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2326 MLIRContext *context) {
2327 results.insert<FoldNodeName>(context);
2328 results.add(demoteForceableIfUnused<NodeOp>);
2329}
2330
2331namespace {
2332// For a lhs, find all the writers of fields of the aggregate type. If there
2333// is one writer for each field, merge the writes
2334struct AggOneShot : public mlir::RewritePattern {
2335 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2336 : RewritePattern(name, 0, context) {}
2337
2338 SmallVector<Value> getCompleteWrite(Operation *lhs) const {
2339 auto lhsTy = lhs->getResult(0).getType();
2340 if (!type_isa<BundleType, FVectorType>(lhsTy))
2341 return {};
2342
2343 DenseMap<uint32_t, Value> fields;
2344 for (Operation *user : lhs->getResult(0).getUsers()) {
2345 if (user->getParentOp() != lhs->getParentOp())
2346 return {};
2347 if (auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2348 if (aConnect.getDest() == lhs->getResult(0))
2349 return {};
2350 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2351 for (Operation *subuser : subField.getResult().getUsers()) {
2352 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2353 if (aConnect.getDest() == subField) {
2354 if (subuser->getParentOp() != lhs->getParentOp())
2355 return {};
2356 if (fields.count(subField.getFieldIndex())) // duplicate write
2357 return {};
2358 fields[subField.getFieldIndex()] = aConnect.getSrc();
2359 }
2360 continue;
2361 }
2362 return {};
2363 }
2364 } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2365 for (Operation *subuser : subIndex.getResult().getUsers()) {
2366 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2367 if (aConnect.getDest() == subIndex) {
2368 if (subuser->getParentOp() != lhs->getParentOp())
2369 return {};
2370 if (fields.count(subIndex.getIndex())) // duplicate write
2371 return {};
2372 fields[subIndex.getIndex()] = aConnect.getSrc();
2373 }
2374 continue;
2375 }
2376 return {};
2377 }
2378 } else {
2379 return {};
2380 }
2381 }
2382
2383 SmallVector<Value> values;
2384 uint32_t total = type_isa<BundleType>(lhsTy)
2385 ? type_cast<BundleType>(lhsTy).getNumElements()
2386 : type_cast<FVectorType>(lhsTy).getNumElements();
2387 for (uint32_t i = 0; i < total; ++i) {
2388 if (!fields.count(i))
2389 return {};
2390 values.push_back(fields[i]);
2391 }
2392 return values;
2393 }
2394
2395 LogicalResult matchAndRewrite(Operation *op,
2396 PatternRewriter &rewriter) const override {
2397 auto values = getCompleteWrite(op);
2398 if (values.empty())
2399 return failure();
2400 rewriter.setInsertionPointToEnd(op->getBlock());
2401 auto dest = op->getResult(0);
2402 auto destType = dest.getType();
2403
2404 // If not passive, cannot matchingconnect.
2405 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2406 return failure();
2407
2408 Value newVal = type_isa<BundleType>(destType)
2409 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2410 destType, values)
2411 : rewriter.createOrFold<VectorCreateOp>(
2412 op->getLoc(), destType, values);
2413 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2414 for (Operation *user : dest.getUsers()) {
2415 if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2416 for (Operation *subuser :
2417 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2418 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2419 if (aConnect.getDest() == subIndex)
2420 rewriter.eraseOp(aConnect);
2421 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2422 for (Operation *subuser :
2423 llvm::make_early_inc_range(subField.getResult().getUsers()))
2424 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2425 if (aConnect.getDest() == subField)
2426 rewriter.eraseOp(aConnect);
2427 }
2428 }
2429 return success();
2430 }
2431};
2432
2433struct WireAggOneShot : public AggOneShot {
2434 WireAggOneShot(MLIRContext *context)
2435 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2436};
2437struct SubindexAggOneShot : public AggOneShot {
2438 SubindexAggOneShot(MLIRContext *context)
2439 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2440};
2441struct SubfieldAggOneShot : public AggOneShot {
2442 SubfieldAggOneShot(MLIRContext *context)
2443 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2444};
2445} // namespace
2446
2447void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2448 MLIRContext *context) {
2449 results.insert<WireAggOneShot>(context);
2450 results.add(demoteForceableIfUnused<WireOp>);
2451}
2452
2453void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2454 MLIRContext *context) {
2455 results.insert<SubindexAggOneShot>(context);
2456}
2457
2458OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2459 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2460 if (!attr)
2461 return {};
2462 return attr[getIndex()];
2463}
2464
2465OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2466 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2467 if (!attr)
2468 return {};
2469 auto index = getFieldIndex();
2470 return attr[index];
2471}
2472
2473void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2474 MLIRContext *context) {
2475 results.insert<SubfieldAggOneShot>(context);
2476}
2477
2478static Attribute collectFields(MLIRContext *context,
2479 ArrayRef<Attribute> operands) {
2480 for (auto operand : operands)
2481 if (!operand)
2482 return {};
2483 return ArrayAttr::get(context, operands);
2484}
2485
2486OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2487 // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2488 // bundle<a:..., b:...>.
2489 if (getNumOperands() > 0)
2490 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2491 if (first.getFieldIndex() == 0 &&
2492 first.getInput().getType() == getType() &&
2493 llvm::all_of(
2494 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2495 auto subindex =
2496 elem.value().template getDefiningOp<SubfieldOp>();
2497 return subindex && subindex.getInput() == first.getInput() &&
2498 subindex.getFieldIndex() == elem.index();
2499 }))
2500 return first.getInput();
2501
2502 return collectFields(getContext(), adaptor.getOperands());
2503}
2504
2505OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2506 // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2507 // vector<..., 2>.
2508 if (getNumOperands() > 0)
2509 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2510 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2511 llvm::all_of(
2512 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2513 auto subindex =
2514 elem.value().template getDefiningOp<SubindexOp>();
2515 return subindex && subindex.getInput() == first.getInput() &&
2516 subindex.getIndex() == elem.index();
2517 }))
2518 return first.getInput();
2519
2520 return collectFields(getContext(), adaptor.getOperands());
2521}
2522
2523OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2524 if (getOperand().getType() == getType())
2525 return getOperand();
2526 return {};
2527}
2528
2529namespace {
2530// A register with constant reset and all connection to either itself or the
2531// same constant, must be replaced by the constant.
2532struct FoldResetMux : public mlir::RewritePattern {
2533 FoldResetMux(MLIRContext *context)
2534 : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2535 LogicalResult matchAndRewrite(Operation *op,
2536 PatternRewriter &rewriter) const override {
2537 auto reg = cast<RegResetOp>(op);
2538 auto reset =
2539 dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2540 if (!reset || hasDontTouch(reg.getOperation()) ||
2541 !AnnotationSet(reg).empty() || reg.isForceable())
2542 return failure();
2543 // Find the one true connect, or bail
2544 auto con = getSingleConnectUserOf(reg.getResult());
2545 if (!con)
2546 return failure();
2547
2548 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2549 if (!mux)
2550 return failure();
2551 auto *high = mux.getHigh().getDefiningOp();
2552 auto *low = mux.getLow().getDefiningOp();
2553 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2554
2555 if (constOp && low != reg)
2556 return failure();
2557 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2558 constOp = dyn_cast<ConstantOp>(low);
2559
2560 if (!constOp || constOp.getType() != reset.getType() ||
2561 constOp.getValue() != reset.getValue())
2562 return failure();
2563
2564 // Check all types should be typed by now
2565 auto regTy = reg.getResult().getType();
2566 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2567 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2568 regTy.getBitWidthOrSentinel() < 0)
2569 return failure();
2570
2571 // Ok, we know we are doing the transformation.
2572
2573 // Make sure the constant dominates all users.
2574 if (constOp != &con->getBlock()->front())
2575 constOp->moveBefore(&con->getBlock()->front());
2576
2577 // Replace the register with the constant.
2578 replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2579 // Remove the connect.
2580 rewriter.eraseOp(con);
2581 return success();
2582 }
2583};
2584} // namespace
2585
2586static bool isDefinedByOneConstantOp(Value v) {
2587 if (auto c = v.getDefiningOp<ConstantOp>())
2588 return c.getValue().isOne();
2589 if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2590 return sc.getValue();
2591 return false;
2592}
2593
2594static LogicalResult
2595canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2596 if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2597 return failure();
2598
2599 auto resetValue = reg.getResetValue();
2600 if (reg.getType(0) != resetValue.getType())
2601 return failure();
2602
2603 // Ignore 'passthrough'.
2604 (void)dropWrite(rewriter, reg->getResult(0), {});
2605 replaceOpWithNewOpAndCopyName<NodeOp>(
2606 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2607 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2608 return success();
2609}
2610
2611void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2612 MLIRContext *context) {
2613 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2615 results.add(demoteForceableIfUnused<RegResetOp>);
2616}
2617
2618// Returns the value connected to a port, if there is only one.
2619static Value getPortFieldValue(Value port, StringRef name) {
2620 auto portTy = type_cast<BundleType>(port.getType());
2621 auto fieldIndex = portTy.getElementIndex(name);
2622 assert(fieldIndex && "missing field on memory port");
2623
2624 Value value = {};
2625 for (auto *op : port.getUsers()) {
2626 auto portAccess = cast<SubfieldOp>(op);
2627 if (fieldIndex != portAccess.getFieldIndex())
2628 continue;
2629 auto conn = getSingleConnectUserOf(portAccess);
2630 if (!conn || value)
2631 return {};
2632 value = conn.getSrc();
2633 }
2634 return value;
2635}
2636
2637// Returns true if the enable field of a port is set to false.
2638static bool isPortDisabled(Value port) {
2639 auto value = getPortFieldValue(port, "en");
2640 if (!value)
2641 return false;
2642 auto portConst = value.getDefiningOp<ConstantOp>();
2643 if (!portConst)
2644 return false;
2645 return portConst.getValue().isZero();
2646}
2647
2648// Returns true if the data output is unused.
2649static bool isPortUnused(Value port, StringRef data) {
2650 auto portTy = type_cast<BundleType>(port.getType());
2651 auto fieldIndex = portTy.getElementIndex(data);
2652 assert(fieldIndex && "missing enable flag on memory port");
2653
2654 for (auto *op : port.getUsers()) {
2655 auto portAccess = cast<SubfieldOp>(op);
2656 if (fieldIndex != portAccess.getFieldIndex())
2657 continue;
2658 if (!portAccess.use_empty())
2659 return false;
2660 }
2661
2662 return true;
2663}
2664
2665// Returns the value connected to a port, if there is only one.
2666static void replacePortField(PatternRewriter &rewriter, Value port,
2667 StringRef name, Value value) {
2668 auto portTy = type_cast<BundleType>(port.getType());
2669 auto fieldIndex = portTy.getElementIndex(name);
2670 assert(fieldIndex && "missing field on memory port");
2671
2672 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2673 auto portAccess = cast<SubfieldOp>(op);
2674 if (fieldIndex != portAccess.getFieldIndex())
2675 continue;
2676 rewriter.replaceAllUsesWith(portAccess, value);
2677 rewriter.eraseOp(portAccess);
2678 }
2679}
2680
2681// Remove accesses to a port which is used.
2682static void erasePort(PatternRewriter &rewriter, Value port) {
2683 // Helper to create a dummy 0 clock for the dummy registers.
2684 Value clock;
2685 auto getClock = [&] {
2686 if (!clock)
2687 clock = SpecialConstantOp::create(rewriter, port.getLoc(),
2688 ClockType::get(rewriter.getContext()),
2689 false);
2690 return clock;
2691 };
2692
2693 // Find the clock field of the port and determine whether the port is
2694 // accessed only through its subfields or as a whole wire. If the port
2695 // is used in its entirety, replace it with a wire. Otherwise,
2696 // eliminate individual subfields and replace with reasonable defaults.
2697 for (auto *op : port.getUsers()) {
2698 auto subfield = dyn_cast<SubfieldOp>(op);
2699 if (!subfield) {
2700 auto ty = port.getType();
2701 auto reg = RegOp::create(rewriter, port.getLoc(), ty, getClock());
2702 rewriter.replaceAllUsesWith(port, reg.getResult());
2703 return;
2704 }
2705 }
2706
2707 // Remove all connects to field accesses as they are no longer relevant.
2708 // If field values are used anywhere, which should happen solely for read
2709 // ports, a dummy register is introduced which replicates the behaviour of
2710 // memory that is never written, but might be read.
2711 for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2712 auto access = cast<SubfieldOp>(accessOp);
2713 for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2714 auto connect = dyn_cast<FConnectLike>(user);
2715 if (connect && connect.getDest() == access) {
2716 rewriter.eraseOp(user);
2717 continue;
2718 }
2719 }
2720 if (access.use_empty()) {
2721 rewriter.eraseOp(access);
2722 continue;
2723 }
2724
2725 // Replace read values with a register that is never written, handing off
2726 // the canonicalization of such a register to another canonicalizer.
2727 auto ty = access.getType();
2728 auto reg = RegOp::create(rewriter, access.getLoc(), ty, getClock());
2729 rewriter.replaceOp(access, reg.getResult());
2730 }
2731 assert(port.use_empty() && "port should have no remaining uses");
2732}
2733
2734namespace {
2735// If memory has known, but zero width, eliminate it.
2736struct FoldZeroWidthMemory : public mlir::RewritePattern {
2737 FoldZeroWidthMemory(MLIRContext *context)
2738 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2739 LogicalResult matchAndRewrite(Operation *op,
2740 PatternRewriter &rewriter) const override {
2741 MemOp mem = cast<MemOp>(op);
2742 if (hasDontTouch(mem))
2743 return failure();
2744
2745 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2746 mem.getDataType().getBitWidthOrSentinel() != 0)
2747 return failure();
2748
2749 // Make sure are users are safe to replace
2750 for (auto port : mem.getResults())
2751 for (auto *user : port.getUsers())
2752 if (!isa<SubfieldOp>(user))
2753 return failure();
2754
2755 // Annoyingly, there isn't a good replacement for the port as a whole,
2756 // since they have an outer flip type.
2757 for (auto port : op->getResults()) {
2758 for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2759 SubfieldOp sfop = cast<SubfieldOp>(user);
2760 StringRef fieldName = sfop.getFieldName();
2761 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2762 rewriter, sfop, sfop.getResult().getType())
2763 .getResult();
2764 if (fieldName.ends_with("data")) {
2765 // Make sure to write data ports.
2766 auto zero = firrtl::ConstantOp::create(
2767 rewriter, wire.getLoc(),
2768 firrtl::type_cast<IntType>(wire.getType()), APInt::getZero(0));
2769 MatchingConnectOp::create(rewriter, wire.getLoc(), wire, zero);
2770 }
2771 }
2772 }
2773 rewriter.eraseOp(op);
2774 return success();
2775 }
2776};
2777
2778// If memory has no write ports and no file initialization, eliminate it.
2779struct FoldReadOrWriteOnlyMemory : public mlir::RewritePattern {
2780 FoldReadOrWriteOnlyMemory(MLIRContext *context)
2781 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2782 LogicalResult matchAndRewrite(Operation *op,
2783 PatternRewriter &rewriter) const override {
2784 MemOp mem = cast<MemOp>(op);
2785 if (hasDontTouch(mem))
2786 return failure();
2787 bool isRead = false, isWritten = false;
2788 for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2789 switch (mem.getPortKind(i)) {
2790 case MemOp::PortKind::Read:
2791 isRead = true;
2792 if (isWritten)
2793 return failure();
2794 continue;
2795 case MemOp::PortKind::Write:
2796 isWritten = true;
2797 if (isRead)
2798 return failure();
2799 continue;
2800 case MemOp::PortKind::Debug:
2801 case MemOp::PortKind::ReadWrite:
2802 return failure();
2803 }
2804 llvm_unreachable("unknown port kind");
2805 }
2806 assert((!isWritten || !isRead) && "memory is in use");
2807
2808 // If the memory is read only, but has a file initialization, then we can't
2809 // remove it. A write only memory with file initialization is okay to
2810 // remove.
2811 if (isRead && mem.getInit())
2812 return failure();
2813
2814 for (auto port : mem.getResults())
2815 erasePort(rewriter, port);
2816
2817 rewriter.eraseOp(op);
2818 return success();
2819 }
2820};
2821
2822// Eliminate the dead ports of memories.
2823struct FoldUnusedPorts : public mlir::RewritePattern {
2824 FoldUnusedPorts(MLIRContext *context)
2825 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2826 LogicalResult matchAndRewrite(Operation *op,
2827 PatternRewriter &rewriter) const override {
2828 MemOp mem = cast<MemOp>(op);
2829 if (hasDontTouch(mem))
2830 return failure();
2831 // Identify the dead and changed ports.
2832 llvm::SmallBitVector deadPorts(mem.getNumResults());
2833 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2834 // Do not simplify annotated ports.
2835 if (!mem.getPortAnnotation(i).empty())
2836 continue;
2837
2838 // Skip debug ports.
2839 auto kind = mem.getPortKind(i);
2840 if (kind == MemOp::PortKind::Debug)
2841 continue;
2842
2843 // If a port is disabled, always eliminate it.
2844 if (isPortDisabled(port)) {
2845 deadPorts.set(i);
2846 continue;
2847 }
2848 // Eliminate read ports whose outputs are not used.
2849 if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
2850 deadPorts.set(i);
2851 continue;
2852 }
2853 }
2854 if (deadPorts.none())
2855 return failure();
2856
2857 // Rebuild the new memory with the altered ports.
2858 SmallVector<Type> resultTypes;
2859 SmallVector<StringRef> portNames;
2860 SmallVector<Attribute> portAnnotations;
2861 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2862 if (deadPorts[i])
2863 continue;
2864 resultTypes.push_back(port.getType());
2865 portNames.push_back(mem.getPortName(i));
2866 portAnnotations.push_back(mem.getPortAnnotation(i));
2867 }
2868
2869 MemOp newOp;
2870 if (!resultTypes.empty())
2871 newOp = MemOp::create(
2872 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2873 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2874 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2875 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2876 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2877
2878 // Replace the dead ports with dummy wires.
2879 unsigned nextPort = 0;
2880 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2881 if (deadPorts[i])
2882 erasePort(rewriter, port);
2883 else
2884 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2885 }
2886
2887 rewriter.eraseOp(op);
2888 return success();
2889 }
2890};
2891
2892// Rewrite write-only read-write ports to write ports.
2893struct FoldReadWritePorts : public mlir::RewritePattern {
2894 FoldReadWritePorts(MLIRContext *context)
2895 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2896 LogicalResult matchAndRewrite(Operation *op,
2897 PatternRewriter &rewriter) const override {
2898 MemOp mem = cast<MemOp>(op);
2899 if (hasDontTouch(mem))
2900 return failure();
2901
2902 // Identify read-write ports whose read end is unused.
2903 llvm::SmallBitVector deadReads(mem.getNumResults());
2904 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2905 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2906 continue;
2907 if (!mem.getPortAnnotation(i).empty())
2908 continue;
2909 if (isPortUnused(port, "rdata")) {
2910 deadReads.set(i);
2911 continue;
2912 }
2913 }
2914 if (deadReads.none())
2915 return failure();
2916
2917 SmallVector<Type> resultTypes;
2918 SmallVector<StringRef> portNames;
2919 SmallVector<Attribute> portAnnotations;
2920 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2921 if (deadReads[i])
2922 resultTypes.push_back(
2923 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2924 MemOp::PortKind::Write, mem.getMaskBits()));
2925 else
2926 resultTypes.push_back(port.getType());
2927
2928 portNames.push_back(mem.getPortName(i));
2929 portAnnotations.push_back(mem.getPortAnnotation(i));
2930 }
2931
2932 auto newOp = MemOp::create(
2933 rewriter, mem.getLoc(), resultTypes, mem.getReadLatency(),
2934 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2935 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2936 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2937 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2938
2939 for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2940 auto result = mem.getResult(i);
2941 auto newResult = newOp.getResult(i);
2942 if (deadReads[i]) {
2943 auto resultPortTy = type_cast<BundleType>(result.getType());
2944
2945 // Rewrite accesses to the old port field to accesses to a
2946 // corresponding field of the new port.
2947 auto replace = [&](StringRef toName, StringRef fromName) {
2948 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2949 assert(fromFieldIndex && "missing enable flag on memory port");
2950
2951 auto toField = SubfieldOp::create(rewriter, newResult.getLoc(),
2952 newResult, toName);
2953 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2954 auto fromField = cast<SubfieldOp>(op);
2955 if (fromFieldIndex != fromField.getFieldIndex())
2956 continue;
2957 rewriter.replaceOp(fromField, toField.getResult());
2958 }
2959 };
2960
2961 replace("addr", "addr");
2962 replace("en", "en");
2963 replace("clk", "clk");
2964 replace("data", "wdata");
2965 replace("mask", "wmask");
2966
2967 // Remove the wmode field, replacing it with dummy wires.
2968 auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
2969 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2970 auto wmodeField = cast<SubfieldOp>(op);
2971 if (wmodeFieldIndex != wmodeField.getFieldIndex())
2972 continue;
2973 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2974 }
2975 } else {
2976 rewriter.replaceAllUsesWith(result, newResult);
2977 }
2978 }
2979 rewriter.eraseOp(op);
2980 return success();
2981 }
2982};
2983
2984// Eliminate the dead ports of memories.
2985struct FoldUnusedBits : public mlir::RewritePattern {
2986 FoldUnusedBits(MLIRContext *context)
2987 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2988
2989 LogicalResult matchAndRewrite(Operation *op,
2990 PatternRewriter &rewriter) const override {
2991 MemOp mem = cast<MemOp>(op);
2992 if (hasDontTouch(mem))
2993 return failure();
2994
2995 // Only apply the transformation if the memory is not sequential.
2996 const auto &summary = mem.getSummary();
2997 if (summary.isMasked || summary.isSeqMem())
2998 return failure();
2999
3000 auto type = type_dyn_cast<IntType>(mem.getDataType());
3001 if (!type)
3002 return failure();
3003 auto width = type.getBitWidthOrSentinel();
3004 if (width <= 0)
3005 return failure();
3006
3007 llvm::SmallBitVector usedBits(width);
3008 DenseMap<unsigned, unsigned> mapping;
3009
3010 // Find which bits are used out of the users of a read port. This detects
3011 // ports whose data/rdata field is used only through bit select ops. The
3012 // bit selects are then used to build a bit-mask. The ops are collected.
3013 SmallVector<BitsPrimOp> readOps;
3014 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
3015 auto portTy = type_cast<BundleType>(port.getType());
3016 auto fieldIndex = portTy.getElementIndex(field);
3017 assert(fieldIndex && "missing data port");
3018
3019 for (auto *op : port.getUsers()) {
3020 auto portAccess = cast<SubfieldOp>(op);
3021 if (fieldIndex != portAccess.getFieldIndex())
3022 continue;
3023
3024 for (auto *user : op->getUsers()) {
3025 auto bits = dyn_cast<BitsPrimOp>(user);
3026 if (!bits)
3027 return failure();
3028
3029 usedBits.set(bits.getLo(), bits.getHi() + 1);
3030 if (usedBits.all())
3031 return failure();
3032
3033 mapping[bits.getLo()] = 0;
3034 readOps.push_back(bits);
3035 }
3036 }
3037
3038 return success();
3039 };
3040
3041 // Finds the users of write ports. This expects all the data/wdata fields
3042 // of the ports to be used solely as the destination of matching connects.
3043 // If a memory has ports with other uses, it is excluded from optimisation.
3044 SmallVector<MatchingConnectOp> writeOps;
3045 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
3046 auto portTy = type_cast<BundleType>(port.getType());
3047 auto fieldIndex = portTy.getElementIndex(field);
3048 assert(fieldIndex && "missing data port");
3049
3050 for (auto *op : port.getUsers()) {
3051 auto portAccess = cast<SubfieldOp>(op);
3052 if (fieldIndex != portAccess.getFieldIndex())
3053 continue;
3054
3055 auto conn = getSingleConnectUserOf(portAccess);
3056 if (!conn)
3057 return failure();
3058
3059 writeOps.push_back(conn);
3060 }
3061 return success();
3062 };
3063
3064 // Traverse all ports and find the read and used data fields.
3065 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3066 // Do not simplify annotated ports.
3067 if (!mem.getPortAnnotation(i).empty())
3068 return failure();
3069
3070 switch (mem.getPortKind(i)) {
3071 case MemOp::PortKind::Debug:
3072 // Skip debug ports.
3073 return failure();
3074 case MemOp::PortKind::Write:
3075 if (failed(findWriteUsers(port, "data")))
3076 return failure();
3077 continue;
3078 case MemOp::PortKind::Read:
3079 if (failed(findReadUsers(port, "data")))
3080 return failure();
3081 continue;
3082 case MemOp::PortKind::ReadWrite:
3083 if (failed(findWriteUsers(port, "wdata")))
3084 return failure();
3085 if (failed(findReadUsers(port, "rdata")))
3086 return failure();
3087 continue;
3088 }
3089 llvm_unreachable("unknown port kind");
3090 }
3091
3092 // Unused memories are handled in a different canonicalizer.
3093 if (usedBits.none())
3094 return failure();
3095
3096 // Build a mapping of existing indices to compacted ones.
3097 SmallVector<std::pair<unsigned, unsigned>> ranges;
3098 unsigned newWidth = 0;
3099 for (int i = usedBits.find_first(); 0 <= i && i < width;) {
3100 int e = usedBits.find_next_unset(i);
3101 if (e < 0)
3102 e = width;
3103 for (int idx = i; idx < e; ++idx, ++newWidth) {
3104 if (auto it = mapping.find(idx); it != mapping.end()) {
3105 it->second = newWidth;
3106 }
3107 }
3108 ranges.emplace_back(i, e - 1);
3109 i = e != width ? usedBits.find_next(e) : e;
3110 }
3111
3112 // Create the new op with the new port types.
3113 auto newType = IntType::get(op->getContext(), type.isSigned(), newWidth);
3114 SmallVector<Type> portTypes;
3115 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3116 portTypes.push_back(
3117 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
3118 }
3119 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
3120 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
3121 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
3122 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
3123 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
3124
3125 // Rewrite bundle users to the new data type.
3126 auto rewriteSubfield = [&](Value port, StringRef field) {
3127 auto portTy = type_cast<BundleType>(port.getType());
3128 auto fieldIndex = portTy.getElementIndex(field);
3129 assert(fieldIndex && "missing data port");
3130
3131 rewriter.setInsertionPointAfter(newMem);
3132 auto newPortAccess =
3133 SubfieldOp::create(rewriter, port.getLoc(), port, field);
3134
3135 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
3136 auto portAccess = cast<SubfieldOp>(op);
3137 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
3138 continue;
3139 rewriter.replaceOp(portAccess, newPortAccess.getResult());
3140 }
3141 };
3142
3143 // Rewrite the field accesses.
3144 for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
3145 switch (newMem.getPortKind(i)) {
3146 case MemOp::PortKind::Debug:
3147 llvm_unreachable("cannot rewrite debug port");
3148 case MemOp::PortKind::Write:
3149 rewriteSubfield(port, "data");
3150 continue;
3151 case MemOp::PortKind::Read:
3152 rewriteSubfield(port, "data");
3153 continue;
3154 case MemOp::PortKind::ReadWrite:
3155 rewriteSubfield(port, "rdata");
3156 rewriteSubfield(port, "wdata");
3157 continue;
3158 }
3159 llvm_unreachable("unknown port kind");
3160 }
3161
3162 // Rewrite the reads to the new ranges, compacting them.
3163 for (auto readOp : readOps) {
3164 rewriter.setInsertionPointAfter(readOp);
3165 auto it = mapping.find(readOp.getLo());
3166 assert(it != mapping.end() && "bit op mapping not found");
3167 // Create a new bit selection from the compressed memory. The new op may
3168 // be folded if we are selecting the entire compressed memory.
3169 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
3170 readOp.getLoc(), readOp.getInput(),
3171 readOp.getHi() - readOp.getLo() + it->second, it->second);
3172 rewriter.replaceAllUsesWith(readOp, newReadValue);
3173 rewriter.eraseOp(readOp);
3174 }
3175
3176 // Rewrite the writes into a concatenation of slices.
3177 for (auto writeOp : writeOps) {
3178 Value source = writeOp.getSrc();
3179 rewriter.setInsertionPoint(writeOp);
3180
3181 SmallVector<Value> slices;
3182 for (auto &[start, end] : llvm::reverse(ranges)) {
3183 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
3184 source, end, start);
3185 slices.push_back(slice);
3186 }
3187
3188 Value catOfSlices =
3189 rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(), slices);
3190
3191 // If the original memory held a signed integer, then the compressed
3192 // memory will be signed too. Since the catOfSlices is always unsigned,
3193 // cast the data to a signed integer if needed before connecting back to
3194 // the memory.
3195 if (type.isSigned())
3196 catOfSlices =
3197 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
3198
3199 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
3200 catOfSlices);
3201 }
3202
3203 return success();
3204 }
3205};
3206
3207// Rewrite single-address memories to a firrtl register.
3208struct FoldRegMems : public mlir::RewritePattern {
3209 FoldRegMems(MLIRContext *context)
3210 : RewritePattern(MemOp::getOperationName(), 0, context) {}
3211 LogicalResult matchAndRewrite(Operation *op,
3212 PatternRewriter &rewriter) const override {
3213 MemOp mem = cast<MemOp>(op);
3214 const FirMemory &info = mem.getSummary();
3215 if (hasDontTouch(mem) || info.depth != 1)
3216 return failure();
3217
3218 auto ty = mem.getDataType();
3219 auto loc = mem.getLoc();
3220 auto *block = mem->getBlock();
3221
3222 // Find the clock of the register-to-be, all write ports should share it.
3223 Value clock;
3224 SmallPtrSet<Operation *, 8> connects;
3225 SmallVector<SubfieldOp> portAccesses;
3226 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3227 if (!mem.getPortAnnotation(i).empty())
3228 continue;
3229
3230 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
3231 auto portTy = type_cast<BundleType>(port.getType());
3232 for (auto field : fields) {
3233 auto fieldIndex = portTy.getElementIndex(field);
3234 assert(fieldIndex && "missing field on memory port");
3235
3236 for (auto *op : port.getUsers()) {
3237 auto portAccess = cast<SubfieldOp>(op);
3238 if (fieldIndex != portAccess.getFieldIndex())
3239 continue;
3240 portAccesses.push_back(portAccess);
3241 for (auto *user : portAccess->getUsers()) {
3242 auto conn = dyn_cast<FConnectLike>(user);
3243 if (!conn)
3244 return failure();
3245 connects.insert(conn);
3246 }
3247 }
3248 }
3249 return success();
3250 };
3251
3252 switch (mem.getPortKind(i)) {
3253 case MemOp::PortKind::Debug:
3254 return failure();
3255 case MemOp::PortKind::Read:
3256 if (failed(collect({"clk", "en", "addr"})))
3257 return failure();
3258 continue;
3259 case MemOp::PortKind::Write:
3260 if (failed(collect({"clk", "en", "addr", "data", "mask"})))
3261 return failure();
3262 break;
3263 case MemOp::PortKind::ReadWrite:
3264 if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
3265 return failure();
3266 break;
3267 }
3268
3269 Value portClock = getPortFieldValue(port, "clk");
3270 if (!portClock || (clock && portClock != clock))
3271 return failure();
3272 clock = portClock;
3273 }
3274 // Create a new wire where the memory used to be. This wire will dominate
3275 // all readers of the memory. Reads should be made through this wire.
3276 rewriter.setInsertionPointAfter(mem);
3277 auto memWire = WireOp::create(rewriter, loc, ty).getResult();
3278
3279 // The memory is replaced by a register, which we place at the end of the
3280 // block, so that any value driven to the original memory will dominate the
3281 // new register (including the clock). All other ops will be placed
3282 // after the register.
3283 rewriter.setInsertionPointToEnd(block);
3284 auto memReg =
3285 RegOp::create(rewriter, loc, ty, clock, mem.getName()).getResult();
3286
3287 // Connect the output of the register to the wire.
3288 MatchingConnectOp::create(rewriter, loc, memWire, memReg);
3289
3290 // Helper to insert a given number of pipeline stages through registers.
3291 // The pipelines are placed at the end of the block.
3292 auto pipeline = [&](Value value, Value clock, const Twine &name,
3293 unsigned latency) {
3294 for (unsigned i = 0; i < latency; ++i) {
3295 std::string regName;
3296 {
3297 llvm::raw_string_ostream os(regName);
3298 os << mem.getName() << "_" << name << "_" << i;
3299 }
3300 auto reg = RegOp::create(rewriter, mem.getLoc(), value.getType(), clock,
3301 rewriter.getStringAttr(regName))
3302 .getResult();
3303 MatchingConnectOp::create(rewriter, value.getLoc(), reg, value);
3304 value = reg;
3305 }
3306 return value;
3307 };
3308
3309 const unsigned writeStages = info.writeLatency - 1;
3310
3311 // Traverse each port. Replace reads with the pipelined register, discarding
3312 // the enable flag and reading unconditionally. Pipeline the mask, enable
3313 // and data bits of all write ports to be arbitrated and wired to the reg.
3314 SmallVector<std::tuple<Value, Value, Value>> writes;
3315 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
3316 Value portClock = getPortFieldValue(port, "clk");
3317 StringRef name = mem.getPortName(i);
3318
3319 auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
3320 Value value = getPortFieldValue(port, field);
3321 assert(value);
3322 return pipeline(value, portClock, name + "_" + field, stages);
3323 };
3324
3325 switch (mem.getPortKind(i)) {
3326 case MemOp::PortKind::Debug:
3327 llvm_unreachable("unknown port kind");
3328 case MemOp::PortKind::Read: {
3329 // Read ports pipeline the addr and enable signals. However, the
3330 // address must be 0 for single-address memories and the enable signal
3331 // is ignored, always reading out the register. Under these constraints,
3332 // the read port can be replaced with the value from the register.
3333 replacePortField(rewriter, port, "data", memWire);
3334 break;
3335 }
3336 case MemOp::PortKind::Write: {
3337 auto data = portPipeline("data", writeStages);
3338 auto en = portPipeline("en", writeStages);
3339 auto mask = portPipeline("mask", writeStages);
3340 writes.emplace_back(data, en, mask);
3341 break;
3342 }
3343 case MemOp::PortKind::ReadWrite: {
3344 // Always read the register into the read end.
3345 replacePortField(rewriter, port, "rdata", memWire);
3346
3347 // Create a write enable and pipeline stages.
3348 auto wdata = portPipeline("wdata", writeStages);
3349 auto wmask = portPipeline("wmask", writeStages);
3350
3351 Value en = getPortFieldValue(port, "en");
3352 Value wmode = getPortFieldValue(port, "wmode");
3353
3354 auto wen = AndPrimOp::create(rewriter, port.getLoc(), en, wmode);
3355 auto wenPipelined =
3356 pipeline(wen, portClock, name + "_wen", writeStages);
3357 writes.emplace_back(wdata, wenPipelined, wmask);
3358 break;
3359 }
3360 }
3361 }
3362
3363 // Regardless of `writeUnderWrite`, always implement PortOrder.
3364 Value next = memReg;
3365 for (auto &[data, en, mask] : writes) {
3366 Value masked;
3367
3368 // If a mask bit is used, emit muxes to select the input from the
3369 // register (no mask) or the input (mask bit set).
3370 Location loc = mem.getLoc();
3371 unsigned maskGran = info.dataWidth / info.maskBits;
3372 SmallVector<Value> chunks;
3373 for (unsigned i = 0; i < info.maskBits; ++i) {
3374 unsigned hi = (i + 1) * maskGran - 1;
3375 unsigned lo = i * maskGran;
3376
3377 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
3378 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3379 auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
3380 auto chunk = MuxPrimOp::create(rewriter, loc, bit, dataPart, nextPart);
3381 chunks.push_back(chunk);
3382 }
3383
3384 std::reverse(chunks.begin(), chunks.end());
3385 masked = rewriter.createOrFold<CatPrimOp>(loc, chunks);
3386 next = MuxPrimOp::create(rewriter, next.getLoc(), en, masked, next);
3387 }
3388 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3389 MatchingConnectOp::create(rewriter, memReg.getLoc(), memReg, typedNext);
3390
3391 // Delete the fields and their associated connects.
3392 for (Operation *conn : connects)
3393 rewriter.eraseOp(conn);
3394 for (auto portAccess : portAccesses)
3395 rewriter.eraseOp(portAccess);
3396 rewriter.eraseOp(mem);
3397
3398 return success();
3399 }
3400};
3401} // namespace
3402
3403void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3404 MLIRContext *context) {
3405 results
3406 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3407 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3408 context);
3409}
3410
3411//===----------------------------------------------------------------------===//
3412// Declarations
3413//===----------------------------------------------------------------------===//
3414
3415// Turn synchronous reset looking register updates to registers with resets.
3416// Also, const prop registers that are driven by a mux tree containing only
3417// instances of one constant or self-assigns.
3418static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3419 // reg ; connect(reg, mux(port, const, val)) ->
3420 // reg.reset(port, const); connect(reg, val)
3421
3422 // Find the one true connect, or bail
3423 auto con = getSingleConnectUserOf(reg.getResult());
3424 if (!con)
3425 return failure();
3426
3427 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3428 if (!mux)
3429 return failure();
3430 auto *high = mux.getHigh().getDefiningOp();
3431 auto *low = mux.getLow().getDefiningOp();
3432 // Reset value must be constant
3433 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3434
3435 // Detect the case if a register only has two possible drivers:
3436 // (1) itself/uninit and (2) constant.
3437 // The mux can then be replaced with the constant.
3438 // r = mux(cond, r, 3) --> r = 3
3439 // r = mux(cond, 3, r) --> r = 3
3440 bool constReg = false;
3441
3442 if (constOp && low == reg)
3443 constReg = true;
3444 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3445 constReg = true;
3446 constOp = dyn_cast<ConstantOp>(low);
3447 }
3448 if (!constOp)
3449 return failure();
3450
3451 // For a non-constant register, reset should be a module port (heuristic to
3452 // limit to intended reset lines). Replace the register anyway if constant.
3453 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3454 return failure();
3455
3456 // Check all types should be typed by now
3457 auto regTy = reg.getResult().getType();
3458 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3459 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3460 regTy.getBitWidthOrSentinel() < 0)
3461 return failure();
3462
3463 // Ok, we know we are doing the transformation.
3464
3465 // Make sure the constant dominates all users.
3466 if (constOp != &con->getBlock()->front())
3467 constOp->moveBefore(&con->getBlock()->front());
3468
3469 if (!constReg) {
3470 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3471 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3472 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3473 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3474 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3475 reg.getForceableAttr());
3476 newReg->setDialectAttrs(attrs);
3477 }
3478 auto pt = rewriter.saveInsertionPoint();
3479 rewriter.setInsertionPoint(con);
3480 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3481 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3482 rewriter.restoreInsertionPoint(pt);
3483 return success();
3484}
3485
3486LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3487 if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3488 succeeded(foldHiddenReset(op, rewriter)))
3489 return success();
3490
3491 if (succeeded(demoteForceableIfUnused(op, rewriter)))
3492 return success();
3493
3494 return failure();
3495}
3496
3497//===----------------------------------------------------------------------===//
3498// Verification Ops.
3499//===----------------------------------------------------------------------===//
3500
3501static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3502 Value enable,
3503 PatternRewriter &rewriter,
3504 bool eraseIfZero) {
3505 // If the verification op is never enabled, delete it.
3506 if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3507 if (constant.getValue().isZero()) {
3508 rewriter.eraseOp(op);
3509 return success();
3510 }
3511 }
3512
3513 // If the verification op is never triggered, delete it.
3514 if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3515 if (constant.getValue().isZero() == eraseIfZero) {
3516 rewriter.eraseOp(op);
3517 return success();
3518 }
3519 }
3520
3521 return failure();
3522}
3523
3524template <class Op, bool EraseIfZero = false>
3525static LogicalResult canonicalizeImmediateVerifOp(Op op,
3526 PatternRewriter &rewriter) {
3527 return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3528 EraseIfZero);
3529}
3530
3531void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3532 MLIRContext *context) {
3533 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3534 results.add<patterns::AssertXWhenX>(context);
3535}
3536
3537void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3538 MLIRContext *context) {
3539 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3540 results.add<patterns::AssumeXWhenX>(context);
3541}
3542
3543void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3544 RewritePatternSet &results, MLIRContext *context) {
3545 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3546 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(context);
3547}
3548
3549void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3550 MLIRContext *context) {
3551 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3552}
3553
3554//===----------------------------------------------------------------------===//
3555// InvalidValueOp
3556//===----------------------------------------------------------------------===//
3557
3558LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3559 PatternRewriter &rewriter) {
3560 // Remove `InvalidValueOp`s with no uses.
3561 if (op.use_empty()) {
3562 rewriter.eraseOp(op);
3563 return success();
3564 }
3565 // Propagate invalids through a single use which is a unary op. You cannot
3566 // propagate through multiple uses as that breaks invalid semantics. Nor
3567 // can you propagate through binary ops or generally any op which computes.
3568 // Not is an exception as it is a pure, all-bits inverse.
3569 if (op->hasOneUse() &&
3570 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3571 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3572 *op->user_begin()) ||
3573 (isa<CvtPrimOp>(*op->user_begin()) &&
3574 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3575 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3576 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3577 .getBitWidthOrSentinel() > 0))) {
3578 auto *modop = *op->user_begin();
3579 auto inv = InvalidValueOp::create(rewriter, op.getLoc(),
3580 modop->getResult(0).getType());
3581 rewriter.replaceAllOpUsesWith(modop, inv);
3582 rewriter.eraseOp(modop);
3583 rewriter.eraseOp(op);
3584 return success();
3585 }
3586 return failure();
3587}
3588
3589OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3590 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3591 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3592 return {};
3593}
3594
3595//===----------------------------------------------------------------------===//
3596// ClockGateIntrinsicOp
3597//===----------------------------------------------------------------------===//
3598
3599OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3600 // Forward the clock if one of the enables is always true.
3601 if (isConstantOne(adaptor.getEnable()) ||
3602 isConstantOne(adaptor.getTestEnable()))
3603 return getInput();
3604
3605 // Fold to a constant zero clock if the enables are always false.
3606 if (isConstantZero(adaptor.getEnable()) &&
3607 (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3608 return BoolAttr::get(getContext(), false);
3609
3610 // Forward constant zero clocks.
3611 if (isConstantZero(adaptor.getInput()))
3612 return BoolAttr::get(getContext(), false);
3613
3614 return {};
3615}
3616
3617LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3618 PatternRewriter &rewriter) {
3619 // Remove constant false test enable.
3620 if (auto testEnable = op.getTestEnable()) {
3621 if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3622 if (constOp.getValue().isZero()) {
3623 rewriter.modifyOpInPlace(op,
3624 [&] { op.getTestEnableMutable().clear(); });
3625 return success();
3626 }
3627 }
3628 }
3629
3630 return failure();
3631}
3632
3633//===----------------------------------------------------------------------===//
3634// Reference Ops.
3635//===----------------------------------------------------------------------===//
3636
3637// refresolve(forceable.ref) -> forceable.data
3638static LogicalResult
3639canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3640 auto forceable = op.getRef().getDefiningOp<Forceable>();
3641 if (!forceable || !forceable.isForceable() ||
3642 op.getRef() != forceable.getDataRef() ||
3643 op.getType() != forceable.getDataType())
3644 return failure();
3645 rewriter.replaceAllUsesWith(op, forceable.getData());
3646 return success();
3647}
3648
3649void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3650 MLIRContext *context) {
3651 results.insert<patterns::RefResolveOfRefSend>(context);
3652 results.insert(canonicalizeRefResolveOfForceable);
3653}
3654
3655OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3656 // RefCast is unnecessary if types match.
3657 if (getInput().getType() == getType())
3658 return getInput();
3659 return {};
3660}
3661
3662static bool isConstantZero(Value operand) {
3663 auto constOp = operand.getDefiningOp<ConstantOp>();
3664 return constOp && constOp.getValue().isZero();
3665}
3666
3667template <typename Op>
3668static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3669 if (isConstantZero(op.getPredicate())) {
3670 rewriter.eraseOp(op);
3671 return success();
3672 }
3673 return failure();
3674}
3675
3676void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3677 MLIRContext *context) {
3678 results.add(eraseIfPredFalse<RefForceOp>);
3679}
3680void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3681 MLIRContext *context) {
3682 results.add(eraseIfPredFalse<RefForceInitialOp>);
3683}
3684void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3685 MLIRContext *context) {
3686 results.add(eraseIfPredFalse<RefReleaseOp>);
3687}
3688void RefReleaseInitialOp::getCanonicalizationPatterns(
3689 RewritePatternSet &results, MLIRContext *context) {
3690 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3691}
3692
3693//===----------------------------------------------------------------------===//
3694// HasBeenResetIntrinsicOp
3695//===----------------------------------------------------------------------===//
3696
3697OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3698 // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3699
3700 // Fold to zero if the reset is a constant. In this case the op is either
3701 // permanently in reset or never resets. Both mean that the reset never
3702 // finishes, so this op never returns true.
3703 if (adaptor.getReset())
3704 return getIntZerosAttr(UIntType::get(getContext(), 1));
3705
3706 // Fold to zero if the clock is a constant and the reset is synchronous. In
3707 // that case the reset will never be started.
3708 if (isUInt1(getReset().getType()) && adaptor.getClock())
3709 return getIntZerosAttr(UIntType::get(getContext(), 1));
3710
3711 return {};
3712}
3713
3714//===----------------------------------------------------------------------===//
3715// FPGAProbeIntrinsicOp
3716//===----------------------------------------------------------------------===//
3717
3718static bool isTypeEmpty(FIRRTLType type) {
3720 .Case<FVectorType>(
3721 [&](auto ty) -> bool { return isTypeEmpty(ty.getElementType()); })
3722 .Case<BundleType>([&](auto ty) -> bool {
3723 for (auto elem : ty.getElements())
3724 if (!isTypeEmpty(elem.type))
3725 return false;
3726 return true;
3727 })
3728 .Case<IntType>([&](auto ty) { return ty.getWidth() == 0; })
3729 .Default([](auto) -> bool { return false; });
3730}
3731
3732LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3733 PatternRewriter &rewriter) {
3734 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3735 if (!firrtlTy)
3736 return failure();
3737
3738 if (!isTypeEmpty(firrtlTy))
3739 return failure();
3740
3741 rewriter.eraseOp(op);
3742 return success();
3743}
3744
3745//===----------------------------------------------------------------------===//
3746// Layer Block Op
3747//===----------------------------------------------------------------------===//
3748
3749LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3750 PatternRewriter &rewriter) {
3751
3752 // If the layerblock is empty, erase it.
3753 if (op.getBody()->empty()) {
3754 rewriter.eraseOp(op);
3755 return success();
3756 }
3757
3758 return failure();
3759}
3760
3761//===----------------------------------------------------------------------===//
3762// Domain-related Ops
3763//===----------------------------------------------------------------------===//
3764
3765OpFoldResult UnsafeDomainCastOp::fold(FoldAdaptor adaptor) {
3766 // If no domains are specified, then forward the input to the result.
3767 if (getDomains().empty())
3768 return getInput();
3769
3770 return {};
3771}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static bool isOkToPropagateName(Operation *op)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static InstancePath empty
AndRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
OrRCat(MLIRContext *context)
bool getIdentityValue() const override
Return the unit value for this reduction operation:
virtual bool getIdentityValue() const =0
Return the unit value for this reduction operation:
virtual bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp constantOp, SmallVectorImpl< Value > &remaining) const =0
Handle a constant operand in the cat operation.
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
ReductionCat(MLIRContext *context, llvm::StringLiteral opName)
XorRCat(MLIRContext *context)
bool handleConstant(mlir::PatternRewriter &rewriter, Operation *op, ConstantOp value, SmallVectorImpl< Value > &remaining) const override
Handle a constant operand in the cat operation.
bool getIdentityValue() const override
Return the unit value for this reduction operation:
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:18
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition FoldUtils.h:28
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition FoldUtils.h:35
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21
LogicalResult matchAndRewrite(Operation *op, mlir::PatternRewriter &rewriter) const override
BitsOfCat(MLIRContext *context)