16 #include "mlir/Dialect/Func/IR/FuncOps.h"
17 #include "mlir/IR/MLIRContext.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Support/LogicalResult.h"
20 #include "mlir/Transforms/DialectConversion.h"
21 #include "llvm/ADT/STLExtras.h"
22 #include "llvm/Support/raw_ostream.h"
25 using namespace circt;
29 if (
auto blockArg = dyn_cast<BlockArgument>(value); blockArg)
30 return blockArg.getParentBlock();
33 auto *defOp = value.getDefiningOp();
35 return defOp->getBlock();
42 auto branchOp = dyn_cast<BranchOpInterface>(predBlock->getTerminator());
44 return predBlock->getTerminator()->emitError(
45 "Expected terminator operation of block to be a "
46 "branch-like operation");
50 for (
auto [idx, succBlock] : llvm::enumerate(branchOp->getSuccessors()))
51 if (succBlock == block)
52 branchOp.getSuccessorOperands(idx).append(value);
61 for (
auto &block : region.getBlocks())
62 for (
auto &op : block.getOperations())
63 for (
auto operand : op.getOperands())
88 DenseSet<Block *> blocksUsing;
89 for (
auto &use : value.getUses()) {
90 auto *block = use.getOwner()->getBlock();
91 if (block != defBlock)
92 blocksUsing.insert(block);
99 SmallVector<Block *> blocksToVisit(blocksUsing.begin(), blocksUsing.end());
106 DenseMap<Block *, BlockArgument> blockToArg;
107 while (!blocksToVisit.empty()) {
109 auto *block = blocksToVisit.pop_back_val();
113 block->addArgument(value.getType(), rewriter.getUnknownLoc());
117 SmallPtrSet<Block *, 8> uniquePredecessors;
118 for (
auto *predBlock : block->getPredecessors()) {
123 if (
auto [_, newPredecessor] = uniquePredecessors.insert(predBlock);
134 if (predBlock != defBlock)
135 if (
auto [_, blockNewlyUsing] = blocksUsing.insert(predBlock);
137 blocksToVisit.push_back(predBlock);
142 SmallVector<Operation *> users;
143 for (
auto &use : value.getUses()) {
144 auto *owner = use.getOwner();
145 if (owner->getBlock() != defBlock)
146 users.push_back(owner);
148 for (
auto *user : users)
149 user->replaceUsesOfWith(value, blockToArg[user->getBlock()]);
156 PatternRewriter &rewriter) {
158 for (
auto res : op->getResults())
168 PatternRewriter &rewriter) {
170 for (
auto arg : block->getArguments())
176 for (
auto &op : block->getOperations())
186 PatternRewriter &rewriter) {
188 for (
auto &block : region.getBlocks())
190 if (failed(
maximizeSSA(&block, strategy, rewriter)))
198 struct MaxSSAConversion :
public ConversionPattern {
200 MaxSSAConversion(MLIRContext *context)
201 : ConversionPattern(MatchAnyOpTypeTag(), 1, context) {}
203 matchAndRewrite(Operation *op, ArrayRef<Value> operands,
204 ConversionPatternRewriter &rewriter)
const override {
205 LogicalResult conversionStatus = success();
206 rewriter.updateRootInPlace(op, [&] {
207 for (
auto ®ion : op->getRegions()) {
208 SSAMaximizationStrategy strategy;
209 if (failed(maximizeSSA(region, strategy, rewriter)))
210 conversionStatus = failure();
213 return conversionStatus;
217 struct MaximizeSSAPass :
public MaximizeSSABase<MaximizeSSAPass> {
219 void runOnOperation()
override {
220 auto *ctx = &getContext();
223 patterns.add<MaxSSAConversion>(ctx);
224 ConversionTarget target(*ctx);
227 target.markUnknownOpDynamicallyLegal([](Operation *op) {
233 if (failed(applyPartialConversion(getOperation(), target,
243 return std::make_unique<MaximizeSSAPass>();
assert(baseType &&"element must be base type")
static LogicalResult addArgToTerminator(Block *block, Block *predBlock, Value value)
static Block * getDefiningBlock(Value value)
Strategy class to control the behavior of SSA maximization.
virtual bool maximizeResult(OpResult res)
Determines whether an operation's result should be SSA maximized.
virtual bool maximizeArgument(BlockArgument arg)
Determines whether a block argument should be SSA maximized.
virtual bool maximizeBlock(Block *block)
Determines whether a block should have the values it defines (i.e., block arguments and operation res...
virtual bool maximizeOp(Operation *op)
Determines whether an operation should have its results SSA maximized.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
std::unique_ptr< mlir::Pass > createMaximizeSSAPass()
LogicalResult maximizeSSA(Value value, PatternRewriter &rewriter)
Converts a single value within a function into maximal SSA form.
bool isRegionSSAMaximized(Region ®ion)