Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
MergeIfs.cpp
Go to the documentation of this file.
1//===- MergeIfs.cpp -------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/Dialect/SCF/IR/SCF.h"
12#include "llvm/Support/Debug.h"
13
14#define DEBUG_TYPE "arc-merge-ifs"
15
16namespace circt {
17namespace arc {
18#define GEN_PASS_DEF_MERGEIFSPASS
19#include "circt/Dialect/Arc/ArcPasses.h.inc"
20} // namespace arc
21} // namespace circt
22
23using namespace mlir;
24using namespace circt;
25using namespace arc;
26
27namespace {
28struct MergeIfsPass : public arc::impl::MergeIfsPassBase<MergeIfsPass> {
29 void runOnOperation() override;
30 void runOnBlock(Block &rootBlock);
31 void sinkOps(Block &rootBlock);
32 void mergeIfs(Block &rootBlock);
33
34private:
35 bool anyChanges;
36 DenseMap<scf::IfOp, DenseSet<Value>> ifWrites, ifReads;
37 DenseMap<scf::IfOp, bool> ifHasSideEffects;
38};
39} // namespace
40
41void MergeIfsPass::runOnOperation() {
42 // Go through the regions recursively, from outer regions to nested regions,
43 // and try to move/sink/merge ops in each.
44 getOperation()->walk<WalkOrder::PreOrder>([&](Region *region) {
45 if (region->hasOneBlock() && mlir::mayHaveSSADominance(*region))
46 runOnBlock(region->front());
47 });
48}
49
50/// Iteratively sink ops into block, move them closer to their uses, and merge
51/// adjacent `scf.if` operations.
52void MergeIfsPass::runOnBlock(Block &rootBlock) {
53 LLVM_DEBUG(llvm::dbgs() << "Running on block in "
54 << rootBlock.getParentOp()->getName() << "\n");
55 do {
56 ++numIterations;
57 anyChanges = false;
58 sinkOps(rootBlock);
59 mergeIfs(rootBlock);
60 } while (anyChanges);
61}
62
63/// Return the state/memory value being written by an op.
64static Value getPointerWrittenByOp(Operation *op) {
65 if (auto write = dyn_cast<StateWriteOp>(op))
66 return write.getState();
67 if (auto write = dyn_cast<MemoryWriteOp>(op))
68 return write.getMemory();
69 return {};
70}
71
72/// Return the state/memory value being read by an op.
73static Value getPointerReadByOp(Operation *op) {
74 if (auto read = dyn_cast<StateReadOp>(op))
75 return read.getState();
76 if (auto read = dyn_cast<MemoryReadOp>(op))
77 return read.getMemory();
78 return {};
79}
80
81/// Check if an operation has side effects, ignoring any nested ops. This is
82/// useful if we're traversing all nested ops anyway, and we are only interested
83/// in the current op's side effects.
84static bool hasSideEffects(Operation *op) {
85 if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op))
86 return !memEffects.hasNoEffect();
87 return !op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
88}
89
90namespace {
91/// An integer indicating the position of an operation in its parent block. The
92/// first field is the initial order/position assigned. The second field is used
93/// to order ops that were moved to the same location, which makes them have the
94/// same first field.
95using OpOrder = std::pair<unsigned, unsigned>;
96
97/// A helper that tracks an op and its order, and allows for convenient
98/// substitution with another op that has a higher/lower order.
99struct OpAndOrder {
100 Operation *op = nullptr;
101 OpOrder order = {0, 0};
102
103 explicit operator bool() const { return op; }
104
105 /// Assign `other` if its order is lower than this op, or this op is null.
106 void minimize(const OpAndOrder &other) {
107 if (!op || (other.op && other.order < order))
108 *this = other;
109 }
110
111 /// Assign `other` if its order is higher than this op, or this op is null.
112 void maximize(const OpAndOrder &other) {
113 if (!op || (other.op && other.order > order))
114 *this = other;
115 }
116};
117} // namespace
118
119/// Sink operations as close to their users as possible.
120void MergeIfsPass::sinkOps(Block &rootBlock) {
121 // A numeric position assigned to ops as we encounter them. Ops at the end of
122 // the block get the lowest order number, ops at the beginning the highest.
123 DenseMap<Operation *, OpOrder> opOrder;
124 // A lookup table that indicates where ops should be inserted. This is used to
125 // maintain the original op order if multiple ops pile up before the same
126 // other op that blocks their move.
127 DenseMap<Operation *, Operation *> insertionPoints;
128 // The write ops to each state/memory pointer we've seen so far. ("Next"
129 // because we run from the end to the beginning of the block.)
130 DenseMap<Value, Operation *> nextWrite;
131 // The most recent op that has an unknown (non-read/write) side-effect.
132 Operation *nextSideEffect = nullptr;
133
134 for (auto &op : llvm::make_early_inc_range(llvm::reverse(rootBlock))) {
135 // Assign an order to this op.
136 auto order = OpOrder{opOrder.size() + 1, 0};
137 opOrder[&op] = order;
138 // Track whether the op is, or contains, any writes (and thus can't
139 // generally be moved into a block)
140 bool opContainsWrites = false;
141
142 // Analyze the side effects in the op.
143 op.walk([&](Operation *subOp) {
144 if (auto ptr = getPointerWrittenByOp(subOp)) {
145 nextWrite[ptr] = &op;
146 opContainsWrites = true;
147 } else if (!isa<StateReadOp, MemoryReadOp>(subOp) &&
148 hasSideEffects(subOp)) {
149 nextSideEffect = &op;
150 }
151 });
152
153 // Determine how much the op can be moved.
154 OpAndOrder moveLimit;
155 if (auto ptr = getPointerReadByOp(&op)) {
156 // Don't move across writes to the same state/memory.
157 if (auto *write = nextWrite.lookup(ptr))
158 moveLimit.maximize({write, opOrder.lookup(write)});
159 // Don't move across general side-effecting ops.
160 if (nextSideEffect)
161 moveLimit.maximize({nextSideEffect, opOrder.lookup(nextSideEffect)});
162 } else if (opContainsWrites || nextSideEffect == &op) {
163 // Don't move writes or side-effecting ops.
164 continue;
165 }
166
167 // Find the block that contains all uses.
168 Block *allUsesInBlock = nullptr;
169 for (auto *user : op.getUsers()) {
170 // If this user is directly in the root block there's no chance of sinking
171 // the current op anywhere.
172 if (user->getBlock() == &rootBlock) {
173 allUsesInBlock = nullptr;
174 break;
175 }
176
177 // Find the operation in the root block that contains this user.
178 while (user->getParentOp()->getBlock() != &rootBlock)
179 user = user->getParentOp();
180 assert(user);
181
182 // Check that all users sit in the same op in the root block.
183 if (!allUsesInBlock) {
184 allUsesInBlock = user->getBlock();
185 } else if (allUsesInBlock != user->getBlock()) {
186 allUsesInBlock = nullptr;
187 break;
188 }
189 }
190
191 // If no single block exists that contains all uses, find the earliest op in
192 // the root block that uses the current op.
193 OpAndOrder earliest;
194 if (allUsesInBlock) {
195 earliest.op = allUsesInBlock->getParentOp();
196 earliest.order = opOrder.lookup(earliest.op);
197 } else {
198 for (auto *user : op.getUsers()) {
199 while (user->getBlock() != &rootBlock)
200 user = user->getParentOp();
201 assert(user);
202 earliest.maximize({user, opOrder.lookup(user)});
203 }
204 }
205
206 // Ensure we don't move past the move limit imposed by side effects.
207 earliest.maximize(moveLimit);
208 if (!earliest)
209 continue;
210
211 // Either move the op inside the single block that contains all uses, or
212 // move it to just before its earliest user.
213 if (allUsesInBlock && allUsesInBlock->getParentOp() == earliest.op) {
214 op.moveBefore(allUsesInBlock, allUsesInBlock->begin());
215 ++numOpsSunk;
216 anyChanges = true;
217 LLVM_DEBUG(llvm::dbgs() << "- Sunk " << op << "\n");
218 } else {
219 // Insert above other ops that we have already moved to this earliest op.
220 // This ensures the original op order is maintained and we are not
221 // spuriously flipping ops around. This also works without the
222 // `insertionPoint` lookup, but can cause significant linear scanning to
223 // find the op before which we want to insert.
224 auto &insertionPoint = insertionPoints[earliest.op];
225 if (insertionPoint) {
226 auto order = opOrder.lookup(insertionPoint);
227 assert(order.first == earliest.order.first);
228 assert(order.second >= earliest.order.second);
229 earliest.op = insertionPoint;
230 earliest.order = order;
231 }
232 while (auto *prevOp = earliest.op->getPrevNode()) {
233 auto order = opOrder.lookup(prevOp);
234 if (order.first != earliest.order.first)
235 break;
236 assert(order.second > earliest.order.second);
237 earliest.op = prevOp;
238 earliest.order = order;
239 }
240 insertionPoint = earliest.op;
241
242 // Only move if the op isn't already in the right spot.
243 if (op.getNextNode() != earliest.op) {
244 LLVM_DEBUG(llvm::dbgs() << "- Moved " << op << "\n");
245 op.moveBefore(earliest.op);
246 ++numOpsMovedToUser;
247 anyChanges = true;
248 }
249
250 // Update the current op's order to reflect where it has been inserted.
251 // This ensures that later moves to the same pile of moved ops do not
252 // reorder the operations.
253 order = earliest.order;
254 assert(order.second < unsigned(-1));
255 ++order.second;
256 opOrder[&op] = order;
257 }
258 }
259}
260
261void MergeIfsPass::mergeIfs(Block &rootBlock) {
262 ifReads.clear();
263 ifWrites.clear();
264 ifHasSideEffects.clear();
265
266 scf::IfOp lastOp;
267 SmallVector<scf::IfOp> ifOps(rootBlock.getOps<scf::IfOp>());
268 for (auto ifOp : ifOps) {
269 ifHasSideEffects[ifOp] = false;
270 ifOp.walk([&](Operation *op) {
271 if (auto ptr = getPointerWrittenByOp(op))
272 ifWrites[ifOp].insert(ptr);
273 else if (auto ptr = getPointerReadByOp(op))
274 ifReads[ifOp].insert(ptr);
275 else if (!ifHasSideEffects[ifOp] && hasSideEffects(op))
276 ifHasSideEffects[ifOp] = true;
277 });
278 }
279
280 for (auto ifOp : ifOps) {
281 auto prevIfOp = std::exchange(lastOp, ifOp);
282 if (!prevIfOp)
283 continue;
284
285 // Only handle simple cases for now. (Same condition, no results, and both
286 // ifs either have or don't have an else block.)
287 if (ifOp.getCondition() != prevIfOp.getCondition())
288 continue;
289 if (ifOp.getNumResults() != 0 || prevIfOp.getNumResults() != 0)
290 continue;
291 if (ifOp.getElseRegion().empty() != prevIfOp.getElseRegion().empty())
292 continue;
293
294 // Try to move ops in between the `scf.if` ops above the previous `scf.if`
295 // in order to make them immediately adjacent.
296 if (ifOp->getPrevNode() != prevIfOp) {
297 // Check if it is legal to throw all ops over the previous `scf.if` op,
298 // given the side effects. We don't move the ops yet to ensure we can move
299 // *all* of them at once afterwards. Otherwise this optimization would
300 // race with the sink-to-users optimization.
301 bool allMovable = true;
302 for (auto &op : llvm::make_range(Block::iterator(prevIfOp->getNextNode()),
303 Block::iterator(ifOp))) {
304 auto result = op.walk([&](Operation *subOp) {
305 if (auto ptr = getPointerWrittenByOp(subOp)) {
306 // We can't move writes over writes or reads of the same state.
307 if (ifWrites[prevIfOp].contains(ptr) ||
308 ifReads[prevIfOp].contains(ptr))
309 return WalkResult::interrupt();
310 } else if (auto ptr = getPointerReadByOp(subOp)) {
311 // We can't move reads over writes to the same state.
312 if (ifWrites[prevIfOp].contains(ptr))
313 return WalkResult::interrupt();
314 } else if (hasSideEffects(subOp)) {
315 // We can't move side-effecting ops over other side-effecting ops.
316 if (ifHasSideEffects[prevIfOp])
317 return WalkResult::interrupt();
318 }
319 return WalkResult::advance();
320 });
321 if (result.wasInterrupted()) {
322 allMovable = false;
323 break;
324 }
325 }
326 if (!allMovable)
327 continue;
328
329 // At this point we know that all ops can be moved. Do so.
330 while (auto *op = prevIfOp->getNextNode()) {
331 if (op == ifOp)
332 break;
333 LLVM_DEBUG(llvm::dbgs() << "- Moved before if " << *op << "\n");
334 op->moveBefore(prevIfOp);
335 ++numOpsMovedFromBetweenIfs;
336 }
337 }
338
339 // Merge the then-blocks.
340 prevIfOp.thenYield().erase();
341 prevIfOp.thenBlock()->getOperations().splice(
342 prevIfOp.thenBlock()->end(), ifOp.thenBlock()->getOperations());
343
344 // Merge the else-blocks if present.
345 if (ifOp.elseBlock()) {
346 prevIfOp.elseYield().erase();
347 prevIfOp.elseBlock()->getOperations().splice(
348 prevIfOp.elseBlock()->end(), ifOp.elseBlock()->getOperations());
349 }
350
351 ifReads[prevIfOp].insert_range(ifReads[ifOp]);
352 ifWrites[prevIfOp].insert_range(ifWrites[ifOp]);
353 ifHasSideEffects[prevIfOp] |= ifHasSideEffects[ifOp];
354 ifReads.erase(ifOp);
355 ifWrites.erase(ifOp);
356 ifHasSideEffects.erase(ifOp);
357
358 // Clean up.
359 ifOp.erase();
360 lastOp = prevIfOp;
361 anyChanges = true;
362 ++numIfsMerged;
363 LLVM_DEBUG(llvm::dbgs() << "- Merged adjacent if ops\n");
364 }
365}
assert(baseType &&"element must be base type")
static bool hasSideEffects(Operation *op)
Check if an operation has side effects, ignoring any nested ops.
Definition MergeIfs.cpp:84
static Value getPointerWrittenByOp(Operation *op)
Return the state/memory value being written by an op.
Definition MergeIfs.cpp:64
static Value getPointerReadByOp(Operation *op)
Return the state/memory value being read by an op.
Definition MergeIfs.cpp:73
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.