CIRCT  20.0.0git
MergeIfs.cpp
Go to the documentation of this file.
1 //===- MergeIfs.cpp -------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/Dialect/SCF/IR/SCF.h"
12 #include "llvm/Support/Debug.h"
13 
14 #define DEBUG_TYPE "arc-merge-ifs"
15 
16 namespace circt {
17 namespace arc {
18 #define GEN_PASS_DEF_MERGEIFSPASS
19 #include "circt/Dialect/Arc/ArcPasses.h.inc"
20 } // namespace arc
21 } // namespace circt
22 
23 using namespace mlir;
24 using namespace circt;
25 using namespace arc;
26 
27 namespace {
28 struct MergeIfsPass : public arc::impl::MergeIfsPassBase<MergeIfsPass> {
29  void runOnOperation() override;
30  void runOnBlock(Block &rootBlock);
31  void sinkOps(Block &rootBlock);
32  void mergeIfs(Block &rootBlock);
33 
34 private:
35  bool anyChanges;
36 };
37 } // namespace
38 
39 void MergeIfsPass::runOnOperation() {
40  // Go through the regions recursively, from outer regions to nested regions,
41  // and try to move/sink/merge ops in each.
42  getOperation()->walk<WalkOrder::PreOrder>([&](Region *region) {
43  if (region->hasOneBlock() && mlir::mayHaveSSADominance(*region))
44  runOnBlock(region->front());
45  });
46 }
47 
48 /// Iteratively sink ops into block, move them closer to their uses, and merge
49 /// adjacent `scf.if` operations.
50 void MergeIfsPass::runOnBlock(Block &rootBlock) {
51  LLVM_DEBUG(llvm::dbgs() << "Running on block in "
52  << rootBlock.getParentOp()->getName() << "\n");
53  do {
54  ++numIterations;
55  anyChanges = false;
56  sinkOps(rootBlock);
57  mergeIfs(rootBlock);
58  } while (anyChanges);
59 }
60 
61 /// Return the state/memory value being written by an op.
62 static Value getPointerWrittenByOp(Operation *op) {
63  if (auto write = dyn_cast<StateWriteOp>(op))
64  return write.getState();
65  if (auto write = dyn_cast<MemoryWriteOp>(op))
66  return write.getMemory();
67  return {};
68 }
69 
70 /// Return the state/memory value being read by an op.
71 static Value getPointerReadByOp(Operation *op) {
72  if (auto read = dyn_cast<StateReadOp>(op))
73  return read.getState();
74  if (auto read = dyn_cast<MemoryReadOp>(op))
75  return read.getMemory();
76  return {};
77 }
78 
79 /// Check if an operation has side effects, ignoring any nested ops. This is
80 /// useful if we're traversing all nested ops anyway, and we are only interested
81 /// in the current op's side effects.
82 static bool hasSideEffects(Operation *op) {
83  if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op))
84  return !memEffects.hasNoEffect();
85  return !op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
86 }
87 
88 namespace {
89 /// An integer indicating the position of an operation in its parent block. The
90 /// first field is the initial order/position assigned. The second field is used
91 /// to order ops that were moved to the same location, which makes them have the
92 /// same first field.
93 using OpOrder = std::pair<unsigned, unsigned>;
94 
95 /// A helper that tracks an op and its order, and allows for convenient
96 /// substitution with another op that has a higher/lower order.
97 struct OpAndOrder {
98  Operation *op = nullptr;
99  OpOrder order = {0, 0};
100 
101  explicit operator bool() const { return op; }
102 
103  /// Assign `other` if its order is lower than this op, or this op is null.
104  void minimize(const OpAndOrder &other) {
105  if (!op || (other.op && other.order < order))
106  *this = other;
107  }
108 
109  /// Assign `other` if its order is higher than this op, or this op is null.
110  void maximize(const OpAndOrder &other) {
111  if (!op || (other.op && other.order > order))
112  *this = other;
113  }
114 };
115 } // namespace
116 
117 /// Sink operations as close to their users as possible.
118 void MergeIfsPass::sinkOps(Block &rootBlock) {
119  // A numeric position assigned to ops as we encounter them. Ops at the end of
120  // the block get the lowest order number, ops at the beginning the highest.
121  DenseMap<Operation *, OpOrder> opOrder;
122  // A lookup table that indicates where ops should be inserted. This is used to
123  // maintain the original op order if multiple ops pile up before the same
124  // other op that blocks their move.
125  DenseMap<Operation *, Operation *> insertionPoints;
126  // The write ops to each state/memory pointer we've seen so far. ("Next"
127  // because we run from the end to the beginning of the block.)
128  DenseMap<Value, Operation *> nextWrite;
129  // The most recent op that has an unknown (non-read/write) side-effect.
130  Operation *nextSideEffect = nullptr;
131 
132  for (auto &op : llvm::make_early_inc_range(llvm::reverse(rootBlock))) {
133  // Assign an order to this op.
134  auto order = OpOrder{opOrder.size() + 1, 0};
135  opOrder[&op] = order;
136 
137  // Analyze the side effects in the op.
138  op.walk([&](Operation *subOp) {
139  if (auto ptr = getPointerWrittenByOp(subOp))
140  nextWrite[ptr] = &op;
141  else if (!isa<StateReadOp, MemoryReadOp>(subOp) && hasSideEffects(subOp))
142  nextSideEffect = &op;
143  });
144 
145  // Determine how much the op can be moved.
146  OpAndOrder moveLimit;
147  if (auto ptr = getPointerReadByOp(&op)) {
148  // Don't move across writes to the same state/memory.
149  if (auto *write = nextWrite.lookup(ptr))
150  moveLimit.maximize({write, opOrder.lookup(write)});
151  // Don't move across general side-effecting ops.
152  if (nextSideEffect)
153  moveLimit.maximize({nextSideEffect, opOrder.lookup(nextSideEffect)});
154  } else if (isa<StateWriteOp, MemoryWriteOp>(&op) || nextSideEffect == &op) {
155  // Don't move writes or side-effecting ops.
156  continue;
157  }
158 
159  // Find the block that contains all uses.
160  Block *allUsesInBlock = nullptr;
161  for (auto *user : op.getUsers()) {
162  // If this user is directly in the root block there's no chance of sinking
163  // the current op anywhere.
164  if (user->getBlock() == &rootBlock) {
165  allUsesInBlock = nullptr;
166  break;
167  }
168 
169  // Find the operation in the root block that contains this user.
170  while (user->getParentOp()->getBlock() != &rootBlock)
171  user = user->getParentOp();
172  assert(user);
173 
174  // Check that all users sit in the same op in the root block.
175  if (!allUsesInBlock) {
176  allUsesInBlock = user->getBlock();
177  } else if (allUsesInBlock != user->getBlock()) {
178  allUsesInBlock = nullptr;
179  break;
180  }
181  }
182 
183  // If no single block exists that contains all uses, find the earliest op in
184  // the root block that uses the current op.
185  OpAndOrder earliest;
186  if (allUsesInBlock) {
187  earliest.op = allUsesInBlock->getParentOp();
188  earliest.order = opOrder.lookup(earliest.op);
189  } else {
190  for (auto *user : op.getUsers()) {
191  while (user->getBlock() != &rootBlock)
192  user = user->getParentOp();
193  assert(user);
194  earliest.maximize({user, opOrder.lookup(user)});
195  }
196  }
197 
198  // Ensure we don't move past the move limit imposed by side effects.
199  earliest.maximize(moveLimit);
200  if (!earliest)
201  continue;
202 
203  // Either move the op inside the single block that contains all uses, or
204  // move it to just before its earliest user.
205  if (allUsesInBlock && allUsesInBlock->getParentOp() == earliest.op) {
206  op.moveBefore(allUsesInBlock, allUsesInBlock->begin());
207  ++numOpsSunk;
208  anyChanges = true;
209  LLVM_DEBUG(llvm::dbgs() << "- Sunk " << op << "\n");
210  } else {
211  // Insert above other ops that we have already moved to this earliest op.
212  // This ensures the original op order is maintained and we are not
213  // spuriously flipping ops around. This also works without the
214  // `insertionPoint` lookup, but can cause significant linear scanning to
215  // find the op before which we want to insert.
216  auto &insertionPoint = insertionPoints[earliest.op];
217  if (insertionPoint) {
218  auto order = opOrder.lookup(insertionPoint);
219  assert(order.first == earliest.order.first);
220  assert(order.second >= earliest.order.second);
221  earliest.op = insertionPoint;
222  earliest.order = order;
223  }
224  while (auto *prevOp = earliest.op->getPrevNode()) {
225  auto order = opOrder.lookup(prevOp);
226  if (order.first != earliest.order.first)
227  break;
228  assert(order.second > earliest.order.second);
229  earliest.op = prevOp;
230  earliest.order = order;
231  }
232  insertionPoint = earliest.op;
233 
234  // Only move if the op isn't already in the right spot.
235  if (op.getNextNode() != earliest.op) {
236  LLVM_DEBUG(llvm::dbgs() << "- Moved " << op << "\n");
237  op.moveBefore(earliest.op);
238  ++numOpsMovedToUser;
239  anyChanges = true;
240  }
241 
242  // Update the current op's order to reflect where it has been inserted.
243  // This ensures that later moves to the same pile of moved ops do not
244  // reorder the operations.
245  order = earliest.order;
246  assert(order.second < unsigned(-1));
247  ++order.second;
248  opOrder[&op] = order;
249  }
250  }
251 }
252 
253 void MergeIfsPass::mergeIfs(Block &rootBlock) {
254  DenseSet<Value> prevIfWrites, prevIfReads;
255 
256  scf::IfOp lastOp;
257  for (auto ifOp : rootBlock.getOps<scf::IfOp>()) {
258  auto prevIfOp = std::exchange(lastOp, ifOp);
259  if (!prevIfOp)
260  continue;
261 
262  // Only handle simple cases for now. (Same condition, no results, and both
263  // ifs either have or don't have an else block.)
264  if (ifOp.getCondition() != prevIfOp.getCondition())
265  continue;
266  if (ifOp.getNumResults() != 0 || prevIfOp.getNumResults() != 0)
267  continue;
268  if (ifOp.getElseRegion().empty() != prevIfOp.getElseRegion().empty())
269  continue;
270 
271  // Try to move ops in between the `scf.if` ops above the previous `scf.if`
272  // in order to make them immediately adjacent.
273  if (ifOp->getPrevNode() != prevIfOp) {
274  // Determine the side effects inside the previous if op.
275  bool prevIfHasSideEffects = false;
276  prevIfWrites.clear();
277  prevIfReads.clear();
278  prevIfOp.walk([&](Operation *op) {
279  if (auto ptr = getPointerWrittenByOp(op))
280  prevIfWrites.insert(ptr);
281  else if (auto ptr = getPointerReadByOp(op))
282  prevIfReads.insert(ptr);
283  else if (!prevIfHasSideEffects && hasSideEffects(op))
284  prevIfHasSideEffects = true;
285  });
286 
287  // Check if it is legal to throw all ops over the previous `scf.if` op,
288  // given the side effects. We don't move the ops yet to ensure we can move
289  // *all* of them at once afterwards. Otherwise this optimization would
290  // race with the sink-to-users optimization.
291  bool allMovable = true;
292  for (auto &op : llvm::make_range(Block::iterator(prevIfOp->getNextNode()),
293  Block::iterator(ifOp))) {
294  auto result = op.walk([&](Operation *subOp) {
295  if (auto ptr = getPointerWrittenByOp(subOp)) {
296  // We can't move writes over writes or reads of the same state.
297  if (prevIfWrites.contains(ptr) || prevIfReads.contains(ptr))
298  return WalkResult::interrupt();
299  } else if (auto ptr = getPointerReadByOp(subOp)) {
300  // We can't move reads over writes to the same state.
301  if (prevIfWrites.contains(ptr))
302  return WalkResult::interrupt();
303  } else if (hasSideEffects(subOp)) {
304  // We can't move side-effecting ops over other side-effecting ops.
305  if (prevIfHasSideEffects)
306  return WalkResult::interrupt();
307  }
308  return WalkResult::advance();
309  });
310  if (result.wasInterrupted()) {
311  allMovable = false;
312  break;
313  }
314  }
315  if (!allMovable)
316  continue;
317 
318  // At this point we know that all ops can be moved. Do so.
319  while (auto *op = prevIfOp->getNextNode()) {
320  if (op == ifOp)
321  break;
322  LLVM_DEBUG(llvm::dbgs() << "- Moved before if " << *op << "\n");
323  op->moveBefore(prevIfOp);
324  ++numOpsMovedFromBetweenIfs;
325  }
326  }
327 
328  // Merge the then-blocks.
329  prevIfOp.thenYield().erase();
330  ifOp.thenBlock()->getOperations().splice(
331  ifOp.thenBlock()->begin(), prevIfOp.thenBlock()->getOperations());
332 
333  // Merge the else-blocks if present.
334  if (ifOp.elseBlock()) {
335  prevIfOp.elseYield().erase();
336  ifOp.elseBlock()->getOperations().splice(
337  ifOp.elseBlock()->begin(), prevIfOp.elseBlock()->getOperations());
338  }
339 
340  // Clean up.
341  prevIfOp.erase();
342  anyChanges = true;
343  ++numIfsMerged;
344  LLVM_DEBUG(llvm::dbgs() << "- Merged adjacent if ops\n");
345  }
346 }
assert(baseType &&"element must be base type")
static bool hasSideEffects(Operation *op)
Check if an operation has side effects, ignoring any nested ops.
Definition: MergeIfs.cpp:82
static Value getPointerWrittenByOp(Operation *op)
Return the state/memory value being written by an op.
Definition: MergeIfs.cpp:62
static Value getPointerReadByOp(Operation *op)
Return the state/memory value being read by an op.
Definition: MergeIfs.cpp:71
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21