CIRCT  20.0.0git
LatencyRetiming.cpp
Go to the documentation of this file.
1 //===- LatencyRetiming.cpp - Implement LatencyRetiming Pass ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "circt/Support/LLVM.h"
12 #include "circt/Support/SymCache.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16 
17 #define DEBUG_TYPE "arc-latency-retiming"
18 
19 namespace circt {
20 namespace arc {
21 #define GEN_PASS_DEF_LATENCYRETIMING
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
23 } // namespace arc
24 } // namespace circt
25 
26 using namespace circt;
27 using namespace arc;
28 
29 //===----------------------------------------------------------------------===//
30 // Patterns
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 struct LatencyRetimingStatistics {
35  unsigned numOpsRemoved = 0;
36  unsigned latencyUnitsSaved = 0;
37 };
38 
39 /// Absorb the latencies from predecessor states to collapse shift registers and
40 /// reduce the overall amount of latency units in the design.
41 struct LatencyRetimingPattern
42  : mlir::OpInterfaceRewritePattern<ClockedOpInterface> {
43  LatencyRetimingPattern(MLIRContext *context, SymbolCache &symCache,
44  LatencyRetimingStatistics &statistics)
45  : OpInterfaceRewritePattern<ClockedOpInterface>(context),
46  symCache(symCache), statistics(statistics) {}
47 
48  LogicalResult matchAndRewrite(ClockedOpInterface op,
49  PatternRewriter &rewriter) const final;
50 
51 private:
52  SymbolCache &symCache;
53  LatencyRetimingStatistics &statistics;
54 };
55 
56 } // namespace
57 
58 LogicalResult
59 LatencyRetimingPattern::matchAndRewrite(ClockedOpInterface op,
60  PatternRewriter &rewriter) const {
61  uint32_t minPrevLatency = UINT_MAX;
62  SetVector<ClockedOpInterface> predecessors;
63  Value clock;
64 
65  auto hasEnableOrReset = [](Operation *op) -> bool {
66  if (auto stateOp = dyn_cast<StateOp>(op))
67  if (stateOp.getReset() || stateOp.getEnable())
68  return true;
69  return false;
70  };
71 
72  // Restrict this pattern to call and state ops only. In the future we could
73  // also add support for memory write operations.
74  if (!isa<CallOp, StateOp>(op.getOperation()))
75  return failure();
76 
77  // In principle we could support enables and resets but would have to check
78  // that all involved states have the same.
79  if (hasEnableOrReset(op))
80  return failure();
81 
82  assert(isa<mlir::CallOpInterface>(op.getOperation()) &&
83  "state and call operations call arcs and thus have to implement the "
84  "CallOpInterface");
85  auto callOp = cast<mlir::CallOpInterface>(op.getOperation());
86 
87  for (auto input : callOp.getArgOperands()) {
88  auto predOp = input.getDefiningOp<ClockedOpInterface>();
89 
90  // Only support call and state ops for the predecessors as well.
91  if (!predOp || !isa<CallOp, StateOp>(predOp.getOperation()))
92  return failure();
93 
94  // Conditions for both StateOp and CallOp
95  if (predOp->hasAttr("name") || predOp->hasAttr("names"))
96  return failure();
97 
98  // Check for a use-def cycle since we can be in a graph region.
99  if (predOp == op)
100  return failure();
101 
102  if (predOp.getClock() && op.getClock() &&
103  predOp.getClock() != op.getClock())
104  return failure();
105 
106  if (predOp->getParentRegion() != op->getParentRegion())
107  return failure();
108 
109  if (hasEnableOrReset(predOp))
110  return failure();
111 
112  // Check that the predecessor state does not have another user since then
113  // we cannot change its latency attribute without also changing it for the
114  // other users. This is not supported yet and thus we just fail.
115  if (llvm::any_of(predOp->getUsers(),
116  [&](auto *user) { return user != op; }))
117  return failure();
118 
119  // We check that all clocks are the same if present. Here we remember that
120  // clock. If none of the involved operations have a clock, they must have
121  // latency 0 and thus `minPrevLatency = 0` leading to early failure below.
122  if (!clock) {
123  if (predOp.getClock())
124  clock = predOp.getClock();
125  if (auto clockDomain = predOp->getParentOfType<ClockDomainOp>())
126  clock = clockDomain.getClock();
127  }
128 
129  predecessors.insert(predOp);
130  minPrevLatency = std::min(minPrevLatency, predOp.getLatency());
131  }
132 
133  if (minPrevLatency == 0 || minPrevLatency == UINT_MAX)
134  return failure();
135 
136  auto setLatency = [&](Operation *op, uint64_t newLatency, Value clock) {
137  assert((isa<StateOp, CallOp>(op)) && "must be a state or call op");
138  bool isInClockDomain = op->getParentOfType<ClockDomainOp>();
139 
140  if (auto stateOp = dyn_cast<StateOp>(op)) {
141  if (newLatency == 0) {
142  if (cast<DefineOp>(symCache.getDefinition(stateOp.getArcAttr()))
143  .isPassthrough()) {
144  rewriter.replaceOp(stateOp, stateOp.getInputs());
145  ++statistics.numOpsRemoved;
146  return;
147  }
148  rewriter.setInsertionPoint(op);
149  rewriter.replaceOpWithNewOp<CallOp>(op, stateOp.getOutputs().getTypes(),
150  stateOp.getArcAttr(),
151  stateOp.getInputs());
152  return;
153  }
154 
155  rewriter.modifyOpInPlace(op, [&]() {
156  stateOp.setLatency(newLatency);
157  if (!stateOp.getClock() && !isInClockDomain)
158  stateOp.getClockMutable().assign(clock);
159  });
160  return;
161  }
162 
163  if (auto callOp = dyn_cast<CallOp>(op); callOp && newLatency > 0)
164  rewriter.replaceOpWithNewOp<StateOp>(
165  op, callOp.getArcAttr(), callOp->getResultTypes(),
166  isInClockDomain ? Value{} : clock, Value{}, newLatency,
167  callOp.getInputs());
168  };
169 
170  setLatency(op, op.getLatency() + minPrevLatency, clock);
171  for (auto prevOp : predecessors) {
172  statistics.latencyUnitsSaved += minPrevLatency;
173  auto newLatency = prevOp.getLatency() - minPrevLatency;
174  setLatency(prevOp, newLatency, {});
175  }
176  statistics.latencyUnitsSaved -= minPrevLatency;
177 
178  return success();
179 }
180 
181 //===----------------------------------------------------------------------===//
182 // LatencyRetiming pass
183 //===----------------------------------------------------------------------===//
184 
185 namespace {
186 struct LatencyRetimingPass
187  : arc::impl::LatencyRetimingBase<LatencyRetimingPass> {
188  void runOnOperation() override;
189 };
190 } // namespace
191 
192 void LatencyRetimingPass::runOnOperation() {
193  SymbolCache cache;
194  cache.addDefinitions(getOperation());
195 
196  LatencyRetimingStatistics statistics;
197 
198  RewritePatternSet patterns(&getContext());
199  patterns.add<LatencyRetimingPattern>(&getContext(), cache, statistics);
200 
201  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
202  return signalPassFailure();
203 
204  numOpsRemoved = statistics.numOpsRemoved;
205  latencyUnitsSaved = statistics.latencyUnitsSaved;
206 }
207 
208 std::unique_ptr<Pass> arc::createLatencyRetimingPass() {
209  return std::make_unique<LatencyRetimingPass>();
210 }
assert(baseType &&"element must be base type")
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition: SymCache.h:85
std::unique_ptr< mlir::Pass > createLatencyRetimingPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21