CIRCT  18.0.0git
MaximizeSSA.cpp
Go to the documentation of this file.
1 //===- MaximizeSSA.cpp - SSA Maximization Pass ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the SSA maximization pass as well as utilities
10 // for converting a function with standard control flow into maximal SSA form.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "PassDetail.h"
16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Support/LogicalResult.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/raw_ostream.h"
23 
24 using namespace mlir;
25 using namespace circt;
26 
27 static Block *getDefiningBlock(Value value) {
28  // Value is either a block argument...
29  if (auto blockArg = dyn_cast<BlockArgument>(value); blockArg)
30  return blockArg.getParentBlock();
31 
32  // ... or an operation's result
33  auto *defOp = value.getDefiningOp();
34  assert(defOp);
35  return defOp->getBlock();
36 }
37 
38 static LogicalResult addArgToTerminator(Block *block, Block *predBlock,
39  Value value) {
40 
41  // Identify terminator branching instruction in predecessor block
42  auto branchOp = dyn_cast<BranchOpInterface>(predBlock->getTerminator());
43  if (!branchOp)
44  return predBlock->getTerminator()->emitError(
45  "Expected terminator operation of block to be a "
46  "branch-like operation");
47 
48  // In the predecessor block's terminator, find all successors that equal
49  // the block and add the value to the list of operands it's passed
50  for (auto [idx, succBlock] : llvm::enumerate(branchOp->getSuccessors()))
51  if (succBlock == block)
52  branchOp.getSuccessorOperands(idx).append(value);
53 
54  return success();
55 }
56 
57 bool circt::isRegionSSAMaximized(Region &region) {
58 
59  // Check whether all operands used within each block are also defined within
60  // the same block
61  for (auto &block : region.getBlocks())
62  for (auto &op : block.getOperations())
63  for (auto operand : op.getOperands())
64  if (getDefiningBlock(operand) != &block)
65  return false;
66 
67  return true;
68 }
69 
71  return true;
72 }
74  return true;
75 }
76 bool circt::SSAMaximizationStrategy::maximizeOp(Operation *op) { return true; }
78  return true;
79 }
80 
81 LogicalResult circt::maximizeSSA(Value value, PatternRewriter &rewriter) {
82 
83  // Identify the basic block in which the value is defined
84  Block *defBlock = getDefiningBlock(value);
85 
86  // Identify all basic blocks in which the value is used (excluding the
87  // value-defining block)
88  DenseSet<Block *> blocksUsing;
89  for (auto &use : value.getUses()) {
90  auto *block = use.getOwner()->getBlock();
91  if (block != defBlock)
92  blocksUsing.insert(block);
93  }
94 
95  // Prepare a stack to iterate over the list of basic blocks that must be
96  // modified for the value to be in maximum SSA form. At all points,
97  // blocksUsing is a non-strict superset of the elements contained in
98  // blocksToVisit
99  SmallVector<Block *> blocksToVisit(blocksUsing.begin(), blocksUsing.end());
100 
101  // Backtrack from all blocks using the value to the value-defining basic
102  // block, adding a new block argument for the value along the way. Keep
103  // track of which blocks have already been modified to avoid visiting a
104  // block more than once while backtracking (possible due to branching
105  // control flow)
106  DenseMap<Block *, BlockArgument> blockToArg;
107  while (!blocksToVisit.empty()) {
108  // Pop the basic block at the top of the stack
109  auto *block = blocksToVisit.pop_back_val();
110 
111  // Add an argument to the block to hold the value
112  blockToArg[block] =
113  block->addArgument(value.getType(), rewriter.getUnknownLoc());
114 
115  // In all unique block predecessors, modify the terminator branching
116  // instruction to also pass the value to the block
117  SmallPtrSet<Block *, 8> uniquePredecessors;
118  for (auto *predBlock : block->getPredecessors()) {
119  // If we have already visited the block predecessor, skip it. It's
120  // possible to get duplicate block predecessors when there exists a
121  // conditional branch with both branches going to the same block e.g.,
122  // cf.cond_br %cond, ^bbx, ^bbx
123  if (auto [_, newPredecessor] = uniquePredecessors.insert(predBlock);
124  !newPredecessor) {
125  continue;
126  }
127 
128  // Modify the terminator instruction
129  if (failed(addArgToTerminator(block, predBlock, value)))
130  return failure();
131 
132  // Now the predecessor block is using the value, so we must also make sure
133  // to visit it
134  if (predBlock != defBlock)
135  if (auto [_, blockNewlyUsing] = blocksUsing.insert(predBlock);
136  blockNewlyUsing)
137  blocksToVisit.push_back(predBlock);
138  }
139  }
140 
141  // Replace all uses of the value with the newly added block arguments
142  SmallVector<Operation *> users;
143  for (auto &use : value.getUses()) {
144  auto *owner = use.getOwner();
145  if (owner->getBlock() != defBlock)
146  users.push_back(owner);
147  }
148  for (auto *user : users)
149  user->replaceUsesOfWith(value, blockToArg[user->getBlock()]);
150 
151  return success();
152 }
153 
154 LogicalResult circt::maximizeSSA(Operation *op,
155  SSAMaximizationStrategy &strategy,
156  PatternRewriter &rewriter) {
157  // Apply SSA maximization on each of the operation's results
158  for (auto res : op->getResults())
159  if (strategy.maximizeResult(res))
160  if (failed(maximizeSSA(res, rewriter)))
161  return failure();
162 
163  return success();
164 }
165 
166 LogicalResult circt::maximizeSSA(Block *block,
167  SSAMaximizationStrategy &strategy,
168  PatternRewriter &rewriter) {
169  // Apply SSA maximization on each of the block's arguments
170  for (auto arg : block->getArguments())
171  if (strategy.maximizeArgument(arg))
172  if (failed(maximizeSSA(arg, rewriter)))
173  return failure();
174 
175  // Apply SSA maximization on each of the block's operations
176  for (auto &op : block->getOperations())
177  if (strategy.maximizeOp(&op))
178  if (failed(maximizeSSA(&op, strategy, rewriter)))
179  return failure();
180 
181  return success();
182 }
183 
184 LogicalResult circt::maximizeSSA(Region &region,
185  SSAMaximizationStrategy &strategy,
186  PatternRewriter &rewriter) {
187  // Apply SSA maximization on each of the region's block
188  for (auto &block : region.getBlocks())
189  if (strategy.maximizeBlock(&block))
190  if (failed(maximizeSSA(&block, strategy, rewriter)))
191  return failure();
192 
193  return success();
194 }
195 
196 namespace {
197 
198 struct MaxSSAConversion : public ConversionPattern {
199 public:
200  MaxSSAConversion(MLIRContext *context)
201  : ConversionPattern(MatchAnyOpTypeTag(), 1, context) {}
202  LogicalResult
203  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
204  ConversionPatternRewriter &rewriter) const override {
205  LogicalResult conversionStatus = success();
206  rewriter.updateRootInPlace(op, [&] {
207  for (auto &region : op->getRegions()) {
208  SSAMaximizationStrategy strategy;
209  if (failed(maximizeSSA(region, strategy, rewriter)))
210  conversionStatus = failure();
211  }
212  });
213  return conversionStatus;
214  }
215 };
216 
217 struct MaximizeSSAPass : public MaximizeSSABase<MaximizeSSAPass> {
218 public:
219  void runOnOperation() override {
220  auto *ctx = &getContext();
221 
222  RewritePatternSet patterns{ctx};
223  patterns.add<MaxSSAConversion>(ctx);
224  ConversionTarget target(*ctx);
225 
226  // SSA maximization should apply to all region-defining ops.
227  target.markUnknownOpDynamicallyLegal([](Operation *op) {
228  return llvm::all_of(op->getRegions(), isRegionSSAMaximized);
229  });
230 
231  // Each region is turned into maximal SSA form independently of the
232  // others. Function signatures are never modified by SSA maximization
233  if (failed(applyPartialConversion(getOperation(), target,
234  std::move(patterns))))
235  signalPassFailure();
236  }
237 };
238 
239 } // namespace
240 
241 namespace circt {
242 std::unique_ptr<mlir::Pass> createMaximizeSSAPass() {
243  return std::make_unique<MaximizeSSAPass>();
244 }
245 } // namespace circt
assert(baseType &&"element must be base type")
static LogicalResult addArgToTerminator(Block *block, Block *predBlock, Value value)
Definition: MaximizeSSA.cpp:38
static Block * getDefiningBlock(Value value)
Definition: MaximizeSSA.cpp:27
Strategy class to control the behavior of SSA maximization.
Definition: Passes.h:53
virtual bool maximizeResult(OpResult res)
Determines whether an operation's result should be SSA maximized.
Definition: MaximizeSSA.cpp:77
virtual bool maximizeArgument(BlockArgument arg)
Determines whether a block argument should be SSA maximized.
Definition: MaximizeSSA.cpp:73
virtual bool maximizeBlock(Block *block)
Determines whether a block should have the values it defines (i.e., block arguments and operation res...
Definition: MaximizeSSA.cpp:70
virtual bool maximizeOp(Operation *op)
Determines whether an operation should have its results SSA maximized.
Definition: MaximizeSSA.cpp:76
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
std::unique_ptr< mlir::Pass > createMaximizeSSAPass()
LogicalResult maximizeSSA(Value value, PatternRewriter &rewriter)
Converts a single value within a function into maximal SSA form.
Definition: MaximizeSSA.cpp:81
bool isRegionSSAMaximized(Region &region)
Definition: MaximizeSSA.cpp:57