CIRCT 22.0.0git
Loading...
Searching...
No Matches
IndexSwitchToIf.cpp
Go to the documentation of this file.
1//===- IndexSwitchToIf.cpp - Index switch to if-else pass ---*-C++-*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the SCF IndexSwitch to If-Else pass.
10//
11//===----------------------------------------------------------------------===//
12
14#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15#include "mlir/Conversion/LLVMCommon/Pattern.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/IR/OperationSupport.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Transforms/DialectConversion.h"
21
22namespace circt {
23#define GEN_PASS_DEF_INDEXSWITCHTOIF
24#include "circt/Transforms/Passes.h.inc"
25} // namespace circt
26
27using namespace mlir;
28using namespace circt;
29
30struct SwitchToIfConversion : public OpConversionPattern<scf::IndexSwitchOp> {
31 using OpConversionPattern::OpConversionPattern;
32
33 LogicalResult
34 matchAndRewrite(scf::IndexSwitchOp switchOp, OpAdaptor adaptor,
35 ConversionPatternRewriter &rewriter) const override {
36 auto loc = switchOp.getLoc();
37
38 Region &defaultRegion = switchOp.getDefaultRegion();
39 bool hasResults = !switchOp.getResultTypes().empty();
40
41 Value finalResult;
42 scf::IfOp prevIfOp = nullptr;
43
44 rewriter.setInsertionPointAfter(switchOp);
45 auto switchCases = switchOp.getCases();
46 for (size_t i = 0; i < switchCases.size(); i++) {
47 auto caseValueInt = switchCases[i];
48 if (prevIfOp)
49 rewriter.setInsertionPointToStart(&prevIfOp.getElseRegion().front());
50
51 Value caseValue =
52 arith::ConstantIndexOp::create(rewriter, loc, caseValueInt);
53 Value cond =
54 arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
55 switchOp.getOperand(), caseValue);
56
57 auto ifOp = scf::IfOp::create(rewriter, loc, switchOp.getResultTypes(),
58 cond, /*hasElseRegion=*/true);
59
60 Region &caseRegion = switchOp.getCaseRegions()[i];
61 rewriter.eraseBlock(&ifOp.getThenRegion().front());
62 rewriter.inlineRegionBefore(caseRegion, ifOp.getThenRegion(),
63 ifOp.getThenRegion().end());
64
65 if (i + 1 == switchCases.size()) {
66 rewriter.eraseBlock(&ifOp.getElseRegion().front());
67 rewriter.inlineRegionBefore(defaultRegion, ifOp.getElseRegion(),
68 ifOp.getElseRegion().end());
69 }
70
71 if (prevIfOp && hasResults) {
72 rewriter.setInsertionPointToEnd(&prevIfOp.getElseRegion().front());
73 scf::YieldOp::create(rewriter, loc, ifOp.getResult(0));
74 }
75
76 if (i == 0 && hasResults)
77 finalResult = ifOp.getResult(0);
78
79 prevIfOp = ifOp;
80 }
81
82 if (hasResults)
83 rewriter.replaceOp(switchOp, finalResult);
84 else
85 rewriter.eraseOp(switchOp);
86
87 return success();
88 }
89};
90
91namespace {
92
93struct IndexSwitchToIfPass
94 : public circt::impl::IndexSwitchToIfBase<IndexSwitchToIfPass> {
95public:
96 void runOnOperation() override {
97 auto *ctx = &getContext();
98 RewritePatternSet patterns(ctx);
99 ConversionTarget target(*ctx);
100
101 target.addLegalDialect<scf::SCFDialect>();
102 target.addLegalDialect<arith::ArithDialect>();
103 target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
104 target.addIllegalOp<scf::IndexSwitchOp>();
105
107
108 if (applyPartialConversion(getOperation(), target, std::move(patterns))
109 .failed()) {
110 signalPassFailure();
111 return;
112 }
113 }
114};
115
116} // namespace
117
118namespace circt {
119std::unique_ptr<mlir::Pass> createIndexSwitchToIfPass() {
120 return std::make_unique<IndexSwitchToIfPass>();
121}
122} // namespace circt
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createIndexSwitchToIfPass()
LogicalResult matchAndRewrite(scf::IndexSwitchOp switchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override