CIRCT 23.0.0git
Loading...
Searching...
No Matches
CombToSynth.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main Comb to Synth Conversion Pass Implementation.
10//
11// High-level Comb Operations
12// |
13// v
14// +-------------------+
15// | and, or, xor, mux |
16// +---------+---------+
17// |
18// +-----+
19// | AIG |
20// +-----+
21//
22//===----------------------------------------------------------------------===//
23
31#include "mlir/Pass/Pass.h"
32#include "mlir/Transforms/DialectConversion.h"
33#include "llvm/ADT/APInt.h"
34#include "llvm/ADT/PointerUnion.h"
35#include "llvm/Support/Debug.h"
36#include <array>
37
38#define DEBUG_TYPE "comb-to-synth"
39
40namespace circt {
41#define GEN_PASS_DEF_CONVERTCOMBTOSYNTH
42#include "circt/Conversion/Passes.h.inc"
43} // namespace circt
44
45using namespace circt;
46using namespace comb;
47
48//===----------------------------------------------------------------------===//
49// Utility Functions
50//===----------------------------------------------------------------------===//
51
52// A wrapper for comb::extractBits that returns a SmallVector<Value>.
53static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
54 SmallVector<Value> bits;
55 comb::extractBits(builder, val, bits);
56 return bits;
57}
58
59// Construct a mux tree for shift operations. `isLeftShift` controls the
60// direction of the shift operation and is used to determine order of the
61// padding and extracted bits. Callbacks `getPadding` and `getExtract` are used
62// to get the padding and extracted bits for each shift amount. `getPadding`
63// could return a nullptr as i0 value but except for that, these callbacks must
64// return a valid value for each shift amount in the range [0, maxShiftAmount].
65// The value for `maxShiftAmount` is used as the out-of-bounds value.
66template <bool isLeftShift>
67static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc,
68 Value shiftAmount, int64_t maxShiftAmount,
69 llvm::function_ref<Value(int64_t)> getPadding,
70 llvm::function_ref<Value(int64_t)> getExtract) {
71 // Extract individual bits from shift amount
72 auto bits = extractBits(rewriter, shiftAmount);
73
74 // Create nodes for each possible shift amount
75 SmallVector<Value> nodes;
76 nodes.reserve(maxShiftAmount);
77 for (int64_t i = 0; i < maxShiftAmount; ++i) {
78 Value extract = getExtract(i);
79 Value padding = getPadding(i);
80
81 if (!padding) {
82 nodes.push_back(extract);
83 continue;
84 }
85
86 // Concatenate extracted bits with padding
87 if (isLeftShift)
88 nodes.push_back(
89 rewriter.createOrFold<comb::ConcatOp>(loc, extract, padding));
90 else
91 nodes.push_back(
92 rewriter.createOrFold<comb::ConcatOp>(loc, padding, extract));
93 }
94
95 // Create out-of-bounds value
96 auto outOfBoundsValue = getPadding(maxShiftAmount);
97 assert(outOfBoundsValue && "outOfBoundsValue must be valid");
98
99 // Construct mux tree for shift operation
100 auto result =
101 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
102
103 // Add bounds checking
104 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
105 loc, ICmpPredicate::ult, shiftAmount,
106 hw::ConstantOp::create(rewriter, loc, shiftAmount.getType(),
107 maxShiftAmount));
108
109 return rewriter.createOrFold<comb::MuxOp>(loc, inBound, result,
110 outOfBoundsValue);
111}
112
113// Return a majority function implemented with Comb operations. `carry` has
114// slightly smaller depth than the other inputs.
115static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a,
116 Value b, Value carry) {
117 // maj(a, b, c) = (c & (a ^ b)) | (a & b)
118 auto aXnorB = comb::XorOp::create(rewriter, loc, ValueRange{a, b}, true);
119 auto andOp =
120 comb::AndOp::create(rewriter, loc, ValueRange{carry, aXnorB}, true);
121 auto aAndB = comb::AndOp::create(rewriter, loc, ValueRange{a, b}, true);
122 return comb::OrOp::create(rewriter, loc, ValueRange{andOp, aAndB}, true);
123}
124
125static Value extractMSB(OpBuilder &builder, Value val) {
126 return builder.createOrFold<comb::ExtractOp>(
127 val.getLoc(), val, val.getType().getIntOrFloatBitWidth() - 1, 1);
128}
129
130static Value extractOtherThanMSB(OpBuilder &builder, Value val) {
131 return builder.createOrFold<comb::ExtractOp>(
132 val.getLoc(), val, 0, val.getType().getIntOrFloatBitWidth() - 1);
133}
134
135namespace {
136// A union of Value and IntegerAttr to cleanly handle constant values.
137using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
138} // namespace
139
140// Return the number of unknown bits and populate the concatenated values.
142 Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
143 // Constant or zero width value are all known.
144 if (value.getType().isInteger(0))
145 return 0;
146
147 // Recursively count unknown bits for concat.
148 if (auto concat = value.getDefiningOp<comb::ConcatOp>()) {
149 int64_t totalUnknownBits = 0;
150 for (auto concatInput : llvm::reverse(concat.getInputs())) {
151 auto unknownBits =
152 getNumUnknownBitsAndPopulateValues(concatInput, values);
153 if (unknownBits < 0)
154 return unknownBits;
155 totalUnknownBits += unknownBits;
156 }
157 return totalUnknownBits;
158 }
159
160 // Constant value is known.
161 if (auto constant = value.getDefiningOp<hw::ConstantOp>()) {
162 values.push_back(constant.getValueAttr());
163 return 0;
164 }
165
166 // Consider other operations as unknown bits.
167 // TODO: We can handle replicate, extract, etc.
168 values.push_back(value);
169 return hw::getBitWidth(value.getType());
170}
171
172// Return a value that substitutes the unknown bits with the mask.
173static APInt
175 llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
176 uint32_t mask) {
177 uint32_t bitPos = 0, unknownPos = 0;
178 APInt result(width, 0);
179 for (auto constantOrValue : constantOrValues) {
180 int64_t elemWidth;
181 if (auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
182 elemWidth = constant.getValue().getBitWidth();
183 result.insertBits(constant.getValue(), bitPos);
184 } else {
185 elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
186 assert(elemWidth >= 0 && "unknown bit width");
187 assert(elemWidth + unknownPos < 32 && "unknown bit width too large");
188 // Create a mask for the unknown bits.
189 uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
190 result.insertBits(APInt(elemWidth, usedBits), bitPos);
191 unknownPos += elemWidth;
192 }
193 bitPos += elemWidth;
194 }
195
196 return result;
197}
198
199// Emulate a binary operation with unknown bits using a table lookup.
200// This function enumerates all possible combinations of unknown bits and
201// emulates the operation for each combination.
202static LogicalResult emulateBinaryOpForUnknownBits(
203 ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
204 Operation *op,
205 llvm::function_ref<APInt(const APInt &, const APInt &)> emulate) {
206 SmallVector<ConstantOrValue> lhsValues, rhsValues;
207
208 assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
209 "op must be a single result binary operation");
210
211 auto lhs = op->getOperand(0);
212 auto rhs = op->getOperand(1);
213 auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
214 auto loc = op->getLoc();
215 auto numLhsUnknownBits = getNumUnknownBitsAndPopulateValues(lhs, lhsValues);
216 auto numRhsUnknownBits = getNumUnknownBitsAndPopulateValues(rhs, rhsValues);
217
218 // If unknown bit width is detected, abort the lowering.
219 if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
220 return failure();
221
222 int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
223 if (totalUnknownBits > maxEmulationUnknownBits)
224 return failure();
225
226 SmallVector<Value> emulatedResults;
227 emulatedResults.reserve(1 << totalUnknownBits);
228
229 // Emulate all possible cases.
230 DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
231 auto getConstant = [&](const APInt &value) -> hw::ConstantOp {
232 auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
233 auto it = constantPool.find(attr);
234 if (it != constantPool.end())
235 return it->second;
236 auto constant = hw::ConstantOp::create(rewriter, loc, value);
237 constantPool[attr] = constant;
238 return constant;
239 };
240
241 for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
242 lhsMask < lhsMaskEnd; ++lhsMask) {
243 APInt lhsValue = substitueMaskToValues(width, lhsValues, lhsMask);
244 for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
245 rhsMask < rhsMaskEnd; ++rhsMask) {
246 APInt rhsValue = substitueMaskToValues(width, rhsValues, rhsMask);
247 // Emulate.
248 emulatedResults.push_back(getConstant(emulate(lhsValue, rhsValue)));
249 }
250 }
251
252 // Create selectors for mux tree.
253 SmallVector<Value> selectors;
254 selectors.reserve(totalUnknownBits);
255 for (auto &concatedValues : {rhsValues, lhsValues})
256 for (auto valueOrConstant : concatedValues) {
257 auto value = dyn_cast<Value>(valueOrConstant);
258 if (!value)
259 continue;
260 extractBits(rewriter, value, selectors);
261 }
262
263 assert(totalUnknownBits == static_cast<int64_t>(selectors.size()) &&
264 "number of selectors must match");
265 auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
266 getConstant(APInt::getZero(width)));
267
268 replaceOpAndCopyNamehint(rewriter, op, muxed);
269 return success();
270}
271
272//===----------------------------------------------------------------------===//
273// Conversion patterns
274//===----------------------------------------------------------------------===//
275
276namespace {
277
278/// Lower a comb::AndOp operation to synth::aig::AndInverterOp
279struct CombAndOpConversion : OpConversionPattern<AndOp> {
281
282 LogicalResult
283 matchAndRewrite(AndOp op, OpAdaptor adaptor,
284 ConversionPatternRewriter &rewriter) const override {
285 SmallVector<bool> nonInverts(adaptor.getInputs().size(), false);
286 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
287 rewriter, op, adaptor.getInputs(), nonInverts);
288 return success();
289 }
290};
291
292/// Lower a comb::OrOp operation to synth::aig::AndInverterOp with invert flags
293struct CombOrToAIGConversion : OpConversionPattern<OrOp> {
295
296 LogicalResult
297 matchAndRewrite(OrOp op, OpAdaptor adaptor,
298 ConversionPatternRewriter &rewriter) const override {
299 // Implement Or using And and invert flags: a | b = ~(~a & ~b)
300 SmallVector<bool> allInverts(adaptor.getInputs().size(), true);
301 auto andOp = synth::aig::AndInverterOp::create(
302 rewriter, op.getLoc(), adaptor.getInputs(), allInverts);
303 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
304 rewriter, op, andOp,
305 /*invert=*/true);
306 return success();
307 }
308};
309
310struct CombXorOpToSynthConversion : OpConversionPattern<XorOp> {
312
313 LogicalResult
314 matchAndRewrite(XorOp op, OpAdaptor adaptor,
315 ConversionPatternRewriter &rewriter) const override {
316 SmallVector<bool> inverted(adaptor.getInputs().size(), false);
317 replaceOpWithNewOpAndCopyNamehint<synth::XorInverterOp>(
318 rewriter, op, adaptor.getInputs(), inverted);
319 return success();
320 }
321};
322
323/// Lower a synth::XorOp operation to AIG operations
324struct SynthXorInverterOpConversion
325 : OpConversionPattern<synth::XorInverterOp> {
326 using OpConversionPattern<synth::XorInverterOp>::OpConversionPattern;
327
328 LogicalResult
329 matchAndRewrite(synth::XorInverterOp op, OpAdaptor adaptor,
330 ConversionPatternRewriter &rewriter) const override {
331 if (op.getNumOperands() != 2)
332 return failure();
333 // Xor using And with invert flags: a ^ b = (a | b) & (~a | ~b)
334
335 // (a | b) = ~(~a & ~b)
336 // (~a | ~b) = ~(a & b)
337 auto inputs = adaptor.getInputs();
338 auto allNotInverts = op.getInverted();
339 std::array<bool, 2> allInverts = {!allNotInverts[0], !allNotInverts[1]};
340
341 auto notAAndNotB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
342 inputs, allInverts);
343 auto aAndB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
344 inputs, allNotInverts);
345
346 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
347 rewriter, op, notAAndNotB, aAndB,
348 /*lhs_invert=*/true,
349 /*rhs_invert=*/true);
350 return success();
351 }
352};
353
354/// Lower a comb::MuxOp operation to synth::MuxInverterOps.
355struct CombMuxOpToSynthConversion : OpConversionPattern<MuxOp> {
357
358 LogicalResult
359 matchAndRewrite(MuxOp op, OpAdaptor adaptor,
360 ConversionPatternRewriter &rewriter) const override {
361 Value cond = adaptor.getCond();
362 Value trueVal = adaptor.getTrueValue();
363 Value falseVal = adaptor.getFalseValue();
364
365 if (!op.getType().isInteger()) {
366 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
367 trueVal =
368 hw::BitcastOp::create(rewriter, op.getLoc(), widthType, trueVal);
369 falseVal =
370 hw::BitcastOp::create(rewriter, op.getLoc(), widthType, falseVal);
371 }
372
373 if (!trueVal.getType().isInteger(1))
374 cond = comb::ReplicateOp::create(rewriter, op.getLoc(), trueVal.getType(),
375 cond);
376
377 Value result = synth::MuxInverterOp::create(rewriter, op.getLoc(), cond,
378 trueVal, falseVal);
379
380 if (result.getType() != op.getType())
381 result =
382 hw::BitcastOp::create(rewriter, op.getLoc(), op.getType(), result);
383
384 replaceOpAndCopyNamehint(rewriter, op, result);
385 return success();
386 }
387};
388
389/// Lower a synth::MuxInverterOp operation to AIG operations.
390struct SynthMuxInverterOpConversion
391 : OpConversionPattern<synth::MuxInverterOp> {
392 using OpConversionPattern<synth::MuxInverterOp>::OpConversionPattern;
393
394 LogicalResult
395 matchAndRewrite(synth::MuxInverterOp op, OpAdaptor adaptor,
396 ConversionPatternRewriter &rewriter) const override {
397 auto inputs = adaptor.getInputs();
398 auto inverted = op.getInverted();
399
400 auto lhs = synth::aig::AndInverterOp::create(
401 rewriter, op.getLoc(), inputs[0], inputs[1], inverted[0], inverted[1]);
402
403 auto rhs = synth::aig::AndInverterOp::create(
404 rewriter, op.getLoc(), inputs[0], inputs[2], !inverted[0], inverted[2]);
405
406 auto nand = synth::aig::AndInverterOp::create(rewriter, op.getLoc(), lhs,
407 rhs, true, true);
408 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(rewriter, op,
409 nand, true);
410 return success();
411 }
412};
413
414template <typename OpTy>
415struct CombLowerVariadicOp : OpConversionPattern<OpTy> {
417 using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
418 LogicalResult
419 matchAndRewrite(OpTy op, OpAdaptor adaptor,
420 ConversionPatternRewriter &rewriter) const override {
421 auto result = lowerFullyAssociativeOp(op, op.getOperands(), rewriter);
422 replaceOpAndCopyNamehint(rewriter, op, result);
423 return success();
424 }
425
426 static Value lowerFullyAssociativeOp(OpTy op, OperandRange operands,
427 ConversionPatternRewriter &rewriter) {
428 Value lhs, rhs;
429 switch (operands.size()) {
430 case 0:
431 llvm_unreachable("cannot be called with empty operand range");
432 break;
433 case 1:
434 return operands[0];
435 case 2:
436 lhs = operands[0];
437 rhs = operands[1];
438 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs}, true);
439 default:
440 auto firstHalf = operands.size() / 2;
441 lhs =
442 lowerFullyAssociativeOp(op, operands.take_front(firstHalf), rewriter);
443 rhs =
444 lowerFullyAssociativeOp(op, operands.drop_front(firstHalf), rewriter);
445 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs}, true);
446 }
447 }
448};
449
450//===----------------------------------------------------------------------===//
451// Adder Architecture Selection
452//===----------------------------------------------------------------------===//
453
454enum AdderArchitecture { RippleCarry, Sklanskey, KoggeStone, BrentKung };
455AdderArchitecture determineAdderArch(Operation *op, int64_t width) {
456 auto strAttr = op->getAttrOfType<StringAttr>("synth.test.arch");
457 if (strAttr) {
458 return llvm::StringSwitch<AdderArchitecture>(strAttr.getValue())
459 .Case("SKLANSKEY", Sklanskey)
460 .Case("KOGGE-STONE", KoggeStone)
461 .Case("BRENT-KUNG", BrentKung)
462 .Case("RIPPLE-CARRY", RippleCarry);
463 }
464 // Determine using width as a heuristic.
465 // TODO: Perform a more thorough analysis to motivate the choices or
466 // implement an adder synthesis algorithm to construct an optimal adder
467 // under the given timing constraints - see the work of Zimmermann
468
469 // For very small adders, overhead of a parallel prefix adder is likely not
470 // worth it.
471 if (width < 8)
472 return AdderArchitecture::RippleCarry;
473
474 // Sklanskey is a good compromise for high-performance, but has high fanout
475 // which may lead to wiring congestion for very large adders.
476 if (width <= 32)
477 return AdderArchitecture::Sklanskey;
478
479 // Kogge-Stone uses greater area than Sklanskey but has lower fanout thus
480 // may be preferable for larger adders.
481 return AdderArchitecture::KoggeStone;
482}
483
484//===----------------------------------------------------------------------===//
485// Parallel Prefix Tree
486//===----------------------------------------------------------------------===//
487
488// Implement the Kogge-Stone parallel prefix tree
489// Described in https://en.wikipedia.org/wiki/Kogge%E2%80%93Stone_adder
490// Slightly better delay than Brent-Kung, but more area.
491void lowerKoggeStonePrefixTree(OpBuilder &builder, Location loc,
492 SmallVector<Value> &pPrefix,
493 SmallVector<Value> &gPrefix) {
494
495 auto width = static_cast<int64_t>(pPrefix.size());
496 assert(width == static_cast<int64_t>(gPrefix.size()));
497 SmallVector<Value> pPrefixNew = pPrefix;
498 SmallVector<Value> gPrefixNew = gPrefix;
499
500 // Kogge-Stone parallel prefix computation
501 for (int64_t stride = 1; stride < width; stride *= 2) {
502
503 for (int64_t i = stride; i < width; ++i) {
504 int64_t j = i - stride;
505
506 // Group generate: g_i OR (p_i AND g_j)
507 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
508 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
509
510 // Group propagate: p_i AND p_j
511 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
512 }
513
514 pPrefix = pPrefixNew;
515 gPrefix = gPrefixNew;
516 }
517
518 LLVM_DEBUG({
519 int64_t stage = 0;
520 for (int64_t stride = 1; stride < width; stride *= 2) {
521 llvm::dbgs()
522 << "--------------------------------------- Kogge-Stone Stage "
523 << stage << "\n";
524 for (int64_t i = stride; i < width; ++i) {
525 int64_t j = i - stride;
526 // Group generate: g_i OR (p_i AND g_j)
527 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
528 << " OR (P" << i << stage << " AND G" << j << stage
529 << ")\n";
530
531 // Group propagate: p_i AND p_j
532 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
533 << " AND P" << j << stage << "\n";
534 }
535 ++stage;
536 }
537 });
538}
539
540// Implement the Sklansky parallel prefix tree
541// High fan-out, low depth, low area
542void lowerSklanskeyPrefixTree(OpBuilder &builder, Location loc,
543 SmallVector<Value> &pPrefix,
544 SmallVector<Value> &gPrefix) {
545 auto width = static_cast<int64_t>(pPrefix.size());
546 assert(width == static_cast<int64_t>(gPrefix.size()));
547 SmallVector<Value> pPrefixNew = pPrefix;
548 SmallVector<Value> gPrefixNew = gPrefix;
549 for (int64_t stride = 1; stride < width; stride *= 2) {
550 for (int64_t i = stride; i < width; i += 2 * stride) {
551 for (int64_t k = 0; k < stride && i + k < width; ++k) {
552 int64_t idx = i + k;
553 int64_t j = i - 1;
554
555 // Group generate: g_idx OR (p_idx AND g_j)
556 Value andPG =
557 comb::AndOp::create(builder, loc, pPrefix[idx], gPrefix[j]);
558 gPrefixNew[idx] = comb::OrOp::create(builder, loc, gPrefix[idx], andPG);
559
560 // Group propagate: p_idx AND p_j
561 pPrefixNew[idx] =
562 comb::AndOp::create(builder, loc, pPrefix[idx], pPrefix[j]);
563 }
564 }
565
566 pPrefix = pPrefixNew;
567 gPrefix = gPrefixNew;
568 }
569
570 LLVM_DEBUG({
571 int64_t stage = 0;
572 for (int64_t stride = 1; stride < width; stride *= 2) {
573 llvm::dbgs() << "--------------------------------------- Sklanskey Stage "
574 << stage << "\n";
575 for (int64_t i = stride; i < width; i += 2 * stride) {
576 for (int64_t k = 0; k < stride && i + k < width; ++k) {
577 int64_t idx = i + k;
578 int64_t j = i - 1;
579 // Group generate: g_i OR (p_i AND g_j)
580 llvm::dbgs() << "G" << idx << stage + 1 << " = G" << idx << stage
581 << " OR (P" << idx << stage << " AND G" << j << stage
582 << ")\n";
583
584 // Group propagate: p_i AND p_j
585 llvm::dbgs() << "P" << idx << stage + 1 << " = P" << idx << stage
586 << " AND P" << j << stage << "\n";
587 }
588 }
589 ++stage;
590 }
591 });
592}
593
594// Implement the Brent-Kung parallel prefix tree
595// Described in https://en.wikipedia.org/wiki/Brent%E2%80%93Kung_adder
596// Slightly worse delay than Kogge-Stone, but less area.
597void lowerBrentKungPrefixTree(OpBuilder &builder, Location loc,
598 SmallVector<Value> &pPrefix,
599 SmallVector<Value> &gPrefix) {
600 auto width = static_cast<int64_t>(pPrefix.size());
601 assert(width == static_cast<int64_t>(gPrefix.size()));
602 SmallVector<Value> pPrefixNew = pPrefix;
603 SmallVector<Value> gPrefixNew = gPrefix;
604 // Brent-Kung parallel prefix computation
605 // Forward phase
606 int64_t stride;
607 for (stride = 1; stride < width; stride *= 2) {
608 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
609 int64_t j = i - stride;
610
611 // Group generate: g_i OR (p_i AND g_j)
612 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
613 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
614
615 // Group propagate: p_i AND p_j
616 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
617 }
618 pPrefix = pPrefixNew;
619 gPrefix = gPrefixNew;
620 }
621
622 // Backward phase
623 for (; stride > 0; stride /= 2) {
624 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
625 int64_t j = i - stride;
626
627 // Group generate: g_i OR (p_i AND g_j)
628 Value andPG = comb::AndOp::create(builder, loc, pPrefix[i], gPrefix[j]);
629 gPrefixNew[i] = comb::OrOp::create(builder, loc, gPrefix[i], andPG);
630
631 // Group propagate: p_i AND p_j
632 pPrefixNew[i] = comb::AndOp::create(builder, loc, pPrefix[i], pPrefix[j]);
633 }
634 pPrefix = pPrefixNew;
635 gPrefix = gPrefixNew;
636 }
637
638 LLVM_DEBUG({
639 int64_t stage = 0;
640 for (stride = 1; stride < width; stride *= 2) {
641 llvm::dbgs() << "--------------------------------------- Brent-Kung FW "
642 << stage << " : Stride " << stride << "\n";
643 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
644 int64_t j = i - stride;
645
646 // Group generate: g_i OR (p_i AND g_j)
647 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
648 << " OR (P" << i << stage << " AND G" << j << stage
649 << ")\n";
650
651 // Group propagate: p_i AND p_j
652 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
653 << " AND P" << j << stage << "\n";
654 }
655 ++stage;
656 }
657
658 for (; stride > 0; stride /= 2) {
659 if (stride * 3 - 1 < width)
660 llvm::dbgs() << "--------------------------------------- Brent-Kung BW "
661 << stage << " : Stride " << stride << "\n";
662
663 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
664 int64_t j = i - stride;
665
666 // Group generate: g_i OR (p_i AND g_j)
667 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
668 << " OR (P" << i << stage << " AND G" << j << stage
669 << ")\n";
670
671 // Group propagate: p_i AND p_j
672 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
673 << " AND P" << j << stage << "\n";
674 }
675 --stage;
676 }
677 });
678}
679
680// TODO: Generalize to other parallel prefix trees.
681class LazyKoggeStonePrefixTree {
682public:
683 LazyKoggeStonePrefixTree(OpBuilder &builder, Location loc, int64_t width,
684 ArrayRef<Value> pPrefix, ArrayRef<Value> gPrefix)
685 : builder(builder), loc(loc), width(width) {
686 assert(width > 0 && "width must be positive");
687 for (int64_t i = 0; i < width; ++i)
688 prefixCache[{0, i}] = {pPrefix[i], gPrefix[i]};
689 }
690
691 // Get the final group and propagate values for bit i.
692 std::pair<Value, Value> getFinal(int64_t i) {
693 assert(i >= 0 && i < width && "i out of bounds");
694 // Final level is ceil(log2(width)) in Kogge-Stone.
695 return getGroupAndPropagate(llvm::Log2_64_Ceil(width), i);
696 }
697
698private:
699 // Recursively get the group and propagate values for bit i at level `level`.
700 // Level 0 is the initial level with the input propagate and generate values.
701 // Level n computes the group and propagate values for a stride of 2^(n-1).
702 // Uses memoization to cache intermediate results.
703 std::pair<Value, Value> getGroupAndPropagate(int64_t level, int64_t i);
704 OpBuilder &builder;
705 Location loc;
706 int64_t width;
707 DenseMap<std::pair<int64_t, int64_t>, std::pair<Value, Value>> prefixCache;
708};
709
710std::pair<Value, Value>
711LazyKoggeStonePrefixTree::getGroupAndPropagate(int64_t level, int64_t i) {
712 assert(i < width && "i out of bounds");
713 auto key = std::make_pair(level, i);
714 auto it = prefixCache.find(key);
715 if (it != prefixCache.end())
716 return it->second;
717
718 assert(level > 0 && "If the level is 0, we should have hit the cache");
719
720 int64_t previousStride = 1ULL << (level - 1);
721 if (i < previousStride) {
722 // No dependency, just copy from the previous level.
723 auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
724 prefixCache[key] = {propagateI, generateI};
725 return prefixCache[key];
726 }
727 // Get the dependency index.
728 int64_t j = i - previousStride;
729 auto [propagateI, generateI] = getGroupAndPropagate(level - 1, i);
730 auto [propagateJ, generateJ] = getGroupAndPropagate(level - 1, j);
731 // Group generate: g_i OR (p_i AND g_j)
732 Value andPG = comb::AndOp::create(builder, loc, propagateI, generateJ);
733 Value newGenerate = comb::OrOp::create(builder, loc, generateI, andPG);
734 // Group propagate: p_i AND p_j
735 Value newPropagate =
736 comb::AndOp::create(builder, loc, propagateI, propagateJ);
737 prefixCache[key] = {newPropagate, newGenerate};
738 return prefixCache[key];
739}
740
741struct CombAddOpConversion : OpConversionPattern<AddOp> {
743
744 LogicalResult
745 matchAndRewrite(AddOp op, OpAdaptor adaptor,
746 ConversionPatternRewriter &rewriter) const override {
747 auto inputs = adaptor.getInputs();
748 // Lower only when there are two inputs.
749 // Variadic operands must be lowered in a different pattern.
750 if (inputs.size() != 2)
751 return failure();
752
753 auto width = op.getType().getIntOrFloatBitWidth();
754 // Skip a zero width value.
755 if (width == 0) {
756 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
757 op.getType(), 0);
758 return success();
759 }
760
761 // Check if the architecture is specified by an attribute.
762 auto arch = determineAdderArch(op, width);
763 if (arch == AdderArchitecture::RippleCarry)
764 return lowerRippleCarryAdder(op, inputs, rewriter);
765 return lowerParallelPrefixAdder(op, inputs, rewriter);
766 }
767
768 // Implement a basic ripple-carry adder for small bitwidths.
769 LogicalResult
770 lowerRippleCarryAdder(comb::AddOp op, ValueRange inputs,
771 ConversionPatternRewriter &rewriter) const {
772 auto width = op.getType().getIntOrFloatBitWidth();
773 // Implement a naive Ripple-carry full adder.
774 Value carry;
775
776 auto aBits = extractBits(rewriter, inputs[0]);
777 auto bBits = extractBits(rewriter, inputs[1]);
778 SmallVector<Value> results;
779 results.resize(width);
780 for (int64_t i = 0; i < width; ++i) {
781 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
782 if (carry)
783 xorOperands.push_back(carry);
784
785 // sum[i] = xor(carry[i-1], a[i], b[i])
786 // NOTE: The result is stored in reverse order.
787 results[width - i - 1] =
788 comb::XorOp::create(rewriter, op.getLoc(), xorOperands, true);
789
790 // If this is the last bit, we are done.
791 if (i == width - 1)
792 break;
793
794 // carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
795 if (!carry) {
796 // This is the first bit, so the carry is the next carry.
797 carry = comb::AndOp::create(rewriter, op.getLoc(),
798 ValueRange{aBits[i], bBits[i]}, true);
799 continue;
800 }
801
802 carry = createMajorityFunction(rewriter, op.getLoc(), aBits[i], bBits[i],
803 carry);
804 }
805 LLVM_DEBUG(llvm::dbgs() << "Lower comb.add to Ripple-Carry Adder of width "
806 << width << "\n");
807
808 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
809 return success();
810 }
811
812 // Implement a parallel prefix adder - with Kogge-Stone or Brent-Kung trees
813 // Will introduce unused signals for the carry bits but these will be removed
814 // by the AIG pass.
815 LogicalResult
816 lowerParallelPrefixAdder(comb::AddOp op, ValueRange inputs,
817 ConversionPatternRewriter &rewriter) const {
818 auto width = op.getType().getIntOrFloatBitWidth();
819
820 auto aBits = extractBits(rewriter, inputs[0]);
821 auto bBits = extractBits(rewriter, inputs[1]);
822
823 // Construct propagate (p) and generate (g) signals
824 SmallVector<Value> p, g;
825 p.reserve(width);
826 g.reserve(width);
827
828 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
829 // p_i = a_i XOR b_i
830 p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit));
831 // g_i = a_i AND b_i
832 g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit));
833 }
834
835 LLVM_DEBUG({
836 llvm::dbgs() << "Lower comb.add to Parallel-Prefix of width " << width
837 << "\n--------------------------------------- Init\n";
838
839 for (int64_t i = 0; i < width; ++i) {
840 // p_i = a_i XOR b_i
841 llvm::dbgs() << "P0" << i << " = A" << i << " XOR B" << i << "\n";
842 // g_i = a_i AND b_i
843 llvm::dbgs() << "G0" << i << " = A" << i << " AND B" << i << "\n";
844 }
845 });
846
847 // Create copies of p and g for the prefix computation
848 SmallVector<Value> pPrefix = p;
849 SmallVector<Value> gPrefix = g;
850
851 // Check if the architecture is specified by an attribute.
852 auto arch = determineAdderArch(op, width);
853
854 switch (arch) {
855 case AdderArchitecture::RippleCarry:
856 llvm_unreachable("Ripple-Carry should be handled separately");
857 break;
858 case AdderArchitecture::Sklanskey:
859 lowerSklanskeyPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
860 break;
861 case AdderArchitecture::KoggeStone:
862 lowerKoggeStonePrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
863 break;
864 case AdderArchitecture::BrentKung:
865 lowerBrentKungPrefixTree(rewriter, op.getLoc(), pPrefix, gPrefix);
866 break;
867 }
868
869 // Generate result sum bits
870 // NOTE: The result is stored in reverse order.
871 SmallVector<Value> results;
872 results.resize(width);
873 // Sum bit 0 is just p[0] since carry_in = 0
874 results[width - 1] = p[0];
875
876 // For remaining bits, sum_i = p_i XOR g_(i-1)
877 // The carry into position i is the group generate from position i-1
878 for (int64_t i = 1; i < width; ++i)
879 results[width - 1 - i] =
880 comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]);
881
882 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
883
884 LLVM_DEBUG({
885 llvm::dbgs() << "--------------------------------------- Completion\n"
886 << "RES0 = P0\n";
887 for (int64_t i = 1; i < width; ++i)
888 llvm::dbgs() << "RES" << i << " = P" << i << " XOR G" << i - 1 << "\n";
889 });
890
891 return success();
892 }
893};
894
895struct CombMulOpConversion : OpConversionPattern<MulOp> {
897 using OpAdaptor = typename OpConversionPattern<MulOp>::OpAdaptor;
898 LogicalResult
899 matchAndRewrite(MulOp op, OpAdaptor adaptor,
900 ConversionPatternRewriter &rewriter) const override {
901 if (adaptor.getInputs().size() != 2)
902 return failure();
903
904 Location loc = op.getLoc();
905 Value a = adaptor.getInputs()[0];
906 Value b = adaptor.getInputs()[1];
907 unsigned width = op.getType().getIntOrFloatBitWidth();
908
909 // Skip a zero width value.
910 if (width == 0) {
911 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
912 return success();
913 }
914
915 // Extract individual bits from operands
916 SmallVector<Value> aBits = extractBits(rewriter, a);
917 SmallVector<Value> bBits = extractBits(rewriter, b);
918
919 auto falseValue = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));
920
921 // Generate partial products
922 SmallVector<SmallVector<Value>> partialProducts;
923 partialProducts.reserve(width);
924 for (unsigned i = 0; i < width; ++i) {
925 SmallVector<Value> row(i, falseValue);
926 row.reserve(width);
927 // Generate partial product bits
928 for (unsigned j = 0; i + j < width; ++j)
929 row.push_back(
930 rewriter.createOrFold<comb::AndOp>(loc, aBits[j], bBits[i]));
931
932 partialProducts.push_back(row);
933 }
934
935 // If the width is 1, we are done.
936 if (width == 1) {
937 rewriter.replaceOp(op, partialProducts[0][0]);
938 return success();
939 }
940
941 // Wallace tree reduction - reduce to two addends.
942 datapath::CompressorTree comp(width, partialProducts, loc);
943 auto addends = comp.compressToHeight(rewriter, 2);
944
945 // Sum the two addends using a carry-propagate adder
946 auto newAdd = comb::AddOp::create(rewriter, loc, addends, true);
947 replaceOpAndCopyNamehint(rewriter, op, newAdd);
948 return success();
949 }
950};
951
952template <typename OpTy>
953struct DivModOpConversionBase : OpConversionPattern<OpTy> {
954 DivModOpConversionBase(MLIRContext *context, int64_t maxEmulationUnknownBits)
956 maxEmulationUnknownBits(maxEmulationUnknownBits) {
957 assert(maxEmulationUnknownBits < 32 &&
958 "maxEmulationUnknownBits must be less than 32");
959 }
960 const int64_t maxEmulationUnknownBits;
961};
962
963struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
964 using DivModOpConversionBase<DivUOp>::DivModOpConversionBase;
965 LogicalResult
966 matchAndRewrite(DivUOp op, OpAdaptor adaptor,
967 ConversionPatternRewriter &rewriter) const override {
968 // Check if the divisor is a power of two.
969 if (llvm::succeeded(comb::convertDivUByPowerOfTwo(op, rewriter)))
970 return success();
971
972 // When rhs is not power of two and the number of unknown bits are small,
973 // create a mux tree that emulates all possible cases.
975 rewriter, maxEmulationUnknownBits, op,
976 [](const APInt &lhs, const APInt &rhs) {
977 // Division by zero is undefined, just return zero.
978 if (rhs.isZero())
979 return APInt::getZero(rhs.getBitWidth());
980 return lhs.udiv(rhs);
981 });
982 }
983};
984
985struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
986 using DivModOpConversionBase<ModUOp>::DivModOpConversionBase;
987 LogicalResult
988 matchAndRewrite(ModUOp op, OpAdaptor adaptor,
989 ConversionPatternRewriter &rewriter) const override {
990 // Check if the divisor is a power of two.
991 if (llvm::succeeded(comb::convertModUByPowerOfTwo(op, rewriter)))
992 return success();
993
994 // When rhs is not power of two and the number of unknown bits are small,
995 // create a mux tree that emulates all possible cases.
997 rewriter, maxEmulationUnknownBits, op,
998 [](const APInt &lhs, const APInt &rhs) {
999 // Division by zero is undefined, just return zero.
1000 if (rhs.isZero())
1001 return APInt::getZero(rhs.getBitWidth());
1002 return lhs.urem(rhs);
1003 });
1004 }
1005};
1006
1007struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
1008 using DivModOpConversionBase<DivSOp>::DivModOpConversionBase;
1009
1010 LogicalResult
1011 matchAndRewrite(DivSOp op, OpAdaptor adaptor,
1012 ConversionPatternRewriter &rewriter) const override {
1013 // Currently only lower with emulation.
1014 // TODO: Implement a signed division lowering at least for power of two.
1016 rewriter, maxEmulationUnknownBits, op,
1017 [](const APInt &lhs, const APInt &rhs) {
1018 // Division by zero is undefined, just return zero.
1019 if (rhs.isZero())
1020 return APInt::getZero(rhs.getBitWidth());
1021 return lhs.sdiv(rhs);
1022 });
1023 }
1024};
1025
1026struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
1027 using DivModOpConversionBase<ModSOp>::DivModOpConversionBase;
1028 LogicalResult
1029 matchAndRewrite(ModSOp op, OpAdaptor adaptor,
1030 ConversionPatternRewriter &rewriter) const override {
1031 // Currently only lower with emulation.
1032 // TODO: Implement a signed modulus lowering at least for power of two.
1034 rewriter, maxEmulationUnknownBits, op,
1035 [](const APInt &lhs, const APInt &rhs) {
1036 // Division by zero is undefined, just return zero.
1037 if (rhs.isZero())
1038 return APInt::getZero(rhs.getBitWidth());
1039 return lhs.srem(rhs);
1040 });
1041 }
1042};
1043
1044struct CombICmpOpConversion : OpConversionPattern<ICmpOp> {
1046
1047 // Simple comparator for small bit widths
1048 static Value constructRippleCarry(Location loc, Value a, Value b,
1049 bool includeEq,
1050 ConversionPatternRewriter &rewriter) {
1051 // Construct following unsigned comparison expressions.
1052 // a <= b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] <= b[n-1:0])
1053 // a < b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] < b[n-1:0])
1054 auto aBits = extractBits(rewriter, a);
1055 auto bBits = extractBits(rewriter, b);
1056 Value acc = hw::ConstantOp::create(rewriter, loc, APInt(1, includeEq));
1057
1058 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
1059 auto aBitXorBBit =
1060 rewriter.createOrFold<comb::XorOp>(loc, aBit, bBit, true);
1061 auto aEqualB = rewriter.createOrFold<synth::aig::AndInverterOp>(
1062 loc, aBitXorBBit, true);
1063 auto pred = rewriter.createOrFold<synth::aig::AndInverterOp>(
1064 loc, aBit, bBit, true, false);
1065
1066 auto aBitAndBBit = rewriter.createOrFold<comb::AndOp>(
1067 loc, ValueRange{aEqualB, acc}, true);
1068 acc = rewriter.createOrFold<comb::OrOp>(loc, pred, aBitAndBBit, true);
1069 }
1070 return acc;
1071 }
1072
1073 // Compute prefix comparison using parallel prefix algorithm
1074 // Note: This generates all intermediate prefix values even though we only
1075 // need the final result. Optimizing this to skip intermediate computations
1076 // is non-trivial because each iteration depends on results from previous
1077 // iterations. We rely on DCE passes to remove unused operations.
1078 // TODO: Lazily compute only the required prefix values. Kogge-Stone is
1079 // already implemented in a lazy manner below, but other architectures can
1080 // also be optimized.
1081 static Value computePrefixComparison(ConversionPatternRewriter &rewriter,
1082 Location loc, SmallVector<Value> pPrefix,
1083 SmallVector<Value> gPrefix,
1084 bool includeEq, AdderArchitecture arch) {
1085 auto width = pPrefix.size();
1086 Value finalGroup, finalPropagate;
1087 // Apply the appropriate prefix tree algorithm
1088 switch (arch) {
1089 case AdderArchitecture::RippleCarry:
1090 llvm_unreachable("Ripple-Carry should be handled separately");
1091 break;
1092 case AdderArchitecture::Sklanskey: {
1093 lowerSklanskeyPrefixTree(rewriter, loc, pPrefix, gPrefix);
1094 finalGroup = gPrefix[width - 1];
1095 finalPropagate = pPrefix[width - 1];
1096 break;
1097 }
1098 case AdderArchitecture::KoggeStone:
1099 // Use lazy Kogge-Stone implementation to avoid computing all
1100 // intermediate prefix values.
1101 std::tie(finalPropagate, finalGroup) =
1102 LazyKoggeStonePrefixTree(rewriter, loc, width, pPrefix, gPrefix)
1103 .getFinal(width - 1);
1104 break;
1105 case AdderArchitecture::BrentKung: {
1106 lowerBrentKungPrefixTree(rewriter, loc, pPrefix, gPrefix);
1107 finalGroup = gPrefix[width - 1];
1108 finalPropagate = pPrefix[width - 1];
1109 break;
1110 }
1111 }
1112
1113 // Final result: `finalGroup` gives us "a < b"
1114 if (includeEq) {
1115 // a <= b iff (a < b) OR (a == b)
1116 // a == b iff `finalPropagate` (all bits are equal)
1117 return comb::OrOp::create(rewriter, loc, finalGroup, finalPropagate);
1118 }
1119 // a < b iff `finalGroup`
1120 return finalGroup;
1121 }
1122
1123 // Construct an unsigned comparator using either ripple-carry or
1124 // parallel-prefix architecture. Comparison uses parallel prefix tree as an
1125 // internal component, so use `AdderArchitecture` enum to select architecture.
1126 static Value constructUnsignedCompare(Operation *op, Location loc, Value a,
1127 Value b, bool isLess, bool includeEq,
1128 ConversionPatternRewriter &rewriter) {
1129 // Ensure a <= b by swapping for simplicity.
1130 if (!isLess)
1131 std::swap(a, b);
1132 auto width = a.getType().getIntOrFloatBitWidth();
1133
1134 // Check if the architecture is specified by an attribute.
1135 auto arch = determineAdderArch(op, width);
1136 if (arch == AdderArchitecture::RippleCarry)
1137 return constructRippleCarry(loc, a, b, includeEq, rewriter);
1138
1139 // For larger widths, use parallel prefix tree
1140 auto aBits = extractBits(rewriter, a);
1141 auto bBits = extractBits(rewriter, b);
1142
1143 // For comparison, we compute:
1144 // - Equal bits: eq_i = ~(a_i ^ b_i)
1145 // - Greater bits: gt_i = ~a_i & b_i (a_i < b_i)
1146 // - Propagate: p_i = eq_i (equality propagates)
1147 // - Generate: g_i = gt_i (greater-than generates)
1148 SmallVector<Value> eq, gt;
1149 eq.reserve(width);
1150 gt.reserve(width);
1151
1152 auto one =
1153 hw::ConstantOp::create(rewriter, loc, rewriter.getIntegerType(1), 1);
1154
1155 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
1156 // eq_i = ~(a_i ^ b_i) = a_i == b_i
1157 auto xorBit = comb::XorOp::create(rewriter, loc, aBit, bBit);
1158 eq.push_back(comb::XorOp::create(rewriter, loc, xorBit, one));
1159
1160 // gt_i = ~a_i & b_i = a_i < b_i
1161 auto notA = comb::XorOp::create(rewriter, loc, aBit, one);
1162 gt.push_back(comb::AndOp::create(rewriter, loc, notA, bBit));
1163 }
1164
1165 return computePrefixComparison(rewriter, loc, std::move(eq), std::move(gt),
1166 includeEq, arch);
1167 }
1168
1169 LogicalResult
1170 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
1171 ConversionPatternRewriter &rewriter) const override {
1172 auto lhs = adaptor.getLhs();
1173 auto rhs = adaptor.getRhs();
1174
1175 switch (op.getPredicate()) {
1176 default:
1177 return failure();
1178
1179 case ICmpPredicate::eq:
1180 case ICmpPredicate::ceq: {
1181 // a == b ==> ~(a[n] ^ b[n]) & ~(a[n-1] ^ b[n-1]) & ...
1182 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
1183 auto xorBits = extractBits(rewriter, xorOp);
1184 SmallVector<bool> allInverts(xorBits.size(), true);
1185 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
1186 rewriter, op, xorBits, allInverts);
1187 return success();
1188 }
1189
1190 case ICmpPredicate::ne:
1191 case ICmpPredicate::cne: {
1192 // a != b ==> (a[n] ^ b[n]) | (a[n-1] ^ b[n-1]) | ...
1193 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
1194 replaceOpWithNewOpAndCopyNamehint<comb::OrOp>(
1195 rewriter, op, extractBits(rewriter, xorOp), true);
1196 return success();
1197 }
1198
1199 case ICmpPredicate::uge:
1200 case ICmpPredicate::ugt:
1201 case ICmpPredicate::ule:
1202 case ICmpPredicate::ult: {
1203 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
1204 op.getPredicate() == ICmpPredicate::ule;
1205 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
1206 op.getPredicate() == ICmpPredicate::ule;
1207 replaceOpAndCopyNamehint(rewriter, op,
1208 constructUnsignedCompare(op, op.getLoc(), lhs,
1209 rhs, isLess, includeEq,
1210 rewriter));
1211 return success();
1212 }
1213 case ICmpPredicate::slt:
1214 case ICmpPredicate::sle:
1215 case ICmpPredicate::sgt:
1216 case ICmpPredicate::sge: {
1217 if (lhs.getType().getIntOrFloatBitWidth() == 0)
1218 return rewriter.notifyMatchFailure(
1219 op.getLoc(), "i0 signed comparison is unsupported");
1220 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
1221 op.getPredicate() == ICmpPredicate::sle;
1222 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
1223 op.getPredicate() == ICmpPredicate::sle;
1224
1225 // Get a sign bit
1226 auto signA = extractMSB(rewriter, lhs);
1227 auto signB = extractMSB(rewriter, rhs);
1228 auto aRest = extractOtherThanMSB(rewriter, lhs);
1229 auto bRest = extractOtherThanMSB(rewriter, rhs);
1230
1231 // Compare magnitudes (all bits except sign)
1232 auto sameSignResult = constructUnsignedCompare(
1233 op, op.getLoc(), aRest, bRest, isLess, includeEq, rewriter);
1234
1235 // XOR of signs: true if signs are different
1236 auto signsDiffer =
1237 comb::XorOp::create(rewriter, op.getLoc(), signA, signB);
1238
1239 // Result when signs are different
1240 Value diffSignResult = isLess ? signA : signB;
1241
1242 // Final result: choose based on whether signs differ
1243 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1244 rewriter, op, signsDiffer, diffSignResult, sameSignResult);
1245 return success();
1246 }
1247 }
1248 }
1249};
1250
1251struct CombParityOpConversion : OpConversionPattern<ParityOp> {
1253
1254 LogicalResult
1255 matchAndRewrite(ParityOp op, OpAdaptor adaptor,
1256 ConversionPatternRewriter &rewriter) const override {
1257 // Parity is the XOR of all bits.
1258 replaceOpWithNewOpAndCopyNamehint<comb::XorOp>(
1259 rewriter, op, extractBits(rewriter, adaptor.getInput()), true);
1260 return success();
1261 }
1262};
1263
1264struct CombShlOpConversion : OpConversionPattern<comb::ShlOp> {
1266
1267 LogicalResult
1268 matchAndRewrite(comb::ShlOp op, OpAdaptor adaptor,
1269 ConversionPatternRewriter &rewriter) const override {
1270 auto width = op.getType().getIntOrFloatBitWidth();
1271 auto lhs = adaptor.getLhs();
1272 auto result = createShiftLogic</*isLeftShift=*/true>(
1273 rewriter, op.getLoc(), adaptor.getRhs(), width,
1274 /*getPadding=*/
1275 [&](int64_t index) {
1276 // Don't create zero width value.
1277 if (index == 0)
1278 return Value();
1279 // Padding is 0 for left shift.
1280 return rewriter.createOrFold<hw::ConstantOp>(
1281 op.getLoc(), rewriter.getIntegerType(index), 0);
1282 },
1283 /*getExtract=*/
1284 [&](int64_t index) {
1285 assert(index < width && "index out of bounds");
1286 // Exract the bits from LSB.
1287 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, 0,
1288 width - index);
1289 });
1290
1291 replaceOpAndCopyNamehint(rewriter, op, result);
1292 return success();
1293 }
1294};
1295
1296struct CombShrUOpConversion : OpConversionPattern<comb::ShrUOp> {
1298
1299 LogicalResult
1300 matchAndRewrite(comb::ShrUOp op, OpAdaptor adaptor,
1301 ConversionPatternRewriter &rewriter) const override {
1302 auto width = op.getType().getIntOrFloatBitWidth();
1303 auto lhs = adaptor.getLhs();
1304 auto result = createShiftLogic</*isLeftShift=*/false>(
1305 rewriter, op.getLoc(), adaptor.getRhs(), width,
1306 /*getPadding=*/
1307 [&](int64_t index) {
1308 // Don't create zero width value.
1309 if (index == 0)
1310 return Value();
1311 // Padding is 0 for right shift.
1312 return rewriter.createOrFold<hw::ConstantOp>(
1313 op.getLoc(), rewriter.getIntegerType(index), 0);
1314 },
1315 /*getExtract=*/
1316 [&](int64_t index) {
1317 assert(index < width && "index out of bounds");
1318 // Exract the bits from MSB.
1319 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
1320 width - index);
1321 });
1322
1323 replaceOpAndCopyNamehint(rewriter, op, result);
1324 return success();
1325 }
1326};
1327
1328struct CombShrSOpConversion : OpConversionPattern<comb::ShrSOp> {
1330
1331 LogicalResult
1332 matchAndRewrite(comb::ShrSOp op, OpAdaptor adaptor,
1333 ConversionPatternRewriter &rewriter) const override {
1334 auto width = op.getType().getIntOrFloatBitWidth();
1335 if (width == 0)
1336 return rewriter.notifyMatchFailure(op.getLoc(),
1337 "i0 signed shift is unsupported");
1338 auto lhs = adaptor.getLhs();
1339 // Get the sign bit.
1340 auto sign =
1341 rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
1342
1343 // NOTE: The max shift amount is width - 1 because the sign bit is
1344 // already shifted out.
1345 auto result = createShiftLogic</*isLeftShift=*/false>(
1346 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
1347 /*getPadding=*/
1348 [&](int64_t index) {
1349 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
1350 index + 1);
1351 },
1352 /*getExtract=*/
1353 [&](int64_t index) {
1354 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
1355 width - index - 1);
1356 });
1357
1358 replaceOpAndCopyNamehint(rewriter, op, result);
1359 return success();
1360 }
1361};
1362
1363} // namespace
1364
1365//===----------------------------------------------------------------------===//
1366// Convert Comb to AIG pass
1367//===----------------------------------------------------------------------===//
1368
1369namespace {
1370struct ConvertCombToSynthPass
1371 : public impl::ConvertCombToSynthBase<ConvertCombToSynthPass> {
1372 void runOnOperation() override;
1373 using ConvertCombToSynthBase<ConvertCombToSynthPass>::ConvertCombToSynthBase;
1374};
1375} // namespace
1376
1377static void
1379 uint32_t maxEmulationUnknownBits,
1380 bool forceAIG) {
1381 patterns.add<
1382 // Bitwise Logical Ops
1383 CombAndOpConversion, CombParityOpConversion, CombXorOpToSynthConversion,
1384 CombMuxOpToSynthConversion,
1385 // Arithmetic Ops
1386 CombMulOpConversion, CombICmpOpConversion,
1387 // Shift Ops
1388 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
1389 // Variadic ops that must be lowered to binary operations
1390 CombLowerVariadicOp<AddOp>, CombLowerVariadicOp<MulOp>>(
1391 patterns.getContext());
1392
1393 if (forceAIG) {
1394 patterns.add<SynthXorInverterOpConversion, SynthMuxInverterOpConversion>(
1395 patterns.getContext());
1396 }
1397 patterns.add(comb::convertSubToAdd);
1398
1399 patterns.add<CombOrToAIGConversion, CombAddOpConversion>(
1400 patterns.getContext());
1401 synth::populateVariadicAndInverterLoweringPatterns(patterns);
1402
1403 if (forceAIG)
1404 synth::populateVariadicXorInverterLoweringPatterns(patterns);
1405
1406 // Add div/mod patterns with a threshold given by the pass option.
1407 patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
1408 CombModSOpConversion>(patterns.getContext(),
1409 maxEmulationUnknownBits);
1410}
1411
1412void ConvertCombToSynthPass::runOnOperation() {
1413 ConversionTarget target(getContext());
1414
1415 // Comb is source dialect.
1416 target.addIllegalDialect<comb::CombDialect>();
1417 // Keep data movement operations like Extract, Concat and Replicate.
1418 target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
1420
1421 // Treat array operations as illegal. Strictly speaking, other than array
1422 // get operation with non-const index are legal in AIG but array types
1423 // prevent a bunch of optimizations so just lower them to integer
1424 // operations. It's required to run HWAggregateToComb pass before this pass.
1426 hw::AggregateConstantOp>();
1427
1428 target.addLegalDialect<synth::SynthDialect>();
1429 if (forceAIG)
1430 target.addIllegalOp<synth::XorInverterOp, synth::MuxInverterOp>();
1431
1432 // If additional legal ops are specified, add them to the target.
1433 if (!additionalLegalOps.empty())
1434 for (const auto &opName : additionalLegalOps)
1435 target.addLegalOp(OperationName(opName, &getContext()));
1436
1437 RewritePatternSet patterns(&getContext());
1438 populateCombToAIGConversionPatterns(patterns, maxEmulationUnknownBits,
1439 forceAIG);
1440
1441 if (failed(mlir::applyPartialConversion(getOperation(), target,
1442 std::move(patterns))))
1443 return signalPassFailure();
1444}
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
static APInt substitueMaskToValues(size_t width, llvm::SmallVectorImpl< ConstantOrValue > &constantOrValues, uint32_t mask)
static LogicalResult emulateBinaryOpForUnknownBits(ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits, Operation *op, llvm::function_ref< APInt(const APInt &, const APInt &)> emulate)
static int64_t getNumUnknownBitsAndPopulateValues(Value value, llvm::SmallVectorImpl< ConstantOrValue > &values)
static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a, Value b, Value carry)
static Value extractOtherThanMSB(OpBuilder &builder, Value val)
static Value extractMSB(OpBuilder &builder, Value val)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, uint32_t maxEmulationUnknownBits, bool forceAIG)
static std::unique_ptr< Context > context
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
create(data_type, value)
Definition hw.py:441
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
Definition comb.py:1