10#include "mlir/Analysis/CFGLoopInfo.h"
11#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
12#include "mlir/Conversion/LLVMCommon/Pattern.h"
13#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
14#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/Dominance.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/DialectConversion.h"
19#include "llvm/ADT/TypeSwitch.h"
22#define GEN_PASS_DEF_INSERTMERGEBLOCKS
23#include "circt/Transforms/Passes.h.inc"
33 ConversionPatternRewriter &rewriter) {
34 rewriter.setInsertionPointToEnd(block);
35 auto term = block->getTerminator();
36 return llvm::TypeSwitch<Operation *, LogicalResult>(term)
37 .Case<cf::BranchOp>([&](
auto branchOp) {
38 rewriter.replaceOpWithNewOp<cf::BranchOp>(branchOp, newDest,
39 branchOp->getOperands());
42 .Case<cf::CondBranchOp>([&](
auto condBr) {
43 auto cond = condBr.getCondition();
45 Block *trueDest = condBr.getTrueDest();
46 Block *falseDest = condBr.getFalseDest();
49 if (trueDest == oldDest)
52 if (falseDest == oldDest)
55 rewriter.replaceOpWithNewOp<cf::CondBranchOp>(
56 condBr, cond, trueDest, condBr.getTrueOperands(), falseDest,
57 condBr.getFalseOperands());
60 .Default([&](Operation *op) {
61 return op->emitError(
"Unexpected terminator that cannot be handled.");
68 ConversionPatternRewriter &rewriter) {
69 auto blockArgTypes = oldSucc->getArgumentTypes();
70 SmallVector<Location> argLocs(blockArgTypes.size(), rewriter.getUnknownLoc());
72 Block *res = rewriter.createBlock(oldSucc, blockArgTypes, argLocs);
73 rewriter.create<cf::BranchOp>(rewriter.getUnknownLoc(), oldSucc,
87 DualGraph(Region &r, CFGLoopInfo &loopInfo);
89 size_t getNumPredecessors(Block *b) {
return predCnts.lookup(b); }
90 void getPredecessors(Block *b, SmallVectorImpl<Block *> &res);
92 size_t getNumSuccessors(Block *b) {
return succMap.lookup(b).size(); }
93 ArrayRef<Block *> getSuccessors(Block *b) {
94 return succMap.find(b)->getSecond();
99 Block *lookupDualBlock(Block *b);
100 DenseMap<Block *, size_t> getPredCountMapCopy() {
return predCnts; }
103 CFGLoopInfo &loopInfo;
105 DenseMap<Block *, SmallVector<Block *>> succMap;
106 DenseMap<Block *, size_t> predCnts;
110DualGraph::DualGraph(Region &r, CFGLoopInfo &loopInfo)
111 : loopInfo(loopInfo), succMap(), predCnts() {
113 CFGLoop *loop = loopInfo.getLoopFor(&b);
115 if (loop && loop->getHeader() != &b)
119 SmallVector<Block *> &succs =
120 succMap.try_emplace(&b, SmallVector<Block *>()).first->getSecond();
124 unsigned predCnt = 0;
125 for (
auto *pred : b.getPredecessors())
126 if (!loop || !loop->contains(pred))
129 if (loop && loop->getHeader() == &b)
130 loop->getExitBlocks(succs);
132 llvm::copy(b.getSuccessors(), std::back_inserter(succs));
134 predCnts.try_emplace(&b, predCnt);
138Block *DualGraph::lookupDualBlock(Block *b) {
139 CFGLoop *loop = loopInfo.getLoopFor(b);
143 return loop->getHeader();
146void DualGraph::getPredecessors(Block *b, SmallVectorImpl<Block *> &res) {
147 CFGLoop *loop = loopInfo.getLoopFor(b);
148 assert((!loop || loop->getHeader() == b) &&
149 "can only get predecessors of blocks in the graph");
151 for (
auto *pred : b->getPredecessors()) {
152 if (loop && loop->contains(pred))
155 if (CFGLoop *predLoop = loopInfo.getLoopFor(pred)) {
156 assert(predLoop->getExitBlock() &&
157 "multiple exit blocks are not yet supported");
158 res.push_back(predLoop->getHeader());
166using BlockToBlockMap = DenseMap<Block *, Block *>;
181 ConversionPatternRewriter &rewriter,
183 SmallVector<Block *> preds;
184 llvm::copy(currBlock->getPredecessors(), std::back_inserter(preds));
187 DenseMap<Block *, Block *> predsToConsider;
189 while (!preds.empty()) {
190 Block *pred = preds.pop_back_val();
191 Block *splitBlock = splitInfo.out.lookup(graph.lookupDualBlock(pred));
192 if (splitBlock == predDom)
197 if (predsToConsider.count(splitBlock) == 0) {
200 predsToConsider.try_emplace(splitBlock, pred);
205 Block *other = predsToConsider.lookup(splitBlock);
206 predsToConsider.erase(splitBlock);
214 Block *splitIn = splitInfo.in.lookup(splitBlock);
215 splitInfo.in.try_emplace(*
mergeBlock, splitIn);
217 splitInfo.out.try_emplace(*
mergeBlock, splitIn);
221 if (!predsToConsider.empty())
222 return currBlock->getParentOp()->emitError(
223 "irregular control flow is not yet supported");
229 for (
auto &info : loopInfo.getTopLevelLoops())
231 if (!info->getExitBlock())
232 return r.getParentOp()->emitError(
233 "multiple exit blocks are not yet supported");
247 ConversionPatternRewriter &rewriter) {
248 Block *entry = &r.front();
249 DominanceInfo domInfo(r.getParentOp());
251 CFGLoopInfo loopInfo(domInfo.getDomTree(&r));
256 SmallVector<Block *> stack;
257 stack.push_back(entry);
261 DualGraph graph(r, loopInfo);
264 auto predsToVisit = graph.getPredCountMapCopy();
268 while (!stack.empty()) {
269 Block *currBlock = stack.pop_back_val();
272 Block *out =
nullptr;
274 bool isMergeBlock = graph.getNumPredecessors(currBlock) > 1;
275 bool isSplitBlock = graph.getNumSuccessors(currBlock) > 1;
277 SmallVector<Block *> preds;
278 graph.getPredecessors(currBlock, preds);
281 Block *predDom = currBlock;
282 for (
auto *pred : preds) {
283 predDom = domInfo.findNearestCommonDominator(predDom, pred);
293 in = splitInfo.in.lookup(predDom);
294 }
else if (!preds.empty()) {
295 Block *pred = preds.front();
297 in = splitInfo.out.lookup(pred);
305 splitInfo.in.try_emplace(currBlock, in);
306 splitInfo.out.try_emplace(currBlock, out);
308 for (
auto *succ : graph.getSuccessors(currBlock)) {
309 auto it = predsToVisit.find(succ);
310 unsigned predsRemaining = --(it->getSecond());
313 if (predsRemaining == 0)
314 stack.push_back(succ);
323using PtrSet = SmallPtrSet<Operation *, 4>;
327 FuncOpPattern(PtrSet &rewrittenFuncs, MLIRContext *ctx)
331 matchAndRewrite(func::FuncOp op, OpAdaptor adaptor,
332 ConversionPatternRewriter &rewriter)
const override {
333 rewriter.startOpModification(op);
335 if (!op.isExternal())
337 rewriter.cancelOpModification(op);
341 rewriter.finalizeOpModification(op);
342 rewrittenFuncs.insert(op);
348 PtrSet &rewrittenFuncs;
351struct InsertMergeBlocksPass
352 :
public circt::impl::InsertMergeBlocksBase<InsertMergeBlocksPass> {
354 void runOnOperation()
override {
355 auto *ctx = &getContext();
358 PtrSet rewrittenFuncs;
359 patterns.add<FuncOpPattern>(rewrittenFuncs, ctx);
361 ConversionTarget target(*ctx);
362 target.addDynamicallyLegalOp<func::FuncOp>(
363 [&](func::FuncOp func) {
return rewrittenFuncs.contains(func); });
364 target.addLegalDialect<cf::ControlFlowDialect>();
366 if (applyPartialConversion(getOperation(), target, std::move(
patterns))
376 return std::make_unique<InsertMergeBlocksPass>();
assert(baseType &&"element must be base type")
static void mergeBlock(Block &destination, Block::iterator insertPoint, Block &source)
Move all operations from a source block in to a destination block.
static LogicalResult changeBranchTarget(Block *block, Block *oldDest, Block *newDest, ConversionPatternRewriter &rewriter)
Replaces the branching to oldDest of with an equivalent operation that instead branches to newDest.
static FailureOr< Block * > buildMergeBlock(Block *b1, Block *b2, Block *oldSucc, ConversionPatternRewriter &rewriter)
Creates a new intermediate block that b1 and b2 branch to.
static LogicalResult preconditionCheck(Region &r, CFGLoopInfo &loopInfo)
Checks preconditions of this transformation.
static LogicalResult buildMergeBlocks(Block *currBlock, SplitInfo &splitInfo, Block *predDom, ConversionPatternRewriter &rewriter, DualGraph &graph)
Builds a binary merge block tree for the predecessors of currBlock.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createInsertMergeBlocksPass()
mlir::LogicalResult insertMergeBlocks(mlir::Region &r, mlir::ConversionPatternRewriter &rewriter)
Insert additional blocks that serve as counterparts to the blocks that diverged the control flow.