Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
CombFolds.cpp
Go to the documentation of this file.
1//===- CombFolds.cpp - Folds + Canonicalization for Comb operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/PatternMatch.h"
15#include "llvm/ADT/SetVector.h"
16#include "llvm/ADT/SmallBitVector.h"
17#include "llvm/ADT/TypeSwitch.h"
18#include "llvm/Support/KnownBits.h"
19
20using namespace mlir;
21using namespace circt;
22using namespace comb;
23using namespace matchers;
24
25// Returns true if the op has one of its own results as an operand.
26static bool isOpTriviallyRecursive(Operation *op) {
27 return llvm::any_of(op->getOperands(), [op](auto operand) {
28 return operand.getDefiningOp() == op;
29 });
30}
31
32/// Create a new instance of a generic operation that only has value operands,
33/// and has a single result value whose type matches the first operand.
34///
35/// This should not be used to create instances of ops with attributes or with
36/// more complicated type signatures.
37static Value createGenericOp(Location loc, OperationName name,
38 ArrayRef<Value> operands, OpBuilder &builder) {
39 OperationState state(loc, name);
40 state.addOperands(operands);
41 state.addTypes(operands[0].getType());
42 return builder.create(state)->getResult(0);
43}
44
45static TypedAttr getIntAttr(const APInt &value, MLIRContext *context) {
46 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
47 value);
48}
49
50/// Flatten concat and mux operands into a vector.
51static void getConcatOperands(Value v, SmallVectorImpl<Value> &result) {
52 if (auto concat = v.getDefiningOp<ConcatOp>()) {
53 for (auto op : concat.getOperands())
54 getConcatOperands(op, result);
55 } else if (auto repl = v.getDefiningOp<ReplicateOp>()) {
56 for (size_t i = 0, e = repl.getMultiple(); i != e; ++i)
57 getConcatOperands(repl.getOperand(), result);
58 } else {
59 result.push_back(v);
60 }
61}
62
63// Return true if the op has SV attributes. Note that we cannot use a helper
64// function `hasSVAttributes` defined under SV dialect because of a cyclic
65// dependency.
66static bool hasSVAttributes(Operation *op) {
67 return op->hasAttr("sv.attributes");
68}
69
70namespace {
71template <typename SubType>
72struct ComplementMatcher {
73 SubType lhs;
74 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
75 bool match(Operation *op) {
76 auto xorOp = dyn_cast<XorOp>(op);
77 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
78 }
79};
80} // end anonymous namespace
81
82template <typename SubType>
83static inline ComplementMatcher<SubType> m_Complement(const SubType &subExpr) {
84 return ComplementMatcher<SubType>(subExpr);
85}
86
87/// Return true if the op will be flattened afterwards. Op will be flattend if
88/// it has a single user which has a same op type. User must be in same block.
89static bool shouldBeFlattened(Operation *op) {
90 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
91 "must be commutative operations"));
92 if (op->hasOneUse()) {
93 auto *user = *op->getUsers().begin();
94 return user->getName() == op->getName() &&
95 op->getAttrOfType<UnitAttr>("twoState") ==
96 user->getAttrOfType<UnitAttr>("twoState") &&
97 op->getBlock() == user->getBlock();
98 }
99 return false;
100}
101
102/// Flattens a single input in `op` if `hasOneUse` is true and it can be defined
103/// as an Op. Returns true if successful, and false otherwise.
104///
105/// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true
106///
107static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) {
108 // Skip if the operation should be flattened by another operation.
109 if (shouldBeFlattened(op))
110 return false;
111
112 auto inputs = op->getOperands();
113
114 SmallVector<Value, 4> newOperands;
115 SmallVector<Location, 4> newLocations{op->getLoc()};
116 newOperands.reserve(inputs.size());
117 struct Element {
118 decltype(inputs.begin()) current, end;
119 };
120
121 SmallVector<Element> worklist;
122 worklist.push_back({inputs.begin(), inputs.end()});
123 bool binFlag = op->hasAttrOfType<UnitAttr>("twoState");
124 bool changed = false;
125 while (!worklist.empty()) {
126 auto &element = worklist.back(); // Do not pop. Take ref.
127
128 // Pop when we finished traversing the current operand range.
129 if (element.current == element.end) {
130 worklist.pop_back();
131 continue;
132 }
133
134 Value value = *element.current++;
135 auto *flattenOp = value.getDefiningOp();
136 // If not defined by a compatible operation of the same kind and
137 // from the same block, keep this as-is.
138 if (!flattenOp || flattenOp->getName() != op->getName() ||
139 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>("twoState") ||
140 flattenOp->getBlock() != op->getBlock()) {
141 newOperands.push_back(value);
142 continue;
143 }
144
145 // Don't duplicate logic when it has multiple uses.
146 if (!value.hasOneUse()) {
147 // We can fold a multi-use binary operation into this one if this allows a
148 // constant to fold though. For example, fold
149 // (or a, b, c, (or d, cst1), cst2) --> (or a, b, c, d, cst1, cst2)
150 // since the constants will both fold and we end up with the equiv cost.
151 //
152 // We don't do this for add/mul because the hardware won't be shared
153 // between the two ops if duplicated.
154 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
155 !flattenOp->getOperand(1).getDefiningOp<hw::ConstantOp>() ||
156 !inputs.back().getDefiningOp<hw::ConstantOp>()) {
157 newOperands.push_back(value);
158 continue;
159 }
160 }
161
162 changed = true;
163
164 // Otherwise, push operands into worklist.
165 auto flattenOpInputs = flattenOp->getOperands();
166 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
167 newLocations.push_back(flattenOp->getLoc());
168 }
169
170 if (!changed)
171 return false;
172
173 Value result = createGenericOp(FusedLoc::get(op->getContext(), newLocations),
174 op->getName(), newOperands, rewriter);
175 if (binFlag)
176 result.getDefiningOp()->setAttr("twoState", rewriter.getUnitAttr());
177
178 replaceOpAndCopyNamehint(rewriter, op, result);
179 return true;
180}
181
182// Given a range of uses of an operation, find the lowest and highest bits
183// inclusive that are ever referenced. The range of uses must not be empty.
184static std::pair<size_t, size_t>
185getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits,
186 size_t originalOpWidth) {
187 auto users = op->getUsers();
188 assert(!users.empty() &&
189 "getLowestBitAndHighestBitRequired cannot operate on "
190 "a empty list of uses.");
191
192 // when we don't want to narrowTrailingBits (namely in arithmetic
193 // operations), forcing lowestBitRequired = 0
194 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
195 size_t highestBitRequired = 0;
196
197 for (auto *user : users) {
198 if (auto extractOp = dyn_cast<ExtractOp>(user)) {
199 size_t lowBit = extractOp.getLowBit();
200 size_t highBit =
201 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
202 highestBitRequired = std::max(highestBitRequired, highBit);
203 lowestBitRequired = std::min(lowestBitRequired, lowBit);
204 continue;
205 }
206
207 highestBitRequired = originalOpWidth - 1;
208 lowestBitRequired = 0;
209 break;
210 }
211
212 return {lowestBitRequired, highestBitRequired};
213}
214
215template <class OpTy>
216static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits,
217 PatternRewriter &rewriter) {
218 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
219 if (!opType)
220 return false;
221
222 auto range = getLowestBitAndHighestBitRequired(op, narrowTrailingBits,
223 opType.getWidth());
224 if (range.second + 1 == opType.getWidth() && range.first == 0)
225 return false;
226
227 SmallVector<Value> args;
228 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
229 for (auto inop : op.getOperands()) {
230 // deal with muxes here
231 if (inop.getType() != op.getType())
232 args.push_back(inop);
233 else
234 args.push_back(rewriter.createOrFold<ExtractOp>(inop.getLoc(), newType,
235 inop, range.first));
236 }
237 auto newop = OpTy::create(rewriter, op.getLoc(), newType, args);
238 newop->setDialectAttrs(op->getDialectAttrs());
239 if (op.getTwoState())
240 newop.setTwoState(true);
241
242 Value newResult = newop.getResult();
243 if (range.first)
244 newResult = rewriter.createOrFold<ConcatOp>(
245 op.getLoc(), newResult,
246 hw::ConstantOp::create(rewriter, op.getLoc(),
247 APInt::getZero(range.first)));
248 if (range.second + 1 < opType.getWidth())
249 newResult = rewriter.createOrFold<ConcatOp>(
250 op.getLoc(),
252 rewriter, op.getLoc(),
253 APInt::getZero(opType.getWidth() - range.second - 1)),
254 newResult);
255 rewriter.replaceOp(op, newResult);
256 return true;
257}
258
259//===----------------------------------------------------------------------===//
260// Unary Operations
261//===----------------------------------------------------------------------===//
262
263OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
264 if (isOpTriviallyRecursive(*this))
265 return {};
266
267 // Replicate one time -> noop.
268 if (cast<IntegerType>(getType()).getWidth() ==
269 getInput().getType().getIntOrFloatBitWidth())
270 return getInput();
271
272 // Constant fold.
273 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
274 if (input.getValue().getBitWidth() == 1) {
275 if (input.getValue().isZero())
276 return getIntAttr(
277 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
278 getContext());
279 return getIntAttr(
280 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
281 getContext());
282 }
283
284 APInt result = APInt::getZeroWidth();
285 for (auto i = getMultiple(); i != 0; --i)
286 result = result.concat(input.getValue());
287 return getIntAttr(result, getContext());
288 }
289
290 return {};
291}
292
293OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
294 if (isOpTriviallyRecursive(*this))
295 return {};
296
297 // Constant fold.
298 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
299 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
300
301 return {};
302}
303
304//===----------------------------------------------------------------------===//
305// Binary Operations
306//===----------------------------------------------------------------------===//
307
308/// Performs constant folding `calculate` with element-wise behavior on the two
309/// attributes in `operands` and returns the result if possible.
310static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
311 hw::PEO paramOpcode) {
312 assert(operands.size() == 2 && "binary op takes two operands");
313 if (!operands[0] || !operands[1])
314 return {};
315
316 // Fold constants with ParamExprAttr::get which handles simple constants as
317 // well as parameter expressions.
318 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
319 cast<TypedAttr>(operands[1]));
320}
321
322OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
323 if (isOpTriviallyRecursive(*this))
324 return {};
325
326 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
327 if (rhs.getValue().isZero())
328 return getOperand(0);
329
330 unsigned width = getType().getIntOrFloatBitWidth();
331 if (rhs.getValue().uge(width))
332 return getIntAttr(APInt::getZero(width), getContext());
333 }
334 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::Shl);
335}
336
337LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
339 return failure();
340
341 // ShlOp(x, cst) -> Concat(Extract(x), zeros)
342 APInt value;
343 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
344 return failure();
345
346 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
347 if (value.ugt(width))
348 value = width;
349 unsigned shift = value.getZExtValue();
350
351 // This case is handled by fold.
352 if (width <= shift || shift == 0)
353 return failure();
354
355 auto zeros =
356 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getZero(shift));
357
358 // Remove the high bits which would be removed by the Shl.
359 auto extract =
360 ExtractOp::create(rewriter, op.getLoc(), op.getLhs(), 0, width - shift);
361
362 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
363 return success();
364}
365
366OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
367 if (isOpTriviallyRecursive(*this))
368 return {};
369
370 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
371 if (rhs.getValue().isZero())
372 return getOperand(0);
373
374 unsigned width = getType().getIntOrFloatBitWidth();
375 if (rhs.getValue().uge(width))
376 return getIntAttr(APInt::getZero(width), getContext());
377 }
378 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrU);
379}
380
381LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
383 return failure();
384
385 // ShrUOp(x, cst) -> Concat(zeros, Extract(x))
386 APInt value;
387 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
388 return failure();
389
390 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
391 if (value.ugt(width))
392 value = width;
393 unsigned shift = value.getZExtValue();
394
395 // This case is handled by fold.
396 if (width <= shift || shift == 0)
397 return failure();
398
399 auto zeros =
400 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getZero(shift));
401
402 // Remove the low bits which would be removed by the Shr.
403 auto extract = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(), shift,
404 width - shift);
405
406 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
407 return success();
408}
409
410OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
411 if (isOpTriviallyRecursive(*this))
412 return {};
413
414 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs()))
415 if (rhs.getValue().isZero())
416 return getOperand(0);
417 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrS);
418}
419
420LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
422 return failure();
423
424 // ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x))
425 APInt value;
426 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
427 return failure();
428
429 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
430 if (value.ugt(width))
431 value = width;
432 unsigned shift = value.getZExtValue();
433
434 auto topbit =
435 rewriter.createOrFold<ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
436 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
437
438 if (width == shift) {
439 replaceOpAndCopyNamehint(rewriter, op, {sext});
440 return success();
441 }
442
443 auto extract = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(), shift,
444 width - shift);
445
446 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
447 return success();
448}
449
450//===----------------------------------------------------------------------===//
451// Other Operations
452//===----------------------------------------------------------------------===//
453
454OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
455 if (isOpTriviallyRecursive(*this))
456 return {};
457
458 // If we are extracting the entire input, then return it.
459 if (getInput().getType() == getType())
460 return getInput();
461
462 // Constant fold.
463 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
464 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
465 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
466 getContext());
467 }
468 return {};
469}
470
471// Transforms extract(lo, cat(a, b, c, d, e)) into
472// cat(extract(lo1, b), c, extract(lo2, d)).
473// innerCat must be the argument of the provided ExtractOp.
475 ConcatOp innerCat,
476 PatternRewriter &rewriter) {
477 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
478 size_t beginOfFirstRelevantElement = 0;
479 auto it = reversedConcatArgs.begin();
480 size_t lowBit = op.getLowBit();
481
482 // This loop finds the first concatArg that is covered by the ExtractOp
483 for (; it != reversedConcatArgs.end(); it++) {
484 assert(beginOfFirstRelevantElement <= lowBit &&
485 "incorrectly moved past an element that lowBit has coverage over");
486 auto operand = *it;
487
488 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
489 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
490 // A bit other than the first bit will be used in this element.
491 // ...... ........ ...
492 // ^---lowBit
493 // ^---beginOfFirstRelevantElement
494 //
495 // Edge-case close to the end of the range.
496 // ...... ........ ...
497 // ^---(position + operandWidth)
498 // ^---lowBit
499 // ^---beginOfFirstRelevantElement
500 //
501 // Edge-case close to the beginning of the rang
502 // ...... ........ ...
503 // ^---lowBit
504 // ^---beginOfFirstRelevantElement
505 //
506 break;
507 }
508
509 // extraction discards this element.
510 // ...... ........ ...
511 // | ^---lowBit
512 // ^---beginOfFirstRelevantElement
513 beginOfFirstRelevantElement += operandWidth;
514 }
515 assert(it != reversedConcatArgs.end() &&
516 "incorrectly failed to find an element which contains coverage of "
517 "lowBit");
518
519 SmallVector<Value> reverseConcatArgs;
520 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
521 size_t extractLo = lowBit - beginOfFirstRelevantElement;
522
523 // Transform individual arguments of innerCat(..., a, b, c,) into
524 // [ extract(a), b, extract(c) ], skipping an extract operation where
525 // possible.
526 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
527 auto concatArg = *it;
528 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
529 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
530
531 if (widthToConsume == operandWidth && extractLo == 0) {
532 reverseConcatArgs.push_back(concatArg);
533 } else {
534 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
535 reverseConcatArgs.push_back(
536 ExtractOp::create(rewriter, op.getLoc(), resultType, *it, extractLo));
537 }
538
539 widthRemaining -= widthToConsume;
540
541 // Beyond the first element, all elements are extracted from position 0.
542 extractLo = 0;
543 }
544
545 if (reverseConcatArgs.size() == 1) {
546 replaceOpAndCopyNamehint(rewriter, op, reverseConcatArgs[0]);
547 } else {
548 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
549 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
550 }
551 return success();
552}
553
554// Transforms extract(lo, replicate(a, N)) into replicate(a, N-c).
555static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate,
556 PatternRewriter &rewriter) {
557 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
558 auto replicateEltWidth =
559 replicate.getOperand().getType().getIntOrFloatBitWidth();
560
561 // If the extract starts at the base of an element and is an even multiple,
562 // we can replace the extract with a smaller replicate.
563 if (op.getLowBit() % replicateEltWidth == 0 &&
564 extractResultWidth % replicateEltWidth == 0) {
565 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
566 replicate.getOperand());
567 return true;
568 }
569
570 // If the extract is completely contained in one element, extract from the
571 // element.
572 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
573 replicateEltWidth) {
574 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
575 rewriter, op, op.getType(), replicate.getOperand(),
576 op.getLowBit() % replicateEltWidth);
577 return true;
578 }
579
580 // We don't currently handle the case of extracting from non-whole elements,
581 // e.g. `extract (replicate 2-bit-thing, N), 1`.
582 return false;
583}
584
585LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
587 return failure();
588 auto *inputOp = op.getInput().getDefiningOp();
589
590 // This turns out to be incredibly expensive. Disable until performance is
591 // addressed.
592#if 0
593 // If the extracted bits are all known, then return the result.
594 auto knownBits = computeKnownBits(op.getInput())
595 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
596 op.getLowBit());
597 if (knownBits.isConstant()) {
598 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
599 knownBits.getConstant());
600 return success();
601 }
602#endif
603
604 // extract(olo, extract(ilo, x)) = extract(olo + ilo, x)
605 if (auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
606 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
607 rewriter, op, op.getType(), innerExtract.getInput(),
608 innerExtract.getLowBit() + op.getLowBit());
609 return success();
610 }
611
612 // extract(lo, cat(a, b, c, d, e)) = cat(extract(lo1, b), c, extract(lo2, d))
613 if (auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
614 return extractConcatToConcatExtract(op, innerCat, rewriter);
615
616 // extract(lo, replicate(a))
617 if (auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
618 if (extractFromReplicate(op, replicate, rewriter))
619 return success();
620
621 // `extract(and(a, cst))` -> `extract(a)` when the relevant bits of the
622 // and/or/xor are not modifying the extracted bits.
623 if (inputOp && inputOp->getNumOperands() == 2 &&
624 isa<AndOp, OrOp, XorOp>(inputOp)) {
625 if (auto cstRHS = inputOp->getOperand(1).getDefiningOp<hw::ConstantOp>()) {
626 auto extractedCst = cstRHS.getValue().extractBits(
627 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
628 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
629 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
630 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
631 return success();
632 }
633
634 // `extract(and(a, cst))` -> `concat(extract(a), 0)` if we only need one
635 // extract to represent the result. Turning it into a pile of extracts is
636 // always fine by our cost model, but we don't want to explode things into
637 // a ton of bits because it will bloat the IR and generated Verilog.
638 if (isa<AndOp>(inputOp)) {
639 // For our cost model, we only do this if the bit pattern is a
640 // contiguous series of ones.
641 unsigned lz = extractedCst.countLeadingZeros();
642 unsigned tz = extractedCst.countTrailingZeros();
643 unsigned pop = extractedCst.popcount();
644 if (extractedCst.getBitWidth() - lz - tz == pop) {
645 auto resultTy = rewriter.getIntegerType(pop);
646 SmallVector<Value> resultElts;
647 if (lz)
648 resultElts.push_back(hw::ConstantOp::create(rewriter, op.getLoc(),
649 APInt::getZero(lz)));
650 resultElts.push_back(rewriter.createOrFold<ExtractOp>(
651 op.getLoc(), resultTy, inputOp->getOperand(0),
652 op.getLowBit() + tz));
653 if (tz)
654 resultElts.push_back(hw::ConstantOp::create(rewriter, op.getLoc(),
655 APInt::getZero(tz)));
656 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
657 return success();
658 }
659 }
660 }
661 }
662
663 // `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is
664 // extracted.
665 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
666 if (auto shlOp = dyn_cast<ShlOp>(inputOp)) {
667 // Don't canonicalize if the shift is multiply used.
668 if (shlOp->hasOneUse())
669 if (auto lhsCst = shlOp.getLhs().getDefiningOp<hw::ConstantOp>())
670 if (lhsCst.getValue().isOne()) {
671 auto newCst = hw::ConstantOp::create(
672 rewriter, shlOp.getLoc(),
673 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
674 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
675 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
676 false);
677 return success();
678 }
679 }
680
681 return failure();
682}
683
684//===----------------------------------------------------------------------===//
685// Associative Variadic operations
686//===----------------------------------------------------------------------===//
687
688// Reduce all operands to a single value (either integer constant or parameter
689// expression) if all the operands are constants.
690static Attribute constFoldAssociativeOp(ArrayRef<Attribute> operands,
691 hw::PEO paramOpcode) {
692 assert(operands.size() > 1 && "caller should handle one-operand case");
693 // We can only fold anything in the case where all operands are known to be
694 // constants. Check the least common one first for an early out.
695 if (!operands[1] || !operands[0])
696 return {};
697
698 // This will fold to a simple constant if all operands are constant.
699 if (llvm::all_of(operands.drop_front(2),
700 [&](Attribute in) { return !!in; })) {
701 SmallVector<mlir::TypedAttr> typedOperands;
702 typedOperands.reserve(operands.size());
703 for (auto operand : operands) {
704 if (auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
705 typedOperands.push_back(typedOperand);
706 else
707 break;
708 }
709 if (typedOperands.size() == operands.size())
710 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
711 }
712
713 return {};
714}
715
716/// When we find a logical operation (and, or, xor) with a constant e.g.
717/// `X & 42`, we want to push the constant into the computation of X if it leads
718/// to simplification.
719///
720/// This function handles the case where the logical operation has a concat
721/// operand. We check to see if we can simplify the concat, e.g. when it has
722/// constant operands.
723///
724/// This returns true when a simplification happens.
725static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp,
726 size_t concatIdx, const APInt &cst,
727 PatternRewriter &rewriter) {
728 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<ConcatOp>();
729 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
730
731 // Check to see if any operands can be simplified by pushing the logical op
732 // into all parts of the concat.
733 bool canSimplify =
734 llvm::any_of(concatOp->getOperands(), [&](Value operand) -> bool {
735 auto *operandOp = operand.getDefiningOp();
736 if (!operandOp)
737 return false;
738
739 // If the concat has a constant operand then we can transform this.
740 if (isa<hw::ConstantOp>(operandOp))
741 return true;
742 // If the concat has the same logical operation and that operation has
743 // a constant operation than we can fold it into that suboperation.
744 return operandOp->getName() == logicalOp->getName() &&
745 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
746 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
747 });
748
749 if (!canSimplify)
750 return false;
751
752 // Create a new instance of the logical operation. We have to do this the
753 // hard way since we're generic across a family of different ops.
754 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
755 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
756 rewriter);
757 };
758
759 // Ok, let's do the transformation. We do this by slicing up the constant
760 // for each unit of the concat and duplicate the operation into the
761 // sub-operand.
762 SmallVector<Value> newConcatOperands;
763 newConcatOperands.reserve(concatOp->getNumOperands());
764
765 // Work from MSB to LSB.
766 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
767 for (Value operand : concatOp->getOperands()) {
768 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
769 nextOperandBit -= operandWidth;
770 // Take a slice of the constant.
771 auto eltCst =
772 hw::ConstantOp::create(rewriter, logicalOp->getLoc(),
773 cst.lshr(nextOperandBit).trunc(operandWidth));
774
775 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
776 }
777
778 // Create the concat, and the rest of the logical op if we need it.
779 Value newResult =
780 ConcatOp::create(rewriter, concatOp.getLoc(), newConcatOperands);
781
782 // If we had a variadic logical op on the top level, then recreate it with the
783 // new concat and without the constant operand.
784 if (logicalOp->getNumOperands() > 2) {
785 auto origOperands = logicalOp->getOperands();
786 SmallVector<Value> operands;
787 // Take any stuff before the concat.
788 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
789 // Take any stuff after the concat but before the constant.
790 operands.append(origOperands.begin() + concatIdx + 1,
791 origOperands.begin() + (origOperands.size() - 1));
792 // Include the new concat.
793 operands.push_back(newResult);
794 newResult = createLogicalOp(operands);
795 }
796
797 replaceOpAndCopyNamehint(rewriter, logicalOp, newResult);
798 return true;
799}
800
801// Determines whether the inputs to a logical element are of opposite
802// comparisons and can lowered into a constant.
803static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands) {
804 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
805
806 for (auto op : operands) {
807 if (auto icmpOp = op.getDefiningOp<ICmpOp>();
808 icmpOp && icmpOp.getTwoState()) {
809 auto predicate = icmpOp.getPredicate();
810 auto lhs = icmpOp.getLhs();
811 auto rhs = icmpOp.getRhs();
812 if (seenPredicates.contains(
813 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
814 return true;
815
816 seenPredicates.insert({predicate, lhs, rhs});
817 }
818 }
819 return false;
820}
821
822OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
823 if (isOpTriviallyRecursive(*this))
824 return {};
825
826 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).getWidth());
827
828 auto inputs = adaptor.getInputs();
829
830 // and(x, 01, 10) -> 00 -- annulment.
831 for (auto operand : inputs) {
832 if (!operand)
833 continue;
834 value &= cast<IntegerAttr>(operand).getValue();
835 if (value.isZero())
836 return getIntAttr(value, getContext());
837 }
838
839 // and(x, -1) -> x.
840 if (inputs.size() == 2 && inputs[1] &&
841 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
842 return getInputs()[0];
843
844 // and(x, x, x) -> x. This also handles and(x) -> x.
845 if (llvm::all_of(getInputs(),
846 [&](auto in) { return in == this->getInputs()[0]; }))
847 return getInputs()[0];
848
849 // and(..., x, ..., ~x, ...) -> 0
850 for (Value arg : getInputs()) {
851 Value subExpr;
852 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
853 for (Value arg2 : getInputs())
854 if (arg2 == subExpr)
855 return getIntAttr(
856 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
857 getContext());
858 }
859 }
860
861 // x0 = icmp(pred, x, y)
862 // x1 = icmp(!pred, x, y)
863 // and(x0, x1) -> 0
865 return getIntAttr(APInt::getZero(cast<IntegerType>(getType()).getWidth()),
866 getContext());
867
868 // Constant fold
869 return constFoldAssociativeOp(inputs, hw::PEO::And);
870}
871
872/// Returns a single common operand that all inputs of the operation `op` can
873/// be traced back to, or an empty `Value` if no such operand exists.
874///
875/// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a`
876/// (assuming the bit-width of `a` is `n`).
877template <typename Op>
878static Value getCommonOperand(Op op) {
879 if (!op.getType().isInteger(1))
880 return Value();
881
882 auto inputs = op.getInputs();
883 size_t size = inputs.size();
884
885 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
886 if (!sourceOp)
887 return Value();
888 Value source = sourceOp.getOperand();
889
890 // Fast path: the input size is not equal to the width of the source.
891 if (size != source.getType().getIntOrFloatBitWidth())
892 return Value();
893
894 // Tracks the bits that were encountered.
895 llvm::BitVector bits(size);
896 bits.set(sourceOp.getLowBit());
897
898 for (size_t i = 1; i != size; ++i) {
899 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
900 if (!extractOp || extractOp.getOperand() != source)
901 return Value();
902 bits.set(extractOp.getLowBit());
903 }
904
905 return bits.all() ? source : Value();
906}
907
908/// Canonicalize an idempotent operation `op` so that only one input of any kind
909/// occurs.
910///
911/// Example: `and(x, y, x, z)` -> `and(x, y, z)`
912template <typename Op>
913static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
914 // Depth limit to search, in operations. Chosen arbitrarily, keep small.
915 constexpr unsigned limit = 3;
916 auto inputs = op.getInputs();
917
918 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
919 llvm::SmallDenseSet<Op, 8> checked;
920 checked.insert(op);
921
922 struct OpWithDepth {
923 Op op;
924 unsigned depth;
925 };
926 llvm::SmallVector<OpWithDepth, 8> worklist;
927
928 auto enqueue = [&worklist, &checked, &op](Value input, unsigned depth) {
929 // Add to worklist if within depth limit, is defined in the same block by
930 // the same kind of operation, has same two-state-ness, and not enqueued
931 // previously.
932 if (depth < limit && input.getParentBlock() == op->getBlock()) {
933 auto inputOp = input.template getDefiningOp<Op>();
934 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
935 checked.insert(inputOp).second)
936 worklist.push_back({inputOp, depth + 1});
937 }
938 };
939
940 for (auto input : uniqueInputs)
941 enqueue(input, 0);
942
943 while (!worklist.empty()) {
944 auto item = worklist.pop_back_val();
945
946 for (auto input : item.op.getInputs()) {
947 uniqueInputs.remove(input);
948 enqueue(input, item.depth);
949 }
950 }
951
952 if (uniqueInputs.size() < inputs.size()) {
953 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
954 uniqueInputs.getArrayRef(),
955 op.getTwoState());
956 return true;
957 }
958
959 return false;
960}
961
962LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
964 return failure();
965
966 auto inputs = op.getInputs();
967 auto size = inputs.size();
968
969 // and(x, and(...)) -> and(x, ...) -- flatten
970 if (tryFlatteningOperands(op, rewriter))
971 return success();
972
973 // and(..., x, ..., x) -> and(..., x, ...) -- idempotent
974 // and(..., x, and(..., x, ...)) -> and(..., and(..., x, ...)) -- idempotent
975 // Trivial and(x), and(x, x) cases are handled by [AndOp::fold] above.
976 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
977 return success();
978
979 assert(size > 1 && "expected 2 or more operands, `fold` should handle this");
980
981 // Patterns for and with a constant on RHS.
982 APInt value;
983 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
984 // and(..., '1) -> and(...) -- identity
985 if (value.isAllOnes()) {
986 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
987 inputs.drop_back(), false);
988 return success();
989 }
990
991 // TODO: Combine multiple constants together even if they aren't at the
992 // end. and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant
993 // folding
994 APInt value2;
995 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
996 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value & value2);
997 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
998 newOperands.push_back(cst);
999 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
1000 newOperands, false);
1001 return success();
1002 }
1003
1004 // Handle 'and' with a single bit constant on the RHS.
1005 if (size == 2 && value.isPowerOf2()) {
1006 // If the LHS is a replicate from a single bit, we can 'concat' it
1007 // into place. e.g.:
1008 // `replicate(x) & 4` -> `concat(zeros, x, zeros)`
1009 // TODO: Generalize this for non-single-bit operands.
1010 if (auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1011 auto replicateOperand = replicate.getOperand();
1012 if (replicateOperand.getType().isInteger(1)) {
1013 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1014 auto trailingZeros = value.countTrailingZeros();
1015
1016 // Don't add zero bit constants unnecessarily.
1017 SmallVector<Value, 3> concatOperands;
1018 if (trailingZeros != resultWidth - 1) {
1019 auto highZeros = hw::ConstantOp::create(
1020 rewriter, op.getLoc(),
1021 APInt::getZero(resultWidth - trailingZeros - 1));
1022 concatOperands.push_back(highZeros);
1023 }
1024 concatOperands.push_back(replicateOperand);
1025 if (trailingZeros != 0) {
1026 auto lowZeros = hw::ConstantOp::create(
1027 rewriter, op.getLoc(), APInt::getZero(trailingZeros));
1028 concatOperands.push_back(lowZeros);
1029 }
1030 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
1031 rewriter, op, op.getType(), concatOperands);
1032 return success();
1033 }
1034 }
1035 }
1036
1037 // Narrow the op if the constant has leading or trailing zeros.
1038 //
1039 // and(a, 0b00101100) -> concat(0b00, and(extract(a), 0b1011), 0b00)
1040 unsigned leadingZeros = value.countLeadingZeros();
1041 unsigned trailingZeros = value.countTrailingZeros();
1042 if (leadingZeros > 0 || trailingZeros > 0) {
1043 unsigned maskLength = value.getBitWidth() - leadingZeros - trailingZeros;
1044
1045 // Extract the non-zero regions of the operands. Look through extracts.
1046 SmallVector<Value> operands;
1047 for (auto input : inputs.drop_back()) {
1048 unsigned offset = trailingZeros;
1049 while (auto extractOp = input.getDefiningOp<ExtractOp>()) {
1050 input = extractOp.getInput();
1051 offset += extractOp.getLowBit();
1052 }
1053 operands.push_back(ExtractOp::create(rewriter, op.getLoc(), input,
1054 offset, maskLength));
1055 }
1056
1057 // Add the narrowed mask if needed.
1058 auto narrowMask = value.extractBits(maskLength, trailingZeros);
1059 if (!narrowMask.isAllOnes())
1060 operands.push_back(hw::ConstantOp::create(
1061 rewriter, inputs.back().getLoc(), narrowMask));
1062
1063 // Create the narrow and op.
1064 Value narrowValue = operands.back();
1065 if (operands.size() > 1)
1066 narrowValue =
1067 AndOp::create(rewriter, op.getLoc(), operands, op.getTwoState());
1068 operands.clear();
1069
1070 // Concatenate the narrow and with the leading and trailing zeros.
1071 if (leadingZeros > 0)
1072 operands.push_back(hw::ConstantOp::create(
1073 rewriter, op.getLoc(), APInt::getZero(leadingZeros)));
1074 operands.push_back(narrowValue);
1075 if (trailingZeros > 0)
1076 operands.push_back(hw::ConstantOp::create(
1077 rewriter, op.getLoc(), APInt::getZero(trailingZeros)));
1078 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, operands);
1079 return success();
1080 }
1081
1082 // and(concat(x, cst1), a, b, c, cst2)
1083 // ==> and(a, b, c, concat(and(x,cst2'), and(cst1,cst2'')).
1084 // We do this for even more multi-use concats since they are "just wiring".
1085 for (size_t i = 0; i < size - 1; ++i) {
1086 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1087 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1088 return success();
1089 }
1090 }
1091
1092 // extracts only of and(...) -> and(extract()...)
1093 if (narrowOperationWidth(op, true, rewriter))
1094 return success();
1095
1096 // and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1)
1097 if (auto source = getCommonOperand(op)) {
1098 auto cmpAgainst =
1099 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getAllOnes(size));
1100 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1101 source, cmpAgainst);
1102 return success();
1103 }
1104
1105 /// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
1106 return failure();
1107}
1108
1109OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1110 if (isOpTriviallyRecursive(*this))
1111 return {};
1112
1113 auto value = APInt::getZero(cast<IntegerType>(getType()).getWidth());
1114 auto inputs = adaptor.getInputs();
1115 // or(x, 10, 01) -> 11
1116 for (auto operand : inputs) {
1117 if (!operand)
1118 continue;
1119 value |= cast<IntegerAttr>(operand).getValue();
1120 if (value.isAllOnes())
1121 return getIntAttr(value, getContext());
1122 }
1123
1124 // or(x, 0) -> x
1125 if (inputs.size() == 2 && inputs[1] &&
1126 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1127 return getInputs()[0];
1128
1129 // or(x, x, x) -> x. This also handles or(x) -> x
1130 if (llvm::all_of(getInputs(),
1131 [&](auto in) { return in == this->getInputs()[0]; }))
1132 return getInputs()[0];
1133
1134 // or(..., x, ..., ~x, ...) -> -1
1135 for (Value arg : getInputs()) {
1136 Value subExpr;
1137 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
1138 for (Value arg2 : getInputs())
1139 if (arg2 == subExpr)
1140 return getIntAttr(
1141 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1142 getContext());
1143 }
1144 }
1145
1146 // x0 = icmp(pred, x, y)
1147 // x1 = icmp(!pred, x, y)
1148 // or(x0, x1) -> 1
1149 if (canCombineOppositeBinCmpIntoConstant(getInputs()))
1150 return getIntAttr(
1151 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1152 getContext());
1153
1154 // Constant fold
1155 return constFoldAssociativeOp(inputs, hw::PEO::Or);
1156}
1157
1158LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
1159 if (isOpTriviallyRecursive(op))
1160 return failure();
1161
1162 auto inputs = op.getInputs();
1163 auto size = inputs.size();
1164
1165 // or(x, or(...)) -> or(x, ...) -- flatten
1166 if (tryFlatteningOperands(op, rewriter))
1167 return success();
1168
1169 // or(..., x, ..., x, ...) -> or(..., x) -- idempotent
1170 // or(..., x, or(..., x, ...)) -> or(..., or(..., x, ...)) -- idempotent
1171 // Trivial or(x), or(x, x) cases are handled by [OrOp::fold].
1172 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
1173 return success();
1174
1175 assert(size > 1 && "expected 2 or more operands");
1176
1177 // Patterns for and with a constant on RHS.
1178 APInt value;
1179 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1180 // or(..., '0) -> or(...) -- identity
1181 if (value.isZero()) {
1182 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1183 inputs.drop_back());
1184 return success();
1185 }
1186
1187 // or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant folding
1188 APInt value2;
1189 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1190 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value | value2);
1191 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1192 newOperands.push_back(cst);
1193 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1194 newOperands);
1195 return success();
1196 }
1197
1198 // or(concat(x, cst1), a, b, c, cst2)
1199 // ==> or(a, b, c, concat(or(x,cst2'), or(cst1,cst2'')).
1200 // We do this for even more multi-use concats since they are "just wiring".
1201 for (size_t i = 0; i < size - 1; ++i) {
1202 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1203 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1204 return success();
1205 }
1206 }
1207
1208 // extracts only of or(...) -> or(extract()...)
1209 if (narrowOperationWidth(op, true, rewriter))
1210 return success();
1211
1212 // or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0)
1213 if (auto source = getCommonOperand(op)) {
1214 auto cmpAgainst =
1215 hw::ConstantOp::create(rewriter, op.getLoc(), APInt::getZero(size));
1216 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1217 source, cmpAgainst);
1218 return success();
1219 }
1220
1221 // or(mux(c_1, a, 0), mux(c_2, a, 0), ..., mux(c_n, a, 0)) -> mux(or(c_1, c_2,
1222 // .., c_n), a, 0)
1223 if (auto firstMux = op.getOperand(0).getDefiningOp<comb::MuxOp>()) {
1224 APInt value;
1225 if (op.getTwoState() && firstMux.getTwoState() &&
1226 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1227 value.isZero()) {
1228 SmallVector<Value> conditions{firstMux.getCond()};
1229 auto check = [&](Value v) {
1230 auto mux = v.getDefiningOp<comb::MuxOp>();
1231 if (!mux)
1232 return false;
1233 conditions.push_back(mux.getCond());
1234 return mux.getTwoState() &&
1235 firstMux.getTrueValue() == mux.getTrueValue() &&
1236 firstMux.getFalseValue() == mux.getFalseValue();
1237 };
1238 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1239 auto cond = comb::OrOp::create(rewriter, op.getLoc(), conditions, true);
1240 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1241 rewriter, op, cond, firstMux.getTrueValue(),
1242 firstMux.getFalseValue(), true);
1243 return success();
1244 }
1245 }
1246 }
1247
1248 /// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
1249 return failure();
1250}
1251
1252OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1253 if (isOpTriviallyRecursive(*this))
1254 return {};
1255
1256 auto size = getInputs().size();
1257 auto inputs = adaptor.getInputs();
1258
1259 // xor(x) -> x -- noop
1260 if (size == 1)
1261 return getInputs()[0];
1262
1263 // xor(x, x) -> 0 -- idempotent
1264 if (size == 2 && getInputs()[0] == getInputs()[1])
1265 return IntegerAttr::get(getType(), 0);
1266
1267 // xor(x, 0) -> x
1268 if (inputs.size() == 2 && inputs[1] &&
1269 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1270 return getInputs()[0];
1271
1272 // xor(xor(x,1),1) -> x
1273 // but not self loop
1274 if (isBinaryNot()) {
1275 Value subExpr;
1276 if (matchPattern(getOperand(0), m_Complement(m_Any(&subExpr))) &&
1277 subExpr != getResult())
1278 return subExpr;
1279 }
1280
1281 // Constant fold
1282 return constFoldAssociativeOp(inputs, hw::PEO::Xor);
1283}
1284
1285// xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1286static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand,
1287 PatternRewriter &rewriter) {
1288 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1289 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1290
1291 Value result =
1292 ICmpOp::create(rewriter, icmp.getLoc(), negatedPred, icmp.getOperand(0),
1293 icmp.getOperand(1), icmp.getTwoState());
1294
1295 // If the xor had other operands, rebuild it.
1296 if (op.getNumOperands() > 2) {
1297 SmallVector<Value, 4> newOperands(op.getOperands());
1298 newOperands.pop_back();
1299 newOperands.erase(newOperands.begin() + icmpOperand);
1300 newOperands.push_back(result);
1301 result =
1302 XorOp::create(rewriter, op.getLoc(), newOperands, op.getTwoState());
1303 }
1304
1305 replaceOpAndCopyNamehint(rewriter, op, result);
1306}
1307
1308LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
1309 if (isOpTriviallyRecursive(op))
1310 return failure();
1311
1312 auto inputs = op.getInputs();
1313 auto size = inputs.size();
1314 assert(size > 1 && "expected 2 or more operands");
1315
1316 // xor(..., x, x) -> xor (...) -- idempotent
1317 if (inputs[size - 1] == inputs[size - 2]) {
1318 assert(size > 2 &&
1319 "expected idempotent case for 2 elements handled already.");
1320 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1321 inputs.drop_back(/*n=*/2), false);
1322 return success();
1323 }
1324
1325 // Patterns for xor with a constant on RHS.
1326 APInt value;
1327 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1328 // xor(..., 0) -> xor(...) -- identity
1329 if (value.isZero()) {
1330 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1331 inputs.drop_back(), false);
1332 return success();
1333 }
1334
1335 // xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2.
1336 APInt value2;
1337 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1338 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value ^ value2);
1339 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1340 newOperands.push_back(cst);
1341 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1342 newOperands, false);
1343 return success();
1344 }
1345
1346 bool isSingleBit = value.getBitWidth() == 1;
1347
1348 // Check for subexpressions that we can simplify.
1349 for (size_t i = 0; i < size - 1; ++i) {
1350 Value operand = inputs[i];
1351
1352 // xor(concat(x, cst1), a, b, c, cst2)
1353 // ==> xor(a, b, c, concat(xor(x,cst2'), xor(cst1,cst2'')).
1354 // We do this for even more multi-use concats since they are "just
1355 // wiring".
1356 if (auto concat = operand.getDefiningOp<ConcatOp>())
1357 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1358 return success();
1359
1360 // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1361 if (isSingleBit && operand.hasOneUse()) {
1362 assert(value == 1 && "single bit constant has to be one if not zero");
1363 if (auto icmp = operand.getDefiningOp<ICmpOp>())
1364 return canonicalizeXorIcmpTrue(op, i, rewriter), success();
1365 }
1366 }
1367 }
1368
1369 // xor(x, xor(...)) -> xor(x, ...) -- flatten
1370 if (tryFlatteningOperands(op, rewriter))
1371 return success();
1372
1373 // extracts only of xor(...) -> xor(extract()...)
1374 if (narrowOperationWidth(op, true, rewriter))
1375 return success();
1376
1377 // xor(a[0], a[1], ..., a[n]) -> parity(a)
1378 if (auto source = getCommonOperand(op)) {
1379 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1380 return success();
1381 }
1382
1383 return failure();
1384}
1385
1386OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1387 if (isOpTriviallyRecursive(*this))
1388 return {};
1389
1390 // sub(x - x) -> 0
1391 if (getRhs() == getLhs())
1392 return getIntAttr(
1393 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1394 getContext());
1395
1396 if (adaptor.getRhs()) {
1397 // If both are constants, we can unconditionally fold.
1398 if (adaptor.getLhs()) {
1399 // Constant fold (c1 - c2) => (c1 + -1*c2).
1400 auto negOne = getIntAttr(
1401 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1402 getContext());
1403 auto rhsNeg = hw::ParamExprAttr::get(
1404 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1405 return hw::ParamExprAttr::get(hw::PEO::Add,
1406 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1407 }
1408
1409 // sub(x - 0) -> x
1410 if (auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1411 if (rhsC.getValue().isZero())
1412 return getLhs();
1413 }
1414 }
1415
1416 return {};
1417}
1418
1419LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
1420 if (isOpTriviallyRecursive(op))
1421 return failure();
1422
1423 // sub(x, cst) -> add(x, -cst)
1424 APInt value;
1425 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1426 auto negCst = hw::ConstantOp::create(rewriter, op.getLoc(), -value);
1427 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1428 false);
1429 return success();
1430 }
1431
1432 // extracts only of sub(...) -> sub(extract()...)
1433 if (narrowOperationWidth(op, false, rewriter))
1434 return success();
1435
1436 return failure();
1437}
1438
1439OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1440 if (isOpTriviallyRecursive(*this))
1441 return {};
1442
1443 auto size = getInputs().size();
1444
1445 // add(x) -> x -- noop
1446 if (size == 1u)
1447 return getInputs()[0];
1448
1449 // Constant fold constant operands.
1450 return constFoldAssociativeOp(adaptor.getOperands(), hw::PEO::Add);
1451}
1452
1453LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
1454 if (isOpTriviallyRecursive(op))
1455 return failure();
1456
1457 auto inputs = op.getInputs();
1458 auto size = inputs.size();
1459 assert(size > 1 && "expected 2 or more operands");
1460
1461 APInt value, value2;
1462
1463 // add(..., 0) -> add(...) -- identity
1464 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1465 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1466 inputs.drop_back(), false);
1467 return success();
1468 }
1469
1470 // add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant folding
1471 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1472 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1473 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value + value2);
1474 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1475 newOperands.push_back(cst);
1476 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1477 newOperands, false);
1478 return success();
1479 }
1480
1481 // add(..., x, x) -> add(..., shl(x, 1))
1482 if (inputs[size - 1] == inputs[size - 2]) {
1483 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1484
1485 auto one = hw::ConstantOp::create(rewriter, op.getLoc(), op.getType(), 1);
1486 auto shiftLeftOp =
1487 comb::ShlOp::create(rewriter, op.getLoc(), inputs.back(), one, false);
1488
1489 newOperands.push_back(shiftLeftOp);
1490 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1491 newOperands, false);
1492 return success();
1493 }
1494
1495 auto shlOp = inputs[size - 1].getDefiningOp<comb::ShlOp>();
1496 // add(..., x, shl(x, c)) -> add(..., mul(x, (1 << c) + 1))
1497 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1498 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1499
1500 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1501 auto rhs =
1502 hw::ConstantOp::create(rewriter, op.getLoc(), (one << value) + one);
1503
1504 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1505 auto mulOp = comb::MulOp::create(rewriter, op.getLoc(), factors, false);
1506
1507 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1508 newOperands.push_back(mulOp);
1509 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1510 newOperands, false);
1511 return success();
1512 }
1513
1514 auto mulOp = inputs[size - 1].getDefiningOp<comb::MulOp>();
1515 // add(..., x, mul(x, c)) -> add(..., mul(x, c + 1))
1516 if (mulOp && mulOp.getInputs().size() == 2 &&
1517 mulOp.getInputs()[0] == inputs[size - 2] &&
1518 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1519
1520 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1521 auto rhs = hw::ConstantOp::create(rewriter, op.getLoc(), value + one);
1522 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1523 auto newMulOp = comb::MulOp::create(rewriter, op.getLoc(), factors, false);
1524
1525 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1526 newOperands.push_back(newMulOp);
1527 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1528 newOperands, false);
1529 return success();
1530 }
1531
1532 // add(a, add(...)) -> add(a, ...) -- flatten
1533 if (tryFlatteningOperands(op, rewriter))
1534 return success();
1535
1536 // extracts only of add(...) -> add(extract()...)
1537 if (narrowOperationWidth(op, false, rewriter))
1538 return success();
1539
1540 // add(add(x, c1), c2) -> add(x, c1 + c2)
1541 auto addOp = inputs[0].getDefiningOp<comb::AddOp>();
1542 if (addOp && addOp.getInputs().size() == 2 &&
1543 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1544 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1545
1546 auto rhs = hw::ConstantOp::create(rewriter, op.getLoc(), value + value2);
1547 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1548 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1549 /*twoState=*/op.getTwoState() && addOp.getTwoState());
1550 return success();
1551 }
1552
1553 return failure();
1554}
1555
1556OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1557 if (isOpTriviallyRecursive(*this))
1558 return {};
1559
1560 auto size = getInputs().size();
1561 auto inputs = adaptor.getInputs();
1562
1563 // mul(x) -> x -- noop
1564 if (size == 1u)
1565 return getInputs()[0];
1566
1567 auto width = cast<IntegerType>(getType()).getWidth();
1568 if (width == 0)
1569 return getIntAttr(APInt::getZero(0), getContext());
1570
1571 APInt value(/*numBits=*/width, 1, /*isSigned=*/false);
1572
1573 // mul(x, 0, 1) -> 0 -- annulment
1574 for (auto operand : inputs) {
1575 if (!operand)
1576 continue;
1577 value *= cast<IntegerAttr>(operand).getValue();
1578 if (value.isZero())
1579 return getIntAttr(value, getContext());
1580 }
1581
1582 // Constant fold
1583 return constFoldAssociativeOp(inputs, hw::PEO::Mul);
1584}
1585
1586LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
1587 if (isOpTriviallyRecursive(op))
1588 return failure();
1589
1590 auto inputs = op.getInputs();
1591 auto size = inputs.size();
1592 assert(size > 1 && "expected 2 or more operands");
1593
1594 APInt value, value2;
1595
1596 // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
1597 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1598 value.isPowerOf2()) {
1599 auto shift = hw::ConstantOp::create(rewriter, op.getLoc(), op.getType(),
1600 value.exactLogBase2());
1601 auto shlOp =
1602 comb::ShlOp::create(rewriter, op.getLoc(), inputs[0], shift, false);
1603
1604 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1605 ArrayRef<Value>(shlOp), false);
1606 return success();
1607 }
1608
1609 // mul(..., 1) -> mul(...) -- identity
1610 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1611 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1612 inputs.drop_back());
1613 return success();
1614 }
1615
1616 // mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant folding
1617 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1618 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1619 auto cst = hw::ConstantOp::create(rewriter, op.getLoc(), value * value2);
1620 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1621 newOperands.push_back(cst);
1622 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1623 newOperands);
1624 return success();
1625 }
1626
1627 // mul(a, mul(...)) -> mul(a, ...) -- flatten
1628 if (tryFlatteningOperands(op, rewriter))
1629 return success();
1630
1631 // extracts only of mul(...) -> mul(extract()...)
1632 if (narrowOperationWidth(op, false, rewriter))
1633 return success();
1634
1635 return failure();
1636}
1637
1638template <class Op, bool isSigned>
1639static OpFoldResult foldDiv(Op op, ArrayRef<Attribute> constants) {
1640 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1641 // divu(x, 1) -> x, divs(x, 1) -> x
1642 if (rhsValue.getValue() == 1)
1643 return op.getLhs();
1644
1645 // If the divisor is zero, do not fold for now.
1646 if (rhsValue.getValue().isZero())
1647 return {};
1648 }
1649
1650 return constFoldBinaryOp(constants, isSigned ? hw::PEO::DivS : hw::PEO::DivU);
1651}
1652
1653OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1654 if (isOpTriviallyRecursive(*this))
1655 return {};
1656 return foldDiv<DivUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1657}
1658
1659OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1660 if (isOpTriviallyRecursive(*this))
1661 return {};
1662 return foldDiv<DivSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1663}
1664
1665template <class Op, bool isSigned>
1666static OpFoldResult foldMod(Op op, ArrayRef<Attribute> constants) {
1667 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1668 // modu(x, 1) -> 0, mods(x, 1) -> 0
1669 if (rhsValue.getValue() == 1)
1670 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1671 op.getContext());
1672
1673 // If the divisor is zero, do not fold for now.
1674 if (rhsValue.getValue().isZero())
1675 return {};
1676 }
1677
1678 if (auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1679 // modu(0, x) -> 0, mods(0, x) -> 0
1680 if (lhsValue.getValue().isZero())
1681 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1682 op.getContext());
1683 }
1684
1685 return constFoldBinaryOp(constants, isSigned ? hw::PEO::ModS : hw::PEO::ModU);
1686}
1687
1688OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1689 if (isOpTriviallyRecursive(*this))
1690 return {};
1691 return foldMod<ModUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1692}
1693
1694OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1695 if (isOpTriviallyRecursive(*this))
1696 return {};
1697 return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1698}
1699//===----------------------------------------------------------------------===//
1700// ConcatOp
1701//===----------------------------------------------------------------------===//
1702
1703// Constant folding
1704OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1705 if (isOpTriviallyRecursive(*this))
1706 return {};
1707
1708 if (getNumOperands() == 1)
1709 return getOperand(0);
1710
1711 // If all the operands are constant, we can fold.
1712 for (auto attr : adaptor.getInputs())
1713 if (!attr || !isa<IntegerAttr>(attr))
1714 return {};
1715
1716 // If we got here, we can constant fold.
1717 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1718 APInt result(resultWidth, 0);
1719
1720 unsigned nextInsertion = resultWidth;
1721 // Insert each chunk into the result.
1722 for (auto attr : adaptor.getInputs()) {
1723 auto chunk = cast<IntegerAttr>(attr).getValue();
1724 nextInsertion -= chunk.getBitWidth();
1725 result.insertBits(chunk, nextInsertion);
1726 }
1727
1728 return getIntAttr(result, getContext());
1729}
1730
1731LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
1732 if (isOpTriviallyRecursive(op))
1733 return failure();
1734
1735 auto inputs = op.getInputs();
1736 auto size = inputs.size();
1737 assert(size > 1 && "expected 2 or more operands");
1738
1739 // This function is used when we flatten neighboring operands of a
1740 // (variadic) concat into a new vesion of the concat. first/last indices
1741 // are inclusive.
1742 auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex,
1743 ValueRange replacements) -> LogicalResult {
1744 SmallVector<Value, 4> newOperands;
1745 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1746 newOperands.append(replacements.begin(), replacements.end());
1747 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1748 if (newOperands.size() == 1)
1749 replaceOpAndCopyNamehint(rewriter, op, newOperands[0]);
1750 else
1751 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1752 newOperands);
1753 return success();
1754 };
1755
1756 Value commonOperand = inputs[0];
1757 for (size_t i = 0; i != size; ++i) {
1758 // Check to see if all operands are the same.
1759 if (inputs[i] != commonOperand)
1760 commonOperand = Value();
1761
1762 // If an operand to the concat is itself a concat, then we can fold them
1763 // together.
1764 if (auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1765 return flattenConcat(i, i, subConcat->getOperands());
1766
1767 // Check for canonicalization due to neighboring operands.
1768 if (i != 0) {
1769 // Merge neighboring constants.
1770 if (auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1771 if (auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1772 unsigned prevWidth = prevCst.getValue().getBitWidth();
1773 unsigned thisWidth = cst.getValue().getBitWidth();
1774 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1775 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1776 << thisWidth;
1777 Value replacement =
1778 hw::ConstantOp::create(rewriter, op.getLoc(), resultCst);
1779 return flattenConcat(i - 1, i, replacement);
1780 }
1781 }
1782
1783 // If the two operands are the same, turn them into a replicate.
1784 if (inputs[i] == inputs[i - 1]) {
1785 Value replacement =
1786 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1787 return flattenConcat(i - 1, i, replacement);
1788 }
1789
1790 // If this input is a replicate, see if we can fold it with the previous
1791 // one.
1792 if (auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1793 // ... x, repl(x, n), ... ==> ..., repl(x, n+1), ...
1794 if (repl.getOperand() == inputs[i - 1]) {
1795 Value replacement = rewriter.createOrFold<ReplicateOp>(
1796 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1797 return flattenConcat(i - 1, i, replacement);
1798 }
1799 // ... repl(x, n), repl(x, m), ... ==> ..., repl(x, n+m), ...
1800 if (auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1801 if (prevRepl.getOperand() == repl.getOperand()) {
1802 Value replacement = rewriter.createOrFold<ReplicateOp>(
1803 op.getLoc(), repl.getOperand(),
1804 repl.getMultiple() + prevRepl.getMultiple());
1805 return flattenConcat(i - 1, i, replacement);
1806 }
1807 }
1808 }
1809
1810 // ... repl(x, n), x, ... ==> ..., repl(x, n+1), ...
1811 if (auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1812 if (repl.getOperand() == inputs[i]) {
1813 Value replacement = rewriter.createOrFold<ReplicateOp>(
1814 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1815 return flattenConcat(i - 1, i, replacement);
1816 }
1817 }
1818
1819 // Merge neighboring extracts of neighboring inputs, e.g.
1820 // {A[3], A[2]} -> A[3:2]
1821 if (auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1822 if (auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1823 if (extract.getInput() == prevExtract.getInput()) {
1824 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1825 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1826 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1827 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1828 Value replacement =
1829 ExtractOp::create(rewriter, op.getLoc(), resType,
1830 extract.getInput(), extract.getLowBit());
1831 return flattenConcat(i - 1, i, replacement);
1832 }
1833 }
1834 }
1835 }
1836 // Merge neighboring array extracts of neighboring inputs, e.g.
1837 // {Array[4], bitcast(Array[3:2])} -> bitcast(A[4:2])
1838
1839 // This represents a slice of an array.
1840 struct ArraySlice {
1841 Value input;
1842 Value index;
1843 size_t width;
1844 static std::optional<ArraySlice> get(Value value) {
1845 assert(isa<IntegerType>(value.getType()) && "expected integer type");
1846 if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
1847 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1848 // array slice op is wrapped with bitcast.
1849 if (auto bitcast = value.getDefiningOp<hw::BitcastOp>())
1850 if (auto arraySlice =
1851 bitcast.getInput().getDefiningOp<hw::ArraySliceOp>())
1852 return ArraySlice{
1853 arraySlice.getInput(), arraySlice.getLowIndex(),
1854 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1855 .getNumElements()};
1856 return std::nullopt;
1857 }
1858 };
1859 if (auto extractOpt = ArraySlice::get(inputs[i])) {
1860 if (auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1861 // Check that two array slices are mergable.
1862 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1863 prevExtractOpt->input == extractOpt->input &&
1864 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1865 extractOpt->width)) {
1866 auto resType = hw::ArrayType::get(
1867 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1868 .getElementType(),
1869 extractOpt->width + prevExtractOpt->width);
1870 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1871 Value replacement = hw::BitcastOp::create(
1872 rewriter, op.getLoc(), resIntType,
1873 hw::ArraySliceOp::create(rewriter, op.getLoc(), resType,
1874 prevExtractOpt->input,
1875 extractOpt->index));
1876 return flattenConcat(i - 1, i, replacement);
1877 }
1878 }
1879 }
1880 }
1881 }
1882
1883 // If all operands were the same, then this is a replicate.
1884 if (commonOperand) {
1885 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1886 commonOperand);
1887 return success();
1888 }
1889
1890 return failure();
1891}
1892
1893//===----------------------------------------------------------------------===//
1894// MuxOp
1895//===----------------------------------------------------------------------===//
1896
1897OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1898 if (isOpTriviallyRecursive(*this))
1899 return {};
1900
1901 // mux (c, b, b) -> b
1902 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1903 return getTrueValue();
1904 if (auto tv = adaptor.getTrueValue())
1905 if (tv == adaptor.getFalseValue())
1906 return tv;
1907
1908 // mux(0, a, b) -> b
1909 // mux(1, a, b) -> a
1910 if (auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1911 if (pred.getValue().isZero() && getFalseValue() != getResult())
1912 return getFalseValue();
1913 if (pred.getValue().isOne() && getTrueValue() != getResult())
1914 return getTrueValue();
1915 }
1916
1917 // mux(cond, 1, 0) -> cond
1918 if (auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1919 if (auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1920 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1921 hw::getBitWidth(getType()) == 1 && getCond() != getResult())
1922 return getCond();
1923
1924 return {};
1925}
1926
1927/// Check to see if the condition to the specified mux is an equality
1928/// comparison `indexValue` and one or more constants. If so, put the
1929/// constants in the constants vector and return true, otherwise return false.
1930///
1931/// This is part of foldMuxChain.
1932///
1933static bool
1934getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted,
1935 std::function<void(hw::ConstantOp)> constantFn) {
1936 // Handle `idx == 42` and `idx != 42`.
1937 if (auto cmp = cond.getDefiningOp<ICmpOp>()) {
1938 // TODO: We could handle things like "x < 2" as two entries.
1939 auto requiredPredicate =
1940 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1941 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1942 if (auto cst = cmp.getRhs().getDefiningOp<hw::ConstantOp>()) {
1943 constantFn(cst);
1944 return true;
1945 }
1946 }
1947 return false;
1948 }
1949
1950 // Handle mux(`idx == 1 || idx == 3`, value, muxchain).
1951 if (auto orOp = cond.getDefiningOp<OrOp>()) {
1952 if (!isInverted)
1953 return false;
1954 for (auto operand : orOp.getOperands())
1955 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
1956 return false;
1957 return true;
1958 }
1959
1960 // Handle mux(`idx != 1 && idx != 3`, muxchain, value).
1961 if (auto andOp = cond.getDefiningOp<AndOp>()) {
1962 if (isInverted)
1963 return false;
1964 for (auto operand : andOp.getOperands())
1965 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
1966 return false;
1967 return true;
1968 }
1969
1970 return false;
1971}
1972
1973/// Given a mux, check to see if the "on true" value (or "on false" value if
1974/// isFalseSide=true) is a mux tree with the same condition. This allows us
1975/// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into
1976/// `array_get (array_create(A, B, C), VAL)` which is far more compact and
1977/// allows synthesis tools to do more interesting optimizations.
1978///
1979/// This returns false if we cannot form the mux tree (or do not want to) and
1980/// returns true if the mux was replaced.
1981static bool foldMuxChain(MuxOp rootMux, bool isFalseSide,
1982 PatternRewriter &rewriter) {
1983 // Get the index value being compared. Later we check to see if it is
1984 // compared to a constant with the right predicate.
1985 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
1986 if (!rootCmp)
1987 return false;
1988 Value indexValue = rootCmp.getLhs();
1989
1990 // Return the value to use if the equality match succeeds.
1991 auto getCaseValue = [&](MuxOp mux) -> Value {
1992 return mux.getOperand(1 + unsigned(!isFalseSide));
1993 };
1994
1995 // Return the value to use if the equality match fails. This is the next
1996 // mux in the sequence or the "otherwise" value.
1997 auto getTreeValue = [&](MuxOp mux) -> Value {
1998 return mux.getOperand(1 + unsigned(isFalseSide));
1999 };
2000
2001 // Start scanning the mux tree to see what we've got. Keep track of the
2002 // constant comparison value and the SSA value to use when equal to it.
2003 SmallVector<Location> locationsFound;
2004 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2005
2006 /// Extract constants and values into `valuesFound` and return true if this is
2007 /// part of the mux tree, otherwise return false.
2008 auto collectConstantValues = [&](MuxOp mux) -> bool {
2010 mux.getCond(), indexValue, isFalseSide, [&](hw::ConstantOp cst) {
2011 valuesFound.push_back({cst, getCaseValue(mux)});
2012 locationsFound.push_back(mux.getCond().getLoc());
2013 locationsFound.push_back(mux->getLoc());
2014 });
2015 };
2016
2017 // Make sure the root is a correct comparison with a constant.
2018 if (!collectConstantValues(rootMux))
2019 return false;
2020
2021 // Make sure that we're not looking at the intermediate node in a mux tree.
2022 if (rootMux->hasOneUse()) {
2023 if (auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2024 if (getTreeValue(userMux) == rootMux.getResult() &&
2025 getMuxChainCondConstant(userMux.getCond(), indexValue, isFalseSide,
2026 [&](hw::ConstantOp cst) {}))
2027 return false;
2028 }
2029 }
2030
2031 // Scan up the tree linearly.
2032 auto nextTreeValue = getTreeValue(rootMux);
2033 while (1) {
2034 auto nextMux = nextTreeValue.getDefiningOp<MuxOp>();
2035 if (!nextMux || !nextMux->hasOneUse())
2036 break;
2037 if (!collectConstantValues(nextMux))
2038 break;
2039 nextTreeValue = getTreeValue(nextMux);
2040 }
2041
2042 // We need to have more than three values to create an array. This is an
2043 // arbitrary threshold which is saying that one or two muxes together is ok,
2044 // but three should be folded.
2045 if (valuesFound.size() < 3)
2046 return false;
2047
2048 // If the array is greater that 9 bits, it will take over 512 elements and
2049 // it will be too large for a single expression.
2050 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2051 if (indexWidth >= 9)
2052 return false;
2053
2054 // Next we need to see if the values are dense-ish. We don't want to have
2055 // a tremendous number of replicated entries in the array. Some sparsity is
2056 // ok though, so we require the table to be at least 5/8 utilized.
2057 uint64_t tableSize = 1ULL << indexWidth;
2058 if (valuesFound.size() < (tableSize * 5) / 8)
2059 return false; // Not dense enough.
2060
2061 // Ok, we're going to do the transformation, start by building the table
2062 // filled with the "otherwise" value.
2063 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2064
2065 // Fill in entries in the table from the leaf to the root of the expression.
2066 // This ensures that any duplicate matches end up with the ultimate value,
2067 // which is the one closer to the root.
2068 for (auto &elt : llvm::reverse(valuesFound)) {
2069 uint64_t idx = elt.first.getValue().getZExtValue();
2070 assert(idx < table.size() && "constant should be same bitwidth as index");
2071 table[idx] = elt.second;
2072 }
2073
2074 // The hw.array_create operation has the operand list in unintuitive order
2075 // with a[0] stored as the last element, not the first.
2076 std::reverse(table.begin(), table.end());
2077
2078 // Build the array_create and the array_get.
2079 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2080 auto array = hw::ArrayCreateOp::create(rewriter, fusedLoc, table);
2081 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
2082 indexValue);
2083 return true;
2084}
2085
2086/// Given a fully associative variadic operation like (a+b+c+d), break the
2087/// expression into two parts, one without the specified operand (e.g.
2088/// `tmp = a+b+d`) and one that combines that into the full expression (e.g.
2089/// `tmp+c`), and return the inner expression.
2090///
2091/// NOTE: This mutates the operation in place if it only has a single user,
2092/// which assumes that user will be removed.
2093///
2094static Value extractOperandFromFullyAssociative(Operation *fullyAssoc,
2095 size_t operandNo,
2096 PatternRewriter &rewriter) {
2097 assert(fullyAssoc->getNumOperands() >= 2 && "cannot split up unary ops");
2098 assert(operandNo < fullyAssoc->getNumOperands() && "Invalid operand #");
2099
2100 // If this expression already has two operands (the common case) no splitting
2101 // is necessary.
2102 if (fullyAssoc->getNumOperands() == 2)
2103 return fullyAssoc->getOperand(operandNo ^ 1);
2104
2105 // If the operation has a single use, mutate it in place.
2106 if (fullyAssoc->hasOneUse()) {
2107 rewriter.modifyOpInPlace(fullyAssoc,
2108 [&]() { fullyAssoc->eraseOperand(operandNo); });
2109 return fullyAssoc->getResult(0);
2110 }
2111
2112 // Form the new operation with the operands that remain.
2113 SmallVector<Value> operands;
2114 operands.append(fullyAssoc->getOperands().begin(),
2115 fullyAssoc->getOperands().begin() + operandNo);
2116 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2117 fullyAssoc->getOperands().end());
2118 Value opWithoutExcluded = createGenericOp(
2119 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2120 Value excluded = fullyAssoc->getOperand(operandNo);
2121
2122 Value fullResult =
2123 createGenericOp(fullyAssoc->getLoc(), fullyAssoc->getName(),
2124 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2125 replaceOpAndCopyNamehint(rewriter, fullyAssoc, fullResult);
2126 return opWithoutExcluded;
2127}
2128
2129/// Fold things like `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond)|a` and
2130/// `mux(cond, a, x|y|z|a) -> `(x|y|z)&replicate(~cond) | a` (when isTrueOperand
2131/// is true. Return true on successful transformation, false if not.
2132///
2133/// These are various forms of "predicated ops" that can be handled with a
2134/// replicate/and combination.
2135static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
2136 PatternRewriter &rewriter) {
2137 // Check to see the operand in question is an operation. If it is a port,
2138 // we can't simplify it.
2139 Operation *subExpr =
2140 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2141 if (!subExpr || subExpr->getNumOperands() < 2)
2142 return false;
2143
2144 // If this isn't an operation we can handle, don't spend energy on it.
2145 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2146 return false;
2147
2148 // Check to see if the common value occurs in the operand list for the
2149 // subexpression op. If so, then we can simplify it.
2150 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2151 size_t opNo = 0, e = subExpr->getNumOperands();
2152 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2153 ++opNo;
2154 if (opNo == e)
2155 return false;
2156
2157 // If we got a hit, then go ahead and simplify it!
2158 Value cond = op.getCond();
2159
2160 // `mux(cond, a, mux(cond2, a, b))` -> `mux(cond|cond2, a, b)`
2161 // `mux(cond, a, mux(cond2, b, a))` -> `mux(cond|~cond2, a, b)`
2162 // `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)`
2163 // `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)`
2164 if (auto subMux = dyn_cast<MuxOp>(subExpr)) {
2165 if (subMux == op)
2166 return false;
2167
2168 Value otherValue;
2169 Value subCond = subMux.getCond();
2170
2171 // Invert th subCond if needed and dig out the 'b' value.
2172 if (subMux.getTrueValue() == commonValue)
2173 otherValue = subMux.getFalseValue();
2174 else if (subMux.getFalseValue() == commonValue) {
2175 otherValue = subMux.getTrueValue();
2176 subCond = createOrFoldNot(op.getLoc(), subCond, rewriter);
2177 } else {
2178 // We can't fold `mux(cond, a, mux(a, x, y))`.
2179 return false;
2180 }
2181
2182 // Invert the outer cond if needed, and combine the mux conditions.
2183 if (!isTrueOperand)
2184 cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2185 cond = rewriter.createOrFold<OrOp>(op.getLoc(), cond, subCond, false);
2186 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2187 otherValue, op.getTwoState());
2188 return true;
2189 }
2190
2191 // Invert the condition if needed. Or/Xor invert when dealing with
2192 // TrueOperand, And inverts for False operand.
2193 bool isaAndOp = isa<AndOp>(subExpr);
2194 if (isTrueOperand ^ isaAndOp)
2195 cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2196
2197 auto extendedCond =
2198 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2199
2200 // Cache this information before subExpr is erased by extraction below.
2201 bool isaXorOp = isa<XorOp>(subExpr);
2202 bool isaOrOp = isa<OrOp>(subExpr);
2203
2204 // Handle the fully associative ops, start by pulling out the subexpression
2205 // from a many operand version of the op.
2206 auto restOfAssoc =
2207 extractOperandFromFullyAssociative(subExpr, opNo, rewriter);
2208
2209 // `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond) | a`
2210 // `mux(cond, x^y^z^a, a)` -> `(x^y^z)&replicate(cond) ^ a`
2211 if (isaOrOp || isaXorOp) {
2212 auto masked = rewriter.createOrFold<AndOp>(op.getLoc(), extendedCond,
2213 restOfAssoc, false);
2214 if (isaXorOp)
2215 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2216 commonValue, false);
2217 else
2218 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2219 false);
2220 return true;
2221 }
2222
2223 // `mux(cond, a, x&y&z&a)` -> `((x&y&z)|replicate(cond)) & a`
2224 assert(isaAndOp && "unexpected operation here");
2225 auto masked = rewriter.createOrFold<OrOp>(op.getLoc(), extendedCond,
2226 restOfAssoc, false);
2227 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2228 false);
2229 return true;
2230}
2231
2232/// This function is invoke when we find a mux with true/false operations that
2233/// have the same opcode. Check to see if we can strength reduce the mux by
2234/// applying it to less data by applying this transformation:
2235/// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2236static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp,
2237 Operation *falseOp,
2238 PatternRewriter &rewriter) {
2239 // Right now we only apply to concat.
2240 // TODO: Generalize this to and, or, xor, icmp(!), which all occur in practice
2241 if (!isa<ConcatOp>(trueOp))
2242 return false;
2243
2244 // Decode the operands, looking through recursive concats and replicates.
2245 SmallVector<Value> trueOperands, falseOperands;
2246 getConcatOperands(trueOp->getResult(0), trueOperands);
2247 getConcatOperands(falseOp->getResult(0), falseOperands);
2248
2249 size_t numTrueOperands = trueOperands.size();
2250 size_t numFalseOperands = falseOperands.size();
2251
2252 if (!numTrueOperands || !numFalseOperands ||
2253 (trueOperands.front() != falseOperands.front() &&
2254 trueOperands.back() != falseOperands.back()))
2255 return false;
2256
2257 // Pull all leading shared operands out into their own op if any are common.
2258 if (trueOperands.front() == falseOperands.front()) {
2259 SmallVector<Value> operands;
2260 size_t i;
2261 for (i = 0; i < numTrueOperands; ++i) {
2262 Value trueOperand = trueOperands[i];
2263 if (trueOperand == falseOperands[i])
2264 operands.push_back(trueOperand);
2265 else
2266 break;
2267 }
2268 if (i == numTrueOperands) {
2269 // Selecting between distinct, but lexically identical, concats.
2270 replaceOpAndCopyNamehint(rewriter, mux, trueOp->getResult(0));
2271 return true;
2272 }
2273
2274 Value sharedMSB;
2275 if (llvm::all_of(operands, [&](Value v) { return v == operands.front(); }))
2276 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2277 mux->getLoc(), operands.front(), operands.size());
2278 else
2279 sharedMSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2280 operands.clear();
2281
2282 // Get a concat of the LSB's on each side.
2283 operands.append(trueOperands.begin() + i, trueOperands.end());
2284 Value trueLSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2285 operands.clear();
2286 operands.append(falseOperands.begin() + i, falseOperands.end());
2287 Value falseLSB =
2288 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2289 // Merge the LSBs with a new mux and concat the MSB with the LSB to be
2290 // done.
2291 Value lsb = rewriter.createOrFold<MuxOp>(
2292 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2293 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2294 return true;
2295 }
2296
2297 // If trailing operands match, try to commonize them.
2298 if (trueOperands.back() == falseOperands.back()) {
2299 SmallVector<Value> operands;
2300 size_t i;
2301 for (i = 0;; ++i) {
2302 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2303 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2304 operands.push_back(trueOperand);
2305 else
2306 break;
2307 }
2308 std::reverse(operands.begin(), operands.end());
2309 Value sharedLSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2310 operands.clear();
2311
2312 // Get a concat of the MSB's on each side.
2313 operands.append(trueOperands.begin(), trueOperands.end() - i);
2314 Value trueMSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2315 operands.clear();
2316 operands.append(falseOperands.begin(), falseOperands.end() - i);
2317 Value falseMSB =
2318 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2319 // Merge the MSBs with a new mux and concat the MSB with the LSB to be done.
2320 Value msb = rewriter.createOrFold<MuxOp>(
2321 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2322 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2323 return true;
2324 }
2325
2326 return false;
2327}
2328
2329// If both arguments of the mux are arrays with the same elements, sink the
2330// mux and return a uniform array initializing all elements to it.
2331static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter) {
2332 auto trueVec = op.getTrueValue().getDefiningOp<hw::ArrayCreateOp>();
2333 auto falseVec = op.getFalseValue().getDefiningOp<hw::ArrayCreateOp>();
2334 if (!trueVec || !falseVec)
2335 return false;
2336 if (!trueVec.isUniform() || !falseVec.isUniform())
2337 return false;
2338
2339 auto mux = MuxOp::create(rewriter, op.getLoc(), op.getCond(),
2340 trueVec.getUniformElement(),
2341 falseVec.getUniformElement(), op.getTwoState());
2342
2343 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2344 rewriter.replaceOpWithNewOp<hw::ArrayCreateOp>(op, values);
2345 return true;
2346}
2347
2348/// If the mux condition is an operand to the op defining its true or false
2349/// value, replace the condition with 1 or 0.
2350static bool assumeMuxCondInOperand(Value muxCond, Value muxValue,
2351 bool constCond, PatternRewriter &rewriter) {
2352 if (!muxValue.hasOneUse())
2353 return false;
2354 auto *op = muxValue.getDefiningOp();
2355 if (!op || !isa_and_nonnull<CombDialect>(op->getDialect()))
2356 return false;
2357 if (!llvm::is_contained(op->getOperands(), muxCond))
2358 return false;
2359 OpBuilder::InsertionGuard guard(rewriter);
2360 rewriter.setInsertionPoint(op);
2361 auto condValue =
2362 hw::ConstantOp::create(rewriter, muxCond.getLoc(), APInt(1, constCond));
2363 rewriter.modifyOpInPlace(op, [&] {
2364 for (auto &use : op->getOpOperands())
2365 if (use.get() == muxCond)
2366 use.set(condValue);
2367 });
2368 return true;
2369}
2370
2371namespace {
2372struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
2373 using OpRewritePattern::OpRewritePattern;
2374
2375 LogicalResult matchAndRewrite(MuxOp op,
2376 PatternRewriter &rewriter) const override;
2377};
2378
2379LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
2380 PatternRewriter &rewriter) const {
2381 if (isOpTriviallyRecursive(op))
2382 return failure();
2383
2384 // If the op has a SV attribute, don't optimize it.
2385 if (hasSVAttributes(op))
2386 return failure();
2387 APInt value;
2388
2389 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2390 if (value.getBitWidth() == 1) {
2391 // mux(a, 0, b) -> and(~a, b) for single-bit values.
2392 if (value.isZero()) {
2393 auto notCond = createOrFoldNot(op.getLoc(), op.getCond(), rewriter);
2394 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2395 op.getFalseValue(), false);
2396 return success();
2397 }
2398
2399 // mux(a, 1, b) -> or(a, b) for single-bit values.
2400 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2401 op.getFalseValue(), false);
2402 return success();
2403 }
2404
2405 // Check for mux of two constants. There are many ways to simplify them.
2406 APInt value2;
2407 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2408 // When both inputs are constants and differ by only one bit, we can
2409 // simplify by splitting the mux into up to three contiguous chunks: one
2410 // for the differing bit and up to two for the bits that are the same.
2411 // E.g. mux(a, 3'h2, 0) -> concat(0, mux(a, 1, 0), 0) -> concat(0, a, 0)
2412 APInt xorValue = value ^ value2;
2413 if (xorValue.isPowerOf2()) {
2414 unsigned leadingZeros = xorValue.countLeadingZeros();
2415 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2416 SmallVector<Value, 3> operands;
2417
2418 // Concat operands go from MSB to LSB, so we handle chunks in reverse
2419 // order of bit indexes.
2420 // For the chunks that are identical (i.e. correspond to 0s in
2421 // xorValue), we can extract directly from either input value, and we
2422 // arbitrarily pick the trueValue().
2423
2424 if (leadingZeros > 0)
2425 operands.push_back(rewriter.createOrFold<ExtractOp>(
2426 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2427
2428 // Handle the differing bit, which should simplify into either cond or
2429 // ~cond.
2430 auto v1 = rewriter.createOrFold<ExtractOp>(
2431 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2432 auto v2 = rewriter.createOrFold<ExtractOp>(
2433 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2434 operands.push_back(rewriter.createOrFold<MuxOp>(
2435 op.getLoc(), op.getCond(), v1, v2, false));
2436
2437 if (trailingZeros > 0)
2438 operands.push_back(rewriter.createOrFold<ExtractOp>(
2439 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2440
2441 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2442 operands);
2443 return success();
2444 }
2445
2446 // If the true value is all ones and the false is all zeros then we have a
2447 // replicate pattern.
2448 if (value.isAllOnes() && value2.isZero()) {
2449 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2450 rewriter, op, op.getType(), op.getCond());
2451 return success();
2452 }
2453 }
2454 }
2455
2456 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2457 value.getBitWidth() == 1) {
2458 // mux(a, b, 0) -> and(a, b) for single-bit values.
2459 if (value.isZero()) {
2460 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2461 op.getTrueValue(), false);
2462 return success();
2463 }
2464
2465 // mux(a, b, 1) -> or(~a, b) for single-bit values.
2466 // falseValue() is known to be a single-bit 1, which we can use for
2467 // the 1 in the representation of ~ using xor.
2468 auto notCond = rewriter.createOrFold<XorOp>(op.getLoc(), op.getCond(),
2469 op.getFalseValue(), false);
2470 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2471 op.getTrueValue(), false);
2472 return success();
2473 }
2474
2475 // mux(!a, b, c) -> mux(a, c, b)
2476 Value subExpr;
2477 Operation *condOp = op.getCond().getDefiningOp();
2478 if (condOp && matchPattern(condOp, m_Complement(m_Any(&subExpr))) &&
2479 op.getTwoState()) {
2480 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2481 subExpr, op.getFalseValue(),
2482 op.getTrueValue(), true);
2483 return success();
2484 }
2485
2486 // Same but with Demorgan's law.
2487 // mux(and(~a, ~b, ~c), x, y) -> mux(or(a, b, c), y, x)
2488 // mux(or(~a, ~b, ~c), x, y) -> mux(and(a, b, c), y, x)
2489 if (condOp && condOp->hasOneUse()) {
2490 SmallVector<Value> invertedOperands;
2491
2492 /// Scan all the operands to see if they are complemented. If so, build a
2493 /// vector of them and return true, otherwise return false.
2494 auto getInvertedOperands = [&]() -> bool {
2495 for (Value operand : condOp->getOperands()) {
2496 if (matchPattern(operand, m_Complement(m_Any(&subExpr))))
2497 invertedOperands.push_back(subExpr);
2498 else
2499 return false;
2500 }
2501 return true;
2502 };
2503
2504 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2505 auto newOr =
2506 rewriter.createOrFold<OrOp>(op.getLoc(), invertedOperands, false);
2507 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2508 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2509 op.getTwoState());
2510 return success();
2511 }
2512 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2513 auto newAnd =
2514 rewriter.createOrFold<AndOp>(op.getLoc(), invertedOperands, false);
2515 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2516 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2517 op.getTwoState());
2518 return success();
2519 }
2520 }
2521
2522 if (auto falseMux = op.getFalseValue().getDefiningOp<MuxOp>();
2523 falseMux && falseMux != op) {
2524 // mux(selector, x, mux(selector, y, z) = mux(selector, x, z)
2525 if (op.getCond() == falseMux.getCond()) {
2526 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2527 rewriter, op, op.getCond(), op.getTrueValue(),
2528 falseMux.getFalseValue(), op.getTwoStateAttr());
2529 return success();
2530 }
2531
2532 // Check to see if we can fold a mux tree into an array_create/get pair.
2533 if (foldMuxChain(op, /*isFalse*/ true, rewriter))
2534 return success();
2535 }
2536
2537 if (auto trueMux = op.getTrueValue().getDefiningOp<MuxOp>();
2538 trueMux && trueMux != op) {
2539 // mux(selector, mux(selector, a, b), c) = mux(selector, a, c)
2540 if (op.getCond() == trueMux.getCond()) {
2541 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2542 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2543 op.getFalseValue(), op.getTwoStateAttr());
2544 return success();
2545 }
2546
2547 // Check to see if we can fold a mux tree into an array_create/get pair.
2548 if (foldMuxChain(op, /*isFalseSide*/ false, rewriter))
2549 return success();
2550 }
2551
2552 // mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c))
2553 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2554 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2555 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2556 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2557 falseMux != op) {
2558 auto subMux = MuxOp::create(
2559 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2560 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2561 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2562 trueMux.getTrueValue(), subMux,
2563 op.getTwoStateAttr());
2564 return success();
2565 }
2566
2567 // mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b)
2568 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2569 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2570 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2571 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2572 falseMux != op) {
2573 auto subMux = MuxOp::create(
2574 rewriter, rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2575 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2576 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2577 subMux, trueMux.getFalseValue(),
2578 op.getTwoStateAttr());
2579 return success();
2580 }
2581
2582 // mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b)
2583 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2584 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2585 trueMux && falseMux &&
2586 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2587 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2588 falseMux != op) {
2589 auto subMux =
2590 MuxOp::create(rewriter,
2591 rewriter.getFusedLoc(
2592 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2593 op.getCond(), trueMux.getCond(), falseMux.getCond());
2594 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2595 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2596 op.getTwoStateAttr());
2597 return success();
2598 }
2599
2600 // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a
2601 if (foldCommonMuxValue(op, false, rewriter))
2602 return success();
2603 // mux(cond, a, x|y|z|a) -> (x|y|z)&replicate(~cond) | a
2604 if (foldCommonMuxValue(op, true, rewriter))
2605 return success();
2606
2607 // `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2608 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2609 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2610 if (trueOp->getName() == falseOp->getName())
2611 if (foldCommonMuxOperation(op, trueOp, falseOp, rewriter))
2612 return success();
2613
2614 // extracts only of mux(...) -> mux(extract()...)
2615 if (narrowOperationWidth(op, true, rewriter))
2616 return success();
2617
2618 // mux(cond, repl(n, a1), repl(n, a2)) -> repl(n, mux(cond, a1, a2))
2619 if (foldMuxOfUniformArrays(op, rewriter))
2620 return success();
2621
2622 // mux(cond, opA(cond), opB(cond)) -> mux(cond, opA(1), opB(1))
2623 if (op.getTrueValue().getDefiningOp() &&
2624 op.getTrueValue().getDefiningOp() != op)
2625 if (assumeMuxCondInOperand(op.getCond(), op.getTrueValue(), true, rewriter))
2626 return success();
2627 if (op.getFalseValue().getDefiningOp() &&
2628 op.getFalseValue().getDefiningOp() != op)
2629
2630 if (assumeMuxCondInOperand(op.getCond(), op.getFalseValue(), false,
2631 rewriter))
2632 return success();
2633
2634 return failure();
2635}
2636
2637static bool foldArrayOfMuxes(hw::ArrayCreateOp op, PatternRewriter &rewriter) {
2638 // Do not fold uniform or singleton arrays to avoid duplicating muxes.
2639 if (op.getInputs().empty() || op.isUniform())
2640 return false;
2641 auto inputs = op.getInputs();
2642 if (inputs.size() <= 1)
2643 return false;
2644
2645 // Check the operands to the array create. Ensure all of them are the
2646 // same op with the same number of operands.
2647 auto first = inputs[0].getDefiningOp<comb::MuxOp>();
2648 if (!first || hasSVAttributes(first))
2649 return false;
2650
2651 // Check whether all operands are muxes with the same condition.
2652 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2653 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2654 if (!input || first.getCond() != input.getCond())
2655 return false;
2656 }
2657
2658 // Collect the true and the false branches into arrays.
2659 SmallVector<Value> trues{first.getTrueValue()};
2660 SmallVector<Value> falses{first.getFalseValue()};
2661 SmallVector<Location> locs{first->getLoc()};
2662 bool isTwoState = true;
2663 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2664 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2665 trues.push_back(input.getTrueValue());
2666 falses.push_back(input.getFalseValue());
2667 locs.push_back(input->getLoc());
2668 if (!input.getTwoState())
2669 isTwoState = false;
2670 }
2671
2672 // Define the location of the array create as the aggregate of all muxes.
2673 auto loc = FusedLoc::get(op.getContext(), locs);
2674
2675 // Replace the create with an aggregate operation. Push the create op
2676 // into the operands of the aggregate operation.
2677 auto arrayTy = op.getType();
2678 auto trueValues = hw::ArrayCreateOp::create(rewriter, loc, arrayTy, trues);
2679 auto falseValues = hw::ArrayCreateOp::create(rewriter, loc, arrayTy, falses);
2680 rewriter.replaceOpWithNewOp<comb::MuxOp>(op, arrayTy, first.getCond(),
2681 trueValues, falseValues, isTwoState);
2682 return true;
2683}
2684
2685struct ArrayRewriter : public mlir::OpRewritePattern<hw::ArrayCreateOp> {
2686 using OpRewritePattern::OpRewritePattern;
2687
2688 LogicalResult matchAndRewrite(hw::ArrayCreateOp op,
2689 PatternRewriter &rewriter) const override {
2690 if (foldArrayOfMuxes(op, rewriter))
2691 return success();
2692 return failure();
2693 }
2694};
2695
2696} // namespace
2697
2698void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2699 MLIRContext *context) {
2700 results.insert<MuxRewriter, ArrayRewriter>(context);
2701}
2702
2703//===----------------------------------------------------------------------===//
2704// ICmpOp
2705//===----------------------------------------------------------------------===//
2706
2707// Calculate the result of a comparison when the LHS and RHS are both
2708// constants.
2709static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs,
2710 const APInt &rhs) {
2711 switch (predicate) {
2712 case ICmpPredicate::eq:
2713 return lhs.eq(rhs);
2714 case ICmpPredicate::ne:
2715 return lhs.ne(rhs);
2716 case ICmpPredicate::slt:
2717 return lhs.slt(rhs);
2718 case ICmpPredicate::sle:
2719 return lhs.sle(rhs);
2720 case ICmpPredicate::sgt:
2721 return lhs.sgt(rhs);
2722 case ICmpPredicate::sge:
2723 return lhs.sge(rhs);
2724 case ICmpPredicate::ult:
2725 return lhs.ult(rhs);
2726 case ICmpPredicate::ule:
2727 return lhs.ule(rhs);
2728 case ICmpPredicate::ugt:
2729 return lhs.ugt(rhs);
2730 case ICmpPredicate::uge:
2731 return lhs.uge(rhs);
2732 case ICmpPredicate::ceq:
2733 return lhs.eq(rhs);
2734 case ICmpPredicate::cne:
2735 return lhs.ne(rhs);
2736 case ICmpPredicate::weq:
2737 return lhs.eq(rhs);
2738 case ICmpPredicate::wne:
2739 return lhs.ne(rhs);
2740 }
2741 llvm_unreachable("unknown comparison predicate");
2742}
2743
2744// Returns the result of applying the predicate when the LHS and RHS are the
2745// exact same value.
2746static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) {
2747 switch (predicate) {
2748 case ICmpPredicate::eq:
2749 case ICmpPredicate::sle:
2750 case ICmpPredicate::sge:
2751 case ICmpPredicate::ule:
2752 case ICmpPredicate::uge:
2753 case ICmpPredicate::ceq:
2754 case ICmpPredicate::weq:
2755 return true;
2756 case ICmpPredicate::ne:
2757 case ICmpPredicate::slt:
2758 case ICmpPredicate::sgt:
2759 case ICmpPredicate::ult:
2760 case ICmpPredicate::ugt:
2761 case ICmpPredicate::cne:
2762 case ICmpPredicate::wne:
2763 return false;
2764 }
2765 llvm_unreachable("unknown comparison predicate");
2766}
2767
2768OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2769 // gt a, a -> false
2770 // gte a, a -> true
2771 if (getLhs() == getRhs()) {
2772 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2773 return IntegerAttr::get(getType(), val);
2774 }
2775
2776 // gt 1, 2 -> false
2777 if (auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2778 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2779 auto val =
2780 applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2781 return IntegerAttr::get(getType(), val);
2782 }
2783 }
2784 return {};
2785}
2786
2787// Given a range of operands, computes the number of matching prefix and
2788// suffix elements. This does not perform cross-element matching.
2789template <typename Range>
2790static size_t computeCommonPrefixLength(const Range &a, const Range &b) {
2791 size_t commonPrefixLength = 0;
2792 auto ia = a.begin();
2793 auto ib = b.begin();
2794
2795 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2796 if (*ia != *ib) {
2797 break;
2798 }
2799 }
2800
2801 return commonPrefixLength;
2802}
2803
2804static size_t getTotalWidth(ArrayRef<Value> operands) {
2805 size_t totalWidth = 0;
2806 for (auto operand : operands) {
2807 // getIntOrFloatBitWidth should never raise, since all arguments to
2808 // ConcatOp are integers.
2809 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2810 assert(width >= 0);
2811 totalWidth += width;
2812 }
2813 return totalWidth;
2814}
2815
2816/// Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise
2817/// comparison on common prefix and suffixes. Returns success() if a rewriting
2818/// happens. This handles both concat and replicate.
2819static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs,
2820 Operation *rhs,
2821 PatternRewriter &rewriter) {
2822 // It is safe to assume that [{lhsOperands, rhsOperands}.size() > 0] and
2823 // all elements have non-zero length. Both these invariants are verified
2824 // by the ConcatOp verifier.
2825 SmallVector<Value> lhsOperands, rhsOperands;
2826 getConcatOperands(lhs->getResult(0), lhsOperands);
2827 getConcatOperands(rhs->getResult(0), rhsOperands);
2828 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2829
2830 auto formCatOrReplicate = [&](Location loc,
2831 ArrayRef<Value> operands) -> Value {
2832 assert(!operands.empty());
2833 Value sameElement = operands[0];
2834 for (size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2835 if (sameElement != operands[i])
2836 sameElement = Value();
2837 if (sameElement)
2838 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2839 operands.size());
2840 return rewriter.createOrFold<ConcatOp>(loc, operands);
2841 };
2842
2843 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2844 Value rhs) -> LogicalResult {
2845 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2846 op.getTwoState());
2847 return success();
2848 };
2849
2850 size_t commonPrefixLength =
2851 computeCommonPrefixLength(lhsOperands, rhsOperands);
2852 if (commonPrefixLength == lhsOperands.size()) {
2853 // cat(a, b, c) == cat(a, b, c) -> 1
2854 bool result = applyCmpPredicateToEqualOperands(op.getPredicate());
2855 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2856 APInt(1, result));
2857 return success();
2858 }
2859
2860 size_t commonSuffixLength = computeCommonPrefixLength(
2861 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2862
2863 size_t commonPrefixTotalWidth =
2864 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2865 size_t commonSuffixTotalWidth =
2866 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2867 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2868 .drop_back(commonSuffixLength);
2869 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2870 .drop_back(commonSuffixLength);
2871
2872 auto replaceWithoutReplicatingSignBit = [&]() {
2873 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2874 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2875 return replaceWith(op.getPredicate(), newLhs, newRhs);
2876 };
2877
2878 auto replaceWithReplicatingSignBit = [&]() {
2879 auto firstNonEmptyValue = lhsOperands[0];
2880 auto firstNonEmptyElemWidth =
2881 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2882 Value signBit = rewriter.createOrFold<ExtractOp>(
2883 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2884
2885 auto newLhs = ConcatOp::create(rewriter, lhs->getLoc(), signBit, lhsOnly);
2886 auto newRhs = ConcatOp::create(rewriter, rhs->getLoc(), signBit, rhsOnly);
2887 return replaceWith(op.getPredicate(), newLhs, newRhs);
2888 };
2889
2890 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2891 // scmp(cat(..x, b), cat(..y, b)) == scmp(cat(..x), cat(..y))
2892 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2893 return replaceWithoutReplicatingSignBit();
2894
2895 // scmp(cat(a, ..x, b), cat(a, ..y, b)) == scmp(cat(sgn(a), ..x),
2896 // cat(sgn(b), ..y)) Note that we cannot perform this optimization if
2897 // [width(b) = 0 && width(a) <= 1]. since that common prefix is the sign
2898 // bit. Doing the rewrite can result in an infinite loop.
2899 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2900 return replaceWithReplicatingSignBit();
2901
2902 } else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2903 // ucmp(cat(a, ..x, b), cat(a, ..y, b)) = ucmp(cat(..x), cat(..y))
2904 return replaceWithoutReplicatingSignBit();
2905 }
2906
2907 return failure();
2908}
2909
2910/// Given an equality comparison with a constant value and some operand that has
2911/// known bits, simplify the comparison to check only the unknown bits of the
2912/// input.
2913///
2914/// One simple example of this is that `concat(0, stuff) == 0` can be simplified
2915/// to `stuff == 0`, or `and(x, 3) == 0` can be simplified to
2916/// `extract x[1:0] == 0`
2918 ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst,
2919 PatternRewriter &rewriter) {
2920
2921 // If any of the known bits disagree with any of the comparison bits, then
2922 // we can constant fold this comparison right away.
2923 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2924 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2925 // If we discover a mismatch then we know an "eq" comparison is false
2926 // and a "ne" comparison is true!
2927 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2928 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2929 APInt(1, result));
2930 return;
2931 }
2932
2933 // Check to see if we can prove the result entirely of the comparison (in
2934 // which we bail out early), otherwise build a list of values to concat and a
2935 // smaller constant to compare against.
2936 SmallVector<Value> newConcatOperands;
2937 auto newConstant = APInt::getZeroWidth();
2938
2939 // Ok, some (maybe all) bits are known and some others may be unknown.
2940 // Extract out segments of the operand and compare against the
2941 // corresponding bits.
2942 unsigned knownMSB = bitsKnown.countLeadingOnes();
2943
2944 Value operand = cmpOp.getLhs();
2945
2946 // Ok, some bits are known but others are not. Extract out sequences of
2947 // bits that are unknown and compare just those bits. We work from MSB to
2948 // LSB.
2949 while (knownMSB != bitsKnown.getBitWidth()) {
2950 // Drop any high bits that are known.
2951 if (knownMSB)
2952 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
2953
2954 // Find the span of unknown bits, and extract it.
2955 unsigned unknownBits = bitsKnown.countLeadingZeros();
2956 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
2957 auto spanOperand = rewriter.createOrFold<ExtractOp>(
2958 operand.getLoc(), operand, /*lowBit=*/lowBit,
2959 /*bitWidth=*/unknownBits);
2960 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
2961
2962 // Add this info to the concat we're generating.
2963 newConcatOperands.push_back(spanOperand);
2964 // FIXME(llvm merge, cc697fc292b0): concat doesn't work with zero bit values
2965 // newConstant = newConstant.concat(spanConstant);
2966 if (newConstant.getBitWidth() != 0)
2967 newConstant = newConstant.concat(spanConstant);
2968 else
2969 newConstant = spanConstant;
2970
2971 // Drop the unknown bits in prep for the next chunk.
2972 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
2973 bitsKnown = bitsKnown.trunc(newWidth);
2974 knownMSB = bitsKnown.countLeadingOnes();
2975 }
2976
2977 // If all the operands to the concat are foldable then we have an identity
2978 // situation where all the sub-elements equal each other. This implies that
2979 // the overall result is foldable.
2980 if (newConcatOperands.empty()) {
2981 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
2982 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2983 APInt(1, result));
2984 return;
2985 }
2986
2987 // If we have a single operand remaining, use it, otherwise form a concat.
2988 Value concatResult =
2989 rewriter.createOrFold<ConcatOp>(operand.getLoc(), newConcatOperands);
2990
2991 // Form the comparison against the smaller constant.
2992 auto newConstantOp = hw::ConstantOp::create(
2993 rewriter, cmpOp.getOperand(1).getLoc(), newConstant);
2994
2995 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
2996 cmpOp.getPredicate(), concatResult,
2997 newConstantOp, cmpOp.getTwoState());
2998}
2999
3000// Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), cst1^cst2).
3001static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp,
3002 const APInt &rhs,
3003 PatternRewriter &rewriter) {
3004 auto ip = rewriter.saveInsertionPoint();
3005 rewriter.setInsertionPoint(xorOp);
3006
3007 auto xorRHS = xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>();
3008 auto newRHS = hw::ConstantOp::create(rewriter, xorRHS->getLoc(),
3009 xorRHS.getValue() ^ rhs);
3010 Value newLHS;
3011 switch (xorOp.getNumOperands()) {
3012 case 1:
3013 // This isn't common but is defined so we need to handle it.
3014 newLHS = hw::ConstantOp::create(rewriter, xorOp.getLoc(),
3015 APInt::getZero(rhs.getBitWidth()));
3016 break;
3017 case 2:
3018 // The binary case is the most common.
3019 newLHS = xorOp.getOperand(0);
3020 break;
3021 default:
3022 // The general case forces us to form a new xor with the remaining operands.
3023 SmallVector<Value> newOperands(xorOp.getOperands());
3024 newOperands.pop_back();
3025 newLHS = XorOp::create(rewriter, xorOp.getLoc(), newOperands, false);
3026 break;
3027 }
3028
3029 bool xorMultipleUses = !xorOp->hasOneUse();
3030
3031 // If the xor has multiple uses (not just the compare, then we need/want to
3032 // replace them as well.
3033 if (xorMultipleUses)
3034 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3035 false);
3036
3037 // Replace the comparison.
3038 rewriter.restoreInsertionPoint(ip);
3039 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3040 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS, false);
3041}
3042
3043LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3044 if (isOpTriviallyRecursive(op))
3045 return failure();
3046 APInt lhs, rhs;
3047
3048 // icmp 1, x -> icmp x, 1
3049 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3050 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3051 "Should be folded");
3052 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3053 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3054 op.getRhs(), op.getLhs(), op.getTwoState());
3055 return success();
3056 }
3057
3058 // Canonicalize with RHS constant
3059 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3060 auto getConstant = [&](APInt constant) -> Value {
3061 return hw::ConstantOp::create(rewriter, op.getLoc(), std::move(constant));
3062 };
3063
3064 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3065 Value rhs) -> LogicalResult {
3066 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
3067 rhs, op.getTwoState());
3068 return success();
3069 };
3070
3071 auto replaceWithConstantI1 = [&](bool constant) -> LogicalResult {
3072 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
3073 APInt(1, constant));
3074 return success();
3075 };
3076
3077 switch (op.getPredicate()) {
3078 case ICmpPredicate::slt:
3079 // x < max -> x != max
3080 if (rhs.isMaxSignedValue())
3081 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3082 // x < min -> false
3083 if (rhs.isMinSignedValue())
3084 return replaceWithConstantI1(0);
3085 // x < min+1 -> x == min
3086 if ((rhs - 1).isMinSignedValue())
3087 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3088 getConstant(rhs - 1));
3089 break;
3090 case ICmpPredicate::sgt:
3091 // x > min -> x != min
3092 if (rhs.isMinSignedValue())
3093 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3094 // x > max -> false
3095 if (rhs.isMaxSignedValue())
3096 return replaceWithConstantI1(0);
3097 // x > max-1 -> x == max
3098 if ((rhs + 1).isMaxSignedValue())
3099 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3100 getConstant(rhs + 1));
3101 break;
3102 case ICmpPredicate::ult:
3103 // x < max -> x != max
3104 if (rhs.isAllOnes())
3105 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3106 // x < min -> false
3107 if (rhs.isZero())
3108 return replaceWithConstantI1(0);
3109 // x < min+1 -> x == min
3110 if ((rhs - 1).isZero())
3111 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3112 getConstant(rhs - 1));
3113
3114 // x < 0xE0 -> extract(x, 5..7) != 0b111
3115 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3116 rhs.getBitWidth()) {
3117 auto numOnes = rhs.countLeadingOnes();
3118 auto smaller = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(),
3119 rhs.getBitWidth() - numOnes, numOnes);
3120 return replaceWith(ICmpPredicate::ne, smaller,
3121 getConstant(APInt::getAllOnes(numOnes)));
3122 }
3123
3124 break;
3125 case ICmpPredicate::ugt:
3126 // x > min -> x != min
3127 if (rhs.isZero())
3128 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3129 // x > max -> false
3130 if (rhs.isAllOnes())
3131 return replaceWithConstantI1(0);
3132 // x > max-1 -> x == max
3133 if ((rhs + 1).isAllOnes())
3134 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3135 getConstant(rhs + 1));
3136
3137 // x > 0x07 -> extract(x, 3..7) != 0b00000
3138 if ((rhs + 1).isPowerOf2()) {
3139 auto numOnes = rhs.countTrailingOnes();
3140 auto newWidth = rhs.getBitWidth() - numOnes;
3141 auto smaller = ExtractOp::create(rewriter, op.getLoc(), op.getLhs(),
3142 numOnes, newWidth);
3143 return replaceWith(ICmpPredicate::ne, smaller,
3144 getConstant(APInt::getZero(newWidth)));
3145 }
3146
3147 break;
3148 case ICmpPredicate::sle:
3149 // x <= max -> true
3150 if (rhs.isMaxSignedValue())
3151 return replaceWithConstantI1(1);
3152 // x <= c -> x < (c+1)
3153 return replaceWith(ICmpPredicate::slt, op.getLhs(), getConstant(rhs + 1));
3154 case ICmpPredicate::sge:
3155 // x >= min -> true
3156 if (rhs.isMinSignedValue())
3157 return replaceWithConstantI1(1);
3158 // x >= c -> x > (c-1)
3159 return replaceWith(ICmpPredicate::sgt, op.getLhs(), getConstant(rhs - 1));
3160 case ICmpPredicate::ule:
3161 // x <= max -> true
3162 if (rhs.isAllOnes())
3163 return replaceWithConstantI1(1);
3164 // x <= c -> x < (c+1)
3165 return replaceWith(ICmpPredicate::ult, op.getLhs(), getConstant(rhs + 1));
3166 case ICmpPredicate::uge:
3167 // x >= min -> true
3168 if (rhs.isZero())
3169 return replaceWithConstantI1(1);
3170 // x >= c -> x > (c-1)
3171 return replaceWith(ICmpPredicate::ugt, op.getLhs(), getConstant(rhs - 1));
3172 case ICmpPredicate::eq:
3173 if (rhs.getBitWidth() == 1) {
3174 if (rhs.isZero()) {
3175 // x == 0 -> x ^ 1
3176 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3177 getConstant(APInt(1, 1)),
3178 op.getTwoState());
3179 return success();
3180 }
3181 if (rhs.isAllOnes()) {
3182 // x == 1 -> x
3183 replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
3184 return success();
3185 }
3186 }
3187 break;
3188 case ICmpPredicate::ne:
3189 if (rhs.getBitWidth() == 1) {
3190 if (rhs.isZero()) {
3191 // x != 0 -> x
3192 replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
3193 return success();
3194 }
3195 if (rhs.isAllOnes()) {
3196 // x != 1 -> x ^ 1
3197 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3198 getConstant(APInt(1, 1)),
3199 op.getTwoState());
3200 return success();
3201 }
3202 }
3203 break;
3204 case ICmpPredicate::ceq:
3205 case ICmpPredicate::cne:
3206 case ICmpPredicate::weq:
3207 case ICmpPredicate::wne:
3208 break;
3209 }
3210
3211 // We have some specific optimizations for comparison with a constant that
3212 // are only supported for equality comparisons.
3213 if (op.getPredicate() == ICmpPredicate::eq ||
3214 op.getPredicate() == ICmpPredicate::ne) {
3215 // Simplify `icmp(value_with_known_bits, rhscst)` into some extracts
3216 // with a smaller constant. We only support equality comparisons for
3217 // this.
3218 auto knownBits = computeKnownBits(op.getLhs());
3219 if (!knownBits.isUnknown())
3220 return combineEqualityICmpWithKnownBitsAndConstant(op, knownBits, rhs,
3221 rewriter),
3222 success();
3223
3224 // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b),
3225 // cst1^cst2).
3226 if (auto xorOp = op.getLhs().getDefiningOp<XorOp>())
3227 if (xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>())
3228 return combineEqualityICmpWithXorOfConstant(op, xorOp, rhs, rewriter),
3229 success();
3230
3231 // Simplify icmp eq(replicate(v, n), c) -> icmp eq(v, c) if c is zero or
3232 // all one.
3233 if (auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3234 if (rhs.isAllOnes() || rhs.isZero()) {
3235 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3236 auto cst =
3237 hw::ConstantOp::create(rewriter, op.getLoc(),
3238 rhs.isAllOnes() ? APInt::getAllOnes(width)
3239 : APInt::getZero(width));
3240 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3241 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3242 op.getTwoState());
3243 return success();
3244 }
3245 }
3246 }
3247
3248 // icmp(cat(prefix, a, b, suffix), cat(prefix, c, d, suffix)) => icmp(cat(a,
3249 // b), cat(c, d)). contains special handling for sign bit in signed
3250 // compressions.
3251 if (Operation *opLHS = op.getLhs().getDefiningOp())
3252 if (Operation *opRHS = op.getRhs().getDefiningOp())
3253 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3254 isa<ConcatOp, ReplicateOp>(opRHS)) {
3255 if (succeeded(matchAndRewriteCompareConcat(op, opLHS, opRHS, rewriter)))
3256 return success();
3257 }
3258
3259 return failure();
3260}
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
Definition CombFolds.cpp:83
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
Definition CombFolds.cpp:51
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
Definition CombFolds.cpp:45
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
Definition CombFolds.cpp:89
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool assumeMuxCondInOperand(Value muxCond, Value muxValue, bool constCond, PatternRewriter &rewriter)
If the mux condition is an operand to the op defining its true or false value, replace the condition ...
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool isOpTriviallyRecursive(Operation *op)
Definition CombFolds.cpp:26
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
static bool hasSVAttributes(Operation *op)
Definition CombFolds.cpp:66
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
Definition CombFolds.cpp:37
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(low_bit, result_type, input=None)
Definition comb.py:187
create(elements)
Definition hw.py:483
create(array_value, low_index, ret_type)
Definition hw.py:466
create(data_type, value)
Definition hw.py:441
create(data_type, value)
Definition hw.py:433
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition CombOps.cpp:65
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
Definition comb.py:1