CIRCT 22.0.0git
Loading...
Searching...
No Matches
ReduceDelay.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15
16namespace circt {
17namespace datapath {
18#define GEN_PASS_DEF_DATAPATHREDUCEDELAY
19#include "circt/Dialect/Datapath/DatapathPasses.h.inc"
20} // namespace datapath
21} // namespace circt
22
23using namespace circt;
24using namespace datapath;
25using namespace mlir;
26
27namespace {
28
29// Fold add operations even if used multiple times incurring area overhead as
30// transformation reduces shared logic - but reduces delay
31// add = a + b;
32// out1 = add + c;
33// out2 = add << d;
34// -->
35// add = a + b;
36// comp1 = compress(a, b, c);
37// out1 = comp1[0] + comp1[1];
38// out2 = add << d;
39struct FoldAddReplicate : public OpRewritePattern<comb::AddOp> {
40 using OpRewritePattern::OpRewritePattern;
41
42 LogicalResult matchAndRewrite(comb::AddOp addOp,
43 PatternRewriter &rewriter) const override {
44
45 SmallVector<Value, 8> newCompressOperands;
46 // Check if any operand is an AddOp that has not be folded by comb folds
47 for (Value operand : addOp.getOperands()) {
48 if (comb::AddOp nestedAddOp = operand.getDefiningOp<comb::AddOp>()) {
49 llvm::append_range(newCompressOperands, nestedAddOp.getOperands());
50 } else {
51 newCompressOperands.push_back(operand);
52 }
53 }
54
55 // Nothing to be folded
56 if (newCompressOperands.size() <= addOp.getNumOperands())
57 return failure();
58
59 // Create a new CompressOp with all collected operands
60 auto newCompressOp = datapath::CompressOp::create(rewriter, addOp.getLoc(),
61 newCompressOperands, 2);
62
63 // Add the results of the CompressOp
64 rewriter.replaceOpWithNewOp<comb::AddOp>(addOp, newCompressOp.getResults(),
65 true);
66 return success();
67 }
68};
69
70// (a ? b + c : d + e) + f
71// -->
72// (a ? b : d) + (a ? c : e) + f
73struct FoldMuxAdd : public OpRewritePattern<comb::AddOp> {
74 using OpRewritePattern::OpRewritePattern;
75
76 // When used in conjunction with datapath canonicalization will only replicate
77 // two input adders.
78 LogicalResult matchAndRewrite(comb::AddOp addOp,
79 PatternRewriter &rewriter) const override {
80
81 SmallVector<Value, 8> newCompressOperands;
82 for (Value operand : addOp.getOperands()) {
83 // Detect a mux operand - then check if it contains add operations
84 comb::MuxOp nestedMuxOp = operand.getDefiningOp<comb::MuxOp>();
85
86 // If not matched just add the operand without modification
87 if (!nestedMuxOp) {
88 newCompressOperands.push_back(operand);
89 continue;
90 }
91
92 SmallVector<Value> trueValOperands = {nestedMuxOp.getTrueValue()};
93 SmallVector<Value> falseValOperands = {nestedMuxOp.getFalseValue()};
94 // match a ? b + c : xx
95 if (comb::AddOp trueVal =
96 nestedMuxOp.getTrueValue().getDefiningOp<comb::AddOp>())
97 trueValOperands = trueVal.getOperands();
98
99 // match a ? xx : c + d
100 if (comb::AddOp falseVal =
101 nestedMuxOp.getFalseValue().getDefiningOp<comb::AddOp>())
102 falseValOperands = falseVal.getOperands();
103
104 auto maxOperands =
105 std::max(trueValOperands.size(), falseValOperands.size());
106
107 // No nested additions
108 if (maxOperands <= 1) {
109 newCompressOperands.push_back(operand);
110 continue;
111 }
112
113 // Pad with zeros to match number of operands
114 // a ? b + c : d -> (a ? b : d) + (a ? c : 0)
115 auto zero =
116 hw::ConstantOp::create(rewriter, addOp.getLoc(),
117 rewriter.getIntegerAttr(addOp.getType(), 0));
118 for (size_t i = 0; i < maxOperands; ++i) {
119 auto tOp = i < trueValOperands.size() ? trueValOperands[i] : zero;
120 auto fOp = i < falseValOperands.size() ? falseValOperands[i] : zero;
121 auto newMux = comb::MuxOp::create(rewriter, addOp.getLoc(),
122 nestedMuxOp.getCond(), tOp, fOp);
123 newCompressOperands.push_back(newMux.getResult());
124 }
125 }
126
127 // Nothing to be folded
128 if (newCompressOperands.size() <= addOp.getNumOperands())
129 return failure();
130
131 // Create a new CompressOp with all collected operands
132 auto newCompressOp = datapath::CompressOp::create(rewriter, addOp.getLoc(),
133 newCompressOperands, 2);
134
135 // Add the results of the CompressOp
136 rewriter.replaceOpWithNewOp<comb::AddOp>(addOp, newCompressOp.getResults(),
137 true);
138 return success();
139 }
140};
141
142struct ConvertCmpToAdd : public OpRewritePattern<comb::ICmpOp> {
143 using OpRewritePattern::OpRewritePattern;
144
145 // Applicable to unsigned comparisons without overflow:
146 // a + b < c + d
147 // -->
148 // msb( {0,a} + {0,b} - {0,c} - {0,d} )
149 LogicalResult matchAndRewrite(comb::ICmpOp op,
150 PatternRewriter &rewriter) const override {
151 Value lhs = op.getLhs();
152 Value rhs = op.getRhs();
153 auto width = lhs.getType().getIntOrFloatBitWidth();
154
155 // Only unsigned comparisons
156 if (op.getPredicate() != comb::ICmpPredicate::ult &&
157 op.getPredicate() != comb::ICmpPredicate::ule &&
158 op.getPredicate() != comb::ICmpPredicate::ugt &&
159 op.getPredicate() != comb::ICmpPredicate::uge)
160 return failure();
161
162 // lhsMinusRhs invertOut
163 //---------------------------------------------------------------------
164 // ult: a < b -> a - b < 0 true false
165 // uge: a > b -> b - a < 0 false false
166 // uge: a >= b -> !(a < b) -> !(a - b < 0) true true
167 // ule: a <= b -> !(a > b) -> !(b - a < 0) false true
168 bool lhsMinusRhs = op.getPredicate() == comb::ICmpPredicate::ult ||
169 op.getPredicate() == comb::ICmpPredicate::uge;
170
171 bool invertOut = op.getPredicate() == comb::ICmpPredicate::uge ||
172 op.getPredicate() == comb::ICmpPredicate::ule;
173
174 // Compute rhs - lhs
175 if (!lhsMinusRhs)
176 std::swap(lhs, rhs);
177 SmallVector<Value> lhsAddends = {lhs};
178 // Detect adder inputs to either side of the comparison and detect overflow
179 if (comb::AddOp lhsAdd = lhs.getDefiningOp<comb::AddOp>()) {
180 // Check for no unsigned wrap (i.e. no overflow bits get truncated)
181 if (lhsAdd->getAttrOfType<UnitAttr>("comb.nuw"))
182 lhsAddends = lhsAdd.getOperands();
183 }
184
185 SmallVector<Value> rhsAddends = {rhs};
186 // Detect adder inputs to either side of the comparison and detect overflow
187 if (comb::AddOp rhsAdd = rhs.getDefiningOp<comb::AddOp>()) {
188 // Check for no unsigned wrap (i.e. no overflow bits get truncated)
189 if (rhsAdd->getAttrOfType<UnitAttr>("comb.nuw"))
190 rhsAddends = rhsAdd.getOperands();
191 }
192
193 // No benefit to folding into a single addition - more expensive than
194 // the original comparison
195 if (lhsAddends.size() + rhsAddends.size() < 3)
196 return failure();
197
198 SmallVector<Value> lhsExtend;
199 for (auto addend : lhsAddends) {
200 auto ext = comb::createZExt(rewriter, op.getLoc(), addend, width + 1);
201 lhsExtend.push_back(ext);
202 }
203
204 SmallVector<Value> rhsExtend;
205 for (auto addend : rhsAddends) {
206 auto ext = comb::createZExt(rewriter, op.getLoc(), addend, width + 1);
207 auto negatedAddend = comb::createOrFoldNot(op.getLoc(), ext, rewriter);
208 rhsExtend.push_back(negatedAddend);
209 }
210
211 rhsExtend.push_back(hw::ConstantOp::create(
212 rewriter, op.getLoc(), APInt(width + 1, rhsExtend.size())));
213
214 SmallVector<Value> allAddends = std::move(lhsExtend);
215 llvm::append_range(allAddends, rhsExtend);
216 auto add = comb::AddOp::create(rewriter, op.getLoc(), allAddends, false);
217 auto msb = rewriter.createOrFold<comb::ExtractOp>(
218 op.getLoc(), add.getResult(), width, 1);
219
220 if (!invertOut) {
221 rewriter.replaceOp(op, msb);
222 return success();
223 }
224
225 auto notOp = comb::createOrFoldNot(op.getLoc(), msb, rewriter);
226 rewriter.replaceOp(op, notOp);
227 return success();
228 }
229};
230
231} // namespace
232
233namespace {
234struct DatapathReduceDelayPass
235 : public circt::datapath::impl::DatapathReduceDelayBase<
236 DatapathReduceDelayPass> {
237
238 void runOnOperation() override {
239 Operation *op = getOperation();
240 MLIRContext *ctx = op->getContext();
241
242 RewritePatternSet patterns(ctx);
243 patterns.add<FoldAddReplicate, FoldMuxAdd, ConvertCmpToAdd>(ctx);
244
245 if (failed(applyPatternsGreedily(op, std::move(patterns))))
246 signalPassFailure();
247 };
248};
249} // namespace
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.