13 #include "mlir/Analysis/TopologicalSortUtils.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/DialectConversion.h"
19 #define GEN_PASS_DEF_CONVERTHWTOSMT
20 #include "circt/Conversion/Passes.h.inc"
23 using namespace circt;
36 matchAndRewrite(
ConstantOp op, OpAdaptor adaptor,
37 ConversionPatternRewriter &rewriter)
const override {
38 if (adaptor.getValue().getBitWidth() < 1)
39 return rewriter.notifyMatchFailure(op.getLoc(),
40 "0-bit constants not supported");
41 rewriter.replaceOpWithNewOp<smt::BVConstantOp>(op, adaptor.getValue());
51 matchAndRewrite(
HWModuleOp op, OpAdaptor adaptor,
52 ConversionPatternRewriter &rewriter)
const override {
53 auto funcTy = op.getModuleType().getFuncType();
54 SmallVector<Type> inputTypes, resultTypes;
55 if (failed(typeConverter->convertTypes(funcTy.getInputs(), inputTypes)))
57 if (failed(typeConverter->convertTypes(funcTy.getResults(), resultTypes)))
59 if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter)))
61 auto funcOp = rewriter.create<mlir::func::FuncOp>(
62 op.getLoc(), adaptor.getSymNameAttr(),
63 rewriter.getFunctionType(inputTypes, resultTypes));
64 rewriter.inlineRegionBefore(op.getBody(), funcOp.getBody(), funcOp.end());
75 matchAndRewrite(OutputOp op, OpAdaptor adaptor,
76 ConversionPatternRewriter &rewriter)
const override {
77 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, adaptor.getOutputs());
87 matchAndRewrite(InstanceOp op, OpAdaptor adaptor,
88 ConversionPatternRewriter &rewriter)
const override {
89 SmallVector<Type> resultTypes;
90 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
93 rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
94 op, adaptor.getModuleNameAttr(), resultTypes, adaptor.getInputs());
106 ConversionPatternRewriter &rewriter)
const override {
107 Location loc = op.getLoc();
108 Type arrTy = typeConverter->convertType(op.getType());
110 return rewriter.notifyMatchFailure(op.getLoc(),
"unsupported array type");
112 unsigned width = adaptor.getInputs().size();
114 Value arr = rewriter.
create<smt::DeclareFunOp>(loc, arrTy);
115 for (
auto [i, el] : llvm::enumerate(adaptor.getInputs())) {
116 Value idx = rewriter.create<smt::BVConstantOp>(loc, width - i - 1,
117 llvm::Log2_64_Ceil(width));
118 arr = rewriter.create<smt::ArrayStoreOp>(loc, arr, idx, el);
121 rewriter.replaceOp(op, arr);
131 matchAndRewrite(
ArrayGetOp op, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter)
const override {
133 Location loc = op.getLoc();
135 cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
137 Type type = typeConverter->convertType(op.getType());
139 return rewriter.notifyMatchFailure(op.getLoc(),
140 "unsupported array element type");
142 Value oobVal = rewriter.create<smt::DeclareFunOp>(loc, type);
143 Value numElementsVal = rewriter.create<smt::BVConstantOp>(
145 Value inBounds = rewriter.create<smt::BVCmpOp>(
146 loc, smt::BVCmpPredicate::ule, adaptor.getIndex(), numElementsVal);
147 Value indexed = rewriter.create<smt::ArraySelectOp>(loc, adaptor.getInput(),
149 rewriter.replaceOpWithNewOp<smt::IteOp>(op, inBounds, indexed, oobVal);
155 template <
typename OpTy>
158 using OpAdaptor =
typename OpTy::Adaptor;
161 matchAndRewrite(OpTy op, OpAdaptor adaptor,
162 ConversionPatternRewriter &rewriter)
const override {
163 rewriter.replaceOp(op, adaptor.getOperands());
175 struct ConvertHWToSMTPass
176 :
public impl::ConvertHWToSMTBase<ConvertHWToSMTPass> {
177 void runOnOperation()
override;
189 converter.addConversion([](IntegerType type) -> std::optional<Type> {
190 if (type.getWidth() <= 0)
194 converter.addConversion([](seq::ClockType type) -> std::optional<Type> {
197 converter.addConversion([&](ArrayType type) -> std::optional<Type> {
198 auto rangeType = converter.convertType(type.getElementType());
202 type.getContext(), llvm::Log2_64_Ceil(type.getNumElements()));
208 converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
210 Location loc) -> Value {
212 .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
217 converter.addTargetMaterialization(
218 [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs,
219 Location loc) -> Value {
220 if (inputs.size() != 1)
223 if (!isa<smt::BoolType>(inputs[0].getType()))
226 unsigned width = resultType.getWidth();
227 Value constZero = builder.create<smt::BVConstantOp>(loc, 0, width);
228 Value constOne = builder.create<smt::BVConstantOp>(loc, 1, width);
229 return builder.create<smt::IteOp>(loc, inputs[0], constOne, constZero);
234 converter.addTargetMaterialization(
235 [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs,
236 Location loc) -> Value {
237 if (inputs.size() != 1 || resultType.getWidth() != 1)
240 auto intType = dyn_cast<IntegerType>(inputs[0].getType());
241 if (!intType || intType.getWidth() != 1)
245 inputs[0].getDefiningOp<mlir::UnrealizedConversionCastOp>();
246 if (!castOp || castOp.getInputs().size() != 1)
249 if (!isa<smt::BoolType>(castOp.getInputs()[0].getType()))
252 Value constZero = builder.create<smt::BVConstantOp>(loc, 0, 1);
253 Value constOne = builder.create<smt::BVConstantOp>(loc, 1, 1);
254 return builder.create<smt::IteOp>(loc, castOp.getInputs()[0], constOne,
259 converter.addTargetMaterialization(
260 [&](OpBuilder &builder, smt::BoolType resultType, ValueRange inputs,
261 Location loc) -> Value {
262 if (inputs.size() != 1)
265 auto bvType = dyn_cast<smt::BitVectorType>(inputs[0].getType());
266 if (!bvType || bvType.getWidth() != 1)
269 Value constOne = builder.create<smt::BVConstantOp>(loc, 1, 1);
270 return builder.create<smt::EqOp>(loc, inputs[0], constOne);
275 converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
277 Location loc) -> Value {
279 .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
286 patterns.add<HWConstantOpConversion, HWModuleOpConversion, OutputOpConversion,
287 InstanceOpConversion, ReplaceWithInput<seq::ToClockOp>,
288 ReplaceWithInput<seq::FromClockOp>, ArrayCreateOpConversion,
289 ArrayGetOpConversion>(converter,
patterns.getContext());
292 void ConvertHWToSMTPass::runOnOperation() {
293 ConversionTarget target(getContext());
294 target.addIllegalDialect<hw::HWDialect>();
295 target.addIllegalOp<seq::FromClockOp>();
296 target.addIllegalOp<seq::ToClockOp>();
297 target.addLegalDialect<smt::SMTDialect>();
298 target.addLegalDialect<mlir::func::FuncDialect>();
300 RewritePatternSet
patterns(&getContext());
301 TypeConverter converter;
305 if (failed(mlir::applyPartialConversion(getOperation(), target,
307 return signalPassFailure();
312 for (
auto func : getOperation().getOps<mlir::func::FuncOp>()) {
315 if (func.getBody().getBlocks().size() != 1)
318 mlir::sortTopologically(&func.getBody().front());
MlirType uint64_t numElements
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.