CIRCT  20.0.0git
MapArithToComb.cpp
Go to the documentation of this file.
1 //===- MapArithToComb.cpp - Arith-to-comb mapping pass ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Contains the definitions of the MapArithToComb pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "circt/Dialect/HW/HWOps.h"
17 #include "mlir/Dialect/Arith/IR/Arith.h"
18 #include "mlir/Pass/Pass.h"
19 #include "mlir/Transforms/DialectConversion.h"
20 
21 namespace circt {
22 #define GEN_PASS_DEF_MAPARITHTOCOMBPASS
23 #include "circt/Transforms/Passes.h.inc"
24 } // namespace circt
25 
26 using namespace mlir;
27 using namespace circt;
28 
29 namespace {
30 
31 // A type converter which legalizes integer types, thus ensuring that vector
32 // types are illegal.
33 class MapArithTypeConverter : public mlir::TypeConverter {
34 public:
35  MapArithTypeConverter() {
36  addConversion([](Type type) {
37  if (isa<mlir::IntegerType>(type))
38  return type;
39 
40  return Type();
41  });
42  }
43 };
44 
45 template <typename TFrom, typename TTo, bool cloneAttrs = false>
46 class OneToOnePattern : public OpConversionPattern<TFrom> {
47 public:
49  using OpAdaptor = typename TFrom::Adaptor;
50 
51  LogicalResult
52  matchAndRewrite(TFrom op, OpAdaptor adaptor,
53  ConversionPatternRewriter &rewriter) const override {
54  rewriter.replaceOpWithNewOp<TTo>(
55  op, adaptor.getOperands(),
56  cloneAttrs ? op->getAttrs() : ArrayRef<::mlir::NamedAttribute>());
57  return success();
58  }
59 };
60 
61 class ExtSConversionPattern : public OpConversionPattern<arith::ExtSIOp> {
62 public:
64  using OpAdaptor = typename arith::ExtSIOp::Adaptor;
65 
66  LogicalResult
67  matchAndRewrite(arith::ExtSIOp op, OpAdaptor adaptor,
68  ConversionPatternRewriter &rewriter) const override {
69  size_t outWidth = op.getType().getIntOrFloatBitWidth();
70  rewriter.replaceOp(op, comb::createOrFoldSExt(
71  op.getLoc(), op.getOperand(),
72  rewriter.getIntegerType(outWidth), rewriter));
73  return success();
74  }
75 };
76 
77 class ExtZConversionPattern : public OpConversionPattern<arith::ExtUIOp> {
78 public:
80  using OpAdaptor = typename arith::ExtUIOp::Adaptor;
81 
82  LogicalResult
83  matchAndRewrite(arith::ExtUIOp op, OpAdaptor adaptor,
84  ConversionPatternRewriter &rewriter) const override {
85  auto loc = op.getLoc();
86  size_t outWidth = op.getOut().getType().getIntOrFloatBitWidth();
87  size_t inWidth = adaptor.getIn().getType().getIntOrFloatBitWidth();
88 
89  rewriter.replaceOp(op, rewriter.create<comb::ConcatOp>(
90  loc,
91  rewriter.create<hw::ConstantOp>(
92  loc, APInt(outWidth - inWidth, 0)),
93  adaptor.getIn()));
94  return success();
95  }
96 };
97 
98 class TruncateConversionPattern : public OpConversionPattern<arith::TruncIOp> {
99 public:
101  using OpAdaptor = typename arith::TruncIOp::Adaptor;
102 
103  LogicalResult
104  matchAndRewrite(arith::TruncIOp op, OpAdaptor adaptor,
105  ConversionPatternRewriter &rewriter) const override {
106  size_t outWidth = op.getType().getIntOrFloatBitWidth();
107  rewriter.replaceOpWithNewOp<comb::ExtractOp>(op, adaptor.getIn(), 0,
108  outWidth);
109  return success();
110  }
111 };
112 
113 class CompConversionPattern : public OpConversionPattern<arith::CmpIOp> {
114 public:
116  using OpAdaptor = typename arith::CmpIOp::Adaptor;
117 
118  static comb::ICmpPredicate
119  arithToCombPredicate(arith::CmpIPredicate predicate) {
120  switch (predicate) {
121  case arith::CmpIPredicate::eq:
122  return comb::ICmpPredicate::eq;
123  case arith::CmpIPredicate::ne:
124  return comb::ICmpPredicate::ne;
125  case arith::CmpIPredicate::slt:
126  return comb::ICmpPredicate::slt;
127  case arith::CmpIPredicate::ult:
128  return comb::ICmpPredicate::ult;
129  case arith::CmpIPredicate::sle:
130  return comb::ICmpPredicate::sle;
131  case arith::CmpIPredicate::ule:
132  return comb::ICmpPredicate::ule;
133  case arith::CmpIPredicate::sgt:
134  return comb::ICmpPredicate::sgt;
135  case arith::CmpIPredicate::ugt:
136  return comb::ICmpPredicate::ugt;
137  case arith::CmpIPredicate::sge:
138  return comb::ICmpPredicate::sge;
139  case arith::CmpIPredicate::uge:
140  return comb::ICmpPredicate::uge;
141  }
142  llvm_unreachable("Unknown predicate");
143  }
144 
145  LogicalResult
146  matchAndRewrite(arith::CmpIOp op, OpAdaptor adaptor,
147  ConversionPatternRewriter &rewriter) const override {
148  rewriter.replaceOpWithNewOp<comb::ICmpOp>(
149  op, arithToCombPredicate(op.getPredicate()), adaptor.getLhs(),
150  adaptor.getRhs());
151  return success();
152  }
153 };
154 
155 struct MapArithToCombPass
156  : public circt::impl::MapArithToCombPassBase<MapArithToCombPass> {
157 public:
158  void runOnOperation() override {
159  auto *ctx = &getContext();
160 
161  ConversionTarget target(*ctx);
162  target.addLegalDialect<comb::CombDialect, hw::HWDialect>();
163  target.addIllegalDialect<arith::ArithDialect>();
164  MapArithTypeConverter typeConverter;
165  RewritePatternSet patterns(ctx);
166 
167  patterns.insert<OneToOnePattern<arith::AddIOp, comb::AddOp>,
168  OneToOnePattern<arith::SubIOp, comb::SubOp>,
169  OneToOnePattern<arith::MulIOp, comb::MulOp>,
170  OneToOnePattern<arith::DivSIOp, comb::DivSOp>,
171  OneToOnePattern<arith::DivUIOp, comb::DivUOp>,
172  OneToOnePattern<arith::RemSIOp, comb::ModSOp>,
173  OneToOnePattern<arith::RemUIOp, comb::ModUOp>,
174  OneToOnePattern<arith::AndIOp, comb::AndOp>,
175  OneToOnePattern<arith::OrIOp, comb::OrOp>,
176  OneToOnePattern<arith::XOrIOp, comb::XorOp>,
177  OneToOnePattern<arith::ShLIOp, comb::ShlOp>,
178  OneToOnePattern<arith::ShRSIOp, comb::ShrSOp>,
179  OneToOnePattern<arith::ShRUIOp, comb::ShrUOp>,
180  OneToOnePattern<arith::ConstantOp, hw::ConstantOp, true>,
181  OneToOnePattern<arith::SelectOp, comb::MuxOp>,
182  ExtSConversionPattern, ExtZConversionPattern,
183  TruncateConversionPattern, CompConversionPattern>(
184  typeConverter, ctx);
185 
186  if (failed(applyPartialConversion(getOperation(), target,
187  std::move(patterns))))
188  signalPassFailure();
189  }
190 };
191 
192 } // namespace
193 
194 std::unique_ptr<mlir::Pass> circt::createMapArithToCombPass() {
195  return std::make_unique<MapArithToCombPass>();
196 }
Value createOrFoldSExt(Location loc, Value value, Type destTy, OpBuilder &builder)
Create a sign extension operation from a value of integer type to an equal or larger integer type.
Definition: CombOps.cpp:25
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
std::unique_ptr< mlir::Pass > createMapArithToCombPass()