CIRCT  18.0.0git
LatencyRetiming.cpp
Go to the documentation of this file.
1 //===- LatencyRetiming.cpp - Implement LatencyRetiming Pass ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "circt/Support/LLVM.h"
12 #include "circt/Support/SymCache.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
16 
17 #define DEBUG_TYPE "arc-latency-retiming"
18 
19 namespace circt {
20 namespace arc {
21 #define GEN_PASS_DEF_LATENCYRETIMING
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
23 } // namespace arc
24 } // namespace circt
25 
26 using namespace circt;
27 using namespace arc;
28 
29 //===----------------------------------------------------------------------===//
30 // Patterns
31 //===----------------------------------------------------------------------===//
32 
33 namespace {
34 struct LatencyRetimingStatistics {
35  unsigned numOpsRemoved = 0;
36  unsigned latencyUnitsSaved = 0;
37 };
38 
39 /// Absorb the latencies from predecessor states to collapse shift registers and
40 /// reduce the overall amount of latency units in the design.
41 struct LatencyRetimingPattern : OpRewritePattern<StateOp> {
42  LatencyRetimingPattern(MLIRContext *context, SymbolCache &symCache,
43  LatencyRetimingStatistics &statistics)
44  : OpRewritePattern<StateOp>(context), symCache(symCache),
45  statistics(statistics) {}
46 
47  LogicalResult matchAndRewrite(StateOp op,
48  PatternRewriter &rewriter) const final;
49 
50 private:
51  SymbolCache &symCache;
52  LatencyRetimingStatistics &statistics;
53 };
54 
55 } // namespace
56 
57 LogicalResult
58 LatencyRetimingPattern::matchAndRewrite(StateOp op,
59  PatternRewriter &rewriter) const {
60  unsigned minPrevLatency = UINT_MAX;
61  SetVector<StateOp> predecessors;
62 
63  if (op.getReset() || op.getEnable())
64  return failure();
65 
66  for (auto input : op.getInputs()) {
67  auto predState = input.getDefiningOp<StateOp>();
68  if (!predState)
69  return failure();
70 
71  if (predState->hasAttr("name") || predState->hasAttr("names"))
72  return failure();
73 
74  if (predState == op)
75  return failure();
76 
77  if (predState.getLatency() != 0 && op.getLatency() != 0 &&
78  predState.getClock() != op.getClock())
79  return failure();
80 
81  if (predState.getEnable() || predState.getReset())
82  return failure();
83 
84  if (llvm::any_of(predState->getUsers(),
85  [&](auto *user) { return user != op; }))
86  return failure();
87 
88  predecessors.insert(predState);
89  minPrevLatency = std::min(minPrevLatency, predState.getLatency());
90  }
91 
92  if (minPrevLatency == 0 || minPrevLatency == UINT_MAX)
93  return failure();
94 
95  op.setLatency(op.getLatency() + minPrevLatency);
96  for (auto prevStateOp : predecessors) {
97  if (!op.getClock() && !op->getParentOfType<ClockDomainOp>())
98  op.getClockMutable().assign(prevStateOp.getClock());
99 
100  statistics.latencyUnitsSaved += minPrevLatency;
101  auto newLatency = prevStateOp.getLatency() - minPrevLatency;
102  prevStateOp.setLatency(newLatency);
103 
104  if (newLatency > 0)
105  continue;
106 
107  prevStateOp.getClockMutable().clear();
108  if (cast<DefineOp>(symCache.getDefinition(prevStateOp.getArcAttr()))
109  .isPassthrough()) {
110  rewriter.replaceOp(prevStateOp, prevStateOp.getInputs());
111  ++statistics.numOpsRemoved;
112  }
113  }
114  statistics.latencyUnitsSaved -= minPrevLatency;
115 
116  return success();
117 }
118 
119 //===----------------------------------------------------------------------===//
120 // LatencyRetiming pass
121 //===----------------------------------------------------------------------===//
122 
123 namespace {
124 struct LatencyRetimingPass
125  : arc::impl::LatencyRetimingBase<LatencyRetimingPass> {
126  void runOnOperation() override;
127 };
128 } // namespace
129 
130 void LatencyRetimingPass::runOnOperation() {
131  SymbolCache cache;
132  cache.addDefinitions(getOperation());
133 
134  LatencyRetimingStatistics statistics;
135 
136  RewritePatternSet patterns(&getContext());
137  patterns.add<LatencyRetimingPattern>(&getContext(), cache, statistics);
138 
139  if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
140  return signalPassFailure();
141 
142  numOpsRemoved = statistics.numOpsRemoved;
143  latencyUnitsSaved = statistics.latencyUnitsSaved;
144 }
145 
146 std::unique_ptr<Pass> arc::createLatencyRetimingPass() {
147  return std::make_unique<LatencyRetimingPass>();
148 }
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition: SymCache.h:85
std::unique_ptr< mlir::Pass > createLatencyRetimingPass()
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21