CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
FIRRTLFolds.cpp
Go to the documentation of this file.
1//===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the folding and canonicalizations for FIRRTL ops.
10//
11//===----------------------------------------------------------------------===//
12
17#include "circt/Support/APInt.h"
18#include "circt/Support/LLVM.h"
20#include "mlir/IR/Matchers.h"
21#include "mlir/IR/PatternMatch.h"
22#include "llvm/ADT/APSInt.h"
23#include "llvm/ADT/SmallPtrSet.h"
24#include "llvm/ADT/StringExtras.h"
25#include "llvm/ADT/TypeSwitch.h"
26
27using namespace circt;
28using namespace firrtl;
29
30// Drop writes to old and pass through passthrough to make patterns easier to
31// write.
32static Value dropWrite(PatternRewriter &rewriter, OpResult old,
33 Value passthrough) {
34 SmallPtrSet<Operation *, 8> users;
35 for (auto *user : old.getUsers())
36 users.insert(user);
37 for (Operation *user : users)
38 if (auto connect = dyn_cast<FConnectLike>(user))
39 if (connect.getDest() == old)
40 rewriter.eraseOp(user);
41 return passthrough;
42}
43
44// Move a name hint from a soon to be deleted operation to a new operation.
45// Pass through the new operation to make patterns easier to write. This cannot
46// move a name to a port (block argument), doing so would require rewriting all
47// instance sites as well as the module.
48static Value moveNameHint(OpResult old, Value passthrough) {
49 Operation *op = passthrough.getDefiningOp();
50 // This should handle ports, but it isn't clear we can change those in
51 // canonicalizers.
52 assert(op && "passthrough must be an operation");
53 Operation *oldOp = old.getOwner();
54 auto name = oldOp->getAttrOfType<StringAttr>("name");
55 if (name && !name.getValue().empty())
56 op->setAttr("name", name);
57 return passthrough;
58}
59
60// Declarative canonicalization patterns
61namespace circt {
62namespace firrtl {
63namespace patterns {
64#include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
65} // namespace patterns
66} // namespace firrtl
67} // namespace circt
68
69/// Return true if this operation's operands and results all have a known width.
70/// This only works for integer types.
71static bool hasKnownWidthIntTypes(Operation *op) {
72 auto resultType = type_cast<IntType>(op->getResult(0).getType());
73 if (!resultType.hasWidth())
74 return false;
75 for (Value operand : op->getOperands())
76 if (!type_cast<IntType>(operand.getType()).hasWidth())
77 return false;
78 return true;
79}
80
81/// Return true if this value is 1 bit UInt.
82static bool isUInt1(Type type) {
83 auto t = type_dyn_cast<UIntType>(type);
84 if (!t || !t.hasWidth() || t.getWidth() != 1)
85 return false;
86 return true;
87}
88
89/// Set the name of an op based on the best of two names: The current name, and
90/// the name passed in.
91static void updateName(PatternRewriter &rewriter, Operation *op,
92 StringAttr name) {
93 // Should never rename InstanceOp
94 assert(!isa<InstanceOp>(op));
95 if (!name || name.getValue().empty())
96 return;
97 auto newName = name.getValue(); // old name is interesting
98 auto newOpName = op->getAttrOfType<StringAttr>("name");
99 // new name might not be interesting
100 if (newOpName)
101 newName = chooseName(newOpName.getValue(), name.getValue());
102 // Only update if needed
103 if (!newOpName || newOpName.getValue() != newName)
104 rewriter.modifyOpInPlace(
105 op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
106}
107
108/// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
109/// If a replaced op has a "name" attribute, this function propagates the name
110/// to the new value.
111static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
112 Value newValue) {
113 if (auto *newOp = newValue.getDefiningOp()) {
114 auto name = op->getAttrOfType<StringAttr>("name");
115 updateName(rewriter, newOp, name);
116 }
117 rewriter.replaceOp(op, newValue);
118}
119
120/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
121/// attribute. If a replaced op has a "name" attribute, this function propagates
122/// the name to the new value.
123template <typename OpTy, typename... Args>
124static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
125 Operation *op, Args &&...args) {
126 auto name = op->getAttrOfType<StringAttr>("name");
127 auto newOp =
128 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
129 updateName(rewriter, newOp, name);
130 return newOp;
131}
132
133/// Return true if the name is droppable. Note that this is different from
134/// `isUselessName` because non-useless names may be also droppable.
136 if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
137 return namableOp.hasDroppableName();
138 return false;
139}
140
141/// Implicitly replace the operand to a constant folding operation with a const
142/// 0 in case the operand is non-constant but has a bit width 0, or if the
143/// operand is an invalid value.
144///
145/// This makes constant folding significantly easier, as we can simply pass the
146/// operands to an operation through this function to appropriately replace any
147/// zero-width dynamic values or invalid values with a constant of value 0.
148static std::optional<APSInt>
149getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
150 assert(type_cast<IntType>(operand.getType()) &&
151 "getExtendedConstant is limited to integer types");
152
153 // We never support constant folding to unknown width values.
154 if (destWidth < 0)
155 return {};
156
157 // Extension signedness follows the operand sign.
158 if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
159 return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
160
161 // If the operand is zero bits, then we can return a zero of the result
162 // type.
163 if (type_cast<IntType>(operand.getType()).getWidth() == 0)
164 return APSInt(destWidth,
165 type_cast<IntType>(operand.getType()).isUnsigned());
166 return {};
167}
168
169/// Determine the value of a constant operand for the sake of constant folding.
170static std::optional<APSInt> getConstant(Attribute operand) {
171 if (!operand)
172 return {};
173 if (auto attr = dyn_cast<BoolAttr>(operand))
174 return APSInt(APInt(1, attr.getValue()));
175 if (auto attr = dyn_cast<IntegerAttr>(operand))
176 return attr.getAPSInt();
177 return {};
178}
179
180/// Determine whether a constant operand is a zero value for the sake of
181/// constant folding. This considers `invalidvalue` to be zero.
182static bool isConstantZero(Attribute operand) {
183 if (auto cst = getConstant(operand))
184 return cst->isZero();
185 return false;
186}
187
188/// Determine whether a constant operand is a one value for the sake of constant
189/// folding.
190static bool isConstantOne(Attribute operand) {
191 if (auto cst = getConstant(operand))
192 return cst->isOne();
193 return false;
194}
195
196/// This is the policy for folding, which depends on the sort of operator we're
197/// processing.
198enum class BinOpKind {
199 Normal,
200 Compare,
202};
203
204/// Applies the constant folding function `calculate` to the given operands.
205///
206/// Sign or zero extends the operands appropriately to the bitwidth of the
207/// result type if \p useDstWidth is true, else to the larger of the two operand
208/// bit widths and depending on whether the operation is to be performed on
209/// signed or unsigned operands.
210static Attribute constFoldFIRRTLBinaryOp(
211 Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
212 const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
213 assert(operands.size() == 2 && "binary op takes two operands");
214
215 // We cannot fold something to an unknown width.
216 auto resultType = type_cast<IntType>(op->getResult(0).getType());
217 if (resultType.getWidthOrSentinel() < 0)
218 return {};
219
220 // Any binary op returning i0 is 0.
221 if (resultType.getWidthOrSentinel() == 0)
222 return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
223
224 // Determine the operand widths. This is either dictated by the operand type,
225 // or if that type is an unsized integer, by the actual bits necessary to
226 // represent the constant value.
227 auto lhsWidth =
228 type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
229 auto rhsWidth =
230 type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
231 if (auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
232 lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
233 if (auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
234 rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
235
236 // Compares extend the operands to the widest of the operand types, not to the
237 // result type.
238 int32_t operandWidth;
239 switch (opKind) {
241 operandWidth = resultType.getWidthOrSentinel();
242 break;
244 // Compares compute with the widest operand, not at the destination type
245 // (which is always i1).
246 operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
247 break;
249 operandWidth =
250 std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
251 break;
252 }
253
254 auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
255 if (!lhs)
256 return {};
257 auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
258 if (!rhs)
259 return {};
260
261 APInt resultValue = calculate(*lhs, *rhs);
262
263 // If the result type is smaller than the computation then we need to
264 // narrow the constant after the calculation.
265 if (opKind == BinOpKind::DivideOrShift)
266 resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
267
268 assert((unsigned)resultType.getWidthOrSentinel() ==
269 resultValue.getBitWidth());
270 return getIntAttr(resultType, resultValue);
271}
272
273/// Applies the canonicalization function `canonicalize` to the given operation.
274///
275/// Determines which (if any) of the operation's operands are constants, and
276/// provides them as arguments to the callback function. Any `invalidvalue` in
277/// the input is mapped to a constant zero. The value returned from the callback
278/// is used as the replacement for `op`, and an additional pad operation is
279/// inserted if necessary. Does nothing if the result of `op` is of unknown
280/// width, in which case the necessity of a pad cannot be determined.
281static LogicalResult canonicalizePrimOp(
282 Operation *op, PatternRewriter &rewriter,
283 const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
284 // Can only operate on FIRRTL primitive operations.
285 if (op->getNumResults() != 1)
286 return failure();
287 auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
288 if (!type)
289 return failure();
290
291 // Can only operate on operations with a known result width.
292 auto width = type.getBitWidthOrSentinel();
293 if (width < 0)
294 return failure();
295
296 // Determine which of the operands are constants.
297 SmallVector<Attribute, 3> constOperands;
298 constOperands.reserve(op->getNumOperands());
299 for (auto operand : op->getOperands()) {
300 Attribute attr;
301 if (auto *defOp = operand.getDefiningOp())
302 TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
303 [&](auto op) { attr = op.getValueAttr(); });
304 constOperands.push_back(attr);
305 }
306
307 // Perform the canonicalization and materialize the result if it is a
308 // constant.
309 auto result = canonicalize(constOperands);
310 if (!result)
311 return failure();
312 Value resultValue;
313 if (auto cst = dyn_cast<Attribute>(result))
314 resultValue = op->getDialect()
315 ->materializeConstant(rewriter, cst, type, op->getLoc())
316 ->getResult(0);
317 else
318 resultValue = cast<Value>(result);
319
320 // Insert a pad if the type widths disagree.
321 if (width !=
322 type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
323 resultValue = rewriter.create<PadPrimOp>(op->getLoc(), resultValue, width);
324
325 // Insert a cast if this is a uint vs. sint or vice versa.
326 if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
327 resultValue = rewriter.create<AsSIntPrimOp>(op->getLoc(), resultValue);
328 else if (type_isa<UIntType>(type) &&
329 type_isa<SIntType>(resultValue.getType()))
330 resultValue = rewriter.create<AsUIntPrimOp>(op->getLoc(), resultValue);
331
332 assert(type == resultValue.getType() && "canonicalization changed type");
333 replaceOpAndCopyName(rewriter, op, resultValue);
334 return success();
335}
336
337/// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
338/// value if `bitWidth` is 0.
339static APInt getMaxUnsignedValue(unsigned bitWidth) {
340 return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
341}
342
343/// Get the smallest signed value of a given bit width. Returns a 1-bit zero
344/// value if `bitWidth` is 0.
345static APInt getMinSignedValue(unsigned bitWidth) {
346 return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
347}
348
349/// Get the largest signed value of a given bit width. Returns a 1-bit zero
350/// value if `bitWidth` is 0.
351static APInt getMaxSignedValue(unsigned bitWidth) {
352 return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
353}
354
355//===----------------------------------------------------------------------===//
356// Fold Hooks
357//===----------------------------------------------------------------------===//
358
359OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
360 assert(adaptor.getOperands().empty() && "constant has no operands");
361 return getValueAttr();
362}
363
364OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
365 assert(adaptor.getOperands().empty() && "constant has no operands");
366 return getValueAttr();
367}
368
369OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
370 assert(adaptor.getOperands().empty() && "constant has no operands");
371 return getFieldsAttr();
372}
373
374OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
375 assert(adaptor.getOperands().empty() && "constant has no operands");
376 return getValueAttr();
377}
378
379OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
380 assert(adaptor.getOperands().empty() && "constant has no operands");
381 return getValueAttr();
382}
383
384OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
385 assert(adaptor.getOperands().empty() && "constant has no operands");
386 return getValueAttr();
387}
388
389OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
390 assert(adaptor.getOperands().empty() && "constant has no operands");
391 return getValueAttr();
392}
393
394//===----------------------------------------------------------------------===//
395// Binary Operators
396//===----------------------------------------------------------------------===//
397
398OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
400 *this, adaptor.getOperands(), BinOpKind::Normal,
401 [=](const APSInt &a, const APSInt &b) { return a + b; });
402}
403
404void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
405 MLIRContext *context) {
406 results.insert<patterns::moveConstAdd, patterns::AddOfZero,
407 patterns::AddOfSelf, patterns::AddOfPad>(context);
408}
409
410OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
412 *this, adaptor.getOperands(), BinOpKind::Normal,
413 [=](const APSInt &a, const APSInt &b) { return a - b; });
414}
415
416void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417 MLIRContext *context) {
418 results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
419 patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
420 patterns::SubOfPadL, patterns::SubOfPadR>(context);
421}
422
423OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
424 // mul(x, 0) -> 0
425 //
426 // This is legal because it aligns with the Scala FIRRTL Compiler
427 // interpretation of lowering invalid to constant zero before constant
428 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
429 // multiplication this way and will emit "x * 0".
430 if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
431 return getIntZerosAttr(getType());
432
434 *this, adaptor.getOperands(), BinOpKind::Normal,
435 [=](const APSInt &a, const APSInt &b) { return a * b; });
436}
437
438OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
439 /// div(x, x) -> 1
440 ///
441 /// Division by zero is undefined in the FIRRTL specification. This fold
442 /// exploits that fact to optimize self division to one. Note: this should
443 /// supersede any division with invalid or zero. Division of invalid by
444 /// invalid should be one.
445 if (getLhs() == getRhs()) {
446 auto width = getType().base().getWidthOrSentinel();
447 if (width == -1)
448 width = 2;
449 // Only fold if we have at least 1 bit of width to represent the `1` value.
450 if (width != 0)
451 return getIntAttr(getType(), APInt(width, 1));
452 }
453
454 // div(0, x) -> 0
455 //
456 // This is legal because it aligns with the Scala FIRRTL Compiler
457 // interpretation of lowering invalid to constant zero before constant
458 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
459 // division this way and will emit "0 / x".
460 if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
461 return getIntZerosAttr(getType());
462
463 /// div(x, 1) -> x : (uint, uint) -> uint
464 ///
465 /// UInt division by one returns the numerator. SInt division can't
466 /// be folded here because it increases the return type bitwidth by
467 /// one and requires sign extension (a new op).
468 if (auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
469 if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
470 return getLhs();
471
473 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
474 [=](const APSInt &a, const APSInt &b) -> APInt {
475 if (!!b)
476 return a / b;
477 return APInt(a.getBitWidth(), 0);
478 });
479}
480
481OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
482 // rem(x, x) -> 0
483 //
484 // Division by zero is undefined in the FIRRTL specification. This fold
485 // exploits that fact to optimize self division remainder to zero. Note:
486 // this should supersede any division with invalid or zero. Remainder of
487 // division of invalid by invalid should be zero.
488 if (getLhs() == getRhs())
489 return getIntZerosAttr(getType());
490
491 // rem(0, x) -> 0
492 //
493 // This is legal because it aligns with the Scala FIRRTL Compiler
494 // interpretation of lowering invalid to constant zero before constant
495 // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
496 // division this way and will emit "0 % x".
497 if (isConstantZero(adaptor.getLhs()))
498 return getIntZerosAttr(getType());
499
501 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
502 [=](const APSInt &a, const APSInt &b) -> APInt {
503 if (!!b)
504 return a % b;
505 return APInt(a.getBitWidth(), 0);
506 });
507}
508
509OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
511 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
512 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
513}
514
515OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
517 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
518 [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
519}
520
521OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
523 *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
524 [=](const APSInt &a, const APSInt &b) -> APInt {
525 return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
526 : a.ashr(b);
527 });
528}
529
530// TODO: Move to DRR.
531OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
532 if (auto rhsCst = getConstant(adaptor.getRhs())) {
533 /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
534 if (rhsCst->isZero())
535 return getIntZerosAttr(getType());
536
537 /// and(x, -1) -> x
538 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
539 getRhs().getType() == getType())
540 return getLhs();
541 }
542
543 if (auto lhsCst = getConstant(adaptor.getLhs())) {
544 /// and(0, x) -> 0, 0 is largest or is implicit zero extended
545 if (lhsCst->isZero())
546 return getIntZerosAttr(getType());
547
548 /// and(-1, x) -> x
549 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
550 getRhs().getType() == getType())
551 return getRhs();
552 }
553
554 /// and(x, x) -> x
555 if (getLhs() == getRhs() && getRhs().getType() == getType())
556 return getRhs();
557
559 *this, adaptor.getOperands(), BinOpKind::Normal,
560 [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
561}
562
563void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
564 MLIRContext *context) {
565 results
566 .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
567 patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
568 patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
569}
570
571OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
572 if (auto rhsCst = getConstant(adaptor.getRhs())) {
573 /// or(x, 0) -> x
574 if (rhsCst->isZero() && getLhs().getType() == getType())
575 return getLhs();
576
577 /// or(x, -1) -> -1
578 if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
579 getLhs().getType() == getType())
580 return getRhs();
581 }
582
583 if (auto lhsCst = getConstant(adaptor.getLhs())) {
584 /// or(0, x) -> x
585 if (lhsCst->isZero() && getRhs().getType() == getType())
586 return getRhs();
587
588 /// or(-1, x) -> -1
589 if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
590 getRhs().getType() == getType())
591 return getLhs();
592 }
593
594 /// or(x, x) -> x
595 if (getLhs() == getRhs() && getRhs().getType() == getType())
596 return getRhs();
597
599 *this, adaptor.getOperands(), BinOpKind::Normal,
600 [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
601}
602
603void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
604 MLIRContext *context) {
605 results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
606 patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
607 patterns::OrOrr>(context);
608}
609
610OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
611 /// xor(x, 0) -> x
612 if (auto rhsCst = getConstant(adaptor.getRhs()))
613 if (rhsCst->isZero() &&
614 firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
615 return getLhs();
616
617 /// xor(x, 0) -> x
618 if (auto lhsCst = getConstant(adaptor.getLhs()))
619 if (lhsCst->isZero() &&
620 firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
621 return getRhs();
622
623 /// xor(x, x) -> 0
624 if (getLhs() == getRhs())
625 return getIntAttr(
626 getType(),
627 APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
628
630 *this, adaptor.getOperands(), BinOpKind::Normal,
631 [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
632}
633
634void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
635 MLIRContext *context) {
636 results.insert<patterns::extendXor, patterns::moveConstXor,
637 patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
638 context);
639}
640
641void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
642 MLIRContext *context) {
643 results.insert<patterns::LEQWithConstLHS>(context);
644}
645
646OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
647 bool isUnsigned = getLhs().getType().base().isUnsigned();
648
649 // leq(x, x) -> 1
650 if (getLhs() == getRhs())
651 return getIntAttr(getType(), APInt(1, 1));
652
653 // Comparison against constant outside type bounds.
654 if (auto width = getLhs().getType().base().getWidth()) {
655 if (auto rhsCst = getConstant(adaptor.getRhs())) {
656 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
657 commonWidth = std::max(commonWidth, 1);
658
659 // leq(x, const) -> 0 where const < minValue of the unsigned type of x
660 // This can never occur since const is unsigned and cannot be less than 0.
661
662 // leq(x, const) -> 0 where const < minValue of the signed type of x
663 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
664 .slt(getMinSignedValue(*width).sext(commonWidth)))
665 return getIntAttr(getType(), APInt(1, 0));
666
667 // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
668 if (isUnsigned && rhsCst->zext(commonWidth)
669 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
670 return getIntAttr(getType(), APInt(1, 1));
671
672 // leq(x, const) -> 1 where const >= maxValue of the signed type of x
673 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
674 .sge(getMaxSignedValue(*width).sext(commonWidth)))
675 return getIntAttr(getType(), APInt(1, 1));
676 }
677 }
678
680 *this, adaptor.getOperands(), BinOpKind::Compare,
681 [=](const APSInt &a, const APSInt &b) -> APInt {
682 return APInt(1, a <= b);
683 });
684}
685
686void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
687 MLIRContext *context) {
688 results.insert<patterns::LTWithConstLHS>(context);
689}
690
691OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
692 IntType lhsType = getLhs().getType();
693 bool isUnsigned = lhsType.isUnsigned();
694
695 // lt(x, x) -> 0
696 if (getLhs() == getRhs())
697 return getIntAttr(getType(), APInt(1, 0));
698
699 // lt(x, 0) -> 0 when x is unsigned
700 if (auto rhsCst = getConstant(adaptor.getRhs())) {
701 if (rhsCst->isZero() && lhsType.isUnsigned())
702 return getIntAttr(getType(), APInt(1, 0));
703 }
704
705 // Comparison against constant outside type bounds.
706 if (auto width = lhsType.getWidth()) {
707 if (auto rhsCst = getConstant(adaptor.getRhs())) {
708 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
709 commonWidth = std::max(commonWidth, 1);
710
711 // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
712 // Handled explicitly above.
713
714 // lt(x, const) -> 0 where const <= minValue of the signed type of x
715 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
716 .sle(getMinSignedValue(*width).sext(commonWidth)))
717 return getIntAttr(getType(), APInt(1, 0));
718
719 // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
720 if (isUnsigned && rhsCst->zext(commonWidth)
721 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
722 return getIntAttr(getType(), APInt(1, 1));
723
724 // lt(x, const) -> 1 where const > maxValue of the signed type of x
725 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
726 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
727 return getIntAttr(getType(), APInt(1, 1));
728 }
729 }
730
732 *this, adaptor.getOperands(), BinOpKind::Compare,
733 [=](const APSInt &a, const APSInt &b) -> APInt {
734 return APInt(1, a < b);
735 });
736}
737
738void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
739 MLIRContext *context) {
740 results.insert<patterns::GEQWithConstLHS>(context);
741}
742
743OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
744 IntType lhsType = getLhs().getType();
745 bool isUnsigned = lhsType.isUnsigned();
746
747 // geq(x, x) -> 1
748 if (getLhs() == getRhs())
749 return getIntAttr(getType(), APInt(1, 1));
750
751 // geq(x, 0) -> 1 when x is unsigned
752 if (auto rhsCst = getConstant(adaptor.getRhs())) {
753 if (rhsCst->isZero() && isUnsigned)
754 return getIntAttr(getType(), APInt(1, 1));
755 }
756
757 // Comparison against constant outside type bounds.
758 if (auto width = lhsType.getWidth()) {
759 if (auto rhsCst = getConstant(adaptor.getRhs())) {
760 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
761 commonWidth = std::max(commonWidth, 1);
762
763 // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
764 if (isUnsigned && rhsCst->zext(commonWidth)
765 .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
766 return getIntAttr(getType(), APInt(1, 0));
767
768 // geq(x, const) -> 0 where const > maxValue of the signed type of x
769 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
770 .sgt(getMaxSignedValue(*width).sext(commonWidth)))
771 return getIntAttr(getType(), APInt(1, 0));
772
773 // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
774 // Handled explicitly above.
775
776 // geq(x, const) -> 1 where const <= minValue of the signed type of x
777 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
778 .sle(getMinSignedValue(*width).sext(commonWidth)))
779 return getIntAttr(getType(), APInt(1, 1));
780 }
781 }
782
784 *this, adaptor.getOperands(), BinOpKind::Compare,
785 [=](const APSInt &a, const APSInt &b) -> APInt {
786 return APInt(1, a >= b);
787 });
788}
789
790void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
791 MLIRContext *context) {
792 results.insert<patterns::GTWithConstLHS>(context);
793}
794
795OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
796 IntType lhsType = getLhs().getType();
797 bool isUnsigned = lhsType.isUnsigned();
798
799 // gt(x, x) -> 0
800 if (getLhs() == getRhs())
801 return getIntAttr(getType(), APInt(1, 0));
802
803 // Comparison against constant outside type bounds.
804 if (auto width = lhsType.getWidth()) {
805 if (auto rhsCst = getConstant(adaptor.getRhs())) {
806 auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
807 commonWidth = std::max(commonWidth, 1);
808
809 // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
810 if (isUnsigned && rhsCst->zext(commonWidth)
811 .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
812 return getIntAttr(getType(), APInt(1, 0));
813
814 // gt(x, const) -> 0 where const >= maxValue of the signed type of x
815 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
816 .sge(getMaxSignedValue(*width).sext(commonWidth)))
817 return getIntAttr(getType(), APInt(1, 0));
818
819 // gt(x, const) -> 1 where const < minValue of the unsigned type of x
820 // This can never occur since const is unsigned and cannot be less than 0.
821
822 // gt(x, const) -> 1 where const < minValue of the signed type of x
823 if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
824 .slt(getMinSignedValue(*width).sext(commonWidth)))
825 return getIntAttr(getType(), APInt(1, 1));
826 }
827 }
828
830 *this, adaptor.getOperands(), BinOpKind::Compare,
831 [=](const APSInt &a, const APSInt &b) -> APInt {
832 return APInt(1, a > b);
833 });
834}
835
836OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
837 // eq(x, x) -> 1
838 if (getLhs() == getRhs())
839 return getIntAttr(getType(), APInt(1, 1));
840
841 if (auto rhsCst = getConstant(adaptor.getRhs())) {
842 /// eq(x, 1) -> x when x is 1 bit.
843 /// TODO: Support SInt<1> on the LHS etc.
844 if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
845 getRhs().getType() == getType())
846 return getLhs();
847 }
848
850 *this, adaptor.getOperands(), BinOpKind::Compare,
851 [=](const APSInt &a, const APSInt &b) -> APInt {
852 return APInt(1, a == b);
853 });
854}
855
856LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
857 return canonicalizePrimOp(
858 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
859 if (auto rhsCst = getConstant(operands[1])) {
860 auto width = op.getLhs().getType().getBitWidthOrSentinel();
861
862 // eq(x, 0) -> not(x) when x is 1 bit.
863 if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
864 op.getRhs().getType() == op.getType()) {
865 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
866 .getResult();
867 }
868
869 // eq(x, 0) -> not(orr(x)) when x is >1 bit
870 if (rhsCst->isZero() && width > 1) {
871 auto orrOp = rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs());
872 return rewriter.create<NotPrimOp>(op.getLoc(), orrOp).getResult();
873 }
874
875 // eq(x, ~0) -> andr(x) when x is >1 bit
876 if (rhsCst->isAllOnes() && width > 1 &&
877 op.getLhs().getType() == op.getRhs().getType()) {
878 return rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs())
879 .getResult();
880 }
881 }
882 return {};
883 });
884}
885
886OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
887 // neq(x, x) -> 0
888 if (getLhs() == getRhs())
889 return getIntAttr(getType(), APInt(1, 0));
890
891 if (auto rhsCst = getConstant(adaptor.getRhs())) {
892 /// neq(x, 0) -> x when x is 1 bit.
893 /// TODO: Support SInt<1> on the LHS etc.
894 if (rhsCst->isZero() && getLhs().getType() == getType() &&
895 getRhs().getType() == getType())
896 return getLhs();
897 }
898
900 *this, adaptor.getOperands(), BinOpKind::Compare,
901 [=](const APSInt &a, const APSInt &b) -> APInt {
902 return APInt(1, a != b);
903 });
904}
905
906LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
907 return canonicalizePrimOp(
908 op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
909 if (auto rhsCst = getConstant(operands[1])) {
910 auto width = op.getLhs().getType().getBitWidthOrSentinel();
911
912 // neq(x, 1) -> not(x) when x is 1 bit
913 if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
914 op.getRhs().getType() == op.getType()) {
915 return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
916 .getResult();
917 }
918
919 // neq(x, 0) -> orr(x) when x is >1 bit
920 if (rhsCst->isZero() && width > 1) {
921 return rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs())
922 .getResult();
923 }
924
925 // neq(x, ~0) -> not(andr(x))) when x is >1 bit
926 if (rhsCst->isAllOnes() && width > 1 &&
927 op.getLhs().getType() == op.getRhs().getType()) {
928 auto andrOp = rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs());
929 return rewriter.create<NotPrimOp>(op.getLoc(), andrOp).getResult();
930 }
931 }
932
933 return {};
934 });
935}
936
937OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
938 // TODO: implement constant folding, etc.
939 // Tracked in https://github.com/llvm/circt/issues/6696.
940 return {};
941}
942
943OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
944 // TODO: implement constant folding, etc.
945 // Tracked in https://github.com/llvm/circt/issues/6724.
946 return {};
947}
948
949OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
950 // TODO: implement constant folding, etc.
951 // Tracked in https://github.com/llvm/circt/issues/6725.
952 return {};
953}
954
955OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
956 if (auto rhsCst = getConstant(adaptor.getRhs())) {
957 // Constant folding
958 if (auto lhsCst = getConstant(adaptor.getLhs()))
959
960 return IntegerAttr::get(
961 IntegerType::get(getContext(), lhsCst->getBitWidth()),
962 lhsCst->shl(*rhsCst));
963
964 // integer.shl(x, 0) -> x
965 if (rhsCst->isZero())
966 return getLhs();
967 }
968
969 return {};
970}
971
972//===----------------------------------------------------------------------===//
973// Unary Operators
974//===----------------------------------------------------------------------===//
975
976OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
977 auto base = getInput().getType();
978 auto w = getBitWidth(base);
979 if (w)
980 return getIntAttr(getType(), APInt(32, *w));
981 return {};
982}
983
984OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
985 // No constant can be 'x' by definition.
986 if (auto cst = getConstant(adaptor.getArg()))
987 return getIntAttr(getType(), APInt(1, 0));
988 return {};
989}
990
991OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
992 // No effect.
993 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
994 return getInput();
995
996 // Be careful to only fold the cast into the constant if the size is known.
997 // Otherwise width inference may produce differently-sized constants if the
998 // sign changes.
999 if (getType().base().hasWidth())
1000 if (auto cst = getConstant(adaptor.getInput()))
1001 return getIntAttr(getType(), *cst);
1002
1003 return {};
1004}
1005
1006void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1007 MLIRContext *context) {
1008 results.insert<patterns::StoUtoS>(context);
1009}
1010
1011OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1012 // No effect.
1013 if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1014 return getInput();
1015
1016 // Be careful to only fold the cast into the constant if the size is known.
1017 // Otherwise width inference may produce differently-sized constants if the
1018 // sign changes.
1019 if (getType().base().hasWidth())
1020 if (auto cst = getConstant(adaptor.getInput()))
1021 return getIntAttr(getType(), *cst);
1022
1023 return {};
1024}
1025
1026void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1027 MLIRContext *context) {
1028 results.insert<patterns::UtoStoU>(context);
1029}
1030
1031OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1032 // No effect.
1033 if (getInput().getType() == getType())
1034 return getInput();
1035
1036 // Constant fold.
1037 if (auto cst = getConstant(adaptor.getInput()))
1038 return BoolAttr::get(getContext(), cst->getBoolValue());
1039
1040 return {};
1041}
1042
1043OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1044 // No effect.
1045 if (getInput().getType() == getType())
1046 return getInput();
1047
1048 // Constant fold.
1049 if (auto cst = getConstant(adaptor.getInput()))
1050 return BoolAttr::get(getContext(), cst->getBoolValue());
1051
1052 return {};
1053}
1054
1055OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1056 if (!hasKnownWidthIntTypes(*this))
1057 return {};
1058
1059 // Signed to signed is a noop, unsigned operands prepend a zero bit.
1060 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1061 getType().base().getWidthOrSentinel()))
1062 return getIntAttr(getType(), *cst);
1063
1064 return {};
1065}
1066
1067void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1068 MLIRContext *context) {
1069 results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1070}
1071
1072OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1073 if (!hasKnownWidthIntTypes(*this))
1074 return {};
1075
1076 // FIRRTL negate always adds a bit.
1077 // -x ---> 0-sext(x) or 0-zext(x)
1078 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1079 getType().base().getWidthOrSentinel()))
1080 return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1081
1082 return {};
1083}
1084
1085OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1086 if (!hasKnownWidthIntTypes(*this))
1087 return {};
1088
1089 if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1090 getType().base().getWidthOrSentinel()))
1091 return getIntAttr(getType(), ~*cst);
1092
1093 return {};
1094}
1095
1096void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1097 MLIRContext *context) {
1098 results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1099 patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1100 patterns::NotGt>(context);
1101}
1102
1103OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1104 if (!hasKnownWidthIntTypes(*this))
1105 return {};
1106
1107 if (getInput().getType().getBitWidthOrSentinel() == 0)
1108 return getIntAttr(getType(), APInt(1, 1));
1109
1110 // x == -1
1111 if (auto cst = getConstant(adaptor.getInput()))
1112 return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1113
1114 // one bit is identity. Only applies to UInt since we can't make a cast
1115 // here.
1116 if (isUInt1(getInput().getType()))
1117 return getInput();
1118
1119 return {};
1120}
1121
1122void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1123 MLIRContext *context) {
1124 results
1125 .insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1126 patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
1127 patterns::AndRCatZeroL, patterns::AndRCatZeroR,
1128 patterns::AndRCatAndR_left, patterns::AndRCatAndR_right>(context);
1129}
1130
1131OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1132 if (!hasKnownWidthIntTypes(*this))
1133 return {};
1134
1135 if (getInput().getType().getBitWidthOrSentinel() == 0)
1136 return getIntAttr(getType(), APInt(1, 0));
1137
1138 // x != 0
1139 if (auto cst = getConstant(adaptor.getInput()))
1140 return getIntAttr(getType(), APInt(1, !cst->isZero()));
1141
1142 // one bit is identity. Only applies to UInt since we can't make a cast
1143 // here.
1144 if (isUInt1(getInput().getType()))
1145 return getInput();
1146
1147 return {};
1148}
1149
1150void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1151 MLIRContext *context) {
1152 results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1153 patterns::OrRCatZeroH, patterns::OrRCatZeroL,
1154 patterns::OrRCatOrR_left, patterns::OrRCatOrR_right>(context);
1155}
1156
1157OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1158 if (!hasKnownWidthIntTypes(*this))
1159 return {};
1160
1161 if (getInput().getType().getBitWidthOrSentinel() == 0)
1162 return getIntAttr(getType(), APInt(1, 0));
1163
1164 // popcount(x) & 1
1165 if (auto cst = getConstant(adaptor.getInput()))
1166 return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1167
1168 // one bit is identity. Only applies to UInt since we can't make a cast here.
1169 if (isUInt1(getInput().getType()))
1170 return getInput();
1171
1172 return {};
1173}
1174
1175void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1176 MLIRContext *context) {
1177 results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1178 patterns::XorRCatZeroH, patterns::XorRCatZeroL,
1179 patterns::XorRCatXorR_left, patterns::XorRCatXorR_right>(
1180 context);
1181}
1182
1183//===----------------------------------------------------------------------===//
1184// Other Operators
1185//===----------------------------------------------------------------------===//
1186
1187OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1188 // cat(x, 0-width) -> x
1189 // cat(0-width, x) -> x
1190 // Limit to unsigned (result type), as cannot insert cast here.
1191 IntType lhsType = getLhs().getType();
1192 IntType rhsType = getRhs().getType();
1193 if (lhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
1194 return getRhs();
1195 if (rhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
1196 return getLhs();
1197
1198 if (!hasKnownWidthIntTypes(*this))
1199 return {};
1200
1201 // Constant fold cat.
1202 if (auto lhs = getConstant(adaptor.getLhs()))
1203 if (auto rhs = getConstant(adaptor.getRhs()))
1204 return getIntAttr(getType(), lhs->concat(*rhs));
1205
1206 return {};
1207}
1208
1209void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1210 MLIRContext *context) {
1211 results.insert<patterns::DShlOfConstant>(context);
1212}
1213
1214void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1215 MLIRContext *context) {
1216 results.insert<patterns::DShrOfConstant>(context);
1217}
1218
1219void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1220 MLIRContext *context) {
1221 results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1222 patterns::CatCast>(context);
1223}
1224
1225OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1226 auto op = (*this);
1227 // BitCast is redundant if input and result types are same.
1228 if (op.getType() == op.getInput().getType())
1229 return op.getInput();
1230
1231 // Two consecutive BitCasts are redundant if first bitcast type is same as the
1232 // final result type.
1233 if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1234 if (op.getType() == in.getInput().getType())
1235 return in.getInput();
1236
1237 return {};
1238}
1239
1240OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1241 IntType inputType = getInput().getType();
1242 IntType resultType = getType();
1243 // If we are extracting the entire input, then return it.
1244 if (inputType == getType() && resultType.hasWidth())
1245 return getInput();
1246
1247 // Constant fold.
1248 if (hasKnownWidthIntTypes(*this))
1249 if (auto cst = getConstant(adaptor.getInput()))
1250 return getIntAttr(resultType,
1251 cst->extractBits(getHi() - getLo() + 1, getLo()));
1252
1253 return {};
1254}
1255
1256void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1257 MLIRContext *context) {
1258 results
1259 .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1260 patterns::BitsOfAnd, patterns::BitsOfPad>(context);
1261}
1262
1263/// Replace the specified operation with a 'bits' op from the specified hi/lo
1264/// bits. Insert a cast to handle the case where the original operation
1265/// returned a signed integer.
1266static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1267 unsigned loBit, PatternRewriter &rewriter) {
1268 auto resType = type_cast<IntType>(op->getResult(0).getType());
1269 if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1270 value = rewriter.create<BitsPrimOp>(op->getLoc(), value, hiBit, loBit);
1271
1272 if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1273 value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1274 } else if (resType.isUnsigned() &&
1275 !type_cast<IntType>(value.getType()).isUnsigned()) {
1276 value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1277 }
1278 rewriter.replaceOp(op, value);
1279}
1280
1281template <typename OpTy>
1282static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1283 // mux : UInt<0> -> 0
1284 if (op.getType().getBitWidthOrSentinel() == 0)
1285 return getIntAttr(op.getType(),
1286 APInt(0, 0, op.getType().isSignedInteger()));
1287
1288 // mux(cond, x, x) -> x
1289 if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1290 return op.getHigh();
1291
1292 // The following folds require that the result has a known width. Otherwise
1293 // the mux requires an additional padding operation to be inserted, which is
1294 // not possible in a fold.
1295 if (op.getType().getBitWidthOrSentinel() < 0)
1296 return {};
1297
1298 // mux(0/1, x, y) -> x or y
1299 if (auto cond = getConstant(adaptor.getSel())) {
1300 if (cond->isZero() && op.getLow().getType() == op.getType())
1301 return op.getLow();
1302 if (!cond->isZero() && op.getHigh().getType() == op.getType())
1303 return op.getHigh();
1304 }
1305
1306 // mux(cond, x, cst)
1307 if (auto lowCst = getConstant(adaptor.getLow())) {
1308 // mux(cond, c1, c2)
1309 if (auto highCst = getConstant(adaptor.getHigh())) {
1310 // mux(cond, cst, cst) -> cst
1311 if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1312 *highCst == *lowCst)
1313 return getIntAttr(op.getType(), *highCst);
1314 // mux(cond, 1, 0) -> cond
1315 if (highCst->isOne() && lowCst->isZero() &&
1316 op.getType() == op.getSel().getType())
1317 return op.getSel();
1318
1319 // TODO: x ? ~0 : 0 -> sext(x)
1320 // TODO: "x ? c1 : c2" -> many tricks
1321 }
1322 // TODO: "x ? a : 0" -> sext(x) & a
1323 }
1324
1325 // TODO: "x ? c1 : y" -> "~x ? y : c1"
1326 return {};
1327}
1328
1329OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1330 return foldMux(*this, adaptor);
1331}
1332
1333OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1334 return foldMux(*this, adaptor);
1335}
1336
1337OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1338
1339namespace {
1340
1341// If the mux has a known output width, pad the operands up to this width.
1342// Most folds on mux require that folded operands are of the same width as
1343// the mux itself.
1344class MuxPad : public mlir::RewritePattern {
1345public:
1346 MuxPad(MLIRContext *context)
1347 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1348
1349 LogicalResult
1350 matchAndRewrite(Operation *op,
1351 mlir::PatternRewriter &rewriter) const override {
1352 auto mux = cast<MuxPrimOp>(op);
1353 auto width = mux.getType().getBitWidthOrSentinel();
1354 if (width < 0)
1355 return failure();
1356
1357 auto pad = [&](Value input) -> Value {
1358 auto inputWidth =
1359 type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1360 if (inputWidth < 0 || width == inputWidth)
1361 return input;
1362 return rewriter
1363 .create<PadPrimOp>(mux.getLoc(), mux.getType(), input, width)
1364 .getResult();
1365 };
1366
1367 auto newHigh = pad(mux.getHigh());
1368 auto newLow = pad(mux.getLow());
1369 if (newHigh == mux.getHigh() && newLow == mux.getLow())
1370 return failure();
1371
1372 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1373 rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1374 mux->getAttrs());
1375 return success();
1376 }
1377};
1378
1379// Find muxes which have conditions dominated by other muxes with the same
1380// condition.
1381class MuxSharedCond : public mlir::RewritePattern {
1382public:
1383 MuxSharedCond(MLIRContext *context)
1384 : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1385
1386 static const int depthLimit = 5;
1387
1388 Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1389 mlir::PatternRewriter &rewriter,
1390 bool updateInPlace) const {
1391 if (updateInPlace) {
1392 rewriter.modifyOpInPlace(mux, [&] {
1393 mux.setOperand(1, high);
1394 mux.setOperand(2, low);
1395 });
1396 return {};
1397 }
1398 rewriter.setInsertionPointAfter(mux);
1399 return rewriter
1400 .create<MuxPrimOp>(mux.getLoc(), mux.getType(),
1401 ValueRange{mux.getSel(), high, low})
1402 .getResult();
1403 }
1404
1405 // Walk a dependent mux tree assuming the condition cond is true.
1406 Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1407 bool updateInPlace, int limit) const {
1408 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1409 if (!mux)
1410 return {};
1411 if (mux.getSel() == cond)
1412 return mux.getHigh();
1413 if (limit > depthLimit)
1414 return {};
1415 updateInPlace &= mux->hasOneUse();
1416
1417 if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1418 limit + 1))
1419 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1420
1421 if (Value v =
1422 tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1423 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1424 return {};
1425 }
1426
1427 // Walk a dependent mux tree assuming the condition cond is false.
1428 Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1429 bool updateInPlace, int limit) const {
1430 MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1431 if (!mux)
1432 return {};
1433 if (mux.getSel() == cond)
1434 return mux.getLow();
1435 if (limit > depthLimit)
1436 return {};
1437 updateInPlace &= mux->hasOneUse();
1438
1439 if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1440 limit + 1))
1441 return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1442
1443 if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1444 limit + 1))
1445 return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1446
1447 return {};
1448 }
1449
1450 LogicalResult
1451 matchAndRewrite(Operation *op,
1452 mlir::PatternRewriter &rewriter) const override {
1453 auto mux = cast<MuxPrimOp>(op);
1454 auto width = mux.getType().getBitWidthOrSentinel();
1455 if (width < 0)
1456 return failure();
1457
1458 if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1459 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1460 return success();
1461 }
1462
1463 if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
1464 rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1465 return success();
1466 }
1467
1468 return failure();
1469 }
1470};
1471} // namespace
1472
1473void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1474 MLIRContext *context) {
1475 results
1476 .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1477 patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1478 patterns::MuxSameTrue, patterns::MuxSameFalse,
1479 patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1480 context);
1481}
1482
1483void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1484 RewritePatternSet &results, MLIRContext *context) {
1485 results.add<patterns::Mux2PadSel>(context);
1486}
1487
1488void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1489 RewritePatternSet &results, MLIRContext *context) {
1490 results.add<patterns::Mux4PadSel>(context);
1491}
1492
1493OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1494 auto input = this->getInput();
1495
1496 // pad(x) -> x if the width doesn't change.
1497 if (input.getType() == getType())
1498 return input;
1499
1500 // Need to know the input width.
1501 auto inputType = input.getType().base();
1502 int32_t width = inputType.getWidthOrSentinel();
1503 if (width == -1)
1504 return {};
1505
1506 // Constant fold.
1507 if (auto cst = getConstant(adaptor.getInput())) {
1508 auto destWidth = getType().base().getWidthOrSentinel();
1509 if (destWidth == -1)
1510 return {};
1511
1512 if (inputType.isSigned() && cst->getBitWidth())
1513 return getIntAttr(getType(), cst->sext(destWidth));
1514 return getIntAttr(getType(), cst->zext(destWidth));
1515 }
1516
1517 return {};
1518}
1519
1520OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1521 auto input = this->getInput();
1522 IntType inputType = input.getType();
1523 int shiftAmount = getAmount();
1524
1525 // shl(x, 0) -> x
1526 if (shiftAmount == 0)
1527 return input;
1528
1529 // Constant fold.
1530 if (auto cst = getConstant(adaptor.getInput())) {
1531 auto inputWidth = inputType.getWidthOrSentinel();
1532 if (inputWidth != -1) {
1533 auto resultWidth = inputWidth + shiftAmount;
1534 shiftAmount = std::min(shiftAmount, resultWidth);
1535 return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1536 }
1537 }
1538 return {};
1539}
1540
1541OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1542 auto input = this->getInput();
1543 IntType inputType = input.getType();
1544 int shiftAmount = getAmount();
1545 auto inputWidth = inputType.getWidthOrSentinel();
1546
1547 // shr(x, 0) -> x
1548 // Once the shr width changes, do this: shiftAmount == 0 &&
1549 // (!inputType.isSigned() || inputWidth > 0)
1550 if (shiftAmount == 0 && inputWidth > 0)
1551 return input;
1552
1553 if (inputWidth == -1)
1554 return {};
1555 if (inputWidth == 0)
1556 return getIntZerosAttr(getType());
1557
1558 // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
1559 // If x is signed, it is the sign bit.
1560 if (shiftAmount >= inputWidth && inputType.isUnsigned())
1561 return getIntAttr(getType(), APInt(0, 0, false));
1562
1563 // Constant fold.
1564 if (auto cst = getConstant(adaptor.getInput())) {
1565 APInt value;
1566 if (inputType.isSigned())
1567 value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1568 else
1569 value = cst->lshr(std::min(shiftAmount, inputWidth));
1570 auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1571 return getIntAttr(getType(), value.trunc(resultWidth));
1572 }
1573 return {};
1574}
1575
1576LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1577 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1578 if (inputWidth <= 0)
1579 return failure();
1580
1581 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1582 unsigned shiftAmount = op.getAmount();
1583 if (int(shiftAmount) >= inputWidth) {
1584 // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
1585 if (op.getType().base().isUnsigned())
1586 return failure();
1587
1588 // Shifting a signed value by the full width is actually taking the
1589 // sign bit. If the shift amount is greater than the input width, it
1590 // is equivalent to shifting by the input width.
1591 shiftAmount = inputWidth - 1;
1592 }
1593
1594 replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1595 return success();
1596}
1597
1598LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1599 PatternRewriter &rewriter) {
1600 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1601 if (inputWidth <= 0)
1602 return failure();
1603
1604 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1605 unsigned keepAmount = op.getAmount();
1606 if (keepAmount)
1607 replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1608 rewriter);
1609 return success();
1610}
1611
1612OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1613 if (hasKnownWidthIntTypes(*this))
1614 if (auto cst = getConstant(adaptor.getInput())) {
1615 int shiftAmount =
1616 getInput().getType().base().getWidthOrSentinel() - getAmount();
1617 return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1618 }
1619
1620 return {};
1621}
1622
1623OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1624 if (hasKnownWidthIntTypes(*this))
1625 if (auto cst = getConstant(adaptor.getInput()))
1626 return getIntAttr(getType(),
1627 cst->trunc(getType().base().getWidthOrSentinel()));
1628 return {};
1629}
1630
1631LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1632 PatternRewriter &rewriter) {
1633 auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1634 if (inputWidth <= 0)
1635 return failure();
1636
1637 // If we know the input width, we can canonicalize this into a BitsPrimOp.
1638 unsigned dropAmount = op.getAmount();
1639 if (dropAmount != unsigned(inputWidth))
1640 replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
1641 rewriter);
1642 return success();
1643}
1644
1645void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1646 MLIRContext *context) {
1647 results.add<patterns::SubaccessOfConstant>(context);
1648}
1649
1650OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1651 // If there is only one input, just return it.
1652 if (adaptor.getInputs().size() == 1)
1653 return getOperand(1);
1654
1655 if (auto constIndex = getConstant(adaptor.getIndex())) {
1656 auto index = constIndex->getZExtValue();
1657 if (index < getInputs().size())
1658 return getInputs()[getInputs().size() - 1 - index];
1659 }
1660
1661 return {};
1662}
1663
1664LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
1665 PatternRewriter &rewriter) {
1666 // If all operands are equal, just canonicalize to it. We can add this
1667 // canonicalization as a folder but it costly to look through all inputs so it
1668 // is added here.
1669 if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
1670 return input == op.getInputs().front();
1671 })) {
1672 replaceOpAndCopyName(rewriter, op, op.getInputs().front());
1673 return success();
1674 }
1675
1676 // If the index width is narrower than the size of inputs, drop front
1677 // elements.
1678 auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
1679 uint64_t inputSize = op.getInputs().size();
1680 if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
1681 rewriter.modifyOpInPlace(op, [&]() {
1682 op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
1683 });
1684 return success();
1685 }
1686
1687 // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
1688 // a[0]`), we can fold the op into subaccess op `a[idx]`.
1689 if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
1690 if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
1691 auto subindex = e.value().template getDefiningOp<SubindexOp>();
1692 return subindex && lastSubindex.getInput() == subindex.getInput() &&
1693 subindex.getIndex() + e.index() + 1 == op.getInputs().size();
1694 })) {
1695 replaceOpWithNewOpAndCopyName<SubaccessOp>(
1696 rewriter, op, lastSubindex.getInput(), op.getIndex());
1697 return success();
1698 }
1699 }
1700
1701 // If the size is 2, canonicalize into a normal mux to introduce more folds.
1702 if (op.getInputs().size() != 2)
1703 return failure();
1704
1705 // TODO: Handle even when `index` doesn't have uint<1>.
1706 auto uintType = op.getIndex().getType();
1707 if (uintType.getBitWidthOrSentinel() != 1)
1708 return failure();
1709
1710 // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
1711 replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1712 rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
1713 return success();
1714}
1715
1716//===----------------------------------------------------------------------===//
1717// Declarations
1718//===----------------------------------------------------------------------===//
1719
1720/// Scan all the uses of the specified value, checking to see if there is
1721/// exactly one connect that has the value as its destination. This returns the
1722/// operation if found and if all the other users are "reads" from the value.
1723/// Returns null if there are no connects, or multiple connects to the value, or
1724/// if the value is involved in an `AttachOp`, or if the connect isn't matching.
1725///
1726/// Note that this will simply return the connect, which is located *anywhere*
1727/// after the definition of the value. Users of this function are likely
1728/// interested in the source side of the returned connect, the definition of
1729/// which does likely not dominate the original value.
1730MatchingConnectOp firrtl::getSingleConnectUserOf(Value value) {
1731 MatchingConnectOp connect;
1732 for (Operation *user : value.getUsers()) {
1733 // If we see an attach or aggregate sublements, just conservatively fail.
1734 if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
1735 return {};
1736
1737 if (auto aConnect = dyn_cast<FConnectLike>(user))
1738 if (aConnect.getDest() == value) {
1739 auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
1740 // If this is not a matching connect, a second matching connect or in a
1741 // different block, fail.
1742 if (!matchingConnect || (connect && connect != matchingConnect) ||
1743 matchingConnect->getBlock() != value.getParentBlock())
1744 return {};
1745 else
1746 connect = matchingConnect;
1747 }
1748 }
1749 return connect;
1750}
1751
1752// Forward simple values through wire's and reg's.
1753static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op,
1754 PatternRewriter &rewriter) {
1755 // While we can do this for nearly all wires, we currently limit it to simple
1756 // things.
1757 Operation *connectedDecl = op.getDest().getDefiningOp();
1758 if (!connectedDecl)
1759 return failure();
1760
1761 // Only support wire and reg for now.
1762 if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
1763 return failure();
1764 if (hasDontTouch(connectedDecl) || !AnnotationSet(connectedDecl).empty() ||
1765 !hasDroppableName(connectedDecl) ||
1766 cast<Forceable>(connectedDecl).isForceable())
1767 return failure();
1768
1769 // Only forward if the types exactly match and there is one connect.
1770 if (getSingleConnectUserOf(op.getDest()) != op)
1771 return failure();
1772
1773 // Only forward if there is more than one use
1774 if (connectedDecl->hasOneUse())
1775 return failure();
1776
1777 // Only do this if the connectee and the declaration are in the same block.
1778 auto *declBlock = connectedDecl->getBlock();
1779 auto *srcValueOp = op.getSrc().getDefiningOp();
1780 if (!srcValueOp) {
1781 // Ports are ok for wires but not registers.
1782 if (!isa<WireOp>(connectedDecl))
1783 return failure();
1784
1785 } else {
1786 // Constants/invalids in the same block are ok to forward, even through
1787 // reg's since the clocking doesn't matter for constants.
1788 if (!isa<ConstantOp>(srcValueOp))
1789 return failure();
1790 if (srcValueOp->getBlock() != declBlock)
1791 return failure();
1792 }
1793
1794 // Ok, we know we are doing the transformation.
1795
1796 auto replacement = op.getSrc();
1797 // This will be replaced with the constant source. First, make sure the
1798 // constant dominates all users.
1799 if (srcValueOp && srcValueOp != &declBlock->front())
1800 srcValueOp->moveBefore(&declBlock->front());
1801
1802 // Replace all things *using* the decl with the constant/port, and
1803 // remove the declaration.
1804 replaceOpAndCopyName(rewriter, connectedDecl, replacement);
1805
1806 // Remove the connect
1807 rewriter.eraseOp(op);
1808 return success();
1809}
1810
1811void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1812 MLIRContext *context) {
1813 results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
1814 context);
1815}
1816
1817LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
1818 PatternRewriter &rewriter) {
1819 // TODO: Canonicalize towards explicit extensions and flips here.
1820
1821 // If there is a simple value connected to a foldable decl like a wire or reg,
1822 // see if we can eliminate the decl.
1823 if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
1824 return success();
1825 return failure();
1826}
1827
1828//===----------------------------------------------------------------------===//
1829// Statements
1830//===----------------------------------------------------------------------===//
1831
1832/// If the specified value has an AttachOp user strictly dominating by
1833/// "dominatingAttach" then return it.
1834static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
1835 for (auto *user : value.getUsers()) {
1836 auto attach = dyn_cast<AttachOp>(user);
1837 if (!attach || attach == dominatedAttach)
1838 continue;
1839 if (attach->isBeforeInBlock(dominatedAttach))
1840 return attach;
1841 }
1842 return {};
1843}
1844
1845LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
1846 // Single operand attaches are a noop.
1847 if (op.getNumOperands() <= 1) {
1848 rewriter.eraseOp(op);
1849 return success();
1850 }
1851
1852 for (auto operand : op.getOperands()) {
1853 // Check to see if any of our operands has other attaches to it:
1854 // attach x, y
1855 // ...
1856 // attach x, z
1857 // If so, we can merge these into "attach x, y, z".
1858 if (auto attach = getDominatingAttachUser(operand, op)) {
1859 SmallVector<Value> newOperands(op.getOperands());
1860 for (auto newOperand : attach.getOperands())
1861 if (newOperand != operand) // Don't add operand twice.
1862 newOperands.push_back(newOperand);
1863 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1864 rewriter.eraseOp(attach);
1865 rewriter.eraseOp(op);
1866 return success();
1867 }
1868
1869 // If this wire is *only* used by an attach then we can just delete
1870 // it.
1871 // TODO: May need to be sensitive to "don't touch" or other
1872 // annotations.
1873 if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
1874 if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
1875 !wire.isForceable()) {
1876 SmallVector<Value> newOperands;
1877 for (auto newOperand : op.getOperands())
1878 if (newOperand != operand) // Don't the add wire.
1879 newOperands.push_back(newOperand);
1880
1881 rewriter.create<AttachOp>(op->getLoc(), newOperands);
1882 rewriter.eraseOp(op);
1883 rewriter.eraseOp(wire);
1884 return success();
1885 }
1886 }
1887 }
1888 return failure();
1889}
1890
1891/// Replaces the given op with the contents of the given single-block region.
1892static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1893 Region &region) {
1894 assert(llvm::hasSingleElement(region) && "expected single-region block");
1895 rewriter.inlineBlockBefore(&region.front(), op, {});
1896}
1897
1898LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
1899 if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
1900 if (constant.getValue().isAllOnes())
1901 replaceOpWithRegion(rewriter, op, op.getThenRegion());
1902 else if (op.hasElseRegion() && !op.getElseRegion().empty())
1903 replaceOpWithRegion(rewriter, op, op.getElseRegion());
1904
1905 rewriter.eraseOp(op);
1906
1907 return success();
1908 }
1909
1910 // Erase empty if-else block.
1911 if (!op.getThenBlock().empty() && op.hasElseRegion() &&
1912 op.getElseBlock().empty()) {
1913 rewriter.eraseBlock(&op.getElseBlock());
1914 return success();
1915 }
1916
1917 // Erase empty whens.
1918
1919 // If there is stuff in the then block, leave this operation alone.
1920 if (!op.getThenBlock().empty())
1921 return failure();
1922
1923 // If not and there is no else, then this operation is just useless.
1924 if (!op.hasElseRegion() || op.getElseBlock().empty()) {
1925 rewriter.eraseOp(op);
1926 return success();
1927 }
1928 return failure();
1929}
1930
1931namespace {
1932// Remove private nodes. If they have an interesting names, move the name to
1933// the source expression.
1934struct FoldNodeName : public mlir::RewritePattern {
1935 FoldNodeName(MLIRContext *context)
1936 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1937 LogicalResult matchAndRewrite(Operation *op,
1938 PatternRewriter &rewriter) const override {
1939 auto node = cast<NodeOp>(op);
1940 auto name = node.getNameAttr();
1941 if (!node.hasDroppableName() || node.getInnerSym() ||
1942 !AnnotationSet(node).empty() || node.isForceable())
1943 return failure();
1944 auto *newOp = node.getInput().getDefiningOp();
1945 // Best effort, do not rename InstanceOp
1946 if (newOp && !isa<InstanceOp>(newOp))
1947 updateName(rewriter, newOp, name);
1948 rewriter.replaceOp(node, node.getInput());
1949 return success();
1950 }
1951};
1952
1953// Bypass nodes.
1954struct NodeBypass : public mlir::RewritePattern {
1955 NodeBypass(MLIRContext *context)
1956 : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1957 LogicalResult matchAndRewrite(Operation *op,
1958 PatternRewriter &rewriter) const override {
1959 auto node = cast<NodeOp>(op);
1960 if (node.getInnerSym() || !AnnotationSet(node).empty() ||
1961 node.use_empty() || node.isForceable())
1962 return failure();
1963 rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
1964 return success();
1965 }
1966};
1967
1968} // namespace
1969
1970template <typename OpTy>
1971static LogicalResult demoteForceableIfUnused(OpTy op,
1972 PatternRewriter &rewriter) {
1973 if (!op.isForceable() || !op.getDataRef().use_empty())
1974 return failure();
1975
1976 firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
1977 return success();
1978}
1979
1980// Interesting names and symbols and don't touch force nodes to stick around.
1981LogicalResult NodeOp::fold(FoldAdaptor adaptor,
1982 SmallVectorImpl<OpFoldResult> &results) {
1983 if (!hasDroppableName())
1984 return failure();
1985 if (hasDontTouch(getResult())) // handles inner symbols
1986 return failure();
1987 if (getAnnotationsAttr() && !AnnotationSet(getAnnotationsAttr()).empty())
1988 return failure();
1989 if (isForceable())
1990 return failure();
1991 if (!adaptor.getInput())
1992 return failure();
1993
1994 results.push_back(adaptor.getInput());
1995 return success();
1996}
1997
1998void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1999 MLIRContext *context) {
2000 results.insert<FoldNodeName>(context);
2001 results.add(demoteForceableIfUnused<NodeOp>);
2002}
2003
2004namespace {
2005// For a lhs, find all the writers of fields of the aggregate type. If there
2006// is one writer for each field, merge the writes
2007struct AggOneShot : public mlir::RewritePattern {
2008 AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2009 : RewritePattern(name, 0, context) {}
2010
2011 SmallVector<Value> getCompleteWrite(Operation *lhs) const {
2012 auto lhsTy = lhs->getResult(0).getType();
2013 if (!type_isa<BundleType, FVectorType>(lhsTy))
2014 return {};
2015
2016 DenseMap<uint32_t, Value> fields;
2017 for (Operation *user : lhs->getResult(0).getUsers()) {
2018 if (user->getParentOp() != lhs->getParentOp())
2019 return {};
2020 if (auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2021 if (aConnect.getDest() == lhs->getResult(0))
2022 return {};
2023 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2024 for (Operation *subuser : subField.getResult().getUsers()) {
2025 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2026 if (aConnect.getDest() == subField) {
2027 if (subuser->getParentOp() != lhs->getParentOp())
2028 return {};
2029 if (fields.count(subField.getFieldIndex())) // duplicate write
2030 return {};
2031 fields[subField.getFieldIndex()] = aConnect.getSrc();
2032 }
2033 continue;
2034 }
2035 return {};
2036 }
2037 } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2038 for (Operation *subuser : subIndex.getResult().getUsers()) {
2039 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2040 if (aConnect.getDest() == subIndex) {
2041 if (subuser->getParentOp() != lhs->getParentOp())
2042 return {};
2043 if (fields.count(subIndex.getIndex())) // duplicate write
2044 return {};
2045 fields[subIndex.getIndex()] = aConnect.getSrc();
2046 }
2047 continue;
2048 }
2049 return {};
2050 }
2051 } else {
2052 return {};
2053 }
2054 }
2055
2056 SmallVector<Value> values;
2057 uint32_t total = type_isa<BundleType>(lhsTy)
2058 ? type_cast<BundleType>(lhsTy).getNumElements()
2059 : type_cast<FVectorType>(lhsTy).getNumElements();
2060 for (uint32_t i = 0; i < total; ++i) {
2061 if (!fields.count(i))
2062 return {};
2063 values.push_back(fields[i]);
2064 }
2065 return values;
2066 }
2067
2068 LogicalResult matchAndRewrite(Operation *op,
2069 PatternRewriter &rewriter) const override {
2070 auto values = getCompleteWrite(op);
2071 if (values.empty())
2072 return failure();
2073 rewriter.setInsertionPointToEnd(op->getBlock());
2074 auto dest = op->getResult(0);
2075 auto destType = dest.getType();
2076
2077 // If not passive, cannot matchingconnect.
2078 if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2079 return failure();
2080
2081 Value newVal = type_isa<BundleType>(destType)
2082 ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2083 destType, values)
2084 : rewriter.createOrFold<VectorCreateOp>(
2085 op->getLoc(), destType, values);
2086 rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2087 for (Operation *user : dest.getUsers()) {
2088 if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2089 for (Operation *subuser :
2090 llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2091 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2092 if (aConnect.getDest() == subIndex)
2093 rewriter.eraseOp(aConnect);
2094 } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2095 for (Operation *subuser :
2096 llvm::make_early_inc_range(subField.getResult().getUsers()))
2097 if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2098 if (aConnect.getDest() == subField)
2099 rewriter.eraseOp(aConnect);
2100 }
2101 }
2102 return success();
2103 }
2104};
2105
2106struct WireAggOneShot : public AggOneShot {
2107 WireAggOneShot(MLIRContext *context)
2108 : AggOneShot(WireOp::getOperationName(), 0, context) {}
2109};
2110struct SubindexAggOneShot : public AggOneShot {
2111 SubindexAggOneShot(MLIRContext *context)
2112 : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2113};
2114struct SubfieldAggOneShot : public AggOneShot {
2115 SubfieldAggOneShot(MLIRContext *context)
2116 : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2117};
2118} // namespace
2119
2120void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2121 MLIRContext *context) {
2122 results.insert<WireAggOneShot>(context);
2123 results.add(demoteForceableIfUnused<WireOp>);
2124}
2125
2126void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2127 MLIRContext *context) {
2128 results.insert<SubindexAggOneShot>(context);
2129}
2130
2131OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2132 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2133 if (!attr)
2134 return {};
2135 return attr[getIndex()];
2136}
2137
2138OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2139 auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2140 if (!attr)
2141 return {};
2142 auto index = getFieldIndex();
2143 return attr[index];
2144}
2145
2146void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2147 MLIRContext *context) {
2148 results.insert<SubfieldAggOneShot>(context);
2149}
2150
2151static Attribute collectFields(MLIRContext *context,
2152 ArrayRef<Attribute> operands) {
2153 for (auto operand : operands)
2154 if (!operand)
2155 return {};
2156 return ArrayAttr::get(context, operands);
2157}
2158
2159OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2160 // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2161 // bundle<a:..., b:...>.
2162 if (getNumOperands() > 0)
2163 if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2164 if (first.getFieldIndex() == 0 &&
2165 first.getInput().getType() == getType() &&
2166 llvm::all_of(
2167 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2168 auto subindex =
2169 elem.value().template getDefiningOp<SubfieldOp>();
2170 return subindex && subindex.getInput() == first.getInput() &&
2171 subindex.getFieldIndex() == elem.index();
2172 }))
2173 return first.getInput();
2174
2175 return collectFields(getContext(), adaptor.getOperands());
2176}
2177
2178OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2179 // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2180 // vector<..., 2>.
2181 if (getNumOperands() > 0)
2182 if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2183 if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2184 llvm::all_of(
2185 llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2186 auto subindex =
2187 elem.value().template getDefiningOp<SubindexOp>();
2188 return subindex && subindex.getInput() == first.getInput() &&
2189 subindex.getIndex() == elem.index();
2190 }))
2191 return first.getInput();
2192
2193 return collectFields(getContext(), adaptor.getOperands());
2194}
2195
2196OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2197 if (getOperand().getType() == getType())
2198 return getOperand();
2199 return {};
2200}
2201
2202namespace {
2203// A register with constant reset and all connection to either itself or the
2204// same constant, must be replaced by the constant.
2205struct FoldResetMux : public mlir::RewritePattern {
2206 FoldResetMux(MLIRContext *context)
2207 : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2208 LogicalResult matchAndRewrite(Operation *op,
2209 PatternRewriter &rewriter) const override {
2210 auto reg = cast<RegResetOp>(op);
2211 auto reset =
2212 dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2213 if (!reset || hasDontTouch(reg.getOperation()) ||
2214 !AnnotationSet(reg).empty() || reg.isForceable())
2215 return failure();
2216 // Find the one true connect, or bail
2217 auto con = getSingleConnectUserOf(reg.getResult());
2218 if (!con)
2219 return failure();
2220
2221 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2222 if (!mux)
2223 return failure();
2224 auto *high = mux.getHigh().getDefiningOp();
2225 auto *low = mux.getLow().getDefiningOp();
2226 auto constOp = dyn_cast_or_null<ConstantOp>(high);
2227
2228 if (constOp && low != reg)
2229 return failure();
2230 if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2231 constOp = dyn_cast<ConstantOp>(low);
2232
2233 if (!constOp || constOp.getType() != reset.getType() ||
2234 constOp.getValue() != reset.getValue())
2235 return failure();
2236
2237 // Check all types should be typed by now
2238 auto regTy = reg.getResult().getType();
2239 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2240 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2241 regTy.getBitWidthOrSentinel() < 0)
2242 return failure();
2243
2244 // Ok, we know we are doing the transformation.
2245
2246 // Make sure the constant dominates all users.
2247 if (constOp != &con->getBlock()->front())
2248 constOp->moveBefore(&con->getBlock()->front());
2249
2250 // Replace the register with the constant.
2251 replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2252 // Remove the connect.
2253 rewriter.eraseOp(con);
2254 return success();
2255 }
2256};
2257} // namespace
2258
2259static bool isDefinedByOneConstantOp(Value v) {
2260 if (auto c = v.getDefiningOp<ConstantOp>())
2261 return c.getValue().isOne();
2262 if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2263 return sc.getValue();
2264 return false;
2265}
2266
2267static LogicalResult
2268canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2269 if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2270 return failure();
2271
2272 auto resetValue = reg.getResetValue();
2273 if (reg.getType(0) != resetValue.getType())
2274 return failure();
2275
2276 // Ignore 'passthrough'.
2277 (void)dropWrite(rewriter, reg->getResult(0), {});
2278 replaceOpWithNewOpAndCopyName<NodeOp>(
2279 rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
2280 reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2281 return success();
2282}
2283
2284void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2285 MLIRContext *context) {
2286 results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2288 results.add(demoteForceableIfUnused<RegResetOp>);
2289}
2290
2291// Returns the value connected to a port, if there is only one.
2292static Value getPortFieldValue(Value port, StringRef name) {
2293 auto portTy = type_cast<BundleType>(port.getType());
2294 auto fieldIndex = portTy.getElementIndex(name);
2295 assert(fieldIndex && "missing field on memory port");
2296
2297 Value value = {};
2298 for (auto *op : port.getUsers()) {
2299 auto portAccess = cast<SubfieldOp>(op);
2300 if (fieldIndex != portAccess.getFieldIndex())
2301 continue;
2302 auto conn = getSingleConnectUserOf(portAccess);
2303 if (!conn || value)
2304 return {};
2305 value = conn.getSrc();
2306 }
2307 return value;
2308}
2309
2310// Returns true if the enable field of a port is set to false.
2311static bool isPortDisabled(Value port) {
2312 auto value = getPortFieldValue(port, "en");
2313 if (!value)
2314 return false;
2315 auto portConst = value.getDefiningOp<ConstantOp>();
2316 if (!portConst)
2317 return false;
2318 return portConst.getValue().isZero();
2319}
2320
2321// Returns true if the data output is unused.
2322static bool isPortUnused(Value port, StringRef data) {
2323 auto portTy = type_cast<BundleType>(port.getType());
2324 auto fieldIndex = portTy.getElementIndex(data);
2325 assert(fieldIndex && "missing enable flag on memory port");
2326
2327 for (auto *op : port.getUsers()) {
2328 auto portAccess = cast<SubfieldOp>(op);
2329 if (fieldIndex != portAccess.getFieldIndex())
2330 continue;
2331 if (!portAccess.use_empty())
2332 return false;
2333 }
2334
2335 return true;
2336}
2337
2338// Returns the value connected to a port, if there is only one.
2339static void replacePortField(PatternRewriter &rewriter, Value port,
2340 StringRef name, Value value) {
2341 auto portTy = type_cast<BundleType>(port.getType());
2342 auto fieldIndex = portTy.getElementIndex(name);
2343 assert(fieldIndex && "missing field on memory port");
2344
2345 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2346 auto portAccess = cast<SubfieldOp>(op);
2347 if (fieldIndex != portAccess.getFieldIndex())
2348 continue;
2349 rewriter.replaceAllUsesWith(portAccess, value);
2350 rewriter.eraseOp(portAccess);
2351 }
2352}
2353
2354// Remove accesses to a port which is used.
2355static void erasePort(PatternRewriter &rewriter, Value port) {
2356 // Helper to create a dummy 0 clock for the dummy registers.
2357 Value clock;
2358 auto getClock = [&] {
2359 if (!clock)
2360 clock = rewriter.create<SpecialConstantOp>(
2361 port.getLoc(), ClockType::get(rewriter.getContext()), false);
2362 return clock;
2363 };
2364
2365 // Find the clock field of the port and determine whether the port is
2366 // accessed only through its subfields or as a whole wire. If the port
2367 // is used in its entirety, replace it with a wire. Otherwise,
2368 // eliminate individual subfields and replace with reasonable defaults.
2369 for (auto *op : port.getUsers()) {
2370 auto subfield = dyn_cast<SubfieldOp>(op);
2371 if (!subfield) {
2372 auto ty = port.getType();
2373 auto reg = rewriter.create<RegOp>(port.getLoc(), ty, getClock());
2374 rewriter.replaceAllUsesWith(port, reg.getResult());
2375 return;
2376 }
2377 }
2378
2379 // Remove all connects to field accesses as they are no longer relevant.
2380 // If field values are used anywhere, which should happen solely for read
2381 // ports, a dummy register is introduced which replicates the behaviour of
2382 // memory that is never written, but might be read.
2383 for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2384 auto access = cast<SubfieldOp>(accessOp);
2385 for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2386 auto connect = dyn_cast<FConnectLike>(user);
2387 if (connect && connect.getDest() == access) {
2388 rewriter.eraseOp(user);
2389 continue;
2390 }
2391 }
2392 if (access.use_empty()) {
2393 rewriter.eraseOp(access);
2394 continue;
2395 }
2396
2397 // Replace read values with a register that is never written, handing off
2398 // the canonicalization of such a register to another canonicalizer.
2399 auto ty = access.getType();
2400 auto reg = rewriter.create<RegOp>(access.getLoc(), ty, getClock());
2401 rewriter.replaceOp(access, reg.getResult());
2402 }
2403 assert(port.use_empty() && "port should have no remaining uses");
2404}
2405
2406namespace {
2407// If memory has known, but zero width, eliminate it.
2408struct FoldZeroWidthMemory : public mlir::RewritePattern {
2409 FoldZeroWidthMemory(MLIRContext *context)
2410 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2411 LogicalResult matchAndRewrite(Operation *op,
2412 PatternRewriter &rewriter) const override {
2413 MemOp mem = cast<MemOp>(op);
2414 if (hasDontTouch(mem))
2415 return failure();
2416
2417 if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2418 mem.getDataType().getBitWidthOrSentinel() != 0)
2419 return failure();
2420
2421 // Make sure are users are safe to replace
2422 for (auto port : mem.getResults())
2423 for (auto *user : port.getUsers())
2424 if (!isa<SubfieldOp>(user))
2425 return failure();
2426
2427 // Annoyingly, there isn't a good replacement for the port as a whole,
2428 // since they have an outer flip type.
2429 for (auto port : op->getResults()) {
2430 for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2431 SubfieldOp sfop = cast<SubfieldOp>(user);
2432 StringRef fieldName = sfop.getFieldName();
2433 auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2434 rewriter, sfop, sfop.getResult().getType())
2435 .getResult();
2436 if (fieldName.ends_with("data")) {
2437 // Make sure to write data ports.
2438 auto zero = rewriter.create<firrtl::ConstantOp>(
2439 wire.getLoc(), firrtl::type_cast<IntType>(wire.getType()),
2440 APInt::getZero(0));
2441 rewriter.create<MatchingConnectOp>(wire.getLoc(), wire, zero);
2442 }
2443 }
2444 }
2445 rewriter.eraseOp(op);
2446 return success();
2447 }
2448};
2449
2450// If memory has no write ports and no file initialization, eliminate it.
2451struct FoldReadOrWriteOnlyMemory : public mlir::RewritePattern {
2452 FoldReadOrWriteOnlyMemory(MLIRContext *context)
2453 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2454 LogicalResult matchAndRewrite(Operation *op,
2455 PatternRewriter &rewriter) const override {
2456 MemOp mem = cast<MemOp>(op);
2457 if (hasDontTouch(mem))
2458 return failure();
2459 bool isRead = false, isWritten = false;
2460 for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2461 switch (mem.getPortKind(i)) {
2462 case MemOp::PortKind::Read:
2463 isRead = true;
2464 if (isWritten)
2465 return failure();
2466 continue;
2467 case MemOp::PortKind::Write:
2468 isWritten = true;
2469 if (isRead)
2470 return failure();
2471 continue;
2472 case MemOp::PortKind::Debug:
2473 case MemOp::PortKind::ReadWrite:
2474 return failure();
2475 }
2476 llvm_unreachable("unknown port kind");
2477 }
2478 assert((!isWritten || !isRead) && "memory is in use");
2479
2480 // If the memory is read only, but has a file initialization, then we can't
2481 // remove it. A write only memory with file initialization is okay to
2482 // remove.
2483 if (isRead && mem.getInit())
2484 return failure();
2485
2486 for (auto port : mem.getResults())
2487 erasePort(rewriter, port);
2488
2489 rewriter.eraseOp(op);
2490 return success();
2491 }
2492};
2493
2494// Eliminate the dead ports of memories.
2495struct FoldUnusedPorts : public mlir::RewritePattern {
2496 FoldUnusedPorts(MLIRContext *context)
2497 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2498 LogicalResult matchAndRewrite(Operation *op,
2499 PatternRewriter &rewriter) const override {
2500 MemOp mem = cast<MemOp>(op);
2501 if (hasDontTouch(mem))
2502 return failure();
2503 // Identify the dead and changed ports.
2504 llvm::SmallBitVector deadPorts(mem.getNumResults());
2505 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2506 // Do not simplify annotated ports.
2507 if (!mem.getPortAnnotation(i).empty())
2508 continue;
2509
2510 // Skip debug ports.
2511 auto kind = mem.getPortKind(i);
2512 if (kind == MemOp::PortKind::Debug)
2513 continue;
2514
2515 // If a port is disabled, always eliminate it.
2516 if (isPortDisabled(port)) {
2517 deadPorts.set(i);
2518 continue;
2519 }
2520 // Eliminate read ports whose outputs are not used.
2521 if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
2522 deadPorts.set(i);
2523 continue;
2524 }
2525 }
2526 if (deadPorts.none())
2527 return failure();
2528
2529 // Rebuild the new memory with the altered ports.
2530 SmallVector<Type> resultTypes;
2531 SmallVector<StringRef> portNames;
2532 SmallVector<Attribute> portAnnotations;
2533 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2534 if (deadPorts[i])
2535 continue;
2536 resultTypes.push_back(port.getType());
2537 portNames.push_back(mem.getPortName(i));
2538 portAnnotations.push_back(mem.getPortAnnotation(i));
2539 }
2540
2541 MemOp newOp;
2542 if (!resultTypes.empty())
2543 newOp = rewriter.create<MemOp>(
2544 mem.getLoc(), resultTypes, mem.getReadLatency(),
2545 mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2546 rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2547 mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2548 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2549
2550 // Replace the dead ports with dummy wires.
2551 unsigned nextPort = 0;
2552 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2553 if (deadPorts[i])
2554 erasePort(rewriter, port);
2555 else
2556 rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2557 }
2558
2559 rewriter.eraseOp(op);
2560 return success();
2561 }
2562};
2563
2564// Rewrite write-only read-write ports to write ports.
2565struct FoldReadWritePorts : public mlir::RewritePattern {
2566 FoldReadWritePorts(MLIRContext *context)
2567 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2568 LogicalResult matchAndRewrite(Operation *op,
2569 PatternRewriter &rewriter) const override {
2570 MemOp mem = cast<MemOp>(op);
2571 if (hasDontTouch(mem))
2572 return failure();
2573
2574 // Identify read-write ports whose read end is unused.
2575 llvm::SmallBitVector deadReads(mem.getNumResults());
2576 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2577 if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2578 continue;
2579 if (!mem.getPortAnnotation(i).empty())
2580 continue;
2581 if (isPortUnused(port, "rdata")) {
2582 deadReads.set(i);
2583 continue;
2584 }
2585 }
2586 if (deadReads.none())
2587 return failure();
2588
2589 SmallVector<Type> resultTypes;
2590 SmallVector<StringRef> portNames;
2591 SmallVector<Attribute> portAnnotations;
2592 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2593 if (deadReads[i])
2594 resultTypes.push_back(
2595 MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2596 MemOp::PortKind::Write, mem.getMaskBits()));
2597 else
2598 resultTypes.push_back(port.getType());
2599
2600 portNames.push_back(mem.getPortName(i));
2601 portAnnotations.push_back(mem.getPortAnnotation(i));
2602 }
2603
2604 auto newOp = rewriter.create<MemOp>(
2605 mem.getLoc(), resultTypes, mem.getReadLatency(), mem.getWriteLatency(),
2606 mem.getDepth(), mem.getRuw(), rewriter.getStrArrayAttr(portNames),
2607 mem.getName(), mem.getNameKind(), mem.getAnnotations(),
2608 rewriter.getArrayAttr(portAnnotations), mem.getInnerSymAttr(),
2609 mem.getInitAttr(), mem.getPrefixAttr());
2610
2611 for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2612 auto result = mem.getResult(i);
2613 auto newResult = newOp.getResult(i);
2614 if (deadReads[i]) {
2615 auto resultPortTy = type_cast<BundleType>(result.getType());
2616
2617 // Rewrite accesses to the old port field to accesses to a
2618 // corresponding field of the new port.
2619 auto replace = [&](StringRef toName, StringRef fromName) {
2620 auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2621 assert(fromFieldIndex && "missing enable flag on memory port");
2622
2623 auto toField = rewriter.create<SubfieldOp>(newResult.getLoc(),
2624 newResult, toName);
2625 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2626 auto fromField = cast<SubfieldOp>(op);
2627 if (fromFieldIndex != fromField.getFieldIndex())
2628 continue;
2629 rewriter.replaceOp(fromField, toField.getResult());
2630 }
2631 };
2632
2633 replace("addr", "addr");
2634 replace("en", "en");
2635 replace("clk", "clk");
2636 replace("data", "wdata");
2637 replace("mask", "wmask");
2638
2639 // Remove the wmode field, replacing it with dummy wires.
2640 auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
2641 for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2642 auto wmodeField = cast<SubfieldOp>(op);
2643 if (wmodeFieldIndex != wmodeField.getFieldIndex())
2644 continue;
2645 rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2646 }
2647 } else {
2648 rewriter.replaceAllUsesWith(result, newResult);
2649 }
2650 }
2651 rewriter.eraseOp(op);
2652 return success();
2653 }
2654};
2655
2656// Eliminate the dead ports of memories.
2657struct FoldUnusedBits : public mlir::RewritePattern {
2658 FoldUnusedBits(MLIRContext *context)
2659 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2660
2661 LogicalResult matchAndRewrite(Operation *op,
2662 PatternRewriter &rewriter) const override {
2663 MemOp mem = cast<MemOp>(op);
2664 if (hasDontTouch(mem))
2665 return failure();
2666
2667 // Only apply the transformation if the memory is not sequential.
2668 const auto &summary = mem.getSummary();
2669 if (summary.isMasked || summary.isSeqMem())
2670 return failure();
2671
2672 auto type = type_dyn_cast<IntType>(mem.getDataType());
2673 if (!type)
2674 return failure();
2675 auto width = type.getBitWidthOrSentinel();
2676 if (width <= 0)
2677 return failure();
2678
2679 llvm::SmallBitVector usedBits(width);
2680 DenseMap<unsigned, unsigned> mapping;
2681
2682 // Find which bits are used out of the users of a read port. This detects
2683 // ports whose data/rdata field is used only through bit select ops. The
2684 // bit selects are then used to build a bit-mask. The ops are collected.
2685 SmallVector<BitsPrimOp> readOps;
2686 auto findReadUsers = [&](Value port, StringRef field) -> LogicalResult {
2687 auto portTy = type_cast<BundleType>(port.getType());
2688 auto fieldIndex = portTy.getElementIndex(field);
2689 assert(fieldIndex && "missing data port");
2690
2691 for (auto *op : port.getUsers()) {
2692 auto portAccess = cast<SubfieldOp>(op);
2693 if (fieldIndex != portAccess.getFieldIndex())
2694 continue;
2695
2696 for (auto *user : op->getUsers()) {
2697 auto bits = dyn_cast<BitsPrimOp>(user);
2698 if (!bits)
2699 return failure();
2700
2701 usedBits.set(bits.getLo(), bits.getHi() + 1);
2702 if (usedBits.all())
2703 return failure();
2704
2705 mapping[bits.getLo()] = 0;
2706 readOps.push_back(bits);
2707 }
2708 }
2709
2710 return success();
2711 };
2712
2713 // Finds the users of write ports. This expects all the data/wdata fields
2714 // of the ports to be used solely as the destination of matching connects.
2715 // If a memory has ports with other uses, it is excluded from optimisation.
2716 SmallVector<MatchingConnectOp> writeOps;
2717 auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
2718 auto portTy = type_cast<BundleType>(port.getType());
2719 auto fieldIndex = portTy.getElementIndex(field);
2720 assert(fieldIndex && "missing data port");
2721
2722 for (auto *op : port.getUsers()) {
2723 auto portAccess = cast<SubfieldOp>(op);
2724 if (fieldIndex != portAccess.getFieldIndex())
2725 continue;
2726
2727 auto conn = getSingleConnectUserOf(portAccess);
2728 if (!conn)
2729 return failure();
2730
2731 writeOps.push_back(conn);
2732 }
2733 return success();
2734 };
2735
2736 // Traverse all ports and find the read and used data fields.
2737 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2738 // Do not simplify annotated ports.
2739 if (!mem.getPortAnnotation(i).empty())
2740 return failure();
2741
2742 switch (mem.getPortKind(i)) {
2743 case MemOp::PortKind::Debug:
2744 // Skip debug ports.
2745 return failure();
2746 case MemOp::PortKind::Write:
2747 if (failed(findWriteUsers(port, "data")))
2748 return failure();
2749 continue;
2750 case MemOp::PortKind::Read:
2751 if (failed(findReadUsers(port, "data")))
2752 return failure();
2753 continue;
2754 case MemOp::PortKind::ReadWrite:
2755 if (failed(findWriteUsers(port, "wdata")))
2756 return failure();
2757 if (failed(findReadUsers(port, "rdata")))
2758 return failure();
2759 continue;
2760 }
2761 llvm_unreachable("unknown port kind");
2762 }
2763
2764 // Unused memories are handled in a different canonicalizer.
2765 if (usedBits.none())
2766 return failure();
2767
2768 // Build a mapping of existing indices to compacted ones.
2769 SmallVector<std::pair<unsigned, unsigned>> ranges;
2770 unsigned newWidth = 0;
2771 for (int i = usedBits.find_first(); 0 <= i && i < width;) {
2772 int e = usedBits.find_next_unset(i);
2773 if (e < 0)
2774 e = width;
2775 for (int idx = i; idx < e; ++idx, ++newWidth) {
2776 if (auto it = mapping.find(idx); it != mapping.end()) {
2777 it->second = newWidth;
2778 }
2779 }
2780 ranges.emplace_back(i, e - 1);
2781 i = e != width ? usedBits.find_next(e) : e;
2782 }
2783
2784 // Create the new op with the new port types.
2785 auto newType = IntType::get(op->getContext(), type.isSigned(), newWidth);
2786 SmallVector<Type> portTypes;
2787 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2788 portTypes.push_back(
2789 MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
2790 }
2791 auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
2792 mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
2793 mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
2794 mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
2795 mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2796
2797 // Rewrite bundle users to the new data type.
2798 auto rewriteSubfield = [&](Value port, StringRef field) {
2799 auto portTy = type_cast<BundleType>(port.getType());
2800 auto fieldIndex = portTy.getElementIndex(field);
2801 assert(fieldIndex && "missing data port");
2802
2803 rewriter.setInsertionPointAfter(newMem);
2804 auto newPortAccess =
2805 rewriter.create<SubfieldOp>(port.getLoc(), port, field);
2806
2807 for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2808 auto portAccess = cast<SubfieldOp>(op);
2809 if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
2810 continue;
2811 rewriter.replaceOp(portAccess, newPortAccess.getResult());
2812 }
2813 };
2814
2815 // Rewrite the field accesses.
2816 for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
2817 switch (newMem.getPortKind(i)) {
2818 case MemOp::PortKind::Debug:
2819 llvm_unreachable("cannot rewrite debug port");
2820 case MemOp::PortKind::Write:
2821 rewriteSubfield(port, "data");
2822 continue;
2823 case MemOp::PortKind::Read:
2824 rewriteSubfield(port, "data");
2825 continue;
2826 case MemOp::PortKind::ReadWrite:
2827 rewriteSubfield(port, "rdata");
2828 rewriteSubfield(port, "wdata");
2829 continue;
2830 }
2831 llvm_unreachable("unknown port kind");
2832 }
2833
2834 // Rewrite the reads to the new ranges, compacting them.
2835 for (auto readOp : readOps) {
2836 rewriter.setInsertionPointAfter(readOp);
2837 auto it = mapping.find(readOp.getLo());
2838 assert(it != mapping.end() && "bit op mapping not found");
2839 // Create a new bit selection from the compressed memory. The new op may
2840 // be folded if we are selecting the entire compressed memory.
2841 auto newReadValue = rewriter.createOrFold<BitsPrimOp>(
2842 readOp.getLoc(), readOp.getInput(),
2843 readOp.getHi() - readOp.getLo() + it->second, it->second);
2844 rewriter.replaceAllUsesWith(readOp, newReadValue);
2845 rewriter.eraseOp(readOp);
2846 }
2847
2848 // Rewrite the writes into a concatenation of slices.
2849 for (auto writeOp : writeOps) {
2850 Value source = writeOp.getSrc();
2851 rewriter.setInsertionPoint(writeOp);
2852
2853 Value catOfSlices;
2854 for (auto &[start, end] : ranges) {
2855 Value slice = rewriter.createOrFold<BitsPrimOp>(writeOp.getLoc(),
2856 source, end, start);
2857 if (catOfSlices) {
2858 catOfSlices = rewriter.createOrFold<CatPrimOp>(writeOp.getLoc(),
2859 slice, catOfSlices);
2860 } else {
2861 catOfSlices = slice;
2862 }
2863 }
2864
2865 // If the original memory held a signed integer, then the compressed
2866 // memory will be signed too. Since the catOfSlices is always unsigned,
2867 // cast the data to a signed integer if needed before connecting back to
2868 // the memory.
2869 if (type.isSigned())
2870 catOfSlices =
2871 rewriter.createOrFold<AsSIntPrimOp>(writeOp.getLoc(), catOfSlices);
2872
2873 rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
2874 catOfSlices);
2875 }
2876
2877 return success();
2878 }
2879};
2880
2881// Rewrite single-address memories to a firrtl register.
2882struct FoldRegMems : public mlir::RewritePattern {
2883 FoldRegMems(MLIRContext *context)
2884 : RewritePattern(MemOp::getOperationName(), 0, context) {}
2885 LogicalResult matchAndRewrite(Operation *op,
2886 PatternRewriter &rewriter) const override {
2887 MemOp mem = cast<MemOp>(op);
2888 const FirMemory &info = mem.getSummary();
2889 if (hasDontTouch(mem) || info.depth != 1)
2890 return failure();
2891
2892 auto ty = mem.getDataType();
2893 auto loc = mem.getLoc();
2894 auto *block = mem->getBlock();
2895
2896 // Find the clock of the register-to-be, all write ports should share it.
2897 Value clock;
2898 SmallPtrSet<Operation *, 8> connects;
2899 SmallVector<SubfieldOp> portAccesses;
2900 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2901 if (!mem.getPortAnnotation(i).empty())
2902 continue;
2903
2904 auto collect = [&, port = port](ArrayRef<StringRef> fields) {
2905 auto portTy = type_cast<BundleType>(port.getType());
2906 for (auto field : fields) {
2907 auto fieldIndex = portTy.getElementIndex(field);
2908 assert(fieldIndex && "missing field on memory port");
2909
2910 for (auto *op : port.getUsers()) {
2911 auto portAccess = cast<SubfieldOp>(op);
2912 if (fieldIndex != portAccess.getFieldIndex())
2913 continue;
2914 portAccesses.push_back(portAccess);
2915 for (auto *user : portAccess->getUsers()) {
2916 auto conn = dyn_cast<FConnectLike>(user);
2917 if (!conn)
2918 return failure();
2919 connects.insert(conn);
2920 }
2921 }
2922 }
2923 return success();
2924 };
2925
2926 switch (mem.getPortKind(i)) {
2927 case MemOp::PortKind::Debug:
2928 return failure();
2929 case MemOp::PortKind::Read:
2930 if (failed(collect({"clk", "en", "addr"})))
2931 return failure();
2932 continue;
2933 case MemOp::PortKind::Write:
2934 if (failed(collect({"clk", "en", "addr", "data", "mask"})))
2935 return failure();
2936 break;
2937 case MemOp::PortKind::ReadWrite:
2938 if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
2939 return failure();
2940 break;
2941 }
2942
2943 Value portClock = getPortFieldValue(port, "clk");
2944 if (!portClock || (clock && portClock != clock))
2945 return failure();
2946 clock = portClock;
2947 }
2948 // Create a new wire where the memory used to be. This wire will dominate
2949 // all readers of the memory. Reads should be made through this wire.
2950 rewriter.setInsertionPointAfter(mem);
2951 auto memWire = rewriter.create<WireOp>(loc, ty).getResult();
2952
2953 // The memory is replaced by a register, which we place at the end of the
2954 // block, so that any value driven to the original memory will dominate the
2955 // new register (including the clock). All other ops will be placed
2956 // after the register.
2957 rewriter.setInsertionPointToEnd(block);
2958 auto memReg =
2959 rewriter.create<RegOp>(loc, ty, clock, mem.getName()).getResult();
2960
2961 // Connect the output of the register to the wire.
2962 rewriter.create<MatchingConnectOp>(loc, memWire, memReg);
2963
2964 // Helper to insert a given number of pipeline stages through registers.
2965 // The pipelines are placed at the end of the block.
2966 auto pipeline = [&](Value value, Value clock, const Twine &name,
2967 unsigned latency) {
2968 for (unsigned i = 0; i < latency; ++i) {
2969 std::string regName;
2970 {
2971 llvm::raw_string_ostream os(regName);
2972 os << mem.getName() << "_" << name << "_" << i;
2973 }
2974 auto reg = rewriter
2975 .create<RegOp>(mem.getLoc(), value.getType(), clock,
2976 rewriter.getStringAttr(regName))
2977 .getResult();
2978 rewriter.create<MatchingConnectOp>(value.getLoc(), reg, value);
2979 value = reg;
2980 }
2981 return value;
2982 };
2983
2984 const unsigned writeStages = info.writeLatency - 1;
2985
2986 // Traverse each port. Replace reads with the pipelined register, discarding
2987 // the enable flag and reading unconditionally. Pipeline the mask, enable
2988 // and data bits of all write ports to be arbitrated and wired to the reg.
2989 SmallVector<std::tuple<Value, Value, Value>> writes;
2990 for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2991 Value portClock = getPortFieldValue(port, "clk");
2992 StringRef name = mem.getPortName(i);
2993
2994 auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
2995 Value value = getPortFieldValue(port, field);
2996 assert(value);
2997 return pipeline(value, portClock, name + "_" + field, stages);
2998 };
2999
3000 switch (mem.getPortKind(i)) {
3001 case MemOp::PortKind::Debug:
3002 llvm_unreachable("unknown port kind");
3003 case MemOp::PortKind::Read: {
3004 // Read ports pipeline the addr and enable signals. However, the
3005 // address must be 0 for single-address memories and the enable signal
3006 // is ignored, always reading out the register. Under these constraints,
3007 // the read port can be replaced with the value from the register.
3008 replacePortField(rewriter, port, "data", memWire);
3009 break;
3010 }
3011 case MemOp::PortKind::Write: {
3012 auto data = portPipeline("data", writeStages);
3013 auto en = portPipeline("en", writeStages);
3014 auto mask = portPipeline("mask", writeStages);
3015 writes.emplace_back(data, en, mask);
3016 break;
3017 }
3018 case MemOp::PortKind::ReadWrite: {
3019 // Always read the register into the read end.
3020 replacePortField(rewriter, port, "rdata", memWire);
3021
3022 // Create a write enable and pipeline stages.
3023 auto wdata = portPipeline("wdata", writeStages);
3024 auto wmask = portPipeline("wmask", writeStages);
3025
3026 Value en = getPortFieldValue(port, "en");
3027 Value wmode = getPortFieldValue(port, "wmode");
3028
3029 auto wen = rewriter.create<AndPrimOp>(port.getLoc(), en, wmode);
3030 auto wenPipelined =
3031 pipeline(wen, portClock, name + "_wen", writeStages);
3032 writes.emplace_back(wdata, wenPipelined, wmask);
3033 break;
3034 }
3035 }
3036 }
3037
3038 // Regardless of `writeUnderWrite`, always implement PortOrder.
3039 Value next = memReg;
3040 for (auto &[data, en, mask] : writes) {
3041 Value masked;
3042
3043 // If a mask bit is used, emit muxes to select the input from the
3044 // register (no mask) or the input (mask bit set).
3045 Location loc = mem.getLoc();
3046 unsigned maskGran = info.dataWidth / info.maskBits;
3047 for (unsigned i = 0; i < info.maskBits; ++i) {
3048 unsigned hi = (i + 1) * maskGran - 1;
3049 unsigned lo = i * maskGran;
3050
3051 auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
3052 auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3053 auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
3054 auto chunk = rewriter.create<MuxPrimOp>(loc, bit, dataPart, nextPart);
3055
3056 if (masked) {
3057 masked = rewriter.create<CatPrimOp>(loc, chunk, masked);
3058 } else {
3059 masked = chunk;
3060 }
3061 }
3062
3063 next = rewriter.create<MuxPrimOp>(next.getLoc(), en, masked, next);
3064 }
3065 Value typedNext = rewriter.createOrFold<BitCastOp>(next.getLoc(), ty, next);
3066 rewriter.create<MatchingConnectOp>(memReg.getLoc(), memReg, typedNext);
3067
3068 // Delete the fields and their associated connects.
3069 for (Operation *conn : connects)
3070 rewriter.eraseOp(conn);
3071 for (auto portAccess : portAccesses)
3072 rewriter.eraseOp(portAccess);
3073 rewriter.eraseOp(mem);
3074
3075 return success();
3076 }
3077};
3078} // namespace
3079
3080void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3081 MLIRContext *context) {
3082 results
3083 .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3084 FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3085 context);
3086}
3087
3088//===----------------------------------------------------------------------===//
3089// Declarations
3090//===----------------------------------------------------------------------===//
3091
3092// Turn synchronous reset looking register updates to registers with resets.
3093// Also, const prop registers that are driven by a mux tree containing only
3094// instances of one constant or self-assigns.
3095static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3096 // reg ; connect(reg, mux(port, const, val)) ->
3097 // reg.reset(port, const); connect(reg, val)
3098
3099 // Find the one true connect, or bail
3100 auto con = getSingleConnectUserOf(reg.getResult());
3101 if (!con)
3102 return failure();
3103
3104 auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3105 if (!mux)
3106 return failure();
3107 auto *high = mux.getHigh().getDefiningOp();
3108 auto *low = mux.getLow().getDefiningOp();
3109 // Reset value must be constant
3110 auto constOp = dyn_cast_or_null<ConstantOp>(high);
3111
3112 // Detect the case if a register only has two possible drivers:
3113 // (1) itself/uninit and (2) constant.
3114 // The mux can then be replaced with the constant.
3115 // r = mux(cond, r, 3) --> r = 3
3116 // r = mux(cond, 3, r) --> r = 3
3117 bool constReg = false;
3118
3119 if (constOp && low == reg)
3120 constReg = true;
3121 else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3122 constReg = true;
3123 constOp = dyn_cast<ConstantOp>(low);
3124 }
3125 if (!constOp)
3126 return failure();
3127
3128 // For a non-constant register, reset should be a module port (heuristic to
3129 // limit to intended reset lines). Replace the register anyway if constant.
3130 if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3131 return failure();
3132
3133 // Check all types should be typed by now
3134 auto regTy = reg.getResult().getType();
3135 if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3136 mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3137 regTy.getBitWidthOrSentinel() < 0)
3138 return failure();
3139
3140 // Ok, we know we are doing the transformation.
3141
3142 // Make sure the constant dominates all users.
3143 if (constOp != &con->getBlock()->front())
3144 constOp->moveBefore(&con->getBlock()->front());
3145
3146 if (!constReg) {
3147 SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3148 auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3149 rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3150 mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3151 reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3152 reg.getForceableAttr());
3153 newReg->setDialectAttrs(attrs);
3154 }
3155 auto pt = rewriter.saveInsertionPoint();
3156 rewriter.setInsertionPoint(con);
3157 auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3158 replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3159 rewriter.restoreInsertionPoint(pt);
3160 return success();
3161}
3162
3163LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3164 if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3165 succeeded(foldHiddenReset(op, rewriter)))
3166 return success();
3167
3168 if (succeeded(demoteForceableIfUnused(op, rewriter)))
3169 return success();
3170
3171 return failure();
3172}
3173
3174//===----------------------------------------------------------------------===//
3175// Verification Ops.
3176//===----------------------------------------------------------------------===//
3177
3178static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3179 Value enable,
3180 PatternRewriter &rewriter,
3181 bool eraseIfZero) {
3182 // If the verification op is never enabled, delete it.
3183 if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3184 if (constant.getValue().isZero()) {
3185 rewriter.eraseOp(op);
3186 return success();
3187 }
3188 }
3189
3190 // If the verification op is never triggered, delete it.
3191 if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3192 if (constant.getValue().isZero() == eraseIfZero) {
3193 rewriter.eraseOp(op);
3194 return success();
3195 }
3196 }
3197
3198 return failure();
3199}
3200
3201template <class Op, bool EraseIfZero = false>
3202static LogicalResult canonicalizeImmediateVerifOp(Op op,
3203 PatternRewriter &rewriter) {
3204 return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3205 EraseIfZero);
3206}
3207
3208void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3209 MLIRContext *context) {
3210 results.add(canonicalizeImmediateVerifOp<AssertOp>);
3211 results.add<patterns::AssertXWhenX>(context);
3212}
3213
3214void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3215 MLIRContext *context) {
3216 results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3217 results.add<patterns::AssumeXWhenX>(context);
3218}
3219
3220void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3221 RewritePatternSet &results, MLIRContext *context) {
3222 results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3223 results.add<patterns::UnclockedAssumeIntrinsicXWhenX>(context);
3224}
3225
3226void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3227 MLIRContext *context) {
3228 results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3229}
3230
3231//===----------------------------------------------------------------------===//
3232// InvalidValueOp
3233//===----------------------------------------------------------------------===//
3234
3235LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3236 PatternRewriter &rewriter) {
3237 // Remove `InvalidValueOp`s with no uses.
3238 if (op.use_empty()) {
3239 rewriter.eraseOp(op);
3240 return success();
3241 }
3242 // Propagate invalids through a single use which is a unary op. You cannot
3243 // propagate through multiple uses as that breaks invalid semantics. Nor
3244 // can you propagate through binary ops or generally any op which computes.
3245 // Not is an exception as it is a pure, all-bits inverse.
3246 if (op->hasOneUse() &&
3247 (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3248 SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3249 *op->user_begin()) ||
3250 (isa<CvtPrimOp>(*op->user_begin()) &&
3251 type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3252 (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3253 type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3254 .getBitWidthOrSentinel() > 0))) {
3255 auto *modop = *op->user_begin();
3256 auto inv = rewriter.create<InvalidValueOp>(op.getLoc(),
3257 modop->getResult(0).getType());
3258 rewriter.replaceAllOpUsesWith(modop, inv);
3259 rewriter.eraseOp(modop);
3260 rewriter.eraseOp(op);
3261 return success();
3262 }
3263 return failure();
3264}
3265
3266OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3267 if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3268 return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3269 return {};
3270}
3271
3272//===----------------------------------------------------------------------===//
3273// ClockGateIntrinsicOp
3274//===----------------------------------------------------------------------===//
3275
3276OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3277 // Forward the clock if one of the enables is always true.
3278 if (isConstantOne(adaptor.getEnable()) ||
3279 isConstantOne(adaptor.getTestEnable()))
3280 return getInput();
3281
3282 // Fold to a constant zero clock if the enables are always false.
3283 if (isConstantZero(adaptor.getEnable()) &&
3284 (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3285 return BoolAttr::get(getContext(), false);
3286
3287 // Forward constant zero clocks.
3288 if (isConstantZero(adaptor.getInput()))
3289 return BoolAttr::get(getContext(), false);
3290
3291 return {};
3292}
3293
3294LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3295 PatternRewriter &rewriter) {
3296 // Remove constant false test enable.
3297 if (auto testEnable = op.getTestEnable()) {
3298 if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3299 if (constOp.getValue().isZero()) {
3300 rewriter.modifyOpInPlace(op,
3301 [&] { op.getTestEnableMutable().clear(); });
3302 return success();
3303 }
3304 }
3305 }
3306
3307 return failure();
3308}
3309
3310//===----------------------------------------------------------------------===//
3311// Reference Ops.
3312//===----------------------------------------------------------------------===//
3313
3314// refresolve(forceable.ref) -> forceable.data
3315static LogicalResult
3316canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3317 auto forceable = op.getRef().getDefiningOp<Forceable>();
3318 if (!forceable || !forceable.isForceable() ||
3319 op.getRef() != forceable.getDataRef() ||
3320 op.getType() != forceable.getDataType())
3321 return failure();
3322 rewriter.replaceAllUsesWith(op, forceable.getData());
3323 return success();
3324}
3325
3326void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3327 MLIRContext *context) {
3328 results.insert<patterns::RefResolveOfRefSend>(context);
3329 results.insert(canonicalizeRefResolveOfForceable);
3330}
3331
3332OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3333 // RefCast is unnecessary if types match.
3334 if (getInput().getType() == getType())
3335 return getInput();
3336 return {};
3337}
3338
3339static bool isConstantZero(Value operand) {
3340 auto constOp = operand.getDefiningOp<ConstantOp>();
3341 return constOp && constOp.getValue().isZero();
3342}
3343
3344template <typename Op>
3345static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3346 if (isConstantZero(op.getPredicate())) {
3347 rewriter.eraseOp(op);
3348 return success();
3349 }
3350 return failure();
3351}
3352
3353void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3354 MLIRContext *context) {
3355 results.add(eraseIfPredFalse<RefForceOp>);
3356}
3357void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3358 MLIRContext *context) {
3359 results.add(eraseIfPredFalse<RefForceInitialOp>);
3360}
3361void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3362 MLIRContext *context) {
3363 results.add(eraseIfPredFalse<RefReleaseOp>);
3364}
3365void RefReleaseInitialOp::getCanonicalizationPatterns(
3366 RewritePatternSet &results, MLIRContext *context) {
3367 results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3368}
3369
3370//===----------------------------------------------------------------------===//
3371// HasBeenResetIntrinsicOp
3372//===----------------------------------------------------------------------===//
3373
3374OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3375 // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3376
3377 // Fold to zero if the reset is a constant. In this case the op is either
3378 // permanently in reset or never resets. Both mean that the reset never
3379 // finishes, so this op never returns true.
3380 if (adaptor.getReset())
3381 return getIntZerosAttr(UIntType::get(getContext(), 1));
3382
3383 // Fold to zero if the clock is a constant and the reset is synchronous. In
3384 // that case the reset will never be started.
3385 if (isUInt1(getReset().getType()) && adaptor.getClock())
3386 return getIntZerosAttr(UIntType::get(getContext(), 1));
3387
3388 return {};
3389}
3390
3391//===----------------------------------------------------------------------===//
3392// FPGAProbeIntrinsicOp
3393//===----------------------------------------------------------------------===//
3394
3395static bool isTypeEmpty(FIRRTLType type) {
3397 .Case<FVectorType>(
3398 [&](auto ty) -> bool { return isTypeEmpty(ty.getElementType()); })
3399 .Case<BundleType>([&](auto ty) -> bool {
3400 for (auto elem : ty.getElements())
3401 if (!isTypeEmpty(elem.type))
3402 return false;
3403 return true;
3404 })
3405 .Case<IntType>([&](auto ty) { return ty.getWidth() == 0; })
3406 .Default([](auto) -> bool { return false; });
3407}
3408
3409LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3410 PatternRewriter &rewriter) {
3411 auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3412 if (!firrtlTy)
3413 return failure();
3414
3415 if (!isTypeEmpty(firrtlTy))
3416 return failure();
3417
3418 rewriter.eraseOp(op);
3419 return success();
3420}
3421
3422//===----------------------------------------------------------------------===//
3423// Layer Block Op
3424//===----------------------------------------------------------------------===//
3425
3426LogicalResult LayerBlockOp::canonicalize(LayerBlockOp op,
3427 PatternRewriter &rewriter) {
3428
3429 // If the layerblock is empty, erase it.
3430 if (op.getBody()->empty()) {
3431 rewriter.eraseOp(op);
3432 return success();
3433 }
3434
3435 return failure();
3436}
assert(baseType &&"element must be base type")
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static Value moveNameHint(OpResult old, Value passthrough)
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
static InstancePath empty
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:212
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
bool hasWidth() const
Return true if this integer type has a known width.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition APInt.cpp:18
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition Naming.cpp:47
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition FoldUtils.h:27
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition FoldUtils.h:34
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21