CIRCT 22.0.0git
Loading...
Searching...
No Matches
DatapathToSMT.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Dialect/SMT/IR/SMTOps.h"
13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/DialectConversion.h"
15
16namespace circt {
17#define GEN_PASS_DEF_CONVERTDATAPATHTOSMT
18#include "circt/Conversion/Passes.h.inc"
19} // namespace circt
20
21using namespace mlir;
22using namespace circt;
23using namespace datapath;
24
25//===----------------------------------------------------------------------===//
26// Conversion patterns
27//===----------------------------------------------------------------------===//
28
29namespace {
30
31// Lower to an SMT assertion that summing the results is equivalent to summing
32// the compress inputs
33// d:2 = compress(a, b, c) ->
34// assert(d#0 + d#1 == a + b + c)
35struct CompressOpConversion : OpConversionPattern<CompressOp> {
37
38 LogicalResult
39 matchAndRewrite(CompressOp op, OpAdaptor adaptor,
40 ConversionPatternRewriter &rewriter) const override {
41
42 ValueRange operands = adaptor.getOperands();
43 ValueRange results = op.getResults();
44
45 // Sum operands
46 Value operandRunner = operands[0];
47 for (Value operand : operands.drop_front())
48 operandRunner =
49 smt::BVAddOp::create(rewriter, op.getLoc(), operandRunner, operand);
50
51 // Create free variables
52 SmallVector<Value, 2> newResults;
53 newResults.reserve(results.size());
54 for (Value result : results) {
55 auto declareFunOp = smt::DeclareFunOp::create(
56 rewriter, op.getLoc(), typeConverter->convertType(result.getType()));
57 newResults.push_back(declareFunOp.getResult());
58 }
59
60 // Sum the free variables
61 Value resultRunner = newResults.front();
62 for (auto freeVar : llvm::drop_begin(newResults, 1))
63 resultRunner =
64 smt::BVAddOp::create(rewriter, op.getLoc(), resultRunner, freeVar);
65
66 // Assert sum operands == sum results (free variables)
67 auto premise =
68 smt::EqOp::create(rewriter, op.getLoc(), operandRunner, resultRunner);
69 // Encode via an assertion (could be relaxed to an assumption).
70 smt::AssertOp::create(rewriter, op.getLoc(), premise);
71
72 if (newResults.size() != results.size())
73 return rewriter.notifyMatchFailure(op, "expected same number of results");
74
75 rewriter.replaceOp(op, newResults);
76 return success();
77 }
78};
79
80// Lower to an SMT assertion that summing the results is equivalent to the
81// product of the partial_product inputs
82// c:<N> = partial_product(a, b) ->
83// assert(c#0 + ... + c#<N-1> == a * b)
84struct PartialProductOpConversion : OpConversionPattern<PartialProductOp> {
85 using OpConversionPattern<PartialProductOp>::OpConversionPattern;
86
87 LogicalResult
88 matchAndRewrite(PartialProductOp op, OpAdaptor adaptor,
89 ConversionPatternRewriter &rewriter) const override {
90
91 ValueRange operands = adaptor.getOperands();
92 ValueRange results = op.getResults();
93
94 // Multiply the operands
95 auto mulResult =
96 smt::BVMulOp::create(rewriter, op.getLoc(), operands[0], operands[1]);
97
98 // Create free variables
99 SmallVector<Value, 2> newResults;
100 newResults.reserve(results.size());
101 for (Value result : results) {
102 auto declareFunOp = smt::DeclareFunOp::create(
103 rewriter, op.getLoc(), typeConverter->convertType(result.getType()));
104 newResults.push_back(declareFunOp.getResult());
105 }
106
107 // Sum the free variables
108 Value resultRunner = newResults.front();
109 for (auto freeVar : llvm::drop_begin(newResults, 1))
110 resultRunner =
111 smt::BVAddOp::create(rewriter, op.getLoc(), resultRunner, freeVar);
112
113 // Assert product of operands == sum results (free variables)
114 auto premise =
115 smt::EqOp::create(rewriter, op.getLoc(), mulResult, resultRunner);
116 // Encode via an assertion (could be relaxed to an assumption).
117 smt::AssertOp::create(rewriter, op.getLoc(), premise);
118
119 if (newResults.size() != results.size())
120 return rewriter.notifyMatchFailure(op, "expected same number of results");
121
122 rewriter.replaceOp(op, newResults);
123 return success();
124 }
125};
126} // namespace
127
128//===----------------------------------------------------------------------===//
129// Convert Datapath to SMT pass
130//===----------------------------------------------------------------------===//
131
132namespace {
133struct ConvertDatapathToSMTPass
134 : public circt::impl::ConvertDatapathToSMTBase<ConvertDatapathToSMTPass> {
135 void runOnOperation() override;
136};
137} // namespace
138
140 TypeConverter &converter, RewritePatternSet &patterns) {
141 patterns.add<CompressOpConversion, PartialProductOpConversion>(
142 converter, patterns.getContext());
143}
144
145void ConvertDatapathToSMTPass::runOnOperation() {
146 ConversionTarget target(getContext());
147 target.addIllegalDialect<datapath::DatapathDialect>();
148 target.addLegalDialect<smt::SMTDialect>();
149
150 RewritePatternSet patterns(&getContext());
151 TypeConverter converter;
154
155 if (failed(mlir::applyPartialConversion(getOperation(), target,
156 std::move(patterns))))
157 return signalPassFailure();
158}
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateDatapathToSMTConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
Get the Datapath to SMT conversion patterns.
void populateHWToSMTTypeConverter(TypeConverter &converter)
Get the HW to SMT type conversions.
Definition HWToSMT.cpp:218