CIRCT 20.0.0git
Loading...
Searching...
No Matches
MapArithToComb.cpp
Go to the documentation of this file.
1//===- MapArithToComb.cpp - Arith-to-comb mapping pass ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the MapArithToComb pass.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Transforms/DialectConversion.h"
20
21namespace circt {
22#define GEN_PASS_DEF_MAPARITHTOCOMBPASS
23#include "circt/Transforms/Passes.h.inc"
24} // namespace circt
25
26using namespace mlir;
27using namespace circt;
28
29namespace {
30
31// A type converter which legalizes integer types, thus ensuring that vector
32// types are illegal.
33class MapArithTypeConverter : public mlir::TypeConverter {
34public:
35 MapArithTypeConverter() {
36 addConversion([](Type type) {
37 if (isa<mlir::IntegerType>(type))
38 return type;
39
40 return Type();
41 });
42 }
43};
44
45template <typename TFrom, typename TTo, bool cloneAttrs = false>
46class OneToOnePattern : public OpConversionPattern<TFrom> {
47public:
49 using OpAdaptor = typename TFrom::Adaptor;
50
51 LogicalResult
52 matchAndRewrite(TFrom op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override {
54 rewriter.replaceOpWithNewOp<TTo>(
55 op, adaptor.getOperands(),
56 cloneAttrs ? op->getAttrs() : ArrayRef<::mlir::NamedAttribute>());
57 return success();
58 }
59};
60
61class ExtSConversionPattern : public OpConversionPattern<arith::ExtSIOp> {
62public:
63 using OpConversionPattern<arith::ExtSIOp>::OpConversionPattern;
64 using OpAdaptor = typename arith::ExtSIOp::Adaptor;
65
66 LogicalResult
67 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter) const override {
69 size_t outWidth = op.getType().getIntOrFloatBitWidth();
70 rewriter.replaceOp(op, comb::createOrFoldSExt(
71 op.getLoc(), op.getOperand(),
72 rewriter.getIntegerType(outWidth), rewriter));
73 return success();
74 }
75};
76
77class ExtZConversionPattern : public OpConversionPattern<arith::ExtUIOp> {
78public:
79 using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
80 using OpAdaptor = typename arith::ExtUIOp::Adaptor;
81
82 LogicalResult
83 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
84 ConversionPatternRewriter &rewriter) const override {
85 auto loc = op.getLoc();
86 size_t outWidth = op.getOut().getType().getIntOrFloatBitWidth();
87 size_t inWidth = adaptor.getIn().getType().getIntOrFloatBitWidth();
88
89 rewriter.replaceOp(op, rewriter.create<comb::ConcatOp>(
90 loc,
91 rewriter.create<hw::ConstantOp>(
92 loc, APInt(outWidth - inWidth, 0)),
93 adaptor.getIn()));
94 return success();
95 }
96};
97
98class TruncateConversionPattern : public OpConversionPattern<arith::TruncIOp> {
99public:
100 using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
101 using OpAdaptor = typename arith::TruncIOp::Adaptor;
102
103 LogicalResult
104 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter) const override {
106 size_t outWidth = op.getType().getIntOrFloatBitWidth();
107 rewriter.replaceOpWithNewOp<comb::ExtractOp>(op, adaptor.getIn(), 0,
108 outWidth);
109 return success();
110 }
111};
112
113class CompConversionPattern : public OpConversionPattern<arith::CmpIOp> {
114public:
115 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
116 using OpAdaptor = typename arith::CmpIOp::Adaptor;
117
118 static comb::ICmpPredicate
119 arithToCombPredicate(arith::CmpIPredicate predicate) {
120 switch (predicate) {
121 case arith::CmpIPredicate::eq:
122 return comb::ICmpPredicate::eq;
123 case arith::CmpIPredicate::ne:
124 return comb::ICmpPredicate::ne;
125 case arith::CmpIPredicate::slt:
126 return comb::ICmpPredicate::slt;
127 case arith::CmpIPredicate::ult:
128 return comb::ICmpPredicate::ult;
129 case arith::CmpIPredicate::sle:
130 return comb::ICmpPredicate::sle;
131 case arith::CmpIPredicate::ule:
132 return comb::ICmpPredicate::ule;
133 case arith::CmpIPredicate::sgt:
134 return comb::ICmpPredicate::sgt;
135 case arith::CmpIPredicate::ugt:
136 return comb::ICmpPredicate::ugt;
137 case arith::CmpIPredicate::sge:
138 return comb::ICmpPredicate::sge;
139 case arith::CmpIPredicate::uge:
140 return comb::ICmpPredicate::uge;
141 }
142 llvm_unreachable("Unknown predicate");
143 }
144
145 LogicalResult
146 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter) const override {
148 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
149 op, arithToCombPredicate(op.getPredicate()), adaptor.getLhs(),
150 adaptor.getRhs());
151 return success();
152 }
153};
154
155struct MapArithToCombPass
156 : public circt::impl::MapArithToCombPassBase<MapArithToCombPass> {
157public:
158 void runOnOperation() override {
159 auto *ctx = &getContext();
160
161 ConversionTarget target(*ctx);
162 target.addLegalDialect<comb::CombDialect, hw::HWDialect>();
163 target.addIllegalDialect<arith::ArithDialect>();
164 MapArithTypeConverter typeConverter;
165 RewritePatternSet patterns(ctx);
166
167 patterns.insert<OneToOnePattern<arith::AddIOp, comb::AddOp>,
168 OneToOnePattern<arith::SubIOp, comb::SubOp>,
169 OneToOnePattern<arith::MulIOp, comb::MulOp>,
170 OneToOnePattern<arith::DivSIOp, comb::DivSOp>,
171 OneToOnePattern<arith::DivUIOp, comb::DivUOp>,
172 OneToOnePattern<arith::RemSIOp, comb::ModSOp>,
173 OneToOnePattern<arith::RemUIOp, comb::ModUOp>,
174 OneToOnePattern<arith::AndIOp, comb::AndOp>,
175 OneToOnePattern<arith::OrIOp, comb::OrOp>,
176 OneToOnePattern<arith::XOrIOp, comb::XorOp>,
177 OneToOnePattern<arith::ShLIOp, comb::ShlOp>,
178 OneToOnePattern<arith::ShRSIOp, comb::ShrSOp>,
179 OneToOnePattern<arith::ShRUIOp, comb::ShrUOp>,
180 OneToOnePattern<arith::ConstantOp, hw::ConstantOp, true>,
181 OneToOnePattern<arith::SelectOp, comb::MuxOp>,
182 ExtSConversionPattern, ExtZConversionPattern,
183 TruncateConversionPattern, CompConversionPattern>(
184 typeConverter, ctx);
185
186 if (failed(applyPartialConversion(getOperation(), target,
187 std::move(patterns))))
188 signalPassFailure();
189 }
190};
191
192} // namespace
193
194std::unique_ptr<mlir::Pass> circt::createMapArithToCombPass() {
195 return std::make_unique<MapArithToCombPass>();
196}
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createMapArithToCombPass()