13#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
14#include "mlir/IR/Dominance.h"
15#include "mlir/IR/Matchers.h"
16#include "mlir/Pass/Pass.h"
17#include "llvm/ADT/PostOrderIterator.h"
18#include "llvm/Support/Debug.h"
20#define DEBUG_TYPE "llhd-remove-control-flow"
24#define GEN_PASS_DEF_REMOVECONTROLFLOWPASS
25#include "circt/Dialect/LLHD/Transforms/LLHDPasses.h.inc"
42 Condition(Value value) : pair(value, 0) {
44 if (matchPattern(value, m_One()))
45 *
this = Condition(
true);
46 if (matchPattern(value, m_Zero()))
47 *
this = Condition(
false);
50 Condition(
bool konst) : pair(nullptr, konst ? 1 : 2) {}
52 explicit operator bool()
const {
53 return pair.getPointer() !=
nullptr || pair.getInt() != 0;
56 bool isTrue()
const {
return !pair.getPointer() && pair.getInt() == 1; }
57 bool isFalse()
const {
return !pair.getPointer() && pair.getInt() == 2; }
58 Value getValue()
const {
return pair.getPointer(); }
62 Value materialize(OpBuilder &builder, Location loc)
const {
67 return pair.getPointer();
70 Condition orWith(Condition other, OpBuilder &builder)
const {
71 if (isTrue() || other.isTrue())
77 return builder.createOrFold<
comb::OrOp>(getValue().getLoc(), getValue(),
81 Condition andWith(Condition other, OpBuilder &builder)
const {
82 if (isFalse() || other.isFalse())
88 return builder.createOrFold<
comb::AndOp>(getValue().getLoc(), getValue(),
92 Condition inverted(OpBuilder &builder)
const {
97 return comb::createOrFoldNot(getValue().
getLoc(), getValue(), builder);
101 llvm::PointerIntPair<Value, 2> pair;
113 OpBuilder &builder, Block *dominator, Block *target,
114 SmallDenseMap<std::pair<Block *, Block *>, Condition> &decisions) {
115 if (
auto decision = decisions.lookup({dominator, target}))
118 SmallPtrSet<Block *, 8> visitedBlocks;
119 visitedBlocks.insert(dominator);
120 if (
auto &decision = decisions[{dominator, dominator}]; !decision)
121 decision = Condition(
true);
128 for (
auto *block : llvm::inverse_post_order_ext(target, visitedBlocks)) {
129 auto merged = Condition(
false);
130 for (
auto *pred : block->getPredecessors()) {
131 auto predDecision = decisions.lookup({dominator, pred});
133 if (pred->getTerminator()->getNumSuccessors() != 1) {
134 auto condBr = cast<cf::CondBranchOp>(pred->getTerminator());
135 if (condBr.getTrueDest() == condBr.getFalseDest()) {
136 merged = merged.orWith(predDecision, builder);
138 auto cond = Condition(condBr.getCondition());
139 if (condBr.getFalseDest() == block)
140 cond = cond.inverted(builder);
141 merged = merged.orWith(cond.andWith(predDecision, builder), builder);
144 merged = merged.orWith(predDecision, builder);
148 decisions.insert({{dominator, block}, merged});
151 return decisions.lookup({dominator, target});
161 CFRemover(Region ®ion) : region(region) {}
168 SmallVector<Block *> sortedBlocks;
170 DominanceInfo domInfo;
174void CFRemover::run() {
175 LLVM_DEBUG(llvm::dbgs() <<
"Removing control flow in " << region.getLoc()
181 SmallVector<YieldOp, 2> yieldOps;
182 SmallPtrSet<Block *, 8> visitedBlocks, ipoSet;
183 for (
auto &block : region) {
184 for (
auto *ipoBlock :
llvm::inverse_post_order_ext(&block, ipoSet)) {
185 if (!llvm::all_of(ipoBlock->getPredecessors(), [&](
auto *pred) {
186 return visitedBlocks.contains(pred);
188 LLVM_DEBUG(llvm::dbgs() <<
"- Loop detected, giving up\n");
191 visitedBlocks.insert(ipoBlock);
192 sortedBlocks.push_back(ipoBlock);
196 for (
auto &op : block) {
197 if (!isMemoryEffectFree(&op)) {
198 LLVM_DEBUG(llvm::dbgs() <<
"- Has side effects, giving up\n");
204 if (!isa<YieldOp, cf::BranchOp, cf::CondBranchOp>(block.getTerminator())) {
205 LLVM_DEBUG(llvm::dbgs()
206 <<
"- Has unsupported terminator "
207 << block.getTerminator()->getName() <<
", giving up\n");
212 if (
auto yieldOp = dyn_cast<YieldOp>(block.getTerminator()))
213 yieldOps.push_back(yieldOp);
217 auto yieldOp = yieldOps[0];
218 if (yieldOps.size() > 1) {
219 LLVM_DEBUG(llvm::dbgs() <<
"- Creating single yield block\n");
220 OpBuilder builder(region.getContext());
221 SmallVector<Location> locs(yieldOps[0].getNumOperands(), region.getLoc());
222 auto *yieldBlock = builder.createBlock(®ion, region.end(),
223 yieldOps[0].getOperandTypes(), locs);
224 sortedBlocks.push_back(yieldBlock);
226 YieldOp::create(builder, region.getLoc(), yieldBlock->getArguments());
227 for (
auto yieldOp : yieldOps) {
228 builder.setInsertionPoint(yieldOp);
229 cf::BranchOp::create(builder, yieldOp.getLoc(), yieldBlock,
230 yieldOp.getOperands());
236 domInfo = DominanceInfo(region.getParentOp());
242 auto *entryBlock = sortedBlocks.front();
243 for (
auto *block : sortedBlocks) {
244 if (!domInfo.isReachableFromEntry(block))
247 llvm::dbgs() <<
"- Merging block ";
248 block->printAsOperand(llvm::dbgs());
249 llvm::dbgs() <<
"\n";
255 auto *domBlock = block;
256 for (
auto *pred : block->getPredecessors())
257 if (domInfo.isReachableFromEntry(pred))
258 domBlock = domInfo.findNearestCommonDominator(domBlock, pred);
260 llvm::dbgs() <<
" - Common dominator: ";
261 domBlock->printAsOperand(llvm::dbgs());
262 llvm::dbgs() <<
"\n";
266 OpBuilder builder(entryBlock->getTerminator());
267 SmallVector<Value> mergedArgs;
268 SmallPtrSet<Block *, 4> seenPreds;
269 for (
auto *pred : block->getPredecessors()) {
271 if (!seenPreds.insert(pred).second)
275 if (!domInfo.isReachableFromEntry(pred))
283 auto mergeArgs = [&](ValueRange args, Condition cond,
bool invCond) {
284 if (mergedArgs.empty()) {
289 builder, domBlock, pred, decisionCache);
292 cond = cond.inverted(builder);
293 decision = decision.andWith(cond, builder);
295 for (
auto [mergedArg, arg] :
llvm::zip(mergedArgs, args)) {
296 if (decision.isTrue())
298 else if (decision.isFalse())
302 arg.getLoc(), decision.materialize(builder, arg.getLoc()), arg,
308 if (
auto condBrOp = dyn_cast<cf::CondBranchOp>(pred->getTerminator())) {
309 if (condBrOp.getTrueDest() == condBrOp.getFalseDest()) {
313 LLVM_DEBUG(llvm::dbgs() <<
" - Both from " << condBrOp <<
"\n");
314 SmallVector<Value> mergedOperands;
315 mergedOperands.reserve(block->getNumArguments());
316 for (
auto [trueArg, falseArg] :
317 llvm::zip(condBrOp.getTrueDestOperands(),
318 condBrOp.getFalseDestOperands())) {
319 mergedOperands.push_back(builder.createOrFold<
comb::MuxOp>(
320 trueArg.getLoc(), condBrOp.getCondition(), trueArg, falseArg));
322 mergeArgs(mergedOperands, Value{},
false);
323 }
else if (condBrOp.getTrueDest() == block) {
325 LLVM_DEBUG(llvm::dbgs() <<
" - True from " << condBrOp <<
"\n");
326 mergeArgs(condBrOp.getTrueDestOperands(), condBrOp.getCondition(),
330 LLVM_DEBUG(llvm::dbgs() <<
" - False from " << condBrOp <<
"\n");
331 mergeArgs(condBrOp.getFalseDestOperands(), condBrOp.getCondition(),
335 auto brOp = cast<cf::BranchOp>(pred->getTerminator());
336 LLVM_DEBUG(llvm::dbgs() <<
" - From " << brOp <<
"\n");
337 mergeArgs(brOp.getDestOperands(), Value{},
false);
340 for (
auto [blockArg, mergedArg] :
341 llvm::zip(block->getArguments(), mergedArgs))
342 blockArg.replaceAllUsesWith(mergedArg);
345 if (block != entryBlock)
346 entryBlock->getOperations().splice(--entryBlock->end(),
347 block->getOperations(), block->begin(),
352 if (yieldOp != entryBlock->getTerminator()) {
353 yieldOp->moveBefore(entryBlock->getTerminator());
354 entryBlock->getTerminator()->erase();
360 for (
auto *block : sortedBlocks)
361 if (block != entryBlock)
363 for (
auto *block : sortedBlocks)
364 if (block != entryBlock)
375struct RemoveControlFlowPass
376 :
public llhd::impl::RemoveControlFlowPassBase<RemoveControlFlowPass> {
377 void runOnOperation()
override;
381void RemoveControlFlowPass::runOnOperation() {
382 for (
auto op : getOperation().getOps<CombinationalOp>())
383 CFRemover(op.getBody()).
run();
assert(baseType &&"element must be base type")
static Location getLoc(DefSlot slot)
static Condition getBranchDecisionsFromDominatorToTarget(OpBuilder &builder, Block *dominator, Block *target, SmallDenseMap< std::pair< Block *, Block * >, Condition > &decisions)
Compute the branch decisions that cause control to flow from the dominator to the target block.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)