13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 #define DEBUG_TYPE "arc-latency-retiming"
21 #define GEN_PASS_DEF_LATENCYRETIMING
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
26 using namespace circt;
34 struct LatencyRetimingStatistics {
35 unsigned numOpsRemoved = 0;
36 unsigned latencyUnitsSaved = 0;
41 struct LatencyRetimingPattern
42 : mlir::OpInterfaceRewritePattern<ClockedOpInterface> {
43 LatencyRetimingPattern(MLIRContext *context,
SymbolCache &symCache,
44 LatencyRetimingStatistics &statistics)
45 : OpInterfaceRewritePattern<ClockedOpInterface>(context),
46 symCache(symCache), statistics(statistics) {}
48 LogicalResult matchAndRewrite(ClockedOpInterface op,
49 PatternRewriter &rewriter)
const final;
53 LatencyRetimingStatistics &statistics;
59 LatencyRetimingPattern::matchAndRewrite(ClockedOpInterface op,
60 PatternRewriter &rewriter)
const {
61 uint32_t minPrevLatency = UINT_MAX;
62 SetVector<ClockedOpInterface> predecessors;
65 auto hasEnableOrReset = [](Operation *op) ->
bool {
66 if (
auto stateOp = dyn_cast<StateOp>(op))
67 if (stateOp.getReset() || stateOp.getEnable())
74 if (!isa<CallOp, StateOp>(op.getOperation()))
79 if (hasEnableOrReset(op))
82 assert(isa<mlir::CallOpInterface>(op.getOperation()) &&
83 "state and call operations call arcs and thus have to implement the "
85 auto callOp = cast<mlir::CallOpInterface>(op.getOperation());
87 for (
auto input : callOp.getArgOperands()) {
88 auto predOp = input.getDefiningOp<ClockedOpInterface>();
91 if (!predOp || !isa<CallOp, StateOp>(predOp.getOperation()))
95 if (predOp->hasAttr(
"name") || predOp->hasAttr(
"names"))
102 if (predOp.getClock() && op.getClock() &&
103 predOp.getClock() != op.getClock())
106 if (predOp->getParentRegion() != op->getParentRegion())
109 if (hasEnableOrReset(predOp))
115 if (llvm::any_of(predOp->getUsers(),
116 [&](
auto *user) { return user != op; }))
123 if (predOp.getClock())
124 clock = predOp.getClock();
125 if (
auto clockDomain = predOp->getParentOfType<ClockDomainOp>())
126 clock = clockDomain.getClock();
129 predecessors.insert(predOp);
130 minPrevLatency = std::min(minPrevLatency, predOp.getLatency());
133 if (minPrevLatency == 0 || minPrevLatency == UINT_MAX)
136 auto setLatency = [&](Operation *op, uint64_t newLatency, Value clock) {
137 assert((isa<StateOp, CallOp>(op)) &&
"must be a state or call op");
138 bool isInClockDomain = op->getParentOfType<ClockDomainOp>();
140 if (
auto stateOp = dyn_cast<StateOp>(op)) {
141 if (newLatency == 0) {
142 if (cast<DefineOp>(symCache.getDefinition(stateOp.getArcAttr()))
144 rewriter.replaceOp(stateOp, stateOp.getInputs());
145 ++statistics.numOpsRemoved;
148 rewriter.setInsertionPoint(op);
149 rewriter.replaceOpWithNewOp<CallOp>(op, stateOp.getOutputs().getTypes(),
150 stateOp.getArcAttr(),
151 stateOp.getInputs());
155 rewriter.modifyOpInPlace(op, [&]() {
156 stateOp.setLatency(newLatency);
157 if (!stateOp.getClock() && !isInClockDomain)
158 stateOp.getClockMutable().assign(clock);
163 if (
auto callOp = dyn_cast<CallOp>(op); callOp && newLatency > 0)
164 rewriter.replaceOpWithNewOp<StateOp>(
165 op, callOp.getArcAttr(), callOp->getResultTypes(),
166 isInClockDomain ? Value{} : clock, Value{}, newLatency,
170 setLatency(op, op.getLatency() + minPrevLatency, clock);
171 for (
auto prevOp : predecessors) {
172 statistics.latencyUnitsSaved += minPrevLatency;
173 auto newLatency = prevOp.getLatency() - minPrevLatency;
174 setLatency(prevOp, newLatency, {});
176 statistics.latencyUnitsSaved -= minPrevLatency;
186 struct LatencyRetimingPass
187 : arc::impl::LatencyRetimingBase<LatencyRetimingPass> {
188 void runOnOperation()
override;
192 void LatencyRetimingPass::runOnOperation() {
196 LatencyRetimingStatistics statistics;
198 RewritePatternSet
patterns(&getContext());
199 patterns.add<LatencyRetimingPattern>(&getContext(), cache, statistics);
201 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(
patterns))))
202 return signalPassFailure();
204 numOpsRemoved = statistics.numOpsRemoved;
205 latencyUnitsSaved = statistics.latencyUnitsSaved;
209 return std::make_unique<LatencyRetimingPass>();
assert(baseType &&"element must be base type")
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
std::unique_ptr< mlir::Pass > createLatencyRetimingPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.