CIRCT  20.0.0git
HWToSMT.cpp
Go to the documentation of this file.
1 //===- HWToSMT.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "circt/Dialect/HW/HWOps.h"
13 #include "mlir/Analysis/TopologicalSortUtils.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 
18 namespace circt {
19 #define GEN_PASS_DEF_CONVERTHWTOSMT
20 #include "circt/Conversion/Passes.h.inc"
21 } // namespace circt
22 
23 using namespace circt;
24 using namespace hw;
25 
26 //===----------------------------------------------------------------------===//
27 // Conversion patterns
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 /// Lower a hw::ConstantOp operation to smt::BVConstantOp
32 struct HWConstantOpConversion : OpConversionPattern<ConstantOp> {
34 
35  LogicalResult
36  matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
37  ConversionPatternRewriter &rewriter) const override {
38  if (adaptor.getValue().getBitWidth() < 1)
39  return rewriter.notifyMatchFailure(op.getLoc(),
40  "0-bit constants not supported");
41  rewriter.replaceOpWithNewOp<smt::BVConstantOp>(op, adaptor.getValue());
42  return success();
43  }
44 };
45 
46 /// Lower a hw::HWModuleOp operation to func::FuncOp.
47 struct HWModuleOpConversion : OpConversionPattern<HWModuleOp> {
49 
50  LogicalResult
51  matchAndRewrite(HWModuleOp op, OpAdaptor adaptor,
52  ConversionPatternRewriter &rewriter) const override {
53  auto funcTy = op.getModuleType().getFuncType();
54  SmallVector<Type> inputTypes, resultTypes;
55  if (failed(typeConverter->convertTypes(funcTy.getInputs(), inputTypes)))
56  return failure();
57  if (failed(typeConverter->convertTypes(funcTy.getResults(), resultTypes)))
58  return failure();
59  if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter)))
60  return failure();
61  auto funcOp = rewriter.create<mlir::func::FuncOp>(
62  op.getLoc(), adaptor.getSymNameAttr(),
63  rewriter.getFunctionType(inputTypes, resultTypes));
64  rewriter.inlineRegionBefore(op.getBody(), funcOp.getBody(), funcOp.end());
65  rewriter.eraseOp(op);
66  return success();
67  }
68 };
69 
70 /// Lower a hw::OutputOp operation to func::ReturnOp.
71 struct OutputOpConversion : OpConversionPattern<OutputOp> {
73 
74  LogicalResult
75  matchAndRewrite(OutputOp op, OpAdaptor adaptor,
76  ConversionPatternRewriter &rewriter) const override {
77  rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, adaptor.getOutputs());
78  return success();
79  }
80 };
81 
82 /// Lower a hw::InstanceOp operation to func::CallOp.
83 struct InstanceOpConversion : OpConversionPattern<InstanceOp> {
85 
86  LogicalResult
87  matchAndRewrite(InstanceOp op, OpAdaptor adaptor,
88  ConversionPatternRewriter &rewriter) const override {
89  SmallVector<Type> resultTypes;
90  if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
91  return failure();
92 
93  rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
94  op, adaptor.getModuleNameAttr(), resultTypes, adaptor.getInputs());
95  return success();
96  }
97 };
98 
99 /// Lower a hw::ArrayCreateOp operation to smt::DeclareFun and an
100 /// smt::ArrayStoreOp for each operand.
101 struct ArrayCreateOpConversion : OpConversionPattern<ArrayCreateOp> {
103 
104  LogicalResult
105  matchAndRewrite(ArrayCreateOp op, OpAdaptor adaptor,
106  ConversionPatternRewriter &rewriter) const override {
107  Location loc = op.getLoc();
108  Type arrTy = typeConverter->convertType(op.getType());
109  if (!arrTy)
110  return rewriter.notifyMatchFailure(op.getLoc(), "unsupported array type");
111 
112  unsigned width = adaptor.getInputs().size();
113 
114  Value arr = rewriter.create<smt::DeclareFunOp>(loc, arrTy);
115  for (auto [i, el] : llvm::enumerate(adaptor.getInputs())) {
116  Value idx = rewriter.create<smt::BVConstantOp>(loc, width - i - 1,
117  llvm::Log2_64_Ceil(width));
118  arr = rewriter.create<smt::ArrayStoreOp>(loc, arr, idx, el);
119  }
120 
121  rewriter.replaceOp(op, arr);
122  return success();
123  }
124 };
125 
126 /// Lower a hw::ArrayGetOp operation to smt::ArraySelectOp
127 struct ArrayGetOpConversion : OpConversionPattern<ArrayGetOp> {
129 
130  LogicalResult
131  matchAndRewrite(ArrayGetOp op, OpAdaptor adaptor,
132  ConversionPatternRewriter &rewriter) const override {
133  Location loc = op.getLoc();
134  unsigned numElements =
135  cast<hw::ArrayType>(op.getInput().getType()).getNumElements();
136 
137  Type type = typeConverter->convertType(op.getType());
138  if (!type)
139  return rewriter.notifyMatchFailure(op.getLoc(),
140  "unsupported array element type");
141 
142  Value oobVal = rewriter.create<smt::DeclareFunOp>(loc, type);
143  Value numElementsVal = rewriter.create<smt::BVConstantOp>(
144  loc, numElements - 1, llvm::Log2_64_Ceil(numElements));
145  Value inBounds = rewriter.create<smt::BVCmpOp>(
146  loc, smt::BVCmpPredicate::ule, adaptor.getIndex(), numElementsVal);
147  Value indexed = rewriter.create<smt::ArraySelectOp>(loc, adaptor.getInput(),
148  adaptor.getIndex());
149  rewriter.replaceOpWithNewOp<smt::IteOp>(op, inBounds, indexed, oobVal);
150  return success();
151  }
152 };
153 
154 /// Remove redundant (seq::FromClock and seq::ToClock) ops.
155 template <typename OpTy>
156 struct ReplaceWithInput : OpConversionPattern<OpTy> {
158  using OpAdaptor = typename OpTy::Adaptor;
159 
160  LogicalResult
161  matchAndRewrite(OpTy op, OpAdaptor adaptor,
162  ConversionPatternRewriter &rewriter) const override {
163  rewriter.replaceOp(op, adaptor.getOperands());
164  return success();
165  }
166 };
167 
168 } // namespace
169 
170 //===----------------------------------------------------------------------===//
171 // Convert HW to SMT pass
172 //===----------------------------------------------------------------------===//
173 
174 namespace {
175 struct ConvertHWToSMTPass
176  : public impl::ConvertHWToSMTBase<ConvertHWToSMTPass> {
177  void runOnOperation() override;
178 };
179 } // namespace
180 
181 void circt::populateHWToSMTTypeConverter(TypeConverter &converter) {
182  // The semantics of the builtin integer at the CIRCT core level is currently
183  // not very well defined. It is used for two-valued, four-valued, and possible
184  // other multi-valued logic. Here, we interpret it as two-valued for now.
185  // From a formal perspective, CIRCT would ideally define its own types for
186  // two-valued, four-valued, nine-valued (etc.) logic each. In MLIR upstream
187  // the integer type also carries poison information (which we don't have in
188  // CIRCT?).
189  converter.addConversion([](IntegerType type) -> std::optional<Type> {
190  if (type.getWidth() <= 0)
191  return std::nullopt;
192  return smt::BitVectorType::get(type.getContext(), type.getWidth());
193  });
194  converter.addConversion([](seq::ClockType type) -> std::optional<Type> {
195  return smt::BitVectorType::get(type.getContext(), 1);
196  });
197  converter.addConversion([&](ArrayType type) -> std::optional<Type> {
198  auto rangeType = converter.convertType(type.getElementType());
199  if (!rangeType)
200  return {};
201  auto domainType = smt::BitVectorType::get(
202  type.getContext(), llvm::Log2_64_Ceil(type.getNumElements()));
203  return smt::ArrayType::get(type.getContext(), domainType, rangeType);
204  });
205 
206  // Default target materialization to convert from illegal types to legal
207  // types, e.g., at the boundary of an inlined child block.
208  converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
209  ValueRange inputs,
210  Location loc) -> Value {
211  return builder
212  .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
213  ->getResult(0);
214  });
215 
216  // Convert a 'smt.bool'-typed value to a 'smt.bv<N>'-typed value
217  converter.addTargetMaterialization(
218  [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs,
219  Location loc) -> Value {
220  if (inputs.size() != 1)
221  return Value();
222 
223  if (!isa<smt::BoolType>(inputs[0].getType()))
224  return Value();
225 
226  unsigned width = resultType.getWidth();
227  Value constZero = builder.create<smt::BVConstantOp>(loc, 0, width);
228  Value constOne = builder.create<smt::BVConstantOp>(loc, 1, width);
229  return builder.create<smt::IteOp>(loc, inputs[0], constOne, constZero);
230  });
231 
232  // Convert an unrealized conversion cast from 'smt.bool' to i1
233  // into a direct conversion from 'smt.bool' to 'smt.bv<1>'.
234  converter.addTargetMaterialization(
235  [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs,
236  Location loc) -> Value {
237  if (inputs.size() != 1 || resultType.getWidth() != 1)
238  return Value();
239 
240  auto intType = dyn_cast<IntegerType>(inputs[0].getType());
241  if (!intType || intType.getWidth() != 1)
242  return Value();
243 
244  auto castOp =
245  inputs[0].getDefiningOp<mlir::UnrealizedConversionCastOp>();
246  if (!castOp || castOp.getInputs().size() != 1)
247  return Value();
248 
249  if (!isa<smt::BoolType>(castOp.getInputs()[0].getType()))
250  return Value();
251 
252  Value constZero = builder.create<smt::BVConstantOp>(loc, 0, 1);
253  Value constOne = builder.create<smt::BVConstantOp>(loc, 1, 1);
254  return builder.create<smt::IteOp>(loc, castOp.getInputs()[0], constOne,
255  constZero);
256  });
257 
258  // Convert a 'smt.bv<1>'-typed value to a 'smt.bool'-typed value
259  converter.addTargetMaterialization(
260  [&](OpBuilder &builder, smt::BoolType resultType, ValueRange inputs,
261  Location loc) -> Value {
262  if (inputs.size() != 1)
263  return Value();
264 
265  auto bvType = dyn_cast<smt::BitVectorType>(inputs[0].getType());
266  if (!bvType || bvType.getWidth() != 1)
267  return Value();
268 
269  Value constOne = builder.create<smt::BVConstantOp>(loc, 1, 1);
270  return builder.create<smt::EqOp>(loc, inputs[0], constOne);
271  });
272 
273  // Default source materialization to convert from illegal types to legal
274  // types, e.g., at the boundary of an inlined child block.
275  converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
276  ValueRange inputs,
277  Location loc) -> Value {
278  return builder
279  .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
280  ->getResult(0);
281  });
282 }
283 
284 void circt::populateHWToSMTConversionPatterns(TypeConverter &converter,
285  RewritePatternSet &patterns) {
286  patterns.add<HWConstantOpConversion, HWModuleOpConversion, OutputOpConversion,
287  InstanceOpConversion, ReplaceWithInput<seq::ToClockOp>,
288  ReplaceWithInput<seq::FromClockOp>, ArrayCreateOpConversion,
289  ArrayGetOpConversion>(converter, patterns.getContext());
290 }
291 
292 void ConvertHWToSMTPass::runOnOperation() {
293  ConversionTarget target(getContext());
294  target.addIllegalDialect<hw::HWDialect>();
295  target.addIllegalOp<seq::FromClockOp>();
296  target.addIllegalOp<seq::ToClockOp>();
297  target.addLegalDialect<smt::SMTDialect>();
298  target.addLegalDialect<mlir::func::FuncDialect>();
299 
300  RewritePatternSet patterns(&getContext());
301  TypeConverter converter;
302  populateHWToSMTTypeConverter(converter);
304 
305  if (failed(mlir::applyPartialConversion(getOperation(), target,
306  std::move(patterns))))
307  return signalPassFailure();
308 
309  // Sort the functions topologically because 'hw.module' has a graph region
310  // while 'func.func' is a regular SSACFG region. Real combinational cycles or
311  // pseudo cycles through module instances are not supported yet.
312  for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
313  // Skip functions that are definitely not the result of lowering from
314  // 'hw.module'
315  if (func.getBody().getBlocks().size() != 1)
316  continue;
317 
318  mlir::sortTopologically(&func.getBody().front());
319  }
320 }
MlirType uint64_t numElements
Definition: CHIRRTL.cpp:30
int32_t width
Definition: FIRRTL.cpp:40
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
Definition: HWToSMT.cpp:284
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition: HWToSMT.cpp:181
Definition: hw.py:1