12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/DialectConversion.h"
17#define GEN_PASS_DEF_CONVERTCOMBTOARITH
18#include "circt/Conversion/Passes.h.inc"
37 matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
38 ConversionPatternRewriter &rewriter)
const override {
40 Type inputType = op.getInput().getType();
41 if (isa<IntegerType>(inputType) && inputType.getIntOrFloatBitWidth() == 1) {
42 Type outType = rewriter.getIntegerType(op.getMultiple());
43 rewriter.replaceOpWithNewOp<ExtSIOp>(op, outType, adaptor.getInput());
47 SmallVector<Value> inputs(op.getMultiple(), adaptor.getInput());
48 rewriter.replaceOpWithNewOp<
ConcatOp>(op, inputs);
59 ConversionPatternRewriter &rewriter)
const override {
61 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, adaptor.getValueAttr());
71 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
72 ConversionPatternRewriter &rewriter)
const override {
75 switch (adaptor.getPredicate()) {
76 case ICmpPredicate::cne:
77 case ICmpPredicate::wne:
78 case ICmpPredicate::ne:
79 pred = CmpIPredicate::ne;
81 case ICmpPredicate::ceq:
82 case ICmpPredicate::weq:
83 case ICmpPredicate::eq:
84 pred = CmpIPredicate::eq;
86 case ICmpPredicate::sge:
87 pred = CmpIPredicate::sge;
89 case ICmpPredicate::sgt:
90 pred = CmpIPredicate::sgt;
92 case ICmpPredicate::sle:
93 pred = CmpIPredicate::sle;
95 case ICmpPredicate::slt:
96 pred = CmpIPredicate::slt;
98 case ICmpPredicate::uge:
99 pred = CmpIPredicate::uge;
101 case ICmpPredicate::ugt:
102 pred = CmpIPredicate::ugt;
104 case ICmpPredicate::ule:
105 pred = CmpIPredicate::ule;
107 case ICmpPredicate::ult:
108 pred = CmpIPredicate::ult;
112 rewriter.replaceOpWithNewOp<CmpIOp>(op, pred, adaptor.getLhs(),
123 matchAndRewrite(
ExtractOp op, OpAdaptor adaptor,
124 ConversionPatternRewriter &rewriter)
const override {
126 Value lowBit = arith::ConstantOp::create(
127 rewriter, op.getLoc(),
128 IntegerAttr::get(adaptor.getInput().getType(), adaptor.getLowBit()));
130 ShRUIOp::create(rewriter, op.getLoc(), adaptor.getInput(), lowBit);
131 rewriter.replaceOpWithNewOp<TruncIOp>(op, op.getResult().getType(),
142 matchAndRewrite(
ConcatOp op, OpAdaptor adaptor,
143 ConversionPatternRewriter &rewriter)
const override {
144 Type type = op.getResult().getType();
145 Location loc = op.getLoc();
149 if (op.getNumOperands() == 1) {
150 rewriter.replaceOp(op, adaptor.getOperands().back());
158 rewriter.createOrFold<ExtUIOp>(loc, type, adaptor.getOperands().back());
162 unsigned offset = type.getIntOrFloatBitWidth();
163 for (
auto operand : adaptor.getOperands().drop_back()) {
164 offset -= operand.getType().getIntOrFloatBitWidth();
165 auto offsetConst = arith::ConstantOp::create(
166 rewriter, loc, IntegerAttr::get(type, offset));
167 auto extended = rewriter.createOrFold<ExtUIOp>(loc, type, operand);
168 auto shifted = rewriter.createOrFold<ShLIOp>(loc, extended, offsetConst);
169 aggregate = rewriter.createOrFold<OrIOp>(loc, aggregate, shifted);
172 rewriter.replaceOp(op, aggregate);
178template <
typename SourceOp,
typename TargetOp>
181 using OpAdaptor =
typename SourceOp::Adaptor;
184 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter)
const override {
187 rewriter.replaceOpWithNewOp<TargetOp>(op, op.getResult().getType(),
188 adaptor.getOperands());
195template <
typename SourceOp,
typename TargetOp>
198 using OpAdaptor =
typename SourceOp::Adaptor;
201 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
202 ConversionPatternRewriter &rewriter)
const override {
207 Location loc = op.getLoc();
208 Value zero = arith::ConstantOp::create(
209 rewriter, loc, rewriter.getIntegerAttr(adaptor.getRhs().getType(), 0));
210 Value one = arith::ConstantOp::create(
211 rewriter, loc, rewriter.getIntegerAttr(adaptor.getRhs().getType(), 1));
212 Value isZero = arith::CmpIOp::create(rewriter, loc, CmpIPredicate::eq,
213 adaptor.getRhs(), zero);
215 arith::SelectOp::create(rewriter, loc, isZero, one, adaptor.getRhs());
216 rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getLhs(), divisor);
223template <
typename SourceOp,
typename TargetOp,
bool IsRem>
227 using OpAdaptor =
typename SourceOp::Adaptor;
230 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
231 ConversionPatternRewriter &rewriter)
const override {
232 Value dividend = adaptor.getLhs();
233 Value divisor = adaptor.getRhs();
234 Type ty = op.getType();
247 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
248 auto eq = [&](Value lhs, Value rhs) {
249 return arith::CmpIOp::create(b, CmpIPredicate::eq, lhs, rhs);
251 auto and_ = [&](Value lhs, Value rhs) {
252 return arith::AndIOp::create(b, lhs, rhs);
254 auto or_ = [&](Value lhs, Value rhs) {
255 return arith::OrIOp::create(b, lhs, rhs);
258 int bitwidth = ty.getIntOrFloatBitWidth();
259 Value zero = arith::ConstantOp::create(b, rewriter.getIntegerAttr(ty, 0));
260 Value one = arith::ConstantOp::create(b, rewriter.getIntegerAttr(ty, 1));
261 Value int_min = arith::ConstantOp::create(
262 b, rewriter.getIntegerAttr(ty, APInt::getSignedMinValue(bitwidth)));
263 Value minus_one = arith::ConstantOp::create(
264 b, rewriter.getIntegerAttr(ty, APInt::getAllOnes(bitwidth)));
266 Value isZero = eq(divisor, zero);
267 Value isOverflow = and_(eq(dividend, int_min), eq(divisor, minus_one));
268 Value pred = or_(isZero, isOverflow);
269 Value safeDivisor = arith::SelectOp::create(b, pred, one, divisor);
270 auto newOp = TargetOp::create(b, dividend, safeDivisor);
272 Value resultIfOverflow = IsRem ? zero : int_min;
274 arith::SelectOp::create(b, isOverflow, resultIfOverflow, newOp);
275 rewriter.replaceOp(op, result);
281template <
typename SourceOp,
typename TargetOp>
284 using OpAdaptor =
typename SourceOp::Adaptor;
287 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter)
const override {
291 ValueRange operands = adaptor.getOperands();
292 Value runner = operands[0];
294 llvm::make_range(operands.begin() + 1, operands.
end())) {
295 runner = TargetOp::create(rewriter, op.getLoc(), runner, operand);
297 rewriter.replaceOp(op, runner);
309template <
typename SourceOp,
typename TargetOp>
312 using OpAdaptor =
typename SourceOp::Adaptor;
315 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
316 ConversionPatternRewriter &rewriter)
const override {
317 unsigned shifteeWidth =
318 hw::type_cast<IntegerType>(adaptor.getLhs().getType())
319 .getIntOrFloatBitWidth();
320 auto zeroConstOp = arith::ConstantOp::create(
321 rewriter, op.getLoc(), IntegerAttr::get(adaptor.getLhs().getType(), 0));
322 auto maxShamtConstOp = arith::ConstantOp::create(
323 rewriter, op.getLoc(),
324 IntegerAttr::get(adaptor.getLhs().getType(), shifteeWidth));
325 auto shiftOp = rewriter.createOrFold<TargetOp>(
326 op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
327 auto isAllZeroOp = rewriter.createOrFold<CmpIOp>(
328 op.getLoc(), CmpIPredicate::uge, adaptor.getRhs(),
329 maxShamtConstOp.getResult());
330 rewriter.replaceOpWithNewOp<SelectOp>(op, isAllZeroOp, zeroConstOp,
341 matchAndRewrite(
ShrSOp op, OpAdaptor adaptor,
342 ConversionPatternRewriter &rewriter)
const override {
343 unsigned shifteeWidth =
344 hw::type_cast<IntegerType>(adaptor.getLhs().getType())
345 .getIntOrFloatBitWidth();
347 auto maxShamtMinusOneConstOp = arith::ConstantOp::create(
348 rewriter, op.getLoc(),
349 IntegerAttr::get(adaptor.getLhs().getType(), shifteeWidth - 1));
350 auto shamtOp = rewriter.createOrFold<MinUIOp>(op.getLoc(), adaptor.getRhs(),
351 maxShamtMinusOneConstOp);
352 rewriter.replaceOpWithNewOp<ShRSIOp>(op, adaptor.getLhs(), shamtOp);
364struct ConvertCombToArithPass
365 :
public circt::impl::ConvertCombToArithBase<ConvertCombToArithPass> {
366 void runOnOperation()
override;
371 TypeConverter &converter, mlir::RewritePatternSet &
patterns) {
373 CombReplicateOpConversion, HWConstantOpConversion, IcmpOpConversion,
374 ExtractOpConversion, ConcatOpConversion, ShrSOpConversion,
375 LogicalShiftConversion<ShlOp, ShLIOp>,
376 LogicalShiftConversion<ShrUOp, ShRUIOp>,
377 BinaryOpConversion<SubOp, SubIOp>,
378 DivSOpConversion<
DivSOp, DivSIOp,
false>,
379 DivUOpConversion<DivUOp, DivUIOp>,
380 DivSOpConversion<
ModSOp, RemSIOp,
true>,
381 DivUOpConversion<ModUOp, RemUIOp>, BinaryOpConversion<MuxOp, SelectOp>,
382 VariadicOpConversion<AddOp, AddIOp>, VariadicOpConversion<MulOp, MulIOp>,
383 VariadicOpConversion<AndOp, AndIOp>, VariadicOpConversion<OrOp, OrIOp>,
384 VariadicOpConversion<XorOp, XOrIOp>>(converter,
patterns.getContext());
387void ConvertCombToArithPass::runOnOperation() {
388 ConversionTarget target(getContext());
389 target.addIllegalDialect<comb::CombDialect>();
391 target.addLegalDialect<ArithDialect>();
397 target.addLegalOp<comb::ReverseOp>();
402 target.markUnknownOpDynamicallyLegal([](Operation *) {
return true; });
403 RewritePatternSet
patterns(&getContext());
404 TypeConverter converter;
405 converter.addConversion([](Type type) {
return type; });
409 ConversionConfig config;
410 config.allowPatternRollback =
false;
411 if (failed(mlir::applyPartialConversion(getOperation(), target,
417 return std::make_unique<ConvertCombToArithPass>();
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateCombToArithConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createConvertCombToArithPass()