CIRCT 22.0.0git
Loading...
Searching...
No Matches
UnrollLoops.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
14#include "mlir/Analysis/CFGLoopInfo.h"
15#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16#include "mlir/IR/Dominance.h"
17#include "mlir/IR/IRMapping.h"
18#include "mlir/IR/Matchers.h"
19#include "mlir/Pass/Pass.h"
20#include "llvm/ADT/PostOrderIterator.h"
21#include "llvm/Support/Debug.h"
22
23#define DEBUG_TYPE "llhd-unroll-loops"
24
25namespace circt {
26namespace llhd {
27#define GEN_PASS_DEF_UNROLLLOOPSPASS
28#include "circt/Dialect/LLHD/Transforms/LLHDPasses.h.inc"
29} // namespace llhd
30} // namespace circt
31
32using namespace mlir;
33using namespace circt;
34using namespace llhd;
35using llvm::SmallDenseSet;
36using llvm::SmallSetVector;
37
38//===----------------------------------------------------------------------===//
39// Utilities
40//===----------------------------------------------------------------------===//
41
42/// Clone a list of blocks into a region before the given block.
43///
44/// See `Region::cloneInto` for the original code that clones an entire region.
45static void cloneBlocks(ArrayRef<Block *> blocks, Region &region,
46 Region::iterator before, IRMapping &mapper) {
47 // If the list is empty there is nothing to clone.
48 if (blocks.empty())
49 return;
50
51 // First clone all the blocks and block arguments and map them, but don't yet
52 // clone the operations, as they may otherwise add a use to a block that has
53 // not yet been mapped
54 SmallVector<Block *> newBlocks;
55 newBlocks.reserve(blocks.size());
56 for (auto *block : blocks) {
57 auto *newBlock = new Block();
58 mapper.map(block, newBlock);
59 for (auto arg : block->getArguments())
60 mapper.map(arg, newBlock->addArgument(arg.getType(), arg.getLoc()));
61 region.getBlocks().insert(before, newBlock);
62 newBlocks.push_back(newBlock);
63 }
64
65 // Now follow up with creating the operations, but don't yet clone their
66 // regions, nor set their operands. Setting the successors is safe as all have
67 // already been mapped. We are essentially just creating the operation results
68 // to be able to map them. Cloning the operands and region as well would lead
69 // to uses of operations not yet mapped.
70 auto cloneOptions =
71 Operation::CloneOptions::all().cloneRegions(false).cloneOperands(false);
72 for (auto [oldBlock, newBlock] : llvm::zip(blocks, newBlocks))
73 for (auto &op : *oldBlock)
74 newBlock->push_back(op.clone(mapper, cloneOptions));
75
76 // Finally now that all operation results have been mapped, set the operands
77 // and clone the regions.
78 SmallVector<Value> operands;
79 for (auto [oldBlock, newBlock] : llvm::zip(blocks, newBlocks)) {
80 for (auto [oldOp, newOp] : llvm::zip(*oldBlock, *newBlock)) {
81 operands.resize(oldOp.getNumOperands());
82 llvm::transform(
83 oldOp.getOperands(), operands.begin(),
84 [&](Value operand) { return mapper.lookupOrDefault(operand); });
85 newOp.setOperands(operands);
86 for (auto [oldRegion, newRegion] :
87 llvm::zip(oldOp.getRegions(), newOp.getRegions()))
88 oldRegion.cloneInto(&newRegion, mapper);
89 }
90 }
91}
92
93//===----------------------------------------------------------------------===//
94// Loop Unroller
95//===----------------------------------------------------------------------===//
96
97namespace {
98/// A data structure tracking information on a single loop.
99struct Loop {
100 Loop(unsigned loopId, CFGLoop &cfgLoop) : loopId(loopId), cfgLoop(cfgLoop) {}
101 bool failMatch(const Twine &msg) const;
102 bool match();
103 void unroll(CFGLoopInfo &cfgLoopInfo);
104
105 /// A numeric identifier for debugging purposes.
106 unsigned loopId;
107 /// Loop analysis information about this specific loop.
108 CFGLoop &cfgLoop;
109 /// The CFG edge exiting the loop.
110 BlockOperand *exitEdge = nullptr;
111 /// The SSA value holding the exit condition.
112 Value exitCondition;
113 /// Whether the exit condition is inverted, i.e. the contination condition.
114 bool exitInverted;
115 /// The induction variable.
116 Value indVar;
117 /// The updated induction variable passed into the next loop iteration.
118 Value indVarNext;
119 /// The continuation predicate. The loop continues until the induction
120 /// variable compared against the end bound no longer matches this predicate.
121 comb::ICmpPredicate predicate;
122 /// The induction variable increment.
123 APInt indVarIncrement;
124 /// The initial value for the induction variable.
125 APInt beginBound;
126 /// The final value for the induction variable.
127 APInt endBound;
128 /// The number of iterations of the loop.
129 unsigned tripCount = 0;
130};
131} // namespace
132
133static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Loop &loop) {
134 os << "#" << loop.loopId << " from ";
135 loop.cfgLoop.getHeader()->printAsOperand(os);
136 os << " to ";
137 loop.cfgLoop.getLoopLatch()->printAsOperand(os);
138 return os;
139}
140
141/// Helper to print a debug message on match failure and return false.
142bool Loop::failMatch(const Twine &msg) const {
143 LLVM_DEBUG(llvm::dbgs() << "- Ignoring loop " << *this << ": " << msg
144 << "\n");
145 return false;
146}
147
148/// Check that the loop matches the specific pattern we understand, and extract
149/// the loop condition and induction variable.
150bool Loop::match() {
151 // Ensure that there is a unique exit point and condition for the loop.
152 SmallVector<BlockOperand *> exits;
153 for (auto *block : cfgLoop.getBlocks())
154 for (auto &edge : block->getTerminator()->getBlockOperands())
155 if (!cfgLoop.contains(edge.get()))
156 exits.push_back(&edge);
157 if (exits.size() != 1)
158 return failMatch("multiple exits");
159 exitEdge = exits.back();
160
161 // The terminator doing the exit must be a conditional branch.
162 auto exitBranch = dyn_cast<cf::CondBranchOp>(exitEdge->getOwner());
163 if (!exitBranch)
164 return failMatch("unsupported exit branch");
165 exitCondition = exitBranch.getCondition();
166 exitInverted = exitEdge->getOperandNumber() == 1;
167
168 // Determine one of the loop bounds and the induction variable based on the
169 // exit condition.
170 if (auto icmpOp = exitCondition.getDefiningOp<comb::ICmpOp>()) {
171 IntegerAttr boundAttr;
172 if (!matchPattern(icmpOp.getRhs(), m_Constant(&boundAttr)))
173 return failMatch("non-constant loop bound");
174 indVar = icmpOp.getLhs();
175 predicate = icmpOp.getPredicate();
176 endBound = boundAttr.getValue();
177 } else {
178 return failMatch("unsupported exit condition");
179 }
180
181 // If the exit condition is not inverted, the predicate is the exit predicate.
182 // Negate it such that we have a continuation predicate.
183 if (!exitInverted)
184 predicate = comb::ICmpOp::getNegatedPredicate(predicate);
185
186 // Determine the initial and next value of the induction variable.
187 auto *header = cfgLoop.getHeader();
188 auto *latch = cfgLoop.getLoopLatch();
189 auto indVarArg = dyn_cast<BlockArgument>(indVar);
190 if (!indVarArg || indVarArg.getOwner() != header)
191 return failMatch("induction variable is not a header block argument");
192 IntegerAttr beginBoundAttr;
193 for (auto &pred : header->getUses()) {
194 auto branchOp = dyn_cast<BranchOpInterface>(pred.getOwner());
195 if (!branchOp)
196 return failMatch("header predecessor terminator is not a branch op");
197 auto indVarValue = branchOp.getSuccessorOperands(
198 pred.getOperandNumber())[indVarArg.getArgNumber()];
199 IntegerAttr boundAttr;
200 if (pred.getOwner()->getBlock() == latch) {
201 indVarNext = indVarValue;
202 } else if (matchPattern(indVarValue, m_Constant(&boundAttr))) {
203 if (!beginBoundAttr)
204 beginBoundAttr = boundAttr;
205 else if (boundAttr != beginBoundAttr)
206 return failMatch("multiple initial bounds");
207 } else {
208 return failMatch("unsupported induction variable value");
209 }
210 }
211 if (!beginBoundAttr)
212 return failMatch("no initial bound");
213 beginBound = beginBoundAttr.getValue();
214
215 // Pattern match the increment operation on the induction variable.
216 if (auto addOp = indVarNext.getDefiningOp<comb::AddOp>();
217 addOp && addOp.getNumOperands() == 2) {
218 if (addOp.getOperand(0) != indVarArg)
219 return failMatch("increment LHS not the induction variable");
220 IntegerAttr incAttr;
221 if (!matchPattern(addOp.getOperand(1), m_Constant(&incAttr)))
222 return failMatch("increment RHS non-constant");
223 indVarIncrement = incAttr.getValue();
224 } else {
225 return failMatch("unsupported increment");
226 }
227
228 // Determine the trip count and loop behavior. We're very picky for now.
229 // for (unsigned i = 0; i < N; i += 1) with N < 1024
230 if (predicate == comb::ICmpPredicate::ult && indVarIncrement == 1 &&
231 beginBound == 0 && endBound.ult(1024)) {
232 tripCount = endBound.getZExtValue();
233 return true;
234 }
235 // for (signed i = 0; i < N; i += 1) with 0 <= N < 1024
236 if (predicate == comb::ICmpPredicate::slt && indVarIncrement == 1 &&
237 beginBound == 0 && !endBound.isNegative() && endBound.slt(1024)) {
238 tripCount = endBound.getZExtValue();
239 return true;
240 }
241 // for (signless i = N; i == N; i += S) with S != 0
242 if (predicate == comb::ICmpPredicate::eq && indVarIncrement != 0 &&
243 beginBound == endBound) {
244 tripCount = 1;
245 return true;
246 }
247 return failMatch("unsupported loop bounds");
248}
249
250/// Unroll the loop by cloning its body blocks and replacing the induction
251/// variable with constant iteration indices.
252void Loop::unroll(CFGLoopInfo &cfgLoopInfo) {
253 assert(beginBound == 0 && !endBound.isNegative() && indVarIncrement == 1);
254 LLVM_DEBUG(llvm::dbgs() << "- Unrolling loop " << *this << "\n");
255 UnusedOpPruner pruner;
256
257 // Sort the blocks in the body. This is not strictly necessary, but makes the
258 // pass a lot easier to reason about in tests.
259 auto *header = cfgLoop.getHeader();
260 SmallVector<Block *> orderedBody;
261 for (auto &block : *header->getParent())
262 if (cfgLoop.contains(&block))
263 orderedBody.push_back(&block);
264
265 // Copy the loop body for every iteration of the loop.
266 auto *latch = cfgLoop.getLoopLatch();
267 OpBuilder builder(indVar.getContext());
268 auto indValue = beginBound;
269 for (unsigned trip = 0; trip < tripCount; ++trip) {
270 // Clone the loop body.
271 IRMapping mapper;
272 cloneBlocks(orderedBody, *header->getParent(), header->getIterator(),
273 mapper);
274 auto *clonedHeader = mapper.lookup(header);
275 auto *clonedTail = mapper.lookup(latch);
276
277 // Replace the induction variable with the concrete value.
278 auto iterIndVar = mapper.lookup(indVar);
279 pruner.eraseLaterIfUnused(iterIndVar);
280 builder.setInsertionPointAfterValue(iterIndVar);
281 iterIndVar.replaceAllUsesWith(
282 hw::ConstantOp::create(builder, iterIndVar.getLoc(), indValue));
283
284 // Update all edges to the original loop header to point to the cloned loop
285 // header. Leave the original back-edge untouched.
286 for (auto &blockOperand : llvm::make_early_inc_range(header->getUses()))
287 if (blockOperand.getOwner()->getBlock() != latch)
288 blockOperand.set(clonedHeader);
289
290 // Update the back-edge in the cloned latch to point to the original loop
291 // header, i.e. the next iteration, instead of the cloned loop header.
292 for (auto &blockOperand : clonedTail->getTerminator()->getBlockOperands())
293 if (blockOperand.get() == clonedHeader)
294 blockOperand.set(header);
295
296 // Remove the exit edge in the cloned body, since we statically know that
297 // the loop will continue.
298 auto exitBranchOp =
299 cast<cf::CondBranchOp>(mapper.lookup(exitEdge->getOwner()));
300 Block *continueDest = exitBranchOp.getTrueDest();
301 ValueRange continueDestOperands = exitBranchOp.getTrueDestOperands();
302 if (exitEdge->getOperandNumber() == 0) {
303 continueDest = exitBranchOp.getFalseDest();
304 continueDestOperands = exitBranchOp.getFalseDestOperands();
305 }
306 builder.setInsertionPoint(exitBranchOp);
307 cf::BranchOp::create(builder, exitBranchOp.getLoc(), continueDest,
308 continueDestOperands);
309 pruner.eraseLaterIfUnused(exitBranchOp.getOperands());
310 exitBranchOp.erase();
311
312 // Add the new blocks to the loop body.
313 for (auto *block : orderedBody) {
314 auto *newBlock = mapper.lookup(block);
315 cfgLoop.addBasicBlockToLoop(newBlock, cfgLoopInfo);
316 }
317
318 // Increment the induction variable value.
319 indValue += indVarIncrement;
320 }
321
322 // Now that the loop body has been cloned once for each trip throughout the
323 // loop, we can clean up the final iteration by always breaking out of the
324 // loop. Start by replacing the induction variable with the final value.
325 pruner.eraseLaterIfUnused(indVar);
326 builder.setInsertionPointAfterValue(indVar);
327 indVar.replaceAllUsesWith(
328 hw::ConstantOp::create(builder, indVar.getLoc(), indValue));
329 indVar = {};
330
331 // Remove the continue edge of the exit branch in the loop body, since we
332 // statically know that the loop will exit.
333 auto exitBranchOp = cast<cf::CondBranchOp>(exitEdge->getOwner());
334 Block *exitDest = exitBranchOp.getTrueDest();
335 ValueRange exitDestOperands = exitBranchOp.getTrueDestOperands();
336 if (exitEdge->getOperandNumber() == 1) {
337 exitDest = exitBranchOp.getFalseDest();
338 exitDestOperands = exitBranchOp.getFalseDestOperands();
339 }
340 builder.setInsertionPoint(exitBranchOp);
341 cf::BranchOp::create(builder, exitBranchOp.getLoc(), exitDest,
342 exitDestOperands);
343 pruner.eraseLaterIfUnused(exitBranchOp.getOperands());
344 exitBranchOp.erase();
345 exitEdge = nullptr;
346
347 // Prune any body blocks that have become unreachable.
348 SmallPtrSet<Block *, 8> blocksToPrune;
349 for (auto *block : cfgLoop.getBlocks())
350 if (block->use_empty())
351 blocksToPrune.insert(block);
352 while (!blocksToPrune.empty()) {
353 auto *block = *blocksToPrune.begin();
354 blocksToPrune.erase(block);
355 if (!block->use_empty())
356 continue;
357 for (auto *succ : block->getSuccessors())
358 if (cfgLoop.contains(succ))
359 blocksToPrune.insert(succ);
360 block->dropAllDefinedValueUses();
361 cfgLoopInfo.removeBlock(block);
362 block->erase();
363 }
364
365 // Remove any unused operations and block arguments.
366 pruner.eraseNow();
367
368 // Collapse trivial branches to avoid carrying a lot of useless blocks around
369 // especially when unrolling nested loops.
370 for (auto &block : *header->getParent()) {
371 if (!cfgLoop.contains(&block))
372 continue;
373 while (true) {
374 auto branchOp = dyn_cast<cf::BranchOp>(block.getTerminator());
375 if (!branchOp)
376 break;
377 auto *otherBlock = branchOp.getDest();
378 if (!cfgLoop.contains(otherBlock) || !otherBlock->getSinglePredecessor())
379 break;
380 for (auto [blockArg, branchArg] :
381 llvm::zip(otherBlock->getArguments(), branchOp.getDestOperands()))
382 blockArg.replaceAllUsesWith(branchArg);
383 block.getOperations().splice(branchOp->getIterator(),
384 otherBlock->getOperations());
385 branchOp.erase();
386 cfgLoopInfo.removeBlock(otherBlock);
387 otherBlock->erase();
388 }
389 }
390}
391
392//===----------------------------------------------------------------------===//
393// Pass Infrastructure
394//===----------------------------------------------------------------------===//
395
396namespace {
397struct UnrollLoopsPass
398 : public llhd::impl::UnrollLoopsPassBase<UnrollLoopsPass> {
399 void runOnOperation() override;
400 void runOnOperation(CombinationalOp op);
401};
402} // namespace
403
404void UnrollLoopsPass::runOnOperation() {
405 for (auto op : getOperation().getOps<CombinationalOp>())
406 runOnOperation(op);
407}
408
409void UnrollLoopsPass::runOnOperation(CombinationalOp op) {
410 // There's nothing to do if we only have a single block. MLIR even refuses to
411 // compute a dominator tree in that case.
412 if (op.getBody().hasOneBlock())
413 return;
414
415 // Find the loops.
416 LLVM_DEBUG(llvm::dbgs() << "Unrolling loops in " << op.getLoc() << "\n");
417 DominanceInfo domInfo(op);
418 CFGLoopInfo cfgLoopInfo(domInfo.getDomTree(&op.getBody()));
419
420 // We only support simple loops where there is a single back-edge to the
421 // header, and the latch block has a back-edge to a single header. Create a
422 // data structure for each loop we can potentially unroll. The loops are in
423 // preorder, with outer loops appearing before their child loops.
424 SmallVector<Loop> loops;
425 for (auto *cfgLoop : cfgLoopInfo.getLoopsInPreorder()) {
426 // To simplify unrolling we need a unique latch block branching back to the
427 // header.
428 auto *header = cfgLoop->getHeader();
429 auto *latch = cfgLoop->getLoopLatch();
430 if (!latch)
431 continue;
432
433 LLVM_DEBUG({
434 llvm::dbgs() << "- ";
435 cfgLoop->print(llvm::dbgs(), false, false);
436 llvm::dbgs() << "\n";
437 });
438 Loop loop(loops.size(), *cfgLoop);
439
440 // Ensure that the header block is only a header for the current loop. This
441 // simplifies unrolling.
442 auto *parent = cfgLoop->getParentLoop();
443 while (parent && parent->getHeader() != header)
444 parent = parent->getParentLoop();
445 if (parent) {
446 loop.failMatch("header block shared across multiple loops");
447 continue;
448 }
449
450 // Ensure that the latch block is only a latch for the current loop. This
451 // simplifies unrolling.
452 parent = cfgLoop->getParentLoop();
453 while (parent && !parent->isLoopLatch(latch))
454 parent = parent->getParentLoop();
455 if (parent) {
456 loop.failMatch("latch block shared across multiple loops");
457 continue;
458 }
459
460 // Check if the loop body matches the pattern we can unroll.
461 if (loop.match())
462 loops.push_back(std::move(loop));
463 }
464
465 if (loops.empty())
466 return;
467
468 // Dump some debugging information about the loops we've found.
469 LLVM_DEBUG({
470 auto &os = llvm::dbgs();
471 for (auto &loop : loops) {
472 os << "- Loop " << loop << ":\n";
473 os << " - ";
474 loop.cfgLoop.print(os, false, false);
475 os << "\n";
476 os << " - Exit: ";
477 loop.exitEdge->get()->printAsOperand(os);
478 os << " if ";
479 if (loop.exitInverted)
480 os << "not ";
481 os << loop.exitCondition;
482 os << "\n";
483 os << " - Induction variable: ";
484 loop.indVar.printAsOperand(os, OpPrintingFlags().useLocalScope());
485 os << ", from " << loop.beginBound << ", while " << loop.predicate << " "
486 << loop.endBound << ", increment " << loop.indVarIncrement << "\n";
487 os << " - Trip count: " << loop.tripCount << "\n";
488 }
489 });
490
491 // Unroll the loops. Handling the loops in reverse unrolls inner loops before
492 // their parent loops.
493 for (auto &loop : llvm::reverse(loops))
494 loop.unroll(cfgLoopInfo);
495}
assert(baseType &&"element must be base type")
static void cloneBlocks(ArrayRef< Block * > blocks, Region &region, Region::iterator before, IRMapping &mapper)
Clone a list of blocks into a region before the given block.
create(data_type, value)
Definition hw.py:433
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...
void eraseLaterIfUnused(Operation *op)
Mark an op the be erased later if it is unused at that point.
void eraseNow(Operation *op)
Erase an operation immediately, and remove it from the set of ops to be removed later.