CIRCT 22.0.0git
Loading...
Searching...
No Matches
SplitFuncs.cpp
Go to the documentation of this file.
1//===- SplitFuncs.cpp -----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/Analysis/Liveness.h"
12#include "mlir/Dialect/Func/IR/FuncOps.h"
13#include "mlir/IR/Builders.h"
14#include "mlir/IR/Value.h"
15#include "mlir/Pass/Pass.h"
16#include "mlir/Transforms/RegionUtils.h"
17#include "llvm/ADT/STLExtras.h"
18#include "llvm/Support/MathExtras.h"
19#include <string>
20
21#define DEBUG_TYPE "arc-split-funcs"
22
23namespace circt {
24namespace arc {
25#define GEN_PASS_DEF_SPLITFUNCS
26#include "circt/Dialect/Arc/ArcPasses.h.inc"
27} // namespace arc
28} // namespace circt
29
30using namespace mlir;
31using namespace circt;
32using namespace func;
33
34//===----------------------------------------------------------------------===//
35// Pass Implementation
36//===----------------------------------------------------------------------===//
37
38namespace {
39struct SplitFuncsPass : public arc::impl::SplitFuncsBase<SplitFuncsPass> {
40 using arc::impl::SplitFuncsBase<SplitFuncsPass>::SplitFuncsBase;
41 void runOnOperation() override;
42 LogicalResult lowerFunc(FuncOp funcOp);
43
44 SymbolTable *symbolTable;
45};
46} // namespace
47
48void SplitFuncsPass::runOnOperation() {
49 symbolTable = &getAnalysis<SymbolTable>();
50 for (auto op : llvm::make_early_inc_range(getOperation().getOps<FuncOp>())) {
51 // Ignore extern functions
52 if (op.isExternal())
53 continue;
54 if (failed(lowerFunc(op)))
55 return signalPassFailure();
56 }
57}
58
59LogicalResult SplitFuncsPass::lowerFunc(FuncOp funcOp) {
60 if (splitBound == 0)
61 return funcOp.emitError("Cannot split functions into functions of size 0.");
62 if (funcOp.getBody().getBlocks().size() > 1)
63 return funcOp.emitError("Regions with multiple blocks are not supported.");
64 assert(funcOp->getNumRegions() == 1);
65 unsigned numOps = funcOp.front().getOperations().size();
66 if (numOps < splitBound)
67 return success();
68 int numBlocks = llvm::divideCeil(numOps, splitBound);
69 OpBuilder opBuilder(funcOp->getContext());
70 SmallVector<Block *> blocks;
71 Block *frontBlock = &(funcOp.getBody().front());
72 blocks.push_back(frontBlock);
73 for (int i = 0; i < numBlocks - 1; ++i) {
74 SmallVector<Location> locs(frontBlock->getNumArguments(), funcOp.getLoc());
75 auto *block = opBuilder.createBlock(&(funcOp.getBody()), {},
76 frontBlock->getArgumentTypes(), locs);
77 blocks.push_back(block);
78 }
79
80 unsigned numOpsInBlock = 0;
81 SmallVector<Block *>::iterator blockIter = blocks.begin();
82 for (auto &op : llvm::make_early_inc_range(*frontBlock)) {
83 if (numOpsInBlock >= splitBound) {
84 ++blockIter;
85 numOpsInBlock = 0;
86 opBuilder.setInsertionPointToEnd(*blockIter);
87 }
88 ++numOpsInBlock;
89 // Don't bother moving ops to the original block
90 if (*blockIter == (frontBlock))
91 continue;
92 // Remove op from original block and insert in new block
93 op.remove();
94 (*blockIter)->push_back(&op);
95 }
96 DenseMap<Value, Value> argMap;
97 // Move function arguments to the block that will stay in the function
98 for (unsigned argIndex = 0; argIndex < frontBlock->getNumArguments();
99 ++argIndex) {
100 auto oldArg = frontBlock->getArgument(argIndex);
101 auto newArg = blocks.back()->getArgument(argIndex);
102 replaceAllUsesInRegionWith(oldArg, newArg, funcOp.getBody());
103 }
104 Liveness liveness(funcOp);
105 auto argTypes = blocks.back()->getArgumentTypes();
106 auto args = blocks.back()->getArguments();
107
108 // Create return ops
109 for (int i = blocks.size() - 2; i >= 0; --i) {
110 liveness = Liveness(funcOp);
111 Block *currentBlock = blocks[i];
112 Liveness::ValueSetT liveOut = liveness.getLiveIn(blocks[i + 1]);
113 SmallVector<Value> outValues;
114 llvm::for_each(liveOut, [&outValues](auto el) {
115 if (!isa<BlockArgument>(el))
116 outValues.push_back(el);
117 });
118 opBuilder.setInsertionPointToEnd(currentBlock);
119 ReturnOp::create(opBuilder, funcOp->getLoc(), outValues);
120 }
121 // Create and populate new FuncOps
122 for (long unsigned i = 0; i < blocks.size() - 1; ++i) {
123 Block *currentBlock = blocks[i];
124 Liveness::ValueSetT liveOut = liveness.getLiveIn(blocks[i + 1]);
125 SmallVector<Type> outTypes;
126 SmallVector<Value> outValues;
127 llvm::for_each(liveOut, [&outTypes, &outValues](auto el) {
128 if (!isa<BlockArgument>(el)) {
129 outValues.push_back(el);
130 outTypes.push_back(el.getType());
131 }
132 });
133 opBuilder.setInsertionPoint(funcOp);
134 SmallString<64> funcName;
135 funcName.append(funcOp.getName());
136 funcName.append("_split_func");
137 funcName.append(std::to_string(i));
138 auto newFunc =
139 FuncOp::create(opBuilder, funcOp->getLoc(), funcName,
140 opBuilder.getFunctionType(argTypes, outTypes));
141 ++numFuncsCreated;
142 symbolTable->insert(newFunc);
143 auto *funcBlock = newFunc.addEntryBlock();
144 for (auto &op : make_early_inc_range(currentBlock->getOperations())) {
145 op.remove();
146 funcBlock->push_back(&op);
147 }
148 currentBlock->erase();
149
150 opBuilder.setInsertionPointToEnd(funcBlock);
151 for (auto [j, el] : llvm::enumerate(args))
152 replaceAllUsesInRegionWith(el, newFunc.getArgument(j),
153 newFunc.getRegion());
154 for (auto pair : argMap)
155 replaceAllUsesInRegionWith(pair.first, pair.second, newFunc.getRegion());
156 opBuilder.setInsertionPointToStart(blocks[i + 1]);
157 Operation *callOp = func::CallOp::create(opBuilder, funcOp->getLoc(),
158 outTypes, funcName, args);
159 auto callResults = callOp->getResults();
160 argMap.clear();
161 for (unsigned long k = 0; k < outValues.size(); ++k)
162 argMap.insert(std::pair(outValues[k], callResults[k]));
163 }
164 for (auto pair : argMap)
165 replaceAllUsesInRegionWith(pair.first, pair.second, funcOp.getRegion());
166 return success();
167}
assert(baseType &&"element must be base type")
Definition arc.py:1
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.