CIRCT  20.0.0git
CombToArith.cpp
Go to the documentation of this file.
1 //===- CombToArith.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "circt/Dialect/HW/HWOps.h"
12 #include "mlir/Dialect/Arith/IR/Arith.h"
13 #include "mlir/Pass/Pass.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 
16 namespace circt {
17 #define GEN_PASS_DEF_CONVERTCOMBTOARITH
18 #include "circt/Conversion/Passes.h.inc"
19 } // namespace circt
20 
21 using namespace circt;
22 using namespace hw;
23 using namespace comb;
24 using namespace mlir;
25 using namespace arith;
26 
27 //===----------------------------------------------------------------------===//
28 // Conversion patterns
29 //===----------------------------------------------------------------------===//
30 
31 namespace {
32 /// Lower a comb::ReplicateOp operation to a comb::ConcatOp
33 struct CombReplicateOpConversion : OpConversionPattern<ReplicateOp> {
35 
36  LogicalResult
37  matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
38  ConversionPatternRewriter &rewriter) const override {
39 
40  Type inputType = op.getInput().getType();
41  if (isa<IntegerType>(inputType) && inputType.getIntOrFloatBitWidth() == 1) {
42  Type outType = rewriter.getIntegerType(op.getMultiple());
43  rewriter.replaceOpWithNewOp<ExtSIOp>(op, outType, adaptor.getInput());
44  return success();
45  }
46 
47  SmallVector<Value> inputs(op.getMultiple(), adaptor.getInput());
48  rewriter.replaceOpWithNewOp<ConcatOp>(op, inputs);
49  return success();
50  }
51 };
52 
53 /// Lower a hw::ConstantOp operation to a arith::ConstantOp
54 struct HWConstantOpConversion : OpConversionPattern<hw::ConstantOp> {
56 
57  LogicalResult
58  matchAndRewrite(hw::ConstantOp op, OpAdaptor adaptor,
59  ConversionPatternRewriter &rewriter) const override {
60 
61  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, adaptor.getValueAttr());
62  return success();
63  }
64 };
65 
66 /// Lower a comb::ICmpOp operation to a arith::CmpIOp
67 struct IcmpOpConversion : OpConversionPattern<ICmpOp> {
69 
70  LogicalResult
71  matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
72  ConversionPatternRewriter &rewriter) const override {
73 
74  CmpIPredicate pred;
75  switch (adaptor.getPredicate()) {
76  case ICmpPredicate::cne:
77  case ICmpPredicate::wne:
78  case ICmpPredicate::ne:
79  pred = CmpIPredicate::ne;
80  break;
81  case ICmpPredicate::ceq:
82  case ICmpPredicate::weq:
83  case ICmpPredicate::eq:
84  pred = CmpIPredicate::eq;
85  break;
86  case ICmpPredicate::sge:
87  pred = CmpIPredicate::sge;
88  break;
89  case ICmpPredicate::sgt:
90  pred = CmpIPredicate::sgt;
91  break;
92  case ICmpPredicate::sle:
93  pred = CmpIPredicate::sle;
94  break;
95  case ICmpPredicate::slt:
96  pred = CmpIPredicate::slt;
97  break;
98  case ICmpPredicate::uge:
99  pred = CmpIPredicate::uge;
100  break;
101  case ICmpPredicate::ugt:
102  pred = CmpIPredicate::ugt;
103  break;
104  case ICmpPredicate::ule:
105  pred = CmpIPredicate::ule;
106  break;
107  case ICmpPredicate::ult:
108  pred = CmpIPredicate::ult;
109  break;
110  }
111 
112  rewriter.replaceOpWithNewOp<CmpIOp>(op, pred, adaptor.getLhs(),
113  adaptor.getRhs());
114  return success();
115  }
116 };
117 
118 /// Lower a comb::ExtractOp operation to the arith dialect
119 struct ExtractOpConversion : OpConversionPattern<ExtractOp> {
121 
122  LogicalResult
123  matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
124  ConversionPatternRewriter &rewriter) const override {
125 
126  Value lowBit = rewriter.create<arith::ConstantOp>(
127  op.getLoc(),
128  IntegerAttr::get(adaptor.getInput().getType(), adaptor.getLowBit()));
129  Value shifted =
130  rewriter.create<ShRUIOp>(op.getLoc(), adaptor.getInput(), lowBit);
131  rewriter.replaceOpWithNewOp<TruncIOp>(op, op.getResult().getType(),
132  shifted);
133  return success();
134  }
135 };
136 
137 /// Lower a comb::ConcatOp operation to the arith dialect
138 struct ConcatOpConversion : OpConversionPattern<ConcatOp> {
140 
141  LogicalResult
142  matchAndRewrite(ConcatOp op, OpAdaptor adaptor,
143  ConversionPatternRewriter &rewriter) const override {
144  Type type = op.getResult().getType();
145  Location loc = op.getLoc();
146 
147  // Handle the trivial case where we have only one operand. The concat is a
148  // no-op in this case.
149  if (op.getNumOperands() == 1) {
150  rewriter.replaceOp(op, adaptor.getOperands().back());
151  return success();
152  }
153 
154  // The operand at the least significant bit position (the one all the way on
155  // the right at the highest index) does not need to be shifted and can just
156  // be zero-extended to the final bit width.
157  Value aggregate =
158  rewriter.createOrFold<ExtUIOp>(loc, type, adaptor.getOperands().back());
159 
160  // Shift and OR all the other operands onto the aggregate. Skip the last
161  // operand because it has already been incorporated into the aggregate.
162  unsigned offset = type.getIntOrFloatBitWidth();
163  for (auto operand : adaptor.getOperands().drop_back()) {
164  offset -= operand.getType().getIntOrFloatBitWidth();
165  auto offsetConst = rewriter.create<arith::ConstantOp>(
166  loc, IntegerAttr::get(type, offset));
167  auto extended = rewriter.createOrFold<ExtUIOp>(loc, type, operand);
168  auto shifted = rewriter.createOrFold<ShLIOp>(loc, extended, offsetConst);
169  aggregate = rewriter.createOrFold<OrIOp>(loc, aggregate, shifted);
170  }
171 
172  rewriter.replaceOp(op, aggregate);
173  return success();
174  }
175 };
176 
177 /// Lower the two-operand SourceOp to the two-operand TargetOp
178 template <typename SourceOp, typename TargetOp>
179 struct BinaryOpConversion : OpConversionPattern<SourceOp> {
181  using OpAdaptor = typename SourceOp::Adaptor;
182 
183  LogicalResult
184  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
185  ConversionPatternRewriter &rewriter) const override {
186 
187  rewriter.replaceOpWithNewOp<TargetOp>(op, op.getResult().getType(),
188  adaptor.getOperands());
189  return success();
190  }
191 };
192 
193 /// Lowering for division operations that need to special-case zero-value
194 /// divisors to not run coarser UB than CIRCT defines.
195 template <typename SourceOp, typename TargetOp>
196 struct DivOpConversion : OpConversionPattern<SourceOp> {
198  using OpAdaptor = typename SourceOp::Adaptor;
199 
200  LogicalResult
201  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
202  ConversionPatternRewriter &rewriter) const override {
203  Location loc = op.getLoc();
204  Value zero = rewriter.create<arith::ConstantOp>(
205  loc, rewriter.getIntegerAttr(adaptor.getRhs().getType(), 0));
206  Value one = rewriter.create<arith::ConstantOp>(
207  loc, rewriter.getIntegerAttr(adaptor.getRhs().getType(), 1));
208  Value isZero = rewriter.create<arith::CmpIOp>(loc, CmpIPredicate::eq,
209  adaptor.getRhs(), zero);
210  Value divisor =
211  rewriter.create<arith::SelectOp>(loc, isZero, one, adaptor.getRhs());
212  rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getLhs(), divisor);
213  return success();
214  }
215 };
216 
217 /// Lower a comb::ReplicateOp operation to the LLVM dialect.
218 template <typename SourceOp, typename TargetOp>
219 struct VariadicOpConversion : OpConversionPattern<SourceOp> {
221  using OpAdaptor = typename SourceOp::Adaptor;
222 
223  LogicalResult
224  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
225  ConversionPatternRewriter &rewriter) const override {
226 
227  // TODO: building a tree would be better here
228  ValueRange operands = adaptor.getOperands();
229  Value runner = operands[0];
230  for (Value operand :
231  llvm::make_range(operands.begin() + 1, operands.end())) {
232  runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
233  }
234  rewriter.replaceOp(op, runner);
235  return success();
236  }
237 };
238 
239 // Shifts greater than or equal to the width of the lhs are currently
240 // unspecified in arith and produce poison in LLVM IR. To prevent undefined
241 // behaviour we handle this case explicitly.
242 
243 /// Lower the logical shift SourceOp to the logical shift TargetOp
244 /// Ensure to produce zero for shift amounts greater than or equal to the width
245 /// of the lhs
246 template <typename SourceOp, typename TargetOp>
247 struct LogicalShiftConversion : OpConversionPattern<SourceOp> {
249  using OpAdaptor = typename SourceOp::Adaptor;
250 
251  LogicalResult
252  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
253  ConversionPatternRewriter &rewriter) const override {
254  unsigned shifteeWidth =
255  hw::type_cast<IntegerType>(adaptor.getLhs().getType())
256  .getIntOrFloatBitWidth();
257  auto zeroConstOp = rewriter.create<arith::ConstantOp>(
258  op.getLoc(), IntegerAttr::get(adaptor.getLhs().getType(), 0));
259  auto maxShamtConstOp = rewriter.create<arith::ConstantOp>(
260  op.getLoc(),
261  IntegerAttr::get(adaptor.getLhs().getType(), shifteeWidth));
262  auto shiftOp = rewriter.createOrFold<TargetOp>(
263  op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
264  auto isAllZeroOp = rewriter.createOrFold<CmpIOp>(
265  op.getLoc(), CmpIPredicate::uge, adaptor.getRhs(),
266  maxShamtConstOp.getResult());
267  rewriter.replaceOpWithNewOp<SelectOp>(op, isAllZeroOp, zeroConstOp,
268  shiftOp);
269  return success();
270  }
271 };
272 
273 /// Lower a comb::ShrSOp operation to a (saturating) arith::ShRSIOp
274 struct ShrSOpConversion : OpConversionPattern<ShrSOp> {
276 
277  LogicalResult
278  matchAndRewrite(ShrSOp op, OpAdaptor adaptor,
279  ConversionPatternRewriter &rewriter) const override {
280  unsigned shifteeWidth =
281  hw::type_cast<IntegerType>(adaptor.getLhs().getType())
282  .getIntOrFloatBitWidth();
283  // Clamp the shift amount to shifteeWidth - 1
284  auto maxShamtMinusOneConstOp = rewriter.create<arith::ConstantOp>(
285  op.getLoc(),
286  IntegerAttr::get(adaptor.getLhs().getType(), shifteeWidth - 1));
287  auto shamtOp = rewriter.createOrFold<MinUIOp>(op.getLoc(), adaptor.getRhs(),
288  maxShamtMinusOneConstOp);
289  rewriter.replaceOpWithNewOp<ShRSIOp>(op, adaptor.getLhs(), shamtOp);
290  return success();
291  }
292 };
293 
294 } // namespace
295 
296 //===----------------------------------------------------------------------===//
297 // Convert Comb to Arith pass
298 //===----------------------------------------------------------------------===//
299 
300 namespace {
301 struct ConvertCombToArithPass
302  : public circt::impl::ConvertCombToArithBase<ConvertCombToArithPass> {
303  void runOnOperation() override;
304 };
305 } // namespace
306 
308  TypeConverter &converter, mlir::RewritePatternSet &patterns) {
309  patterns.add<
310  CombReplicateOpConversion, HWConstantOpConversion, IcmpOpConversion,
311  ExtractOpConversion, ConcatOpConversion, ShrSOpConversion,
312  LogicalShiftConversion<ShlOp, ShLIOp>,
313  LogicalShiftConversion<ShrUOp, ShRUIOp>,
314  BinaryOpConversion<SubOp, SubIOp>, DivOpConversion<DivSOp, DivSIOp>,
315  DivOpConversion<DivUOp, DivUIOp>, DivOpConversion<ModSOp, RemSIOp>,
316  DivOpConversion<ModUOp, RemUIOp>, BinaryOpConversion<MuxOp, SelectOp>,
317  VariadicOpConversion<AddOp, AddIOp>, VariadicOpConversion<MulOp, MulIOp>,
318  VariadicOpConversion<AndOp, AndIOp>, VariadicOpConversion<OrOp, OrIOp>,
319  VariadicOpConversion<XorOp, XOrIOp>>(converter, patterns.getContext());
320 }
321 
322 void ConvertCombToArithPass::runOnOperation() {
323  ConversionTarget target(getContext());
324  target.addIllegalDialect<comb::CombDialect>();
325  target.addIllegalOp<hw::ConstantOp>();
326  target.addLegalDialect<ArithDialect>();
327  // Arith does not have an operation equivalent to comb.parity. A lowering
328  // would result in undesirably complex logic, therefore, we mark it legal
329  // here.
330  target.addLegalOp<comb::ParityOp>();
331 
332  RewritePatternSet patterns(&getContext());
333  TypeConverter converter;
334  converter.addConversion([](Type type) { return type; });
335  // TODO: a pattern for comb.parity
337 
338  if (failed(mlir::applyPartialConversion(getOperation(), target,
339  std::move(patterns))))
340  signalPassFailure();
341 }
342 
343 std::unique_ptr<OperationPass<ModuleOp>> circt::createConvertCombToArithPass() {
344  return std::make_unique<ConvertCombToArithPass>();
345 }
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
void populateCombToArithConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< OperationPass< ModuleOp > > createConvertCombToArithPass()
Definition: comb.py:1
Definition: hw.py:1