Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
HWToSMT.cpp
Go to the documentation of this file.
1//===- HWToSMT.cpp --------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Analysis/TopologicalSortUtils.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/SMT/IR/SMTOps.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/DialectConversion.h"
17
18namespace circt {
19#define GEN_PASS_DEF_CONVERTHWTOSMT
20#include "circt/Conversion/Passes.h.inc"
21} // namespace circt
22
23using namespace circt;
24using namespace hw;
25
26//===----------------------------------------------------------------------===//
27// Conversion patterns
28//===----------------------------------------------------------------------===//
29
30namespace {
31/// Lower a hw::ConstantOp operation to smt::BVConstantOp
32struct HWConstantOpConversion : OpConversionPattern<ConstantOp> {
34
35 LogicalResult
36 matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
37 ConversionPatternRewriter &rewriter) const override {
38 if (adaptor.getValue().getBitWidth() < 1)
39 return rewriter.notifyMatchFailure(op.getLoc(),
40 "0-bit constants not supported");
41 rewriter.replaceOpWithNewOp<mlir::smt::BVConstantOp>(op,
42 adaptor.getValue());
43 return success();
44 }
45};
46
47/// Lower a hw::HWModuleOp operation to func::FuncOp.
48struct HWModuleOpConversion : OpConversionPattern<HWModuleOp> {
50
51 LogicalResult
52 matchAndRewrite(HWModuleOp op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override {
54 auto funcTy = op.getModuleType().getFuncType();
55 SmallVector<Type> inputTypes, resultTypes;
56 if (failed(typeConverter->convertTypes(funcTy.getInputs(), inputTypes)))
57 return failure();
58 if (failed(typeConverter->convertTypes(funcTy.getResults(), resultTypes)))
59 return failure();
60 if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter)))
61 return failure();
62 auto funcOp = rewriter.create<mlir::func::FuncOp>(
63 op.getLoc(), adaptor.getSymNameAttr(),
64 rewriter.getFunctionType(inputTypes, resultTypes));
65 rewriter.inlineRegionBefore(op.getBody(), funcOp.getBody(), funcOp.end());
66 rewriter.eraseOp(op);
67 return success();
68 }
69};
70
71/// Lower a hw::OutputOp operation to func::ReturnOp.
72struct OutputOpConversion : OpConversionPattern<OutputOp> {
74
75 LogicalResult
76 matchAndRewrite(OutputOp op, OpAdaptor adaptor,
77 ConversionPatternRewriter &rewriter) const override {
78 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, adaptor.getOutputs());
79 return success();
80 }
81};
82
83/// Lower a hw::InstanceOp operation to func::CallOp.
84struct InstanceOpConversion : OpConversionPattern<InstanceOp> {
86
87 LogicalResult
88 matchAndRewrite(InstanceOp op, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter) const override {
90 SmallVector<Type> resultTypes;
91 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
92 return failure();
93
94 rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
95 op, adaptor.getModuleNameAttr(), resultTypes, adaptor.getInputs());
96 return success();
97 }
98};
99
100/// Lower a hw::ArrayCreateOp operation to smt::DeclareFun and an
101/// smt::ArrayStoreOp for each operand.
102struct ArrayCreateOpConversion : OpConversionPattern<ArrayCreateOp> {
104
105 LogicalResult
106 matchAndRewrite(ArrayCreateOp op, OpAdaptor adaptor,
107 ConversionPatternRewriter &rewriter) const override {
108 Location loc = op.getLoc();
109 Type arrTy = typeConverter->convertType(op.getType());
110 if (!arrTy)
111 return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type");
112
113 unsigned width = adaptor.getInputs().size();
114
115 Value arr = rewriter.create<mlir::smt::DeclareFunOp>(loc, arrTy);
116 for (auto [i, el] : llvm::enumerate(adaptor.getInputs())) {
117 Value idx = rewriter.create<mlir::smt::BVConstantOp>(
118 loc, width - i - 1, llvm::Log2_64_Ceil(width));
119 arr = rewriter.create<mlir::smt::ArrayStoreOp>(loc, arr, idx, el);
120 }
121
122 rewriter.replaceOp(op, arr);
123 return success();
124 }
125};
126
127/// Lower a hw::ArrayGetOp operation to smt::ArraySelectOp
128struct ArrayGetOpConversion : OpConversionPattern<ArrayGetOp> {
130
131 LogicalResult
132 matchAndRewrite(ArrayGetOp op, OpAdaptor adaptor,
133 ConversionPatternRewriter &rewriter) const override {
134 Location loc = op.getLoc();
135 unsigned numElements =
136 cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
137
138 Type type = typeConverter->convertType(op.getType());
139 if (!type)
140 return rewriter.notifyMatchFailure(op.getLoc(),
141 "unsupported array element type");
142
143 Value oobVal = rewriter.create<mlir::smt::DeclareFunOp>(loc, type);
144 Value numElementsVal = rewriter.create<mlir::smt::BVConstantOp>(
145 loc, numElements - 1, llvm::Log2_64_Ceil(numElements));
146 Value inBounds =
147 rewriter.create<mlir::smt::BVCmpOp>(loc, mlir::smt::BVCmpPredicate::ule,
148 adaptor.getIndex(), numElementsVal);
149 Value indexed = rewriter.create<mlir::smt::ArraySelectOp>(
150 loc, adaptor.getInput(), adaptor.getIndex());
151 rewriter.replaceOpWithNewOp<mlir::smt::IteOp>(op, inBounds, indexed,
152 oobVal);
153 return success();
154 }
155};
156
157/// Remove redundant (seq::FromClock and seq::ToClock) ops.
158template <typename OpTy>
159struct ReplaceWithInput : OpConversionPattern<OpTy> {
161 using OpAdaptor = typename OpTy::Adaptor;
162
163 LogicalResult
164 matchAndRewrite(OpTy op, OpAdaptor adaptor,
165 ConversionPatternRewriter &rewriter) const override {
166 rewriter.replaceOp(op, adaptor.getOperands());
167 return success();
168 }
169};
170
171} // namespace
172
173//===----------------------------------------------------------------------===//
174// Convert HW to SMT pass
175//===----------------------------------------------------------------------===//
176
177namespace {
178struct ConvertHWToSMTPass
179 : public impl::ConvertHWToSMTBase<ConvertHWToSMTPass> {
180 void runOnOperation() override;
181};
182} // namespace
183
184void circt::populateHWToSMTTypeConverter(TypeConverter &converter) {
185 // The semantics of the builtin integer at the CIRCT core level is currently
186 // not very well defined. It is used for two-valued, four-valued, and possible
187 // other multi-valued logic. Here, we interpret it as two-valued for now.
188 // From a formal perspective, CIRCT would ideally define its own types for
189 // two-valued, four-valued, nine-valued (etc.) logic each. In MLIR upstream
190 // the integer type also carries poison information (which we don't have in
191 // CIRCT?).
192 converter.addConversion([](IntegerType type) -> std::optional<Type> {
193 if (type.getWidth() <= 0)
194 return std::nullopt;
195 return mlir::smt::BitVectorType::get(type.getContext(), type.getWidth());
196 });
197 converter.addConversion([](seq::ClockType type) -> std::optional<Type> {
198 return mlir::smt::BitVectorType::get(type.getContext(), 1);
199 });
200 converter.addConversion([&](ArrayType type) -> std::optional<Type> {
201 auto rangeType = converter.convertType(type.getElementType());
202 if (!rangeType)
203 return {};
204 auto domainType = mlir::smt::BitVectorType::get(
205 type.getContext(), llvm::Log2_64_Ceil(type.getNumElements()));
206 return mlir::smt::ArrayType::get(type.getContext(), domainType, rangeType);
207 });
208
209 // Default target materialization to convert from illegal types to legal
210 // types, e.g., at the boundary of an inlined child block.
211 converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
212 ValueRange inputs,
213 Location loc) -> Value {
214 return builder
215 .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
216 ->getResult(0);
217 });
218
219 // Convert a 'smt.bool'-typed value to a 'smt.bv<N>'-typed value
220 converter.addTargetMaterialization(
221 [&](OpBuilder &builder, mlir::smt::BitVectorType resultType,
222 ValueRange inputs, Location loc) -> Value {
223 if (inputs.size() != 1)
224 return Value();
225
226 if (!isa<mlir::smt::BoolType>(inputs[0].getType()))
227 return Value();
228
229 unsigned width = resultType.getWidth();
230 Value constZero =
231 builder.create<mlir::smt::BVConstantOp>(loc, 0, width);
232 Value constOne = builder.create<mlir::smt::BVConstantOp>(loc, 1, width);
233 return builder.create<mlir::smt::IteOp>(loc, inputs[0], constOne,
234 constZero);
235 });
236
237 // Convert an unrealized conversion cast from 'smt.bool' to i1
238 // into a direct conversion from 'smt.bool' to 'smt.bv<1>'.
239 converter.addTargetMaterialization(
240 [&](OpBuilder &builder, mlir::smt::BitVectorType resultType,
241 ValueRange inputs, Location loc) -> Value {
242 if (inputs.size() != 1 || resultType.getWidth() != 1)
243 return Value();
244
245 auto intType = dyn_cast<IntegerType>(inputs[0].getType());
246 if (!intType || intType.getWidth() != 1)
247 return Value();
248
249 auto castOp =
250 inputs[0].getDefiningOp<mlir::UnrealizedConversionCastOp>();
251 if (!castOp || castOp.getInputs().size() != 1)
252 return Value();
253
254 if (!isa<mlir::smt::BoolType>(castOp.getInputs()[0].getType()))
255 return Value();
256
257 Value constZero = builder.create<mlir::smt::BVConstantOp>(loc, 0, 1);
258 Value constOne = builder.create<mlir::smt::BVConstantOp>(loc, 1, 1);
259 return builder.create<mlir::smt::IteOp>(loc, castOp.getInputs()[0],
260 constOne, constZero);
261 });
262
263 // Convert a 'smt.bv<1>'-typed value to a 'smt.bool'-typed value
264 converter.addTargetMaterialization(
265 [&](OpBuilder &builder, mlir::smt::BoolType resultType, ValueRange inputs,
266 Location loc) -> Value {
267 if (inputs.size() != 1)
268 return Value();
269
270 auto bvType = dyn_cast<mlir::smt::BitVectorType>(inputs[0].getType());
271 if (!bvType || bvType.getWidth() != 1)
272 return Value();
273
274 Value constOne = builder.create<mlir::smt::BVConstantOp>(loc, 1, 1);
275 return builder.create<mlir::smt::EqOp>(loc, inputs[0], constOne);
276 });
277
278 // Default source materialization to convert from illegal types to legal
279 // types, e.g., at the boundary of an inlined child block.
280 converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
281 ValueRange inputs,
282 Location loc) -> Value {
283 return builder
284 .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
285 ->getResult(0);
286 });
287}
288
289void circt::populateHWToSMTConversionPatterns(TypeConverter &converter,
290 RewritePatternSet &patterns) {
291 patterns.add<HWConstantOpConversion, HWModuleOpConversion, OutputOpConversion,
292 InstanceOpConversion, ReplaceWithInput<seq::ToClockOp>,
293 ReplaceWithInput<seq::FromClockOp>, ArrayCreateOpConversion,
294 ArrayGetOpConversion>(converter, patterns.getContext());
295}
296
297void ConvertHWToSMTPass::runOnOperation() {
298 ConversionTarget target(getContext());
299 target.addIllegalDialect<hw::HWDialect>();
300 target.addIllegalOp<seq::FromClockOp>();
301 target.addIllegalOp<seq::ToClockOp>();
302 target.addLegalDialect<mlir::smt::SMTDialect>();
303 target.addLegalDialect<mlir::func::FuncDialect>();
304
305 RewritePatternSet patterns(&getContext());
306 TypeConverter converter;
309
310 if (failed(mlir::applyPartialConversion(getOperation(), target,
311 std::move(patterns))))
312 return signalPassFailure();
313
314 // Sort the functions topologically because 'hw.module' has a graph region
315 // while 'func.func' is a regular SSACFG region. Real combinational cycles or
316 // pseudo cycles through module instances are not supported yet.
317 for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
318 // Skip functions that are definitely not the result of lowering from
319 // 'hw.module'
320 if (func.getBody().getBlocks().size() != 1)
321 continue;
322
323 mlir::sortTopologically(&func.getBody().front());
324 }
325}
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
create(elements)
Definition hw.py:483
create(array_value, idx)
Definition hw.py:450
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
Definition HWToSMT.cpp:289
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition HWToSMT.cpp:184
Definition hw.py:1