CIRCT  20.0.0git
HWToSMT.cpp
Go to the documentation of this file.
1 //===- HWToSMT.cpp --------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "circt/Dialect/HW/HWOps.h"
13 #include "mlir/Analysis/TopologicalSortUtils.h"
14 #include "mlir/Dialect/Func/IR/FuncOps.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/DialectConversion.h"
17 
18 namespace circt {
19 #define GEN_PASS_DEF_CONVERTHWTOSMT
20 #include "circt/Conversion/Passes.h.inc"
21 } // namespace circt
22 
23 using namespace circt;
24 using namespace hw;
25 
26 //===----------------------------------------------------------------------===//
27 // Conversion patterns
28 //===----------------------------------------------------------------------===//
29 
30 namespace {
31 /// Lower a hw::ConstantOp operation to smt::BVConstantOp
32 struct HWConstantOpConversion : OpConversionPattern<ConstantOp> {
34 
35  LogicalResult
36  matchAndRewrite(ConstantOp op, OpAdaptor adaptor,
37  ConversionPatternRewriter &rewriter) const override {
38  if (adaptor.getValue().getBitWidth() < 1)
39  return rewriter.notifyMatchFailure(op.getLoc(),
40  "0-bit constants not supported");
41  rewriter.replaceOpWithNewOp<smt::BVConstantOp>(op, adaptor.getValue());
42  return success();
43  }
44 };
45 
46 /// Lower a hw::HWModuleOp operation to func::FuncOp.
47 struct HWModuleOpConversion : OpConversionPattern<HWModuleOp> {
49 
50  LogicalResult
51  matchAndRewrite(HWModuleOp op, OpAdaptor adaptor,
52  ConversionPatternRewriter &rewriter) const override {
53  auto funcTy = op.getModuleType().getFuncType();
54  SmallVector<Type> inputTypes, resultTypes;
55  if (failed(typeConverter->convertTypes(funcTy.getInputs(), inputTypes)))
56  return failure();
57  if (failed(typeConverter->convertTypes(funcTy.getResults(), resultTypes)))
58  return failure();
59  if (failed(rewriter.convertRegionTypes(&op.getBody(), *typeConverter)))
60  return failure();
61  auto funcOp = rewriter.create<mlir::func::FuncOp>(
62  op.getLoc(), adaptor.getSymNameAttr(),
63  rewriter.getFunctionType(inputTypes, resultTypes));
64  rewriter.inlineRegionBefore(op.getBody(), funcOp.getBody(), funcOp.end());
65  rewriter.eraseOp(op);
66  return success();
67  }
68 };
69 
70 /// Lower a hw::OutputOp operation to func::ReturnOp.
71 struct OutputOpConversion : OpConversionPattern<OutputOp> {
73 
74  LogicalResult
75  matchAndRewrite(OutputOp op, OpAdaptor adaptor,
76  ConversionPatternRewriter &rewriter) const override {
77  rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(op, adaptor.getOutputs());
78  return success();
79  }
80 };
81 
82 /// Lower a hw::InstanceOp operation to func::CallOp.
83 struct InstanceOpConversion : OpConversionPattern<InstanceOp> {
85 
86  LogicalResult
87  matchAndRewrite(InstanceOp op, OpAdaptor adaptor,
88  ConversionPatternRewriter &rewriter) const override {
89  SmallVector<Type> resultTypes;
90  if (failed(typeConverter->convertTypes(op->getResultTypes(), resultTypes)))
91  return failure();
92 
93  rewriter.replaceOpWithNewOp<mlir::func::CallOp>(
94  op, adaptor.getModuleNameAttr(), resultTypes, adaptor.getInputs());
95  return success();
96  }
97 };
98 
99 /// Remove redundant (seq::FromClock and seq::ToClock) ops.
100 template <typename OpTy>
101 struct ReplaceWithInput : OpConversionPattern<OpTy> {
103  using OpAdaptor = typename OpTy::Adaptor;
104 
105  LogicalResult
106  matchAndRewrite(OpTy op, OpAdaptor adaptor,
107  ConversionPatternRewriter &rewriter) const override {
108  rewriter.replaceOp(op, adaptor.getOperands());
109  return success();
110  }
111 };
112 
113 } // namespace
114 
115 //===----------------------------------------------------------------------===//
116 // Convert HW to SMT pass
117 //===----------------------------------------------------------------------===//
118 
119 namespace {
120 struct ConvertHWToSMTPass
121  : public impl::ConvertHWToSMTBase<ConvertHWToSMTPass> {
122  void runOnOperation() override;
123 };
124 } // namespace
125 
126 void circt::populateHWToSMTTypeConverter(TypeConverter &converter) {
127  // The semantics of the builtin integer at the CIRCT core level is currently
128  // not very well defined. It is used for two-valued, four-valued, and possible
129  // other multi-valued logic. Here, we interpret it as two-valued for now.
130  // From a formal perspective, CIRCT would ideally define its own types for
131  // two-valued, four-valued, nine-valued (etc.) logic each. In MLIR upstream
132  // the integer type also carries poison information (which we don't have in
133  // CIRCT?).
134  converter.addConversion([](IntegerType type) -> std::optional<Type> {
135  if (type.getWidth() <= 0)
136  return std::nullopt;
137  return smt::BitVectorType::get(type.getContext(), type.getWidth());
138  });
139  converter.addConversion([](seq::ClockType type) -> std::optional<Type> {
140  return smt::BitVectorType::get(type.getContext(), 1);
141  });
142 
143  // Default target materialization to convert from illegal types to legal
144  // types, e.g., at the boundary of an inlined child block.
145  converter.addTargetMaterialization([&](OpBuilder &builder, Type resultType,
146  ValueRange inputs,
147  Location loc) -> std::optional<Value> {
148  return builder
149  .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
150  ->getResult(0);
151  });
152 
153  // Convert a 'smt.bool'-typed value to a 'smt.bv<N>'-typed value
154  converter.addTargetMaterialization(
155  [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs,
156  Location loc) -> std::optional<Value> {
157  if (inputs.size() != 1)
158  return std::nullopt;
159 
160  if (!isa<smt::BoolType>(inputs[0].getType()))
161  return std::nullopt;
162 
163  unsigned width = resultType.getWidth();
164  Value constZero = builder.create<smt::BVConstantOp>(loc, 0, width);
165  Value constOne = builder.create<smt::BVConstantOp>(loc, 1, width);
166  return builder.create<smt::IteOp>(loc, inputs[0], constOne, constZero);
167  });
168 
169  // Convert an unrealized conversion cast from 'smt.bool' to i1
170  // into a direct conversion from 'smt.bool' to 'smt.bv<1>'.
171  converter.addTargetMaterialization(
172  [&](OpBuilder &builder, smt::BitVectorType resultType, ValueRange inputs,
173  Location loc) -> std::optional<Value> {
174  if (inputs.size() != 1 || resultType.getWidth() != 1)
175  return std::nullopt;
176 
177  auto intType = dyn_cast<IntegerType>(inputs[0].getType());
178  if (!intType || intType.getWidth() != 1)
179  return std::nullopt;
180 
181  auto castOp =
182  inputs[0].getDefiningOp<mlir::UnrealizedConversionCastOp>();
183  if (!castOp || castOp.getInputs().size() != 1)
184  return std::nullopt;
185 
186  if (!isa<smt::BoolType>(castOp.getInputs()[0].getType()))
187  return std::nullopt;
188 
189  Value constZero = builder.create<smt::BVConstantOp>(loc, 0, 1);
190  Value constOne = builder.create<smt::BVConstantOp>(loc, 1, 1);
191  return builder.create<smt::IteOp>(loc, castOp.getInputs()[0], constOne,
192  constZero);
193  });
194 
195  // Convert a 'smt.bv<1>'-typed value to a 'smt.bool'-typed value
196  converter.addTargetMaterialization(
197  [&](OpBuilder &builder, smt::BoolType resultType, ValueRange inputs,
198  Location loc) -> std::optional<Value> {
199  if (inputs.size() != 1)
200  return std::nullopt;
201 
202  auto bvType = dyn_cast<smt::BitVectorType>(inputs[0].getType());
203  if (!bvType || bvType.getWidth() != 1)
204  return std::nullopt;
205 
206  Value constOne = builder.create<smt::BVConstantOp>(loc, 1, 1);
207  return builder.create<smt::EqOp>(loc, inputs[0], constOne);
208  });
209 
210  // Default source materialization to convert from illegal types to legal
211  // types, e.g., at the boundary of an inlined child block.
212  converter.addSourceMaterialization([&](OpBuilder &builder, Type resultType,
213  ValueRange inputs,
214  Location loc) -> std::optional<Value> {
215  return builder
216  .create<mlir::UnrealizedConversionCastOp>(loc, resultType, inputs)
217  ->getResult(0);
218  });
219 }
220 
221 void circt::populateHWToSMTConversionPatterns(TypeConverter &converter,
222  RewritePatternSet &patterns) {
223  patterns.add<HWConstantOpConversion, HWModuleOpConversion, OutputOpConversion,
224  InstanceOpConversion, ReplaceWithInput<seq::ToClockOp>,
225  ReplaceWithInput<seq::FromClockOp>>(converter,
226  patterns.getContext());
227 }
228 
229 void ConvertHWToSMTPass::runOnOperation() {
230  ConversionTarget target(getContext());
231  target.addIllegalDialect<hw::HWDialect>();
232  target.addIllegalOp<seq::FromClockOp>();
233  target.addIllegalOp<seq::ToClockOp>();
234  target.addLegalDialect<smt::SMTDialect>();
235  target.addLegalDialect<mlir::func::FuncDialect>();
236 
237  RewritePatternSet patterns(&getContext());
238  TypeConverter converter;
239  populateHWToSMTTypeConverter(converter);
241 
242  if (failed(mlir::applyPartialConversion(getOperation(), target,
243  std::move(patterns))))
244  return signalPassFailure();
245 
246  // Sort the functions topologically because 'hw.module' has a graph region
247  // while 'func.func' is a regular SSACFG region. Real combinational cycles or
248  // pseudo cycles through module instances are not supported yet.
249  for (auto func : getOperation().getOps<mlir::func::FuncOp>()) {
250  // Skip functions that are definitely not the result of lowering from
251  // 'hw.module'
252  if (func.getBody().getBlocks().size() != 1)
253  continue;
254 
255  mlir::sortTopologically(&func.getBody().front());
256  }
257 }
int32_t width
Definition: FIRRTL.cpp:36
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
Definition: HWToSMT.cpp:221
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition: HWToSMT.cpp:126
Definition: hw.py:1