14#include "mlir/Analysis/CFGLoopInfo.h"
15#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
16#include "mlir/IR/Dominance.h"
17#include "mlir/IR/IRMapping.h"
18#include "mlir/IR/Matchers.h"
19#include "mlir/Pass/Pass.h"
20#include "llvm/ADT/PostOrderIterator.h"
21#include "llvm/Support/Debug.h"
23#define DEBUG_TYPE "llhd-unroll-loops"
27#define GEN_PASS_DEF_UNROLLLOOPSPASS
28#include "circt/Dialect/LLHD/Transforms/LLHDPasses.h.inc"
35using llvm::SmallDenseSet;
36using llvm::SmallSetVector;
45static void cloneBlocks(ArrayRef<Block *> blocks, Region ®ion,
46 Region::iterator before, IRMapping &mapper) {
54 SmallVector<Block *> newBlocks;
55 newBlocks.reserve(blocks.size());
56 for (
auto *block : blocks) {
57 auto *newBlock =
new Block();
58 mapper.map(block, newBlock);
59 for (
auto arg : block->getArguments())
60 mapper.map(arg, newBlock->addArgument(arg.getType(), arg.getLoc()));
61 region.getBlocks().insert(before, newBlock);
62 newBlocks.push_back(newBlock);
71 Operation::CloneOptions::all().cloneRegions(
false).cloneOperands(
false);
72 for (
auto [oldBlock, newBlock] : llvm::zip(blocks, newBlocks))
73 for (
auto &op : *oldBlock)
74 newBlock->push_back(op.clone(mapper, cloneOptions));
78 SmallVector<Value> operands;
79 for (
auto [oldBlock, newBlock] : llvm::zip(blocks, newBlocks)) {
80 for (
auto [oldOp, newOp] : llvm::zip(*oldBlock, *newBlock)) {
81 operands.resize(oldOp.getNumOperands());
83 oldOp.getOperands(), operands.begin(),
84 [&](Value operand) { return mapper.lookupOrDefault(operand); });
85 newOp.setOperands(operands);
86 for (
auto [oldRegion, newRegion] :
87 llvm::zip(oldOp.getRegions(), newOp.getRegions()))
88 oldRegion.cloneInto(&newRegion, mapper);
100 Loop(
unsigned loopId, CFGLoop &cfgLoop) : loopId(loopId), cfgLoop(cfgLoop) {}
101 bool failMatch(
const Twine &msg)
const;
103 void unroll(CFGLoopInfo &cfgLoopInfo);
110 BlockOperand *exitEdge =
nullptr;
121 comb::ICmpPredicate predicate;
123 APInt indVarIncrement;
129 unsigned tripCount = 0;
133static llvm::raw_ostream &
operator<<(llvm::raw_ostream &os,
const Loop &loop) {
134 os <<
"#" << loop.loopId <<
" from ";
135 loop.cfgLoop.getHeader()->printAsOperand(os);
137 loop.cfgLoop.getLoopLatch()->printAsOperand(os);
142bool Loop::failMatch(
const Twine &msg)
const {
143 LLVM_DEBUG(llvm::dbgs() <<
"- Ignoring loop " << *
this <<
": " << msg
152 SmallVector<BlockOperand *> exits;
153 for (
auto *block : cfgLoop.getBlocks())
154 for (auto &edge : block->getTerminator()->getBlockOperands())
155 if (!cfgLoop.contains(edge.
get()))
156 exits.push_back(&edge);
157 if (exits.size() != 1)
158 return failMatch(
"multiple exits");
159 exitEdge = exits.back();
162 auto exitBranch = dyn_cast<cf::CondBranchOp>(exitEdge->getOwner());
164 return failMatch(
"unsupported exit branch");
165 exitCondition = exitBranch.getCondition();
166 exitInverted = exitEdge->getOperandNumber() == 1;
170 if (
auto icmpOp = exitCondition.getDefiningOp<comb::ICmpOp>()) {
171 IntegerAttr boundAttr;
172 if (!matchPattern(icmpOp.getRhs(), m_Constant(&boundAttr)))
173 return failMatch(
"non-constant loop bound");
174 indVar = icmpOp.getLhs();
175 predicate = icmpOp.getPredicate();
176 endBound = boundAttr.getValue();
178 return failMatch(
"unsupported exit condition");
184 predicate = comb::ICmpOp::getNegatedPredicate(predicate);
187 auto *header = cfgLoop.getHeader();
188 auto *latch = cfgLoop.getLoopLatch();
189 auto indVarArg = dyn_cast<BlockArgument>(indVar);
190 if (!indVarArg || indVarArg.getOwner() != header)
191 return failMatch(
"induction variable is not a header block argument");
192 IntegerAttr beginBoundAttr;
193 for (
auto &pred : header->getUses()) {
194 auto branchOp = dyn_cast<BranchOpInterface>(pred.getOwner());
196 return failMatch(
"header predecessor terminator is not a branch op");
197 auto indVarValue = branchOp.getSuccessorOperands(
198 pred.getOperandNumber())[indVarArg.getArgNumber()];
199 IntegerAttr boundAttr;
200 if (pred.getOwner()->getBlock() == latch) {
201 indVarNext = indVarValue;
202 }
else if (matchPattern(indVarValue, m_Constant(&boundAttr))) {
204 beginBoundAttr = boundAttr;
205 else if (boundAttr != beginBoundAttr)
206 return failMatch(
"multiple initial bounds");
208 return failMatch(
"unsupported induction variable value");
212 return failMatch(
"no initial bound");
213 beginBound = beginBoundAttr.getValue();
216 if (
auto addOp = indVarNext.getDefiningOp<
comb::AddOp>();
217 addOp && addOp.getNumOperands() == 2) {
218 if (addOp.getOperand(0) != indVarArg)
219 return failMatch(
"increment LHS not the induction variable");
221 if (!matchPattern(addOp.getOperand(1), m_Constant(&incAttr)))
222 return failMatch(
"increment RHS non-constant");
223 indVarIncrement = incAttr.getValue();
225 return failMatch(
"unsupported increment");
230 if (predicate == comb::ICmpPredicate::ult && indVarIncrement == 1 &&
231 beginBound == 0 && endBound.ult(1024)) {
232 tripCount = endBound.getZExtValue();
236 if (predicate == comb::ICmpPredicate::slt && indVarIncrement == 1 &&
237 beginBound == 0 && !endBound.isNegative() && endBound.slt(1024)) {
238 tripCount = endBound.getZExtValue();
242 if (predicate == comb::ICmpPredicate::eq && indVarIncrement != 0 &&
243 beginBound == endBound) {
247 return failMatch(
"unsupported loop bounds");
252void Loop::unroll(CFGLoopInfo &cfgLoopInfo) {
253 assert(beginBound == 0 && !endBound.isNegative() && indVarIncrement == 1);
254 LLVM_DEBUG(llvm::dbgs() <<
"- Unrolling loop " << *
this <<
"\n");
259 auto *header = cfgLoop.getHeader();
260 SmallVector<Block *> orderedBody;
261 for (
auto &block : *header->getParent())
262 if (cfgLoop.contains(&block))
263 orderedBody.push_back(&block);
266 auto *latch = cfgLoop.getLoopLatch();
267 OpBuilder builder(indVar.getContext());
268 auto indValue = beginBound;
269 for (
unsigned trip = 0; trip < tripCount; ++trip) {
272 cloneBlocks(orderedBody, *header->getParent(), header->getIterator(),
274 auto *clonedHeader = mapper.lookup(header);
275 auto *clonedTail = mapper.lookup(latch);
278 auto iterIndVar = mapper.lookup(indVar);
280 builder.setInsertionPointAfterValue(iterIndVar);
281 iterIndVar.replaceAllUsesWith(
286 for (
auto &blockOperand :
llvm::make_early_inc_range(header->getUses()))
287 if (blockOperand.getOwner()->getBlock() != latch)
288 blockOperand.set(clonedHeader);
292 for (
auto &blockOperand : clonedTail->getTerminator()->getBlockOperands())
293 if (blockOperand.
get() == clonedHeader)
294 blockOperand.set(header);
299 cast<cf::CondBranchOp>(mapper.lookup(exitEdge->getOwner()));
300 Block *continueDest = exitBranchOp.getTrueDest();
301 ValueRange continueDestOperands = exitBranchOp.getTrueDestOperands();
302 if (exitEdge->getOperandNumber() == 0) {
303 continueDest = exitBranchOp.getFalseDest();
304 continueDestOperands = exitBranchOp.getFalseDestOperands();
306 builder.setInsertionPoint(exitBranchOp);
307 cf::BranchOp::create(builder, exitBranchOp.getLoc(), continueDest,
308 continueDestOperands);
310 exitBranchOp.erase();
313 for (
auto *block : orderedBody) {
314 auto *newBlock = mapper.lookup(block);
315 cfgLoop.addBasicBlockToLoop(newBlock, cfgLoopInfo);
319 indValue += indVarIncrement;
326 builder.setInsertionPointAfterValue(indVar);
327 indVar.replaceAllUsesWith(
333 auto exitBranchOp = cast<cf::CondBranchOp>(exitEdge->getOwner());
334 Block *exitDest = exitBranchOp.getTrueDest();
335 ValueRange exitDestOperands = exitBranchOp.getTrueDestOperands();
336 if (exitEdge->getOperandNumber() == 1) {
337 exitDest = exitBranchOp.getFalseDest();
338 exitDestOperands = exitBranchOp.getFalseDestOperands();
340 builder.setInsertionPoint(exitBranchOp);
341 cf::BranchOp::create(builder, exitBranchOp.getLoc(), exitDest,
344 exitBranchOp.erase();
348 SmallPtrSet<Block *, 8> blocksToPrune;
349 for (
auto *block : cfgLoop.getBlocks())
350 if (block->use_empty())
351 blocksToPrune.insert(block);
352 while (!blocksToPrune.empty()) {
353 auto *block = *blocksToPrune.begin();
354 blocksToPrune.erase(block);
355 if (!block->use_empty())
357 for (
auto *succ : block->getSuccessors())
358 if (cfgLoop.contains(succ))
359 blocksToPrune.insert(succ);
360 block->dropAllDefinedValueUses();
361 cfgLoopInfo.removeBlock(block);
370 for (
auto &block : *header->getParent()) {
371 if (!cfgLoop.contains(&block))
374 auto branchOp = dyn_cast<cf::BranchOp>(block.getTerminator());
377 auto *otherBlock = branchOp.getDest();
378 if (!cfgLoop.contains(otherBlock) || !otherBlock->getSinglePredecessor())
380 for (
auto [blockArg, branchArg] :
381 llvm::zip(otherBlock->getArguments(), branchOp.getDestOperands()))
382 blockArg.replaceAllUsesWith(branchArg);
383 block.getOperations().splice(branchOp->getIterator(),
384 otherBlock->getOperations());
386 cfgLoopInfo.removeBlock(otherBlock);
397struct UnrollLoopsPass
398 :
public llhd::impl::UnrollLoopsPassBase<UnrollLoopsPass> {
399 void runOnOperation()
override;
400 void runOnOperation(CombinationalOp op);
404void UnrollLoopsPass::runOnOperation() {
405 for (
auto op : getOperation().getOps<CombinationalOp>())
409void UnrollLoopsPass::runOnOperation(CombinationalOp op) {
412 if (op.getBody().hasOneBlock())
416 LLVM_DEBUG(llvm::dbgs() <<
"Unrolling loops in " << op.getLoc() <<
"\n");
417 DominanceInfo domInfo(op);
418 CFGLoopInfo cfgLoopInfo(domInfo.getDomTree(&op.getBody()));
424 SmallVector<Loop> loops;
425 for (
auto *cfgLoop : cfgLoopInfo.getLoopsInPreorder()) {
428 auto *header = cfgLoop->getHeader();
429 auto *latch = cfgLoop->getLoopLatch();
434 llvm::dbgs() <<
"- ";
435 cfgLoop->print(llvm::dbgs(),
false,
false);
436 llvm::dbgs() <<
"\n";
438 Loop loop(loops.size(), *cfgLoop);
442 auto *parent = cfgLoop->getParentLoop();
443 while (parent && parent->getHeader() != header)
444 parent = parent->getParentLoop();
446 loop.failMatch(
"header block shared across multiple loops");
452 parent = cfgLoop->getParentLoop();
453 while (parent && !parent->isLoopLatch(latch))
454 parent = parent->getParentLoop();
456 loop.failMatch(
"latch block shared across multiple loops");
462 loops.push_back(std::move(loop));
470 auto &os = llvm::dbgs();
471 for (
auto &loop : loops) {
472 os <<
"- Loop " << loop <<
":\n";
474 loop.cfgLoop.print(os,
false,
false);
477 loop.exitEdge->get()->printAsOperand(os);
479 if (loop.exitInverted)
481 os << loop.exitCondition;
483 os <<
" - Induction variable: ";
484 loop.indVar.printAsOperand(os, OpPrintingFlags().useLocalScope());
485 os <<
", from " << loop.beginBound <<
", while " << loop.predicate <<
" "
486 << loop.endBound <<
", increment " << loop.indVarIncrement <<
"\n";
487 os <<
" - Trip count: " << loop.tripCount <<
"\n";
493 for (
auto &loop :
llvm::reverse(loops))
494 loop.unroll(cfgLoopInfo);
assert(baseType &&"element must be base type")
static void cloneBlocks(ArrayRef< Block * > blocks, Region ®ion, Region::iterator before, IRMapping &mapper)
Clone a list of blocks into a region before the given block.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
OS & operator<<(OS &os, const InnerSymTarget &target)
Printing InnerSymTarget's.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...
void eraseLaterIfUnused(Operation *op)
Mark an op the be erased later if it is unused at that point.
void eraseNow(Operation *op)
Erase an operation immediately, and remove it from the set of ops to be removed later.