CIRCT 20.0.0git
Loading...
Searching...
No Matches
CombToAIG.cpp
Go to the documentation of this file.
1//===- CombToAIG.cpp - Comb to AIG Conversion Pass --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main Comb to AIG Conversion Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/DialectConversion.h"
19#include "llvm/ADT/PointerUnion.h"
20
21namespace circt {
22#define GEN_PASS_DEF_CONVERTCOMBTOAIG
23#include "circt/Conversion/Passes.h.inc"
24} // namespace circt
25
26using namespace circt;
27using namespace comb;
28
29//===----------------------------------------------------------------------===//
30// Utility Functions
31//===----------------------------------------------------------------------===//
32
33// A wrapper for comb::extractBits that returns a SmallVector<Value>.
34static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
35 SmallVector<Value> bits;
36 comb::extractBits(builder, val, bits);
37 return bits;
38}
39
40// Construct a mux tree for shift operations. `isLeftShift` controls the
41// direction of the shift operation and is used to determine order of the
42// padding and extracted bits. Callbacks `getPadding` and `getExtract` are used
43// to get the padding and extracted bits for each shift amount. `getPadding`
44// could return a nullptr as i0 value but except for that, these callbacks must
45// return a valid value for each shift amount in the range [0, maxShiftAmount].
46// The value for `maxShiftAmount` is used as the out-of-bounds value.
47template <bool isLeftShift>
48static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc,
49 Value shiftAmount, int64_t maxShiftAmount,
50 llvm::function_ref<Value(int64_t)> getPadding,
51 llvm::function_ref<Value(int64_t)> getExtract) {
52 // Extract individual bits from shift amount
53 auto bits = extractBits(rewriter, shiftAmount);
54
55 // Create nodes for each possible shift amount
56 SmallVector<Value> nodes;
57 nodes.reserve(maxShiftAmount);
58 for (int64_t i = 0; i < maxShiftAmount; ++i) {
59 Value extract = getExtract(i);
60 Value padding = getPadding(i);
61
62 if (!padding) {
63 nodes.push_back(extract);
64 continue;
65 }
66
67 // Concatenate extracted bits with padding
68 if (isLeftShift)
69 nodes.push_back(
70 rewriter.createOrFold<comb::ConcatOp>(loc, extract, padding));
71 else
72 nodes.push_back(
73 rewriter.createOrFold<comb::ConcatOp>(loc, padding, extract));
74 }
75
76 // Create out-of-bounds value
77 auto outOfBoundsValue = getPadding(maxShiftAmount);
78 assert(outOfBoundsValue && "outOfBoundsValue must be valid");
79
80 // Construct mux tree for shift operation
81 auto result =
82 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
83
84 // Add bounds checking
85 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
86 loc, ICmpPredicate::ult, shiftAmount,
87 rewriter.create<hw::ConstantOp>(loc, shiftAmount.getType(),
88 maxShiftAmount));
89
90 return rewriter.createOrFold<comb::MuxOp>(loc, inBound, result,
91 outOfBoundsValue);
92}
93
94namespace {
95// A union of Value and IntegerAttr to cleanly handle constant values.
96using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
97} // namespace
98
99// Return the number of unknown bits and populate the concatenated values.
101 Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
102 // Constant or zero width value are all known.
103 if (value.getType().isInteger(0))
104 return 0;
105
106 // Recursively count unknown bits for concat.
107 if (auto concat = value.getDefiningOp<comb::ConcatOp>()) {
108 int64_t totalUnknownBits = 0;
109 for (auto concatInput : llvm::reverse(concat.getInputs())) {
110 auto unknownBits =
111 getNumUnknownBitsAndPopulateValues(concatInput, values);
112 if (unknownBits < 0)
113 return unknownBits;
114 totalUnknownBits += unknownBits;
115 }
116 return totalUnknownBits;
117 }
118
119 // Constant value is known.
120 if (auto constant = value.getDefiningOp<hw::ConstantOp>()) {
121 values.push_back(constant.getValueAttr());
122 return 0;
123 }
124
125 // Consider other operations as unknown bits.
126 // TODO: We can handle replicate, extract, etc.
127 values.push_back(value);
128 return hw::getBitWidth(value.getType());
129}
130
131// Return a value that substitutes the unknown bits with the mask.
132static APInt
134 llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
135 uint32_t mask) {
136 uint32_t bitPos = 0, unknownPos = 0;
137 APInt result(width, 0);
138 for (auto constantOrValue : constantOrValues) {
139 int64_t elemWidth;
140 if (auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
141 elemWidth = constant.getValue().getBitWidth();
142 result.insertBits(constant.getValue(), bitPos);
143 } else {
144 elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
145 assert(elemWidth >= 0 && "unknown bit width");
146 assert(elemWidth + unknownPos < 32 && "unknown bit width too large");
147 // Create a mask for the unknown bits.
148 uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
149 result.insertBits(APInt(elemWidth, usedBits), bitPos);
150 unknownPos += elemWidth;
151 }
152 bitPos += elemWidth;
153 }
154
155 return result;
156}
157
158// Emulate a binary operation with unknown bits using a table lookup.
159// This function enumerates all possible combinations of unknown bits and
160// emulates the operation for each combination.
161static LogicalResult emulateBinaryOpForUnknownBits(
162 ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
163 Operation *op,
164 llvm::function_ref<APInt(const APInt &, const APInt &)> emulate) {
165 SmallVector<ConstantOrValue> lhsValues, rhsValues;
166
167 assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
168 "op must be a single result binary operation");
169
170 auto lhs = op->getOperand(0);
171 auto rhs = op->getOperand(1);
172 auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
173 auto loc = op->getLoc();
174 auto numLhsUnknownBits = getNumUnknownBitsAndPopulateValues(lhs, lhsValues);
175 auto numRhsUnknownBits = getNumUnknownBitsAndPopulateValues(rhs, rhsValues);
176
177 // If unknown bit width is detected, abort the lowering.
178 if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
179 return failure();
180
181 int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
182 if (totalUnknownBits > maxEmulationUnknownBits)
183 return failure();
184
185 SmallVector<Value> emulatedResults;
186 emulatedResults.reserve(1 << totalUnknownBits);
187
188 // Emulate all possible cases.
189 DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
190 auto getConstant = [&](const APInt &value) -> hw::ConstantOp {
191 auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
192 auto it = constantPool.find(attr);
193 if (it != constantPool.end())
194 return it->second;
195 auto constant = rewriter.create<hw::ConstantOp>(loc, value);
196 constantPool[attr] = constant;
197 return constant;
198 };
199
200 for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
201 lhsMask < lhsMaskEnd; ++lhsMask) {
202 APInt lhsValue = substitueMaskToValues(width, lhsValues, lhsMask);
203 for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
204 rhsMask < rhsMaskEnd; ++rhsMask) {
205 APInt rhsValue = substitueMaskToValues(width, rhsValues, rhsMask);
206 // Emulate.
207 emulatedResults.push_back(getConstant(emulate(lhsValue, rhsValue)));
208 }
209 }
210
211 // Create selectors for mux tree.
212 SmallVector<Value> selectors;
213 selectors.reserve(totalUnknownBits);
214 for (auto &concatedValues : {rhsValues, lhsValues})
215 for (auto valueOrConstant : concatedValues) {
216 auto value = dyn_cast<Value>(valueOrConstant);
217 if (!value)
218 continue;
219 extractBits(rewriter, value, selectors);
220 }
221
222 assert(totalUnknownBits == static_cast<int64_t>(selectors.size()) &&
223 "number of selectors must match");
224 auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
225 getConstant(APInt::getZero(width)));
226
227 rewriter.replaceOp(op, muxed);
228 return success();
229}
230
231//===----------------------------------------------------------------------===//
232// Conversion patterns
233//===----------------------------------------------------------------------===//
234
235namespace {
236
237/// Lower a comb::AndOp operation to aig::AndInverterOp
238struct CombAndOpConversion : OpConversionPattern<AndOp> {
240
241 LogicalResult
242 matchAndRewrite(AndOp op, OpAdaptor adaptor,
243 ConversionPatternRewriter &rewriter) const override {
244 SmallVector<bool> nonInverts(adaptor.getInputs().size(), false);
245 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, adaptor.getInputs(),
246 nonInverts);
247 return success();
248 }
249};
250
251/// Lower a comb::OrOp operation to aig::AndInverterOp with invert flags
252struct CombOrOpConversion : OpConversionPattern<OrOp> {
254
255 LogicalResult
256 matchAndRewrite(OrOp op, OpAdaptor adaptor,
257 ConversionPatternRewriter &rewriter) const override {
258 // Implement Or using And and invert flags: a | b = ~(~a & ~b)
259 SmallVector<bool> allInverts(adaptor.getInputs().size(), true);
260 auto andOp = rewriter.create<aig::AndInverterOp>(
261 op.getLoc(), adaptor.getInputs(), allInverts);
262 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, andOp,
263 /*invert=*/true);
264 return success();
265 }
266};
267
268/// Lower a comb::XorOp operation to AIG operations
269struct CombXorOpConversion : OpConversionPattern<XorOp> {
271
272 LogicalResult
273 matchAndRewrite(XorOp op, OpAdaptor adaptor,
274 ConversionPatternRewriter &rewriter) const override {
275 if (op.getNumOperands() != 2)
276 return failure();
277 // Xor using And with invert flags: a ^ b = (a | b) & (~a | ~b)
278
279 // (a | b) = ~(~a & ~b)
280 // (~a | ~b) = ~(a & b)
281 auto inputs = adaptor.getInputs();
282 SmallVector<bool> allInverts(inputs.size(), true);
283 SmallVector<bool> allNotInverts(inputs.size(), false);
284
285 auto notAAndNotB =
286 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allInverts);
287 auto aAndB =
288 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allNotInverts);
289
290 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, notAAndNotB, aAndB,
291 /*lhs_invert=*/true,
292 /*rhs_invert=*/true);
293 return success();
294 }
295};
296
297template <typename OpTy>
298struct CombLowerVariadicOp : OpConversionPattern<OpTy> {
300 using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
301 LogicalResult
302 matchAndRewrite(OpTy op, OpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter) const override {
304 auto result = lowerFullyAssociativeOp(op, op.getOperands(), rewriter);
305 rewriter.replaceOp(op, result);
306 return success();
307 }
308
309 static Value lowerFullyAssociativeOp(OpTy op, OperandRange operands,
310 ConversionPatternRewriter &rewriter) {
311 Value lhs, rhs;
312 switch (operands.size()) {
313 case 0:
314 assert(false && "cannot be called with empty operand range");
315 break;
316 case 1:
317 return operands[0];
318 case 2:
319 lhs = operands[0];
320 rhs = operands[1];
321 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs}, true);
322 default:
323 auto firstHalf = operands.size() / 2;
324 lhs =
325 lowerFullyAssociativeOp(op, operands.take_front(firstHalf), rewriter);
326 rhs =
327 lowerFullyAssociativeOp(op, operands.drop_front(firstHalf), rewriter);
328 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs}, true);
329 }
330 }
331};
332
333// Lower comb::MuxOp to AIG operations.
334struct CombMuxOpConversion : OpConversionPattern<MuxOp> {
336
337 LogicalResult
338 matchAndRewrite(MuxOp op, OpAdaptor adaptor,
339 ConversionPatternRewriter &rewriter) const override {
340 // Implement: c ? a : b = (replicate(c) & a) | (~replicate(c) & b)
341
342 Value cond = op.getCond();
343 auto trueVal = op.getTrueValue();
344 auto falseVal = op.getFalseValue();
345
346 if (!op.getType().isInteger()) {
347 // If the type of the mux is not integer, bitcast the operands first.
348 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
349 trueVal =
350 rewriter.create<hw::BitcastOp>(op->getLoc(), widthType, trueVal);
351 falseVal =
352 rewriter.create<hw::BitcastOp>(op->getLoc(), widthType, falseVal);
353 }
354
355 // Replicate condition if needed
356 if (!trueVal.getType().isInteger(1))
357 cond = rewriter.create<comb::ReplicateOp>(op.getLoc(), trueVal.getType(),
358 cond);
359
360 // c ? a : b => (replicate(c) & a) | (~replicate(c) & b)
361 auto lhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, trueVal);
362 auto rhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, falseVal,
363 true, false);
364
365 Value result = rewriter.create<comb::OrOp>(op.getLoc(), lhs, rhs);
366 // Insert the bitcast if the type of the mux is not integer.
367 if (result.getType() != op.getType())
368 result =
369 rewriter.create<hw::BitcastOp>(op.getLoc(), op.getType(), result);
370 rewriter.replaceOp(op, result);
371 return success();
372 }
373};
374
375struct CombAddOpConversion : OpConversionPattern<AddOp> {
377 LogicalResult
378 matchAndRewrite(AddOp op, OpAdaptor adaptor,
379 ConversionPatternRewriter &rewriter) const override {
380 auto inputs = adaptor.getInputs();
381 // Lower only when there are two inputs.
382 // Variadic operands must be lowered in a different pattern.
383 if (inputs.size() != 2)
384 return failure();
385
386 auto width = op.getType().getIntOrFloatBitWidth();
387 // Skip a zero width value.
388 if (width == 0) {
389 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
390 return success();
391 }
392
393 // Implement a naive Ripple-carry full adder.
394 Value carry;
395
396 auto aBits = extractBits(rewriter, inputs[0]);
397 auto bBits = extractBits(rewriter, inputs[1]);
398 SmallVector<Value> results;
399 results.resize(width);
400 for (int64_t i = 0; i < width; ++i) {
401 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
402 if (carry)
403 xorOperands.push_back(carry);
404
405 // sum[i] = xor(carry[i-1], a[i], b[i])
406 // NOTE: The result is stored in reverse order.
407 results[width - i - 1] =
408 rewriter.create<comb::XorOp>(op.getLoc(), xorOperands, true);
409
410 // If this is the last bit, we are done.
411 if (i == width - 1) {
412 break;
413 }
414
415 // carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
416 Value nextCarry = rewriter.create<comb::AndOp>(
417 op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
418 if (!carry) {
419 // This is the first bit, so the carry is the next carry.
420 carry = nextCarry;
421 continue;
422 }
423
424 auto aXnorB = rewriter.create<comb::XorOp>(
425 op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
426 auto andOp = rewriter.create<comb::AndOp>(
427 op.getLoc(), ValueRange{carry, aXnorB}, true);
428 carry = rewriter.create<comb::OrOp>(op.getLoc(),
429 ValueRange{andOp, nextCarry}, true);
430 }
431
432 rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results);
433 return success();
434 }
435};
436
437struct CombSubOpConversion : OpConversionPattern<SubOp> {
439 LogicalResult
440 matchAndRewrite(SubOp op, OpAdaptor adaptor,
441 ConversionPatternRewriter &rewriter) const override {
442 auto lhs = op.getLhs();
443 auto rhs = op.getRhs();
444 // Since `-rhs = ~rhs + 1` holds, rewrite `sub(lhs, rhs)` to:
445 // sub(lhs, rhs) => add(lhs, -rhs) => add(lhs, add(~rhs, 1))
446 // => add(lhs, ~rhs, 1)
447 auto notRhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), rhs,
448 /*invert=*/true);
449 auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
450 rewriter.replaceOpWithNewOp<comb::AddOp>(op, ValueRange{lhs, notRhs, one},
451 true);
452 return success();
453 }
454};
455
456struct CombMulOpConversion : OpConversionPattern<MulOp> {
458 using OpAdaptor = typename OpConversionPattern<MulOp>::OpAdaptor;
459 LogicalResult
460 matchAndRewrite(MulOp op, OpAdaptor adaptor,
461 ConversionPatternRewriter &rewriter) const override {
462 if (adaptor.getInputs().size() != 2)
463 return failure();
464
465 // FIXME: Currently it's lowered to a really naive implementation that
466 // chains add operations.
467
468 // a_{n}a_{n-1}...a_0 * b
469 // = sum_{i=0}^{n} a_i * 2^i * b
470 // = sum_{i=0}^{n} (a_i ? b : 0) << i
471 int64_t width = op.getType().getIntOrFloatBitWidth();
472 auto aBits = extractBits(rewriter, adaptor.getInputs()[0]);
473 SmallVector<Value> results;
474 auto rhs = op.getInputs()[1];
475 auto zero = rewriter.create<hw::ConstantOp>(op.getLoc(),
476 llvm::APInt::getZero(width));
477 for (int64_t i = 0; i < width; ++i) {
478 auto aBit = aBits[i];
479 auto andBit =
480 rewriter.createOrFold<comb::MuxOp>(op.getLoc(), aBit, rhs, zero);
481 auto upperBits = rewriter.createOrFold<comb::ExtractOp>(
482 op.getLoc(), andBit, 0, width - i);
483 if (i == 0) {
484 results.push_back(upperBits);
485 continue;
486 }
487
488 auto lowerBits =
489 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(i));
490
491 auto shifted = rewriter.createOrFold<comb::ConcatOp>(
492 op.getLoc(), op.getType(), ValueRange{upperBits, lowerBits});
493 results.push_back(shifted);
494 }
495
496 rewriter.replaceOpWithNewOp<comb::AddOp>(op, results, true);
497 return success();
498 }
499};
500
501template <typename OpTy>
502struct DivModOpConversionBase : OpConversionPattern<OpTy> {
503 DivModOpConversionBase(MLIRContext *context, int64_t maxEmulationUnknownBits)
504 : OpConversionPattern<OpTy>(context),
505 maxEmulationUnknownBits(maxEmulationUnknownBits) {
506 assert(maxEmulationUnknownBits < 32 &&
507 "maxEmulationUnknownBits must be less than 32");
508 }
509 const int64_t maxEmulationUnknownBits;
510};
511
512struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
513 using DivModOpConversionBase<DivUOp>::DivModOpConversionBase;
514 LogicalResult
515 matchAndRewrite(DivUOp op, OpAdaptor adaptor,
516 ConversionPatternRewriter &rewriter) const override {
517 // Check if the divisor is a power of two.
518 if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
519 if (rhsConstantOp.getValue().isPowerOf2()) {
520 // Extract upper bits.
521 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
522 size_t width = op.getType().getIntOrFloatBitWidth();
523 Value upperBits = rewriter.createOrFold<comb::ExtractOp>(
524 op.getLoc(), adaptor.getLhs(), extractAmount,
525 width - extractAmount);
526 Value constZero = rewriter.create<hw::ConstantOp>(
527 op.getLoc(), APInt::getZero(extractAmount));
528 rewriter.replaceOpWithNewOp<comb::ConcatOp>(
529 op, op.getType(), ArrayRef<Value>{constZero, upperBits});
530 return success();
531 }
532
533 // When rhs is not power of two and the number of unknown bits are small,
534 // create a mux tree that emulates all possible cases.
536 rewriter, maxEmulationUnknownBits, op,
537 [](const APInt &lhs, const APInt &rhs) {
538 // Division by zero is undefined, just return zero.
539 if (rhs.isZero())
540 return APInt::getZero(rhs.getBitWidth());
541 return lhs.udiv(rhs);
542 });
543 }
544};
545
546struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
547 using DivModOpConversionBase<ModUOp>::DivModOpConversionBase;
548 LogicalResult
549 matchAndRewrite(ModUOp op, OpAdaptor adaptor,
550 ConversionPatternRewriter &rewriter) const override {
551 // Check if the divisor is a power of two.
552 if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
553 if (rhsConstantOp.getValue().isPowerOf2()) {
554 // Extract lower bits.
555 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
556 size_t width = op.getType().getIntOrFloatBitWidth();
557 Value lowerBits = rewriter.createOrFold<comb::ExtractOp>(
558 op.getLoc(), adaptor.getLhs(), 0, extractAmount);
559 Value constZero = rewriter.create<hw::ConstantOp>(
560 op.getLoc(), APInt::getZero(width - extractAmount));
561 rewriter.replaceOpWithNewOp<comb::ConcatOp>(
562 op, op.getType(), ArrayRef<Value>{constZero, lowerBits});
563 return success();
564 }
565
566 // When rhs is not power of two and the number of unknown bits are small,
567 // create a mux tree that emulates all possible cases.
569 rewriter, maxEmulationUnknownBits, op,
570 [](const APInt &lhs, const APInt &rhs) {
571 // Division by zero is undefined, just return zero.
572 if (rhs.isZero())
573 return APInt::getZero(rhs.getBitWidth());
574 return lhs.urem(rhs);
575 });
576 }
577};
578
579struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
580 using DivModOpConversionBase<DivSOp>::DivModOpConversionBase;
581
582 LogicalResult
583 matchAndRewrite(DivSOp op, OpAdaptor adaptor,
584 ConversionPatternRewriter &rewriter) const override {
585 // Currently only lower with emulation.
586 // TODO: Implement a signed division lowering at least for power of two.
588 rewriter, maxEmulationUnknownBits, op,
589 [](const APInt &lhs, const APInt &rhs) {
590 // Division by zero is undefined, just return zero.
591 if (rhs.isZero())
592 return APInt::getZero(rhs.getBitWidth());
593 return lhs.sdiv(rhs);
594 });
595 }
596};
597
598struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
599 using DivModOpConversionBase<ModSOp>::DivModOpConversionBase;
600 LogicalResult
601 matchAndRewrite(ModSOp op, OpAdaptor adaptor,
602 ConversionPatternRewriter &rewriter) const override {
603 // Currently only lower with emulation.
604 // TODO: Implement a signed modulus lowering at least for power of two.
606 rewriter, maxEmulationUnknownBits, op,
607 [](const APInt &lhs, const APInt &rhs) {
608 // Division by zero is undefined, just return zero.
609 if (rhs.isZero())
610 return APInt::getZero(rhs.getBitWidth());
611 return lhs.srem(rhs);
612 });
613 }
614};
615
616struct CombICmpOpConversion : OpConversionPattern<ICmpOp> {
618 static Value constructUnsignedCompare(ICmpOp op, ArrayRef<Value> aBits,
619 ArrayRef<Value> bBits, bool isLess,
620 bool includeEq,
621 ConversionPatternRewriter &rewriter) {
622 // Construct following unsigned comparison expressions.
623 // a <= b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] <= b[n-1:0])
624 // a < b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] < b[n-1:0])
625 // a >= b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] >= b[n-1:0])
626 // a > b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] > b[n-1:0])
627 Value acc =
628 rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), includeEq);
629
630 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
631 auto aBitXorBBit =
632 rewriter.createOrFold<comb::XorOp>(op.getLoc(), aBit, bBit, true);
633 auto aEqualB = rewriter.createOrFold<aig::AndInverterOp>(
634 op.getLoc(), aBitXorBBit, true);
635 auto pred = rewriter.createOrFold<aig::AndInverterOp>(
636 op.getLoc(), aBit, bBit, isLess, !isLess);
637
638 auto aBitAndBBit = rewriter.createOrFold<comb::AndOp>(
639 op.getLoc(), ValueRange{aEqualB, acc}, true);
640 acc = rewriter.createOrFold<comb::OrOp>(op.getLoc(), pred, aBitAndBBit,
641 true);
642 }
643 return acc;
644 }
645
646 LogicalResult
647 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
648 ConversionPatternRewriter &rewriter) const override {
649 auto lhs = adaptor.getLhs();
650 auto rhs = adaptor.getRhs();
651
652 switch (op.getPredicate()) {
653 default:
654 return failure();
655
656 case ICmpPredicate::eq:
657 case ICmpPredicate::ceq: {
658 // a == b ==> ~(a[n] ^ b[n]) & ~(a[n-1] ^ b[n-1]) & ...
659 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
660 auto xorBits = extractBits(rewriter, xorOp);
661 SmallVector<bool> allInverts(xorBits.size(), true);
662 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, xorBits, allInverts);
663 return success();
664 }
665
666 case ICmpPredicate::ne:
667 case ICmpPredicate::cne: {
668 // a != b ==> (a[n] ^ b[n]) | (a[n-1] ^ b[n-1]) | ...
669 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
670 rewriter.replaceOpWithNewOp<comb::OrOp>(op, extractBits(rewriter, xorOp),
671 true);
672 return success();
673 }
674
675 case ICmpPredicate::uge:
676 case ICmpPredicate::ugt:
677 case ICmpPredicate::ule:
678 case ICmpPredicate::ult: {
679 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
680 op.getPredicate() == ICmpPredicate::ule;
681 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
682 op.getPredicate() == ICmpPredicate::ule;
683 auto aBits = extractBits(rewriter, lhs);
684 auto bBits = extractBits(rewriter, rhs);
685 rewriter.replaceOp(op, constructUnsignedCompare(op, aBits, bBits, isLess,
686 includeEq, rewriter));
687 return success();
688 }
689 case ICmpPredicate::slt:
690 case ICmpPredicate::sle:
691 case ICmpPredicate::sgt:
692 case ICmpPredicate::sge: {
693 if (lhs.getType().getIntOrFloatBitWidth() == 0)
694 return rewriter.notifyMatchFailure(
695 op.getLoc(), "i0 signed comparison is unsupported");
696 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
697 op.getPredicate() == ICmpPredicate::sle;
698 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
699 op.getPredicate() == ICmpPredicate::sle;
700
701 auto aBits = extractBits(rewriter, lhs);
702 auto bBits = extractBits(rewriter, rhs);
703
704 // Get a sign bit
705 auto signA = aBits.back();
706 auto signB = bBits.back();
707
708 // Compare magnitudes (all bits except sign)
709 auto sameSignResult = constructUnsignedCompare(
710 op, ArrayRef(aBits).drop_back(), ArrayRef(bBits).drop_back(), isLess,
711 includeEq, rewriter);
712
713 // XOR of signs: true if signs are different
714 auto signsDiffer =
715 rewriter.create<comb::XorOp>(op.getLoc(), signA, signB);
716
717 // Result when signs are different
718 Value diffSignResult = isLess ? signA : signB;
719
720 // Final result: choose based on whether signs differ
721 rewriter.replaceOpWithNewOp<comb::MuxOp>(op, signsDiffer, diffSignResult,
722 sameSignResult);
723 return success();
724 }
725 }
726 }
727};
728
729struct CombParityOpConversion : OpConversionPattern<ParityOp> {
731
732 LogicalResult
733 matchAndRewrite(ParityOp op, OpAdaptor adaptor,
734 ConversionPatternRewriter &rewriter) const override {
735 // Parity is the XOR of all bits.
736 rewriter.replaceOpWithNewOp<comb::XorOp>(
737 op, extractBits(rewriter, adaptor.getInput()), true);
738 return success();
739 }
740};
741
742struct CombShlOpConversion : OpConversionPattern<comb::ShlOp> {
744
745 LogicalResult
746 matchAndRewrite(comb::ShlOp op, OpAdaptor adaptor,
747 ConversionPatternRewriter &rewriter) const override {
748 auto width = op.getType().getIntOrFloatBitWidth();
749 auto lhs = adaptor.getLhs();
750 auto result = createShiftLogic</*isLeftShift=*/true>(
751 rewriter, op.getLoc(), adaptor.getRhs(), width,
752 /*getPadding=*/
753 [&](int64_t index) {
754 // Don't create zero width value.
755 if (index == 0)
756 return Value();
757 // Padding is 0 for left shift.
758 return rewriter.createOrFold<hw::ConstantOp>(
759 op.getLoc(), rewriter.getIntegerType(index), 0);
760 },
761 /*getExtract=*/
762 [&](int64_t index) {
763 assert(index < width && "index out of bounds");
764 // Exract the bits from LSB.
765 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, 0,
766 width - index);
767 });
768
769 rewriter.replaceOp(op, result);
770 return success();
771 }
772};
773
774struct CombShrUOpConversion : OpConversionPattern<comb::ShrUOp> {
776
777 LogicalResult
778 matchAndRewrite(comb::ShrUOp op, OpAdaptor adaptor,
779 ConversionPatternRewriter &rewriter) const override {
780 auto width = op.getType().getIntOrFloatBitWidth();
781 auto lhs = adaptor.getLhs();
782 auto result = createShiftLogic</*isLeftShift=*/false>(
783 rewriter, op.getLoc(), adaptor.getRhs(), width,
784 /*getPadding=*/
785 [&](int64_t index) {
786 // Don't create zero width value.
787 if (index == 0)
788 return Value();
789 // Padding is 0 for right shift.
790 return rewriter.createOrFold<hw::ConstantOp>(
791 op.getLoc(), rewriter.getIntegerType(index), 0);
792 },
793 /*getExtract=*/
794 [&](int64_t index) {
795 assert(index < width && "index out of bounds");
796 // Exract the bits from MSB.
797 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
798 width - index);
799 });
800
801 rewriter.replaceOp(op, result);
802 return success();
803 }
804};
805
806struct CombShrSOpConversion : OpConversionPattern<comb::ShrSOp> {
808
809 LogicalResult
810 matchAndRewrite(comb::ShrSOp op, OpAdaptor adaptor,
811 ConversionPatternRewriter &rewriter) const override {
812 auto width = op.getType().getIntOrFloatBitWidth();
813 if (width == 0)
814 return rewriter.notifyMatchFailure(op.getLoc(),
815 "i0 signed shift is unsupported");
816 auto lhs = adaptor.getLhs();
817 // Get the sign bit.
818 auto sign =
819 rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
820
821 // NOTE: The max shift amount is width - 1 because the sign bit is
822 // already shifted out.
823 auto result = createShiftLogic</*isLeftShift=*/false>(
824 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
825 /*getPadding=*/
826 [&](int64_t index) {
827 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
828 index + 1);
829 },
830 /*getExtract=*/
831 [&](int64_t index) {
832 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
833 width - index - 1);
834 });
835
836 rewriter.replaceOp(op, result);
837 return success();
838 }
839};
840
841} // namespace
842
843//===----------------------------------------------------------------------===//
844// Convert Comb to AIG pass
845//===----------------------------------------------------------------------===//
846
847namespace {
848struct ConvertCombToAIGPass
849 : public impl::ConvertCombToAIGBase<ConvertCombToAIGPass> {
850 void runOnOperation() override;
851 using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
852 using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
853 using ConvertCombToAIGBase<ConvertCombToAIGPass>::maxEmulationUnknownBits;
854};
855} // namespace
856
857static void
859 uint32_t maxEmulationUnknownBits) {
860 patterns.add<
861 // Bitwise Logical Ops
862 CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
863 CombMuxOpConversion, CombParityOpConversion,
864 // Arithmetic Ops
865 CombAddOpConversion, CombSubOpConversion, CombMulOpConversion,
866 CombICmpOpConversion,
867 // Shift Ops
868 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
869 // Variadic ops that must be lowered to binary operations
870 CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
871 CombLowerVariadicOp<MulOp>>(patterns.getContext());
872
873 // Add div/mod patterns with a threshold given by the pass option.
874 patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
875 CombModSOpConversion>(patterns.getContext(),
876 maxEmulationUnknownBits);
877}
878
879void ConvertCombToAIGPass::runOnOperation() {
880 ConversionTarget target(getContext());
881
882 // Comb is source dialect.
883 target.addIllegalDialect<comb::CombDialect>();
884 // Keep data movement operations like Extract, Concat and Replicate.
885 target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
887
888 // Treat array operations as illegal. Strictly speaking, other than array
889 // get operation with non-const index are legal in AIG but array types
890 // prevent a bunch of optimizations so just lower them to integer
891 // operations. It's required to run HWAggregateToComb pass before this pass.
893 hw::AggregateConstantOp>();
894
895 // AIG is target dialect.
896 target.addLegalDialect<aig::AIGDialect>();
897
898 // This is a test only option to add logical ops.
899 if (!additionalLegalOps.empty())
900 for (const auto &opName : additionalLegalOps)
901 target.addLegalOp(OperationName(opName, &getContext()));
902
903 RewritePatternSet patterns(&getContext());
904 populateCombToAIGConversionPatterns(patterns, maxEmulationUnknownBits);
905
906 if (failed(mlir::applyPartialConversion(getOperation(), target,
907 std::move(patterns))))
908 return signalPassFailure();
909}
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
Definition CombToAIG.cpp:34
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
Definition CombToAIG.cpp:48
static APInt substitueMaskToValues(size_t width, llvm::SmallVectorImpl< ConstantOrValue > &constantOrValues, uint32_t mask)
static LogicalResult emulateBinaryOpForUnknownBits(ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits, Operation *op, llvm::function_ref< APInt(const APInt &, const APInt &)> emulate)
static int64_t getNumUnknownBitsAndPopulateValues(Value value, llvm::SmallVectorImpl< ConstantOrValue > &values)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, uint32_t maxEmulationUnknownBits)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
create(low_bit, result_type, input=None)
Definition comb.py:187
create(data_type, value)
Definition hw.py:441
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1