CIRCT  19.0.0git
FIRRTLFolds.cpp
Go to the documentation of this file.
1 //===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the folding and canonicalizations for FIRRTL ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "circt/Support/APInt.h"
18 #include "circt/Support/LLVM.h"
19 #include "circt/Support/Naming.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/APSInt.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 using namespace circt;
28 using namespace firrtl;
29 
30 // Drop writes to old and pass through passthrough to make patterns easier to
31 // write.
32 static Value dropWrite(PatternRewriter &rewriter, OpResult old,
33  Value passthrough) {
34  SmallPtrSet<Operation *, 8> users;
35  for (auto *user : old.getUsers())
36  users.insert(user);
37  for (Operation *user : users)
38  if (auto connect = dyn_cast<FConnectLike>(user))
39  if (connect.getDest() == old)
40  rewriter.eraseOp(user);
41  return passthrough;
42 }
43 
44 // Move a name hint from a soon to be deleted operation to a new operation.
45 // Pass through the new operation to make patterns easier to write. This cannot
46 // move a name to a port (block argument), doing so would require rewriting all
47 // instance sites as well as the module.
48 static Value moveNameHint(OpResult old, Value passthrough) {
49  Operation *op = passthrough.getDefiningOp();
50  // This should handle ports, but it isn't clear we can change those in
51  // canonicalizers.
52  assert(op && "passthrough must be an operation");
53  Operation *oldOp = old.getOwner();
54  auto name = oldOp->getAttrOfType<StringAttr>("name");
55  if (name && !name.getValue().empty())
56  op->setAttr("name", name);
57  return passthrough;
58 }
59 
60 // Declarative canonicalization patterns
61 namespace circt {
62 namespace firrtl {
63 namespace patterns {
64 #include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
65 } // namespace patterns
66 } // namespace firrtl
67 } // namespace circt
68 
69 /// Return true if this operation's operands and results all have a known width.
70 /// This only works for integer types.
71 static bool hasKnownWidthIntTypes(Operation *op) {
72  auto resultType = type_cast<IntType>(op->getResult(0).getType());
73  if (!resultType.hasWidth())
74  return false;
75  for (Value operand : op->getOperands())
76  if (!type_cast<IntType>(operand.getType()).hasWidth())
77  return false;
78  return true;
79 }
80 
81 /// Return true if this value is 1 bit UInt.
82 static bool isUInt1(Type type) {
83  auto t = type_dyn_cast<UIntType>(type);
84  if (!t || !t.hasWidth() || t.getWidth() != 1)
85  return false;
86  return true;
87 }
88 
89 /// Set the name of an op based on the best of two names: The current name, and
90 /// the name passed in.
91 static void updateName(PatternRewriter &rewriter, Operation *op,
92  StringAttr name) {
93  // Should never rename InstanceOp
94  assert(!isa<InstanceOp>(op));
95  if (!name || name.getValue().empty())
96  return;
97  auto newName = name.getValue(); // old name is interesting
98  auto newOpName = op->getAttrOfType<StringAttr>("name");
99  // new name might not be interesting
100  if (newOpName)
101  newName = chooseName(newOpName.getValue(), name.getValue());
102  // Only update if needed
103  if (!newOpName || newOpName.getValue() != newName)
104  rewriter.modifyOpInPlace(
105  op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
106 }
107 
108 /// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
109 /// If a replaced op has a "name" attribute, this function propagates the name
110 /// to the new value.
111 static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
112  Value newValue) {
113  if (auto *newOp = newValue.getDefiningOp()) {
114  auto name = op->getAttrOfType<StringAttr>("name");
115  updateName(rewriter, newOp, name);
116  }
117  rewriter.replaceOp(op, newValue);
118 }
119 
120 /// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
121 /// attribute. If a replaced op has a "name" attribute, this function propagates
122 /// the name to the new value.
123 template <typename OpTy, typename... Args>
124 static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
125  Operation *op, Args &&...args) {
126  auto name = op->getAttrOfType<StringAttr>("name");
127  auto newOp =
128  rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
129  updateName(rewriter, newOp, name);
130  return newOp;
131 }
132 
133 /// Return true if the name is droppable. Note that this is different from
134 /// `isUselessName` because non-useless names may be also droppable.
135 bool circt::firrtl::hasDroppableName(Operation *op) {
136  if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
137  return namableOp.hasDroppableName();
138  return false;
139 }
140 
141 /// Implicitly replace the operand to a constant folding operation with a const
142 /// 0 in case the operand is non-constant but has a bit width 0, or if the
143 /// operand is an invalid value.
144 ///
145 /// This makes constant folding significantly easier, as we can simply pass the
146 /// operands to an operation through this function to appropriately replace any
147 /// zero-width dynamic values or invalid values with a constant of value 0.
148 static std::optional<APSInt>
149 getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
150  assert(type_cast<IntType>(operand.getType()) &&
151  "getExtendedConstant is limited to integer types");
152 
153  // We never support constant folding to unknown width values.
154  if (destWidth < 0)
155  return {};
156 
157  // Extension signedness follows the operand sign.
158  if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
159  return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
160 
161  // If the operand is zero bits, then we can return a zero of the result
162  // type.
163  if (type_cast<IntType>(operand.getType()).getWidth() == 0)
164  return APSInt(destWidth,
165  type_cast<IntType>(operand.getType()).isUnsigned());
166  return {};
167 }
168 
169 /// Determine the value of a constant operand for the sake of constant folding.
170 static std::optional<APSInt> getConstant(Attribute operand) {
171  if (!operand)
172  return {};
173  if (auto attr = dyn_cast<BoolAttr>(operand))
174  return APSInt(APInt(1, attr.getValue()));
175  if (auto attr = dyn_cast<IntegerAttr>(operand))
176  return attr.getAPSInt();
177  return {};
178 }
179 
180 /// Determine whether a constant operand is a zero value for the sake of
181 /// constant folding. This considers `invalidvalue` to be zero.
182 static bool isConstantZero(Attribute operand) {
183  if (auto cst = getConstant(operand))
184  return cst->isZero();
185  return false;
186 }
187 
188 /// Determine whether a constant operand is a one value for the sake of constant
189 /// folding.
190 static bool isConstantOne(Attribute operand) {
191  if (auto cst = getConstant(operand))
192  return cst->isOne();
193  return false;
194 }
195 
196 /// This is the policy for folding, which depends on the sort of operator we're
197 /// processing.
198 enum class BinOpKind {
199  Normal,
200  Compare,
202 };
203 
204 /// Applies the constant folding function `calculate` to the given operands.
205 ///
206 /// Sign or zero extends the operands appropriately to the bitwidth of the
207 /// result type if \p useDstWidth is true, else to the larger of the two operand
208 /// bit widths and depending on whether the operation is to be performed on
209 /// signed or unsigned operands.
210 static Attribute constFoldFIRRTLBinaryOp(
211  Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
212  const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
213  assert(operands.size() == 2 && "binary op takes two operands");
214 
215  // We cannot fold something to an unknown width.
216  auto resultType = type_cast<IntType>(op->getResult(0).getType());
217  if (resultType.getWidthOrSentinel() < 0)
218  return {};
219 
220  // Any binary op returning i0 is 0.
221  if (resultType.getWidthOrSentinel() == 0)
222  return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
223 
224  // Determine the operand widths. This is either dictated by the operand type,
225  // or if that type is an unsized integer, by the actual bits necessary to
226  // represent the constant value.
227  auto lhsWidth =
228  type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
229  auto rhsWidth =
230  type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
231  if (auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
232  lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
233  if (auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
234  rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
235 
236  // Compares extend the operands to the widest of the operand types, not to the
237  // result type.
238  int32_t operandWidth;
239  switch (opKind) {
240  case BinOpKind::Normal:
241  operandWidth = resultType.getWidthOrSentinel();
242  break;
243  case BinOpKind::Compare:
244  // Compares compute with the widest operand, not at the destination type
245  // (which is always i1).
246  operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
247  break;
249  operandWidth =
250  std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
251  break;
252  }
253 
254  auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
255  if (!lhs)
256  return {};
257  auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
258  if (!rhs)
259  return {};
260 
261  APInt resultValue = calculate(*lhs, *rhs);
262 
263  // If the result type is smaller than the computation then we need to
264  // narrow the constant after the calculation.
265  if (opKind == BinOpKind::DivideOrShift)
266  resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
267 
268  assert((unsigned)resultType.getWidthOrSentinel() ==
269  resultValue.getBitWidth());
270  return getIntAttr(resultType, resultValue);
271 }
272 
273 /// Applies the canonicalization function `canonicalize` to the given operation.
274 ///
275 /// Determines which (if any) of the operation's operands are constants, and
276 /// provides them as arguments to the callback function. Any `invalidvalue` in
277 /// the input is mapped to a constant zero. The value returned from the callback
278 /// is used as the replacement for `op`, and an additional pad operation is
279 /// inserted if necessary. Does nothing if the result of `op` is of unknown
280 /// width, in which case the necessity of a pad cannot be determined.
281 static LogicalResult canonicalizePrimOp(
282  Operation *op, PatternRewriter &rewriter,
283  const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
284  // Can only operate on FIRRTL primitive operations.
285  if (op->getNumResults() != 1)
286  return failure();
287  auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
288  if (!type)
289  return failure();
290 
291  // Can only operate on operations with a known result width.
292  auto width = type.getBitWidthOrSentinel();
293  if (width < 0)
294  return failure();
295 
296  // Determine which of the operands are constants.
297  SmallVector<Attribute, 3> constOperands;
298  constOperands.reserve(op->getNumOperands());
299  for (auto operand : op->getOperands()) {
300  Attribute attr;
301  if (auto *defOp = operand.getDefiningOp())
302  TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
303  [&](auto op) { attr = op.getValueAttr(); });
304  constOperands.push_back(attr);
305  }
306 
307  // Perform the canonicalization and materialize the result if it is a
308  // constant.
309  auto result = canonicalize(constOperands);
310  if (!result)
311  return failure();
312  Value resultValue;
313  if (auto cst = dyn_cast<Attribute>(result))
314  resultValue = op->getDialect()
315  ->materializeConstant(rewriter, cst, type, op->getLoc())
316  ->getResult(0);
317  else
318  resultValue = result.get<Value>();
319 
320  // Insert a pad if the type widths disagree.
321  if (width !=
322  type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
323  resultValue = rewriter.create<PadPrimOp>(op->getLoc(), resultValue, width);
324 
325  // Insert a cast if this is a uint vs. sint or vice versa.
326  if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
327  resultValue = rewriter.create<AsSIntPrimOp>(op->getLoc(), resultValue);
328  else if (type_isa<UIntType>(type) &&
329  type_isa<SIntType>(resultValue.getType()))
330  resultValue = rewriter.create<AsUIntPrimOp>(op->getLoc(), resultValue);
331 
332  assert(type == resultValue.getType() && "canonicalization changed type");
333  replaceOpAndCopyName(rewriter, op, resultValue);
334  return success();
335 }
336 
337 /// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
338 /// value if `bitWidth` is 0.
339 static APInt getMaxUnsignedValue(unsigned bitWidth) {
340  return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
341 }
342 
343 /// Get the smallest signed value of a given bit width. Returns a 1-bit zero
344 /// value if `bitWidth` is 0.
345 static APInt getMinSignedValue(unsigned bitWidth) {
346  return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
347 }
348 
349 /// Get the largest signed value of a given bit width. Returns a 1-bit zero
350 /// value if `bitWidth` is 0.
351 static APInt getMaxSignedValue(unsigned bitWidth) {
352  return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // Fold Hooks
357 //===----------------------------------------------------------------------===//
358 
359 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
360  assert(adaptor.getOperands().empty() && "constant has no operands");
361  return getValueAttr();
362 }
363 
364 OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
365  assert(adaptor.getOperands().empty() && "constant has no operands");
366  return getValueAttr();
367 }
368 
369 OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
370  assert(adaptor.getOperands().empty() && "constant has no operands");
371  return getFieldsAttr();
372 }
373 
374 OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
375  assert(adaptor.getOperands().empty() && "constant has no operands");
376  return getValueAttr();
377 }
378 
379 OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
380  assert(adaptor.getOperands().empty() && "constant has no operands");
381  return getValueAttr();
382 }
383 
384 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
385  assert(adaptor.getOperands().empty() && "constant has no operands");
386  return getValueAttr();
387 }
388 
389 OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
390  assert(adaptor.getOperands().empty() && "constant has no operands");
391  return getValueAttr();
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // Binary Operators
396 //===----------------------------------------------------------------------===//
397 
398 OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
400  *this, adaptor.getOperands(), BinOpKind::Normal,
401  [=](const APSInt &a, const APSInt &b) { return a + b; });
402 }
403 
404 void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
405  MLIRContext *context) {
406  results.insert<patterns::moveConstAdd, patterns::AddOfZero,
407  patterns::AddOfSelf, patterns::AddOfPad>(context);
408 }
409 
410 OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
412  *this, adaptor.getOperands(), BinOpKind::Normal,
413  [=](const APSInt &a, const APSInt &b) { return a - b; });
414 }
415 
416 void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417  MLIRContext *context) {
418  results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
419  patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
420  patterns::SubOfPadL, patterns::SubOfPadR>(context);
421 }
422 
423 OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
424  // mul(x, 0) -> 0
425  //
426  // This is legal because it aligns with the Scala FIRRTL Compiler
427  // interpretation of lowering invalid to constant zero before constant
428  // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
429  // multiplication this way and will emit "x * 0".
430  if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
431  return getIntZerosAttr(getType());
432 
434  *this, adaptor.getOperands(), BinOpKind::Normal,
435  [=](const APSInt &a, const APSInt &b) { return a * b; });
436 }
437 
438 OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
439  /// div(x, x) -> 1
440  ///
441  /// Division by zero is undefined in the FIRRTL specification. This fold
442  /// exploits that fact to optimize self division to one. Note: this should
443  /// supersede any division with invalid or zero. Division of invalid by
444  /// invalid should be one.
445  if (getLhs() == getRhs()) {
446  auto width = getType().base().getWidthOrSentinel();
447  if (width == -1)
448  width = 2;
449  // Only fold if we have at least 1 bit of width to represent the `1` value.
450  if (width != 0)
451  return getIntAttr(getType(), APInt(width, 1));
452  }
453 
454  // div(0, x) -> 0
455  //
456  // This is legal because it aligns with the Scala FIRRTL Compiler
457  // interpretation of lowering invalid to constant zero before constant
458  // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
459  // division this way and will emit "0 / x".
460  if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
461  return getIntZerosAttr(getType());
462 
463  /// div(x, 1) -> x : (uint, uint) -> uint
464  ///
465  /// UInt division by one returns the numerator. SInt division can't
466  /// be folded here because it increases the return type bitwidth by
467  /// one and requires sign extension (a new op).
468  if (auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
469  if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
470  return getLhs();
471 
473  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
474  [=](const APSInt &a, const APSInt &b) -> APInt {
475  if (!!b)
476  return a / b;
477  return APInt(a.getBitWidth(), 0);
478  });
479 }
480 
481 OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
482  // rem(x, x) -> 0
483  //
484  // Division by zero is undefined in the FIRRTL specification. This fold
485  // exploits that fact to optimize self division remainder to zero. Note:
486  // this should supersede any division with invalid or zero. Remainder of
487  // division of invalid by invalid should be zero.
488  if (getLhs() == getRhs())
489  return getIntZerosAttr(getType());
490 
491  // rem(0, x) -> 0
492  //
493  // This is legal because it aligns with the Scala FIRRTL Compiler
494  // interpretation of lowering invalid to constant zero before constant
495  // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
496  // division this way and will emit "0 % x".
497  if (isConstantZero(adaptor.getLhs()))
498  return getIntZerosAttr(getType());
499 
501  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
502  [=](const APSInt &a, const APSInt &b) -> APInt {
503  if (!!b)
504  return a % b;
505  return APInt(a.getBitWidth(), 0);
506  });
507 }
508 
509 OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
511  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
512  [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
513 }
514 
515 OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
517  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
518  [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
519 }
520 
521 OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
523  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
524  [=](const APSInt &a, const APSInt &b) -> APInt {
525  return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
526  : a.ashr(b);
527  });
528 }
529 
530 // TODO: Move to DRR.
531 OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
532  if (auto rhsCst = getConstant(adaptor.getRhs())) {
533  /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
534  if (rhsCst->isZero())
535  return getIntZerosAttr(getType());
536 
537  /// and(x, -1) -> x
538  if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
539  getRhs().getType() == getType())
540  return getLhs();
541  }
542 
543  if (auto lhsCst = getConstant(adaptor.getLhs())) {
544  /// and(0, x) -> 0, 0 is largest or is implicit zero extended
545  if (lhsCst->isZero())
546  return getIntZerosAttr(getType());
547 
548  /// and(-1, x) -> x
549  if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
550  getRhs().getType() == getType())
551  return getRhs();
552  }
553 
554  /// and(x, x) -> x
555  if (getLhs() == getRhs() && getRhs().getType() == getType())
556  return getRhs();
557 
559  *this, adaptor.getOperands(), BinOpKind::Normal,
560  [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
561 }
562 
563 void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
564  MLIRContext *context) {
565  results
566  .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
567  patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
568  patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
569 }
570 
571 OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
572  if (auto rhsCst = getConstant(adaptor.getRhs())) {
573  /// or(x, 0) -> x
574  if (rhsCst->isZero() && getLhs().getType() == getType())
575  return getLhs();
576 
577  /// or(x, -1) -> -1
578  if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
579  getLhs().getType() == getType())
580  return getRhs();
581  }
582 
583  if (auto lhsCst = getConstant(adaptor.getLhs())) {
584  /// or(0, x) -> x
585  if (lhsCst->isZero() && getRhs().getType() == getType())
586  return getRhs();
587 
588  /// or(-1, x) -> -1
589  if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
590  getRhs().getType() == getType())
591  return getLhs();
592  }
593 
594  /// or(x, x) -> x
595  if (getLhs() == getRhs() && getRhs().getType() == getType())
596  return getRhs();
597 
599  *this, adaptor.getOperands(), BinOpKind::Normal,
600  [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
601 }
602 
603 void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
604  MLIRContext *context) {
605  results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
606  patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
607  patterns::OrOrr>(context);
608 }
609 
610 OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
611  /// xor(x, 0) -> x
612  if (auto rhsCst = getConstant(adaptor.getRhs()))
613  if (rhsCst->isZero() &&
614  firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
615  return getLhs();
616 
617  /// xor(x, 0) -> x
618  if (auto lhsCst = getConstant(adaptor.getLhs()))
619  if (lhsCst->isZero() &&
620  firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
621  return getRhs();
622 
623  /// xor(x, x) -> 0
624  if (getLhs() == getRhs())
625  return getIntAttr(
626  getType(),
627  APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
628 
630  *this, adaptor.getOperands(), BinOpKind::Normal,
631  [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
632 }
633 
634 void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
635  MLIRContext *context) {
636  results.insert<patterns::extendXor, patterns::moveConstXor,
637  patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
638  context);
639 }
640 
641 void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
642  MLIRContext *context) {
643  results.insert<patterns::LEQWithConstLHS>(context);
644 }
645 
646 OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
647  bool isUnsigned = getLhs().getType().base().isUnsigned();
648 
649  // leq(x, x) -> 1
650  if (getLhs() == getRhs())
651  return getIntAttr(getType(), APInt(1, 1));
652 
653  // Comparison against constant outside type bounds.
654  if (auto width = getLhs().getType().base().getWidth()) {
655  if (auto rhsCst = getConstant(adaptor.getRhs())) {
656  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
657  commonWidth = std::max(commonWidth, 1);
658 
659  // leq(x, const) -> 0 where const < minValue of the unsigned type of x
660  // This can never occur since const is unsigned and cannot be less than 0.
661 
662  // leq(x, const) -> 0 where const < minValue of the signed type of x
663  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
664  .slt(getMinSignedValue(*width).sext(commonWidth)))
665  return getIntAttr(getType(), APInt(1, 0));
666 
667  // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
668  if (isUnsigned && rhsCst->zext(commonWidth)
669  .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
670  return getIntAttr(getType(), APInt(1, 1));
671 
672  // leq(x, const) -> 1 where const >= maxValue of the signed type of x
673  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
674  .sge(getMaxSignedValue(*width).sext(commonWidth)))
675  return getIntAttr(getType(), APInt(1, 1));
676  }
677  }
678 
680  *this, adaptor.getOperands(), BinOpKind::Compare,
681  [=](const APSInt &a, const APSInt &b) -> APInt {
682  return APInt(1, a <= b);
683  });
684 }
685 
686 void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
687  MLIRContext *context) {
688  results.insert<patterns::LTWithConstLHS>(context);
689 }
690 
691 OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
692  IntType lhsType = getLhs().getType();
693  bool isUnsigned = lhsType.isUnsigned();
694 
695  // lt(x, x) -> 0
696  if (getLhs() == getRhs())
697  return getIntAttr(getType(), APInt(1, 0));
698 
699  // lt(x, 0) -> 0 when x is unsigned
700  if (auto rhsCst = getConstant(adaptor.getRhs())) {
701  if (rhsCst->isZero() && lhsType.isUnsigned())
702  return getIntAttr(getType(), APInt(1, 0));
703  }
704 
705  // Comparison against constant outside type bounds.
706  if (auto width = lhsType.getWidth()) {
707  if (auto rhsCst = getConstant(adaptor.getRhs())) {
708  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
709  commonWidth = std::max(commonWidth, 1);
710 
711  // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
712  // Handled explicitly above.
713 
714  // lt(x, const) -> 0 where const <= minValue of the signed type of x
715  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
716  .sle(getMinSignedValue(*width).sext(commonWidth)))
717  return getIntAttr(getType(), APInt(1, 0));
718 
719  // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
720  if (isUnsigned && rhsCst->zext(commonWidth)
721  .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
722  return getIntAttr(getType(), APInt(1, 1));
723 
724  // lt(x, const) -> 1 where const > maxValue of the signed type of x
725  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
726  .sgt(getMaxSignedValue(*width).sext(commonWidth)))
727  return getIntAttr(getType(), APInt(1, 1));
728  }
729  }
730 
732  *this, adaptor.getOperands(), BinOpKind::Compare,
733  [=](const APSInt &a, const APSInt &b) -> APInt {
734  return APInt(1, a < b);
735  });
736 }
737 
738 void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
739  MLIRContext *context) {
740  results.insert<patterns::GEQWithConstLHS>(context);
741 }
742 
743 OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
744  IntType lhsType = getLhs().getType();
745  bool isUnsigned = lhsType.isUnsigned();
746 
747  // geq(x, x) -> 1
748  if (getLhs() == getRhs())
749  return getIntAttr(getType(), APInt(1, 1));
750 
751  // geq(x, 0) -> 1 when x is unsigned
752  if (auto rhsCst = getConstant(adaptor.getRhs())) {
753  if (rhsCst->isZero() && isUnsigned)
754  return getIntAttr(getType(), APInt(1, 1));
755  }
756 
757  // Comparison against constant outside type bounds.
758  if (auto width = lhsType.getWidth()) {
759  if (auto rhsCst = getConstant(adaptor.getRhs())) {
760  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
761  commonWidth = std::max(commonWidth, 1);
762 
763  // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
764  if (isUnsigned && rhsCst->zext(commonWidth)
765  .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
766  return getIntAttr(getType(), APInt(1, 0));
767 
768  // geq(x, const) -> 0 where const > maxValue of the signed type of x
769  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
770  .sgt(getMaxSignedValue(*width).sext(commonWidth)))
771  return getIntAttr(getType(), APInt(1, 0));
772 
773  // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
774  // Handled explicitly above.
775 
776  // geq(x, const) -> 1 where const <= minValue of the signed type of x
777  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
778  .sle(getMinSignedValue(*width).sext(commonWidth)))
779  return getIntAttr(getType(), APInt(1, 1));
780  }
781  }
782 
784  *this, adaptor.getOperands(), BinOpKind::Compare,
785  [=](const APSInt &a, const APSInt &b) -> APInt {
786  return APInt(1, a >= b);
787  });
788 }
789 
790 void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
791  MLIRContext *context) {
792  results.insert<patterns::GTWithConstLHS>(context);
793 }
794 
795 OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
796  IntType lhsType = getLhs().getType();
797  bool isUnsigned = lhsType.isUnsigned();
798 
799  // gt(x, x) -> 0
800  if (getLhs() == getRhs())
801  return getIntAttr(getType(), APInt(1, 0));
802 
803  // Comparison against constant outside type bounds.
804  if (auto width = lhsType.getWidth()) {
805  if (auto rhsCst = getConstant(adaptor.getRhs())) {
806  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
807  commonWidth = std::max(commonWidth, 1);
808 
809  // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
810  if (isUnsigned && rhsCst->zext(commonWidth)
811  .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
812  return getIntAttr(getType(), APInt(1, 0));
813 
814  // gt(x, const) -> 0 where const >= maxValue of the signed type of x
815  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
816  .sge(getMaxSignedValue(*width).sext(commonWidth)))
817  return getIntAttr(getType(), APInt(1, 0));
818 
819  // gt(x, const) -> 1 where const < minValue of the unsigned type of x
820  // This can never occur since const is unsigned and cannot be less than 0.
821 
822  // gt(x, const) -> 1 where const < minValue of the signed type of x
823  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
824  .slt(getMinSignedValue(*width).sext(commonWidth)))
825  return getIntAttr(getType(), APInt(1, 1));
826  }
827  }
828 
830  *this, adaptor.getOperands(), BinOpKind::Compare,
831  [=](const APSInt &a, const APSInt &b) -> APInt {
832  return APInt(1, a > b);
833  });
834 }
835 
836 OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
837  // eq(x, x) -> 1
838  if (getLhs() == getRhs())
839  return getIntAttr(getType(), APInt(1, 1));
840 
841  if (auto rhsCst = getConstant(adaptor.getRhs())) {
842  /// eq(x, 1) -> x when x is 1 bit.
843  /// TODO: Support SInt<1> on the LHS etc.
844  if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
845  getRhs().getType() == getType())
846  return getLhs();
847  }
848 
850  *this, adaptor.getOperands(), BinOpKind::Compare,
851  [=](const APSInt &a, const APSInt &b) -> APInt {
852  return APInt(1, a == b);
853  });
854 }
855 
856 LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
857  return canonicalizePrimOp(
858  op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
859  if (auto rhsCst = getConstant(operands[1])) {
860  auto width = op.getLhs().getType().getBitWidthOrSentinel();
861 
862  // eq(x, 0) -> not(x) when x is 1 bit.
863  if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
864  op.getRhs().getType() == op.getType()) {
865  return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
866  .getResult();
867  }
868 
869  // eq(x, 0) -> not(orr(x)) when x is >1 bit
870  if (rhsCst->isZero() && width > 1) {
871  auto orrOp = rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs());
872  return rewriter.create<NotPrimOp>(op.getLoc(), orrOp).getResult();
873  }
874 
875  // eq(x, ~0) -> andr(x) when x is >1 bit
876  if (rhsCst->isAllOnes() && width > 1 &&
877  op.getLhs().getType() == op.getRhs().getType()) {
878  return rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs())
879  .getResult();
880  }
881  }
882  return {};
883  });
884 }
885 
886 OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
887  // neq(x, x) -> 0
888  if (getLhs() == getRhs())
889  return getIntAttr(getType(), APInt(1, 0));
890 
891  if (auto rhsCst = getConstant(adaptor.getRhs())) {
892  /// neq(x, 0) -> x when x is 1 bit.
893  /// TODO: Support SInt<1> on the LHS etc.
894  if (rhsCst->isZero() && getLhs().getType() == getType() &&
895  getRhs().getType() == getType())
896  return getLhs();
897  }
898 
900  *this, adaptor.getOperands(), BinOpKind::Compare,
901  [=](const APSInt &a, const APSInt &b) -> APInt {
902  return APInt(1, a != b);
903  });
904 }
905 
906 LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
907  return canonicalizePrimOp(
908  op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
909  if (auto rhsCst = getConstant(operands[1])) {
910  auto width = op.getLhs().getType().getBitWidthOrSentinel();
911 
912  // neq(x, 1) -> not(x) when x is 1 bit
913  if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
914  op.getRhs().getType() == op.getType()) {
915  return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
916  .getResult();
917  }
918 
919  // neq(x, 0) -> orr(x) when x is >1 bit
920  if (rhsCst->isZero() && width > 1) {
921  return rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs())
922  .getResult();
923  }
924 
925  // neq(x, ~0) -> not(andr(x))) when x is >1 bit
926  if (rhsCst->isAllOnes() && width > 1 &&
927  op.getLhs().getType() == op.getRhs().getType()) {
928  auto andrOp = rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs());
929  return rewriter.create<NotPrimOp>(op.getLoc(), andrOp).getResult();
930  }
931  }
932 
933  return {};
934  });
935 }
936 
937 OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
938  // TODO: implement constant folding, etc.
939  // Tracked in https://github.com/llvm/circt/issues/6696.
940  return {};
941 }
942 
943 OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
944  // TODO: implement constant folding, etc.
945  // Tracked in https://github.com/llvm/circt/issues/6724.
946  return {};
947 }
948 
949 OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
950  // TODO: implement constant folding, etc.
951  // Tracked in https://github.com/llvm/circt/issues/6725.
952  return {};
953 }
954 
955 //===----------------------------------------------------------------------===//
956 // Unary Operators
957 //===----------------------------------------------------------------------===//
958 
959 OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
960  auto base = getInput().getType();
961  auto w = getBitWidth(base);
962  if (w)
963  return getIntAttr(getType(), APInt(32, *w));
964  return {};
965 }
966 
967 OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
968  // No constant can be 'x' by definition.
969  if (auto cst = getConstant(adaptor.getArg()))
970  return getIntAttr(getType(), APInt(1, 0));
971  return {};
972 }
973 
974 OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
975  // No effect.
976  if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
977  return getInput();
978 
979  // Be careful to only fold the cast into the constant if the size is known.
980  // Otherwise width inference may produce differently-sized constants if the
981  // sign changes.
982  if (getType().base().hasWidth())
983  if (auto cst = getConstant(adaptor.getInput()))
984  return getIntAttr(getType(), *cst);
985 
986  return {};
987 }
988 
989 void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
990  MLIRContext *context) {
991  results.insert<patterns::StoUtoS>(context);
992 }
993 
994 OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
995  // No effect.
996  if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
997  return getInput();
998 
999  // Be careful to only fold the cast into the constant if the size is known.
1000  // Otherwise width inference may produce differently-sized constants if the
1001  // sign changes.
1002  if (getType().base().hasWidth())
1003  if (auto cst = getConstant(adaptor.getInput()))
1004  return getIntAttr(getType(), *cst);
1005 
1006  return {};
1007 }
1008 
1009 void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1010  MLIRContext *context) {
1011  results.insert<patterns::UtoStoU>(context);
1012 }
1013 
1014 OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1015  // No effect.
1016  if (getInput().getType() == getType())
1017  return getInput();
1018 
1019  // Constant fold.
1020  if (auto cst = getConstant(adaptor.getInput()))
1021  return BoolAttr::get(getContext(), cst->getBoolValue());
1022 
1023  return {};
1024 }
1025 
1026 OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1027  // No effect.
1028  if (getInput().getType() == getType())
1029  return getInput();
1030 
1031  // Constant fold.
1032  if (auto cst = getConstant(adaptor.getInput()))
1033  return BoolAttr::get(getContext(), cst->getBoolValue());
1034 
1035  return {};
1036 }
1037 
1038 OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1039  if (!hasKnownWidthIntTypes(*this))
1040  return {};
1041 
1042  // Signed to signed is a noop, unsigned operands prepend a zero bit.
1043  if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1044  getType().base().getWidthOrSentinel()))
1045  return getIntAttr(getType(), *cst);
1046 
1047  return {};
1048 }
1049 
1050 void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1051  MLIRContext *context) {
1052  results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1053 }
1054 
1055 OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1056  if (!hasKnownWidthIntTypes(*this))
1057  return {};
1058 
1059  // FIRRTL negate always adds a bit.
1060  // -x ---> 0-sext(x) or 0-zext(x)
1061  if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1062  getType().base().getWidthOrSentinel()))
1063  return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1064 
1065  return {};
1066 }
1067 
1068 OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1069  if (!hasKnownWidthIntTypes(*this))
1070  return {};
1071 
1072  if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1073  getType().base().getWidthOrSentinel()))
1074  return getIntAttr(getType(), ~*cst);
1075 
1076  return {};
1077 }
1078 
1079 void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1080  MLIRContext *context) {
1081  results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1082  patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1083  patterns::NotGt>(context);
1084 }
1085 
1086 OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1087  if (!hasKnownWidthIntTypes(*this))
1088  return {};
1089 
1090  if (getInput().getType().getBitWidthOrSentinel() == 0)
1091  return getIntAttr(getType(), APInt(1, 1));
1092 
1093  // x == -1
1094  if (auto cst = getConstant(adaptor.getInput()))
1095  return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1096 
1097  // one bit is identity. Only applies to UInt since we can't make a cast
1098  // here.
1099  if (isUInt1(getInput().getType()))
1100  return getInput();
1101 
1102  return {};
1103 }
1104 
1105 void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1106  MLIRContext *context) {
1107  results
1108  .insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1109  patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
1110  patterns::AndRCatZeroL, patterns::AndRCatZeroR,
1111  patterns::AndRCatAndR_left, patterns::AndRCatAndR_right>(context);
1112 }
1113 
1114 OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1115  if (!hasKnownWidthIntTypes(*this))
1116  return {};
1117 
1118  if (getInput().getType().getBitWidthOrSentinel() == 0)
1119  return getIntAttr(getType(), APInt(1, 0));
1120 
1121  // x != 0
1122  if (auto cst = getConstant(adaptor.getInput()))
1123  return getIntAttr(getType(), APInt(1, !cst->isZero()));
1124 
1125  // one bit is identity. Only applies to UInt since we can't make a cast
1126  // here.
1127  if (isUInt1(getInput().getType()))
1128  return getInput();
1129 
1130  return {};
1131 }
1132 
1133 void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1134  MLIRContext *context) {
1135  results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1136  patterns::OrRCatZeroH, patterns::OrRCatZeroL,
1137  patterns::OrRCatOrR_left, patterns::OrRCatOrR_right>(context);
1138 }
1139 
1140 OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1141  if (!hasKnownWidthIntTypes(*this))
1142  return {};
1143 
1144  if (getInput().getType().getBitWidthOrSentinel() == 0)
1145  return getIntAttr(getType(), APInt(1, 0));
1146 
1147  // popcount(x) & 1
1148  if (auto cst = getConstant(adaptor.getInput()))
1149  return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1150 
1151  // one bit is identity. Only applies to UInt since we can't make a cast here.
1152  if (isUInt1(getInput().getType()))
1153  return getInput();
1154 
1155  return {};
1156 }
1157 
1158 void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1159  MLIRContext *context) {
1160  results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1161  patterns::XorRCatZeroH, patterns::XorRCatZeroL,
1162  patterns::XorRCatXorR_left, patterns::XorRCatXorR_right>(
1163  context);
1164 }
1165 
1166 //===----------------------------------------------------------------------===//
1167 // Other Operators
1168 //===----------------------------------------------------------------------===//
1169 
1170 OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1171  // cat(x, 0-width) -> x
1172  // cat(0-width, x) -> x
1173  // Limit to unsigned (result type), as cannot insert cast here.
1174  IntType lhsType = getLhs().getType();
1175  IntType rhsType = getRhs().getType();
1176  if (lhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
1177  return getRhs();
1178  if (rhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
1179  return getLhs();
1180 
1181  if (!hasKnownWidthIntTypes(*this))
1182  return {};
1183 
1184  // Constant fold cat.
1185  if (auto lhs = getConstant(adaptor.getLhs()))
1186  if (auto rhs = getConstant(adaptor.getRhs()))
1187  return getIntAttr(getType(), lhs->concat(*rhs));
1188 
1189  return {};
1190 }
1191 
1192 void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1193  MLIRContext *context) {
1194  results.insert<patterns::DShlOfConstant>(context);
1195 }
1196 
1197 void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1198  MLIRContext *context) {
1199  results.insert<patterns::DShrOfConstant>(context);
1200 }
1201 
1202 void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1203  MLIRContext *context) {
1204  results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1205  patterns::CatCast>(context);
1206 }
1207 
1208 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1209  auto op = (*this);
1210  // BitCast is redundant if input and result types are same.
1211  if (op.getType() == op.getInput().getType())
1212  return op.getInput();
1213 
1214  // Two consecutive BitCasts are redundant if first bitcast type is same as the
1215  // final result type.
1216  if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1217  if (op.getType() == in.getInput().getType())
1218  return in.getInput();
1219 
1220  return {};
1221 }
1222 
1223 OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1224  IntType inputType = getInput().getType();
1225  IntType resultType = getType();
1226  // If we are extracting the entire input, then return it.
1227  if (inputType == getType() && resultType.hasWidth())
1228  return getInput();
1229 
1230  // Constant fold.
1231  if (hasKnownWidthIntTypes(*this))
1232  if (auto cst = getConstant(adaptor.getInput()))
1233  return getIntAttr(resultType,
1234  cst->extractBits(getHi() - getLo() + 1, getLo()));
1235 
1236  return {};
1237 }
1238 
1239 void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1240  MLIRContext *context) {
1241  results
1242  .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1243  patterns::BitsOfAnd, patterns::BitsOfPad>(context);
1244 }
1245 
1246 /// Replace the specified operation with a 'bits' op from the specified hi/lo
1247 /// bits. Insert a cast to handle the case where the original operation
1248 /// returned a signed integer.
1249 static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1250  unsigned loBit, PatternRewriter &rewriter) {
1251  auto resType = type_cast<IntType>(op->getResult(0).getType());
1252  if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1253  value = rewriter.create<BitsPrimOp>(op->getLoc(), value, hiBit, loBit);
1254 
1255  if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1256  value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1257  } else if (resType.isUnsigned() &&
1258  !type_cast<IntType>(value.getType()).isUnsigned()) {
1259  value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1260  }
1261  rewriter.replaceOp(op, value);
1262 }
1263 
1264 template <typename OpTy>
1265 static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1266  // mux : UInt<0> -> 0
1267  if (op.getType().getBitWidthOrSentinel() == 0)
1268  return getIntAttr(op.getType(),
1269  APInt(0, 0, op.getType().isSignedInteger()));
1270 
1271  // mux(cond, x, x) -> x
1272  if (op.getHigh() == op.getLow())
1273  return op.getHigh();
1274 
1275  // The following folds require that the result has a known width. Otherwise
1276  // the mux requires an additional padding operation to be inserted, which is
1277  // not possible in a fold.
1278  if (op.getType().getBitWidthOrSentinel() < 0)
1279  return {};
1280 
1281  // mux(0/1, x, y) -> x or y
1282  if (auto cond = getConstant(adaptor.getSel())) {
1283  if (cond->isZero() && op.getLow().getType() == op.getType())
1284  return op.getLow();
1285  if (!cond->isZero() && op.getHigh().getType() == op.getType())
1286  return op.getHigh();
1287  }
1288 
1289  // mux(cond, x, cst)
1290  if (auto lowCst = getConstant(adaptor.getLow())) {
1291  // mux(cond, c1, c2)
1292  if (auto highCst = getConstant(adaptor.getHigh())) {
1293  // mux(cond, cst, cst) -> cst
1294  if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1295  *highCst == *lowCst)
1296  return getIntAttr(op.getType(), *highCst);
1297  // mux(cond, 1, 0) -> cond
1298  if (highCst->isOne() && lowCst->isZero() &&
1299  op.getType() == op.getSel().getType())
1300  return op.getSel();
1301 
1302  // TODO: x ? ~0 : 0 -> sext(x)
1303  // TODO: "x ? c1 : c2" -> many tricks
1304  }
1305  // TODO: "x ? a : 0" -> sext(x) & a
1306  }
1307 
1308  // TODO: "x ? c1 : y" -> "~x ? y : c1"
1309  return {};
1310 }
1311 
1312 OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1313  return foldMux(*this, adaptor);
1314 }
1315 
1316 OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1317  return foldMux(*this, adaptor);
1318 }
1319 
1320 OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1321 
1322 namespace {
1323 
1324 // If the mux has a known output width, pad the operands up to this width.
1325 // Most folds on mux require that folded operands are of the same width as
1326 // the mux itself.
1327 class MuxPad : public mlir::RewritePattern {
1328 public:
1329  MuxPad(MLIRContext *context)
1330  : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1331 
1332  LogicalResult
1333  matchAndRewrite(Operation *op,
1334  mlir::PatternRewriter &rewriter) const override {
1335  auto mux = cast<MuxPrimOp>(op);
1336  auto width = mux.getType().getBitWidthOrSentinel();
1337  if (width < 0)
1338  return failure();
1339 
1340  auto pad = [&](Value input) -> Value {
1341  auto inputWidth =
1342  type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1343  if (inputWidth < 0 || width == inputWidth)
1344  return input;
1345  return rewriter
1346  .create<PadPrimOp>(mux.getLoc(), mux.getType(), input, width)
1347  .getResult();
1348  };
1349 
1350  auto newHigh = pad(mux.getHigh());
1351  auto newLow = pad(mux.getLow());
1352  if (newHigh == mux.getHigh() && newLow == mux.getLow())
1353  return failure();
1354 
1355  replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1356  rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1357  mux->getAttrs());
1358  return success();
1359  }
1360 };
1361 
1362 // Find muxes which have conditions dominated by other muxes with the same
1363 // condition.
1364 class MuxSharedCond : public mlir::RewritePattern {
1365 public:
1366  MuxSharedCond(MLIRContext *context)
1367  : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1368 
1369  static const int depthLimit = 5;
1370 
1371  Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1372  mlir::PatternRewriter &rewriter,
1373  bool updateInPlace) const {
1374  if (updateInPlace) {
1375  rewriter.modifyOpInPlace(mux, [&] {
1376  mux.setOperand(1, high);
1377  mux.setOperand(2, low);
1378  });
1379  return {};
1380  }
1381  rewriter.setInsertionPointAfter(mux);
1382  return rewriter
1383  .create<MuxPrimOp>(mux.getLoc(), mux.getType(),
1384  ValueRange{mux.getSel(), high, low})
1385  .getResult();
1386  }
1387 
1388  // Walk a dependent mux tree assuming the condition cond is true.
1389  Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1390  bool updateInPlace, int limit) const {
1391  MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1392  if (!mux)
1393  return {};
1394  if (mux.getSel() == cond)
1395  return mux.getHigh();
1396  if (limit > depthLimit)
1397  return {};
1398  updateInPlace &= mux->hasOneUse();
1399 
1400  if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1401  limit + 1))
1402  return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1403 
1404  if (Value v =
1405  tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1406  return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1407  return {};
1408  }
1409 
1410  // Walk a dependent mux tree assuming the condition cond is false.
1411  Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1412  bool updateInPlace, int limit) const {
1413  MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1414  if (!mux)
1415  return {};
1416  if (mux.getSel() == cond)
1417  return mux.getLow();
1418  if (limit > depthLimit)
1419  return {};
1420  updateInPlace &= mux->hasOneUse();
1421 
1422  if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1423  limit + 1))
1424  return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1425 
1426  if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1427  limit + 1))
1428  return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1429 
1430  return {};
1431  }
1432 
1433  LogicalResult
1434  matchAndRewrite(Operation *op,
1435  mlir::PatternRewriter &rewriter) const override {
1436  auto mux = cast<MuxPrimOp>(op);
1437  auto width = mux.getType().getBitWidthOrSentinel();
1438  if (width < 0)
1439  return failure();
1440 
1441  if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1442  rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1443  return success();
1444  }
1445 
1446  if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
1447  rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1448  return success();
1449  }
1450 
1451  return failure();
1452  }
1453 };
1454 } // namespace
1455 
1456 void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1457  MLIRContext *context) {
1458  results
1459  .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1460  patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1461  patterns::MuxSameTrue, patterns::MuxSameFalse,
1462  patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1463  context);
1464 }
1465 
1466 void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1467  RewritePatternSet &results, MLIRContext *context) {
1468  results.add<patterns::Mux2PadSel>(context);
1469 }
1470 
1471 void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1472  RewritePatternSet &results, MLIRContext *context) {
1473  results.add<patterns::Mux4PadSel>(context);
1474 }
1475 
1476 OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1477  auto input = this->getInput();
1478 
1479  // pad(x) -> x if the width doesn't change.
1480  if (input.getType() == getType())
1481  return input;
1482 
1483  // Need to know the input width.
1484  auto inputType = input.getType().base();
1485  int32_t width = inputType.getWidthOrSentinel();
1486  if (width == -1)
1487  return {};
1488 
1489  // Constant fold.
1490  if (auto cst = getConstant(adaptor.getInput())) {
1491  auto destWidth = getType().base().getWidthOrSentinel();
1492  if (destWidth == -1)
1493  return {};
1494 
1495  if (inputType.isSigned() && cst->getBitWidth())
1496  return getIntAttr(getType(), cst->sext(destWidth));
1497  return getIntAttr(getType(), cst->zext(destWidth));
1498  }
1499 
1500  return {};
1501 }
1502 
1503 OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1504  auto input = this->getInput();
1505  IntType inputType = input.getType();
1506  int shiftAmount = getAmount();
1507 
1508  // shl(x, 0) -> x
1509  if (shiftAmount == 0)
1510  return input;
1511 
1512  // Constant fold.
1513  if (auto cst = getConstant(adaptor.getInput())) {
1514  auto inputWidth = inputType.getWidthOrSentinel();
1515  if (inputWidth != -1) {
1516  auto resultWidth = inputWidth + shiftAmount;
1517  shiftAmount = std::min(shiftAmount, resultWidth);
1518  return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1519  }
1520  }
1521  return {};
1522 }
1523 
1524 OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1525  auto input = this->getInput();
1526  IntType inputType = input.getType();
1527  int shiftAmount = getAmount();
1528  auto inputWidth = inputType.getWidthOrSentinel();
1529 
1530  // shr(x, 0) -> x
1531  // Once the shr width changes, do this: shiftAmount == 0 &&
1532  // (!inputType.isSigned() || inputWidth > 0)
1533  if (shiftAmount == 0 && inputWidth > 0)
1534  return input;
1535 
1536  if (inputWidth == -1)
1537  return {};
1538  if (inputWidth == 0)
1539  return getIntZerosAttr(getType());
1540 
1541  // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
1542  // If x is signed, it is the sign bit.
1543  if (shiftAmount >= inputWidth && inputType.isUnsigned())
1544  return getIntAttr(getType(), APInt(0, 0, false));
1545 
1546  // Constant fold.
1547  if (auto cst = getConstant(adaptor.getInput())) {
1548  APInt value;
1549  if (inputType.isSigned())
1550  value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1551  else
1552  value = cst->lshr(std::min(shiftAmount, inputWidth));
1553  auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1554  return getIntAttr(getType(), value.trunc(resultWidth));
1555  }
1556  return {};
1557 }
1558 
1559 LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1560  auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1561  if (inputWidth <= 0)
1562  return failure();
1563 
1564  // If we know the input width, we can canonicalize this into a BitsPrimOp.
1565  unsigned shiftAmount = op.getAmount();
1566  if (int(shiftAmount) >= inputWidth) {
1567  // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
1568  if (op.getType().base().isUnsigned())
1569  return failure();
1570 
1571  // Shifting a signed value by the full width is actually taking the
1572  // sign bit. If the shift amount is greater than the input width, it
1573  // is equivalent to shifting by the input width.
1574  shiftAmount = inputWidth - 1;
1575  }
1576 
1577  replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1578  return success();
1579 }
1580 
1581 LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1582  PatternRewriter &rewriter) {
1583  auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1584  if (inputWidth <= 0)
1585  return failure();
1586 
1587  // If we know the input width, we can canonicalize this into a BitsPrimOp.
1588  unsigned keepAmount = op.getAmount();
1589  if (keepAmount)
1590  replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1591  rewriter);
1592  return success();
1593 }
1594 
1595 OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1596  if (hasKnownWidthIntTypes(*this))
1597  if (auto cst = getConstant(adaptor.getInput())) {
1598  int shiftAmount =
1599  getInput().getType().base().getWidthOrSentinel() - getAmount();
1600  return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1601  }
1602 
1603  return {};
1604 }
1605 
1606 OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1607  if (hasKnownWidthIntTypes(*this))
1608  if (auto cst = getConstant(adaptor.getInput()))
1609  return getIntAttr(getType(),
1610  cst->trunc(getType().base().getWidthOrSentinel()));
1611  return {};
1612 }
1613 
1614 LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1615  PatternRewriter &rewriter) {
1616  auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1617  if (inputWidth <= 0)
1618  return failure();
1619 
1620  // If we know the input width, we can canonicalize this into a BitsPrimOp.
1621  unsigned dropAmount = op.getAmount();
1622  if (dropAmount != unsigned(inputWidth))
1623  replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
1624  rewriter);
1625  return success();
1626 }
1627 
1628 void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1629  MLIRContext *context) {
1630  results.add<patterns::SubaccessOfConstant>(context);
1631 }
1632 
1633 OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1634  // If there is only one input, just return it.
1635  if (adaptor.getInputs().size() == 1)
1636  return getOperand(1);
1637 
1638  if (auto constIndex = getConstant(adaptor.getIndex())) {
1639  auto index = constIndex->getZExtValue();
1640  if (index < getInputs().size())
1641  return getInputs()[getInputs().size() - 1 - index];
1642  }
1643 
1644  return {};
1645 }
1646 
1647 LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
1648  PatternRewriter &rewriter) {
1649  // If all operands are equal, just canonicalize to it. We can add this
1650  // canonicalization as a folder but it costly to look through all inputs so it
1651  // is added here.
1652  if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
1653  return input == op.getInputs().front();
1654  })) {
1655  replaceOpAndCopyName(rewriter, op, op.getInputs().front());
1656  return success();
1657  }
1658 
1659  // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
1660  // a[0]`), we can fold the op into subaccess op `a[idx]`.
1661  if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
1662  if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
1663  auto subindex = e.value().template getDefiningOp<SubindexOp>();
1664  return subindex && lastSubindex.getInput() == subindex.getInput() &&
1665  subindex.getIndex() + e.index() + 1 == op.getInputs().size();
1666  })) {
1667  replaceOpWithNewOpAndCopyName<SubaccessOp>(
1668  rewriter, op, lastSubindex.getInput(), op.getIndex());
1669  return success();
1670  }
1671  }
1672 
1673  // If the size is 2, canonicalize into a normal mux to introduce more folds.
1674  if (op.getInputs().size() != 2)
1675  return failure();
1676 
1677  // TODO: Handle even when `index` doesn't have uint<1>.
1678  auto uintType = op.getIndex().getType();
1679  if (uintType.getBitWidthOrSentinel() != 1)
1680  return failure();
1681 
1682  // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
1683  replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1684  rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
1685  return success();
1686 }
1687 
1688 //===----------------------------------------------------------------------===//
1689 // Declarations
1690 //===----------------------------------------------------------------------===//
1691 
1692 /// Scan all the uses of the specified value, checking to see if there is
1693 /// exactly one connect that has the value as its destination. This returns the
1694 /// operation if found and if all the other users are "reads" from the value.
1695 /// Returns null if there are no connects, or multiple connects to the value, or
1696 /// if the value is involved in an `AttachOp`, or if the connect isn't matching.
1697 ///
1698 /// Note that this will simply return the connect, which is located *anywhere*
1699 /// after the definition of the value. Users of this function are likely
1700 /// interested in the source side of the returned connect, the definition of
1701 /// which does likely not dominate the original value.
1702 MatchingConnectOp firrtl::getSingleConnectUserOf(Value value) {
1703  MatchingConnectOp connect;
1704  for (Operation *user : value.getUsers()) {
1705  // If we see an attach or aggregate sublements, just conservatively fail.
1706  if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
1707  return {};
1708 
1709  if (auto aConnect = dyn_cast<FConnectLike>(user))
1710  if (aConnect.getDest() == value) {
1711  auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
1712  // If this is not a matching connect, a second matching connect or in a
1713  // different block, fail.
1714  if (!matchingConnect || (connect && connect != matchingConnect) ||
1715  matchingConnect->getBlock() != value.getParentBlock())
1716  return {};
1717  else
1718  connect = matchingConnect;
1719  }
1720  }
1721  return connect;
1722 }
1723 
1724 // Forward simple values through wire's and reg's.
1725 static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op,
1726  PatternRewriter &rewriter) {
1727  // While we can do this for nearly all wires, we currently limit it to simple
1728  // things.
1729  Operation *connectedDecl = op.getDest().getDefiningOp();
1730  if (!connectedDecl)
1731  return failure();
1732 
1733  // Only support wire and reg for now.
1734  if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
1735  return failure();
1736  if (hasDontTouch(connectedDecl) ||
1737  !AnnotationSet(connectedDecl).canBeDeleted() ||
1738  !hasDroppableName(connectedDecl) ||
1739  cast<Forceable>(connectedDecl).isForceable())
1740  return failure();
1741 
1742  // Only forward if the types exactly match and there is one connect.
1743  if (getSingleConnectUserOf(op.getDest()) != op)
1744  return failure();
1745 
1746  // Only forward if there is more than one use
1747  if (connectedDecl->hasOneUse())
1748  return failure();
1749 
1750  // Only do this if the connectee and the declaration are in the same block.
1751  auto *declBlock = connectedDecl->getBlock();
1752  auto *srcValueOp = op.getSrc().getDefiningOp();
1753  if (!srcValueOp) {
1754  // Ports are ok for wires but not registers.
1755  if (!isa<WireOp>(connectedDecl))
1756  return failure();
1757 
1758  } else {
1759  // Constants/invalids in the same block are ok to forward, even through
1760  // reg's since the clocking doesn't matter for constants.
1761  if (!isa<ConstantOp>(srcValueOp))
1762  return failure();
1763  if (srcValueOp->getBlock() != declBlock)
1764  return failure();
1765  }
1766 
1767  // Ok, we know we are doing the transformation.
1768 
1769  auto replacement = op.getSrc();
1770  // This will be replaced with the constant source. First, make sure the
1771  // constant dominates all users.
1772  if (srcValueOp && srcValueOp != &declBlock->front())
1773  srcValueOp->moveBefore(&declBlock->front());
1774 
1775  // Replace all things *using* the decl with the constant/port, and
1776  // remove the declaration.
1777  replaceOpAndCopyName(rewriter, connectedDecl, replacement);
1778 
1779  // Remove the connect
1780  rewriter.eraseOp(op);
1781  return success();
1782 }
1783 
1784 void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1785  MLIRContext *context) {
1786  results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
1787  context);
1788 }
1789 
1790 LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
1791  PatternRewriter &rewriter) {
1792  // TODO: Canonicalize towards explicit extensions and flips here.
1793 
1794  // If there is a simple value connected to a foldable decl like a wire or reg,
1795  // see if we can eliminate the decl.
1796  if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
1797  return success();
1798  return failure();
1799 }
1800 
1801 //===----------------------------------------------------------------------===//
1802 // Statements
1803 //===----------------------------------------------------------------------===//
1804 
1805 /// If the specified value has an AttachOp user strictly dominating by
1806 /// "dominatingAttach" then return it.
1807 static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
1808  for (auto *user : value.getUsers()) {
1809  auto attach = dyn_cast<AttachOp>(user);
1810  if (!attach || attach == dominatedAttach)
1811  continue;
1812  if (attach->isBeforeInBlock(dominatedAttach))
1813  return attach;
1814  }
1815  return {};
1816 }
1817 
1818 LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
1819  // Single operand attaches are a noop.
1820  if (op.getNumOperands() <= 1) {
1821  rewriter.eraseOp(op);
1822  return success();
1823  }
1824 
1825  for (auto operand : op.getOperands()) {
1826  // Check to see if any of our operands has other attaches to it:
1827  // attach x, y
1828  // ...
1829  // attach x, z
1830  // If so, we can merge these into "attach x, y, z".
1831  if (auto attach = getDominatingAttachUser(operand, op)) {
1832  SmallVector<Value> newOperands(op.getOperands());
1833  for (auto newOperand : attach.getOperands())
1834  if (newOperand != operand) // Don't add operand twice.
1835  newOperands.push_back(newOperand);
1836  rewriter.create<AttachOp>(op->getLoc(), newOperands);
1837  rewriter.eraseOp(attach);
1838  rewriter.eraseOp(op);
1839  return success();
1840  }
1841 
1842  // If this wire is *only* used by an attach then we can just delete
1843  // it.
1844  // TODO: May need to be sensitive to "don't touch" or other
1845  // annotations.
1846  if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
1847  if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
1848  !wire.isForceable()) {
1849  SmallVector<Value> newOperands;
1850  for (auto newOperand : op.getOperands())
1851  if (newOperand != operand) // Don't the add wire.
1852  newOperands.push_back(newOperand);
1853 
1854  rewriter.create<AttachOp>(op->getLoc(), newOperands);
1855  rewriter.eraseOp(op);
1856  rewriter.eraseOp(wire);
1857  return success();
1858  }
1859  }
1860  }
1861  return failure();
1862 }
1863 
1864 /// Replaces the given op with the contents of the given single-block region.
1865 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1866  Region &region) {
1867  assert(llvm::hasSingleElement(region) && "expected single-region block");
1868  rewriter.inlineBlockBefore(&region.front(), op, {});
1869 }
1870 
1871 LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
1872  if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
1873  if (constant.getValue().isAllOnes())
1874  replaceOpWithRegion(rewriter, op, op.getThenRegion());
1875  else if (op.hasElseRegion() && !op.getElseRegion().empty())
1876  replaceOpWithRegion(rewriter, op, op.getElseRegion());
1877 
1878  rewriter.eraseOp(op);
1879 
1880  return success();
1881  }
1882 
1883  // Erase empty if-else block.
1884  if (!op.getThenBlock().empty() && op.hasElseRegion() &&
1885  op.getElseBlock().empty()) {
1886  rewriter.eraseBlock(&op.getElseBlock());
1887  return success();
1888  }
1889 
1890  // Erase empty whens.
1891 
1892  // If there is stuff in the then block, leave this operation alone.
1893  if (!op.getThenBlock().empty())
1894  return failure();
1895 
1896  // If not and there is no else, then this operation is just useless.
1897  if (!op.hasElseRegion() || op.getElseBlock().empty()) {
1898  rewriter.eraseOp(op);
1899  return success();
1900  }
1901  return failure();
1902 }
1903 
1904 namespace {
1905 // Remove private nodes. If they have an interesting names, move the name to
1906 // the source expression.
1907 struct FoldNodeName : public mlir::RewritePattern {
1908  FoldNodeName(MLIRContext *context)
1909  : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1910  LogicalResult matchAndRewrite(Operation *op,
1911  PatternRewriter &rewriter) const override {
1912  auto node = cast<NodeOp>(op);
1913  auto name = node.getNameAttr();
1914  if (!node.hasDroppableName() || node.getInnerSym() ||
1915  !AnnotationSet(node).canBeDeleted() || node.isForceable())
1916  return failure();
1917  auto *newOp = node.getInput().getDefiningOp();
1918  // Best effort, do not rename InstanceOp
1919  if (newOp && !isa<InstanceOp>(newOp))
1920  updateName(rewriter, newOp, name);
1921  rewriter.replaceOp(node, node.getInput());
1922  return success();
1923  }
1924 };
1925 
1926 // Bypass nodes.
1927 struct NodeBypass : public mlir::RewritePattern {
1928  NodeBypass(MLIRContext *context)
1929  : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1930  LogicalResult matchAndRewrite(Operation *op,
1931  PatternRewriter &rewriter) const override {
1932  auto node = cast<NodeOp>(op);
1933  if (node.getInnerSym() || !AnnotationSet(node).canBeDeleted() ||
1934  node.use_empty() || node.isForceable())
1935  return failure();
1936  rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
1937  return success();
1938  }
1939 };
1940 
1941 } // namespace
1942 
1943 template <typename OpTy>
1944 static LogicalResult demoteForceableIfUnused(OpTy op,
1945  PatternRewriter &rewriter) {
1946  if (!op.isForceable() || !op.getDataRef().use_empty())
1947  return failure();
1948 
1949  firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
1950  return success();
1951 }
1952 
1953 // Interesting names and symbols and don't touch force nodes to stick around.
1954 LogicalResult NodeOp::fold(FoldAdaptor adaptor,
1955  SmallVectorImpl<OpFoldResult> &results) {
1956  if (!hasDroppableName())
1957  return failure();
1958  if (hasDontTouch(getResult())) // handles inner symbols
1959  return failure();
1960  if (getAnnotationsAttr() &&
1961  !AnnotationSet(getAnnotationsAttr()).canBeDeleted())
1962  return failure();
1963  if (isForceable())
1964  return failure();
1965  if (!adaptor.getInput())
1966  return failure();
1967 
1968  results.push_back(adaptor.getInput());
1969  return success();
1970 }
1971 
1972 void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1973  MLIRContext *context) {
1974  results.insert<FoldNodeName>(context);
1975  results.add(demoteForceableIfUnused<NodeOp>);
1976 }
1977 
1978 namespace {
1979 // For a lhs, find all the writers of fields of the aggregate type. If there
1980 // is one writer for each field, merge the writes
1981 struct AggOneShot : public mlir::RewritePattern {
1982  AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
1983  : RewritePattern(name, 0, context) {}
1984 
1985  SmallVector<Value> getCompleteWrite(Operation *lhs) const {
1986  auto lhsTy = lhs->getResult(0).getType();
1987  if (!type_isa<BundleType, FVectorType>(lhsTy))
1988  return {};
1989 
1990  DenseMap<uint32_t, Value> fields;
1991  for (Operation *user : lhs->getResult(0).getUsers()) {
1992  if (user->getParentOp() != lhs->getParentOp())
1993  return {};
1994  if (auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
1995  if (aConnect.getDest() == lhs->getResult(0))
1996  return {};
1997  } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
1998  for (Operation *subuser : subField.getResult().getUsers()) {
1999  if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2000  if (aConnect.getDest() == subField) {
2001  if (subuser->getParentOp() != lhs->getParentOp())
2002  return {};
2003  if (fields.count(subField.getFieldIndex())) // duplicate write
2004  return {};
2005  fields[subField.getFieldIndex()] = aConnect.getSrc();
2006  }
2007  continue;
2008  }
2009  return {};
2010  }
2011  } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2012  for (Operation *subuser : subIndex.getResult().getUsers()) {
2013  if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2014  if (aConnect.getDest() == subIndex) {
2015  if (subuser->getParentOp() != lhs->getParentOp())
2016  return {};
2017  if (fields.count(subIndex.getIndex())) // duplicate write
2018  return {};
2019  fields[subIndex.getIndex()] = aConnect.getSrc();
2020  }
2021  continue;
2022  }
2023  return {};
2024  }
2025  } else {
2026  return {};
2027  }
2028  }
2029 
2030  SmallVector<Value> values;
2031  uint32_t total = type_isa<BundleType>(lhsTy)
2032  ? type_cast<BundleType>(lhsTy).getNumElements()
2033  : type_cast<FVectorType>(lhsTy).getNumElements();
2034  for (uint32_t i = 0; i < total; ++i) {
2035  if (!fields.count(i))
2036  return {};
2037  values.push_back(fields[i]);
2038  }
2039  return values;
2040  }
2041 
2042  LogicalResult matchAndRewrite(Operation *op,
2043  PatternRewriter &rewriter) const override {
2044  auto values = getCompleteWrite(op);
2045  if (values.empty())
2046  return failure();
2047  rewriter.setInsertionPointToEnd(op->getBlock());
2048  auto dest = op->getResult(0);
2049  auto destType = dest.getType();
2050 
2051  // If not passive, cannot matchingconnect.
2052  if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2053  return failure();
2054 
2055  Value newVal = type_isa<BundleType>(destType)
2056  ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2057  destType, values)
2058  : rewriter.createOrFold<VectorCreateOp>(
2059  op->getLoc(), destType, values);
2060  rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2061  for (Operation *user : dest.getUsers()) {
2062  if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2063  for (Operation *subuser :
2064  llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2065  if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2066  if (aConnect.getDest() == subIndex)
2067  rewriter.eraseOp(aConnect);
2068  } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2069  for (Operation *subuser :
2070  llvm::make_early_inc_range(subField.getResult().getUsers()))
2071  if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2072  if (aConnect.getDest() == subField)
2073  rewriter.eraseOp(aConnect);
2074  }
2075  }
2076  return success();
2077  }
2078 };
2079 
2080 struct WireAggOneShot : public AggOneShot {
2081  WireAggOneShot(MLIRContext *context)
2082  : AggOneShot(WireOp::getOperationName(), 0, context) {}
2083 };
2084 struct SubindexAggOneShot : public AggOneShot {
2085  SubindexAggOneShot(MLIRContext *context)
2086  : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2087 };
2088 struct SubfieldAggOneShot : public AggOneShot {
2089  SubfieldAggOneShot(MLIRContext *context)
2090  : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2091 };
2092 } // namespace
2093 
2094 void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2095  MLIRContext *context) {
2096  results.insert<WireAggOneShot>(context);
2097  results.add(demoteForceableIfUnused<WireOp>);
2098 }
2099 
2100 void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2101  MLIRContext *context) {
2102  results.insert<SubindexAggOneShot>(context);
2103 }
2104 
2105 OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2106  auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2107  if (!attr)
2108  return {};
2109  return attr[getIndex()];
2110 }
2111 
2112 OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2113  auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2114  if (!attr)
2115  return {};
2116  auto index = getFieldIndex();
2117  return attr[index];
2118 }
2119 
2120 void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2121  MLIRContext *context) {
2122  results.insert<SubfieldAggOneShot>(context);
2123 }
2124 
2125 static Attribute collectFields(MLIRContext *context,
2126  ArrayRef<Attribute> operands) {
2127  for (auto operand : operands)
2128  if (!operand)
2129  return {};
2130  return ArrayAttr::get(context, operands);
2131 }
2132 
2133 OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2134  // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2135  // bundle<a:..., b:...>.
2136  if (getNumOperands() > 0)
2137  if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2138  if (first.getFieldIndex() == 0 &&
2139  first.getInput().getType() == getType() &&
2140  llvm::all_of(
2141  llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2142  auto subindex =
2143  elem.value().template getDefiningOp<SubfieldOp>();
2144  return subindex && subindex.getInput() == first.getInput() &&
2145  subindex.getFieldIndex() == elem.index();
2146  }))
2147  return first.getInput();
2148 
2149  return collectFields(getContext(), adaptor.getOperands());
2150 }
2151 
2152 OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2153  // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2154  // vector<..., 2>.
2155  if (getNumOperands() > 0)
2156  if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2157  if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2158  llvm::all_of(
2159  llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2160  auto subindex =
2161  elem.value().template getDefiningOp<SubindexOp>();
2162  return subindex && subindex.getInput() == first.getInput() &&
2163  subindex.getIndex() == elem.index();
2164  }))
2165  return first.getInput();
2166 
2167  return collectFields(getContext(), adaptor.getOperands());
2168 }
2169 
2170 OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2171  if (getOperand().getType() == getType())
2172  return getOperand();
2173  return {};
2174 }
2175 
2176 namespace {
2177 // A register with constant reset and all connection to either itself or the
2178 // same constant, must be replaced by the constant.
2179 struct FoldResetMux : public mlir::RewritePattern {
2180  FoldResetMux(MLIRContext *context)
2181  : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2182  LogicalResult matchAndRewrite(Operation *op,
2183  PatternRewriter &rewriter) const override {
2184  auto reg = cast<RegResetOp>(op);
2185  auto reset =
2186  dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2187  if (!reset || hasDontTouch(reg.getOperation()) ||
2188  !AnnotationSet(reg).canBeDeleted() || reg.isForceable())
2189  return failure();
2190  // Find the one true connect, or bail
2191  auto con = getSingleConnectUserOf(reg.getResult());
2192  if (!con)
2193  return failure();
2194 
2195  auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2196  if (!mux)
2197  return failure();
2198  auto *high = mux.getHigh().getDefiningOp();
2199  auto *low = mux.getLow().getDefiningOp();
2200  auto constOp = dyn_cast_or_null<ConstantOp>(high);
2201 
2202  if (constOp && low != reg)
2203  return failure();
2204  if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2205  constOp = dyn_cast<ConstantOp>(low);
2206 
2207  if (!constOp || constOp.getType() != reset.getType() ||
2208  constOp.getValue() != reset.getValue())
2209  return failure();
2210 
2211  // Check all types should be typed by now
2212  auto regTy = reg.getResult().getType();
2213  if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2214  mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2215  regTy.getBitWidthOrSentinel() < 0)
2216  return failure();
2217 
2218  // Ok, we know we are doing the transformation.
2219 
2220  // Make sure the constant dominates all users.
2221  if (constOp != &con->getBlock()->front())
2222  constOp->moveBefore(&con->getBlock()->front());
2223 
2224  // Replace the register with the constant.
2225  replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2226  // Remove the connect.
2227  rewriter.eraseOp(con);
2228  return success();
2229  }
2230 };
2231 } // namespace
2232 
2233 static bool isDefinedByOneConstantOp(Value v) {
2234  if (auto c = v.getDefiningOp<ConstantOp>())
2235  return c.getValue().isOne();
2236  if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2237  return sc.getValue();
2238  return false;
2239 }
2240 
2241 static LogicalResult
2242 canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2243  if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2244  return failure();
2245 
2246  // Ignore 'passthrough'.
2247  (void)dropWrite(rewriter, reg->getResult(0), {});
2248  replaceOpWithNewOpAndCopyName<NodeOp>(
2249  rewriter, reg, reg.getResetValue(), reg.getNameAttr(), reg.getNameKind(),
2250  reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2251  return success();
2252 }
2253 
2254 void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2255  MLIRContext *context) {
2256  results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2257  results.add(canonicalizeRegResetWithOneReset);
2258  results.add(demoteForceableIfUnused<RegResetOp>);
2259 }
2260 
2261 // Returns the value connected to a port, if there is only one.
2262 static Value getPortFieldValue(Value port, StringRef name) {
2263  auto portTy = type_cast<BundleType>(port.getType());
2264  auto fieldIndex = portTy.getElementIndex(name);
2265  assert(fieldIndex && "missing field on memory port");
2266 
2267  Value value = {};
2268  for (auto *op : port.getUsers()) {
2269  auto portAccess = cast<SubfieldOp>(op);
2270  if (fieldIndex != portAccess.getFieldIndex())
2271  continue;
2272  auto conn = getSingleConnectUserOf(portAccess);
2273  if (!conn || value)
2274  return {};
2275  value = conn.getSrc();
2276  }
2277  return value;
2278 }
2279 
2280 // Returns true if the enable field of a port is set to false.
2281 static bool isPortDisabled(Value port) {
2282  auto value = getPortFieldValue(port, "en");
2283  if (!value)
2284  return false;
2285  auto portConst = value.getDefiningOp<ConstantOp>();
2286  if (!portConst)
2287  return false;
2288  return portConst.getValue().isZero();
2289 }
2290 
2291 // Returns true if the data output is unused.
2292 static bool isPortUnused(Value port, StringRef data) {
2293  auto portTy = type_cast<BundleType>(port.getType());
2294  auto fieldIndex = portTy.getElementIndex(data);
2295  assert(fieldIndex && "missing enable flag on memory port");
2296 
2297  for (auto *op : port.getUsers()) {
2298  auto portAccess = cast<SubfieldOp>(op);
2299  if (fieldIndex != portAccess.getFieldIndex())
2300  continue;
2301  if (!portAccess.use_empty())
2302  return false;
2303  }
2304 
2305  return true;
2306 }
2307 
2308 // Returns the value connected to a port, if there is only one.
2309 static void replacePortField(PatternRewriter &rewriter, Value port,
2310  StringRef name, Value value) {
2311  auto portTy = type_cast<BundleType>(port.getType());
2312  auto fieldIndex = portTy.getElementIndex(name);
2313  assert(fieldIndex && "missing field on memory port");
2314 
2315  for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2316  auto portAccess = cast<SubfieldOp>(op);
2317  if (fieldIndex != portAccess.getFieldIndex())
2318  continue;
2319  rewriter.replaceAllUsesWith(portAccess, value);
2320  rewriter.eraseOp(portAccess);
2321  }
2322 }
2323 
2324 // Remove accesses to a port which is used.
2325 static void erasePort(PatternRewriter &rewriter, Value port) {
2326  // Helper to create a dummy 0 clock for the dummy registers.
2327  Value clock;
2328  auto getClock = [&] {
2329  if (!clock)
2330  clock = rewriter.create<SpecialConstantOp>(
2331  port.getLoc(), ClockType::get(rewriter.getContext()), false);
2332  return clock;
2333  };
2334 
2335  // Find the clock field of the port and determine whether the port is
2336  // accessed only through its subfields or as a whole wire. If the port
2337  // is used in its entirety, replace it with a wire. Otherwise,
2338  // eliminate individual subfields and replace with reasonable defaults.
2339  for (auto *op : port.getUsers()) {
2340  auto subfield = dyn_cast<SubfieldOp>(op);
2341  if (!subfield) {
2342  auto ty = port.getType();
2343  auto reg = rewriter.create<RegOp>(port.getLoc(), ty, getClock());
2344  rewriter.replaceAllUsesWith(port, reg.getResult());
2345  return;
2346  }
2347  }
2348 
2349  // Remove all connects to field accesses as they are no longer relevant.
2350  // If field values are used anywhere, which should happen solely for read
2351  // ports, a dummy register is introduced which replicates the behaviour of
2352  // memory that is never written, but might be read.
2353  for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2354  auto access = cast<SubfieldOp>(accessOp);
2355  for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2356  auto connect = dyn_cast<FConnectLike>(user);
2357  if (connect && connect.getDest() == access) {
2358  rewriter.eraseOp(user);
2359  continue;
2360  }
2361  }
2362  if (access.use_empty()) {
2363  rewriter.eraseOp(access);
2364  continue;
2365  }
2366 
2367  // Replace read values with a register that is never written, handing off
2368  // the canonicalization of such a register to another canonicalizer.
2369  auto ty = access.getType();
2370  auto reg = rewriter.create<RegOp>(access.getLoc(), ty, getClock());
2371  rewriter.replaceOp(access, reg.getResult());
2372  }
2373  assert(port.use_empty() && "port should have no remaining uses");
2374 }
2375 
2376 namespace {
2377 // If memory has known, but zero width, eliminate it.
2378 struct FoldZeroWidthMemory : public mlir::RewritePattern {
2379  FoldZeroWidthMemory(MLIRContext *context)
2380  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2381  LogicalResult matchAndRewrite(Operation *op,
2382  PatternRewriter &rewriter) const override {
2383  MemOp mem = cast<MemOp>(op);
2384  if (hasDontTouch(mem))
2385  return failure();
2386 
2387  if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2388  mem.getDataType().getBitWidthOrSentinel() != 0)
2389  return failure();
2390 
2391  // Make sure are users are safe to replace
2392  for (auto port : mem.getResults())
2393  for (auto *user : port.getUsers())
2394  if (!isa<SubfieldOp>(user))
2395  return failure();
2396 
2397  // Annoyingly, there isn't a good replacement for the port as a whole,
2398  // since they have an outer flip type.
2399  for (auto port : op->getResults()) {
2400  for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2401  SubfieldOp sfop = cast<SubfieldOp>(user);
2402  StringRef fieldName = sfop.getFieldName();
2403  auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2404  rewriter, sfop, sfop.getResult().getType())
2405  .getResult();
2406  if (fieldName.ends_with("data")) {
2407  // Make sure to write data ports.
2408  auto zero = rewriter.create<firrtl::ConstantOp>(
2409  wire.getLoc(), firrtl::type_cast<IntType>(wire.getType()),
2410  APInt::getZero(0));
2411  rewriter.create<MatchingConnectOp>(wire.getLoc(), wire, zero);
2412  }
2413  }
2414  }
2415  rewriter.eraseOp(op);
2416  return success();
2417  }
2418 };
2419 
2420 // If memory has no write ports and no file initialization, eliminate it.
2421 struct FoldReadOrWriteOnlyMemory : public mlir::RewritePattern {
2422  FoldReadOrWriteOnlyMemory(MLIRContext *context)
2423  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2424  LogicalResult matchAndRewrite(Operation *op,
2425  PatternRewriter &rewriter) const override {
2426  MemOp mem = cast<MemOp>(op);
2427  if (hasDontTouch(mem))
2428  return failure();
2429  bool isRead = false, isWritten = false;
2430  for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2431  switch (mem.getPortKind(i)) {
2432  case MemOp::PortKind::Read:
2433  isRead = true;
2434  if (isWritten)
2435  return failure();
2436  continue;
2437  case MemOp::PortKind::Write:
2438  isWritten = true;
2439  if (isRead)
2440  return failure();
2441  continue;
2442  case MemOp::PortKind::Debug:
2443  case MemOp::PortKind::ReadWrite:
2444  return failure();
2445  }
2446  llvm_unreachable("unknown port kind");
2447  }
2448  assert((!isWritten || !isRead) && "memory is in use");
2449 
2450  // If the memory is read only, but has a file initialization, then we can't
2451  // remove it. A write only memory with file initialization is okay to
2452  // remove.
2453  if (isRead && mem.getInit())
2454  return failure();
2455 
2456  for (auto port : mem.getResults())
2457  erasePort(rewriter, port);
2458 
2459  rewriter.eraseOp(op);
2460  return success();
2461  }
2462 };
2463 
2464 // Eliminate the dead ports of memories.
2465 struct FoldUnusedPorts : public mlir::RewritePattern {
2466  FoldUnusedPorts(MLIRContext *context)
2467  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2468  LogicalResult matchAndRewrite(Operation *op,
2469  PatternRewriter &rewriter) const override {
2470  MemOp mem = cast<MemOp>(op);
2471  if (hasDontTouch(mem))
2472  return failure();
2473  // Identify the dead and changed ports.
2474  llvm::SmallBitVector deadPorts(mem.getNumResults());
2475  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2476  // Do not simplify annotated ports.
2477  if (!mem.getPortAnnotation(i).empty())
2478  continue;
2479 
2480  // Skip debug ports.
2481  auto kind = mem.getPortKind(i);
2482  if (kind == MemOp::PortKind::Debug)
2483  continue;
2484 
2485  // If a port is disabled, always eliminate it.
2486  if (isPortDisabled(port)) {
2487  deadPorts.set(i);
2488  continue;
2489  }
2490  // Eliminate read ports whose outputs are not used.
2491  if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
2492  deadPorts.set(i);
2493  continue;
2494  }
2495  }
2496  if (deadPorts.none())
2497  return failure();
2498 
2499  // Rebuild the new memory with the altered ports.
2500  SmallVector<Type> resultTypes;
2501  SmallVector<StringRef> portNames;
2502  SmallVector<Attribute> portAnnotations;
2503  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2504  if (deadPorts[i])
2505  continue;
2506  resultTypes.push_back(port.getType());
2507  portNames.push_back(mem.getPortName(i));
2508  portAnnotations.push_back(mem.getPortAnnotation(i));
2509  }
2510 
2511  MemOp newOp;
2512  if (!resultTypes.empty())
2513  newOp = rewriter.create<MemOp>(
2514  mem.getLoc(), resultTypes, mem.getReadLatency(),
2515  mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2516  rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2517  mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2518  mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2519 
2520  // Replace the dead ports with dummy wires.
2521  unsigned nextPort = 0;
2522  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2523  if (deadPorts[i])
2524  erasePort(rewriter, port);
2525  else
2526  rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2527  }
2528 
2529  rewriter.eraseOp(op);
2530  return success();
2531  }
2532 };
2533 
2534 // Rewrite write-only read-write ports to write ports.
2535 struct FoldReadWritePorts : public mlir::RewritePattern {
2536  FoldReadWritePorts(MLIRContext *context)
2537  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2538  LogicalResult matchAndRewrite(Operation *op,
2539  PatternRewriter &rewriter) const override {
2540  MemOp mem = cast<MemOp>(op);
2541  if (hasDontTouch(mem))
2542  return failure();
2543 
2544  // Identify read-write ports whose read end is unused.
2545  llvm::SmallBitVector deadReads(mem.getNumResults());
2546  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2547  if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2548  continue;
2549  if (!mem.getPortAnnotation(i).empty())
2550  continue;
2551  if (isPortUnused(port, "rdata")) {
2552  deadReads.set(i);
2553  continue;
2554  }
2555  }
2556  if (deadReads.none())
2557  return failure();
2558 
2559  SmallVector<Type> resultTypes;
2560  SmallVector<StringRef> portNames;
2561  SmallVector<Attribute> portAnnotations;
2562  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2563  if (deadReads[i])
2564  resultTypes.push_back(
2565  MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2566  MemOp::PortKind::Write, mem.getMaskBits()));
2567  else
2568  resultTypes.push_back(port.getType());
2569 
2570  portNames.push_back(mem.getPortName(i));
2571  portAnnotations.push_back(mem.getPortAnnotation(i));
2572  }
2573 
2574  auto newOp = rewriter.create<MemOp>(
2575  mem.getLoc(), resultTypes, mem.getReadLatency(), mem.getWriteLatency(),
2576  mem.getDepth(), mem.getRuw(), rewriter.getStrArrayAttr(portNames),
2577  mem.getName(), mem.getNameKind(), mem.getAnnotations(),
2578  rewriter.getArrayAttr(portAnnotations), mem.getInnerSymAttr(),
2579  mem.getInitAttr(), mem.getPrefixAttr());
2580 
2581  for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2582  auto result = mem.getResult(i);
2583  auto newResult = newOp.getResult(i);
2584  if (deadReads[i]) {
2585  auto resultPortTy = type_cast<BundleType>(result.getType());
2586 
2587  // Rewrite accesses to the old port field to accesses to a
2588  // corresponding field of the new port.
2589  auto replace = [&](StringRef toName, StringRef fromName) {
2590  auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2591  assert(fromFieldIndex && "missing enable flag on memory port");
2592 
2593  auto toField = rewriter.create<SubfieldOp>(newResult.getLoc(),
2594  newResult, toName);
2595  for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2596  auto fromField = cast<SubfieldOp>(op);
2597  if (fromFieldIndex != fromField.getFieldIndex())
2598  continue;
2599  rewriter.replaceOp(fromField, toField.getResult());
2600  }
2601  };
2602 
2603  replace("addr", "addr");
2604  replace("en", "en");
2605  replace("clk", "clk");
2606  replace("data", "wdata");
2607  replace("mask", "wmask");
2608 
2609  // Remove the wmode field, replacing it with dummy wires.
2610  auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
2611  for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2612  auto wmodeField = cast<SubfieldOp>(op);
2613  if (wmodeFieldIndex != wmodeField.getFieldIndex())
2614  continue;
2615  rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2616  }
2617  } else {
2618  rewriter.replaceAllUsesWith(result, newResult);
2619  }
2620  }
2621  rewriter.eraseOp(op);
2622  return success();
2623  }
2624 };
2625 
2626 // Eliminate the dead ports of memories.
2627 struct FoldUnusedBits : public mlir::RewritePattern {
2628  FoldUnusedBits(MLIRContext *context)
2629  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2630 
2631  LogicalResult matchAndRewrite(Operation *op,
2632  PatternRewriter &rewriter) const override {
2633  MemOp mem = cast<MemOp>(op);
2634  if (hasDontTouch(mem))
2635  return failure();
2636 
2637  // Only apply the transformation if the memory is not sequential.
2638  const auto &summary = mem.getSummary();
2639  if (summary.isMasked || summary.isSeqMem())
2640  return failure();
2641 
2642  auto type = type_dyn_cast<IntType>(mem.getDataType());
2643  if (!type)
2644  return failure();
2645  auto width = type.getBitWidthOrSentinel();
2646  if (width <= 0)
2647  return failure();
2648 
2649  llvm::SmallBitVector usedBits(width);
2650  DenseMap<unsigned, unsigned> mapping;
2651 
2652  // Find which bits are used out of the users of a read port. This detects
2653  // ports whose data/rdata field is used only through bit select ops. The
2654  // bit selects are then used to build a bit-mask. The ops are collected.
2655  SmallVector<BitsPrimOp> readOps;
2656  auto findReadUsers = [&](Value port, StringRef field) {
2657  auto portTy = type_cast<BundleType>(port.getType());
2658  auto fieldIndex = portTy.getElementIndex(field);
2659  assert(fieldIndex && "missing data port");
2660 
2661  for (auto *op : port.getUsers()) {
2662  auto portAccess = cast<SubfieldOp>(op);
2663  if (fieldIndex != portAccess.getFieldIndex())
2664  continue;
2665 
2666  for (auto *user : op->getUsers()) {
2667  auto bits = dyn_cast<BitsPrimOp>(user);
2668  if (!bits) {
2669  usedBits.set();
2670  continue;
2671  }
2672 
2673  usedBits.set(bits.getLo(), bits.getHi() + 1);
2674  mapping[bits.getLo()] = 0;
2675  readOps.push_back(bits);
2676  }
2677  }
2678  };
2679 
2680  // Finds the users of write ports. This expects all the data/wdata fields
2681  // of the ports to be used solely as the destination of matching connects.
2682  // If a memory has ports with other uses, it is excluded from optimisation.
2683  SmallVector<MatchingConnectOp> writeOps;
2684  auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
2685  auto portTy = type_cast<BundleType>(port.getType());
2686  auto fieldIndex = portTy.getElementIndex(field);
2687  assert(fieldIndex && "missing data port");
2688 
2689  for (auto *op : port.getUsers()) {
2690  auto portAccess = cast<SubfieldOp>(op);
2691  if (fieldIndex != portAccess.getFieldIndex())
2692  continue;
2693 
2694  auto conn = getSingleConnectUserOf(portAccess);
2695  if (!conn)
2696  return failure();
2697 
2698  writeOps.push_back(conn);
2699  }
2700  return success();
2701  };
2702 
2703  // Traverse all ports and find the read and used data fields.
2704  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2705  // Do not simplify annotated ports.
2706  if (!mem.getPortAnnotation(i).empty())
2707  return failure();
2708 
2709  switch (mem.getPortKind(i)) {
2710  case MemOp::PortKind::Debug:
2711  // Skip debug ports.
2712  return failure();
2713  case MemOp::PortKind::Write:
2714  if (failed(findWriteUsers(port, "data")))
2715  return failure();
2716  continue;
2717  case MemOp::PortKind::Read:
2718  findReadUsers(port, "data");
2719  continue;
2720  case MemOp::PortKind::ReadWrite:
2721  if (failed(findWriteUsers(port, "wdata")))
2722  return failure();
2723  findReadUsers(port, "rdata");
2724  continue;
2725  }
2726  llvm_unreachable("unknown port kind");
2727  }
2728 
2729  // Perform the transformation is there are some bits missing. Unused
2730  // memories are handled in a different canonicalizer.
2731  if (usedBits.all() || usedBits.none())
2732  return failure();
2733 
2734  // Build a mapping of existing indices to compacted ones.
2735  SmallVector<std::pair<unsigned, unsigned>> ranges;
2736  unsigned newWidth = 0;
2737  for (int i = usedBits.find_first(); 0 <= i && i < width;) {
2738  int e = usedBits.find_next_unset(i);
2739  if (e < 0)
2740  e = width;
2741  for (int idx = i; idx < e; ++idx, ++newWidth) {
2742  if (auto it = mapping.find(idx); it != mapping.end()) {
2743  it->second = newWidth;
2744  }
2745  }
2746  ranges.emplace_back(i, e - 1);
2747  i = e != width ? usedBits.find_next(e) : e;
2748  }
2749 
2750  // Create the new op with the new port types.
2751  auto newType = IntType::get(op->getContext(), type.isSigned(), newWidth);
2752  SmallVector<Type> portTypes;
2753  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2754  portTypes.push_back(
2755  MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
2756  }
2757  auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
2758  mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
2759  mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
2760  mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
2761  mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2762 
2763  // Rewrite bundle users to the new data type.
2764  auto rewriteSubfield = [&](Value port, StringRef field) {
2765  auto portTy = type_cast<BundleType>(port.getType());
2766  auto fieldIndex = portTy.getElementIndex(field);
2767  assert(fieldIndex && "missing data port");
2768 
2769  rewriter.setInsertionPointAfter(newMem);
2770  auto newPortAccess =
2771  rewriter.create<SubfieldOp>(port.getLoc(), port, field);
2772 
2773  for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2774  auto portAccess = cast<SubfieldOp>(op);
2775  if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
2776  continue;
2777  rewriter.replaceOp(portAccess, newPortAccess.getResult());
2778  }
2779  };
2780 
2781  // Rewrite the field accesses.
2782  for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
2783  switch (newMem.getPortKind(i)) {
2784  case MemOp::PortKind::Debug:
2785  llvm_unreachable("cannot rewrite debug port");
2786  case MemOp::PortKind::Write:
2787  rewriteSubfield(port, "data");
2788  continue;
2789  case MemOp::PortKind::Read:
2790  rewriteSubfield(port, "data");
2791  continue;
2792  case MemOp::PortKind::ReadWrite:
2793  rewriteSubfield(port, "rdata");
2794  rewriteSubfield(port, "wdata");
2795  continue;
2796  }
2797  llvm_unreachable("unknown port kind");
2798  }
2799 
2800  // Rewrite the reads to the new ranges, compacting them.
2801  for (auto readOp : readOps) {
2802  rewriter.setInsertionPointAfter(readOp);
2803  auto it = mapping.find(readOp.getLo());
2804  assert(it != mapping.end() && "bit op mapping not found");
2805  rewriter.replaceOpWithNewOp<BitsPrimOp>(
2806  readOp, readOp.getInput(),
2807  readOp.getHi() - readOp.getLo() + it->second, it->second);
2808  }
2809 
2810  // Rewrite the writes into a concatenation of slices.
2811  for (auto writeOp : writeOps) {
2812  Value source = writeOp.getSrc();
2813  rewriter.setInsertionPoint(writeOp);
2814 
2815  Value catOfSlices;
2816  for (auto &[start, end] : ranges) {
2817  Value slice =
2818  rewriter.create<BitsPrimOp>(writeOp.getLoc(), source, end, start);
2819  if (catOfSlices) {
2820  catOfSlices =
2821  rewriter.create<CatPrimOp>(writeOp.getLoc(), slice, catOfSlices);
2822  } else {
2823  catOfSlices = slice;
2824  }
2825  }
2826  rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
2827  catOfSlices);
2828  }
2829 
2830  return success();
2831  }
2832 };
2833 
2834 // Rewrite single-address memories to a firrtl register.
2835 struct FoldRegMems : public mlir::RewritePattern {
2836  FoldRegMems(MLIRContext *context)
2837  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2838  LogicalResult matchAndRewrite(Operation *op,
2839  PatternRewriter &rewriter) const override {
2840  MemOp mem = cast<MemOp>(op);
2841  const FirMemory &info = mem.getSummary();
2842  if (hasDontTouch(mem) || info.depth != 1)
2843  return failure();
2844 
2845  auto memModule = mem->getParentOfType<FModuleOp>();
2846 
2847  // Find the clock of the register-to-be, all write ports should share it.
2848  Value clock;
2849  SmallPtrSet<Operation *, 8> connects;
2850  SmallVector<SubfieldOp> portAccesses;
2851  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2852  if (!mem.getPortAnnotation(i).empty())
2853  continue;
2854 
2855  auto collect = [&, port = port](ArrayRef<StringRef> fields) {
2856  auto portTy = type_cast<BundleType>(port.getType());
2857  for (auto field : fields) {
2858  auto fieldIndex = portTy.getElementIndex(field);
2859  assert(fieldIndex && "missing field on memory port");
2860 
2861  for (auto *op : port.getUsers()) {
2862  auto portAccess = cast<SubfieldOp>(op);
2863  if (fieldIndex != portAccess.getFieldIndex())
2864  continue;
2865  portAccesses.push_back(portAccess);
2866  for (auto *user : portAccess->getUsers()) {
2867  auto conn = dyn_cast<FConnectLike>(user);
2868  if (!conn)
2869  return failure();
2870  connects.insert(conn);
2871  }
2872  }
2873  }
2874  return success();
2875  };
2876 
2877  switch (mem.getPortKind(i)) {
2878  case MemOp::PortKind::Debug:
2879  return failure();
2880  case MemOp::PortKind::Read:
2881  if (failed(collect({"clk", "en", "addr"})))
2882  return failure();
2883  continue;
2884  case MemOp::PortKind::Write:
2885  if (failed(collect({"clk", "en", "addr", "data", "mask"})))
2886  return failure();
2887  break;
2888  case MemOp::PortKind::ReadWrite:
2889  if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
2890  return failure();
2891  break;
2892  }
2893 
2894  Value portClock = getPortFieldValue(port, "clk");
2895  if (!portClock || (clock && portClock != clock))
2896  return failure();
2897  clock = portClock;
2898  }
2899 
2900  // Create a new register to store the data.
2901  auto ty = mem.getDataType();
2902  rewriter.setInsertionPointAfterValue(clock);
2903  auto reg = rewriter.create<RegOp>(mem.getLoc(), ty, clock, mem.getName())
2904  .getResult();
2905 
2906  // Helper to insert a given number of pipeline stages through registers.
2907  auto pipeline = [&](Value value, Value clock, const Twine &name,
2908  unsigned latency) {
2909  for (unsigned i = 0; i < latency; ++i) {
2910  std::string regName;
2911  {
2912  llvm::raw_string_ostream os(regName);
2913  os << mem.getName() << "_" << name << "_" << i;
2914  }
2915 
2916  auto reg = rewriter
2917  .create<RegOp>(mem.getLoc(), value.getType(), clock,
2918  rewriter.getStringAttr(regName))
2919  .getResult();
2920  rewriter.create<MatchingConnectOp>(value.getLoc(), reg, value);
2921  value = reg;
2922  }
2923  return value;
2924  };
2925 
2926  const unsigned writeStages = info.writeLatency - 1;
2927 
2928  // Traverse each port. Replace reads with the pipelined register, discarding
2929  // the enable flag and reading unconditionally. Pipeline the mask, enable
2930  // and data bits of all write ports to be arbitrated and wired to the reg.
2931  SmallVector<std::tuple<Value, Value, Value>> writes;
2932  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2933  Value portClock = getPortFieldValue(port, "clk");
2934  StringRef name = mem.getPortName(i);
2935 
2936  auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
2937  Value value = getPortFieldValue(port, field);
2938  assert(value);
2939  rewriter.setInsertionPointAfterValue(value);
2940  return pipeline(value, portClock, name + "_" + field, stages);
2941  };
2942 
2943  switch (mem.getPortKind(i)) {
2944  case MemOp::PortKind::Debug:
2945  llvm_unreachable("unknown port kind");
2946  case MemOp::PortKind::Read: {
2947  // Read ports pipeline the addr and enable signals. However, the
2948  // address must be 0 for single-address memories and the enable signal
2949  // is ignored, always reading out the register. Under these constraints,
2950  // the read port can be replaced with the value from the register.
2951  rewriter.setInsertionPointAfterValue(reg);
2952  replacePortField(rewriter, port, "data", reg);
2953  break;
2954  }
2955  case MemOp::PortKind::Write: {
2956  auto data = portPipeline("data", writeStages);
2957  auto en = portPipeline("en", writeStages);
2958  auto mask = portPipeline("mask", writeStages);
2959  writes.emplace_back(data, en, mask);
2960  break;
2961  }
2962  case MemOp::PortKind::ReadWrite: {
2963  // Always read the register into the read end.
2964  rewriter.setInsertionPointAfterValue(reg);
2965  replacePortField(rewriter, port, "rdata", reg);
2966 
2967  // Create a write enable and pipeline stages.
2968  auto wdata = portPipeline("wdata", writeStages);
2969  auto wmask = portPipeline("wmask", writeStages);
2970 
2971  Value en = getPortFieldValue(port, "en");
2972  Value wmode = getPortFieldValue(port, "wmode");
2973  rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
2974 
2975  auto wen = rewriter.create<AndPrimOp>(port.getLoc(), en, wmode);
2976  auto wenPipelined =
2977  pipeline(wen, portClock, name + "_wen", writeStages);
2978  writes.emplace_back(wdata, wenPipelined, wmask);
2979  break;
2980  }
2981  }
2982  }
2983 
2984  // Regardless of `writeUnderWrite`, always implement PortOrder.
2985  rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
2986  Value next = reg;
2987  for (auto &[data, en, mask] : writes) {
2988  Value masked;
2989 
2990  // If a mask bit is used, emit muxes to select the input from the
2991  // register (no mask) or the input (mask bit set).
2992  Location loc = mem.getLoc();
2993  unsigned maskGran = info.dataWidth / info.maskBits;
2994  for (unsigned i = 0; i < info.maskBits; ++i) {
2995  unsigned hi = (i + 1) * maskGran - 1;
2996  unsigned lo = i * maskGran;
2997 
2998  auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
2999  auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3000  auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
3001  auto chunk = rewriter.create<MuxPrimOp>(loc, bit, dataPart, nextPart);
3002 
3003  if (masked) {
3004  masked = rewriter.create<CatPrimOp>(loc, chunk, masked);
3005  } else {
3006  masked = chunk;
3007  }
3008  }
3009 
3010  next = rewriter.create<MuxPrimOp>(next.getLoc(), en, masked, next);
3011  }
3012  rewriter.create<MatchingConnectOp>(reg.getLoc(), reg, next);
3013 
3014  // Delete the fields and their associated connects.
3015  for (Operation *conn : connects)
3016  rewriter.eraseOp(conn);
3017  for (auto portAccess : portAccesses)
3018  rewriter.eraseOp(portAccess);
3019  rewriter.eraseOp(mem);
3020 
3021  return success();
3022  }
3023 };
3024 } // namespace
3025 
3026 void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3027  MLIRContext *context) {
3028  results
3029  .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3030  FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3031  context);
3032 }
3033 
3034 //===----------------------------------------------------------------------===//
3035 // Declarations
3036 //===----------------------------------------------------------------------===//
3037 
3038 // Turn synchronous reset looking register updates to registers with resets.
3039 // Also, const prop registers that are driven by a mux tree containing only
3040 // instances of one constant or self-assigns.
3041 static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3042  // reg ; connect(reg, mux(port, const, val)) ->
3043  // reg.reset(port, const); connect(reg, val)
3044 
3045  // Find the one true connect, or bail
3046  auto con = getSingleConnectUserOf(reg.getResult());
3047  if (!con)
3048  return failure();
3049 
3050  auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3051  if (!mux)
3052  return failure();
3053  auto *high = mux.getHigh().getDefiningOp();
3054  auto *low = mux.getLow().getDefiningOp();
3055  // Reset value must be constant
3056  auto constOp = dyn_cast_or_null<ConstantOp>(high);
3057 
3058  // Detect the case if a register only has two possible drivers:
3059  // (1) itself/uninit and (2) constant.
3060  // The mux can then be replaced with the constant.
3061  // r = mux(cond, r, 3) --> r = 3
3062  // r = mux(cond, 3, r) --> r = 3
3063  bool constReg = false;
3064 
3065  if (constOp && low == reg)
3066  constReg = true;
3067  else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3068  constReg = true;
3069  constOp = dyn_cast<ConstantOp>(low);
3070  }
3071  if (!constOp)
3072  return failure();
3073 
3074  // For a non-constant register, reset should be a module port (heuristic to
3075  // limit to intended reset lines). Replace the register anyway if constant.
3076  if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3077  return failure();
3078 
3079  // Check all types should be typed by now
3080  auto regTy = reg.getResult().getType();
3081  if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3082  mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3083  regTy.getBitWidthOrSentinel() < 0)
3084  return failure();
3085 
3086  // Ok, we know we are doing the transformation.
3087 
3088  // Make sure the constant dominates all users.
3089  if (constOp != &con->getBlock()->front())
3090  constOp->moveBefore(&con->getBlock()->front());
3091 
3092  if (!constReg) {
3093  SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3094  auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3095  rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3096  mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3097  reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3098  reg.getForceableAttr());
3099  newReg->setDialectAttrs(attrs);
3100  }
3101  auto pt = rewriter.saveInsertionPoint();
3102  rewriter.setInsertionPoint(con);
3103  auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3104  replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3105  rewriter.restoreInsertionPoint(pt);
3106  return success();
3107 }
3108 
3109 LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3110  if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3111  succeeded(foldHiddenReset(op, rewriter)))
3112  return success();
3113 
3114  if (succeeded(demoteForceableIfUnused(op, rewriter)))
3115  return success();
3116 
3117  return failure();
3118 }
3119 
3120 //===----------------------------------------------------------------------===//
3121 // Verification Ops.
3122 //===----------------------------------------------------------------------===//
3123 
3124 static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3125  Value enable,
3126  PatternRewriter &rewriter,
3127  bool eraseIfZero) {
3128  // If the verification op is never enabled, delete it.
3129  if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3130  if (constant.getValue().isZero()) {
3131  rewriter.eraseOp(op);
3132  return success();
3133  }
3134  }
3135 
3136  // If the verification op is never triggered, delete it.
3137  if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3138  if (constant.getValue().isZero() == eraseIfZero) {
3139  rewriter.eraseOp(op);
3140  return success();
3141  }
3142  }
3143 
3144  return failure();
3145 }
3146 
3147 template <class Op, bool EraseIfZero = false>
3148 static LogicalResult canonicalizeImmediateVerifOp(Op op,
3149  PatternRewriter &rewriter) {
3150  return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3151  EraseIfZero);
3152 }
3153 
3154 void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3155  MLIRContext *context) {
3156  results.add(canonicalizeImmediateVerifOp<AssertOp>);
3157 }
3158 
3159 void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3160  MLIRContext *context) {
3161  results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3162 }
3163 
3164 void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3165  RewritePatternSet &results, MLIRContext *context) {
3166  results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3167 }
3168 
3169 void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3170  MLIRContext *context) {
3171  results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3172 }
3173 
3174 //===----------------------------------------------------------------------===//
3175 // InvalidValueOp
3176 //===----------------------------------------------------------------------===//
3177 
3178 LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3179  PatternRewriter &rewriter) {
3180  // Remove `InvalidValueOp`s with no uses.
3181  if (op.use_empty()) {
3182  rewriter.eraseOp(op);
3183  return success();
3184  }
3185  // Propagate invalids through a single use which is a unary op. You cannot
3186  // propagate through multiple uses as that breaks invalid semantics. Nor
3187  // can you propagate through binary ops or generally any op which computes.
3188  // Not is an exception as it is a pure, all-bits inverse.
3189  if (op->hasOneUse() &&
3190  (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3191  SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3192  *op->user_begin()) ||
3193  (isa<CvtPrimOp>(*op->user_begin()) &&
3194  type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3195  (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3196  type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3197  .getBitWidthOrSentinel() > 0))) {
3198  auto *modop = *op->user_begin();
3199  auto inv = rewriter.create<InvalidValueOp>(op.getLoc(),
3200  modop->getResult(0).getType());
3201  rewriter.replaceAllOpUsesWith(modop, inv);
3202  rewriter.eraseOp(modop);
3203  rewriter.eraseOp(op);
3204  return success();
3205  }
3206  return failure();
3207 }
3208 
3209 OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3210  if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3211  return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3212  return {};
3213 }
3214 
3215 //===----------------------------------------------------------------------===//
3216 // ClockGateIntrinsicOp
3217 //===----------------------------------------------------------------------===//
3218 
3219 OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3220  // Forward the clock if one of the enables is always true.
3221  if (isConstantOne(adaptor.getEnable()) ||
3222  isConstantOne(adaptor.getTestEnable()))
3223  return getInput();
3224 
3225  // Fold to a constant zero clock if the enables are always false.
3226  if (isConstantZero(adaptor.getEnable()) &&
3227  (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3228  return BoolAttr::get(getContext(), false);
3229 
3230  // Forward constant zero clocks.
3231  if (isConstantZero(adaptor.getInput()))
3232  return BoolAttr::get(getContext(), false);
3233 
3234  return {};
3235 }
3236 
3237 LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3238  PatternRewriter &rewriter) {
3239  // Remove constant false test enable.
3240  if (auto testEnable = op.getTestEnable()) {
3241  if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3242  if (constOp.getValue().isZero()) {
3243  rewriter.modifyOpInPlace(op,
3244  [&] { op.getTestEnableMutable().clear(); });
3245  return success();
3246  }
3247  }
3248  }
3249 
3250  return failure();
3251 }
3252 
3253 //===----------------------------------------------------------------------===//
3254 // Reference Ops.
3255 //===----------------------------------------------------------------------===//
3256 
3257 // refresolve(forceable.ref) -> forceable.data
3258 static LogicalResult
3259 canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3260  auto forceable = op.getRef().getDefiningOp<Forceable>();
3261  if (!forceable || !forceable.isForceable() ||
3262  op.getRef() != forceable.getDataRef() ||
3263  op.getType() != forceable.getDataType())
3264  return failure();
3265  rewriter.replaceAllUsesWith(op, forceable.getData());
3266  return success();
3267 }
3268 
3269 void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3270  MLIRContext *context) {
3271  results.insert<patterns::RefResolveOfRefSend>(context);
3272  results.insert(canonicalizeRefResolveOfForceable);
3273 }
3274 
3275 OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3276  // RefCast is unnecessary if types match.
3277  if (getInput().getType() == getType())
3278  return getInput();
3279  return {};
3280 }
3281 
3282 static bool isConstantZero(Value operand) {
3283  auto constOp = operand.getDefiningOp<ConstantOp>();
3284  return constOp && constOp.getValue().isZero();
3285 }
3286 
3287 template <typename Op>
3288 static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3289  if (isConstantZero(op.getPredicate())) {
3290  rewriter.eraseOp(op);
3291  return success();
3292  }
3293  return failure();
3294 }
3295 
3296 void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3297  MLIRContext *context) {
3298  results.add(eraseIfPredFalse<RefForceOp>);
3299 }
3300 void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3301  MLIRContext *context) {
3302  results.add(eraseIfPredFalse<RefForceInitialOp>);
3303 }
3304 void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3305  MLIRContext *context) {
3306  results.add(eraseIfPredFalse<RefReleaseOp>);
3307 }
3308 void RefReleaseInitialOp::getCanonicalizationPatterns(
3309  RewritePatternSet &results, MLIRContext *context) {
3310  results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3311 }
3312 
3313 //===----------------------------------------------------------------------===//
3314 // HasBeenResetIntrinsicOp
3315 //===----------------------------------------------------------------------===//
3316 
3317 OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3318  // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3319 
3320  // Fold to zero if the reset is a constant. In this case the op is either
3321  // permanently in reset or never resets. Both mean that the reset never
3322  // finishes, so this op never returns true.
3323  if (adaptor.getReset())
3324  return getIntZerosAttr(UIntType::get(getContext(), 1));
3325 
3326  // Fold to zero if the clock is a constant and the reset is synchronous. In
3327  // that case the reset will never be started.
3328  if (isUInt1(getReset().getType()) && adaptor.getClock())
3329  return getIntZerosAttr(UIntType::get(getContext(), 1));
3330 
3331  return {};
3332 }
3333 
3334 //===----------------------------------------------------------------------===//
3335 // FPGAProbeIntrinsicOp
3336 //===----------------------------------------------------------------------===//
3337 
3338 static bool isTypeEmpty(FIRRTLType type) {
3340  .Case<FVectorType>(
3341  [&](auto ty) -> bool { return isTypeEmpty(ty.getElementType()); })
3342  .Case<BundleType>([&](auto ty) -> bool {
3343  for (auto elem : ty.getElements())
3344  if (!isTypeEmpty(elem.type))
3345  return false;
3346  return true;
3347  })
3348  .Case<IntType>([&](auto ty) { return ty.getWidth() == 0; })
3349  .Default([](auto) -> bool { return false; });
3350 }
3351 
3352 LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3353  PatternRewriter &rewriter) {
3354  auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3355  if (!firrtlTy)
3356  return failure();
3357 
3358  if (!isTypeEmpty(firrtlTy))
3359  return failure();
3360 
3361  rewriter.eraseOp(op);
3362  return success();
3363 }
assert(baseType &&"element must be base type")
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
Definition: FIRRTLFolds.cpp:71
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
Definition: FIRRTLFolds.cpp:91
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
Definition: FIRRTLFolds.cpp:82
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
Definition: FIRRTLFolds.cpp:32
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value moveNameHint(OpResult old, Value passthrough)
Definition: FIRRTLFolds.cpp:48
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
int32_t width
Definition: FIRRTL.cpp:36
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:520
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:530
This is the common base class between SIntType and UIntType.
Definition: FIRRTLTypes.h:296
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
Definition: FIRRTLTypes.h:275
bool hasWidth() const
Return true if this integer type has a known width.
Definition: FIRRTLTypes.h:283
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
Definition: VerifOps.cpp:66
def connect(destination, source)
Definition: support.py:37
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:32
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
Definition: FIRRTLOps.cpp:314
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition: APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition: APInt.cpp:18
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition: Naming.cpp:47
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition: FoldUtils.h:27
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition: FoldUtils.h:34
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:20