CIRCT 20.0.0git
Loading...
Searching...
No Matches
CombToArith.cpp
Go to the documentation of this file.
1//===- CombToArith.cpp ----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/DialectConversion.h"
15
16namespace circt {
17#define GEN_PASS_DEF_CONVERTCOMBTOARITH
18#include "circt/Conversion/Passes.h.inc"
19} // namespace circt
20
21using namespace circt;
22using namespace hw;
23using namespace comb;
24using namespace mlir;
25using namespace arith;
26
27//===----------------------------------------------------------------------===//
28// Conversion patterns
29//===----------------------------------------------------------------------===//
30
31namespace {
32/// Lower a comb::ReplicateOp operation to a comb::ConcatOp
33struct CombReplicateOpConversion : OpConversionPattern<ReplicateOp> {
35
36 LogicalResult
37 matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
38 ConversionPatternRewriter &rewriter) const override {
39
40 Type inputType = op.getInput().getType();
41 if (isa<IntegerType>(inputType) && inputType.getIntOrFloatBitWidth() == 1) {
42 Type outType = rewriter.getIntegerType(op.getMultiple());
43 rewriter.replaceOpWithNewOp<ExtSIOp>(op, outType, adaptor.getInput());
44 return success();
45 }
46
47 SmallVector<Value> inputs(op.getMultiple(), adaptor.getInput());
48 rewriter.replaceOpWithNewOp<ConcatOp>(op, inputs);
49 return success();
50 }
51};
52
53/// Lower a hw::ConstantOp operation to a arith::ConstantOp
54struct HWConstantOpConversion : OpConversionPattern<hw::ConstantOp> {
56
57 LogicalResult
58 matchAndRewrite(hw::ConstantOp op, OpAdaptor adaptor,
59 ConversionPatternRewriter &rewriter) const override {
60
61 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, adaptor.getValueAttr());
62 return success();
63 }
64};
65
66/// Lower a comb::ICmpOp operation to a arith::CmpIOp
67struct IcmpOpConversion : OpConversionPattern<ICmpOp> {
69
70 LogicalResult
71 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
72 ConversionPatternRewriter &rewriter) const override {
73
74 CmpIPredicate pred;
75 switch (adaptor.getPredicate()) {
76 case ICmpPredicate::cne:
77 case ICmpPredicate::wne:
78 case ICmpPredicate::ne:
79 pred = CmpIPredicate::ne;
80 break;
81 case ICmpPredicate::ceq:
82 case ICmpPredicate::weq:
83 case ICmpPredicate::eq:
84 pred = CmpIPredicate::eq;
85 break;
86 case ICmpPredicate::sge:
87 pred = CmpIPredicate::sge;
88 break;
89 case ICmpPredicate::sgt:
90 pred = CmpIPredicate::sgt;
91 break;
92 case ICmpPredicate::sle:
93 pred = CmpIPredicate::sle;
94 break;
95 case ICmpPredicate::slt:
96 pred = CmpIPredicate::slt;
97 break;
98 case ICmpPredicate::uge:
99 pred = CmpIPredicate::uge;
100 break;
101 case ICmpPredicate::ugt:
102 pred = CmpIPredicate::ugt;
103 break;
104 case ICmpPredicate::ule:
105 pred = CmpIPredicate::ule;
106 break;
107 case ICmpPredicate::ult:
108 pred = CmpIPredicate::ult;
109 break;
110 }
111
112 rewriter.replaceOpWithNewOp<CmpIOp>(op, pred, adaptor.getLhs(),
113 adaptor.getRhs());
114 return success();
115 }
116};
117
118/// Lower a comb::ExtractOp operation to the arith dialect
119struct ExtractOpConversion : OpConversionPattern<ExtractOp> {
121
122 LogicalResult
123 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
124 ConversionPatternRewriter &rewriter) const override {
125
126 Value lowBit = rewriter.create<arith::ConstantOp>(
127 op.getLoc(),
128 IntegerAttr::get(adaptor.getInput().getType(), adaptor.getLowBit()));
129 Value shifted =
130 rewriter.create<ShRUIOp>(op.getLoc(), adaptor.getInput(), lowBit);
131 rewriter.replaceOpWithNewOp<TruncIOp>(op, op.getResult().getType(),
132 shifted);
133 return success();
134 }
135};
136
137/// Lower a comb::ConcatOp operation to the arith dialect
138struct ConcatOpConversion : OpConversionPattern<ConcatOp> {
140
141 LogicalResult
142 matchAndRewrite(ConcatOp op, OpAdaptor adaptor,
143 ConversionPatternRewriter &rewriter) const override {
144 Type type = op.getResult().getType();
145 Location loc = op.getLoc();
146
147 // Handle the trivial case where we have only one operand. The concat is a
148 // no-op in this case.
149 if (op.getNumOperands() == 1) {
150 rewriter.replaceOp(op, adaptor.getOperands().back());
151 return success();
152 }
153
154 // The operand at the least significant bit position (the one all the way on
155 // the right at the highest index) does not need to be shifted and can just
156 // be zero-extended to the final bit width.
157 Value aggregate =
158 rewriter.createOrFold<ExtUIOp>(loc, type, adaptor.getOperands().back());
159
160 // Shift and OR all the other operands onto the aggregate. Skip the last
161 // operand because it has already been incorporated into the aggregate.
162 unsigned offset = type.getIntOrFloatBitWidth();
163 for (auto operand : adaptor.getOperands().drop_back()) {
164 offset -= operand.getType().getIntOrFloatBitWidth();
165 auto offsetConst = rewriter.create<arith::ConstantOp>(
166 loc, IntegerAttr::get(type, offset));
167 auto extended = rewriter.createOrFold<ExtUIOp>(loc, type, operand);
168 auto shifted = rewriter.createOrFold<ShLIOp>(loc, extended, offsetConst);
169 aggregate = rewriter.createOrFold<OrIOp>(loc, aggregate, shifted);
170 }
171
172 rewriter.replaceOp(op, aggregate);
173 return success();
174 }
175};
176
177/// Lower the two-operand SourceOp to the two-operand TargetOp
178template <typename SourceOp, typename TargetOp>
179struct BinaryOpConversion : OpConversionPattern<SourceOp> {
181 using OpAdaptor = typename SourceOp::Adaptor;
182
183 LogicalResult
184 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter) const override {
186
187 rewriter.replaceOpWithNewOp<TargetOp>(op, op.getResult().getType(),
188 adaptor.getOperands());
189 return success();
190 }
191};
192
193/// Lowering for division operations that need to special-case zero-value
194/// divisors to not run coarser UB than CIRCT defines.
195template <typename SourceOp, typename TargetOp>
196struct DivOpConversion : OpConversionPattern<SourceOp> {
198 using OpAdaptor = typename SourceOp::Adaptor;
199
200 LogicalResult
201 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
202 ConversionPatternRewriter &rewriter) const override {
203 Location loc = op.getLoc();
204 Value zero = rewriter.create<arith::ConstantOp>(
205 loc, rewriter.getIntegerAttr(adaptor.getRhs().getType(), 0));
206 Value one = rewriter.create<arith::ConstantOp>(
207 loc, rewriter.getIntegerAttr(adaptor.getRhs().getType(), 1));
208 Value isZero = rewriter.create<arith::CmpIOp>(loc, CmpIPredicate::eq,
209 adaptor.getRhs(), zero);
210 Value divisor =
211 rewriter.create<arith::SelectOp>(loc, isZero, one, adaptor.getRhs());
212 rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getLhs(), divisor);
213 return success();
214 }
215};
216
217/// Lower a comb::ReplicateOp operation to the LLVM dialect.
218template <typename SourceOp, typename TargetOp>
219struct VariadicOpConversion : OpConversionPattern<SourceOp> {
221 using OpAdaptor = typename SourceOp::Adaptor;
222
223 LogicalResult
224 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
225 ConversionPatternRewriter &rewriter) const override {
226
227 // TODO: building a tree would be better here
228 ValueRange operands = adaptor.getOperands();
229 Value runner = operands[0];
230 for (Value operand :
231 llvm::make_range(operands.begin() + 1, operands.end())) {
232 runner = rewriter.create<TargetOp>(op.getLoc(), runner, operand);
233 }
234 rewriter.replaceOp(op, runner);
235 return success();
236 }
237};
238
239// Shifts greater than or equal to the width of the lhs are currently
240// unspecified in arith and produce poison in LLVM IR. To prevent undefined
241// behaviour we handle this case explicitly.
242
243/// Lower the logical shift SourceOp to the logical shift TargetOp
244/// Ensure to produce zero for shift amounts greater than or equal to the width
245/// of the lhs
246template <typename SourceOp, typename TargetOp>
247struct LogicalShiftConversion : OpConversionPattern<SourceOp> {
249 using OpAdaptor = typename SourceOp::Adaptor;
250
251 LogicalResult
252 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
253 ConversionPatternRewriter &rewriter) const override {
254 unsigned shifteeWidth =
255 hw::type_cast<IntegerType>(adaptor.getLhs().getType())
256 .getIntOrFloatBitWidth();
257 auto zeroConstOp = rewriter.create<arith::ConstantOp>(
258 op.getLoc(), IntegerAttr::get(adaptor.getLhs().getType(), 0));
259 auto maxShamtConstOp = rewriter.create<arith::ConstantOp>(
260 op.getLoc(),
261 IntegerAttr::get(adaptor.getLhs().getType(), shifteeWidth));
262 auto shiftOp = rewriter.createOrFold<TargetOp>(
263 op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
264 auto isAllZeroOp = rewriter.createOrFold<CmpIOp>(
265 op.getLoc(), CmpIPredicate::uge, adaptor.getRhs(),
266 maxShamtConstOp.getResult());
267 rewriter.replaceOpWithNewOp<SelectOp>(op, isAllZeroOp, zeroConstOp,
268 shiftOp);
269 return success();
270 }
271};
272
273/// Lower a comb::ShrSOp operation to a (saturating) arith::ShRSIOp
274struct ShrSOpConversion : OpConversionPattern<ShrSOp> {
276
277 LogicalResult
278 matchAndRewrite(ShrSOp op, OpAdaptor adaptor,
279 ConversionPatternRewriter &rewriter) const override {
280 unsigned shifteeWidth =
281 hw::type_cast<IntegerType>(adaptor.getLhs().getType())
282 .getIntOrFloatBitWidth();
283 // Clamp the shift amount to shifteeWidth - 1
284 auto maxShamtMinusOneConstOp = rewriter.create<arith::ConstantOp>(
285 op.getLoc(),
286 IntegerAttr::get(adaptor.getLhs().getType(), shifteeWidth - 1));
287 auto shamtOp = rewriter.createOrFold<MinUIOp>(op.getLoc(), adaptor.getRhs(),
288 maxShamtMinusOneConstOp);
289 rewriter.replaceOpWithNewOp<ShRSIOp>(op, adaptor.getLhs(), shamtOp);
290 return success();
291 }
292};
293
294} // namespace
295
296//===----------------------------------------------------------------------===//
297// Convert Comb to Arith pass
298//===----------------------------------------------------------------------===//
299
300namespace {
301struct ConvertCombToArithPass
302 : public circt::impl::ConvertCombToArithBase<ConvertCombToArithPass> {
303 void runOnOperation() override;
304};
305} // namespace
306
308 TypeConverter &converter, mlir::RewritePatternSet &patterns) {
309 patterns.add<
310 CombReplicateOpConversion, HWConstantOpConversion, IcmpOpConversion,
311 ExtractOpConversion, ConcatOpConversion, ShrSOpConversion,
312 LogicalShiftConversion<ShlOp, ShLIOp>,
313 LogicalShiftConversion<ShrUOp, ShRUIOp>,
314 BinaryOpConversion<SubOp, SubIOp>, DivOpConversion<DivSOp, DivSIOp>,
315 DivOpConversion<DivUOp, DivUIOp>, DivOpConversion<ModSOp, RemSIOp>,
316 DivOpConversion<ModUOp, RemUIOp>, BinaryOpConversion<MuxOp, SelectOp>,
317 VariadicOpConversion<AddOp, AddIOp>, VariadicOpConversion<MulOp, MulIOp>,
318 VariadicOpConversion<AndOp, AndIOp>, VariadicOpConversion<OrOp, OrIOp>,
319 VariadicOpConversion<XorOp, XOrIOp>>(converter, patterns.getContext());
320}
321
322void ConvertCombToArithPass::runOnOperation() {
323 ConversionTarget target(getContext());
324 target.addIllegalDialect<comb::CombDialect>();
325 target.addIllegalOp<hw::ConstantOp>();
326 target.addLegalDialect<ArithDialect>();
327 // Arith does not have an operation equivalent to comb.parity. A lowering
328 // would result in undesirably complex logic, therefore, we mark it legal
329 // here.
330 target.addLegalOp<comb::ParityOp>();
331
332 RewritePatternSet patterns(&getContext());
333 TypeConverter converter;
334 converter.addConversion([](Type type) { return type; });
335 // TODO: a pattern for comb.parity
337
338 if (failed(mlir::applyPartialConversion(getOperation(), target,
339 std::move(patterns))))
340 signalPassFailure();
341}
342
343std::unique_ptr<Pass> circt::createConvertCombToArithPass() {
344 return std::make_unique<ConvertCombToArithPass>();
345}
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateCombToArithConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createConvertCombToArithPass()
Definition comb.py:1
Definition hw.py:1