CIRCT  19.0.0git
CombToSMT.cpp
Go to the documentation of this file.
1 //===- CombToSMT.cpp ------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 
16 namespace circt {
17 #define GEN_PASS_DEF_CONVERTCOMBTOSMT
18 #include "circt/Conversion/Passes.h.inc"
19 } // namespace circt
20 
21 using namespace circt;
22 using namespace comb;
23 
24 //===----------------------------------------------------------------------===//
25 // Conversion patterns
26 //===----------------------------------------------------------------------===//
27 
28 namespace {
29 /// Lower a comb::ReplicateOp operation to smt::RepeatOp
30 struct CombReplicateOpConversion : OpConversionPattern<ReplicateOp> {
32 
33  LogicalResult
34  matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
35  ConversionPatternRewriter &rewriter) const override {
36  rewriter.replaceOpWithNewOp<smt::RepeatOp>(op, op.getMultiple(),
37  adaptor.getInput());
38  return success();
39  }
40 };
41 
42 /// Lower a comb::ICmpOp operation to a smt::BVCmpOp, smt::EqOp or
43 /// smt::DistinctOp
44 struct IcmpOpConversion : OpConversionPattern<ICmpOp> {
46 
47  LogicalResult
48  matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
49  ConversionPatternRewriter &rewriter) const override {
50  if (adaptor.getPredicate() == ICmpPredicate::weq ||
51  adaptor.getPredicate() == ICmpPredicate::ceq ||
52  adaptor.getPredicate() == ICmpPredicate::wne ||
53  adaptor.getPredicate() == ICmpPredicate::cne)
54  return rewriter.notifyMatchFailure(op,
55  "comparison predicate not supported");
56 
57  if (adaptor.getPredicate() == ICmpPredicate::eq) {
58  rewriter.replaceOpWithNewOp<smt::EqOp>(op, adaptor.getLhs(),
59  adaptor.getRhs());
60  return success();
61  }
62 
63  if (adaptor.getPredicate() == ICmpPredicate::ne) {
64  rewriter.replaceOpWithNewOp<smt::DistinctOp>(op, adaptor.getLhs(),
65  adaptor.getRhs());
66  return success();
67  }
68 
69  smt::BVCmpPredicate pred;
70  switch (adaptor.getPredicate()) {
71  case ICmpPredicate::sge:
72  pred = smt::BVCmpPredicate::sge;
73  break;
74  case ICmpPredicate::sgt:
75  pred = smt::BVCmpPredicate::sgt;
76  break;
77  case ICmpPredicate::sle:
78  pred = smt::BVCmpPredicate::sle;
79  break;
80  case ICmpPredicate::slt:
81  pred = smt::BVCmpPredicate::slt;
82  break;
83  case ICmpPredicate::uge:
84  pred = smt::BVCmpPredicate::uge;
85  break;
86  case ICmpPredicate::ugt:
87  pred = smt::BVCmpPredicate::ugt;
88  break;
89  case ICmpPredicate::ule:
90  pred = smt::BVCmpPredicate::ule;
91  break;
92  case ICmpPredicate::ult:
93  pred = smt::BVCmpPredicate::ult;
94  break;
95  default:
96  llvm_unreachable("all cases handled above");
97  }
98 
99  rewriter.replaceOpWithNewOp<smt::BVCmpOp>(op, pred, adaptor.getLhs(),
100  adaptor.getRhs());
101  return success();
102  }
103 };
104 
105 /// Lower a comb::ExtractOp operation to an smt::ExtractOp
106 struct ExtractOpConversion : OpConversionPattern<ExtractOp> {
108 
109  LogicalResult
110  matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
111  ConversionPatternRewriter &rewriter) const override {
112 
113  rewriter.replaceOpWithNewOp<smt::ExtractOp>(
114  op, typeConverter->convertType(op.getResult().getType()),
115  adaptor.getLowBitAttr(), adaptor.getInput());
116  return success();
117  }
118 };
119 
120 /// Lower a comb::MuxOp operation to an smt::IteOp
121 struct MuxOpConversion : OpConversionPattern<MuxOp> {
123 
124  LogicalResult
125  matchAndRewrite(MuxOp op, OpAdaptor adaptor,
126  ConversionPatternRewriter &rewriter) const override {
127  Value condition = typeConverter->materializeTargetConversion(
128  rewriter, op.getLoc(), smt::BoolType::get(getContext()),
129  adaptor.getCond());
130  rewriter.replaceOpWithNewOp<smt::IteOp>(
131  op, condition, adaptor.getTrueValue(), adaptor.getFalseValue());
132  return success();
133  }
134 };
135 
136 /// Lower a comb::SubOp operation to an smt::BVNegOp + smt::BVAddOp
137 struct SubOpConversion : OpConversionPattern<SubOp> {
139 
140  LogicalResult
141  matchAndRewrite(SubOp op, OpAdaptor adaptor,
142  ConversionPatternRewriter &rewriter) const override {
143  Value negRhs = rewriter.create<smt::BVNegOp>(op.getLoc(), adaptor.getRhs());
144  rewriter.replaceOpWithNewOp<smt::BVAddOp>(op, adaptor.getLhs(), negRhs);
145  return success();
146  }
147 };
148 
149 /// Lower a comb::ParityOp operation to a chain of smt::Extract + XOr ops
150 struct ParityOpConversion : OpConversionPattern<ParityOp> {
152 
153  LogicalResult
154  matchAndRewrite(ParityOp op, OpAdaptor adaptor,
155  ConversionPatternRewriter &rewriter) const override {
156  Location loc = op.getLoc();
157  unsigned bitwidth =
158  cast<smt::BitVectorType>(adaptor.getInput().getType()).getWidth();
159 
160  // Note: the SMT bitvector type does not support 0 bitwidth vectors and thus
161  // the type conversion should already fail.
162  Type oneBitTy = smt::BitVectorType::get(getContext(), 1);
163  Value runner =
164  rewriter.create<smt::ExtractOp>(loc, oneBitTy, 0, adaptor.getInput());
165  for (unsigned i = 1; i < bitwidth; ++i) {
166  Value ext =
167  rewriter.create<smt::ExtractOp>(loc, oneBitTy, i, adaptor.getInput());
168  runner = rewriter.create<smt::BVXOrOp>(loc, runner, ext);
169  }
170 
171  rewriter.replaceOp(op, runner);
172  return success();
173  }
174 };
175 
176 /// Lower the SourceOp to the TargetOp one-to-one.
177 template <typename SourceOp, typename TargetOp>
178 struct OneToOneOpConversion : OpConversionPattern<SourceOp> {
180  using OpAdaptor = typename SourceOp::Adaptor;
181 
182  LogicalResult
183  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
184  ConversionPatternRewriter &rewriter) const override {
185 
186  rewriter.replaceOpWithNewOp<TargetOp>(
187  op,
189  op.getResult().getType()),
190  adaptor.getOperands());
191  return success();
192  }
193 };
194 
195 /// Converts an operation with a variadic number of operands to a chain of
196 /// binary operations assuming left-associativity of the operation.
197 template <typename SourceOp, typename TargetOp>
198 struct VariadicToBinaryOpConversion : OpConversionPattern<SourceOp> {
200  using OpAdaptor = typename SourceOp::Adaptor;
201 
202  LogicalResult
203  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
204  ConversionPatternRewriter &rewriter) const override {
205 
206  ValueRange operands = adaptor.getOperands();
207  if (operands.size() < 2)
208  return failure();
209 
210  Value runner = operands[0];
211  for (Value operand : operands.drop_front())
212  runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
213 
214  rewriter.replaceOp(op, runner);
215  return success();
216  }
217 };
218 
219 } // namespace
220 
221 //===----------------------------------------------------------------------===//
222 // Convert Comb to SMT pass
223 //===----------------------------------------------------------------------===//
224 
225 namespace {
226 struct ConvertCombToSMTPass
227  : public impl::ConvertCombToSMTBase<ConvertCombToSMTPass> {
228  void runOnOperation() override;
229 };
230 } // namespace
231 
232 void circt::populateCombToSMTConversionPatterns(TypeConverter &converter,
233  RewritePatternSet &patterns) {
234  patterns.add<CombReplicateOpConversion, IcmpOpConversion, ExtractOpConversion,
235  SubOpConversion, MuxOpConversion, ParityOpConversion,
236  OneToOneOpConversion<ShlOp, smt::BVShlOp>,
237  OneToOneOpConversion<ShrUOp, smt::BVLShrOp>,
238  OneToOneOpConversion<ShrSOp, smt::BVAShrOp>,
239  OneToOneOpConversion<DivSOp, smt::BVSDivOp>,
240  OneToOneOpConversion<DivUOp, smt::BVUDivOp>,
241  OneToOneOpConversion<ModSOp, smt::BVSRemOp>,
242  OneToOneOpConversion<ModUOp, smt::BVURemOp>,
243  VariadicToBinaryOpConversion<ConcatOp, smt::ConcatOp>,
244  VariadicToBinaryOpConversion<AddOp, smt::BVAddOp>,
245  VariadicToBinaryOpConversion<MulOp, smt::BVMulOp>,
246  VariadicToBinaryOpConversion<AndOp, smt::BVAndOp>,
247  VariadicToBinaryOpConversion<OrOp, smt::BVOrOp>,
248  VariadicToBinaryOpConversion<XorOp, smt::BVXOrOp>>(
249  converter, patterns.getContext());
250 
251  // TODO: there are two unsupported operations in the comb dialect: 'parity'
252  // and 'truth_table'.
253 }
254 
255 void ConvertCombToSMTPass::runOnOperation() {
256  ConversionTarget target(getContext());
257  target.addIllegalDialect<comb::CombDialect>();
258  target.addLegalDialect<smt::SMTDialect>();
259 
260  RewritePatternSet patterns(&getContext());
261  TypeConverter converter;
262  populateHWToSMTTypeConverter(converter);
263  // Also add HW patterns because some 'comb' canonicalizers produce constant
264  // operations, i.e., even if there is absolutely no HW operation present
265  // initially, we might have to convert one.
268 
269  if (failed(mlir::applyPartialConversion(getOperation(), target,
270  std::move(patterns))))
271  return signalPassFailure();
272 }
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
void populateHWToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
Definition: HWToSMT.cpp:174
void populateCombToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the HW to SMT conversion patterns.
Definition: CombToSMT.cpp:232
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition: HWToSMT.cpp:108
Definition: comb.py:1