12#include "mlir/Dialect/Func/IR/FuncOps.h" 
   13#include "mlir/IR/IRMapping.h" 
   14#include "mlir/Pass/Pass.h" 
   15#include "mlir/Transforms/InliningUtils.h" 
   16#include "llvm/Support/Debug.h" 
   18#define DEBUG_TYPE "arc-inline" 
   22#define GEN_PASS_DEF_INLINEARCS 
   23#include "circt/Dialect/Arc/ArcPasses.h.inc" 
   35struct InlineArcsStatistics {
 
   36  size_t numInlinedArcs = 0;
 
   37  size_t numRemovedArcs = 0;
 
   38  size_t numTrivialArcs = 0;
 
   39  size_t numSingleUseArcs = 0;
 
   45class InlineArcsAnalysis {
 
   47  InlineArcsAnalysis(InlineArcsStatistics &statistics,
 
   48                     const InlineArcsOptions &options)
 
   49      : statistics(statistics), options(options) {}
 
   55  void analyze(ArrayRef<Region *> regionsWithCalls,
 
   56               ArrayRef<DefineOp> arcDefinitions);
 
   60  void notifyInlinedCallInto(mlir::CallOpInterface callOp, Region *region);
 
   63  void notifyArcRemoved(DefineOp arc);
 
   67  bool shouldInline(mlir::CallOpInterface callOp) 
const;
 
   70  DefineOp getArc(mlir::CallOpInterface callOp) 
const;
 
   74  size_t getNumArcUses(StringAttr arcName) 
const;
 
   77  DenseMap<StringAttr, SmallVector<StringAttr>> callsInArcBody;
 
   78  DenseMap<StringAttr, size_t> numOpsInArc;
 
   79  DenseMap<StringAttr, size_t> usersPerArc;
 
   80  DenseMap<StringAttr, DefineOp> arcMap;
 
   82  InlineArcsStatistics &statistics;
 
   83  const InlineArcsOptions &options;
 
   90  explicit ArcInliner(InlineArcsAnalysis &analysis) : analysis(analysis) {}
 
   96  void inlineCallsInRegion(ArrayRef<Region *> regionsWithCalls,
 
   97                           ArrayRef<DefineOp> arcDefinitions,
 
   98                           bool removeUnusedArcs = 
false);
 
  102  void removeUnusedArcs(Region *unusedIn, ArrayRef<DefineOp> arcs);
 
  105  void inlineCallsInRegion(Region *region);
 
  106  void removeUnusedArcsInternal(ArrayRef<DefineOp> arcs);
 
  108  InlineArcsAnalysis &analysis;
 
  113struct InlineArcsPass : 
public arc::impl::InlineArcsBase<InlineArcsPass> {
 
  114  using InlineArcsBase::InlineArcsBase;
 
  116  void runOnOperation() 
override;
 
  121void ArcInliner::inlineCallsInRegion(Region *region) {
 
  122  for (
auto &block : region->getBlocks()) {
 
  123    for (
auto iter = block.begin(); iter != block.end(); ++iter) {
 
  124      Operation &op = *iter;
 
  125      if (
auto callOp = dyn_cast<mlir::CallOpInterface>(op);
 
  126          callOp && analysis.shouldInline(callOp)) {
 
  127        DefineOp arc = analysis.getArc(callOp);
 
  128        auto args = arc.getBodyBlock().getArguments();
 
  130        IRMapping localMapping;
 
  131        for (
auto [arg, operand] : 
llvm::zip(args, callOp.getArgOperands()))
 
  132          localMapping.map(arg, operand);
 
  134        OpBuilder builder(callOp);
 
  135        builder.setInsertionPointAfter(callOp);
 
  136        for (
auto &op : arc.
getBodyBlock().without_terminator())
 
  137          builder.clone(op, localMapping);
 
  139        for (
auto [returnVal, result] :
 
  141                       callOp->getResults()))
 
  142          result.replaceAllUsesWith(localMapping.lookup(returnVal));
 
  144        analysis.notifyInlinedCallInto(callOp, region);
 
  153      for (Region ®ion : op.getRegions())
 
  154        inlineCallsInRegion(®ion);
 
  159void ArcInliner::removeUnusedArcsInternal(ArrayRef<DefineOp> arcs) {
 
  160  for (
auto arc : 
llvm::make_early_inc_range(arcs)) {
 
  161    if (analysis.getNumArcUses(arc.getSymNameAttr()) == 0) {
 
  162      analysis.notifyArcRemoved(arc);
 
  168void ArcInliner::removeUnusedArcs(Region *unusedIn, ArrayRef<DefineOp> arcs) {
 
  169  analysis.analyze({unusedIn}, arcs);
 
  170  removeUnusedArcsInternal(arcs);
 
  173void InlineArcsAnalysis::analyze(ArrayRef<Region *> regionsWithCalls,
 
  174                                 ArrayRef<DefineOp> arcDefinitions) {
 
  175  callsInArcBody.clear();
 
  182  for (
auto arc : arcDefinitions) {
 
  183    auto arcName = arc.getSymNameAttr();
 
  184    arcMap[arcName] = arc;
 
  185    numOpsInArc[arcName] = 0;
 
  186    arc->walk([&](Operation *op) {
 
  187      if (!op->hasTrait<OpTrait::ConstantLike>() && !isa<OutputOp>(op))
 
  188        ++numOpsInArc[arcName];
 
  189      if (isa<mlir::CallOpInterface>(op))
 
  191        callsInArcBody[arcName].push_back(
 
  192            cast<mlir::SymbolRefAttr>(
 
  193                cast<mlir::CallOpInterface>(op).getCallableForCallee())
 
  194                .getLeafReference());
 
  196    if (numOpsInArc[arcName] <= options.maxNonTrivialOpsInBody)
 
  197      ++statistics.numTrivialArcs;
 
  199    LLVM_DEBUG(llvm::dbgs() << 
"Arc " << arc.getSymName() << 
" has " 
  200                            << numOpsInArc[arcName] << 
" non-trivial ops\n");
 
  205    usersPerArc[arc.getSymNameAttr()] = 0;
 
  208  for (
auto *regionWithCalls : regionsWithCalls) {
 
  209    regionWithCalls->walk([&](mlir::CallOpInterface op) {
 
  210      if (!isa<SymbolRefAttr>(op.getCallableForCallee()))
 
  214          cast<SymbolRefAttr>(op.getCallableForCallee()).getLeafReference();
 
  215      if (!usersPerArc.contains(arcName))
 
  218      ++usersPerArc[arcName];
 
  224  for (
auto arc : arcDefinitions)
 
  225    if (usersPerArc[arc.getSymNameAttr()] == 1)
 
  226      ++statistics.numSingleUseArcs;
 
  229bool InlineArcsAnalysis::shouldInline(mlir::CallOpInterface callOp)
 const {
 
  231  if (!isa<SymbolRefAttr>(callOp.getCallableForCallee()))
 
  234  if (!callOp->getParentOfType<DefineOp>() && options.intoArcsOnly)
 
  240  StringAttr arcName = llvm::cast<SymbolRefAttr>(callOp.getCallableForCallee())
 
  242  if (!numOpsInArc.contains(arcName))
 
  248  auto *inlinerInterface =
 
  249      dyn_cast<mlir::DialectInlinerInterface>(callOp->getDialect());
 
  250  if (!inlinerInterface ||
 
  251      !inlinerInterface->isLegalToInline(callOp, getArc(callOp), 
true))
 
  254  if (numOpsInArc.at(arcName) <= options.maxNonTrivialOpsInBody)
 
  257  return usersPerArc.at(arcName) == 1;
 
  260DefineOp InlineArcsAnalysis::getArc(mlir::CallOpInterface callOp)
 const {
 
  262      cast<SymbolRefAttr>(callOp.getCallableForCallee()).getLeafReference();
 
  263  return arcMap.at(arcName);
 
  266void ArcInliner::inlineCallsInRegion(ArrayRef<Region *> regionsWithCalls,
 
  267                                     ArrayRef<DefineOp> arcDefinitions,
 
  268                                     bool removeUnusedArcs) {
 
  269  analysis.analyze(regionsWithCalls, arcDefinitions);
 
  270  for (
auto *regionWithCalls : regionsWithCalls)
 
  271    inlineCallsInRegion(regionWithCalls);
 
  273  if (removeUnusedArcs)
 
  274    removeUnusedArcsInternal(arcDefinitions);
 
  277size_t InlineArcsAnalysis::getNumArcUses(StringAttr arcName)
 const {
 
  278  return usersPerArc.at(arcName);
 
  281void InlineArcsAnalysis::notifyInlinedCallInto(mlir::CallOpInterface callOp,
 
  283  StringAttr calledArcName =
 
  284      cast<mlir::SymbolRefAttr>(callOp.getCallableForCallee())
 
  286  --usersPerArc[calledArcName];
 
  287  ++statistics.numInlinedArcs;
 
  289  for (
auto calleeName : callsInArcBody[calledArcName]) {
 
  290    if (!usersPerArc.contains(calleeName))
 
  292    ++usersPerArc[calleeName];
 
  295  auto arc = dyn_cast<DefineOp>(region->getParentOp());
 
  299  StringAttr arcName = arc.getSymNameAttr();
 
  301  numOpsInArc[arcName] += numOpsInArc[calledArcName] - 1;
 
  302  auto &calls = callsInArcBody[arcName];
 
  303  auto *iter = llvm::find(calls, calledArcName);
 
  304  if (iter != calls.end())
 
  307  for (
auto calleeName : callsInArcBody[calledArcName]) {
 
  308    if (!usersPerArc.contains(calleeName))
 
  310    callsInArcBody[arcName].push_back(calleeName);
 
  314void InlineArcsAnalysis::notifyArcRemoved(DefineOp arc) {
 
  315  for (
auto calleeName : callsInArcBody[arc.getSymNameAttr()])
 
  316    --usersPerArc[calleeName];
 
  318  callsInArcBody[arc.getSymNameAttr()].clear();
 
  319  ++statistics.numRemovedArcs;
 
  322void InlineArcsPass::runOnOperation() {
 
  325  InlineArcsOptions options;
 
  326  options.intoArcsOnly = intoArcsOnly;
 
  327  options.maxNonTrivialOpsInBody = maxNonTrivialOpsInBody;
 
  328  InlineArcsStatistics statistics;
 
  329  InlineArcsAnalysis analysis(statistics, options);
 
  330  ArcInliner inliner(analysis);
 
  336  SmallVector<DefineOp> arcDefinitions;
 
  337  SmallVector<Region *> regions;
 
  338  for (Operation &op : *getOperation().getBody()) {
 
  339    if (
auto arc = dyn_cast<DefineOp>(&op)) {
 
  340      arcDefinitions.emplace_back(arc);
 
  341      regions.push_back(&arc.getBody());
 
  345    if (isa<hw::HWModuleOp, mlir::func::FuncOp, ModelOp>(&op))
 
  346      regions.push_back(&op.getRegion(0));
 
  352  inliner.inlineCallsInRegion(regions, arcDefinitions,
 
  357  numInlinedArcs = statistics.numInlinedArcs;
 
  358  numRemovedArcs = statistics.numRemovedArcs;
 
  359  numSingleUseArcs = statistics.numSingleUseArcs;
 
  360  numTrivialArcs = statistics.numTrivialArcs;
 
  364  return std::make_unique<InlineArcsPass>();
 
 
static Block * getBodyBlock(FModuleLike mod)
std::unique_ptr< mlir::Pass > createInlineArcsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.