CIRCT 22.0.0git
Loading...
Searching...
No Matches
ReduceDelay.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
15
16namespace circt {
17namespace datapath {
18#define GEN_PASS_DEF_DATAPATHREDUCEDELAY
19#include "circt/Dialect/Datapath/DatapathPasses.h.inc"
20} // namespace datapath
21} // namespace circt
22
23using namespace circt;
24using namespace datapath;
25using namespace mlir;
26
27namespace {
28
29// Fold add operations even if used multiple times incurring area overhead as
30// transformation reduces shared logic - but reduces delay
31// add = a + b;
32// out1 = add + c;
33// out2 = add << d;
34// -->
35// add = a + b;
36// comp1 = compress(a, b, c);
37// out1 = comp1[0] + comp1[1];
38// out2 = add << d;
39struct FoldAddReplicate : public OpRewritePattern<comb::AddOp> {
40 using OpRewritePattern::OpRewritePattern;
41
42 LogicalResult matchAndRewrite(comb::AddOp addOp,
43 PatternRewriter &rewriter) const override {
44
45 SmallVector<Value, 8> newCompressOperands;
46 // Check if any operand is an AddOp that has not be folded by comb folds
47 for (Value operand : addOp.getOperands()) {
48 if (comb::AddOp nestedAddOp = operand.getDefiningOp<comb::AddOp>()) {
49 llvm::append_range(newCompressOperands, nestedAddOp.getOperands());
50 } else {
51 newCompressOperands.push_back(operand);
52 }
53 }
54
55 // Nothing to be folded
56 if (newCompressOperands.size() <= addOp.getNumOperands())
57 return failure();
58
59 // Create a new CompressOp with all collected operands
60 auto newCompressOp = rewriter.create<datapath::CompressOp>(
61 addOp.getLoc(), newCompressOperands, 2);
62
63 // Add the results of the CompressOp
64 rewriter.replaceOpWithNewOp<comb::AddOp>(addOp, newCompressOp.getResults(),
65 true);
66 return success();
67 }
68};
69
70// (a ? b + c : d + e) + f
71// -->
72// (a ? b : d) + (a ? c : e) + f
73struct FoldMuxAdd : public OpRewritePattern<comb::AddOp> {
74 using OpRewritePattern::OpRewritePattern;
75
76 // When used in conjunction with datapath canonicalization will only replicate
77 // two input adders.
78 LogicalResult matchAndRewrite(comb::AddOp addOp,
79 PatternRewriter &rewriter) const override {
80
81 SmallVector<Value, 8> newCompressOperands;
82 for (Value operand : addOp.getOperands()) {
83 // Detect a mux operand - then check if it contains add operations
84 comb::MuxOp nestedMuxOp = operand.getDefiningOp<comb::MuxOp>();
85
86 // If not matched just add the operand without modification
87 if (!nestedMuxOp) {
88 newCompressOperands.push_back(operand);
89 continue;
90 }
91
92 SmallVector<Value> trueValOperands = {nestedMuxOp.getTrueValue()};
93 SmallVector<Value> falseValOperands = {nestedMuxOp.getFalseValue()};
94 // match a ? b + c : xx
95 if (comb::AddOp trueVal =
96 nestedMuxOp.getTrueValue().getDefiningOp<comb::AddOp>())
97 trueValOperands = trueVal.getOperands();
98
99 // match a ? xx : c + d
100 if (comb::AddOp falseVal =
101 nestedMuxOp.getFalseValue().getDefiningOp<comb::AddOp>())
102 falseValOperands = falseVal.getOperands();
103
104 auto maxOperands =
105 std::max(trueValOperands.size(), falseValOperands.size());
106
107 // No nested additions
108 if (maxOperands <= 1) {
109 newCompressOperands.push_back(operand);
110 continue;
111 }
112
113 // Pad with zeros to match number of operands
114 // a ? b + c : d -> (a ? b : d) + (a ? c : 0)
115 auto zero = rewriter.create<hw::ConstantOp>(
116 addOp.getLoc(), rewriter.getIntegerAttr(addOp.getType(), 0));
117 for (size_t i = 0; i < maxOperands; ++i) {
118 auto tOp = i < trueValOperands.size() ? trueValOperands[i] : zero;
119 auto fOp = i < falseValOperands.size() ? falseValOperands[i] : zero;
120 auto newMux = rewriter.create<comb::MuxOp>(
121 addOp.getLoc(), nestedMuxOp.getCond(), tOp, fOp);
122 newCompressOperands.push_back(newMux.getResult());
123 }
124 }
125
126 // Nothing to be folded
127 if (newCompressOperands.size() <= addOp.getNumOperands())
128 return failure();
129
130 // Create a new CompressOp with all collected operands
131 auto newCompressOp = rewriter.create<datapath::CompressOp>(
132 addOp.getLoc(), newCompressOperands, 2);
133
134 // Add the results of the CompressOp
135 rewriter.replaceOpWithNewOp<comb::AddOp>(addOp, newCompressOp.getResults(),
136 true);
137 return success();
138 }
139};
140} // namespace
141
142namespace {
143struct DatapathReduceDelayPass
144 : public circt::datapath::impl::DatapathReduceDelayBase<
145 DatapathReduceDelayPass> {
146
147 void runOnOperation() override {
148 Operation *op = getOperation();
149 MLIRContext *ctx = op->getContext();
150
151 RewritePatternSet patterns(ctx);
152 patterns.add<FoldAddReplicate, FoldMuxAdd>(ctx);
153
154 if (failed(applyPatternsGreedily(op, std::move(patterns))))
155 signalPassFailure();
156 };
157};
158} // namespace
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.