CIRCT 22.0.0git
Loading...
Searching...
No Matches
HWToSMT.cpp
Go to the documentation of this file.
1//===- HWToSMT.cpp --------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Analysis/TopologicalSortUtils.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/SMT/IR/SMTOps.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/DialectConversion.h"
17
18namespace circt {
19#define GEN_PASS_DEF_CONVERTHWTOSMT
20#include "circt/Conversion/Passes.h.inc"
21} // namespace circt
22
23using namespace circt;
24using namespace hw;
25
26//===----------------------------------------------------------------------===//
27// Conversion patterns
28//===----------------------------------------------------------------------===//
29
30namespace {
31/// Lower a hw::ConstantOp operation to smt::BVConstantOp
32struct HWConstantOpConversion : OpConversionPattern<ConstantOp> {
34
35 LogicalResult
36 matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
37 ConversionPatternRewriter &rewriter) const override {
38 if (adaptor.getValue().getBitWidth() < 1)
39 return rewriter.notifyMatchFailure(op.getLoc(),
40 "0-bit constants not supported");
41 rewriter.replaceOpWithNewOp<mlir::smt::BVConstantOp>(op,
42 adaptor.getValue());
43 return success();
44 }
45};
46
47/// Lower a hw::HWModuleOp operation to func::FuncOp.
48struct HWModuleOpConversion : OpConversionPattern<HWModuleOp> {
50
51 LogicalResult
52 matchAndRewrite(HWModuleOp op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override {
54 auto funcTy = op.getModuleType().getFuncType();
55 SmallVector<Type> inputTypes, resultTypes;
56 if (failed(typeConverter->convertTypes(funcTy.getInputs(), inputTypes)))
57 return failure();
58 if (failed(typeConverter->convertTypes(funcTy.getResults(), resultTypes)))
59 return failure();
60 if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter)))
61 return failure();
62 auto funcOp = mlir::func::FuncOp::create(
63 rewriter, op.getLoc(), adaptor.getSymNameAttr(),
64 rewriter.getFunctionType(inputTypes, resultTypes));
65 rewriter.inlineRegionBefore(op.getBody(), funcOp.getBody(), funcOp.end());
66 rewriter.eraseOp(op);
67 return success();
68 }
69};
70
71/// Lower a hw::OutputOp operation to func::ReturnOp.
72struct OutputOpConversion : OpConversionPattern<OutputOp> {
74
75 OutputOpConversion(TypeConverter &converter, MLIRContext *context,
76 bool assertModuleOutputs)
77 : OpConversionPattern<OutputOp>::OpConversionPattern(converter, context),
78 assertModuleOutputs(assertModuleOutputs) {}
79
80 LogicalResult
81 matchAndRewrite(OutputOp op, OpAdaptor adaptor,
82 ConversionPatternRewriter &rewriter) const override {
83 if (assertModuleOutputs) {
84 Location loc = op.getLoc();
85 for (auto output : adaptor.getOutputs()) {
86 Value constOutput =
87 mlir::smt::DeclareFunOp::create(rewriter, loc, output.getType());
88 Value eq = mlir::smt::EqOp::create(rewriter, loc, output, constOutput);
89 mlir::smt::AssertOp::create(rewriter, loc, eq);
90 }
91 }
92 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, adaptor.getOutputs());
93 return success();
94 }
95
96 bool assertModuleOutputs;
97};
98
99/// Lower a hw::InstanceOp operation to func::CallOp.
100struct InstanceOpConversion : OpConversionPattern<InstanceOp> {
102
103 LogicalResult
104 matchAndRewrite(InstanceOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter) const override {
106 SmallVector<Type> resultTypes;
107 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
108 return failure();
109
110 rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
111 op, adaptor.getModuleNameAttr(), resultTypes, adaptor.getInputs());
112 return success();
113 }
114};
115
116/// Lower a hw::ArrayCreateOp operation to smt::DeclareFun and an
117/// smt::ArrayStoreOp for each operand.
118struct ArrayCreateOpConversion : OpConversionPattern<ArrayCreateOp> {
120
121 LogicalResult
122 matchAndRewrite(ArrayCreateOp op, OpAdaptor adaptor,
123 ConversionPatternRewriter &rewriter) const override {
124 Location loc = op.getLoc();
125 Type arrTy = typeConverter->convertType(op.getType());
126 if (!arrTy)
127 return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type");
128
129 unsigned width = adaptor.getInputs().size();
130
131 Value arr = mlir::smt::DeclareFunOp::create(rewriter, loc, arrTy);
132 for (auto [i, el] : llvm::enumerate(adaptor.getInputs())) {
133 Value idx = mlir::smt::BVConstantOp::create(rewriter, loc, width - i - 1,
134 llvm::Log2_64_Ceil(width));
135 arr = mlir::smt::ArrayStoreOp::create(rewriter, loc, arr, idx, el);
136 }
137
138 rewriter.replaceOp(op, arr);
139 return success();
140 }
141};
142
143/// Lower a hw::ArrayGetOp operation to smt::ArraySelectOp
144struct ArrayGetOpConversion : OpConversionPattern<ArrayGetOp> {
146
147 LogicalResult
148 matchAndRewrite(ArrayGetOp op, OpAdaptor adaptor,
149 ConversionPatternRewriter &rewriter) const override {
150 Location loc = op.getLoc();
151 unsigned numElements =
152 cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
153
154 Type type = typeConverter->convertType(op.getType());
155 if (!type)
156 return rewriter.notifyMatchFailure(op.getLoc(),
157 "unsupported array element type");
158
159 Value oobVal = mlir::smt::DeclareFunOp::create(rewriter, loc, type);
160 Value numElementsVal = mlir::smt::BVConstantOp::create(
161 rewriter, loc, numElements - 1, llvm::Log2_64_Ceil(numElements));
162 Value inBounds = mlir::smt::BVCmpOp::create(
163 rewriter, loc, mlir::smt::BVCmpPredicate::ule, adaptor.getIndex(),
164 numElementsVal);
165 Value indexed = mlir::smt::ArraySelectOp::create(
166 rewriter, loc, adaptor.getInput(), adaptor.getIndex());
167 rewriter.replaceOpWithNewOp<mlir::smt::IteOp>(op, inBounds, indexed,
168 oobVal);
169 return success();
170 }
171};
172
173/// Lower a hw::ArrayInjectOp operation to smt::ArrayStoreOp.
174struct ArrayInjectOpConversion : OpConversionPattern<ArrayInjectOp> {
175 using OpConversionPattern<ArrayInjectOp>::OpConversionPattern;
176
177 LogicalResult
178 matchAndRewrite(ArrayInjectOp op, OpAdaptor adaptor,
179 ConversionPatternRewriter &rewriter) const override {
180 Location loc = op.getLoc();
181 unsigned numElements =
182 cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
183
184 Type arrType = typeConverter->convertType(op.getType());
185 if (!arrType)
186 return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type");
187
188 Value oobVal = mlir::smt::DeclareFunOp::create(rewriter, loc, arrType);
189 // Check if the index is within bounds
190 Value numElementsVal = mlir::smt::BVConstantOp::create(
191 rewriter, loc, numElements - 1, llvm::Log2_64_Ceil(numElements));
192 Value inBounds = mlir::smt::BVCmpOp::create(
193 rewriter, loc, mlir::smt::BVCmpPredicate::ule, adaptor.getIndex(),
194 numElementsVal);
195
196 // Store the element at the given index
197 Value stored = mlir::smt::ArrayStoreOp::create(
198 rewriter, loc, adaptor.getInput(), adaptor.getIndex(),
199 adaptor.getElement());
200
201 // Return unbounded array if out of bounds
202 rewriter.replaceOpWithNewOp<mlir::smt::IteOp>(op, inBounds, stored, oobVal);
203 return success();
204 }
205};
206
207/// Remove redundant (seq::FromClock and seq::ToClock) ops.
208template <typename OpTy>
209struct ReplaceWithInput : OpConversionPattern<OpTy> {
211 using OpAdaptor = typename OpTy::Adaptor;
212
213 LogicalResult
214 matchAndRewrite(OpTy op, OpAdaptor adaptor,
215 ConversionPatternRewriter &rewriter) const override {
216 rewriter.replaceOp(op, adaptor.getOperands());
217 return success();
218 }
219};
220
221} // namespace
222
223//===----------------------------------------------------------------------===//
224// Convert HW to SMT pass
225//===----------------------------------------------------------------------===//
226
227namespace {
228struct ConvertHWToSMTPass
229 : public impl::ConvertHWToSMTBase<ConvertHWToSMTPass> {
230 using Base::Base;
231 void runOnOperation() override;
232};
233} // namespace
234
235void circt::populateHWToSMTTypeConverter(TypeConverter &converter) {
236 // The semantics of the builtin integer at the CIRCT core level is currently
237 // not very well defined. It is used for two-valued, four-valued, and possible
238 // other multi-valued logic. Here, we interpret it as two-valued for now.
239 // From a formal perspective, CIRCT would ideally define its own types for
240 // two-valued, four-valued, nine-valued (etc.) logic each. In MLIR upstream
241 // the integer type also carries poison information (which we don't have in
242 // CIRCT?).
243 converter.addConversion([](IntegerType type) -> std::optional<Type> {
244 if (type.getWidth() <= 0)
245 return std::nullopt;
246 return mlir::smt::BitVectorType::get(type.getContext(), type.getWidth());
247 });
248 converter.addConversion([](seq::ClockType type) -> std::optional<Type> {
249 return mlir::smt::BitVectorType::get(type.getContext(), 1);
250 });
251 converter.addConversion([&](ArrayType type) -> std::optional<Type> {
252 auto rangeType = converter.convertType(type.getElementType());
253 if (!rangeType)
254 return {};
255 auto domainType = mlir::smt::BitVectorType::get(
256 type.getContext(), llvm::Log2_64_Ceil(type.getNumElements()));
257 return mlir::smt::ArrayType::get(type.getContext(), domainType, rangeType);
258 });
259
260 // Default target materialization to convert from illegal types to legal
261 // types, e.g., at the boundary of an inlined child block.
262 converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
263 ValueRange inputs,
264 Location loc) -> Value {
265 return mlir::UnrealizedConversionCastOp::create(builder, loc, resultType,
266 inputs)
267 ->getResult(0);
268 });
269
270 // Convert a 'smt.bool'-typed value to a 'smt.bv<N>'-typed value
271 converter.addTargetMaterialization([&](OpBuilder &builder,
272 mlir::smt::BitVectorType resultType,
273 ValueRange inputs,
274 Location loc) -> Value {
275 if (inputs.size() != 1)
276 return Value();
277
278 if (!isa<mlir::smt::BoolType>(inputs[0].getType()))
279 return Value();
280
281 unsigned width = resultType.getWidth();
282 Value constZero = mlir::smt::BVConstantOp::create(builder, loc, 0, width);
283 Value constOne = mlir::smt::BVConstantOp::create(builder, loc, 1, width);
284 return mlir::smt::IteOp::create(builder, loc, inputs[0], constOne,
285 constZero);
286 });
287
288 // Convert an unrealized conversion cast from 'smt.bool' to i1
289 // into a direct conversion from 'smt.bool' to 'smt.bv<1>'.
290 converter.addTargetMaterialization(
291 [&](OpBuilder &builder, mlir::smt::BitVectorType resultType,
292 ValueRange inputs, Location loc) -> Value {
293 if (inputs.size() != 1 || resultType.getWidth() != 1)
294 return Value();
295
296 auto intType = dyn_cast<IntegerType>(inputs[0].getType());
297 if (!intType || intType.getWidth() != 1)
298 return Value();
299
300 auto castOp =
301 inputs[0].getDefiningOp<mlir::UnrealizedConversionCastOp>();
302 if (!castOp || castOp.getInputs().size() != 1)
303 return Value();
304
305 if (!isa<mlir::smt::BoolType>(castOp.getInputs()[0].getType()))
306 return Value();
307
308 Value constZero = mlir::smt::BVConstantOp::create(builder, loc, 0, 1);
309 Value constOne = mlir::smt::BVConstantOp::create(builder, loc, 1, 1);
310 return mlir::smt::IteOp::create(builder, loc, castOp.getInputs()[0],
311 constOne, constZero);
312 });
313
314 // Convert a 'smt.bv<1>'-typed value to a 'smt.bool'-typed value
315 converter.addTargetMaterialization(
316 [&](OpBuilder &builder, mlir::smt::BoolType resultType, ValueRange inputs,
317 Location loc) -> Value {
318 if (inputs.size() != 1)
319 return Value();
320
321 auto bvType = dyn_cast<mlir::smt::BitVectorType>(inputs[0].getType());
322 if (!bvType || bvType.getWidth() != 1)
323 return Value();
324
325 Value constOne = mlir::smt::BVConstantOp::create(builder, loc, 1, 1);
326 return mlir::smt::EqOp::create(builder, loc, inputs[0], constOne);
327 });
328
329 // Default source materialization to convert from illegal types to legal
330 // types, e.g., at the boundary of an inlined child block.
331 converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
332 ValueRange inputs,
333 Location loc) -> Value {
334 return mlir::UnrealizedConversionCastOp::create(builder, loc, resultType,
335 inputs)
336 ->getResult(0);
337 });
338}
339
340void circt::populateHWToSMTConversionPatterns(TypeConverter &converter,
341 RewritePatternSet &patterns,
342 bool assertModuleOutputs) {
343 patterns.add<HWConstantOpConversion, HWModuleOpConversion,
344 InstanceOpConversion, ReplaceWithInput<seq::ToClockOp>,
345 ReplaceWithInput<seq::FromClockOp>, ArrayCreateOpConversion,
346 ArrayGetOpConversion, ArrayInjectOpConversion>(
347 converter, patterns.getContext());
348 patterns.add<OutputOpConversion>(converter, patterns.getContext(),
349 assertModuleOutputs);
350}
351
352void ConvertHWToSMTPass::runOnOperation() {
353 ConversionTarget target(getContext());
354 target.addIllegalDialect<hw::HWDialect>();
355 target.addIllegalOp<seq::FromClockOp>();
356 target.addIllegalOp<seq::ToClockOp>();
357 target.addLegalDialect<mlir::smt::SMTDialect>();
358 target.addLegalDialect<mlir::func::FuncDialect>();
359
360 RewritePatternSet patterns(&getContext());
361 TypeConverter converter;
363 populateHWToSMTConversionPatterns(converter, patterns, assertModuleOutputs);
364
365 if (failed(mlir::applyPartialConversion(getOperation(), target,
366 std::move(patterns))))
367 return signalPassFailure();
368
369 // Sort the functions topologically because 'hw.module' has a graph region
370 // while 'func.func' is a regular SSACFG region. Real combinational cycles or
371 // pseudo cycles through module instances are not supported yet.
372 for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
373 // Skip functions that are definitely not the result of lowering from
374 // 'hw.module'
375 if (func.getBody().getBlocks().size() != 1)
376 continue;
377
378 mlir::sortTopologically(&func.getBody().front());
379 }
380}
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition HWToSMT.cpp:235
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns, bool assertModuleOutputs)
Get the HW to SMT conversion patterns.
Definition HWToSMT.cpp:340
Definition hw.py:1