CIRCT 22.0.0git
Loading...
Searching...
No Matches
CombFolds.cpp
Go to the documentation of this file.
1//===- CombFolds.cpp - Folds + Canonicalization for Comb operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Diagnostics.h"
14#include "mlir/IR/Matchers.h"
15#include "mlir/IR/PatternMatch.h"
16#include "llvm/ADT/SetVector.h"
17#include "llvm/ADT/SmallBitVector.h"
18#include "llvm/ADT/TypeSwitch.h"
19#include "llvm/Support/KnownBits.h"
20
21using namespace mlir;
22using namespace circt;
23using namespace comb;
24using namespace matchers;
25
26// Returns true if the op has one of its own results as an operand.
27static bool isOpTriviallyRecursive(Operation *op) {
28 return llvm::any_of(op->getOperands(), [op](auto operand) {
29 return operand.getDefiningOp() == op;
30 });
31}
32
33/// Create a new instance of a generic operation that only has value operands,
34/// and has a single result value whose type matches the first operand.
35///
36/// This should not be used to create instances of ops with attributes or with
37/// more complicated type signatures.
38static Value createGenericOp(Location loc, OperationName name,
39 ArrayRef<Value> operands, OpBuilder &builder) {
40 OperationState state(loc, name);
41 state.addOperands(operands);
42 state.addTypes(operands[0].getType());
43 return builder.create(state)->getResult(0);
44}
45
46static TypedAttr getIntAttr(const APInt &value, MLIRContext *context) {
47 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
48 value);
49}
50
51/// Flatten concat and mux operands into a vector.
52static void getConcatOperands(Value v, SmallVectorImpl<Value> &result) {
53 if (auto concat = v.getDefiningOp<ConcatOp>()) {
54 for (auto op : concat.getOperands())
55 getConcatOperands(op, result);
56 } else if (auto repl = v.getDefiningOp<ReplicateOp>()) {
57 for (size_t i = 0, e = repl.getMultiple(); i != e; ++i)
58 getConcatOperands(repl.getOperand(), result);
59 } else {
60 result.push_back(v);
61 }
62}
63
64// Return true if the op has SV attributes. Note that we cannot use a helper
65// function `hasSVAttributes` defined under SV dialect because of a cyclic
66// dependency.
67static bool hasSVAttributes(Operation *op) {
68 return op->hasAttr("sv.attributes");
69}
70
71namespace {
72template <typename SubType>
73struct ComplementMatcher {
74 SubType lhs;
75 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
76 bool match(Operation *op) {
77 auto xorOp = dyn_cast<XorOp>(op);
78 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
79 }
80};
81} // end anonymous namespace
82
83template <typename SubType>
84static inline ComplementMatcher<SubType> m_Complement(const SubType &subExpr) {
85 return ComplementMatcher<SubType>(subExpr);
86}
87
88/// Return true if the op will be flattened afterwards. Op will be flattend if
89/// it has a single user which has a same op type. User must be in same block.
90static bool shouldBeFlattened(Operation *op) {
91 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
92 "must be commutative operations"));
93 if (op->hasOneUse()) {
94 auto *user = *op->getUsers().begin();
95 return user->getName() == op->getName() &&
96 op->getAttrOfType<UnitAttr>("twoState") ==
97 user->getAttrOfType<UnitAttr>("twoState") &&
98 op->getBlock() == user->getBlock();
99 }
100 return false;
101}
102
103/// Flattens a single input in `op` if `hasOneUse` is true and it can be defined
104/// as an Op. Returns true if successful, and false otherwise.
105///
106/// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true
107///
108static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) {
109 // Skip if the operation should be flattened by another operation.
110 if (shouldBeFlattened(op))
111 return false;
112
113 auto inputs = op->getOperands();
114
115 SmallVector<Value, 4> newOperands;
116 SmallVector<Location, 4> newLocations{op->getLoc()};
117 newOperands.reserve(inputs.size());
118 struct Element {
119 decltype(inputs.begin()) current, end;
120 };
121
122 SmallVector<Element> worklist;
123 worklist.push_back({inputs.begin(), inputs.end()});
124 bool binFlag = op->hasAttrOfType<UnitAttr>("twoState");
125 bool changed = false;
126 while (!worklist.empty()) {
127 auto &element = worklist.back(); // Do not pop. Take ref.
128
129 // Pop when we finished traversing the current operand range.
130 if (element.current == element.end) {
131 worklist.pop_back();
132 continue;
133 }
134
135 Value value = *element.current++;
136 auto *flattenOp = value.getDefiningOp();
137 // If not defined by a compatible operation of the same kind and
138 // from the same block, keep this as-is.
139 if (!flattenOp || flattenOp->getName() != op->getName() ||
140 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>("twoState") ||
141 flattenOp->getBlock() != op->getBlock()) {
142 newOperands.push_back(value);
143 continue;
144 }
145
146 // Don't duplicate logic when it has multiple uses.
147 if (!value.hasOneUse()) {
148 // We can fold a multi-use binary operation into this one if this allows a
149 // constant to fold though. For example, fold
150 // (or a, b, c, (or d, cst1), cst2) --> (or a, b, c, d, cst1, cst2)
151 // since the constants will both fold and we end up with the equiv cost.
152 //
153 // We don't do this for add/mul because the hardware won't be shared
154 // between the two ops if duplicated.
155 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
156 !flattenOp->getOperand(1).getDefiningOp<hw::ConstantOp>() ||
157 !inputs.back().getDefiningOp<hw::ConstantOp>()) {
158 newOperands.push_back(value);
159 continue;
160 }
161 }
162
163 changed = true;
164
165 // Otherwise, push operands into worklist.
166 auto flattenOpInputs = flattenOp->getOperands();
167 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
168 newLocations.push_back(flattenOp->getLoc());
169 }
170
171 if (!changed)
172 return false;
173
174 Value result = createGenericOp(FusedLoc::get(op->getContext(), newLocations),
175 op->getName(), newOperands, rewriter);
176 if (binFlag)
177 result.getDefiningOp()->setAttr("twoState", rewriter.getUnitAttr());
178
179 replaceOpAndCopyNamehint(rewriter, op, result);
180 return true;
181}
182
183// Given a range of uses of an operation, find the lowest and highest bits
184// inclusive that are ever referenced. The range of uses must not be empty.
185static std::pair<size_t, size_t>
186getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits,
187 size_t originalOpWidth) {
188 auto users = op->getUsers();
189 assert(!users.empty() &&
190 "getLowestBitAndHighestBitRequired cannot operate on "
191 "a empty list of uses.");
192
193 // when we don't want to narrowTrailingBits (namely in arithmetic
194 // operations), forcing lowestBitRequired = 0
195 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
196 size_t highestBitRequired = 0;
197
198 for (auto *user : users) {
199 if (auto extractOp = dyn_cast<ExtractOp>(user)) {
200 size_t lowBit = extractOp.getLowBit();
201 size_t highBit =
202 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
203 highestBitRequired = std::max(highestBitRequired, highBit);
204 lowestBitRequired = std::min(lowestBitRequired, lowBit);
205 continue;
206 }
207
208 highestBitRequired = originalOpWidth - 1;
209 lowestBitRequired = 0;
210 break;
211 }
212
213 return {lowestBitRequired, highestBitRequired};
214}
215
216template <class OpTy>
217static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits,
218 PatternRewriter &rewriter) {
219 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
220 if (!opType)
221 return false;
222
223 auto range = getLowestBitAndHighestBitRequired(op, narrowTrailingBits,
224 opType.getWidth());
225 if (range.second + 1 == opType.getWidth() && range.first == 0)
226 return false;
227
228 SmallVector<Value> args;
229 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
230 for (auto inop : op.getOperands()) {
231 // deal with muxes here
232 if (inop.getType() != op.getType())
233 args.push_back(inop);
234 else
235 args.push_back(rewriter.createOrFold<ExtractOp>(inop.getLoc(), newType,
236 inop, range.first));
237 }
238 auto newop = OpTy::create(rewriter, op.getLoc(), newType, args);
239 newop->setDialectAttrs(op->getDialectAttrs());
240 if (op.getTwoState())
241 newop.setTwoState(true);
242
243 Value newResult = newop.getResult();
244 if (range.first)
245 newResult = rewriter.createOrFold<ConcatOp>(
246 op.getLoc(), newResult,
247 hw::ConstantOp::create(rewriter, op.getLoc(),
248 APInt::getZero(range.first)));
249 if (range.second + 1 < opType.getWidth())
250 newResult = rewriter.createOrFold<ConcatOp>(
251 op.getLoc(),
253 rewriter, op.getLoc(),
254 APInt::getZero(opType.getWidth() - range.second - 1)),
255 newResult);
256 rewriter.replaceOp(op, newResult);
257 return true;
258}
259
260//===----------------------------------------------------------------------===//
261// Unary Operations
262//===----------------------------------------------------------------------===//
263
264OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
265 if (isOpTriviallyRecursive(*this))
266 return {};
267
268 // Replicate one time -> noop.
269 if (cast<IntegerType>(getType()).getWidth() ==
270 getInput().getType().getIntOrFloatBitWidth())
271 return getInput();
272
273 // Constant fold.
274 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
275 if (input.getValue().getBitWidth() == 1) {
276 if (input.getValue().isZero())
277 return getIntAttr(
278 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
279 getContext());
280 return getIntAttr(
281 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
282 getContext());
283 }
284
285 APInt result = APInt::getZeroWidth();
286 for (auto i = getMultiple(); i != 0; --i)
287 result = result.concat(input.getValue());
288 return getIntAttr(result, getContext());
289 }
290
291 return {};
292}
293
294OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
295 if (isOpTriviallyRecursive(*this))
296 return {};
297
298 // Constant fold.
299 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
300 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
301
302 return {};
303}
304
305//===----------------------------------------------------------------------===//
306// Binary Operations
307//===----------------------------------------------------------------------===//
308
309/// Performs constant folding `calculate` with element-wise behavior on the two
310/// attributes in `operands` and returns the result if possible.
311static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
312 hw::PEO paramOpcode) {
313 assert(operands.size() == 2 && "binary op takes two operands");
314 if (!operands[0] || !operands[1])
315 return {};
316
317 // Fold constants with ParamExprAttr::get which handles simple constants as
318 // well as parameter expressions.
319 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
320 cast<TypedAttr>(operands[1]));
321}
322
323OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
324 if (isOpTriviallyRecursive(*this))
325 return {};
326
327 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
328 if (rhs.getValue().isZero())
329 return getOperand(0);
330
331 unsigned width = getType().getIntOrFloatBitWidth();
332 if (rhs.getValue().uge(width))
333 return getIntAttr(APInt::getZero(width), getContext());
334 }
335 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::Shl);
336}
337
338LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
340 return failure();
341
342 // ShlOp(x, cst) -> Concat(Extract(x), zeros)
343 APInt value;
344 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
345 return failure();
346
347 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
348 if (value.ugt(width))
349 value = width;
350 unsigned shift = value.getZExtValue();
351
352 // This case is handled by fold.
353 if (width <= shift || shift == 0)
354 return failure();
355
356 auto zeros =
357 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getZero(shift));
358
359 // Remove the high bits which would be removed by the Shl.
360 auto extract =
361 ExtractOp::create(rewriter, op.getLoc(), op.getLhs(), 0, width - shift);
362
363 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
364 return success();
365}
366
367OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
368 if (isOpTriviallyRecursive(*this))
369 return {};
370
371 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
372 if (rhs.getValue().isZero())
373 return getOperand(0);
374
375 unsigned width = getType().getIntOrFloatBitWidth();
376 if (rhs.getValue().uge(width))
377 return getIntAttr(APInt::getZero(width), getContext());
378 }
379 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrU);
380}
381
382LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
384 return failure();
385
386 // ShrUOp(x, cst) -> Concat(zeros, Extract(x))
387 APInt value;
388 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
389 return failure();
390
391 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
392 if (value.ugt(width))
393 value = width;
394 unsigned shift = value.getZExtValue();
395
396 // This case is handled by fold.
397 if (width <= shift || shift == 0)
398 return failure();
399
400 auto zeros =
401 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getZero(shift));
402
403 // Remove the low bits which would be removed by the Shr.
404 auto extract = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(), shift,
405 width - shift);
406
407 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
408 return success();
409}
410
411OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
412 if (isOpTriviallyRecursive(*this))
413 return {};
414
415 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
416 if (rhs.getValue().isZero())
417 return getOperand(0);
418 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrS);
419}
420
421LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
423 return failure();
424
425 // ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x))
426 APInt value;
427 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
428 return failure();
429
430 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
431 if (value.ugt(width))
432 value = width;
433 unsigned shift = value.getZExtValue();
434
435 auto topbit =
436 rewriter.createOrFold<ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
437 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
438
439 if (width == shift) {
440 replaceOpAndCopyNamehint(rewriter, op, {sext});
441 return success();
442 }
443
444 auto extract = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(), shift,
445 width - shift);
446
447 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
448 return success();
449}
450
451//===----------------------------------------------------------------------===//
452// Other Operations
453//===----------------------------------------------------------------------===//
454
455OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
456 if (isOpTriviallyRecursive(*this))
457 return {};
458
459 // If we are extracting the entire input, then return it.
460 if (getInput().getType() == getType())
461 return getInput();
462
463 // Constant fold.
464 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
465 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
466 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
467 getContext());
468 }
469 return {};
470}
471
472// Transforms extract(lo, cat(a, b, c, d, e)) into
473// cat(extract(lo1, b), c, extract(lo2, d)).
474// innerCat must be the argument of the provided ExtractOp.
476 ConcatOp innerCat,
477 PatternRewriter &rewriter) {
478 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
479 size_t beginOfFirstRelevantElement = 0;
480 auto it = reversedConcatArgs.begin();
481 size_t lowBit = op.getLowBit();
482
483 // This loop finds the first concatArg that is covered by the ExtractOp
484 for (; it != reversedConcatArgs.end(); it++) {
485 assert(beginOfFirstRelevantElement <= lowBit &&
486 "incorrectly moved past an element that lowBit has coverage over");
487 auto operand = *it;
488
489 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
490 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
491 // A bit other than the first bit will be used in this element.
492 // ...... ........ ...
493 // ^---lowBit
494 // ^---beginOfFirstRelevantElement
495 //
496 // Edge-case close to the end of the range.
497 // ...... ........ ...
498 // ^---(position + operandWidth)
499 // ^---lowBit
500 // ^---beginOfFirstRelevantElement
501 //
502 // Edge-case close to the beginning of the rang
503 // ...... ........ ...
504 // ^---lowBit
505 // ^---beginOfFirstRelevantElement
506 //
507 break;
508 }
509
510 // extraction discards this element.
511 // ...... ........ ...
512 // | ^---lowBit
513 // ^---beginOfFirstRelevantElement
514 beginOfFirstRelevantElement += operandWidth;
515 }
516 assert(it != reversedConcatArgs.end() &&
517 "incorrectly failed to find an element which contains coverage of "
518 "lowBit");
519
520 SmallVector<Value> reverseConcatArgs;
521 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
522 size_t extractLo = lowBit - beginOfFirstRelevantElement;
523
524 // Transform individual arguments of innerCat(..., a, b, c,) into
525 // [ extract(a), b, extract(c) ], skipping an extract operation where
526 // possible.
527 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
528 auto concatArg = *it;
529 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
530 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
531
532 if (widthToConsume == operandWidth && extractLo == 0) {
533 reverseConcatArgs.push_back(concatArg);
534 } else {
535 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
536 reverseConcatArgs.push_back(
537 ExtractOp::create(rewriter, op.getLoc(), resultType, *it, extractLo));
538 }
539
540 widthRemaining -= widthToConsume;
541
542 // Beyond the first element, all elements are extracted from position 0.
543 extractLo = 0;
544 }
545
546 if (reverseConcatArgs.size() == 1) {
547 replaceOpAndCopyNamehint(rewriter, op, reverseConcatArgs[0]);
548 } else {
549 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
550 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
551 }
552 return success();
553}
554
555// Transforms extract(lo, replicate(a, N)) into replicate(a, N-c).
556static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate,
557 PatternRewriter &rewriter) {
558 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
559 auto replicateEltWidth =
560 replicate.getOperand().getType().getIntOrFloatBitWidth();
561
562 // If the extract starts at the base of an element and is an even multiple,
563 // we can replace the extract with a smaller replicate.
564 if (op.getLowBit() % replicateEltWidth == 0 &&
565 extractResultWidth % replicateEltWidth == 0) {
566 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
567 replicate.getOperand());
568 return true;
569 }
570
571 // If the extract is completely contained in one element, extract from the
572 // element.
573 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
574 replicateEltWidth) {
575 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
576 rewriter, op, op.getType(), replicate.getOperand(),
577 op.getLowBit() % replicateEltWidth);
578 return true;
579 }
580
581 // We don't currently handle the case of extracting from non-whole elements,
582 // e.g. `extract (replicate 2-bit-thing, N), 1`.
583 return false;
584}
585
586LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
588 return failure();
589 auto *inputOp = op.getInput().getDefiningOp();
590
591 // This turns out to be incredibly expensive. Disable until performance is
592 // addressed.
593#if 0
594 // If the extracted bits are all known, then return the result.
595 auto knownBits = computeKnownBits(op.getInput())
596 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
597 op.getLowBit());
598 if (knownBits.isConstant()) {
599 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
600 knownBits.getConstant());
601 return success();
602 }
603#endif
604
605 // extract(olo, extract(ilo, x)) = extract(olo + ilo, x)
606 if (auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
607 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
608 rewriter, op, op.getType(), innerExtract.getInput(),
609 innerExtract.getLowBit() + op.getLowBit());
610 return success();
611 }
612
613 // extract(lo, cat(a, b, c, d, e)) = cat(extract(lo1, b), c, extract(lo2, d))
614 if (auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
615 return extractConcatToConcatExtract(op, innerCat, rewriter);
616
617 // extract(lo, replicate(a))
618 if (auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
619 if (extractFromReplicate(op, replicate, rewriter))
620 return success();
621
622 // `extract(and(a, cst))` -> `extract(a)` when the relevant bits of the
623 // and/or/xor are not modifying the extracted bits.
624 if (inputOp && inputOp->getNumOperands() == 2 &&
625 isa<AndOp, OrOp, XorOp>(inputOp)) {
626 if (auto cstRHS = inputOp->getOperand(1).getDefiningOp<hw::ConstantOp>()) {
627 auto extractedCst = cstRHS.getValue().extractBits(
628 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
629 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
630 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
631 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
632 return success();
633 }
634
635 // `extract(and(a, cst))` -> `concat(extract(a), 0)` if we only need one
636 // extract to represent the result. Turning it into a pile of extracts is
637 // always fine by our cost model, but we don't want to explode things into
638 // a ton of bits because it will bloat the IR and generated Verilog.
639 if (isa<AndOp>(inputOp)) {
640 // For our cost model, we only do this if the bit pattern is a
641 // contiguous series of ones.
642 unsigned lz = extractedCst.countLeadingZeros();
643 unsigned tz = extractedCst.countTrailingZeros();
644 unsigned pop = extractedCst.popcount();
645 if (extractedCst.getBitWidth() - lz - tz == pop) {
646 auto resultTy = rewriter.getIntegerType(pop);
647 SmallVector<Value> resultElts;
648 if (lz)
649 resultElts.push_back(hw::ConstantOp::create(rewriter, op.getLoc(),
650 APInt::getZero(lz)));
651 resultElts.push_back(rewriter.createOrFold<ExtractOp>(
652 op.getLoc(), resultTy, inputOp->getOperand(0),
653 op.getLowBit() + tz));
654 if (tz)
655 resultElts.push_back(hw::ConstantOp::create(rewriter, op.getLoc(),
656 APInt::getZero(tz)));
657 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
658 return success();
659 }
660 }
661 }
662 }
663
664 // `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is
665 // extracted.
666 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
667 if (auto shlOp = dyn_cast<ShlOp>(inputOp)) {
668 // Don't canonicalize if the shift is multiply used.
669 if (shlOp->hasOneUse())
670 if (auto lhsCst = shlOp.getLhs().getDefiningOp<hw::ConstantOp>())
671 if (lhsCst.getValue().isOne()) {
672 auto newCst = hw::ConstantOp::create(
673 rewriter, shlOp.getLoc(),
674 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
675 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
676 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
677 false);
678 return success();
679 }
680 }
681
682 return failure();
683}
684
685//===----------------------------------------------------------------------===//
686// Associative Variadic operations
687//===----------------------------------------------------------------------===//
688
689// Reduce all operands to a single value (either integer constant or parameter
690// expression) if all the operands are constants.
691static Attribute constFoldAssociativeOp(ArrayRef<Attribute> operands,
692 hw::PEO paramOpcode) {
693 assert(operands.size() > 1 && "caller should handle one-operand case");
694 // We can only fold anything in the case where all operands are known to be
695 // constants. Check the least common one first for an early out.
696 if (!operands[1] || !operands[0])
697 return {};
698
699 // This will fold to a simple constant if all operands are constant.
700 if (llvm::all_of(operands.drop_front(2),
701 [&](Attribute in) { return !!in; })) {
702 SmallVector<mlir::TypedAttr> typedOperands;
703 typedOperands.reserve(operands.size());
704 for (auto operand : operands) {
705 if (auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
706 typedOperands.push_back(typedOperand);
707 else
708 break;
709 }
710 if (typedOperands.size() == operands.size())
711 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
712 }
713
714 return {};
715}
716
717/// When we find a logical operation (and, or, xor) with a constant e.g.
718/// `X & 42`, we want to push the constant into the computation of X if it leads
719/// to simplification.
720///
721/// This function handles the case where the logical operation has a concat
722/// operand. We check to see if we can simplify the concat, e.g. when it has
723/// constant operands.
724///
725/// This returns true when a simplification happens.
726static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp,
727 size_t concatIdx, const APInt &cst,
728 PatternRewriter &rewriter) {
729 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<ConcatOp>();
730 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
731
732 // Check to see if any operands can be simplified by pushing the logical op
733 // into all parts of the concat.
734 bool canSimplify =
735 llvm::any_of(concatOp->getOperands(), [&](Value operand) -> bool {
736 auto *operandOp = operand.getDefiningOp();
737 if (!operandOp)
738 return false;
739
740 // If the concat has a constant operand then we can transform this.
741 if (isa<hw::ConstantOp>(operandOp))
742 return true;
743 // If the concat has the same logical operation and that operation has
744 // a constant operation than we can fold it into that suboperation.
745 return operandOp->getName() == logicalOp->getName() &&
746 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
747 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
748 });
749
750 if (!canSimplify)
751 return false;
752
753 // Create a new instance of the logical operation. We have to do this the
754 // hard way since we're generic across a family of different ops.
755 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
756 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
757 rewriter);
758 };
759
760 // Ok, let's do the transformation. We do this by slicing up the constant
761 // for each unit of the concat and duplicate the operation into the
762 // sub-operand.
763 SmallVector<Value> newConcatOperands;
764 newConcatOperands.reserve(concatOp->getNumOperands());
765
766 // Work from MSB to LSB.
767 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
768 for (Value operand : concatOp->getOperands()) {
769 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
770 nextOperandBit -= operandWidth;
771 // Take a slice of the constant.
772 auto eltCst =
773 hw::ConstantOp::create(rewriter, logicalOp->getLoc(),
774 cst.lshr(nextOperandBit).trunc(operandWidth));
775
776 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
777 }
778
779 // Create the concat, and the rest of the logical op if we need it.
780 Value newResult =
781 ConcatOp::create(rewriter, concatOp.getLoc(), newConcatOperands);
782
783 // If we had a variadic logical op on the top level, then recreate it with the
784 // new concat and without the constant operand.
785 if (logicalOp->getNumOperands() > 2) {
786 auto origOperands = logicalOp->getOperands();
787 SmallVector<Value> operands;
788 // Take any stuff before the concat.
789 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
790 // Take any stuff after the concat but before the constant.
791 operands.append(origOperands.begin() + concatIdx + 1,
792 origOperands.begin() + (origOperands.size() - 1));
793 // Include the new concat.
794 operands.push_back(newResult);
795 newResult = createLogicalOp(operands);
796 }
797
798 replaceOpAndCopyNamehint(rewriter, logicalOp, newResult);
799 return true;
800}
801
802// Determines whether the inputs to a logical element are of opposite
803// comparisons and can lowered into a constant.
804static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands) {
805 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
806
807 for (auto op : operands) {
808 if (auto icmpOp = op.getDefiningOp<ICmpOp>();
809 icmpOp && icmpOp.getTwoState()) {
810 auto predicate = icmpOp.getPredicate();
811 auto lhs = icmpOp.getLhs();
812 auto rhs = icmpOp.getRhs();
813 if (seenPredicates.contains(
814 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
815 return true;
816
817 seenPredicates.insert({predicate, lhs, rhs});
818 }
819 }
820 return false;
821}
822
823OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
824 if (isOpTriviallyRecursive(*this))
825 return {};
826
827 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).getWidth());
828
829 auto inputs = adaptor.getInputs();
830
831 // and(x, 01, 10) -> 00 -- annulment.
832 for (auto operand : inputs) {
833 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
834 if (!attr)
835 continue;
836 value &= attr.getValue();
837 if (value.isZero())
838 return getIntAttr(value, getContext());
839 }
840
841 // and(x, -1) -> x.
842 if (inputs.size() == 2)
843 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
844 if (intAttr.getValue().isAllOnes())
845 return getInputs()[0];
846
847 // and(x, x, x) -> x. This also handles and(x) -> x.
848 if (llvm::all_of(getInputs(),
849 [&](auto in) { return in == this->getInputs()[0]; }))
850 return getInputs()[0];
851
852 // and(..., x, ..., ~x, ...) -> 0
853 for (Value arg : getInputs()) {
854 Value subExpr;
855 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
856 for (Value arg2 : getInputs())
857 if (arg2 == subExpr)
858 return getIntAttr(
859 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
860 getContext());
861 }
862 }
863
864 // x0 = icmp(pred, x, y)
865 // x1 = icmp(!pred, x, y)
866 // and(x0, x1) -> 0
868 return getIntAttr(APInt::getZero(cast<IntegerType>(getType()).getWidth()),
869 getContext());
870
871 // Constant fold
872 return constFoldAssociativeOp(inputs, hw::PEO::And);
873}
874
875/// Returns a single common operand that all inputs of the operation `op` can
876/// be traced back to, or an empty `Value` if no such operand exists.
877///
878/// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a`
879/// (assuming the bit-width of `a` is `n`).
880template <typename Op>
881static Value getCommonOperand(Op op) {
882 if (!op.getType().isInteger(1))
883 return Value();
884
885 auto inputs = op.getInputs();
886 size_t size = inputs.size();
887
888 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
889 if (!sourceOp)
890 return Value();
891 Value source = sourceOp.getOperand();
892
893 // Fast path: the input size is not equal to the width of the source.
894 if (size != source.getType().getIntOrFloatBitWidth())
895 return Value();
896
897 // Tracks the bits that were encountered.
898 llvm::BitVector bits(size);
899 bits.set(sourceOp.getLowBit());
900
901 for (size_t i = 1; i != size; ++i) {
902 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
903 if (!extractOp || extractOp.getOperand() != source)
904 return Value();
905 bits.set(extractOp.getLowBit());
906 }
907
908 return bits.all() ? source : Value();
909}
910
911/// Canonicalize an idempotent operation `op` so that only one input of any kind
912/// occurs.
913///
914/// Example: `and(x, y, x, z)` -> `and(x, y, z)`
915template <typename Op>
916static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
917 // Depth limit to search, in operations. Chosen arbitrarily, keep small.
918 constexpr unsigned limit = 3;
919 auto inputs = op.getInputs();
920
921 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
922 llvm::SmallDenseSet<Op, 8> checked;
923 checked.insert(op);
924
925 struct OpWithDepth {
926 Op op;
927 unsigned depth;
928 };
929 llvm::SmallVector<OpWithDepth, 8> worklist;
930
931 auto enqueue = [&worklist, &checked, &op](Value input, unsigned depth) {
932 // Add to worklist if within depth limit, is defined in the same block by
933 // the same kind of operation, has same two-state-ness, and not enqueued
934 // previously.
935 if (depth < limit && input.getParentBlock() == op->getBlock()) {
936 auto inputOp = input.template getDefiningOp<Op>();
937 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
938 checked.insert(inputOp).second)
939 worklist.push_back({inputOp, depth + 1});
940 }
941 };
942
943 for (auto input : uniqueInputs)
944 enqueue(input, 0);
945
946 while (!worklist.empty()) {
947 auto item = worklist.pop_back_val();
948
949 for (auto input : item.op.getInputs()) {
950 uniqueInputs.remove(input);
951 enqueue(input, item.depth);
952 }
953 }
954
955 if (uniqueInputs.size() < inputs.size()) {
956 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
957 uniqueInputs.getArrayRef(),
958 op.getTwoState());
959 return true;
960 }
961
962 return false;
963}
964
965LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
967 return failure();
968
969 auto inputs = op.getInputs();
970 auto size = inputs.size();
971
972 // and(x, and(...)) -> and(x, ...) -- flatten
973 if (tryFlatteningOperands(op, rewriter))
974 return success();
975
976 // and(..., x, ..., x) -> and(..., x, ...) -- idempotent
977 // and(..., x, and(..., x, ...)) -> and(..., and(..., x, ...)) -- idempotent
978 // Trivial and(x), and(x, x) cases are handled by [AndOp::fold] above.
979 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
980 return success();
981
982 assert(size > 1 && "expected 2 or more operands, `fold` should handle this");
983
984 // Patterns for and with a constant on RHS.
985 APInt value;
986 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
987 // and(..., '1) -> and(...) -- identity
988 if (value.isAllOnes()) {
989 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
990 inputs.drop_back(), false);
991 return success();
992 }
993
994 // TODO: Combine multiple constants together even if they aren't at the
995 // end. and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant
996 // folding
997 APInt value2;
998 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
999 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value & value2);
1000 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1001 newOperands.push_back(cst);
1002 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
1003 newOperands, false);
1004 return success();
1005 }
1006
1007 // Handle 'and' with a single bit constant on the RHS.
1008 if (size == 2 && value.isPowerOf2()) {
1009 // If the LHS is a replicate from a single bit, we can 'concat' it
1010 // into place. e.g.:
1011 // `replicate(x) & 4` -> `concat(zeros, x, zeros)`
1012 // TODO: Generalize this for non-single-bit operands.
1013 if (auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1014 auto replicateOperand = replicate.getOperand();
1015 if (replicateOperand.getType().isInteger(1)) {
1016 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1017 auto trailingZeros = value.countTrailingZeros();
1018
1019 // Don't add zero bit constants unnecessarily.
1020 SmallVector<Value, 3> concatOperands;
1021 if (trailingZeros != resultWidth - 1) {
1022 auto highZeros = hw::ConstantOp::create(
1023 rewriter, op.getLoc(),
1024 APInt::getZero(resultWidth - trailingZeros - 1));
1025 concatOperands.push_back(highZeros);
1026 }
1027 concatOperands.push_back(replicateOperand);
1028 if (trailingZeros != 0) {
1029 auto lowZeros = hw::ConstantOp::create(
1030 rewriter, op.getLoc(), APInt::getZero(trailingZeros));
1031 concatOperands.push_back(lowZeros);
1032 }
1033 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
1034 rewriter, op, op.getType(), concatOperands);
1035 return success();
1036 }
1037 }
1038 }
1039
1040 // Narrow the op if the constant has leading or trailing zeros.
1041 //
1042 // and(a, 0b00101100) -> concat(0b00, and(extract(a), 0b1011), 0b00)
1043 unsigned leadingZeros = value.countLeadingZeros();
1044 unsigned trailingZeros = value.countTrailingZeros();
1045 if (leadingZeros > 0 || trailingZeros > 0) {
1046 unsigned maskLength = value.getBitWidth() - leadingZeros - trailingZeros;
1047
1048 // Extract the non-zero regions of the operands. Look through extracts.
1049 SmallVector<Value> operands;
1050 for (auto input : inputs.drop_back()) {
1051 unsigned offset = trailingZeros;
1052 while (auto extractOp = input.getDefiningOp<ExtractOp>()) {
1053 input = extractOp.getInput();
1054 offset += extractOp.getLowBit();
1055 }
1056 operands.push_back(ExtractOp::create(rewriter, op.getLoc(), input,
1057 offset, maskLength));
1058 }
1059
1060 // Add the narrowed mask if needed.
1061 auto narrowMask = value.extractBits(maskLength, trailingZeros);
1062 if (!narrowMask.isAllOnes())
1063 operands.push_back(hw::ConstantOp::create(
1064 rewriter, inputs.back().getLoc(), narrowMask));
1065
1066 // Create the narrow and op.
1067 Value narrowValue = operands.back();
1068 if (operands.size() > 1)
1069 narrowValue =
1070 AndOp::create(rewriter, op.getLoc(), operands, op.getTwoState());
1071 operands.clear();
1072
1073 // Concatenate the narrow and with the leading and trailing zeros.
1074 if (leadingZeros > 0)
1075 operands.push_back(hw::ConstantOp::create(
1076 rewriter, op.getLoc(), APInt::getZero(leadingZeros)));
1077 operands.push_back(narrowValue);
1078 if (trailingZeros > 0)
1079 operands.push_back(hw::ConstantOp::create(
1080 rewriter, op.getLoc(), APInt::getZero(trailingZeros)));
1081 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, operands);
1082 return success();
1083 }
1084
1085 // and(concat(x, cst1), a, b, c, cst2)
1086 // ==> and(a, b, c, concat(and(x,cst2'), and(cst1,cst2'')).
1087 // We do this for even more multi-use concats since they are "just wiring".
1088 for (size_t i = 0; i < size - 1; ++i) {
1089 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1090 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1091 return success();
1092 }
1093 }
1094
1095 // extracts only of and(...) -> and(extract()...)
1096 if (narrowOperationWidth(op, true, rewriter))
1097 return success();
1098
1099 // and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1)
1100 if (auto source = getCommonOperand(op)) {
1101 auto cmpAgainst =
1102 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getAllOnes(size));
1103 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1104 source, cmpAgainst);
1105 return success();
1106 }
1107
1108 /// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
1109 return failure();
1110}
1111
1112OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1113 if (isOpTriviallyRecursive(*this))
1114 return {};
1115
1116 auto value = APInt::getZero(cast<IntegerType>(getType()).getWidth());
1117 auto inputs = adaptor.getInputs();
1118 // or(x, 10, 01) -> 11
1119 for (auto operand : inputs) {
1120 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
1121 if (!attr)
1122 continue;
1123 value |= attr.getValue();
1124 if (value.isAllOnes())
1125 return getIntAttr(value, getContext());
1126 }
1127
1128 // or(x, 0) -> x
1129 if (inputs.size() == 2)
1130 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
1131 if (intAttr.getValue().isZero())
1132 return getInputs()[0];
1133
1134 // or(x, x, x) -> x. This also handles or(x) -> x
1135 if (llvm::all_of(getInputs(),
1136 [&](auto in) { return in == this->getInputs()[0]; }))
1137 return getInputs()[0];
1138
1139 // or(..., x, ..., ~x, ...) -> -1
1140 for (Value arg : getInputs()) {
1141 Value subExpr;
1142 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
1143 for (Value arg2 : getInputs())
1144 if (arg2 == subExpr)
1145 return getIntAttr(
1146 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1147 getContext());
1148 }
1149 }
1150
1151 // x0 = icmp(pred, x, y)
1152 // x1 = icmp(!pred, x, y)
1153 // or(x0, x1) -> 1
1154 if (canCombineOppositeBinCmpIntoConstant(getInputs()))
1155 return getIntAttr(
1156 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1157 getContext());
1158
1159 // Constant fold
1160 return constFoldAssociativeOp(inputs, hw::PEO::Or);
1161}
1162
1163LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
1164 if (isOpTriviallyRecursive(op))
1165 return failure();
1166
1167 auto inputs = op.getInputs();
1168 auto size = inputs.size();
1169
1170 // or(x, or(...)) -> or(x, ...) -- flatten
1171 if (tryFlatteningOperands(op, rewriter))
1172 return success();
1173
1174 // or(..., x, ..., x, ...) -> or(..., x) -- idempotent
1175 // or(..., x, or(..., x, ...)) -> or(..., or(..., x, ...)) -- idempotent
1176 // Trivial or(x), or(x, x) cases are handled by [OrOp::fold].
1177 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
1178 return success();
1179
1180 assert(size > 1 && "expected 2 or more operands");
1181
1182 // Patterns for and with a constant on RHS.
1183 APInt value;
1184 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1185 // or(..., '0) -> or(...) -- identity
1186 if (value.isZero()) {
1187 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1188 inputs.drop_back());
1189 return success();
1190 }
1191
1192 // or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant folding
1193 APInt value2;
1194 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1195 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value | value2);
1196 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1197 newOperands.push_back(cst);
1198 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1199 newOperands);
1200 return success();
1201 }
1202
1203 // or(concat(x, cst1), a, b, c, cst2)
1204 // ==> or(a, b, c, concat(or(x,cst2'), or(cst1,cst2'')).
1205 // We do this for even more multi-use concats since they are "just wiring".
1206 for (size_t i = 0; i < size - 1; ++i) {
1207 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1208 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1209 return success();
1210 }
1211 }
1212
1213 // extracts only of or(...) -> or(extract()...)
1214 if (narrowOperationWidth(op, true, rewriter))
1215 return success();
1216
1217 // or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0)
1218 if (auto source = getCommonOperand(op)) {
1219 auto cmpAgainst =
1220 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getZero(size));
1221 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1222 source, cmpAgainst);
1223 return success();
1224 }
1225
1226 // or(mux(c_1, a, 0), mux(c_2, a, 0), ..., mux(c_n, a, 0)) -> mux(or(c_1, c_2,
1227 // .., c_n), a, 0)
1228 if (auto firstMux = op.getOperand(0).getDefiningOp<comb::MuxOp>()) {
1229 APInt value;
1230 if (op.getTwoState() && firstMux.getTwoState() &&
1231 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1232 value.isZero()) {
1233 SmallVector<Value> conditions{firstMux.getCond()};
1234 auto check = [&](Value v) {
1235 auto mux = v.getDefiningOp<comb::MuxOp>();
1236 if (!mux)
1237 return false;
1238 conditions.push_back(mux.getCond());
1239 return mux.getTwoState() &&
1240 firstMux.getTrueValue() == mux.getTrueValue() &&
1241 firstMux.getFalseValue() == mux.getFalseValue();
1242 };
1243 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1244 auto cond = comb::OrOp::create(rewriter, op.getLoc(), conditions, true);
1245 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1246 rewriter, op, cond, firstMux.getTrueValue(),
1247 firstMux.getFalseValue(), true);
1248 return success();
1249 }
1250 }
1251 }
1252
1253 /// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
1254 return failure();
1255}
1256
1257OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1258 if (isOpTriviallyRecursive(*this))
1259 return {};
1260
1261 auto size = getInputs().size();
1262 auto inputs = adaptor.getInputs();
1263
1264 // xor(x) -> x -- noop
1265 if (size == 1)
1266 return getInputs()[0];
1267
1268 // xor(x, x) -> 0 -- idempotent
1269 if (size == 2 && getInputs()[0] == getInputs()[1])
1270 return IntegerAttr::get(getType(), 0);
1271
1272 // xor(x, 0) -> x
1273 if (inputs.size() == 2)
1274 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
1275 if (intAttr.getValue().isZero())
1276 return getInputs()[0];
1277
1278 // xor(xor(x,1),1) -> x
1279 // but not self loop
1280 if (isBinaryNot()) {
1281 Value subExpr;
1282 if (matchPattern(getOperand(0), m_Complement(m_Any(&subExpr))) &&
1283 subExpr != getResult())
1284 return subExpr;
1285 }
1286
1287 // Constant fold
1288 return constFoldAssociativeOp(inputs, hw::PEO::Xor);
1289}
1290
1291// xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1292static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand,
1293 PatternRewriter &rewriter) {
1294 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1295 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1296
1297 Value result =
1298 ICmpOp::create(rewriter, icmp.getLoc(), negatedPred, icmp.getOperand(0),
1299 icmp.getOperand(1), icmp.getTwoState());
1300
1301 // If the xor had other operands, rebuild it.
1302 if (op.getNumOperands() > 2) {
1303 SmallVector<Value, 4> newOperands(op.getOperands());
1304 newOperands.pop_back();
1305 newOperands.erase(newOperands.begin() + icmpOperand);
1306 newOperands.push_back(result);
1307 result =
1308 XorOp::create(rewriter, op.getLoc(), newOperands, op.getTwoState());
1309 }
1310
1311 replaceOpAndCopyNamehint(rewriter, op, result);
1312}
1313
1314LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
1315 if (isOpTriviallyRecursive(op))
1316 return failure();
1317
1318 auto inputs = op.getInputs();
1319 auto size = inputs.size();
1320 assert(size > 1 && "expected 2 or more operands");
1321
1322 // xor(..., x, x) -> xor (...) -- idempotent
1323 if (inputs[size - 1] == inputs[size - 2]) {
1324 assert(size > 2 &&
1325 "expected idempotent case for 2 elements handled already.");
1326 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1327 inputs.drop_back(/*n=*/2), false);
1328 return success();
1329 }
1330
1331 // Patterns for xor with a constant on RHS.
1332 APInt value;
1333 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1334 // xor(..., 0) -> xor(...) -- identity
1335 if (value.isZero()) {
1336 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1337 inputs.drop_back(), false);
1338 return success();
1339 }
1340
1341 // xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2.
1342 APInt value2;
1343 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1344 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value ^ value2);
1345 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1346 newOperands.push_back(cst);
1347 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1348 newOperands, false);
1349 return success();
1350 }
1351
1352 bool isSingleBit = value.getBitWidth() == 1;
1353
1354 // Check for subexpressions that we can simplify.
1355 for (size_t i = 0; i < size - 1; ++i) {
1356 Value operand = inputs[i];
1357
1358 // xor(concat(x, cst1), a, b, c, cst2)
1359 // ==> xor(a, b, c, concat(xor(x,cst2'), xor(cst1,cst2'')).
1360 // We do this for even more multi-use concats since they are "just
1361 // wiring".
1362 if (auto concat = operand.getDefiningOp<ConcatOp>())
1363 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1364 return success();
1365
1366 // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1367 if (isSingleBit && operand.hasOneUse()) {
1368 assert(value == 1 && "single bit constant has to be one if not zero");
1369 if (auto icmp = operand.getDefiningOp<ICmpOp>())
1370 return canonicalizeXorIcmpTrue(op, i, rewriter), success();
1371 }
1372 }
1373 }
1374
1375 // xor(x, xor(...)) -> xor(x, ...) -- flatten
1376 if (tryFlatteningOperands(op, rewriter))
1377 return success();
1378
1379 // extracts only of xor(...) -> xor(extract()...)
1380 if (narrowOperationWidth(op, true, rewriter))
1381 return success();
1382
1383 // xor(a[0], a[1], ..., a[n]) -> parity(a)
1384 if (auto source = getCommonOperand(op)) {
1385 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1386 return success();
1387 }
1388
1389 return failure();
1390}
1391
1392OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1393 if (isOpTriviallyRecursive(*this))
1394 return {};
1395
1396 // sub(x - x) -> 0
1397 if (getRhs() == getLhs())
1398 return getIntAttr(
1399 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1400 getContext());
1401
1402 if (adaptor.getRhs()) {
1403 // If both are constants, we can unconditionally fold.
1404 if (adaptor.getLhs()) {
1405 // Constant fold (c1 - c2) => (c1 + -1*c2).
1406 auto negOne = getIntAttr(
1407 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1408 getContext());
1409 auto rhsNeg = hw::ParamExprAttr::get(
1410 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1411 return hw::ParamExprAttr::get(hw::PEO::Add,
1412 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1413 }
1414
1415 // sub(x - 0) -> x
1416 if (auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1417 if (rhsC.getValue().isZero())
1418 return getLhs();
1419 }
1420 }
1421
1422 return {};
1423}
1424
1425LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
1426 if (isOpTriviallyRecursive(op))
1427 return failure();
1428
1429 // sub(x, cst) -> add(x, -cst)
1430 APInt value;
1431 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1432 auto negCst = hw::ConstantOp::create(rewriter, op.getLoc(), -value);
1433 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1434 false);
1435 return success();
1436 }
1437
1438 // extracts only of sub(...) -> sub(extract()...)
1439 if (narrowOperationWidth(op, false, rewriter))
1440 return success();
1441
1442 return failure();
1443}
1444
1445OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1446 if (isOpTriviallyRecursive(*this))
1447 return {};
1448
1449 auto size = getInputs().size();
1450
1451 // add(x) -> x -- noop
1452 if (size == 1u)
1453 return getInputs()[0];
1454
1455 // Constant fold constant operands.
1456 return constFoldAssociativeOp(adaptor.getOperands(), hw::PEO::Add);
1457}
1458
1459LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
1460 if (isOpTriviallyRecursive(op))
1461 return failure();
1462
1463 auto inputs = op.getInputs();
1464 auto size = inputs.size();
1465 assert(size > 1 && "expected 2 or more operands");
1466
1467 APInt value, value2;
1468
1469 // add(..., 0) -> add(...) -- identity
1470 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1471 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1472 inputs.drop_back(), false);
1473 return success();
1474 }
1475
1476 // add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant folding
1477 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1478 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1479 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value + value2);
1480 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1481 newOperands.push_back(cst);
1482 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1483 newOperands, false);
1484 return success();
1485 }
1486
1487 // add(..., x, x) -> add(..., shl(x, 1))
1488 if (inputs[size - 1] == inputs[size - 2]) {
1489 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1490
1491 auto one = hw::ConstantOp::create(rewriter, op.getLoc(), op.getType(), 1);
1492 auto shiftLeftOp =
1493 comb::ShlOp::create(rewriter, op.getLoc(), inputs.back(), one, false);
1494
1495 newOperands.push_back(shiftLeftOp);
1496 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1497 newOperands, false);
1498 return success();
1499 }
1500
1501 auto shlOp = inputs[size - 1].getDefiningOp<comb::ShlOp>();
1502 // add(..., x, shl(x, c)) -> add(..., mul(x, (1 << c) + 1))
1503 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1504 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1505
1506 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1507 auto rhs =
1508 hw::ConstantOp::create(rewriter, op.getLoc(), (one << value) + one);
1509
1510 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1511 auto mulOp = comb::MulOp::create(rewriter, op.getLoc(), factors, false);
1512
1513 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1514 newOperands.push_back(mulOp);
1515 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1516 newOperands, false);
1517 return success();
1518 }
1519
1520 auto mulOp = inputs[size - 1].getDefiningOp<comb::MulOp>();
1521 // add(..., x, mul(x, c)) -> add(..., mul(x, c + 1))
1522 if (mulOp && mulOp.getInputs().size() == 2 &&
1523 mulOp.getInputs()[0] == inputs[size - 2] &&
1524 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1525
1526 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1527 auto rhs = hw::ConstantOp::create(rewriter, op.getLoc(), value + one);
1528 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1529 auto newMulOp = comb::MulOp::create(rewriter, op.getLoc(), factors, false);
1530
1531 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1532 newOperands.push_back(newMulOp);
1533 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1534 newOperands, false);
1535 return success();
1536 }
1537
1538 // add(a, add(...)) -> add(a, ...) -- flatten
1539 if (tryFlatteningOperands(op, rewriter))
1540 return success();
1541
1542 // extracts only of add(...) -> add(extract()...)
1543 if (narrowOperationWidth(op, false, rewriter))
1544 return success();
1545
1546 // add(add(x, c1), c2) -> add(x, c1 + c2)
1547 auto addOp = inputs[0].getDefiningOp<comb::AddOp>();
1548 if (addOp && addOp.getInputs().size() == 2 &&
1549 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1550 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1551
1552 auto rhs = hw::ConstantOp::create(rewriter, op.getLoc(), value + value2);
1553 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1554 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1555 /*twoState=*/op.getTwoState() && addOp.getTwoState());
1556 return success();
1557 }
1558
1559 return failure();
1560}
1561
1562OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1563 if (isOpTriviallyRecursive(*this))
1564 return {};
1565
1566 auto size = getInputs().size();
1567 auto inputs = adaptor.getInputs();
1568
1569 // mul(x) -> x -- noop
1570 if (size == 1u)
1571 return getInputs()[0];
1572
1573 auto width = cast<IntegerType>(getType()).getWidth();
1574 if (width == 0)
1575 return getIntAttr(APInt::getZero(0), getContext());
1576
1577 APInt value(/*numBits=*/width, 1, /*isSigned=*/false);
1578
1579 // mul(x, 0, 1) -> 0 -- annulment
1580 for (auto operand : inputs) {
1581 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
1582 if (!attr)
1583 continue;
1584 value *= attr.getValue();
1585 if (value.isZero())
1586 return getIntAttr(value, getContext());
1587 }
1588
1589 // Constant fold
1590 return constFoldAssociativeOp(inputs, hw::PEO::Mul);
1591}
1592
1593LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
1594 if (isOpTriviallyRecursive(op))
1595 return failure();
1596
1597 auto inputs = op.getInputs();
1598 auto size = inputs.size();
1599 assert(size > 1 && "expected 2 or more operands");
1600
1601 APInt value, value2;
1602
1603 // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
1604 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1605 value.isPowerOf2()) {
1606 auto shift = hw::ConstantOp::create(rewriter, op.getLoc(), op.getType(),
1607 value.exactLogBase2());
1608 auto shlOp =
1609 comb::ShlOp::create(rewriter, op.getLoc(), inputs[0], shift, false);
1610
1611 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1612 ArrayRef<Value>(shlOp), false);
1613 return success();
1614 }
1615
1616 // mul(..., 1) -> mul(...) -- identity
1617 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1618 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1619 inputs.drop_back());
1620 return success();
1621 }
1622
1623 // mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant folding
1624 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1625 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1626 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value * value2);
1627 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1628 newOperands.push_back(cst);
1629 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1630 newOperands);
1631 return success();
1632 }
1633
1634 // mul(a, mul(...)) -> mul(a, ...) -- flatten
1635 if (tryFlatteningOperands(op, rewriter))
1636 return success();
1637
1638 // extracts only of mul(...) -> mul(extract()...)
1639 if (narrowOperationWidth(op, false, rewriter))
1640 return success();
1641
1642 return failure();
1643}
1644
1645template <class Op, bool isSigned>
1646static OpFoldResult foldDiv(Op op, ArrayRef<Attribute> constants) {
1647 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1648 // divu(x, 1) -> x, divs(x, 1) -> x
1649 if (rhsValue.getValue() == 1)
1650 return op.getLhs();
1651
1652 // If the divisor is zero, do not fold for now.
1653 if (rhsValue.getValue().isZero())
1654 return {};
1655 }
1656
1657 return constFoldBinaryOp(constants, isSigned ? hw::PEO::DivS : hw::PEO::DivU);
1658}
1659
1660OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1661 if (isOpTriviallyRecursive(*this))
1662 return {};
1663 return foldDiv<DivUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1664}
1665
1666OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1667 if (isOpTriviallyRecursive(*this))
1668 return {};
1669 return foldDiv<DivSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1670}
1671
1672template <class Op, bool isSigned>
1673static OpFoldResult foldMod(Op op, ArrayRef<Attribute> constants) {
1674 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1675 // modu(x, 1) -> 0, mods(x, 1) -> 0
1676 if (rhsValue.getValue() == 1)
1677 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1678 op.getContext());
1679
1680 // If the divisor is zero, do not fold for now.
1681 if (rhsValue.getValue().isZero())
1682 return {};
1683 }
1684
1685 if (auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1686 // modu(0, x) -> 0, mods(0, x) -> 0
1687 if (lhsValue.getValue().isZero())
1688 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1689 op.getContext());
1690 }
1691
1692 return constFoldBinaryOp(constants, isSigned ? hw::PEO::ModS : hw::PEO::ModU);
1693}
1694
1695OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1696 if (isOpTriviallyRecursive(*this))
1697 return {};
1698 return foldMod<ModUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1699}
1700
1701OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1702 if (isOpTriviallyRecursive(*this))
1703 return {};
1704 return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1705}
1706
1707LogicalResult DivUOp::canonicalize(DivUOp op, PatternRewriter &rewriter) {
1708 if (isOpTriviallyRecursive(op) || !op.getTwoState())
1709 return failure();
1710 return convertDivUByPowerOfTwo(op, rewriter);
1711}
1712
1713LogicalResult ModUOp::canonicalize(ModUOp op, PatternRewriter &rewriter) {
1714 if (isOpTriviallyRecursive(op) || !op.getTwoState())
1715 return failure();
1716
1717 return convertModUByPowerOfTwo(op, rewriter);
1718}
1719
1720//===----------------------------------------------------------------------===//
1721// ConcatOp
1722//===----------------------------------------------------------------------===//
1723
1724// Constant folding
1725OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1726 if (isOpTriviallyRecursive(*this))
1727 return {};
1728
1729 if (getNumOperands() == 1)
1730 return getOperand(0);
1731
1732 // If all the operands are constant, we can fold.
1733 for (auto attr : adaptor.getInputs())
1734 if (!attr || !isa<IntegerAttr>(attr))
1735 return {};
1736
1737 // If we got here, we can constant fold.
1738 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1739 APInt result(resultWidth, 0);
1740
1741 unsigned nextInsertion = resultWidth;
1742 // Insert each chunk into the result.
1743 for (auto attr : adaptor.getInputs()) {
1744 auto chunk = cast<IntegerAttr>(attr).getValue();
1745 nextInsertion -= chunk.getBitWidth();
1746 result.insertBits(chunk, nextInsertion);
1747 }
1748
1749 return getIntAttr(result, getContext());
1750}
1751
1752LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
1753 if (isOpTriviallyRecursive(op))
1754 return failure();
1755
1756 auto inputs = op.getInputs();
1757 auto size = inputs.size();
1758 assert(size > 1 && "expected 2 or more operands");
1759
1760 // This function is used when we flatten neighboring operands of a
1761 // (variadic) concat into a new vesion of the concat. first/last indices
1762 // are inclusive.
1763 auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex,
1764 ValueRange replacements) -> LogicalResult {
1765 SmallVector<Value, 4> newOperands;
1766 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1767 newOperands.append(replacements.begin(), replacements.end());
1768 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1769 if (newOperands.size() == 1)
1770 replaceOpAndCopyNamehint(rewriter, op, newOperands[0]);
1771 else
1772 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1773 newOperands);
1774 return success();
1775 };
1776
1777 Value commonOperand = inputs[0];
1778 for (size_t i = 0; i != size; ++i) {
1779 // Check to see if all operands are the same.
1780 if (inputs[i] != commonOperand)
1781 commonOperand = Value();
1782
1783 // If an operand to the concat is itself a concat, then we can fold them
1784 // together.
1785 if (auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1786 return flattenConcat(i, i, subConcat->getOperands());
1787
1788 // Check for canonicalization due to neighboring operands.
1789 if (i != 0) {
1790 // Merge neighboring constants.
1791 if (auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1792 if (auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1793 unsigned prevWidth = prevCst.getValue().getBitWidth();
1794 unsigned thisWidth = cst.getValue().getBitWidth();
1795 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1796 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1797 << thisWidth;
1798 Value replacement =
1799 hw::ConstantOp::create(rewriter, op.getLoc(), resultCst);
1800 return flattenConcat(i - 1, i, replacement);
1801 }
1802 }
1803
1804 // If the two operands are the same, turn them into a replicate.
1805 if (inputs[i] == inputs[i - 1]) {
1806 Value replacement =
1807 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1808 return flattenConcat(i - 1, i, replacement);
1809 }
1810
1811 // If this input is a replicate, see if we can fold it with the previous
1812 // one.
1813 if (auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1814 // ... x, repl(x, n), ... ==> ..., repl(x, n+1), ...
1815 if (repl.getOperand() == inputs[i - 1]) {
1816 Value replacement = rewriter.createOrFold<ReplicateOp>(
1817 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1818 return flattenConcat(i - 1, i, replacement);
1819 }
1820 // ... repl(x, n), repl(x, m), ... ==> ..., repl(x, n+m), ...
1821 if (auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1822 if (prevRepl.getOperand() == repl.getOperand()) {
1823 Value replacement = rewriter.createOrFold<ReplicateOp>(
1824 op.getLoc(), repl.getOperand(),
1825 repl.getMultiple() + prevRepl.getMultiple());
1826 return flattenConcat(i - 1, i, replacement);
1827 }
1828 }
1829 }
1830
1831 // ... repl(x, n), x, ... ==> ..., repl(x, n+1), ...
1832 if (auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1833 if (repl.getOperand() == inputs[i]) {
1834 Value replacement = rewriter.createOrFold<ReplicateOp>(
1835 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1836 return flattenConcat(i - 1, i, replacement);
1837 }
1838 }
1839
1840 // Merge neighboring extracts of neighboring inputs, e.g.
1841 // {A[3], A[2]} -> A[3:2]
1842 if (auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1843 if (auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1844 if (extract.getInput() == prevExtract.getInput()) {
1845 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1846 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1847 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1848 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1849 Value replacement =
1850 ExtractOp::create(rewriter, op.getLoc(), resType,
1851 extract.getInput(), extract.getLowBit());
1852 return flattenConcat(i - 1, i, replacement);
1853 }
1854 }
1855 }
1856 }
1857 // Merge neighboring array extracts of neighboring inputs, e.g.
1858 // {Array[4], bitcast(Array[3:2])} -> bitcast(A[4:2])
1859
1860 // This represents a slice of an array.
1861 struct ArraySlice {
1862 Value input;
1863 Value index;
1864 size_t width;
1865 static std::optional<ArraySlice> get(Value value) {
1866 assert(isa<IntegerType>(value.getType()) && "expected integer type");
1867 if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
1868 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1869 // array slice op is wrapped with bitcast.
1870 if (auto bitcast = value.getDefiningOp<hw::BitcastOp>())
1871 if (auto arraySlice =
1872 bitcast.getInput().getDefiningOp<hw::ArraySliceOp>())
1873 return ArraySlice{
1874 arraySlice.getInput(), arraySlice.getLowIndex(),
1875 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1876 .getNumElements()};
1877 return std::nullopt;
1878 }
1879 };
1880 if (auto extractOpt = ArraySlice::get(inputs[i])) {
1881 if (auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1882 // Check that two array slices are mergable.
1883 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1884 prevExtractOpt->input == extractOpt->input &&
1885 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1886 extractOpt->width)) {
1887 auto resType = hw::ArrayType::get(
1888 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1889 .getElementType(),
1890 extractOpt->width + prevExtractOpt->width);
1891 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1892 Value replacement = hw::BitcastOp::create(
1893 rewriter, op.getLoc(), resIntType,
1894 hw::ArraySliceOp::create(rewriter, op.getLoc(), resType,
1895 prevExtractOpt->input,
1896 extractOpt->index));
1897 return flattenConcat(i - 1, i, replacement);
1898 }
1899 }
1900 }
1901 }
1902 }
1903
1904 // If all operands were the same, then this is a replicate.
1905 if (commonOperand) {
1906 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1907 commonOperand);
1908 return success();
1909 }
1910
1911 return failure();
1912}
1913
1914//===----------------------------------------------------------------------===//
1915// MuxOp
1916//===----------------------------------------------------------------------===//
1917
1918OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1919 if (isOpTriviallyRecursive(*this))
1920 return {};
1921
1922 // mux (c, b, b) -> b
1923 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1924 return getTrueValue();
1925 if (auto tv = adaptor.getTrueValue())
1926 if (tv == adaptor.getFalseValue())
1927 return tv;
1928
1929 // mux(0, a, b) -> b
1930 // mux(1, a, b) -> a
1931 if (auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1932 if (pred.getValue().isZero() && getFalseValue() != getResult())
1933 return getFalseValue();
1934 if (pred.getValue().isOne() && getTrueValue() != getResult())
1935 return getTrueValue();
1936 }
1937
1938 // mux(cond, 1, 0) -> cond
1939 if (getCond().getType() == getTrueValue().getType())
1940 if (auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1941 if (auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1942 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1943 hw::getBitWidth(getType()) == 1 && getCond() != getResult())
1944 return getCond();
1945
1946 return {};
1947}
1948
1949/// Check to see if the condition to the specified mux is an equality
1950/// comparison `indexValue` and one or more constants. If so, put the
1951/// constants in the constants vector and return true, otherwise return false.
1952///
1953/// This is part of foldMuxChain.
1954///
1955static bool
1956getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted,
1957 std::function<void(hw::ConstantOp)> constantFn) {
1958 // Handle `idx == 42` and `idx != 42`.
1959 if (auto cmp = cond.getDefiningOp<ICmpOp>()) {
1960 // TODO: We could handle things like "x < 2" as two entries.
1961 auto requiredPredicate =
1962 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1963 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1964 if (auto cst = cmp.getRhs().getDefiningOp<hw::ConstantOp>()) {
1965 constantFn(cst);
1966 return true;
1967 }
1968 }
1969 return false;
1970 }
1971
1972 // Handle mux(`idx == 1 || idx == 3`, value, muxchain).
1973 if (auto orOp = cond.getDefiningOp<OrOp>()) {
1974 if (!isInverted)
1975 return false;
1976 for (auto operand : orOp.getOperands())
1977 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
1978 return false;
1979 return true;
1980 }
1981
1982 // Handle mux(`idx != 1 && idx != 3`, muxchain, value).
1983 if (auto andOp = cond.getDefiningOp<AndOp>()) {
1984 if (isInverted)
1985 return false;
1986 for (auto operand : andOp.getOperands())
1987 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
1988 return false;
1989 return true;
1990 }
1991
1992 return false;
1993}
1994
1995/// Given a mux, check to see if the "on true" value (or "on false" value if
1996/// isFalseSide=true) is a mux tree with the same condition. This allows us
1997/// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into
1998/// `array_get (array_create(A, B, C), VAL)` or a balanced mux tree which is far
1999/// more compact and allows synthesis tools to do more interesting
2000/// optimizations.
2001///
2002/// This returns false if we cannot form the mux tree (or do not want to) and
2003/// returns true if the mux was replaced.
2005 PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide,
2006 llvm::function_ref<MuxChainWithComparisonFoldingStyle(size_t indexWidth,
2007 size_t numEntries)>
2008 styleFn) {
2009 // Get the index value being compared. Later we check to see if it is
2010 // compared to a constant with the right predicate.
2011 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2012 if (!rootCmp)
2013 return false;
2014 Value indexValue = rootCmp.getLhs();
2015
2016 // Return the value to use if the equality match succeeds.
2017 auto getCaseValue = [&](MuxOp mux) -> Value {
2018 return mux.getOperand(1 + unsigned(!isFalseSide));
2019 };
2020
2021 // Return the value to use if the equality match fails. This is the next
2022 // mux in the sequence or the "otherwise" value.
2023 auto getTreeValue = [&](MuxOp mux) -> Value {
2024 return mux.getOperand(1 + unsigned(isFalseSide));
2025 };
2026
2027 // Start scanning the mux tree to see what we've got. Keep track of the
2028 // constant comparison value and the SSA value to use when equal to it.
2029 SmallVector<Location> locationsFound;
2030 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2031
2032 /// Extract constants and values into `valuesFound` and return true if this is
2033 /// part of the mux tree, otherwise return false.
2034 auto collectConstantValues = [&](MuxOp mux) -> bool {
2036 mux.getCond(), indexValue, isFalseSide, [&](hw::ConstantOp cst) {
2037 valuesFound.push_back({cst, getCaseValue(mux)});
2038 locationsFound.push_back(mux.getCond().getLoc());
2039 locationsFound.push_back(mux->getLoc());
2040 });
2041 };
2042
2043 // Make sure the root is a correct comparison with a constant.
2044 if (!collectConstantValues(rootMux))
2045 return false;
2046
2047 // Make sure that we're not looking at the intermediate node in a mux tree.
2048 if (rootMux->hasOneUse()) {
2049 if (auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2050 if (getTreeValue(userMux) == rootMux.getResult() &&
2051 getMuxChainCondConstant(userMux.getCond(), indexValue, isFalseSide,
2052 [&](hw::ConstantOp cst) {}))
2053 return false;
2054 }
2055 }
2056
2057 // Scan up the tree linearly.
2058 auto nextTreeValue = getTreeValue(rootMux);
2059 while (1) {
2060 auto nextMux = nextTreeValue.getDefiningOp<MuxOp>();
2061 if (!nextMux || !nextMux->hasOneUse())
2062 break;
2063 if (!collectConstantValues(nextMux))
2064 break;
2065 nextTreeValue = getTreeValue(nextMux);
2066 }
2067
2068 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2069
2070 if (indexWidth > 20)
2071 return false; // Too big to make a table.
2072
2073 auto foldingStyle = styleFn(indexWidth, valuesFound.size());
2074 if (foldingStyle == MuxChainWithComparisonFoldingStyle::None)
2075 return false;
2076
2077 uint64_t tableSize = 1ULL << indexWidth;
2078
2079 // Ok, we're going to do the transformation, start by building the table
2080 // filled with the "otherwise" value.
2081 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2082
2083 // Fill in entries in the table from the leaf to the root of the expression.
2084 // This ensures that any duplicate matches end up with the ultimate value,
2085 // which is the one closer to the root.
2086 for (auto &elt : llvm::reverse(valuesFound)) {
2087 uint64_t idx = elt.first.getValue().getZExtValue();
2088 assert(idx < table.size() && "constant should be same bitwidth as index");
2089 table[idx] = elt.second;
2090 }
2091
2093 SmallVector<Value> bits;
2094 comb::extractBits(rewriter, indexValue, bits);
2095 auto result = constructMuxTree(rewriter, rootMux->getLoc(), bits, table,
2096 nextTreeValue);
2097 replaceOpAndCopyNamehint(rewriter, rootMux, result);
2098 return true;
2099 }
2100
2102 "unknown folding style");
2103
2104 // The hw.array_create operation has the operand list in unintuitive order
2105 // with a[0] stored as the last element, not the first.
2106 std::reverse(table.begin(), table.end());
2107
2108 // Build the array_create and the array_get.
2109 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2110 auto array = hw::ArrayCreateOp::create(rewriter, fusedLoc, table);
2111 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
2112 indexValue);
2113 return true;
2114}
2115
2116/// Given a fully associative variadic operation like (a+b+c+d), break the
2117/// expression into two parts, one without the specified operand (e.g.
2118/// `tmp = a+b+d`) and one that combines that into the full expression (e.g.
2119/// `tmp+c`), and return the inner expression.
2120///
2121/// NOTE: This mutates the operation in place if it only has a single user,
2122/// which assumes that user will be removed.
2123///
2124static Value extractOperandFromFullyAssociative(Operation *fullyAssoc,
2125 size_t operandNo,
2126 PatternRewriter &rewriter) {
2127 assert(fullyAssoc->getNumOperands() >= 2 && "cannot split up unary ops");
2128 assert(operandNo < fullyAssoc->getNumOperands() && "Invalid operand #");
2129
2130 // If this expression already has two operands (the common case) no splitting
2131 // is necessary.
2132 if (fullyAssoc->getNumOperands() == 2)
2133 return fullyAssoc->getOperand(operandNo ^ 1);
2134
2135 // If the operation has a single use, mutate it in place.
2136 if (fullyAssoc->hasOneUse()) {
2137 rewriter.modifyOpInPlace(fullyAssoc,
2138 [&]() { fullyAssoc->eraseOperand(operandNo); });
2139 return fullyAssoc->getResult(0);
2140 }
2141
2142 // Form the new operation with the operands that remain.
2143 SmallVector<Value> operands;
2144 operands.append(fullyAssoc->getOperands().begin(),
2145 fullyAssoc->getOperands().begin() + operandNo);
2146 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2147 fullyAssoc->getOperands().end());
2148 Value opWithoutExcluded = createGenericOp(
2149 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2150 Value excluded = fullyAssoc->getOperand(operandNo);
2151
2152 Value fullResult =
2153 createGenericOp(fullyAssoc->getLoc(), fullyAssoc->getName(),
2154 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2155 replaceOpAndCopyNamehint(rewriter, fullyAssoc, fullResult);
2156 return opWithoutExcluded;
2157}
2158
2159/// Fold things like `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond)|a` and
2160/// `mux(cond, a, x|y|z|a) -> `(x|y|z)&replicate(~cond) | a` (when isTrueOperand
2161/// is true. Return true on successful transformation, false if not.
2162///
2163/// These are various forms of "predicated ops" that can be handled with a
2164/// replicate/and combination.
2165static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
2166 PatternRewriter &rewriter) {
2167 // Check to see the operand in question is an operation. If it is a port,
2168 // we can't simplify it.
2169 Operation *subExpr =
2170 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2171 if (!subExpr || subExpr->getNumOperands() < 2)
2172 return false;
2173
2174 // If this isn't an operation we can handle, don't spend energy on it.
2175 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2176 return false;
2177
2178 // Check to see if the common value occurs in the operand list for the
2179 // subexpression op. If so, then we can simplify it.
2180 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2181 size_t opNo = 0, e = subExpr->getNumOperands();
2182 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2183 ++opNo;
2184 if (opNo == e)
2185 return false;
2186
2187 // If we got a hit, then go ahead and simplify it!
2188 Value cond = op.getCond();
2189
2190 // `mux(cond, a, mux(cond2, a, b))` -> `mux(cond|cond2, a, b)`
2191 // `mux(cond, a, mux(cond2, b, a))` -> `mux(cond|~cond2, a, b)`
2192 // `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)`
2193 // `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)`
2194 if (auto subMux = dyn_cast<MuxOp>(subExpr)) {
2195 if (subMux == op)
2196 return false;
2197
2198 Value otherValue;
2199 Value subCond = subMux.getCond();
2200
2201 // Invert th subCond if needed and dig out the 'b' value.
2202 if (subMux.getTrueValue() == commonValue)
2203 otherValue = subMux.getFalseValue();
2204 else if (subMux.getFalseValue() == commonValue) {
2205 otherValue = subMux.getTrueValue();
2206 subCond = createOrFoldNot(op.getLoc(), subCond, rewriter);
2207 } else {
2208 // We can't fold `mux(cond, a, mux(a, x, y))`.
2209 return false;
2210 }
2211
2212 // Invert the outer cond if needed, and combine the mux conditions.
2213 if (!isTrueOperand)
2214 cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2215 cond = rewriter.createOrFold<OrOp>(op.getLoc(), cond, subCond, false);
2216 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2217 otherValue, op.getTwoState());
2218 return true;
2219 }
2220
2221 // Invert the condition if needed. Or/Xor invert when dealing with
2222 // TrueOperand, And inverts for False operand.
2223 bool isaAndOp = isa<AndOp>(subExpr);
2224 if (isTrueOperand ^ isaAndOp)
2225 cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2226
2227 auto extendedCond =
2228 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2229
2230 // Cache this information before subExpr is erased by extraction below.
2231 bool isaXorOp = isa<XorOp>(subExpr);
2232 bool isaOrOp = isa<OrOp>(subExpr);
2233
2234 // Handle the fully associative ops, start by pulling out the subexpression
2235 // from a many operand version of the op.
2236 auto restOfAssoc =
2237 extractOperandFromFullyAssociative(subExpr, opNo, rewriter);
2238
2239 // `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond) | a`
2240 // `mux(cond, x^y^z^a, a)` -> `(x^y^z)&replicate(cond) ^ a`
2241 if (isaOrOp || isaXorOp) {
2242 auto masked = rewriter.createOrFold<AndOp>(op.getLoc(), extendedCond,
2243 restOfAssoc, false);
2244 if (isaXorOp)
2245 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2246 commonValue, false);
2247 else
2248 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2249 false);
2250 return true;
2251 }
2252
2253 // `mux(cond, a, x&y&z&a)` -> `((x&y&z)|replicate(cond)) & a`
2254 assert(isaAndOp && "unexpected operation here");
2255 auto masked = rewriter.createOrFold<OrOp>(op.getLoc(), extendedCond,
2256 restOfAssoc, false);
2257 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2258 false);
2259 return true;
2260}
2261
2262/// This function is invoke when we find a mux with true/false operations that
2263/// have the same opcode. Check to see if we can strength reduce the mux by
2264/// applying it to less data by applying this transformation:
2265/// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2266static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp,
2267 Operation *falseOp,
2268 PatternRewriter &rewriter) {
2269 // Right now we only apply to concat.
2270 // TODO: Generalize this to and, or, xor, icmp(!), which all occur in practice
2271 if (!isa<ConcatOp>(trueOp))
2272 return false;
2273
2274 // Decode the operands, looking through recursive concats and replicates.
2275 SmallVector<Value> trueOperands, falseOperands;
2276 getConcatOperands(trueOp->getResult(0), trueOperands);
2277 getConcatOperands(falseOp->getResult(0), falseOperands);
2278
2279 size_t numTrueOperands = trueOperands.size();
2280 size_t numFalseOperands = falseOperands.size();
2281
2282 if (!numTrueOperands || !numFalseOperands ||
2283 (trueOperands.front() != falseOperands.front() &&
2284 trueOperands.back() != falseOperands.back()))
2285 return false;
2286
2287 // Pull all leading shared operands out into their own op if any are common.
2288 if (trueOperands.front() == falseOperands.front()) {
2289 SmallVector<Value> operands;
2290 size_t i;
2291 for (i = 0; i < numTrueOperands; ++i) {
2292 Value trueOperand = trueOperands[i];
2293 if (trueOperand == falseOperands[i])
2294 operands.push_back(trueOperand);
2295 else
2296 break;
2297 }
2298 if (i == numTrueOperands) {
2299 // Selecting between distinct, but lexically identical, concats.
2300 replaceOpAndCopyNamehint(rewriter, mux, trueOp->getResult(0));
2301 return true;
2302 }
2303
2304 Value sharedMSB;
2305 if (llvm::all_of(operands, [&](Value v) { return v == operands.front(); }))
2306 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2307 mux->getLoc(), operands.front(), operands.size());
2308 else
2309 sharedMSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2310 operands.clear();
2311
2312 // Get a concat of the LSB's on each side.
2313 operands.append(trueOperands.begin() + i, trueOperands.end());
2314 Value trueLSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2315 operands.clear();
2316 operands.append(falseOperands.begin() + i, falseOperands.end());
2317 Value falseLSB =
2318 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2319 // Merge the LSBs with a new mux and concat the MSB with the LSB to be
2320 // done.
2321 Value lsb = rewriter.createOrFold<MuxOp>(
2322 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2323 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2324 return true;
2325 }
2326
2327 // If trailing operands match, try to commonize them.
2328 if (trueOperands.back() == falseOperands.back()) {
2329 SmallVector<Value> operands;
2330 size_t i;
2331 for (i = 0;; ++i) {
2332 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2333 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2334 operands.push_back(trueOperand);
2335 else
2336 break;
2337 }
2338 std::reverse(operands.begin(), operands.end());
2339 Value sharedLSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2340 operands.clear();
2341
2342 // Get a concat of the MSB's on each side.
2343 operands.append(trueOperands.begin(), trueOperands.end() - i);
2344 Value trueMSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2345 operands.clear();
2346 operands.append(falseOperands.begin(), falseOperands.end() - i);
2347 Value falseMSB =
2348 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2349 // Merge the MSBs with a new mux and concat the MSB with the LSB to be done.
2350 Value msb = rewriter.createOrFold<MuxOp>(
2351 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2352 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2353 return true;
2354 }
2355
2356 return false;
2357}
2358
2359// If both arguments of the mux are arrays with the same elements, sink the
2360// mux and return a uniform array initializing all elements to it.
2361static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter) {
2362 auto trueVec = op.getTrueValue().getDefiningOp<hw::ArrayCreateOp>();
2363 auto falseVec = op.getFalseValue().getDefiningOp<hw::ArrayCreateOp>();
2364 if (!trueVec || !falseVec)
2365 return false;
2366 if (!trueVec.isUniform() || !falseVec.isUniform())
2367 return false;
2368
2369 auto mux = MuxOp::create(rewriter, op.getLoc(), op.getCond(),
2370 trueVec.getUniformElement(),
2371 falseVec.getUniformElement(), op.getTwoState());
2372
2373 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2374 rewriter.replaceOpWithNewOp<hw::ArrayCreateOp>(op, values);
2375 return true;
2376}
2377
2378/// If the mux condition is an operand to the op defining its true or false
2379/// value, replace the condition with 1 or 0.
2380static bool assumeMuxCondInOperand(Value muxCond, Value muxValue,
2381 bool constCond, PatternRewriter &rewriter) {
2382 if (!muxValue.hasOneUse())
2383 return false;
2384 auto *op = muxValue.getDefiningOp();
2385 if (!op || !isa_and_nonnull<CombDialect>(op->getDialect()))
2386 return false;
2387 if (!llvm::is_contained(op->getOperands(), muxCond))
2388 return false;
2389 OpBuilder::InsertionGuard guard(rewriter);
2390 rewriter.setInsertionPoint(op);
2391 auto condValue =
2392 hw::ConstantOp::create(rewriter, muxCond.getLoc(), APInt(1, constCond));
2393 rewriter.modifyOpInPlace(op, [&] {
2394 for (auto &use : op->getOpOperands())
2395 if (use.get() == muxCond)
2396 use.set(condValue);
2397 });
2398 return true;
2399}
2400
2401namespace {
2402struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
2403 using OpRewritePattern::OpRewritePattern;
2404
2405 LogicalResult matchAndRewrite(MuxOp op,
2406 PatternRewriter &rewriter) const override;
2407};
2408
2410foldToArrayCreateOnlyWhenDense(size_t indexWidth, size_t numEntries) {
2411 // If the array is greater that 9 bits, it will take over 512 elements and
2412 // it will be too large for a single expression.
2413 if (indexWidth >= 9 || numEntries < 3)
2415
2416 // Next we need to see if the values are dense-ish. We don't want to have
2417 // a tremendous number of replicated entries in the array. Some sparsity is
2418 // ok though, so we require the table to be at least 5/8 utilized.
2419 uint64_t tableSize = 1ULL << indexWidth;
2420 if (numEntries >= tableSize * 5 / 8)
2423}
2424
2425LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
2426 PatternRewriter &rewriter) const {
2427 if (isOpTriviallyRecursive(op))
2428 return failure();
2429
2430 bool isSignlessInt = false;
2431 if (auto intType = dyn_cast<IntegerType>(op.getType()))
2432 isSignlessInt = intType.isSignless();
2433
2434 // If the op has a SV attribute, don't optimize it.
2435 if (hasSVAttributes(op))
2436 return failure();
2437 APInt value;
2438
2439 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value)) && isSignlessInt) {
2440 if (value.getBitWidth() == 1) {
2441 // mux(a, 0, b) -> and(~a, b) for single-bit values.
2442 if (value.isZero()) {
2443 auto notCond = createOrFoldNot(op.getLoc(), op.getCond(), rewriter);
2444 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2445 op.getFalseValue(), false);
2446 return success();
2447 }
2448
2449 // mux(a, 1, b) -> or(a, b) for single-bit values.
2450 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2451 op.getFalseValue(), false);
2452 return success();
2453 }
2454
2455 // Check for mux of two constants. There are many ways to simplify them.
2456 APInt value2;
2457 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2458 // When both inputs are constants and differ by only one bit, we can
2459 // simplify by splitting the mux into up to three contiguous chunks: one
2460 // for the differing bit and up to two for the bits that are the same.
2461 // E.g. mux(a, 3'h2, 0) -> concat(0, mux(a, 1, 0), 0) -> concat(0, a, 0)
2462 APInt xorValue = value ^ value2;
2463 if (xorValue.isPowerOf2()) {
2464 unsigned leadingZeros = xorValue.countLeadingZeros();
2465 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2466 SmallVector<Value, 3> operands;
2467
2468 // Concat operands go from MSB to LSB, so we handle chunks in reverse
2469 // order of bit indexes.
2470 // For the chunks that are identical (i.e. correspond to 0s in
2471 // xorValue), we can extract directly from either input value, and we
2472 // arbitrarily pick the trueValue().
2473
2474 if (leadingZeros > 0)
2475 operands.push_back(rewriter.createOrFold<ExtractOp>(
2476 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2477
2478 // Handle the differing bit, which should simplify into either cond or
2479 // ~cond.
2480 auto v1 = rewriter.createOrFold<ExtractOp>(
2481 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2482 auto v2 = rewriter.createOrFold<ExtractOp>(
2483 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2484 operands.push_back(rewriter.createOrFold<MuxOp>(
2485 op.getLoc(), op.getCond(), v1, v2, false));
2486
2487 if (trailingZeros > 0)
2488 operands.push_back(rewriter.createOrFold<ExtractOp>(
2489 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2490
2491 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2492 operands);
2493 return success();
2494 }
2495
2496 // If the true value is all ones and the false is all zeros then we have a
2497 // replicate pattern.
2498 if (value.isAllOnes() && value2.isZero()) {
2499 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2500 rewriter, op, op.getType(), op.getCond());
2501 return success();
2502 }
2503 }
2504 }
2505
2506 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2507 isSignlessInt && value.getBitWidth() == 1) {
2508 // mux(a, b, 0) -> and(a, b) for single-bit values.
2509 if (value.isZero()) {
2510 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2511 op.getTrueValue(), false);
2512 return success();
2513 }
2514
2515 // mux(a, b, 1) -> or(~a, b) for single-bit values.
2516 // falseValue() is known to be a single-bit 1, which we can use for
2517 // the 1 in the representation of ~ using xor.
2518 auto notCond = rewriter.createOrFold<XorOp>(op.getLoc(), op.getCond(),
2519 op.getFalseValue(), false);
2520 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2521 op.getTrueValue(), false);
2522 return success();
2523 }
2524
2525 // mux(!a, b, c) -> mux(a, c, b)
2526 Value subExpr;
2527 Operation *condOp = op.getCond().getDefiningOp();
2528 if (condOp && matchPattern(condOp, m_Complement(m_Any(&subExpr))) &&
2529 op.getTwoState()) {
2530 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2531 subExpr, op.getFalseValue(),
2532 op.getTrueValue(), true);
2533 return success();
2534 }
2535
2536 // Same but with Demorgan's law.
2537 // mux(and(~a, ~b, ~c), x, y) -> mux(or(a, b, c), y, x)
2538 // mux(or(~a, ~b, ~c), x, y) -> mux(and(a, b, c), y, x)
2539 if (condOp && condOp->hasOneUse()) {
2540 SmallVector<Value> invertedOperands;
2541
2542 /// Scan all the operands to see if they are complemented. If so, build a
2543 /// vector of them and return true, otherwise return false.
2544 auto getInvertedOperands = [&]() -> bool {
2545 for (Value operand : condOp->getOperands()) {
2546 if (matchPattern(operand, m_Complement(m_Any(&subExpr))))
2547 invertedOperands.push_back(subExpr);
2548 else
2549 return false;
2550 }
2551 return true;
2552 };
2553
2554 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2555 auto newOr =
2556 rewriter.createOrFold<OrOp>(op.getLoc(), invertedOperands, false);
2557 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2558 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2559 op.getTwoState());
2560 return success();
2561 }
2562 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2563 auto newAnd =
2564 rewriter.createOrFold<AndOp>(op.getLoc(), invertedOperands, false);
2565 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2566 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2567 op.getTwoState());
2568 return success();
2569 }
2570 }
2571
2572 if (auto falseMux = op.getFalseValue().getDefiningOp<MuxOp>();
2573 falseMux && falseMux != op) {
2574 // mux(selector, x, mux(selector, y, z) = mux(selector, x, z)
2575 if (op.getCond() == falseMux.getCond() &&
2576 falseMux.getFalseValue() != falseMux) {
2577 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2578 rewriter, op, op.getCond(), op.getTrueValue(),
2579 falseMux.getFalseValue(), op.getTwoStateAttr());
2580 return success();
2581 }
2582
2583 // Check to see if we can fold a mux tree into an array_create/get pair.
2584 if (foldMuxChainWithComparison(rewriter, op, /*isFalse*/ true,
2585 foldToArrayCreateOnlyWhenDense))
2586 return success();
2587 }
2588
2589 if (auto trueMux = op.getTrueValue().getDefiningOp<MuxOp>();
2590 trueMux && trueMux != op) {
2591 // mux(selector, mux(selector, a, b), c) = mux(selector, a, c)
2592 if (op.getCond() == trueMux.getCond()) {
2593 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2594 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2595 op.getFalseValue(), op.getTwoStateAttr());
2596 return success();
2597 }
2598
2599 // Check to see if we can fold a mux tree into an array_create/get pair.
2600 if (foldMuxChainWithComparison(rewriter, op, /*isFalseSide*/ false,
2601 foldToArrayCreateOnlyWhenDense))
2602 return success();
2603 }
2604
2605 // mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c))
2606 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2607 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2608 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2609 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2610 falseMux != op) {
2611 auto subMux = MuxOp::create(
2612 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2613 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2614 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2615 trueMux.getTrueValue(), subMux,
2616 op.getTwoStateAttr());
2617 return success();
2618 }
2619
2620 // mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b)
2621 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2622 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2623 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2624 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2625 falseMux != op) {
2626 auto subMux = MuxOp::create(
2627 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2628 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2629 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2630 subMux, trueMux.getFalseValue(),
2631 op.getTwoStateAttr());
2632 return success();
2633 }
2634
2635 // mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b)
2636 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2637 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2638 trueMux && falseMux &&
2639 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2640 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2641 falseMux != op) {
2642 auto subMux =
2643 MuxOp::create(rewriter,
2644 rewriter.getFusedLoc(
2645 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2646 op.getCond(), trueMux.getCond(), falseMux.getCond());
2647 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2648 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2649 op.getTwoStateAttr());
2650 return success();
2651 }
2652
2653 // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a
2654 if (foldCommonMuxValue(op, false, rewriter))
2655 return success();
2656 // mux(cond, a, x|y|z|a) -> (x|y|z)&replicate(~cond) | a
2657 if (foldCommonMuxValue(op, true, rewriter))
2658 return success();
2659
2660 // `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2661 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2662 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2663 if (trueOp->getName() == falseOp->getName())
2664 if (foldCommonMuxOperation(op, trueOp, falseOp, rewriter))
2665 return success();
2666
2667 // extracts only of mux(...) -> mux(extract()...)
2668 if (narrowOperationWidth(op, true, rewriter))
2669 return success();
2670
2671 // mux(cond, repl(n, a1), repl(n, a2)) -> repl(n, mux(cond, a1, a2))
2672 if (foldMuxOfUniformArrays(op, rewriter))
2673 return success();
2674
2675 // mux(cond, opA(cond), opB(cond)) -> mux(cond, opA(1), opB(0))
2676 if (op.getTrueValue().getDefiningOp() &&
2677 op.getTrueValue().getDefiningOp() != op)
2678 if (assumeMuxCondInOperand(op.getCond(), op.getTrueValue(), true, rewriter))
2679 return success();
2680 if (op.getFalseValue().getDefiningOp() &&
2681 op.getFalseValue().getDefiningOp() != op)
2682
2683 if (assumeMuxCondInOperand(op.getCond(), op.getFalseValue(), false,
2684 rewriter))
2685 return success();
2686
2687 return failure();
2688}
2689
2690static bool foldArrayOfMuxes(hw::ArrayCreateOp op, PatternRewriter &rewriter) {
2691 // Do not fold uniform or singleton arrays to avoid duplicating muxes.
2692 if (op.getInputs().empty() || op.isUniform())
2693 return false;
2694 auto inputs = op.getInputs();
2695 if (inputs.size() <= 1)
2696 return false;
2697
2698 // Check the operands to the array create. Ensure all of them are the
2699 // same op with the same number of operands.
2700 auto first = inputs[0].getDefiningOp<comb::MuxOp>();
2701 if (!first || hasSVAttributes(first))
2702 return false;
2703
2704 // Check whether all operands are muxes with the same condition.
2705 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2706 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2707 if (!input || first.getCond() != input.getCond())
2708 return false;
2709 }
2710
2711 // Collect the true and the false branches into arrays.
2712 SmallVector<Value> trues{first.getTrueValue()};
2713 SmallVector<Value> falses{first.getFalseValue()};
2714 SmallVector<Location> locs{first->getLoc()};
2715 bool isTwoState = true;
2716 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2717 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2718 trues.push_back(input.getTrueValue());
2719 falses.push_back(input.getFalseValue());
2720 locs.push_back(input->getLoc());
2721 if (!input.getTwoState())
2722 isTwoState = false;
2723 }
2724
2725 // Define the location of the array create as the aggregate of all muxes.
2726 auto loc = FusedLoc::get(op.getContext(), locs);
2727
2728 // Replace the create with an aggregate operation. Push the create op
2729 // into the operands of the aggregate operation.
2730 auto arrayTy = op.getType();
2731 auto trueValues = hw::ArrayCreateOp::create(rewriter, loc, arrayTy, trues);
2732 auto falseValues = hw::ArrayCreateOp::create(rewriter, loc, arrayTy, falses);
2733 rewriter.replaceOpWithNewOp<comb::MuxOp>(op, arrayTy, first.getCond(),
2734 trueValues, falseValues, isTwoState);
2735 return true;
2736}
2737
2738struct ArrayRewriter : public mlir::OpRewritePattern<hw::ArrayCreateOp> {
2739 using OpRewritePattern::OpRewritePattern;
2740
2741 LogicalResult matchAndRewrite(hw::ArrayCreateOp op,
2742 PatternRewriter &rewriter) const override {
2743 if (foldArrayOfMuxes(op, rewriter))
2744 return success();
2745 return failure();
2746 }
2747};
2748
2749} // namespace
2750
2751void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2752 MLIRContext *context) {
2753 results.insert<MuxRewriter, ArrayRewriter>(context);
2754}
2755
2756//===----------------------------------------------------------------------===//
2757// ICmpOp
2758//===----------------------------------------------------------------------===//
2759
2760// Calculate the result of a comparison when the LHS and RHS are both
2761// constants.
2762static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs,
2763 const APInt &rhs) {
2764 switch (predicate) {
2765 case ICmpPredicate::eq:
2766 return lhs.eq(rhs);
2767 case ICmpPredicate::ne:
2768 return lhs.ne(rhs);
2769 case ICmpPredicate::slt:
2770 return lhs.slt(rhs);
2771 case ICmpPredicate::sle:
2772 return lhs.sle(rhs);
2773 case ICmpPredicate::sgt:
2774 return lhs.sgt(rhs);
2775 case ICmpPredicate::sge:
2776 return lhs.sge(rhs);
2777 case ICmpPredicate::ult:
2778 return lhs.ult(rhs);
2779 case ICmpPredicate::ule:
2780 return lhs.ule(rhs);
2781 case ICmpPredicate::ugt:
2782 return lhs.ugt(rhs);
2783 case ICmpPredicate::uge:
2784 return lhs.uge(rhs);
2785 case ICmpPredicate::ceq:
2786 return lhs.eq(rhs);
2787 case ICmpPredicate::cne:
2788 return lhs.ne(rhs);
2789 case ICmpPredicate::weq:
2790 return lhs.eq(rhs);
2791 case ICmpPredicate::wne:
2792 return lhs.ne(rhs);
2793 }
2794 llvm_unreachable("unknown comparison predicate");
2795}
2796
2797// Returns the result of applying the predicate when the LHS and RHS are the
2798// exact same value.
2799static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) {
2800 switch (predicate) {
2801 case ICmpPredicate::eq:
2802 case ICmpPredicate::sle:
2803 case ICmpPredicate::sge:
2804 case ICmpPredicate::ule:
2805 case ICmpPredicate::uge:
2806 case ICmpPredicate::ceq:
2807 case ICmpPredicate::weq:
2808 return true;
2809 case ICmpPredicate::ne:
2810 case ICmpPredicate::slt:
2811 case ICmpPredicate::sgt:
2812 case ICmpPredicate::ult:
2813 case ICmpPredicate::ugt:
2814 case ICmpPredicate::cne:
2815 case ICmpPredicate::wne:
2816 return false;
2817 }
2818 llvm_unreachable("unknown comparison predicate");
2819}
2820
2821OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2822 // gt a, a -> false
2823 // gte a, a -> true
2824 if (getLhs() == getRhs()) {
2825 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2826 return IntegerAttr::get(getType(), val);
2827 }
2828
2829 // gt 1, 2 -> false
2830 if (auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2831 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2832 auto val =
2833 applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2834 return IntegerAttr::get(getType(), val);
2835 }
2836 }
2837 return {};
2838}
2839
2840// Given a range of operands, computes the number of matching prefix and
2841// suffix elements. This does not perform cross-element matching.
2842template <typename Range>
2843static size_t computeCommonPrefixLength(const Range &a, const Range &b) {
2844 size_t commonPrefixLength = 0;
2845 auto ia = a.begin();
2846 auto ib = b.begin();
2847
2848 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2849 if (*ia != *ib) {
2850 break;
2851 }
2852 }
2853
2854 return commonPrefixLength;
2855}
2856
2857static size_t getTotalWidth(ArrayRef<Value> operands) {
2858 size_t totalWidth = 0;
2859 for (auto operand : operands) {
2860 // getIntOrFloatBitWidth should never raise, since all arguments to
2861 // ConcatOp are integers.
2862 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2863 assert(width >= 0);
2864 totalWidth += width;
2865 }
2866 return totalWidth;
2867}
2868
2869/// Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise
2870/// comparison on common prefix and suffixes. Returns success() if a rewriting
2871/// happens. This handles both concat and replicate.
2872static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs,
2873 Operation *rhs,
2874 PatternRewriter &rewriter) {
2875 // It is safe to assume that [{lhsOperands, rhsOperands}.size() > 0] and
2876 // all elements have non-zero length. Both these invariants are verified
2877 // by the ConcatOp verifier.
2878 SmallVector<Value> lhsOperands, rhsOperands;
2879 getConcatOperands(lhs->getResult(0), lhsOperands);
2880 getConcatOperands(rhs->getResult(0), rhsOperands);
2881 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2882
2883 auto formCatOrReplicate = [&](Location loc,
2884 ArrayRef<Value> operands) -> Value {
2885 assert(!operands.empty());
2886 Value sameElement = operands[0];
2887 for (size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2888 if (sameElement != operands[i])
2889 sameElement = Value();
2890 if (sameElement)
2891 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2892 operands.size());
2893 return rewriter.createOrFold<ConcatOp>(loc, operands);
2894 };
2895
2896 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2897 Value rhs) -> LogicalResult {
2898 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2899 op.getTwoState());
2900 return success();
2901 };
2902
2903 size_t commonPrefixLength =
2904 computeCommonPrefixLength(lhsOperands, rhsOperands);
2905 if (commonPrefixLength == lhsOperands.size()) {
2906 // cat(a, b, c) == cat(a, b, c) -> 1
2907 bool result = applyCmpPredicateToEqualOperands(op.getPredicate());
2908 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2909 APInt(1, result));
2910 return success();
2911 }
2912
2913 size_t commonSuffixLength = computeCommonPrefixLength(
2914 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2915
2916 size_t commonPrefixTotalWidth =
2917 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2918 size_t commonSuffixTotalWidth =
2919 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2920 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2921 .drop_back(commonSuffixLength);
2922 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2923 .drop_back(commonSuffixLength);
2924
2925 auto replaceWithoutReplicatingSignBit = [&]() {
2926 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2927 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2928 return replaceWith(op.getPredicate(), newLhs, newRhs);
2929 };
2930
2931 auto replaceWithReplicatingSignBit = [&]() {
2932 auto firstNonEmptyValue = lhsOperands[0];
2933 auto firstNonEmptyElemWidth =
2934 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2935 Value signBit = rewriter.createOrFold<ExtractOp>(
2936 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2937
2938 auto newLhs = ConcatOp::create(rewriter, lhs->getLoc(), signBit, lhsOnly);
2939 auto newRhs = ConcatOp::create(rewriter, rhs->getLoc(), signBit, rhsOnly);
2940 return replaceWith(op.getPredicate(), newLhs, newRhs);
2941 };
2942
2943 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2944 // scmp(cat(..x, b), cat(..y, b)) == scmp(cat(..x), cat(..y))
2945 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2946 return replaceWithoutReplicatingSignBit();
2947
2948 // scmp(cat(a, ..x, b), cat(a, ..y, b)) == scmp(cat(sgn(a), ..x),
2949 // cat(sgn(b), ..y)) Note that we cannot perform this optimization if
2950 // [width(b) = 0 && width(a) <= 1]. since that common prefix is the sign
2951 // bit. Doing the rewrite can result in an infinite loop.
2952 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2953 return replaceWithReplicatingSignBit();
2954
2955 } else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2956 // ucmp(cat(a, ..x, b), cat(a, ..y, b)) = ucmp(cat(..x), cat(..y))
2957 return replaceWithoutReplicatingSignBit();
2958 }
2959
2960 return failure();
2961}
2962
2963/// Given an equality comparison with a constant value and some operand that has
2964/// known bits, simplify the comparison to check only the unknown bits of the
2965/// input.
2966///
2967/// One simple example of this is that `concat(0, stuff) == 0` can be simplified
2968/// to `stuff == 0`, or `and(x, 3) == 0` can be simplified to
2969/// `extract x[1:0] == 0`
2971 ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst,
2972 PatternRewriter &rewriter) {
2973
2974 // If any of the known bits disagree with any of the comparison bits, then
2975 // we can constant fold this comparison right away.
2976 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2977 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2978 // If we discover a mismatch then we know an "eq" comparison is false
2979 // and a "ne" comparison is true!
2980 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2981 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2982 APInt(1, result));
2983 return;
2984 }
2985
2986 // Check to see if we can prove the result entirely of the comparison (in
2987 // which we bail out early), otherwise build a list of values to concat and a
2988 // smaller constant to compare against.
2989 SmallVector<Value> newConcatOperands;
2990 auto newConstant = APInt::getZeroWidth();
2991
2992 // Ok, some (maybe all) bits are known and some others may be unknown.
2993 // Extract out segments of the operand and compare against the
2994 // corresponding bits.
2995 unsigned knownMSB = bitsKnown.countLeadingOnes();
2996
2997 Value operand = cmpOp.getLhs();
2998
2999 // Ok, some bits are known but others are not. Extract out sequences of
3000 // bits that are unknown and compare just those bits. We work from MSB to
3001 // LSB.
3002 while (knownMSB != bitsKnown.getBitWidth()) {
3003 // Drop any high bits that are known.
3004 if (knownMSB)
3005 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
3006
3007 // Find the span of unknown bits, and extract it.
3008 unsigned unknownBits = bitsKnown.countLeadingZeros();
3009 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
3010 auto spanOperand = rewriter.createOrFold<ExtractOp>(
3011 operand.getLoc(), operand, /*lowBit=*/lowBit,
3012 /*bitWidth=*/unknownBits);
3013 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
3014
3015 // Add this info to the concat we're generating.
3016 newConcatOperands.push_back(spanOperand);
3017 // FIXME(llvm merge, cc697fc292b0): concat doesn't work with zero bit values
3018 // newConstant = newConstant.concat(spanConstant);
3019 if (newConstant.getBitWidth() != 0)
3020 newConstant = newConstant.concat(spanConstant);
3021 else
3022 newConstant = spanConstant;
3023
3024 // Drop the unknown bits in prep for the next chunk.
3025 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3026 bitsKnown = bitsKnown.trunc(newWidth);
3027 knownMSB = bitsKnown.countLeadingOnes();
3028 }
3029
3030 // If all the operands to the concat are foldable then we have an identity
3031 // situation where all the sub-elements equal each other. This implies that
3032 // the overall result is foldable.
3033 if (newConcatOperands.empty()) {
3034 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3035 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
3036 APInt(1, result));
3037 return;
3038 }
3039
3040 // If we have a single operand remaining, use it, otherwise form a concat.
3041 Value concatResult =
3042 rewriter.createOrFold<ConcatOp>(operand.getLoc(), newConcatOperands);
3043
3044 // Form the comparison against the smaller constant.
3045 auto newConstantOp = hw::ConstantOp::create(
3046 rewriter, cmpOp.getOperand(1).getLoc(), newConstant);
3047
3048 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
3049 cmpOp.getPredicate(), concatResult,
3050 newConstantOp, cmpOp.getTwoState());
3051}
3052
3053// Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), cst1^cst2).
3054static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp,
3055 const APInt &rhs,
3056 PatternRewriter &rewriter) {
3057 auto ip = rewriter.saveInsertionPoint();
3058 rewriter.setInsertionPoint(xorOp);
3059
3060 auto xorRHS = xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>();
3061 auto newRHS = hw::ConstantOp::create(rewriter, xorRHS->getLoc(),
3062 xorRHS.getValue() ^ rhs);
3063 Value newLHS;
3064 switch (xorOp.getNumOperands()) {
3065 case 1:
3066 // This isn't common but is defined so we need to handle it.
3067 newLHS = hw::ConstantOp::create(rewriter, xorOp.getLoc(),
3068 APInt::getZero(rhs.getBitWidth()));
3069 break;
3070 case 2:
3071 // The binary case is the most common.
3072 newLHS = xorOp.getOperand(0);
3073 break;
3074 default:
3075 // The general case forces us to form a new xor with the remaining operands.
3076 SmallVector<Value> newOperands(xorOp.getOperands());
3077 newOperands.pop_back();
3078 newLHS = XorOp::create(rewriter, xorOp.getLoc(), newOperands, false);
3079 break;
3080 }
3081
3082 bool xorMultipleUses = !xorOp->hasOneUse();
3083
3084 // If the xor has multiple uses (not just the compare, then we need/want to
3085 // replace them as well.
3086 if (xorMultipleUses)
3087 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3088 false);
3089
3090 // Replace the comparison.
3091 rewriter.restoreInsertionPoint(ip);
3092 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3093 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS, false);
3094}
3095
3096LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3097 if (isOpTriviallyRecursive(op))
3098 return failure();
3099 APInt lhs, rhs;
3100
3101 // icmp 1, x -> icmp x, 1
3102 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3103 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3104 "Should be folded");
3105 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3106 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3107 op.getRhs(), op.getLhs(), op.getTwoState());
3108 return success();
3109 }
3110
3111 // Canonicalize with RHS constant
3112 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3113 auto getConstant = [&](APInt constant) -> Value {
3114 return hw::ConstantOp::create(rewriter, op.getLoc(), std::move(constant));
3115 };
3116
3117 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3118 Value rhs) -> LogicalResult {
3119 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
3120 rhs, op.getTwoState());
3121 return success();
3122 };
3123
3124 auto replaceWithConstantI1 = [&](bool constant) -> LogicalResult {
3125 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
3126 APInt(1, constant));
3127 return success();
3128 };
3129
3130 switch (op.getPredicate()) {
3131 case ICmpPredicate::slt:
3132 // x < max -> x != max
3133 if (rhs.isMaxSignedValue())
3134 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3135 // x < min -> false
3136 if (rhs.isMinSignedValue())
3137 return replaceWithConstantI1(0);
3138 // x < min+1 -> x == min
3139 if ((rhs - 1).isMinSignedValue())
3140 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3141 getConstant(rhs - 1));
3142 break;
3143 case ICmpPredicate::sgt:
3144 // x > min -> x != min
3145 if (rhs.isMinSignedValue())
3146 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3147 // x > max -> false
3148 if (rhs.isMaxSignedValue())
3149 return replaceWithConstantI1(0);
3150 // x > max-1 -> x == max
3151 if ((rhs + 1).isMaxSignedValue())
3152 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3153 getConstant(rhs + 1));
3154 break;
3155 case ICmpPredicate::ult:
3156 // x < max -> x != max
3157 if (rhs.isAllOnes())
3158 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3159 // x < min -> false
3160 if (rhs.isZero())
3161 return replaceWithConstantI1(0);
3162 // x < min+1 -> x == min
3163 if ((rhs - 1).isZero())
3164 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3165 getConstant(rhs - 1));
3166
3167 // x < 0xE0 -> extract(x, 5..7) != 0b111
3168 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3169 rhs.getBitWidth()) {
3170 auto numOnes = rhs.countLeadingOnes();
3171 auto smaller = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(),
3172 rhs.getBitWidth() - numOnes, numOnes);
3173 return replaceWith(ICmpPredicate::ne, smaller,
3174 getConstant(APInt::getAllOnes(numOnes)));
3175 }
3176
3177 break;
3178 case ICmpPredicate::ugt:
3179 // x > min -> x != min
3180 if (rhs.isZero())
3181 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3182 // x > max -> false
3183 if (rhs.isAllOnes())
3184 return replaceWithConstantI1(0);
3185 // x > max-1 -> x == max
3186 if ((rhs + 1).isAllOnes())
3187 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3188 getConstant(rhs + 1));
3189
3190 // x > 0x07 -> extract(x, 3..7) != 0b00000
3191 if ((rhs + 1).isPowerOf2()) {
3192 auto numOnes = rhs.countTrailingOnes();
3193 auto newWidth = rhs.getBitWidth() - numOnes;
3194 auto smaller = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(),
3195 numOnes, newWidth);
3196 return replaceWith(ICmpPredicate::ne, smaller,
3197 getConstant(APInt::getZero(newWidth)));
3198 }
3199
3200 break;
3201 case ICmpPredicate::sle:
3202 // x <= max -> true
3203 if (rhs.isMaxSignedValue())
3204 return replaceWithConstantI1(1);
3205 // x <= c -> x < (c+1)
3206 return replaceWith(ICmpPredicate::slt, op.getLhs(), getConstant(rhs + 1));
3207 case ICmpPredicate::sge:
3208 // x >= min -> true
3209 if (rhs.isMinSignedValue())
3210 return replaceWithConstantI1(1);
3211 // x >= c -> x > (c-1)
3212 return replaceWith(ICmpPredicate::sgt, op.getLhs(), getConstant(rhs - 1));
3213 case ICmpPredicate::ule:
3214 // x <= max -> true
3215 if (rhs.isAllOnes())
3216 return replaceWithConstantI1(1);
3217 // x <= c -> x < (c+1)
3218 return replaceWith(ICmpPredicate::ult, op.getLhs(), getConstant(rhs + 1));
3219 case ICmpPredicate::uge:
3220 // x >= min -> true
3221 if (rhs.isZero())
3222 return replaceWithConstantI1(1);
3223 // x >= c -> x > (c-1)
3224 return replaceWith(ICmpPredicate::ugt, op.getLhs(), getConstant(rhs - 1));
3225 case ICmpPredicate::eq:
3226 if (rhs.getBitWidth() == 1) {
3227 if (rhs.isZero()) {
3228 // x == 0 -> x ^ 1
3229 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3230 getConstant(APInt(1, 1)),
3231 op.getTwoState());
3232 return success();
3233 }
3234 if (rhs.isAllOnes()) {
3235 // x == 1 -> x
3236 replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
3237 return success();
3238 }
3239 }
3240 break;
3241 case ICmpPredicate::ne:
3242 if (rhs.getBitWidth() == 1) {
3243 if (rhs.isZero()) {
3244 // x != 0 -> x
3245 replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
3246 return success();
3247 }
3248 if (rhs.isAllOnes()) {
3249 // x != 1 -> x ^ 1
3250 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3251 getConstant(APInt(1, 1)),
3252 op.getTwoState());
3253 return success();
3254 }
3255 }
3256 break;
3257 case ICmpPredicate::ceq:
3258 case ICmpPredicate::cne:
3259 case ICmpPredicate::weq:
3260 case ICmpPredicate::wne:
3261 break;
3262 }
3263
3264 // We have some specific optimizations for comparison with a constant that
3265 // are only supported for equality comparisons.
3266 if (op.getPredicate() == ICmpPredicate::eq ||
3267 op.getPredicate() == ICmpPredicate::ne) {
3268 // Simplify `icmp(value_with_known_bits, rhscst)` into some extracts
3269 // with a smaller constant. We only support equality comparisons for
3270 // this.
3271 auto knownBits = computeKnownBits(op.getLhs());
3272 if (!knownBits.isUnknown())
3273 return combineEqualityICmpWithKnownBitsAndConstant(op, knownBits, rhs,
3274 rewriter),
3275 success();
3276
3277 // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b),
3278 // cst1^cst2).
3279 if (auto xorOp = op.getLhs().getDefiningOp<XorOp>())
3280 if (xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>())
3281 return combineEqualityICmpWithXorOfConstant(op, xorOp, rhs, rewriter),
3282 success();
3283
3284 // Simplify icmp eq(replicate(v, n), c) -> icmp eq(v, c) if c is zero or
3285 // all one.
3286 if (auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3287 if (rhs.isAllOnes() || rhs.isZero()) {
3288 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3289 auto cst =
3290 hw::ConstantOp::create(rewriter, op.getLoc(),
3291 rhs.isAllOnes() ? APInt::getAllOnes(width)
3292 : APInt::getZero(width));
3293 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3294 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3295 op.getTwoState());
3296 return success();
3297 }
3298 }
3299 }
3300
3301 // icmp(cat(prefix, a, b, suffix), cat(prefix, c, d, suffix)) => icmp(cat(a,
3302 // b), cat(c, d)). contains special handling for sign bit in signed
3303 // compressions.
3304 if (Operation *opLHS = op.getLhs().getDefiningOp())
3305 if (Operation *opRHS = op.getRhs().getDefiningOp())
3306 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3307 isa<ConcatOp, ReplicateOp>(opRHS)) {
3308 if (succeeded(matchAndRewriteCompareConcat(op, opLHS, opRHS, rewriter)))
3309 return success();
3310 }
3311
3312 return failure();
3313}
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
Definition CombFolds.cpp:84
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
Definition CombFolds.cpp:52
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
Definition CombFolds.cpp:46
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
Definition CombFolds.cpp:90
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool assumeMuxCondInOperand(Value muxCond, Value muxValue, bool constCond, PatternRewriter &rewriter)
If the mux condition is an operand to the op defining its true or false value, replace the condition ...
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool isOpTriviallyRecursive(Operation *op)
Definition CombFolds.cpp:27
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool hasSVAttributes(Operation *op)
Definition CombFolds.cpp:67
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
Definition CombFolds.cpp:38
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::unique_ptr< Context > context
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(low_bit, result_type, input=None)
Definition comb.py:187
create(elements)
Definition hw.py:483
create(array_value, low_index, ret_type)
Definition hw.py:466
create(data_type, value)
Definition hw.py:441
create(data_type, value)
Definition hw.py:433
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
void extractBits(OpBuilder &builder, Value val, SmallVectorImpl< Value > &bits)
Extract bits from a value.
Definition CombOps.cpp:78
bool foldMuxChainWithComparison(PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide, llvm::function_ref< MuxChainWithComparisonFoldingStyle(size_t indexWidth, size_t numEntries)> styleFn)
Mux chain folding that converts chains of muxes with index comparisons into array operations or balan...
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition CombOps.cpp:66
MuxChainWithComparisonFoldingStyle
Enum for mux chain folding styles.
Definition CombOps.h:103
@ BalancedMuxTree
Definition CombOps.h:103
LogicalResult convertModUByPowerOfTwo(ModUOp modOp, mlir::PatternRewriter &rewriter)
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
Value constructMuxTree(OpBuilder &builder, Location loc, ArrayRef< Value > selectors, ArrayRef< Value > leafNodes, Value outOfBoundsValue)
Construct a mux tree for given leaf nodes.
Definition CombOps.cpp:105
LogicalResult convertDivUByPowerOfTwo(DivUOp divOp, mlir::PatternRewriter &rewriter)
Convert unsigned division or modulo by a power of two.
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
Definition comb.py:1