CIRCT 22.0.0git
Loading...
Searching...
No Matches
DatapathToComb.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/Analysis/TopologicalSortUtils.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/PatternMatch.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19#include "llvm/Support/Debug.h"
20#include "llvm/Support/KnownBits.h"
21#include <algorithm>
22
23#define DEBUG_TYPE "datapath-to-comb"
24
25namespace circt {
26#define GEN_PASS_DEF_CONVERTDATAPATHTOCOMB
27#include "circt/Conversion/Passes.h.inc"
28} // namespace circt
29
30using namespace circt;
31using namespace datapath;
32
33// A wrapper for comb::extractBits that returns a SmallVector<Value>.
34static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
35 SmallVector<Value> bits;
36 comb::extractBits(builder, val, bits);
37 return bits;
38}
39
40//===----------------------------------------------------------------------===//
41// Conversion patterns
42//===----------------------------------------------------------------------===//
43
44namespace {
45// Replace compressor by an adder of the inputs and zero for the other results:
46// compress(a,b,c,d) -> {a+b+c+d, 0}
47// Facilitates use of downstream compression algorithms e.g. Yosys
48struct DatapathCompressOpAddConversion : mlir::OpRewritePattern<CompressOp> {
50 LogicalResult
51 matchAndRewrite(CompressOp op,
52 mlir::PatternRewriter &rewriter) const override {
53 Location loc = op.getLoc();
54 auto inputs = op.getOperands();
55 unsigned width = inputs[0].getType().getIntOrFloatBitWidth();
56 // Sum all the inputs - set that to result value 0
57 auto addOp = comb::AddOp::create(rewriter, loc, inputs, true);
58 // Replace remaining results with zeros
59 auto zeroOp = hw::ConstantOp::create(rewriter, loc, APInt(width, 0));
60 SmallVector<Value> results(op.getNumResults() - 1, zeroOp);
61 results.push_back(addOp);
62 rewriter.replaceOp(op, results);
63 return success();
64 }
65};
66
67// Replace compressor by a wallace tree of full-adders
68struct DatapathCompressOpConversion : mlir::OpRewritePattern<CompressOp> {
69 DatapathCompressOpConversion(MLIRContext *context,
71 : mlir::OpRewritePattern<CompressOp>(context), analysis(analysis) {}
72
73 LogicalResult
74 matchAndRewrite(CompressOp op,
75 mlir::PatternRewriter &rewriter) const override {
76 Location loc = op.getLoc();
77 auto inputs = op.getOperands();
78
79 SmallVector<SmallVector<Value>> addends;
80 for (auto input : inputs) {
81 addends.push_back(
82 extractBits(rewriter, input)); // Extract bits from each input
83 }
84
85 // Compressor tree reduction
86 auto width = inputs[0].getType().getIntOrFloatBitWidth();
87 auto targetAddends = op.getNumResults();
88 datapath::CompressorTree comp(width, addends, loc);
89
90 if (analysis) {
91 // Update delay information with arrival times
92 if (failed(comp.withInputDelays(
93 [&](Value v) { return analysis->getMaxDelay(v, 0); })))
94 return failure();
95 }
96
97 rewriter.replaceOp(op, comp.compressToHeight(rewriter, targetAddends));
98 return success();
99 }
100
101private:
102 synth::IncrementalLongestPathAnalysis *analysis = nullptr;
103};
104
105struct DatapathPartialProductOpConversion : OpRewritePattern<PartialProductOp> {
106 using OpRewritePattern<PartialProductOp>::OpRewritePattern;
107
108 DatapathPartialProductOpConversion(MLIRContext *context, bool forceBooth)
109 : OpRewritePattern<PartialProductOp>(context), forceBooth(forceBooth){};
110
111 const bool forceBooth;
112
113 LogicalResult matchAndRewrite(PartialProductOp op,
114 PatternRewriter &rewriter) const override {
115
116 Value a = op.getLhs();
117 Value b = op.getRhs();
118 unsigned width = a.getType().getIntOrFloatBitWidth();
119
120 // Skip a zero width value.
121 if (width == 0) {
122 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(0), 0);
123 return success();
124 }
125
126 // Square partial product array can be reduced to upper triangular array.
127 // For example: AND array for a 4-bit squarer:
128 // 0 0 0 a0a3 a0a2 a0a1 a0a0
129 // 0 0 a1a3 a1a2 a1a1 a1a0 0
130 // 0 a2a3 a2a2 a2a1 a2a0 0 0
131 // a3a3 a3a2 a3a1 a3a0 0 0 0
132 //
133 // Can be reduced to:
134 // 0 0 a0a3 a0a2 a0a1 0 a0
135 // 0 a1a3 a1a2 0 a1 0 0
136 // a2a3 0 a2 0 0 0 0
137 // a3 0 0 0 0 0 0
138 if (a == b)
139 return lowerSqrAndArray(rewriter, a, op, width);
140
141 // Use result rows as a heuristic to guide partial product
142 // implementation
143 if (op.getNumResults() > 16 || forceBooth)
144 return lowerBoothArray(rewriter, a, b, op, width);
145 else
146 return lowerAndArray(rewriter, a, b, op, width);
147 }
148
149private:
150 static LogicalResult lowerAndArray(PatternRewriter &rewriter, Value a,
151 Value b, PartialProductOp op,
152 unsigned width) {
153
154 Location loc = op.getLoc();
155 // Keep a as a bitvector - multiply by each digit of b
156 SmallVector<Value> bBits = extractBits(rewriter, b);
157
158 auto rowWidth = width;
159 auto knownBitsA = comb::computeKnownBits(a);
160 if (!knownBitsA.Zero.isZero()) {
161 if (knownBitsA.Zero.countLeadingOnes() > 1) {
162 rowWidth -= knownBitsA.Zero.countLeadingOnes();
163 a = rewriter.createOrFold<comb::ExtractOp>(loc, a, 0, rowWidth);
164 }
165 }
166
167 SmallVector<Value> partialProducts;
168 partialProducts.reserve(width);
169 // AND Array Construction:
170 // partialProducts[i] = ({b[i],..., b[i]} & a) << i
171 assert(op.getNumResults() <= width &&
172 "Cannot return more results than the operator width");
173
174 for (unsigned i = 0; i < op.getNumResults(); ++i) {
175 auto repl =
176 rewriter.createOrFold<comb::ReplicateOp>(loc, bBits[i], rowWidth);
177 auto ppRow = rewriter.createOrFold<comb::AndOp>(loc, repl, a);
178 if (rowWidth < width) {
179 auto padding = width - rowWidth;
180 auto zeroPad = hw::ConstantOp::create(rewriter, loc, APInt(padding, 0));
181 ppRow = rewriter.createOrFold<comb::ConcatOp>(
182 loc, ValueRange{zeroPad, ppRow}); // Pad to full width
183 }
184
185 if (i == 0) {
186 partialProducts.push_back(ppRow);
187 continue;
188 }
189 auto shiftBy = hw::ConstantOp::create(rewriter, loc, APInt(i, 0));
190 auto ppAlign =
191 comb::ConcatOp::create(rewriter, loc, ValueRange{ppRow, shiftBy});
192 auto ppAlignTrunc = rewriter.createOrFold<comb::ExtractOp>(
193 loc, ppAlign, 0, width); // Truncate to width+i bits
194 partialProducts.push_back(ppAlignTrunc);
195 }
196
197 rewriter.replaceOp(op, partialProducts);
198 return success();
199 }
200
201 static LogicalResult lowerSqrAndArray(PatternRewriter &rewriter, Value a,
202 PartialProductOp op, unsigned width) {
203
204 Location loc = op.getLoc();
205 SmallVector<Value> aBits = extractBits(rewriter, a);
206
207 SmallVector<Value> partialProducts;
208 partialProducts.reserve(width);
209 // AND Array Construction - reducing to upper triangle:
210 // partialProducts[i] = ({a[i],..., a[i]} & a) << i
211 // optimised to: {a[i] & a[n-1], ..., a[i] & a[i+1], 0, a[i], 0, ..., 0}
212 assert(op.getNumResults() <= width &&
213 "Cannot return more results than the operator width");
214 auto zeroFalse = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));
215 for (unsigned i = 0; i < op.getNumResults(); ++i) {
216 SmallVector<Value> row;
217 row.reserve(width);
218
219 if (2 * i >= width) {
220 // Pad the remaining rows with zeros
221 auto zeroWidth = hw::ConstantOp::create(rewriter, loc, APInt(width, 0));
222 partialProducts.push_back(zeroWidth);
223 continue;
224 }
225
226 if (i > 0) {
227 auto shiftBy = hw::ConstantOp::create(rewriter, loc, APInt(2 * i, 0));
228 row.push_back(shiftBy);
229 }
230 row.push_back(aBits[i]);
231
232 // Track width of constructed row
233 unsigned rowWidth = 2 * i + 1;
234 if (rowWidth < width) {
235 row.push_back(zeroFalse);
236 ++rowWidth;
237 }
238
239 for (unsigned j = i + 1; j < width; ++j) {
240 // Stop when we reach the required width
241 if (rowWidth == width)
242 break;
243
244 // Otherwise pad with zeros or partial product bits
245 ++rowWidth;
246 // Number of results indicates number of non-zero bits in input
247 if (j >= op.getNumResults()) {
248 row.push_back(zeroFalse);
249 continue;
250 }
251
252 auto ppBit =
253 rewriter.createOrFold<comb::AndOp>(loc, aBits[i], aBits[j]);
254 row.push_back(ppBit);
255 }
256 std::reverse(row.begin(), row.end());
257 auto ppRow = comb::ConcatOp::create(rewriter, loc, row);
258 partialProducts.push_back(ppRow);
259 }
260
261 rewriter.replaceOp(op, partialProducts);
262 return success();
263 }
264
265 static LogicalResult lowerBoothArray(PatternRewriter &rewriter, Value a,
266 Value b, PartialProductOp op,
267 unsigned width) {
268 Location loc = op.getLoc();
269 auto zeroFalse = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));
270 auto zeroWidth = hw::ConstantOp::create(rewriter, loc, APInt(width, 0));
271
272 // Detect leading zeros in multiplicand due to zero-extension
273 // and truncate to reduce partial product bits
274 // {'0, a} * {'0, b}
275 auto rowWidth = width;
276 auto knownBitsA = comb::computeKnownBits(a);
277 if (!knownBitsA.Zero.isZero()) {
278 if (knownBitsA.Zero.countLeadingOnes() > 1) {
279 // Retain one leading zero to represent 2*{1'b0, a} = {a, 1'b0}
280 // {'0, a} -> {1'b0, a}
281 rowWidth -= knownBitsA.Zero.countLeadingOnes() - 1;
282 a = rewriter.createOrFold<comb::ExtractOp>(loc, a, 0, rowWidth);
283 }
284 }
285
286 // TODO - replace with a concatenation to aid longest path analysis
287 auto oneRowWidth =
288 hw::ConstantOp::create(rewriter, loc, APInt(rowWidth, 1));
289 // Booth encoding will select each row from {-2a, -1a, 0, 1a, 2a}
290 Value twoA = rewriter.createOrFold<comb::ShlOp>(loc, a, oneRowWidth);
291
292 // Encode based on the bits of b
293 // TODO: sort a and b based on non-zero bits to encode the smaller input
294 SmallVector<Value> bBits = extractBits(rewriter, b);
295
296 // Identify zero bits of b to reduce height of partial product array
297 auto knownBitsB = comb::computeKnownBits(b);
298 if (!knownBitsB.Zero.isZero()) {
299 for (unsigned i = 0; i < width; ++i)
300 if (knownBitsB.Zero[i])
301 bBits[i] = zeroFalse;
302 }
303
304 SmallVector<Value> partialProducts;
305 partialProducts.reserve(width);
306
307 // Booth encoding halves array height by grouping three bits at a time:
308 // partialProducts[i] = a * (-2*b[2*i+1] + b[2*i] + b[2*i-1]) << 2*i
309 // encNeg \approx (-2*b[2*i+1] + b[2*i] + b[2*i-1]) <= 0
310 // encOne = (-2*b[2*i+1] + b[2*i] + b[2*i-1]) == +/- 1
311 // encTwo = (-2*b[2*i+1] + b[2*i] + b[2*i-1]) == +/- 2
312 Value encNegPrev;
313
314 // For even width - additional row contains the final sign correction
315 for (unsigned i = 0; i <= width; i += 2) {
316 // Get Booth bits: b[i+1], b[i], b[i-1] (b[-1] = 0)
317 Value bim1 = (i == 0) ? zeroFalse : bBits[i - 1];
318 Value bi = (i < width) ? bBits[i] : zeroFalse;
319 Value bip1 = (i + 1 < width) ? bBits[i + 1] : zeroFalse;
320
321 // Is the encoding zero or negative (an approximation)
322 Value encNeg = bip1;
323 // Is the encoding one = b[i] xor b[i-1]
324 Value encOne = rewriter.createOrFold<comb::XorOp>(loc, bi, bim1, true);
325 // Is the encoding two = (bip1 & ~bi & ~bim1) | (~bip1 & bi & bim1)
326 Value constOne = hw::ConstantOp::create(rewriter, loc, APInt(1, 1));
327 Value biInv = rewriter.createOrFold<comb::XorOp>(loc, bi, constOne, true);
328 Value bip1Inv =
329 rewriter.createOrFold<comb::XorOp>(loc, bip1, constOne, true);
330 Value bim1Inv =
331 rewriter.createOrFold<comb::XorOp>(loc, bim1, constOne, true);
332
333 Value andLeft = rewriter.createOrFold<comb::AndOp>(
334 loc, ValueRange{bip1Inv, bi, bim1}, true);
335 Value andRight = rewriter.createOrFold<comb::AndOp>(
336 loc, ValueRange{bip1, biInv, bim1Inv}, true);
337 Value encTwo =
338 rewriter.createOrFold<comb::OrOp>(loc, andLeft, andRight, true);
339
340 Value encNegRepl =
341 rewriter.createOrFold<comb::ReplicateOp>(loc, encNeg, rowWidth);
342 Value encOneRepl =
343 rewriter.createOrFold<comb::ReplicateOp>(loc, encOne, rowWidth);
344 Value encTwoRepl =
345 rewriter.createOrFold<comb::ReplicateOp>(loc, encTwo, rowWidth);
346
347 // Select between 2*a or 1*a or 0*a
348 Value selTwoA = rewriter.createOrFold<comb::AndOp>(loc, encTwoRepl, twoA);
349 Value selOneA = rewriter.createOrFold<comb::AndOp>(loc, encOneRepl, a);
350 Value magA =
351 rewriter.createOrFold<comb::OrOp>(loc, selTwoA, selOneA, true);
352
353 // Conditionally invert the row
354 Value ppRow =
355 rewriter.createOrFold<comb::XorOp>(loc, magA, encNegRepl, true);
356
357 // Sign-extension Optimisation:
358 // Section 7.2.2 of "Application Specific Arithmetic" by Dinechin &
359 // Kumm Handle sign-extension and padding to full width s = encNeg
360 // (sign-bit) {s, s, s, s, s, pp} = {1, 1, 1, 1, 1, pp}
361 // + {0, 0, 0, 0,!s, '0}
362 // Applying this to every row we create an upper-triangle of 1s that
363 // can be optimised away since they will not affect the final sum.
364 // {!s3, 0,!s2, 0,!s1, 0}
365 // { 1, 1, 1, 1, 1, p1}
366 // { 1, 1, 1, p2 }
367 // { 1, p3 }
368 if (rowWidth < width) {
369 auto padding = width - rowWidth;
370 auto encNegInv = bip1Inv;
371
372 // Sign-extension trick not worth it for padding < 3
373 if (padding < 3) {
374 Value encNegPad =
375 rewriter.createOrFold<comb::ReplicateOp>(loc, encNeg, padding);
376 ppRow = rewriter.createOrFold<comb::ConcatOp>(
377 loc, ValueRange{encNegPad, ppRow}); // Pad to full width
378 } else if (i == 0) {
379 // First row = {!encNeg, encNeg, encNeg, ppRow}
380 ppRow = rewriter.createOrFold<comb::ConcatOp>(
381 loc, ValueRange{encNegInv, encNeg, encNeg, ppRow});
382 } else {
383 // Remaining rows = {1, !encNeg, ppRow}
384 ppRow = rewriter.createOrFold<comb::ConcatOp>(
385 loc, ValueRange{constOne, encNegInv, ppRow});
386 }
387
388 // Zero pad to full width
389 auto rowWidth = ppRow.getType().getIntOrFloatBitWidth();
390 if (rowWidth < width) {
391 auto zeroPad =
392 hw::ConstantOp::create(rewriter, loc, APInt(width - rowWidth, 0));
393 ppRow = rewriter.createOrFold<comb::ConcatOp>(
394 loc, ValueRange{zeroPad, ppRow});
395 }
396 }
397
398 // No sign-correction in the first row
399 if (i == 0) {
400 partialProducts.push_back(ppRow);
401 encNegPrev = encNeg;
402 continue;
403 }
404
405 if (i == 2) {
406 Value withSignCorrection = rewriter.createOrFold<comb::ConcatOp>(
407 loc, ValueRange{ppRow, zeroFalse, encNegPrev});
408 Value ppAlign = rewriter.createOrFold<comb::ExtractOp>(
409 loc, withSignCorrection, 0, width);
410 partialProducts.push_back(ppAlign);
411 encNegPrev = encNeg;
412 continue;
413 }
414
415 // Insert a sign-correction from the previous row
416 // {ppRow, 0, encNegPrev} << 2*(i-1)
417 Value shiftBy = hw::ConstantOp::create(rewriter, loc, APInt(i - 2, 0));
418 Value withSignCorrection = rewriter.createOrFold<comb::ConcatOp>(
419 loc, ValueRange{ppRow, zeroFalse, encNegPrev, shiftBy});
420 Value ppAlign = rewriter.createOrFold<comb::ExtractOp>(
421 loc, withSignCorrection, 0, width);
422
423 partialProducts.push_back(ppAlign);
424 encNegPrev = encNeg;
425
426 if (partialProducts.size() == op.getNumResults())
427 break;
428 }
429
430 // Zero-pad to match the required output width
431 while (partialProducts.size() < op.getNumResults())
432 partialProducts.push_back(zeroWidth);
433
434 assert(partialProducts.size() == op.getNumResults() &&
435 "Expected number of booth partial products to match results");
436
437 rewriter.replaceOp(op, partialProducts);
438 return success();
439 }
440};
441
442struct DatapathPosPartialProductOpConversion
443 : OpRewritePattern<PosPartialProductOp> {
444 using OpRewritePattern<PosPartialProductOp>::OpRewritePattern;
445
446 DatapathPosPartialProductOpConversion(MLIRContext *context, bool forceBooth)
447 : OpRewritePattern<PosPartialProductOp>(context),
448 forceBooth(forceBooth){};
449
450 const bool forceBooth;
451
452 LogicalResult matchAndRewrite(PosPartialProductOp op,
453 PatternRewriter &rewriter) const override {
454
455 Value a = op.getAddend0();
456 Value b = op.getAddend1();
457 Value c = op.getMultiplicand();
458 unsigned width = a.getType().getIntOrFloatBitWidth();
459
460 // Skip a zero width value.
461 if (width == 0) {
462 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(0), 0);
463 return success();
464 }
465
466 // TODO: Implement Booth lowering
467 return lowerAndArray(rewriter, a, b, c, op, width);
468 }
469
470private:
471 static LogicalResult lowerAndArray(PatternRewriter &rewriter, Value a,
472 Value b, Value c, PosPartialProductOp op,
473 unsigned width) {
474
475 Location loc = op.getLoc();
476 // Encode (a+b) by implementing a half-adder - then note the following
477 // fact carry[i] & save[i] == false
478 auto carry = rewriter.createOrFold<comb::AndOp>(loc, a, b);
479 auto save = rewriter.createOrFold<comb::XorOp>(loc, a, b);
480
481 SmallVector<Value> carryBits = extractBits(rewriter, carry);
482 SmallVector<Value> saveBits = extractBits(rewriter, save);
483
484 // Reduce c width based on leading zeros
485 auto rowWidth = width;
486 auto knownBitsC = comb::computeKnownBits(c);
487 if (!knownBitsC.Zero.isZero()) {
488 if (knownBitsC.Zero.countLeadingOnes() > 1) {
489 // Retain one leading zero to represent 2*{1'b0, c} = {c, 1'b0}
490 // {'0, c} -> {1'b0, c}
491 rowWidth -= knownBitsC.Zero.countLeadingOnes() - 1;
492 c = rewriter.createOrFold<comb::ExtractOp>(loc, c, 0, rowWidth);
493 }
494 }
495
496 // Compute 2*c for use in array construction
497 Value zero = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));
498 Value twoCWider =
499 comb::ConcatOp::create(rewriter, loc, ValueRange{c, zero});
500 Value twoC = comb::ExtractOp::create(rewriter, loc, twoCWider, 0, rowWidth);
501
502 // AND Array Construction:
503 // pp[i] = ( (carry[i] * (c<<1)) | (save[i] * c) ) << i
504 SmallVector<Value> partialProducts;
505 partialProducts.reserve(width);
506
507 assert(op.getNumResults() <= width &&
508 "Cannot return more results than the operator width");
509
510 for (unsigned i = 0; i < op.getNumResults(); ++i) {
511 auto replSave =
512 rewriter.createOrFold<comb::ReplicateOp>(loc, saveBits[i], rowWidth);
513 auto replCarry =
514 rewriter.createOrFold<comb::ReplicateOp>(loc, carryBits[i], rowWidth);
515
516 auto ppRowSave = rewriter.createOrFold<comb::AndOp>(loc, replSave, c);
517 auto ppRowCarry =
518 rewriter.createOrFold<comb::AndOp>(loc, replCarry, twoC);
519 auto ppRow =
520 rewriter.createOrFold<comb::OrOp>(loc, ppRowSave, ppRowCarry);
521 auto ppAlign = ppRow;
522 if (i > 0) {
523 auto shiftBy = hw::ConstantOp::create(rewriter, loc, APInt(i, 0));
524 ppAlign =
525 comb::ConcatOp::create(rewriter, loc, ValueRange{ppRow, shiftBy});
526 }
527
528 // May need to truncate shifted value
529 if (rowWidth + i > width) {
530 auto ppAlignTrunc =
531 rewriter.createOrFold<comb::ExtractOp>(loc, ppAlign, 0, width);
532 partialProducts.push_back(ppAlignTrunc);
533 continue;
534 }
535 // May need to zero pad to approriate width
536 if (rowWidth + i < width) {
537 auto padding = width - rowWidth - i;
538 Value zeroPad =
539 hw::ConstantOp::create(rewriter, loc, APInt(padding, 0));
540 partialProducts.push_back(rewriter.createOrFold<comb::ConcatOp>(
541 loc, ValueRange{zeroPad, ppAlign})); // Pad to full width
542 continue;
543 }
544
545 partialProducts.push_back(ppAlign);
546 }
547
548 rewriter.replaceOp(op, partialProducts);
549 return success();
550 }
551};
552
553} // namespace
554
555//===----------------------------------------------------------------------===//
556// Convert Datapath to Comb pass
557//===----------------------------------------------------------------------===//
558
559namespace {
560struct ConvertDatapathToCombPass
561 : public impl::ConvertDatapathToCombBase<ConvertDatapathToCombPass> {
562 void runOnOperation() override;
563 using ConvertDatapathToCombBase<
564 ConvertDatapathToCombPass>::ConvertDatapathToCombBase;
565};
566} // namespace
567
569 Operation *op, RewritePatternSet &&patterns,
571 // TODO: Topologically sort the operations in the module to ensure that all
572 // dependencies are processed before their users.
573 mlir::GreedyRewriteConfig config;
574 // Set the listener to update timing information
575 // HACK: Setting max iterations to 2 to ensure that the patterns are
576 // one-shot, making sure target operations are datapath operations are
577 // replaced.
578 config.setMaxIterations(2).setListener(analysis).setUseTopDownTraversal(true);
579
580 // Apply the patterns greedily
581 if (failed(mlir::applyPatternsGreedily(op, std::move(patterns), config)))
582 return failure();
583
584 return success();
585}
586
587void ConvertDatapathToCombPass::runOnOperation() {
588 RewritePatternSet patterns(&getContext());
589
590 patterns.add<DatapathPartialProductOpConversion,
591 DatapathPosPartialProductOpConversion>(patterns.getContext(),
592 forceBooth);
593 synth::IncrementalLongestPathAnalysis *analysis = nullptr;
594 if (timingAware)
595 analysis = &getAnalysis<synth::IncrementalLongestPathAnalysis>();
596 if (lowerCompressToAdd)
597 // Lower compressors to simple add operations for downstream optimisations
598 patterns.add<DatapathCompressOpAddConversion>(patterns.getContext());
599 else
600 // Lower compressors to a complete gate-level implementation
601 patterns.add<DatapathCompressOpConversion>(patterns.getContext(), analysis);
602
604 getOperation(), std::move(patterns), analysis)))
605 return signalPassFailure();
606
607 // Verify that all Datapath operations have been successfully converted.
608 // Walk the operation and check for any remaining Datapath dialect
609 // operations.
610 auto result = getOperation()->walk([&](Operation *op) {
611 if (llvm::isa_and_nonnull<datapath::DatapathDialect>(op->getDialect())) {
612 op->emitError("Datapath operation not converted: ") << *op;
613 return WalkResult::interrupt();
614 }
615 return WalkResult::advance();
616 });
617 if (result.wasInterrupted())
618 return signalPassFailure();
619}
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static LogicalResult applyPatternsGreedilyWithTimingInfo(Operation *op, RewritePatternSet &&patterns, synth::IncrementalLongestPathAnalysis *analysis)
create(low_bit, result_type, input=None)
Definition comb.py:187
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.