CIRCT 22.0.0git
Loading...
Searching...
No Matches
DatapathToComb.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/Analysis/TopologicalSortUtils.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/KnownBits.h"
21
22#define DEBUG_TYPE "datapath-to-comb"
23
24namespace circt {
25#define GEN_PASS_DEF_CONVERTDATAPATHTOCOMB
26#include "circt/Conversion/Passes.h.inc"
27} // namespace circt
28
29using namespace circt;
30using namespace datapath;
31
32// A wrapper for comb::extractBits that returns a SmallVector<Value>.
33static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
34 SmallVector<Value> bits;
35 comb::extractBits(builder, val, bits);
36 return bits;
37}
38
39//===----------------------------------------------------------------------===//
40// Conversion patterns
41//===----------------------------------------------------------------------===//
42
43namespace {
44// Replace compressor by an adder of the inputs and zero for the other results:
45// compress(a,b,c,d) -> {a+b+c+d, 0}
46// Facilitates use of downstream compression algorithms e.g. Yosys
47struct DatapathCompressOpAddConversion : mlir::OpRewritePattern<CompressOp> {
49 LogicalResult
50 matchAndRewrite(CompressOp op,
51 mlir::PatternRewriter &rewriter) const override {
52 Location loc = op.getLoc();
53 auto inputs = op.getOperands();
54 unsigned width = inputs[0].getType().getIntOrFloatBitWidth();
55 // Sum all the inputs - set that to result value 0
56 auto addOp = comb::AddOp::create(rewriter, loc, inputs, true);
57 // Replace remaining results with zeros
58 auto zeroOp = hw::ConstantOp::create(rewriter, loc, APInt(width, 0));
59 SmallVector<Value> results(op.getNumResults() - 1, zeroOp);
60 results.push_back(addOp);
61 rewriter.replaceOp(op, results);
62 return success();
63 }
64};
65
66// Replace compressor by a wallace tree of full-adders
67struct DatapathCompressOpConversion : mlir::OpRewritePattern<CompressOp> {
68 DatapathCompressOpConversion(MLIRContext *context,
70 : mlir::OpRewritePattern<CompressOp>(context), analysis(analysis) {}
71
72 LogicalResult
73 matchAndRewrite(CompressOp op,
74 mlir::PatternRewriter &rewriter) const override {
75 Location loc = op.getLoc();
76 auto inputs = op.getOperands();
77
78 SmallVector<SmallVector<Value>> addends;
79 for (auto input : inputs) {
80 addends.push_back(
81 extractBits(rewriter, input)); // Extract bits from each input
82 }
83
84 // Compressor tree reduction
85 auto width = inputs[0].getType().getIntOrFloatBitWidth();
86 auto targetAddends = op.getNumResults();
87 datapath::CompressorTree comp(width, addends, loc);
88
89 if (analysis) {
90 // Update delay information with arrival times
91 if (failed(comp.withInputDelays(
92 [&](Value v) { return analysis->getMaxDelay(v, 0); })))
93 return failure();
94 }
95
96 rewriter.replaceOp(op, comp.compressToHeight(rewriter, targetAddends));
97 return success();
98 }
99
100private:
101 aig::IncrementalLongestPathAnalysis *analysis = nullptr;
102};
103
104struct DatapathPartialProductOpConversion : OpRewritePattern<PartialProductOp> {
105 using OpRewritePattern<PartialProductOp>::OpRewritePattern;
106
107 DatapathPartialProductOpConversion(MLIRContext *context, bool forceBooth)
108 : OpRewritePattern<PartialProductOp>(context), forceBooth(forceBooth){};
109
110 const bool forceBooth;
111
112 LogicalResult matchAndRewrite(PartialProductOp op,
113 PatternRewriter &rewriter) const override {
114
115 Value a = op.getLhs();
116 Value b = op.getRhs();
117 unsigned width = a.getType().getIntOrFloatBitWidth();
118
119 // Skip a zero width value.
120 if (width == 0) {
121 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(0), 0);
122 return success();
123 }
124
125 // Use result rows as a heuristic to guide partial product
126 // implementation
127 if (op.getNumResults() > 16 || forceBooth)
128 return lowerBoothArray(rewriter, a, b, op, width);
129 else
130 return lowerAndArray(rewriter, a, b, op, width);
131 }
132
133private:
134 static LogicalResult lowerAndArray(PatternRewriter &rewriter, Value a,
135 Value b, PartialProductOp op,
136 unsigned width) {
137
138 Location loc = op.getLoc();
139 // Keep a as a bitvector - multiply by each digit of b
140 SmallVector<Value> bBits = extractBits(rewriter, b);
141
142 SmallVector<Value> partialProducts;
143 partialProducts.reserve(width);
144 // AND Array Construction:
145 // partialProducts[i] = ({b[i],..., b[i]} & a) << i
146 assert(op.getNumResults() <= width &&
147 "Cannot return more results than the operator width");
148
149 for (unsigned i = 0; i < op.getNumResults(); ++i) {
150 auto repl =
151 rewriter.createOrFold<comb::ReplicateOp>(loc, bBits[i], width);
152 auto ppRow = rewriter.createOrFold<comb::AndOp>(loc, repl, a);
153 if (i == 0) {
154 partialProducts.push_back(ppRow);
155 continue;
156 }
157 auto shiftBy = hw::ConstantOp::create(rewriter, loc, APInt(i, 0));
158 auto ppAlign =
159 comb::ConcatOp::create(rewriter, loc, ValueRange{ppRow, shiftBy});
160 auto ppAlignTrunc = rewriter.createOrFold<comb::ExtractOp>(
161 loc, ppAlign, 0, width); // Truncate to width+i bits
162 partialProducts.push_back(ppAlignTrunc);
163 }
164
165 rewriter.replaceOp(op, partialProducts);
166 return success();
167 }
168
169 static LogicalResult lowerBoothArray(PatternRewriter &rewriter, Value a,
170 Value b, PartialProductOp op,
171 unsigned width) {
172 Location loc = op.getLoc();
173 auto zeroFalse = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));
174 auto zeroWidth = hw::ConstantOp::create(rewriter, loc, APInt(width, 0));
175
176 // Detect leading zeros in multiplicand due to zero-extension
177 // and truncate to reduce partial product bits
178 // {'0, a} * {'0, b}
179 auto rowWidth = width;
180 auto knownBitsA = comb::computeKnownBits(a);
181 if (!knownBitsA.Zero.isZero()) {
182 if (knownBitsA.Zero.countLeadingOnes() > 1) {
183 // Retain one leading zero to represent 2*{1'b0, a} = {a, 1'b0}
184 // {'0, a} -> {1'b0, a}
185 rowWidth -= knownBitsA.Zero.countLeadingOnes() - 1;
186 a = rewriter.createOrFold<comb::ExtractOp>(loc, a, 0, rowWidth);
187 }
188 }
189 auto oneRowWidth =
190 hw::ConstantOp::create(rewriter, loc, APInt(rowWidth, 1));
191 // Booth encoding will select each row from {-2a, -1a, 0, 1a, 2a}
192 Value twoA = rewriter.createOrFold<comb::ShlOp>(loc, a, oneRowWidth);
193
194 // Encode based on the bits of b
195 // TODO: sort a and b based on non-zero bits to encode the smaller input
196 SmallVector<Value> bBits = extractBits(rewriter, b);
197
198 // Identify zero bits of b to reduce height of partial product array
199 auto knownBitsB = comb::computeKnownBits(b);
200 if (!knownBitsB.Zero.isZero()) {
201 for (unsigned i = 0; i < width; ++i)
202 if (knownBitsB.Zero[i])
203 bBits[i] = zeroFalse;
204 }
205
206 SmallVector<Value> partialProducts;
207 partialProducts.reserve(width);
208
209 // Booth encoding halves array height by grouping three bits at a time:
210 // partialProducts[i] = a * (-2*b[2*i+1] + b[2*i] + b[2*i-1]) << 2*i
211 // encNeg \approx (-2*b[2*i+1] + b[2*i] + b[2*i-1]) <= 0
212 // encOne = (-2*b[2*i+1] + b[2*i] + b[2*i-1]) == +/- 1
213 // encTwo = (-2*b[2*i+1] + b[2*i] + b[2*i-1]) == +/- 2
214 Value encNegPrev;
215
216 // For even width - additional row contains the final sign correction
217 for (unsigned i = 0; i <= width; i += 2) {
218 // Get Booth bits: b[i+1], b[i], b[i-1] (b[-1] = 0)
219 Value bim1 = (i == 0) ? zeroFalse : bBits[i - 1];
220 Value bi = (i < width) ? bBits[i] : zeroFalse;
221 Value bip1 = (i + 1 < width) ? bBits[i + 1] : zeroFalse;
222
223 // Is the encoding zero or negative (an approximation)
224 Value encNeg = bip1;
225 // Is the encoding one = b[i] xor b[i-1]
226 Value encOne = rewriter.createOrFold<comb::XorOp>(loc, bi, bim1, true);
227 // Is the encoding two = (bip1 & ~bi & ~bim1) | (~bip1 & bi & bim1)
228 Value constOne = hw::ConstantOp::create(rewriter, loc, APInt(1, 1));
229 Value biInv = rewriter.createOrFold<comb::XorOp>(loc, bi, constOne, true);
230 Value bip1Inv =
231 rewriter.createOrFold<comb::XorOp>(loc, bip1, constOne, true);
232 Value bim1Inv =
233 rewriter.createOrFold<comb::XorOp>(loc, bim1, constOne, true);
234
235 Value andLeft = rewriter.createOrFold<comb::AndOp>(
236 loc, ValueRange{bip1Inv, bi, bim1}, true);
237 Value andRight = rewriter.createOrFold<comb::AndOp>(
238 loc, ValueRange{bip1, biInv, bim1Inv}, true);
239 Value encTwo =
240 rewriter.createOrFold<comb::OrOp>(loc, andLeft, andRight, true);
241
242 Value encNegRepl =
243 rewriter.createOrFold<comb::ReplicateOp>(loc, encNeg, rowWidth);
244 Value encOneRepl =
245 rewriter.createOrFold<comb::ReplicateOp>(loc, encOne, rowWidth);
246 Value encTwoRepl =
247 rewriter.createOrFold<comb::ReplicateOp>(loc, encTwo, rowWidth);
248
249 // Select between 2*a or 1*a or 0*a
250 Value selTwoA = rewriter.createOrFold<comb::AndOp>(loc, encTwoRepl, twoA);
251 Value selOneA = rewriter.createOrFold<comb::AndOp>(loc, encOneRepl, a);
252 Value magA =
253 rewriter.createOrFold<comb::OrOp>(loc, selTwoA, selOneA, true);
254
255 // Conditionally invert the row
256 Value ppRow =
257 rewriter.createOrFold<comb::XorOp>(loc, magA, encNegRepl, true);
258
259 // Sign-extension Optimisation:
260 // Section 7.2.2 of "Application Specific Arithmetic" by Dinechin &
261 // Kumm Handle sign-extension and padding to full width s = encNeg
262 // (sign-bit) {s, s, s, s, s, pp} = {1, 1, 1, 1, 1, pp}
263 // + {0, 0, 0, 0,!s, '0}
264 // Applying this to every row we create an upper-triangle of 1s that
265 // can be optimised away since they will not affect the final sum.
266 // {!s3, 0,!s2, 0,!s1, 0}
267 // { 1, 1, 1, 1, 1, p1}
268 // { 1, 1, 1, p2 }
269 // { 1, p3 }
270 if (rowWidth < width) {
271 auto padding = width - rowWidth;
272 auto encNegInv = bip1Inv;
273
274 // Sign-extension trick not worth it for padding < 3
275 if (padding < 3) {
276 Value encNegPad =
277 rewriter.createOrFold<comb::ReplicateOp>(loc, encNeg, padding);
278 ppRow = rewriter.createOrFold<comb::ConcatOp>(
279 loc, ValueRange{encNegPad, ppRow}); // Pad to full width
280 } else if (i == 0) {
281 // First row = {!encNeg, encNeg, encNeg, ppRow}
282 ppRow = rewriter.createOrFold<comb::ConcatOp>(
283 loc, ValueRange{encNegInv, encNeg, encNeg, ppRow});
284 } else {
285 // Remaining rows = {1, !encNeg, ppRow}
286 ppRow = rewriter.createOrFold<comb::ConcatOp>(
287 loc, ValueRange{constOne, encNegInv, ppRow});
288 }
289
290 // Zero pad to full width
291 auto rowWidth = ppRow.getType().getIntOrFloatBitWidth();
292 if (rowWidth < width) {
293 auto zeroPad =
294 hw::ConstantOp::create(rewriter, loc, APInt(width - rowWidth, 0));
295 ppRow = rewriter.createOrFold<comb::ConcatOp>(
296 loc, ValueRange{zeroPad, ppRow});
297 }
298 }
299
300 // No sign-correction in the first row
301 if (i == 0) {
302 partialProducts.push_back(ppRow);
303 encNegPrev = encNeg;
304 continue;
305 }
306
307 // Insert a sign-correction from the previous row
308 assert(i >= 2 && "Expected i to be at least 2 for sign correction");
309 // {ppRow, 0, encNegPrev} << 2*(i-1)
310 Value withSignCorrection = rewriter.createOrFold<comb::ConcatOp>(
311 loc, ValueRange{ppRow, zeroFalse, encNegPrev});
312 Value ppAlignPre = rewriter.createOrFold<comb::ExtractOp>(
313 loc, withSignCorrection, 0, width);
314 Value shiftBy =
315 hw::ConstantOp::create(rewriter, loc, APInt(width, i - 2));
316 Value ppAlign =
317 rewriter.createOrFold<comb::ShlOp>(loc, ppAlignPre, shiftBy);
318 partialProducts.push_back(ppAlign);
319 encNegPrev = encNeg;
320
321 if (partialProducts.size() == op.getNumResults())
322 break;
323 }
324
325 // Zero-pad to match the required output width
326 while (partialProducts.size() < op.getNumResults())
327 partialProducts.push_back(zeroWidth);
328
329 assert(partialProducts.size() == op.getNumResults() &&
330 "Expected number of booth partial products to match results");
331
332 rewriter.replaceOp(op, partialProducts);
333 return success();
334 }
335};
336} // namespace
337
338//===----------------------------------------------------------------------===//
339// Convert Datapath to Comb pass
340//===----------------------------------------------------------------------===//
341
342namespace {
343struct ConvertDatapathToCombPass
344 : public impl::ConvertDatapathToCombBase<ConvertDatapathToCombPass> {
345 void runOnOperation() override;
346 using ConvertDatapathToCombBase<
347 ConvertDatapathToCombPass>::ConvertDatapathToCombBase;
348};
349} // namespace
350
352 Operation *op, RewritePatternSet &&patterns,
354 // TODO: Topologically sort the operations in the module to ensure that all
355 // dependencies are processed before their users.
356 mlir::GreedyRewriteConfig config;
357 // Set the listener to update timing information
358 // HACK: Setting max iterations to 2 to ensure that the patterns are
359 // one-shot, making sure target operations are datapath operations are
360 // replaced.
361 config.setMaxIterations(2).setListener(analysis).setUseTopDownTraversal(true);
362
363 // Apply the patterns greedily
364 if (failed(mlir::applyPatternsGreedily(op, std::move(patterns), config)))
365 return failure();
366
367 return success();
368}
369
370void ConvertDatapathToCombPass::runOnOperation() {
371 RewritePatternSet patterns(&getContext());
372
373 patterns.add<DatapathPartialProductOpConversion>(patterns.getContext(),
374 forceBooth);
375 aig::IncrementalLongestPathAnalysis *analysis = nullptr;
376 if (timingAware)
377 analysis = &getAnalysis<aig::IncrementalLongestPathAnalysis>();
378 if (lowerCompressToAdd)
379 // Lower compressors to simple add operations for downstream optimisations
380 patterns.add<DatapathCompressOpAddConversion>(patterns.getContext());
381 else
382 // Lower compressors to a complete gate-level implementation
383 patterns.add<DatapathCompressOpConversion>(patterns.getContext(), analysis);
384
386 getOperation(), std::move(patterns), analysis)))
387 return signalPassFailure();
388
389 // Verify that all Datapath operations have been successfully converted.
390 // Walk the operation and check for any remaining Datapath dialect
391 // operations.
392 auto result = getOperation()->walk([&](Operation *op) {
393 if (llvm::isa_and_nonnull<datapath::DatapathDialect>(op->getDialect())) {
394 op->emitError("Datapath operation not converted: ") << *op;
395 return WalkResult::interrupt();
396 }
397 return WalkResult::advance();
398 });
399 if (result.wasInterrupted())
400 return signalPassFailure();
401}
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
Definition CombToAIG.cpp:57
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static LogicalResult applyPatternsGreedilyWithTimingInfo(Operation *op, RewritePatternSet &&patterns, aig::IncrementalLongestPathAnalysis *analysis)
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.