11 #include "mlir/Dialect/SCF/IR/SCF.h"
12 #include "llvm/Support/Debug.h"
14 #define DEBUG_TYPE "arc-merge-ifs"
18 #define GEN_PASS_DEF_MERGEIFSPASS
19 #include "circt/Dialect/Arc/ArcPasses.h.inc"
24 using namespace circt;
28 struct MergeIfsPass :
public arc::impl::MergeIfsPassBase<MergeIfsPass> {
29 void runOnOperation()
override;
30 void runOnBlock(
Block &rootBlock);
31 void sinkOps(
Block &rootBlock);
32 void mergeIfs(
Block &rootBlock);
39 void MergeIfsPass::runOnOperation() {
42 getOperation()->walk<WalkOrder::PreOrder>([&](Region *region) {
43 if (region->hasOneBlock() && mlir::mayHaveSSADominance(*region))
44 runOnBlock(region->front());
50 void MergeIfsPass::runOnBlock(Block &rootBlock) {
51 LLVM_DEBUG(llvm::dbgs() <<
"Running on block in "
52 << rootBlock.getParentOp()->getName() <<
"\n");
63 if (
auto write = dyn_cast<StateWriteOp>(op))
64 return write.getState();
65 if (
auto write = dyn_cast<MemoryWriteOp>(op))
66 return write.getMemory();
72 if (
auto read = dyn_cast<StateReadOp>(op))
73 return read.getState();
74 if (
auto read = dyn_cast<MemoryReadOp>(op))
75 return read.getMemory();
83 if (
auto memEffects = dyn_cast<MemoryEffectOpInterface>(op))
84 return !memEffects.hasNoEffect();
85 return !op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
93 using OpOrder = std::pair<unsigned, unsigned>;
98 Operation *op =
nullptr;
99 OpOrder order = {0, 0};
101 explicit operator bool()
const {
return op; }
104 void minimize(
const OpAndOrder &other) {
105 if (!op || (other.op && other.order < order))
110 void maximize(
const OpAndOrder &other) {
111 if (!op || (other.op && other.order > order))
118 void MergeIfsPass::sinkOps(Block &rootBlock) {
121 DenseMap<Operation *, OpOrder> opOrder;
125 DenseMap<Operation *, Operation *> insertionPoints;
128 DenseMap<Value, Operation *> nextWrite;
130 Operation *nextSideEffect =
nullptr;
132 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(rootBlock))) {
134 auto order = OpOrder{opOrder.size() + 1, 0};
135 opOrder[&op] = order;
138 op.walk([&](Operation *subOp) {
140 nextWrite[ptr] = &op;
141 else if (!isa<StateReadOp, MemoryReadOp>(subOp) &&
hasSideEffects(subOp))
142 nextSideEffect = &op;
146 OpAndOrder moveLimit;
149 if (
auto *write = nextWrite.lookup(ptr))
150 moveLimit.maximize({write, opOrder.lookup(write)});
153 moveLimit.maximize({nextSideEffect, opOrder.lookup(nextSideEffect)});
154 }
else if (isa<StateWriteOp, MemoryWriteOp>(&op) || nextSideEffect == &op) {
160 Block *allUsesInBlock =
nullptr;
161 for (
auto *user : op.getUsers()) {
164 if (user->getBlock() == &rootBlock) {
165 allUsesInBlock =
nullptr;
170 while (user->getParentOp()->getBlock() != &rootBlock)
171 user = user->getParentOp();
175 if (!allUsesInBlock) {
176 allUsesInBlock = user->getBlock();
177 }
else if (allUsesInBlock != user->getBlock()) {
178 allUsesInBlock =
nullptr;
186 if (allUsesInBlock) {
187 earliest.op = allUsesInBlock->getParentOp();
188 earliest.order = opOrder.lookup(earliest.op);
190 for (
auto *user : op.getUsers()) {
191 while (user->getBlock() != &rootBlock)
192 user = user->getParentOp();
194 earliest.maximize({user, opOrder.lookup(user)});
199 earliest.maximize(moveLimit);
205 if (allUsesInBlock && allUsesInBlock->getParentOp() == earliest.op) {
206 op.moveBefore(allUsesInBlock, allUsesInBlock->begin());
209 LLVM_DEBUG(llvm::dbgs() <<
"- Sunk " << op <<
"\n");
216 auto &insertionPoint = insertionPoints[earliest.op];
217 if (insertionPoint) {
218 auto order = opOrder.lookup(insertionPoint);
219 assert(order.first == earliest.order.first);
220 assert(order.second >= earliest.order.second);
221 earliest.op = insertionPoint;
222 earliest.order = order;
224 while (
auto *prevOp = earliest.op->getPrevNode()) {
225 auto order = opOrder.lookup(prevOp);
226 if (order.first != earliest.order.first)
228 assert(order.second > earliest.order.second);
229 earliest.op = prevOp;
230 earliest.order = order;
232 insertionPoint = earliest.op;
235 if (op.getNextNode() != earliest.op) {
236 LLVM_DEBUG(llvm::dbgs() <<
"- Moved " << op <<
"\n");
237 op.moveBefore(earliest.op);
245 order = earliest.order;
246 assert(order.second <
unsigned(-1));
248 opOrder[&op] = order;
253 void MergeIfsPass::mergeIfs(Block &rootBlock) {
254 DenseSet<Value> prevIfWrites, prevIfReads;
257 for (
auto ifOp : rootBlock.getOps<scf::IfOp>()) {
258 auto prevIfOp = std::exchange(lastOp, ifOp);
264 if (ifOp.getCondition() != prevIfOp.getCondition())
266 if (ifOp.getNumResults() != 0 || prevIfOp.getNumResults() != 0)
268 if (ifOp.getElseRegion().empty() != prevIfOp.getElseRegion().empty())
273 if (ifOp->getPrevNode() != prevIfOp) {
275 bool prevIfHasSideEffects =
false;
276 prevIfWrites.clear();
278 prevIfOp.walk([&](Operation *op) {
280 prevIfWrites.insert(ptr);
282 prevIfReads.insert(ptr);
284 prevIfHasSideEffects =
true;
291 bool allMovable =
true;
292 for (
auto &op : llvm::make_range(Block::iterator(prevIfOp->getNextNode()),
293 Block::iterator(ifOp))) {
294 auto result = op.walk([&](Operation *subOp) {
297 if (prevIfWrites.contains(ptr) || prevIfReads.contains(ptr))
298 return WalkResult::interrupt();
301 if (prevIfWrites.contains(ptr))
302 return WalkResult::interrupt();
305 if (prevIfHasSideEffects)
306 return WalkResult::interrupt();
308 return WalkResult::advance();
310 if (result.wasInterrupted()) {
319 while (
auto *op = prevIfOp->getNextNode()) {
322 LLVM_DEBUG(llvm::dbgs() <<
"- Moved before if " << *op <<
"\n");
323 op->moveBefore(prevIfOp);
324 ++numOpsMovedFromBetweenIfs;
329 prevIfOp.thenYield().erase();
330 ifOp.thenBlock()->getOperations().splice(
331 ifOp.thenBlock()->begin(), prevIfOp.thenBlock()->getOperations());
334 if (ifOp.elseBlock()) {
335 prevIfOp.elseYield().erase();
336 ifOp.elseBlock()->getOperations().splice(
337 ifOp.elseBlock()->begin(), prevIfOp.elseBlock()->getOperations());
344 LLVM_DEBUG(llvm::dbgs() <<
"- Merged adjacent if ops\n");
assert(baseType &&"element must be base type")
static bool hasSideEffects(Operation *op)
Check if an operation has side effects, ignoring any nested ops.
static Value getPointerWrittenByOp(Operation *op)
Return the state/memory value being written by an op.
static Value getPointerReadByOp(Operation *op)
Return the state/memory value being read by an op.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.