CIRCT 22.0.0git
Loading...
Searching...
No Matches
MapArithToComb.cpp
Go to the documentation of this file.
1//===- MapArithToComb.cpp - Arith-to-comb mapping pass ----------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Contains the definitions of the MapArithToComb pass.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Transforms/DialectConversion.h"
20
21namespace circt {
22#define GEN_PASS_DEF_MAPARITHTOCOMBPASS
23#include "circt/Transforms/Passes.h.inc"
24} // namespace circt
25
26using namespace mlir;
27using namespace circt;
28
29namespace {
30
31// A type converter which legalizes integer types, thus ensuring that vector
32// types are illegal.
33class MapArithTypeConverter : public mlir::TypeConverter {
34public:
35 MapArithTypeConverter() {
36 addConversion([](Type type) {
37 if (hw::isHWValueType(type))
38 return type;
39
40 return Type();
41 });
42 }
43};
44
45template <typename TFrom, typename TTo, bool cloneAttrs = false>
46class OneToOnePattern : public OpConversionPattern<TFrom> {
47public:
49 using OpAdaptor = typename TFrom::Adaptor;
50
51 LogicalResult
52 matchAndRewrite(TFrom op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter) const override {
54 rewriter.replaceOpWithNewOp<TTo>(
55 op, adaptor.getOperands(),
56 cloneAttrs ? op->getAttrs() : ArrayRef<::mlir::NamedAttribute>());
57 return success();
58 }
59};
60
61class ExtSConversionPattern : public OpConversionPattern<arith::ExtSIOp> {
62public:
63 using OpConversionPattern<arith::ExtSIOp>::OpConversionPattern;
64 using OpAdaptor = typename arith::ExtSIOp::Adaptor;
65
66 LogicalResult
67 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter) const override {
69 size_t outWidth = op.getType().getIntOrFloatBitWidth();
70 rewriter.replaceOp(op, comb::createOrFoldSExt(
71 op.getLoc(), op.getOperand(),
72 rewriter.getIntegerType(outWidth), rewriter));
73 return success();
74 }
75};
76
77class ExtZConversionPattern : public OpConversionPattern<arith::ExtUIOp> {
78public:
79 using OpConversionPattern<arith::ExtUIOp>::OpConversionPattern;
80 using OpAdaptor = typename arith::ExtUIOp::Adaptor;
81
82 LogicalResult
83 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
84 ConversionPatternRewriter &rewriter) const override {
85 auto loc = op.getLoc();
86 size_t outWidth = op.getOut().getType().getIntOrFloatBitWidth();
87 size_t inWidth = adaptor.getIn().getType().getIntOrFloatBitWidth();
88
89 rewriter.replaceOp(op, comb::ConcatOp::create(
90 rewriter, loc,
92 rewriter, loc, APInt(outWidth - inWidth, 0)),
93 adaptor.getIn()));
94 return success();
95 }
96};
97
98class TruncateConversionPattern : public OpConversionPattern<arith::TruncIOp> {
99public:
100 using OpConversionPattern<arith::TruncIOp>::OpConversionPattern;
101 using OpAdaptor = typename arith::TruncIOp::Adaptor;
102
103 LogicalResult
104 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter) const override {
106 size_t outWidth = op.getType().getIntOrFloatBitWidth();
107 rewriter.replaceOpWithNewOp<comb::ExtractOp>(op, adaptor.getIn(), 0,
108 outWidth);
109 return success();
110 }
111};
112
113class CompConversionPattern : public OpConversionPattern<arith::CmpIOp> {
114public:
115 using OpConversionPattern<arith::CmpIOp>::OpConversionPattern;
116 using OpAdaptor = typename arith::CmpIOp::Adaptor;
117
118 static comb::ICmpPredicate
119 arithToCombPredicate(arith::CmpIPredicate predicate) {
120 switch (predicate) {
121 case arith::CmpIPredicate::eq:
122 return comb::ICmpPredicate::eq;
123 case arith::CmpIPredicate::ne:
124 return comb::ICmpPredicate::ne;
125 case arith::CmpIPredicate::slt:
126 return comb::ICmpPredicate::slt;
127 case arith::CmpIPredicate::ult:
128 return comb::ICmpPredicate::ult;
129 case arith::CmpIPredicate::sle:
130 return comb::ICmpPredicate::sle;
131 case arith::CmpIPredicate::ule:
132 return comb::ICmpPredicate::ule;
133 case arith::CmpIPredicate::sgt:
134 return comb::ICmpPredicate::sgt;
135 case arith::CmpIPredicate::ugt:
136 return comb::ICmpPredicate::ugt;
137 case arith::CmpIPredicate::sge:
138 return comb::ICmpPredicate::sge;
139 case arith::CmpIPredicate::uge:
140 return comb::ICmpPredicate::uge;
141 }
142 llvm_unreachable("Unknown predicate");
143 }
144
145 LogicalResult
146 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter) const override {
148 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
149 op, arithToCombPredicate(op.getPredicate()), adaptor.getLhs(),
150 adaptor.getRhs());
151 return success();
152 }
153};
154
155struct ConstantConversionPattern
156 : public OpConversionPattern<arith::ConstantOp> {
157 using OpConversionPattern::OpConversionPattern;
158
159 LogicalResult
160 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
161 ConversionPatternRewriter &rewriter) const override {
162 // `hw.constant` only supports integers.
163 if (!isa<IntegerType>(op.getType()))
164 return failure();
165
166 rewriter.replaceOpWithNewOp<hw::ConstantOp>(
167 op, cast<IntegerAttr>(adaptor.getValue()));
168 return success();
169 }
170};
171
172struct MapArithToCombPass
173 : public circt::impl::MapArithToCombPassBase<MapArithToCombPass> {
174public:
175 MapArithToCombPass(bool enableBestEffortLowering) {
176 this->enableBestEffortLowering = enableBestEffortLowering;
177 }
178
179 void runOnOperation() override {
180 auto *ctx = &getContext();
181
182 ConversionTarget target(*ctx);
183 target.addLegalDialect<comb::CombDialect, hw::HWDialect>();
184 if (!enableBestEffortLowering) {
185 target.addIllegalDialect<arith::ArithDialect>();
186 } else {
187 // We make all arith operations with a potential lowering here (as
188 // specified in circt::populateArithToCombPatterns) illegal
189 target.addIllegalOp<arith::AddIOp>();
190 target.addIllegalOp<arith::SubIOp>();
191 target.addIllegalOp<arith::MulIOp>();
192 target.addIllegalOp<arith::DivSIOp>();
193 target.addIllegalOp<arith::DivUIOp>();
194 target.addIllegalOp<arith::RemSIOp>();
195 target.addIllegalOp<arith::RemUIOp>();
196 target.addIllegalOp<arith::AndIOp>();
197 target.addIllegalOp<arith::OrIOp>();
198 target.addIllegalOp<arith::XOrIOp>();
199 target.addIllegalOp<arith::ShLIOp>();
200 target.addIllegalOp<arith::ShRSIOp>();
201 target.addIllegalOp<arith::ShRUIOp>();
202 target.addIllegalOp<arith::SelectOp>();
203 target.addIllegalOp<arith::ExtSIOp>();
204 target.addIllegalOp<arith::ExtUIOp>();
205 target.addIllegalOp<arith::TruncIOp>();
206 target.addIllegalOp<arith::CmpIOp>();
207
208 // Force integer constants to be mapped to `hw.constant`.
209 target.addDynamicallyLegalOp<arith::ConstantOp>([](Operation *op) {
210 return !isa<IntegerType>(op->getResult(0).getType());
211 });
212 }
213 MapArithTypeConverter typeConverter;
214 RewritePatternSet patterns(ctx);
216
217 if (failed(applyPartialConversion(getOperation(), target,
218 std::move(patterns))))
219 signalPassFailure();
220 }
221};
222
223} // namespace
224
225void circt::populateArithToCombPatterns(mlir::RewritePatternSet &patterns,
226 TypeConverter &typeConverter) {
227 patterns.insert<OneToOnePattern<arith::AddIOp, comb::AddOp>,
228 OneToOnePattern<arith::SubIOp, comb::SubOp>,
229 OneToOnePattern<arith::MulIOp, comb::MulOp>,
230 OneToOnePattern<arith::DivSIOp, comb::DivSOp>,
231 OneToOnePattern<arith::DivUIOp, comb::DivUOp>,
232 OneToOnePattern<arith::RemSIOp, comb::ModSOp>,
233 OneToOnePattern<arith::RemUIOp, comb::ModUOp>,
234 OneToOnePattern<arith::AndIOp, comb::AndOp>,
235 OneToOnePattern<arith::OrIOp, comb::OrOp>,
236 OneToOnePattern<arith::XOrIOp, comb::XorOp>,
237 OneToOnePattern<arith::ShLIOp, comb::ShlOp>,
238 OneToOnePattern<arith::ShRSIOp, comb::ShrSOp>,
239 OneToOnePattern<arith::ShRUIOp, comb::ShrUOp>,
240 OneToOnePattern<arith::SelectOp, comb::MuxOp>,
241 ExtSConversionPattern, ExtZConversionPattern,
242 TruncateConversionPattern, CompConversionPattern,
243 ConstantConversionPattern>(typeConverter,
244 patterns.getContext());
245}
246
247std::unique_ptr<mlir::Pass>
248circt::createMapArithToCombPass(bool enableBestEffortLowering) {
249 return std::make_unique<MapArithToCombPass>(enableBestEffortLowering);
250}
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createMapArithToCombPass(bool enableBestEffortLowering=false)
void populateArithToCombPatterns(mlir::RewritePatternSet &patterns, TypeConverter &typeConverter)