CIRCT 22.0.0git
Loading...
Searching...
No Matches
HWToSMT.cpp
Go to the documentation of this file.
1//===- HWToSMT.cpp --------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Analysis/TopologicalSortUtils.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/SMT/IR/SMTOps.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/DialectConversion.h"
17
18namespace circt {
19#define GEN_PASS_DEF_CONVERTHWTOSMT
20#include "circt/Conversion/Passes.h.inc"
21} // namespace circt
22
23using namespace circt;
24using namespace hw;
25
26//===----------------------------------------------------------------------===//
27// Conversion patterns
28//===----------------------------------------------------------------------===//
29
30namespace {
31/// Lower a hw::ConstantOp operation to smt::BVConstantOp
32struct HWConstantOpConversion : OpConversionPattern<ConstantOp> {
34
35 LogicalResult
36 matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
37 ConversionPatternRewriter &rewriter) const override {
38 if (adaptor.getValue().getBitWidth() < 1)
39 return rewriter.notifyMatchFailure(op.getLoc(),
40 "0-bit constants not supported");
41 rewriter.replaceOpWithNewOp<mlir::smt::BVConstantOp>(op,
42 adaptor.getValue());
43 return success();
44 }
45};
46
47/// Lower a hw::HWModuleOp operation to func::FuncOp.
48struct HWModuleOpConversion : OpConversionPattern<HWModuleOp> {
50
51 LogicalResult
52 matchAndRewrite(HWModuleOp op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override {
54 auto funcTy = op.getModuleType().getFuncType();
55 SmallVector<Type> inputTypes, resultTypes;
56 if (failed(typeConverter->convertTypes(funcTy.getInputs(), inputTypes)))
57 return failure();
58 if (failed(typeConverter->convertTypes(funcTy.getResults(), resultTypes)))
59 return failure();
60 if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter)))
61 return failure();
62 auto funcOp = mlir::func::FuncOp::create(
63 rewriter, op.getLoc(), adaptor.getSymNameAttr(),
64 rewriter.getFunctionType(inputTypes, resultTypes));
65 rewriter.inlineRegionBefore(op.getBody(), funcOp.getBody(), funcOp.end());
66 rewriter.eraseOp(op);
67 return success();
68 }
69};
70
71/// Lower a hw::OutputOp operation to func::ReturnOp.
72struct OutputOpConversion : OpConversionPattern<OutputOp> {
74
75 LogicalResult
76 matchAndRewrite(OutputOp op, OpAdaptor adaptor,
77 ConversionPatternRewriter &rewriter) const override {
78 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, adaptor.getOutputs());
79 return success();
80 }
81};
82
83/// Lower a hw::InstanceOp operation to func::CallOp.
84struct InstanceOpConversion : OpConversionPattern<InstanceOp> {
86
87 LogicalResult
88 matchAndRewrite(InstanceOp op, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter) const override {
90 SmallVector<Type> resultTypes;
91 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
92 return failure();
93
94 rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
95 op, adaptor.getModuleNameAttr(), resultTypes, adaptor.getInputs());
96 return success();
97 }
98};
99
100/// Lower a hw::ArrayCreateOp operation to smt::DeclareFun and an
101/// smt::ArrayStoreOp for each operand.
102struct ArrayCreateOpConversion : OpConversionPattern<ArrayCreateOp> {
104
105 LogicalResult
106 matchAndRewrite(ArrayCreateOp op, OpAdaptor adaptor,
107 ConversionPatternRewriter &rewriter) const override {
108 Location loc = op.getLoc();
109 Type arrTy = typeConverter->convertType(op.getType());
110 if (!arrTy)
111 return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type");
112
113 unsigned width = adaptor.getInputs().size();
114
115 Value arr = mlir::smt::DeclareFunOp::create(rewriter, loc, arrTy);
116 for (auto [i, el] : llvm::enumerate(adaptor.getInputs())) {
117 Value idx = mlir::smt::BVConstantOp::create(rewriter, loc, width - i - 1,
118 llvm::Log2_64_Ceil(width));
119 arr = mlir::smt::ArrayStoreOp::create(rewriter, loc, arr, idx, el);
120 }
121
122 rewriter.replaceOp(op, arr);
123 return success();
124 }
125};
126
127/// Lower a hw::ArrayGetOp operation to smt::ArraySelectOp
128struct ArrayGetOpConversion : OpConversionPattern<ArrayGetOp> {
130
131 LogicalResult
132 matchAndRewrite(ArrayGetOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter) const override {
134 Location loc = op.getLoc();
135 unsigned numElements =
136 cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
137
138 Type type = typeConverter->convertType(op.getType());
139 if (!type)
140 return rewriter.notifyMatchFailure(op.getLoc(),
141 "unsupported array element type");
142
143 Value oobVal = mlir::smt::DeclareFunOp::create(rewriter, loc, type);
144 Value numElementsVal = mlir::smt::BVConstantOp::create(
145 rewriter, loc, numElements - 1, llvm::Log2_64_Ceil(numElements));
146 Value inBounds = mlir::smt::BVCmpOp::create(
147 rewriter, loc, mlir::smt::BVCmpPredicate::ule, adaptor.getIndex(),
148 numElementsVal);
149 Value indexed = mlir::smt::ArraySelectOp::create(
150 rewriter, loc, adaptor.getInput(), adaptor.getIndex());
151 rewriter.replaceOpWithNewOp<mlir::smt::IteOp>(op, inBounds, indexed,
152 oobVal);
153 return success();
154 }
155};
156
157/// Lower a hw::ArrayInjectOp operation to smt::ArrayStoreOp.
158struct ArrayInjectOpConversion : OpConversionPattern<ArrayInjectOp> {
159 using OpConversionPattern<ArrayInjectOp>::OpConversionPattern;
160
161 LogicalResult
162 matchAndRewrite(ArrayInjectOp op, OpAdaptor adaptor,
163 ConversionPatternRewriter &rewriter) const override {
164 Location loc = op.getLoc();
165 unsigned numElements =
166 cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
167
168 Type arrType = typeConverter->convertType(op.getType());
169 if (!arrType)
170 return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type");
171
172 Value oobVal = mlir::smt::DeclareFunOp::create(rewriter, loc, arrType);
173 // Check if the index is within bounds
174 Value numElementsVal = mlir::smt::BVConstantOp::create(
175 rewriter, loc, numElements - 1, llvm::Log2_64_Ceil(numElements));
176 Value inBounds = mlir::smt::BVCmpOp::create(
177 rewriter, loc, mlir::smt::BVCmpPredicate::ule, adaptor.getIndex(),
178 numElementsVal);
179
180 // Store the element at the given index
181 Value stored = mlir::smt::ArrayStoreOp::create(
182 rewriter, loc, adaptor.getInput(), adaptor.getIndex(),
183 adaptor.getElement());
184
185 // Return unbounded array if out of bounds
186 rewriter.replaceOpWithNewOp<mlir::smt::IteOp>(op, inBounds, stored, oobVal);
187 return success();
188 }
189};
190
191/// Remove redundant (seq::FromClock and seq::ToClock) ops.
192template <typename OpTy>
193struct ReplaceWithInput : OpConversionPattern<OpTy> {
195 using OpAdaptor = typename OpTy::Adaptor;
196
197 LogicalResult
198 matchAndRewrite(OpTy op, OpAdaptor adaptor,
199 ConversionPatternRewriter &rewriter) const override {
200 rewriter.replaceOp(op, adaptor.getOperands());
201 return success();
202 }
203};
204
205} // namespace
206
207//===----------------------------------------------------------------------===//
208// Convert HW to SMT pass
209//===----------------------------------------------------------------------===//
210
211namespace {
212struct ConvertHWToSMTPass
213 : public impl::ConvertHWToSMTBase<ConvertHWToSMTPass> {
214 void runOnOperation() override;
215};
216} // namespace
217
218void circt::populateHWToSMTTypeConverter(TypeConverter &converter) {
219 // The semantics of the builtin integer at the CIRCT core level is currently
220 // not very well defined. It is used for two-valued, four-valued, and possible
221 // other multi-valued logic. Here, we interpret it as two-valued for now.
222 // From a formal perspective, CIRCT would ideally define its own types for
223 // two-valued, four-valued, nine-valued (etc.) logic each. In MLIR upstream
224 // the integer type also carries poison information (which we don't have in
225 // CIRCT?).
226 converter.addConversion([](IntegerType type) -> std::optional<Type> {
227 if (type.getWidth() <= 0)
228 return std::nullopt;
229 return mlir::smt::BitVectorType::get(type.getContext(), type.getWidth());
230 });
231 converter.addConversion([](seq::ClockType type) -> std::optional<Type> {
232 return mlir::smt::BitVectorType::get(type.getContext(), 1);
233 });
234 converter.addConversion([&](ArrayType type) -> std::optional<Type> {
235 auto rangeType = converter.convertType(type.getElementType());
236 if (!rangeType)
237 return {};
238 auto domainType = mlir::smt::BitVectorType::get(
239 type.getContext(), llvm::Log2_64_Ceil(type.getNumElements()));
240 return mlir::smt::ArrayType::get(type.getContext(), domainType, rangeType);
241 });
242
243 // Default target materialization to convert from illegal types to legal
244 // types, e.g., at the boundary of an inlined child block.
245 converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
246 ValueRange inputs,
247 Location loc) -> Value {
248 return mlir::UnrealizedConversionCastOp::create(builder, loc, resultType,
249 inputs)
250 ->getResult(0);
251 });
252
253 // Convert a 'smt.bool'-typed value to a 'smt.bv<N>'-typed value
254 converter.addTargetMaterialization(
255 [&](OpBuilder &builder, mlir::smt::BitVectorType resultType,
256 ValueRange inputs, Location loc) -> Value {
257 if (inputs.size() != 1)
258 return Value();
259
260 if (!isa<mlir::smt::BoolType>(inputs[0].getType()))
261 return Value();
262
263 unsigned width = resultType.getWidth();
264 Value constZero =
265 mlir::smt::BVConstantOp::create(builder, loc, 0, width);
266 Value constOne =
267 mlir::smt::BVConstantOp::create(builder, loc, 1, width);
268 return mlir::smt::IteOp::create(builder, loc, inputs[0], constOne,
269 constZero);
270 });
271
272 // Convert an unrealized conversion cast from 'smt.bool' to i1
273 // into a direct conversion from 'smt.bool' to 'smt.bv<1>'.
274 converter.addTargetMaterialization(
275 [&](OpBuilder &builder, mlir::smt::BitVectorType resultType,
276 ValueRange inputs, Location loc) -> Value {
277 if (inputs.size() != 1 || resultType.getWidth() != 1)
278 return Value();
279
280 auto intType = dyn_cast<IntegerType>(inputs[0].getType());
281 if (!intType || intType.getWidth() != 1)
282 return Value();
283
284 auto castOp =
285 inputs[0].getDefiningOp<mlir::UnrealizedConversionCastOp>();
286 if (!castOp || castOp.getInputs().size() != 1)
287 return Value();
288
289 if (!isa<mlir::smt::BoolType>(castOp.getInputs()[0].getType()))
290 return Value();
291
292 Value constZero = mlir::smt::BVConstantOp::create(builder, loc, 0, 1);
293 Value constOne = mlir::smt::BVConstantOp::create(builder, loc, 1, 1);
294 return mlir::smt::IteOp::create(builder, loc, castOp.getInputs()[0],
295 constOne, constZero);
296 });
297
298 // Convert a 'smt.bv<1>'-typed value to a 'smt.bool'-typed value
299 converter.addTargetMaterialization(
300 [&](OpBuilder &builder, mlir::smt::BoolType resultType, ValueRange inputs,
301 Location loc) -> Value {
302 if (inputs.size() != 1)
303 return Value();
304
305 auto bvType = dyn_cast<mlir::smt::BitVectorType>(inputs[0].getType());
306 if (!bvType || bvType.getWidth() != 1)
307 return Value();
308
309 Value constOne = mlir::smt::BVConstantOp::create(builder, loc, 1, 1);
310 return mlir::smt::EqOp::create(builder, loc, inputs[0], constOne);
311 });
312
313 // Default source materialization to convert from illegal types to legal
314 // types, e.g., at the boundary of an inlined child block.
315 converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
316 ValueRange inputs,
317 Location loc) -> Value {
318 return mlir::UnrealizedConversionCastOp::create(builder, loc, resultType,
319 inputs)
320 ->getResult(0);
321 });
322}
323
324void circt::populateHWToSMTConversionPatterns(TypeConverter &converter,
325 RewritePatternSet &patterns) {
326 patterns.add<HWConstantOpConversion, HWModuleOpConversion, OutputOpConversion,
327 InstanceOpConversion, ReplaceWithInput<seq::ToClockOp>,
328 ReplaceWithInput<seq::FromClockOp>, ArrayCreateOpConversion,
329 ArrayGetOpConversion, ArrayInjectOpConversion>(
330 converter, patterns.getContext());
331}
332
333void ConvertHWToSMTPass::runOnOperation() {
334 ConversionTarget target(getContext());
335 target.addIllegalDialect<hw::HWDialect>();
336 target.addIllegalOp<seq::FromClockOp>();
337 target.addIllegalOp<seq::ToClockOp>();
338 target.addLegalDialect<mlir::smt::SMTDialect>();
339 target.addLegalDialect<mlir::func::FuncDialect>();
340
341 RewritePatternSet patterns(&getContext());
342 TypeConverter converter;
345
346 if (failed(mlir::applyPartialConversion(getOperation(), target,
347 std::move(patterns))))
348 return signalPassFailure();
349
350 // Sort the functions topologically because 'hw.module' has a graph region
351 // while 'func.func' is a regular SSACFG region. Real combinational cycles or
352 // pseudo cycles through module instances are not supported yet.
353 for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
354 // Skip functions that are definitely not the result of lowering from
355 // 'hw.module'
356 if (func.getBody().getBlocks().size() != 1)
357 continue;
358
359 mlir::sortTopologically(&func.getBody().front());
360 }
361}
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
Definition HWToSMT.cpp:324
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition HWToSMT.cpp:218
Definition hw.py:1