CIRCT  18.0.0git
FIRRTLFolds.cpp
Go to the documentation of this file.
1 //===- FIRRTLFolds.cpp - Implement folds and canonicalizations for ops ----===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the folding and canonicalizations for FIRRTL ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "circt/Support/APInt.h"
18 #include "circt/Support/LLVM.h"
19 #include "mlir/IR/Matchers.h"
20 #include "mlir/IR/PatternMatch.h"
21 #include "llvm/ADT/APSInt.h"
22 #include "llvm/ADT/SmallPtrSet.h"
23 #include "llvm/ADT/StringExtras.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 using namespace circt;
27 using namespace firrtl;
28 
29 // Drop writes to old and pass through passthrough to make patterns easier to
30 // write.
31 static Value dropWrite(PatternRewriter &rewriter, OpResult old,
32  Value passthrough) {
33  SmallPtrSet<Operation *, 8> users;
34  for (auto *user : old.getUsers())
35  users.insert(user);
36  for (Operation *user : users)
37  if (auto connect = dyn_cast<FConnectLike>(user))
38  if (connect.getDest() == old)
39  rewriter.eraseOp(user);
40  return passthrough;
41 }
42 
43 // Move a name hint from a soon to be deleted operation to a new operation.
44 // Pass through the new operation to make patterns easier to write. This cannot
45 // move a name to a port (block argument), doing so would require rewriting all
46 // instance sites as well as the module.
47 static Value moveNameHint(OpResult old, Value passthrough) {
48  Operation *op = passthrough.getDefiningOp();
49  // This should handle ports, but it isn't clear we can change those in
50  // canonicalizers.
51  assert(op && "passthrough must be an operation");
52  Operation *oldOp = old.getOwner();
53  auto name = oldOp->getAttrOfType<StringAttr>("name");
54  if (name && !name.getValue().empty())
55  op->setAttr("name", name);
56  return passthrough;
57 }
58 
59 // Declarative canonicalization patterns
60 namespace circt {
61 namespace firrtl {
62 namespace patterns {
63 #include "circt/Dialect/FIRRTL/FIRRTLCanonicalization.h.inc"
64 } // namespace patterns
65 } // namespace firrtl
66 } // namespace circt
67 
68 /// Return true if this operation's operands and results all have a known width.
69 /// This only works for integer types.
70 static bool hasKnownWidthIntTypes(Operation *op) {
71  auto resultType = type_cast<IntType>(op->getResult(0).getType());
72  if (!resultType.hasWidth())
73  return false;
74  for (Value operand : op->getOperands())
75  if (!type_cast<IntType>(operand.getType()).hasWidth())
76  return false;
77  return true;
78 }
79 
80 /// Return true if this value is 1 bit UInt.
81 static bool isUInt1(Type type) {
82  auto t = type_dyn_cast<UIntType>(type);
83  if (!t || !t.hasWidth() || t.getWidth() != 1)
84  return false;
85  return true;
86 }
87 
88 // Heuristic to pick the best name.
89 // Good names are not useless, don't start with an underscore, minimize
90 // underscores in them, and are short. This function deterministically favors
91 // the second name on ties.
92 static StringRef chooseName(StringRef a, StringRef b) {
93  if (a.empty())
94  return b;
95  if (b.empty())
96  return a;
97  if (isUselessName(a))
98  return b;
99  if (isUselessName(b))
100  return a;
101  if (a.starts_with("_"))
102  return b;
103  if (b.starts_with("_"))
104  return a;
105  if (b.count('_') < a.count('_'))
106  return b;
107  if (b.count('_') > a.count('_'))
108  return a;
109  if (a.size() > b.size())
110  return b;
111  return a;
112 }
113 
114 /// Set the name of an op based on the best of two names: The current name, and
115 /// the name passed in.
116 static void updateName(PatternRewriter &rewriter, Operation *op,
117  StringAttr name) {
118  // Should never rename InstanceOp
119  assert(!isa<InstanceOp>(op));
120  if (!name || name.getValue().empty())
121  return;
122  auto newName = name.getValue(); // old name is interesting
123  auto newOpName = op->getAttrOfType<StringAttr>("name");
124  // new name might not be interesting
125  if (newOpName)
126  newName = chooseName(newOpName.getValue(), name.getValue());
127  // Only update if needed
128  if (!newOpName || newOpName.getValue() != newName)
129  rewriter.updateRootInPlace(
130  op, [&] { op->setAttr("name", rewriter.getStringAttr(newName)); });
131 }
132 
133 /// A wrapper of `PatternRewriter::replaceOp` to propagate "name" attribute.
134 /// If a replaced op has a "name" attribute, this function propagates the name
135 /// to the new value.
136 static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
137  Value newValue) {
138  if (auto *newOp = newValue.getDefiningOp()) {
139  auto name = op->getAttrOfType<StringAttr>("name");
140  updateName(rewriter, newOp, name);
141  }
142  rewriter.replaceOp(op, newValue);
143 }
144 
145 /// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate "name"
146 /// attribute. If a replaced op has a "name" attribute, this function propagates
147 /// the name to the new value.
148 template <typename OpTy, typename... Args>
149 static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
150  Operation *op, Args &&...args) {
151  auto name = op->getAttrOfType<StringAttr>("name");
152  auto newOp =
153  rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
154  updateName(rewriter, newOp, name);
155  return newOp;
156 }
157 
158 /// Return true if this is a useless temporary name produced by FIRRTL. We
159 /// drop these as they don't convey semantic meaning.
160 bool circt::firrtl::isUselessName(StringRef name) {
161  if (name.empty())
162  return true;
163  // Ignore _.*
164  return name.startswith("_T") || name.startswith("_WIRE");
165 }
166 
167 /// Return true if the name is droppable. Note that this is different from
168 /// `isUselessName` because non-useless names may be also droppable.
169 bool circt::firrtl::hasDroppableName(Operation *op) {
170  if (auto namableOp = dyn_cast<firrtl::FNamableOp>(op))
171  return namableOp.hasDroppableName();
172  return false;
173 }
174 
175 /// Implicitly replace the operand to a constant folding operation with a const
176 /// 0 in case the operand is non-constant but has a bit width 0, or if the
177 /// operand is an invalid value.
178 ///
179 /// This makes constant folding significantly easier, as we can simply pass the
180 /// operands to an operation through this function to appropriately replace any
181 /// zero-width dynamic values or invalid values with a constant of value 0.
182 static std::optional<APSInt>
183 getExtendedConstant(Value operand, Attribute constant, int32_t destWidth) {
184  assert(type_cast<IntType>(operand.getType()) &&
185  "getExtendedConstant is limited to integer types");
186 
187  // We never support constant folding to unknown width values.
188  if (destWidth < 0)
189  return {};
190 
191  // Extension signedness follows the operand sign.
192  if (IntegerAttr result = dyn_cast_or_null<IntegerAttr>(constant))
193  return extOrTruncZeroWidth(result.getAPSInt(), destWidth);
194 
195  // If the operand is zero bits, then we can return a zero of the result
196  // type.
197  if (type_cast<IntType>(operand.getType()).getWidth() == 0)
198  return APSInt(destWidth,
199  type_cast<IntType>(operand.getType()).isUnsigned());
200  return {};
201 }
202 
203 /// Determine the value of a constant operand for the sake of constant folding.
204 static std::optional<APSInt> getConstant(Attribute operand) {
205  if (!operand)
206  return {};
207  if (auto attr = dyn_cast<BoolAttr>(operand))
208  return APSInt(APInt(1, attr.getValue()));
209  if (auto attr = dyn_cast<IntegerAttr>(operand))
210  return attr.getAPSInt();
211  return {};
212 }
213 
214 /// Determine whether a constant operand is a zero value for the sake of
215 /// constant folding. This considers `invalidvalue` to be zero.
216 static bool isConstantZero(Attribute operand) {
217  if (auto cst = getConstant(operand))
218  return cst->isZero();
219  return false;
220 }
221 
222 /// Determine whether a constant operand is a one value for the sake of constant
223 /// folding.
224 static bool isConstantOne(Attribute operand) {
225  if (auto cst = getConstant(operand))
226  return cst->isOne();
227  return false;
228 }
229 
230 /// This is the policy for folding, which depends on the sort of operator we're
231 /// processing.
232 enum class BinOpKind {
233  Normal,
234  Compare,
236 };
237 
238 /// Applies the constant folding function `calculate` to the given operands.
239 ///
240 /// Sign or zero extends the operands appropriately to the bitwidth of the
241 /// result type if \p useDstWidth is true, else to the larger of the two operand
242 /// bit widths and depending on whether the operation is to be performed on
243 /// signed or unsigned operands.
244 static Attribute constFoldFIRRTLBinaryOp(
245  Operation *op, ArrayRef<Attribute> operands, BinOpKind opKind,
246  const function_ref<APInt(const APSInt &, const APSInt &)> &calculate) {
247  assert(operands.size() == 2 && "binary op takes two operands");
248 
249  // We cannot fold something to an unknown width.
250  auto resultType = type_cast<IntType>(op->getResult(0).getType());
251  if (resultType.getWidthOrSentinel() < 0)
252  return {};
253 
254  // Any binary op returning i0 is 0.
255  if (resultType.getWidthOrSentinel() == 0)
256  return getIntAttr(resultType, APInt(0, 0, resultType.isSigned()));
257 
258  // Determine the operand widths. This is either dictated by the operand type,
259  // or if that type is an unsized integer, by the actual bits necessary to
260  // represent the constant value.
261  auto lhsWidth =
262  type_cast<IntType>(op->getOperand(0).getType()).getWidthOrSentinel();
263  auto rhsWidth =
264  type_cast<IntType>(op->getOperand(1).getType()).getWidthOrSentinel();
265  if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>())
266  lhsWidth = std::max<int32_t>(lhsWidth, lhs.getValue().getBitWidth());
267  if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
268  rhsWidth = std::max<int32_t>(rhsWidth, rhs.getValue().getBitWidth());
269 
270  // Compares extend the operands to the widest of the operand types, not to the
271  // result type.
272  int32_t operandWidth;
273  switch (opKind) {
274  case BinOpKind::Normal:
275  operandWidth = resultType.getWidthOrSentinel();
276  break;
277  case BinOpKind::Compare:
278  // Compares compute with the widest operand, not at the destination type
279  // (which is always i1).
280  operandWidth = std::max(1, std::max(lhsWidth, rhsWidth));
281  break;
283  operandWidth =
284  std::max(std::max(lhsWidth, rhsWidth), resultType.getWidthOrSentinel());
285  break;
286  }
287 
288  auto lhs = getExtendedConstant(op->getOperand(0), operands[0], operandWidth);
289  if (!lhs)
290  return {};
291  auto rhs = getExtendedConstant(op->getOperand(1), operands[1], operandWidth);
292  if (!rhs)
293  return {};
294 
295  APInt resultValue = calculate(*lhs, *rhs);
296 
297  // If the result type is smaller than the computation then we need to
298  // narrow the constant after the calculation.
299  if (opKind == BinOpKind::DivideOrShift)
300  resultValue = resultValue.trunc(resultType.getWidthOrSentinel());
301 
302  assert((unsigned)resultType.getWidthOrSentinel() ==
303  resultValue.getBitWidth());
304  return getIntAttr(resultType, resultValue);
305 }
306 
307 /// Applies the canonicalization function `canonicalize` to the given operation.
308 ///
309 /// Determines which (if any) of the operation's operands are constants, and
310 /// provides them as arguments to the callback function. Any `invalidvalue` in
311 /// the input is mapped to a constant zero. The value returned from the callback
312 /// is used as the replacement for `op`, and an additional pad operation is
313 /// inserted if necessary. Does nothing if the result of `op` is of unknown
314 /// width, in which case the necessity of a pad cannot be determined.
315 static LogicalResult canonicalizePrimOp(
316  Operation *op, PatternRewriter &rewriter,
317  const function_ref<OpFoldResult(ArrayRef<Attribute>)> &canonicalize) {
318  // Can only operate on FIRRTL primitive operations.
319  if (op->getNumResults() != 1)
320  return failure();
321  auto type = type_dyn_cast<FIRRTLBaseType>(op->getResult(0).getType());
322  if (!type)
323  return failure();
324 
325  // Can only operate on operations with a known result width.
326  auto width = type.getBitWidthOrSentinel();
327  if (width < 0)
328  return failure();
329 
330  // Determine which of the operands are constants.
331  SmallVector<Attribute, 3> constOperands;
332  constOperands.reserve(op->getNumOperands());
333  for (auto operand : op->getOperands()) {
334  Attribute attr;
335  if (auto *defOp = operand.getDefiningOp())
336  TypeSwitch<Operation *>(defOp).Case<ConstantOp, SpecialConstantOp>(
337  [&](auto op) { attr = op.getValueAttr(); });
338  constOperands.push_back(attr);
339  }
340 
341  // Perform the canonicalization and materialize the result if it is a
342  // constant.
343  auto result = canonicalize(constOperands);
344  if (!result)
345  return failure();
346  Value resultValue;
347  if (auto cst = dyn_cast<Attribute>(result))
348  resultValue = op->getDialect()
349  ->materializeConstant(rewriter, cst, type, op->getLoc())
350  ->getResult(0);
351  else
352  resultValue = result.get<Value>();
353 
354  // Insert a pad if the type widths disagree.
355  if (width !=
356  type_cast<FIRRTLBaseType>(resultValue.getType()).getBitWidthOrSentinel())
357  resultValue = rewriter.create<PadPrimOp>(op->getLoc(), resultValue, width);
358 
359  // Insert a cast if this is a uint vs. sint or vice versa.
360  if (type_isa<SIntType>(type) && type_isa<UIntType>(resultValue.getType()))
361  resultValue = rewriter.create<AsSIntPrimOp>(op->getLoc(), resultValue);
362  else if (type_isa<UIntType>(type) &&
363  type_isa<SIntType>(resultValue.getType()))
364  resultValue = rewriter.create<AsUIntPrimOp>(op->getLoc(), resultValue);
365 
366  assert(type == resultValue.getType() && "canonicalization changed type");
367  replaceOpAndCopyName(rewriter, op, resultValue);
368  return success();
369 }
370 
371 /// Get the largest unsigned value of a given bit width. Returns a 1-bit zero
372 /// value if `bitWidth` is 0.
373 static APInt getMaxUnsignedValue(unsigned bitWidth) {
374  return bitWidth > 0 ? APInt::getMaxValue(bitWidth) : APInt();
375 }
376 
377 /// Get the smallest signed value of a given bit width. Returns a 1-bit zero
378 /// value if `bitWidth` is 0.
379 static APInt getMinSignedValue(unsigned bitWidth) {
380  return bitWidth > 0 ? APInt::getSignedMinValue(bitWidth) : APInt();
381 }
382 
383 /// Get the largest signed value of a given bit width. Returns a 1-bit zero
384 /// value if `bitWidth` is 0.
385 static APInt getMaxSignedValue(unsigned bitWidth) {
386  return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
387 }
388 
389 //===----------------------------------------------------------------------===//
390 // Fold Hooks
391 //===----------------------------------------------------------------------===//
392 
393 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
394  assert(adaptor.getOperands().empty() && "constant has no operands");
395  return getValueAttr();
396 }
397 
398 OpFoldResult SpecialConstantOp::fold(FoldAdaptor adaptor) {
399  assert(adaptor.getOperands().empty() && "constant has no operands");
400  return getValueAttr();
401 }
402 
403 OpFoldResult AggregateConstantOp::fold(FoldAdaptor adaptor) {
404  assert(adaptor.getOperands().empty() && "constant has no operands");
405  return getFieldsAttr();
406 }
407 
408 OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
409  assert(adaptor.getOperands().empty() && "constant has no operands");
410  return getValueAttr();
411 }
412 
413 OpFoldResult FIntegerConstantOp::fold(FoldAdaptor adaptor) {
414  assert(adaptor.getOperands().empty() && "constant has no operands");
415  return getValueAttr();
416 }
417 
418 OpFoldResult BoolConstantOp::fold(FoldAdaptor adaptor) {
419  assert(adaptor.getOperands().empty() && "constant has no operands");
420  return getValueAttr();
421 }
422 
423 OpFoldResult DoubleConstantOp::fold(FoldAdaptor adaptor) {
424  assert(adaptor.getOperands().empty() && "constant has no operands");
425  return getValueAttr();
426 }
427 
428 //===----------------------------------------------------------------------===//
429 // Binary Operators
430 //===----------------------------------------------------------------------===//
431 
432 OpFoldResult AddPrimOp::fold(FoldAdaptor adaptor) {
434  *this, adaptor.getOperands(), BinOpKind::Normal,
435  [=](const APSInt &a, const APSInt &b) { return a + b; });
436 }
437 
438 void AddPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
439  MLIRContext *context) {
440  results.insert<patterns::moveConstAdd, patterns::AddOfZero,
441  patterns::AddOfSelf, patterns::AddOfPad>(context);
442 }
443 
444 OpFoldResult SubPrimOp::fold(FoldAdaptor adaptor) {
446  *this, adaptor.getOperands(), BinOpKind::Normal,
447  [=](const APSInt &a, const APSInt &b) { return a - b; });
448 }
449 
450 void SubPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
451  MLIRContext *context) {
452  results.insert<patterns::SubOfZero, patterns::SubFromZeroSigned,
453  patterns::SubFromZeroUnsigned, patterns::SubOfSelf,
454  patterns::SubOfPadL, patterns::SubOfPadR>(context);
455 }
456 
457 OpFoldResult MulPrimOp::fold(FoldAdaptor adaptor) {
458  // mul(x, 0) -> 0
459  //
460  // This is legal because it aligns with the Scala FIRRTL Compiler
461  // interpretation of lowering invalid to constant zero before constant
462  // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
463  // multiplication this way and will emit "x * 0".
464  if (isConstantZero(adaptor.getRhs()) || isConstantZero(adaptor.getLhs()))
465  return getIntZerosAttr(getType());
466 
468  *this, adaptor.getOperands(), BinOpKind::Normal,
469  [=](const APSInt &a, const APSInt &b) { return a * b; });
470 }
471 
472 OpFoldResult DivPrimOp::fold(FoldAdaptor adaptor) {
473  /// div(x, x) -> 1
474  ///
475  /// Division by zero is undefined in the FIRRTL specification. This fold
476  /// exploits that fact to optimize self division to one. Note: this should
477  /// supersede any division with invalid or zero. Division of invalid by
478  /// invalid should be one.
479  if (getLhs() == getRhs()) {
480  auto width = getType().get().getWidthOrSentinel();
481  if (width == -1)
482  width = 2;
483  // Only fold if we have at least 1 bit of width to represent the `1` value.
484  if (width != 0)
485  return getIntAttr(getType(), APInt(width, 1));
486  }
487 
488  // div(0, x) -> 0
489  //
490  // This is legal because it aligns with the Scala FIRRTL Compiler
491  // interpretation of lowering invalid to constant zero before constant
492  // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
493  // division this way and will emit "0 / x".
494  if (isConstantZero(adaptor.getLhs()) && !isConstantZero(adaptor.getRhs()))
495  return getIntZerosAttr(getType());
496 
497  /// div(x, 1) -> x : (uint, uint) -> uint
498  ///
499  /// UInt division by one returns the numerator. SInt division can't
500  /// be folded here because it increases the return type bitwidth by
501  /// one and requires sign extension (a new op).
502  if (auto rhsCst = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>())
503  if (rhsCst.getValue().isOne() && getLhs().getType() == getType())
504  return getLhs();
505 
507  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
508  [=](const APSInt &a, const APSInt &b) -> APInt {
509  if (!!b)
510  return a / b;
511  return APInt(a.getBitWidth(), 0);
512  });
513 }
514 
515 OpFoldResult RemPrimOp::fold(FoldAdaptor adaptor) {
516  // rem(x, x) -> 0
517  //
518  // Division by zero is undefined in the FIRRTL specification. This fold
519  // exploits that fact to optimize self division remainder to zero. Note:
520  // this should supersede any division with invalid or zero. Remainder of
521  // division of invalid by invalid should be zero.
522  if (getLhs() == getRhs())
523  return getIntZerosAttr(getType());
524 
525  // rem(0, x) -> 0
526  //
527  // This is legal because it aligns with the Scala FIRRTL Compiler
528  // interpretation of lowering invalid to constant zero before constant
529  // propagation. Note: the Scala FIRRTL Compiler does NOT currently optimize
530  // division this way and will emit "0 % x".
531  if (isConstantZero(adaptor.getLhs()))
532  return getIntZerosAttr(getType());
533 
535  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
536  [=](const APSInt &a, const APSInt &b) -> APInt {
537  if (!!b)
538  return a % b;
539  return APInt(a.getBitWidth(), 0);
540  });
541 }
542 
543 OpFoldResult DShlPrimOp::fold(FoldAdaptor adaptor) {
545  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
546  [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
547 }
548 
549 OpFoldResult DShlwPrimOp::fold(FoldAdaptor adaptor) {
551  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
552  [=](const APSInt &a, const APSInt &b) -> APInt { return a.shl(b); });
553 }
554 
555 OpFoldResult DShrPrimOp::fold(FoldAdaptor adaptor) {
557  *this, adaptor.getOperands(), BinOpKind::DivideOrShift,
558  [=](const APSInt &a, const APSInt &b) -> APInt {
559  return getType().get().isUnsigned() || !a.getBitWidth() ? a.lshr(b)
560  : a.ashr(b);
561  });
562 }
563 
564 // TODO: Move to DRR.
565 OpFoldResult AndPrimOp::fold(FoldAdaptor adaptor) {
566  if (auto rhsCst = getConstant(adaptor.getRhs())) {
567  /// and(x, 0) -> 0, 0 is largest or is implicit zero extended
568  if (rhsCst->isZero())
569  return getIntZerosAttr(getType());
570 
571  /// and(x, -1) -> x
572  if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
573  getRhs().getType() == getType())
574  return getLhs();
575  }
576 
577  if (auto lhsCst = getConstant(adaptor.getLhs())) {
578  /// and(0, x) -> 0, 0 is largest or is implicit zero extended
579  if (lhsCst->isZero())
580  return getIntZerosAttr(getType());
581 
582  /// and(-1, x) -> x
583  if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
584  getRhs().getType() == getType())
585  return getRhs();
586  }
587 
588  /// and(x, x) -> x
589  if (getLhs() == getRhs() && getRhs().getType() == getType())
590  return getRhs();
591 
593  *this, adaptor.getOperands(), BinOpKind::Normal,
594  [](const APSInt &a, const APSInt &b) -> APInt { return a & b; });
595 }
596 
597 void AndPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
598  MLIRContext *context) {
599  results
600  .insert<patterns::extendAnd, patterns::moveConstAnd, patterns::AndOfZero,
601  patterns::AndOfAllOne, patterns::AndOfSelf, patterns::AndOfPad,
602  patterns::AndOfAsSIntL, patterns::AndOfAsSIntR>(context);
603 }
604 
605 OpFoldResult OrPrimOp::fold(FoldAdaptor adaptor) {
606  if (auto rhsCst = getConstant(adaptor.getRhs())) {
607  /// or(x, 0) -> x
608  if (rhsCst->isZero() && getLhs().getType() == getType())
609  return getLhs();
610 
611  /// or(x, -1) -> -1
612  if (rhsCst->isAllOnes() && getRhs().getType() == getType() &&
613  getLhs().getType() == getType())
614  return getRhs();
615  }
616 
617  if (auto lhsCst = getConstant(adaptor.getLhs())) {
618  /// or(0, x) -> x
619  if (lhsCst->isZero() && getRhs().getType() == getType())
620  return getRhs();
621 
622  /// or(-1, x) -> -1
623  if (lhsCst->isAllOnes() && getLhs().getType() == getType() &&
624  getRhs().getType() == getType())
625  return getLhs();
626  }
627 
628  /// or(x, x) -> x
629  if (getLhs() == getRhs() && getRhs().getType() == getType())
630  return getRhs();
631 
633  *this, adaptor.getOperands(), BinOpKind::Normal,
634  [](const APSInt &a, const APSInt &b) -> APInt { return a | b; });
635 }
636 
637 void OrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
638  MLIRContext *context) {
639  results.insert<patterns::extendOr, patterns::moveConstOr, patterns::OrOfZero,
640  patterns::OrOfAllOne, patterns::OrOfSelf, patterns::OrOfPad>(
641  context);
642 }
643 
644 OpFoldResult XorPrimOp::fold(FoldAdaptor adaptor) {
645  /// xor(x, 0) -> x
646  if (auto rhsCst = getConstant(adaptor.getRhs()))
647  if (rhsCst->isZero() &&
648  firrtl::areAnonymousTypesEquivalent(getLhs().getType(), getType()))
649  return getLhs();
650 
651  /// xor(x, 0) -> x
652  if (auto lhsCst = getConstant(adaptor.getLhs()))
653  if (lhsCst->isZero() &&
654  firrtl::areAnonymousTypesEquivalent(getRhs().getType(), getType()))
655  return getRhs();
656 
657  /// xor(x, x) -> 0
658  if (getLhs() == getRhs())
659  return getIntAttr(
660  getType(), APInt(std::max(getType().get().getWidthOrSentinel(), 0), 0));
661 
663  *this, adaptor.getOperands(), BinOpKind::Normal,
664  [](const APSInt &a, const APSInt &b) -> APInt { return a ^ b; });
665 }
666 
667 void XorPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
668  MLIRContext *context) {
669  results.insert<patterns::extendXor, patterns::moveConstXor,
670  patterns::XorOfZero, patterns::XorOfSelf, patterns::XorOfPad>(
671  context);
672 }
673 
674 void LEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
675  MLIRContext *context) {
676  results.insert<patterns::LEQWithConstLHS>(context);
677 }
678 
679 OpFoldResult LEQPrimOp::fold(FoldAdaptor adaptor) {
680  bool isUnsigned = getLhs().getType().get().isUnsigned();
681 
682  // leq(x, x) -> 1
683  if (getLhs() == getRhs())
684  return getIntAttr(getType(), APInt(1, 1));
685 
686  // Comparison against constant outside type bounds.
687  if (auto width = getLhs().getType().get().getWidth()) {
688  if (auto rhsCst = getConstant(adaptor.getRhs())) {
689  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
690  commonWidth = std::max(commonWidth, 1);
691 
692  // leq(x, const) -> 0 where const < minValue of the unsigned type of x
693  // This can never occur since const is unsigned and cannot be less than 0.
694 
695  // leq(x, const) -> 0 where const < minValue of the signed type of x
696  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
697  .slt(getMinSignedValue(*width).sext(commonWidth)))
698  return getIntAttr(getType(), APInt(1, 0));
699 
700  // leq(x, const) -> 1 where const >= maxValue of the unsigned type of x
701  if (isUnsigned && rhsCst->zext(commonWidth)
702  .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
703  return getIntAttr(getType(), APInt(1, 1));
704 
705  // leq(x, const) -> 1 where const >= maxValue of the signed type of x
706  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
707  .sge(getMaxSignedValue(*width).sext(commonWidth)))
708  return getIntAttr(getType(), APInt(1, 1));
709  }
710  }
711 
713  *this, adaptor.getOperands(), BinOpKind::Compare,
714  [=](const APSInt &a, const APSInt &b) -> APInt {
715  return APInt(1, a <= b);
716  });
717 }
718 
719 void LTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
720  MLIRContext *context) {
721  results.insert<patterns::LTWithConstLHS>(context);
722 }
723 
724 OpFoldResult LTPrimOp::fold(FoldAdaptor adaptor) {
725  IntType lhsType = getLhs().getType();
726  bool isUnsigned = lhsType.isUnsigned();
727 
728  // lt(x, x) -> 0
729  if (getLhs() == getRhs())
730  return getIntAttr(getType(), APInt(1, 0));
731 
732  // lt(x, 0) -> 0 when x is unsigned
733  if (auto rhsCst = getConstant(adaptor.getRhs())) {
734  if (rhsCst->isZero() && lhsType.isUnsigned())
735  return getIntAttr(getType(), APInt(1, 0));
736  }
737 
738  // Comparison against constant outside type bounds.
739  if (auto width = lhsType.getWidth()) {
740  if (auto rhsCst = getConstant(adaptor.getRhs())) {
741  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
742  commonWidth = std::max(commonWidth, 1);
743 
744  // lt(x, const) -> 0 where const <= minValue of the unsigned type of x
745  // Handled explicitly above.
746 
747  // lt(x, const) -> 0 where const <= minValue of the signed type of x
748  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
749  .sle(getMinSignedValue(*width).sext(commonWidth)))
750  return getIntAttr(getType(), APInt(1, 0));
751 
752  // lt(x, const) -> 1 where const > maxValue of the unsigned type of x
753  if (isUnsigned && rhsCst->zext(commonWidth)
754  .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
755  return getIntAttr(getType(), APInt(1, 1));
756 
757  // lt(x, const) -> 1 where const > maxValue of the signed type of x
758  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
759  .sgt(getMaxSignedValue(*width).sext(commonWidth)))
760  return getIntAttr(getType(), APInt(1, 1));
761  }
762  }
763 
765  *this, adaptor.getOperands(), BinOpKind::Compare,
766  [=](const APSInt &a, const APSInt &b) -> APInt {
767  return APInt(1, a < b);
768  });
769 }
770 
771 void GEQPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
772  MLIRContext *context) {
773  results.insert<patterns::GEQWithConstLHS>(context);
774 }
775 
776 OpFoldResult GEQPrimOp::fold(FoldAdaptor adaptor) {
777  IntType lhsType = getLhs().getType();
778  bool isUnsigned = lhsType.isUnsigned();
779 
780  // geq(x, x) -> 1
781  if (getLhs() == getRhs())
782  return getIntAttr(getType(), APInt(1, 1));
783 
784  // geq(x, 0) -> 1 when x is unsigned
785  if (auto rhsCst = getConstant(adaptor.getRhs())) {
786  if (rhsCst->isZero() && isUnsigned)
787  return getIntAttr(getType(), APInt(1, 1));
788  }
789 
790  // Comparison against constant outside type bounds.
791  if (auto width = lhsType.getWidth()) {
792  if (auto rhsCst = getConstant(adaptor.getRhs())) {
793  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
794  commonWidth = std::max(commonWidth, 1);
795 
796  // geq(x, const) -> 0 where const > maxValue of the unsigned type of x
797  if (isUnsigned && rhsCst->zext(commonWidth)
798  .ugt(getMaxUnsignedValue(*width).zext(commonWidth)))
799  return getIntAttr(getType(), APInt(1, 0));
800 
801  // geq(x, const) -> 0 where const > maxValue of the signed type of x
802  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
803  .sgt(getMaxSignedValue(*width).sext(commonWidth)))
804  return getIntAttr(getType(), APInt(1, 0));
805 
806  // geq(x, const) -> 1 where const <= minValue of the unsigned type of x
807  // Handled explicitly above.
808 
809  // geq(x, const) -> 1 where const <= minValue of the signed type of x
810  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
811  .sle(getMinSignedValue(*width).sext(commonWidth)))
812  return getIntAttr(getType(), APInt(1, 1));
813  }
814  }
815 
817  *this, adaptor.getOperands(), BinOpKind::Compare,
818  [=](const APSInt &a, const APSInt &b) -> APInt {
819  return APInt(1, a >= b);
820  });
821 }
822 
823 void GTPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
824  MLIRContext *context) {
825  results.insert<patterns::GTWithConstLHS>(context);
826 }
827 
828 OpFoldResult GTPrimOp::fold(FoldAdaptor adaptor) {
829  IntType lhsType = getLhs().getType();
830  bool isUnsigned = lhsType.isUnsigned();
831 
832  // gt(x, x) -> 0
833  if (getLhs() == getRhs())
834  return getIntAttr(getType(), APInt(1, 0));
835 
836  // Comparison against constant outside type bounds.
837  if (auto width = lhsType.getWidth()) {
838  if (auto rhsCst = getConstant(adaptor.getRhs())) {
839  auto commonWidth = std::max<int32_t>(*width, rhsCst->getBitWidth());
840  commonWidth = std::max(commonWidth, 1);
841 
842  // gt(x, const) -> 0 where const >= maxValue of the unsigned type of x
843  if (isUnsigned && rhsCst->zext(commonWidth)
844  .uge(getMaxUnsignedValue(*width).zext(commonWidth)))
845  return getIntAttr(getType(), APInt(1, 0));
846 
847  // gt(x, const) -> 0 where const >= maxValue of the signed type of x
848  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
849  .sge(getMaxSignedValue(*width).sext(commonWidth)))
850  return getIntAttr(getType(), APInt(1, 0));
851 
852  // gt(x, const) -> 1 where const < minValue of the unsigned type of x
853  // This can never occur since const is unsigned and cannot be less than 0.
854 
855  // gt(x, const) -> 1 where const < minValue of the signed type of x
856  if (!isUnsigned && sextZeroWidth(*rhsCst, commonWidth)
857  .slt(getMinSignedValue(*width).sext(commonWidth)))
858  return getIntAttr(getType(), APInt(1, 1));
859  }
860  }
861 
863  *this, adaptor.getOperands(), BinOpKind::Compare,
864  [=](const APSInt &a, const APSInt &b) -> APInt {
865  return APInt(1, a > b);
866  });
867 }
868 
869 OpFoldResult EQPrimOp::fold(FoldAdaptor adaptor) {
870  // eq(x, x) -> 1
871  if (getLhs() == getRhs())
872  return getIntAttr(getType(), APInt(1, 1));
873 
874  if (auto rhsCst = getConstant(adaptor.getRhs())) {
875  /// eq(x, 1) -> x when x is 1 bit.
876  /// TODO: Support SInt<1> on the LHS etc.
877  if (rhsCst->isAllOnes() && getLhs().getType() == getType() &&
878  getRhs().getType() == getType())
879  return getLhs();
880  }
881 
883  *this, adaptor.getOperands(), BinOpKind::Compare,
884  [=](const APSInt &a, const APSInt &b) -> APInt {
885  return APInt(1, a == b);
886  });
887 }
888 
889 LogicalResult EQPrimOp::canonicalize(EQPrimOp op, PatternRewriter &rewriter) {
890  return canonicalizePrimOp(
891  op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
892  if (auto rhsCst = getConstant(operands[1])) {
893  auto width = op.getLhs().getType().getBitWidthOrSentinel();
894 
895  // eq(x, 0) -> not(x) when x is 1 bit.
896  if (rhsCst->isZero() && op.getLhs().getType() == op.getType() &&
897  op.getRhs().getType() == op.getType()) {
898  return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
899  .getResult();
900  }
901 
902  // eq(x, 0) -> not(orr(x)) when x is >1 bit
903  if (rhsCst->isZero() && width > 1) {
904  auto orrOp = rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs());
905  return rewriter.create<NotPrimOp>(op.getLoc(), orrOp).getResult();
906  }
907 
908  // eq(x, ~0) -> andr(x) when x is >1 bit
909  if (rhsCst->isAllOnes() && width > 1 &&
910  op.getLhs().getType() == op.getRhs().getType()) {
911  return rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs())
912  .getResult();
913  }
914  }
915 
916  return {};
917  });
918 }
919 
920 OpFoldResult NEQPrimOp::fold(FoldAdaptor adaptor) {
921  // neq(x, x) -> 0
922  if (getLhs() == getRhs())
923  return getIntAttr(getType(), APInt(1, 0));
924 
925  if (auto rhsCst = getConstant(adaptor.getRhs())) {
926  /// neq(x, 0) -> x when x is 1 bit.
927  /// TODO: Support SInt<1> on the LHS etc.
928  if (rhsCst->isZero() && getLhs().getType() == getType() &&
929  getRhs().getType() == getType())
930  return getLhs();
931  }
932 
934  *this, adaptor.getOperands(), BinOpKind::Compare,
935  [=](const APSInt &a, const APSInt &b) -> APInt {
936  return APInt(1, a != b);
937  });
938 }
939 
940 LogicalResult NEQPrimOp::canonicalize(NEQPrimOp op, PatternRewriter &rewriter) {
941  return canonicalizePrimOp(
942  op, rewriter, [&](ArrayRef<Attribute> operands) -> OpFoldResult {
943  if (auto rhsCst = getConstant(operands[1])) {
944  auto width = op.getLhs().getType().getBitWidthOrSentinel();
945 
946  // neq(x, 1) -> not(x) when x is 1 bit
947  if (rhsCst->isAllOnes() && op.getLhs().getType() == op.getType() &&
948  op.getRhs().getType() == op.getType()) {
949  return rewriter.create<NotPrimOp>(op.getLoc(), op.getLhs())
950  .getResult();
951  }
952 
953  // neq(x, 0) -> orr(x) when x is >1 bit
954  if (rhsCst->isZero() && width > 1) {
955  return rewriter.create<OrRPrimOp>(op.getLoc(), op.getLhs())
956  .getResult();
957  }
958 
959  // neq(x, ~0) -> not(andr(x))) when x is >1 bit
960  if (rhsCst->isAllOnes() && width > 1 &&
961  op.getLhs().getType() == op.getRhs().getType()) {
962  auto andrOp = rewriter.create<AndRPrimOp>(op.getLoc(), op.getLhs());
963  return rewriter.create<NotPrimOp>(op.getLoc(), andrOp).getResult();
964  }
965  }
966 
967  return {};
968  });
969 }
970 
971 //===----------------------------------------------------------------------===//
972 // Unary Operators
973 //===----------------------------------------------------------------------===//
974 
975 OpFoldResult SizeOfIntrinsicOp::fold(FoldAdaptor) {
976  auto base = getInput().getType();
977  auto w = base.getBitWidthOrSentinel();
978  if (w >= 0)
979  return getIntAttr(getType(), APInt(32, w));
980  return {};
981 }
982 
983 OpFoldResult IsXIntrinsicOp::fold(FoldAdaptor adaptor) {
984  // No constant can be 'x' by definition.
985  if (auto cst = getConstant(adaptor.getArg()))
986  return getIntAttr(getType(), APInt(1, 0));
987  return {};
988 }
989 
990 OpFoldResult AsSIntPrimOp::fold(FoldAdaptor adaptor) {
991  // No effect.
992  if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
993  return getInput();
994 
995  // Be careful to only fold the cast into the constant if the size is known.
996  // Otherwise width inference may produce differently-sized constants if the
997  // sign changes.
998  if (getType().get().hasWidth())
999  if (auto cst = getConstant(adaptor.getInput()))
1000  return getIntAttr(getType(), *cst);
1001 
1002  return {};
1003 }
1004 
1005 void AsSIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1006  MLIRContext *context) {
1007  results.insert<patterns::StoUtoS>(context);
1008 }
1009 
1010 OpFoldResult AsUIntPrimOp::fold(FoldAdaptor adaptor) {
1011  // No effect.
1012  if (areAnonymousTypesEquivalent(getInput().getType(), getType()))
1013  return getInput();
1014 
1015  // Be careful to only fold the cast into the constant if the size is known.
1016  // Otherwise width inference may produce differently-sized constants if the
1017  // sign changes.
1018  if (getType().get().hasWidth())
1019  if (auto cst = getConstant(adaptor.getInput()))
1020  return getIntAttr(getType(), *cst);
1021 
1022  return {};
1023 }
1024 
1025 void AsUIntPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1026  MLIRContext *context) {
1027  results.insert<patterns::UtoStoU>(context);
1028 }
1029 
1030 OpFoldResult AsAsyncResetPrimOp::fold(FoldAdaptor adaptor) {
1031  // No effect.
1032  if (getInput().getType() == getType())
1033  return getInput();
1034 
1035  // Constant fold.
1036  if (auto cst = getConstant(adaptor.getInput()))
1037  return BoolAttr::get(getContext(), cst->getBoolValue());
1038 
1039  return {};
1040 }
1041 
1042 OpFoldResult AsClockPrimOp::fold(FoldAdaptor adaptor) {
1043  // No effect.
1044  if (getInput().getType() == getType())
1045  return getInput();
1046 
1047  // Constant fold.
1048  if (auto cst = getConstant(adaptor.getInput()))
1049  return BoolAttr::get(getContext(), cst->getBoolValue());
1050 
1051  return {};
1052 }
1053 
1054 OpFoldResult CvtPrimOp::fold(FoldAdaptor adaptor) {
1055  if (!hasKnownWidthIntTypes(*this))
1056  return {};
1057 
1058  // Signed to signed is a noop, unsigned operands prepend a zero bit.
1059  if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1060  getType().get().getWidthOrSentinel()))
1061  return getIntAttr(getType(), *cst);
1062 
1063  return {};
1064 }
1065 
1066 void CvtPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1067  MLIRContext *context) {
1068  results.insert<patterns::CVTSigned, patterns::CVTUnSigned>(context);
1069 }
1070 
1071 OpFoldResult NegPrimOp::fold(FoldAdaptor adaptor) {
1072  if (!hasKnownWidthIntTypes(*this))
1073  return {};
1074 
1075  // FIRRTL negate always adds a bit.
1076  // -x ---> 0-sext(x) or 0-zext(x)
1077  if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1078  getType().get().getWidthOrSentinel()))
1079  return getIntAttr(getType(), APInt((*cst).getBitWidth(), 0) - *cst);
1080 
1081  return {};
1082 }
1083 
1084 OpFoldResult NotPrimOp::fold(FoldAdaptor adaptor) {
1085  if (!hasKnownWidthIntTypes(*this))
1086  return {};
1087 
1088  if (auto cst = getExtendedConstant(getOperand(), adaptor.getInput(),
1089  getType().get().getWidthOrSentinel()))
1090  return getIntAttr(getType(), ~*cst);
1091 
1092  return {};
1093 }
1094 
1095 void NotPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1096  MLIRContext *context) {
1097  results.insert<patterns::NotNot>(context);
1098 }
1099 
1100 OpFoldResult AndRPrimOp::fold(FoldAdaptor adaptor) {
1101  if (!hasKnownWidthIntTypes(*this))
1102  return {};
1103 
1104  if (getInput().getType().getBitWidthOrSentinel() == 0)
1105  return getIntAttr(getType(), APInt(1, 1));
1106 
1107  // x == -1
1108  if (auto cst = getConstant(adaptor.getInput()))
1109  return getIntAttr(getType(), APInt(1, cst->isAllOnes()));
1110 
1111  // one bit is identity. Only applies to UInt since we can't make a cast
1112  // here.
1113  if (isUInt1(getInput().getType()))
1114  return getInput();
1115 
1116  return {};
1117 }
1118 
1119 void AndRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1120  MLIRContext *context) {
1121  results
1122  .insert<patterns::AndRasSInt, patterns::AndRasUInt, patterns::AndRPadU,
1123  patterns::AndRPadS, patterns::AndRCatOneL, patterns::AndRCatOneR,
1124  patterns::AndRCatZeroL, patterns::AndRCatZeroR>(context);
1125 }
1126 
1127 OpFoldResult OrRPrimOp::fold(FoldAdaptor adaptor) {
1128  if (!hasKnownWidthIntTypes(*this))
1129  return {};
1130 
1131  if (getInput().getType().getBitWidthOrSentinel() == 0)
1132  return getIntAttr(getType(), APInt(1, 0));
1133 
1134  // x != 0
1135  if (auto cst = getConstant(adaptor.getInput()))
1136  return getIntAttr(getType(), APInt(1, !cst->isZero()));
1137 
1138  // one bit is identity. Only applies to UInt since we can't make a cast
1139  // here.
1140  if (isUInt1(getInput().getType()))
1141  return getInput();
1142 
1143  return {};
1144 }
1145 
1146 void OrRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1147  MLIRContext *context) {
1148  results.insert<patterns::OrRasSInt, patterns::OrRasUInt, patterns::OrRPadU,
1149  patterns::OrRCatZeroH, patterns::OrRCatZeroL>(context);
1150 }
1151 
1152 OpFoldResult XorRPrimOp::fold(FoldAdaptor adaptor) {
1153  if (!hasKnownWidthIntTypes(*this))
1154  return {};
1155 
1156  if (getInput().getType().getBitWidthOrSentinel() == 0)
1157  return getIntAttr(getType(), APInt(1, 0));
1158 
1159  // popcount(x) & 1
1160  if (auto cst = getConstant(adaptor.getInput()))
1161  return getIntAttr(getType(), APInt(1, cst->popcount() & 1));
1162 
1163  // one bit is identity. Only applies to UInt since we can't make a cast here.
1164  if (isUInt1(getInput().getType()))
1165  return getInput();
1166 
1167  return {};
1168 }
1169 
1170 void XorRPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1171  MLIRContext *context) {
1172  results.insert<patterns::XorRasSInt, patterns::XorRasUInt, patterns::XorRPadU,
1173  patterns::XorRCatZeroH, patterns::XorRCatZeroL>(context);
1174 }
1175 
1176 //===----------------------------------------------------------------------===//
1177 // Other Operators
1178 //===----------------------------------------------------------------------===//
1179 
1180 OpFoldResult CatPrimOp::fold(FoldAdaptor adaptor) {
1181  // cat(x, 0-width) -> x
1182  // cat(0-width, x) -> x
1183  // Limit to unsigned (result type), as cannot insert cast here.
1184  IntType lhsType = getLhs().getType();
1185  IntType rhsType = getRhs().getType();
1186  if (lhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
1187  return getRhs();
1188  if (rhsType.getBitWidthOrSentinel() == 0 && rhsType.isUnsigned())
1189  return getLhs();
1190 
1191  if (!hasKnownWidthIntTypes(*this))
1192  return {};
1193 
1194  // Constant fold cat.
1195  if (auto lhs = getConstant(adaptor.getLhs()))
1196  if (auto rhs = getConstant(adaptor.getRhs()))
1197  return getIntAttr(getType(), lhs->concat(*rhs));
1198 
1199  return {};
1200 }
1201 
1202 void DShlPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1203  MLIRContext *context) {
1204  results.insert<patterns::DShlOfConstant>(context);
1205 }
1206 
1207 void DShrPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1208  MLIRContext *context) {
1209  results.insert<patterns::DShrOfConstant>(context);
1210 }
1211 
1212 namespace {
1213 // cat(bits(x, ...), bits(x, ...)) -> bits(x ...) when the two ...'s are
1214 // consequtive in the input.
1215 struct CatBitsBits : public mlir::RewritePattern {
1216  CatBitsBits(MLIRContext *context)
1217  : RewritePattern(CatPrimOp::getOperationName(), 0, context) {}
1218  LogicalResult matchAndRewrite(Operation *op,
1219  PatternRewriter &rewriter) const override {
1220  auto cat = cast<CatPrimOp>(op);
1221  if (auto lhsBits =
1222  dyn_cast_or_null<BitsPrimOp>(cat.getLhs().getDefiningOp())) {
1223  if (auto rhsBits =
1224  dyn_cast_or_null<BitsPrimOp>(cat.getRhs().getDefiningOp())) {
1225  if (lhsBits.getInput() == rhsBits.getInput() &&
1226  lhsBits.getLo() - 1 == rhsBits.getHi()) {
1227  replaceOpWithNewOpAndCopyName<BitsPrimOp>(
1228  rewriter, cat, cat.getType(), lhsBits.getInput(), lhsBits.getHi(),
1229  rhsBits.getLo());
1230  return success();
1231  }
1232  }
1233  }
1234  return failure();
1235  }
1236 };
1237 } // namespace
1238 
1239 void CatPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1240  MLIRContext *context) {
1241  results.insert<CatBitsBits, patterns::CatDoubleConst>(context);
1242 }
1243 
1244 OpFoldResult BitCastOp::fold(FoldAdaptor adaptor) {
1245  auto op = (*this);
1246  // BitCast is redundant if input and result types are same.
1247  if (op.getType() == op.getInput().getType())
1248  return op.getInput();
1249 
1250  // Two consecutive BitCasts are redundant if first bitcast type is same as the
1251  // final result type.
1252  if (BitCastOp in = dyn_cast_or_null<BitCastOp>(op.getInput().getDefiningOp()))
1253  if (op.getType() == in.getInput().getType())
1254  return in.getInput();
1255 
1256  return {};
1257 }
1258 
1259 OpFoldResult BitsPrimOp::fold(FoldAdaptor adaptor) {
1260  IntType inputType = getInput().getType();
1261  IntType resultType = getType();
1262  // If we are extracting the entire input, then return it.
1263  if (inputType == getType() && resultType.hasWidth())
1264  return getInput();
1265 
1266  // Constant fold.
1267  if (hasKnownWidthIntTypes(*this))
1268  if (auto cst = getConstant(adaptor.getInput()))
1269  return getIntAttr(resultType,
1270  cst->extractBits(getHi() - getLo() + 1, getLo()));
1271 
1272  return {};
1273 }
1274 
1275 void BitsPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1276  MLIRContext *context) {
1277  results
1278  .insert<patterns::BitsOfBits, patterns::BitsOfMux, patterns::BitsOfAsUInt,
1279  patterns::BitsOfAnd, patterns::BitsOfPad>(context);
1280 }
1281 
1282 /// Replace the specified operation with a 'bits' op from the specified hi/lo
1283 /// bits. Insert a cast to handle the case where the original operation
1284 /// returned a signed integer.
1285 static void replaceWithBits(Operation *op, Value value, unsigned hiBit,
1286  unsigned loBit, PatternRewriter &rewriter) {
1287  auto resType = type_cast<IntType>(op->getResult(0).getType());
1288  if (type_cast<IntType>(value.getType()).getWidth() != resType.getWidth())
1289  value = rewriter.create<BitsPrimOp>(op->getLoc(), value, hiBit, loBit);
1290 
1291  if (resType.isSigned() && !type_cast<IntType>(value.getType()).isSigned()) {
1292  value = rewriter.createOrFold<AsSIntPrimOp>(op->getLoc(), resType, value);
1293  } else if (resType.isUnsigned() &&
1294  !type_cast<IntType>(value.getType()).isUnsigned()) {
1295  value = rewriter.createOrFold<AsUIntPrimOp>(op->getLoc(), resType, value);
1296  }
1297  rewriter.replaceOp(op, value);
1298 }
1299 
1300 template <typename OpTy>
1301 static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor) {
1302  // mux : UInt<0> -> 0
1303  if (op.getType().getBitWidthOrSentinel() == 0)
1304  return getIntAttr(op.getType(),
1305  APInt(0, 0, op.getType().isSignedInteger()));
1306 
1307  // mux(cond, x, x) -> x
1308  if (op.getHigh() == op.getLow())
1309  return op.getHigh();
1310 
1311  // The following folds require that the result has a known width. Otherwise
1312  // the mux requires an additional padding operation to be inserted, which is
1313  // not possible in a fold.
1314  if (op.getType().getBitWidthOrSentinel() < 0)
1315  return {};
1316 
1317  // mux(0/1, x, y) -> x or y
1318  if (auto cond = getConstant(adaptor.getSel())) {
1319  if (cond->isZero() && op.getLow().getType() == op.getType())
1320  return op.getLow();
1321  if (!cond->isZero() && op.getHigh().getType() == op.getType())
1322  return op.getHigh();
1323  }
1324 
1325  // mux(cond, x, cst)
1326  if (auto lowCst = getConstant(adaptor.getLow())) {
1327  // mux(cond, c1, c2)
1328  if (auto highCst = getConstant(adaptor.getHigh())) {
1329  // mux(cond, cst, cst) -> cst
1330  if (highCst->getBitWidth() == lowCst->getBitWidth() &&
1331  *highCst == *lowCst)
1332  return getIntAttr(op.getType(), *highCst);
1333  // mux(cond, 1, 0) -> cond
1334  if (highCst->isOne() && lowCst->isZero() &&
1335  op.getType() == op.getSel().getType())
1336  return op.getSel();
1337 
1338  // TODO: x ? ~0 : 0 -> sext(x)
1339  // TODO: "x ? c1 : c2" -> many tricks
1340  }
1341  // TODO: "x ? a : 0" -> sext(x) & a
1342  }
1343 
1344  // TODO: "x ? c1 : y" -> "~x ? y : c1"
1345  return {};
1346 }
1347 
1348 OpFoldResult MuxPrimOp::fold(FoldAdaptor adaptor) {
1349  return foldMux(*this, adaptor);
1350 }
1351 
1352 OpFoldResult Mux2CellIntrinsicOp::fold(FoldAdaptor adaptor) {
1353  return foldMux(*this, adaptor);
1354 }
1355 
1356 OpFoldResult Mux4CellIntrinsicOp::fold(FoldAdaptor adaptor) { return {}; }
1357 
1358 namespace {
1359 
1360 // If the mux has a known output width, pad the operands up to this width.
1361 // Most folds on mux require that folded operands are of the same width as
1362 // the mux itself.
1363 class MuxPad : public mlir::RewritePattern {
1364 public:
1365  MuxPad(MLIRContext *context)
1366  : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1367 
1368  LogicalResult
1369  matchAndRewrite(Operation *op,
1370  mlir::PatternRewriter &rewriter) const override {
1371  auto mux = cast<MuxPrimOp>(op);
1372  auto width = mux.getType().getBitWidthOrSentinel();
1373  if (width < 0)
1374  return failure();
1375 
1376  auto pad = [&](Value input) -> Value {
1377  auto inputWidth =
1378  type_cast<FIRRTLBaseType>(input.getType()).getBitWidthOrSentinel();
1379  if (inputWidth < 0 || width == inputWidth)
1380  return input;
1381  return rewriter
1382  .create<PadPrimOp>(mux.getLoc(), mux.getType(), input, width)
1383  .getResult();
1384  };
1385 
1386  auto newHigh = pad(mux.getHigh());
1387  auto newLow = pad(mux.getLow());
1388  if (newHigh == mux.getHigh() && newLow == mux.getLow())
1389  return failure();
1390 
1391  replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1392  rewriter, op, mux.getType(), ValueRange{mux.getSel(), newHigh, newLow},
1393  mux->getAttrs());
1394  return success();
1395  }
1396 };
1397 
1398 // Find muxes which have conditions dominated by other muxes with the same
1399 // condition.
1400 class MuxSharedCond : public mlir::RewritePattern {
1401 public:
1402  MuxSharedCond(MLIRContext *context)
1403  : RewritePattern(MuxPrimOp::getOperationName(), 0, context) {}
1404 
1405  static const int depthLimit = 5;
1406 
1407  Value updateOrClone(MuxPrimOp mux, Value high, Value low,
1408  mlir::PatternRewriter &rewriter,
1409  bool updateInPlace) const {
1410  if (updateInPlace) {
1411  rewriter.updateRootInPlace(mux, [&] {
1412  mux.setOperand(1, high);
1413  mux.setOperand(2, low);
1414  });
1415  return {};
1416  }
1417  rewriter.setInsertionPointAfter(mux);
1418  return rewriter
1419  .create<MuxPrimOp>(mux.getLoc(), mux.getType(),
1420  ValueRange{mux.getSel(), high, low})
1421  .getResult();
1422  }
1423 
1424  // Walk a dependent mux tree assuming the condition cond is true.
1425  Value tryCondTrue(Value op, Value cond, mlir::PatternRewriter &rewriter,
1426  bool updateInPlace, int limit) const {
1427  MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1428  if (!mux)
1429  return {};
1430  if (mux.getSel() == cond)
1431  return mux.getHigh();
1432  if (limit > depthLimit)
1433  return {};
1434  updateInPlace &= mux->hasOneUse();
1435 
1436  if (Value v = tryCondTrue(mux.getHigh(), cond, rewriter, updateInPlace,
1437  limit + 1))
1438  return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1439 
1440  if (Value v =
1441  tryCondTrue(mux.getLow(), cond, rewriter, updateInPlace, limit + 1))
1442  return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1443  return {};
1444  }
1445 
1446  // Walk a dependent mux tree assuming the condition cond is false.
1447  Value tryCondFalse(Value op, Value cond, mlir::PatternRewriter &rewriter,
1448  bool updateInPlace, int limit) const {
1449  MuxPrimOp mux = op.getDefiningOp<MuxPrimOp>();
1450  if (!mux)
1451  return {};
1452  if (mux.getSel() == cond)
1453  return mux.getLow();
1454  if (limit > depthLimit)
1455  return {};
1456  updateInPlace &= mux->hasOneUse();
1457 
1458  if (Value v = tryCondFalse(mux.getHigh(), cond, rewriter, updateInPlace,
1459  limit + 1))
1460  return updateOrClone(mux, v, mux.getLow(), rewriter, updateInPlace);
1461 
1462  if (Value v = tryCondFalse(mux.getLow(), cond, rewriter, updateInPlace,
1463  limit + 1))
1464  return updateOrClone(mux, mux.getHigh(), v, rewriter, updateInPlace);
1465 
1466  return {};
1467  }
1468 
1469  LogicalResult
1470  matchAndRewrite(Operation *op,
1471  mlir::PatternRewriter &rewriter) const override {
1472  auto mux = cast<MuxPrimOp>(op);
1473  auto width = mux.getType().getBitWidthOrSentinel();
1474  if (width < 0)
1475  return failure();
1476 
1477  if (Value v = tryCondTrue(mux.getHigh(), mux.getSel(), rewriter, true, 0)) {
1478  rewriter.updateRootInPlace(mux, [&] { mux.setOperand(1, v); });
1479  return success();
1480  }
1481 
1482  if (Value v = tryCondFalse(mux.getLow(), mux.getSel(), rewriter, true, 0)) {
1483  rewriter.updateRootInPlace(mux, [&] { mux.setOperand(2, v); });
1484  return success();
1485  }
1486 
1487  return failure();
1488  }
1489 };
1490 } // namespace
1491 
1492 void MuxPrimOp::getCanonicalizationPatterns(RewritePatternSet &results,
1493  MLIRContext *context) {
1494  results.add<MuxPad, MuxSharedCond, patterns::MuxNot, patterns::MuxSameTrue,
1495  patterns::MuxSameFalse, patterns::NarrowMuxLHS,
1496  patterns::NarrowMuxRHS>(context);
1497 }
1498 
1499 OpFoldResult PadPrimOp::fold(FoldAdaptor adaptor) {
1500  auto input = this->getInput();
1501 
1502  // pad(x) -> x if the width doesn't change.
1503  if (input.getType() == getType())
1504  return input;
1505 
1506  // Need to know the input width.
1507  auto inputType = input.getType().get();
1508  int32_t width = inputType.getWidthOrSentinel();
1509  if (width == -1)
1510  return {};
1511 
1512  // Constant fold.
1513  if (auto cst = getConstant(adaptor.getInput())) {
1514  auto destWidth = getType().get().getWidthOrSentinel();
1515  if (destWidth == -1)
1516  return {};
1517 
1518  if (inputType.isSigned() && cst->getBitWidth())
1519  return getIntAttr(getType(), cst->sext(destWidth));
1520  return getIntAttr(getType(), cst->zext(destWidth));
1521  }
1522 
1523  return {};
1524 }
1525 
1526 OpFoldResult ShlPrimOp::fold(FoldAdaptor adaptor) {
1527  auto input = this->getInput();
1528  IntType inputType = input.getType();
1529  int shiftAmount = getAmount();
1530 
1531  // shl(x, 0) -> x
1532  if (shiftAmount == 0)
1533  return input;
1534 
1535  // Constant fold.
1536  if (auto cst = getConstant(adaptor.getInput())) {
1537  auto inputWidth = inputType.getWidthOrSentinel();
1538  if (inputWidth != -1) {
1539  auto resultWidth = inputWidth + shiftAmount;
1540  shiftAmount = std::min(shiftAmount, resultWidth);
1541  return getIntAttr(getType(), cst->zext(resultWidth).shl(shiftAmount));
1542  }
1543  }
1544  return {};
1545 }
1546 
1547 OpFoldResult ShrPrimOp::fold(FoldAdaptor adaptor) {
1548  auto input = this->getInput();
1549  IntType inputType = input.getType();
1550  int shiftAmount = getAmount();
1551 
1552  // shr(x, 0) -> x
1553  if (shiftAmount == 0)
1554  return input;
1555 
1556  auto inputWidth = inputType.getWidthOrSentinel();
1557  if (inputWidth == -1)
1558  return {};
1559  if (inputWidth == 0)
1560  return getIntZerosAttr(getType());
1561 
1562  // shr(x, cst) where cst is all of x's bits and x is unsigned is 0.
1563  // If x is signed, it is the sign bit.
1564  if (shiftAmount >= inputWidth && inputType.isUnsigned())
1565  return getIntAttr(getType(), APInt(1, 0));
1566 
1567  // Constant fold.
1568  if (auto cst = getConstant(adaptor.getInput())) {
1569  APInt value;
1570  if (inputType.isSigned())
1571  value = cst->ashr(std::min(shiftAmount, inputWidth - 1));
1572  else
1573  value = cst->lshr(std::min(shiftAmount, inputWidth));
1574  auto resultWidth = std::max(inputWidth - shiftAmount, 1);
1575  return getIntAttr(getType(), value.trunc(resultWidth));
1576  }
1577  return {};
1578 }
1579 
1580 LogicalResult ShrPrimOp::canonicalize(ShrPrimOp op, PatternRewriter &rewriter) {
1581  auto inputWidth = op.getInput().getType().get().getWidthOrSentinel();
1582  if (inputWidth <= 0)
1583  return failure();
1584 
1585  // If we know the input width, we can canonicalize this into a BitsPrimOp.
1586  unsigned shiftAmount = op.getAmount();
1587  if (int(shiftAmount) >= inputWidth) {
1588  // shift(x, 32) => 0 when x has 32 bits. This is handled by fold().
1589  if (op.getType().get().isUnsigned())
1590  return failure();
1591 
1592  // Shifting a signed value by the full width is actually taking the
1593  // sign bit. If the shift amount is greater than the input width, it
1594  // is equivalent to shifting by the input width.
1595  shiftAmount = inputWidth - 1;
1596  }
1597 
1598  replaceWithBits(op, op.getInput(), inputWidth - 1, shiftAmount, rewriter);
1599  return success();
1600 }
1601 
1602 LogicalResult HeadPrimOp::canonicalize(HeadPrimOp op,
1603  PatternRewriter &rewriter) {
1604  auto inputWidth = op.getInput().getType().get().getWidthOrSentinel();
1605  if (inputWidth <= 0)
1606  return failure();
1607 
1608  // If we know the input width, we can canonicalize this into a BitsPrimOp.
1609  unsigned keepAmount = op.getAmount();
1610  if (keepAmount)
1611  replaceWithBits(op, op.getInput(), inputWidth - 1, inputWidth - keepAmount,
1612  rewriter);
1613  return success();
1614 }
1615 
1616 OpFoldResult HeadPrimOp::fold(FoldAdaptor adaptor) {
1617  if (hasKnownWidthIntTypes(*this))
1618  if (auto cst = getConstant(adaptor.getInput())) {
1619  int shiftAmount =
1620  getInput().getType().get().getWidthOrSentinel() - getAmount();
1621  return getIntAttr(getType(), cst->lshr(shiftAmount).trunc(getAmount()));
1622  }
1623 
1624  return {};
1625 }
1626 
1627 OpFoldResult TailPrimOp::fold(FoldAdaptor adaptor) {
1628  if (hasKnownWidthIntTypes(*this))
1629  if (auto cst = getConstant(adaptor.getInput()))
1630  return getIntAttr(getType(),
1631  cst->trunc(getType().get().getWidthOrSentinel()));
1632  return {};
1633 }
1634 
1635 LogicalResult TailPrimOp::canonicalize(TailPrimOp op,
1636  PatternRewriter &rewriter) {
1637  auto inputWidth = op.getInput().getType().get().getWidthOrSentinel();
1638  if (inputWidth <= 0)
1639  return failure();
1640 
1641  // If we know the input width, we can canonicalize this into a BitsPrimOp.
1642  unsigned dropAmount = op.getAmount();
1643  if (dropAmount != unsigned(inputWidth))
1644  replaceWithBits(op, op.getInput(), inputWidth - dropAmount - 1, 0,
1645  rewriter);
1646  return success();
1647 }
1648 
1649 void SubaccessOp::getCanonicalizationPatterns(RewritePatternSet &results,
1650  MLIRContext *context) {
1651  results.add<patterns::SubaccessOfConstant>(context);
1652 }
1653 
1654 OpFoldResult MultibitMuxOp::fold(FoldAdaptor adaptor) {
1655  // If there is only one input, just return it.
1656  if (adaptor.getInputs().size() == 1)
1657  return getOperand(1);
1658 
1659  if (auto constIndex = getConstant(adaptor.getIndex())) {
1660  auto index = constIndex->getZExtValue();
1661  if (index < getInputs().size())
1662  return getInputs()[getInputs().size() - 1 - index];
1663  }
1664 
1665  return {};
1666 }
1667 
1668 LogicalResult MultibitMuxOp::canonicalize(MultibitMuxOp op,
1669  PatternRewriter &rewriter) {
1670  // If all operands are equal, just canonicalize to it. We can add this
1671  // canonicalization as a folder but it costly to look through all inputs so it
1672  // is added here.
1673  if (llvm::all_of(op.getInputs().drop_front(), [&](auto input) {
1674  return input == op.getInputs().front();
1675  })) {
1676  replaceOpAndCopyName(rewriter, op, op.getInputs().front());
1677  return success();
1678  }
1679 
1680  // If the op is a vector indexing (e.g. `multbit_mux idx, a[n-1], a[n-2], ...,
1681  // a[0]`), we can fold the op into subaccess op `a[idx]`.
1682  if (auto lastSubindex = op.getInputs().back().getDefiningOp<SubindexOp>()) {
1683  if (llvm::all_of(llvm::enumerate(op.getInputs()), [&](auto e) {
1684  auto subindex = e.value().template getDefiningOp<SubindexOp>();
1685  return subindex && lastSubindex.getInput() == subindex.getInput() &&
1686  subindex.getIndex() + e.index() + 1 == op.getInputs().size();
1687  })) {
1688  replaceOpWithNewOpAndCopyName<SubaccessOp>(
1689  rewriter, op, lastSubindex.getInput(), op.getIndex());
1690  return success();
1691  }
1692  }
1693 
1694  // If the size is 2, canonicalize into a normal mux to introduce more folds.
1695  if (op.getInputs().size() != 2)
1696  return failure();
1697 
1698  // TODO: Handle even when `index` doesn't have uint<1>.
1699  auto uintType = op.getIndex().getType();
1700  if (uintType.getBitWidthOrSentinel() != 1)
1701  return failure();
1702 
1703  // multibit_mux(index, {lhs, rhs}) -> mux(index, lhs, rhs)
1704  replaceOpWithNewOpAndCopyName<MuxPrimOp>(
1705  rewriter, op, op.getIndex(), op.getInputs()[0], op.getInputs()[1]);
1706  return success();
1707 }
1708 
1709 //===----------------------------------------------------------------------===//
1710 // Declarations
1711 //===----------------------------------------------------------------------===//
1712 
1713 /// Scan all the uses of the specified value, checking to see if there is
1714 /// exactly one connect that has the value as its destination. This returns the
1715 /// operation if found and if all the other users are "reads" from the value.
1716 /// Returns null if there are no connects, or multiple connects to the value, or
1717 /// if the value is involved in an `AttachOp`, or if the connect isn't strict.
1718 ///
1719 /// Note that this will simply return the connect, which is located *anywhere*
1720 /// after the definition of the value. Users of this function are likely
1721 /// interested in the source side of the returned connect, the definition of
1722 /// which does likely not dominate the original value.
1723 StrictConnectOp firrtl::getSingleConnectUserOf(Value value) {
1724  StrictConnectOp connect;
1725  for (Operation *user : value.getUsers()) {
1726  // If we see an attach or aggregate sublements, just conservatively fail.
1727  if (isa<AttachOp, SubfieldOp, SubaccessOp, SubindexOp>(user))
1728  return {};
1729 
1730  if (auto aConnect = dyn_cast<FConnectLike>(user))
1731  if (aConnect.getDest() == value) {
1732  auto strictConnect = dyn_cast<StrictConnectOp>(*aConnect);
1733  // If this is not a strict connect, a second strict connect or in a
1734  // different block, fail.
1735  if (!strictConnect || (connect && connect != strictConnect) ||
1736  strictConnect->getBlock() != value.getParentBlock())
1737  return {};
1738  else
1739  connect = strictConnect;
1740  }
1741  }
1742  return connect;
1743 }
1744 
1745 // Forward simple values through wire's and reg's.
1746 static LogicalResult canonicalizeSingleSetConnect(StrictConnectOp op,
1747  PatternRewriter &rewriter) {
1748  // While we can do this for nearly all wires, we currently limit it to simple
1749  // things.
1750  Operation *connectedDecl = op.getDest().getDefiningOp();
1751  if (!connectedDecl)
1752  return failure();
1753 
1754  // Only support wire and reg for now.
1755  if (!isa<WireOp>(connectedDecl) && !isa<RegOp>(connectedDecl))
1756  return failure();
1757  if (hasDontTouch(connectedDecl) ||
1758  !AnnotationSet(connectedDecl).canBeDeleted() ||
1759  !hasDroppableName(connectedDecl) ||
1760  cast<Forceable>(connectedDecl).isForceable())
1761  return failure();
1762 
1763  // Only forward if the types exactly match and there is one connect.
1764  if (getSingleConnectUserOf(op.getDest()) != op)
1765  return failure();
1766 
1767  // Only forward if there is more than one use
1768  if (connectedDecl->hasOneUse())
1769  return failure();
1770 
1771  // Only do this if the connectee and the declaration are in the same block.
1772  auto *declBlock = connectedDecl->getBlock();
1773  auto *srcValueOp = op.getSrc().getDefiningOp();
1774  if (!srcValueOp) {
1775  // Ports are ok for wires but not registers.
1776  if (!isa<WireOp>(connectedDecl))
1777  return failure();
1778 
1779  } else {
1780  // Constants/invalids in the same block are ok to forward, even through
1781  // reg's since the clocking doesn't matter for constants.
1782  if (!isa<ConstantOp>(srcValueOp) && !isa<InvalidValueOp>(srcValueOp))
1783  return failure();
1784  if (srcValueOp->getBlock() != declBlock)
1785  return failure();
1786  }
1787 
1788  // Ok, we know we are doing the transformation.
1789 
1790  auto replacement = op.getSrc();
1791  if (srcValueOp) {
1792  // Replace with constant zero.
1793  if (isa<InvalidValueOp>(srcValueOp)) {
1794  if (isa<BundleType, FVectorType>(op.getDest().getType()))
1795  return failure();
1796  if (isa<ClockType, AsyncResetType, ResetType>(op.getDest().getType()))
1797  replacement = rewriter.create<SpecialConstantOp>(
1798  op.getSrc().getLoc(), op.getDest().getType(),
1799  rewriter.getBoolAttr(false));
1800  else
1801  replacement = rewriter.create<ConstantOp>(
1802  op.getSrc().getLoc(), op.getDest().getType(),
1803  getIntZerosAttr(op.getDest().getType()));
1804  }
1805  // This will be replaced with the constant source. First, make sure the
1806  // constant dominates all users.
1807  else if (srcValueOp != &declBlock->front()) {
1808  srcValueOp->moveBefore(&declBlock->front());
1809  }
1810  }
1811 
1812  // Replace all things *using* the decl with the constant/port, and
1813  // remove the declaration.
1814  replaceOpAndCopyName(rewriter, connectedDecl, replacement);
1815 
1816  // Remove the connect
1817  rewriter.eraseOp(op);
1818  return success();
1819 }
1820 
1821 void ConnectOp::getCanonicalizationPatterns(RewritePatternSet &results,
1822  MLIRContext *context) {
1823  results.insert<patterns::ConnectExtension, patterns::ConnectSameType>(
1824  context);
1825 }
1826 
1827 LogicalResult StrictConnectOp::canonicalize(StrictConnectOp op,
1828  PatternRewriter &rewriter) {
1829  // TODO: Canonicalize towards explicit extensions and flips here.
1830 
1831  // If there is a simple value connected to a foldable decl like a wire or reg,
1832  // see if we can eliminate the decl.
1833  if (succeeded(canonicalizeSingleSetConnect(op, rewriter)))
1834  return success();
1835  return failure();
1836 }
1837 
1838 //===----------------------------------------------------------------------===//
1839 // Statements
1840 //===----------------------------------------------------------------------===//
1841 
1842 /// If the specified value has an AttachOp user strictly dominating by
1843 /// "dominatingAttach" then return it.
1844 static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach) {
1845  for (auto *user : value.getUsers()) {
1846  auto attach = dyn_cast<AttachOp>(user);
1847  if (!attach || attach == dominatedAttach)
1848  continue;
1849  if (attach->isBeforeInBlock(dominatedAttach))
1850  return attach;
1851  }
1852  return {};
1853 }
1854 
1855 LogicalResult AttachOp::canonicalize(AttachOp op, PatternRewriter &rewriter) {
1856  // Single operand attaches are a noop.
1857  if (op.getNumOperands() <= 1) {
1858  rewriter.eraseOp(op);
1859  return success();
1860  }
1861 
1862  for (auto operand : op.getOperands()) {
1863  // Check to see if any of our operands has other attaches to it:
1864  // attach x, y
1865  // ...
1866  // attach x, z
1867  // If so, we can merge these into "attach x, y, z".
1868  if (auto attach = getDominatingAttachUser(operand, op)) {
1869  SmallVector<Value> newOperands(op.getOperands());
1870  for (auto newOperand : attach.getOperands())
1871  if (newOperand != operand) // Don't add operand twice.
1872  newOperands.push_back(newOperand);
1873  rewriter.create<AttachOp>(op->getLoc(), newOperands);
1874  rewriter.eraseOp(attach);
1875  rewriter.eraseOp(op);
1876  return success();
1877  }
1878 
1879  // If this wire is *only* used by an attach then we can just delete
1880  // it.
1881  // TODO: May need to be sensitive to "don't touch" or other
1882  // annotations.
1883  if (auto wire = dyn_cast_or_null<WireOp>(operand.getDefiningOp())) {
1884  if (!hasDontTouch(wire.getOperation()) && wire->hasOneUse() &&
1885  !wire.isForceable()) {
1886  SmallVector<Value> newOperands;
1887  for (auto newOperand : op.getOperands())
1888  if (newOperand != operand) // Don't the add wire.
1889  newOperands.push_back(newOperand);
1890 
1891  rewriter.create<AttachOp>(op->getLoc(), newOperands);
1892  rewriter.eraseOp(op);
1893  rewriter.eraseOp(wire);
1894  return success();
1895  }
1896  }
1897  }
1898  return failure();
1899 }
1900 
1901 namespace {
1902 // Remove private nodes. If they have an interesting names, move the name to
1903 // the source expression.
1904 struct FoldNodeName : public mlir::RewritePattern {
1905  FoldNodeName(MLIRContext *context)
1906  : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1907  LogicalResult matchAndRewrite(Operation *op,
1908  PatternRewriter &rewriter) const override {
1909  auto node = cast<NodeOp>(op);
1910  auto name = node.getNameAttr();
1911  if (!node.hasDroppableName() || node.getInnerSym() ||
1912  !AnnotationSet(node).canBeDeleted() || node.isForceable())
1913  return failure();
1914  auto *newOp = node.getInput().getDefiningOp();
1915  // Best effort, do not rename InstanceOp
1916  if (newOp && !isa<InstanceOp>(newOp))
1917  updateName(rewriter, newOp, name);
1918  rewriter.replaceOp(node, node.getInput());
1919  return success();
1920  }
1921 };
1922 
1923 // Bypass nodes.
1924 struct NodeBypass : public mlir::RewritePattern {
1925  NodeBypass(MLIRContext *context)
1926  : RewritePattern(NodeOp::getOperationName(), 0, context) {}
1927  LogicalResult matchAndRewrite(Operation *op,
1928  PatternRewriter &rewriter) const override {
1929  auto node = cast<NodeOp>(op);
1930  if (node.getInnerSym() || !AnnotationSet(node).canBeDeleted() ||
1931  node.use_empty() || node.isForceable())
1932  return failure();
1933  rewriter.startRootUpdate(node);
1934  node.getResult().replaceAllUsesWith(node.getInput());
1935  rewriter.finalizeRootUpdate(node);
1936  return success();
1937  }
1938 };
1939 
1940 } // namespace
1941 
1942 template <typename OpTy>
1943 static LogicalResult demoteForceableIfUnused(OpTy op,
1944  PatternRewriter &rewriter) {
1945  if (!op.isForceable() || !op.getDataRef().use_empty())
1946  return failure();
1947 
1948  firrtl::detail::replaceWithNewForceability(op, false, &rewriter);
1949  return success();
1950 }
1951 
1952 // Interesting names and symbols and don't touch force nodes to stick around.
1953 LogicalResult NodeOp::fold(FoldAdaptor adaptor,
1954  SmallVectorImpl<OpFoldResult> &results) {
1955  if (!hasDroppableName())
1956  return failure();
1957  if (hasDontTouch(getResult())) // handles inner symbols
1958  return failure();
1959  if (getAnnotationsAttr() &&
1960  !AnnotationSet(getAnnotationsAttr()).canBeDeleted())
1961  return failure();
1962  if (isForceable())
1963  return failure();
1964  if (!adaptor.getInput())
1965  return failure();
1966 
1967  results.push_back(adaptor.getInput());
1968  return success();
1969 }
1970 
1971 void NodeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1972  MLIRContext *context) {
1973  results.insert<FoldNodeName>(context);
1974  results.add(demoteForceableIfUnused<NodeOp>);
1975 }
1976 
1977 namespace {
1978 // For a lhs, find all the writers of fields of the aggregate type. If there
1979 // is one writer for each field, merge the writes
1980 struct AggOneShot : public mlir::RewritePattern {
1981  AggOneShot(StringRef name, uint32_t weight, MLIRContext *context)
1982  : RewritePattern(name, 0, context) {}
1983 
1984  SmallVector<Value> getCompleteWrite(Operation *lhs) const {
1985  auto lhsTy = lhs->getResult(0).getType();
1986  if (!type_isa<BundleType, FVectorType>(lhsTy))
1987  return {};
1988 
1989  DenseMap<uint32_t, Value> fields;
1990  for (Operation *user : lhs->getResult(0).getUsers()) {
1991  if (user->getParentOp() != lhs->getParentOp())
1992  return {};
1993  if (auto aConnect = dyn_cast<StrictConnectOp>(user)) {
1994  if (aConnect.getDest() == lhs->getResult(0))
1995  return {};
1996  } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
1997  for (Operation *subuser : subField.getResult().getUsers()) {
1998  if (auto aConnect = dyn_cast<StrictConnectOp>(subuser)) {
1999  if (aConnect.getDest() == subField) {
2000  if (subuser->getParentOp() != lhs->getParentOp())
2001  return {};
2002  if (fields.count(subField.getFieldIndex())) // duplicate write
2003  return {};
2004  fields[subField.getFieldIndex()] = aConnect.getSrc();
2005  }
2006  continue;
2007  }
2008  return {};
2009  }
2010  } else if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2011  for (Operation *subuser : subIndex.getResult().getUsers()) {
2012  if (auto aConnect = dyn_cast<StrictConnectOp>(subuser)) {
2013  if (aConnect.getDest() == subIndex) {
2014  if (subuser->getParentOp() != lhs->getParentOp())
2015  return {};
2016  if (fields.count(subIndex.getIndex())) // duplicate write
2017  return {};
2018  fields[subIndex.getIndex()] = aConnect.getSrc();
2019  }
2020  continue;
2021  }
2022  return {};
2023  }
2024  } else {
2025  return {};
2026  }
2027  }
2028 
2029  SmallVector<Value> values;
2030  uint32_t total = type_isa<BundleType>(lhsTy)
2031  ? type_cast<BundleType>(lhsTy).getNumElements()
2032  : type_cast<FVectorType>(lhsTy).getNumElements();
2033  for (uint32_t i = 0; i < total; ++i) {
2034  if (!fields.count(i))
2035  return {};
2036  values.push_back(fields[i]);
2037  }
2038  return values;
2039  }
2040 
2041  LogicalResult matchAndRewrite(Operation *op,
2042  PatternRewriter &rewriter) const override {
2043  auto values = getCompleteWrite(op);
2044  if (values.empty())
2045  return failure();
2046  rewriter.setInsertionPointToEnd(op->getBlock());
2047  auto dest = op->getResult(0);
2048  auto destType = dest.getType();
2049 
2050  // If not passive, cannot strictconnect.
2051  if (!type_cast<FIRRTLBaseType>(destType).isPassive())
2052  return failure();
2053 
2054  Value newVal = type_isa<BundleType>(destType)
2055  ? rewriter.createOrFold<BundleCreateOp>(op->getLoc(),
2056  destType, values)
2057  : rewriter.createOrFold<VectorCreateOp>(
2058  op->getLoc(), destType, values);
2059  rewriter.createOrFold<StrictConnectOp>(op->getLoc(), dest, newVal);
2060  for (Operation *user : dest.getUsers()) {
2061  if (auto subIndex = dyn_cast<SubindexOp>(user)) {
2062  for (Operation *subuser :
2063  llvm::make_early_inc_range(subIndex.getResult().getUsers()))
2064  if (auto aConnect = dyn_cast<StrictConnectOp>(subuser))
2065  if (aConnect.getDest() == subIndex)
2066  rewriter.eraseOp(aConnect);
2067  } else if (auto subField = dyn_cast<SubfieldOp>(user)) {
2068  for (Operation *subuser :
2069  llvm::make_early_inc_range(subField.getResult().getUsers()))
2070  if (auto aConnect = dyn_cast<StrictConnectOp>(subuser))
2071  if (aConnect.getDest() == subField)
2072  rewriter.eraseOp(aConnect);
2073  }
2074  }
2075  return success();
2076  }
2077 };
2078 
2079 struct WireAggOneShot : public AggOneShot {
2080  WireAggOneShot(MLIRContext *context)
2081  : AggOneShot(WireOp::getOperationName(), 0, context) {}
2082 };
2083 struct SubindexAggOneShot : public AggOneShot {
2084  SubindexAggOneShot(MLIRContext *context)
2085  : AggOneShot(SubindexOp::getOperationName(), 0, context) {}
2086 };
2087 struct SubfieldAggOneShot : public AggOneShot {
2088  SubfieldAggOneShot(MLIRContext *context)
2089  : AggOneShot(SubfieldOp::getOperationName(), 0, context) {}
2090 };
2091 } // namespace
2092 
2093 void WireOp::getCanonicalizationPatterns(RewritePatternSet &results,
2094  MLIRContext *context) {
2095  results.insert<WireAggOneShot>(context);
2096  results.add(demoteForceableIfUnused<WireOp>);
2097 }
2098 
2099 void SubindexOp::getCanonicalizationPatterns(RewritePatternSet &results,
2100  MLIRContext *context) {
2101  results.insert<SubindexAggOneShot>(context);
2102 }
2103 
2104 OpFoldResult SubindexOp::fold(FoldAdaptor adaptor) {
2105  auto attr = adaptor.getInput().dyn_cast_or_null<ArrayAttr>();
2106  if (!attr)
2107  return {};
2108  return attr[getIndex()];
2109 }
2110 
2111 OpFoldResult SubfieldOp::fold(FoldAdaptor adaptor) {
2112  auto attr = adaptor.getInput().dyn_cast_or_null<ArrayAttr>();
2113  if (!attr)
2114  return {};
2115  auto index = getFieldIndex();
2116  return attr[index];
2117 }
2118 
2119 void SubfieldOp::getCanonicalizationPatterns(RewritePatternSet &results,
2120  MLIRContext *context) {
2121  results.insert<SubfieldAggOneShot>(context);
2122 }
2123 
2124 static Attribute collectFields(MLIRContext *context,
2125  ArrayRef<Attribute> operands) {
2126  for (auto operand : operands)
2127  if (!operand)
2128  return {};
2129  return ArrayAttr::get(context, operands);
2130 }
2131 
2132 OpFoldResult BundleCreateOp::fold(FoldAdaptor adaptor) {
2133  // bundle_create(%foo["a"], %foo["b"]) -> %foo when the type of %foo is
2134  // bundle<a:..., b:...>.
2135  if (getNumOperands() > 0)
2136  if (SubfieldOp first = getOperand(0).getDefiningOp<SubfieldOp>())
2137  if (first.getFieldIndex() == 0 &&
2138  first.getInput().getType() == getType() &&
2139  llvm::all_of(
2140  llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2141  auto subindex =
2142  elem.value().template getDefiningOp<SubfieldOp>();
2143  return subindex && subindex.getInput() == first.getInput() &&
2144  subindex.getFieldIndex() == elem.index();
2145  }))
2146  return first.getInput();
2147 
2148  return collectFields(getContext(), adaptor.getOperands());
2149 }
2150 
2151 OpFoldResult VectorCreateOp::fold(FoldAdaptor adaptor) {
2152  // vector_create(%foo[0], %foo[1]) -> %foo when the type of %foo is
2153  // vector<..., 2>.
2154  if (getNumOperands() > 0)
2155  if (SubindexOp first = getOperand(0).getDefiningOp<SubindexOp>())
2156  if (first.getIndex() == 0 && first.getInput().getType() == getType() &&
2157  llvm::all_of(
2158  llvm::drop_begin(llvm::enumerate(getOperands())), [&](auto elem) {
2159  auto subindex =
2160  elem.value().template getDefiningOp<SubindexOp>();
2161  return subindex && subindex.getInput() == first.getInput() &&
2162  subindex.getIndex() == elem.index();
2163  }))
2164  return first.getInput();
2165 
2166  return collectFields(getContext(), adaptor.getOperands());
2167 }
2168 
2169 OpFoldResult UninferredResetCastOp::fold(FoldAdaptor adaptor) {
2170  if (getOperand().getType() == getType())
2171  return getOperand();
2172  return {};
2173 }
2174 
2175 namespace {
2176 // A register with constant reset and all connection to either itself or the
2177 // same constant, must be replaced by the constant.
2178 struct FoldResetMux : public mlir::RewritePattern {
2179  FoldResetMux(MLIRContext *context)
2180  : RewritePattern(RegResetOp::getOperationName(), 0, context) {}
2181  LogicalResult matchAndRewrite(Operation *op,
2182  PatternRewriter &rewriter) const override {
2183  auto reg = cast<RegResetOp>(op);
2184  auto reset =
2185  dyn_cast_or_null<ConstantOp>(reg.getResetValue().getDefiningOp());
2186  if (!reset || hasDontTouch(reg.getOperation()) ||
2187  !AnnotationSet(reg).canBeDeleted() || reg.isForceable())
2188  return failure();
2189  // Find the one true connect, or bail
2190  auto con = getSingleConnectUserOf(reg.getResult());
2191  if (!con)
2192  return failure();
2193 
2194  auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
2195  if (!mux)
2196  return failure();
2197  auto *high = mux.getHigh().getDefiningOp();
2198  auto *low = mux.getLow().getDefiningOp();
2199  auto constOp = dyn_cast_or_null<ConstantOp>(high);
2200 
2201  if (constOp && low != reg)
2202  return failure();
2203  if (dyn_cast_or_null<ConstantOp>(low) && high == reg)
2204  constOp = dyn_cast<ConstantOp>(low);
2205 
2206  if (!constOp || constOp.getType() != reset.getType() ||
2207  constOp.getValue() != reset.getValue())
2208  return failure();
2209 
2210  // Check all types should be typed by now
2211  auto regTy = reg.getResult().getType();
2212  if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
2213  mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
2214  regTy.getBitWidthOrSentinel() < 0)
2215  return failure();
2216 
2217  // Ok, we know we are doing the transformation.
2218 
2219  // Make sure the constant dominates all users.
2220  if (constOp != &con->getBlock()->front())
2221  constOp->moveBefore(&con->getBlock()->front());
2222 
2223  // Replace the register with the constant.
2224  replaceOpAndCopyName(rewriter, reg, constOp.getResult());
2225  // Remove the connect.
2226  rewriter.eraseOp(con);
2227  return success();
2228  }
2229 };
2230 } // namespace
2231 
2232 static bool isDefinedByOneConstantOp(Value v) {
2233  if (auto c = v.getDefiningOp<ConstantOp>())
2234  return c.getValue().isOne();
2235  if (auto sc = v.getDefiningOp<SpecialConstantOp>())
2236  return sc.getValue();
2237  return false;
2238 }
2239 
2240 static LogicalResult
2241 canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
2242  if (!isDefinedByOneConstantOp(reg.getResetSignal()))
2243  return failure();
2244 
2245  // Ignore 'passthrough'.
2246  (void)dropWrite(rewriter, reg->getResult(0), {});
2247  replaceOpWithNewOpAndCopyName<NodeOp>(
2248  rewriter, reg, reg.getResetValue(), reg.getNameAttr(), reg.getNameKind(),
2249  reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
2250  return success();
2251 }
2252 
2253 void RegResetOp::getCanonicalizationPatterns(RewritePatternSet &results,
2254  MLIRContext *context) {
2255  results.add<patterns::RegResetWithZeroReset, FoldResetMux>(context);
2256  results.add(canonicalizeRegResetWithOneReset);
2257  results.add(demoteForceableIfUnused<RegResetOp>);
2258 }
2259 
2260 // Returns the value connected to a port, if there is only one.
2261 static Value getPortFieldValue(Value port, StringRef name) {
2262  auto portTy = type_cast<BundleType>(port.getType());
2263  auto fieldIndex = portTy.getElementIndex(name);
2264  assert(fieldIndex && "missing field on memory port");
2265 
2266  Value value = {};
2267  for (auto *op : port.getUsers()) {
2268  auto portAccess = cast<SubfieldOp>(op);
2269  if (fieldIndex != portAccess.getFieldIndex())
2270  continue;
2271  auto conn = getSingleConnectUserOf(portAccess);
2272  if (!conn || value)
2273  return {};
2274  value = conn.getSrc();
2275  }
2276  return value;
2277 }
2278 
2279 // Returns true if the enable field of a port is set to false.
2280 static bool isPortDisabled(Value port) {
2281  auto value = getPortFieldValue(port, "en");
2282  if (!value)
2283  return false;
2284  auto portConst = value.getDefiningOp<ConstantOp>();
2285  if (!portConst)
2286  return false;
2287  return portConst.getValue().isZero();
2288 }
2289 
2290 // Returns true if the data output is unused.
2291 static bool isPortUnused(Value port, StringRef data) {
2292  auto portTy = type_cast<BundleType>(port.getType());
2293  auto fieldIndex = portTy.getElementIndex(data);
2294  assert(fieldIndex && "missing enable flag on memory port");
2295 
2296  for (auto *op : port.getUsers()) {
2297  auto portAccess = cast<SubfieldOp>(op);
2298  if (fieldIndex != portAccess.getFieldIndex())
2299  continue;
2300  if (!portAccess.use_empty())
2301  return false;
2302  }
2303 
2304  return true;
2305 }
2306 
2307 // Returns the value connected to a port, if there is only one.
2308 static void replacePortField(PatternRewriter &rewriter, Value port,
2309  StringRef name, Value value) {
2310  auto portTy = type_cast<BundleType>(port.getType());
2311  auto fieldIndex = portTy.getElementIndex(name);
2312  assert(fieldIndex && "missing field on memory port");
2313 
2314  for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2315  auto portAccess = cast<SubfieldOp>(op);
2316  if (fieldIndex != portAccess.getFieldIndex())
2317  continue;
2318  rewriter.replaceAllUsesWith(portAccess, value);
2319  rewriter.eraseOp(portAccess);
2320  }
2321 }
2322 
2323 // Remove accesses to a port which is used.
2324 static void erasePort(PatternRewriter &rewriter, Value port) {
2325  // Helper to create a dummy 0 clock for the dummy registers.
2326  Value clock;
2327  auto getClock = [&] {
2328  if (!clock)
2329  clock = rewriter.create<SpecialConstantOp>(
2330  port.getLoc(), ClockType::get(rewriter.getContext()), false);
2331  return clock;
2332  };
2333 
2334  // Find the clock field of the port and determine whether the port is
2335  // accessed only through its subfields or as a whole wire. If the port
2336  // is used in its entirety, replace it with a wire. Otherwise,
2337  // eliminate individual subfields and replace with reasonable defaults.
2338  for (auto *op : port.getUsers()) {
2339  auto subfield = dyn_cast<SubfieldOp>(op);
2340  if (!subfield) {
2341  auto ty = port.getType();
2342  auto reg = rewriter.create<RegOp>(port.getLoc(), ty, getClock());
2343  port.replaceAllUsesWith(reg.getResult());
2344  return;
2345  }
2346  }
2347 
2348  // Remove all connects to field accesses as they are no longer relevant.
2349  // If field values are used anywhere, which should happen solely for read
2350  // ports, a dummy register is introduced which replicates the behaviour of
2351  // memory that is never written, but might be read.
2352  for (auto *accessOp : llvm::make_early_inc_range(port.getUsers())) {
2353  auto access = cast<SubfieldOp>(accessOp);
2354  for (auto *user : llvm::make_early_inc_range(access->getUsers())) {
2355  auto connect = dyn_cast<FConnectLike>(user);
2356  if (connect && connect.getDest() == access) {
2357  rewriter.eraseOp(user);
2358  continue;
2359  }
2360  }
2361  if (access.use_empty()) {
2362  rewriter.eraseOp(access);
2363  continue;
2364  }
2365 
2366  // Replace read values with a register that is never written, handing off
2367  // the canonicalization of such a register to another canonicalizer.
2368  auto ty = access.getType();
2369  auto reg = rewriter.create<RegOp>(access.getLoc(), ty, getClock());
2370  rewriter.replaceOp(access, reg.getResult());
2371  }
2372  assert(port.use_empty() && "port should have no remaining uses");
2373 }
2374 
2375 namespace {
2376 // If memory has known, but zero width, eliminate it.
2377 struct FoldZeroWidthMemory : public mlir::RewritePattern {
2378  FoldZeroWidthMemory(MLIRContext *context)
2379  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2380  LogicalResult matchAndRewrite(Operation *op,
2381  PatternRewriter &rewriter) const override {
2382  MemOp mem = cast<MemOp>(op);
2383  if (hasDontTouch(mem))
2384  return failure();
2385 
2386  if (mem.getDataType().getBitWidthOrSentinel() != 0)
2387  return failure();
2388 
2389  // Make sure are users are safe to replace
2390  for (auto port : mem.getResults())
2391  for (auto *user : port.getUsers())
2392  if (!isa<SubfieldOp>(user))
2393  return failure();
2394 
2395  // Annoyingly, there isn't a good replacement for the port as a whole,
2396  // since they have an outer flip type.
2397  for (auto port : op->getResults()) {
2398  for (auto *user : llvm::make_early_inc_range(port.getUsers())) {
2399  SubfieldOp sfop = cast<SubfieldOp>(user);
2400  replaceOpWithNewOpAndCopyName<WireOp>(rewriter, sfop,
2401  sfop.getResult().getType());
2402  }
2403  }
2404  rewriter.eraseOp(op);
2405  return success();
2406  }
2407 };
2408 
2409 // If memory has no write ports and no file initialization, eliminate it.
2410 struct FoldReadOrWriteOnlyMemory : public mlir::RewritePattern {
2411  FoldReadOrWriteOnlyMemory(MLIRContext *context)
2412  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2413  LogicalResult matchAndRewrite(Operation *op,
2414  PatternRewriter &rewriter) const override {
2415  MemOp mem = cast<MemOp>(op);
2416  if (hasDontTouch(mem))
2417  return failure();
2418  bool isRead = false, isWritten = false;
2419  for (unsigned i = 0; i < mem.getNumResults(); ++i) {
2420  switch (mem.getPortKind(i)) {
2421  case MemOp::PortKind::Read:
2422  isRead = true;
2423  if (isWritten)
2424  return failure();
2425  continue;
2426  case MemOp::PortKind::Write:
2427  isWritten = true;
2428  if (isRead)
2429  return failure();
2430  continue;
2431  case MemOp::PortKind::Debug:
2432  case MemOp::PortKind::ReadWrite:
2433  return failure();
2434  }
2435  llvm_unreachable("unknown port kind");
2436  }
2437  assert((!isWritten || !isRead) && "memory is in use");
2438 
2439  // If the memory is read only, but has a file initialization, then we can't
2440  // remove it. A write only memory with file initialization is okay to
2441  // remove.
2442  if (isRead && mem.getInit())
2443  return failure();
2444 
2445  for (auto port : mem.getResults())
2446  erasePort(rewriter, port);
2447 
2448  rewriter.eraseOp(op);
2449  return success();
2450  }
2451 };
2452 
2453 // Eliminate the dead ports of memories.
2454 struct FoldUnusedPorts : public mlir::RewritePattern {
2455  FoldUnusedPorts(MLIRContext *context)
2456  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2457  LogicalResult matchAndRewrite(Operation *op,
2458  PatternRewriter &rewriter) const override {
2459  MemOp mem = cast<MemOp>(op);
2460  if (hasDontTouch(mem))
2461  return failure();
2462  // Identify the dead and changed ports.
2463  llvm::SmallBitVector deadPorts(mem.getNumResults());
2464  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2465  // Do not simplify annotated ports.
2466  if (!mem.getPortAnnotation(i).empty())
2467  continue;
2468 
2469  // Skip debug ports.
2470  auto kind = mem.getPortKind(i);
2471  if (kind == MemOp::PortKind::Debug)
2472  continue;
2473 
2474  // If a port is disabled, always eliminate it.
2475  if (isPortDisabled(port)) {
2476  deadPorts.set(i);
2477  continue;
2478  }
2479  // Eliminate read ports whose outputs are not used.
2480  if (kind == MemOp::PortKind::Read && isPortUnused(port, "data")) {
2481  deadPorts.set(i);
2482  continue;
2483  }
2484  }
2485  if (deadPorts.none())
2486  return failure();
2487 
2488  // Rebuild the new memory with the altered ports.
2489  SmallVector<Type> resultTypes;
2490  SmallVector<StringRef> portNames;
2491  SmallVector<Attribute> portAnnotations;
2492  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2493  if (deadPorts[i])
2494  continue;
2495  resultTypes.push_back(port.getType());
2496  portNames.push_back(mem.getPortName(i));
2497  portAnnotations.push_back(mem.getPortAnnotation(i));
2498  }
2499 
2500  MemOp newOp;
2501  if (!resultTypes.empty())
2502  newOp = rewriter.create<MemOp>(
2503  mem.getLoc(), resultTypes, mem.getReadLatency(),
2504  mem.getWriteLatency(), mem.getDepth(), mem.getRuw(),
2505  rewriter.getStrArrayAttr(portNames), mem.getName(), mem.getNameKind(),
2506  mem.getAnnotations(), rewriter.getArrayAttr(portAnnotations),
2507  mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2508 
2509  // Replace the dead ports with dummy wires.
2510  unsigned nextPort = 0;
2511  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2512  if (deadPorts[i])
2513  erasePort(rewriter, port);
2514  else
2515  port.replaceAllUsesWith(newOp.getResult(nextPort++));
2516  }
2517 
2518  rewriter.eraseOp(op);
2519  return success();
2520  }
2521 };
2522 
2523 // Rewrite write-only read-write ports to write ports.
2524 struct FoldReadWritePorts : public mlir::RewritePattern {
2525  FoldReadWritePorts(MLIRContext *context)
2526  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2527  LogicalResult matchAndRewrite(Operation *op,
2528  PatternRewriter &rewriter) const override {
2529  MemOp mem = cast<MemOp>(op);
2530  if (hasDontTouch(mem))
2531  return failure();
2532 
2533  // Identify read-write ports whose read end is unused.
2534  llvm::SmallBitVector deadReads(mem.getNumResults());
2535  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2536  if (mem.getPortKind(i) != MemOp::PortKind::ReadWrite)
2537  continue;
2538  if (!mem.getPortAnnotation(i).empty())
2539  continue;
2540  if (isPortUnused(port, "rdata")) {
2541  deadReads.set(i);
2542  continue;
2543  }
2544  }
2545  if (deadReads.none())
2546  return failure();
2547 
2548  SmallVector<Type> resultTypes;
2549  SmallVector<StringRef> portNames;
2550  SmallVector<Attribute> portAnnotations;
2551  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2552  if (deadReads[i])
2553  resultTypes.push_back(
2554  MemOp::getTypeForPort(mem.getDepth(), mem.getDataType(),
2555  MemOp::PortKind::Write, mem.getMaskBits()));
2556  else
2557  resultTypes.push_back(port.getType());
2558 
2559  portNames.push_back(mem.getPortName(i));
2560  portAnnotations.push_back(mem.getPortAnnotation(i));
2561  }
2562 
2563  auto newOp = rewriter.create<MemOp>(
2564  mem.getLoc(), resultTypes, mem.getReadLatency(), mem.getWriteLatency(),
2565  mem.getDepth(), mem.getRuw(), rewriter.getStrArrayAttr(portNames),
2566  mem.getName(), mem.getNameKind(), mem.getAnnotations(),
2567  rewriter.getArrayAttr(portAnnotations), mem.getInnerSymAttr(),
2568  mem.getInitAttr(), mem.getPrefixAttr());
2569 
2570  for (unsigned i = 0, n = mem.getNumResults(); i < n; ++i) {
2571  auto result = mem.getResult(i);
2572  auto newResult = newOp.getResult(i);
2573  if (deadReads[i]) {
2574  auto resultPortTy = type_cast<BundleType>(result.getType());
2575 
2576  // Rewrite accesses to the old port field to accesses to a
2577  // corresponding field of the new port.
2578  auto replace = [&](StringRef toName, StringRef fromName) {
2579  auto fromFieldIndex = resultPortTy.getElementIndex(fromName);
2580  assert(fromFieldIndex && "missing enable flag on memory port");
2581 
2582  auto toField = rewriter.create<SubfieldOp>(newResult.getLoc(),
2583  newResult, toName);
2584  for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2585  auto fromField = cast<SubfieldOp>(op);
2586  if (fromFieldIndex != fromField.getFieldIndex())
2587  continue;
2588  rewriter.replaceOp(fromField, toField.getResult());
2589  }
2590  };
2591 
2592  replace("addr", "addr");
2593  replace("en", "en");
2594  replace("clk", "clk");
2595  replace("data", "wdata");
2596  replace("mask", "wmask");
2597 
2598  // Remove the wmode field, replacing it with dummy wires.
2599  auto wmodeFieldIndex = resultPortTy.getElementIndex("wmode");
2600  for (auto *op : llvm::make_early_inc_range(result.getUsers())) {
2601  auto wmodeField = cast<SubfieldOp>(op);
2602  if (wmodeFieldIndex != wmodeField.getFieldIndex())
2603  continue;
2604  rewriter.replaceOpWithNewOp<WireOp>(wmodeField, wmodeField.getType());
2605  }
2606  } else {
2607  result.replaceAllUsesWith(newResult);
2608  }
2609  }
2610  rewriter.eraseOp(op);
2611  return success();
2612  }
2613 };
2614 
2615 // Eliminate the dead ports of memories.
2616 struct FoldUnusedBits : public mlir::RewritePattern {
2617  FoldUnusedBits(MLIRContext *context)
2618  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2619 
2620  LogicalResult matchAndRewrite(Operation *op,
2621  PatternRewriter &rewriter) const override {
2622  MemOp mem = cast<MemOp>(op);
2623  if (hasDontTouch(mem))
2624  return failure();
2625 
2626  // Only apply the transformation if the memory is not sequential.
2627  const auto &summary = mem.getSummary();
2628  if (summary.isMasked || summary.isSeqMem())
2629  return failure();
2630 
2631  auto type = type_dyn_cast<IntType>(mem.getDataType());
2632  if (!type)
2633  return failure();
2634  auto width = type.getBitWidthOrSentinel();
2635  if (width <= 0)
2636  return failure();
2637 
2638  llvm::SmallBitVector usedBits(width);
2639  DenseMap<unsigned, unsigned> mapping;
2640 
2641  // Find which bits are used out of the users of a read port. This detects
2642  // ports whose data/rdata field is used only through bit select ops. The
2643  // bit selects are then used to build a bit-mask. The ops are collected.
2644  SmallVector<BitsPrimOp> readOps;
2645  auto findReadUsers = [&](Value port, StringRef field) {
2646  auto portTy = type_cast<BundleType>(port.getType());
2647  auto fieldIndex = portTy.getElementIndex(field);
2648  assert(fieldIndex && "missing data port");
2649 
2650  for (auto *op : port.getUsers()) {
2651  auto portAccess = cast<SubfieldOp>(op);
2652  if (fieldIndex != portAccess.getFieldIndex())
2653  continue;
2654 
2655  for (auto *user : op->getUsers()) {
2656  auto bits = dyn_cast<BitsPrimOp>(user);
2657  if (!bits) {
2658  usedBits.set();
2659  continue;
2660  }
2661 
2662  usedBits.set(bits.getLo(), bits.getHi() + 1);
2663  mapping[bits.getLo()] = 0;
2664  readOps.push_back(bits);
2665  }
2666  }
2667  };
2668 
2669  // Finds the users of write ports. This expects all the data/wdata fields
2670  // of the ports to be used solely as the destination of strict connects.
2671  // If a memory has ports with other uses, it is excluded from optimisation.
2672  SmallVector<StrictConnectOp> writeOps;
2673  auto findWriteUsers = [&](Value port, StringRef field) -> LogicalResult {
2674  auto portTy = type_cast<BundleType>(port.getType());
2675  auto fieldIndex = portTy.getElementIndex(field);
2676  assert(fieldIndex && "missing data port");
2677 
2678  for (auto *op : port.getUsers()) {
2679  auto portAccess = cast<SubfieldOp>(op);
2680  if (fieldIndex != portAccess.getFieldIndex())
2681  continue;
2682 
2683  auto conn = getSingleConnectUserOf(portAccess);
2684  if (!conn)
2685  return failure();
2686 
2687  writeOps.push_back(conn);
2688  }
2689  return success();
2690  };
2691 
2692  // Traverse all ports and find the read and used data fields.
2693  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2694  // Do not simplify annotated ports.
2695  if (!mem.getPortAnnotation(i).empty())
2696  return failure();
2697 
2698  switch (mem.getPortKind(i)) {
2699  case MemOp::PortKind::Debug:
2700  // Skip debug ports.
2701  return failure();
2702  case MemOp::PortKind::Write:
2703  if (failed(findWriteUsers(port, "data")))
2704  return failure();
2705  continue;
2706  case MemOp::PortKind::Read:
2707  findReadUsers(port, "data");
2708  continue;
2709  case MemOp::PortKind::ReadWrite:
2710  if (failed(findWriteUsers(port, "wdata")))
2711  return failure();
2712  findReadUsers(port, "rdata");
2713  continue;
2714  }
2715  llvm_unreachable("unknown port kind");
2716  }
2717 
2718  // Perform the transformation is there are some bits missing. Unused
2719  // memories are handled in a different canonicalizer.
2720  if (usedBits.all() || usedBits.none())
2721  return failure();
2722 
2723  // Build a mapping of existing indices to compacted ones.
2724  SmallVector<std::pair<unsigned, unsigned>> ranges;
2725  unsigned newWidth = 0;
2726  for (int i = usedBits.find_first(); 0 <= i && i < width;) {
2727  int e = usedBits.find_next_unset(i);
2728  if (e < 0)
2729  e = width;
2730  for (int idx = i; idx < e; ++idx, ++newWidth) {
2731  if (auto it = mapping.find(idx); it != mapping.end()) {
2732  it->second = newWidth;
2733  }
2734  }
2735  ranges.emplace_back(i, e - 1);
2736  i = e != width ? usedBits.find_next(e) : e;
2737  }
2738 
2739  // Create the new op with the new port types.
2740  auto newType = IntType::get(op->getContext(), type.isSigned(), newWidth);
2741  SmallVector<Type> portTypes;
2742  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2743  portTypes.push_back(
2744  MemOp::getTypeForPort(mem.getDepth(), newType, mem.getPortKind(i)));
2745  }
2746  auto newMem = rewriter.replaceOpWithNewOp<MemOp>(
2747  mem, portTypes, mem.getReadLatency(), mem.getWriteLatency(),
2748  mem.getDepth(), mem.getRuw(), mem.getPortNames(), mem.getName(),
2749  mem.getNameKind(), mem.getAnnotations(), mem.getPortAnnotations(),
2750  mem.getInnerSymAttr(), mem.getInitAttr(), mem.getPrefixAttr());
2751 
2752  // Rewrite bundle users to the new data type.
2753  auto rewriteSubfield = [&](Value port, StringRef field) {
2754  auto portTy = type_cast<BundleType>(port.getType());
2755  auto fieldIndex = portTy.getElementIndex(field);
2756  assert(fieldIndex && "missing data port");
2757 
2758  rewriter.setInsertionPointAfter(newMem);
2759  auto newPortAccess =
2760  rewriter.create<SubfieldOp>(port.getLoc(), port, field);
2761 
2762  for (auto *op : llvm::make_early_inc_range(port.getUsers())) {
2763  auto portAccess = cast<SubfieldOp>(op);
2764  if (op == newPortAccess || fieldIndex != portAccess.getFieldIndex())
2765  continue;
2766  rewriter.replaceOp(portAccess, newPortAccess.getResult());
2767  }
2768  };
2769 
2770  // Rewrite the field accesses.
2771  for (auto [i, port] : llvm::enumerate(newMem.getResults())) {
2772  switch (newMem.getPortKind(i)) {
2773  case MemOp::PortKind::Debug:
2774  llvm_unreachable("cannot rewrite debug port");
2775  case MemOp::PortKind::Write:
2776  rewriteSubfield(port, "data");
2777  continue;
2778  case MemOp::PortKind::Read:
2779  rewriteSubfield(port, "data");
2780  continue;
2781  case MemOp::PortKind::ReadWrite:
2782  rewriteSubfield(port, "rdata");
2783  rewriteSubfield(port, "wdata");
2784  continue;
2785  }
2786  llvm_unreachable("unknown port kind");
2787  }
2788 
2789  // Rewrite the reads to the new ranges, compacting them.
2790  for (auto readOp : readOps) {
2791  rewriter.setInsertionPointAfter(readOp);
2792  auto it = mapping.find(readOp.getLo());
2793  assert(it != mapping.end() && "bit op mapping not found");
2794  rewriter.replaceOpWithNewOp<BitsPrimOp>(
2795  readOp, readOp.getInput(),
2796  readOp.getHi() - readOp.getLo() + it->second, it->second);
2797  }
2798 
2799  // Rewrite the writes into a concatenation of slices.
2800  for (auto writeOp : writeOps) {
2801  Value source = writeOp.getSrc();
2802  rewriter.setInsertionPoint(writeOp);
2803 
2804  Value catOfSlices;
2805  for (auto &[start, end] : ranges) {
2806  Value slice =
2807  rewriter.create<BitsPrimOp>(writeOp.getLoc(), source, end, start);
2808  if (catOfSlices) {
2809  catOfSlices =
2810  rewriter.create<CatPrimOp>(writeOp.getLoc(), slice, catOfSlices);
2811  } else {
2812  catOfSlices = slice;
2813  }
2814  }
2815  rewriter.replaceOpWithNewOp<StrictConnectOp>(writeOp, writeOp.getDest(),
2816  catOfSlices);
2817  }
2818 
2819  return success();
2820  }
2821 };
2822 
2823 // Rewrite single-address memories to a firrtl register.
2824 struct FoldRegMems : public mlir::RewritePattern {
2825  FoldRegMems(MLIRContext *context)
2826  : RewritePattern(MemOp::getOperationName(), 0, context) {}
2827  LogicalResult matchAndRewrite(Operation *op,
2828  PatternRewriter &rewriter) const override {
2829  MemOp mem = cast<MemOp>(op);
2830  const FirMemory &info = mem.getSummary();
2831  if (hasDontTouch(mem) || info.depth != 1)
2832  return failure();
2833 
2834  auto memModule = mem->getParentOfType<FModuleOp>();
2835 
2836  // Find the clock of the register-to-be, all write ports should share it.
2837  Value clock;
2838  SmallPtrSet<Operation *, 8> connects;
2839  SmallVector<SubfieldOp> portAccesses;
2840  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2841  if (!mem.getPortAnnotation(i).empty())
2842  continue;
2843 
2844  auto collect = [&, port = port](ArrayRef<StringRef> fields) {
2845  auto portTy = type_cast<BundleType>(port.getType());
2846  for (auto field : fields) {
2847  auto fieldIndex = portTy.getElementIndex(field);
2848  assert(fieldIndex && "missing field on memory port");
2849 
2850  for (auto *op : port.getUsers()) {
2851  auto portAccess = cast<SubfieldOp>(op);
2852  if (fieldIndex != portAccess.getFieldIndex())
2853  continue;
2854  portAccesses.push_back(portAccess);
2855  for (auto *user : portAccess->getUsers()) {
2856  auto conn = dyn_cast<FConnectLike>(user);
2857  if (!conn)
2858  return failure();
2859  connects.insert(conn);
2860  }
2861  }
2862  }
2863  return success();
2864  };
2865 
2866  switch (mem.getPortKind(i)) {
2867  case MemOp::PortKind::Debug:
2868  return failure();
2869  case MemOp::PortKind::Read:
2870  if (failed(collect({"clk", "en", "addr"})))
2871  return failure();
2872  continue;
2873  case MemOp::PortKind::Write:
2874  if (failed(collect({"clk", "en", "addr", "data", "mask"})))
2875  return failure();
2876  break;
2877  case MemOp::PortKind::ReadWrite:
2878  if (failed(collect({"clk", "en", "addr", "wmode", "wdata", "wmask"})))
2879  return failure();
2880  break;
2881  }
2882 
2883  Value portClock = getPortFieldValue(port, "clk");
2884  if (!portClock || (clock && portClock != clock))
2885  return failure();
2886  clock = portClock;
2887  }
2888 
2889  // Create a new register to store the data.
2890  auto ty = mem.getDataType();
2891  rewriter.setInsertionPointAfterValue(clock);
2892  auto reg = rewriter.create<RegOp>(mem.getLoc(), ty, clock, mem.getName())
2893  .getResult();
2894 
2895  // Helper to insert a given number of pipeline stages through registers.
2896  auto pipeline = [&](Value value, Value clock, const Twine &name,
2897  unsigned latency) {
2898  for (unsigned i = 0; i < latency; ++i) {
2899  std::string regName;
2900  {
2901  llvm::raw_string_ostream os(regName);
2902  os << mem.getName() << "_" << name << "_" << i;
2903  }
2904 
2905  auto reg = rewriter
2906  .create<RegOp>(mem.getLoc(), value.getType(), clock,
2907  rewriter.getStringAttr(regName))
2908  .getResult();
2909  rewriter.create<StrictConnectOp>(value.getLoc(), reg, value);
2910  value = reg;
2911  }
2912  return value;
2913  };
2914 
2915  const unsigned writeStages = info.writeLatency - 1;
2916 
2917  // Traverse each port. Replace reads with the pipelined register, discarding
2918  // the enable flag and reading unconditionally. Pipeline the mask, enable
2919  // and data bits of all write ports to be arbitrated and wired to the reg.
2920  SmallVector<std::tuple<Value, Value, Value>> writes;
2921  for (auto [i, port] : llvm::enumerate(mem.getResults())) {
2922  Value portClock = getPortFieldValue(port, "clk");
2923  StringRef name = mem.getPortName(i);
2924 
2925  auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
2926  Value value = getPortFieldValue(port, field);
2927  assert(value);
2928  rewriter.setInsertionPointAfterValue(value);
2929  return pipeline(value, portClock, name + "_" + field, stages);
2930  };
2931 
2932  switch (mem.getPortKind(i)) {
2933  case MemOp::PortKind::Debug:
2934  llvm_unreachable("unknown port kind");
2935  case MemOp::PortKind::Read: {
2936  // Read ports pipeline the addr and enable signals. However, the
2937  // address must be 0 for single-address memories and the enable signal
2938  // is ignored, always reading out the register. Under these constraints,
2939  // the read port can be replaced with the value from the register.
2940  rewriter.setInsertionPointAfterValue(reg);
2941  replacePortField(rewriter, port, "data", reg);
2942  break;
2943  }
2944  case MemOp::PortKind::Write: {
2945  auto data = portPipeline("data", writeStages);
2946  auto en = portPipeline("en", writeStages);
2947  auto mask = portPipeline("mask", writeStages);
2948  writes.emplace_back(data, en, mask);
2949  break;
2950  }
2951  case MemOp::PortKind::ReadWrite: {
2952  // Always read the register into the read end.
2953  rewriter.setInsertionPointAfterValue(reg);
2954  replacePortField(rewriter, port, "rdata", reg);
2955 
2956  // Create a write enable and pipeline stages.
2957  auto wdata = portPipeline("wdata", writeStages);
2958  auto wmask = portPipeline("wmask", writeStages);
2959 
2960  Value en = getPortFieldValue(port, "en");
2961  Value wmode = getPortFieldValue(port, "wmode");
2962  rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
2963 
2964  auto wen = rewriter.create<AndPrimOp>(port.getLoc(), en, wmode);
2965  auto wenPipelined =
2966  pipeline(wen, portClock, name + "_wen", writeStages);
2967  writes.emplace_back(wdata, wenPipelined, wmask);
2968  break;
2969  }
2970  }
2971  }
2972 
2973  // Regardless of `writeUnderWrite`, always implement PortOrder.
2974  rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
2975  Value next = reg;
2976  for (auto &[data, en, mask] : writes) {
2977  Value masked;
2978 
2979  // If a mask bit is used, emit muxes to select the input from the
2980  // register (no mask) or the input (mask bit set).
2981  Location loc = mem.getLoc();
2982  unsigned maskGran = info.dataWidth / info.maskBits;
2983  for (unsigned i = 0; i < info.maskBits; ++i) {
2984  unsigned hi = (i + 1) * maskGran - 1;
2985  unsigned lo = i * maskGran;
2986 
2987  auto dataPart = rewriter.createOrFold<BitsPrimOp>(loc, data, hi, lo);
2988  auto nextPart = rewriter.createOrFold<BitsPrimOp>(loc, next, hi, lo);
2989  auto bit = rewriter.createOrFold<BitsPrimOp>(loc, mask, i, i);
2990  auto chunk = rewriter.create<MuxPrimOp>(loc, bit, dataPart, nextPart);
2991 
2992  if (masked) {
2993  masked = rewriter.create<CatPrimOp>(loc, chunk, masked);
2994  } else {
2995  masked = chunk;
2996  }
2997  }
2998 
2999  next = rewriter.create<MuxPrimOp>(next.getLoc(), en, masked, next);
3000  }
3001  rewriter.create<StrictConnectOp>(reg.getLoc(), reg, next);
3002 
3003  // Delete the fields and their associated connects.
3004  for (Operation *conn : connects)
3005  rewriter.eraseOp(conn);
3006  for (auto portAccess : portAccesses)
3007  rewriter.eraseOp(portAccess);
3008  rewriter.eraseOp(mem);
3009 
3010  return success();
3011  }
3012 };
3013 } // namespace
3014 
3015 void MemOp::getCanonicalizationPatterns(RewritePatternSet &results,
3016  MLIRContext *context) {
3017  results
3018  .insert<FoldZeroWidthMemory, FoldReadOrWriteOnlyMemory,
3019  FoldReadWritePorts, FoldUnusedPorts, FoldUnusedBits, FoldRegMems>(
3020  context);
3021 }
3022 
3023 //===----------------------------------------------------------------------===//
3024 // Declarations
3025 //===----------------------------------------------------------------------===//
3026 
3027 // Turn synchronous reset looking register updates to registers with resets.
3028 // Also, const prop registers that are driven by a mux tree containing only
3029 // instances of one constant or self-assigns.
3030 static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter) {
3031  // reg ; connect(reg, mux(port, const, val)) ->
3032  // reg.reset(port, const); connect(reg, val)
3033 
3034  // Find the one true connect, or bail
3035  auto con = getSingleConnectUserOf(reg.getResult());
3036  if (!con)
3037  return failure();
3038 
3039  auto mux = dyn_cast_or_null<MuxPrimOp>(con.getSrc().getDefiningOp());
3040  if (!mux)
3041  return failure();
3042  auto *high = mux.getHigh().getDefiningOp();
3043  auto *low = mux.getLow().getDefiningOp();
3044  // Reset value must be constant
3045  auto constOp = dyn_cast_or_null<ConstantOp>(high);
3046 
3047  // Detect the case if a register only has two possible drivers:
3048  // (1) itself/uninit and (2) constant.
3049  // The mux can then be replaced with the constant.
3050  // r = mux(cond, r, 3) --> r = 3
3051  // r = mux(cond, 3, r) --> r = 3
3052  bool constReg = false;
3053 
3054  if (constOp && low == reg)
3055  constReg = true;
3056  else if (dyn_cast_or_null<ConstantOp>(low) && high == reg) {
3057  constReg = true;
3058  constOp = dyn_cast<ConstantOp>(low);
3059  }
3060  if (!constOp)
3061  return failure();
3062 
3063  // For a non-constant register, reset should be a module port (heuristic to
3064  // limit to intended reset lines). Replace the register anyway if constant.
3065  if (!isa<BlockArgument>(mux.getSel()) && !constReg)
3066  return failure();
3067 
3068  // Check all types should be typed by now
3069  auto regTy = reg.getResult().getType();
3070  if (con.getDest().getType() != regTy || con.getSrc().getType() != regTy ||
3071  mux.getHigh().getType() != regTy || mux.getLow().getType() != regTy ||
3072  regTy.getBitWidthOrSentinel() < 0)
3073  return failure();
3074 
3075  // Ok, we know we are doing the transformation.
3076 
3077  // Make sure the constant dominates all users.
3078  if (constOp != &con->getBlock()->front())
3079  constOp->moveBefore(&con->getBlock()->front());
3080 
3081  if (!constReg) {
3082  SmallVector<NamedAttribute, 2> attrs(reg->getDialectAttrs());
3083  auto newReg = replaceOpWithNewOpAndCopyName<RegResetOp>(
3084  rewriter, reg, reg.getResult().getType(), reg.getClockVal(),
3085  mux.getSel(), mux.getHigh(), reg.getNameAttr(), reg.getNameKindAttr(),
3086  reg.getAnnotationsAttr(), reg.getInnerSymAttr(),
3087  reg.getForceableAttr());
3088  newReg->setDialectAttrs(attrs);
3089  }
3090  auto pt = rewriter.saveInsertionPoint();
3091  rewriter.setInsertionPoint(con);
3092  auto v = constReg ? (Value)constOp.getResult() : (Value)mux.getLow();
3093  replaceOpWithNewOpAndCopyName<ConnectOp>(rewriter, con, con.getDest(), v);
3094  rewriter.restoreInsertionPoint(pt);
3095  return success();
3096 }
3097 
3098 LogicalResult RegOp::canonicalize(RegOp op, PatternRewriter &rewriter) {
3099  if (!hasDontTouch(op.getOperation()) && !op.isForceable() &&
3100  succeeded(foldHiddenReset(op, rewriter)))
3101  return success();
3102 
3103  if (succeeded(demoteForceableIfUnused(op, rewriter)))
3104  return success();
3105 
3106  return failure();
3107 }
3108 
3109 //===----------------------------------------------------------------------===//
3110 // Verification Ops.
3111 //===----------------------------------------------------------------------===//
3112 
3113 static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate,
3114  Value enable,
3115  PatternRewriter &rewriter,
3116  bool eraseIfZero) {
3117  // If the verification op is never enabled, delete it.
3118  if (auto constant = enable.getDefiningOp<firrtl::ConstantOp>()) {
3119  if (constant.getValue().isZero()) {
3120  rewriter.eraseOp(op);
3121  return success();
3122  }
3123  }
3124 
3125  // If the verification op is never triggered, delete it.
3126  if (auto constant = predicate.getDefiningOp<firrtl::ConstantOp>()) {
3127  if (constant.getValue().isZero() == eraseIfZero) {
3128  rewriter.eraseOp(op);
3129  return success();
3130  }
3131  }
3132 
3133  return failure();
3134 }
3135 
3136 template <class Op, bool EraseIfZero = false>
3137 static LogicalResult canonicalizeImmediateVerifOp(Op op,
3138  PatternRewriter &rewriter) {
3139  return eraseIfZeroOrNotZero(op, op.getPredicate(), op.getEnable(), rewriter,
3140  EraseIfZero);
3141 }
3142 
3143 void AssertOp::getCanonicalizationPatterns(RewritePatternSet &results,
3144  MLIRContext *context) {
3145  results.add(canonicalizeImmediateVerifOp<AssertOp>);
3146 }
3147 
3148 void AssumeOp::getCanonicalizationPatterns(RewritePatternSet &results,
3149  MLIRContext *context) {
3150  results.add(canonicalizeImmediateVerifOp<AssumeOp>);
3151 }
3152 
3153 void CoverOp::getCanonicalizationPatterns(RewritePatternSet &results,
3154  MLIRContext *context) {
3155  results.add(canonicalizeImmediateVerifOp<CoverOp, /* EraseIfZero = */ true>);
3156 }
3157 
3158 //===----------------------------------------------------------------------===//
3159 // InvalidValueOp
3160 //===----------------------------------------------------------------------===//
3161 
3162 LogicalResult InvalidValueOp::canonicalize(InvalidValueOp op,
3163  PatternRewriter &rewriter) {
3164  // Remove `InvalidValueOp`s with no uses.
3165  if (op.use_empty()) {
3166  rewriter.eraseOp(op);
3167  return success();
3168  }
3169  return failure();
3170 }
3171 
3172 //===----------------------------------------------------------------------===//
3173 // ClockGateIntrinsicOp
3174 //===----------------------------------------------------------------------===//
3175 
3176 OpFoldResult ClockGateIntrinsicOp::fold(FoldAdaptor adaptor) {
3177  // Forward the clock if one of the enables is always true.
3178  if (isConstantOne(adaptor.getEnable()) ||
3179  isConstantOne(adaptor.getTestEnable()))
3180  return getInput();
3181 
3182  // Fold to a constant zero clock if the enables are always false.
3183  if (isConstantZero(adaptor.getEnable()) &&
3184  (!getTestEnable() || isConstantZero(adaptor.getTestEnable())))
3185  return BoolAttr::get(getContext(), false);
3186 
3187  // Forward constant zero clocks.
3188  if (isConstantZero(adaptor.getInput()))
3189  return BoolAttr::get(getContext(), false);
3190 
3191  return {};
3192 }
3193 
3194 LogicalResult ClockGateIntrinsicOp::canonicalize(ClockGateIntrinsicOp op,
3195  PatternRewriter &rewriter) {
3196  // Remove constant false test enable.
3197  if (auto testEnable = op.getTestEnable()) {
3198  if (auto constOp = testEnable.getDefiningOp<ConstantOp>()) {
3199  if (constOp.getValue().isZero()) {
3200  rewriter.updateRootInPlace(op,
3201  [&] { op.getTestEnableMutable().clear(); });
3202  return success();
3203  }
3204  }
3205  }
3206 
3207  return failure();
3208 }
3209 
3210 //===----------------------------------------------------------------------===//
3211 // Reference Ops.
3212 //===----------------------------------------------------------------------===//
3213 
3214 // refresolve(forceable.ref) -> forceable.data
3215 static LogicalResult
3216 canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter) {
3217  auto forceable = op.getRef().getDefiningOp<Forceable>();
3218  if (!forceable || !forceable.isForceable() ||
3219  op.getRef() != forceable.getDataRef() ||
3220  op.getType() != forceable.getDataType())
3221  return failure();
3222  rewriter.replaceAllUsesWith(op, forceable.getData());
3223  return success();
3224 }
3225 
3226 void RefResolveOp::getCanonicalizationPatterns(RewritePatternSet &results,
3227  MLIRContext *context) {
3228  results.insert<patterns::RefResolveOfRefSend>(context);
3229  results.insert(canonicalizeRefResolveOfForceable);
3230 }
3231 
3232 OpFoldResult RefCastOp::fold(FoldAdaptor adaptor) {
3233  // RefCast is unnecessary if types match.
3234  if (getInput().getType() == getType())
3235  return getInput();
3236  return {};
3237 }
3238 
3239 static bool isConstantZero(Value operand) {
3240  auto constOp = operand.getDefiningOp<ConstantOp>();
3241  return constOp && constOp.getValue().isZero();
3242 }
3243 
3244 template <typename Op>
3245 static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter) {
3246  if (isConstantZero(op.getPredicate())) {
3247  rewriter.eraseOp(op);
3248  return success();
3249  }
3250  return failure();
3251 }
3252 
3253 void RefForceOp::getCanonicalizationPatterns(RewritePatternSet &results,
3254  MLIRContext *context) {
3255  results.add(eraseIfPredFalse<RefForceOp>);
3256 }
3257 void RefForceInitialOp::getCanonicalizationPatterns(RewritePatternSet &results,
3258  MLIRContext *context) {
3259  results.add(eraseIfPredFalse<RefForceInitialOp>);
3260 }
3261 void RefReleaseOp::getCanonicalizationPatterns(RewritePatternSet &results,
3262  MLIRContext *context) {
3263  results.add(eraseIfPredFalse<RefReleaseOp>);
3264 }
3265 void RefReleaseInitialOp::getCanonicalizationPatterns(
3266  RewritePatternSet &results, MLIRContext *context) {
3267  results.add(eraseIfPredFalse<RefReleaseInitialOp>);
3268 }
3269 
3270 //===----------------------------------------------------------------------===//
3271 // HasBeenResetIntrinsicOp
3272 //===----------------------------------------------------------------------===//
3273 
3274 OpFoldResult HasBeenResetIntrinsicOp::fold(FoldAdaptor adaptor) {
3275  // The folds in here should reflect the ones for `verif::HasBeenResetOp`.
3276 
3277  // Fold to zero if the reset is a constant. In this case the op is either
3278  // permanently in reset or never resets. Both mean that the reset never
3279  // finishes, so this op never returns true.
3280  if (adaptor.getReset())
3281  return getIntZerosAttr(UIntType::get(getContext(), 1));
3282 
3283  // Fold to zero if the clock is a constant and the reset is synchronous. In
3284  // that case the reset will never be started.
3285  if (isUInt1(getReset().getType()) && adaptor.getClock())
3286  return getIntZerosAttr(UIntType::get(getContext(), 1));
3287 
3288  return {};
3289 }
assert(baseType &&"element must be base type")
static bool hasKnownWidthIntTypes(Operation *op)
Return true if this operation's operands and results all have a known width.
Definition: FIRRTLFolds.cpp:70
static LogicalResult canonicalizeSingleSetConnect(StrictConnectOp op, PatternRewriter &rewriter)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static LogicalResult canonicalizeImmediateVerifOp(Op op, PatternRewriter &rewriter)
static StringRef chooseName(StringRef a, StringRef b)
Definition: FIRRTLFolds.cpp:92
static bool isDefinedByOneConstantOp(Value v)
static Attribute collectFields(MLIRContext *context, ArrayRef< Attribute > operands)
static void erasePort(PatternRewriter &rewriter, Value port)
static Value getPortFieldValue(Value port, StringRef name)
static AttachOp getDominatingAttachUser(Value value, AttachOp dominatedAttach)
If the specified value has an AttachOp user strictly dominating by "dominatingAttach" then return it.
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "name" attribute.
static void updateName(PatternRewriter &rewriter, Operation *op, StringAttr name)
Set the name of an op based on the best of two names: The current name, and the name passed in.
static bool isUInt1(Type type)
Return true if this value is 1 bit UInt.
Definition: FIRRTLFolds.cpp:81
static LogicalResult demoteForceableIfUnused(OpTy op, PatternRewriter &rewriter)
static bool isPortDisabled(Value port)
static LogicalResult eraseIfZeroOrNotZero(Operation *op, Value predicate, Value enable, PatternRewriter &rewriter, bool eraseIfZero)
static APInt getMaxSignedValue(unsigned bitWidth)
Get the largest signed value of a given bit width.
static Value dropWrite(PatternRewriter &rewriter, OpResult old, Value passthrough)
Definition: FIRRTLFolds.cpp:31
static LogicalResult canonicalizePrimOp(Operation *op, PatternRewriter &rewriter, const function_ref< OpFoldResult(ArrayRef< Attribute >)> &canonicalize)
Applies the canonicalization function canonicalize to the given operation.
static void replaceWithBits(Operation *op, Value value, unsigned hiBit, unsigned loBit, PatternRewriter &rewriter)
Replace the specified operation with a 'bits' op from the specified hi/lo bits.
static LogicalResult canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter)
static LogicalResult eraseIfPredFalse(Op op, PatternRewriter &rewriter)
static OpFoldResult foldMux(OpTy op, typename OpTy::FoldAdaptor adaptor)
static APInt getMaxUnsignedValue(unsigned bitWidth)
Get the largest unsigned value of a given bit width.
static void replacePortField(PatternRewriter &rewriter, Value port, StringRef name, Value value)
BinOpKind
This is the policy for folding, which depends on the sort of operator we're processing.
static bool isPortUnused(Value port, StringRef data)
static LogicalResult canonicalizeRefResolveOfForceable(RefResolveOp op, PatternRewriter &rewriter)
static Attribute constFoldFIRRTLBinaryOp(Operation *op, ArrayRef< Attribute > operands, BinOpKind opKind, const function_ref< APInt(const APSInt &, const APSInt &)> &calculate)
Applies the constant folding function calculate to the given operands.
static APInt getMinSignedValue(unsigned bitWidth)
Get the smallest signed value of a given bit width.
static LogicalResult foldHiddenReset(RegOp reg, PatternRewriter &rewriter)
static std::optional< APSInt > getExtendedConstant(Value operand, Attribute constant, int32_t destWidth)
Implicitly replace the operand to a constant folding operation with a const 0 in case the operand is ...
static Value moveNameHint(OpResult old, Value passthrough)
Definition: FIRRTLFolds.cpp:47
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "name" attribute.
int32_t width
Definition: FIRRTL.cpp:27
static size_t bits(::capnp::schema::Type::Reader type)
Return the number of bits used by a Capnp type.
Definition: Schema.cpp:121
static int64_t size(hw::ArrayType mType, capnp::schema::Field::Reader cField)
Returns the expected size of an array (capnp list) in 64-bit words.
Definition: Schema.cpp:193
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
This is the common base class between SIntType and UIntType.
Definition: FIRRTLTypes.h:291
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
Definition: FIRRTLTypes.h:270
bool hasWidth() const
Return true if this integer type has a known width.
Definition: FIRRTLTypes.h:278
def connect(destination, source)
Definition: support.py:37
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
Forceable replaceWithNewForceability(Forceable op, bool forceable, ::mlir::PatternRewriter *rewriter=nullptr)
Replace a Forceable op with equivalent, changing whether forceable.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
IntegerAttr getIntAttr(Type type, const APInt &value)
Utiility for generating a constant attribute.
bool hasDontTouch(Value value)
Check whether a block argument ("port") or the operation defining a value has a DontTouch annotation,...
Definition: FIRRTLOps.cpp:287
bool hasDroppableName(Operation *op)
Return true if the name is droppable.
StrictConnectOp getSingleConnectUserOf(Value value)
Scan all the uses of the specified value, checking to see if there is exactly one connect that has th...
bool isUselessName(circt::StringRef name)
Return true if this is a useless temporary name produced by FIRRTL.
IntegerAttr getIntZerosAttr(Type type)
Utility for generating a constant zero attribute.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
APSInt extOrTruncZeroWidth(APSInt value, unsigned width)
A safe version of APSInt::extOrTrunc that will NOT assert on zero-width signed APSInts.
Definition: APInt.cpp:22
APInt sextZeroWidth(APInt value, unsigned width)
A safe version of APInt::sext that will NOT assert on zero-width signed APSInts.
Definition: APInt.cpp:18
static bool isConstantZero(Attribute operand)
Determine whether a constant operand is a zero value.
Definition: FoldUtils.h:27
static bool isConstantOne(Attribute operand)
Determine whether a constant operand is a one value.
Definition: FoldUtils.h:34
def reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition: seq.py:16