CIRCT 23.0.0git
Loading...
Searching...
No Matches
CombToSynth.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main Comb to Synth Conversion Pass Implementation.
10//
11// High-level Comb Operations
12// | |
13// v |
14// +-------------------+ |
15// | and, or, xor, mux | |
16// +---------+---------+ |
17// | |
18// +-------+--------+ |
19// v v v
20// +-----+ +-----+
21// | AIG |-------->| MIG |
22// +-----+ +-----+
23//
24//===----------------------------------------------------------------------===//
25
33#include "mlir/Pass/Pass.h"
34#include "mlir/Transforms/DialectConversion.h"
35#include "llvm/ADT/APInt.h"
36#include "llvm/ADT/PointerUnion.h"
37#include "llvm/Support/Debug.h"
38#include <array>
39
40#define DEBUG_TYPE "comb-to-synth"
41
42namespace circt {
43#define GEN_PASS_DEF_CONVERTCOMBTOSYNTH
44#include "circt/Conversion/Passes.h.inc"
45} // namespace circt
46
47using namespace circt;
48using namespace comb;
49
50//===----------------------------------------------------------------------===//
51// Utility Functions
52//===----------------------------------------------------------------------===//
53
54// A wrapper for comb::extractBits that returns a SmallVector<Value>.
55static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
56 SmallVector<Value> bits;
57 comb::extractBits(builder, val, bits);
58 return bits;
59}
60
61// Construct a mux tree for shift operations. `isLeftShift` controls the
62// direction of the shift operation and is used to determine order of the
63// padding and extracted bits. Callbacks `getPadding` and `getExtract` are used
64// to get the padding and extracted bits for each shift amount. `getPadding`
65// could return a nullptr as i0 value but except for that, these callbacks must
66// return a valid value for each shift amount in the range [0, maxShiftAmount].
67// The value for `maxShiftAmount` is used as the out-of-bounds value.
68template <bool isLeftShift>
69static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc,
70 Value shiftAmount, int64_t maxShiftAmount,
71 llvm::function_ref<Value(int64_t)> getPadding,
72 llvm::function_ref<Value(int64_t)> getExtract) {
73 // Extract individual bits from shift amount
74 auto bits = extractBits(rewriter, shiftAmount);
75
76 // Create nodes for each possible shift amount
77 SmallVector<Value> nodes;
78 nodes.reserve(maxShiftAmount);
79 for (int64_t i = 0; i < maxShiftAmount; ++i) {
80 Value extract = getExtract(i);
81 Value padding = getPadding(i);
82
83 if (!padding) {
84 nodes.push_back(extract);
85 continue;
86 }
87
88 // Concatenate extracted bits with padding
89 if (isLeftShift)
90 nodes.push_back(
91 rewriter.createOrFold<comb::ConcatOp>(loc, extract, padding));
92 else
93 nodes.push_back(
94 rewriter.createOrFold<comb::ConcatOp>(loc, padding, extract));
95 }
96
97 // Create out-of-bounds value
98 auto outOfBoundsValue = getPadding(maxShiftAmount);
99 assert(outOfBoundsValue && "outOfBoundsValue must be valid");
100
101 // Construct mux tree for shift operation
102 auto result =
103 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
104
105 // Add bounds checking
106 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
107 loc, ICmpPredicate::ult, shiftAmount,
108 hw::ConstantOp::create(rewriter, loc, shiftAmount.getType(),
109 maxShiftAmount));
110
111 return rewriter.createOrFold<comb::MuxOp>(loc, inBound, result,
112 outOfBoundsValue);
113}
114
115// Return a majority operation if MIG is enabled, otherwise return a majority
116// function implemented with Comb operations. In that case `carry` has slightly
117// smaller depth than the other inputs.
118static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a,
119 Value b, Value carry,
120 bool useMajorityInverterOp) {
121 if (useMajorityInverterOp) {
122 std::array<Value, 3> inputs = {a, b, carry};
123 std::array<bool, 3> inverts = {false, false, false};
124 return synth::mig::MajorityInverterOp::create(rewriter, loc, inputs,
125 inverts);
126 }
127
128 // maj(a, b, c) = (c & (a ^ b)) | (a & b)
129 auto aXnorB = comb::XorOp::create(rewriter, loc, ValueRange{a, b}, true);
130 auto andOp =
131 comb::AndOp::create(rewriter, loc, ValueRange{carry, aXnorB}, true);
132 auto aAndB = comb::AndOp::create(rewriter, loc, ValueRange{a, b}, true);
133 return comb::OrOp::create(rewriter, loc, ValueRange{andOp, aAndB}, true);
134}
135
136static Value extractMSB(OpBuilder &builder, Value val) {
137 return builder.createOrFold<comb::ExtractOp>(
138 val.getLoc(), val, val.getType().getIntOrFloatBitWidth() - 1, 1);
139}
140
141static Value extractOtherThanMSB(OpBuilder &builder, Value val) {
142 return builder.createOrFold<comb::ExtractOp>(
143 val.getLoc(), val, 0, val.getType().getIntOrFloatBitWidth() - 1);
144}
145
146namespace {
147// A union of Value and IntegerAttr to cleanly handle constant values.
148using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
149} // namespace
150
151// Return the number of unknown bits and populate the concatenated values.
153 Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
154 // Constant or zero width value are all known.
155 if (value.getType().isInteger(0))
156 return 0;
157
158 // Recursively count unknown bits for concat.
159 if (auto concat = value.getDefiningOp<comb::ConcatOp>()) {
160 int64_t totalUnknownBits = 0;
161 for (auto concatInput : llvm::reverse(concat.getInputs())) {
162 auto unknownBits =
163 getNumUnknownBitsAndPopulateValues(concatInput, values);
164 if (unknownBits < 0)
165 return unknownBits;
166 totalUnknownBits += unknownBits;
167 }
168 return totalUnknownBits;
169 }
170
171 // Constant value is known.
172 if (auto constant = value.getDefiningOp<hw::ConstantOp>()) {
173 values.push_back(constant.getValueAttr());
174 return 0;
175 }
176
177 // Consider other operations as unknown bits.
178 // TODO: We can handle replicate, extract, etc.
179 values.push_back(value);
180 return hw::getBitWidth(value.getType());
181}
182
183// Return a value that substitutes the unknown bits with the mask.
184static APInt
186 llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
187 uint32_t mask) {
188 uint32_t bitPos = 0, unknownPos = 0;
189 APInt result(width, 0);
190 for (auto constantOrValue : constantOrValues) {
191 int64_t elemWidth;
192 if (auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
193 elemWidth = constant.getValue().getBitWidth();
194 result.insertBits(constant.getValue(), bitPos);
195 } else {
196 elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
197 assert(elemWidth >= 0 && "unknown bit width");
198 assert(elemWidth + unknownPos < 32 && "unknown bit width too large");
199 // Create a mask for the unknown bits.
200 uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
201 result.insertBits(APInt(elemWidth, usedBits), bitPos);
202 unknownPos += elemWidth;
203 }
204 bitPos += elemWidth;
205 }
206
207 return result;
208}
209
210// Emulate a binary operation with unknown bits using a table lookup.
211// This function enumerates all possible combinations of unknown bits and
212// emulates the operation for each combination.
213static LogicalResult emulateBinaryOpForUnknownBits(
214 ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
215 Operation *op,
216 llvm::function_ref<APInt(const APInt &, const APInt &)> emulate) {
217 SmallVector<ConstantOrValue> lhsValues, rhsValues;
218
219 assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
220 "op must be a single result binary operation");
221
222 auto lhs = op->getOperand(0);
223 auto rhs = op->getOperand(1);
224 auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
225 auto loc = op->getLoc();
226 auto numLhsUnknownBits = getNumUnknownBitsAndPopulateValues(lhs, lhsValues);
227 auto numRhsUnknownBits = getNumUnknownBitsAndPopulateValues(rhs, rhsValues);
228
229 // If unknown bit width is detected, abort the lowering.
230 if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
231 return failure();
232
233 int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
234 if (totalUnknownBits > maxEmulationUnknownBits)
235 return failure();
236
237 SmallVector<Value> emulatedResults;
238 emulatedResults.reserve(1 << totalUnknownBits);
239
240 // Emulate all possible cases.
241 DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
242 auto getConstant = [&](const APInt &value) -> hw::ConstantOp {
243 auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
244 auto it = constantPool.find(attr);
245 if (it != constantPool.end())
246 return it->second;
247 auto constant = hw::ConstantOp::create(rewriter, loc, value);
248 constantPool[attr] = constant;
249 return constant;
250 };
251
252 for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
253 lhsMask < lhsMaskEnd; ++lhsMask) {
254 APInt lhsValue = substitueMaskToValues(width, lhsValues, lhsMask);
255 for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
256 rhsMask < rhsMaskEnd; ++rhsMask) {
257 APInt rhsValue = substitueMaskToValues(width, rhsValues, rhsMask);
258 // Emulate.
259 emulatedResults.push_back(getConstant(emulate(lhsValue, rhsValue)));
260 }
261 }
262
263 // Create selectors for mux tree.
264 SmallVector<Value> selectors;
265 selectors.reserve(totalUnknownBits);
266 for (auto &concatedValues : {rhsValues, lhsValues})
267 for (auto valueOrConstant : concatedValues) {
268 auto value = dyn_cast<Value>(valueOrConstant);
269 if (!value)
270 continue;
271 extractBits(rewriter, value, selectors);
272 }
273
274 assert(totalUnknownBits == static_cast<int64_t>(selectors.size()) &&
275 "number of selectors must match");
276 auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
277 getConstant(APInt::getZero(width)));
278
279 replaceOpAndCopyNamehint(rewriter, op, muxed);
280 return success();
281}
282
283//===----------------------------------------------------------------------===//
284// Conversion patterns
285//===----------------------------------------------------------------------===//
286
287namespace {
288
289/// Lower a comb::AndOp operation to synth::aig::AndInverterOp
290struct CombAndOpConversion : OpConversionPattern<AndOp> {
292
293 LogicalResult
294 matchAndRewrite(AndOp op, OpAdaptor adaptor,
295 ConversionPatternRewriter &rewriter) const override {
296 SmallVector<bool> nonInverts(adaptor.getInputs().size(), false);
297 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
298 rewriter, op, adaptor.getInputs(), nonInverts);
299 return success();
300 }
301};
302
303/// Lower a comb::OrOp operation to synth::aig::AndInverterOp with invert flags
304struct CombOrToAIGConversion : OpConversionPattern<OrOp> {
306
307 LogicalResult
308 matchAndRewrite(OrOp op, OpAdaptor adaptor,
309 ConversionPatternRewriter &rewriter) const override {
310 // Implement Or using And and invert flags: a | b = ~(~a & ~b)
311 SmallVector<bool> allInverts(adaptor.getInputs().size(), true);
312 auto andOp = synth::aig::AndInverterOp::create(
313 rewriter, op.getLoc(), adaptor.getInputs(), allInverts);
314 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
315 rewriter, op, andOp,
316 /*invert=*/true);
317 return success();
318 }
319};
320
321struct CombOrToMIGConversion : OpConversionPattern<OrOp> {
323 LogicalResult
324 matchAndRewrite(OrOp op, OpAdaptor adaptor,
325 ConversionPatternRewriter &rewriter) const override {
326 if (op.getNumOperands() != 2)
327 return failure();
328 SmallVector<Value, 3> inputs(adaptor.getInputs());
329 auto one = hw::ConstantOp::create(
330 rewriter, op.getLoc(),
331 APInt::getAllOnes(hw::getBitWidth(op.getType())));
332 inputs.push_back(one);
333 std::array<bool, 3> inverts = {false, false, false};
334 replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
335 rewriter, op, inputs, inverts);
336 return success();
337 }
338};
339
340struct AndInverterToMIGConversion
341 : OpConversionPattern<synth::aig::AndInverterOp> {
342 using OpConversionPattern<synth::aig::AndInverterOp>::OpConversionPattern;
343 LogicalResult
344 matchAndRewrite(synth::aig::AndInverterOp op, OpAdaptor adaptor,
345 ConversionPatternRewriter &rewriter) const override {
346 if (op.getNumOperands() > 2)
347 return failure();
348 if (op.getNumOperands() == 1) {
349 SmallVector<bool, 1> inverts{op.getInverted()[0]};
350 replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
351 rewriter, op, adaptor.getInputs(), inverts);
352 return success();
353 }
354 SmallVector<Value, 3> inputs(adaptor.getInputs());
355 auto one = hw::ConstantOp::create(
356 rewriter, op.getLoc(), APInt::getZero(hw::getBitWidth(op.getType())));
357 inputs.push_back(one);
358 SmallVector<bool, 3> inverts(adaptor.getInverted());
359 inverts.push_back(false);
360 replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
361 rewriter, op, inputs, inverts);
362 return success();
363 }
364};
365
366struct MajorityInverterToAIGConversion
367 : OpConversionPattern<synth::mig::MajorityInverterOp> {
369 synth::mig::MajorityInverterOp>::OpConversionPattern;
370 LogicalResult
371 matchAndRewrite(synth::mig::MajorityInverterOp op, OpAdaptor adaptor,
372 ConversionPatternRewriter &rewriter) const override {
373 // Only handle 1 or 3-input majority inverter for now.
374 if (op.getNumOperands() > 3)
375 return failure();
376
377 if (op.getNumOperands() == 1) {
378 bool inverts[1] = {op.getInverted()[0]};
379 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
380 rewriter, op, adaptor.getInputs(), inverts);
381 return success();
382 }
383
384 assert(op.getNumOperands() == 3 && "Expected 3 operands for majority op");
385 auto getProduct = [&](unsigned idx1, unsigned idx2) {
386 bool inverts[2] = {op.getInverted()[idx1], op.getInverted()[idx2]};
387 return synth::aig::AndInverterOp::create(
388 rewriter, op.getLoc(),
389 ValueRange{adaptor.getInputs()[idx1], adaptor.getInputs()[idx2]},
390 inverts);
391 };
392
393 // Majority is the OR of the three pairwise products. Materialize that OR as
394 // a NOR followed by an inverter so it stays in AIG form.
395 Value products[3] = {getProduct(0, 1), getProduct(0, 2), getProduct(1, 2)};
396 bool allInverted[3] = {true, true, true};
397 auto notOr = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
398 products, allInverted);
399 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
400 rewriter, op, notOr,
401 /*invert=*/true);
402 return success();
403 }
404};
405
406/// Lower a comb::XorOp operation to AIG operations
407struct CombXorOpConversion : OpConversionPattern<XorOp> {
409
410 LogicalResult
411 matchAndRewrite(XorOp op, OpAdaptor adaptor,
412 ConversionPatternRewriter &rewriter) const override {
413 if (op.getNumOperands() != 2)
414 return failure();
415 // Xor using And with invert flags: a ^ b = (a | b) & (~a | ~b)
416
417 // (a | b) = ~(~a & ~b)
418 // (~a | ~b) = ~(a & b)
419 auto inputs = adaptor.getInputs();
420 SmallVector<bool> allInverts(inputs.size(), true);
421 SmallVector<bool> allNotInverts(inputs.size(), false);
422
423 auto notAAndNotB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
424 inputs, allInverts);
425 auto aAndB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
426 inputs, allNotInverts);
427
428 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
429 rewriter, op, notAAndNotB, aAndB,
430 /*lhs_invert=*/true,
431 /*rhs_invert=*/true);
432 return success();
433 }
434};
435
436template <typename OpTy>
437struct CombLowerVariadicOp : OpConversionPattern<OpTy> {
439 using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
440 LogicalResult
441 matchAndRewrite(OpTy op, OpAdaptor adaptor,
442 ConversionPatternRewriter &rewriter) const override {
443 auto result = lowerFullyAssociativeOp(op, op.getOperands(), rewriter);
444 replaceOpAndCopyNamehint(rewriter, op, result);
445 return success();
446 }
447
448 static Value lowerFullyAssociativeOp(OpTy op, OperandRange operands,
449 ConversionPatternRewriter &rewriter) {
450 Value lhs, rhs;
451 switch (operands.size()) {
452 case 0:
453 llvm_unreachable("cannot be called with empty operand range");
454 break;
455 case 1:
456 return operands[0];
457 case 2:
458 lhs = operands[0];
459 rhs = operands[1];
460 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs}, true);
461 default:
462 auto firstHalf = operands.size() / 2;
463 lhs =
464 lowerFullyAssociativeOp(op, operands.take_front(firstHalf), rewriter);
465 rhs =
466 lowerFullyAssociativeOp(op, operands.drop_front(firstHalf), rewriter);
467 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs}, true);
468 }
469 }
470};
471
472// Lower comb::MuxOp to AIG operations.
473struct CombMuxOpConversion : OpConversionPattern<MuxOp> {
475
476 LogicalResult
477 matchAndRewrite(MuxOp op, OpAdaptor adaptor,
478 ConversionPatternRewriter &rewriter) const override {
479 Value cond = op.getCond();
480 auto trueVal = op.getTrueValue();
481 auto falseVal = op.getFalseValue();
482
483 if (!op.getType().isInteger()) {
484 // If the type of the mux is not integer, bitcast the operands first.
485 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
486 trueVal =
487 hw::BitcastOp::create(rewriter, op->getLoc(), widthType, trueVal);
488 falseVal =
489 hw::BitcastOp::create(rewriter, op->getLoc(), widthType, falseVal);
490 }
491
492 // Replicate condition if needed
493 if (!trueVal.getType().isInteger(1))
494 cond = comb::ReplicateOp::create(rewriter, op.getLoc(), trueVal.getType(),
495 cond);
496
497 // c ? a : b => (replicate(c) & a) | (~replicate(c) & b)
498 auto lhs =
499 synth::aig::AndInverterOp::create(rewriter, op.getLoc(), cond, trueVal);
500 auto rhs = synth::aig::AndInverterOp::create(rewriter, op.getLoc(), cond,
501 falseVal, true, false);
502
503 Value result = comb::OrOp::create(rewriter, op.getLoc(), lhs, rhs);
504 // Insert the bitcast if the type of the mux is not integer.
505 if (result.getType() != op.getType())
506 result =
507 hw::BitcastOp::create(rewriter, op.getLoc(), op.getType(), result);
508 replaceOpAndCopyNamehint(rewriter, op, result);
509 return success();
510 }
511};
512
513//===----------------------------------------------------------------------===//
514// Adder Architecture Selection
515//===----------------------------------------------------------------------===//
516
517enum AdderArchitecture { RippleCarry, Sklanskey, KoggeStone, BrentKung };
518AdderArchitecture determineAdderArch(Operation *op, int64_t width) {
519 auto strAttr = op->getAttrOfType<StringAttr>("synth.test.arch");
520 if (strAttr) {
521 return llvm::StringSwitch<AdderArchitecture>(strAttr.getValue())
522 .Case("SKLANSKEY", Sklanskey)
523 .Case("KOGGE-STONE", KoggeStone)
524 .Case("BRENT-KUNG", BrentKung)
525 .Case("RIPPLE-CARRY", RippleCarry);
526 }
527 // Determine using width as a heuristic.
528 // TODO: Perform a more thorough analysis to motivate the choices or
529 // implement an adder synthesis algorithm to construct an optimal adder
530 // under the given timing constraints - see the work of Zimmermann
531
532 // For very small adders, overhead of a parallel prefix adder is likely not
533 // worth it.
534 if (width < 8)
535 return AdderArchitecture::RippleCarry;
536
537 // Sklanskey is a good compromise for high-performance, but has high fanout
538 // which may lead to wiring congestion for very large adders.
539 if (width <= 32)
540 return AdderArchitecture::Sklanskey;
541
542 // Kogge-Stone uses greater area than Sklanskey but has lower fanout thus
543 // may be preferable for larger adders.
544 return AdderArchitecture::KoggeStone;
545}
546
547//===----------------------------------------------------------------------===//
548// Parallel Prefix Tree
549//===----------------------------------------------------------------------===//
550
551// Implement the Kogge-Stone parallel prefix tree
552// Described in https://en.wikipedia.org/wiki/Kogge%E2%80%93Stone_adder
553// Slightly better delay than Brent-Kung, but more area.
554void lowerKoggeStonePrefixTree(OpBuilder &builder, Location loc,
555 SmallVector<Value> &pPrefix,
556 SmallVector<Value> &gPrefix) {
557
558 auto width = static_cast<int64_t>(pPrefix.size());
559 assert(width == static_cast<int64_t>(gPrefix.size()));
560 SmallVector<Value> pPrefixNew = pPrefix;
561 SmallVector<Value> gPrefixNew = gPrefix;
562
563 // Kogge-Stone parallel prefix computation
564 for (int64_t stride = 1; stride < width; stride *= 2) {
565
566 for (int64_t i = stride; i < width; ++i) {
567 int64_t j = i - stride;
568
569 // Group generate: g_i OR (p_i AND g_j)
570 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
571 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
572
573 // Group propagate: p_i AND p_j
574 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
575 }
576
577 pPrefix = pPrefixNew;
578 gPrefix = gPrefixNew;
579 }
580
581 LLVM_DEBUG({
582 int64_t stage = 0;
583 for (int64_t stride = 1; stride < width; stride *= 2) {
584 llvm::dbgs()
585 << "--------------------------------------- Kogge-Stone Stage "
586 << stage << "\n";
587 for (int64_t i = stride; i < width; ++i) {
588 int64_t j = i - stride;
589 // Group generate: g_i OR (p_i AND g_j)
590 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
591 << " OR (P" << i << stage << " AND G" << j << stage
592 << ")\n";
593
594 // Group propagate: p_i AND p_j
595 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
596 << " AND P" << j << stage << "\n";
597 }
598 ++stage;
599 }
600 });
601}
602
603// Implement the Sklansky parallel prefix tree
604// High fan-out, low depth, low area
605void lowerSklanskeyPrefixTree(OpBuilder &builder, Location loc,
606 SmallVector<Value> &pPrefix,
607 SmallVector<Value> &gPrefix) {
608 auto width = static_cast<int64_t>(pPrefix.size());
609 assert(width == static_cast<int64_t>(gPrefix.size()));
610 SmallVector<Value> pPrefixNew = pPrefix;
611 SmallVector<Value> gPrefixNew = gPrefix;
612 for (int64_t stride = 1; stride < width; stride *= 2) {
613 for (int64_t i = stride; i < width; i += 2 * stride) {
614 for (int64_t k = 0; k < stride && i + k < width; ++k) {
615 int64_t idx = i + k;
616 int64_t j = i - 1;
617
618 // Group generate: g_idx OR (p_idx AND g_j)
619 Value andPG =
620 comb::AndOp::create(builder, loc, pPrefix[idx], gPrefix[j]);
621 gPrefixNew[idx] = comb::OrOp::create(builder, loc, gPrefix[idx], andPG);
622
623 // Group propagate: p_idx AND p_j
624 pPrefixNew[idx] =
625 comb::AndOp::create(builder, loc, pPrefix[idx], pPrefix[j]);
626 }
627 }
628
629 pPrefix = pPrefixNew;
630 gPrefix = gPrefixNew;
631 }
632
633 LLVM_DEBUG({
634 int64_t stage = 0;
635 for (int64_t stride = 1; stride < width; stride *= 2) {
636 llvm::dbgs() << "--------------------------------------- Sklanskey Stage "
637 << stage << "\n";
638 for (int64_t i = stride; i < width; i += 2 * stride) {
639 for (int64_t k = 0; k < stride && i + k < width; ++k) {
640 int64_t idx = i + k;
641 int64_t j = i - 1;
642 // Group generate: g_i OR (p_i AND g_j)
643 llvm::dbgs() << "G" << idx << stage + 1 << " = G" << idx << stage
644 << " OR (P" << idx << stage << " AND G" << j << stage
645 << ")\n";
646
647 // Group propagate: p_i AND p_j
648 llvm::dbgs() << "P" << idx << stage + 1 << " = P" << idx << stage
649 << " AND P" << j << stage << "\n";
650 }
651 }
652 ++stage;
653 }
654 });
655}
656
657// Implement the Brent-Kung parallel prefix tree
658// Described in https://en.wikipedia.org/wiki/Brent%E2%80%93Kung_adder
659// Slightly worse delay than Kogge-Stone, but less area.
660void lowerBrentKungPrefixTree(OpBuilder &builder, Location loc,
661 SmallVector<Value> &pPrefix,
662 SmallVector<Value> &gPrefix) {
663 auto width = static_cast<int64_t>(pPrefix.size());
664 assert(width == static_cast<int64_t>(gPrefix.size()));
665 SmallVector<Value> pPrefixNew = pPrefix;
666 SmallVector<Value> gPrefixNew = gPrefix;
667 // Brent-Kung parallel prefix computation
668 // Forward phase
669 int64_t stride;
670 for (stride = 1; stride < width; stride *= 2) {
671 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
672 int64_t j = i - stride;
673
674 // Group generate: g_i OR (p_i AND g_j)
675 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
676 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
677
678 // Group propagate: p_i AND p_j
679 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
680 }
681 pPrefix = pPrefixNew;
682 gPrefix = gPrefixNew;
683 }
684
685 // Backward phase
686 for (; stride > 0; stride /= 2) {
687 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
688 int64_t j = i - stride;
689
690 // Group generate: g_i OR (p_i AND g_j)
691 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
692 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
693
694 // Group propagate: p_i AND p_j
695 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
696 }
697 pPrefix = pPrefixNew;
698 gPrefix = gPrefixNew;
699 }
700
701 LLVM_DEBUG({
702 int64_t stage = 0;
703 for (stride = 1; stride < width; stride *= 2) {
704 llvm::dbgs() << "--------------------------------------- Brent-Kung FW "
705 << stage << " : Stride " << stride << "\n";
706 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
707 int64_t j = i - stride;
708
709 // Group generate: g_i OR (p_i AND g_j)
710 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
711 << " OR (P" << i << stage << " AND G" << j << stage
712 << ")\n";
713
714 // Group propagate: p_i AND p_j
715 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
716 << " AND P" << j << stage << "\n";
717 }
718 ++stage;
719 }
720
721 for (; stride > 0; stride /= 2) {
722 if (stride * 3 - 1 < width)
723 llvm::dbgs() << "--------------------------------------- Brent-Kung BW "
724 << stage << " : Stride " << stride << "\n";
725
726 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
727 int64_t j = i - stride;
728
729 // Group generate: g_i OR (p_i AND g_j)
730 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
731 << " OR (P" << i << stage << " AND G" << j << stage
732 << ")\n";
733
734 // Group propagate: p_i AND p_j
735 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
736 << " AND P" << j << stage << "\n";
737 }
738 --stage;
739 }
740 });
741}
742
743// TODO: Generalize to other parallel prefix trees.
744class LazyKoggeStonePrefixTree {
745public:
746 LazyKoggeStonePrefixTree(OpBuilder &builder, Location loc, int64_t width,
747 ArrayRef<Value> pPrefix, ArrayRef<Value> gPrefix)
748 : builder(builder), loc(loc), width(width) {
749 assert(width > 0 && "width must be positive");
750 for (int64_t i = 0; i < width; ++i)
751 prefixCache[{0, i}] = {pPrefix[i], gPrefix[i]};
752 }
753
754 // Get the final group and propagate values for bit i.
755 std::pair<Value, Value> getFinal(int64_t i) {
756 assert(i >= 0 && i < width && "i out of bounds");
757 // Final level is ceil(log2(width)) in Kogge-Stone.
758 return getGroupAndPropagate(llvm::Log2_64_Ceil(width), i);
759 }
760
761private:
762 // Recursively get the group and propagate values for bit i at level `level`.
763 // Level 0 is the initial level with the input propagate and generate values.
764 // Level n computes the group and propagate values for a stride of 2^(n-1).
765 // Uses memoization to cache intermediate results.
766 std::pair<Value, Value> getGroupAndPropagate(int64_t level, int64_t i);
767 OpBuilder &builder;
768 Location loc;
769 int64_t width;
770 DenseMap<std::pair<int64_t, int64_t>, std::pair<Value, Value>> prefixCache;
771};
772
773std::pair<Value, Value>
774LazyKoggeStonePrefixTree::getGroupAndPropagate(int64_t level, int64_t i) {
775 assert(i < width && "i out of bounds");
776 auto key = std::make_pair(level, i);
777 auto it = prefixCache.find(key);
778 if (it != prefixCache.end())
779 return it->second;
780
781 assert(level > 0 && "If the level is 0, we should have hit the cache");
782
783 int64_t previousStride = 1ULL << (level - 1);
784 if (i < previousStride) {
785 // No dependency, just copy from the previous level.
786 auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
787 prefixCache[key] = {propagateI, generateI};
788 return prefixCache[key];
789 }
790 // Get the dependency index.
791 int64_t j = i - previousStride;
792 auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
793 auto [propagateJ, generateJ] = getGroupAndPropagate(level - 1, j);
794 // Group generate: g_i OR (p_i AND g_j)
795 Value andPG = comb::AndOp::create(builder, loc, propagateI, generateJ);
796 Value newGenerate = comb::OrOp::create(builder, loc, generateI, andPG);
797 // Group propagate: p_i AND p_j
798 Value newPropagate =
799 comb::AndOp::create(builder, loc, propagateI, propagateJ);
800 prefixCache[key] = {newPropagate, newGenerate};
801 return prefixCache[key];
802}
803
804template <bool lowerToMIG>
805struct CombAddOpConversion : OpConversionPattern<AddOp> {
807
808 LogicalResult
809 matchAndRewrite(AddOp op, OpAdaptor adaptor,
810 ConversionPatternRewriter &rewriter) const override {
811 auto inputs = adaptor.getInputs();
812 // Lower only when there are two inputs.
813 // Variadic operands must be lowered in a different pattern.
814 if (inputs.size() != 2)
815 return failure();
816
817 auto width = op.getType().getIntOrFloatBitWidth();
818 // Skip a zero width value.
819 if (width == 0) {
820 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
821 op.getType(), 0);
822 return success();
823 }
824
825 // Check if the architecture is specified by an attribute.
826 auto arch = determineAdderArch(op, width);
827 if (arch == AdderArchitecture::RippleCarry)
828 return lowerRippleCarryAdder(op, inputs, rewriter);
829 return lowerParallelPrefixAdder(op, inputs, rewriter);
830 }
831
832 // Implement a basic ripple-carry adder for small bitwidths.
833 LogicalResult
834 lowerRippleCarryAdder(comb::AddOp op, ValueRange inputs,
835 ConversionPatternRewriter &rewriter) const {
836 auto width = op.getType().getIntOrFloatBitWidth();
837 // Implement a naive Ripple-carry full adder.
838 Value carry;
839
840 auto aBits = extractBits(rewriter, inputs[0]);
841 auto bBits = extractBits(rewriter, inputs[1]);
842 SmallVector<Value> results;
843 results.resize(width);
844 for (int64_t i = 0; i < width; ++i) {
845 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
846 if (carry)
847 xorOperands.push_back(carry);
848
849 // sum[i] = xor(carry[i-1], a[i], b[i])
850 // NOTE: The result is stored in reverse order.
851 results[width - i - 1] =
852 comb::XorOp::create(rewriter, op.getLoc(), xorOperands, true);
853
854 // If this is the last bit, we are done.
855 if (i == width - 1)
856 break;
857
858 // carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
859 if (!carry) {
860 // This is the first bit, so the carry is the next carry.
861 carry = comb::AndOp::create(rewriter, op.getLoc(),
862 ValueRange{aBits[i], bBits[i]}, true);
863 continue;
864 }
865
866 carry = createMajorityFunction(rewriter, op.getLoc(), aBits[i], bBits[i],
867 carry, lowerToMIG);
868 }
869 LLVM_DEBUG(llvm::dbgs() << "Lower comb.add to Ripple-Carry Adder of width "
870 << width << "\n");
871
872 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
873 return success();
874 }
875
876 // Implement a parallel prefix adder - with Kogge-Stone or Brent-Kung trees
877 // Will introduce unused signals for the carry bits but these will be removed
878 // by the AIG pass.
879 LogicalResult
880 lowerParallelPrefixAdder(comb::AddOp op, ValueRange inputs,
881 ConversionPatternRewriter &rewriter) const {
882 auto width = op.getType().getIntOrFloatBitWidth();
883
884 auto aBits = extractBits(rewriter, inputs[0]);
885 auto bBits = extractBits(rewriter, inputs[1]);
886
887 // Construct propagate (p) and generate (g) signals
888 SmallVector<Value> p, g;
889 p.reserve(width);
890 g.reserve(width);
891
892 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
893 // p_i = a_i XOR b_i
894 p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit));
895 // g_i = a_i AND b_i
896 g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit));
897 }
898
899 LLVM_DEBUG({
900 llvm::dbgs() << "Lower comb.add to Parallel-Prefix of width " << width
901 << "\n--------------------------------------- Init\n";
902
903 for (int64_t i = 0; i < width; ++i) {
904 // p_i = a_i XOR b_i
905 llvm::dbgs() << "P0" << i << " = A" << i << " XOR B" << i << "\n";
906 // g_i = a_i AND b_i
907 llvm::dbgs() << "G0" << i << " = A" << i << " AND B" << i << "\n";
908 }
909 });
910
911 // Create copies of p and g for the prefix computation
912 SmallVector<Value> pPrefix = p;
913 SmallVector<Value> gPrefix = g;
914
915 // Check if the architecture is specified by an attribute.
916 auto arch = determineAdderArch(op, width);
917
918 switch (arch) {
919 case AdderArchitecture::RippleCarry:
920 llvm_unreachable("Ripple-Carry should be handled separately");
921 break;
922 case AdderArchitecture::Sklanskey:
923 lowerSklanskeyPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
924 break;
925 case AdderArchitecture::KoggeStone:
926 lowerKoggeStonePrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
927 break;
928 case AdderArchitecture::BrentKung:
929 lowerBrentKungPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
930 break;
931 }
932
933 // Generate result sum bits
934 // NOTE: The result is stored in reverse order.
935 SmallVector<Value> results;
936 results.resize(width);
937 // Sum bit 0 is just p[0] since carry_in = 0
938 results[width - 1] = p[0];
939
940 // For remaining bits, sum_i = p_i XOR g_(i-1)
941 // The carry into position i is the group generate from position i-1
942 for (int64_t i = 1; i < width; ++i)
943 results[width - 1 - i] =
944 comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]);
945
946 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
947
948 LLVM_DEBUG({
949 llvm::dbgs() << "--------------------------------------- Completion\n"
950 << "RES0 = P0\n";
951 for (int64_t i = 1; i < width; ++i)
952 llvm::dbgs() << "RES" << i << " = P" << i << " XOR G" << i - 1 << "\n";
953 });
954
955 return success();
956 }
957};
958
959struct CombMulOpConversion : OpConversionPattern<MulOp> {
961 using OpAdaptor = typename OpConversionPattern<MulOp>::OpAdaptor;
962 LogicalResult
963 matchAndRewrite(MulOp op, OpAdaptor adaptor,
964 ConversionPatternRewriter &rewriter) const override {
965 if (adaptor.getInputs().size() != 2)
966 return failure();
967
968 Location loc = op.getLoc();
969 Value a = adaptor.getInputs()[0];
970 Value b = adaptor.getInputs()[1];
971 unsigned width = op.getType().getIntOrFloatBitWidth();
972
973 // Skip a zero width value.
974 if (width == 0) {
975 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
976 return success();
977 }
978
979 // Extract individual bits from operands
980 SmallVector<Value> aBits = extractBits(rewriter, a);
981 SmallVector<Value> bBits = extractBits(rewriter, b);
982
983 auto falseValue = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));
984
985 // Generate partial products
986 SmallVector<SmallVector<Value>> partialProducts;
987 partialProducts.reserve(width);
988 for (unsigned i = 0; i < width; ++i) {
989 SmallVector<Value> row(i, falseValue);
990 row.reserve(width);
991 // Generate partial product bits
992 for (unsigned j = 0; i + j < width; ++j)
993 row.push_back(
994 rewriter.createOrFold<comb::AndOp>(loc, aBits[j], bBits[i]));
995
996 partialProducts.push_back(row);
997 }
998
999 // If the width is 1, we are done.
1000 if (width == 1) {
1001 rewriter.replaceOp(op, partialProducts[0][0]);
1002 return success();
1003 }
1004
1005 // Wallace tree reduction - reduce to two addends.
1006 datapath::CompressorTree comp(width, partialProducts, loc);
1007 auto addends = comp.compressToHeight(rewriter, 2);
1008
1009 // Sum the two addends using a carry-propagate adder
1010 auto newAdd = comb::AddOp::create(rewriter, loc, addends, true);
1011 replaceOpAndCopyNamehint(rewriter, op, newAdd);
1012 return success();
1013 }
1014};
1015
1016template <typename OpTy>
1017struct DivModOpConversionBase : OpConversionPattern<OpTy> {
1018 DivModOpConversionBase(MLIRContext *context, int64_t maxEmulationUnknownBits)
1020 maxEmulationUnknownBits(maxEmulationUnknownBits) {
1021 assert(maxEmulationUnknownBits < 32 &&
1022 "maxEmulationUnknownBits must be less than 32");
1023 }
1024 const int64_t maxEmulationUnknownBits;
1025};
1026
1027struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
1028 using DivModOpConversionBase<DivUOp>::DivModOpConversionBase;
1029 LogicalResult
1030 matchAndRewrite(DivUOp op, OpAdaptor adaptor,
1031 ConversionPatternRewriter &rewriter) const override {
1032 // Check if the divisor is a power of two.
1033 if (llvm::succeeded(comb::convertDivUByPowerOfTwo(op, rewriter)))
1034 return success();
1035
1036 // When rhs is not power of two and the number of unknown bits are small,
1037 // create a mux tree that emulates all possible cases.
1039 rewriter, maxEmulationUnknownBits, op,
1040 [](const APInt &lhs, const APInt &rhs) {
1041 // Division by zero is undefined, just return zero.
1042 if (rhs.isZero())
1043 return APInt::getZero(rhs.getBitWidth());
1044 return lhs.udiv(rhs);
1045 });
1046 }
1047};
1048
1049struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
1050 using DivModOpConversionBase<ModUOp>::DivModOpConversionBase;
1051 LogicalResult
1052 matchAndRewrite(ModUOp op, OpAdaptor adaptor,
1053 ConversionPatternRewriter &rewriter) const override {
1054 // Check if the divisor is a power of two.
1055 if (llvm::succeeded(comb::convertModUByPowerOfTwo(op, rewriter)))
1056 return success();
1057
1058 // When rhs is not power of two and the number of unknown bits are small,
1059 // create a mux tree that emulates all possible cases.
1061 rewriter, maxEmulationUnknownBits, op,
1062 [](const APInt &lhs, const APInt &rhs) {
1063 // Division by zero is undefined, just return zero.
1064 if (rhs.isZero())
1065 return APInt::getZero(rhs.getBitWidth());
1066 return lhs.urem(rhs);
1067 });
1068 }
1069};
1070
1071struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
1072 using DivModOpConversionBase<DivSOp>::DivModOpConversionBase;
1073
1074 LogicalResult
1075 matchAndRewrite(DivSOp op, OpAdaptor adaptor,
1076 ConversionPatternRewriter &rewriter) const override {
1077 // Currently only lower with emulation.
1078 // TODO: Implement a signed division lowering at least for power of two.
1080 rewriter, maxEmulationUnknownBits, op,
1081 [](const APInt &lhs, const APInt &rhs) {
1082 // Division by zero is undefined, just return zero.
1083 if (rhs.isZero())
1084 return APInt::getZero(rhs.getBitWidth());
1085 return lhs.sdiv(rhs);
1086 });
1087 }
1088};
1089
1090struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
1091 using DivModOpConversionBase<ModSOp>::DivModOpConversionBase;
1092 LogicalResult
1093 matchAndRewrite(ModSOp op, OpAdaptor adaptor,
1094 ConversionPatternRewriter &rewriter) const override {
1095 // Currently only lower with emulation.
1096 // TODO: Implement a signed modulus lowering at least for power of two.
1098 rewriter, maxEmulationUnknownBits, op,
1099 [](const APInt &lhs, const APInt &rhs) {
1100 // Division by zero is undefined, just return zero.
1101 if (rhs.isZero())
1102 return APInt::getZero(rhs.getBitWidth());
1103 return lhs.srem(rhs);
1104 });
1105 }
1106};
1107
1108struct CombICmpOpConversion : OpConversionPattern<ICmpOp> {
1110
1111 // Simple comparator for small bit widths
1112 static Value constructRippleCarry(Location loc, Value a, Value b,
1113 bool includeEq,
1114 ConversionPatternRewriter &rewriter) {
1115 // Construct following unsigned comparison expressions.
1116 // a <= b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] <= b[n-1:0])
1117 // a < b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] < b[n-1:0])
1118 auto aBits = extractBits(rewriter, a);
1119 auto bBits = extractBits(rewriter, b);
1120 Value acc = hw::ConstantOp::create(rewriter, loc, APInt(1, includeEq));
1121
1122 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
1123 auto aBitXorBBit =
1124 rewriter.createOrFold<comb::XorOp>(loc, aBit, bBit, true);
1125 auto aEqualB = rewriter.createOrFold<synth::aig::AndInverterOp>(
1126 loc, aBitXorBBit, true);
1127 auto pred = rewriter.createOrFold<synth::aig::AndInverterOp>(
1128 loc, aBit, bBit, true, false);
1129
1130 auto aBitAndBBit = rewriter.createOrFold<comb::AndOp>(
1131 loc, ValueRange{aEqualB, acc}, true);
1132 acc = rewriter.createOrFold<comb::OrOp>(loc, pred, aBitAndBBit, true);
1133 }
1134 return acc;
1135 }
1136
1137 // Compute prefix comparison using parallel prefix algorithm
1138 // Note: This generates all intermediate prefix values even though we only
1139 // need the final result. Optimizing this to skip intermediate computations
1140 // is non-trivial because each iteration depends on results from previous
1141 // iterations. We rely on DCE passes to remove unused operations.
1142 // TODO: Lazily compute only the required prefix values. Kogge-Stone is
1143 // already implemented in a lazy manner below, but other architectures can
1144 // also be optimized.
1145 static Value computePrefixComparison(ConversionPatternRewriter &rewriter,
1146 Location loc, SmallVector<Value> pPrefix,
1147 SmallVector<Value> gPrefix,
1148 bool includeEq, AdderArchitecture arch) {
1149 auto width = pPrefix.size();
1150 Value finalGroup, finalPropagate;
1151 // Apply the appropriate prefix tree algorithm
1152 switch (arch) {
1153 case AdderArchitecture::RippleCarry:
1154 llvm_unreachable("Ripple-Carry should be handled separately");
1155 break;
1156 case AdderArchitecture::Sklanskey: {
1157 lowerSklanskeyPrefixTree(rewriter, loc, pPrefix, gPrefix);
1158 finalGroup = gPrefix[width - 1];
1159 finalPropagate = pPrefix[width - 1];
1160 break;
1161 }
1162 case AdderArchitecture::KoggeStone:
1163 // Use lazy Kogge-Stone implementation to avoid computing all
1164 // intermediate prefix values.
1165 std::tie(finalPropagate, finalGroup) =
1166 LazyKoggeStonePrefixTree(rewriter, loc, width, pPrefix, gPrefix)
1167 .getFinal(width - 1);
1168 break;
1169 case AdderArchitecture::BrentKung: {
1170 lowerBrentKungPrefixTree(rewriter, loc, pPrefix, gPrefix);
1171 finalGroup = gPrefix[width - 1];
1172 finalPropagate = pPrefix[width - 1];
1173 break;
1174 }
1175 }
1176
1177 // Final result: `finalGroup` gives us "a < b"
1178 if (includeEq) {
1179 // a <= b iff (a < b) OR (a == b)
1180 // a == b iff `finalPropagate` (all bits are equal)
1181 return comb::OrOp::create(rewriter, loc, finalGroup, finalPropagate);
1182 }
1183 // a < b iff `finalGroup`
1184 return finalGroup;
1185 }
1186
1187 // Construct an unsigned comparator using either ripple-carry or
1188 // parallel-prefix architecture. Comparison uses parallel prefix tree as an
1189 // internal component, so use `AdderArchitecture` enum to select architecture.
1190 static Value constructUnsignedCompare(Operation *op, Location loc, Value a,
1191 Value b, bool isLess, bool includeEq,
1192 ConversionPatternRewriter &rewriter) {
1193 // Ensure a <= b by swapping for simplicity.
1194 if (!isLess)
1195 std::swap(a, b);
1196 auto width = a.getType().getIntOrFloatBitWidth();
1197
1198 // Check if the architecture is specified by an attribute.
1199 auto arch = determineAdderArch(op, width);
1200 if (arch == AdderArchitecture::RippleCarry)
1201 return constructRippleCarry(loc, a, b, includeEq, rewriter);
1202
1203 // For larger widths, use parallel prefix tree
1204 auto aBits = extractBits(rewriter, a);
1205 auto bBits = extractBits(rewriter, b);
1206
1207 // For comparison, we compute:
1208 // - Equal bits: eq_i = ~(a_i ^ b_i)
1209 // - Greater bits: gt_i = ~a_i & b_i (a_i < b_i)
1210 // - Propagate: p_i = eq_i (equality propagates)
1211 // - Generate: g_i = gt_i (greater-than generates)
1212 SmallVector<Value> eq, gt;
1213 eq.reserve(width);
1214 gt.reserve(width);
1215
1216 auto one =
1217 hw::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(1), 1);
1218
1219 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
1220 // eq_i = ~(a_i ^ b_i) = a_i == b_i
1221 auto xorBit = comb::XorOp::create(rewriter, loc, aBit, bBit);
1222 eq.push_back(comb::XorOp::create(rewriter, loc, xorBit, one));
1223
1224 // gt_i = ~a_i & b_i = a_i < b_i
1225 auto notA = comb::XorOp::create(rewriter, loc, aBit, one);
1226 gt.push_back(comb::AndOp::create(rewriter, loc, notA, bBit));
1227 }
1228
1229 return computePrefixComparison(rewriter, loc, std::move(eq), std::move(gt),
1230 includeEq, arch);
1231 }
1232
1233 LogicalResult
1234 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
1235 ConversionPatternRewriter &rewriter) const override {
1236 auto lhs = adaptor.getLhs();
1237 auto rhs = adaptor.getRhs();
1238
1239 switch (op.getPredicate()) {
1240 default:
1241 return failure();
1242
1243 case ICmpPredicate::eq:
1244 case ICmpPredicate::ceq: {
1245 // a == b ==> ~(a[n] ^ b[n]) & ~(a[n-1] ^ b[n-1]) & ...
1246 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
1247 auto xorBits = extractBits(rewriter, xorOp);
1248 SmallVector<bool> allInverts(xorBits.size(), true);
1249 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
1250 rewriter, op, xorBits, allInverts);
1251 return success();
1252 }
1253
1254 case ICmpPredicate::ne:
1255 case ICmpPredicate::cne: {
1256 // a != b ==> (a[n] ^ b[n]) | (a[n-1] ^ b[n-1]) | ...
1257 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
1258 replaceOpWithNewOpAndCopyNamehint<comb::OrOp>(
1259 rewriter, op, extractBits(rewriter, xorOp), true);
1260 return success();
1261 }
1262
1263 case ICmpPredicate::uge:
1264 case ICmpPredicate::ugt:
1265 case ICmpPredicate::ule:
1266 case ICmpPredicate::ult: {
1267 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
1268 op.getPredicate() == ICmpPredicate::ule;
1269 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
1270 op.getPredicate() == ICmpPredicate::ule;
1271 replaceOpAndCopyNamehint(rewriter, op,
1272 constructUnsignedCompare(op, op.getLoc(), lhs,
1273 rhs, isLess, includeEq,
1274 rewriter));
1275 return success();
1276 }
1277 case ICmpPredicate::slt:
1278 case ICmpPredicate::sle:
1279 case ICmpPredicate::sgt:
1280 case ICmpPredicate::sge: {
1281 if (lhs.getType().getIntOrFloatBitWidth() == 0)
1282 return rewriter.notifyMatchFailure(
1283 op.getLoc(), "i0 signed comparison is unsupported");
1284 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
1285 op.getPredicate() == ICmpPredicate::sle;
1286 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
1287 op.getPredicate() == ICmpPredicate::sle;
1288
1289 // Get a sign bit
1290 auto signA = extractMSB(rewriter, lhs);
1291 auto signB = extractMSB(rewriter, rhs);
1292 auto aRest = extractOtherThanMSB(rewriter, lhs);
1293 auto bRest = extractOtherThanMSB(rewriter, rhs);
1294
1295 // Compare magnitudes (all bits except sign)
1296 auto sameSignResult = constructUnsignedCompare(
1297 op, op.getLoc(), aRest, bRest, isLess, includeEq, rewriter);
1298
1299 // XOR of signs: true if signs are different
1300 auto signsDiffer =
1301 comb::XorOp::create(rewriter, op.getLoc(), signA, signB);
1302
1303 // Result when signs are different
1304 Value diffSignResult = isLess ? signA : signB;
1305
1306 // Final result: choose based on whether signs differ
1307 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1308 rewriter, op, signsDiffer, diffSignResult, sameSignResult);
1309 return success();
1310 }
1311 }
1312 }
1313};
1314
1315struct CombParityOpConversion : OpConversionPattern<ParityOp> {
1317
1318 LogicalResult
1319 matchAndRewrite(ParityOp op, OpAdaptor adaptor,
1320 ConversionPatternRewriter &rewriter) const override {
1321 // Parity is the XOR of all bits.
1322 replaceOpWithNewOpAndCopyNamehint<comb::XorOp>(
1323 rewriter, op, extractBits(rewriter, adaptor.getInput()), true);
1324 return success();
1325 }
1326};
1327
1328struct CombShlOpConversion : OpConversionPattern<comb::ShlOp> {
1330
1331 LogicalResult
1332 matchAndRewrite(comb::ShlOp op, OpAdaptor adaptor,
1333 ConversionPatternRewriter &rewriter) const override {
1334 auto width = op.getType().getIntOrFloatBitWidth();
1335 auto lhs = adaptor.getLhs();
1336 auto result = createShiftLogic</*isLeftShift=*/true>(
1337 rewriter, op.getLoc(), adaptor.getRhs(), width,
1338 /*getPadding=*/
1339 [&](int64_t index) {
1340 // Don't create zero width value.
1341 if (index == 0)
1342 return Value();
1343 // Padding is 0 for left shift.
1344 return rewriter.createOrFold<hw::ConstantOp>(
1345 op.getLoc(), rewriter.getIntegerType(index), 0);
1346 },
1347 /*getExtract=*/
1348 [&](int64_t index) {
1349 assert(index < width && "index out of bounds");
1350 // Exract the bits from LSB.
1351 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, 0,
1352 width - index);
1353 });
1354
1355 replaceOpAndCopyNamehint(rewriter, op, result);
1356 return success();
1357 }
1358};
1359
1360struct CombShrUOpConversion : OpConversionPattern<comb::ShrUOp> {
1362
1363 LogicalResult
1364 matchAndRewrite(comb::ShrUOp op, OpAdaptor adaptor,
1365 ConversionPatternRewriter &rewriter) const override {
1366 auto width = op.getType().getIntOrFloatBitWidth();
1367 auto lhs = adaptor.getLhs();
1368 auto result = createShiftLogic</*isLeftShift=*/false>(
1369 rewriter, op.getLoc(), adaptor.getRhs(), width,
1370 /*getPadding=*/
1371 [&](int64_t index) {
1372 // Don't create zero width value.
1373 if (index == 0)
1374 return Value();
1375 // Padding is 0 for right shift.
1376 return rewriter.createOrFold<hw::ConstantOp>(
1377 op.getLoc(), rewriter.getIntegerType(index), 0);
1378 },
1379 /*getExtract=*/
1380 [&](int64_t index) {
1381 assert(index < width && "index out of bounds");
1382 // Exract the bits from MSB.
1383 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
1384 width - index);
1385 });
1386
1387 replaceOpAndCopyNamehint(rewriter, op, result);
1388 return success();
1389 }
1390};
1391
1392struct CombShrSOpConversion : OpConversionPattern<comb::ShrSOp> {
1394
1395 LogicalResult
1396 matchAndRewrite(comb::ShrSOp op, OpAdaptor adaptor,
1397 ConversionPatternRewriter &rewriter) const override {
1398 auto width = op.getType().getIntOrFloatBitWidth();
1399 if (width == 0)
1400 return rewriter.notifyMatchFailure(op.getLoc(),
1401 "i0 signed shift is unsupported");
1402 auto lhs = adaptor.getLhs();
1403 // Get the sign bit.
1404 auto sign =
1405 rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
1406
1407 // NOTE: The max shift amount is width - 1 because the sign bit is
1408 // already shifted out.
1409 auto result = createShiftLogic</*isLeftShift=*/false>(
1410 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
1411 /*getPadding=*/
1412 [&](int64_t index) {
1413 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
1414 index + 1);
1415 },
1416 /*getExtract=*/
1417 [&](int64_t index) {
1418 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
1419 width - index - 1);
1420 });
1421
1422 replaceOpAndCopyNamehint(rewriter, op, result);
1423 return success();
1424 }
1425};
1426
1427} // namespace
1428
1429//===----------------------------------------------------------------------===//
1430// Convert Comb to AIG pass
1431//===----------------------------------------------------------------------===//
1432
1433namespace {
1434struct ConvertCombToSynthPass
1435 : public impl::ConvertCombToSynthBase<ConvertCombToSynthPass> {
1436 void runOnOperation() override;
1437 using ConvertCombToSynthBase<ConvertCombToSynthPass>::ConvertCombToSynthBase;
1438};
1439} // namespace
1440
1441static void
1443 uint32_t maxEmulationUnknownBits,
1444 bool lowerToMIG) {
1445 patterns.add<
1446 // Bitwise Logical Ops
1447 CombAndOpConversion, CombXorOpConversion, CombMuxOpConversion,
1448 CombParityOpConversion,
1449 // Arithmetic Ops
1450 CombMulOpConversion, CombICmpOpConversion,
1451 // Shift Ops
1452 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
1453 // Variadic ops that must be lowered to binary operations
1454 CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
1455 CombLowerVariadicOp<MulOp>>(patterns.getContext());
1456
1457 patterns.add(comb::convertSubToAdd);
1458
1459 if (lowerToMIG) {
1460 patterns.add<CombOrToMIGConversion, CombLowerVariadicOp<OrOp>,
1461 AndInverterToMIGConversion,
1463 CombAddOpConversion</*useMIG=*/true>>(patterns.getContext());
1464 } else {
1465 patterns.add<CombOrToAIGConversion, MajorityInverterToAIGConversion,
1466 CombAddOpConversion</*useMIG=*/false>>(patterns.getContext());
1467 }
1468
1469 // Add div/mod patterns with a threshold given by the pass option.
1470 patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
1471 CombModSOpConversion>(patterns.getContext(),
1472 maxEmulationUnknownBits);
1473}
1474
1475void ConvertCombToSynthPass::runOnOperation() {
1476 ConversionTarget target(getContext());
1477
1478 // Comb is source dialect.
1479 target.addIllegalDialect<comb::CombDialect>();
1480 // Keep data movement operations like Extract, Concat and Replicate.
1481 target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
1483
1484 // Treat array operations as illegal. Strictly speaking, other than array
1485 // get operation with non-const index are legal in AIG but array types
1486 // prevent a bunch of optimizations so just lower them to integer
1487 // operations. It's required to run HWAggregateToComb pass before this pass.
1489 hw::AggregateConstantOp>();
1490
1491 target.addLegalDialect<synth::SynthDialect>();
1492
1493 if (targetIR == CombToSynthTargetIR::AIG) {
1494 // AIG is target dialect.
1495 target.addIllegalOp<synth::mig::MajorityInverterOp>();
1496 } else if (targetIR == CombToSynthTargetIR::MIG) {
1497 target.addIllegalOp<synth::aig::AndInverterOp>();
1498 }
1499
1500 // If additional legal ops are specified, add them to the target.
1501 if (!additionalLegalOps.empty())
1502 for (const auto &opName : additionalLegalOps)
1503 target.addLegalOp(OperationName(opName, &getContext()));
1504
1505 RewritePatternSet patterns(&getContext());
1506 populateCombToAIGConversionPatterns(patterns, maxEmulationUnknownBits,
1507 targetIR == CombToSynthTargetIR::MIG);
1508
1509 if (failed(mlir::applyPartialConversion(getOperation(), target,
1510 std::move(patterns))))
1511 return signalPassFailure();
1512}
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
static APInt substitueMaskToValues(size_t width, llvm::SmallVectorImpl< ConstantOrValue > &constantOrValues, uint32_t mask)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, uint32_t maxEmulationUnknownBits, bool lowerToMIG)
static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a, Value b, Value carry, bool useMajorityInverterOp)
static LogicalResult emulateBinaryOpForUnknownBits(ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits, Operation *op, llvm::function_ref< APInt(const APInt &, const APInt &)> emulate)
static int64_t getNumUnknownBitsAndPopulateValues(Value value, llvm::SmallVectorImpl< ConstantOrValue > &values)
static Value extractOtherThanMSB(OpBuilder &builder, Value val)
static Value extractMSB(OpBuilder &builder, Value val)
static std::unique_ptr< Context > context
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
create(data_type, value)
Definition hw.py:441
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
Definition comb.py:1