CIRCT 22.0.0git
Loading...
Searching...
No Matches
HWToSMT.cpp
Go to the documentation of this file.
1//===- HWToSMT.cpp --------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Analysis/TopologicalSortUtils.h"
13#include "mlir/Dialect/Func/IR/FuncOps.h"
14#include "mlir/Dialect/SMT/IR/SMTOps.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/DialectConversion.h"
17
18namespace circt {
19#define GEN_PASS_DEF_CONVERTHWTOSMT
20#include "circt/Conversion/Passes.h.inc"
21} // namespace circt
22
23using namespace circt;
24using namespace hw;
25
26//===----------------------------------------------------------------------===//
27// Conversion patterns
28//===----------------------------------------------------------------------===//
29
30namespace {
31/// Lower a hw::ConstantOp operation to smt::BVConstantOp
32struct HWConstantOpConversion : OpConversionPattern<ConstantOp> {
34
35 LogicalResult
36 matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
37 ConversionPatternRewriter &rewriter) const override {
38 if (adaptor.getValue().getBitWidth() < 1)
39 return rewriter.notifyMatchFailure(op.getLoc(),
40 "0-bit constants not supported");
41 rewriter.replaceOpWithNewOp<mlir::smt::BVConstantOp>(op,
42 adaptor.getValue());
43 return success();
44 }
45};
46
47/// Lower a hw::HWModuleOp operation to func::FuncOp.
48struct HWModuleOpConversion : OpConversionPattern<HWModuleOp> {
50
51 HWModuleOpConversion(TypeConverter &converter, MLIRContext *context,
52 bool replaceWithSolver)
54 context),
55 replaceWithSolver(replaceWithSolver) {}
56
57 LogicalResult
58 matchAndRewrite(HWModuleOp op, OpAdaptor adaptor,
59 ConversionPatternRewriter &rewriter) const override {
60 auto funcTy = op.getModuleType().getFuncType();
61 SmallVector<Type> inputTypes, resultTypes;
62 if (failed(typeConverter->convertTypes(funcTy.getInputs(), inputTypes)))
63 return failure();
64 if (failed(typeConverter->convertTypes(funcTy.getResults(), resultTypes)))
65 return failure();
66 if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter)))
67 return failure();
68 auto loc = op.getLoc();
69 if (replaceWithSolver) {
70 // If we're exporting to SMTLIB we need to move the module into an
71 // smt.solver op (pre-pattern checks make sure we only have one module).
72 auto solverOp = mlir::smt::SolverOp::create(rewriter, loc, {}, {});
73 auto *solverBlock = rewriter.createBlock(&solverOp.getBodyRegion());
74 // Create a new symbolic value to replace each input
75 rewriter.setInsertionPointToStart(solverBlock);
76 SmallVector<Value> symVals;
77 for (auto inputType : inputTypes) {
78 auto symVal = mlir::smt::DeclareFunOp::create(rewriter, loc, inputType);
79 symVals.push_back(symVal);
80 }
81 // Inline module body into solver op and replace args with new symbolic
82 // values
83 rewriter.inlineBlockBefore(op.getBodyBlock(), solverBlock,
84 solverBlock->end(), symVals);
85 rewriter.eraseOp(op);
86 return success();
87 }
88 auto funcOp = mlir::func::FuncOp::create(
89 rewriter, loc, adaptor.getSymNameAttr(),
90 rewriter.getFunctionType(inputTypes, resultTypes));
91 rewriter.inlineRegionBefore(op.getBody(), funcOp.getBody(), funcOp.end());
92 rewriter.eraseOp(op);
93 return success();
94 }
95
96 bool replaceWithSolver;
97};
98
99/// Lower a hw::OutputOp operation to func::ReturnOp.
100struct OutputOpConversion : OpConversionPattern<OutputOp> {
102
103 OutputOpConversion(TypeConverter &converter, MLIRContext *context,
104 bool assertModuleOutputs)
105 : OpConversionPattern<OutputOp>::OpConversionPattern(converter, context),
106 assertModuleOutputs(assertModuleOutputs) {}
107
108 LogicalResult
109 matchAndRewrite(OutputOp op, OpAdaptor adaptor,
110 ConversionPatternRewriter &rewriter) const override {
111 if (assertModuleOutputs) {
112 Location loc = op.getLoc();
113 for (auto output : adaptor.getOutputs()) {
114 Value constOutput =
115 mlir::smt::DeclareFunOp::create(rewriter, loc, output.getType());
116 Value eq = mlir::smt::EqOp::create(rewriter, loc, output, constOutput);
117 mlir::smt::AssertOp::create(rewriter, loc, eq);
118 }
119 rewriter.replaceOpWithNewOp<mlir::smt::YieldOp>(op);
120 return success();
121 }
122 rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, adaptor.getOutputs());
123 return success();
124 }
125
126 bool assertModuleOutputs;
127};
128
129/// Lower a hw::InstanceOp operation to func::CallOp.
130struct InstanceOpConversion : OpConversionPattern<InstanceOp> {
132
133 LogicalResult
134 matchAndRewrite(InstanceOp op, OpAdaptor adaptor,
135 ConversionPatternRewriter &rewriter) const override {
136 SmallVector<Type> resultTypes;
137 if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
138 return failure();
139
140 rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
141 op, adaptor.getModuleNameAttr(), resultTypes, adaptor.getInputs());
142 return success();
143 }
144};
145
146/// Lower a hw::ArrayCreateOp operation to smt::DeclareFun and an
147/// smt::ArrayStoreOp for each operand.
148struct ArrayCreateOpConversion : OpConversionPattern<ArrayCreateOp> {
150
151 LogicalResult
152 matchAndRewrite(ArrayCreateOp op, OpAdaptor adaptor,
153 ConversionPatternRewriter &rewriter) const override {
154 Location loc = op.getLoc();
155 Type arrTy = typeConverter->convertType(op.getType());
156 if (!arrTy)
157 return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type");
158
159 unsigned width = adaptor.getInputs().size();
160
161 Value arr = mlir::smt::DeclareFunOp::create(rewriter, loc, arrTy);
162 for (auto [i, el] : llvm::enumerate(adaptor.getInputs())) {
163 Value idx = mlir::smt::BVConstantOp::create(rewriter, loc, width - i - 1,
164 llvm::Log2_64_Ceil(width));
165 arr = mlir::smt::ArrayStoreOp::create(rewriter, loc, arr, idx, el);
166 }
167
168 rewriter.replaceOp(op, arr);
169 return success();
170 }
171};
172
173/// Lower a hw::ArrayGetOp operation to smt::ArraySelectOp
174struct ArrayGetOpConversion : OpConversionPattern<ArrayGetOp> {
176
177 LogicalResult
178 matchAndRewrite(ArrayGetOp op, OpAdaptor adaptor,
179 ConversionPatternRewriter &rewriter) const override {
180 Location loc = op.getLoc();
181 unsigned numElements =
182 cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
183
184 Type type = typeConverter->convertType(op.getType());
185 if (!type)
186 return rewriter.notifyMatchFailure(op.getLoc(),
187 "unsupported array element type");
188
189 Value oobVal = mlir::smt::DeclareFunOp::create(rewriter, loc, type);
190 Value numElementsVal = mlir::smt::BVConstantOp::create(
191 rewriter, loc, numElements - 1, llvm::Log2_64_Ceil(numElements));
192 Value inBounds = mlir::smt::BVCmpOp::create(
193 rewriter, loc, mlir::smt::BVCmpPredicate::ule, adaptor.getIndex(),
194 numElementsVal);
195 Value indexed = mlir::smt::ArraySelectOp::create(
196 rewriter, loc, adaptor.getInput(), adaptor.getIndex());
197 rewriter.replaceOpWithNewOp<mlir::smt::IteOp>(op, inBounds, indexed,
198 oobVal);
199 return success();
200 }
201};
202
203/// Lower a hw::ArrayInjectOp operation to smt::ArrayStoreOp.
204struct ArrayInjectOpConversion : OpConversionPattern<ArrayInjectOp> {
205 using OpConversionPattern<ArrayInjectOp>::OpConversionPattern;
206
207 LogicalResult
208 matchAndRewrite(ArrayInjectOp op, OpAdaptor adaptor,
209 ConversionPatternRewriter &rewriter) const override {
210 Location loc = op.getLoc();
211 unsigned numElements =
212 cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
213
214 Type arrType = typeConverter->convertType(op.getType());
215 if (!arrType)
216 return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type");
217
218 Value oobVal = mlir::smt::DeclareFunOp::create(rewriter, loc, arrType);
219 // Check if the index is within bounds
220 Value numElementsVal = mlir::smt::BVConstantOp::create(
221 rewriter, loc, numElements - 1, llvm::Log2_64_Ceil(numElements));
222 Value inBounds = mlir::smt::BVCmpOp::create(
223 rewriter, loc, mlir::smt::BVCmpPredicate::ule, adaptor.getIndex(),
224 numElementsVal);
225
226 // Store the element at the given index
227 Value stored = mlir::smt::ArrayStoreOp::create(
228 rewriter, loc, adaptor.getInput(), adaptor.getIndex(),
229 adaptor.getElement());
230
231 // Return unbounded array if out of bounds
232 rewriter.replaceOpWithNewOp<mlir::smt::IteOp>(op, inBounds, stored, oobVal);
233 return success();
234 }
235};
236
237/// Remove redundant (seq::FromClock and seq::ToClock) ops.
238template <typename OpTy>
239struct ReplaceWithInput : OpConversionPattern<OpTy> {
241 using OpAdaptor = typename OpTy::Adaptor;
242
243 LogicalResult
244 matchAndRewrite(OpTy op, OpAdaptor adaptor,
245 ConversionPatternRewriter &rewriter) const override {
246 rewriter.replaceOp(op, adaptor.getOperands());
247 return success();
248 }
249};
250
251} // namespace
252
253//===----------------------------------------------------------------------===//
254// Convert HW to SMT pass
255//===----------------------------------------------------------------------===//
256
257namespace {
258struct ConvertHWToSMTPass
259 : public impl::ConvertHWToSMTBase<ConvertHWToSMTPass> {
260 using Base::Base;
261 void runOnOperation() override;
262};
263} // namespace
264
265void circt::populateHWToSMTTypeConverter(TypeConverter &converter) {
266 // The semantics of the builtin integer at the CIRCT core level is currently
267 // not very well defined. It is used for two-valued, four-valued, and possible
268 // other multi-valued logic. Here, we interpret it as two-valued for now.
269 // From a formal perspective, CIRCT would ideally define its own types for
270 // two-valued, four-valued, nine-valued (etc.) logic each. In MLIR upstream
271 // the integer type also carries poison information (which we don't have in
272 // CIRCT?).
273 converter.addConversion([](IntegerType type) -> std::optional<Type> {
274 if (type.getWidth() <= 0)
275 return std::nullopt;
276 return mlir::smt::BitVectorType::get(type.getContext(), type.getWidth());
277 });
278 converter.addConversion([](seq::ClockType type) -> std::optional<Type> {
279 return mlir::smt::BitVectorType::get(type.getContext(), 1);
280 });
281 converter.addConversion([&](ArrayType type) -> std::optional<Type> {
282 auto rangeType = converter.convertType(type.getElementType());
283 if (!rangeType)
284 return {};
285 auto domainType = mlir::smt::BitVectorType::get(
286 type.getContext(), llvm::Log2_64_Ceil(type.getNumElements()));
287 return mlir::smt::ArrayType::get(type.getContext(), domainType, rangeType);
288 });
289
290 // Default target materialization to convert from illegal types to legal
291 // types, e.g., at the boundary of an inlined child block.
292 converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
293 ValueRange inputs,
294 Location loc) -> Value {
295 return mlir::UnrealizedConversionCastOp::create(builder, loc, resultType,
296 inputs)
297 ->getResult(0);
298 });
299
300 // Convert a 'smt.bool'-typed value to a 'smt.bv<N>'-typed value
301 converter.addTargetMaterialization([&](OpBuilder &builder,
302 mlir::smt::BitVectorType resultType,
303 ValueRange inputs,
304 Location loc) -> Value {
305 if (inputs.size() != 1)
306 return Value();
307
308 if (!isa<mlir::smt::BoolType>(inputs[0].getType()))
309 return Value();
310
311 unsigned width = resultType.getWidth();
312 Value constZero = mlir::smt::BVConstantOp::create(builder, loc, 0, width);
313 Value constOne = mlir::smt::BVConstantOp::create(builder, loc, 1, width);
314 return mlir::smt::IteOp::create(builder, loc, inputs[0], constOne,
315 constZero);
316 });
317
318 // Convert an unrealized conversion cast from 'smt.bool' to i1
319 // into a direct conversion from 'smt.bool' to 'smt.bv<1>'.
320 converter.addTargetMaterialization(
321 [&](OpBuilder &builder, mlir::smt::BitVectorType resultType,
322 ValueRange inputs, Location loc) -> Value {
323 if (inputs.size() != 1 || resultType.getWidth() != 1)
324 return Value();
325
326 auto intType = dyn_cast<IntegerType>(inputs[0].getType());
327 if (!intType || intType.getWidth() != 1)
328 return Value();
329
330 auto castOp =
331 inputs[0].getDefiningOp<mlir::UnrealizedConversionCastOp>();
332 if (!castOp || castOp.getInputs().size() != 1)
333 return Value();
334
335 if (!isa<mlir::smt::BoolType>(castOp.getInputs()[0].getType()))
336 return Value();
337
338 Value constZero = mlir::smt::BVConstantOp::create(builder, loc, 0, 1);
339 Value constOne = mlir::smt::BVConstantOp::create(builder, loc, 1, 1);
340 return mlir::smt::IteOp::create(builder, loc, castOp.getInputs()[0],
341 constOne, constZero);
342 });
343
344 // Convert a 'smt.bv<1>'-typed value to a 'smt.bool'-typed value
345 converter.addTargetMaterialization(
346 [&](OpBuilder &builder, mlir::smt::BoolType resultType, ValueRange inputs,
347 Location loc) -> Value {
348 if (inputs.size() != 1)
349 return Value();
350
351 auto bvType = dyn_cast<mlir::smt::BitVectorType>(inputs[0].getType());
352 if (!bvType || bvType.getWidth() != 1)
353 return Value();
354
355 Value constOne = mlir::smt::BVConstantOp::create(builder, loc, 1, 1);
356 return mlir::smt::EqOp::create(builder, loc, inputs[0], constOne);
357 });
358
359 // Default source materialization to convert from illegal types to legal
360 // types, e.g., at the boundary of an inlined child block.
361 converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
362 ValueRange inputs,
363 Location loc) -> Value {
364 return mlir::UnrealizedConversionCastOp::create(builder, loc, resultType,
365 inputs)
366 ->getResult(0);
367 });
368}
369
370void circt::populateHWToSMTConversionPatterns(TypeConverter &converter,
371 RewritePatternSet &patterns,
372 bool forSMTLIBExport) {
373 patterns.add<HWConstantOpConversion, InstanceOpConversion,
374 ReplaceWithInput<seq::ToClockOp>,
375 ReplaceWithInput<seq::FromClockOp>, ArrayCreateOpConversion,
376 ArrayGetOpConversion, ArrayInjectOpConversion>(
377 converter, patterns.getContext());
378 patterns.add<OutputOpConversion, HWModuleOpConversion>(
379 converter, patterns.getContext(), forSMTLIBExport);
380}
381
382void ConvertHWToSMTPass::runOnOperation() {
383 if (forSMTLIBExport) {
384 auto numModules = 0;
385 auto numInstances = 0;
386 getOperation().walk([&](Operation *op) {
387 if (isa<hw::HWModuleOp>(op))
388 numModules++;
389 if (isa<hw::InstanceOp>(op))
390 numInstances++;
391 });
392 // Error out if there is any module hierarchy or multiple modules
393 // Currently there's no need as this flag is intended for SMTLIB export and
394 // we can just flatten modules
395 if (numModules > 1) {
396 getOperation()->emitError("multiple hw.module operations are not "
397 "supported with for-smtlib-export");
398 return;
399 }
400 if (numInstances > 0) {
401 getOperation()->emitError("hw.instance operations are not supported "
402 "with for-smtlib-export");
403 return;
404 }
405 }
406
407 ConversionTarget target(getContext());
408 target.addIllegalDialect<hw::HWDialect>();
409 target.addIllegalOp<seq::FromClockOp>();
410 target.addIllegalOp<seq::ToClockOp>();
411 target.addLegalDialect<mlir::smt::SMTDialect>();
412 target.addLegalDialect<mlir::func::FuncDialect>();
413
414 RewritePatternSet patterns(&getContext());
415 TypeConverter converter;
417 populateHWToSMTConversionPatterns(converter, patterns, forSMTLIBExport);
418
419 if (failed(mlir::applyPartialConversion(getOperation(), target,
420 std::move(patterns))))
421 return signalPassFailure();
422
423 // Sort the functions topologically because 'hw.module' has a graph region
424 // while 'func.func' is a regular SSACFG region. Real combinational cycles or
425 // pseudo cycles through module instances are not supported yet.
426 for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
427 // Skip functions that are definitely not the result of lowering from
428 // 'hw.module'
429 if (func.getBody().getBlocks().size() != 1)
430 continue;
431
432 mlir::sortTopologically(&func.getBody().front());
433 }
434}
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns, bool forSMTLIBExport)
Get the HW to SMT conversion patterns.
Definition HWToSMT.cpp:370
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition HWToSMT.cpp:265
Definition hw.py:1