CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
CombFolds.cpp
Go to the documentation of this file.
1//===- CombFolds.cpp - Folds + Canonicalization for Comb operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/PatternMatch.h"
15#include "llvm/ADT/SetVector.h"
16#include "llvm/ADT/SmallBitVector.h"
17#include "llvm/ADT/TypeSwitch.h"
18#include "llvm/Support/KnownBits.h"
19
20using namespace mlir;
21using namespace circt;
22using namespace comb;
23using namespace matchers;
24
25/// In comb, we assume no knowledge of the semantics of cross-block dataflow. As
26/// such, cross-block dataflow is interpreted as a canonicalization barrier.
27/// This is a conservative approach which:
28/// 1. still allows for efficient canonicalization for the common CIRCT usecase
29/// of comb (comb logic nested inside single-block hw.module's)
30/// 2. allows comb operations to be used in non-HW container ops - that may use
31/// MLIR blocks and regions to represent various forms of hierarchical
32/// abstractions, thus allowing comb to compose with other dialects.
33static bool hasOperandsOutsideOfBlock(Operation *op) {
34 Block *thisBlock = op->getBlock();
35 return llvm::any_of(op->getOperands(), [&](Value operand) {
36 return operand.getParentBlock() != thisBlock;
37 });
38}
39
40/// Create a new instance of a generic operation that only has value operands,
41/// and has a single result value whose type matches the first operand.
42///
43/// This should not be used to create instances of ops with attributes or with
44/// more complicated type signatures.
45static Value createGenericOp(Location loc, OperationName name,
46 ArrayRef<Value> operands, OpBuilder &builder) {
47 OperationState state(loc, name);
48 state.addOperands(operands);
49 state.addTypes(operands[0].getType());
50 return builder.create(state)->getResult(0);
51}
52
53static TypedAttr getIntAttr(const APInt &value, MLIRContext *context) {
54 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
55 value);
56}
57
58/// Flatten concat and mux operands into a vector.
59static void getConcatOperands(Value v, SmallVectorImpl<Value> &result) {
60 if (auto concat = v.getDefiningOp<ConcatOp>()) {
61 for (auto op : concat.getOperands())
62 getConcatOperands(op, result);
63 } else if (auto repl = v.getDefiningOp<ReplicateOp>()) {
64 for (size_t i = 0, e = repl.getMultiple(); i != e; ++i)
65 getConcatOperands(repl.getOperand(), result);
66 } else {
67 result.push_back(v);
68 }
69}
70
71// Return true if the op has SV attributes. Note that we cannot use a helper
72// function `hasSVAttributes` defined under SV dialect because of a cyclic
73// dependency.
74static bool hasSVAttributes(Operation *op) {
75 return op->hasAttr("sv.attributes");
76}
77
78namespace {
79template <typename SubType>
80struct ComplementMatcher {
81 SubType lhs;
82 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
83 bool match(Operation *op) {
84 auto xorOp = dyn_cast<XorOp>(op);
85 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
86 }
87};
88} // end anonymous namespace
89
90template <typename SubType>
91static inline ComplementMatcher<SubType> m_Complement(const SubType &subExpr) {
92 return ComplementMatcher<SubType>(subExpr);
93}
94
95/// Return true if the op will be flattened afterwards. Op will be flattend if
96/// it has a single user which has a same op type. User must be in same block.
97static bool shouldBeFlattened(Operation *op) {
98 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
99 "must be commutative operations"));
100 if (op->hasOneUse()) {
101 auto *user = *op->getUsers().begin();
102 return user->getName() == op->getName() &&
103 op->getAttrOfType<UnitAttr>("twoState") ==
104 user->getAttrOfType<UnitAttr>("twoState") &&
105 op->getBlock() == user->getBlock();
106 }
107 return false;
108}
109
110/// Flattens a single input in `op` if `hasOneUse` is true and it can be defined
111/// as an Op. Returns true if successful, and false otherwise.
112///
113/// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true
114///
115static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) {
116 // Skip if the operation should be flattened by another operation.
117 if (shouldBeFlattened(op))
118 return false;
119
120 auto inputs = op->getOperands();
121
122 SmallVector<Value, 4> newOperands;
123 SmallVector<Location, 4> newLocations{op->getLoc()};
124 newOperands.reserve(inputs.size());
125 struct Element {
126 decltype(inputs.begin()) current, end;
127 };
128
129 SmallVector<Element> worklist;
130 worklist.push_back({inputs.begin(), inputs.end()});
131 bool binFlag = op->hasAttrOfType<UnitAttr>("twoState");
132 bool changed = false;
133 while (!worklist.empty()) {
134 auto &element = worklist.back(); // Do not pop. Take ref.
135
136 // Pop when we finished traversing the current operand range.
137 if (element.current == element.end) {
138 worklist.pop_back();
139 continue;
140 }
141
142 Value value = *element.current++;
143 auto *flattenOp = value.getDefiningOp();
144 // If not defined by a compatible operation of the same kind and
145 // from the same block, keep this as-is.
146 if (!flattenOp || flattenOp->getName() != op->getName() ||
147 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>("twoState") ||
148 flattenOp->getBlock() != op->getBlock()) {
149 newOperands.push_back(value);
150 continue;
151 }
152
153 // Don't duplicate logic when it has multiple uses.
154 if (!value.hasOneUse()) {
155 // We can fold a multi-use binary operation into this one if this allows a
156 // constant to fold though. For example, fold
157 // (or a, b, c, (or d, cst1), cst2) --> (or a, b, c, d, cst1, cst2)
158 // since the constants will both fold and we end up with the equiv cost.
159 //
160 // We don't do this for add/mul because the hardware won't be shared
161 // between the two ops if duplicated.
162 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
163 !flattenOp->getOperand(1).getDefiningOp<hw::ConstantOp>() ||
164 !inputs.back().getDefiningOp<hw::ConstantOp>()) {
165 newOperands.push_back(value);
166 continue;
167 }
168 }
169
170 changed = true;
171
172 // Otherwise, push operands into worklist.
173 auto flattenOpInputs = flattenOp->getOperands();
174 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
175 newLocations.push_back(flattenOp->getLoc());
176 }
177
178 if (!changed)
179 return false;
180
181 Value result = createGenericOp(FusedLoc::get(op->getContext(), newLocations),
182 op->getName(), newOperands, rewriter);
183 if (binFlag)
184 result.getDefiningOp()->setAttr("twoState", rewriter.getUnitAttr());
185
186 replaceOpAndCopyNamehint(rewriter, op, result);
187 return true;
188}
189
190// Given a range of uses of an operation, find the lowest and highest bits
191// inclusive that are ever referenced. The range of uses must not be empty.
192static std::pair<size_t, size_t>
193getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits,
194 size_t originalOpWidth) {
195 auto users = op->getUsers();
196 assert(!users.empty() &&
197 "getLowestBitAndHighestBitRequired cannot operate on "
198 "a empty list of uses.");
199
200 // when we don't want to narrowTrailingBits (namely in arithmetic
201 // operations), forcing lowestBitRequired = 0
202 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
203 size_t highestBitRequired = 0;
204
205 for (auto *user : users) {
206 if (auto extractOp = dyn_cast<ExtractOp>(user)) {
207 size_t lowBit = extractOp.getLowBit();
208 size_t highBit =
209 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
210 highestBitRequired = std::max(highestBitRequired, highBit);
211 lowestBitRequired = std::min(lowestBitRequired, lowBit);
212 continue;
213 }
214
215 highestBitRequired = originalOpWidth - 1;
216 lowestBitRequired = 0;
217 break;
218 }
219
220 return {lowestBitRequired, highestBitRequired};
221}
222
223template <class OpTy>
224static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits,
225 PatternRewriter &rewriter) {
226 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
227 if (!opType)
228 return false;
229
230 auto range = getLowestBitAndHighestBitRequired(op, narrowTrailingBits,
231 opType.getWidth());
232 if (range.second + 1 == opType.getWidth() && range.first == 0)
233 return false;
234
235 SmallVector<Value> args;
236 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
237 for (auto inop : op.getOperands()) {
238 // deal with muxes here
239 if (inop.getType() != op.getType())
240 args.push_back(inop);
241 else
242 args.push_back(rewriter.createOrFold<ExtractOp>(inop.getLoc(), newType,
243 inop, range.first));
244 }
245 auto newop = rewriter.create<OpTy>(op.getLoc(), newType, args);
246 newop->setDialectAttrs(op->getDialectAttrs());
247 if (op.getTwoState())
248 newop.setTwoState(true);
249
250 Value newResult = newop.getResult();
251 if (range.first)
252 newResult = rewriter.createOrFold<ConcatOp>(
253 op.getLoc(), newResult,
254 rewriter.create<hw::ConstantOp>(op.getLoc(),
255 APInt::getZero(range.first)));
256 if (range.second + 1 < opType.getWidth())
257 newResult = rewriter.createOrFold<ConcatOp>(
258 op.getLoc(),
259 rewriter.create<hw::ConstantOp>(
260 op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
261 newResult);
262 rewriter.replaceOp(op, newResult);
263 return true;
264}
265
266//===----------------------------------------------------------------------===//
267// Unary Operations
268//===----------------------------------------------------------------------===//
269
270OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
271 if (hasOperandsOutsideOfBlock(getOperation()))
272 return {};
273
274 // Replicate one time -> noop.
275 if (cast<IntegerType>(getType()).getWidth() ==
276 getInput().getType().getIntOrFloatBitWidth())
277 return getInput();
278
279 // Constant fold.
280 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
281 if (input.getValue().getBitWidth() == 1) {
282 if (input.getValue().isZero())
283 return getIntAttr(
284 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
285 getContext());
286 return getIntAttr(
287 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
288 getContext());
289 }
290
291 APInt result = APInt::getZeroWidth();
292 for (auto i = getMultiple(); i != 0; --i)
293 result = result.concat(input.getValue());
294 return getIntAttr(result, getContext());
295 }
296
297 return {};
298}
299
300OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
301 if (hasOperandsOutsideOfBlock(getOperation()))
302 return {};
303
304 // Constant fold.
305 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
306 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
307
308 return {};
309}
310
311//===----------------------------------------------------------------------===//
312// Binary Operations
313//===----------------------------------------------------------------------===//
314
315/// Performs constant folding `calculate` with element-wise behavior on the two
316/// attributes in `operands` and returns the result if possible.
317static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
318 hw::PEO paramOpcode) {
319 assert(operands.size() == 2 && "binary op takes two operands");
320 if (!operands[0] || !operands[1])
321 return {};
322
323 // Fold constants with ParamExprAttr::get which handles simple constants as
324 // well as parameter expressions.
325 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
326 cast<TypedAttr>(operands[1]));
327}
328
329OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
330 if (hasOperandsOutsideOfBlock(getOperation()))
331 return {};
332
333 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
334 unsigned shift = rhs.getValue().getZExtValue();
335 unsigned width = getType().getIntOrFloatBitWidth();
336 if (shift == 0)
337 return getOperand(0);
338 if (width <= shift)
339 return getIntAttr(APInt::getZero(width), getContext());
340 }
341
342 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::Shl);
343}
344
345LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
347 return failure();
348
349 // ShlOp(x, cst) -> Concat(Extract(x), zeros)
350 APInt value;
351 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
352 return failure();
353
354 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
355 unsigned shift = value.getZExtValue();
356
357 // This case is handled by fold.
358 if (width <= shift || shift == 0)
359 return failure();
360
361 auto zeros =
362 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
363
364 // Remove the high bits which would be removed by the Shl.
365 auto extract =
366 rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), 0, width - shift);
367
368 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
369 return success();
370}
371
372OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
373 if (hasOperandsOutsideOfBlock(getOperation()))
374 return {};
375
376 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
377 unsigned shift = rhs.getValue().getZExtValue();
378 if (shift == 0)
379 return getOperand(0);
380
381 unsigned width = getType().getIntOrFloatBitWidth();
382 if (width <= shift)
383 return getIntAttr(APInt::getZero(width), getContext());
384 }
385 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrU);
386}
387
388LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
390 return failure();
391
392 // ShrUOp(x, cst) -> Concat(zeros, Extract(x))
393 APInt value;
394 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
395 return failure();
396
397 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
398 unsigned shift = value.getZExtValue();
399
400 // This case is handled by fold.
401 if (width <= shift || shift == 0)
402 return failure();
403
404 auto zeros =
405 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
406
407 // Remove the low bits which would be removed by the Shr.
408 auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
409 width - shift);
410
411 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
412 return success();
413}
414
415OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
416 if (hasOperandsOutsideOfBlock(getOperation()))
417 return {};
418
419 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
420 if (rhs.getValue().getZExtValue() == 0)
421 return getOperand(0);
422 }
423 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrS);
424}
425
426LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
428 return failure();
429
430 // ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x))
431 APInt value;
432 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
433 return failure();
434
435 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
436 unsigned shift = value.getZExtValue();
437
438 auto topbit =
439 rewriter.createOrFold<ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
440 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
441
442 if (width <= shift) {
443 replaceOpAndCopyNamehint(rewriter, op, {sext});
444 return success();
445 }
446
447 auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
448 width - shift);
449
450 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
451 return success();
452}
453
454//===----------------------------------------------------------------------===//
455// Other Operations
456//===----------------------------------------------------------------------===//
457
458OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
459 if (hasOperandsOutsideOfBlock(getOperation()))
460 return {};
461
462 // If we are extracting the entire input, then return it.
463 if (getInput().getType() == getType())
464 return getInput();
465
466 // Constant fold.
467 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
468 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
469 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
470 getContext());
471 }
472 return {};
473}
474
475// Transforms extract(lo, cat(a, b, c, d, e)) into
476// cat(extract(lo1, b), c, extract(lo2, d)).
477// innerCat must be the argument of the provided ExtractOp.
479 ConcatOp innerCat,
480 PatternRewriter &rewriter) {
481 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
482 size_t beginOfFirstRelevantElement = 0;
483 auto it = reversedConcatArgs.begin();
484 size_t lowBit = op.getLowBit();
485
486 // This loop finds the first concatArg that is covered by the ExtractOp
487 for (; it != reversedConcatArgs.end(); it++) {
488 assert(beginOfFirstRelevantElement <= lowBit &&
489 "incorrectly moved past an element that lowBit has coverage over");
490 auto operand = *it;
491
492 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
493 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
494 // A bit other than the first bit will be used in this element.
495 // ...... ........ ...
496 // ^---lowBit
497 // ^---beginOfFirstRelevantElement
498 //
499 // Edge-case close to the end of the range.
500 // ...... ........ ...
501 // ^---(position + operandWidth)
502 // ^---lowBit
503 // ^---beginOfFirstRelevantElement
504 //
505 // Edge-case close to the beginning of the rang
506 // ...... ........ ...
507 // ^---lowBit
508 // ^---beginOfFirstRelevantElement
509 //
510 break;
511 }
512
513 // extraction discards this element.
514 // ...... ........ ...
515 // | ^---lowBit
516 // ^---beginOfFirstRelevantElement
517 beginOfFirstRelevantElement += operandWidth;
518 }
519 assert(it != reversedConcatArgs.end() &&
520 "incorrectly failed to find an element which contains coverage of "
521 "lowBit");
522
523 SmallVector<Value> reverseConcatArgs;
524 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
525 size_t extractLo = lowBit - beginOfFirstRelevantElement;
526
527 // Transform individual arguments of innerCat(..., a, b, c,) into
528 // [ extract(a), b, extract(c) ], skipping an extract operation where
529 // possible.
530 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
531 auto concatArg = *it;
532 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
533 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
534
535 if (widthToConsume == operandWidth && extractLo == 0) {
536 reverseConcatArgs.push_back(concatArg);
537 } else {
538 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
539 reverseConcatArgs.push_back(
540 rewriter.create<ExtractOp>(op.getLoc(), resultType, *it, extractLo));
541 }
542
543 widthRemaining -= widthToConsume;
544
545 // Beyond the first element, all elements are extracted from position 0.
546 extractLo = 0;
547 }
548
549 if (reverseConcatArgs.size() == 1) {
550 replaceOpAndCopyNamehint(rewriter, op, reverseConcatArgs[0]);
551 } else {
552 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
553 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
554 }
555 return success();
556}
557
558// Transforms extract(lo, replicate(a, N)) into replicate(a, N-c).
559static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate,
560 PatternRewriter &rewriter) {
561 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
562 auto replicateEltWidth =
563 replicate.getOperand().getType().getIntOrFloatBitWidth();
564
565 // If the extract starts at the base of an element and is an even multiple,
566 // we can replace the extract with a smaller replicate.
567 if (op.getLowBit() % replicateEltWidth == 0 &&
568 extractResultWidth % replicateEltWidth == 0) {
569 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
570 replicate.getOperand());
571 return true;
572 }
573
574 // If the extract is completely contained in one element, extract from the
575 // element.
576 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
577 replicateEltWidth) {
578 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
579 rewriter, op, op.getType(), replicate.getOperand(),
580 op.getLowBit() % replicateEltWidth);
581 return true;
582 }
583
584 // We don't currently handle the case of extracting from non-whole elements,
585 // e.g. `extract (replicate 2-bit-thing, N), 1`.
586 return false;
587}
588
589LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
591 return failure();
592
593 auto *inputOp = op.getInput().getDefiningOp();
594
595 // This turns out to be incredibly expensive. Disable until performance is
596 // addressed.
597#if 0
598 // If the extracted bits are all known, then return the result.
599 auto knownBits = computeKnownBits(op.getInput())
600 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
601 op.getLowBit());
602 if (knownBits.isConstant()) {
603 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
604 knownBits.getConstant());
605 return success();
606 }
607#endif
608
609 // extract(olo, extract(ilo, x)) = extract(olo + ilo, x)
610 if (auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
611 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
612 rewriter, op, op.getType(), innerExtract.getInput(),
613 innerExtract.getLowBit() + op.getLowBit());
614 return success();
615 }
616
617 // extract(lo, cat(a, b, c, d, e)) = cat(extract(lo1, b), c, extract(lo2, d))
618 if (auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
619 return extractConcatToConcatExtract(op, innerCat, rewriter);
620
621 // extract(lo, replicate(a))
622 if (auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
623 if (extractFromReplicate(op, replicate, rewriter))
624 return success();
625
626 // `extract(and(a, cst))` -> `extract(a)` when the relevant bits of the
627 // and/or/xor are not modifying the extracted bits.
628 if (inputOp && inputOp->getNumOperands() == 2 &&
629 isa<AndOp, OrOp, XorOp>(inputOp)) {
630 if (auto cstRHS = inputOp->getOperand(1).getDefiningOp<hw::ConstantOp>()) {
631 auto extractedCst = cstRHS.getValue().extractBits(
632 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
633 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
634 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
635 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
636 return success();
637 }
638
639 // `extract(and(a, cst))` -> `concat(extract(a), 0)` if we only need one
640 // extract to represent the result. Turning it into a pile of extracts is
641 // always fine by our cost model, but we don't want to explode things into
642 // a ton of bits because it will bloat the IR and generated Verilog.
643 if (isa<AndOp>(inputOp)) {
644 // For our cost model, we only do this if the bit pattern is a
645 // contiguous series of ones.
646 unsigned lz = extractedCst.countLeadingZeros();
647 unsigned tz = extractedCst.countTrailingZeros();
648 unsigned pop = extractedCst.popcount();
649 if (extractedCst.getBitWidth() - lz - tz == pop) {
650 auto resultTy = rewriter.getIntegerType(pop);
651 SmallVector<Value> resultElts;
652 if (lz)
653 resultElts.push_back(rewriter.create<hw::ConstantOp>(
654 op.getLoc(), APInt::getZero(lz)));
655 resultElts.push_back(rewriter.createOrFold<ExtractOp>(
656 op.getLoc(), resultTy, inputOp->getOperand(0),
657 op.getLowBit() + tz));
658 if (tz)
659 resultElts.push_back(rewriter.create<hw::ConstantOp>(
660 op.getLoc(), APInt::getZero(tz)));
661 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
662 return success();
663 }
664 }
665 }
666 }
667
668 // `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is
669 // extracted.
670 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
671 if (auto shlOp = dyn_cast<ShlOp>(inputOp)) {
672 // Don't canonicalize if the shift is multiply used.
673 if (shlOp->hasOneUse())
674 if (auto lhsCst = shlOp.getLhs().getDefiningOp<hw::ConstantOp>())
675 if (lhsCst.getValue().isOne()) {
676 auto newCst = rewriter.create<hw::ConstantOp>(
677 shlOp.getLoc(),
678 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
679 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
680 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
681 false);
682 return success();
683 }
684 }
685
686 return failure();
687}
688
689//===----------------------------------------------------------------------===//
690// Associative Variadic operations
691//===----------------------------------------------------------------------===//
692
693// Reduce all operands to a single value (either integer constant or parameter
694// expression) if all the operands are constants.
695static Attribute constFoldAssociativeOp(ArrayRef<Attribute> operands,
696 hw::PEO paramOpcode) {
697 assert(operands.size() > 1 && "caller should handle one-operand case");
698 // We can only fold anything in the case where all operands are known to be
699 // constants. Check the least common one first for an early out.
700 if (!operands[1] || !operands[0])
701 return {};
702
703 // This will fold to a simple constant if all operands are constant.
704 if (llvm::all_of(operands.drop_front(2),
705 [&](Attribute in) { return !!in; })) {
706 SmallVector<mlir::TypedAttr> typedOperands;
707 typedOperands.reserve(operands.size());
708 for (auto operand : operands) {
709 if (auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
710 typedOperands.push_back(typedOperand);
711 else
712 break;
713 }
714 if (typedOperands.size() == operands.size())
715 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
716 }
717
718 return {};
719}
720
721/// When we find a logical operation (and, or, xor) with a constant e.g.
722/// `X & 42`, we want to push the constant into the computation of X if it leads
723/// to simplification.
724///
725/// This function handles the case where the logical operation has a concat
726/// operand. We check to see if we can simplify the concat, e.g. when it has
727/// constant operands.
728///
729/// This returns true when a simplification happens.
730static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp,
731 size_t concatIdx, const APInt &cst,
732 PatternRewriter &rewriter) {
733 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<ConcatOp>();
734 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
735
736 // Check to see if any operands can be simplified by pushing the logical op
737 // into all parts of the concat.
738 bool canSimplify =
739 llvm::any_of(concatOp->getOperands(), [&](Value operand) -> bool {
740 auto *operandOp = operand.getDefiningOp();
741 if (!operandOp)
742 return false;
743
744 // If the concat has a constant operand then we can transform this.
745 if (isa<hw::ConstantOp>(operandOp))
746 return true;
747 // If the concat has the same logical operation and that operation has
748 // a constant operation than we can fold it into that suboperation.
749 return operandOp->getName() == logicalOp->getName() &&
750 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
751 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
752 });
753
754 if (!canSimplify)
755 return false;
756
757 // Create a new instance of the logical operation. We have to do this the
758 // hard way since we're generic across a family of different ops.
759 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
760 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
761 rewriter);
762 };
763
764 // Ok, let's do the transformation. We do this by slicing up the constant
765 // for each unit of the concat and duplicate the operation into the
766 // sub-operand.
767 SmallVector<Value> newConcatOperands;
768 newConcatOperands.reserve(concatOp->getNumOperands());
769
770 // Work from MSB to LSB.
771 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
772 for (Value operand : concatOp->getOperands()) {
773 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
774 nextOperandBit -= operandWidth;
775 // Take a slice of the constant.
776 auto eltCst = rewriter.create<hw::ConstantOp>(
777 logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
778
779 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
780 }
781
782 // Create the concat, and the rest of the logical op if we need it.
783 Value newResult =
784 rewriter.create<ConcatOp>(concatOp.getLoc(), newConcatOperands);
785
786 // If we had a variadic logical op on the top level, then recreate it with the
787 // new concat and without the constant operand.
788 if (logicalOp->getNumOperands() > 2) {
789 auto origOperands = logicalOp->getOperands();
790 SmallVector<Value> operands;
791 // Take any stuff before the concat.
792 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
793 // Take any stuff after the concat but before the constant.
794 operands.append(origOperands.begin() + concatIdx + 1,
795 origOperands.begin() + (origOperands.size() - 1));
796 // Include the new concat.
797 operands.push_back(newResult);
798 newResult = createLogicalOp(operands);
799 }
800
801 replaceOpAndCopyNamehint(rewriter, logicalOp, newResult);
802 return true;
803}
804
805// Determines whether the inputs to a logical element are of opposite
806// comparisons and can lowered into a constant.
807static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands) {
808 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
809
810 for (auto op : operands) {
811 if (auto icmpOp = op.getDefiningOp<ICmpOp>();
812 icmpOp && icmpOp.getTwoState()) {
813 auto predicate = icmpOp.getPredicate();
814 auto lhs = icmpOp.getLhs();
815 auto rhs = icmpOp.getRhs();
816 if (seenPredicates.contains(
817 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
818 return true;
819
820 seenPredicates.insert({predicate, lhs, rhs});
821 }
822 }
823 return false;
824}
825
826OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
827 if (hasOperandsOutsideOfBlock(getOperation()))
828 return {};
829
830 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).getWidth());
831
832 auto inputs = adaptor.getInputs();
833
834 // and(x, 01, 10) -> 00 -- annulment.
835 for (auto operand : inputs) {
836 if (!operand)
837 continue;
838 value &= cast<IntegerAttr>(operand).getValue();
839 if (value.isZero())
840 return getIntAttr(value, getContext());
841 }
842
843 // and(x, -1) -> x.
844 if (inputs.size() == 2 && inputs[1] &&
845 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
846 return getInputs()[0];
847
848 // and(x, x, x) -> x. This also handles and(x) -> x.
849 if (llvm::all_of(getInputs(),
850 [&](auto in) { return in == this->getInputs()[0]; }))
851 return getInputs()[0];
852
853 // and(..., x, ..., ~x, ...) -> 0
854 for (Value arg : getInputs()) {
855 Value subExpr;
856 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
857 for (Value arg2 : getInputs())
858 if (arg2 == subExpr)
859 return getIntAttr(
860 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
861 getContext());
862 }
863 }
864
865 // x0 = icmp(pred, x, y)
866 // x1 = icmp(!pred, x, y)
867 // and(x0, x1) -> 0
869 return getIntAttr(APInt::getZero(cast<IntegerType>(getType()).getWidth()),
870 getContext());
871
872 // Constant fold
873 return constFoldAssociativeOp(inputs, hw::PEO::And);
874}
875
876/// Returns a single common operand that all inputs of the operation `op` can
877/// be traced back to, or an empty `Value` if no such operand exists.
878///
879/// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a`
880/// (assuming the bit-width of `a` is `n`).
881template <typename Op>
882static Value getCommonOperand(Op op) {
883 if (!op.getType().isInteger(1))
884 return Value();
885
886 auto inputs = op.getInputs();
887 size_t size = inputs.size();
888
889 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
890 if (!sourceOp)
891 return Value();
892 Value source = sourceOp.getOperand();
893
894 // Fast path: the input size is not equal to the width of the source.
895 if (size != source.getType().getIntOrFloatBitWidth())
896 return Value();
897
898 // Tracks the bits that were encountered.
899 llvm::BitVector bits(size);
900 bits.set(sourceOp.getLowBit());
901
902 for (size_t i = 1; i != size; ++i) {
903 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
904 if (!extractOp || extractOp.getOperand() != source)
905 return Value();
906 bits.set(extractOp.getLowBit());
907 }
908
909 return bits.all() ? source : Value();
910}
911
912/// Canonicalize an idempotent operation `op` so that only one input of any kind
913/// occurs.
914///
915/// Example: `and(x, y, x, z)` -> `and(x, y, z)`
916template <typename Op>
917static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
918 // Depth limit to search, in operations. Chosen arbitrarily, keep small.
919 constexpr unsigned limit = 3;
920 auto inputs = op.getInputs();
921
922 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
923 llvm::SmallDenseSet<Op, 8> checked;
924 checked.insert(op);
925
926 struct OpWithDepth {
927 Op op;
928 unsigned depth;
929 };
930 llvm::SmallVector<OpWithDepth, 8> worklist;
931
932 auto enqueue = [&worklist, &checked, &op](Value input, unsigned depth) {
933 // Add to worklist if within depth limit, is defined in the same block by
934 // the same kind of operation, has same two-state-ness, and not enqueued
935 // previously.
936 if (depth < limit && input.getParentBlock() == op->getBlock()) {
937 auto inputOp = input.template getDefiningOp<Op>();
938 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
939 checked.insert(inputOp).second)
940 worklist.push_back({inputOp, depth + 1});
941 }
942 };
943
944 for (auto input : uniqueInputs)
945 enqueue(input, 0);
946
947 while (!worklist.empty()) {
948 auto item = worklist.pop_back_val();
949
950 for (auto input : item.op.getInputs()) {
951 uniqueInputs.remove(input);
952 enqueue(input, item.depth);
953 }
954 }
955
956 if (uniqueInputs.size() < inputs.size()) {
957 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
958 uniqueInputs.getArrayRef(),
959 op.getTwoState());
960 return true;
961 }
962
963 return false;
964}
965
966LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
967 auto inputs = op.getInputs();
968 auto size = inputs.size();
969
970 // and(x, and(...)) -> and(x, ...) -- flatten
971 if (tryFlatteningOperands(op, rewriter))
972 return success();
973
974 // and(..., x, ..., x) -> and(..., x, ...) -- idempotent
975 // and(..., x, and(..., x, ...)) -> and(..., and(..., x, ...)) -- idempotent
976 // Trivial and(x), and(x, x) cases are handled by [AndOp::fold] above.
977 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
978 return success();
979
981 return failure();
982 assert(size > 1 && "expected 2 or more operands, `fold` should handle this");
983
984 // Patterns for and with a constant on RHS.
985 APInt value;
986 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
987 // and(..., '1) -> and(...) -- identity
988 if (value.isAllOnes()) {
989 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
990 inputs.drop_back(), false);
991 return success();
992 }
993
994 // TODO: Combine multiple constants together even if they aren't at the
995 // end. and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant
996 // folding
997 APInt value2;
998 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
999 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value & value2);
1000 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1001 newOperands.push_back(cst);
1002 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
1003 newOperands, false);
1004 return success();
1005 }
1006
1007 // Handle 'and' with a single bit constant on the RHS.
1008 if (size == 2 && value.isPowerOf2()) {
1009 // If the LHS is a replicate from a single bit, we can 'concat' it
1010 // into place. e.g.:
1011 // `replicate(x) & 4` -> `concat(zeros, x, zeros)`
1012 // TODO: Generalize this for non-single-bit operands.
1013 if (auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
1014 auto replicateOperand = replicate.getOperand();
1015 if (replicateOperand.getType().isInteger(1)) {
1016 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
1017 auto trailingZeros = value.countTrailingZeros();
1018
1019 // Don't add zero bit constants unnecessarily.
1020 SmallVector<Value, 3> concatOperands;
1021 if (trailingZeros != resultWidth - 1) {
1022 auto highZeros = rewriter.create<hw::ConstantOp>(
1023 op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
1024 concatOperands.push_back(highZeros);
1025 }
1026 concatOperands.push_back(replicateOperand);
1027 if (trailingZeros != 0) {
1028 auto lowZeros = rewriter.create<hw::ConstantOp>(
1029 op.getLoc(), APInt::getZero(trailingZeros));
1030 concatOperands.push_back(lowZeros);
1031 }
1032 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
1033 rewriter, op, op.getType(), concatOperands);
1034 return success();
1035 }
1036 }
1037 }
1038
1039 // If this is an and from an extract op, try shrinking the extract.
1040 if (auto extractOp = inputs[0].getDefiningOp<ExtractOp>()) {
1041 if (size == 2 &&
1042 // We can shrink it if the mask has leading or trailing zeros.
1043 (value.countLeadingZeros() || value.countTrailingZeros())) {
1044 unsigned lz = value.countLeadingZeros();
1045 unsigned tz = value.countTrailingZeros();
1046
1047 // Start by extracting the smaller number of bits.
1048 auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
1049 Value smallElt = rewriter.createOrFold<ExtractOp>(
1050 extractOp.getLoc(), smallTy, extractOp->getOperand(0),
1051 extractOp.getLowBit() + tz);
1052 // Apply the 'and' mask if needed.
1053 APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
1054 if (!smallMask.isAllOnes()) {
1055 auto loc = inputs.back().getLoc();
1056 smallElt = rewriter.createOrFold<AndOp>(
1057 loc, smallElt, rewriter.create<hw::ConstantOp>(loc, smallMask),
1058 false);
1059 }
1060
1061 // The final replacement will be a concat of the leading/trailing zeros
1062 // along with the smaller extracted value.
1063 SmallVector<Value> resultElts;
1064 if (lz)
1065 resultElts.push_back(
1066 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
1067 resultElts.push_back(smallElt);
1068 if (tz)
1069 resultElts.push_back(
1070 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
1071 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
1072 return success();
1073 }
1074 }
1075
1076 // and(concat(x, cst1), a, b, c, cst2)
1077 // ==> and(a, b, c, concat(and(x,cst2'), and(cst1,cst2'')).
1078 // We do this for even more multi-use concats since they are "just wiring".
1079 for (size_t i = 0; i < size - 1; ++i) {
1080 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1081 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1082 return success();
1083 }
1084 }
1085
1086 // extracts only of and(...) -> and(extract()...)
1087 if (narrowOperationWidth(op, true, rewriter))
1088 return success();
1089
1090 // and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1)
1091 if (auto source = getCommonOperand(op)) {
1092 auto cmpAgainst =
1093 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1094 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1095 source, cmpAgainst);
1096 return success();
1097 }
1098
1099 /// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
1100 return failure();
1101}
1102
1103OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1104 if (hasOperandsOutsideOfBlock(getOperation()))
1105 return {};
1106
1107 auto value = APInt::getZero(cast<IntegerType>(getType()).getWidth());
1108 auto inputs = adaptor.getInputs();
1109 // or(x, 10, 01) -> 11
1110 for (auto operand : inputs) {
1111 if (!operand)
1112 continue;
1113 value |= cast<IntegerAttr>(operand).getValue();
1114 if (value.isAllOnes())
1115 return getIntAttr(value, getContext());
1116 }
1117
1118 // or(x, 0) -> x
1119 if (inputs.size() == 2 && inputs[1] &&
1120 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1121 return getInputs()[0];
1122
1123 // or(x, x, x) -> x. This also handles or(x) -> x
1124 if (llvm::all_of(getInputs(),
1125 [&](auto in) { return in == this->getInputs()[0]; }))
1126 return getInputs()[0];
1127
1128 // or(..., x, ..., ~x, ...) -> -1
1129 for (Value arg : getInputs()) {
1130 Value subExpr;
1131 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
1132 for (Value arg2 : getInputs())
1133 if (arg2 == subExpr)
1134 return getIntAttr(
1135 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1136 getContext());
1137 }
1138 }
1139
1140 // x0 = icmp(pred, x, y)
1141 // x1 = icmp(!pred, x, y)
1142 // or(x0, x1) -> 1
1143 if (canCombineOppositeBinCmpIntoConstant(getInputs()))
1144 return getIntAttr(
1145 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1146 getContext());
1147
1148 // Constant fold
1149 return constFoldAssociativeOp(inputs, hw::PEO::Or);
1150}
1151
1152LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
1153 auto inputs = op.getInputs();
1154 auto size = inputs.size();
1155
1156 // or(x, or(...)) -> or(x, ...) -- flatten
1157 if (tryFlatteningOperands(op, rewriter))
1158 return success();
1159
1160 // or(..., x, ..., x, ...) -> or(..., x) -- idempotent
1161 // or(..., x, or(..., x, ...)) -> or(..., or(..., x, ...)) -- idempotent
1162 // Trivial or(x), or(x, x) cases are handled by [OrOp::fold].
1163 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
1164 return success();
1165
1166 if (hasOperandsOutsideOfBlock(&*op))
1167 return failure();
1168 assert(size > 1 && "expected 2 or more operands");
1169
1170 // Patterns for and with a constant on RHS.
1171 APInt value;
1172 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1173 // or(..., '0) -> or(...) -- identity
1174 if (value.isZero()) {
1175 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1176 inputs.drop_back());
1177 return success();
1178 }
1179
1180 // or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant folding
1181 APInt value2;
1182 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1183 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value | value2);
1184 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1185 newOperands.push_back(cst);
1186 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1187 newOperands);
1188 return success();
1189 }
1190
1191 // or(concat(x, cst1), a, b, c, cst2)
1192 // ==> or(a, b, c, concat(or(x,cst2'), or(cst1,cst2'')).
1193 // We do this for even more multi-use concats since they are "just wiring".
1194 for (size_t i = 0; i < size - 1; ++i) {
1195 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1196 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1197 return success();
1198 }
1199 }
1200
1201 // extracts only of or(...) -> or(extract()...)
1202 if (narrowOperationWidth(op, true, rewriter))
1203 return success();
1204
1205 // or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0)
1206 if (auto source = getCommonOperand(op)) {
1207 auto cmpAgainst =
1208 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1209 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1210 source, cmpAgainst);
1211 return success();
1212 }
1213
1214 // or(mux(c_1, a, 0), mux(c_2, a, 0), ..., mux(c_n, a, 0)) -> mux(or(c_1, c_2,
1215 // .., c_n), a, 0)
1216 if (auto firstMux = op.getOperand(0).getDefiningOp<comb::MuxOp>()) {
1217 APInt value;
1218 if (op.getTwoState() && firstMux.getTwoState() &&
1219 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1220 value.isZero()) {
1221 SmallVector<Value> conditions{firstMux.getCond()};
1222 auto check = [&](Value v) {
1223 auto mux = v.getDefiningOp<comb::MuxOp>();
1224 if (!mux)
1225 return false;
1226 conditions.push_back(mux.getCond());
1227 return mux.getTwoState() &&
1228 firstMux.getTrueValue() == mux.getTrueValue() &&
1229 firstMux.getFalseValue() == mux.getFalseValue();
1230 };
1231 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1232 auto cond = rewriter.create<comb::OrOp>(op.getLoc(), conditions, true);
1233 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1234 rewriter, op, cond, firstMux.getTrueValue(),
1235 firstMux.getFalseValue(), true);
1236 return success();
1237 }
1238 }
1239 }
1240
1241 /// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
1242 return failure();
1243}
1244
1245OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1246 if (hasOperandsOutsideOfBlock(getOperation()))
1247 return {};
1248
1249 auto size = getInputs().size();
1250 auto inputs = adaptor.getInputs();
1251
1252 // xor(x) -> x -- noop
1253 if (size == 1)
1254 return getInputs()[0];
1255
1256 // xor(x, x) -> 0 -- idempotent
1257 if (size == 2 && getInputs()[0] == getInputs()[1])
1258 return IntegerAttr::get(getType(), 0);
1259
1260 // xor(x, 0) -> x
1261 if (inputs.size() == 2 && inputs[1] &&
1262 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1263 return getInputs()[0];
1264
1265 // xor(xor(x,1),1) -> x
1266 // but not self loop
1267 if (isBinaryNot()) {
1268 Value subExpr;
1269 if (matchPattern(getOperand(0), m_Complement(m_Any(&subExpr))) &&
1270 subExpr != getResult())
1271 return subExpr;
1272 }
1273
1274 // Constant fold
1275 return constFoldAssociativeOp(inputs, hw::PEO::Xor);
1276}
1277
1278// xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1279static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand,
1280 PatternRewriter &rewriter) {
1281 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1282 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1283
1284 Value result =
1285 rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1286 icmp.getOperand(1), icmp.getTwoState());
1287
1288 // If the xor had other operands, rebuild it.
1289 if (op.getNumOperands() > 2) {
1290 SmallVector<Value, 4> newOperands(op.getOperands());
1291 newOperands.pop_back();
1292 newOperands.erase(newOperands.begin() + icmpOperand);
1293 newOperands.push_back(result);
1294 result = rewriter.create<XorOp>(op.getLoc(), newOperands, op.getTwoState());
1295 }
1296
1297 replaceOpAndCopyNamehint(rewriter, op, result);
1298}
1299
1300LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
1301 if (hasOperandsOutsideOfBlock(&*op))
1302 return failure();
1303
1304 auto inputs = op.getInputs();
1305 auto size = inputs.size();
1306 assert(size > 1 && "expected 2 or more operands");
1307
1308 // xor(..., x, x) -> xor (...) -- idempotent
1309 if (inputs[size - 1] == inputs[size - 2]) {
1310 assert(size > 2 &&
1311 "expected idempotent case for 2 elements handled already.");
1312 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1313 inputs.drop_back(/*n=*/2), false);
1314 return success();
1315 }
1316
1317 // Patterns for xor with a constant on RHS.
1318 APInt value;
1319 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1320 // xor(..., 0) -> xor(...) -- identity
1321 if (value.isZero()) {
1322 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1323 inputs.drop_back(), false);
1324 return success();
1325 }
1326
1327 // xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2.
1328 APInt value2;
1329 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1330 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value ^ value2);
1331 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1332 newOperands.push_back(cst);
1333 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1334 newOperands, false);
1335 return success();
1336 }
1337
1338 bool isSingleBit = value.getBitWidth() == 1;
1339
1340 // Check for subexpressions that we can simplify.
1341 for (size_t i = 0; i < size - 1; ++i) {
1342 Value operand = inputs[i];
1343
1344 // xor(concat(x, cst1), a, b, c, cst2)
1345 // ==> xor(a, b, c, concat(xor(x,cst2'), xor(cst1,cst2'')).
1346 // We do this for even more multi-use concats since they are "just
1347 // wiring".
1348 if (auto concat = operand.getDefiningOp<ConcatOp>())
1349 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1350 return success();
1351
1352 // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1353 if (isSingleBit && operand.hasOneUse()) {
1354 assert(value == 1 && "single bit constant has to be one if not zero");
1355 if (auto icmp = operand.getDefiningOp<ICmpOp>())
1356 return canonicalizeXorIcmpTrue(op, i, rewriter), success();
1357 }
1358 }
1359 }
1360
1361 // xor(x, xor(...)) -> xor(x, ...) -- flatten
1362 if (tryFlatteningOperands(op, rewriter))
1363 return success();
1364
1365 // extracts only of xor(...) -> xor(extract()...)
1366 if (narrowOperationWidth(op, true, rewriter))
1367 return success();
1368
1369 // xor(a[0], a[1], ..., a[n]) -> parity(a)
1370 if (auto source = getCommonOperand(op)) {
1371 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1372 return success();
1373 }
1374
1375 return failure();
1376}
1377
1378OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1379 if (hasOperandsOutsideOfBlock(getOperation()))
1380 return {};
1381
1382 // sub(x - x) -> 0
1383 if (getRhs() == getLhs())
1384 return getIntAttr(
1385 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1386 getContext());
1387
1388 if (adaptor.getRhs()) {
1389 // If both are constants, we can unconditionally fold.
1390 if (adaptor.getLhs()) {
1391 // Constant fold (c1 - c2) => (c1 + -1*c2).
1392 auto negOne = getIntAttr(
1393 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1394 getContext());
1395 auto rhsNeg = hw::ParamExprAttr::get(
1396 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1397 return hw::ParamExprAttr::get(hw::PEO::Add,
1398 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1399 }
1400
1401 // sub(x - 0) -> x
1402 if (auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1403 if (rhsC.getValue().isZero())
1404 return getLhs();
1405 }
1406 }
1407
1408 return {};
1409}
1410
1411LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
1412 if (hasOperandsOutsideOfBlock(&*op))
1413 return failure();
1414
1415 // sub(x, cst) -> add(x, -cst)
1416 APInt value;
1417 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1418 auto negCst = rewriter.create<hw::ConstantOp>(op.getLoc(), -value);
1419 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1420 false);
1421 return success();
1422 }
1423
1424 // extracts only of sub(...) -> sub(extract()...)
1425 if (narrowOperationWidth(op, false, rewriter))
1426 return success();
1427
1428 return failure();
1429}
1430
1431OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1432 if (hasOperandsOutsideOfBlock(getOperation()))
1433 return {};
1434
1435 auto size = getInputs().size();
1436
1437 // add(x) -> x -- noop
1438 if (size == 1u)
1439 return getInputs()[0];
1440
1441 // Constant fold constant operands.
1442 return constFoldAssociativeOp(adaptor.getOperands(), hw::PEO::Add);
1443}
1444
1445LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
1446 if (hasOperandsOutsideOfBlock(&*op))
1447 return failure();
1448
1449 auto inputs = op.getInputs();
1450 auto size = inputs.size();
1451 assert(size > 1 && "expected 2 or more operands");
1452
1453 APInt value, value2;
1454
1455 // add(..., 0) -> add(...) -- identity
1456 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1457 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1458 inputs.drop_back(), false);
1459 return success();
1460 }
1461
1462 // add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant folding
1463 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1464 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1465 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
1466 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1467 newOperands.push_back(cst);
1468 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1469 newOperands, false);
1470 return success();
1471 }
1472
1473 // add(..., x, x) -> add(..., shl(x, 1))
1474 if (inputs[size - 1] == inputs[size - 2]) {
1475 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1476
1477 auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1478 auto shiftLeftOp =
1479 rewriter.create<comb::ShlOp>(op.getLoc(), inputs.back(), one, false);
1480
1481 newOperands.push_back(shiftLeftOp);
1482 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1483 newOperands, false);
1484 return success();
1485 }
1486
1487 auto shlOp = inputs[size - 1].getDefiningOp<comb::ShlOp>();
1488 // add(..., x, shl(x, c)) -> add(..., mul(x, (1 << c) + 1))
1489 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1490 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1491
1492 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1493 auto rhs =
1494 rewriter.create<hw::ConstantOp>(op.getLoc(), (one << value) + one);
1495
1496 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1497 auto mulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
1498
1499 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1500 newOperands.push_back(mulOp);
1501 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1502 newOperands, false);
1503 return success();
1504 }
1505
1506 auto mulOp = inputs[size - 1].getDefiningOp<comb::MulOp>();
1507 // add(..., x, mul(x, c)) -> add(..., mul(x, c + 1))
1508 if (mulOp && mulOp.getInputs().size() == 2 &&
1509 mulOp.getInputs()[0] == inputs[size - 2] &&
1510 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1511
1512 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1513 auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + one);
1514 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1515 auto newMulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
1516
1517 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1518 newOperands.push_back(newMulOp);
1519 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1520 newOperands, false);
1521 return success();
1522 }
1523
1524 // add(a, add(...)) -> add(a, ...) -- flatten
1525 if (tryFlatteningOperands(op, rewriter))
1526 return success();
1527
1528 // extracts only of add(...) -> add(extract()...)
1529 if (narrowOperationWidth(op, false, rewriter))
1530 return success();
1531
1532 // add(add(x, c1), c2) -> add(x, c1 + c2)
1533 auto addOp = inputs[0].getDefiningOp<comb::AddOp>();
1534 if (addOp && addOp.getInputs().size() == 2 &&
1535 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1536 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1537
1538 auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
1539 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1540 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1541 /*twoState=*/op.getTwoState() && addOp.getTwoState());
1542 return success();
1543 }
1544
1545 return failure();
1546}
1547
1548OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1549 if (hasOperandsOutsideOfBlock(getOperation()))
1550 return {};
1551
1552 auto size = getInputs().size();
1553 auto inputs = adaptor.getInputs();
1554
1555 // mul(x) -> x -- noop
1556 if (size == 1u)
1557 return getInputs()[0];
1558
1559 auto width = cast<IntegerType>(getType()).getWidth();
1560 if (width == 0)
1561 return getIntAttr(APInt::getZero(0), getContext());
1562
1563 APInt value(/*numBits=*/width, 1, /*isSigned=*/false);
1564
1565 // mul(x, 0, 1) -> 0 -- annulment
1566 for (auto operand : inputs) {
1567 if (!operand)
1568 continue;
1569 value *= cast<IntegerAttr>(operand).getValue();
1570 if (value.isZero())
1571 return getIntAttr(value, getContext());
1572 }
1573
1574 // Constant fold
1575 return constFoldAssociativeOp(inputs, hw::PEO::Mul);
1576}
1577
1578LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
1579 if (hasOperandsOutsideOfBlock(&*op))
1580 return failure();
1581
1582 auto inputs = op.getInputs();
1583 auto size = inputs.size();
1584 assert(size > 1 && "expected 2 or more operands");
1585
1586 APInt value, value2;
1587
1588 // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
1589 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1590 value.isPowerOf2()) {
1591 auto shift = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(),
1592 value.exactLogBase2());
1593 auto shlOp =
1594 rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift, false);
1595
1596 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1597 ArrayRef<Value>(shlOp), false);
1598 return success();
1599 }
1600
1601 // mul(..., 1) -> mul(...) -- identity
1602 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1603 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1604 inputs.drop_back());
1605 return success();
1606 }
1607
1608 // mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant folding
1609 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1610 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1611 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value * value2);
1612 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1613 newOperands.push_back(cst);
1614 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1615 newOperands);
1616 return success();
1617 }
1618
1619 // mul(a, mul(...)) -> mul(a, ...) -- flatten
1620 if (tryFlatteningOperands(op, rewriter))
1621 return success();
1622
1623 // extracts only of mul(...) -> mul(extract()...)
1624 if (narrowOperationWidth(op, false, rewriter))
1625 return success();
1626
1627 return failure();
1628}
1629
1630template <class Op, bool isSigned>
1631static OpFoldResult foldDiv(Op op, ArrayRef<Attribute> constants) {
1632 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1633 // divu(x, 1) -> x, divs(x, 1) -> x
1634 if (rhsValue.getValue() == 1)
1635 return op.getLhs();
1636
1637 // If the divisor is zero, do not fold for now.
1638 if (rhsValue.getValue().isZero())
1639 return {};
1640 }
1641
1642 return constFoldBinaryOp(constants, isSigned ? hw::PEO::DivS : hw::PEO::DivU);
1643}
1644
1645OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1646 if (hasOperandsOutsideOfBlock(getOperation()))
1647 return {};
1648
1649 return foldDiv<DivUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1650}
1651
1652OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1653 if (hasOperandsOutsideOfBlock(getOperation()))
1654 return {};
1655
1656 return foldDiv<DivSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1657}
1658
1659template <class Op, bool isSigned>
1660static OpFoldResult foldMod(Op op, ArrayRef<Attribute> constants) {
1661 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1662 // modu(x, 1) -> 0, mods(x, 1) -> 0
1663 if (rhsValue.getValue() == 1)
1664 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1665 op.getContext());
1666
1667 // If the divisor is zero, do not fold for now.
1668 if (rhsValue.getValue().isZero())
1669 return {};
1670 }
1671
1672 if (auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1673 // modu(0, x) -> 0, mods(0, x) -> 0
1674 if (lhsValue.getValue().isZero())
1675 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1676 op.getContext());
1677 }
1678
1679 return constFoldBinaryOp(constants, isSigned ? hw::PEO::ModS : hw::PEO::ModU);
1680}
1681
1682OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1683 if (hasOperandsOutsideOfBlock(getOperation()))
1684 return {};
1685
1686 return foldMod<ModUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1687}
1688
1689OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1690 if (hasOperandsOutsideOfBlock(getOperation()))
1691 return {};
1692
1693 return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1694}
1695//===----------------------------------------------------------------------===//
1696// ConcatOp
1697//===----------------------------------------------------------------------===//
1698
1699// Constant folding
1700OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1701 if (hasOperandsOutsideOfBlock(getOperation()))
1702 return {};
1703
1704 if (getNumOperands() == 1)
1705 return getOperand(0);
1706
1707 // If all the operands are constant, we can fold.
1708 for (auto attr : adaptor.getInputs())
1709 if (!attr || !isa<IntegerAttr>(attr))
1710 return {};
1711
1712 // If we got here, we can constant fold.
1713 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1714 APInt result(resultWidth, 0);
1715
1716 unsigned nextInsertion = resultWidth;
1717 // Insert each chunk into the result.
1718 for (auto attr : adaptor.getInputs()) {
1719 auto chunk = cast<IntegerAttr>(attr).getValue();
1720 nextInsertion -= chunk.getBitWidth();
1721 result.insertBits(chunk, nextInsertion);
1722 }
1723
1724 return getIntAttr(result, getContext());
1725}
1726
1727LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
1728 if (hasOperandsOutsideOfBlock(&*op))
1729 return failure();
1730
1731 auto inputs = op.getInputs();
1732 auto size = inputs.size();
1733 assert(size > 1 && "expected 2 or more operands");
1734
1735 // This function is used when we flatten neighboring operands of a
1736 // (variadic) concat into a new vesion of the concat. first/last indices
1737 // are inclusive.
1738 auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex,
1739 ValueRange replacements) -> LogicalResult {
1740 SmallVector<Value, 4> newOperands;
1741 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1742 newOperands.append(replacements.begin(), replacements.end());
1743 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1744 if (newOperands.size() == 1)
1745 replaceOpAndCopyNamehint(rewriter, op, newOperands[0]);
1746 else
1747 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1748 newOperands);
1749 return success();
1750 };
1751
1752 Value commonOperand = inputs[0];
1753 for (size_t i = 0; i != size; ++i) {
1754 // Check to see if all operands are the same.
1755 if (inputs[i] != commonOperand)
1756 commonOperand = Value();
1757
1758 // If an operand to the concat is itself a concat, then we can fold them
1759 // together.
1760 if (auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1761 return flattenConcat(i, i, subConcat->getOperands());
1762
1763 // Check for canonicalization due to neighboring operands.
1764 if (i != 0) {
1765 // Merge neighboring constants.
1766 if (auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1767 if (auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1768 unsigned prevWidth = prevCst.getValue().getBitWidth();
1769 unsigned thisWidth = cst.getValue().getBitWidth();
1770 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1771 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1772 << thisWidth;
1773 Value replacement =
1774 rewriter.create<hw::ConstantOp>(op.getLoc(), resultCst);
1775 return flattenConcat(i - 1, i, replacement);
1776 }
1777 }
1778
1779 // If the two operands are the same, turn them into a replicate.
1780 if (inputs[i] == inputs[i - 1]) {
1781 Value replacement =
1782 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1783 return flattenConcat(i - 1, i, replacement);
1784 }
1785
1786 // If this input is a replicate, see if we can fold it with the previous
1787 // one.
1788 if (auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1789 // ... x, repl(x, n), ... ==> ..., repl(x, n+1), ...
1790 if (repl.getOperand() == inputs[i - 1]) {
1791 Value replacement = rewriter.createOrFold<ReplicateOp>(
1792 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1793 return flattenConcat(i - 1, i, replacement);
1794 }
1795 // ... repl(x, n), repl(x, m), ... ==> ..., repl(x, n+m), ...
1796 if (auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1797 if (prevRepl.getOperand() == repl.getOperand()) {
1798 Value replacement = rewriter.createOrFold<ReplicateOp>(
1799 op.getLoc(), repl.getOperand(),
1800 repl.getMultiple() + prevRepl.getMultiple());
1801 return flattenConcat(i - 1, i, replacement);
1802 }
1803 }
1804 }
1805
1806 // ... repl(x, n), x, ... ==> ..., repl(x, n+1), ...
1807 if (auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1808 if (repl.getOperand() == inputs[i]) {
1809 Value replacement = rewriter.createOrFold<ReplicateOp>(
1810 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1811 return flattenConcat(i - 1, i, replacement);
1812 }
1813 }
1814
1815 // Merge neighboring extracts of neighboring inputs, e.g.
1816 // {A[3], A[2]} -> A[3:2]
1817 if (auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1818 if (auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1819 if (extract.getInput() == prevExtract.getInput()) {
1820 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1821 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1822 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1823 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1824 Value replacement = rewriter.create<ExtractOp>(
1825 op.getLoc(), resType, extract.getInput(),
1826 extract.getLowBit());
1827 return flattenConcat(i - 1, i, replacement);
1828 }
1829 }
1830 }
1831 }
1832 // Merge neighboring array extracts of neighboring inputs, e.g.
1833 // {Array[4], bitcast(Array[3:2])} -> bitcast(A[4:2])
1834
1835 // This represents a slice of an array.
1836 struct ArraySlice {
1837 Value input;
1838 Value index;
1839 size_t width;
1840 static std::optional<ArraySlice> get(Value value) {
1841 assert(isa<IntegerType>(value.getType()) && "expected integer type");
1842 if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
1843 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1844 // array slice op is wrapped with bitcast.
1845 if (auto bitcast = value.getDefiningOp<hw::BitcastOp>())
1846 if (auto arraySlice =
1847 bitcast.getInput().getDefiningOp<hw::ArraySliceOp>())
1848 return ArraySlice{
1849 arraySlice.getInput(), arraySlice.getLowIndex(),
1850 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1851 .getNumElements()};
1852 return std::nullopt;
1853 }
1854 };
1855 if (auto extractOpt = ArraySlice::get(inputs[i])) {
1856 if (auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1857 // Check that two array slices are mergable.
1858 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1859 prevExtractOpt->input == extractOpt->input &&
1860 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1861 extractOpt->width)) {
1862 auto resType = hw::ArrayType::get(
1863 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1864 .getElementType(),
1865 extractOpt->width + prevExtractOpt->width);
1866 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1867 Value replacement = rewriter.create<hw::BitcastOp>(
1868 op.getLoc(), resIntType,
1869 rewriter.create<hw::ArraySliceOp>(op.getLoc(), resType,
1870 prevExtractOpt->input,
1871 extractOpt->index));
1872 return flattenConcat(i - 1, i, replacement);
1873 }
1874 }
1875 }
1876 }
1877 }
1878
1879 // If all operands were the same, then this is a replicate.
1880 if (commonOperand) {
1881 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1882 commonOperand);
1883 return success();
1884 }
1885
1886 return failure();
1887}
1888
1889//===----------------------------------------------------------------------===//
1890// MuxOp
1891//===----------------------------------------------------------------------===//
1892
1893OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1894 if (hasOperandsOutsideOfBlock(getOperation()))
1895 return {};
1896
1897 // mux (c, b, b) -> b
1898 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1899 return getTrueValue();
1900 if (auto tv = adaptor.getTrueValue())
1901 if (tv == adaptor.getFalseValue())
1902 return tv;
1903
1904 // mux(0, a, b) -> b
1905 // mux(1, a, b) -> a
1906 if (auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1907 if (pred.getValue().isZero())
1908 return getFalseValue();
1909 return getTrueValue();
1910 }
1911
1912 // mux(cond, 1, 0) -> cond
1913 if (auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1914 if (auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1915 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1916 hw::getBitWidth(getType()) == 1)
1917 return getCond();
1918
1919 return {};
1920}
1921
1922/// Check to see if the condition to the specified mux is an equality
1923/// comparison `indexValue` and one or more constants. If so, put the
1924/// constants in the constants vector and return true, otherwise return false.
1925///
1926/// This is part of foldMuxChain.
1927///
1928static bool
1929getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted,
1930 std::function<void(hw::ConstantOp)> constantFn) {
1931 // Handle `idx == 42` and `idx != 42`.
1932 if (auto cmp = cond.getDefiningOp<ICmpOp>()) {
1933 // TODO: We could handle things like "x < 2" as two entries.
1934 auto requiredPredicate =
1935 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1936 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1937 if (auto cst = cmp.getRhs().getDefiningOp<hw::ConstantOp>()) {
1938 constantFn(cst);
1939 return true;
1940 }
1941 }
1942 return false;
1943 }
1944
1945 // Handle mux(`idx == 1 || idx == 3`, value, muxchain).
1946 if (auto orOp = cond.getDefiningOp<OrOp>()) {
1947 if (!isInverted)
1948 return false;
1949 for (auto operand : orOp.getOperands())
1950 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
1951 return false;
1952 return true;
1953 }
1954
1955 // Handle mux(`idx != 1 && idx != 3`, muxchain, value).
1956 if (auto andOp = cond.getDefiningOp<AndOp>()) {
1957 if (isInverted)
1958 return false;
1959 for (auto operand : andOp.getOperands())
1960 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
1961 return false;
1962 return true;
1963 }
1964
1965 return false;
1966}
1967
1968/// Given a mux, check to see if the "on true" value (or "on false" value if
1969/// isFalseSide=true) is a mux tree with the same condition. This allows us
1970/// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into
1971/// `array_get (array_create(A, B, C), VAL)` which is far more compact and
1972/// allows synthesis tools to do more interesting optimizations.
1973///
1974/// This returns false if we cannot form the mux tree (or do not want to) and
1975/// returns true if the mux was replaced.
1976static bool foldMuxChain(MuxOp rootMux, bool isFalseSide,
1977 PatternRewriter &rewriter) {
1978 // Get the index value being compared. Later we check to see if it is
1979 // compared to a constant with the right predicate.
1980 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
1981 if (!rootCmp)
1982 return false;
1983 Value indexValue = rootCmp.getLhs();
1984
1985 // Return the value to use if the equality match succeeds.
1986 auto getCaseValue = [&](MuxOp mux) -> Value {
1987 return mux.getOperand(1 + unsigned(!isFalseSide));
1988 };
1989
1990 // Return the value to use if the equality match fails. This is the next
1991 // mux in the sequence or the "otherwise" value.
1992 auto getTreeValue = [&](MuxOp mux) -> Value {
1993 return mux.getOperand(1 + unsigned(isFalseSide));
1994 };
1995
1996 // Start scanning the mux tree to see what we've got. Keep track of the
1997 // constant comparison value and the SSA value to use when equal to it.
1998 SmallVector<Location> locationsFound;
1999 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
2000
2001 /// Extract constants and values into `valuesFound` and return true if this is
2002 /// part of the mux tree, otherwise return false.
2003 auto collectConstantValues = [&](MuxOp mux) -> bool {
2005 mux.getCond(), indexValue, isFalseSide, [&](hw::ConstantOp cst) {
2006 valuesFound.push_back({cst, getCaseValue(mux)});
2007 locationsFound.push_back(mux.getCond().getLoc());
2008 locationsFound.push_back(mux->getLoc());
2009 });
2010 };
2011
2012 // Make sure the root is a correct comparison with a constant.
2013 if (!collectConstantValues(rootMux))
2014 return false;
2015
2016 // Make sure that we're not looking at the intermediate node in a mux tree.
2017 if (rootMux->hasOneUse()) {
2018 if (auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
2019 if (getTreeValue(userMux) == rootMux.getResult() &&
2020 getMuxChainCondConstant(userMux.getCond(), indexValue, isFalseSide,
2021 [&](hw::ConstantOp cst) {}))
2022 return false;
2023 }
2024 }
2025
2026 // Scan up the tree linearly.
2027 auto nextTreeValue = getTreeValue(rootMux);
2028 while (1) {
2029 auto nextMux = nextTreeValue.getDefiningOp<MuxOp>();
2030 if (!nextMux || !nextMux->hasOneUse())
2031 break;
2032 if (!collectConstantValues(nextMux))
2033 break;
2034 nextTreeValue = getTreeValue(nextMux);
2035 }
2036
2037 // We need to have more than three values to create an array. This is an
2038 // arbitrary threshold which is saying that one or two muxes together is ok,
2039 // but three should be folded.
2040 if (valuesFound.size() < 3)
2041 return false;
2042
2043 // If the array is greater that 9 bits, it will take over 512 elements and
2044 // it will be too large for a single expression.
2045 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
2046 if (indexWidth >= 9)
2047 return false;
2048
2049 // Next we need to see if the values are dense-ish. We don't want to have
2050 // a tremendous number of replicated entries in the array. Some sparsity is
2051 // ok though, so we require the table to be at least 5/8 utilized.
2052 uint64_t tableSize = 1ULL << indexWidth;
2053 if (valuesFound.size() < (tableSize * 5) / 8)
2054 return false; // Not dense enough.
2055
2056 // Ok, we're going to do the transformation, start by building the table
2057 // filled with the "otherwise" value.
2058 SmallVector<Value, 8> table(tableSize, nextTreeValue);
2059
2060 // Fill in entries in the table from the leaf to the root of the expression.
2061 // This ensures that any duplicate matches end up with the ultimate value,
2062 // which is the one closer to the root.
2063 for (auto &elt : llvm::reverse(valuesFound)) {
2064 uint64_t idx = elt.first.getValue().getZExtValue();
2065 assert(idx < table.size() && "constant should be same bitwidth as index");
2066 table[idx] = elt.second;
2067 }
2068
2069 // The hw.array_create operation has the operand list in unintuitive order
2070 // with a[0] stored as the last element, not the first.
2071 std::reverse(table.begin(), table.end());
2072
2073 // Build the array_create and the array_get.
2074 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
2075 auto array = rewriter.create<hw::ArrayCreateOp>(fusedLoc, table);
2076 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
2077 indexValue);
2078 return true;
2079}
2080
2081/// Given a fully associative variadic operation like (a+b+c+d), break the
2082/// expression into two parts, one without the specified operand (e.g.
2083/// `tmp = a+b+d`) and one that combines that into the full expression (e.g.
2084/// `tmp+c`), and return the inner expression.
2085///
2086/// NOTE: This mutates the operation in place if it only has a single user,
2087/// which assumes that user will be removed.
2088///
2089static Value extractOperandFromFullyAssociative(Operation *fullyAssoc,
2090 size_t operandNo,
2091 PatternRewriter &rewriter) {
2092 assert(fullyAssoc->getNumOperands() >= 2 && "cannot split up unary ops");
2093 assert(operandNo < fullyAssoc->getNumOperands() && "Invalid operand #");
2094
2095 // If this expression already has two operands (the common case) no splitting
2096 // is necessary.
2097 if (fullyAssoc->getNumOperands() == 2)
2098 return fullyAssoc->getOperand(operandNo ^ 1);
2099
2100 // If the operation has a single use, mutate it in place.
2101 if (fullyAssoc->hasOneUse()) {
2102 rewriter.modifyOpInPlace(fullyAssoc,
2103 [&]() { fullyAssoc->eraseOperand(operandNo); });
2104 return fullyAssoc->getResult(0);
2105 }
2106
2107 // Form the new operation with the operands that remain.
2108 SmallVector<Value> operands;
2109 operands.append(fullyAssoc->getOperands().begin(),
2110 fullyAssoc->getOperands().begin() + operandNo);
2111 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2112 fullyAssoc->getOperands().end());
2113 Value opWithoutExcluded = createGenericOp(
2114 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2115 Value excluded = fullyAssoc->getOperand(operandNo);
2116
2117 Value fullResult =
2118 createGenericOp(fullyAssoc->getLoc(), fullyAssoc->getName(),
2119 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2120 replaceOpAndCopyNamehint(rewriter, fullyAssoc, fullResult);
2121 return opWithoutExcluded;
2122}
2123
2124/// Fold things like `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond)|a` and
2125/// `mux(cond, a, x|y|z|a) -> `(x|y|z)&replicate(~cond) | a` (when isTrueOperand
2126/// is true. Return true on successful transformation, false if not.
2127///
2128/// These are various forms of "predicated ops" that can be handled with a
2129/// replicate/and combination.
2130static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
2131 PatternRewriter &rewriter) {
2132 // Check to see the operand in question is an operation. If it is a port,
2133 // we can't simplify it.
2134 Operation *subExpr =
2135 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2136 if (!subExpr || subExpr->getNumOperands() < 2)
2137 return false;
2138
2139 // If this isn't an operation we can handle, don't spend energy on it.
2140 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2141 return false;
2142
2143 // Check to see if the common value occurs in the operand list for the
2144 // subexpression op. If so, then we can simplify it.
2145 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2146 size_t opNo = 0, e = subExpr->getNumOperands();
2147 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2148 ++opNo;
2149 if (opNo == e)
2150 return false;
2151
2152 // If we got a hit, then go ahead and simplify it!
2153 Value cond = op.getCond();
2154
2155 // `mux(cond, a, mux(cond2, a, b))` -> `mux(cond|cond2, a, b)`
2156 // `mux(cond, a, mux(cond2, b, a))` -> `mux(cond|~cond2, a, b)`
2157 // `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)`
2158 // `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)`
2159 if (auto subMux = dyn_cast<MuxOp>(subExpr)) {
2160 if (subMux == op)
2161 return false;
2162
2163 Value otherValue;
2164 Value subCond = subMux.getCond();
2165
2166 // Invert th subCond if needed and dig out the 'b' value.
2167 if (subMux.getTrueValue() == commonValue)
2168 otherValue = subMux.getFalseValue();
2169 else if (subMux.getFalseValue() == commonValue) {
2170 otherValue = subMux.getTrueValue();
2171 subCond = createOrFoldNot(op.getLoc(), subCond, rewriter);
2172 } else {
2173 // We can't fold `mux(cond, a, mux(a, x, y))`.
2174 return false;
2175 }
2176
2177 // Invert the outer cond if needed, and combine the mux conditions.
2178 if (!isTrueOperand)
2179 cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2180 cond = rewriter.createOrFold<OrOp>(op.getLoc(), cond, subCond, false);
2181 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2182 otherValue, op.getTwoState());
2183 return true;
2184 }
2185
2186 // Invert the condition if needed. Or/Xor invert when dealing with
2187 // TrueOperand, And inverts for False operand.
2188 bool isaAndOp = isa<AndOp>(subExpr);
2189 if (isTrueOperand ^ isaAndOp)
2190 cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2191
2192 auto extendedCond =
2193 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2194
2195 // Cache this information before subExpr is erased by extraction below.
2196 bool isaXorOp = isa<XorOp>(subExpr);
2197 bool isaOrOp = isa<OrOp>(subExpr);
2198
2199 // Handle the fully associative ops, start by pulling out the subexpression
2200 // from a many operand version of the op.
2201 auto restOfAssoc =
2202 extractOperandFromFullyAssociative(subExpr, opNo, rewriter);
2203
2204 // `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond) | a`
2205 // `mux(cond, x^y^z^a, a)` -> `(x^y^z)&replicate(cond) ^ a`
2206 if (isaOrOp || isaXorOp) {
2207 auto masked = rewriter.createOrFold<AndOp>(op.getLoc(), extendedCond,
2208 restOfAssoc, false);
2209 if (isaXorOp)
2210 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2211 commonValue, false);
2212 else
2213 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2214 false);
2215 return true;
2216 }
2217
2218 // `mux(cond, a, x&y&z&a)` -> `((x&y&z)|replicate(cond)) & a`
2219 assert(isaAndOp && "unexpected operation here");
2220 auto masked = rewriter.createOrFold<OrOp>(op.getLoc(), extendedCond,
2221 restOfAssoc, false);
2222 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2223 false);
2224 return true;
2225}
2226
2227/// This function is invoke when we find a mux with true/false operations that
2228/// have the same opcode. Check to see if we can strength reduce the mux by
2229/// applying it to less data by applying this transformation:
2230/// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2231static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp,
2232 Operation *falseOp,
2233 PatternRewriter &rewriter) {
2234 // Right now we only apply to concat.
2235 // TODO: Generalize this to and, or, xor, icmp(!), which all occur in practice
2236 if (!isa<ConcatOp>(trueOp))
2237 return false;
2238
2239 // Decode the operands, looking through recursive concats and replicates.
2240 SmallVector<Value> trueOperands, falseOperands;
2241 getConcatOperands(trueOp->getResult(0), trueOperands);
2242 getConcatOperands(falseOp->getResult(0), falseOperands);
2243
2244 size_t numTrueOperands = trueOperands.size();
2245 size_t numFalseOperands = falseOperands.size();
2246
2247 if (!numTrueOperands || !numFalseOperands ||
2248 (trueOperands.front() != falseOperands.front() &&
2249 trueOperands.back() != falseOperands.back()))
2250 return false;
2251
2252 // Pull all leading shared operands out into their own op if any are common.
2253 if (trueOperands.front() == falseOperands.front()) {
2254 SmallVector<Value> operands;
2255 size_t i;
2256 for (i = 0; i < numTrueOperands; ++i) {
2257 Value trueOperand = trueOperands[i];
2258 if (trueOperand == falseOperands[i])
2259 operands.push_back(trueOperand);
2260 else
2261 break;
2262 }
2263 if (i == numTrueOperands) {
2264 // Selecting between distinct, but lexically identical, concats.
2265 replaceOpAndCopyNamehint(rewriter, mux, trueOp->getResult(0));
2266 return true;
2267 }
2268
2269 Value sharedMSB;
2270 if (llvm::all_of(operands, [&](Value v) { return v == operands.front(); }))
2271 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2272 mux->getLoc(), operands.front(), operands.size());
2273 else
2274 sharedMSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2275 operands.clear();
2276
2277 // Get a concat of the LSB's on each side.
2278 operands.append(trueOperands.begin() + i, trueOperands.end());
2279 Value trueLSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2280 operands.clear();
2281 operands.append(falseOperands.begin() + i, falseOperands.end());
2282 Value falseLSB =
2283 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2284 // Merge the LSBs with a new mux and concat the MSB with the LSB to be
2285 // done.
2286 Value lsb = rewriter.createOrFold<MuxOp>(
2287 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2288 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2289 return true;
2290 }
2291
2292 // If trailing operands match, try to commonize them.
2293 if (trueOperands.back() == falseOperands.back()) {
2294 SmallVector<Value> operands;
2295 size_t i;
2296 for (i = 0;; ++i) {
2297 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2298 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2299 operands.push_back(trueOperand);
2300 else
2301 break;
2302 }
2303 std::reverse(operands.begin(), operands.end());
2304 Value sharedLSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2305 operands.clear();
2306
2307 // Get a concat of the MSB's on each side.
2308 operands.append(trueOperands.begin(), trueOperands.end() - i);
2309 Value trueMSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2310 operands.clear();
2311 operands.append(falseOperands.begin(), falseOperands.end() - i);
2312 Value falseMSB =
2313 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2314 // Merge the MSBs with a new mux and concat the MSB with the LSB to be done.
2315 Value msb = rewriter.createOrFold<MuxOp>(
2316 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2317 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2318 return true;
2319 }
2320
2321 return false;
2322}
2323
2324// If both arguments of the mux are arrays with the same elements, sink the
2325// mux and return a uniform array initializing all elements to it.
2326static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter) {
2327 auto trueVec = op.getTrueValue().getDefiningOp<hw::ArrayCreateOp>();
2328 auto falseVec = op.getFalseValue().getDefiningOp<hw::ArrayCreateOp>();
2329 if (!trueVec || !falseVec)
2330 return false;
2331 if (!trueVec.isUniform() || !falseVec.isUniform())
2332 return false;
2333
2334 auto mux = rewriter.create<MuxOp>(
2335 op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2336 falseVec.getUniformElement(), op.getTwoState());
2337
2338 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2339 rewriter.replaceOpWithNewOp<hw::ArrayCreateOp>(op, values);
2340 return true;
2341}
2342
2343namespace {
2344struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
2345 using OpRewritePattern::OpRewritePattern;
2346
2347 LogicalResult matchAndRewrite(MuxOp op,
2348 PatternRewriter &rewriter) const override;
2349};
2350
2351LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
2352 PatternRewriter &rewriter) const {
2353 if (hasOperandsOutsideOfBlock(&*op))
2354 return failure();
2355
2356 // If the op has a SV attribute, don't optimize it.
2357 if (hasSVAttributes(op))
2358 return failure();
2359 APInt value;
2360
2361 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2362 if (value.getBitWidth() == 1) {
2363 // mux(a, 0, b) -> and(~a, b) for single-bit values.
2364 if (value.isZero()) {
2365 auto notCond = createOrFoldNot(op.getLoc(), op.getCond(), rewriter);
2366 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2367 op.getFalseValue(), false);
2368 return success();
2369 }
2370
2371 // mux(a, 1, b) -> or(a, b) for single-bit values.
2372 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2373 op.getFalseValue(), false);
2374 return success();
2375 }
2376
2377 // Check for mux of two constants. There are many ways to simplify them.
2378 APInt value2;
2379 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2380 // When both inputs are constants and differ by only one bit, we can
2381 // simplify by splitting the mux into up to three contiguous chunks: one
2382 // for the differing bit and up to two for the bits that are the same.
2383 // E.g. mux(a, 3'h2, 0) -> concat(0, mux(a, 1, 0), 0) -> concat(0, a, 0)
2384 APInt xorValue = value ^ value2;
2385 if (xorValue.isPowerOf2()) {
2386 unsigned leadingZeros = xorValue.countLeadingZeros();
2387 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2388 SmallVector<Value, 3> operands;
2389
2390 // Concat operands go from MSB to LSB, so we handle chunks in reverse
2391 // order of bit indexes.
2392 // For the chunks that are identical (i.e. correspond to 0s in
2393 // xorValue), we can extract directly from either input value, and we
2394 // arbitrarily pick the trueValue().
2395
2396 if (leadingZeros > 0)
2397 operands.push_back(rewriter.createOrFold<ExtractOp>(
2398 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2399
2400 // Handle the differing bit, which should simplify into either cond or
2401 // ~cond.
2402 auto v1 = rewriter.createOrFold<ExtractOp>(
2403 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2404 auto v2 = rewriter.createOrFold<ExtractOp>(
2405 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2406 operands.push_back(rewriter.createOrFold<MuxOp>(
2407 op.getLoc(), op.getCond(), v1, v2, false));
2408
2409 if (trailingZeros > 0)
2410 operands.push_back(rewriter.createOrFold<ExtractOp>(
2411 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2412
2413 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2414 operands);
2415 return success();
2416 }
2417
2418 // If the true value is all ones and the false is all zeros then we have a
2419 // replicate pattern.
2420 if (value.isAllOnes() && value2.isZero()) {
2421 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2422 rewriter, op, op.getType(), op.getCond());
2423 return success();
2424 }
2425 }
2426 }
2427
2428 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2429 value.getBitWidth() == 1) {
2430 // mux(a, b, 0) -> and(a, b) for single-bit values.
2431 if (value.isZero()) {
2432 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2433 op.getTrueValue(), false);
2434 return success();
2435 }
2436
2437 // mux(a, b, 1) -> or(~a, b) for single-bit values.
2438 // falseValue() is known to be a single-bit 1, which we can use for
2439 // the 1 in the representation of ~ using xor.
2440 auto notCond = rewriter.createOrFold<XorOp>(op.getLoc(), op.getCond(),
2441 op.getFalseValue(), false);
2442 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2443 op.getTrueValue(), false);
2444 return success();
2445 }
2446
2447 // mux(!a, b, c) -> mux(a, c, b)
2448 Value subExpr;
2449 Operation *condOp = op.getCond().getDefiningOp();
2450 if (condOp && matchPattern(condOp, m_Complement(m_Any(&subExpr))) &&
2451 op.getTwoState()) {
2452 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2453 subExpr, op.getFalseValue(),
2454 op.getTrueValue(), true);
2455 return success();
2456 }
2457
2458 // Same but with Demorgan's law.
2459 // mux(and(~a, ~b, ~c), x, y) -> mux(or(a, b, c), y, x)
2460 // mux(or(~a, ~b, ~c), x, y) -> mux(and(a, b, c), y, x)
2461 if (condOp && condOp->hasOneUse()) {
2462 SmallVector<Value> invertedOperands;
2463
2464 /// Scan all the operands to see if they are complemented. If so, build a
2465 /// vector of them and return true, otherwise return false.
2466 auto getInvertedOperands = [&]() -> bool {
2467 for (Value operand : condOp->getOperands()) {
2468 if (matchPattern(operand, m_Complement(m_Any(&subExpr))))
2469 invertedOperands.push_back(subExpr);
2470 else
2471 return false;
2472 }
2473 return true;
2474 };
2475
2476 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2477 auto newOr =
2478 rewriter.createOrFold<OrOp>(op.getLoc(), invertedOperands, false);
2479 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2480 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2481 op.getTwoState());
2482 return success();
2483 }
2484 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2485 auto newAnd =
2486 rewriter.createOrFold<AndOp>(op.getLoc(), invertedOperands, false);
2487 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2488 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2489 op.getTwoState());
2490 return success();
2491 }
2492 }
2493
2494 if (auto falseMux = op.getFalseValue().getDefiningOp<MuxOp>();
2495 falseMux && falseMux != op) {
2496 // mux(selector, x, mux(selector, y, z) = mux(selector, x, z)
2497 if (op.getCond() == falseMux.getCond()) {
2498 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2499 rewriter, op, op.getCond(), op.getTrueValue(),
2500 falseMux.getFalseValue(), op.getTwoStateAttr());
2501 return success();
2502 }
2503
2504 // Check to see if we can fold a mux tree into an array_create/get pair.
2505 if (foldMuxChain(op, /*isFalse*/ true, rewriter))
2506 return success();
2507 }
2508
2509 if (auto trueMux = op.getTrueValue().getDefiningOp<MuxOp>();
2510 trueMux && trueMux != op) {
2511 // mux(selector, mux(selector, a, b), c) = mux(selector, a, c)
2512 if (op.getCond() == trueMux.getCond()) {
2513 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2514 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2515 op.getFalseValue(), op.getTwoStateAttr());
2516 return success();
2517 }
2518
2519 // Check to see if we can fold a mux tree into an array_create/get pair.
2520 if (foldMuxChain(op, /*isFalseSide*/ false, rewriter))
2521 return success();
2522 }
2523
2524 // mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c))
2525 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2526 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2527 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2528 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2529 falseMux != op) {
2530 auto subMux = rewriter.create<MuxOp>(
2531 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2532 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2533 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2534 trueMux.getTrueValue(), subMux,
2535 op.getTwoStateAttr());
2536 return success();
2537 }
2538
2539 // mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b)
2540 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2541 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2542 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2543 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2544 falseMux != op) {
2545 auto subMux = rewriter.create<MuxOp>(
2546 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2547 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2548 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2549 subMux, trueMux.getFalseValue(),
2550 op.getTwoStateAttr());
2551 return success();
2552 }
2553
2554 // mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b)
2555 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2556 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2557 trueMux && falseMux &&
2558 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2559 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2560 falseMux != op) {
2561 auto subMux = rewriter.create<MuxOp>(
2562 rewriter.getFusedLoc(
2563 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2564 op.getCond(), trueMux.getCond(), falseMux.getCond());
2565 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2566 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2567 op.getTwoStateAttr());
2568 return success();
2569 }
2570
2571 // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a
2572 if (foldCommonMuxValue(op, false, rewriter))
2573 return success();
2574 // mux(cond, a, x|y|z|a) -> (x|y|z)&replicate(~cond) | a
2575 if (foldCommonMuxValue(op, true, rewriter))
2576 return success();
2577
2578 // `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2579 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2580 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2581 if (trueOp->getName() == falseOp->getName())
2582 if (foldCommonMuxOperation(op, trueOp, falseOp, rewriter))
2583 return success();
2584
2585 // extracts only of mux(...) -> mux(extract()...)
2586 if (narrowOperationWidth(op, true, rewriter))
2587 return success();
2588
2589 // mux(cond, repl(n, a1), repl(n, a2)) -> repl(n, mux(cond, a1, a2))
2590 if (foldMuxOfUniformArrays(op, rewriter))
2591 return success();
2592
2593 return failure();
2594}
2595
2596static bool foldArrayOfMuxes(hw::ArrayCreateOp op, PatternRewriter &rewriter) {
2597 // Do not fold uniform or singleton arrays to avoid duplicating muxes.
2598 if (op.getInputs().empty() || op.isUniform())
2599 return false;
2600 auto inputs = op.getInputs();
2601 if (inputs.size() <= 1)
2602 return false;
2603
2604 // Check the operands to the array create. Ensure all of them are the
2605 // same op with the same number of operands.
2606 auto first = inputs[0].getDefiningOp<comb::MuxOp>();
2607 if (!first || hasSVAttributes(first))
2608 return false;
2609
2610 // Check whether all operands are muxes with the same condition.
2611 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2612 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2613 if (!input || first.getCond() != input.getCond())
2614 return false;
2615 }
2616
2617 // Collect the true and the false branches into arrays.
2618 SmallVector<Value> trues{first.getTrueValue()};
2619 SmallVector<Value> falses{first.getFalseValue()};
2620 SmallVector<Location> locs{first->getLoc()};
2621 bool isTwoState = true;
2622 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2623 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2624 trues.push_back(input.getTrueValue());
2625 falses.push_back(input.getFalseValue());
2626 locs.push_back(input->getLoc());
2627 if (!input.getTwoState())
2628 isTwoState = false;
2629 }
2630
2631 // Define the location of the array create as the aggregate of all muxes.
2632 auto loc = FusedLoc::get(op.getContext(), locs);
2633
2634 // Replace the create with an aggregate operation. Push the create op
2635 // into the operands of the aggregate operation.
2636 auto arrayTy = op.getType();
2637 auto trueValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, trues);
2638 auto falseValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, falses);
2639 rewriter.replaceOpWithNewOp<comb::MuxOp>(op, arrayTy, first.getCond(),
2640 trueValues, falseValues, isTwoState);
2641 return true;
2642}
2643
2644struct ArrayRewriter : public mlir::OpRewritePattern<hw::ArrayCreateOp> {
2645 using OpRewritePattern::OpRewritePattern;
2646
2647 LogicalResult matchAndRewrite(hw::ArrayCreateOp op,
2648 PatternRewriter &rewriter) const override {
2649 if (hasOperandsOutsideOfBlock(&*op))
2650 return failure();
2651
2652 if (foldArrayOfMuxes(op, rewriter))
2653 return success();
2654 return failure();
2655 }
2656};
2657
2658} // namespace
2659
2660void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2661 MLIRContext *context) {
2662 results.insert<MuxRewriter, ArrayRewriter>(context);
2663}
2664
2665//===----------------------------------------------------------------------===//
2666// ICmpOp
2667//===----------------------------------------------------------------------===//
2668
2669// Calculate the result of a comparison when the LHS and RHS are both
2670// constants.
2671static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs,
2672 const APInt &rhs) {
2673 switch (predicate) {
2674 case ICmpPredicate::eq:
2675 return lhs.eq(rhs);
2676 case ICmpPredicate::ne:
2677 return lhs.ne(rhs);
2678 case ICmpPredicate::slt:
2679 return lhs.slt(rhs);
2680 case ICmpPredicate::sle:
2681 return lhs.sle(rhs);
2682 case ICmpPredicate::sgt:
2683 return lhs.sgt(rhs);
2684 case ICmpPredicate::sge:
2685 return lhs.sge(rhs);
2686 case ICmpPredicate::ult:
2687 return lhs.ult(rhs);
2688 case ICmpPredicate::ule:
2689 return lhs.ule(rhs);
2690 case ICmpPredicate::ugt:
2691 return lhs.ugt(rhs);
2692 case ICmpPredicate::uge:
2693 return lhs.uge(rhs);
2694 case ICmpPredicate::ceq:
2695 return lhs.eq(rhs);
2696 case ICmpPredicate::cne:
2697 return lhs.ne(rhs);
2698 case ICmpPredicate::weq:
2699 return lhs.eq(rhs);
2700 case ICmpPredicate::wne:
2701 return lhs.ne(rhs);
2702 }
2703 llvm_unreachable("unknown comparison predicate");
2704}
2705
2706// Returns the result of applying the predicate when the LHS and RHS are the
2707// exact same value.
2708static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) {
2709 switch (predicate) {
2710 case ICmpPredicate::eq:
2711 case ICmpPredicate::sle:
2712 case ICmpPredicate::sge:
2713 case ICmpPredicate::ule:
2714 case ICmpPredicate::uge:
2715 case ICmpPredicate::ceq:
2716 case ICmpPredicate::weq:
2717 return true;
2718 case ICmpPredicate::ne:
2719 case ICmpPredicate::slt:
2720 case ICmpPredicate::sgt:
2721 case ICmpPredicate::ult:
2722 case ICmpPredicate::ugt:
2723 case ICmpPredicate::cne:
2724 case ICmpPredicate::wne:
2725 return false;
2726 }
2727 llvm_unreachable("unknown comparison predicate");
2728}
2729
2730OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2731 if (hasOperandsOutsideOfBlock(getOperation()))
2732 return {};
2733
2734 // gt a, a -> false
2735 // gte a, a -> true
2736 if (getLhs() == getRhs()) {
2737 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2738 return IntegerAttr::get(getType(), val);
2739 }
2740
2741 // gt 1, 2 -> false
2742 if (auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2743 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2744 auto val =
2745 applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2746 return IntegerAttr::get(getType(), val);
2747 }
2748 }
2749 return {};
2750}
2751
2752// Given a range of operands, computes the number of matching prefix and
2753// suffix elements. This does not perform cross-element matching.
2754template <typename Range>
2755static size_t computeCommonPrefixLength(const Range &a, const Range &b) {
2756 size_t commonPrefixLength = 0;
2757 auto ia = a.begin();
2758 auto ib = b.begin();
2759
2760 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2761 if (*ia != *ib) {
2762 break;
2763 }
2764 }
2765
2766 return commonPrefixLength;
2767}
2768
2769static size_t getTotalWidth(ArrayRef<Value> operands) {
2770 size_t totalWidth = 0;
2771 for (auto operand : operands) {
2772 // getIntOrFloatBitWidth should never raise, since all arguments to
2773 // ConcatOp are integers.
2774 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2775 assert(width >= 0);
2776 totalWidth += width;
2777 }
2778 return totalWidth;
2779}
2780
2781/// Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise
2782/// comparison on common prefix and suffixes. Returns success() if a rewriting
2783/// happens. This handles both concat and replicate.
2784static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs,
2785 Operation *rhs,
2786 PatternRewriter &rewriter) {
2787 // It is safe to assume that [{lhsOperands, rhsOperands}.size() > 0] and
2788 // all elements have non-zero length. Both these invariants are verified
2789 // by the ConcatOp verifier.
2790 SmallVector<Value> lhsOperands, rhsOperands;
2791 getConcatOperands(lhs->getResult(0), lhsOperands);
2792 getConcatOperands(rhs->getResult(0), rhsOperands);
2793 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2794
2795 auto formCatOrReplicate = [&](Location loc,
2796 ArrayRef<Value> operands) -> Value {
2797 assert(!operands.empty());
2798 Value sameElement = operands[0];
2799 for (size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2800 if (sameElement != operands[i])
2801 sameElement = Value();
2802 if (sameElement)
2803 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2804 operands.size());
2805 return rewriter.createOrFold<ConcatOp>(loc, operands);
2806 };
2807
2808 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2809 Value rhs) -> LogicalResult {
2810 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2811 op.getTwoState());
2812 return success();
2813 };
2814
2815 size_t commonPrefixLength =
2816 computeCommonPrefixLength(lhsOperands, rhsOperands);
2817 if (commonPrefixLength == lhsOperands.size()) {
2818 // cat(a, b, c) == cat(a, b, c) -> 1
2819 bool result = applyCmpPredicateToEqualOperands(op.getPredicate());
2820 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2821 APInt(1, result));
2822 return success();
2823 }
2824
2825 size_t commonSuffixLength = computeCommonPrefixLength(
2826 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2827
2828 size_t commonPrefixTotalWidth =
2829 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2830 size_t commonSuffixTotalWidth =
2831 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2832 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2833 .drop_back(commonSuffixLength);
2834 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2835 .drop_back(commonSuffixLength);
2836
2837 auto replaceWithoutReplicatingSignBit = [&]() {
2838 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2839 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2840 return replaceWith(op.getPredicate(), newLhs, newRhs);
2841 };
2842
2843 auto replaceWithReplicatingSignBit = [&]() {
2844 auto firstNonEmptyValue = lhsOperands[0];
2845 auto firstNonEmptyElemWidth =
2846 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2847 Value signBit = rewriter.createOrFold<ExtractOp>(
2848 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2849
2850 auto newLhs = rewriter.create<ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2851 auto newRhs = rewriter.create<ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2852 return replaceWith(op.getPredicate(), newLhs, newRhs);
2853 };
2854
2855 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2856 // scmp(cat(..x, b), cat(..y, b)) == scmp(cat(..x), cat(..y))
2857 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2858 return replaceWithoutReplicatingSignBit();
2859
2860 // scmp(cat(a, ..x, b), cat(a, ..y, b)) == scmp(cat(sgn(a), ..x),
2861 // cat(sgn(b), ..y)) Note that we cannot perform this optimization if
2862 // [width(b) = 0 && width(a) <= 1]. since that common prefix is the sign
2863 // bit. Doing the rewrite can result in an infinite loop.
2864 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2865 return replaceWithReplicatingSignBit();
2866
2867 } else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2868 // ucmp(cat(a, ..x, b), cat(a, ..y, b)) = ucmp(cat(..x), cat(..y))
2869 return replaceWithoutReplicatingSignBit();
2870 }
2871
2872 return failure();
2873}
2874
2875/// Given an equality comparison with a constant value and some operand that has
2876/// known bits, simplify the comparison to check only the unknown bits of the
2877/// input.
2878///
2879/// One simple example of this is that `concat(0, stuff) == 0` can be simplified
2880/// to `stuff == 0`, or `and(x, 3) == 0` can be simplified to
2881/// `extract x[1:0] == 0`
2883 ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst,
2884 PatternRewriter &rewriter) {
2885
2886 // If any of the known bits disagree with any of the comparison bits, then
2887 // we can constant fold this comparison right away.
2888 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2889 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2890 // If we discover a mismatch then we know an "eq" comparison is false
2891 // and a "ne" comparison is true!
2892 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2893 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2894 APInt(1, result));
2895 return;
2896 }
2897
2898 // Check to see if we can prove the result entirely of the comparison (in
2899 // which we bail out early), otherwise build a list of values to concat and a
2900 // smaller constant to compare against.
2901 SmallVector<Value> newConcatOperands;
2902 auto newConstant = APInt::getZeroWidth();
2903
2904 // Ok, some (maybe all) bits are known and some others may be unknown.
2905 // Extract out segments of the operand and compare against the
2906 // corresponding bits.
2907 unsigned knownMSB = bitsKnown.countLeadingOnes();
2908
2909 Value operand = cmpOp.getLhs();
2910
2911 // Ok, some bits are known but others are not. Extract out sequences of
2912 // bits that are unknown and compare just those bits. We work from MSB to
2913 // LSB.
2914 while (knownMSB != bitsKnown.getBitWidth()) {
2915 // Drop any high bits that are known.
2916 if (knownMSB)
2917 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
2918
2919 // Find the span of unknown bits, and extract it.
2920 unsigned unknownBits = bitsKnown.countLeadingZeros();
2921 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
2922 auto spanOperand = rewriter.createOrFold<ExtractOp>(
2923 operand.getLoc(), operand, /*lowBit=*/lowBit,
2924 /*bitWidth=*/unknownBits);
2925 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
2926
2927 // Add this info to the concat we're generating.
2928 newConcatOperands.push_back(spanOperand);
2929 // FIXME(llvm merge, cc697fc292b0): concat doesn't work with zero bit values
2930 // newConstant = newConstant.concat(spanConstant);
2931 if (newConstant.getBitWidth() != 0)
2932 newConstant = newConstant.concat(spanConstant);
2933 else
2934 newConstant = spanConstant;
2935
2936 // Drop the unknown bits in prep for the next chunk.
2937 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
2938 bitsKnown = bitsKnown.trunc(newWidth);
2939 knownMSB = bitsKnown.countLeadingOnes();
2940 }
2941
2942 // If all the operands to the concat are foldable then we have an identity
2943 // situation where all the sub-elements equal each other. This implies that
2944 // the overall result is foldable.
2945 if (newConcatOperands.empty()) {
2946 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
2947 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2948 APInt(1, result));
2949 return;
2950 }
2951
2952 // If we have a single operand remaining, use it, otherwise form a concat.
2953 Value concatResult =
2954 rewriter.createOrFold<ConcatOp>(operand.getLoc(), newConcatOperands);
2955
2956 // Form the comparison against the smaller constant.
2957 auto newConstantOp = rewriter.create<hw::ConstantOp>(
2958 cmpOp.getOperand(1).getLoc(), newConstant);
2959
2960 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
2961 cmpOp.getPredicate(), concatResult,
2962 newConstantOp, cmpOp.getTwoState());
2963}
2964
2965// Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), cst1^cst2).
2966static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp,
2967 const APInt &rhs,
2968 PatternRewriter &rewriter) {
2969 auto ip = rewriter.saveInsertionPoint();
2970 rewriter.setInsertionPoint(xorOp);
2971
2972 auto xorRHS = xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>();
2973 auto newRHS = rewriter.create<hw::ConstantOp>(xorRHS->getLoc(),
2974 xorRHS.getValue() ^ rhs);
2975 Value newLHS;
2976 switch (xorOp.getNumOperands()) {
2977 case 1:
2978 // This isn't common but is defined so we need to handle it.
2979 newLHS = rewriter.create<hw::ConstantOp>(xorOp.getLoc(),
2980 APInt::getZero(rhs.getBitWidth()));
2981 break;
2982 case 2:
2983 // The binary case is the most common.
2984 newLHS = xorOp.getOperand(0);
2985 break;
2986 default:
2987 // The general case forces us to form a new xor with the remaining operands.
2988 SmallVector<Value> newOperands(xorOp.getOperands());
2989 newOperands.pop_back();
2990 newLHS = rewriter.create<XorOp>(xorOp.getLoc(), newOperands, false);
2991 break;
2992 }
2993
2994 bool xorMultipleUses = !xorOp->hasOneUse();
2995
2996 // If the xor has multiple uses (not just the compare, then we need/want to
2997 // replace them as well.
2998 if (xorMultipleUses)
2999 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
3000 false);
3001
3002 // Replace the comparison.
3003 rewriter.restoreInsertionPoint(ip);
3004 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3005 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS, false);
3006}
3007
3008LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
3009 if (hasOperandsOutsideOfBlock(&*op))
3010 return failure();
3011
3012 APInt lhs, rhs;
3013
3014 // icmp 1, x -> icmp x, 1
3015 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
3016 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
3017 "Should be folded");
3018 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3019 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
3020 op.getRhs(), op.getLhs(), op.getTwoState());
3021 return success();
3022 }
3023
3024 // Canonicalize with RHS constant
3025 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
3026 auto getConstant = [&](APInt constant) -> Value {
3027 return rewriter.create<hw::ConstantOp>(op.getLoc(), std::move(constant));
3028 };
3029
3030 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
3031 Value rhs) -> LogicalResult {
3032 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
3033 rhs, op.getTwoState());
3034 return success();
3035 };
3036
3037 auto replaceWithConstantI1 = [&](bool constant) -> LogicalResult {
3038 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
3039 APInt(1, constant));
3040 return success();
3041 };
3042
3043 switch (op.getPredicate()) {
3044 case ICmpPredicate::slt:
3045 // x < max -> x != max
3046 if (rhs.isMaxSignedValue())
3047 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3048 // x < min -> false
3049 if (rhs.isMinSignedValue())
3050 return replaceWithConstantI1(0);
3051 // x < min+1 -> x == min
3052 if ((rhs - 1).isMinSignedValue())
3053 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3054 getConstant(rhs - 1));
3055 break;
3056 case ICmpPredicate::sgt:
3057 // x > min -> x != min
3058 if (rhs.isMinSignedValue())
3059 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3060 // x > max -> false
3061 if (rhs.isMaxSignedValue())
3062 return replaceWithConstantI1(0);
3063 // x > max-1 -> x == max
3064 if ((rhs + 1).isMaxSignedValue())
3065 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3066 getConstant(rhs + 1));
3067 break;
3068 case ICmpPredicate::ult:
3069 // x < max -> x != max
3070 if (rhs.isAllOnes())
3071 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3072 // x < min -> false
3073 if (rhs.isZero())
3074 return replaceWithConstantI1(0);
3075 // x < min+1 -> x == min
3076 if ((rhs - 1).isZero())
3077 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3078 getConstant(rhs - 1));
3079
3080 // x < 0xE0 -> extract(x, 5..7) != 0b111
3081 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
3082 rhs.getBitWidth()) {
3083 auto numOnes = rhs.countLeadingOnes();
3084 auto smaller = rewriter.create<ExtractOp>(
3085 op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
3086 return replaceWith(ICmpPredicate::ne, smaller,
3087 getConstant(APInt::getAllOnes(numOnes)));
3088 }
3089
3090 break;
3091 case ICmpPredicate::ugt:
3092 // x > min -> x != min
3093 if (rhs.isZero())
3094 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
3095 // x > max -> false
3096 if (rhs.isAllOnes())
3097 return replaceWithConstantI1(0);
3098 // x > max-1 -> x == max
3099 if ((rhs + 1).isAllOnes())
3100 return replaceWith(ICmpPredicate::eq, op.getLhs(),
3101 getConstant(rhs + 1));
3102
3103 // x > 0x07 -> extract(x, 3..7) != 0b00000
3104 if ((rhs + 1).isPowerOf2()) {
3105 auto numOnes = rhs.countTrailingOnes();
3106 auto newWidth = rhs.getBitWidth() - numOnes;
3107 auto smaller = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(),
3108 numOnes, newWidth);
3109 return replaceWith(ICmpPredicate::ne, smaller,
3110 getConstant(APInt::getZero(newWidth)));
3111 }
3112
3113 break;
3114 case ICmpPredicate::sle:
3115 // x <= max -> true
3116 if (rhs.isMaxSignedValue())
3117 return replaceWithConstantI1(1);
3118 // x <= c -> x < (c+1)
3119 return replaceWith(ICmpPredicate::slt, op.getLhs(), getConstant(rhs + 1));
3120 case ICmpPredicate::sge:
3121 // x >= min -> true
3122 if (rhs.isMinSignedValue())
3123 return replaceWithConstantI1(1);
3124 // x >= c -> x > (c-1)
3125 return replaceWith(ICmpPredicate::sgt, op.getLhs(), getConstant(rhs - 1));
3126 case ICmpPredicate::ule:
3127 // x <= max -> true
3128 if (rhs.isAllOnes())
3129 return replaceWithConstantI1(1);
3130 // x <= c -> x < (c+1)
3131 return replaceWith(ICmpPredicate::ult, op.getLhs(), getConstant(rhs + 1));
3132 case ICmpPredicate::uge:
3133 // x >= min -> true
3134 if (rhs.isZero())
3135 return replaceWithConstantI1(1);
3136 // x >= c -> x > (c-1)
3137 return replaceWith(ICmpPredicate::ugt, op.getLhs(), getConstant(rhs - 1));
3138 case ICmpPredicate::eq:
3139 if (rhs.getBitWidth() == 1) {
3140 if (rhs.isZero()) {
3141 // x == 0 -> x ^ 1
3142 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3143 getConstant(APInt(1, 1)),
3144 op.getTwoState());
3145 return success();
3146 }
3147 if (rhs.isAllOnes()) {
3148 // x == 1 -> x
3149 replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
3150 return success();
3151 }
3152 }
3153 break;
3154 case ICmpPredicate::ne:
3155 if (rhs.getBitWidth() == 1) {
3156 if (rhs.isZero()) {
3157 // x != 0 -> x
3158 replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
3159 return success();
3160 }
3161 if (rhs.isAllOnes()) {
3162 // x != 1 -> x ^ 1
3163 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3164 getConstant(APInt(1, 1)),
3165 op.getTwoState());
3166 return success();
3167 }
3168 }
3169 break;
3170 case ICmpPredicate::ceq:
3171 case ICmpPredicate::cne:
3172 case ICmpPredicate::weq:
3173 case ICmpPredicate::wne:
3174 break;
3175 }
3176
3177 // We have some specific optimizations for comparison with a constant that
3178 // are only supported for equality comparisons.
3179 if (op.getPredicate() == ICmpPredicate::eq ||
3180 op.getPredicate() == ICmpPredicate::ne) {
3181 // Simplify `icmp(value_with_known_bits, rhscst)` into some extracts
3182 // with a smaller constant. We only support equality comparisons for
3183 // this.
3184 auto knownBits = computeKnownBits(op.getLhs());
3185 if (!knownBits.isUnknown())
3186 return combineEqualityICmpWithKnownBitsAndConstant(op, knownBits, rhs,
3187 rewriter),
3188 success();
3189
3190 // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b),
3191 // cst1^cst2).
3192 if (auto xorOp = op.getLhs().getDefiningOp<XorOp>())
3193 if (xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>())
3194 return combineEqualityICmpWithXorOfConstant(op, xorOp, rhs, rewriter),
3195 success();
3196
3197 // Simplify icmp eq(replicate(v, n), c) -> icmp eq(v, c) if c is zero or
3198 // all one.
3199 if (auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3200 if (rhs.isAllOnes() || rhs.isZero()) {
3201 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3202 auto cst = rewriter.create<hw::ConstantOp>(
3203 op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(width)
3204 : APInt::getZero(width));
3205 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3206 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3207 op.getTwoState());
3208 return success();
3209 }
3210 }
3211 }
3212
3213 // icmp(cat(prefix, a, b, suffix), cat(prefix, c, d, suffix)) => icmp(cat(a,
3214 // b), cat(c, d)). contains special handling for sign bit in signed
3215 // compressions.
3216 if (Operation *opLHS = op.getLhs().getDefiningOp())
3217 if (Operation *opRHS = op.getRhs().getDefiningOp())
3218 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3219 isa<ConcatOp, ReplicateOp>(opRHS)) {
3220 if (succeeded(matchAndRewriteCompareConcat(op, opLHS, opRHS, rewriter)))
3221 return success();
3222 }
3223
3224 return failure();
3225}
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
Definition CombFolds.cpp:91
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool hasOperandsOutsideOfBlock(Operation *op)
In comb, we assume no knowledge of the semantics of cross-block dataflow.
Definition CombFolds.cpp:33
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
Definition CombFolds.cpp:59
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
Definition CombFolds.cpp:53
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
Definition CombFolds.cpp:97
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
static bool hasSVAttributes(Operation *op)
Definition CombFolds.cpp:74
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
Definition CombFolds.cpp:45
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(low_bit, result_type, input=None)
Definition comb.py:187
create(elements)
Definition hw.py:483
create(data_type, value)
Definition hw.py:433
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition CombOps.cpp:48
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
Definition comb.py:1