CIRCT 20.0.0git
Loading...
Searching...
No Matches
LatencyRetiming.cpp
Go to the documentation of this file.
1//===- LatencyRetiming.cpp - Implement LatencyRetiming Pass ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "circt/Support/LLVM.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16
17#define DEBUG_TYPE "arc-latency-retiming"
18
19namespace circt {
20namespace arc {
21#define GEN_PASS_DEF_LATENCYRETIMING
22#include "circt/Dialect/Arc/ArcPasses.h.inc"
23} // namespace arc
24} // namespace circt
25
26using namespace circt;
27using namespace arc;
28
29//===----------------------------------------------------------------------===//
30// Patterns
31//===----------------------------------------------------------------------===//
32
33namespace {
34struct LatencyRetimingStatistics {
35 unsigned numOpsRemoved = 0;
36 unsigned latencyUnitsSaved = 0;
37};
38
39/// Absorb the latencies from predecessor states to collapse shift registers and
40/// reduce the overall amount of latency units in the design.
41struct LatencyRetimingPattern
42 : mlir::OpInterfaceRewritePattern<ClockedOpInterface> {
43 LatencyRetimingPattern(MLIRContext *context, SymbolCache &symCache,
44 LatencyRetimingStatistics &statistics)
45 : OpInterfaceRewritePattern<ClockedOpInterface>(context),
46 symCache(symCache), statistics(statistics) {}
47
48 LogicalResult matchAndRewrite(ClockedOpInterface op,
49 PatternRewriter &rewriter) const final;
50
51private:
52 SymbolCache &symCache;
53 LatencyRetimingStatistics &statistics;
54};
55
56} // namespace
57
58LogicalResult
59LatencyRetimingPattern::matchAndRewrite(ClockedOpInterface op,
60 PatternRewriter &rewriter) const {
61 uint32_t minPrevLatency = UINT_MAX;
62 SetVector<ClockedOpInterface> predecessors;
63 Value clock;
64
65 auto hasEnableOrReset = [](Operation *op) -> bool {
66 if (auto stateOp = dyn_cast<StateOp>(op))
67 if (stateOp.getReset() || stateOp.getEnable())
68 return true;
69 return false;
70 };
71
72 // Restrict this pattern to call and state ops only. In the future we could
73 // also add support for memory write operations.
74 if (!isa<CallOp, StateOp>(op.getOperation()))
75 return failure();
76
77 // In principle we could support enables and resets but would have to check
78 // that all involved states have the same.
79 if (hasEnableOrReset(op))
80 return failure();
81
82 assert(isa<mlir::CallOpInterface>(op.getOperation()) &&
83 "state and call operations call arcs and thus have to implement the "
84 "CallOpInterface");
85 auto callOp = cast<mlir::CallOpInterface>(op.getOperation());
86
87 for (auto input : callOp.getArgOperands()) {
88 auto predOp = input.getDefiningOp<ClockedOpInterface>();
89
90 // Only support call and state ops for the predecessors as well.
91 if (!predOp || !isa<CallOp, StateOp>(predOp.getOperation()))
92 return failure();
93
94 // Conditions for both StateOp and CallOp
95 if (predOp->hasAttr("name") || predOp->hasAttr("names"))
96 return failure();
97
98 // Check for a use-def cycle since we can be in a graph region.
99 if (predOp == op)
100 return failure();
101
102 if (predOp.getClock() && op.getClock() &&
103 predOp.getClock() != op.getClock())
104 return failure();
105
106 if (predOp->getParentRegion() != op->getParentRegion())
107 return failure();
108
109 if (hasEnableOrReset(predOp))
110 return failure();
111
112 // Check that the predecessor state does not have another user since then
113 // we cannot change its latency attribute without also changing it for the
114 // other users. This is not supported yet and thus we just fail.
115 if (llvm::any_of(predOp->getUsers(),
116 [&](auto *user) { return user != op; }))
117 return failure();
118
119 // We check that all clocks are the same if present. Here we remember that
120 // clock. If none of the involved operations have a clock, they must have
121 // latency 0 and thus `minPrevLatency = 0` leading to early failure below.
122 if (!clock) {
123 if (predOp.getClock())
124 clock = predOp.getClock();
125 if (auto clockDomain = predOp->getParentOfType<ClockDomainOp>())
126 clock = clockDomain.getClock();
127 }
128
129 predecessors.insert(predOp);
130 minPrevLatency = std::min(minPrevLatency, predOp.getLatency());
131 }
132
133 if (minPrevLatency == 0 || minPrevLatency == UINT_MAX)
134 return failure();
135
136 auto setLatency = [&](Operation *op, uint64_t newLatency, Value clock) {
137 assert((isa<StateOp, CallOp>(op)) && "must be a state or call op");
138 bool isInClockDomain = op->getParentOfType<ClockDomainOp>();
139
140 if (auto stateOp = dyn_cast<StateOp>(op)) {
141 if (newLatency == 0) {
142 if (cast<DefineOp>(symCache.getDefinition(stateOp.getArcAttr()))
143 .isPassthrough()) {
144 rewriter.replaceOp(stateOp, stateOp.getInputs());
145 ++statistics.numOpsRemoved;
146 return;
147 }
148 rewriter.setInsertionPoint(op);
149 rewriter.replaceOpWithNewOp<CallOp>(op, stateOp.getOutputs().getTypes(),
150 stateOp.getArcAttr(),
151 stateOp.getInputs());
152 return;
153 }
154
155 rewriter.modifyOpInPlace(op, [&]() {
156 stateOp.setLatency(newLatency);
157 if (!stateOp.getClock() && !isInClockDomain)
158 stateOp.getClockMutable().assign(clock);
159 });
160 return;
161 }
162
163 if (auto callOp = dyn_cast<CallOp>(op); callOp && newLatency > 0)
164 rewriter.replaceOpWithNewOp<StateOp>(
165 op, callOp.getArcAttr(), callOp->getResultTypes(),
166 isInClockDomain ? Value{} : clock, Value{}, newLatency,
167 callOp.getInputs());
168 };
169
170 setLatency(op, op.getLatency() + minPrevLatency, clock);
171 for (auto prevOp : predecessors) {
172 statistics.latencyUnitsSaved += minPrevLatency;
173 auto newLatency = prevOp.getLatency() - minPrevLatency;
174 setLatency(prevOp, newLatency, {});
175 }
176 statistics.latencyUnitsSaved -= minPrevLatency;
177
178 return success();
179}
180
181//===----------------------------------------------------------------------===//
182// LatencyRetiming pass
183//===----------------------------------------------------------------------===//
184
185namespace {
186struct LatencyRetimingPass
187 : arc::impl::LatencyRetimingBase<LatencyRetimingPass> {
188 void runOnOperation() override;
189};
190} // namespace
191
192void LatencyRetimingPass::runOnOperation() {
193 SymbolCache cache;
194 cache.addDefinitions(getOperation());
195
196 LatencyRetimingStatistics statistics;
197
198 RewritePatternSet patterns(&getContext());
199 patterns.add<LatencyRetimingPattern>(&getContext(), cache, statistics);
200
201 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
202 return signalPassFailure();
203
204 numOpsRemoved = statistics.numOpsRemoved;
205 latencyUnitsSaved = statistics.latencyUnitsSaved;
206}
207
208std::unique_ptr<Pass> arc::createLatencyRetimingPass() {
209 return std::make_unique<LatencyRetimingPass>();
210}
assert(baseType &&"element must be base type")
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
std::unique_ptr< mlir::Pass > createLatencyRetimingPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.