CIRCT  19.0.0git
CombToArith.cpp
Go to the documentation of this file.
1 //===- CombToArith.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "../PassDetail.h"
12 #include "circt/Dialect/HW/HWOps.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 
16 using namespace circt;
17 using namespace hw;
18 using namespace comb;
19 using namespace mlir;
20 using namespace arith;
21 
22 //===----------------------------------------------------------------------===//
23 // Conversion patterns
24 //===----------------------------------------------------------------------===//
25 
26 namespace {
27 /// Lower a comb::ReplicateOp operation to a comb::ConcatOp
28 struct CombReplicateOpConversion : OpConversionPattern<ReplicateOp> {
30 
31  LogicalResult
32  matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
33  ConversionPatternRewriter &rewriter) const override {
34 
35  Type inputType = op.getInput().getType();
36  if (inputType.isa<IntegerType>() &&
37  inputType.getIntOrFloatBitWidth() == 1) {
38  Type outType = rewriter.getIntegerType(op.getMultiple());
39  rewriter.replaceOpWithNewOp<ExtSIOp>(op, outType, adaptor.getInput());
40  return success();
41  }
42 
43  SmallVector<Value> inputs(op.getMultiple(), adaptor.getInput());
44  rewriter.replaceOpWithNewOp<ConcatOp>(op, inputs);
45  return success();
46  }
47 };
48 
49 /// Lower a hw::ConstantOp operation to a arith::ConstantOp
50 struct HWConstantOpConversion : OpConversionPattern<hw::ConstantOp> {
52 
53  LogicalResult
54  matchAndRewrite(hw::ConstantOp op, OpAdaptor adaptor,
55  ConversionPatternRewriter &rewriter) const override {
56 
57  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, adaptor.getValueAttr());
58  return success();
59  }
60 };
61 
62 /// Lower a comb::ICmpOp operation to a arith::CmpIOp
63 struct IcmpOpConversion : OpConversionPattern<ICmpOp> {
65 
66  LogicalResult
67  matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
68  ConversionPatternRewriter &rewriter) const override {
69 
70  CmpIPredicate pred;
71  switch (adaptor.getPredicate()) {
72  case ICmpPredicate::cne:
73  case ICmpPredicate::wne:
74  case ICmpPredicate::ne:
75  pred = CmpIPredicate::ne;
76  break;
77  case ICmpPredicate::ceq:
78  case ICmpPredicate::weq:
79  case ICmpPredicate::eq:
80  pred = CmpIPredicate::eq;
81  break;
82  case ICmpPredicate::sge:
83  pred = CmpIPredicate::sge;
84  break;
85  case ICmpPredicate::sgt:
86  pred = CmpIPredicate::sgt;
87  break;
88  case ICmpPredicate::sle:
89  pred = CmpIPredicate::sle;
90  break;
91  case ICmpPredicate::slt:
92  pred = CmpIPredicate::slt;
93  break;
94  case ICmpPredicate::uge:
95  pred = CmpIPredicate::uge;
96  break;
97  case ICmpPredicate::ugt:
98  pred = CmpIPredicate::ugt;
99  break;
100  case ICmpPredicate::ule:
101  pred = CmpIPredicate::ule;
102  break;
103  case ICmpPredicate::ult:
104  pred = CmpIPredicate::ult;
105  break;
106  }
107 
108  rewriter.replaceOpWithNewOp<CmpIOp>(op, pred, adaptor.getLhs(),
109  adaptor.getRhs());
110  return success();
111  }
112 };
113 
114 /// Lower a comb::ExtractOp operation to the arith dialect
115 struct ExtractOpConversion : OpConversionPattern<ExtractOp> {
117 
118  LogicalResult
119  matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
120  ConversionPatternRewriter &rewriter) const override {
121 
122  Value lowBit = rewriter.create<arith::ConstantOp>(
123  op.getLoc(),
124  IntegerAttr::get(adaptor.getInput().getType(), adaptor.getLowBit()));
125  Value shifted =
126  rewriter.create<ShRUIOp>(op.getLoc(), adaptor.getInput(), lowBit);
127  rewriter.replaceOpWithNewOp<TruncIOp>(op, op.getResult().getType(),
128  shifted);
129  return success();
130  }
131 };
132 
133 /// Lower a comb::ConcatOp operation to the arith dialect
134 struct ConcatOpConversion : OpConversionPattern<ConcatOp> {
136 
137  LogicalResult
138  matchAndRewrite(ConcatOp op, OpAdaptor adaptor,
139  ConversionPatternRewriter &rewriter) const override {
140  Type type = op.getResult().getType();
141  Location loc = op.getLoc();
142 
143  // Handle the trivial case where we have only one operand. The concat is a
144  // no-op in this case.
145  if (op.getNumOperands() == 1) {
146  rewriter.replaceOp(op, adaptor.getOperands().back());
147  return success();
148  }
149 
150  // The operand at the least significant bit position (the one all the way on
151  // the right at the highest index) does not need to be shifted and can just
152  // be zero-extended to the final bit width.
153  Value aggregate =
154  rewriter.createOrFold<ExtUIOp>(loc, type, adaptor.getOperands().back());
155 
156  // Shift and OR all the other operands onto the aggregate. Skip the last
157  // operand because it has already been incorporated into the aggregate.
158  unsigned offset = type.getIntOrFloatBitWidth();
159  for (auto operand : adaptor.getOperands().drop_back()) {
160  offset -= operand.getType().getIntOrFloatBitWidth();
161  auto offsetConst = rewriter.create<arith::ConstantOp>(
162  loc, IntegerAttr::get(type, offset));
163  auto extended = rewriter.createOrFold<ExtUIOp>(loc, type, operand);
164  auto shifted = rewriter.createOrFold<ShLIOp>(loc, extended, offsetConst);
165  aggregate = rewriter.createOrFold<OrIOp>(loc, aggregate, shifted);
166  }
167 
168  rewriter.replaceOp(op, aggregate);
169  return success();
170  }
171 };
172 
173 /// Lower the two-operand SourceOp to the two-operand TargetOp
174 template <typename SourceOp, typename TargetOp>
175 struct BinaryOpConversion : OpConversionPattern<SourceOp> {
177  using OpAdaptor = typename SourceOp::Adaptor;
178 
179  LogicalResult
180  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
181  ConversionPatternRewriter &rewriter) const override {
182 
183  rewriter.replaceOpWithNewOp<TargetOp>(op, op.getResult().getType(),
184  adaptor.getOperands());
185  return success();
186  }
187 };
188 
189 /// Lower a comb::ReplicateOp operation to the LLVM dialect.
190 template <typename SourceOp, typename TargetOp>
191 struct VariadicOpConversion : OpConversionPattern<SourceOp> {
193  using OpAdaptor = typename SourceOp::Adaptor;
194 
195  LogicalResult
196  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
197  ConversionPatternRewriter &rewriter) const override {
198 
199  // TODO: building a tree would be better here
200  ValueRange operands = adaptor.getOperands();
201  Value runner = operands[0];
202  for (Value operand :
203  llvm::make_range(operands.begin() + 1, operands.end())) {
204  runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
205  }
206  rewriter.replaceOp(op, runner);
207  return success();
208  }
209 };
210 
211 // Shifts greater than or equal to the width of the lhs are currently
212 // unspecified in arith and produce poison in LLVM IR. To prevent undefined
213 // behaviour we handle this case explicitly.
214 
215 /// Lower the logical shift SourceOp to the logical shift TargetOp
216 /// Ensure to produce zero for shift amounts greater than or equal to the width
217 /// of the lhs
218 template <typename SourceOp, typename TargetOp>
219 struct LogicalShiftConversion : OpConversionPattern<SourceOp> {
221  using OpAdaptor = typename SourceOp::Adaptor;
222 
223  LogicalResult
224  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
225  ConversionPatternRewriter &rewriter) const override {
226  unsigned shifteeWidth =
227  hw::type_cast<IntegerType>(adaptor.getLhs().getType())
228  .getIntOrFloatBitWidth();
229  auto zeroConstOp = rewriter.create<arith::ConstantOp>(
230  op.getLoc(), IntegerAttr::get(adaptor.getLhs().getType(), 0));
231  auto maxShamtConstOp = rewriter.create<arith::ConstantOp>(
232  op.getLoc(),
233  IntegerAttr::get(adaptor.getLhs().getType(), shifteeWidth));
234  auto shiftOp = rewriter.createOrFold<TargetOp>(
235  op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
236  auto isAllZeroOp = rewriter.createOrFold<CmpIOp>(
237  op.getLoc(), CmpIPredicate::uge, adaptor.getRhs(),
238  maxShamtConstOp.getResult());
239  rewriter.replaceOpWithNewOp<SelectOp>(op, isAllZeroOp, zeroConstOp,
240  shiftOp);
241  return success();
242  }
243 };
244 
245 /// Lower a comb::ShrSOp operation to a (saturating) arith::ShRSIOp
246 struct ShrSOpConversion : OpConversionPattern<ShrSOp> {
248 
249  LogicalResult
250  matchAndRewrite(ShrSOp op, OpAdaptor adaptor,
251  ConversionPatternRewriter &rewriter) const override {
252  unsigned shifteeWidth =
253  hw::type_cast<IntegerType>(adaptor.getLhs().getType())
254  .getIntOrFloatBitWidth();
255  // Clamp the shift amount to shifteeWidth - 1
256  auto maxShamtMinusOneConstOp = rewriter.create<arith::ConstantOp>(
257  op.getLoc(),
258  IntegerAttr::get(adaptor.getLhs().getType(), shifteeWidth - 1));
259  auto shamtOp = rewriter.createOrFold<MinUIOp>(op.getLoc(), adaptor.getRhs(),
260  maxShamtMinusOneConstOp);
261  rewriter.replaceOpWithNewOp<ShRSIOp>(op, adaptor.getLhs(), shamtOp);
262  return success();
263  }
264 };
265 
266 } // namespace
267 
268 //===----------------------------------------------------------------------===//
269 // Convert Comb to Arith pass
270 //===----------------------------------------------------------------------===//
271 
272 namespace {
273 struct ConvertCombToArithPass
274  : public ConvertCombToArithBase<ConvertCombToArithPass> {
275  void runOnOperation() override;
276 };
277 } // namespace
278 
280  TypeConverter &converter, mlir::RewritePatternSet &patterns) {
281  patterns.add<
282  CombReplicateOpConversion, HWConstantOpConversion, IcmpOpConversion,
283  ExtractOpConversion, ConcatOpConversion, ShrSOpConversion,
284  LogicalShiftConversion<ShlOp, ShLIOp>,
285  LogicalShiftConversion<ShrUOp, ShRUIOp>,
286  BinaryOpConversion<SubOp, SubIOp>, BinaryOpConversion<DivSOp, DivSIOp>,
287  BinaryOpConversion<DivUOp, DivUIOp>, BinaryOpConversion<ModSOp, RemSIOp>,
288  BinaryOpConversion<ModUOp, RemUIOp>, BinaryOpConversion<MuxOp, SelectOp>,
289  VariadicOpConversion<AddOp, AddIOp>, VariadicOpConversion<MulOp, MulIOp>,
290  VariadicOpConversion<AndOp, AndIOp>, VariadicOpConversion<OrOp, OrIOp>,
291  VariadicOpConversion<XorOp, XOrIOp>>(converter, patterns.getContext());
292 }
293 
294 void ConvertCombToArithPass::runOnOperation() {
295  ConversionTarget target(getContext());
296  target.addIllegalDialect<comb::CombDialect>();
297  target.addIllegalOp<hw::ConstantOp>();
298  target.addLegalDialect<ArithDialect>();
299  // Arith does not have an operation equivalent to comb.parity. A lowering
300  // would result in undesirably complex logic, therefore, we mark it legal
301  // here.
302  target.addLegalOp<comb::ParityOp>();
303 
304  RewritePatternSet patterns(&getContext());
305  TypeConverter converter;
306  converter.addConversion([](Type type) { return type; });
307  // TODO: a pattern for comb.parity
309 
310  if (failed(mlir::applyPartialConversion(getOperation(), target,
311  std::move(patterns))))
312  signalPassFailure();
313 }
314 
315 std::unique_ptr<OperationPass<ModuleOp>> circt::createConvertCombToArithPass() {
316  return std::make_unique<ConvertCombToArithPass>();
317 }
llvm::SmallVector< StringAttr > inputs
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
void populateCombToArithConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< OperationPass< ModuleOp > > createConvertCombToArithPass()
Definition: comb.py:1
Definition: hw.py:1