CIRCT 23.0.0git
Loading...
Searching...
No Matches
LatencyRetiming.cpp
Go to the documentation of this file.
1//===- LatencyRetiming.cpp - Implement LatencyRetiming Pass ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "circt/Support/LLVM.h"
13#include "mlir/IR/PatternMatch.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16
17#define DEBUG_TYPE "arc-latency-retiming"
18
19namespace circt {
20namespace arc {
21#define GEN_PASS_DEF_LATENCYRETIMING
22#include "circt/Dialect/Arc/ArcPasses.h.inc"
23} // namespace arc
24} // namespace circt
25
26using namespace circt;
27using namespace arc;
28
29//===----------------------------------------------------------------------===//
30// Patterns
31//===----------------------------------------------------------------------===//
32
33namespace {
34struct LatencyRetimingStatistics {
35 unsigned numOpsRemoved = 0;
36 unsigned latencyUnitsSaved = 0;
37};
38
39/// Absorb the latencies from predecessor states to collapse shift registers and
40/// reduce the overall amount of latency units in the design.
41struct LatencyRetimingPattern
42 : mlir::OpInterfaceRewritePattern<ClockedOpInterface> {
43 LatencyRetimingPattern(MLIRContext *context, SymbolCache &symCache,
44 LatencyRetimingStatistics &statistics)
45 : OpInterfaceRewritePattern<ClockedOpInterface>(context),
46 symCache(symCache), statistics(statistics) {}
47
48 LogicalResult matchAndRewrite(ClockedOpInterface op,
49 PatternRewriter &rewriter) const final;
50
51private:
52 SymbolCache &symCache;
53 LatencyRetimingStatistics &statistics;
54};
55
56} // namespace
57
58LogicalResult
59LatencyRetimingPattern::matchAndRewrite(ClockedOpInterface op,
60 PatternRewriter &rewriter) const {
61 uint32_t minPrevLatency = UINT_MAX;
62 SetVector<ClockedOpInterface> predecessors;
63 Value clock;
64
65 auto hasEnableOrReset = [](Operation *op) -> bool {
66 if (auto stateOp = dyn_cast<StateOp>(op))
67 if (stateOp.getReset() || stateOp.getEnable())
68 return true;
69 return false;
70 };
71
72 // Restrict this pattern to call and state ops only. In the future we could
73 // also add support for memory write operations.
74 if (!isa<CallOp, StateOp>(op.getOperation()))
75 return failure();
76
77 // In principle we could support enables and resets but would have to check
78 // that all involved states have the same.
79 if (hasEnableOrReset(op))
80 return failure();
81
82 assert(isa<mlir::CallOpInterface>(op.getOperation()) &&
83 "state and call operations call arcs and thus have to implement the "
84 "CallOpInterface");
85 auto callOp = cast<mlir::CallOpInterface>(op.getOperation());
86
87 for (auto input : callOp.getArgOperands()) {
88 auto predOp = input.getDefiningOp<ClockedOpInterface>();
89
90 // Only support call and state ops for the predecessors as well.
91 if (!predOp || !isa<CallOp, StateOp>(predOp.getOperation()))
92 return failure();
93
94 // Conditions for both StateOp and CallOp
95 if (predOp->hasAttr("name") || predOp->hasAttr("names"))
96 return failure();
97
98 // Check for a use-def cycle since we can be in a graph region.
99 if (predOp == op)
100 return failure();
101
102 if (predOp.getClock() && op.getClock() &&
103 predOp.getClock() != op.getClock())
104 return failure();
105
106 if (predOp->getParentRegion() != op->getParentRegion())
107 return failure();
108
109 if (hasEnableOrReset(predOp))
110 return failure();
111
112 // Check that the predecessor state does not have another user since then
113 // we cannot change its latency attribute without also changing it for the
114 // other users. This is not supported yet and thus we just fail.
115 if (llvm::any_of(predOp->getUsers(),
116 [&](auto *user) { return user != op; }))
117 return failure();
118
119 // We check that all clocks are the same if present. Here we remember that
120 // clock. If none of the involved operations have a clock, they must have
121 // latency 0 and thus `minPrevLatency = 0` leading to early failure below.
122 if (!clock && predOp.getClock())
123 clock = predOp.getClock();
124
125 predecessors.insert(predOp);
126 minPrevLatency = std::min(minPrevLatency, predOp.getLatency());
127 }
128
129 if (minPrevLatency == 0 || minPrevLatency == UINT_MAX)
130 return failure();
131
132 auto setLatency = [&](Operation *op, uint64_t newLatency, Value clock) {
133 assert((isa<StateOp, CallOp>(op)) && "must be a state or call op");
134
135 if (auto stateOp = dyn_cast<StateOp>(op)) {
136 if (newLatency == 0) {
137 if (cast<DefineOp>(symCache.getDefinition(stateOp.getArcAttr()))
138 .isPassthrough()) {
139 rewriter.replaceOp(stateOp, stateOp.getInputs());
140 ++statistics.numOpsRemoved;
141 return;
142 }
143 rewriter.setInsertionPoint(op);
144 rewriter.replaceOpWithNewOp<CallOp>(op, stateOp.getOutputs().getTypes(),
145 stateOp.getArcAttr(),
146 stateOp.getInputs());
147 return;
148 }
149
150 rewriter.modifyOpInPlace(op, [&]() {
151 stateOp.setLatency(newLatency);
152 if (!stateOp.getClock())
153 stateOp.getClockMutable().assign(clock);
154 });
155 return;
156 }
157
158 if (auto callOp = dyn_cast<CallOp>(op); callOp && newLatency > 0)
159 rewriter.replaceOpWithNewOp<StateOp>(
160 op, callOp.getArcAttr(), callOp->getResultTypes(), clock, Value{},
161 newLatency, callOp.getInputs());
162 };
163
164 setLatency(op, op.getLatency() + minPrevLatency, clock);
165 for (auto prevOp : predecessors) {
166 statistics.latencyUnitsSaved += minPrevLatency;
167 auto newLatency = prevOp.getLatency() - minPrevLatency;
168 setLatency(prevOp, newLatency, {});
169 }
170 statistics.latencyUnitsSaved -= minPrevLatency;
171
172 return success();
173}
174
175//===----------------------------------------------------------------------===//
176// LatencyRetiming pass
177//===----------------------------------------------------------------------===//
178
179namespace {
180struct LatencyRetimingPass
181 : arc::impl::LatencyRetimingBase<LatencyRetimingPass> {
182 void runOnOperation() override;
183};
184} // namespace
185
186void LatencyRetimingPass::runOnOperation() {
187 SymbolCache cache;
188 cache.addDefinitions(getOperation());
189
190 LatencyRetimingStatistics statistics;
191
192 RewritePatternSet patterns(&getContext());
193 patterns.add<LatencyRetimingPattern>(&getContext(), cache, statistics);
194
195 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
196 return signalPassFailure();
197
198 numOpsRemoved = statistics.numOpsRemoved;
199 latencyUnitsSaved = statistics.latencyUnitsSaved;
200}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
Definition arc.py:1
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.