CIRCT  20.0.0git
CombFolds.cpp
Go to the documentation of this file.
1 //===- CombFolds.cpp - Folds + Canonicalization for Comb operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "circt/Dialect/HW/HWOps.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallBitVector.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/KnownBits.h"
18 
19 using namespace mlir;
20 using namespace circt;
21 using namespace comb;
22 using namespace matchers;
23 
24 /// In comb, we assume no knowledge of the semantics of cross-block dataflow. As
25 /// such, cross-block dataflow is interpreted as a canonicalization barrier.
26 /// This is a conservative approach which:
27 /// 1. still allows for efficient canonicalization for the common CIRCT usecase
28 /// of comb (comb logic nested inside single-block hw.module's)
29 /// 2. allows comb operations to be used in non-HW container ops - that may use
30 /// MLIR blocks and regions to represent various forms of hierarchical
31 /// abstractions, thus allowing comb to compose with other dialects.
32 static bool hasOperandsOutsideOfBlock(Operation *op) {
33  Block *thisBlock = op->getBlock();
34  return llvm::any_of(op->getOperands(), [&](Value operand) {
35  return operand.getParentBlock() != thisBlock;
36  });
37 }
38 
39 /// Create a new instance of a generic operation that only has value operands,
40 /// and has a single result value whose type matches the first operand.
41 ///
42 /// This should not be used to create instances of ops with attributes or with
43 /// more complicated type signatures.
44 static Value createGenericOp(Location loc, OperationName name,
45  ArrayRef<Value> operands, OpBuilder &builder) {
46  OperationState state(loc, name);
47  state.addOperands(operands);
48  state.addTypes(operands[0].getType());
49  return builder.create(state)->getResult(0);
50 }
51 
52 static TypedAttr getIntAttr(const APInt &value, MLIRContext *context) {
53  return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
54  value);
55 }
56 
57 /// Flatten concat and mux operands into a vector.
58 static void getConcatOperands(Value v, SmallVectorImpl<Value> &result) {
59  if (auto concat = v.getDefiningOp<ConcatOp>()) {
60  for (auto op : concat.getOperands())
61  getConcatOperands(op, result);
62  } else if (auto repl = v.getDefiningOp<ReplicateOp>()) {
63  for (size_t i = 0, e = repl.getMultiple(); i != e; ++i)
64  getConcatOperands(repl.getOperand(), result);
65  } else {
66  result.push_back(v);
67  }
68 }
69 
70 /// A wrapper of `PatternRewriter::replaceOp` to propagate "sv.namehint"
71 /// attribute. If a replaced op has a "sv.namehint" attribute, this function
72 /// propagates the name to the new value.
73 static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
74  Value newValue) {
75  if (auto *newOp = newValue.getDefiningOp()) {
76  auto name = op->getAttrOfType<StringAttr>("sv.namehint");
77  if (name && !newOp->hasAttr("sv.namehint"))
78  rewriter.modifyOpInPlace(newOp,
79  [&] { newOp->setAttr("sv.namehint", name); });
80  }
81  rewriter.replaceOp(op, newValue);
82 }
83 
84 /// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate
85 /// "sv.namehint" attribute. If a replaced op has a "sv.namehint" attribute,
86 /// this function propagates the name to the new value.
87 template <typename OpTy, typename... Args>
88 static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
89  Operation *op, Args &&...args) {
90  auto name = op->getAttrOfType<StringAttr>("sv.namehint");
91  auto newOp =
92  rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
93  if (name && !newOp->hasAttr("sv.namehint"))
94  rewriter.modifyOpInPlace(newOp,
95  [&] { newOp->setAttr("sv.namehint", name); });
96 
97  return newOp;
98 }
99 
100 // Return true if the op has SV attributes. Note that we cannot use a helper
101 // function `hasSVAttributes` defined under SV dialect because of a cyclic
102 // dependency.
103 static bool hasSVAttributes(Operation *op) {
104  return op->hasAttr("sv.attributes");
105 }
106 
107 namespace {
108 template <typename SubType>
109 struct ComplementMatcher {
110  SubType lhs;
111  ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
112  bool match(Operation *op) {
113  auto xorOp = dyn_cast<XorOp>(op);
114  return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
115  }
116 };
117 } // end anonymous namespace
118 
119 template <typename SubType>
120 static inline ComplementMatcher<SubType> m_Complement(const SubType &subExpr) {
121  return ComplementMatcher<SubType>(subExpr);
122 }
123 
124 /// Return true if the op will be flattened afterwards. Op will be flattend if
125 /// it has a single user which has a same op type. User must be in same block.
126 static bool shouldBeFlattened(Operation *op) {
127  assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
128  "must be commutative operations"));
129  if (op->hasOneUse()) {
130  auto *user = *op->getUsers().begin();
131  return user->getName() == op->getName() &&
132  op->getAttrOfType<UnitAttr>("twoState") ==
133  user->getAttrOfType<UnitAttr>("twoState") &&
134  op->getBlock() == user->getBlock();
135  }
136  return false;
137 }
138 
139 /// Flattens a single input in `op` if `hasOneUse` is true and it can be defined
140 /// as an Op. Returns true if successful, and false otherwise.
141 ///
142 /// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true
143 ///
144 static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) {
145  // Skip if the operation should be flattened by another operation.
146  if (shouldBeFlattened(op))
147  return false;
148 
149  auto inputs = op->getOperands();
150 
151  SmallVector<Value, 4> newOperands;
152  SmallVector<Location, 4> newLocations{op->getLoc()};
153  newOperands.reserve(inputs.size());
154  struct Element {
155  decltype(inputs.begin()) current, end;
156  };
157 
158  SmallVector<Element> worklist;
159  worklist.push_back({inputs.begin(), inputs.end()});
160  bool binFlag = op->hasAttrOfType<UnitAttr>("twoState");
161  bool changed = false;
162  while (!worklist.empty()) {
163  auto &element = worklist.back(); // Do not pop. Take ref.
164 
165  // Pop when we finished traversing the current operand range.
166  if (element.current == element.end) {
167  worklist.pop_back();
168  continue;
169  }
170 
171  Value value = *element.current++;
172  auto *flattenOp = value.getDefiningOp();
173  // If not defined by a compatible operation of the same kind and
174  // from the same block, keep this as-is.
175  if (!flattenOp || flattenOp->getName() != op->getName() ||
176  flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>("twoState") ||
177  flattenOp->getBlock() != op->getBlock()) {
178  newOperands.push_back(value);
179  continue;
180  }
181 
182  // Don't duplicate logic when it has multiple uses.
183  if (!value.hasOneUse()) {
184  // We can fold a multi-use binary operation into this one if this allows a
185  // constant to fold though. For example, fold
186  // (or a, b, c, (or d, cst1), cst2) --> (or a, b, c, d, cst1, cst2)
187  // since the constants will both fold and we end up with the equiv cost.
188  //
189  // We don't do this for add/mul because the hardware won't be shared
190  // between the two ops if duplicated.
191  if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
192  !flattenOp->getOperand(1).getDefiningOp<hw::ConstantOp>() ||
193  !inputs.back().getDefiningOp<hw::ConstantOp>()) {
194  newOperands.push_back(value);
195  continue;
196  }
197  }
198 
199  changed = true;
200 
201  // Otherwise, push operands into worklist.
202  auto flattenOpInputs = flattenOp->getOperands();
203  worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
204  newLocations.push_back(flattenOp->getLoc());
205  }
206 
207  if (!changed)
208  return false;
209 
210  Value result = createGenericOp(FusedLoc::get(op->getContext(), newLocations),
211  op->getName(), newOperands, rewriter);
212  if (binFlag)
213  result.getDefiningOp()->setAttr("twoState", rewriter.getUnitAttr());
214 
215  replaceOpAndCopyName(rewriter, op, result);
216  return true;
217 }
218 
219 // Given a range of uses of an operation, find the lowest and highest bits
220 // inclusive that are ever referenced. The range of uses must not be empty.
221 static std::pair<size_t, size_t>
222 getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits,
223  size_t originalOpWidth) {
224  auto users = op->getUsers();
225  assert(!users.empty() &&
226  "getLowestBitAndHighestBitRequired cannot operate on "
227  "a empty list of uses.");
228 
229  // when we don't want to narrowTrailingBits (namely in arithmetic
230  // operations), forcing lowestBitRequired = 0
231  size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
232  size_t highestBitRequired = 0;
233 
234  for (auto *user : users) {
235  if (auto extractOp = dyn_cast<ExtractOp>(user)) {
236  size_t lowBit = extractOp.getLowBit();
237  size_t highBit =
238  cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
239  highestBitRequired = std::max(highestBitRequired, highBit);
240  lowestBitRequired = std::min(lowestBitRequired, lowBit);
241  continue;
242  }
243 
244  highestBitRequired = originalOpWidth - 1;
245  lowestBitRequired = 0;
246  break;
247  }
248 
249  return {lowestBitRequired, highestBitRequired};
250 }
251 
252 template <class OpTy>
253 static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits,
254  PatternRewriter &rewriter) {
255  IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
256  if (!opType)
257  return false;
258 
259  auto range = getLowestBitAndHighestBitRequired(op, narrowTrailingBits,
260  opType.getWidth());
261  if (range.second + 1 == opType.getWidth() && range.first == 0)
262  return false;
263 
264  SmallVector<Value> args;
265  auto newType = rewriter.getIntegerType(range.second - range.first + 1);
266  for (auto inop : op.getOperands()) {
267  // deal with muxes here
268  if (inop.getType() != op.getType())
269  args.push_back(inop);
270  else
271  args.push_back(rewriter.createOrFold<ExtractOp>(inop.getLoc(), newType,
272  inop, range.first));
273  }
274  auto newop = rewriter.create<OpTy>(op.getLoc(), newType, args);
275  newop->setDialectAttrs(op->getDialectAttrs());
276  if (op.getTwoState())
277  newop.setTwoState(true);
278 
279  Value newResult = newop.getResult();
280  if (range.first)
281  newResult = rewriter.createOrFold<ConcatOp>(
282  op.getLoc(), newResult,
283  rewriter.create<hw::ConstantOp>(op.getLoc(),
284  APInt::getZero(range.first)));
285  if (range.second + 1 < opType.getWidth())
286  newResult = rewriter.createOrFold<ConcatOp>(
287  op.getLoc(),
288  rewriter.create<hw::ConstantOp>(
289  op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
290  newResult);
291  rewriter.replaceOp(op, newResult);
292  return true;
293 }
294 
295 //===----------------------------------------------------------------------===//
296 // Unary Operations
297 //===----------------------------------------------------------------------===//
298 
299 OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
300  if (hasOperandsOutsideOfBlock(getOperation()))
301  return {};
302 
303  // Replicate one time -> noop.
304  if (cast<IntegerType>(getType()).getWidth() ==
305  getInput().getType().getIntOrFloatBitWidth())
306  return getInput();
307 
308  // Constant fold.
309  if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
310  if (input.getValue().getBitWidth() == 1) {
311  if (input.getValue().isZero())
312  return getIntAttr(
313  APInt::getZero(cast<IntegerType>(getType()).getWidth()),
314  getContext());
315  return getIntAttr(
316  APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
317  getContext());
318  }
319 
320  APInt result = APInt::getZeroWidth();
321  for (auto i = getMultiple(); i != 0; --i)
322  result = result.concat(input.getValue());
323  return getIntAttr(result, getContext());
324  }
325 
326  return {};
327 }
328 
329 OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
330  if (hasOperandsOutsideOfBlock(getOperation()))
331  return {};
332 
333  // Constant fold.
334  if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
335  return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
336 
337  return {};
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // Binary Operations
342 //===----------------------------------------------------------------------===//
343 
344 /// Performs constant folding `calculate` with element-wise behavior on the two
345 /// attributes in `operands` and returns the result if possible.
346 static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
347  hw::PEO paramOpcode) {
348  assert(operands.size() == 2 && "binary op takes two operands");
349  if (!operands[0] || !operands[1])
350  return {};
351 
352  // Fold constants with ParamExprAttr::get which handles simple constants as
353  // well as parameter expressions.
354  return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
355  cast<TypedAttr>(operands[1]));
356 }
357 
358 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
359  if (hasOperandsOutsideOfBlock(getOperation()))
360  return {};
361 
362  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
363  unsigned shift = rhs.getValue().getZExtValue();
364  unsigned width = getType().getIntOrFloatBitWidth();
365  if (shift == 0)
366  return getOperand(0);
367  if (width <= shift)
368  return getIntAttr(APInt::getZero(width), getContext());
369  }
370 
371  return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::Shl);
372 }
373 
374 LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
375  if (hasOperandsOutsideOfBlock(&*op))
376  return failure();
377 
378  // ShlOp(x, cst) -> Concat(Extract(x), zeros)
379  APInt value;
380  if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
381  return failure();
382 
383  unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
384  unsigned shift = value.getZExtValue();
385 
386  // This case is handled by fold.
387  if (width <= shift || shift == 0)
388  return failure();
389 
390  auto zeros =
391  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
392 
393  // Remove the high bits which would be removed by the Shl.
394  auto extract =
395  rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), 0, width - shift);
396 
397  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, extract, zeros);
398  return success();
399 }
400 
401 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
402  if (hasOperandsOutsideOfBlock(getOperation()))
403  return {};
404 
405  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
406  unsigned shift = rhs.getValue().getZExtValue();
407  if (shift == 0)
408  return getOperand(0);
409 
410  unsigned width = getType().getIntOrFloatBitWidth();
411  if (width <= shift)
412  return getIntAttr(APInt::getZero(width), getContext());
413  }
414  return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrU);
415 }
416 
417 LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
418  if (hasOperandsOutsideOfBlock(&*op))
419  return failure();
420 
421  // ShrUOp(x, cst) -> Concat(zeros, Extract(x))
422  APInt value;
423  if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
424  return failure();
425 
426  unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
427  unsigned shift = value.getZExtValue();
428 
429  // This case is handled by fold.
430  if (width <= shift || shift == 0)
431  return failure();
432 
433  auto zeros =
434  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
435 
436  // Remove the low bits which would be removed by the Shr.
437  auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
438  width - shift);
439 
440  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, zeros, extract);
441  return success();
442 }
443 
444 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
445  if (hasOperandsOutsideOfBlock(getOperation()))
446  return {};
447 
448  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
449  if (rhs.getValue().getZExtValue() == 0)
450  return getOperand(0);
451  }
452  return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrS);
453 }
454 
455 LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
456  if (hasOperandsOutsideOfBlock(&*op))
457  return failure();
458 
459  // ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x))
460  APInt value;
461  if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
462  return failure();
463 
464  unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
465  unsigned shift = value.getZExtValue();
466 
467  auto topbit =
468  rewriter.createOrFold<ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
469  auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
470 
471  if (width <= shift) {
472  replaceOpAndCopyName(rewriter, op, {sext});
473  return success();
474  }
475 
476  auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
477  width - shift);
478 
479  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, sext, extract);
480  return success();
481 }
482 
483 //===----------------------------------------------------------------------===//
484 // Other Operations
485 //===----------------------------------------------------------------------===//
486 
487 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
488  if (hasOperandsOutsideOfBlock(getOperation()))
489  return {};
490 
491  // If we are extracting the entire input, then return it.
492  if (getInput().getType() == getType())
493  return getInput();
494 
495  // Constant fold.
496  if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
497  unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
498  return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
499  getContext());
500  }
501  return {};
502 }
503 
504 // Transforms extract(lo, cat(a, b, c, d, e)) into
505 // cat(extract(lo1, b), c, extract(lo2, d)).
506 // innerCat must be the argument of the provided ExtractOp.
507 static LogicalResult extractConcatToConcatExtract(ExtractOp op,
508  ConcatOp innerCat,
509  PatternRewriter &rewriter) {
510  auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
511  size_t beginOfFirstRelevantElement = 0;
512  auto it = reversedConcatArgs.begin();
513  size_t lowBit = op.getLowBit();
514 
515  // This loop finds the first concatArg that is covered by the ExtractOp
516  for (; it != reversedConcatArgs.end(); it++) {
517  assert(beginOfFirstRelevantElement <= lowBit &&
518  "incorrectly moved past an element that lowBit has coverage over");
519  auto operand = *it;
520 
521  size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
522  if (lowBit < beginOfFirstRelevantElement + operandWidth) {
523  // A bit other than the first bit will be used in this element.
524  // ...... ........ ...
525  // ^---lowBit
526  // ^---beginOfFirstRelevantElement
527  //
528  // Edge-case close to the end of the range.
529  // ...... ........ ...
530  // ^---(position + operandWidth)
531  // ^---lowBit
532  // ^---beginOfFirstRelevantElement
533  //
534  // Edge-case close to the beginning of the rang
535  // ...... ........ ...
536  // ^---lowBit
537  // ^---beginOfFirstRelevantElement
538  //
539  break;
540  }
541 
542  // extraction discards this element.
543  // ...... ........ ...
544  // | ^---lowBit
545  // ^---beginOfFirstRelevantElement
546  beginOfFirstRelevantElement += operandWidth;
547  }
548  assert(it != reversedConcatArgs.end() &&
549  "incorrectly failed to find an element which contains coverage of "
550  "lowBit");
551 
552  SmallVector<Value> reverseConcatArgs;
553  size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
554  size_t extractLo = lowBit - beginOfFirstRelevantElement;
555 
556  // Transform individual arguments of innerCat(..., a, b, c,) into
557  // [ extract(a), b, extract(c) ], skipping an extract operation where
558  // possible.
559  for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
560  auto concatArg = *it;
561  size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
562  size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
563 
564  if (widthToConsume == operandWidth && extractLo == 0) {
565  reverseConcatArgs.push_back(concatArg);
566  } else {
567  auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
568  reverseConcatArgs.push_back(
569  rewriter.create<ExtractOp>(op.getLoc(), resultType, *it, extractLo));
570  }
571 
572  widthRemaining -= widthToConsume;
573 
574  // Beyond the first element, all elements are extracted from position 0.
575  extractLo = 0;
576  }
577 
578  if (reverseConcatArgs.size() == 1) {
579  replaceOpAndCopyName(rewriter, op, reverseConcatArgs[0]);
580  } else {
581  replaceOpWithNewOpAndCopyName<ConcatOp>(
582  rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
583  }
584  return success();
585 }
586 
587 // Transforms extract(lo, replicate(a, N)) into replicate(a, N-c).
588 static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate,
589  PatternRewriter &rewriter) {
590  auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
591  auto replicateEltWidth =
592  replicate.getOperand().getType().getIntOrFloatBitWidth();
593 
594  // If the extract starts at the base of an element and is an even multiple,
595  // we can replace the extract with a smaller replicate.
596  if (op.getLowBit() % replicateEltWidth == 0 &&
597  extractResultWidth % replicateEltWidth == 0) {
598  replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
599  replicate.getOperand());
600  return true;
601  }
602 
603  // If the extract is completely contained in one element, extract from the
604  // element.
605  if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
606  replicateEltWidth) {
607  replaceOpWithNewOpAndCopyName<ExtractOp>(
608  rewriter, op, op.getType(), replicate.getOperand(),
609  op.getLowBit() % replicateEltWidth);
610  return true;
611  }
612 
613  // We don't currently handle the case of extracting from non-whole elements,
614  // e.g. `extract (replicate 2-bit-thing, N), 1`.
615  return false;
616 }
617 
618 LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
619  if (hasOperandsOutsideOfBlock(&*op))
620  return failure();
621 
622  auto *inputOp = op.getInput().getDefiningOp();
623 
624  // This turns out to be incredibly expensive. Disable until performance is
625  // addressed.
626 #if 0
627  // If the extracted bits are all known, then return the result.
628  auto knownBits = computeKnownBits(op.getInput())
629  .extractBits(cast<IntegerType>(op.getType()).getWidth(),
630  op.getLowBit());
631  if (knownBits.isConstant()) {
632  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
633  knownBits.getConstant());
634  return success();
635  }
636 #endif
637 
638  // extract(olo, extract(ilo, x)) = extract(olo + ilo, x)
639  if (auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
640  replaceOpWithNewOpAndCopyName<ExtractOp>(
641  rewriter, op, op.getType(), innerExtract.getInput(),
642  innerExtract.getLowBit() + op.getLowBit());
643  return success();
644  }
645 
646  // extract(lo, cat(a, b, c, d, e)) = cat(extract(lo1, b), c, extract(lo2, d))
647  if (auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
648  return extractConcatToConcatExtract(op, innerCat, rewriter);
649 
650  // extract(lo, replicate(a))
651  if (auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
652  if (extractFromReplicate(op, replicate, rewriter))
653  return success();
654 
655  // `extract(and(a, cst))` -> `extract(a)` when the relevant bits of the
656  // and/or/xor are not modifying the extracted bits.
657  if (inputOp && inputOp->getNumOperands() == 2 &&
658  isa<AndOp, OrOp, XorOp>(inputOp)) {
659  if (auto cstRHS = inputOp->getOperand(1).getDefiningOp<hw::ConstantOp>()) {
660  auto extractedCst = cstRHS.getValue().extractBits(
661  cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
662  if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
663  replaceOpWithNewOpAndCopyName<ExtractOp>(
664  rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
665  return success();
666  }
667 
668  // `extract(and(a, cst))` -> `concat(extract(a), 0)` if we only need one
669  // extract to represent the result. Turning it into a pile of extracts is
670  // always fine by our cost model, but we don't want to explode things into
671  // a ton of bits because it will bloat the IR and generated Verilog.
672  if (isa<AndOp>(inputOp)) {
673  // For our cost model, we only do this if the bit pattern is a
674  // contiguous series of ones.
675  unsigned lz = extractedCst.countLeadingZeros();
676  unsigned tz = extractedCst.countTrailingZeros();
677  unsigned pop = extractedCst.popcount();
678  if (extractedCst.getBitWidth() - lz - tz == pop) {
679  auto resultTy = rewriter.getIntegerType(pop);
680  SmallVector<Value> resultElts;
681  if (lz)
682  resultElts.push_back(rewriter.create<hw::ConstantOp>(
683  op.getLoc(), APInt::getZero(lz)));
684  resultElts.push_back(rewriter.createOrFold<ExtractOp>(
685  op.getLoc(), resultTy, inputOp->getOperand(0),
686  op.getLowBit() + tz));
687  if (tz)
688  resultElts.push_back(rewriter.create<hw::ConstantOp>(
689  op.getLoc(), APInt::getZero(tz)));
690  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
691  return success();
692  }
693  }
694  }
695  }
696 
697  // `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is
698  // extracted.
699  if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
700  if (auto shlOp = dyn_cast<ShlOp>(inputOp)) {
701  // Don't canonicalize if the shift is multiply used.
702  if (shlOp->hasOneUse())
703  if (auto lhsCst = shlOp.getLhs().getDefiningOp<hw::ConstantOp>())
704  if (lhsCst.getValue().isOne()) {
705  auto newCst = rewriter.create<hw::ConstantOp>(
706  shlOp.getLoc(),
707  APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
708  replaceOpWithNewOpAndCopyName<ICmpOp>(
709  rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
710  false);
711  return success();
712  }
713  }
714 
715  return failure();
716 }
717 
718 //===----------------------------------------------------------------------===//
719 // Associative Variadic operations
720 //===----------------------------------------------------------------------===//
721 
722 // Reduce all operands to a single value (either integer constant or parameter
723 // expression) if all the operands are constants.
724 static Attribute constFoldAssociativeOp(ArrayRef<Attribute> operands,
725  hw::PEO paramOpcode) {
726  assert(operands.size() > 1 && "caller should handle one-operand case");
727  // We can only fold anything in the case where all operands are known to be
728  // constants. Check the least common one first for an early out.
729  if (!operands[1] || !operands[0])
730  return {};
731 
732  // This will fold to a simple constant if all operands are constant.
733  if (llvm::all_of(operands.drop_front(2),
734  [&](Attribute in) { return !!in; })) {
735  SmallVector<mlir::TypedAttr> typedOperands;
736  typedOperands.reserve(operands.size());
737  for (auto operand : operands) {
738  if (auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
739  typedOperands.push_back(typedOperand);
740  else
741  break;
742  }
743  if (typedOperands.size() == operands.size())
744  return hw::ParamExprAttr::get(paramOpcode, typedOperands);
745  }
746 
747  return {};
748 }
749 
750 /// When we find a logical operation (and, or, xor) with a constant e.g.
751 /// `X & 42`, we want to push the constant into the computation of X if it leads
752 /// to simplification.
753 ///
754 /// This function handles the case where the logical operation has a concat
755 /// operand. We check to see if we can simplify the concat, e.g. when it has
756 /// constant operands.
757 ///
758 /// This returns true when a simplification happens.
759 static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp,
760  size_t concatIdx, const APInt &cst,
761  PatternRewriter &rewriter) {
762  auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<ConcatOp>();
763  assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
764 
765  // Check to see if any operands can be simplified by pushing the logical op
766  // into all parts of the concat.
767  bool canSimplify =
768  llvm::any_of(concatOp->getOperands(), [&](Value operand) -> bool {
769  auto *operandOp = operand.getDefiningOp();
770  if (!operandOp)
771  return false;
772 
773  // If the concat has a constant operand then we can transform this.
774  if (isa<hw::ConstantOp>(operandOp))
775  return true;
776  // If the concat has the same logical operation and that operation has
777  // a constant operation than we can fold it into that suboperation.
778  return operandOp->getName() == logicalOp->getName() &&
779  operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
780  operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
781  });
782 
783  if (!canSimplify)
784  return false;
785 
786  // Create a new instance of the logical operation. We have to do this the
787  // hard way since we're generic across a family of different ops.
788  auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
789  return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
790  rewriter);
791  };
792 
793  // Ok, let's do the transformation. We do this by slicing up the constant
794  // for each unit of the concat and duplicate the operation into the
795  // sub-operand.
796  SmallVector<Value> newConcatOperands;
797  newConcatOperands.reserve(concatOp->getNumOperands());
798 
799  // Work from MSB to LSB.
800  size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
801  for (Value operand : concatOp->getOperands()) {
802  size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
803  nextOperandBit -= operandWidth;
804  // Take a slice of the constant.
805  auto eltCst = rewriter.create<hw::ConstantOp>(
806  logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
807 
808  newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
809  }
810 
811  // Create the concat, and the rest of the logical op if we need it.
812  Value newResult =
813  rewriter.create<ConcatOp>(concatOp.getLoc(), newConcatOperands);
814 
815  // If we had a variadic logical op on the top level, then recreate it with the
816  // new concat and without the constant operand.
817  if (logicalOp->getNumOperands() > 2) {
818  auto origOperands = logicalOp->getOperands();
819  SmallVector<Value> operands;
820  // Take any stuff before the concat.
821  operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
822  // Take any stuff after the concat but before the constant.
823  operands.append(origOperands.begin() + concatIdx + 1,
824  origOperands.begin() + (origOperands.size() - 1));
825  // Include the new concat.
826  operands.push_back(newResult);
827  newResult = createLogicalOp(operands);
828  }
829 
830  replaceOpAndCopyName(rewriter, logicalOp, newResult);
831  return true;
832 }
833 
834 // Determines whether the inputs to a logical element are of opposite
835 // comparisons and can lowered into a constant.
836 static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands) {
837  llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
838 
839  for (auto op : operands) {
840  if (auto icmpOp = op.getDefiningOp<ICmpOp>();
841  icmpOp && icmpOp.getTwoState()) {
842  auto predicate = icmpOp.getPredicate();
843  auto lhs = icmpOp.getLhs();
844  auto rhs = icmpOp.getRhs();
845  if (seenPredicates.contains(
846  {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
847  return true;
848 
849  seenPredicates.insert({predicate, lhs, rhs});
850  }
851  }
852  return false;
853 }
854 
855 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
856  if (hasOperandsOutsideOfBlock(getOperation()))
857  return {};
858 
859  APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).getWidth());
860 
861  auto inputs = adaptor.getInputs();
862 
863  // and(x, 01, 10) -> 00 -- annulment.
864  for (auto operand : inputs) {
865  if (!operand)
866  continue;
867  value &= cast<IntegerAttr>(operand).getValue();
868  if (value.isZero())
869  return getIntAttr(value, getContext());
870  }
871 
872  // and(x, -1) -> x.
873  if (inputs.size() == 2 && inputs[1] &&
874  cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
875  return getInputs()[0];
876 
877  // and(x, x, x) -> x. This also handles and(x) -> x.
878  if (llvm::all_of(getInputs(),
879  [&](auto in) { return in == this->getInputs()[0]; }))
880  return getInputs()[0];
881 
882  // and(..., x, ..., ~x, ...) -> 0
883  for (Value arg : getInputs()) {
884  Value subExpr;
885  if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
886  for (Value arg2 : getInputs())
887  if (arg2 == subExpr)
888  return getIntAttr(
889  APInt::getZero(cast<IntegerType>(getType()).getWidth()),
890  getContext());
891  }
892  }
893 
894  // x0 = icmp(pred, x, y)
895  // x1 = icmp(!pred, x, y)
896  // and(x0, x1) -> 0
897  if (canCombineOppositeBinCmpIntoConstant(getInputs()))
898  return getIntAttr(APInt::getZero(cast<IntegerType>(getType()).getWidth()),
899  getContext());
900 
901  // Constant fold
902  return constFoldAssociativeOp(inputs, hw::PEO::And);
903 }
904 
905 /// Returns a single common operand that all inputs of the operation `op` can
906 /// be traced back to, or an empty `Value` if no such operand exists.
907 ///
908 /// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a`
909 /// (assuming the bit-width of `a` is `n`).
910 template <typename Op>
911 static Value getCommonOperand(Op op) {
912  if (!op.getType().isInteger(1))
913  return Value();
914 
915  auto inputs = op.getInputs();
916  size_t size = inputs.size();
917 
918  auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
919  if (!sourceOp)
920  return Value();
921  Value source = sourceOp.getOperand();
922 
923  // Fast path: the input size is not equal to the width of the source.
924  if (size != source.getType().getIntOrFloatBitWidth())
925  return Value();
926 
927  // Tracks the bits that were encountered.
928  llvm::BitVector bits(size);
929  bits.set(sourceOp.getLowBit());
930 
931  for (size_t i = 1; i != size; ++i) {
932  auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
933  if (!extractOp || extractOp.getOperand() != source)
934  return Value();
935  bits.set(extractOp.getLowBit());
936  }
937 
938  return bits.all() ? source : Value();
939 }
940 
941 /// Canonicalize an idempotent operation `op` so that only one input of any kind
942 /// occurs.
943 ///
944 /// Example: `and(x, y, x, z)` -> `and(x, y, z)`
945 template <typename Op>
946 static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
947  // Depth limit to search, in operations. Chosen arbitrarily, keep small.
948  constexpr unsigned limit = 3;
949  auto inputs = op.getInputs();
950 
951  llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
952  llvm::SmallDenseSet<Op, 8> checked;
953  checked.insert(op);
954 
955  struct OpWithDepth {
956  Op op;
957  unsigned depth;
958  };
959  llvm::SmallVector<OpWithDepth, 8> worklist;
960 
961  auto enqueue = [&worklist, &checked, &op](Value input, unsigned depth) {
962  // Add to worklist if within depth limit, is defined in the same block by
963  // the same kind of operation, has same two-state-ness, and not enqueued
964  // previously.
965  if (depth < limit && input.getParentBlock() == op->getBlock()) {
966  auto inputOp = input.template getDefiningOp<Op>();
967  if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
968  checked.insert(inputOp).second)
969  worklist.push_back({inputOp, depth + 1});
970  }
971  };
972 
973  for (auto input : uniqueInputs)
974  enqueue(input, 0);
975 
976  while (!worklist.empty()) {
977  auto item = worklist.pop_back_val();
978 
979  for (auto input : item.op.getInputs()) {
980  uniqueInputs.remove(input);
981  enqueue(input, item.depth);
982  }
983  }
984 
985  if (uniqueInputs.size() < inputs.size()) {
986  replaceOpWithNewOpAndCopyName<Op>(rewriter, op, op.getType(),
987  uniqueInputs.getArrayRef(),
988  op.getTwoState());
989  return true;
990  }
991 
992  return false;
993 }
994 
995 LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
996  auto inputs = op.getInputs();
997  auto size = inputs.size();
998 
999  // and(x, and(...)) -> and(x, ...) -- flatten
1000  if (tryFlatteningOperands(op, rewriter))
1001  return success();
1002 
1003  // and(..., x, ..., x) -> and(..., x, ...) -- idempotent
1004  // and(..., x, and(..., x, ...)) -> and(..., and(..., x, ...)) -- idempotent
1005  // Trivial and(x), and(x, x) cases are handled by [AndOp::fold] above.
1006  if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
1007  return success();
1008 
1009  if (hasOperandsOutsideOfBlock(&*op))
1010  return failure();
1011  assert(size > 1 && "expected 2 or more operands, `fold` should handle this");
1012 
1013  // Patterns for and with a constant on RHS.
1014  APInt value;
1015  if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1016  // and(..., '1) -> and(...) -- identity
1017  if (value.isAllOnes()) {
1018  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
1019  inputs.drop_back(), false);
1020  return success();
1021  }
1022 
1023  // TODO: Combine multiple constants together even if they aren't at the
1024  // end. and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant
1025  // folding
1026  APInt value2;
1027  if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1028  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value & value2);
1029  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1030  newOperands.push_back(cst);
1031  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
1032  newOperands, false);
1033  return success();
1034  }
1035 
1036  // Handle 'and' with a single bit constant on the RHS.
1037  if (size == 2 && value.isPowerOf2()) {
1038  // If the LHS is a replicate from a single bit, we can 'concat' it
1039  // into place. e.g.:
1040  // `replicate(x) & 4` -> `concat(zeros, x, zeros)`
1041  // TODO: Generalize this for non-single-bit operands.
1042  if (auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1043  auto replicateOperand = replicate.getOperand();
1044  if (replicateOperand.getType().isInteger(1)) {
1045  unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1046  auto trailingZeros = value.countTrailingZeros();
1047 
1048  // Don't add zero bit constants unnecessarily.
1049  SmallVector<Value, 3> concatOperands;
1050  if (trailingZeros != resultWidth - 1) {
1051  auto highZeros = rewriter.create<hw::ConstantOp>(
1052  op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
1053  concatOperands.push_back(highZeros);
1054  }
1055  concatOperands.push_back(replicateOperand);
1056  if (trailingZeros != 0) {
1057  auto lowZeros = rewriter.create<hw::ConstantOp>(
1058  op.getLoc(), APInt::getZero(trailingZeros));
1059  concatOperands.push_back(lowZeros);
1060  }
1061  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1062  concatOperands);
1063  return success();
1064  }
1065  }
1066  }
1067 
1068  // If this is an and from an extract op, try shrinking the extract.
1069  if (auto extractOp = inputs[0].getDefiningOp<ExtractOp>()) {
1070  if (size == 2 &&
1071  // We can shrink it if the mask has leading or trailing zeros.
1072  (value.countLeadingZeros() || value.countTrailingZeros())) {
1073  unsigned lz = value.countLeadingZeros();
1074  unsigned tz = value.countTrailingZeros();
1075 
1076  // Start by extracting the smaller number of bits.
1077  auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
1078  Value smallElt = rewriter.createOrFold<ExtractOp>(
1079  extractOp.getLoc(), smallTy, extractOp->getOperand(0),
1080  extractOp.getLowBit() + tz);
1081  // Apply the 'and' mask if needed.
1082  APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
1083  if (!smallMask.isAllOnes()) {
1084  auto loc = inputs.back().getLoc();
1085  smallElt = rewriter.createOrFold<AndOp>(
1086  loc, smallElt, rewriter.create<hw::ConstantOp>(loc, smallMask),
1087  false);
1088  }
1089 
1090  // The final replacement will be a concat of the leading/trailing zeros
1091  // along with the smaller extracted value.
1092  SmallVector<Value> resultElts;
1093  if (lz)
1094  resultElts.push_back(
1095  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
1096  resultElts.push_back(smallElt);
1097  if (tz)
1098  resultElts.push_back(
1099  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
1100  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
1101  return success();
1102  }
1103  }
1104 
1105  // and(concat(x, cst1), a, b, c, cst2)
1106  // ==> and(a, b, c, concat(and(x,cst2'), and(cst1,cst2'')).
1107  // We do this for even more multi-use concats since they are "just wiring".
1108  for (size_t i = 0; i < size - 1; ++i) {
1109  if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1110  if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1111  return success();
1112  }
1113  }
1114 
1115  // extracts only of and(...) -> and(extract()...)
1116  if (narrowOperationWidth(op, true, rewriter))
1117  return success();
1118 
1119  // and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1)
1120  if (auto source = getCommonOperand(op)) {
1121  auto cmpAgainst =
1122  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1123  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1124  source, cmpAgainst);
1125  return success();
1126  }
1127 
1128  /// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
1129  return failure();
1130 }
1131 
1132 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1133  if (hasOperandsOutsideOfBlock(getOperation()))
1134  return {};
1135 
1136  auto value = APInt::getZero(cast<IntegerType>(getType()).getWidth());
1137  auto inputs = adaptor.getInputs();
1138  // or(x, 10, 01) -> 11
1139  for (auto operand : inputs) {
1140  if (!operand)
1141  continue;
1142  value |= cast<IntegerAttr>(operand).getValue();
1143  if (value.isAllOnes())
1144  return getIntAttr(value, getContext());
1145  }
1146 
1147  // or(x, 0) -> x
1148  if (inputs.size() == 2 && inputs[1] &&
1149  cast<IntegerAttr>(inputs[1]).getValue().isZero())
1150  return getInputs()[0];
1151 
1152  // or(x, x, x) -> x. This also handles or(x) -> x
1153  if (llvm::all_of(getInputs(),
1154  [&](auto in) { return in == this->getInputs()[0]; }))
1155  return getInputs()[0];
1156 
1157  // or(..., x, ..., ~x, ...) -> -1
1158  for (Value arg : getInputs()) {
1159  Value subExpr;
1160  if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
1161  for (Value arg2 : getInputs())
1162  if (arg2 == subExpr)
1163  return getIntAttr(
1164  APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1165  getContext());
1166  }
1167  }
1168 
1169  // x0 = icmp(pred, x, y)
1170  // x1 = icmp(!pred, x, y)
1171  // or(x0, x1) -> 1
1172  if (canCombineOppositeBinCmpIntoConstant(getInputs()))
1173  return getIntAttr(
1174  APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1175  getContext());
1176 
1177  // Constant fold
1178  return constFoldAssociativeOp(inputs, hw::PEO::Or);
1179 }
1180 
1181 /// Simplify concat ops in an or op when a constant operand is present in either
1182 /// concat.
1183 ///
1184 /// This will invert an or(concat, concat) into concat(or, or, ...), which can
1185 /// often be further simplified due to the smaller or ops being easier to fold.
1186 ///
1187 /// For example:
1188 ///
1189 /// or(..., concat(x, 0), concat(0, y))
1190 /// ==> or(..., concat(x, 0, y)), when x and y don't overlap.
1191 ///
1192 /// or(..., concat(x: i2, cst1: i4), concat(cst2: i5, y: i1))
1193 /// ==> or(..., concat(or(x: i2, extract(cst2, 4..3)),
1194 /// or(extract(cst1, 3..1), extract(cst2, 2..0)),
1195 /// or(extract(cst1, 0..0), y: i1))
1196 static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1,
1197  size_t concatIdx2,
1198  PatternRewriter &rewriter) {
1199  assert(concatIdx1 < concatIdx2 && "concatIdx1 must be < concatIdx2");
1200 
1201  auto inputs = op.getInputs();
1202  auto concat1 = inputs[concatIdx1].getDefiningOp<ConcatOp>();
1203  auto concat2 = inputs[concatIdx2].getDefiningOp<ConcatOp>();
1204 
1205  assert(concat1 && concat2 && "expected indexes to point to ConcatOps");
1206 
1207  // We can simplify as long as a constant is present in either concat.
1208  bool hasConstantOp1 =
1209  llvm::any_of(concat1->getOperands(), [&](Value operand) -> bool {
1210  return operand.getDefiningOp<hw::ConstantOp>();
1211  });
1212  if (!hasConstantOp1) {
1213  bool hasConstantOp2 =
1214  llvm::any_of(concat2->getOperands(), [&](Value operand) -> bool {
1215  return operand.getDefiningOp<hw::ConstantOp>();
1216  });
1217  if (!hasConstantOp2)
1218  return false;
1219  }
1220 
1221  SmallVector<Value> newConcatOperands;
1222 
1223  // Simultaneously iterate over the operands of both concat ops, from MSB to
1224  // LSB, pushing out or's of overlapping ranges of the operands. When operands
1225  // span different bit ranges, we extract only the maximum overlap.
1226  auto operands1 = concat1->getOperands();
1227  auto operands2 = concat2->getOperands();
1228  // Number of bits already consumed from operands 1 and 2, respectively.
1229  unsigned consumedWidth1 = 0;
1230  unsigned consumedWidth2 = 0;
1231  for (auto it1 = operands1.begin(), end1 = operands1.end(),
1232  it2 = operands2.begin(), end2 = operands2.end();
1233  it1 != end1 && it2 != end2;) {
1234  auto operand1 = *it1;
1235  auto operand2 = *it2;
1236 
1237  unsigned remainingWidth1 =
1238  hw::getBitWidth(operand1.getType()) - consumedWidth1;
1239  unsigned remainingWidth2 =
1240  hw::getBitWidth(operand2.getType()) - consumedWidth2;
1241  unsigned widthToConsume = std::min(remainingWidth1, remainingWidth2);
1242  auto narrowedType = rewriter.getIntegerType(widthToConsume);
1243 
1244  auto extract1 = rewriter.createOrFold<ExtractOp>(
1245  op.getLoc(), narrowedType, operand1, remainingWidth1 - widthToConsume);
1246  auto extract2 = rewriter.createOrFold<ExtractOp>(
1247  op.getLoc(), narrowedType, operand2, remainingWidth2 - widthToConsume);
1248 
1249  newConcatOperands.push_back(
1250  rewriter.createOrFold<OrOp>(op.getLoc(), extract1, extract2, false));
1251 
1252  consumedWidth1 += widthToConsume;
1253  consumedWidth2 += widthToConsume;
1254 
1255  if (widthToConsume == remainingWidth1) {
1256  ++it1;
1257  consumedWidth1 = 0;
1258  }
1259  if (widthToConsume == remainingWidth2) {
1260  ++it2;
1261  consumedWidth2 = 0;
1262  }
1263  }
1264 
1265  ConcatOp newOp = rewriter.create<ConcatOp>(op.getLoc(), newConcatOperands);
1266 
1267  // Copy the old operands except for concatIdx1 and concatIdx2, and append the
1268  // new ConcatOp to the end.
1269  SmallVector<Value> newOrOperands;
1270  newOrOperands.append(inputs.begin(), inputs.begin() + concatIdx1);
1271  newOrOperands.append(inputs.begin() + concatIdx1 + 1,
1272  inputs.begin() + concatIdx2);
1273  newOrOperands.append(inputs.begin() + concatIdx2 + 1,
1274  inputs.begin() + inputs.size());
1275  newOrOperands.push_back(newOp);
1276 
1277  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1278  newOrOperands);
1279  return true;
1280 }
1281 
1282 LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
1283  auto inputs = op.getInputs();
1284  auto size = inputs.size();
1285 
1286  // or(x, or(...)) -> or(x, ...) -- flatten
1287  if (tryFlatteningOperands(op, rewriter))
1288  return success();
1289 
1290  // or(..., x, ..., x, ...) -> or(..., x) -- idempotent
1291  // or(..., x, or(..., x, ...)) -> or(..., or(..., x, ...)) -- idempotent
1292  // Trivial or(x), or(x, x) cases are handled by [OrOp::fold].
1293  if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
1294  return success();
1295 
1296  if (hasOperandsOutsideOfBlock(&*op))
1297  return failure();
1298  assert(size > 1 && "expected 2 or more operands");
1299 
1300  // Patterns for and with a constant on RHS.
1301  APInt value;
1302  if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1303  // or(..., '0) -> or(...) -- identity
1304  if (value.isZero()) {
1305  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1306  inputs.drop_back());
1307  return success();
1308  }
1309 
1310  // or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant folding
1311  APInt value2;
1312  if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1313  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value | value2);
1314  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1315  newOperands.push_back(cst);
1316  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1317  newOperands);
1318  return success();
1319  }
1320 
1321  // or(concat(x, cst1), a, b, c, cst2)
1322  // ==> or(a, b, c, concat(or(x,cst2'), or(cst1,cst2'')).
1323  // We do this for even more multi-use concats since they are "just wiring".
1324  for (size_t i = 0; i < size - 1; ++i) {
1325  if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1326  if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1327  return success();
1328  }
1329  }
1330 
1331  // or(..., concat(x, cst1), concat(cst2, y)
1332  // ==> or(..., concat(x, cst3, y)), when x and y don't overlap.
1333  for (size_t i = 0; i < size - 1; ++i) {
1334  if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1335  for (size_t j = i + 1; j < size; ++j)
1336  if (auto concat = inputs[j].getDefiningOp<ConcatOp>())
1337  if (canonicalizeOrOfConcatsWithCstOperands(op, i, j, rewriter))
1338  return success();
1339  }
1340 
1341  // extracts only of or(...) -> or(extract()...)
1342  if (narrowOperationWidth(op, true, rewriter))
1343  return success();
1344 
1345  // or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0)
1346  if (auto source = getCommonOperand(op)) {
1347  auto cmpAgainst =
1348  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1349  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1350  source, cmpAgainst);
1351  return success();
1352  }
1353 
1354  // or(mux(c_1, a, 0), mux(c_2, a, 0), ..., mux(c_n, a, 0)) -> mux(or(c_1, c_2,
1355  // .., c_n), a, 0)
1356  if (auto firstMux = op.getOperand(0).getDefiningOp<comb::MuxOp>()) {
1357  APInt value;
1358  if (op.getTwoState() && firstMux.getTwoState() &&
1359  matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1360  value.isZero()) {
1361  SmallVector<Value> conditions{firstMux.getCond()};
1362  auto check = [&](Value v) {
1363  auto mux = v.getDefiningOp<comb::MuxOp>();
1364  if (!mux)
1365  return false;
1366  conditions.push_back(mux.getCond());
1367  return mux.getTwoState() &&
1368  firstMux.getTrueValue() == mux.getTrueValue() &&
1369  firstMux.getFalseValue() == mux.getFalseValue();
1370  };
1371  if (llvm::all_of(op.getOperands().drop_front(), check)) {
1372  auto cond = rewriter.create<comb::OrOp>(op.getLoc(), conditions, true);
1373  replaceOpWithNewOpAndCopyName<comb::MuxOp>(
1374  rewriter, op, cond, firstMux.getTrueValue(),
1375  firstMux.getFalseValue(), true);
1376  return success();
1377  }
1378  }
1379  }
1380 
1381  /// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
1382  return failure();
1383 }
1384 
1385 OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1386  if (hasOperandsOutsideOfBlock(getOperation()))
1387  return {};
1388 
1389  auto size = getInputs().size();
1390  auto inputs = adaptor.getInputs();
1391 
1392  // xor(x) -> x -- noop
1393  if (size == 1)
1394  return getInputs()[0];
1395 
1396  // xor(x, x) -> 0 -- idempotent
1397  if (size == 2 && getInputs()[0] == getInputs()[1])
1398  return IntegerAttr::get(getType(), 0);
1399 
1400  // xor(x, 0) -> x
1401  if (inputs.size() == 2 && inputs[1] &&
1402  cast<IntegerAttr>(inputs[1]).getValue().isZero())
1403  return getInputs()[0];
1404 
1405  // xor(xor(x,1),1) -> x
1406  // but not self loop
1407  if (isBinaryNot()) {
1408  Value subExpr;
1409  if (matchPattern(getOperand(0), m_Complement(m_Any(&subExpr))) &&
1410  subExpr != getResult())
1411  return subExpr;
1412  }
1413 
1414  // Constant fold
1415  return constFoldAssociativeOp(inputs, hw::PEO::Xor);
1416 }
1417 
1418 // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1419 static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand,
1420  PatternRewriter &rewriter) {
1421  auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1422  auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1423 
1424  Value result =
1425  rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1426  icmp.getOperand(1), icmp.getTwoState());
1427 
1428  // If the xor had other operands, rebuild it.
1429  if (op.getNumOperands() > 2) {
1430  SmallVector<Value, 4> newOperands(op.getOperands());
1431  newOperands.pop_back();
1432  newOperands.erase(newOperands.begin() + icmpOperand);
1433  newOperands.push_back(result);
1434  result = rewriter.create<XorOp>(op.getLoc(), newOperands, op.getTwoState());
1435  }
1436 
1437  replaceOpAndCopyName(rewriter, op, result);
1438 }
1439 
1440 LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
1441  if (hasOperandsOutsideOfBlock(&*op))
1442  return failure();
1443 
1444  auto inputs = op.getInputs();
1445  auto size = inputs.size();
1446  assert(size > 1 && "expected 2 or more operands");
1447 
1448  // xor(..., x, x) -> xor (...) -- idempotent
1449  if (inputs[size - 1] == inputs[size - 2]) {
1450  assert(size > 2 &&
1451  "expected idempotent case for 2 elements handled already.");
1452  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1453  inputs.drop_back(/*n=*/2), false);
1454  return success();
1455  }
1456 
1457  // Patterns for xor with a constant on RHS.
1458  APInt value;
1459  if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1460  // xor(..., 0) -> xor(...) -- identity
1461  if (value.isZero()) {
1462  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1463  inputs.drop_back(), false);
1464  return success();
1465  }
1466 
1467  // xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2.
1468  APInt value2;
1469  if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1470  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value ^ value2);
1471  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1472  newOperands.push_back(cst);
1473  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1474  newOperands, false);
1475  return success();
1476  }
1477 
1478  bool isSingleBit = value.getBitWidth() == 1;
1479 
1480  // Check for subexpressions that we can simplify.
1481  for (size_t i = 0; i < size - 1; ++i) {
1482  Value operand = inputs[i];
1483 
1484  // xor(concat(x, cst1), a, b, c, cst2)
1485  // ==> xor(a, b, c, concat(xor(x,cst2'), xor(cst1,cst2'')).
1486  // We do this for even more multi-use concats since they are "just
1487  // wiring".
1488  if (auto concat = operand.getDefiningOp<ConcatOp>())
1489  if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1490  return success();
1491 
1492  // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1493  if (isSingleBit && operand.hasOneUse()) {
1494  assert(value == 1 && "single bit constant has to be one if not zero");
1495  if (auto icmp = operand.getDefiningOp<ICmpOp>())
1496  return canonicalizeXorIcmpTrue(op, i, rewriter), success();
1497  }
1498  }
1499  }
1500 
1501  // xor(x, xor(...)) -> xor(x, ...) -- flatten
1502  if (tryFlatteningOperands(op, rewriter))
1503  return success();
1504 
1505  // extracts only of xor(...) -> xor(extract()...)
1506  if (narrowOperationWidth(op, true, rewriter))
1507  return success();
1508 
1509  // xor(a[0], a[1], ..., a[n]) -> parity(a)
1510  if (auto source = getCommonOperand(op)) {
1511  replaceOpWithNewOpAndCopyName<ParityOp>(rewriter, op, source);
1512  return success();
1513  }
1514 
1515  return failure();
1516 }
1517 
1518 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1519  if (hasOperandsOutsideOfBlock(getOperation()))
1520  return {};
1521 
1522  // sub(x - x) -> 0
1523  if (getRhs() == getLhs())
1524  return getIntAttr(
1525  APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1526  getContext());
1527 
1528  if (adaptor.getRhs()) {
1529  // If both are constants, we can unconditionally fold.
1530  if (adaptor.getLhs()) {
1531  // Constant fold (c1 - c2) => (c1 + -1*c2).
1532  auto negOne = getIntAttr(
1533  APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1534  getContext());
1535  auto rhsNeg = hw::ParamExprAttr::get(
1536  hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1537  return hw::ParamExprAttr::get(hw::PEO::Add,
1538  cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1539  }
1540 
1541  // sub(x - 0) -> x
1542  if (auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1543  if (rhsC.getValue().isZero())
1544  return getLhs();
1545  }
1546  }
1547 
1548  return {};
1549 }
1550 
1551 LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
1552  if (hasOperandsOutsideOfBlock(&*op))
1553  return failure();
1554 
1555  // sub(x, cst) -> add(x, -cst)
1556  APInt value;
1557  if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1558  auto negCst = rewriter.create<hw::ConstantOp>(op.getLoc(), -value);
1559  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getLhs(), negCst,
1560  false);
1561  return success();
1562  }
1563 
1564  // extracts only of sub(...) -> sub(extract()...)
1565  if (narrowOperationWidth(op, false, rewriter))
1566  return success();
1567 
1568  return failure();
1569 }
1570 
1571 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1572  if (hasOperandsOutsideOfBlock(getOperation()))
1573  return {};
1574 
1575  auto size = getInputs().size();
1576 
1577  // add(x) -> x -- noop
1578  if (size == 1u)
1579  return getInputs()[0];
1580 
1581  // Constant fold constant operands.
1582  return constFoldAssociativeOp(adaptor.getOperands(), hw::PEO::Add);
1583 }
1584 
1585 LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
1586  if (hasOperandsOutsideOfBlock(&*op))
1587  return failure();
1588 
1589  auto inputs = op.getInputs();
1590  auto size = inputs.size();
1591  assert(size > 1 && "expected 2 or more operands");
1592 
1593  APInt value, value2;
1594 
1595  // add(..., 0) -> add(...) -- identity
1596  if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1597  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1598  inputs.drop_back(), false);
1599  return success();
1600  }
1601 
1602  // add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant folding
1603  if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1604  matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1605  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
1606  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1607  newOperands.push_back(cst);
1608  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1609  newOperands, false);
1610  return success();
1611  }
1612 
1613  // add(..., x, x) -> add(..., shl(x, 1))
1614  if (inputs[size - 1] == inputs[size - 2]) {
1615  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1616 
1617  auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1618  auto shiftLeftOp =
1619  rewriter.create<comb::ShlOp>(op.getLoc(), inputs.back(), one, false);
1620 
1621  newOperands.push_back(shiftLeftOp);
1622  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1623  newOperands, false);
1624  return success();
1625  }
1626 
1627  auto shlOp = inputs[size - 1].getDefiningOp<comb::ShlOp>();
1628  // add(..., x, shl(x, c)) -> add(..., mul(x, (1 << c) + 1))
1629  if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1630  matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1631 
1632  APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1633  auto rhs =
1634  rewriter.create<hw::ConstantOp>(op.getLoc(), (one << value) + one);
1635 
1636  std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1637  auto mulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
1638 
1639  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1640  newOperands.push_back(mulOp);
1641  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1642  newOperands, false);
1643  return success();
1644  }
1645 
1646  auto mulOp = inputs[size - 1].getDefiningOp<comb::MulOp>();
1647  // add(..., x, mul(x, c)) -> add(..., mul(x, c + 1))
1648  if (mulOp && mulOp.getInputs().size() == 2 &&
1649  mulOp.getInputs()[0] == inputs[size - 2] &&
1650  matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1651 
1652  APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1653  auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + one);
1654  std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1655  auto newMulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
1656 
1657  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1658  newOperands.push_back(newMulOp);
1659  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1660  newOperands, false);
1661  return success();
1662  }
1663 
1664  // add(a, add(...)) -> add(a, ...) -- flatten
1665  if (tryFlatteningOperands(op, rewriter))
1666  return success();
1667 
1668  // extracts only of add(...) -> add(extract()...)
1669  if (narrowOperationWidth(op, false, rewriter))
1670  return success();
1671 
1672  // add(add(x, c1), c2) -> add(x, c1 + c2)
1673  auto addOp = inputs[0].getDefiningOp<comb::AddOp>();
1674  if (addOp && addOp.getInputs().size() == 2 &&
1675  matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1676  inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1677 
1678  auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
1679  replaceOpWithNewOpAndCopyName<AddOp>(
1680  rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1681  /*twoState=*/op.getTwoState() && addOp.getTwoState());
1682  return success();
1683  }
1684 
1685  return failure();
1686 }
1687 
1688 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1689  if (hasOperandsOutsideOfBlock(getOperation()))
1690  return {};
1691 
1692  auto size = getInputs().size();
1693  auto inputs = adaptor.getInputs();
1694 
1695  // mul(x) -> x -- noop
1696  if (size == 1u)
1697  return getInputs()[0];
1698 
1699  auto width = cast<IntegerType>(getType()).getWidth();
1700  APInt value(/*numBits=*/width, 1, /*isSigned=*/false);
1701 
1702  // mul(x, 0, 1) -> 0 -- annulment
1703  for (auto operand : inputs) {
1704  if (!operand)
1705  continue;
1706  value *= cast<IntegerAttr>(operand).getValue();
1707  if (value.isZero())
1708  return getIntAttr(value, getContext());
1709  }
1710 
1711  // Constant fold
1712  return constFoldAssociativeOp(inputs, hw::PEO::Mul);
1713 }
1714 
1715 LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
1716  if (hasOperandsOutsideOfBlock(&*op))
1717  return failure();
1718 
1719  auto inputs = op.getInputs();
1720  auto size = inputs.size();
1721  assert(size > 1 && "expected 2 or more operands");
1722 
1723  APInt value, value2;
1724 
1725  // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
1726  if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1727  value.isPowerOf2()) {
1728  auto shift = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(),
1729  value.exactLogBase2());
1730  auto shlOp =
1731  rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift, false);
1732 
1733  replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1734  ArrayRef<Value>(shlOp), false);
1735  return success();
1736  }
1737 
1738  // mul(..., 1) -> mul(...) -- identity
1739  if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1740  replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1741  inputs.drop_back());
1742  return success();
1743  }
1744 
1745  // mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant folding
1746  if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1747  matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1748  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value * value2);
1749  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1750  newOperands.push_back(cst);
1751  replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1752  newOperands);
1753  return success();
1754  }
1755 
1756  // mul(a, mul(...)) -> mul(a, ...) -- flatten
1757  if (tryFlatteningOperands(op, rewriter))
1758  return success();
1759 
1760  // extracts only of mul(...) -> mul(extract()...)
1761  if (narrowOperationWidth(op, false, rewriter))
1762  return success();
1763 
1764  return failure();
1765 }
1766 
1767 template <class Op, bool isSigned>
1768 static OpFoldResult foldDiv(Op op, ArrayRef<Attribute> constants) {
1769  if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1770  // divu(x, 1) -> x, divs(x, 1) -> x
1771  if (rhsValue.getValue() == 1)
1772  return op.getLhs();
1773 
1774  // If the divisor is zero, do not fold for now.
1775  if (rhsValue.getValue().isZero())
1776  return {};
1777  }
1778 
1779  return constFoldBinaryOp(constants, isSigned ? hw::PEO::DivS : hw::PEO::DivU);
1780 }
1781 
1782 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1783  if (hasOperandsOutsideOfBlock(getOperation()))
1784  return {};
1785 
1786  return foldDiv<DivUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1787 }
1788 
1789 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1790  if (hasOperandsOutsideOfBlock(getOperation()))
1791  return {};
1792 
1793  return foldDiv<DivSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1794 }
1795 
1796 template <class Op, bool isSigned>
1797 static OpFoldResult foldMod(Op op, ArrayRef<Attribute> constants) {
1798  if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1799  // modu(x, 1) -> 0, mods(x, 1) -> 0
1800  if (rhsValue.getValue() == 1)
1801  return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1802  op.getContext());
1803 
1804  // If the divisor is zero, do not fold for now.
1805  if (rhsValue.getValue().isZero())
1806  return {};
1807  }
1808 
1809  if (auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1810  // modu(0, x) -> 0, mods(0, x) -> 0
1811  if (lhsValue.getValue().isZero())
1812  return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1813  op.getContext());
1814  }
1815 
1816  return constFoldBinaryOp(constants, isSigned ? hw::PEO::ModS : hw::PEO::ModU);
1817 }
1818 
1819 OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1820  if (hasOperandsOutsideOfBlock(getOperation()))
1821  return {};
1822 
1823  return foldMod<ModUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1824 }
1825 
1826 OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1827  if (hasOperandsOutsideOfBlock(getOperation()))
1828  return {};
1829 
1830  return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1831 }
1832 //===----------------------------------------------------------------------===//
1833 // ConcatOp
1834 //===----------------------------------------------------------------------===//
1835 
1836 // Constant folding
1837 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1838  if (hasOperandsOutsideOfBlock(getOperation()))
1839  return {};
1840 
1841  if (getNumOperands() == 1)
1842  return getOperand(0);
1843 
1844  // If all the operands are constant, we can fold.
1845  for (auto attr : adaptor.getInputs())
1846  if (!attr || !isa<IntegerAttr>(attr))
1847  return {};
1848 
1849  // If we got here, we can constant fold.
1850  unsigned resultWidth = getType().getIntOrFloatBitWidth();
1851  APInt result(resultWidth, 0);
1852 
1853  unsigned nextInsertion = resultWidth;
1854  // Insert each chunk into the result.
1855  for (auto attr : adaptor.getInputs()) {
1856  auto chunk = cast<IntegerAttr>(attr).getValue();
1857  nextInsertion -= chunk.getBitWidth();
1858  result.insertBits(chunk, nextInsertion);
1859  }
1860 
1861  return getIntAttr(result, getContext());
1862 }
1863 
1864 LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
1865  if (hasOperandsOutsideOfBlock(&*op))
1866  return failure();
1867 
1868  auto inputs = op.getInputs();
1869  auto size = inputs.size();
1870  assert(size > 1 && "expected 2 or more operands");
1871 
1872  // This function is used when we flatten neighboring operands of a
1873  // (variadic) concat into a new vesion of the concat. first/last indices
1874  // are inclusive.
1875  auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex,
1876  ValueRange replacements) -> LogicalResult {
1877  SmallVector<Value, 4> newOperands;
1878  newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1879  newOperands.append(replacements.begin(), replacements.end());
1880  newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1881  if (newOperands.size() == 1)
1882  replaceOpAndCopyName(rewriter, op, newOperands[0]);
1883  else
1884  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1885  newOperands);
1886  return success();
1887  };
1888 
1889  Value commonOperand = inputs[0];
1890  for (size_t i = 0; i != size; ++i) {
1891  // Check to see if all operands are the same.
1892  if (inputs[i] != commonOperand)
1893  commonOperand = Value();
1894 
1895  // If an operand to the concat is itself a concat, then we can fold them
1896  // together.
1897  if (auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1898  return flattenConcat(i, i, subConcat->getOperands());
1899 
1900  // Check for canonicalization due to neighboring operands.
1901  if (i != 0) {
1902  // Merge neighboring constants.
1903  if (auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1904  if (auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1905  unsigned prevWidth = prevCst.getValue().getBitWidth();
1906  unsigned thisWidth = cst.getValue().getBitWidth();
1907  auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1908  resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1909  << thisWidth;
1910  Value replacement =
1911  rewriter.create<hw::ConstantOp>(op.getLoc(), resultCst);
1912  return flattenConcat(i - 1, i, replacement);
1913  }
1914  }
1915 
1916  // If the two operands are the same, turn them into a replicate.
1917  if (inputs[i] == inputs[i - 1]) {
1918  Value replacement =
1919  rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1920  return flattenConcat(i - 1, i, replacement);
1921  }
1922 
1923  // If this input is a replicate, see if we can fold it with the previous
1924  // one.
1925  if (auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1926  // ... x, repl(x, n), ... ==> ..., repl(x, n+1), ...
1927  if (repl.getOperand() == inputs[i - 1]) {
1928  Value replacement = rewriter.createOrFold<ReplicateOp>(
1929  op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1930  return flattenConcat(i - 1, i, replacement);
1931  }
1932  // ... repl(x, n), repl(x, m), ... ==> ..., repl(x, n+m), ...
1933  if (auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1934  if (prevRepl.getOperand() == repl.getOperand()) {
1935  Value replacement = rewriter.createOrFold<ReplicateOp>(
1936  op.getLoc(), repl.getOperand(),
1937  repl.getMultiple() + prevRepl.getMultiple());
1938  return flattenConcat(i - 1, i, replacement);
1939  }
1940  }
1941  }
1942 
1943  // ... repl(x, n), x, ... ==> ..., repl(x, n+1), ...
1944  if (auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1945  if (repl.getOperand() == inputs[i]) {
1946  Value replacement = rewriter.createOrFold<ReplicateOp>(
1947  op.getLoc(), inputs[i], repl.getMultiple() + 1);
1948  return flattenConcat(i - 1, i, replacement);
1949  }
1950  }
1951 
1952  // Merge neighboring extracts of neighboring inputs, e.g.
1953  // {A[3], A[2]} -> A[3:2]
1954  if (auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1955  if (auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1956  if (extract.getInput() == prevExtract.getInput()) {
1957  auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1958  if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1959  auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1960  auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1961  Value replacement = rewriter.create<ExtractOp>(
1962  op.getLoc(), resType, extract.getInput(),
1963  extract.getLowBit());
1964  return flattenConcat(i - 1, i, replacement);
1965  }
1966  }
1967  }
1968  }
1969  // Merge neighboring array extracts of neighboring inputs, e.g.
1970  // {Array[4], bitcast(Array[3:2])} -> bitcast(A[4:2])
1971 
1972  // This represents a slice of an array.
1973  struct ArraySlice {
1974  Value input;
1975  Value index;
1976  size_t width;
1977  static std::optional<ArraySlice> get(Value value) {
1978  assert(isa<IntegerType>(value.getType()) && "expected integer type");
1979  if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
1980  return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1981  // array slice op is wrapped with bitcast.
1982  if (auto bitcast = value.getDefiningOp<hw::BitcastOp>())
1983  if (auto arraySlice =
1984  bitcast.getInput().getDefiningOp<hw::ArraySliceOp>())
1985  return ArraySlice{
1986  arraySlice.getInput(), arraySlice.getLowIndex(),
1987  hw::type_cast<hw::ArrayType>(arraySlice.getType())
1988  .getNumElements()};
1989  return std::nullopt;
1990  }
1991  };
1992  if (auto extractOpt = ArraySlice::get(inputs[i])) {
1993  if (auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1994  // Check that two array slices are mergable.
1995  if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1996  prevExtractOpt->input == extractOpt->input &&
1997  hw::isOffset(extractOpt->index, prevExtractOpt->index,
1998  extractOpt->width)) {
1999  auto resType = hw::ArrayType::get(
2000  hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
2001  .getElementType(),
2002  extractOpt->width + prevExtractOpt->width);
2003  auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
2004  Value replacement = rewriter.create<hw::BitcastOp>(
2005  op.getLoc(), resIntType,
2006  rewriter.create<hw::ArraySliceOp>(op.getLoc(), resType,
2007  prevExtractOpt->input,
2008  extractOpt->index));
2009  return flattenConcat(i - 1, i, replacement);
2010  }
2011  }
2012  }
2013  }
2014  }
2015 
2016  // If all operands were the same, then this is a replicate.
2017  if (commonOperand) {
2018  replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2019  commonOperand);
2020  return success();
2021  }
2022 
2023  return failure();
2024 }
2025 
2026 //===----------------------------------------------------------------------===//
2027 // MuxOp
2028 //===----------------------------------------------------------------------===//
2029 
2030 OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
2031  if (hasOperandsOutsideOfBlock(getOperation()))
2032  return {};
2033 
2034  // mux (c, b, b) -> b
2035  if (getTrueValue() == getFalseValue())
2036  return getTrueValue();
2037  if (auto tv = adaptor.getTrueValue())
2038  if (tv == adaptor.getFalseValue())
2039  return tv;
2040 
2041  // mux(0, a, b) -> b
2042  // mux(1, a, b) -> a
2043  if (auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
2044  if (pred.getValue().isZero())
2045  return getFalseValue();
2046  return getTrueValue();
2047  }
2048 
2049  // mux(cond, 1, 0) -> cond
2050  if (auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
2051  if (auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
2052  if (tv.getValue().isOne() && fv.getValue().isZero() &&
2053  hw::getBitWidth(getType()) == 1)
2054  return getCond();
2055 
2056  return {};
2057 }
2058 
2059 /// Check to see if the condition to the specified mux is an equality
2060 /// comparison `indexValue` and one or more constants. If so, put the
2061 /// constants in the constants vector and return true, otherwise return false.
2062 ///
2063 /// This is part of foldMuxChain.
2064 ///
2065 static bool
2066 getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted,
2067  std::function<void(hw::ConstantOp)> constantFn) {
2068  // Handle `idx == 42` and `idx != 42`.
2069  if (auto cmp = cond.getDefiningOp<ICmpOp>()) {
2070  // TODO: We could handle things like "x < 2" as two entries.
2071  auto requiredPredicate =
2072  (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
2073  if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
2074  if (auto cst = cmp.getRhs().getDefiningOp<hw::ConstantOp>()) {
2075  constantFn(cst);
2076  return true;
2077  }
2078  }
2079  return false;
2080  }
2081 
2082  // Handle mux(`idx == 1 || idx == 3`, value, muxchain).
2083  if (auto orOp = cond.getDefiningOp<OrOp>()) {
2084  if (!isInverted)
2085  return false;
2086  for (auto operand : orOp.getOperands())
2087  if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
2088  return false;
2089  return true;
2090  }
2091 
2092  // Handle mux(`idx != 1 && idx != 3`, muxchain, value).
2093  if (auto andOp = cond.getDefiningOp<AndOp>()) {
2094  if (isInverted)
2095  return false;
2096  for (auto operand : andOp.getOperands())
2097  if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
2098  return false;
2099  return true;
2100  }
2101 
2102  return false;
2103 }
2104 
2105 /// Given a mux, check to see if the "on true" value (or "on false" value if
2106 /// isFalseSide=true) is a mux tree with the same condition. This allows us
2107 /// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into
2108 /// `array_get (array_create(A, B, C), VAL)` which is far more compact and
2109 /// allows synthesis tools to do more interesting optimizations.
2110 ///
2111 /// This returns false if we cannot form the mux tree (or do not want to) and
2112 /// returns true if the mux was replaced.
2113 static bool foldMuxChain(MuxOp rootMux, bool isFalseSide,
2114  PatternRewriter &rewriter) {
2115  // Get the index value being compared. Later we check to see if it is
2116  // compared to a constant with the right predicate.
2117  auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2118  if (!rootCmp)
2119  return false;
2120  Value indexValue = rootCmp.getLhs();
2121 
2122  // Return the value to use if the equality match succeeds.
2123  auto getCaseValue = [&](MuxOp mux) -> Value {
2124  return mux.getOperand(1 + unsigned(!isFalseSide));
2125  };
2126 
2127  // Return the value to use if the equality match fails. This is the next
2128  // mux in the sequence or the "otherwise" value.
2129  auto getTreeValue = [&](MuxOp mux) -> Value {
2130  return mux.getOperand(1 + unsigned(isFalseSide));
2131  };
2132 
2133  // Start scanning the mux tree to see what we've got. Keep track of the
2134  // constant comparison value and the SSA value to use when equal to it.
2135  SmallVector<Location> locationsFound;
2136  SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2137 
2138  /// Extract constants and values into `valuesFound` and return true if this is
2139  /// part of the mux tree, otherwise return false.
2140  auto collectConstantValues = [&](MuxOp mux) -> bool {
2141  return getMuxChainCondConstant(
2142  mux.getCond(), indexValue, isFalseSide, [&](hw::ConstantOp cst) {
2143  valuesFound.push_back({cst, getCaseValue(mux)});
2144  locationsFound.push_back(mux.getCond().getLoc());
2145  locationsFound.push_back(mux->getLoc());
2146  });
2147  };
2148 
2149  // Make sure the root is a correct comparison with a constant.
2150  if (!collectConstantValues(rootMux))
2151  return false;
2152 
2153  // Make sure that we're not looking at the intermediate node in a mux tree.
2154  if (rootMux->hasOneUse()) {
2155  if (auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2156  if (getTreeValue(userMux) == rootMux.getResult() &&
2157  getMuxChainCondConstant(userMux.getCond(), indexValue, isFalseSide,
2158  [&](hw::ConstantOp cst) {}))
2159  return false;
2160  }
2161  }
2162 
2163  // Scan up the tree linearly.
2164  auto nextTreeValue = getTreeValue(rootMux);
2165  while (1) {
2166  auto nextMux = nextTreeValue.getDefiningOp<MuxOp>();
2167  if (!nextMux || !nextMux->hasOneUse())
2168  break;
2169  if (!collectConstantValues(nextMux))
2170  break;
2171  nextTreeValue = getTreeValue(nextMux);
2172  }
2173 
2174  // We need to have more than three values to create an array. This is an
2175  // arbitrary threshold which is saying that one or two muxes together is ok,
2176  // but three should be folded.
2177  if (valuesFound.size() < 3)
2178  return false;
2179 
2180  // If the array is greater that 9 bits, it will take over 512 elements and
2181  // it will be too large for a single expression.
2182  auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2183  if (indexWidth >= 9)
2184  return false;
2185 
2186  // Next we need to see if the values are dense-ish. We don't want to have
2187  // a tremendous number of replicated entries in the array. Some sparsity is
2188  // ok though, so we require the table to be at least 5/8 utilized.
2189  uint64_t tableSize = 1ULL << indexWidth;
2190  if (valuesFound.size() < (tableSize * 5) / 8)
2191  return false; // Not dense enough.
2192 
2193  // Ok, we're going to do the transformation, start by building the table
2194  // filled with the "otherwise" value.
2195  SmallVector<Value, 8> table(tableSize, nextTreeValue);
2196 
2197  // Fill in entries in the table from the leaf to the root of the expression.
2198  // This ensures that any duplicate matches end up with the ultimate value,
2199  // which is the one closer to the root.
2200  for (auto &elt : llvm::reverse(valuesFound)) {
2201  uint64_t idx = elt.first.getValue().getZExtValue();
2202  assert(idx < table.size() && "constant should be same bitwidth as index");
2203  table[idx] = elt.second;
2204  }
2205 
2206  // The hw.array_create operation has the operand list in unintuitive order
2207  // with a[0] stored as the last element, not the first.
2208  std::reverse(table.begin(), table.end());
2209 
2210  // Build the array_create and the array_get.
2211  auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2212  auto array = rewriter.create<hw::ArrayCreateOp>(fusedLoc, table);
2213  replaceOpWithNewOpAndCopyName<hw::ArrayGetOp>(rewriter, rootMux, array,
2214  indexValue);
2215  return true;
2216 }
2217 
2218 /// Given a fully associative variadic operation like (a+b+c+d), break the
2219 /// expression into two parts, one without the specified operand (e.g.
2220 /// `tmp = a+b+d`) and one that combines that into the full expression (e.g.
2221 /// `tmp+c`), and return the inner expression.
2222 ///
2223 /// NOTE: This mutates the operation in place if it only has a single user,
2224 /// which assumes that user will be removed.
2225 ///
2226 static Value extractOperandFromFullyAssociative(Operation *fullyAssoc,
2227  size_t operandNo,
2228  PatternRewriter &rewriter) {
2229  assert(fullyAssoc->getNumOperands() >= 2 && "cannot split up unary ops");
2230  assert(operandNo < fullyAssoc->getNumOperands() && "Invalid operand #");
2231 
2232  // If this expression already has two operands (the common case) no splitting
2233  // is necessary.
2234  if (fullyAssoc->getNumOperands() == 2)
2235  return fullyAssoc->getOperand(operandNo ^ 1);
2236 
2237  // If the operation has a single use, mutate it in place.
2238  if (fullyAssoc->hasOneUse()) {
2239  rewriter.modifyOpInPlace(fullyAssoc,
2240  [&]() { fullyAssoc->eraseOperand(operandNo); });
2241  return fullyAssoc->getResult(0);
2242  }
2243 
2244  // Form the new operation with the operands that remain.
2245  SmallVector<Value> operands;
2246  operands.append(fullyAssoc->getOperands().begin(),
2247  fullyAssoc->getOperands().begin() + operandNo);
2248  operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2249  fullyAssoc->getOperands().end());
2250  Value opWithoutExcluded = createGenericOp(
2251  fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2252  Value excluded = fullyAssoc->getOperand(operandNo);
2253 
2254  Value fullResult =
2255  createGenericOp(fullyAssoc->getLoc(), fullyAssoc->getName(),
2256  ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2257  replaceOpAndCopyName(rewriter, fullyAssoc, fullResult);
2258  return opWithoutExcluded;
2259 }
2260 
2261 /// Fold things like `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond)|a` and
2262 /// `mux(cond, a, x|y|z|a) -> `(x|y|z)&replicate(~cond) | a` (when isTrueOperand
2263 /// is true. Return true on successful transformation, false if not.
2264 ///
2265 /// These are various forms of "predicated ops" that can be handled with a
2266 /// replicate/and combination.
2267 static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
2268  PatternRewriter &rewriter) {
2269  // Check to see the operand in question is an operation. If it is a port,
2270  // we can't simplify it.
2271  Operation *subExpr =
2272  (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2273  if (!subExpr || subExpr->getNumOperands() < 2)
2274  return false;
2275 
2276  // If this isn't an operation we can handle, don't spend energy on it.
2277  if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2278  return false;
2279 
2280  // Check to see if the common value occurs in the operand list for the
2281  // subexpression op. If so, then we can simplify it.
2282  Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2283  size_t opNo = 0, e = subExpr->getNumOperands();
2284  while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2285  ++opNo;
2286  if (opNo == e)
2287  return false;
2288 
2289  // If we got a hit, then go ahead and simplify it!
2290  Value cond = op.getCond();
2291 
2292  // `mux(cond, a, mux(cond2, a, b))` -> `mux(cond|cond2, a, b)`
2293  // `mux(cond, a, mux(cond2, b, a))` -> `mux(cond|~cond2, a, b)`
2294  // `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)`
2295  // `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)`
2296  if (auto subMux = dyn_cast<MuxOp>(subExpr)) {
2297  Value otherValue;
2298  Value subCond = subMux.getCond();
2299 
2300  // Invert th subCond if needed and dig out the 'b' value.
2301  if (subMux.getTrueValue() == commonValue)
2302  otherValue = subMux.getFalseValue();
2303  else if (subMux.getFalseValue() == commonValue) {
2304  otherValue = subMux.getTrueValue();
2305  subCond = createOrFoldNot(op.getLoc(), subCond, rewriter);
2306  } else {
2307  // We can't fold `mux(cond, a, mux(a, x, y))`.
2308  return false;
2309  }
2310 
2311  // Invert the outer cond if needed, and combine the mux conditions.
2312  if (!isTrueOperand)
2313  cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2314  cond = rewriter.createOrFold<OrOp>(op.getLoc(), cond, subCond, false);
2315  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, cond, commonValue,
2316  otherValue, op.getTwoState());
2317  return true;
2318  }
2319 
2320  // Invert the condition if needed. Or/Xor invert when dealing with
2321  // TrueOperand, And inverts for False operand.
2322  bool isaAndOp = isa<AndOp>(subExpr);
2323  if (isTrueOperand ^ isaAndOp)
2324  cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2325 
2326  auto extendedCond =
2327  rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2328 
2329  // Cache this information before subExpr is erased by extraction below.
2330  bool isaXorOp = isa<XorOp>(subExpr);
2331  bool isaOrOp = isa<OrOp>(subExpr);
2332 
2333  // Handle the fully associative ops, start by pulling out the subexpression
2334  // from a many operand version of the op.
2335  auto restOfAssoc =
2336  extractOperandFromFullyAssociative(subExpr, opNo, rewriter);
2337 
2338  // `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond) | a`
2339  // `mux(cond, x^y^z^a, a)` -> `(x^y^z)&replicate(cond) ^ a`
2340  if (isaOrOp || isaXorOp) {
2341  auto masked = rewriter.createOrFold<AndOp>(op.getLoc(), extendedCond,
2342  restOfAssoc, false);
2343  if (isaXorOp)
2344  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, masked, commonValue,
2345  false);
2346  else
2347  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, masked, commonValue,
2348  false);
2349  return true;
2350  }
2351 
2352  // `mux(cond, a, x&y&z&a)` -> `((x&y&z)|replicate(cond)) & a`
2353  assert(isaAndOp && "unexpected operation here");
2354  auto masked = rewriter.createOrFold<OrOp>(op.getLoc(), extendedCond,
2355  restOfAssoc, false);
2356  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, masked, commonValue,
2357  false);
2358  return true;
2359 }
2360 
2361 /// This function is invoke when we find a mux with true/false operations that
2362 /// have the same opcode. Check to see if we can strength reduce the mux by
2363 /// applying it to less data by applying this transformation:
2364 /// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2365 static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp,
2366  Operation *falseOp,
2367  PatternRewriter &rewriter) {
2368  // Right now we only apply to concat.
2369  // TODO: Generalize this to and, or, xor, icmp(!), which all occur in practice
2370  if (!isa<ConcatOp>(trueOp))
2371  return false;
2372 
2373  // Decode the operands, looking through recursive concats and replicates.
2374  SmallVector<Value> trueOperands, falseOperands;
2375  getConcatOperands(trueOp->getResult(0), trueOperands);
2376  getConcatOperands(falseOp->getResult(0), falseOperands);
2377 
2378  size_t numTrueOperands = trueOperands.size();
2379  size_t numFalseOperands = falseOperands.size();
2380 
2381  if (!numTrueOperands || !numFalseOperands ||
2382  (trueOperands.front() != falseOperands.front() &&
2383  trueOperands.back() != falseOperands.back()))
2384  return false;
2385 
2386  // Pull all leading shared operands out into their own op if any are common.
2387  if (trueOperands.front() == falseOperands.front()) {
2388  SmallVector<Value> operands;
2389  size_t i;
2390  for (i = 0; i < numTrueOperands; ++i) {
2391  Value trueOperand = trueOperands[i];
2392  if (trueOperand == falseOperands[i])
2393  operands.push_back(trueOperand);
2394  else
2395  break;
2396  }
2397  if (i == numTrueOperands) {
2398  // Selecting between distinct, but lexically identical, concats.
2399  replaceOpAndCopyName(rewriter, mux, trueOp->getResult(0));
2400  return true;
2401  }
2402 
2403  Value sharedMSB;
2404  if (llvm::all_of(operands, [&](Value v) { return v == operands.front(); }))
2405  sharedMSB = rewriter.createOrFold<ReplicateOp>(
2406  mux->getLoc(), operands.front(), operands.size());
2407  else
2408  sharedMSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2409  operands.clear();
2410 
2411  // Get a concat of the LSB's on each side.
2412  operands.append(trueOperands.begin() + i, trueOperands.end());
2413  Value trueLSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2414  operands.clear();
2415  operands.append(falseOperands.begin() + i, falseOperands.end());
2416  Value falseLSB =
2417  rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2418  // Merge the LSBs with a new mux and concat the MSB with the LSB to be
2419  // done.
2420  Value lsb = rewriter.createOrFold<MuxOp>(
2421  mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2422  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2423  return true;
2424  }
2425 
2426  // If trailing operands match, try to commonize them.
2427  if (trueOperands.back() == falseOperands.back()) {
2428  SmallVector<Value> operands;
2429  size_t i;
2430  for (i = 0;; ++i) {
2431  Value trueOperand = trueOperands[numTrueOperands - i - 1];
2432  if (trueOperand == falseOperands[numFalseOperands - i - 1])
2433  operands.push_back(trueOperand);
2434  else
2435  break;
2436  }
2437  std::reverse(operands.begin(), operands.end());
2438  Value sharedLSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2439  operands.clear();
2440 
2441  // Get a concat of the MSB's on each side.
2442  operands.append(trueOperands.begin(), trueOperands.end() - i);
2443  Value trueMSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2444  operands.clear();
2445  operands.append(falseOperands.begin(), falseOperands.end() - i);
2446  Value falseMSB =
2447  rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2448  // Merge the MSBs with a new mux and concat the MSB with the LSB to be done.
2449  Value msb = rewriter.createOrFold<MuxOp>(
2450  mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2451  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, msb, sharedLSB);
2452  return true;
2453  }
2454 
2455  return false;
2456 }
2457 
2458 // If both arguments of the mux are arrays with the same elements, sink the
2459 // mux and return a uniform array initializing all elements to it.
2460 static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter) {
2461  auto trueVec = op.getTrueValue().getDefiningOp<hw::ArrayCreateOp>();
2462  auto falseVec = op.getFalseValue().getDefiningOp<hw::ArrayCreateOp>();
2463  if (!trueVec || !falseVec)
2464  return false;
2465  if (!trueVec.isUniform() || !falseVec.isUniform())
2466  return false;
2467 
2468  auto mux = rewriter.create<MuxOp>(
2469  op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2470  falseVec.getUniformElement(), op.getTwoState());
2471 
2472  SmallVector<Value> values(trueVec.getInputs().size(), mux);
2473  rewriter.replaceOpWithNewOp<hw::ArrayCreateOp>(op, values);
2474  return true;
2475 }
2476 
2477 namespace {
2478 struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
2479  using OpRewritePattern::OpRewritePattern;
2480 
2481  LogicalResult matchAndRewrite(MuxOp op,
2482  PatternRewriter &rewriter) const override;
2483 };
2484 
2485 LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
2486  PatternRewriter &rewriter) const {
2487  if (hasOperandsOutsideOfBlock(&*op))
2488  return failure();
2489 
2490  // If the op has a SV attribute, don't optimize it.
2491  if (hasSVAttributes(op))
2492  return failure();
2493  APInt value;
2494 
2495  if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2496  if (value.getBitWidth() == 1) {
2497  // mux(a, 0, b) -> and(~a, b) for single-bit values.
2498  if (value.isZero()) {
2499  auto notCond = createOrFoldNot(op.getLoc(), op.getCond(), rewriter);
2500  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, notCond,
2501  op.getFalseValue(), false);
2502  return success();
2503  }
2504 
2505  // mux(a, 1, b) -> or(a, b) for single-bit values.
2506  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getCond(),
2507  op.getFalseValue(), false);
2508  return success();
2509  }
2510 
2511  // Check for mux of two constants. There are many ways to simplify them.
2512  APInt value2;
2513  if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2514  // When both inputs are constants and differ by only one bit, we can
2515  // simplify by splitting the mux into up to three contiguous chunks: one
2516  // for the differing bit and up to two for the bits that are the same.
2517  // E.g. mux(a, 3'h2, 0) -> concat(0, mux(a, 1, 0), 0) -> concat(0, a, 0)
2518  APInt xorValue = value ^ value2;
2519  if (xorValue.isPowerOf2()) {
2520  unsigned leadingZeros = xorValue.countLeadingZeros();
2521  unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2522  SmallVector<Value, 3> operands;
2523 
2524  // Concat operands go from MSB to LSB, so we handle chunks in reverse
2525  // order of bit indexes.
2526  // For the chunks that are identical (i.e. correspond to 0s in
2527  // xorValue), we can extract directly from either input value, and we
2528  // arbitrarily pick the trueValue().
2529 
2530  if (leadingZeros > 0)
2531  operands.push_back(rewriter.createOrFold<ExtractOp>(
2532  op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2533 
2534  // Handle the differing bit, which should simplify into either cond or
2535  // ~cond.
2536  auto v1 = rewriter.createOrFold<ExtractOp>(
2537  op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2538  auto v2 = rewriter.createOrFold<ExtractOp>(
2539  op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2540  operands.push_back(rewriter.createOrFold<MuxOp>(
2541  op.getLoc(), op.getCond(), v1, v2, false));
2542 
2543  if (trailingZeros > 0)
2544  operands.push_back(rewriter.createOrFold<ExtractOp>(
2545  op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2546 
2547  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
2548  operands);
2549  return success();
2550  }
2551 
2552  // If the true value is all ones and the false is all zeros then we have a
2553  // replicate pattern.
2554  if (value.isAllOnes() && value2.isZero()) {
2555  replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2556  op.getCond());
2557  return success();
2558  }
2559  }
2560  }
2561 
2562  if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2563  value.getBitWidth() == 1) {
2564  // mux(a, b, 0) -> and(a, b) for single-bit values.
2565  if (value.isZero()) {
2566  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getCond(),
2567  op.getTrueValue(), false);
2568  return success();
2569  }
2570 
2571  // mux(a, b, 1) -> or(~a, b) for single-bit values.
2572  // falseValue() is known to be a single-bit 1, which we can use for
2573  // the 1 in the representation of ~ using xor.
2574  auto notCond = rewriter.createOrFold<XorOp>(op.getLoc(), op.getCond(),
2575  op.getFalseValue(), false);
2576  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, notCond,
2577  op.getTrueValue(), false);
2578  return success();
2579  }
2580 
2581  // mux(!a, b, c) -> mux(a, c, b)
2582  Value subExpr;
2583  Operation *condOp = op.getCond().getDefiningOp();
2584  if (condOp && matchPattern(condOp, m_Complement(m_Any(&subExpr))) &&
2585  op.getTwoState()) {
2586  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, op.getType(), subExpr,
2587  op.getFalseValue(), op.getTrueValue(),
2588  true);
2589  return success();
2590  }
2591 
2592  // Same but with Demorgan's law.
2593  // mux(and(~a, ~b, ~c), x, y) -> mux(or(a, b, c), y, x)
2594  // mux(or(~a, ~b, ~c), x, y) -> mux(and(a, b, c), y, x)
2595  if (condOp && condOp->hasOneUse()) {
2596  SmallVector<Value> invertedOperands;
2597 
2598  /// Scan all the operands to see if they are complemented. If so, build a
2599  /// vector of them and return true, otherwise return false.
2600  auto getInvertedOperands = [&]() -> bool {
2601  for (Value operand : condOp->getOperands()) {
2602  if (matchPattern(operand, m_Complement(m_Any(&subExpr))))
2603  invertedOperands.push_back(subExpr);
2604  else
2605  return false;
2606  }
2607  return true;
2608  };
2609 
2610  if (isa<AndOp>(condOp) && getInvertedOperands()) {
2611  auto newOr =
2612  rewriter.createOrFold<OrOp>(op.getLoc(), invertedOperands, false);
2613  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newOr,
2614  op.getFalseValue(),
2615  op.getTrueValue(), op.getTwoState());
2616  return success();
2617  }
2618  if (isa<OrOp>(condOp) && getInvertedOperands()) {
2619  auto newAnd =
2620  rewriter.createOrFold<AndOp>(op.getLoc(), invertedOperands, false);
2621  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newAnd,
2622  op.getFalseValue(),
2623  op.getTrueValue(), op.getTwoState());
2624  return success();
2625  }
2626  }
2627 
2628  if (auto falseMux =
2629  dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp())) {
2630  // mux(selector, x, mux(selector, y, z) = mux(selector, x, z)
2631  if (op.getCond() == falseMux.getCond()) {
2632  replaceOpWithNewOpAndCopyName<MuxOp>(
2633  rewriter, op, op.getCond(), op.getTrueValue(),
2634  falseMux.getFalseValue(), op.getTwoStateAttr());
2635  return success();
2636  }
2637 
2638  // Check to see if we can fold a mux tree into an array_create/get pair.
2639  if (foldMuxChain(op, /*isFalse*/ true, rewriter))
2640  return success();
2641  }
2642 
2643  if (auto trueMux =
2644  dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp())) {
2645  // mux(selector, mux(selector, a, b), c) = mux(selector, a, c)
2646  if (op.getCond() == trueMux.getCond()) {
2647  replaceOpWithNewOpAndCopyName<MuxOp>(
2648  rewriter, op, op.getCond(), trueMux.getTrueValue(),
2649  op.getFalseValue(), op.getTwoStateAttr());
2650  return success();
2651  }
2652 
2653  // Check to see if we can fold a mux tree into an array_create/get pair.
2654  if (foldMuxChain(op, /*isFalseSide*/ false, rewriter))
2655  return success();
2656  }
2657 
2658  // mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c))
2659  if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2660  falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2661  trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2662  trueMux.getTrueValue() == falseMux.getTrueValue()) {
2663  auto subMux = rewriter.create<MuxOp>(
2664  rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2665  op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2666  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2667  trueMux.getTrueValue(), subMux,
2668  op.getTwoStateAttr());
2669  return success();
2670  }
2671 
2672  // mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b)
2673  if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2674  falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2675  trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2676  trueMux.getFalseValue() == falseMux.getFalseValue()) {
2677  auto subMux = rewriter.create<MuxOp>(
2678  rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2679  op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2680  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2681  subMux, trueMux.getFalseValue(),
2682  op.getTwoStateAttr());
2683  return success();
2684  }
2685 
2686  // mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b)
2687  if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2688  falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2689  trueMux && falseMux &&
2690  trueMux.getTrueValue() == falseMux.getTrueValue() &&
2691  trueMux.getFalseValue() == falseMux.getFalseValue()) {
2692  auto subMux = rewriter.create<MuxOp>(
2693  rewriter.getFusedLoc(
2694  {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2695  op.getCond(), trueMux.getCond(), falseMux.getCond());
2696  replaceOpWithNewOpAndCopyName<MuxOp>(
2697  rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2698  op.getTwoStateAttr());
2699  return success();
2700  }
2701 
2702  // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a
2703  if (foldCommonMuxValue(op, false, rewriter))
2704  return success();
2705  // mux(cond, a, x|y|z|a) -> (x|y|z)&replicate(~cond) | a
2706  if (foldCommonMuxValue(op, true, rewriter))
2707  return success();
2708 
2709  // `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2710  if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2711  if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2712  if (trueOp->getName() == falseOp->getName())
2713  if (foldCommonMuxOperation(op, trueOp, falseOp, rewriter))
2714  return success();
2715 
2716  // extracts only of mux(...) -> mux(extract()...)
2717  if (narrowOperationWidth(op, true, rewriter))
2718  return success();
2719 
2720  // mux(cond, repl(n, a1), repl(n, a2)) -> repl(n, mux(cond, a1, a2))
2721  if (foldMuxOfUniformArrays(op, rewriter))
2722  return success();
2723 
2724  return failure();
2725 }
2726 
2727 static bool foldArrayOfMuxes(hw::ArrayCreateOp op, PatternRewriter &rewriter) {
2728  // Do not fold uniform or singleton arrays to avoid duplicating muxes.
2729  if (op.getInputs().empty() || op.isUniform())
2730  return false;
2731  auto inputs = op.getInputs();
2732  if (inputs.size() <= 1)
2733  return false;
2734 
2735  // Check the operands to the array create. Ensure all of them are the
2736  // same op with the same number of operands.
2737  auto first = inputs[0].getDefiningOp<comb::MuxOp>();
2738  if (!first || hasSVAttributes(first))
2739  return false;
2740 
2741  // Check whether all operands are muxes with the same condition.
2742  for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2743  auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2744  if (!input || first.getCond() != input.getCond())
2745  return false;
2746  }
2747 
2748  // Collect the true and the false branches into arrays.
2749  SmallVector<Value> trues{first.getTrueValue()};
2750  SmallVector<Value> falses{first.getFalseValue()};
2751  SmallVector<Location> locs{first->getLoc()};
2752  bool isTwoState = true;
2753  for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2754  auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2755  trues.push_back(input.getTrueValue());
2756  falses.push_back(input.getFalseValue());
2757  locs.push_back(input->getLoc());
2758  if (!input.getTwoState())
2759  isTwoState = false;
2760  }
2761 
2762  // Define the location of the array create as the aggregate of all muxes.
2763  auto loc = FusedLoc::get(op.getContext(), locs);
2764 
2765  // Replace the create with an aggregate operation. Push the create op
2766  // into the operands of the aggregate operation.
2767  auto arrayTy = op.getType();
2768  auto trueValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, trues);
2769  auto falseValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, falses);
2770  rewriter.replaceOpWithNewOp<comb::MuxOp>(op, arrayTy, first.getCond(),
2771  trueValues, falseValues, isTwoState);
2772  return true;
2773 }
2774 
2775 struct ArrayRewriter : public mlir::OpRewritePattern<hw::ArrayCreateOp> {
2776  using OpRewritePattern::OpRewritePattern;
2777 
2778  LogicalResult matchAndRewrite(hw::ArrayCreateOp op,
2779  PatternRewriter &rewriter) const override {
2780  if (hasOperandsOutsideOfBlock(&*op))
2781  return failure();
2782 
2783  if (foldArrayOfMuxes(op, rewriter))
2784  return success();
2785  return failure();
2786  }
2787 };
2788 
2789 } // namespace
2790 
2791 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2792  MLIRContext *context) {
2793  results.insert<MuxRewriter, ArrayRewriter>(context);
2794 }
2795 
2796 //===----------------------------------------------------------------------===//
2797 // ICmpOp
2798 //===----------------------------------------------------------------------===//
2799 
2800 // Calculate the result of a comparison when the LHS and RHS are both
2801 // constants.
2802 static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs,
2803  const APInt &rhs) {
2804  switch (predicate) {
2805  case ICmpPredicate::eq:
2806  return lhs.eq(rhs);
2807  case ICmpPredicate::ne:
2808  return lhs.ne(rhs);
2809  case ICmpPredicate::slt:
2810  return lhs.slt(rhs);
2811  case ICmpPredicate::sle:
2812  return lhs.sle(rhs);
2813  case ICmpPredicate::sgt:
2814  return lhs.sgt(rhs);
2815  case ICmpPredicate::sge:
2816  return lhs.sge(rhs);
2817  case ICmpPredicate::ult:
2818  return lhs.ult(rhs);
2819  case ICmpPredicate::ule:
2820  return lhs.ule(rhs);
2821  case ICmpPredicate::ugt:
2822  return lhs.ugt(rhs);
2823  case ICmpPredicate::uge:
2824  return lhs.uge(rhs);
2825  case ICmpPredicate::ceq:
2826  return lhs.eq(rhs);
2827  case ICmpPredicate::cne:
2828  return lhs.ne(rhs);
2829  case ICmpPredicate::weq:
2830  return lhs.eq(rhs);
2831  case ICmpPredicate::wne:
2832  return lhs.ne(rhs);
2833  }
2834  llvm_unreachable("unknown comparison predicate");
2835 }
2836 
2837 // Returns the result of applying the predicate when the LHS and RHS are the
2838 // exact same value.
2839 static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) {
2840  switch (predicate) {
2841  case ICmpPredicate::eq:
2842  case ICmpPredicate::sle:
2843  case ICmpPredicate::sge:
2844  case ICmpPredicate::ule:
2845  case ICmpPredicate::uge:
2846  case ICmpPredicate::ceq:
2847  case ICmpPredicate::weq:
2848  return true;
2849  case ICmpPredicate::ne:
2850  case ICmpPredicate::slt:
2851  case ICmpPredicate::sgt:
2852  case ICmpPredicate::ult:
2853  case ICmpPredicate::ugt:
2854  case ICmpPredicate::cne:
2855  case ICmpPredicate::wne:
2856  return false;
2857  }
2858  llvm_unreachable("unknown comparison predicate");
2859 }
2860 
2861 OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2862  if (hasOperandsOutsideOfBlock(getOperation()))
2863  return {};
2864 
2865  // gt a, a -> false
2866  // gte a, a -> true
2867  if (getLhs() == getRhs()) {
2868  auto val = applyCmpPredicateToEqualOperands(getPredicate());
2869  return IntegerAttr::get(getType(), val);
2870  }
2871 
2872  // gt 1, 2 -> false
2873  if (auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2874  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2875  auto val =
2876  applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2877  return IntegerAttr::get(getType(), val);
2878  }
2879  }
2880  return {};
2881 }
2882 
2883 // Given a range of operands, computes the number of matching prefix and
2884 // suffix elements. This does not perform cross-element matching.
2885 template <typename Range>
2886 static size_t computeCommonPrefixLength(const Range &a, const Range &b) {
2887  size_t commonPrefixLength = 0;
2888  auto ia = a.begin();
2889  auto ib = b.begin();
2890 
2891  for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2892  if (*ia != *ib) {
2893  break;
2894  }
2895  }
2896 
2897  return commonPrefixLength;
2898 }
2899 
2900 static size_t getTotalWidth(ArrayRef<Value> operands) {
2901  size_t totalWidth = 0;
2902  for (auto operand : operands) {
2903  // getIntOrFloatBitWidth should never raise, since all arguments to
2904  // ConcatOp are integers.
2905  ssize_t width = operand.getType().getIntOrFloatBitWidth();
2906  assert(width >= 0);
2907  totalWidth += width;
2908  }
2909  return totalWidth;
2910 }
2911 
2912 /// Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise
2913 /// comparison on common prefix and suffixes. Returns success() if a rewriting
2914 /// happens. This handles both concat and replicate.
2915 static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs,
2916  Operation *rhs,
2917  PatternRewriter &rewriter) {
2918  // It is safe to assume that [{lhsOperands, rhsOperands}.size() > 0] and
2919  // all elements have non-zero length. Both these invariants are verified
2920  // by the ConcatOp verifier.
2921  SmallVector<Value> lhsOperands, rhsOperands;
2922  getConcatOperands(lhs->getResult(0), lhsOperands);
2923  getConcatOperands(rhs->getResult(0), rhsOperands);
2924  ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2925 
2926  auto formCatOrReplicate = [&](Location loc,
2927  ArrayRef<Value> operands) -> Value {
2928  assert(!operands.empty());
2929  Value sameElement = operands[0];
2930  for (size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2931  if (sameElement != operands[i])
2932  sameElement = Value();
2933  if (sameElement)
2934  return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2935  operands.size());
2936  return rewriter.createOrFold<ConcatOp>(loc, operands);
2937  };
2938 
2939  auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2940  Value rhs) -> LogicalResult {
2941  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2942  op.getTwoState());
2943  return success();
2944  };
2945 
2946  size_t commonPrefixLength =
2947  computeCommonPrefixLength(lhsOperands, rhsOperands);
2948  if (commonPrefixLength == lhsOperands.size()) {
2949  // cat(a, b, c) == cat(a, b, c) -> 1
2950  bool result = applyCmpPredicateToEqualOperands(op.getPredicate());
2951  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
2952  APInt(1, result));
2953  return success();
2954  }
2955 
2956  size_t commonSuffixLength = computeCommonPrefixLength(
2957  llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2958 
2959  size_t commonPrefixTotalWidth =
2960  getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2961  size_t commonSuffixTotalWidth =
2962  getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2963  auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2964  .drop_back(commonSuffixLength);
2965  auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2966  .drop_back(commonSuffixLength);
2967 
2968  auto replaceWithoutReplicatingSignBit = [&]() {
2969  auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2970  auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2971  return replaceWith(op.getPredicate(), newLhs, newRhs);
2972  };
2973 
2974  auto replaceWithReplicatingSignBit = [&]() {
2975  auto firstNonEmptyValue = lhsOperands[0];
2976  auto firstNonEmptyElemWidth =
2977  firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2978  Value signBit = rewriter.createOrFold<ExtractOp>(
2979  op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2980 
2981  auto newLhs = rewriter.create<ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2982  auto newRhs = rewriter.create<ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2983  return replaceWith(op.getPredicate(), newLhs, newRhs);
2984  };
2985 
2986  if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2987  // scmp(cat(..x, b), cat(..y, b)) == scmp(cat(..x), cat(..y))
2988  if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2989  return replaceWithoutReplicatingSignBit();
2990 
2991  // scmp(cat(a, ..x, b), cat(a, ..y, b)) == scmp(cat(sgn(a), ..x),
2992  // cat(sgn(b), ..y)) Note that we cannot perform this optimization if
2993  // [width(b) = 0 && width(a) <= 1]. since that common prefix is the sign
2994  // bit. Doing the rewrite can result in an infinite loop.
2995  if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2996  return replaceWithReplicatingSignBit();
2997 
2998  } else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2999  // ucmp(cat(a, ..x, b), cat(a, ..y, b)) = ucmp(cat(..x), cat(..y))
3000  return replaceWithoutReplicatingSignBit();
3001  }
3002 
3003  return failure();
3004 }
3005 
3006 /// Given an equality comparison with a constant value and some operand that has
3007 /// known bits, simplify the comparison to check only the unknown bits of the
3008 /// input.
3009 ///
3010 /// One simple example of this is that `concat(0, stuff) == 0` can be simplified
3011 /// to `stuff == 0`, or `and(x, 3) == 0` can be simplified to
3012 /// `extract x[1:0] == 0`
3014  ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst,
3015  PatternRewriter &rewriter) {
3016 
3017  // If any of the known bits disagree with any of the comparison bits, then
3018  // we can constant fold this comparison right away.
3019  APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
3020  if ((bitsKnown & rhsCst) != bitAnalysis.One) {
3021  // If we discover a mismatch then we know an "eq" comparison is false
3022  // and a "ne" comparison is true!
3023  bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
3024  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
3025  APInt(1, result));
3026  return;
3027  }
3028 
3029  // Check to see if we can prove the result entirely of the comparison (in
3030  // which we bail out early), otherwise build a list of values to concat and a
3031  // smaller constant to compare against.
3032  SmallVector<Value> newConcatOperands;
3033  auto newConstant = APInt::getZeroWidth();
3034 
3035  // Ok, some (maybe all) bits are known and some others may be unknown.
3036  // Extract out segments of the operand and compare against the
3037  // corresponding bits.
3038  unsigned knownMSB = bitsKnown.countLeadingOnes();
3039 
3040  Value operand = cmpOp.getLhs();
3041 
3042  // Ok, some bits are known but others are not. Extract out sequences of
3043  // bits that are unknown and compare just those bits. We work from MSB to
3044  // LSB.
3045  while (knownMSB != bitsKnown.getBitWidth()) {
3046  // Drop any high bits that are known.
3047  if (knownMSB)
3048  bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
3049 
3050  // Find the span of unknown bits, and extract it.
3051  unsigned unknownBits = bitsKnown.countLeadingZeros();
3052  unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
3053  auto spanOperand = rewriter.createOrFold<ExtractOp>(
3054  operand.getLoc(), operand, /*lowBit=*/lowBit,
3055  /*bitWidth=*/unknownBits);
3056  auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
3057 
3058  // Add this info to the concat we're generating.
3059  newConcatOperands.push_back(spanOperand);
3060  // FIXME(llvm merge, cc697fc292b0): concat doesn't work with zero bit values
3061  // newConstant = newConstant.concat(spanConstant);
3062  if (newConstant.getBitWidth() != 0)
3063  newConstant = newConstant.concat(spanConstant);
3064  else
3065  newConstant = spanConstant;
3066 
3067  // Drop the unknown bits in prep for the next chunk.
3068  unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3069  bitsKnown = bitsKnown.trunc(newWidth);
3070  knownMSB = bitsKnown.countLeadingOnes();
3071  }
3072 
3073  // If all the operands to the concat are foldable then we have an identity
3074  // situation where all the sub-elements equal each other. This implies that
3075  // the overall result is foldable.
3076  if (newConcatOperands.empty()) {
3077  bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3078  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
3079  APInt(1, result));
3080  return;
3081  }
3082 
3083  // If we have a single operand remaining, use it, otherwise form a concat.
3084  Value concatResult =
3085  rewriter.createOrFold<ConcatOp>(operand.getLoc(), newConcatOperands);
3086 
3087  // Form the comparison against the smaller constant.
3088  auto newConstantOp = rewriter.create<hw::ConstantOp>(
3089  cmpOp.getOperand(1).getLoc(), newConstant);
3090 
3091  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3092  concatResult, newConstantOp,
3093  cmpOp.getTwoState());
3094 }
3095 
3096 // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), cst1^cst2).
3097 static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp,
3098  const APInt &rhs,
3099  PatternRewriter &rewriter) {
3100  auto ip = rewriter.saveInsertionPoint();
3101  rewriter.setInsertionPoint(xorOp);
3102 
3103  auto xorRHS = xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>();
3104  auto newRHS = rewriter.create<hw::ConstantOp>(xorRHS->getLoc(),
3105  xorRHS.getValue() ^ rhs);
3106  Value newLHS;
3107  switch (xorOp.getNumOperands()) {
3108  case 1:
3109  // This isn't common but is defined so we need to handle it.
3110  newLHS = rewriter.create<hw::ConstantOp>(xorOp.getLoc(),
3111  APInt::getZero(rhs.getBitWidth()));
3112  break;
3113  case 2:
3114  // The binary case is the most common.
3115  newLHS = xorOp.getOperand(0);
3116  break;
3117  default:
3118  // The general case forces us to form a new xor with the remaining operands.
3119  SmallVector<Value> newOperands(xorOp.getOperands());
3120  newOperands.pop_back();
3121  newLHS = rewriter.create<XorOp>(xorOp.getLoc(), newOperands, false);
3122  break;
3123  }
3124 
3125  bool xorMultipleUses = !xorOp->hasOneUse();
3126 
3127  // If the xor has multiple uses (not just the compare, then we need/want to
3128  // replace them as well.
3129  if (xorMultipleUses)
3130  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3131  false);
3132 
3133  // Replace the comparison.
3134  rewriter.restoreInsertionPoint(ip);
3135  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3136  newLHS, newRHS, false);
3137 }
3138 
3139 LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3140  if (hasOperandsOutsideOfBlock(&*op))
3141  return failure();
3142 
3143  APInt lhs, rhs;
3144 
3145  // icmp 1, x -> icmp x, 1
3146  if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3147  assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3148  "Should be folded");
3149  replaceOpWithNewOpAndCopyName<ICmpOp>(
3150  rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3151  op.getRhs(), op.getLhs(), op.getTwoState());
3152  return success();
3153  }
3154 
3155  // Canonicalize with RHS constant
3156  if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3157  auto getConstant = [&](APInt constant) -> Value {
3158  return rewriter.create<hw::ConstantOp>(op.getLoc(), std::move(constant));
3159  };
3160 
3161  auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3162  Value rhs) -> LogicalResult {
3163  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
3164  op.getTwoState());
3165  return success();
3166  };
3167 
3168  auto replaceWithConstantI1 = [&](bool constant) -> LogicalResult {
3169  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
3170  APInt(1, constant));
3171  return success();
3172  };
3173 
3174  switch (op.getPredicate()) {
3175  case ICmpPredicate::slt:
3176  // x < max -> x != max
3177  if (rhs.isMaxSignedValue())
3178  return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3179  // x < min -> false
3180  if (rhs.isMinSignedValue())
3181  return replaceWithConstantI1(0);
3182  // x < min+1 -> x == min
3183  if ((rhs - 1).isMinSignedValue())
3184  return replaceWith(ICmpPredicate::eq, op.getLhs(),
3185  getConstant(rhs - 1));
3186  break;
3187  case ICmpPredicate::sgt:
3188  // x > min -> x != min
3189  if (rhs.isMinSignedValue())
3190  return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3191  // x > max -> false
3192  if (rhs.isMaxSignedValue())
3193  return replaceWithConstantI1(0);
3194  // x > max-1 -> x == max
3195  if ((rhs + 1).isMaxSignedValue())
3196  return replaceWith(ICmpPredicate::eq, op.getLhs(),
3197  getConstant(rhs + 1));
3198  break;
3199  case ICmpPredicate::ult:
3200  // x < max -> x != max
3201  if (rhs.isAllOnes())
3202  return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3203  // x < min -> false
3204  if (rhs.isZero())
3205  return replaceWithConstantI1(0);
3206  // x < min+1 -> x == min
3207  if ((rhs - 1).isZero())
3208  return replaceWith(ICmpPredicate::eq, op.getLhs(),
3209  getConstant(rhs - 1));
3210 
3211  // x < 0xE0 -> extract(x, 5..7) != 0b111
3212  if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3213  rhs.getBitWidth()) {
3214  auto numOnes = rhs.countLeadingOnes();
3215  auto smaller = rewriter.create<ExtractOp>(
3216  op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
3217  return replaceWith(ICmpPredicate::ne, smaller,
3218  getConstant(APInt::getAllOnes(numOnes)));
3219  }
3220 
3221  break;
3222  case ICmpPredicate::ugt:
3223  // x > min -> x != min
3224  if (rhs.isZero())
3225  return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3226  // x > max -> false
3227  if (rhs.isAllOnes())
3228  return replaceWithConstantI1(0);
3229  // x > max-1 -> x == max
3230  if ((rhs + 1).isAllOnes())
3231  return replaceWith(ICmpPredicate::eq, op.getLhs(),
3232  getConstant(rhs + 1));
3233 
3234  // x > 0x07 -> extract(x, 3..7) != 0b00000
3235  if ((rhs + 1).isPowerOf2()) {
3236  auto numOnes = rhs.countTrailingOnes();
3237  auto newWidth = rhs.getBitWidth() - numOnes;
3238  auto smaller = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(),
3239  numOnes, newWidth);
3240  return replaceWith(ICmpPredicate::ne, smaller,
3241  getConstant(APInt::getZero(newWidth)));
3242  }
3243 
3244  break;
3245  case ICmpPredicate::sle:
3246  // x <= max -> true
3247  if (rhs.isMaxSignedValue())
3248  return replaceWithConstantI1(1);
3249  // x <= c -> x < (c+1)
3250  return replaceWith(ICmpPredicate::slt, op.getLhs(), getConstant(rhs + 1));
3251  case ICmpPredicate::sge:
3252  // x >= min -> true
3253  if (rhs.isMinSignedValue())
3254  return replaceWithConstantI1(1);
3255  // x >= c -> x > (c-1)
3256  return replaceWith(ICmpPredicate::sgt, op.getLhs(), getConstant(rhs - 1));
3257  case ICmpPredicate::ule:
3258  // x <= max -> true
3259  if (rhs.isAllOnes())
3260  return replaceWithConstantI1(1);
3261  // x <= c -> x < (c+1)
3262  return replaceWith(ICmpPredicate::ult, op.getLhs(), getConstant(rhs + 1));
3263  case ICmpPredicate::uge:
3264  // x >= min -> true
3265  if (rhs.isZero())
3266  return replaceWithConstantI1(1);
3267  // x >= c -> x > (c-1)
3268  return replaceWith(ICmpPredicate::ugt, op.getLhs(), getConstant(rhs - 1));
3269  case ICmpPredicate::eq:
3270  if (rhs.getBitWidth() == 1) {
3271  if (rhs.isZero()) {
3272  // x == 0 -> x ^ 1
3273  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3274  getConstant(APInt(1, 1)),
3275  op.getTwoState());
3276  return success();
3277  }
3278  if (rhs.isAllOnes()) {
3279  // x == 1 -> x
3280  replaceOpAndCopyName(rewriter, op, op.getLhs());
3281  return success();
3282  }
3283  }
3284  break;
3285  case ICmpPredicate::ne:
3286  if (rhs.getBitWidth() == 1) {
3287  if (rhs.isZero()) {
3288  // x != 0 -> x
3289  replaceOpAndCopyName(rewriter, op, op.getLhs());
3290  return success();
3291  }
3292  if (rhs.isAllOnes()) {
3293  // x != 1 -> x ^ 1
3294  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3295  getConstant(APInt(1, 1)),
3296  op.getTwoState());
3297  return success();
3298  }
3299  }
3300  break;
3301  case ICmpPredicate::ceq:
3302  case ICmpPredicate::cne:
3303  case ICmpPredicate::weq:
3304  case ICmpPredicate::wne:
3305  break;
3306  }
3307 
3308  // We have some specific optimizations for comparison with a constant that
3309  // are only supported for equality comparisons.
3310  if (op.getPredicate() == ICmpPredicate::eq ||
3311  op.getPredicate() == ICmpPredicate::ne) {
3312  // Simplify `icmp(value_with_known_bits, rhscst)` into some extracts
3313  // with a smaller constant. We only support equality comparisons for
3314  // this.
3315  auto knownBits = computeKnownBits(op.getLhs());
3316  if (!knownBits.isUnknown())
3317  return combineEqualityICmpWithKnownBitsAndConstant(op, knownBits, rhs,
3318  rewriter),
3319  success();
3320 
3321  // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b),
3322  // cst1^cst2).
3323  if (auto xorOp = op.getLhs().getDefiningOp<XorOp>())
3324  if (xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>())
3325  return combineEqualityICmpWithXorOfConstant(op, xorOp, rhs, rewriter),
3326  success();
3327 
3328  // Simplify icmp eq(replicate(v, n), c) -> icmp eq(v, c) if c is zero or
3329  // all one.
3330  if (auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3331  if (rhs.isAllOnes() || rhs.isZero()) {
3332  auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3333  auto cst = rewriter.create<hw::ConstantOp>(
3334  op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(width)
3335  : APInt::getZero(width));
3336  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, op.getPredicate(),
3337  replicateOp.getInput(), cst,
3338  op.getTwoState());
3339  return success();
3340  }
3341  }
3342  }
3343 
3344  // icmp(cat(prefix, a, b, suffix), cat(prefix, c, d, suffix)) => icmp(cat(a,
3345  // b), cat(c, d)). contains special handling for sign bit in signed
3346  // compressions.
3347  if (Operation *opLHS = op.getLhs().getDefiningOp())
3348  if (Operation *opRHS = op.getRhs().getDefiningOp())
3349  if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3350  isa<ConcatOp, ReplicateOp>(opRHS)) {
3351  if (succeeded(matchAndRewriteCompareConcat(op, opLHS, opRHS, rewriter)))
3352  return success();
3353  }
3354 
3355  return failure();
3356 }
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition: CalyxOps.cpp:540
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
Definition: CombFolds.cpp:2460
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Definition: CombFolds.cpp:724
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Definition: CombFolds.cpp:346
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
Definition: CombFolds.cpp:2839
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
Definition: CombFolds.cpp:759
static bool hasOperandsOutsideOfBlock(Operation *op)
In comb, we assume no knowledge of the semantics of cross-block dataflow.
Definition: CombFolds.cpp:32
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
Definition: CombFolds.cpp:253
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
Definition: CombFolds.cpp:1768
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
Definition: CombFolds.cpp:911
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
Definition: CombFolds.cpp:836
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
Definition: CombFolds.cpp:58
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "sv.namehint" attribute.
Definition: CombFolds.cpp:88
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
Definition: CombFolds.cpp:2226
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
Definition: CombFolds.cpp:2066
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
Definition: CombFolds.cpp:52
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
Definition: CombFolds.cpp:126
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
Definition: CombFolds.cpp:222
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
Definition: CombFolds.cpp:1419
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
Definition: CombFolds.cpp:588
static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1, size_t concatIdx2, PatternRewriter &rewriter)
Simplify concat ops in an or op when a constant operand is present in either concat.
Definition: CombFolds.cpp:1196
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
Definition: CombFolds.cpp:3097
static size_t getTotalWidth(ArrayRef< Value > operands)
Definition: CombFolds.cpp:2900
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
Definition: CombFolds.cpp:2365
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
Definition: CombFolds.cpp:144
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
Definition: CombFolds.cpp:946
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
Definition: CombFolds.cpp:2802
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
Definition: CombFolds.cpp:3013
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
Definition: CombFolds.cpp:2113
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
Definition: CombFolds.cpp:120
static bool hasSVAttributes(Operation *op)
Definition: CombFolds.cpp:103
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
Definition: CombFolds.cpp:507
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
Definition: CombFolds.cpp:1797
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
Definition: CombFolds.cpp:2886
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
Definition: CombFolds.cpp:2267
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
Definition: CombFolds.cpp:2915
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
Definition: CombFolds.cpp:44
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition: CombFolds.cpp:73
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
int32_t width
Definition: FIRRTL.cpp:36
def create(low_bit, result_type, input=None)
Definition: comb.py:187
def create(elements)
Definition: hw.py:443
def create(data_type, value)
Definition: hw.py:393
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
Definition: VerifOps.cpp:66
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition: CombOps.cpp:48
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:32
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
bool isOffset(Value base, Value index, uint64_t offset)
Definition: HWOps.cpp:1832
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: comb.py:1