CIRCT 20.0.0git
Loading...
Searching...
No Matches
MergeIfs.cpp
Go to the documentation of this file.
1//===- MergeIfs.cpp -------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/Dialect/SCF/IR/SCF.h"
12#include "llvm/Support/Debug.h"
13
14#define DEBUG_TYPE "arc-merge-ifs"
15
16namespace circt {
17namespace arc {
18#define GEN_PASS_DEF_MERGEIFSPASS
19#include "circt/Dialect/Arc/ArcPasses.h.inc"
20} // namespace arc
21} // namespace circt
22
23using namespace mlir;
24using namespace circt;
25using namespace arc;
26
27namespace {
28struct MergeIfsPass : public arc::impl::MergeIfsPassBase<MergeIfsPass> {
29 void runOnOperation() override;
30 void runOnBlock(Block &rootBlock);
31 void sinkOps(Block &rootBlock);
32 void mergeIfs(Block &rootBlock);
33
34private:
35 bool anyChanges;
36};
37} // namespace
38
39void MergeIfsPass::runOnOperation() {
40 // Go through the regions recursively, from outer regions to nested regions,
41 // and try to move/sink/merge ops in each.
42 getOperation()->walk<WalkOrder::PreOrder>([&](Region *region) {
43 if (region->hasOneBlock() && mlir::mayHaveSSADominance(*region))
44 runOnBlock(region->front());
45 });
46}
47
48/// Iteratively sink ops into block, move them closer to their uses, and merge
49/// adjacent `scf.if` operations.
50void MergeIfsPass::runOnBlock(Block &rootBlock) {
51 LLVM_DEBUG(llvm::dbgs() << "Running on block in "
52 << rootBlock.getParentOp()->getName() << "\n");
53 do {
54 ++numIterations;
55 anyChanges = false;
56 sinkOps(rootBlock);
57 mergeIfs(rootBlock);
58 } while (anyChanges);
59}
60
61/// Return the state/memory value being written by an op.
62static Value getPointerWrittenByOp(Operation *op) {
63 if (auto write = dyn_cast<StateWriteOp>(op))
64 return write.getState();
65 if (auto write = dyn_cast<MemoryWriteOp>(op))
66 return write.getMemory();
67 return {};
68}
69
70/// Return the state/memory value being read by an op.
71static Value getPointerReadByOp(Operation *op) {
72 if (auto read = dyn_cast<StateReadOp>(op))
73 return read.getState();
74 if (auto read = dyn_cast<MemoryReadOp>(op))
75 return read.getMemory();
76 return {};
77}
78
79/// Check if an operation has side effects, ignoring any nested ops. This is
80/// useful if we're traversing all nested ops anyway, and we are only interested
81/// in the current op's side effects.
82static bool hasSideEffects(Operation *op) {
83 if (auto memEffects = dyn_cast<MemoryEffectOpInterface>(op))
84 return !memEffects.hasNoEffect();
85 return !op->hasTrait<OpTrait::HasRecursiveMemoryEffects>();
86}
87
88namespace {
89/// An integer indicating the position of an operation in its parent block. The
90/// first field is the initial order/position assigned. The second field is used
91/// to order ops that were moved to the same location, which makes them have the
92/// same first field.
93using OpOrder = std::pair<unsigned, unsigned>;
94
95/// A helper that tracks an op and its order, and allows for convenient
96/// substitution with another op that has a higher/lower order.
97struct OpAndOrder {
98 Operation *op = nullptr;
99 OpOrder order = {0, 0};
100
101 explicit operator bool() const { return op; }
102
103 /// Assign `other` if its order is lower than this op, or this op is null.
104 void minimize(const OpAndOrder &other) {
105 if (!op || (other.op && other.order < order))
106 *this = other;
107 }
108
109 /// Assign `other` if its order is higher than this op, or this op is null.
110 void maximize(const OpAndOrder &other) {
111 if (!op || (other.op && other.order > order))
112 *this = other;
113 }
114};
115} // namespace
116
117/// Sink operations as close to their users as possible.
118void MergeIfsPass::sinkOps(Block &rootBlock) {
119 // A numeric position assigned to ops as we encounter them. Ops at the end of
120 // the block get the lowest order number, ops at the beginning the highest.
121 DenseMap<Operation *, OpOrder> opOrder;
122 // A lookup table that indicates where ops should be inserted. This is used to
123 // maintain the original op order if multiple ops pile up before the same
124 // other op that blocks their move.
125 DenseMap<Operation *, Operation *> insertionPoints;
126 // The write ops to each state/memory pointer we've seen so far. ("Next"
127 // because we run from the end to the beginning of the block.)
128 DenseMap<Value, Operation *> nextWrite;
129 // The most recent op that has an unknown (non-read/write) side-effect.
130 Operation *nextSideEffect = nullptr;
131
132 for (auto &op : llvm::make_early_inc_range(llvm::reverse(rootBlock))) {
133 // Assign an order to this op.
134 auto order = OpOrder{opOrder.size() + 1, 0};
135 opOrder[&op] = order;
136
137 // Analyze the side effects in the op.
138 op.walk([&](Operation *subOp) {
139 if (auto ptr = getPointerWrittenByOp(subOp))
140 nextWrite[ptr] = &op;
141 else if (!isa<StateReadOp, MemoryReadOp>(subOp) && hasSideEffects(subOp))
142 nextSideEffect = &op;
143 });
144
145 // Determine how much the op can be moved.
146 OpAndOrder moveLimit;
147 if (auto ptr = getPointerReadByOp(&op)) {
148 // Don't move across writes to the same state/memory.
149 if (auto *write = nextWrite.lookup(ptr))
150 moveLimit.maximize({write, opOrder.lookup(write)});
151 // Don't move across general side-effecting ops.
152 if (nextSideEffect)
153 moveLimit.maximize({nextSideEffect, opOrder.lookup(nextSideEffect)});
154 } else if (isa<StateWriteOp, MemoryWriteOp>(&op) || nextSideEffect == &op) {
155 // Don't move writes or side-effecting ops.
156 continue;
157 }
158
159 // Find the block that contains all uses.
160 Block *allUsesInBlock = nullptr;
161 for (auto *user : op.getUsers()) {
162 // If this user is directly in the root block there's no chance of sinking
163 // the current op anywhere.
164 if (user->getBlock() == &rootBlock) {
165 allUsesInBlock = nullptr;
166 break;
167 }
168
169 // Find the operation in the root block that contains this user.
170 while (user->getParentOp()->getBlock() != &rootBlock)
171 user = user->getParentOp();
172 assert(user);
173
174 // Check that all users sit in the same op in the root block.
175 if (!allUsesInBlock) {
176 allUsesInBlock = user->getBlock();
177 } else if (allUsesInBlock != user->getBlock()) {
178 allUsesInBlock = nullptr;
179 break;
180 }
181 }
182
183 // If no single block exists that contains all uses, find the earliest op in
184 // the root block that uses the current op.
185 OpAndOrder earliest;
186 if (allUsesInBlock) {
187 earliest.op = allUsesInBlock->getParentOp();
188 earliest.order = opOrder.lookup(earliest.op);
189 } else {
190 for (auto *user : op.getUsers()) {
191 while (user->getBlock() != &rootBlock)
192 user = user->getParentOp();
193 assert(user);
194 earliest.maximize({user, opOrder.lookup(user)});
195 }
196 }
197
198 // Ensure we don't move past the move limit imposed by side effects.
199 earliest.maximize(moveLimit);
200 if (!earliest)
201 continue;
202
203 // Either move the op inside the single block that contains all uses, or
204 // move it to just before its earliest user.
205 if (allUsesInBlock && allUsesInBlock->getParentOp() == earliest.op) {
206 op.moveBefore(allUsesInBlock, allUsesInBlock->begin());
207 ++numOpsSunk;
208 anyChanges = true;
209 LLVM_DEBUG(llvm::dbgs() << "- Sunk " << op << "\n");
210 } else {
211 // Insert above other ops that we have already moved to this earliest op.
212 // This ensures the original op order is maintained and we are not
213 // spuriously flipping ops around. This also works without the
214 // `insertionPoint` lookup, but can cause significant linear scanning to
215 // find the op before which we want to insert.
216 auto &insertionPoint = insertionPoints[earliest.op];
217 if (insertionPoint) {
218 auto order = opOrder.lookup(insertionPoint);
219 assert(order.first == earliest.order.first);
220 assert(order.second >= earliest.order.second);
221 earliest.op = insertionPoint;
222 earliest.order = order;
223 }
224 while (auto *prevOp = earliest.op->getPrevNode()) {
225 auto order = opOrder.lookup(prevOp);
226 if (order.first != earliest.order.first)
227 break;
228 assert(order.second > earliest.order.second);
229 earliest.op = prevOp;
230 earliest.order = order;
231 }
232 insertionPoint = earliest.op;
233
234 // Only move if the op isn't already in the right spot.
235 if (op.getNextNode() != earliest.op) {
236 LLVM_DEBUG(llvm::dbgs() << "- Moved " << op << "\n");
237 op.moveBefore(earliest.op);
238 ++numOpsMovedToUser;
239 anyChanges = true;
240 }
241
242 // Update the current op's order to reflect where it has been inserted.
243 // This ensures that later moves to the same pile of moved ops do not
244 // reorder the operations.
245 order = earliest.order;
246 assert(order.second < unsigned(-1));
247 ++order.second;
248 opOrder[&op] = order;
249 }
250 }
251}
252
253void MergeIfsPass::mergeIfs(Block &rootBlock) {
254 DenseSet<Value> prevIfWrites, prevIfReads;
255
256 scf::IfOp lastOp;
257 for (auto ifOp : rootBlock.getOps<scf::IfOp>()) {
258 auto prevIfOp = std::exchange(lastOp, ifOp);
259 if (!prevIfOp)
260 continue;
261
262 // Only handle simple cases for now. (Same condition, no results, and both
263 // ifs either have or don't have an else block.)
264 if (ifOp.getCondition() != prevIfOp.getCondition())
265 continue;
266 if (ifOp.getNumResults() != 0 || prevIfOp.getNumResults() != 0)
267 continue;
268 if (ifOp.getElseRegion().empty() != prevIfOp.getElseRegion().empty())
269 continue;
270
271 // Try to move ops in between the `scf.if` ops above the previous `scf.if`
272 // in order to make them immediately adjacent.
273 if (ifOp->getPrevNode() != prevIfOp) {
274 // Determine the side effects inside the previous if op.
275 bool prevIfHasSideEffects = false;
276 prevIfWrites.clear();
277 prevIfReads.clear();
278 prevIfOp.walk([&](Operation *op) {
279 if (auto ptr = getPointerWrittenByOp(op))
280 prevIfWrites.insert(ptr);
281 else if (auto ptr = getPointerReadByOp(op))
282 prevIfReads.insert(ptr);
283 else if (!prevIfHasSideEffects && hasSideEffects(op))
284 prevIfHasSideEffects = true;
285 });
286
287 // Check if it is legal to throw all ops over the previous `scf.if` op,
288 // given the side effects. We don't move the ops yet to ensure we can move
289 // *all* of them at once afterwards. Otherwise this optimization would
290 // race with the sink-to-users optimization.
291 bool allMovable = true;
292 for (auto &op : llvm::make_range(Block::iterator(prevIfOp->getNextNode()),
293 Block::iterator(ifOp))) {
294 auto result = op.walk([&](Operation *subOp) {
295 if (auto ptr = getPointerWrittenByOp(subOp)) {
296 // We can't move writes over writes or reads of the same state.
297 if (prevIfWrites.contains(ptr) || prevIfReads.contains(ptr))
298 return WalkResult::interrupt();
299 } else if (auto ptr = getPointerReadByOp(subOp)) {
300 // We can't move reads over writes to the same state.
301 if (prevIfWrites.contains(ptr))
302 return WalkResult::interrupt();
303 } else if (hasSideEffects(subOp)) {
304 // We can't move side-effecting ops over other side-effecting ops.
305 if (prevIfHasSideEffects)
306 return WalkResult::interrupt();
307 }
308 return WalkResult::advance();
309 });
310 if (result.wasInterrupted()) {
311 allMovable = false;
312 break;
313 }
314 }
315 if (!allMovable)
316 continue;
317
318 // At this point we know that all ops can be moved. Do so.
319 while (auto *op = prevIfOp->getNextNode()) {
320 if (op == ifOp)
321 break;
322 LLVM_DEBUG(llvm::dbgs() << "- Moved before if " << *op << "\n");
323 op->moveBefore(prevIfOp);
324 ++numOpsMovedFromBetweenIfs;
325 }
326 }
327
328 // Merge the then-blocks.
329 prevIfOp.thenYield().erase();
330 ifOp.thenBlock()->getOperations().splice(
331 ifOp.thenBlock()->begin(), prevIfOp.thenBlock()->getOperations());
332
333 // Merge the else-blocks if present.
334 if (ifOp.elseBlock()) {
335 prevIfOp.elseYield().erase();
336 ifOp.elseBlock()->getOperations().splice(
337 ifOp.elseBlock()->begin(), prevIfOp.elseBlock()->getOperations());
338 }
339
340 // Clean up.
341 prevIfOp.erase();
342 anyChanges = true;
343 ++numIfsMerged;
344 LLVM_DEBUG(llvm::dbgs() << "- Merged adjacent if ops\n");
345 }
346}
assert(baseType &&"element must be base type")
static bool hasSideEffects(Operation *op)
Check if an operation has side effects, ignoring any nested ops.
Definition MergeIfs.cpp:82
static Value getPointerWrittenByOp(Operation *op)
Return the state/memory value being written by an op.
Definition MergeIfs.cpp:62
static Value getPointerReadByOp(Operation *op)
Return the state/memory value being read by an op.
Definition MergeIfs.cpp:71
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.