17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Transforms/DialectConversion.h"
22#define GEN_PASS_DEF_MAPARITHTOCOMBPASS
23#include "circt/Transforms/Passes.h.inc"
33class MapArithTypeConverter :
public mlir::TypeConverter {
35 MapArithTypeConverter() {
36 addConversion([](Type type) {
37 if (hw::isHWValueType(type))
45template <
typename TFrom,
typename TTo,
bool cloneAttrs = false>
49 using OpAdaptor =
typename TFrom::Adaptor;
52 matchAndRewrite(TFrom op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter)
const override {
54 rewriter.replaceOpWithNewOp<TTo>(
55 op, adaptor.getOperands(),
56 cloneAttrs ? op->getAttrs() : ArrayRef<::mlir::NamedAttribute>());
64 using OpAdaptor =
typename arith::ExtSIOp::Adaptor;
67 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter)
const override {
69 size_t outWidth = op.getType().getIntOrFloatBitWidth();
71 op, comb::createOrFoldSExt(rewriter, op.getLoc(), op.getOperand(),
72 rewriter.getIntegerType(outWidth)));
80 using OpAdaptor =
typename arith::ExtUIOp::Adaptor;
83 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
84 ConversionPatternRewriter &rewriter)
const override {
85 auto loc = op.getLoc();
86 size_t outWidth = op.getOut().getType().getIntOrFloatBitWidth();
87 size_t inWidth = adaptor.getIn().getType().getIntOrFloatBitWidth();
89 rewriter.replaceOp(op, comb::ConcatOp::create(
92 rewriter, loc, APInt(outWidth - inWidth, 0)),
101 using OpAdaptor =
typename arith::TruncIOp::Adaptor;
104 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter)
const override {
106 size_t outWidth = op.getType().getIntOrFloatBitWidth();
116 using OpAdaptor =
typename arith::CmpIOp::Adaptor;
118 static comb::ICmpPredicate
119 arithToCombPredicate(arith::CmpIPredicate predicate) {
121 case arith::CmpIPredicate::eq:
122 return comb::ICmpPredicate::eq;
123 case arith::CmpIPredicate::ne:
124 return comb::ICmpPredicate::ne;
125 case arith::CmpIPredicate::slt:
126 return comb::ICmpPredicate::slt;
127 case arith::CmpIPredicate::ult:
128 return comb::ICmpPredicate::ult;
129 case arith::CmpIPredicate::sle:
130 return comb::ICmpPredicate::sle;
131 case arith::CmpIPredicate::ule:
132 return comb::ICmpPredicate::ule;
133 case arith::CmpIPredicate::sgt:
134 return comb::ICmpPredicate::sgt;
135 case arith::CmpIPredicate::ugt:
136 return comb::ICmpPredicate::ugt;
137 case arith::CmpIPredicate::sge:
138 return comb::ICmpPredicate::sge;
139 case arith::CmpIPredicate::uge:
140 return comb::ICmpPredicate::uge;
142 llvm_unreachable(
"Unknown predicate");
146 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter)
const override {
148 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
149 op, arithToCombPredicate(op.getPredicate()), adaptor.getLhs(),
155struct ConstantConversionPattern
157 using OpConversionPattern::OpConversionPattern;
160 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
161 ConversionPatternRewriter &rewriter)
const override {
163 if (!isa<IntegerType>(op.getType()))
167 op, cast<IntegerAttr>(adaptor.getValue()));
174template <
typename ArithOp, comb::ICmpPredicate pred>
178 using OpAdaptor =
typename ArithOp::Adaptor;
181 matchAndRewrite(ArithOp op, OpAdaptor adaptor,
182 ConversionPatternRewriter &rewriter)
const override {
185 auto cmp = comb::ICmpOp::create(rewriter, op.getLoc(), pred,
186 adaptor.getLhs(), adaptor.getRhs());
187 rewriter.replaceOpWithNewOp<
comb::MuxOp>(op, cmp, adaptor.getLhs(),
193struct MapArithToCombPass
194 :
public circt::impl::MapArithToCombPassBase<MapArithToCombPass> {
196 MapArithToCombPass(
bool enableBestEffortLowering) {
197 this->enableBestEffortLowering = enableBestEffortLowering;
200 void runOnOperation()
override {
201 auto *ctx = &getContext();
203 ConversionTarget target(*ctx);
204 target.addLegalDialect<comb::CombDialect, hw::HWDialect>();
205 if (!enableBestEffortLowering) {
206 target.addIllegalDialect<arith::ArithDialect>();
210 target.addIllegalOp<arith::AddIOp>();
211 target.addIllegalOp<arith::SubIOp>();
212 target.addIllegalOp<arith::MulIOp>();
213 target.addIllegalOp<arith::DivSIOp>();
214 target.addIllegalOp<arith::DivUIOp>();
215 target.addIllegalOp<arith::RemSIOp>();
216 target.addIllegalOp<arith::RemUIOp>();
217 target.addIllegalOp<arith::AndIOp>();
218 target.addIllegalOp<arith::OrIOp>();
219 target.addIllegalOp<arith::XOrIOp>();
220 target.addIllegalOp<arith::ShLIOp>();
221 target.addIllegalOp<arith::ShRSIOp>();
222 target.addIllegalOp<arith::ShRUIOp>();
223 target.addIllegalOp<arith::SelectOp>();
224 target.addIllegalOp<arith::ExtSIOp>();
225 target.addIllegalOp<arith::ExtUIOp>();
226 target.addIllegalOp<arith::TruncIOp>();
227 target.addIllegalOp<arith::CmpIOp>();
228 target.addIllegalOp<arith::MaxSIOp>();
229 target.addIllegalOp<arith::MaxUIOp>();
230 target.addIllegalOp<arith::MinSIOp>();
231 target.addIllegalOp<arith::MinUIOp>();
234 target.addDynamicallyLegalOp<arith::ConstantOp>([](Operation *op) {
235 return !isa<IntegerType>(op->getResult(0).getType());
238 MapArithTypeConverter typeConverter;
242 if (failed(applyPartialConversion(getOperation(), target,
251 TypeConverter &typeConverter) {
253 OneToOnePattern<arith::AddIOp, comb::AddOp>,
254 OneToOnePattern<arith::SubIOp, comb::SubOp>,
255 OneToOnePattern<arith::MulIOp, comb::MulOp>,
256 OneToOnePattern<arith::DivSIOp, comb::DivSOp>,
257 OneToOnePattern<arith::DivUIOp, comb::DivUOp>,
258 OneToOnePattern<arith::RemSIOp, comb::ModSOp>,
259 OneToOnePattern<arith::RemUIOp, comb::ModUOp>,
260 OneToOnePattern<arith::AndIOp, comb::AndOp>,
261 OneToOnePattern<arith::OrIOp, comb::OrOp>,
262 OneToOnePattern<arith::XOrIOp, comb::XorOp>,
263 OneToOnePattern<arith::ShLIOp, comb::ShlOp>,
264 OneToOnePattern<arith::ShRSIOp, comb::ShrSOp>,
265 OneToOnePattern<arith::ShRUIOp, comb::ShrUOp>,
266 OneToOnePattern<arith::SelectOp, comb::MuxOp>,
267 MinMaxConversionPattern<arith::MaxSIOp, comb::ICmpPredicate::sge>,
268 MinMaxConversionPattern<arith::MaxUIOp, comb::ICmpPredicate::uge>,
269 MinMaxConversionPattern<arith::MinSIOp, comb::ICmpPredicate::sle>,
270 MinMaxConversionPattern<arith::MinUIOp, comb::ICmpPredicate::ule>,
271 ExtSConversionPattern, ExtZConversionPattern, TruncateConversionPattern,
272 CompConversionPattern, ConstantConversionPattern>(typeConverter,
276std::unique_ptr<mlir::Pass>
278 return std::make_unique<MapArithToCombPass>(enableBestEffortLowering);
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createMapArithToCombPass(bool enableBestEffortLowering=false)
void populateArithToCombPatterns(mlir::RewritePatternSet &patterns, TypeConverter &typeConverter)