CIRCT 20.0.0git
Loading...
Searching...
No Matches
CombFolds.cpp
Go to the documentation of this file.
1//===- CombFolds.cpp - Folds + Canonicalization for Comb operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/Matchers.h"
13#include "mlir/IR/PatternMatch.h"
14#include "llvm/ADT/SetVector.h"
15#include "llvm/ADT/SmallBitVector.h"
16#include "llvm/ADT/TypeSwitch.h"
17#include "llvm/Support/KnownBits.h"
18
19using namespace mlir;
20using namespace circt;
21using namespace comb;
22using namespace matchers;
23
24/// In comb, we assume no knowledge of the semantics of cross-block dataflow. As
25/// such, cross-block dataflow is interpreted as a canonicalization barrier.
26/// This is a conservative approach which:
27/// 1. still allows for efficient canonicalization for the common CIRCT usecase
28/// of comb (comb logic nested inside single-block hw.module's)
29/// 2. allows comb operations to be used in non-HW container ops - that may use
30/// MLIR blocks and regions to represent various forms of hierarchical
31/// abstractions, thus allowing comb to compose with other dialects.
32static bool hasOperandsOutsideOfBlock(Operation *op) {
33 Block *thisBlock = op->getBlock();
34 return llvm::any_of(op->getOperands(), [&](Value operand) {
35 return operand.getParentBlock() != thisBlock;
36 });
37}
38
39/// Create a new instance of a generic operation that only has value operands,
40/// and has a single result value whose type matches the first operand.
41///
42/// This should not be used to create instances of ops with attributes or with
43/// more complicated type signatures.
44static Value createGenericOp(Location loc, OperationName name,
45 ArrayRef<Value> operands, OpBuilder &builder) {
46 OperationState state(loc, name);
47 state.addOperands(operands);
48 state.addTypes(operands[0].getType());
49 return builder.create(state)->getResult(0);
50}
51
52static TypedAttr getIntAttr(const APInt &value, MLIRContext *context) {
53 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
54 value);
55}
56
57/// Flatten concat and mux operands into a vector.
58static void getConcatOperands(Value v, SmallVectorImpl<Value> &result) {
59 if (auto concat = v.getDefiningOp<ConcatOp>()) {
60 for (auto op : concat.getOperands())
61 getConcatOperands(op, result);
62 } else if (auto repl = v.getDefiningOp<ReplicateOp>()) {
63 for (size_t i = 0, e = repl.getMultiple(); i != e; ++i)
64 getConcatOperands(repl.getOperand(), result);
65 } else {
66 result.push_back(v);
67 }
68}
69
70/// A wrapper of `PatternRewriter::replaceOp` to propagate "sv.namehint"
71/// attribute. If a replaced op has a "sv.namehint" attribute, this function
72/// propagates the name to the new value.
73static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op,
74 Value newValue) {
75 if (auto *newOp = newValue.getDefiningOp()) {
76 auto name = op->getAttrOfType<StringAttr>("sv.namehint");
77 if (name && !newOp->hasAttr("sv.namehint"))
78 rewriter.modifyOpInPlace(newOp,
79 [&] { newOp->setAttr("sv.namehint", name); });
80 }
81 rewriter.replaceOp(op, newValue);
82}
83
84/// A wrapper of `PatternRewriter::replaceOpWithNewOp` to propagate
85/// "sv.namehint" attribute. If a replaced op has a "sv.namehint" attribute,
86/// this function propagates the name to the new value.
87template <typename OpTy, typename... Args>
88static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter,
89 Operation *op, Args &&...args) {
90 auto name = op->getAttrOfType<StringAttr>("sv.namehint");
91 auto newOp =
92 rewriter.replaceOpWithNewOp<OpTy>(op, std::forward<Args>(args)...);
93 if (name && !newOp->hasAttr("sv.namehint"))
94 rewriter.modifyOpInPlace(newOp,
95 [&] { newOp->setAttr("sv.namehint", name); });
96
97 return newOp;
98}
99
100// Return true if the op has SV attributes. Note that we cannot use a helper
101// function `hasSVAttributes` defined under SV dialect because of a cyclic
102// dependency.
103static bool hasSVAttributes(Operation *op) {
104 return op->hasAttr("sv.attributes");
105}
106
107namespace {
108template <typename SubType>
109struct ComplementMatcher {
110 SubType lhs;
111 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
112 bool match(Operation *op) {
113 auto xorOp = dyn_cast<XorOp>(op);
114 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
115 }
116};
117} // end anonymous namespace
118
119template <typename SubType>
120static inline ComplementMatcher<SubType> m_Complement(const SubType &subExpr) {
121 return ComplementMatcher<SubType>(subExpr);
122}
123
124/// Return true if the op will be flattened afterwards. Op will be flattend if
125/// it has a single user which has a same op type. User must be in same block.
126static bool shouldBeFlattened(Operation *op) {
127 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
128 "must be commutative operations"));
129 if (op->hasOneUse()) {
130 auto *user = *op->getUsers().begin();
131 return user->getName() == op->getName() &&
132 op->getAttrOfType<UnitAttr>("twoState") ==
133 user->getAttrOfType<UnitAttr>("twoState") &&
134 op->getBlock() == user->getBlock();
135 }
136 return false;
137}
138
139/// Flattens a single input in `op` if `hasOneUse` is true and it can be defined
140/// as an Op. Returns true if successful, and false otherwise.
141///
142/// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true
143///
144static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) {
145 // Skip if the operation should be flattened by another operation.
146 if (shouldBeFlattened(op))
147 return false;
148
149 auto inputs = op->getOperands();
150
151 SmallVector<Value, 4> newOperands;
152 SmallVector<Location, 4> newLocations{op->getLoc()};
153 newOperands.reserve(inputs.size());
154 struct Element {
155 decltype(inputs.begin()) current, end;
156 };
157
158 SmallVector<Element> worklist;
159 worklist.push_back({inputs.begin(), inputs.end()});
160 bool binFlag = op->hasAttrOfType<UnitAttr>("twoState");
161 bool changed = false;
162 while (!worklist.empty()) {
163 auto &element = worklist.back(); // Do not pop. Take ref.
164
165 // Pop when we finished traversing the current operand range.
166 if (element.current == element.end) {
167 worklist.pop_back();
168 continue;
169 }
170
171 Value value = *element.current++;
172 auto *flattenOp = value.getDefiningOp();
173 // If not defined by a compatible operation of the same kind and
174 // from the same block, keep this as-is.
175 if (!flattenOp || flattenOp->getName() != op->getName() ||
176 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>("twoState") ||
177 flattenOp->getBlock() != op->getBlock()) {
178 newOperands.push_back(value);
179 continue;
180 }
181
182 // Don't duplicate logic when it has multiple uses.
183 if (!value.hasOneUse()) {
184 // We can fold a multi-use binary operation into this one if this allows a
185 // constant to fold though. For example, fold
186 // (or a, b, c, (or d, cst1), cst2) --> (or a, b, c, d, cst1, cst2)
187 // since the constants will both fold and we end up with the equiv cost.
188 //
189 // We don't do this for add/mul because the hardware won't be shared
190 // between the two ops if duplicated.
191 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
192 !flattenOp->getOperand(1).getDefiningOp<hw::ConstantOp>() ||
193 !inputs.back().getDefiningOp<hw::ConstantOp>()) {
194 newOperands.push_back(value);
195 continue;
196 }
197 }
198
199 changed = true;
200
201 // Otherwise, push operands into worklist.
202 auto flattenOpInputs = flattenOp->getOperands();
203 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
204 newLocations.push_back(flattenOp->getLoc());
205 }
206
207 if (!changed)
208 return false;
209
210 Value result = createGenericOp(FusedLoc::get(op->getContext(), newLocations),
211 op->getName(), newOperands, rewriter);
212 if (binFlag)
213 result.getDefiningOp()->setAttr("twoState", rewriter.getUnitAttr());
214
215 replaceOpAndCopyName(rewriter, op, result);
216 return true;
217}
218
219// Given a range of uses of an operation, find the lowest and highest bits
220// inclusive that are ever referenced. The range of uses must not be empty.
221static std::pair<size_t, size_t>
222getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits,
223 size_t originalOpWidth) {
224 auto users = op->getUsers();
225 assert(!users.empty() &&
226 "getLowestBitAndHighestBitRequired cannot operate on "
227 "a empty list of uses.");
228
229 // when we don't want to narrowTrailingBits (namely in arithmetic
230 // operations), forcing lowestBitRequired = 0
231 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
232 size_t highestBitRequired = 0;
233
234 for (auto *user : users) {
235 if (auto extractOp = dyn_cast<ExtractOp>(user)) {
236 size_t lowBit = extractOp.getLowBit();
237 size_t highBit =
238 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
239 highestBitRequired = std::max(highestBitRequired, highBit);
240 lowestBitRequired = std::min(lowestBitRequired, lowBit);
241 continue;
242 }
243
244 highestBitRequired = originalOpWidth - 1;
245 lowestBitRequired = 0;
246 break;
247 }
248
249 return {lowestBitRequired, highestBitRequired};
250}
251
252template <class OpTy>
253static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits,
254 PatternRewriter &rewriter) {
255 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
256 if (!opType)
257 return false;
258
259 auto range = getLowestBitAndHighestBitRequired(op, narrowTrailingBits,
260 opType.getWidth());
261 if (range.second + 1 == opType.getWidth() && range.first == 0)
262 return false;
263
264 SmallVector<Value> args;
265 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
266 for (auto inop : op.getOperands()) {
267 // deal with muxes here
268 if (inop.getType() != op.getType())
269 args.push_back(inop);
270 else
271 args.push_back(rewriter.createOrFold<ExtractOp>(inop.getLoc(), newType,
272 inop, range.first));
273 }
274 auto newop = rewriter.create<OpTy>(op.getLoc(), newType, args);
275 newop->setDialectAttrs(op->getDialectAttrs());
276 if (op.getTwoState())
277 newop.setTwoState(true);
278
279 Value newResult = newop.getResult();
280 if (range.first)
281 newResult = rewriter.createOrFold<ConcatOp>(
282 op.getLoc(), newResult,
283 rewriter.create<hw::ConstantOp>(op.getLoc(),
284 APInt::getZero(range.first)));
285 if (range.second + 1 < opType.getWidth())
286 newResult = rewriter.createOrFold<ConcatOp>(
287 op.getLoc(),
288 rewriter.create<hw::ConstantOp>(
289 op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
290 newResult);
291 rewriter.replaceOp(op, newResult);
292 return true;
293}
294
295//===----------------------------------------------------------------------===//
296// Unary Operations
297//===----------------------------------------------------------------------===//
298
299OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
300 if (hasOperandsOutsideOfBlock(getOperation()))
301 return {};
302
303 // Replicate one time -> noop.
304 if (cast<IntegerType>(getType()).getWidth() ==
305 getInput().getType().getIntOrFloatBitWidth())
306 return getInput();
307
308 // Constant fold.
309 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
310 if (input.getValue().getBitWidth() == 1) {
311 if (input.getValue().isZero())
312 return getIntAttr(
313 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
314 getContext());
315 return getIntAttr(
316 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
317 getContext());
318 }
319
320 APInt result = APInt::getZeroWidth();
321 for (auto i = getMultiple(); i != 0; --i)
322 result = result.concat(input.getValue());
323 return getIntAttr(result, getContext());
324 }
325
326 return {};
327}
328
329OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
330 if (hasOperandsOutsideOfBlock(getOperation()))
331 return {};
332
333 // Constant fold.
334 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
335 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
336
337 return {};
338}
339
340//===----------------------------------------------------------------------===//
341// Binary Operations
342//===----------------------------------------------------------------------===//
343
344/// Performs constant folding `calculate` with element-wise behavior on the two
345/// attributes in `operands` and returns the result if possible.
346static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
347 hw::PEO paramOpcode) {
348 assert(operands.size() == 2 && "binary op takes two operands");
349 if (!operands[0] || !operands[1])
350 return {};
351
352 // Fold constants with ParamExprAttr::get which handles simple constants as
353 // well as parameter expressions.
354 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
355 cast<TypedAttr>(operands[1]));
356}
357
358OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
359 if (hasOperandsOutsideOfBlock(getOperation()))
360 return {};
361
362 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
363 unsigned shift = rhs.getValue().getZExtValue();
364 unsigned width = getType().getIntOrFloatBitWidth();
365 if (shift == 0)
366 return getOperand(0);
367 if (width <= shift)
368 return getIntAttr(APInt::getZero(width), getContext());
369 }
370
371 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::Shl);
372}
373
374LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
376 return failure();
377
378 // ShlOp(x, cst) -> Concat(Extract(x), zeros)
379 APInt value;
380 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
381 return failure();
382
383 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
384 unsigned shift = value.getZExtValue();
385
386 // This case is handled by fold.
387 if (width <= shift || shift == 0)
388 return failure();
389
390 auto zeros =
391 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
392
393 // Remove the high bits which would be removed by the Shl.
394 auto extract =
395 rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), 0, width - shift);
396
397 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, extract, zeros);
398 return success();
399}
400
401OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
402 if (hasOperandsOutsideOfBlock(getOperation()))
403 return {};
404
405 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
406 unsigned shift = rhs.getValue().getZExtValue();
407 if (shift == 0)
408 return getOperand(0);
409
410 unsigned width = getType().getIntOrFloatBitWidth();
411 if (width <= shift)
412 return getIntAttr(APInt::getZero(width), getContext());
413 }
414 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrU);
415}
416
417LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
419 return failure();
420
421 // ShrUOp(x, cst) -> Concat(zeros, Extract(x))
422 APInt value;
423 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
424 return failure();
425
426 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
427 unsigned shift = value.getZExtValue();
428
429 // This case is handled by fold.
430 if (width <= shift || shift == 0)
431 return failure();
432
433 auto zeros =
434 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
435
436 // Remove the low bits which would be removed by the Shr.
437 auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
438 width - shift);
439
440 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, zeros, extract);
441 return success();
442}
443
444OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
445 if (hasOperandsOutsideOfBlock(getOperation()))
446 return {};
447
448 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
449 if (rhs.getValue().getZExtValue() == 0)
450 return getOperand(0);
451 }
452 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrS);
453}
454
455LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
457 return failure();
458
459 // ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x))
460 APInt value;
461 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
462 return failure();
463
464 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
465 unsigned shift = value.getZExtValue();
466
467 auto topbit =
468 rewriter.createOrFold<ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
469 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
470
471 if (width <= shift) {
472 replaceOpAndCopyName(rewriter, op, {sext});
473 return success();
474 }
475
476 auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
477 width - shift);
478
479 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, sext, extract);
480 return success();
481}
482
483//===----------------------------------------------------------------------===//
484// Other Operations
485//===----------------------------------------------------------------------===//
486
487OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
488 if (hasOperandsOutsideOfBlock(getOperation()))
489 return {};
490
491 // If we are extracting the entire input, then return it.
492 if (getInput().getType() == getType())
493 return getInput();
494
495 // Constant fold.
496 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
497 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
498 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
499 getContext());
500 }
501 return {};
502}
503
504// Transforms extract(lo, cat(a, b, c, d, e)) into
505// cat(extract(lo1, b), c, extract(lo2, d)).
506// innerCat must be the argument of the provided ExtractOp.
508 ConcatOp innerCat,
509 PatternRewriter &rewriter) {
510 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
511 size_t beginOfFirstRelevantElement = 0;
512 auto it = reversedConcatArgs.begin();
513 size_t lowBit = op.getLowBit();
514
515 // This loop finds the first concatArg that is covered by the ExtractOp
516 for (; it != reversedConcatArgs.end(); it++) {
517 assert(beginOfFirstRelevantElement <= lowBit &&
518 "incorrectly moved past an element that lowBit has coverage over");
519 auto operand = *it;
520
521 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
522 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
523 // A bit other than the first bit will be used in this element.
524 // ...... ........ ...
525 // ^---lowBit
526 // ^---beginOfFirstRelevantElement
527 //
528 // Edge-case close to the end of the range.
529 // ...... ........ ...
530 // ^---(position + operandWidth)
531 // ^---lowBit
532 // ^---beginOfFirstRelevantElement
533 //
534 // Edge-case close to the beginning of the rang
535 // ...... ........ ...
536 // ^---lowBit
537 // ^---beginOfFirstRelevantElement
538 //
539 break;
540 }
541
542 // extraction discards this element.
543 // ...... ........ ...
544 // | ^---lowBit
545 // ^---beginOfFirstRelevantElement
546 beginOfFirstRelevantElement += operandWidth;
547 }
548 assert(it != reversedConcatArgs.end() &&
549 "incorrectly failed to find an element which contains coverage of "
550 "lowBit");
551
552 SmallVector<Value> reverseConcatArgs;
553 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
554 size_t extractLo = lowBit - beginOfFirstRelevantElement;
555
556 // Transform individual arguments of innerCat(..., a, b, c,) into
557 // [ extract(a), b, extract(c) ], skipping an extract operation where
558 // possible.
559 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
560 auto concatArg = *it;
561 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
562 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
563
564 if (widthToConsume == operandWidth && extractLo == 0) {
565 reverseConcatArgs.push_back(concatArg);
566 } else {
567 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
568 reverseConcatArgs.push_back(
569 rewriter.create<ExtractOp>(op.getLoc(), resultType, *it, extractLo));
570 }
571
572 widthRemaining -= widthToConsume;
573
574 // Beyond the first element, all elements are extracted from position 0.
575 extractLo = 0;
576 }
577
578 if (reverseConcatArgs.size() == 1) {
579 replaceOpAndCopyName(rewriter, op, reverseConcatArgs[0]);
580 } else {
581 replaceOpWithNewOpAndCopyName<ConcatOp>(
582 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
583 }
584 return success();
585}
586
587// Transforms extract(lo, replicate(a, N)) into replicate(a, N-c).
588static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate,
589 PatternRewriter &rewriter) {
590 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
591 auto replicateEltWidth =
592 replicate.getOperand().getType().getIntOrFloatBitWidth();
593
594 // If the extract starts at the base of an element and is an even multiple,
595 // we can replace the extract with a smaller replicate.
596 if (op.getLowBit() % replicateEltWidth == 0 &&
597 extractResultWidth % replicateEltWidth == 0) {
598 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
599 replicate.getOperand());
600 return true;
601 }
602
603 // If the extract is completely contained in one element, extract from the
604 // element.
605 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
606 replicateEltWidth) {
607 replaceOpWithNewOpAndCopyName<ExtractOp>(
608 rewriter, op, op.getType(), replicate.getOperand(),
609 op.getLowBit() % replicateEltWidth);
610 return true;
611 }
612
613 // We don't currently handle the case of extracting from non-whole elements,
614 // e.g. `extract (replicate 2-bit-thing, N), 1`.
615 return false;
616}
617
618LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
620 return failure();
621
622 auto *inputOp = op.getInput().getDefiningOp();
623
624 // This turns out to be incredibly expensive. Disable until performance is
625 // addressed.
626#if 0
627 // If the extracted bits are all known, then return the result.
628 auto knownBits = computeKnownBits(op.getInput())
629 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
630 op.getLowBit());
631 if (knownBits.isConstant()) {
632 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
633 knownBits.getConstant());
634 return success();
635 }
636#endif
637
638 // extract(olo, extract(ilo, x)) = extract(olo + ilo, x)
639 if (auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
640 replaceOpWithNewOpAndCopyName<ExtractOp>(
641 rewriter, op, op.getType(), innerExtract.getInput(),
642 innerExtract.getLowBit() + op.getLowBit());
643 return success();
644 }
645
646 // extract(lo, cat(a, b, c, d, e)) = cat(extract(lo1, b), c, extract(lo2, d))
647 if (auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
648 return extractConcatToConcatExtract(op, innerCat, rewriter);
649
650 // extract(lo, replicate(a))
651 if (auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
652 if (extractFromReplicate(op, replicate, rewriter))
653 return success();
654
655 // `extract(and(a, cst))` -> `extract(a)` when the relevant bits of the
656 // and/or/xor are not modifying the extracted bits.
657 if (inputOp && inputOp->getNumOperands() == 2 &&
658 isa<AndOp, OrOp, XorOp>(inputOp)) {
659 if (auto cstRHS = inputOp->getOperand(1).getDefiningOp<hw::ConstantOp>()) {
660 auto extractedCst = cstRHS.getValue().extractBits(
661 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
662 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
663 replaceOpWithNewOpAndCopyName<ExtractOp>(
664 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
665 return success();
666 }
667
668 // `extract(and(a, cst))` -> `concat(extract(a), 0)` if we only need one
669 // extract to represent the result. Turning it into a pile of extracts is
670 // always fine by our cost model, but we don't want to explode things into
671 // a ton of bits because it will bloat the IR and generated Verilog.
672 if (isa<AndOp>(inputOp)) {
673 // For our cost model, we only do this if the bit pattern is a
674 // contiguous series of ones.
675 unsigned lz = extractedCst.countLeadingZeros();
676 unsigned tz = extractedCst.countTrailingZeros();
677 unsigned pop = extractedCst.popcount();
678 if (extractedCst.getBitWidth() - lz - tz == pop) {
679 auto resultTy = rewriter.getIntegerType(pop);
680 SmallVector<Value> resultElts;
681 if (lz)
682 resultElts.push_back(rewriter.create<hw::ConstantOp>(
683 op.getLoc(), APInt::getZero(lz)));
684 resultElts.push_back(rewriter.createOrFold<ExtractOp>(
685 op.getLoc(), resultTy, inputOp->getOperand(0),
686 op.getLowBit() + tz));
687 if (tz)
688 resultElts.push_back(rewriter.create<hw::ConstantOp>(
689 op.getLoc(), APInt::getZero(tz)));
690 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
691 return success();
692 }
693 }
694 }
695 }
696
697 // `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is
698 // extracted.
699 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
700 if (auto shlOp = dyn_cast<ShlOp>(inputOp)) {
701 // Don't canonicalize if the shift is multiply used.
702 if (shlOp->hasOneUse())
703 if (auto lhsCst = shlOp.getLhs().getDefiningOp<hw::ConstantOp>())
704 if (lhsCst.getValue().isOne()) {
705 auto newCst = rewriter.create<hw::ConstantOp>(
706 shlOp.getLoc(),
707 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
708 replaceOpWithNewOpAndCopyName<ICmpOp>(
709 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
710 false);
711 return success();
712 }
713 }
714
715 return failure();
716}
717
718//===----------------------------------------------------------------------===//
719// Associative Variadic operations
720//===----------------------------------------------------------------------===//
721
722// Reduce all operands to a single value (either integer constant or parameter
723// expression) if all the operands are constants.
724static Attribute constFoldAssociativeOp(ArrayRef<Attribute> operands,
725 hw::PEO paramOpcode) {
726 assert(operands.size() > 1 && "caller should handle one-operand case");
727 // We can only fold anything in the case where all operands are known to be
728 // constants. Check the least common one first for an early out.
729 if (!operands[1] || !operands[0])
730 return {};
731
732 // This will fold to a simple constant if all operands are constant.
733 if (llvm::all_of(operands.drop_front(2),
734 [&](Attribute in) { return !!in; })) {
735 SmallVector<mlir::TypedAttr> typedOperands;
736 typedOperands.reserve(operands.size());
737 for (auto operand : operands) {
738 if (auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
739 typedOperands.push_back(typedOperand);
740 else
741 break;
742 }
743 if (typedOperands.size() == operands.size())
744 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
745 }
746
747 return {};
748}
749
750/// When we find a logical operation (and, or, xor) with a constant e.g.
751/// `X & 42`, we want to push the constant into the computation of X if it leads
752/// to simplification.
753///
754/// This function handles the case where the logical operation has a concat
755/// operand. We check to see if we can simplify the concat, e.g. when it has
756/// constant operands.
757///
758/// This returns true when a simplification happens.
759static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp,
760 size_t concatIdx, const APInt &cst,
761 PatternRewriter &rewriter) {
762 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<ConcatOp>();
763 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
764
765 // Check to see if any operands can be simplified by pushing the logical op
766 // into all parts of the concat.
767 bool canSimplify =
768 llvm::any_of(concatOp->getOperands(), [&](Value operand) -> bool {
769 auto *operandOp = operand.getDefiningOp();
770 if (!operandOp)
771 return false;
772
773 // If the concat has a constant operand then we can transform this.
774 if (isa<hw::ConstantOp>(operandOp))
775 return true;
776 // If the concat has the same logical operation and that operation has
777 // a constant operation than we can fold it into that suboperation.
778 return operandOp->getName() == logicalOp->getName() &&
779 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
780 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
781 });
782
783 if (!canSimplify)
784 return false;
785
786 // Create a new instance of the logical operation. We have to do this the
787 // hard way since we're generic across a family of different ops.
788 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
789 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
790 rewriter);
791 };
792
793 // Ok, let's do the transformation. We do this by slicing up the constant
794 // for each unit of the concat and duplicate the operation into the
795 // sub-operand.
796 SmallVector<Value> newConcatOperands;
797 newConcatOperands.reserve(concatOp->getNumOperands());
798
799 // Work from MSB to LSB.
800 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
801 for (Value operand : concatOp->getOperands()) {
802 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
803 nextOperandBit -= operandWidth;
804 // Take a slice of the constant.
805 auto eltCst = rewriter.create<hw::ConstantOp>(
806 logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
807
808 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
809 }
810
811 // Create the concat, and the rest of the logical op if we need it.
812 Value newResult =
813 rewriter.create<ConcatOp>(concatOp.getLoc(), newConcatOperands);
814
815 // If we had a variadic logical op on the top level, then recreate it with the
816 // new concat and without the constant operand.
817 if (logicalOp->getNumOperands() > 2) {
818 auto origOperands = logicalOp->getOperands();
819 SmallVector<Value> operands;
820 // Take any stuff before the concat.
821 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
822 // Take any stuff after the concat but before the constant.
823 operands.append(origOperands.begin() + concatIdx + 1,
824 origOperands.begin() + (origOperands.size() - 1));
825 // Include the new concat.
826 operands.push_back(newResult);
827 newResult = createLogicalOp(operands);
828 }
829
830 replaceOpAndCopyName(rewriter, logicalOp, newResult);
831 return true;
832}
833
834// Determines whether the inputs to a logical element are of opposite
835// comparisons and can lowered into a constant.
836static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands) {
837 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
838
839 for (auto op : operands) {
840 if (auto icmpOp = op.getDefiningOp<ICmpOp>();
841 icmpOp && icmpOp.getTwoState()) {
842 auto predicate = icmpOp.getPredicate();
843 auto lhs = icmpOp.getLhs();
844 auto rhs = icmpOp.getRhs();
845 if (seenPredicates.contains(
846 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
847 return true;
848
849 seenPredicates.insert({predicate, lhs, rhs});
850 }
851 }
852 return false;
853}
854
855OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
856 if (hasOperandsOutsideOfBlock(getOperation()))
857 return {};
858
859 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).getWidth());
860
861 auto inputs = adaptor.getInputs();
862
863 // and(x, 01, 10) -> 00 -- annulment.
864 for (auto operand : inputs) {
865 if (!operand)
866 continue;
867 value &= cast<IntegerAttr>(operand).getValue();
868 if (value.isZero())
869 return getIntAttr(value, getContext());
870 }
871
872 // and(x, -1) -> x.
873 if (inputs.size() == 2 && inputs[1] &&
874 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
875 return getInputs()[0];
876
877 // and(x, x, x) -> x. This also handles and(x) -> x.
878 if (llvm::all_of(getInputs(),
879 [&](auto in) { return in == this->getInputs()[0]; }))
880 return getInputs()[0];
881
882 // and(..., x, ..., ~x, ...) -> 0
883 for (Value arg : getInputs()) {
884 Value subExpr;
885 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
886 for (Value arg2 : getInputs())
887 if (arg2 == subExpr)
888 return getIntAttr(
889 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
890 getContext());
891 }
892 }
893
894 // x0 = icmp(pred, x, y)
895 // x1 = icmp(!pred, x, y)
896 // and(x0, x1) -> 0
898 return getIntAttr(APInt::getZero(cast<IntegerType>(getType()).getWidth()),
899 getContext());
900
901 // Constant fold
902 return constFoldAssociativeOp(inputs, hw::PEO::And);
903}
904
905/// Returns a single common operand that all inputs of the operation `op` can
906/// be traced back to, or an empty `Value` if no such operand exists.
907///
908/// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a`
909/// (assuming the bit-width of `a` is `n`).
910template <typename Op>
911static Value getCommonOperand(Op op) {
912 if (!op.getType().isInteger(1))
913 return Value();
914
915 auto inputs = op.getInputs();
916 size_t size = inputs.size();
917
918 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
919 if (!sourceOp)
920 return Value();
921 Value source = sourceOp.getOperand();
922
923 // Fast path: the input size is not equal to the width of the source.
924 if (size != source.getType().getIntOrFloatBitWidth())
925 return Value();
926
927 // Tracks the bits that were encountered.
928 llvm::BitVector bits(size);
929 bits.set(sourceOp.getLowBit());
930
931 for (size_t i = 1; i != size; ++i) {
932 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
933 if (!extractOp || extractOp.getOperand() != source)
934 return Value();
935 bits.set(extractOp.getLowBit());
936 }
937
938 return bits.all() ? source : Value();
939}
940
941/// Canonicalize an idempotent operation `op` so that only one input of any kind
942/// occurs.
943///
944/// Example: `and(x, y, x, z)` -> `and(x, y, z)`
945template <typename Op>
946static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
947 // Depth limit to search, in operations. Chosen arbitrarily, keep small.
948 constexpr unsigned limit = 3;
949 auto inputs = op.getInputs();
950
951 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
952 llvm::SmallDenseSet<Op, 8> checked;
953 checked.insert(op);
954
955 struct OpWithDepth {
956 Op op;
957 unsigned depth;
958 };
959 llvm::SmallVector<OpWithDepth, 8> worklist;
960
961 auto enqueue = [&worklist, &checked, &op](Value input, unsigned depth) {
962 // Add to worklist if within depth limit, is defined in the same block by
963 // the same kind of operation, has same two-state-ness, and not enqueued
964 // previously.
965 if (depth < limit && input.getParentBlock() == op->getBlock()) {
966 auto inputOp = input.template getDefiningOp<Op>();
967 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
968 checked.insert(inputOp).second)
969 worklist.push_back({inputOp, depth + 1});
970 }
971 };
972
973 for (auto input : uniqueInputs)
974 enqueue(input, 0);
975
976 while (!worklist.empty()) {
977 auto item = worklist.pop_back_val();
978
979 for (auto input : item.op.getInputs()) {
980 uniqueInputs.remove(input);
981 enqueue(input, item.depth);
982 }
983 }
984
985 if (uniqueInputs.size() < inputs.size()) {
986 replaceOpWithNewOpAndCopyName<Op>(rewriter, op, op.getType(),
987 uniqueInputs.getArrayRef(),
988 op.getTwoState());
989 return true;
990 }
991
992 return false;
993}
994
995LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
996 auto inputs = op.getInputs();
997 auto size = inputs.size();
998
999 // and(x, and(...)) -> and(x, ...) -- flatten
1000 if (tryFlatteningOperands(op, rewriter))
1001 return success();
1002
1003 // and(..., x, ..., x) -> and(..., x, ...) -- idempotent
1004 // and(..., x, and(..., x, ...)) -> and(..., and(..., x, ...)) -- idempotent
1005 // Trivial and(x), and(x, x) cases are handled by [AndOp::fold] above.
1006 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
1007 return success();
1008
1009 if (hasOperandsOutsideOfBlock(&*op))
1010 return failure();
1011 assert(size > 1 && "expected 2 or more operands, `fold` should handle this");
1012
1013 // Patterns for and with a constant on RHS.
1014 APInt value;
1015 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1016 // and(..., '1) -> and(...) -- identity
1017 if (value.isAllOnes()) {
1018 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
1019 inputs.drop_back(), false);
1020 return success();
1021 }
1022
1023 // TODO: Combine multiple constants together even if they aren't at the
1024 // end. and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant
1025 // folding
1026 APInt value2;
1027 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1028 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value & value2);
1029 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1030 newOperands.push_back(cst);
1031 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getType(),
1032 newOperands, false);
1033 return success();
1034 }
1035
1036 // Handle 'and' with a single bit constant on the RHS.
1037 if (size == 2 && value.isPowerOf2()) {
1038 // If the LHS is a replicate from a single bit, we can 'concat' it
1039 // into place. e.g.:
1040 // `replicate(x) & 4` -> `concat(zeros, x, zeros)`
1041 // TODO: Generalize this for non-single-bit operands.
1042 if (auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1043 auto replicateOperand = replicate.getOperand();
1044 if (replicateOperand.getType().isInteger(1)) {
1045 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1046 auto trailingZeros = value.countTrailingZeros();
1047
1048 // Don't add zero bit constants unnecessarily.
1049 SmallVector<Value, 3> concatOperands;
1050 if (trailingZeros != resultWidth - 1) {
1051 auto highZeros = rewriter.create<hw::ConstantOp>(
1052 op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
1053 concatOperands.push_back(highZeros);
1054 }
1055 concatOperands.push_back(replicateOperand);
1056 if (trailingZeros != 0) {
1057 auto lowZeros = rewriter.create<hw::ConstantOp>(
1058 op.getLoc(), APInt::getZero(trailingZeros));
1059 concatOperands.push_back(lowZeros);
1060 }
1061 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1062 concatOperands);
1063 return success();
1064 }
1065 }
1066 }
1067
1068 // If this is an and from an extract op, try shrinking the extract.
1069 if (auto extractOp = inputs[0].getDefiningOp<ExtractOp>()) {
1070 if (size == 2 &&
1071 // We can shrink it if the mask has leading or trailing zeros.
1072 (value.countLeadingZeros() || value.countTrailingZeros())) {
1073 unsigned lz = value.countLeadingZeros();
1074 unsigned tz = value.countTrailingZeros();
1075
1076 // Start by extracting the smaller number of bits.
1077 auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
1078 Value smallElt = rewriter.createOrFold<ExtractOp>(
1079 extractOp.getLoc(), smallTy, extractOp->getOperand(0),
1080 extractOp.getLowBit() + tz);
1081 // Apply the 'and' mask if needed.
1082 APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
1083 if (!smallMask.isAllOnes()) {
1084 auto loc = inputs.back().getLoc();
1085 smallElt = rewriter.createOrFold<AndOp>(
1086 loc, smallElt, rewriter.create<hw::ConstantOp>(loc, smallMask),
1087 false);
1088 }
1089
1090 // The final replacement will be a concat of the leading/trailing zeros
1091 // along with the smaller extracted value.
1092 SmallVector<Value> resultElts;
1093 if (lz)
1094 resultElts.push_back(
1095 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
1096 resultElts.push_back(smallElt);
1097 if (tz)
1098 resultElts.push_back(
1099 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
1100 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, resultElts);
1101 return success();
1102 }
1103 }
1104
1105 // and(concat(x, cst1), a, b, c, cst2)
1106 // ==> and(a, b, c, concat(and(x,cst2'), and(cst1,cst2'')).
1107 // We do this for even more multi-use concats since they are "just wiring".
1108 for (size_t i = 0; i < size - 1; ++i) {
1109 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1110 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1111 return success();
1112 }
1113 }
1114
1115 // extracts only of and(...) -> and(extract()...)
1116 if (narrowOperationWidth(op, true, rewriter))
1117 return success();
1118
1119 // and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1)
1120 if (auto source = getCommonOperand(op)) {
1121 auto cmpAgainst =
1122 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1123 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1124 source, cmpAgainst);
1125 return success();
1126 }
1127
1128 /// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
1129 return failure();
1130}
1131
1132OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1133 if (hasOperandsOutsideOfBlock(getOperation()))
1134 return {};
1135
1136 auto value = APInt::getZero(cast<IntegerType>(getType()).getWidth());
1137 auto inputs = adaptor.getInputs();
1138 // or(x, 10, 01) -> 11
1139 for (auto operand : inputs) {
1140 if (!operand)
1141 continue;
1142 value |= cast<IntegerAttr>(operand).getValue();
1143 if (value.isAllOnes())
1144 return getIntAttr(value, getContext());
1145 }
1146
1147 // or(x, 0) -> x
1148 if (inputs.size() == 2 && inputs[1] &&
1149 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1150 return getInputs()[0];
1151
1152 // or(x, x, x) -> x. This also handles or(x) -> x
1153 if (llvm::all_of(getInputs(),
1154 [&](auto in) { return in == this->getInputs()[0]; }))
1155 return getInputs()[0];
1156
1157 // or(..., x, ..., ~x, ...) -> -1
1158 for (Value arg : getInputs()) {
1159 Value subExpr;
1160 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
1161 for (Value arg2 : getInputs())
1162 if (arg2 == subExpr)
1163 return getIntAttr(
1164 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1165 getContext());
1166 }
1167 }
1168
1169 // x0 = icmp(pred, x, y)
1170 // x1 = icmp(!pred, x, y)
1171 // or(x0, x1) -> 1
1172 if (canCombineOppositeBinCmpIntoConstant(getInputs()))
1173 return getIntAttr(
1174 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1175 getContext());
1176
1177 // Constant fold
1178 return constFoldAssociativeOp(inputs, hw::PEO::Or);
1179}
1180
1181LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
1182 auto inputs = op.getInputs();
1183 auto size = inputs.size();
1184
1185 // or(x, or(...)) -> or(x, ...) -- flatten
1186 if (tryFlatteningOperands(op, rewriter))
1187 return success();
1188
1189 // or(..., x, ..., x, ...) -> or(..., x) -- idempotent
1190 // or(..., x, or(..., x, ...)) -> or(..., or(..., x, ...)) -- idempotent
1191 // Trivial or(x), or(x, x) cases are handled by [OrOp::fold].
1192 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
1193 return success();
1194
1195 if (hasOperandsOutsideOfBlock(&*op))
1196 return failure();
1197 assert(size > 1 && "expected 2 or more operands");
1198
1199 // Patterns for and with a constant on RHS.
1200 APInt value;
1201 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1202 // or(..., '0) -> or(...) -- identity
1203 if (value.isZero()) {
1204 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1205 inputs.drop_back());
1206 return success();
1207 }
1208
1209 // or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant folding
1210 APInt value2;
1211 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1212 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value | value2);
1213 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1214 newOperands.push_back(cst);
1215 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getType(),
1216 newOperands);
1217 return success();
1218 }
1219
1220 // or(concat(x, cst1), a, b, c, cst2)
1221 // ==> or(a, b, c, concat(or(x,cst2'), or(cst1,cst2'')).
1222 // We do this for even more multi-use concats since they are "just wiring".
1223 for (size_t i = 0; i < size - 1; ++i) {
1224 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1225 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1226 return success();
1227 }
1228 }
1229
1230 // extracts only of or(...) -> or(extract()...)
1231 if (narrowOperationWidth(op, true, rewriter))
1232 return success();
1233
1234 // or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0)
1235 if (auto source = getCommonOperand(op)) {
1236 auto cmpAgainst =
1237 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1238 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1239 source, cmpAgainst);
1240 return success();
1241 }
1242
1243 // or(mux(c_1, a, 0), mux(c_2, a, 0), ..., mux(c_n, a, 0)) -> mux(or(c_1, c_2,
1244 // .., c_n), a, 0)
1245 if (auto firstMux = op.getOperand(0).getDefiningOp<comb::MuxOp>()) {
1246 APInt value;
1247 if (op.getTwoState() && firstMux.getTwoState() &&
1248 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1249 value.isZero()) {
1250 SmallVector<Value> conditions{firstMux.getCond()};
1251 auto check = [&](Value v) {
1252 auto mux = v.getDefiningOp<comb::MuxOp>();
1253 if (!mux)
1254 return false;
1255 conditions.push_back(mux.getCond());
1256 return mux.getTwoState() &&
1257 firstMux.getTrueValue() == mux.getTrueValue() &&
1258 firstMux.getFalseValue() == mux.getFalseValue();
1259 };
1260 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1261 auto cond = rewriter.create<comb::OrOp>(op.getLoc(), conditions, true);
1262 replaceOpWithNewOpAndCopyName<comb::MuxOp>(
1263 rewriter, op, cond, firstMux.getTrueValue(),
1264 firstMux.getFalseValue(), true);
1265 return success();
1266 }
1267 }
1268 }
1269
1270 /// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
1271 return failure();
1272}
1273
1274OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1275 if (hasOperandsOutsideOfBlock(getOperation()))
1276 return {};
1277
1278 auto size = getInputs().size();
1279 auto inputs = adaptor.getInputs();
1280
1281 // xor(x) -> x -- noop
1282 if (size == 1)
1283 return getInputs()[0];
1284
1285 // xor(x, x) -> 0 -- idempotent
1286 if (size == 2 && getInputs()[0] == getInputs()[1])
1287 return IntegerAttr::get(getType(), 0);
1288
1289 // xor(x, 0) -> x
1290 if (inputs.size() == 2 && inputs[1] &&
1291 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1292 return getInputs()[0];
1293
1294 // xor(xor(x,1),1) -> x
1295 // but not self loop
1296 if (isBinaryNot()) {
1297 Value subExpr;
1298 if (matchPattern(getOperand(0), m_Complement(m_Any(&subExpr))) &&
1299 subExpr != getResult())
1300 return subExpr;
1301 }
1302
1303 // Constant fold
1304 return constFoldAssociativeOp(inputs, hw::PEO::Xor);
1305}
1306
1307// xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1308static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand,
1309 PatternRewriter &rewriter) {
1310 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1311 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1312
1313 Value result =
1314 rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1315 icmp.getOperand(1), icmp.getTwoState());
1316
1317 // If the xor had other operands, rebuild it.
1318 if (op.getNumOperands() > 2) {
1319 SmallVector<Value, 4> newOperands(op.getOperands());
1320 newOperands.pop_back();
1321 newOperands.erase(newOperands.begin() + icmpOperand);
1322 newOperands.push_back(result);
1323 result = rewriter.create<XorOp>(op.getLoc(), newOperands, op.getTwoState());
1324 }
1325
1326 replaceOpAndCopyName(rewriter, op, result);
1327}
1328
1329LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
1330 if (hasOperandsOutsideOfBlock(&*op))
1331 return failure();
1332
1333 auto inputs = op.getInputs();
1334 auto size = inputs.size();
1335 assert(size > 1 && "expected 2 or more operands");
1336
1337 // xor(..., x, x) -> xor (...) -- idempotent
1338 if (inputs[size - 1] == inputs[size - 2]) {
1339 assert(size > 2 &&
1340 "expected idempotent case for 2 elements handled already.");
1341 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1342 inputs.drop_back(/*n=*/2), false);
1343 return success();
1344 }
1345
1346 // Patterns for xor with a constant on RHS.
1347 APInt value;
1348 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1349 // xor(..., 0) -> xor(...) -- identity
1350 if (value.isZero()) {
1351 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1352 inputs.drop_back(), false);
1353 return success();
1354 }
1355
1356 // xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2.
1357 APInt value2;
1358 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1359 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value ^ value2);
1360 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1361 newOperands.push_back(cst);
1362 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getType(),
1363 newOperands, false);
1364 return success();
1365 }
1366
1367 bool isSingleBit = value.getBitWidth() == 1;
1368
1369 // Check for subexpressions that we can simplify.
1370 for (size_t i = 0; i < size - 1; ++i) {
1371 Value operand = inputs[i];
1372
1373 // xor(concat(x, cst1), a, b, c, cst2)
1374 // ==> xor(a, b, c, concat(xor(x,cst2'), xor(cst1,cst2'')).
1375 // We do this for even more multi-use concats since they are "just
1376 // wiring".
1377 if (auto concat = operand.getDefiningOp<ConcatOp>())
1378 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1379 return success();
1380
1381 // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1382 if (isSingleBit && operand.hasOneUse()) {
1383 assert(value == 1 && "single bit constant has to be one if not zero");
1384 if (auto icmp = operand.getDefiningOp<ICmpOp>())
1385 return canonicalizeXorIcmpTrue(op, i, rewriter), success();
1386 }
1387 }
1388 }
1389
1390 // xor(x, xor(...)) -> xor(x, ...) -- flatten
1391 if (tryFlatteningOperands(op, rewriter))
1392 return success();
1393
1394 // extracts only of xor(...) -> xor(extract()...)
1395 if (narrowOperationWidth(op, true, rewriter))
1396 return success();
1397
1398 // xor(a[0], a[1], ..., a[n]) -> parity(a)
1399 if (auto source = getCommonOperand(op)) {
1400 replaceOpWithNewOpAndCopyName<ParityOp>(rewriter, op, source);
1401 return success();
1402 }
1403
1404 return failure();
1405}
1406
1407OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1408 if (hasOperandsOutsideOfBlock(getOperation()))
1409 return {};
1410
1411 // sub(x - x) -> 0
1412 if (getRhs() == getLhs())
1413 return getIntAttr(
1414 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1415 getContext());
1416
1417 if (adaptor.getRhs()) {
1418 // If both are constants, we can unconditionally fold.
1419 if (adaptor.getLhs()) {
1420 // Constant fold (c1 - c2) => (c1 + -1*c2).
1421 auto negOne = getIntAttr(
1422 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1423 getContext());
1424 auto rhsNeg = hw::ParamExprAttr::get(
1425 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1426 return hw::ParamExprAttr::get(hw::PEO::Add,
1427 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1428 }
1429
1430 // sub(x - 0) -> x
1431 if (auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1432 if (rhsC.getValue().isZero())
1433 return getLhs();
1434 }
1435 }
1436
1437 return {};
1438}
1439
1440LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
1441 if (hasOperandsOutsideOfBlock(&*op))
1442 return failure();
1443
1444 // sub(x, cst) -> add(x, -cst)
1445 APInt value;
1446 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1447 auto negCst = rewriter.create<hw::ConstantOp>(op.getLoc(), -value);
1448 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getLhs(), negCst,
1449 false);
1450 return success();
1451 }
1452
1453 // extracts only of sub(...) -> sub(extract()...)
1454 if (narrowOperationWidth(op, false, rewriter))
1455 return success();
1456
1457 return failure();
1458}
1459
1460OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1461 if (hasOperandsOutsideOfBlock(getOperation()))
1462 return {};
1463
1464 auto size = getInputs().size();
1465
1466 // add(x) -> x -- noop
1467 if (size == 1u)
1468 return getInputs()[0];
1469
1470 // Constant fold constant operands.
1471 return constFoldAssociativeOp(adaptor.getOperands(), hw::PEO::Add);
1472}
1473
1474LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
1475 if (hasOperandsOutsideOfBlock(&*op))
1476 return failure();
1477
1478 auto inputs = op.getInputs();
1479 auto size = inputs.size();
1480 assert(size > 1 && "expected 2 or more operands");
1481
1482 APInt value, value2;
1483
1484 // add(..., 0) -> add(...) -- identity
1485 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1486 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1487 inputs.drop_back(), false);
1488 return success();
1489 }
1490
1491 // add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant folding
1492 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1493 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1494 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
1495 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1496 newOperands.push_back(cst);
1497 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1498 newOperands, false);
1499 return success();
1500 }
1501
1502 // add(..., x, x) -> add(..., shl(x, 1))
1503 if (inputs[size - 1] == inputs[size - 2]) {
1504 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1505
1506 auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1507 auto shiftLeftOp =
1508 rewriter.create<comb::ShlOp>(op.getLoc(), inputs.back(), one, false);
1509
1510 newOperands.push_back(shiftLeftOp);
1511 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1512 newOperands, false);
1513 return success();
1514 }
1515
1516 auto shlOp = inputs[size - 1].getDefiningOp<comb::ShlOp>();
1517 // add(..., x, shl(x, c)) -> add(..., mul(x, (1 << c) + 1))
1518 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1519 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1520
1521 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1522 auto rhs =
1523 rewriter.create<hw::ConstantOp>(op.getLoc(), (one << value) + one);
1524
1525 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1526 auto mulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
1527
1528 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1529 newOperands.push_back(mulOp);
1530 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1531 newOperands, false);
1532 return success();
1533 }
1534
1535 auto mulOp = inputs[size - 1].getDefiningOp<comb::MulOp>();
1536 // add(..., x, mul(x, c)) -> add(..., mul(x, c + 1))
1537 if (mulOp && mulOp.getInputs().size() == 2 &&
1538 mulOp.getInputs()[0] == inputs[size - 2] &&
1539 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1540
1541 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1542 auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + one);
1543 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1544 auto newMulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
1545
1546 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1547 newOperands.push_back(newMulOp);
1548 replaceOpWithNewOpAndCopyName<AddOp>(rewriter, op, op.getType(),
1549 newOperands, false);
1550 return success();
1551 }
1552
1553 // add(a, add(...)) -> add(a, ...) -- flatten
1554 if (tryFlatteningOperands(op, rewriter))
1555 return success();
1556
1557 // extracts only of add(...) -> add(extract()...)
1558 if (narrowOperationWidth(op, false, rewriter))
1559 return success();
1560
1561 // add(add(x, c1), c2) -> add(x, c1 + c2)
1562 auto addOp = inputs[0].getDefiningOp<comb::AddOp>();
1563 if (addOp && addOp.getInputs().size() == 2 &&
1564 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1565 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1566
1567 auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
1568 replaceOpWithNewOpAndCopyName<AddOp>(
1569 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1570 /*twoState=*/op.getTwoState() && addOp.getTwoState());
1571 return success();
1572 }
1573
1574 return failure();
1575}
1576
1577OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1578 if (hasOperandsOutsideOfBlock(getOperation()))
1579 return {};
1580
1581 auto size = getInputs().size();
1582 auto inputs = adaptor.getInputs();
1583
1584 // mul(x) -> x -- noop
1585 if (size == 1u)
1586 return getInputs()[0];
1587
1588 auto width = cast<IntegerType>(getType()).getWidth();
1589 APInt value(/*numBits=*/width, 1, /*isSigned=*/false);
1590
1591 // mul(x, 0, 1) -> 0 -- annulment
1592 for (auto operand : inputs) {
1593 if (!operand)
1594 continue;
1595 value *= cast<IntegerAttr>(operand).getValue();
1596 if (value.isZero())
1597 return getIntAttr(value, getContext());
1598 }
1599
1600 // Constant fold
1601 return constFoldAssociativeOp(inputs, hw::PEO::Mul);
1602}
1603
1604LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
1605 if (hasOperandsOutsideOfBlock(&*op))
1606 return failure();
1607
1608 auto inputs = op.getInputs();
1609 auto size = inputs.size();
1610 assert(size > 1 && "expected 2 or more operands");
1611
1612 APInt value, value2;
1613
1614 // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
1615 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1616 value.isPowerOf2()) {
1617 auto shift = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(),
1618 value.exactLogBase2());
1619 auto shlOp =
1620 rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift, false);
1621
1622 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1623 ArrayRef<Value>(shlOp), false);
1624 return success();
1625 }
1626
1627 // mul(..., 1) -> mul(...) -- identity
1628 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1629 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1630 inputs.drop_back());
1631 return success();
1632 }
1633
1634 // mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant folding
1635 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1636 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1637 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value * value2);
1638 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1639 newOperands.push_back(cst);
1640 replaceOpWithNewOpAndCopyName<MulOp>(rewriter, op, op.getType(),
1641 newOperands);
1642 return success();
1643 }
1644
1645 // mul(a, mul(...)) -> mul(a, ...) -- flatten
1646 if (tryFlatteningOperands(op, rewriter))
1647 return success();
1648
1649 // extracts only of mul(...) -> mul(extract()...)
1650 if (narrowOperationWidth(op, false, rewriter))
1651 return success();
1652
1653 return failure();
1654}
1655
1656template <class Op, bool isSigned>
1657static OpFoldResult foldDiv(Op op, ArrayRef<Attribute> constants) {
1658 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1659 // divu(x, 1) -> x, divs(x, 1) -> x
1660 if (rhsValue.getValue() == 1)
1661 return op.getLhs();
1662
1663 // If the divisor is zero, do not fold for now.
1664 if (rhsValue.getValue().isZero())
1665 return {};
1666 }
1667
1668 return constFoldBinaryOp(constants, isSigned ? hw::PEO::DivS : hw::PEO::DivU);
1669}
1670
1671OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1672 if (hasOperandsOutsideOfBlock(getOperation()))
1673 return {};
1674
1675 return foldDiv<DivUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1676}
1677
1678OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1679 if (hasOperandsOutsideOfBlock(getOperation()))
1680 return {};
1681
1682 return foldDiv<DivSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1683}
1684
1685template <class Op, bool isSigned>
1686static OpFoldResult foldMod(Op op, ArrayRef<Attribute> constants) {
1687 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1688 // modu(x, 1) -> 0, mods(x, 1) -> 0
1689 if (rhsValue.getValue() == 1)
1690 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1691 op.getContext());
1692
1693 // If the divisor is zero, do not fold for now.
1694 if (rhsValue.getValue().isZero())
1695 return {};
1696 }
1697
1698 if (auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1699 // modu(0, x) -> 0, mods(0, x) -> 0
1700 if (lhsValue.getValue().isZero())
1701 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1702 op.getContext());
1703 }
1704
1705 return constFoldBinaryOp(constants, isSigned ? hw::PEO::ModS : hw::PEO::ModU);
1706}
1707
1708OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1709 if (hasOperandsOutsideOfBlock(getOperation()))
1710 return {};
1711
1712 return foldMod<ModUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1713}
1714
1715OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1716 if (hasOperandsOutsideOfBlock(getOperation()))
1717 return {};
1718
1719 return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1720}
1721//===----------------------------------------------------------------------===//
1722// ConcatOp
1723//===----------------------------------------------------------------------===//
1724
1725// Constant folding
1726OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1727 if (hasOperandsOutsideOfBlock(getOperation()))
1728 return {};
1729
1730 if (getNumOperands() == 1)
1731 return getOperand(0);
1732
1733 // If all the operands are constant, we can fold.
1734 for (auto attr : adaptor.getInputs())
1735 if (!attr || !isa<IntegerAttr>(attr))
1736 return {};
1737
1738 // If we got here, we can constant fold.
1739 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1740 APInt result(resultWidth, 0);
1741
1742 unsigned nextInsertion = resultWidth;
1743 // Insert each chunk into the result.
1744 for (auto attr : adaptor.getInputs()) {
1745 auto chunk = cast<IntegerAttr>(attr).getValue();
1746 nextInsertion -= chunk.getBitWidth();
1747 result.insertBits(chunk, nextInsertion);
1748 }
1749
1750 return getIntAttr(result, getContext());
1751}
1752
1753LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
1754 if (hasOperandsOutsideOfBlock(&*op))
1755 return failure();
1756
1757 auto inputs = op.getInputs();
1758 auto size = inputs.size();
1759 assert(size > 1 && "expected 2 or more operands");
1760
1761 // This function is used when we flatten neighboring operands of a
1762 // (variadic) concat into a new vesion of the concat. first/last indices
1763 // are inclusive.
1764 auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex,
1765 ValueRange replacements) -> LogicalResult {
1766 SmallVector<Value, 4> newOperands;
1767 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1768 newOperands.append(replacements.begin(), replacements.end());
1769 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1770 if (newOperands.size() == 1)
1771 replaceOpAndCopyName(rewriter, op, newOperands[0]);
1772 else
1773 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
1774 newOperands);
1775 return success();
1776 };
1777
1778 Value commonOperand = inputs[0];
1779 for (size_t i = 0; i != size; ++i) {
1780 // Check to see if all operands are the same.
1781 if (inputs[i] != commonOperand)
1782 commonOperand = Value();
1783
1784 // If an operand to the concat is itself a concat, then we can fold them
1785 // together.
1786 if (auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1787 return flattenConcat(i, i, subConcat->getOperands());
1788
1789 // Check for canonicalization due to neighboring operands.
1790 if (i != 0) {
1791 // Merge neighboring constants.
1792 if (auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1793 if (auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1794 unsigned prevWidth = prevCst.getValue().getBitWidth();
1795 unsigned thisWidth = cst.getValue().getBitWidth();
1796 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1797 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1798 << thisWidth;
1799 Value replacement =
1800 rewriter.create<hw::ConstantOp>(op.getLoc(), resultCst);
1801 return flattenConcat(i - 1, i, replacement);
1802 }
1803 }
1804
1805 // If the two operands are the same, turn them into a replicate.
1806 if (inputs[i] == inputs[i - 1]) {
1807 Value replacement =
1808 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1809 return flattenConcat(i - 1, i, replacement);
1810 }
1811
1812 // If this input is a replicate, see if we can fold it with the previous
1813 // one.
1814 if (auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1815 // ... x, repl(x, n), ... ==> ..., repl(x, n+1), ...
1816 if (repl.getOperand() == inputs[i - 1]) {
1817 Value replacement = rewriter.createOrFold<ReplicateOp>(
1818 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1819 return flattenConcat(i - 1, i, replacement);
1820 }
1821 // ... repl(x, n), repl(x, m), ... ==> ..., repl(x, n+m), ...
1822 if (auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1823 if (prevRepl.getOperand() == repl.getOperand()) {
1824 Value replacement = rewriter.createOrFold<ReplicateOp>(
1825 op.getLoc(), repl.getOperand(),
1826 repl.getMultiple() + prevRepl.getMultiple());
1827 return flattenConcat(i - 1, i, replacement);
1828 }
1829 }
1830 }
1831
1832 // ... repl(x, n), x, ... ==> ..., repl(x, n+1), ...
1833 if (auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1834 if (repl.getOperand() == inputs[i]) {
1835 Value replacement = rewriter.createOrFold<ReplicateOp>(
1836 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1837 return flattenConcat(i - 1, i, replacement);
1838 }
1839 }
1840
1841 // Merge neighboring extracts of neighboring inputs, e.g.
1842 // {A[3], A[2]} -> A[3:2]
1843 if (auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1844 if (auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1845 if (extract.getInput() == prevExtract.getInput()) {
1846 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1847 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1848 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1849 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1850 Value replacement = rewriter.create<ExtractOp>(
1851 op.getLoc(), resType, extract.getInput(),
1852 extract.getLowBit());
1853 return flattenConcat(i - 1, i, replacement);
1854 }
1855 }
1856 }
1857 }
1858 // Merge neighboring array extracts of neighboring inputs, e.g.
1859 // {Array[4], bitcast(Array[3:2])} -> bitcast(A[4:2])
1860
1861 // This represents a slice of an array.
1862 struct ArraySlice {
1863 Value input;
1864 Value index;
1865 size_t width;
1866 static std::optional<ArraySlice> get(Value value) {
1867 assert(isa<IntegerType>(value.getType()) && "expected integer type");
1868 if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
1869 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1870 // array slice op is wrapped with bitcast.
1871 if (auto bitcast = value.getDefiningOp<hw::BitcastOp>())
1872 if (auto arraySlice =
1873 bitcast.getInput().getDefiningOp<hw::ArraySliceOp>())
1874 return ArraySlice{
1875 arraySlice.getInput(), arraySlice.getLowIndex(),
1876 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1877 .getNumElements()};
1878 return std::nullopt;
1879 }
1880 };
1881 if (auto extractOpt = ArraySlice::get(inputs[i])) {
1882 if (auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1883 // Check that two array slices are mergable.
1884 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1885 prevExtractOpt->input == extractOpt->input &&
1886 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1887 extractOpt->width)) {
1888 auto resType = hw::ArrayType::get(
1889 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1890 .getElementType(),
1891 extractOpt->width + prevExtractOpt->width);
1892 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1893 Value replacement = rewriter.create<hw::BitcastOp>(
1894 op.getLoc(), resIntType,
1895 rewriter.create<hw::ArraySliceOp>(op.getLoc(), resType,
1896 prevExtractOpt->input,
1897 extractOpt->index));
1898 return flattenConcat(i - 1, i, replacement);
1899 }
1900 }
1901 }
1902 }
1903 }
1904
1905 // If all operands were the same, then this is a replicate.
1906 if (commonOperand) {
1907 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
1908 commonOperand);
1909 return success();
1910 }
1911
1912 return failure();
1913}
1914
1915//===----------------------------------------------------------------------===//
1916// MuxOp
1917//===----------------------------------------------------------------------===//
1918
1919OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1920 if (hasOperandsOutsideOfBlock(getOperation()))
1921 return {};
1922
1923 // mux (c, b, b) -> b
1924 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1925 return getTrueValue();
1926 if (auto tv = adaptor.getTrueValue())
1927 if (tv == adaptor.getFalseValue())
1928 return tv;
1929
1930 // mux(0, a, b) -> b
1931 // mux(1, a, b) -> a
1932 if (auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1933 if (pred.getValue().isZero())
1934 return getFalseValue();
1935 return getTrueValue();
1936 }
1937
1938 // mux(cond, 1, 0) -> cond
1939 if (auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1940 if (auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1941 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1942 hw::getBitWidth(getType()) == 1)
1943 return getCond();
1944
1945 return {};
1946}
1947
1948/// Check to see if the condition to the specified mux is an equality
1949/// comparison `indexValue` and one or more constants. If so, put the
1950/// constants in the constants vector and return true, otherwise return false.
1951///
1952/// This is part of foldMuxChain.
1953///
1954static bool
1955getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted,
1956 std::function<void(hw::ConstantOp)> constantFn) {
1957 // Handle `idx == 42` and `idx != 42`.
1958 if (auto cmp = cond.getDefiningOp<ICmpOp>()) {
1959 // TODO: We could handle things like "x < 2" as two entries.
1960 auto requiredPredicate =
1961 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1962 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1963 if (auto cst = cmp.getRhs().getDefiningOp<hw::ConstantOp>()) {
1964 constantFn(cst);
1965 return true;
1966 }
1967 }
1968 return false;
1969 }
1970
1971 // Handle mux(`idx == 1 || idx == 3`, value, muxchain).
1972 if (auto orOp = cond.getDefiningOp<OrOp>()) {
1973 if (!isInverted)
1974 return false;
1975 for (auto operand : orOp.getOperands())
1976 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
1977 return false;
1978 return true;
1979 }
1980
1981 // Handle mux(`idx != 1 && idx != 3`, muxchain, value).
1982 if (auto andOp = cond.getDefiningOp<AndOp>()) {
1983 if (isInverted)
1984 return false;
1985 for (auto operand : andOp.getOperands())
1986 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
1987 return false;
1988 return true;
1989 }
1990
1991 return false;
1992}
1993
1994/// Given a mux, check to see if the "on true" value (or "on false" value if
1995/// isFalseSide=true) is a mux tree with the same condition. This allows us
1996/// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into
1997/// `array_get (array_create(A, B, C), VAL)` which is far more compact and
1998/// allows synthesis tools to do more interesting optimizations.
1999///
2000/// This returns false if we cannot form the mux tree (or do not want to) and
2001/// returns true if the mux was replaced.
2002static bool foldMuxChain(MuxOp rootMux, bool isFalseSide,
2003 PatternRewriter &rewriter) {
2004 // Get the index value being compared. Later we check to see if it is
2005 // compared to a constant with the right predicate.
2006 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2007 if (!rootCmp)
2008 return false;
2009 Value indexValue = rootCmp.getLhs();
2010
2011 // Return the value to use if the equality match succeeds.
2012 auto getCaseValue = [&](MuxOp mux) -> Value {
2013 return mux.getOperand(1 + unsigned(!isFalseSide));
2014 };
2015
2016 // Return the value to use if the equality match fails. This is the next
2017 // mux in the sequence or the "otherwise" value.
2018 auto getTreeValue = [&](MuxOp mux) -> Value {
2019 return mux.getOperand(1 + unsigned(isFalseSide));
2020 };
2021
2022 // Start scanning the mux tree to see what we've got. Keep track of the
2023 // constant comparison value and the SSA value to use when equal to it.
2024 SmallVector<Location> locationsFound;
2025 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2026
2027 /// Extract constants and values into `valuesFound` and return true if this is
2028 /// part of the mux tree, otherwise return false.
2029 auto collectConstantValues = [&](MuxOp mux) -> bool {
2031 mux.getCond(), indexValue, isFalseSide, [&](hw::ConstantOp cst) {
2032 valuesFound.push_back({cst, getCaseValue(mux)});
2033 locationsFound.push_back(mux.getCond().getLoc());
2034 locationsFound.push_back(mux->getLoc());
2035 });
2036 };
2037
2038 // Make sure the root is a correct comparison with a constant.
2039 if (!collectConstantValues(rootMux))
2040 return false;
2041
2042 // Make sure that we're not looking at the intermediate node in a mux tree.
2043 if (rootMux->hasOneUse()) {
2044 if (auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2045 if (getTreeValue(userMux) == rootMux.getResult() &&
2046 getMuxChainCondConstant(userMux.getCond(), indexValue, isFalseSide,
2047 [&](hw::ConstantOp cst) {}))
2048 return false;
2049 }
2050 }
2051
2052 // Scan up the tree linearly.
2053 auto nextTreeValue = getTreeValue(rootMux);
2054 while (1) {
2055 auto nextMux = nextTreeValue.getDefiningOp<MuxOp>();
2056 if (!nextMux || !nextMux->hasOneUse())
2057 break;
2058 if (!collectConstantValues(nextMux))
2059 break;
2060 nextTreeValue = getTreeValue(nextMux);
2061 }
2062
2063 // We need to have more than three values to create an array. This is an
2064 // arbitrary threshold which is saying that one or two muxes together is ok,
2065 // but three should be folded.
2066 if (valuesFound.size() < 3)
2067 return false;
2068
2069 // If the array is greater that 9 bits, it will take over 512 elements and
2070 // it will be too large for a single expression.
2071 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2072 if (indexWidth >= 9)
2073 return false;
2074
2075 // Next we need to see if the values are dense-ish. We don't want to have
2076 // a tremendous number of replicated entries in the array. Some sparsity is
2077 // ok though, so we require the table to be at least 5/8 utilized.
2078 uint64_t tableSize = 1ULL << indexWidth;
2079 if (valuesFound.size() < (tableSize * 5) / 8)
2080 return false; // Not dense enough.
2081
2082 // Ok, we're going to do the transformation, start by building the table
2083 // filled with the "otherwise" value.
2084 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2085
2086 // Fill in entries in the table from the leaf to the root of the expression.
2087 // This ensures that any duplicate matches end up with the ultimate value,
2088 // which is the one closer to the root.
2089 for (auto &elt : llvm::reverse(valuesFound)) {
2090 uint64_t idx = elt.first.getValue().getZExtValue();
2091 assert(idx < table.size() && "constant should be same bitwidth as index");
2092 table[idx] = elt.second;
2093 }
2094
2095 // The hw.array_create operation has the operand list in unintuitive order
2096 // with a[0] stored as the last element, not the first.
2097 std::reverse(table.begin(), table.end());
2098
2099 // Build the array_create and the array_get.
2100 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2101 auto array = rewriter.create<hw::ArrayCreateOp>(fusedLoc, table);
2102 replaceOpWithNewOpAndCopyName<hw::ArrayGetOp>(rewriter, rootMux, array,
2103 indexValue);
2104 return true;
2105}
2106
2107/// Given a fully associative variadic operation like (a+b+c+d), break the
2108/// expression into two parts, one without the specified operand (e.g.
2109/// `tmp = a+b+d`) and one that combines that into the full expression (e.g.
2110/// `tmp+c`), and return the inner expression.
2111///
2112/// NOTE: This mutates the operation in place if it only has a single user,
2113/// which assumes that user will be removed.
2114///
2115static Value extractOperandFromFullyAssociative(Operation *fullyAssoc,
2116 size_t operandNo,
2117 PatternRewriter &rewriter) {
2118 assert(fullyAssoc->getNumOperands() >= 2 && "cannot split up unary ops");
2119 assert(operandNo < fullyAssoc->getNumOperands() && "Invalid operand #");
2120
2121 // If this expression already has two operands (the common case) no splitting
2122 // is necessary.
2123 if (fullyAssoc->getNumOperands() == 2)
2124 return fullyAssoc->getOperand(operandNo ^ 1);
2125
2126 // If the operation has a single use, mutate it in place.
2127 if (fullyAssoc->hasOneUse()) {
2128 rewriter.modifyOpInPlace(fullyAssoc,
2129 [&]() { fullyAssoc->eraseOperand(operandNo); });
2130 return fullyAssoc->getResult(0);
2131 }
2132
2133 // Form the new operation with the operands that remain.
2134 SmallVector<Value> operands;
2135 operands.append(fullyAssoc->getOperands().begin(),
2136 fullyAssoc->getOperands().begin() + operandNo);
2137 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2138 fullyAssoc->getOperands().end());
2139 Value opWithoutExcluded = createGenericOp(
2140 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2141 Value excluded = fullyAssoc->getOperand(operandNo);
2142
2143 Value fullResult =
2144 createGenericOp(fullyAssoc->getLoc(), fullyAssoc->getName(),
2145 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2146 replaceOpAndCopyName(rewriter, fullyAssoc, fullResult);
2147 return opWithoutExcluded;
2148}
2149
2150/// Fold things like `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond)|a` and
2151/// `mux(cond, a, x|y|z|a) -> `(x|y|z)&replicate(~cond) | a` (when isTrueOperand
2152/// is true. Return true on successful transformation, false if not.
2153///
2154/// These are various forms of "predicated ops" that can be handled with a
2155/// replicate/and combination.
2156static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
2157 PatternRewriter &rewriter) {
2158 // Check to see the operand in question is an operation. If it is a port,
2159 // we can't simplify it.
2160 Operation *subExpr =
2161 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2162 if (!subExpr || subExpr->getNumOperands() < 2)
2163 return false;
2164
2165 // If this isn't an operation we can handle, don't spend energy on it.
2166 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2167 return false;
2168
2169 // Check to see if the common value occurs in the operand list for the
2170 // subexpression op. If so, then we can simplify it.
2171 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2172 size_t opNo = 0, e = subExpr->getNumOperands();
2173 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2174 ++opNo;
2175 if (opNo == e)
2176 return false;
2177
2178 // If we got a hit, then go ahead and simplify it!
2179 Value cond = op.getCond();
2180
2181 // `mux(cond, a, mux(cond2, a, b))` -> `mux(cond|cond2, a, b)`
2182 // `mux(cond, a, mux(cond2, b, a))` -> `mux(cond|~cond2, a, b)`
2183 // `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)`
2184 // `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)`
2185 if (auto subMux = dyn_cast<MuxOp>(subExpr)) {
2186 if (subMux == op)
2187 return false;
2188
2189 Value otherValue;
2190 Value subCond = subMux.getCond();
2191
2192 // Invert th subCond if needed and dig out the 'b' value.
2193 if (subMux.getTrueValue() == commonValue)
2194 otherValue = subMux.getFalseValue();
2195 else if (subMux.getFalseValue() == commonValue) {
2196 otherValue = subMux.getTrueValue();
2197 subCond = createOrFoldNot(op.getLoc(), subCond, rewriter);
2198 } else {
2199 // We can't fold `mux(cond, a, mux(a, x, y))`.
2200 return false;
2201 }
2202
2203 // Invert the outer cond if needed, and combine the mux conditions.
2204 if (!isTrueOperand)
2205 cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2206 cond = rewriter.createOrFold<OrOp>(op.getLoc(), cond, subCond, false);
2207 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, cond, commonValue,
2208 otherValue, op.getTwoState());
2209 return true;
2210 }
2211
2212 // Invert the condition if needed. Or/Xor invert when dealing with
2213 // TrueOperand, And inverts for False operand.
2214 bool isaAndOp = isa<AndOp>(subExpr);
2215 if (isTrueOperand ^ isaAndOp)
2216 cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2217
2218 auto extendedCond =
2219 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2220
2221 // Cache this information before subExpr is erased by extraction below.
2222 bool isaXorOp = isa<XorOp>(subExpr);
2223 bool isaOrOp = isa<OrOp>(subExpr);
2224
2225 // Handle the fully associative ops, start by pulling out the subexpression
2226 // from a many operand version of the op.
2227 auto restOfAssoc =
2228 extractOperandFromFullyAssociative(subExpr, opNo, rewriter);
2229
2230 // `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond) | a`
2231 // `mux(cond, x^y^z^a, a)` -> `(x^y^z)&replicate(cond) ^ a`
2232 if (isaOrOp || isaXorOp) {
2233 auto masked = rewriter.createOrFold<AndOp>(op.getLoc(), extendedCond,
2234 restOfAssoc, false);
2235 if (isaXorOp)
2236 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, masked, commonValue,
2237 false);
2238 else
2239 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, masked, commonValue,
2240 false);
2241 return true;
2242 }
2243
2244 // `mux(cond, a, x&y&z&a)` -> `((x&y&z)|replicate(cond)) & a`
2245 assert(isaAndOp && "unexpected operation here");
2246 auto masked = rewriter.createOrFold<OrOp>(op.getLoc(), extendedCond,
2247 restOfAssoc, false);
2248 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, masked, commonValue,
2249 false);
2250 return true;
2251}
2252
2253/// This function is invoke when we find a mux with true/false operations that
2254/// have the same opcode. Check to see if we can strength reduce the mux by
2255/// applying it to less data by applying this transformation:
2256/// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2257static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp,
2258 Operation *falseOp,
2259 PatternRewriter &rewriter) {
2260 // Right now we only apply to concat.
2261 // TODO: Generalize this to and, or, xor, icmp(!), which all occur in practice
2262 if (!isa<ConcatOp>(trueOp))
2263 return false;
2264
2265 // Decode the operands, looking through recursive concats and replicates.
2266 SmallVector<Value> trueOperands, falseOperands;
2267 getConcatOperands(trueOp->getResult(0), trueOperands);
2268 getConcatOperands(falseOp->getResult(0), falseOperands);
2269
2270 size_t numTrueOperands = trueOperands.size();
2271 size_t numFalseOperands = falseOperands.size();
2272
2273 if (!numTrueOperands || !numFalseOperands ||
2274 (trueOperands.front() != falseOperands.front() &&
2275 trueOperands.back() != falseOperands.back()))
2276 return false;
2277
2278 // Pull all leading shared operands out into their own op if any are common.
2279 if (trueOperands.front() == falseOperands.front()) {
2280 SmallVector<Value> operands;
2281 size_t i;
2282 for (i = 0; i < numTrueOperands; ++i) {
2283 Value trueOperand = trueOperands[i];
2284 if (trueOperand == falseOperands[i])
2285 operands.push_back(trueOperand);
2286 else
2287 break;
2288 }
2289 if (i == numTrueOperands) {
2290 // Selecting between distinct, but lexically identical, concats.
2291 replaceOpAndCopyName(rewriter, mux, trueOp->getResult(0));
2292 return true;
2293 }
2294
2295 Value sharedMSB;
2296 if (llvm::all_of(operands, [&](Value v) { return v == operands.front(); }))
2297 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2298 mux->getLoc(), operands.front(), operands.size());
2299 else
2300 sharedMSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2301 operands.clear();
2302
2303 // Get a concat of the LSB's on each side.
2304 operands.append(trueOperands.begin() + i, trueOperands.end());
2305 Value trueLSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2306 operands.clear();
2307 operands.append(falseOperands.begin() + i, falseOperands.end());
2308 Value falseLSB =
2309 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2310 // Merge the LSBs with a new mux and concat the MSB with the LSB to be
2311 // done.
2312 Value lsb = rewriter.createOrFold<MuxOp>(
2313 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2314 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2315 return true;
2316 }
2317
2318 // If trailing operands match, try to commonize them.
2319 if (trueOperands.back() == falseOperands.back()) {
2320 SmallVector<Value> operands;
2321 size_t i;
2322 for (i = 0;; ++i) {
2323 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2324 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2325 operands.push_back(trueOperand);
2326 else
2327 break;
2328 }
2329 std::reverse(operands.begin(), operands.end());
2330 Value sharedLSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2331 operands.clear();
2332
2333 // Get a concat of the MSB's on each side.
2334 operands.append(trueOperands.begin(), trueOperands.end() - i);
2335 Value trueMSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2336 operands.clear();
2337 operands.append(falseOperands.begin(), falseOperands.end() - i);
2338 Value falseMSB =
2339 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2340 // Merge the MSBs with a new mux and concat the MSB with the LSB to be done.
2341 Value msb = rewriter.createOrFold<MuxOp>(
2342 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2343 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, mux, msb, sharedLSB);
2344 return true;
2345 }
2346
2347 return false;
2348}
2349
2350// If both arguments of the mux are arrays with the same elements, sink the
2351// mux and return a uniform array initializing all elements to it.
2352static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter) {
2353 auto trueVec = op.getTrueValue().getDefiningOp<hw::ArrayCreateOp>();
2354 auto falseVec = op.getFalseValue().getDefiningOp<hw::ArrayCreateOp>();
2355 if (!trueVec || !falseVec)
2356 return false;
2357 if (!trueVec.isUniform() || !falseVec.isUniform())
2358 return false;
2359
2360 auto mux = rewriter.create<MuxOp>(
2361 op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2362 falseVec.getUniformElement(), op.getTwoState());
2363
2364 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2365 rewriter.replaceOpWithNewOp<hw::ArrayCreateOp>(op, values);
2366 return true;
2367}
2368
2369namespace {
2370struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
2371 using OpRewritePattern::OpRewritePattern;
2372
2373 LogicalResult matchAndRewrite(MuxOp op,
2374 PatternRewriter &rewriter) const override;
2375};
2376
2377LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
2378 PatternRewriter &rewriter) const {
2379 if (hasOperandsOutsideOfBlock(&*op))
2380 return failure();
2381
2382 // If the op has a SV attribute, don't optimize it.
2383 if (hasSVAttributes(op))
2384 return failure();
2385 APInt value;
2386
2387 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2388 if (value.getBitWidth() == 1) {
2389 // mux(a, 0, b) -> and(~a, b) for single-bit values.
2390 if (value.isZero()) {
2391 auto notCond = createOrFoldNot(op.getLoc(), op.getCond(), rewriter);
2392 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, notCond,
2393 op.getFalseValue(), false);
2394 return success();
2395 }
2396
2397 // mux(a, 1, b) -> or(a, b) for single-bit values.
2398 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, op.getCond(),
2399 op.getFalseValue(), false);
2400 return success();
2401 }
2402
2403 // Check for mux of two constants. There are many ways to simplify them.
2404 APInt value2;
2405 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2406 // When both inputs are constants and differ by only one bit, we can
2407 // simplify by splitting the mux into up to three contiguous chunks: one
2408 // for the differing bit and up to two for the bits that are the same.
2409 // E.g. mux(a, 3'h2, 0) -> concat(0, mux(a, 1, 0), 0) -> concat(0, a, 0)
2410 APInt xorValue = value ^ value2;
2411 if (xorValue.isPowerOf2()) {
2412 unsigned leadingZeros = xorValue.countLeadingZeros();
2413 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2414 SmallVector<Value, 3> operands;
2415
2416 // Concat operands go from MSB to LSB, so we handle chunks in reverse
2417 // order of bit indexes.
2418 // For the chunks that are identical (i.e. correspond to 0s in
2419 // xorValue), we can extract directly from either input value, and we
2420 // arbitrarily pick the trueValue().
2421
2422 if (leadingZeros > 0)
2423 operands.push_back(rewriter.createOrFold<ExtractOp>(
2424 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2425
2426 // Handle the differing bit, which should simplify into either cond or
2427 // ~cond.
2428 auto v1 = rewriter.createOrFold<ExtractOp>(
2429 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2430 auto v2 = rewriter.createOrFold<ExtractOp>(
2431 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2432 operands.push_back(rewriter.createOrFold<MuxOp>(
2433 op.getLoc(), op.getCond(), v1, v2, false));
2434
2435 if (trailingZeros > 0)
2436 operands.push_back(rewriter.createOrFold<ExtractOp>(
2437 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2438
2439 replaceOpWithNewOpAndCopyName<ConcatOp>(rewriter, op, op.getType(),
2440 operands);
2441 return success();
2442 }
2443
2444 // If the true value is all ones and the false is all zeros then we have a
2445 // replicate pattern.
2446 if (value.isAllOnes() && value2.isZero()) {
2447 replaceOpWithNewOpAndCopyName<ReplicateOp>(rewriter, op, op.getType(),
2448 op.getCond());
2449 return success();
2450 }
2451 }
2452 }
2453
2454 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2455 value.getBitWidth() == 1) {
2456 // mux(a, b, 0) -> and(a, b) for single-bit values.
2457 if (value.isZero()) {
2458 replaceOpWithNewOpAndCopyName<AndOp>(rewriter, op, op.getCond(),
2459 op.getTrueValue(), false);
2460 return success();
2461 }
2462
2463 // mux(a, b, 1) -> or(~a, b) for single-bit values.
2464 // falseValue() is known to be a single-bit 1, which we can use for
2465 // the 1 in the representation of ~ using xor.
2466 auto notCond = rewriter.createOrFold<XorOp>(op.getLoc(), op.getCond(),
2467 op.getFalseValue(), false);
2468 replaceOpWithNewOpAndCopyName<OrOp>(rewriter, op, notCond,
2469 op.getTrueValue(), false);
2470 return success();
2471 }
2472
2473 // mux(!a, b, c) -> mux(a, c, b)
2474 Value subExpr;
2475 Operation *condOp = op.getCond().getDefiningOp();
2476 if (condOp && matchPattern(condOp, m_Complement(m_Any(&subExpr))) &&
2477 op.getTwoState()) {
2478 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, op.getType(), subExpr,
2479 op.getFalseValue(), op.getTrueValue(),
2480 true);
2481 return success();
2482 }
2483
2484 // Same but with Demorgan's law.
2485 // mux(and(~a, ~b, ~c), x, y) -> mux(or(a, b, c), y, x)
2486 // mux(or(~a, ~b, ~c), x, y) -> mux(and(a, b, c), y, x)
2487 if (condOp && condOp->hasOneUse()) {
2488 SmallVector<Value> invertedOperands;
2489
2490 /// Scan all the operands to see if they are complemented. If so, build a
2491 /// vector of them and return true, otherwise return false.
2492 auto getInvertedOperands = [&]() -> bool {
2493 for (Value operand : condOp->getOperands()) {
2494 if (matchPattern(operand, m_Complement(m_Any(&subExpr))))
2495 invertedOperands.push_back(subExpr);
2496 else
2497 return false;
2498 }
2499 return true;
2500 };
2501
2502 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2503 auto newOr =
2504 rewriter.createOrFold<OrOp>(op.getLoc(), invertedOperands, false);
2505 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newOr,
2506 op.getFalseValue(),
2507 op.getTrueValue(), op.getTwoState());
2508 return success();
2509 }
2510 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2511 auto newAnd =
2512 rewriter.createOrFold<AndOp>(op.getLoc(), invertedOperands, false);
2513 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, newAnd,
2514 op.getFalseValue(),
2515 op.getTrueValue(), op.getTwoState());
2516 return success();
2517 }
2518 }
2519
2520 if (auto falseMux = op.getFalseValue().getDefiningOp<MuxOp>();
2521 falseMux && falseMux != op) {
2522 // mux(selector, x, mux(selector, y, z) = mux(selector, x, z)
2523 if (op.getCond() == falseMux.getCond()) {
2524 replaceOpWithNewOpAndCopyName<MuxOp>(
2525 rewriter, op, op.getCond(), op.getTrueValue(),
2526 falseMux.getFalseValue(), op.getTwoStateAttr());
2527 return success();
2528 }
2529
2530 // Check to see if we can fold a mux tree into an array_create/get pair.
2531 if (foldMuxChain(op, /*isFalse*/ true, rewriter))
2532 return success();
2533 }
2534
2535 if (auto trueMux = op.getTrueValue().getDefiningOp<MuxOp>();
2536 trueMux && trueMux != op) {
2537 // mux(selector, mux(selector, a, b), c) = mux(selector, a, c)
2538 if (op.getCond() == trueMux.getCond()) {
2539 replaceOpWithNewOpAndCopyName<MuxOp>(
2540 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2541 op.getFalseValue(), op.getTwoStateAttr());
2542 return success();
2543 }
2544
2545 // Check to see if we can fold a mux tree into an array_create/get pair.
2546 if (foldMuxChain(op, /*isFalseSide*/ false, rewriter))
2547 return success();
2548 }
2549
2550 // mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c))
2551 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2552 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2553 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2554 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2555 falseMux != op) {
2556 auto subMux = rewriter.create<MuxOp>(
2557 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2558 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2559 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2560 trueMux.getTrueValue(), subMux,
2561 op.getTwoStateAttr());
2562 return success();
2563 }
2564
2565 // mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b)
2566 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2567 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2568 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2569 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2570 falseMux != op) {
2571 auto subMux = rewriter.create<MuxOp>(
2572 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2573 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2574 replaceOpWithNewOpAndCopyName<MuxOp>(rewriter, op, trueMux.getCond(),
2575 subMux, trueMux.getFalseValue(),
2576 op.getTwoStateAttr());
2577 return success();
2578 }
2579
2580 // mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b)
2581 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2582 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2583 trueMux && falseMux &&
2584 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2585 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2586 falseMux != op) {
2587 auto subMux = rewriter.create<MuxOp>(
2588 rewriter.getFusedLoc(
2589 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2590 op.getCond(), trueMux.getCond(), falseMux.getCond());
2591 replaceOpWithNewOpAndCopyName<MuxOp>(
2592 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2593 op.getTwoStateAttr());
2594 return success();
2595 }
2596
2597 // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a
2598 if (foldCommonMuxValue(op, false, rewriter))
2599 return success();
2600 // mux(cond, a, x|y|z|a) -> (x|y|z)&replicate(~cond) | a
2601 if (foldCommonMuxValue(op, true, rewriter))
2602 return success();
2603
2604 // `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2605 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2606 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2607 if (trueOp->getName() == falseOp->getName())
2608 if (foldCommonMuxOperation(op, trueOp, falseOp, rewriter))
2609 return success();
2610
2611 // extracts only of mux(...) -> mux(extract()...)
2612 if (narrowOperationWidth(op, true, rewriter))
2613 return success();
2614
2615 // mux(cond, repl(n, a1), repl(n, a2)) -> repl(n, mux(cond, a1, a2))
2616 if (foldMuxOfUniformArrays(op, rewriter))
2617 return success();
2618
2619 return failure();
2620}
2621
2622static bool foldArrayOfMuxes(hw::ArrayCreateOp op, PatternRewriter &rewriter) {
2623 // Do not fold uniform or singleton arrays to avoid duplicating muxes.
2624 if (op.getInputs().empty() || op.isUniform())
2625 return false;
2626 auto inputs = op.getInputs();
2627 if (inputs.size() <= 1)
2628 return false;
2629
2630 // Check the operands to the array create. Ensure all of them are the
2631 // same op with the same number of operands.
2632 auto first = inputs[0].getDefiningOp<comb::MuxOp>();
2633 if (!first || hasSVAttributes(first))
2634 return false;
2635
2636 // Check whether all operands are muxes with the same condition.
2637 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2638 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2639 if (!input || first.getCond() != input.getCond())
2640 return false;
2641 }
2642
2643 // Collect the true and the false branches into arrays.
2644 SmallVector<Value> trues{first.getTrueValue()};
2645 SmallVector<Value> falses{first.getFalseValue()};
2646 SmallVector<Location> locs{first->getLoc()};
2647 bool isTwoState = true;
2648 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2649 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2650 trues.push_back(input.getTrueValue());
2651 falses.push_back(input.getFalseValue());
2652 locs.push_back(input->getLoc());
2653 if (!input.getTwoState())
2654 isTwoState = false;
2655 }
2656
2657 // Define the location of the array create as the aggregate of all muxes.
2658 auto loc = FusedLoc::get(op.getContext(), locs);
2659
2660 // Replace the create with an aggregate operation. Push the create op
2661 // into the operands of the aggregate operation.
2662 auto arrayTy = op.getType();
2663 auto trueValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, trues);
2664 auto falseValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, falses);
2665 rewriter.replaceOpWithNewOp<comb::MuxOp>(op, arrayTy, first.getCond(),
2666 trueValues, falseValues, isTwoState);
2667 return true;
2668}
2669
2670struct ArrayRewriter : public mlir::OpRewritePattern<hw::ArrayCreateOp> {
2671 using OpRewritePattern::OpRewritePattern;
2672
2673 LogicalResult matchAndRewrite(hw::ArrayCreateOp op,
2674 PatternRewriter &rewriter) const override {
2675 if (hasOperandsOutsideOfBlock(&*op))
2676 return failure();
2677
2678 if (foldArrayOfMuxes(op, rewriter))
2679 return success();
2680 return failure();
2681 }
2682};
2683
2684} // namespace
2685
2686void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2687 MLIRContext *context) {
2688 results.insert<MuxRewriter, ArrayRewriter>(context);
2689}
2690
2691//===----------------------------------------------------------------------===//
2692// ICmpOp
2693//===----------------------------------------------------------------------===//
2694
2695// Calculate the result of a comparison when the LHS and RHS are both
2696// constants.
2697static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs,
2698 const APInt &rhs) {
2699 switch (predicate) {
2700 case ICmpPredicate::eq:
2701 return lhs.eq(rhs);
2702 case ICmpPredicate::ne:
2703 return lhs.ne(rhs);
2704 case ICmpPredicate::slt:
2705 return lhs.slt(rhs);
2706 case ICmpPredicate::sle:
2707 return lhs.sle(rhs);
2708 case ICmpPredicate::sgt:
2709 return lhs.sgt(rhs);
2710 case ICmpPredicate::sge:
2711 return lhs.sge(rhs);
2712 case ICmpPredicate::ult:
2713 return lhs.ult(rhs);
2714 case ICmpPredicate::ule:
2715 return lhs.ule(rhs);
2716 case ICmpPredicate::ugt:
2717 return lhs.ugt(rhs);
2718 case ICmpPredicate::uge:
2719 return lhs.uge(rhs);
2720 case ICmpPredicate::ceq:
2721 return lhs.eq(rhs);
2722 case ICmpPredicate::cne:
2723 return lhs.ne(rhs);
2724 case ICmpPredicate::weq:
2725 return lhs.eq(rhs);
2726 case ICmpPredicate::wne:
2727 return lhs.ne(rhs);
2728 }
2729 llvm_unreachable("unknown comparison predicate");
2730}
2731
2732// Returns the result of applying the predicate when the LHS and RHS are the
2733// exact same value.
2734static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) {
2735 switch (predicate) {
2736 case ICmpPredicate::eq:
2737 case ICmpPredicate::sle:
2738 case ICmpPredicate::sge:
2739 case ICmpPredicate::ule:
2740 case ICmpPredicate::uge:
2741 case ICmpPredicate::ceq:
2742 case ICmpPredicate::weq:
2743 return true;
2744 case ICmpPredicate::ne:
2745 case ICmpPredicate::slt:
2746 case ICmpPredicate::sgt:
2747 case ICmpPredicate::ult:
2748 case ICmpPredicate::ugt:
2749 case ICmpPredicate::cne:
2750 case ICmpPredicate::wne:
2751 return false;
2752 }
2753 llvm_unreachable("unknown comparison predicate");
2754}
2755
2756OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2757 if (hasOperandsOutsideOfBlock(getOperation()))
2758 return {};
2759
2760 // gt a, a -> false
2761 // gte a, a -> true
2762 if (getLhs() == getRhs()) {
2763 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2764 return IntegerAttr::get(getType(), val);
2765 }
2766
2767 // gt 1, 2 -> false
2768 if (auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2769 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2770 auto val =
2771 applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2772 return IntegerAttr::get(getType(), val);
2773 }
2774 }
2775 return {};
2776}
2777
2778// Given a range of operands, computes the number of matching prefix and
2779// suffix elements. This does not perform cross-element matching.
2780template <typename Range>
2781static size_t computeCommonPrefixLength(const Range &a, const Range &b) {
2782 size_t commonPrefixLength = 0;
2783 auto ia = a.begin();
2784 auto ib = b.begin();
2785
2786 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2787 if (*ia != *ib) {
2788 break;
2789 }
2790 }
2791
2792 return commonPrefixLength;
2793}
2794
2795static size_t getTotalWidth(ArrayRef<Value> operands) {
2796 size_t totalWidth = 0;
2797 for (auto operand : operands) {
2798 // getIntOrFloatBitWidth should never raise, since all arguments to
2799 // ConcatOp are integers.
2800 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2801 assert(width >= 0);
2802 totalWidth += width;
2803 }
2804 return totalWidth;
2805}
2806
2807/// Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise
2808/// comparison on common prefix and suffixes. Returns success() if a rewriting
2809/// happens. This handles both concat and replicate.
2810static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs,
2811 Operation *rhs,
2812 PatternRewriter &rewriter) {
2813 // It is safe to assume that [{lhsOperands, rhsOperands}.size() > 0] and
2814 // all elements have non-zero length. Both these invariants are verified
2815 // by the ConcatOp verifier.
2816 SmallVector<Value> lhsOperands, rhsOperands;
2817 getConcatOperands(lhs->getResult(0), lhsOperands);
2818 getConcatOperands(rhs->getResult(0), rhsOperands);
2819 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2820
2821 auto formCatOrReplicate = [&](Location loc,
2822 ArrayRef<Value> operands) -> Value {
2823 assert(!operands.empty());
2824 Value sameElement = operands[0];
2825 for (size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2826 if (sameElement != operands[i])
2827 sameElement = Value();
2828 if (sameElement)
2829 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2830 operands.size());
2831 return rewriter.createOrFold<ConcatOp>(loc, operands);
2832 };
2833
2834 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2835 Value rhs) -> LogicalResult {
2836 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2837 op.getTwoState());
2838 return success();
2839 };
2840
2841 size_t commonPrefixLength =
2842 computeCommonPrefixLength(lhsOperands, rhsOperands);
2843 if (commonPrefixLength == lhsOperands.size()) {
2844 // cat(a, b, c) == cat(a, b, c) -> 1
2845 bool result = applyCmpPredicateToEqualOperands(op.getPredicate());
2846 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
2847 APInt(1, result));
2848 return success();
2849 }
2850
2851 size_t commonSuffixLength = computeCommonPrefixLength(
2852 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2853
2854 size_t commonPrefixTotalWidth =
2855 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2856 size_t commonSuffixTotalWidth =
2857 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2858 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2859 .drop_back(commonSuffixLength);
2860 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2861 .drop_back(commonSuffixLength);
2862
2863 auto replaceWithoutReplicatingSignBit = [&]() {
2864 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2865 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2866 return replaceWith(op.getPredicate(), newLhs, newRhs);
2867 };
2868
2869 auto replaceWithReplicatingSignBit = [&]() {
2870 auto firstNonEmptyValue = lhsOperands[0];
2871 auto firstNonEmptyElemWidth =
2872 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2873 Value signBit = rewriter.createOrFold<ExtractOp>(
2874 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2875
2876 auto newLhs = rewriter.create<ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2877 auto newRhs = rewriter.create<ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2878 return replaceWith(op.getPredicate(), newLhs, newRhs);
2879 };
2880
2881 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2882 // scmp(cat(..x, b), cat(..y, b)) == scmp(cat(..x), cat(..y))
2883 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2884 return replaceWithoutReplicatingSignBit();
2885
2886 // scmp(cat(a, ..x, b), cat(a, ..y, b)) == scmp(cat(sgn(a), ..x),
2887 // cat(sgn(b), ..y)) Note that we cannot perform this optimization if
2888 // [width(b) = 0 && width(a) <= 1]. since that common prefix is the sign
2889 // bit. Doing the rewrite can result in an infinite loop.
2890 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2891 return replaceWithReplicatingSignBit();
2892
2893 } else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2894 // ucmp(cat(a, ..x, b), cat(a, ..y, b)) = ucmp(cat(..x), cat(..y))
2895 return replaceWithoutReplicatingSignBit();
2896 }
2897
2898 return failure();
2899}
2900
2901/// Given an equality comparison with a constant value and some operand that has
2902/// known bits, simplify the comparison to check only the unknown bits of the
2903/// input.
2904///
2905/// One simple example of this is that `concat(0, stuff) == 0` can be simplified
2906/// to `stuff == 0`, or `and(x, 3) == 0` can be simplified to
2907/// `extract x[1:0] == 0`
2909 ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst,
2910 PatternRewriter &rewriter) {
2911
2912 // If any of the known bits disagree with any of the comparison bits, then
2913 // we can constant fold this comparison right away.
2914 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2915 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2916 // If we discover a mismatch then we know an "eq" comparison is false
2917 // and a "ne" comparison is true!
2918 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2919 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
2920 APInt(1, result));
2921 return;
2922 }
2923
2924 // Check to see if we can prove the result entirely of the comparison (in
2925 // which we bail out early), otherwise build a list of values to concat and a
2926 // smaller constant to compare against.
2927 SmallVector<Value> newConcatOperands;
2928 auto newConstant = APInt::getZeroWidth();
2929
2930 // Ok, some (maybe all) bits are known and some others may be unknown.
2931 // Extract out segments of the operand and compare against the
2932 // corresponding bits.
2933 unsigned knownMSB = bitsKnown.countLeadingOnes();
2934
2935 Value operand = cmpOp.getLhs();
2936
2937 // Ok, some bits are known but others are not. Extract out sequences of
2938 // bits that are unknown and compare just those bits. We work from MSB to
2939 // LSB.
2940 while (knownMSB != bitsKnown.getBitWidth()) {
2941 // Drop any high bits that are known.
2942 if (knownMSB)
2943 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
2944
2945 // Find the span of unknown bits, and extract it.
2946 unsigned unknownBits = bitsKnown.countLeadingZeros();
2947 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
2948 auto spanOperand = rewriter.createOrFold<ExtractOp>(
2949 operand.getLoc(), operand, /*lowBit=*/lowBit,
2950 /*bitWidth=*/unknownBits);
2951 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
2952
2953 // Add this info to the concat we're generating.
2954 newConcatOperands.push_back(spanOperand);
2955 // FIXME(llvm merge, cc697fc292b0): concat doesn't work with zero bit values
2956 // newConstant = newConstant.concat(spanConstant);
2957 if (newConstant.getBitWidth() != 0)
2958 newConstant = newConstant.concat(spanConstant);
2959 else
2960 newConstant = spanConstant;
2961
2962 // Drop the unknown bits in prep for the next chunk.
2963 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
2964 bitsKnown = bitsKnown.trunc(newWidth);
2965 knownMSB = bitsKnown.countLeadingOnes();
2966 }
2967
2968 // If all the operands to the concat are foldable then we have an identity
2969 // situation where all the sub-elements equal each other. This implies that
2970 // the overall result is foldable.
2971 if (newConcatOperands.empty()) {
2972 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
2973 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, cmpOp,
2974 APInt(1, result));
2975 return;
2976 }
2977
2978 // If we have a single operand remaining, use it, otherwise form a concat.
2979 Value concatResult =
2980 rewriter.createOrFold<ConcatOp>(operand.getLoc(), newConcatOperands);
2981
2982 // Form the comparison against the smaller constant.
2983 auto newConstantOp = rewriter.create<hw::ConstantOp>(
2984 cmpOp.getOperand(1).getLoc(), newConstant);
2985
2986 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
2987 concatResult, newConstantOp,
2988 cmpOp.getTwoState());
2989}
2990
2991// Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), cst1^cst2).
2992static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp,
2993 const APInt &rhs,
2994 PatternRewriter &rewriter) {
2995 auto ip = rewriter.saveInsertionPoint();
2996 rewriter.setInsertionPoint(xorOp);
2997
2998 auto xorRHS = xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>();
2999 auto newRHS = rewriter.create<hw::ConstantOp>(xorRHS->getLoc(),
3000 xorRHS.getValue() ^ rhs);
3001 Value newLHS;
3002 switch (xorOp.getNumOperands()) {
3003 case 1:
3004 // This isn't common but is defined so we need to handle it.
3005 newLHS = rewriter.create<hw::ConstantOp>(xorOp.getLoc(),
3006 APInt::getZero(rhs.getBitWidth()));
3007 break;
3008 case 2:
3009 // The binary case is the most common.
3010 newLHS = xorOp.getOperand(0);
3011 break;
3012 default:
3013 // The general case forces us to form a new xor with the remaining operands.
3014 SmallVector<Value> newOperands(xorOp.getOperands());
3015 newOperands.pop_back();
3016 newLHS = rewriter.create<XorOp>(xorOp.getLoc(), newOperands, false);
3017 break;
3018 }
3019
3020 bool xorMultipleUses = !xorOp->hasOneUse();
3021
3022 // If the xor has multiple uses (not just the compare, then we need/want to
3023 // replace them as well.
3024 if (xorMultipleUses)
3025 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3026 false);
3027
3028 // Replace the comparison.
3029 rewriter.restoreInsertionPoint(ip);
3030 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, cmpOp, cmpOp.getPredicate(),
3031 newLHS, newRHS, false);
3032}
3033
3034LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3035 if (hasOperandsOutsideOfBlock(&*op))
3036 return failure();
3037
3038 APInt lhs, rhs;
3039
3040 // icmp 1, x -> icmp x, 1
3041 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3042 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3043 "Should be folded");
3044 replaceOpWithNewOpAndCopyName<ICmpOp>(
3045 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3046 op.getRhs(), op.getLhs(), op.getTwoState());
3047 return success();
3048 }
3049
3050 // Canonicalize with RHS constant
3051 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3052 auto getConstant = [&](APInt constant) -> Value {
3053 return rewriter.create<hw::ConstantOp>(op.getLoc(), std::move(constant));
3054 };
3055
3056 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3057 Value rhs) -> LogicalResult {
3058 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, predicate, lhs, rhs,
3059 op.getTwoState());
3060 return success();
3061 };
3062
3063 auto replaceWithConstantI1 = [&](bool constant) -> LogicalResult {
3064 replaceOpWithNewOpAndCopyName<hw::ConstantOp>(rewriter, op,
3065 APInt(1, constant));
3066 return success();
3067 };
3068
3069 switch (op.getPredicate()) {
3070 case ICmpPredicate::slt:
3071 // x < max -> x != max
3072 if (rhs.isMaxSignedValue())
3073 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3074 // x < min -> false
3075 if (rhs.isMinSignedValue())
3076 return replaceWithConstantI1(0);
3077 // x < min+1 -> x == min
3078 if ((rhs - 1).isMinSignedValue())
3079 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3080 getConstant(rhs - 1));
3081 break;
3082 case ICmpPredicate::sgt:
3083 // x > min -> x != min
3084 if (rhs.isMinSignedValue())
3085 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3086 // x > max -> false
3087 if (rhs.isMaxSignedValue())
3088 return replaceWithConstantI1(0);
3089 // x > max-1 -> x == max
3090 if ((rhs + 1).isMaxSignedValue())
3091 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3092 getConstant(rhs + 1));
3093 break;
3094 case ICmpPredicate::ult:
3095 // x < max -> x != max
3096 if (rhs.isAllOnes())
3097 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3098 // x < min -> false
3099 if (rhs.isZero())
3100 return replaceWithConstantI1(0);
3101 // x < min+1 -> x == min
3102 if ((rhs - 1).isZero())
3103 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3104 getConstant(rhs - 1));
3105
3106 // x < 0xE0 -> extract(x, 5..7) != 0b111
3107 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3108 rhs.getBitWidth()) {
3109 auto numOnes = rhs.countLeadingOnes();
3110 auto smaller = rewriter.create<ExtractOp>(
3111 op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
3112 return replaceWith(ICmpPredicate::ne, smaller,
3113 getConstant(APInt::getAllOnes(numOnes)));
3114 }
3115
3116 break;
3117 case ICmpPredicate::ugt:
3118 // x > min -> x != min
3119 if (rhs.isZero())
3120 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3121 // x > max -> false
3122 if (rhs.isAllOnes())
3123 return replaceWithConstantI1(0);
3124 // x > max-1 -> x == max
3125 if ((rhs + 1).isAllOnes())
3126 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3127 getConstant(rhs + 1));
3128
3129 // x > 0x07 -> extract(x, 3..7) != 0b00000
3130 if ((rhs + 1).isPowerOf2()) {
3131 auto numOnes = rhs.countTrailingOnes();
3132 auto newWidth = rhs.getBitWidth() - numOnes;
3133 auto smaller = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(),
3134 numOnes, newWidth);
3135 return replaceWith(ICmpPredicate::ne, smaller,
3136 getConstant(APInt::getZero(newWidth)));
3137 }
3138
3139 break;
3140 case ICmpPredicate::sle:
3141 // x <= max -> true
3142 if (rhs.isMaxSignedValue())
3143 return replaceWithConstantI1(1);
3144 // x <= c -> x < (c+1)
3145 return replaceWith(ICmpPredicate::slt, op.getLhs(), getConstant(rhs + 1));
3146 case ICmpPredicate::sge:
3147 // x >= min -> true
3148 if (rhs.isMinSignedValue())
3149 return replaceWithConstantI1(1);
3150 // x >= c -> x > (c-1)
3151 return replaceWith(ICmpPredicate::sgt, op.getLhs(), getConstant(rhs - 1));
3152 case ICmpPredicate::ule:
3153 // x <= max -> true
3154 if (rhs.isAllOnes())
3155 return replaceWithConstantI1(1);
3156 // x <= c -> x < (c+1)
3157 return replaceWith(ICmpPredicate::ult, op.getLhs(), getConstant(rhs + 1));
3158 case ICmpPredicate::uge:
3159 // x >= min -> true
3160 if (rhs.isZero())
3161 return replaceWithConstantI1(1);
3162 // x >= c -> x > (c-1)
3163 return replaceWith(ICmpPredicate::ugt, op.getLhs(), getConstant(rhs - 1));
3164 case ICmpPredicate::eq:
3165 if (rhs.getBitWidth() == 1) {
3166 if (rhs.isZero()) {
3167 // x == 0 -> x ^ 1
3168 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3169 getConstant(APInt(1, 1)),
3170 op.getTwoState());
3171 return success();
3172 }
3173 if (rhs.isAllOnes()) {
3174 // x == 1 -> x
3175 replaceOpAndCopyName(rewriter, op, op.getLhs());
3176 return success();
3177 }
3178 }
3179 break;
3180 case ICmpPredicate::ne:
3181 if (rhs.getBitWidth() == 1) {
3182 if (rhs.isZero()) {
3183 // x != 0 -> x
3184 replaceOpAndCopyName(rewriter, op, op.getLhs());
3185 return success();
3186 }
3187 if (rhs.isAllOnes()) {
3188 // x != 1 -> x ^ 1
3189 replaceOpWithNewOpAndCopyName<XorOp>(rewriter, op, op.getLhs(),
3190 getConstant(APInt(1, 1)),
3191 op.getTwoState());
3192 return success();
3193 }
3194 }
3195 break;
3196 case ICmpPredicate::ceq:
3197 case ICmpPredicate::cne:
3198 case ICmpPredicate::weq:
3199 case ICmpPredicate::wne:
3200 break;
3201 }
3202
3203 // We have some specific optimizations for comparison with a constant that
3204 // are only supported for equality comparisons.
3205 if (op.getPredicate() == ICmpPredicate::eq ||
3206 op.getPredicate() == ICmpPredicate::ne) {
3207 // Simplify `icmp(value_with_known_bits, rhscst)` into some extracts
3208 // with a smaller constant. We only support equality comparisons for
3209 // this.
3210 auto knownBits = computeKnownBits(op.getLhs());
3211 if (!knownBits.isUnknown())
3212 return combineEqualityICmpWithKnownBitsAndConstant(op, knownBits, rhs,
3213 rewriter),
3214 success();
3215
3216 // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b),
3217 // cst1^cst2).
3218 if (auto xorOp = op.getLhs().getDefiningOp<XorOp>())
3219 if (xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>())
3220 return combineEqualityICmpWithXorOfConstant(op, xorOp, rhs, rewriter),
3221 success();
3222
3223 // Simplify icmp eq(replicate(v, n), c) -> icmp eq(v, c) if c is zero or
3224 // all one.
3225 if (auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3226 if (rhs.isAllOnes() || rhs.isZero()) {
3227 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3228 auto cst = rewriter.create<hw::ConstantOp>(
3229 op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(width)
3230 : APInt::getZero(width));
3231 replaceOpWithNewOpAndCopyName<ICmpOp>(rewriter, op, op.getPredicate(),
3232 replicateOp.getInput(), cst,
3233 op.getTwoState());
3234 return success();
3235 }
3236 }
3237 }
3238
3239 // icmp(cat(prefix, a, b, suffix), cat(prefix, c, d, suffix)) => icmp(cat(a,
3240 // b), cat(c, d)). contains special handling for sign bit in signed
3241 // compressions.
3242 if (Operation *opLHS = op.getLhs().getDefiningOp())
3243 if (Operation *opRHS = op.getRhs().getDefiningOp())
3244 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3245 isa<ConcatOp, ReplicateOp>(opRHS)) {
3246 if (succeeded(matchAndRewriteCompareConcat(op, opLHS, opRHS, rewriter)))
3247 return success();
3248 }
3249
3250 return failure();
3251}
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool hasOperandsOutsideOfBlock(Operation *op)
In comb, we assume no knowledge of the semantics of cross-block dataflow.
Definition CombFolds.cpp:32
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
Definition CombFolds.cpp:58
static OpTy replaceOpWithNewOpAndCopyName(PatternRewriter &rewriter, Operation *op, Args &&...args)
A wrapper of PatternRewriter::replaceOpWithNewOp to propagate "sv.namehint" attribute.
Definition CombFolds.cpp:88
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
Definition CombFolds.cpp:52
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
static bool hasSVAttributes(Operation *op)
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
Definition CombFolds.cpp:44
static void replaceOpAndCopyName(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition CombFolds.cpp:73
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(low_bit, result_type, input=None)
Definition comb.py:187
create(elements)
Definition hw.py:483
create(data_type, value)
Definition hw.py:433
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition CombOps.cpp:48
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1