CIRCT 22.0.0git
Loading...
Searching...
No Matches
DatapathFolds.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/IR/Matchers.h"
14#include "mlir/IR/PatternMatch.h"
15#include "llvm/Support/Casting.h"
16#include "llvm/Support/KnownBits.h"
17#include <algorithm>
18
19using namespace mlir;
20using namespace circt;
21using namespace datapath;
22using namespace matchers;
23
24//===----------------------------------------------------------------------===//
25// Utility Functions
26//===----------------------------------------------------------------------===//
27static FailureOr<size_t> calculateNonZeroBits(Value operand,
28 size_t numResults) {
29 // If the extracted bits are all known, then return the result.
30 auto knownBits = comb::computeKnownBits(operand);
31 if (knownBits.isUnknown())
32 return failure(); // Skip if we don't know anything about the bits
33
34 size_t nonZeroBits = operand.getType().getIntOrFloatBitWidth() -
35 knownBits.Zero.countLeadingOnes();
36
37 // If all bits non-zero we will not reduce the number of results
38 if (nonZeroBits == numResults)
39 return failure();
40
41 return nonZeroBits;
42}
43
44//===----------------------------------------------------------------------===//
45// Compress Operation
46//===----------------------------------------------------------------------===//
47// Check that all compressor results are included in this list of operands
48// If not we must take care as manipulating compressor results independently
49// could easily introduce a non-equivalent representation.
50static bool areAllCompressorResultsSummed(ValueRange compressResults,
51 ValueRange operands) {
52 for (auto result : compressResults) {
53 if (!llvm::is_contained(operands, result))
54 return false;
55 }
56 return true;
57}
58
60 : public OpRewritePattern<datapath::CompressOp> {
61 using OpRewritePattern::OpRewritePattern;
62
63 // compress(compress(a,b,c), add(e,f)) -> compress(a,b,c,e,f)
64 LogicalResult matchAndRewrite(datapath::CompressOp compOp,
65 PatternRewriter &rewriter) const override {
66 auto operands = compOp.getOperands();
67 llvm::SmallSetVector<Value, 8> processedCompressorResults;
68 SmallVector<Value, 8> newCompressOperands;
69
70 for (Value operand : operands) {
71
72 // Skip if already processed this compressor
73 if (processedCompressorResults.contains(operand))
74 continue;
75
76 // If the operand has multiple uses, we do not fold it into a compress
77 // operation, so we treat it as a regular operand to maintain sharing.
78 if (!operand.hasOneUse()) {
79 newCompressOperands.push_back(operand);
80 continue;
81 }
82
83 // Found a compress op - add its operands to our new list
84 if (auto compressOp = operand.getDefiningOp<datapath::CompressOp>()) {
85
86 // Check that all results of the compressor are summed in this add
87 if (!areAllCompressorResultsSummed(compressOp.getResults(), operands))
88 return failure();
89
90 llvm::append_range(newCompressOperands, compressOp.getOperands());
91 // Only process each compressor once as multiple operands will point
92 // to the same defining operation
93 processedCompressorResults.insert(compressOp.getResults().begin(),
94 compressOp.getResults().end());
95 continue;
96 }
97
98 if (auto addOp = operand.getDefiningOp<comb::AddOp>()) {
99 llvm::append_range(newCompressOperands, addOp.getOperands());
100 continue;
101 }
102
103 // Regular operand - just add it to our list
104 newCompressOperands.push_back(operand);
105 }
106
107 // If unable to collect more operands then this pattern doesn't apply
108 if (newCompressOperands.size() <= compOp.getNumOperands())
109 return failure();
110
111 // Create a new CompressOp with all collected operands
112 rewriter.replaceOpWithNewOp<datapath::CompressOp>(
113 compOp, newCompressOperands, compOp.getNumResults());
114 return success();
115 }
116};
117
118struct FoldAddIntoCompress : public OpRewritePattern<comb::AddOp> {
119 using OpRewritePattern::OpRewritePattern;
120
121 // add(compress(a,b,c),d) -> add(compress(a,b,c,d))
122 // FIXME: This should be implemented as a canonicalization pattern for
123 // compress op. Currently `hasDatapathOperand` flag prevents introducing
124 // datapath operations from comb operations.
125 LogicalResult matchAndRewrite(comb::AddOp addOp,
126 PatternRewriter &rewriter) const override {
127 // comb.add canonicalization patterns handle folding add operations
128 if (addOp.getNumOperands() <= 2)
129 return failure();
130
131 // Get operands of the AddOp
132 auto operands = addOp.getOperands();
133 llvm::SmallSetVector<Value, 8> processedCompressorResults;
134 SmallVector<Value, 8> newCompressOperands;
135 // Only construct compressor if can form a larger compressor than what
136 // is currently an input of this add. Also check that there is at least
137 // one datapath operand.
138 bool shouldFold = false, hasDatapathOperand = false;
139
140 for (Value operand : operands) {
141
142 // Skip if already processed this compressor
143 if (processedCompressorResults.contains(operand))
144 continue;
145
146 if (auto *op = operand.getDefiningOp())
147 if (isa_and_nonnull<datapath::DatapathDialect>(op->getDialect()))
148 hasDatapathOperand = true;
149
150 // If the operand has multiple uses, we do not fold it into a compress
151 // operation, so we treat it as a regular operand.
152 if (!operand.hasOneUse()) {
153 shouldFold |= !newCompressOperands.empty();
154 newCompressOperands.push_back(operand);
155 continue;
156 }
157
158 // Found a compress op - add its operands to our new list
159 if (auto compressOp = operand.getDefiningOp<datapath::CompressOp>()) {
160
161 // Check that all results of the compressor are summed in this add
162 if (!areAllCompressorResultsSummed(compressOp.getResults(), operands))
163 return failure();
164
165 // If we've already added one operand it should be folded
166 shouldFold |= !newCompressOperands.empty();
167 llvm::append_range(newCompressOperands, compressOp.getOperands());
168 // Only process each compressor once
169 processedCompressorResults.insert(compressOp.getResults().begin(),
170 compressOp.getResults().end());
171 continue;
172 }
173
174 if (auto addOp = operand.getDefiningOp<comb::AddOp>()) {
175 shouldFold |= !newCompressOperands.empty();
176 llvm::append_range(newCompressOperands, addOp.getOperands());
177 continue;
178 }
179
180 // Regular operand - just add it to our list
181 shouldFold |= !newCompressOperands.empty();
182 newCompressOperands.push_back(operand);
183 }
184
185 // Only fold if we have constructed a larger compressor than what was
186 // already there
187 if (!shouldFold || !hasDatapathOperand)
188 return failure();
189
190 // Create a new CompressOp with all collected operands
191 auto newCompressOp = datapath::CompressOp::create(rewriter, addOp.getLoc(),
192 newCompressOperands, 2);
193
194 // Replace the original AddOp with a new add(compress(inputs))
195 rewriter.replaceOpWithNewOp<comb::AddOp>(addOp, newCompressOp.getResults(),
196 true);
197 return success();
198 }
199};
200
201struct ConstantFoldCompress : public OpRewritePattern<CompressOp> {
202 using OpRewritePattern::OpRewritePattern;
203
204 LogicalResult matchAndRewrite(CompressOp op,
205 PatternRewriter &rewriter) const override {
206 auto inputs = op.getInputs();
207 auto size = inputs.size();
208
209 APInt value;
210
211 // compress(..., 0) -> compress(...) -- identity
212 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
213
214 // If only reducing by one row and contains zero - pass through operands
215 if (size - 1 == op.getNumResults()) {
216 rewriter.replaceOp(op, inputs.drop_back());
217 return success();
218 }
219
220 // Default create a compressor with fewer arguments
221 rewriter.replaceOpWithNewOp<CompressOp>(op, inputs.drop_back(),
222 op.getNumResults());
223 return success();
224 }
225
226 return failure();
227 }
228};
229
230void CompressOp::getCanonicalizationPatterns(RewritePatternSet &results,
231 MLIRContext *context) {
232
233 results
235 context);
236}
237
238//===----------------------------------------------------------------------===//
239// Partial Product Operation
240//===----------------------------------------------------------------------===//
241struct ReduceNumPartialProducts : public OpRewritePattern<PartialProductOp> {
242 using OpRewritePattern::OpRewritePattern;
243
244 // pp(concat(0,a), concat(0,b)) -> reduced number of results
245 LogicalResult matchAndRewrite(PartialProductOp op,
246 PatternRewriter &rewriter) const override {
247 auto operands = op.getOperands();
248 unsigned inputWidth = operands[0].getType().getIntOrFloatBitWidth();
249
250 // TODO: implement a constant multiplication for the PartialProductOp
251
252 auto op0NonZeroBits = calculateNonZeroBits(operands[0], op.getNumResults());
253 auto op1NonZeroBits = calculateNonZeroBits(operands[1], op.getNumResults());
254
255 if (failed(op0NonZeroBits) || failed(op1NonZeroBits))
256 return failure();
257
258 // Need the +1 for the carry-out
259 size_t maxNonZeroBits = std::max(*op0NonZeroBits, *op1NonZeroBits);
260
261 auto newPP = datapath::PartialProductOp::create(
262 rewriter, op.getLoc(), op.getOperands(), maxNonZeroBits);
263
264 auto zero = hw::ConstantOp::create(rewriter, op.getLoc(),
265 APInt::getZero(inputWidth));
266
267 // Collect newPP results and pad with zeros if needed
268 SmallVector<Value> newResults(newPP.getResults().begin(),
269 newPP.getResults().end());
270
271 newResults.append(op.getNumResults() - newResults.size(), zero);
272
273 rewriter.replaceOp(op, newResults);
274 return success();
275 }
276};
277
278struct PosPartialProducts : public OpRewritePattern<PartialProductOp> {
279 using OpRewritePattern::OpRewritePattern;
280
281 // pp(add(a,b),c) -> pos_pp(a,b,c)
282 LogicalResult matchAndRewrite(PartialProductOp op,
283 PatternRewriter &rewriter) const override {
284 auto width = op.getType(0).getIntOrFloatBitWidth();
285
286 assert(op.getNumOperands() == 2);
287
288 // Detect if any input is an AddOp
289 auto lhsAdder = op.getOperand(0).getDefiningOp<comb::AddOp>();
290 auto rhsAdder = op.getOperand(1).getDefiningOp<comb::AddOp>();
291 if ((lhsAdder && rhsAdder) || !(lhsAdder || rhsAdder))
292 return failure();
293 auto addInput = lhsAdder ? lhsAdder : rhsAdder;
294 auto otherInput = lhsAdder ? op.getOperand(1) : op.getOperand(0);
295
296 if (addInput->getNumOperands() != 2)
297 return failure();
298
299 Value addend0 = addInput->getOperand(0);
300 Value addend1 = addInput->getOperand(1);
301
302 rewriter.replaceOpWithNewOp<PosPartialProductOp>(
303 op, ValueRange{addend0, addend1, otherInput}, width);
304 return success();
305 }
306};
307
308void PartialProductOp::getCanonicalizationPatterns(RewritePatternSet &results,
309 MLIRContext *context) {
310
311 results.add<ReduceNumPartialProducts, PosPartialProducts>(context);
312}
313
314//===----------------------------------------------------------------------===//
315// Pos Partial Product Operation
316//===----------------------------------------------------------------------===//
318 : public OpRewritePattern<PosPartialProductOp> {
319 using OpRewritePattern::OpRewritePattern;
320
321 // pos_pp(concat(0,a), concat(0,b), c) -> reduced number of results
322 LogicalResult matchAndRewrite(PosPartialProductOp op,
323 PatternRewriter &rewriter) const override {
324 unsigned inputWidth = op.getAddend0().getType().getIntOrFloatBitWidth();
325 auto addend0NonZero =
326 calculateNonZeroBits(op.getAddend0(), op.getNumResults());
327 auto addend1NonZero =
328 calculateNonZeroBits(op.getAddend1(), op.getNumResults());
329
330 if (failed(addend0NonZero) || failed(addend1NonZero))
331 return failure();
332
333 // Need the +1 for the carry-out
334 size_t maxNonZeroBits = std::max(*addend0NonZero, *addend1NonZero) + 1;
335
336 if (maxNonZeroBits >= op.getNumResults())
337 return failure();
338
339 auto newPP = datapath::PosPartialProductOp::create(
340 rewriter, op.getLoc(), op.getOperands(), maxNonZeroBits);
341
342 auto zero = hw::ConstantOp::create(rewriter, op.getLoc(),
343 APInt::getZero(inputWidth));
344
345 // Collect newPP results and pad with zeros if needed
346 SmallVector<Value> newResults(newPP.getResults().begin(),
347 newPP.getResults().end());
348
349 newResults.append(op.getNumResults() - newResults.size(), zero);
350
351 rewriter.replaceOp(op, newResults);
352 return success();
353 }
354};
355
356void PosPartialProductOp::getCanonicalizationPatterns(
357 RewritePatternSet &results, MLIRContext *context) {
358
359 results.add<ReduceNumPosPartialProducts>(context);
360}
assert(baseType &&"element must be base type")
static bool areAllCompressorResultsSummed(ValueRange compressResults, ValueRange operands)
static FailureOr< size_t > calculateNonZeroBits(Value operand, size_t numResults)
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
LogicalResult matchAndRewrite(CompressOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(comb::AddOp addOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(datapath::CompressOp compOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PartialProductOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PartialProductOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PosPartialProductOp op, PatternRewriter &rewriter) const override