18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Transforms/DialectConversion.h"
22 using namespace circt;
28 class MapArithTypeConverter :
public mlir::TypeConverter {
30 MapArithTypeConverter() {
31 addConversion([](Type type) {
32 if (type.isa<mlir::IntegerType>())
40 template <
typename TFrom,
typename TTo,
bool cloneAttrs = false>
44 using OpAdaptor =
typename TFrom::Adaptor;
47 matchAndRewrite(TFrom op, OpAdaptor adaptor,
48 ConversionPatternRewriter &rewriter)
const override {
49 rewriter.replaceOpWithNewOp<TTo>(
50 op, adaptor.getOperands(),
51 cloneAttrs ? op->getAttrs() : ArrayRef<::mlir::NamedAttribute>());
59 using OpAdaptor =
typename arith::ExtSIOp::Adaptor;
62 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
63 ConversionPatternRewriter &rewriter)
const override {
64 size_t outWidth = op.getType().getIntOrFloatBitWidth();
66 op.getLoc(), op.getOperand(),
67 rewriter.getIntegerType(outWidth), rewriter));
75 using OpAdaptor =
typename arith::ExtUIOp::Adaptor;
78 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
79 ConversionPatternRewriter &rewriter)
const override {
80 auto loc = op.getLoc();
81 size_t outWidth = op.getOut().getType().getIntOrFloatBitWidth();
82 size_t inWidth = adaptor.getIn().getType().getIntOrFloatBitWidth();
87 loc, APInt(outWidth - inWidth, 0)),
96 using OpAdaptor =
typename arith::TruncIOp::Adaptor;
99 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
100 ConversionPatternRewriter &rewriter)
const override {
101 size_t outWidth = op.getType().getIntOrFloatBitWidth();
111 using OpAdaptor =
typename arith::CmpIOp::Adaptor;
113 static comb::ICmpPredicate
114 arithToCombPredicate(arith::CmpIPredicate predicate) {
116 case arith::CmpIPredicate::eq:
117 return comb::ICmpPredicate::eq;
118 case arith::CmpIPredicate::ne:
119 return comb::ICmpPredicate::ne;
120 case arith::CmpIPredicate::slt:
121 return comb::ICmpPredicate::slt;
122 case arith::CmpIPredicate::ult:
123 return comb::ICmpPredicate::ult;
124 case arith::CmpIPredicate::sle:
125 return comb::ICmpPredicate::sle;
126 case arith::CmpIPredicate::ule:
127 return comb::ICmpPredicate::ule;
128 case arith::CmpIPredicate::sgt:
129 return comb::ICmpPredicate::sgt;
130 case arith::CmpIPredicate::ugt:
131 return comb::ICmpPredicate::ugt;
132 case arith::CmpIPredicate::sge:
133 return comb::ICmpPredicate::sge;
134 case arith::CmpIPredicate::uge:
135 return comb::ICmpPredicate::uge;
137 llvm_unreachable(
"Unknown predicate");
141 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
142 ConversionPatternRewriter &rewriter)
const override {
143 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
144 op, arithToCombPredicate(op.getPredicate()), adaptor.getLhs(),
150 struct MapArithToCombPass :
public MapArithToCombPassBase<MapArithToCombPass> {
152 void runOnOperation()
override {
153 auto *ctx = &getContext();
155 ConversionTarget target(*ctx);
156 target.addLegalDialect<comb::CombDialect, hw::HWDialect>();
157 target.addIllegalDialect<arith::ArithDialect>();
158 MapArithTypeConverter typeConverter;
161 patterns.insert<OneToOnePattern<arith::AddIOp, comb::AddOp>,
162 OneToOnePattern<arith::SubIOp, comb::SubOp>,
163 OneToOnePattern<arith::MulIOp, comb::MulOp>,
164 OneToOnePattern<arith::DivSIOp, comb::DivSOp>,
165 OneToOnePattern<arith::DivUIOp, comb::DivUOp>,
166 OneToOnePattern<arith::RemSIOp, comb::ModSOp>,
167 OneToOnePattern<arith::RemUIOp, comb::ModUOp>,
168 OneToOnePattern<arith::AndIOp, comb::AndOp>,
169 OneToOnePattern<arith::OrIOp, comb::OrOp>,
170 OneToOnePattern<arith::XOrIOp, comb::XorOp>,
171 OneToOnePattern<arith::ShLIOp, comb::ShlOp>,
172 OneToOnePattern<arith::ShRSIOp, comb::ShrSOp>,
173 OneToOnePattern<arith::ShRUIOp, comb::ShrUOp>,
174 OneToOnePattern<arith::ConstantOp, hw::ConstantOp, true>,
175 OneToOnePattern<arith::SelectOp, comb::MuxOp>,
176 ExtSConversionPattern, ExtZConversionPattern,
177 TruncateConversionPattern, CompConversionPattern>(
180 if (failed(applyPartialConversion(getOperation(), target,
189 return std::make_unique<MapArithToCombPass>();
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createMapArithToCombPass()