17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
22 #define GEN_PASS_DEF_MAPARITHTOCOMBPASS
23 #include "circt/Transforms/Passes.h.inc"
27 using namespace circt;
33 class MapArithTypeConverter :
public mlir::TypeConverter {
35 MapArithTypeConverter() {
36 addConversion([](Type type) {
37 if (isa<mlir::IntegerType>(type))
45 template <
typename TFrom,
typename TTo,
bool cloneAttrs = false>
49 using OpAdaptor =
typename TFrom::Adaptor;
52 matchAndRewrite(TFrom op, OpAdaptor adaptor,
53 ConversionPatternRewriter &rewriter)
const override {
54 rewriter.replaceOpWithNewOp<TTo>(
55 op, adaptor.getOperands(),
56 cloneAttrs ? op->getAttrs() : ArrayRef<::mlir::NamedAttribute>());
64 using OpAdaptor =
typename arith::ExtSIOp::Adaptor;
67 matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
68 ConversionPatternRewriter &rewriter)
const override {
69 size_t outWidth = op.getType().getIntOrFloatBitWidth();
71 op.getLoc(), op.getOperand(),
72 rewriter.getIntegerType(outWidth), rewriter));
80 using OpAdaptor =
typename arith::ExtUIOp::Adaptor;
83 matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
84 ConversionPatternRewriter &rewriter)
const override {
85 auto loc = op.getLoc();
86 size_t outWidth = op.getOut().getType().getIntOrFloatBitWidth();
87 size_t inWidth = adaptor.getIn().getType().getIntOrFloatBitWidth();
92 loc, APInt(outWidth - inWidth, 0)),
101 using OpAdaptor =
typename arith::TruncIOp::Adaptor;
104 matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter)
const override {
106 size_t outWidth = op.getType().getIntOrFloatBitWidth();
116 using OpAdaptor =
typename arith::CmpIOp::Adaptor;
118 static comb::ICmpPredicate
119 arithToCombPredicate(arith::CmpIPredicate predicate) {
121 case arith::CmpIPredicate::eq:
122 return comb::ICmpPredicate::eq;
123 case arith::CmpIPredicate::ne:
124 return comb::ICmpPredicate::ne;
125 case arith::CmpIPredicate::slt:
126 return comb::ICmpPredicate::slt;
127 case arith::CmpIPredicate::ult:
128 return comb::ICmpPredicate::ult;
129 case arith::CmpIPredicate::sle:
130 return comb::ICmpPredicate::sle;
131 case arith::CmpIPredicate::ule:
132 return comb::ICmpPredicate::ule;
133 case arith::CmpIPredicate::sgt:
134 return comb::ICmpPredicate::sgt;
135 case arith::CmpIPredicate::ugt:
136 return comb::ICmpPredicate::ugt;
137 case arith::CmpIPredicate::sge:
138 return comb::ICmpPredicate::sge;
139 case arith::CmpIPredicate::uge:
140 return comb::ICmpPredicate::uge;
142 llvm_unreachable(
"Unknown predicate");
146 matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
147 ConversionPatternRewriter &rewriter)
const override {
148 rewriter.replaceOpWithNewOp<comb::ICmpOp>(
149 op, arithToCombPredicate(op.getPredicate()), adaptor.getLhs(),
155 struct MapArithToCombPass
156 :
public circt::impl::MapArithToCombPassBase<MapArithToCombPass> {
158 void runOnOperation()
override {
159 auto *ctx = &getContext();
161 ConversionTarget target(*ctx);
162 target.addLegalDialect<comb::CombDialect, hw::HWDialect>();
163 target.addIllegalDialect<arith::ArithDialect>();
164 MapArithTypeConverter typeConverter;
167 patterns.insert<OneToOnePattern<arith::AddIOp, comb::AddOp>,
168 OneToOnePattern<arith::SubIOp, comb::SubOp>,
169 OneToOnePattern<arith::MulIOp, comb::MulOp>,
170 OneToOnePattern<arith::DivSIOp, comb::DivSOp>,
171 OneToOnePattern<arith::DivUIOp, comb::DivUOp>,
172 OneToOnePattern<arith::RemSIOp, comb::ModSOp>,
173 OneToOnePattern<arith::RemUIOp, comb::ModUOp>,
174 OneToOnePattern<arith::AndIOp, comb::AndOp>,
175 OneToOnePattern<arith::OrIOp, comb::OrOp>,
176 OneToOnePattern<arith::XOrIOp, comb::XorOp>,
177 OneToOnePattern<arith::ShLIOp, comb::ShlOp>,
178 OneToOnePattern<arith::ShRSIOp, comb::ShrSOp>,
179 OneToOnePattern<arith::ShRUIOp, comb::ShrUOp>,
180 OneToOnePattern<arith::ConstantOp, hw::ConstantOp, true>,
181 OneToOnePattern<arith::SelectOp, comb::MuxOp>,
182 ExtSConversionPattern, ExtZConversionPattern,
183 TruncateConversionPattern, CompConversionPattern>(
186 if (failed(applyPartialConversion(getOperation(), target,
195 return std::make_unique<MapArithToCombPass>();
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createMapArithToCombPass()