CIRCT 20.0.0git
Loading...
Searching...
No Matches
FIRRTLFolds.cpp
Go to the documentation of this file.
1//===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the folding and canonicalizations for FIRRTL ops.
10//
11//===----------------------------------------------------------------------===//
12
17#include "circt/Support/APInt.h"
18#include "circt/Support/LLVM.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/SmallPtrSet.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/ADT/TypeSwitch.h"
26
27using namespace circt;
28using namespace firrtl;
29
30// Drop writes to old and pass through passthrough to make patterns easier to
31// write.
32static Value dropWrite(PatternRewriter &rewriter, OpResult old,
33 Value passthrough) {
34 SmallPtrSet<Operation *, 8> users;
35 for (auto *user : old.getUsers())
36 users.insert(user);
37 for (Operation *user : users)
38 if (auto connect = dyn_cast<FConnectLike>(user))
39 if (connect.getDest() == old)
40 rewriter.eraseOp(user);
41 return passthrough;
42}
43
44// Move a name hint from a soon to be deleted operation to a new operation.
45// Pass through the new operation to make patterns easier to write. This cannot
46// move a name to a port (block argument), doing so would require rewriting all
47// instance sites as well as the module.
48static Value moveNameHint(OpResult old, Value passthrough) {
49 Operation *op = passthrough.getDefiningOp();
50 // This should handle ports, but it isn't clear we can change those in
51 // canonicalizers.
52 assert(op && "passthrough must be an operation");
53 Operation *oldOp = old.getOwner();
54 auto name = oldOp->getAttrOfType<StringAttr>("name");
55 if (name && !name.getValue().empty())
56 op->setAttr("name", name);
57 return passthrough;
58}
59
60// Declarative canonicalization patterns
61namespace circt {
62namespace firrtl {
63namespace patterns {
64#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
65} // namespace patterns
66} // namespace firrtl
67} // namespace circt
68
69/// Return true if this operation's operands and results all have a known width.
70/// This only works for integer types.
71static bool hasKnownWidthIntTypes(Operation *op) {
72 auto resultType = type_cast<IntType>(op->getResult(0).getType());
73 if (!resultType.hasWidth())
74 return false;
75 for (Value operand : op->getOperands())
76 if (!type_cast<IntType>(operand.getType()).hasWidth())
77 return false;
78 return true;
79}
80
81/// Return true if this value is 1 bit UInt.
82static bool isUInt1(Type type) {
83 auto t = type_dyn_cast<UIntType>(type);
84 if (!t || !t.hasWidth() || t.getWidth() != 1)
85 return false;
86 return true;
87}
88
89/// Set the name of an op based on the best of two names: The current name, and
90/// the name passed in.
91static void updateName(PatternRewriter &rewriter, Operation *op,
92 StringAttr name) {
93 // Should never rename InstanceOp
94 assert(!isa<InstanceOp>(op));
95 if (!name || name.getValue().empty())
96 return;
97 auto newName = name.getValue(); // old name is interesting
98 auto newOpName = op->getAttrOfType<StringAttr>("name");
99 // new name might not be interesting
100 if (newOpName)
101 newName = chooseName(newOpName.getValue(), name.getValue());
102 // Only update if needed
103 if (!newOpName || newOpName.getValue() != newName)
104 rewriter.modifyOpInPlace(
105 op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
106}
107
108/// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
109/// If a replaced op has a "name" attribute, this function propagates the name
110/// to the new value.
111static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
112 Value newValue) {
113 if (auto *newOp = newValue.getDefiningOp()) {
114 auto name = op->getAttrOfType<StringAttr>("name");
115 updateName(rewriter, newOp, name);
116 }
117 rewriter.replaceOp(op, newValue);
118}
119
120/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
121/// attribute. If a replaced op has a "name" attribute, this function propagates
122/// the name to the new value.
123template <typename OpTy, typename... Args>
124static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
125 Operation *op, Args &&...args) {
126 auto name = op->getAttrOfType<StringAttr>("name");
127 auto newOp =
128 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
129 updateName(rewriter, newOp, name);
130 return newOp;
131}
132
133/// Return true if the name is droppable. Note that this is different from
134/// `isUselessName` because non-useless names may be also droppable.
136 if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
137 return namableOp.hasDroppableName();
138 return false;
139}
140
141/// Implicitly replace the operand to a constant folding operation with a const
142/// 0 in case the operand is non-constant but has a bit width 0, or if the
143/// operand is an invalid value.
144///
145/// This makes constant folding significantly easier, as we can simply pass the
146/// operands to an operation through this function to appropriately replace any
147/// zero-width dynamic values or invalid values with a constant of value 0.
148static std::optional<APSInt>
149getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
150 assert(type_cast<IntType>(operand.getType()) &&
151 "getExtendedConstant is limited to integer types");
152
153 // We never support constant folding to unknown width values.
154 if (destWidth < 0)
155 return {};
156
157 // Extension signedness follows the operand sign.
158 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
159 return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
160
161 // If the operand is zero bits, then we can return a zero of the result
162 // type.
163 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
164 return APSInt(destWidth,
165 type_cast<IntType>(operand.getType()).isUnsigned());
166 return {};
167}
168
169/// Determine the value of a constant operand for the sake of constant folding.
170static std::optional<APSInt> getConstant(Attribute operand) {
171 if (!operand)
172 return {};
173 if (auto attr = dyn_cast<BoolAttr>(operand))
174 return APSInt(APInt(1, attr.getValue()));
175 if (auto attr = dyn_cast<IntegerAttr>(operand))
176 return attr.getAPSInt();
177 return {};
178}
179
180/// Determine whether a constant operand is a zero value for the sake of
181/// constant folding. This considers `invalidvalue` to be zero.
182static bool isConstantZero(Attribute operand) {
183 if (auto cst = getConstant(operand))
184 return cst->isZero();
185 return false;
186}
187
188/// Determine whether a constant operand is a one value for the sake of constant
189/// folding.
190static bool isConstantOne(Attribute operand) {
191 if (auto cst = getConstant(operand))
192 return cst->isOne();
193 return false;
194}
195
196/// This is the policy for folding, which depends on the sort of operator we're
197/// processing.
198enum class BinOpKind {
199 Normal,
200 Compare,
202};
203
204/// Applies the constant folding function `calculate` to the given operands.
205///
206/// Sign or zero extends the operands appropriately to the bitwidth of the
207/// result type if \p useDstWidth is true, else to the larger of the two operand
208/// bit widths and depending on whether the operation is to be performed on
209/// signed or unsigned operands.
210static Attribute constFoldFIRRTLBinaryOp(
211 Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
212 const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
213 assert(operands.size() == 2 && "binary op takes two operands");
214
215 // We cannot fold something to an unknown width.
216 auto resultType = type_cast<IntType>(op->getResult(0).getType());
217 if (resultType.getWidthOrSentinel() < 0)
218 return {};
219
220 // Any binary op returning i0 is 0.
221 if (resultType.getWidthOrSentinel() == 0)
222 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
223
224 // Determine the operand widths. This is either dictated by the operand type,
225 // or if that type is an unsized integer, by the actual bits necessary to
226 // represent the constant value.
227 auto lhsWidth =
228 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
229 auto rhsWidth =
230 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
231 if (auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
232 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
233 if (auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
234 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
235
236 // Compares extend the operands to the widest of the operand types, not to the
237 // result type.
238 int32_t operandWidth;
239 switch (opKind) {
241 operandWidth = resultType.getWidthOrSentinel();
242 break;
244 // Compares compute with the widest operand, not at the destination type
245 // (which is always i1).
246 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
247 break;
249 operandWidth =
250 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
251 break;
252 }
253
254 auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
255 if (!lhs)
256 return {};
257 auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
258 if (!rhs)
259 return {};
260
261 APInt resultValue = calculate(*lhs, *rhs);
262
263 // If the result type is smaller than the computation then we need to
264 // narrow the constant after the calculation.
265 if (opKind == BinOpKind::DivideOrShift)
266 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
267
268 assert((unsigned)resultType.getWidthOrSentinel() ==
269 resultValue.getBitWidth());
270 return getIntAttr(resultType, resultValue);
271}
272
273/// Applies the canonicalization function `canonicalize` to the given operation.
274///
275/// Determines which (if any) of the operation's operands are constants, and
276/// provides them as arguments to the callback function. Any `invalidvalue` in
277/// the input is mapped to a constant zero. The value returned from the callback
278/// is used as the replacement for `op`, and an additional pad operation is
279/// inserted if necessary. Does nothing if the result of `op` is of unknown
280/// width, in which case the necessity of a pad cannot be determined.
281static LogicalResult canonicalizePrimOp(
282 Operation *op, PatternRewriter &rewriter,
283 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
284 // Can only operate on FIRRTL primitive operations.
285 if (op->getNumResults() != 1)
286 return failure();
287 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
288 if (!type)
289 return failure();
290
291 // Can only operate on operations with a known result width.
292 auto width = type.getBitWidthOrSentinel();
293 if (width < 0)
294 return failure();
295
296 // Determine which of the operands are constants.
297 SmallVector<Attribute, 3> constOperands;
298 constOperands.reserve(op->getNumOperands());
299 for (auto operand : op->getOperands()) {
300 Attribute attr;
301 if (auto *defOp = operand.getDefiningOp())
302 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
303 [&](auto op) { attr = op.getValueAttr(); });
304 constOperands.push_back(attr);
305 }
306
307 // Perform the canonicalization and materialize the result if it is a
308 // constant.
309 auto result = canonicalize(constOperands);
310 if (!result)
311 return failure();
312 Value resultValue;
313 if (auto cst = dyn_cast<Attribute>(result))
314 resultValue = op->getDialect()
315 ->materializeConstant(rewriter, cst, type, op->getLoc())
316 ->getResult(0);
317 else
318 resultValue = cast<Value>(result);
319
320 // Insert a pad if the type widths disagree.
321 if (width !=
322 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
323 resultValue = rewriter.create<PadPrimOp>(op->getLoc(), resultValue, width);
324
325 // Insert a cast if this is a uint vs. sint or vice versa.
326 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
327 resultValue = rewriter.create<AsSIntPrimOp>(op->getLoc(), resultValue);
328 else if (type_isa<UIntType>(type) &&
329 type_isa<SIntType>(resultValue.getType()))
330 resultValue = rewriter.create<AsUIntPrimOp>(op->getLoc(), resultValue);
331
332 assert(type == resultValue.getType() && "canonicalization changed type");
333 replaceOpAndCopyName(rewriter, op, resultValue);
334 return success();
335}
336
337/// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
338/// value if `bitWidth` is 0.
339static APInt getMaxUnsignedValue(unsigned bitWidth) {
340 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
341}
342
343/// Get the smallest signed value of a given bit width. Returns a 1-bit zero
344/// value if `bitWidth` is 0.
345static APInt getMinSignedValue(unsigned bitWidth) {
346 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
347}
348
349/// Get the largest signed value of a given bit width. Returns a 1-bit zero
350/// value if `bitWidth` is 0.
351static APInt getMaxSignedValue(unsigned bitWidth) {
352 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
353}
354
355//===----------------------------------------------------------------------===//
356// Fold Hooks
357//===----------------------------------------------------------------------===//
358
359OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
360 assert(adaptor.getOperands().empty() && "constant has no operands");
361 return getValueAttr();
362}
363
364OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
365 assert(adaptor.getOperands().empty() && "constant has no operands");
366 return getValueAttr();
367}
368
369OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
370 assert(adaptor.getOperands().empty() && "constant has no operands");
371 return getFieldsAttr();
372}
373
374OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
375 assert(adaptor.getOperands().empty() && "constant has no operands");
376 return getValueAttr();
377}
378
379OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
380 assert(adaptor.getOperands().empty() && "constant has no operands");
381 return getValueAttr();
382}
383
384OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
385 assert(adaptor.getOperands().empty() && "constant has no operands");
386 return getValueAttr();
387}
388
389OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
390 assert(adaptor.getOperands().empty() && "constant has no operands");
391 return getValueAttr();
392}
393
394//===----------------------------------------------------------------------===//
395// Binary Operators
396//===----------------------------------------------------------------------===//
397
398OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
400 *this, adaptor.getOperands(), BinOpKind::Normal,
401 [=](const APSInt &a, const APSInt &b) { return a + b; });
402}
403
404void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
405 MLIRContext *context) {
406 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
407 patterns::AddOfSelf, patterns::AddOfPad>(context);
408}
409
410OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
412 *this, adaptor.getOperands(), BinOpKind::Normal,
413 [=](const APSInt &a, const APSInt &b) { return a - b; });
414}
415
416void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417 MLIRContext *context) {
418 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
419 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
420 patterns::SubOfPadL, patterns::SubOfPadR>(context);
421}
422
423OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
424 // mul(x, 0) -> 0
425 //
426 // This is legal because it aligns with the Scala FIRRTL Compiler
427 // interpretation of lowering invalid to constant zero before constant
428 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
429 // multiplication this way and will emit "x * 0".
430 if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
431 return getIntZerosAttr(getType());
432
434 *this, adaptor.getOperands(), BinOpKind::Normal,
435 [=](const APSInt &a, const APSInt &b) { return a * b; });
436}
437
438OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
439 /// div(x, x) -> 1
440 ///
441 /// Division by zero is undefined in the FIRRTL specification. This fold
442 /// exploits that fact to optimize self division to one. Note: this should
443 /// supersede any division with invalid or zero. Division of invalid by
444 /// invalid should be one.
445 if (getLhs() == getRhs()) {
446 auto width = getType().base().getWidthOrSentinel();
447 if (width == -1)
448 width = 2;
449 // Only fold if we have at least 1 bit of width to represent the `1` value.
450 if (width != 0)
451 return getIntAttr(getType(), APInt(width, 1));
452 }
453
454 // div(0, x) -> 0
455 //
456 // This is legal because it aligns with the Scala FIRRTL Compiler
457 // interpretation of lowering invalid to constant zero before constant
458 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
459 // division this way and will emit "0 / x".
460 if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
461 return getIntZerosAttr(getType());
462
463 /// div(x, 1) -> x : (uint, uint) -> uint
464 ///
465 /// UInt division by one returns the numerator. SInt division can't
466 /// be folded here because it increases the return type bitwidth by
467 /// one and requires sign extension (a new op).
468 if (auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
469 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
470 return getLhs();
471
473 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
474 [=](const APSInt &a, const APSInt &b) -> APInt {
475 if (!!b)
476 return a / b;
477 return APInt(a.getBitWidth(), 0);
478 });
479}
480
481OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
482 // rem(x, x) -> 0
483 //
484 // Division by zero is undefined in the FIRRTL specification. This fold
485 // exploits that fact to optimize self division remainder to zero. Note:
486 // this should supersede any division with invalid or zero. Remainder of
487 // division of invalid by invalid should be zero.
488 if (getLhs() == getRhs())
489 return getIntZerosAttr(getType());
490
491 // rem(0, x) -> 0
492 //
493 // This is legal because it aligns with the Scala FIRRTL Compiler
494 // interpretation of lowering invalid to constant zero before constant
495 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
496 // division this way and will emit "0 % x".
497 if (isConstantZero(adaptor.getLhs()))
498 return getIntZerosAttr(getType());
499
501 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
502 [=](const APSInt &a, const APSInt &b) -> APInt {
503 if (!!b)
504 return a % b;
505 return APInt(a.getBitWidth(), 0);
506 });
507}
508
509OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
511 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
512 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
513}
514
515OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
517 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
518 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
519}
520
521OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
523 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
524 [=](const APSInt &a, const APSInt &b) -> APInt {
525 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
526 : a.ashr(b);
527 });
528}
529
530// TODO: Move to DRR.
531OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
532 if (auto rhsCst = getConstant(adaptor.getRhs())) {
533 /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
534 if (rhsCst->isZero())
535 return getIntZerosAttr(getType());
536
537 /// and(x, -1) -> x
538 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
539 getRhs().getType() == getType())
540 return getLhs();
541 }
542
543 if (auto lhsCst = getConstant(adaptor.getLhs())) {
544 /// and(0, x) -> 0, 0 is largest or is implicit zero extended
545 if (lhsCst->isZero())
546 return getIntZerosAttr(getType());
547
548 /// and(-1, x) -> x
549 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
550 getRhs().getType() == getType())
551 return getRhs();
552 }
553
554 /// and(x, x) -> x
555 if (getLhs() == getRhs() && getRhs().getType() == getType())
556 return getRhs();
557
559 *this, adaptor.getOperands(), BinOpKind::Normal,
560 [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
561}
562
563void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
564 MLIRContext *context) {
565 results
566 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
567 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
568 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
569}
570
571OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
572 if (auto rhsCst = getConstant(adaptor.getRhs())) {
573 /// or(x, 0) -> x
574 if (rhsCst->isZero() && getLhs().getType() == getType())
575 return getLhs();
576
577 /// or(x, -1) -> -1
578 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
579 getLhs().getType() == getType())
580 return getRhs();
581 }
582
583 if (auto lhsCst = getConstant(adaptor.getLhs())) {
584 /// or(0, x) -> x
585 if (lhsCst->isZero() && getRhs().getType() == getType())
586 return getRhs();
587
588 /// or(-1, x) -> -1
589 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
590 getRhs().getType() == getType())
591 return getLhs();
592 }
593
594 /// or(x, x) -> x
595 if (getLhs() == getRhs() && getRhs().getType() == getType())
596 return getRhs();
597
599 *this, adaptor.getOperands(), BinOpKind::Normal,
600 [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
601}
602
603void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
604 MLIRContext *context) {
605 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
606 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
607 patterns::OrOrr>(context);
608}
609
610OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
611 /// xor(x, 0) -> x
612 if (auto rhsCst = getConstant(adaptor.getRhs()))
613 if (rhsCst->isZero() &&
614 firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
615 return getLhs();
616
617 /// xor(x, 0) -> x
618 if (auto lhsCst = getConstant(adaptor.getLhs()))
619 if (lhsCst->isZero() &&
620 firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
621 return getRhs();
622
623 /// xor(x, x) -> 0
624 if (getLhs() == getRhs())
625 return getIntAttr(
626 getType(),
627 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
628
630 *this, adaptor.getOperands(), BinOpKind::Normal,
631 [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
632}
633
634void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
635 MLIRContext *context) {
636 results.insert<patterns::extendXor, patterns::moveConstXor,
637 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
638 context);
639}
640
641void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
642 MLIRContext *context) {
643 results.insert<patterns::LEQWithConstLHS>(context);
644}
645
646OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
647 bool isUnsigned = getLhs().getType().base().isUnsigned();
648
649 // leq(x, x) -> 1
650 if (getLhs() == getRhs())
651 return getIntAttr(getType(), APInt(1, 1));
652
653 // Comparison against constant outside type bounds.
654 if (auto width = getLhs().getType().base().getWidth()) {
655 if (auto rhsCst = getConstant(adaptor.getRhs())) {
656 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
657 commonWidth = std::max(commonWidth, 1);
658
659 // leq(x, const) -> 0 where const < minValue of the unsigned type of x
660 // This can never occur since const is unsigned and cannot be less than 0.
661
662 // leq(x, const) -> 0 where const < minValue of the signed type of x
663 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
664 .slt(getMinSignedValue(*width).sext(commonWidth)))
665 return getIntAttr(getType(), APInt(1, 0));
666
667 // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
668 if (isUnsigned && rhsCst->zext(commonWidth)
669 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
670 return getIntAttr(getType(), APInt(1, 1));
671
672 // leq(x, const) -> 1 where const >= maxValue of the signed type of x
673 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
674 .sge(getMaxSignedValue(*width).sext(commonWidth)))
675 return getIntAttr(getType(), APInt(1, 1));
676 }
677 }
678
680 *this, adaptor.getOperands(), BinOpKind::Compare,
681 [=](const APSInt &a, const APSInt &b) -> APInt {
682 return APInt(1, a <= b);
683 });
684}
685
686void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
687 MLIRContext *context) {
688 results.insert<patterns::LTWithConstLHS>(context);
689}
690
691OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
692 IntType lhsType = getLhs().getType();
693 bool isUnsigned = lhsType.isUnsigned();
694
695 // lt(x, x) -> 0
696 if (getLhs() == getRhs())
697 return getIntAttr(getType(), APInt(1, 0));
698
699 // lt(x, 0) -> 0 when x is unsigned
700 if (auto rhsCst = getConstant(adaptor.getRhs())) {
701 if (rhsCst->isZero() && lhsType.isUnsigned())
702 return getIntAttr(getType(), APInt(1, 0));
703 }
704
705 // Comparison against constant outside type bounds.
706 if (auto width = lhsType.getWidth()) {
707 if (auto rhsCst = getConstant(adaptor.getRhs())) {
708 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
709 commonWidth = std::max(commonWidth, 1);
710
711 // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
712 // Handled explicitly above.
713
714 // lt(x, const) -> 0 where const <= minValue of the signed type of x
715 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
716 .sle(getMinSignedValue(*width).sext(commonWidth)))
717 return getIntAttr(getType(), APInt(1, 0));
718
719 // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
720 if (isUnsigned && rhsCst->zext(commonWidth)
721 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
722 return getIntAttr(getType(), APInt(1, 1));
723
724 // lt(x, const) -> 1 where const > maxValue of the signed type of x
725 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
726 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
727 return getIntAttr(getType(), APInt(1, 1));
728 }
729 }
730
732 *this, adaptor.getOperands(), BinOpKind::Compare,
733 [=](const APSInt &a, const APSInt &b) -> APInt {
734 return APInt(1, a < b);
735 });
736}
737
738void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
739 MLIRContext *context) {
740 results.insert<patterns::GEQWithConstLHS>(context);
741}
742
743OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
744 IntType lhsType = getLhs().getType();
745 bool isUnsigned = lhsType.isUnsigned();
746
747 // geq(x, x) -> 1
748 if (getLhs() == getRhs())
749 return getIntAttr(getType(), APInt(1, 1));
750
751 // geq(x, 0) -> 1 when x is unsigned
752 if (auto rhsCst = getConstant(adaptor.getRhs())) {
753 if (rhsCst->isZero() && isUnsigned)
754 return getIntAttr(getType(), APInt(1, 1));
755 }
756
757 // Comparison against constant outside type bounds.
758 if (auto width = lhsType.getWidth()) {
759 if (auto rhsCst = getConstant(adaptor.getRhs())) {
760 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
761 commonWidth = std::max(commonWidth, 1);
762
763 // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
764 if (isUnsigned && rhsCst->zext(commonWidth)
765 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
766 return getIntAttr(getType(), APInt(1, 0));
767
768 // geq(x, const) -> 0 where const > maxValue of the signed type of x
769 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
770 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
771 return getIntAttr(getType(), APInt(1, 0));
772
773 // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
774 // Handled explicitly above.
775
776 // geq(x, const) -> 1 where const <= minValue of the signed type of x
777 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
778 .sle(getMinSignedValue(*width).sext(commonWidth)))
779 return getIntAttr(getType(), APInt(1, 1));
780 }
781 }
782
784 *this, adaptor.getOperands(), BinOpKind::Compare,
785 [=](const APSInt &a, const APSInt &b) -> APInt {
786 return APInt(1, a >= b);
787 });
788}
789
790void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
791 MLIRContext *context) {
792 results.insert<patterns::GTWithConstLHS>(context);
793}
794
795OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
796 IntType lhsType = getLhs().getType();
797 bool isUnsigned = lhsType.isUnsigned();
798
799 // gt(x, x) -> 0
800 if (getLhs() == getRhs())
801 return getIntAttr(getType(), APInt(1, 0));
802
803 // Comparison against constant outside type bounds.
804 if (auto width = lhsType.getWidth()) {
805 if (auto rhsCst = getConstant(adaptor.getRhs())) {
806 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
807 commonWidth = std::max(commonWidth, 1);
808
809 // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
810 if (isUnsigned && rhsCst->zext(commonWidth)
811 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
812 return getIntAttr(getType(), APInt(1, 0));
813
814 // gt(x, const) -> 0 where const >= maxValue of the signed type of x
815 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
816 .sge(getMaxSignedValue(*width).sext(commonWidth)))
817 return getIntAttr(getType(), APInt(1, 0));
818
819 // gt(x, const) -> 1 where const < minValue of the unsigned type of x
820 // This can never occur since const is unsigned and cannot be less than 0.
821
822 // gt(x, const) -> 1 where const < minValue of the signed type of x
823 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
824 .slt(getMinSignedValue(*width).sext(commonWidth)))
825 return getIntAttr(getType(), APInt(1, 1));
826 }
827 }
828
830 *this, adaptor.getOperands(), BinOpKind::Compare,
831 [=](const APSInt &a, const APSInt &b) -> APInt {
832 return APInt(1, a > b);
833 });
834}
835
836OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
837 // eq(x, x) -> 1
838 if (getLhs() == getRhs())
839 return getIntAttr(getType(), APInt(1, 1));
840
841 if (auto rhsCst = getConstant(adaptor.getRhs())) {
842 /// eq(x, 1) -> x when x is 1 bit.
843 /// TODO: Support SInt<1> on the LHS etc.
844 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
845 getRhs().getType() == getType())
846 return getLhs();
847 }
848
850 *this, adaptor.getOperands(), BinOpKind::Compare,
851 [=](const APSInt &a, const APSInt &b) -> APInt {
852 return APInt(1, a == b);
853 });
854}
855
856LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
857 return canonicalizePrimOp(
858 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
859 if (auto rhsCst = getConstant(operands[1])) {
860 auto width = op.getLhs().getType().getBitWidthOrSentinel();
861
862 // eq(x, 0) -> not(x) when x is 1 bit.
863 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
864 op.getRhs().getType() == op.getType()) {
865 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
866 .getResult();
867 }
868
869 // eq(x, 0) -> not(orr(x)) when x is >1 bit
870 if (rhsCst->isZero() && width > 1) {
871 auto orrOp = rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs());
872 return rewriter.create<NotPrimOp>(op.getLoc(), orrOp).getResult();
873 }
874
875 // eq(x, ~0) -> andr(x) when x is >1 bit
876 if (rhsCst->isAllOnes() && width > 1 &&
877 op.getLhs().getType() == op.getRhs().getType()) {
878 return rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs())
879 .getResult();
880 }
881 }
882 return {};
883 });
884}
885
886OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
887 // neq(x, x) -> 0
888 if (getLhs() == getRhs())
889 return getIntAttr(getType(), APInt(1, 0));
890
891 if (auto rhsCst = getConstant(adaptor.getRhs())) {
892 /// neq(x, 0) -> x when x is 1 bit.
893 /// TODO: Support SInt<1> on the LHS etc.
894 if (rhsCst->isZero() && getLhs().getType() == getType() &&
895 getRhs().getType() == getType())
896 return getLhs();
897 }
898
900 *this, adaptor.getOperands(), BinOpKind::Compare,
901 [=](const APSInt &a, const APSInt &b) -> APInt {
902 return APInt(1, a != b);
903 });
904}
905
906LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
907 return canonicalizePrimOp(
908 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
909 if (auto rhsCst = getConstant(operands[1])) {
910 auto width = op.getLhs().getType().getBitWidthOrSentinel();
911
912 // neq(x, 1) -> not(x) when x is 1 bit
913 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
914 op.getRhs().getType() == op.getType()) {
915 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
916 .getResult();
917 }
918
919 // neq(x, 0) -> orr(x) when x is >1 bit
920 if (rhsCst->isZero() && width > 1) {
921 return rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs())
922 .getResult();
923 }
924
925 // neq(x, ~0) -> not(andr(x))) when x is >1 bit
926 if (rhsCst->isAllOnes() && width > 1 &&
927 op.getLhs().getType() == op.getRhs().getType()) {
928 auto andrOp = rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs());
929 return rewriter.create<NotPrimOp>(op.getLoc(), andrOp).getResult();
930 }
931 }
932
933 return {};
934 });
935}
936
937OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
938 // TODO: implement constant folding, etc.
939 // Tracked in https://github.com/llvm/circt/issues/6696.
940 return {};
941}
942
943OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
944 // TODO: implement constant folding, etc.
945 // Tracked in https://github.com/llvm/circt/issues/6724.
946 return {};
947}
948
949OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
950 // TODO: implement constant folding, etc.
951 // Tracked in https://github.com/llvm/circt/issues/6725.
952 return {};
953}
954
955OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
956 if (auto rhsCst = getConstant(adaptor.getRhs())) {
957 // Constant folding
958 if (auto lhsCst = getConstant(adaptor.getLhs()))
959
960 return IntegerAttr::get(
961 IntegerType::get(getContext(), lhsCst->getBitWidth()),
962 lhsCst->shl(*rhsCst));
963
964 // integer.shl(x, 0) -> x
965 if (rhsCst->isZero())
966 return getLhs();
967 }
968
969 return {};
970}
971
972//===----------------------------------------------------------------------===//
973// Unary Operators
974//===----------------------------------------------------------------------===//
975
976OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
977 auto base = getInput().getType();
978 auto w = getBitWidth(base);
979 if (w)
980 return getIntAttr(getType(), APInt(32, *w));
981 return {};
982}
983
984OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
985 // No constant can be 'x' by definition.
986 if (auto cst = getConstant(adaptor.getArg()))
987 return getIntAttr(getType(), APInt(1, 0));
988 return {};
989}
990
991OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
992 // No effect.
993 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
994 return getInput();
995
996 // Be careful to only fold the cast into the constant if the size is known.
997 // Otherwise width inference may produce differently-sized constants if the
998 // sign changes.
999 if (getType().base().hasWidth())
1000 if (auto cst = getConstant(adaptor.getInput()))
1001 return getIntAttr(getType(), *cst);
1002
1003 return {};
1004}
1005
1006void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1007 MLIRContext *context) {
1008 results.insert<patterns::StoUtoS>(context);
1009}
1010
1011OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1012 // No effect.
1013 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1014 return getInput();
1015
1016 // Be careful to only fold the cast into the constant if the size is known.
1017 // Otherwise width inference may produce differently-sized constants if the
1018 // sign changes.
1019 if (getType().base().hasWidth())
1020 if (auto cst = getConstant(adaptor.getInput()))
1021 return getIntAttr(getType(), *cst);
1022
1023 return {};
1024}
1025
1026void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1027 MLIRContext *context) {
1028 results.insert<patterns::UtoStoU>(context);
1029}
1030
1031OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1032 // No effect.
1033 if (getInput().getType() == getType())
1034 return getInput();
1035
1036 // Constant fold.
1037 if (auto cst = getConstant(adaptor.getInput()))
1038 return BoolAttr::get(getContext(), cst->getBoolValue());
1039
1040 return {};
1041}
1042
1043OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1044 // No effect.
1045 if (getInput().getType() == getType())
1046 return getInput();
1047
1048 // Constant fold.
1049 if (auto cst = getConstant(adaptor.getInput()))
1050 return BoolAttr::get(getContext(), cst->getBoolValue());
1051
1052 return {};
1053}
1054
1055OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1056 if (!hasKnownWidthIntTypes(*this))
1057 return {};
1058
1059 // Signed to signed is a noop, unsigned operands prepend a zero bit.
1060 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1061 getType().base().getWidthOrSentinel()))
1062 return getIntAttr(getType(), *cst);
1063
1064 return {};
1065}
1066
1067void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1068 MLIRContext *context) {
1069 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1070}
1071
1072OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1073 if (!hasKnownWidthIntTypes(*this))
1074 return {};
1075
1076 // FIRRTL negate always adds a bit.
1077 // -x ---> 0-sext(x) or 0-zext(x)
1078 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1079 getType().base().getWidthOrSentinel()))
1080 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1081
1082 return {};
1083}
1084
1085OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1086 if (!hasKnownWidthIntTypes(*this))
1087 return {};
1088
1089 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1090 getType().base().getWidthOrSentinel()))
1091 return getIntAttr(getType(), ~*cst);
1092
1093 return {};
1094}
1095
1096void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1097 MLIRContext *context) {
1098 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1099 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1100 patterns::NotGt>(context);
1101}
1102
1103OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1104 if (!hasKnownWidthIntTypes(*this))
1105 return {};
1106
1107 if (getInput().getType().getBitWidthOrSentinel() == 0)
1108 return getIntAttr(getType(), APInt(1, 1));
1109
1110 // x == -1
1111 if (auto cst = getConstant(adaptor.getInput()))
1112 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1113
1114 // one bit is identity. Only applies to UInt since we can't make a cast
1115 // here.
1116 if (isUInt1(getInput().getType()))
1117 return getInput();
1118
1119 return {};
1120}
1121
1122void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1123 MLIRContext *context) {
1124 results
1125 .insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1126 patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
1127 patterns::AndRCatZeroL, patterns::AndRCatZeroR,
1128 patterns::AndRCatAndR_left, patterns::AndRCatAndR_right>(context);
1129}
1130
1131OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1132 if (!hasKnownWidthIntTypes(*this))
1133 return {};
1134
1135 if (getInput().getType().getBitWidthOrSentinel() == 0)
1136 return getIntAttr(getType(), APInt(1, 0));
1137
1138 // x != 0
1139 if (auto cst = getConstant(adaptor.getInput()))
1140 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1141
1142 // one bit is identity. Only applies to UInt since we can't make a cast
1143 // here.
1144 if (isUInt1(getInput().getType()))
1145 return getInput();
1146
1147 return {};
1148}
1149
1150void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1151 MLIRContext *context) {
1152 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1153 patterns::OrRCatZeroH, patterns::OrRCatZeroL,
1154 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right>(context);
1155}
1156
1157OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1158 if (!hasKnownWidthIntTypes(*this))
1159 return {};
1160
1161 if (getInput().getType().getBitWidthOrSentinel() == 0)
1162 return getIntAttr(getType(), APInt(1, 0));
1163
1164 // popcount(x) & 1
1165 if (auto cst = getConstant(adaptor.getInput()))
1166 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1167
1168 // one bit is identity. Only applies to UInt since we can't make a cast here.
1169 if (isUInt1(getInput().getType()))
1170 return getInput();
1171
1172 return {};
1173}
1174
1175void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1176 MLIRContext *context) {
1177 results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1178 patterns::XorRCatZeroH, patterns::XorRCatZeroL,
1179 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right>(
1180 context);
1181}
1182
1183//===----------------------------------------------------------------------===//
1184// Other Operators
1185//===----------------------------------------------------------------------===//
1186
1187OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1188 // cat(x, 0-width) -> x
1189 // cat(0-width, x) -> x
1190 // Limit to unsigned (result type), as cannot insert cast here.
1191 IntType lhsType = getLhs().getType();
1192 IntType rhsType = getRhs().getType();
1193 if (lhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
1194 return getRhs();
1195 if (rhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
1196 return getLhs();
1197
1198 if (!hasKnownWidthIntTypes(*this))
1199 return {};
1200
1201 // Constant fold cat.
1202 if (auto lhs = getConstant(adaptor.getLhs()))
1203 if (auto rhs = getConstant(adaptor.getRhs()))
1204 return getIntAttr(getType(), lhs->concat(*rhs));
1205
1206 return {};
1207}
1208
1209void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1210 MLIRContext *context) {
1211 results.insert<patterns::DShlOfConstant>(context);
1212}
1213
1214void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1215 MLIRContext *context) {
1216 results.insert<patterns::DShrOfConstant>(context);
1217}
1218
1219void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1220 MLIRContext *context) {
1221 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1222 patterns::CatCast>(context);
1223}
1224
1225OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1226 auto op = (*this);
1227 // BitCast is redundant if input and result types are same.
1228 if (op.getType() == op.getInput().getType())
1229 return op.getInput();
1230
1231 // Two consecutive BitCasts are redundant if first bitcast type is same as the
1232 // final result type.
1233 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1234 if (op.getType() == in.getInput().getType())
1235 return in.getInput();
1236
1237 return {};
1238}
1239
1240OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1241 IntType inputType = getInput().getType();
1242 IntType resultType = getType();
1243 // If we are extracting the entire input, then return it.
1244 if (inputType == getType() && resultType.hasWidth())
1245 return getInput();
1246
1247 // Constant fold.
1248 if (hasKnownWidthIntTypes(*this))
1249 if (auto cst = getConstant(adaptor.getInput()))
1250 return getIntAttr(resultType,
1251 cst->extractBits(getHi() - getLo() + 1, getLo()));
1252
1253 return {};
1254}
1255
1256void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1257 MLIRContext *context) {
1258 results
1259 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1260 patterns::BitsOfAnd, patterns::BitsOfPad>(context);
1261}
1262
1263/// Replace the specified operation with a 'bits' op from the specified hi/lo
1264/// bits. Insert a cast to handle the case where the original operation
1265/// returned a signed integer.
1266static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1267 unsigned loBit, PatternRewriter &rewriter) {
1268 auto resType = type_cast<IntType>(op->getResult(0).getType());
1269 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1270 value = rewriter.create<BitsPrimOp>(op->getLoc(), value, hiBit, loBit);
1271
1272 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1273 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1274 } else if (resType.isUnsigned() &&
1275 !type_cast<IntType>(value.getType()).isUnsigned()) {
1276 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1277 }
1278 rewriter.replaceOp(op, value);
1279}
1280
1281template <typename OpTy>
1282static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1283 // mux : UInt<0> -> 0
1284 if (op.getType().getBitWidthOrSentinel() == 0)
1285 return getIntAttr(op.getType(),
1286 APInt(0, 0, op.getType().isSignedInteger()));
1287
1288 // mux(cond, x, x) -> x
1289 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1290 return op.getHigh();
1291
1292 // The following folds require that the result has a known width. Otherwise
1293 // the mux requires an additional padding operation to be inserted, which is
1294 // not possible in a fold.
1295 if (op.getType().getBitWidthOrSentinel() < 0)
1296 return {};
1297
1298 // mux(0/1, x, y) -> x or y
1299 if (auto cond = getConstant(adaptor.getSel())) {
1300 if (cond->isZero() && op.getLow().getType() == op.getType())
1301 return op.getLow();
1302 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1303 return op.getHigh();
1304 }
1305
1306 // mux(cond, x, cst)
1307 if (auto lowCst = getConstant(adaptor.getLow())) {
1308 // mux(cond, c1, c2)
1309 if (auto highCst = getConstant(adaptor.getHigh())) {
1310 // mux(cond, cst, cst) -> cst
1311 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1312 *highCst == *lowCst)
1313 return getIntAttr(op.getType(), *highCst);
1314 // mux(cond, 1, 0) -> cond
1315 if (highCst->isOne() && lowCst->isZero() &&
1316 op.getType() == op.getSel().getType())
1317 return op.getSel();
1318
1319 // TODO: x ? ~0 : 0 -> sext(x)
1320 // TODO: "x ? c1 : c2" -> many tricks
1321 }
1322 // TODO: "x ? a : 0" -> sext(x) & a
1323 }
1324
1325 // TODO: "x ? c1 : y" -> "~x ? y : c1"
1326 return {};
1327}
1328
1329OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1330 return foldMux(*this, adaptor);
1331}
1332
1333OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1334 return foldMux(*this, adaptor);
1335}
1336
1337OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1338
1339namespace {
1340
1341// If the mux has a known output width, pad the operands up to this width.
1342// Most folds on mux require that folded operands are of the same width as
1343// the mux itself.
1344class MuxPad : public mlir::RewritePattern {
1345public:
1346 MuxPad(MLIRContext *context)
1347 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1348
1349 LogicalResult
1350 matchAndRewrite(Operation *op,
1351 mlir::PatternRewriter &rewriter) const override {
1352 auto mux = cast<MuxPrimOp>(op);
1353 auto width = mux.getType().getBitWidthOrSentinel();
1354 if (width < 0)
1355 return failure();
1356
1357 auto pad = [&](Value input) -> Value {
1358 auto inputWidth =
1359 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1360 if (inputWidth < 0 || width == inputWidth)
1361 return input;
1362 return rewriter
1363 .create<PadPrimOp>(mux.getLoc(), mux.getType(), input, width)
1364 .getResult();
1365 };
1366
1367 auto newHigh = pad(mux.getHigh());
1368 auto newLow = pad(mux.getLow());
1369 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1370 return failure();
1371
1372 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1373 rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1374 mux->getAttrs());
1375 return success();
1376 }
1377};
1378
1379// Find muxes which have conditions dominated by other muxes with the same
1380// condition.
1381class MuxSharedCond : public mlir::RewritePattern {
1382public:
1383 MuxSharedCond(MLIRContext *context)
1384 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1385
1386 static const int depthLimit = 5;
1387
1388 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1389 mlir::PatternRewriter &rewriter,
1390 bool updateInPlace) const {
1391 if (updateInPlace) {
1392 rewriter.modifyOpInPlace(mux, [&] {
1393 mux.setOperand(1, high);
1394 mux.setOperand(2, low);
1395 });
1396 return {};
1397 }
1398 rewriter.setInsertionPointAfter(mux);
1399 return rewriter
1400 .create<MuxPrimOp>(mux.getLoc(), mux.getType(),
1401 ValueRange{mux.getSel(), high, low})
1402 .getResult();
1403 }
1404
1405 // Walk a dependent mux tree assuming the condition cond is true.
1406 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1407 bool updateInPlace, int limit) const {
1408 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1409 if (!mux)
1410 return {};
1411 if (mux.getSel() == cond)
1412 return mux.getHigh();
1413 if (limit > depthLimit)
1414 return {};
1415 updateInPlace &= mux->hasOneUse();
1416
1417 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1418 limit + 1))
1419 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1420
1421 if (Value v =
1422 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1423 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1424 return {};
1425 }
1426
1427 // Walk a dependent mux tree assuming the condition cond is false.
1428 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1429 bool updateInPlace, int limit) const {
1430 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1431 if (!mux)
1432 return {};
1433 if (mux.getSel() == cond)
1434 return mux.getLow();
1435 if (limit > depthLimit)
1436 return {};
1437 updateInPlace &= mux->hasOneUse();
1438
1439 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1440 limit + 1))
1441 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1442
1443 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1444 limit + 1))
1445 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1446
1447 return {};
1448 }
1449
1450 LogicalResult
1451 matchAndRewrite(Operation *op,
1452 mlir::PatternRewriter &rewriter) const override {
1453 auto mux = cast<MuxPrimOp>(op);
1454 auto width = mux.getType().getBitWidthOrSentinel();
1455 if (width < 0)
1456 return failure();
1457
1458 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1459 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1460 return success();
1461 }
1462
1463 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
1464 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1465 return success();
1466 }
1467
1468 return failure();
1469 }
1470};
1471} // namespace
1472
1473void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1474 MLIRContext *context) {
1475 results
1476 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1477 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1478 patterns::MuxSameTrue, patterns::MuxSameFalse,
1479 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1480 context);
1481}
1482
1483void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1484 RewritePatternSet &results, MLIRContext *context) {
1485 results.add<patterns::Mux2PadSel>(context);
1486}
1487
1488void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1489 RewritePatternSet &results, MLIRContext *context) {
1490 results.add<patterns::Mux4PadSel>(context);
1491}
1492
1493OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1494 auto input = this->getInput();
1495
1496 // pad(x) -> x if the width doesn't change.
1497 if (input.getType() == getType())
1498 return input;
1499
1500 // Need to know the input width.
1501 auto inputType = input.getType().base();
1502 int32_t width = inputType.getWidthOrSentinel();
1503 if (width == -1)
1504 return {};
1505
1506 // Constant fold.
1507 if (auto cst = getConstant(adaptor.getInput())) {
1508 auto destWidth = getType().base().getWidthOrSentinel();
1509 if (destWidth == -1)
1510 return {};
1511
1512 if (inputType.isSigned() && cst->getBitWidth())
1513 return getIntAttr(getType(), cst->sext(destWidth));
1514 return getIntAttr(getType(), cst->zext(destWidth));
1515 }
1516
1517 return {};
1518}
1519
1520OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1521 auto input = this->getInput();
1522 IntType inputType = input.getType();
1523 int shiftAmount = getAmount();
1524
1525 // shl(x, 0) -> x
1526 if (shiftAmount == 0)
1527 return input;
1528
1529 // Constant fold.
1530 if (auto cst = getConstant(adaptor.getInput())) {
1531 auto inputWidth = inputType.getWidthOrSentinel();
1532 if (inputWidth != -1) {
1533 auto resultWidth = inputWidth + shiftAmount;
1534 shiftAmount = std::min(shiftAmount, resultWidth);
1535 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1536 }
1537 }
1538 return {};
1539}
1540
1541OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1542 auto input = this->getInput();
1543 IntType inputType = input.getType();
1544 int shiftAmount = getAmount();
1545 auto inputWidth = inputType.getWidthOrSentinel();
1546
1547 // shr(x, 0) -> x
1548 // Once the shr width changes, do this: shiftAmount == 0 &&
1549 // (!inputType.isSigned() || inputWidth > 0)
1550 if (shiftAmount == 0 && inputWidth > 0)
1551 return input;
1552
1553 if (inputWidth == -1)
1554 return {};
1555 if (inputWidth == 0)
1556 return getIntZerosAttr(getType());
1557
1558 // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
1559 // If x is signed, it is the sign bit.
1560 if (shiftAmount >= inputWidth && inputType.isUnsigned())
1561 return getIntAttr(getType(), APInt(0, 0, false));
1562
1563 // Constant fold.
1564 if (auto cst = getConstant(adaptor.getInput())) {
1565 APInt value;
1566 if (inputType.isSigned())
1567 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1568 else
1569 value = cst->lshr(std::min(shiftAmount, inputWidth));
1570 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1571 return getIntAttr(getType(), value.trunc(resultWidth));
1572 }
1573 return {};
1574}
1575
1576LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1577 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1578 if (inputWidth <= 0)
1579 return failure();
1580
1581 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1582 unsigned shiftAmount = op.getAmount();
1583 if (int(shiftAmount) >= inputWidth) {
1584 // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
1585 if (op.getType().base().isUnsigned())
1586 return failure();
1587
1588 // Shifting a signed value by the full width is actually taking the
1589 // sign bit. If the shift amount is greater than the input width, it
1590 // is equivalent to shifting by the input width.
1591 shiftAmount = inputWidth - 1;
1592 }
1593
1594 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1595 return success();
1596}
1597
1598LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1599 PatternRewriter &rewriter) {
1600 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1601 if (inputWidth <= 0)
1602 return failure();
1603
1604 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1605 unsigned keepAmount = op.getAmount();
1606 if (keepAmount)
1607 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1608 rewriter);
1609 return success();
1610}
1611
1612OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1613 if (hasKnownWidthIntTypes(*this))
1614 if (auto cst = getConstant(adaptor.getInput())) {
1615 int shiftAmount =
1616 getInput().getType().base().getWidthOrSentinel() - getAmount();
1617 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1618 }
1619
1620 return {};
1621}
1622
1623OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1624 if (hasKnownWidthIntTypes(*this))
1625 if (auto cst = getConstant(adaptor.getInput()))
1626 return getIntAttr(getType(),
1627 cst->trunc(getType().base().getWidthOrSentinel()));
1628 return {};
1629}
1630
1631LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1632 PatternRewriter &rewriter) {
1633 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1634 if (inputWidth <= 0)
1635 return failure();
1636
1637 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1638 unsigned dropAmount = op.getAmount();
1639 if (dropAmount != unsigned(inputWidth))
1640 replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
1641 rewriter);
1642 return success();
1643}
1644
1645void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1646 MLIRContext *context) {
1647 results.add<patterns::SubaccessOfConstant>(context);
1648}
1649
1650OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1651 // If there is only one input, just return it.
1652 if (adaptor.getInputs().size() == 1)
1653 return getOperand(1);
1654
1655 if (auto constIndex = getConstant(adaptor.getIndex())) {
1656 auto index = constIndex->getZExtValue();
1657 if (index < getInputs().size())
1658 return getInputs()[getInputs().size() - 1 - index];
1659 }
1660
1661 return {};
1662}
1663
1664LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
1665 PatternRewriter &rewriter) {
1666 // If all operands are equal, just canonicalize to it. We can add this
1667 // canonicalization as a folder but it costly to look through all inputs so it
1668 // is added here.
1669 if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
1670 return input == op.getInputs().front();
1671 })) {
1672 replaceOpAndCopyName(rewriter, op, op.getInputs().front());
1673 return success();
1674 }
1675
1676 // If the index width is narrower than the size of inputs, drop front
1677 // elements.
1678 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
1679 uint64_t inputSize = op.getInputs().size();
1680 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
1681 rewriter.modifyOpInPlace(op, [&]() {
1682 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
1683 });
1684 return success();
1685 }
1686
1687 // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
1688 // a[0]`), we can fold the op into subaccess op `a[idx]`.
1689 if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
1690 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
1691 auto subindex = e.value().template getDefiningOp<SubindexOp>();
1692 return subindex && lastSubindex.getInput() == subindex.getInput() &&
1693 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
1694 })) {
1695 replaceOpWithNewOpAndCopyName<SubaccessOp>(
1696 rewriter, op, lastSubindex.getInput(), op.getIndex());
1697 return success();
1698 }
1699 }
1700
1701 // If the size is 2, canonicalize into a normal mux to introduce more folds.
1702 if (op.getInputs().size() != 2)
1703 return failure();
1704
1705 // TODO: Handle even when `index` doesn't have uint<1>.
1706 auto uintType = op.getIndex().getType();
1707 if (uintType.getBitWidthOrSentinel() != 1)
1708 return failure();
1709
1710 // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
1711 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1712 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
1713 return success();
1714}
1715
1716//===----------------------------------------------------------------------===//
1717// Declarations
1718//===----------------------------------------------------------------------===//
1719
1720/// Scan all the uses of the specified value, checking to see if there is
1721/// exactly one connect that has the value as its destination. This returns the
1722/// operation if found and if all the other users are "reads" from the value.
1723/// Returns null if there are no connects, or multiple connects to the value, or
1724/// if the value is involved in an `AttachOp`, or if the connect isn't matching.
1725///
1726/// Note that this will simply return the connect, which is located *anywhere*
1727/// after the definition of the value. Users of this function are likely
1728/// interested in the source side of the returned connect, the definition of
1729/// which does likely not dominate the original value.
1730MatchingConnectOp firrtl::getSingleConnectUserOf(Value value) {
1731 MatchingConnectOp connect;
1732 for (Operation *user : value.getUsers()) {
1733 // If we see an attach or aggregate sublements, just conservatively fail.
1734 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
1735 return {};
1736
1737 if (auto aConnect = dyn_cast<FConnectLike>(user))
1738 if (aConnect.getDest() == value) {
1739 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
1740 // If this is not a matching connect, a second matching connect or in a
1741 // different block, fail.
1742 if (!matchingConnect || (connect && connect != matchingConnect) ||
1743 matchingConnect->getBlock() != value.getParentBlock())
1744 return {};
1745 else
1746 connect = matchingConnect;
1747 }
1748 }
1749 return connect;
1750}
1751
1752// Forward simple values through wire's and reg's.
1753static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op,
1754 PatternRewriter &rewriter) {
1755 // While we can do this for nearly all wires, we currently limit it to simple
1756 // things.
1757 Operation *connectedDecl = op.getDest().getDefiningOp();
1758 if (!connectedDecl)
1759 return failure();
1760
1761 // Only support wire and reg for now.
1762 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
1763 return failure();
1764 if (hasDontTouch(connectedDecl) || !AnnotationSet(connectedDecl).empty() ||
1765 !hasDroppableName(connectedDecl) ||
1766 cast<Forceable>(connectedDecl).isForceable())
1767 return failure();
1768
1769 // Only forward if the types exactly match and there is one connect.
1770 if (getSingleConnectUserOf(op.getDest()) != op)
1771 return failure();
1772
1773 // Only forward if there is more than one use
1774 if (connectedDecl->hasOneUse())
1775 return failure();
1776
1777 // Only do this if the connectee and the declaration are in the same block.
1778 auto *declBlock = connectedDecl->getBlock();
1779 auto *srcValueOp = op.getSrc().getDefiningOp();
1780 if (!srcValueOp) {
1781 // Ports are ok for wires but not registers.
1782 if (!isa<WireOp>(connectedDecl))
1783 return failure();
1784
1785 } else {
1786 // Constants/invalids in the same block are ok to forward, even through
1787 // reg's since the clocking doesn't matter for constants.
1788 if (!isa<ConstantOp>(srcValueOp))
1789 return failure();
1790 if (srcValueOp->getBlock() != declBlock)
1791 return failure();
1792 }
1793
1794 // Ok, we know we are doing the transformation.
1795
1796 auto replacement = op.getSrc();
1797 // This will be replaced with the constant source. First, make sure the
1798 // constant dominates all users.
1799 if (srcValueOp && srcValueOp != &declBlock->front())
1800 srcValueOp->moveBefore(&declBlock->front());
1801
1802 // Replace all things *using* the decl with the constant/port, and
1803 // remove the declaration.
1804 replaceOpAndCopyName(rewriter, connectedDecl, replacement);
1805
1806 // Remove the connect
1807 rewriter.eraseOp(op);
1808 return success();
1809}
1810
1811void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1812 MLIRContext *context) {
1813 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
1814 context);
1815}
1816
1817LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
1818 PatternRewriter &rewriter) {
1819 // TODO: Canonicalize towards explicit extensions and flips here.
1820
1821 // If there is a simple value connected to a foldable decl like a wire or reg,
1822 // see if we can eliminate the decl.
1823 if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
1824 return success();
1825 return failure();
1826}
1827
1828//===----------------------------------------------------------------------===//
1829// Statements
1830//===----------------------------------------------------------------------===//
1831
1832/// If the specified value has an AttachOp user strictly dominating by
1833/// "dominatingAttach" then return it.
1834static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
1835 for (auto *user : value.getUsers()) {
1836 auto attach = dyn_cast<AttachOp>(user);
1837 if (!attach || attach == dominatedAttach)
1838 continue;
1839 if (attach->isBeforeInBlock(dominatedAttach))
1840 return attach;
1841 }
1842 return {};
1843}
1844
1845LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
1846 // Single operand attaches are a noop.
1847 if (op.getNumOperands() <= 1) {
1848 rewriter.eraseOp(op);
1849 return success();
1850 }
1851
1852 for (auto operand : op.getOperands()) {
1853 // Check to see if any of our operands has other attaches to it:
1854 // attach x, y
1855 // ...
1856 // attach x, z
1857 // If so, we can merge these into "attach x, y, z".
1858 if (auto attach = getDominatingAttachUser(operand, op)) {
1859 SmallVector<Value> newOperands(op.getOperands());
1860 for (auto newOperand : attach.getOperands())
1861 if (newOperand != operand) // Don't add operand twice.
1862 newOperands.push_back(newOperand);
1863 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1864 rewriter.eraseOp(attach);
1865 rewriter.eraseOp(op);
1866 return success();
1867 }
1868
1869 // If this wire is *only* used by an attach then we can just delete
1870 // it.
1871 // TODO: May need to be sensitive to "don't touch" or other
1872 // annotations.
1873 if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
1874 if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
1875 !wire.isForceable()) {
1876 SmallVector<Value> newOperands;
1877 for (auto newOperand : op.getOperands())
1878 if (newOperand != operand) // Don't the add wire.
1879 newOperands.push_back(newOperand);
1880
1881 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1882 rewriter.eraseOp(op);
1883 rewriter.eraseOp(wire);
1884 return success();
1885 }
1886 }
1887 }
1888 return failure();
1889}
1890
1891/// Replaces the given op with the contents of the given single-block region.
1892static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1893 Region &region) {
1894 assert(llvm::hasSingleElement(region) && "expected single-region block");
1895 rewriter.inlineBlockBefore(&region.front(), op, {});
1896}
1897
1898LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
1899 if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
1900 if (constant.getValue().isAllOnes())
1901 replaceOpWithRegion(rewriter, op, op.getThenRegion());
1902 else if (op.hasElseRegion() && !op.getElseRegion().empty())
1903 replaceOpWithRegion(rewriter, op, op.getElseRegion());
1904
1905 rewriter.eraseOp(op);
1906
1907 return success();
1908 }
1909
1910 // Erase empty if-else block.
1911 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
1912 op.getElseBlock().empty()) {
1913 rewriter.eraseBlock(&op.getElseBlock());
1914 return success();
1915 }
1916
1917 // Erase empty whens.
1918
1919 // If there is stuff in the then block, leave this operation alone.
1920 if (!op.getThenBlock().empty())
1921 return failure();
1922
1923 // If not and there is no else, then this operation is just useless.
1924 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
1925 rewriter.eraseOp(op);
1926 return success();
1927 }
1928 return failure();
1929}
1930
1931namespace {
1932// Remove private nodes. If they have an interesting names, move the name to
1933// the source expression.
1934struct FoldNodeName : public mlir::RewritePattern {
1935 FoldNodeName(MLIRContext *context)
1936 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1937 LogicalResult matchAndRewrite(Operation *op,
1938 PatternRewriter &rewriter) const override {
1939 auto node = cast<NodeOp>(op);
1940 auto name = node.getNameAttr();
1941 if (!node.hasDroppableName() || node.getInnerSym() ||
1942 !AnnotationSet(node).empty() || node.isForceable())
1943 return failure();
1944 auto *newOp = node.getInput().getDefiningOp();
1945 // Best effort, do not rename InstanceOp
1946 if (newOp && !isa<InstanceOp>(newOp))
1947 updateName(rewriter, newOp, name);
1948 rewriter.replaceOp(node, node.getInput());
1949 return success();
1950 }
1951};
1952
1953// Bypass nodes.
1954struct NodeBypass : public mlir::RewritePattern {
1955 NodeBypass(MLIRContext *context)
1956 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1957 LogicalResult matchAndRewrite(Operation *op,
1958 PatternRewriter &rewriter) const override {
1959 auto node = cast<NodeOp>(op);
1960 if (node.getInnerSym() || !AnnotationSet(node).empty() ||
1961 node.use_empty() || node.isForceable())
1962 return failure();
1963 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
1964 return success();
1965 }
1966};
1967
1968} // namespace
1969
1970template <typename OpTy>
1971static LogicalResult demoteForceableIfUnused(OpTy op,
1972 PatternRewriter &rewriter) {
1973 if (!op.isForceable() || !op.getDataRef().use_empty())
1974 return failure();
1975
1976 firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
1977 return success();
1978}
1979
1980// Interesting names and symbols and don't touch force nodes to stick around.
1981LogicalResult NodeOp::fold(FoldAdaptor adaptor,
1982 SmallVectorImpl<OpFoldResult> &results) {
1983 if (!hasDroppableName())
1984 return failure();
1985 if (hasDontTouch(getResult())) // handles inner symbols
1986 return failure();
1987 if (getAnnotationsAttr() && !AnnotationSet(getAnnotationsAttr()).empty())
1988 return failure();
1989 if (isForceable())
1990 return failure();
1991 if (!adaptor.getInput())
1992 return failure();
1993
1994 results.push_back(adaptor.getInput());
1995 return success();
1996}
1997
1998void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1999 MLIRContext *context) {
2000 results.insert<FoldNodeName>(context);
2001 results.add(demoteForceableIfUnused<NodeOp>);
2002}
2003
2004namespace {
2005// For a lhs, find all the writers of fields of the aggregate type. If there
2006// is one writer for each field, merge the writes
2007struct AggOneShot : public mlir::RewritePattern {
2008 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2009 : RewritePattern(name, 0, context) {}
2010
2011 SmallVector<Value> getCompleteWrite(Operation *lhs) const {
2012 auto lhsTy = lhs->getResult(0).getType();
2013 if (!type_isa<BundleType, FVectorType>(lhsTy))
2014 return {};
2015
2016 DenseMap<uint32_t, Value> fields;
2017 for (Operation *user : lhs->getResult(0).getUsers()) {
2018 if (user->getParentOp() != lhs->getParentOp())
2019 return {};
2020 if (auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2021 if (aConnect.getDest() == lhs->getResult(0))
2022 return {};
2023 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2024 for (Operation *subuser : subField.getResult().getUsers()) {
2025 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2026 if (aConnect.getDest() == subField) {
2027 if (subuser->getParentOp() != lhs->getParentOp())
2028 return {};
2029 if (fields.count(subField.getFieldIndex())) // duplicate write
2030 return {};
2031 fields[subField.getFieldIndex()] = aConnect.getSrc();
2032 }
2033 continue;
2034 }
2035 return {};
2036 }
2037 } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2038 for (Operation *subuser : subIndex.getResult().getUsers()) {
2039 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2040 if (aConnect.getDest() == subIndex) {
2041 if (subuser->getParentOp() != lhs->getParentOp())
2042 return {};
2043 if (fields.count(subIndex.getIndex())) // duplicate write
2044 return {};
2045 fields[subIndex.getIndex()] = aConnect.getSrc();
2046 }
2047 continue;
2048 }
2049 return {};
2050 }
2051 } else {
2052 return {};
2053 }
2054 }
2055
2056 SmallVector<Value> values;
2057 uint32_t total = type_isa<BundleType>(lhsTy)
2058 ? type_cast<BundleType>(lhsTy).getNumElements()
2059 : type_cast<FVectorType>(lhsTy).getNumElements();
2060 for (uint32_t i = 0; i < total; ++i) {
2061 if (!fields.count(i))
2062 return {};
2063 values.push_back(fields[i]);
2064 }
2065 return values;
2066 }
2067
2068 LogicalResult matchAndRewrite(Operation *op,
2069 PatternRewriter &rewriter) const override {
2070 auto values = getCompleteWrite(op);
2071 if (values.empty())
2072 return failure();
2073 rewriter.setInsertionPointToEnd(op->getBlock());
2074 auto dest = op->getResult(0);
2075 auto destType = dest.getType();
2076
2077 // If not passive, cannot matchingconnect.
2078 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2079 return failure();
2080
2081 Value newVal = type_isa<BundleType>(destType)
2082 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2083 destType, values)
2084 : rewriter.createOrFold<VectorCreateOp>(
2085 op->getLoc(), destType, values);
2086 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2087 for (Operation *user : dest.getUsers()) {
2088 if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2089 for (Operation *subuser :
2090 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2091 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2092 if (aConnect.getDest() == subIndex)
2093 rewriter.eraseOp(aConnect);
2094 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2095 for (Operation *subuser :
2096 llvm::make_early_inc_range(subField.getResult().getUsers()))
2097 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2098 if (aConnect.getDest() == subField)
2099 rewriter.eraseOp(aConnect);
2100 }
2101 }
2102 return success();
2103 }
2104};
2105
2106struct WireAggOneShot : public AggOneShot {
2107 WireAggOneShot(MLIRContext *context)
2108 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2109};
2110struct SubindexAggOneShot : public AggOneShot {
2111 SubindexAggOneShot(MLIRContext *context)
2112 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2113};
2114struct SubfieldAggOneShot : public AggOneShot {
2115 SubfieldAggOneShot(MLIRContext *context)
2116 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2117};
2118} // namespace
2119
2120void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2121 MLIRContext *context) {
2122 results.insert<WireAggOneShot>(context);
2123 results.add(demoteForceableIfUnused<WireOp>);
2124}
2125
2126void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2127 MLIRContext *context) {
2128 results.insert<SubindexAggOneShot>(context);
2129}
2130
2131OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2132 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2133 if (!attr)
2134 return {};
2135 return attr[getIndex()];
2136}
2137
2138OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2139 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2140 if (!attr)
2141 return {};
2142 auto index = getFieldIndex();
2143 return attr[index];
2144}
2145
2146void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2147 MLIRContext *context) {
2148 results.insert<SubfieldAggOneShot>(context);
2149}
2150
2151static Attribute collectFields(MLIRContext *context,
2152 ArrayRef<Attribute> operands) {
2153 for (auto operand : operands)
2154 if (!operand)
2155 return {};
2156 return ArrayAttr::get(context, operands);
2157}
2158
2159OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2160 // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2161 // bundle<a:..., b:...>.
2162 if (getNumOperands() > 0)
2163 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2164 if (first.getFieldIndex() == 0 &&
2165 first.getInput().getType() == getType() &&
2166 llvm::all_of(
2167 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2168 auto subindex =
2169 elem.value().template getDefiningOp<SubfieldOp>();
2170 return subindex && subindex.getInput() == first.getInput() &&
2171 subindex.getFieldIndex() == elem.index();
2172 }))
2173 return first.getInput();
2174
2175 return collectFields(getContext(), adaptor.getOperands());
2176}
2177
2178OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2179 // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2180 // vector<..., 2>.
2181 if (getNumOperands() > 0)
2182 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2183 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2184 llvm::all_of(
2185 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2186 auto subindex =
2187 elem.value().template getDefiningOp<SubindexOp>();
2188 return subindex && subindex.getInput() == first.getInput() &&
2189 subindex.getIndex() == elem.index();
2190 }))
2191 return first.getInput();
2192
2193 return collectFields(getContext(), adaptor.getOperands());
2194}
2195
2196OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2197 if (getOperand().getType() == getType())
2198 return getOperand();
2199 return {};
2200}
2201
2202namespace {
2203// A register with constant reset and all connection to either itself or the
2204// same constant, must be replaced by the constant.
2205struct FoldResetMux : public mlir::RewritePattern {
2206 FoldResetMux(MLIRContext *context)
2207 : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2208 LogicalResult matchAndRewrite(Operation *op,
2209 PatternRewriter &rewriter) const override {
2210 auto reg = cast<RegResetOp>(op);
2211 auto reset =
2212 dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2213 if (!reset || hasDontTouch(reg.getOperation()) ||
2214 !AnnotationSet(reg).empty() || reg.isForceable())
2215 return failure();
2216 // Find the one true connect, or bail
2217 auto con = getSingleConnectUserOf(reg.getResult());
2218 if (!con)
2219 return failure();
2220
2221 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2222 if (!mux)
2223 return failure();
2224 auto *high = mux.getHigh().getDefiningOp();
2225 auto *low = mux.getLow().getDefiningOp();
2226 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2227
2228 if (constOp && low != reg)
2229 return failure();
2230 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2231 constOp = dyn_cast<ConstantOp>(low);
2232
2233 if (!constOp || constOp.getType() != reset.getType() ||
2234 constOp.getValue() != reset.getValue())
2235 return failure();
2236
2237 // Check all types should be typed by now
2238 auto regTy = reg.getResult().getType();
2239 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2240 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2241 regTy.getBitWidthOrSentinel() < 0)
2242 return failure();
2243
2244 // Ok, we know we are doing the transformation.
2245
2246 // Make sure the constant dominates all users.
2247 if (constOp != &con->getBlock()->front())
2248 constOp->moveBefore(&con->getBlock()->front());
2249
2250 // Replace the register with the constant.
2251 replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2252 // Remove the connect.
2253 rewriter.eraseOp(con);
2254 return success();
2255 }
2256};
2257} // namespace
2258
2259static bool isDefinedByOneConstantOp(Value v) {
2260 if (auto c = v.getDefiningOp<ConstantOp>())
2261 return c.getValue().isOne();
2262 if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2263 return sc.getValue();
2264 return false;
2265}
2266
2267static LogicalResult
2268canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2269 if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2270 return failure();
2271
2272 // Ignore 'passthrough'.
2273 (void)dropWrite(rewriter, reg->getResult(0), {});
2274 replaceOpWithNewOpAndCopyName<NodeOp>(
2275 rewriter, reg, reg.getResetValue(), reg.getNameAttr(), reg.getNameKind(),
2276 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2277 return success();
2278}
2279
2280void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2281 MLIRContext *context) {
2282 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2284 results.add(demoteForceableIfUnused<RegResetOp>);
2285}
2286
2287// Returns the value connected to a port, if there is only one.
2288static Value getPortFieldValue(Value port, StringRef name) {
2289 auto portTy = type_cast<BundleType>(port.getType());
2290 auto fieldIndex = portTy.getElementIndex(name);
2291 assert(fieldIndex && "missing field on memory port");
2292
2293 Value value = {};
2294 for (auto *op : port.getUsers()) {
2295 auto portAccess = cast<SubfieldOp>(op);
2296 if (fieldIndex != portAccess.getFieldIndex())
2297 continue;
2298 auto conn = getSingleConnectUserOf(portAccess);
2299 if (!conn || value)
2300 return {};
2301 value = conn.getSrc();
2302 }
2303 return value;
2304}
2305
2306// Returns true if the enable field of a port is set to false.
2307static bool isPortDisabled(Value port) {
2308 auto value = getPortFieldValue(port, "en");
2309 if (!value)
2310 return false;
2311 auto portConst = value.getDefiningOp<ConstantOp>();
2312 if (!portConst)
2313 return false;
2314 return portConst.getValue().isZero();
2315}
2316
2317// Returns true if the data output is unused.
2318static bool isPortUnused(Value port, StringRef data) {
2319 auto portTy = type_cast<BundleType>(port.getType());
2320 auto fieldIndex = portTy.getElementIndex(data);
2321 assert(fieldIndex && "missing enable flag on memory port");
2322
2323 for (auto *op : port.getUsers()) {
2324 auto portAccess = cast<SubfieldOp>(op);
2325 if (fieldIndex != portAccess.getFieldIndex())
2326 continue;
2327 if (!portAccess.use_empty())
2328 return false;
2329 }
2330
2331 return true;
2332}
2333
2334// Returns the value connected to a port, if there is only one.
2335static void replacePortField(PatternRewriter &rewriter, Value port,
2336 StringRef name, Value value) {
2337 auto portTy = type_cast<BundleType>(port.getType());
2338 auto fieldIndex = portTy.getElementIndex(name);
2339 assert(fieldIndex && "missing field on memory port");
2340
2341 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2342 auto portAccess = cast<SubfieldOp>(op);
2343 if (fieldIndex != portAccess.getFieldIndex())
2344 continue;
2345 rewriter.replaceAllUsesWith(portAccess, value);
2346 rewriter.eraseOp(portAccess);
2347 }
2348}
2349
2350// Remove accesses to a port which is used.
2351static void erasePort(PatternRewriter &rewriter, Value port) {
2352 // Helper to create a dummy 0 clock for the dummy registers.
2353 Value clock;
2354 auto getClock = [&] {
2355 if (!clock)
2356 clock = rewriter.create<SpecialConstantOp>(
2357 port.getLoc(), ClockType::get(rewriter.getContext()), false);
2358 return clock;
2359 };
2360
2361 // Find the clock field of the port and determine whether the port is
2362 // accessed only through its subfields or as a whole wire. If the port
2363 // is used in its entirety, replace it with a wire. Otherwise,
2364 // eliminate individual subfields and replace with reasonable defaults.
2365 for (auto *op : port.getUsers()) {
2366 auto subfield = dyn_cast<SubfieldOp>(op);
2367 if (!subfield) {
2368 auto ty = port.getType();
2369 auto reg = rewriter.create<RegOp>(port.getLoc(), ty, getClock());
2370 rewriter.replaceAllUsesWith(port, reg.getResult());
2371 return;
2372 }
2373 }
2374
2375 // Remove all connects to field accesses as they are no longer relevant.
2376 // If field values are used anywhere, which should happen solely for read
2377 // ports, a dummy register is introduced which replicates the behaviour of
2378 // memory that is never written, but might be read.
2379 for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2380 auto access = cast<SubfieldOp>(accessOp);
2381 for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2382 auto connect = dyn_cast<FConnectLike>(user);
2383 if (connect && connect.getDest() == access) {
2384 rewriter.eraseOp(user);
2385 continue;
2386 }
2387 }
2388 if (access.use_empty()) {
2389 rewriter.eraseOp(access);
2390 continue;
2391 }
2392
2393 // Replace read values with a register that is never written, handing off
2394 // the canonicalization of such a register to another canonicalizer.
2395 auto ty = access.getType();
2396 auto reg = rewriter.create<RegOp>(access.getLoc(), ty, getClock());
2397 rewriter.replaceOp(access, reg.getResult());
2398 }
2399 assert(port.use_empty() && "port should have no remaining uses");
2400}
2401
2402namespace {
2403// If memory has known, but zero width, eliminate it.
2404struct FoldZeroWidthMemory : public mlir::RewritePattern {
2405 FoldZeroWidthMemory(MLIRContext *context)
2406 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2407 LogicalResult matchAndRewrite(Operation *op,
2408 PatternRewriter &rewriter) const override {
2409 MemOp mem = cast<MemOp>(op);
2410 if (hasDontTouch(mem))
2411 return failure();
2412
2413 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2414 mem.getDataType().getBitWidthOrSentinel() != 0)
2415 return failure();
2416
2417 // Make sure are users are safe to replace
2418 for (auto port : mem.getResults())
2419 for (auto *user : port.getUsers())
2420 if (!isa<SubfieldOp>(user))
2421 return failure();
2422
2423 // Annoyingly, there isn't a good replacement for the port as a whole,
2424 // since they have an outer flip type.
2425 for (auto port : op->getResults()) {
2426 for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2427 SubfieldOp sfop = cast<SubfieldOp>(user);
2428 StringRef fieldName = sfop.getFieldName();
2429 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2430 rewriter, sfop, sfop.getResult().getType())
2431 .getResult();
2432 if (fieldName.ends_with("data")) {
2433 // Make sure to write data ports.
2434 auto zero = rewriter.create<firrtl::ConstantOp>(
2435 wire.getLoc(), firrtl::type_cast<IntType>(wire.getType()),
2436 APInt::getZero(0));
2437 rewriter.create<MatchingConnectOp>(wire.getLoc(), wire, zero);
2438 }
2439 }
2440 }
2441 rewriter.eraseOp(op);
2442 return success();
2443 }
2444};
2445
2446// If memory has no write ports and no file initialization, eliminate it.
2447struct FoldReadOrWriteOnlyMemory : public mlir::RewritePattern {
2448 FoldReadOrWriteOnlyMemory(MLIRContext *context)
2449 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2450 LogicalResult matchAndRewrite(Operation *op,
2451 PatternRewriter &rewriter) const override {
2452 MemOp mem = cast<MemOp>(op);
2453 if (hasDontTouch(mem))
2454 return failure();
2455 bool isRead = false, isWritten = false;
2456 for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2457 switch (mem.getPortKind(i)) {
2458 case MemOp::PortKind::Read:
2459 isRead = true;
2460 if (isWritten)
2461 return failure();
2462 continue;
2463 case MemOp::PortKind::Write:
2464 isWritten = true;
2465 if (isRead)
2466 return failure();
2467 continue;
2468 case MemOp::PortKind::Debug:
2469 case MemOp::PortKind::ReadWrite:
2470 return failure();
2471 }
2472 llvm_unreachable("unknown port kind");
2473 }
2474 assert((!isWritten || !isRead) && "memory is in use");
2475
2476 // If the memory is read only, but has a file initialization, then we can't
2477 // remove it. A write only memory with file initialization is okay to
2478 // remove.
2479 if (isRead && mem.getInit())
2480 return failure();
2481
2482 for (auto port : mem.getResults())
2483 erasePort(rewriter, port);
2484
2485 rewriter.eraseOp(op);
2486 return success();
2487 }
2488};
2489
2490// Eliminate the dead ports of memories.
2491struct FoldUnusedPorts : public mlir::RewritePattern {
2492 FoldUnusedPorts(MLIRContext *context)
2493 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2494 LogicalResult matchAndRewrite(Operation *op,
2495 PatternRewriter &rewriter) const override {
2496 MemOp mem = cast<MemOp>(op);
2497 if (hasDontTouch(mem))
2498 return failure();
2499 // Identify the dead and changed ports.
2500 llvm::SmallBitVector deadPorts(mem.getNumResults());
2501 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2502 // Do not simplify annotated ports.
2503 if (!mem.getPortAnnotation(i).empty())
2504 continue;
2505
2506 // Skip debug ports.
2507 auto kind = mem.getPortKind(i);
2508 if (kind == MemOp::PortKind::Debug)
2509 continue;
2510
2511 // If a port is disabled, always eliminate it.
2512 if (isPortDisabled(port)) {
2513 deadPorts.set(i);
2514 continue;
2515 }
2516 // Eliminate read ports whose outputs are not used.
2517 if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
2518 deadPorts.set(i);
2519 continue;
2520 }
2521 }
2522 if (deadPorts.none())
2523 return failure();
2524
2525 // Rebuild the new memory with the altered ports.
2526 SmallVector<Type> resultTypes;
2527 SmallVector<StringRef> portNames;
2528 SmallVector<Attribute> portAnnotations;
2529 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2530 if (deadPorts[i])
2531 continue;
2532 resultTypes.push_back(port.getType());
2533 portNames.push_back(mem.getPortName(i));
2534 portAnnotations.push_back(mem.getPortAnnotation(i));
2535 }
2536
2537 MemOp newOp;
2538 if (!resultTypes.empty())
2539 newOp = rewriter.create<MemOp>(
2540 mem.getLoc(), resultTypes, mem.getReadLatency(),
2541 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2542 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2543 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2544 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2545
2546 // Replace the dead ports with dummy wires.
2547 unsigned nextPort = 0;
2548 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2549 if (deadPorts[i])
2550 erasePort(rewriter, port);
2551 else
2552 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2553 }
2554
2555 rewriter.eraseOp(op);
2556 return success();
2557 }
2558};
2559
2560// Rewrite write-only read-write ports to write ports.
2561struct FoldReadWritePorts : public mlir::RewritePattern {
2562 FoldReadWritePorts(MLIRContext *context)
2563 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2564 LogicalResult matchAndRewrite(Operation *op,
2565 PatternRewriter &rewriter) const override {
2566 MemOp mem = cast<MemOp>(op);
2567 if (hasDontTouch(mem))
2568 return failure();
2569
2570 // Identify read-write ports whose read end is unused.
2571 llvm::SmallBitVector deadReads(mem.getNumResults());
2572 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2573 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2574 continue;
2575 if (!mem.getPortAnnotation(i).empty())
2576 continue;
2577 if (isPortUnused(port, "rdata")) {
2578 deadReads.set(i);
2579 continue;
2580 }
2581 }
2582 if (deadReads.none())
2583 return failure();
2584
2585 SmallVector<Type> resultTypes;
2586 SmallVector<StringRef> portNames;
2587 SmallVector<Attribute> portAnnotations;
2588 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2589 if (deadReads[i])
2590 resultTypes.push_back(
2591 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2592 MemOp::PortKind::Write, mem.getMaskBits()));
2593 else
2594 resultTypes.push_back(port.getType());
2595
2596 portNames.push_back(mem.getPortName(i));
2597 portAnnotations.push_back(mem.getPortAnnotation(i));
2598 }
2599
2600 auto newOp = rewriter.create<MemOp>(
2601 mem.getLoc(), resultTypes, mem.getReadLatency(), mem.getWriteLatency(),
2602 mem.getDepth(), mem.getRuw(), rewriter.getStrArrayAttr(portNames),
2603 mem.getName(), mem.getNameKind(), mem.getAnnotations(),
2604 rewriter.getArrayAttr(portAnnotations), mem.getInnerSymAttr(),
2605 mem.getInitAttr(), mem.getPrefixAttr());
2606
2607 for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2608 auto result = mem.getResult(i);
2609 auto newResult = newOp.getResult(i);
2610 if (deadReads[i]) {
2611 auto resultPortTy = type_cast<BundleType>(result.getType());
2612
2613 // Rewrite accesses to the old port field to accesses to a
2614 // corresponding field of the new port.
2615 auto replace = [&](StringRef toName, StringRef fromName) {
2616 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2617 assert(fromFieldIndex && "missing enable flag on memory port");
2618
2619 auto toField = rewriter.create<SubfieldOp>(newResult.getLoc(),
2620 newResult, toName);
2621 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2622 auto fromField = cast<SubfieldOp>(op);
2623 if (fromFieldIndex != fromField.getFieldIndex())
2624 continue;
2625 rewriter.replaceOp(fromField, toField.getResult());
2626 }
2627 };
2628
2629 replace("addr", "addr");
2630 replace("en", "en");
2631 replace("clk", "clk");
2632 replace("data", "wdata");
2633 replace("mask", "wmask");
2634
2635 // Remove the wmode field, replacing it with dummy wires.
2636 auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
2637 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2638 auto wmodeField = cast<SubfieldOp>(op);
2639 if (wmodeFieldIndex != wmodeField.getFieldIndex())
2640 continue;
2641 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2642 }
2643 } else {
2644 rewriter.replaceAllUsesWith(result, newResult);
2645 }
2646 }
2647 rewriter.eraseOp(op);
2648 return success();
2649 }
2650};
2651
2652// Eliminate the dead ports of memories.
2653struct FoldUnusedBits : public mlir::RewritePattern {
2654 FoldUnusedBits(MLIRContext *context)
2655 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2656
2657 LogicalResult matchAndRewrite(Operation *op,
2658 PatternRewriter &rewriter) const override {
2659 MemOp mem = cast<MemOp>(op);
2660 if (hasDontTouch(mem))
2661 return failure();
2662
2663 // Only apply the transformation if the memory is not sequential.
2664 const auto &summary = mem.getSummary();
2665 if (summary.isMasked || summary.isSeqMem())
2666 return failure();
2667
2668 auto type = type_dyn_cast<IntType>(mem.getDataType());
2669 if (!type)
2670 return failure();
2671 auto width = type.getBitWidthOrSentinel();
2672 if (width <= 0)
2673 return failure();
2674
2675 llvm::SmallBitVector usedBits(width);
2676 DenseMap<unsigned, unsigned> mapping;
2677
2678 // Find which bits are used out of the users of a read port. This detects
2679 // ports whose data/rdata field is used only through bit select ops. The
2680 // bit selects are then used to build a bit-mask. The ops are collected.
2681 SmallVector<BitsPrimOp> readOps;
2682 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
2683 auto portTy = type_cast<BundleType>(port.getType());
2684 auto fieldIndex = portTy.getElementIndex(field);
2685 assert(fieldIndex && "missing data port");
2686
2687 for (auto *op : port.getUsers()) {
2688 auto portAccess = cast<SubfieldOp>(op);
2689 if (fieldIndex != portAccess.getFieldIndex())
2690 continue;
2691
2692 for (auto *user : op->getUsers()) {
2693 auto bits = dyn_cast<BitsPrimOp>(user);
2694 if (!bits)
2695 return failure();
2696
2697 usedBits.set(bits.getLo(), bits.getHi() + 1);
2698 if (usedBits.all())
2699 return failure();
2700
2701 mapping[bits.getLo()] = 0;
2702 readOps.push_back(bits);
2703 }
2704 }
2705
2706 return success();
2707 };
2708
2709 // Finds the users of write ports. This expects all the data/wdata fields
2710 // of the ports to be used solely as the destination of matching connects.
2711 // If a memory has ports with other uses, it is excluded from optimisation.
2712 SmallVector<MatchingConnectOp> writeOps;
2713 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
2714 auto portTy = type_cast<BundleType>(port.getType());
2715 auto fieldIndex = portTy.getElementIndex(field);
2716 assert(fieldIndex && "missing data port");
2717
2718 for (auto *op : port.getUsers()) {
2719 auto portAccess = cast<SubfieldOp>(op);
2720 if (fieldIndex != portAccess.getFieldIndex())
2721 continue;
2722
2723 auto conn = getSingleConnectUserOf(portAccess);
2724 if (!conn)
2725 return failure();
2726
2727 writeOps.push_back(conn);
2728 }
2729 return success();
2730 };
2731
2732 // Traverse all ports and find the read and used data fields.
2733 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2734 // Do not simplify annotated ports.
2735 if (!mem.getPortAnnotation(i).empty())
2736 return failure();
2737
2738 switch (mem.getPortKind(i)) {
2739 case MemOp::PortKind::Debug:
2740 // Skip debug ports.
2741 return failure();
2742 case MemOp::PortKind::Write:
2743 if (failed(findWriteUsers(port, "data")))
2744 return failure();
2745 continue;
2746 case MemOp::PortKind::Read:
2747 if (failed(findReadUsers(port, "data")))
2748 return failure();
2749 continue;
2750 case MemOp::PortKind::ReadWrite:
2751 if (failed(findWriteUsers(port, "wdata")))
2752 return failure();
2753 if (failed(findReadUsers(port, "rdata")))
2754 return failure();
2755 continue;
2756 }
2757 llvm_unreachable("unknown port kind");
2758 }
2759
2760 // Unused memories are handled in a different canonicalizer.
2761 if (usedBits.none())
2762 return failure();
2763
2764 // Build a mapping of existing indices to compacted ones.
2765 SmallVector<std::pair<unsigned, unsigned>> ranges;
2766 unsigned newWidth = 0;
2767 for (int i = usedBits.find_first(); 0 <= i && i < width;) {
2768 int e = usedBits.find_next_unset(i);
2769 if (e < 0)
2770 e = width;
2771 for (int idx = i; idx < e; ++idx, ++newWidth) {
2772 if (auto it = mapping.find(idx); it != mapping.end()) {
2773 it->second = newWidth;
2774 }
2775 }
2776 ranges.emplace_back(i, e - 1);
2777 i = e != width ? usedBits.find_next(e) : e;
2778 }
2779
2780 // Create the new op with the new port types.
2781 auto newType = IntType::get(op->getContext(), type.isSigned(), newWidth);
2782 SmallVector<Type> portTypes;
2783 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2784 portTypes.push_back(
2785 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
2786 }
2787 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
2788 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
2789 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
2790 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
2791 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2792
2793 // Rewrite bundle users to the new data type.
2794 auto rewriteSubfield = [&](Value port, StringRef field) {
2795 auto portTy = type_cast<BundleType>(port.getType());
2796 auto fieldIndex = portTy.getElementIndex(field);
2797 assert(fieldIndex && "missing data port");
2798
2799 rewriter.setInsertionPointAfter(newMem);
2800 auto newPortAccess =
2801 rewriter.create<SubfieldOp>(port.getLoc(), port, field);
2802
2803 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2804 auto portAccess = cast<SubfieldOp>(op);
2805 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
2806 continue;
2807 rewriter.replaceOp(portAccess, newPortAccess.getResult());
2808 }
2809 };
2810
2811 // Rewrite the field accesses.
2812 for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
2813 switch (newMem.getPortKind(i)) {
2814 case MemOp::PortKind::Debug:
2815 llvm_unreachable("cannot rewrite debug port");
2816 case MemOp::PortKind::Write:
2817 rewriteSubfield(port, "data");
2818 continue;
2819 case MemOp::PortKind::Read:
2820 rewriteSubfield(port, "data");
2821 continue;
2822 case MemOp::PortKind::ReadWrite:
2823 rewriteSubfield(port, "rdata");
2824 rewriteSubfield(port, "wdata");
2825 continue;
2826 }
2827 llvm_unreachable("unknown port kind");
2828 }
2829
2830 // Rewrite the reads to the new ranges, compacting them.
2831 for (auto readOp : readOps) {
2832 rewriter.setInsertionPointAfter(readOp);
2833 auto it = mapping.find(readOp.getLo());
2834 assert(it != mapping.end() && "bit op mapping not found");
2835 // Create a new bit selection from the compressed memory. The new op may
2836 // be folded if we are selecting the entire compressed memory.
2837 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
2838 readOp.getLoc(), readOp.getInput(),
2839 readOp.getHi() - readOp.getLo() + it->second, it->second);
2840 rewriter.replaceAllUsesWith(readOp, newReadValue);
2841 rewriter.eraseOp(readOp);
2842 }
2843
2844 // Rewrite the writes into a concatenation of slices.
2845 for (auto writeOp : writeOps) {
2846 Value source = writeOp.getSrc();
2847 rewriter.setInsertionPoint(writeOp);
2848
2849 Value catOfSlices;
2850 for (auto &[start, end] : ranges) {
2851 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
2852 source, end, start);
2853 if (catOfSlices) {
2854 catOfSlices = rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(),
2855 slice, catOfSlices);
2856 } else {
2857 catOfSlices = slice;
2858 }
2859 }
2860
2861 // If the original memory held a signed integer, then the compressed
2862 // memory will be signed too. Since the catOfSlices is always unsigned,
2863 // cast the data to a signed integer if needed before connecting back to
2864 // the memory.
2865 if (type.isSigned())
2866 catOfSlices =
2867 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
2868
2869 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
2870 catOfSlices);
2871 }
2872
2873 return success();
2874 }
2875};
2876
2877// Rewrite single-address memories to a firrtl register.
2878struct FoldRegMems : public mlir::RewritePattern {
2879 FoldRegMems(MLIRContext *context)
2880 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2881 LogicalResult matchAndRewrite(Operation *op,
2882 PatternRewriter &rewriter) const override {
2883 MemOp mem = cast<MemOp>(op);
2884 const FirMemory &info = mem.getSummary();
2885 if (hasDontTouch(mem) || info.depth != 1)
2886 return failure();
2887
2888 auto ty = mem.getDataType();
2889 auto loc = mem.getLoc();
2890 auto *block = mem->getBlock();
2891
2892 // Find the clock of the register-to-be, all write ports should share it.
2893 Value clock;
2894 SmallPtrSet<Operation *, 8> connects;
2895 SmallVector<SubfieldOp> portAccesses;
2896 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2897 if (!mem.getPortAnnotation(i).empty())
2898 continue;
2899
2900 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
2901 auto portTy = type_cast<BundleType>(port.getType());
2902 for (auto field : fields) {
2903 auto fieldIndex = portTy.getElementIndex(field);
2904 assert(fieldIndex && "missing field on memory port");
2905
2906 for (auto *op : port.getUsers()) {
2907 auto portAccess = cast<SubfieldOp>(op);
2908 if (fieldIndex != portAccess.getFieldIndex())
2909 continue;
2910 portAccesses.push_back(portAccess);
2911 for (auto *user : portAccess->getUsers()) {
2912 auto conn = dyn_cast<FConnectLike>(user);
2913 if (!conn)
2914 return failure();
2915 connects.insert(conn);
2916 }
2917 }
2918 }
2919 return success();
2920 };
2921
2922 switch (mem.getPortKind(i)) {
2923 case MemOp::PortKind::Debug:
2924 return failure();
2925 case MemOp::PortKind::Read:
2926 if (failed(collect({"clk", "en", "addr"})))
2927 return failure();
2928 continue;
2929 case MemOp::PortKind::Write:
2930 if (failed(collect({"clk", "en", "addr", "data", "mask"})))
2931 return failure();
2932 break;
2933 case MemOp::PortKind::ReadWrite:
2934 if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
2935 return failure();
2936 break;
2937 }
2938
2939 Value portClock = getPortFieldValue(port, "clk");
2940 if (!portClock || (clock && portClock != clock))
2941 return failure();
2942 clock = portClock;
2943 }
2944 // Create a new wire where the memory used to be. This wire will dominate
2945 // all readers of the memory. Reads should be made through this wire.
2946 rewriter.setInsertionPointAfter(mem);
2947 auto memWire = rewriter.create<WireOp>(loc, ty).getResult();
2948
2949 // The memory is replaced by a register, which we place at the end of the
2950 // block, so that any value driven to the original memory will dominate the
2951 // new register (including the clock). All other ops will be placed
2952 // after the register.
2953 rewriter.setInsertionPointToEnd(block);
2954 auto memReg =
2955 rewriter.create<RegOp>(loc, ty, clock, mem.getName()).getResult();
2956
2957 // Connect the output of the register to the wire.
2958 rewriter.create<MatchingConnectOp>(loc, memWire, memReg);
2959
2960 // Helper to insert a given number of pipeline stages through registers.
2961 // The pipelines are placed at the end of the block.
2962 auto pipeline = [&](Value value, Value clock, const Twine &name,
2963 unsigned latency) {
2964 for (unsigned i = 0; i < latency; ++i) {
2965 std::string regName;
2966 {
2967 llvm::raw_string_ostream os(regName);
2968 os << mem.getName() << "_" << name << "_" << i;
2969 }
2970 auto reg = rewriter
2971 .create<RegOp>(mem.getLoc(), value.getType(), clock,
2972 rewriter.getStringAttr(regName))
2973 .getResult();
2974 rewriter.create<MatchingConnectOp>(value.getLoc(), reg, value);
2975 value = reg;
2976 }
2977 return value;
2978 };
2979
2980 const unsigned writeStages = info.writeLatency - 1;
2981
2982 // Traverse each port. Replace reads with the pipelined register, discarding
2983 // the enable flag and reading unconditionally. Pipeline the mask, enable
2984 // and data bits of all write ports to be arbitrated and wired to the reg.
2985 SmallVector<std::tuple<Value, Value, Value>> writes;
2986 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2987 Value portClock = getPortFieldValue(port, "clk");
2988 StringRef name = mem.getPortName(i);
2989
2990 auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
2991 Value value = getPortFieldValue(port, field);
2992 assert(value);
2993 return pipeline(value, portClock, name + "_" + field, stages);
2994 };
2995
2996 switch (mem.getPortKind(i)) {
2997 case MemOp::PortKind::Debug:
2998 llvm_unreachable("unknown port kind");
2999 case MemOp::PortKind::Read: {
3000 // Read ports pipeline the addr and enable signals. However, the
3001 // address must be 0 for single-address memories and the enable signal
3002 // is ignored, always reading out the register. Under these constraints,
3003 // the read port can be replaced with the value from the register.
3004 replacePortField(rewriter, port, "data", memWire);
3005 break;
3006 }
3007 case MemOp::PortKind::Write: {
3008 auto data = portPipeline("data", writeStages);
3009 auto en = portPipeline("en", writeStages);
3010 auto mask = portPipeline("mask", writeStages);
3011 writes.emplace_back(data, en, mask);
3012 break;
3013 }
3014 case MemOp::PortKind::ReadWrite: {
3015 // Always read the register into the read end.
3016 replacePortField(rewriter, port, "rdata", memWire);
3017
3018 // Create a write enable and pipeline stages.
3019 auto wdata = portPipeline("wdata", writeStages);
3020 auto wmask = portPipeline("wmask", writeStages);
3021
3022 Value en = getPortFieldValue(port, "en");
3023 Value wmode = getPortFieldValue(port, "wmode");
3024
3025 auto wen = rewriter.create<AndPrimOp>(port.getLoc(), en, wmode);
3026 auto wenPipelined =
3027 pipeline(wen, portClock, name + "_wen", writeStages);
3028 writes.emplace_back(wdata, wenPipelined, wmask);
3029 break;
3030 }
3031 }
3032 }
3033
3034 // Regardless of `writeUnderWrite`, always implement PortOrder.
3035 Value next = memReg;
3036 for (auto &[data, en, mask] : writes) {
3037 Value masked;
3038
3039 // If a mask bit is used, emit muxes to select the input from the
3040 // register (no mask) or the input (mask bit set).
3041 Location loc = mem.getLoc();
3042 unsigned maskGran = info.dataWidth / info.maskBits;
3043 for (unsigned i = 0; i < info.maskBits; ++i) {
3044 unsigned hi = (i + 1) * maskGran - 1;
3045 unsigned lo = i * maskGran;
3046
3047 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
3048 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3049 auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
3050 auto chunk = rewriter.create<MuxPrimOp>(loc, bit, dataPart, nextPart);
3051
3052 if (masked) {
3053 masked = rewriter.create<CatPrimOp>(loc, chunk, masked);
3054 } else {
3055 masked = chunk;
3056 }
3057 }
3058
3059 next = rewriter.create<MuxPrimOp>(next.getLoc(), en, masked, next);
3060 }
3061 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3062 rewriter.create<MatchingConnectOp>(memReg.getLoc(), memReg, typedNext);
3063
3064 // Delete the fields and their associated connects.
3065 for (Operation *conn : connects)
3066 rewriter.eraseOp(conn);
3067 for (auto portAccess : portAccesses)
3068 rewriter.eraseOp(portAccess);
3069 rewriter.eraseOp(mem);
3070
3071 return success();
3072 }
3073};
3074} // namespace
3075
3076void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3077 MLIRContext *context) {
3078 results
3079 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3080 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3081 context);
3082}
3083
3084//===----------------------------------------------------------------------===//
3085// Declarations
3086//===----------------------------------------------------------------------===//
3087
3088// Turn synchronous reset looking register updates to registers with resets.
3089// Also, const prop registers that are driven by a mux tree containing only
3090// instances of one constant or self-assigns.
3091static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3092 // reg ; connect(reg, mux(port, const, val)) ->
3093 // reg.reset(port, const); connect(reg, val)
3094
3095 // Find the one true connect, or bail
3096 auto con = getSingleConnectUserOf(reg.getResult());
3097 if (!con)
3098 return failure();
3099
3100 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3101 if (!mux)
3102 return failure();
3103 auto *high = mux.getHigh().getDefiningOp();
3104 auto *low = mux.getLow().getDefiningOp();
3105 // Reset value must be constant
3106 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3107
3108 // Detect the case if a register only has two possible drivers:
3109 // (1) itself/uninit and (2) constant.
3110 // The mux can then be replaced with the constant.
3111 // r = mux(cond, r, 3) --> r = 3
3112 // r = mux(cond, 3, r) --> r = 3
3113 bool constReg = false;
3114
3115 if (constOp && low == reg)
3116 constReg = true;
3117 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3118 constReg = true;
3119 constOp = dyn_cast<ConstantOp>(low);
3120 }
3121 if (!constOp)
3122 return failure();
3123
3124 // For a non-constant register, reset should be a module port (heuristic to
3125 // limit to intended reset lines). Replace the register anyway if constant.
3126 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3127 return failure();
3128
3129 // Check all types should be typed by now
3130 auto regTy = reg.getResult().getType();
3131 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3132 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3133 regTy.getBitWidthOrSentinel() < 0)
3134 return failure();
3135
3136 // Ok, we know we are doing the transformation.
3137
3138 // Make sure the constant dominates all users.
3139 if (constOp != &con->getBlock()->front())
3140 constOp->moveBefore(&con->getBlock()->front());
3141
3142 if (!constReg) {
3143 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3144 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3145 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3146 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3147 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3148 reg.getForceableAttr());
3149 newReg->setDialectAttrs(attrs);
3150 }
3151 auto pt = rewriter.saveInsertionPoint();
3152 rewriter.setInsertionPoint(con);
3153 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3154 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3155 rewriter.restoreInsertionPoint(pt);
3156 return success();
3157}
3158
3159LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3160 if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3161 succeeded(foldHiddenReset(op, rewriter)))
3162 return success();
3163
3164 if (succeeded(demoteForceableIfUnused(op, rewriter)))
3165 return success();
3166
3167 return failure();
3168}
3169
3170//===----------------------------------------------------------------------===//
3171// Verification Ops.
3172//===----------------------------------------------------------------------===//
3173
3174static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3175 Value enable,
3176 PatternRewriter &rewriter,
3177 bool eraseIfZero) {
3178 // If the verification op is never enabled, delete it.
3179 if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3180 if (constant.getValue().isZero()) {
3181 rewriter.eraseOp(op);
3182 return success();
3183 }
3184 }
3185
3186 // If the verification op is never triggered, delete it.
3187 if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3188 if (constant.getValue().isZero() == eraseIfZero) {
3189 rewriter.eraseOp(op);
3190 return success();
3191 }
3192 }
3193
3194 return failure();
3195}
3196
3197template <class Op, bool EraseIfZero = false>
3198static LogicalResult canonicalizeImmediateVerifOp(Op op,
3199 PatternRewriter &rewriter) {
3200 return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3201 EraseIfZero);
3202}
3203
3204void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3205 MLIRContext *context) {
3206 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3207}
3208
3209void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3210 MLIRContext *context) {
3211 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3212}
3213
3214void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3215 RewritePatternSet &results, MLIRContext *context) {
3216 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3217}
3218
3219void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3220 MLIRContext *context) {
3221 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3222}
3223
3224//===----------------------------------------------------------------------===//
3225// InvalidValueOp
3226//===----------------------------------------------------------------------===//
3227
3228LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3229 PatternRewriter &rewriter) {
3230 // Remove `InvalidValueOp`s with no uses.
3231 if (op.use_empty()) {
3232 rewriter.eraseOp(op);
3233 return success();
3234 }
3235 // Propagate invalids through a single use which is a unary op. You cannot
3236 // propagate through multiple uses as that breaks invalid semantics. Nor
3237 // can you propagate through binary ops or generally any op which computes.
3238 // Not is an exception as it is a pure, all-bits inverse.
3239 if (op->hasOneUse() &&
3240 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3241 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3242 *op->user_begin()) ||
3243 (isa<CvtPrimOp>(*op->user_begin()) &&
3244 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3245 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3246 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3247 .getBitWidthOrSentinel() > 0))) {
3248 auto *modop = *op->user_begin();
3249 auto inv = rewriter.create<InvalidValueOp>(op.getLoc(),
3250 modop->getResult(0).getType());
3251 rewriter.replaceAllOpUsesWith(modop, inv);
3252 rewriter.eraseOp(modop);
3253 rewriter.eraseOp(op);
3254 return success();
3255 }
3256 return failure();
3257}
3258
3259OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3260 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3261 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3262 return {};
3263}
3264
3265//===----------------------------------------------------------------------===//
3266// ClockGateIntrinsicOp
3267//===----------------------------------------------------------------------===//
3268
3269OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3270 // Forward the clock if one of the enables is always true.
3271 if (isConstantOne(adaptor.getEnable()) ||
3272 isConstantOne(adaptor.getTestEnable()))
3273 return getInput();
3274
3275 // Fold to a constant zero clock if the enables are always false.
3276 if (isConstantZero(adaptor.getEnable()) &&
3277 (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3278 return BoolAttr::get(getContext(), false);
3279
3280 // Forward constant zero clocks.
3281 if (isConstantZero(adaptor.getInput()))
3282 return BoolAttr::get(getContext(), false);
3283
3284 return {};
3285}
3286
3287LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3288 PatternRewriter &rewriter) {
3289 // Remove constant false test enable.
3290 if (auto testEnable = op.getTestEnable()) {
3291 if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3292 if (constOp.getValue().isZero()) {
3293 rewriter.modifyOpInPlace(op,
3294 [&] { op.getTestEnableMutable().clear(); });
3295 return success();
3296 }
3297 }
3298 }
3299
3300 return failure();
3301}
3302
3303//===----------------------------------------------------------------------===//
3304// Reference Ops.
3305//===----------------------------------------------------------------------===//
3306
3307// refresolve(forceable.ref) -> forceable.data
3308static LogicalResult
3309canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3310 auto forceable = op.getRef().getDefiningOp<Forceable>();
3311 if (!forceable || !forceable.isForceable() ||
3312 op.getRef() != forceable.getDataRef() ||
3313 op.getType() != forceable.getDataType())
3314 return failure();
3315 rewriter.replaceAllUsesWith(op, forceable.getData());
3316 return success();
3317}
3318
3319void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3320 MLIRContext *context) {
3321 results.insert<patterns::RefResolveOfRefSend>(context);
3322 results.insert(canonicalizeRefResolveOfForceable);
3323}
3324
3325OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3326 // RefCast is unnecessary if types match.
3327 if (getInput().getType() == getType())
3328 return getInput();
3329 return {};
3330}
3331
3332static bool isConstantZero(Value operand) {
3333 auto constOp = operand.getDefiningOp<ConstantOp>();
3334 return constOp && constOp.getValue().isZero();
3335}
3336
3337template <typename Op>
3338static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3339 if (isConstantZero(op.getPredicate())) {
3340 rewriter.eraseOp(op);
3341 return success();
3342 }
3343 return failure();
3344}
3345
3346void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3347 MLIRContext *context) {
3348 results.add(eraseIfPredFalse<RefForceOp>);
3349}
3350void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3351 MLIRContext *context) {
3352 results.add(eraseIfPredFalse<RefForceInitialOp>);
3353}
3354void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3355 MLIRContext *context) {
3356 results.add(eraseIfPredFalse<RefReleaseOp>);
3357}
3358void RefReleaseInitialOp::getCanonicalizationPatterns(
3359 RewritePatternSet &results, MLIRContext *context) {
3360 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3361}
3362
3363//===----------------------------------------------------------------------===//
3364// HasBeenResetIntrinsicOp
3365//===----------------------------------------------------------------------===//
3366
3367OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3368 // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3369
3370 // Fold to zero if the reset is a constant. In this case the op is either
3371 // permanently in reset or never resets. Both mean that the reset never
3372 // finishes, so this op never returns true.
3373 if (adaptor.getReset())
3374 return getIntZerosAttr(UIntType::get(getContext(), 1));
3375
3376 // Fold to zero if the clock is a constant and the reset is synchronous. In
3377 // that case the reset will never be started.
3378 if (isUInt1(getReset().getType()) && adaptor.getClock())
3379 return getIntZerosAttr(UIntType::get(getContext(), 1));
3380
3381 return {};
3382}
3383
3384//===----------------------------------------------------------------------===//
3385// FPGAProbeIntrinsicOp
3386//===----------------------------------------------------------------------===//
3387
3388static bool isTypeEmpty(FIRRTLType type) {
3390 .Case<FVectorType>(
3391 [&](auto ty) -> bool { return isTypeEmpty(ty.getElementType()); })
3392 .Case<BundleType>([&](auto ty) -> bool {
3393 for (auto elem : ty.getElements())
3394 if (!isTypeEmpty(elem.type))
3395 return false;
3396 return true;
3397 })
3398 .Case<IntType>([&](auto ty) { return ty.getWidth() == 0; })
3399 .Default([](auto) -> bool { return false; });
3400}
3401
3402LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3403 PatternRewriter &rewriter) {
3404 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3405 if (!firrtlTy)
3406 return failure();
3407
3408 if (!isTypeEmpty(firrtlTy))
3409 return failure();
3410
3411 rewriter.eraseOp(op);
3412 return success();
3413}
3414
3415//===----------------------------------------------------------------------===//
3416// Layer Block Op
3417//===----------------------------------------------------------------------===//
3418
3419LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3420 PatternRewriter &rewriter) {
3421
3422 // If the layerblock is empty, erase it.
3423 if (op.getBody()->empty()) {
3424 rewriter.eraseOp(op);
3425 return success();
3426 }
3427
3428 return failure();
3429}
assert(baseType &&"element must be base type")
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static InstancePath empty
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:18
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition FoldUtils.h:27
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition FoldUtils.h:34
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21