CIRCT  19.0.0git
MapArithToComb.cpp
Go to the documentation of this file.
1 //===- MapArithToComb.cpp - Arith-to-comb mapping pass ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the MapArithToComb pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetail.h"
16 #include "circt/Dialect/HW/HWOps.h"
18 #include "mlir/Dialect/Arith/IR/Arith.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 using namespace mlir;
22 using namespace circt;
23 
24 namespace {
25 
26 // A type converter which legalizes integer types, thus ensuring that vector
27 // types are illegal.
28 class MapArithTypeConverter : public mlir::TypeConverter {
29 public:
30  MapArithTypeConverter() {
31  addConversion([](Type type) {
32  if (type.isa<mlir::IntegerType>())
33  return type;
34 
35  return Type();
36  });
37  }
38 };
39 
40 template <typename TFrom, typename TTo, bool cloneAttrs = false>
41 class OneToOnePattern : public OpConversionPattern<TFrom> {
42 public:
44  using OpAdaptor = typename TFrom::Adaptor;
45 
46  LogicalResult
47  matchAndRewrite(TFrom op, OpAdaptor adaptor,
48  ConversionPatternRewriter &rewriter) const override {
49  rewriter.replaceOpWithNewOp<TTo>(
50  op, adaptor.getOperands(),
51  cloneAttrs ? op->getAttrs() : ArrayRef<::mlir::NamedAttribute>());
52  return success();
53  }
54 };
55 
56 class ExtSConversionPattern : public OpConversionPattern<arith::ExtSIOp> {
57 public:
59  using OpAdaptor = typename arith::ExtSIOp::Adaptor;
60 
61  LogicalResult
62  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
63  ConversionPatternRewriter &rewriter) const override {
64  size_t outWidth = op.getType().getIntOrFloatBitWidth();
65  rewriter.replaceOp(op, comb::createOrFoldSExt(
66  op.getLoc(), op.getOperand(),
67  rewriter.getIntegerType(outWidth), rewriter));
68  return success();
69  }
70 };
71 
72 class ExtZConversionPattern : public OpConversionPattern<arith::ExtUIOp> {
73 public:
75  using OpAdaptor = typename arith::ExtUIOp::Adaptor;
76 
77  LogicalResult
78  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
79  ConversionPatternRewriter &rewriter) const override {
80  auto loc = op.getLoc();
81  size_t outWidth = op.getOut().getType().getIntOrFloatBitWidth();
82  size_t inWidth = adaptor.getIn().getType().getIntOrFloatBitWidth();
83 
84  rewriter.replaceOp(op, rewriter.create<comb::ConcatOp>(
85  loc,
86  rewriter.create<hw::ConstantOp>(
87  loc, APInt(outWidth - inWidth, 0)),
88  adaptor.getIn()));
89  return success();
90  }
91 };
92 
93 class TruncateConversionPattern : public OpConversionPattern<arith::TruncIOp> {
94 public:
96  using OpAdaptor = typename arith::TruncIOp::Adaptor;
97 
98  LogicalResult
99  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
100  ConversionPatternRewriter &rewriter) const override {
101  size_t outWidth = op.getType().getIntOrFloatBitWidth();
102  rewriter.replaceOpWithNewOp<comb::ExtractOp>(op, adaptor.getIn(), 0,
103  outWidth);
104  return success();
105  }
106 };
107 
108 class CompConversionPattern : public OpConversionPattern<arith::CmpIOp> {
109 public:
111  using OpAdaptor = typename arith::CmpIOp::Adaptor;
112 
113  static comb::ICmpPredicate
114  arithToCombPredicate(arith::CmpIPredicate predicate) {
115  switch (predicate) {
116  case arith::CmpIPredicate::eq:
117  return comb::ICmpPredicate::eq;
118  case arith::CmpIPredicate::ne:
119  return comb::ICmpPredicate::ne;
120  case arith::CmpIPredicate::slt:
121  return comb::ICmpPredicate::slt;
122  case arith::CmpIPredicate::ult:
123  return comb::ICmpPredicate::ult;
124  case arith::CmpIPredicate::sle:
125  return comb::ICmpPredicate::sle;
126  case arith::CmpIPredicate::ule:
127  return comb::ICmpPredicate::ule;
128  case arith::CmpIPredicate::sgt:
129  return comb::ICmpPredicate::sgt;
130  case arith::CmpIPredicate::ugt:
131  return comb::ICmpPredicate::ugt;
132  case arith::CmpIPredicate::sge:
133  return comb::ICmpPredicate::sge;
134  case arith::CmpIPredicate::uge:
135  return comb::ICmpPredicate::uge;
136  }
137  llvm_unreachable("Unknown predicate");
138  }
139 
140  LogicalResult
141  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
142  ConversionPatternRewriter &rewriter) const override {
143  rewriter.replaceOpWithNewOp<comb::ICmpOp>(
144  op, arithToCombPredicate(op.getPredicate()), adaptor.getLhs(),
145  adaptor.getRhs());
146  return success();
147  }
148 };
149 
150 struct MapArithToCombPass : public MapArithToCombPassBase<MapArithToCombPass> {
151 public:
152  void runOnOperation() override {
153  auto *ctx = &getContext();
154 
155  ConversionTarget target(*ctx);
156  target.addLegalDialect<comb::CombDialect, hw::HWDialect>();
157  target.addIllegalDialect<arith::ArithDialect>();
158  MapArithTypeConverter typeConverter;
159  RewritePatternSet patterns(ctx);
160 
161  patterns.insert<OneToOnePattern<arith::AddIOp, comb::AddOp>,
162  OneToOnePattern<arith::SubIOp, comb::SubOp>,
163  OneToOnePattern<arith::MulIOp, comb::MulOp>,
164  OneToOnePattern<arith::DivSIOp, comb::DivSOp>,
165  OneToOnePattern<arith::DivUIOp, comb::DivUOp>,
166  OneToOnePattern<arith::RemSIOp, comb::ModSOp>,
167  OneToOnePattern<arith::RemUIOp, comb::ModUOp>,
168  OneToOnePattern<arith::AndIOp, comb::AndOp>,
169  OneToOnePattern<arith::OrIOp, comb::OrOp>,
170  OneToOnePattern<arith::XOrIOp, comb::XorOp>,
171  OneToOnePattern<arith::ShLIOp, comb::ShlOp>,
172  OneToOnePattern<arith::ShRSIOp, comb::ShrSOp>,
173  OneToOnePattern<arith::ShRUIOp, comb::ShrUOp>,
174  OneToOnePattern<arith::ConstantOp, hw::ConstantOp, true>,
175  OneToOnePattern<arith::SelectOp, comb::MuxOp>,
176  ExtSConversionPattern, ExtZConversionPattern,
177  TruncateConversionPattern, CompConversionPattern>(
178  typeConverter, ctx);
179 
180  if (failed(applyPartialConversion(getOperation(), target,
181  std::move(patterns))))
182  signalPassFailure();
183  }
184 };
185 
186 } // namespace
187 
188 std::unique_ptr<mlir::Pass> circt::createMapArithToCombPass() {
189  return std::make_unique<MapArithToCombPass>();
190 }
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Definition: CombOps.cpp:25
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
std::unique_ptr< mlir::Pass > createMapArithToCombPass()