17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Transforms/DialectConversion.h"
21 #define GEN_PASS_DEF_CONVERTCOMBTOAIG
22 #include "circt/Conversion/Passes.h.inc"
25 using namespace circt;
33 static SmallVector<Value>
extractBits(ConversionPatternRewriter &rewriter,
35 assert(val.getType().isInteger() &&
"expected integer");
36 auto width = val.getType().getIntOrFloatBitWidth();
37 SmallVector<Value> bits;
42 if (
concat.getNumOperands() == width &&
43 llvm::all_of(
concat.getOperandTypes(), [](Type type) {
44 return type.getIntOrFloatBitWidth() == 1;
47 bits.append(std::make_reverse_iterator(
concat.getOperands().end()),
48 std::make_reverse_iterator(
concat.getOperands().begin()));
54 for (int64_t i = 0; i < width; ++i)
72 matchAndRewrite(
AndOp op, OpAdaptor adaptor,
73 ConversionPatternRewriter &rewriter)
const override {
74 SmallVector<bool> nonInverts(adaptor.getInputs().size(),
false);
75 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, adaptor.getInputs(),
86 matchAndRewrite(
OrOp op, OpAdaptor adaptor,
87 ConversionPatternRewriter &rewriter)
const override {
89 SmallVector<bool> allInverts(adaptor.getInputs().size(),
true);
90 auto andOp = rewriter.create<aig::AndInverterOp>(
91 op.getLoc(), adaptor.getInputs(), allInverts);
92 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, andOp,
103 matchAndRewrite(
XorOp op, OpAdaptor adaptor,
104 ConversionPatternRewriter &rewriter)
const override {
105 if (op.getNumOperands() != 2)
111 auto inputs = adaptor.getInputs();
112 SmallVector<bool> allInverts(inputs.size(),
true);
113 SmallVector<bool> allNotInverts(inputs.size(),
false);
116 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allInverts);
118 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allNotInverts);
120 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, notAAndNotB, aAndB,
127 template <
typename OpTy>
132 matchAndRewrite(OpTy op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter)
const override {
135 rewriter.replaceOp(op, result);
140 ConversionPatternRewriter &rewriter) {
142 switch (operands.size()) {
144 assert(
false &&
"cannot be called with empty operand range");
151 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs},
true);
153 auto firstHalf = operands.size() / 2;
158 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs},
true);
168 matchAndRewrite(
MuxOp op, OpAdaptor adaptor,
169 ConversionPatternRewriter &rewriter)
const override {
172 Value cond = op.getCond();
173 auto trueVal = op.getTrueValue();
174 auto falseVal = op.getFalseValue();
176 if (!op.getType().isInteger()) {
178 auto widthType = rewriter.getIntegerType(
hw::getBitWidth(op.getType()));
180 rewriter.create<
hw::BitcastOp>(op->getLoc(), widthType, trueVal);
186 if (!trueVal.getType().isInteger(1))
187 cond = rewriter.
create<comb::ReplicateOp>(op.getLoc(), trueVal.getType(),
191 auto lhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, trueVal);
192 auto rhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, falseVal,
195 Value result = rewriter.create<
comb::OrOp>(op.getLoc(), lhs, rhs);
197 if (result.getType() != op.getType())
199 rewriter.create<
hw::BitcastOp>(op.getLoc(), op.getType(), result);
200 rewriter.replaceOp(op, result);
208 matchAndRewrite(
AddOp op, OpAdaptor adaptor,
209 ConversionPatternRewriter &rewriter)
const override {
210 auto inputs = adaptor.getInputs();
213 if (inputs.size() != 2)
216 auto width = op.getType().getIntOrFloatBitWidth();
228 SmallVector<Value> results;
229 results.resize(width);
230 for (int64_t i = 0; i < width; ++i) {
231 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
233 xorOperands.push_back(carry);
237 results[width - i - 1] =
238 rewriter.create<
comb::XorOp>(op.getLoc(), xorOperands,
true);
241 if (i == width - 1) {
247 op.getLoc(), ValueRange{aBits[i], bBits[i]},
true);
255 op.getLoc(), ValueRange{aBits[i], bBits[i]},
true);
257 op.getLoc(), ValueRange{carry, aXnorB},
true);
258 carry = rewriter.create<
comb::OrOp>(op.getLoc(),
259 ValueRange{andOp, nextCarry},
true);
270 matchAndRewrite(
SubOp op, OpAdaptor adaptor,
271 ConversionPatternRewriter &rewriter)
const override {
272 auto lhs = op.getLhs();
273 auto rhs = op.getRhs();
277 auto notRhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), rhs,
279 auto one = rewriter.create<
hw::ConstantOp>(op.getLoc(), op.getType(), 1);
280 rewriter.replaceOpWithNewOp<
comb::AddOp>(op, ValueRange{lhs, notRhs, one},
293 struct ConvertCombToAIGPass
294 :
public impl::ConvertCombToAIGBase<ConvertCombToAIGPass> {
295 void runOnOperation()
override;
296 using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
297 using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
304 CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
307 CombAddOpConversion, CombSubOpConversion,
309 CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>>(
313 void ConvertCombToAIGPass::runOnOperation() {
314 ConversionTarget target(getContext());
315 target.addIllegalDialect<comb::CombDialect>();
319 target.addLegalDialect<aig::AIGDialect>();
322 if (!additionalLegalOps.empty())
323 for (
const auto &opName : additionalLegalOps)
324 target.addLegalOp(OperationName(opName, &getContext()));
326 RewritePatternSet
patterns(&getContext());
329 if (failed(mlir::applyPartialConversion(getOperation(), target,
331 return signalPassFailure();
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns)
static SmallVector< Value > extractBits(ConversionPatternRewriter &rewriter, Value val)
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
def create(data_type, value)
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.