Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
CombToAIG.cpp
Go to the documentation of this file.
1//===- CombToAIG.cpp - Comb to AIG Conversion Pass --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main Comb to AIG Conversion Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/DialectConversion.h"
19#include "llvm/ADT/PointerUnion.h"
20#include "llvm/Support/Debug.h"
21
22#define DEBUG_TYPE "comb-to-aig"
23
24namespace circt {
25#define GEN_PASS_DEF_CONVERTCOMBTOAIG
26#include "circt/Conversion/Passes.h.inc"
27} // namespace circt
28
29using namespace circt;
30using namespace comb;
31
32//===----------------------------------------------------------------------===//
33// Utility Functions
34//===----------------------------------------------------------------------===//
35
36// A wrapper for comb::extractBits that returns a SmallVector<Value>.
37static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
38 SmallVector<Value> bits;
39 comb::extractBits(builder, val, bits);
40 return bits;
41}
42
43// Construct a mux tree for shift operations. `isLeftShift` controls the
44// direction of the shift operation and is used to determine order of the
45// padding and extracted bits. Callbacks `getPadding` and `getExtract` are used
46// to get the padding and extracted bits for each shift amount. `getPadding`
47// could return a nullptr as i0 value but except for that, these callbacks must
48// return a valid value for each shift amount in the range [0, maxShiftAmount].
49// The value for `maxShiftAmount` is used as the out-of-bounds value.
50template <bool isLeftShift>
51static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc,
52 Value shiftAmount, int64_t maxShiftAmount,
53 llvm::function_ref<Value(int64_t)> getPadding,
54 llvm::function_ref<Value(int64_t)> getExtract) {
55 // Extract individual bits from shift amount
56 auto bits = extractBits(rewriter, shiftAmount);
57
58 // Create nodes for each possible shift amount
59 SmallVector<Value> nodes;
60 nodes.reserve(maxShiftAmount);
61 for (int64_t i = 0; i < maxShiftAmount; ++i) {
62 Value extract = getExtract(i);
63 Value padding = getPadding(i);
64
65 if (!padding) {
66 nodes.push_back(extract);
67 continue;
68 }
69
70 // Concatenate extracted bits with padding
71 if (isLeftShift)
72 nodes.push_back(
73 rewriter.createOrFold<comb::ConcatOp>(loc, extract, padding));
74 else
75 nodes.push_back(
76 rewriter.createOrFold<comb::ConcatOp>(loc, padding, extract));
77 }
78
79 // Create out-of-bounds value
80 auto outOfBoundsValue = getPadding(maxShiftAmount);
81 assert(outOfBoundsValue && "outOfBoundsValue must be valid");
82
83 // Construct mux tree for shift operation
84 auto result =
85 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
86
87 // Add bounds checking
88 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
89 loc, ICmpPredicate::ult, shiftAmount,
90 rewriter.create<hw::ConstantOp>(loc, shiftAmount.getType(),
91 maxShiftAmount));
92
93 return rewriter.createOrFold<comb::MuxOp>(loc, inBound, result,
94 outOfBoundsValue);
95}
96
97namespace {
98// A union of Value and IntegerAttr to cleanly handle constant values.
99using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
100} // namespace
101
102// Return the number of unknown bits and populate the concatenated values.
104 Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
105 // Constant or zero width value are all known.
106 if (value.getType().isInteger(0))
107 return 0;
108
109 // Recursively count unknown bits for concat.
110 if (auto concat = value.getDefiningOp<comb::ConcatOp>()) {
111 int64_t totalUnknownBits = 0;
112 for (auto concatInput : llvm::reverse(concat.getInputs())) {
113 auto unknownBits =
114 getNumUnknownBitsAndPopulateValues(concatInput, values);
115 if (unknownBits < 0)
116 return unknownBits;
117 totalUnknownBits += unknownBits;
118 }
119 return totalUnknownBits;
120 }
121
122 // Constant value is known.
123 if (auto constant = value.getDefiningOp<hw::ConstantOp>()) {
124 values.push_back(constant.getValueAttr());
125 return 0;
126 }
127
128 // Consider other operations as unknown bits.
129 // TODO: We can handle replicate, extract, etc.
130 values.push_back(value);
131 return hw::getBitWidth(value.getType());
132}
133
134// Return a value that substitutes the unknown bits with the mask.
135static APInt
137 llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
138 uint32_t mask) {
139 uint32_t bitPos = 0, unknownPos = 0;
140 APInt result(width, 0);
141 for (auto constantOrValue : constantOrValues) {
142 int64_t elemWidth;
143 if (auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
144 elemWidth = constant.getValue().getBitWidth();
145 result.insertBits(constant.getValue(), bitPos);
146 } else {
147 elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
148 assert(elemWidth >= 0 && "unknown bit width");
149 assert(elemWidth + unknownPos < 32 && "unknown bit width too large");
150 // Create a mask for the unknown bits.
151 uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
152 result.insertBits(APInt(elemWidth, usedBits), bitPos);
153 unknownPos += elemWidth;
154 }
155 bitPos += elemWidth;
156 }
157
158 return result;
159}
160
161// Emulate a binary operation with unknown bits using a table lookup.
162// This function enumerates all possible combinations of unknown bits and
163// emulates the operation for each combination.
164static LogicalResult emulateBinaryOpForUnknownBits(
165 ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
166 Operation *op,
167 llvm::function_ref<APInt(const APInt &, const APInt &)> emulate) {
168 SmallVector<ConstantOrValue> lhsValues, rhsValues;
169
170 assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
171 "op must be a single result binary operation");
172
173 auto lhs = op->getOperand(0);
174 auto rhs = op->getOperand(1);
175 auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
176 auto loc = op->getLoc();
177 auto numLhsUnknownBits = getNumUnknownBitsAndPopulateValues(lhs, lhsValues);
178 auto numRhsUnknownBits = getNumUnknownBitsAndPopulateValues(rhs, rhsValues);
179
180 // If unknown bit width is detected, abort the lowering.
181 if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
182 return failure();
183
184 int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
185 if (totalUnknownBits > maxEmulationUnknownBits)
186 return failure();
187
188 SmallVector<Value> emulatedResults;
189 emulatedResults.reserve(1 << totalUnknownBits);
190
191 // Emulate all possible cases.
192 DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
193 auto getConstant = [&](const APInt &value) -> hw::ConstantOp {
194 auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
195 auto it = constantPool.find(attr);
196 if (it != constantPool.end())
197 return it->second;
198 auto constant = rewriter.create<hw::ConstantOp>(loc, value);
199 constantPool[attr] = constant;
200 return constant;
201 };
202
203 for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
204 lhsMask < lhsMaskEnd; ++lhsMask) {
205 APInt lhsValue = substitueMaskToValues(width, lhsValues, lhsMask);
206 for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
207 rhsMask < rhsMaskEnd; ++rhsMask) {
208 APInt rhsValue = substitueMaskToValues(width, rhsValues, rhsMask);
209 // Emulate.
210 emulatedResults.push_back(getConstant(emulate(lhsValue, rhsValue)));
211 }
212 }
213
214 // Create selectors for mux tree.
215 SmallVector<Value> selectors;
216 selectors.reserve(totalUnknownBits);
217 for (auto &concatedValues : {rhsValues, lhsValues})
218 for (auto valueOrConstant : concatedValues) {
219 auto value = dyn_cast<Value>(valueOrConstant);
220 if (!value)
221 continue;
222 extractBits(rewriter, value, selectors);
223 }
224
225 assert(totalUnknownBits == static_cast<int64_t>(selectors.size()) &&
226 "number of selectors must match");
227 auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
228 getConstant(APInt::getZero(width)));
229
230 rewriter.replaceOp(op, muxed);
231 return success();
232}
233
234//===----------------------------------------------------------------------===//
235// Conversion patterns
236//===----------------------------------------------------------------------===//
237
238namespace {
239
240/// Lower a comb::AndOp operation to aig::AndInverterOp
241struct CombAndOpConversion : OpConversionPattern<AndOp> {
243
244 LogicalResult
245 matchAndRewrite(AndOp op, OpAdaptor adaptor,
246 ConversionPatternRewriter &rewriter) const override {
247 SmallVector<bool> nonInverts(adaptor.getInputs().size(), false);
248 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, adaptor.getInputs(),
249 nonInverts);
250 return success();
251 }
252};
253
254/// Lower a comb::OrOp operation to aig::AndInverterOp with invert flags
255struct CombOrOpConversion : OpConversionPattern<OrOp> {
257
258 LogicalResult
259 matchAndRewrite(OrOp op, OpAdaptor adaptor,
260 ConversionPatternRewriter &rewriter) const override {
261 // Implement Or using And and invert flags: a | b = ~(~a & ~b)
262 SmallVector<bool> allInverts(adaptor.getInputs().size(), true);
263 auto andOp = rewriter.create<aig::AndInverterOp>(
264 op.getLoc(), adaptor.getInputs(), allInverts);
265 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, andOp,
266 /*invert=*/true);
267 return success();
268 }
269};
270
271/// Lower a comb::XorOp operation to AIG operations
272struct CombXorOpConversion : OpConversionPattern<XorOp> {
274
275 LogicalResult
276 matchAndRewrite(XorOp op, OpAdaptor adaptor,
277 ConversionPatternRewriter &rewriter) const override {
278 if (op.getNumOperands() != 2)
279 return failure();
280 // Xor using And with invert flags: a ^ b = (a | b) & (~a | ~b)
281
282 // (a | b) = ~(~a & ~b)
283 // (~a | ~b) = ~(a & b)
284 auto inputs = adaptor.getInputs();
285 SmallVector<bool> allInverts(inputs.size(), true);
286 SmallVector<bool> allNotInverts(inputs.size(), false);
287
288 auto notAAndNotB =
289 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allInverts);
290 auto aAndB =
291 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allNotInverts);
292
293 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, notAAndNotB, aAndB,
294 /*lhs_invert=*/true,
295 /*rhs_invert=*/true);
296 return success();
297 }
298};
299
300template <typename OpTy>
301struct CombLowerVariadicOp : OpConversionPattern<OpTy> {
303 using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
304 LogicalResult
305 matchAndRewrite(OpTy op, OpAdaptor adaptor,
306 ConversionPatternRewriter &rewriter) const override {
307 auto result = lowerFullyAssociativeOp(op, op.getOperands(), rewriter);
308 rewriter.replaceOp(op, result);
309 return success();
310 }
311
312 static Value lowerFullyAssociativeOp(OpTy op, OperandRange operands,
313 ConversionPatternRewriter &rewriter) {
314 Value lhs, rhs;
315 switch (operands.size()) {
316 case 0:
317 assert(false && "cannot be called with empty operand range");
318 break;
319 case 1:
320 return operands[0];
321 case 2:
322 lhs = operands[0];
323 rhs = operands[1];
324 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs}, true);
325 default:
326 auto firstHalf = operands.size() / 2;
327 lhs =
328 lowerFullyAssociativeOp(op, operands.take_front(firstHalf), rewriter);
329 rhs =
330 lowerFullyAssociativeOp(op, operands.drop_front(firstHalf), rewriter);
331 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs}, true);
332 }
333 }
334};
335
336// Lower comb::MuxOp to AIG operations.
337struct CombMuxOpConversion : OpConversionPattern<MuxOp> {
339
340 LogicalResult
341 matchAndRewrite(MuxOp op, OpAdaptor adaptor,
342 ConversionPatternRewriter &rewriter) const override {
343 // Implement: c ? a : b = (replicate(c) & a) | (~replicate(c) & b)
344
345 Value cond = op.getCond();
346 auto trueVal = op.getTrueValue();
347 auto falseVal = op.getFalseValue();
348
349 if (!op.getType().isInteger()) {
350 // If the type of the mux is not integer, bitcast the operands first.
351 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
352 trueVal =
353 rewriter.create<hw::BitcastOp>(op->getLoc(), widthType, trueVal);
354 falseVal =
355 rewriter.create<hw::BitcastOp>(op->getLoc(), widthType, falseVal);
356 }
357
358 // Replicate condition if needed
359 if (!trueVal.getType().isInteger(1))
360 cond = rewriter.create<comb::ReplicateOp>(op.getLoc(), trueVal.getType(),
361 cond);
362
363 // c ? a : b => (replicate(c) & a) | (~replicate(c) & b)
364 auto lhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, trueVal);
365 auto rhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, falseVal,
366 true, false);
367
368 Value result = rewriter.create<comb::OrOp>(op.getLoc(), lhs, rhs);
369 // Insert the bitcast if the type of the mux is not integer.
370 if (result.getType() != op.getType())
371 result =
372 rewriter.create<hw::BitcastOp>(op.getLoc(), op.getType(), result);
373 rewriter.replaceOp(op, result);
374 return success();
375 }
376};
377
378struct CombAddOpConversion : OpConversionPattern<AddOp> {
380 LogicalResult
381 matchAndRewrite(AddOp op, OpAdaptor adaptor,
382 ConversionPatternRewriter &rewriter) const override {
383 auto inputs = adaptor.getInputs();
384 // Lower only when there are two inputs.
385 // Variadic operands must be lowered in a different pattern.
386 if (inputs.size() != 2)
387 return failure();
388
389 auto width = op.getType().getIntOrFloatBitWidth();
390 // Skip a zero width value.
391 if (width == 0) {
392 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
393 return success();
394 }
395
396 if (width < 8)
397 lowerRippleCarryAdder(op, inputs, rewriter);
398 else
399 lowerParallelPrefixAdder(op, inputs, rewriter);
400
401 return success();
402 }
403
404 // Implement a basic ripple-carry adder for small bitwidths.
405 void lowerRippleCarryAdder(comb::AddOp op, ValueRange inputs,
406 ConversionPatternRewriter &rewriter) const {
407 auto width = op.getType().getIntOrFloatBitWidth();
408 // Implement a naive Ripple-carry full adder.
409 Value carry;
410
411 auto aBits = extractBits(rewriter, inputs[0]);
412 auto bBits = extractBits(rewriter, inputs[1]);
413 SmallVector<Value> results;
414 results.resize(width);
415 for (int64_t i = 0; i < width; ++i) {
416 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
417 if (carry)
418 xorOperands.push_back(carry);
419
420 // sum[i] = xor(carry[i-1], a[i], b[i])
421 // NOTE: The result is stored in reverse order.
422 results[width - i - 1] =
423 rewriter.create<comb::XorOp>(op.getLoc(), xorOperands, true);
424
425 // If this is the last bit, we are done.
426 if (i == width - 1)
427 break;
428
429 // carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
430 Value nextCarry = rewriter.create<comb::AndOp>(
431 op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
432 if (!carry) {
433 // This is the first bit, so the carry is the next carry.
434 carry = nextCarry;
435 continue;
436 }
437
438 auto aXnorB = rewriter.create<comb::XorOp>(
439 op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
440 auto andOp = rewriter.create<comb::AndOp>(
441 op.getLoc(), ValueRange{carry, aXnorB}, true);
442 carry = rewriter.create<comb::OrOp>(op.getLoc(),
443 ValueRange{andOp, nextCarry}, true);
444 }
445 LLVM_DEBUG(llvm::dbgs() << "Lower comb.add to Ripple-Carry Adder of width "
446 << width << "\n");
447
448 rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results);
449 }
450
451 // Implement a parallel prefix adder - with Kogge-Stone or Brent-Kung trees
452 // Will introduce unused signals for the carry bits but these will be removed
453 // by the AIG pass.
454 void lowerParallelPrefixAdder(comb::AddOp op, ValueRange inputs,
455 ConversionPatternRewriter &rewriter) const {
456 auto width = op.getType().getIntOrFloatBitWidth();
457
458 auto aBits = extractBits(rewriter, inputs[0]);
459 auto bBits = extractBits(rewriter, inputs[1]);
460 // Construct propagate (p) and generate (g) signals
461 SmallVector<Value> p, g;
462 p.reserve(width);
463 g.reserve(width);
464
465 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
466 // p_i = a_i XOR b_i
467 p.push_back(rewriter.create<comb::XorOp>(op.getLoc(), aBit, bBit));
468 // g_i = a_i AND b_i
469 g.push_back(rewriter.create<comb::AndOp>(op.getLoc(), aBit, bBit));
470 }
471
472 LLVM_DEBUG({
473 llvm::dbgs() << "Lower comb.add to Parallel-Prefix of width " << width
474 << "\n--------------------------------------- Init\n";
475
476 for (int64_t i = 0; i < width; ++i) {
477 // p_i = a_i XOR b_i
478 llvm::dbgs() << "P0" << i << " = A" << i << " XOR B" << i << "\n";
479 // g_i = a_i AND b_i
480 llvm::dbgs() << "G0" << i << " = A" << i << " AND B" << i << "\n";
481 }
482 });
483
484 // Create copies of p and g for the prefix computation
485 SmallVector<Value> pPrefix = p;
486 SmallVector<Value> gPrefix = g;
487 if (width < 32)
488 lowerKoggeStonePrefixTree(op, inputs, rewriter, pPrefix, gPrefix);
489 else
490 lowerBrentKungPrefixTree(op, inputs, rewriter, pPrefix, gPrefix);
491
492 // Generate result sum bits
493 // NOTE: The result is stored in reverse order.
494 SmallVector<Value> results;
495 results.resize(width);
496 // Sum bit 0 is just p[0] since carry_in = 0
497 results[width - 1] = p[0];
498
499 // For remaining bits, sum_i = p_i XOR c_(i-1)
500 // The carry into position i is the group generate from position i-1
501 for (int64_t i = 1; i < width; ++i)
502 results[width - 1 - i] =
503 rewriter.create<comb::XorOp>(op.getLoc(), p[i], gPrefix[i - 1]);
504
505 rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results);
506
507 LLVM_DEBUG({
508 llvm::dbgs() << "--------------------------------------- Completion\n"
509 << "RES0 = P0\n";
510 for (int64_t i = 1; i < width; ++i)
511 llvm::dbgs() << "RES" << i << " = P" << i << " XOR G" << i - 1 << "\n";
512 });
513 }
514
515 // Implement the Kogge-Stone parallel prefix tree
516 // Described in https://en.wikipedia.org/wiki/Kogge%E2%80%93Stone_adder
517 // Slightly better delay than Brent-Kung, but more area.
518 void lowerKoggeStonePrefixTree(comb::AddOp op, ValueRange inputs,
519 ConversionPatternRewriter &rewriter,
520 SmallVector<Value> &pPrefix,
521 SmallVector<Value> &gPrefix) const {
522 auto width = op.getType().getIntOrFloatBitWidth();
523
524 // Kogge-Stone parallel prefix computation
525 for (int64_t stride = 1; stride < width; stride *= 2) {
526 for (int64_t i = stride; i < width; ++i) {
527 int64_t j = i - stride;
528 // Group generate: g_i OR (p_i AND g_j)
529 Value andPG =
530 rewriter.create<comb::AndOp>(op.getLoc(), pPrefix[i], gPrefix[j]);
531 gPrefix[i] =
532 rewriter.create<comb::OrOp>(op.getLoc(), gPrefix[i], andPG);
533
534 // Group propagate: p_i AND p_j
535 pPrefix[i] =
536 rewriter.create<comb::AndOp>(op.getLoc(), pPrefix[i], pPrefix[j]);
537 }
538 }
539 LLVM_DEBUG({
540 int64_t stage = 0;
541 for (int64_t stride = 1; stride < width; stride *= 2) {
542 llvm::dbgs()
543 << "--------------------------------------- Kogge-Stone Stage "
544 << stage << "\n";
545 for (int64_t i = stride; i < width; ++i) {
546 int64_t j = i - stride;
547 // Group generate: g_i OR (p_i AND g_j)
548 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
549 << " OR (P" << i << stage << " AND G" << j << stage
550 << ")\n";
551
552 // Group propagate: p_i AND p_j
553 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
554 << " AND P" << j << stage << "\n";
555 }
556 ++stage;
557 }
558 });
559 }
560
561 // Implement the Brent-Kung parallel prefix tree
562 // Described in https://en.wikipedia.org/wiki/Brent%E2%80%93Kung_adder
563 // Slightly worse delay than Kogge-Stone, but less area.
564 void lowerBrentKungPrefixTree(comb::AddOp op, ValueRange inputs,
565 ConversionPatternRewriter &rewriter,
566 SmallVector<Value> &pPrefix,
567 SmallVector<Value> &gPrefix) const {
568 auto width = op.getType().getIntOrFloatBitWidth();
569
570 // Brent-Kung parallel prefix computation
571 // Forward phase
572 int64_t stride;
573 for (stride = 1; stride < width; stride *= 2) {
574 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
575 int64_t j = i - stride;
576
577 // Group generate: g_i OR (p_i AND g_j)
578 Value andPG =
579 rewriter.create<comb::AndOp>(op.getLoc(), pPrefix[i], gPrefix[j]);
580 gPrefix[i] =
581 rewriter.create<comb::OrOp>(op.getLoc(), gPrefix[i], andPG);
582
583 // Group propagate: p_i AND p_j
584 pPrefix[i] =
585 rewriter.create<comb::AndOp>(op.getLoc(), pPrefix[i], pPrefix[j]);
586 }
587 }
588
589 // Backward phase
590 for (; stride > 0; stride /= 2) {
591 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
592 int64_t j = i - stride;
593
594 // Group generate: g_i OR (p_i AND g_j)
595 Value andPG =
596 rewriter.create<comb::AndOp>(op.getLoc(), pPrefix[i], gPrefix[j]);
597 gPrefix[i] = rewriter.create<OrOp>(op.getLoc(), gPrefix[i], andPG);
598
599 // Group propagate: p_i AND p_j
600 pPrefix[i] =
601 rewriter.create<comb::AndOp>(op.getLoc(), pPrefix[i], pPrefix[j]);
602 }
603 }
604
605 LLVM_DEBUG({
606 int64_t stage = 0;
607 for (stride = 1; stride < width; stride *= 2) {
608 llvm::dbgs() << "--------------------------------------- Brent-Kung FW "
609 << stage << " : Stride " << stride << "\n";
610 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
611 int64_t j = i - stride;
612
613 // Group generate: g_i OR (p_i AND g_j)
614 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
615 << " OR (P" << i << stage << " AND G" << j << stage
616 << ")\n";
617
618 // Group propagate: p_i AND p_j
619 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
620 << " AND P" << j << stage << "\n";
621 }
622 ++stage;
623 }
624
625 for (; stride > 0; stride /= 2) {
626 if (stride * 3 - 1 < width)
627 llvm::dbgs()
628 << "--------------------------------------- Brent-Kung BW "
629 << stage << " : Stride " << stride << "\n";
630
631 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
632 int64_t j = i - stride;
633
634 // Group generate: g_i OR (p_i AND g_j)
635 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
636 << " OR (P" << i << stage << " AND G" << j << stage
637 << ")\n";
638
639 // Group propagate: p_i AND p_j
640 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
641 << " AND P" << j << stage << "\n";
642 }
643 --stage;
644 }
645 });
646 }
647};
648
649struct CombSubOpConversion : OpConversionPattern<SubOp> {
651 LogicalResult
652 matchAndRewrite(SubOp op, OpAdaptor adaptor,
653 ConversionPatternRewriter &rewriter) const override {
654 auto lhs = op.getLhs();
655 auto rhs = op.getRhs();
656 // Since `-rhs = ~rhs + 1` holds, rewrite `sub(lhs, rhs)` to:
657 // sub(lhs, rhs) => add(lhs, -rhs) => add(lhs, add(~rhs, 1))
658 // => add(lhs, ~rhs, 1)
659 auto notRhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), rhs,
660 /*invert=*/true);
661 auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
662 rewriter.replaceOpWithNewOp<comb::AddOp>(op, ValueRange{lhs, notRhs, one},
663 true);
664 return success();
665 }
666};
667
668struct CombMulOpConversion : OpConversionPattern<MulOp> {
670 using OpAdaptor = typename OpConversionPattern<MulOp>::OpAdaptor;
671 LogicalResult
672 matchAndRewrite(MulOp op, OpAdaptor adaptor,
673 ConversionPatternRewriter &rewriter) const override {
674 if (adaptor.getInputs().size() != 2)
675 return failure();
676
677 // FIXME: Currently it's lowered to a really naive implementation that
678 // chains add operations.
679
680 // a_{n}a_{n-1}...a_0 * b
681 // = sum_{i=0}^{n} a_i * 2^i * b
682 // = sum_{i=0}^{n} (a_i ? b : 0) << i
683 int64_t width = op.getType().getIntOrFloatBitWidth();
684 auto aBits = extractBits(rewriter, adaptor.getInputs()[0]);
685 SmallVector<Value> results;
686 auto rhs = op.getInputs()[1];
687 auto zero = rewriter.create<hw::ConstantOp>(op.getLoc(),
688 llvm::APInt::getZero(width));
689 for (int64_t i = 0; i < width; ++i) {
690 auto aBit = aBits[i];
691 auto andBit =
692 rewriter.createOrFold<comb::MuxOp>(op.getLoc(), aBit, rhs, zero);
693 auto upperBits = rewriter.createOrFold<comb::ExtractOp>(
694 op.getLoc(), andBit, 0, width - i);
695 if (i == 0) {
696 results.push_back(upperBits);
697 continue;
698 }
699
700 auto lowerBits =
701 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(i));
702
703 auto shifted = rewriter.createOrFold<comb::ConcatOp>(
704 op.getLoc(), op.getType(), ValueRange{upperBits, lowerBits});
705 results.push_back(shifted);
706 }
707
708 rewriter.replaceOpWithNewOp<comb::AddOp>(op, results, true);
709 return success();
710 }
711};
712
713template <typename OpTy>
714struct DivModOpConversionBase : OpConversionPattern<OpTy> {
715 DivModOpConversionBase(MLIRContext *context, int64_t maxEmulationUnknownBits)
716 : OpConversionPattern<OpTy>(context),
717 maxEmulationUnknownBits(maxEmulationUnknownBits) {
718 assert(maxEmulationUnknownBits < 32 &&
719 "maxEmulationUnknownBits must be less than 32");
720 }
721 const int64_t maxEmulationUnknownBits;
722};
723
724struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
725 using DivModOpConversionBase<DivUOp>::DivModOpConversionBase;
726 LogicalResult
727 matchAndRewrite(DivUOp op, OpAdaptor adaptor,
728 ConversionPatternRewriter &rewriter) const override {
729 // Check if the divisor is a power of two.
730 if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
731 if (rhsConstantOp.getValue().isPowerOf2()) {
732 // Extract upper bits.
733 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
734 size_t width = op.getType().getIntOrFloatBitWidth();
735 Value upperBits = rewriter.createOrFold<comb::ExtractOp>(
736 op.getLoc(), adaptor.getLhs(), extractAmount,
737 width - extractAmount);
738 Value constZero = rewriter.create<hw::ConstantOp>(
739 op.getLoc(), APInt::getZero(extractAmount));
740 rewriter.replaceOpWithNewOp<comb::ConcatOp>(
741 op, op.getType(), ArrayRef<Value>{constZero, upperBits});
742 return success();
743 }
744
745 // When rhs is not power of two and the number of unknown bits are small,
746 // create a mux tree that emulates all possible cases.
748 rewriter, maxEmulationUnknownBits, op,
749 [](const APInt &lhs, const APInt &rhs) {
750 // Division by zero is undefined, just return zero.
751 if (rhs.isZero())
752 return APInt::getZero(rhs.getBitWidth());
753 return lhs.udiv(rhs);
754 });
755 }
756};
757
758struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
759 using DivModOpConversionBase<ModUOp>::DivModOpConversionBase;
760 LogicalResult
761 matchAndRewrite(ModUOp op, OpAdaptor adaptor,
762 ConversionPatternRewriter &rewriter) const override {
763 // Check if the divisor is a power of two.
764 if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
765 if (rhsConstantOp.getValue().isPowerOf2()) {
766 // Extract lower bits.
767 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
768 size_t width = op.getType().getIntOrFloatBitWidth();
769 Value lowerBits = rewriter.createOrFold<comb::ExtractOp>(
770 op.getLoc(), adaptor.getLhs(), 0, extractAmount);
771 Value constZero = rewriter.create<hw::ConstantOp>(
772 op.getLoc(), APInt::getZero(width - extractAmount));
773 rewriter.replaceOpWithNewOp<comb::ConcatOp>(
774 op, op.getType(), ArrayRef<Value>{constZero, lowerBits});
775 return success();
776 }
777
778 // When rhs is not power of two and the number of unknown bits are small,
779 // create a mux tree that emulates all possible cases.
781 rewriter, maxEmulationUnknownBits, op,
782 [](const APInt &lhs, const APInt &rhs) {
783 // Division by zero is undefined, just return zero.
784 if (rhs.isZero())
785 return APInt::getZero(rhs.getBitWidth());
786 return lhs.urem(rhs);
787 });
788 }
789};
790
791struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
792 using DivModOpConversionBase<DivSOp>::DivModOpConversionBase;
793
794 LogicalResult
795 matchAndRewrite(DivSOp op, OpAdaptor adaptor,
796 ConversionPatternRewriter &rewriter) const override {
797 // Currently only lower with emulation.
798 // TODO: Implement a signed division lowering at least for power of two.
800 rewriter, maxEmulationUnknownBits, op,
801 [](const APInt &lhs, const APInt &rhs) {
802 // Division by zero is undefined, just return zero.
803 if (rhs.isZero())
804 return APInt::getZero(rhs.getBitWidth());
805 return lhs.sdiv(rhs);
806 });
807 }
808};
809
810struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
811 using DivModOpConversionBase<ModSOp>::DivModOpConversionBase;
812 LogicalResult
813 matchAndRewrite(ModSOp op, OpAdaptor adaptor,
814 ConversionPatternRewriter &rewriter) const override {
815 // Currently only lower with emulation.
816 // TODO: Implement a signed modulus lowering at least for power of two.
818 rewriter, maxEmulationUnknownBits, op,
819 [](const APInt &lhs, const APInt &rhs) {
820 // Division by zero is undefined, just return zero.
821 if (rhs.isZero())
822 return APInt::getZero(rhs.getBitWidth());
823 return lhs.srem(rhs);
824 });
825 }
826};
827
828struct CombICmpOpConversion : OpConversionPattern<ICmpOp> {
830 static Value constructUnsignedCompare(ICmpOp op, ArrayRef<Value> aBits,
831 ArrayRef<Value> bBits, bool isLess,
832 bool includeEq,
833 ConversionPatternRewriter &rewriter) {
834 // Construct following unsigned comparison expressions.
835 // a <= b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] <= b[n-1:0])
836 // a < b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] < b[n-1:0])
837 // a >= b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] >= b[n-1:0])
838 // a > b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] > b[n-1:0])
839 Value acc =
840 rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), includeEq);
841
842 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
843 auto aBitXorBBit =
844 rewriter.createOrFold<comb::XorOp>(op.getLoc(), aBit, bBit, true);
845 auto aEqualB = rewriter.createOrFold<aig::AndInverterOp>(
846 op.getLoc(), aBitXorBBit, true);
847 auto pred = rewriter.createOrFold<aig::AndInverterOp>(
848 op.getLoc(), aBit, bBit, isLess, !isLess);
849
850 auto aBitAndBBit = rewriter.createOrFold<comb::AndOp>(
851 op.getLoc(), ValueRange{aEqualB, acc}, true);
852 acc = rewriter.createOrFold<comb::OrOp>(op.getLoc(), pred, aBitAndBBit,
853 true);
854 }
855 return acc;
856 }
857
858 LogicalResult
859 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
860 ConversionPatternRewriter &rewriter) const override {
861 auto lhs = adaptor.getLhs();
862 auto rhs = adaptor.getRhs();
863
864 switch (op.getPredicate()) {
865 default:
866 return failure();
867
868 case ICmpPredicate::eq:
869 case ICmpPredicate::ceq: {
870 // a == b ==> ~(a[n] ^ b[n]) & ~(a[n-1] ^ b[n-1]) & ...
871 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
872 auto xorBits = extractBits(rewriter, xorOp);
873 SmallVector<bool> allInverts(xorBits.size(), true);
874 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, xorBits, allInverts);
875 return success();
876 }
877
878 case ICmpPredicate::ne:
879 case ICmpPredicate::cne: {
880 // a != b ==> (a[n] ^ b[n]) | (a[n-1] ^ b[n-1]) | ...
881 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
882 rewriter.replaceOpWithNewOp<comb::OrOp>(op, extractBits(rewriter, xorOp),
883 true);
884 return success();
885 }
886
887 case ICmpPredicate::uge:
888 case ICmpPredicate::ugt:
889 case ICmpPredicate::ule:
890 case ICmpPredicate::ult: {
891 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
892 op.getPredicate() == ICmpPredicate::ule;
893 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
894 op.getPredicate() == ICmpPredicate::ule;
895 auto aBits = extractBits(rewriter, lhs);
896 auto bBits = extractBits(rewriter, rhs);
897 rewriter.replaceOp(op, constructUnsignedCompare(op, aBits, bBits, isLess,
898 includeEq, rewriter));
899 return success();
900 }
901 case ICmpPredicate::slt:
902 case ICmpPredicate::sle:
903 case ICmpPredicate::sgt:
904 case ICmpPredicate::sge: {
905 if (lhs.getType().getIntOrFloatBitWidth() == 0)
906 return rewriter.notifyMatchFailure(
907 op.getLoc(), "i0 signed comparison is unsupported");
908 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
909 op.getPredicate() == ICmpPredicate::sle;
910 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
911 op.getPredicate() == ICmpPredicate::sle;
912
913 auto aBits = extractBits(rewriter, lhs);
914 auto bBits = extractBits(rewriter, rhs);
915
916 // Get a sign bit
917 auto signA = aBits.back();
918 auto signB = bBits.back();
919
920 // Compare magnitudes (all bits except sign)
921 auto sameSignResult = constructUnsignedCompare(
922 op, ArrayRef(aBits).drop_back(), ArrayRef(bBits).drop_back(), isLess,
923 includeEq, rewriter);
924
925 // XOR of signs: true if signs are different
926 auto signsDiffer =
927 rewriter.create<comb::XorOp>(op.getLoc(), signA, signB);
928
929 // Result when signs are different
930 Value diffSignResult = isLess ? signA : signB;
931
932 // Final result: choose based on whether signs differ
933 rewriter.replaceOpWithNewOp<comb::MuxOp>(op, signsDiffer, diffSignResult,
934 sameSignResult);
935 return success();
936 }
937 }
938 }
939};
940
941struct CombParityOpConversion : OpConversionPattern<ParityOp> {
943
944 LogicalResult
945 matchAndRewrite(ParityOp op, OpAdaptor adaptor,
946 ConversionPatternRewriter &rewriter) const override {
947 // Parity is the XOR of all bits.
948 rewriter.replaceOpWithNewOp<comb::XorOp>(
949 op, extractBits(rewriter, adaptor.getInput()), true);
950 return success();
951 }
952};
953
954struct CombShlOpConversion : OpConversionPattern<comb::ShlOp> {
956
957 LogicalResult
958 matchAndRewrite(comb::ShlOp op, OpAdaptor adaptor,
959 ConversionPatternRewriter &rewriter) const override {
960 auto width = op.getType().getIntOrFloatBitWidth();
961 auto lhs = adaptor.getLhs();
962 auto result = createShiftLogic</*isLeftShift=*/true>(
963 rewriter, op.getLoc(), adaptor.getRhs(), width,
964 /*getPadding=*/
965 [&](int64_t index) {
966 // Don't create zero width value.
967 if (index == 0)
968 return Value();
969 // Padding is 0 for left shift.
970 return rewriter.createOrFold<hw::ConstantOp>(
971 op.getLoc(), rewriter.getIntegerType(index), 0);
972 },
973 /*getExtract=*/
974 [&](int64_t index) {
975 assert(index < width && "index out of bounds");
976 // Exract the bits from LSB.
977 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, 0,
978 width - index);
979 });
980
981 rewriter.replaceOp(op, result);
982 return success();
983 }
984};
985
986struct CombShrUOpConversion : OpConversionPattern<comb::ShrUOp> {
988
989 LogicalResult
990 matchAndRewrite(comb::ShrUOp op, OpAdaptor adaptor,
991 ConversionPatternRewriter &rewriter) const override {
992 auto width = op.getType().getIntOrFloatBitWidth();
993 auto lhs = adaptor.getLhs();
994 auto result = createShiftLogic</*isLeftShift=*/false>(
995 rewriter, op.getLoc(), adaptor.getRhs(), width,
996 /*getPadding=*/
997 [&](int64_t index) {
998 // Don't create zero width value.
999 if (index == 0)
1000 return Value();
1001 // Padding is 0 for right shift.
1002 return rewriter.createOrFold<hw::ConstantOp>(
1003 op.getLoc(), rewriter.getIntegerType(index), 0);
1004 },
1005 /*getExtract=*/
1006 [&](int64_t index) {
1007 assert(index < width && "index out of bounds");
1008 // Exract the bits from MSB.
1009 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
1010 width - index);
1011 });
1012
1013 rewriter.replaceOp(op, result);
1014 return success();
1015 }
1016};
1017
1018struct CombShrSOpConversion : OpConversionPattern<comb::ShrSOp> {
1020
1021 LogicalResult
1022 matchAndRewrite(comb::ShrSOp op, OpAdaptor adaptor,
1023 ConversionPatternRewriter &rewriter) const override {
1024 auto width = op.getType().getIntOrFloatBitWidth();
1025 if (width == 0)
1026 return rewriter.notifyMatchFailure(op.getLoc(),
1027 "i0 signed shift is unsupported");
1028 auto lhs = adaptor.getLhs();
1029 // Get the sign bit.
1030 auto sign =
1031 rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
1032
1033 // NOTE: The max shift amount is width - 1 because the sign bit is
1034 // already shifted out.
1035 auto result = createShiftLogic</*isLeftShift=*/false>(
1036 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
1037 /*getPadding=*/
1038 [&](int64_t index) {
1039 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
1040 index + 1);
1041 },
1042 /*getExtract=*/
1043 [&](int64_t index) {
1044 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
1045 width - index - 1);
1046 });
1047
1048 rewriter.replaceOp(op, result);
1049 return success();
1050 }
1051};
1052
1053} // namespace
1054
1055//===----------------------------------------------------------------------===//
1056// Convert Comb to AIG pass
1057//===----------------------------------------------------------------------===//
1058
1059namespace {
1060struct ConvertCombToAIGPass
1061 : public impl::ConvertCombToAIGBase<ConvertCombToAIGPass> {
1062 void runOnOperation() override;
1063 using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
1064 using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
1065 using ConvertCombToAIGBase<ConvertCombToAIGPass>::maxEmulationUnknownBits;
1066};
1067} // namespace
1068
1069static void
1071 uint32_t maxEmulationUnknownBits) {
1072 patterns.add<
1073 // Bitwise Logical Ops
1074 CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
1075 CombMuxOpConversion, CombParityOpConversion,
1076 // Arithmetic Ops
1077 CombAddOpConversion, CombSubOpConversion, CombMulOpConversion,
1078 CombICmpOpConversion,
1079 // Shift Ops
1080 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
1081 // Variadic ops that must be lowered to binary operations
1082 CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
1083 CombLowerVariadicOp<MulOp>>(patterns.getContext());
1084
1085 // Add div/mod patterns with a threshold given by the pass option.
1086 patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
1087 CombModSOpConversion>(patterns.getContext(),
1088 maxEmulationUnknownBits);
1089}
1090
1091void ConvertCombToAIGPass::runOnOperation() {
1092 ConversionTarget target(getContext());
1093
1094 // Comb is source dialect.
1095 target.addIllegalDialect<comb::CombDialect>();
1096 // Keep data movement operations like Extract, Concat and Replicate.
1097 target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
1099
1100 // Treat array operations as illegal. Strictly speaking, other than array
1101 // get operation with non-const index are legal in AIG but array types
1102 // prevent a bunch of optimizations so just lower them to integer
1103 // operations. It's required to run HWAggregateToComb pass before this pass.
1105 hw::AggregateConstantOp>();
1106
1107 // AIG is target dialect.
1108 target.addLegalDialect<aig::AIGDialect>();
1109
1110 // If additional legal ops are specified, add them to the target.
1111 if (!additionalLegalOps.empty())
1112 for (const auto &opName : additionalLegalOps)
1113 target.addLegalOp(OperationName(opName, &getContext()));
1114
1115 RewritePatternSet patterns(&getContext());
1116 populateCombToAIGConversionPatterns(patterns, maxEmulationUnknownBits);
1117
1118 if (failed(mlir::applyPartialConversion(getOperation(), target,
1119 std::move(patterns))))
1120 return signalPassFailure();
1121}
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
Definition CombToAIG.cpp:37
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
Definition CombToAIG.cpp:51
static APInt substitueMaskToValues(size_t width, llvm::SmallVectorImpl< ConstantOrValue > &constantOrValues, uint32_t mask)
static LogicalResult emulateBinaryOpForUnknownBits(ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits, Operation *op, llvm::function_ref< APInt(const APInt &, const APInt &)> emulate)
static int64_t getNumUnknownBitsAndPopulateValues(Value value, llvm::SmallVectorImpl< ConstantOrValue > &values)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, uint32_t maxEmulationUnknownBits)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
create(low_bit, result_type, input=None)
Definition comb.py:187
create(data_type, value)
Definition hw.py:441
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1