17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/DialectConversion.h"
21#define GEN_PASS_DEF_CONVERTCOMBTOAIG
22#include "circt/Conversion/Passes.h.inc"
33static SmallVector<Value>
extractBits(OpBuilder &builder, Value val) {
34 SmallVector<Value> bits;
35 comb::extractBits(builder, val, bits);
46template <
bool isLeftShift>
48 Value shiftAmount, int64_t maxShiftAmount,
49 llvm::function_ref<Value(int64_t)> getPadding,
50 llvm::function_ref<Value(int64_t)> getExtract) {
55 SmallVector<Value> nodes;
56 nodes.reserve(maxShiftAmount);
57 for (int64_t i = 0; i < maxShiftAmount; ++i) {
58 Value extract = getExtract(i);
59 Value padding = getPadding(i);
62 nodes.push_back(extract);
76 auto outOfBoundsValue = getPadding(maxShiftAmount);
77 assert(outOfBoundsValue &&
"outOfBoundsValue must be valid");
81 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
84 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
85 loc, ICmpPredicate::ult, shiftAmount,
89 return rewriter.createOrFold<
comb::MuxOp>(loc, inBound, result,
104 matchAndRewrite(
AndOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter)
const override {
106 SmallVector<bool> nonInverts(adaptor.getInputs().size(),
false);
107 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, adaptor.getInputs(),
118 matchAndRewrite(
OrOp op, OpAdaptor adaptor,
119 ConversionPatternRewriter &rewriter)
const override {
121 SmallVector<bool> allInverts(adaptor.getInputs().size(),
true);
122 auto andOp = rewriter.create<aig::AndInverterOp>(
123 op.getLoc(), adaptor.getInputs(), allInverts);
124 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, andOp,
135 matchAndRewrite(
XorOp op, OpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter)
const override {
137 if (op.getNumOperands() != 2)
143 auto inputs = adaptor.getInputs();
144 SmallVector<bool> allInverts(inputs.size(),
true);
145 SmallVector<bool> allNotInverts(inputs.size(),
false);
148 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allInverts);
150 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allNotInverts);
152 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, notAAndNotB, aAndB,
159template <
typename OpTy>
164 matchAndRewrite(OpTy op, OpAdaptor adaptor,
165 ConversionPatternRewriter &rewriter)
const override {
167 rewriter.replaceOp(op, result);
172 ConversionPatternRewriter &rewriter) {
174 switch (operands.size()) {
176 assert(
false &&
"cannot be called with empty operand range");
183 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs},
true);
185 auto firstHalf = operands.size() / 2;
190 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs},
true);
200 matchAndRewrite(
MuxOp op, OpAdaptor adaptor,
201 ConversionPatternRewriter &rewriter)
const override {
204 Value cond = op.getCond();
205 auto trueVal = op.getTrueValue();
206 auto falseVal = op.getFalseValue();
208 if (!op.getType().isInteger()) {
210 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
212 rewriter.create<
hw::BitcastOp>(op->getLoc(), widthType, trueVal);
218 if (!trueVal.getType().isInteger(1))
219 cond = rewriter.
create<comb::ReplicateOp>(op.getLoc(), trueVal.getType(),
223 auto lhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, trueVal);
224 auto rhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, falseVal,
227 Value result = rewriter.create<
comb::OrOp>(op.getLoc(), lhs, rhs);
229 if (result.getType() != op.getType())
231 rewriter.create<
hw::BitcastOp>(op.getLoc(), op.getType(), result);
232 rewriter.replaceOp(op, result);
240 matchAndRewrite(
AddOp op, OpAdaptor adaptor,
241 ConversionPatternRewriter &rewriter)
const override {
242 auto inputs = adaptor.getInputs();
245 if (inputs.size() != 2)
248 auto width = op.getType().getIntOrFloatBitWidth();
260 SmallVector<Value> results;
261 results.resize(width);
262 for (int64_t i = 0; i < width; ++i) {
263 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
265 xorOperands.push_back(carry);
269 results[width - i - 1] =
270 rewriter.create<
comb::XorOp>(op.getLoc(), xorOperands,
true);
273 if (i == width - 1) {
279 op.getLoc(), ValueRange{aBits[i], bBits[i]},
true);
287 op.getLoc(), ValueRange{aBits[i], bBits[i]},
true);
289 op.getLoc(), ValueRange{carry, aXnorB},
true);
290 carry = rewriter.create<
comb::OrOp>(op.getLoc(),
291 ValueRange{andOp, nextCarry},
true);
302 matchAndRewrite(
SubOp op, OpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter)
const override {
304 auto lhs = op.getLhs();
305 auto rhs = op.getRhs();
309 auto notRhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), rhs,
311 auto one = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), 1);
312 rewriter.replaceOpWithNewOp<
comb::AddOp>(op, ValueRange{lhs, notRhs, one},
322 matchAndRewrite(
MulOp op, OpAdaptor adaptor,
323 ConversionPatternRewriter &rewriter)
const override {
324 if (adaptor.getInputs().size() != 2)
333 int64_t width = op.getType().getIntOrFloatBitWidth();
334 auto aBits =
extractBits(rewriter, adaptor.getInputs()[0]);
335 SmallVector<Value> results;
336 auto rhs = op.getInputs()[1];
338 llvm::APInt::getZero(width));
339 for (int64_t i = 0; i < width; ++i) {
340 auto aBit = aBits[i];
342 rewriter.createOrFold<
comb::MuxOp>(op.getLoc(), aBit, rhs, zero);
344 op.getLoc(), andBit, 0, width - i);
346 results.push_back(upperBits);
354 op.getLoc(), op.getType(), ValueRange{upperBits, lowerBits});
355 results.push_back(shifted);
358 rewriter.replaceOpWithNewOp<
comb::AddOp>(op, results,
true);
365 static Value constructUnsignedCompare(ICmpOp op, ArrayRef<Value> aBits,
366 ArrayRef<Value> bBits,
bool isLess,
368 ConversionPatternRewriter &rewriter) {
375 rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), includeEq);
377 for (
auto [aBit, bBit] :
llvm::zip(aBits, bBits)) {
379 rewriter.createOrFold<
comb::XorOp>(op.getLoc(), aBit, bBit,
true);
380 auto aEqualB = rewriter.createOrFold<aig::AndInverterOp>(
381 op.getLoc(), aBitXorBBit,
true);
382 auto pred = rewriter.createOrFold<aig::AndInverterOp>(
383 op.getLoc(), aBit, bBit, isLess, !isLess);
385 auto aBitAndBBit = rewriter.createOrFold<
comb::AndOp>(
386 op.getLoc(), ValueRange{aEqualB, acc},
true);
387 acc = rewriter.createOrFold<
comb::OrOp>(op.getLoc(), pred, aBitAndBBit,
394 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
395 ConversionPatternRewriter &rewriter)
const override {
396 auto lhs = adaptor.getLhs();
397 auto rhs = adaptor.getRhs();
399 switch (op.getPredicate()) {
403 case ICmpPredicate::eq:
404 case ICmpPredicate::ceq: {
406 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
408 SmallVector<bool> allInverts(xorBits.size(),
true);
409 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, xorBits, allInverts);
413 case ICmpPredicate::ne:
414 case ICmpPredicate::cne: {
416 auto xorOp = rewriter.createOrFold<
comb::XorOp>(op.getLoc(), lhs, rhs);
422 case ICmpPredicate::uge:
423 case ICmpPredicate::ugt:
424 case ICmpPredicate::ule:
425 case ICmpPredicate::ult: {
426 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
427 op.getPredicate() == ICmpPredicate::ule;
428 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
429 op.getPredicate() == ICmpPredicate::ule;
432 rewriter.replaceOp(op, constructUnsignedCompare(op, aBits, bBits, isLess,
433 includeEq, rewriter));
436 case ICmpPredicate::slt:
437 case ICmpPredicate::sle:
438 case ICmpPredicate::sgt:
439 case ICmpPredicate::sge: {
440 if (lhs.getType().getIntOrFloatBitWidth() == 0)
441 return rewriter.notifyMatchFailure(
442 op.getLoc(),
"i0 signed comparison is unsupported");
443 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
444 op.getPredicate() == ICmpPredicate::sle;
445 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
446 op.getPredicate() == ICmpPredicate::sle;
452 auto signA = aBits.back();
453 auto signB = bBits.back();
456 auto sameSignResult = constructUnsignedCompare(
457 op, ArrayRef(aBits).drop_back(), ArrayRef(bBits).drop_back(), isLess,
458 includeEq, rewriter);
462 rewriter.create<
comb::XorOp>(op.getLoc(), signA, signB);
465 Value diffSignResult = isLess ? signA : signB;
468 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, signsDiffer, diffSignResult,
480 matchAndRewrite(
ParityOp op, OpAdaptor adaptor,
481 ConversionPatternRewriter &rewriter)
const override {
484 op,
extractBits(rewriter, adaptor.getInput()),
true);
494 ConversionPatternRewriter &rewriter)
const override {
495 auto width = op.getType().getIntOrFloatBitWidth();
496 auto lhs = adaptor.getLhs();
498 rewriter, op.getLoc(), adaptor.getRhs(), width,
506 op.getLoc(), rewriter.getIntegerType(index), 0);
510 assert(index < width &&
"index out of bounds");
516 rewriter.replaceOp(op, result);
526 ConversionPatternRewriter &rewriter)
const override {
527 auto width = op.getType().getIntOrFloatBitWidth();
528 auto lhs = adaptor.getLhs();
530 rewriter, op.getLoc(), adaptor.getRhs(), width,
538 op.getLoc(), rewriter.getIntegerType(index), 0);
542 assert(index < width &&
"index out of bounds");
548 rewriter.replaceOp(op, result);
558 ConversionPatternRewriter &rewriter)
const override {
559 auto width = op.getType().getIntOrFloatBitWidth();
561 return rewriter.notifyMatchFailure(op.getLoc(),
562 "i0 signed shift is unsupported");
563 auto lhs = adaptor.getLhs();
566 rewriter.createOrFold<
comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
571 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
574 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
583 rewriter.replaceOp(op, result);
595struct ConvertCombToAIGPass
596 :
public impl::ConvertCombToAIGBase<ConvertCombToAIGPass> {
597 void runOnOperation()
override;
598 using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
599 using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
606 CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
607 CombMuxOpConversion, CombParityOpConversion,
609 CombAddOpConversion, CombSubOpConversion, CombMulOpConversion,
610 CombICmpOpConversion,
612 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
614 CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
615 CombLowerVariadicOp<MulOp>>(
patterns.getContext());
618void ConvertCombToAIGPass::runOnOperation() {
619 ConversionTarget target(getContext());
622 target.addIllegalDialect<comb::CombDialect>();
632 hw::AggregateConstantOp>();
635 target.addLegalDialect<aig::AIGDialect>();
638 if (!additionalLegalOps.empty())
639 for (
const auto &opName : additionalLegalOps)
640 target.addLegalOp(OperationName(opName, &getContext()));
642 RewritePatternSet
patterns(&getContext());
645 if (failed(mlir::applyPartialConversion(getOperation(), target,
647 return signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns)
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.