CIRCT  20.0.0git
SplitFuncs.cpp
Go to the documentation of this file.
1 //===- SplitFuncs.cpp -----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/Analysis/Liveness.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/RegionUtils.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/Support/MathExtras.h"
19 #include <string>
20 
21 #define DEBUG_TYPE "arc-split-funcs"
22 
23 namespace circt {
24 namespace arc {
25 #define GEN_PASS_DEF_SPLITFUNCS
26 #include "circt/Dialect/Arc/ArcPasses.h.inc"
27 } // namespace arc
28 } // namespace circt
29 
30 using namespace mlir;
31 using namespace circt;
32 using namespace func;
33 
34 //===----------------------------------------------------------------------===//
35 // Pass Implementation
36 //===----------------------------------------------------------------------===//
37 
38 namespace {
39 struct SplitFuncsPass : public arc::impl::SplitFuncsBase<SplitFuncsPass> {
40  using arc::impl::SplitFuncsBase<SplitFuncsPass>::SplitFuncsBase;
41  void runOnOperation() override;
42  LogicalResult lowerFunc(FuncOp funcOp);
43 
44  SymbolTable *symbolTable;
45 };
46 } // namespace
47 
48 void SplitFuncsPass::runOnOperation() {
49  symbolTable = &getAnalysis<SymbolTable>();
50  for (auto op : llvm::make_early_inc_range(getOperation().getOps<FuncOp>()))
51  if (failed(lowerFunc(op)))
52  return signalPassFailure();
53 }
54 
55 LogicalResult SplitFuncsPass::lowerFunc(FuncOp funcOp) {
56  if (splitBound == 0)
57  return funcOp.emitError("Cannot split functions into functions of size 0.");
58  if (funcOp.getBody().getBlocks().size() > 1)
59  return funcOp.emitError("Regions with multiple blocks are not supported.");
60  assert(funcOp->getNumRegions() == 1);
61  unsigned numOps = funcOp.front().getOperations().size();
62  if (numOps < splitBound)
63  return success();
64  int numBlocks = llvm::divideCeil(numOps, splitBound);
65  OpBuilder opBuilder(funcOp->getContext());
66  SmallVector<Block *> blocks;
67  Block *frontBlock = &(funcOp.getBody().front());
68  blocks.push_back(frontBlock);
69  for (int i = 0; i < numBlocks - 1; ++i) {
70  SmallVector<Location> locs(frontBlock->getNumArguments(), funcOp.getLoc());
71  auto *block = opBuilder.createBlock(&(funcOp.getBody()), {},
72  frontBlock->getArgumentTypes(), locs);
73  blocks.push_back(block);
74  }
75 
76  unsigned numOpsInBlock = 0;
77  SmallVector<Block *>::iterator blockIter = blocks.begin();
78  for (auto &op : llvm::make_early_inc_range(*frontBlock)) {
79  if (numOpsInBlock >= splitBound) {
80  ++blockIter;
81  numOpsInBlock = 0;
82  opBuilder.setInsertionPointToEnd(*blockIter);
83  }
84  ++numOpsInBlock;
85  // Don't bother moving ops to the original block
86  if (*blockIter == (frontBlock))
87  continue;
88  // Remove op from original block and insert in new block
89  op.remove();
90  (*blockIter)->push_back(&op);
91  }
92  DenseMap<Value, Value> argMap;
93  // Move function arguments to the block that will stay in the function
94  for (unsigned argIndex = 0; argIndex < frontBlock->getNumArguments();
95  ++argIndex) {
96  auto oldArg = frontBlock->getArgument(argIndex);
97  auto newArg = blocks.back()->getArgument(argIndex);
98  replaceAllUsesInRegionWith(oldArg, newArg, funcOp.getBody());
99  }
100  Liveness liveness(funcOp);
101  auto argTypes = blocks.back()->getArgumentTypes();
102  auto args = blocks.back()->getArguments();
103 
104  // Create return ops
105  for (int i = blocks.size() - 2; i >= 0; --i) {
106  liveness = Liveness(funcOp);
107  Block *currentBlock = blocks[i];
108  Liveness::ValueSetT liveOut = liveness.getLiveIn(blocks[i + 1]);
109  SmallVector<Value> outValues;
110  llvm::for_each(liveOut, [&outValues](auto el) {
111  if (!isa<BlockArgument>(el))
112  outValues.push_back(el);
113  });
114  opBuilder.setInsertionPointToEnd(currentBlock);
115  opBuilder.create<ReturnOp>(funcOp->getLoc(), outValues);
116  }
117  // Create and populate new FuncOps
118  for (long unsigned i = 0; i < blocks.size() - 1; ++i) {
119  Block *currentBlock = blocks[i];
120  Liveness::ValueSetT liveOut = liveness.getLiveIn(blocks[i + 1]);
121  SmallVector<Type> outTypes;
122  SmallVector<Value> outValues;
123  llvm::for_each(liveOut, [&outTypes, &outValues](auto el) {
124  if (!isa<BlockArgument>(el)) {
125  outValues.push_back(el);
126  outTypes.push_back(el.getType());
127  }
128  });
129  opBuilder.setInsertionPoint(funcOp);
130  SmallString<64> funcName;
131  funcName.append(funcOp.getName());
132  funcName.append("_split_func");
133  funcName.append(std::to_string(i));
134  auto newFunc =
135  opBuilder.create<FuncOp>(funcOp->getLoc(), funcName,
136  opBuilder.getFunctionType(argTypes, outTypes));
137  ++numFuncsCreated;
138  symbolTable->insert(newFunc);
139  auto *funcBlock = newFunc.addEntryBlock();
140  for (auto &op : make_early_inc_range(currentBlock->getOperations())) {
141  op.remove();
142  funcBlock->push_back(&op);
143  }
144  currentBlock->erase();
145 
146  opBuilder.setInsertionPointToEnd(funcBlock);
147  for (auto [j, el] : llvm::enumerate(args))
148  replaceAllUsesInRegionWith(el, newFunc.getArgument(j),
149  newFunc.getRegion());
150  for (auto pair : argMap)
151  replaceAllUsesInRegionWith(pair.first, pair.second, newFunc.getRegion());
152  opBuilder.setInsertionPointToStart(blocks[i + 1]);
153  Operation *callOp = opBuilder.create<func::CallOp>(
154  funcOp->getLoc(), outTypes, funcName, args);
155  auto callResults = callOp->getResults();
156  argMap.clear();
157  for (unsigned long k = 0; k < outValues.size(); ++k)
158  argMap.insert(std::pair(outValues[k], callResults[k]));
159  }
160  for (auto pair : argMap)
161  replaceAllUsesInRegionWith(pair.first, pair.second, funcOp.getRegion());
162  return success();
163 }
assert(baseType &&"element must be base type")
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21