17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Pass/Pass.h"
19#include "mlir/Transforms/DialectConversion.h"
22#define GEN_PASS_DEF_MAPARITHTOCOMBPASS
23#include "circt/Transforms/Passes.h.inc"
33class MapArithTypeConverter :
public mlir::TypeConverter {
35 MapArithTypeConverter() {
36 addConversion([](Type type) {
37 if (hw::isHWValueType(type))
45template <
typename TFrom,
typename TTo,
bool cloneAttrs = false>
49 using OpAdaptor =
typename TFrom::Adaptor;
52 matchAndRewrite(TFrom op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter)
const override {
54 rewriter.replaceOpWithNewOp<TTo>(
55 op, adaptor.getOperands(),
56 cloneAttrs ? op->getAttrs() : ArrayRef<::mlir::NamedAttribute>());
64 using OpAdaptor =
typename arith::ExtSIOp::Adaptor;
67 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter)
const override {
69 size_t outWidth = op.getType().getIntOrFloatBitWidth();
70 rewriter.replaceOp(op, comb::createOrFoldSExt(
71 op.getLoc(), op.getOperand(),
72 rewriter.getIntegerType(outWidth), rewriter));
80 using OpAdaptor =
typename arith::ExtUIOp::Adaptor;
83 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
84 ConversionPatternRewriter &rewriter)
const override {
85 auto loc = op.getLoc();
86 size_t outWidth = op.getOut().getType().getIntOrFloatBitWidth();
87 size_t inWidth = adaptor.getIn().getType().getIntOrFloatBitWidth();
89 rewriter.replaceOp(op, comb::ConcatOp::create(
92 rewriter, loc, APInt(outWidth - inWidth, 0)),
101 using OpAdaptor =
typename arith::TruncIOp::Adaptor;
104 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter)
const override {
106 size_t outWidth = op.getType().getIntOrFloatBitWidth();
116 using OpAdaptor =
typename arith::CmpIOp::Adaptor;
118 static comb::ICmpPredicate
119 arithToCombPredicate(arith::CmpIPredicate predicate) {
121 case arith::CmpIPredicate::eq:
122 return comb::ICmpPredicate::eq;
123 case arith::CmpIPredicate::ne:
124 return comb::ICmpPredicate::ne;
125 case arith::CmpIPredicate::slt:
126 return comb::ICmpPredicate::slt;
127 case arith::CmpIPredicate::ult:
128 return comb::ICmpPredicate::ult;
129 case arith::CmpIPredicate::sle:
130 return comb::ICmpPredicate::sle;
131 case arith::CmpIPredicate::ule:
132 return comb::ICmpPredicate::ule;
133 case arith::CmpIPredicate::sgt:
134 return comb::ICmpPredicate::sgt;
135 case arith::CmpIPredicate::ugt:
136 return comb::ICmpPredicate::ugt;
137 case arith::CmpIPredicate::sge:
138 return comb::ICmpPredicate::sge;
139 case arith::CmpIPredicate::uge:
140 return comb::ICmpPredicate::uge;
142 llvm_unreachable(
"Unknown predicate");
146 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter)
const override {
148 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
149 op, arithToCombPredicate(op.getPredicate()), adaptor.getLhs(),
155struct ConstantConversionPattern
157 using OpConversionPattern::OpConversionPattern;
160 matchAndRewrite(arith::ConstantOp op, OpAdaptor adaptor,
161 ConversionPatternRewriter &rewriter)
const override {
163 if (!isa<IntegerType>(op.getType()))
167 op, cast<IntegerAttr>(adaptor.getValue()));
172struct MapArithToCombPass
173 :
public circt::impl::MapArithToCombPassBase<MapArithToCombPass> {
175 MapArithToCombPass(
bool enableBestEffortLowering) {
176 this->enableBestEffortLowering = enableBestEffortLowering;
179 void runOnOperation()
override {
180 auto *ctx = &getContext();
182 ConversionTarget target(*ctx);
183 target.addLegalDialect<comb::CombDialect, hw::HWDialect>();
184 if (!enableBestEffortLowering) {
185 target.addIllegalDialect<arith::ArithDialect>();
189 target.addIllegalOp<arith::AddIOp>();
190 target.addIllegalOp<arith::SubIOp>();
191 target.addIllegalOp<arith::MulIOp>();
192 target.addIllegalOp<arith::DivSIOp>();
193 target.addIllegalOp<arith::DivUIOp>();
194 target.addIllegalOp<arith::RemSIOp>();
195 target.addIllegalOp<arith::RemUIOp>();
196 target.addIllegalOp<arith::AndIOp>();
197 target.addIllegalOp<arith::OrIOp>();
198 target.addIllegalOp<arith::XOrIOp>();
199 target.addIllegalOp<arith::ShLIOp>();
200 target.addIllegalOp<arith::ShRSIOp>();
201 target.addIllegalOp<arith::ShRUIOp>();
202 target.addIllegalOp<arith::SelectOp>();
203 target.addIllegalOp<arith::ExtSIOp>();
204 target.addIllegalOp<arith::ExtUIOp>();
205 target.addIllegalOp<arith::TruncIOp>();
206 target.addIllegalOp<arith::CmpIOp>();
209 target.addDynamicallyLegalOp<arith::ConstantOp>([](Operation *op) {
210 return !isa<IntegerType>(op->getResult(0).getType());
213 MapArithTypeConverter typeConverter;
217 if (failed(applyPartialConversion(getOperation(), target,
226 TypeConverter &typeConverter) {
227 patterns.insert<OneToOnePattern<arith::AddIOp, comb::AddOp>,
228 OneToOnePattern<arith::SubIOp, comb::SubOp>,
229 OneToOnePattern<arith::MulIOp, comb::MulOp>,
230 OneToOnePattern<arith::DivSIOp, comb::DivSOp>,
231 OneToOnePattern<arith::DivUIOp, comb::DivUOp>,
232 OneToOnePattern<arith::RemSIOp, comb::ModSOp>,
233 OneToOnePattern<arith::RemUIOp, comb::ModUOp>,
234 OneToOnePattern<arith::AndIOp, comb::AndOp>,
235 OneToOnePattern<arith::OrIOp, comb::OrOp>,
236 OneToOnePattern<arith::XOrIOp, comb::XorOp>,
237 OneToOnePattern<arith::ShLIOp, comb::ShlOp>,
238 OneToOnePattern<arith::ShRSIOp, comb::ShrSOp>,
239 OneToOnePattern<arith::ShRUIOp, comb::ShrUOp>,
240 OneToOnePattern<arith::SelectOp, comb::MuxOp>,
241 ExtSConversionPattern, ExtZConversionPattern,
242 TruncateConversionPattern, CompConversionPattern,
243 ConstantConversionPattern>(typeConverter,
247std::unique_ptr<mlir::Pass>
249 return std::make_unique<MapArithToCombPass>(enableBestEffortLowering);
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createMapArithToCombPass(bool enableBestEffortLowering=false)
void populateArithToCombPatterns(mlir::RewritePatternSet &patterns, TypeConverter &typeConverter)