11#include "mlir/Analysis/Liveness.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/Value.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/RegionUtils.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/Support/MathExtras.h"
21#define DEBUG_TYPE "arc-split-funcs"
25#define GEN_PASS_DEF_SPLITFUNCS
26#include "circt/Dialect/Arc/ArcPasses.h.inc"
39struct SplitFuncsPass :
public arc::impl::SplitFuncsBase<SplitFuncsPass> {
40 using arc::impl::SplitFuncsBase<SplitFuncsPass>::SplitFuncsBase;
41 void runOnOperation()
override;
42 LogicalResult lowerFunc(FuncOp funcOp);
44 SymbolTable *symbolTable;
48void SplitFuncsPass::runOnOperation() {
49 symbolTable = &getAnalysis<SymbolTable>();
50 for (
auto op :
llvm::make_early_inc_range(getOperation().getOps<FuncOp>())) {
54 if (failed(lowerFunc(op)))
55 return signalPassFailure();
59LogicalResult SplitFuncsPass::lowerFunc(FuncOp funcOp) {
61 return funcOp.emitError(
"Cannot split functions into functions of size 0.");
62 if (funcOp.getBody().getBlocks().size() > 1)
63 return funcOp.emitError(
"Regions with multiple blocks are not supported.");
64 assert(funcOp->getNumRegions() == 1);
65 unsigned numOps = funcOp.front().getOperations().size();
66 if (numOps < splitBound)
68 int numBlocks = llvm::divideCeil(numOps, splitBound);
69 OpBuilder opBuilder(funcOp->getContext());
70 SmallVector<Block *> blocks;
71 Block *frontBlock = &(funcOp.getBody().front());
72 blocks.push_back(frontBlock);
73 for (
int i = 0; i < numBlocks - 1; ++i) {
74 SmallVector<Location> locs(frontBlock->getNumArguments(), funcOp.getLoc());
75 auto *block = opBuilder.createBlock(&(funcOp.getBody()), {},
76 frontBlock->getArgumentTypes(), locs);
77 blocks.push_back(block);
80 unsigned numOpsInBlock = 0;
81 SmallVector<Block *>::iterator blockIter = blocks.begin();
82 for (
auto &op :
llvm::make_early_inc_range(*frontBlock)) {
83 if (numOpsInBlock >= splitBound) {
86 opBuilder.setInsertionPointToEnd(*blockIter);
90 if (*blockIter == (frontBlock))
94 (*blockIter)->push_back(&op);
98 for (
unsigned argIndex = 0; argIndex < frontBlock->getNumArguments();
100 auto oldArg = frontBlock->getArgument(argIndex);
101 auto newArg = blocks.back()->getArgument(argIndex);
102 replaceAllUsesInRegionWith(oldArg, newArg, funcOp.getBody());
104 Liveness liveness(funcOp);
105 auto argTypes = blocks.back()->getArgumentTypes();
106 auto args = blocks.back()->getArguments();
109 auto sortFunc = [](Value
a, Value
b) {
110 auto *opA =
a.getDefiningOp();
111 auto *opB =
b.getDefiningOp();
113 return cast<OpResult>(a).getResultNumber() <
114 cast<OpResult>(b).getResultNumber();
115 if (opA->getBlock() == opB->getBlock())
116 return opA->isBeforeInBlock(opB);
121 for (
int i = blocks.size() - 2; i >= 0; --i) {
122 liveness = Liveness(funcOp);
123 Block *currentBlock = blocks[i];
124 Liveness::ValueSetT liveOut = liveness.getLiveIn(blocks[i + 1]);
125 SmallVector<Value> outValues;
126 llvm::for_each(liveOut, [&outValues](
auto el) {
127 if (!isa<BlockArgument>(el))
128 outValues.push_back(el);
130 llvm::stable_sort(outValues, sortFunc);
131 opBuilder.setInsertionPointToEnd(currentBlock);
132 ReturnOp::create(opBuilder, funcOp->getLoc(), outValues);
135 for (
long unsigned i = 0; i < blocks.size() - 1; ++i) {
136 Block *currentBlock = blocks[i];
137 Liveness::ValueSetT liveOut = liveness.getLiveIn(blocks[i + 1]);
138 SmallVector<Value> outValues;
139 llvm::for_each(liveOut, [&outValues](
auto el) {
140 if (!isa<BlockArgument>(el))
141 outValues.push_back(el);
143 llvm::stable_sort(outValues, sortFunc);
144 auto outTypes = llvm::to_vector(
145 llvm::map_range(outValues, [](Value v) {
return v.getType(); }));
146 opBuilder.setInsertionPoint(funcOp);
147 SmallString<64> funcName;
148 funcName.append(funcOp.getName());
149 funcName.append(
"_split_func");
150 funcName.append(std::to_string(i));
152 FuncOp::create(opBuilder, funcOp->getLoc(), funcName,
153 opBuilder.getFunctionType(argTypes, outTypes));
155 symbolTable->insert(newFunc);
156 auto *funcBlock = newFunc.addEntryBlock();
157 for (
auto &op : make_early_inc_range(currentBlock->getOperations())) {
159 funcBlock->push_back(&op);
161 currentBlock->erase();
163 opBuilder.setInsertionPointToEnd(funcBlock);
164 for (
auto [j, el] :
llvm::enumerate(args))
165 replaceAllUsesInRegionWith(el, newFunc.getArgument(j),
166 newFunc.getRegion());
167 for (
auto pair : argMap)
168 replaceAllUsesInRegionWith(pair.first, pair.second, newFunc.getRegion());
169 opBuilder.setInsertionPointToStart(blocks[i + 1]);
170 Operation *callOp = func::CallOp::create(opBuilder, funcOp->getLoc(),
171 outTypes, funcName, args);
172 auto callResults = callOp->getResults();
174 for (
unsigned long k = 0; k < outValues.size(); ++k)
175 argMap.insert(std::pair(outValues[k], callResults[k]));
177 for (
auto pair : argMap)
178 replaceAllUsesInRegionWith(pair.first, pair.second, funcOp.getRegion());
assert(baseType &&"element must be base type")
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.