CIRCT  18.0.0git
CombToArith.cpp
Go to the documentation of this file.
1 //===- CombToArith.cpp ----------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "../PassDetail.h"
12 #include "circt/Dialect/HW/HWOps.h"
13 #include "mlir/Dialect/Arith/IR/Arith.h"
14 #include "mlir/Transforms/DialectConversion.h"
15 
16 using namespace circt;
17 using namespace hw;
18 using namespace comb;
19 using namespace mlir;
20 using namespace arith;
21 
22 //===----------------------------------------------------------------------===//
23 // Conversion patterns
24 //===----------------------------------------------------------------------===//
25 
26 namespace {
27 /// Lower a comb::ReplicateOp operation to a comb::ConcatOp
28 struct CombReplicateOpConversion : OpConversionPattern<ReplicateOp> {
30 
31  LogicalResult
32  matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
33  ConversionPatternRewriter &rewriter) const override {
34 
35  Type inputType = op.getInput().getType();
36  if (inputType.isa<IntegerType>() &&
37  inputType.getIntOrFloatBitWidth() == 1) {
38  Type outType = rewriter.getIntegerType(op.getMultiple());
39  rewriter.replaceOpWithNewOp<ExtSIOp>(op, outType, adaptor.getInput());
40  return success();
41  }
42 
43  SmallVector<Value> inputs(op.getMultiple(), adaptor.getInput());
44  rewriter.replaceOpWithNewOp<ConcatOp>(op, inputs);
45  return success();
46  }
47 };
48 
49 /// Lower a hw::ConstantOp operation to a arith::ConstantOp
50 struct HWConstantOpConversion : OpConversionPattern<hw::ConstantOp> {
52 
53  LogicalResult
54  matchAndRewrite(hw::ConstantOp op, OpAdaptor adaptor,
55  ConversionPatternRewriter &rewriter) const override {
56 
57  rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, adaptor.getValueAttr());
58  return success();
59  }
60 };
61 
62 /// Lower a comb::ICmpOp operation to a arith::CmpIOp
63 struct IcmpOpConversion : OpConversionPattern<ICmpOp> {
65 
66  LogicalResult
67  matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
68  ConversionPatternRewriter &rewriter) const override {
69 
70  CmpIPredicate pred;
71  switch (adaptor.getPredicate()) {
72  case ICmpPredicate::cne:
73  case ICmpPredicate::wne:
74  case ICmpPredicate::ne:
75  pred = CmpIPredicate::ne;
76  break;
77  case ICmpPredicate::ceq:
78  case ICmpPredicate::weq:
79  case ICmpPredicate::eq:
80  pred = CmpIPredicate::eq;
81  break;
82  case ICmpPredicate::sge:
83  pred = CmpIPredicate::sge;
84  break;
85  case ICmpPredicate::sgt:
86  pred = CmpIPredicate::sgt;
87  break;
88  case ICmpPredicate::sle:
89  pred = CmpIPredicate::sle;
90  break;
91  case ICmpPredicate::slt:
92  pred = CmpIPredicate::slt;
93  break;
94  case ICmpPredicate::uge:
95  pred = CmpIPredicate::uge;
96  break;
97  case ICmpPredicate::ugt:
98  pred = CmpIPredicate::ugt;
99  break;
100  case ICmpPredicate::ule:
101  pred = CmpIPredicate::ule;
102  break;
103  case ICmpPredicate::ult:
104  pred = CmpIPredicate::ult;
105  break;
106  }
107 
108  rewriter.replaceOpWithNewOp<CmpIOp>(op, pred, adaptor.getLhs(),
109  adaptor.getRhs());
110  return success();
111  }
112 };
113 
114 /// Lower a comb::ExtractOp operation to the arith dialect
115 struct ExtractOpConversion : OpConversionPattern<ExtractOp> {
117 
118  LogicalResult
119  matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
120  ConversionPatternRewriter &rewriter) const override {
121 
122  Value lowBit = rewriter.create<arith::ConstantOp>(
123  op.getLoc(),
124  IntegerAttr::get(adaptor.getInput().getType(), adaptor.getLowBit()));
125  Value shifted =
126  rewriter.create<ShRUIOp>(op.getLoc(), adaptor.getInput(), lowBit);
127  rewriter.replaceOpWithNewOp<TruncIOp>(op, op.getResult().getType(),
128  shifted);
129  return success();
130  }
131 };
132 
133 /// Lower a comb::ConcatOp operation to the arith dialect
134 struct ConcatOpConversion : OpConversionPattern<ConcatOp> {
136 
137  LogicalResult
138  matchAndRewrite(ConcatOp op, OpAdaptor adaptor,
139  ConversionPatternRewriter &rewriter) const override {
140  Type type = op.getResult().getType();
141  Location loc = op.getLoc();
142  unsigned nextInsertion = type.getIntOrFloatBitWidth();
143 
144  Value aggregate =
145  rewriter.create<arith::ConstantOp>(loc, IntegerAttr::get(type, 0));
146 
147  for (unsigned i = 0, e = op.getNumOperands(); i < e; i++) {
148  nextInsertion -=
149  adaptor.getOperands()[i].getType().getIntOrFloatBitWidth();
150 
151  Value nextInsValue = rewriter.create<arith::ConstantOp>(
152  loc, IntegerAttr::get(type, nextInsertion));
153  Value extended =
154  rewriter.create<ExtUIOp>(loc, type, adaptor.getOperands()[i]);
155  Value shifted = rewriter.create<ShLIOp>(loc, extended, nextInsValue);
156  aggregate = rewriter.create<OrIOp>(loc, aggregate, shifted);
157  }
158 
159  rewriter.replaceOp(op, aggregate);
160  return success();
161  }
162 };
163 
164 /// Lower the two-operand SourceOp to the two-operand TargetOp
165 template <typename SourceOp, typename TargetOp>
166 struct BinaryOpConversion : OpConversionPattern<SourceOp> {
168  using OpAdaptor = typename SourceOp::Adaptor;
169 
170  LogicalResult
171  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
172  ConversionPatternRewriter &rewriter) const override {
173 
174  rewriter.replaceOpWithNewOp<TargetOp>(op, op.getResult().getType(),
175  adaptor.getOperands());
176  return success();
177  }
178 };
179 
180 /// Lower a comb::ReplicateOp operation to the LLVM dialect.
181 template <typename SourceOp, typename TargetOp>
182 struct VariadicOpConversion : OpConversionPattern<SourceOp> {
184  using OpAdaptor = typename SourceOp::Adaptor;
185 
186  LogicalResult
187  matchAndRewrite(SourceOp op, OpAdaptor adaptor,
188  ConversionPatternRewriter &rewriter) const override {
189 
190  // TODO: building a tree would be better here
191  ValueRange operands = adaptor.getOperands();
192  Value runner = operands[0];
193  for (Value operand :
194  llvm::make_range(operands.begin() + 1, operands.end())) {
195  runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
196  }
197  rewriter.replaceOp(op, runner);
198  return success();
199  }
200 };
201 } // namespace
202 
203 //===----------------------------------------------------------------------===//
204 // Convert Comb to Arith pass
205 //===----------------------------------------------------------------------===//
206 
207 namespace {
208 struct ConvertCombToArithPass
209  : public ConvertCombToArithBase<ConvertCombToArithPass> {
210  void runOnOperation() override;
211 };
212 } // namespace
213 
215  TypeConverter &converter, mlir::RewritePatternSet &patterns) {
216  patterns.add<
217  CombReplicateOpConversion, HWConstantOpConversion, IcmpOpConversion,
218  ExtractOpConversion, ConcatOpConversion,
219  BinaryOpConversion<ShlOp, ShLIOp>, BinaryOpConversion<ShrSOp, ShRSIOp>,
220  BinaryOpConversion<ShrUOp, ShRUIOp>, BinaryOpConversion<SubOp, SubIOp>,
221  BinaryOpConversion<DivSOp, DivSIOp>, BinaryOpConversion<DivUOp, DivUIOp>,
222  BinaryOpConversion<ModSOp, RemSIOp>, BinaryOpConversion<ModUOp, RemUIOp>,
223  BinaryOpConversion<MuxOp, SelectOp>, VariadicOpConversion<AddOp, AddIOp>,
224  VariadicOpConversion<MulOp, MulIOp>, VariadicOpConversion<AndOp, AndIOp>,
225  VariadicOpConversion<OrOp, OrIOp>, VariadicOpConversion<XorOp, XOrIOp>>(
226  converter, patterns.getContext());
227 }
228 
229 void ConvertCombToArithPass::runOnOperation() {
230  ConversionTarget target(getContext());
231  target.addIllegalDialect<comb::CombDialect>();
232  target.addIllegalOp<hw::ConstantOp>();
233  target.addLegalDialect<ArithDialect>();
234  // Arith does not have an operation equivalent to comb.parity. A lowering
235  // would result in undesirably complex logic, therefore, we mark it legal
236  // here.
237  target.addLegalOp<comb::ParityOp>();
238 
239  RewritePatternSet patterns(&getContext());
240  TypeConverter converter;
241  converter.addConversion([](Type type) { return type; });
242  // TODO: a pattern for comb.parity
244 
245  if (failed(mlir::applyPartialConversion(getOperation(), target,
246  std::move(patterns))))
247  signalPassFailure();
248 }
249 
250 std::unique_ptr<OperationPass<ModuleOp>> circt::createConvertCombToArithPass() {
251  return std::make_unique<ConvertCombToArithPass>();
252 }
llvm::SmallVector< StringAttr > inputs
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
void populateCombToArithConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< OperationPass< ModuleOp > > createConvertCombToArithPass()
Definition: comb.py:1
Definition: hw.py:1