12#include "mlir/Analysis/TopologicalSortUtils.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/SMT/IR/SMTOps.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/DialectConversion.h"
19#define GEN_PASS_DEF_CONVERTHWTOSMT
20#include "circt/Conversion/Passes.h.inc"
36 matchAndRewrite(
ConstantOp op, OpAdaptor adaptor,
37 ConversionPatternRewriter &rewriter)
const override {
38 if (adaptor.getValue().getBitWidth() < 1)
39 return rewriter.notifyMatchFailure(op.getLoc(),
40 "0-bit constants not supported");
41 rewriter.replaceOpWithNewOp<mlir::smt::BVConstantOp>(op,
52 matchAndRewrite(
HWModuleOp op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter)
const override {
54 auto funcTy = op.getModuleType().getFuncType();
55 SmallVector<Type> inputTypes, resultTypes;
56 if (failed(typeConverter->convertTypes(funcTy.getInputs(), inputTypes)))
58 if (failed(typeConverter->convertTypes(funcTy.getResults(), resultTypes)))
60 if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter)))
62 auto funcOp = mlir::func::FuncOp::create(
63 rewriter, op.getLoc(), adaptor.getSymNameAttr(),
64 rewriter.getFunctionType(inputTypes, resultTypes));
65 rewriter.inlineRegionBefore(op.getBody(), funcOp.getBody(), funcOp.end());
76 matchAndRewrite(OutputOp op, OpAdaptor adaptor,
77 ConversionPatternRewriter &rewriter)
const override {
78 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, adaptor.getOutputs());
88 matchAndRewrite(InstanceOp op, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter)
const override {
90 SmallVector<Type> resultTypes;
91 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
94 rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
95 op, adaptor.getModuleNameAttr(), resultTypes, adaptor.getInputs());
107 ConversionPatternRewriter &rewriter)
const override {
108 Location loc = op.getLoc();
109 Type arrTy = typeConverter->convertType(op.getType());
111 return rewriter.notifyMatchFailure(op.getLoc(),
"unsupported array type");
113 unsigned width = adaptor.getInputs().size();
115 Value arr = mlir::smt::DeclareFunOp::create(rewriter, loc, arrTy);
116 for (
auto [i, el] :
llvm::enumerate(adaptor.getInputs())) {
117 Value idx = mlir::smt::BVConstantOp::create(rewriter, loc, width - i - 1,
118 llvm::Log2_64_Ceil(width));
119 arr = mlir::smt::ArrayStoreOp::create(rewriter, loc, arr, idx, el);
122 rewriter.replaceOp(op, arr);
132 matchAndRewrite(
ArrayGetOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter)
const override {
134 Location loc = op.getLoc();
136 cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
138 Type type = typeConverter->convertType(op.getType());
140 return rewriter.notifyMatchFailure(op.getLoc(),
141 "unsupported array element type");
143 Value oobVal = mlir::smt::DeclareFunOp::create(rewriter, loc, type);
144 Value numElementsVal = mlir::smt::BVConstantOp::create(
146 Value inBounds = mlir::smt::BVCmpOp::create(
147 rewriter, loc, mlir::smt::BVCmpPredicate::ule, adaptor.getIndex(),
149 Value indexed = mlir::smt::ArraySelectOp::create(
150 rewriter, loc, adaptor.getInput(), adaptor.getIndex());
151 rewriter.replaceOpWithNewOp<mlir::smt::IteOp>(op, inBounds, indexed,
162 matchAndRewrite(ArrayInjectOp op, OpAdaptor adaptor,
163 ConversionPatternRewriter &rewriter)
const override {
164 Location loc = op.getLoc();
166 cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
168 Type arrType = typeConverter->convertType(op.getType());
170 return rewriter.notifyMatchFailure(op.getLoc(),
"unsupported array type");
172 Value oobVal = mlir::smt::DeclareFunOp::create(rewriter, loc, arrType);
174 Value numElementsVal = mlir::smt::BVConstantOp::create(
176 Value inBounds = mlir::smt::BVCmpOp::create(
177 rewriter, loc, mlir::smt::BVCmpPredicate::ule, adaptor.getIndex(),
181 Value stored = mlir::smt::ArrayStoreOp::create(
182 rewriter, loc, adaptor.getInput(), adaptor.getIndex(),
183 adaptor.getElement());
186 rewriter.replaceOpWithNewOp<mlir::smt::IteOp>(op, inBounds, stored, oobVal);
192template <
typename OpTy>
195 using OpAdaptor =
typename OpTy::Adaptor;
198 matchAndRewrite(OpTy op, OpAdaptor adaptor,
199 ConversionPatternRewriter &rewriter)
const override {
200 rewriter.replaceOp(op, adaptor.getOperands());
212struct ConvertHWToSMTPass
213 :
public impl::ConvertHWToSMTBase<ConvertHWToSMTPass> {
214 void runOnOperation()
override;
226 converter.addConversion([](IntegerType type) -> std::optional<Type> {
227 if (type.getWidth() <= 0)
229 return mlir::smt::BitVectorType::get(type.getContext(), type.getWidth());
231 converter.addConversion([](seq::ClockType type) -> std::optional<Type> {
232 return mlir::smt::BitVectorType::get(type.getContext(), 1);
234 converter.addConversion([&](ArrayType type) -> std::optional<Type> {
235 auto rangeType = converter.convertType(type.getElementType());
238 auto domainType = mlir::smt::BitVectorType::get(
239 type.getContext(), llvm::Log2_64_Ceil(type.getNumElements()));
240 return mlir::smt::ArrayType::get(type.getContext(), domainType, rangeType);
245 converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
247 Location loc) -> Value {
248 return mlir::UnrealizedConversionCastOp::create(builder, loc, resultType,
254 converter.addTargetMaterialization(
255 [&](OpBuilder &builder, mlir::smt::BitVectorType resultType,
256 ValueRange inputs, Location loc) -> Value {
257 if (inputs.size() != 1)
260 if (!isa<mlir::smt::BoolType>(inputs[0].getType()))
263 unsigned width = resultType.getWidth();
265 mlir::smt::BVConstantOp::create(builder, loc, 0, width);
267 mlir::smt::BVConstantOp::create(builder, loc, 1, width);
268 return mlir::smt::IteOp::create(builder, loc, inputs[0], constOne,
274 converter.addTargetMaterialization(
275 [&](OpBuilder &builder, mlir::smt::BitVectorType resultType,
276 ValueRange inputs, Location loc) -> Value {
277 if (inputs.size() != 1 || resultType.getWidth() != 1)
280 auto intType = dyn_cast<IntegerType>(inputs[0].getType());
281 if (!intType || intType.getWidth() != 1)
285 inputs[0].getDefiningOp<mlir::UnrealizedConversionCastOp>();
286 if (!castOp || castOp.getInputs().size() != 1)
289 if (!isa<mlir::smt::BoolType>(castOp.getInputs()[0].getType()))
292 Value constZero = mlir::smt::BVConstantOp::create(builder, loc, 0, 1);
293 Value constOne = mlir::smt::BVConstantOp::create(builder, loc, 1, 1);
294 return mlir::smt::IteOp::create(builder, loc, castOp.getInputs()[0],
295 constOne, constZero);
299 converter.addTargetMaterialization(
300 [&](OpBuilder &builder, mlir::smt::BoolType resultType, ValueRange inputs,
301 Location loc) -> Value {
302 if (inputs.size() != 1)
305 auto bvType = dyn_cast<mlir::smt::BitVectorType>(inputs[0].getType());
306 if (!bvType || bvType.getWidth() != 1)
309 Value constOne = mlir::smt::BVConstantOp::create(builder, loc, 1, 1);
310 return mlir::smt::EqOp::create(builder, loc, inputs[0], constOne);
315 converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
317 Location loc) -> Value {
318 return mlir::UnrealizedConversionCastOp::create(builder, loc, resultType,
326 patterns.add<HWConstantOpConversion, HWModuleOpConversion, OutputOpConversion,
327 InstanceOpConversion, ReplaceWithInput<seq::ToClockOp>,
328 ReplaceWithInput<seq::FromClockOp>, ArrayCreateOpConversion,
329 ArrayGetOpConversion, ArrayInjectOpConversion>(
333void ConvertHWToSMTPass::runOnOperation() {
334 ConversionTarget target(getContext());
335 target.addIllegalDialect<hw::HWDialect>();
336 target.addIllegalOp<seq::FromClockOp>();
337 target.addIllegalOp<seq::ToClockOp>();
338 target.addLegalDialect<mlir::smt::SMTDialect>();
339 target.addLegalDialect<mlir::func::FuncDialect>();
341 RewritePatternSet
patterns(&getContext());
342 TypeConverter converter;
346 if (failed(mlir::applyPartialConversion(getOperation(), target,
348 return signalPassFailure();
353 for (
auto func : getOperation().getOps<
mlir::func::FuncOp>()) {
356 if (func.getBody().getBlocks().size() != 1)
359 mlir::sortTopologically(&func.getBody().front());
MlirType uint64_t numElements
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.