CIRCT  20.0.0git
CombToSMT.cpp
Go to the documentation of this file.
1 //===- CombToSMT.cpp ------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/Dialect/Func/IR/FuncOps.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/DialectConversion.h"
16 
17 namespace circt {
18 #define GEN_PASS_DEF_CONVERTCOMBTOSMT
19 #include "circt/Conversion/Passes.h.inc"
20 } // namespace circt
21 
22 using namespace circt;
23 using namespace comb;
24 
25 //===----------------------------------------------------------------------===//
26 // Conversion patterns
27 //===----------------------------------------------------------------------===//
28 
29 namespace {
30 /// Lower a comb::ReplicateOp operation to smt::RepeatOp
31 struct CombReplicateOpConversion : OpConversionPattern<ReplicateOp> {
33 
34  LogicalResult
35  matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
36  ConversionPatternRewriter &rewriter) const override {
37  rewriter.replaceOpWithNewOp<smt::RepeatOp>(op, op.getMultiple(),
38  adaptor.getInput());
39  return success();
40  }
41 };
42 
43 /// Lower a comb::ICmpOp operation to a smt::BVCmpOp, smt::EqOp or
44 /// smt::DistinctOp
45 struct IcmpOpConversion : OpConversionPattern<ICmpOp> {
47 
48  LogicalResult
49  matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
50  ConversionPatternRewriter &rewriter) const override {
51  if (adaptor.getPredicate() == ICmpPredicate::weq ||
52  adaptor.getPredicate() == ICmpPredicate::ceq ||
53  adaptor.getPredicate() == ICmpPredicate::wne ||
54  adaptor.getPredicate() == ICmpPredicate::cne)
55  return rewriter.notifyMatchFailure(op,
56  "comparison predicate not supported");
57 
58  if (adaptor.getPredicate() == ICmpPredicate::eq) {
59  rewriter.replaceOpWithNewOp<smt::EqOp>(op, adaptor.getLhs(),
60  adaptor.getRhs());
61  return success();
62  }
63 
64  if (adaptor.getPredicate() == ICmpPredicate::ne) {
65  rewriter.replaceOpWithNewOp<smt::DistinctOp>(op, adaptor.getLhs(),
66  adaptor.getRhs());
67  return success();
68  }
69 
70  smt::BVCmpPredicate pred;
71  switch (adaptor.getPredicate()) {
72  case ICmpPredicate::sge:
73  pred = smt::BVCmpPredicate::sge;
74  break;
75  case ICmpPredicate::sgt:
76  pred = smt::BVCmpPredicate::sgt;
77  break;
78  case ICmpPredicate::sle:
79  pred = smt::BVCmpPredicate::sle;
80  break;
81  case ICmpPredicate::slt:
82  pred = smt::BVCmpPredicate::slt;
83  break;
84  case ICmpPredicate::uge:
85  pred = smt::BVCmpPredicate::uge;
86  break;
87  case ICmpPredicate::ugt:
88  pred = smt::BVCmpPredicate::ugt;
89  break;
90  case ICmpPredicate::ule:
91  pred = smt::BVCmpPredicate::ule;
92  break;
93  case ICmpPredicate::ult:
94  pred = smt::BVCmpPredicate::ult;
95  break;
96  default:
97  llvm_unreachable("all cases handled above");
98  }
99 
100  rewriter.replaceOpWithNewOp<smt::BVCmpOp>(op, pred, adaptor.getLhs(),
101  adaptor.getRhs());
102  return success();
103  }
104 };
105 
106 /// Lower a comb::ExtractOp operation to an smt::ExtractOp
107 struct ExtractOpConversion : OpConversionPattern<ExtractOp> {
109 
110  LogicalResult
111  matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
112  ConversionPatternRewriter &rewriter) const override {
113 
114  rewriter.replaceOpWithNewOp<smt::ExtractOp>(
115  op, typeConverter->convertType(op.getResult().getType()),
116  adaptor.getLowBitAttr(), adaptor.getInput());
117  return success();
118  }
119 };
120 
121 /// Lower a comb::MuxOp operation to an smt::IteOp
122 struct MuxOpConversion : OpConversionPattern<MuxOp> {
124 
125  LogicalResult
126  matchAndRewrite(MuxOp op, OpAdaptor adaptor,
127  ConversionPatternRewriter &rewriter) const override {
128  Value condition = typeConverter->materializeTargetConversion(
129  rewriter, op.getLoc(), smt::BoolType::get(getContext()),
130  adaptor.getCond());
131  rewriter.replaceOpWithNewOp<smt::IteOp>(
132  op, condition, adaptor.getTrueValue(), adaptor.getFalseValue());
133  return success();
134  }
135 };
136 
137 /// Lower a comb::SubOp operation to an smt::BVNegOp + smt::BVAddOp
138 struct SubOpConversion : OpConversionPattern<SubOp> {
140 
141  LogicalResult
142  matchAndRewrite(SubOp op, OpAdaptor adaptor,
143  ConversionPatternRewriter &rewriter) const override {
144  Value negRhs = rewriter.create<smt::BVNegOp>(op.getLoc(), adaptor.getRhs());
145  rewriter.replaceOpWithNewOp<smt::BVAddOp>(op, adaptor.getLhs(), negRhs);
146  return success();
147  }
148 };
149 
150 /// Lower a comb::ParityOp operation to a chain of smt::Extract + XOr ops
151 struct ParityOpConversion : OpConversionPattern<ParityOp> {
153 
154  LogicalResult
155  matchAndRewrite(ParityOp op, OpAdaptor adaptor,
156  ConversionPatternRewriter &rewriter) const override {
157  Location loc = op.getLoc();
158  unsigned bitwidth =
159  cast<smt::BitVectorType>(adaptor.getInput().getType()).getWidth();
160 
161  // Note: the SMT bitvector type does not support 0 bitwidth vectors and thus
162  // the type conversion should already fail.
163  Type oneBitTy = smt::BitVectorType::get(getContext(), 1);
164  Value runner =
165  rewriter.create<smt::ExtractOp>(loc, oneBitTy, 0, adaptor.getInput());
166  for (unsigned i = 1; i < bitwidth; ++i) {
167  Value ext =
168  rewriter.create<smt::ExtractOp>(loc, oneBitTy, i, adaptor.getInput());
169  runner = rewriter.create<smt::BVXOrOp>(loc, runner, ext);
170  }
171 
172  rewriter.replaceOp(op, runner);
173  return success();
174  }
175 };
176 
177 /// Lower the SourceOp to the TargetOp one-to-one.
178 template <typename SourceOp, typename TargetOp>
179 struct OneToOneOpConversion : OpConversionPattern<SourceOp> {
181  using OpAdaptor = typename SourceOp::Adaptor;
182 
183  LogicalResult
184  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
185  ConversionPatternRewriter &rewriter) const override {
186 
187  rewriter.replaceOpWithNewOp<TargetOp>(
188  op,
190  op.getResult().getType()),
191  adaptor.getOperands());
192  return success();
193  }
194 };
195 
196 /// Lower the SourceOp to the TargetOp special-casing if the second operand is
197 /// zero to return a new symbolic value.
198 template <typename SourceOp, typename TargetOp>
199 struct DivisionOpConversion : OpConversionPattern<SourceOp> {
201  using OpAdaptor = typename SourceOp::Adaptor;
202 
203  LogicalResult
204  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
205  ConversionPatternRewriter &rewriter) const override {
206  Location loc = op.getLoc();
207  auto type = dyn_cast<smt::BitVectorType>(adaptor.getRhs().getType());
208  if (!type)
209  return failure();
210 
211  auto resultType = OpConversionPattern<SourceOp>::typeConverter->convertType(
212  op.getResult().getType());
213  Value zero =
214  rewriter.create<smt::BVConstantOp>(loc, APInt(type.getWidth(), 0));
215  Value isZero = rewriter.create<smt::EqOp>(loc, adaptor.getRhs(), zero);
216  Value symbolicVal = rewriter.create<smt::DeclareFunOp>(loc, resultType);
217  Value division =
218  rewriter.create<TargetOp>(loc, resultType, adaptor.getOperands());
219  rewriter.replaceOpWithNewOp<smt::IteOp>(op, isZero, symbolicVal, division);
220  return success();
221  }
222 };
223 
224 /// Converts an operation with a variadic number of operands to a chain of
225 /// binary operations assuming left-associativity of the operation.
226 template <typename SourceOp, typename TargetOp>
227 struct VariadicToBinaryOpConversion : OpConversionPattern<SourceOp> {
229  using OpAdaptor = typename SourceOp::Adaptor;
230 
231  LogicalResult
232  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
233  ConversionPatternRewriter &rewriter) const override {
234 
235  ValueRange operands = adaptor.getOperands();
236  if (operands.size() < 2)
237  return failure();
238 
239  Value runner = operands[0];
240  for (Value operand : operands.drop_front())
241  runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
242 
243  rewriter.replaceOp(op, runner);
244  return success();
245  }
246 };
247 
248 } // namespace
249 
250 //===----------------------------------------------------------------------===//
251 // Convert Comb to SMT pass
252 //===----------------------------------------------------------------------===//
253 
254 namespace {
255 struct ConvertCombToSMTPass
256  : public impl::ConvertCombToSMTBase<ConvertCombToSMTPass> {
257  void runOnOperation() override;
258 };
259 } // namespace
260 
261 void circt::populateCombToSMTConversionPatterns(TypeConverter &converter,
262  RewritePatternSet &patterns) {
263  patterns.add<CombReplicateOpConversion, IcmpOpConversion, ExtractOpConversion,
264  SubOpConversion, MuxOpConversion, ParityOpConversion,
265  OneToOneOpConversion<ShlOp, smt::BVShlOp>,
266  OneToOneOpConversion<ShrUOp, smt::BVLShrOp>,
267  OneToOneOpConversion<ShrSOp, smt::BVAShrOp>,
268  DivisionOpConversion<DivSOp, smt::BVSDivOp>,
269  DivisionOpConversion<DivUOp, smt::BVUDivOp>,
270  DivisionOpConversion<ModSOp, smt::BVSRemOp>,
271  DivisionOpConversion<ModUOp, smt::BVURemOp>,
272  VariadicToBinaryOpConversion<ConcatOp, smt::ConcatOp>,
273  VariadicToBinaryOpConversion<AddOp, smt::BVAddOp>,
274  VariadicToBinaryOpConversion<MulOp, smt::BVMulOp>,
275  VariadicToBinaryOpConversion<AndOp, smt::BVAndOp>,
276  VariadicToBinaryOpConversion<OrOp, smt::BVOrOp>,
277  VariadicToBinaryOpConversion<XorOp, smt::BVXOrOp>>(
278  converter, patterns.getContext());
279 
280  // TODO: there are two unsupported operations in the comb dialect: 'parity'
281  // and 'truth_table'.
282 }
283 
284 void ConvertCombToSMTPass::runOnOperation() {
285  ConversionTarget target(getContext());
286  target.addIllegalDialect<comb::CombDialect>();
287  target.addLegalDialect<smt::SMTDialect>();
288 
289  RewritePatternSet patterns(&getContext());
290  TypeConverter converter;
291  populateHWToSMTTypeConverter(converter);
292  // Also add HW patterns because some 'comb' canonicalizers produce constant
293  // operations, i.e., even if there is absolutely no HW operation present
294  // initially, we might have to convert one.
297 
298  if (failed(mlir::applyPartialConversion(getOperation(), target,
299  std::move(patterns))))
300  return signalPassFailure();
301 }
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
Definition: HWToSMT.cpp:221
void populateCombToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
Definition: CombToSMT.cpp:261
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition: HWToSMT.cpp:126
Definition: comb.py:1