CIRCT 23.0.0git
Loading...
Searching...
No Matches
DatapathFolds.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/PatternMatch.h"
15#include "llvm/Support/Casting.h"
16#include "llvm/Support/KnownBits.h"
17#include <algorithm>
18
19using namespace mlir;
20using namespace circt;
21using namespace datapath;
22using namespace matchers;
23
24//===----------------------------------------------------------------------===//
25// Utility Functions
26//===----------------------------------------------------------------------===//
27static FailureOr<size_t> calculateNonZeroBits(Value operand,
28 size_t numResults) {
29 // If the extracted bits are all known, then return the result.
30 auto knownBits = comb::computeKnownBits(operand);
31 if (knownBits.isUnknown())
32 return failure(); // Skip if we don't know anything about the bits
33
34 size_t nonZeroBits = operand.getType().getIntOrFloatBitWidth() -
35 knownBits.Zero.countLeadingOnes();
36
37 // If all bits non-zero we will not reduce the number of results
38 if (nonZeroBits == numResults)
39 return failure();
40
41 return nonZeroBits;
42}
43
44// This pattern commonly arrises when inverting zext: ~zext(x) = {1,...1, ~x}
45// Check if the operand is {ones, base} and return the unextended operand:
46static FailureOr<Value> isOneExt(Value operand) {
47 // Check if operand is a concat operation
48 auto concatOp = operand.getDefiningOp<comb::ConcatOp>();
49 if (!concatOp)
50 return failure();
51
52 auto operands = concatOp.getOperands();
53 // ConcatOp must have exactly 2 operands
54 if (operands.size() != 2)
55 return failure();
56
57 APInt value;
58 if (matchPattern(operands[0], m_ConstantInt(&value)) && value.isAllOnes())
59 // Return the base unextended value
60 return success(operands[1]);
61
62 return failure();
63}
64
65// zext(input<<trailingZeros) to targetWidth
66static Value zeroPad(PatternRewriter &rewriter, Location loc, Value input,
67 size_t targetWidth, size_t trailingZeros) {
68 assert(trailingZeros > 0 && "zeroPad called with zero trailing zeros");
69 auto trailingZerosValue =
70 hw::ConstantOp::create(rewriter, loc, APInt::getZero(trailingZeros));
71 auto padTrailing = comb::ConcatOp::create(
72 rewriter, loc, ValueRange{input, trailingZerosValue});
73 return comb::createZExt(rewriter, loc, padTrailing, targetWidth);
74}
75
76//===----------------------------------------------------------------------===//
77// Compress Operation
78//===----------------------------------------------------------------------===//
79// Check that all compressor results are included in this list of operands
80// If not we must take care as manipulating compressor results independently
81// could easily introduce a non-equivalent representation.
82static bool areAllCompressorResultsSummed(ValueRange compressResults,
83 ValueRange operands) {
84 for (auto result : compressResults) {
85 if (!llvm::is_contained(operands, result))
86 return false;
87 }
88 return true;
89}
90
92 : public OpRewritePattern<datapath::CompressOp> {
93 using OpRewritePattern::OpRewritePattern;
94
95 // compress(compress(a,b,c), add(e,f)) -> compress(a,b,c,e,f)
96 LogicalResult matchAndRewrite(datapath::CompressOp compOp,
97 PatternRewriter &rewriter) const override {
98 auto operands = compOp.getOperands();
99 llvm::SmallSetVector<Value, 8> processedCompressorResults;
100 SmallVector<Value, 8> newCompressOperands;
101
102 for (Value operand : operands) {
103
104 // Skip if already processed this compressor
105 if (processedCompressorResults.contains(operand))
106 continue;
107
108 // If the operand has multiple uses, we do not fold it into a compress
109 // operation, so we treat it as a regular operand to maintain sharing.
110 if (!operand.hasOneUse()) {
111 newCompressOperands.push_back(operand);
112 continue;
113 }
114
115 // Found a compress op - add its operands to our new list
116 if (auto compressOp = operand.getDefiningOp<datapath::CompressOp>()) {
117
118 // Check that all results of the compressor are summed in this add
119 if (!areAllCompressorResultsSummed(compressOp.getResults(), operands))
120 return failure();
121
122 llvm::append_range(newCompressOperands, compressOp.getOperands());
123 // Only process each compressor once as multiple operands will point
124 // to the same defining operation
125 processedCompressorResults.insert(compressOp.getResults().begin(),
126 compressOp.getResults().end());
127 continue;
128 }
129
130 if (auto addOp = operand.getDefiningOp<comb::AddOp>()) {
131 llvm::append_range(newCompressOperands, addOp.getOperands());
132 continue;
133 }
134
135 // Regular operand - just add it to our list
136 newCompressOperands.push_back(operand);
137 }
138
139 // If unable to collect more operands then this pattern doesn't apply
140 if (newCompressOperands.size() <= compOp.getNumOperands())
141 return failure();
142
143 // Create a new CompressOp with all collected operands
144 rewriter.replaceOpWithNewOp<datapath::CompressOp>(
145 compOp, newCompressOperands, compOp.getNumResults());
146 return success();
147 }
148};
149
150struct FoldAddIntoCompress : public OpRewritePattern<comb::AddOp> {
151 using OpRewritePattern::OpRewritePattern;
152
153 // add(compress(a,b,c),d) -> add(compress(a,b,c,d))
154 // FIXME: This should be implemented as a canonicalization pattern for
155 // compress op. Currently `hasDatapathOperand` flag prevents introducing
156 // datapath operations from comb operations.
157 LogicalResult matchAndRewrite(comb::AddOp addOp,
158 PatternRewriter &rewriter) const override {
159 // comb.add canonicalization patterns handle folding add operations
160 if (addOp.getNumOperands() <= 2)
161 return failure();
162
163 // Get operands of the AddOp
164 auto operands = addOp.getOperands();
165 llvm::SmallSetVector<Value, 8> processedCompressorResults;
166 SmallVector<Value, 8> newCompressOperands;
167 // Only construct compressor if can form a larger compressor than what
168 // is currently an input of this add. Also check that there is at least
169 // one datapath operand.
170 bool shouldFold = false, hasDatapathOperand = false;
171
172 for (Value operand : operands) {
173
174 // Skip if already processed this compressor
175 if (processedCompressorResults.contains(operand))
176 continue;
177
178 if (auto *op = operand.getDefiningOp())
179 if (isa_and_nonnull<datapath::DatapathDialect>(op->getDialect()))
180 hasDatapathOperand = true;
181
182 // If the operand has multiple uses, we do not fold it into a compress
183 // operation, so we treat it as a regular operand.
184 if (!operand.hasOneUse()) {
185 shouldFold |= !newCompressOperands.empty();
186 newCompressOperands.push_back(operand);
187 continue;
188 }
189
190 // Found a compress op - add its operands to our new list
191 if (auto compressOp = operand.getDefiningOp<datapath::CompressOp>()) {
192
193 // Check that all results of the compressor are summed in this add
194 if (!areAllCompressorResultsSummed(compressOp.getResults(), operands))
195 return failure();
196
197 // If we've already added one operand it should be folded
198 shouldFold |= !newCompressOperands.empty();
199 llvm::append_range(newCompressOperands, compressOp.getOperands());
200 // Only process each compressor once
201 processedCompressorResults.insert(compressOp.getResults().begin(),
202 compressOp.getResults().end());
203 continue;
204 }
205
206 if (auto addOp = operand.getDefiningOp<comb::AddOp>()) {
207 shouldFold |= !newCompressOperands.empty();
208 llvm::append_range(newCompressOperands, addOp.getOperands());
209 continue;
210 }
211
212 // Regular operand - just add it to our list
213 shouldFold |= !newCompressOperands.empty();
214 newCompressOperands.push_back(operand);
215 }
216
217 // Only fold if we have constructed a larger compressor than what was
218 // already there
219 if (!shouldFold || !hasDatapathOperand)
220 return failure();
221
222 // Create a new CompressOp with all collected operands
223 auto newCompressOp = datapath::CompressOp::create(rewriter, addOp.getLoc(),
224 newCompressOperands, 2);
225
226 // Replace the original AddOp with a new add(compress(inputs))
227 rewriter.replaceOpWithNewOp<comb::AddOp>(addOp, newCompressOp.getResults(),
228 true);
229 return success();
230 }
231};
232
233// compress(..., sext(x),...) ->
234// compress(..., zext({~x[p-1], x[p-2:0]}), (-1) << (width(x)-1), ...)
235// Justification:
236// sext(x) = {x[p-1], x[p-1], ..., x[p-1], x[p-2], ..., x[0]} =
237// = { 0, 0, ..., ~x[p-1], x[p-2], ..., x[0]} +
238// { 1, 1, ..., 1, 0, ..., 0} =
239// = zext({~x[p-1], x[p-2], ..., x[0]}) + ((-1) << (width(x)-1))
240//
241// Note that we are adding arguments to the compressor, but we are reducing the
242// number of unknown bits in the compressor array
243struct SextCompress : public OpRewritePattern<CompressOp> {
244 using OpRewritePattern::OpRewritePattern;
245
246 LogicalResult matchAndRewrite(CompressOp op,
247 PatternRewriter &rewriter) const override {
248 auto inputs = op.getInputs();
249 auto opSize = inputs[0].getType().getIntOrFloatBitWidth();
250 auto size = inputs.size();
251
252 APInt value;
253 SmallVector<Value> newInputs;
254 for (auto input : inputs) {
255 Value sextInput;
256 // Check for sext of the inverted value
257 if (!matchPattern(input, comb::m_Sext(m_Any(&sextInput)))) {
258 newInputs.push_back(input);
259 continue;
260 }
261
262 auto baseWidth = sextInput.getType().getIntOrFloatBitWidth();
263 // Need a separate sign-bit that gets extended by at least two bits to
264 // be beneficial
265 if (baseWidth <= 1 || (opSize - baseWidth) <= 1) {
266 newInputs.push_back(input);
267 continue;
268 }
269
270 // x[p-2:0]
271 auto base = comb::ExtractOp::create(rewriter, op.getLoc(), sextInput, 0,
272 baseWidth - 1);
273 // x[p-1]
274 auto signBit = comb::ExtractOp::create(rewriter, op.getLoc(), sextInput,
275 baseWidth - 1, 1);
276 auto invSign =
277 comb::createOrFoldNot(op.getLoc(), signBit, rewriter, true);
278 // {~x[p-1], x[p-2:0]}
279 auto newOp = comb::ConcatOp::create(rewriter, op.getLoc(),
280 ValueRange{invSign, base});
281 auto newOpZExt = comb::createZExt(rewriter, op.getLoc(), newOp, opSize);
282
283 newInputs.push_back(newOpZExt);
284
285 // (-1) << (width(x)-1)
286 auto ones = APInt::getAllOnes(opSize);
287 auto correction = hw::ConstantOp::create(rewriter, op.getLoc(),
288 ones << (baseWidth - 1));
289
290 newInputs.push_back(correction);
291 }
292
293 // If no sext inputs have not updated any arguments
294 if (newInputs.size() == size)
295 return failure();
296
297 auto newCompress = datapath::CompressOp::create(
298 rewriter, op.getLoc(), newInputs, op.getNumResults());
299 rewriter.replaceOp(op, newCompress.getResults());
300 return success();
301 }
302};
303
304// compress(..., oneExt(x),...) ->
305// compress(..., zext(x), (-1) << (width(x)-1), ...)
306// Justification:
307// {1, 1, ..., 1, x}
308// = zext(x) + ((-1) << (width(x)-1))
309//
310// Note that we are adding arguments to the compressor, but these can be
311// constant folded should other constants arise
312//
313// A pattern encountered when we convert subtraction to addition:
314// zext(a)-zext(b) = zext(a) + ~zext(b) + 1
315// = zext(a) + oneExt(~b) + 1
316// TODO: use knownBits to extract all constant ones
317struct OnesExtCompress : public OpRewritePattern<CompressOp> {
318 using OpRewritePattern::OpRewritePattern;
319
320 LogicalResult matchAndRewrite(CompressOp op,
321 PatternRewriter &rewriter) const override {
322 auto inputs = op.getInputs();
323 auto opType = inputs[0].getType();
324 auto opSize = opType.getIntOrFloatBitWidth();
325
326 SmallVector<Value> newInputs;
327 for (auto input : inputs) {
328 // Check for replication of ones leading
329 auto baseInput = isOneExt(input);
330 if (failed(baseInput)) {
331 newInputs.push_back(input);
332 continue;
333 }
334
335 // Separate {ones, x} -> zext(x) + (ones << baseWidth)
336 auto newOp = comb::createZExt(rewriter, op.getLoc(), *baseInput, opSize);
337 newInputs.push_back(newOp);
338
339 APInt ones = APInt::getAllOnes(opSize);
340 auto baseWidth = baseInput->getType().getIntOrFloatBitWidth();
341 auto correction =
342 hw::ConstantOp::create(rewriter, op.getLoc(), ones << baseWidth);
343 newInputs.push_back(correction);
344 }
345
346 if (newInputs.size() == inputs.size())
347 return failure();
348
349 auto newCompress = datapath::CompressOp::create(
350 rewriter, op.getLoc(), newInputs, op.getNumResults());
351 rewriter.replaceOp(op, newCompress.getResults());
352 return success();
353 }
354};
355
356struct ConstantFoldCompress : public OpRewritePattern<CompressOp> {
357 using OpRewritePattern::OpRewritePattern;
358
359 LogicalResult matchAndRewrite(CompressOp op,
360 PatternRewriter &rewriter) const override {
361 auto inputs = op.getInputs();
362 auto size = inputs.size();
363
364 APInt value;
365
366 // compress(..., 0) -> compress(...) -- identity
367 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
368
369 // If only reducing by one row and contains zero - pass through operands
370 if (size - 1 == op.getNumResults()) {
371 rewriter.replaceOp(op, inputs.drop_back());
372 return success();
373 }
374
375 // Default create a compressor with fewer arguments
376 rewriter.replaceOpWithNewOp<CompressOp>(op, inputs.drop_back(),
377 op.getNumResults());
378 return success();
379 }
380
381 APInt value1, value2;
382 // compress(...c1, c2) -> compress(..., c1+c2)
383 assert(size >= 3 &&
384 "compress op has 3 or more operands ensured by a verifier");
385 if (matchPattern(inputs.back(), m_ConstantInt(&value1)) &&
386 matchPattern(inputs[size - 2], m_ConstantInt(&value2))) {
387
388 SmallVector<Value> newInputs(inputs.drop_back(2));
389 auto summedValue = value1 + value2;
390 auto constOp = hw::ConstantOp::create(rewriter, op.getLoc(), summedValue);
391 newInputs.push_back(constOp);
392 // If reducing by one row and constant folding - pass through operands
393 if (size - 1 == op.getNumResults()) {
394 rewriter.replaceOp(op, newInputs);
395 return success();
396 }
397
398 // Default create a compressor with fewer arguments
399 rewriter.replaceOpWithNewOp<CompressOp>(op, newInputs,
400 op.getNumResults());
401 return success();
402 }
403
404 return failure();
405 }
406};
407
408void CompressOp::getCanonicalizationPatterns(RewritePatternSet &results,
409 MLIRContext *context) {
412}
413
414//===----------------------------------------------------------------------===//
415// Partial Product Operation
416//===----------------------------------------------------------------------===//
417struct ReduceNumPartialProducts : public OpRewritePattern<PartialProductOp> {
418 using OpRewritePattern::OpRewritePattern;
419
420 // pp(concat(0,a), concat(0,b)) -> reduced number of results
421 LogicalResult matchAndRewrite(PartialProductOp op,
422 PatternRewriter &rewriter) const override {
423 auto operands = op.getOperands();
424 unsigned inputWidth = operands[0].getType().getIntOrFloatBitWidth();
425
426 // TODO: implement a constant multiplication for the PartialProductOp
427
428 auto op0NonZeroBits = calculateNonZeroBits(operands[0], op.getNumResults());
429 auto op1NonZeroBits = calculateNonZeroBits(operands[1], op.getNumResults());
430
431 if (failed(op0NonZeroBits) || failed(op1NonZeroBits))
432 return failure();
433
434 // Need the +1 for the carry-out
435 size_t maxNonZeroBits = std::max(*op0NonZeroBits, *op1NonZeroBits);
436
437 auto newPP = datapath::PartialProductOp::create(
438 rewriter, op.getLoc(), op.getOperands(), maxNonZeroBits);
439
440 auto zero = hw::ConstantOp::create(rewriter, op.getLoc(),
441 APInt::getZero(inputWidth));
442
443 // Collect newPP results and pad with zeros if needed
444 SmallVector<Value> newResults(newPP.getResults().begin(),
445 newPP.getResults().end());
446
447 newResults.append(op.getNumResults() - newResults.size(), zero);
448
449 rewriter.replaceOp(op, newResults);
450 return success();
451 }
452};
453
454struct SignedPartialProducts : public OpRewritePattern<PartialProductOp> {
455 using OpRewritePattern::OpRewritePattern;
456
457 // Based on the classical Baugh-Wooley algorithm for signed mulitplication.
458 // Paper: A Two's Complement Parallel Array Multiplication Algorithm
459 //
460 // Consider a p-bit by q-bit signed multiplier - producing a (p+q)-bit result:
461 // a_sign = a[p-1], a_mag = a[p-2:0],
462 // b_sign = b[q-1], b_mag = b[q-2:0]
463 // sext(a) * sext(b) = a_mag * b_mag [unsigned product]
464 // - 2^(p-1) * a_sign * b_mag [sign correction]
465 // - 2^(q-1) * b_sign * a_mag [sign correction]
466 // + 2^(p+q-2) * a_sign * b_sign [sign * sign]
467 //
468 // We implement optimizations to turn the subtractions into bitwise
469 // negations with constant corrections that can be folded together.
470 LogicalResult matchAndRewrite(PartialProductOp op,
471 PatternRewriter &rewriter) const override {
472 auto inputWidth = op.getOperand(0).getType().getIntOrFloatBitWidth();
473 Value lhs;
474 Value rhs;
475 if (!matchPattern(op.getOperand(0), comb::m_Sext(m_Any(&lhs))) ||
476 !matchPattern(op.getOperand(1), comb::m_Sext(m_Any(&rhs))))
477 return failure();
478
479 size_t lhsWidth = lhs.getType().getIntOrFloatBitWidth();
480 size_t rhsWidth = rhs.getType().getIntOrFloatBitWidth();
481 // Subtract 1 as will handle sign-bit separately
482 size_t maxRows = std::max(lhsWidth, rhsWidth) - 1;
483
484 // TODO: add support for different width inputs
485 // Need to have a sign bit in both inputs
486 if (lhsWidth != rhsWidth || lhsWidth <= 1 || rhsWidth <= 1)
487 return failure();
488
489 // No further reduction possible
490 if (maxRows >= op.getNumResults())
491 return failure();
492
493 // Pull off the sign bits
494 auto lhsBaseWidth = lhsWidth - 1;
495 auto rhsBaseWidth = rhsWidth - 1;
496 auto lhsSignBit =
497 comb::ExtractOp::create(rewriter, op.getLoc(), lhs, lhsBaseWidth, 1);
498 auto rhsSignBit =
499 comb::ExtractOp::create(rewriter, op.getLoc(), rhs, rhsBaseWidth, 1);
500 auto lhsBase =
501 comb::ExtractOp::create(rewriter, op.getLoc(), lhs, 0, lhsBaseWidth);
502 auto rhsBase =
503 comb::ExtractOp::create(rewriter, op.getLoc(), rhs, 0, rhsBaseWidth);
504
505 // Create the unsigned partial product of the unextended inputs
506 auto lhsBaseZext =
507 comb::createZExt(rewriter, op.getLoc(), lhsBase, inputWidth);
508 auto rhsBaseZext =
509 comb::createZExt(rewriter, op.getLoc(), rhsBase, inputWidth);
510 auto newPP = datapath::PartialProductOp::create(
511 rewriter, op.getLoc(), ValueRange{lhsBaseZext, rhsBaseZext}, maxRows);
512
513 // Optimization (similar for second sign correction), ext to (p+q)-bits:
514 // -2^(p-1)*sign(lhs)*rhsBase = ~((sign(lhs) * rhsBase) << (p-1)) + 1
515 // = (~(replicate(sign(lhs)) & rhsBase)) << (p-1)
516 // + (-1) << (p+q-2) [msb correction]
517 // + (1<<(p-1)) - 1 + 1 [lsb correction]
518
519 // Create ~(replicate(sign(lhs)) & rhsBase)
520 auto lhsSignReplicate = comb::ReplicateOp::create(rewriter, op.getLoc(),
521 lhsSignBit, rhsBaseWidth);
522 auto lhsSignAndRhs =
523 comb::AndOp::create(rewriter, op.getLoc(), lhsSignReplicate, rhsBase);
524 auto lhsSignCorrection =
525 comb::createOrFoldNot(op.getLoc(), lhsSignAndRhs, rewriter, true);
526
527 // zext({lhsSignCorrection, lhsBaseWidth{1'b0}})
528 auto alignLhsSignCorrection = zeroPad(
529 rewriter, op.getLoc(), lhsSignCorrection, inputWidth, lhsBaseWidth);
530
531 // Create ~(replicate(sign(rhs)) & lhsBase)
532 auto rhsSignReplicate = comb::ReplicateOp::create(rewriter, op.getLoc(),
533 rhsSignBit, lhsBaseWidth);
534 auto rhsSignAndLhs =
535 comb::AndOp::create(rewriter, op.getLoc(), rhsSignReplicate, lhsBase);
536 auto rhsSignCorrection =
537 comb::createOrFoldNot(op.getLoc(), rhsSignAndLhs, rewriter, true);
538
539 // zext({rhsSignCorrection, rhsBaseWidth{1'b0}})
540 auto alignRhsSignCorrection = zeroPad(
541 rewriter, op.getLoc(), rhsSignCorrection, inputWidth, rhsBaseWidth);
542
543 // 2^(p+q-2) * sign(lhs) * sign(rhs) = (sign(lhs) & sign(rhs)) << (p+q-2)
544 // Create sign(lhs) & sign(rhs)
545 auto signAnd =
546 comb::AndOp::create(rewriter, op.getLoc(), lhsSignBit, rhsSignBit);
547 // zext({sign(lhs) & sign(rhs), lhsBaseWidth+rhsBaseWidth{1'b0}})
548 auto alignSignAndZext = zeroPad(rewriter, op.getLoc(), signAnd, inputWidth,
549 lhsBaseWidth + rhsBaseWidth);
550
551 // Gather constant corrections together (once for each sign correction):
552 // (-1) << (p+q-2) + (1<<(p-1)) - 1 + 1
553 auto ones = APInt::getAllOnes(inputWidth);
554 auto lowerLhs = APInt::getOneBitSet(inputWidth, lhsBaseWidth);
555 auto lowerRhs = APInt::getOneBitSet(inputWidth, rhsBaseWidth);
556 auto msbCorrection = ones << (lhsBaseWidth + rhsBaseWidth);
557 auto correction = lowerLhs + lowerRhs + 2 * msbCorrection;
558
559 auto constantCorrection =
560 hw::ConstantOp::create(rewriter, op.getLoc(), correction);
561
562 auto zero = hw::ConstantOp::create(rewriter, op.getLoc(),
563 APInt::getZero(inputWidth));
564 // Collect newPP results and pad with zeros if needed
565 SmallVector<Value> newResults(newPP.getResults().begin(),
566 newPP.getResults().end());
567
568 // ~(replicate(sign(lhs)) & rhsBase) * 2^(p-1)
569 newResults.push_back(alignLhsSignCorrection);
570 // ~(replicate(sign(rhs)) & lhsBase) * 2^(q-1)
571 newResults.push_back(alignRhsSignCorrection);
572 // sign(lhs)*sign(rhs) * 2^(p+q-2)
573 newResults.push_back(alignSignAndZext);
574 // Constant correction
575 newResults.push_back(constantCorrection);
576 // Zero pad if necessary
577 newResults.append(op.getNumResults() - newResults.size(), zero);
578
579 rewriter.replaceOp(op, newResults);
580 return success();
581 }
582};
583
584struct PosPartialProducts : public OpRewritePattern<PartialProductOp> {
585 using OpRewritePattern::OpRewritePattern;
586
587 // pp(add(a,b),c) -> pos_pp(a,b,c)
588 LogicalResult matchAndRewrite(PartialProductOp op,
589 PatternRewriter &rewriter) const override {
590 auto width = op.getType(0).getIntOrFloatBitWidth();
591
592 assert(op.getNumOperands() == 2);
593
594 // Detect if any input is an AddOp
595 auto lhsAdder = op.getOperand(0).getDefiningOp<comb::AddOp>();
596 auto rhsAdder = op.getOperand(1).getDefiningOp<comb::AddOp>();
597 if ((lhsAdder && rhsAdder) || !(lhsAdder || rhsAdder))
598 return failure();
599 auto addInput = lhsAdder ? lhsAdder : rhsAdder;
600 auto otherInput = lhsAdder ? op.getOperand(1) : op.getOperand(0);
601
602 if (addInput->getNumOperands() != 2)
603 return failure();
604
605 Value addend0 = addInput->getOperand(0);
606 Value addend1 = addInput->getOperand(1);
607
608 rewriter.replaceOpWithNewOp<PosPartialProductOp>(
609 op, ValueRange{addend0, addend1, otherInput}, width);
610 return success();
611 }
612};
613
614void PartialProductOp::getCanonicalizationPatterns(RewritePatternSet &results,
615 MLIRContext *context) {
616 results
618 context);
619}
620
621//===----------------------------------------------------------------------===//
622// Pos Partial Product Operation
623//===----------------------------------------------------------------------===//
625 : public OpRewritePattern<PosPartialProductOp> {
626 using OpRewritePattern::OpRewritePattern;
627
628 // pos_pp(concat(0,a), concat(0,b), c) -> reduced number of results
629 LogicalResult matchAndRewrite(PosPartialProductOp op,
630 PatternRewriter &rewriter) const override {
631 unsigned inputWidth = op.getAddend0().getType().getIntOrFloatBitWidth();
632 auto addend0NonZero =
633 calculateNonZeroBits(op.getAddend0(), op.getNumResults());
634 auto addend1NonZero =
635 calculateNonZeroBits(op.getAddend1(), op.getNumResults());
636
637 if (failed(addend0NonZero) || failed(addend1NonZero))
638 return failure();
639
640 // Need the +1 for the carry-out
641 size_t maxNonZeroBits = std::max(*addend0NonZero, *addend1NonZero) + 1;
642
643 if (maxNonZeroBits >= op.getNumResults())
644 return failure();
645
646 auto newPP = datapath::PosPartialProductOp::create(
647 rewriter, op.getLoc(), op.getOperands(), maxNonZeroBits);
648
649 auto zero = hw::ConstantOp::create(rewriter, op.getLoc(),
650 APInt::getZero(inputWidth));
651
652 // Collect newPP results and pad with zeros if needed
653 SmallVector<Value> newResults(newPP.getResults().begin(),
654 newPP.getResults().end());
655
656 newResults.append(op.getNumResults() - newResults.size(), zero);
657
658 rewriter.replaceOp(op, newResults);
659 return success();
660 }
661};
662
663void PosPartialProductOp::getCanonicalizationPatterns(
664 RewritePatternSet &results, MLIRContext *context) {
666}
assert(baseType &&"element must be base type")
static bool areAllCompressorResultsSummed(ValueRange compressResults, ValueRange operands)
static FailureOr< Value > isOneExt(Value operand)
static FailureOr< size_t > calculateNonZeroBits(Value operand, size_t numResults)
static Value zeroPad(PatternRewriter &rewriter, Location loc, Value input, size_t targetWidth, size_t trailingZeros)
static std::unique_ptr< Context > context
create(low_bit, result_type, input=None)
Definition comb.py:187
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
LogicalResult matchAndRewrite(CompressOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(comb::AddOp addOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(datapath::CompressOp compOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CompressOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PartialProductOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PartialProductOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PosPartialProductOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(CompressOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PartialProductOp op, PatternRewriter &rewriter) const override