13#include "mlir/IR/PatternMatch.h"
14#include "mlir/Pass/Pass.h"
15#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17#define DEBUG_TYPE "arc-latency-retiming"
21#define GEN_PASS_DEF_LATENCYRETIMING
22#include "circt/Dialect/Arc/ArcPasses.h.inc"
34struct LatencyRetimingStatistics {
35 unsigned numOpsRemoved = 0;
36 unsigned latencyUnitsSaved = 0;
41struct LatencyRetimingPattern
42 : mlir::OpInterfaceRewritePattern<ClockedOpInterface> {
44 LatencyRetimingStatistics &statistics)
45 : OpInterfaceRewritePattern<ClockedOpInterface>(
context),
46 symCache(symCache), statistics(statistics) {}
48 LogicalResult matchAndRewrite(ClockedOpInterface op,
49 PatternRewriter &rewriter)
const final;
53 LatencyRetimingStatistics &statistics;
59LatencyRetimingPattern::matchAndRewrite(ClockedOpInterface op,
60 PatternRewriter &rewriter)
const {
61 uint32_t minPrevLatency = UINT_MAX;
62 SetVector<ClockedOpInterface> predecessors;
65 auto hasEnableOrReset = [](Operation *op) ->
bool {
66 if (
auto stateOp = dyn_cast<StateOp>(op))
67 if (stateOp.getReset() || stateOp.getEnable())
74 if (!isa<CallOp, StateOp>(op.getOperation()))
79 if (hasEnableOrReset(op))
82 assert(isa<mlir::CallOpInterface>(op.getOperation()) &&
83 "state and call operations call arcs and thus have to implement the "
85 auto callOp = cast<mlir::CallOpInterface>(op.getOperation());
87 for (
auto input : callOp.getArgOperands()) {
88 auto predOp = input.getDefiningOp<ClockedOpInterface>();
91 if (!predOp || !isa<CallOp, StateOp>(predOp.getOperation()))
95 if (predOp->hasAttr(
"name") || predOp->hasAttr(
"names"))
102 if (predOp.getClock() && op.getClock() &&
103 predOp.getClock() != op.getClock())
106 if (predOp->getParentRegion() != op->getParentRegion())
109 if (hasEnableOrReset(predOp))
115 if (llvm::any_of(predOp->getUsers(),
116 [&](
auto *user) { return user != op; }))
122 if (!clock && predOp.getClock())
123 clock = predOp.getClock();
125 predecessors.insert(predOp);
126 minPrevLatency = std::min(minPrevLatency, predOp.getLatency());
129 if (minPrevLatency == 0 || minPrevLatency == UINT_MAX)
132 auto setLatency = [&](Operation *op, uint64_t newLatency, Value clock) {
133 assert((isa<StateOp, CallOp>(op)) &&
"must be a state or call op");
135 if (
auto stateOp = dyn_cast<StateOp>(op)) {
136 if (newLatency == 0) {
137 if (cast<DefineOp>(symCache.getDefinition(stateOp.getArcAttr()))
139 rewriter.replaceOp(stateOp, stateOp.getInputs());
140 ++statistics.numOpsRemoved;
143 rewriter.setInsertionPoint(op);
144 rewriter.replaceOpWithNewOp<CallOp>(op, stateOp.getOutputs().getTypes(),
145 stateOp.getArcAttr(),
146 stateOp.getInputs());
150 rewriter.modifyOpInPlace(op, [&]() {
151 stateOp.setLatency(newLatency);
152 if (!stateOp.getClock())
153 stateOp.getClockMutable().assign(clock);
158 if (
auto callOp = dyn_cast<CallOp>(op); callOp && newLatency > 0)
159 rewriter.replaceOpWithNewOp<StateOp>(
160 op, callOp.getArcAttr(), callOp->getResultTypes(), clock, Value{},
161 newLatency, callOp.getInputs());
164 setLatency(op, op.getLatency() + minPrevLatency, clock);
165 for (
auto prevOp : predecessors) {
166 statistics.latencyUnitsSaved += minPrevLatency;
167 auto newLatency = prevOp.getLatency() - minPrevLatency;
168 setLatency(prevOp, newLatency, {});
170 statistics.latencyUnitsSaved -= minPrevLatency;
180struct LatencyRetimingPass
181 : arc::impl::LatencyRetimingBase<LatencyRetimingPass> {
182 void runOnOperation()
override;
186void LatencyRetimingPass::runOnOperation() {
190 LatencyRetimingStatistics statistics;
192 RewritePatternSet
patterns(&getContext());
193 patterns.add<LatencyRetimingPattern>(&getContext(), cache, statistics);
195 if (failed(applyPatternsGreedily(getOperation(), std::move(
patterns))))
196 return signalPassFailure();
198 numOpsRemoved = statistics.numOpsRemoved;
199 latencyUnitsSaved = statistics.latencyUnitsSaved;
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.