CIRCT 20.0.0git
Loading...
Searching...
No Matches
HWToSMT.cpp
Go to the documentation of this file.
1//===- HWToSMT.cpp --------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/Analysis/TopologicalSortUtils.h"
14#include "mlir/Dialect/Func/IR/FuncOps.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/DialectConversion.h"
17
18namespace circt {
19#define GEN_PASS_DEF_CONVERTHWTOSMT
20#include "circt/Conversion/Passes.h.inc"
21} // namespace circt
22
23using namespace circt;
24using namespace hw;
25
26//===----------------------------------------------------------------------===//
27// Conversion patterns
28//===----------------------------------------------------------------------===//
29
30namespace {
31/// Lower a hw::ConstantOp operation to smt::BVConstantOp
32struct HWConstantOpConversion : OpConversionPattern<ConstantOp> {
34
35 LogicalResult
36 matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
37 ConversionPatternRewriter &rewriter) const override {
38 if (adaptor.getValue().getBitWidth() < 1)
39 return rewriter.notifyMatchFailure(op.getLoc(),
40 "0-bit constants not supported");
41 rewriter.replaceOpWithNewOp<smt::BVConstantOp>(op, adaptor.getValue());
42 return success();
43 }
44};
45
46/// Lower a hw::HWModuleOp operation to func::FuncOp.
47struct HWModuleOpConversion : OpConversionPattern<HWModuleOp> {
49
50 LogicalResult
51 matchAndRewrite(HWModuleOp op, OpAdaptor adaptor,
52 ConversionPatternRewriter &rewriter) const override {
53 auto funcTy = op.getModuleType().getFuncType();
54 SmallVector<Type> inputTypes, resultTypes;
55 if (failed(typeConverter->convertTypes(funcTy.getInputs(), inputTypes)))
56 return failure();
57 if (failed(typeConverter->convertTypes(funcTy.getResults(), resultTypes)))
58 return failure();
59 if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter)))
60 return failure();
61 auto funcOp = rewriter.create<mlir::func::FuncOp>(
62 op.getLoc(), adaptor.getSymNameAttr(),
63 rewriter.getFunctionType(inputTypes, resultTypes));
64 rewriter.inlineRegionBefore(op.getBody(), funcOp.getBody(), funcOp.end());
65 rewriter.eraseOp(op);
66 return success();
67 }
68};
69
70/// Lower a hw::OutputOp operation to func::ReturnOp.
71struct OutputOpConversion : OpConversionPattern<OutputOp> {
73
74 LogicalResult
75 matchAndRewrite(OutputOp op, OpAdaptor adaptor,
76 ConversionPatternRewriter &rewriter) const override {
77 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, adaptor.getOutputs());
78 return success();
79 }
80};
81
82/// Lower a hw::InstanceOp operation to func::CallOp.
83struct InstanceOpConversion : OpConversionPattern<InstanceOp> {
85
86 LogicalResult
87 matchAndRewrite(InstanceOp op, OpAdaptor adaptor,
88 ConversionPatternRewriter &rewriter) const override {
89 SmallVector<Type> resultTypes;
90 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
91 return failure();
92
93 rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
94 op, adaptor.getModuleNameAttr(), resultTypes, adaptor.getInputs());
95 return success();
96 }
97};
98
99/// Lower a hw::ArrayCreateOp operation to smt::DeclareFun and an
100/// smt::ArrayStoreOp for each operand.
101struct ArrayCreateOpConversion : OpConversionPattern<ArrayCreateOp> {
103
104 LogicalResult
105 matchAndRewrite(ArrayCreateOp op, OpAdaptor adaptor,
106 ConversionPatternRewriter &rewriter) const override {
107 Location loc = op.getLoc();
108 Type arrTy = typeConverter->convertType(op.getType());
109 if (!arrTy)
110 return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type");
111
112 unsigned width = adaptor.getInputs().size();
113
114 Value arr = rewriter.create<smt::DeclareFunOp>(loc, arrTy);
115 for (auto [i, el] : llvm::enumerate(adaptor.getInputs())) {
116 Value idx = rewriter.create<smt::BVConstantOp>(loc, width - i - 1,
117 llvm::Log2_64_Ceil(width));
118 arr = rewriter.create<smt::ArrayStoreOp>(loc, arr, idx, el);
119 }
120
121 rewriter.replaceOp(op, arr);
122 return success();
123 }
124};
125
126/// Lower a hw::ArrayGetOp operation to smt::ArraySelectOp
127struct ArrayGetOpConversion : OpConversionPattern<ArrayGetOp> {
129
130 LogicalResult
131 matchAndRewrite(ArrayGetOp op, OpAdaptor adaptor,
132 ConversionPatternRewriter &rewriter) const override {
133 Location loc = op.getLoc();
134 unsigned numElements =
135 cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
136
137 Type type = typeConverter->convertType(op.getType());
138 if (!type)
139 return rewriter.notifyMatchFailure(op.getLoc(),
140 "unsupported array element type");
141
142 Value oobVal = rewriter.create<smt::DeclareFunOp>(loc, type);
143 Value numElementsVal = rewriter.create<smt::BVConstantOp>(
144 loc, numElements - 1, llvm::Log2_64_Ceil(numElements));
145 Value inBounds = rewriter.create<smt::BVCmpOp>(
146 loc, smt::BVCmpPredicate::ule, adaptor.getIndex(), numElementsVal);
147 Value indexed = rewriter.create<smt::ArraySelectOp>(loc, adaptor.getInput(),
148 adaptor.getIndex());
149 rewriter.replaceOpWithNewOp<smt::IteOp>(op, inBounds, indexed, oobVal);
150 return success();
151 }
152};
153
154/// Remove redundant (seq::FromClock and seq::ToClock) ops.
155template <typename OpTy>
156struct ReplaceWithInput : OpConversionPattern<OpTy> {
158 using OpAdaptor = typename OpTy::Adaptor;
159
160 LogicalResult
161 matchAndRewrite(OpTy op, OpAdaptor adaptor,
162 ConversionPatternRewriter &rewriter) const override {
163 rewriter.replaceOp(op, adaptor.getOperands());
164 return success();
165 }
166};
167
168} // namespace
169
170//===----------------------------------------------------------------------===//
171// Convert HW to SMT pass
172//===----------------------------------------------------------------------===//
173
174namespace {
175struct ConvertHWToSMTPass
176 : public impl::ConvertHWToSMTBase<ConvertHWToSMTPass> {
177 void runOnOperation() override;
178};
179} // namespace
180
181void circt::populateHWToSMTTypeConverter(TypeConverter &converter) {
182 // The semantics of the builtin integer at the CIRCT core level is currently
183 // not very well defined. It is used for two-valued, four-valued, and possible
184 // other multi-valued logic. Here, we interpret it as two-valued for now.
185 // From a formal perspective, CIRCT would ideally define its own types for
186 // two-valued, four-valued, nine-valued (etc.) logic each. In MLIR upstream
187 // the integer type also carries poison information (which we don't have in
188 // CIRCT?).
189 converter.addConversion([](IntegerType type) -> std::optional<Type> {
190 if (type.getWidth() <= 0)
191 return std::nullopt;
192 return smt::BitVectorType::get(type.getContext(), type.getWidth());
193 });
194 converter.addConversion([](seq::ClockType type) -> std::optional<Type> {
195 return smt::BitVectorType::get(type.getContext(), 1);
196 });
197 converter.addConversion([&](ArrayType type) -> std::optional<Type> {
198 auto rangeType = converter.convertType(type.getElementType());
199 if (!rangeType)
200 return {};
201 auto domainType = smt::BitVectorType::get(
202 type.getContext(), llvm::Log2_64_Ceil(type.getNumElements()));
203 return smt::ArrayType::get(type.getContext(), domainType, rangeType);
204 });
205
206 // Default target materialization to convert from illegal types to legal
207 // types, e.g., at the boundary of an inlined child block.
208 converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
209 ValueRange inputs,
210 Location loc) -> Value {
211 return builder
212 .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
213 ->getResult(0);
214 });
215
216 // Convert a 'smt.bool'-typed value to a 'smt.bv<N>'-typed value
217 converter.addTargetMaterialization(
218 [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs,
219 Location loc) -> Value {
220 if (inputs.size() != 1)
221 return Value();
222
223 if (!isa<smt::BoolType>(inputs[0].getType()))
224 return Value();
225
226 unsigned width = resultType.getWidth();
227 Value constZero = builder.create<smt::BVConstantOp>(loc, 0, width);
228 Value constOne = builder.create<smt::BVConstantOp>(loc, 1, width);
229 return builder.create<smt::IteOp>(loc, inputs[0], constOne, constZero);
230 });
231
232 // Convert an unrealized conversion cast from 'smt.bool' to i1
233 // into a direct conversion from 'smt.bool' to 'smt.bv<1>'.
234 converter.addTargetMaterialization(
235 [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs,
236 Location loc) -> Value {
237 if (inputs.size() != 1 || resultType.getWidth() != 1)
238 return Value();
239
240 auto intType = dyn_cast<IntegerType>(inputs[0].getType());
241 if (!intType || intType.getWidth() != 1)
242 return Value();
243
244 auto castOp =
245 inputs[0].getDefiningOp<mlir::UnrealizedConversionCastOp>();
246 if (!castOp || castOp.getInputs().size() != 1)
247 return Value();
248
249 if (!isa<smt::BoolType>(castOp.getInputs()[0].getType()))
250 return Value();
251
252 Value constZero = builder.create<smt::BVConstantOp>(loc, 0, 1);
253 Value constOne = builder.create<smt::BVConstantOp>(loc, 1, 1);
254 return builder.create<smt::IteOp>(loc, castOp.getInputs()[0], constOne,
255 constZero);
256 });
257
258 // Convert a 'smt.bv<1>'-typed value to a 'smt.bool'-typed value
259 converter.addTargetMaterialization(
260 [&](OpBuilder &builder, smt::BoolType resultType, ValueRange inputs,
261 Location loc) -> Value {
262 if (inputs.size() != 1)
263 return Value();
264
265 auto bvType = dyn_cast<smt::BitVectorType>(inputs[0].getType());
266 if (!bvType || bvType.getWidth() != 1)
267 return Value();
268
269 Value constOne = builder.create<smt::BVConstantOp>(loc, 1, 1);
270 return builder.create<smt::EqOp>(loc, inputs[0], constOne);
271 });
272
273 // Default source materialization to convert from illegal types to legal
274 // types, e.g., at the boundary of an inlined child block.
275 converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
276 ValueRange inputs,
277 Location loc) -> Value {
278 return builder
279 .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
280 ->getResult(0);
281 });
282}
283
284void circt::populateHWToSMTConversionPatterns(TypeConverter &converter,
285 RewritePatternSet &patterns) {
286 patterns.add<HWConstantOpConversion, HWModuleOpConversion, OutputOpConversion,
287 InstanceOpConversion, ReplaceWithInput<seq::ToClockOp>,
288 ReplaceWithInput<seq::FromClockOp>, ArrayCreateOpConversion,
289 ArrayGetOpConversion>(converter, patterns.getContext());
290}
291
292void ConvertHWToSMTPass::runOnOperation() {
293 ConversionTarget target(getContext());
294 target.addIllegalDialect<hw::HWDialect>();
295 target.addIllegalOp<seq::FromClockOp>();
296 target.addIllegalOp<seq::ToClockOp>();
297 target.addLegalDialect<smt::SMTDialect>();
298 target.addLegalDialect<mlir::func::FuncDialect>();
299
300 RewritePatternSet patterns(&getContext());
301 TypeConverter converter;
304
305 if (failed(mlir::applyPartialConversion(getOperation(), target,
306 std::move(patterns))))
307 return signalPassFailure();
308
309 // Sort the functions topologically because 'hw.module' has a graph region
310 // while 'func.func' is a regular SSACFG region. Real combinational cycles or
311 // pseudo cycles through module instances are not supported yet.
312 for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
313 // Skip functions that are definitely not the result of lowering from
314 // 'hw.module'
315 if (func.getBody().getBlocks().size() != 1)
316 continue;
317
318 mlir::sortTopologically(&func.getBody().front());
319 }
320}
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
create(elements)
Definition hw.py:483
create(array_value, idx)
Definition hw.py:450
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
Definition HWToSMT.cpp:284
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition HWToSMT.cpp:181
Definition hw.py:1