12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/IR/IRMapping.h"
14 #include "mlir/Pass/Pass.h"
15 #include "mlir/Transforms/InliningUtils.h"
16 #include "llvm/Support/Debug.h"
18 #define DEBUG_TYPE "arc-inline"
22 #define GEN_PASS_DEF_INLINEARCS
23 #include "circt/Dialect/Arc/ArcPasses.h.inc"
27 using namespace circt;
35 struct InlineArcsStatistics {
36 size_t numInlinedArcs = 0;
37 size_t numRemovedArcs = 0;
38 size_t numTrivialArcs = 0;
39 size_t numSingleUseArcs = 0;
45 class InlineArcsAnalysis {
47 InlineArcsAnalysis(InlineArcsStatistics &statistics,
48 const InlineArcsOptions &options)
49 : statistics(statistics), options(options) {}
55 void analyze(ArrayRef<Region *> regionsWithCalls,
56 ArrayRef<DefineOp> arcDefinitions);
60 void notifyInlinedCallInto(mlir::CallOpInterface callOp, Region *region);
63 void notifyArcRemoved(DefineOp arc);
67 bool shouldInline(mlir::CallOpInterface callOp)
const;
70 DefineOp getArc(mlir::CallOpInterface callOp)
const;
74 size_t getNumArcUses(StringAttr arcName)
const;
77 DenseMap<StringAttr, SmallVector<StringAttr>> callsInArcBody;
78 DenseMap<StringAttr, size_t> numOpsInArc;
79 DenseMap<StringAttr, size_t> usersPerArc;
80 DenseMap<StringAttr, DefineOp> arcMap;
82 InlineArcsStatistics &statistics;
83 const InlineArcsOptions &options;
90 explicit ArcInliner(InlineArcsAnalysis &analysis) : analysis(analysis) {}
96 void inlineCallsInRegion(ArrayRef<Region *> regionsWithCalls,
97 ArrayRef<DefineOp> arcDefinitions,
98 bool removeUnusedArcs =
false);
102 void removeUnusedArcs(Region *unusedIn, ArrayRef<DefineOp> arcs);
105 void inlineCallsInRegion(Region *region);
106 void removeUnusedArcsInternal(ArrayRef<DefineOp> arcs);
108 InlineArcsAnalysis &analysis;
113 struct InlineArcsPass :
public arc::impl::InlineArcsBase<InlineArcsPass> {
114 using InlineArcsBase::InlineArcsBase;
116 void runOnOperation()
override;
121 void ArcInliner::inlineCallsInRegion(Region *region) {
122 for (
auto &block : region->getBlocks()) {
123 for (
auto iter = block.begin(); iter != block.end(); ++iter) {
124 Operation &op = *iter;
125 if (
auto callOp = dyn_cast<mlir::CallOpInterface>(op);
126 callOp && analysis.shouldInline(callOp)) {
127 DefineOp arc = analysis.getArc(callOp);
128 auto args = arc.getBodyBlock().getArguments();
130 IRMapping localMapping;
131 for (
auto [arg, operand] : llvm::zip(args, callOp.getArgOperands()))
132 localMapping.map(arg, operand);
134 OpBuilder builder(callOp);
135 builder.setInsertionPointAfter(callOp);
136 for (
auto &op : arc.getBodyBlock().without_terminator())
137 builder.clone(op, localMapping);
139 for (
auto [returnVal, result] :
140 llvm::zip(arc.getBodyBlock().getTerminator()->getOperands(),
141 callOp->getResults()))
142 result.replaceAllUsesWith(localMapping.lookup(returnVal));
144 analysis.notifyInlinedCallInto(callOp, region);
153 for (Region ®ion : op.getRegions())
154 inlineCallsInRegion(®ion);
159 void ArcInliner::removeUnusedArcsInternal(ArrayRef<DefineOp> arcs) {
160 for (
auto arc : llvm::make_early_inc_range(arcs)) {
161 if (analysis.getNumArcUses(arc.getSymNameAttr()) == 0) {
162 analysis.notifyArcRemoved(arc);
168 void ArcInliner::removeUnusedArcs(Region *unusedIn, ArrayRef<DefineOp> arcs) {
169 analysis.analyze({unusedIn}, arcs);
170 removeUnusedArcsInternal(arcs);
173 void InlineArcsAnalysis::analyze(ArrayRef<Region *> regionsWithCalls,
174 ArrayRef<DefineOp> arcDefinitions) {
175 callsInArcBody.clear();
182 for (
auto arc : arcDefinitions) {
183 auto arcName = arc.getSymNameAttr();
184 arcMap[arcName] = arc;
185 numOpsInArc[arcName] = 0;
186 arc->walk([&](Operation *op) {
187 if (!op->hasTrait<OpTrait::ConstantLike>() && !isa<OutputOp>(op))
188 ++numOpsInArc[arcName];
189 if (isa<mlir::CallOpInterface>(op))
191 callsInArcBody[arcName].push_back(cast<mlir::CallOpInterface>(op)
192 .getCallableForCallee()
193 .get<mlir::SymbolRefAttr>()
194 .getLeafReference());
196 if (numOpsInArc[arcName] <= options.maxNonTrivialOpsInBody)
197 ++statistics.numTrivialArcs;
199 LLVM_DEBUG(llvm::dbgs() <<
"Arc " << arc.getSymName() <<
" has "
200 << numOpsInArc[arcName] <<
" non-trivial ops\n");
205 usersPerArc[arc.getSymNameAttr()] = 0;
208 for (
auto *regionWithCalls : regionsWithCalls) {
209 regionWithCalls->walk([&](mlir::CallOpInterface op) {
210 if (!op.getCallableForCallee().is<SymbolRefAttr>())
214 op.getCallableForCallee().get<SymbolRefAttr>().getLeafReference();
215 if (!usersPerArc.contains(arcName))
218 ++usersPerArc[arcName];
224 for (
auto arc : arcDefinitions)
225 if (usersPerArc[arc.getSymNameAttr()] == 1)
226 ++statistics.numSingleUseArcs;
229 bool InlineArcsAnalysis::shouldInline(mlir::CallOpInterface callOp)
const {
231 if (!callOp.getCallableForCallee().is<SymbolRefAttr>())
234 if (!callOp->getParentOfType<DefineOp>() && options.intoArcsOnly)
241 callOp.getCallableForCallee().get<SymbolRefAttr>().getLeafReference();
242 if (!numOpsInArc.contains(arcName))
248 auto *inlinerInterface =
249 dyn_cast<mlir::DialectInlinerInterface>(callOp->getDialect());
250 if (!inlinerInterface ||
251 !inlinerInterface->isLegalToInline(callOp, getArc(callOp),
true))
254 if (numOpsInArc.at(arcName) <= options.maxNonTrivialOpsInBody)
257 return usersPerArc.at(arcName) == 1;
260 DefineOp InlineArcsAnalysis::getArc(mlir::CallOpInterface callOp)
const {
262 callOp.getCallableForCallee().get<SymbolRefAttr>().getLeafReference();
263 return arcMap.at(arcName);
266 void ArcInliner::inlineCallsInRegion(ArrayRef<Region *> regionsWithCalls,
267 ArrayRef<DefineOp> arcDefinitions,
268 bool removeUnusedArcs) {
269 analysis.analyze(regionsWithCalls, arcDefinitions);
270 for (
auto *regionWithCalls : regionsWithCalls)
271 inlineCallsInRegion(regionWithCalls);
273 if (removeUnusedArcs)
274 removeUnusedArcsInternal(arcDefinitions);
277 size_t InlineArcsAnalysis::getNumArcUses(StringAttr arcName)
const {
278 return usersPerArc.at(arcName);
281 void InlineArcsAnalysis::notifyInlinedCallInto(mlir::CallOpInterface callOp,
283 StringAttr calledArcName = callOp.getCallableForCallee()
284 .get<mlir::SymbolRefAttr>()
286 --usersPerArc[calledArcName];
287 ++statistics.numInlinedArcs;
289 auto arc = dyn_cast<DefineOp>(region->getParentOp());
293 StringAttr arcName = arc.getSymNameAttr();
295 numOpsInArc[arcName] += numOpsInArc[calledArcName] - 1;
296 auto &calls = callsInArcBody[arcName];
297 auto *iter = llvm::find(calls, calledArcName);
298 if (iter != calls.end())
301 for (
auto calleeName : callsInArcBody[calledArcName]) {
302 if (!usersPerArc.contains(calleeName))
305 ++usersPerArc[calleeName];
306 callsInArcBody[arcName].push_back(calleeName);
310 void InlineArcsAnalysis::notifyArcRemoved(DefineOp arc) {
311 for (
auto calleeName : callsInArcBody[arc.getSymNameAttr()])
312 --usersPerArc[calleeName];
314 callsInArcBody[arc.getSymNameAttr()].clear();
315 ++statistics.numRemovedArcs;
318 void InlineArcsPass::runOnOperation() {
321 InlineArcsOptions options;
322 options.intoArcsOnly = intoArcsOnly;
323 options.maxNonTrivialOpsInBody = maxNonTrivialOpsInBody;
324 InlineArcsStatistics statistics;
325 InlineArcsAnalysis analysis(statistics, options);
326 ArcInliner inliner(analysis);
332 SmallVector<DefineOp> arcDefinitions;
333 SmallVector<Region *> regions;
334 for (Operation &op : *getOperation().getBody()) {
335 if (
auto arc = dyn_cast<DefineOp>(&op)) {
336 arcDefinitions.emplace_back(arc);
337 regions.push_back(&arc.getBody());
341 if (isa<hw::HWModuleOp, mlir::func::FuncOp, ModelOp>(&op))
342 regions.push_back(&op.getRegion(0));
348 inliner.inlineCallsInRegion(regions, arcDefinitions,
353 numInlinedArcs = statistics.numInlinedArcs;
354 numRemovedArcs = statistics.numRemovedArcs;
355 numSingleUseArcs = statistics.numSingleUseArcs;
356 numTrivialArcs = statistics.numTrivialArcs;
360 return std::make_unique<InlineArcsPass>();
std::unique_ptr< mlir::Pass > createInlineArcsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.