CIRCT 23.0.0git
Loading...
Searching...
No Matches
CombToArith.cpp
Go to the documentation of this file.
1//===- CombToArith.cpp ----------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/Dialect/Arith/IR/Arith.h"
13#include "mlir/Pass/Pass.h"
14#include "mlir/Transforms/DialectConversion.h"
15
16namespace circt {
17#define GEN_PASS_DEF_CONVERTCOMBTOARITH
18#include "circt/Conversion/Passes.h.inc"
19} // namespace circt
20
21using namespace circt;
22using namespace hw;
23using namespace comb;
24using namespace mlir;
25using namespace arith;
26
27//===----------------------------------------------------------------------===//
28// Conversion patterns
29//===----------------------------------------------------------------------===//
30
31namespace {
32/// Lower a comb::ReplicateOp operation to a comb::ConcatOp
33struct CombReplicateOpConversion : OpConversionPattern<ReplicateOp> {
35
36 LogicalResult
37 matchAndRewrite(ReplicateOp op, OpAdaptor adaptor,
38 ConversionPatternRewriter &rewriter) const override {
39
40 Type inputType = op.getInput().getType();
41 if (isa<IntegerType>(inputType) && inputType.getIntOrFloatBitWidth() == 1) {
42 Type outType = rewriter.getIntegerType(op.getMultiple());
43 rewriter.replaceOpWithNewOp<ExtSIOp>(op, outType, adaptor.getInput());
44 return success();
45 }
46
47 SmallVector<Value> inputs(op.getMultiple(), adaptor.getInput());
48 rewriter.replaceOpWithNewOp<ConcatOp>(op, inputs);
49 return success();
50 }
51};
52
53/// Lower a hw::ConstantOp operation to a arith::ConstantOp
54struct HWConstantOpConversion : OpConversionPattern<hw::ConstantOp> {
56
57 LogicalResult
58 matchAndRewrite(hw::ConstantOp op, OpAdaptor adaptor,
59 ConversionPatternRewriter &rewriter) const override {
60
61 rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, adaptor.getValueAttr());
62 return success();
63 }
64};
65
66/// Lower a comb::ICmpOp operation to a arith::CmpIOp
67struct IcmpOpConversion : OpConversionPattern<ICmpOp> {
69
70 LogicalResult
71 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
72 ConversionPatternRewriter &rewriter) const override {
73
74 CmpIPredicate pred;
75 switch (adaptor.getPredicate()) {
76 case ICmpPredicate::cne:
77 case ICmpPredicate::wne:
78 case ICmpPredicate::ne:
79 pred = CmpIPredicate::ne;
80 break;
81 case ICmpPredicate::ceq:
82 case ICmpPredicate::weq:
83 case ICmpPredicate::eq:
84 pred = CmpIPredicate::eq;
85 break;
86 case ICmpPredicate::sge:
87 pred = CmpIPredicate::sge;
88 break;
89 case ICmpPredicate::sgt:
90 pred = CmpIPredicate::sgt;
91 break;
92 case ICmpPredicate::sle:
93 pred = CmpIPredicate::sle;
94 break;
95 case ICmpPredicate::slt:
96 pred = CmpIPredicate::slt;
97 break;
98 case ICmpPredicate::uge:
99 pred = CmpIPredicate::uge;
100 break;
101 case ICmpPredicate::ugt:
102 pred = CmpIPredicate::ugt;
103 break;
104 case ICmpPredicate::ule:
105 pred = CmpIPredicate::ule;
106 break;
107 case ICmpPredicate::ult:
108 pred = CmpIPredicate::ult;
109 break;
110 }
111
112 rewriter.replaceOpWithNewOp<CmpIOp>(op, pred, adaptor.getLhs(),
113 adaptor.getRhs());
114 return success();
115 }
116};
117
118/// Lower a comb::ExtractOp operation to the arith dialect
119struct ExtractOpConversion : OpConversionPattern<ExtractOp> {
121
122 LogicalResult
123 matchAndRewrite(ExtractOp op, OpAdaptor adaptor,
124 ConversionPatternRewriter &rewriter) const override {
125
126 Value lowBit = arith::ConstantOp::create(
127 rewriter, op.getLoc(),
128 IntegerAttr::get(adaptor.getInput().getType(), adaptor.getLowBit()));
129 Value shifted =
130 ShRUIOp::create(rewriter, op.getLoc(), adaptor.getInput(), lowBit);
131 rewriter.replaceOpWithNewOp<TruncIOp>(op, op.getResult().getType(),
132 shifted);
133 return success();
134 }
135};
136
137/// Lower a comb::ConcatOp operation to the arith dialect
138struct ConcatOpConversion : OpConversionPattern<ConcatOp> {
140
141 LogicalResult
142 matchAndRewrite(ConcatOp op, OpAdaptor adaptor,
143 ConversionPatternRewriter &rewriter) const override {
144 Type type = op.getResult().getType();
145 Location loc = op.getLoc();
146
147 // Handle the trivial case where we have only one operand. The concat is a
148 // no-op in this case.
149 if (op.getNumOperands() == 1) {
150 rewriter.replaceOp(op, adaptor.getOperands().back());
151 return success();
152 }
153
154 // The operand at the least significant bit position (the one all the way on
155 // the right at the highest index) does not need to be shifted and can just
156 // be zero-extended to the final bit width.
157 Value aggregate =
158 rewriter.createOrFold<ExtUIOp>(loc, type, adaptor.getOperands().back());
159
160 // Shift and OR all the other operands onto the aggregate. Skip the last
161 // operand because it has already been incorporated into the aggregate.
162 unsigned offset = type.getIntOrFloatBitWidth();
163 for (auto operand : adaptor.getOperands().drop_back()) {
164 offset -= operand.getType().getIntOrFloatBitWidth();
165 auto offsetConst = arith::ConstantOp::create(
166 rewriter, loc, IntegerAttr::get(type, offset));
167 auto extended = rewriter.createOrFold<ExtUIOp>(loc, type, operand);
168 auto shifted = rewriter.createOrFold<ShLIOp>(loc, extended, offsetConst);
169 aggregate = rewriter.createOrFold<OrIOp>(loc, aggregate, shifted);
170 }
171
172 rewriter.replaceOp(op, aggregate);
173 return success();
174 }
175};
176
177/// Lower the two-operand SourceOp to the two-operand TargetOp
178template <typename SourceOp, typename TargetOp>
179struct BinaryOpConversion : OpConversionPattern<SourceOp> {
181 using OpAdaptor = typename SourceOp::Adaptor;
182
183 LogicalResult
184 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
185 ConversionPatternRewriter &rewriter) const override {
186
187 rewriter.replaceOpWithNewOp<TargetOp>(op, op.getResult().getType(),
188 adaptor.getOperands());
189 return success();
190 }
191};
192
193/// Lowering for unsigned division / remainder operations that need to
194/// special-case zero-value divisors to not run coarser UB than CIRCT defines.
195template <typename SourceOp, typename TargetOp>
196struct DivUOpConversion : OpConversionPattern<SourceOp> {
198 using OpAdaptor = typename SourceOp::Adaptor;
199
200 LogicalResult
201 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
202 ConversionPatternRewriter &rewriter) const override {
203 // Rewrite: divu(a, b) ~>
204 // isZero = (b == 0)
205 // divisor = isZero ? 1 : b;
206 // result = divu(a, divisor)
207 Location loc = op.getLoc();
208 Value zero = arith::ConstantOp::create(
209 rewriter, loc, rewriter.getIntegerAttr(adaptor.getRhs().getType(), 0));
210 Value one = arith::ConstantOp::create(
211 rewriter, loc, rewriter.getIntegerAttr(adaptor.getRhs().getType(), 1));
212 Value isZero = arith::CmpIOp::create(rewriter, loc, CmpIPredicate::eq,
213 adaptor.getRhs(), zero);
214 Value divisor =
215 arith::SelectOp::create(rewriter, loc, isZero, one, adaptor.getRhs());
216 rewriter.replaceOpWithNewOp<TargetOp>(op, adaptor.getLhs(), divisor);
217 return success();
218 }
219};
220
221/// Lowering for signed division / remainder that need to special-case INT_MIN /
222/// -1 (and division by zero).
223template <typename SourceOp, typename TargetOp, bool IsRem>
224class DivSOpConversion : public OpConversionPattern<SourceOp> {
225public:
227 using OpAdaptor = typename SourceOp::Adaptor;
228
229 LogicalResult
230 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
231 ConversionPatternRewriter &rewriter) const override {
232 Value dividend = adaptor.getLhs();
233 Value divisor = adaptor.getRhs();
234 Type ty = op.getType();
235
236 // Rewrite: divs(a, b) ~>
237 // isZero = (b == 0)
238 // isOverflow = (b == -1 && a == INT_MIN)
239 // pred = isZero || isOverflow
240 // rhs_safe = pred ? 1 : b;
241 // c = divs(a, rhs_safe)
242 // result = isOverflow ? INT_MIN : c;
243 //
244 // mods is the same except the result is zero when isOverflow is true. These
245 // values were chosen to align with the behavior of existing Verilog
246 // simulators.
247 ImplicitLocOpBuilder b(op.getLoc(), rewriter);
248 auto eq = [&](Value lhs, Value rhs) {
249 return arith::CmpIOp::create(b, CmpIPredicate::eq, lhs, rhs);
250 };
251 auto and_ = [&](Value lhs, Value rhs) {
252 return arith::AndIOp::create(b, lhs, rhs);
253 };
254 auto or_ = [&](Value lhs, Value rhs) {
255 return arith::OrIOp::create(b, lhs, rhs);
256 };
257
258 int bitwidth = ty.getIntOrFloatBitWidth();
259 Value zero = arith::ConstantOp::create(b, rewriter.getIntegerAttr(ty, 0));
260 Value one = arith::ConstantOp::create(b, rewriter.getIntegerAttr(ty, 1));
261 Value int_min = arith::ConstantOp::create(
262 b, rewriter.getIntegerAttr(ty, APInt::getSignedMinValue(bitwidth)));
263 Value minus_one = arith::ConstantOp::create(
264 b, rewriter.getIntegerAttr(ty, APInt::getAllOnes(bitwidth)));
265
266 Value isZero = eq(divisor, zero);
267 Value isOverflow = and_(eq(dividend, int_min), eq(divisor, minus_one));
268 Value pred = or_(isZero, isOverflow);
269 Value safeDivisor = arith::SelectOp::create(b, pred, one, divisor);
270 auto newOp = TargetOp::create(b, dividend, safeDivisor);
271
272 Value resultIfOverflow = IsRem ? zero : int_min;
273 Value result =
274 arith::SelectOp::create(b, isOverflow, resultIfOverflow, newOp);
275 rewriter.replaceOp(op, result);
276 return success();
277 }
278};
279
280/// Lower a comb::ReplicateOp operation to the LLVM dialect.
281template <typename SourceOp, typename TargetOp>
282struct VariadicOpConversion : OpConversionPattern<SourceOp> {
284 using OpAdaptor = typename SourceOp::Adaptor;
285
286 LogicalResult
287 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
288 ConversionPatternRewriter &rewriter) const override {
289
290 // TODO: building a tree would be better here
291 ValueRange operands = adaptor.getOperands();
292 Value runner = operands[0];
293 for (Value operand :
294 llvm::make_range(operands.begin() + 1, operands.end())) {
295 runner = TargetOp::create(rewriter, op.getLoc(), runner, operand);
296 }
297 rewriter.replaceOp(op, runner);
298 return success();
299 }
300};
301
302// Shifts greater than or equal to the width of the lhs are currently
303// unspecified in arith and produce poison in LLVM IR. To prevent undefined
304// behaviour we handle this case explicitly.
305
306/// Lower the logical shift SourceOp to the logical shift TargetOp
307/// Ensure to produce zero for shift amounts greater than or equal to the width
308/// of the lhs
309template <typename SourceOp, typename TargetOp>
310struct LogicalShiftConversion : OpConversionPattern<SourceOp> {
312 using OpAdaptor = typename SourceOp::Adaptor;
313
314 LogicalResult
315 matchAndRewrite(SourceOp op, OpAdaptor adaptor,
316 ConversionPatternRewriter &rewriter) const override {
317 unsigned shifteeWidth =
318 hw::type_cast<IntegerType>(adaptor.getLhs().getType())
319 .getIntOrFloatBitWidth();
320 auto zeroConstOp = arith::ConstantOp::create(
321 rewriter, op.getLoc(), IntegerAttr::get(adaptor.getLhs().getType(), 0));
322 auto maxShamtConstOp = arith::ConstantOp::create(
323 rewriter, op.getLoc(),
324 IntegerAttr::get(adaptor.getLhs().getType(), shifteeWidth));
325 auto shiftOp = rewriter.createOrFold<TargetOp>(
326 op.getLoc(), adaptor.getLhs(), adaptor.getRhs());
327 auto isAllZeroOp = rewriter.createOrFold<CmpIOp>(
328 op.getLoc(), CmpIPredicate::uge, adaptor.getRhs(),
329 maxShamtConstOp.getResult());
330 rewriter.replaceOpWithNewOp<SelectOp>(op, isAllZeroOp, zeroConstOp,
331 shiftOp);
332 return success();
333 }
334};
335
336/// Lower a comb::ShrSOp operation to a (saturating) arith::ShRSIOp
337struct ShrSOpConversion : OpConversionPattern<ShrSOp> {
339
340 LogicalResult
341 matchAndRewrite(ShrSOp op, OpAdaptor adaptor,
342 ConversionPatternRewriter &rewriter) const override {
343 unsigned shifteeWidth =
344 hw::type_cast<IntegerType>(adaptor.getLhs().getType())
345 .getIntOrFloatBitWidth();
346 // Clamp the shift amount to shifteeWidth - 1
347 auto maxShamtMinusOneConstOp = arith::ConstantOp::create(
348 rewriter, op.getLoc(),
349 IntegerAttr::get(adaptor.getLhs().getType(), shifteeWidth - 1));
350 auto shamtOp = rewriter.createOrFold<MinUIOp>(op.getLoc(), adaptor.getRhs(),
351 maxShamtMinusOneConstOp);
352 rewriter.replaceOpWithNewOp<ShRSIOp>(op, adaptor.getLhs(), shamtOp);
353 return success();
354 }
355};
356
357} // namespace
358
359//===----------------------------------------------------------------------===//
360// Convert Comb to Arith pass
361//===----------------------------------------------------------------------===//
362
363namespace {
364struct ConvertCombToArithPass
365 : public circt::impl::ConvertCombToArithBase<ConvertCombToArithPass> {
366 void runOnOperation() override;
367};
368} // namespace
369
371 TypeConverter &converter, mlir::RewritePatternSet &patterns) {
372 patterns.add<
373 CombReplicateOpConversion, HWConstantOpConversion, IcmpOpConversion,
374 ExtractOpConversion, ConcatOpConversion, ShrSOpConversion,
375 LogicalShiftConversion<ShlOp, ShLIOp>,
376 LogicalShiftConversion<ShrUOp, ShRUIOp>,
377 BinaryOpConversion<SubOp, SubIOp>,
378 DivSOpConversion<DivSOp, DivSIOp, /*IsRem=*/false>,
379 DivUOpConversion<DivUOp, DivUIOp>,
380 DivSOpConversion<ModSOp, RemSIOp, /*IsRem=*/true>,
381 DivUOpConversion<ModUOp, RemUIOp>, BinaryOpConversion<MuxOp, SelectOp>,
382 VariadicOpConversion<AddOp, AddIOp>, VariadicOpConversion<MulOp, MulIOp>,
383 VariadicOpConversion<AndOp, AndIOp>, VariadicOpConversion<OrOp, OrIOp>,
384 VariadicOpConversion<XorOp, XOrIOp>>(converter, patterns.getContext());
385}
386
387void ConvertCombToArithPass::runOnOperation() {
388 ConversionTarget target(getContext());
389 target.addIllegalDialect<comb::CombDialect>();
390 target.addIllegalOp<hw::ConstantOp>();
391 target.addLegalDialect<ArithDialect>();
392 // Arith does not have an operation equivalent to comb.parity. A lowering
393 // would result in undesirably complex logic, therefore, we mark it legal
394 // here.
395 target.addLegalOp<comb::ParityOp>();
396 // Arith does not have bitreverse, so we leave it for the CombToLLVM pass.
397 target.addLegalOp<comb::ReverseOp>();
398 // This pass is intended to rewrite Comb ops into Arith ops. Other dialects
399 // (e.g. LLVM) may legitimately be present when this pass is used in custom
400 // pipelines. Treat all unknown operations as legal so we don't attempt to
401 // fold/legalize unrelated ops.
402 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
403 RewritePatternSet patterns(&getContext());
404 TypeConverter converter;
405 converter.addConversion([](Type type) { return type; });
406 // TODO: a pattern for comb.parity
408
409 ConversionConfig config;
410 config.allowPatternRollback = false;
411 if (failed(mlir::applyPartialConversion(getOperation(), target,
412 std::move(patterns), config)))
413 signalPassFailure();
414}
415
416std::unique_ptr<Pass> circt::createConvertCombToArithPass() {
417 return std::make_unique<ConvertCombToArithPass>();
418}
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void populateCombToArithConversionPatterns(TypeConverter &converter, RewritePatternSet &patterns)
std::unique_ptr< Pass > createConvertCombToArithPass()
Definition comb.py:1
Definition hw.py:1