CIRCT 22.0.0git
Loading...
Searching...
No Matches
DatapathFolds.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/Matchers.h"
13#include "mlir/IR/PatternMatch.h"
14#include "llvm/Support/KnownBits.h"
15#include <algorithm>
16
17using namespace mlir;
18using namespace circt;
19using namespace datapath;
20using namespace matchers;
21
22//===----------------------------------------------------------------------===//
23// Utility Functions
24//===----------------------------------------------------------------------===//
25static FailureOr<size_t> calculateNonZeroBits(Value operand,
26 size_t numResults) {
27 // If the extracted bits are all known, then return the result.
28 auto knownBits = comb::computeKnownBits(operand);
29 if (knownBits.isUnknown())
30 return failure(); // Skip if we don't know anything about the bits
31
32 size_t nonZeroBits = operand.getType().getIntOrFloatBitWidth() -
33 knownBits.Zero.countLeadingOnes();
34
35 // If all bits non-zero we will not reduce the number of results
36 if (nonZeroBits == numResults)
37 return failure();
38
39 return nonZeroBits;
40}
41
42//===----------------------------------------------------------------------===//
43// Compress Operation
44//===----------------------------------------------------------------------===//
45// Check that all compressor results are included in this list of operands
46// If not we must take care as manipulating compressor results independently
47// could easily introduce a non-equivalent representation.
48static bool areAllCompressorResultsSummed(ValueRange compressResults,
49 ValueRange operands) {
50 for (auto result : compressResults) {
51 if (!llvm::is_contained(operands, result))
52 return false;
53 }
54 return true;
55}
56
58 : public OpRewritePattern<datapath::CompressOp> {
59 using OpRewritePattern::OpRewritePattern;
60
61 // compress(compress(a,b,c), add(e,f)) -> compress(a,b,c,e,f)
62 LogicalResult matchAndRewrite(datapath::CompressOp compOp,
63 PatternRewriter &rewriter) const override {
64 auto operands = compOp.getOperands();
65 llvm::SmallSetVector<Value, 8> processedCompressorResults;
66 SmallVector<Value, 8> newCompressOperands;
67
68 for (Value operand : operands) {
69
70 // Skip if already processed this compressor
71 if (processedCompressorResults.contains(operand))
72 continue;
73
74 // If the operand has multiple uses, we do not fold it into a compress
75 // operation, so we treat it as a regular operand to maintain sharing.
76 if (!operand.hasOneUse()) {
77 newCompressOperands.push_back(operand);
78 continue;
79 }
80
81 // Found a compress op - add its operands to our new list
82 if (auto compressOp = operand.getDefiningOp<datapath::CompressOp>()) {
83
84 // Check that all results of the compressor are summed in this add
85 if (!areAllCompressorResultsSummed(compressOp.getResults(), operands))
86 return failure();
87
88 llvm::append_range(newCompressOperands, compressOp.getOperands());
89 // Only process each compressor once as multiple operands will point
90 // to the same defining operation
91 processedCompressorResults.insert(compressOp.getResults().begin(),
92 compressOp.getResults().end());
93 continue;
94 }
95
96 if (auto addOp = operand.getDefiningOp<comb::AddOp>()) {
97 llvm::append_range(newCompressOperands, addOp.getOperands());
98 continue;
99 }
100
101 // Regular operand - just add it to our list
102 newCompressOperands.push_back(operand);
103 }
104
105 // If unable to collect more operands then this pattern doesn't apply
106 if (newCompressOperands.size() <= compOp.getNumOperands())
107 return failure();
108
109 // Create a new CompressOp with all collected operands
110 rewriter.replaceOpWithNewOp<datapath::CompressOp>(
111 compOp, newCompressOperands, compOp.getNumResults());
112 return success();
113 }
114};
115
116struct FoldAddIntoCompress : public OpRewritePattern<comb::AddOp> {
117 using OpRewritePattern::OpRewritePattern;
118
119 // add(compress(a,b,c),d) -> add(compress(a,b,c,d))
120 LogicalResult matchAndRewrite(comb::AddOp addOp,
121 PatternRewriter &rewriter) const override {
122 // comb.add canonicalization patterns handle folding add operations
123 if (addOp.getNumOperands() <= 2)
124 return failure();
125
126 // Get operands of the AddOp
127 auto operands = addOp.getOperands();
128 llvm::SmallSetVector<Value, 8> processedCompressorResults;
129 SmallVector<Value, 8> newCompressOperands;
130 // Only construct compressor if can form a larger compressor than what
131 // is currently an input of this add
132 bool shouldFold = false;
133
134 for (Value operand : operands) {
135
136 // Skip if already processed this compressor
137 if (processedCompressorResults.contains(operand))
138 continue;
139
140 // If the operand has multiple uses, we do not fold it into a compress
141 // operation, so we treat it as a regular operand.
142 if (!operand.hasOneUse()) {
143 shouldFold |= !newCompressOperands.empty();
144 newCompressOperands.push_back(operand);
145 continue;
146 }
147
148 // Found a compress op - add its operands to our new list
149 if (auto compressOp = operand.getDefiningOp<datapath::CompressOp>()) {
150
151 // Check that all results of the compressor are summed in this add
152 if (!areAllCompressorResultsSummed(compressOp.getResults(), operands))
153 return failure();
154
155 // If we've already added one operand it should be folded
156 shouldFold |= !newCompressOperands.empty();
157 llvm::append_range(newCompressOperands, compressOp.getOperands());
158 // Only process each compressor once
159 processedCompressorResults.insert(compressOp.getResults().begin(),
160 compressOp.getResults().end());
161 continue;
162 }
163
164 if (auto addOp = operand.getDefiningOp<comb::AddOp>()) {
165 shouldFold |= !newCompressOperands.empty();
166 llvm::append_range(newCompressOperands, addOp.getOperands());
167 continue;
168 }
169
170 // Regular operand - just add it to our list
171 shouldFold |= !newCompressOperands.empty();
172 newCompressOperands.push_back(operand);
173 }
174
175 // Only fold if we have constructed a larger compressor than what was
176 // already there
177 if (!shouldFold)
178 return failure();
179
180 // Create a new CompressOp with all collected operands
181 auto newCompressOp = datapath::CompressOp::create(rewriter, addOp.getLoc(),
182 newCompressOperands, 2);
183
184 // Replace the original AddOp with a new add(compress(inputs))
185 rewriter.replaceOpWithNewOp<comb::AddOp>(addOp, newCompressOp.getResults(),
186 true);
187 return success();
188 }
189};
190
191struct ConstantFoldCompress : public OpRewritePattern<CompressOp> {
192 using OpRewritePattern::OpRewritePattern;
193
194 LogicalResult matchAndRewrite(CompressOp op,
195 PatternRewriter &rewriter) const override {
196 auto inputs = op.getInputs();
197 auto size = inputs.size();
198
199 APInt value;
200
201 // compress(..., 0) -> compress(...) -- identity
202 if (matchPattern(inputs.back(), m_ConstantInt(&value)) && value.isZero()) {
203
204 // If only reducing by one row and contains zero - pass through operands
205 if (size - 1 == op.getNumResults()) {
206 rewriter.replaceOp(op, inputs.drop_back());
207 return success();
208 }
209
210 // Default create a compressor with fewer arguments
211 rewriter.replaceOpWithNewOp<CompressOp>(op, inputs.drop_back(),
212 op.getNumResults());
213 return success();
214 }
215
216 return failure();
217 }
218};
219
220void CompressOp::getCanonicalizationPatterns(RewritePatternSet &results,
221 MLIRContext *context) {
222
223 results
225 context);
226}
227
228//===----------------------------------------------------------------------===//
229// Partial Product Operation
230//===----------------------------------------------------------------------===//
231struct ReduceNumPartialProducts : public OpRewritePattern<PartialProductOp> {
232 using OpRewritePattern::OpRewritePattern;
233
234 // pp(concat(0,a), concat(0,b)) -> reduced number of results
235 LogicalResult matchAndRewrite(PartialProductOp op,
236 PatternRewriter &rewriter) const override {
237 auto operands = op.getOperands();
238 unsigned inputWidth = operands[0].getType().getIntOrFloatBitWidth();
239
240 // TODO: implement a constant multiplication for the PartialProductOp
241
242 auto op0NonZeroBits = calculateNonZeroBits(operands[0], op.getNumResults());
243 auto op1NonZeroBits = calculateNonZeroBits(operands[1], op.getNumResults());
244
245 if (failed(op0NonZeroBits) || failed(op1NonZeroBits))
246 return failure();
247
248 // Need the +1 for the carry-out
249 size_t maxNonZeroBits = std::max(*op0NonZeroBits, *op1NonZeroBits);
250
251 auto newPP = datapath::PartialProductOp::create(
252 rewriter, op.getLoc(), op.getOperands(), maxNonZeroBits);
253
254 auto zero = hw::ConstantOp::create(rewriter, op.getLoc(),
255 APInt::getZero(inputWidth));
256
257 // Collect newPP results and pad with zeros if needed
258 SmallVector<Value> newResults(newPP.getResults().begin(),
259 newPP.getResults().end());
260
261 newResults.append(op.getNumResults() - newResults.size(), zero);
262
263 rewriter.replaceOp(op, newResults);
264 return success();
265 }
266};
267
268struct PosPartialProducts : public OpRewritePattern<PartialProductOp> {
269 using OpRewritePattern::OpRewritePattern;
270
271 // pp(add(a,b),c) -> pos_pp(a,b,c)
272 LogicalResult matchAndRewrite(PartialProductOp op,
273 PatternRewriter &rewriter) const override {
274 auto width = op.getType(0).getIntOrFloatBitWidth();
275
276 assert(op.getNumOperands() == 2);
277
278 // Detect if any input is an AddOp
279 auto lhsAdder = op.getOperand(0).getDefiningOp<comb::AddOp>();
280 auto rhsAdder = op.getOperand(1).getDefiningOp<comb::AddOp>();
281 if ((lhsAdder && rhsAdder) | !(lhsAdder || rhsAdder))
282 return failure();
283 auto addInput = lhsAdder ? lhsAdder : rhsAdder;
284 auto otherInput = lhsAdder ? op.getOperand(1) : op.getOperand(0);
285
286 if (addInput->getNumOperands() != 2)
287 return failure();
288
289 Value addend0 = addInput->getOperand(0);
290 Value addend1 = addInput->getOperand(1);
291
292 rewriter.replaceOpWithNewOp<PosPartialProductOp>(
293 op, ValueRange{addend0, addend1, otherInput}, width);
294 return success();
295 }
296};
297
298void PartialProductOp::getCanonicalizationPatterns(RewritePatternSet &results,
299 MLIRContext *context) {
300
301 results.add<ReduceNumPartialProducts, PosPartialProducts>(context);
302}
303
304//===----------------------------------------------------------------------===//
305// Pos Partial Product Operation
306//===----------------------------------------------------------------------===//
308 : public OpRewritePattern<PosPartialProductOp> {
309 using OpRewritePattern::OpRewritePattern;
310
311 // pos_pp(concat(0,a), concat(0,b), c) -> reduced number of results
312 LogicalResult matchAndRewrite(PosPartialProductOp op,
313 PatternRewriter &rewriter) const override {
314 unsigned inputWidth = op.getAddend0().getType().getIntOrFloatBitWidth();
315 auto addend0NonZero =
316 calculateNonZeroBits(op.getAddend0(), op.getNumResults());
317 auto addend1NonZero =
318 calculateNonZeroBits(op.getAddend1(), op.getNumResults());
319
320 if (failed(addend0NonZero) || failed(addend1NonZero))
321 return failure();
322
323 // Need the +1 for the carry-out
324 size_t maxNonZeroBits = std::max(*addend0NonZero, *addend1NonZero) + 1;
325
326 if (maxNonZeroBits >= op.getNumResults())
327 return failure();
328
329 auto newPP = datapath::PosPartialProductOp::create(
330 rewriter, op.getLoc(), op.getOperands(), maxNonZeroBits);
331
332 auto zero = hw::ConstantOp::create(rewriter, op.getLoc(),
333 APInt::getZero(inputWidth));
334
335 // Collect newPP results and pad with zeros if needed
336 SmallVector<Value> newResults(newPP.getResults().begin(),
337 newPP.getResults().end());
338
339 newResults.append(op.getNumResults() - newResults.size(), zero);
340
341 rewriter.replaceOp(op, newResults);
342 return success();
343 }
344};
345
346void PosPartialProductOp::getCanonicalizationPatterns(
347 RewritePatternSet &results, MLIRContext *context) {
348
349 results.add<ReduceNumPosPartialProducts>(context);
350}
assert(baseType &&"element must be base type")
static bool areAllCompressorResultsSummed(ValueRange compressResults, ValueRange operands)
static FailureOr< size_t > calculateNonZeroBits(Value operand, size_t numResults)
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
LogicalResult matchAndRewrite(CompressOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(comb::AddOp addOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(datapath::CompressOp compOp, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PartialProductOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PartialProductOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(PosPartialProductOp op, PatternRewriter &rewriter) const override