CIRCT 23.0.0git
Loading...
Searching...
No Matches
CombFolds.cpp
Go to the documentation of this file.
1//===- CombFolds.cpp - Folds + Canonicalization for Comb operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Diagnostics.h"
14#include "mlir/IR/Matchers.h"
15#include "mlir/IR/PatternMatch.h"
16#include "llvm/ADT/SetVector.h"
17#include "llvm/ADT/SmallBitVector.h"
18#include "llvm/ADT/TypeSwitch.h"
19#include "llvm/Support/KnownBits.h"
20
21using namespace mlir;
22using namespace circt;
23using namespace comb;
24using namespace matchers;
25
26// Returns true if the op has one of its own results as an operand.
27static bool isOpTriviallyRecursive(Operation *op) {
28 return llvm::any_of(op->getOperands(), [op](auto operand) {
29 return operand.getDefiningOp() == op;
30 });
31}
32
33/// Create a new instance of a generic operation that only has value operands,
34/// and has a single result value whose type matches the first operand.
35///
36/// This should not be used to create instances of ops with attributes or with
37/// more complicated type signatures.
38static Value createGenericOp(Location loc, OperationName name,
39 ArrayRef<Value> operands, OpBuilder &builder) {
40 OperationState state(loc, name);
41 state.addOperands(operands);
42 state.addTypes(operands[0].getType());
43 return builder.create(state)->getResult(0);
44}
45
46static TypedAttr getIntAttr(const APInt &value, MLIRContext *context) {
47 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
48 value);
49}
50
51/// Flatten concat and mux operands into a vector.
52static void getConcatOperands(Value v, SmallVectorImpl<Value> &result) {
53 if (auto concat = v.getDefiningOp<ConcatOp>()) {
54 for (auto op : concat.getOperands())
55 getConcatOperands(op, result);
56 } else if (auto repl = v.getDefiningOp<ReplicateOp>()) {
57 for (size_t i = 0, e = repl.getMultiple(); i != e; ++i)
58 getConcatOperands(repl.getOperand(), result);
59 } else {
60 result.push_back(v);
61 }
62}
63
64// Return true if the op has SV attributes. Note that we cannot use a helper
65// function `hasSVAttributes` defined under SV dialect because of a cyclic
66// dependency.
67static bool hasSVAttributes(Operation *op) {
68 return op->hasAttr("sv.attributes");
69}
70
71namespace {
72template <typename SubType>
73struct ComplementMatcher {
74 SubType lhs;
75 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
76 bool match(Operation *op) {
77 auto xorOp = dyn_cast<XorOp>(op);
78 return xorOp && xorOp.isBinaryNot() &&
79 mlir::detail::matchOperandOrValueAtIndex(op, 0, lhs);
80 }
81};
82} // end anonymous namespace
83
84template <typename SubType>
85static inline ComplementMatcher<SubType> m_Complement(const SubType &subExpr) {
86 return ComplementMatcher<SubType>(subExpr);
87}
88
89/// Return true if the op will be flattened afterwards. Op will be flattend if
90/// it has a single user which has a same op type. User must be in same block.
91static bool shouldBeFlattened(Operation *op) {
92 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
93 "must be commutative operations"));
94 if (op->hasOneUse()) {
95 auto *user = *op->getUsers().begin();
96 return user->getName() == op->getName() &&
97 op->getAttrOfType<UnitAttr>("twoState") ==
98 user->getAttrOfType<UnitAttr>("twoState") &&
99 op->getBlock() == user->getBlock();
100 }
101 return false;
102}
103
104/// Flattens a single input in `op` if `hasOneUse` is true and it can be defined
105/// as an Op. Returns true if successful, and false otherwise.
106///
107/// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true
108///
109static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) {
110 // Skip if the operation should be flattened by another operation.
111 if (shouldBeFlattened(op))
112 return false;
113
114 auto inputs = op->getOperands();
115
116 SmallVector<Value, 4> newOperands;
117 SmallVector<Location, 4> newLocations{op->getLoc()};
118 newOperands.reserve(inputs.size());
119 struct Element {
120 decltype(inputs.begin()) current, end;
121 };
122
123 SmallVector<Element> worklist;
124 worklist.push_back({inputs.begin(), inputs.end()});
125 bool binFlag = op->hasAttrOfType<UnitAttr>("twoState");
126 bool changed = false;
127 while (!worklist.empty()) {
128 auto &element = worklist.back(); // Do not pop. Take ref.
129
130 // Pop when we finished traversing the current operand range.
131 if (element.current == element.end) {
132 worklist.pop_back();
133 continue;
134 }
135
136 Value value = *element.current++;
137 auto *flattenOp = value.getDefiningOp();
138 // If not defined by a compatible operation of the same kind and
139 // from the same block, keep this as-is.
140 if (!flattenOp || flattenOp->getName() != op->getName() ||
141 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>("twoState") ||
142 flattenOp->getBlock() != op->getBlock()) {
143 newOperands.push_back(value);
144 continue;
145 }
146
147 // Don't duplicate logic when it has multiple uses.
148 if (!value.hasOneUse()) {
149 // We can fold a multi-use binary operation into this one if this allows a
150 // constant to fold though. For example, fold
151 // (or a, b, c, (or d, cst1), cst2) --> (or a, b, c, d, cst1, cst2)
152 // since the constants will both fold and we end up with the equiv cost.
153 //
154 // We don't do this for add/mul because the hardware won't be shared
155 // between the two ops if duplicated.
156 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
157 !flattenOp->getOperand(1).getDefiningOp<hw::ConstantOp>() ||
158 !inputs.back().getDefiningOp<hw::ConstantOp>()) {
159 newOperands.push_back(value);
160 continue;
161 }
162 }
163
164 changed = true;
165
166 // Otherwise, push operands into worklist.
167 auto flattenOpInputs = flattenOp->getOperands();
168 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
169 newLocations.push_back(flattenOp->getLoc());
170 }
171
172 if (!changed)
173 return false;
174
175 Value result = createGenericOp(FusedLoc::get(op->getContext(), newLocations),
176 op->getName(), newOperands, rewriter);
177 if (binFlag)
178 result.getDefiningOp()->setAttr("twoState", rewriter.getUnitAttr());
179
180 replaceOpAndCopyNamehint(rewriter, op, result);
181 return true;
182}
183
184// Given a range of uses of an operation, find the lowest and highest bits
185// inclusive that are ever referenced. The range of uses must not be empty.
186static std::pair<size_t, size_t>
187getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits,
188 size_t originalOpWidth) {
189 auto users = op->getUsers();
190 assert(!users.empty() &&
191 "getLowestBitAndHighestBitRequired cannot operate on "
192 "a empty list of uses.");
193
194 // when we don't want to narrowTrailingBits (namely in arithmetic
195 // operations), forcing lowestBitRequired = 0
196 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
197 size_t highestBitRequired = 0;
198
199 for (auto *user : users) {
200 if (auto extractOp = dyn_cast<ExtractOp>(user)) {
201 size_t lowBit = extractOp.getLowBit();
202 size_t highBit =
203 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
204 highestBitRequired = std::max(highestBitRequired, highBit);
205 lowestBitRequired = std::min(lowestBitRequired, lowBit);
206 continue;
207 }
208
209 highestBitRequired = originalOpWidth - 1;
210 lowestBitRequired = 0;
211 break;
212 }
213
214 return {lowestBitRequired, highestBitRequired};
215}
216
217template <class OpTy>
218static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits,
219 PatternRewriter &rewriter) {
220 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
221 if (!opType)
222 return false;
223
224 auto range = getLowestBitAndHighestBitRequired(op, narrowTrailingBits,
225 opType.getWidth());
226 if (range.second + 1 == opType.getWidth() && range.first == 0)
227 return false;
228
229 SmallVector<Value> args;
230 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
231 for (auto inop : op.getOperands()) {
232 // deal with muxes here
233 if (inop.getType() != op.getType())
234 args.push_back(inop);
235 else
236 args.push_back(rewriter.createOrFold<ExtractOp>(inop.getLoc(), newType,
237 inop, range.first));
238 }
239 auto newop = OpTy::create(rewriter, op.getLoc(), newType, args);
240 newop->setDialectAttrs(op->getDialectAttrs());
241 if (op.getTwoState())
242 newop.setTwoState(true);
243
244 Value newResult = newop.getResult();
245 if (range.first)
246 newResult = rewriter.createOrFold<ConcatOp>(
247 op.getLoc(), newResult,
248 hw::ConstantOp::create(rewriter, op.getLoc(),
249 APInt::getZero(range.first)));
250 if (range.second + 1 < opType.getWidth())
251 newResult = rewriter.createOrFold<ConcatOp>(
252 op.getLoc(),
254 rewriter, op.getLoc(),
255 APInt::getZero(opType.getWidth() - range.second - 1)),
256 newResult);
257 rewriter.replaceOp(op, newResult);
258 return true;
259}
260
261//===----------------------------------------------------------------------===//
262// Unary Operations
263//===----------------------------------------------------------------------===//
264
265OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
266 if (isOpTriviallyRecursive(*this))
267 return {};
268
269 // Replicate one time -> noop.
270 if (cast<IntegerType>(getType()).getWidth() ==
271 getInput().getType().getIntOrFloatBitWidth())
272 return getInput();
273
274 // Constant fold.
275 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
276 if (input.getValue().getBitWidth() == 1) {
277 if (input.getValue().isZero())
278 return getIntAttr(
279 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
280 getContext());
281 return getIntAttr(
282 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
283 getContext());
284 }
285
286 APInt result = APInt::getZeroWidth();
287 for (auto i = getMultiple(); i != 0; --i)
288 result = result.concat(input.getValue());
289 return getIntAttr(result, getContext());
290 }
291
292 return {};
293}
294
295OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
296 if (isOpTriviallyRecursive(*this))
297 return {};
298
299 // Constant fold.
300 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
301 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
302
303 return {};
304}
305
306//===----------------------------------------------------------------------===//
307// Binary Operations
308//===----------------------------------------------------------------------===//
309
310/// Performs constant folding `calculate` with element-wise behavior on the two
311/// attributes in `operands` and returns the result if possible.
312static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
313 hw::PEO paramOpcode) {
314 assert(operands.size() == 2 && "binary op takes two operands");
315 if (!operands[0] || !operands[1])
316 return {};
317
318 // Fold constants with ParamExprAttr::get which handles simple constants as
319 // well as parameter expressions.
320 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
321 cast<TypedAttr>(operands[1]));
322}
323
324OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
325 if (isOpTriviallyRecursive(*this))
326 return {};
327
328 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
329 if (rhs.getValue().isZero())
330 return getOperand(0);
331
332 unsigned width = getType().getIntOrFloatBitWidth();
333 if (rhs.getValue().uge(width))
334 return getIntAttr(APInt::getZero(width), getContext());
335 }
336 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::Shl);
337}
338
339LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
341 return failure();
342
343 // ShlOp(x, cst) -> Concat(Extract(x), zeros)
344 APInt value;
345 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
346 return failure();
347
348 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
349 if (value.ugt(width))
350 value = width;
351 unsigned shift = value.getZExtValue();
352
353 // This case is handled by fold.
354 if (width <= shift || shift == 0)
355 return failure();
356
357 auto zeros =
358 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getZero(shift));
359
360 // Remove the high bits which would be removed by the Shl.
361 auto extract =
362 ExtractOp::create(rewriter, op.getLoc(), op.getLhs(), 0, width - shift);
363
364 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
365 return success();
366}
367
368OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
369 if (isOpTriviallyRecursive(*this))
370 return {};
371
372 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
373 if (rhs.getValue().isZero())
374 return getOperand(0);
375
376 unsigned width = getType().getIntOrFloatBitWidth();
377 if (rhs.getValue().uge(width))
378 return getIntAttr(APInt::getZero(width), getContext());
379 }
380 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrU);
381}
382
383LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
385 return failure();
386
387 // ShrUOp(x, cst) -> Concat(zeros, Extract(x))
388 APInt value;
389 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
390 return failure();
391
392 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
393 if (value.ugt(width))
394 value = width;
395 unsigned shift = value.getZExtValue();
396
397 // This case is handled by fold.
398 if (width <= shift || shift == 0)
399 return failure();
400
401 auto zeros =
402 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getZero(shift));
403
404 // Remove the low bits which would be removed by the Shr.
405 auto extract = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(), shift,
406 width - shift);
407
408 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
409 return success();
410}
411
412OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
413 if (isOpTriviallyRecursive(*this))
414 return {};
415
416 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
417 if (rhs.getValue().isZero())
418 return getOperand(0);
419 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrS);
420}
421
422LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
424 return failure();
425
426 // ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x))
427 APInt value;
428 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
429 return failure();
430
431 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
432 if (value.ugt(width))
433 value = width;
434 unsigned shift = value.getZExtValue();
435
436 auto topbit =
437 rewriter.createOrFold<ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
438 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
439
440 if (width == shift) {
441 replaceOpAndCopyNamehint(rewriter, op, {sext});
442 return success();
443 }
444
445 auto extract = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(), shift,
446 width - shift);
447
448 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
449 return success();
450}
451
452//===----------------------------------------------------------------------===//
453// Other Operations
454//===----------------------------------------------------------------------===//
455
456OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
457 if (isOpTriviallyRecursive(*this))
458 return {};
459
460 // If we are extracting the entire input, then return it.
461 if (getInput().getType() == getType())
462 return getInput();
463
464 // Constant fold.
465 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
466 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
467 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
468 getContext());
469 }
470 return {};
471}
472
473// Transforms extract(lo, cat(a, b, c, d, e)) into
474// cat(extract(lo1, b), c, extract(lo2, d)).
475// innerCat must be the argument of the provided ExtractOp.
477 ConcatOp innerCat,
478 PatternRewriter &rewriter) {
479 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
480 size_t beginOfFirstRelevantElement = 0;
481 auto it = reversedConcatArgs.begin();
482 size_t lowBit = op.getLowBit();
483
484 // This loop finds the first concatArg that is covered by the ExtractOp
485 for (; it != reversedConcatArgs.end(); it++) {
486 assert(beginOfFirstRelevantElement <= lowBit &&
487 "incorrectly moved past an element that lowBit has coverage over");
488 auto operand = *it;
489
490 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
491 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
492 // A bit other than the first bit will be used in this element.
493 // ...... ........ ...
494 // ^---lowBit
495 // ^---beginOfFirstRelevantElement
496 //
497 // Edge-case close to the end of the range.
498 // ...... ........ ...
499 // ^---(position + operandWidth)
500 // ^---lowBit
501 // ^---beginOfFirstRelevantElement
502 //
503 // Edge-case close to the beginning of the rang
504 // ...... ........ ...
505 // ^---lowBit
506 // ^---beginOfFirstRelevantElement
507 //
508 break;
509 }
510
511 // extraction discards this element.
512 // ...... ........ ...
513 // | ^---lowBit
514 // ^---beginOfFirstRelevantElement
515 beginOfFirstRelevantElement += operandWidth;
516 }
517 assert(it != reversedConcatArgs.end() &&
518 "incorrectly failed to find an element which contains coverage of "
519 "lowBit");
520
521 SmallVector<Value> reverseConcatArgs;
522 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
523 size_t extractLo = lowBit - beginOfFirstRelevantElement;
524
525 // Transform individual arguments of innerCat(..., a, b, c,) into
526 // [ extract(a), b, extract(c) ], skipping an extract operation where
527 // possible.
528 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
529 auto concatArg = *it;
530 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
531 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
532
533 if (widthToConsume == operandWidth && extractLo == 0) {
534 reverseConcatArgs.push_back(concatArg);
535 } else {
536 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
537 reverseConcatArgs.push_back(
538 ExtractOp::create(rewriter, op.getLoc(), resultType, *it, extractLo));
539 }
540
541 widthRemaining -= widthToConsume;
542
543 // Beyond the first element, all elements are extracted from position 0.
544 extractLo = 0;
545 }
546
547 if (reverseConcatArgs.size() == 1) {
548 replaceOpAndCopyNamehint(rewriter, op, reverseConcatArgs[0]);
549 } else {
550 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
551 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
552 }
553 return success();
554}
555
556// Transforms extract(lo, replicate(a, N)) into replicate(a, N-c).
557static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate,
558 PatternRewriter &rewriter) {
559 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
560 auto replicateEltWidth =
561 replicate.getOperand().getType().getIntOrFloatBitWidth();
562
563 // If the extract starts at the base of an element and is an even multiple,
564 // we can replace the extract with a smaller replicate.
565 if (op.getLowBit() % replicateEltWidth == 0 &&
566 extractResultWidth % replicateEltWidth == 0) {
567 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
568 replicate.getOperand());
569 return true;
570 }
571
572 // If the extract is completely contained in one element, extract from the
573 // element.
574 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
575 replicateEltWidth) {
576 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
577 rewriter, op, op.getType(), replicate.getOperand(),
578 op.getLowBit() % replicateEltWidth);
579 return true;
580 }
581
582 // We don't currently handle the case of extracting from non-whole elements,
583 // e.g. `extract (replicate 2-bit-thing, N), 1`.
584 return false;
585}
586
587LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
589 return failure();
590 auto *inputOp = op.getInput().getDefiningOp();
591
592 // This turns out to be incredibly expensive. Disable until performance is
593 // addressed.
594#if 0
595 // If the extracted bits are all known, then return the result.
596 auto knownBits = computeKnownBits(op.getInput())
597 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
598 op.getLowBit());
599 if (knownBits.isConstant()) {
600 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
601 knownBits.getConstant());
602 return success();
603 }
604#endif
605
606 // extract(olo, extract(ilo, x)) = extract(olo + ilo, x)
607 if (auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
608 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
609 rewriter, op, op.getType(), innerExtract.getInput(),
610 innerExtract.getLowBit() + op.getLowBit());
611 return success();
612 }
613
614 // extract(lo, cat(a, b, c, d, e)) = cat(extract(lo1, b), c, extract(lo2, d))
615 if (auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
616 return extractConcatToConcatExtract(op, innerCat, rewriter);
617
618 // extract(lo, replicate(a))
619 if (auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
620 if (extractFromReplicate(op, replicate, rewriter))
621 return success();
622
623 // `extract(and(a, cst))` -> `extract(a)` when the relevant bits of the
624 // and/or/xor are not modifying the extracted bits.
625 if (inputOp && inputOp->getNumOperands() == 2 &&
626 isa<AndOp, OrOp, XorOp>(inputOp)) {
627 if (auto cstRHS = inputOp->getOperand(1).getDefiningOp<hw::ConstantOp>()) {
628 auto extractedCst = cstRHS.getValue().extractBits(
629 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
630 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
631 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
632 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
633 return success();
634 }
635
636 // `extract(and(a, cst))` -> `concat(extract(a), 0)` if we only need one
637 // extract to represent the result. Turning it into a pile of extracts is
638 // always fine by our cost model, but we don't want to explode things into
639 // a ton of bits because it will bloat the IR and generated Verilog.
640 if (isa<AndOp>(inputOp)) {
641 // For our cost model, we only do this if the bit pattern is a
642 // contiguous series of ones.
643 unsigned lz = extractedCst.countLeadingZeros();
644 unsigned tz = extractedCst.countTrailingZeros();
645 unsigned pop = extractedCst.popcount();
646 if (extractedCst.getBitWidth() - lz - tz == pop) {
647 auto resultTy = rewriter.getIntegerType(pop);
648 SmallVector<Value> resultElts;
649 if (lz)
650 resultElts.push_back(hw::ConstantOp::create(rewriter, op.getLoc(),
651 APInt::getZero(lz)));
652 resultElts.push_back(rewriter.createOrFold<ExtractOp>(
653 op.getLoc(), resultTy, inputOp->getOperand(0),
654 op.getLowBit() + tz));
655 if (tz)
656 resultElts.push_back(hw::ConstantOp::create(rewriter, op.getLoc(),
657 APInt::getZero(tz)));
658 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
659 return success();
660 }
661 }
662 }
663 }
664
665 // `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is
666 // extracted.
667 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
668 if (auto shlOp = dyn_cast<ShlOp>(inputOp)) {
669 // Don't canonicalize if the shift is multiply used.
670 if (shlOp->hasOneUse())
671 if (auto lhsCst = shlOp.getLhs().getDefiningOp<hw::ConstantOp>())
672 if (lhsCst.getValue().isOne()) {
673 auto newCst = hw::ConstantOp::create(
674 rewriter, shlOp.getLoc(),
675 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
676 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
677 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
678 false);
679 return success();
680 }
681 }
682
683 return failure();
684}
685
686//===----------------------------------------------------------------------===//
687// Associative Variadic operations
688//===----------------------------------------------------------------------===//
689
690// Reduce all operands to a single value (either integer constant or parameter
691// expression) if all the operands are constants.
692static Attribute constFoldAssociativeOp(ArrayRef<Attribute> operands,
693 hw::PEO paramOpcode) {
694 assert(operands.size() > 1 && "caller should handle one-operand case");
695 // We can only fold anything in the case where all operands are known to be
696 // constants. Check the least common one first for an early out.
697 if (!operands[1] || !operands[0])
698 return {};
699
700 // This will fold to a simple constant if all operands are constant.
701 if (llvm::all_of(operands.drop_front(2),
702 [&](Attribute in) { return !!in; })) {
703 SmallVector<mlir::TypedAttr> typedOperands;
704 typedOperands.reserve(operands.size());
705 for (auto operand : operands) {
706 if (auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
707 typedOperands.push_back(typedOperand);
708 else
709 break;
710 }
711 if (typedOperands.size() == operands.size())
712 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
713 }
714
715 return {};
716}
717
718/// When we find a logical operation (and, or, xor) with a constant e.g.
719/// `X & 42`, we want to push the constant into the computation of X if it leads
720/// to simplification.
721///
722/// This function handles the case where the logical operation has a concat
723/// operand. We check to see if we can simplify the concat, e.g. when it has
724/// constant operands.
725///
726/// This returns true when a simplification happens.
727static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp,
728 size_t concatIdx, const APInt &cst,
729 PatternRewriter &rewriter) {
730 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<ConcatOp>();
731 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
732
733 // Check to see if any operands can be simplified by pushing the logical op
734 // into all parts of the concat.
735 bool canSimplify =
736 llvm::any_of(concatOp->getOperands(), [&](Value operand) -> bool {
737 auto *operandOp = operand.getDefiningOp();
738 if (!operandOp)
739 return false;
740
741 // If the concat has a constant operand then we can transform this.
742 if (isa<hw::ConstantOp>(operandOp))
743 return true;
744 // If the concat has the same logical operation and that operation has
745 // a constant operation than we can fold it into that suboperation.
746 return operandOp->getName() == logicalOp->getName() &&
747 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
748 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
749 });
750
751 if (!canSimplify)
752 return false;
753
754 // Create a new instance of the logical operation. We have to do this the
755 // hard way since we're generic across a family of different ops.
756 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
757 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
758 rewriter);
759 };
760
761 // Ok, let's do the transformation. We do this by slicing up the constant
762 // for each unit of the concat and duplicate the operation into the
763 // sub-operand.
764 SmallVector<Value> newConcatOperands;
765 newConcatOperands.reserve(concatOp->getNumOperands());
766
767 // Work from MSB to LSB.
768 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
769 for (Value operand : concatOp->getOperands()) {
770 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
771 nextOperandBit -= operandWidth;
772 // Take a slice of the constant.
773 auto eltCst =
774 hw::ConstantOp::create(rewriter, logicalOp->getLoc(),
775 cst.lshr(nextOperandBit).trunc(operandWidth));
776
777 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
778 }
779
780 // Create the concat, and the rest of the logical op if we need it.
781 Value newResult =
782 ConcatOp::create(rewriter, concatOp.getLoc(), newConcatOperands);
783
784 // If we had a variadic logical op on the top level, then recreate it with the
785 // new concat and without the constant operand.
786 if (logicalOp->getNumOperands() > 2) {
787 auto origOperands = logicalOp->getOperands();
788 SmallVector<Value> operands;
789 // Take any stuff before the concat.
790 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
791 // Take any stuff after the concat but before the constant.
792 operands.append(origOperands.begin() + concatIdx + 1,
793 origOperands.begin() + (origOperands.size() - 1));
794 // Include the new concat.
795 operands.push_back(newResult);
796 newResult = createLogicalOp(operands);
797 }
798
799 replaceOpAndCopyNamehint(rewriter, logicalOp, newResult);
800 return true;
801}
802
803// Determines whether the inputs to a logical element are of opposite
804// comparisons and can lowered into a constant.
805static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands) {
806 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
807
808 for (auto op : operands) {
809 if (auto icmpOp = op.getDefiningOp<ICmpOp>();
810 icmpOp && icmpOp.getTwoState()) {
811 auto predicate = icmpOp.getPredicate();
812 auto lhs = icmpOp.getLhs();
813 auto rhs = icmpOp.getRhs();
814 if (seenPredicates.contains(
815 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
816 return true;
817
818 seenPredicates.insert({predicate, lhs, rhs});
819 }
820 }
821 return false;
822}
823
824OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
825 if (isOpTriviallyRecursive(*this))
826 return {};
827
828 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).getWidth());
829
830 auto inputs = adaptor.getInputs();
831
832 // and(x, 01, 10) -> 00 -- annulment.
833 for (auto operand : inputs) {
834 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
835 if (!attr)
836 continue;
837 value &= attr.getValue();
838 if (value.isZero())
839 return getIntAttr(value, getContext());
840 }
841
842 // and(x, -1) -> x.
843 if (inputs.size() == 2)
844 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
845 if (intAttr.getValue().isAllOnes())
846 return getInputs()[0];
847
848 // and(x, x, x) -> x. This also handles and(x) -> x.
849 if (llvm::all_of(getInputs(),
850 [&](auto in) { return in == this->getInputs()[0]; }))
851 return getInputs()[0];
852
853 // and(..., x, ..., ~x, ...) -> 0
854 for (Value arg : getInputs()) {
855 Value subExpr;
856 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
857 for (Value arg2 : getInputs())
858 if (arg2 == subExpr)
859 return getIntAttr(
860 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
861 getContext());
862 }
863 }
864
865 // x0 = icmp(pred, x, y)
866 // x1 = icmp(!pred, x, y)
867 // and(x0, x1) -> 0
869 return getIntAttr(APInt::getZero(cast<IntegerType>(getType()).getWidth()),
870 getContext());
871
872 // Constant fold
873 return constFoldAssociativeOp(inputs, hw::PEO::And);
874}
875
876/// Returns a single common operand that all inputs of the operation `op` can
877/// be traced back to, or an empty `Value` if no such operand exists.
878///
879/// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a`
880/// (assuming the bit-width of `a` is `n`).
881template <typename Op>
882static Value getCommonOperand(Op op) {
883 if (!op.getType().isInteger(1))
884 return Value();
885
886 auto inputs = op.getInputs();
887 size_t size = inputs.size();
888
889 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
890 if (!sourceOp)
891 return Value();
892 Value source = sourceOp.getOperand();
893
894 // Fast path: the input size is not equal to the width of the source.
895 if (size != source.getType().getIntOrFloatBitWidth())
896 return Value();
897
898 // Tracks the bits that were encountered.
899 llvm::BitVector bits(size);
900 bits.set(sourceOp.getLowBit());
901
902 for (size_t i = 1; i != size; ++i) {
903 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
904 if (!extractOp || extractOp.getOperand() != source)
905 return Value();
906 bits.set(extractOp.getLowBit());
907 }
908
909 return bits.all() ? source : Value();
910}
911
912/// Canonicalize an idempotent operation `op` so that only one input of any kind
913/// occurs.
914///
915/// Example: `and(x, y, x, z)` -> `and(x, y, z)`
916template <typename Op>
917static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
918 // Depth limit to search, in operations. Chosen arbitrarily, keep small.
919 constexpr unsigned limit = 3;
920 auto inputs = op.getInputs();
921
922 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
923 llvm::SmallDenseSet<Op, 8> checked;
924 checked.insert(op);
925
926 struct OpWithDepth {
927 Op op;
928 unsigned depth;
929 };
930 llvm::SmallVector<OpWithDepth, 8> worklist;
931
932 auto enqueue = [&worklist, &checked, &op](Value input, unsigned depth) {
933 // Add to worklist if within depth limit, is defined in the same block by
934 // the same kind of operation, has same two-state-ness, and not enqueued
935 // previously.
936 if (depth < limit && input.getParentBlock() == op->getBlock()) {
937 auto inputOp = input.template getDefiningOp<Op>();
938 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
939 checked.insert(inputOp).second)
940 worklist.push_back({inputOp, depth + 1});
941 }
942 };
943
944 for (auto input : uniqueInputs)
945 enqueue(input, 0);
946
947 while (!worklist.empty()) {
948 auto item = worklist.pop_back_val();
949
950 for (auto input : item.op.getInputs()) {
951 uniqueInputs.remove(input);
952 enqueue(input, item.depth);
953 }
954 }
955
956 if (uniqueInputs.size() < inputs.size()) {
957 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
958 uniqueInputs.getArrayRef(),
959 op.getTwoState());
960 return true;
961 }
962
963 return false;
964}
965
966LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
968 return failure();
969
970 auto inputs = op.getInputs();
971 auto size = inputs.size();
972
973 // and(x, and(...)) -> and(x, ...) -- flatten
974 if (tryFlatteningOperands(op, rewriter))
975 return success();
976
977 // and(..., x, ..., x) -> and(..., x, ...) -- idempotent
978 // and(..., x, and(..., x, ...)) -> and(..., and(..., x, ...)) -- idempotent
979 // Trivial and(x), and(x, x) cases are handled by [AndOp::fold] above.
980 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
981 return success();
982
983 assert(size > 1 && "expected 2 or more operands, `fold` should handle this");
984
985 // Patterns for and with a constant on RHS.
986 APInt value;
987 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
988 // and(..., '1) -> and(...) -- identity
989 if (value.isAllOnes()) {
990 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
991 inputs.drop_back(), false);
992 return success();
993 }
994
995 // TODO: Combine multiple constants together even if they aren't at the
996 // end. and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant
997 // folding
998 APInt value2;
999 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1000 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value & value2);
1001 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1002 newOperands.push_back(cst);
1003 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
1004 newOperands, false);
1005 return success();
1006 }
1007
1008 // Handle 'and' with a single bit constant on the RHS.
1009 if (size == 2 && value.isPowerOf2()) {
1010 // If the LHS is a replicate from a single bit, we can 'concat' it
1011 // into place. e.g.:
1012 // `replicate(x) & 4` -> `concat(zeros, x, zeros)`
1013 // TODO: Generalize this for non-single-bit operands.
1014 if (auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1015 auto replicateOperand = replicate.getOperand();
1016 if (replicateOperand.getType().isInteger(1)) {
1017 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1018 auto trailingZeros = value.countTrailingZeros();
1019
1020 // Don't add zero bit constants unnecessarily.
1021 SmallVector<Value, 3> concatOperands;
1022 if (trailingZeros != resultWidth - 1) {
1023 auto highZeros = hw::ConstantOp::create(
1024 rewriter, op.getLoc(),
1025 APInt::getZero(resultWidth - trailingZeros - 1));
1026 concatOperands.push_back(highZeros);
1027 }
1028 concatOperands.push_back(replicateOperand);
1029 if (trailingZeros != 0) {
1030 auto lowZeros = hw::ConstantOp::create(
1031 rewriter, op.getLoc(), APInt::getZero(trailingZeros));
1032 concatOperands.push_back(lowZeros);
1033 }
1034 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
1035 rewriter, op, op.getType(), concatOperands);
1036 return success();
1037 }
1038 }
1039 }
1040
1041 // Narrow the op if the constant has leading or trailing zeros.
1042 //
1043 // and(a, 0b00101100) -> concat(0b00, and(extract(a), 0b1011), 0b00)
1044 unsigned leadingZeros = value.countLeadingZeros();
1045 unsigned trailingZeros = value.countTrailingZeros();
1046 if (leadingZeros > 0 || trailingZeros > 0) {
1047 unsigned maskLength = value.getBitWidth() - leadingZeros - trailingZeros;
1048
1049 // Extract the non-zero regions of the operands. Look through extracts.
1050 SmallVector<Value> operands;
1051 for (auto input : inputs.drop_back()) {
1052 unsigned offset = trailingZeros;
1053 while (auto extractOp = input.getDefiningOp<ExtractOp>()) {
1054 input = extractOp.getInput();
1055 offset += extractOp.getLowBit();
1056 }
1057 operands.push_back(ExtractOp::create(rewriter, op.getLoc(), input,
1058 offset, maskLength));
1059 }
1060
1061 // Add the narrowed mask if needed.
1062 auto narrowMask = value.extractBits(maskLength, trailingZeros);
1063 if (!narrowMask.isAllOnes())
1064 operands.push_back(hw::ConstantOp::create(
1065 rewriter, inputs.back().getLoc(), narrowMask));
1066
1067 // Create the narrow and op.
1068 Value narrowValue = operands.back();
1069 if (operands.size() > 1)
1070 narrowValue =
1071 AndOp::create(rewriter, op.getLoc(), operands, op.getTwoState());
1072 operands.clear();
1073
1074 // Concatenate the narrow and with the leading and trailing zeros.
1075 if (leadingZeros > 0)
1076 operands.push_back(hw::ConstantOp::create(
1077 rewriter, op.getLoc(), APInt::getZero(leadingZeros)));
1078 operands.push_back(narrowValue);
1079 if (trailingZeros > 0)
1080 operands.push_back(hw::ConstantOp::create(
1081 rewriter, op.getLoc(), APInt::getZero(trailingZeros)));
1082 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, operands);
1083 return success();
1084 }
1085
1086 // and(concat(x, cst1), a, b, c, cst2)
1087 // ==> and(a, b, c, concat(and(x,cst2'), and(cst1,cst2'')).
1088 // We do this for even more multi-use concats since they are "just wiring".
1089 for (size_t i = 0; i < size - 1; ++i) {
1090 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1091 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1092 return success();
1093 }
1094 }
1095
1096 // extracts only of and(...) -> and(extract()...)
1097 if (narrowOperationWidth(op, true, rewriter))
1098 return success();
1099
1100 // and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1)
1101 if (auto source = getCommonOperand(op)) {
1102 auto cmpAgainst =
1103 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getAllOnes(size));
1104 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1105 source, cmpAgainst);
1106 return success();
1107 }
1108
1109 /// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
1110 return failure();
1111}
1112
1113OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1114 if (isOpTriviallyRecursive(*this))
1115 return {};
1116
1117 auto value = APInt::getZero(cast<IntegerType>(getType()).getWidth());
1118 auto inputs = adaptor.getInputs();
1119 // or(x, 10, 01) -> 11
1120 for (auto operand : inputs) {
1121 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
1122 if (!attr)
1123 continue;
1124 value |= attr.getValue();
1125 if (value.isAllOnes())
1126 return getIntAttr(value, getContext());
1127 }
1128
1129 // or(x, 0) -> x
1130 if (inputs.size() == 2)
1131 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
1132 if (intAttr.getValue().isZero())
1133 return getInputs()[0];
1134
1135 // or(x, x, x) -> x. This also handles or(x) -> x
1136 if (llvm::all_of(getInputs(),
1137 [&](auto in) { return in == this->getInputs()[0]; }))
1138 return getInputs()[0];
1139
1140 // or(..., x, ..., ~x, ...) -> -1
1141 for (Value arg : getInputs()) {
1142 Value subExpr;
1143 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
1144 for (Value arg2 : getInputs())
1145 if (arg2 == subExpr)
1146 return getIntAttr(
1147 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1148 getContext());
1149 }
1150 }
1151
1152 // x0 = icmp(pred, x, y)
1153 // x1 = icmp(!pred, x, y)
1154 // or(x0, x1) -> 1
1155 if (canCombineOppositeBinCmpIntoConstant(getInputs()))
1156 return getIntAttr(
1157 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1158 getContext());
1159
1160 // Constant fold
1161 return constFoldAssociativeOp(inputs, hw::PEO::Or);
1162}
1163
1164LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
1165 if (isOpTriviallyRecursive(op))
1166 return failure();
1167
1168 auto inputs = op.getInputs();
1169 auto size = inputs.size();
1170
1171 // or(x, or(...)) -> or(x, ...) -- flatten
1172 if (tryFlatteningOperands(op, rewriter))
1173 return success();
1174
1175 // or(..., x, ..., x, ...) -> or(..., x) -- idempotent
1176 // or(..., x, or(..., x, ...)) -> or(..., or(..., x, ...)) -- idempotent
1177 // Trivial or(x), or(x, x) cases are handled by [OrOp::fold].
1178 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
1179 return success();
1180
1181 assert(size > 1 && "expected 2 or more operands");
1182
1183 // Patterns for and with a constant on RHS.
1184 APInt value;
1185 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1186 // or(..., '0) -> or(...) -- identity
1187 if (value.isZero()) {
1188 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1189 inputs.drop_back());
1190 return success();
1191 }
1192
1193 // or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant folding
1194 APInt value2;
1195 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1196 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value | value2);
1197 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1198 newOperands.push_back(cst);
1199 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1200 newOperands);
1201 return success();
1202 }
1203
1204 // or(concat(x, cst1), a, b, c, cst2)
1205 // ==> or(a, b, c, concat(or(x,cst2'), or(cst1,cst2'')).
1206 // We do this for even more multi-use concats since they are "just wiring".
1207 for (size_t i = 0; i < size - 1; ++i) {
1208 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1209 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1210 return success();
1211 }
1212 }
1213
1214 // extracts only of or(...) -> or(extract()...)
1215 if (narrowOperationWidth(op, true, rewriter))
1216 return success();
1217
1218 // or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0)
1219 if (auto source = getCommonOperand(op)) {
1220 auto cmpAgainst =
1221 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getZero(size));
1222 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1223 source, cmpAgainst);
1224 return success();
1225 }
1226
1227 // or(mux(c_1, a, 0), mux(c_2, a, 0), ..., mux(c_n, a, 0)) -> mux(or(c_1, c_2,
1228 // .., c_n), a, 0)
1229 if (auto firstMux = op.getOperand(0).getDefiningOp<comb::MuxOp>()) {
1230 APInt value;
1231 if (op.getTwoState() && firstMux.getTwoState() &&
1232 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1233 value.isZero()) {
1234 SmallVector<Value> conditions{firstMux.getCond()};
1235 auto check = [&](Value v) {
1236 auto mux = v.getDefiningOp<comb::MuxOp>();
1237 if (!mux)
1238 return false;
1239 conditions.push_back(mux.getCond());
1240 return mux.getTwoState() &&
1241 firstMux.getTrueValue() == mux.getTrueValue() &&
1242 firstMux.getFalseValue() == mux.getFalseValue();
1243 };
1244 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1245 auto cond = comb::OrOp::create(rewriter, op.getLoc(), conditions, true);
1246 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1247 rewriter, op, cond, firstMux.getTrueValue(),
1248 firstMux.getFalseValue(), true);
1249 return success();
1250 }
1251 }
1252 }
1253
1254 /// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
1255 return failure();
1256}
1257
1258OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1259 if (isOpTriviallyRecursive(*this))
1260 return {};
1261
1262 auto size = getInputs().size();
1263 auto inputs = adaptor.getInputs();
1264
1265 // xor(x) -> x -- noop
1266 if (size == 1)
1267 return getInputs()[0];
1268
1269 // xor(x, x) -> 0 -- idempotent
1270 if (size == 2 && getInputs()[0] == getInputs()[1])
1271 return IntegerAttr::get(getType(), 0);
1272
1273 // xor(x, 0) -> x
1274 if (inputs.size() == 2)
1275 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
1276 if (intAttr.getValue().isZero())
1277 return getInputs()[0];
1278
1279 // xor(xor(x,1),1) -> x
1280 // but not self loop
1281 Value subExpr;
1282 if (matchPattern(getResult(), m_Complement(m_Complement(m_Any(&subExpr)))) &&
1283 subExpr != getResult())
1284 return subExpr;
1285
1286 // Constant fold
1287 return constFoldAssociativeOp(inputs, hw::PEO::Xor);
1288}
1289
1290// xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1291static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand,
1292 PatternRewriter &rewriter) {
1293 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1294 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1295
1296 Value result =
1297 ICmpOp::create(rewriter, icmp.getLoc(), negatedPred, icmp.getOperand(0),
1298 icmp.getOperand(1), icmp.getTwoState());
1299
1300 // If the xor had other operands, rebuild it.
1301 if (op.getNumOperands() > 2) {
1302 SmallVector<Value, 4> newOperands(op.getOperands());
1303 newOperands.pop_back();
1304 newOperands.erase(newOperands.begin() + icmpOperand);
1305 newOperands.push_back(result);
1306 result =
1307 XorOp::create(rewriter, op.getLoc(), newOperands, op.getTwoState());
1308 }
1309
1310 replaceOpAndCopyNamehint(rewriter, op, result);
1311}
1312
1313LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
1314 if (isOpTriviallyRecursive(op))
1315 return failure();
1316
1317 auto inputs = op.getInputs();
1318 auto size = inputs.size();
1319 assert(size > 1 && "expected 2 or more operands");
1320
1321 // xor(..., x, x) -> xor (...) -- idempotent
1322 if (inputs[size - 1] == inputs[size - 2]) {
1323 assert(size > 2 &&
1324 "expected idempotent case for 2 elements handled already.");
1325 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1326 inputs.drop_back(/*n=*/2), false);
1327 return success();
1328 }
1329
1330 // Patterns for xor with a constant on RHS.
1331 APInt value;
1332 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1333 // xor(..., 0) -> xor(...) -- identity
1334 if (value.isZero()) {
1335 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1336 inputs.drop_back(), false);
1337 return success();
1338 }
1339
1340 // xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2.
1341 APInt value2;
1342 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1343 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value ^ value2);
1344 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1345 newOperands.push_back(cst);
1346 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1347 newOperands, false);
1348 return success();
1349 }
1350
1351 bool isSingleBit = value.getBitWidth() == 1;
1352
1353 // Check for subexpressions that we can simplify.
1354 for (size_t i = 0; i < size - 1; ++i) {
1355 Value operand = inputs[i];
1356
1357 // xor(concat(x, cst1), a, b, c, cst2)
1358 // ==> xor(a, b, c, concat(xor(x,cst2'), xor(cst1,cst2'')).
1359 // We do this for even more multi-use concats since they are "just
1360 // wiring".
1361 if (auto concat = operand.getDefiningOp<ConcatOp>())
1362 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1363 return success();
1364
1365 // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1366 if (isSingleBit && operand.hasOneUse()) {
1367 assert(value == 1 && "single bit constant has to be one if not zero");
1368 if (auto icmp = operand.getDefiningOp<ICmpOp>())
1369 return canonicalizeXorIcmpTrue(op, i, rewriter), success();
1370 }
1371 }
1372 }
1373
1374 // xor(sext(x), -1) -> sext(xor(x,-1))
1375 // More concisely: ~sext(x) = sext(~x)
1376 Value base;
1377 // Check for sext of the inverted value
1378 if (matchPattern(op.getResult(), m_Complement(m_Sext(m_Any(&base))))) {
1379 // Create negated sext: ~sext(x) = sext(~x)
1380 auto negBase = createOrFoldNot(op.getLoc(), base, rewriter, true);
1381 auto sextNegBase =
1382 createOrFoldSExt(op.getLoc(), negBase, op.getType(), rewriter);
1383 replaceOpAndCopyNamehint(rewriter, op, sextNegBase);
1384 return success();
1385 }
1386
1387 // xor(x, xor(...)) -> xor(x, ...) -- flatten
1388 if (tryFlatteningOperands(op, rewriter))
1389 return success();
1390
1391 // extracts only of xor(...) -> xor(extract()...)
1392 if (narrowOperationWidth(op, true, rewriter))
1393 return success();
1394
1395 // xor(a[0], a[1], ..., a[n]) -> parity(a)
1396 if (auto source = getCommonOperand(op)) {
1397 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1398 return success();
1399 }
1400
1401 return failure();
1402}
1403
1404OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1405 if (isOpTriviallyRecursive(*this))
1406 return {};
1407
1408 // sub(x - x) -> 0
1409 if (getRhs() == getLhs())
1410 return getIntAttr(
1411 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1412 getContext());
1413
1414 if (adaptor.getRhs()) {
1415 // If both are constants, we can unconditionally fold.
1416 if (adaptor.getLhs()) {
1417 // Constant fold (c1 - c2) => (c1 + -1*c2).
1418 auto negOne = getIntAttr(
1419 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1420 getContext());
1421 auto rhsNeg = hw::ParamExprAttr::get(
1422 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1423 return hw::ParamExprAttr::get(hw::PEO::Add,
1424 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1425 }
1426
1427 // sub(x - 0) -> x
1428 if (auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1429 if (rhsC.getValue().isZero())
1430 return getLhs();
1431 }
1432 }
1433
1434 return {};
1435}
1436
1437LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
1438 if (isOpTriviallyRecursive(op))
1439 return failure();
1440
1441 // sub(x, cst) -> add(x, -cst)
1442 APInt value;
1443 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1444 auto negCst = hw::ConstantOp::create(rewriter, op.getLoc(), -value);
1445 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1446 false);
1447 return success();
1448 }
1449
1450 // extracts only of sub(...) -> sub(extract()...)
1451 if (narrowOperationWidth(op, false, rewriter))
1452 return success();
1453
1454 return failure();
1455}
1456
1457OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1458 if (isOpTriviallyRecursive(*this))
1459 return {};
1460
1461 auto size = getInputs().size();
1462
1463 // add(x) -> x -- noop
1464 if (size == 1u)
1465 return getInputs()[0];
1466
1467 // Constant fold constant operands.
1468 return constFoldAssociativeOp(adaptor.getOperands(), hw::PEO::Add);
1469}
1470
1471LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
1472 if (isOpTriviallyRecursive(op))
1473 return failure();
1474
1475 auto inputs = op.getInputs();
1476 auto size = inputs.size();
1477 assert(size > 1 && "expected 2 or more operands");
1478
1479 APInt value, value2;
1480
1481 // add(..., 0) -> add(...) -- identity
1482 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1483 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1484 inputs.drop_back(), false);
1485 return success();
1486 }
1487
1488 // add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant folding
1489 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1490 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1491 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value + value2);
1492 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1493 newOperands.push_back(cst);
1494 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1495 newOperands, false);
1496 return success();
1497 }
1498
1499 // add(..., x, x) -> add(..., shl(x, 1))
1500 if (inputs[size - 1] == inputs[size - 2]) {
1501 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1502
1503 auto one = hw::ConstantOp::create(rewriter, op.getLoc(), op.getType(), 1);
1504 auto shiftLeftOp =
1505 comb::ShlOp::create(rewriter, op.getLoc(), inputs.back(), one, false);
1506
1507 newOperands.push_back(shiftLeftOp);
1508 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1509 newOperands, false);
1510 return success();
1511 }
1512
1513 auto shlOp = inputs[size - 1].getDefiningOp<comb::ShlOp>();
1514 // add(..., x, shl(x, c)) -> add(..., mul(x, (1 << c) + 1))
1515 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1516 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1517
1518 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1519 auto rhs =
1520 hw::ConstantOp::create(rewriter, op.getLoc(), (one << value) + one);
1521
1522 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1523 auto mulOp = comb::MulOp::create(rewriter, op.getLoc(), factors, false);
1524
1525 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1526 newOperands.push_back(mulOp);
1527 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1528 newOperands, false);
1529 return success();
1530 }
1531
1532 auto mulOp = inputs[size - 1].getDefiningOp<comb::MulOp>();
1533 // add(..., x, mul(x, c)) -> add(..., mul(x, c + 1))
1534 if (mulOp && mulOp.getInputs().size() == 2 &&
1535 mulOp.getInputs()[0] == inputs[size - 2] &&
1536 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1537
1538 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1539 auto rhs = hw::ConstantOp::create(rewriter, op.getLoc(), value + one);
1540 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1541 auto newMulOp = comb::MulOp::create(rewriter, op.getLoc(), factors, false);
1542
1543 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1544 newOperands.push_back(newMulOp);
1545 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1546 newOperands, false);
1547 return success();
1548 }
1549
1550 // add(a, add(...)) -> add(a, ...) -- flatten
1551 if (tryFlatteningOperands(op, rewriter))
1552 return success();
1553
1554 // extracts only of add(...) -> add(extract()...)
1555 if (narrowOperationWidth(op, false, rewriter))
1556 return success();
1557
1558 // add(add(x, c1), c2) -> add(x, c1 + c2)
1559 auto addOp = inputs[0].getDefiningOp<comb::AddOp>();
1560 if (addOp && addOp.getInputs().size() == 2 &&
1561 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1562 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1563
1564 auto rhs = hw::ConstantOp::create(rewriter, op.getLoc(), value + value2);
1565 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1566 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1567 /*twoState=*/op.getTwoState() && addOp.getTwoState());
1568 return success();
1569 }
1570
1571 return failure();
1572}
1573
1574OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1575 if (isOpTriviallyRecursive(*this))
1576 return {};
1577
1578 auto size = getInputs().size();
1579 auto inputs = adaptor.getInputs();
1580
1581 // mul(x) -> x -- noop
1582 if (size == 1u)
1583 return getInputs()[0];
1584
1585 auto width = cast<IntegerType>(getType()).getWidth();
1586 if (width == 0)
1587 return getIntAttr(APInt::getZero(0), getContext());
1588
1589 APInt value(/*numBits=*/width, 1, /*isSigned=*/false);
1590
1591 // mul(x, 0, 1) -> 0 -- annulment
1592 for (auto operand : inputs) {
1593 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
1594 if (!attr)
1595 continue;
1596 value *= attr.getValue();
1597 if (value.isZero())
1598 return getIntAttr(value, getContext());
1599 }
1600
1601 // Constant fold
1602 return constFoldAssociativeOp(inputs, hw::PEO::Mul);
1603}
1604
1605LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
1606 if (isOpTriviallyRecursive(op))
1607 return failure();
1608
1609 auto inputs = op.getInputs();
1610 auto size = inputs.size();
1611 assert(size > 1 && "expected 2 or more operands");
1612
1613 APInt value, value2;
1614
1615 // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
1616 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1617 value.isPowerOf2()) {
1618 auto shift = hw::ConstantOp::create(rewriter, op.getLoc(), op.getType(),
1619 value.exactLogBase2());
1620 auto shlOp =
1621 comb::ShlOp::create(rewriter, op.getLoc(), inputs[0], shift, false);
1622
1623 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1624 ArrayRef<Value>(shlOp), false);
1625 return success();
1626 }
1627
1628 // mul(..., 1) -> mul(...) -- identity
1629 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1630 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1631 inputs.drop_back());
1632 return success();
1633 }
1634
1635 // mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant folding
1636 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1637 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1638 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value * value2);
1639 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1640 newOperands.push_back(cst);
1641 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1642 newOperands);
1643 return success();
1644 }
1645
1646 // mul(a, mul(...)) -> mul(a, ...) -- flatten
1647 if (tryFlatteningOperands(op, rewriter))
1648 return success();
1649
1650 // extracts only of mul(...) -> mul(extract()...)
1651 if (narrowOperationWidth(op, false, rewriter))
1652 return success();
1653
1654 return failure();
1655}
1656
1657template <class Op, bool isSigned>
1658static OpFoldResult foldDiv(Op op, ArrayRef<Attribute> constants) {
1659 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1660 // divu(x, 1) -> x, divs(x, 1) -> x
1661 if (rhsValue.getValue() == 1)
1662 return op.getLhs();
1663
1664 // If the divisor is zero, do not fold for now.
1665 if (rhsValue.getValue().isZero())
1666 return {};
1667 }
1668
1669 return constFoldBinaryOp(constants, isSigned ? hw::PEO::DivS : hw::PEO::DivU);
1670}
1671
1672OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1673 if (isOpTriviallyRecursive(*this))
1674 return {};
1675 return foldDiv<DivUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1676}
1677
1678OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1679 if (isOpTriviallyRecursive(*this))
1680 return {};
1681 return foldDiv<DivSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1682}
1683
1684template <class Op, bool isSigned>
1685static OpFoldResult foldMod(Op op, ArrayRef<Attribute> constants) {
1686 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1687 // modu(x, 1) -> 0, mods(x, 1) -> 0
1688 if (rhsValue.getValue() == 1)
1689 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1690 op.getContext());
1691
1692 // If the divisor is zero, do not fold for now.
1693 if (rhsValue.getValue().isZero())
1694 return {};
1695 }
1696
1697 if (auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1698 // modu(0, x) -> 0, mods(0, x) -> 0
1699 if (lhsValue.getValue().isZero())
1700 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1701 op.getContext());
1702 }
1703
1704 return constFoldBinaryOp(constants, isSigned ? hw::PEO::ModS : hw::PEO::ModU);
1705}
1706
1707OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1708 if (isOpTriviallyRecursive(*this))
1709 return {};
1710 return foldMod<ModUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1711}
1712
1713OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1714 if (isOpTriviallyRecursive(*this))
1715 return {};
1716 return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1717}
1718
1719LogicalResult DivUOp::canonicalize(DivUOp op, PatternRewriter &rewriter) {
1720 if (isOpTriviallyRecursive(op) || !op.getTwoState())
1721 return failure();
1722 return convertDivUByPowerOfTwo(op, rewriter);
1723}
1724
1725LogicalResult ModUOp::canonicalize(ModUOp op, PatternRewriter &rewriter) {
1726 if (isOpTriviallyRecursive(op) || !op.getTwoState())
1727 return failure();
1728
1729 return convertModUByPowerOfTwo(op, rewriter);
1730}
1731
1732//===----------------------------------------------------------------------===//
1733// ConcatOp
1734//===----------------------------------------------------------------------===//
1735
1736// Constant folding
1737OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1738 if (isOpTriviallyRecursive(*this))
1739 return {};
1740
1741 if (getNumOperands() == 1)
1742 return getOperand(0);
1743
1744 // If all the operands are constant, we can fold.
1745 for (auto attr : adaptor.getInputs())
1746 if (!attr || !isa<IntegerAttr>(attr))
1747 return {};
1748
1749 // If we got here, we can constant fold.
1750 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1751 APInt result(resultWidth, 0);
1752
1753 unsigned nextInsertion = resultWidth;
1754 // Insert each chunk into the result.
1755 for (auto attr : adaptor.getInputs()) {
1756 auto chunk = cast<IntegerAttr>(attr).getValue();
1757 nextInsertion -= chunk.getBitWidth();
1758 result.insertBits(chunk, nextInsertion);
1759 }
1760
1761 return getIntAttr(result, getContext());
1762}
1763
1764LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
1765 if (isOpTriviallyRecursive(op))
1766 return failure();
1767
1768 auto inputs = op.getInputs();
1769 auto size = inputs.size();
1770 assert(size > 1 && "expected 2 or more operands");
1771
1772 // This function is used when we flatten neighboring operands of a
1773 // (variadic) concat into a new vesion of the concat. first/last indices
1774 // are inclusive.
1775 auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex,
1776 ValueRange replacements) -> LogicalResult {
1777 SmallVector<Value, 4> newOperands;
1778 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1779 newOperands.append(replacements.begin(), replacements.end());
1780 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1781 if (newOperands.size() == 1)
1782 replaceOpAndCopyNamehint(rewriter, op, newOperands[0]);
1783 else
1784 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1785 newOperands);
1786 return success();
1787 };
1788
1789 Value commonOperand = inputs[0];
1790 for (size_t i = 0; i != size; ++i) {
1791 // Check to see if all operands are the same.
1792 if (inputs[i] != commonOperand)
1793 commonOperand = Value();
1794
1795 // If an operand to the concat is itself a concat, then we can fold them
1796 // together.
1797 if (auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1798 return flattenConcat(i, i, subConcat->getOperands());
1799
1800 // Check for canonicalization due to neighboring operands.
1801 if (i != 0) {
1802 // Merge neighboring constants.
1803 if (auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1804 if (auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1805 unsigned prevWidth = prevCst.getValue().getBitWidth();
1806 unsigned thisWidth = cst.getValue().getBitWidth();
1807 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1808 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1809 << thisWidth;
1810 Value replacement =
1811 hw::ConstantOp::create(rewriter, op.getLoc(), resultCst);
1812 return flattenConcat(i - 1, i, replacement);
1813 }
1814 }
1815
1816 // If the two operands are the same, turn them into a replicate.
1817 if (inputs[i] == inputs[i - 1]) {
1818 Value replacement =
1819 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1820 return flattenConcat(i - 1, i, replacement);
1821 }
1822
1823 // If this input is a replicate, see if we can fold it with the previous
1824 // one.
1825 if (auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1826 // ... x, repl(x, n), ... ==> ..., repl(x, n+1), ...
1827 if (repl.getOperand() == inputs[i - 1]) {
1828 Value replacement = rewriter.createOrFold<ReplicateOp>(
1829 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1830 return flattenConcat(i - 1, i, replacement);
1831 }
1832 // ... repl(x, n), repl(x, m), ... ==> ..., repl(x, n+m), ...
1833 if (auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1834 if (prevRepl.getOperand() == repl.getOperand()) {
1835 Value replacement = rewriter.createOrFold<ReplicateOp>(
1836 op.getLoc(), repl.getOperand(),
1837 repl.getMultiple() + prevRepl.getMultiple());
1838 return flattenConcat(i - 1, i, replacement);
1839 }
1840 }
1841 }
1842
1843 // ... repl(x, n), x, ... ==> ..., repl(x, n+1), ...
1844 if (auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1845 if (repl.getOperand() == inputs[i]) {
1846 Value replacement = rewriter.createOrFold<ReplicateOp>(
1847 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1848 return flattenConcat(i - 1, i, replacement);
1849 }
1850 }
1851
1852 // Merge neighboring extracts of neighboring inputs, e.g.
1853 // {A[3], A[2]} -> A[3:2]
1854 if (auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1855 if (auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1856 if (extract.getInput() == prevExtract.getInput()) {
1857 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1858 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1859 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1860 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1861 Value replacement =
1862 ExtractOp::create(rewriter, op.getLoc(), resType,
1863 extract.getInput(), extract.getLowBit());
1864 return flattenConcat(i - 1, i, replacement);
1865 }
1866 }
1867 }
1868 }
1869 // Merge neighboring array extracts of neighboring inputs, e.g.
1870 // {Array[4], bitcast(Array[3:2])} -> bitcast(A[4:2])
1871
1872 // This represents a slice of an array.
1873 struct ArraySlice {
1874 Value input;
1875 Value index;
1876 size_t width;
1877 static std::optional<ArraySlice> get(Value value) {
1878 assert(isa<IntegerType>(value.getType()) && "expected integer type");
1879 if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
1880 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1881 // array slice op is wrapped with bitcast.
1882 if (auto bitcast = value.getDefiningOp<hw::BitcastOp>())
1883 if (auto arraySlice =
1884 bitcast.getInput().getDefiningOp<hw::ArraySliceOp>())
1885 return ArraySlice{
1886 arraySlice.getInput(), arraySlice.getLowIndex(),
1887 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1888 .getNumElements()};
1889 return std::nullopt;
1890 }
1891 };
1892 if (auto extractOpt = ArraySlice::get(inputs[i])) {
1893 if (auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1894 // Check that two array slices are mergable.
1895 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1896 prevExtractOpt->input == extractOpt->input &&
1897 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1898 extractOpt->width)) {
1899 auto resType = hw::ArrayType::get(
1900 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1901 .getElementType(),
1902 extractOpt->width + prevExtractOpt->width);
1903 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1904 Value replacement = hw::BitcastOp::create(
1905 rewriter, op.getLoc(), resIntType,
1906 hw::ArraySliceOp::create(rewriter, op.getLoc(), resType,
1907 prevExtractOpt->input,
1908 extractOpt->index));
1909 return flattenConcat(i - 1, i, replacement);
1910 }
1911 }
1912 }
1913 }
1914 }
1915
1916 // If all operands were the same, then this is a replicate.
1917 if (commonOperand) {
1918 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1919 commonOperand);
1920 return success();
1921 }
1922
1923 return failure();
1924}
1925
1926//===----------------------------------------------------------------------===//
1927// MuxOp
1928//===----------------------------------------------------------------------===//
1929
1930OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1931 if (isOpTriviallyRecursive(*this))
1932 return {};
1933
1934 // mux (c, b, b) -> b
1935 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1936 return getTrueValue();
1937 if (auto tv = adaptor.getTrueValue())
1938 if (tv == adaptor.getFalseValue())
1939 return tv;
1940
1941 // mux(0, a, b) -> b
1942 // mux(1, a, b) -> a
1943 if (auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1944 if (pred.getValue().isZero() && getFalseValue() != getResult())
1945 return getFalseValue();
1946 if (pred.getValue().isOne() && getTrueValue() != getResult())
1947 return getTrueValue();
1948 }
1949
1950 // mux(cond, 1, 0) -> cond
1951 if (getCond().getType() == getTrueValue().getType())
1952 if (auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1953 if (auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1954 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1955 hw::getBitWidth(getType()) == 1 && getCond() != getResult())
1956 return getCond();
1957
1958 return {};
1959}
1960
1961/// Check to see if the condition to the specified mux is an equality
1962/// comparison `indexValue` and one or more constants. If so, put the
1963/// constants in the constants vector and return true, otherwise return false.
1964///
1965/// This is part of foldMuxChain.
1966///
1967static bool
1968getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted,
1969 std::function<void(hw::ConstantOp)> constantFn) {
1970 // Handle `idx == 42` and `idx != 42`.
1971 if (auto cmp = cond.getDefiningOp<ICmpOp>()) {
1972 // TODO: We could handle things like "x < 2" as two entries.
1973 auto requiredPredicate =
1974 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1975 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1976 if (auto cst = cmp.getRhs().getDefiningOp<hw::ConstantOp>()) {
1977 constantFn(cst);
1978 return true;
1979 }
1980 }
1981 return false;
1982 }
1983
1984 // Handle mux(`idx == 1 || idx == 3`, value, muxchain).
1985 if (auto orOp = cond.getDefiningOp<OrOp>()) {
1986 if (!isInverted)
1987 return false;
1988 for (auto operand : orOp.getOperands())
1989 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
1990 return false;
1991 return true;
1992 }
1993
1994 // Handle mux(`idx != 1 && idx != 3`, muxchain, value).
1995 if (auto andOp = cond.getDefiningOp<AndOp>()) {
1996 if (isInverted)
1997 return false;
1998 for (auto operand : andOp.getOperands())
1999 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
2000 return false;
2001 return true;
2002 }
2003
2004 return false;
2005}
2006
2007/// Given a mux, check to see if the "on true" value (or "on false" value if
2008/// isFalseSide=true) is a mux tree with the same condition. This allows us
2009/// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into
2010/// `array_get (array_create(A, B, C), VAL)` or a balanced mux tree which is far
2011/// more compact and allows synthesis tools to do more interesting
2012/// optimizations.
2013///
2014/// This returns false if we cannot form the mux tree (or do not want to) and
2015/// returns true if the mux was replaced.
2017 PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide,
2018 llvm::function_ref<MuxChainWithComparisonFoldingStyle(size_t indexWidth,
2019 size_t numEntries)>
2020 styleFn) {
2021 // Get the index value being compared. Later we check to see if it is
2022 // compared to a constant with the right predicate.
2023 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2024 if (!rootCmp)
2025 return false;
2026 Value indexValue = rootCmp.getLhs();
2027
2028 // Return the value to use if the equality match succeeds.
2029 auto getCaseValue = [&](MuxOp mux) -> Value {
2030 return mux.getOperand(1 + unsigned(!isFalseSide));
2031 };
2032
2033 // Return the value to use if the equality match fails. This is the next
2034 // mux in the sequence or the "otherwise" value.
2035 auto getTreeValue = [&](MuxOp mux) -> Value {
2036 return mux.getOperand(1 + unsigned(isFalseSide));
2037 };
2038
2039 // Start scanning the mux tree to see what we've got. Keep track of the
2040 // constant comparison value and the SSA value to use when equal to it.
2041 SmallVector<Location> locationsFound;
2042 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2043
2044 /// Extract constants and values into `valuesFound` and return true if this is
2045 /// part of the mux tree, otherwise return false.
2046 auto collectConstantValues = [&](MuxOp mux) -> bool {
2048 mux.getCond(), indexValue, isFalseSide, [&](hw::ConstantOp cst) {
2049 valuesFound.push_back({cst, getCaseValue(mux)});
2050 locationsFound.push_back(mux.getCond().getLoc());
2051 locationsFound.push_back(mux->getLoc());
2052 });
2053 };
2054
2055 // Make sure the root is a correct comparison with a constant.
2056 if (!collectConstantValues(rootMux))
2057 return false;
2058
2059 // Make sure that we're not looking at the intermediate node in a mux tree.
2060 if (rootMux->hasOneUse()) {
2061 if (auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2062 if (getTreeValue(userMux) == rootMux.getResult() &&
2063 getMuxChainCondConstant(userMux.getCond(), indexValue, isFalseSide,
2064 [&](hw::ConstantOp cst) {}))
2065 return false;
2066 }
2067 }
2068
2069 // Scan up the tree linearly.
2070 auto nextTreeValue = getTreeValue(rootMux);
2071 while (1) {
2072 auto nextMux = nextTreeValue.getDefiningOp<MuxOp>();
2073 if (!nextMux || !nextMux->hasOneUse())
2074 break;
2075 if (!collectConstantValues(nextMux))
2076 break;
2077 nextTreeValue = getTreeValue(nextMux);
2078 }
2079
2080 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2081
2082 if (indexWidth > 20)
2083 return false; // Too big to make a table.
2084
2085 auto foldingStyle = styleFn(indexWidth, valuesFound.size());
2086 if (foldingStyle == MuxChainWithComparisonFoldingStyle::None)
2087 return false;
2088
2089 uint64_t tableSize = 1ULL << indexWidth;
2090
2091 // Ok, we're going to do the transformation, start by building the table
2092 // filled with the "otherwise" value.
2093 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2094
2095 // Fill in entries in the table from the leaf to the root of the expression.
2096 // This ensures that any duplicate matches end up with the ultimate value,
2097 // which is the one closer to the root.
2098 for (auto &elt : llvm::reverse(valuesFound)) {
2099 uint64_t idx = elt.first.getValue().getZExtValue();
2100 assert(idx < table.size() && "constant should be same bitwidth as index");
2101 table[idx] = elt.second;
2102 }
2103
2105 SmallVector<Value> bits;
2106 comb::extractBits(rewriter, indexValue, bits);
2107 auto result = constructMuxTree(rewriter, rootMux->getLoc(), bits, table,
2108 nextTreeValue);
2109 replaceOpAndCopyNamehint(rewriter, rootMux, result);
2110 return true;
2111 }
2112
2114 "unknown folding style");
2115
2116 // The hw.array_create operation has the operand list in unintuitive order
2117 // with a[0] stored as the last element, not the first.
2118 std::reverse(table.begin(), table.end());
2119
2120 // Build the array_create and the array_get.
2121 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2122 auto array = hw::ArrayCreateOp::create(rewriter, fusedLoc, table);
2123 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
2124 indexValue);
2125 return true;
2126}
2127
2128/// Given a fully associative variadic operation like (a+b+c+d), break the
2129/// expression into two parts, one without the specified operand (e.g.
2130/// `tmp = a+b+d`) and one that combines that into the full expression (e.g.
2131/// `tmp+c`), and return the inner expression.
2132///
2133/// NOTE: This mutates the operation in place if it only has a single user,
2134/// which assumes that user will be removed.
2135///
2136static Value extractOperandFromFullyAssociative(Operation *fullyAssoc,
2137 size_t operandNo,
2138 PatternRewriter &rewriter) {
2139 assert(fullyAssoc->getNumOperands() >= 2 && "cannot split up unary ops");
2140 assert(operandNo < fullyAssoc->getNumOperands() && "Invalid operand #");
2141
2142 // If this expression already has two operands (the common case) no splitting
2143 // is necessary.
2144 if (fullyAssoc->getNumOperands() == 2)
2145 return fullyAssoc->getOperand(operandNo ^ 1);
2146
2147 // If the operation has a single use, mutate it in place.
2148 if (fullyAssoc->hasOneUse()) {
2149 rewriter.modifyOpInPlace(fullyAssoc,
2150 [&]() { fullyAssoc->eraseOperand(operandNo); });
2151 return fullyAssoc->getResult(0);
2152 }
2153
2154 // Form the new operation with the operands that remain.
2155 SmallVector<Value> operands;
2156 operands.append(fullyAssoc->getOperands().begin(),
2157 fullyAssoc->getOperands().begin() + operandNo);
2158 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2159 fullyAssoc->getOperands().end());
2160 Value opWithoutExcluded = createGenericOp(
2161 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2162 Value excluded = fullyAssoc->getOperand(operandNo);
2163
2164 Value fullResult =
2165 createGenericOp(fullyAssoc->getLoc(), fullyAssoc->getName(),
2166 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2167 replaceOpAndCopyNamehint(rewriter, fullyAssoc, fullResult);
2168 return opWithoutExcluded;
2169}
2170
2171/// Fold things like `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond)|a` and
2172/// `mux(cond, a, x|y|z|a) -> `(x|y|z)&replicate(~cond) | a` (when isTrueOperand
2173/// is true. Return true on successful transformation, false if not.
2174///
2175/// These are various forms of "predicated ops" that can be handled with a
2176/// replicate/and combination.
2177static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
2178 PatternRewriter &rewriter) {
2179 // Check to see the operand in question is an operation. If it is a port,
2180 // we can't simplify it.
2181 Operation *subExpr =
2182 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2183 if (!subExpr || subExpr->getNumOperands() < 2)
2184 return false;
2185
2186 // If this isn't an operation we can handle, don't spend energy on it.
2187 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2188 return false;
2189
2190 // Check to see if the common value occurs in the operand list for the
2191 // subexpression op. If so, then we can simplify it.
2192 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2193 size_t opNo = 0, e = subExpr->getNumOperands();
2194 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2195 ++opNo;
2196 if (opNo == e)
2197 return false;
2198
2199 // If we got a hit, then go ahead and simplify it!
2200 Value cond = op.getCond();
2201
2202 // `mux(cond, a, mux(cond2, a, b))` -> `mux(cond|cond2, a, b)`
2203 // `mux(cond, a, mux(cond2, b, a))` -> `mux(cond|~cond2, a, b)`
2204 // `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)`
2205 // `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)`
2206 if (auto subMux = dyn_cast<MuxOp>(subExpr)) {
2207 if (subMux == op)
2208 return false;
2209
2210 Value otherValue;
2211 Value subCond = subMux.getCond();
2212
2213 // Invert th subCond if needed and dig out the 'b' value.
2214 if (subMux.getTrueValue() == commonValue)
2215 otherValue = subMux.getFalseValue();
2216 else if (subMux.getFalseValue() == commonValue) {
2217 otherValue = subMux.getTrueValue();
2218 subCond = createOrFoldNot(op.getLoc(), subCond, rewriter);
2219 } else {
2220 // We can't fold `mux(cond, a, mux(a, x, y))`.
2221 return false;
2222 }
2223
2224 // Invert the outer cond if needed, and combine the mux conditions.
2225 if (!isTrueOperand)
2226 cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2227 cond = rewriter.createOrFold<OrOp>(op.getLoc(), cond, subCond, false);
2228 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2229 otherValue, op.getTwoState());
2230 return true;
2231 }
2232
2233 // Invert the condition if needed. Or/Xor invert when dealing with
2234 // TrueOperand, And inverts for False operand.
2235 bool isaAndOp = isa<AndOp>(subExpr);
2236 if (isTrueOperand ^ isaAndOp)
2237 cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2238
2239 auto extendedCond =
2240 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2241
2242 // Cache this information before subExpr is erased by extraction below.
2243 bool isaXorOp = isa<XorOp>(subExpr);
2244 bool isaOrOp = isa<OrOp>(subExpr);
2245
2246 // Handle the fully associative ops, start by pulling out the subexpression
2247 // from a many operand version of the op.
2248 auto restOfAssoc =
2249 extractOperandFromFullyAssociative(subExpr, opNo, rewriter);
2250
2251 // `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond) | a`
2252 // `mux(cond, x^y^z^a, a)` -> `(x^y^z)&replicate(cond) ^ a`
2253 if (isaOrOp || isaXorOp) {
2254 auto masked = rewriter.createOrFold<AndOp>(op.getLoc(), extendedCond,
2255 restOfAssoc, false);
2256 if (isaXorOp)
2257 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2258 commonValue, false);
2259 else
2260 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2261 false);
2262 return true;
2263 }
2264
2265 // `mux(cond, a, x&y&z&a)` -> `((x&y&z)|replicate(cond)) & a`
2266 assert(isaAndOp && "unexpected operation here");
2267 auto masked = rewriter.createOrFold<OrOp>(op.getLoc(), extendedCond,
2268 restOfAssoc, false);
2269 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2270 false);
2271 return true;
2272}
2273
2274/// This function is invoke when we find a mux with true/false operations that
2275/// have the same opcode. Check to see if we can strength reduce the mux by
2276/// applying it to less data by applying this transformation:
2277/// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2278static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp,
2279 Operation *falseOp,
2280 PatternRewriter &rewriter) {
2281 // Right now we only apply to concat.
2282 // TODO: Generalize this to and, or, xor, icmp(!), which all occur in practice
2283 if (!isa<ConcatOp>(trueOp))
2284 return false;
2285
2286 // Decode the operands, looking through recursive concats and replicates.
2287 SmallVector<Value> trueOperands, falseOperands;
2288 getConcatOperands(trueOp->getResult(0), trueOperands);
2289 getConcatOperands(falseOp->getResult(0), falseOperands);
2290
2291 size_t numTrueOperands = trueOperands.size();
2292 size_t numFalseOperands = falseOperands.size();
2293
2294 if (!numTrueOperands || !numFalseOperands ||
2295 (trueOperands.front() != falseOperands.front() &&
2296 trueOperands.back() != falseOperands.back()))
2297 return false;
2298
2299 // Pull all leading shared operands out into their own op if any are common.
2300 if (trueOperands.front() == falseOperands.front()) {
2301 SmallVector<Value> operands;
2302 size_t i;
2303 for (i = 0; i < numTrueOperands; ++i) {
2304 Value trueOperand = trueOperands[i];
2305 if (trueOperand == falseOperands[i])
2306 operands.push_back(trueOperand);
2307 else
2308 break;
2309 }
2310 if (i == numTrueOperands) {
2311 // Selecting between distinct, but lexically identical, concats.
2312 replaceOpAndCopyNamehint(rewriter, mux, trueOp->getResult(0));
2313 return true;
2314 }
2315
2316 Value sharedMSB;
2317 if (llvm::all_of(operands, [&](Value v) { return v == operands.front(); }))
2318 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2319 mux->getLoc(), operands.front(), operands.size());
2320 else
2321 sharedMSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2322 operands.clear();
2323
2324 // Get a concat of the LSB's on each side.
2325 operands.append(trueOperands.begin() + i, trueOperands.end());
2326 Value trueLSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2327 operands.clear();
2328 operands.append(falseOperands.begin() + i, falseOperands.end());
2329 Value falseLSB =
2330 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2331 // Merge the LSBs with a new mux and concat the MSB with the LSB to be
2332 // done.
2333 Value lsb = rewriter.createOrFold<MuxOp>(
2334 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2335 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2336 return true;
2337 }
2338
2339 // If trailing operands match, try to commonize them.
2340 if (trueOperands.back() == falseOperands.back()) {
2341 SmallVector<Value> operands;
2342 size_t i;
2343 for (i = 0;; ++i) {
2344 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2345 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2346 operands.push_back(trueOperand);
2347 else
2348 break;
2349 }
2350 std::reverse(operands.begin(), operands.end());
2351 Value sharedLSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2352 operands.clear();
2353
2354 // Get a concat of the MSB's on each side.
2355 operands.append(trueOperands.begin(), trueOperands.end() - i);
2356 Value trueMSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2357 operands.clear();
2358 operands.append(falseOperands.begin(), falseOperands.end() - i);
2359 Value falseMSB =
2360 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2361 // Merge the MSBs with a new mux and concat the MSB with the LSB to be done.
2362 Value msb = rewriter.createOrFold<MuxOp>(
2363 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2364 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2365 return true;
2366 }
2367
2368 return false;
2369}
2370
2371// If both arguments of the mux are arrays with the same elements, sink the
2372// mux and return a uniform array initializing all elements to it.
2373static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter) {
2374 auto trueVec = op.getTrueValue().getDefiningOp<hw::ArrayCreateOp>();
2375 auto falseVec = op.getFalseValue().getDefiningOp<hw::ArrayCreateOp>();
2376 if (!trueVec || !falseVec)
2377 return false;
2378 if (!trueVec.isUniform() || !falseVec.isUniform())
2379 return false;
2380
2381 auto mux = MuxOp::create(rewriter, op.getLoc(), op.getCond(),
2382 trueVec.getUniformElement(),
2383 falseVec.getUniformElement(), op.getTwoState());
2384
2385 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2386 rewriter.replaceOpWithNewOp<hw::ArrayCreateOp>(op, values);
2387 return true;
2388}
2389
2390/// If the mux condition is an operand to the op defining its true or false
2391/// value, replace the condition with 1 or 0.
2392static bool assumeMuxCondInOperand(Value muxCond, Value muxValue,
2393 bool constCond, PatternRewriter &rewriter) {
2394 if (!muxValue.hasOneUse())
2395 return false;
2396 auto *op = muxValue.getDefiningOp();
2397 if (!op || !isa_and_nonnull<CombDialect>(op->getDialect()))
2398 return false;
2399 if (!llvm::is_contained(op->getOperands(), muxCond))
2400 return false;
2401 OpBuilder::InsertionGuard guard(rewriter);
2402 rewriter.setInsertionPoint(op);
2403 auto condValue =
2404 hw::ConstantOp::create(rewriter, muxCond.getLoc(), APInt(1, constCond));
2405 rewriter.modifyOpInPlace(op, [&] {
2406 for (auto &use : op->getOpOperands())
2407 if (use.get() == muxCond)
2408 use.set(condValue);
2409 });
2410 return true;
2411}
2412
2413namespace {
2414struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
2415 using OpRewritePattern::OpRewritePattern;
2416
2417 LogicalResult matchAndRewrite(MuxOp op,
2418 PatternRewriter &rewriter) const override;
2419};
2420
2422foldToArrayCreateOnlyWhenDense(size_t indexWidth, size_t numEntries) {
2423 // If the array is greater that 9 bits, it will take over 512 elements and
2424 // it will be too large for a single expression.
2425 if (indexWidth >= 9 || numEntries < 3)
2427
2428 // Next we need to see if the values are dense-ish. We don't want to have
2429 // a tremendous number of replicated entries in the array. Some sparsity is
2430 // ok though, so we require the table to be at least 5/8 utilized.
2431 uint64_t tableSize = 1ULL << indexWidth;
2432 if (numEntries >= tableSize * 5 / 8)
2435}
2436
2437LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
2438 PatternRewriter &rewriter) const {
2439 if (isOpTriviallyRecursive(op))
2440 return failure();
2441
2442 bool isSignlessInt = false;
2443 if (auto intType = dyn_cast<IntegerType>(op.getType()))
2444 isSignlessInt = intType.isSignless();
2445
2446 // If the op has a SV attribute, don't optimize it.
2447 if (hasSVAttributes(op))
2448 return failure();
2449 APInt value;
2450
2451 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value)) && isSignlessInt) {
2452 if (value.getBitWidth() == 1) {
2453 // mux(a, 0, b) -> and(~a, b) for single-bit values.
2454 if (value.isZero()) {
2455 auto notCond = createOrFoldNot(op.getLoc(), op.getCond(), rewriter);
2456 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2457 op.getFalseValue(), false);
2458 return success();
2459 }
2460
2461 // mux(a, 1, b) -> or(a, b) for single-bit values.
2462 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2463 op.getFalseValue(), false);
2464 return success();
2465 }
2466
2467 // Check for mux of two constants. There are many ways to simplify them.
2468 APInt value2;
2469 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2470 // When both inputs are constants and differ by only one bit, we can
2471 // simplify by splitting the mux into up to three contiguous chunks: one
2472 // for the differing bit and up to two for the bits that are the same.
2473 // E.g. mux(a, 3'h2, 0) -> concat(0, mux(a, 1, 0), 0) -> concat(0, a, 0)
2474 APInt xorValue = value ^ value2;
2475 if (xorValue.isPowerOf2()) {
2476 unsigned leadingZeros = xorValue.countLeadingZeros();
2477 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2478 SmallVector<Value, 3> operands;
2479
2480 // Concat operands go from MSB to LSB, so we handle chunks in reverse
2481 // order of bit indexes.
2482 // For the chunks that are identical (i.e. correspond to 0s in
2483 // xorValue), we can extract directly from either input value, and we
2484 // arbitrarily pick the trueValue().
2485
2486 if (leadingZeros > 0)
2487 operands.push_back(rewriter.createOrFold<ExtractOp>(
2488 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2489
2490 // Handle the differing bit, which should simplify into either cond or
2491 // ~cond.
2492 auto v1 = rewriter.createOrFold<ExtractOp>(
2493 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2494 auto v2 = rewriter.createOrFold<ExtractOp>(
2495 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2496 operands.push_back(rewriter.createOrFold<MuxOp>(
2497 op.getLoc(), op.getCond(), v1, v2, false));
2498
2499 if (trailingZeros > 0)
2500 operands.push_back(rewriter.createOrFold<ExtractOp>(
2501 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2502
2503 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2504 operands);
2505 return success();
2506 }
2507
2508 // If the true value is all ones and the false is all zeros then we have a
2509 // replicate pattern.
2510 if (value.isAllOnes() && value2.isZero()) {
2511 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2512 rewriter, op, op.getType(), op.getCond());
2513 return success();
2514 }
2515 }
2516 }
2517
2518 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2519 isSignlessInt && value.getBitWidth() == 1) {
2520 // mux(a, b, 0) -> and(a, b) for single-bit values.
2521 if (value.isZero()) {
2522 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2523 op.getTrueValue(), false);
2524 return success();
2525 }
2526
2527 // mux(a, b, 1) -> or(~a, b) for single-bit values.
2528 // falseValue() is known to be a single-bit 1, which we can use for
2529 // the 1 in the representation of ~ using xor.
2530 auto notCond = rewriter.createOrFold<XorOp>(op.getLoc(), op.getCond(),
2531 op.getFalseValue(), false);
2532 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2533 op.getTrueValue(), false);
2534 return success();
2535 }
2536
2537 // mux(!a, b, c) -> mux(a, c, b)
2538 Value subExpr;
2539 Operation *condOp = op.getCond().getDefiningOp();
2540 if (condOp && matchPattern(condOp, m_Complement(m_Any(&subExpr))) &&
2541 op.getTwoState()) {
2542 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2543 subExpr, op.getFalseValue(),
2544 op.getTrueValue(), true);
2545 return success();
2546 }
2547
2548 // Same but with Demorgan's law.
2549 // mux(and(~a, ~b, ~c), x, y) -> mux(or(a, b, c), y, x)
2550 // mux(or(~a, ~b, ~c), x, y) -> mux(and(a, b, c), y, x)
2551 if (condOp && condOp->hasOneUse()) {
2552 SmallVector<Value> invertedOperands;
2553
2554 /// Scan all the operands to see if they are complemented. If so, build a
2555 /// vector of them and return true, otherwise return false.
2556 auto getInvertedOperands = [&]() -> bool {
2557 for (Value operand : condOp->getOperands()) {
2558 if (matchPattern(operand, m_Complement(m_Any(&subExpr))))
2559 invertedOperands.push_back(subExpr);
2560 else
2561 return false;
2562 }
2563 return true;
2564 };
2565
2566 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2567 auto newOr =
2568 rewriter.createOrFold<OrOp>(op.getLoc(), invertedOperands, false);
2569 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2570 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2571 op.getTwoState());
2572 return success();
2573 }
2574 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2575 auto newAnd =
2576 rewriter.createOrFold<AndOp>(op.getLoc(), invertedOperands, false);
2577 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2578 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2579 op.getTwoState());
2580 return success();
2581 }
2582 }
2583
2584 if (auto falseMux = op.getFalseValue().getDefiningOp<MuxOp>();
2585 falseMux && falseMux != op) {
2586 // mux(selector, x, mux(selector, y, z) = mux(selector, x, z)
2587 if (op.getCond() == falseMux.getCond() &&
2588 falseMux.getFalseValue() != falseMux) {
2589 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2590 rewriter, op, op.getCond(), op.getTrueValue(),
2591 falseMux.getFalseValue(), op.getTwoStateAttr());
2592 return success();
2593 }
2594
2595 // Check to see if we can fold a mux tree into an array_create/get pair.
2596 if (foldMuxChainWithComparison(rewriter, op, /*isFalse*/ true,
2597 foldToArrayCreateOnlyWhenDense))
2598 return success();
2599 }
2600
2601 if (auto trueMux = op.getTrueValue().getDefiningOp<MuxOp>();
2602 trueMux && trueMux != op) {
2603 // mux(selector, mux(selector, a, b), c) = mux(selector, a, c)
2604 if (op.getCond() == trueMux.getCond()) {
2605 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2606 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2607 op.getFalseValue(), op.getTwoStateAttr());
2608 return success();
2609 }
2610
2611 // Check to see if we can fold a mux tree into an array_create/get pair.
2612 if (foldMuxChainWithComparison(rewriter, op, /*isFalseSide*/ false,
2613 foldToArrayCreateOnlyWhenDense))
2614 return success();
2615 }
2616
2617 // mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c))
2618 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2619 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2620 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2621 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2622 falseMux != op) {
2623 auto subMux = MuxOp::create(
2624 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2625 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2626 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2627 trueMux.getTrueValue(), subMux,
2628 op.getTwoStateAttr());
2629 return success();
2630 }
2631
2632 // mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b)
2633 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2634 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2635 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2636 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2637 falseMux != op) {
2638 auto subMux = MuxOp::create(
2639 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2640 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2641 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2642 subMux, trueMux.getFalseValue(),
2643 op.getTwoStateAttr());
2644 return success();
2645 }
2646
2647 // mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b)
2648 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2649 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2650 trueMux && falseMux &&
2651 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2652 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2653 falseMux != op) {
2654 auto subMux =
2655 MuxOp::create(rewriter,
2656 rewriter.getFusedLoc(
2657 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2658 op.getCond(), trueMux.getCond(), falseMux.getCond());
2659 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2660 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2661 op.getTwoStateAttr());
2662 return success();
2663 }
2664
2665 // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a
2666 if (foldCommonMuxValue(op, false, rewriter))
2667 return success();
2668 // mux(cond, a, x|y|z|a) -> (x|y|z)&replicate(~cond) | a
2669 if (foldCommonMuxValue(op, true, rewriter))
2670 return success();
2671
2672 // `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2673 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2674 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2675 if (trueOp->getName() == falseOp->getName())
2676 if (foldCommonMuxOperation(op, trueOp, falseOp, rewriter))
2677 return success();
2678
2679 // extracts only of mux(...) -> mux(extract()...)
2680 if (narrowOperationWidth(op, true, rewriter))
2681 return success();
2682
2683 // mux(cond, repl(n, a1), repl(n, a2)) -> repl(n, mux(cond, a1, a2))
2684 if (foldMuxOfUniformArrays(op, rewriter))
2685 return success();
2686
2687 // mux(cond, opA(cond), opB(cond)) -> mux(cond, opA(1), opB(0))
2688 if (op.getTrueValue().getDefiningOp() &&
2689 op.getTrueValue().getDefiningOp() != op)
2690 if (assumeMuxCondInOperand(op.getCond(), op.getTrueValue(), true, rewriter))
2691 return success();
2692 if (op.getFalseValue().getDefiningOp() &&
2693 op.getFalseValue().getDefiningOp() != op)
2694
2695 if (assumeMuxCondInOperand(op.getCond(), op.getFalseValue(), false,
2696 rewriter))
2697 return success();
2698
2699 return failure();
2700}
2701
2702static bool foldArrayOfMuxes(hw::ArrayCreateOp op, PatternRewriter &rewriter) {
2703 // Do not fold uniform or singleton arrays to avoid duplicating muxes.
2704 if (op.getInputs().empty() || op.isUniform())
2705 return false;
2706 auto inputs = op.getInputs();
2707 if (inputs.size() <= 1)
2708 return false;
2709
2710 // Check the operands to the array create. Ensure all of them are the
2711 // same op with the same number of operands.
2712 auto first = inputs[0].getDefiningOp<comb::MuxOp>();
2713 if (!first || hasSVAttributes(first))
2714 return false;
2715
2716 // Check whether all operands are muxes with the same condition.
2717 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2718 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2719 if (!input || first.getCond() != input.getCond())
2720 return false;
2721 }
2722
2723 // Collect the true and the false branches into arrays.
2724 SmallVector<Value> trues{first.getTrueValue()};
2725 SmallVector<Value> falses{first.getFalseValue()};
2726 SmallVector<Location> locs{first->getLoc()};
2727 bool isTwoState = true;
2728 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2729 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2730 trues.push_back(input.getTrueValue());
2731 falses.push_back(input.getFalseValue());
2732 locs.push_back(input->getLoc());
2733 if (!input.getTwoState())
2734 isTwoState = false;
2735 }
2736
2737 // Define the location of the array create as the aggregate of all muxes.
2738 auto loc = FusedLoc::get(op.getContext(), locs);
2739
2740 // Replace the create with an aggregate operation. Push the create op
2741 // into the operands of the aggregate operation.
2742 auto arrayTy = op.getType();
2743 auto trueValues = hw::ArrayCreateOp::create(rewriter, loc, arrayTy, trues);
2744 auto falseValues = hw::ArrayCreateOp::create(rewriter, loc, arrayTy, falses);
2745 rewriter.replaceOpWithNewOp<comb::MuxOp>(op, arrayTy, first.getCond(),
2746 trueValues, falseValues, isTwoState);
2747 return true;
2748}
2749
2750struct ArrayRewriter : public mlir::OpRewritePattern<hw::ArrayCreateOp> {
2751 using OpRewritePattern::OpRewritePattern;
2752
2753 LogicalResult matchAndRewrite(hw::ArrayCreateOp op,
2754 PatternRewriter &rewriter) const override {
2755 if (foldArrayOfMuxes(op, rewriter))
2756 return success();
2757 return failure();
2758 }
2759};
2760
2761} // namespace
2762
2763void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2764 MLIRContext *context) {
2765 results.insert<MuxRewriter, ArrayRewriter>(context);
2766}
2767
2768//===----------------------------------------------------------------------===//
2769// ICmpOp
2770//===----------------------------------------------------------------------===//
2771
2772// Calculate the result of a comparison when the LHS and RHS are both
2773// constants.
2774static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs,
2775 const APInt &rhs) {
2776 switch (predicate) {
2777 case ICmpPredicate::eq:
2778 return lhs.eq(rhs);
2779 case ICmpPredicate::ne:
2780 return lhs.ne(rhs);
2781 case ICmpPredicate::slt:
2782 return lhs.slt(rhs);
2783 case ICmpPredicate::sle:
2784 return lhs.sle(rhs);
2785 case ICmpPredicate::sgt:
2786 return lhs.sgt(rhs);
2787 case ICmpPredicate::sge:
2788 return lhs.sge(rhs);
2789 case ICmpPredicate::ult:
2790 return lhs.ult(rhs);
2791 case ICmpPredicate::ule:
2792 return lhs.ule(rhs);
2793 case ICmpPredicate::ugt:
2794 return lhs.ugt(rhs);
2795 case ICmpPredicate::uge:
2796 return lhs.uge(rhs);
2797 case ICmpPredicate::ceq:
2798 return lhs.eq(rhs);
2799 case ICmpPredicate::cne:
2800 return lhs.ne(rhs);
2801 case ICmpPredicate::weq:
2802 return lhs.eq(rhs);
2803 case ICmpPredicate::wne:
2804 return lhs.ne(rhs);
2805 }
2806 llvm_unreachable("unknown comparison predicate");
2807}
2808
2809// Returns the result of applying the predicate when the LHS and RHS are the
2810// exact same value.
2811static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) {
2812 switch (predicate) {
2813 case ICmpPredicate::eq:
2814 case ICmpPredicate::sle:
2815 case ICmpPredicate::sge:
2816 case ICmpPredicate::ule:
2817 case ICmpPredicate::uge:
2818 case ICmpPredicate::ceq:
2819 case ICmpPredicate::weq:
2820 return true;
2821 case ICmpPredicate::ne:
2822 case ICmpPredicate::slt:
2823 case ICmpPredicate::sgt:
2824 case ICmpPredicate::ult:
2825 case ICmpPredicate::ugt:
2826 case ICmpPredicate::cne:
2827 case ICmpPredicate::wne:
2828 return false;
2829 }
2830 llvm_unreachable("unknown comparison predicate");
2831}
2832
2833OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2834 // gt a, a -> false
2835 // gte a, a -> true
2836 if (getLhs() == getRhs()) {
2837 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2838 return IntegerAttr::get(getType(), val);
2839 }
2840
2841 // gt 1, 2 -> false
2842 if (auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2843 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2844 auto val =
2845 applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2846 return IntegerAttr::get(getType(), val);
2847 }
2848 }
2849 return {};
2850}
2851
2852// Given a range of operands, computes the number of matching prefix and
2853// suffix elements. This does not perform cross-element matching.
2854template <typename Range>
2855static size_t computeCommonPrefixLength(const Range &a, const Range &b) {
2856 size_t commonPrefixLength = 0;
2857 auto ia = a.begin();
2858 auto ib = b.begin();
2859
2860 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2861 if (*ia != *ib) {
2862 break;
2863 }
2864 }
2865
2866 return commonPrefixLength;
2867}
2868
2869static size_t getTotalWidth(ArrayRef<Value> operands) {
2870 size_t totalWidth = 0;
2871 for (auto operand : operands) {
2872 // getIntOrFloatBitWidth should never raise, since all arguments to
2873 // ConcatOp are integers.
2874 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2875 assert(width >= 0);
2876 totalWidth += width;
2877 }
2878 return totalWidth;
2879}
2880
2881/// Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise
2882/// comparison on common prefix and suffixes. Returns success() if a rewriting
2883/// happens. This handles both concat and replicate.
2884static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs,
2885 Operation *rhs,
2886 PatternRewriter &rewriter) {
2887 // It is safe to assume that [{lhsOperands, rhsOperands}.size() > 0] and
2888 // all elements have non-zero length. Both these invariants are verified
2889 // by the ConcatOp verifier.
2890 SmallVector<Value> lhsOperands, rhsOperands;
2891 getConcatOperands(lhs->getResult(0), lhsOperands);
2892 getConcatOperands(rhs->getResult(0), rhsOperands);
2893 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2894
2895 auto formCatOrReplicate = [&](Location loc,
2896 ArrayRef<Value> operands) -> Value {
2897 assert(!operands.empty());
2898 Value sameElement = operands[0];
2899 for (size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2900 if (sameElement != operands[i])
2901 sameElement = Value();
2902 if (sameElement)
2903 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2904 operands.size());
2905 return rewriter.createOrFold<ConcatOp>(loc, operands);
2906 };
2907
2908 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2909 Value rhs) -> LogicalResult {
2910 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2911 op.getTwoState());
2912 return success();
2913 };
2914
2915 size_t commonPrefixLength =
2916 computeCommonPrefixLength(lhsOperands, rhsOperands);
2917 if (commonPrefixLength == lhsOperands.size()) {
2918 // cat(a, b, c) == cat(a, b, c) -> 1
2919 bool result = applyCmpPredicateToEqualOperands(op.getPredicate());
2920 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2921 APInt(1, result));
2922 return success();
2923 }
2924
2925 size_t commonSuffixLength = computeCommonPrefixLength(
2926 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2927
2928 size_t commonPrefixTotalWidth =
2929 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2930 size_t commonSuffixTotalWidth =
2931 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2932 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2933 .drop_back(commonSuffixLength);
2934 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2935 .drop_back(commonSuffixLength);
2936
2937 auto replaceWithoutReplicatingSignBit = [&]() {
2938 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2939 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2940 return replaceWith(op.getPredicate(), newLhs, newRhs);
2941 };
2942
2943 auto replaceWithReplicatingSignBit = [&]() {
2944 auto firstNonEmptyValue = lhsOperands[0];
2945 auto firstNonEmptyElemWidth =
2946 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2947 Value signBit = rewriter.createOrFold<ExtractOp>(
2948 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2949
2950 auto newLhs = ConcatOp::create(rewriter, lhs->getLoc(), signBit, lhsOnly);
2951 auto newRhs = ConcatOp::create(rewriter, rhs->getLoc(), signBit, rhsOnly);
2952 return replaceWith(op.getPredicate(), newLhs, newRhs);
2953 };
2954
2955 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2956 // scmp(cat(..x, b), cat(..y, b)) == scmp(cat(..x), cat(..y))
2957 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2958 return replaceWithoutReplicatingSignBit();
2959
2960 // scmp(cat(a, ..x, b), cat(a, ..y, b)) == scmp(cat(sgn(a), ..x),
2961 // cat(sgn(b), ..y)) Note that we cannot perform this optimization if
2962 // [width(b) = 0 && width(a) <= 1]. since that common prefix is the sign
2963 // bit. Doing the rewrite can result in an infinite loop.
2964 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2965 return replaceWithReplicatingSignBit();
2966
2967 } else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2968 // ucmp(cat(a, ..x, b), cat(a, ..y, b)) = ucmp(cat(..x), cat(..y))
2969 return replaceWithoutReplicatingSignBit();
2970 }
2971
2972 return failure();
2973}
2974
2975/// Given an equality comparison with a constant value and some operand that has
2976/// known bits, simplify the comparison to check only the unknown bits of the
2977/// input.
2978///
2979/// One simple example of this is that `concat(0, stuff) == 0` can be simplified
2980/// to `stuff == 0`, or `and(x, 3) == 0` can be simplified to
2981/// `extract x[1:0] == 0`
2983 ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst,
2984 PatternRewriter &rewriter) {
2985
2986 // If any of the known bits disagree with any of the comparison bits, then
2987 // we can constant fold this comparison right away.
2988 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2989 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2990 // If we discover a mismatch then we know an "eq" comparison is false
2991 // and a "ne" comparison is true!
2992 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2993 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2994 APInt(1, result));
2995 return;
2996 }
2997
2998 // Check to see if we can prove the result entirely of the comparison (in
2999 // which we bail out early), otherwise build a list of values to concat and a
3000 // smaller constant to compare against.
3001 SmallVector<Value> newConcatOperands;
3002 auto newConstant = APInt::getZeroWidth();
3003
3004 // Ok, some (maybe all) bits are known and some others may be unknown.
3005 // Extract out segments of the operand and compare against the
3006 // corresponding bits.
3007 unsigned knownMSB = bitsKnown.countLeadingOnes();
3008
3009 Value operand = cmpOp.getLhs();
3010
3011 // Ok, some bits are known but others are not. Extract out sequences of
3012 // bits that are unknown and compare just those bits. We work from MSB to
3013 // LSB.
3014 while (knownMSB != bitsKnown.getBitWidth()) {
3015 // Drop any high bits that are known.
3016 if (knownMSB)
3017 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
3018
3019 // Find the span of unknown bits, and extract it.
3020 unsigned unknownBits = bitsKnown.countLeadingZeros();
3021 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
3022 auto spanOperand = rewriter.createOrFold<ExtractOp>(
3023 operand.getLoc(), operand, /*lowBit=*/lowBit,
3024 /*bitWidth=*/unknownBits);
3025 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
3026
3027 // Add this info to the concat we're generating.
3028 newConcatOperands.push_back(spanOperand);
3029 // FIXME(llvm merge, cc697fc292b0): concat doesn't work with zero bit values
3030 // newConstant = newConstant.concat(spanConstant);
3031 if (newConstant.getBitWidth() != 0)
3032 newConstant = newConstant.concat(spanConstant);
3033 else
3034 newConstant = spanConstant;
3035
3036 // Drop the unknown bits in prep for the next chunk.
3037 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3038 bitsKnown = bitsKnown.trunc(newWidth);
3039 knownMSB = bitsKnown.countLeadingOnes();
3040 }
3041
3042 // If all the operands to the concat are foldable then we have an identity
3043 // situation where all the sub-elements equal each other. This implies that
3044 // the overall result is foldable.
3045 if (newConcatOperands.empty()) {
3046 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3047 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
3048 APInt(1, result));
3049 return;
3050 }
3051
3052 // If we have a single operand remaining, use it, otherwise form a concat.
3053 Value concatResult =
3054 rewriter.createOrFold<ConcatOp>(operand.getLoc(), newConcatOperands);
3055
3056 // Form the comparison against the smaller constant.
3057 auto newConstantOp = hw::ConstantOp::create(
3058 rewriter, cmpOp.getOperand(1).getLoc(), newConstant);
3059
3060 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
3061 cmpOp.getPredicate(), concatResult,
3062 newConstantOp, cmpOp.getTwoState());
3063}
3064
3065// Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), cst1^cst2).
3066static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp,
3067 const APInt &rhs,
3068 PatternRewriter &rewriter) {
3069 auto ip = rewriter.saveInsertionPoint();
3070 rewriter.setInsertionPoint(xorOp);
3071
3072 auto xorRHS = xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>();
3073 auto newRHS = hw::ConstantOp::create(rewriter, xorRHS->getLoc(),
3074 xorRHS.getValue() ^ rhs);
3075 Value newLHS;
3076 switch (xorOp.getNumOperands()) {
3077 case 1:
3078 // This isn't common but is defined so we need to handle it.
3079 newLHS = hw::ConstantOp::create(rewriter, xorOp.getLoc(),
3080 APInt::getZero(rhs.getBitWidth()));
3081 break;
3082 case 2:
3083 // The binary case is the most common.
3084 newLHS = xorOp.getOperand(0);
3085 break;
3086 default:
3087 // The general case forces us to form a new xor with the remaining operands.
3088 SmallVector<Value> newOperands(xorOp.getOperands());
3089 newOperands.pop_back();
3090 newLHS = XorOp::create(rewriter, xorOp.getLoc(), newOperands, false);
3091 break;
3092 }
3093
3094 bool xorMultipleUses = !xorOp->hasOneUse();
3095
3096 // If the xor has multiple uses (not just the compare, then we need/want to
3097 // replace them as well.
3098 if (xorMultipleUses)
3099 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3100 false);
3101
3102 // Replace the comparison.
3103 rewriter.restoreInsertionPoint(ip);
3104 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3105 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS, false);
3106}
3107
3108LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3109 if (isOpTriviallyRecursive(op))
3110 return failure();
3111 APInt lhs, rhs;
3112
3113 // icmp 1, x -> icmp x, 1
3114 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3115 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3116 "Should be folded");
3117 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3118 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3119 op.getRhs(), op.getLhs(), op.getTwoState());
3120 return success();
3121 }
3122
3123 // Canonicalize with RHS constant
3124 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3125 auto getConstant = [&](APInt constant) -> Value {
3126 return hw::ConstantOp::create(rewriter, op.getLoc(), std::move(constant));
3127 };
3128
3129 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3130 Value rhs) -> LogicalResult {
3131 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
3132 rhs, op.getTwoState());
3133 return success();
3134 };
3135
3136 auto replaceWithConstantI1 = [&](bool constant) -> LogicalResult {
3137 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
3138 APInt(1, constant));
3139 return success();
3140 };
3141
3142 switch (op.getPredicate()) {
3143 case ICmpPredicate::slt:
3144 // x < max -> x != max
3145 if (rhs.isMaxSignedValue())
3146 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3147 // x < min -> false
3148 if (rhs.isMinSignedValue())
3149 return replaceWithConstantI1(0);
3150 // x < min+1 -> x == min
3151 if ((rhs - 1).isMinSignedValue())
3152 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3153 getConstant(rhs - 1));
3154 break;
3155 case ICmpPredicate::sgt:
3156 // x > min -> x != min
3157 if (rhs.isMinSignedValue())
3158 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3159 // x > max -> false
3160 if (rhs.isMaxSignedValue())
3161 return replaceWithConstantI1(0);
3162 // x > max-1 -> x == max
3163 if ((rhs + 1).isMaxSignedValue())
3164 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3165 getConstant(rhs + 1));
3166 break;
3167 case ICmpPredicate::ult:
3168 // x < max -> x != max
3169 if (rhs.isAllOnes())
3170 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3171 // x < min -> false
3172 if (rhs.isZero())
3173 return replaceWithConstantI1(0);
3174 // x < min+1 -> x == min
3175 if ((rhs - 1).isZero())
3176 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3177 getConstant(rhs - 1));
3178
3179 // x < 0xE0 -> extract(x, 5..7) != 0b111
3180 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3181 rhs.getBitWidth()) {
3182 auto numOnes = rhs.countLeadingOnes();
3183 auto smaller = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(),
3184 rhs.getBitWidth() - numOnes, numOnes);
3185 return replaceWith(ICmpPredicate::ne, smaller,
3186 getConstant(APInt::getAllOnes(numOnes)));
3187 }
3188
3189 break;
3190 case ICmpPredicate::ugt:
3191 // x > min -> x != min
3192 if (rhs.isZero())
3193 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3194 // x > max -> false
3195 if (rhs.isAllOnes())
3196 return replaceWithConstantI1(0);
3197 // x > max-1 -> x == max
3198 if ((rhs + 1).isAllOnes())
3199 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3200 getConstant(rhs + 1));
3201
3202 // x > 0x07 -> extract(x, 3..7) != 0b00000
3203 if ((rhs + 1).isPowerOf2()) {
3204 auto numOnes = rhs.countTrailingOnes();
3205 auto newWidth = rhs.getBitWidth() - numOnes;
3206 auto smaller = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(),
3207 numOnes, newWidth);
3208 return replaceWith(ICmpPredicate::ne, smaller,
3209 getConstant(APInt::getZero(newWidth)));
3210 }
3211
3212 break;
3213 case ICmpPredicate::sle:
3214 // x <= max -> true
3215 if (rhs.isMaxSignedValue())
3216 return replaceWithConstantI1(1);
3217 // x <= c -> x < (c+1)
3218 return replaceWith(ICmpPredicate::slt, op.getLhs(), getConstant(rhs + 1));
3219 case ICmpPredicate::sge:
3220 // x >= min -> true
3221 if (rhs.isMinSignedValue())
3222 return replaceWithConstantI1(1);
3223 // x >= c -> x > (c-1)
3224 return replaceWith(ICmpPredicate::sgt, op.getLhs(), getConstant(rhs - 1));
3225 case ICmpPredicate::ule:
3226 // x <= max -> true
3227 if (rhs.isAllOnes())
3228 return replaceWithConstantI1(1);
3229 // x <= c -> x < (c+1)
3230 return replaceWith(ICmpPredicate::ult, op.getLhs(), getConstant(rhs + 1));
3231 case ICmpPredicate::uge:
3232 // x >= min -> true
3233 if (rhs.isZero())
3234 return replaceWithConstantI1(1);
3235 // x >= c -> x > (c-1)
3236 return replaceWith(ICmpPredicate::ugt, op.getLhs(), getConstant(rhs - 1));
3237 case ICmpPredicate::eq:
3238 if (rhs.getBitWidth() == 1) {
3239 if (rhs.isZero()) {
3240 // x == 0 -> x ^ 1
3241 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3242 getConstant(APInt(1, 1)),
3243 op.getTwoState());
3244 return success();
3245 }
3246 if (rhs.isAllOnes()) {
3247 // x == 1 -> x
3248 replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
3249 return success();
3250 }
3251 }
3252 break;
3253 case ICmpPredicate::ne:
3254 if (rhs.getBitWidth() == 1) {
3255 if (rhs.isZero()) {
3256 // x != 0 -> x
3257 replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
3258 return success();
3259 }
3260 if (rhs.isAllOnes()) {
3261 // x != 1 -> x ^ 1
3262 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3263 getConstant(APInt(1, 1)),
3264 op.getTwoState());
3265 return success();
3266 }
3267 }
3268 break;
3269 case ICmpPredicate::ceq:
3270 case ICmpPredicate::cne:
3271 case ICmpPredicate::weq:
3272 case ICmpPredicate::wne:
3273 break;
3274 }
3275
3276 // We have some specific optimizations for comparison with a constant that
3277 // are only supported for equality comparisons.
3278 if (op.getPredicate() == ICmpPredicate::eq ||
3279 op.getPredicate() == ICmpPredicate::ne) {
3280 // Simplify `icmp(value_with_known_bits, rhscst)` into some extracts
3281 // with a smaller constant. We only support equality comparisons for
3282 // this.
3283 auto knownBits = computeKnownBits(op.getLhs());
3284 if (!knownBits.isUnknown())
3285 return combineEqualityICmpWithKnownBitsAndConstant(op, knownBits, rhs,
3286 rewriter),
3287 success();
3288
3289 // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b),
3290 // cst1^cst2).
3291 if (auto xorOp = op.getLhs().getDefiningOp<XorOp>())
3292 if (xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>())
3293 return combineEqualityICmpWithXorOfConstant(op, xorOp, rhs, rewriter),
3294 success();
3295
3296 // Simplify icmp eq(replicate(v, n), c) -> icmp eq(v, c) if c is zero or
3297 // all one.
3298 if (auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3299 if (rhs.isAllOnes() || rhs.isZero()) {
3300 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3301 auto cst =
3302 hw::ConstantOp::create(rewriter, op.getLoc(),
3303 rhs.isAllOnes() ? APInt::getAllOnes(width)
3304 : APInt::getZero(width));
3305 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3306 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3307 op.getTwoState());
3308 return success();
3309 }
3310 }
3311 }
3312
3313 // icmp(cat(prefix, a, b, suffix), cat(prefix, c, d, suffix)) => icmp(cat(a,
3314 // b), cat(c, d)). contains special handling for sign bit in signed
3315 // compressions.
3316 if (Operation *opLHS = op.getLhs().getDefiningOp())
3317 if (Operation *opRHS = op.getRhs().getDefiningOp())
3318 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3319 isa<ConcatOp, ReplicateOp>(opRHS)) {
3320 if (succeeded(matchAndRewriteCompareConcat(op, opLHS, opRHS, rewriter)))
3321 return success();
3322 }
3323
3324 return failure();
3325}
assert(baseType &&"element must be base type")
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
Definition CombFolds.cpp:85
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
Definition CombFolds.cpp:52
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
Definition CombFolds.cpp:46
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
Definition CombFolds.cpp:91
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool assumeMuxCondInOperand(Value muxCond, Value muxValue, bool constCond, PatternRewriter &rewriter)
If the mux condition is an operand to the op defining its true or false value, replace the condition ...
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool isOpTriviallyRecursive(Operation *op)
Definition CombFolds.cpp:27
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool hasSVAttributes(Operation *op)
Definition CombFolds.cpp:67
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
Definition CombFolds.cpp:38
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::unique_ptr< Context > context
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(low_bit, result_type, input=None)
Definition comb.py:187
create(elements, Type result_type=None)
Definition hw.py:483
create(array_value, low_index, ret_type)
Definition hw.py:466
create(data_type, value)
Definition hw.py:441
create(data_type, value)
Definition hw.py:433
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
void extractBits(OpBuilder &builder, Value val, SmallVectorImpl< Value > &bits)
Extract bits from a value.
Definition CombOps.cpp:78
bool foldMuxChainWithComparison(PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide, llvm::function_ref< MuxChainWithComparisonFoldingStyle(size_t indexWidth, size_t numEntries)> styleFn)
Mux chain folding that converts chains of muxes with index comparisons into array operations or balan...
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition CombOps.cpp:66
MuxChainWithComparisonFoldingStyle
Enum for mux chain folding styles.
Definition CombOps.h:104
@ BalancedMuxTree
Definition CombOps.h:104
LogicalResult convertModUByPowerOfTwo(ModUOp modOp, mlir::PatternRewriter &rewriter)
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Definition CombOps.cpp:43
Value constructMuxTree(OpBuilder &builder, Location loc, ArrayRef< Value > selectors, ArrayRef< Value > leafNodes, Value outOfBoundsValue)
Construct a mux tree for given leaf nodes.
Definition CombOps.cpp:105
LogicalResult convertDivUByPowerOfTwo(DivUOp divOp, mlir::PatternRewriter &rewriter)
Convert unsigned division or modulo by a power of two.
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
Definition comb.py:1