CIRCT 20.0.0git
Loading...
Searching...
No Matches
MaximizeSSA.cpp
Go to the documentation of this file.
1//===- MaximizeSSA.cpp - SSA Maximization Pass ------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the SSA maximization pass as well as utilities
10// for converting a function with standard control flow into maximal SSA form.
11//
12//===----------------------------------------------------------------------===//
13
15#include "mlir/Dialect/Func/IR/FuncOps.h"
16#include "mlir/IR/MLIRContext.h"
17#include "mlir/Pass/Pass.h"
18#include "mlir/Support/LogicalResult.h"
19#include "mlir/Transforms/DialectConversion.h"
20#include "llvm/ADT/STLExtras.h"
21#include "llvm/Support/raw_ostream.h"
22
23namespace circt {
24#define GEN_PASS_DEF_MAXIMIZESSA
25#include "circt/Transforms/Passes.h.inc"
26} // namespace circt
27
28using namespace mlir;
29using namespace circt;
30
31static Block *getDefiningBlock(Value value) {
32 // Value is either a block argument...
33 if (auto blockArg = dyn_cast<BlockArgument>(value); blockArg)
34 return blockArg.getParentBlock();
35
36 // ... or an operation's result
37 auto *defOp = value.getDefiningOp();
38 assert(defOp);
39 return defOp->getBlock();
40}
41
42static LogicalResult addArgToTerminator(Block *block, Block *predBlock,
43 Value value) {
44
45 // Identify terminator branching instruction in predecessor block
46 auto branchOp = dyn_cast<BranchOpInterface>(predBlock->getTerminator());
47 if (!branchOp)
48 return predBlock->getTerminator()->emitError(
49 "Expected terminator operation of block to be a "
50 "branch-like operation");
51
52 // In the predecessor block's terminator, find all successors that equal
53 // the block and add the value to the list of operands it's passed
54 for (auto [idx, succBlock] : llvm::enumerate(branchOp->getSuccessors()))
55 if (succBlock == block)
56 branchOp.getSuccessorOperands(idx).append(value);
57
58 return success();
59}
60
61bool circt::isRegionSSAMaximized(Region &region) {
62
63 // Check whether all operands used within each block are also defined within
64 // the same block
65 for (auto &block : region.getBlocks())
66 for (auto &op : block.getOperations())
67 for (auto operand : op.getOperands())
68 if (getDefiningBlock(operand) != &block)
69 return false;
70
71 return true;
72}
73
75 return true;
76}
78 return true;
79}
80bool circt::SSAMaximizationStrategy::maximizeOp(Operation *op) { return true; }
82 return true;
83}
84
85LogicalResult circt::maximizeSSA(Value value, PatternRewriter &rewriter) {
86
87 // Identify the basic block in which the value is defined
88 Block *defBlock = getDefiningBlock(value);
89
90 // Identify all basic blocks in which the value is used (excluding the
91 // value-defining block)
92 DenseSet<Block *> blocksUsing;
93 for (auto &use : value.getUses()) {
94 auto *block = use.getOwner()->getBlock();
95 if (block != defBlock)
96 blocksUsing.insert(block);
97 }
98
99 // Prepare a stack to iterate over the list of basic blocks that must be
100 // modified for the value to be in maximum SSA form. At all points,
101 // blocksUsing is a non-strict superset of the elements contained in
102 // blocksToVisit
103 SmallVector<Block *> blocksToVisit(blocksUsing.begin(), blocksUsing.end());
104
105 // Backtrack from all blocks using the value to the value-defining basic
106 // block, adding a new block argument for the value along the way. Keep
107 // track of which blocks have already been modified to avoid visiting a
108 // block more than once while backtracking (possible due to branching
109 // control flow)
110 DenseMap<Block *, BlockArgument> blockToArg;
111 while (!blocksToVisit.empty()) {
112 // Pop the basic block at the top of the stack
113 auto *block = blocksToVisit.pop_back_val();
114
115 // Add an argument to the block to hold the value
116 blockToArg[block] =
117 block->addArgument(value.getType(), rewriter.getUnknownLoc());
118
119 // In all unique block predecessors, modify the terminator branching
120 // instruction to also pass the value to the block
121 SmallPtrSet<Block *, 8> uniquePredecessors;
122 for (auto *predBlock : block->getPredecessors()) {
123 // If we have already visited the block predecessor, skip it. It's
124 // possible to get duplicate block predecessors when there exists a
125 // conditional branch with both branches going to the same block e.g.,
126 // cf.cond_br %cond, ^bbx, ^bbx
127 if (auto [_, newPredecessor] = uniquePredecessors.insert(predBlock);
128 !newPredecessor) {
129 continue;
130 }
131
132 // Modify the terminator instruction
133 if (failed(addArgToTerminator(block, predBlock, value)))
134 return failure();
135
136 // Now the predecessor block is using the value, so we must also make sure
137 // to visit it
138 if (predBlock != defBlock)
139 if (auto [_, blockNewlyUsing] = blocksUsing.insert(predBlock);
140 blockNewlyUsing)
141 blocksToVisit.push_back(predBlock);
142 }
143 }
144
145 // Replace all uses of the value with the newly added block arguments
146 SmallVector<Operation *> users;
147 for (auto &use : value.getUses()) {
148 auto *owner = use.getOwner();
149 if (owner->getBlock() != defBlock)
150 users.push_back(owner);
151 }
152 for (auto *user : users)
153 user->replaceUsesOfWith(value, blockToArg[user->getBlock()]);
154
155 return success();
156}
157
158LogicalResult circt::maximizeSSA(Operation *op,
160 PatternRewriter &rewriter) {
161 // Apply SSA maximization on each of the operation's results
162 for (auto res : op->getResults())
163 if (strategy.maximizeResult(res))
164 if (failed(maximizeSSA(res, rewriter)))
165 return failure();
166
167 return success();
168}
169
170LogicalResult circt::maximizeSSA(Block *block,
172 PatternRewriter &rewriter) {
173 // Apply SSA maximization on each of the block's arguments
174 for (auto arg : block->getArguments())
175 if (strategy.maximizeArgument(arg))
176 if (failed(maximizeSSA(arg, rewriter)))
177 return failure();
178
179 // Apply SSA maximization on each of the block's operations
180 for (auto &op : block->getOperations())
181 if (strategy.maximizeOp(&op))
182 if (failed(maximizeSSA(&op, strategy, rewriter)))
183 return failure();
184
185 return success();
186}
187
188LogicalResult circt::maximizeSSA(Region &region,
190 PatternRewriter &rewriter) {
191 // Apply SSA maximization on each of the region's block
192 for (auto &block : region.getBlocks())
193 if (strategy.maximizeBlock(&block))
194 if (failed(maximizeSSA(&block, strategy, rewriter)))
195 return failure();
196
197 return success();
198}
199
200namespace {
201
202struct MaxSSAConversion : public ConversionPattern {
203public:
204 MaxSSAConversion(MLIRContext *context)
205 : ConversionPattern(MatchAnyOpTypeTag(), 1, context) {}
206 LogicalResult
207 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
208 ConversionPatternRewriter &rewriter) const override {
209 LogicalResult conversionStatus = success();
210 rewriter.modifyOpInPlace(op, [&] {
211 for (auto &region : op->getRegions()) {
213 if (failed(maximizeSSA(region, strategy, rewriter)))
214 conversionStatus = failure();
215 }
216 });
217 return conversionStatus;
218 }
219};
220
221struct MaximizeSSAPass : public circt::impl::MaximizeSSABase<MaximizeSSAPass> {
222public:
223 void runOnOperation() override {
224 auto *ctx = &getContext();
225
226 RewritePatternSet patterns{ctx};
227 patterns.add<MaxSSAConversion>(ctx);
228 ConversionTarget target(*ctx);
229
230 // SSA maximization should apply to all region-defining ops.
231 target.markUnknownOpDynamicallyLegal([](Operation *op) {
232 return llvm::all_of(op->getRegions(), isRegionSSAMaximized);
233 });
234
235 // Each region is turned into maximal SSA form independently of the
236 // others. Function signatures are never modified by SSA maximization
237 if (failed(applyPartialConversion(getOperation(), target,
238 std::move(patterns))))
239 signalPassFailure();
240 }
241};
242
243} // namespace
244
245namespace circt {
246std::unique_ptr<mlir::Pass> createMaximizeSSAPass() {
247 return std::make_unique<MaximizeSSAPass>();
248}
249} // namespace circt
assert(baseType &&"element must be base type")
static LogicalResult addArgToTerminator(Block *block, Block *predBlock, Value value)
static Block * getDefiningBlock(Value value)
Strategy strategy
Strategy class to control the behavior of SSA maximization.
Definition Passes.h:73
virtual bool maximizeResult(OpResult res)
Determines whether an operation's result should be SSA maximized.
virtual bool maximizeArgument(BlockArgument arg)
Determines whether a block argument should be SSA maximized.
virtual bool maximizeBlock(Block *block)
Determines whether a block should have the values it defines (i.e., block arguments and operation res...
virtual bool maximizeOp(Operation *op)
Determines whether an operation should have its results SSA maximized.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createMaximizeSSAPass()
LogicalResult maximizeSSA(Value value, PatternRewriter &rewriter)
Converts a single value within a function into maximal SSA form.
bool isRegionSSAMaximized(Region &region)