CIRCT  18.0.0git
CombFolds.cpp
Go to the documentation of this file.
1 //===- CombFolds.cpp - Folds + Canonicalization for Comb operations -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "circt/Dialect/HW/HWOps.h"
12 #include "mlir/IR/Matchers.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "llvm/ADT/SetVector.h"
15 #include "llvm/ADT/SmallBitVector.h"
16 #include "llvm/ADT/TypeSwitch.h"
17 #include "llvm/Support/KnownBits.h"
18 
19 using namespace mlir;
20 using namespace circt;
21 using namespace comb;
22 using namespace matchers;
23 
24 /// In comb, we assume no knowledge of the semantics of cross-block dataflow. As
25 /// such, cross-block dataflow is interpreted as a canonicalization barrier.
26 /// This is a conservative approach which:
27 /// 1. still allows for efficient canonicalization for the common CIRCT usecase
28 /// of comb (comb logic nested inside single-block hw.module's)
29 /// 2. allows comb operations to be used in non-HW container ops - that may use
30 /// MLIR blocks and regions to represent various forms of hierarchical
31 /// abstractions, thus allowing comb to compose with other dialects.
32 static bool hasOperandsOutsideOfBlock(Operation *op) {
33  Block *thisBlock = op->getBlock();
34  return llvm::any_of(op->getOperands(), [&](Value operand) {
35  return operand.getParentBlock() != thisBlock;
36  });
37 }
38 
39 /// Create a new instance of a generic operation that only has value operands,
40 /// and has a single result value whose type matches the first operand.
41 ///
42 /// This should not be used to create instances of ops with attributes or with
43 /// more complicated type signatures.
44 static Value createGenericOp(Location loc, OperationName name,
45  ArrayRef<Value> operands, OpBuilder &builder) {
46  OperationState state(loc, name);
47  state.addOperands(operands);
48  state.addTypes(operands[0].getType());
49  return builder.create(state)->getResult(0);
50 }
51 
52 static TypedAttr getIntAttr(const APInt &value, MLIRContext *context) {
53  return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
54  value);
55 }
56 
57 /// Flatten concat and mux operands into a vector.
58 static void getConcatOperands(Value v, SmallVectorImpl<Value> &result) {
59  if (auto concat = v.getDefiningOp<ConcatOp>()) {
60  for (auto op : concat.getOperands())
61  getConcatOperands(op, result);
62  } else if (auto repl = v.getDefiningOp<ReplicateOp>()) {
63  for (size_t i = 0, e = repl.getMultiple(); i != e; ++i)
64  getConcatOperands(repl.getOperand(), result);
65  } else {
66  result.push_back(v);
67  }
68 }
69 
70 /// A wrapper of `PatternRewriter::replaceOp` to propagate "sv.namehint"
71 /// attribute. If a replaced op has a "sv.namehint" attribute, this function
72 /// propagates the name to the new value.
73 static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
74  Value newValue) {
75  if (auto *newOp = newValue.getDefiningOp()) {
76  auto name = op->getAttrOfType<StringAttr>("sv.namehint");
77  if (name && !newOp->hasAttr("sv.namehint"))
78  rewriter.updateRootInPlace(newOp,
79  [&] { newOp->setAttr("sv.namehint", name); });
80  }
81  rewriter.replaceOp(op, newValue);
82 }
83 
84 /// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate
85 /// "sv.namehint" attribute. If a replaced op has a "sv.namehint" attribute,
86 /// this function propagates the name to the new value.
87 template <typename OpTy, typename... Args>
88 static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
89  Operation *op, Args &&...args) {
90  auto name = op->getAttrOfType<StringAttr>("sv.namehint");
91  auto newOp =
92  rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
93  if (name && !newOp->hasAttr("sv.namehint"))
94  rewriter.updateRootInPlace(newOp,
95  [&] { newOp->setAttr("sv.namehint", name); });
96 
97  return newOp;
98 }
99 
100 // Return true if the op has SV attributes. Note that we cannot use a helper
101 // function `hasSVAttributes` defined under SV dialect because of a cyclic
102 // dependency.
103 static bool hasSVAttributes(Operation *op) {
104  return op->hasAttr("sv.attributes");
105 }
106 
107 namespace {
108 template <typename SubType>
109 struct ComplementMatcher {
110  SubType lhs;
111  ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
112  bool match(Operation *op) {
113  auto xorOp = dyn_cast<XorOp>(op);
114  return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
115  }
116 };
117 } // end anonymous namespace
118 
119 template <typename SubType>
120 static inline ComplementMatcher<SubType> m_Complement(const SubType &subExpr) {
121  return ComplementMatcher<SubType>(subExpr);
122 }
123 
124 /// Flattens a single input in `op` if `hasOneUse` is true and it can be defined
125 /// as an Op. Returns true if successful, and false otherwise.
126 ///
127 /// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true
128 ///
129 static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) {
130  auto inputs = op->getOperands();
131 
132  for (size_t i = 0, size = inputs.size(); i != size; ++i) {
133  Operation *flattenOp = inputs[i].getDefiningOp();
134  if (!flattenOp || flattenOp->getName() != op->getName())
135  continue;
136 
137  // Check for loops
138  if (flattenOp == op)
139  continue;
140 
141  // Don't duplicate logic when it has multiple uses.
142  if (!inputs[i].hasOneUse()) {
143  // We can fold a multi-use binary operation into this one if this allows a
144  // constant to fold though. For example, fold
145  // (or a, b, c, (or d, cst1), cst2) --> (or a, b, c, d, cst1, cst2)
146  // since the constants will both fold and we end up with the equiv cost.
147  //
148  // We don't do this for add/mul because the hardware won't be shared
149  // between the two ops if duplicated.
150  if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
151  !flattenOp->getOperand(1).getDefiningOp<hw::ConstantOp>() ||
152  !inputs.back().getDefiningOp<hw::ConstantOp>())
153  continue;
154  }
155 
156  // Otherwise, flatten away.
157  auto flattenOpInputs = flattenOp->getOperands();
158 
159  SmallVector<Value, 4> newOperands;
160  newOperands.reserve(size + flattenOpInputs.size());
161 
162  auto flattenOpIndex = inputs.begin() + i;
163  newOperands.append(inputs.begin(), flattenOpIndex);
164  newOperands.append(flattenOpInputs.begin(), flattenOpInputs.end());
165  newOperands.append(flattenOpIndex + 1, inputs.end());
166 
167  Value result =
168  createGenericOp(op->getLoc(), op->getName(), newOperands, rewriter);
169 
170  // If the original operation and flatten operand have bin flags, propagte
171  // the flag to new one.
172  if (op->hasAttrOfType<UnitAttr>("twoState") &&
173  flattenOp->hasAttrOfType<UnitAttr>("twoState"))
174  result.getDefiningOp()->setAttr("twoState", rewriter.getUnitAttr());
175 
176  replaceOpAndCopyName(rewriter, op, result);
177  return true;
178  }
179  return false;
180 }
181 
182 // Given a range of uses of an operation, find the lowest and highest bits
183 // inclusive that are ever referenced. The range of uses must not be empty.
184 static std::pair<size_t, size_t>
185 getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits,
186  size_t originalOpWidth) {
187  auto users = op->getUsers();
188  assert(!users.empty() &&
189  "getLowestBitAndHighestBitRequired cannot operate on "
190  "a empty list of uses.");
191 
192  // when we don't want to narrowTrailingBits (namely in arithmetic
193  // operations), forcing lowestBitRequired = 0
194  size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
195  size_t highestBitRequired = 0;
196 
197  for (auto *user : users) {
198  if (auto extractOp = dyn_cast<ExtractOp>(user)) {
199  size_t lowBit = extractOp.getLowBit();
200  size_t highBit =
201  extractOp.getType().cast<IntegerType>().getWidth() + lowBit - 1;
202  highestBitRequired = std::max(highestBitRequired, highBit);
203  lowestBitRequired = std::min(lowestBitRequired, lowBit);
204  continue;
205  }
206 
207  highestBitRequired = originalOpWidth - 1;
208  lowestBitRequired = 0;
209  break;
210  }
211 
212  return {lowestBitRequired, highestBitRequired};
213 }
214 
215 template <class OpTy>
216 static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits,
217  PatternRewriter &rewriter) {
218  IntegerType opType =
219  op.getResult().getType().template dyn_cast<IntegerType>();
220  if (!opType)
221  return false;
222 
223  auto range = getLowestBitAndHighestBitRequired(op, narrowTrailingBits,
224  opType.getWidth());
225  if (range.second + 1 == opType.getWidth() && range.first == 0)
226  return false;
227 
228  SmallVector<Value> args;
229  auto newType = rewriter.getIntegerType(range.second - range.first + 1);
230  for (auto inop : op.getOperands()) {
231  // deal with muxes here
232  if (inop.getType() != op.getType())
233  args.push_back(inop);
234  else
235  args.push_back(rewriter.createOrFold<ExtractOp>(inop.getLoc(), newType,
236  inop, range.first));
237  }
238  Value newop = rewriter.createOrFold<OpTy>(op.getLoc(), newType, args);
239  newop.getDefiningOp()->setDialectAttrs(op->getDialectAttrs());
240  if (range.first)
241  newop = rewriter.createOrFold<ConcatOp>(
242  op.getLoc(), newop,
243  rewriter.create<hw::ConstantOp>(op.getLoc(),
244  APInt::getZero(range.first)));
245  if (range.second + 1 < opType.getWidth())
246  newop = rewriter.createOrFold<ConcatOp>(
247  op.getLoc(),
248  rewriter.create<hw::ConstantOp>(
249  op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
250  newop);
251  rewriter.replaceOp(op, newop);
252  return true;
253 }
254 
255 //===----------------------------------------------------------------------===//
256 // Unary Operations
257 //===----------------------------------------------------------------------===//
258 
259 OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
260  if (hasOperandsOutsideOfBlock(getOperation()))
261  return {};
262 
263  // Replicate one time -> noop.
264  if (getType().cast<IntegerType>().getWidth() ==
265  getInput().getType().getIntOrFloatBitWidth())
266  return getInput();
267 
268  // Constant fold.
269  if (auto input = adaptor.getInput().dyn_cast_or_null<IntegerAttr>()) {
270  if (input.getValue().getBitWidth() == 1) {
271  if (input.getValue().isZero())
272  return getIntAttr(
273  APInt::getZero(getType().cast<IntegerType>().getWidth()),
274  getContext());
275  return getIntAttr(
276  APInt::getAllOnes(getType().cast<IntegerType>().getWidth()),
277  getContext());
278  }
279 
280  APInt result = APInt::getZeroWidth();
281  for (auto i = getMultiple(); i != 0; --i)
282  result = result.concat(input.getValue());
283  return getIntAttr(result, getContext());
284  }
285 
286  return {};
287 }
288 
289 OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
290  if (hasOperandsOutsideOfBlock(getOperation()))
291  return {};
292 
293  // Constant fold.
294  if (auto input = adaptor.getInput().dyn_cast_or_null<IntegerAttr>())
295  return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
296 
297  return {};
298 }
299 
300 //===----------------------------------------------------------------------===//
301 // Binary Operations
302 //===----------------------------------------------------------------------===//
303 
304 /// Performs constant folding `calculate` with element-wise behavior on the two
305 /// attributes in `operands` and returns the result if possible.
306 static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
307  hw::PEO paramOpcode) {
308  assert(operands.size() == 2 && "binary op takes two operands");
309  if (!operands[0] || !operands[1])
310  return {};
311 
312  // Fold constants with ParamExprAttr::get which handles simple constants as
313  // well as parameter expressions.
314  return hw::ParamExprAttr::get(paramOpcode, operands[0].cast<TypedAttr>(),
315  operands[1].cast<TypedAttr>());
316 }
317 
318 OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
319  if (hasOperandsOutsideOfBlock(getOperation()))
320  return {};
321 
322  if (auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>()) {
323  unsigned shift = rhs.getValue().getZExtValue();
324  unsigned width = getType().getIntOrFloatBitWidth();
325  if (shift == 0)
326  return getOperand(0);
327  if (width <= shift)
328  return getIntAttr(APInt::getZero(width), getContext());
329  }
330 
331  return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::Shl);
332 }
333 
334 LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
335  if (hasOperandsOutsideOfBlock(&*op))
336  return failure();
337 
338  // ShlOp(x, cst) -> Concat(Extract(x), zeros)
339  APInt value;
340  if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
341  return failure();
342 
343  unsigned width = op.getLhs().getType().cast<IntegerType>().getWidth();
344  unsigned shift = value.getZExtValue();
345 
346  // This case is handled by fold.
347  if (width <= shift || shift == 0)
348  return failure();
349 
350  auto zeros =
351  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
352 
353  // Remove the high bits which would be removed by the Shl.
354  auto extract =
355  rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), 0, width - shift);
356 
357  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, extract, zeros);
358  return success();
359 }
360 
361 OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
362  if (hasOperandsOutsideOfBlock(getOperation()))
363  return {};
364 
365  if (auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>()) {
366  unsigned shift = rhs.getValue().getZExtValue();
367  if (shift == 0)
368  return getOperand(0);
369 
370  unsigned width = getType().getIntOrFloatBitWidth();
371  if (width <= shift)
372  return getIntAttr(APInt::getZero(width), getContext());
373  }
374  return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrU);
375 }
376 
377 LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
378  if (hasOperandsOutsideOfBlock(&*op))
379  return failure();
380 
381  // ShrUOp(x, cst) -> Concat(zeros, Extract(x))
382  APInt value;
383  if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
384  return failure();
385 
386  unsigned width = op.getLhs().getType().cast<IntegerType>().getWidth();
387  unsigned shift = value.getZExtValue();
388 
389  // This case is handled by fold.
390  if (width <= shift || shift == 0)
391  return failure();
392 
393  auto zeros =
394  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
395 
396  // Remove the low bits which would be removed by the Shr.
397  auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
398  width - shift);
399 
400  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, zeros, extract);
401  return success();
402 }
403 
404 OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
405  if (hasOperandsOutsideOfBlock(getOperation()))
406  return {};
407 
408  if (auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>()) {
409  if (rhs.getValue().getZExtValue() == 0)
410  return getOperand(0);
411  }
412  return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrS);
413 }
414 
415 LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
416  if (hasOperandsOutsideOfBlock(&*op))
417  return failure();
418 
419  // ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x))
420  APInt value;
421  if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
422  return failure();
423 
424  unsigned width = op.getLhs().getType().cast<IntegerType>().getWidth();
425  unsigned shift = value.getZExtValue();
426 
427  auto topbit =
428  rewriter.createOrFold<ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
429  auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
430 
431  if (width <= shift) {
432  replaceOpAndCopyName(rewriter, op, {sext});
433  return success();
434  }
435 
436  auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
437  width - shift);
438 
439  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, sext, extract);
440  return success();
441 }
442 
443 //===----------------------------------------------------------------------===//
444 // Other Operations
445 //===----------------------------------------------------------------------===//
446 
447 OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
448  if (hasOperandsOutsideOfBlock(getOperation()))
449  return {};
450 
451  // If we are extracting the entire input, then return it.
452  if (getInput().getType() == getType())
453  return getInput();
454 
455  // Constant fold.
456  if (auto input = adaptor.getInput().dyn_cast_or_null<IntegerAttr>()) {
457  unsigned dstWidth = getType().cast<IntegerType>().getWidth();
458  return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
459  getContext());
460  }
461  return {};
462 }
463 
464 // Transforms extract(lo, cat(a, b, c, d, e)) into
465 // cat(extract(lo1, b), c, extract(lo2, d)).
466 // innerCat must be the argument of the provided ExtractOp.
467 static LogicalResult extractConcatToConcatExtract(ExtractOp op,
468  ConcatOp innerCat,
469  PatternRewriter &rewriter) {
470  auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
471  size_t beginOfFirstRelevantElement = 0;
472  auto it = reversedConcatArgs.begin();
473  size_t lowBit = op.getLowBit();
474 
475  // This loop finds the first concatArg that is covered by the ExtractOp
476  for (; it != reversedConcatArgs.end(); it++) {
477  assert(beginOfFirstRelevantElement <= lowBit &&
478  "incorrectly moved past an element that lowBit has coverage over");
479  auto operand = *it;
480 
481  size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
482  if (lowBit < beginOfFirstRelevantElement + operandWidth) {
483  // A bit other than the first bit will be used in this element.
484  // ...... ........ ...
485  // ^---lowBit
486  // ^---beginOfFirstRelevantElement
487  //
488  // Edge-case close to the end of the range.
489  // ...... ........ ...
490  // ^---(position + operandWidth)
491  // ^---lowBit
492  // ^---beginOfFirstRelevantElement
493  //
494  // Edge-case close to the beginning of the rang
495  // ...... ........ ...
496  // ^---lowBit
497  // ^---beginOfFirstRelevantElement
498  //
499  break;
500  }
501 
502  // extraction discards this element.
503  // ...... ........ ...
504  // | ^---lowBit
505  // ^---beginOfFirstRelevantElement
506  beginOfFirstRelevantElement += operandWidth;
507  }
508  assert(it != reversedConcatArgs.end() &&
509  "incorrectly failed to find an element which contains coverage of "
510  "lowBit");
511 
512  SmallVector<Value> reverseConcatArgs;
513  size_t widthRemaining = op.getType().cast<IntegerType>().getWidth();
514  size_t extractLo = lowBit - beginOfFirstRelevantElement;
515 
516  // Transform individual arguments of innerCat(..., a, b, c,) into
517  // [ extract(a), b, extract(c) ], skipping an extract operation where
518  // possible.
519  for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
520  auto concatArg = *it;
521  size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
522  size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
523 
524  if (widthToConsume == operandWidth && extractLo == 0) {
525  reverseConcatArgs.push_back(concatArg);
526  } else {
527  auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
528  reverseConcatArgs.push_back(
529  rewriter.create<ExtractOp>(op.getLoc(), resultType, *it, extractLo));
530  }
531 
532  widthRemaining -= widthToConsume;
533 
534  // Beyond the first element, all elements are extracted from position 0.
535  extractLo = 0;
536  }
537 
538  if (reverseConcatArgs.size() == 1) {
539  replaceOpAndCopyName(rewriter, op, reverseConcatArgs[0]);
540  } else {
541  replaceOpWithNewOpAndCopyName<ConcatOp>(
542  rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
543  }
544  return success();
545 }
546 
547 // Transforms extract(lo, replicate(a, N)) into replicate(a, N-c).
548 static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate,
549  PatternRewriter &rewriter) {
550  auto extractResultWidth = op.getType().cast<IntegerType>().getWidth();
551  auto replicateEltWidth =
552  replicate.getOperand().getType().getIntOrFloatBitWidth();
553 
554  // If the extract starts at the base of an element and is an even multiple,
555  // we can replace the extract with a smaller replicate.
556  if (op.getLowBit() % replicateEltWidth == 0 &&
557  extractResultWidth % replicateEltWidth == 0) {
558  replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
559  replicate.getOperand());
560  return true;
561  }
562 
563  // If the extract is completely contained in one element, extract from the
564  // element.
565  if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
566  replicateEltWidth) {
567  replaceOpWithNewOpAndCopyName<ExtractOp>(
568  rewriter, op, op.getType(), replicate.getOperand(),
569  op.getLowBit() % replicateEltWidth);
570  return true;
571  }
572 
573  // We don't currently handle the case of extracting from non-whole elements,
574  // e.g. `extract (replicate 2-bit-thing, N), 1`.
575  return false;
576 }
577 
578 LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
579  if (hasOperandsOutsideOfBlock(&*op))
580  return failure();
581 
582  auto *inputOp = op.getInput().getDefiningOp();
583 
584  // This turns out to be incredibly expensive. Disable until performance is
585  // addressed.
586 #if 0
587  // If the extracted bits are all known, then return the result.
588  auto knownBits = computeKnownBits(op.getInput())
589  .extractBits(op.getType().cast<IntegerType>().getWidth(),
590  op.getLowBit());
591  if (knownBits.isConstant()) {
592  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
593  knownBits.getConstant());
594  return success();
595  }
596 #endif
597 
598  // extract(olo, extract(ilo, x)) = extract(olo + ilo, x)
599  if (auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
600  replaceOpWithNewOpAndCopyName<ExtractOp>(
601  rewriter, op, op.getType(), innerExtract.getInput(),
602  innerExtract.getLowBit() + op.getLowBit());
603  return success();
604  }
605 
606  // extract(lo, cat(a, b, c, d, e)) = cat(extract(lo1, b), c, extract(lo2, d))
607  if (auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
608  return extractConcatToConcatExtract(op, innerCat, rewriter);
609 
610  // extract(lo, replicate(a))
611  if (auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
612  if (extractFromReplicate(op, replicate, rewriter))
613  return success();
614 
615  // `extract(and(a, cst))` -> `extract(a)` when the relevant bits of the
616  // and/or/xor are not modifying the extracted bits.
617  if (inputOp && inputOp->getNumOperands() == 2 &&
618  isa<AndOp, OrOp, XorOp>(inputOp)) {
619  if (auto cstRHS = inputOp->getOperand(1).getDefiningOp<hw::ConstantOp>()) {
620  auto extractedCst = cstRHS.getValue().extractBits(
621  op.getType().cast<IntegerType>().getWidth(), op.getLowBit());
622  if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
623  replaceOpWithNewOpAndCopyName<ExtractOp>(
624  rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
625  return success();
626  }
627 
628  // `extract(and(a, cst))` -> `concat(extract(a), 0)` if we only need one
629  // extract to represent the result. Turning it into a pile of extracts is
630  // always fine by our cost model, but we don't want to explode things into
631  // a ton of bits because it will bloat the IR and generated Verilog.
632  if (isa<AndOp>(inputOp)) {
633  // For our cost model, we only do this if the bit pattern is a
634  // contiguous series of ones.
635  unsigned lz = extractedCst.countLeadingZeros();
636  unsigned tz = extractedCst.countTrailingZeros();
637  unsigned pop = extractedCst.popcount();
638  if (extractedCst.getBitWidth() - lz - tz == pop) {
639  auto resultTy = rewriter.getIntegerType(pop);
640  SmallVector<Value> resultElts;
641  if (lz)
642  resultElts.push_back(rewriter.create<hw::ConstantOp>(
643  op.getLoc(), APInt::getZero(lz)));
644  resultElts.push_back(rewriter.createOrFold<ExtractOp>(
645  op.getLoc(), resultTy, inputOp->getOperand(0),
646  op.getLowBit() + tz));
647  if (tz)
648  resultElts.push_back(rewriter.create<hw::ConstantOp>(
649  op.getLoc(), APInt::getZero(tz)));
650  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
651  return success();
652  }
653  }
654  }
655  }
656 
657  // `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is
658  // extracted.
659  if (op.getType().cast<IntegerType>().getWidth() == 1 && inputOp)
660  if (auto shlOp = dyn_cast<ShlOp>(inputOp))
661  if (auto lhsCst = shlOp.getOperand(0).getDefiningOp<hw::ConstantOp>())
662  if (lhsCst.getValue().isOne()) {
663  auto newCst = rewriter.create<hw::ConstantOp>(
664  shlOp.getLoc(),
665  APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
666  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
667  shlOp->getOperand(1), newCst,
668  false);
669  return success();
670  }
671 
672  return failure();
673 }
674 
675 //===----------------------------------------------------------------------===//
676 // Associative Variadic operations
677 //===----------------------------------------------------------------------===//
678 
679 // Reduce all operands to a single value (either integer constant or parameter
680 // expression) if all the operands are constants.
681 static Attribute constFoldAssociativeOp(ArrayRef<Attribute> operands,
682  hw::PEO paramOpcode) {
683  assert(operands.size() > 1 && "caller should handle one-operand case");
684  // We can only fold anything in the case where all operands are known to be
685  // constants. Check the least common one first for an early out.
686  if (!operands[1] || !operands[0])
687  return {};
688 
689  // This will fold to a simple constant if all operands are constant.
690  if (llvm::all_of(operands.drop_front(2),
691  [&](Attribute in) { return !!in; })) {
692  SmallVector<mlir::TypedAttr> typedOperands;
693  typedOperands.reserve(operands.size());
694  for (auto operand : operands) {
695  if (auto typedOperand = operand.dyn_cast<mlir::TypedAttr>())
696  typedOperands.push_back(typedOperand);
697  else
698  break;
699  }
700  if (typedOperands.size() == operands.size())
701  return hw::ParamExprAttr::get(paramOpcode, typedOperands);
702  }
703 
704  return {};
705 }
706 
707 /// When we find a logical operation (and, or, xor) with a constant e.g.
708 /// `X & 42`, we want to push the constant into the computation of X if it leads
709 /// to simplification.
710 ///
711 /// This function handles the case where the logical operation has a concat
712 /// operand. We check to see if we can simplify the concat, e.g. when it has
713 /// constant operands.
714 ///
715 /// This returns true when a simplification happens.
716 static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp,
717  size_t concatIdx, const APInt &cst,
718  PatternRewriter &rewriter) {
719  auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<ConcatOp>();
720  assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
721 
722  // Check to see if any operands can be simplified by pushing the logical op
723  // into all parts of the concat.
724  bool canSimplify =
725  llvm::any_of(concatOp->getOperands(), [&](Value operand) -> bool {
726  auto *operandOp = operand.getDefiningOp();
727  if (!operandOp)
728  return false;
729 
730  // If the concat has a constant operand then we can transform this.
731  if (isa<hw::ConstantOp>(operandOp))
732  return true;
733  // If the concat has the same logical operation and that operation has
734  // a constant operation than we can fold it into that suboperation.
735  return operandOp->getName() == logicalOp->getName() &&
736  operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
737  operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
738  });
739 
740  if (!canSimplify)
741  return false;
742 
743  // Create a new instance of the logical operation. We have to do this the
744  // hard way since we're generic across a family of different ops.
745  auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
746  return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
747  rewriter);
748  };
749 
750  // Ok, let's do the transformation. We do this by slicing up the constant
751  // for each unit of the concat and duplicate the operation into the
752  // sub-operand.
753  SmallVector<Value> newConcatOperands;
754  newConcatOperands.reserve(concatOp->getNumOperands());
755 
756  // Work from MSB to LSB.
757  size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
758  for (Value operand : concatOp->getOperands()) {
759  size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
760  nextOperandBit -= operandWidth;
761  // Take a slice of the constant.
762  auto eltCst = rewriter.create<hw::ConstantOp>(
763  logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
764 
765  newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
766  }
767 
768  // Create the concat, and the rest of the logical op if we need it.
769  Value newResult =
770  rewriter.create<ConcatOp>(concatOp.getLoc(), newConcatOperands);
771 
772  // If we had a variadic logical op on the top level, then recreate it with the
773  // new concat and without the constant operand.
774  if (logicalOp->getNumOperands() > 2) {
775  auto origOperands = logicalOp->getOperands();
776  SmallVector<Value> operands;
777  // Take any stuff before the concat.
778  operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
779  // Take any stuff after the concat but before the constant.
780  operands.append(origOperands.begin() + concatIdx + 1,
781  origOperands.begin() + (origOperands.size() - 1));
782  // Include the new concat.
783  operands.push_back(newResult);
784  newResult = createLogicalOp(operands);
785  }
786 
787  replaceOpAndCopyName(rewriter, logicalOp, newResult);
788  return true;
789 }
790 
791 OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
792  if (hasOperandsOutsideOfBlock(getOperation()))
793  return {};
794 
795  APInt value = APInt::getAllOnes(getType().cast<IntegerType>().getWidth());
796 
797  auto inputs = adaptor.getInputs();
798 
799  // and(x, 01, 10) -> 00 -- annulment.
800  for (auto operand : inputs) {
801  if (!operand)
802  continue;
803  value &= operand.cast<IntegerAttr>().getValue();
804  if (value.isZero())
805  return getIntAttr(value, getContext());
806  }
807 
808  // and(x, -1) -> x.
809  if (inputs.size() == 2 && inputs[1] &&
810  inputs[1].cast<IntegerAttr>().getValue().isAllOnes())
811  return getInputs()[0];
812 
813  // and(x, x, x) -> x. This also handles and(x) -> x.
814  if (llvm::all_of(getInputs(),
815  [&](auto in) { return in == this->getInputs()[0]; }))
816  return getInputs()[0];
817 
818  // and(..., x, ..., ~x, ...) -> 0
819  for (Value arg : getInputs()) {
820  Value subExpr;
821  if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
822  for (Value arg2 : getInputs())
823  if (arg2 == subExpr)
824  return getIntAttr(
825  APInt::getZero(getType().cast<IntegerType>().getWidth()),
826  getContext());
827  }
828  }
829 
830  // Constant fold
831  return constFoldAssociativeOp(inputs, hw::PEO::And);
832 }
833 
834 /// Returns a single common operand that all inputs of the operation `op` can
835 /// be traced back to, or an empty `Value` if no such operand exists.
836 ///
837 /// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a`
838 /// (assuming the bit-width of `a` is `n`).
839 template <typename Op>
840 static Value getCommonOperand(Op op) {
841  if (!op.getType().isInteger(1))
842  return Value();
843 
844  auto inputs = op.getInputs();
845  size_t size = inputs.size();
846 
847  auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
848  if (!sourceOp)
849  return Value();
850  Value source = sourceOp.getOperand();
851 
852  // Fast path: the input size is not equal to the width of the source.
853  if (size != source.getType().getIntOrFloatBitWidth())
854  return Value();
855 
856  // Tracks the bits that were encountered.
857  llvm::BitVector bits(size);
858  bits.set(sourceOp.getLowBit());
859 
860  for (size_t i = 1; i != size; ++i) {
861  auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
862  if (!extractOp || extractOp.getOperand() != source)
863  return Value();
864  bits.set(extractOp.getLowBit());
865  }
866 
867  return bits.all() ? source : Value();
868 }
869 
870 /// Canonicalize an idempotent operation `op` so that only one input of any kind
871 /// occurs.
872 ///
873 /// Example: `and(x, y, x, z)` -> `and(x, y, z)`
874 template <typename Op>
875 static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
876  auto inputs = op.getInputs();
877  llvm::SmallSetVector<Value, 8> uniqueInputs;
878 
879  for (const auto input : inputs)
880  uniqueInputs.insert(input);
881 
882  if (uniqueInputs.size() < inputs.size()) {
883  replaceOpWithNewOpAndCopyName<Op>(rewriter, op, op.getType(),
884  uniqueInputs.getArrayRef());
885  return true;
886  }
887 
888  return false;
889 }
890 
891 LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
892  if (hasOperandsOutsideOfBlock(&*op))
893  return failure();
894 
895  auto inputs = op.getInputs();
896  auto size = inputs.size();
897  assert(size > 1 && "expected 2 or more operands, `fold` should handle this");
898 
899  // and(..., x, ..., x) -> and(..., x, ...) -- idempotent
900  // Trivial and(x), and(x, x) cases are handled by [AndOp::fold] above.
901  if (size > 2 && canonicalizeIdempotentInputs(op, rewriter))
902  return success();
903 
904  // Patterns for and with a constant on RHS.
905  APInt value;
906  if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
907  // and(..., '1) -> and(...) -- identity
908  if (value.isAllOnes()) {
909  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
910  inputs.drop_back(), false);
911  return success();
912  }
913 
914  // TODO: Combine multiple constants together even if they aren't at the
915  // end. and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant
916  // folding
917  APInt value2;
918  if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
919  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value & value2);
920  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
921  newOperands.push_back(cst);
922  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
923  newOperands, false);
924  return success();
925  }
926 
927  // Handle 'and' with a single bit constant on the RHS.
928  if (size == 2 && value.isPowerOf2()) {
929  // If the LHS is a replicate from a single bit, we can 'concat' it
930  // into place. e.g.:
931  // `replicate(x) & 4` -> `concat(zeros, x, zeros)`
932  // TODO: Generalize this for non-single-bit operands.
933  if (auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
934  auto replicateOperand = replicate.getOperand();
935  if (replicateOperand.getType().isInteger(1)) {
936  unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
937  auto trailingZeros = value.countTrailingZeros();
938 
939  // Don't add zero bit constants unnecessarily.
940  SmallVector<Value, 3> concatOperands;
941  if (trailingZeros != resultWidth - 1) {
942  auto highZeros = rewriter.create<hw::ConstantOp>(
943  op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
944  concatOperands.push_back(highZeros);
945  }
946  concatOperands.push_back(replicateOperand);
947  if (trailingZeros != 0) {
948  auto lowZeros = rewriter.create<hw::ConstantOp>(
949  op.getLoc(), APInt::getZero(trailingZeros));
950  concatOperands.push_back(lowZeros);
951  }
952  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
953  concatOperands);
954  return success();
955  }
956  }
957  }
958 
959  // If this is an and from an extract op, try shrinking the extract.
960  if (auto extractOp = inputs[0].getDefiningOp<ExtractOp>()) {
961  if (size == 2 &&
962  // We can shrink it if the mask has leading or trailing zeros.
963  (value.countLeadingZeros() || value.countTrailingZeros())) {
964  unsigned lz = value.countLeadingZeros();
965  unsigned tz = value.countTrailingZeros();
966 
967  // Start by extracting the smaller number of bits.
968  auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
969  Value smallElt = rewriter.createOrFold<ExtractOp>(
970  extractOp.getLoc(), smallTy, extractOp->getOperand(0),
971  extractOp.getLowBit() + tz);
972  // Apply the 'and' mask if needed.
973  APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
974  if (!smallMask.isAllOnes()) {
975  auto loc = inputs.back().getLoc();
976  smallElt = rewriter.createOrFold<AndOp>(
977  loc, smallElt, rewriter.create<hw::ConstantOp>(loc, smallMask),
978  false);
979  }
980 
981  // The final replacement will be a concat of the leading/trailing zeros
982  // along with the smaller extracted value.
983  SmallVector<Value> resultElts;
984  if (lz)
985  resultElts.push_back(
986  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
987  resultElts.push_back(smallElt);
988  if (tz)
989  resultElts.push_back(
990  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
991  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
992  return success();
993  }
994  }
995 
996  // and(concat(x, cst1), a, b, c, cst2)
997  // ==> and(a, b, c, concat(and(x,cst2'), and(cst1,cst2'')).
998  // We do this for even more multi-use concats since they are "just wiring".
999  for (size_t i = 0; i < size - 1; ++i) {
1000  if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1001  if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1002  return success();
1003  }
1004  }
1005 
1006  // and(x, and(...)) -> and(x, ...) -- flatten
1007  if (tryFlatteningOperands(op, rewriter))
1008  return success();
1009 
1010  // extracts only of and(...) -> and(extract()...)
1011  if (narrowOperationWidth(op, true, rewriter))
1012  return success();
1013 
1014  // and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1)
1015  if (auto source = getCommonOperand(op)) {
1016  auto cmpAgainst =
1017  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1018  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1019  source, cmpAgainst);
1020  return success();
1021  }
1022 
1023  /// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
1024  return failure();
1025 }
1026 
1027 OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1028  if (hasOperandsOutsideOfBlock(getOperation()))
1029  return {};
1030 
1031  auto value = APInt::getZero(getType().cast<IntegerType>().getWidth());
1032  auto inputs = adaptor.getInputs();
1033  // or(x, 10, 01) -> 11
1034  for (auto operand : inputs) {
1035  if (!operand)
1036  continue;
1037  value |= operand.cast<IntegerAttr>().getValue();
1038  if (value.isAllOnes())
1039  return getIntAttr(value, getContext());
1040  }
1041 
1042  // or(x, 0) -> x
1043  if (inputs.size() == 2 && inputs[1] &&
1044  inputs[1].cast<IntegerAttr>().getValue().isZero())
1045  return getInputs()[0];
1046 
1047  // or(x, x, x) -> x. This also handles or(x) -> x
1048  if (llvm::all_of(getInputs(),
1049  [&](auto in) { return in == this->getInputs()[0]; }))
1050  return getInputs()[0];
1051 
1052  // or(..., x, ..., ~x, ...) -> -1
1053  for (Value arg : getInputs()) {
1054  Value subExpr;
1055  if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
1056  for (Value arg2 : getInputs())
1057  if (arg2 == subExpr)
1058  return getIntAttr(
1059  APInt::getAllOnes(getType().cast<IntegerType>().getWidth()),
1060  getContext());
1061  }
1062  }
1063 
1064  // Constant fold
1065  return constFoldAssociativeOp(inputs, hw::PEO::Or);
1066 }
1067 
1068 /// Simplify concat ops in an or op when a constant operand is present in either
1069 /// concat.
1070 ///
1071 /// This will invert an or(concat, concat) into concat(or, or, ...), which can
1072 /// often be further simplified due to the smaller or ops being easier to fold.
1073 ///
1074 /// For example:
1075 ///
1076 /// or(..., concat(x, 0), concat(0, y))
1077 /// ==> or(..., concat(x, 0, y)), when x and y don't overlap.
1078 ///
1079 /// or(..., concat(x: i2, cst1: i4), concat(cst2: i5, y: i1))
1080 /// ==> or(..., concat(or(x: i2, extract(cst2, 4..3)),
1081 /// or(extract(cst1, 3..1), extract(cst2, 2..0)),
1082 /// or(extract(cst1, 0..0), y: i1))
1083 static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1,
1084  size_t concatIdx2,
1085  PatternRewriter &rewriter) {
1086  assert(concatIdx1 < concatIdx2 && "concatIdx1 must be < concatIdx2");
1087 
1088  auto inputs = op.getInputs();
1089  auto concat1 = inputs[concatIdx1].getDefiningOp<ConcatOp>();
1090  auto concat2 = inputs[concatIdx2].getDefiningOp<ConcatOp>();
1091 
1092  assert(concat1 && concat2 && "expected indexes to point to ConcatOps");
1093 
1094  // We can simplify as long as a constant is present in either concat.
1095  bool hasConstantOp1 =
1096  llvm::any_of(concat1->getOperands(), [&](Value operand) -> bool {
1097  return operand.getDefiningOp<hw::ConstantOp>();
1098  });
1099  if (!hasConstantOp1) {
1100  bool hasConstantOp2 =
1101  llvm::any_of(concat2->getOperands(), [&](Value operand) -> bool {
1102  return operand.getDefiningOp<hw::ConstantOp>();
1103  });
1104  if (!hasConstantOp2)
1105  return false;
1106  }
1107 
1108  SmallVector<Value> newConcatOperands;
1109 
1110  // Simultaneously iterate over the operands of both concat ops, from MSB to
1111  // LSB, pushing out or's of overlapping ranges of the operands. When operands
1112  // span different bit ranges, we extract only the maximum overlap.
1113  auto operands1 = concat1->getOperands();
1114  auto operands2 = concat2->getOperands();
1115  // Number of bits already consumed from operands 1 and 2, respectively.
1116  unsigned consumedWidth1 = 0;
1117  unsigned consumedWidth2 = 0;
1118  for (auto it1 = operands1.begin(), end1 = operands1.end(),
1119  it2 = operands2.begin(), end2 = operands2.end();
1120  it1 != end1 && it2 != end2;) {
1121  auto operand1 = *it1;
1122  auto operand2 = *it2;
1123 
1124  unsigned remainingWidth1 =
1125  hw::getBitWidth(operand1.getType()) - consumedWidth1;
1126  unsigned remainingWidth2 =
1127  hw::getBitWidth(operand2.getType()) - consumedWidth2;
1128  unsigned widthToConsume = std::min(remainingWidth1, remainingWidth2);
1129  auto narrowedType = rewriter.getIntegerType(widthToConsume);
1130 
1131  auto extract1 = rewriter.createOrFold<ExtractOp>(
1132  op.getLoc(), narrowedType, operand1, remainingWidth1 - widthToConsume);
1133  auto extract2 = rewriter.createOrFold<ExtractOp>(
1134  op.getLoc(), narrowedType, operand2, remainingWidth2 - widthToConsume);
1135 
1136  newConcatOperands.push_back(
1137  rewriter.createOrFold<OrOp>(op.getLoc(), extract1, extract2, false));
1138 
1139  consumedWidth1 += widthToConsume;
1140  consumedWidth2 += widthToConsume;
1141 
1142  if (widthToConsume == remainingWidth1) {
1143  ++it1;
1144  consumedWidth1 = 0;
1145  }
1146  if (widthToConsume == remainingWidth2) {
1147  ++it2;
1148  consumedWidth2 = 0;
1149  }
1150  }
1151 
1152  ConcatOp newOp = rewriter.create<ConcatOp>(op.getLoc(), newConcatOperands);
1153 
1154  // Copy the old operands except for concatIdx1 and concatIdx2, and append the
1155  // new ConcatOp to the end.
1156  SmallVector<Value> newOrOperands;
1157  newOrOperands.append(inputs.begin(), inputs.begin() + concatIdx1);
1158  newOrOperands.append(inputs.begin() + concatIdx1 + 1,
1159  inputs.begin() + concatIdx2);
1160  newOrOperands.append(inputs.begin() + concatIdx2 + 1,
1161  inputs.begin() + inputs.size());
1162  newOrOperands.push_back(newOp);
1163 
1164  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1165  newOrOperands);
1166  return true;
1167 }
1168 
1169 LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
1170  if (hasOperandsOutsideOfBlock(&*op))
1171  return failure();
1172 
1173  auto inputs = op.getInputs();
1174  auto size = inputs.size();
1175  assert(size > 1 && "expected 2 or more operands");
1176 
1177  // or(..., x, ..., x, ...) -> or(..., x) -- idempotent
1178  // Trivial or(x), or(x, x) cases are handled by [OrOp::fold].
1179  if (size > 2 && canonicalizeIdempotentInputs(op, rewriter))
1180  return success();
1181 
1182  // Patterns for and with a constant on RHS.
1183  APInt value;
1184  if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1185  // or(..., '0) -> or(...) -- identity
1186  if (value.isZero()) {
1187  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1188  inputs.drop_back());
1189  return success();
1190  }
1191 
1192  // or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant folding
1193  APInt value2;
1194  if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1195  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value | value2);
1196  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1197  newOperands.push_back(cst);
1198  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1199  newOperands);
1200  return success();
1201  }
1202 
1203  // or(concat(x, cst1), a, b, c, cst2)
1204  // ==> or(a, b, c, concat(or(x,cst2'), or(cst1,cst2'')).
1205  // We do this for even more multi-use concats since they are "just wiring".
1206  for (size_t i = 0; i < size - 1; ++i) {
1207  if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1208  if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1209  return success();
1210  }
1211  }
1212 
1213  // or(x, or(...)) -> or(x, ...) -- flatten
1214  if (tryFlatteningOperands(op, rewriter))
1215  return success();
1216 
1217  // or(..., concat(x, cst1), concat(cst2, y)
1218  // ==> or(..., concat(x, cst3, y)), when x and y don't overlap.
1219  for (size_t i = 0; i < size - 1; ++i) {
1220  if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1221  for (size_t j = i + 1; j < size; ++j)
1222  if (auto concat = inputs[j].getDefiningOp<ConcatOp>())
1223  if (canonicalizeOrOfConcatsWithCstOperands(op, i, j, rewriter))
1224  return success();
1225  }
1226 
1227  // extracts only of or(...) -> or(extract()...)
1228  if (narrowOperationWidth(op, true, rewriter))
1229  return success();
1230 
1231  // or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0)
1232  if (auto source = getCommonOperand(op)) {
1233  auto cmpAgainst =
1234  rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1235  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1236  source, cmpAgainst);
1237  return success();
1238  }
1239 
1240  // or(mux(c_1, a, 0), mux(c_2, a, 0), ..., mux(c_n, a, 0)) -> mux(or(c_1, c_2,
1241  // .., c_n), a, 0)
1242  if (auto firstMux = op.getOperand(0).getDefiningOp<comb::MuxOp>()) {
1243  APInt value;
1244  if (op.getTwoState() && firstMux.getTwoState() &&
1245  matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1246  value.isZero()) {
1247  SmallVector<Value> conditions{firstMux.getCond()};
1248  auto check = [&](Value v) {
1249  auto mux = v.getDefiningOp<comb::MuxOp>();
1250  if (!mux)
1251  return false;
1252  conditions.push_back(mux.getCond());
1253  return mux.getTwoState() &&
1254  firstMux.getTrueValue() == mux.getTrueValue() &&
1255  firstMux.getFalseValue() == mux.getFalseValue();
1256  };
1257  if (llvm::all_of(op.getOperands().drop_front(), check)) {
1258  auto cond = rewriter.create<comb::OrOp>(op.getLoc(), conditions, true);
1259  replaceOpWithNewOpAndCopyName<comb::MuxOp>(
1260  rewriter, op, cond, firstMux.getTrueValue(),
1261  firstMux.getFalseValue(), true);
1262  return success();
1263  }
1264  }
1265  }
1266 
1267  /// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
1268  return failure();
1269 }
1270 
1271 OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1272  if (hasOperandsOutsideOfBlock(getOperation()))
1273  return {};
1274 
1275  auto size = getInputs().size();
1276  auto inputs = adaptor.getInputs();
1277 
1278  // xor(x) -> x -- noop
1279  if (size == 1)
1280  return getInputs()[0];
1281 
1282  // xor(x, x) -> 0 -- idempotent
1283  if (size == 2 && getInputs()[0] == getInputs()[1])
1284  return IntegerAttr::get(getType(), 0);
1285 
1286  // xor(x, 0) -> x
1287  if (inputs.size() == 2 && inputs[1] &&
1288  inputs[1].cast<IntegerAttr>().getValue().isZero())
1289  return getInputs()[0];
1290 
1291  // xor(xor(x,1),1) -> x
1292  // but not self loop
1293  if (isBinaryNot()) {
1294  Value subExpr;
1295  if (matchPattern(getOperand(0), m_Complement(m_Any(&subExpr))) &&
1296  subExpr != getResult())
1297  return subExpr;
1298  }
1299 
1300  // Constant fold
1301  return constFoldAssociativeOp(inputs, hw::PEO::Xor);
1302 }
1303 
1304 // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1305 static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand,
1306  PatternRewriter &rewriter) {
1307  auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1308  auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1309 
1310  Value result =
1311  rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1312  icmp.getOperand(1), icmp.getTwoState());
1313 
1314  // If the xor had other operands, rebuild it.
1315  if (op.getNumOperands() > 2) {
1316  SmallVector<Value, 4> newOperands(op.getOperands());
1317  newOperands.pop_back();
1318  newOperands.erase(newOperands.begin() + icmpOperand);
1319  newOperands.push_back(result);
1320  result = rewriter.create<XorOp>(op.getLoc(), newOperands, op.getTwoState());
1321  }
1322 
1323  replaceOpAndCopyName(rewriter, op, result);
1324 }
1325 
1326 LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
1327  if (hasOperandsOutsideOfBlock(&*op))
1328  return failure();
1329 
1330  auto inputs = op.getInputs();
1331  auto size = inputs.size();
1332  assert(size > 1 && "expected 2 or more operands");
1333 
1334  // xor(..., x, x) -> xor (...) -- idempotent
1335  if (inputs[size - 1] == inputs[size - 2]) {
1336  assert(size > 2 &&
1337  "expected idempotent case for 2 elements handled already.");
1338  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1339  inputs.drop_back(/*n=*/2), false);
1340  return success();
1341  }
1342 
1343  // Patterns for xor with a constant on RHS.
1344  APInt value;
1345  if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1346  // xor(..., 0) -> xor(...) -- identity
1347  if (value.isZero()) {
1348  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1349  inputs.drop_back(), false);
1350  return success();
1351  }
1352 
1353  // xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2.
1354  APInt value2;
1355  if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1356  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value ^ value2);
1357  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1358  newOperands.push_back(cst);
1359  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1360  newOperands, false);
1361  return success();
1362  }
1363 
1364  bool isSingleBit = value.getBitWidth() == 1;
1365 
1366  // Check for subexpressions that we can simplify.
1367  for (size_t i = 0; i < size - 1; ++i) {
1368  Value operand = inputs[i];
1369 
1370  // xor(concat(x, cst1), a, b, c, cst2)
1371  // ==> xor(a, b, c, concat(xor(x,cst2'), xor(cst1,cst2'')).
1372  // We do this for even more multi-use concats since they are "just
1373  // wiring".
1374  if (auto concat = operand.getDefiningOp<ConcatOp>())
1375  if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1376  return success();
1377 
1378  // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1379  if (isSingleBit && operand.hasOneUse()) {
1380  assert(value == 1 && "single bit constant has to be one if not zero");
1381  if (auto icmp = operand.getDefiningOp<ICmpOp>())
1382  return canonicalizeXorIcmpTrue(op, i, rewriter), success();
1383  }
1384  }
1385  }
1386 
1387  // xor(x, xor(...)) -> xor(x, ...) -- flatten
1388  if (tryFlatteningOperands(op, rewriter))
1389  return success();
1390 
1391  // extracts only of xor(...) -> xor(extract()...)
1392  if (narrowOperationWidth(op, true, rewriter))
1393  return success();
1394 
1395  // xor(a[0], a[1], ..., a[n]) -> parity(a)
1396  if (auto source = getCommonOperand(op)) {
1397  replaceOpWithNewOpAndCopyName<ParityOp>(rewriter, op, source);
1398  return success();
1399  }
1400 
1401  return failure();
1402 }
1403 
1404 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1405  if (hasOperandsOutsideOfBlock(getOperation()))
1406  return {};
1407 
1408  // sub(x - x) -> 0
1409  if (getRhs() == getLhs())
1410  return getIntAttr(
1411  APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1412  getContext());
1413 
1414  if (adaptor.getRhs()) {
1415  // If both are constants, we can unconditionally fold.
1416  if (adaptor.getLhs()) {
1417  // Constant fold (c1 - c2) => (c1 + -1*c2).
1418  auto negOne = getIntAttr(
1419  APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1420  getContext());
1421  auto rhsNeg = hw::ParamExprAttr::get(
1422  hw::PEO::Mul, adaptor.getRhs().cast<TypedAttr>(), negOne);
1423  return hw::ParamExprAttr::get(hw::PEO::Add,
1424  adaptor.getLhs().cast<TypedAttr>(), rhsNeg);
1425  }
1426 
1427  // sub(x - 0) -> x
1428  if (auto rhsC = adaptor.getRhs().dyn_cast<IntegerAttr>()) {
1429  if (rhsC.getValue().isZero())
1430  return getLhs();
1431  }
1432  }
1433 
1434  return {};
1435 }
1436 
1437 LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
1438  if (hasOperandsOutsideOfBlock(&*op))
1439  return failure();
1440 
1441  // sub(x, cst) -> add(x, -cst)
1442  APInt value;
1443  if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1444  auto negCst = rewriter.create<hw::ConstantOp>(op.getLoc(), -value);
1445  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getLhs(), negCst,
1446  false);
1447  return success();
1448  }
1449 
1450  // extracts only of sub(...) -> sub(extract()...)
1451  if (narrowOperationWidth(op, false, rewriter))
1452  return success();
1453 
1454  return failure();
1455 }
1456 
1457 OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1458  if (hasOperandsOutsideOfBlock(getOperation()))
1459  return {};
1460 
1461  auto size = getInputs().size();
1462 
1463  // add(x) -> x -- noop
1464  if (size == 1u)
1465  return getInputs()[0];
1466 
1467  // Constant fold constant operands.
1468  return constFoldAssociativeOp(adaptor.getOperands(), hw::PEO::Add);
1469 }
1470 
1471 LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
1472  if (hasOperandsOutsideOfBlock(&*op))
1473  return failure();
1474 
1475  auto inputs = op.getInputs();
1476  auto size = inputs.size();
1477  assert(size > 1 && "expected 2 or more operands");
1478 
1479  APInt value, value2;
1480 
1481  // add(..., 0) -> add(...) -- identity
1482  if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1483  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1484  inputs.drop_back(), false);
1485  return success();
1486  }
1487 
1488  // add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant folding
1489  if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1490  matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1491  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
1492  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1493  newOperands.push_back(cst);
1494  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1495  newOperands, false);
1496  return success();
1497  }
1498 
1499  // add(..., x, x) -> add(..., shl(x, 1))
1500  if (inputs[size - 1] == inputs[size - 2]) {
1501  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1502 
1503  auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1504  auto shiftLeftOp =
1505  rewriter.create<comb::ShlOp>(op.getLoc(), inputs.back(), one, false);
1506 
1507  newOperands.push_back(shiftLeftOp);
1508  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1509  newOperands, false);
1510  return success();
1511  }
1512 
1513  auto shlOp = inputs[size - 1].getDefiningOp<comb::ShlOp>();
1514  // add(..., x, shl(x, c)) -> add(..., mul(x, (1 << c) + 1))
1515  if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1516  matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1517 
1518  APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1519  auto rhs =
1520  rewriter.create<hw::ConstantOp>(op.getLoc(), (one << value) + one);
1521 
1522  std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1523  auto mulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
1524 
1525  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1526  newOperands.push_back(mulOp);
1527  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1528  newOperands, false);
1529  return success();
1530  }
1531 
1532  auto mulOp = inputs[size - 1].getDefiningOp<comb::MulOp>();
1533  // add(..., x, mul(x, c)) -> add(..., mul(x, c + 1))
1534  if (mulOp && mulOp.getInputs().size() == 2 &&
1535  mulOp.getInputs()[0] == inputs[size - 2] &&
1536  matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1537 
1538  APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1539  auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + one);
1540  std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1541  auto newMulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
1542 
1543  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1544  newOperands.push_back(newMulOp);
1545  replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1546  newOperands, false);
1547  return success();
1548  }
1549 
1550  // add(x, add(...)) -> add(x, ...) -- flatten
1551  if (tryFlatteningOperands(op, rewriter))
1552  return success();
1553 
1554  // extracts only of add(...) -> add(extract()...)
1555  if (narrowOperationWidth(op, false, rewriter))
1556  return success();
1557 
1558  // add(add(x, c1), c2) -> add(x, c1 + c2)
1559  auto addOp = inputs[0].getDefiningOp<comb::AddOp>();
1560  if (addOp && addOp.getInputs().size() == 2 &&
1561  matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1562  inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1563 
1564  auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
1565  replaceOpWithNewOpAndCopyName<AddOp>(
1566  rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1567  /*twoState=*/op.getTwoState() && addOp.getTwoState());
1568  return success();
1569  }
1570 
1571  return failure();
1572 }
1573 
1574 OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1575  if (hasOperandsOutsideOfBlock(getOperation()))
1576  return {};
1577 
1578  auto size = getInputs().size();
1579  auto inputs = adaptor.getInputs();
1580 
1581  // mul(x) -> x -- noop
1582  if (size == 1u)
1583  return getInputs()[0];
1584 
1585  auto width = getType().cast<IntegerType>().getWidth();
1586  APInt value(/*numBits=*/width, 1, /*isSigned=*/false);
1587 
1588  // mul(x, 0, 1) -> 0 -- annulment
1589  for (auto operand : inputs) {
1590  if (!operand)
1591  continue;
1592  value *= operand.cast<IntegerAttr>().getValue();
1593  if (value.isZero())
1594  return getIntAttr(value, getContext());
1595  }
1596 
1597  // Constant fold
1598  return constFoldAssociativeOp(inputs, hw::PEO::Mul);
1599 }
1600 
1601 LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
1602  if (hasOperandsOutsideOfBlock(&*op))
1603  return failure();
1604 
1605  auto inputs = op.getInputs();
1606  auto size = inputs.size();
1607  assert(size > 1 && "expected 2 or more operands");
1608 
1609  APInt value, value2;
1610 
1611  // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
1612  if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1613  value.isPowerOf2()) {
1614  auto shift = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(),
1615  value.exactLogBase2());
1616  auto shlOp =
1617  rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift, false);
1618 
1619  replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1620  ArrayRef<Value>(shlOp), false);
1621  return success();
1622  }
1623 
1624  // mul(..., 1) -> mul(...) -- identity
1625  if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1626  replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1627  inputs.drop_back());
1628  return success();
1629  }
1630 
1631  // mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant folding
1632  if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1633  matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1634  auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value * value2);
1635  SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1636  newOperands.push_back(cst);
1637  replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1638  newOperands);
1639  return success();
1640  }
1641 
1642  // mul(a, mul(...)) -> mul(a, ...) -- flatten
1643  if (tryFlatteningOperands(op, rewriter))
1644  return success();
1645 
1646  // extracts only of mul(...) -> mul(extract()...)
1647  if (narrowOperationWidth(op, false, rewriter))
1648  return success();
1649 
1650  return failure();
1651 }
1652 
1653 template <class Op, bool isSigned>
1654 static OpFoldResult foldDiv(Op op, ArrayRef<Attribute> constants) {
1655  if (auto rhsValue = constants[1].dyn_cast_or_null<IntegerAttr>()) {
1656  // divu(x, 1) -> x, divs(x, 1) -> x
1657  if (rhsValue.getValue() == 1)
1658  return op.getLhs();
1659 
1660  // If the divisor is zero, do not fold for now.
1661  if (rhsValue.getValue().isZero())
1662  return {};
1663  }
1664 
1665  return constFoldBinaryOp(constants, isSigned ? hw::PEO::DivS : hw::PEO::DivU);
1666 }
1667 
1668 OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1669  if (hasOperandsOutsideOfBlock(getOperation()))
1670  return {};
1671 
1672  return foldDiv<DivUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1673 }
1674 
1675 OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1676  if (hasOperandsOutsideOfBlock(getOperation()))
1677  return {};
1678 
1679  return foldDiv<DivSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1680 }
1681 
1682 template <class Op, bool isSigned>
1683 static OpFoldResult foldMod(Op op, ArrayRef<Attribute> constants) {
1684  if (auto rhsValue = constants[1].dyn_cast_or_null<IntegerAttr>()) {
1685  // modu(x, 1) -> 0, mods(x, 1) -> 0
1686  if (rhsValue.getValue() == 1)
1687  return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1688  op.getContext());
1689 
1690  // If the divisor is zero, do not fold for now.
1691  if (rhsValue.getValue().isZero())
1692  return {};
1693  }
1694 
1695  if (auto lhsValue = constants[0].dyn_cast_or_null<IntegerAttr>()) {
1696  // modu(0, x) -> 0, mods(0, x) -> 0
1697  if (lhsValue.getValue().isZero())
1698  return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1699  op.getContext());
1700  }
1701 
1702  return constFoldBinaryOp(constants, isSigned ? hw::PEO::ModS : hw::PEO::ModU);
1703 }
1704 
1705 OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1706  if (hasOperandsOutsideOfBlock(getOperation()))
1707  return {};
1708 
1709  return foldMod<ModUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1710 }
1711 
1712 OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1713  if (hasOperandsOutsideOfBlock(getOperation()))
1714  return {};
1715 
1716  return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1717 }
1718 //===----------------------------------------------------------------------===//
1719 // ConcatOp
1720 //===----------------------------------------------------------------------===//
1721 
1722 // Constant folding
1723 OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1724  if (hasOperandsOutsideOfBlock(getOperation()))
1725  return {};
1726 
1727  if (getNumOperands() == 1)
1728  return getOperand(0);
1729 
1730  // If all the operands are constant, we can fold.
1731  for (auto attr : adaptor.getInputs())
1732  if (!attr || !attr.isa<IntegerAttr>())
1733  return {};
1734 
1735  // If we got here, we can constant fold.
1736  unsigned resultWidth = getType().getIntOrFloatBitWidth();
1737  APInt result(resultWidth, 0);
1738 
1739  unsigned nextInsertion = resultWidth;
1740  // Insert each chunk into the result.
1741  for (auto attr : adaptor.getInputs()) {
1742  auto chunk = attr.cast<IntegerAttr>().getValue();
1743  nextInsertion -= chunk.getBitWidth();
1744  result.insertBits(chunk, nextInsertion);
1745  }
1746 
1747  return getIntAttr(result, getContext());
1748 }
1749 
1750 LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
1751  if (hasOperandsOutsideOfBlock(&*op))
1752  return failure();
1753 
1754  auto inputs = op.getInputs();
1755  auto size = inputs.size();
1756  assert(size > 1 && "expected 2 or more operands");
1757 
1758  // This function is used when we flatten neighboring operands of a
1759  // (variadic) concat into a new vesion of the concat. first/last indices
1760  // are inclusive.
1761  auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex,
1762  ValueRange replacements) -> LogicalResult {
1763  SmallVector<Value, 4> newOperands;
1764  newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1765  newOperands.append(replacements.begin(), replacements.end());
1766  newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1767  if (newOperands.size() == 1)
1768  replaceOpAndCopyName(rewriter, op, newOperands[0]);
1769  else
1770  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1771  newOperands);
1772  return success();
1773  };
1774 
1775  Value commonOperand = inputs[0];
1776  for (size_t i = 0; i != size; ++i) {
1777  // Check to see if all operands are the same.
1778  if (inputs[i] != commonOperand)
1779  commonOperand = Value();
1780 
1781  // If an operand to the concat is itself a concat, then we can fold them
1782  // together.
1783  if (auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1784  return flattenConcat(i, i, subConcat->getOperands());
1785 
1786  // Check for canonicalization due to neighboring operands.
1787  if (i != 0) {
1788  // Merge neighboring constants.
1789  if (auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1790  if (auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1791  unsigned prevWidth = prevCst.getValue().getBitWidth();
1792  unsigned thisWidth = cst.getValue().getBitWidth();
1793  auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1794  resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1795  << thisWidth;
1796  Value replacement =
1797  rewriter.create<hw::ConstantOp>(op.getLoc(), resultCst);
1798  return flattenConcat(i - 1, i, replacement);
1799  }
1800  }
1801 
1802  // If the two operands are the same, turn them into a replicate.
1803  if (inputs[i] == inputs[i - 1]) {
1804  Value replacement =
1805  rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1806  return flattenConcat(i - 1, i, replacement);
1807  }
1808 
1809  // If this input is a replicate, see if we can fold it with the previous
1810  // one.
1811  if (auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1812  // ... x, repl(x, n), ... ==> ..., repl(x, n+1), ...
1813  if (repl.getOperand() == inputs[i - 1]) {
1814  Value replacement = rewriter.createOrFold<ReplicateOp>(
1815  op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1816  return flattenConcat(i - 1, i, replacement);
1817  }
1818  // ... repl(x, n), repl(x, m), ... ==> ..., repl(x, n+m), ...
1819  if (auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1820  if (prevRepl.getOperand() == repl.getOperand()) {
1821  Value replacement = rewriter.createOrFold<ReplicateOp>(
1822  op.getLoc(), repl.getOperand(),
1823  repl.getMultiple() + prevRepl.getMultiple());
1824  return flattenConcat(i - 1, i, replacement);
1825  }
1826  }
1827  }
1828 
1829  // ... repl(x, n), x, ... ==> ..., repl(x, n+1), ...
1830  if (auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1831  if (repl.getOperand() == inputs[i]) {
1832  Value replacement = rewriter.createOrFold<ReplicateOp>(
1833  op.getLoc(), inputs[i], repl.getMultiple() + 1);
1834  return flattenConcat(i - 1, i, replacement);
1835  }
1836  }
1837 
1838  // Merge neighboring extracts of neighboring inputs, e.g.
1839  // {A[3], A[2]} -> A[3:2]
1840  if (auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1841  if (auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1842  if (extract.getInput() == prevExtract.getInput()) {
1843  auto thisWidth = extract.getType().cast<IntegerType>().getWidth();
1844  if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1845  auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1846  auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1847  Value replacement = rewriter.create<ExtractOp>(
1848  op.getLoc(), resType, extract.getInput(),
1849  extract.getLowBit());
1850  return flattenConcat(i - 1, i, replacement);
1851  }
1852  }
1853  }
1854  }
1855  // Merge neighboring array extracts of neighboring inputs, e.g.
1856  // {Array[4], bitcast(Array[3:2])} -> bitcast(A[4:2])
1857 
1858  // This represents a slice of an array.
1859  struct ArraySlice {
1860  Value input;
1861  Value index;
1862  size_t width;
1863  static std::optional<ArraySlice> get(Value value) {
1864  assert(value.getType().isa<IntegerType>() && "expected integer type");
1865  if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
1866  return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1867  // array slice op is wrapped with bitcast.
1868  if (auto bitcast = value.getDefiningOp<hw::BitcastOp>())
1869  if (auto arraySlice =
1870  bitcast.getInput().getDefiningOp<hw::ArraySliceOp>())
1871  return ArraySlice{
1872  arraySlice.getInput(), arraySlice.getLowIndex(),
1873  hw::type_cast<hw::ArrayType>(arraySlice.getType())
1874  .getNumElements()};
1875  return std::nullopt;
1876  }
1877  };
1878  if (auto extractOpt = ArraySlice::get(inputs[i])) {
1879  if (auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1880  // Check that two array slices are mergable.
1881  if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1882  prevExtractOpt->input == extractOpt->input &&
1883  hw::isOffset(extractOpt->index, prevExtractOpt->index,
1884  extractOpt->width)) {
1885  auto resType = hw::ArrayType::get(
1886  hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1887  .getElementType(),
1888  extractOpt->width + prevExtractOpt->width);
1889  auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1890  Value replacement = rewriter.create<hw::BitcastOp>(
1891  op.getLoc(), resIntType,
1892  rewriter.create<hw::ArraySliceOp>(op.getLoc(), resType,
1893  prevExtractOpt->input,
1894  extractOpt->index));
1895  return flattenConcat(i - 1, i, replacement);
1896  }
1897  }
1898  }
1899  }
1900  }
1901 
1902  // If all operands were the same, then this is a replicate.
1903  if (commonOperand) {
1904  replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
1905  commonOperand);
1906  return success();
1907  }
1908 
1909  return failure();
1910 }
1911 
1912 //===----------------------------------------------------------------------===//
1913 // MuxOp
1914 //===----------------------------------------------------------------------===//
1915 
1916 OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1917  if (hasOperandsOutsideOfBlock(getOperation()))
1918  return {};
1919 
1920  // mux (c, b, b) -> b
1921  if (getTrueValue() == getFalseValue())
1922  return getTrueValue();
1923 
1924  // mux(0, a, b) -> b
1925  // mux(1, a, b) -> a
1926  if (auto pred = adaptor.getCond().dyn_cast_or_null<IntegerAttr>()) {
1927  if (pred.getValue().isZero())
1928  return getFalseValue();
1929  return getTrueValue();
1930  }
1931 
1932  // mux(cond, 1, 0) -> cond
1933  if (auto tv = adaptor.getTrueValue().dyn_cast_or_null<IntegerAttr>())
1934  if (auto fv = adaptor.getFalseValue().dyn_cast_or_null<IntegerAttr>())
1935  if (tv.getValue().isOne() && fv.getValue().isZero() &&
1936  hw::getBitWidth(getType()) == 1)
1937  return getCond();
1938 
1939  return {};
1940 }
1941 
1942 /// Check to see if the condition to the specified mux is an equality
1943 /// comparison `indexValue` and one or more constants. If so, put the
1944 /// constants in the constants vector and return true, otherwise return false.
1945 ///
1946 /// This is part of foldMuxChain.
1947 ///
1948 static bool
1949 getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted,
1950  std::function<void(hw::ConstantOp)> constantFn) {
1951  // Handle `idx == 42` and `idx != 42`.
1952  if (auto cmp = cond.getDefiningOp<ICmpOp>()) {
1953  // TODO: We could handle things like "x < 2" as two entries.
1954  auto requiredPredicate =
1955  (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1956  if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1957  if (auto cst = cmp.getRhs().getDefiningOp<hw::ConstantOp>()) {
1958  constantFn(cst);
1959  return true;
1960  }
1961  }
1962  return false;
1963  }
1964 
1965  // Handle mux(`idx == 1 || idx == 3`, value, muxchain).
1966  if (auto orOp = cond.getDefiningOp<OrOp>()) {
1967  if (!isInverted)
1968  return false;
1969  for (auto operand : orOp.getOperands())
1970  if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
1971  return false;
1972  return true;
1973  }
1974 
1975  // Handle mux(`idx != 1 && idx != 3`, muxchain, value).
1976  if (auto andOp = cond.getDefiningOp<AndOp>()) {
1977  if (isInverted)
1978  return false;
1979  for (auto operand : andOp.getOperands())
1980  if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
1981  return false;
1982  return true;
1983  }
1984 
1985  return false;
1986 }
1987 
1988 /// Given a mux, check to see if the "on true" value (or "on false" value if
1989 /// isFalseSide=true) is a mux tree with the same condition. This allows us
1990 /// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into
1991 /// `array_get (array_create(A, B, C), VAL)` which is far more compact and
1992 /// allows synthesis tools to do more interesting optimizations.
1993 ///
1994 /// This returns false if we cannot form the mux tree (or do not want to) and
1995 /// returns true if the mux was replaced.
1996 static bool foldMuxChain(MuxOp rootMux, bool isFalseSide,
1997  PatternRewriter &rewriter) {
1998  // Get the index value being compared. Later we check to see if it is
1999  // compared to a constant with the right predicate.
2000  auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2001  if (!rootCmp)
2002  return false;
2003  Value indexValue = rootCmp.getLhs();
2004 
2005  // Return the value to use if the equality match succeeds.
2006  auto getCaseValue = [&](MuxOp mux) -> Value {
2007  return mux.getOperand(1 + unsigned(!isFalseSide));
2008  };
2009 
2010  // Return the value to use if the equality match fails. This is the next
2011  // mux in the sequence or the "otherwise" value.
2012  auto getTreeValue = [&](MuxOp mux) -> Value {
2013  return mux.getOperand(1 + unsigned(isFalseSide));
2014  };
2015 
2016  // Start scanning the mux tree to see what we've got. Keep track of the
2017  // constant comparison value and the SSA value to use when equal to it.
2018  SmallVector<Location> locationsFound;
2019  SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2020 
2021  /// Extract constants and values into `valuesFound` and return true if this is
2022  /// part of the mux tree, otherwise return false.
2023  auto collectConstantValues = [&](MuxOp mux) -> bool {
2024  return getMuxChainCondConstant(
2025  mux.getCond(), indexValue, isFalseSide, [&](hw::ConstantOp cst) {
2026  valuesFound.push_back({cst, getCaseValue(mux)});
2027  locationsFound.push_back(mux.getCond().getLoc());
2028  locationsFound.push_back(mux->getLoc());
2029  });
2030  };
2031 
2032  // Make sure the root is a correct comparison with a constant.
2033  if (!collectConstantValues(rootMux))
2034  return false;
2035 
2036  // Make sure that we're not looking at the intermediate node in a mux tree.
2037  if (rootMux->hasOneUse()) {
2038  if (auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2039  if (getTreeValue(userMux) == rootMux.getResult() &&
2040  getMuxChainCondConstant(userMux.getCond(), indexValue, isFalseSide,
2041  [&](hw::ConstantOp cst) {}))
2042  return false;
2043  }
2044  }
2045 
2046  // Scan up the tree linearly.
2047  auto nextTreeValue = getTreeValue(rootMux);
2048  while (1) {
2049  auto nextMux = nextTreeValue.getDefiningOp<MuxOp>();
2050  if (!nextMux || !nextMux->hasOneUse())
2051  break;
2052  if (!collectConstantValues(nextMux))
2053  break;
2054  nextTreeValue = getTreeValue(nextMux);
2055  }
2056 
2057  // We need to have more than three values to create an array. This is an
2058  // arbitrary threshold which is saying that one or two muxes together is ok,
2059  // but three should be folded.
2060  if (valuesFound.size() < 3)
2061  return false;
2062 
2063  // If the array is greater that 9 bits, it will take over 512 elements and
2064  // it will be too large for a single expression.
2065  auto indexWidth = indexValue.getType().cast<IntegerType>().getWidth();
2066  if (indexWidth >= 9)
2067  return false;
2068 
2069  // Next we need to see if the values are dense-ish. We don't want to have
2070  // a tremendous number of replicated entries in the array. Some sparsity is
2071  // ok though, so we require the table to be at least 5/8 utilized.
2072  uint64_t tableSize = 1ULL << indexWidth;
2073  if (valuesFound.size() < (tableSize * 5) / 8)
2074  return false; // Not dense enough.
2075 
2076  // Ok, we're going to do the transformation, start by building the table
2077  // filled with the "otherwise" value.
2078  SmallVector<Value, 8> table(tableSize, nextTreeValue);
2079 
2080  // Fill in entries in the table from the leaf to the root of the expression.
2081  // This ensures that any duplicate matches end up with the ultimate value,
2082  // which is the one closer to the root.
2083  for (auto &elt : llvm::reverse(valuesFound)) {
2084  uint64_t idx = elt.first.getValue().getZExtValue();
2085  assert(idx < table.size() && "constant should be same bitwidth as index");
2086  table[idx] = elt.second;
2087  }
2088 
2089  // The hw.array_create operation has the operand list in unintuitive order
2090  // with a[0] stored as the last element, not the first.
2091  std::reverse(table.begin(), table.end());
2092 
2093  // Build the array_create and the array_get.
2094  auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2095  auto array = rewriter.create<hw::ArrayCreateOp>(fusedLoc, table);
2096  replaceOpWithNewOpAndCopyName<hw::ArrayGetOp>(rewriter, rootMux, array,
2097  indexValue);
2098  return true;
2099 }
2100 
2101 /// Given a fully associative variadic operation like (a+b+c+d), break the
2102 /// expression into two parts, one without the specified operand (e.g.
2103 /// `tmp = a+b+d`) and one that combines that into the full expression (e.g.
2104 /// `tmp+c`), and return the inner expression.
2105 ///
2106 /// NOTE: This mutates the operation in place if it only has a single user,
2107 /// which assumes that user will be removed.
2108 ///
2109 static Value extractOperandFromFullyAssociative(Operation *fullyAssoc,
2110  size_t operandNo,
2111  PatternRewriter &rewriter) {
2112  assert(fullyAssoc->getNumOperands() >= 2 && "cannot split up unary ops");
2113  assert(operandNo < fullyAssoc->getNumOperands() && "Invalid operand #");
2114 
2115  // If this expression already has two operands (the common case) no splitting
2116  // is necessary.
2117  if (fullyAssoc->getNumOperands() == 2)
2118  return fullyAssoc->getOperand(operandNo ^ 1);
2119 
2120  // If the operation has a single use, mutate it in place.
2121  if (fullyAssoc->hasOneUse()) {
2122  fullyAssoc->eraseOperand(operandNo);
2123  return fullyAssoc->getResult(0);
2124  }
2125 
2126  // Form the new operation with the operands that remain.
2127  SmallVector<Value> operands;
2128  operands.append(fullyAssoc->getOperands().begin(),
2129  fullyAssoc->getOperands().begin() + operandNo);
2130  operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2131  fullyAssoc->getOperands().end());
2132  Value opWithoutExcluded = createGenericOp(
2133  fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2134  Value excluded = fullyAssoc->getOperand(operandNo);
2135 
2136  Value fullResult =
2137  createGenericOp(fullyAssoc->getLoc(), fullyAssoc->getName(),
2138  ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2139  replaceOpAndCopyName(rewriter, fullyAssoc, fullResult);
2140  return opWithoutExcluded;
2141 }
2142 
2143 /// Fold things like `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond)|a` and
2144 /// `mux(cond, a, x|y|z|a) -> `(x|y|z)&replicate(~cond) | a` (when isTrueOperand
2145 /// is true. Return true on successful transformation, false if not.
2146 ///
2147 /// These are various forms of "predicated ops" that can be handled with a
2148 /// replicate/and combination.
2149 static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
2150  PatternRewriter &rewriter) {
2151  // Check to see the operand in question is an operation. If it is a port,
2152  // we can't simplify it.
2153  Operation *subExpr =
2154  (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2155  if (!subExpr || subExpr->getNumOperands() < 2)
2156  return false;
2157 
2158  // If this isn't an operation we can handle, don't spend energy on it.
2159  if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2160  return false;
2161 
2162  // Check to see if the common value occurs in the operand list for the
2163  // subexpression op. If so, then we can simplify it.
2164  Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2165  size_t opNo = 0, e = subExpr->getNumOperands();
2166  while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2167  ++opNo;
2168  if (opNo == e)
2169  return false;
2170 
2171  // If we got a hit, then go ahead and simplify it!
2172  Value cond = op.getCond();
2173 
2174  // `mux(cond, a, mux(cond2, a, b))` -> `mux(cond|cond2, a, b)`
2175  // `mux(cond, a, mux(cond2, b, a))` -> `mux(cond|~cond2, a, b)`
2176  // `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)`
2177  // `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)`
2178  if (auto subMux = dyn_cast<MuxOp>(subExpr)) {
2179  Value otherValue;
2180  Value subCond = subMux.getCond();
2181 
2182  // Invert th subCond if needed and dig out the 'b' value.
2183  if (subMux.getTrueValue() == commonValue)
2184  otherValue = subMux.getFalseValue();
2185  else if (subMux.getFalseValue() == commonValue) {
2186  otherValue = subMux.getTrueValue();
2187  subCond = createOrFoldNot(op.getLoc(), subCond, rewriter);
2188  } else {
2189  // We can't fold `mux(cond, a, mux(a, x, y))`.
2190  return false;
2191  }
2192 
2193  // Invert the outer cond if needed, and combine the mux conditions.
2194  if (!isTrueOperand)
2195  cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2196  cond = rewriter.createOrFold<OrOp>(op.getLoc(), cond, subCond, false);
2197  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, cond, commonValue,
2198  otherValue, op.getTwoState());
2199  return true;
2200  }
2201 
2202  // Invert the condition if needed. Or/Xor invert when dealing with
2203  // TrueOperand, And inverts for False operand.
2204  bool isaAndOp = isa<AndOp>(subExpr);
2205  if (isTrueOperand ^ isaAndOp)
2206  cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2207 
2208  auto extendedCond =
2209  rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2210 
2211  // Cache this information before subExpr is erased by extraction below.
2212  bool isaXorOp = isa<XorOp>(subExpr);
2213  bool isaOrOp = isa<OrOp>(subExpr);
2214 
2215  // Handle the fully associative ops, start by pulling out the subexpression
2216  // from a many operand version of the op.
2217  auto restOfAssoc =
2218  extractOperandFromFullyAssociative(subExpr, opNo, rewriter);
2219 
2220  // `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond) | a`
2221  // `mux(cond, x^y^z^a, a)` -> `(x^y^z)&replicate(cond) ^ a`
2222  if (isaOrOp || isaXorOp) {
2223  auto masked = rewriter.createOrFold<AndOp>(op.getLoc(), extendedCond,
2224  restOfAssoc, false);
2225  if (isaXorOp)
2226  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, masked, commonValue,
2227  false);
2228  else
2229  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, masked, commonValue,
2230  false);
2231  return true;
2232  }
2233 
2234  // `mux(cond, a, x&y&z&a)` -> `((x&y&z)|replicate(cond)) & a`
2235  assert(isaAndOp && "unexpected operation here");
2236  auto masked = rewriter.createOrFold<OrOp>(op.getLoc(), extendedCond,
2237  restOfAssoc, false);
2238  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, masked, commonValue,
2239  false);
2240  return true;
2241 }
2242 
2243 /// This function is invoke when we find a mux with true/false operations that
2244 /// have the same opcode. Check to see if we can strength reduce the mux by
2245 /// applying it to less data by applying this transformation:
2246 /// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2247 static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp,
2248  Operation *falseOp,
2249  PatternRewriter &rewriter) {
2250  // Right now we only apply to concat.
2251  // TODO: Generalize this to and, or, xor, icmp(!), which all occur in practice
2252  if (!isa<ConcatOp>(trueOp))
2253  return false;
2254 
2255  // Decode the operands, looking through recursive concats and replicates.
2256  SmallVector<Value> trueOperands, falseOperands;
2257  getConcatOperands(trueOp->getResult(0), trueOperands);
2258  getConcatOperands(falseOp->getResult(0), falseOperands);
2259 
2260  size_t numTrueOperands = trueOperands.size();
2261  size_t numFalseOperands = falseOperands.size();
2262 
2263  if (!numTrueOperands || !numFalseOperands ||
2264  (trueOperands.front() != falseOperands.front() &&
2265  trueOperands.back() != falseOperands.back()))
2266  return false;
2267 
2268  // Pull all leading shared operands out into their own op if any are common.
2269  if (trueOperands.front() == falseOperands.front()) {
2270  SmallVector<Value> operands;
2271  size_t i;
2272  for (i = 0; i < numTrueOperands; ++i) {
2273  Value trueOperand = trueOperands[i];
2274  if (trueOperand == falseOperands[i])
2275  operands.push_back(trueOperand);
2276  else
2277  break;
2278  }
2279  if (i == numTrueOperands) {
2280  // Selecting between distinct, but lexically identical, concats.
2281  replaceOpAndCopyName(rewriter, mux, trueOp->getResult(0));
2282  return true;
2283  }
2284 
2285  Value sharedMSB;
2286  if (llvm::all_of(operands, [&](Value v) { return v == operands.front(); }))
2287  sharedMSB = rewriter.createOrFold<ReplicateOp>(
2288  mux->getLoc(), operands.front(), operands.size());
2289  else
2290  sharedMSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2291  operands.clear();
2292 
2293  // Get a concat of the LSB's on each side.
2294  operands.append(trueOperands.begin() + i, trueOperands.end());
2295  Value trueLSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2296  operands.clear();
2297  operands.append(falseOperands.begin() + i, falseOperands.end());
2298  Value falseLSB =
2299  rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2300  // Merge the LSBs with a new mux and concat the MSB with the LSB to be
2301  // done.
2302  Value lsb = rewriter.createOrFold<MuxOp>(
2303  mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2304  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2305  return true;
2306  }
2307 
2308  // If trailing operands match, try to commonize them.
2309  if (trueOperands.back() == falseOperands.back()) {
2310  SmallVector<Value> operands;
2311  size_t i;
2312  for (i = 0;; ++i) {
2313  Value trueOperand = trueOperands[numTrueOperands - i - 1];
2314  if (trueOperand == falseOperands[numFalseOperands - i - 1])
2315  operands.push_back(trueOperand);
2316  else
2317  break;
2318  }
2319  std::reverse(operands.begin(), operands.end());
2320  Value sharedLSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2321  operands.clear();
2322 
2323  // Get a concat of the MSB's on each side.
2324  operands.append(trueOperands.begin(), trueOperands.end() - i);
2325  Value trueMSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2326  operands.clear();
2327  operands.append(falseOperands.begin(), falseOperands.end() - i);
2328  Value falseMSB =
2329  rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2330  // Merge the MSBs with a new mux and concat the MSB with the LSB to be done.
2331  Value msb = rewriter.createOrFold<MuxOp>(
2332  mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2333  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, msb, sharedLSB);
2334  return true;
2335  }
2336 
2337  return false;
2338 }
2339 
2340 // If both arguments of the mux are arrays with the same elements, sink the
2341 // mux and return a uniform array initializing all elements to it.
2342 static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter) {
2343  auto trueVec = op.getTrueValue().getDefiningOp<hw::ArrayCreateOp>();
2344  auto falseVec = op.getFalseValue().getDefiningOp<hw::ArrayCreateOp>();
2345  if (!trueVec || !falseVec)
2346  return false;
2347  if (!trueVec.isUniform() || !falseVec.isUniform())
2348  return false;
2349 
2350  auto mux = rewriter.create<MuxOp>(
2351  op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2352  falseVec.getUniformElement(), op.getTwoState());
2353 
2354  SmallVector<Value> values(trueVec.getInputs().size(), mux);
2355  rewriter.replaceOpWithNewOp<hw::ArrayCreateOp>(op, values);
2356  return true;
2357 }
2358 
2359 namespace {
2360 struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
2361  using OpRewritePattern::OpRewritePattern;
2362 
2363  LogicalResult matchAndRewrite(MuxOp op,
2364  PatternRewriter &rewriter) const override;
2365 };
2366 
2367 LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
2368  PatternRewriter &rewriter) const {
2369  if (hasOperandsOutsideOfBlock(&*op))
2370  return failure();
2371 
2372  // If the op has a SV attribute, don't optimize it.
2373  if (hasSVAttributes(op))
2374  return failure();
2375  APInt value;
2376 
2377  if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2378  if (value.getBitWidth() == 1) {
2379  // mux(a, 0, b) -> and(~a, b) for single-bit values.
2380  if (value.isZero()) {
2381  auto notCond = createOrFoldNot(op.getLoc(), op.getCond(), rewriter);
2382  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, notCond,
2383  op.getFalseValue(), false);
2384  return success();
2385  }
2386 
2387  // mux(a, 1, b) -> or(a, b) for single-bit values.
2388  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getCond(),
2389  op.getFalseValue(), false);
2390  return success();
2391  }
2392 
2393  // Check for mux of two constants. There are many ways to simplify them.
2394  APInt value2;
2395  if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2396  // When both inputs are constants and differ by only one bit, we can
2397  // simplify by splitting the mux into up to three contiguous chunks: one
2398  // for the differing bit and up to two for the bits that are the same.
2399  // E.g. mux(a, 3'h2, 0) -> concat(0, mux(a, 1, 0), 0) -> concat(0, a, 0)
2400  APInt xorValue = value ^ value2;
2401  if (xorValue.isPowerOf2()) {
2402  unsigned leadingZeros = xorValue.countLeadingZeros();
2403  unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2404  SmallVector<Value, 3> operands;
2405 
2406  // Concat operands go from MSB to LSB, so we handle chunks in reverse
2407  // order of bit indexes.
2408  // For the chunks that are identical (i.e. correspond to 0s in
2409  // xorValue), we can extract directly from either input value, and we
2410  // arbitrarily pick the trueValue().
2411 
2412  if (leadingZeros > 0)
2413  operands.push_back(rewriter.createOrFold<ExtractOp>(
2414  op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2415 
2416  // Handle the differing bit, which should simplify into either cond or
2417  // ~cond.
2418  auto v1 = rewriter.createOrFold<ExtractOp>(
2419  op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2420  auto v2 = rewriter.createOrFold<ExtractOp>(
2421  op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2422  operands.push_back(rewriter.createOrFold<MuxOp>(
2423  op.getLoc(), op.getCond(), v1, v2, false));
2424 
2425  if (trailingZeros > 0)
2426  operands.push_back(rewriter.createOrFold<ExtractOp>(
2427  op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2428 
2429  replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
2430  operands);
2431  return success();
2432  }
2433 
2434  // If the true value is all ones and the false is all zeros then we have a
2435  // replicate pattern.
2436  if (value.isAllOnes() && value2.isZero()) {
2437  replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2438  op.getCond());
2439  return success();
2440  }
2441  }
2442  }
2443 
2444  if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2445  value.getBitWidth() == 1) {
2446  // mux(a, b, 0) -> and(a, b) for single-bit values.
2447  if (value.isZero()) {
2448  replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getCond(),
2449  op.getTrueValue(), false);
2450  return success();
2451  }
2452 
2453  // mux(a, b, 1) -> or(~a, b) for single-bit values.
2454  // falseValue() is known to be a single-bit 1, which we can use for
2455  // the 1 in the representation of ~ using xor.
2456  auto notCond = rewriter.createOrFold<XorOp>(op.getLoc(), op.getCond(),
2457  op.getFalseValue(), false);
2458  replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, notCond,
2459  op.getTrueValue(), false);
2460  return success();
2461  }
2462 
2463  // mux(!a, b, c) -> mux(a, c, b)
2464  Value subExpr;
2465  Operation *condOp = op.getCond().getDefiningOp();
2466  if (condOp && matchPattern(condOp, m_Complement(m_Any(&subExpr))) &&
2467  op.getTwoState()) {
2468  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, op.getType(), subExpr,
2469  op.getFalseValue(), op.getTrueValue(),
2470  true);
2471  return success();
2472  }
2473 
2474  // Same but with Demorgan's law.
2475  // mux(and(~a, ~b, ~c), x, y) -> mux(or(a, b, c), y, x)
2476  // mux(or(~a, ~b, ~c), x, y) -> mux(and(a, b, c), y, x)
2477  if (condOp && condOp->hasOneUse()) {
2478  SmallVector<Value> invertedOperands;
2479 
2480  /// Scan all the operands to see if they are complemented. If so, build a
2481  /// vector of them and return true, otherwise return false.
2482  auto getInvertedOperands = [&]() -> bool {
2483  for (Value operand : condOp->getOperands()) {
2484  if (matchPattern(operand, m_Complement(m_Any(&subExpr))))
2485  invertedOperands.push_back(subExpr);
2486  else
2487  return false;
2488  }
2489  return true;
2490  };
2491 
2492  if (isa<AndOp>(condOp) && getInvertedOperands()) {
2493  auto newOr =
2494  rewriter.createOrFold<OrOp>(op.getLoc(), invertedOperands, false);
2495  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newOr,
2496  op.getFalseValue(),
2497  op.getTrueValue(), op.getTwoState());
2498  return success();
2499  }
2500  if (isa<OrOp>(condOp) && getInvertedOperands()) {
2501  auto newAnd =
2502  rewriter.createOrFold<AndOp>(op.getLoc(), invertedOperands, false);
2503  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newAnd,
2504  op.getFalseValue(),
2505  op.getTrueValue(), op.getTwoState());
2506  return success();
2507  }
2508  }
2509 
2510  if (auto falseMux =
2511  dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp())) {
2512  // mux(selector, x, mux(selector, y, z) = mux(selector, x, z)
2513  if (op.getCond() == falseMux.getCond()) {
2514  replaceOpWithNewOpAndCopyName<MuxOp>(
2515  rewriter, op, op.getCond(), op.getTrueValue(),
2516  falseMux.getFalseValue(), op.getTwoStateAttr());
2517  return success();
2518  }
2519 
2520  // Check to see if we can fold a mux tree into an array_create/get pair.
2521  if (foldMuxChain(op, /*isFalse*/ true, rewriter))
2522  return success();
2523  }
2524 
2525  if (auto trueMux =
2526  dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp())) {
2527  // mux(selector, mux(selector, a, b), c) = mux(selector, a, c)
2528  if (op.getCond() == trueMux.getCond()) {
2529  replaceOpWithNewOpAndCopyName<MuxOp>(
2530  rewriter, op, op.getCond(), trueMux.getTrueValue(),
2531  op.getFalseValue(), op.getTwoStateAttr());
2532  return success();
2533  }
2534 
2535  // Check to see if we can fold a mux tree into an array_create/get pair.
2536  if (foldMuxChain(op, /*isFalseSide*/ false, rewriter))
2537  return success();
2538  }
2539 
2540  // mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c))
2541  if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2542  falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2543  trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2544  trueMux.getTrueValue() == falseMux.getTrueValue()) {
2545  auto subMux = rewriter.create<MuxOp>(
2546  rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2547  op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2548  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2549  trueMux.getTrueValue(), subMux,
2550  op.getTwoStateAttr());
2551  return success();
2552  }
2553 
2554  // mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b)
2555  if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2556  falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2557  trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2558  trueMux.getFalseValue() == falseMux.getFalseValue()) {
2559  auto subMux = rewriter.create<MuxOp>(
2560  rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2561  op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2562  replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2563  subMux, trueMux.getFalseValue(),
2564  op.getTwoStateAttr());
2565  return success();
2566  }
2567 
2568  // mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b)
2569  if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2570  falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2571  trueMux && falseMux &&
2572  trueMux.getTrueValue() == falseMux.getTrueValue() &&
2573  trueMux.getFalseValue() == falseMux.getFalseValue()) {
2574  auto subMux = rewriter.create<MuxOp>(
2575  rewriter.getFusedLoc(
2576  {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2577  op.getCond(), trueMux.getCond(), falseMux.getCond());
2578  replaceOpWithNewOpAndCopyName<MuxOp>(
2579  rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2580  op.getTwoStateAttr());
2581  return success();
2582  }
2583 
2584  // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a
2585  if (foldCommonMuxValue(op, false, rewriter))
2586  return success();
2587  // mux(cond, a, x|y|z|a) -> (x|y|z)&replicate(~cond) | a
2588  if (foldCommonMuxValue(op, true, rewriter))
2589  return success();
2590 
2591  // `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2592  if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2593  if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2594  if (trueOp->getName() == falseOp->getName())
2595  if (foldCommonMuxOperation(op, trueOp, falseOp, rewriter))
2596  return success();
2597 
2598  // extracts only of mux(...) -> mux(extract()...)
2599  if (narrowOperationWidth(op, true, rewriter))
2600  return success();
2601 
2602  // mux(cond, repl(n, a1), repl(n, a2)) -> repl(n, mux(cond, a1, a2))
2603  if (foldMuxOfUniformArrays(op, rewriter))
2604  return success();
2605 
2606  return failure();
2607 }
2608 
2609 static bool foldArrayOfMuxes(hw::ArrayCreateOp op, PatternRewriter &rewriter) {
2610  // Do not fold uniform or singleton arrays to avoid duplicating muxes.
2611  if (op.getInputs().empty() || op.isUniform())
2612  return false;
2613  auto inputs = op.getInputs();
2614  if (inputs.size() <= 1)
2615  return false;
2616 
2617  // Check the operands to the array create. Ensure all of them are the
2618  // same op with the same number of operands.
2619  auto first = inputs[0].getDefiningOp<comb::MuxOp>();
2620  if (!first || hasSVAttributes(first))
2621  return false;
2622 
2623  // Check whether all operands are muxes with the same condition.
2624  for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2625  auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2626  if (!input || first.getCond() != input.getCond())
2627  return false;
2628  }
2629 
2630  // Collect the true and the false branches into arrays.
2631  SmallVector<Value> trues{first.getTrueValue()};
2632  SmallVector<Value> falses{first.getFalseValue()};
2633  SmallVector<Location> locs{first->getLoc()};
2634  bool isTwoState = true;
2635  for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2636  auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2637  trues.push_back(input.getTrueValue());
2638  falses.push_back(input.getFalseValue());
2639  locs.push_back(input->getLoc());
2640  if (!input.getTwoState())
2641  isTwoState = false;
2642  }
2643 
2644  // Define the location of the array create as the aggregate of all muxes.
2645  auto loc = FusedLoc::get(op.getContext(), locs);
2646 
2647  // Replace the create with an aggregate operation. Push the create op
2648  // into the operands of the aggregate operation.
2649  auto arrayTy = op.getType();
2650  auto trueValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, trues);
2651  auto falseValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, falses);
2652  rewriter.replaceOpWithNewOp<comb::MuxOp>(op, arrayTy, first.getCond(),
2653  trueValues, falseValues, isTwoState);
2654  return true;
2655 }
2656 
2657 struct ArrayRewriter : public mlir::OpRewritePattern<hw::ArrayCreateOp> {
2658  using OpRewritePattern::OpRewritePattern;
2659 
2660  LogicalResult matchAndRewrite(hw::ArrayCreateOp op,
2661  PatternRewriter &rewriter) const override {
2662  if (hasOperandsOutsideOfBlock(&*op))
2663  return failure();
2664 
2665  if (foldArrayOfMuxes(op, rewriter))
2666  return success();
2667  return failure();
2668  }
2669 };
2670 
2671 } // namespace
2672 
2673 void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2674  MLIRContext *context) {
2675  results.insert<MuxRewriter, ArrayRewriter>(context);
2676 }
2677 
2678 //===----------------------------------------------------------------------===//
2679 // ICmpOp
2680 //===----------------------------------------------------------------------===//
2681 
2682 // Calculate the result of a comparison when the LHS and RHS are both
2683 // constants.
2684 static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs,
2685  const APInt &rhs) {
2686  switch (predicate) {
2687  case ICmpPredicate::eq:
2688  return lhs.eq(rhs);
2689  case ICmpPredicate::ne:
2690  return lhs.ne(rhs);
2691  case ICmpPredicate::slt:
2692  return lhs.slt(rhs);
2693  case ICmpPredicate::sle:
2694  return lhs.sle(rhs);
2695  case ICmpPredicate::sgt:
2696  return lhs.sgt(rhs);
2697  case ICmpPredicate::sge:
2698  return lhs.sge(rhs);
2699  case ICmpPredicate::ult:
2700  return lhs.ult(rhs);
2701  case ICmpPredicate::ule:
2702  return lhs.ule(rhs);
2703  case ICmpPredicate::ugt:
2704  return lhs.ugt(rhs);
2705  case ICmpPredicate::uge:
2706  return lhs.uge(rhs);
2707  case ICmpPredicate::ceq:
2708  return lhs.eq(rhs);
2709  case ICmpPredicate::cne:
2710  return lhs.ne(rhs);
2711  case ICmpPredicate::weq:
2712  return lhs.eq(rhs);
2713  case ICmpPredicate::wne:
2714  return lhs.ne(rhs);
2715  }
2716  llvm_unreachable("unknown comparison predicate");
2717 }
2718 
2719 // Returns the result of applying the predicate when the LHS and RHS are the
2720 // exact same value.
2721 static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) {
2722  switch (predicate) {
2723  case ICmpPredicate::eq:
2724  case ICmpPredicate::sle:
2725  case ICmpPredicate::sge:
2726  case ICmpPredicate::ule:
2727  case ICmpPredicate::uge:
2728  case ICmpPredicate::ceq:
2729  case ICmpPredicate::weq:
2730  return true;
2731  case ICmpPredicate::ne:
2732  case ICmpPredicate::slt:
2733  case ICmpPredicate::sgt:
2734  case ICmpPredicate::ult:
2735  case ICmpPredicate::ugt:
2736  case ICmpPredicate::cne:
2737  case ICmpPredicate::wne:
2738  return false;
2739  }
2740  llvm_unreachable("unknown comparison predicate");
2741 }
2742 
2743 OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2744  if (hasOperandsOutsideOfBlock(getOperation()))
2745  return {};
2746 
2747  // gt a, a -> false
2748  // gte a, a -> true
2749  if (getLhs() == getRhs()) {
2750  auto val = applyCmpPredicateToEqualOperands(getPredicate());
2751  return IntegerAttr::get(getType(), val);
2752  }
2753 
2754  // gt 1, 2 -> false
2755  if (auto lhs = adaptor.getLhs().dyn_cast_or_null<IntegerAttr>()) {
2756  if (auto rhs = adaptor.getRhs().dyn_cast_or_null<IntegerAttr>()) {
2757  auto val =
2758  applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2759  return IntegerAttr::get(getType(), val);
2760  }
2761  }
2762  return {};
2763 }
2764 
2765 // Given a range of operands, computes the number of matching prefix and
2766 // suffix elements. This does not perform cross-element matching.
2767 template <typename Range>
2768 static size_t computeCommonPrefixLength(const Range &a, const Range &b) {
2769  size_t commonPrefixLength = 0;
2770  auto ia = a.begin();
2771  auto ib = b.begin();
2772 
2773  for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2774  if (*ia != *ib) {
2775  break;
2776  }
2777  }
2778 
2779  return commonPrefixLength;
2780 }
2781 
2782 static size_t getTotalWidth(ArrayRef<Value> operands) {
2783  size_t totalWidth = 0;
2784  for (auto operand : operands) {
2785  // getIntOrFloatBitWidth should never raise, since all arguments to
2786  // ConcatOp are integers.
2787  ssize_t width = operand.getType().getIntOrFloatBitWidth();
2788  assert(width >= 0);
2789  totalWidth += width;
2790  }
2791  return totalWidth;
2792 }
2793 
2794 /// Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise
2795 /// comparison on common prefix and suffixes. Returns success() if a rewriting
2796 /// happens. This handles both concat and replicate.
2797 static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs,
2798  Operation *rhs,
2799  PatternRewriter &rewriter) {
2800  // It is safe to assume that [{lhsOperands, rhsOperands}.size() > 0] and
2801  // all elements have non-zero length. Both these invariants are verified
2802  // by the ConcatOp verifier.
2803  SmallVector<Value> lhsOperands, rhsOperands;
2804  getConcatOperands(lhs->getResult(0), lhsOperands);
2805  getConcatOperands(rhs->getResult(0), rhsOperands);
2806  ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2807 
2808  auto formCatOrReplicate = [&](Location loc,
2809  ArrayRef<Value> operands) -> Value {
2810  assert(!operands.empty());
2811  Value sameElement = operands[0];
2812  for (size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2813  if (sameElement != operands[i])
2814  sameElement = Value();
2815  if (sameElement)
2816  return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2817  operands.size());
2818  return rewriter.createOrFold<ConcatOp>(loc, operands);
2819  };
2820 
2821  auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2822  Value rhs) -> LogicalResult {
2823  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2824  op.getTwoState());
2825  return success();
2826  };
2827 
2828  size_t commonPrefixLength =
2829  computeCommonPrefixLength(lhsOperands, rhsOperands);
2830  if (commonPrefixLength == lhsOperands.size()) {
2831  // cat(a, b, c) == cat(a, b, c) -> 1
2832  bool result = applyCmpPredicateToEqualOperands(op.getPredicate());
2833  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
2834  APInt(1, result));
2835  return success();
2836  }
2837 
2838  size_t commonSuffixLength = computeCommonPrefixLength(
2839  llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2840 
2841  size_t commonPrefixTotalWidth =
2842  getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2843  size_t commonSuffixTotalWidth =
2844  getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2845  auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2846  .drop_back(commonSuffixLength);
2847  auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2848  .drop_back(commonSuffixLength);
2849 
2850  auto replaceWithoutReplicatingSignBit = [&]() {
2851  auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2852  auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2853  return replaceWith(op.getPredicate(), newLhs, newRhs);
2854  };
2855 
2856  auto replaceWithReplicatingSignBit = [&]() {
2857  auto firstNonEmptyValue = lhsOperands[0];
2858  auto firstNonEmptyElemWidth =
2859  firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2860  Value signBit = rewriter.createOrFold<ExtractOp>(
2861  op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2862 
2863  auto newLhs = rewriter.create<ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2864  auto newRhs = rewriter.create<ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2865  return replaceWith(op.getPredicate(), newLhs, newRhs);
2866  };
2867 
2868  if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2869  // scmp(cat(..x, b), cat(..y, b)) == scmp(cat(..x), cat(..y))
2870  if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2871  return replaceWithoutReplicatingSignBit();
2872 
2873  // scmp(cat(a, ..x, b), cat(a, ..y, b)) == scmp(cat(sgn(a), ..x),
2874  // cat(sgn(b), ..y)) Note that we cannot perform this optimization if
2875  // [width(b) = 0 && width(a) <= 1]. since that common prefix is the sign
2876  // bit. Doing the rewrite can result in an infinite loop.
2877  if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2878  return replaceWithReplicatingSignBit();
2879 
2880  } else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2881  // ucmp(cat(a, ..x, b), cat(a, ..y, b)) = ucmp(cat(..x), cat(..y))
2882  return replaceWithoutReplicatingSignBit();
2883  }
2884 
2885  return failure();
2886 }
2887 
2888 /// Given an equality comparison with a constant value and some operand that has
2889 /// known bits, simplify the comparison to check only the unknown bits of the
2890 /// input.
2891 ///
2892 /// One simple example of this is that `concat(0, stuff) == 0` can be simplified
2893 /// to `stuff == 0`, or `and(x, 3) == 0` can be simplified to
2894 /// `extract x[1:0] == 0`
2896  ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst,
2897  PatternRewriter &rewriter) {
2898 
2899  // If any of the known bits disagree with any of the comparison bits, then
2900  // we can constant fold this comparison right away.
2901  APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2902  if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2903  // If we discover a mismatch then we know an "eq" comparison is false
2904  // and a "ne" comparison is true!
2905  bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2906  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
2907  APInt(1, result));
2908  return;
2909  }
2910 
2911  // Check to see if we can prove the result entirely of the comparison (in
2912  // which we bail out early), otherwise build a list of values to concat and a
2913  // smaller constant to compare against.
2914  SmallVector<Value> newConcatOperands;
2915  auto newConstant = APInt::getZeroWidth();
2916 
2917  // Ok, some (maybe all) bits are known and some others may be unknown.
2918  // Extract out segments of the operand and compare against the
2919  // corresponding bits.
2920  unsigned knownMSB = bitsKnown.countLeadingOnes();
2921 
2922  Value operand = cmpOp.getLhs();
2923 
2924  // Ok, some bits are known but others are not. Extract out sequences of
2925  // bits that are unknown and compare just those bits. We work from MSB to
2926  // LSB.
2927  while (knownMSB != bitsKnown.getBitWidth()) {
2928  // Drop any high bits that are known.
2929  if (knownMSB)
2930  bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
2931 
2932  // Find the span of unknown bits, and extract it.
2933  unsigned unknownBits = bitsKnown.countLeadingZeros();
2934  unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
2935  auto spanOperand = rewriter.createOrFold<ExtractOp>(
2936  operand.getLoc(), operand, /*lowBit=*/lowBit,
2937  /*bitWidth=*/unknownBits);
2938  auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
2939 
2940  // Add this info to the concat we're generating.
2941  newConcatOperands.push_back(spanOperand);
2942  // FIXME(llvm merge, cc697fc292b0): concat doesn't work with zero bit values
2943  // newConstant = newConstant.concat(spanConstant);
2944  if (newConstant.getBitWidth() != 0)
2945  newConstant = newConstant.concat(spanConstant);
2946  else
2947  newConstant = spanConstant;
2948 
2949  // Drop the unknown bits in prep for the next chunk.
2950  unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
2951  bitsKnown = bitsKnown.trunc(newWidth);
2952  knownMSB = bitsKnown.countLeadingOnes();
2953  }
2954 
2955  // If all the operands to the concat are foldable then we have an identity
2956  // situation where all the sub-elements equal each other. This implies that
2957  // the overall result is foldable.
2958  if (newConcatOperands.empty()) {
2959  bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
2960  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
2961  APInt(1, result));
2962  return;
2963  }
2964 
2965  // If we have a single operand remaining, use it, otherwise form a concat.
2966  Value concatResult =
2967  rewriter.createOrFold<ConcatOp>(operand.getLoc(), newConcatOperands);
2968 
2969  // Form the comparison against the smaller constant.
2970  auto newConstantOp = rewriter.create<hw::ConstantOp>(
2971  cmpOp.getOperand(1).getLoc(), newConstant);
2972 
2973  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
2974  concatResult, newConstantOp,
2975  cmpOp.getTwoState());
2976 }
2977 
2978 // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), cst1^cst2).
2979 static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp,
2980  const APInt &rhs,
2981  PatternRewriter &rewriter) {
2982  auto ip = rewriter.saveInsertionPoint();
2983  rewriter.setInsertionPoint(xorOp);
2984 
2985  auto xorRHS = xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>();
2986  auto newRHS = rewriter.create<hw::ConstantOp>(xorRHS->getLoc(),
2987  xorRHS.getValue() ^ rhs);
2988  Value newLHS;
2989  switch (xorOp.getNumOperands()) {
2990  case 1:
2991  // This isn't common but is defined so we need to handle it.
2992  newLHS = rewriter.create<hw::ConstantOp>(xorOp.getLoc(),
2993  APInt::getZero(rhs.getBitWidth()));
2994  break;
2995  case 2:
2996  // The binary case is the most common.
2997  newLHS = xorOp.getOperand(0);
2998  break;
2999  default:
3000  // The general case forces us to form a new xor with the remaining operands.
3001  SmallVector<Value> newOperands(xorOp.getOperands());
3002  newOperands.pop_back();
3003  newLHS = rewriter.create<XorOp>(xorOp.getLoc(), newOperands, false);
3004  break;
3005  }
3006 
3007  bool xorMultipleUses = !xorOp->hasOneUse();
3008 
3009  // If the xor has multiple uses (not just the compare, then we need/want to
3010  // replace them as well.
3011  if (xorMultipleUses)
3012  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3013  false);
3014 
3015  // Replace the comparison.
3016  rewriter.restoreInsertionPoint(ip);
3017  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3018  newLHS, newRHS, false);
3019 }
3020 
3021 LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3022  if (hasOperandsOutsideOfBlock(&*op))
3023  return failure();
3024 
3025  APInt lhs, rhs;
3026 
3027  // icmp 1, x -> icmp x, 1
3028  if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3029  assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3030  "Should be folded");
3031  replaceOpWithNewOpAndCopyName<ICmpOp>(
3032  rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3033  op.getRhs(), op.getLhs(), op.getTwoState());
3034  return success();
3035  }
3036 
3037  // Canonicalize with RHS constant
3038  if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3039  auto getConstant = [&](APInt constant) -> Value {
3040  return rewriter.create<hw::ConstantOp>(op.getLoc(), std::move(constant));
3041  };
3042 
3043  auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3044  Value rhs) -> LogicalResult {
3045  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
3046  op.getTwoState());
3047  return success();
3048  };
3049 
3050  auto replaceWithConstantI1 = [&](bool constant) -> LogicalResult {
3051  replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
3052  APInt(1, constant));
3053  return success();
3054  };
3055 
3056  switch (op.getPredicate()) {
3057  case ICmpPredicate::slt:
3058  // x < max -> x != max
3059  if (rhs.isMaxSignedValue())
3060  return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3061  // x < min -> false
3062  if (rhs.isMinSignedValue())
3063  return replaceWithConstantI1(0);
3064  // x < min+1 -> x == min
3065  if ((rhs - 1).isMinSignedValue())
3066  return replaceWith(ICmpPredicate::eq, op.getLhs(),
3067  getConstant(rhs - 1));
3068  break;
3069  case ICmpPredicate::sgt:
3070  // x > min -> x != min
3071  if (rhs.isMinSignedValue())
3072  return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3073  // x > max -> false
3074  if (rhs.isMaxSignedValue())
3075  return replaceWithConstantI1(0);
3076  // x > max-1 -> x == max
3077  if ((rhs + 1).isMaxSignedValue())
3078  return replaceWith(ICmpPredicate::eq, op.getLhs(),
3079  getConstant(rhs + 1));
3080  break;
3081  case ICmpPredicate::ult:
3082  // x < max -> x != max
3083  if (rhs.isAllOnes())
3084  return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3085  // x < min -> false
3086  if (rhs.isZero())
3087  return replaceWithConstantI1(0);
3088  // x < min+1 -> x == min
3089  if ((rhs - 1).isZero())
3090  return replaceWith(ICmpPredicate::eq, op.getLhs(),
3091  getConstant(rhs - 1));
3092 
3093  // x < 0xE0 -> extract(x, 5..7) != 0b111
3094  if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3095  rhs.getBitWidth()) {
3096  auto numOnes = rhs.countLeadingOnes();
3097  auto smaller = rewriter.create<ExtractOp>(
3098  op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
3099  return replaceWith(ICmpPredicate::ne, smaller,
3100  getConstant(APInt::getAllOnes(numOnes)));
3101  }
3102 
3103  break;
3104  case ICmpPredicate::ugt:
3105  // x > min -> x != min
3106  if (rhs.isZero())
3107  return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3108  // x > max -> false
3109  if (rhs.isAllOnes())
3110  return replaceWithConstantI1(0);
3111  // x > max-1 -> x == max
3112  if ((rhs + 1).isAllOnes())
3113  return replaceWith(ICmpPredicate::eq, op.getLhs(),
3114  getConstant(rhs + 1));
3115 
3116  // x > 0x07 -> extract(x, 3..7) != 0b00000
3117  if ((rhs + 1).isPowerOf2()) {
3118  auto numOnes = rhs.countTrailingOnes();
3119  auto newWidth = rhs.getBitWidth() - numOnes;
3120  auto smaller = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(),
3121  numOnes, newWidth);
3122  return replaceWith(ICmpPredicate::ne, smaller,
3123  getConstant(APInt::getZero(newWidth)));
3124  }
3125 
3126  break;
3127  case ICmpPredicate::sle:
3128  // x <= max -> true
3129  if (rhs.isMaxSignedValue())
3130  return replaceWithConstantI1(1);
3131  // x <= c -> x < (c+1)
3132  return replaceWith(ICmpPredicate::slt, op.getLhs(), getConstant(rhs + 1));
3133  case ICmpPredicate::sge:
3134  // x >= min -> true
3135  if (rhs.isMinSignedValue())
3136  return replaceWithConstantI1(1);
3137  // x >= c -> x > (c-1)
3138  return replaceWith(ICmpPredicate::sgt, op.getLhs(), getConstant(rhs - 1));
3139  case ICmpPredicate::ule:
3140  // x <= max -> true
3141  if (rhs.isAllOnes())
3142  return replaceWithConstantI1(1);
3143  // x <= c -> x < (c+1)
3144  return replaceWith(ICmpPredicate::ult, op.getLhs(), getConstant(rhs + 1));
3145  case ICmpPredicate::uge:
3146  // x >= min -> true
3147  if (rhs.isZero())
3148  return replaceWithConstantI1(1);
3149  // x >= c -> x > (c-1)
3150  return replaceWith(ICmpPredicate::ugt, op.getLhs(), getConstant(rhs - 1));
3151  case ICmpPredicate::eq:
3152  if (rhs.getBitWidth() == 1) {
3153  if (rhs.isZero()) {
3154  // x == 0 -> x ^ 1
3155  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3156  getConstant(APInt(1, 1)),
3157  op.getTwoState());
3158  return success();
3159  }
3160  if (rhs.isAllOnes()) {
3161  // x == 1 -> x
3162  replaceOpAndCopyName(rewriter, op, op.getLhs());
3163  return success();
3164  }
3165  }
3166  break;
3167  case ICmpPredicate::ne:
3168  if (rhs.getBitWidth() == 1) {
3169  if (rhs.isZero()) {
3170  // x != 0 -> x
3171  replaceOpAndCopyName(rewriter, op, op.getLhs());
3172  return success();
3173  }
3174  if (rhs.isAllOnes()) {
3175  // x != 1 -> x ^ 1
3176  replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3177  getConstant(APInt(1, 1)),
3178  op.getTwoState());
3179  return success();
3180  }
3181  }
3182  break;
3183  case ICmpPredicate::ceq:
3184  case ICmpPredicate::cne:
3185  case ICmpPredicate::weq:
3186  case ICmpPredicate::wne:
3187  break;
3188  }
3189 
3190  // We have some specific optimizations for comparison with a constant that
3191  // are only supported for equality comparisons.
3192  if (op.getPredicate() == ICmpPredicate::eq ||
3193  op.getPredicate() == ICmpPredicate::ne) {
3194  // Simplify `icmp(value_with_known_bits, rhscst)` into some extracts
3195  // with a smaller constant. We only support equality comparisons for
3196  // this.
3197  auto knownBits = computeKnownBits(op.getLhs());
3198  if (!knownBits.isUnknown())
3199  return combineEqualityICmpWithKnownBitsAndConstant(op, knownBits, rhs,
3200  rewriter),
3201  success();
3202 
3203  // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b),
3204  // cst1^cst2).
3205  if (auto xorOp = op.getLhs().getDefiningOp<XorOp>())
3206  if (xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>())
3207  return combineEqualityICmpWithXorOfConstant(op, xorOp, rhs, rewriter),
3208  success();
3209 
3210  // Simplify icmp eq(replicate(v, n), c) -> icmp eq(v, c) if c is zero or
3211  // all one.
3212  if (auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3213  if (rhs.isAllOnes() || rhs.isZero()) {
3214  auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3215  auto cst = rewriter.create<hw::ConstantOp>(
3216  op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(width)
3217  : APInt::getZero(width));
3218  replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, op.getPredicate(),
3219  replicateOp.getInput(), cst,
3220  op.getTwoState());
3221  return success();
3222  }
3223  }
3224  }
3225 
3226  // icmp(cat(prefix, a, b, suffix), cat(prefix, c, d, suffix)) => icmp(cat(a,
3227  // b), cat(c, d)). contains special handling for sign bit in signed
3228  // compressions.
3229  if (Operation *opLHS = op.getLhs().getDefiningOp())
3230  if (Operation *opRHS = op.getRhs().getDefiningOp())
3231  if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3232  isa<ConcatOp, ReplicateOp>(opRHS)) {
3233  if (succeeded(matchAndRewriteCompareConcat(op, opLHS, opRHS, rewriter)))
3234  return success();
3235  }
3236 
3237  return failure();
3238 }
lowerAnnotationsNoRefTypePorts FirtoolPreserveValuesMode value
Definition: Firtool.cpp:95
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition: CalyxOps.cpp:538
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
Definition: CombFolds.cpp:2342
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Definition: CombFolds.cpp:681
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
Definition: CombFolds.cpp:306
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
Definition: CombFolds.cpp:2721
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
Definition: CombFolds.cpp:716
static bool hasOperandsOutsideOfBlock(Operation *op)
In comb, we assume no knowledge of the semantics of cross-block dataflow.
Definition: CombFolds.cpp:32
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
Definition: CombFolds.cpp:216
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
Definition: CombFolds.cpp:1654
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
Definition: CombFolds.cpp:840
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
Definition: CombFolds.cpp:58
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "sv.namehint" attribute.
Definition: CombFolds.cpp:88
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
Definition: CombFolds.cpp:2109
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
Definition: CombFolds.cpp:1949
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
Definition: CombFolds.cpp:52
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
Definition: CombFolds.cpp:185
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
Definition: CombFolds.cpp:1305
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
Definition: CombFolds.cpp:548
static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1, size_t concatIdx2, PatternRewriter &rewriter)
Simplify concat ops in an or op when a constant operand is present in either concat.
Definition: CombFolds.cpp:1083
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
Definition: CombFolds.cpp:2979
static size_t getTotalWidth(ArrayRef< Value > operands)
Definition: CombFolds.cpp:2782
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
Definition: CombFolds.cpp:2247
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
Definition: CombFolds.cpp:129
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
Definition: CombFolds.cpp:875
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
Definition: CombFolds.cpp:2684
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
Definition: CombFolds.cpp:2895
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
Definition: CombFolds.cpp:1996
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
Definition: CombFolds.cpp:120
static bool hasSVAttributes(Operation *op)
Definition: CombFolds.cpp:103
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
Definition: CombFolds.cpp:467
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
Definition: CombFolds.cpp:1683
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
Definition: CombFolds.cpp:2768
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
Definition: CombFolds.cpp:2149
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
Definition: CombFolds.cpp:2797
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
Definition: CombFolds.cpp:44
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition: CombFolds.cpp:73
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
int32_t width
Definition: FIRRTL.cpp:27
llvm::SmallVector< StringAttr > inputs
Builder builder
def create(low_bit, result_type, input=None)
Definition: comb.py:187
def create(elements)
Definition: hw.py:447
def create(data_type, value)
Definition: hw.py:397
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition: CombOps.cpp:48
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:34
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
bool isOffset(Value base, Value index, uint64_t offset)
Definition: HWOps.cpp:1747
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
Definition: comb.py:1