CIRCT  20.0.0git
FIRRTLFolds.cpp
Go to the documentation of this file.
1 //===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the folding and canonicalizations for FIRRTL ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "circt/Support/APInt.h"
18 #include "circt/Support/LLVM.h"
19 #include "circt/Support/Naming.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/APSInt.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 using namespace circt;
28 using namespace firrtl;
29 
30 // Drop writes to old and pass through passthrough to make patterns easier to
31 // write.
32 static Value dropWrite(PatternRewriter &rewriter, OpResult old,
33  Value passthrough) {
34  SmallPtrSet<Operation *, 8> users;
35  for (auto *user : old.getUsers())
36  users.insert(user);
37  for (Operation *user : users)
38  if (auto connect = dyn_cast<FConnectLike>(user))
39  if (connect.getDest() == old)
40  rewriter.eraseOp(user);
41  return passthrough;
42 }
43 
44 // Move a name hint from a soon to be deleted operation to a new operation.
45 // Pass through the new operation to make patterns easier to write. This cannot
46 // move a name to a port (block argument), doing so would require rewriting all
47 // instance sites as well as the module.
48 static Value moveNameHint(OpResult old, Value passthrough) {
49  Operation *op = passthrough.getDefiningOp();
50  // This should handle ports, but it isn't clear we can change those in
51  // canonicalizers.
52  assert(op && "passthrough must be an operation");
53  Operation *oldOp = old.getOwner();
54  auto name = oldOp->getAttrOfType<StringAttr>("name");
55  if (name && !name.getValue().empty())
56  op->setAttr("name", name);
57  return passthrough;
58 }
59 
60 // Declarative canonicalization patterns
61 namespace circt {
62 namespace firrtl {
63 namespace patterns {
64 #include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
65 } // namespace patterns
66 } // namespace firrtl
67 } // namespace circt
68 
69 /// Return true if this operation's operands and results all have a known width.
70 /// This only works for integer types.
71 static bool hasKnownWidthIntTypes(Operation *op) {
72  auto resultType = type_cast<IntType>(op->getResult(0).getType());
73  if (!resultType.hasWidth())
74  return false;
75  for (Value operand : op->getOperands())
76  if (!type_cast<IntType>(operand.getType()).hasWidth())
77  return false;
78  return true;
79 }
80 
81 /// Return true if this value is 1 bit UInt.
82 static bool isUInt1(Type type) {
83  auto t = type_dyn_cast<UIntType>(type);
84  if (!t || !t.hasWidth() || t.getWidth() != 1)
85  return false;
86  return true;
87 }
88 
89 /// Set the name of an op based on the best of two names: The current name, and
90 /// the name passed in.
91 static void updateName(PatternRewriter &rewriter, Operation *op,
92  StringAttr name) {
93  // Should never rename InstanceOp
94  assert(!isa<InstanceOp>(op));
95  if (!name || name.getValue().empty())
96  return;
97  auto newName = name.getValue(); // old name is interesting
98  auto newOpName = op->getAttrOfType<StringAttr>("name");
99  // new name might not be interesting
100  if (newOpName)
101  newName = chooseName(newOpName.getValue(), name.getValue());
102  // Only update if needed
103  if (!newOpName || newOpName.getValue() != newName)
104  rewriter.modifyOpInPlace(
105  op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
106 }
107 
108 /// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
109 /// If a replaced op has a "name" attribute, this function propagates the name
110 /// to the new value.
111 static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
112  Value newValue) {
113  if (auto *newOp = newValue.getDefiningOp()) {
114  auto name = op->getAttrOfType<StringAttr>("name");
115  updateName(rewriter, newOp, name);
116  }
117  rewriter.replaceOp(op, newValue);
118 }
119 
120 /// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
121 /// attribute. If a replaced op has a "name" attribute, this function propagates
122 /// the name to the new value.
123 template <typename OpTy, typename... Args>
124 static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
125  Operation *op, Args &&...args) {
126  auto name = op->getAttrOfType<StringAttr>("name");
127  auto newOp =
128  rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
129  updateName(rewriter, newOp, name);
130  return newOp;
131 }
132 
133 /// Return true if the name is droppable. Note that this is different from
134 /// `isUselessName` because non-useless names may be also droppable.
135 bool circt::firrtl::hasDroppableName(Operation *op) {
136  if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
137  return namableOp.hasDroppableName();
138  return false;
139 }
140 
141 /// Implicitly replace the operand to a constant folding operation with a const
142 /// 0 in case the operand is non-constant but has a bit width 0, or if the
143 /// operand is an invalid value.
144 ///
145 /// This makes constant folding significantly easier, as we can simply pass the
146 /// operands to an operation through this function to appropriately replace any
147 /// zero-width dynamic values or invalid values with a constant of value 0.
148 static std::optional<APSInt>
149 getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
150  assert(type_cast<IntType>(operand.getType()) &&
151  "getExtendedConstant is limited to integer types");
152 
153  // We never support constant folding to unknown width values.
154  if (destWidth < 0)
155  return {};
156 
157  // Extension signedness follows the operand sign.
158  if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
159  return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
160 
161  // If the operand is zero bits, then we can return a zero of the result
162  // type.
163  if (type_cast<IntType>(operand.getType()).getWidth() == 0)
164  return APSInt(destWidth,
165  type_cast<IntType>(operand.getType()).isUnsigned());
166  return {};
167 }
168 
169 /// Determine the value of a constant operand for the sake of constant folding.
170 static std::optional<APSInt> getConstant(Attribute operand) {
171  if (!operand)
172  return {};
173  if (auto attr = dyn_cast<BoolAttr>(operand))
174  return APSInt(APInt(1, attr.getValue()));
175  if (auto attr = dyn_cast<IntegerAttr>(operand))
176  return attr.getAPSInt();
177  return {};
178 }
179 
180 /// Determine whether a constant operand is a zero value for the sake of
181 /// constant folding. This considers `invalidvalue` to be zero.
182 static bool isConstantZero(Attribute operand) {
183  if (auto cst = getConstant(operand))
184  return cst->isZero();
185  return false;
186 }
187 
188 /// Determine whether a constant operand is a one value for the sake of constant
189 /// folding.
190 static bool isConstantOne(Attribute operand) {
191  if (auto cst = getConstant(operand))
192  return cst->isOne();
193  return false;
194 }
195 
196 /// This is the policy for folding, which depends on the sort of operator we're
197 /// processing.
198 enum class BinOpKind {
199  Normal,
200  Compare,
202 };
203 
204 /// Applies the constant folding function `calculate` to the given operands.
205 ///
206 /// Sign or zero extends the operands appropriately to the bitwidth of the
207 /// result type if \p useDstWidth is true, else to the larger of the two operand
208 /// bit widths and depending on whether the operation is to be performed on
209 /// signed or unsigned operands.
210 static Attribute constFoldFIRRTLBinaryOp(
211  Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
212  const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
213  assert(operands.size() == 2 && "binary op takes two operands");
214 
215  // We cannot fold something to an unknown width.
216  auto resultType = type_cast<IntType>(op->getResult(0).getType());
217  if (resultType.getWidthOrSentinel() < 0)
218  return {};
219 
220  // Any binary op returning i0 is 0.
221  if (resultType.getWidthOrSentinel() == 0)
222  return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
223 
224  // Determine the operand widths. This is either dictated by the operand type,
225  // or if that type is an unsized integer, by the actual bits necessary to
226  // represent the constant value.
227  auto lhsWidth =
228  type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
229  auto rhsWidth =
230  type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
231  if (auto lhs = dyn_cast_or_null<IntegerAttr>(operands[0]))
232  lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
233  if (auto rhs = dyn_cast_or_null<IntegerAttr>(operands[1]))
234  rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
235 
236  // Compares extend the operands to the widest of the operand types, not to the
237  // result type.
238  int32_t operandWidth;
239  switch (opKind) {
240  case BinOpKind::Normal:
241  operandWidth = resultType.getWidthOrSentinel();
242  break;
243  case BinOpKind::Compare:
244  // Compares compute with the widest operand, not at the destination type
245  // (which is always i1).
246  operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
247  break;
249  operandWidth =
250  std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
251  break;
252  }
253 
254  auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
255  if (!lhs)
256  return {};
257  auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
258  if (!rhs)
259  return {};
260 
261  APInt resultValue = calculate(*lhs, *rhs);
262 
263  // If the result type is smaller than the computation then we need to
264  // narrow the constant after the calculation.
265  if (opKind == BinOpKind::DivideOrShift)
266  resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
267 
268  assert((unsigned)resultType.getWidthOrSentinel() ==
269  resultValue.getBitWidth());
270  return getIntAttr(resultType, resultValue);
271 }
272 
273 /// Applies the canonicalization function `canonicalize` to the given operation.
274 ///
275 /// Determines which (if any) of the operation's operands are constants, and
276 /// provides them as arguments to the callback function. Any `invalidvalue` in
277 /// the input is mapped to a constant zero. The value returned from the callback
278 /// is used as the replacement for `op`, and an additional pad operation is
279 /// inserted if necessary. Does nothing if the result of `op` is of unknown
280 /// width, in which case the necessity of a pad cannot be determined.
281 static LogicalResult canonicalizePrimOp(
282  Operation *op, PatternRewriter &rewriter,
283  const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
284  // Can only operate on FIRRTL primitive operations.
285  if (op->getNumResults() != 1)
286  return failure();
287  auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
288  if (!type)
289  return failure();
290 
291  // Can only operate on operations with a known result width.
292  auto width = type.getBitWidthOrSentinel();
293  if (width < 0)
294  return failure();
295 
296  // Determine which of the operands are constants.
297  SmallVector<Attribute, 3> constOperands;
298  constOperands.reserve(op->getNumOperands());
299  for (auto operand : op->getOperands()) {
300  Attribute attr;
301  if (auto *defOp = operand.getDefiningOp())
302  TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
303  [&](auto op) { attr = op.getValueAttr(); });
304  constOperands.push_back(attr);
305  }
306 
307  // Perform the canonicalization and materialize the result if it is a
308  // constant.
309  auto result = canonicalize(constOperands);
310  if (!result)
311  return failure();
312  Value resultValue;
313  if (auto cst = dyn_cast<Attribute>(result))
314  resultValue = op->getDialect()
315  ->materializeConstant(rewriter, cst, type, op->getLoc())
316  ->getResult(0);
317  else
318  resultValue = result.get<Value>();
319 
320  // Insert a pad if the type widths disagree.
321  if (width !=
322  type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
323  resultValue = rewriter.create<PadPrimOp>(op->getLoc(), resultValue, width);
324 
325  // Insert a cast if this is a uint vs. sint or vice versa.
326  if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
327  resultValue = rewriter.create<AsSIntPrimOp>(op->getLoc(), resultValue);
328  else if (type_isa<UIntType>(type) &&
329  type_isa<SIntType>(resultValue.getType()))
330  resultValue = rewriter.create<AsUIntPrimOp>(op->getLoc(), resultValue);
331 
332  assert(type == resultValue.getType() && "canonicalization changed type");
333  replaceOpAndCopyName(rewriter, op, resultValue);
334  return success();
335 }
336 
337 /// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
338 /// value if `bitWidth` is 0.
339 static APInt getMaxUnsignedValue(unsigned bitWidth) {
340  return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
341 }
342 
343 /// Get the smallest signed value of a given bit width. Returns a 1-bit zero
344 /// value if `bitWidth` is 0.
345 static APInt getMinSignedValue(unsigned bitWidth) {
346  return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
347 }
348 
349 /// Get the largest signed value of a given bit width. Returns a 1-bit zero
350 /// value if `bitWidth` is 0.
351 static APInt getMaxSignedValue(unsigned bitWidth) {
352  return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // Fold Hooks
357 //===----------------------------------------------------------------------===//
358 
359 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
360  assert(adaptor.getOperands().empty() && "constant has no operands");
361  return getValueAttr();
362 }
363 
364 OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
365  assert(adaptor.getOperands().empty() && "constant has no operands");
366  return getValueAttr();
367 }
368 
369 OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
370  assert(adaptor.getOperands().empty() && "constant has no operands");
371  return getFieldsAttr();
372 }
373 
374 OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
375  assert(adaptor.getOperands().empty() && "constant has no operands");
376  return getValueAttr();
377 }
378 
379 OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
380  assert(adaptor.getOperands().empty() && "constant has no operands");
381  return getValueAttr();
382 }
383 
384 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
385  assert(adaptor.getOperands().empty() && "constant has no operands");
386  return getValueAttr();
387 }
388 
389 OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
390  assert(adaptor.getOperands().empty() && "constant has no operands");
391  return getValueAttr();
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // Binary Operators
396 //===----------------------------------------------------------------------===//
397 
398 OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
400  *this, adaptor.getOperands(), BinOpKind::Normal,
401  [=](const APSInt &a, const APSInt &b) { return a + b; });
402 }
403 
404 void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
405  MLIRContext *context) {
406  results.insert<patterns::moveConstAdd, patterns::AddOfZero,
407  patterns::AddOfSelf, patterns::AddOfPad>(context);
408 }
409 
410 OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
412  *this, adaptor.getOperands(), BinOpKind::Normal,
413  [=](const APSInt &a, const APSInt &b) { return a - b; });
414 }
415 
416 void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417  MLIRContext *context) {
418  results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
419  patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
420  patterns::SubOfPadL, patterns::SubOfPadR>(context);
421 }
422 
423 OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
424  // mul(x, 0) -> 0
425  //
426  // This is legal because it aligns with the Scala FIRRTL Compiler
427  // interpretation of lowering invalid to constant zero before constant
428  // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
429  // multiplication this way and will emit "x * 0".
430  if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
431  return getIntZerosAttr(getType());
432 
434  *this, adaptor.getOperands(), BinOpKind::Normal,
435  [=](const APSInt &a, const APSInt &b) { return a * b; });
436 }
437 
438 OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
439  /// div(x, x) -> 1
440  ///
441  /// Division by zero is undefined in the FIRRTL specification. This fold
442  /// exploits that fact to optimize self division to one. Note: this should
443  /// supersede any division with invalid or zero. Division of invalid by
444  /// invalid should be one.
445  if (getLhs() == getRhs()) {
446  auto width = getType().base().getWidthOrSentinel();
447  if (width == -1)
448  width = 2;
449  // Only fold if we have at least 1 bit of width to represent the `1` value.
450  if (width != 0)
451  return getIntAttr(getType(), APInt(width, 1));
452  }
453 
454  // div(0, x) -> 0
455  //
456  // This is legal because it aligns with the Scala FIRRTL Compiler
457  // interpretation of lowering invalid to constant zero before constant
458  // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
459  // division this way and will emit "0 / x".
460  if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
461  return getIntZerosAttr(getType());
462 
463  /// div(x, 1) -> x : (uint, uint) -> uint
464  ///
465  /// UInt division by one returns the numerator. SInt division can't
466  /// be folded here because it increases the return type bitwidth by
467  /// one and requires sign extension (a new op).
468  if (auto rhsCst = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
469  if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
470  return getLhs();
471 
473  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
474  [=](const APSInt &a, const APSInt &b) -> APInt {
475  if (!!b)
476  return a / b;
477  return APInt(a.getBitWidth(), 0);
478  });
479 }
480 
481 OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
482  // rem(x, x) -> 0
483  //
484  // Division by zero is undefined in the FIRRTL specification. This fold
485  // exploits that fact to optimize self division remainder to zero. Note:
486  // this should supersede any division with invalid or zero. Remainder of
487  // division of invalid by invalid should be zero.
488  if (getLhs() == getRhs())
489  return getIntZerosAttr(getType());
490 
491  // rem(0, x) -> 0
492  //
493  // This is legal because it aligns with the Scala FIRRTL Compiler
494  // interpretation of lowering invalid to constant zero before constant
495  // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
496  // division this way and will emit "0 % x".
497  if (isConstantZero(adaptor.getLhs()))
498  return getIntZerosAttr(getType());
499 
501  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
502  [=](const APSInt &a, const APSInt &b) -> APInt {
503  if (!!b)
504  return a % b;
505  return APInt(a.getBitWidth(), 0);
506  });
507 }
508 
509 OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
511  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
512  [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
513 }
514 
515 OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
517  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
518  [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
519 }
520 
521 OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
523  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
524  [=](const APSInt &a, const APSInt &b) -> APInt {
525  return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
526  : a.ashr(b);
527  });
528 }
529 
530 // TODO: Move to DRR.
531 OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
532  if (auto rhsCst = getConstant(adaptor.getRhs())) {
533  /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
534  if (rhsCst->isZero())
535  return getIntZerosAttr(getType());
536 
537  /// and(x, -1) -> x
538  if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
539  getRhs().getType() == getType())
540  return getLhs();
541  }
542 
543  if (auto lhsCst = getConstant(adaptor.getLhs())) {
544  /// and(0, x) -> 0, 0 is largest or is implicit zero extended
545  if (lhsCst->isZero())
546  return getIntZerosAttr(getType());
547 
548  /// and(-1, x) -> x
549  if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
550  getRhs().getType() == getType())
551  return getRhs();
552  }
553 
554  /// and(x, x) -> x
555  if (getLhs() == getRhs() && getRhs().getType() == getType())
556  return getRhs();
557 
559  *this, adaptor.getOperands(), BinOpKind::Normal,
560  [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
561 }
562 
563 void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
564  MLIRContext *context) {
565  results
566  .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
567  patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
568  patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
569 }
570 
571 OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
572  if (auto rhsCst = getConstant(adaptor.getRhs())) {
573  /// or(x, 0) -> x
574  if (rhsCst->isZero() && getLhs().getType() == getType())
575  return getLhs();
576 
577  /// or(x, -1) -> -1
578  if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
579  getLhs().getType() == getType())
580  return getRhs();
581  }
582 
583  if (auto lhsCst = getConstant(adaptor.getLhs())) {
584  /// or(0, x) -> x
585  if (lhsCst->isZero() && getRhs().getType() == getType())
586  return getRhs();
587 
588  /// or(-1, x) -> -1
589  if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
590  getRhs().getType() == getType())
591  return getLhs();
592  }
593 
594  /// or(x, x) -> x
595  if (getLhs() == getRhs() && getRhs().getType() == getType())
596  return getRhs();
597 
599  *this, adaptor.getOperands(), BinOpKind::Normal,
600  [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
601 }
602 
603 void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
604  MLIRContext *context) {
605  results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
606  patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad,
607  patterns::OrOrr>(context);
608 }
609 
610 OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
611  /// xor(x, 0) -> x
612  if (auto rhsCst = getConstant(adaptor.getRhs()))
613  if (rhsCst->isZero() &&
614  firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
615  return getLhs();
616 
617  /// xor(x, 0) -> x
618  if (auto lhsCst = getConstant(adaptor.getLhs()))
619  if (lhsCst->isZero() &&
620  firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
621  return getRhs();
622 
623  /// xor(x, x) -> 0
624  if (getLhs() == getRhs())
625  return getIntAttr(
626  getType(),
627  APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
628 
630  *this, adaptor.getOperands(), BinOpKind::Normal,
631  [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
632 }
633 
634 void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
635  MLIRContext *context) {
636  results.insert<patterns::extendXor, patterns::moveConstXor,
637  patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
638  context);
639 }
640 
641 void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
642  MLIRContext *context) {
643  results.insert<patterns::LEQWithConstLHS>(context);
644 }
645 
646 OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
647  bool isUnsigned = getLhs().getType().base().isUnsigned();
648 
649  // leq(x, x) -> 1
650  if (getLhs() == getRhs())
651  return getIntAttr(getType(), APInt(1, 1));
652 
653  // Comparison against constant outside type bounds.
654  if (auto width = getLhs().getType().base().getWidth()) {
655  if (auto rhsCst = getConstant(adaptor.getRhs())) {
656  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
657  commonWidth = std::max(commonWidth, 1);
658 
659  // leq(x, const) -> 0 where const < minValue of the unsigned type of x
660  // This can never occur since const is unsigned and cannot be less than 0.
661 
662  // leq(x, const) -> 0 where const < minValue of the signed type of x
663  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
664  .slt(getMinSignedValue(*width).sext(commonWidth)))
665  return getIntAttr(getType(), APInt(1, 0));
666 
667  // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
668  if (isUnsigned && rhsCst->zext(commonWidth)
669  .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
670  return getIntAttr(getType(), APInt(1, 1));
671 
672  // leq(x, const) -> 1 where const >= maxValue of the signed type of x
673  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
674  .sge(getMaxSignedValue(*width).sext(commonWidth)))
675  return getIntAttr(getType(), APInt(1, 1));
676  }
677  }
678 
680  *this, adaptor.getOperands(), BinOpKind::Compare,
681  [=](const APSInt &a, const APSInt &b) -> APInt {
682  return APInt(1, a <= b);
683  });
684 }
685 
686 void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
687  MLIRContext *context) {
688  results.insert<patterns::LTWithConstLHS>(context);
689 }
690 
691 OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
692  IntType lhsType = getLhs().getType();
693  bool isUnsigned = lhsType.isUnsigned();
694 
695  // lt(x, x) -> 0
696  if (getLhs() == getRhs())
697  return getIntAttr(getType(), APInt(1, 0));
698 
699  // lt(x, 0) -> 0 when x is unsigned
700  if (auto rhsCst = getConstant(adaptor.getRhs())) {
701  if (rhsCst->isZero() && lhsType.isUnsigned())
702  return getIntAttr(getType(), APInt(1, 0));
703  }
704 
705  // Comparison against constant outside type bounds.
706  if (auto width = lhsType.getWidth()) {
707  if (auto rhsCst = getConstant(adaptor.getRhs())) {
708  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
709  commonWidth = std::max(commonWidth, 1);
710 
711  // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
712  // Handled explicitly above.
713 
714  // lt(x, const) -> 0 where const <= minValue of the signed type of x
715  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
716  .sle(getMinSignedValue(*width).sext(commonWidth)))
717  return getIntAttr(getType(), APInt(1, 0));
718 
719  // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
720  if (isUnsigned && rhsCst->zext(commonWidth)
721  .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
722  return getIntAttr(getType(), APInt(1, 1));
723 
724  // lt(x, const) -> 1 where const > maxValue of the signed type of x
725  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
726  .sgt(getMaxSignedValue(*width).sext(commonWidth)))
727  return getIntAttr(getType(), APInt(1, 1));
728  }
729  }
730 
732  *this, adaptor.getOperands(), BinOpKind::Compare,
733  [=](const APSInt &a, const APSInt &b) -> APInt {
734  return APInt(1, a < b);
735  });
736 }
737 
738 void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
739  MLIRContext *context) {
740  results.insert<patterns::GEQWithConstLHS>(context);
741 }
742 
743 OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
744  IntType lhsType = getLhs().getType();
745  bool isUnsigned = lhsType.isUnsigned();
746 
747  // geq(x, x) -> 1
748  if (getLhs() == getRhs())
749  return getIntAttr(getType(), APInt(1, 1));
750 
751  // geq(x, 0) -> 1 when x is unsigned
752  if (auto rhsCst = getConstant(adaptor.getRhs())) {
753  if (rhsCst->isZero() && isUnsigned)
754  return getIntAttr(getType(), APInt(1, 1));
755  }
756 
757  // Comparison against constant outside type bounds.
758  if (auto width = lhsType.getWidth()) {
759  if (auto rhsCst = getConstant(adaptor.getRhs())) {
760  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
761  commonWidth = std::max(commonWidth, 1);
762 
763  // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
764  if (isUnsigned && rhsCst->zext(commonWidth)
765  .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
766  return getIntAttr(getType(), APInt(1, 0));
767 
768  // geq(x, const) -> 0 where const > maxValue of the signed type of x
769  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
770  .sgt(getMaxSignedValue(*width).sext(commonWidth)))
771  return getIntAttr(getType(), APInt(1, 0));
772 
773  // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
774  // Handled explicitly above.
775 
776  // geq(x, const) -> 1 where const <= minValue of the signed type of x
777  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
778  .sle(getMinSignedValue(*width).sext(commonWidth)))
779  return getIntAttr(getType(), APInt(1, 1));
780  }
781  }
782 
784  *this, adaptor.getOperands(), BinOpKind::Compare,
785  [=](const APSInt &a, const APSInt &b) -> APInt {
786  return APInt(1, a >= b);
787  });
788 }
789 
790 void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
791  MLIRContext *context) {
792  results.insert<patterns::GTWithConstLHS>(context);
793 }
794 
795 OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
796  IntType lhsType = getLhs().getType();
797  bool isUnsigned = lhsType.isUnsigned();
798 
799  // gt(x, x) -> 0
800  if (getLhs() == getRhs())
801  return getIntAttr(getType(), APInt(1, 0));
802 
803  // Comparison against constant outside type bounds.
804  if (auto width = lhsType.getWidth()) {
805  if (auto rhsCst = getConstant(adaptor.getRhs())) {
806  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
807  commonWidth = std::max(commonWidth, 1);
808 
809  // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
810  if (isUnsigned && rhsCst->zext(commonWidth)
811  .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
812  return getIntAttr(getType(), APInt(1, 0));
813 
814  // gt(x, const) -> 0 where const >= maxValue of the signed type of x
815  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
816  .sge(getMaxSignedValue(*width).sext(commonWidth)))
817  return getIntAttr(getType(), APInt(1, 0));
818 
819  // gt(x, const) -> 1 where const < minValue of the unsigned type of x
820  // This can never occur since const is unsigned and cannot be less than 0.
821 
822  // gt(x, const) -> 1 where const < minValue of the signed type of x
823  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
824  .slt(getMinSignedValue(*width).sext(commonWidth)))
825  return getIntAttr(getType(), APInt(1, 1));
826  }
827  }
828 
830  *this, adaptor.getOperands(), BinOpKind::Compare,
831  [=](const APSInt &a, const APSInt &b) -> APInt {
832  return APInt(1, a > b);
833  });
834 }
835 
836 OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
837  // eq(x, x) -> 1
838  if (getLhs() == getRhs())
839  return getIntAttr(getType(), APInt(1, 1));
840 
841  if (auto rhsCst = getConstant(adaptor.getRhs())) {
842  /// eq(x, 1) -> x when x is 1 bit.
843  /// TODO: Support SInt<1> on the LHS etc.
844  if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
845  getRhs().getType() == getType())
846  return getLhs();
847  }
848 
850  *this, adaptor.getOperands(), BinOpKind::Compare,
851  [=](const APSInt &a, const APSInt &b) -> APInt {
852  return APInt(1, a == b);
853  });
854 }
855 
856 LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
857  return canonicalizePrimOp(
858  op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
859  if (auto rhsCst = getConstant(operands[1])) {
860  auto width = op.getLhs().getType().getBitWidthOrSentinel();
861 
862  // eq(x, 0) -> not(x) when x is 1 bit.
863  if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
864  op.getRhs().getType() == op.getType()) {
865  return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
866  .getResult();
867  }
868 
869  // eq(x, 0) -> not(orr(x)) when x is >1 bit
870  if (rhsCst->isZero() && width > 1) {
871  auto orrOp = rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs());
872  return rewriter.create<NotPrimOp>(op.getLoc(), orrOp).getResult();
873  }
874 
875  // eq(x, ~0) -> andr(x) when x is >1 bit
876  if (rhsCst->isAllOnes() && width > 1 &&
877  op.getLhs().getType() == op.getRhs().getType()) {
878  return rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs())
879  .getResult();
880  }
881  }
882  return {};
883  });
884 }
885 
886 OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
887  // neq(x, x) -> 0
888  if (getLhs() == getRhs())
889  return getIntAttr(getType(), APInt(1, 0));
890 
891  if (auto rhsCst = getConstant(adaptor.getRhs())) {
892  /// neq(x, 0) -> x when x is 1 bit.
893  /// TODO: Support SInt<1> on the LHS etc.
894  if (rhsCst->isZero() && getLhs().getType() == getType() &&
895  getRhs().getType() == getType())
896  return getLhs();
897  }
898 
900  *this, adaptor.getOperands(), BinOpKind::Compare,
901  [=](const APSInt &a, const APSInt &b) -> APInt {
902  return APInt(1, a != b);
903  });
904 }
905 
906 LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
907  return canonicalizePrimOp(
908  op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
909  if (auto rhsCst = getConstant(operands[1])) {
910  auto width = op.getLhs().getType().getBitWidthOrSentinel();
911 
912  // neq(x, 1) -> not(x) when x is 1 bit
913  if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
914  op.getRhs().getType() == op.getType()) {
915  return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
916  .getResult();
917  }
918 
919  // neq(x, 0) -> orr(x) when x is >1 bit
920  if (rhsCst->isZero() && width > 1) {
921  return rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs())
922  .getResult();
923  }
924 
925  // neq(x, ~0) -> not(andr(x))) when x is >1 bit
926  if (rhsCst->isAllOnes() && width > 1 &&
927  op.getLhs().getType() == op.getRhs().getType()) {
928  auto andrOp = rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs());
929  return rewriter.create<NotPrimOp>(op.getLoc(), andrOp).getResult();
930  }
931  }
932 
933  return {};
934  });
935 }
936 
937 OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
938  // TODO: implement constant folding, etc.
939  // Tracked in https://github.com/llvm/circt/issues/6696.
940  return {};
941 }
942 
943 OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
944  // TODO: implement constant folding, etc.
945  // Tracked in https://github.com/llvm/circt/issues/6724.
946  return {};
947 }
948 
949 OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
950  // TODO: implement constant folding, etc.
951  // Tracked in https://github.com/llvm/circt/issues/6725.
952  return {};
953 }
954 
955 OpFoldResult IntegerShlOp::fold(FoldAdaptor adaptor) {
956  if (auto rhsCst = getConstant(adaptor.getRhs())) {
957  // Constant folding
958  if (auto lhsCst = getConstant(adaptor.getLhs()))
959 
960  return IntegerAttr::get(
961  IntegerType::get(getContext(), lhsCst->getBitWidth()),
962  lhsCst->shl(*rhsCst));
963 
964  // integer.shl(x, 0) -> x
965  if (rhsCst->isZero())
966  return getLhs();
967  }
968 
969  return {};
970 }
971 
972 //===----------------------------------------------------------------------===//
973 // Unary Operators
974 //===----------------------------------------------------------------------===//
975 
976 OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
977  auto base = getInput().getType();
978  auto w = getBitWidth(base);
979  if (w)
980  return getIntAttr(getType(), APInt(32, *w));
981  return {};
982 }
983 
984 OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
985  // No constant can be 'x' by definition.
986  if (auto cst = getConstant(adaptor.getArg()))
987  return getIntAttr(getType(), APInt(1, 0));
988  return {};
989 }
990 
991 OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
992  // No effect.
993  if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
994  return getInput();
995 
996  // Be careful to only fold the cast into the constant if the size is known.
997  // Otherwise width inference may produce differently-sized constants if the
998  // sign changes.
999  if (getType().base().hasWidth())
1000  if (auto cst = getConstant(adaptor.getInput()))
1001  return getIntAttr(getType(), *cst);
1002 
1003  return {};
1004 }
1005 
1006 void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1007  MLIRContext *context) {
1008  results.insert<patterns::StoUtoS>(context);
1009 }
1010 
1011 OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1012  // No effect.
1013  if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1014  return getInput();
1015 
1016  // Be careful to only fold the cast into the constant if the size is known.
1017  // Otherwise width inference may produce differently-sized constants if the
1018  // sign changes.
1019  if (getType().base().hasWidth())
1020  if (auto cst = getConstant(adaptor.getInput()))
1021  return getIntAttr(getType(), *cst);
1022 
1023  return {};
1024 }
1025 
1026 void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1027  MLIRContext *context) {
1028  results.insert<patterns::UtoStoU>(context);
1029 }
1030 
1031 OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1032  // No effect.
1033  if (getInput().getType() == getType())
1034  return getInput();
1035 
1036  // Constant fold.
1037  if (auto cst = getConstant(adaptor.getInput()))
1038  return BoolAttr::get(getContext(), cst->getBoolValue());
1039 
1040  return {};
1041 }
1042 
1043 OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1044  // No effect.
1045  if (getInput().getType() == getType())
1046  return getInput();
1047 
1048  // Constant fold.
1049  if (auto cst = getConstant(adaptor.getInput()))
1050  return BoolAttr::get(getContext(), cst->getBoolValue());
1051 
1052  return {};
1053 }
1054 
1055 OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1056  if (!hasKnownWidthIntTypes(*this))
1057  return {};
1058 
1059  // Signed to signed is a noop, unsigned operands prepend a zero bit.
1060  if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1061  getType().base().getWidthOrSentinel()))
1062  return getIntAttr(getType(), *cst);
1063 
1064  return {};
1065 }
1066 
1067 void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1068  MLIRContext *context) {
1069  results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1070 }
1071 
1072 OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1073  if (!hasKnownWidthIntTypes(*this))
1074  return {};
1075 
1076  // FIRRTL negate always adds a bit.
1077  // -x ---> 0-sext(x) or 0-zext(x)
1078  if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1079  getType().base().getWidthOrSentinel()))
1080  return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1081 
1082  return {};
1083 }
1084 
1085 OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1086  if (!hasKnownWidthIntTypes(*this))
1087  return {};
1088 
1089  if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1090  getType().base().getWidthOrSentinel()))
1091  return getIntAttr(getType(), ~*cst);
1092 
1093  return {};
1094 }
1095 
1096 void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1097  MLIRContext *context) {
1098  results.insert<patterns::NotNot, patterns::NotEq, patterns::NotNeq,
1099  patterns::NotLeq, patterns::NotLt, patterns::NotGeq,
1100  patterns::NotGt>(context);
1101 }
1102 
1103 OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1104  if (!hasKnownWidthIntTypes(*this))
1105  return {};
1106 
1107  if (getInput().getType().getBitWidthOrSentinel() == 0)
1108  return getIntAttr(getType(), APInt(1, 1));
1109 
1110  // x == -1
1111  if (auto cst = getConstant(adaptor.getInput()))
1112  return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1113 
1114  // one bit is identity. Only applies to UInt since we can't make a cast
1115  // here.
1116  if (isUInt1(getInput().getType()))
1117  return getInput();
1118 
1119  return {};
1120 }
1121 
1122 void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1123  MLIRContext *context) {
1124  results
1125  .insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1126  patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
1127  patterns::AndRCatZeroL, patterns::AndRCatZeroR,
1128  patterns::AndRCatAndR_left, patterns::AndRCatAndR_right>(context);
1129 }
1130 
1131 OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1132  if (!hasKnownWidthIntTypes(*this))
1133  return {};
1134 
1135  if (getInput().getType().getBitWidthOrSentinel() == 0)
1136  return getIntAttr(getType(), APInt(1, 0));
1137 
1138  // x != 0
1139  if (auto cst = getConstant(adaptor.getInput()))
1140  return getIntAttr(getType(), APInt(1, !cst->isZero()));
1141 
1142  // one bit is identity. Only applies to UInt since we can't make a cast
1143  // here.
1144  if (isUInt1(getInput().getType()))
1145  return getInput();
1146 
1147  return {};
1148 }
1149 
1150 void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1151  MLIRContext *context) {
1152  results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1153  patterns::OrRCatZeroH, patterns::OrRCatZeroL,
1154  patterns::OrRCatOrR_left, patterns::OrRCatOrR_right>(context);
1155 }
1156 
1157 OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1158  if (!hasKnownWidthIntTypes(*this))
1159  return {};
1160 
1161  if (getInput().getType().getBitWidthOrSentinel() == 0)
1162  return getIntAttr(getType(), APInt(1, 0));
1163 
1164  // popcount(x) & 1
1165  if (auto cst = getConstant(adaptor.getInput()))
1166  return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1167 
1168  // one bit is identity. Only applies to UInt since we can't make a cast here.
1169  if (isUInt1(getInput().getType()))
1170  return getInput();
1171 
1172  return {};
1173 }
1174 
1175 void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1176  MLIRContext *context) {
1177  results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1178  patterns::XorRCatZeroH, patterns::XorRCatZeroL,
1179  patterns::XorRCatXorR_left, patterns::XorRCatXorR_right>(
1180  context);
1181 }
1182 
1183 //===----------------------------------------------------------------------===//
1184 // Other Operators
1185 //===----------------------------------------------------------------------===//
1186 
1187 OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1188  // cat(x, 0-width) -> x
1189  // cat(0-width, x) -> x
1190  // Limit to unsigned (result type), as cannot insert cast here.
1191  IntType lhsType = getLhs().getType();
1192  IntType rhsType = getRhs().getType();
1193  if (lhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
1194  return getRhs();
1195  if (rhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
1196  return getLhs();
1197 
1198  if (!hasKnownWidthIntTypes(*this))
1199  return {};
1200 
1201  // Constant fold cat.
1202  if (auto lhs = getConstant(adaptor.getLhs()))
1203  if (auto rhs = getConstant(adaptor.getRhs()))
1204  return getIntAttr(getType(), lhs->concat(*rhs));
1205 
1206  return {};
1207 }
1208 
1209 void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1210  MLIRContext *context) {
1211  results.insert<patterns::DShlOfConstant>(context);
1212 }
1213 
1214 void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1215  MLIRContext *context) {
1216  results.insert<patterns::DShrOfConstant>(context);
1217 }
1218 
1219 void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1220  MLIRContext *context) {
1221  results.insert<patterns::CatBitsBits, patterns::CatDoubleConst,
1222  patterns::CatCast>(context);
1223 }
1224 
1225 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1226  auto op = (*this);
1227  // BitCast is redundant if input and result types are same.
1228  if (op.getType() == op.getInput().getType())
1229  return op.getInput();
1230 
1231  // Two consecutive BitCasts are redundant if first bitcast type is same as the
1232  // final result type.
1233  if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1234  if (op.getType() == in.getInput().getType())
1235  return in.getInput();
1236 
1237  return {};
1238 }
1239 
1240 OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1241  IntType inputType = getInput().getType();
1242  IntType resultType = getType();
1243  // If we are extracting the entire input, then return it.
1244  if (inputType == getType() && resultType.hasWidth())
1245  return getInput();
1246 
1247  // Constant fold.
1248  if (hasKnownWidthIntTypes(*this))
1249  if (auto cst = getConstant(adaptor.getInput()))
1250  return getIntAttr(resultType,
1251  cst->extractBits(getHi() - getLo() + 1, getLo()));
1252 
1253  return {};
1254 }
1255 
1256 void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1257  MLIRContext *context) {
1258  results
1259  .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1260  patterns::BitsOfAnd, patterns::BitsOfPad>(context);
1261 }
1262 
1263 /// Replace the specified operation with a 'bits' op from the specified hi/lo
1264 /// bits. Insert a cast to handle the case where the original operation
1265 /// returned a signed integer.
1266 static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1267  unsigned loBit, PatternRewriter &rewriter) {
1268  auto resType = type_cast<IntType>(op->getResult(0).getType());
1269  if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1270  value = rewriter.create<BitsPrimOp>(op->getLoc(), value, hiBit, loBit);
1271 
1272  if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1273  value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1274  } else if (resType.isUnsigned() &&
1275  !type_cast<IntType>(value.getType()).isUnsigned()) {
1276  value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1277  }
1278  rewriter.replaceOp(op, value);
1279 }
1280 
1281 template <typename OpTy>
1282 static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1283  // mux : UInt<0> -> 0
1284  if (op.getType().getBitWidthOrSentinel() == 0)
1285  return getIntAttr(op.getType(),
1286  APInt(0, 0, op.getType().isSignedInteger()));
1287 
1288  // mux(cond, x, x) -> x
1289  if (op.getHigh() == op.getLow() && op.getHigh().getType() == op.getType())
1290  return op.getHigh();
1291 
1292  // The following folds require that the result has a known width. Otherwise
1293  // the mux requires an additional padding operation to be inserted, which is
1294  // not possible in a fold.
1295  if (op.getType().getBitWidthOrSentinel() < 0)
1296  return {};
1297 
1298  // mux(0/1, x, y) -> x or y
1299  if (auto cond = getConstant(adaptor.getSel())) {
1300  if (cond->isZero() && op.getLow().getType() == op.getType())
1301  return op.getLow();
1302  if (!cond->isZero() && op.getHigh().getType() == op.getType())
1303  return op.getHigh();
1304  }
1305 
1306  // mux(cond, x, cst)
1307  if (auto lowCst = getConstant(adaptor.getLow())) {
1308  // mux(cond, c1, c2)
1309  if (auto highCst = getConstant(adaptor.getHigh())) {
1310  // mux(cond, cst, cst) -> cst
1311  if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1312  *highCst == *lowCst)
1313  return getIntAttr(op.getType(), *highCst);
1314  // mux(cond, 1, 0) -> cond
1315  if (highCst->isOne() && lowCst->isZero() &&
1316  op.getType() == op.getSel().getType())
1317  return op.getSel();
1318 
1319  // TODO: x ? ~0 : 0 -> sext(x)
1320  // TODO: "x ? c1 : c2" -> many tricks
1321  }
1322  // TODO: "x ? a : 0" -> sext(x) & a
1323  }
1324 
1325  // TODO: "x ? c1 : y" -> "~x ? y : c1"
1326  return {};
1327 }
1328 
1329 OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1330  return foldMux(*this, adaptor);
1331 }
1332 
1333 OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1334  return foldMux(*this, adaptor);
1335 }
1336 
1337 OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1338 
1339 namespace {
1340 
1341 // If the mux has a known output width, pad the operands up to this width.
1342 // Most folds on mux require that folded operands are of the same width as
1343 // the mux itself.
1344 class MuxPad : public mlir::RewritePattern {
1345 public:
1346  MuxPad(MLIRContext *context)
1347  : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1348 
1349  LogicalResult
1350  matchAndRewrite(Operation *op,
1351  mlir::PatternRewriter &rewriter) const override {
1352  auto mux = cast<MuxPrimOp>(op);
1353  auto width = mux.getType().getBitWidthOrSentinel();
1354  if (width < 0)
1355  return failure();
1356 
1357  auto pad = [&](Value input) -> Value {
1358  auto inputWidth =
1359  type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1360  if (inputWidth < 0 || width == inputWidth)
1361  return input;
1362  return rewriter
1363  .create<PadPrimOp>(mux.getLoc(), mux.getType(), input, width)
1364  .getResult();
1365  };
1366 
1367  auto newHigh = pad(mux.getHigh());
1368  auto newLow = pad(mux.getLow());
1369  if (newHigh == mux.getHigh() && newLow == mux.getLow())
1370  return failure();
1371 
1372  replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1373  rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1374  mux->getAttrs());
1375  return success();
1376  }
1377 };
1378 
1379 // Find muxes which have conditions dominated by other muxes with the same
1380 // condition.
1381 class MuxSharedCond : public mlir::RewritePattern {
1382 public:
1383  MuxSharedCond(MLIRContext *context)
1384  : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1385 
1386  static const int depthLimit = 5;
1387 
1388  Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1389  mlir::PatternRewriter &rewriter,
1390  bool updateInPlace) const {
1391  if (updateInPlace) {
1392  rewriter.modifyOpInPlace(mux, [&] {
1393  mux.setOperand(1, high);
1394  mux.setOperand(2, low);
1395  });
1396  return {};
1397  }
1398  rewriter.setInsertionPointAfter(mux);
1399  return rewriter
1400  .create<MuxPrimOp>(mux.getLoc(), mux.getType(),
1401  ValueRange{mux.getSel(), high, low})
1402  .getResult();
1403  }
1404 
1405  // Walk a dependent mux tree assuming the condition cond is true.
1406  Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1407  bool updateInPlace, int limit) const {
1408  MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1409  if (!mux)
1410  return {};
1411  if (mux.getSel() == cond)
1412  return mux.getHigh();
1413  if (limit > depthLimit)
1414  return {};
1415  updateInPlace &= mux->hasOneUse();
1416 
1417  if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1418  limit + 1))
1419  return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1420 
1421  if (Value v =
1422  tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1423  return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1424  return {};
1425  }
1426 
1427  // Walk a dependent mux tree assuming the condition cond is false.
1428  Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1429  bool updateInPlace, int limit) const {
1430  MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1431  if (!mux)
1432  return {};
1433  if (mux.getSel() == cond)
1434  return mux.getLow();
1435  if (limit > depthLimit)
1436  return {};
1437  updateInPlace &= mux->hasOneUse();
1438 
1439  if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1440  limit + 1))
1441  return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1442 
1443  if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1444  limit + 1))
1445  return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1446 
1447  return {};
1448  }
1449 
1450  LogicalResult
1451  matchAndRewrite(Operation *op,
1452  mlir::PatternRewriter &rewriter) const override {
1453  auto mux = cast<MuxPrimOp>(op);
1454  auto width = mux.getType().getBitWidthOrSentinel();
1455  if (width < 0)
1456  return failure();
1457 
1458  if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1459  rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1460  return success();
1461  }
1462 
1463  if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
1464  rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1465  return success();
1466  }
1467 
1468  return failure();
1469  }
1470 };
1471 } // namespace
1472 
1473 void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1474  MLIRContext *context) {
1475  results
1476  .add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1477  patterns::MuxEQOperandsSwapped, patterns::MuxNEQ, patterns::MuxNot,
1478  patterns::MuxSameTrue, patterns::MuxSameFalse,
1479  patterns::NarrowMuxLHS, patterns::NarrowMuxRHS, patterns::MuxPadSel>(
1480  context);
1481 }
1482 
1483 void Mux2CellIntrinsicOp::getCanonicalizationPatterns(
1484  RewritePatternSet &results, MLIRContext *context) {
1485  results.add<patterns::Mux2PadSel>(context);
1486 }
1487 
1488 void Mux4CellIntrinsicOp::getCanonicalizationPatterns(
1489  RewritePatternSet &results, MLIRContext *context) {
1490  results.add<patterns::Mux4PadSel>(context);
1491 }
1492 
1493 OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1494  auto input = this->getInput();
1495 
1496  // pad(x) -> x if the width doesn't change.
1497  if (input.getType() == getType())
1498  return input;
1499 
1500  // Need to know the input width.
1501  auto inputType = input.getType().base();
1502  int32_t width = inputType.getWidthOrSentinel();
1503  if (width == -1)
1504  return {};
1505 
1506  // Constant fold.
1507  if (auto cst = getConstant(adaptor.getInput())) {
1508  auto destWidth = getType().base().getWidthOrSentinel();
1509  if (destWidth == -1)
1510  return {};
1511 
1512  if (inputType.isSigned() && cst->getBitWidth())
1513  return getIntAttr(getType(), cst->sext(destWidth));
1514  return getIntAttr(getType(), cst->zext(destWidth));
1515  }
1516 
1517  return {};
1518 }
1519 
1520 OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1521  auto input = this->getInput();
1522  IntType inputType = input.getType();
1523  int shiftAmount = getAmount();
1524 
1525  // shl(x, 0) -> x
1526  if (shiftAmount == 0)
1527  return input;
1528 
1529  // Constant fold.
1530  if (auto cst = getConstant(adaptor.getInput())) {
1531  auto inputWidth = inputType.getWidthOrSentinel();
1532  if (inputWidth != -1) {
1533  auto resultWidth = inputWidth + shiftAmount;
1534  shiftAmount = std::min(shiftAmount, resultWidth);
1535  return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1536  }
1537  }
1538  return {};
1539 }
1540 
1541 OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1542  auto input = this->getInput();
1543  IntType inputType = input.getType();
1544  int shiftAmount = getAmount();
1545  auto inputWidth = inputType.getWidthOrSentinel();
1546 
1547  // shr(x, 0) -> x
1548  // Once the shr width changes, do this: shiftAmount == 0 &&
1549  // (!inputType.isSigned() || inputWidth > 0)
1550  if (shiftAmount == 0 && inputWidth > 0)
1551  return input;
1552 
1553  if (inputWidth == -1)
1554  return {};
1555  if (inputWidth == 0)
1556  return getIntZerosAttr(getType());
1557 
1558  // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
1559  // If x is signed, it is the sign bit.
1560  if (shiftAmount >= inputWidth && inputType.isUnsigned())
1561  return getIntAttr(getType(), APInt(0, 0, false));
1562 
1563  // Constant fold.
1564  if (auto cst = getConstant(adaptor.getInput())) {
1565  APInt value;
1566  if (inputType.isSigned())
1567  value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1568  else
1569  value = cst->lshr(std::min(shiftAmount, inputWidth));
1570  auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1571  return getIntAttr(getType(), value.trunc(resultWidth));
1572  }
1573  return {};
1574 }
1575 
1576 LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1577  auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1578  if (inputWidth <= 0)
1579  return failure();
1580 
1581  // If we know the input width, we can canonicalize this into a BitsPrimOp.
1582  unsigned shiftAmount = op.getAmount();
1583  if (int(shiftAmount) >= inputWidth) {
1584  // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
1585  if (op.getType().base().isUnsigned())
1586  return failure();
1587 
1588  // Shifting a signed value by the full width is actually taking the
1589  // sign bit. If the shift amount is greater than the input width, it
1590  // is equivalent to shifting by the input width.
1591  shiftAmount = inputWidth - 1;
1592  }
1593 
1594  replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1595  return success();
1596 }
1597 
1598 LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1599  PatternRewriter &rewriter) {
1600  auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1601  if (inputWidth <= 0)
1602  return failure();
1603 
1604  // If we know the input width, we can canonicalize this into a BitsPrimOp.
1605  unsigned keepAmount = op.getAmount();
1606  if (keepAmount)
1607  replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1608  rewriter);
1609  return success();
1610 }
1611 
1612 OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1613  if (hasKnownWidthIntTypes(*this))
1614  if (auto cst = getConstant(adaptor.getInput())) {
1615  int shiftAmount =
1616  getInput().getType().base().getWidthOrSentinel() - getAmount();
1617  return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1618  }
1619 
1620  return {};
1621 }
1622 
1623 OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1624  if (hasKnownWidthIntTypes(*this))
1625  if (auto cst = getConstant(adaptor.getInput()))
1626  return getIntAttr(getType(),
1627  cst->trunc(getType().base().getWidthOrSentinel()));
1628  return {};
1629 }
1630 
1631 LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1632  PatternRewriter &rewriter) {
1633  auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1634  if (inputWidth <= 0)
1635  return failure();
1636 
1637  // If we know the input width, we can canonicalize this into a BitsPrimOp.
1638  unsigned dropAmount = op.getAmount();
1639  if (dropAmount != unsigned(inputWidth))
1640  replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
1641  rewriter);
1642  return success();
1643 }
1644 
1645 void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1646  MLIRContext *context) {
1647  results.add<patterns::SubaccessOfConstant>(context);
1648 }
1649 
1650 OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1651  // If there is only one input, just return it.
1652  if (adaptor.getInputs().size() == 1)
1653  return getOperand(1);
1654 
1655  if (auto constIndex = getConstant(adaptor.getIndex())) {
1656  auto index = constIndex->getZExtValue();
1657  if (index < getInputs().size())
1658  return getInputs()[getInputs().size() - 1 - index];
1659  }
1660 
1661  return {};
1662 }
1663 
1664 LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
1665  PatternRewriter &rewriter) {
1666  // If all operands are equal, just canonicalize to it. We can add this
1667  // canonicalization as a folder but it costly to look through all inputs so it
1668  // is added here.
1669  if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
1670  return input == op.getInputs().front();
1671  })) {
1672  replaceOpAndCopyName(rewriter, op, op.getInputs().front());
1673  return success();
1674  }
1675 
1676  // If the index width is narrower than the size of inputs, drop front
1677  // elements.
1678  auto indexWidth = op.getIndex().getType().getBitWidthOrSentinel();
1679  uint64_t inputSize = op.getInputs().size();
1680  if (indexWidth >= 0 && indexWidth < 64 && 1ull << indexWidth < inputSize) {
1681  rewriter.modifyOpInPlace(op, [&]() {
1682  op.getInputsMutable().erase(0, inputSize - (1ull << indexWidth));
1683  });
1684  return success();
1685  }
1686 
1687  // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
1688  // a[0]`), we can fold the op into subaccess op `a[idx]`.
1689  if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
1690  if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
1691  auto subindex = e.value().template getDefiningOp<SubindexOp>();
1692  return subindex && lastSubindex.getInput() == subindex.getInput() &&
1693  subindex.getIndex() + e.index() + 1 == op.getInputs().size();
1694  })) {
1695  replaceOpWithNewOpAndCopyName<SubaccessOp>(
1696  rewriter, op, lastSubindex.getInput(), op.getIndex());
1697  return success();
1698  }
1699  }
1700 
1701  // If the size is 2, canonicalize into a normal mux to introduce more folds.
1702  if (op.getInputs().size() != 2)
1703  return failure();
1704 
1705  // TODO: Handle even when `index` doesn't have uint<1>.
1706  auto uintType = op.getIndex().getType();
1707  if (uintType.getBitWidthOrSentinel() != 1)
1708  return failure();
1709 
1710  // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
1711  replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1712  rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
1713  return success();
1714 }
1715 
1716 //===----------------------------------------------------------------------===//
1717 // Declarations
1718 //===----------------------------------------------------------------------===//
1719 
1720 /// Scan all the uses of the specified value, checking to see if there is
1721 /// exactly one connect that has the value as its destination. This returns the
1722 /// operation if found and if all the other users are "reads" from the value.
1723 /// Returns null if there are no connects, or multiple connects to the value, or
1724 /// if the value is involved in an `AttachOp`, or if the connect isn't matching.
1725 ///
1726 /// Note that this will simply return the connect, which is located *anywhere*
1727 /// after the definition of the value. Users of this function are likely
1728 /// interested in the source side of the returned connect, the definition of
1729 /// which does likely not dominate the original value.
1730 MatchingConnectOp firrtl::getSingleConnectUserOf(Value value) {
1731  MatchingConnectOp connect;
1732  for (Operation *user : value.getUsers()) {
1733  // If we see an attach or aggregate sublements, just conservatively fail.
1734  if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
1735  return {};
1736 
1737  if (auto aConnect = dyn_cast<FConnectLike>(user))
1738  if (aConnect.getDest() == value) {
1739  auto matchingConnect = dyn_cast<MatchingConnectOp>(*aConnect);
1740  // If this is not a matching connect, a second matching connect or in a
1741  // different block, fail.
1742  if (!matchingConnect || (connect && connect != matchingConnect) ||
1743  matchingConnect->getBlock() != value.getParentBlock())
1744  return {};
1745  else
1746  connect = matchingConnect;
1747  }
1748  }
1749  return connect;
1750 }
1751 
1752 // Forward simple values through wire's and reg's.
1753 static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op,
1754  PatternRewriter &rewriter) {
1755  // While we can do this for nearly all wires, we currently limit it to simple
1756  // things.
1757  Operation *connectedDecl = op.getDest().getDefiningOp();
1758  if (!connectedDecl)
1759  return failure();
1760 
1761  // Only support wire and reg for now.
1762  if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
1763  return failure();
1764  if (hasDontTouch(connectedDecl) ||
1765  !AnnotationSet(connectedDecl).canBeDeleted() ||
1766  !hasDroppableName(connectedDecl) ||
1767  cast<Forceable>(connectedDecl).isForceable())
1768  return failure();
1769 
1770  // Only forward if the types exactly match and there is one connect.
1771  if (getSingleConnectUserOf(op.getDest()) != op)
1772  return failure();
1773 
1774  // Only forward if there is more than one use
1775  if (connectedDecl->hasOneUse())
1776  return failure();
1777 
1778  // Only do this if the connectee and the declaration are in the same block.
1779  auto *declBlock = connectedDecl->getBlock();
1780  auto *srcValueOp = op.getSrc().getDefiningOp();
1781  if (!srcValueOp) {
1782  // Ports are ok for wires but not registers.
1783  if (!isa<WireOp>(connectedDecl))
1784  return failure();
1785 
1786  } else {
1787  // Constants/invalids in the same block are ok to forward, even through
1788  // reg's since the clocking doesn't matter for constants.
1789  if (!isa<ConstantOp>(srcValueOp))
1790  return failure();
1791  if (srcValueOp->getBlock() != declBlock)
1792  return failure();
1793  }
1794 
1795  // Ok, we know we are doing the transformation.
1796 
1797  auto replacement = op.getSrc();
1798  // This will be replaced with the constant source. First, make sure the
1799  // constant dominates all users.
1800  if (srcValueOp && srcValueOp != &declBlock->front())
1801  srcValueOp->moveBefore(&declBlock->front());
1802 
1803  // Replace all things *using* the decl with the constant/port, and
1804  // remove the declaration.
1805  replaceOpAndCopyName(rewriter, connectedDecl, replacement);
1806 
1807  // Remove the connect
1808  rewriter.eraseOp(op);
1809  return success();
1810 }
1811 
1812 void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1813  MLIRContext *context) {
1814  results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
1815  context);
1816 }
1817 
1818 LogicalResult MatchingConnectOp::canonicalize(MatchingConnectOp op,
1819  PatternRewriter &rewriter) {
1820  // TODO: Canonicalize towards explicit extensions and flips here.
1821 
1822  // If there is a simple value connected to a foldable decl like a wire or reg,
1823  // see if we can eliminate the decl.
1824  if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
1825  return success();
1826  return failure();
1827 }
1828 
1829 //===----------------------------------------------------------------------===//
1830 // Statements
1831 //===----------------------------------------------------------------------===//
1832 
1833 /// If the specified value has an AttachOp user strictly dominating by
1834 /// "dominatingAttach" then return it.
1835 static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
1836  for (auto *user : value.getUsers()) {
1837  auto attach = dyn_cast<AttachOp>(user);
1838  if (!attach || attach == dominatedAttach)
1839  continue;
1840  if (attach->isBeforeInBlock(dominatedAttach))
1841  return attach;
1842  }
1843  return {};
1844 }
1845 
1846 LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
1847  // Single operand attaches are a noop.
1848  if (op.getNumOperands() <= 1) {
1849  rewriter.eraseOp(op);
1850  return success();
1851  }
1852 
1853  for (auto operand : op.getOperands()) {
1854  // Check to see if any of our operands has other attaches to it:
1855  // attach x, y
1856  // ...
1857  // attach x, z
1858  // If so, we can merge these into "attach x, y, z".
1859  if (auto attach = getDominatingAttachUser(operand, op)) {
1860  SmallVector<Value> newOperands(op.getOperands());
1861  for (auto newOperand : attach.getOperands())
1862  if (newOperand != operand) // Don't add operand twice.
1863  newOperands.push_back(newOperand);
1864  rewriter.create<AttachOp>(op->getLoc(), newOperands);
1865  rewriter.eraseOp(attach);
1866  rewriter.eraseOp(op);
1867  return success();
1868  }
1869 
1870  // If this wire is *only* used by an attach then we can just delete
1871  // it.
1872  // TODO: May need to be sensitive to "don't touch" or other
1873  // annotations.
1874  if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
1875  if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
1876  !wire.isForceable()) {
1877  SmallVector<Value> newOperands;
1878  for (auto newOperand : op.getOperands())
1879  if (newOperand != operand) // Don't the add wire.
1880  newOperands.push_back(newOperand);
1881 
1882  rewriter.create<AttachOp>(op->getLoc(), newOperands);
1883  rewriter.eraseOp(op);
1884  rewriter.eraseOp(wire);
1885  return success();
1886  }
1887  }
1888  }
1889  return failure();
1890 }
1891 
1892 /// Replaces the given op with the contents of the given single-block region.
1893 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1894  Region &region) {
1895  assert(llvm::hasSingleElement(region) && "expected single-region block");
1896  rewriter.inlineBlockBefore(&region.front(), op, {});
1897 }
1898 
1899 LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
1900  if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
1901  if (constant.getValue().isAllOnes())
1902  replaceOpWithRegion(rewriter, op, op.getThenRegion());
1903  else if (op.hasElseRegion() && !op.getElseRegion().empty())
1904  replaceOpWithRegion(rewriter, op, op.getElseRegion());
1905 
1906  rewriter.eraseOp(op);
1907 
1908  return success();
1909  }
1910 
1911  // Erase empty if-else block.
1912  if (!op.getThenBlock().empty() && op.hasElseRegion() &&
1913  op.getElseBlock().empty()) {
1914  rewriter.eraseBlock(&op.getElseBlock());
1915  return success();
1916  }
1917 
1918  // Erase empty whens.
1919 
1920  // If there is stuff in the then block, leave this operation alone.
1921  if (!op.getThenBlock().empty())
1922  return failure();
1923 
1924  // If not and there is no else, then this operation is just useless.
1925  if (!op.hasElseRegion() || op.getElseBlock().empty()) {
1926  rewriter.eraseOp(op);
1927  return success();
1928  }
1929  return failure();
1930 }
1931 
1932 namespace {
1933 // Remove private nodes. If they have an interesting names, move the name to
1934 // the source expression.
1935 struct FoldNodeName : public mlir::RewritePattern {
1936  FoldNodeName(MLIRContext *context)
1937  : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1938  LogicalResult matchAndRewrite(Operation *op,
1939  PatternRewriter &rewriter) const override {
1940  auto node = cast<NodeOp>(op);
1941  auto name = node.getNameAttr();
1942  if (!node.hasDroppableName() || node.getInnerSym() ||
1943  !AnnotationSet(node).canBeDeleted() || node.isForceable())
1944  return failure();
1945  auto *newOp = node.getInput().getDefiningOp();
1946  // Best effort, do not rename InstanceOp
1947  if (newOp && !isa<InstanceOp>(newOp))
1948  updateName(rewriter, newOp, name);
1949  rewriter.replaceOp(node, node.getInput());
1950  return success();
1951  }
1952 };
1953 
1954 // Bypass nodes.
1955 struct NodeBypass : public mlir::RewritePattern {
1956  NodeBypass(MLIRContext *context)
1957  : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1958  LogicalResult matchAndRewrite(Operation *op,
1959  PatternRewriter &rewriter) const override {
1960  auto node = cast<NodeOp>(op);
1961  if (node.getInnerSym() || !AnnotationSet(node).canBeDeleted() ||
1962  node.use_empty() || node.isForceable())
1963  return failure();
1964  rewriter.replaceAllUsesWith(node.getResult(), node.getInput());
1965  return success();
1966  }
1967 };
1968 
1969 } // namespace
1970 
1971 template <typename OpTy>
1972 static LogicalResult demoteForceableIfUnused(OpTy op,
1973  PatternRewriter &rewriter) {
1974  if (!op.isForceable() || !op.getDataRef().use_empty())
1975  return failure();
1976 
1977  firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
1978  return success();
1979 }
1980 
1981 // Interesting names and symbols and don't touch force nodes to stick around.
1982 LogicalResult NodeOp::fold(FoldAdaptor adaptor,
1983  SmallVectorImpl<OpFoldResult> &results) {
1984  if (!hasDroppableName())
1985  return failure();
1986  if (hasDontTouch(getResult())) // handles inner symbols
1987  return failure();
1988  if (getAnnotationsAttr() &&
1989  !AnnotationSet(getAnnotationsAttr()).canBeDeleted())
1990  return failure();
1991  if (isForceable())
1992  return failure();
1993  if (!adaptor.getInput())
1994  return failure();
1995 
1996  results.push_back(adaptor.getInput());
1997  return success();
1998 }
1999 
2000 void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
2001  MLIRContext *context) {
2002  results.insert<FoldNodeName>(context);
2003  results.add(demoteForceableIfUnused<NodeOp>);
2004 }
2005 
2006 namespace {
2007 // For a lhs, find all the writers of fields of the aggregate type. If there
2008 // is one writer for each field, merge the writes
2009 struct AggOneShot : public mlir::RewritePattern {
2010  AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
2011  : RewritePattern(name, 0, context) {}
2012 
2013  SmallVector<Value> getCompleteWrite(Operation *lhs) const {
2014  auto lhsTy = lhs->getResult(0).getType();
2015  if (!type_isa<BundleType, FVectorType>(lhsTy))
2016  return {};
2017 
2018  DenseMap<uint32_t, Value> fields;
2019  for (Operation *user : lhs->getResult(0).getUsers()) {
2020  if (user->getParentOp() != lhs->getParentOp())
2021  return {};
2022  if (auto aConnect = dyn_cast<MatchingConnectOp>(user)) {
2023  if (aConnect.getDest() == lhs->getResult(0))
2024  return {};
2025  } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2026  for (Operation *subuser : subField.getResult().getUsers()) {
2027  if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2028  if (aConnect.getDest() == subField) {
2029  if (subuser->getParentOp() != lhs->getParentOp())
2030  return {};
2031  if (fields.count(subField.getFieldIndex())) // duplicate write
2032  return {};
2033  fields[subField.getFieldIndex()] = aConnect.getSrc();
2034  }
2035  continue;
2036  }
2037  return {};
2038  }
2039  } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2040  for (Operation *subuser : subIndex.getResult().getUsers()) {
2041  if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser)) {
2042  if (aConnect.getDest() == subIndex) {
2043  if (subuser->getParentOp() != lhs->getParentOp())
2044  return {};
2045  if (fields.count(subIndex.getIndex())) // duplicate write
2046  return {};
2047  fields[subIndex.getIndex()] = aConnect.getSrc();
2048  }
2049  continue;
2050  }
2051  return {};
2052  }
2053  } else {
2054  return {};
2055  }
2056  }
2057 
2058  SmallVector<Value> values;
2059  uint32_t total = type_isa<BundleType>(lhsTy)
2060  ? type_cast<BundleType>(lhsTy).getNumElements()
2061  : type_cast<FVectorType>(lhsTy).getNumElements();
2062  for (uint32_t i = 0; i < total; ++i) {
2063  if (!fields.count(i))
2064  return {};
2065  values.push_back(fields[i]);
2066  }
2067  return values;
2068  }
2069 
2070  LogicalResult matchAndRewrite(Operation *op,
2071  PatternRewriter &rewriter) const override {
2072  auto values = getCompleteWrite(op);
2073  if (values.empty())
2074  return failure();
2075  rewriter.setInsertionPointToEnd(op->getBlock());
2076  auto dest = op->getResult(0);
2077  auto destType = dest.getType();
2078 
2079  // If not passive, cannot matchingconnect.
2080  if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2081  return failure();
2082 
2083  Value newVal = type_isa<BundleType>(destType)
2084  ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2085  destType, values)
2086  : rewriter.createOrFold<VectorCreateOp>(
2087  op->getLoc(), destType, values);
2088  rewriter.createOrFold<MatchingConnectOp>(op->getLoc(), dest, newVal);
2089  for (Operation *user : dest.getUsers()) {
2090  if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2091  for (Operation *subuser :
2092  llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2093  if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2094  if (aConnect.getDest() == subIndex)
2095  rewriter.eraseOp(aConnect);
2096  } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2097  for (Operation *subuser :
2098  llvm::make_early_inc_range(subField.getResult().getUsers()))
2099  if (auto aConnect = dyn_cast<MatchingConnectOp>(subuser))
2100  if (aConnect.getDest() == subField)
2101  rewriter.eraseOp(aConnect);
2102  }
2103  }
2104  return success();
2105  }
2106 };
2107 
2108 struct WireAggOneShot : public AggOneShot {
2109  WireAggOneShot(MLIRContext *context)
2110  : AggOneShot(WireOp::getOperationName(), 0, context) {}
2111 };
2112 struct SubindexAggOneShot : public AggOneShot {
2113  SubindexAggOneShot(MLIRContext *context)
2114  : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2115 };
2116 struct SubfieldAggOneShot : public AggOneShot {
2117  SubfieldAggOneShot(MLIRContext *context)
2118  : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2119 };
2120 } // namespace
2121 
2122 void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2123  MLIRContext *context) {
2124  results.insert<WireAggOneShot>(context);
2125  results.add(demoteForceableIfUnused<WireOp>);
2126 }
2127 
2128 void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2129  MLIRContext *context) {
2130  results.insert<SubindexAggOneShot>(context);
2131 }
2132 
2133 OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2134  auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2135  if (!attr)
2136  return {};
2137  return attr[getIndex()];
2138 }
2139 
2140 OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2141  auto attr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2142  if (!attr)
2143  return {};
2144  auto index = getFieldIndex();
2145  return attr[index];
2146 }
2147 
2148 void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2149  MLIRContext *context) {
2150  results.insert<SubfieldAggOneShot>(context);
2151 }
2152 
2153 static Attribute collectFields(MLIRContext *context,
2154  ArrayRef<Attribute> operands) {
2155  for (auto operand : operands)
2156  if (!operand)
2157  return {};
2158  return ArrayAttr::get(context, operands);
2159 }
2160 
2161 OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2162  // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2163  // bundle<a:..., b:...>.
2164  if (getNumOperands() > 0)
2165  if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2166  if (first.getFieldIndex() == 0 &&
2167  first.getInput().getType() == getType() &&
2168  llvm::all_of(
2169  llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2170  auto subindex =
2171  elem.value().template getDefiningOp<SubfieldOp>();
2172  return subindex && subindex.getInput() == first.getInput() &&
2173  subindex.getFieldIndex() == elem.index();
2174  }))
2175  return first.getInput();
2176 
2177  return collectFields(getContext(), adaptor.getOperands());
2178 }
2179 
2180 OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2181  // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2182  // vector<..., 2>.
2183  if (getNumOperands() > 0)
2184  if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2185  if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2186  llvm::all_of(
2187  llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2188  auto subindex =
2189  elem.value().template getDefiningOp<SubindexOp>();
2190  return subindex && subindex.getInput() == first.getInput() &&
2191  subindex.getIndex() == elem.index();
2192  }))
2193  return first.getInput();
2194 
2195  return collectFields(getContext(), adaptor.getOperands());
2196 }
2197 
2198 OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2199  if (getOperand().getType() == getType())
2200  return getOperand();
2201  return {};
2202 }
2203 
2204 namespace {
2205 // A register with constant reset and all connection to either itself or the
2206 // same constant, must be replaced by the constant.
2207 struct FoldResetMux : public mlir::RewritePattern {
2208  FoldResetMux(MLIRContext *context)
2209  : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2210  LogicalResult matchAndRewrite(Operation *op,
2211  PatternRewriter &rewriter) const override {
2212  auto reg = cast<RegResetOp>(op);
2213  auto reset =
2214  dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2215  if (!reset || hasDontTouch(reg.getOperation()) ||
2216  !AnnotationSet(reg).canBeDeleted() || reg.isForceable())
2217  return failure();
2218  // Find the one true connect, or bail
2219  auto con = getSingleConnectUserOf(reg.getResult());
2220  if (!con)
2221  return failure();
2222 
2223  auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2224  if (!mux)
2225  return failure();
2226  auto *high = mux.getHigh().getDefiningOp();
2227  auto *low = mux.getLow().getDefiningOp();
2228  auto constOp = dyn_cast_or_null<ConstantOp>(high);
2229 
2230  if (constOp && low != reg)
2231  return failure();
2232  if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2233  constOp = dyn_cast<ConstantOp>(low);
2234 
2235  if (!constOp || constOp.getType() != reset.getType() ||
2236  constOp.getValue() != reset.getValue())
2237  return failure();
2238 
2239  // Check all types should be typed by now
2240  auto regTy = reg.getResult().getType();
2241  if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2242  mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2243  regTy.getBitWidthOrSentinel() < 0)
2244  return failure();
2245 
2246  // Ok, we know we are doing the transformation.
2247 
2248  // Make sure the constant dominates all users.
2249  if (constOp != &con->getBlock()->front())
2250  constOp->moveBefore(&con->getBlock()->front());
2251 
2252  // Replace the register with the constant.
2253  replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2254  // Remove the connect.
2255  rewriter.eraseOp(con);
2256  return success();
2257  }
2258 };
2259 } // namespace
2260 
2261 static bool isDefinedByOneConstantOp(Value v) {
2262  if (auto c = v.getDefiningOp<ConstantOp>())
2263  return c.getValue().isOne();
2264  if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2265  return sc.getValue();
2266  return false;
2267 }
2268 
2269 static LogicalResult
2270 canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2271  if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2272  return failure();
2273 
2274  // Ignore 'passthrough'.
2275  (void)dropWrite(rewriter, reg->getResult(0), {});
2276  replaceOpWithNewOpAndCopyName<NodeOp>(
2277  rewriter, reg, reg.getResetValue(), reg.getNameAttr(), reg.getNameKind(),
2278  reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2279  return success();
2280 }
2281 
2282 void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2283  MLIRContext *context) {
2284  results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2285  results.add(canonicalizeRegResetWithOneReset);
2286  results.add(demoteForceableIfUnused<RegResetOp>);
2287 }
2288 
2289 // Returns the value connected to a port, if there is only one.
2290 static Value getPortFieldValue(Value port, StringRef name) {
2291  auto portTy = type_cast<BundleType>(port.getType());
2292  auto fieldIndex = portTy.getElementIndex(name);
2293  assert(fieldIndex && "missing field on memory port");
2294 
2295  Value value = {};
2296  for (auto *op : port.getUsers()) {
2297  auto portAccess = cast<SubfieldOp>(op);
2298  if (fieldIndex != portAccess.getFieldIndex())
2299  continue;
2300  auto conn = getSingleConnectUserOf(portAccess);
2301  if (!conn || value)
2302  return {};
2303  value = conn.getSrc();
2304  }
2305  return value;
2306 }
2307 
2308 // Returns true if the enable field of a port is set to false.
2309 static bool isPortDisabled(Value port) {
2310  auto value = getPortFieldValue(port, "en");
2311  if (!value)
2312  return false;
2313  auto portConst = value.getDefiningOp<ConstantOp>();
2314  if (!portConst)
2315  return false;
2316  return portConst.getValue().isZero();
2317 }
2318 
2319 // Returns true if the data output is unused.
2320 static bool isPortUnused(Value port, StringRef data) {
2321  auto portTy = type_cast<BundleType>(port.getType());
2322  auto fieldIndex = portTy.getElementIndex(data);
2323  assert(fieldIndex && "missing enable flag on memory port");
2324 
2325  for (auto *op : port.getUsers()) {
2326  auto portAccess = cast<SubfieldOp>(op);
2327  if (fieldIndex != portAccess.getFieldIndex())
2328  continue;
2329  if (!portAccess.use_empty())
2330  return false;
2331  }
2332 
2333  return true;
2334 }
2335 
2336 // Returns the value connected to a port, if there is only one.
2337 static void replacePortField(PatternRewriter &rewriter, Value port,
2338  StringRef name, Value value) {
2339  auto portTy = type_cast<BundleType>(port.getType());
2340  auto fieldIndex = portTy.getElementIndex(name);
2341  assert(fieldIndex && "missing field on memory port");
2342 
2343  for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2344  auto portAccess = cast<SubfieldOp>(op);
2345  if (fieldIndex != portAccess.getFieldIndex())
2346  continue;
2347  rewriter.replaceAllUsesWith(portAccess, value);
2348  rewriter.eraseOp(portAccess);
2349  }
2350 }
2351 
2352 // Remove accesses to a port which is used.
2353 static void erasePort(PatternRewriter &rewriter, Value port) {
2354  // Helper to create a dummy 0 clock for the dummy registers.
2355  Value clock;
2356  auto getClock = [&] {
2357  if (!clock)
2358  clock = rewriter.create<SpecialConstantOp>(
2359  port.getLoc(), ClockType::get(rewriter.getContext()), false);
2360  return clock;
2361  };
2362 
2363  // Find the clock field of the port and determine whether the port is
2364  // accessed only through its subfields or as a whole wire. If the port
2365  // is used in its entirety, replace it with a wire. Otherwise,
2366  // eliminate individual subfields and replace with reasonable defaults.
2367  for (auto *op : port.getUsers()) {
2368  auto subfield = dyn_cast<SubfieldOp>(op);
2369  if (!subfield) {
2370  auto ty = port.getType();
2371  auto reg = rewriter.create<RegOp>(port.getLoc(), ty, getClock());
2372  rewriter.replaceAllUsesWith(port, reg.getResult());
2373  return;
2374  }
2375  }
2376 
2377  // Remove all connects to field accesses as they are no longer relevant.
2378  // If field values are used anywhere, which should happen solely for read
2379  // ports, a dummy register is introduced which replicates the behaviour of
2380  // memory that is never written, but might be read.
2381  for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2382  auto access = cast<SubfieldOp>(accessOp);
2383  for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2384  auto connect = dyn_cast<FConnectLike>(user);
2385  if (connect && connect.getDest() == access) {
2386  rewriter.eraseOp(user);
2387  continue;
2388  }
2389  }
2390  if (access.use_empty()) {
2391  rewriter.eraseOp(access);
2392  continue;
2393  }
2394 
2395  // Replace read values with a register that is never written, handing off
2396  // the canonicalization of such a register to another canonicalizer.
2397  auto ty = access.getType();
2398  auto reg = rewriter.create<RegOp>(access.getLoc(), ty, getClock());
2399  rewriter.replaceOp(access, reg.getResult());
2400  }
2401  assert(port.use_empty() && "port should have no remaining uses");
2402 }
2403 
2404 namespace {
2405 // If memory has known, but zero width, eliminate it.
2406 struct FoldZeroWidthMemory : public mlir::RewritePattern {
2407  FoldZeroWidthMemory(MLIRContext *context)
2408  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2409  LogicalResult matchAndRewrite(Operation *op,
2410  PatternRewriter &rewriter) const override {
2411  MemOp mem = cast<MemOp>(op);
2412  if (hasDontTouch(mem))
2413  return failure();
2414 
2415  if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2416  mem.getDataType().getBitWidthOrSentinel() != 0)
2417  return failure();
2418 
2419  // Make sure are users are safe to replace
2420  for (auto port : mem.getResults())
2421  for (auto *user : port.getUsers())
2422  if (!isa<SubfieldOp>(user))
2423  return failure();
2424 
2425  // Annoyingly, there isn't a good replacement for the port as a whole,
2426  // since they have an outer flip type.
2427  for (auto port : op->getResults()) {
2428  for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2429  SubfieldOp sfop = cast<SubfieldOp>(user);
2430  StringRef fieldName = sfop.getFieldName();
2431  auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2432  rewriter, sfop, sfop.getResult().getType())
2433  .getResult();
2434  if (fieldName.ends_with("data")) {
2435  // Make sure to write data ports.
2436  auto zero = rewriter.create<firrtl::ConstantOp>(
2437  wire.getLoc(), firrtl::type_cast<IntType>(wire.getType()),
2438  APInt::getZero(0));
2439  rewriter.create<MatchingConnectOp>(wire.getLoc(), wire, zero);
2440  }
2441  }
2442  }
2443  rewriter.eraseOp(op);
2444  return success();
2445  }
2446 };
2447 
2448 // If memory has no write ports and no file initialization, eliminate it.
2449 struct FoldReadOrWriteOnlyMemory : public mlir::RewritePattern {
2450  FoldReadOrWriteOnlyMemory(MLIRContext *context)
2451  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2452  LogicalResult matchAndRewrite(Operation *op,
2453  PatternRewriter &rewriter) const override {
2454  MemOp mem = cast<MemOp>(op);
2455  if (hasDontTouch(mem))
2456  return failure();
2457  bool isRead = false, isWritten = false;
2458  for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2459  switch (mem.getPortKind(i)) {
2460  case MemOp::PortKind::Read:
2461  isRead = true;
2462  if (isWritten)
2463  return failure();
2464  continue;
2465  case MemOp::PortKind::Write:
2466  isWritten = true;
2467  if (isRead)
2468  return failure();
2469  continue;
2470  case MemOp::PortKind::Debug:
2471  case MemOp::PortKind::ReadWrite:
2472  return failure();
2473  }
2474  llvm_unreachable("unknown port kind");
2475  }
2476  assert((!isWritten || !isRead) && "memory is in use");
2477 
2478  // If the memory is read only, but has a file initialization, then we can't
2479  // remove it. A write only memory with file initialization is okay to
2480  // remove.
2481  if (isRead && mem.getInit())
2482  return failure();
2483 
2484  for (auto port : mem.getResults())
2485  erasePort(rewriter, port);
2486 
2487  rewriter.eraseOp(op);
2488  return success();
2489  }
2490 };
2491 
2492 // Eliminate the dead ports of memories.
2493 struct FoldUnusedPorts : public mlir::RewritePattern {
2494  FoldUnusedPorts(MLIRContext *context)
2495  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2496  LogicalResult matchAndRewrite(Operation *op,
2497  PatternRewriter &rewriter) const override {
2498  MemOp mem = cast<MemOp>(op);
2499  if (hasDontTouch(mem))
2500  return failure();
2501  // Identify the dead and changed ports.
2502  llvm::SmallBitVector deadPorts(mem.getNumResults());
2503  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2504  // Do not simplify annotated ports.
2505  if (!mem.getPortAnnotation(i).empty())
2506  continue;
2507 
2508  // Skip debug ports.
2509  auto kind = mem.getPortKind(i);
2510  if (kind == MemOp::PortKind::Debug)
2511  continue;
2512 
2513  // If a port is disabled, always eliminate it.
2514  if (isPortDisabled(port)) {
2515  deadPorts.set(i);
2516  continue;
2517  }
2518  // Eliminate read ports whose outputs are not used.
2519  if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
2520  deadPorts.set(i);
2521  continue;
2522  }
2523  }
2524  if (deadPorts.none())
2525  return failure();
2526 
2527  // Rebuild the new memory with the altered ports.
2528  SmallVector<Type> resultTypes;
2529  SmallVector<StringRef> portNames;
2530  SmallVector<Attribute> portAnnotations;
2531  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2532  if (deadPorts[i])
2533  continue;
2534  resultTypes.push_back(port.getType());
2535  portNames.push_back(mem.getPortName(i));
2536  portAnnotations.push_back(mem.getPortAnnotation(i));
2537  }
2538 
2539  MemOp newOp;
2540  if (!resultTypes.empty())
2541  newOp = rewriter.create<MemOp>(
2542  mem.getLoc(), resultTypes, mem.getReadLatency(),
2543  mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2544  rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2545  mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2546  mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2547 
2548  // Replace the dead ports with dummy wires.
2549  unsigned nextPort = 0;
2550  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2551  if (deadPorts[i])
2552  erasePort(rewriter, port);
2553  else
2554  rewriter.replaceAllUsesWith(port, newOp.getResult(nextPort++));
2555  }
2556 
2557  rewriter.eraseOp(op);
2558  return success();
2559  }
2560 };
2561 
2562 // Rewrite write-only read-write ports to write ports.
2563 struct FoldReadWritePorts : public mlir::RewritePattern {
2564  FoldReadWritePorts(MLIRContext *context)
2565  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2566  LogicalResult matchAndRewrite(Operation *op,
2567  PatternRewriter &rewriter) const override {
2568  MemOp mem = cast<MemOp>(op);
2569  if (hasDontTouch(mem))
2570  return failure();
2571 
2572  // Identify read-write ports whose read end is unused.
2573  llvm::SmallBitVector deadReads(mem.getNumResults());
2574  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2575  if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2576  continue;
2577  if (!mem.getPortAnnotation(i).empty())
2578  continue;
2579  if (isPortUnused(port, "rdata")) {
2580  deadReads.set(i);
2581  continue;
2582  }
2583  }
2584  if (deadReads.none())
2585  return failure();
2586 
2587  SmallVector<Type> resultTypes;
2588  SmallVector<StringRef> portNames;
2589  SmallVector<Attribute> portAnnotations;
2590  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2591  if (deadReads[i])
2592  resultTypes.push_back(
2593  MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2594  MemOp::PortKind::Write, mem.getMaskBits()));
2595  else
2596  resultTypes.push_back(port.getType());
2597 
2598  portNames.push_back(mem.getPortName(i));
2599  portAnnotations.push_back(mem.getPortAnnotation(i));
2600  }
2601 
2602  auto newOp = rewriter.create<MemOp>(
2603  mem.getLoc(), resultTypes, mem.getReadLatency(), mem.getWriteLatency(),
2604  mem.getDepth(), mem.getRuw(), rewriter.getStrArrayAttr(portNames),
2605  mem.getName(), mem.getNameKind(), mem.getAnnotations(),
2606  rewriter.getArrayAttr(portAnnotations), mem.getInnerSymAttr(),
2607  mem.getInitAttr(), mem.getPrefixAttr());
2608 
2609  for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2610  auto result = mem.getResult(i);
2611  auto newResult = newOp.getResult(i);
2612  if (deadReads[i]) {
2613  auto resultPortTy = type_cast<BundleType>(result.getType());
2614 
2615  // Rewrite accesses to the old port field to accesses to a
2616  // corresponding field of the new port.
2617  auto replace = [&](StringRef toName, StringRef fromName) {
2618  auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2619  assert(fromFieldIndex && "missing enable flag on memory port");
2620 
2621  auto toField = rewriter.create<SubfieldOp>(newResult.getLoc(),
2622  newResult, toName);
2623  for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2624  auto fromField = cast<SubfieldOp>(op);
2625  if (fromFieldIndex != fromField.getFieldIndex())
2626  continue;
2627  rewriter.replaceOp(fromField, toField.getResult());
2628  }
2629  };
2630 
2631  replace("addr", "addr");
2632  replace("en", "en");
2633  replace("clk", "clk");
2634  replace("data", "wdata");
2635  replace("mask", "wmask");
2636 
2637  // Remove the wmode field, replacing it with dummy wires.
2638  auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
2639  for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2640  auto wmodeField = cast<SubfieldOp>(op);
2641  if (wmodeFieldIndex != wmodeField.getFieldIndex())
2642  continue;
2643  rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2644  }
2645  } else {
2646  rewriter.replaceAllUsesWith(result, newResult);
2647  }
2648  }
2649  rewriter.eraseOp(op);
2650  return success();
2651  }
2652 };
2653 
2654 // Eliminate the dead ports of memories.
2655 struct FoldUnusedBits : public mlir::RewritePattern {
2656  FoldUnusedBits(MLIRContext *context)
2657  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2658 
2659  LogicalResult matchAndRewrite(Operation *op,
2660  PatternRewriter &rewriter) const override {
2661  MemOp mem = cast<MemOp>(op);
2662  if (hasDontTouch(mem))
2663  return failure();
2664 
2665  // Only apply the transformation if the memory is not sequential.
2666  const auto &summary = mem.getSummary();
2667  if (summary.isMasked || summary.isSeqMem())
2668  return failure();
2669 
2670  auto type = type_dyn_cast<IntType>(mem.getDataType());
2671  if (!type)
2672  return failure();
2673  auto width = type.getBitWidthOrSentinel();
2674  if (width <= 0)
2675  return failure();
2676 
2677  llvm::SmallBitVector usedBits(width);
2678  DenseMap<unsigned, unsigned> mapping;
2679 
2680  // Find which bits are used out of the users of a read port. This detects
2681  // ports whose data/rdata field is used only through bit select ops. The
2682  // bit selects are then used to build a bit-mask. The ops are collected.
2683  SmallVector<BitsPrimOp> readOps;
2684  auto findReadUsers = [&](Value port, StringRef field) {
2685  auto portTy = type_cast<BundleType>(port.getType());
2686  auto fieldIndex = portTy.getElementIndex(field);
2687  assert(fieldIndex && "missing data port");
2688 
2689  for (auto *op : port.getUsers()) {
2690  auto portAccess = cast<SubfieldOp>(op);
2691  if (fieldIndex != portAccess.getFieldIndex())
2692  continue;
2693 
2694  for (auto *user : op->getUsers()) {
2695  auto bits = dyn_cast<BitsPrimOp>(user);
2696  if (!bits) {
2697  usedBits.set();
2698  continue;
2699  }
2700 
2701  usedBits.set(bits.getLo(), bits.getHi() + 1);
2702  mapping[bits.getLo()] = 0;
2703  readOps.push_back(bits);
2704  }
2705  }
2706  };
2707 
2708  // Finds the users of write ports. This expects all the data/wdata fields
2709  // of the ports to be used solely as the destination of matching connects.
2710  // If a memory has ports with other uses, it is excluded from optimisation.
2711  SmallVector<MatchingConnectOp> writeOps;
2712  auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
2713  auto portTy = type_cast<BundleType>(port.getType());
2714  auto fieldIndex = portTy.getElementIndex(field);
2715  assert(fieldIndex && "missing data port");
2716 
2717  for (auto *op : port.getUsers()) {
2718  auto portAccess = cast<SubfieldOp>(op);
2719  if (fieldIndex != portAccess.getFieldIndex())
2720  continue;
2721 
2722  auto conn = getSingleConnectUserOf(portAccess);
2723  if (!conn)
2724  return failure();
2725 
2726  writeOps.push_back(conn);
2727  }
2728  return success();
2729  };
2730 
2731  // Traverse all ports and find the read and used data fields.
2732  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2733  // Do not simplify annotated ports.
2734  if (!mem.getPortAnnotation(i).empty())
2735  return failure();
2736 
2737  switch (mem.getPortKind(i)) {
2738  case MemOp::PortKind::Debug:
2739  // Skip debug ports.
2740  return failure();
2741  case MemOp::PortKind::Write:
2742  if (failed(findWriteUsers(port, "data")))
2743  return failure();
2744  continue;
2745  case MemOp::PortKind::Read:
2746  findReadUsers(port, "data");
2747  continue;
2748  case MemOp::PortKind::ReadWrite:
2749  if (failed(findWriteUsers(port, "wdata")))
2750  return failure();
2751  findReadUsers(port, "rdata");
2752  continue;
2753  }
2754  llvm_unreachable("unknown port kind");
2755  }
2756 
2757  // Perform the transformation is there are some bits missing. Unused
2758  // memories are handled in a different canonicalizer.
2759  if (usedBits.all() || usedBits.none())
2760  return failure();
2761 
2762  // Build a mapping of existing indices to compacted ones.
2763  SmallVector<std::pair<unsigned, unsigned>> ranges;
2764  unsigned newWidth = 0;
2765  for (int i = usedBits.find_first(); 0 <= i && i < width;) {
2766  int e = usedBits.find_next_unset(i);
2767  if (e < 0)
2768  e = width;
2769  for (int idx = i; idx < e; ++idx, ++newWidth) {
2770  if (auto it = mapping.find(idx); it != mapping.end()) {
2771  it->second = newWidth;
2772  }
2773  }
2774  ranges.emplace_back(i, e - 1);
2775  i = e != width ? usedBits.find_next(e) : e;
2776  }
2777 
2778  // Create the new op with the new port types.
2779  auto newType = IntType::get(op->getContext(), type.isSigned(), newWidth);
2780  SmallVector<Type> portTypes;
2781  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2782  portTypes.push_back(
2783  MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
2784  }
2785  auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
2786  mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
2787  mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
2788  mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
2789  mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2790 
2791  // Rewrite bundle users to the new data type.
2792  auto rewriteSubfield = [&](Value port, StringRef field) {
2793  auto portTy = type_cast<BundleType>(port.getType());
2794  auto fieldIndex = portTy.getElementIndex(field);
2795  assert(fieldIndex && "missing data port");
2796 
2797  rewriter.setInsertionPointAfter(newMem);
2798  auto newPortAccess =
2799  rewriter.create<SubfieldOp>(port.getLoc(), port, field);
2800 
2801  for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2802  auto portAccess = cast<SubfieldOp>(op);
2803  if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
2804  continue;
2805  rewriter.replaceOp(portAccess, newPortAccess.getResult());
2806  }
2807  };
2808 
2809  // Rewrite the field accesses.
2810  for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
2811  switch (newMem.getPortKind(i)) {
2812  case MemOp::PortKind::Debug:
2813  llvm_unreachable("cannot rewrite debug port");
2814  case MemOp::PortKind::Write:
2815  rewriteSubfield(port, "data");
2816  continue;
2817  case MemOp::PortKind::Read:
2818  rewriteSubfield(port, "data");
2819  continue;
2820  case MemOp::PortKind::ReadWrite:
2821  rewriteSubfield(port, "rdata");
2822  rewriteSubfield(port, "wdata");
2823  continue;
2824  }
2825  llvm_unreachable("unknown port kind");
2826  }
2827 
2828  // Rewrite the reads to the new ranges, compacting them.
2829  for (auto readOp : readOps) {
2830  rewriter.setInsertionPointAfter(readOp);
2831  auto it = mapping.find(readOp.getLo());
2832  assert(it != mapping.end() && "bit op mapping not found");
2833  rewriter.replaceOpWithNewOp<BitsPrimOp>(
2834  readOp, readOp.getInput(),
2835  readOp.getHi() - readOp.getLo() + it->second, it->second);
2836  }
2837 
2838  // Rewrite the writes into a concatenation of slices.
2839  for (auto writeOp : writeOps) {
2840  Value source = writeOp.getSrc();
2841  rewriter.setInsertionPoint(writeOp);
2842 
2843  Value catOfSlices;
2844  for (auto &[start, end] : ranges) {
2845  Value slice =
2846  rewriter.create<BitsPrimOp>(writeOp.getLoc(), source, end, start);
2847  if (catOfSlices) {
2848  catOfSlices =
2849  rewriter.create<CatPrimOp>(writeOp.getLoc(), slice, catOfSlices);
2850  } else {
2851  catOfSlices = slice;
2852  }
2853  }
2854  rewriter.replaceOpWithNewOp<MatchingConnectOp>(writeOp, writeOp.getDest(),
2855  catOfSlices);
2856  }
2857 
2858  return success();
2859  }
2860 };
2861 
2862 // Rewrite single-address memories to a firrtl register.
2863 struct FoldRegMems : public mlir::RewritePattern {
2864  FoldRegMems(MLIRContext *context)
2865  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2866  LogicalResult matchAndRewrite(Operation *op,
2867  PatternRewriter &rewriter) const override {
2868  MemOp mem = cast<MemOp>(op);
2869  const FirMemory &info = mem.getSummary();
2870  if (hasDontTouch(mem) || info.depth != 1)
2871  return failure();
2872 
2873  auto memModule = mem->getParentOfType<FModuleOp>();
2874 
2875  // Find the clock of the register-to-be, all write ports should share it.
2876  Value clock;
2877  SmallPtrSet<Operation *, 8> connects;
2878  SmallVector<SubfieldOp> portAccesses;
2879  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2880  if (!mem.getPortAnnotation(i).empty())
2881  continue;
2882 
2883  auto collect = [&, port = port](ArrayRef<StringRef> fields) {
2884  auto portTy = type_cast<BundleType>(port.getType());
2885  for (auto field : fields) {
2886  auto fieldIndex = portTy.getElementIndex(field);
2887  assert(fieldIndex && "missing field on memory port");
2888 
2889  for (auto *op : port.getUsers()) {
2890  auto portAccess = cast<SubfieldOp>(op);
2891  if (fieldIndex != portAccess.getFieldIndex())
2892  continue;
2893  portAccesses.push_back(portAccess);
2894  for (auto *user : portAccess->getUsers()) {
2895  auto conn = dyn_cast<FConnectLike>(user);
2896  if (!conn)
2897  return failure();
2898  connects.insert(conn);
2899  }
2900  }
2901  }
2902  return success();
2903  };
2904 
2905  switch (mem.getPortKind(i)) {
2906  case MemOp::PortKind::Debug:
2907  return failure();
2908  case MemOp::PortKind::Read:
2909  if (failed(collect({"clk", "en", "addr"})))
2910  return failure();
2911  continue;
2912  case MemOp::PortKind::Write:
2913  if (failed(collect({"clk", "en", "addr", "data", "mask"})))
2914  return failure();
2915  break;
2916  case MemOp::PortKind::ReadWrite:
2917  if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
2918  return failure();
2919  break;
2920  }
2921 
2922  Value portClock = getPortFieldValue(port, "clk");
2923  if (!portClock || (clock && portClock != clock))
2924  return failure();
2925  clock = portClock;
2926  }
2927 
2928  // Create a new register to store the data.
2929  auto ty = mem.getDataType();
2930  rewriter.setInsertionPointAfterValue(clock);
2931  auto reg = rewriter.create<RegOp>(mem.getLoc(), ty, clock, mem.getName())
2932  .getResult();
2933 
2934  // Helper to insert a given number of pipeline stages through registers.
2935  auto pipeline = [&](Value value, Value clock, const Twine &name,
2936  unsigned latency) {
2937  for (unsigned i = 0; i < latency; ++i) {
2938  std::string regName;
2939  {
2940  llvm::raw_string_ostream os(regName);
2941  os << mem.getName() << "_" << name << "_" << i;
2942  }
2943 
2944  auto reg = rewriter
2945  .create<RegOp>(mem.getLoc(), value.getType(), clock,
2946  rewriter.getStringAttr(regName))
2947  .getResult();
2948  rewriter.create<MatchingConnectOp>(value.getLoc(), reg, value);
2949  value = reg;
2950  }
2951  return value;
2952  };
2953 
2954  const unsigned writeStages = info.writeLatency - 1;
2955 
2956  // Traverse each port. Replace reads with the pipelined register, discarding
2957  // the enable flag and reading unconditionally. Pipeline the mask, enable
2958  // and data bits of all write ports to be arbitrated and wired to the reg.
2959  SmallVector<std::tuple<Value, Value, Value>> writes;
2960  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2961  Value portClock = getPortFieldValue(port, "clk");
2962  StringRef name = mem.getPortName(i);
2963 
2964  auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
2965  Value value = getPortFieldValue(port, field);
2966  assert(value);
2967  rewriter.setInsertionPointAfterValue(value);
2968  return pipeline(value, portClock, name + "_" + field, stages);
2969  };
2970 
2971  switch (mem.getPortKind(i)) {
2972  case MemOp::PortKind::Debug:
2973  llvm_unreachable("unknown port kind");
2974  case MemOp::PortKind::Read: {
2975  // Read ports pipeline the addr and enable signals. However, the
2976  // address must be 0 for single-address memories and the enable signal
2977  // is ignored, always reading out the register. Under these constraints,
2978  // the read port can be replaced with the value from the register.
2979  rewriter.setInsertionPointAfterValue(reg);
2980  replacePortField(rewriter, port, "data", reg);
2981  break;
2982  }
2983  case MemOp::PortKind::Write: {
2984  auto data = portPipeline("data", writeStages);
2985  auto en = portPipeline("en", writeStages);
2986  auto mask = portPipeline("mask", writeStages);
2987  writes.emplace_back(data, en, mask);
2988  break;
2989  }
2990  case MemOp::PortKind::ReadWrite: {
2991  // Always read the register into the read end.
2992  rewriter.setInsertionPointAfterValue(reg);
2993  replacePortField(rewriter, port, "rdata", reg);
2994 
2995  // Create a write enable and pipeline stages.
2996  auto wdata = portPipeline("wdata", writeStages);
2997  auto wmask = portPipeline("wmask", writeStages);
2998 
2999  Value en = getPortFieldValue(port, "en");
3000  Value wmode = getPortFieldValue(port, "wmode");
3001  rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
3002 
3003  auto wen = rewriter.create<AndPrimOp>(port.getLoc(), en, wmode);
3004  auto wenPipelined =
3005  pipeline(wen, portClock, name + "_wen", writeStages);
3006  writes.emplace_back(wdata, wenPipelined, wmask);
3007  break;
3008  }
3009  }
3010  }
3011 
3012  // Regardless of `writeUnderWrite`, always implement PortOrder.
3013  rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
3014  Value next = reg;
3015  for (auto &[data, en, mask] : writes) {
3016  Value masked;
3017 
3018  // If a mask bit is used, emit muxes to select the input from the
3019  // register (no mask) or the input (mask bit set).
3020  Location loc = mem.getLoc();
3021  unsigned maskGran = info.dataWidth / info.maskBits;
3022  for (unsigned i = 0; i < info.maskBits; ++i) {
3023  unsigned hi = (i + 1) * maskGran - 1;
3024  unsigned lo = i * maskGran;
3025 
3026  auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
3027  auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3028  auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
3029  auto chunk = rewriter.create<MuxPrimOp>(loc, bit, dataPart, nextPart);
3030 
3031  if (masked) {
3032  masked = rewriter.create<CatPrimOp>(loc, chunk, masked);
3033  } else {
3034  masked = chunk;
3035  }
3036  }
3037 
3038  next = rewriter.create<MuxPrimOp>(next.getLoc(), en, masked, next);
3039  }
3040  rewriter.create<MatchingConnectOp>(reg.getLoc(), reg, next);
3041 
3042  // Delete the fields and their associated connects.
3043  for (Operation *conn : connects)
3044  rewriter.eraseOp(conn);
3045  for (auto portAccess : portAccesses)
3046  rewriter.eraseOp(portAccess);
3047  rewriter.eraseOp(mem);
3048 
3049  return success();
3050  }
3051 };
3052 } // namespace
3053 
3054 void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3055  MLIRContext *context) {
3056  results
3057  .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3058  FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3059  context);
3060 }
3061 
3062 //===----------------------------------------------------------------------===//
3063 // Declarations
3064 //===----------------------------------------------------------------------===//
3065 
3066 // Turn synchronous reset looking register updates to registers with resets.
3067 // Also, const prop registers that are driven by a mux tree containing only
3068 // instances of one constant or self-assigns.
3069 static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3070  // reg ; connect(reg, mux(port, const, val)) ->
3071  // reg.reset(port, const); connect(reg, val)
3072 
3073  // Find the one true connect, or bail
3074  auto con = getSingleConnectUserOf(reg.getResult());
3075  if (!con)
3076  return failure();
3077 
3078  auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3079  if (!mux)
3080  return failure();
3081  auto *high = mux.getHigh().getDefiningOp();
3082  auto *low = mux.getLow().getDefiningOp();
3083  // Reset value must be constant
3084  auto constOp = dyn_cast_or_null<ConstantOp>(high);
3085 
3086  // Detect the case if a register only has two possible drivers:
3087  // (1) itself/uninit and (2) constant.
3088  // The mux can then be replaced with the constant.
3089  // r = mux(cond, r, 3) --> r = 3
3090  // r = mux(cond, 3, r) --> r = 3
3091  bool constReg = false;
3092 
3093  if (constOp && low == reg)
3094  constReg = true;
3095  else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3096  constReg = true;
3097  constOp = dyn_cast<ConstantOp>(low);
3098  }
3099  if (!constOp)
3100  return failure();
3101 
3102  // For a non-constant register, reset should be a module port (heuristic to
3103  // limit to intended reset lines). Replace the register anyway if constant.
3104  if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3105  return failure();
3106 
3107  // Check all types should be typed by now
3108  auto regTy = reg.getResult().getType();
3109  if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3110  mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3111  regTy.getBitWidthOrSentinel() < 0)
3112  return failure();
3113 
3114  // Ok, we know we are doing the transformation.
3115 
3116  // Make sure the constant dominates all users.
3117  if (constOp != &con->getBlock()->front())
3118  constOp->moveBefore(&con->getBlock()->front());
3119 
3120  if (!constReg) {
3121  SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3122  auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3123  rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3124  mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3125  reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3126  reg.getForceableAttr());
3127  newReg->setDialectAttrs(attrs);
3128  }
3129  auto pt = rewriter.saveInsertionPoint();
3130  rewriter.setInsertionPoint(con);
3131  auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3132  replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3133  rewriter.restoreInsertionPoint(pt);
3134  return success();
3135 }
3136 
3137 LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3138  if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3139  succeeded(foldHiddenReset(op, rewriter)))
3140  return success();
3141 
3142  if (succeeded(demoteForceableIfUnused(op, rewriter)))
3143  return success();
3144 
3145  return failure();
3146 }
3147 
3148 //===----------------------------------------------------------------------===//
3149 // Verification Ops.
3150 //===----------------------------------------------------------------------===//
3151 
3152 static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3153  Value enable,
3154  PatternRewriter &rewriter,
3155  bool eraseIfZero) {
3156  // If the verification op is never enabled, delete it.
3157  if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3158  if (constant.getValue().isZero()) {
3159  rewriter.eraseOp(op);
3160  return success();
3161  }
3162  }
3163 
3164  // If the verification op is never triggered, delete it.
3165  if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3166  if (constant.getValue().isZero() == eraseIfZero) {
3167  rewriter.eraseOp(op);
3168  return success();
3169  }
3170  }
3171 
3172  return failure();
3173 }
3174 
3175 template <class Op, bool EraseIfZero = false>
3176 static LogicalResult canonicalizeImmediateVerifOp(Op op,
3177  PatternRewriter &rewriter) {
3178  return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3179  EraseIfZero);
3180 }
3181 
3182 void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3183  MLIRContext *context) {
3184  results.add(canonicalizeImmediateVerifOp<AssertOp>);
3185 }
3186 
3187 void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3188  MLIRContext *context) {
3189  results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3190 }
3191 
3192 void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3193  RewritePatternSet &results, MLIRContext *context) {
3194  results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3195 }
3196 
3197 void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3198  MLIRContext *context) {
3199  results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3200 }
3201 
3202 //===----------------------------------------------------------------------===//
3203 // InvalidValueOp
3204 //===----------------------------------------------------------------------===//
3205 
3206 LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3207  PatternRewriter &rewriter) {
3208  // Remove `InvalidValueOp`s with no uses.
3209  if (op.use_empty()) {
3210  rewriter.eraseOp(op);
3211  return success();
3212  }
3213  // Propagate invalids through a single use which is a unary op. You cannot
3214  // propagate through multiple uses as that breaks invalid semantics. Nor
3215  // can you propagate through binary ops or generally any op which computes.
3216  // Not is an exception as it is a pure, all-bits inverse.
3217  if (op->hasOneUse() &&
3218  (isa<BitsPrimOp, HeadPrimOp, ShrPrimOp, TailPrimOp, SubfieldOp,
3219  SubindexOp, AsSIntPrimOp, AsUIntPrimOp, NotPrimOp, BitCastOp>(
3220  *op->user_begin()) ||
3221  (isa<CvtPrimOp>(*op->user_begin()) &&
3222  type_isa<SIntType>(op->user_begin()->getOperand(0).getType())) ||
3223  (isa<AndRPrimOp, XorRPrimOp, OrRPrimOp>(*op->user_begin()) &&
3224  type_cast<FIRRTLBaseType>(op->user_begin()->getOperand(0).getType())
3225  .getBitWidthOrSentinel() > 0))) {
3226  auto *modop = *op->user_begin();
3227  auto inv = rewriter.create<InvalidValueOp>(op.getLoc(),
3228  modop->getResult(0).getType());
3229  rewriter.replaceAllOpUsesWith(modop, inv);
3230  rewriter.eraseOp(modop);
3231  rewriter.eraseOp(op);
3232  return success();
3233  }
3234  return failure();
3235 }
3236 
3237 OpFoldResult InvalidValueOp::fold(FoldAdaptor adaptor) {
3238  if (getType().getBitWidthOrSentinel() == 0 && isa<IntType>(getType()))
3239  return getIntAttr(getType(), APInt(0, 0, isa<SIntType>(getType())));
3240  return {};
3241 }
3242 
3243 //===----------------------------------------------------------------------===//
3244 // ClockGateIntrinsicOp
3245 //===----------------------------------------------------------------------===//
3246 
3247 OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3248  // Forward the clock if one of the enables is always true.
3249  if (isConstantOne(adaptor.getEnable()) ||
3250  isConstantOne(adaptor.getTestEnable()))
3251  return getInput();
3252 
3253  // Fold to a constant zero clock if the enables are always false.
3254  if (isConstantZero(adaptor.getEnable()) &&
3255  (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3256  return BoolAttr::get(getContext(), false);
3257 
3258  // Forward constant zero clocks.
3259  if (isConstantZero(adaptor.getInput()))
3260  return BoolAttr::get(getContext(), false);
3261 
3262  return {};
3263 }
3264 
3265 LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3266  PatternRewriter &rewriter) {
3267  // Remove constant false test enable.
3268  if (auto testEnable = op.getTestEnable()) {
3269  if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3270  if (constOp.getValue().isZero()) {
3271  rewriter.modifyOpInPlace(op,
3272  [&] { op.getTestEnableMutable().clear(); });
3273  return success();
3274  }
3275  }
3276  }
3277 
3278  return failure();
3279 }
3280 
3281 //===----------------------------------------------------------------------===//
3282 // Reference Ops.
3283 //===----------------------------------------------------------------------===//
3284 
3285 // refresolve(forceable.ref) -> forceable.data
3286 static LogicalResult
3287 canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3288  auto forceable = op.getRef().getDefiningOp<Forceable>();
3289  if (!forceable || !forceable.isForceable() ||
3290  op.getRef() != forceable.getDataRef() ||
3291  op.getType() != forceable.getDataType())
3292  return failure();
3293  rewriter.replaceAllUsesWith(op, forceable.getData());
3294  return success();
3295 }
3296 
3297 void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3298  MLIRContext *context) {
3299  results.insert<patterns::RefResolveOfRefSend>(context);
3300  results.insert(canonicalizeRefResolveOfForceable);
3301 }
3302 
3303 OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3304  // RefCast is unnecessary if types match.
3305  if (getInput().getType() == getType())
3306  return getInput();
3307  return {};
3308 }
3309 
3310 static bool isConstantZero(Value operand) {
3311  auto constOp = operand.getDefiningOp<ConstantOp>();
3312  return constOp && constOp.getValue().isZero();
3313 }
3314 
3315 template <typename Op>
3316 static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3317  if (isConstantZero(op.getPredicate())) {
3318  rewriter.eraseOp(op);
3319  return success();
3320  }
3321  return failure();
3322 }
3323 
3324 void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3325  MLIRContext *context) {
3326  results.add(eraseIfPredFalse<RefForceOp>);
3327 }
3328 void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3329  MLIRContext *context) {
3330  results.add(eraseIfPredFalse<RefForceInitialOp>);
3331 }
3332 void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3333  MLIRContext *context) {
3334  results.add(eraseIfPredFalse<RefReleaseOp>);
3335 }
3336 void RefReleaseInitialOp::getCanonicalizationPatterns(
3337  RewritePatternSet &results, MLIRContext *context) {
3338  results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3339 }
3340 
3341 //===----------------------------------------------------------------------===//
3342 // HasBeenResetIntrinsicOp
3343 //===----------------------------------------------------------------------===//
3344 
3345 OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3346  // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3347 
3348  // Fold to zero if the reset is a constant. In this case the op is either
3349  // permanently in reset or never resets. Both mean that the reset never
3350  // finishes, so this op never returns true.
3351  if (adaptor.getReset())
3352  return getIntZerosAttr(UIntType::get(getContext(), 1));
3353 
3354  // Fold to zero if the clock is a constant and the reset is synchronous. In
3355  // that case the reset will never be started.
3356  if (isUInt1(getReset().getType()) && adaptor.getClock())
3357  return getIntZerosAttr(UIntType::get(getContext(), 1));
3358 
3359  return {};
3360 }
3361 
3362 //===----------------------------------------------------------------------===//
3363 // FPGAProbeIntrinsicOp
3364 //===----------------------------------------------------------------------===//
3365 
3366 static bool isTypeEmpty(FIRRTLType type) {
3368  .Case<FVectorType>(
3369  [&](auto ty) -> bool { return isTypeEmpty(ty.getElementType()); })
3370  .Case<BundleType>([&](auto ty) -> bool {
3371  for (auto elem : ty.getElements())
3372  if (!isTypeEmpty(elem.type))
3373  return false;
3374  return true;
3375  })
3376  .Case<IntType>([&](auto ty) { return ty.getWidth() == 0; })
3377  .Default([](auto) -> bool { return false; });
3378 }
3379 
3380 LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3381  PatternRewriter &rewriter) {
3382  auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3383  if (!firrtlTy)
3384  return failure();
3385 
3386  if (!isTypeEmpty(firrtlTy))
3387  return failure();
3388 
3389  rewriter.eraseOp(op);
3390  return success();
3391 }
assert(baseType &&"element must be base type")
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
Definition: FIRRTLFolds.cpp:71
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static LogicalResult canonicalizeSingleSetConnect(MatchingConnectOp op, PatternRewriter &rewriter)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
Definition: FIRRTLFolds.cpp:91
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
Definition: FIRRTLFolds.cpp:82
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
Definition: FIRRTLFolds.cpp:32
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value moveNameHint(OpResult old, Value passthrough)
Definition: FIRRTLFolds.cpp:48
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
int32_t width
Definition: FIRRTL.cpp:36
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:520
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:530
This is the common base class between SIntType and UIntType.
Definition: FIRRTLTypes.h:296
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
Definition: FIRRTLTypes.h:275
bool hasWidth() const
Return true if this integer type has a known width.
Definition: FIRRTLTypes.h:283
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
Definition: VerifOps.cpp:66
def connect(destination, source)
Definition: support.py:39
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:32
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
Definition: FIRRTLOps.cpp:317
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
MatchingConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition: APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition: APInt.cpp:18
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition: Naming.cpp:47
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition: FoldUtils.h:27
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition: FoldUtils.h:34
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:21