11 #include "mlir/Analysis/Liveness.h"
12 #include "mlir/Dialect/Func/IR/FuncOps.h"
13 #include "mlir/IR/Builders.h"
14 #include "mlir/IR/Value.h"
15 #include "mlir/Pass/Pass.h"
16 #include "mlir/Transforms/RegionUtils.h"
17 #include "llvm/ADT/STLExtras.h"
18 #include "llvm/Support/MathExtras.h"
21 #define DEBUG_TYPE "arc-split-funcs"
25 #define GEN_PASS_DEF_SPLITFUNCS
26 #include "circt/Dialect/Arc/ArcPasses.h.inc"
31 using namespace circt;
39 struct SplitFuncsPass :
public arc::impl::SplitFuncsBase<SplitFuncsPass> {
40 using arc::impl::SplitFuncsBase<SplitFuncsPass>::SplitFuncsBase;
41 void runOnOperation()
override;
42 LogicalResult lowerFunc(FuncOp funcOp);
44 SymbolTable *symbolTable;
48 void SplitFuncsPass::runOnOperation() {
49 symbolTable = &getAnalysis<SymbolTable>();
50 for (
auto op : llvm::make_early_inc_range(getOperation().getOps<FuncOp>()))
51 if (failed(lowerFunc(op)))
52 return signalPassFailure();
55 LogicalResult SplitFuncsPass::lowerFunc(FuncOp funcOp) {
57 return funcOp.emitError(
"Cannot split functions into functions of size 0.");
58 if (funcOp.getBody().getBlocks().size() > 1)
59 return funcOp.emitError(
"Regions with multiple blocks are not supported.");
60 assert(funcOp->getNumRegions() == 1);
61 unsigned numOps = funcOp.front().getOperations().size();
62 if (numOps < splitBound)
64 int numBlocks = llvm::divideCeil(numOps, splitBound);
65 OpBuilder opBuilder(funcOp->getContext());
66 SmallVector<Block *> blocks;
67 Block *frontBlock = &(funcOp.getBody().front());
68 blocks.push_back(frontBlock);
69 for (
int i = 0; i < numBlocks - 1; ++i) {
70 SmallVector<Location> locs(frontBlock->getNumArguments(), funcOp.getLoc());
71 auto *block = opBuilder.createBlock(&(funcOp.getBody()), {},
72 frontBlock->getArgumentTypes(), locs);
73 blocks.push_back(block);
76 unsigned numOpsInBlock = 0;
77 SmallVector<Block *>::iterator blockIter = blocks.begin();
78 for (
auto &op : llvm::make_early_inc_range(*frontBlock)) {
79 if (numOpsInBlock >= splitBound) {
82 opBuilder.setInsertionPointToEnd(*blockIter);
86 if (*blockIter == (frontBlock))
90 (*blockIter)->push_back(&op);
92 DenseMap<Value, Value> argMap;
94 for (
unsigned argIndex = 0; argIndex < frontBlock->getNumArguments();
96 auto oldArg = frontBlock->getArgument(argIndex);
97 auto newArg = blocks.back()->getArgument(argIndex);
98 replaceAllUsesInRegionWith(oldArg, newArg, funcOp.getBody());
100 Liveness liveness(funcOp);
101 auto argTypes = blocks.back()->getArgumentTypes();
102 auto args = blocks.back()->getArguments();
105 for (
int i = blocks.size() - 2; i >= 0; --i) {
106 liveness = Liveness(funcOp);
107 Block *currentBlock = blocks[i];
108 Liveness::ValueSetT liveOut = liveness.getLiveIn(blocks[i + 1]);
109 SmallVector<Value> outValues;
110 llvm::for_each(liveOut, [&outValues](
auto el) {
111 if (!isa<BlockArgument>(el))
112 outValues.push_back(el);
114 opBuilder.setInsertionPointToEnd(currentBlock);
115 opBuilder.create<ReturnOp>(funcOp->getLoc(), outValues);
118 for (
long unsigned i = 0; i < blocks.size() - 1; ++i) {
119 Block *currentBlock = blocks[i];
120 Liveness::ValueSetT liveOut = liveness.getLiveIn(blocks[i + 1]);
121 SmallVector<Type> outTypes;
122 SmallVector<Value> outValues;
123 llvm::for_each(liveOut, [&outTypes, &outValues](
auto el) {
124 if (!isa<BlockArgument>(el)) {
125 outValues.push_back(el);
126 outTypes.push_back(el.getType());
129 opBuilder.setInsertionPoint(funcOp);
130 SmallString<64> funcName;
131 funcName.append(funcOp.getName());
132 funcName.append(
"_split_func");
133 funcName.append(std::to_string(i));
135 opBuilder.create<FuncOp>(funcOp->getLoc(), funcName,
136 opBuilder.getFunctionType(argTypes, outTypes));
138 symbolTable->insert(newFunc);
139 auto *funcBlock = newFunc.addEntryBlock();
140 for (
auto &op : make_early_inc_range(currentBlock->getOperations())) {
142 funcBlock->push_back(&op);
144 currentBlock->erase();
146 opBuilder.setInsertionPointToEnd(funcBlock);
147 for (
auto [j, el] : llvm::enumerate(args))
148 replaceAllUsesInRegionWith(el, newFunc.getArgument(j),
149 newFunc.getRegion());
150 for (
auto pair : argMap)
151 replaceAllUsesInRegionWith(pair.first, pair.second, newFunc.getRegion());
152 opBuilder.setInsertionPointToStart(blocks[i + 1]);
153 Operation *callOp = opBuilder.create<func::CallOp>(
154 funcOp->getLoc(), outTypes, funcName, args);
155 auto callResults = callOp->getResults();
157 for (
unsigned long k = 0; k < outValues.size(); ++k)
158 argMap.insert(std::pair(outValues[k], callResults[k]));
160 for (
auto pair : argMap)
161 replaceAllUsesInRegionWith(pair.first, pair.second, funcOp.getRegion());
assert(baseType &&"element must be base type")
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.