CIRCT  20.0.0git
MaximizeSSA.cpp
Go to the documentation of this file.
1 //===- MaximizeSSA.cpp - SSA Maximization Pass ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the SSA maximization pass as well as utilities
10 // for converting a function with standard control flow into maximal SSA form.
11 //
12 //===----------------------------------------------------------------------===//
13 
15 #include "mlir/Dialect/Func/IR/FuncOps.h"
16 #include "mlir/IR/MLIRContext.h"
17 #include "mlir/Pass/Pass.h"
18 #include "mlir/Support/LogicalResult.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 #include "llvm/ADT/STLExtras.h"
21 #include "llvm/Support/raw_ostream.h"
22 
23 namespace circt {
24 #define GEN_PASS_DEF_MAXIMIZESSA
25 #include "circt/Transforms/Passes.h.inc"
26 } // namespace circt
27 
28 using namespace mlir;
29 using namespace circt;
30 
31 static Block *getDefiningBlock(Value value) {
32  // Value is either a block argument...
33  if (auto blockArg = dyn_cast<BlockArgument>(value); blockArg)
34  return blockArg.getParentBlock();
35 
36  // ... or an operation's result
37  auto *defOp = value.getDefiningOp();
38  assert(defOp);
39  return defOp->getBlock();
40 }
41 
42 static LogicalResult addArgToTerminator(Block *block, Block *predBlock,
43  Value value) {
44 
45  // Identify terminator branching instruction in predecessor block
46  auto branchOp = dyn_cast<BranchOpInterface>(predBlock->getTerminator());
47  if (!branchOp)
48  return predBlock->getTerminator()->emitError(
49  "Expected terminator operation of block to be a "
50  "branch-like operation");
51 
52  // In the predecessor block's terminator, find all successors that equal
53  // the block and add the value to the list of operands it's passed
54  for (auto [idx, succBlock] : llvm::enumerate(branchOp->getSuccessors()))
55  if (succBlock == block)
56  branchOp.getSuccessorOperands(idx).append(value);
57 
58  return success();
59 }
60 
61 bool circt::isRegionSSAMaximized(Region &region) {
62 
63  // Check whether all operands used within each block are also defined within
64  // the same block
65  for (auto &block : region.getBlocks())
66  for (auto &op : block.getOperations())
67  for (auto operand : op.getOperands())
68  if (getDefiningBlock(operand) != &block)
69  return false;
70 
71  return true;
72 }
73 
75  return true;
76 }
78  return true;
79 }
80 bool circt::SSAMaximizationStrategy::maximizeOp(Operation *op) { return true; }
82  return true;
83 }
84 
85 LogicalResult circt::maximizeSSA(Value value, PatternRewriter &rewriter) {
86 
87  // Identify the basic block in which the value is defined
88  Block *defBlock = getDefiningBlock(value);
89 
90  // Identify all basic blocks in which the value is used (excluding the
91  // value-defining block)
92  DenseSet<Block *> blocksUsing;
93  for (auto &use : value.getUses()) {
94  auto *block = use.getOwner()->getBlock();
95  if (block != defBlock)
96  blocksUsing.insert(block);
97  }
98 
99  // Prepare a stack to iterate over the list of basic blocks that must be
100  // modified for the value to be in maximum SSA form. At all points,
101  // blocksUsing is a non-strict superset of the elements contained in
102  // blocksToVisit
103  SmallVector<Block *> blocksToVisit(blocksUsing.begin(), blocksUsing.end());
104 
105  // Backtrack from all blocks using the value to the value-defining basic
106  // block, adding a new block argument for the value along the way. Keep
107  // track of which blocks have already been modified to avoid visiting a
108  // block more than once while backtracking (possible due to branching
109  // control flow)
110  DenseMap<Block *, BlockArgument> blockToArg;
111  while (!blocksToVisit.empty()) {
112  // Pop the basic block at the top of the stack
113  auto *block = blocksToVisit.pop_back_val();
114 
115  // Add an argument to the block to hold the value
116  blockToArg[block] =
117  block->addArgument(value.getType(), rewriter.getUnknownLoc());
118 
119  // In all unique block predecessors, modify the terminator branching
120  // instruction to also pass the value to the block
121  SmallPtrSet<Block *, 8> uniquePredecessors;
122  for (auto *predBlock : block->getPredecessors()) {
123  // If we have already visited the block predecessor, skip it. It's
124  // possible to get duplicate block predecessors when there exists a
125  // conditional branch with both branches going to the same block e.g.,
126  // cf.cond_br %cond, ^bbx, ^bbx
127  if (auto [_, newPredecessor] = uniquePredecessors.insert(predBlock);
128  !newPredecessor) {
129  continue;
130  }
131 
132  // Modify the terminator instruction
133  if (failed(addArgToTerminator(block, predBlock, value)))
134  return failure();
135 
136  // Now the predecessor block is using the value, so we must also make sure
137  // to visit it
138  if (predBlock != defBlock)
139  if (auto [_, blockNewlyUsing] = blocksUsing.insert(predBlock);
140  blockNewlyUsing)
141  blocksToVisit.push_back(predBlock);
142  }
143  }
144 
145  // Replace all uses of the value with the newly added block arguments
146  SmallVector<Operation *> users;
147  for (auto &use : value.getUses()) {
148  auto *owner = use.getOwner();
149  if (owner->getBlock() != defBlock)
150  users.push_back(owner);
151  }
152  for (auto *user : users)
153  user->replaceUsesOfWith(value, blockToArg[user->getBlock()]);
154 
155  return success();
156 }
157 
158 LogicalResult circt::maximizeSSA(Operation *op,
160  PatternRewriter &rewriter) {
161  // Apply SSA maximization on each of the operation's results
162  for (auto res : op->getResults())
163  if (strategy.maximizeResult(res))
164  if (failed(maximizeSSA(res, rewriter)))
165  return failure();
166 
167  return success();
168 }
169 
170 LogicalResult circt::maximizeSSA(Block *block,
172  PatternRewriter &rewriter) {
173  // Apply SSA maximization on each of the block's arguments
174  for (auto arg : block->getArguments())
175  if (strategy.maximizeArgument(arg))
176  if (failed(maximizeSSA(arg, rewriter)))
177  return failure();
178 
179  // Apply SSA maximization on each of the block's operations
180  for (auto &op : block->getOperations())
181  if (strategy.maximizeOp(&op))
182  if (failed(maximizeSSA(&op, strategy, rewriter)))
183  return failure();
184 
185  return success();
186 }
187 
188 LogicalResult circt::maximizeSSA(Region &region,
190  PatternRewriter &rewriter) {
191  // Apply SSA maximization on each of the region's block
192  for (auto &block : region.getBlocks())
193  if (strategy.maximizeBlock(&block))
194  if (failed(maximizeSSA(&block, strategy, rewriter)))
195  return failure();
196 
197  return success();
198 }
199 
200 namespace {
201 
202 struct MaxSSAConversion : public ConversionPattern {
203 public:
204  MaxSSAConversion(MLIRContext *context)
205  : ConversionPattern(MatchAnyOpTypeTag(), 1, context) {}
206  LogicalResult
207  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
208  ConversionPatternRewriter &rewriter) const override {
209  LogicalResult conversionStatus = success();
210  rewriter.modifyOpInPlace(op, [&] {
211  for (auto &region : op->getRegions()) {
212  SSAMaximizationStrategy strategy;
213  if (failed(maximizeSSA(region, strategy, rewriter)))
214  conversionStatus = failure();
215  }
216  });
217  return conversionStatus;
218  }
219 };
220 
221 struct MaximizeSSAPass : public circt::impl::MaximizeSSABase<MaximizeSSAPass> {
222 public:
223  void runOnOperation() override {
224  auto *ctx = &getContext();
225 
226  RewritePatternSet patterns{ctx};
227  patterns.add<MaxSSAConversion>(ctx);
228  ConversionTarget target(*ctx);
229 
230  // SSA maximization should apply to all region-defining ops.
231  target.markUnknownOpDynamicallyLegal([](Operation *op) {
232  return llvm::all_of(op->getRegions(), isRegionSSAMaximized);
233  });
234 
235  // Each region is turned into maximal SSA form independently of the
236  // others. Function signatures are never modified by SSA maximization
237  if (failed(applyPartialConversion(getOperation(), target,
238  std::move(patterns))))
239  signalPassFailure();
240  }
241 };
242 
243 } // namespace
244 
245 namespace circt {
246 std::unique_ptr<mlir::Pass> createMaximizeSSAPass() {
247  return std::make_unique<MaximizeSSAPass>();
248 }
249 } // namespace circt
assert(baseType &&"element must be base type")
static LogicalResult addArgToTerminator(Block *block, Block *predBlock, Value value)
Definition: MaximizeSSA.cpp:42
static Block * getDefiningBlock(Value value)
Definition: MaximizeSSA.cpp:31
Strategy strategy
Strategy class to control the behavior of SSA maximization.
Definition: Passes.h:69
virtual bool maximizeResult(OpResult res)
Determines whether an operation's result should be SSA maximized.
Definition: MaximizeSSA.cpp:81
virtual bool maximizeArgument(BlockArgument arg)
Determines whether a block argument should be SSA maximized.
Definition: MaximizeSSA.cpp:77
virtual bool maximizeBlock(Block *block)
Determines whether a block should have the values it defines (i.e., block arguments and operation res...
Definition: MaximizeSSA.cpp:74
virtual bool maximizeOp(Operation *op)
Determines whether an operation should have its results SSA maximized.
Definition: MaximizeSSA.cpp:80
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
std::unique_ptr< mlir::Pass > createMaximizeSSAPass()
LogicalResult maximizeSSA(Value value, PatternRewriter &rewriter)
Converts a single value within a function into maximal SSA form.
Definition: MaximizeSSA.cpp:85
bool isRegionSSAMaximized(Region &region)
Definition: MaximizeSSA.cpp:61