CIRCT  19.0.0git
LatencyRetiming.cpp
Go to the documentation of this file.
1 //===- LatencyRetiming.cpp - Implement LatencyRetiming Pass ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "circt/Support/LLVM.h"
12 #include "circt/Support/SymCache.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16 
17 #define DEBUG_TYPE "arc-latency-retiming"
18 
19 namespace circt {
20 namespace arc {
21 #define GEN_PASS_DEF_LATENCYRETIMING
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
23 } // namespace arc
24 } // namespace circt
25 
26 using namespace circt;
27 using namespace arc;
28 
29 //===----------------------------------------------------------------------===//
30 // Patterns
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 struct LatencyRetimingStatistics {
35  unsigned numOpsRemoved = 0;
36  unsigned latencyUnitsSaved = 0;
37 };
38 
39 /// Absorb the latencies from predecessor states to collapse shift registers and
40 /// reduce the overall amount of latency units in the design.
41 struct LatencyRetimingPattern
42  : mlir::OpInterfaceRewritePattern<ClockedOpInterface> {
43  LatencyRetimingPattern(MLIRContext *context, SymbolCache &symCache,
44  LatencyRetimingStatistics &statistics)
45  : OpInterfaceRewritePattern<ClockedOpInterface>(context),
46  symCache(symCache), statistics(statistics) {}
47 
48  LogicalResult matchAndRewrite(ClockedOpInterface op,
49  PatternRewriter &rewriter) const final;
50 
51 private:
52  SymbolCache &symCache;
53  LatencyRetimingStatistics &statistics;
54 };
55 
56 } // namespace
57 
58 LogicalResult
59 LatencyRetimingPattern::matchAndRewrite(ClockedOpInterface op,
60  PatternRewriter &rewriter) const {
61  uint32_t minPrevLatency = UINT_MAX;
62  SetVector<ClockedOpInterface> predecessors;
63  Value clock;
64 
65  auto hasEnableOrReset = [](Operation *op) -> bool {
66  if (auto stateOp = dyn_cast<StateOp>(op))
67  if (stateOp.getReset() || stateOp.getEnable())
68  return true;
69  return false;
70  };
71 
72  // Restrict this pattern to call and state ops only. In the future we could
73  // also add support for memory write operations.
74  if (!isa<CallOp, StateOp>(op.getOperation()))
75  return failure();
76 
77  // In principle we could support enables and resets but would have to check
78  // that all involved states have the same.
79  if (hasEnableOrReset(op))
80  return failure();
81 
82  assert(isa<mlir::CallOpInterface>(op.getOperation()) &&
83  "state and call operations call arcs and thus have to implement the "
84  "CallOpInterface");
85  auto callOp = cast<mlir::CallOpInterface>(op.getOperation());
86 
87  for (auto input : callOp.getArgOperands()) {
88  auto predOp = input.getDefiningOp<ClockedOpInterface>();
89 
90  // Only support call and state ops for the predecessors as well.
91  if (!predOp || !isa<CallOp, StateOp>(predOp.getOperation()))
92  return failure();
93 
94  // Conditions for both StateOp and CallOp
95  if (predOp->hasAttr("name") || predOp->hasAttr("names"))
96  return failure();
97 
98  // Check for a use-def cycle since we can be in a graph region.
99  if (predOp == op)
100  return failure();
101 
102  if (predOp.getClock() && op.getClock() &&
103  predOp.getClock() != op.getClock())
104  return failure();
105 
106  if (predOp->getParentRegion() != op->getParentRegion())
107  return failure();
108 
109  if (hasEnableOrReset(predOp))
110  return failure();
111 
112  // Check that the predecessor state does not have another user since then
113  // we cannot change its latency attribute without also changing it for the
114  // other users. This is not supported yet and thus we just fail.
115  if (llvm::any_of(predOp->getUsers(),
116  [&](auto *user) { return user != op; }))
117  return failure();
118 
119  // We check that all clocks are the same if present. Here we remember that
120  // clock. If none of the involved operations have a clock, they must have
121  // latency 0 and thus `minPrevLatency = 0` leading to early failure below.
122  if (!clock) {
123  if (predOp.getClock())
124  clock = predOp.getClock();
125  if (auto clockDomain = predOp->getParentOfType<ClockDomainOp>())
126  clock = clockDomain.getClock();
127  }
128 
129  predecessors.insert(predOp);
130  minPrevLatency = std::min(minPrevLatency, predOp.getLatency());
131  }
132 
133  if (minPrevLatency == 0 || minPrevLatency == UINT_MAX)
134  return failure();
135 
136  auto setLatency = [&](Operation *op, uint64_t newLatency, Value clock) {
137  bool validOp = isa<StateOp, CallOp>(op);
138  assert(validOp && "must be a state or call op");
139  bool isInClockDomain = op->getParentOfType<ClockDomainOp>();
140 
141  if (auto stateOp = dyn_cast<StateOp>(op)) {
142  if (newLatency == 0) {
143  if (cast<DefineOp>(symCache.getDefinition(stateOp.getArcAttr()))
144  .isPassthrough()) {
145  rewriter.replaceOp(stateOp, stateOp.getInputs());
146  ++statistics.numOpsRemoved;
147  return;
148  }
149  rewriter.setInsertionPoint(op);
150  rewriter.replaceOpWithNewOp<CallOp>(op, stateOp.getOutputs().getTypes(),
151  stateOp.getArcAttr(),
152  stateOp.getInputs());
153  return;
154  }
155 
156  rewriter.modifyOpInPlace(op, [&]() {
157  stateOp.setLatency(newLatency);
158  if (!stateOp.getClock() && !isInClockDomain)
159  stateOp.getClockMutable().assign(clock);
160  });
161  return;
162  }
163 
164  if (auto callOp = dyn_cast<CallOp>(op); callOp && newLatency > 0)
165  rewriter.replaceOpWithNewOp<StateOp>(
166  op, callOp.getArcAttr(), callOp->getResultTypes(),
167  isInClockDomain ? Value{} : clock, Value{}, newLatency,
168  callOp.getInputs());
169  };
170 
171  setLatency(op, op.getLatency() + minPrevLatency, clock);
172  for (auto prevOp : predecessors) {
173  statistics.latencyUnitsSaved += minPrevLatency;
174  auto newLatency = prevOp.getLatency() - minPrevLatency;
175  setLatency(prevOp, newLatency, {});
176  }
177  statistics.latencyUnitsSaved -= minPrevLatency;
178 
179  return success();
180 }
181 
182 //===----------------------------------------------------------------------===//
183 // LatencyRetiming pass
184 //===----------------------------------------------------------------------===//
185 
186 namespace {
187 struct LatencyRetimingPass
188  : arc::impl::LatencyRetimingBase<LatencyRetimingPass> {
189  void runOnOperation() override;
190 };
191 } // namespace
192 
193 void LatencyRetimingPass::runOnOperation() {
194  SymbolCache cache;
195  cache.addDefinitions(getOperation());
196 
197  LatencyRetimingStatistics statistics;
198 
199  RewritePatternSet patterns(&getContext());
200  patterns.add<LatencyRetimingPattern>(&getContext(), cache, statistics);
201 
202  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
203  return signalPassFailure();
204 
205  numOpsRemoved = statistics.numOpsRemoved;
206  latencyUnitsSaved = statistics.latencyUnitsSaved;
207 }
208 
209 std::unique_ptr<Pass> arc::createLatencyRetimingPass() {
210  return std::make_unique<LatencyRetimingPass>();
211 }
assert(baseType &&"element must be base type")
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition: SymCache.h:85
std::unique_ptr< mlir::Pass > createLatencyRetimingPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21