CIRCT 22.0.0git
Loading...
Searching...
No Matches
ConvertIndexToUInt.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
11#include "mlir/Dialect/Arith/IR/Arith.h"
12#include "mlir/IR/BuiltinAttributes.h"
13#include "mlir/IR/BuiltinOps.h"
14#include "mlir/IR/BuiltinTypes.h"
15#include "mlir/IR/PatternMatch.h"
16#include "mlir/Support/LogicalResult.h"
17#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18
19namespace circt {
20#define GEN_PASS_DEF_CONVERTINDEXTOUINT
21#include "circt/Transforms/Passes.h.inc"
22} // namespace circt
23
24using namespace mlir;
25using namespace circt;
26
27namespace {
28
29/// Rewrite `arith.cmpi` operations that still reason about `index` values into
30/// pure integer comparisons so that subsequent hardware mappings only observe
31/// integer arithmetic.
32class IndexCmpToIntegerPattern : public OpRewritePattern<arith::CmpIOp> {
33public:
34 using OpRewritePattern::OpRewritePattern;
35
36 LogicalResult matchAndRewrite(arith::CmpIOp op,
37 PatternRewriter &rewriter) const override {
38 if (!op.getLhs().getType().isIndex() || !op.getRhs().getType().isIndex())
39 return failure();
40
41 FailureOr<IntegerType> targetType = getTargetIntegerType(op);
42 if (failed(targetType))
43 return failure();
44
45 // Peel index operands back to the original integer type: either drop an
46 // index_cast (only if it came from the exact target integer type) or
47 // rebuild an index constant as an integer constant. Anything else keeps
48 // the pattern from firing so we never rewrite mixed or ambiguous operands.
49 auto convertOperand = [&](Value operand) -> FailureOr<Value> {
50 if (auto castOp = operand.getDefiningOp<arith::IndexCastOp>()) {
51 Value source = castOp.getIn();
52 auto srcType = dyn_cast<IntegerType>(source.getType());
53 if (!srcType || srcType != *targetType)
54 return failure();
55 return source;
56 }
57
58 if (auto constOp = operand.getDefiningOp<arith::ConstantOp>()) {
59 if (!constOp.getType().isIndex())
60 return failure();
61
62 auto value = dyn_cast<IntegerAttr>(constOp.getValue());
63 if (!value)
64 return failure();
65
66 auto attr = rewriter.getIntegerAttr(*targetType, value.getInt());
67 auto newConst =
68 arith::ConstantOp::create(rewriter, constOp.getLoc(), attr);
69 return newConst.getResult();
70 }
71
72 return failure();
73 };
74
75 FailureOr<Value> lhs = convertOperand(op.getLhs());
76 FailureOr<Value> rhs = convertOperand(op.getRhs());
77 if (failed(lhs) || failed(rhs))
78 return failure();
79
80 rewriter.replaceOpWithNewOp<arith::CmpIOp>(op, op.getPredicate(), *lhs,
81 *rhs);
82 return success();
83 }
84
85private:
86 static FailureOr<IntegerType> getTargetIntegerType(arith::CmpIOp op) {
87 auto pickType = [](Value operand) -> FailureOr<IntegerType> {
88 if (auto castOp = operand.getDefiningOp<arith::IndexCastOp>()) {
89 if (auto srcType = dyn_cast<IntegerType>(castOp.getIn().getType()))
90 return srcType;
91 }
92 return failure();
93 };
94
95 auto lhsType = pickType(op.getLhs());
96 if (succeeded(lhsType))
97 return *lhsType;
98
99 auto rhsType = pickType(op.getRhs());
100 if (succeeded(rhsType))
101 return *rhsType;
102
103 return failure();
104 }
105};
106
107/// Drop `arith.index_cast` that became unused once comparisons were rewritten.
108class DropUnusedIndexCastPattern : public OpRewritePattern<arith::IndexCastOp> {
109public:
110 using OpRewritePattern::OpRewritePattern;
111
112 LogicalResult matchAndRewrite(arith::IndexCastOp op,
113 PatternRewriter &rewriter) const override {
114 if (!op->use_empty())
115 return failure();
116 rewriter.eraseOp(op);
117 return success();
118 }
119};
120
121/// Remove `arith.constant` index definitions that no longer feed any user.
122class DropUnusedIndexConstantPattern
123 : public OpRewritePattern<arith::ConstantOp> {
124public:
125 using OpRewritePattern::OpRewritePattern;
126
127 LogicalResult matchAndRewrite(arith::ConstantOp op,
128 PatternRewriter &rewriter) const override {
129 if (!op.getType().isIndex() || !op->use_empty())
130 return failure();
131 rewriter.eraseOp(op);
132 return success();
133 }
134};
135
136struct ConvertIndexToUIntPass
137 : public circt::impl::ConvertIndexToUIntBase<ConvertIndexToUIntPass> {
138 void runOnOperation() override {
139 MLIRContext *ctx = &getContext();
140 RewritePatternSet patterns(ctx);
141 patterns.add<IndexCmpToIntegerPattern, DropUnusedIndexCastPattern,
142 DropUnusedIndexConstantPattern>(ctx);
143
144 if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
145 signalPassFailure();
146 }
147};
148
149} // namespace
150
151std::unique_ptr<mlir::Pass> circt::createConvertIndexToUIntPass() {
152 return std::make_unique<ConvertIndexToUIntPass>();
153}
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createConvertIndexToUIntPass()