CIRCT  19.0.0git
CombFolds.cpp
Go to the documentation of this file.
1 //===- CombFolds.cpp - Folds + Canonicalization for Comb operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "circt/Dialect/HW/HWOps.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallBitVector.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/KnownBits.h"
18 
19 using namespace mlir;
20 using namespace circt;
21 using namespace comb;
22 using namespace matchers;
23 
24 /// In comb, we assume no knowledge of the semantics of cross-block dataflow. As
25 /// such, cross-block dataflow is interpreted as a canonicalization barrier.
26 /// This is a conservative approach which:
27 /// 1. still allows for efficient canonicalization for the common CIRCT usecase
28 /// of comb (comb logic nested inside single-block hw.module's)
29 /// 2. allows comb operations to be used in non-HW container ops - that may use
30 /// MLIR blocks and regions to represent various forms of hierarchical
31 /// abstractions, thus allowing comb to compose with other dialects.
32 static bool hasOperandsOutsideOfBlock(Operation *op) {
33  Block *thisBlock = op->getBlock();
34  return llvm::any_of(op->getOperands(), [&](Value operand) {
35  return operand.getParentBlock() != thisBlock;
36  });
37 }
38 
39 /// Create a new instance of a generic operation that only has value operands,
40 /// and has a single result value whose type matches the first operand.
41 ///
42 /// This should not be used to create instances of ops with attributes or with
43 /// more complicated type signatures.
44 static Value createGenericOp(Location loc, OperationName name,
45  ArrayRef<Value> operands, OpBuilder &builder) {
46  OperationState state(loc, name);
47  state.addOperands(operands);
48  state.addTypes(operands[0].getType());
49  return builder.create(state)->getResult(0);
50 }
51 
52 static TypedAttr getIntAttr(const APInt &value, MLIRContext *context) {
53  return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
54  value);
55 }
56 
57 /// Flatten concat and mux operands into a vector.
58 static void getConcatOperands(Value v, SmallVectorImpl<Value> &result) {
59  if (auto concat = v.getDefiningOp<ConcatOp>()) {
60  for (auto op : concat.getOperands())
61  getConcatOperands(op, result);
62  } else if (auto repl = v.getDefiningOp<ReplicateOp>()) {
63  for (size_t i = 0, e = repl.getMultiple(); i != e; ++i)
64  getConcatOperands(repl.getOperand(), result);
65  } else {
66  result.push_back(v);
67  }
68 }
69 
70 /// A wrapper of `PatternRewriter::replaceOp` to propagate "sv.namehint"
71 /// attribute. If a replaced op has a "sv.namehint" attribute, this function
72 /// propagates the name to the new value.
73 static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
74  Value newValue) {
75  if (auto *newOp = newValue.getDefiningOp()) {
76  auto name = op->getAttrOfType<StringAttr>("sv.namehint");
77  if (name && !newOp->hasAttr("sv.namehint"))
78  rewriter.modifyOpInPlace(newOp,
79  [&] { newOp->setAttr("sv.namehint", name); });
80  }
81  rewriter.replaceOp(op, newValue);
82 }
83 
84 /// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate
85 /// "sv.namehint" attribute. If a replaced op has a "sv.namehint" attribute,
86 /// this function propagates the name to the new value.
87 template <typename OpTy, typename... Args>
88 static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
89  Operation *op, Args &&...args) {
90  auto name = op->getAttrOfType<StringAttr>("sv.namehint");
91  auto newOp =
92  rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
93  if (name && !newOp->hasAttr("sv.namehint"))
94  rewriter.modifyOpInPlace(newOp,
95  [&] { newOp->setAttr("sv.namehint", name); });
96 
97  return newOp;
98 }
99 
100 // Return true if the op has SV attributes. Note that we cannot use a helper
101 // function `hasSVAttributes` defined under SV dialect because of a cyclic
102 // dependency.
103 static bool hasSVAttributes(Operation *op) {
104  return op->hasAttr("sv.attributes");
105 }
106 
107 namespace {
108 template <typename SubType>
109 struct ComplementMatcher {
110  SubType lhs;
111  ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
112  bool match(Operation *op) {
113  auto xorOp = dyn_cast<XorOp>(op);
114  return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
115  }
116 };
117 } // end anonymous namespace
118 
119 template <typename SubType>
120 static inline ComplementMatcher<SubType> m_Complement(const SubType &subExpr) {
121  return ComplementMatcher<SubType>(subExpr);
122 }
123 
124 /// Return true if the op will be flattened afterwards. Op will be flattend if
125 /// it has a single user which has a same op type.
126 static bool shouldBeFlattened(Operation *op) {
127  assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
128  "must be commutative operations"));
129  if (op->hasOneUse()) {
130  auto *user = *op->getUsers().begin();
131  return user->getName() == op->getName() &&
132  op->getAttrOfType<UnitAttr>("twoState") ==
133  user->getAttrOfType<UnitAttr>("twoState");
134  }
135  return false;
136 }
137 
138 /// Flattens a single input in `op` if `hasOneUse` is true and it can be defined
139 /// as an Op. Returns true if successful, and false otherwise.
140 ///
141 /// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true
142 ///
143 static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) {
144  // Skip if the operation should be flattened by another operation.
145  if (shouldBeFlattened(op))
146  return false;
147 
148  auto inputs = op->getOperands();
149 
150  SmallVector<Value, 4> newOperands;
151  SmallVector<Location, 4> newLocations{op->getLoc()};
152  newOperands.reserve(inputs.size());
153  struct Element {
154  decltype(inputs.begin()) current, end;
155  };
156 
157  SmallVector<Element> worklist;
158  worklist.push_back({inputs.begin(), inputs.end()});
159  bool binFlag = op->hasAttrOfType<UnitAttr>("twoState");
160  bool changed = false;
161  while (!worklist.empty()) {
162  auto &element = worklist.back(); // Do not pop. Take ref.
163 
164  // Pop when we finished traversing the current operand range.
165  if (element.current == element.end) {
166  worklist.pop_back();
167  continue;
168  }
169 
170  Value value = *element.current++;
171  auto *flattenOp = value.getDefiningOp();
172  if (!flattenOp || flattenOp->getName() != op->getName() ||
173  flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>("twoState")) {
174  newOperands.push_back(value);
175  continue;
176  }
177 
178  // Don't duplicate logic when it has multiple uses.
179  if (!value.hasOneUse()) {
180  // We can fold a multi-use binary operation into this one if this allows a
181  // constant to fold though. For example, fold
182  // (or a, b, c, (or d, cst1), cst2) --> (or a, b, c, d, cst1, cst2)
183  // since the constants will both fold and we end up with the equiv cost.
184  //
185  // We don't do this for add/mul because the hardware won't be shared
186  // between the two ops if duplicated.
187  if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
188  !flattenOp->getOperand(1).getDefiningOp<hw::ConstantOp>() ||
189  !inputs.back().getDefiningOp<hw::ConstantOp>()) {
190  newOperands.push_back(value);
191  continue;
192  }
193  }
194 
195  changed = true;
196 
197  // Otherwise, push operands into worklist.
198  auto flattenOpInputs = flattenOp->getOperands();
199  worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
200  newLocations.push_back(flattenOp->getLoc());
201  }
202 
203  if (!changed)
204  return false;
205 
206  Value result = createGenericOp(FusedLoc::get(op->getContext(), newLocations),
207  op->getName(), newOperands, rewriter);
208  if (binFlag)
209  result.getDefiningOp()->setAttr("twoState", rewriter.getUnitAttr());
210 
211  replaceOpAndCopyName(rewriter, op, result);
212  return true;
213 }
214 
215 // Given a range of uses of an operation, find the lowest and highest bits
216 // inclusive that are ever referenced. The range of uses must not be empty.
217 static std::pair<size_t, size_t>
218 getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits,
219  size_t originalOpWidth) {
220  auto users = op->getUsers();
221  assert(!users.empty() &&
222  "getLowestBitAndHighestBitRequired cannot operate on "
223  "a empty list of uses.");
224 
225  // when we don't want to narrowTrailingBits (namely in arithmetic
226  // operations), forcing lowestBitRequired = 0
227  size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
228  size_t highestBitRequired = 0;
229 
230  for (auto *user : users) {
231  if (auto extractOp = dyn_cast<ExtractOp>(user)) {
232  size_t lowBit = extractOp.getLowBit();
233  size_t highBit =
234  cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
235  highestBitRequired = std::max(highestBitRequired, highBit);
236  lowestBitRequired = std::min(lowestBitRequired, lowBit);
237  continue;
238  }
239 
240  highestBitRequired = originalOpWidth - 1;
241  lowestBitRequired = 0;
242  break;
243  }
244 
245  return {lowestBitRequired, highestBitRequired};
246 }
247 
248 template <class OpTy>
249 static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits,
250  PatternRewriter &rewriter) {
251  IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
252  if (!opType)
253  return false;
254 
255  auto range = getLowestBitAndHighestBitRequired(op, narrowTrailingBits,
256  opType.getWidth());
257  if (range.second + 1 == opType.getWidth() && range.first == 0)
258  return false;
259 
260  SmallVector<Value> args;
261  auto newType = rewriter.getIntegerType(range.second - range.first + 1);
262  for (auto inop : op.getOperands()) {
263  // deal with muxes here
264  if (inop.getType() != op.getType())
265  args.push_back(inop);
266  else
267  args.push_back(rewriter.createOrFold<ExtractOp>(inop.getLoc(), newType,
268  inop, range.first));
269  }
270  Value newop = rewriter.createOrFold<OpTy>(op.getLoc(), newType, args);
271  newop.getDefiningOp()->setDialectAttrs(op->getDialectAttrs());
272  if (range.first)
273  newop = rewriter.createOrFold<ConcatOp>(
274  op.getLoc(), newop,
275  rewriter.create<hw::ConstantOp>(op.getLoc(),
276  APInt::getZero(range.first)));
277  if (range.second + 1 < opType.getWidth())
278  newop = rewriter.createOrFold<ConcatOp>(
279  op.getLoc(),
280  rewriter.create<hw::ConstantOp>(
281  op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
282  newop);
283  rewriter.replaceOp(op, newop);
284  return true;
285 }
286 
287 //===----------------------------------------------------------------------===//
288 // Unary Operations
289 //===----------------------------------------------------------------------===//
290 
291 OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
292  if (hasOperandsOutsideOfBlock(getOperation()))
293  return {};
294 
295  // Replicate one time -> noop.
296  if (cast<IntegerType>(getType()).getWidth() ==
297  getInput().getType().getIntOrFloatBitWidth())
298  return getInput();
299 
300  // Constant fold.
301  if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
302  if (input.getValue().getBitWidth() == 1) {
303  if (input.getValue().isZero())
304  return getIntAttr(
305  APInt::getZero(cast<IntegerType>(getType()).getWidth()),
306  getContext());
307  return getIntAttr(
308  APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
309  getContext());
310  }
311 
312  APInt result = APInt::getZeroWidth();
313  for (auto i = getMultiple(); i != 0; --i)
314  result = result.concat(input.getValue());
315  return getIntAttr(result, getContext());
316  }
317 
318  return {};
319 }
320 
321 OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
322  if (hasOperandsOutsideOfBlock(getOperation()))
323  return {};
324 
325  // Constant fold.
326  if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
327  return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
328 
329  return {};
330 }
331 
332 //===----------------------------------------------------------------------===//
333 // Binary Operations
334 //===----------------------------------------------------------------------===//
335 
336 /// Performs constant folding `calculate` with element-wise behavior on the two
337 /// attributes in `operands` and returns the result if possible.
338 static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
339  hw::PEO paramOpcode) {
340  assert(operands.size() == 2 && "binary op takes two operands");
341  if (!operands[0] || !operands[1])
342  return {};
343 
344  // Fold constants with ParamExprAttr::get which handles simple constants as
345  // well as parameter expressions.
346  return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
347  cast<TypedAttr>(operands[1]));
348 }
349 
350 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
351  if (hasOperandsOutsideOfBlock(getOperation()))
352  return {};
353 
354  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
355  unsigned shift = rhs.getValue().getZExtValue();
356  unsigned width = getType().getIntOrFloatBitWidth();
357  if (shift == 0)
358  return getOperand(0);
359  if (width <= shift)
360  return getIntAttr(APInt::getZero(width), getContext());
361  }
362 
363  return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::Shl);
364 }
365 
366 LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
367  if (hasOperandsOutsideOfBlock(&*op))
368  return failure();
369 
370  // ShlOp(x, cst) -> Concat(Extract(x), zeros)
371  APInt value;
372  if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
373  return failure();
374 
375  unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
376  unsigned shift = value.getZExtValue();
377 
378  // This case is handled by fold.
379  if (width <= shift || shift == 0)
380  return failure();
381 
382  auto zeros =
383  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
384 
385  // Remove the high bits which would be removed by the Shl.
386  auto extract =
387  rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), 0, width - shift);
388 
389  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, extract, zeros);
390  return success();
391 }
392 
393 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
394  if (hasOperandsOutsideOfBlock(getOperation()))
395  return {};
396 
397  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
398  unsigned shift = rhs.getValue().getZExtValue();
399  if (shift == 0)
400  return getOperand(0);
401 
402  unsigned width = getType().getIntOrFloatBitWidth();
403  if (width <= shift)
404  return getIntAttr(APInt::getZero(width), getContext());
405  }
406  return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrU);
407 }
408 
409 LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
410  if (hasOperandsOutsideOfBlock(&*op))
411  return failure();
412 
413  // ShrUOp(x, cst) -> Concat(zeros, Extract(x))
414  APInt value;
415  if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
416  return failure();
417 
418  unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
419  unsigned shift = value.getZExtValue();
420 
421  // This case is handled by fold.
422  if (width <= shift || shift == 0)
423  return failure();
424 
425  auto zeros =
426  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
427 
428  // Remove the low bits which would be removed by the Shr.
429  auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
430  width - shift);
431 
432  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, zeros, extract);
433  return success();
434 }
435 
436 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
437  if (hasOperandsOutsideOfBlock(getOperation()))
438  return {};
439 
440  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
441  if (rhs.getValue().getZExtValue() == 0)
442  return getOperand(0);
443  }
444  return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrS);
445 }
446 
447 LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
448  if (hasOperandsOutsideOfBlock(&*op))
449  return failure();
450 
451  // ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x))
452  APInt value;
453  if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
454  return failure();
455 
456  unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
457  unsigned shift = value.getZExtValue();
458 
459  auto topbit =
460  rewriter.createOrFold<ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
461  auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
462 
463  if (width <= shift) {
464  replaceOpAndCopyName(rewriter, op, {sext});
465  return success();
466  }
467 
468  auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
469  width - shift);
470 
471  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, sext, extract);
472  return success();
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // Other Operations
477 //===----------------------------------------------------------------------===//
478 
479 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
480  if (hasOperandsOutsideOfBlock(getOperation()))
481  return {};
482 
483  // If we are extracting the entire input, then return it.
484  if (getInput().getType() == getType())
485  return getInput();
486 
487  // Constant fold.
488  if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
489  unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
490  return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
491  getContext());
492  }
493  return {};
494 }
495 
496 // Transforms extract(lo, cat(a, b, c, d, e)) into
497 // cat(extract(lo1, b), c, extract(lo2, d)).
498 // innerCat must be the argument of the provided ExtractOp.
499 static LogicalResult extractConcatToConcatExtract(ExtractOp op,
500  ConcatOp innerCat,
501  PatternRewriter &rewriter) {
502  auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
503  size_t beginOfFirstRelevantElement = 0;
504  auto it = reversedConcatArgs.begin();
505  size_t lowBit = op.getLowBit();
506 
507  // This loop finds the first concatArg that is covered by the ExtractOp
508  for (; it != reversedConcatArgs.end(); it++) {
509  assert(beginOfFirstRelevantElement <= lowBit &&
510  "incorrectly moved past an element that lowBit has coverage over");
511  auto operand = *it;
512 
513  size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
514  if (lowBit < beginOfFirstRelevantElement + operandWidth) {
515  // A bit other than the first bit will be used in this element.
516  // ...... ........ ...
517  // ^---lowBit
518  // ^---beginOfFirstRelevantElement
519  //
520  // Edge-case close to the end of the range.
521  // ...... ........ ...
522  // ^---(position + operandWidth)
523  // ^---lowBit
524  // ^---beginOfFirstRelevantElement
525  //
526  // Edge-case close to the beginning of the rang
527  // ...... ........ ...
528  // ^---lowBit
529  // ^---beginOfFirstRelevantElement
530  //
531  break;
532  }
533 
534  // extraction discards this element.
535  // ...... ........ ...
536  // | ^---lowBit
537  // ^---beginOfFirstRelevantElement
538  beginOfFirstRelevantElement += operandWidth;
539  }
540  assert(it != reversedConcatArgs.end() &&
541  "incorrectly failed to find an element which contains coverage of "
542  "lowBit");
543 
544  SmallVector<Value> reverseConcatArgs;
545  size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
546  size_t extractLo = lowBit - beginOfFirstRelevantElement;
547 
548  // Transform individual arguments of innerCat(..., a, b, c,) into
549  // [ extract(a), b, extract(c) ], skipping an extract operation where
550  // possible.
551  for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
552  auto concatArg = *it;
553  size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
554  size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
555 
556  if (widthToConsume == operandWidth && extractLo == 0) {
557  reverseConcatArgs.push_back(concatArg);
558  } else {
559  auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
560  reverseConcatArgs.push_back(
561  rewriter.create<ExtractOp>(op.getLoc(), resultType, *it, extractLo));
562  }
563 
564  widthRemaining -= widthToConsume;
565 
566  // Beyond the first element, all elements are extracted from position 0.
567  extractLo = 0;
568  }
569 
570  if (reverseConcatArgs.size() == 1) {
571  replaceOpAndCopyName(rewriter, op, reverseConcatArgs[0]);
572  } else {
573  replaceOpWithNewOpAndCopyName<ConcatOp>(
574  rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
575  }
576  return success();
577 }
578 
579 // Transforms extract(lo, replicate(a, N)) into replicate(a, N-c).
580 static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate,
581  PatternRewriter &rewriter) {
582  auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
583  auto replicateEltWidth =
584  replicate.getOperand().getType().getIntOrFloatBitWidth();
585 
586  // If the extract starts at the base of an element and is an even multiple,
587  // we can replace the extract with a smaller replicate.
588  if (op.getLowBit() % replicateEltWidth == 0 &&
589  extractResultWidth % replicateEltWidth == 0) {
590  replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
591  replicate.getOperand());
592  return true;
593  }
594 
595  // If the extract is completely contained in one element, extract from the
596  // element.
597  if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
598  replicateEltWidth) {
599  replaceOpWithNewOpAndCopyName<ExtractOp>(
600  rewriter, op, op.getType(), replicate.getOperand(),
601  op.getLowBit() % replicateEltWidth);
602  return true;
603  }
604 
605  // We don't currently handle the case of extracting from non-whole elements,
606  // e.g. `extract (replicate 2-bit-thing, N), 1`.
607  return false;
608 }
609 
610 LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
611  if (hasOperandsOutsideOfBlock(&*op))
612  return failure();
613 
614  auto *inputOp = op.getInput().getDefiningOp();
615 
616  // This turns out to be incredibly expensive. Disable until performance is
617  // addressed.
618 #if 0
619  // If the extracted bits are all known, then return the result.
620  auto knownBits = computeKnownBits(op.getInput())
621  .extractBits(cast<IntegerType>(op.getType()).getWidth(),
622  op.getLowBit());
623  if (knownBits.isConstant()) {
624  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
625  knownBits.getConstant());
626  return success();
627  }
628 #endif
629 
630  // extract(olo, extract(ilo, x)) = extract(olo + ilo, x)
631  if (auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
632  replaceOpWithNewOpAndCopyName<ExtractOp>(
633  rewriter, op, op.getType(), innerExtract.getInput(),
634  innerExtract.getLowBit() + op.getLowBit());
635  return success();
636  }
637 
638  // extract(lo, cat(a, b, c, d, e)) = cat(extract(lo1, b), c, extract(lo2, d))
639  if (auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
640  return extractConcatToConcatExtract(op, innerCat, rewriter);
641 
642  // extract(lo, replicate(a))
643  if (auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
644  if (extractFromReplicate(op, replicate, rewriter))
645  return success();
646 
647  // `extract(and(a, cst))` -> `extract(a)` when the relevant bits of the
648  // and/or/xor are not modifying the extracted bits.
649  if (inputOp && inputOp->getNumOperands() == 2 &&
650  isa<AndOp, OrOp, XorOp>(inputOp)) {
651  if (auto cstRHS = inputOp->getOperand(1).getDefiningOp<hw::ConstantOp>()) {
652  auto extractedCst = cstRHS.getValue().extractBits(
653  cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
654  if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
655  replaceOpWithNewOpAndCopyName<ExtractOp>(
656  rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
657  return success();
658  }
659 
660  // `extract(and(a, cst))` -> `concat(extract(a), 0)` if we only need one
661  // extract to represent the result. Turning it into a pile of extracts is
662  // always fine by our cost model, but we don't want to explode things into
663  // a ton of bits because it will bloat the IR and generated Verilog.
664  if (isa<AndOp>(inputOp)) {
665  // For our cost model, we only do this if the bit pattern is a
666  // contiguous series of ones.
667  unsigned lz = extractedCst.countLeadingZeros();
668  unsigned tz = extractedCst.countTrailingZeros();
669  unsigned pop = extractedCst.popcount();
670  if (extractedCst.getBitWidth() - lz - tz == pop) {
671  auto resultTy = rewriter.getIntegerType(pop);
672  SmallVector<Value> resultElts;
673  if (lz)
674  resultElts.push_back(rewriter.create<hw::ConstantOp>(
675  op.getLoc(), APInt::getZero(lz)));
676  resultElts.push_back(rewriter.createOrFold<ExtractOp>(
677  op.getLoc(), resultTy, inputOp->getOperand(0),
678  op.getLowBit() + tz));
679  if (tz)
680  resultElts.push_back(rewriter.create<hw::ConstantOp>(
681  op.getLoc(), APInt::getZero(tz)));
682  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
683  return success();
684  }
685  }
686  }
687  }
688 
689  // `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is
690  // extracted.
691  if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
692  if (auto shlOp = dyn_cast<ShlOp>(inputOp))
693  if (auto lhsCst = shlOp.getOperand(0).getDefiningOp<hw::ConstantOp>())
694  if (lhsCst.getValue().isOne()) {
695  auto newCst = rewriter.create<hw::ConstantOp>(
696  shlOp.getLoc(),
697  APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
698  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
699  shlOp->getOperand(1), newCst,
700  false);
701  return success();
702  }
703 
704  return failure();
705 }
706 
707 //===----------------------------------------------------------------------===//
708 // Associative Variadic operations
709 //===----------------------------------------------------------------------===//
710 
711 // Reduce all operands to a single value (either integer constant or parameter
712 // expression) if all the operands are constants.
713 static Attribute constFoldAssociativeOp(ArrayRef<Attribute> operands,
714  hw::PEO paramOpcode) {
715  assert(operands.size() > 1 && "caller should handle one-operand case");
716  // We can only fold anything in the case where all operands are known to be
717  // constants. Check the least common one first for an early out.
718  if (!operands[1] || !operands[0])
719  return {};
720 
721  // This will fold to a simple constant if all operands are constant.
722  if (llvm::all_of(operands.drop_front(2),
723  [&](Attribute in) { return !!in; })) {
724  SmallVector<mlir::TypedAttr> typedOperands;
725  typedOperands.reserve(operands.size());
726  for (auto operand : operands) {
727  if (auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
728  typedOperands.push_back(typedOperand);
729  else
730  break;
731  }
732  if (typedOperands.size() == operands.size())
733  return hw::ParamExprAttr::get(paramOpcode, typedOperands);
734  }
735 
736  return {};
737 }
738 
739 /// When we find a logical operation (and, or, xor) with a constant e.g.
740 /// `X & 42`, we want to push the constant into the computation of X if it leads
741 /// to simplification.
742 ///
743 /// This function handles the case where the logical operation has a concat
744 /// operand. We check to see if we can simplify the concat, e.g. when it has
745 /// constant operands.
746 ///
747 /// This returns true when a simplification happens.
748 static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp,
749  size_t concatIdx, const APInt &cst,
750  PatternRewriter &rewriter) {
751  auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<ConcatOp>();
752  assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
753 
754  // Check to see if any operands can be simplified by pushing the logical op
755  // into all parts of the concat.
756  bool canSimplify =
757  llvm::any_of(concatOp->getOperands(), [&](Value operand) -> bool {
758  auto *operandOp = operand.getDefiningOp();
759  if (!operandOp)
760  return false;
761 
762  // If the concat has a constant operand then we can transform this.
763  if (isa<hw::ConstantOp>(operandOp))
764  return true;
765  // If the concat has the same logical operation and that operation has
766  // a constant operation than we can fold it into that suboperation.
767  return operandOp->getName() == logicalOp->getName() &&
768  operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
769  operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
770  });
771 
772  if (!canSimplify)
773  return false;
774 
775  // Create a new instance of the logical operation. We have to do this the
776  // hard way since we're generic across a family of different ops.
777  auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
778  return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
779  rewriter);
780  };
781 
782  // Ok, let's do the transformation. We do this by slicing up the constant
783  // for each unit of the concat and duplicate the operation into the
784  // sub-operand.
785  SmallVector<Value> newConcatOperands;
786  newConcatOperands.reserve(concatOp->getNumOperands());
787 
788  // Work from MSB to LSB.
789  size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
790  for (Value operand : concatOp->getOperands()) {
791  size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
792  nextOperandBit -= operandWidth;
793  // Take a slice of the constant.
794  auto eltCst = rewriter.create<hw::ConstantOp>(
795  logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
796 
797  newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
798  }
799 
800  // Create the concat, and the rest of the logical op if we need it.
801  Value newResult =
802  rewriter.create<ConcatOp>(concatOp.getLoc(), newConcatOperands);
803 
804  // If we had a variadic logical op on the top level, then recreate it with the
805  // new concat and without the constant operand.
806  if (logicalOp->getNumOperands() > 2) {
807  auto origOperands = logicalOp->getOperands();
808  SmallVector<Value> operands;
809  // Take any stuff before the concat.
810  operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
811  // Take any stuff after the concat but before the constant.
812  operands.append(origOperands.begin() + concatIdx + 1,
813  origOperands.begin() + (origOperands.size() - 1));
814  // Include the new concat.
815  operands.push_back(newResult);
816  newResult = createLogicalOp(operands);
817  }
818 
819  replaceOpAndCopyName(rewriter, logicalOp, newResult);
820  return true;
821 }
822 
823 // Determines whether the inputs to a logical element are of opposite
824 // comparisons and can lowered into a constant.
825 static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands) {
826  llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
827 
828  for (auto op : operands) {
829  if (auto icmpOp = op.getDefiningOp<ICmpOp>();
830  icmpOp && icmpOp.getTwoState()) {
831  auto predicate = icmpOp.getPredicate();
832  auto lhs = icmpOp.getLhs();
833  auto rhs = icmpOp.getRhs();
834  if (seenPredicates.contains(
835  {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
836  return true;
837 
838  seenPredicates.insert({predicate, lhs, rhs});
839  }
840  }
841  return false;
842 }
843 
844 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
845  if (hasOperandsOutsideOfBlock(getOperation()))
846  return {};
847 
848  APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).getWidth());
849 
850  auto inputs = adaptor.getInputs();
851 
852  // and(x, 01, 10) -> 00 -- annulment.
853  for (auto operand : inputs) {
854  if (!operand)
855  continue;
856  value &= cast<IntegerAttr>(operand).getValue();
857  if (value.isZero())
858  return getIntAttr(value, getContext());
859  }
860 
861  // and(x, -1) -> x.
862  if (inputs.size() == 2 && inputs[1] &&
863  cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
864  return getInputs()[0];
865 
866  // and(x, x, x) -> x. This also handles and(x) -> x.
867  if (llvm::all_of(getInputs(),
868  [&](auto in) { return in == this->getInputs()[0]; }))
869  return getInputs()[0];
870 
871  // and(..., x, ..., ~x, ...) -> 0
872  for (Value arg : getInputs()) {
873  Value subExpr;
874  if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
875  for (Value arg2 : getInputs())
876  if (arg2 == subExpr)
877  return getIntAttr(
878  APInt::getZero(cast<IntegerType>(getType()).getWidth()),
879  getContext());
880  }
881  }
882 
883  // x0 = icmp(pred, x, y)
884  // x1 = icmp(!pred, x, y)
885  // and(x0, x1) -> 0
886  if (canCombineOppositeBinCmpIntoConstant(getInputs()))
887  return getIntAttr(APInt::getZero(cast<IntegerType>(getType()).getWidth()),
888  getContext());
889 
890  // Constant fold
891  return constFoldAssociativeOp(inputs, hw::PEO::And);
892 }
893 
894 /// Returns a single common operand that all inputs of the operation `op` can
895 /// be traced back to, or an empty `Value` if no such operand exists.
896 ///
897 /// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a`
898 /// (assuming the bit-width of `a` is `n`).
899 template <typename Op>
900 static Value getCommonOperand(Op op) {
901  if (!op.getType().isInteger(1))
902  return Value();
903 
904  auto inputs = op.getInputs();
905  size_t size = inputs.size();
906 
907  auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
908  if (!sourceOp)
909  return Value();
910  Value source = sourceOp.getOperand();
911 
912  // Fast path: the input size is not equal to the width of the source.
913  if (size != source.getType().getIntOrFloatBitWidth())
914  return Value();
915 
916  // Tracks the bits that were encountered.
917  llvm::BitVector bits(size);
918  bits.set(sourceOp.getLowBit());
919 
920  for (size_t i = 1; i != size; ++i) {
921  auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
922  if (!extractOp || extractOp.getOperand() != source)
923  return Value();
924  bits.set(extractOp.getLowBit());
925  }
926 
927  return bits.all() ? source : Value();
928 }
929 
930 /// Canonicalize an idempotent operation `op` so that only one input of any kind
931 /// occurs.
932 ///
933 /// Example: `and(x, y, x, z)` -> `and(x, y, z)`
934 template <typename Op>
935 static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
936  auto inputs = op.getInputs();
937 
938  llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
939  llvm::SmallDenseSet<Value, 8> checked;
940  checked.insert(op);
941 
942  llvm::SmallVector<Value, 8> worklist;
943  for (auto input : inputs) {
944  if (input != op)
945  worklist.push_back(input);
946  }
947 
948  while (!worklist.empty()) {
949  auto element = worklist.pop_back_val();
950 
951  if (auto idempotentOp = element.getDefiningOp<Op>()) {
952  for (auto input : idempotentOp.getInputs()) {
953  uniqueInputs.remove(input);
954 
955  if (checked.insert(input).second)
956  worklist.push_back(input);
957  }
958  }
959  }
960 
961  if (uniqueInputs.size() < inputs.size()) {
962  replaceOpWithNewOpAndCopyName<Op>(rewriter, op, op.getType(),
963  uniqueInputs.getArrayRef());
964  return true;
965  }
966 
967  return false;
968 }
969 
970 LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
971  if (hasOperandsOutsideOfBlock(&*op))
972  return failure();
973 
974  auto inputs = op.getInputs();
975  auto size = inputs.size();
976  assert(size > 1 && "expected 2 or more operands, `fold` should handle this");
977 
978  // and(x, and(...)) -> and(x, ...) -- flatten
979  if (tryFlatteningOperands(op, rewriter))
980  return success();
981 
982  // and(..., x, ..., x) -> and(..., x, ...) -- idempotent
983  // and(..., x, and(..., x, ...)) -> and(..., and(..., x, ...)) -- idempotent
984  // Trivial and(x), and(x, x) cases are handled by [AndOp::fold] above.
985  if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
986  return success();
987 
988  // Patterns for and with a constant on RHS.
989  APInt value;
990  if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
991  // and(..., '1) -> and(...) -- identity
992  if (value.isAllOnes()) {
993  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
994  inputs.drop_back(), false);
995  return success();
996  }
997 
998  // TODO: Combine multiple constants together even if they aren't at the
999  // end. and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant
1000  // folding
1001  APInt value2;
1002  if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1003  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value & value2);
1004  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1005  newOperands.push_back(cst);
1006  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
1007  newOperands, false);
1008  return success();
1009  }
1010 
1011  // Handle 'and' with a single bit constant on the RHS.
1012  if (size == 2 && value.isPowerOf2()) {
1013  // If the LHS is a replicate from a single bit, we can 'concat' it
1014  // into place. e.g.:
1015  // `replicate(x) & 4` -> `concat(zeros, x, zeros)`
1016  // TODO: Generalize this for non-single-bit operands.
1017  if (auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1018  auto replicateOperand = replicate.getOperand();
1019  if (replicateOperand.getType().isInteger(1)) {
1020  unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1021  auto trailingZeros = value.countTrailingZeros();
1022 
1023  // Don't add zero bit constants unnecessarily.
1024  SmallVector<Value, 3> concatOperands;
1025  if (trailingZeros != resultWidth - 1) {
1026  auto highZeros = rewriter.create<hw::ConstantOp>(
1027  op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
1028  concatOperands.push_back(highZeros);
1029  }
1030  concatOperands.push_back(replicateOperand);
1031  if (trailingZeros != 0) {
1032  auto lowZeros = rewriter.create<hw::ConstantOp>(
1033  op.getLoc(), APInt::getZero(trailingZeros));
1034  concatOperands.push_back(lowZeros);
1035  }
1036  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1037  concatOperands);
1038  return success();
1039  }
1040  }
1041  }
1042 
1043  // If this is an and from an extract op, try shrinking the extract.
1044  if (auto extractOp = inputs[0].getDefiningOp<ExtractOp>()) {
1045  if (size == 2 &&
1046  // We can shrink it if the mask has leading or trailing zeros.
1047  (value.countLeadingZeros() || value.countTrailingZeros())) {
1048  unsigned lz = value.countLeadingZeros();
1049  unsigned tz = value.countTrailingZeros();
1050 
1051  // Start by extracting the smaller number of bits.
1052  auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
1053  Value smallElt = rewriter.createOrFold<ExtractOp>(
1054  extractOp.getLoc(), smallTy, extractOp->getOperand(0),
1055  extractOp.getLowBit() + tz);
1056  // Apply the 'and' mask if needed.
1057  APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
1058  if (!smallMask.isAllOnes()) {
1059  auto loc = inputs.back().getLoc();
1060  smallElt = rewriter.createOrFold<AndOp>(
1061  loc, smallElt, rewriter.create<hw::ConstantOp>(loc, smallMask),
1062  false);
1063  }
1064 
1065  // The final replacement will be a concat of the leading/trailing zeros
1066  // along with the smaller extracted value.
1067  SmallVector<Value> resultElts;
1068  if (lz)
1069  resultElts.push_back(
1070  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
1071  resultElts.push_back(smallElt);
1072  if (tz)
1073  resultElts.push_back(
1074  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
1075  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
1076  return success();
1077  }
1078  }
1079 
1080  // and(concat(x, cst1), a, b, c, cst2)
1081  // ==> and(a, b, c, concat(and(x,cst2'), and(cst1,cst2'')).
1082  // We do this for even more multi-use concats since they are "just wiring".
1083  for (size_t i = 0; i < size - 1; ++i) {
1084  if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1085  if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1086  return success();
1087  }
1088  }
1089 
1090  // extracts only of and(...) -> and(extract()...)
1091  if (narrowOperationWidth(op, true, rewriter))
1092  return success();
1093 
1094  // and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1)
1095  if (auto source = getCommonOperand(op)) {
1096  auto cmpAgainst =
1097  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1098  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1099  source, cmpAgainst);
1100  return success();
1101  }
1102 
1103  /// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
1104  return failure();
1105 }
1106 
1107 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1108  if (hasOperandsOutsideOfBlock(getOperation()))
1109  return {};
1110 
1111  auto value = APInt::getZero(cast<IntegerType>(getType()).getWidth());
1112  auto inputs = adaptor.getInputs();
1113  // or(x, 10, 01) -> 11
1114  for (auto operand : inputs) {
1115  if (!operand)
1116  continue;
1117  value |= cast<IntegerAttr>(operand).getValue();
1118  if (value.isAllOnes())
1119  return getIntAttr(value, getContext());
1120  }
1121 
1122  // or(x, 0) -> x
1123  if (inputs.size() == 2 && inputs[1] &&
1124  cast<IntegerAttr>(inputs[1]).getValue().isZero())
1125  return getInputs()[0];
1126 
1127  // or(x, x, x) -> x. This also handles or(x) -> x
1128  if (llvm::all_of(getInputs(),
1129  [&](auto in) { return in == this->getInputs()[0]; }))
1130  return getInputs()[0];
1131 
1132  // or(..., x, ..., ~x, ...) -> -1
1133  for (Value arg : getInputs()) {
1134  Value subExpr;
1135  if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
1136  for (Value arg2 : getInputs())
1137  if (arg2 == subExpr)
1138  return getIntAttr(
1139  APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1140  getContext());
1141  }
1142  }
1143 
1144  // x0 = icmp(pred, x, y)
1145  // x1 = icmp(!pred, x, y)
1146  // or(x0, x1) -> 1
1147  if (canCombineOppositeBinCmpIntoConstant(getInputs()))
1148  return getIntAttr(
1149  APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1150  getContext());
1151 
1152  // Constant fold
1153  return constFoldAssociativeOp(inputs, hw::PEO::Or);
1154 }
1155 
1156 /// Simplify concat ops in an or op when a constant operand is present in either
1157 /// concat.
1158 ///
1159 /// This will invert an or(concat, concat) into concat(or, or, ...), which can
1160 /// often be further simplified due to the smaller or ops being easier to fold.
1161 ///
1162 /// For example:
1163 ///
1164 /// or(..., concat(x, 0), concat(0, y))
1165 /// ==> or(..., concat(x, 0, y)), when x and y don't overlap.
1166 ///
1167 /// or(..., concat(x: i2, cst1: i4), concat(cst2: i5, y: i1))
1168 /// ==> or(..., concat(or(x: i2, extract(cst2, 4..3)),
1169 /// or(extract(cst1, 3..1), extract(cst2, 2..0)),
1170 /// or(extract(cst1, 0..0), y: i1))
1171 static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1,
1172  size_t concatIdx2,
1173  PatternRewriter &rewriter) {
1174  assert(concatIdx1 < concatIdx2 && "concatIdx1 must be < concatIdx2");
1175 
1176  auto inputs = op.getInputs();
1177  auto concat1 = inputs[concatIdx1].getDefiningOp<ConcatOp>();
1178  auto concat2 = inputs[concatIdx2].getDefiningOp<ConcatOp>();
1179 
1180  assert(concat1 && concat2 && "expected indexes to point to ConcatOps");
1181 
1182  // We can simplify as long as a constant is present in either concat.
1183  bool hasConstantOp1 =
1184  llvm::any_of(concat1->getOperands(), [&](Value operand) -> bool {
1185  return operand.getDefiningOp<hw::ConstantOp>();
1186  });
1187  if (!hasConstantOp1) {
1188  bool hasConstantOp2 =
1189  llvm::any_of(concat2->getOperands(), [&](Value operand) -> bool {
1190  return operand.getDefiningOp<hw::ConstantOp>();
1191  });
1192  if (!hasConstantOp2)
1193  return false;
1194  }
1195 
1196  SmallVector<Value> newConcatOperands;
1197 
1198  // Simultaneously iterate over the operands of both concat ops, from MSB to
1199  // LSB, pushing out or's of overlapping ranges of the operands. When operands
1200  // span different bit ranges, we extract only the maximum overlap.
1201  auto operands1 = concat1->getOperands();
1202  auto operands2 = concat2->getOperands();
1203  // Number of bits already consumed from operands 1 and 2, respectively.
1204  unsigned consumedWidth1 = 0;
1205  unsigned consumedWidth2 = 0;
1206  for (auto it1 = operands1.begin(), end1 = operands1.end(),
1207  it2 = operands2.begin(), end2 = operands2.end();
1208  it1 != end1 && it2 != end2;) {
1209  auto operand1 = *it1;
1210  auto operand2 = *it2;
1211 
1212  unsigned remainingWidth1 =
1213  hw::getBitWidth(operand1.getType()) - consumedWidth1;
1214  unsigned remainingWidth2 =
1215  hw::getBitWidth(operand2.getType()) - consumedWidth2;
1216  unsigned widthToConsume = std::min(remainingWidth1, remainingWidth2);
1217  auto narrowedType = rewriter.getIntegerType(widthToConsume);
1218 
1219  auto extract1 = rewriter.createOrFold<ExtractOp>(
1220  op.getLoc(), narrowedType, operand1, remainingWidth1 - widthToConsume);
1221  auto extract2 = rewriter.createOrFold<ExtractOp>(
1222  op.getLoc(), narrowedType, operand2, remainingWidth2 - widthToConsume);
1223 
1224  newConcatOperands.push_back(
1225  rewriter.createOrFold<OrOp>(op.getLoc(), extract1, extract2, false));
1226 
1227  consumedWidth1 += widthToConsume;
1228  consumedWidth2 += widthToConsume;
1229 
1230  if (widthToConsume == remainingWidth1) {
1231  ++it1;
1232  consumedWidth1 = 0;
1233  }
1234  if (widthToConsume == remainingWidth2) {
1235  ++it2;
1236  consumedWidth2 = 0;
1237  }
1238  }
1239 
1240  ConcatOp newOp = rewriter.create<ConcatOp>(op.getLoc(), newConcatOperands);
1241 
1242  // Copy the old operands except for concatIdx1 and concatIdx2, and append the
1243  // new ConcatOp to the end.
1244  SmallVector<Value> newOrOperands;
1245  newOrOperands.append(inputs.begin(), inputs.begin() + concatIdx1);
1246  newOrOperands.append(inputs.begin() + concatIdx1 + 1,
1247  inputs.begin() + concatIdx2);
1248  newOrOperands.append(inputs.begin() + concatIdx2 + 1,
1249  inputs.begin() + inputs.size());
1250  newOrOperands.push_back(newOp);
1251 
1252  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1253  newOrOperands);
1254  return true;
1255 }
1256 
1257 LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
1258  if (hasOperandsOutsideOfBlock(&*op))
1259  return failure();
1260 
1261  auto inputs = op.getInputs();
1262  auto size = inputs.size();
1263  assert(size > 1 && "expected 2 or more operands");
1264 
1265  // or(x, or(...)) -> or(x, ...) -- flatten
1266  if (tryFlatteningOperands(op, rewriter))
1267  return success();
1268 
1269  // or(..., x, ..., x, ...) -> or(..., x) -- idempotent
1270  // or(..., x, or(..., x, ...)) -> or(..., or(..., x, ...)) -- idempotent
1271  // Trivial or(x), or(x, x) cases are handled by [OrOp::fold].
1272  if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
1273  return success();
1274 
1275  // Patterns for and with a constant on RHS.
1276  APInt value;
1277  if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1278  // or(..., '0) -> or(...) -- identity
1279  if (value.isZero()) {
1280  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1281  inputs.drop_back());
1282  return success();
1283  }
1284 
1285  // or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant folding
1286  APInt value2;
1287  if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1288  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value | value2);
1289  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1290  newOperands.push_back(cst);
1291  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1292  newOperands);
1293  return success();
1294  }
1295 
1296  // or(concat(x, cst1), a, b, c, cst2)
1297  // ==> or(a, b, c, concat(or(x,cst2'), or(cst1,cst2'')).
1298  // We do this for even more multi-use concats since they are "just wiring".
1299  for (size_t i = 0; i < size - 1; ++i) {
1300  if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1301  if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1302  return success();
1303  }
1304  }
1305 
1306  // or(..., concat(x, cst1), concat(cst2, y)
1307  // ==> or(..., concat(x, cst3, y)), when x and y don't overlap.
1308  for (size_t i = 0; i < size - 1; ++i) {
1309  if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1310  for (size_t j = i + 1; j < size; ++j)
1311  if (auto concat = inputs[j].getDefiningOp<ConcatOp>())
1312  if (canonicalizeOrOfConcatsWithCstOperands(op, i, j, rewriter))
1313  return success();
1314  }
1315 
1316  // extracts only of or(...) -> or(extract()...)
1317  if (narrowOperationWidth(op, true, rewriter))
1318  return success();
1319 
1320  // or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0)
1321  if (auto source = getCommonOperand(op)) {
1322  auto cmpAgainst =
1323  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1324  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1325  source, cmpAgainst);
1326  return success();
1327  }
1328 
1329  // or(mux(c_1, a, 0), mux(c_2, a, 0), ..., mux(c_n, a, 0)) -> mux(or(c_1, c_2,
1330  // .., c_n), a, 0)
1331  if (auto firstMux = op.getOperand(0).getDefiningOp<comb::MuxOp>()) {
1332  APInt value;
1333  if (op.getTwoState() && firstMux.getTwoState() &&
1334  matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1335  value.isZero()) {
1336  SmallVector<Value> conditions{firstMux.getCond()};
1337  auto check = [&](Value v) {
1338  auto mux = v.getDefiningOp<comb::MuxOp>();
1339  if (!mux)
1340  return false;
1341  conditions.push_back(mux.getCond());
1342  return mux.getTwoState() &&
1343  firstMux.getTrueValue() == mux.getTrueValue() &&
1344  firstMux.getFalseValue() == mux.getFalseValue();
1345  };
1346  if (llvm::all_of(op.getOperands().drop_front(), check)) {
1347  auto cond = rewriter.create<comb::OrOp>(op.getLoc(), conditions, true);
1348  replaceOpWithNewOpAndCopyName<comb::MuxOp>(
1349  rewriter, op, cond, firstMux.getTrueValue(),
1350  firstMux.getFalseValue(), true);
1351  return success();
1352  }
1353  }
1354  }
1355 
1356  /// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
1357  return failure();
1358 }
1359 
1360 OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1361  if (hasOperandsOutsideOfBlock(getOperation()))
1362  return {};
1363 
1364  auto size = getInputs().size();
1365  auto inputs = adaptor.getInputs();
1366 
1367  // xor(x) -> x -- noop
1368  if (size == 1)
1369  return getInputs()[0];
1370 
1371  // xor(x, x) -> 0 -- idempotent
1372  if (size == 2 && getInputs()[0] == getInputs()[1])
1373  return IntegerAttr::get(getType(), 0);
1374 
1375  // xor(x, 0) -> x
1376  if (inputs.size() == 2 && inputs[1] &&
1377  cast<IntegerAttr>(inputs[1]).getValue().isZero())
1378  return getInputs()[0];
1379 
1380  // xor(xor(x,1),1) -> x
1381  // but not self loop
1382  if (isBinaryNot()) {
1383  Value subExpr;
1384  if (matchPattern(getOperand(0), m_Complement(m_Any(&subExpr))) &&
1385  subExpr != getResult())
1386  return subExpr;
1387  }
1388 
1389  // Constant fold
1390  return constFoldAssociativeOp(inputs, hw::PEO::Xor);
1391 }
1392 
1393 // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1394 static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand,
1395  PatternRewriter &rewriter) {
1396  auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1397  auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1398 
1399  Value result =
1400  rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1401  icmp.getOperand(1), icmp.getTwoState());
1402 
1403  // If the xor had other operands, rebuild it.
1404  if (op.getNumOperands() > 2) {
1405  SmallVector<Value, 4> newOperands(op.getOperands());
1406  newOperands.pop_back();
1407  newOperands.erase(newOperands.begin() + icmpOperand);
1408  newOperands.push_back(result);
1409  result = rewriter.create<XorOp>(op.getLoc(), newOperands, op.getTwoState());
1410  }
1411 
1412  replaceOpAndCopyName(rewriter, op, result);
1413 }
1414 
1415 LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
1416  if (hasOperandsOutsideOfBlock(&*op))
1417  return failure();
1418 
1419  auto inputs = op.getInputs();
1420  auto size = inputs.size();
1421  assert(size > 1 && "expected 2 or more operands");
1422 
1423  // xor(..., x, x) -> xor (...) -- idempotent
1424  if (inputs[size - 1] == inputs[size - 2]) {
1425  assert(size > 2 &&
1426  "expected idempotent case for 2 elements handled already.");
1427  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1428  inputs.drop_back(/*n=*/2), false);
1429  return success();
1430  }
1431 
1432  // Patterns for xor with a constant on RHS.
1433  APInt value;
1434  if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1435  // xor(..., 0) -> xor(...) -- identity
1436  if (value.isZero()) {
1437  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1438  inputs.drop_back(), false);
1439  return success();
1440  }
1441 
1442  // xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2.
1443  APInt value2;
1444  if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1445  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value ^ value2);
1446  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1447  newOperands.push_back(cst);
1448  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1449  newOperands, false);
1450  return success();
1451  }
1452 
1453  bool isSingleBit = value.getBitWidth() == 1;
1454 
1455  // Check for subexpressions that we can simplify.
1456  for (size_t i = 0; i < size - 1; ++i) {
1457  Value operand = inputs[i];
1458 
1459  // xor(concat(x, cst1), a, b, c, cst2)
1460  // ==> xor(a, b, c, concat(xor(x,cst2'), xor(cst1,cst2'')).
1461  // We do this for even more multi-use concats since they are "just
1462  // wiring".
1463  if (auto concat = operand.getDefiningOp<ConcatOp>())
1464  if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1465  return success();
1466 
1467  // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1468  if (isSingleBit && operand.hasOneUse()) {
1469  assert(value == 1 && "single bit constant has to be one if not zero");
1470  if (auto icmp = operand.getDefiningOp<ICmpOp>())
1471  return canonicalizeXorIcmpTrue(op, i, rewriter), success();
1472  }
1473  }
1474  }
1475 
1476  // xor(x, xor(...)) -> xor(x, ...) -- flatten
1477  if (tryFlatteningOperands(op, rewriter))
1478  return success();
1479 
1480  // extracts only of xor(...) -> xor(extract()...)
1481  if (narrowOperationWidth(op, true, rewriter))
1482  return success();
1483 
1484  // xor(a[0], a[1], ..., a[n]) -> parity(a)
1485  if (auto source = getCommonOperand(op)) {
1486  replaceOpWithNewOpAndCopyName<ParityOp>(rewriter, op, source);
1487  return success();
1488  }
1489 
1490  return failure();
1491 }
1492 
1493 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1494  if (hasOperandsOutsideOfBlock(getOperation()))
1495  return {};
1496 
1497  // sub(x - x) -> 0
1498  if (getRhs() == getLhs())
1499  return getIntAttr(
1500  APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1501  getContext());
1502 
1503  if (adaptor.getRhs()) {
1504  // If both are constants, we can unconditionally fold.
1505  if (adaptor.getLhs()) {
1506  // Constant fold (c1 - c2) => (c1 + -1*c2).
1507  auto negOne = getIntAttr(
1508  APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1509  getContext());
1510  auto rhsNeg = hw::ParamExprAttr::get(
1511  hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1512  return hw::ParamExprAttr::get(hw::PEO::Add,
1513  cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1514  }
1515 
1516  // sub(x - 0) -> x
1517  if (auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1518  if (rhsC.getValue().isZero())
1519  return getLhs();
1520  }
1521  }
1522 
1523  return {};
1524 }
1525 
1526 LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
1527  if (hasOperandsOutsideOfBlock(&*op))
1528  return failure();
1529 
1530  // sub(x, cst) -> add(x, -cst)
1531  APInt value;
1532  if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1533  auto negCst = rewriter.create<hw::ConstantOp>(op.getLoc(), -value);
1534  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getLhs(), negCst,
1535  false);
1536  return success();
1537  }
1538 
1539  // extracts only of sub(...) -> sub(extract()...)
1540  if (narrowOperationWidth(op, false, rewriter))
1541  return success();
1542 
1543  return failure();
1544 }
1545 
1546 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1547  if (hasOperandsOutsideOfBlock(getOperation()))
1548  return {};
1549 
1550  auto size = getInputs().size();
1551 
1552  // add(x) -> x -- noop
1553  if (size == 1u)
1554  return getInputs()[0];
1555 
1556  // Constant fold constant operands.
1557  return constFoldAssociativeOp(adaptor.getOperands(), hw::PEO::Add);
1558 }
1559 
1560 LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
1561  if (hasOperandsOutsideOfBlock(&*op))
1562  return failure();
1563 
1564  auto inputs = op.getInputs();
1565  auto size = inputs.size();
1566  assert(size > 1 && "expected 2 or more operands");
1567 
1568  APInt value, value2;
1569 
1570  // add(..., 0) -> add(...) -- identity
1571  if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1572  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1573  inputs.drop_back(), false);
1574  return success();
1575  }
1576 
1577  // add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant folding
1578  if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1579  matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1580  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
1581  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1582  newOperands.push_back(cst);
1583  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1584  newOperands, false);
1585  return success();
1586  }
1587 
1588  // add(..., x, x) -> add(..., shl(x, 1))
1589  if (inputs[size - 1] == inputs[size - 2]) {
1590  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1591 
1592  auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1593  auto shiftLeftOp =
1594  rewriter.create<comb::ShlOp>(op.getLoc(), inputs.back(), one, false);
1595 
1596  newOperands.push_back(shiftLeftOp);
1597  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1598  newOperands, false);
1599  return success();
1600  }
1601 
1602  auto shlOp = inputs[size - 1].getDefiningOp<comb::ShlOp>();
1603  // add(..., x, shl(x, c)) -> add(..., mul(x, (1 << c) + 1))
1604  if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1605  matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1606 
1607  APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1608  auto rhs =
1609  rewriter.create<hw::ConstantOp>(op.getLoc(), (one << value) + one);
1610 
1611  std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1612  auto mulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
1613 
1614  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1615  newOperands.push_back(mulOp);
1616  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1617  newOperands, false);
1618  return success();
1619  }
1620 
1621  auto mulOp = inputs[size - 1].getDefiningOp<comb::MulOp>();
1622  // add(..., x, mul(x, c)) -> add(..., mul(x, c + 1))
1623  if (mulOp && mulOp.getInputs().size() == 2 &&
1624  mulOp.getInputs()[0] == inputs[size - 2] &&
1625  matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1626 
1627  APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1628  auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + one);
1629  std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1630  auto newMulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
1631 
1632  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1633  newOperands.push_back(newMulOp);
1634  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1635  newOperands, false);
1636  return success();
1637  }
1638 
1639  // add(a, add(...)) -> add(a, ...) -- flatten
1640  if (tryFlatteningOperands(op, rewriter))
1641  return success();
1642 
1643  // extracts only of add(...) -> add(extract()...)
1644  if (narrowOperationWidth(op, false, rewriter))
1645  return success();
1646 
1647  // add(add(x, c1), c2) -> add(x, c1 + c2)
1648  auto addOp = inputs[0].getDefiningOp<comb::AddOp>();
1649  if (addOp && addOp.getInputs().size() == 2 &&
1650  matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1651  inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1652 
1653  auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
1654  replaceOpWithNewOpAndCopyName<AddOp>(
1655  rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1656  /*twoState=*/op.getTwoState() && addOp.getTwoState());
1657  return success();
1658  }
1659 
1660  return failure();
1661 }
1662 
1663 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1664  if (hasOperandsOutsideOfBlock(getOperation()))
1665  return {};
1666 
1667  auto size = getInputs().size();
1668  auto inputs = adaptor.getInputs();
1669 
1670  // mul(x) -> x -- noop
1671  if (size == 1u)
1672  return getInputs()[0];
1673 
1674  auto width = cast<IntegerType>(getType()).getWidth();
1675  APInt value(/*numBits=*/width, 1, /*isSigned=*/false);
1676 
1677  // mul(x, 0, 1) -> 0 -- annulment
1678  for (auto operand : inputs) {
1679  if (!operand)
1680  continue;
1681  value *= cast<IntegerAttr>(operand).getValue();
1682  if (value.isZero())
1683  return getIntAttr(value, getContext());
1684  }
1685 
1686  // Constant fold
1687  return constFoldAssociativeOp(inputs, hw::PEO::Mul);
1688 }
1689 
1690 LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
1691  if (hasOperandsOutsideOfBlock(&*op))
1692  return failure();
1693 
1694  auto inputs = op.getInputs();
1695  auto size = inputs.size();
1696  assert(size > 1 && "expected 2 or more operands");
1697 
1698  APInt value, value2;
1699 
1700  // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
1701  if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1702  value.isPowerOf2()) {
1703  auto shift = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(),
1704  value.exactLogBase2());
1705  auto shlOp =
1706  rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift, false);
1707 
1708  replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1709  ArrayRef<Value>(shlOp), false);
1710  return success();
1711  }
1712 
1713  // mul(..., 1) -> mul(...) -- identity
1714  if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1715  replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1716  inputs.drop_back());
1717  return success();
1718  }
1719 
1720  // mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant folding
1721  if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1722  matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1723  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value * value2);
1724  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1725  newOperands.push_back(cst);
1726  replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1727  newOperands);
1728  return success();
1729  }
1730 
1731  // mul(a, mul(...)) -> mul(a, ...) -- flatten
1732  if (tryFlatteningOperands(op, rewriter))
1733  return success();
1734 
1735  // extracts only of mul(...) -> mul(extract()...)
1736  if (narrowOperationWidth(op, false, rewriter))
1737  return success();
1738 
1739  return failure();
1740 }
1741 
1742 template <class Op, bool isSigned>
1743 static OpFoldResult foldDiv(Op op, ArrayRef<Attribute> constants) {
1744  if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1745  // divu(x, 1) -> x, divs(x, 1) -> x
1746  if (rhsValue.getValue() == 1)
1747  return op.getLhs();
1748 
1749  // If the divisor is zero, do not fold for now.
1750  if (rhsValue.getValue().isZero())
1751  return {};
1752  }
1753 
1754  return constFoldBinaryOp(constants, isSigned ? hw::PEO::DivS : hw::PEO::DivU);
1755 }
1756 
1757 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1758  if (hasOperandsOutsideOfBlock(getOperation()))
1759  return {};
1760 
1761  return foldDiv<DivUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1762 }
1763 
1764 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1765  if (hasOperandsOutsideOfBlock(getOperation()))
1766  return {};
1767 
1768  return foldDiv<DivSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1769 }
1770 
1771 template <class Op, bool isSigned>
1772 static OpFoldResult foldMod(Op op, ArrayRef<Attribute> constants) {
1773  if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1774  // modu(x, 1) -> 0, mods(x, 1) -> 0
1775  if (rhsValue.getValue() == 1)
1776  return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1777  op.getContext());
1778 
1779  // If the divisor is zero, do not fold for now.
1780  if (rhsValue.getValue().isZero())
1781  return {};
1782  }
1783 
1784  if (auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1785  // modu(0, x) -> 0, mods(0, x) -> 0
1786  if (lhsValue.getValue().isZero())
1787  return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1788  op.getContext());
1789  }
1790 
1791  return constFoldBinaryOp(constants, isSigned ? hw::PEO::ModS : hw::PEO::ModU);
1792 }
1793 
1794 OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1795  if (hasOperandsOutsideOfBlock(getOperation()))
1796  return {};
1797 
1798  return foldMod<ModUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1799 }
1800 
1801 OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1802  if (hasOperandsOutsideOfBlock(getOperation()))
1803  return {};
1804 
1805  return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1806 }
1807 //===----------------------------------------------------------------------===//
1808 // ConcatOp
1809 //===----------------------------------------------------------------------===//
1810 
1811 // Constant folding
1812 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1813  if (hasOperandsOutsideOfBlock(getOperation()))
1814  return {};
1815 
1816  if (getNumOperands() == 1)
1817  return getOperand(0);
1818 
1819  // If all the operands are constant, we can fold.
1820  for (auto attr : adaptor.getInputs())
1821  if (!attr || !isa<IntegerAttr>(attr))
1822  return {};
1823 
1824  // If we got here, we can constant fold.
1825  unsigned resultWidth = getType().getIntOrFloatBitWidth();
1826  APInt result(resultWidth, 0);
1827 
1828  unsigned nextInsertion = resultWidth;
1829  // Insert each chunk into the result.
1830  for (auto attr : adaptor.getInputs()) {
1831  auto chunk = cast<IntegerAttr>(attr).getValue();
1832  nextInsertion -= chunk.getBitWidth();
1833  result.insertBits(chunk, nextInsertion);
1834  }
1835 
1836  return getIntAttr(result, getContext());
1837 }
1838 
1839 LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
1840  if (hasOperandsOutsideOfBlock(&*op))
1841  return failure();
1842 
1843  auto inputs = op.getInputs();
1844  auto size = inputs.size();
1845  assert(size > 1 && "expected 2 or more operands");
1846 
1847  // This function is used when we flatten neighboring operands of a
1848  // (variadic) concat into a new vesion of the concat. first/last indices
1849  // are inclusive.
1850  auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex,
1851  ValueRange replacements) -> LogicalResult {
1852  SmallVector<Value, 4> newOperands;
1853  newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1854  newOperands.append(replacements.begin(), replacements.end());
1855  newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1856  if (newOperands.size() == 1)
1857  replaceOpAndCopyName(rewriter, op, newOperands[0]);
1858  else
1859  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1860  newOperands);
1861  return success();
1862  };
1863 
1864  Value commonOperand = inputs[0];
1865  for (size_t i = 0; i != size; ++i) {
1866  // Check to see if all operands are the same.
1867  if (inputs[i] != commonOperand)
1868  commonOperand = Value();
1869 
1870  // If an operand to the concat is itself a concat, then we can fold them
1871  // together.
1872  if (auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1873  return flattenConcat(i, i, subConcat->getOperands());
1874 
1875  // Check for canonicalization due to neighboring operands.
1876  if (i != 0) {
1877  // Merge neighboring constants.
1878  if (auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1879  if (auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1880  unsigned prevWidth = prevCst.getValue().getBitWidth();
1881  unsigned thisWidth = cst.getValue().getBitWidth();
1882  auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1883  resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1884  << thisWidth;
1885  Value replacement =
1886  rewriter.create<hw::ConstantOp>(op.getLoc(), resultCst);
1887  return flattenConcat(i - 1, i, replacement);
1888  }
1889  }
1890 
1891  // If the two operands are the same, turn them into a replicate.
1892  if (inputs[i] == inputs[i - 1]) {
1893  Value replacement =
1894  rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1895  return flattenConcat(i - 1, i, replacement);
1896  }
1897 
1898  // If this input is a replicate, see if we can fold it with the previous
1899  // one.
1900  if (auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1901  // ... x, repl(x, n), ... ==> ..., repl(x, n+1), ...
1902  if (repl.getOperand() == inputs[i - 1]) {
1903  Value replacement = rewriter.createOrFold<ReplicateOp>(
1904  op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1905  return flattenConcat(i - 1, i, replacement);
1906  }
1907  // ... repl(x, n), repl(x, m), ... ==> ..., repl(x, n+m), ...
1908  if (auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1909  if (prevRepl.getOperand() == repl.getOperand()) {
1910  Value replacement = rewriter.createOrFold<ReplicateOp>(
1911  op.getLoc(), repl.getOperand(),
1912  repl.getMultiple() + prevRepl.getMultiple());
1913  return flattenConcat(i - 1, i, replacement);
1914  }
1915  }
1916  }
1917 
1918  // ... repl(x, n), x, ... ==> ..., repl(x, n+1), ...
1919  if (auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1920  if (repl.getOperand() == inputs[i]) {
1921  Value replacement = rewriter.createOrFold<ReplicateOp>(
1922  op.getLoc(), inputs[i], repl.getMultiple() + 1);
1923  return flattenConcat(i - 1, i, replacement);
1924  }
1925  }
1926 
1927  // Merge neighboring extracts of neighboring inputs, e.g.
1928  // {A[3], A[2]} -> A[3:2]
1929  if (auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1930  if (auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1931  if (extract.getInput() == prevExtract.getInput()) {
1932  auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1933  if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1934  auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1935  auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1936  Value replacement = rewriter.create<ExtractOp>(
1937  op.getLoc(), resType, extract.getInput(),
1938  extract.getLowBit());
1939  return flattenConcat(i - 1, i, replacement);
1940  }
1941  }
1942  }
1943  }
1944  // Merge neighboring array extracts of neighboring inputs, e.g.
1945  // {Array[4], bitcast(Array[3:2])} -> bitcast(A[4:2])
1946 
1947  // This represents a slice of an array.
1948  struct ArraySlice {
1949  Value input;
1950  Value index;
1951  size_t width;
1952  static std::optional<ArraySlice> get(Value value) {
1953  assert(isa<IntegerType>(value.getType()) && "expected integer type");
1954  if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
1955  return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1956  // array slice op is wrapped with bitcast.
1957  if (auto bitcast = value.getDefiningOp<hw::BitcastOp>())
1958  if (auto arraySlice =
1959  bitcast.getInput().getDefiningOp<hw::ArraySliceOp>())
1960  return ArraySlice{
1961  arraySlice.getInput(), arraySlice.getLowIndex(),
1962  hw::type_cast<hw::ArrayType>(arraySlice.getType())
1963  .getNumElements()};
1964  return std::nullopt;
1965  }
1966  };
1967  if (auto extractOpt = ArraySlice::get(inputs[i])) {
1968  if (auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1969  // Check that two array slices are mergable.
1970  if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1971  prevExtractOpt->input == extractOpt->input &&
1972  hw::isOffset(extractOpt->index, prevExtractOpt->index,
1973  extractOpt->width)) {
1974  auto resType = hw::ArrayType::get(
1975  hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1976  .getElementType(),
1977  extractOpt->width + prevExtractOpt->width);
1978  auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1979  Value replacement = rewriter.create<hw::BitcastOp>(
1980  op.getLoc(), resIntType,
1981  rewriter.create<hw::ArraySliceOp>(op.getLoc(), resType,
1982  prevExtractOpt->input,
1983  extractOpt->index));
1984  return flattenConcat(i - 1, i, replacement);
1985  }
1986  }
1987  }
1988  }
1989  }
1990 
1991  // If all operands were the same, then this is a replicate.
1992  if (commonOperand) {
1993  replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
1994  commonOperand);
1995  return success();
1996  }
1997 
1998  return failure();
1999 }
2000 
2001 //===----------------------------------------------------------------------===//
2002 // MuxOp
2003 //===----------------------------------------------------------------------===//
2004 
2005 OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
2006  if (hasOperandsOutsideOfBlock(getOperation()))
2007  return {};
2008 
2009  // mux (c, b, b) -> b
2010  if (getTrueValue() == getFalseValue())
2011  return getTrueValue();
2012  if (auto tv = adaptor.getTrueValue())
2013  if (tv == adaptor.getFalseValue())
2014  return tv;
2015 
2016  // mux(0, a, b) -> b
2017  // mux(1, a, b) -> a
2018  if (auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
2019  if (pred.getValue().isZero())
2020  return getFalseValue();
2021  return getTrueValue();
2022  }
2023 
2024  // mux(cond, 1, 0) -> cond
2025  if (auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
2026  if (auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
2027  if (tv.getValue().isOne() && fv.getValue().isZero() &&
2028  hw::getBitWidth(getType()) == 1)
2029  return getCond();
2030 
2031  return {};
2032 }
2033 
2034 /// Check to see if the condition to the specified mux is an equality
2035 /// comparison `indexValue` and one or more constants. If so, put the
2036 /// constants in the constants vector and return true, otherwise return false.
2037 ///
2038 /// This is part of foldMuxChain.
2039 ///
2040 static bool
2041 getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted,
2042  std::function<void(hw::ConstantOp)> constantFn) {
2043  // Handle `idx == 42` and `idx != 42`.
2044  if (auto cmp = cond.getDefiningOp<ICmpOp>()) {
2045  // TODO: We could handle things like "x < 2" as two entries.
2046  auto requiredPredicate =
2047  (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
2048  if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
2049  if (auto cst = cmp.getRhs().getDefiningOp<hw::ConstantOp>()) {
2050  constantFn(cst);
2051  return true;
2052  }
2053  }
2054  return false;
2055  }
2056 
2057  // Handle mux(`idx == 1 || idx == 3`, value, muxchain).
2058  if (auto orOp = cond.getDefiningOp<OrOp>()) {
2059  if (!isInverted)
2060  return false;
2061  for (auto operand : orOp.getOperands())
2062  if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
2063  return false;
2064  return true;
2065  }
2066 
2067  // Handle mux(`idx != 1 && idx != 3`, muxchain, value).
2068  if (auto andOp = cond.getDefiningOp<AndOp>()) {
2069  if (isInverted)
2070  return false;
2071  for (auto operand : andOp.getOperands())
2072  if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
2073  return false;
2074  return true;
2075  }
2076 
2077  return false;
2078 }
2079 
2080 /// Given a mux, check to see if the "on true" value (or "on false" value if
2081 /// isFalseSide=true) is a mux tree with the same condition. This allows us
2082 /// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into
2083 /// `array_get (array_create(A, B, C), VAL)` which is far more compact and
2084 /// allows synthesis tools to do more interesting optimizations.
2085 ///
2086 /// This returns false if we cannot form the mux tree (or do not want to) and
2087 /// returns true if the mux was replaced.
2088 static bool foldMuxChain(MuxOp rootMux, bool isFalseSide,
2089  PatternRewriter &rewriter) {
2090  // Get the index value being compared. Later we check to see if it is
2091  // compared to a constant with the right predicate.
2092  auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2093  if (!rootCmp)
2094  return false;
2095  Value indexValue = rootCmp.getLhs();
2096 
2097  // Return the value to use if the equality match succeeds.
2098  auto getCaseValue = [&](MuxOp mux) -> Value {
2099  return mux.getOperand(1 + unsigned(!isFalseSide));
2100  };
2101 
2102  // Return the value to use if the equality match fails. This is the next
2103  // mux in the sequence or the "otherwise" value.
2104  auto getTreeValue = [&](MuxOp mux) -> Value {
2105  return mux.getOperand(1 + unsigned(isFalseSide));
2106  };
2107 
2108  // Start scanning the mux tree to see what we've got. Keep track of the
2109  // constant comparison value and the SSA value to use when equal to it.
2110  SmallVector<Location> locationsFound;
2111  SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2112 
2113  /// Extract constants and values into `valuesFound` and return true if this is
2114  /// part of the mux tree, otherwise return false.
2115  auto collectConstantValues = [&](MuxOp mux) -> bool {
2116  return getMuxChainCondConstant(
2117  mux.getCond(), indexValue, isFalseSide, [&](hw::ConstantOp cst) {
2118  valuesFound.push_back({cst, getCaseValue(mux)});
2119  locationsFound.push_back(mux.getCond().getLoc());
2120  locationsFound.push_back(mux->getLoc());
2121  });
2122  };
2123 
2124  // Make sure the root is a correct comparison with a constant.
2125  if (!collectConstantValues(rootMux))
2126  return false;
2127 
2128  // Make sure that we're not looking at the intermediate node in a mux tree.
2129  if (rootMux->hasOneUse()) {
2130  if (auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2131  if (getTreeValue(userMux) == rootMux.getResult() &&
2132  getMuxChainCondConstant(userMux.getCond(), indexValue, isFalseSide,
2133  [&](hw::ConstantOp cst) {}))
2134  return false;
2135  }
2136  }
2137 
2138  // Scan up the tree linearly.
2139  auto nextTreeValue = getTreeValue(rootMux);
2140  while (1) {
2141  auto nextMux = nextTreeValue.getDefiningOp<MuxOp>();
2142  if (!nextMux || !nextMux->hasOneUse())
2143  break;
2144  if (!collectConstantValues(nextMux))
2145  break;
2146  nextTreeValue = getTreeValue(nextMux);
2147  }
2148 
2149  // We need to have more than three values to create an array. This is an
2150  // arbitrary threshold which is saying that one or two muxes together is ok,
2151  // but three should be folded.
2152  if (valuesFound.size() < 3)
2153  return false;
2154 
2155  // If the array is greater that 9 bits, it will take over 512 elements and
2156  // it will be too large for a single expression.
2157  auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2158  if (indexWidth >= 9)
2159  return false;
2160 
2161  // Next we need to see if the values are dense-ish. We don't want to have
2162  // a tremendous number of replicated entries in the array. Some sparsity is
2163  // ok though, so we require the table to be at least 5/8 utilized.
2164  uint64_t tableSize = 1ULL << indexWidth;
2165  if (valuesFound.size() < (tableSize * 5) / 8)
2166  return false; // Not dense enough.
2167 
2168  // Ok, we're going to do the transformation, start by building the table
2169  // filled with the "otherwise" value.
2170  SmallVector<Value, 8> table(tableSize, nextTreeValue);
2171 
2172  // Fill in entries in the table from the leaf to the root of the expression.
2173  // This ensures that any duplicate matches end up with the ultimate value,
2174  // which is the one closer to the root.
2175  for (auto &elt : llvm::reverse(valuesFound)) {
2176  uint64_t idx = elt.first.getValue().getZExtValue();
2177  assert(idx < table.size() && "constant should be same bitwidth as index");
2178  table[idx] = elt.second;
2179  }
2180 
2181  // The hw.array_create operation has the operand list in unintuitive order
2182  // with a[0] stored as the last element, not the first.
2183  std::reverse(table.begin(), table.end());
2184 
2185  // Build the array_create and the array_get.
2186  auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2187  auto array = rewriter.create<hw::ArrayCreateOp>(fusedLoc, table);
2188  replaceOpWithNewOpAndCopyName<hw::ArrayGetOp>(rewriter, rootMux, array,
2189  indexValue);
2190  return true;
2191 }
2192 
2193 /// Given a fully associative variadic operation like (a+b+c+d), break the
2194 /// expression into two parts, one without the specified operand (e.g.
2195 /// `tmp = a+b+d`) and one that combines that into the full expression (e.g.
2196 /// `tmp+c`), and return the inner expression.
2197 ///
2198 /// NOTE: This mutates the operation in place if it only has a single user,
2199 /// which assumes that user will be removed.
2200 ///
2201 static Value extractOperandFromFullyAssociative(Operation *fullyAssoc,
2202  size_t operandNo,
2203  PatternRewriter &rewriter) {
2204  assert(fullyAssoc->getNumOperands() >= 2 && "cannot split up unary ops");
2205  assert(operandNo < fullyAssoc->getNumOperands() && "Invalid operand #");
2206 
2207  // If this expression already has two operands (the common case) no splitting
2208  // is necessary.
2209  if (fullyAssoc->getNumOperands() == 2)
2210  return fullyAssoc->getOperand(operandNo ^ 1);
2211 
2212  // If the operation has a single use, mutate it in place.
2213  if (fullyAssoc->hasOneUse()) {
2214  fullyAssoc->eraseOperand(operandNo);
2215  return fullyAssoc->getResult(0);
2216  }
2217 
2218  // Form the new operation with the operands that remain.
2219  SmallVector<Value> operands;
2220  operands.append(fullyAssoc->getOperands().begin(),
2221  fullyAssoc->getOperands().begin() + operandNo);
2222  operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2223  fullyAssoc->getOperands().end());
2224  Value opWithoutExcluded = createGenericOp(
2225  fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2226  Value excluded = fullyAssoc->getOperand(operandNo);
2227 
2228  Value fullResult =
2229  createGenericOp(fullyAssoc->getLoc(), fullyAssoc->getName(),
2230  ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2231  replaceOpAndCopyName(rewriter, fullyAssoc, fullResult);
2232  return opWithoutExcluded;
2233 }
2234 
2235 /// Fold things like `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond)|a` and
2236 /// `mux(cond, a, x|y|z|a) -> `(x|y|z)&replicate(~cond) | a` (when isTrueOperand
2237 /// is true. Return true on successful transformation, false if not.
2238 ///
2239 /// These are various forms of "predicated ops" that can be handled with a
2240 /// replicate/and combination.
2241 static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
2242  PatternRewriter &rewriter) {
2243  // Check to see the operand in question is an operation. If it is a port,
2244  // we can't simplify it.
2245  Operation *subExpr =
2246  (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2247  if (!subExpr || subExpr->getNumOperands() < 2)
2248  return false;
2249 
2250  // If this isn't an operation we can handle, don't spend energy on it.
2251  if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2252  return false;
2253 
2254  // Check to see if the common value occurs in the operand list for the
2255  // subexpression op. If so, then we can simplify it.
2256  Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2257  size_t opNo = 0, e = subExpr->getNumOperands();
2258  while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2259  ++opNo;
2260  if (opNo == e)
2261  return false;
2262 
2263  // If we got a hit, then go ahead and simplify it!
2264  Value cond = op.getCond();
2265 
2266  // `mux(cond, a, mux(cond2, a, b))` -> `mux(cond|cond2, a, b)`
2267  // `mux(cond, a, mux(cond2, b, a))` -> `mux(cond|~cond2, a, b)`
2268  // `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)`
2269  // `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)`
2270  if (auto subMux = dyn_cast<MuxOp>(subExpr)) {
2271  Value otherValue;
2272  Value subCond = subMux.getCond();
2273 
2274  // Invert th subCond if needed and dig out the 'b' value.
2275  if (subMux.getTrueValue() == commonValue)
2276  otherValue = subMux.getFalseValue();
2277  else if (subMux.getFalseValue() == commonValue) {
2278  otherValue = subMux.getTrueValue();
2279  subCond = createOrFoldNot(op.getLoc(), subCond, rewriter);
2280  } else {
2281  // We can't fold `mux(cond, a, mux(a, x, y))`.
2282  return false;
2283  }
2284 
2285  // Invert the outer cond if needed, and combine the mux conditions.
2286  if (!isTrueOperand)
2287  cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2288  cond = rewriter.createOrFold<OrOp>(op.getLoc(), cond, subCond, false);
2289  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, cond, commonValue,
2290  otherValue, op.getTwoState());
2291  return true;
2292  }
2293 
2294  // Invert the condition if needed. Or/Xor invert when dealing with
2295  // TrueOperand, And inverts for False operand.
2296  bool isaAndOp = isa<AndOp>(subExpr);
2297  if (isTrueOperand ^ isaAndOp)
2298  cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2299 
2300  auto extendedCond =
2301  rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2302 
2303  // Cache this information before subExpr is erased by extraction below.
2304  bool isaXorOp = isa<XorOp>(subExpr);
2305  bool isaOrOp = isa<OrOp>(subExpr);
2306 
2307  // Handle the fully associative ops, start by pulling out the subexpression
2308  // from a many operand version of the op.
2309  auto restOfAssoc =
2310  extractOperandFromFullyAssociative(subExpr, opNo, rewriter);
2311 
2312  // `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond) | a`
2313  // `mux(cond, x^y^z^a, a)` -> `(x^y^z)&replicate(cond) ^ a`
2314  if (isaOrOp || isaXorOp) {
2315  auto masked = rewriter.createOrFold<AndOp>(op.getLoc(), extendedCond,
2316  restOfAssoc, false);
2317  if (isaXorOp)
2318  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, masked, commonValue,
2319  false);
2320  else
2321  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, masked, commonValue,
2322  false);
2323  return true;
2324  }
2325 
2326  // `mux(cond, a, x&y&z&a)` -> `((x&y&z)|replicate(cond)) & a`
2327  assert(isaAndOp && "unexpected operation here");
2328  auto masked = rewriter.createOrFold<OrOp>(op.getLoc(), extendedCond,
2329  restOfAssoc, false);
2330  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, masked, commonValue,
2331  false);
2332  return true;
2333 }
2334 
2335 /// This function is invoke when we find a mux with true/false operations that
2336 /// have the same opcode. Check to see if we can strength reduce the mux by
2337 /// applying it to less data by applying this transformation:
2338 /// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2339 static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp,
2340  Operation *falseOp,
2341  PatternRewriter &rewriter) {
2342  // Right now we only apply to concat.
2343  // TODO: Generalize this to and, or, xor, icmp(!), which all occur in practice
2344  if (!isa<ConcatOp>(trueOp))
2345  return false;
2346 
2347  // Decode the operands, looking through recursive concats and replicates.
2348  SmallVector<Value> trueOperands, falseOperands;
2349  getConcatOperands(trueOp->getResult(0), trueOperands);
2350  getConcatOperands(falseOp->getResult(0), falseOperands);
2351 
2352  size_t numTrueOperands = trueOperands.size();
2353  size_t numFalseOperands = falseOperands.size();
2354 
2355  if (!numTrueOperands || !numFalseOperands ||
2356  (trueOperands.front() != falseOperands.front() &&
2357  trueOperands.back() != falseOperands.back()))
2358  return false;
2359 
2360  // Pull all leading shared operands out into their own op if any are common.
2361  if (trueOperands.front() == falseOperands.front()) {
2362  SmallVector<Value> operands;
2363  size_t i;
2364  for (i = 0; i < numTrueOperands; ++i) {
2365  Value trueOperand = trueOperands[i];
2366  if (trueOperand == falseOperands[i])
2367  operands.push_back(trueOperand);
2368  else
2369  break;
2370  }
2371  if (i == numTrueOperands) {
2372  // Selecting between distinct, but lexically identical, concats.
2373  replaceOpAndCopyName(rewriter, mux, trueOp->getResult(0));
2374  return true;
2375  }
2376 
2377  Value sharedMSB;
2378  if (llvm::all_of(operands, [&](Value v) { return v == operands.front(); }))
2379  sharedMSB = rewriter.createOrFold<ReplicateOp>(
2380  mux->getLoc(), operands.front(), operands.size());
2381  else
2382  sharedMSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2383  operands.clear();
2384 
2385  // Get a concat of the LSB's on each side.
2386  operands.append(trueOperands.begin() + i, trueOperands.end());
2387  Value trueLSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2388  operands.clear();
2389  operands.append(falseOperands.begin() + i, falseOperands.end());
2390  Value falseLSB =
2391  rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2392  // Merge the LSBs with a new mux and concat the MSB with the LSB to be
2393  // done.
2394  Value lsb = rewriter.createOrFold<MuxOp>(
2395  mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2396  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2397  return true;
2398  }
2399 
2400  // If trailing operands match, try to commonize them.
2401  if (trueOperands.back() == falseOperands.back()) {
2402  SmallVector<Value> operands;
2403  size_t i;
2404  for (i = 0;; ++i) {
2405  Value trueOperand = trueOperands[numTrueOperands - i - 1];
2406  if (trueOperand == falseOperands[numFalseOperands - i - 1])
2407  operands.push_back(trueOperand);
2408  else
2409  break;
2410  }
2411  std::reverse(operands.begin(), operands.end());
2412  Value sharedLSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2413  operands.clear();
2414 
2415  // Get a concat of the MSB's on each side.
2416  operands.append(trueOperands.begin(), trueOperands.end() - i);
2417  Value trueMSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2418  operands.clear();
2419  operands.append(falseOperands.begin(), falseOperands.end() - i);
2420  Value falseMSB =
2421  rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2422  // Merge the MSBs with a new mux and concat the MSB with the LSB to be done.
2423  Value msb = rewriter.createOrFold<MuxOp>(
2424  mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2425  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, msb, sharedLSB);
2426  return true;
2427  }
2428 
2429  return false;
2430 }
2431 
2432 // If both arguments of the mux are arrays with the same elements, sink the
2433 // mux and return a uniform array initializing all elements to it.
2434 static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter) {
2435  auto trueVec = op.getTrueValue().getDefiningOp<hw::ArrayCreateOp>();
2436  auto falseVec = op.getFalseValue().getDefiningOp<hw::ArrayCreateOp>();
2437  if (!trueVec || !falseVec)
2438  return false;
2439  if (!trueVec.isUniform() || !falseVec.isUniform())
2440  return false;
2441 
2442  auto mux = rewriter.create<MuxOp>(
2443  op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2444  falseVec.getUniformElement(), op.getTwoState());
2445 
2446  SmallVector<Value> values(trueVec.getInputs().size(), mux);
2447  rewriter.replaceOpWithNewOp<hw::ArrayCreateOp>(op, values);
2448  return true;
2449 }
2450 
2451 namespace {
2452 struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
2453  using OpRewritePattern::OpRewritePattern;
2454 
2455  LogicalResult matchAndRewrite(MuxOp op,
2456  PatternRewriter &rewriter) const override;
2457 };
2458 
2459 LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
2460  PatternRewriter &rewriter) const {
2461  if (hasOperandsOutsideOfBlock(&*op))
2462  return failure();
2463 
2464  // If the op has a SV attribute, don't optimize it.
2465  if (hasSVAttributes(op))
2466  return failure();
2467  APInt value;
2468 
2469  if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2470  if (value.getBitWidth() == 1) {
2471  // mux(a, 0, b) -> and(~a, b) for single-bit values.
2472  if (value.isZero()) {
2473  auto notCond = createOrFoldNot(op.getLoc(), op.getCond(), rewriter);
2474  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, notCond,
2475  op.getFalseValue(), false);
2476  return success();
2477  }
2478 
2479  // mux(a, 1, b) -> or(a, b) for single-bit values.
2480  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getCond(),
2481  op.getFalseValue(), false);
2482  return success();
2483  }
2484 
2485  // Check for mux of two constants. There are many ways to simplify them.
2486  APInt value2;
2487  if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2488  // When both inputs are constants and differ by only one bit, we can
2489  // simplify by splitting the mux into up to three contiguous chunks: one
2490  // for the differing bit and up to two for the bits that are the same.
2491  // E.g. mux(a, 3'h2, 0) -> concat(0, mux(a, 1, 0), 0) -> concat(0, a, 0)
2492  APInt xorValue = value ^ value2;
2493  if (xorValue.isPowerOf2()) {
2494  unsigned leadingZeros = xorValue.countLeadingZeros();
2495  unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2496  SmallVector<Value, 3> operands;
2497 
2498  // Concat operands go from MSB to LSB, so we handle chunks in reverse
2499  // order of bit indexes.
2500  // For the chunks that are identical (i.e. correspond to 0s in
2501  // xorValue), we can extract directly from either input value, and we
2502  // arbitrarily pick the trueValue().
2503 
2504  if (leadingZeros > 0)
2505  operands.push_back(rewriter.createOrFold<ExtractOp>(
2506  op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2507 
2508  // Handle the differing bit, which should simplify into either cond or
2509  // ~cond.
2510  auto v1 = rewriter.createOrFold<ExtractOp>(
2511  op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2512  auto v2 = rewriter.createOrFold<ExtractOp>(
2513  op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2514  operands.push_back(rewriter.createOrFold<MuxOp>(
2515  op.getLoc(), op.getCond(), v1, v2, false));
2516 
2517  if (trailingZeros > 0)
2518  operands.push_back(rewriter.createOrFold<ExtractOp>(
2519  op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2520 
2521  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
2522  operands);
2523  return success();
2524  }
2525 
2526  // If the true value is all ones and the false is all zeros then we have a
2527  // replicate pattern.
2528  if (value.isAllOnes() && value2.isZero()) {
2529  replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2530  op.getCond());
2531  return success();
2532  }
2533  }
2534  }
2535 
2536  if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2537  value.getBitWidth() == 1) {
2538  // mux(a, b, 0) -> and(a, b) for single-bit values.
2539  if (value.isZero()) {
2540  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getCond(),
2541  op.getTrueValue(), false);
2542  return success();
2543  }
2544 
2545  // mux(a, b, 1) -> or(~a, b) for single-bit values.
2546  // falseValue() is known to be a single-bit 1, which we can use for
2547  // the 1 in the representation of ~ using xor.
2548  auto notCond = rewriter.createOrFold<XorOp>(op.getLoc(), op.getCond(),
2549  op.getFalseValue(), false);
2550  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, notCond,
2551  op.getTrueValue(), false);
2552  return success();
2553  }
2554 
2555  // mux(!a, b, c) -> mux(a, c, b)
2556  Value subExpr;
2557  Operation *condOp = op.getCond().getDefiningOp();
2558  if (condOp && matchPattern(condOp, m_Complement(m_Any(&subExpr))) &&
2559  op.getTwoState()) {
2560  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, op.getType(), subExpr,
2561  op.getFalseValue(), op.getTrueValue(),
2562  true);
2563  return success();
2564  }
2565 
2566  // Same but with Demorgan's law.
2567  // mux(and(~a, ~b, ~c), x, y) -> mux(or(a, b, c), y, x)
2568  // mux(or(~a, ~b, ~c), x, y) -> mux(and(a, b, c), y, x)
2569  if (condOp && condOp->hasOneUse()) {
2570  SmallVector<Value> invertedOperands;
2571 
2572  /// Scan all the operands to see if they are complemented. If so, build a
2573  /// vector of them and return true, otherwise return false.
2574  auto getInvertedOperands = [&]() -> bool {
2575  for (Value operand : condOp->getOperands()) {
2576  if (matchPattern(operand, m_Complement(m_Any(&subExpr))))
2577  invertedOperands.push_back(subExpr);
2578  else
2579  return false;
2580  }
2581  return true;
2582  };
2583 
2584  if (isa<AndOp>(condOp) && getInvertedOperands()) {
2585  auto newOr =
2586  rewriter.createOrFold<OrOp>(op.getLoc(), invertedOperands, false);
2587  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newOr,
2588  op.getFalseValue(),
2589  op.getTrueValue(), op.getTwoState());
2590  return success();
2591  }
2592  if (isa<OrOp>(condOp) && getInvertedOperands()) {
2593  auto newAnd =
2594  rewriter.createOrFold<AndOp>(op.getLoc(), invertedOperands, false);
2595  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newAnd,
2596  op.getFalseValue(),
2597  op.getTrueValue(), op.getTwoState());
2598  return success();
2599  }
2600  }
2601 
2602  if (auto falseMux =
2603  dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp())) {
2604  // mux(selector, x, mux(selector, y, z) = mux(selector, x, z)
2605  if (op.getCond() == falseMux.getCond()) {
2606  replaceOpWithNewOpAndCopyName<MuxOp>(
2607  rewriter, op, op.getCond(), op.getTrueValue(),
2608  falseMux.getFalseValue(), op.getTwoStateAttr());
2609  return success();
2610  }
2611 
2612  // Check to see if we can fold a mux tree into an array_create/get pair.
2613  if (foldMuxChain(op, /*isFalse*/ true, rewriter))
2614  return success();
2615  }
2616 
2617  if (auto trueMux =
2618  dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp())) {
2619  // mux(selector, mux(selector, a, b), c) = mux(selector, a, c)
2620  if (op.getCond() == trueMux.getCond()) {
2621  replaceOpWithNewOpAndCopyName<MuxOp>(
2622  rewriter, op, op.getCond(), trueMux.getTrueValue(),
2623  op.getFalseValue(), op.getTwoStateAttr());
2624  return success();
2625  }
2626 
2627  // Check to see if we can fold a mux tree into an array_create/get pair.
2628  if (foldMuxChain(op, /*isFalseSide*/ false, rewriter))
2629  return success();
2630  }
2631 
2632  // mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c))
2633  if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2634  falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2635  trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2636  trueMux.getTrueValue() == falseMux.getTrueValue()) {
2637  auto subMux = rewriter.create<MuxOp>(
2638  rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2639  op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2640  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2641  trueMux.getTrueValue(), subMux,
2642  op.getTwoStateAttr());
2643  return success();
2644  }
2645 
2646  // mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b)
2647  if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2648  falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2649  trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2650  trueMux.getFalseValue() == falseMux.getFalseValue()) {
2651  auto subMux = rewriter.create<MuxOp>(
2652  rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2653  op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2654  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2655  subMux, trueMux.getFalseValue(),
2656  op.getTwoStateAttr());
2657  return success();
2658  }
2659 
2660  // mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b)
2661  if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2662  falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2663  trueMux && falseMux &&
2664  trueMux.getTrueValue() == falseMux.getTrueValue() &&
2665  trueMux.getFalseValue() == falseMux.getFalseValue()) {
2666  auto subMux = rewriter.create<MuxOp>(
2667  rewriter.getFusedLoc(
2668  {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2669  op.getCond(), trueMux.getCond(), falseMux.getCond());
2670  replaceOpWithNewOpAndCopyName<MuxOp>(
2671  rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2672  op.getTwoStateAttr());
2673  return success();
2674  }
2675 
2676  // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a
2677  if (foldCommonMuxValue(op, false, rewriter))
2678  return success();
2679  // mux(cond, a, x|y|z|a) -> (x|y|z)&replicate(~cond) | a
2680  if (foldCommonMuxValue(op, true, rewriter))
2681  return success();
2682 
2683  // `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2684  if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2685  if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2686  if (trueOp->getName() == falseOp->getName())
2687  if (foldCommonMuxOperation(op, trueOp, falseOp, rewriter))
2688  return success();
2689 
2690  // extracts only of mux(...) -> mux(extract()...)
2691  if (narrowOperationWidth(op, true, rewriter))
2692  return success();
2693 
2694  // mux(cond, repl(n, a1), repl(n, a2)) -> repl(n, mux(cond, a1, a2))
2695  if (foldMuxOfUniformArrays(op, rewriter))
2696  return success();
2697 
2698  return failure();
2699 }
2700 
2701 static bool foldArrayOfMuxes(hw::ArrayCreateOp op, PatternRewriter &rewriter) {
2702  // Do not fold uniform or singleton arrays to avoid duplicating muxes.
2703  if (op.getInputs().empty() || op.isUniform())
2704  return false;
2705  auto inputs = op.getInputs();
2706  if (inputs.size() <= 1)
2707  return false;
2708 
2709  // Check the operands to the array create. Ensure all of them are the
2710  // same op with the same number of operands.
2711  auto first = inputs[0].getDefiningOp<comb::MuxOp>();
2712  if (!first || hasSVAttributes(first))
2713  return false;
2714 
2715  // Check whether all operands are muxes with the same condition.
2716  for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2717  auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2718  if (!input || first.getCond() != input.getCond())
2719  return false;
2720  }
2721 
2722  // Collect the true and the false branches into arrays.
2723  SmallVector<Value> trues{first.getTrueValue()};
2724  SmallVector<Value> falses{first.getFalseValue()};
2725  SmallVector<Location> locs{first->getLoc()};
2726  bool isTwoState = true;
2727  for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2728  auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2729  trues.push_back(input.getTrueValue());
2730  falses.push_back(input.getFalseValue());
2731  locs.push_back(input->getLoc());
2732  if (!input.getTwoState())
2733  isTwoState = false;
2734  }
2735 
2736  // Define the location of the array create as the aggregate of all muxes.
2737  auto loc = FusedLoc::get(op.getContext(), locs);
2738 
2739  // Replace the create with an aggregate operation. Push the create op
2740  // into the operands of the aggregate operation.
2741  auto arrayTy = op.getType();
2742  auto trueValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, trues);
2743  auto falseValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, falses);
2744  rewriter.replaceOpWithNewOp<comb::MuxOp>(op, arrayTy, first.getCond(),
2745  trueValues, falseValues, isTwoState);
2746  return true;
2747 }
2748 
2749 struct ArrayRewriter : public mlir::OpRewritePattern<hw::ArrayCreateOp> {
2750  using OpRewritePattern::OpRewritePattern;
2751 
2752  LogicalResult matchAndRewrite(hw::ArrayCreateOp op,
2753  PatternRewriter &rewriter) const override {
2754  if (hasOperandsOutsideOfBlock(&*op))
2755  return failure();
2756 
2757  if (foldArrayOfMuxes(op, rewriter))
2758  return success();
2759  return failure();
2760  }
2761 };
2762 
2763 } // namespace
2764 
2765 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2766  MLIRContext *context) {
2767  results.insert<MuxRewriter, ArrayRewriter>(context);
2768 }
2769 
2770 //===----------------------------------------------------------------------===//
2771 // ICmpOp
2772 //===----------------------------------------------------------------------===//
2773 
2774 // Calculate the result of a comparison when the LHS and RHS are both
2775 // constants.
2776 static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs,
2777  const APInt &rhs) {
2778  switch (predicate) {
2779  case ICmpPredicate::eq:
2780  return lhs.eq(rhs);
2781  case ICmpPredicate::ne:
2782  return lhs.ne(rhs);
2783  case ICmpPredicate::slt:
2784  return lhs.slt(rhs);
2785  case ICmpPredicate::sle:
2786  return lhs.sle(rhs);
2787  case ICmpPredicate::sgt:
2788  return lhs.sgt(rhs);
2789  case ICmpPredicate::sge:
2790  return lhs.sge(rhs);
2791  case ICmpPredicate::ult:
2792  return lhs.ult(rhs);
2793  case ICmpPredicate::ule:
2794  return lhs.ule(rhs);
2795  case ICmpPredicate::ugt:
2796  return lhs.ugt(rhs);
2797  case ICmpPredicate::uge:
2798  return lhs.uge(rhs);
2799  case ICmpPredicate::ceq:
2800  return lhs.eq(rhs);
2801  case ICmpPredicate::cne:
2802  return lhs.ne(rhs);
2803  case ICmpPredicate::weq:
2804  return lhs.eq(rhs);
2805  case ICmpPredicate::wne:
2806  return lhs.ne(rhs);
2807  }
2808  llvm_unreachable("unknown comparison predicate");
2809 }
2810 
2811 // Returns the result of applying the predicate when the LHS and RHS are the
2812 // exact same value.
2813 static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) {
2814  switch (predicate) {
2815  case ICmpPredicate::eq:
2816  case ICmpPredicate::sle:
2817  case ICmpPredicate::sge:
2818  case ICmpPredicate::ule:
2819  case ICmpPredicate::uge:
2820  case ICmpPredicate::ceq:
2821  case ICmpPredicate::weq:
2822  return true;
2823  case ICmpPredicate::ne:
2824  case ICmpPredicate::slt:
2825  case ICmpPredicate::sgt:
2826  case ICmpPredicate::ult:
2827  case ICmpPredicate::ugt:
2828  case ICmpPredicate::cne:
2829  case ICmpPredicate::wne:
2830  return false;
2831  }
2832  llvm_unreachable("unknown comparison predicate");
2833 }
2834 
2835 OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2836  if (hasOperandsOutsideOfBlock(getOperation()))
2837  return {};
2838 
2839  // gt a, a -> false
2840  // gte a, a -> true
2841  if (getLhs() == getRhs()) {
2842  auto val = applyCmpPredicateToEqualOperands(getPredicate());
2843  return IntegerAttr::get(getType(), val);
2844  }
2845 
2846  // gt 1, 2 -> false
2847  if (auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2848  if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2849  auto val =
2850  applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2851  return IntegerAttr::get(getType(), val);
2852  }
2853  }
2854  return {};
2855 }
2856 
2857 // Given a range of operands, computes the number of matching prefix and
2858 // suffix elements. This does not perform cross-element matching.
2859 template <typename Range>
2860 static size_t computeCommonPrefixLength(const Range &a, const Range &b) {
2861  size_t commonPrefixLength = 0;
2862  auto ia = a.begin();
2863  auto ib = b.begin();
2864 
2865  for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2866  if (*ia != *ib) {
2867  break;
2868  }
2869  }
2870 
2871  return commonPrefixLength;
2872 }
2873 
2874 static size_t getTotalWidth(ArrayRef<Value> operands) {
2875  size_t totalWidth = 0;
2876  for (auto operand : operands) {
2877  // getIntOrFloatBitWidth should never raise, since all arguments to
2878  // ConcatOp are integers.
2879  ssize_t width = operand.getType().getIntOrFloatBitWidth();
2880  assert(width >= 0);
2881  totalWidth += width;
2882  }
2883  return totalWidth;
2884 }
2885 
2886 /// Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise
2887 /// comparison on common prefix and suffixes. Returns success() if a rewriting
2888 /// happens. This handles both concat and replicate.
2889 static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs,
2890  Operation *rhs,
2891  PatternRewriter &rewriter) {
2892  // It is safe to assume that [{lhsOperands, rhsOperands}.size() > 0] and
2893  // all elements have non-zero length. Both these invariants are verified
2894  // by the ConcatOp verifier.
2895  SmallVector<Value> lhsOperands, rhsOperands;
2896  getConcatOperands(lhs->getResult(0), lhsOperands);
2897  getConcatOperands(rhs->getResult(0), rhsOperands);
2898  ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2899 
2900  auto formCatOrReplicate = [&](Location loc,
2901  ArrayRef<Value> operands) -> Value {
2902  assert(!operands.empty());
2903  Value sameElement = operands[0];
2904  for (size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2905  if (sameElement != operands[i])
2906  sameElement = Value();
2907  if (sameElement)
2908  return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2909  operands.size());
2910  return rewriter.createOrFold<ConcatOp>(loc, operands);
2911  };
2912 
2913  auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2914  Value rhs) -> LogicalResult {
2915  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2916  op.getTwoState());
2917  return success();
2918  };
2919 
2920  size_t commonPrefixLength =
2921  computeCommonPrefixLength(lhsOperands, rhsOperands);
2922  if (commonPrefixLength == lhsOperands.size()) {
2923  // cat(a, b, c) == cat(a, b, c) -> 1
2924  bool result = applyCmpPredicateToEqualOperands(op.getPredicate());
2925  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
2926  APInt(1, result));
2927  return success();
2928  }
2929 
2930  size_t commonSuffixLength = computeCommonPrefixLength(
2931  llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2932 
2933  size_t commonPrefixTotalWidth =
2934  getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2935  size_t commonSuffixTotalWidth =
2936  getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2937  auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2938  .drop_back(commonSuffixLength);
2939  auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2940  .drop_back(commonSuffixLength);
2941 
2942  auto replaceWithoutReplicatingSignBit = [&]() {
2943  auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2944  auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2945  return replaceWith(op.getPredicate(), newLhs, newRhs);
2946  };
2947 
2948  auto replaceWithReplicatingSignBit = [&]() {
2949  auto firstNonEmptyValue = lhsOperands[0];
2950  auto firstNonEmptyElemWidth =
2951  firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2952  Value signBit = rewriter.createOrFold<ExtractOp>(
2953  op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2954 
2955  auto newLhs = rewriter.create<ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2956  auto newRhs = rewriter.create<ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2957  return replaceWith(op.getPredicate(), newLhs, newRhs);
2958  };
2959 
2960  if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2961  // scmp(cat(..x, b), cat(..y, b)) == scmp(cat(..x), cat(..y))
2962  if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2963  return replaceWithoutReplicatingSignBit();
2964 
2965  // scmp(cat(a, ..x, b), cat(a, ..y, b)) == scmp(cat(sgn(a), ..x),
2966  // cat(sgn(b), ..y)) Note that we cannot perform this optimization if
2967  // [width(b) = 0 && width(a) <= 1]. since that common prefix is the sign
2968  // bit. Doing the rewrite can result in an infinite loop.
2969  if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2970  return replaceWithReplicatingSignBit();
2971 
2972  } else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2973  // ucmp(cat(a, ..x, b), cat(a, ..y, b)) = ucmp(cat(..x), cat(..y))
2974  return replaceWithoutReplicatingSignBit();
2975  }
2976 
2977  return failure();
2978 }
2979 
2980 /// Given an equality comparison with a constant value and some operand that has
2981 /// known bits, simplify the comparison to check only the unknown bits of the
2982 /// input.
2983 ///
2984 /// One simple example of this is that `concat(0, stuff) == 0` can be simplified
2985 /// to `stuff == 0`, or `and(x, 3) == 0` can be simplified to
2986 /// `extract x[1:0] == 0`
2988  ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst,
2989  PatternRewriter &rewriter) {
2990 
2991  // If any of the known bits disagree with any of the comparison bits, then
2992  // we can constant fold this comparison right away.
2993  APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2994  if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2995  // If we discover a mismatch then we know an "eq" comparison is false
2996  // and a "ne" comparison is true!
2997  bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2998  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
2999  APInt(1, result));
3000  return;
3001  }
3002 
3003  // Check to see if we can prove the result entirely of the comparison (in
3004  // which we bail out early), otherwise build a list of values to concat and a
3005  // smaller constant to compare against.
3006  SmallVector<Value> newConcatOperands;
3007  auto newConstant = APInt::getZeroWidth();
3008 
3009  // Ok, some (maybe all) bits are known and some others may be unknown.
3010  // Extract out segments of the operand and compare against the
3011  // corresponding bits.
3012  unsigned knownMSB = bitsKnown.countLeadingOnes();
3013 
3014  Value operand = cmpOp.getLhs();
3015 
3016  // Ok, some bits are known but others are not. Extract out sequences of
3017  // bits that are unknown and compare just those bits. We work from MSB to
3018  // LSB.
3019  while (knownMSB != bitsKnown.getBitWidth()) {
3020  // Drop any high bits that are known.
3021  if (knownMSB)
3022  bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
3023 
3024  // Find the span of unknown bits, and extract it.
3025  unsigned unknownBits = bitsKnown.countLeadingZeros();
3026  unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
3027  auto spanOperand = rewriter.createOrFold<ExtractOp>(
3028  operand.getLoc(), operand, /*lowBit=*/lowBit,
3029  /*bitWidth=*/unknownBits);
3030  auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
3031 
3032  // Add this info to the concat we're generating.
3033  newConcatOperands.push_back(spanOperand);
3034  // FIXME(llvm merge, cc697fc292b0): concat doesn't work with zero bit values
3035  // newConstant = newConstant.concat(spanConstant);
3036  if (newConstant.getBitWidth() != 0)
3037  newConstant = newConstant.concat(spanConstant);
3038  else
3039  newConstant = spanConstant;
3040 
3041  // Drop the unknown bits in prep for the next chunk.
3042  unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3043  bitsKnown = bitsKnown.trunc(newWidth);
3044  knownMSB = bitsKnown.countLeadingOnes();
3045  }
3046 
3047  // If all the operands to the concat are foldable then we have an identity
3048  // situation where all the sub-elements equal each other. This implies that
3049  // the overall result is foldable.
3050  if (newConcatOperands.empty()) {
3051  bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3052  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
3053  APInt(1, result));
3054  return;
3055  }
3056 
3057  // If we have a single operand remaining, use it, otherwise form a concat.
3058  Value concatResult =
3059  rewriter.createOrFold<ConcatOp>(operand.getLoc(), newConcatOperands);
3060 
3061  // Form the comparison against the smaller constant.
3062  auto newConstantOp = rewriter.create<hw::ConstantOp>(
3063  cmpOp.getOperand(1).getLoc(), newConstant);
3064 
3065  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3066  concatResult, newConstantOp,
3067  cmpOp.getTwoState());
3068 }
3069 
3070 // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), cst1^cst2).
3071 static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp,
3072  const APInt &rhs,
3073  PatternRewriter &rewriter) {
3074  auto ip = rewriter.saveInsertionPoint();
3075  rewriter.setInsertionPoint(xorOp);
3076 
3077  auto xorRHS = xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>();
3078  auto newRHS = rewriter.create<hw::ConstantOp>(xorRHS->getLoc(),
3079  xorRHS.getValue() ^ rhs);
3080  Value newLHS;
3081  switch (xorOp.getNumOperands()) {
3082  case 1:
3083  // This isn't common but is defined so we need to handle it.
3084  newLHS = rewriter.create<hw::ConstantOp>(xorOp.getLoc(),
3085  APInt::getZero(rhs.getBitWidth()));
3086  break;
3087  case 2:
3088  // The binary case is the most common.
3089  newLHS = xorOp.getOperand(0);
3090  break;
3091  default:
3092  // The general case forces us to form a new xor with the remaining operands.
3093  SmallVector<Value> newOperands(xorOp.getOperands());
3094  newOperands.pop_back();
3095  newLHS = rewriter.create<XorOp>(xorOp.getLoc(), newOperands, false);
3096  break;
3097  }
3098 
3099  bool xorMultipleUses = !xorOp->hasOneUse();
3100 
3101  // If the xor has multiple uses (not just the compare, then we need/want to
3102  // replace them as well.
3103  if (xorMultipleUses)
3104  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3105  false);
3106 
3107  // Replace the comparison.
3108  rewriter.restoreInsertionPoint(ip);
3109  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3110  newLHS, newRHS, false);
3111 }
3112 
3113 LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3114  if (hasOperandsOutsideOfBlock(&*op))
3115  return failure();
3116 
3117  APInt lhs, rhs;
3118 
3119  // icmp 1, x -> icmp x, 1
3120  if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3121  assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3122  "Should be folded");
3123  replaceOpWithNewOpAndCopyName<ICmpOp>(
3124  rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3125  op.getRhs(), op.getLhs(), op.getTwoState());
3126  return success();
3127  }
3128 
3129  // Canonicalize with RHS constant
3130  if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3131  auto getConstant = [&](APInt constant) -> Value {
3132  return rewriter.create<hw::ConstantOp>(op.getLoc(), std::move(constant));
3133  };
3134 
3135  auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3136  Value rhs) -> LogicalResult {
3137  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
3138  op.getTwoState());
3139  return success();
3140  };
3141 
3142  auto replaceWithConstantI1 = [&](bool constant) -> LogicalResult {
3143  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
3144  APInt(1, constant));
3145  return success();
3146  };
3147 
3148  switch (op.getPredicate()) {
3149  case ICmpPredicate::slt:
3150  // x < max -> x != max
3151  if (rhs.isMaxSignedValue())
3152  return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3153  // x < min -> false
3154  if (rhs.isMinSignedValue())
3155  return replaceWithConstantI1(0);
3156  // x < min+1 -> x == min
3157  if ((rhs - 1).isMinSignedValue())
3158  return replaceWith(ICmpPredicate::eq, op.getLhs(),
3159  getConstant(rhs - 1));
3160  break;
3161  case ICmpPredicate::sgt:
3162  // x > min -> x != min
3163  if (rhs.isMinSignedValue())
3164  return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3165  // x > max -> false
3166  if (rhs.isMaxSignedValue())
3167  return replaceWithConstantI1(0);
3168  // x > max-1 -> x == max
3169  if ((rhs + 1).isMaxSignedValue())
3170  return replaceWith(ICmpPredicate::eq, op.getLhs(),
3171  getConstant(rhs + 1));
3172  break;
3173  case ICmpPredicate::ult:
3174  // x < max -> x != max
3175  if (rhs.isAllOnes())
3176  return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3177  // x < min -> false
3178  if (rhs.isZero())
3179  return replaceWithConstantI1(0);
3180  // x < min+1 -> x == min
3181  if ((rhs - 1).isZero())
3182  return replaceWith(ICmpPredicate::eq, op.getLhs(),
3183  getConstant(rhs - 1));
3184 
3185  // x < 0xE0 -> extract(x, 5..7) != 0b111
3186  if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3187  rhs.getBitWidth()) {
3188  auto numOnes = rhs.countLeadingOnes();
3189  auto smaller = rewriter.create<ExtractOp>(
3190  op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
3191  return replaceWith(ICmpPredicate::ne, smaller,
3192  getConstant(APInt::getAllOnes(numOnes)));
3193  }
3194 
3195  break;
3196  case ICmpPredicate::ugt:
3197  // x > min -> x != min
3198  if (rhs.isZero())
3199  return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3200  // x > max -> false
3201  if (rhs.isAllOnes())
3202  return replaceWithConstantI1(0);
3203  // x > max-1 -> x == max
3204  if ((rhs + 1).isAllOnes())
3205  return replaceWith(ICmpPredicate::eq, op.getLhs(),
3206  getConstant(rhs + 1));
3207 
3208  // x > 0x07 -> extract(x, 3..7) != 0b00000
3209  if ((rhs + 1).isPowerOf2()) {
3210  auto numOnes = rhs.countTrailingOnes();
3211  auto newWidth = rhs.getBitWidth() - numOnes;
3212  auto smaller = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(),
3213  numOnes, newWidth);
3214  return replaceWith(ICmpPredicate::ne, smaller,
3215  getConstant(APInt::getZero(newWidth)));
3216  }
3217 
3218  break;
3219  case ICmpPredicate::sle:
3220  // x <= max -> true
3221  if (rhs.isMaxSignedValue())
3222  return replaceWithConstantI1(1);
3223  // x <= c -> x < (c+1)
3224  return replaceWith(ICmpPredicate::slt, op.getLhs(), getConstant(rhs + 1));
3225  case ICmpPredicate::sge:
3226  // x >= min -> true
3227  if (rhs.isMinSignedValue())
3228  return replaceWithConstantI1(1);
3229  // x >= c -> x > (c-1)
3230  return replaceWith(ICmpPredicate::sgt, op.getLhs(), getConstant(rhs - 1));
3231  case ICmpPredicate::ule:
3232  // x <= max -> true
3233  if (rhs.isAllOnes())
3234  return replaceWithConstantI1(1);
3235  // x <= c -> x < (c+1)
3236  return replaceWith(ICmpPredicate::ult, op.getLhs(), getConstant(rhs + 1));
3237  case ICmpPredicate::uge:
3238  // x >= min -> true
3239  if (rhs.isZero())
3240  return replaceWithConstantI1(1);
3241  // x >= c -> x > (c-1)
3242  return replaceWith(ICmpPredicate::ugt, op.getLhs(), getConstant(rhs - 1));
3243  case ICmpPredicate::eq:
3244  if (rhs.getBitWidth() == 1) {
3245  if (rhs.isZero()) {
3246  // x == 0 -> x ^ 1
3247  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3248  getConstant(APInt(1, 1)),
3249  op.getTwoState());
3250  return success();
3251  }
3252  if (rhs.isAllOnes()) {
3253  // x == 1 -> x
3254  replaceOpAndCopyName(rewriter, op, op.getLhs());
3255  return success();
3256  }
3257  }
3258  break;
3259  case ICmpPredicate::ne:
3260  if (rhs.getBitWidth() == 1) {
3261  if (rhs.isZero()) {
3262  // x != 0 -> x
3263  replaceOpAndCopyName(rewriter, op, op.getLhs());
3264  return success();
3265  }
3266  if (rhs.isAllOnes()) {
3267  // x != 1 -> x ^ 1
3268  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3269  getConstant(APInt(1, 1)),
3270  op.getTwoState());
3271  return success();
3272  }
3273  }
3274  break;
3275  case ICmpPredicate::ceq:
3276  case ICmpPredicate::cne:
3277  case ICmpPredicate::weq:
3278  case ICmpPredicate::wne:
3279  break;
3280  }
3281 
3282  // We have some specific optimizations for comparison with a constant that
3283  // are only supported for equality comparisons.
3284  if (op.getPredicate() == ICmpPredicate::eq ||
3285  op.getPredicate() == ICmpPredicate::ne) {
3286  // Simplify `icmp(value_with_known_bits, rhscst)` into some extracts
3287  // with a smaller constant. We only support equality comparisons for
3288  // this.
3289  auto knownBits = computeKnownBits(op.getLhs());
3290  if (!knownBits.isUnknown())
3291  return combineEqualityICmpWithKnownBitsAndConstant(op, knownBits, rhs,
3292  rewriter),
3293  success();
3294 
3295  // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b),
3296  // cst1^cst2).
3297  if (auto xorOp = op.getLhs().getDefiningOp<XorOp>())
3298  if (xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>())
3299  return combineEqualityICmpWithXorOfConstant(op, xorOp, rhs, rewriter),
3300  success();
3301 
3302  // Simplify icmp eq(replicate(v, n), c) -> icmp eq(v, c) if c is zero or
3303  // all one.
3304  if (auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3305  if (rhs.isAllOnes() || rhs.isZero()) {
3306  auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3307  auto cst = rewriter.create<hw::ConstantOp>(
3308  op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(width)
3309  : APInt::getZero(width));
3310  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, op.getPredicate(),
3311  replicateOp.getInput(), cst,
3312  op.getTwoState());
3313  return success();
3314  }
3315  }
3316  }
3317 
3318  // icmp(cat(prefix, a, b, suffix), cat(prefix, c, d, suffix)) => icmp(cat(a,
3319  // b), cat(c, d)). contains special handling for sign bit in signed
3320  // compressions.
3321  if (Operation *opLHS = op.getLhs().getDefiningOp())
3322  if (Operation *opRHS = op.getRhs().getDefiningOp())
3323  if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3324  isa<ConcatOp, ReplicateOp>(opRHS)) {
3325  if (succeeded(matchAndRewriteCompareConcat(op, opLHS, opRHS, rewriter)))
3326  return success();
3327  }
3328 
3329  return failure();
3330 }
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition: CalyxOps.cpp:539
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
Definition: CombFolds.cpp:2434
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Definition: CombFolds.cpp:713
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Definition: CombFolds.cpp:338
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
Definition: CombFolds.cpp:2813
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
Definition: CombFolds.cpp:748
static bool hasOperandsOutsideOfBlock(Operation *op)
In comb, we assume no knowledge of the semantics of cross-block dataflow.
Definition: CombFolds.cpp:32
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
Definition: CombFolds.cpp:249
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
Definition: CombFolds.cpp:1743
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
Definition: CombFolds.cpp:900
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
Definition: CombFolds.cpp:825
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
Definition: CombFolds.cpp:58
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "sv.namehint" attribute.
Definition: CombFolds.cpp:88
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
Definition: CombFolds.cpp:2201
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
Definition: CombFolds.cpp:2041
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
Definition: CombFolds.cpp:52
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
Definition: CombFolds.cpp:126
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
Definition: CombFolds.cpp:218
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
Definition: CombFolds.cpp:1394
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
Definition: CombFolds.cpp:580
static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1, size_t concatIdx2, PatternRewriter &rewriter)
Simplify concat ops in an or op when a constant operand is present in either concat.
Definition: CombFolds.cpp:1171
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
Definition: CombFolds.cpp:3071
static size_t getTotalWidth(ArrayRef< Value > operands)
Definition: CombFolds.cpp:2874
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
Definition: CombFolds.cpp:2339
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
Definition: CombFolds.cpp:143
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
Definition: CombFolds.cpp:935
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
Definition: CombFolds.cpp:2776
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
Definition: CombFolds.cpp:2987
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
Definition: CombFolds.cpp:2088
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
Definition: CombFolds.cpp:120
static bool hasSVAttributes(Operation *op)
Definition: CombFolds.cpp:103
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
Definition: CombFolds.cpp:499
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
Definition: CombFolds.cpp:1772
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
Definition: CombFolds.cpp:2860
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
Definition: CombFolds.cpp:2241
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
Definition: CombFolds.cpp:2889
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
Definition: CombFolds.cpp:44
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition: CombFolds.cpp:73
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
int32_t width
Definition: FIRRTL.cpp:36
llvm::SmallVector< StringAttr > inputs
Builder builder
def create(low_bit, result_type, input=None)
Definition: comb.py:187
def create(elements)
Definition: hw.py:443
def create(data_type, value)
Definition: hw.py:393
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition: CombOps.cpp:48
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:34
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
bool isOffset(Value base, Value index, uint64_t offset)
Definition: HWOps.cpp:1828
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: comb.py:1