CIRCT  19.0.0git
FIRRTLFolds.cpp
Go to the documentation of this file.
1 //===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the folding and canonicalizations for FIRRTL ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "circt/Support/APInt.h"
18 #include "circt/Support/LLVM.h"
19 #include "circt/Support/Naming.h"
20 #include "mlir/IR/Matchers.h"
21 #include "mlir/IR/PatternMatch.h"
22 #include "llvm/ADT/APSInt.h"
23 #include "llvm/ADT/SmallPtrSet.h"
24 #include "llvm/ADT/StringExtras.h"
25 #include "llvm/ADT/TypeSwitch.h"
26 
27 using namespace circt;
28 using namespace firrtl;
29 
30 // Drop writes to old and pass through passthrough to make patterns easier to
31 // write.
32 static Value dropWrite(PatternRewriter &rewriter, OpResult old,
33  Value passthrough) {
34  SmallPtrSet<Operation *, 8> users;
35  for (auto *user : old.getUsers())
36  users.insert(user);
37  for (Operation *user : users)
38  if (auto connect = dyn_cast<FConnectLike>(user))
39  if (connect.getDest() == old)
40  rewriter.eraseOp(user);
41  return passthrough;
42 }
43 
44 // Move a name hint from a soon to be deleted operation to a new operation.
45 // Pass through the new operation to make patterns easier to write. This cannot
46 // move a name to a port (block argument), doing so would require rewriting all
47 // instance sites as well as the module.
48 static Value moveNameHint(OpResult old, Value passthrough) {
49  Operation *op = passthrough.getDefiningOp();
50  // This should handle ports, but it isn't clear we can change those in
51  // canonicalizers.
52  assert(op && "passthrough must be an operation");
53  Operation *oldOp = old.getOwner();
54  auto name = oldOp->getAttrOfType<StringAttr>("name");
55  if (name && !name.getValue().empty())
56  op->setAttr("name", name);
57  return passthrough;
58 }
59 
60 // Declarative canonicalization patterns
61 namespace circt {
62 namespace firrtl {
63 namespace patterns {
64 #include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
65 } // namespace patterns
66 } // namespace firrtl
67 } // namespace circt
68 
69 /// Return true if this operation's operands and results all have a known width.
70 /// This only works for integer types.
71 static bool hasKnownWidthIntTypes(Operation *op) {
72  auto resultType = type_cast<IntType>(op->getResult(0).getType());
73  if (!resultType.hasWidth())
74  return false;
75  for (Value operand : op->getOperands())
76  if (!type_cast<IntType>(operand.getType()).hasWidth())
77  return false;
78  return true;
79 }
80 
81 /// Return true if this value is 1 bit UInt.
82 static bool isUInt1(Type type) {
83  auto t = type_dyn_cast<UIntType>(type);
84  if (!t || !t.hasWidth() || t.getWidth() != 1)
85  return false;
86  return true;
87 }
88 
89 /// Set the name of an op based on the best of two names: The current name, and
90 /// the name passed in.
91 static void updateName(PatternRewriter &rewriter, Operation *op,
92  StringAttr name) {
93  // Should never rename InstanceOp
94  assert(!isa<InstanceOp>(op));
95  if (!name || name.getValue().empty())
96  return;
97  auto newName = name.getValue(); // old name is interesting
98  auto newOpName = op->getAttrOfType<StringAttr>("name");
99  // new name might not be interesting
100  if (newOpName)
101  newName = chooseName(newOpName.getValue(), name.getValue());
102  // Only update if needed
103  if (!newOpName || newOpName.getValue() != newName)
104  rewriter.modifyOpInPlace(
105  op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
106 }
107 
108 /// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
109 /// If a replaced op has a "name" attribute, this function propagates the name
110 /// to the new value.
111 static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
112  Value newValue) {
113  if (auto *newOp = newValue.getDefiningOp()) {
114  auto name = op->getAttrOfType<StringAttr>("name");
115  updateName(rewriter, newOp, name);
116  }
117  rewriter.replaceOp(op, newValue);
118 }
119 
120 /// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
121 /// attribute. If a replaced op has a "name" attribute, this function propagates
122 /// the name to the new value.
123 template <typename OpTy, typename... Args>
124 static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
125  Operation *op, Args &&...args) {
126  auto name = op->getAttrOfType<StringAttr>("name");
127  auto newOp =
128  rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
129  updateName(rewriter, newOp, name);
130  return newOp;
131 }
132 
133 /// Return true if the name is droppable. Note that this is different from
134 /// `isUselessName` because non-useless names may be also droppable.
135 bool circt::firrtl::hasDroppableName(Operation *op) {
136  if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
137  return namableOp.hasDroppableName();
138  return false;
139 }
140 
141 /// Implicitly replace the operand to a constant folding operation with a const
142 /// 0 in case the operand is non-constant but has a bit width 0, or if the
143 /// operand is an invalid value.
144 ///
145 /// This makes constant folding significantly easier, as we can simply pass the
146 /// operands to an operation through this function to appropriately replace any
147 /// zero-width dynamic values or invalid values with a constant of value 0.
148 static std::optional<APSInt>
149 getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
150  assert(type_cast<IntType>(operand.getType()) &&
151  "getExtendedConstant is limited to integer types");
152 
153  // We never support constant folding to unknown width values.
154  if (destWidth < 0)
155  return {};
156 
157  // Extension signedness follows the operand sign.
158  if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
159  return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
160 
161  // If the operand is zero bits, then we can return a zero of the result
162  // type.
163  if (type_cast<IntType>(operand.getType()).getWidth() == 0)
164  return APSInt(destWidth,
165  type_cast<IntType>(operand.getType()).isUnsigned());
166  return {};
167 }
168 
169 /// Determine the value of a constant operand for the sake of constant folding.
170 static std::optional<APSInt> getConstant(Attribute operand) {
171  if (!operand)
172  return {};
173  if (auto attr = dyn_cast<BoolAttr>(operand))
174  return APSInt(APInt(1, attr.getValue()));
175  if (auto attr = dyn_cast<IntegerAttr>(operand))
176  return attr.getAPSInt();
177  return {};
178 }
179 
180 /// Determine whether a constant operand is a zero value for the sake of
181 /// constant folding. This considers `invalidvalue` to be zero.
182 static bool isConstantZero(Attribute operand) {
183  if (auto cst = getConstant(operand))
184  return cst->isZero();
185  return false;
186 }
187 
188 /// Determine whether a constant operand is a one value for the sake of constant
189 /// folding.
190 static bool isConstantOne(Attribute operand) {
191  if (auto cst = getConstant(operand))
192  return cst->isOne();
193  return false;
194 }
195 
196 /// This is the policy for folding, which depends on the sort of operator we're
197 /// processing.
198 enum class BinOpKind {
199  Normal,
200  Compare,
202 };
203 
204 /// Applies the constant folding function `calculate` to the given operands.
205 ///
206 /// Sign or zero extends the operands appropriately to the bitwidth of the
207 /// result type if \p useDstWidth is true, else to the larger of the two operand
208 /// bit widths and depending on whether the operation is to be performed on
209 /// signed or unsigned operands.
210 static Attribute constFoldFIRRTLBinaryOp(
211  Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
212  const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
213  assert(operands.size() == 2 && "binary op takes two operands");
214 
215  // We cannot fold something to an unknown width.
216  auto resultType = type_cast<IntType>(op->getResult(0).getType());
217  if (resultType.getWidthOrSentinel() < 0)
218  return {};
219 
220  // Any binary op returning i0 is 0.
221  if (resultType.getWidthOrSentinel() == 0)
222  return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
223 
224  // Determine the operand widths. This is either dictated by the operand type,
225  // or if that type is an unsized integer, by the actual bits necessary to
226  // represent the constant value.
227  auto lhsWidth =
228  type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
229  auto rhsWidth =
230  type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
231  if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
232  lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
233  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
234  rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
235 
236  // Compares extend the operands to the widest of the operand types, not to the
237  // result type.
238  int32_t operandWidth;
239  switch (opKind) {
240  case BinOpKind::Normal:
241  operandWidth = resultType.getWidthOrSentinel();
242  break;
243  case BinOpKind::Compare:
244  // Compares compute with the widest operand, not at the destination type
245  // (which is always i1).
246  operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
247  break;
249  operandWidth =
250  std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
251  break;
252  }
253 
254  auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
255  if (!lhs)
256  return {};
257  auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
258  if (!rhs)
259  return {};
260 
261  APInt resultValue = calculate(*lhs, *rhs);
262 
263  // If the result type is smaller than the computation then we need to
264  // narrow the constant after the calculation.
265  if (opKind == BinOpKind::DivideOrShift)
266  resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
267 
268  assert((unsigned)resultType.getWidthOrSentinel() ==
269  resultValue.getBitWidth());
270  return getIntAttr(resultType, resultValue);
271 }
272 
273 /// Applies the canonicalization function `canonicalize` to the given operation.
274 ///
275 /// Determines which (if any) of the operation's operands are constants, and
276 /// provides them as arguments to the callback function. Any `invalidvalue` in
277 /// the input is mapped to a constant zero. The value returned from the callback
278 /// is used as the replacement for `op`, and an additional pad operation is
279 /// inserted if necessary. Does nothing if the result of `op` is of unknown
280 /// width, in which case the necessity of a pad cannot be determined.
281 static LogicalResult canonicalizePrimOp(
282  Operation *op, PatternRewriter &rewriter,
283  const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
284  // Can only operate on FIRRTL primitive operations.
285  if (op->getNumResults() != 1)
286  return failure();
287  auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
288  if (!type)
289  return failure();
290 
291  // Can only operate on operations with a known result width.
292  auto width = type.getBitWidthOrSentinel();
293  if (width < 0)
294  return failure();
295 
296  // Determine which of the operands are constants.
297  SmallVector<Attribute, 3> constOperands;
298  constOperands.reserve(op->getNumOperands());
299  for (auto operand : op->getOperands()) {
300  Attribute attr;
301  if (auto *defOp = operand.getDefiningOp())
302  TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
303  [&](auto op) { attr = op.getValueAttr(); });
304  constOperands.push_back(attr);
305  }
306 
307  // Perform the canonicalization and materialize the result if it is a
308  // constant.
309  auto result = canonicalize(constOperands);
310  if (!result)
311  return failure();
312  Value resultValue;
313  if (auto cst = dyn_cast<Attribute>(result))
314  resultValue = op->getDialect()
315  ->materializeConstant(rewriter, cst, type, op->getLoc())
316  ->getResult(0);
317  else
318  resultValue = result.get<Value>();
319 
320  // Insert a pad if the type widths disagree.
321  if (width !=
322  type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
323  resultValue = rewriter.create<PadPrimOp>(op->getLoc(), resultValue, width);
324 
325  // Insert a cast if this is a uint vs. sint or vice versa.
326  if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
327  resultValue = rewriter.create<AsSIntPrimOp>(op->getLoc(), resultValue);
328  else if (type_isa<UIntType>(type) &&
329  type_isa<SIntType>(resultValue.getType()))
330  resultValue = rewriter.create<AsUIntPrimOp>(op->getLoc(), resultValue);
331 
332  assert(type == resultValue.getType() && "canonicalization changed type");
333  replaceOpAndCopyName(rewriter, op, resultValue);
334  return success();
335 }
336 
337 /// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
338 /// value if `bitWidth` is 0.
339 static APInt getMaxUnsignedValue(unsigned bitWidth) {
340  return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
341 }
342 
343 /// Get the smallest signed value of a given bit width. Returns a 1-bit zero
344 /// value if `bitWidth` is 0.
345 static APInt getMinSignedValue(unsigned bitWidth) {
346  return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
347 }
348 
349 /// Get the largest signed value of a given bit width. Returns a 1-bit zero
350 /// value if `bitWidth` is 0.
351 static APInt getMaxSignedValue(unsigned bitWidth) {
352  return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
353 }
354 
355 //===----------------------------------------------------------------------===//
356 // Fold Hooks
357 //===----------------------------------------------------------------------===//
358 
359 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
360  assert(adaptor.getOperands().empty() && "constant has no operands");
361  return getValueAttr();
362 }
363 
364 OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
365  assert(adaptor.getOperands().empty() && "constant has no operands");
366  return getValueAttr();
367 }
368 
369 OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
370  assert(adaptor.getOperands().empty() && "constant has no operands");
371  return getFieldsAttr();
372 }
373 
374 OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
375  assert(adaptor.getOperands().empty() && "constant has no operands");
376  return getValueAttr();
377 }
378 
379 OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
380  assert(adaptor.getOperands().empty() && "constant has no operands");
381  return getValueAttr();
382 }
383 
384 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
385  assert(adaptor.getOperands().empty() && "constant has no operands");
386  return getValueAttr();
387 }
388 
389 OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
390  assert(adaptor.getOperands().empty() && "constant has no operands");
391  return getValueAttr();
392 }
393 
394 //===----------------------------------------------------------------------===//
395 // Binary Operators
396 //===----------------------------------------------------------------------===//
397 
398 OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
400  *this, adaptor.getOperands(), BinOpKind::Normal,
401  [=](const APSInt &a, const APSInt &b) { return a + b; });
402 }
403 
404 void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
405  MLIRContext *context) {
406  results.insert<patterns::moveConstAdd, patterns::AddOfZero,
407  patterns::AddOfSelf, patterns::AddOfPad>(context);
408 }
409 
410 OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
412  *this, adaptor.getOperands(), BinOpKind::Normal,
413  [=](const APSInt &a, const APSInt &b) { return a - b; });
414 }
415 
416 void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
417  MLIRContext *context) {
418  results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
419  patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
420  patterns::SubOfPadL, patterns::SubOfPadR>(context);
421 }
422 
423 OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
424  // mul(x, 0) -> 0
425  //
426  // This is legal because it aligns with the Scala FIRRTL Compiler
427  // interpretation of lowering invalid to constant zero before constant
428  // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
429  // multiplication this way and will emit "x * 0".
430  if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
431  return getIntZerosAttr(getType());
432 
434  *this, adaptor.getOperands(), BinOpKind::Normal,
435  [=](const APSInt &a, const APSInt &b) { return a * b; });
436 }
437 
438 OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
439  /// div(x, x) -> 1
440  ///
441  /// Division by zero is undefined in the FIRRTL specification. This fold
442  /// exploits that fact to optimize self division to one. Note: this should
443  /// supersede any division with invalid or zero. Division of invalid by
444  /// invalid should be one.
445  if (getLhs() == getRhs()) {
446  auto width = getType().base().getWidthOrSentinel();
447  if (width == -1)
448  width = 2;
449  // Only fold if we have at least 1 bit of width to represent the `1` value.
450  if (width != 0)
451  return getIntAttr(getType(), APInt(width, 1));
452  }
453 
454  // div(0, x) -> 0
455  //
456  // This is legal because it aligns with the Scala FIRRTL Compiler
457  // interpretation of lowering invalid to constant zero before constant
458  // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
459  // division this way and will emit "0 / x".
460  if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
461  return getIntZerosAttr(getType());
462 
463  /// div(x, 1) -> x : (uint, uint) -> uint
464  ///
465  /// UInt division by one returns the numerator. SInt division can't
466  /// be folded here because it increases the return type bitwidth by
467  /// one and requires sign extension (a new op).
468  if (auto rhsCst = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
469  if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
470  return getLhs();
471 
473  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
474  [=](const APSInt &a, const APSInt &b) -> APInt {
475  if (!!b)
476  return a / b;
477  return APInt(a.getBitWidth(), 0);
478  });
479 }
480 
481 OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
482  // rem(x, x) -> 0
483  //
484  // Division by zero is undefined in the FIRRTL specification. This fold
485  // exploits that fact to optimize self division remainder to zero. Note:
486  // this should supersede any division with invalid or zero. Remainder of
487  // division of invalid by invalid should be zero.
488  if (getLhs() == getRhs())
489  return getIntZerosAttr(getType());
490 
491  // rem(0, x) -> 0
492  //
493  // This is legal because it aligns with the Scala FIRRTL Compiler
494  // interpretation of lowering invalid to constant zero before constant
495  // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
496  // division this way and will emit "0 % x".
497  if (isConstantZero(adaptor.getLhs()))
498  return getIntZerosAttr(getType());
499 
501  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
502  [=](const APSInt &a, const APSInt &b) -> APInt {
503  if (!!b)
504  return a % b;
505  return APInt(a.getBitWidth(), 0);
506  });
507 }
508 
509 OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
511  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
512  [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
513 }
514 
515 OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
517  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
518  [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
519 }
520 
521 OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
523  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
524  [=](const APSInt &a, const APSInt &b) -> APInt {
525  return getType().base().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
526  : a.ashr(b);
527  });
528 }
529 
530 // TODO: Move to DRR.
531 OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
532  if (auto rhsCst = getConstant(adaptor.getRhs())) {
533  /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
534  if (rhsCst->isZero())
535  return getIntZerosAttr(getType());
536 
537  /// and(x, -1) -> x
538  if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
539  getRhs().getType() == getType())
540  return getLhs();
541  }
542 
543  if (auto lhsCst = getConstant(adaptor.getLhs())) {
544  /// and(0, x) -> 0, 0 is largest or is implicit zero extended
545  if (lhsCst->isZero())
546  return getIntZerosAttr(getType());
547 
548  /// and(-1, x) -> x
549  if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
550  getRhs().getType() == getType())
551  return getRhs();
552  }
553 
554  /// and(x, x) -> x
555  if (getLhs() == getRhs() && getRhs().getType() == getType())
556  return getRhs();
557 
559  *this, adaptor.getOperands(), BinOpKind::Normal,
560  [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
561 }
562 
563 void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
564  MLIRContext *context) {
565  results
566  .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
567  patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
568  patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
569 }
570 
571 OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
572  if (auto rhsCst = getConstant(adaptor.getRhs())) {
573  /// or(x, 0) -> x
574  if (rhsCst->isZero() && getLhs().getType() == getType())
575  return getLhs();
576 
577  /// or(x, -1) -> -1
578  if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
579  getLhs().getType() == getType())
580  return getRhs();
581  }
582 
583  if (auto lhsCst = getConstant(adaptor.getLhs())) {
584  /// or(0, x) -> x
585  if (lhsCst->isZero() && getRhs().getType() == getType())
586  return getRhs();
587 
588  /// or(-1, x) -> -1
589  if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
590  getRhs().getType() == getType())
591  return getLhs();
592  }
593 
594  /// or(x, x) -> x
595  if (getLhs() == getRhs() && getRhs().getType() == getType())
596  return getRhs();
597 
599  *this, adaptor.getOperands(), BinOpKind::Normal,
600  [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
601 }
602 
603 void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
604  MLIRContext *context) {
605  results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
606  patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad>(
607  context);
608 }
609 
610 OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
611  /// xor(x, 0) -> x
612  if (auto rhsCst = getConstant(adaptor.getRhs()))
613  if (rhsCst->isZero() &&
614  firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
615  return getLhs();
616 
617  /// xor(x, 0) -> x
618  if (auto lhsCst = getConstant(adaptor.getLhs()))
619  if (lhsCst->isZero() &&
620  firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
621  return getRhs();
622 
623  /// xor(x, x) -> 0
624  if (getLhs() == getRhs())
625  return getIntAttr(
626  getType(),
627  APInt(std::max(getType().base().getWidthOrSentinel(), 0), 0));
628 
630  *this, adaptor.getOperands(), BinOpKind::Normal,
631  [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
632 }
633 
634 void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
635  MLIRContext *context) {
636  results.insert<patterns::extendXor, patterns::moveConstXor,
637  patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
638  context);
639 }
640 
641 void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
642  MLIRContext *context) {
643  results.insert<patterns::LEQWithConstLHS>(context);
644 }
645 
646 OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
647  bool isUnsigned = getLhs().getType().base().isUnsigned();
648 
649  // leq(x, x) -> 1
650  if (getLhs() == getRhs())
651  return getIntAttr(getType(), APInt(1, 1));
652 
653  // Comparison against constant outside type bounds.
654  if (auto width = getLhs().getType().base().getWidth()) {
655  if (auto rhsCst = getConstant(adaptor.getRhs())) {
656  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
657  commonWidth = std::max(commonWidth, 1);
658 
659  // leq(x, const) -> 0 where const < minValue of the unsigned type of x
660  // This can never occur since const is unsigned and cannot be less than 0.
661 
662  // leq(x, const) -> 0 where const < minValue of the signed type of x
663  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
664  .slt(getMinSignedValue(*width).sext(commonWidth)))
665  return getIntAttr(getType(), APInt(1, 0));
666 
667  // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
668  if (isUnsigned && rhsCst->zext(commonWidth)
669  .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
670  return getIntAttr(getType(), APInt(1, 1));
671 
672  // leq(x, const) -> 1 where const >= maxValue of the signed type of x
673  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
674  .sge(getMaxSignedValue(*width).sext(commonWidth)))
675  return getIntAttr(getType(), APInt(1, 1));
676  }
677  }
678 
680  *this, adaptor.getOperands(), BinOpKind::Compare,
681  [=](const APSInt &a, const APSInt &b) -> APInt {
682  return APInt(1, a <= b);
683  });
684 }
685 
686 void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
687  MLIRContext *context) {
688  results.insert<patterns::LTWithConstLHS>(context);
689 }
690 
691 OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
692  IntType lhsType = getLhs().getType();
693  bool isUnsigned = lhsType.isUnsigned();
694 
695  // lt(x, x) -> 0
696  if (getLhs() == getRhs())
697  return getIntAttr(getType(), APInt(1, 0));
698 
699  // lt(x, 0) -> 0 when x is unsigned
700  if (auto rhsCst = getConstant(adaptor.getRhs())) {
701  if (rhsCst->isZero() && lhsType.isUnsigned())
702  return getIntAttr(getType(), APInt(1, 0));
703  }
704 
705  // Comparison against constant outside type bounds.
706  if (auto width = lhsType.getWidth()) {
707  if (auto rhsCst = getConstant(adaptor.getRhs())) {
708  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
709  commonWidth = std::max(commonWidth, 1);
710 
711  // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
712  // Handled explicitly above.
713 
714  // lt(x, const) -> 0 where const <= minValue of the signed type of x
715  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
716  .sle(getMinSignedValue(*width).sext(commonWidth)))
717  return getIntAttr(getType(), APInt(1, 0));
718 
719  // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
720  if (isUnsigned && rhsCst->zext(commonWidth)
721  .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
722  return getIntAttr(getType(), APInt(1, 1));
723 
724  // lt(x, const) -> 1 where const > maxValue of the signed type of x
725  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
726  .sgt(getMaxSignedValue(*width).sext(commonWidth)))
727  return getIntAttr(getType(), APInt(1, 1));
728  }
729  }
730 
732  *this, adaptor.getOperands(), BinOpKind::Compare,
733  [=](const APSInt &a, const APSInt &b) -> APInt {
734  return APInt(1, a < b);
735  });
736 }
737 
738 void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
739  MLIRContext *context) {
740  results.insert<patterns::GEQWithConstLHS>(context);
741 }
742 
743 OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
744  IntType lhsType = getLhs().getType();
745  bool isUnsigned = lhsType.isUnsigned();
746 
747  // geq(x, x) -> 1
748  if (getLhs() == getRhs())
749  return getIntAttr(getType(), APInt(1, 1));
750 
751  // geq(x, 0) -> 1 when x is unsigned
752  if (auto rhsCst = getConstant(adaptor.getRhs())) {
753  if (rhsCst->isZero() && isUnsigned)
754  return getIntAttr(getType(), APInt(1, 1));
755  }
756 
757  // Comparison against constant outside type bounds.
758  if (auto width = lhsType.getWidth()) {
759  if (auto rhsCst = getConstant(adaptor.getRhs())) {
760  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
761  commonWidth = std::max(commonWidth, 1);
762 
763  // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
764  if (isUnsigned && rhsCst->zext(commonWidth)
765  .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
766  return getIntAttr(getType(), APInt(1, 0));
767 
768  // geq(x, const) -> 0 where const > maxValue of the signed type of x
769  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
770  .sgt(getMaxSignedValue(*width).sext(commonWidth)))
771  return getIntAttr(getType(), APInt(1, 0));
772 
773  // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
774  // Handled explicitly above.
775 
776  // geq(x, const) -> 1 where const <= minValue of the signed type of x
777  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
778  .sle(getMinSignedValue(*width).sext(commonWidth)))
779  return getIntAttr(getType(), APInt(1, 1));
780  }
781  }
782 
784  *this, adaptor.getOperands(), BinOpKind::Compare,
785  [=](const APSInt &a, const APSInt &b) -> APInt {
786  return APInt(1, a >= b);
787  });
788 }
789 
790 void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
791  MLIRContext *context) {
792  results.insert<patterns::GTWithConstLHS>(context);
793 }
794 
795 OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
796  IntType lhsType = getLhs().getType();
797  bool isUnsigned = lhsType.isUnsigned();
798 
799  // gt(x, x) -> 0
800  if (getLhs() == getRhs())
801  return getIntAttr(getType(), APInt(1, 0));
802 
803  // Comparison against constant outside type bounds.
804  if (auto width = lhsType.getWidth()) {
805  if (auto rhsCst = getConstant(adaptor.getRhs())) {
806  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
807  commonWidth = std::max(commonWidth, 1);
808 
809  // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
810  if (isUnsigned && rhsCst->zext(commonWidth)
811  .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
812  return getIntAttr(getType(), APInt(1, 0));
813 
814  // gt(x, const) -> 0 where const >= maxValue of the signed type of x
815  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
816  .sge(getMaxSignedValue(*width).sext(commonWidth)))
817  return getIntAttr(getType(), APInt(1, 0));
818 
819  // gt(x, const) -> 1 where const < minValue of the unsigned type of x
820  // This can never occur since const is unsigned and cannot be less than 0.
821 
822  // gt(x, const) -> 1 where const < minValue of the signed type of x
823  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
824  .slt(getMinSignedValue(*width).sext(commonWidth)))
825  return getIntAttr(getType(), APInt(1, 1));
826  }
827  }
828 
830  *this, adaptor.getOperands(), BinOpKind::Compare,
831  [=](const APSInt &a, const APSInt &b) -> APInt {
832  return APInt(1, a > b);
833  });
834 }
835 
836 OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
837  // eq(x, x) -> 1
838  if (getLhs() == getRhs())
839  return getIntAttr(getType(), APInt(1, 1));
840 
841  if (auto rhsCst = getConstant(adaptor.getRhs())) {
842  /// eq(x, 1) -> x when x is 1 bit.
843  /// TODO: Support SInt<1> on the LHS etc.
844  if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
845  getRhs().getType() == getType())
846  return getLhs();
847  }
848 
850  *this, adaptor.getOperands(), BinOpKind::Compare,
851  [=](const APSInt &a, const APSInt &b) -> APInt {
852  return APInt(1, a == b);
853  });
854 }
855 
856 LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
857  return canonicalizePrimOp(
858  op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
859  if (auto rhsCst = getConstant(operands[1])) {
860  auto width = op.getLhs().getType().getBitWidthOrSentinel();
861 
862  // eq(x, 0) -> not(x) when x is 1 bit.
863  if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
864  op.getRhs().getType() == op.getType()) {
865  return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
866  .getResult();
867  }
868 
869  // eq(x, 0) -> not(orr(x)) when x is >1 bit
870  if (rhsCst->isZero() && width > 1) {
871  auto orrOp = rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs());
872  return rewriter.create<NotPrimOp>(op.getLoc(), orrOp).getResult();
873  }
874 
875  // eq(x, ~0) -> andr(x) when x is >1 bit
876  if (rhsCst->isAllOnes() && width > 1 &&
877  op.getLhs().getType() == op.getRhs().getType()) {
878  return rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs())
879  .getResult();
880  }
881  }
882 
883  return {};
884  });
885 }
886 
887 OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
888  // neq(x, x) -> 0
889  if (getLhs() == getRhs())
890  return getIntAttr(getType(), APInt(1, 0));
891 
892  if (auto rhsCst = getConstant(adaptor.getRhs())) {
893  /// neq(x, 0) -> x when x is 1 bit.
894  /// TODO: Support SInt<1> on the LHS etc.
895  if (rhsCst->isZero() && getLhs().getType() == getType() &&
896  getRhs().getType() == getType())
897  return getLhs();
898  }
899 
901  *this, adaptor.getOperands(), BinOpKind::Compare,
902  [=](const APSInt &a, const APSInt &b) -> APInt {
903  return APInt(1, a != b);
904  });
905 }
906 
907 LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
908  return canonicalizePrimOp(
909  op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
910  if (auto rhsCst = getConstant(operands[1])) {
911  auto width = op.getLhs().getType().getBitWidthOrSentinel();
912 
913  // neq(x, 1) -> not(x) when x is 1 bit
914  if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
915  op.getRhs().getType() == op.getType()) {
916  return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
917  .getResult();
918  }
919 
920  // neq(x, 0) -> orr(x) when x is >1 bit
921  if (rhsCst->isZero() && width > 1) {
922  return rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs())
923  .getResult();
924  }
925 
926  // neq(x, ~0) -> not(andr(x))) when x is >1 bit
927  if (rhsCst->isAllOnes() && width > 1 &&
928  op.getLhs().getType() == op.getRhs().getType()) {
929  auto andrOp = rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs());
930  return rewriter.create<NotPrimOp>(op.getLoc(), andrOp).getResult();
931  }
932  }
933 
934  return {};
935  });
936 }
937 
938 OpFoldResult IntegerAddOp::fold(FoldAdaptor adaptor) {
939  // TODO: implement constant folding, etc.
940  // Tracked in https://github.com/llvm/circt/issues/6696.
941  return {};
942 }
943 
944 OpFoldResult IntegerMulOp::fold(FoldAdaptor adaptor) {
945  // TODO: implement constant folding, etc.
946  // Tracked in https://github.com/llvm/circt/issues/6724.
947  return {};
948 }
949 
950 OpFoldResult IntegerShrOp::fold(FoldAdaptor adaptor) {
951  // TODO: implement constant folding, etc.
952  // Tracked in https://github.com/llvm/circt/issues/6725.
953  return {};
954 }
955 
956 //===----------------------------------------------------------------------===//
957 // Unary Operators
958 //===----------------------------------------------------------------------===//
959 
960 OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
961  auto base = getInput().getType();
962  auto w = base.getBitWidthOrSentinel();
963  if (w >= 0)
964  return getIntAttr(getType(), APInt(32, w));
965  return {};
966 }
967 
968 OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
969  // No constant can be 'x' by definition.
970  if (auto cst = getConstant(adaptor.getArg()))
971  return getIntAttr(getType(), APInt(1, 0));
972  return {};
973 }
974 
975 OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
976  // No effect.
977  if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
978  return getInput();
979 
980  // Be careful to only fold the cast into the constant if the size is known.
981  // Otherwise width inference may produce differently-sized constants if the
982  // sign changes.
983  if (getType().base().hasWidth())
984  if (auto cst = getConstant(adaptor.getInput()))
985  return getIntAttr(getType(), *cst);
986 
987  return {};
988 }
989 
990 void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
991  MLIRContext *context) {
992  results.insert<patterns::StoUtoS>(context);
993 }
994 
995 OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
996  // No effect.
997  if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
998  return getInput();
999 
1000  // Be careful to only fold the cast into the constant if the size is known.
1001  // Otherwise width inference may produce differently-sized constants if the
1002  // sign changes.
1003  if (getType().base().hasWidth())
1004  if (auto cst = getConstant(adaptor.getInput()))
1005  return getIntAttr(getType(), *cst);
1006 
1007  return {};
1008 }
1009 
1010 void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1011  MLIRContext *context) {
1012  results.insert<patterns::UtoStoU>(context);
1013 }
1014 
1015 OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1016  // No effect.
1017  if (getInput().getType() == getType())
1018  return getInput();
1019 
1020  // Constant fold.
1021  if (auto cst = getConstant(adaptor.getInput()))
1022  return BoolAttr::get(getContext(), cst->getBoolValue());
1023 
1024  return {};
1025 }
1026 
1027 OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1028  // No effect.
1029  if (getInput().getType() == getType())
1030  return getInput();
1031 
1032  // Constant fold.
1033  if (auto cst = getConstant(adaptor.getInput()))
1034  return BoolAttr::get(getContext(), cst->getBoolValue());
1035 
1036  return {};
1037 }
1038 
1039 OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1040  if (!hasKnownWidthIntTypes(*this))
1041  return {};
1042 
1043  // Signed to signed is a noop, unsigned operands prepend a zero bit.
1044  if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1045  getType().base().getWidthOrSentinel()))
1046  return getIntAttr(getType(), *cst);
1047 
1048  return {};
1049 }
1050 
1051 void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1052  MLIRContext *context) {
1053  results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1054 }
1055 
1056 OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1057  if (!hasKnownWidthIntTypes(*this))
1058  return {};
1059 
1060  // FIRRTL negate always adds a bit.
1061  // -x ---> 0-sext(x) or 0-zext(x)
1062  if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1063  getType().base().getWidthOrSentinel()))
1064  return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1065 
1066  return {};
1067 }
1068 
1069 OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1070  if (!hasKnownWidthIntTypes(*this))
1071  return {};
1072 
1073  if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1074  getType().base().getWidthOrSentinel()))
1075  return getIntAttr(getType(), ~*cst);
1076 
1077  return {};
1078 }
1079 
1080 void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1081  MLIRContext *context) {
1082  results.insert<patterns::NotNot>(context);
1083 }
1084 
1085 OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1086  if (!hasKnownWidthIntTypes(*this))
1087  return {};
1088 
1089  if (getInput().getType().getBitWidthOrSentinel() == 0)
1090  return getIntAttr(getType(), APInt(1, 1));
1091 
1092  // x == -1
1093  if (auto cst = getConstant(adaptor.getInput()))
1094  return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1095 
1096  // one bit is identity. Only applies to UInt since we can't make a cast
1097  // here.
1098  if (isUInt1(getInput().getType()))
1099  return getInput();
1100 
1101  return {};
1102 }
1103 
1104 void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1105  MLIRContext *context) {
1106  results
1107  .insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1108  patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
1109  patterns::AndRCatZeroL, patterns::AndRCatZeroR>(context);
1110 }
1111 
1112 OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1113  if (!hasKnownWidthIntTypes(*this))
1114  return {};
1115 
1116  if (getInput().getType().getBitWidthOrSentinel() == 0)
1117  return getIntAttr(getType(), APInt(1, 0));
1118 
1119  // x != 0
1120  if (auto cst = getConstant(adaptor.getInput()))
1121  return getIntAttr(getType(), APInt(1, !cst->isZero()));
1122 
1123  // one bit is identity. Only applies to UInt since we can't make a cast
1124  // here.
1125  if (isUInt1(getInput().getType()))
1126  return getInput();
1127 
1128  return {};
1129 }
1130 
1131 void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1132  MLIRContext *context) {
1133  results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1134  patterns::OrRCatZeroH, patterns::OrRCatZeroL>(context);
1135 }
1136 
1137 OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1138  if (!hasKnownWidthIntTypes(*this))
1139  return {};
1140 
1141  if (getInput().getType().getBitWidthOrSentinel() == 0)
1142  return getIntAttr(getType(), APInt(1, 0));
1143 
1144  // popcount(x) & 1
1145  if (auto cst = getConstant(adaptor.getInput()))
1146  return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1147 
1148  // one bit is identity. Only applies to UInt since we can't make a cast here.
1149  if (isUInt1(getInput().getType()))
1150  return getInput();
1151 
1152  return {};
1153 }
1154 
1155 void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1156  MLIRContext *context) {
1157  results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1158  patterns::XorRCatZeroH, patterns::XorRCatZeroL>(context);
1159 }
1160 
1161 //===----------------------------------------------------------------------===//
1162 // Other Operators
1163 //===----------------------------------------------------------------------===//
1164 
1165 OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1166  // cat(x, 0-width) -> x
1167  // cat(0-width, x) -> x
1168  // Limit to unsigned (result type), as cannot insert cast here.
1169  IntType lhsType = getLhs().getType();
1170  IntType rhsType = getRhs().getType();
1171  if (lhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
1172  return getRhs();
1173  if (rhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
1174  return getLhs();
1175 
1176  if (!hasKnownWidthIntTypes(*this))
1177  return {};
1178 
1179  // Constant fold cat.
1180  if (auto lhs = getConstant(adaptor.getLhs()))
1181  if (auto rhs = getConstant(adaptor.getRhs()))
1182  return getIntAttr(getType(), lhs->concat(*rhs));
1183 
1184  return {};
1185 }
1186 
1187 void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1188  MLIRContext *context) {
1189  results.insert<patterns::DShlOfConstant>(context);
1190 }
1191 
1192 void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1193  MLIRContext *context) {
1194  results.insert<patterns::DShrOfConstant>(context);
1195 }
1196 
1197 namespace {
1198 // cat(bits(x, ...), bits(x, ...)) -> bits(x ...) when the two ...'s are
1199 // consequtive in the input.
1200 struct CatBitsBits : public mlir::RewritePattern {
1201  CatBitsBits(MLIRContext *context)
1202  : RewritePattern(CatPrimOp::getOperationName(), 0, context) {}
1203  LogicalResult matchAndRewrite(Operation *op,
1204  PatternRewriter &rewriter) const override {
1205  auto cat = cast<CatPrimOp>(op);
1206  if (auto lhsBits =
1207  dyn_cast_or_null<BitsPrimOp>(cat.getLhs().getDefiningOp())) {
1208  if (auto rhsBits =
1209  dyn_cast_or_null<BitsPrimOp>(cat.getRhs().getDefiningOp())) {
1210  if (lhsBits.getInput() == rhsBits.getInput() &&
1211  lhsBits.getLo() - 1 == rhsBits.getHi()) {
1212  replaceOpWithNewOpAndCopyName<BitsPrimOp>(
1213  rewriter, cat, cat.getType(), lhsBits.getInput(), lhsBits.getHi(),
1214  rhsBits.getLo());
1215  return success();
1216  }
1217  }
1218  }
1219  return failure();
1220  }
1221 };
1222 } // namespace
1223 
1224 void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1225  MLIRContext *context) {
1226  results.insert<CatBitsBits, patterns::CatDoubleConst>(context);
1227 }
1228 
1229 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1230  auto op = (*this);
1231  // BitCast is redundant if input and result types are same.
1232  if (op.getType() == op.getInput().getType())
1233  return op.getInput();
1234 
1235  // Two consecutive BitCasts are redundant if first bitcast type is same as the
1236  // final result type.
1237  if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1238  if (op.getType() == in.getInput().getType())
1239  return in.getInput();
1240 
1241  return {};
1242 }
1243 
1244 OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1245  IntType inputType = getInput().getType();
1246  IntType resultType = getType();
1247  // If we are extracting the entire input, then return it.
1248  if (inputType == getType() && resultType.hasWidth())
1249  return getInput();
1250 
1251  // Constant fold.
1252  if (hasKnownWidthIntTypes(*this))
1253  if (auto cst = getConstant(adaptor.getInput()))
1254  return getIntAttr(resultType,
1255  cst->extractBits(getHi() - getLo() + 1, getLo()));
1256 
1257  return {};
1258 }
1259 
1260 void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1261  MLIRContext *context) {
1262  results
1263  .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1264  patterns::BitsOfAnd, patterns::BitsOfPad>(context);
1265 }
1266 
1267 /// Replace the specified operation with a 'bits' op from the specified hi/lo
1268 /// bits. Insert a cast to handle the case where the original operation
1269 /// returned a signed integer.
1270 static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1271  unsigned loBit, PatternRewriter &rewriter) {
1272  auto resType = type_cast<IntType>(op->getResult(0).getType());
1273  if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1274  value = rewriter.create<BitsPrimOp>(op->getLoc(), value, hiBit, loBit);
1275 
1276  if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1277  value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1278  } else if (resType.isUnsigned() &&
1279  !type_cast<IntType>(value.getType()).isUnsigned()) {
1280  value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1281  }
1282  rewriter.replaceOp(op, value);
1283 }
1284 
1285 template <typename OpTy>
1286 static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1287  // mux : UInt<0> -> 0
1288  if (op.getType().getBitWidthOrSentinel() == 0)
1289  return getIntAttr(op.getType(),
1290  APInt(0, 0, op.getType().isSignedInteger()));
1291 
1292  // mux(cond, x, x) -> x
1293  if (op.getHigh() == op.getLow())
1294  return op.getHigh();
1295 
1296  // The following folds require that the result has a known width. Otherwise
1297  // the mux requires an additional padding operation to be inserted, which is
1298  // not possible in a fold.
1299  if (op.getType().getBitWidthOrSentinel() < 0)
1300  return {};
1301 
1302  // mux(0/1, x, y) -> x or y
1303  if (auto cond = getConstant(adaptor.getSel())) {
1304  if (cond->isZero() && op.getLow().getType() == op.getType())
1305  return op.getLow();
1306  if (!cond->isZero() && op.getHigh().getType() == op.getType())
1307  return op.getHigh();
1308  }
1309 
1310  // mux(cond, x, cst)
1311  if (auto lowCst = getConstant(adaptor.getLow())) {
1312  // mux(cond, c1, c2)
1313  if (auto highCst = getConstant(adaptor.getHigh())) {
1314  // mux(cond, cst, cst) -> cst
1315  if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1316  *highCst == *lowCst)
1317  return getIntAttr(op.getType(), *highCst);
1318  // mux(cond, 1, 0) -> cond
1319  if (highCst->isOne() && lowCst->isZero() &&
1320  op.getType() == op.getSel().getType())
1321  return op.getSel();
1322 
1323  // TODO: x ? ~0 : 0 -> sext(x)
1324  // TODO: "x ? c1 : c2" -> many tricks
1325  }
1326  // TODO: "x ? a : 0" -> sext(x) & a
1327  }
1328 
1329  // TODO: "x ? c1 : y" -> "~x ? y : c1"
1330  return {};
1331 }
1332 
1333 OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1334  return foldMux(*this, adaptor);
1335 }
1336 
1337 OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1338  return foldMux(*this, adaptor);
1339 }
1340 
1341 OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1342 
1343 namespace {
1344 
1345 // If the mux has a known output width, pad the operands up to this width.
1346 // Most folds on mux require that folded operands are of the same width as
1347 // the mux itself.
1348 class MuxPad : public mlir::RewritePattern {
1349 public:
1350  MuxPad(MLIRContext *context)
1351  : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1352 
1353  LogicalResult
1354  matchAndRewrite(Operation *op,
1355  mlir::PatternRewriter &rewriter) const override {
1356  auto mux = cast<MuxPrimOp>(op);
1357  auto width = mux.getType().getBitWidthOrSentinel();
1358  if (width < 0)
1359  return failure();
1360 
1361  auto pad = [&](Value input) -> Value {
1362  auto inputWidth =
1363  type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1364  if (inputWidth < 0 || width == inputWidth)
1365  return input;
1366  return rewriter
1367  .create<PadPrimOp>(mux.getLoc(), mux.getType(), input, width)
1368  .getResult();
1369  };
1370 
1371  auto newHigh = pad(mux.getHigh());
1372  auto newLow = pad(mux.getLow());
1373  if (newHigh == mux.getHigh() && newLow == mux.getLow())
1374  return failure();
1375 
1376  replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1377  rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1378  mux->getAttrs());
1379  return success();
1380  }
1381 };
1382 
1383 // Find muxes which have conditions dominated by other muxes with the same
1384 // condition.
1385 class MuxSharedCond : public mlir::RewritePattern {
1386 public:
1387  MuxSharedCond(MLIRContext *context)
1388  : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1389 
1390  static const int depthLimit = 5;
1391 
1392  Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1393  mlir::PatternRewriter &rewriter,
1394  bool updateInPlace) const {
1395  if (updateInPlace) {
1396  rewriter.modifyOpInPlace(mux, [&] {
1397  mux.setOperand(1, high);
1398  mux.setOperand(2, low);
1399  });
1400  return {};
1401  }
1402  rewriter.setInsertionPointAfter(mux);
1403  return rewriter
1404  .create<MuxPrimOp>(mux.getLoc(), mux.getType(),
1405  ValueRange{mux.getSel(), high, low})
1406  .getResult();
1407  }
1408 
1409  // Walk a dependent mux tree assuming the condition cond is true.
1410  Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1411  bool updateInPlace, int limit) const {
1412  MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1413  if (!mux)
1414  return {};
1415  if (mux.getSel() == cond)
1416  return mux.getHigh();
1417  if (limit > depthLimit)
1418  return {};
1419  updateInPlace &= mux->hasOneUse();
1420 
1421  if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1422  limit + 1))
1423  return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1424 
1425  if (Value v =
1426  tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1427  return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1428  return {};
1429  }
1430 
1431  // Walk a dependent mux tree assuming the condition cond is false.
1432  Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1433  bool updateInPlace, int limit) const {
1434  MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1435  if (!mux)
1436  return {};
1437  if (mux.getSel() == cond)
1438  return mux.getLow();
1439  if (limit > depthLimit)
1440  return {};
1441  updateInPlace &= mux->hasOneUse();
1442 
1443  if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1444  limit + 1))
1445  return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1446 
1447  if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1448  limit + 1))
1449  return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1450 
1451  return {};
1452  }
1453 
1454  LogicalResult
1455  matchAndRewrite(Operation *op,
1456  mlir::PatternRewriter &rewriter) const override {
1457  auto mux = cast<MuxPrimOp>(op);
1458  auto width = mux.getType().getBitWidthOrSentinel();
1459  if (width < 0)
1460  return failure();
1461 
1462  if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1463  rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(1, v); });
1464  return success();
1465  }
1466 
1467  if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
1468  rewriter.modifyOpInPlace(mux, [&] { mux.setOperand(2, v); });
1469  return success();
1470  }
1471 
1472  return failure();
1473  }
1474 };
1475 } // namespace
1476 
1477 void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1478  MLIRContext *context) {
1479  results.add<MuxPad, MuxSharedCond, patterns::MuxEQOperands,
1480  patterns::MuxEQOperandsSwapped, patterns::MuxNEQ,
1481  patterns::MuxNot, patterns::MuxSameTrue, patterns::MuxSameFalse,
1482  patterns::NarrowMuxLHS, patterns::NarrowMuxRHS>(context);
1483 }
1484 
1485 OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1486  auto input = this->getInput();
1487 
1488  // pad(x) -> x if the width doesn't change.
1489  if (input.getType() == getType())
1490  return input;
1491 
1492  // Need to know the input width.
1493  auto inputType = input.getType().base();
1494  int32_t width = inputType.getWidthOrSentinel();
1495  if (width == -1)
1496  return {};
1497 
1498  // Constant fold.
1499  if (auto cst = getConstant(adaptor.getInput())) {
1500  auto destWidth = getType().base().getWidthOrSentinel();
1501  if (destWidth == -1)
1502  return {};
1503 
1504  if (inputType.isSigned() && cst->getBitWidth())
1505  return getIntAttr(getType(), cst->sext(destWidth));
1506  return getIntAttr(getType(), cst->zext(destWidth));
1507  }
1508 
1509  return {};
1510 }
1511 
1512 OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1513  auto input = this->getInput();
1514  IntType inputType = input.getType();
1515  int shiftAmount = getAmount();
1516 
1517  // shl(x, 0) -> x
1518  if (shiftAmount == 0)
1519  return input;
1520 
1521  // Constant fold.
1522  if (auto cst = getConstant(adaptor.getInput())) {
1523  auto inputWidth = inputType.getWidthOrSentinel();
1524  if (inputWidth != -1) {
1525  auto resultWidth = inputWidth + shiftAmount;
1526  shiftAmount = std::min(shiftAmount, resultWidth);
1527  return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1528  }
1529  }
1530  return {};
1531 }
1532 
1533 OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1534  auto input = this->getInput();
1535  IntType inputType = input.getType();
1536  int shiftAmount = getAmount();
1537  auto inputWidth = inputType.getWidthOrSentinel();
1538 
1539  // shr(x, 0) -> x
1540  // Once the shr width changes, do this: shiftAmount == 0 &&
1541  // (!inputType.isSigned() || inputWidth > 0)
1542  if (shiftAmount == 0 && inputWidth > 0)
1543  return input;
1544 
1545  if (inputWidth == -1)
1546  return {};
1547  if (inputWidth == 0)
1548  return getIntZerosAttr(getType());
1549 
1550  // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
1551  // If x is signed, it is the sign bit.
1552  if (shiftAmount >= inputWidth && inputType.isUnsigned())
1553  return getIntAttr(getType(), APInt(0, 0, false));
1554 
1555  // Constant fold.
1556  if (auto cst = getConstant(adaptor.getInput())) {
1557  APInt value;
1558  if (inputType.isSigned())
1559  value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1560  else
1561  value = cst->lshr(std::min(shiftAmount, inputWidth));
1562  auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1563  return getIntAttr(getType(), value.trunc(resultWidth));
1564  }
1565  return {};
1566 }
1567 
1568 LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1569  auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1570  if (inputWidth <= 0)
1571  return failure();
1572 
1573  // If we know the input width, we can canonicalize this into a BitsPrimOp.
1574  unsigned shiftAmount = op.getAmount();
1575  if (int(shiftAmount) >= inputWidth) {
1576  // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
1577  if (op.getType().base().isUnsigned())
1578  return failure();
1579 
1580  // Shifting a signed value by the full width is actually taking the
1581  // sign bit. If the shift amount is greater than the input width, it
1582  // is equivalent to shifting by the input width.
1583  shiftAmount = inputWidth - 1;
1584  }
1585 
1586  replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1587  return success();
1588 }
1589 
1590 LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1591  PatternRewriter &rewriter) {
1592  auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1593  if (inputWidth <= 0)
1594  return failure();
1595 
1596  // If we know the input width, we can canonicalize this into a BitsPrimOp.
1597  unsigned keepAmount = op.getAmount();
1598  if (keepAmount)
1599  replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1600  rewriter);
1601  return success();
1602 }
1603 
1604 OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1605  if (hasKnownWidthIntTypes(*this))
1606  if (auto cst = getConstant(adaptor.getInput())) {
1607  int shiftAmount =
1608  getInput().getType().base().getWidthOrSentinel() - getAmount();
1609  return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1610  }
1611 
1612  return {};
1613 }
1614 
1615 OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1616  if (hasKnownWidthIntTypes(*this))
1617  if (auto cst = getConstant(adaptor.getInput()))
1618  return getIntAttr(getType(),
1619  cst->trunc(getType().base().getWidthOrSentinel()));
1620  return {};
1621 }
1622 
1623 LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1624  PatternRewriter &rewriter) {
1625  auto inputWidth = op.getInput().getType().base().getWidthOrSentinel();
1626  if (inputWidth <= 0)
1627  return failure();
1628 
1629  // If we know the input width, we can canonicalize this into a BitsPrimOp.
1630  unsigned dropAmount = op.getAmount();
1631  if (dropAmount != unsigned(inputWidth))
1632  replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
1633  rewriter);
1634  return success();
1635 }
1636 
1637 void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1638  MLIRContext *context) {
1639  results.add<patterns::SubaccessOfConstant>(context);
1640 }
1641 
1642 OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1643  // If there is only one input, just return it.
1644  if (adaptor.getInputs().size() == 1)
1645  return getOperand(1);
1646 
1647  if (auto constIndex = getConstant(adaptor.getIndex())) {
1648  auto index = constIndex->getZExtValue();
1649  if (index < getInputs().size())
1650  return getInputs()[getInputs().size() - 1 - index];
1651  }
1652 
1653  return {};
1654 }
1655 
1656 LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
1657  PatternRewriter &rewriter) {
1658  // If all operands are equal, just canonicalize to it. We can add this
1659  // canonicalization as a folder but it costly to look through all inputs so it
1660  // is added here.
1661  if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
1662  return input == op.getInputs().front();
1663  })) {
1664  replaceOpAndCopyName(rewriter, op, op.getInputs().front());
1665  return success();
1666  }
1667 
1668  // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
1669  // a[0]`), we can fold the op into subaccess op `a[idx]`.
1670  if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
1671  if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
1672  auto subindex = e.value().template getDefiningOp<SubindexOp>();
1673  return subindex && lastSubindex.getInput() == subindex.getInput() &&
1674  subindex.getIndex() + e.index() + 1 == op.getInputs().size();
1675  })) {
1676  replaceOpWithNewOpAndCopyName<SubaccessOp>(
1677  rewriter, op, lastSubindex.getInput(), op.getIndex());
1678  return success();
1679  }
1680  }
1681 
1682  // If the size is 2, canonicalize into a normal mux to introduce more folds.
1683  if (op.getInputs().size() != 2)
1684  return failure();
1685 
1686  // TODO: Handle even when `index` doesn't have uint<1>.
1687  auto uintType = op.getIndex().getType();
1688  if (uintType.getBitWidthOrSentinel() != 1)
1689  return failure();
1690 
1691  // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
1692  replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1693  rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
1694  return success();
1695 }
1696 
1697 //===----------------------------------------------------------------------===//
1698 // Declarations
1699 //===----------------------------------------------------------------------===//
1700 
1701 /// Scan all the uses of the specified value, checking to see if there is
1702 /// exactly one connect that has the value as its destination. This returns the
1703 /// operation if found and if all the other users are "reads" from the value.
1704 /// Returns null if there are no connects, or multiple connects to the value, or
1705 /// if the value is involved in an `AttachOp`, or if the connect isn't strict.
1706 ///
1707 /// Note that this will simply return the connect, which is located *anywhere*
1708 /// after the definition of the value. Users of this function are likely
1709 /// interested in the source side of the returned connect, the definition of
1710 /// which does likely not dominate the original value.
1711 StrictConnectOp firrtl::getSingleConnectUserOf(Value value) {
1712  StrictConnectOp connect;
1713  for (Operation *user : value.getUsers()) {
1714  // If we see an attach or aggregate sublements, just conservatively fail.
1715  if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
1716  return {};
1717 
1718  if (auto aConnect = dyn_cast<FConnectLike>(user))
1719  if (aConnect.getDest() == value) {
1720  auto strictConnect = dyn_cast<StrictConnectOp>(*aConnect);
1721  // If this is not a strict connect, a second strict connect or in a
1722  // different block, fail.
1723  if (!strictConnect || (connect && connect != strictConnect) ||
1724  strictConnect->getBlock() != value.getParentBlock())
1725  return {};
1726  else
1727  connect = strictConnect;
1728  }
1729  }
1730  return connect;
1731 }
1732 
1733 // Forward simple values through wire's and reg's.
1734 static LogicalResult canonicalizeSingleSetConnect(StrictConnectOp op,
1735  PatternRewriter &rewriter) {
1736  // While we can do this for nearly all wires, we currently limit it to simple
1737  // things.
1738  Operation *connectedDecl = op.getDest().getDefiningOp();
1739  if (!connectedDecl)
1740  return failure();
1741 
1742  // Only support wire and reg for now.
1743  if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
1744  return failure();
1745  if (hasDontTouch(connectedDecl) ||
1746  !AnnotationSet(connectedDecl).canBeDeleted() ||
1747  !hasDroppableName(connectedDecl) ||
1748  cast<Forceable>(connectedDecl).isForceable())
1749  return failure();
1750 
1751  // Only forward if the types exactly match and there is one connect.
1752  if (getSingleConnectUserOf(op.getDest()) != op)
1753  return failure();
1754 
1755  // Only forward if there is more than one use
1756  if (connectedDecl->hasOneUse())
1757  return failure();
1758 
1759  // Only do this if the connectee and the declaration are in the same block.
1760  auto *declBlock = connectedDecl->getBlock();
1761  auto *srcValueOp = op.getSrc().getDefiningOp();
1762  if (!srcValueOp) {
1763  // Ports are ok for wires but not registers.
1764  if (!isa<WireOp>(connectedDecl))
1765  return failure();
1766 
1767  } else {
1768  // Constants/invalids in the same block are ok to forward, even through
1769  // reg's since the clocking doesn't matter for constants.
1770  if (!isa<ConstantOp>(srcValueOp))
1771  return failure();
1772  if (srcValueOp->getBlock() != declBlock)
1773  return failure();
1774  }
1775 
1776  // Ok, we know we are doing the transformation.
1777 
1778  auto replacement = op.getSrc();
1779  // This will be replaced with the constant source. First, make sure the
1780  // constant dominates all users.
1781  if (srcValueOp && srcValueOp != &declBlock->front())
1782  srcValueOp->moveBefore(&declBlock->front());
1783 
1784  // Replace all things *using* the decl with the constant/port, and
1785  // remove the declaration.
1786  replaceOpAndCopyName(rewriter, connectedDecl, replacement);
1787 
1788  // Remove the connect
1789  rewriter.eraseOp(op);
1790  return success();
1791 }
1792 
1793 void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1794  MLIRContext *context) {
1795  results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
1796  context);
1797 }
1798 
1799 LogicalResult StrictConnectOp::canonicalize(StrictConnectOp op,
1800  PatternRewriter &rewriter) {
1801  // TODO: Canonicalize towards explicit extensions and flips here.
1802 
1803  // If there is a simple value connected to a foldable decl like a wire or reg,
1804  // see if we can eliminate the decl.
1805  if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
1806  return success();
1807  return failure();
1808 }
1809 
1810 //===----------------------------------------------------------------------===//
1811 // Statements
1812 //===----------------------------------------------------------------------===//
1813 
1814 /// If the specified value has an AttachOp user strictly dominating by
1815 /// "dominatingAttach" then return it.
1816 static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
1817  for (auto *user : value.getUsers()) {
1818  auto attach = dyn_cast<AttachOp>(user);
1819  if (!attach || attach == dominatedAttach)
1820  continue;
1821  if (attach->isBeforeInBlock(dominatedAttach))
1822  return attach;
1823  }
1824  return {};
1825 }
1826 
1827 LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
1828  // Single operand attaches are a noop.
1829  if (op.getNumOperands() <= 1) {
1830  rewriter.eraseOp(op);
1831  return success();
1832  }
1833 
1834  for (auto operand : op.getOperands()) {
1835  // Check to see if any of our operands has other attaches to it:
1836  // attach x, y
1837  // ...
1838  // attach x, z
1839  // If so, we can merge these into "attach x, y, z".
1840  if (auto attach = getDominatingAttachUser(operand, op)) {
1841  SmallVector<Value> newOperands(op.getOperands());
1842  for (auto newOperand : attach.getOperands())
1843  if (newOperand != operand) // Don't add operand twice.
1844  newOperands.push_back(newOperand);
1845  rewriter.create<AttachOp>(op->getLoc(), newOperands);
1846  rewriter.eraseOp(attach);
1847  rewriter.eraseOp(op);
1848  return success();
1849  }
1850 
1851  // If this wire is *only* used by an attach then we can just delete
1852  // it.
1853  // TODO: May need to be sensitive to "don't touch" or other
1854  // annotations.
1855  if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
1856  if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
1857  !wire.isForceable()) {
1858  SmallVector<Value> newOperands;
1859  for (auto newOperand : op.getOperands())
1860  if (newOperand != operand) // Don't the add wire.
1861  newOperands.push_back(newOperand);
1862 
1863  rewriter.create<AttachOp>(op->getLoc(), newOperands);
1864  rewriter.eraseOp(op);
1865  rewriter.eraseOp(wire);
1866  return success();
1867  }
1868  }
1869  }
1870  return failure();
1871 }
1872 
1873 /// Replaces the given op with the contents of the given single-block region.
1874 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1875  Region &region) {
1876  assert(llvm::hasSingleElement(region) && "expected single-region block");
1877  rewriter.inlineBlockBefore(&region.front(), op, {});
1878 }
1879 
1880 LogicalResult WhenOp::canonicalize(WhenOp op, PatternRewriter &rewriter) {
1881  if (auto constant = op.getCondition().getDefiningOp<firrtl::ConstantOp>()) {
1882  if (constant.getValue().isAllOnes())
1883  replaceOpWithRegion(rewriter, op, op.getThenRegion());
1884  else if (op.hasElseRegion() && !op.getElseRegion().empty())
1885  replaceOpWithRegion(rewriter, op, op.getElseRegion());
1886 
1887  rewriter.eraseOp(op);
1888 
1889  return success();
1890  }
1891 
1892  // Erase empty if-else block.
1893  if (!op.getThenBlock().empty() && op.hasElseRegion() &&
1894  op.getElseBlock().empty()) {
1895  rewriter.eraseBlock(&op.getElseBlock());
1896  return success();
1897  }
1898 
1899  // Erase empty whens.
1900 
1901  // If there is stuff in the then block, leave this operation alone.
1902  if (!op.getThenBlock().empty())
1903  return failure();
1904 
1905  // If not and there is no else, then this operation is just useless.
1906  if (!op.hasElseRegion() || op.getElseBlock().empty()) {
1907  rewriter.eraseOp(op);
1908  return success();
1909  }
1910  return failure();
1911 }
1912 
1913 namespace {
1914 // Remove private nodes. If they have an interesting names, move the name to
1915 // the source expression.
1916 struct FoldNodeName : public mlir::RewritePattern {
1917  FoldNodeName(MLIRContext *context)
1918  : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1919  LogicalResult matchAndRewrite(Operation *op,
1920  PatternRewriter &rewriter) const override {
1921  auto node = cast<NodeOp>(op);
1922  auto name = node.getNameAttr();
1923  if (!node.hasDroppableName() || node.getInnerSym() ||
1924  !AnnotationSet(node).canBeDeleted() || node.isForceable())
1925  return failure();
1926  auto *newOp = node.getInput().getDefiningOp();
1927  // Best effort, do not rename InstanceOp
1928  if (newOp && !isa<InstanceOp>(newOp))
1929  updateName(rewriter, newOp, name);
1930  rewriter.replaceOp(node, node.getInput());
1931  return success();
1932  }
1933 };
1934 
1935 // Bypass nodes.
1936 struct NodeBypass : public mlir::RewritePattern {
1937  NodeBypass(MLIRContext *context)
1938  : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1939  LogicalResult matchAndRewrite(Operation *op,
1940  PatternRewriter &rewriter) const override {
1941  auto node = cast<NodeOp>(op);
1942  if (node.getInnerSym() || !AnnotationSet(node).canBeDeleted() ||
1943  node.use_empty() || node.isForceable())
1944  return failure();
1945  rewriter.startOpModification(node);
1946  node.getResult().replaceAllUsesWith(node.getInput());
1947  rewriter.finalizeOpModification(node);
1948  return success();
1949  }
1950 };
1951 
1952 } // namespace
1953 
1954 template <typename OpTy>
1955 static LogicalResult demoteForceableIfUnused(OpTy op,
1956  PatternRewriter &rewriter) {
1957  if (!op.isForceable() || !op.getDataRef().use_empty())
1958  return failure();
1959 
1960  firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
1961  return success();
1962 }
1963 
1964 // Interesting names and symbols and don't touch force nodes to stick around.
1965 LogicalResult NodeOp::fold(FoldAdaptor adaptor,
1966  SmallVectorImpl<OpFoldResult> &results) {
1967  if (!hasDroppableName())
1968  return failure();
1969  if (hasDontTouch(getResult())) // handles inner symbols
1970  return failure();
1971  if (getAnnotationsAttr() &&
1972  !AnnotationSet(getAnnotationsAttr()).canBeDeleted())
1973  return failure();
1974  if (isForceable())
1975  return failure();
1976  if (!adaptor.getInput())
1977  return failure();
1978 
1979  results.push_back(adaptor.getInput());
1980  return success();
1981 }
1982 
1983 void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1984  MLIRContext *context) {
1985  results.insert<FoldNodeName>(context);
1986  results.add(demoteForceableIfUnused<NodeOp>);
1987 }
1988 
1989 namespace {
1990 // For a lhs, find all the writers of fields of the aggregate type. If there
1991 // is one writer for each field, merge the writes
1992 struct AggOneShot : public mlir::RewritePattern {
1993  AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
1994  : RewritePattern(name, 0, context) {}
1995 
1996  SmallVector<Value> getCompleteWrite(Operation *lhs) const {
1997  auto lhsTy = lhs->getResult(0).getType();
1998  if (!type_isa<BundleType, FVectorType>(lhsTy))
1999  return {};
2000 
2001  DenseMap<uint32_t, Value> fields;
2002  for (Operation *user : lhs->getResult(0).getUsers()) {
2003  if (user->getParentOp() != lhs->getParentOp())
2004  return {};
2005  if (auto aConnect = dyn_cast<StrictConnectOp>(user)) {
2006  if (aConnect.getDest() == lhs->getResult(0))
2007  return {};
2008  } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2009  for (Operation *subuser : subField.getResult().getUsers()) {
2010  if (auto aConnect = dyn_cast<StrictConnectOp>(subuser)) {
2011  if (aConnect.getDest() == subField) {
2012  if (subuser->getParentOp() != lhs->getParentOp())
2013  return {};
2014  if (fields.count(subField.getFieldIndex())) // duplicate write
2015  return {};
2016  fields[subField.getFieldIndex()] = aConnect.getSrc();
2017  }
2018  continue;
2019  }
2020  return {};
2021  }
2022  } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2023  for (Operation *subuser : subIndex.getResult().getUsers()) {
2024  if (auto aConnect = dyn_cast<StrictConnectOp>(subuser)) {
2025  if (aConnect.getDest() == subIndex) {
2026  if (subuser->getParentOp() != lhs->getParentOp())
2027  return {};
2028  if (fields.count(subIndex.getIndex())) // duplicate write
2029  return {};
2030  fields[subIndex.getIndex()] = aConnect.getSrc();
2031  }
2032  continue;
2033  }
2034  return {};
2035  }
2036  } else {
2037  return {};
2038  }
2039  }
2040 
2041  SmallVector<Value> values;
2042  uint32_t total = type_isa<BundleType>(lhsTy)
2043  ? type_cast<BundleType>(lhsTy).getNumElements()
2044  : type_cast<FVectorType>(lhsTy).getNumElements();
2045  for (uint32_t i = 0; i < total; ++i) {
2046  if (!fields.count(i))
2047  return {};
2048  values.push_back(fields[i]);
2049  }
2050  return values;
2051  }
2052 
2053  LogicalResult matchAndRewrite(Operation *op,
2054  PatternRewriter &rewriter) const override {
2055  auto values = getCompleteWrite(op);
2056  if (values.empty())
2057  return failure();
2058  rewriter.setInsertionPointToEnd(op->getBlock());
2059  auto dest = op->getResult(0);
2060  auto destType = dest.getType();
2061 
2062  // If not passive, cannot strictconnect.
2063  if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2064  return failure();
2065 
2066  Value newVal = type_isa<BundleType>(destType)
2067  ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2068  destType, values)
2069  : rewriter.createOrFold<VectorCreateOp>(
2070  op->getLoc(), destType, values);
2071  rewriter.createOrFold<StrictConnectOp>(op->getLoc(), dest, newVal);
2072  for (Operation *user : dest.getUsers()) {
2073  if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2074  for (Operation *subuser :
2075  llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2076  if (auto aConnect = dyn_cast<StrictConnectOp>(subuser))
2077  if (aConnect.getDest() == subIndex)
2078  rewriter.eraseOp(aConnect);
2079  } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2080  for (Operation *subuser :
2081  llvm::make_early_inc_range(subField.getResult().getUsers()))
2082  if (auto aConnect = dyn_cast<StrictConnectOp>(subuser))
2083  if (aConnect.getDest() == subField)
2084  rewriter.eraseOp(aConnect);
2085  }
2086  }
2087  return success();
2088  }
2089 };
2090 
2091 struct WireAggOneShot : public AggOneShot {
2092  WireAggOneShot(MLIRContext *context)
2093  : AggOneShot(WireOp::getOperationName(), 0, context) {}
2094 };
2095 struct SubindexAggOneShot : public AggOneShot {
2096  SubindexAggOneShot(MLIRContext *context)
2097  : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2098 };
2099 struct SubfieldAggOneShot : public AggOneShot {
2100  SubfieldAggOneShot(MLIRContext *context)
2101  : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2102 };
2103 } // namespace
2104 
2105 void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2106  MLIRContext *context) {
2107  results.insert<WireAggOneShot>(context);
2108  results.add(demoteForceableIfUnused<WireOp>);
2109 }
2110 
2111 void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2112  MLIRContext *context) {
2113  results.insert<SubindexAggOneShot>(context);
2114 }
2115 
2116 OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2117  auto attr = adaptor.getInput().dyn_cast_or_null<ArrayAttr>();
2118  if (!attr)
2119  return {};
2120  return attr[getIndex()];
2121 }
2122 
2123 OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2124  auto attr = adaptor.getInput().dyn_cast_or_null<ArrayAttr>();
2125  if (!attr)
2126  return {};
2127  auto index = getFieldIndex();
2128  return attr[index];
2129 }
2130 
2131 void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2132  MLIRContext *context) {
2133  results.insert<SubfieldAggOneShot>(context);
2134 }
2135 
2136 static Attribute collectFields(MLIRContext *context,
2137  ArrayRef<Attribute> operands) {
2138  for (auto operand : operands)
2139  if (!operand)
2140  return {};
2141  return ArrayAttr::get(context, operands);
2142 }
2143 
2144 OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2145  // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2146  // bundle<a:..., b:...>.
2147  if (getNumOperands() > 0)
2148  if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2149  if (first.getFieldIndex() == 0 &&
2150  first.getInput().getType() == getType() &&
2151  llvm::all_of(
2152  llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2153  auto subindex =
2154  elem.value().template getDefiningOp<SubfieldOp>();
2155  return subindex && subindex.getInput() == first.getInput() &&
2156  subindex.getFieldIndex() == elem.index();
2157  }))
2158  return first.getInput();
2159 
2160  return collectFields(getContext(), adaptor.getOperands());
2161 }
2162 
2163 OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2164  // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2165  // vector<..., 2>.
2166  if (getNumOperands() > 0)
2167  if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2168  if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2169  llvm::all_of(
2170  llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2171  auto subindex =
2172  elem.value().template getDefiningOp<SubindexOp>();
2173  return subindex && subindex.getInput() == first.getInput() &&
2174  subindex.getIndex() == elem.index();
2175  }))
2176  return first.getInput();
2177 
2178  return collectFields(getContext(), adaptor.getOperands());
2179 }
2180 
2181 OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2182  if (getOperand().getType() == getType())
2183  return getOperand();
2184  return {};
2185 }
2186 
2187 namespace {
2188 // A register with constant reset and all connection to either itself or the
2189 // same constant, must be replaced by the constant.
2190 struct FoldResetMux : public mlir::RewritePattern {
2191  FoldResetMux(MLIRContext *context)
2192  : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2193  LogicalResult matchAndRewrite(Operation *op,
2194  PatternRewriter &rewriter) const override {
2195  auto reg = cast<RegResetOp>(op);
2196  auto reset =
2197  dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2198  if (!reset || hasDontTouch(reg.getOperation()) ||
2199  !AnnotationSet(reg).canBeDeleted() || reg.isForceable())
2200  return failure();
2201  // Find the one true connect, or bail
2202  auto con = getSingleConnectUserOf(reg.getResult());
2203  if (!con)
2204  return failure();
2205 
2206  auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2207  if (!mux)
2208  return failure();
2209  auto *high = mux.getHigh().getDefiningOp();
2210  auto *low = mux.getLow().getDefiningOp();
2211  auto constOp = dyn_cast_or_null<ConstantOp>(high);
2212 
2213  if (constOp && low != reg)
2214  return failure();
2215  if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2216  constOp = dyn_cast<ConstantOp>(low);
2217 
2218  if (!constOp || constOp.getType() != reset.getType() ||
2219  constOp.getValue() != reset.getValue())
2220  return failure();
2221 
2222  // Check all types should be typed by now
2223  auto regTy = reg.getResult().getType();
2224  if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2225  mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2226  regTy.getBitWidthOrSentinel() < 0)
2227  return failure();
2228 
2229  // Ok, we know we are doing the transformation.
2230 
2231  // Make sure the constant dominates all users.
2232  if (constOp != &con->getBlock()->front())
2233  constOp->moveBefore(&con->getBlock()->front());
2234 
2235  // Replace the register with the constant.
2236  replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2237  // Remove the connect.
2238  rewriter.eraseOp(con);
2239  return success();
2240  }
2241 };
2242 } // namespace
2243 
2244 static bool isDefinedByOneConstantOp(Value v) {
2245  if (auto c = v.getDefiningOp<ConstantOp>())
2246  return c.getValue().isOne();
2247  if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2248  return sc.getValue();
2249  return false;
2250 }
2251 
2252 static LogicalResult
2253 canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2254  if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2255  return failure();
2256 
2257  // Ignore 'passthrough'.
2258  (void)dropWrite(rewriter, reg->getResult(0), {});
2259  replaceOpWithNewOpAndCopyName<NodeOp>(
2260  rewriter, reg, reg.getResetValue(), reg.getNameAttr(), reg.getNameKind(),
2261  reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2262  return success();
2263 }
2264 
2265 void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2266  MLIRContext *context) {
2267  results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2268  results.add(canonicalizeRegResetWithOneReset);
2269  results.add(demoteForceableIfUnused<RegResetOp>);
2270 }
2271 
2272 // Returns the value connected to a port, if there is only one.
2273 static Value getPortFieldValue(Value port, StringRef name) {
2274  auto portTy = type_cast<BundleType>(port.getType());
2275  auto fieldIndex = portTy.getElementIndex(name);
2276  assert(fieldIndex && "missing field on memory port");
2277 
2278  Value value = {};
2279  for (auto *op : port.getUsers()) {
2280  auto portAccess = cast<SubfieldOp>(op);
2281  if (fieldIndex != portAccess.getFieldIndex())
2282  continue;
2283  auto conn = getSingleConnectUserOf(portAccess);
2284  if (!conn || value)
2285  return {};
2286  value = conn.getSrc();
2287  }
2288  return value;
2289 }
2290 
2291 // Returns true if the enable field of a port is set to false.
2292 static bool isPortDisabled(Value port) {
2293  auto value = getPortFieldValue(port, "en");
2294  if (!value)
2295  return false;
2296  auto portConst = value.getDefiningOp<ConstantOp>();
2297  if (!portConst)
2298  return false;
2299  return portConst.getValue().isZero();
2300 }
2301 
2302 // Returns true if the data output is unused.
2303 static bool isPortUnused(Value port, StringRef data) {
2304  auto portTy = type_cast<BundleType>(port.getType());
2305  auto fieldIndex = portTy.getElementIndex(data);
2306  assert(fieldIndex && "missing enable flag on memory port");
2307 
2308  for (auto *op : port.getUsers()) {
2309  auto portAccess = cast<SubfieldOp>(op);
2310  if (fieldIndex != portAccess.getFieldIndex())
2311  continue;
2312  if (!portAccess.use_empty())
2313  return false;
2314  }
2315 
2316  return true;
2317 }
2318 
2319 // Returns the value connected to a port, if there is only one.
2320 static void replacePortField(PatternRewriter &rewriter, Value port,
2321  StringRef name, Value value) {
2322  auto portTy = type_cast<BundleType>(port.getType());
2323  auto fieldIndex = portTy.getElementIndex(name);
2324  assert(fieldIndex && "missing field on memory port");
2325 
2326  for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2327  auto portAccess = cast<SubfieldOp>(op);
2328  if (fieldIndex != portAccess.getFieldIndex())
2329  continue;
2330  rewriter.replaceAllUsesWith(portAccess, value);
2331  rewriter.eraseOp(portAccess);
2332  }
2333 }
2334 
2335 // Remove accesses to a port which is used.
2336 static void erasePort(PatternRewriter &rewriter, Value port) {
2337  // Helper to create a dummy 0 clock for the dummy registers.
2338  Value clock;
2339  auto getClock = [&] {
2340  if (!clock)
2341  clock = rewriter.create<SpecialConstantOp>(
2342  port.getLoc(), ClockType::get(rewriter.getContext()), false);
2343  return clock;
2344  };
2345 
2346  // Find the clock field of the port and determine whether the port is
2347  // accessed only through its subfields or as a whole wire. If the port
2348  // is used in its entirety, replace it with a wire. Otherwise,
2349  // eliminate individual subfields and replace with reasonable defaults.
2350  for (auto *op : port.getUsers()) {
2351  auto subfield = dyn_cast<SubfieldOp>(op);
2352  if (!subfield) {
2353  auto ty = port.getType();
2354  auto reg = rewriter.create<RegOp>(port.getLoc(), ty, getClock());
2355  port.replaceAllUsesWith(reg.getResult());
2356  return;
2357  }
2358  }
2359 
2360  // Remove all connects to field accesses as they are no longer relevant.
2361  // If field values are used anywhere, which should happen solely for read
2362  // ports, a dummy register is introduced which replicates the behaviour of
2363  // memory that is never written, but might be read.
2364  for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2365  auto access = cast<SubfieldOp>(accessOp);
2366  for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2367  auto connect = dyn_cast<FConnectLike>(user);
2368  if (connect && connect.getDest() == access) {
2369  rewriter.eraseOp(user);
2370  continue;
2371  }
2372  }
2373  if (access.use_empty()) {
2374  rewriter.eraseOp(access);
2375  continue;
2376  }
2377 
2378  // Replace read values with a register that is never written, handing off
2379  // the canonicalization of such a register to another canonicalizer.
2380  auto ty = access.getType();
2381  auto reg = rewriter.create<RegOp>(access.getLoc(), ty, getClock());
2382  rewriter.replaceOp(access, reg.getResult());
2383  }
2384  assert(port.use_empty() && "port should have no remaining uses");
2385 }
2386 
2387 namespace {
2388 // If memory has known, but zero width, eliminate it.
2389 struct FoldZeroWidthMemory : public mlir::RewritePattern {
2390  FoldZeroWidthMemory(MLIRContext *context)
2391  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2392  LogicalResult matchAndRewrite(Operation *op,
2393  PatternRewriter &rewriter) const override {
2394  MemOp mem = cast<MemOp>(op);
2395  if (hasDontTouch(mem))
2396  return failure();
2397 
2398  if (!firrtl::type_isa<IntType>(mem.getDataType()) ||
2399  mem.getDataType().getBitWidthOrSentinel() != 0)
2400  return failure();
2401 
2402  // Make sure are users are safe to replace
2403  for (auto port : mem.getResults())
2404  for (auto *user : port.getUsers())
2405  if (!isa<SubfieldOp>(user))
2406  return failure();
2407 
2408  // Annoyingly, there isn't a good replacement for the port as a whole,
2409  // since they have an outer flip type.
2410  for (auto port : op->getResults()) {
2411  for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2412  SubfieldOp sfop = cast<SubfieldOp>(user);
2413  StringRef fieldName = sfop.getFieldName();
2414  auto wire = replaceOpWithNewOpAndCopyName<WireOp>(
2415  rewriter, sfop, sfop.getResult().getType())
2416  .getResult();
2417  if (fieldName.ends_with("data")) {
2418  // Make sure to write data ports.
2419  auto zero = rewriter.create<firrtl::ConstantOp>(
2420  wire.getLoc(), firrtl::type_cast<IntType>(wire.getType()),
2421  APInt::getZero(0));
2422  rewriter.create<StrictConnectOp>(wire.getLoc(), wire, zero);
2423  }
2424  }
2425  }
2426  rewriter.eraseOp(op);
2427  return success();
2428  }
2429 };
2430 
2431 // If memory has no write ports and no file initialization, eliminate it.
2432 struct FoldReadOrWriteOnlyMemory : public mlir::RewritePattern {
2433  FoldReadOrWriteOnlyMemory(MLIRContext *context)
2434  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2435  LogicalResult matchAndRewrite(Operation *op,
2436  PatternRewriter &rewriter) const override {
2437  MemOp mem = cast<MemOp>(op);
2438  if (hasDontTouch(mem))
2439  return failure();
2440  bool isRead = false, isWritten = false;
2441  for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2442  switch (mem.getPortKind(i)) {
2443  case MemOp::PortKind::Read:
2444  isRead = true;
2445  if (isWritten)
2446  return failure();
2447  continue;
2448  case MemOp::PortKind::Write:
2449  isWritten = true;
2450  if (isRead)
2451  return failure();
2452  continue;
2453  case MemOp::PortKind::Debug:
2454  case MemOp::PortKind::ReadWrite:
2455  return failure();
2456  }
2457  llvm_unreachable("unknown port kind");
2458  }
2459  assert((!isWritten || !isRead) && "memory is in use");
2460 
2461  // If the memory is read only, but has a file initialization, then we can't
2462  // remove it. A write only memory with file initialization is okay to
2463  // remove.
2464  if (isRead && mem.getInit())
2465  return failure();
2466 
2467  for (auto port : mem.getResults())
2468  erasePort(rewriter, port);
2469 
2470  rewriter.eraseOp(op);
2471  return success();
2472  }
2473 };
2474 
2475 // Eliminate the dead ports of memories.
2476 struct FoldUnusedPorts : public mlir::RewritePattern {
2477  FoldUnusedPorts(MLIRContext *context)
2478  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2479  LogicalResult matchAndRewrite(Operation *op,
2480  PatternRewriter &rewriter) const override {
2481  MemOp mem = cast<MemOp>(op);
2482  if (hasDontTouch(mem))
2483  return failure();
2484  // Identify the dead and changed ports.
2485  llvm::SmallBitVector deadPorts(mem.getNumResults());
2486  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2487  // Do not simplify annotated ports.
2488  if (!mem.getPortAnnotation(i).empty())
2489  continue;
2490 
2491  // Skip debug ports.
2492  auto kind = mem.getPortKind(i);
2493  if (kind == MemOp::PortKind::Debug)
2494  continue;
2495 
2496  // If a port is disabled, always eliminate it.
2497  if (isPortDisabled(port)) {
2498  deadPorts.set(i);
2499  continue;
2500  }
2501  // Eliminate read ports whose outputs are not used.
2502  if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
2503  deadPorts.set(i);
2504  continue;
2505  }
2506  }
2507  if (deadPorts.none())
2508  return failure();
2509 
2510  // Rebuild the new memory with the altered ports.
2511  SmallVector<Type> resultTypes;
2512  SmallVector<StringRef> portNames;
2513  SmallVector<Attribute> portAnnotations;
2514  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2515  if (deadPorts[i])
2516  continue;
2517  resultTypes.push_back(port.getType());
2518  portNames.push_back(mem.getPortName(i));
2519  portAnnotations.push_back(mem.getPortAnnotation(i));
2520  }
2521 
2522  MemOp newOp;
2523  if (!resultTypes.empty())
2524  newOp = rewriter.create<MemOp>(
2525  mem.getLoc(), resultTypes, mem.getReadLatency(),
2526  mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2527  rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2528  mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2529  mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2530 
2531  // Replace the dead ports with dummy wires.
2532  unsigned nextPort = 0;
2533  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2534  if (deadPorts[i])
2535  erasePort(rewriter, port);
2536  else
2537  port.replaceAllUsesWith(newOp.getResult(nextPort++));
2538  }
2539 
2540  rewriter.eraseOp(op);
2541  return success();
2542  }
2543 };
2544 
2545 // Rewrite write-only read-write ports to write ports.
2546 struct FoldReadWritePorts : public mlir::RewritePattern {
2547  FoldReadWritePorts(MLIRContext *context)
2548  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2549  LogicalResult matchAndRewrite(Operation *op,
2550  PatternRewriter &rewriter) const override {
2551  MemOp mem = cast<MemOp>(op);
2552  if (hasDontTouch(mem))
2553  return failure();
2554 
2555  // Identify read-write ports whose read end is unused.
2556  llvm::SmallBitVector deadReads(mem.getNumResults());
2557  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2558  if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2559  continue;
2560  if (!mem.getPortAnnotation(i).empty())
2561  continue;
2562  if (isPortUnused(port, "rdata")) {
2563  deadReads.set(i);
2564  continue;
2565  }
2566  }
2567  if (deadReads.none())
2568  return failure();
2569 
2570  SmallVector<Type> resultTypes;
2571  SmallVector<StringRef> portNames;
2572  SmallVector<Attribute> portAnnotations;
2573  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2574  if (deadReads[i])
2575  resultTypes.push_back(
2576  MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2577  MemOp::PortKind::Write, mem.getMaskBits()));
2578  else
2579  resultTypes.push_back(port.getType());
2580 
2581  portNames.push_back(mem.getPortName(i));
2582  portAnnotations.push_back(mem.getPortAnnotation(i));
2583  }
2584 
2585  auto newOp = rewriter.create<MemOp>(
2586  mem.getLoc(), resultTypes, mem.getReadLatency(), mem.getWriteLatency(),
2587  mem.getDepth(), mem.getRuw(), rewriter.getStrArrayAttr(portNames),
2588  mem.getName(), mem.getNameKind(), mem.getAnnotations(),
2589  rewriter.getArrayAttr(portAnnotations), mem.getInnerSymAttr(),
2590  mem.getInitAttr(), mem.getPrefixAttr());
2591 
2592  for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2593  auto result = mem.getResult(i);
2594  auto newResult = newOp.getResult(i);
2595  if (deadReads[i]) {
2596  auto resultPortTy = type_cast<BundleType>(result.getType());
2597 
2598  // Rewrite accesses to the old port field to accesses to a
2599  // corresponding field of the new port.
2600  auto replace = [&](StringRef toName, StringRef fromName) {
2601  auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2602  assert(fromFieldIndex && "missing enable flag on memory port");
2603 
2604  auto toField = rewriter.create<SubfieldOp>(newResult.getLoc(),
2605  newResult, toName);
2606  for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2607  auto fromField = cast<SubfieldOp>(op);
2608  if (fromFieldIndex != fromField.getFieldIndex())
2609  continue;
2610  rewriter.replaceOp(fromField, toField.getResult());
2611  }
2612  };
2613 
2614  replace("addr", "addr");
2615  replace("en", "en");
2616  replace("clk", "clk");
2617  replace("data", "wdata");
2618  replace("mask", "wmask");
2619 
2620  // Remove the wmode field, replacing it with dummy wires.
2621  auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
2622  for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2623  auto wmodeField = cast<SubfieldOp>(op);
2624  if (wmodeFieldIndex != wmodeField.getFieldIndex())
2625  continue;
2626  rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2627  }
2628  } else {
2629  result.replaceAllUsesWith(newResult);
2630  }
2631  }
2632  rewriter.eraseOp(op);
2633  return success();
2634  }
2635 };
2636 
2637 // Eliminate the dead ports of memories.
2638 struct FoldUnusedBits : public mlir::RewritePattern {
2639  FoldUnusedBits(MLIRContext *context)
2640  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2641 
2642  LogicalResult matchAndRewrite(Operation *op,
2643  PatternRewriter &rewriter) const override {
2644  MemOp mem = cast<MemOp>(op);
2645  if (hasDontTouch(mem))
2646  return failure();
2647 
2648  // Only apply the transformation if the memory is not sequential.
2649  const auto &summary = mem.getSummary();
2650  if (summary.isMasked || summary.isSeqMem())
2651  return failure();
2652 
2653  auto type = type_dyn_cast<IntType>(mem.getDataType());
2654  if (!type)
2655  return failure();
2656  auto width = type.getBitWidthOrSentinel();
2657  if (width <= 0)
2658  return failure();
2659 
2660  llvm::SmallBitVector usedBits(width);
2661  DenseMap<unsigned, unsigned> mapping;
2662 
2663  // Find which bits are used out of the users of a read port. This detects
2664  // ports whose data/rdata field is used only through bit select ops. The
2665  // bit selects are then used to build a bit-mask. The ops are collected.
2666  SmallVector<BitsPrimOp> readOps;
2667  auto findReadUsers = [&](Value port, StringRef field) {
2668  auto portTy = type_cast<BundleType>(port.getType());
2669  auto fieldIndex = portTy.getElementIndex(field);
2670  assert(fieldIndex && "missing data port");
2671 
2672  for (auto *op : port.getUsers()) {
2673  auto portAccess = cast<SubfieldOp>(op);
2674  if (fieldIndex != portAccess.getFieldIndex())
2675  continue;
2676 
2677  for (auto *user : op->getUsers()) {
2678  auto bits = dyn_cast<BitsPrimOp>(user);
2679  if (!bits) {
2680  usedBits.set();
2681  continue;
2682  }
2683 
2684  usedBits.set(bits.getLo(), bits.getHi() + 1);
2685  mapping[bits.getLo()] = 0;
2686  readOps.push_back(bits);
2687  }
2688  }
2689  };
2690 
2691  // Finds the users of write ports. This expects all the data/wdata fields
2692  // of the ports to be used solely as the destination of strict connects.
2693  // If a memory has ports with other uses, it is excluded from optimisation.
2694  SmallVector<StrictConnectOp> writeOps;
2695  auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
2696  auto portTy = type_cast<BundleType>(port.getType());
2697  auto fieldIndex = portTy.getElementIndex(field);
2698  assert(fieldIndex && "missing data port");
2699 
2700  for (auto *op : port.getUsers()) {
2701  auto portAccess = cast<SubfieldOp>(op);
2702  if (fieldIndex != portAccess.getFieldIndex())
2703  continue;
2704 
2705  auto conn = getSingleConnectUserOf(portAccess);
2706  if (!conn)
2707  return failure();
2708 
2709  writeOps.push_back(conn);
2710  }
2711  return success();
2712  };
2713 
2714  // Traverse all ports and find the read and used data fields.
2715  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2716  // Do not simplify annotated ports.
2717  if (!mem.getPortAnnotation(i).empty())
2718  return failure();
2719 
2720  switch (mem.getPortKind(i)) {
2721  case MemOp::PortKind::Debug:
2722  // Skip debug ports.
2723  return failure();
2724  case MemOp::PortKind::Write:
2725  if (failed(findWriteUsers(port, "data")))
2726  return failure();
2727  continue;
2728  case MemOp::PortKind::Read:
2729  findReadUsers(port, "data");
2730  continue;
2731  case MemOp::PortKind::ReadWrite:
2732  if (failed(findWriteUsers(port, "wdata")))
2733  return failure();
2734  findReadUsers(port, "rdata");
2735  continue;
2736  }
2737  llvm_unreachable("unknown port kind");
2738  }
2739 
2740  // Perform the transformation is there are some bits missing. Unused
2741  // memories are handled in a different canonicalizer.
2742  if (usedBits.all() || usedBits.none())
2743  return failure();
2744 
2745  // Build a mapping of existing indices to compacted ones.
2746  SmallVector<std::pair<unsigned, unsigned>> ranges;
2747  unsigned newWidth = 0;
2748  for (int i = usedBits.find_first(); 0 <= i && i < width;) {
2749  int e = usedBits.find_next_unset(i);
2750  if (e < 0)
2751  e = width;
2752  for (int idx = i; idx < e; ++idx, ++newWidth) {
2753  if (auto it = mapping.find(idx); it != mapping.end()) {
2754  it->second = newWidth;
2755  }
2756  }
2757  ranges.emplace_back(i, e - 1);
2758  i = e != width ? usedBits.find_next(e) : e;
2759  }
2760 
2761  // Create the new op with the new port types.
2762  auto newType = IntType::get(op->getContext(), type.isSigned(), newWidth);
2763  SmallVector<Type> portTypes;
2764  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2765  portTypes.push_back(
2766  MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
2767  }
2768  auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
2769  mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
2770  mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
2771  mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
2772  mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2773 
2774  // Rewrite bundle users to the new data type.
2775  auto rewriteSubfield = [&](Value port, StringRef field) {
2776  auto portTy = type_cast<BundleType>(port.getType());
2777  auto fieldIndex = portTy.getElementIndex(field);
2778  assert(fieldIndex && "missing data port");
2779 
2780  rewriter.setInsertionPointAfter(newMem);
2781  auto newPortAccess =
2782  rewriter.create<SubfieldOp>(port.getLoc(), port, field);
2783 
2784  for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2785  auto portAccess = cast<SubfieldOp>(op);
2786  if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
2787  continue;
2788  rewriter.replaceOp(portAccess, newPortAccess.getResult());
2789  }
2790  };
2791 
2792  // Rewrite the field accesses.
2793  for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
2794  switch (newMem.getPortKind(i)) {
2795  case MemOp::PortKind::Debug:
2796  llvm_unreachable("cannot rewrite debug port");
2797  case MemOp::PortKind::Write:
2798  rewriteSubfield(port, "data");
2799  continue;
2800  case MemOp::PortKind::Read:
2801  rewriteSubfield(port, "data");
2802  continue;
2803  case MemOp::PortKind::ReadWrite:
2804  rewriteSubfield(port, "rdata");
2805  rewriteSubfield(port, "wdata");
2806  continue;
2807  }
2808  llvm_unreachable("unknown port kind");
2809  }
2810 
2811  // Rewrite the reads to the new ranges, compacting them.
2812  for (auto readOp : readOps) {
2813  rewriter.setInsertionPointAfter(readOp);
2814  auto it = mapping.find(readOp.getLo());
2815  assert(it != mapping.end() && "bit op mapping not found");
2816  rewriter.replaceOpWithNewOp<BitsPrimOp>(
2817  readOp, readOp.getInput(),
2818  readOp.getHi() - readOp.getLo() + it->second, it->second);
2819  }
2820 
2821  // Rewrite the writes into a concatenation of slices.
2822  for (auto writeOp : writeOps) {
2823  Value source = writeOp.getSrc();
2824  rewriter.setInsertionPoint(writeOp);
2825 
2826  Value catOfSlices;
2827  for (auto &[start, end] : ranges) {
2828  Value slice =
2829  rewriter.create<BitsPrimOp>(writeOp.getLoc(), source, end, start);
2830  if (catOfSlices) {
2831  catOfSlices =
2832  rewriter.create<CatPrimOp>(writeOp.getLoc(), slice, catOfSlices);
2833  } else {
2834  catOfSlices = slice;
2835  }
2836  }
2837  rewriter.replaceOpWithNewOp<StrictConnectOp>(writeOp, writeOp.getDest(),
2838  catOfSlices);
2839  }
2840 
2841  return success();
2842  }
2843 };
2844 
2845 // Rewrite single-address memories to a firrtl register.
2846 struct FoldRegMems : public mlir::RewritePattern {
2847  FoldRegMems(MLIRContext *context)
2848  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2849  LogicalResult matchAndRewrite(Operation *op,
2850  PatternRewriter &rewriter) const override {
2851  MemOp mem = cast<MemOp>(op);
2852  const FirMemory &info = mem.getSummary();
2853  if (hasDontTouch(mem) || info.depth != 1)
2854  return failure();
2855 
2856  auto memModule = mem->getParentOfType<FModuleOp>();
2857 
2858  // Find the clock of the register-to-be, all write ports should share it.
2859  Value clock;
2860  SmallPtrSet<Operation *, 8> connects;
2861  SmallVector<SubfieldOp> portAccesses;
2862  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2863  if (!mem.getPortAnnotation(i).empty())
2864  continue;
2865 
2866  auto collect = [&, port = port](ArrayRef<StringRef> fields) {
2867  auto portTy = type_cast<BundleType>(port.getType());
2868  for (auto field : fields) {
2869  auto fieldIndex = portTy.getElementIndex(field);
2870  assert(fieldIndex && "missing field on memory port");
2871 
2872  for (auto *op : port.getUsers()) {
2873  auto portAccess = cast<SubfieldOp>(op);
2874  if (fieldIndex != portAccess.getFieldIndex())
2875  continue;
2876  portAccesses.push_back(portAccess);
2877  for (auto *user : portAccess->getUsers()) {
2878  auto conn = dyn_cast<FConnectLike>(user);
2879  if (!conn)
2880  return failure();
2881  connects.insert(conn);
2882  }
2883  }
2884  }
2885  return success();
2886  };
2887 
2888  switch (mem.getPortKind(i)) {
2889  case MemOp::PortKind::Debug:
2890  return failure();
2891  case MemOp::PortKind::Read:
2892  if (failed(collect({"clk", "en", "addr"})))
2893  return failure();
2894  continue;
2895  case MemOp::PortKind::Write:
2896  if (failed(collect({"clk", "en", "addr", "data", "mask"})))
2897  return failure();
2898  break;
2899  case MemOp::PortKind::ReadWrite:
2900  if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
2901  return failure();
2902  break;
2903  }
2904 
2905  Value portClock = getPortFieldValue(port, "clk");
2906  if (!portClock || (clock && portClock != clock))
2907  return failure();
2908  clock = portClock;
2909  }
2910 
2911  // Create a new register to store the data.
2912  auto ty = mem.getDataType();
2913  rewriter.setInsertionPointAfterValue(clock);
2914  auto reg = rewriter.create<RegOp>(mem.getLoc(), ty, clock, mem.getName())
2915  .getResult();
2916 
2917  // Helper to insert a given number of pipeline stages through registers.
2918  auto pipeline = [&](Value value, Value clock, const Twine &name,
2919  unsigned latency) {
2920  for (unsigned i = 0; i < latency; ++i) {
2921  std::string regName;
2922  {
2923  llvm::raw_string_ostream os(regName);
2924  os << mem.getName() << "_" << name << "_" << i;
2925  }
2926 
2927  auto reg = rewriter
2928  .create<RegOp>(mem.getLoc(), value.getType(), clock,
2929  rewriter.getStringAttr(regName))
2930  .getResult();
2931  rewriter.create<StrictConnectOp>(value.getLoc(), reg, value);
2932  value = reg;
2933  }
2934  return value;
2935  };
2936 
2937  const unsigned writeStages = info.writeLatency - 1;
2938 
2939  // Traverse each port. Replace reads with the pipelined register, discarding
2940  // the enable flag and reading unconditionally. Pipeline the mask, enable
2941  // and data bits of all write ports to be arbitrated and wired to the reg.
2942  SmallVector<std::tuple<Value, Value, Value>> writes;
2943  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2944  Value portClock = getPortFieldValue(port, "clk");
2945  StringRef name = mem.getPortName(i);
2946 
2947  auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
2948  Value value = getPortFieldValue(port, field);
2949  assert(value);
2950  rewriter.setInsertionPointAfterValue(value);
2951  return pipeline(value, portClock, name + "_" + field, stages);
2952  };
2953 
2954  switch (mem.getPortKind(i)) {
2955  case MemOp::PortKind::Debug:
2956  llvm_unreachable("unknown port kind");
2957  case MemOp::PortKind::Read: {
2958  // Read ports pipeline the addr and enable signals. However, the
2959  // address must be 0 for single-address memories and the enable signal
2960  // is ignored, always reading out the register. Under these constraints,
2961  // the read port can be replaced with the value from the register.
2962  rewriter.setInsertionPointAfterValue(reg);
2963  replacePortField(rewriter, port, "data", reg);
2964  break;
2965  }
2966  case MemOp::PortKind::Write: {
2967  auto data = portPipeline("data", writeStages);
2968  auto en = portPipeline("en", writeStages);
2969  auto mask = portPipeline("mask", writeStages);
2970  writes.emplace_back(data, en, mask);
2971  break;
2972  }
2973  case MemOp::PortKind::ReadWrite: {
2974  // Always read the register into the read end.
2975  rewriter.setInsertionPointAfterValue(reg);
2976  replacePortField(rewriter, port, "rdata", reg);
2977 
2978  // Create a write enable and pipeline stages.
2979  auto wdata = portPipeline("wdata", writeStages);
2980  auto wmask = portPipeline("wmask", writeStages);
2981 
2982  Value en = getPortFieldValue(port, "en");
2983  Value wmode = getPortFieldValue(port, "wmode");
2984  rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
2985 
2986  auto wen = rewriter.create<AndPrimOp>(port.getLoc(), en, wmode);
2987  auto wenPipelined =
2988  pipeline(wen, portClock, name + "_wen", writeStages);
2989  writes.emplace_back(wdata, wenPipelined, wmask);
2990  break;
2991  }
2992  }
2993  }
2994 
2995  // Regardless of `writeUnderWrite`, always implement PortOrder.
2996  rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
2997  Value next = reg;
2998  for (auto &[data, en, mask] : writes) {
2999  Value masked;
3000 
3001  // If a mask bit is used, emit muxes to select the input from the
3002  // register (no mask) or the input (mask bit set).
3003  Location loc = mem.getLoc();
3004  unsigned maskGran = info.dataWidth / info.maskBits;
3005  for (unsigned i = 0; i < info.maskBits; ++i) {
3006  unsigned hi = (i + 1) * maskGran - 1;
3007  unsigned lo = i * maskGran;
3008 
3009  auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
3010  auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
3011  auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
3012  auto chunk = rewriter.create<MuxPrimOp>(loc, bit, dataPart, nextPart);
3013 
3014  if (masked) {
3015  masked = rewriter.create<CatPrimOp>(loc, chunk, masked);
3016  } else {
3017  masked = chunk;
3018  }
3019  }
3020 
3021  next = rewriter.create<MuxPrimOp>(next.getLoc(), en, masked, next);
3022  }
3023  rewriter.create<StrictConnectOp>(reg.getLoc(), reg, next);
3024 
3025  // Delete the fields and their associated connects.
3026  for (Operation *conn : connects)
3027  rewriter.eraseOp(conn);
3028  for (auto portAccess : portAccesses)
3029  rewriter.eraseOp(portAccess);
3030  rewriter.eraseOp(mem);
3031 
3032  return success();
3033  }
3034 };
3035 } // namespace
3036 
3037 void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3038  MLIRContext *context) {
3039  results
3040  .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3041  FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3042  context);
3043 }
3044 
3045 //===----------------------------------------------------------------------===//
3046 // Declarations
3047 //===----------------------------------------------------------------------===//
3048 
3049 // Turn synchronous reset looking register updates to registers with resets.
3050 // Also, const prop registers that are driven by a mux tree containing only
3051 // instances of one constant or self-assigns.
3052 static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3053  // reg ; connect(reg, mux(port, const, val)) ->
3054  // reg.reset(port, const); connect(reg, val)
3055 
3056  // Find the one true connect, or bail
3057  auto con = getSingleConnectUserOf(reg.getResult());
3058  if (!con)
3059  return failure();
3060 
3061  auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3062  if (!mux)
3063  return failure();
3064  auto *high = mux.getHigh().getDefiningOp();
3065  auto *low = mux.getLow().getDefiningOp();
3066  // Reset value must be constant
3067  auto constOp = dyn_cast_or_null<ConstantOp>(high);
3068 
3069  // Detect the case if a register only has two possible drivers:
3070  // (1) itself/uninit and (2) constant.
3071  // The mux can then be replaced with the constant.
3072  // r = mux(cond, r, 3) --> r = 3
3073  // r = mux(cond, 3, r) --> r = 3
3074  bool constReg = false;
3075 
3076  if (constOp && low == reg)
3077  constReg = true;
3078  else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3079  constReg = true;
3080  constOp = dyn_cast<ConstantOp>(low);
3081  }
3082  if (!constOp)
3083  return failure();
3084 
3085  // For a non-constant register, reset should be a module port (heuristic to
3086  // limit to intended reset lines). Replace the register anyway if constant.
3087  if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3088  return failure();
3089 
3090  // Check all types should be typed by now
3091  auto regTy = reg.getResult().getType();
3092  if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3093  mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3094  regTy.getBitWidthOrSentinel() < 0)
3095  return failure();
3096 
3097  // Ok, we know we are doing the transformation.
3098 
3099  // Make sure the constant dominates all users.
3100  if (constOp != &con->getBlock()->front())
3101  constOp->moveBefore(&con->getBlock()->front());
3102 
3103  if (!constReg) {
3104  SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3105  auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3106  rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3107  mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3108  reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3109  reg.getForceableAttr());
3110  newReg->setDialectAttrs(attrs);
3111  }
3112  auto pt = rewriter.saveInsertionPoint();
3113  rewriter.setInsertionPoint(con);
3114  auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3115  replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3116  rewriter.restoreInsertionPoint(pt);
3117  return success();
3118 }
3119 
3120 LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3121  if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3122  succeeded(foldHiddenReset(op, rewriter)))
3123  return success();
3124 
3125  if (succeeded(demoteForceableIfUnused(op, rewriter)))
3126  return success();
3127 
3128  return failure();
3129 }
3130 
3131 //===----------------------------------------------------------------------===//
3132 // Verification Ops.
3133 //===----------------------------------------------------------------------===//
3134 
3135 static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3136  Value enable,
3137  PatternRewriter &rewriter,
3138  bool eraseIfZero) {
3139  // If the verification op is never enabled, delete it.
3140  if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3141  if (constant.getValue().isZero()) {
3142  rewriter.eraseOp(op);
3143  return success();
3144  }
3145  }
3146 
3147  // If the verification op is never triggered, delete it.
3148  if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3149  if (constant.getValue().isZero() == eraseIfZero) {
3150  rewriter.eraseOp(op);
3151  return success();
3152  }
3153  }
3154 
3155  return failure();
3156 }
3157 
3158 template <class Op, bool EraseIfZero = false>
3159 static LogicalResult canonicalizeImmediateVerifOp(Op op,
3160  PatternRewriter &rewriter) {
3161  return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3162  EraseIfZero);
3163 }
3164 
3165 void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3166  MLIRContext *context) {
3167  results.add(canonicalizeImmediateVerifOp<AssertOp>);
3168 }
3169 
3170 void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3171  MLIRContext *context) {
3172  results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3173 }
3174 
3175 void UnclockedAssumeIntrinsicOp::getCanonicalizationPatterns(
3176  RewritePatternSet &results, MLIRContext *context) {
3177  results.add(canonicalizeImmediateVerifOp<UnclockedAssumeIntrinsicOp>);
3178 }
3179 
3180 void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3181  MLIRContext *context) {
3182  results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3183 }
3184 
3185 //===----------------------------------------------------------------------===//
3186 // InvalidValueOp
3187 //===----------------------------------------------------------------------===//
3188 
3189 LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3190  PatternRewriter &rewriter) {
3191  // Remove `InvalidValueOp`s with no uses.
3192  if (op.use_empty()) {
3193  rewriter.eraseOp(op);
3194  return success();
3195  }
3196  return failure();
3197 }
3198 
3199 //===----------------------------------------------------------------------===//
3200 // ClockGateIntrinsicOp
3201 //===----------------------------------------------------------------------===//
3202 
3203 OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3204  // Forward the clock if one of the enables is always true.
3205  if (isConstantOne(adaptor.getEnable()) ||
3206  isConstantOne(adaptor.getTestEnable()))
3207  return getInput();
3208 
3209  // Fold to a constant zero clock if the enables are always false.
3210  if (isConstantZero(adaptor.getEnable()) &&
3211  (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3212  return BoolAttr::get(getContext(), false);
3213 
3214  // Forward constant zero clocks.
3215  if (isConstantZero(adaptor.getInput()))
3216  return BoolAttr::get(getContext(), false);
3217 
3218  return {};
3219 }
3220 
3221 LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3222  PatternRewriter &rewriter) {
3223  // Remove constant false test enable.
3224  if (auto testEnable = op.getTestEnable()) {
3225  if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3226  if (constOp.getValue().isZero()) {
3227  rewriter.modifyOpInPlace(op,
3228  [&] { op.getTestEnableMutable().clear(); });
3229  return success();
3230  }
3231  }
3232  }
3233 
3234  return failure();
3235 }
3236 
3237 //===----------------------------------------------------------------------===//
3238 // Reference Ops.
3239 //===----------------------------------------------------------------------===//
3240 
3241 // refresolve(forceable.ref) -> forceable.data
3242 static LogicalResult
3243 canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3244  auto forceable = op.getRef().getDefiningOp<Forceable>();
3245  if (!forceable || !forceable.isForceable() ||
3246  op.getRef() != forceable.getDataRef() ||
3247  op.getType() != forceable.getDataType())
3248  return failure();
3249  rewriter.replaceAllUsesWith(op, forceable.getData());
3250  return success();
3251 }
3252 
3253 void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3254  MLIRContext *context) {
3255  results.insert<patterns::RefResolveOfRefSend>(context);
3256  results.insert(canonicalizeRefResolveOfForceable);
3257 }
3258 
3259 OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3260  // RefCast is unnecessary if types match.
3261  if (getInput().getType() == getType())
3262  return getInput();
3263  return {};
3264 }
3265 
3266 static bool isConstantZero(Value operand) {
3267  auto constOp = operand.getDefiningOp<ConstantOp>();
3268  return constOp && constOp.getValue().isZero();
3269 }
3270 
3271 template <typename Op>
3272 static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3273  if (isConstantZero(op.getPredicate())) {
3274  rewriter.eraseOp(op);
3275  return success();
3276  }
3277  return failure();
3278 }
3279 
3280 void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3281  MLIRContext *context) {
3282  results.add(eraseIfPredFalse<RefForceOp>);
3283 }
3284 void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3285  MLIRContext *context) {
3286  results.add(eraseIfPredFalse<RefForceInitialOp>);
3287 }
3288 void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3289  MLIRContext *context) {
3290  results.add(eraseIfPredFalse<RefReleaseOp>);
3291 }
3292 void RefReleaseInitialOp::getCanonicalizationPatterns(
3293  RewritePatternSet &results, MLIRContext *context) {
3294  results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3295 }
3296 
3297 //===----------------------------------------------------------------------===//
3298 // HasBeenResetIntrinsicOp
3299 //===----------------------------------------------------------------------===//
3300 
3301 OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3302  // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3303 
3304  // Fold to zero if the reset is a constant. In this case the op is either
3305  // permanently in reset or never resets. Both mean that the reset never
3306  // finishes, so this op never returns true.
3307  if (adaptor.getReset())
3308  return getIntZerosAttr(UIntType::get(getContext(), 1));
3309 
3310  // Fold to zero if the clock is a constant and the reset is synchronous. In
3311  // that case the reset will never be started.
3312  if (isUInt1(getReset().getType()) && adaptor.getClock())
3313  return getIntZerosAttr(UIntType::get(getContext(), 1));
3314 
3315  return {};
3316 }
3317 
3318 //===----------------------------------------------------------------------===//
3319 // FPGAProbeIntrinsicOp
3320 //===----------------------------------------------------------------------===//
3321 
3322 static bool isTypeEmpty(FIRRTLType type) {
3324  .Case<FVectorType>(
3325  [&](auto ty) -> bool { return isTypeEmpty(ty.getElementType()); })
3326  .Case<BundleType>([&](auto ty) -> bool {
3327  for (auto elem : ty.getElements())
3328  if (!isTypeEmpty(elem.type))
3329  return false;
3330  return true;
3331  })
3332  .Case<IntType>([&](auto ty) { return ty.getWidth() == 0; })
3333  .Default([](auto) -> bool { return false; });
3334 }
3335 
3336 LogicalResult FPGAProbeIntrinsicOp::canonicalize(FPGAProbeIntrinsicOp op,
3337  PatternRewriter &rewriter) {
3338  auto firrtlTy = type_dyn_cast<FIRRTLType>(op.getInput().getType());
3339  if (!firrtlTy)
3340  return failure();
3341 
3342  if (!isTypeEmpty(firrtlTy))
3343  return failure();
3344 
3345  rewriter.eraseOp(op);
3346  return success();
3347 }
assert(baseType &&"element must be base type")
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
Definition: FIRRTLFolds.cpp:71
static LogicalResult canonicalizeSingleSetConnect(StrictConnectOp op, PatternRewriter &rewriter)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static void erasePort(PatternRewriter &rewriter, Value port)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region.
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
Definition: FIRRTLFolds.cpp:91
static bool isTypeEmpty(FIRRTLType type)
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
Definition: FIRRTLFolds.cpp:82
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
Definition: FIRRTLFolds.cpp:32
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value moveNameHint(OpResult old, Value passthrough)
Definition: FIRRTLFolds.cpp:48
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
int32_t width
Definition: FIRRTL.cpp:36
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:518
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:528
This is the common base class between SIntType and UIntType.
Definition: FIRRTLTypes.h:294
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
Definition: FIRRTLTypes.h:273
bool hasWidth() const
Return true if this integer type has a known width.
Definition: FIRRTLTypes.h:281
def connect(destination, source)
Definition: support.py:37
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:34
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
Definition: FIRRTLOps.cpp:287
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
StrictConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition: APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition: APInt.cpp:18
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
Definition: Naming.cpp:47
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition: FoldUtils.h:27
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition: FoldUtils.h:34
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:20