CIRCT 23.0.0git
Loading...
Searching...
No Matches
CombToSynth.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main Comb to Synth Conversion Pass Implementation.
10//
11// High-level Comb Operations
12// |
13// v
14// +-------------------+
15// | and, or, xor, mux |
16// +---------+---------+
17// |
18// +-----+
19// | AIG |
20// +-----+
21//
22//===----------------------------------------------------------------------===//
23
31#include "mlir/Pass/Pass.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "llvm/ADT/APInt.h"
34#include "llvm/ADT/PointerUnion.h"
35#include "llvm/Support/Debug.h"
36#include "llvm/Support/DivisionByConstantInfo.h"
37#include <array>
38
39#define DEBUG_TYPE "comb-to-synth"
40
41namespace circt {
42#define GEN_PASS_DEF_CONVERTCOMBTOSYNTH
43#include "circt/Conversion/Passes.h.inc"
44} // namespace circt
45
46using namespace circt;
47using namespace comb;
48
49//===----------------------------------------------------------------------===//
50// Utility Functions
51//===----------------------------------------------------------------------===//
52
53// A wrapper for comb::extractBits that returns a SmallVector<Value>.
54static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
55 SmallVector<Value> bits;
56 comb::extractBits(builder, val, bits);
57 return bits;
58}
59
60// Construct a mux tree for shift operations. `isLeftShift` controls the
61// direction of the shift operation and is used to determine order of the
62// padding and extracted bits. Callbacks `getPadding` and `getExtract` are used
63// to get the padding and extracted bits for each shift amount. `getPadding`
64// could return a nullptr as i0 value but except for that, these callbacks must
65// return a valid value for each shift amount in the range [0, maxShiftAmount].
66// The value for `maxShiftAmount` is used as the out-of-bounds value.
67template <bool isLeftShift>
68static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc,
69 Value shiftAmount, int64_t maxShiftAmount,
70 llvm::function_ref<Value(int64_t)> getPadding,
71 llvm::function_ref<Value(int64_t)> getExtract) {
72 // Extract individual bits from shift amount
73 auto bits = extractBits(rewriter, shiftAmount);
74
75 // Create nodes for each possible shift amount
76 SmallVector<Value> nodes;
77 nodes.reserve(maxShiftAmount);
78 for (int64_t i = 0; i < maxShiftAmount; ++i) {
79 Value extract = getExtract(i);
80 Value padding = getPadding(i);
81
82 if (!padding) {
83 nodes.push_back(extract);
84 continue;
85 }
86
87 // Concatenate extracted bits with padding
88 if (isLeftShift)
89 nodes.push_back(
90 rewriter.createOrFold<comb::ConcatOp>(loc, extract, padding));
91 else
92 nodes.push_back(
93 rewriter.createOrFold<comb::ConcatOp>(loc, padding, extract));
94 }
95
96 // Create out-of-bounds value
97 auto outOfBoundsValue = getPadding(maxShiftAmount);
98 assert(outOfBoundsValue && "outOfBoundsValue must be valid");
99
100 // Construct mux tree for shift operation
101 auto result =
102 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
103
104 // Add bounds checking
105 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
106 loc, ICmpPredicate::ult, shiftAmount,
107 hw::ConstantOp::create(rewriter, loc, shiftAmount.getType(),
108 maxShiftAmount));
109
110 return rewriter.createOrFold<comb::MuxOp>(loc, inBound, result,
111 outOfBoundsValue);
112}
113
114// Return a majority function implemented with Comb operations. `carry` has
115// slightly smaller depth than the other inputs.
116static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a,
117 Value b, Value carry) {
118 // maj(a, b, c) = (c & (a ^ b)) | (a & b)
119 auto aXnorB = comb::XorOp::create(rewriter, loc, ValueRange{a, b}, true);
120 auto andOp =
121 comb::AndOp::create(rewriter, loc, ValueRange{carry, aXnorB}, true);
122 auto aAndB = comb::AndOp::create(rewriter, loc, ValueRange{a, b}, true);
123 return comb::OrOp::create(rewriter, loc, ValueRange{andOp, aAndB}, true);
124}
125
126static Value extractMSB(OpBuilder &builder, Value val) {
127 return builder.createOrFold<comb::ExtractOp>(
128 val.getLoc(), val, val.getType().getIntOrFloatBitWidth() - 1, 1);
129}
130
131static Value extractOtherThanMSB(OpBuilder &builder, Value val) {
132 return builder.createOrFold<comb::ExtractOp>(
133 val.getLoc(), val, 0, val.getType().getIntOrFloatBitWidth() - 1);
134}
135
136namespace {
137// A union of Value and IntegerAttr to cleanly handle constant values.
138using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
139} // namespace
140
141// Return the number of unknown bits and populate the concatenated values.
143 Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
144 // Constant or zero width value are all known.
145 if (value.getType().isInteger(0))
146 return 0;
147
148 // Recursively count unknown bits for concat.
149 if (auto concat = value.getDefiningOp<comb::ConcatOp>()) {
150 int64_t totalUnknownBits = 0;
151 for (auto concatInput : llvm::reverse(concat.getInputs())) {
152 auto unknownBits =
153 getNumUnknownBitsAndPopulateValues(concatInput, values);
154 if (unknownBits < 0)
155 return unknownBits;
156 totalUnknownBits += unknownBits;
157 }
158 return totalUnknownBits;
159 }
160
161 // Constant value is known.
162 if (auto constant = value.getDefiningOp<hw::ConstantOp>()) {
163 values.push_back(constant.getValueAttr());
164 return 0;
165 }
166
167 // Consider other operations as unknown bits.
168 // TODO: We can handle replicate, extract, etc.
169 values.push_back(value);
170 return hw::getBitWidth(value.getType());
171}
172
173// Return a value that substitutes the unknown bits with the mask.
174static APInt
176 llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
177 uint32_t mask) {
178 uint32_t bitPos = 0, unknownPos = 0;
179 APInt result(width, 0);
180 for (auto constantOrValue : constantOrValues) {
181 int64_t elemWidth;
182 if (auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
183 elemWidth = constant.getValue().getBitWidth();
184 result.insertBits(constant.getValue(), bitPos);
185 } else {
186 elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
187 assert(elemWidth >= 0 && "unknown bit width");
188 assert(elemWidth + unknownPos < 32 && "unknown bit width too large");
189 // Create a mask for the unknown bits.
190 uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
191 result.insertBits(APInt(elemWidth, usedBits), bitPos);
192 unknownPos += elemWidth;
193 }
194 bitPos += elemWidth;
195 }
196
197 return result;
198}
199
200// Emulate a binary operation with unknown bits using a table lookup.
201// This function enumerates all possible combinations of unknown bits and
202// emulates the operation for each combination.
203static LogicalResult emulateBinaryOpForUnknownBits(
204 ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
205 Operation *op,
206 llvm::function_ref<APInt(const APInt &, const APInt &)> emulate) {
207 SmallVector<ConstantOrValue> lhsValues, rhsValues;
208
209 assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
210 "op must be a single result binary operation");
211
212 auto lhs = op->getOperand(0);
213 auto rhs = op->getOperand(1);
214 auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
215 auto loc = op->getLoc();
216 auto numLhsUnknownBits = getNumUnknownBitsAndPopulateValues(lhs, lhsValues);
217 auto numRhsUnknownBits = getNumUnknownBitsAndPopulateValues(rhs, rhsValues);
218
219 // If unknown bit width is detected, abort the lowering.
220 if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
221 return failure();
222
223 int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
224 if (totalUnknownBits > maxEmulationUnknownBits)
225 return failure();
226
227 SmallVector<Value> emulatedResults;
228 emulatedResults.reserve(1 << totalUnknownBits);
229
230 // Emulate all possible cases.
231 DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
232 auto getConstant = [&](const APInt &value) -> hw::ConstantOp {
233 auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
234 auto it = constantPool.find(attr);
235 if (it != constantPool.end())
236 return it->second;
237 auto constant = hw::ConstantOp::create(rewriter, loc, value);
238 constantPool[attr] = constant;
239 return constant;
240 };
241
242 for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
243 lhsMask < lhsMaskEnd; ++lhsMask) {
244 APInt lhsValue = substitueMaskToValues(width, lhsValues, lhsMask);
245 for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
246 rhsMask < rhsMaskEnd; ++rhsMask) {
247 APInt rhsValue = substitueMaskToValues(width, rhsValues, rhsMask);
248 // Emulate.
249 emulatedResults.push_back(getConstant(emulate(lhsValue, rhsValue)));
250 }
251 }
252
253 // Create selectors for mux tree.
254 SmallVector<Value> selectors;
255 selectors.reserve(totalUnknownBits);
256 for (auto &concatedValues : {rhsValues, lhsValues})
257 for (auto valueOrConstant : concatedValues) {
258 auto value = dyn_cast<Value>(valueOrConstant);
259 if (!value)
260 continue;
261 extractBits(rewriter, value, selectors);
262 }
263
264 assert(totalUnknownBits == static_cast<int64_t>(selectors.size()) &&
265 "number of selectors must match");
266 auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
267 getConstant(APInt::getZero(width)));
268
269 replaceOpAndCopyNamehint(rewriter, op, muxed);
270 return success();
271}
272
273static Value createLShrByConstant(OpBuilder &builder, Location loc, Value value,
274 unsigned amount) {
275 if (amount == 0)
276 return value;
277 return builder.createOrFold<comb::ShrUOp>(
278 loc, value,
280 builder, loc,
281 APInt(value.getType().getIntOrFloatBitWidth(), amount)));
282}
283
284static Value createAShrByConstant(OpBuilder &builder, Location loc, Value value,
285 unsigned amount) {
286 if (amount == 0)
287 return value;
288 return builder.createOrFold<comb::ShrSOp>(
289 loc, value,
291 builder, loc,
292 APInt(value.getType().getIntOrFloatBitWidth(), amount)));
293}
294
295template <bool isSigned>
296static Value createMulHigh(OpBuilder &builder, Location loc, Value lhs,
297 const APInt &rhs) {
298 unsigned width = lhs.getType().getIntOrFloatBitWidth();
299 auto destTy = builder.getIntegerType(width << 1);
300 // Compute the high half of a double-width product. For signed division,
301 // sign-extend both operands so this acts like a signed multiply-high.
302 Value wideLhs = isSigned ? comb::createOrFoldSExt(builder, loc, lhs, destTy)
303 : comb::createZExt(builder, loc, lhs, width << 1);
304 Value wideRhs = hw::ConstantOp::create(
305 builder, loc, isSigned ? rhs.sext(width << 1) : rhs.zext(width << 1));
306 Value product = builder.createOrFold<comb::MulOp>(
307 loc, ValueRange{wideLhs, wideRhs}, /*twoState=*/true);
308 return builder.createOrFold<comb::ExtractOp>(loc, product, width, width);
309}
310
311static Value lowerUnsignedDivByConstant(OpBuilder &builder, Location loc,
312 Value lhs, const APInt &divisor) {
313 auto info = llvm::UnsignedDivisionByConstantInfo::get(divisor);
314 Value q = createLShrByConstant(builder, loc, lhs, info.PreShift);
315 q = createMulHigh<false>(builder, loc, q, info.Magic);
316 if (info.IsAdd) {
317 Value diff = builder.createOrFold<comb::SubOp>(loc, lhs, q);
318 diff = createLShrByConstant(builder, loc, diff, 1);
319 q = builder.createOrFold<comb::AddOp>(loc, q, diff);
320 }
321 return createLShrByConstant(builder, loc, q, info.PostShift);
322}
323
324static Value lowerSignedDivByConstant(OpBuilder &builder, Location loc,
325 Value lhs, const APInt &divisor) {
326 unsigned width = lhs.getType().getIntOrFloatBitWidth();
327 auto info = llvm::SignedDivisionByConstantInfo::get(divisor);
328 Value q = createMulHigh<true>(builder, loc, lhs, info.Magic);
329 // Depending on the magic constant the signed magic may need to
330 // add or subtract the dividend before the final shift.
331 if (divisor.isStrictlyPositive() && info.Magic.isNegative())
332 q = builder.createOrFold<comb::AddOp>(loc, q, lhs);
333 else if (divisor.isNegative() && info.Magic.isStrictlyPositive())
334 q = builder.createOrFold<comb::SubOp>(loc, q, lhs);
335 q = createAShrByConstant(builder, loc, q, info.ShiftAmount);
336 // Signed division rounds to zero. Add one back for negative tentative
337 // quotients after the arithmetic shift.
338 Value signBit = builder.createOrFold<comb::ExtractOp>(loc, q, width - 1, 1);
339 Value signPadded = comb::createZExt(builder, loc, signBit, width);
340 return builder.createOrFold<comb::AddOp>(loc, q, signPadded);
341}
342
343//===----------------------------------------------------------------------===//
344// Conversion patterns
345//===----------------------------------------------------------------------===//
346
347namespace {
348
349/// Lower a comb::AndOp operation to synth::aig::AndInverterOp
350struct CombAndOpConversion : OpConversionPattern<AndOp> {
352
353 LogicalResult
354 matchAndRewrite(AndOp op, OpAdaptor adaptor,
355 ConversionPatternRewriter &rewriter) const override {
356 SmallVector<bool> nonInverts(adaptor.getInputs().size(), false);
357 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
358 rewriter, op, adaptor.getInputs(), nonInverts);
359 return success();
360 }
361};
362
363/// Lower a comb::OrOp operation to synth::aig::AndInverterOp with invert flags
364struct CombOrToAIGConversion : OpConversionPattern<OrOp> {
366
367 LogicalResult
368 matchAndRewrite(OrOp op, OpAdaptor adaptor,
369 ConversionPatternRewriter &rewriter) const override {
370 // Implement Or using And and invert flags: a | b = ~(~a & ~b)
371 SmallVector<bool> allInverts(adaptor.getInputs().size(), true);
372 auto andOp = synth::aig::AndInverterOp::create(
373 rewriter, op.getLoc(), adaptor.getInputs(), allInverts);
374 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
375 rewriter, op, andOp,
376 /*invert=*/true);
377 return success();
378 }
379};
380
381struct CombXorOpToSynthConversion : OpConversionPattern<XorOp> {
383
384 LogicalResult
385 matchAndRewrite(XorOp op, OpAdaptor adaptor,
386 ConversionPatternRewriter &rewriter) const override {
387 SmallVector<bool> inverted(adaptor.getInputs().size(), false);
388 replaceOpWithNewOpAndCopyNamehint<synth::XorInverterOp>(
389 rewriter, op, adaptor.getInputs(), inverted);
390 return success();
391 }
392};
393
394/// Lower a synth::XorOp operation to AIG operations
395struct SynthXorInverterOpConversion
396 : OpConversionPattern<synth::XorInverterOp> {
397 using OpConversionPattern<synth::XorInverterOp>::OpConversionPattern;
398
399 LogicalResult
400 matchAndRewrite(synth::XorInverterOp op, OpAdaptor adaptor,
401 ConversionPatternRewriter &rewriter) const override {
402 if (op.getNumOperands() != 2)
403 return failure();
404 // Xor using And with invert flags: a ^ b = (a | b) & (~a | ~b)
405
406 // (a | b) = ~(~a & ~b)
407 // (~a | ~b) = ~(a & b)
408 auto inputs = adaptor.getInputs();
409 auto allNotInverts = op.getInverted();
410 std::array<bool, 2> allInverts = {!allNotInverts[0], !allNotInverts[1]};
411
412 auto notAAndNotB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
413 inputs, allInverts);
414 auto aAndB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
415 inputs, allNotInverts);
416
417 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
418 rewriter, op, notAAndNotB, aAndB,
419 /*lhs_invert=*/true,
420 /*rhs_invert=*/true);
421 return success();
422 }
423};
424
425/// Lower a comb::MuxOp operation to synth::MuxInverterOps.
426struct CombMuxOpToSynthConversion : OpConversionPattern<MuxOp> {
428
429 LogicalResult
430 matchAndRewrite(MuxOp op, OpAdaptor adaptor,
431 ConversionPatternRewriter &rewriter) const override {
432 Value cond = adaptor.getCond();
433 Value trueVal = adaptor.getTrueValue();
434 Value falseVal = adaptor.getFalseValue();
435
436 if (!op.getType().isInteger()) {
437 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
438 trueVal =
439 hw::BitcastOp::create(rewriter, op.getLoc(), widthType, trueVal);
440 falseVal =
441 hw::BitcastOp::create(rewriter, op.getLoc(), widthType, falseVal);
442 }
443
444 if (!trueVal.getType().isInteger(1))
445 cond = comb::ReplicateOp::create(rewriter, op.getLoc(), trueVal.getType(),
446 cond);
447
448 Value result = synth::MuxInverterOp::create(rewriter, op.getLoc(), cond,
449 trueVal, falseVal);
450
451 if (result.getType() != op.getType())
452 result =
453 hw::BitcastOp::create(rewriter, op.getLoc(), op.getType(), result);
454
455 replaceOpAndCopyNamehint(rewriter, op, result);
456 return success();
457 }
458};
459
460/// Lower a synth::MuxInverterOp operation to AIG operations.
461struct SynthMuxInverterOpConversion
462 : OpConversionPattern<synth::MuxInverterOp> {
463 using OpConversionPattern<synth::MuxInverterOp>::OpConversionPattern;
464
465 LogicalResult
466 matchAndRewrite(synth::MuxInverterOp op, OpAdaptor adaptor,
467 ConversionPatternRewriter &rewriter) const override {
468 auto inputs = adaptor.getInputs();
469 auto inverted = op.getInverted();
470
471 auto lhs = synth::aig::AndInverterOp::create(
472 rewriter, op.getLoc(), inputs[0], inputs[1], inverted[0], inverted[1]);
473
474 auto rhs = synth::aig::AndInverterOp::create(
475 rewriter, op.getLoc(), inputs[0], inputs[2], !inverted[0], inverted[2]);
476
477 auto nand = synth::aig::AndInverterOp::create(rewriter, op.getLoc(), lhs,
478 rhs, true, true);
479 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(rewriter, op,
480 nand, true);
481 return success();
482 }
483};
484
485template <typename OpTy>
486struct CombLowerVariadicOp : OpConversionPattern<OpTy> {
488 using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
489 LogicalResult
490 matchAndRewrite(OpTy op, OpAdaptor adaptor,
491 ConversionPatternRewriter &rewriter) const override {
492 auto result = lowerFullyAssociativeOp(op, op.getOperands(), rewriter);
493 replaceOpAndCopyNamehint(rewriter, op, result);
494 return success();
495 }
496
497 static Value lowerFullyAssociativeOp(OpTy op, OperandRange operands,
498 ConversionPatternRewriter &rewriter) {
499 Value lhs, rhs;
500 switch (operands.size()) {
501 case 0:
502 llvm_unreachable("cannot be called with empty operand range");
503 break;
504 case 1:
505 return operands[0];
506 case 2:
507 lhs = operands[0];
508 rhs = operands[1];
509 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs}, true);
510 default:
511 auto firstHalf = operands.size() / 2;
512 lhs =
513 lowerFullyAssociativeOp(op, operands.take_front(firstHalf), rewriter);
514 rhs =
515 lowerFullyAssociativeOp(op, operands.drop_front(firstHalf), rewriter);
516 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs}, true);
517 }
518 }
519};
520
521//===----------------------------------------------------------------------===//
522// Adder Architecture Selection
523//===----------------------------------------------------------------------===//
524
525enum AdderArchitecture { RippleCarry, Sklanskey, KoggeStone, BrentKung };
526AdderArchitecture determineAdderArch(Operation *op, int64_t width) {
527 auto strAttr = op->getAttrOfType<StringAttr>("synth.test.arch");
528 if (strAttr) {
529 return llvm::StringSwitch<AdderArchitecture>(strAttr.getValue())
530 .Case("SKLANSKEY", Sklanskey)
531 .Case("KOGGE-STONE", KoggeStone)
532 .Case("BRENT-KUNG", BrentKung)
533 .Case("RIPPLE-CARRY", RippleCarry);
534 }
535 // Determine using width as a heuristic.
536 // TODO: Perform a more thorough analysis to motivate the choices or
537 // implement an adder synthesis algorithm to construct an optimal adder
538 // under the given timing constraints - see the work of Zimmermann
539
540 // For very small adders, overhead of a parallel prefix adder is likely not
541 // worth it.
542 if (width < 8)
543 return AdderArchitecture::RippleCarry;
544
545 // Sklanskey is a good compromise for high-performance, but has high fanout
546 // which may lead to wiring congestion for very large adders.
547 if (width <= 32)
548 return AdderArchitecture::Sklanskey;
549
550 // Kogge-Stone uses greater area than Sklanskey but has lower fanout thus
551 // may be preferable for larger adders.
552 return AdderArchitecture::KoggeStone;
553}
554
555//===----------------------------------------------------------------------===//
556// Parallel Prefix Tree
557//===----------------------------------------------------------------------===//
558
559// Implement the Kogge-Stone parallel prefix tree
560// Described in https://en.wikipedia.org/wiki/Kogge%E2%80%93Stone_adder
561// Slightly better delay than Brent-Kung, but more area.
562void lowerKoggeStonePrefixTree(OpBuilder &builder, Location loc,
563 SmallVector<Value> &pPrefix,
564 SmallVector<Value> &gPrefix) {
565
566 auto width = static_cast<int64_t>(pPrefix.size());
567 assert(width == static_cast<int64_t>(gPrefix.size()));
568 SmallVector<Value> pPrefixNew = pPrefix;
569 SmallVector<Value> gPrefixNew = gPrefix;
570
571 // Kogge-Stone parallel prefix computation
572 for (int64_t stride = 1; stride < width; stride *= 2) {
573
574 for (int64_t i = stride; i < width; ++i) {
575 int64_t j = i - stride;
576
577 // Group generate: g_i OR (p_i AND g_j)
578 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
579 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
580
581 // Group propagate: p_i AND p_j
582 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
583 }
584
585 pPrefix = pPrefixNew;
586 gPrefix = gPrefixNew;
587 }
588
589 LLVM_DEBUG({
590 int64_t stage = 0;
591 for (int64_t stride = 1; stride < width; stride *= 2) {
592 llvm::dbgs()
593 << "--------------------------------------- Kogge-Stone Stage "
594 << stage << "\n";
595 for (int64_t i = stride; i < width; ++i) {
596 int64_t j = i - stride;
597 // Group generate: g_i OR (p_i AND g_j)
598 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
599 << " OR (P" << i << stage << " AND G" << j << stage
600 << ")\n";
601
602 // Group propagate: p_i AND p_j
603 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
604 << " AND P" << j << stage << "\n";
605 }
606 ++stage;
607 }
608 });
609}
610
611// Implement the Sklansky parallel prefix tree
612// High fan-out, low depth, low area
613void lowerSklanskeyPrefixTree(OpBuilder &builder, Location loc,
614 SmallVector<Value> &pPrefix,
615 SmallVector<Value> &gPrefix) {
616 auto width = static_cast<int64_t>(pPrefix.size());
617 assert(width == static_cast<int64_t>(gPrefix.size()));
618 SmallVector<Value> pPrefixNew = pPrefix;
619 SmallVector<Value> gPrefixNew = gPrefix;
620 for (int64_t stride = 1; stride < width; stride *= 2) {
621 for (int64_t i = stride; i < width; i += 2 * stride) {
622 for (int64_t k = 0; k < stride && i + k < width; ++k) {
623 int64_t idx = i + k;
624 int64_t j = i - 1;
625
626 // Group generate: g_idx OR (p_idx AND g_j)
627 Value andPG =
628 comb::AndOp::create(builder, loc, pPrefix[idx], gPrefix[j]);
629 gPrefixNew[idx] = comb::OrOp::create(builder, loc, gPrefix[idx], andPG);
630
631 // Group propagate: p_idx AND p_j
632 pPrefixNew[idx] =
633 comb::AndOp::create(builder, loc, pPrefix[idx], pPrefix[j]);
634 }
635 }
636
637 pPrefix = pPrefixNew;
638 gPrefix = gPrefixNew;
639 }
640
641 LLVM_DEBUG({
642 int64_t stage = 0;
643 for (int64_t stride = 1; stride < width; stride *= 2) {
644 llvm::dbgs() << "--------------------------------------- Sklanskey Stage "
645 << stage << "\n";
646 for (int64_t i = stride; i < width; i += 2 * stride) {
647 for (int64_t k = 0; k < stride && i + k < width; ++k) {
648 int64_t idx = i + k;
649 int64_t j = i - 1;
650 // Group generate: g_i OR (p_i AND g_j)
651 llvm::dbgs() << "G" << idx << stage + 1 << " = G" << idx << stage
652 << " OR (P" << idx << stage << " AND G" << j << stage
653 << ")\n";
654
655 // Group propagate: p_i AND p_j
656 llvm::dbgs() << "P" << idx << stage + 1 << " = P" << idx << stage
657 << " AND P" << j << stage << "\n";
658 }
659 }
660 ++stage;
661 }
662 });
663}
664
665// Implement the Brent-Kung parallel prefix tree
666// Described in https://en.wikipedia.org/wiki/Brent%E2%80%93Kung_adder
667// Slightly worse delay than Kogge-Stone, but less area.
668void lowerBrentKungPrefixTree(OpBuilder &builder, Location loc,
669 SmallVector<Value> &pPrefix,
670 SmallVector<Value> &gPrefix) {
671 auto width = static_cast<int64_t>(pPrefix.size());
672 assert(width == static_cast<int64_t>(gPrefix.size()));
673 SmallVector<Value> pPrefixNew = pPrefix;
674 SmallVector<Value> gPrefixNew = gPrefix;
675 // Brent-Kung parallel prefix computation
676 // Forward phase
677 int64_t stride;
678 for (stride = 1; stride < width; stride *= 2) {
679 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
680 int64_t j = i - stride;
681
682 // Group generate: g_i OR (p_i AND g_j)
683 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
684 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
685
686 // Group propagate: p_i AND p_j
687 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
688 }
689 pPrefix = pPrefixNew;
690 gPrefix = gPrefixNew;
691 }
692
693 // Backward phase
694 for (; stride > 0; stride /= 2) {
695 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
696 int64_t j = i - stride;
697
698 // Group generate: g_i OR (p_i AND g_j)
699 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
700 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
701
702 // Group propagate: p_i AND p_j
703 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
704 }
705 pPrefix = pPrefixNew;
706 gPrefix = gPrefixNew;
707 }
708
709 LLVM_DEBUG({
710 int64_t stage = 0;
711 for (stride = 1; stride < width; stride *= 2) {
712 llvm::dbgs() << "--------------------------------------- Brent-Kung FW "
713 << stage << " : Stride " << stride << "\n";
714 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
715 int64_t j = i - stride;
716
717 // Group generate: g_i OR (p_i AND g_j)
718 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
719 << " OR (P" << i << stage << " AND G" << j << stage
720 << ")\n";
721
722 // Group propagate: p_i AND p_j
723 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
724 << " AND P" << j << stage << "\n";
725 }
726 ++stage;
727 }
728
729 for (; stride > 0; stride /= 2) {
730 if (stride * 3 - 1 < width)
731 llvm::dbgs() << "--------------------------------------- Brent-Kung BW "
732 << stage << " : Stride " << stride << "\n";
733
734 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
735 int64_t j = i - stride;
736
737 // Group generate: g_i OR (p_i AND g_j)
738 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
739 << " OR (P" << i << stage << " AND G" << j << stage
740 << ")\n";
741
742 // Group propagate: p_i AND p_j
743 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
744 << " AND P" << j << stage << "\n";
745 }
746 --stage;
747 }
748 });
749}
750
751// TODO: Generalize to other parallel prefix trees.
752class LazyKoggeStonePrefixTree {
753public:
754 LazyKoggeStonePrefixTree(OpBuilder &builder, Location loc, int64_t width,
755 ArrayRef<Value> pPrefix, ArrayRef<Value> gPrefix)
756 : builder(builder), loc(loc), width(width) {
757 assert(width > 0 && "width must be positive");
758 for (int64_t i = 0; i < width; ++i)
759 prefixCache[{0, i}] = {pPrefix[i], gPrefix[i]};
760 }
761
762 // Get the final group and propagate values for bit i.
763 std::pair<Value, Value> getFinal(int64_t i) {
764 assert(i >= 0 && i < width && "i out of bounds");
765 // Final level is ceil(log2(width)) in Kogge-Stone.
766 return getGroupAndPropagate(llvm::Log2_64_Ceil(width), i);
767 }
768
769private:
770 // Recursively get the group and propagate values for bit i at level `level`.
771 // Level 0 is the initial level with the input propagate and generate values.
772 // Level n computes the group and propagate values for a stride of 2^(n-1).
773 // Uses memoization to cache intermediate results.
774 std::pair<Value, Value> getGroupAndPropagate(int64_t level, int64_t i);
775 OpBuilder &builder;
776 Location loc;
777 int64_t width;
778 DenseMap<std::pair<int64_t, int64_t>, std::pair<Value, Value>> prefixCache;
779};
780
781std::pair<Value, Value>
782LazyKoggeStonePrefixTree::getGroupAndPropagate(int64_t level, int64_t i) {
783 assert(i < width && "i out of bounds");
784 auto key = std::make_pair(level, i);
785 auto it = prefixCache.find(key);
786 if (it != prefixCache.end())
787 return it->second;
788
789 assert(level > 0 && "If the level is 0, we should have hit the cache");
790
791 int64_t previousStride = 1ULL << (level - 1);
792 if (i < previousStride) {
793 // No dependency, just copy from the previous level.
794 auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
795 prefixCache[key] = {propagateI, generateI};
796 return prefixCache[key];
797 }
798 // Get the dependency index.
799 int64_t j = i - previousStride;
800 auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
801 auto [propagateJ, generateJ] = getGroupAndPropagate(level - 1, j);
802 // Group generate: g_i OR (p_i AND g_j)
803 Value andPG = comb::AndOp::create(builder, loc, propagateI, generateJ);
804 Value newGenerate = comb::OrOp::create(builder, loc, generateI, andPG);
805 // Group propagate: p_i AND p_j
806 Value newPropagate =
807 comb::AndOp::create(builder, loc, propagateI, propagateJ);
808 prefixCache[key] = {newPropagate, newGenerate};
809 return prefixCache[key];
810}
811
812struct CombAddOpConversion : OpConversionPattern<AddOp> {
814
815 LogicalResult
816 matchAndRewrite(AddOp op, OpAdaptor adaptor,
817 ConversionPatternRewriter &rewriter) const override {
818 auto inputs = adaptor.getInputs();
819 // Lower only when there are two inputs.
820 // Variadic operands must be lowered in a different pattern.
821 if (inputs.size() != 2)
822 return failure();
823
824 auto width = op.getType().getIntOrFloatBitWidth();
825 // Skip a zero width value.
826 if (width == 0) {
827 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
828 op.getType(), 0);
829 return success();
830 }
831
832 // Check if the architecture is specified by an attribute.
833 auto arch = determineAdderArch(op, width);
834 if (arch == AdderArchitecture::RippleCarry)
835 return lowerRippleCarryAdder(op, inputs, rewriter);
836 return lowerParallelPrefixAdder(op, inputs, rewriter);
837 }
838
839 // Implement a basic ripple-carry adder for small bitwidths.
840 LogicalResult
841 lowerRippleCarryAdder(comb::AddOp op, ValueRange inputs,
842 ConversionPatternRewriter &rewriter) const {
843 auto width = op.getType().getIntOrFloatBitWidth();
844 // Implement a naive Ripple-carry full adder.
845 Value carry;
846
847 auto aBits = extractBits(rewriter, inputs[0]);
848 auto bBits = extractBits(rewriter, inputs[1]);
849 SmallVector<Value> results;
850 results.resize(width);
851 for (int64_t i = 0; i < width; ++i) {
852 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
853 if (carry)
854 xorOperands.push_back(carry);
855
856 // sum[i] = xor(carry[i-1], a[i], b[i])
857 // NOTE: The result is stored in reverse order.
858 results[width - i - 1] =
859 comb::XorOp::create(rewriter, op.getLoc(), xorOperands, true);
860
861 // If this is the last bit, we are done.
862 if (i == width - 1)
863 break;
864
865 // carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
866 if (!carry) {
867 // This is the first bit, so the carry is the next carry.
868 carry = comb::AndOp::create(rewriter, op.getLoc(),
869 ValueRange{aBits[i], bBits[i]}, true);
870 continue;
871 }
872
873 carry = createMajorityFunction(rewriter, op.getLoc(), aBits[i], bBits[i],
874 carry);
875 }
876 LLVM_DEBUG(llvm::dbgs() << "Lower comb.add to Ripple-Carry Adder of width "
877 << width << "\n");
878
879 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
880 return success();
881 }
882
883 // Implement a parallel prefix adder - with Kogge-Stone or Brent-Kung trees
884 // Will introduce unused signals for the carry bits but these will be removed
885 // by the AIG pass.
886 LogicalResult
887 lowerParallelPrefixAdder(comb::AddOp op, ValueRange inputs,
888 ConversionPatternRewriter &rewriter) const {
889 auto width = op.getType().getIntOrFloatBitWidth();
890
891 auto aBits = extractBits(rewriter, inputs[0]);
892 auto bBits = extractBits(rewriter, inputs[1]);
893
894 // Construct propagate (p) and generate (g) signals
895 SmallVector<Value> p, g;
896 p.reserve(width);
897 g.reserve(width);
898
899 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
900 // p_i = a_i XOR b_i
901 p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit));
902 // g_i = a_i AND b_i
903 g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit));
904 }
905
906 LLVM_DEBUG({
907 llvm::dbgs() << "Lower comb.add to Parallel-Prefix of width " << width
908 << "\n--------------------------------------- Init\n";
909
910 for (int64_t i = 0; i < width; ++i) {
911 // p_i = a_i XOR b_i
912 llvm::dbgs() << "P0" << i << " = A" << i << " XOR B" << i << "\n";
913 // g_i = a_i AND b_i
914 llvm::dbgs() << "G0" << i << " = A" << i << " AND B" << i << "\n";
915 }
916 });
917
918 // Create copies of p and g for the prefix computation
919 SmallVector<Value> pPrefix = p;
920 SmallVector<Value> gPrefix = g;
921
922 // Check if the architecture is specified by an attribute.
923 auto arch = determineAdderArch(op, width);
924
925 switch (arch) {
926 case AdderArchitecture::RippleCarry:
927 llvm_unreachable("Ripple-Carry should be handled separately");
928 break;
929 case AdderArchitecture::Sklanskey:
930 lowerSklanskeyPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
931 break;
932 case AdderArchitecture::KoggeStone:
933 lowerKoggeStonePrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
934 break;
935 case AdderArchitecture::BrentKung:
936 lowerBrentKungPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
937 break;
938 }
939
940 // Generate result sum bits
941 // NOTE: The result is stored in reverse order.
942 SmallVector<Value> results;
943 results.resize(width);
944 // Sum bit 0 is just p[0] since carry_in = 0
945 results[width - 1] = p[0];
946
947 // For remaining bits, sum_i = p_i XOR g_(i-1)
948 // The carry into position i is the group generate from position i-1
949 for (int64_t i = 1; i < width; ++i)
950 results[width - 1 - i] =
951 comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]);
952
953 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
954
955 LLVM_DEBUG({
956 llvm::dbgs() << "--------------------------------------- Completion\n"
957 << "RES0 = P0\n";
958 for (int64_t i = 1; i < width; ++i)
959 llvm::dbgs() << "RES" << i << " = P" << i << " XOR G" << i - 1 << "\n";
960 });
961
962 return success();
963 }
964};
965
966struct CombMulOpConversion : OpConversionPattern<MulOp> {
968 using OpAdaptor = typename OpConversionPattern<MulOp>::OpAdaptor;
969 LogicalResult
970 matchAndRewrite(MulOp op, OpAdaptor adaptor,
971 ConversionPatternRewriter &rewriter) const override {
972 if (adaptor.getInputs().size() != 2)
973 return failure();
974
975 Location loc = op.getLoc();
976 Value a = adaptor.getInputs()[0];
977 Value b = adaptor.getInputs()[1];
978 unsigned width = op.getType().getIntOrFloatBitWidth();
979
980 // Skip a zero width value.
981 if (width == 0) {
982 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
983 return success();
984 }
985
986 // Extract individual bits from operands
987 SmallVector<Value> aBits = extractBits(rewriter, a);
988 SmallVector<Value> bBits = extractBits(rewriter, b);
989
990 auto falseValue = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));
991
992 // Generate partial products
993 SmallVector<SmallVector<Value>> partialProducts;
994 partialProducts.reserve(width);
995 for (unsigned i = 0; i < width; ++i) {
996 SmallVector<Value> row(i, falseValue);
997 row.reserve(width);
998 // Generate partial product bits
999 for (unsigned j = 0; i + j < width; ++j)
1000 row.push_back(
1001 rewriter.createOrFold<comb::AndOp>(loc, aBits[j], bBits[i]));
1002
1003 partialProducts.push_back(row);
1004 }
1005
1006 // If the width is 1, we are done.
1007 if (width == 1) {
1008 rewriter.replaceOp(op, partialProducts[0][0]);
1009 return success();
1010 }
1011
1012 // Wallace tree reduction - reduce to two addends.
1013 datapath::CompressorTree comp(width, partialProducts, loc);
1014 auto addends = comp.compressToHeight(rewriter, 2);
1015
1016 // Sum the two addends using a carry-propagate adder
1017 auto newAdd = comb::AddOp::create(rewriter, loc, addends, true);
1018 replaceOpAndCopyNamehint(rewriter, op, newAdd);
1019 return success();
1020 }
1021};
1022
1023template <typename OpTy>
1024struct DivModOpConversionBase : OpConversionPattern<OpTy> {
1025 DivModOpConversionBase(MLIRContext *context, int64_t maxEmulationUnknownBits)
1027 maxEmulationUnknownBits(maxEmulationUnknownBits) {
1028 assert(maxEmulationUnknownBits < 32 &&
1029 "maxEmulationUnknownBits must be less than 32");
1030 }
1031 const int64_t maxEmulationUnknownBits;
1032};
1033
1034struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
1035 using DivModOpConversionBase<DivUOp>::DivModOpConversionBase;
1036 LogicalResult
1037 matchAndRewrite(DivUOp op, OpAdaptor adaptor,
1038 ConversionPatternRewriter &rewriter) const override {
1039 // Check if the divisor is a power of two.
1040 if (llvm::succeeded(comb::convertDivUByPowerOfTwo(op, rewriter)))
1041 return success();
1042
1043 // Lower constant divisors with magic-number division; otherwise fall back
1044 // to emulation for small rhs values.
1045 if (auto rhsConst = adaptor.getRhs().getDefiningOp<hw::ConstantOp>()) {
1046 APInt divisor = rhsConst.getValue();
1047 // Division by zero is undefined, just return zero.
1048 if (divisor.isZero()) {
1049 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
1050 op.getType(), 0);
1051 return success();
1052 }
1053 replaceOpAndCopyNamehint(rewriter, op,
1054 lowerUnsignedDivByConstant(rewriter, op.getLoc(),
1055 adaptor.getLhs(),
1056 divisor));
1057 return success();
1058 }
1059
1060 // When rhs is not power of two and the number of unknown bits are small,
1061 // create a mux tree that emulates all possible cases.
1063 rewriter, maxEmulationUnknownBits, op,
1064 [](const APInt &lhs, const APInt &rhs) {
1065 // Division by zero is undefined, just return zero.
1066 if (rhs.isZero())
1067 return APInt::getZero(rhs.getBitWidth());
1068 return lhs.udiv(rhs);
1069 });
1070 }
1071};
1072
1073struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
1074 using DivModOpConversionBase<ModUOp>::DivModOpConversionBase;
1075 LogicalResult
1076 matchAndRewrite(ModUOp op, OpAdaptor adaptor,
1077 ConversionPatternRewriter &rewriter) const override {
1078 // Check if the divisor is a power of two.
1079 if (llvm::succeeded(comb::convertModUByPowerOfTwo(op, rewriter)))
1080 return success();
1081
1082 // Lower constant divisors by calculating q = lhs / rhs and returning
1083 // lhs - q * rhs; otherwise fall back to emulation for small rhs values.
1084 if (auto rhsConst = adaptor.getRhs().getDefiningOp<hw::ConstantOp>()) {
1085 APInt divisor = rhsConst.getValue();
1086 // Remainder by zero is undefined, just return zero.
1087 if (divisor.isZero()) {
1088 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
1089 op.getType(), 0);
1090 return success();
1091 }
1092 auto loc = op.getLoc();
1093 Value q =
1094 lowerUnsignedDivByConstant(rewriter, loc, adaptor.getLhs(), divisor);
1095 Value product =
1096 rewriter.createOrFold<comb::MulOp>(loc, q, adaptor.getRhs());
1097 Value remainder =
1098 rewriter.createOrFold<comb::SubOp>(loc, adaptor.getLhs(), product);
1099 replaceOpAndCopyNamehint(rewriter, op, remainder);
1100
1101 return success();
1102 }
1103
1104 // When rhs is not power of two and the number of unknown bits are small,
1105 // create a mux tree that emulates all possible cases.
1107 rewriter, maxEmulationUnknownBits, op,
1108 [](const APInt &lhs, const APInt &rhs) {
1109 // Division by zero is undefined, just return zero.
1110 if (rhs.isZero())
1111 return APInt::getZero(rhs.getBitWidth());
1112 return lhs.urem(rhs);
1113 });
1114 }
1115};
1116
1117struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
1118 using DivModOpConversionBase<DivSOp>::DivModOpConversionBase;
1119
1120 LogicalResult
1121 matchAndRewrite(DivSOp op, OpAdaptor adaptor,
1122 ConversionPatternRewriter &rewriter) const override {
1123 // Lower constant divisors with magic-number division; otherwise fall back
1124 // to emulation for small rhs values.
1125 if (auto rhsConst = adaptor.getRhs().getDefiningOp<hw::ConstantOp>()) {
1126 APInt divisor = rhsConst.getValue();
1127 unsigned width = op.getType().getIntOrFloatBitWidth();
1128 // Division by zero is undefined, just return zero.
1129 if (divisor.isZero()) {
1130 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
1131 op.getType(), 0);
1132 return success();
1133 }
1134 // divs(lhs, 1) = lhs.
1135 if (divisor.isOne()) {
1136 replaceOpAndCopyNamehint(rewriter, op, adaptor.getLhs());
1137 return success();
1138 }
1139 // divs(lhs, -1) = -lhs = sub(0, lhs).
1140 if (divisor.isAllOnes()) {
1142 rewriter, op,
1143 rewriter.createOrFold<comb::SubOp>(
1144 op.getLoc(),
1145 hw::ConstantOp::create(rewriter, op.getLoc(),
1146 APInt::getZero(width)),
1147 adaptor.getLhs()));
1148 return success();
1149 }
1150 replaceOpAndCopyNamehint(rewriter, op,
1151 lowerSignedDivByConstant(rewriter, op.getLoc(),
1152 adaptor.getLhs(),
1153 divisor));
1154 return success();
1155 }
1156
1158 rewriter, maxEmulationUnknownBits, op,
1159 [](const APInt &lhs, const APInt &rhs) {
1160 // Division by zero is undefined, just return zero.
1161 if (rhs.isZero())
1162 return APInt::getZero(rhs.getBitWidth());
1163 return lhs.sdiv(rhs);
1164 });
1165 }
1166};
1167
1168struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
1169 using DivModOpConversionBase<ModSOp>::DivModOpConversionBase;
1170 LogicalResult
1171 matchAndRewrite(ModSOp op, OpAdaptor adaptor,
1172 ConversionPatternRewriter &rewriter) const override {
1173 // Lower constant divisors by calculating q = lhs / rhs and returning
1174 // lhs - q * rhs; otherwise fall back to emulation for small rhs values.
1175 if (auto rhsConst = adaptor.getRhs().getDefiningOp<hw::ConstantOp>()) {
1176 APInt divisor = rhsConst.getValue();
1177 // Remainder by 0 is undefined; remainder by +/-1 is always zero.
1178 if (divisor.isZero() || divisor.isOne() || divisor.isAllOnes()) {
1179 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
1180 op.getType(), 0);
1181 return success();
1182 }
1183 auto loc = op.getLoc();
1184 Value q =
1185 lowerSignedDivByConstant(rewriter, loc, adaptor.getLhs(), divisor);
1186 Value product =
1187 rewriter.createOrFold<comb::MulOp>(loc, q, adaptor.getRhs());
1188 Value remainder =
1189 rewriter.createOrFold<comb::SubOp>(loc, adaptor.getLhs(), product);
1190 replaceOpAndCopyNamehint(rewriter, op, remainder);
1191 return success();
1192 }
1193
1195 rewriter, maxEmulationUnknownBits, op,
1196 [](const APInt &lhs, const APInt &rhs) {
1197 // Division by zero is undefined, just return zero.
1198 if (rhs.isZero())
1199 return APInt::getZero(rhs.getBitWidth());
1200 return lhs.srem(rhs);
1201 });
1202 }
1203};
1204
1205struct CombICmpOpConversion : OpConversionPattern<ICmpOp> {
1207
1208 // Simple comparator for small bit widths
1209 static Value constructRippleCarry(Location loc, Value a, Value b,
1210 bool includeEq,
1211 ConversionPatternRewriter &rewriter) {
1212 // Construct following unsigned comparison expressions.
1213 // a <= b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] <= b[n-1:0])
1214 // a < b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] < b[n-1:0])
1215 auto aBits = extractBits(rewriter, a);
1216 auto bBits = extractBits(rewriter, b);
1217 Value acc = hw::ConstantOp::create(rewriter, loc, APInt(1, includeEq));
1218
1219 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
1220 auto aBitXorBBit =
1221 rewriter.createOrFold<comb::XorOp>(loc, aBit, bBit, true);
1222 auto aEqualB = rewriter.createOrFold<synth::aig::AndInverterOp>(
1223 loc, aBitXorBBit, true);
1224 auto pred = rewriter.createOrFold<synth::aig::AndInverterOp>(
1225 loc, aBit, bBit, true, false);
1226
1227 auto aBitAndBBit = rewriter.createOrFold<comb::AndOp>(
1228 loc, ValueRange{aEqualB, acc}, true);
1229 acc = rewriter.createOrFold<comb::OrOp>(loc, pred, aBitAndBBit, true);
1230 }
1231 return acc;
1232 }
1233
1234 // Compute prefix comparison using parallel prefix algorithm
1235 // Note: This generates all intermediate prefix values even though we only
1236 // need the final result. Optimizing this to skip intermediate computations
1237 // is non-trivial because each iteration depends on results from previous
1238 // iterations. We rely on DCE passes to remove unused operations.
1239 // TODO: Lazily compute only the required prefix values. Kogge-Stone is
1240 // already implemented in a lazy manner below, but other architectures can
1241 // also be optimized.
1242 static Value computePrefixComparison(ConversionPatternRewriter &rewriter,
1243 Location loc, SmallVector<Value> pPrefix,
1244 SmallVector<Value> gPrefix,
1245 bool includeEq, AdderArchitecture arch) {
1246 auto width = pPrefix.size();
1247 Value finalGroup, finalPropagate;
1248 // Apply the appropriate prefix tree algorithm
1249 switch (arch) {
1250 case AdderArchitecture::RippleCarry:
1251 llvm_unreachable("Ripple-Carry should be handled separately");
1252 break;
1253 case AdderArchitecture::Sklanskey: {
1254 lowerSklanskeyPrefixTree(rewriter, loc, pPrefix, gPrefix);
1255 finalGroup = gPrefix[width - 1];
1256 finalPropagate = pPrefix[width - 1];
1257 break;
1258 }
1259 case AdderArchitecture::KoggeStone:
1260 // Use lazy Kogge-Stone implementation to avoid computing all
1261 // intermediate prefix values.
1262 std::tie(finalPropagate, finalGroup) =
1263 LazyKoggeStonePrefixTree(rewriter, loc, width, pPrefix, gPrefix)
1264 .getFinal(width - 1);
1265 break;
1266 case AdderArchitecture::BrentKung: {
1267 lowerBrentKungPrefixTree(rewriter, loc, pPrefix, gPrefix);
1268 finalGroup = gPrefix[width - 1];
1269 finalPropagate = pPrefix[width - 1];
1270 break;
1271 }
1272 }
1273
1274 // Final result: `finalGroup` gives us "a < b"
1275 if (includeEq) {
1276 // a <= b iff (a < b) OR (a == b)
1277 // a == b iff `finalPropagate` (all bits are equal)
1278 return comb::OrOp::create(rewriter, loc, finalGroup, finalPropagate);
1279 }
1280 // a < b iff `finalGroup`
1281 return finalGroup;
1282 }
1283
1284 // Construct an unsigned comparator using either ripple-carry or
1285 // parallel-prefix architecture. Comparison uses parallel prefix tree as an
1286 // internal component, so use `AdderArchitecture` enum to select architecture.
1287 static Value constructUnsignedCompare(Operation *op, Location loc, Value a,
1288 Value b, bool isLess, bool includeEq,
1289 ConversionPatternRewriter &rewriter) {
1290 // Ensure a <= b by swapping for simplicity.
1291 if (!isLess)
1292 std::swap(a, b);
1293 auto width = a.getType().getIntOrFloatBitWidth();
1294
1295 // Check if the architecture is specified by an attribute.
1296 auto arch = determineAdderArch(op, width);
1297 if (arch == AdderArchitecture::RippleCarry)
1298 return constructRippleCarry(loc, a, b, includeEq, rewriter);
1299
1300 // For larger widths, use parallel prefix tree
1301 auto aBits = extractBits(rewriter, a);
1302 auto bBits = extractBits(rewriter, b);
1303
1304 // For comparison, we compute:
1305 // - Equal bits: eq_i = ~(a_i ^ b_i)
1306 // - Greater bits: gt_i = ~a_i & b_i (a_i < b_i)
1307 // - Propagate: p_i = eq_i (equality propagates)
1308 // - Generate: g_i = gt_i (greater-than generates)
1309 SmallVector<Value> eq, gt;
1310 eq.reserve(width);
1311 gt.reserve(width);
1312
1313 auto one =
1314 hw::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(1), 1);
1315
1316 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
1317 // eq_i = ~(a_i ^ b_i) = a_i == b_i
1318 auto xorBit = comb::XorOp::create(rewriter, loc, aBit, bBit);
1319 eq.push_back(comb::XorOp::create(rewriter, loc, xorBit, one));
1320
1321 // gt_i = ~a_i & b_i = a_i < b_i
1322 auto notA = comb::XorOp::create(rewriter, loc, aBit, one);
1323 gt.push_back(comb::AndOp::create(rewriter, loc, notA, bBit));
1324 }
1325
1326 return computePrefixComparison(rewriter, loc, std::move(eq), std::move(gt),
1327 includeEq, arch);
1328 }
1329
1330 LogicalResult
1331 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
1332 ConversionPatternRewriter &rewriter) const override {
1333 auto lhs = adaptor.getLhs();
1334 auto rhs = adaptor.getRhs();
1335
1336 switch (op.getPredicate()) {
1337 default:
1338 return failure();
1339
1340 case ICmpPredicate::eq:
1341 case ICmpPredicate::ceq: {
1342 // a == b ==> ~(a[n] ^ b[n]) & ~(a[n-1] ^ b[n-1]) & ...
1343 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
1344 auto xorBits = extractBits(rewriter, xorOp);
1345 SmallVector<bool> allInverts(xorBits.size(), true);
1346 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
1347 rewriter, op, xorBits, allInverts);
1348 return success();
1349 }
1350
1351 case ICmpPredicate::ne:
1352 case ICmpPredicate::cne: {
1353 // a != b ==> (a[n] ^ b[n]) | (a[n-1] ^ b[n-1]) | ...
1354 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
1355 replaceOpWithNewOpAndCopyNamehint<comb::OrOp>(
1356 rewriter, op, extractBits(rewriter, xorOp), true);
1357 return success();
1358 }
1359
1360 case ICmpPredicate::uge:
1361 case ICmpPredicate::ugt:
1362 case ICmpPredicate::ule:
1363 case ICmpPredicate::ult: {
1364 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
1365 op.getPredicate() == ICmpPredicate::ule;
1366 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
1367 op.getPredicate() == ICmpPredicate::ule;
1368 replaceOpAndCopyNamehint(rewriter, op,
1369 constructUnsignedCompare(op, op.getLoc(), lhs,
1370 rhs, isLess, includeEq,
1371 rewriter));
1372 return success();
1373 }
1374 case ICmpPredicate::slt:
1375 case ICmpPredicate::sle:
1376 case ICmpPredicate::sgt:
1377 case ICmpPredicate::sge: {
1378 if (lhs.getType().getIntOrFloatBitWidth() == 0)
1379 return rewriter.notifyMatchFailure(
1380 op.getLoc(), "i0 signed comparison is unsupported");
1381 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
1382 op.getPredicate() == ICmpPredicate::sle;
1383 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
1384 op.getPredicate() == ICmpPredicate::sle;
1385
1386 // Get a sign bit
1387 auto signA = extractMSB(rewriter, lhs);
1388 auto signB = extractMSB(rewriter, rhs);
1389 auto aRest = extractOtherThanMSB(rewriter, lhs);
1390 auto bRest = extractOtherThanMSB(rewriter, rhs);
1391
1392 // Compare magnitudes (all bits except sign)
1393 auto sameSignResult = constructUnsignedCompare(
1394 op, op.getLoc(), aRest, bRest, isLess, includeEq, rewriter);
1395
1396 // XOR of signs: true if signs are different
1397 auto signsDiffer =
1398 comb::XorOp::create(rewriter, op.getLoc(), signA, signB);
1399
1400 // Result when signs are different
1401 Value diffSignResult = isLess ? signA : signB;
1402
1403 // Final result: choose based on whether signs differ
1404 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1405 rewriter, op, signsDiffer, diffSignResult, sameSignResult);
1406 return success();
1407 }
1408 }
1409 }
1410};
1411
1412struct CombParityOpConversion : OpConversionPattern<ParityOp> {
1414
1415 LogicalResult
1416 matchAndRewrite(ParityOp op, OpAdaptor adaptor,
1417 ConversionPatternRewriter &rewriter) const override {
1418 // Parity is the XOR of all bits.
1419 replaceOpWithNewOpAndCopyNamehint<comb::XorOp>(
1420 rewriter, op, extractBits(rewriter, adaptor.getInput()), true);
1421 return success();
1422 }
1423};
1424
1425struct CombShlOpConversion : OpConversionPattern<comb::ShlOp> {
1427
1428 LogicalResult
1429 matchAndRewrite(comb::ShlOp op, OpAdaptor adaptor,
1430 ConversionPatternRewriter &rewriter) const override {
1431 auto width = op.getType().getIntOrFloatBitWidth();
1432 auto lhs = adaptor.getLhs();
1433 auto result = createShiftLogic</*isLeftShift=*/true>(
1434 rewriter, op.getLoc(), adaptor.getRhs(), width,
1435 /*getPadding=*/
1436 [&](int64_t index) {
1437 // Don't create zero width value.
1438 if (index == 0)
1439 return Value();
1440 // Padding is 0 for left shift.
1441 return rewriter.createOrFold<hw::ConstantOp>(
1442 op.getLoc(), rewriter.getIntegerType(index), 0);
1443 },
1444 /*getExtract=*/
1445 [&](int64_t index) {
1446 assert(index < width && "index out of bounds");
1447 // Exract the bits from LSB.
1448 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, 0,
1449 width - index);
1450 });
1451
1452 replaceOpAndCopyNamehint(rewriter, op, result);
1453 return success();
1454 }
1455};
1456
1457struct CombShrUOpConversion : OpConversionPattern<comb::ShrUOp> {
1459
1460 LogicalResult
1461 matchAndRewrite(comb::ShrUOp op, OpAdaptor adaptor,
1462 ConversionPatternRewriter &rewriter) const override {
1463 auto width = op.getType().getIntOrFloatBitWidth();
1464 auto lhs = adaptor.getLhs();
1465 auto result = createShiftLogic</*isLeftShift=*/false>(
1466 rewriter, op.getLoc(), adaptor.getRhs(), width,
1467 /*getPadding=*/
1468 [&](int64_t index) {
1469 // Don't create zero width value.
1470 if (index == 0)
1471 return Value();
1472 // Padding is 0 for right shift.
1473 return rewriter.createOrFold<hw::ConstantOp>(
1474 op.getLoc(), rewriter.getIntegerType(index), 0);
1475 },
1476 /*getExtract=*/
1477 [&](int64_t index) {
1478 assert(index < width && "index out of bounds");
1479 // Exract the bits from MSB.
1480 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
1481 width - index);
1482 });
1483
1484 replaceOpAndCopyNamehint(rewriter, op, result);
1485 return success();
1486 }
1487};
1488
1489struct CombShrSOpConversion : OpConversionPattern<comb::ShrSOp> {
1491
1492 LogicalResult
1493 matchAndRewrite(comb::ShrSOp op, OpAdaptor adaptor,
1494 ConversionPatternRewriter &rewriter) const override {
1495 auto width = op.getType().getIntOrFloatBitWidth();
1496 if (width == 0)
1497 return rewriter.notifyMatchFailure(op.getLoc(),
1498 "i0 signed shift is unsupported");
1499 auto lhs = adaptor.getLhs();
1500 // Get the sign bit.
1501 auto sign =
1502 rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
1503
1504 // NOTE: The max shift amount is width - 1 because the sign bit is
1505 // already shifted out.
1506 auto result = createShiftLogic</*isLeftShift=*/false>(
1507 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
1508 /*getPadding=*/
1509 [&](int64_t index) {
1510 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
1511 index + 1);
1512 },
1513 /*getExtract=*/
1514 [&](int64_t index) {
1515 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
1516 width - index - 1);
1517 });
1518
1519 replaceOpAndCopyNamehint(rewriter, op, result);
1520 return success();
1521 }
1522};
1523
1524} // namespace
1525
1526//===----------------------------------------------------------------------===//
1527// Convert Comb to AIG pass
1528//===----------------------------------------------------------------------===//
1529
1530namespace {
1531struct ConvertCombToSynthPass
1532 : public impl::ConvertCombToSynthBase<ConvertCombToSynthPass> {
1533 void runOnOperation() override;
1534 using ConvertCombToSynthBase<ConvertCombToSynthPass>::ConvertCombToSynthBase;
1535};
1536} // namespace
1537
1538static void
1540 uint32_t maxEmulationUnknownBits,
1541 bool forceAIG) {
1542 patterns.add<
1543 // Bitwise Logical Ops
1544 CombAndOpConversion, CombParityOpConversion, CombXorOpToSynthConversion,
1545 CombMuxOpToSynthConversion,
1546 // Arithmetic Ops
1547 CombMulOpConversion, CombICmpOpConversion,
1548 // Shift Ops
1549 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
1550 // Variadic ops that must be lowered to binary operations
1551 CombLowerVariadicOp<AddOp>, CombLowerVariadicOp<MulOp>>(
1552 patterns.getContext());
1553
1554 if (forceAIG) {
1555 patterns.add<SynthXorInverterOpConversion, SynthMuxInverterOpConversion>(
1556 patterns.getContext());
1557 }
1558 patterns.add(comb::convertSubToAdd);
1559
1560 patterns.add<CombOrToAIGConversion, CombAddOpConversion>(
1561 patterns.getContext());
1562 synth::populateVariadicAndInverterLoweringPatterns(patterns);
1563
1564 if (forceAIG)
1565 synth::populateVariadicXorInverterLoweringPatterns(patterns);
1566
1567 // Add div/mod patterns with a threshold given by the pass option.
1568 patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
1569 CombModSOpConversion>(patterns.getContext(),
1570 maxEmulationUnknownBits);
1571}
1572
1573void ConvertCombToSynthPass::runOnOperation() {
1574 ConversionTarget target(getContext());
1575
1576 // Comb is source dialect.
1577 target.addIllegalDialect<comb::CombDialect>();
1578 // Keep data movement operations like Extract, Concat and Replicate.
1579 target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
1581
1582 // Treat array operations as illegal. Strictly speaking, other than array
1583 // get operation with non-const index are legal in AIG but array types
1584 // prevent a bunch of optimizations so just lower them to integer
1585 // operations. It's required to run HWAggregateToComb pass before this pass.
1587 hw::AggregateConstantOp>();
1588
1589 target.addLegalDialect<synth::SynthDialect>();
1590 if (forceAIG)
1591 target.addIllegalOp<synth::XorInverterOp, synth::MuxInverterOp>();
1592
1593 // If additional legal ops are specified, add them to the target.
1594 if (!additionalLegalOps.empty())
1595 for (const auto &opName : additionalLegalOps)
1596 target.addLegalOp(OperationName(opName, &getContext()));
1597
1598 RewritePatternSet patterns(&getContext());
1599 populateCombToAIGConversionPatterns(patterns, maxEmulationUnknownBits,
1600 forceAIG);
1601
1602 if (failed(mlir::applyPartialConversion(getOperation(), target,
1603 std::move(patterns))))
1604 return signalPassFailure();
1605}
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
static Value createAShrByConstant(OpBuilder &builder, Location loc, Value value, unsigned amount)
static Value createMulHigh(OpBuilder &builder, Location loc, Value lhs, const APInt &rhs)
static APInt substitueMaskToValues(size_t width, llvm::SmallVectorImpl< ConstantOrValue > &constantOrValues, uint32_t mask)
static Value lowerSignedDivByConstant(OpBuilder &builder, Location loc, Value lhs, const APInt &divisor)
static Value createLShrByConstant(OpBuilder &builder, Location loc, Value value, unsigned amount)
static LogicalResult emulateBinaryOpForUnknownBits(ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits, Operation *op, llvm::function_ref< APInt(const APInt &, const APInt &)> emulate)
static int64_t getNumUnknownBitsAndPopulateValues(Value value, llvm::SmallVectorImpl< ConstantOrValue > &values)
static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a, Value b, Value carry)
static Value extractOtherThanMSB(OpBuilder &builder, Value val)
static Value extractMSB(OpBuilder &builder, Value val)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, uint32_t maxEmulationUnknownBits, bool forceAIG)
static Value lowerUnsignedDivByConstant(OpBuilder &builder, Location loc, Value lhs, const APInt &divisor)
static std::unique_ptr< Context > context
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
create(data_type, value)
Definition hw.py:441
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
Definition comb.py:1