13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
17 #define DEBUG_TYPE "arc-latency-retiming"
21 #define GEN_PASS_DEF_LATENCYRETIMING
22 #include "circt/Dialect/Arc/ArcPasses.h.inc"
26 using namespace circt;
34 struct LatencyRetimingStatistics {
35 unsigned numOpsRemoved = 0;
36 unsigned latencyUnitsSaved = 0;
41 struct LatencyRetimingPattern
42 : mlir::OpInterfaceRewritePattern<ClockedOpInterface> {
43 LatencyRetimingPattern(MLIRContext *context,
SymbolCache &symCache,
44 LatencyRetimingStatistics &statistics)
45 : OpInterfaceRewritePattern<ClockedOpInterface>(context),
46 symCache(symCache), statistics(statistics) {}
48 LogicalResult matchAndRewrite(ClockedOpInterface op,
49 PatternRewriter &rewriter)
const final;
53 LatencyRetimingStatistics &statistics;
59 LatencyRetimingPattern::matchAndRewrite(ClockedOpInterface op,
60 PatternRewriter &rewriter)
const {
61 uint32_t minPrevLatency = UINT_MAX;
62 SetVector<ClockedOpInterface> predecessors;
65 auto hasEnableOrReset = [](Operation *op) ->
bool {
66 if (
auto stateOp = dyn_cast<StateOp>(op))
67 if (stateOp.getReset() || stateOp.getEnable())
74 if (!isa<CallOp, StateOp>(op.getOperation()))
79 if (hasEnableOrReset(op))
82 assert(isa<mlir::CallOpInterface>(op.getOperation()) &&
83 "state and call operations call arcs and thus have to implement the "
85 auto callOp = cast<mlir::CallOpInterface>(op.getOperation());
87 for (
auto input : callOp.getArgOperands()) {
88 auto predOp = input.getDefiningOp<ClockedOpInterface>();
91 if (!predOp || !isa<CallOp, StateOp>(predOp.getOperation()))
95 if (predOp->hasAttr(
"name") || predOp->hasAttr(
"names"))
102 if (predOp.getClock() && op.getClock() &&
103 predOp.getClock() != op.getClock())
106 if (predOp->getParentRegion() != op->getParentRegion())
109 if (hasEnableOrReset(predOp))
115 if (llvm::any_of(predOp->getUsers(),
116 [&](
auto *user) { return user != op; }))
123 if (predOp.getClock())
124 clock = predOp.getClock();
125 if (
auto clockDomain = predOp->getParentOfType<ClockDomainOp>())
126 clock = clockDomain.getClock();
129 predecessors.insert(predOp);
130 minPrevLatency = std::min(minPrevLatency, predOp.getLatency());
133 if (minPrevLatency == 0 || minPrevLatency == UINT_MAX)
136 auto setLatency = [&](Operation *op, uint64_t newLatency, Value clock) {
137 bool validOp = isa<StateOp, CallOp>(op);
138 assert(validOp &&
"must be a state or call op");
139 bool isInClockDomain = op->getParentOfType<ClockDomainOp>();
141 if (
auto stateOp = dyn_cast<StateOp>(op)) {
142 if (newLatency == 0) {
143 if (cast<DefineOp>(symCache.getDefinition(stateOp.getArcAttr()))
145 rewriter.replaceOp(stateOp, stateOp.getInputs());
146 ++statistics.numOpsRemoved;
149 rewriter.setInsertionPoint(op);
150 rewriter.replaceOpWithNewOp<CallOp>(op, stateOp.getOutputs().getTypes(),
151 stateOp.getArcAttr(),
152 stateOp.getInputs());
156 rewriter.modifyOpInPlace(op, [&]() {
157 stateOp.setLatency(newLatency);
158 if (!stateOp.getClock() && !isInClockDomain)
159 stateOp.getClockMutable().assign(clock);
164 if (
auto callOp = dyn_cast<CallOp>(op); callOp && newLatency > 0)
165 rewriter.replaceOpWithNewOp<StateOp>(
166 op, callOp.getArcAttr(), callOp->getResultTypes(),
167 isInClockDomain ? Value{} : clock, Value{}, newLatency,
171 setLatency(op, op.getLatency() + minPrevLatency, clock);
172 for (
auto prevOp : predecessors) {
173 statistics.latencyUnitsSaved += minPrevLatency;
174 auto newLatency = prevOp.getLatency() - minPrevLatency;
175 setLatency(prevOp, newLatency, {});
177 statistics.latencyUnitsSaved -= minPrevLatency;
187 struct LatencyRetimingPass
188 : arc::impl::LatencyRetimingBase<LatencyRetimingPass> {
189 void runOnOperation()
override;
193 void LatencyRetimingPass::runOnOperation() {
197 LatencyRetimingStatistics statistics;
199 RewritePatternSet
patterns(&getContext());
200 patterns.add<LatencyRetimingPattern>(&getContext(), cache, statistics);
202 if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(
patterns))))
203 return signalPassFailure();
205 numOpsRemoved = statistics.numOpsRemoved;
206 latencyUnitsSaved = statistics.latencyUnitsSaved;
210 return std::make_unique<LatencyRetimingPass>();
assert(baseType &&"element must be base type")
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
std::unique_ptr< mlir::Pass > createLatencyRetimingPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.