CIRCT 22.0.0git
Loading...
Searching...
No Matches
IndexSwitchToIf.cpp
Go to the documentation of this file.
1//===- IndexSwitchToIf.cpp - Index switch to if-else pass ---*-C++-*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the SCF IndexSwitch to If-Else pass.
10//
11//===----------------------------------------------------------------------===//
12
14#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15#include "mlir/Conversion/LLVMCommon/Pattern.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/IR/OperationSupport.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Transforms/DialectConversion.h"
21#include "llvm/ADT/STLExtras.h"
22
23namespace circt {
24#define GEN_PASS_DEF_INDEXSWITCHTOIF
25#include "circt/Transforms/Passes.h.inc"
26} // namespace circt
27
28using namespace mlir;
29using namespace circt;
30
31struct SwitchToIfConversion : public OpConversionPattern<scf::IndexSwitchOp> {
32 using OpConversionPattern::OpConversionPattern;
33
34 LogicalResult
35 matchAndRewrite(scf::IndexSwitchOp switchOp, OpAdaptor adaptor,
36 ConversionPatternRewriter &rewriter) const override {
37 auto loc = switchOp.getLoc();
38
39 Region &defaultRegion = switchOp.getDefaultRegion();
40 bool hasResults = !switchOp.getResultTypes().empty();
41
42 SmallVector<Value> finalResults;
43 scf::IfOp prevIfOp = nullptr;
44
45 rewriter.setInsertionPointAfter(switchOp);
46 auto switchCases = switchOp.getCases();
47 Value switchOperand = adaptor.getArg();
48 if (!switchOperand)
49 return rewriter.notifyMatchFailure(switchOp,
50 "missing converted switch operand");
51 for (size_t i = 0; i < switchCases.size(); i++) {
52 auto caseValueInt = switchCases[i];
53 if (prevIfOp)
54 rewriter.setInsertionPointToStart(&prevIfOp.getElseRegion().front());
55
56 Value caseValue =
57 arith::ConstantIndexOp::create(rewriter, loc, caseValueInt);
58 Value cond = arith::CmpIOp::create(
59 rewriter, loc, arith::CmpIPredicate::eq, switchOperand, caseValue);
60
61 auto ifOp = scf::IfOp::create(rewriter, loc, switchOp.getResultTypes(),
62 cond, /*hasElseRegion=*/true);
63
64 Region &caseRegion = switchOp.getCaseRegions()[i];
65 rewriter.eraseBlock(&ifOp.getThenRegion().front());
66 rewriter.inlineRegionBefore(caseRegion, ifOp.getThenRegion(),
67 ifOp.getThenRegion().end());
68
69 if (i + 1 == switchCases.size()) {
70 rewriter.eraseBlock(&ifOp.getElseRegion().front());
71 rewriter.inlineRegionBefore(defaultRegion, ifOp.getElseRegion(),
72 ifOp.getElseRegion().end());
73 }
74
75 if (prevIfOp && hasResults) {
76 rewriter.setInsertionPointToEnd(&prevIfOp.getElseRegion().front());
77 scf::YieldOp::create(rewriter, loc, ifOp.getResults());
78 }
79
80 if (i == 0 && hasResults)
81 llvm::append_range(finalResults, ifOp.getResults());
82
83 prevIfOp = ifOp;
84 }
85
86 if (hasResults)
87 rewriter.replaceOp(switchOp, finalResults);
88 else
89 rewriter.eraseOp(switchOp);
90
91 return success();
92 }
93};
94
95namespace {
96
97struct IndexSwitchToIfPass
98 : public circt::impl::IndexSwitchToIfBase<IndexSwitchToIfPass> {
99public:
100 void runOnOperation() override {
101 auto *ctx = &getContext();
102 RewritePatternSet patterns(ctx);
103 ConversionTarget target(*ctx);
104
105 target.addLegalDialect<scf::SCFDialect>();
106 target.addLegalDialect<arith::ArithDialect>();
107 target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
108 target.addIllegalOp<scf::IndexSwitchOp>();
109
111
112 if (applyPartialConversion(getOperation(), target, std::move(patterns))
113 .failed()) {
114 signalPassFailure();
115 return;
116 }
117 }
118};
119
120} // namespace
121
122namespace circt {
123std::unique_ptr<mlir::Pass> createIndexSwitchToIfPass() {
124 return std::make_unique<IndexSwitchToIfPass>();
125}
126} // namespace circt
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createIndexSwitchToIfPass()
LogicalResult matchAndRewrite(scf::IndexSwitchOp switchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override