CIRCT 23.0.0git
Loading...
Searching...
No Matches
CombFolds.cpp
Go to the documentation of this file.
1//===- CombFolds.cpp - Folds + Canonicalization for Comb operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Diagnostics.h"
14#include "mlir/IR/Matchers.h"
15#include "mlir/IR/PatternMatch.h"
16#include "llvm/ADT/SetVector.h"
17#include "llvm/ADT/SmallBitVector.h"
18#include "llvm/ADT/TypeSwitch.h"
19#include "llvm/Support/KnownBits.h"
20
21using namespace mlir;
22using namespace circt;
23using namespace comb;
24using namespace matchers;
25
26// Returns true if the op has one of its own results as an operand.
27static bool isOpTriviallyRecursive(Operation *op) {
28 return llvm::any_of(op->getOperands(), [op](auto operand) {
29 return operand.getDefiningOp() == op;
30 });
31}
32
33/// Create a new instance of a generic operation that only has value operands,
34/// and has a single result value whose type matches the first operand.
35///
36/// This should not be used to create instances of ops with attributes or with
37/// more complicated type signatures.
38static Value createGenericOp(Location loc, OperationName name,
39 ArrayRef<Value> operands, OpBuilder &builder) {
40 OperationState state(loc, name);
41 state.addOperands(operands);
42 state.addTypes(operands[0].getType());
43 return builder.create(state)->getResult(0);
44}
45
46static TypedAttr getIntAttr(const APInt &value, MLIRContext *context) {
47 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
48 value);
49}
50
51/// Flatten concat and mux operands into a vector.
52static void getConcatOperands(Value v, SmallVectorImpl<Value> &result) {
53 if (auto concat = v.getDefiningOp<ConcatOp>()) {
54 for (auto op : concat.getOperands())
55 getConcatOperands(op, result);
56 } else if (auto repl = v.getDefiningOp<ReplicateOp>()) {
57 for (size_t i = 0, e = repl.getMultiple(); i != e; ++i)
58 getConcatOperands(repl.getOperand(), result);
59 } else {
60 result.push_back(v);
61 }
62}
63
64// Return true if the op has SV attributes. Note that we cannot use a helper
65// function `hasSVAttributes` defined under SV dialect because of a cyclic
66// dependency.
67static bool hasSVAttributes(Operation *op) {
68 return op->hasAttr("sv.attributes");
69}
70
71namespace {
72template <typename SubType>
73struct ComplementMatcher {
74 SubType lhs;
75 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
76 bool match(Operation *op) {
77 auto xorOp = dyn_cast<XorOp>(op);
78 return xorOp && xorOp.isBinaryNot() &&
79 mlir::detail::matchOperandOrValueAtIndex(op, 0, lhs);
80 }
81};
82} // end anonymous namespace
83
84template <typename SubType>
85static inline ComplementMatcher<SubType> m_Complement(const SubType &subExpr) {
86 return ComplementMatcher<SubType>(subExpr);
87}
88
89/// Return true if the op will be flattened afterwards. Op will be flattend if
90/// it has a single user which has a same op type. User must be in same block.
91static bool shouldBeFlattened(Operation *op) {
92 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
93 "must be commutative operations"));
94 if (op->hasOneUse()) {
95 auto *user = *op->getUsers().begin();
96 return user->getName() == op->getName() &&
97 op->getAttrOfType<UnitAttr>("twoState") ==
98 user->getAttrOfType<UnitAttr>("twoState") &&
99 op->getBlock() == user->getBlock();
100 }
101 return false;
102}
103
104/// Flattens a single input in `op` if `hasOneUse` is true and it can be defined
105/// as an Op. Returns true if successful, and false otherwise.
106///
107/// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true
108///
109static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) {
110 // Skip if the operation should be flattened by another operation.
111 if (shouldBeFlattened(op))
112 return false;
113
114 auto inputs = op->getOperands();
115
116 SmallVector<Value, 4> newOperands;
117 SmallVector<Location, 4> newLocations{op->getLoc()};
118 newOperands.reserve(inputs.size());
119 struct Element {
120 decltype(inputs.begin()) current, end;
121 };
122
123 SmallVector<Element> worklist;
124 worklist.push_back({inputs.begin(), inputs.end()});
125 bool binFlag = op->hasAttrOfType<UnitAttr>("twoState");
126 bool changed = false;
127 while (!worklist.empty()) {
128 auto &element = worklist.back(); // Do not pop. Take ref.
129
130 // Pop when we finished traversing the current operand range.
131 if (element.current == element.end) {
132 worklist.pop_back();
133 continue;
134 }
135
136 Value value = *element.current++;
137 auto *flattenOp = value.getDefiningOp();
138 // If not defined by a compatible operation of the same kind and
139 // from the same block, keep this as-is.
140 if (!flattenOp || flattenOp->getName() != op->getName() ||
141 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>("twoState") ||
142 flattenOp->getBlock() != op->getBlock()) {
143 newOperands.push_back(value);
144 continue;
145 }
146
147 // Don't duplicate logic when it has multiple uses.
148 if (!value.hasOneUse()) {
149 // We can fold a multi-use binary operation into this one if this allows a
150 // constant to fold though. For example, fold
151 // (or a, b, c, (or d, cst1), cst2) --> (or a, b, c, d, cst1, cst2)
152 // since the constants will both fold and we end up with the equiv cost.
153 //
154 // We don't do this for add/mul because the hardware won't be shared
155 // between the two ops if duplicated.
156 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
157 !flattenOp->getOperand(1).getDefiningOp<hw::ConstantOp>() ||
158 !inputs.back().getDefiningOp<hw::ConstantOp>()) {
159 newOperands.push_back(value);
160 continue;
161 }
162 }
163
164 changed = true;
165
166 // Otherwise, push operands into worklist.
167 auto flattenOpInputs = flattenOp->getOperands();
168 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
169 newLocations.push_back(flattenOp->getLoc());
170 }
171
172 if (!changed)
173 return false;
174
175 Value result = createGenericOp(FusedLoc::get(op->getContext(), newLocations),
176 op->getName(), newOperands, rewriter);
177 if (binFlag)
178 result.getDefiningOp()->setAttr("twoState", rewriter.getUnitAttr());
179
180 replaceOpAndCopyNamehint(rewriter, op, result);
181 return true;
182}
183
184// Given a range of uses of an operation, find the lowest and highest bits
185// inclusive that are ever referenced. The range of uses must not be empty.
186static std::pair<size_t, size_t>
187getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits,
188 size_t originalOpWidth) {
189 auto users = op->getUsers();
190 assert(!users.empty() &&
191 "getLowestBitAndHighestBitRequired cannot operate on "
192 "a empty list of uses.");
193
194 // when we don't want to narrowTrailingBits (namely in arithmetic
195 // operations), forcing lowestBitRequired = 0
196 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
197 size_t highestBitRequired = 0;
198
199 for (auto *user : users) {
200 if (auto extractOp = dyn_cast<ExtractOp>(user)) {
201 size_t lowBit = extractOp.getLowBit();
202 size_t highBit =
203 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
204 highestBitRequired = std::max(highestBitRequired, highBit);
205 lowestBitRequired = std::min(lowestBitRequired, lowBit);
206 continue;
207 }
208
209 highestBitRequired = originalOpWidth - 1;
210 lowestBitRequired = 0;
211 break;
212 }
213
214 return {lowestBitRequired, highestBitRequired};
215}
216
217template <class OpTy>
218static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits,
219 PatternRewriter &rewriter) {
220 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
221 if (!opType)
222 return false;
223
224 auto range = getLowestBitAndHighestBitRequired(op, narrowTrailingBits,
225 opType.getWidth());
226 if (range.second + 1 == opType.getWidth() && range.first == 0)
227 return false;
228
229 SmallVector<Value> args;
230 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
231 for (auto inop : op.getOperands()) {
232 // deal with muxes here
233 if (inop.getType() != op.getType())
234 args.push_back(inop);
235 else
236 args.push_back(rewriter.createOrFold<ExtractOp>(inop.getLoc(), newType,
237 inop, range.first));
238 }
239 auto newop = OpTy::create(rewriter, op.getLoc(), newType, args);
240 newop->setDialectAttrs(op->getDialectAttrs());
241 if (op.getTwoState())
242 newop.setTwoState(true);
243
244 Value newResult = newop.getResult();
245 if (range.first)
246 newResult = rewriter.createOrFold<ConcatOp>(
247 op.getLoc(), newResult,
248 hw::ConstantOp::create(rewriter, op.getLoc(),
249 APInt::getZero(range.first)));
250 if (range.second + 1 < opType.getWidth())
251 newResult = rewriter.createOrFold<ConcatOp>(
252 op.getLoc(),
254 rewriter, op.getLoc(),
255 APInt::getZero(opType.getWidth() - range.second - 1)),
256 newResult);
257 rewriter.replaceOp(op, newResult);
258 return true;
259}
260
261//===----------------------------------------------------------------------===//
262// Unary Operations
263//===----------------------------------------------------------------------===//
264
265OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
266 if (isOpTriviallyRecursive(*this))
267 return {};
268
269 // Replicate one time -> noop.
270 if (cast<IntegerType>(getType()).getWidth() ==
271 getInput().getType().getIntOrFloatBitWidth())
272 return getInput();
273
274 // Constant fold.
275 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
276 if (input.getValue().getBitWidth() == 1) {
277 if (input.getValue().isZero())
278 return getIntAttr(
279 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
280 getContext());
281 return getIntAttr(
282 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
283 getContext());
284 }
285
286 APInt result = APInt::getZeroWidth();
287 for (auto i = getMultiple(); i != 0; --i)
288 result = result.concat(input.getValue());
289 return getIntAttr(result, getContext());
290 }
291
292 return {};
293}
294
295OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
296 if (isOpTriviallyRecursive(*this))
297 return {};
298
299 // Constant fold.
300 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
301 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
302
303 return {};
304}
305
306//===----------------------------------------------------------------------===//
307// Binary Operations
308//===----------------------------------------------------------------------===//
309
310/// Performs constant folding `calculate` with element-wise behavior on the two
311/// attributes in `operands` and returns the result if possible.
312static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
313 hw::PEO paramOpcode) {
314 assert(operands.size() == 2 && "binary op takes two operands");
315 if (!operands[0] || !operands[1])
316 return {};
317
318 // Fold constants with ParamExprAttr::get which handles simple constants as
319 // well as parameter expressions.
320 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
321 cast<TypedAttr>(operands[1]));
322}
323
324OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
325 if (isOpTriviallyRecursive(*this))
326 return {};
327
328 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
329 if (rhs.getValue().isZero())
330 return getOperand(0);
331
332 unsigned width = getType().getIntOrFloatBitWidth();
333 if (rhs.getValue().uge(width))
334 return getIntAttr(APInt::getZero(width), getContext());
335 }
336 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::Shl);
337}
338
339LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
341 return failure();
342
343 // ShlOp(x, cst) -> Concat(Extract(x), zeros)
344 APInt value;
345 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
346 return failure();
347
348 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
349 if (value.ugt(width))
350 value = width;
351 unsigned shift = value.getZExtValue();
352
353 // This case is handled by fold.
354 if (width <= shift || shift == 0)
355 return failure();
356
357 auto zeros =
358 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getZero(shift));
359
360 // Remove the high bits which would be removed by the Shl.
361 auto extract =
362 ExtractOp::create(rewriter, op.getLoc(), op.getLhs(), 0, width - shift);
363
364 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
365 return success();
366}
367
368OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
369 if (isOpTriviallyRecursive(*this))
370 return {};
371
372 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
373 if (rhs.getValue().isZero())
374 return getOperand(0);
375
376 unsigned width = getType().getIntOrFloatBitWidth();
377 if (rhs.getValue().uge(width))
378 return getIntAttr(APInt::getZero(width), getContext());
379 }
380 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrU);
381}
382
383LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
385 return failure();
386
387 // ShrUOp(x, cst) -> Concat(zeros, Extract(x))
388 APInt value;
389 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
390 return failure();
391
392 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
393 if (value.ugt(width))
394 value = width;
395 unsigned shift = value.getZExtValue();
396
397 // This case is handled by fold.
398 if (width <= shift || shift == 0)
399 return failure();
400
401 auto zeros =
402 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getZero(shift));
403
404 // Remove the low bits which would be removed by the Shr.
405 auto extract = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(), shift,
406 width - shift);
407
408 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
409 return success();
410}
411
412OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
413 if (isOpTriviallyRecursive(*this))
414 return {};
415
416 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
417 if (rhs.getValue().isZero())
418 return getOperand(0);
419 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrS);
420}
421
422LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
424 return failure();
425
426 // ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x))
427 APInt value;
428 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
429 return failure();
430
431 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
432 if (value.ugt(width))
433 value = width;
434 unsigned shift = value.getZExtValue();
435
436 auto topbit =
437 rewriter.createOrFold<ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
438 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
439
440 if (width == shift) {
441 replaceOpAndCopyNamehint(rewriter, op, {sext});
442 return success();
443 }
444
445 auto extract = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(), shift,
446 width - shift);
447
448 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
449 return success();
450}
451
452//===----------------------------------------------------------------------===//
453// Other Operations
454//===----------------------------------------------------------------------===//
455
456OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
457 if (isOpTriviallyRecursive(*this))
458 return {};
459
460 // If we are extracting the entire input, then return it.
461 if (getInput().getType() == getType())
462 return getInput();
463
464 // Constant fold.
465 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
466 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
467 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
468 getContext());
469 }
470 return {};
471}
472
473// Transforms extract(lo, cat(a, b, c, d, e)) into
474// cat(extract(lo1, b), c, extract(lo2, d)).
475// innerCat must be the argument of the provided ExtractOp.
477 ConcatOp innerCat,
478 PatternRewriter &rewriter) {
479 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
480 size_t beginOfFirstRelevantElement = 0;
481 auto it = reversedConcatArgs.begin();
482 size_t lowBit = op.getLowBit();
483
484 // This loop finds the first concatArg that is covered by the ExtractOp
485 for (; it != reversedConcatArgs.end(); it++) {
486 assert(beginOfFirstRelevantElement <= lowBit &&
487 "incorrectly moved past an element that lowBit has coverage over");
488 auto operand = *it;
489
490 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
491 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
492 // A bit other than the first bit will be used in this element.
493 // ...... ........ ...
494 // ^---lowBit
495 // ^---beginOfFirstRelevantElement
496 //
497 // Edge-case close to the end of the range.
498 // ...... ........ ...
499 // ^---(position + operandWidth)
500 // ^---lowBit
501 // ^---beginOfFirstRelevantElement
502 //
503 // Edge-case close to the beginning of the rang
504 // ...... ........ ...
505 // ^---lowBit
506 // ^---beginOfFirstRelevantElement
507 //
508 break;
509 }
510
511 // extraction discards this element.
512 // ...... ........ ...
513 // | ^---lowBit
514 // ^---beginOfFirstRelevantElement
515 beginOfFirstRelevantElement += operandWidth;
516 }
517 assert(it != reversedConcatArgs.end() &&
518 "incorrectly failed to find an element which contains coverage of "
519 "lowBit");
520
521 SmallVector<Value> reverseConcatArgs;
522 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
523 size_t extractLo = lowBit - beginOfFirstRelevantElement;
524
525 // Transform individual arguments of innerCat(..., a, b, c,) into
526 // [ extract(a), b, extract(c) ], skipping an extract operation where
527 // possible.
528 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
529 auto concatArg = *it;
530 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
531 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
532
533 if (widthToConsume == operandWidth && extractLo == 0) {
534 reverseConcatArgs.push_back(concatArg);
535 } else {
536 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
537 reverseConcatArgs.push_back(
538 ExtractOp::create(rewriter, op.getLoc(), resultType, *it, extractLo));
539 }
540
541 widthRemaining -= widthToConsume;
542
543 // Beyond the first element, all elements are extracted from position 0.
544 extractLo = 0;
545 }
546
547 if (reverseConcatArgs.size() == 1) {
548 replaceOpAndCopyNamehint(rewriter, op, reverseConcatArgs[0]);
549 } else {
550 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
551 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
552 }
553 return success();
554}
555
556// Transforms extract(lo, replicate(a, N)) into replicate(a, N-c).
557static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate,
558 PatternRewriter &rewriter) {
559 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
560 auto replicateEltWidth =
561 replicate.getOperand().getType().getIntOrFloatBitWidth();
562
563 // If the extract starts at the base of an element and is an even multiple,
564 // we can replace the extract with a smaller replicate.
565 if (op.getLowBit() % replicateEltWidth == 0 &&
566 extractResultWidth % replicateEltWidth == 0) {
567 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
568 replicate.getOperand());
569 return true;
570 }
571
572 // If the extract is completely contained in one element, extract from the
573 // element.
574 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
575 replicateEltWidth) {
576 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
577 rewriter, op, op.getType(), replicate.getOperand(),
578 op.getLowBit() % replicateEltWidth);
579 return true;
580 }
581
582 // We don't currently handle the case of extracting from non-whole elements,
583 // e.g. `extract (replicate 2-bit-thing, N), 1`.
584 return false;
585}
586
587LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
589 return failure();
590 auto *inputOp = op.getInput().getDefiningOp();
591
592 // This turns out to be incredibly expensive. Disable until performance is
593 // addressed.
594#if 0
595 // If the extracted bits are all known, then return the result.
596 auto knownBits = computeKnownBits(op.getInput())
597 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
598 op.getLowBit());
599 if (knownBits.isConstant()) {
600 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
601 knownBits.getConstant());
602 return success();
603 }
604#endif
605
606 // extract(olo, extract(ilo, x)) = extract(olo + ilo, x)
607 if (auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
608 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
609 rewriter, op, op.getType(), innerExtract.getInput(),
610 innerExtract.getLowBit() + op.getLowBit());
611 return success();
612 }
613
614 // extract(lo, cat(a, b, c, d, e)) = cat(extract(lo1, b), c, extract(lo2, d))
615 if (auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
616 return extractConcatToConcatExtract(op, innerCat, rewriter);
617
618 // extract(lo, replicate(a))
619 if (auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
620 if (extractFromReplicate(op, replicate, rewriter))
621 return success();
622
623 // `extract(and(a, cst))` -> `extract(a)` when the relevant bits of the
624 // and/or/xor are not modifying the extracted bits.
625 if (inputOp && inputOp->getNumOperands() == 2 &&
626 isa<AndOp, OrOp, XorOp>(inputOp)) {
627 if (auto cstRHS = inputOp->getOperand(1).getDefiningOp<hw::ConstantOp>()) {
628 auto extractedCst = cstRHS.getValue().extractBits(
629 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
630 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
631 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
632 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
633 return success();
634 }
635
636 // `extract(and(a, cst))` -> `concat(extract(a), 0)` if we only need one
637 // extract to represent the result. Turning it into a pile of extracts is
638 // always fine by our cost model, but we don't want to explode things into
639 // a ton of bits because it will bloat the IR and generated Verilog.
640 if (isa<AndOp>(inputOp)) {
641 // For our cost model, we only do this if the bit pattern is a
642 // contiguous series of ones.
643 unsigned lz = extractedCst.countLeadingZeros();
644 unsigned tz = extractedCst.countTrailingZeros();
645 unsigned pop = extractedCst.popcount();
646 if (extractedCst.getBitWidth() - lz - tz == pop) {
647 auto resultTy = rewriter.getIntegerType(pop);
648 SmallVector<Value> resultElts;
649 if (lz)
650 resultElts.push_back(hw::ConstantOp::create(rewriter, op.getLoc(),
651 APInt::getZero(lz)));
652 resultElts.push_back(rewriter.createOrFold<ExtractOp>(
653 op.getLoc(), resultTy, inputOp->getOperand(0),
654 op.getLowBit() + tz));
655 if (tz)
656 resultElts.push_back(hw::ConstantOp::create(rewriter, op.getLoc(),
657 APInt::getZero(tz)));
658 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
659 return success();
660 }
661 }
662 }
663 }
664
665 // `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is
666 // extracted.
667 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
668 if (auto shlOp = dyn_cast<ShlOp>(inputOp)) {
669 // Don't canonicalize if the shift is multiply used.
670 if (shlOp->hasOneUse())
671 if (auto lhsCst = shlOp.getLhs().getDefiningOp<hw::ConstantOp>())
672 if (lhsCst.getValue().isOne()) {
673 auto newCst = hw::ConstantOp::create(
674 rewriter, shlOp.getLoc(),
675 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
676 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
677 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
678 false);
679 return success();
680 }
681 }
682
683 return failure();
684}
685
686//===----------------------------------------------------------------------===//
687// Associative Variadic operations
688//===----------------------------------------------------------------------===//
689
690// Reduce all operands to a single value (either integer constant or parameter
691// expression) if all the operands are constants.
692static Attribute constFoldAssociativeOp(ArrayRef<Attribute> operands,
693 hw::PEO paramOpcode) {
694 assert(operands.size() > 1 && "caller should handle one-operand case");
695 // We can only fold anything in the case where all operands are known to be
696 // constants. Check the least common one first for an early out.
697 if (!operands[1] || !operands[0])
698 return {};
699
700 // This will fold to a simple constant if all operands are constant.
701 if (llvm::all_of(operands.drop_front(2),
702 [&](Attribute in) { return !!in; })) {
703 SmallVector<mlir::TypedAttr> typedOperands;
704 typedOperands.reserve(operands.size());
705 for (auto operand : operands) {
706 if (auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
707 typedOperands.push_back(typedOperand);
708 else
709 break;
710 }
711 if (typedOperands.size() == operands.size())
712 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
713 }
714
715 return {};
716}
717
718/// When we find a logical operation (and, or, xor) with a constant e.g.
719/// `X & 42`, we want to push the constant into the computation of X if it leads
720/// to simplification.
721///
722/// This function handles the case where the logical operation has a concat
723/// operand. We check to see if we can simplify the concat, e.g. when it has
724/// constant operands.
725///
726/// This returns true when a simplification happens.
727static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp,
728 size_t concatIdx, const APInt &cst,
729 PatternRewriter &rewriter) {
730 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<ConcatOp>();
731 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
732
733 // Check to see if any operands can be simplified by pushing the logical op
734 // into all parts of the concat.
735 bool canSimplify =
736 llvm::any_of(concatOp->getOperands(), [&](Value operand) -> bool {
737 auto *operandOp = operand.getDefiningOp();
738 if (!operandOp)
739 return false;
740
741 // If the concat has a constant operand then we can transform this.
742 if (isa<hw::ConstantOp>(operandOp))
743 return true;
744 // If the concat has the same logical operation and that operation has
745 // a constant operation than we can fold it into that suboperation.
746 return operandOp->getName() == logicalOp->getName() &&
747 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
748 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
749 });
750
751 if (!canSimplify)
752 return false;
753
754 // Create a new instance of the logical operation. We have to do this the
755 // hard way since we're generic across a family of different ops.
756 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
757 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
758 rewriter);
759 };
760
761 // Ok, let's do the transformation. We do this by slicing up the constant
762 // for each unit of the concat and duplicate the operation into the
763 // sub-operand.
764 SmallVector<Value> newConcatOperands;
765 newConcatOperands.reserve(concatOp->getNumOperands());
766
767 // Work from MSB to LSB.
768 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
769 for (Value operand : concatOp->getOperands()) {
770 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
771 nextOperandBit -= operandWidth;
772 // Take a slice of the constant.
773 auto eltCst =
774 hw::ConstantOp::create(rewriter, logicalOp->getLoc(),
775 cst.lshr(nextOperandBit).trunc(operandWidth));
776
777 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
778 }
779
780 // Create the concat, and the rest of the logical op if we need it.
781 Value newResult =
782 ConcatOp::create(rewriter, concatOp.getLoc(), newConcatOperands);
783
784 // If we had a variadic logical op on the top level, then recreate it with the
785 // new concat and without the constant operand.
786 if (logicalOp->getNumOperands() > 2) {
787 auto origOperands = logicalOp->getOperands();
788 SmallVector<Value> operands;
789 // Take any stuff before the concat.
790 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
791 // Take any stuff after the concat but before the constant.
792 operands.append(origOperands.begin() + concatIdx + 1,
793 origOperands.begin() + (origOperands.size() - 1));
794 // Include the new concat.
795 operands.push_back(newResult);
796 newResult = createLogicalOp(operands);
797 }
798
799 replaceOpAndCopyNamehint(rewriter, logicalOp, newResult);
800 return true;
801}
802
803// Determines whether the inputs to a logical element are of opposite
804// comparisons and can lowered into a constant.
805static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands) {
806 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
807
808 for (auto op : operands) {
809 if (auto icmpOp = op.getDefiningOp<ICmpOp>();
810 icmpOp && icmpOp.getTwoState()) {
811 auto predicate = icmpOp.getPredicate();
812 auto lhs = icmpOp.getLhs();
813 auto rhs = icmpOp.getRhs();
814 if (seenPredicates.contains(
815 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
816 return true;
817
818 seenPredicates.insert({predicate, lhs, rhs});
819 }
820 }
821 return false;
822}
823
824OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
825 if (isOpTriviallyRecursive(*this))
826 return {};
827
828 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).getWidth());
829
830 auto inputs = adaptor.getInputs();
831
832 // and(x, 01, 10) -> 00 -- annulment.
833 for (auto operand : inputs) {
834 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
835 if (!attr)
836 continue;
837 value &= attr.getValue();
838 if (value.isZero())
839 return getIntAttr(value, getContext());
840 }
841
842 // and(x, -1) -> x.
843 if (inputs.size() == 2)
844 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
845 if (intAttr.getValue().isAllOnes())
846 return getInputs()[0];
847
848 // and(x, x, x) -> x. This also handles and(x) -> x.
849 if (llvm::all_of(getInputs(),
850 [&](auto in) { return in == this->getInputs()[0]; }))
851 return getInputs()[0];
852
853 // and(..., x, ..., ~x, ...) -> 0
854 for (Value arg : getInputs()) {
855 Value subExpr;
856 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
857 for (Value arg2 : getInputs())
858 if (arg2 == subExpr)
859 return getIntAttr(
860 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
861 getContext());
862 }
863 }
864
865 // x0 = icmp(pred, x, y)
866 // x1 = icmp(!pred, x, y)
867 // and(x0, x1) -> 0
869 return getIntAttr(APInt::getZero(cast<IntegerType>(getType()).getWidth()),
870 getContext());
871
872 // Constant fold
873 return constFoldAssociativeOp(inputs, hw::PEO::And);
874}
875
876/// Returns a single common operand that all inputs of the operation `op` can
877/// be traced back to, or an empty `Value` if no such operand exists.
878///
879/// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a`
880/// (assuming the bit-width of `a` is `n`).
881template <typename Op>
882static Value getCommonOperand(Op op) {
883 if (!op.getType().isInteger(1))
884 return Value();
885
886 auto inputs = op.getInputs();
887 size_t size = inputs.size();
888
889 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
890 if (!sourceOp)
891 return Value();
892 Value source = sourceOp.getOperand();
893
894 // Fast path: the input size is not equal to the width of the source.
895 if (size != source.getType().getIntOrFloatBitWidth())
896 return Value();
897
898 // Tracks the bits that were encountered.
899 llvm::BitVector bits(size);
900 bits.set(sourceOp.getLowBit());
901
902 for (size_t i = 1; i != size; ++i) {
903 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
904 if (!extractOp || extractOp.getOperand() != source)
905 return Value();
906 bits.set(extractOp.getLowBit());
907 }
908
909 return bits.all() ? source : Value();
910}
911
912/// Canonicalize an idempotent operation `op` so that only one input of any kind
913/// occurs.
914///
915/// Example: `and(x, y, x, z)` -> `and(x, y, z)`
916template <typename Op>
917static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
918 // Depth limit to search, in operations. Chosen arbitrarily, keep small.
919 constexpr unsigned limit = 3;
920 auto inputs = op.getInputs();
921
922 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
923 llvm::SmallDenseSet<Op, 8> checked;
924 checked.insert(op);
925
926 struct OpWithDepth {
927 Op op;
928 unsigned depth;
929 };
930 llvm::SmallVector<OpWithDepth, 8> worklist;
931
932 auto enqueue = [&worklist, &checked, &op](Value input, unsigned depth) {
933 // Add to worklist if within depth limit, is defined in the same block by
934 // the same kind of operation, has same two-state-ness, and not enqueued
935 // previously.
936 if (depth < limit && input.getParentBlock() == op->getBlock()) {
937 auto inputOp = input.template getDefiningOp<Op>();
938 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
939 checked.insert(inputOp).second)
940 worklist.push_back({inputOp, depth + 1});
941 }
942 };
943
944 for (auto input : uniqueInputs)
945 enqueue(input, 0);
946
947 while (!worklist.empty()) {
948 auto item = worklist.pop_back_val();
949
950 for (auto input : item.op.getInputs()) {
951 uniqueInputs.remove(input);
952 enqueue(input, item.depth);
953 }
954 }
955
956 if (uniqueInputs.size() < inputs.size()) {
957 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
958 uniqueInputs.getArrayRef(),
959 op.getTwoState());
960 return true;
961 }
962
963 return false;
964}
965
966LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
968 return failure();
969
970 auto inputs = op.getInputs();
971 auto size = inputs.size();
972
973 // and(x, and(...)) -> and(x, ...) -- flatten
974 if (tryFlatteningOperands(op, rewriter))
975 return success();
976
977 // and(..., x, ..., x) -> and(..., x, ...) -- idempotent
978 // and(..., x, and(..., x, ...)) -> and(..., and(..., x, ...)) -- idempotent
979 // Trivial and(x), and(x, x) cases are handled by [AndOp::fold] above.
980 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
981 return success();
982
983 assert(size > 1 && "expected 2 or more operands, `fold` should handle this");
984
985 // Patterns for and with a constant on RHS.
986 APInt value;
987 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
988 // and(..., '1) -> and(...) -- identity
989 if (value.isAllOnes()) {
990 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
991 inputs.drop_back(), false);
992 return success();
993 }
994
995 // TODO: Combine multiple constants together even if they aren't at the
996 // end. and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant
997 // folding
998 APInt value2;
999 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1000 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value & value2);
1001 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1002 newOperands.push_back(cst);
1003 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
1004 newOperands, false);
1005 return success();
1006 }
1007
1008 // Handle 'and' with a single bit constant on the RHS.
1009 if (size == 2 && value.isPowerOf2()) {
1010 // If the LHS is a replicate from a single bit, we can 'concat' it
1011 // into place. e.g.:
1012 // `replicate(x) & 4` -> `concat(zeros, x, zeros)`
1013 // TODO: Generalize this for non-single-bit operands.
1014 if (auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1015 auto replicateOperand = replicate.getOperand();
1016 if (replicateOperand.getType().isInteger(1)) {
1017 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1018 auto trailingZeros = value.countTrailingZeros();
1019
1020 // Don't add zero bit constants unnecessarily.
1021 SmallVector<Value, 3> concatOperands;
1022 if (trailingZeros != resultWidth - 1) {
1023 auto highZeros = hw::ConstantOp::create(
1024 rewriter, op.getLoc(),
1025 APInt::getZero(resultWidth - trailingZeros - 1));
1026 concatOperands.push_back(highZeros);
1027 }
1028 concatOperands.push_back(replicateOperand);
1029 if (trailingZeros != 0) {
1030 auto lowZeros = hw::ConstantOp::create(
1031 rewriter, op.getLoc(), APInt::getZero(trailingZeros));
1032 concatOperands.push_back(lowZeros);
1033 }
1034 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
1035 rewriter, op, op.getType(), concatOperands);
1036 return success();
1037 }
1038 }
1039 }
1040
1041 // Narrow the op if the constant has leading or trailing zeros.
1042 //
1043 // and(a, 0b00101100) -> concat(0b00, and(extract(a), 0b1011), 0b00)
1044 unsigned leadingZeros = value.countLeadingZeros();
1045 unsigned trailingZeros = value.countTrailingZeros();
1046 if (leadingZeros > 0 || trailingZeros > 0) {
1047 unsigned maskLength = value.getBitWidth() - leadingZeros - trailingZeros;
1048
1049 // Extract the non-zero regions of the operands. Look through extracts.
1050 SmallVector<Value> operands;
1051 for (auto input : inputs.drop_back()) {
1052 unsigned offset = trailingZeros;
1053 while (auto extractOp = input.getDefiningOp<ExtractOp>()) {
1054 input = extractOp.getInput();
1055 offset += extractOp.getLowBit();
1056 }
1057 operands.push_back(ExtractOp::create(rewriter, op.getLoc(), input,
1058 offset, maskLength));
1059 }
1060
1061 // Add the narrowed mask if needed.
1062 auto narrowMask = value.extractBits(maskLength, trailingZeros);
1063 if (!narrowMask.isAllOnes())
1064 operands.push_back(hw::ConstantOp::create(
1065 rewriter, inputs.back().getLoc(), narrowMask));
1066
1067 // Create the narrow and op.
1068 Value narrowValue = operands.back();
1069 if (operands.size() > 1)
1070 narrowValue =
1071 AndOp::create(rewriter, op.getLoc(), operands, op.getTwoState());
1072 operands.clear();
1073
1074 // Concatenate the narrow and with the leading and trailing zeros.
1075 if (leadingZeros > 0)
1076 operands.push_back(hw::ConstantOp::create(
1077 rewriter, op.getLoc(), APInt::getZero(leadingZeros)));
1078 operands.push_back(narrowValue);
1079 if (trailingZeros > 0)
1080 operands.push_back(hw::ConstantOp::create(
1081 rewriter, op.getLoc(), APInt::getZero(trailingZeros)));
1082 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, operands);
1083 return success();
1084 }
1085
1086 // and(concat(x, cst1), a, b, c, cst2)
1087 // ==> and(a, b, c, concat(and(x,cst2'), and(cst1,cst2'')).
1088 // We do this for even more multi-use concats since they are "just wiring".
1089 for (size_t i = 0; i < size - 1; ++i) {
1090 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1091 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1092 return success();
1093 }
1094 }
1095
1096 // extracts only of and(...) -> and(extract()...)
1097 if (narrowOperationWidth(op, true, rewriter))
1098 return success();
1099
1100 // and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1)
1101 if (auto source = getCommonOperand(op)) {
1102 auto cmpAgainst =
1103 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getAllOnes(size));
1104 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1105 source, cmpAgainst);
1106 return success();
1107 }
1108
1109 // and(x, replicate(p : i1)) -> mux(p, x, zero)
1110 if (op.getTwoState() && op.getNumOperands() == 2) {
1111 auto isReplicateOfI1 = [](Value v) {
1112 auto rep = v.getDefiningOp<ReplicateOp>();
1113 if (!rep)
1114 return false;
1115 return rep.getOperand().getType().isInteger(1);
1116 };
1117 Value x = op.getOperand(0);
1118 Value y = op.getOperand(1);
1119 if (isReplicateOfI1(x))
1120 std::swap(x, y);
1121 if (isReplicateOfI1(y)) {
1122 Value p = y.getDefiningOp<ReplicateOp>().getInput();
1123 Value zero = hw::ConstantOp::create(
1124 rewriter, op.getLoc(), rewriter.getIntegerAttr(op.getType(), 0));
1125 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, p, x, zero,
1126 /*isTwoState=*/true);
1127 return success();
1128 }
1129 }
1130
1131 /// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
1132 return failure();
1133}
1134
1135OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1136 if (isOpTriviallyRecursive(*this))
1137 return {};
1138
1139 auto value = APInt::getZero(cast<IntegerType>(getType()).getWidth());
1140 auto inputs = adaptor.getInputs();
1141 // or(x, 10, 01) -> 11
1142 for (auto operand : inputs) {
1143 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
1144 if (!attr)
1145 continue;
1146 value |= attr.getValue();
1147 if (value.isAllOnes())
1148 return getIntAttr(value, getContext());
1149 }
1150
1151 // or(x, 0) -> x
1152 if (inputs.size() == 2)
1153 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
1154 if (intAttr.getValue().isZero())
1155 return getInputs()[0];
1156
1157 // or(x, x, x) -> x. This also handles or(x) -> x
1158 if (llvm::all_of(getInputs(),
1159 [&](auto in) { return in == this->getInputs()[0]; }))
1160 return getInputs()[0];
1161
1162 // or(..., x, ..., ~x, ...) -> -1
1163 for (Value arg : getInputs()) {
1164 Value subExpr;
1165 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
1166 for (Value arg2 : getInputs())
1167 if (arg2 == subExpr)
1168 return getIntAttr(
1169 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1170 getContext());
1171 }
1172 }
1173
1174 // x0 = icmp(pred, x, y)
1175 // x1 = icmp(!pred, x, y)
1176 // or(x0, x1) -> 1
1177 if (canCombineOppositeBinCmpIntoConstant(getInputs()))
1178 return getIntAttr(
1179 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1180 getContext());
1181
1182 // Constant fold
1183 return constFoldAssociativeOp(inputs, hw::PEO::Or);
1184}
1185
1186LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
1187 if (isOpTriviallyRecursive(op))
1188 return failure();
1189
1190 auto inputs = op.getInputs();
1191 auto size = inputs.size();
1192
1193 // or(x, or(...)) -> or(x, ...) -- flatten
1194 if (tryFlatteningOperands(op, rewriter))
1195 return success();
1196
1197 // or(..., x, ..., x, ...) -> or(..., x) -- idempotent
1198 // or(..., x, or(..., x, ...)) -> or(..., or(..., x, ...)) -- idempotent
1199 // Trivial or(x), or(x, x) cases are handled by [OrOp::fold].
1200 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
1201 return success();
1202
1203 assert(size > 1 && "expected 2 or more operands");
1204
1205 // Patterns for and with a constant on RHS.
1206 APInt value;
1207 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1208 // or(..., '0) -> or(...) -- identity
1209 if (value.isZero()) {
1210 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1211 inputs.drop_back());
1212 return success();
1213 }
1214
1215 // or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant folding
1216 APInt value2;
1217 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1218 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value | value2);
1219 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1220 newOperands.push_back(cst);
1221 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1222 newOperands);
1223 return success();
1224 }
1225
1226 // or(concat(x, cst1), a, b, c, cst2)
1227 // ==> or(a, b, c, concat(or(x,cst2'), or(cst1,cst2'')).
1228 // We do this for even more multi-use concats since they are "just wiring".
1229 for (size_t i = 0; i < size - 1; ++i) {
1230 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1231 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1232 return success();
1233 }
1234 }
1235
1236 // extracts only of or(...) -> or(extract()...)
1237 if (narrowOperationWidth(op, true, rewriter))
1238 return success();
1239
1240 // or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0)
1241 if (auto source = getCommonOperand(op)) {
1242 auto cmpAgainst =
1243 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getZero(size));
1244 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1245 source, cmpAgainst);
1246 return success();
1247 }
1248
1249 // or(mux(c_1, a, 0), mux(c_2, a, 0), ..., mux(c_n, a, 0)) -> mux(or(c_1, c_2,
1250 // .., c_n), a, 0)
1251 if (auto firstMux = op.getOperand(0).getDefiningOp<comb::MuxOp>()) {
1252 APInt value;
1253 if (op.getTwoState() && firstMux.getTwoState() &&
1254 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1255 value.isZero()) {
1256 SmallVector<Value> conditions{firstMux.getCond()};
1257 auto check = [&](Value v) {
1258 auto mux = v.getDefiningOp<comb::MuxOp>();
1259 if (!mux)
1260 return false;
1261 conditions.push_back(mux.getCond());
1262 return mux.getTwoState() &&
1263 firstMux.getTrueValue() == mux.getTrueValue() &&
1264 firstMux.getFalseValue() == mux.getFalseValue();
1265 };
1266 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1267 auto cond = comb::OrOp::create(rewriter, op.getLoc(), conditions, true);
1268 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1269 rewriter, op, cond, firstMux.getTrueValue(),
1270 firstMux.getFalseValue(), true);
1271 return success();
1272 }
1273 }
1274 }
1275
1276 /// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
1277 return failure();
1278}
1279
1280OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1281 if (isOpTriviallyRecursive(*this))
1282 return {};
1283
1284 auto size = getInputs().size();
1285 auto inputs = adaptor.getInputs();
1286
1287 // xor(x) -> x -- noop
1288 if (size == 1)
1289 return getInputs()[0];
1290
1291 // xor(x, x) -> 0 -- idempotent
1292 if (size == 2 && getInputs()[0] == getInputs()[1])
1293 return IntegerAttr::get(getType(), 0);
1294
1295 // xor(x, 0) -> x
1296 if (inputs.size() == 2)
1297 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1]))
1298 if (intAttr.getValue().isZero())
1299 return getInputs()[0];
1300
1301 // xor(xor(x,1),1) -> x
1302 // but not self loop
1303 Value subExpr;
1304 if (matchPattern(getResult(), m_Complement(m_Complement(m_Any(&subExpr)))) &&
1305 subExpr != getResult())
1306 return subExpr;
1307
1308 // Constant fold
1309 return constFoldAssociativeOp(inputs, hw::PEO::Xor);
1310}
1311
1312// xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1313static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand,
1314 PatternRewriter &rewriter) {
1315 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1316 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1317
1318 Value result =
1319 ICmpOp::create(rewriter, icmp.getLoc(), negatedPred, icmp.getOperand(0),
1320 icmp.getOperand(1), icmp.getTwoState());
1321
1322 // If the xor had other operands, rebuild it.
1323 if (op.getNumOperands() > 2) {
1324 SmallVector<Value, 4> newOperands(op.getOperands());
1325 newOperands.pop_back();
1326 newOperands.erase(newOperands.begin() + icmpOperand);
1327 newOperands.push_back(result);
1328 result =
1329 XorOp::create(rewriter, op.getLoc(), newOperands, op.getTwoState());
1330 }
1331
1332 replaceOpAndCopyNamehint(rewriter, op, result);
1333}
1334
1335LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
1336 if (isOpTriviallyRecursive(op))
1337 return failure();
1338
1339 auto inputs = op.getInputs();
1340 auto size = inputs.size();
1341 assert(size > 1 && "expected 2 or more operands");
1342
1343 // xor(..., x, x) -> xor (...) -- idempotent
1344 if (inputs[size - 1] == inputs[size - 2]) {
1345 assert(size > 2 &&
1346 "expected idempotent case for 2 elements handled already.");
1347 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1348 inputs.drop_back(/*n=*/2), false);
1349 return success();
1350 }
1351
1352 // Patterns for xor with a constant on RHS.
1353 APInt value;
1354 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1355 // xor(..., 0) -> xor(...) -- identity
1356 if (value.isZero()) {
1357 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1358 inputs.drop_back(), false);
1359 return success();
1360 }
1361
1362 // xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2.
1363 APInt value2;
1364 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1365 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value ^ value2);
1366 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1367 newOperands.push_back(cst);
1368 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1369 newOperands, false);
1370 return success();
1371 }
1372
1373 bool isSingleBit = value.getBitWidth() == 1;
1374
1375 // Check for subexpressions that we can simplify.
1376 for (size_t i = 0; i < size - 1; ++i) {
1377 Value operand = inputs[i];
1378
1379 // xor(concat(x, cst1), a, b, c, cst2)
1380 // ==> xor(a, b, c, concat(xor(x,cst2'), xor(cst1,cst2'')).
1381 // We do this for even more multi-use concats since they are "just
1382 // wiring".
1383 if (auto concat = operand.getDefiningOp<ConcatOp>())
1384 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1385 return success();
1386
1387 // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1388 if (isSingleBit && operand.hasOneUse()) {
1389 assert(value == 1 && "single bit constant has to be one if not zero");
1390 if (auto icmp = operand.getDefiningOp<ICmpOp>())
1391 return canonicalizeXorIcmpTrue(op, i, rewriter), success();
1392 }
1393 }
1394 }
1395
1396 // xor(sext(x), -1) -> sext(xor(x,-1))
1397 // More concisely: ~sext(x) = sext(~x)
1398 Value base;
1399 // Check for sext of the inverted value
1400 if (matchPattern(op.getResult(), m_Complement(m_Sext(m_Any(&base))))) {
1401 // Create negated sext: ~sext(x) = sext(~x)
1402 auto negBase = createOrFoldNot(rewriter, op.getLoc(), base, true);
1403 auto sextNegBase =
1404 createOrFoldSExt(rewriter, op.getLoc(), negBase, op.getType());
1405 replaceOpAndCopyNamehint(rewriter, op, sextNegBase);
1406 return success();
1407 }
1408
1409 // xor(x, xor(...)) -> xor(x, ...) -- flatten
1410 if (tryFlatteningOperands(op, rewriter))
1411 return success();
1412
1413 // extracts only of xor(...) -> xor(extract()...)
1414 if (narrowOperationWidth(op, true, rewriter))
1415 return success();
1416
1417 // xor(a[0], a[1], ..., a[n]) -> parity(a)
1418 if (auto source = getCommonOperand(op)) {
1419 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1420 return success();
1421 }
1422
1423 return failure();
1424}
1425
1426OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1427 if (isOpTriviallyRecursive(*this))
1428 return {};
1429
1430 // sub(x - x) -> 0
1431 if (getRhs() == getLhs())
1432 return getIntAttr(
1433 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1434 getContext());
1435
1436 if (adaptor.getRhs()) {
1437 // If both are constants, we can unconditionally fold.
1438 if (adaptor.getLhs()) {
1439 // Constant fold (c1 - c2) => (c1 + -1*c2).
1440 auto negOne = getIntAttr(
1441 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1442 getContext());
1443 auto rhsNeg = hw::ParamExprAttr::get(
1444 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1445 return hw::ParamExprAttr::get(hw::PEO::Add,
1446 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1447 }
1448
1449 // sub(x - 0) -> x
1450 if (auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1451 if (rhsC.getValue().isZero())
1452 return getLhs();
1453 }
1454 }
1455
1456 return {};
1457}
1458
1459LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
1460 if (isOpTriviallyRecursive(op))
1461 return failure();
1462
1463 // sub(x, cst) -> add(x, -cst)
1464 APInt value;
1465 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1466 auto negCst = hw::ConstantOp::create(rewriter, op.getLoc(), -value);
1467 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1468 false);
1469 return success();
1470 }
1471
1472 // extracts only of sub(...) -> sub(extract()...)
1473 if (narrowOperationWidth(op, false, rewriter))
1474 return success();
1475
1476 return failure();
1477}
1478
1479OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1480 if (isOpTriviallyRecursive(*this))
1481 return {};
1482
1483 auto size = getInputs().size();
1484
1485 // add(x) -> x -- noop
1486 if (size == 1u)
1487 return getInputs()[0];
1488
1489 // Constant fold constant operands.
1490 return constFoldAssociativeOp(adaptor.getOperands(), hw::PEO::Add);
1491}
1492
1493LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
1494 if (isOpTriviallyRecursive(op))
1495 return failure();
1496
1497 auto inputs = op.getInputs();
1498 auto size = inputs.size();
1499 assert(size > 1 && "expected 2 or more operands");
1500
1501 APInt value, value2;
1502
1503 // add(..., 0) -> add(...) -- identity
1504 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1505 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1506 inputs.drop_back(), false);
1507 return success();
1508 }
1509
1510 // add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant folding
1511 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1512 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1513 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value + value2);
1514 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1515 newOperands.push_back(cst);
1516 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1517 newOperands, false);
1518 return success();
1519 }
1520
1521 // add(..., x, x) -> add(..., shl(x, 1))
1522 if (inputs[size - 1] == inputs[size - 2]) {
1523 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1524
1525 auto one = hw::ConstantOp::create(rewriter, op.getLoc(), op.getType(), 1);
1526 auto shiftLeftOp =
1527 comb::ShlOp::create(rewriter, op.getLoc(), inputs.back(), one, false);
1528
1529 newOperands.push_back(shiftLeftOp);
1530 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1531 newOperands, false);
1532 return success();
1533 }
1534
1535 auto shlOp = inputs[size - 1].getDefiningOp<comb::ShlOp>();
1536 // add(..., x, shl(x, c)) -> add(..., mul(x, (1 << c) + 1))
1537 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1538 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1539
1540 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1541 auto rhs =
1542 hw::ConstantOp::create(rewriter, op.getLoc(), (one << value) + one);
1543
1544 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1545 auto mulOp = comb::MulOp::create(rewriter, op.getLoc(), factors, false);
1546
1547 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1548 newOperands.push_back(mulOp);
1549 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1550 newOperands, false);
1551 return success();
1552 }
1553
1554 auto mulOp = inputs[size - 1].getDefiningOp<comb::MulOp>();
1555 // add(..., x, mul(x, c)) -> add(..., mul(x, c + 1))
1556 if (mulOp && mulOp.getInputs().size() == 2 &&
1557 mulOp.getInputs()[0] == inputs[size - 2] &&
1558 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1559
1560 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1561 auto rhs = hw::ConstantOp::create(rewriter, op.getLoc(), value + one);
1562 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1563 auto newMulOp = comb::MulOp::create(rewriter, op.getLoc(), factors, false);
1564
1565 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1566 newOperands.push_back(newMulOp);
1567 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1568 newOperands, false);
1569 return success();
1570 }
1571
1572 // add(a, add(...)) -> add(a, ...) -- flatten
1573 if (tryFlatteningOperands(op, rewriter))
1574 return success();
1575
1576 // extracts only of add(...) -> add(extract()...)
1577 if (narrowOperationWidth(op, false, rewriter))
1578 return success();
1579
1580 // add(add(x, c1), c2) -> add(x, c1 + c2)
1581 auto addOp = inputs[0].getDefiningOp<comb::AddOp>();
1582 if (addOp && addOp.getInputs().size() == 2 &&
1583 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1584 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1585
1586 auto rhs = hw::ConstantOp::create(rewriter, op.getLoc(), value + value2);
1587 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1588 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1589 /*twoState=*/op.getTwoState() && addOp.getTwoState());
1590 return success();
1591 }
1592
1593 return failure();
1594}
1595
1596OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1597 if (isOpTriviallyRecursive(*this))
1598 return {};
1599
1600 auto size = getInputs().size();
1601 auto inputs = adaptor.getInputs();
1602
1603 // mul(x) -> x -- noop
1604 if (size == 1u)
1605 return getInputs()[0];
1606
1607 auto width = cast<IntegerType>(getType()).getWidth();
1608 if (width == 0)
1609 return getIntAttr(APInt::getZero(0), getContext());
1610
1611 APInt value(/*numBits=*/width, 1, /*isSigned=*/false);
1612
1613 // mul(x, 0, 1) -> 0 -- annulment
1614 for (auto operand : inputs) {
1615 auto attr = dyn_cast_or_null<IntegerAttr>(operand);
1616 if (!attr)
1617 continue;
1618 value *= attr.getValue();
1619 if (value.isZero())
1620 return getIntAttr(value, getContext());
1621 }
1622
1623 // Constant fold
1624 return constFoldAssociativeOp(inputs, hw::PEO::Mul);
1625}
1626
1627LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
1628 if (isOpTriviallyRecursive(op))
1629 return failure();
1630
1631 auto inputs = op.getInputs();
1632 auto size = inputs.size();
1633 assert(size > 1 && "expected 2 or more operands");
1634
1635 APInt value, value2;
1636
1637 // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
1638 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1639 value.isPowerOf2()) {
1640 auto shift = hw::ConstantOp::create(rewriter, op.getLoc(), op.getType(),
1641 value.exactLogBase2());
1642 auto shlOp =
1643 comb::ShlOp::create(rewriter, op.getLoc(), inputs[0], shift, false);
1644
1645 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1646 ArrayRef<Value>(shlOp), false);
1647 return success();
1648 }
1649
1650 // mul(..., 1) -> mul(...) -- identity
1651 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1652 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1653 inputs.drop_back());
1654 return success();
1655 }
1656
1657 // mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant folding
1658 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1659 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1660 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value * value2);
1661 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1662 newOperands.push_back(cst);
1663 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1664 newOperands);
1665 return success();
1666 }
1667
1668 // mul(a, mul(...)) -> mul(a, ...) -- flatten
1669 if (tryFlatteningOperands(op, rewriter))
1670 return success();
1671
1672 // extracts only of mul(...) -> mul(extract()...)
1673 if (narrowOperationWidth(op, false, rewriter))
1674 return success();
1675
1676 return failure();
1677}
1678
1679template <class Op, bool isSigned>
1680static OpFoldResult foldDiv(Op op, ArrayRef<Attribute> constants) {
1681 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1682 // divu(x, 1) -> x, divs(x, 1) -> x
1683 if (rhsValue.getValue() == 1)
1684 return op.getLhs();
1685
1686 // If the divisor is zero, do not fold for now.
1687 if (rhsValue.getValue().isZero())
1688 return {};
1689 }
1690
1691 return constFoldBinaryOp(constants, isSigned ? hw::PEO::DivS : hw::PEO::DivU);
1692}
1693
1694OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1695 if (isOpTriviallyRecursive(*this))
1696 return {};
1697 return foldDiv<DivUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1698}
1699
1700OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1701 if (isOpTriviallyRecursive(*this))
1702 return {};
1703 return foldDiv<DivSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1704}
1705
1706template <class Op, bool isSigned>
1707static OpFoldResult foldMod(Op op, ArrayRef<Attribute> constants) {
1708 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1709 // modu(x, 1) -> 0, mods(x, 1) -> 0
1710 if (rhsValue.getValue() == 1)
1711 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1712 op.getContext());
1713
1714 // If the divisor is zero, do not fold for now.
1715 if (rhsValue.getValue().isZero())
1716 return {};
1717 }
1718
1719 if (auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1720 // modu(0, x) -> 0, mods(0, x) -> 0
1721 if (lhsValue.getValue().isZero())
1722 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1723 op.getContext());
1724 }
1725
1726 return constFoldBinaryOp(constants, isSigned ? hw::PEO::ModS : hw::PEO::ModU);
1727}
1728
1729OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1730 if (isOpTriviallyRecursive(*this))
1731 return {};
1732 return foldMod<ModUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1733}
1734
1735OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1736 if (isOpTriviallyRecursive(*this))
1737 return {};
1738 return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1739}
1740
1741LogicalResult DivUOp::canonicalize(DivUOp op, PatternRewriter &rewriter) {
1742 if (isOpTriviallyRecursive(op) || !op.getTwoState())
1743 return failure();
1744 return convertDivUByPowerOfTwo(op, rewriter);
1745}
1746
1747LogicalResult ModUOp::canonicalize(ModUOp op, PatternRewriter &rewriter) {
1748 if (isOpTriviallyRecursive(op) || !op.getTwoState())
1749 return failure();
1750
1751 return convertModUByPowerOfTwo(op, rewriter);
1752}
1753
1754//===----------------------------------------------------------------------===//
1755// ConcatOp
1756//===----------------------------------------------------------------------===//
1757
1758// Constant folding
1759OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1760 if (isOpTriviallyRecursive(*this))
1761 return {};
1762
1763 if (getNumOperands() == 1)
1764 return getOperand(0);
1765
1766 // If all the operands are constant, we can fold.
1767 for (auto attr : adaptor.getInputs())
1768 if (!attr || !isa<IntegerAttr>(attr))
1769 return {};
1770
1771 // If we got here, we can constant fold.
1772 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1773 APInt result(resultWidth, 0);
1774
1775 unsigned nextInsertion = resultWidth;
1776 // Insert each chunk into the result.
1777 for (auto attr : adaptor.getInputs()) {
1778 auto chunk = cast<IntegerAttr>(attr).getValue();
1779 nextInsertion -= chunk.getBitWidth();
1780 result.insertBits(chunk, nextInsertion);
1781 }
1782
1783 return getIntAttr(result, getContext());
1784}
1785
1786LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
1787 if (isOpTriviallyRecursive(op))
1788 return failure();
1789
1790 auto inputs = op.getInputs();
1791 auto size = inputs.size();
1792 assert(size > 1 && "expected 2 or more operands");
1793
1794 // This function is used when we flatten neighboring operands of a
1795 // (variadic) concat into a new vesion of the concat. first/last indices
1796 // are inclusive.
1797 auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex,
1798 ValueRange replacements) -> LogicalResult {
1799 SmallVector<Value, 4> newOperands;
1800 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1801 newOperands.append(replacements.begin(), replacements.end());
1802 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1803 if (newOperands.size() == 1)
1804 replaceOpAndCopyNamehint(rewriter, op, newOperands[0]);
1805 else
1806 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1807 newOperands);
1808 return success();
1809 };
1810
1811 Value commonOperand = inputs[0];
1812 for (size_t i = 0; i != size; ++i) {
1813 // Check to see if all operands are the same.
1814 if (inputs[i] != commonOperand)
1815 commonOperand = Value();
1816
1817 // If an operand to the concat is itself a concat, then we can fold them
1818 // together.
1819 if (auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1820 return flattenConcat(i, i, subConcat->getOperands());
1821
1822 // Check for canonicalization due to neighboring operands.
1823 if (i != 0) {
1824 // Merge neighboring constants.
1825 if (auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1826 if (auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1827 unsigned prevWidth = prevCst.getValue().getBitWidth();
1828 unsigned thisWidth = cst.getValue().getBitWidth();
1829 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1830 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1831 << thisWidth;
1832 Value replacement =
1833 hw::ConstantOp::create(rewriter, op.getLoc(), resultCst);
1834 return flattenConcat(i - 1, i, replacement);
1835 }
1836 }
1837
1838 // If the two operands are the same, turn them into a replicate.
1839 if (inputs[i] == inputs[i - 1]) {
1840 Value replacement =
1841 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1842 return flattenConcat(i - 1, i, replacement);
1843 }
1844
1845 // If this input is a replicate, see if we can fold it with the previous
1846 // one.
1847 if (auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1848 // ... x, repl(x, n), ... ==> ..., repl(x, n+1), ...
1849 if (repl.getOperand() == inputs[i - 1]) {
1850 Value replacement = rewriter.createOrFold<ReplicateOp>(
1851 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1852 return flattenConcat(i - 1, i, replacement);
1853 }
1854 // ... repl(x, n), repl(x, m), ... ==> ..., repl(x, n+m), ...
1855 if (auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1856 if (prevRepl.getOperand() == repl.getOperand()) {
1857 Value replacement = rewriter.createOrFold<ReplicateOp>(
1858 op.getLoc(), repl.getOperand(),
1859 repl.getMultiple() + prevRepl.getMultiple());
1860 return flattenConcat(i - 1, i, replacement);
1861 }
1862 }
1863 }
1864
1865 // ... repl(x, n), x, ... ==> ..., repl(x, n+1), ...
1866 if (auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1867 if (repl.getOperand() == inputs[i]) {
1868 Value replacement = rewriter.createOrFold<ReplicateOp>(
1869 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1870 return flattenConcat(i - 1, i, replacement);
1871 }
1872 }
1873
1874 // Merge neighboring extracts of neighboring inputs, e.g.
1875 // {A[3], A[2]} -> A[3:2]
1876 if (auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1877 if (auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1878 if (extract.getInput() == prevExtract.getInput()) {
1879 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1880 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1881 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1882 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1883 Value replacement =
1884 ExtractOp::create(rewriter, op.getLoc(), resType,
1885 extract.getInput(), extract.getLowBit());
1886 return flattenConcat(i - 1, i, replacement);
1887 }
1888 }
1889 }
1890 }
1891 // Merge neighboring array extracts of neighboring inputs, e.g.
1892 // {Array[4], bitcast(Array[3:2])} -> bitcast(A[4:2])
1893
1894 // This represents a slice of an array.
1895 struct ArraySlice {
1896 Value input;
1897 Value index;
1898 size_t width;
1899 static std::optional<ArraySlice> get(Value value) {
1900 assert(isa<IntegerType>(value.getType()) && "expected integer type");
1901 if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
1902 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1903 // array slice op is wrapped with bitcast.
1904 if (auto bitcast = value.getDefiningOp<hw::BitcastOp>())
1905 if (auto arraySlice =
1906 bitcast.getInput().getDefiningOp<hw::ArraySliceOp>())
1907 return ArraySlice{
1908 arraySlice.getInput(), arraySlice.getLowIndex(),
1909 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1910 .getNumElements()};
1911 return std::nullopt;
1912 }
1913 };
1914 if (auto extractOpt = ArraySlice::get(inputs[i])) {
1915 if (auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1916 // Check that two array slices are mergable.
1917 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1918 prevExtractOpt->input == extractOpt->input &&
1919 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1920 extractOpt->width)) {
1921 auto resType = hw::ArrayType::get(
1922 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1923 .getElementType(),
1924 extractOpt->width + prevExtractOpt->width);
1925 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1926 Value replacement = hw::BitcastOp::create(
1927 rewriter, op.getLoc(), resIntType,
1928 hw::ArraySliceOp::create(rewriter, op.getLoc(), resType,
1929 prevExtractOpt->input,
1930 extractOpt->index));
1931 return flattenConcat(i - 1, i, replacement);
1932 }
1933 }
1934 }
1935 }
1936 }
1937
1938 // If all operands were the same, then this is a replicate.
1939 if (commonOperand) {
1940 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1941 commonOperand);
1942 return success();
1943 }
1944
1945 return failure();
1946}
1947
1948//===----------------------------------------------------------------------===//
1949// MuxOp
1950//===----------------------------------------------------------------------===//
1951
1952OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1953 if (isOpTriviallyRecursive(*this))
1954 return {};
1955
1956 // mux (c, b, b) -> b
1957 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1958 return getTrueValue();
1959 if (auto tv = adaptor.getTrueValue())
1960 if (tv == adaptor.getFalseValue())
1961 return tv;
1962
1963 // mux(0, a, b) -> b
1964 // mux(1, a, b) -> a
1965 if (auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1966 if (pred.getValue().isZero() && getFalseValue() != getResult())
1967 return getFalseValue();
1968 if (pred.getValue().isOne() && getTrueValue() != getResult())
1969 return getTrueValue();
1970 }
1971
1972 // mux(cond, 1, 0) -> cond
1973 if (getCond().getType() == getTrueValue().getType())
1974 if (auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1975 if (auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1976 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1977 hw::getBitWidth(getType()) == 1 && getCond() != getResult())
1978 return getCond();
1979
1980 return {};
1981}
1982
1983/// Check to see if the condition to the specified mux is an equality
1984/// comparison `indexValue` and one or more constants. If so, put the
1985/// constants in the constants vector and return true, otherwise return false.
1986///
1987/// This is part of foldMuxChain.
1988///
1989static bool
1990getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted,
1991 std::function<void(hw::ConstantOp)> constantFn) {
1992 // Handle `idx == 42` and `idx != 42`.
1993 if (auto cmp = cond.getDefiningOp<ICmpOp>()) {
1994 // TODO: We could handle things like "x < 2" as two entries.
1995 auto requiredPredicate =
1996 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1997 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1998 if (auto cst = cmp.getRhs().getDefiningOp<hw::ConstantOp>()) {
1999 constantFn(cst);
2000 return true;
2001 }
2002 }
2003 return false;
2004 }
2005
2006 // Handle mux(`idx == 1 || idx == 3`, value, muxchain).
2007 if (auto orOp = cond.getDefiningOp<OrOp>()) {
2008 if (!isInverted)
2009 return false;
2010 for (auto operand : orOp.getOperands())
2011 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
2012 return false;
2013 return true;
2014 }
2015
2016 // Handle mux(`idx != 1 && idx != 3`, muxchain, value).
2017 if (auto andOp = cond.getDefiningOp<AndOp>()) {
2018 if (isInverted)
2019 return false;
2020 for (auto operand : andOp.getOperands())
2021 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
2022 return false;
2023 return true;
2024 }
2025
2026 return false;
2027}
2028
2029/// Given a mux, check to see if the "on true" value (or "on false" value if
2030/// isFalseSide=true) is a mux tree with the same condition. This allows us
2031/// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into
2032/// `array_get (array_create(A, B, C), VAL)` or a balanced mux tree which is far
2033/// more compact and allows synthesis tools to do more interesting
2034/// optimizations.
2035///
2036/// This returns false if we cannot form the mux tree (or do not want to) and
2037/// returns true if the mux was replaced.
2039 PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide,
2040 llvm::function_ref<MuxChainWithComparisonFoldingStyle(size_t indexWidth,
2041 size_t numEntries)>
2042 styleFn) {
2043 // Get the index value being compared. Later we check to see if it is
2044 // compared to a constant with the right predicate.
2045 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
2046 if (!rootCmp)
2047 return false;
2048 Value indexValue = rootCmp.getLhs();
2049
2050 // Return the value to use if the equality match succeeds.
2051 auto getCaseValue = [&](MuxOp mux) -> Value {
2052 return mux.getOperand(1 + unsigned(!isFalseSide));
2053 };
2054
2055 // Return the value to use if the equality match fails. This is the next
2056 // mux in the sequence or the "otherwise" value.
2057 auto getTreeValue = [&](MuxOp mux) -> Value {
2058 return mux.getOperand(1 + unsigned(isFalseSide));
2059 };
2060
2061 // Start scanning the mux tree to see what we've got. Keep track of the
2062 // constant comparison value and the SSA value to use when equal to it.
2063 SmallVector<Location> locationsFound;
2064 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2065
2066 /// Extract constants and values into `valuesFound` and return true if this is
2067 /// part of the mux tree, otherwise return false.
2068 auto collectConstantValues = [&](MuxOp mux) -> bool {
2070 mux.getCond(), indexValue, isFalseSide, [&](hw::ConstantOp cst) {
2071 valuesFound.push_back({cst, getCaseValue(mux)});
2072 locationsFound.push_back(mux.getCond().getLoc());
2073 locationsFound.push_back(mux->getLoc());
2074 });
2075 };
2076
2077 // Make sure the root is a correct comparison with a constant.
2078 if (!collectConstantValues(rootMux))
2079 return false;
2080
2081 // Make sure that we're not looking at the intermediate node in a mux tree.
2082 if (rootMux->hasOneUse()) {
2083 if (auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2084 if (getTreeValue(userMux) == rootMux.getResult() &&
2085 getMuxChainCondConstant(userMux.getCond(), indexValue, isFalseSide,
2086 [&](hw::ConstantOp cst) {}))
2087 return false;
2088 }
2089 }
2090
2091 // Scan up the tree linearly.
2092 auto nextTreeValue = getTreeValue(rootMux);
2093 while (1) {
2094 auto nextMux = nextTreeValue.getDefiningOp<MuxOp>();
2095 if (!nextMux || !nextMux->hasOneUse())
2096 break;
2097 if (!collectConstantValues(nextMux))
2098 break;
2099 nextTreeValue = getTreeValue(nextMux);
2100 }
2101
2102 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2103
2104 if (indexWidth > 20)
2105 return false; // Too big to make a table.
2106
2107 auto foldingStyle = styleFn(indexWidth, valuesFound.size());
2108 if (foldingStyle == MuxChainWithComparisonFoldingStyle::None)
2109 return false;
2110
2111 uint64_t tableSize = 1ULL << indexWidth;
2112
2113 // Ok, we're going to do the transformation, start by building the table
2114 // filled with the "otherwise" value.
2115 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2116
2117 // Fill in entries in the table from the leaf to the root of the expression.
2118 // This ensures that any duplicate matches end up with the ultimate value,
2119 // which is the one closer to the root.
2120 for (auto &elt : llvm::reverse(valuesFound)) {
2121 uint64_t idx = elt.first.getValue().getZExtValue();
2122 assert(idx < table.size() && "constant should be same bitwidth as index");
2123 table[idx] = elt.second;
2124 }
2125
2127 SmallVector<Value> bits;
2128 comb::extractBits(rewriter, indexValue, bits);
2129 auto result = constructMuxTree(rewriter, rootMux->getLoc(), bits, table,
2130 nextTreeValue);
2131 replaceOpAndCopyNamehint(rewriter, rootMux, result);
2132 return true;
2133 }
2134
2136 "unknown folding style");
2137
2138 // The hw.array_create operation has the operand list in unintuitive order
2139 // with a[0] stored as the last element, not the first.
2140 std::reverse(table.begin(), table.end());
2141
2142 // Build the array_create and the array_get.
2143 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2144 auto array = hw::ArrayCreateOp::create(rewriter, fusedLoc, table);
2145 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
2146 indexValue);
2147 return true;
2148}
2149
2150/// Given a fully associative variadic operation like (a+b+c+d), break the
2151/// expression into two parts, one without the specified operand (e.g.
2152/// `tmp = a+b+d`) and one that combines that into the full expression (e.g.
2153/// `tmp+c`), and return the inner expression.
2154///
2155/// NOTE: This mutates the operation in place if it only has a single user,
2156/// which assumes that user will be removed.
2157///
2158static Value extractOperandFromFullyAssociative(Operation *fullyAssoc,
2159 size_t operandNo,
2160 PatternRewriter &rewriter) {
2161 assert(fullyAssoc->getNumOperands() >= 2 && "cannot split up unary ops");
2162 assert(operandNo < fullyAssoc->getNumOperands() && "Invalid operand #");
2163
2164 // If this expression already has two operands (the common case) no splitting
2165 // is necessary.
2166 if (fullyAssoc->getNumOperands() == 2)
2167 return fullyAssoc->getOperand(operandNo ^ 1);
2168
2169 // If the operation has a single use, mutate it in place.
2170 if (fullyAssoc->hasOneUse()) {
2171 rewriter.modifyOpInPlace(fullyAssoc,
2172 [&]() { fullyAssoc->eraseOperand(operandNo); });
2173 return fullyAssoc->getResult(0);
2174 }
2175
2176 // Form the new operation with the operands that remain.
2177 SmallVector<Value> operands;
2178 operands.append(fullyAssoc->getOperands().begin(),
2179 fullyAssoc->getOperands().begin() + operandNo);
2180 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2181 fullyAssoc->getOperands().end());
2182 Value opWithoutExcluded = createGenericOp(
2183 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2184 Value excluded = fullyAssoc->getOperand(operandNo);
2185
2186 Value fullResult =
2187 createGenericOp(fullyAssoc->getLoc(), fullyAssoc->getName(),
2188 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2189 replaceOpAndCopyNamehint(rewriter, fullyAssoc, fullResult);
2190 return opWithoutExcluded;
2191}
2192
2193/// Fold things like `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond)|a` and
2194/// `mux(cond, a, x|y|z|a) -> `(x|y|z)&replicate(~cond) | a` (when isTrueOperand
2195/// is true. Return true on successful transformation, false if not.
2196///
2197/// These are various forms of "predicated ops" that can be handled with a
2198/// replicate/and combination.
2199static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
2200 PatternRewriter &rewriter) {
2201 // Check to see the operand in question is an operation. If it is a port,
2202 // we can't simplify it.
2203 Operation *subExpr =
2204 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2205 if (!subExpr || subExpr->getNumOperands() < 2)
2206 return false;
2207
2208 // If this isn't an operation we can handle, don't spend energy on it.
2209 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2210 return false;
2211
2212 // Check to see if the common value occurs in the operand list for the
2213 // subexpression op. If so, then we can simplify it.
2214 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2215 size_t opNo = 0, e = subExpr->getNumOperands();
2216 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2217 ++opNo;
2218 if (opNo == e)
2219 return false;
2220
2221 // If we got a hit, then go ahead and simplify it!
2222 Value cond = op.getCond();
2223
2224 // `mux(cond, a, mux(cond2, a, b))` -> `mux(cond|cond2, a, b)`
2225 // `mux(cond, a, mux(cond2, b, a))` -> `mux(cond|~cond2, a, b)`
2226 // `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)`
2227 // `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)`
2228 if (auto subMux = dyn_cast<MuxOp>(subExpr)) {
2229 if (subMux == op)
2230 return false;
2231
2232 Value otherValue;
2233 Value subCond = subMux.getCond();
2234
2235 // Invert th subCond if needed and dig out the 'b' value.
2236 if (subMux.getTrueValue() == commonValue)
2237 otherValue = subMux.getFalseValue();
2238 else if (subMux.getFalseValue() == commonValue) {
2239 otherValue = subMux.getTrueValue();
2240 subCond = createOrFoldNot(rewriter, op.getLoc(), subCond);
2241 } else {
2242 // We can't fold `mux(cond, a, mux(a, x, y))`.
2243 return false;
2244 }
2245
2246 // Invert the outer cond if needed, and combine the mux conditions.
2247 if (!isTrueOperand)
2248 cond = createOrFoldNot(rewriter, op.getLoc(), cond);
2249 cond = rewriter.createOrFold<OrOp>(op.getLoc(), cond, subCond, false);
2250 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2251 otherValue, op.getTwoState());
2252 return true;
2253 }
2254
2255 // Invert the condition if needed. Or/Xor invert when dealing with
2256 // TrueOperand, And inverts for False operand.
2257 bool isaAndOp = isa<AndOp>(subExpr);
2258 if (isTrueOperand ^ isaAndOp)
2259 cond = createOrFoldNot(rewriter, op.getLoc(), cond);
2260
2261 auto extendedCond =
2262 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2263
2264 // Cache this information before subExpr is erased by extraction below.
2265 bool isaXorOp = isa<XorOp>(subExpr);
2266 bool isaOrOp = isa<OrOp>(subExpr);
2267
2268 // Handle the fully associative ops, start by pulling out the subexpression
2269 // from a many operand version of the op.
2270 auto restOfAssoc =
2271 extractOperandFromFullyAssociative(subExpr, opNo, rewriter);
2272
2273 // `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond) | a`
2274 // `mux(cond, x^y^z^a, a)` -> `(x^y^z)&replicate(cond) ^ a`
2275 if (isaOrOp || isaXorOp) {
2276 auto masked = rewriter.createOrFold<AndOp>(op.getLoc(), extendedCond,
2277 restOfAssoc, false);
2278 if (isaXorOp)
2279 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2280 commonValue, false);
2281 else
2282 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2283 false);
2284 return true;
2285 }
2286
2287 // `mux(cond, a, x&y&z&a)` -> `((x&y&z)|replicate(cond)) & a`
2288 assert(isaAndOp && "unexpected operation here");
2289 auto masked = rewriter.createOrFold<OrOp>(op.getLoc(), extendedCond,
2290 restOfAssoc, false);
2291 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2292 false);
2293 return true;
2294}
2295
2296/// This function is invoke when we find a mux with true/false operations that
2297/// have the same opcode. Check to see if we can strength reduce the mux by
2298/// applying it to less data by applying this transformation:
2299/// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2300static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp,
2301 Operation *falseOp,
2302 PatternRewriter &rewriter) {
2303 // Right now we only apply to concat.
2304 // TODO: Generalize this to and, or, xor, icmp(!), which all occur in practice
2305 if (!isa<ConcatOp>(trueOp))
2306 return false;
2307
2308 // Decode the operands, looking through recursive concats and replicates.
2309 SmallVector<Value> trueOperands, falseOperands;
2310 getConcatOperands(trueOp->getResult(0), trueOperands);
2311 getConcatOperands(falseOp->getResult(0), falseOperands);
2312
2313 size_t numTrueOperands = trueOperands.size();
2314 size_t numFalseOperands = falseOperands.size();
2315
2316 if (!numTrueOperands || !numFalseOperands ||
2317 (trueOperands.front() != falseOperands.front() &&
2318 trueOperands.back() != falseOperands.back()))
2319 return false;
2320
2321 // Pull all leading shared operands out into their own op if any are common.
2322 if (trueOperands.front() == falseOperands.front()) {
2323 SmallVector<Value> operands;
2324 size_t i;
2325 for (i = 0; i < numTrueOperands; ++i) {
2326 Value trueOperand = trueOperands[i];
2327 if (trueOperand == falseOperands[i])
2328 operands.push_back(trueOperand);
2329 else
2330 break;
2331 }
2332 if (i == numTrueOperands) {
2333 // Selecting between distinct, but lexically identical, concats.
2334 replaceOpAndCopyNamehint(rewriter, mux, trueOp->getResult(0));
2335 return true;
2336 }
2337
2338 Value sharedMSB;
2339 if (llvm::all_of(operands, [&](Value v) { return v == operands.front(); }))
2340 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2341 mux->getLoc(), operands.front(), operands.size());
2342 else
2343 sharedMSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2344 operands.clear();
2345
2346 // Get a concat of the LSB's on each side.
2347 operands.append(trueOperands.begin() + i, trueOperands.end());
2348 Value trueLSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2349 operands.clear();
2350 operands.append(falseOperands.begin() + i, falseOperands.end());
2351 Value falseLSB =
2352 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2353 // Merge the LSBs with a new mux and concat the MSB with the LSB to be
2354 // done.
2355 Value lsb = rewriter.createOrFold<MuxOp>(
2356 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2357 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2358 return true;
2359 }
2360
2361 // If trailing operands match, try to commonize them.
2362 if (trueOperands.back() == falseOperands.back()) {
2363 SmallVector<Value> operands;
2364 size_t i;
2365 for (i = 0;; ++i) {
2366 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2367 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2368 operands.push_back(trueOperand);
2369 else
2370 break;
2371 }
2372 std::reverse(operands.begin(), operands.end());
2373 Value sharedLSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2374 operands.clear();
2375
2376 // Get a concat of the MSB's on each side.
2377 operands.append(trueOperands.begin(), trueOperands.end() - i);
2378 Value trueMSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2379 operands.clear();
2380 operands.append(falseOperands.begin(), falseOperands.end() - i);
2381 Value falseMSB =
2382 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2383 // Merge the MSBs with a new mux and concat the MSB with the LSB to be done.
2384 Value msb = rewriter.createOrFold<MuxOp>(
2385 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2386 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2387 return true;
2388 }
2389
2390 return false;
2391}
2392
2393// If both arguments of the mux are arrays with the same elements, sink the
2394// mux and return a uniform array initializing all elements to it.
2395static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter) {
2396 auto trueVec = op.getTrueValue().getDefiningOp<hw::ArrayCreateOp>();
2397 auto falseVec = op.getFalseValue().getDefiningOp<hw::ArrayCreateOp>();
2398 if (!trueVec || !falseVec)
2399 return false;
2400 if (!trueVec.isUniform() || !falseVec.isUniform())
2401 return false;
2402
2403 auto mux = MuxOp::create(rewriter, op.getLoc(), op.getCond(),
2404 trueVec.getUniformElement(),
2405 falseVec.getUniformElement(), op.getTwoState());
2406
2407 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2408 rewriter.replaceOpWithNewOp<hw::ArrayCreateOp>(op, values);
2409 return true;
2410}
2411
2412/// If the mux condition is an operand to the op defining its true or false
2413/// value, replace the condition with 1 or 0.
2414static bool assumeMuxCondInOperand(Value muxCond, Value muxValue,
2415 bool constCond, PatternRewriter &rewriter) {
2416 if (!muxValue.hasOneUse())
2417 return false;
2418 auto *op = muxValue.getDefiningOp();
2419 if (!op || !isa_and_nonnull<CombDialect>(op->getDialect()))
2420 return false;
2421 if (!llvm::is_contained(op->getOperands(), muxCond))
2422 return false;
2423 OpBuilder::InsertionGuard guard(rewriter);
2424 rewriter.setInsertionPoint(op);
2425 auto condValue =
2426 hw::ConstantOp::create(rewriter, muxCond.getLoc(), APInt(1, constCond));
2427 rewriter.modifyOpInPlace(op, [&] {
2428 for (auto &use : op->getOpOperands())
2429 if (use.get() == muxCond)
2430 use.set(condValue);
2431 });
2432 return true;
2433}
2434
2435namespace {
2436struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
2437 using OpRewritePattern::OpRewritePattern;
2438
2439 LogicalResult matchAndRewrite(MuxOp op,
2440 PatternRewriter &rewriter) const override;
2441};
2442
2444foldToArrayCreateOnlyWhenDense(size_t indexWidth, size_t numEntries) {
2445 // If the array is greater that 9 bits, it will take over 512 elements and
2446 // it will be too large for a single expression.
2447 if (indexWidth >= 9 || numEntries < 3)
2449
2450 // Next we need to see if the values are dense-ish. We don't want to have
2451 // a tremendous number of replicated entries in the array. Some sparsity is
2452 // ok though, so we require the table to be at least 5/8 utilized.
2453 uint64_t tableSize = 1ULL << indexWidth;
2454 if (numEntries >= tableSize * 5 / 8)
2457}
2458
2459LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
2460 PatternRewriter &rewriter) const {
2461 if (isOpTriviallyRecursive(op))
2462 return failure();
2463
2464 bool isSignlessInt = false;
2465 if (auto intType = dyn_cast<IntegerType>(op.getType()))
2466 isSignlessInt = intType.isSignless();
2467
2468 // If the op has a SV attribute, don't optimize it.
2469 if (hasSVAttributes(op))
2470 return failure();
2471 APInt value;
2472
2473 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value)) && isSignlessInt) {
2474 if (value.getBitWidth() == 1) {
2475 // mux(a, 0, b) -> and(~a, b) for single-bit values.
2476 if (value.isZero()) {
2477 auto notCond = createOrFoldNot(rewriter, op.getLoc(), op.getCond());
2478 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2479 op.getFalseValue(), false);
2480 return success();
2481 }
2482
2483 // mux(a, 1, b) -> or(a, b) for single-bit values.
2484 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2485 op.getFalseValue(), false);
2486 return success();
2487 }
2488
2489 // Check for mux of two constants. There are many ways to simplify them.
2490 APInt value2;
2491 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2492 // When both inputs are constants and differ by only one bit, we can
2493 // simplify by splitting the mux into up to three contiguous chunks: one
2494 // for the differing bit and up to two for the bits that are the same.
2495 // E.g. mux(a, 3'h2, 0) -> concat(0, mux(a, 1, 0), 0) -> concat(0, a, 0)
2496 APInt xorValue = value ^ value2;
2497 if (xorValue.isPowerOf2()) {
2498 unsigned leadingZeros = xorValue.countLeadingZeros();
2499 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2500 SmallVector<Value, 3> operands;
2501
2502 // Concat operands go from MSB to LSB, so we handle chunks in reverse
2503 // order of bit indexes.
2504 // For the chunks that are identical (i.e. correspond to 0s in
2505 // xorValue), we can extract directly from either input value, and we
2506 // arbitrarily pick the trueValue().
2507
2508 if (leadingZeros > 0)
2509 operands.push_back(rewriter.createOrFold<ExtractOp>(
2510 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2511
2512 // Handle the differing bit, which should simplify into either cond or
2513 // ~cond.
2514 auto v1 = rewriter.createOrFold<ExtractOp>(
2515 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2516 auto v2 = rewriter.createOrFold<ExtractOp>(
2517 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2518 operands.push_back(rewriter.createOrFold<MuxOp>(
2519 op.getLoc(), op.getCond(), v1, v2, false));
2520
2521 if (trailingZeros > 0)
2522 operands.push_back(rewriter.createOrFold<ExtractOp>(
2523 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2524
2525 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2526 operands);
2527 return success();
2528 }
2529
2530 // If the true value is all ones and the false is all zeros then we have a
2531 // replicate pattern.
2532 if (value.isAllOnes() && value2.isZero()) {
2533 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2534 rewriter, op, op.getType(), op.getCond());
2535 return success();
2536 }
2537 }
2538 }
2539
2540 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2541 isSignlessInt && value.getBitWidth() == 1) {
2542 // mux(a, b, 0) -> and(a, b) for single-bit values.
2543 if (value.isZero()) {
2544 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2545 op.getTrueValue(), false);
2546 return success();
2547 }
2548
2549 // mux(a, b, 1) -> or(~a, b) for single-bit values.
2550 // falseValue() is known to be a single-bit 1, which we can use for
2551 // the 1 in the representation of ~ using xor.
2552 auto notCond = rewriter.createOrFold<XorOp>(op.getLoc(), op.getCond(),
2553 op.getFalseValue(), false);
2554 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2555 op.getTrueValue(), false);
2556 return success();
2557 }
2558
2559 // mux(!a, b, c) -> mux(a, c, b)
2560 Value subExpr;
2561 Operation *condOp = op.getCond().getDefiningOp();
2562 if (condOp && matchPattern(condOp, m_Complement(m_Any(&subExpr))) &&
2563 op.getTwoState()) {
2564 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2565 subExpr, op.getFalseValue(),
2566 op.getTrueValue(), true);
2567 return success();
2568 }
2569
2570 // Same but with Demorgan's law.
2571 // mux(and(~a, ~b, ~c), x, y) -> mux(or(a, b, c), y, x)
2572 // mux(or(~a, ~b, ~c), x, y) -> mux(and(a, b, c), y, x)
2573 if (condOp && condOp->hasOneUse()) {
2574 SmallVector<Value> invertedOperands;
2575
2576 /// Scan all the operands to see if they are complemented. If so, build a
2577 /// vector of them and return true, otherwise return false.
2578 auto getInvertedOperands = [&]() -> bool {
2579 for (Value operand : condOp->getOperands()) {
2580 if (matchPattern(operand, m_Complement(m_Any(&subExpr))))
2581 invertedOperands.push_back(subExpr);
2582 else
2583 return false;
2584 }
2585 return true;
2586 };
2587
2588 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2589 auto newOr =
2590 rewriter.createOrFold<OrOp>(op.getLoc(), invertedOperands, false);
2591 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2592 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2593 op.getTwoState());
2594 return success();
2595 }
2596 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2597 auto newAnd =
2598 rewriter.createOrFold<AndOp>(op.getLoc(), invertedOperands, false);
2599 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2600 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2601 op.getTwoState());
2602 return success();
2603 }
2604 }
2605
2606 if (auto falseMux = op.getFalseValue().getDefiningOp<MuxOp>();
2607 falseMux && falseMux != op) {
2608 // mux(selector, x, mux(selector, y, z) = mux(selector, x, z)
2609 if (op.getCond() == falseMux.getCond() &&
2610 falseMux.getFalseValue() != falseMux) {
2611 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2612 rewriter, op, op.getCond(), op.getTrueValue(),
2613 falseMux.getFalseValue(), op.getTwoStateAttr());
2614 return success();
2615 }
2616
2617 // Check to see if we can fold a mux tree into an array_create/get pair.
2618 if (foldMuxChainWithComparison(rewriter, op, /*isFalse*/ true,
2619 foldToArrayCreateOnlyWhenDense))
2620 return success();
2621 }
2622
2623 if (auto trueMux = op.getTrueValue().getDefiningOp<MuxOp>();
2624 trueMux && trueMux != op) {
2625 // mux(selector, mux(selector, a, b), c) = mux(selector, a, c)
2626 if (op.getCond() == trueMux.getCond()) {
2627 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2628 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2629 op.getFalseValue(), op.getTwoStateAttr());
2630 return success();
2631 }
2632
2633 // Check to see if we can fold a mux tree into an array_create/get pair.
2634 if (foldMuxChainWithComparison(rewriter, op, /*isFalseSide*/ false,
2635 foldToArrayCreateOnlyWhenDense))
2636 return success();
2637 }
2638
2639 // mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c))
2640 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2641 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2642 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2643 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2644 falseMux != op) {
2645 auto subMux = MuxOp::create(
2646 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2647 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2648 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2649 trueMux.getTrueValue(), subMux,
2650 op.getTwoStateAttr());
2651 return success();
2652 }
2653
2654 // mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b)
2655 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2656 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2657 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2658 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2659 falseMux != op) {
2660 auto subMux = MuxOp::create(
2661 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2662 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2663 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2664 subMux, trueMux.getFalseValue(),
2665 op.getTwoStateAttr());
2666 return success();
2667 }
2668
2669 // mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b)
2670 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2671 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2672 trueMux && falseMux &&
2673 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2674 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2675 falseMux != op) {
2676 auto subMux =
2677 MuxOp::create(rewriter,
2678 rewriter.getFusedLoc(
2679 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2680 op.getCond(), trueMux.getCond(), falseMux.getCond());
2681 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2682 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2683 op.getTwoStateAttr());
2684 return success();
2685 }
2686
2687 // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a
2688 if (foldCommonMuxValue(op, false, rewriter))
2689 return success();
2690 // mux(cond, a, x|y|z|a) -> (x|y|z)&replicate(~cond) | a
2691 if (foldCommonMuxValue(op, true, rewriter))
2692 return success();
2693
2694 // `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2695 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2696 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2697 if (trueOp->getName() == falseOp->getName())
2698 if (foldCommonMuxOperation(op, trueOp, falseOp, rewriter))
2699 return success();
2700
2701 // extracts only of mux(...) -> mux(extract()...)
2702 if (narrowOperationWidth(op, true, rewriter))
2703 return success();
2704
2705 // mux(cond, repl(n, a1), repl(n, a2)) -> repl(n, mux(cond, a1, a2))
2706 if (foldMuxOfUniformArrays(op, rewriter))
2707 return success();
2708
2709 // mux(cond, opA(cond), opB(cond)) -> mux(cond, opA(1), opB(0))
2710 if (op.getTrueValue().getDefiningOp() &&
2711 op.getTrueValue().getDefiningOp() != op)
2712 if (assumeMuxCondInOperand(op.getCond(), op.getTrueValue(), true, rewriter))
2713 return success();
2714 if (op.getFalseValue().getDefiningOp() &&
2715 op.getFalseValue().getDefiningOp() != op)
2716
2717 if (assumeMuxCondInOperand(op.getCond(), op.getFalseValue(), false,
2718 rewriter))
2719 return success();
2720
2721 return failure();
2722}
2723
2724static bool foldArrayOfMuxes(hw::ArrayCreateOp op, PatternRewriter &rewriter) {
2725 // Do not fold uniform or singleton arrays to avoid duplicating muxes.
2726 if (op.getInputs().empty() || op.isUniform())
2727 return false;
2728 auto inputs = op.getInputs();
2729 if (inputs.size() <= 1)
2730 return false;
2731
2732 // Check the operands to the array create. Ensure all of them are the
2733 // same op with the same number of operands.
2734 auto first = inputs[0].getDefiningOp<comb::MuxOp>();
2735 if (!first || hasSVAttributes(first))
2736 return false;
2737
2738 // Check whether all operands are muxes with the same condition.
2739 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2740 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2741 if (!input || first.getCond() != input.getCond())
2742 return false;
2743 }
2744
2745 // Collect the true and the false branches into arrays.
2746 SmallVector<Value> trues{first.getTrueValue()};
2747 SmallVector<Value> falses{first.getFalseValue()};
2748 SmallVector<Location> locs{first->getLoc()};
2749 bool isTwoState = true;
2750 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2751 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2752 trues.push_back(input.getTrueValue());
2753 falses.push_back(input.getFalseValue());
2754 locs.push_back(input->getLoc());
2755 if (!input.getTwoState())
2756 isTwoState = false;
2757 }
2758
2759 // Define the location of the array create as the aggregate of all muxes.
2760 auto loc = FusedLoc::get(op.getContext(), locs);
2761
2762 // Replace the create with an aggregate operation. Push the create op
2763 // into the operands of the aggregate operation.
2764 auto arrayTy = op.getType();
2765 auto trueValues = hw::ArrayCreateOp::create(rewriter, loc, arrayTy, trues);
2766 auto falseValues = hw::ArrayCreateOp::create(rewriter, loc, arrayTy, falses);
2767 rewriter.replaceOpWithNewOp<comb::MuxOp>(op, arrayTy, first.getCond(),
2768 trueValues, falseValues, isTwoState);
2769 return true;
2770}
2771
2772struct ArrayRewriter : public mlir::OpRewritePattern<hw::ArrayCreateOp> {
2773 using OpRewritePattern::OpRewritePattern;
2774
2775 LogicalResult matchAndRewrite(hw::ArrayCreateOp op,
2776 PatternRewriter &rewriter) const override {
2777 if (foldArrayOfMuxes(op, rewriter))
2778 return success();
2779 return failure();
2780 }
2781};
2782
2783} // namespace
2784
2785void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2786 MLIRContext *context) {
2787 results.insert<MuxRewriter, ArrayRewriter>(context);
2788}
2789
2790//===----------------------------------------------------------------------===//
2791// ICmpOp
2792//===----------------------------------------------------------------------===//
2793
2794// Calculate the result of a comparison when the LHS and RHS are both
2795// constants.
2796static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs,
2797 const APInt &rhs) {
2798 switch (predicate) {
2799 case ICmpPredicate::eq:
2800 return lhs.eq(rhs);
2801 case ICmpPredicate::ne:
2802 return lhs.ne(rhs);
2803 case ICmpPredicate::slt:
2804 return lhs.slt(rhs);
2805 case ICmpPredicate::sle:
2806 return lhs.sle(rhs);
2807 case ICmpPredicate::sgt:
2808 return lhs.sgt(rhs);
2809 case ICmpPredicate::sge:
2810 return lhs.sge(rhs);
2811 case ICmpPredicate::ult:
2812 return lhs.ult(rhs);
2813 case ICmpPredicate::ule:
2814 return lhs.ule(rhs);
2815 case ICmpPredicate::ugt:
2816 return lhs.ugt(rhs);
2817 case ICmpPredicate::uge:
2818 return lhs.uge(rhs);
2819 case ICmpPredicate::ceq:
2820 return lhs.eq(rhs);
2821 case ICmpPredicate::cne:
2822 return lhs.ne(rhs);
2823 case ICmpPredicate::weq:
2824 return lhs.eq(rhs);
2825 case ICmpPredicate::wne:
2826 return lhs.ne(rhs);
2827 }
2828 llvm_unreachable("unknown comparison predicate");
2829}
2830
2831// Returns the result of applying the predicate when the LHS and RHS are the
2832// exact same value.
2833static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) {
2834 switch (predicate) {
2835 case ICmpPredicate::eq:
2836 case ICmpPredicate::sle:
2837 case ICmpPredicate::sge:
2838 case ICmpPredicate::ule:
2839 case ICmpPredicate::uge:
2840 case ICmpPredicate::ceq:
2841 case ICmpPredicate::weq:
2842 return true;
2843 case ICmpPredicate::ne:
2844 case ICmpPredicate::slt:
2845 case ICmpPredicate::sgt:
2846 case ICmpPredicate::ult:
2847 case ICmpPredicate::ugt:
2848 case ICmpPredicate::cne:
2849 case ICmpPredicate::wne:
2850 return false;
2851 }
2852 llvm_unreachable("unknown comparison predicate");
2853}
2854
2855OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2856 // gt a, a -> false
2857 // gte a, a -> true
2858 if (getLhs() == getRhs()) {
2859 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2860 return IntegerAttr::get(getType(), val);
2861 }
2862
2863 // gt 1, 2 -> false
2864 if (auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2865 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2866 auto val =
2867 applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2868 return IntegerAttr::get(getType(), val);
2869 }
2870 }
2871 return {};
2872}
2873
2874// Given a range of operands, computes the number of matching prefix and
2875// suffix elements. This does not perform cross-element matching.
2876template <typename Range>
2877static size_t computeCommonPrefixLength(const Range &a, const Range &b) {
2878 size_t commonPrefixLength = 0;
2879 auto ia = a.begin();
2880 auto ib = b.begin();
2881
2882 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2883 if (*ia != *ib) {
2884 break;
2885 }
2886 }
2887
2888 return commonPrefixLength;
2889}
2890
2891static size_t getTotalWidth(ArrayRef<Value> operands) {
2892 size_t totalWidth = 0;
2893 for (auto operand : operands) {
2894 // getIntOrFloatBitWidth should never raise, since all arguments to
2895 // ConcatOp are integers.
2896 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2897 assert(width >= 0);
2898 totalWidth += width;
2899 }
2900 return totalWidth;
2901}
2902
2903/// Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise
2904/// comparison on common prefix and suffixes. Returns success() if a rewriting
2905/// happens. This handles both concat and replicate.
2906static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs,
2907 Operation *rhs,
2908 PatternRewriter &rewriter) {
2909 // It is safe to assume that [{lhsOperands, rhsOperands}.size() > 0] and
2910 // all elements have non-zero length. Both these invariants are verified
2911 // by the ConcatOp verifier.
2912 SmallVector<Value> lhsOperands, rhsOperands;
2913 getConcatOperands(lhs->getResult(0), lhsOperands);
2914 getConcatOperands(rhs->getResult(0), rhsOperands);
2915 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2916
2917 auto formCatOrReplicate = [&](Location loc,
2918 ArrayRef<Value> operands) -> Value {
2919 assert(!operands.empty());
2920 Value sameElement = operands[0];
2921 for (size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2922 if (sameElement != operands[i])
2923 sameElement = Value();
2924 if (sameElement)
2925 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2926 operands.size());
2927 return rewriter.createOrFold<ConcatOp>(loc, operands);
2928 };
2929
2930 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2931 Value rhs) -> LogicalResult {
2932 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2933 op.getTwoState());
2934 return success();
2935 };
2936
2937 size_t commonPrefixLength =
2938 computeCommonPrefixLength(lhsOperands, rhsOperands);
2939 if (commonPrefixLength == lhsOperands.size()) {
2940 // cat(a, b, c) == cat(a, b, c) -> 1
2941 bool result = applyCmpPredicateToEqualOperands(op.getPredicate());
2942 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2943 APInt(1, result));
2944 return success();
2945 }
2946
2947 size_t commonSuffixLength = computeCommonPrefixLength(
2948 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2949
2950 size_t commonPrefixTotalWidth =
2951 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2952 size_t commonSuffixTotalWidth =
2953 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2954 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2955 .drop_back(commonSuffixLength);
2956 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2957 .drop_back(commonSuffixLength);
2958
2959 auto replaceWithoutReplicatingSignBit = [&]() {
2960 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2961 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2962 return replaceWith(op.getPredicate(), newLhs, newRhs);
2963 };
2964
2965 auto replaceWithReplicatingSignBit = [&]() {
2966 auto firstNonEmptyValue = lhsOperands[0];
2967 auto firstNonEmptyElemWidth =
2968 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2969 Value signBit = rewriter.createOrFold<ExtractOp>(
2970 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2971
2972 auto newLhs = ConcatOp::create(rewriter, lhs->getLoc(), signBit, lhsOnly);
2973 auto newRhs = ConcatOp::create(rewriter, rhs->getLoc(), signBit, rhsOnly);
2974 return replaceWith(op.getPredicate(), newLhs, newRhs);
2975 };
2976
2977 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2978 // scmp(cat(..x, b), cat(..y, b)) == scmp(cat(..x), cat(..y))
2979 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2980 return replaceWithoutReplicatingSignBit();
2981
2982 // scmp(cat(a, ..x, b), cat(a, ..y, b)) == scmp(cat(sgn(a), ..x),
2983 // cat(sgn(b), ..y)) Note that we cannot perform this optimization if
2984 // [width(b) = 0 && width(a) <= 1]. since that common prefix is the sign
2985 // bit. Doing the rewrite can result in an infinite loop.
2986 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2987 return replaceWithReplicatingSignBit();
2988
2989 } else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2990 // ucmp(cat(a, ..x, b), cat(a, ..y, b)) = ucmp(cat(..x), cat(..y))
2991 return replaceWithoutReplicatingSignBit();
2992 }
2993
2994 return failure();
2995}
2996
2997/// Given an equality comparison with a constant value and some operand that has
2998/// known bits, simplify the comparison to check only the unknown bits of the
2999/// input.
3000///
3001/// One simple example of this is that `concat(0, stuff) == 0` can be simplified
3002/// to `stuff == 0`, or `and(x, 3) == 0` can be simplified to
3003/// `extract x[1:0] == 0`
3005 ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst,
3006 PatternRewriter &rewriter) {
3007
3008 // If any of the known bits disagree with any of the comparison bits, then
3009 // we can constant fold this comparison right away.
3010 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
3011 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
3012 // If we discover a mismatch then we know an "eq" comparison is false
3013 // and a "ne" comparison is true!
3014 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
3015 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
3016 APInt(1, result));
3017 return;
3018 }
3019
3020 // Check to see if we can prove the result entirely of the comparison (in
3021 // which we bail out early), otherwise build a list of values to concat and a
3022 // smaller constant to compare against.
3023 SmallVector<Value> newConcatOperands;
3024 auto newConstant = APInt::getZeroWidth();
3025
3026 // Ok, some (maybe all) bits are known and some others may be unknown.
3027 // Extract out segments of the operand and compare against the
3028 // corresponding bits.
3029 unsigned knownMSB = bitsKnown.countLeadingOnes();
3030
3031 Value operand = cmpOp.getLhs();
3032
3033 // Ok, some bits are known but others are not. Extract out sequences of
3034 // bits that are unknown and compare just those bits. We work from MSB to
3035 // LSB.
3036 while (knownMSB != bitsKnown.getBitWidth()) {
3037 // Drop any high bits that are known.
3038 if (knownMSB)
3039 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
3040
3041 // Find the span of unknown bits, and extract it.
3042 unsigned unknownBits = bitsKnown.countLeadingZeros();
3043 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
3044 auto spanOperand = rewriter.createOrFold<ExtractOp>(
3045 operand.getLoc(), operand, /*lowBit=*/lowBit,
3046 /*bitWidth=*/unknownBits);
3047 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
3048
3049 // Add this info to the concat we're generating.
3050 newConcatOperands.push_back(spanOperand);
3051 // FIXME(llvm merge, cc697fc292b0): concat doesn't work with zero bit values
3052 // newConstant = newConstant.concat(spanConstant);
3053 if (newConstant.getBitWidth() != 0)
3054 newConstant = newConstant.concat(spanConstant);
3055 else
3056 newConstant = spanConstant;
3057
3058 // Drop the unknown bits in prep for the next chunk.
3059 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
3060 bitsKnown = bitsKnown.trunc(newWidth);
3061 knownMSB = bitsKnown.countLeadingOnes();
3062 }
3063
3064 // If all the operands to the concat are foldable then we have an identity
3065 // situation where all the sub-elements equal each other. This implies that
3066 // the overall result is foldable.
3067 if (newConcatOperands.empty()) {
3068 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
3069 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
3070 APInt(1, result));
3071 return;
3072 }
3073
3074 // If we have a single operand remaining, use it, otherwise form a concat.
3075 Value concatResult =
3076 rewriter.createOrFold<ConcatOp>(operand.getLoc(), newConcatOperands);
3077
3078 // Form the comparison against the smaller constant.
3079 auto newConstantOp = hw::ConstantOp::create(
3080 rewriter, cmpOp.getOperand(1).getLoc(), newConstant);
3081
3082 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
3083 cmpOp.getPredicate(), concatResult,
3084 newConstantOp, cmpOp.getTwoState());
3085}
3086
3087// Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), cst1^cst2).
3088static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp,
3089 const APInt &rhs,
3090 PatternRewriter &rewriter) {
3091 auto ip = rewriter.saveInsertionPoint();
3092 rewriter.setInsertionPoint(xorOp);
3093
3094 auto xorRHS = xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>();
3095 auto newRHS = hw::ConstantOp::create(rewriter, xorRHS->getLoc(),
3096 xorRHS.getValue() ^ rhs);
3097 Value newLHS;
3098 switch (xorOp.getNumOperands()) {
3099 case 1:
3100 // This isn't common but is defined so we need to handle it.
3101 newLHS = hw::ConstantOp::create(rewriter, xorOp.getLoc(),
3102 APInt::getZero(rhs.getBitWidth()));
3103 break;
3104 case 2:
3105 // The binary case is the most common.
3106 newLHS = xorOp.getOperand(0);
3107 break;
3108 default:
3109 // The general case forces us to form a new xor with the remaining operands.
3110 SmallVector<Value> newOperands(xorOp.getOperands());
3111 newOperands.pop_back();
3112 newLHS = XorOp::create(rewriter, xorOp.getLoc(), newOperands, false);
3113 break;
3114 }
3115
3116 bool xorMultipleUses = !xorOp->hasOneUse();
3117
3118 // If the xor has multiple uses (not just the compare, then we need/want to
3119 // replace them as well.
3120 if (xorMultipleUses)
3121 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3122 false);
3123
3124 // Replace the comparison.
3125 rewriter.restoreInsertionPoint(ip);
3126 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3127 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS, false);
3128}
3129
3130LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3131 if (isOpTriviallyRecursive(op))
3132 return failure();
3133 APInt lhs, rhs;
3134
3135 // icmp 1, x -> icmp x, 1
3136 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3137 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3138 "Should be folded");
3139 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3140 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3141 op.getRhs(), op.getLhs(), op.getTwoState());
3142 return success();
3143 }
3144
3145 // Canonicalize with RHS constant
3146 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3147 auto getConstant = [&](APInt constant) -> Value {
3148 return hw::ConstantOp::create(rewriter, op.getLoc(), std::move(constant));
3149 };
3150
3151 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3152 Value rhs) -> LogicalResult {
3153 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
3154 rhs, op.getTwoState());
3155 return success();
3156 };
3157
3158 auto replaceWithConstantI1 = [&](bool constant) -> LogicalResult {
3159 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
3160 APInt(1, constant));
3161 return success();
3162 };
3163
3164 switch (op.getPredicate()) {
3165 case ICmpPredicate::slt:
3166 // x < max -> x != max
3167 if (rhs.isMaxSignedValue())
3168 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3169 // x < min -> false
3170 if (rhs.isMinSignedValue())
3171 return replaceWithConstantI1(0);
3172 // x < min+1 -> x == min
3173 if ((rhs - 1).isMinSignedValue())
3174 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3175 getConstant(rhs - 1));
3176 break;
3177 case ICmpPredicate::sgt:
3178 // x > min -> x != min
3179 if (rhs.isMinSignedValue())
3180 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3181 // x > max -> false
3182 if (rhs.isMaxSignedValue())
3183 return replaceWithConstantI1(0);
3184 // x > max-1 -> x == max
3185 if ((rhs + 1).isMaxSignedValue())
3186 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3187 getConstant(rhs + 1));
3188 break;
3189 case ICmpPredicate::ult:
3190 // x < max -> x != max
3191 if (rhs.isAllOnes())
3192 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3193 // x < min -> false
3194 if (rhs.isZero())
3195 return replaceWithConstantI1(0);
3196 // x < min+1 -> x == min
3197 if ((rhs - 1).isZero())
3198 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3199 getConstant(rhs - 1));
3200
3201 // x < 0xE0 -> extract(x, 5..7) != 0b111
3202 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3203 rhs.getBitWidth()) {
3204 auto numOnes = rhs.countLeadingOnes();
3205 auto smaller = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(),
3206 rhs.getBitWidth() - numOnes, numOnes);
3207 return replaceWith(ICmpPredicate::ne, smaller,
3208 getConstant(APInt::getAllOnes(numOnes)));
3209 }
3210
3211 break;
3212 case ICmpPredicate::ugt:
3213 // x > min -> x != min
3214 if (rhs.isZero())
3215 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3216 // x > max -> false
3217 if (rhs.isAllOnes())
3218 return replaceWithConstantI1(0);
3219 // x > max-1 -> x == max
3220 if ((rhs + 1).isAllOnes())
3221 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3222 getConstant(rhs + 1));
3223
3224 // x > 0x07 -> extract(x, 3..7) != 0b00000
3225 if ((rhs + 1).isPowerOf2()) {
3226 auto numOnes = rhs.countTrailingOnes();
3227 auto newWidth = rhs.getBitWidth() - numOnes;
3228 auto smaller = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(),
3229 numOnes, newWidth);
3230 return replaceWith(ICmpPredicate::ne, smaller,
3231 getConstant(APInt::getZero(newWidth)));
3232 }
3233
3234 break;
3235 case ICmpPredicate::sle:
3236 // x <= max -> true
3237 if (rhs.isMaxSignedValue())
3238 return replaceWithConstantI1(1);
3239 // x <= c -> x < (c+1)
3240 return replaceWith(ICmpPredicate::slt, op.getLhs(), getConstant(rhs + 1));
3241 case ICmpPredicate::sge:
3242 // x >= min -> true
3243 if (rhs.isMinSignedValue())
3244 return replaceWithConstantI1(1);
3245 // x >= c -> x > (c-1)
3246 return replaceWith(ICmpPredicate::sgt, op.getLhs(), getConstant(rhs - 1));
3247 case ICmpPredicate::ule:
3248 // x <= max -> true
3249 if (rhs.isAllOnes())
3250 return replaceWithConstantI1(1);
3251 // x <= c -> x < (c+1)
3252 return replaceWith(ICmpPredicate::ult, op.getLhs(), getConstant(rhs + 1));
3253 case ICmpPredicate::uge:
3254 // x >= min -> true
3255 if (rhs.isZero())
3256 return replaceWithConstantI1(1);
3257 // x >= c -> x > (c-1)
3258 return replaceWith(ICmpPredicate::ugt, op.getLhs(), getConstant(rhs - 1));
3259 case ICmpPredicate::eq:
3260 if (rhs.getBitWidth() == 1) {
3261 if (rhs.isZero()) {
3262 // x == 0 -> x ^ 1
3263 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3264 getConstant(APInt(1, 1)),
3265 op.getTwoState());
3266 return success();
3267 }
3268 if (rhs.isAllOnes()) {
3269 // x == 1 -> x
3270 replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
3271 return success();
3272 }
3273 }
3274 break;
3275 case ICmpPredicate::ne:
3276 if (rhs.getBitWidth() == 1) {
3277 if (rhs.isZero()) {
3278 // x != 0 -> x
3279 replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
3280 return success();
3281 }
3282 if (rhs.isAllOnes()) {
3283 // x != 1 -> x ^ 1
3284 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3285 getConstant(APInt(1, 1)),
3286 op.getTwoState());
3287 return success();
3288 }
3289 }
3290 break;
3291 case ICmpPredicate::ceq:
3292 case ICmpPredicate::cne:
3293 case ICmpPredicate::weq:
3294 case ICmpPredicate::wne:
3295 break;
3296 }
3297
3298 // We have some specific optimizations for comparison with a constant that
3299 // are only supported for equality comparisons.
3300 if (op.getPredicate() == ICmpPredicate::eq ||
3301 op.getPredicate() == ICmpPredicate::ne) {
3302 // Simplify `icmp(value_with_known_bits, rhscst)` into some extracts
3303 // with a smaller constant. We only support equality comparisons for
3304 // this.
3305 auto knownBits = computeKnownBits(op.getLhs());
3306 if (!knownBits.isUnknown())
3307 return combineEqualityICmpWithKnownBitsAndConstant(op, knownBits, rhs,
3308 rewriter),
3309 success();
3310
3311 // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b),
3312 // cst1^cst2).
3313 if (auto xorOp = op.getLhs().getDefiningOp<XorOp>())
3314 if (xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>())
3315 return combineEqualityICmpWithXorOfConstant(op, xorOp, rhs, rewriter),
3316 success();
3317
3318 // Simplify icmp eq(replicate(v, n), c) -> icmp eq(v, c) if c is zero or
3319 // all one.
3320 if (auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3321 if (rhs.isAllOnes() || rhs.isZero()) {
3322 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3323 auto cst =
3324 hw::ConstantOp::create(rewriter, op.getLoc(),
3325 rhs.isAllOnes() ? APInt::getAllOnes(width)
3326 : APInt::getZero(width));
3327 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3328 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3329 op.getTwoState());
3330 return success();
3331 }
3332 }
3333 }
3334
3335 // icmp(cat(prefix, a, b, suffix), cat(prefix, c, d, suffix)) => icmp(cat(a,
3336 // b), cat(c, d)). contains special handling for sign bit in signed
3337 // compressions.
3338 if (Operation *opLHS = op.getLhs().getDefiningOp())
3339 if (Operation *opRHS = op.getRhs().getDefiningOp())
3340 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3341 isa<ConcatOp, ReplicateOp>(opRHS)) {
3342 if (succeeded(matchAndRewriteCompareConcat(op, opLHS, opRHS, rewriter)))
3343 return success();
3344 }
3345
3346 return failure();
3347}
assert(baseType &&"element must be base type")
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
Definition CombFolds.cpp:85
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
Definition CombFolds.cpp:52
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
Definition CombFolds.cpp:46
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
Definition CombFolds.cpp:91
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool assumeMuxCondInOperand(Value muxCond, Value muxValue, bool constCond, PatternRewriter &rewriter)
If the mux condition is an operand to the op defining its true or false value, replace the condition ...
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool isOpTriviallyRecursive(Operation *op)
Definition CombFolds.cpp:27
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool hasSVAttributes(Operation *op)
Definition CombFolds.cpp:67
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
Definition CombFolds.cpp:38
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::unique_ptr< Context > context
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(low_bit, result_type, input=None)
Definition comb.py:187
create(elements, Type result_type=None)
Definition hw.py:483
create(array_value, low_index, ret_type)
Definition hw.py:466
create(data_type, value)
Definition hw.py:441
create(data_type, value)
Definition hw.py:433
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:56
void extractBits(OpBuilder &builder, Value val, SmallVectorImpl< Value > &bits)
Extract bits from a value.
Definition CombOps.cpp:79
bool foldMuxChainWithComparison(PatternRewriter &rewriter, MuxOp rootMux, bool isFalseSide, llvm::function_ref< MuxChainWithComparisonFoldingStyle(size_t indexWidth, size_t numEntries)> styleFn)
Mux chain folding that converts chains of muxes with index comparisons into array operations or balan...
Value createOrFoldNot(OpBuilder &builder, Location loc, Value value, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition CombOps.cpp:67
MuxChainWithComparisonFoldingStyle
Enum for mux chain folding styles.
Definition CombOps.h:104
@ BalancedMuxTree
Definition CombOps.h:104
LogicalResult convertModUByPowerOfTwo(ModUOp modOp, mlir::PatternRewriter &rewriter)
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
Value constructMuxTree(OpBuilder &builder, Location loc, ArrayRef< Value > selectors, ArrayRef< Value > leafNodes, Value outOfBoundsValue)
Construct a mux tree for given leaf nodes.
Definition CombOps.cpp:106
Value createOrFoldSExt(OpBuilder &builder, Location loc, Value value, Type destTy)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Definition CombOps.cpp:44
LogicalResult convertDivUByPowerOfTwo(DivUOp divOp, mlir::PatternRewriter &rewriter)
Convert unsigned division or modulo by a power of two.
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
Definition comb.py:1