CIRCT 21.0.0git
Loading...
Searching...
No Matches
CombFolds.cpp
Go to the documentation of this file.
1//===- CombFolds.cpp - Folds + Canonicalization for Comb operations -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/PatternMatch.h"
15#include "llvm/ADT/SetVector.h"
16#include "llvm/ADT/SmallBitVector.h"
17#include "llvm/ADT/TypeSwitch.h"
18#include "llvm/Support/KnownBits.h"
19
20using namespace mlir;
21using namespace circt;
22using namespace comb;
23using namespace matchers;
24
25/// Create a new instance of a generic operation that only has value operands,
26/// and has a single result value whose type matches the first operand.
27///
28/// This should not be used to create instances of ops with attributes or with
29/// more complicated type signatures.
30static Value createGenericOp(Location loc, OperationName name,
31 ArrayRef<Value> operands, OpBuilder &builder) {
32 OperationState state(loc, name);
33 state.addOperands(operands);
34 state.addTypes(operands[0].getType());
35 return builder.create(state)->getResult(0);
36}
37
38static TypedAttr getIntAttr(const APInt &value, MLIRContext *context) {
39 return IntegerAttr::get(IntegerType::get(context, value.getBitWidth()),
40 value);
41}
42
43/// Flatten concat and mux operands into a vector.
44static void getConcatOperands(Value v, SmallVectorImpl<Value> &result) {
45 if (auto concat = v.getDefiningOp<ConcatOp>()) {
46 for (auto op : concat.getOperands())
47 getConcatOperands(op, result);
48 } else if (auto repl = v.getDefiningOp<ReplicateOp>()) {
49 for (size_t i = 0, e = repl.getMultiple(); i != e; ++i)
50 getConcatOperands(repl.getOperand(), result);
51 } else {
52 result.push_back(v);
53 }
54}
55
56// Return true if the op has SV attributes. Note that we cannot use a helper
57// function `hasSVAttributes` defined under SV dialect because of a cyclic
58// dependency.
59static bool hasSVAttributes(Operation *op) {
60 return op->hasAttr("sv.attributes");
61}
62
63namespace {
64template <typename SubType>
65struct ComplementMatcher {
66 SubType lhs;
67 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
68 bool match(Operation *op) {
69 auto xorOp = dyn_cast<XorOp>(op);
70 return xorOp && xorOp.isBinaryNot() && lhs.match(op->getOperand(0));
71 }
72};
73} // end anonymous namespace
74
75template <typename SubType>
76static inline ComplementMatcher<SubType> m_Complement(const SubType &subExpr) {
77 return ComplementMatcher<SubType>(subExpr);
78}
79
80/// Return true if the op will be flattened afterwards. Op will be flattend if
81/// it has a single user which has a same op type. User must be in same block.
82static bool shouldBeFlattened(Operation *op) {
83 assert((isa<AndOp, OrOp, XorOp, AddOp, MulOp>(op) &&
84 "must be commutative operations"));
85 if (op->hasOneUse()) {
86 auto *user = *op->getUsers().begin();
87 return user->getName() == op->getName() &&
88 op->getAttrOfType<UnitAttr>("twoState") ==
89 user->getAttrOfType<UnitAttr>("twoState") &&
90 op->getBlock() == user->getBlock();
91 }
92 return false;
93}
94
95/// Flattens a single input in `op` if `hasOneUse` is true and it can be defined
96/// as an Op. Returns true if successful, and false otherwise.
97///
98/// Example: op(1, 2, op(3, 4), 5) -> op(1, 2, 3, 4, 5) // returns true
99///
100static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter) {
101 // Skip if the operation should be flattened by another operation.
102 if (shouldBeFlattened(op))
103 return false;
104
105 auto inputs = op->getOperands();
106
107 SmallVector<Value, 4> newOperands;
108 SmallVector<Location, 4> newLocations{op->getLoc()};
109 newOperands.reserve(inputs.size());
110 struct Element {
111 decltype(inputs.begin()) current, end;
112 };
113
114 SmallVector<Element> worklist;
115 worklist.push_back({inputs.begin(), inputs.end()});
116 bool binFlag = op->hasAttrOfType<UnitAttr>("twoState");
117 bool changed = false;
118 while (!worklist.empty()) {
119 auto &element = worklist.back(); // Do not pop. Take ref.
120
121 // Pop when we finished traversing the current operand range.
122 if (element.current == element.end) {
123 worklist.pop_back();
124 continue;
125 }
126
127 Value value = *element.current++;
128 auto *flattenOp = value.getDefiningOp();
129 // If not defined by a compatible operation of the same kind and
130 // from the same block, keep this as-is.
131 if (!flattenOp || flattenOp->getName() != op->getName() ||
132 flattenOp == op || binFlag != op->hasAttrOfType<UnitAttr>("twoState") ||
133 flattenOp->getBlock() != op->getBlock()) {
134 newOperands.push_back(value);
135 continue;
136 }
137
138 // Don't duplicate logic when it has multiple uses.
139 if (!value.hasOneUse()) {
140 // We can fold a multi-use binary operation into this one if this allows a
141 // constant to fold though. For example, fold
142 // (or a, b, c, (or d, cst1), cst2) --> (or a, b, c, d, cst1, cst2)
143 // since the constants will both fold and we end up with the equiv cost.
144 //
145 // We don't do this for add/mul because the hardware won't be shared
146 // between the two ops if duplicated.
147 if (flattenOp->getNumOperands() != 2 || !isa<AndOp, OrOp, XorOp>(op) ||
148 !flattenOp->getOperand(1).getDefiningOp<hw::ConstantOp>() ||
149 !inputs.back().getDefiningOp<hw::ConstantOp>()) {
150 newOperands.push_back(value);
151 continue;
152 }
153 }
154
155 changed = true;
156
157 // Otherwise, push operands into worklist.
158 auto flattenOpInputs = flattenOp->getOperands();
159 worklist.push_back({flattenOpInputs.begin(), flattenOpInputs.end()});
160 newLocations.push_back(flattenOp->getLoc());
161 }
162
163 if (!changed)
164 return false;
165
166 Value result = createGenericOp(FusedLoc::get(op->getContext(), newLocations),
167 op->getName(), newOperands, rewriter);
168 if (binFlag)
169 result.getDefiningOp()->setAttr("twoState", rewriter.getUnitAttr());
170
171 replaceOpAndCopyNamehint(rewriter, op, result);
172 return true;
173}
174
175// Given a range of uses of an operation, find the lowest and highest bits
176// inclusive that are ever referenced. The range of uses must not be empty.
177static std::pair<size_t, size_t>
178getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits,
179 size_t originalOpWidth) {
180 auto users = op->getUsers();
181 assert(!users.empty() &&
182 "getLowestBitAndHighestBitRequired cannot operate on "
183 "a empty list of uses.");
184
185 // when we don't want to narrowTrailingBits (namely in arithmetic
186 // operations), forcing lowestBitRequired = 0
187 size_t lowestBitRequired = narrowTrailingBits ? originalOpWidth - 1 : 0;
188 size_t highestBitRequired = 0;
189
190 for (auto *user : users) {
191 if (auto extractOp = dyn_cast<ExtractOp>(user)) {
192 size_t lowBit = extractOp.getLowBit();
193 size_t highBit =
194 cast<IntegerType>(extractOp.getType()).getWidth() + lowBit - 1;
195 highestBitRequired = std::max(highestBitRequired, highBit);
196 lowestBitRequired = std::min(lowestBitRequired, lowBit);
197 continue;
198 }
199
200 highestBitRequired = originalOpWidth - 1;
201 lowestBitRequired = 0;
202 break;
203 }
204
205 return {lowestBitRequired, highestBitRequired};
206}
207
208template <class OpTy>
209static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits,
210 PatternRewriter &rewriter) {
211 IntegerType opType = dyn_cast<IntegerType>(op.getResult().getType());
212 if (!opType)
213 return false;
214
215 auto range = getLowestBitAndHighestBitRequired(op, narrowTrailingBits,
216 opType.getWidth());
217 if (range.second + 1 == opType.getWidth() && range.first == 0)
218 return false;
219
220 SmallVector<Value> args;
221 auto newType = rewriter.getIntegerType(range.second - range.first + 1);
222 for (auto inop : op.getOperands()) {
223 // deal with muxes here
224 if (inop.getType() != op.getType())
225 args.push_back(inop);
226 else
227 args.push_back(rewriter.createOrFold<ExtractOp>(inop.getLoc(), newType,
228 inop, range.first));
229 }
230 auto newop = rewriter.create<OpTy>(op.getLoc(), newType, args);
231 newop->setDialectAttrs(op->getDialectAttrs());
232 if (op.getTwoState())
233 newop.setTwoState(true);
234
235 Value newResult = newop.getResult();
236 if (range.first)
237 newResult = rewriter.createOrFold<ConcatOp>(
238 op.getLoc(), newResult,
239 rewriter.create<hw::ConstantOp>(op.getLoc(),
240 APInt::getZero(range.first)));
241 if (range.second + 1 < opType.getWidth())
242 newResult = rewriter.createOrFold<ConcatOp>(
243 op.getLoc(),
244 rewriter.create<hw::ConstantOp>(
245 op.getLoc(), APInt::getZero(opType.getWidth() - range.second - 1)),
246 newResult);
247 rewriter.replaceOp(op, newResult);
248 return true;
249}
250
251//===----------------------------------------------------------------------===//
252// Unary Operations
253//===----------------------------------------------------------------------===//
254
255OpFoldResult ReplicateOp::fold(FoldAdaptor adaptor) {
256 // Replicate one time -> noop.
257 if (cast<IntegerType>(getType()).getWidth() ==
258 getInput().getType().getIntOrFloatBitWidth())
259 return getInput();
260
261 // Constant fold.
262 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
263 if (input.getValue().getBitWidth() == 1) {
264 if (input.getValue().isZero())
265 return getIntAttr(
266 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
267 getContext());
268 return getIntAttr(
269 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
270 getContext());
271 }
272
273 APInt result = APInt::getZeroWidth();
274 for (auto i = getMultiple(); i != 0; --i)
275 result = result.concat(input.getValue());
276 return getIntAttr(result, getContext());
277 }
278
279 return {};
280}
281
282OpFoldResult ParityOp::fold(FoldAdaptor adaptor) {
283 // Constant fold.
284 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput()))
285 return getIntAttr(APInt(1, input.getValue().popcount() & 1), getContext());
286
287 return {};
288}
289
290//===----------------------------------------------------------------------===//
291// Binary Operations
292//===----------------------------------------------------------------------===//
293
294/// Performs constant folding `calculate` with element-wise behavior on the two
295/// attributes in `operands` and returns the result if possible.
296static Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
297 hw::PEO paramOpcode) {
298 assert(operands.size() == 2 && "binary op takes two operands");
299 if (!operands[0] || !operands[1])
300 return {};
301
302 // Fold constants with ParamExprAttr::get which handles simple constants as
303 // well as parameter expressions.
304 return hw::ParamExprAttr::get(paramOpcode, cast<TypedAttr>(operands[0]),
305 cast<TypedAttr>(operands[1]));
306}
307
308OpFoldResult ShlOp::fold(FoldAdaptor adaptor) {
309 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
310 unsigned shift = rhs.getValue().getZExtValue();
311 unsigned width = getType().getIntOrFloatBitWidth();
312 if (shift == 0)
313 return getOperand(0);
314 if (width <= shift)
315 return getIntAttr(APInt::getZero(width), getContext());
316 }
317
318 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::Shl);
319}
320
321LogicalResult ShlOp::canonicalize(ShlOp op, PatternRewriter &rewriter) {
322 // ShlOp(x, cst) -> Concat(Extract(x), zeros)
323 APInt value;
324 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
325 return failure();
326
327 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
328 unsigned shift = value.getZExtValue();
329
330 // This case is handled by fold.
331 if (width <= shift || shift == 0)
332 return failure();
333
334 auto zeros =
335 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
336
337 // Remove the high bits which would be removed by the Shl.
338 auto extract =
339 rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), 0, width - shift);
340
341 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, extract, zeros);
342 return success();
343}
344
345OpFoldResult ShrUOp::fold(FoldAdaptor adaptor) {
346 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
347 unsigned shift = rhs.getValue().getZExtValue();
348 if (shift == 0)
349 return getOperand(0);
350
351 unsigned width = getType().getIntOrFloatBitWidth();
352 if (width <= shift)
353 return getIntAttr(APInt::getZero(width), getContext());
354 }
355 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrU);
356}
357
358LogicalResult ShrUOp::canonicalize(ShrUOp op, PatternRewriter &rewriter) {
359 // ShrUOp(x, cst) -> Concat(zeros, Extract(x))
360 APInt value;
361 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
362 return failure();
363
364 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
365 unsigned shift = value.getZExtValue();
366
367 // This case is handled by fold.
368 if (width <= shift || shift == 0)
369 return failure();
370
371 auto zeros =
372 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(shift));
373
374 // Remove the low bits which would be removed by the Shr.
375 auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
376 width - shift);
377
378 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, zeros, extract);
379 return success();
380}
381
382OpFoldResult ShrSOp::fold(FoldAdaptor adaptor) {
383 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
384 if (rhs.getValue().getZExtValue() == 0)
385 return getOperand(0);
386 }
387 return constFoldBinaryOp(adaptor.getOperands(), hw::PEO::ShrS);
388}
389
390LogicalResult ShrSOp::canonicalize(ShrSOp op, PatternRewriter &rewriter) {
391 // ShrSOp(x, cst) -> Concat(replicate(extract(x, topbit)),extract(x))
392 APInt value;
393 if (!matchPattern(op.getRhs(), m_ConstantInt(&value)))
394 return failure();
395
396 unsigned width = cast<IntegerType>(op.getLhs().getType()).getWidth();
397 unsigned shift = value.getZExtValue();
398
399 auto topbit =
400 rewriter.createOrFold<ExtractOp>(op.getLoc(), op.getLhs(), width - 1, 1);
401 auto sext = rewriter.createOrFold<ReplicateOp>(op.getLoc(), topbit, shift);
402
403 if (width <= shift) {
404 replaceOpAndCopyNamehint(rewriter, op, {sext});
405 return success();
406 }
407
408 auto extract = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(), shift,
409 width - shift);
410
411 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, sext, extract);
412 return success();
413}
414
415//===----------------------------------------------------------------------===//
416// Other Operations
417//===----------------------------------------------------------------------===//
418
419OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
420 // If we are extracting the entire input, then return it.
421 if (getInput().getType() == getType())
422 return getInput();
423
424 // Constant fold.
425 if (auto input = dyn_cast_or_null<IntegerAttr>(adaptor.getInput())) {
426 unsigned dstWidth = cast<IntegerType>(getType()).getWidth();
427 return getIntAttr(input.getValue().lshr(getLowBit()).trunc(dstWidth),
428 getContext());
429 }
430 return {};
431}
432
433// Transforms extract(lo, cat(a, b, c, d, e)) into
434// cat(extract(lo1, b), c, extract(lo2, d)).
435// innerCat must be the argument of the provided ExtractOp.
437 ConcatOp innerCat,
438 PatternRewriter &rewriter) {
439 auto reversedConcatArgs = llvm::reverse(innerCat.getInputs());
440 size_t beginOfFirstRelevantElement = 0;
441 auto it = reversedConcatArgs.begin();
442 size_t lowBit = op.getLowBit();
443
444 // This loop finds the first concatArg that is covered by the ExtractOp
445 for (; it != reversedConcatArgs.end(); it++) {
446 assert(beginOfFirstRelevantElement <= lowBit &&
447 "incorrectly moved past an element that lowBit has coverage over");
448 auto operand = *it;
449
450 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
451 if (lowBit < beginOfFirstRelevantElement + operandWidth) {
452 // A bit other than the first bit will be used in this element.
453 // ...... ........ ...
454 // ^---lowBit
455 // ^---beginOfFirstRelevantElement
456 //
457 // Edge-case close to the end of the range.
458 // ...... ........ ...
459 // ^---(position + operandWidth)
460 // ^---lowBit
461 // ^---beginOfFirstRelevantElement
462 //
463 // Edge-case close to the beginning of the rang
464 // ...... ........ ...
465 // ^---lowBit
466 // ^---beginOfFirstRelevantElement
467 //
468 break;
469 }
470
471 // extraction discards this element.
472 // ...... ........ ...
473 // | ^---lowBit
474 // ^---beginOfFirstRelevantElement
475 beginOfFirstRelevantElement += operandWidth;
476 }
477 assert(it != reversedConcatArgs.end() &&
478 "incorrectly failed to find an element which contains coverage of "
479 "lowBit");
480
481 SmallVector<Value> reverseConcatArgs;
482 size_t widthRemaining = cast<IntegerType>(op.getType()).getWidth();
483 size_t extractLo = lowBit - beginOfFirstRelevantElement;
484
485 // Transform individual arguments of innerCat(..., a, b, c,) into
486 // [ extract(a), b, extract(c) ], skipping an extract operation where
487 // possible.
488 for (; widthRemaining != 0 && it != reversedConcatArgs.end(); it++) {
489 auto concatArg = *it;
490 size_t operandWidth = concatArg.getType().getIntOrFloatBitWidth();
491 size_t widthToConsume = std::min(widthRemaining, operandWidth - extractLo);
492
493 if (widthToConsume == operandWidth && extractLo == 0) {
494 reverseConcatArgs.push_back(concatArg);
495 } else {
496 auto resultType = IntegerType::get(rewriter.getContext(), widthToConsume);
497 reverseConcatArgs.push_back(
498 rewriter.create<ExtractOp>(op.getLoc(), resultType, *it, extractLo));
499 }
500
501 widthRemaining -= widthToConsume;
502
503 // Beyond the first element, all elements are extracted from position 0.
504 extractLo = 0;
505 }
506
507 if (reverseConcatArgs.size() == 1) {
508 replaceOpAndCopyNamehint(rewriter, op, reverseConcatArgs[0]);
509 } else {
510 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
511 rewriter, op, SmallVector<Value>(llvm::reverse(reverseConcatArgs)));
512 }
513 return success();
514}
515
516// Transforms extract(lo, replicate(a, N)) into replicate(a, N-c).
517static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate,
518 PatternRewriter &rewriter) {
519 auto extractResultWidth = cast<IntegerType>(op.getType()).getWidth();
520 auto replicateEltWidth =
521 replicate.getOperand().getType().getIntOrFloatBitWidth();
522
523 // If the extract starts at the base of an element and is an even multiple,
524 // we can replace the extract with a smaller replicate.
525 if (op.getLowBit() % replicateEltWidth == 0 &&
526 extractResultWidth % replicateEltWidth == 0) {
527 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
528 replicate.getOperand());
529 return true;
530 }
531
532 // If the extract is completely contained in one element, extract from the
533 // element.
534 if (op.getLowBit() % replicateEltWidth + extractResultWidth <=
535 replicateEltWidth) {
536 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
537 rewriter, op, op.getType(), replicate.getOperand(),
538 op.getLowBit() % replicateEltWidth);
539 return true;
540 }
541
542 // We don't currently handle the case of extracting from non-whole elements,
543 // e.g. `extract (replicate 2-bit-thing, N), 1`.
544 return false;
545}
546
547LogicalResult ExtractOp::canonicalize(ExtractOp op, PatternRewriter &rewriter) {
548 auto *inputOp = op.getInput().getDefiningOp();
549
550 // This turns out to be incredibly expensive. Disable until performance is
551 // addressed.
552#if 0
553 // If the extracted bits are all known, then return the result.
554 auto knownBits = computeKnownBits(op.getInput())
555 .extractBits(cast<IntegerType>(op.getType()).getWidth(),
556 op.getLowBit());
557 if (knownBits.isConstant()) {
558 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
559 knownBits.getConstant());
560 return success();
561 }
562#endif
563
564 // extract(olo, extract(ilo, x)) = extract(olo + ilo, x)
565 if (auto innerExtract = dyn_cast_or_null<ExtractOp>(inputOp)) {
566 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
567 rewriter, op, op.getType(), innerExtract.getInput(),
568 innerExtract.getLowBit() + op.getLowBit());
569 return success();
570 }
571
572 // extract(lo, cat(a, b, c, d, e)) = cat(extract(lo1, b), c, extract(lo2, d))
573 if (auto innerCat = dyn_cast_or_null<ConcatOp>(inputOp))
574 return extractConcatToConcatExtract(op, innerCat, rewriter);
575
576 // extract(lo, replicate(a))
577 if (auto replicate = dyn_cast_or_null<ReplicateOp>(inputOp))
578 if (extractFromReplicate(op, replicate, rewriter))
579 return success();
580
581 // `extract(and(a, cst))` -> `extract(a)` when the relevant bits of the
582 // and/or/xor are not modifying the extracted bits.
583 if (inputOp && inputOp->getNumOperands() == 2 &&
584 isa<AndOp, OrOp, XorOp>(inputOp)) {
585 if (auto cstRHS = inputOp->getOperand(1).getDefiningOp<hw::ConstantOp>()) {
586 auto extractedCst = cstRHS.getValue().extractBits(
587 cast<IntegerType>(op.getType()).getWidth(), op.getLowBit());
588 if (isa<OrOp, XorOp>(inputOp) && extractedCst.isZero()) {
589 replaceOpWithNewOpAndCopyNamehint<ExtractOp>(
590 rewriter, op, op.getType(), inputOp->getOperand(0), op.getLowBit());
591 return success();
592 }
593
594 // `extract(and(a, cst))` -> `concat(extract(a), 0)` if we only need one
595 // extract to represent the result. Turning it into a pile of extracts is
596 // always fine by our cost model, but we don't want to explode things into
597 // a ton of bits because it will bloat the IR and generated Verilog.
598 if (isa<AndOp>(inputOp)) {
599 // For our cost model, we only do this if the bit pattern is a
600 // contiguous series of ones.
601 unsigned lz = extractedCst.countLeadingZeros();
602 unsigned tz = extractedCst.countTrailingZeros();
603 unsigned pop = extractedCst.popcount();
604 if (extractedCst.getBitWidth() - lz - tz == pop) {
605 auto resultTy = rewriter.getIntegerType(pop);
606 SmallVector<Value> resultElts;
607 if (lz)
608 resultElts.push_back(rewriter.create<hw::ConstantOp>(
609 op.getLoc(), APInt::getZero(lz)));
610 resultElts.push_back(rewriter.createOrFold<ExtractOp>(
611 op.getLoc(), resultTy, inputOp->getOperand(0),
612 op.getLowBit() + tz));
613 if (tz)
614 resultElts.push_back(rewriter.create<hw::ConstantOp>(
615 op.getLoc(), APInt::getZero(tz)));
616 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
617 return success();
618 }
619 }
620 }
621 }
622
623 // `extract(lowBit, shl(1, x))` -> `x == lowBit` when a single bit is
624 // extracted.
625 if (cast<IntegerType>(op.getType()).getWidth() == 1 && inputOp)
626 if (auto shlOp = dyn_cast<ShlOp>(inputOp)) {
627 // Don't canonicalize if the shift is multiply used.
628 if (shlOp->hasOneUse())
629 if (auto lhsCst = shlOp.getLhs().getDefiningOp<hw::ConstantOp>())
630 if (lhsCst.getValue().isOne()) {
631 auto newCst = rewriter.create<hw::ConstantOp>(
632 shlOp.getLoc(),
633 APInt(lhsCst.getValue().getBitWidth(), op.getLowBit()));
634 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
635 rewriter, op, ICmpPredicate::eq, shlOp->getOperand(1), newCst,
636 false);
637 return success();
638 }
639 }
640
641 return failure();
642}
643
644//===----------------------------------------------------------------------===//
645// Associative Variadic operations
646//===----------------------------------------------------------------------===//
647
648// Reduce all operands to a single value (either integer constant or parameter
649// expression) if all the operands are constants.
650static Attribute constFoldAssociativeOp(ArrayRef<Attribute> operands,
651 hw::PEO paramOpcode) {
652 assert(operands.size() > 1 && "caller should handle one-operand case");
653 // We can only fold anything in the case where all operands are known to be
654 // constants. Check the least common one first for an early out.
655 if (!operands[1] || !operands[0])
656 return {};
657
658 // This will fold to a simple constant if all operands are constant.
659 if (llvm::all_of(operands.drop_front(2),
660 [&](Attribute in) { return !!in; })) {
661 SmallVector<mlir::TypedAttr> typedOperands;
662 typedOperands.reserve(operands.size());
663 for (auto operand : operands) {
664 if (auto typedOperand = dyn_cast<mlir::TypedAttr>(operand))
665 typedOperands.push_back(typedOperand);
666 else
667 break;
668 }
669 if (typedOperands.size() == operands.size())
670 return hw::ParamExprAttr::get(paramOpcode, typedOperands);
671 }
672
673 return {};
674}
675
676/// When we find a logical operation (and, or, xor) with a constant e.g.
677/// `X & 42`, we want to push the constant into the computation of X if it leads
678/// to simplification.
679///
680/// This function handles the case where the logical operation has a concat
681/// operand. We check to see if we can simplify the concat, e.g. when it has
682/// constant operands.
683///
684/// This returns true when a simplification happens.
685static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp,
686 size_t concatIdx, const APInt &cst,
687 PatternRewriter &rewriter) {
688 auto concatOp = logicalOp->getOperand(concatIdx).getDefiningOp<ConcatOp>();
689 assert((isa<AndOp, OrOp, XorOp>(logicalOp) && concatOp));
690
691 // Check to see if any operands can be simplified by pushing the logical op
692 // into all parts of the concat.
693 bool canSimplify =
694 llvm::any_of(concatOp->getOperands(), [&](Value operand) -> bool {
695 auto *operandOp = operand.getDefiningOp();
696 if (!operandOp)
697 return false;
698
699 // If the concat has a constant operand then we can transform this.
700 if (isa<hw::ConstantOp>(operandOp))
701 return true;
702 // If the concat has the same logical operation and that operation has
703 // a constant operation than we can fold it into that suboperation.
704 return operandOp->getName() == logicalOp->getName() &&
705 operandOp->hasOneUse() && operandOp->getNumOperands() != 0 &&
706 operandOp->getOperands().back().getDefiningOp<hw::ConstantOp>();
707 });
708
709 if (!canSimplify)
710 return false;
711
712 // Create a new instance of the logical operation. We have to do this the
713 // hard way since we're generic across a family of different ops.
714 auto createLogicalOp = [&](ArrayRef<Value> operands) -> Value {
715 return createGenericOp(logicalOp->getLoc(), logicalOp->getName(), operands,
716 rewriter);
717 };
718
719 // Ok, let's do the transformation. We do this by slicing up the constant
720 // for each unit of the concat and duplicate the operation into the
721 // sub-operand.
722 SmallVector<Value> newConcatOperands;
723 newConcatOperands.reserve(concatOp->getNumOperands());
724
725 // Work from MSB to LSB.
726 size_t nextOperandBit = concatOp.getType().getIntOrFloatBitWidth();
727 for (Value operand : concatOp->getOperands()) {
728 size_t operandWidth = operand.getType().getIntOrFloatBitWidth();
729 nextOperandBit -= operandWidth;
730 // Take a slice of the constant.
731 auto eltCst = rewriter.create<hw::ConstantOp>(
732 logicalOp->getLoc(), cst.lshr(nextOperandBit).trunc(operandWidth));
733
734 newConcatOperands.push_back(createLogicalOp({operand, eltCst}));
735 }
736
737 // Create the concat, and the rest of the logical op if we need it.
738 Value newResult =
739 rewriter.create<ConcatOp>(concatOp.getLoc(), newConcatOperands);
740
741 // If we had a variadic logical op on the top level, then recreate it with the
742 // new concat and without the constant operand.
743 if (logicalOp->getNumOperands() > 2) {
744 auto origOperands = logicalOp->getOperands();
745 SmallVector<Value> operands;
746 // Take any stuff before the concat.
747 operands.append(origOperands.begin(), origOperands.begin() + concatIdx);
748 // Take any stuff after the concat but before the constant.
749 operands.append(origOperands.begin() + concatIdx + 1,
750 origOperands.begin() + (origOperands.size() - 1));
751 // Include the new concat.
752 operands.push_back(newResult);
753 newResult = createLogicalOp(operands);
754 }
755
756 replaceOpAndCopyNamehint(rewriter, logicalOp, newResult);
757 return true;
758}
759
760// Determines whether the inputs to a logical element are of opposite
761// comparisons and can lowered into a constant.
762static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands) {
763 llvm::SmallDenseSet<std::tuple<ICmpPredicate, Value, Value>> seenPredicates;
764
765 for (auto op : operands) {
766 if (auto icmpOp = op.getDefiningOp<ICmpOp>();
767 icmpOp && icmpOp.getTwoState()) {
768 auto predicate = icmpOp.getPredicate();
769 auto lhs = icmpOp.getLhs();
770 auto rhs = icmpOp.getRhs();
771 if (seenPredicates.contains(
772 {ICmpOp::getNegatedPredicate(predicate), lhs, rhs}))
773 return true;
774
775 seenPredicates.insert({predicate, lhs, rhs});
776 }
777 }
778 return false;
779}
780
781OpFoldResult AndOp::fold(FoldAdaptor adaptor) {
782 APInt value = APInt::getAllOnes(cast<IntegerType>(getType()).getWidth());
783
784 auto inputs = adaptor.getInputs();
785
786 // and(x, 01, 10) -> 00 -- annulment.
787 for (auto operand : inputs) {
788 if (!operand)
789 continue;
790 value &= cast<IntegerAttr>(operand).getValue();
791 if (value.isZero())
792 return getIntAttr(value, getContext());
793 }
794
795 // and(x, -1) -> x.
796 if (inputs.size() == 2 && inputs[1] &&
797 cast<IntegerAttr>(inputs[1]).getValue().isAllOnes())
798 return getInputs()[0];
799
800 // and(x, x, x) -> x. This also handles and(x) -> x.
801 if (llvm::all_of(getInputs(),
802 [&](auto in) { return in == this->getInputs()[0]; }))
803 return getInputs()[0];
804
805 // and(..., x, ..., ~x, ...) -> 0
806 for (Value arg : getInputs()) {
807 Value subExpr;
808 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
809 for (Value arg2 : getInputs())
810 if (arg2 == subExpr)
811 return getIntAttr(
812 APInt::getZero(cast<IntegerType>(getType()).getWidth()),
813 getContext());
814 }
815 }
816
817 // x0 = icmp(pred, x, y)
818 // x1 = icmp(!pred, x, y)
819 // and(x0, x1) -> 0
821 return getIntAttr(APInt::getZero(cast<IntegerType>(getType()).getWidth()),
822 getContext());
823
824 // Constant fold
825 return constFoldAssociativeOp(inputs, hw::PEO::And);
826}
827
828/// Returns a single common operand that all inputs of the operation `op` can
829/// be traced back to, or an empty `Value` if no such operand exists.
830///
831/// For example for `or(a[0], a[1], ..., a[n-1])` this function returns `a`
832/// (assuming the bit-width of `a` is `n`).
833template <typename Op>
834static Value getCommonOperand(Op op) {
835 if (!op.getType().isInteger(1))
836 return Value();
837
838 auto inputs = op.getInputs();
839 size_t size = inputs.size();
840
841 auto sourceOp = inputs[0].template getDefiningOp<ExtractOp>();
842 if (!sourceOp)
843 return Value();
844 Value source = sourceOp.getOperand();
845
846 // Fast path: the input size is not equal to the width of the source.
847 if (size != source.getType().getIntOrFloatBitWidth())
848 return Value();
849
850 // Tracks the bits that were encountered.
851 llvm::BitVector bits(size);
852 bits.set(sourceOp.getLowBit());
853
854 for (size_t i = 1; i != size; ++i) {
855 auto extractOp = inputs[i].template getDefiningOp<ExtractOp>();
856 if (!extractOp || extractOp.getOperand() != source)
857 return Value();
858 bits.set(extractOp.getLowBit());
859 }
860
861 return bits.all() ? source : Value();
862}
863
864/// Canonicalize an idempotent operation `op` so that only one input of any kind
865/// occurs.
866///
867/// Example: `and(x, y, x, z)` -> `and(x, y, z)`
868template <typename Op>
869static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter) {
870 // Depth limit to search, in operations. Chosen arbitrarily, keep small.
871 constexpr unsigned limit = 3;
872 auto inputs = op.getInputs();
873
874 llvm::SmallSetVector<Value, 8> uniqueInputs(inputs.begin(), inputs.end());
875 llvm::SmallDenseSet<Op, 8> checked;
876 checked.insert(op);
877
878 struct OpWithDepth {
879 Op op;
880 unsigned depth;
881 };
882 llvm::SmallVector<OpWithDepth, 8> worklist;
883
884 auto enqueue = [&worklist, &checked, &op](Value input, unsigned depth) {
885 // Add to worklist if within depth limit, is defined in the same block by
886 // the same kind of operation, has same two-state-ness, and not enqueued
887 // previously.
888 if (depth < limit && input.getParentBlock() == op->getBlock()) {
889 auto inputOp = input.template getDefiningOp<Op>();
890 if (inputOp && inputOp.getTwoState() == op.getTwoState() &&
891 checked.insert(inputOp).second)
892 worklist.push_back({inputOp, depth + 1});
893 }
894 };
895
896 for (auto input : uniqueInputs)
897 enqueue(input, 0);
898
899 while (!worklist.empty()) {
900 auto item = worklist.pop_back_val();
901
902 for (auto input : item.op.getInputs()) {
903 uniqueInputs.remove(input);
904 enqueue(input, item.depth);
905 }
906 }
907
908 if (uniqueInputs.size() < inputs.size()) {
909 replaceOpWithNewOpAndCopyNamehint<Op>(rewriter, op, op.getType(),
910 uniqueInputs.getArrayRef(),
911 op.getTwoState());
912 return true;
913 }
914
915 return false;
916}
917
918LogicalResult AndOp::canonicalize(AndOp op, PatternRewriter &rewriter) {
919 auto inputs = op.getInputs();
920 auto size = inputs.size();
921
922 // and(x, and(...)) -> and(x, ...) -- flatten
923 if (tryFlatteningOperands(op, rewriter))
924 return success();
925
926 // and(..., x, ..., x) -> and(..., x, ...) -- idempotent
927 // and(..., x, and(..., x, ...)) -> and(..., and(..., x, ...)) -- idempotent
928 // Trivial and(x), and(x, x) cases are handled by [AndOp::fold] above.
929 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
930 return success();
931
932 assert(size > 1 && "expected 2 or more operands, `fold` should handle this");
933
934 // Patterns for and with a constant on RHS.
935 APInt value;
936 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
937 // and(..., '1) -> and(...) -- identity
938 if (value.isAllOnes()) {
939 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
940 inputs.drop_back(), false);
941 return success();
942 }
943
944 // TODO: Combine multiple constants together even if they aren't at the
945 // end. and(..., c1, c2) -> and(..., c3) where c3 = c1 & c2 -- constant
946 // folding
947 APInt value2;
948 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
949 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value & value2);
950 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
951 newOperands.push_back(cst);
952 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getType(),
953 newOperands, false);
954 return success();
955 }
956
957 // Handle 'and' with a single bit constant on the RHS.
958 if (size == 2 && value.isPowerOf2()) {
959 // If the LHS is a replicate from a single bit, we can 'concat' it
960 // into place. e.g.:
961 // `replicate(x) & 4` -> `concat(zeros, x, zeros)`
962 // TODO: Generalize this for non-single-bit operands.
963 if (auto replicate = inputs[0].getDefiningOp<ReplicateOp>()) {
964 auto replicateOperand = replicate.getOperand();
965 if (replicateOperand.getType().isInteger(1)) {
966 unsigned resultWidth = op.getType().getIntOrFloatBitWidth();
967 auto trailingZeros = value.countTrailingZeros();
968
969 // Don't add zero bit constants unnecessarily.
970 SmallVector<Value, 3> concatOperands;
971 if (trailingZeros != resultWidth - 1) {
972 auto highZeros = rewriter.create<hw::ConstantOp>(
973 op.getLoc(), APInt::getZero(resultWidth - trailingZeros - 1));
974 concatOperands.push_back(highZeros);
975 }
976 concatOperands.push_back(replicateOperand);
977 if (trailingZeros != 0) {
978 auto lowZeros = rewriter.create<hw::ConstantOp>(
979 op.getLoc(), APInt::getZero(trailingZeros));
980 concatOperands.push_back(lowZeros);
981 }
982 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(
983 rewriter, op, op.getType(), concatOperands);
984 return success();
985 }
986 }
987 }
988
989 // If this is an and from an extract op, try shrinking the extract.
990 if (auto extractOp = inputs[0].getDefiningOp<ExtractOp>()) {
991 if (size == 2 &&
992 // We can shrink it if the mask has leading or trailing zeros.
993 (value.countLeadingZeros() || value.countTrailingZeros())) {
994 unsigned lz = value.countLeadingZeros();
995 unsigned tz = value.countTrailingZeros();
996
997 // Start by extracting the smaller number of bits.
998 auto smallTy = rewriter.getIntegerType(value.getBitWidth() - lz - tz);
999 Value smallElt = rewriter.createOrFold<ExtractOp>(
1000 extractOp.getLoc(), smallTy, extractOp->getOperand(0),
1001 extractOp.getLowBit() + tz);
1002 // Apply the 'and' mask if needed.
1003 APInt smallMask = value.extractBits(smallTy.getWidth(), tz);
1004 if (!smallMask.isAllOnes()) {
1005 auto loc = inputs.back().getLoc();
1006 smallElt = rewriter.createOrFold<AndOp>(
1007 loc, smallElt, rewriter.create<hw::ConstantOp>(loc, smallMask),
1008 false);
1009 }
1010
1011 // The final replacement will be a concat of the leading/trailing zeros
1012 // along with the smaller extracted value.
1013 SmallVector<Value> resultElts;
1014 if (lz)
1015 resultElts.push_back(
1016 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(lz)));
1017 resultElts.push_back(smallElt);
1018 if (tz)
1019 resultElts.push_back(
1020 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(tz)));
1021 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, resultElts);
1022 return success();
1023 }
1024 }
1025
1026 // and(concat(x, cst1), a, b, c, cst2)
1027 // ==> and(a, b, c, concat(and(x,cst2'), and(cst1,cst2'')).
1028 // We do this for even more multi-use concats since they are "just wiring".
1029 for (size_t i = 0; i < size - 1; ++i) {
1030 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1031 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1032 return success();
1033 }
1034 }
1035
1036 // extracts only of and(...) -> and(extract()...)
1037 if (narrowOperationWidth(op, true, rewriter))
1038 return success();
1039
1040 // and(a[0], a[1], ..., a[n]) -> icmp eq(a, -1)
1041 if (auto source = getCommonOperand(op)) {
1042 auto cmpAgainst =
1043 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getAllOnes(size));
1044 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::eq,
1045 source, cmpAgainst);
1046 return success();
1047 }
1048
1049 /// TODO: and(..., x, not(x)) -> and(..., 0) -- complement
1050 return failure();
1051}
1052
1053OpFoldResult OrOp::fold(FoldAdaptor adaptor) {
1054 auto value = APInt::getZero(cast<IntegerType>(getType()).getWidth());
1055 auto inputs = adaptor.getInputs();
1056 // or(x, 10, 01) -> 11
1057 for (auto operand : inputs) {
1058 if (!operand)
1059 continue;
1060 value |= cast<IntegerAttr>(operand).getValue();
1061 if (value.isAllOnes())
1062 return getIntAttr(value, getContext());
1063 }
1064
1065 // or(x, 0) -> x
1066 if (inputs.size() == 2 && inputs[1] &&
1067 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1068 return getInputs()[0];
1069
1070 // or(x, x, x) -> x. This also handles or(x) -> x
1071 if (llvm::all_of(getInputs(),
1072 [&](auto in) { return in == this->getInputs()[0]; }))
1073 return getInputs()[0];
1074
1075 // or(..., x, ..., ~x, ...) -> -1
1076 for (Value arg : getInputs()) {
1077 Value subExpr;
1078 if (matchPattern(arg, m_Complement(m_Any(&subExpr)))) {
1079 for (Value arg2 : getInputs())
1080 if (arg2 == subExpr)
1081 return getIntAttr(
1082 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1083 getContext());
1084 }
1085 }
1086
1087 // x0 = icmp(pred, x, y)
1088 // x1 = icmp(!pred, x, y)
1089 // or(x0, x1) -> 1
1090 if (canCombineOppositeBinCmpIntoConstant(getInputs()))
1091 return getIntAttr(
1092 APInt::getAllOnes(cast<IntegerType>(getType()).getWidth()),
1093 getContext());
1094
1095 // Constant fold
1096 return constFoldAssociativeOp(inputs, hw::PEO::Or);
1097}
1098
1099LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) {
1100 auto inputs = op.getInputs();
1101 auto size = inputs.size();
1102
1103 // or(x, or(...)) -> or(x, ...) -- flatten
1104 if (tryFlatteningOperands(op, rewriter))
1105 return success();
1106
1107 // or(..., x, ..., x, ...) -> or(..., x) -- idempotent
1108 // or(..., x, or(..., x, ...)) -> or(..., or(..., x, ...)) -- idempotent
1109 // Trivial or(x), or(x, x) cases are handled by [OrOp::fold].
1110 if (size > 1 && canonicalizeIdempotentInputs(op, rewriter))
1111 return success();
1112
1113 assert(size > 1 && "expected 2 or more operands");
1114
1115 // Patterns for and with a constant on RHS.
1116 APInt value;
1117 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1118 // or(..., '0) -> or(...) -- identity
1119 if (value.isZero()) {
1120 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1121 inputs.drop_back());
1122 return success();
1123 }
1124
1125 // or(..., c1, c2) -> or(..., c3) where c3 = c1 | c2 -- constant folding
1126 APInt value2;
1127 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1128 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value | value2);
1129 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1130 newOperands.push_back(cst);
1131 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getType(),
1132 newOperands);
1133 return success();
1134 }
1135
1136 // or(concat(x, cst1), a, b, c, cst2)
1137 // ==> or(a, b, c, concat(or(x,cst2'), or(cst1,cst2'')).
1138 // We do this for even more multi-use concats since they are "just wiring".
1139 for (size_t i = 0; i < size - 1; ++i) {
1140 if (auto concat = inputs[i].getDefiningOp<ConcatOp>())
1141 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1142 return success();
1143 }
1144 }
1145
1146 // extracts only of or(...) -> or(extract()...)
1147 if (narrowOperationWidth(op, true, rewriter))
1148 return success();
1149
1150 // or(a[0], a[1], ..., a[n]) -> icmp ne(a, 0)
1151 if (auto source = getCommonOperand(op)) {
1152 auto cmpAgainst =
1153 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(size));
1154 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, ICmpPredicate::ne,
1155 source, cmpAgainst);
1156 return success();
1157 }
1158
1159 // or(mux(c_1, a, 0), mux(c_2, a, 0), ..., mux(c_n, a, 0)) -> mux(or(c_1, c_2,
1160 // .., c_n), a, 0)
1161 if (auto firstMux = op.getOperand(0).getDefiningOp<comb::MuxOp>()) {
1162 APInt value;
1163 if (op.getTwoState() && firstMux.getTwoState() &&
1164 matchPattern(firstMux.getFalseValue(), m_ConstantInt(&value)) &&
1165 value.isZero()) {
1166 SmallVector<Value> conditions{firstMux.getCond()};
1167 auto check = [&](Value v) {
1168 auto mux = v.getDefiningOp<comb::MuxOp>();
1169 if (!mux)
1170 return false;
1171 conditions.push_back(mux.getCond());
1172 return mux.getTwoState() &&
1173 firstMux.getTrueValue() == mux.getTrueValue() &&
1174 firstMux.getFalseValue() == mux.getFalseValue();
1175 };
1176 if (llvm::all_of(op.getOperands().drop_front(), check)) {
1177 auto cond = rewriter.create<comb::OrOp>(op.getLoc(), conditions, true);
1178 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1179 rewriter, op, cond, firstMux.getTrueValue(),
1180 firstMux.getFalseValue(), true);
1181 return success();
1182 }
1183 }
1184 }
1185
1186 /// TODO: or(..., x, not(x)) -> or(..., '1) -- complement
1187 return failure();
1188}
1189
1190OpFoldResult XorOp::fold(FoldAdaptor adaptor) {
1191 auto size = getInputs().size();
1192 auto inputs = adaptor.getInputs();
1193
1194 // xor(x) -> x -- noop
1195 if (size == 1)
1196 return getInputs()[0];
1197
1198 // xor(x, x) -> 0 -- idempotent
1199 if (size == 2 && getInputs()[0] == getInputs()[1])
1200 return IntegerAttr::get(getType(), 0);
1201
1202 // xor(x, 0) -> x
1203 if (inputs.size() == 2 && inputs[1] &&
1204 cast<IntegerAttr>(inputs[1]).getValue().isZero())
1205 return getInputs()[0];
1206
1207 // xor(xor(x,1),1) -> x
1208 // but not self loop
1209 if (isBinaryNot()) {
1210 Value subExpr;
1211 if (matchPattern(getOperand(0), m_Complement(m_Any(&subExpr))) &&
1212 subExpr != getResult())
1213 return subExpr;
1214 }
1215
1216 // Constant fold
1217 return constFoldAssociativeOp(inputs, hw::PEO::Xor);
1218}
1219
1220// xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1221static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand,
1222 PatternRewriter &rewriter) {
1223 auto icmp = op.getOperand(icmpOperand).getDefiningOp<ICmpOp>();
1224 auto negatedPred = ICmpOp::getNegatedPredicate(icmp.getPredicate());
1225
1226 Value result =
1227 rewriter.create<ICmpOp>(icmp.getLoc(), negatedPred, icmp.getOperand(0),
1228 icmp.getOperand(1), icmp.getTwoState());
1229
1230 // If the xor had other operands, rebuild it.
1231 if (op.getNumOperands() > 2) {
1232 SmallVector<Value, 4> newOperands(op.getOperands());
1233 newOperands.pop_back();
1234 newOperands.erase(newOperands.begin() + icmpOperand);
1235 newOperands.push_back(result);
1236 result = rewriter.create<XorOp>(op.getLoc(), newOperands, op.getTwoState());
1237 }
1238
1239 replaceOpAndCopyNamehint(rewriter, op, result);
1240}
1241
1242LogicalResult XorOp::canonicalize(XorOp op, PatternRewriter &rewriter) {
1243 auto inputs = op.getInputs();
1244 auto size = inputs.size();
1245 assert(size > 1 && "expected 2 or more operands");
1246
1247 // xor(..., x, x) -> xor (...) -- idempotent
1248 if (inputs[size - 1] == inputs[size - 2]) {
1249 assert(size > 2 &&
1250 "expected idempotent case for 2 elements handled already.");
1251 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1252 inputs.drop_back(/*n=*/2), false);
1253 return success();
1254 }
1255
1256 // Patterns for xor with a constant on RHS.
1257 APInt value;
1258 if (matchPattern(inputs.back(), m_ConstantInt(&value))) {
1259 // xor(..., 0) -> xor(...) -- identity
1260 if (value.isZero()) {
1261 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1262 inputs.drop_back(), false);
1263 return success();
1264 }
1265
1266 // xor(..., c1, c2) -> xor(..., c3) where c3 = c1 ^ c2.
1267 APInt value2;
1268 if (matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1269 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value ^ value2);
1270 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1271 newOperands.push_back(cst);
1272 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getType(),
1273 newOperands, false);
1274 return success();
1275 }
1276
1277 bool isSingleBit = value.getBitWidth() == 1;
1278
1279 // Check for subexpressions that we can simplify.
1280 for (size_t i = 0; i < size - 1; ++i) {
1281 Value operand = inputs[i];
1282
1283 // xor(concat(x, cst1), a, b, c, cst2)
1284 // ==> xor(a, b, c, concat(xor(x,cst2'), xor(cst1,cst2'')).
1285 // We do this for even more multi-use concats since they are "just
1286 // wiring".
1287 if (auto concat = operand.getDefiningOp<ConcatOp>())
1288 if (canonicalizeLogicalCstWithConcat(op, i, value, rewriter))
1289 return success();
1290
1291 // xor(icmp, a, b, 1) -> xor(icmp, a, b) if icmp has one user.
1292 if (isSingleBit && operand.hasOneUse()) {
1293 assert(value == 1 && "single bit constant has to be one if not zero");
1294 if (auto icmp = operand.getDefiningOp<ICmpOp>())
1295 return canonicalizeXorIcmpTrue(op, i, rewriter), success();
1296 }
1297 }
1298 }
1299
1300 // xor(x, xor(...)) -> xor(x, ...) -- flatten
1301 if (tryFlatteningOperands(op, rewriter))
1302 return success();
1303
1304 // extracts only of xor(...) -> xor(extract()...)
1305 if (narrowOperationWidth(op, true, rewriter))
1306 return success();
1307
1308 // xor(a[0], a[1], ..., a[n]) -> parity(a)
1309 if (auto source = getCommonOperand(op)) {
1310 replaceOpWithNewOpAndCopyNamehint<ParityOp>(rewriter, op, source);
1311 return success();
1312 }
1313
1314 return failure();
1315}
1316
1317OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
1318 // sub(x - x) -> 0
1319 if (getRhs() == getLhs())
1320 return getIntAttr(
1321 APInt::getZero(getLhs().getType().getIntOrFloatBitWidth()),
1322 getContext());
1323
1324 if (adaptor.getRhs()) {
1325 // If both are constants, we can unconditionally fold.
1326 if (adaptor.getLhs()) {
1327 // Constant fold (c1 - c2) => (c1 + -1*c2).
1328 auto negOne = getIntAttr(
1329 APInt::getAllOnes(getLhs().getType().getIntOrFloatBitWidth()),
1330 getContext());
1331 auto rhsNeg = hw::ParamExprAttr::get(
1332 hw::PEO::Mul, cast<TypedAttr>(adaptor.getRhs()), negOne);
1333 return hw::ParamExprAttr::get(hw::PEO::Add,
1334 cast<TypedAttr>(adaptor.getLhs()), rhsNeg);
1335 }
1336
1337 // sub(x - 0) -> x
1338 if (auto rhsC = dyn_cast<IntegerAttr>(adaptor.getRhs())) {
1339 if (rhsC.getValue().isZero())
1340 return getLhs();
1341 }
1342 }
1343
1344 return {};
1345}
1346
1347LogicalResult SubOp::canonicalize(SubOp op, PatternRewriter &rewriter) {
1348 // sub(x, cst) -> add(x, -cst)
1349 APInt value;
1350 if (matchPattern(op.getRhs(), m_ConstantInt(&value))) {
1351 auto negCst = rewriter.create<hw::ConstantOp>(op.getLoc(), -value);
1352 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getLhs(), negCst,
1353 false);
1354 return success();
1355 }
1356
1357 // extracts only of sub(...) -> sub(extract()...)
1358 if (narrowOperationWidth(op, false, rewriter))
1359 return success();
1360
1361 return failure();
1362}
1363
1364OpFoldResult AddOp::fold(FoldAdaptor adaptor) {
1365 auto size = getInputs().size();
1366
1367 // add(x) -> x -- noop
1368 if (size == 1u)
1369 return getInputs()[0];
1370
1371 // Constant fold constant operands.
1372 return constFoldAssociativeOp(adaptor.getOperands(), hw::PEO::Add);
1373}
1374
1375LogicalResult AddOp::canonicalize(AddOp op, PatternRewriter &rewriter) {
1376 auto inputs = op.getInputs();
1377 auto size = inputs.size();
1378 assert(size > 1 && "expected 2 or more operands");
1379
1380 APInt value, value2;
1381
1382 // add(..., 0) -> add(...) -- identity
1383 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
1384 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1385 inputs.drop_back(), false);
1386 return success();
1387 }
1388
1389 // add(..., c1, c2) -> add(..., c3) where c3 = c1 + c2 -- constant folding
1390 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1391 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1392 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
1393 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1394 newOperands.push_back(cst);
1395 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1396 newOperands, false);
1397 return success();
1398 }
1399
1400 // add(..., x, x) -> add(..., shl(x, 1))
1401 if (inputs[size - 1] == inputs[size - 2]) {
1402 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1403
1404 auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
1405 auto shiftLeftOp =
1406 rewriter.create<comb::ShlOp>(op.getLoc(), inputs.back(), one, false);
1407
1408 newOperands.push_back(shiftLeftOp);
1409 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1410 newOperands, false);
1411 return success();
1412 }
1413
1414 auto shlOp = inputs[size - 1].getDefiningOp<comb::ShlOp>();
1415 // add(..., x, shl(x, c)) -> add(..., mul(x, (1 << c) + 1))
1416 if (shlOp && shlOp.getLhs() == inputs[size - 2] &&
1417 matchPattern(shlOp.getRhs(), m_ConstantInt(&value))) {
1418
1419 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1420 auto rhs =
1421 rewriter.create<hw::ConstantOp>(op.getLoc(), (one << value) + one);
1422
1423 std::array<Value, 2> factors = {shlOp.getLhs(), rhs};
1424 auto mulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
1425
1426 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1427 newOperands.push_back(mulOp);
1428 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1429 newOperands, false);
1430 return success();
1431 }
1432
1433 auto mulOp = inputs[size - 1].getDefiningOp<comb::MulOp>();
1434 // add(..., x, mul(x, c)) -> add(..., mul(x, c + 1))
1435 if (mulOp && mulOp.getInputs().size() == 2 &&
1436 mulOp.getInputs()[0] == inputs[size - 2] &&
1437 matchPattern(mulOp.getInputs()[1], m_ConstantInt(&value))) {
1438
1439 APInt one(/*numBits=*/value.getBitWidth(), 1, /*isSigned=*/false);
1440 auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + one);
1441 std::array<Value, 2> factors = {mulOp.getInputs()[0], rhs};
1442 auto newMulOp = rewriter.create<comb::MulOp>(op.getLoc(), factors, false);
1443
1444 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1445 newOperands.push_back(newMulOp);
1446 replaceOpWithNewOpAndCopyNamehint<AddOp>(rewriter, op, op.getType(),
1447 newOperands, false);
1448 return success();
1449 }
1450
1451 // add(a, add(...)) -> add(a, ...) -- flatten
1452 if (tryFlatteningOperands(op, rewriter))
1453 return success();
1454
1455 // extracts only of add(...) -> add(extract()...)
1456 if (narrowOperationWidth(op, false, rewriter))
1457 return success();
1458
1459 // add(add(x, c1), c2) -> add(x, c1 + c2)
1460 auto addOp = inputs[0].getDefiningOp<comb::AddOp>();
1461 if (addOp && addOp.getInputs().size() == 2 &&
1462 matchPattern(addOp.getInputs()[1], m_ConstantInt(&value2)) &&
1463 inputs.size() == 2 && matchPattern(inputs[1], m_ConstantInt(&value))) {
1464
1465 auto rhs = rewriter.create<hw::ConstantOp>(op.getLoc(), value + value2);
1466 replaceOpWithNewOpAndCopyNamehint<AddOp>(
1467 rewriter, op, op.getType(), ArrayRef<Value>{addOp.getInputs()[0], rhs},
1468 /*twoState=*/op.getTwoState() && addOp.getTwoState());
1469 return success();
1470 }
1471
1472 return failure();
1473}
1474
1475OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
1476 auto size = getInputs().size();
1477 auto inputs = adaptor.getInputs();
1478
1479 // mul(x) -> x -- noop
1480 if (size == 1u)
1481 return getInputs()[0];
1482
1483 auto width = cast<IntegerType>(getType()).getWidth();
1484 if (width == 0)
1485 return getIntAttr(APInt::getZero(0), getContext());
1486
1487 APInt value(/*numBits=*/width, 1, /*isSigned=*/false);
1488
1489 // mul(x, 0, 1) -> 0 -- annulment
1490 for (auto operand : inputs) {
1491 if (!operand)
1492 continue;
1493 value *= cast<IntegerAttr>(operand).getValue();
1494 if (value.isZero())
1495 return getIntAttr(value, getContext());
1496 }
1497
1498 // Constant fold
1499 return constFoldAssociativeOp(inputs, hw::PEO::Mul);
1500}
1501
1502LogicalResult MulOp::canonicalize(MulOp op, PatternRewriter &rewriter) {
1503 auto inputs = op.getInputs();
1504 auto size = inputs.size();
1505 assert(size > 1 && "expected 2 or more operands");
1506
1507 APInt value, value2;
1508
1509 // mul(x, c) -> shl(x, log2(c)), where c is a power of two.
1510 if (size == 2 && matchPattern(inputs.back(), m_ConstantInt(&value)) &&
1511 value.isPowerOf2()) {
1512 auto shift = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(),
1513 value.exactLogBase2());
1514 auto shlOp =
1515 rewriter.create<comb::ShlOp>(op.getLoc(), inputs[0], shift, false);
1516
1517 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1518 ArrayRef<Value>(shlOp), false);
1519 return success();
1520 }
1521
1522 // mul(..., 1) -> mul(...) -- identity
1523 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isOne()) {
1524 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1525 inputs.drop_back());
1526 return success();
1527 }
1528
1529 // mul(..., c1, c2) -> mul(..., c3) where c3 = c1 * c2 -- constant folding
1530 if (matchPattern(inputs[size - 1], m_ConstantInt(&value)) &&
1531 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
1532 auto cst = rewriter.create<hw::ConstantOp>(op.getLoc(), value * value2);
1533 SmallVector<Value, 4> newOperands(inputs.drop_back(/*n=*/2));
1534 newOperands.push_back(cst);
1535 replaceOpWithNewOpAndCopyNamehint<MulOp>(rewriter, op, op.getType(),
1536 newOperands);
1537 return success();
1538 }
1539
1540 // mul(a, mul(...)) -> mul(a, ...) -- flatten
1541 if (tryFlatteningOperands(op, rewriter))
1542 return success();
1543
1544 // extracts only of mul(...) -> mul(extract()...)
1545 if (narrowOperationWidth(op, false, rewriter))
1546 return success();
1547
1548 return failure();
1549}
1550
1551template <class Op, bool isSigned>
1552static OpFoldResult foldDiv(Op op, ArrayRef<Attribute> constants) {
1553 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1554 // divu(x, 1) -> x, divs(x, 1) -> x
1555 if (rhsValue.getValue() == 1)
1556 return op.getLhs();
1557
1558 // If the divisor is zero, do not fold for now.
1559 if (rhsValue.getValue().isZero())
1560 return {};
1561 }
1562
1563 return constFoldBinaryOp(constants, isSigned ? hw::PEO::DivS : hw::PEO::DivU);
1564}
1565
1566OpFoldResult DivUOp::fold(FoldAdaptor adaptor) {
1567 return foldDiv<DivUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1568}
1569
1570OpFoldResult DivSOp::fold(FoldAdaptor adaptor) {
1571 return foldDiv<DivSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1572}
1573
1574template <class Op, bool isSigned>
1575static OpFoldResult foldMod(Op op, ArrayRef<Attribute> constants) {
1576 if (auto rhsValue = dyn_cast_or_null<IntegerAttr>(constants[1])) {
1577 // modu(x, 1) -> 0, mods(x, 1) -> 0
1578 if (rhsValue.getValue() == 1)
1579 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1580 op.getContext());
1581
1582 // If the divisor is zero, do not fold for now.
1583 if (rhsValue.getValue().isZero())
1584 return {};
1585 }
1586
1587 if (auto lhsValue = dyn_cast_or_null<IntegerAttr>(constants[0])) {
1588 // modu(0, x) -> 0, mods(0, x) -> 0
1589 if (lhsValue.getValue().isZero())
1590 return getIntAttr(APInt::getZero(op.getType().getIntOrFloatBitWidth()),
1591 op.getContext());
1592 }
1593
1594 return constFoldBinaryOp(constants, isSigned ? hw::PEO::ModS : hw::PEO::ModU);
1595}
1596
1597OpFoldResult ModUOp::fold(FoldAdaptor adaptor) {
1598 return foldMod<ModUOp, /*isSigned=*/false>(*this, adaptor.getOperands());
1599}
1600
1601OpFoldResult ModSOp::fold(FoldAdaptor adaptor) {
1602 return foldMod<ModSOp, /*isSigned=*/true>(*this, adaptor.getOperands());
1603}
1604//===----------------------------------------------------------------------===//
1605// ConcatOp
1606//===----------------------------------------------------------------------===//
1607
1608// Constant folding
1609OpFoldResult ConcatOp::fold(FoldAdaptor adaptor) {
1610 if (getNumOperands() == 1)
1611 return getOperand(0);
1612
1613 // If all the operands are constant, we can fold.
1614 for (auto attr : adaptor.getInputs())
1615 if (!attr || !isa<IntegerAttr>(attr))
1616 return {};
1617
1618 // If we got here, we can constant fold.
1619 unsigned resultWidth = getType().getIntOrFloatBitWidth();
1620 APInt result(resultWidth, 0);
1621
1622 unsigned nextInsertion = resultWidth;
1623 // Insert each chunk into the result.
1624 for (auto attr : adaptor.getInputs()) {
1625 auto chunk = cast<IntegerAttr>(attr).getValue();
1626 nextInsertion -= chunk.getBitWidth();
1627 result.insertBits(chunk, nextInsertion);
1628 }
1629
1630 return getIntAttr(result, getContext());
1631}
1632
1633LogicalResult ConcatOp::canonicalize(ConcatOp op, PatternRewriter &rewriter) {
1634 auto inputs = op.getInputs();
1635 auto size = inputs.size();
1636 assert(size > 1 && "expected 2 or more operands");
1637
1638 // This function is used when we flatten neighboring operands of a
1639 // (variadic) concat into a new vesion of the concat. first/last indices
1640 // are inclusive.
1641 auto flattenConcat = [&](size_t firstOpIndex, size_t lastOpIndex,
1642 ValueRange replacements) -> LogicalResult {
1643 SmallVector<Value, 4> newOperands;
1644 newOperands.append(inputs.begin(), inputs.begin() + firstOpIndex);
1645 newOperands.append(replacements.begin(), replacements.end());
1646 newOperands.append(inputs.begin() + lastOpIndex + 1, inputs.end());
1647 if (newOperands.size() == 1)
1648 replaceOpAndCopyNamehint(rewriter, op, newOperands[0]);
1649 else
1650 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
1651 newOperands);
1652 return success();
1653 };
1654
1655 Value commonOperand = inputs[0];
1656 for (size_t i = 0; i != size; ++i) {
1657 // Check to see if all operands are the same.
1658 if (inputs[i] != commonOperand)
1659 commonOperand = Value();
1660
1661 // If an operand to the concat is itself a concat, then we can fold them
1662 // together.
1663 if (auto subConcat = inputs[i].getDefiningOp<ConcatOp>())
1664 return flattenConcat(i, i, subConcat->getOperands());
1665
1666 // Check for canonicalization due to neighboring operands.
1667 if (i != 0) {
1668 // Merge neighboring constants.
1669 if (auto cst = inputs[i].getDefiningOp<hw::ConstantOp>()) {
1670 if (auto prevCst = inputs[i - 1].getDefiningOp<hw::ConstantOp>()) {
1671 unsigned prevWidth = prevCst.getValue().getBitWidth();
1672 unsigned thisWidth = cst.getValue().getBitWidth();
1673 auto resultCst = cst.getValue().zext(prevWidth + thisWidth);
1674 resultCst |= prevCst.getValue().zext(prevWidth + thisWidth)
1675 << thisWidth;
1676 Value replacement =
1677 rewriter.create<hw::ConstantOp>(op.getLoc(), resultCst);
1678 return flattenConcat(i - 1, i, replacement);
1679 }
1680 }
1681
1682 // If the two operands are the same, turn them into a replicate.
1683 if (inputs[i] == inputs[i - 1]) {
1684 Value replacement =
1685 rewriter.createOrFold<ReplicateOp>(op.getLoc(), inputs[i], 2);
1686 return flattenConcat(i - 1, i, replacement);
1687 }
1688
1689 // If this input is a replicate, see if we can fold it with the previous
1690 // one.
1691 if (auto repl = inputs[i].getDefiningOp<ReplicateOp>()) {
1692 // ... x, repl(x, n), ... ==> ..., repl(x, n+1), ...
1693 if (repl.getOperand() == inputs[i - 1]) {
1694 Value replacement = rewriter.createOrFold<ReplicateOp>(
1695 op.getLoc(), repl.getOperand(), repl.getMultiple() + 1);
1696 return flattenConcat(i - 1, i, replacement);
1697 }
1698 // ... repl(x, n), repl(x, m), ... ==> ..., repl(x, n+m), ...
1699 if (auto prevRepl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1700 if (prevRepl.getOperand() == repl.getOperand()) {
1701 Value replacement = rewriter.createOrFold<ReplicateOp>(
1702 op.getLoc(), repl.getOperand(),
1703 repl.getMultiple() + prevRepl.getMultiple());
1704 return flattenConcat(i - 1, i, replacement);
1705 }
1706 }
1707 }
1708
1709 // ... repl(x, n), x, ... ==> ..., repl(x, n+1), ...
1710 if (auto repl = inputs[i - 1].getDefiningOp<ReplicateOp>()) {
1711 if (repl.getOperand() == inputs[i]) {
1712 Value replacement = rewriter.createOrFold<ReplicateOp>(
1713 op.getLoc(), inputs[i], repl.getMultiple() + 1);
1714 return flattenConcat(i - 1, i, replacement);
1715 }
1716 }
1717
1718 // Merge neighboring extracts of neighboring inputs, e.g.
1719 // {A[3], A[2]} -> A[3:2]
1720 if (auto extract = inputs[i].getDefiningOp<ExtractOp>()) {
1721 if (auto prevExtract = inputs[i - 1].getDefiningOp<ExtractOp>()) {
1722 if (extract.getInput() == prevExtract.getInput()) {
1723 auto thisWidth = cast<IntegerType>(extract.getType()).getWidth();
1724 if (prevExtract.getLowBit() == extract.getLowBit() + thisWidth) {
1725 auto prevWidth = prevExtract.getType().getIntOrFloatBitWidth();
1726 auto resType = rewriter.getIntegerType(thisWidth + prevWidth);
1727 Value replacement = rewriter.create<ExtractOp>(
1728 op.getLoc(), resType, extract.getInput(),
1729 extract.getLowBit());
1730 return flattenConcat(i - 1, i, replacement);
1731 }
1732 }
1733 }
1734 }
1735 // Merge neighboring array extracts of neighboring inputs, e.g.
1736 // {Array[4], bitcast(Array[3:2])} -> bitcast(A[4:2])
1737
1738 // This represents a slice of an array.
1739 struct ArraySlice {
1740 Value input;
1741 Value index;
1742 size_t width;
1743 static std::optional<ArraySlice> get(Value value) {
1744 assert(isa<IntegerType>(value.getType()) && "expected integer type");
1745 if (auto arrayGet = value.getDefiningOp<hw::ArrayGetOp>())
1746 return ArraySlice{arrayGet.getInput(), arrayGet.getIndex(), 1};
1747 // array slice op is wrapped with bitcast.
1748 if (auto bitcast = value.getDefiningOp<hw::BitcastOp>())
1749 if (auto arraySlice =
1750 bitcast.getInput().getDefiningOp<hw::ArraySliceOp>())
1751 return ArraySlice{
1752 arraySlice.getInput(), arraySlice.getLowIndex(),
1753 hw::type_cast<hw::ArrayType>(arraySlice.getType())
1754 .getNumElements()};
1755 return std::nullopt;
1756 }
1757 };
1758 if (auto extractOpt = ArraySlice::get(inputs[i])) {
1759 if (auto prevExtractOpt = ArraySlice::get(inputs[i - 1])) {
1760 // Check that two array slices are mergable.
1761 if (prevExtractOpt->index.getType() == extractOpt->index.getType() &&
1762 prevExtractOpt->input == extractOpt->input &&
1763 hw::isOffset(extractOpt->index, prevExtractOpt->index,
1764 extractOpt->width)) {
1765 auto resType = hw::ArrayType::get(
1766 hw::type_cast<hw::ArrayType>(prevExtractOpt->input.getType())
1767 .getElementType(),
1768 extractOpt->width + prevExtractOpt->width);
1769 auto resIntType = rewriter.getIntegerType(hw::getBitWidth(resType));
1770 Value replacement = rewriter.create<hw::BitcastOp>(
1771 op.getLoc(), resIntType,
1772 rewriter.create<hw::ArraySliceOp>(op.getLoc(), resType,
1773 prevExtractOpt->input,
1774 extractOpt->index));
1775 return flattenConcat(i - 1, i, replacement);
1776 }
1777 }
1778 }
1779 }
1780 }
1781
1782 // If all operands were the same, then this is a replicate.
1783 if (commonOperand) {
1784 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(rewriter, op, op.getType(),
1785 commonOperand);
1786 return success();
1787 }
1788
1789 return failure();
1790}
1791
1792//===----------------------------------------------------------------------===//
1793// MuxOp
1794//===----------------------------------------------------------------------===//
1795
1796OpFoldResult MuxOp::fold(FoldAdaptor adaptor) {
1797 // mux (c, b, b) -> b
1798 if (getTrueValue() == getFalseValue() && getTrueValue() != getResult())
1799 return getTrueValue();
1800 if (auto tv = adaptor.getTrueValue())
1801 if (tv == adaptor.getFalseValue())
1802 return tv;
1803
1804 // mux(0, a, b) -> b
1805 // mux(1, a, b) -> a
1806 if (auto pred = dyn_cast_or_null<IntegerAttr>(adaptor.getCond())) {
1807 if (pred.getValue().isZero())
1808 return getFalseValue();
1809 return getTrueValue();
1810 }
1811
1812 // mux(cond, 1, 0) -> cond
1813 if (auto tv = dyn_cast_or_null<IntegerAttr>(adaptor.getTrueValue()))
1814 if (auto fv = dyn_cast_or_null<IntegerAttr>(adaptor.getFalseValue()))
1815 if (tv.getValue().isOne() && fv.getValue().isZero() &&
1816 hw::getBitWidth(getType()) == 1)
1817 return getCond();
1818
1819 return {};
1820}
1821
1822/// Check to see if the condition to the specified mux is an equality
1823/// comparison `indexValue` and one or more constants. If so, put the
1824/// constants in the constants vector and return true, otherwise return false.
1825///
1826/// This is part of foldMuxChain.
1827///
1828static bool
1829getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted,
1830 std::function<void(hw::ConstantOp)> constantFn) {
1831 // Handle `idx == 42` and `idx != 42`.
1832 if (auto cmp = cond.getDefiningOp<ICmpOp>()) {
1833 // TODO: We could handle things like "x < 2" as two entries.
1834 auto requiredPredicate =
1835 (isInverted ? ICmpPredicate::eq : ICmpPredicate::ne);
1836 if (cmp.getLhs() == indexValue && cmp.getPredicate() == requiredPredicate) {
1837 if (auto cst = cmp.getRhs().getDefiningOp<hw::ConstantOp>()) {
1838 constantFn(cst);
1839 return true;
1840 }
1841 }
1842 return false;
1843 }
1844
1845 // Handle mux(`idx == 1 || idx == 3`, value, muxchain).
1846 if (auto orOp = cond.getDefiningOp<OrOp>()) {
1847 if (!isInverted)
1848 return false;
1849 for (auto operand : orOp.getOperands())
1850 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
1851 return false;
1852 return true;
1853 }
1854
1855 // Handle mux(`idx != 1 && idx != 3`, muxchain, value).
1856 if (auto andOp = cond.getDefiningOp<AndOp>()) {
1857 if (isInverted)
1858 return false;
1859 for (auto operand : andOp.getOperands())
1860 if (!getMuxChainCondConstant(operand, indexValue, isInverted, constantFn))
1861 return false;
1862 return true;
1863 }
1864
1865 return false;
1866}
1867
1868/// Given a mux, check to see if the "on true" value (or "on false" value if
1869/// isFalseSide=true) is a mux tree with the same condition. This allows us
1870/// to turn things like `mux(VAL == 0, A, (mux (VAL == 1), B, C))` into
1871/// `array_get (array_create(A, B, C), VAL)` which is far more compact and
1872/// allows synthesis tools to do more interesting optimizations.
1873///
1874/// This returns false if we cannot form the mux tree (or do not want to) and
1875/// returns true if the mux was replaced.
1876static bool foldMuxChain(MuxOp rootMux, bool isFalseSide,
1877 PatternRewriter &rewriter) {
1878 // Get the index value being compared. Later we check to see if it is
1879 // compared to a constant with the right predicate.
1880 auto rootCmp = rootMux.getCond().getDefiningOp<ICmpOp>();
1881 if (!rootCmp)
1882 return false;
1883 Value indexValue = rootCmp.getLhs();
1884
1885 // Return the value to use if the equality match succeeds.
1886 auto getCaseValue = [&](MuxOp mux) -> Value {
1887 return mux.getOperand(1 + unsigned(!isFalseSide));
1888 };
1889
1890 // Return the value to use if the equality match fails. This is the next
1891 // mux in the sequence or the "otherwise" value.
1892 auto getTreeValue = [&](MuxOp mux) -> Value {
1893 return mux.getOperand(1 + unsigned(isFalseSide));
1894 };
1895
1896 // Start scanning the mux tree to see what we've got. Keep track of the
1897 // constant comparison value and the SSA value to use when equal to it.
1898 SmallVector<Location> locationsFound;
1899 SmallVector<std::pair<hw::ConstantOp, Value>, 4> valuesFound;
1900
1901 /// Extract constants and values into `valuesFound` and return true if this is
1902 /// part of the mux tree, otherwise return false.
1903 auto collectConstantValues = [&](MuxOp mux) -> bool {
1905 mux.getCond(), indexValue, isFalseSide, [&](hw::ConstantOp cst) {
1906 valuesFound.push_back({cst, getCaseValue(mux)});
1907 locationsFound.push_back(mux.getCond().getLoc());
1908 locationsFound.push_back(mux->getLoc());
1909 });
1910 };
1911
1912 // Make sure the root is a correct comparison with a constant.
1913 if (!collectConstantValues(rootMux))
1914 return false;
1915
1916 // Make sure that we're not looking at the intermediate node in a mux tree.
1917 if (rootMux->hasOneUse()) {
1918 if (auto userMux = dyn_cast<MuxOp>(*rootMux->user_begin())) {
1919 if (getTreeValue(userMux) == rootMux.getResult() &&
1920 getMuxChainCondConstant(userMux.getCond(), indexValue, isFalseSide,
1921 [&](hw::ConstantOp cst) {}))
1922 return false;
1923 }
1924 }
1925
1926 // Scan up the tree linearly.
1927 auto nextTreeValue = getTreeValue(rootMux);
1928 while (1) {
1929 auto nextMux = nextTreeValue.getDefiningOp<MuxOp>();
1930 if (!nextMux || !nextMux->hasOneUse())
1931 break;
1932 if (!collectConstantValues(nextMux))
1933 break;
1934 nextTreeValue = getTreeValue(nextMux);
1935 }
1936
1937 // We need to have more than three values to create an array. This is an
1938 // arbitrary threshold which is saying that one or two muxes together is ok,
1939 // but three should be folded.
1940 if (valuesFound.size() < 3)
1941 return false;
1942
1943 // If the array is greater that 9 bits, it will take over 512 elements and
1944 // it will be too large for a single expression.
1945 auto indexWidth = cast<IntegerType>(indexValue.getType()).getWidth();
1946 if (indexWidth >= 9)
1947 return false;
1948
1949 // Next we need to see if the values are dense-ish. We don't want to have
1950 // a tremendous number of replicated entries in the array. Some sparsity is
1951 // ok though, so we require the table to be at least 5/8 utilized.
1952 uint64_t tableSize = 1ULL << indexWidth;
1953 if (valuesFound.size() < (tableSize * 5) / 8)
1954 return false; // Not dense enough.
1955
1956 // Ok, we're going to do the transformation, start by building the table
1957 // filled with the "otherwise" value.
1958 SmallVector<Value, 8> table(tableSize, nextTreeValue);
1959
1960 // Fill in entries in the table from the leaf to the root of the expression.
1961 // This ensures that any duplicate matches end up with the ultimate value,
1962 // which is the one closer to the root.
1963 for (auto &elt : llvm::reverse(valuesFound)) {
1964 uint64_t idx = elt.first.getValue().getZExtValue();
1965 assert(idx < table.size() && "constant should be same bitwidth as index");
1966 table[idx] = elt.second;
1967 }
1968
1969 // The hw.array_create operation has the operand list in unintuitive order
1970 // with a[0] stored as the last element, not the first.
1971 std::reverse(table.begin(), table.end());
1972
1973 // Build the array_create and the array_get.
1974 auto fusedLoc = rewriter.getFusedLoc(locationsFound);
1975 auto array = rewriter.create<hw::ArrayCreateOp>(fusedLoc, table);
1976 replaceOpWithNewOpAndCopyNamehint<hw::ArrayGetOp>(rewriter, rootMux, array,
1977 indexValue);
1978 return true;
1979}
1980
1981/// Given a fully associative variadic operation like (a+b+c+d), break the
1982/// expression into two parts, one without the specified operand (e.g.
1983/// `tmp = a+b+d`) and one that combines that into the full expression (e.g.
1984/// `tmp+c`), and return the inner expression.
1985///
1986/// NOTE: This mutates the operation in place if it only has a single user,
1987/// which assumes that user will be removed.
1988///
1989static Value extractOperandFromFullyAssociative(Operation *fullyAssoc,
1990 size_t operandNo,
1991 PatternRewriter &rewriter) {
1992 assert(fullyAssoc->getNumOperands() >= 2 && "cannot split up unary ops");
1993 assert(operandNo < fullyAssoc->getNumOperands() && "Invalid operand #");
1994
1995 // If this expression already has two operands (the common case) no splitting
1996 // is necessary.
1997 if (fullyAssoc->getNumOperands() == 2)
1998 return fullyAssoc->getOperand(operandNo ^ 1);
1999
2000 // If the operation has a single use, mutate it in place.
2001 if (fullyAssoc->hasOneUse()) {
2002 rewriter.modifyOpInPlace(fullyAssoc,
2003 [&]() { fullyAssoc->eraseOperand(operandNo); });
2004 return fullyAssoc->getResult(0);
2005 }
2006
2007 // Form the new operation with the operands that remain.
2008 SmallVector<Value> operands;
2009 operands.append(fullyAssoc->getOperands().begin(),
2010 fullyAssoc->getOperands().begin() + operandNo);
2011 operands.append(fullyAssoc->getOperands().begin() + operandNo + 1,
2012 fullyAssoc->getOperands().end());
2013 Value opWithoutExcluded = createGenericOp(
2014 fullyAssoc->getLoc(), fullyAssoc->getName(), operands, rewriter);
2015 Value excluded = fullyAssoc->getOperand(operandNo);
2016
2017 Value fullResult =
2018 createGenericOp(fullyAssoc->getLoc(), fullyAssoc->getName(),
2019 ArrayRef<Value>{opWithoutExcluded, excluded}, rewriter);
2020 replaceOpAndCopyNamehint(rewriter, fullyAssoc, fullResult);
2021 return opWithoutExcluded;
2022}
2023
2024/// Fold things like `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond)|a` and
2025/// `mux(cond, a, x|y|z|a) -> `(x|y|z)&replicate(~cond) | a` (when isTrueOperand
2026/// is true. Return true on successful transformation, false if not.
2027///
2028/// These are various forms of "predicated ops" that can be handled with a
2029/// replicate/and combination.
2030static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand,
2031 PatternRewriter &rewriter) {
2032 // Check to see the operand in question is an operation. If it is a port,
2033 // we can't simplify it.
2034 Operation *subExpr =
2035 (isTrueOperand ? op.getFalseValue() : op.getTrueValue()).getDefiningOp();
2036 if (!subExpr || subExpr->getNumOperands() < 2)
2037 return false;
2038
2039 // If this isn't an operation we can handle, don't spend energy on it.
2040 if (!isa<AndOp, XorOp, OrOp, MuxOp>(subExpr))
2041 return false;
2042
2043 // Check to see if the common value occurs in the operand list for the
2044 // subexpression op. If so, then we can simplify it.
2045 Value commonValue = isTrueOperand ? op.getTrueValue() : op.getFalseValue();
2046 size_t opNo = 0, e = subExpr->getNumOperands();
2047 while (opNo != e && subExpr->getOperand(opNo) != commonValue)
2048 ++opNo;
2049 if (opNo == e)
2050 return false;
2051
2052 // If we got a hit, then go ahead and simplify it!
2053 Value cond = op.getCond();
2054
2055 // `mux(cond, a, mux(cond2, a, b))` -> `mux(cond|cond2, a, b)`
2056 // `mux(cond, a, mux(cond2, b, a))` -> `mux(cond|~cond2, a, b)`
2057 // `mux(cond, mux(cond2, a, b), a)` -> `mux(~cond|cond2, a, b)`
2058 // `mux(cond, mux(cond2, b, a), a)` -> `mux(~cond|~cond2, a, b)`
2059 if (auto subMux = dyn_cast<MuxOp>(subExpr)) {
2060 if (subMux == op)
2061 return false;
2062
2063 Value otherValue;
2064 Value subCond = subMux.getCond();
2065
2066 // Invert th subCond if needed and dig out the 'b' value.
2067 if (subMux.getTrueValue() == commonValue)
2068 otherValue = subMux.getFalseValue();
2069 else if (subMux.getFalseValue() == commonValue) {
2070 otherValue = subMux.getTrueValue();
2071 subCond = createOrFoldNot(op.getLoc(), subCond, rewriter);
2072 } else {
2073 // We can't fold `mux(cond, a, mux(a, x, y))`.
2074 return false;
2075 }
2076
2077 // Invert the outer cond if needed, and combine the mux conditions.
2078 if (!isTrueOperand)
2079 cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2080 cond = rewriter.createOrFold<OrOp>(op.getLoc(), cond, subCond, false);
2081 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, cond, commonValue,
2082 otherValue, op.getTwoState());
2083 return true;
2084 }
2085
2086 // Invert the condition if needed. Or/Xor invert when dealing with
2087 // TrueOperand, And inverts for False operand.
2088 bool isaAndOp = isa<AndOp>(subExpr);
2089 if (isTrueOperand ^ isaAndOp)
2090 cond = createOrFoldNot(op.getLoc(), cond, rewriter);
2091
2092 auto extendedCond =
2093 rewriter.createOrFold<ReplicateOp>(op.getLoc(), op.getType(), cond);
2094
2095 // Cache this information before subExpr is erased by extraction below.
2096 bool isaXorOp = isa<XorOp>(subExpr);
2097 bool isaOrOp = isa<OrOp>(subExpr);
2098
2099 // Handle the fully associative ops, start by pulling out the subexpression
2100 // from a many operand version of the op.
2101 auto restOfAssoc =
2102 extractOperandFromFullyAssociative(subExpr, opNo, rewriter);
2103
2104 // `mux(cond, x|y|z|a, a)` -> `(x|y|z)&replicate(cond) | a`
2105 // `mux(cond, x^y^z^a, a)` -> `(x^y^z)&replicate(cond) ^ a`
2106 if (isaOrOp || isaXorOp) {
2107 auto masked = rewriter.createOrFold<AndOp>(op.getLoc(), extendedCond,
2108 restOfAssoc, false);
2109 if (isaXorOp)
2110 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, masked,
2111 commonValue, false);
2112 else
2113 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, masked, commonValue,
2114 false);
2115 return true;
2116 }
2117
2118 // `mux(cond, a, x&y&z&a)` -> `((x&y&z)|replicate(cond)) & a`
2119 assert(isaAndOp && "unexpected operation here");
2120 auto masked = rewriter.createOrFold<OrOp>(op.getLoc(), extendedCond,
2121 restOfAssoc, false);
2122 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, masked, commonValue,
2123 false);
2124 return true;
2125}
2126
2127/// This function is invoke when we find a mux with true/false operations that
2128/// have the same opcode. Check to see if we can strength reduce the mux by
2129/// applying it to less data by applying this transformation:
2130/// `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2131static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp,
2132 Operation *falseOp,
2133 PatternRewriter &rewriter) {
2134 // Right now we only apply to concat.
2135 // TODO: Generalize this to and, or, xor, icmp(!), which all occur in practice
2136 if (!isa<ConcatOp>(trueOp))
2137 return false;
2138
2139 // Decode the operands, looking through recursive concats and replicates.
2140 SmallVector<Value> trueOperands, falseOperands;
2141 getConcatOperands(trueOp->getResult(0), trueOperands);
2142 getConcatOperands(falseOp->getResult(0), falseOperands);
2143
2144 size_t numTrueOperands = trueOperands.size();
2145 size_t numFalseOperands = falseOperands.size();
2146
2147 if (!numTrueOperands || !numFalseOperands ||
2148 (trueOperands.front() != falseOperands.front() &&
2149 trueOperands.back() != falseOperands.back()))
2150 return false;
2151
2152 // Pull all leading shared operands out into their own op if any are common.
2153 if (trueOperands.front() == falseOperands.front()) {
2154 SmallVector<Value> operands;
2155 size_t i;
2156 for (i = 0; i < numTrueOperands; ++i) {
2157 Value trueOperand = trueOperands[i];
2158 if (trueOperand == falseOperands[i])
2159 operands.push_back(trueOperand);
2160 else
2161 break;
2162 }
2163 if (i == numTrueOperands) {
2164 // Selecting between distinct, but lexically identical, concats.
2165 replaceOpAndCopyNamehint(rewriter, mux, trueOp->getResult(0));
2166 return true;
2167 }
2168
2169 Value sharedMSB;
2170 if (llvm::all_of(operands, [&](Value v) { return v == operands.front(); }))
2171 sharedMSB = rewriter.createOrFold<ReplicateOp>(
2172 mux->getLoc(), operands.front(), operands.size());
2173 else
2174 sharedMSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2175 operands.clear();
2176
2177 // Get a concat of the LSB's on each side.
2178 operands.append(trueOperands.begin() + i, trueOperands.end());
2179 Value trueLSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2180 operands.clear();
2181 operands.append(falseOperands.begin() + i, falseOperands.end());
2182 Value falseLSB =
2183 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2184 // Merge the LSBs with a new mux and concat the MSB with the LSB to be
2185 // done.
2186 Value lsb = rewriter.createOrFold<MuxOp>(
2187 mux->getLoc(), mux.getCond(), trueLSB, falseLSB, mux.getTwoState());
2188 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, sharedMSB, lsb);
2189 return true;
2190 }
2191
2192 // If trailing operands match, try to commonize them.
2193 if (trueOperands.back() == falseOperands.back()) {
2194 SmallVector<Value> operands;
2195 size_t i;
2196 for (i = 0;; ++i) {
2197 Value trueOperand = trueOperands[numTrueOperands - i - 1];
2198 if (trueOperand == falseOperands[numFalseOperands - i - 1])
2199 operands.push_back(trueOperand);
2200 else
2201 break;
2202 }
2203 std::reverse(operands.begin(), operands.end());
2204 Value sharedLSB = rewriter.createOrFold<ConcatOp>(mux->getLoc(), operands);
2205 operands.clear();
2206
2207 // Get a concat of the MSB's on each side.
2208 operands.append(trueOperands.begin(), trueOperands.end() - i);
2209 Value trueMSB = rewriter.createOrFold<ConcatOp>(trueOp->getLoc(), operands);
2210 operands.clear();
2211 operands.append(falseOperands.begin(), falseOperands.end() - i);
2212 Value falseMSB =
2213 rewriter.createOrFold<ConcatOp>(falseOp->getLoc(), operands);
2214 // Merge the MSBs with a new mux and concat the MSB with the LSB to be done.
2215 Value msb = rewriter.createOrFold<MuxOp>(
2216 mux->getLoc(), mux.getCond(), trueMSB, falseMSB, mux.getTwoState());
2217 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, mux, msb, sharedLSB);
2218 return true;
2219 }
2220
2221 return false;
2222}
2223
2224// If both arguments of the mux are arrays with the same elements, sink the
2225// mux and return a uniform array initializing all elements to it.
2226static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter) {
2227 auto trueVec = op.getTrueValue().getDefiningOp<hw::ArrayCreateOp>();
2228 auto falseVec = op.getFalseValue().getDefiningOp<hw::ArrayCreateOp>();
2229 if (!trueVec || !falseVec)
2230 return false;
2231 if (!trueVec.isUniform() || !falseVec.isUniform())
2232 return false;
2233
2234 auto mux = rewriter.create<MuxOp>(
2235 op.getLoc(), op.getCond(), trueVec.getUniformElement(),
2236 falseVec.getUniformElement(), op.getTwoState());
2237
2238 SmallVector<Value> values(trueVec.getInputs().size(), mux);
2239 rewriter.replaceOpWithNewOp<hw::ArrayCreateOp>(op, values);
2240 return true;
2241}
2242
2243namespace {
2244struct MuxRewriter : public mlir::OpRewritePattern<MuxOp> {
2245 using OpRewritePattern::OpRewritePattern;
2246
2247 LogicalResult matchAndRewrite(MuxOp op,
2248 PatternRewriter &rewriter) const override;
2249};
2250
2251LogicalResult MuxRewriter::matchAndRewrite(MuxOp op,
2252 PatternRewriter &rewriter) const {
2253 // If the op has a SV attribute, don't optimize it.
2254 if (hasSVAttributes(op))
2255 return failure();
2256 APInt value;
2257
2258 if (matchPattern(op.getTrueValue(), m_ConstantInt(&value))) {
2259 if (value.getBitWidth() == 1) {
2260 // mux(a, 0, b) -> and(~a, b) for single-bit values.
2261 if (value.isZero()) {
2262 auto notCond = createOrFoldNot(op.getLoc(), op.getCond(), rewriter);
2263 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, notCond,
2264 op.getFalseValue(), false);
2265 return success();
2266 }
2267
2268 // mux(a, 1, b) -> or(a, b) for single-bit values.
2269 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, op.getCond(),
2270 op.getFalseValue(), false);
2271 return success();
2272 }
2273
2274 // Check for mux of two constants. There are many ways to simplify them.
2275 APInt value2;
2276 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value2))) {
2277 // When both inputs are constants and differ by only one bit, we can
2278 // simplify by splitting the mux into up to three contiguous chunks: one
2279 // for the differing bit and up to two for the bits that are the same.
2280 // E.g. mux(a, 3'h2, 0) -> concat(0, mux(a, 1, 0), 0) -> concat(0, a, 0)
2281 APInt xorValue = value ^ value2;
2282 if (xorValue.isPowerOf2()) {
2283 unsigned leadingZeros = xorValue.countLeadingZeros();
2284 unsigned trailingZeros = value.getBitWidth() - leadingZeros - 1;
2285 SmallVector<Value, 3> operands;
2286
2287 // Concat operands go from MSB to LSB, so we handle chunks in reverse
2288 // order of bit indexes.
2289 // For the chunks that are identical (i.e. correspond to 0s in
2290 // xorValue), we can extract directly from either input value, and we
2291 // arbitrarily pick the trueValue().
2292
2293 if (leadingZeros > 0)
2294 operands.push_back(rewriter.createOrFold<ExtractOp>(
2295 op.getLoc(), op.getTrueValue(), trailingZeros + 1, leadingZeros));
2296
2297 // Handle the differing bit, which should simplify into either cond or
2298 // ~cond.
2299 auto v1 = rewriter.createOrFold<ExtractOp>(
2300 op.getLoc(), op.getTrueValue(), trailingZeros, 1);
2301 auto v2 = rewriter.createOrFold<ExtractOp>(
2302 op.getLoc(), op.getFalseValue(), trailingZeros, 1);
2303 operands.push_back(rewriter.createOrFold<MuxOp>(
2304 op.getLoc(), op.getCond(), v1, v2, false));
2305
2306 if (trailingZeros > 0)
2307 operands.push_back(rewriter.createOrFold<ExtractOp>(
2308 op.getLoc(), op.getTrueValue(), 0, trailingZeros));
2309
2310 replaceOpWithNewOpAndCopyNamehint<ConcatOp>(rewriter, op, op.getType(),
2311 operands);
2312 return success();
2313 }
2314
2315 // If the true value is all ones and the false is all zeros then we have a
2316 // replicate pattern.
2317 if (value.isAllOnes() && value2.isZero()) {
2318 replaceOpWithNewOpAndCopyNamehint<ReplicateOp>(
2319 rewriter, op, op.getType(), op.getCond());
2320 return success();
2321 }
2322 }
2323 }
2324
2325 if (matchPattern(op.getFalseValue(), m_ConstantInt(&value)) &&
2326 value.getBitWidth() == 1) {
2327 // mux(a, b, 0) -> and(a, b) for single-bit values.
2328 if (value.isZero()) {
2329 replaceOpWithNewOpAndCopyNamehint<AndOp>(rewriter, op, op.getCond(),
2330 op.getTrueValue(), false);
2331 return success();
2332 }
2333
2334 // mux(a, b, 1) -> or(~a, b) for single-bit values.
2335 // falseValue() is known to be a single-bit 1, which we can use for
2336 // the 1 in the representation of ~ using xor.
2337 auto notCond = rewriter.createOrFold<XorOp>(op.getLoc(), op.getCond(),
2338 op.getFalseValue(), false);
2339 replaceOpWithNewOpAndCopyNamehint<OrOp>(rewriter, op, notCond,
2340 op.getTrueValue(), false);
2341 return success();
2342 }
2343
2344 // mux(!a, b, c) -> mux(a, c, b)
2345 Value subExpr;
2346 Operation *condOp = op.getCond().getDefiningOp();
2347 if (condOp && matchPattern(condOp, m_Complement(m_Any(&subExpr))) &&
2348 op.getTwoState()) {
2349 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, op.getType(),
2350 subExpr, op.getFalseValue(),
2351 op.getTrueValue(), true);
2352 return success();
2353 }
2354
2355 // Same but with Demorgan's law.
2356 // mux(and(~a, ~b, ~c), x, y) -> mux(or(a, b, c), y, x)
2357 // mux(or(~a, ~b, ~c), x, y) -> mux(and(a, b, c), y, x)
2358 if (condOp && condOp->hasOneUse()) {
2359 SmallVector<Value> invertedOperands;
2360
2361 /// Scan all the operands to see if they are complemented. If so, build a
2362 /// vector of them and return true, otherwise return false.
2363 auto getInvertedOperands = [&]() -> bool {
2364 for (Value operand : condOp->getOperands()) {
2365 if (matchPattern(operand, m_Complement(m_Any(&subExpr))))
2366 invertedOperands.push_back(subExpr);
2367 else
2368 return false;
2369 }
2370 return true;
2371 };
2372
2373 if (isa<AndOp>(condOp) && getInvertedOperands()) {
2374 auto newOr =
2375 rewriter.createOrFold<OrOp>(op.getLoc(), invertedOperands, false);
2376 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2377 rewriter, op, newOr, op.getFalseValue(), op.getTrueValue(),
2378 op.getTwoState());
2379 return success();
2380 }
2381 if (isa<OrOp>(condOp) && getInvertedOperands()) {
2382 auto newAnd =
2383 rewriter.createOrFold<AndOp>(op.getLoc(), invertedOperands, false);
2384 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2385 rewriter, op, newAnd, op.getFalseValue(), op.getTrueValue(),
2386 op.getTwoState());
2387 return success();
2388 }
2389 }
2390
2391 if (auto falseMux = op.getFalseValue().getDefiningOp<MuxOp>();
2392 falseMux && falseMux != op) {
2393 // mux(selector, x, mux(selector, y, z) = mux(selector, x, z)
2394 if (op.getCond() == falseMux.getCond()) {
2395 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2396 rewriter, op, op.getCond(), op.getTrueValue(),
2397 falseMux.getFalseValue(), op.getTwoStateAttr());
2398 return success();
2399 }
2400
2401 // Check to see if we can fold a mux tree into an array_create/get pair.
2402 if (foldMuxChain(op, /*isFalse*/ true, rewriter))
2403 return success();
2404 }
2405
2406 if (auto trueMux = op.getTrueValue().getDefiningOp<MuxOp>();
2407 trueMux && trueMux != op) {
2408 // mux(selector, mux(selector, a, b), c) = mux(selector, a, c)
2409 if (op.getCond() == trueMux.getCond()) {
2410 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2411 rewriter, op, op.getCond(), trueMux.getTrueValue(),
2412 op.getFalseValue(), op.getTwoStateAttr());
2413 return success();
2414 }
2415
2416 // Check to see if we can fold a mux tree into an array_create/get pair.
2417 if (foldMuxChain(op, /*isFalseSide*/ false, rewriter))
2418 return success();
2419 }
2420
2421 // mux(c1, mux(c2, a, b), mux(c2, a, c)) -> mux(c2, a, mux(c1, b, c))
2422 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2423 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2424 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2425 trueMux.getTrueValue() == falseMux.getTrueValue() && trueMux != op &&
2426 falseMux != op) {
2427 auto subMux = rewriter.create<MuxOp>(
2428 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2429 op.getCond(), trueMux.getFalseValue(), falseMux.getFalseValue());
2430 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2431 trueMux.getTrueValue(), subMux,
2432 op.getTwoStateAttr());
2433 return success();
2434 }
2435
2436 // mux(c1, mux(c2, a, b), mux(c2, c, b)) -> mux(c2, mux(c1, a, c), b)
2437 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2438 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2439 trueMux && falseMux && trueMux.getCond() == falseMux.getCond() &&
2440 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2441 falseMux != op) {
2442 auto subMux = rewriter.create<MuxOp>(
2443 rewriter.getFusedLoc({trueMux.getLoc(), falseMux.getLoc()}),
2444 op.getCond(), trueMux.getTrueValue(), falseMux.getTrueValue());
2445 replaceOpWithNewOpAndCopyNamehint<MuxOp>(rewriter, op, trueMux.getCond(),
2446 subMux, trueMux.getFalseValue(),
2447 op.getTwoStateAttr());
2448 return success();
2449 }
2450
2451 // mux(c1, mux(c2, a, b), mux(c3, a, b)) -> mux(mux(c1, c2, c3), a, b)
2452 if (auto trueMux = dyn_cast_or_null<MuxOp>(op.getTrueValue().getDefiningOp()),
2453 falseMux = dyn_cast_or_null<MuxOp>(op.getFalseValue().getDefiningOp());
2454 trueMux && falseMux &&
2455 trueMux.getTrueValue() == falseMux.getTrueValue() &&
2456 trueMux.getFalseValue() == falseMux.getFalseValue() && trueMux != op &&
2457 falseMux != op) {
2458 auto subMux = rewriter.create<MuxOp>(
2459 rewriter.getFusedLoc(
2460 {op.getLoc(), trueMux.getLoc(), falseMux.getLoc()}),
2461 op.getCond(), trueMux.getCond(), falseMux.getCond());
2462 replaceOpWithNewOpAndCopyNamehint<MuxOp>(
2463 rewriter, op, subMux, trueMux.getTrueValue(), trueMux.getFalseValue(),
2464 op.getTwoStateAttr());
2465 return success();
2466 }
2467
2468 // mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond) | a
2469 if (foldCommonMuxValue(op, false, rewriter))
2470 return success();
2471 // mux(cond, a, x|y|z|a) -> (x|y|z)&replicate(~cond) | a
2472 if (foldCommonMuxValue(op, true, rewriter))
2473 return success();
2474
2475 // `mux(cond, op(a, b), op(a, c))` -> `op(a, mux(cond, b, c))`
2476 if (Operation *trueOp = op.getTrueValue().getDefiningOp())
2477 if (Operation *falseOp = op.getFalseValue().getDefiningOp())
2478 if (trueOp->getName() == falseOp->getName())
2479 if (foldCommonMuxOperation(op, trueOp, falseOp, rewriter))
2480 return success();
2481
2482 // extracts only of mux(...) -> mux(extract()...)
2483 if (narrowOperationWidth(op, true, rewriter))
2484 return success();
2485
2486 // mux(cond, repl(n, a1), repl(n, a2)) -> repl(n, mux(cond, a1, a2))
2487 if (foldMuxOfUniformArrays(op, rewriter))
2488 return success();
2489
2490 return failure();
2491}
2492
2493static bool foldArrayOfMuxes(hw::ArrayCreateOp op, PatternRewriter &rewriter) {
2494 // Do not fold uniform or singleton arrays to avoid duplicating muxes.
2495 if (op.getInputs().empty() || op.isUniform())
2496 return false;
2497 auto inputs = op.getInputs();
2498 if (inputs.size() <= 1)
2499 return false;
2500
2501 // Check the operands to the array create. Ensure all of them are the
2502 // same op with the same number of operands.
2503 auto first = inputs[0].getDefiningOp<comb::MuxOp>();
2504 if (!first || hasSVAttributes(first))
2505 return false;
2506
2507 // Check whether all operands are muxes with the same condition.
2508 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2509 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2510 if (!input || first.getCond() != input.getCond())
2511 return false;
2512 }
2513
2514 // Collect the true and the false branches into arrays.
2515 SmallVector<Value> trues{first.getTrueValue()};
2516 SmallVector<Value> falses{first.getFalseValue()};
2517 SmallVector<Location> locs{first->getLoc()};
2518 bool isTwoState = true;
2519 for (size_t i = 1, n = inputs.size(); i < n; ++i) {
2520 auto input = inputs[i].getDefiningOp<comb::MuxOp>();
2521 trues.push_back(input.getTrueValue());
2522 falses.push_back(input.getFalseValue());
2523 locs.push_back(input->getLoc());
2524 if (!input.getTwoState())
2525 isTwoState = false;
2526 }
2527
2528 // Define the location of the array create as the aggregate of all muxes.
2529 auto loc = FusedLoc::get(op.getContext(), locs);
2530
2531 // Replace the create with an aggregate operation. Push the create op
2532 // into the operands of the aggregate operation.
2533 auto arrayTy = op.getType();
2534 auto trueValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, trues);
2535 auto falseValues = rewriter.create<hw::ArrayCreateOp>(loc, arrayTy, falses);
2536 rewriter.replaceOpWithNewOp<comb::MuxOp>(op, arrayTy, first.getCond(),
2537 trueValues, falseValues, isTwoState);
2538 return true;
2539}
2540
2541struct ArrayRewriter : public mlir::OpRewritePattern<hw::ArrayCreateOp> {
2542 using OpRewritePattern::OpRewritePattern;
2543
2544 LogicalResult matchAndRewrite(hw::ArrayCreateOp op,
2545 PatternRewriter &rewriter) const override {
2546 if (foldArrayOfMuxes(op, rewriter))
2547 return success();
2548 return failure();
2549 }
2550};
2551
2552} // namespace
2553
2554void MuxOp::getCanonicalizationPatterns(RewritePatternSet &results,
2555 MLIRContext *context) {
2556 results.insert<MuxRewriter, ArrayRewriter>(context);
2557}
2558
2559//===----------------------------------------------------------------------===//
2560// ICmpOp
2561//===----------------------------------------------------------------------===//
2562
2563// Calculate the result of a comparison when the LHS and RHS are both
2564// constants.
2565static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs,
2566 const APInt &rhs) {
2567 switch (predicate) {
2568 case ICmpPredicate::eq:
2569 return lhs.eq(rhs);
2570 case ICmpPredicate::ne:
2571 return lhs.ne(rhs);
2572 case ICmpPredicate::slt:
2573 return lhs.slt(rhs);
2574 case ICmpPredicate::sle:
2575 return lhs.sle(rhs);
2576 case ICmpPredicate::sgt:
2577 return lhs.sgt(rhs);
2578 case ICmpPredicate::sge:
2579 return lhs.sge(rhs);
2580 case ICmpPredicate::ult:
2581 return lhs.ult(rhs);
2582 case ICmpPredicate::ule:
2583 return lhs.ule(rhs);
2584 case ICmpPredicate::ugt:
2585 return lhs.ugt(rhs);
2586 case ICmpPredicate::uge:
2587 return lhs.uge(rhs);
2588 case ICmpPredicate::ceq:
2589 return lhs.eq(rhs);
2590 case ICmpPredicate::cne:
2591 return lhs.ne(rhs);
2592 case ICmpPredicate::weq:
2593 return lhs.eq(rhs);
2594 case ICmpPredicate::wne:
2595 return lhs.ne(rhs);
2596 }
2597 llvm_unreachable("unknown comparison predicate");
2598}
2599
2600// Returns the result of applying the predicate when the LHS and RHS are the
2601// exact same value.
2602static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate) {
2603 switch (predicate) {
2604 case ICmpPredicate::eq:
2605 case ICmpPredicate::sle:
2606 case ICmpPredicate::sge:
2607 case ICmpPredicate::ule:
2608 case ICmpPredicate::uge:
2609 case ICmpPredicate::ceq:
2610 case ICmpPredicate::weq:
2611 return true;
2612 case ICmpPredicate::ne:
2613 case ICmpPredicate::slt:
2614 case ICmpPredicate::sgt:
2615 case ICmpPredicate::ult:
2616 case ICmpPredicate::ugt:
2617 case ICmpPredicate::cne:
2618 case ICmpPredicate::wne:
2619 return false;
2620 }
2621 llvm_unreachable("unknown comparison predicate");
2622}
2623
2624OpFoldResult ICmpOp::fold(FoldAdaptor adaptor) {
2625 // gt a, a -> false
2626 // gte a, a -> true
2627 if (getLhs() == getRhs()) {
2628 auto val = applyCmpPredicateToEqualOperands(getPredicate());
2629 return IntegerAttr::get(getType(), val);
2630 }
2631
2632 // gt 1, 2 -> false
2633 if (auto lhs = dyn_cast_or_null<IntegerAttr>(adaptor.getLhs())) {
2634 if (auto rhs = dyn_cast_or_null<IntegerAttr>(adaptor.getRhs())) {
2635 auto val =
2636 applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue());
2637 return IntegerAttr::get(getType(), val);
2638 }
2639 }
2640 return {};
2641}
2642
2643// Given a range of operands, computes the number of matching prefix and
2644// suffix elements. This does not perform cross-element matching.
2645template <typename Range>
2646static size_t computeCommonPrefixLength(const Range &a, const Range &b) {
2647 size_t commonPrefixLength = 0;
2648 auto ia = a.begin();
2649 auto ib = b.begin();
2650
2651 for (; ia != a.end() && ib != b.end(); ia++, ib++, commonPrefixLength++) {
2652 if (*ia != *ib) {
2653 break;
2654 }
2655 }
2656
2657 return commonPrefixLength;
2658}
2659
2660static size_t getTotalWidth(ArrayRef<Value> operands) {
2661 size_t totalWidth = 0;
2662 for (auto operand : operands) {
2663 // getIntOrFloatBitWidth should never raise, since all arguments to
2664 // ConcatOp are integers.
2665 ssize_t width = operand.getType().getIntOrFloatBitWidth();
2666 assert(width >= 0);
2667 totalWidth += width;
2668 }
2669 return totalWidth;
2670}
2671
2672/// Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise
2673/// comparison on common prefix and suffixes. Returns success() if a rewriting
2674/// happens. This handles both concat and replicate.
2675static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs,
2676 Operation *rhs,
2677 PatternRewriter &rewriter) {
2678 // It is safe to assume that [{lhsOperands, rhsOperands}.size() > 0] and
2679 // all elements have non-zero length. Both these invariants are verified
2680 // by the ConcatOp verifier.
2681 SmallVector<Value> lhsOperands, rhsOperands;
2682 getConcatOperands(lhs->getResult(0), lhsOperands);
2683 getConcatOperands(rhs->getResult(0), rhsOperands);
2684 ArrayRef<Value> lhsOperandsRef = lhsOperands, rhsOperandsRef = rhsOperands;
2685
2686 auto formCatOrReplicate = [&](Location loc,
2687 ArrayRef<Value> operands) -> Value {
2688 assert(!operands.empty());
2689 Value sameElement = operands[0];
2690 for (size_t i = 1, e = operands.size(); i != e && sameElement; ++i)
2691 if (sameElement != operands[i])
2692 sameElement = Value();
2693 if (sameElement)
2694 return rewriter.createOrFold<ReplicateOp>(loc, sameElement,
2695 operands.size());
2696 return rewriter.createOrFold<ConcatOp>(loc, operands);
2697 };
2698
2699 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2700 Value rhs) -> LogicalResult {
2701 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs, rhs,
2702 op.getTwoState());
2703 return success();
2704 };
2705
2706 size_t commonPrefixLength =
2707 computeCommonPrefixLength(lhsOperands, rhsOperands);
2708 if (commonPrefixLength == lhsOperands.size()) {
2709 // cat(a, b, c) == cat(a, b, c) -> 1
2710 bool result = applyCmpPredicateToEqualOperands(op.getPredicate());
2711 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2712 APInt(1, result));
2713 return success();
2714 }
2715
2716 size_t commonSuffixLength = computeCommonPrefixLength(
2717 llvm::reverse(lhsOperandsRef), llvm::reverse(rhsOperandsRef));
2718
2719 size_t commonPrefixTotalWidth =
2720 getTotalWidth(lhsOperandsRef.take_front(commonPrefixLength));
2721 size_t commonSuffixTotalWidth =
2722 getTotalWidth(lhsOperandsRef.take_back(commonSuffixLength));
2723 auto lhsOnly = lhsOperandsRef.drop_front(commonPrefixLength)
2724 .drop_back(commonSuffixLength);
2725 auto rhsOnly = rhsOperandsRef.drop_front(commonPrefixLength)
2726 .drop_back(commonSuffixLength);
2727
2728 auto replaceWithoutReplicatingSignBit = [&]() {
2729 auto newLhs = formCatOrReplicate(lhs->getLoc(), lhsOnly);
2730 auto newRhs = formCatOrReplicate(rhs->getLoc(), rhsOnly);
2731 return replaceWith(op.getPredicate(), newLhs, newRhs);
2732 };
2733
2734 auto replaceWithReplicatingSignBit = [&]() {
2735 auto firstNonEmptyValue = lhsOperands[0];
2736 auto firstNonEmptyElemWidth =
2737 firstNonEmptyValue.getType().getIntOrFloatBitWidth();
2738 Value signBit = rewriter.createOrFold<ExtractOp>(
2739 op.getLoc(), firstNonEmptyValue, firstNonEmptyElemWidth - 1, 1);
2740
2741 auto newLhs = rewriter.create<ConcatOp>(lhs->getLoc(), signBit, lhsOnly);
2742 auto newRhs = rewriter.create<ConcatOp>(rhs->getLoc(), signBit, rhsOnly);
2743 return replaceWith(op.getPredicate(), newLhs, newRhs);
2744 };
2745
2746 if (ICmpOp::isPredicateSigned(op.getPredicate())) {
2747 // scmp(cat(..x, b), cat(..y, b)) == scmp(cat(..x), cat(..y))
2748 if (commonPrefixTotalWidth == 0 && commonSuffixTotalWidth > 0)
2749 return replaceWithoutReplicatingSignBit();
2750
2751 // scmp(cat(a, ..x, b), cat(a, ..y, b)) == scmp(cat(sgn(a), ..x),
2752 // cat(sgn(b), ..y)) Note that we cannot perform this optimization if
2753 // [width(b) = 0 && width(a) <= 1]. since that common prefix is the sign
2754 // bit. Doing the rewrite can result in an infinite loop.
2755 if (commonPrefixTotalWidth > 1 || commonSuffixTotalWidth > 0)
2756 return replaceWithReplicatingSignBit();
2757
2758 } else if (commonPrefixTotalWidth > 0 || commonSuffixTotalWidth > 0) {
2759 // ucmp(cat(a, ..x, b), cat(a, ..y, b)) = ucmp(cat(..x), cat(..y))
2760 return replaceWithoutReplicatingSignBit();
2761 }
2762
2763 return failure();
2764}
2765
2766/// Given an equality comparison with a constant value and some operand that has
2767/// known bits, simplify the comparison to check only the unknown bits of the
2768/// input.
2769///
2770/// One simple example of this is that `concat(0, stuff) == 0` can be simplified
2771/// to `stuff == 0`, or `and(x, 3) == 0` can be simplified to
2772/// `extract x[1:0] == 0`
2774 ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst,
2775 PatternRewriter &rewriter) {
2776
2777 // If any of the known bits disagree with any of the comparison bits, then
2778 // we can constant fold this comparison right away.
2779 APInt bitsKnown = bitAnalysis.Zero | bitAnalysis.One;
2780 if ((bitsKnown & rhsCst) != bitAnalysis.One) {
2781 // If we discover a mismatch then we know an "eq" comparison is false
2782 // and a "ne" comparison is true!
2783 bool result = cmpOp.getPredicate() == ICmpPredicate::ne;
2784 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2785 APInt(1, result));
2786 return;
2787 }
2788
2789 // Check to see if we can prove the result entirely of the comparison (in
2790 // which we bail out early), otherwise build a list of values to concat and a
2791 // smaller constant to compare against.
2792 SmallVector<Value> newConcatOperands;
2793 auto newConstant = APInt::getZeroWidth();
2794
2795 // Ok, some (maybe all) bits are known and some others may be unknown.
2796 // Extract out segments of the operand and compare against the
2797 // corresponding bits.
2798 unsigned knownMSB = bitsKnown.countLeadingOnes();
2799
2800 Value operand = cmpOp.getLhs();
2801
2802 // Ok, some bits are known but others are not. Extract out sequences of
2803 // bits that are unknown and compare just those bits. We work from MSB to
2804 // LSB.
2805 while (knownMSB != bitsKnown.getBitWidth()) {
2806 // Drop any high bits that are known.
2807 if (knownMSB)
2808 bitsKnown = bitsKnown.trunc(bitsKnown.getBitWidth() - knownMSB);
2809
2810 // Find the span of unknown bits, and extract it.
2811 unsigned unknownBits = bitsKnown.countLeadingZeros();
2812 unsigned lowBit = bitsKnown.getBitWidth() - unknownBits;
2813 auto spanOperand = rewriter.createOrFold<ExtractOp>(
2814 operand.getLoc(), operand, /*lowBit=*/lowBit,
2815 /*bitWidth=*/unknownBits);
2816 auto spanConstant = rhsCst.lshr(lowBit).trunc(unknownBits);
2817
2818 // Add this info to the concat we're generating.
2819 newConcatOperands.push_back(spanOperand);
2820 // FIXME(llvm merge, cc697fc292b0): concat doesn't work with zero bit values
2821 // newConstant = newConstant.concat(spanConstant);
2822 if (newConstant.getBitWidth() != 0)
2823 newConstant = newConstant.concat(spanConstant);
2824 else
2825 newConstant = spanConstant;
2826
2827 // Drop the unknown bits in prep for the next chunk.
2828 unsigned newWidth = bitsKnown.getBitWidth() - unknownBits;
2829 bitsKnown = bitsKnown.trunc(newWidth);
2830 knownMSB = bitsKnown.countLeadingOnes();
2831 }
2832
2833 // If all the operands to the concat are foldable then we have an identity
2834 // situation where all the sub-elements equal each other. This implies that
2835 // the overall result is foldable.
2836 if (newConcatOperands.empty()) {
2837 bool result = cmpOp.getPredicate() == ICmpPredicate::eq;
2838 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, cmpOp,
2839 APInt(1, result));
2840 return;
2841 }
2842
2843 // If we have a single operand remaining, use it, otherwise form a concat.
2844 Value concatResult =
2845 rewriter.createOrFold<ConcatOp>(operand.getLoc(), newConcatOperands);
2846
2847 // Form the comparison against the smaller constant.
2848 auto newConstantOp = rewriter.create<hw::ConstantOp>(
2849 cmpOp.getOperand(1).getLoc(), newConstant);
2850
2851 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, cmpOp,
2852 cmpOp.getPredicate(), concatResult,
2853 newConstantOp, cmpOp.getTwoState());
2854}
2855
2856// Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b), cst1^cst2).
2857static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp,
2858 const APInt &rhs,
2859 PatternRewriter &rewriter) {
2860 auto ip = rewriter.saveInsertionPoint();
2861 rewriter.setInsertionPoint(xorOp);
2862
2863 auto xorRHS = xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>();
2864 auto newRHS = rewriter.create<hw::ConstantOp>(xorRHS->getLoc(),
2865 xorRHS.getValue() ^ rhs);
2866 Value newLHS;
2867 switch (xorOp.getNumOperands()) {
2868 case 1:
2869 // This isn't common but is defined so we need to handle it.
2870 newLHS = rewriter.create<hw::ConstantOp>(xorOp.getLoc(),
2871 APInt::getZero(rhs.getBitWidth()));
2872 break;
2873 case 2:
2874 // The binary case is the most common.
2875 newLHS = xorOp.getOperand(0);
2876 break;
2877 default:
2878 // The general case forces us to form a new xor with the remaining operands.
2879 SmallVector<Value> newOperands(xorOp.getOperands());
2880 newOperands.pop_back();
2881 newLHS = rewriter.create<XorOp>(xorOp.getLoc(), newOperands, false);
2882 break;
2883 }
2884
2885 bool xorMultipleUses = !xorOp->hasOneUse();
2886
2887 // If the xor has multiple uses (not just the compare, then we need/want to
2888 // replace them as well.
2889 if (xorMultipleUses)
2890 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, xorOp, newLHS, xorRHS,
2891 false);
2892
2893 // Replace the comparison.
2894 rewriter.restoreInsertionPoint(ip);
2895 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
2896 rewriter, cmpOp, cmpOp.getPredicate(), newLHS, newRHS, false);
2897}
2898
2899LogicalResult ICmpOp::canonicalize(ICmpOp op, PatternRewriter &rewriter) {
2900 APInt lhs, rhs;
2901
2902 // icmp 1, x -> icmp x, 1
2903 if (matchPattern(op.getLhs(), m_ConstantInt(&lhs))) {
2904 assert(!matchPattern(op.getRhs(), m_ConstantInt(&rhs)) &&
2905 "Should be folded");
2906 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
2907 rewriter, op, ICmpOp::getFlippedPredicate(op.getPredicate()),
2908 op.getRhs(), op.getLhs(), op.getTwoState());
2909 return success();
2910 }
2911
2912 // Canonicalize with RHS constant
2913 if (matchPattern(op.getRhs(), m_ConstantInt(&rhs))) {
2914 auto getConstant = [&](APInt constant) -> Value {
2915 return rewriter.create<hw::ConstantOp>(op.getLoc(), std::move(constant));
2916 };
2917
2918 auto replaceWith = [&](ICmpPredicate predicate, Value lhs,
2919 Value rhs) -> LogicalResult {
2920 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(rewriter, op, predicate, lhs,
2921 rhs, op.getTwoState());
2922 return success();
2923 };
2924
2925 auto replaceWithConstantI1 = [&](bool constant) -> LogicalResult {
2926 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
2927 APInt(1, constant));
2928 return success();
2929 };
2930
2931 switch (op.getPredicate()) {
2932 case ICmpPredicate::slt:
2933 // x < max -> x != max
2934 if (rhs.isMaxSignedValue())
2935 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
2936 // x < min -> false
2937 if (rhs.isMinSignedValue())
2938 return replaceWithConstantI1(0);
2939 // x < min+1 -> x == min
2940 if ((rhs - 1).isMinSignedValue())
2941 return replaceWith(ICmpPredicate::eq, op.getLhs(),
2942 getConstant(rhs - 1));
2943 break;
2944 case ICmpPredicate::sgt:
2945 // x > min -> x != min
2946 if (rhs.isMinSignedValue())
2947 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
2948 // x > max -> false
2949 if (rhs.isMaxSignedValue())
2950 return replaceWithConstantI1(0);
2951 // x > max-1 -> x == max
2952 if ((rhs + 1).isMaxSignedValue())
2953 return replaceWith(ICmpPredicate::eq, op.getLhs(),
2954 getConstant(rhs + 1));
2955 break;
2956 case ICmpPredicate::ult:
2957 // x < max -> x != max
2958 if (rhs.isAllOnes())
2959 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
2960 // x < min -> false
2961 if (rhs.isZero())
2962 return replaceWithConstantI1(0);
2963 // x < min+1 -> x == min
2964 if ((rhs - 1).isZero())
2965 return replaceWith(ICmpPredicate::eq, op.getLhs(),
2966 getConstant(rhs - 1));
2967
2968 // x < 0xE0 -> extract(x, 5..7) != 0b111
2969 if (rhs.countLeadingOnes() + rhs.countTrailingZeros() ==
2970 rhs.getBitWidth()) {
2971 auto numOnes = rhs.countLeadingOnes();
2972 auto smaller = rewriter.create<ExtractOp>(
2973 op.getLoc(), op.getLhs(), rhs.getBitWidth() - numOnes, numOnes);
2974 return replaceWith(ICmpPredicate::ne, smaller,
2975 getConstant(APInt::getAllOnes(numOnes)));
2976 }
2977
2978 break;
2979 case ICmpPredicate::ugt:
2980 // x > min -> x != min
2981 if (rhs.isZero())
2982 return replaceWith(ICmpPredicate::ne, op.getLhs(), op.getRhs());
2983 // x > max -> false
2984 if (rhs.isAllOnes())
2985 return replaceWithConstantI1(0);
2986 // x > max-1 -> x == max
2987 if ((rhs + 1).isAllOnes())
2988 return replaceWith(ICmpPredicate::eq, op.getLhs(),
2989 getConstant(rhs + 1));
2990
2991 // x > 0x07 -> extract(x, 3..7) != 0b00000
2992 if ((rhs + 1).isPowerOf2()) {
2993 auto numOnes = rhs.countTrailingOnes();
2994 auto newWidth = rhs.getBitWidth() - numOnes;
2995 auto smaller = rewriter.create<ExtractOp>(op.getLoc(), op.getLhs(),
2996 numOnes, newWidth);
2997 return replaceWith(ICmpPredicate::ne, smaller,
2998 getConstant(APInt::getZero(newWidth)));
2999 }
3000
3001 break;
3002 case ICmpPredicate::sle:
3003 // x <= max -> true
3004 if (rhs.isMaxSignedValue())
3005 return replaceWithConstantI1(1);
3006 // x <= c -> x < (c+1)
3007 return replaceWith(ICmpPredicate::slt, op.getLhs(), getConstant(rhs + 1));
3008 case ICmpPredicate::sge:
3009 // x >= min -> true
3010 if (rhs.isMinSignedValue())
3011 return replaceWithConstantI1(1);
3012 // x >= c -> x > (c-1)
3013 return replaceWith(ICmpPredicate::sgt, op.getLhs(), getConstant(rhs - 1));
3014 case ICmpPredicate::ule:
3015 // x <= max -> true
3016 if (rhs.isAllOnes())
3017 return replaceWithConstantI1(1);
3018 // x <= c -> x < (c+1)
3019 return replaceWith(ICmpPredicate::ult, op.getLhs(), getConstant(rhs + 1));
3020 case ICmpPredicate::uge:
3021 // x >= min -> true
3022 if (rhs.isZero())
3023 return replaceWithConstantI1(1);
3024 // x >= c -> x > (c-1)
3025 return replaceWith(ICmpPredicate::ugt, op.getLhs(), getConstant(rhs - 1));
3026 case ICmpPredicate::eq:
3027 if (rhs.getBitWidth() == 1) {
3028 if (rhs.isZero()) {
3029 // x == 0 -> x ^ 1
3030 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3031 getConstant(APInt(1, 1)),
3032 op.getTwoState());
3033 return success();
3034 }
3035 if (rhs.isAllOnes()) {
3036 // x == 1 -> x
3037 replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
3038 return success();
3039 }
3040 }
3041 break;
3042 case ICmpPredicate::ne:
3043 if (rhs.getBitWidth() == 1) {
3044 if (rhs.isZero()) {
3045 // x != 0 -> x
3046 replaceOpAndCopyNamehint(rewriter, op, op.getLhs());
3047 return success();
3048 }
3049 if (rhs.isAllOnes()) {
3050 // x != 1 -> x ^ 1
3051 replaceOpWithNewOpAndCopyNamehint<XorOp>(rewriter, op, op.getLhs(),
3052 getConstant(APInt(1, 1)),
3053 op.getTwoState());
3054 return success();
3055 }
3056 }
3057 break;
3058 case ICmpPredicate::ceq:
3059 case ICmpPredicate::cne:
3060 case ICmpPredicate::weq:
3061 case ICmpPredicate::wne:
3062 break;
3063 }
3064
3065 // We have some specific optimizations for comparison with a constant that
3066 // are only supported for equality comparisons.
3067 if (op.getPredicate() == ICmpPredicate::eq ||
3068 op.getPredicate() == ICmpPredicate::ne) {
3069 // Simplify `icmp(value_with_known_bits, rhscst)` into some extracts
3070 // with a smaller constant. We only support equality comparisons for
3071 // this.
3072 auto knownBits = computeKnownBits(op.getLhs());
3073 if (!knownBits.isUnknown())
3074 return combineEqualityICmpWithKnownBitsAndConstant(op, knownBits, rhs,
3075 rewriter),
3076 success();
3077
3078 // Simplify icmp eq(xor(a,b,cst1), cst2) -> icmp eq(xor(a,b),
3079 // cst1^cst2).
3080 if (auto xorOp = op.getLhs().getDefiningOp<XorOp>())
3081 if (xorOp.getOperands().back().getDefiningOp<hw::ConstantOp>())
3082 return combineEqualityICmpWithXorOfConstant(op, xorOp, rhs, rewriter),
3083 success();
3084
3085 // Simplify icmp eq(replicate(v, n), c) -> icmp eq(v, c) if c is zero or
3086 // all one.
3087 if (auto replicateOp = op.getLhs().getDefiningOp<ReplicateOp>())
3088 if (rhs.isAllOnes() || rhs.isZero()) {
3089 auto width = replicateOp.getInput().getType().getIntOrFloatBitWidth();
3090 auto cst = rewriter.create<hw::ConstantOp>(
3091 op.getLoc(), rhs.isAllOnes() ? APInt::getAllOnes(width)
3092 : APInt::getZero(width));
3093 replaceOpWithNewOpAndCopyNamehint<ICmpOp>(
3094 rewriter, op, op.getPredicate(), replicateOp.getInput(), cst,
3095 op.getTwoState());
3096 return success();
3097 }
3098 }
3099 }
3100
3101 // icmp(cat(prefix, a, b, suffix), cat(prefix, c, d, suffix)) => icmp(cat(a,
3102 // b), cat(c, d)). contains special handling for sign bit in signed
3103 // compressions.
3104 if (Operation *opLHS = op.getLhs().getDefiningOp())
3105 if (Operation *opRHS = op.getRhs().getDefiningOp())
3106 if (isa<ConcatOp, ReplicateOp>(opLHS) &&
3107 isa<ConcatOp, ReplicateOp>(opRHS)) {
3108 if (succeeded(matchAndRewriteCompareConcat(op, opLHS, opRHS, rewriter)))
3109 return success();
3110 }
3111
3112 return failure();
3113}
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static KnownBits computeKnownBits(Value v, unsigned depth)
Given an integer SSA value, check to see if we know anything about the result of the computation.
static bool foldMuxOfUniformArrays(MuxOp op, PatternRewriter &rewriter)
static Attribute constFoldAssociativeOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
static Attribute constFoldBinaryOp(ArrayRef< Attribute > operands, hw::PEO paramOpcode)
Performs constant folding calculate with element-wise behavior on the two attributes in operands and ...
static bool applyCmpPredicateToEqualOperands(ICmpPredicate predicate)
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
Definition CombFolds.cpp:76
static bool canonicalizeLogicalCstWithConcat(Operation *logicalOp, size_t concatIdx, const APInt &cst, PatternRewriter &rewriter)
When we find a logical operation (and, or, xor) with a constant e.g.
static bool narrowOperationWidth(OpTy op, bool narrowTrailingBits, PatternRewriter &rewriter)
static OpFoldResult foldDiv(Op op, ArrayRef< Attribute > constants)
static Value getCommonOperand(Op op)
Returns a single common operand that all inputs of the operation op can be traced back to,...
static bool canCombineOppositeBinCmpIntoConstant(OperandRange operands)
static void getConcatOperands(Value v, SmallVectorImpl< Value > &result)
Flatten concat and mux operands into a vector.
Definition CombFolds.cpp:44
static Value extractOperandFromFullyAssociative(Operation *fullyAssoc, size_t operandNo, PatternRewriter &rewriter)
Given a fully associative variadic operation like (a+b+c+d), break the expression into two parts,...
static bool getMuxChainCondConstant(Value cond, Value indexValue, bool isInverted, std::function< void(hw::ConstantOp)> constantFn)
Check to see if the condition to the specified mux is an equality comparison indexValue and one or mo...
static TypedAttr getIntAttr(const APInt &value, MLIRContext *context)
Definition CombFolds.cpp:38
static bool shouldBeFlattened(Operation *op)
Return true if the op will be flattened afterwards.
Definition CombFolds.cpp:82
static void canonicalizeXorIcmpTrue(XorOp op, unsigned icmpOperand, PatternRewriter &rewriter)
static bool extractFromReplicate(ExtractOp op, ReplicateOp replicate, PatternRewriter &rewriter)
static void combineEqualityICmpWithXorOfConstant(ICmpOp cmpOp, XorOp xorOp, const APInt &rhs, PatternRewriter &rewriter)
static size_t getTotalWidth(ArrayRef< Value > operands)
static bool foldCommonMuxOperation(MuxOp mux, Operation *trueOp, Operation *falseOp, PatternRewriter &rewriter)
This function is invoke when we find a mux with true/false operations that have the same opcode.
static std::pair< size_t, size_t > getLowestBitAndHighestBitRequired(Operation *op, bool narrowTrailingBits, size_t originalOpWidth)
static bool tryFlatteningOperands(Operation *op, PatternRewriter &rewriter)
Flattens a single input in op if hasOneUse is true and it can be defined as an Op.
static bool canonicalizeIdempotentInputs(Op op, PatternRewriter &rewriter)
Canonicalize an idempotent operation op so that only one input of any kind occurs.
static bool applyCmpPredicate(ICmpPredicate predicate, const APInt &lhs, const APInt &rhs)
static void combineEqualityICmpWithKnownBitsAndConstant(ICmpOp cmpOp, const KnownBits &bitAnalysis, const APInt &rhsCst, PatternRewriter &rewriter)
Given an equality comparison with a constant value and some operand that has known bits,...
static bool foldMuxChain(MuxOp rootMux, bool isFalseSide, PatternRewriter &rewriter)
Given a mux, check to see if the "on true" value (or "on false" value if isFalseSide=true) is a mux t...
static bool hasSVAttributes(Operation *op)
Definition CombFolds.cpp:59
static LogicalResult extractConcatToConcatExtract(ExtractOp op, ConcatOp innerCat, PatternRewriter &rewriter)
static OpFoldResult foldMod(Op op, ArrayRef< Attribute > constants)
static size_t computeCommonPrefixLength(const Range &a, const Range &b)
static bool foldCommonMuxValue(MuxOp op, bool isTrueOperand, PatternRewriter &rewriter)
Fold things like mux(cond, x|y|z|a, a) -> (x|y|z)&replicate(cond)|a and mux(cond, a,...
static LogicalResult matchAndRewriteCompareConcat(ICmpOp op, Operation *lhs, Operation *rhs, PatternRewriter &rewriter)
Reduce the strength icmp(concat(...), concat(...)) by doing a element-wise comparison on common prefi...
static Value createGenericOp(Location loc, OperationName name, ArrayRef< Value > operands, OpBuilder &builder)
Create a new instance of a generic operation that only has value operands, and has a single result va...
Definition CombFolds.cpp:30
static TypedAttr getIntAttr(MLIRContext *ctx, Type t, const APInt &value)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
create(low_bit, result_type, input=None)
Definition comb.py:187
create(elements)
Definition hw.py:483
create(data_type, value)
Definition hw.py:433
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition CombOps.cpp:65
KnownBits computeKnownBits(Value value)
Compute "known bits" information about the specified value - the set of bits that are guaranteed to a...
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
Definition comb.py:1