17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/DialectConversion.h"
19#include "llvm/ADT/PointerUnion.h"
22#define GEN_PASS_DEF_CONVERTCOMBTOAIG
23#include "circt/Conversion/Passes.h.inc"
34static SmallVector<Value>
extractBits(OpBuilder &builder, Value val) {
35 SmallVector<Value> bits;
36 comb::extractBits(builder, val, bits);
47template <
bool isLeftShift>
49 Value shiftAmount, int64_t maxShiftAmount,
50 llvm::function_ref<Value(int64_t)> getPadding,
51 llvm::function_ref<Value(int64_t)> getExtract) {
56 SmallVector<Value> nodes;
57 nodes.reserve(maxShiftAmount);
58 for (int64_t i = 0; i < maxShiftAmount; ++i) {
59 Value extract = getExtract(i);
60 Value padding = getPadding(i);
63 nodes.push_back(extract);
77 auto outOfBoundsValue = getPadding(maxShiftAmount);
78 assert(outOfBoundsValue &&
"outOfBoundsValue must be valid");
82 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
85 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
86 loc, ICmpPredicate::ult, shiftAmount,
90 return rewriter.createOrFold<
comb::MuxOp>(loc, inBound, result,
96using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
101 Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
103 if (value.getType().isInteger(0))
108 int64_t totalUnknownBits = 0;
109 for (
auto concatInput : llvm::reverse(
concat.getInputs())) {
114 totalUnknownBits += unknownBits;
116 return totalUnknownBits;
121 values.push_back(constant.getValueAttr());
127 values.push_back(value);
128 return hw::getBitWidth(value.getType());
134 llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
136 uint32_t bitPos = 0, unknownPos = 0;
137 APInt result(width, 0);
138 for (
auto constantOrValue : constantOrValues) {
140 if (
auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
141 elemWidth = constant.getValue().getBitWidth();
142 result.insertBits(constant.getValue(), bitPos);
144 elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
145 assert(elemWidth >= 0 &&
"unknown bit width");
146 assert(elemWidth + unknownPos < 32 &&
"unknown bit width too large");
148 uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
149 result.insertBits(APInt(elemWidth, usedBits), bitPos);
150 unknownPos += elemWidth;
162 ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
164 llvm::function_ref<APInt(
const APInt &,
const APInt &)> emulate) {
165 SmallVector<ConstantOrValue> lhsValues, rhsValues;
167 assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
168 "op must be a single result binary operation");
170 auto lhs = op->getOperand(0);
171 auto rhs = op->getOperand(1);
172 auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
173 auto loc = op->getLoc();
178 if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
181 int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
182 if (totalUnknownBits > maxEmulationUnknownBits)
185 SmallVector<Value> emulatedResults;
186 emulatedResults.reserve(1 << totalUnknownBits);
189 DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
191 auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
192 auto it = constantPool.find(attr);
193 if (it != constantPool.end())
196 constantPool[attr] = constant;
200 for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
201 lhsMask < lhsMaskEnd; ++lhsMask) {
203 for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
204 rhsMask < rhsMaskEnd; ++rhsMask) {
207 emulatedResults.push_back(
getConstant(emulate(lhsValue, rhsValue)));
212 SmallVector<Value> selectors;
213 selectors.reserve(totalUnknownBits);
214 for (
auto &concatedValues : {rhsValues, lhsValues})
215 for (
auto valueOrConstant : concatedValues) {
216 auto value = dyn_cast<Value>(valueOrConstant);
222 assert(totalUnknownBits ==
static_cast<int64_t
>(selectors.size()) &&
223 "number of selectors must match");
224 auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
227 rewriter.replaceOp(op, muxed);
242 matchAndRewrite(
AndOp op, OpAdaptor adaptor,
243 ConversionPatternRewriter &rewriter)
const override {
244 SmallVector<bool> nonInverts(adaptor.getInputs().size(),
false);
245 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, adaptor.getInputs(),
256 matchAndRewrite(
OrOp op, OpAdaptor adaptor,
257 ConversionPatternRewriter &rewriter)
const override {
259 SmallVector<bool> allInverts(adaptor.getInputs().size(),
true);
260 auto andOp = rewriter.create<aig::AndInverterOp>(
261 op.getLoc(), adaptor.getInputs(), allInverts);
262 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, andOp,
273 matchAndRewrite(
XorOp op, OpAdaptor adaptor,
274 ConversionPatternRewriter &rewriter)
const override {
275 if (op.getNumOperands() != 2)
281 auto inputs = adaptor.getInputs();
282 SmallVector<bool> allInverts(inputs.size(),
true);
283 SmallVector<bool> allNotInverts(inputs.size(),
false);
286 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allInverts);
288 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allNotInverts);
290 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, notAAndNotB, aAndB,
297template <
typename OpTy>
302 matchAndRewrite(OpTy op, OpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter)
const override {
305 rewriter.replaceOp(op, result);
310 ConversionPatternRewriter &rewriter) {
312 switch (operands.size()) {
314 assert(
false &&
"cannot be called with empty operand range");
321 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs},
true);
323 auto firstHalf = operands.size() / 2;
328 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs},
true);
338 matchAndRewrite(
MuxOp op, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter)
const override {
342 Value cond = op.getCond();
343 auto trueVal = op.getTrueValue();
344 auto falseVal = op.getFalseValue();
346 if (!op.getType().isInteger()) {
348 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
350 rewriter.create<
hw::BitcastOp>(op->getLoc(), widthType, trueVal);
356 if (!trueVal.getType().isInteger(1))
357 cond = rewriter.
create<comb::ReplicateOp>(op.getLoc(), trueVal.getType(),
361 auto lhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, trueVal);
362 auto rhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, falseVal,
365 Value result = rewriter.create<
comb::OrOp>(op.getLoc(), lhs, rhs);
367 if (result.getType() != op.getType())
369 rewriter.create<
hw::BitcastOp>(op.getLoc(), op.getType(), result);
370 rewriter.replaceOp(op, result);
378 matchAndRewrite(
AddOp op, OpAdaptor adaptor,
379 ConversionPatternRewriter &rewriter)
const override {
380 auto inputs = adaptor.getInputs();
383 if (inputs.size() != 2)
386 auto width = op.getType().getIntOrFloatBitWidth();
398 SmallVector<Value> results;
399 results.resize(width);
400 for (int64_t i = 0; i < width; ++i) {
401 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
403 xorOperands.push_back(carry);
407 results[width - i - 1] =
408 rewriter.create<
comb::XorOp>(op.getLoc(), xorOperands,
true);
411 if (i == width - 1) {
417 op.getLoc(), ValueRange{aBits[i], bBits[i]},
true);
425 op.getLoc(), ValueRange{aBits[i], bBits[i]},
true);
427 op.getLoc(), ValueRange{carry, aXnorB},
true);
428 carry = rewriter.create<
comb::OrOp>(op.getLoc(),
429 ValueRange{andOp, nextCarry},
true);
440 matchAndRewrite(
SubOp op, OpAdaptor adaptor,
441 ConversionPatternRewriter &rewriter)
const override {
442 auto lhs = op.getLhs();
443 auto rhs = op.getRhs();
447 auto notRhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), rhs,
449 auto one = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), 1);
450 rewriter.replaceOpWithNewOp<
comb::AddOp>(op, ValueRange{lhs, notRhs, one},
460 matchAndRewrite(
MulOp op, OpAdaptor adaptor,
461 ConversionPatternRewriter &rewriter)
const override {
462 if (adaptor.getInputs().size() != 2)
471 int64_t width = op.getType().getIntOrFloatBitWidth();
472 auto aBits =
extractBits(rewriter, adaptor.getInputs()[0]);
473 SmallVector<Value> results;
474 auto rhs = op.getInputs()[1];
476 llvm::APInt::getZero(width));
477 for (int64_t i = 0; i < width; ++i) {
478 auto aBit = aBits[i];
480 rewriter.createOrFold<
comb::MuxOp>(op.getLoc(), aBit, rhs, zero);
482 op.getLoc(), andBit, 0, width - i);
484 results.push_back(upperBits);
492 op.getLoc(), op.getType(), ValueRange{upperBits, lowerBits});
493 results.push_back(shifted);
496 rewriter.replaceOpWithNewOp<
comb::AddOp>(op, results,
true);
501template <
typename OpTy>
503 DivModOpConversionBase(MLIRContext *context, int64_t maxEmulationUnknownBits)
505 maxEmulationUnknownBits(maxEmulationUnknownBits) {
506 assert(maxEmulationUnknownBits < 32 &&
507 "maxEmulationUnknownBits must be less than 32");
509 const int64_t maxEmulationUnknownBits;
512struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
513 using DivModOpConversionBase<
DivUOp>::DivModOpConversionBase;
515 matchAndRewrite(
DivUOp op, OpAdaptor adaptor,
516 ConversionPatternRewriter &rewriter)
const override {
518 if (
auto rhsConstantOp = adaptor.getRhs().getDefiningOp<
hw::ConstantOp>())
519 if (rhsConstantOp.getValue().isPowerOf2()) {
521 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
522 size_t width = op.getType().getIntOrFloatBitWidth();
524 op.getLoc(), adaptor.getLhs(), extractAmount,
525 width - extractAmount);
527 op.getLoc(), APInt::getZero(extractAmount));
529 op, op.getType(), ArrayRef<Value>{constZero, upperBits});
536 rewriter, maxEmulationUnknownBits, op,
537 [](
const APInt &lhs,
const APInt &rhs) {
540 return APInt::getZero(rhs.getBitWidth());
541 return lhs.udiv(rhs);
546struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
547 using DivModOpConversionBase<
ModUOp>::DivModOpConversionBase;
549 matchAndRewrite(
ModUOp op, OpAdaptor adaptor,
550 ConversionPatternRewriter &rewriter)
const override {
552 if (
auto rhsConstantOp = adaptor.getRhs().getDefiningOp<
hw::ConstantOp>())
553 if (rhsConstantOp.getValue().isPowerOf2()) {
555 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
556 size_t width = op.getType().getIntOrFloatBitWidth();
558 op.getLoc(), adaptor.getLhs(), 0, extractAmount);
560 op.getLoc(), APInt::getZero(width - extractAmount));
562 op, op.getType(), ArrayRef<Value>{constZero, lowerBits});
569 rewriter, maxEmulationUnknownBits, op,
570 [](
const APInt &lhs,
const APInt &rhs) {
573 return APInt::getZero(rhs.getBitWidth());
574 return lhs.urem(rhs);
579struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
580 using DivModOpConversionBase<
DivSOp>::DivModOpConversionBase;
583 matchAndRewrite(
DivSOp op, OpAdaptor adaptor,
584 ConversionPatternRewriter &rewriter)
const override {
588 rewriter, maxEmulationUnknownBits, op,
589 [](
const APInt &lhs,
const APInt &rhs) {
592 return APInt::getZero(rhs.getBitWidth());
593 return lhs.sdiv(rhs);
598struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
599 using DivModOpConversionBase<
ModSOp>::DivModOpConversionBase;
601 matchAndRewrite(
ModSOp op, OpAdaptor adaptor,
602 ConversionPatternRewriter &rewriter)
const override {
606 rewriter, maxEmulationUnknownBits, op,
607 [](
const APInt &lhs,
const APInt &rhs) {
610 return APInt::getZero(rhs.getBitWidth());
611 return lhs.srem(rhs);
618 static Value constructUnsignedCompare(ICmpOp op, ArrayRef<Value> aBits,
619 ArrayRef<Value> bBits,
bool isLess,
621 ConversionPatternRewriter &rewriter) {
628 rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), includeEq);
630 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
632 rewriter.createOrFold<
comb::XorOp>(op.getLoc(), aBit, bBit,
true);
633 auto aEqualB = rewriter.createOrFold<aig::AndInverterOp>(
634 op.getLoc(), aBitXorBBit,
true);
635 auto pred = rewriter.createOrFold<aig::AndInverterOp>(
636 op.getLoc(), aBit, bBit, isLess, !isLess);
638 auto aBitAndBBit = rewriter.createOrFold<
comb::AndOp>(
639 op.getLoc(), ValueRange{aEqualB, acc},
true);
640 acc = rewriter.createOrFold<
comb::OrOp>(op.getLoc(), pred, aBitAndBBit,
647 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
648 ConversionPatternRewriter &rewriter)
const override {
649 auto lhs = adaptor.getLhs();
650 auto rhs = adaptor.getRhs();
652 switch (op.getPredicate()) {
656 case ICmpPredicate::eq:
657 case ICmpPredicate::ceq: {
659 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
661 SmallVector<bool> allInverts(xorBits.size(),
true);
662 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, xorBits, allInverts);
666 case ICmpPredicate::ne:
667 case ICmpPredicate::cne: {
669 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
675 case ICmpPredicate::uge:
676 case ICmpPredicate::ugt:
677 case ICmpPredicate::ule:
678 case ICmpPredicate::ult: {
679 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
680 op.getPredicate() == ICmpPredicate::ule;
681 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
682 op.getPredicate() == ICmpPredicate::ule;
685 rewriter.replaceOp(op, constructUnsignedCompare(op, aBits, bBits, isLess,
686 includeEq, rewriter));
689 case ICmpPredicate::slt:
690 case ICmpPredicate::sle:
691 case ICmpPredicate::sgt:
692 case ICmpPredicate::sge: {
693 if (lhs.getType().getIntOrFloatBitWidth() == 0)
694 return rewriter.notifyMatchFailure(
695 op.getLoc(),
"i0 signed comparison is unsupported");
696 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
697 op.getPredicate() == ICmpPredicate::sle;
698 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
699 op.getPredicate() == ICmpPredicate::sle;
705 auto signA = aBits.back();
706 auto signB = bBits.back();
709 auto sameSignResult = constructUnsignedCompare(
710 op, ArrayRef(aBits).drop_back(), ArrayRef(bBits).drop_back(), isLess,
711 includeEq, rewriter);
715 rewriter.create<
comb::XorOp>(op.getLoc(), signA, signB);
718 Value diffSignResult = isLess ? signA : signB;
721 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, signsDiffer, diffSignResult,
733 matchAndRewrite(
ParityOp op, OpAdaptor adaptor,
734 ConversionPatternRewriter &rewriter)
const override {
737 op,
extractBits(rewriter, adaptor.getInput()),
true);
747 ConversionPatternRewriter &rewriter)
const override {
748 auto width = op.getType().getIntOrFloatBitWidth();
749 auto lhs = adaptor.getLhs();
751 rewriter, op.getLoc(), adaptor.getRhs(), width,
759 op.getLoc(), rewriter.getIntegerType(index), 0);
763 assert(index < width &&
"index out of bounds");
769 rewriter.replaceOp(op, result);
779 ConversionPatternRewriter &rewriter)
const override {
780 auto width = op.getType().getIntOrFloatBitWidth();
781 auto lhs = adaptor.getLhs();
783 rewriter, op.getLoc(), adaptor.getRhs(), width,
791 op.getLoc(), rewriter.getIntegerType(index), 0);
795 assert(index < width &&
"index out of bounds");
801 rewriter.replaceOp(op, result);
811 ConversionPatternRewriter &rewriter)
const override {
812 auto width = op.getType().getIntOrFloatBitWidth();
814 return rewriter.notifyMatchFailure(op.getLoc(),
815 "i0 signed shift is unsupported");
816 auto lhs = adaptor.getLhs();
819 rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
824 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
827 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
836 rewriter.replaceOp(op, result);
848struct ConvertCombToAIGPass
849 :
public impl::ConvertCombToAIGBase<ConvertCombToAIGPass> {
850 void runOnOperation()
override;
851 using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
852 using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
853 using ConvertCombToAIGBase<ConvertCombToAIGPass>::maxEmulationUnknownBits;
859 uint32_t maxEmulationUnknownBits) {
862 CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
863 CombMuxOpConversion, CombParityOpConversion,
865 CombAddOpConversion, CombSubOpConversion, CombMulOpConversion,
866 CombICmpOpConversion,
868 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
870 CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
871 CombLowerVariadicOp<MulOp>>(
patterns.getContext());
874 patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
875 CombModSOpConversion>(
patterns.getContext(),
876 maxEmulationUnknownBits);
879void ConvertCombToAIGPass::runOnOperation() {
880 ConversionTarget target(getContext());
883 target.addIllegalDialect<comb::CombDialect>();
893 hw::AggregateConstantOp>();
896 target.addLegalDialect<aig::AIGDialect>();
899 if (!additionalLegalOps.empty())
900 for (
const auto &opName : additionalLegalOps)
901 target.addLegalOp(OperationName(opName, &getContext()));
903 RewritePatternSet
patterns(&getContext());
906 if (failed(mlir::applyPartialConversion(getOperation(), target,
908 return signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
static APInt substitueMaskToValues(size_t width, llvm::SmallVectorImpl< ConstantOrValue > &constantOrValues, uint32_t mask)
static LogicalResult emulateBinaryOpForUnknownBits(ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits, Operation *op, llvm::function_ref< APInt(const APInt &, const APInt &)> emulate)
static int64_t getNumUnknownBitsAndPopulateValues(Value value, llvm::SmallVectorImpl< ConstantOrValue > &values)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, uint32_t maxEmulationUnknownBits)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.