13 #include "mlir/Analysis/TopologicalSortUtils.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/DialectConversion.h"
19 #define GEN_PASS_DEF_CONVERTHWTOSMT
20 #include "circt/Conversion/Passes.h.inc"
23 using namespace circt;
36 matchAndRewrite(
ConstantOp op, OpAdaptor adaptor,
37 ConversionPatternRewriter &rewriter)
const override {
38 if (adaptor.getValue().getBitWidth() < 1)
39 return rewriter.notifyMatchFailure(op.getLoc(),
40 "0-bit constants not supported");
41 rewriter.replaceOpWithNewOp<smt::BVConstantOp>(op, adaptor.getValue());
51 matchAndRewrite(
HWModuleOp op, OpAdaptor adaptor,
52 ConversionPatternRewriter &rewriter)
const override {
53 auto funcTy = op.getModuleType().getFuncType();
54 SmallVector<Type> inputTypes, resultTypes;
55 if (failed(typeConverter->convertTypes(funcTy.getInputs(), inputTypes)))
57 if (failed(typeConverter->convertTypes(funcTy.getResults(), resultTypes)))
59 if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter)))
61 auto funcOp = rewriter.create<mlir::func::FuncOp>(
62 op.getLoc(), adaptor.getSymNameAttr(),
63 rewriter.getFunctionType(inputTypes, resultTypes));
64 rewriter.inlineRegionBefore(op.getBody(), funcOp.getBody(), funcOp.end());
75 matchAndRewrite(OutputOp op, OpAdaptor adaptor,
76 ConversionPatternRewriter &rewriter)
const override {
77 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, adaptor.getOutputs());
87 matchAndRewrite(InstanceOp op, OpAdaptor adaptor,
88 ConversionPatternRewriter &rewriter)
const override {
89 SmallVector<Type> resultTypes;
90 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
93 rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
94 op, adaptor.getModuleNameAttr(), resultTypes, adaptor.getInputs());
100 template <
typename OpTy>
103 using OpAdaptor =
typename OpTy::Adaptor;
106 matchAndRewrite(OpTy op, OpAdaptor adaptor,
107 ConversionPatternRewriter &rewriter)
const override {
108 rewriter.replaceOp(op, adaptor.getOperands());
120 struct ConvertHWToSMTPass
121 :
public impl::ConvertHWToSMTBase<ConvertHWToSMTPass> {
122 void runOnOperation()
override;
134 converter.addConversion([](IntegerType type) -> std::optional<Type> {
135 if (type.getWidth() <= 0)
139 converter.addConversion([](seq::ClockType type) -> std::optional<Type> {
145 converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
147 Location loc) -> std::optional<Value> {
149 .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
154 converter.addTargetMaterialization(
155 [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs,
156 Location loc) -> std::optional<Value> {
157 if (inputs.size() != 1)
160 if (!isa<smt::BoolType>(inputs[0].getType()))
163 unsigned width = resultType.getWidth();
164 Value constZero = builder.create<smt::BVConstantOp>(loc, 0,
width);
165 Value constOne = builder.create<smt::BVConstantOp>(loc, 1,
width);
166 return builder.create<smt::IteOp>(loc, inputs[0], constOne, constZero);
171 converter.addTargetMaterialization(
172 [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs,
173 Location loc) -> std::optional<Value> {
174 if (inputs.size() != 1 || resultType.getWidth() != 1)
177 auto intType = dyn_cast<IntegerType>(inputs[0].getType());
178 if (!intType || intType.getWidth() != 1)
182 inputs[0].getDefiningOp<mlir::UnrealizedConversionCastOp>();
183 if (!castOp || castOp.getInputs().size() != 1)
186 if (!isa<smt::BoolType>(castOp.getInputs()[0].getType()))
189 Value constZero = builder.create<smt::BVConstantOp>(loc, 0, 1);
190 Value constOne = builder.create<smt::BVConstantOp>(loc, 1, 1);
191 return builder.create<smt::IteOp>(loc, castOp.getInputs()[0], constOne,
196 converter.addTargetMaterialization(
197 [&](OpBuilder &builder, smt::BoolType resultType, ValueRange inputs,
198 Location loc) -> std::optional<Value> {
199 if (inputs.size() != 1)
202 auto bvType = dyn_cast<smt::BitVectorType>(inputs[0].getType());
203 if (!bvType || bvType.getWidth() != 1)
206 Value constOne = builder.create<smt::BVConstantOp>(loc, 1, 1);
207 return builder.create<smt::EqOp>(loc, inputs[0], constOne);
212 converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
214 Location loc) -> std::optional<Value> {
216 .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
223 patterns.add<HWConstantOpConversion, HWModuleOpConversion, OutputOpConversion,
224 InstanceOpConversion, ReplaceWithInput<seq::ToClockOp>,
225 ReplaceWithInput<seq::FromClockOp>>(converter,
229 void ConvertHWToSMTPass::runOnOperation() {
230 ConversionTarget target(getContext());
231 target.addIllegalDialect<hw::HWDialect>();
232 target.addIllegalOp<seq::FromClockOp>();
233 target.addIllegalOp<seq::ToClockOp>();
234 target.addLegalDialect<smt::SMTDialect>();
235 target.addLegalDialect<mlir::func::FuncDialect>();
237 RewritePatternSet
patterns(&getContext());
238 TypeConverter converter;
242 if (failed(mlir::applyPartialConversion(getOperation(), target,
244 return signalPassFailure();
249 for (
auto func : getOperation().getOps<mlir::func::FuncOp>()) {
252 if (func.getBody().getBlocks().size() != 1)
255 mlir::sortTopologically(&func.getBody().front());
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.