11#include "mlir/Dialect/SCF/IR/SCF.h"
12#include "llvm/Support/Debug.h"
14#define DEBUG_TYPE "arc-merge-ifs"
18#define GEN_PASS_DEF_MERGEIFSPASS
19#include "circt/Dialect/Arc/ArcPasses.h.inc"
28struct MergeIfsPass :
public arc::impl::MergeIfsPassBase<MergeIfsPass> {
29 void runOnOperation()
override;
30 void runOnBlock(
Block &rootBlock);
31 void sinkOps(
Block &rootBlock);
32 void mergeIfs(
Block &rootBlock);
36 DenseMap<scf::IfOp, DenseSet<Value>> ifWrites, ifReads;
37 DenseMap<scf::IfOp, bool> ifHasSideEffects;
41void MergeIfsPass::runOnOperation() {
44 getOperation()->walk<WalkOrder::PreOrder>([&](Region *region) {
45 if (region->hasOneBlock() && mlir::mayHaveSSADominance(*region))
46 runOnBlock(region->front());
52void MergeIfsPass::runOnBlock(Block &rootBlock) {
53 LLVM_DEBUG(llvm::dbgs() <<
"Running on block in "
54 << rootBlock.getParentOp()->getName() <<
"\n");
65 if (
auto write = dyn_cast<StateWriteOp>(op))
66 return write.getState();
67 if (
auto write = dyn_cast<MemoryWriteOp>(op))
68 return write.getMemory();
74 if (
auto read = dyn_cast<StateReadOp>(op))
75 return read.getState();
76 if (
auto read = dyn_cast<MemoryReadOp>(op))
77 return read.getMemory();
85 if (
auto memEffects = dyn_cast<MemoryEffectOpInterface>(op))
86 return !memEffects.hasNoEffect();
87 return !op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
95using OpOrder = std::pair<unsigned, unsigned>;
100 Operation *op =
nullptr;
101 OpOrder order = {0, 0};
103 explicit operator bool()
const {
return op; }
106 void minimize(
const OpAndOrder &other) {
107 if (!op || (other.op && other.order < order))
112 void maximize(
const OpAndOrder &other) {
113 if (!op || (other.op && other.order > order))
120void MergeIfsPass::sinkOps(Block &rootBlock) {
123 DenseMap<Operation *, OpOrder> opOrder;
127 DenseMap<Operation *, Operation *> insertionPoints;
130 DenseMap<Value, Operation *> nextWrite;
132 Operation *nextSideEffect =
nullptr;
134 for (
auto &op :
llvm::make_early_inc_range(
llvm::reverse(rootBlock))) {
136 auto order = OpOrder{opOrder.size() + 1, 0};
137 opOrder[&op] = order;
140 bool opContainsWrites =
false;
143 op.walk([&](Operation *subOp) {
145 nextWrite[ptr] = &op;
146 opContainsWrites =
true;
147 }
else if (!isa<StateReadOp, MemoryReadOp>(subOp) &&
149 nextSideEffect = &op;
154 OpAndOrder moveLimit;
157 if (
auto *write = nextWrite.lookup(ptr))
158 moveLimit.maximize({write, opOrder.lookup(write)});
161 moveLimit.maximize({nextSideEffect, opOrder.lookup(nextSideEffect)});
162 }
else if (opContainsWrites || nextSideEffect == &op) {
168 Block *allUsesInBlock =
nullptr;
169 for (
auto *user : op.getUsers()) {
172 if (user->getBlock() == &rootBlock) {
173 allUsesInBlock =
nullptr;
178 while (user->getParentOp()->getBlock() != &rootBlock)
179 user = user->getParentOp();
183 if (!allUsesInBlock) {
184 allUsesInBlock = user->getBlock();
185 }
else if (allUsesInBlock != user->getBlock()) {
186 allUsesInBlock =
nullptr;
194 if (allUsesInBlock) {
195 earliest.op = allUsesInBlock->getParentOp();
196 earliest.order = opOrder.lookup(earliest.op);
198 for (
auto *user : op.getUsers()) {
199 while (user->getBlock() != &rootBlock)
200 user = user->getParentOp();
202 earliest.maximize({user, opOrder.lookup(user)});
207 earliest.maximize(moveLimit);
213 if (allUsesInBlock && allUsesInBlock->getParentOp() == earliest.op) {
214 op.moveBefore(allUsesInBlock, allUsesInBlock->begin());
217 LLVM_DEBUG(llvm::dbgs() <<
"- Sunk " << op <<
"\n");
224 auto &insertionPoint = insertionPoints[earliest.op];
225 if (insertionPoint) {
226 auto order = opOrder.lookup(insertionPoint);
227 assert(order.first == earliest.order.first);
228 assert(order.second >= earliest.order.second);
229 earliest.op = insertionPoint;
230 earliest.order = order;
232 while (
auto *prevOp = earliest.op->getPrevNode()) {
233 auto order = opOrder.lookup(prevOp);
234 if (order.first != earliest.order.first)
236 assert(order.second > earliest.order.second);
237 earliest.op = prevOp;
238 earliest.order = order;
240 insertionPoint = earliest.op;
243 if (op.getNextNode() != earliest.op) {
244 LLVM_DEBUG(llvm::dbgs() <<
"- Moved " << op <<
"\n");
245 op.moveBefore(earliest.op);
253 order = earliest.order;
254 assert(order.second <
unsigned(-1));
256 opOrder[&op] = order;
261void MergeIfsPass::mergeIfs(Block &rootBlock) {
264 ifHasSideEffects.clear();
267 SmallVector<scf::IfOp> ifOps(rootBlock.getOps<scf::IfOp>());
268 for (
auto ifOp : ifOps) {
269 ifHasSideEffects[ifOp] =
false;
270 ifOp.walk([&](Operation *op) {
272 ifWrites[ifOp].insert(ptr);
274 ifReads[ifOp].insert(ptr);
276 ifHasSideEffects[ifOp] =
true;
280 for (
auto ifOp : ifOps) {
281 auto prevIfOp = std::exchange(lastOp, ifOp);
287 if (ifOp.getCondition() != prevIfOp.getCondition())
289 if (ifOp.getNumResults() != 0 || prevIfOp.getNumResults() != 0)
291 if (ifOp.getElseRegion().empty() != prevIfOp.getElseRegion().empty())
296 if (ifOp->getPrevNode() != prevIfOp) {
301 bool allMovable =
true;
302 for (
auto &op :
llvm::make_range(
Block::iterator(prevIfOp->getNextNode()),
303 Block::iterator(ifOp))) {
304 auto result = op.walk([&](Operation *subOp) {
307 if (ifWrites[prevIfOp].contains(ptr) ||
308 ifReads[prevIfOp].contains(ptr))
309 return WalkResult::interrupt();
312 if (ifWrites[prevIfOp].contains(ptr))
313 return WalkResult::interrupt();
316 if (ifHasSideEffects[prevIfOp])
317 return WalkResult::interrupt();
319 return WalkResult::advance();
321 if (result.wasInterrupted()) {
330 while (
auto *op = prevIfOp->getNextNode()) {
333 LLVM_DEBUG(llvm::dbgs() <<
"- Moved before if " << *op <<
"\n");
334 op->moveBefore(prevIfOp);
335 ++numOpsMovedFromBetweenIfs;
340 prevIfOp.thenYield().erase();
341 prevIfOp.thenBlock()->getOperations().splice(
342 prevIfOp.thenBlock()->end(), ifOp.thenBlock()->getOperations());
345 if (ifOp.elseBlock()) {
346 prevIfOp.elseYield().erase();
347 prevIfOp.elseBlock()->getOperations().splice(
348 prevIfOp.elseBlock()->end(), ifOp.elseBlock()->getOperations());
351 ifReads[prevIfOp].insert_range(ifReads[ifOp]);
352 ifWrites[prevIfOp].insert_range(ifWrites[ifOp]);
353 ifHasSideEffects[prevIfOp] |= ifHasSideEffects[ifOp];
355 ifWrites.erase(ifOp);
356 ifHasSideEffects.erase(ifOp);
363 LLVM_DEBUG(llvm::dbgs() <<
"- Merged adjacent if ops\n");
assert(baseType &&"element must be base type")
static bool hasSideEffects(Operation *op)
Check if an operation has side effects, ignoring any nested ops.
static Value getPointerWrittenByOp(Operation *op)
Return the state/memory value being written by an op.
static Value getPointerReadByOp(Operation *op)
Return the state/memory value being read by an op.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.