CIRCT 23.0.0git
Loading...
Searching...
No Matches
DatapathFolds.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/PatternMatch.h"
15#include "llvm/Support/Casting.h"
16#include "llvm/Support/KnownBits.h"
17#include <algorithm>
18
19using namespace mlir;
20using namespace circt;
21using namespace datapath;
22using namespace matchers;
23
24//===----------------------------------------------------------------------===//
25// Utility Functions
26//===----------------------------------------------------------------------===//
27static FailureOr<size_t> calculateNonZeroBits(Value operand,
28 size_t numResults) {
29 // If the extracted bits are all known, then return the result.
30 auto knownBits = comb::computeKnownBits(operand);
31 if (knownBits.isUnknown())
32 return failure(); // Skip if we don't know anything about the bits
33
34 size_t nonZeroBits = operand.getType().getIntOrFloatBitWidth() -
35 knownBits.Zero.countLeadingOnes();
36
37 // If all bits non-zero we will not reduce the number of results
38 if (nonZeroBits == numResults)
39 return failure();
40
41 return nonZeroBits;
42}
43
44// Check if the operand is sext() and return the unextended operand:
45// signBit = comb.extract(baseValue, width-1, 1)
46// ext = comb.replicate(signBit, width-baseWidth)
47// sext = comb.concat(ext, baseValue)
48static FailureOr<Value> isSext(Value operand) {
49 // Check if operand is a concat operation
50 auto concatOp = operand.getDefiningOp<comb::ConcatOp>();
51 if (!concatOp)
52 return failure();
53
54 auto operands = concatOp.getOperands();
55 // ConcatOp must have exactly 2 operands: (sign_bits, original_value)
56 if (operands.size() != 2)
57 return failure();
58
59 Value signBits = operands[0];
60 Value originalValue = operands[1];
61 auto originalWidth = originalValue.getType().getIntOrFloatBitWidth();
62
63 // Check if signBits is a replicate operation
64 auto replicateOp = signBits.getDefiningOp<comb::ReplicateOp>();
65 if (!replicateOp)
66 return failure();
67
68 Value signBit = replicateOp.getInput();
69
70 // Check if signBit is the msb of originalValue
71 auto extractOp = signBit.getDefiningOp<comb::ExtractOp>();
72 if (!extractOp)
73 return failure();
74
75 if ((extractOp.getInput() != originalValue) ||
76 (extractOp.getLowBit() != originalWidth - 1) ||
77 (extractOp.getType().getIntOrFloatBitWidth() != 1))
78 return failure();
79
80 // Return the original unextended value
81 return originalValue;
82}
83
84// zext(input<<trailingZeros) to targetWidth
85static Value zeroPad(PatternRewriter &rewriter, Location loc, Value input,
86 size_t targetWidth, size_t trailingZeros) {
87 assert(trailingZeros > 0 && "zeroPad called with zero trailing zeros");
88 auto trailingZerosValue =
89 hw::ConstantOp::create(rewriter, loc, APInt::getZero(trailingZeros));
90 auto padTrailing = comb::ConcatOp::create(
91 rewriter, loc, ValueRange{input, trailingZerosValue});
92 return comb::createZExt(rewriter, loc, padTrailing, targetWidth);
93}
94
95//===----------------------------------------------------------------------===//
96// Compress Operation
97//===----------------------------------------------------------------------===//
98// Check that all compressor results are included in this list of operands
99// If not we must take care as manipulating compressor results independently
100// could easily introduce a non-equivalent representation.
101static bool areAllCompressorResultsSummed(ValueRange compressResults,
102 ValueRange operands) {
103 for (auto result : compressResults) {
104 if (!llvm::is_contained(operands, result))
105 return false;
106 }
107 return true;
108}
109
111 : public OpRewritePattern<datapath::CompressOp> {
112 using OpRewritePattern::OpRewritePattern;
113
114 // compress(compress(a,b,c), add(e,f)) -> compress(a,b,c,e,f)
115 LogicalResult matchAndRewrite(datapath::CompressOp compOp,
116 PatternRewriter &rewriter) const override {
117 auto operands = compOp.getOperands();
118 llvm::SmallSetVector<Value, 8> processedCompressorResults;
119 SmallVector<Value, 8> newCompressOperands;
120
121 for (Value operand : operands) {
122
123 // Skip if already processed this compressor
124 if (processedCompressorResults.contains(operand))
125 continue;
126
127 // If the operand has multiple uses, we do not fold it into a compress
128 // operation, so we treat it as a regular operand to maintain sharing.
129 if (!operand.hasOneUse()) {
130 newCompressOperands.push_back(operand);
131 continue;
132 }
133
134 // Found a compress op - add its operands to our new list
135 if (auto compressOp = operand.getDefiningOp<datapath::CompressOp>()) {
136
137 // Check that all results of the compressor are summed in this add
138 if (!areAllCompressorResultsSummed(compressOp.getResults(), operands))
139 return failure();
140
141 llvm::append_range(newCompressOperands, compressOp.getOperands());
142 // Only process each compressor once as multiple operands will point
143 // to the same defining operation
144 processedCompressorResults.insert(compressOp.getResults().begin(),
145 compressOp.getResults().end());
146 continue;
147 }
148
149 if (auto addOp = operand.getDefiningOp<comb::AddOp>()) {
150 llvm::append_range(newCompressOperands, addOp.getOperands());
151 continue;
152 }
153
154 // Regular operand - just add it to our list
155 newCompressOperands.push_back(operand);
156 }
157
158 // If unable to collect more operands then this pattern doesn't apply
159 if (newCompressOperands.size() <= compOp.getNumOperands())
160 return failure();
161
162 // Create a new CompressOp with all collected operands
163 rewriter.replaceOpWithNewOp<datapath::CompressOp>(
164 compOp, newCompressOperands, compOp.getNumResults());
165 return success();
166 }
167};
168
169struct FoldAddIntoCompress : public OpRewritePattern<comb::AddOp> {
170 using OpRewritePattern::OpRewritePattern;
171
172 // add(compress(a,b,c),d) -> add(compress(a,b,c,d))
173 // FIXME: This should be implemented as a canonicalization pattern for
174 // compress op. Currently `hasDatapathOperand` flag prevents introducing
175 // datapath operations from comb operations.
176 LogicalResult matchAndRewrite(comb::AddOp addOp,
177 PatternRewriter &rewriter) const override {
178 // comb.add canonicalization patterns handle folding add operations
179 if (addOp.getNumOperands() <= 2)
180 return failure();
181
182 // Get operands of the AddOp
183 auto operands = addOp.getOperands();
184 llvm::SmallSetVector<Value, 8> processedCompressorResults;
185 SmallVector<Value, 8> newCompressOperands;
186 // Only construct compressor if can form a larger compressor than what
187 // is currently an input of this add. Also check that there is at least
188 // one datapath operand.
189 bool shouldFold = false, hasDatapathOperand = false;
190
191 for (Value operand : operands) {
192
193 // Skip if already processed this compressor
194 if (processedCompressorResults.contains(operand))
195 continue;
196
197 if (auto *op = operand.getDefiningOp())
198 if (isa_and_nonnull<datapath::DatapathDialect>(op->getDialect()))
199 hasDatapathOperand = true;
200
201 // If the operand has multiple uses, we do not fold it into a compress
202 // operation, so we treat it as a regular operand.
203 if (!operand.hasOneUse()) {
204 shouldFold |= !newCompressOperands.empty();
205 newCompressOperands.push_back(operand);
206 continue;
207 }
208
209 // Found a compress op - add its operands to our new list
210 if (auto compressOp = operand.getDefiningOp<datapath::CompressOp>()) {
211
212 // Check that all results of the compressor are summed in this add
213 if (!areAllCompressorResultsSummed(compressOp.getResults(), operands))
214 return failure();
215
216 // If we've already added one operand it should be folded
217 shouldFold |= !newCompressOperands.empty();
218 llvm::append_range(newCompressOperands, compressOp.getOperands());
219 // Only process each compressor once
220 processedCompressorResults.insert(compressOp.getResults().begin(),
221 compressOp.getResults().end());
222 continue;
223 }
224
225 if (auto addOp = operand.getDefiningOp<comb::AddOp>()) {
226 shouldFold |= !newCompressOperands.empty();
227 llvm::append_range(newCompressOperands, addOp.getOperands());
228 continue;
229 }
230
231 // Regular operand - just add it to our list
232 shouldFold |= !newCompressOperands.empty();
233 newCompressOperands.push_back(operand);
234 }
235
236 // Only fold if we have constructed a larger compressor than what was
237 // already there
238 if (!shouldFold || !hasDatapathOperand)
239 return failure();
240
241 // Create a new CompressOp with all collected operands
242 auto newCompressOp = datapath::CompressOp::create(rewriter, addOp.getLoc(),
243 newCompressOperands, 2);
244
245 // Replace the original AddOp with a new add(compress(inputs))
246 rewriter.replaceOpWithNewOp<comb::AddOp>(addOp, newCompressOp.getResults(),
247 true);
248 return success();
249 }
250};
251
252struct ConstantFoldCompress : public OpRewritePattern<CompressOp> {
253 using OpRewritePattern::OpRewritePattern;
254
255 LogicalResult matchAndRewrite(CompressOp op,
256 PatternRewriter &rewriter) const override {
257 auto inputs = op.getInputs();
258 auto size = inputs.size();
259
260 APInt value;
261
262 // compress(..., 0) -> compress(...) -- identity
263 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
264
265 // If only reducing by one row and contains zero - pass through operands
266 if (size - 1 == op.getNumResults()) {
267 rewriter.replaceOp(op, inputs.drop_back());
268 return success();
269 }
270
271 // Default create a compressor with fewer arguments
272 rewriter.replaceOpWithNewOp<CompressOp>(op, inputs.drop_back(),
273 op.getNumResults());
274 return success();
275 }
276
277 return failure();
278 }
279};
280
281void CompressOp::getCanonicalizationPatterns(RewritePatternSet &results,
282 MLIRContext *context) {
283 results
285 context);
286}
287
288//===----------------------------------------------------------------------===//
289// Partial Product Operation
290//===----------------------------------------------------------------------===//
291struct ReduceNumPartialProducts : public OpRewritePattern<PartialProductOp> {
292 using OpRewritePattern::OpRewritePattern;
293
294 // pp(concat(0,a), concat(0,b)) -> reduced number of results
295 LogicalResult matchAndRewrite(PartialProductOp op,
296 PatternRewriter &rewriter) const override {
297 auto operands = op.getOperands();
298 unsigned inputWidth = operands[0].getType().getIntOrFloatBitWidth();
299
300 // TODO: implement a constant multiplication for the PartialProductOp
301
302 auto op0NonZeroBits = calculateNonZeroBits(operands[0], op.getNumResults());
303 auto op1NonZeroBits = calculateNonZeroBits(operands[1], op.getNumResults());
304
305 if (failed(op0NonZeroBits) || failed(op1NonZeroBits))
306 return failure();
307
308 // Need the +1 for the carry-out
309 size_t maxNonZeroBits = std::max(*op0NonZeroBits, *op1NonZeroBits);
310
311 auto newPP = datapath::PartialProductOp::create(
312 rewriter, op.getLoc(), op.getOperands(), maxNonZeroBits);
313
314 auto zero = hw::ConstantOp::create(rewriter, op.getLoc(),
315 APInt::getZero(inputWidth));
316
317 // Collect newPP results and pad with zeros if needed
318 SmallVector<Value> newResults(newPP.getResults().begin(),
319 newPP.getResults().end());
320
321 newResults.append(op.getNumResults() - newResults.size(), zero);
322
323 rewriter.replaceOp(op, newResults);
324 return success();
325 }
326};
327
328struct SignedPartialProducts : public OpRewritePattern<PartialProductOp> {
329 using OpRewritePattern::OpRewritePattern;
330
331 // Based on the classical Baugh-Wooley algorithm for signed mulitplication.
332 // Paper: A Two's Complement Parallel Array Multiplication Algorithm
333 //
334 // Consider a p-bit by q-bit signed multiplier - producing a (p+q)-bit result:
335 // a_sign = a[p-1], a_mag = a[p-2:0],
336 // b_sign = b[q-1], b_mag = b[q-2:0]
337 // sext(a) * sext(b) = a_mag * b_mag [unsigned product]
338 // - 2^(p-1) * a_sign * b_mag [sign correction]
339 // - 2^(q-1) * b_sign * a_mag [sign correction]
340 // + 2^(p+q-2) * a_sign * b_sign [sign * sign]
341 //
342 // We implement optimizations to turn the subtractions into bitwise negations
343 // with constant corrections that can be folded together.
344 LogicalResult matchAndRewrite(PartialProductOp op,
345 PatternRewriter &rewriter) const override {
346 auto inputWidth = op.getOperand(0).getType().getIntOrFloatBitWidth();
347 auto lhs = isSext(op.getOperand(0));
348 auto rhs = isSext(op.getOperand(1));
349 if (failed(lhs) || failed(rhs))
350 return failure();
351
352 size_t lhsWidth = (*lhs).getType().getIntOrFloatBitWidth();
353 size_t rhsWidth = (*rhs).getType().getIntOrFloatBitWidth();
354 // Subtract 1 as will handle sign-bit separately
355 size_t maxRows = std::max(lhsWidth, rhsWidth) - 1;
356
357 // TODO: add support for different width inputs
358 // Need to have a sign bit in both inputs
359 if (lhsWidth != rhsWidth || lhsWidth <= 1 || rhsWidth <= 1)
360 return failure();
361
362 // No further reduction possible
363 if (maxRows >= op.getNumResults())
364 return failure();
365
366 // Pull off the sign bits
367 auto lhsBaseWidth = lhsWidth - 1;
368 auto rhsBaseWidth = rhsWidth - 1;
369 auto lhsSignBit =
370 comb::ExtractOp::create(rewriter, op.getLoc(), *lhs, lhsBaseWidth, 1);
371 auto rhsSignBit =
372 comb::ExtractOp::create(rewriter, op.getLoc(), *rhs, rhsBaseWidth, 1);
373 auto lhsBase =
374 comb::ExtractOp::create(rewriter, op.getLoc(), *lhs, 0, lhsBaseWidth);
375 auto rhsBase =
376 comb::ExtractOp::create(rewriter, op.getLoc(), *rhs, 0, rhsBaseWidth);
377
378 // Create the unsigned partial product of the unextended inputs
379 auto lhsBaseZext =
380 comb::createZExt(rewriter, op.getLoc(), lhsBase, inputWidth);
381 auto rhsBaseZext =
382 comb::createZExt(rewriter, op.getLoc(), rhsBase, inputWidth);
383 auto newPP = datapath::PartialProductOp::create(
384 rewriter, op.getLoc(), ValueRange{lhsBaseZext, rhsBaseZext}, maxRows);
385
386 // Optimization (similar for second sign correction), ext to (p+q)-bits:
387 // -2^(p-1)*sign(lhs)*rhsBase = ~((sign(lhs) * rhsBase) << (p-1)) + 1
388 // = (~(replicate(sign(lhs)) & rhsBase)) << (p-1)
389 // + (-1) << (p+q-2) [msb correction]
390 // + (1<<(p-1)) - 1 + 1 [lsb correction]
391
392 // Create ~(replicate(sign(lhs)) & rhsBase)
393 auto lhsSignReplicate = comb::ReplicateOp::create(rewriter, op.getLoc(),
394 lhsSignBit, rhsBaseWidth);
395 auto lhsSignAndRhs =
396 comb::AndOp::create(rewriter, op.getLoc(), lhsSignReplicate, rhsBase);
397 auto lhsSignCorrection =
398 comb::createOrFoldNot(op.getLoc(), lhsSignAndRhs, rewriter, true);
399
400 // zext({lhsSignCorrection, lhsBaseWidth{1'b0}})
401 auto alignLhsSignCorrection = zeroPad(
402 rewriter, op.getLoc(), lhsSignCorrection, inputWidth, lhsBaseWidth);
403
404 // Create ~(replicate(sign(rhs)) & lhsBase)
405 auto rhsSignReplicate = comb::ReplicateOp::create(rewriter, op.getLoc(),
406 rhsSignBit, lhsBaseWidth);
407 auto rhsSignAndLhs =
408 comb::AndOp::create(rewriter, op.getLoc(), rhsSignReplicate, lhsBase);
409 auto rhsSignCorrection =
410 comb::createOrFoldNot(op.getLoc(), rhsSignAndLhs, rewriter, true);
411
412 // zext({rhsSignCorrection, rhsBaseWidth{1'b0}})
413 auto alignRhsSignCorrection = zeroPad(
414 rewriter, op.getLoc(), rhsSignCorrection, inputWidth, rhsBaseWidth);
415
416 // 2^(p+q-2) * sign(lhs) * sign(rhs) = (sign(lhs) & sign(rhs)) << (p+q-2)
417 // Create sign(lhs) & sign(rhs)
418 auto signAnd =
419 comb::AndOp::create(rewriter, op.getLoc(), lhsSignBit, rhsSignBit);
420 // zext({sign(lhs) & sign(rhs), lhsBaseWidth+rhsBaseWidth{1'b0}})
421 auto alignSignAndZext = zeroPad(rewriter, op.getLoc(), signAnd, inputWidth,
422 lhsBaseWidth + rhsBaseWidth);
423
424 // Gather constant corrections together (once for each sign correction):
425 // (-1) << (p+q-2) + (1<<(p-1)) - 1 + 1
426 auto ones = APInt::getAllOnes(inputWidth);
427 auto lowerLhs = APInt::getOneBitSet(inputWidth, lhsBaseWidth);
428 auto lowerRhs = APInt::getOneBitSet(inputWidth, rhsBaseWidth);
429 auto msbCorrection = ones << (lhsBaseWidth + rhsBaseWidth);
430 auto correction = lowerLhs + lowerRhs + 2 * msbCorrection;
431
432 auto constantCorrection =
433 hw::ConstantOp::create(rewriter, op.getLoc(), correction);
434
435 auto zero = hw::ConstantOp::create(rewriter, op.getLoc(),
436 APInt::getZero(inputWidth));
437 // Collect newPP results and pad with zeros if needed
438 SmallVector<Value> newResults(newPP.getResults().begin(),
439 newPP.getResults().end());
440
441 // ~(replicate(sign(lhs)) & rhsBase) * 2^(p-1)
442 newResults.push_back(alignLhsSignCorrection);
443 // ~(replicate(sign(rhs)) & lhsBase) * 2^(q-1)
444 newResults.push_back(alignRhsSignCorrection);
445 // sign(lhs)*sign(rhs) * 2^(p+q-2)
446 newResults.push_back(alignSignAndZext);
447 // Constant correction
448 newResults.push_back(constantCorrection);
449 // Zero pad if necessary
450 newResults.append(op.getNumResults() - newResults.size(), zero);
451
452 rewriter.replaceOp(op, newResults);
453 return success();
454 }
455};
456
457struct PosPartialProducts : public OpRewritePattern<PartialProductOp> {
458 using OpRewritePattern::OpRewritePattern;
459
460 // pp(add(a,b),c) -> pos_pp(a,b,c)
461 LogicalResult matchAndRewrite(PartialProductOp op,
462 PatternRewriter &rewriter) const override {
463 auto width = op.getType(0).getIntOrFloatBitWidth();
464
465 assert(op.getNumOperands() == 2);
466
467 // Detect if any input is an AddOp
468 auto lhsAdder = op.getOperand(0).getDefiningOp<comb::AddOp>();
469 auto rhsAdder = op.getOperand(1).getDefiningOp<comb::AddOp>();
470 if ((lhsAdder && rhsAdder) || !(lhsAdder || rhsAdder))
471 return failure();
472 auto addInput = lhsAdder ? lhsAdder : rhsAdder;
473 auto otherInput = lhsAdder ? op.getOperand(1) : op.getOperand(0);
474
475 if (addInput->getNumOperands() != 2)
476 return failure();
477
478 Value addend0 = addInput->getOperand(0);
479 Value addend1 = addInput->getOperand(1);
480
481 rewriter.replaceOpWithNewOp<PosPartialProductOp>(
482 op, ValueRange{addend0, addend1, otherInput}, width);
483 return success();
484 }
485};
486
487void PartialProductOp::getCanonicalizationPatterns(RewritePatternSet &results,
488 MLIRContext *context) {
489 results
491 context);
492}
493
494//===----------------------------------------------------------------------===//
495// Pos Partial Product Operation
496//===----------------------------------------------------------------------===//
498 : public OpRewritePattern<PosPartialProductOp> {
499 using OpRewritePattern::OpRewritePattern;
500
501 // pos_pp(concat(0,a), concat(0,b), c) -> reduced number of results
502 LogicalResult matchAndRewrite(PosPartialProductOp op,
503 PatternRewriter &rewriter) const override {
504 unsigned inputWidth = op.getAddend0().getType().getIntOrFloatBitWidth();
505 auto addend0NonZero =
506 calculateNonZeroBits(op.getAddend0(), op.getNumResults());
507 auto addend1NonZero =
508 calculateNonZeroBits(op.getAddend1(), op.getNumResults());
509
510 if (failed(addend0NonZero) || failed(addend1NonZero))
511 return failure();
512
513 // Need the +1 for the carry-out
514 size_t maxNonZeroBits = std::max(*addend0NonZero, *addend1NonZero) + 1;
515
516 if (maxNonZeroBits >= op.getNumResults())
517 return failure();
518
519 auto newPP = datapath::PosPartialProductOp::create(
520 rewriter, op.getLoc(), op.getOperands(), maxNonZeroBits);
521
522 auto zero = hw::ConstantOp::create(rewriter, op.getLoc(),
523 APInt::getZero(inputWidth));
524
525 // Collect newPP results and pad with zeros if needed
526 SmallVector<Value> newResults(newPP.getResults().begin(),
527 newPP.getResults().end());
528
529 newResults.append(op.getNumResults() - newResults.size(), zero);
530
531 rewriter.replaceOp(op, newResults);
532 return success();
533 }
534};
535
536void PosPartialProductOp::getCanonicalizationPatterns(
537 RewritePatternSet &results, MLIRContext *context) {
539}
assert(baseType &&"element must be base type")
static FailureOr< Value > isSext(Value operand)
static bool areAllCompressorResultsSummed(ValueRange compressResults, ValueRange operands)
static FailureOr< size_t > calculateNonZeroBits(Value operand, size_t numResults)
static Value zeroPad(PatternRewriter &rewriter, Location loc, Value input, size_t targetWidth, size_t trailingZeros)
static std::unique_ptr< Context > context
create(low_bit, result_type, input=None)
Definition comb.py:187
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
LogicalResult matchAndRewrite(CompressOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(comb::AddOp addOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(datapath::CompressOp compOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PartialProductOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PartialProductOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PosPartialProductOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PartialProductOp op, PatternRewriter &rewriter) const override