CIRCT 20.0.0git
Loading...
Searching...
No Matches
IndexSwitchToIf.cpp
Go to the documentation of this file.
1//===- IndexSwitchToIf.cpp - Index switch to if-else pass ---*-C++-*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the SCF IndexSwitch to If-Else pass.
10//
11//===----------------------------------------------------------------------===//
12
14#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
15#include "mlir/Conversion/LLVMCommon/Pattern.h"
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/Dialect/SCF/IR/SCF.h"
18#include "mlir/IR/OperationSupport.h"
19#include "mlir/Pass/Pass.h"
20#include "mlir/Transforms/DialectConversion.h"
21
22namespace circt {
23#define GEN_PASS_DEF_INDEXSWITCHTOIF
24#include "circt/Transforms/Passes.h.inc"
25} // namespace circt
26
27using namespace mlir;
28using namespace circt;
29
30struct SwitchToIfConversion : public OpConversionPattern<scf::IndexSwitchOp> {
31 using OpConversionPattern::OpConversionPattern;
32
33 LogicalResult
34 matchAndRewrite(scf::IndexSwitchOp switchOp, OpAdaptor adaptor,
35 ConversionPatternRewriter &rewriter) const override {
36 auto loc = switchOp.getLoc();
37
38 Region &defaultRegion = switchOp.getDefaultRegion();
39 bool hasResults = !switchOp.getResultTypes().empty();
40
41 Value finalResult;
42 scf::IfOp prevIfOp = nullptr;
43
44 rewriter.setInsertionPointAfter(switchOp);
45 auto switchCases = switchOp.getCases();
46 for (size_t i = 0; i < switchCases.size(); i++) {
47 auto caseValueInt = switchCases[i];
48 if (prevIfOp)
49 rewriter.setInsertionPointToStart(&prevIfOp.getElseRegion().front());
50
51 Value caseValue =
52 rewriter.create<arith::ConstantIndexOp>(loc, caseValueInt);
53 Value cond = rewriter.create<arith::CmpIOp>(
54 loc, arith::CmpIPredicate::eq, switchOp.getOperand(), caseValue);
55
56 auto ifOp = rewriter.create<scf::IfOp>(loc, switchOp.getResultTypes(),
57 cond, /*hasElseRegion=*/true);
58
59 Region &caseRegion = switchOp.getCaseRegions()[i];
60 rewriter.eraseBlock(&ifOp.getThenRegion().front());
61 rewriter.inlineRegionBefore(caseRegion, ifOp.getThenRegion(),
62 ifOp.getThenRegion().end());
63
64 if (i + 1 == switchCases.size()) {
65 rewriter.eraseBlock(&ifOp.getElseRegion().front());
66 rewriter.inlineRegionBefore(defaultRegion, ifOp.getElseRegion(),
67 ifOp.getElseRegion().end());
68 }
69
70 if (prevIfOp && hasResults) {
71 rewriter.setInsertionPointToEnd(&prevIfOp.getElseRegion().front());
72 rewriter.create<scf::YieldOp>(loc, ifOp.getResult(0));
73 }
74
75 if (i == 0 && hasResults)
76 finalResult = ifOp.getResult(0);
77
78 prevIfOp = ifOp;
79 }
80
81 if (hasResults)
82 rewriter.replaceOp(switchOp, finalResult);
83 else
84 rewriter.eraseOp(switchOp);
85
86 return success();
87 }
88};
89
90namespace {
91
92struct IndexSwitchToIfPass
93 : public circt::impl::IndexSwitchToIfBase<IndexSwitchToIfPass> {
94public:
95 void runOnOperation() override {
96 auto *ctx = &getContext();
97 RewritePatternSet patterns(ctx);
98 ConversionTarget target(*ctx);
99
100 target.addLegalDialect<scf::SCFDialect>();
101 target.addLegalDialect<arith::ArithDialect>();
102 target.addLegalOp<ModuleOp, func::FuncOp, func::ReturnOp>();
103 target.addIllegalOp<scf::IndexSwitchOp>();
104
106
107 if (applyPartialConversion(getOperation(), target, std::move(patterns))
108 .failed()) {
109 signalPassFailure();
110 return;
111 }
112 }
113};
114
115} // namespace
116
117namespace circt {
118std::unique_ptr<mlir::Pass> createIndexSwitchToIfPass() {
119 return std::make_unique<IndexSwitchToIfPass>();
120}
121} // namespace circt
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createIndexSwitchToIfPass()
LogicalResult matchAndRewrite(scf::IndexSwitchOp switchOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override