CIRCT 22.0.0git
Loading...
Searching...
No Matches
CombToSynth.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main Comb to Synth Conversion Pass Implementation.
10//
11// High-level Comb Operations
12// | |
13// v |
14// +-------------------+ |
15// | and, or, xor, mux | |
16// +---------+---------+ |
17// | |
18// +-------+--------+ |
19// v v v
20// +-----+ +-----+
21// | AIG |-------->| MIG |
22// +-----+ +-----+
23//
24//===----------------------------------------------------------------------===//
25
33#include "mlir/Pass/Pass.h"
34#include "mlir/Transforms/DialectConversion.h"
35#include "llvm/ADT/APInt.h"
36#include "llvm/ADT/PointerUnion.h"
37#include "llvm/Support/Debug.h"
38#include <array>
39
40#define DEBUG_TYPE "comb-to-synth"
41
42namespace circt {
43#define GEN_PASS_DEF_CONVERTCOMBTOSYNTH
44#include "circt/Conversion/Passes.h.inc"
45} // namespace circt
46
47using namespace circt;
48using namespace comb;
49
50//===----------------------------------------------------------------------===//
51// Utility Functions
52//===----------------------------------------------------------------------===//
53
54// A wrapper for comb::extractBits that returns a SmallVector<Value>.
55static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
56 SmallVector<Value> bits;
57 comb::extractBits(builder, val, bits);
58 return bits;
59}
60
61// Construct a mux tree for shift operations. `isLeftShift` controls the
62// direction of the shift operation and is used to determine order of the
63// padding and extracted bits. Callbacks `getPadding` and `getExtract` are used
64// to get the padding and extracted bits for each shift amount. `getPadding`
65// could return a nullptr as i0 value but except for that, these callbacks must
66// return a valid value for each shift amount in the range [0, maxShiftAmount].
67// The value for `maxShiftAmount` is used as the out-of-bounds value.
68template <bool isLeftShift>
69static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc,
70 Value shiftAmount, int64_t maxShiftAmount,
71 llvm::function_ref<Value(int64_t)> getPadding,
72 llvm::function_ref<Value(int64_t)> getExtract) {
73 // Extract individual bits from shift amount
74 auto bits = extractBits(rewriter, shiftAmount);
75
76 // Create nodes for each possible shift amount
77 SmallVector<Value> nodes;
78 nodes.reserve(maxShiftAmount);
79 for (int64_t i = 0; i < maxShiftAmount; ++i) {
80 Value extract = getExtract(i);
81 Value padding = getPadding(i);
82
83 if (!padding) {
84 nodes.push_back(extract);
85 continue;
86 }
87
88 // Concatenate extracted bits with padding
89 if (isLeftShift)
90 nodes.push_back(
91 rewriter.createOrFold<comb::ConcatOp>(loc, extract, padding));
92 else
93 nodes.push_back(
94 rewriter.createOrFold<comb::ConcatOp>(loc, padding, extract));
95 }
96
97 // Create out-of-bounds value
98 auto outOfBoundsValue = getPadding(maxShiftAmount);
99 assert(outOfBoundsValue && "outOfBoundsValue must be valid");
100
101 // Construct mux tree for shift operation
102 auto result =
103 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
104
105 // Add bounds checking
106 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
107 loc, ICmpPredicate::ult, shiftAmount,
108 hw::ConstantOp::create(rewriter, loc, shiftAmount.getType(),
109 maxShiftAmount));
110
111 return rewriter.createOrFold<comb::MuxOp>(loc, inBound, result,
112 outOfBoundsValue);
113}
114
115// Return a majority operation if MIG is enabled, otherwise return a majority
116// function implemented with Comb operations. In that case `carry` has slightly
117// smaller depth than the other inputs.
118static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a,
119 Value b, Value carry,
120 bool useMajorityInverterOp) {
121 if (useMajorityInverterOp) {
122 std::array<Value, 3> inputs = {a, b, carry};
123 std::array<bool, 3> inverts = {false, false, false};
124 return synth::mig::MajorityInverterOp::create(rewriter, loc, inputs,
125 inverts);
126 }
127
128 // maj(a, b, c) = (c & (a ^ b)) | (a & b)
129 auto aXnorB = comb::XorOp::create(rewriter, loc, ValueRange{a, b}, true);
130 auto andOp =
131 comb::AndOp::create(rewriter, loc, ValueRange{carry, aXnorB}, true);
132 auto aAndB = comb::AndOp::create(rewriter, loc, ValueRange{a, b}, true);
133 return comb::OrOp::create(rewriter, loc, ValueRange{andOp, aAndB}, true);
134}
135
136namespace {
137// A union of Value and IntegerAttr to cleanly handle constant values.
138using ConstantOrValue = llvm::PointerUnion<Value, mlir::IntegerAttr>;
139} // namespace
140
141// Return the number of unknown bits and populate the concatenated values.
143 Value value, llvm::SmallVectorImpl<ConstantOrValue> &values) {
144 // Constant or zero width value are all known.
145 if (value.getType().isInteger(0))
146 return 0;
147
148 // Recursively count unknown bits for concat.
149 if (auto concat = value.getDefiningOp<comb::ConcatOp>()) {
150 int64_t totalUnknownBits = 0;
151 for (auto concatInput : llvm::reverse(concat.getInputs())) {
152 auto unknownBits =
153 getNumUnknownBitsAndPopulateValues(concatInput, values);
154 if (unknownBits < 0)
155 return unknownBits;
156 totalUnknownBits += unknownBits;
157 }
158 return totalUnknownBits;
159 }
160
161 // Constant value is known.
162 if (auto constant = value.getDefiningOp<hw::ConstantOp>()) {
163 values.push_back(constant.getValueAttr());
164 return 0;
165 }
166
167 // Consider other operations as unknown bits.
168 // TODO: We can handle replicate, extract, etc.
169 values.push_back(value);
170 return hw::getBitWidth(value.getType());
171}
172
173// Return a value that substitutes the unknown bits with the mask.
174static APInt
176 llvm::SmallVectorImpl<ConstantOrValue> &constantOrValues,
177 uint32_t mask) {
178 uint32_t bitPos = 0, unknownPos = 0;
179 APInt result(width, 0);
180 for (auto constantOrValue : constantOrValues) {
181 int64_t elemWidth;
182 if (auto constant = dyn_cast<IntegerAttr>(constantOrValue)) {
183 elemWidth = constant.getValue().getBitWidth();
184 result.insertBits(constant.getValue(), bitPos);
185 } else {
186 elemWidth = hw::getBitWidth(cast<Value>(constantOrValue).getType());
187 assert(elemWidth >= 0 && "unknown bit width");
188 assert(elemWidth + unknownPos < 32 && "unknown bit width too large");
189 // Create a mask for the unknown bits.
190 uint32_t usedBits = (mask >> unknownPos) & ((1 << elemWidth) - 1);
191 result.insertBits(APInt(elemWidth, usedBits), bitPos);
192 unknownPos += elemWidth;
193 }
194 bitPos += elemWidth;
195 }
196
197 return result;
198}
199
200// Emulate a binary operation with unknown bits using a table lookup.
201// This function enumerates all possible combinations of unknown bits and
202// emulates the operation for each combination.
203static LogicalResult emulateBinaryOpForUnknownBits(
204 ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits,
205 Operation *op,
206 llvm::function_ref<APInt(const APInt &, const APInt &)> emulate) {
207 SmallVector<ConstantOrValue> lhsValues, rhsValues;
208
209 assert(op->getNumResults() == 1 && op->getNumOperands() == 2 &&
210 "op must be a single result binary operation");
211
212 auto lhs = op->getOperand(0);
213 auto rhs = op->getOperand(1);
214 auto width = op->getResult(0).getType().getIntOrFloatBitWidth();
215 auto loc = op->getLoc();
216 auto numLhsUnknownBits = getNumUnknownBitsAndPopulateValues(lhs, lhsValues);
217 auto numRhsUnknownBits = getNumUnknownBitsAndPopulateValues(rhs, rhsValues);
218
219 // If unknown bit width is detected, abort the lowering.
220 if (numLhsUnknownBits < 0 || numRhsUnknownBits < 0)
221 return failure();
222
223 int64_t totalUnknownBits = numLhsUnknownBits + numRhsUnknownBits;
224 if (totalUnknownBits > maxEmulationUnknownBits)
225 return failure();
226
227 SmallVector<Value> emulatedResults;
228 emulatedResults.reserve(1 << totalUnknownBits);
229
230 // Emulate all possible cases.
231 DenseMap<IntegerAttr, hw::ConstantOp> constantPool;
232 auto getConstant = [&](const APInt &value) -> hw::ConstantOp {
233 auto attr = rewriter.getIntegerAttr(rewriter.getIntegerType(width), value);
234 auto it = constantPool.find(attr);
235 if (it != constantPool.end())
236 return it->second;
237 auto constant = hw::ConstantOp::create(rewriter, loc, value);
238 constantPool[attr] = constant;
239 return constant;
240 };
241
242 for (uint32_t lhsMask = 0, lhsMaskEnd = 1 << numLhsUnknownBits;
243 lhsMask < lhsMaskEnd; ++lhsMask) {
244 APInt lhsValue = substitueMaskToValues(width, lhsValues, lhsMask);
245 for (uint32_t rhsMask = 0, rhsMaskEnd = 1 << numRhsUnknownBits;
246 rhsMask < rhsMaskEnd; ++rhsMask) {
247 APInt rhsValue = substitueMaskToValues(width, rhsValues, rhsMask);
248 // Emulate.
249 emulatedResults.push_back(getConstant(emulate(lhsValue, rhsValue)));
250 }
251 }
252
253 // Create selectors for mux tree.
254 SmallVector<Value> selectors;
255 selectors.reserve(totalUnknownBits);
256 for (auto &concatedValues : {rhsValues, lhsValues})
257 for (auto valueOrConstant : concatedValues) {
258 auto value = dyn_cast<Value>(valueOrConstant);
259 if (!value)
260 continue;
261 extractBits(rewriter, value, selectors);
262 }
263
264 assert(totalUnknownBits == static_cast<int64_t>(selectors.size()) &&
265 "number of selectors must match");
266 auto muxed = constructMuxTree(rewriter, loc, selectors, emulatedResults,
267 getConstant(APInt::getZero(width)));
268
269 replaceOpAndCopyNamehint(rewriter, op, muxed);
270 return success();
271}
272
273//===----------------------------------------------------------------------===//
274// Conversion patterns
275//===----------------------------------------------------------------------===//
276
277namespace {
278
279/// Lower a comb::AndOp operation to synth::aig::AndInverterOp
280struct CombAndOpConversion : OpConversionPattern<AndOp> {
282
283 LogicalResult
284 matchAndRewrite(AndOp op, OpAdaptor adaptor,
285 ConversionPatternRewriter &rewriter) const override {
286 SmallVector<bool> nonInverts(adaptor.getInputs().size(), false);
287 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
288 rewriter, op, adaptor.getInputs(), nonInverts);
289 return success();
290 }
291};
292
293/// Lower a comb::OrOp operation to synth::aig::AndInverterOp with invert flags
294struct CombOrToAIGConversion : OpConversionPattern<OrOp> {
296
297 LogicalResult
298 matchAndRewrite(OrOp op, OpAdaptor adaptor,
299 ConversionPatternRewriter &rewriter) const override {
300 // Implement Or using And and invert flags: a | b = ~(~a & ~b)
301 SmallVector<bool> allInverts(adaptor.getInputs().size(), true);
302 auto andOp = synth::aig::AndInverterOp::create(
303 rewriter, op.getLoc(), adaptor.getInputs(), allInverts);
304 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
305 rewriter, op, andOp,
306 /*invert=*/true);
307 return success();
308 }
309};
310
311struct CombOrToMIGConversion : OpConversionPattern<OrOp> {
313 LogicalResult
314 matchAndRewrite(OrOp op, OpAdaptor adaptor,
315 ConversionPatternRewriter &rewriter) const override {
316 if (op.getNumOperands() != 2)
317 return failure();
318 SmallVector<Value, 3> inputs(adaptor.getInputs());
319 auto one = hw::ConstantOp::create(
320 rewriter, op.getLoc(),
321 APInt::getAllOnes(hw::getBitWidth(op.getType())));
322 inputs.push_back(one);
323 std::array<bool, 3> inverts = {false, false, false};
324 replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
325 rewriter, op, inputs, inverts);
326 return success();
327 }
328};
329
330struct AndInverterToMIGConversion
331 : OpConversionPattern<synth::aig::AndInverterOp> {
332 using OpConversionPattern<synth::aig::AndInverterOp>::OpConversionPattern;
333 LogicalResult
334 matchAndRewrite(synth::aig::AndInverterOp op, OpAdaptor adaptor,
335 ConversionPatternRewriter &rewriter) const override {
336 if (op.getNumOperands() > 2)
337 return failure();
338 if (op.getNumOperands() == 1) {
339 SmallVector<bool, 1> inverts{op.getInverted()[0]};
340 replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
341 rewriter, op, adaptor.getInputs(), inverts);
342 return success();
343 }
344 SmallVector<Value, 3> inputs(adaptor.getInputs());
345 auto one = hw::ConstantOp::create(
346 rewriter, op.getLoc(), APInt::getZero(hw::getBitWidth(op.getType())));
347 inputs.push_back(one);
348 SmallVector<bool, 3> inverts(adaptor.getInverted());
349 inverts.push_back(false);
350 replaceOpWithNewOpAndCopyNamehint<synth::mig::MajorityInverterOp>(
351 rewriter, op, inputs, inverts);
352 return success();
353 }
354};
355
356/// Lower a comb::XorOp operation to AIG operations
357struct CombXorOpConversion : OpConversionPattern<XorOp> {
359
360 LogicalResult
361 matchAndRewrite(XorOp op, OpAdaptor adaptor,
362 ConversionPatternRewriter &rewriter) const override {
363 if (op.getNumOperands() != 2)
364 return failure();
365 // Xor using And with invert flags: a ^ b = (a | b) & (~a | ~b)
366
367 // (a | b) = ~(~a & ~b)
368 // (~a | ~b) = ~(a & b)
369 auto inputs = adaptor.getInputs();
370 SmallVector<bool> allInverts(inputs.size(), true);
371 SmallVector<bool> allNotInverts(inputs.size(), false);
372
373 auto notAAndNotB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
374 inputs, allInverts);
375 auto aAndB = synth::aig::AndInverterOp::create(rewriter, op.getLoc(),
376 inputs, allNotInverts);
377
378 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
379 rewriter, op, notAAndNotB, aAndB,
380 /*lhs_invert=*/true,
381 /*rhs_invert=*/true);
382 return success();
383 }
384};
385
386template <typename OpTy>
387struct CombLowerVariadicOp : OpConversionPattern<OpTy> {
389 using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
390 LogicalResult
391 matchAndRewrite(OpTy op, OpAdaptor adaptor,
392 ConversionPatternRewriter &rewriter) const override {
393 auto result = lowerFullyAssociativeOp(op, op.getOperands(), rewriter);
394 replaceOpAndCopyNamehint(rewriter, op, result);
395 return success();
396 }
397
398 static Value lowerFullyAssociativeOp(OpTy op, OperandRange operands,
399 ConversionPatternRewriter &rewriter) {
400 Value lhs, rhs;
401 switch (operands.size()) {
402 case 0:
403 llvm_unreachable("cannot be called with empty operand range");
404 break;
405 case 1:
406 return operands[0];
407 case 2:
408 lhs = operands[0];
409 rhs = operands[1];
410 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs}, true);
411 default:
412 auto firstHalf = operands.size() / 2;
413 lhs =
414 lowerFullyAssociativeOp(op, operands.take_front(firstHalf), rewriter);
415 rhs =
416 lowerFullyAssociativeOp(op, operands.drop_front(firstHalf), rewriter);
417 return OpTy::create(rewriter, op.getLoc(), ValueRange{lhs, rhs}, true);
418 }
419 }
420};
421
422// Lower comb::MuxOp to AIG operations.
423struct CombMuxOpConversion : OpConversionPattern<MuxOp> {
425
426 LogicalResult
427 matchAndRewrite(MuxOp op, OpAdaptor adaptor,
428 ConversionPatternRewriter &rewriter) const override {
429 Value cond = op.getCond();
430 auto trueVal = op.getTrueValue();
431 auto falseVal = op.getFalseValue();
432
433 if (!op.getType().isInteger()) {
434 // If the type of the mux is not integer, bitcast the operands first.
435 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
436 trueVal =
437 hw::BitcastOp::create(rewriter, op->getLoc(), widthType, trueVal);
438 falseVal =
439 hw::BitcastOp::create(rewriter, op->getLoc(), widthType, falseVal);
440 }
441
442 // Replicate condition if needed
443 if (!trueVal.getType().isInteger(1))
444 cond = comb::ReplicateOp::create(rewriter, op.getLoc(), trueVal.getType(),
445 cond);
446
447 // c ? a : b => (replicate(c) & a) | (~replicate(c) & b)
448 auto lhs =
449 synth::aig::AndInverterOp::create(rewriter, op.getLoc(), cond, trueVal);
450 auto rhs = synth::aig::AndInverterOp::create(rewriter, op.getLoc(), cond,
451 falseVal, true, false);
452
453 Value result = comb::OrOp::create(rewriter, op.getLoc(), lhs, rhs);
454 // Insert the bitcast if the type of the mux is not integer.
455 if (result.getType() != op.getType())
456 result =
457 hw::BitcastOp::create(rewriter, op.getLoc(), op.getType(), result);
458 replaceOpAndCopyNamehint(rewriter, op, result);
459 return success();
460 }
461};
462
463template <bool lowerToMIG>
464struct CombAddOpConversion : OpConversionPattern<AddOp> {
466 LogicalResult
467 matchAndRewrite(AddOp op, OpAdaptor adaptor,
468 ConversionPatternRewriter &rewriter) const override {
469 auto inputs = adaptor.getInputs();
470 // Lower only when there are two inputs.
471 // Variadic operands must be lowered in a different pattern.
472 if (inputs.size() != 2)
473 return failure();
474
475 auto width = op.getType().getIntOrFloatBitWidth();
476 // Skip a zero width value.
477 if (width == 0) {
478 replaceOpWithNewOpAndCopyNamehint<hw::ConstantOp>(rewriter, op,
479 op.getType(), 0);
480 return success();
481 }
482
483 if (width < 8)
484 lowerRippleCarryAdder(op, inputs, rewriter);
485 else
486 lowerParallelPrefixAdder(op, inputs, rewriter);
487
488 return success();
489 }
490
491 // Implement a basic ripple-carry adder for small bitwidths.
492 void lowerRippleCarryAdder(comb::AddOp op, ValueRange inputs,
493 ConversionPatternRewriter &rewriter) const {
494 auto width = op.getType().getIntOrFloatBitWidth();
495 // Implement a naive Ripple-carry full adder.
496 Value carry;
497
498 auto aBits = extractBits(rewriter, inputs[0]);
499 auto bBits = extractBits(rewriter, inputs[1]);
500 SmallVector<Value> results;
501 results.resize(width);
502 for (int64_t i = 0; i < width; ++i) {
503 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
504 if (carry)
505 xorOperands.push_back(carry);
506
507 // sum[i] = xor(carry[i-1], a[i], b[i])
508 // NOTE: The result is stored in reverse order.
509 results[width - i - 1] =
510 comb::XorOp::create(rewriter, op.getLoc(), xorOperands, true);
511
512 // If this is the last bit, we are done.
513 if (i == width - 1)
514 break;
515
516 // carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
517 if (!carry) {
518 // This is the first bit, so the carry is the next carry.
519 carry = comb::AndOp::create(rewriter, op.getLoc(),
520 ValueRange{aBits[i], bBits[i]}, true);
521 continue;
522 }
523
524 carry = createMajorityFunction(rewriter, op.getLoc(), aBits[i], bBits[i],
525 carry, lowerToMIG);
526 }
527 LLVM_DEBUG(llvm::dbgs() << "Lower comb.add to Ripple-Carry Adder of width "
528 << width << "\n");
529
530 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
531 }
532
533 // Implement a parallel prefix adder - with Kogge-Stone or Brent-Kung trees
534 // Will introduce unused signals for the carry bits but these will be removed
535 // by the AIG pass.
536 void lowerParallelPrefixAdder(comb::AddOp op, ValueRange inputs,
537 ConversionPatternRewriter &rewriter) const {
538 auto width = op.getType().getIntOrFloatBitWidth();
539
540 auto aBits = extractBits(rewriter, inputs[0]);
541 auto bBits = extractBits(rewriter, inputs[1]);
542 // Construct propagate (p) and generate (g) signals
543 SmallVector<Value> p, g;
544 p.reserve(width);
545 g.reserve(width);
546
547 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
548 // p_i = a_i XOR b_i
549 p.push_back(comb::XorOp::create(rewriter, op.getLoc(), aBit, bBit));
550 // g_i = a_i AND b_i
551 g.push_back(comb::AndOp::create(rewriter, op.getLoc(), aBit, bBit));
552 }
553
554 LLVM_DEBUG({
555 llvm::dbgs() << "Lower comb.add to Parallel-Prefix of width " << width
556 << "\n--------------------------------------- Init\n";
557
558 for (int64_t i = 0; i < width; ++i) {
559 // p_i = a_i XOR b_i
560 llvm::dbgs() << "P0" << i << " = A" << i << " XOR B" << i << "\n";
561 // g_i = a_i AND b_i
562 llvm::dbgs() << "G0" << i << " = A" << i << " AND B" << i << "\n";
563 }
564 });
565
566 // Create copies of p and g for the prefix computation
567 SmallVector<Value> pPrefix = p;
568 SmallVector<Value> gPrefix = g;
569 if (width < 32)
570 lowerKoggeStonePrefixTree(op, inputs, rewriter, pPrefix, gPrefix);
571 else
572 lowerBrentKungPrefixTree(op, inputs, rewriter, pPrefix, gPrefix);
573
574 // Generate result sum bits
575 // NOTE: The result is stored in reverse order.
576 SmallVector<Value> results;
577 results.resize(width);
578 // Sum bit 0 is just p[0] since carry_in = 0
579 results[width - 1] = p[0];
580
581 // For remaining bits, sum_i = p_i XOR g_(i-1)
582 // The carry into position i is the group generate from position i-1
583 for (int64_t i = 1; i < width; ++i)
584 results[width - 1 - i] =
585 comb::XorOp::create(rewriter, op.getLoc(), p[i], gPrefix[i - 1]);
586
587 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(rewriter, op, results);
588
589 LLVM_DEBUG({
590 llvm::dbgs() << "--------------------------------------- Completion\n"
591 << "RES0 = P0\n";
592 for (int64_t i = 1; i < width; ++i)
593 llvm::dbgs() << "RES" << i << " = P" << i << " XOR G" << i - 1 << "\n";
594 });
595 }
596
597 // Implement the Kogge-Stone parallel prefix tree
598 // Described in https://en.wikipedia.org/wiki/Kogge%E2%80%93Stone_adder
599 // Slightly better delay than Brent-Kung, but more area.
600 void lowerKoggeStonePrefixTree(comb::AddOp op, ValueRange inputs,
601 ConversionPatternRewriter &rewriter,
602 SmallVector<Value> &pPrefix,
603 SmallVector<Value> &gPrefix) const {
604 auto width = op.getType().getIntOrFloatBitWidth();
605 SmallVector<Value> pPrefixNew = pPrefix;
606 SmallVector<Value> gPrefixNew = gPrefix;
607
608 // Kogge-Stone parallel prefix computation
609 for (int64_t stride = 1; stride < width; stride *= 2) {
610 for (int64_t i = stride; i < width; ++i) {
611 int64_t j = i - stride;
612
613 // Group generate: g_i OR (p_i AND g_j)
614 Value andPG =
615 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]);
616 gPrefixNew[i] =
617 comb::OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG);
618
619 // Group propagate: p_i AND p_j
620 pPrefixNew[i] =
621 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], pPrefix[j]);
622 }
623 pPrefix = pPrefixNew;
624 gPrefix = gPrefixNew;
625 }
626 LLVM_DEBUG({
627 int64_t stage = 0;
628 for (int64_t stride = 1; stride < width; stride *= 2) {
629 llvm::dbgs()
630 << "--------------------------------------- Kogge-Stone Stage "
631 << stage << "\n";
632 for (int64_t i = stride; i < width; ++i) {
633 int64_t j = i - stride;
634 // Group generate: g_i OR (p_i AND g_j)
635 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
636 << " OR (P" << i << stage << " AND G" << j << stage
637 << ")\n";
638
639 // Group propagate: p_i AND p_j
640 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
641 << " AND P" << j << stage << "\n";
642 }
643 ++stage;
644 }
645 });
646 }
647
648 // Implement the Brent-Kung parallel prefix tree
649 // Described in https://en.wikipedia.org/wiki/Brent%E2%80%93Kung_adder
650 // Slightly worse delay than Kogge-Stone, but less area.
651 void lowerBrentKungPrefixTree(comb::AddOp op, ValueRange inputs,
652 ConversionPatternRewriter &rewriter,
653 SmallVector<Value> &pPrefix,
654 SmallVector<Value> &gPrefix) const {
655 auto width = op.getType().getIntOrFloatBitWidth();
656 SmallVector<Value> pPrefixNew = pPrefix;
657 SmallVector<Value> gPrefixNew = gPrefix;
658 // Brent-Kung parallel prefix computation
659 // Forward phase
660 int64_t stride;
661 for (stride = 1; stride < width; stride *= 2) {
662 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
663 int64_t j = i - stride;
664
665 // Group generate: g_i OR (p_i AND g_j)
666 Value andPG =
667 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]);
668 gPrefixNew[i] =
669 comb::OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG);
670
671 // Group propagate: p_i AND p_j
672 pPrefixNew[i] =
673 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], pPrefix[j]);
674 }
675 pPrefix = pPrefixNew;
676 gPrefix = gPrefixNew;
677 }
678
679 // Backward phase
680 for (; stride > 0; stride /= 2) {
681 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
682 int64_t j = i - stride;
683
684 // Group generate: g_i OR (p_i AND g_j)
685 Value andPG =
686 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], gPrefix[j]);
687 gPrefixNew[i] =
688 comb::OrOp::create(rewriter, op.getLoc(), gPrefix[i], andPG);
689
690 // Group propagate: p_i AND p_j
691 pPrefixNew[i] =
692 comb::AndOp::create(rewriter, op.getLoc(), pPrefix[i], pPrefix[j]);
693 }
694 pPrefix = pPrefixNew;
695 gPrefix = gPrefixNew;
696 }
697
698 LLVM_DEBUG({
699 int64_t stage = 0;
700 for (stride = 1; stride < width; stride *= 2) {
701 llvm::dbgs() << "--------------------------------------- Brent-Kung FW "
702 << stage << " : Stride " << stride << "\n";
703 for (int64_t i = stride * 2 - 1; i < width; i += stride * 2) {
704 int64_t j = i - stride;
705
706 // Group generate: g_i OR (p_i AND g_j)
707 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
708 << " OR (P" << i << stage << " AND G" << j << stage
709 << ")\n";
710
711 // Group propagate: p_i AND p_j
712 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
713 << " AND P" << j << stage << "\n";
714 }
715 ++stage;
716 }
717
718 for (; stride > 0; stride /= 2) {
719 if (stride * 3 - 1 < width)
720 llvm::dbgs()
721 << "--------------------------------------- Brent-Kung BW "
722 << stage << " : Stride " << stride << "\n";
723
724 for (int64_t i = stride * 3 - 1; i < width; i += stride * 2) {
725 int64_t j = i - stride;
726
727 // Group generate: g_i OR (p_i AND g_j)
728 llvm::dbgs() << "G" << i << stage + 1 << " = G" << i << stage
729 << " OR (P" << i << stage << " AND G" << j << stage
730 << ")\n";
731
732 // Group propagate: p_i AND p_j
733 llvm::dbgs() << "P" << i << stage + 1 << " = P" << i << stage
734 << " AND P" << j << stage << "\n";
735 }
736 --stage;
737 }
738 });
739 }
740};
741
742struct CombSubOpConversion : OpConversionPattern<SubOp> {
744 LogicalResult
745 matchAndRewrite(SubOp op, OpAdaptor adaptor,
746 ConversionPatternRewriter &rewriter) const override {
747 auto lhs = op.getLhs();
748 auto rhs = op.getRhs();
749 // Since `-rhs = ~rhs + 1` holds, rewrite `sub(lhs, rhs)` to:
750 // sub(lhs, rhs) => add(lhs, -rhs) => add(lhs, add(~rhs, 1))
751 // => add(lhs, ~rhs, 1)
752 auto notRhs = synth::aig::AndInverterOp::create(rewriter, op.getLoc(), rhs,
753 /*invert=*/true);
754 auto one = hw::ConstantOp::create(rewriter, op.getLoc(), op.getType(), 1);
755 replaceOpWithNewOpAndCopyNamehint<comb::AddOp>(
756 rewriter, op, ValueRange{lhs, notRhs, one}, true);
757 return success();
758 }
759};
760
761struct CombMulOpConversion : OpConversionPattern<MulOp> {
763 using OpAdaptor = typename OpConversionPattern<MulOp>::OpAdaptor;
764 LogicalResult
765 matchAndRewrite(MulOp op, OpAdaptor adaptor,
766 ConversionPatternRewriter &rewriter) const override {
767 if (adaptor.getInputs().size() != 2)
768 return failure();
769
770 Location loc = op.getLoc();
771 Value a = adaptor.getInputs()[0];
772 Value b = adaptor.getInputs()[1];
773 unsigned width = op.getType().getIntOrFloatBitWidth();
774
775 // Skip a zero width value.
776 if (width == 0) {
777 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
778 return success();
779 }
780
781 // Extract individual bits from operands
782 SmallVector<Value> aBits = extractBits(rewriter, a);
783 SmallVector<Value> bBits = extractBits(rewriter, b);
784
785 auto falseValue = hw::ConstantOp::create(rewriter, loc, APInt(1, 0));
786
787 // Generate partial products
788 SmallVector<SmallVector<Value>> partialProducts;
789 partialProducts.reserve(width);
790 for (unsigned i = 0; i < width; ++i) {
791 SmallVector<Value> row(i, falseValue);
792 row.reserve(width);
793 // Generate partial product bits
794 for (unsigned j = 0; i + j < width; ++j)
795 row.push_back(
796 rewriter.createOrFold<comb::AndOp>(loc, aBits[j], bBits[i]));
797
798 partialProducts.push_back(row);
799 }
800
801 // If the width is 1, we are done.
802 if (width == 1) {
803 rewriter.replaceOp(op, partialProducts[0][0]);
804 return success();
805 }
806
807 // Wallace tree reduction - reduce to two addends.
808 datapath::CompressorTree comp(width, partialProducts, loc);
809 auto addends = comp.compressToHeight(rewriter, 2);
810
811 // Sum the two addends using a carry-propagate adder
812 auto newAdd = comb::AddOp::create(rewriter, loc, addends, true);
813 replaceOpAndCopyNamehint(rewriter, op, newAdd);
814 return success();
815 }
816};
817
818template <typename OpTy>
819struct DivModOpConversionBase : OpConversionPattern<OpTy> {
820 DivModOpConversionBase(MLIRContext *context, int64_t maxEmulationUnknownBits)
821 : OpConversionPattern<OpTy>(context),
822 maxEmulationUnknownBits(maxEmulationUnknownBits) {
823 assert(maxEmulationUnknownBits < 32 &&
824 "maxEmulationUnknownBits must be less than 32");
825 }
826 const int64_t maxEmulationUnknownBits;
827};
828
829struct CombDivUOpConversion : DivModOpConversionBase<DivUOp> {
830 using DivModOpConversionBase<DivUOp>::DivModOpConversionBase;
831 LogicalResult
832 matchAndRewrite(DivUOp op, OpAdaptor adaptor,
833 ConversionPatternRewriter &rewriter) const override {
834 // Check if the divisor is a power of two.
835 if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
836 if (rhsConstantOp.getValue().isPowerOf2()) {
837 // Extract upper bits.
838 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
839 size_t width = op.getType().getIntOrFloatBitWidth();
840 Value upperBits = rewriter.createOrFold<comb::ExtractOp>(
841 op.getLoc(), adaptor.getLhs(), extractAmount,
842 width - extractAmount);
843 Value constZero = hw::ConstantOp::create(rewriter, op.getLoc(),
844 APInt::getZero(extractAmount));
845 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(
846 rewriter, op, op.getType(), ArrayRef<Value>{constZero, upperBits});
847 return success();
848 }
849
850 // When rhs is not power of two and the number of unknown bits are small,
851 // create a mux tree that emulates all possible cases.
853 rewriter, maxEmulationUnknownBits, op,
854 [](const APInt &lhs, const APInt &rhs) {
855 // Division by zero is undefined, just return zero.
856 if (rhs.isZero())
857 return APInt::getZero(rhs.getBitWidth());
858 return lhs.udiv(rhs);
859 });
860 }
861};
862
863struct CombModUOpConversion : DivModOpConversionBase<ModUOp> {
864 using DivModOpConversionBase<ModUOp>::DivModOpConversionBase;
865 LogicalResult
866 matchAndRewrite(ModUOp op, OpAdaptor adaptor,
867 ConversionPatternRewriter &rewriter) const override {
868 // Check if the divisor is a power of two.
869 if (auto rhsConstantOp = adaptor.getRhs().getDefiningOp<hw::ConstantOp>())
870 if (rhsConstantOp.getValue().isPowerOf2()) {
871 // Extract lower bits.
872 size_t extractAmount = rhsConstantOp.getValue().ceilLogBase2();
873 size_t width = op.getType().getIntOrFloatBitWidth();
874 Value lowerBits = rewriter.createOrFold<comb::ExtractOp>(
875 op.getLoc(), adaptor.getLhs(), 0, extractAmount);
876 Value constZero = hw::ConstantOp::create(
877 rewriter, op.getLoc(), APInt::getZero(width - extractAmount));
878 replaceOpWithNewOpAndCopyNamehint<comb::ConcatOp>(
879 rewriter, op, op.getType(), ArrayRef<Value>{constZero, lowerBits});
880 return success();
881 }
882
883 // When rhs is not power of two and the number of unknown bits are small,
884 // create a mux tree that emulates all possible cases.
886 rewriter, maxEmulationUnknownBits, op,
887 [](const APInt &lhs, const APInt &rhs) {
888 // Division by zero is undefined, just return zero.
889 if (rhs.isZero())
890 return APInt::getZero(rhs.getBitWidth());
891 return lhs.urem(rhs);
892 });
893 }
894};
895
896struct CombDivSOpConversion : DivModOpConversionBase<DivSOp> {
897 using DivModOpConversionBase<DivSOp>::DivModOpConversionBase;
898
899 LogicalResult
900 matchAndRewrite(DivSOp op, OpAdaptor adaptor,
901 ConversionPatternRewriter &rewriter) const override {
902 // Currently only lower with emulation.
903 // TODO: Implement a signed division lowering at least for power of two.
905 rewriter, maxEmulationUnknownBits, op,
906 [](const APInt &lhs, const APInt &rhs) {
907 // Division by zero is undefined, just return zero.
908 if (rhs.isZero())
909 return APInt::getZero(rhs.getBitWidth());
910 return lhs.sdiv(rhs);
911 });
912 }
913};
914
915struct CombModSOpConversion : DivModOpConversionBase<ModSOp> {
916 using DivModOpConversionBase<ModSOp>::DivModOpConversionBase;
917 LogicalResult
918 matchAndRewrite(ModSOp op, OpAdaptor adaptor,
919 ConversionPatternRewriter &rewriter) const override {
920 // Currently only lower with emulation.
921 // TODO: Implement a signed modulus lowering at least for power of two.
923 rewriter, maxEmulationUnknownBits, op,
924 [](const APInt &lhs, const APInt &rhs) {
925 // Division by zero is undefined, just return zero.
926 if (rhs.isZero())
927 return APInt::getZero(rhs.getBitWidth());
928 return lhs.srem(rhs);
929 });
930 }
931};
932
933struct CombICmpOpConversion : OpConversionPattern<ICmpOp> {
935 static Value constructUnsignedCompare(ICmpOp op, ArrayRef<Value> aBits,
936 ArrayRef<Value> bBits, bool isLess,
937 bool includeEq,
938 ConversionPatternRewriter &rewriter) {
939 // Construct following unsigned comparison expressions.
940 // a <= b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] <= b[n-1:0])
941 // a < b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] < b[n-1:0])
942 // a >= b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] >= b[n-1:0])
943 // a > b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] > b[n-1:0])
944 Value acc =
945 hw::ConstantOp::create(rewriter, op.getLoc(), op.getType(), includeEq);
946
947 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
948 auto aBitXorBBit =
949 rewriter.createOrFold<comb::XorOp>(op.getLoc(), aBit, bBit, true);
950 auto aEqualB = rewriter.createOrFold<synth::aig::AndInverterOp>(
951 op.getLoc(), aBitXorBBit, true);
952 auto pred = rewriter.createOrFold<synth::aig::AndInverterOp>(
953 op.getLoc(), aBit, bBit, isLess, !isLess);
954
955 auto aBitAndBBit = rewriter.createOrFold<comb::AndOp>(
956 op.getLoc(), ValueRange{aEqualB, acc}, true);
957 acc = rewriter.createOrFold<comb::OrOp>(op.getLoc(), pred, aBitAndBBit,
958 true);
959 }
960 return acc;
961 }
962
963 LogicalResult
964 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
965 ConversionPatternRewriter &rewriter) const override {
966 auto lhs = adaptor.getLhs();
967 auto rhs = adaptor.getRhs();
968
969 switch (op.getPredicate()) {
970 default:
971 return failure();
972
973 case ICmpPredicate::eq:
974 case ICmpPredicate::ceq: {
975 // a == b ==> ~(a[n] ^ b[n]) & ~(a[n-1] ^ b[n-1]) & ...
976 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
977 auto xorBits = extractBits(rewriter, xorOp);
978 SmallVector<bool> allInverts(xorBits.size(), true);
979 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
980 rewriter, op, xorBits, allInverts);
981 return success();
982 }
983
984 case ICmpPredicate::ne:
985 case ICmpPredicate::cne: {
986 // a != b ==> (a[n] ^ b[n]) | (a[n-1] ^ b[n-1]) | ...
987 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
988 replaceOpWithNewOpAndCopyNamehint<comb::OrOp>(
989 rewriter, op, extractBits(rewriter, xorOp), true);
990 return success();
991 }
992
993 case ICmpPredicate::uge:
994 case ICmpPredicate::ugt:
995 case ICmpPredicate::ule:
996 case ICmpPredicate::ult: {
997 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
998 op.getPredicate() == ICmpPredicate::ule;
999 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
1000 op.getPredicate() == ICmpPredicate::ule;
1001 auto aBits = extractBits(rewriter, lhs);
1002 auto bBits = extractBits(rewriter, rhs);
1003 replaceOpAndCopyNamehint(rewriter, op,
1004 constructUnsignedCompare(op, aBits, bBits,
1005 isLess, includeEq,
1006 rewriter));
1007 return success();
1008 }
1009 case ICmpPredicate::slt:
1010 case ICmpPredicate::sle:
1011 case ICmpPredicate::sgt:
1012 case ICmpPredicate::sge: {
1013 if (lhs.getType().getIntOrFloatBitWidth() == 0)
1014 return rewriter.notifyMatchFailure(
1015 op.getLoc(), "i0 signed comparison is unsupported");
1016 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
1017 op.getPredicate() == ICmpPredicate::sle;
1018 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
1019 op.getPredicate() == ICmpPredicate::sle;
1020
1021 auto aBits = extractBits(rewriter, lhs);
1022 auto bBits = extractBits(rewriter, rhs);
1023
1024 // Get a sign bit
1025 auto signA = aBits.back();
1026 auto signB = bBits.back();
1027
1028 // Compare magnitudes (all bits except sign)
1029 auto sameSignResult = constructUnsignedCompare(
1030 op, ArrayRef(aBits).drop_back(), ArrayRef(bBits).drop_back(), isLess,
1031 includeEq, rewriter);
1032
1033 // XOR of signs: true if signs are different
1034 auto signsDiffer =
1035 comb::XorOp::create(rewriter, op.getLoc(), signA, signB);
1036
1037 // Result when signs are different
1038 Value diffSignResult = isLess ? signA : signB;
1039
1040 // Final result: choose based on whether signs differ
1041 replaceOpWithNewOpAndCopyNamehint<comb::MuxOp>(
1042 rewriter, op, signsDiffer, diffSignResult, sameSignResult);
1043 return success();
1044 }
1045 }
1046 }
1047};
1048
1049struct CombParityOpConversion : OpConversionPattern<ParityOp> {
1051
1052 LogicalResult
1053 matchAndRewrite(ParityOp op, OpAdaptor adaptor,
1054 ConversionPatternRewriter &rewriter) const override {
1055 // Parity is the XOR of all bits.
1056 replaceOpWithNewOpAndCopyNamehint<comb::XorOp>(
1057 rewriter, op, extractBits(rewriter, adaptor.getInput()), true);
1058 return success();
1059 }
1060};
1061
1062struct CombShlOpConversion : OpConversionPattern<comb::ShlOp> {
1064
1065 LogicalResult
1066 matchAndRewrite(comb::ShlOp op, OpAdaptor adaptor,
1067 ConversionPatternRewriter &rewriter) const override {
1068 auto width = op.getType().getIntOrFloatBitWidth();
1069 auto lhs = adaptor.getLhs();
1070 auto result = createShiftLogic</*isLeftShift=*/true>(
1071 rewriter, op.getLoc(), adaptor.getRhs(), width,
1072 /*getPadding=*/
1073 [&](int64_t index) {
1074 // Don't create zero width value.
1075 if (index == 0)
1076 return Value();
1077 // Padding is 0 for left shift.
1078 return rewriter.createOrFold<hw::ConstantOp>(
1079 op.getLoc(), rewriter.getIntegerType(index), 0);
1080 },
1081 /*getExtract=*/
1082 [&](int64_t index) {
1083 assert(index < width && "index out of bounds");
1084 // Exract the bits from LSB.
1085 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, 0,
1086 width - index);
1087 });
1088
1089 replaceOpAndCopyNamehint(rewriter, op, result);
1090 return success();
1091 }
1092};
1093
1094struct CombShrUOpConversion : OpConversionPattern<comb::ShrUOp> {
1096
1097 LogicalResult
1098 matchAndRewrite(comb::ShrUOp op, OpAdaptor adaptor,
1099 ConversionPatternRewriter &rewriter) const override {
1100 auto width = op.getType().getIntOrFloatBitWidth();
1101 auto lhs = adaptor.getLhs();
1102 auto result = createShiftLogic</*isLeftShift=*/false>(
1103 rewriter, op.getLoc(), adaptor.getRhs(), width,
1104 /*getPadding=*/
1105 [&](int64_t index) {
1106 // Don't create zero width value.
1107 if (index == 0)
1108 return Value();
1109 // Padding is 0 for right shift.
1110 return rewriter.createOrFold<hw::ConstantOp>(
1111 op.getLoc(), rewriter.getIntegerType(index), 0);
1112 },
1113 /*getExtract=*/
1114 [&](int64_t index) {
1115 assert(index < width && "index out of bounds");
1116 // Exract the bits from MSB.
1117 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
1118 width - index);
1119 });
1120
1121 replaceOpAndCopyNamehint(rewriter, op, result);
1122 return success();
1123 }
1124};
1125
1126struct CombShrSOpConversion : OpConversionPattern<comb::ShrSOp> {
1128
1129 LogicalResult
1130 matchAndRewrite(comb::ShrSOp op, OpAdaptor adaptor,
1131 ConversionPatternRewriter &rewriter) const override {
1132 auto width = op.getType().getIntOrFloatBitWidth();
1133 if (width == 0)
1134 return rewriter.notifyMatchFailure(op.getLoc(),
1135 "i0 signed shift is unsupported");
1136 auto lhs = adaptor.getLhs();
1137 // Get the sign bit.
1138 auto sign =
1139 rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
1140
1141 // NOTE: The max shift amount is width - 1 because the sign bit is
1142 // already shifted out.
1143 auto result = createShiftLogic</*isLeftShift=*/false>(
1144 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
1145 /*getPadding=*/
1146 [&](int64_t index) {
1147 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
1148 index + 1);
1149 },
1150 /*getExtract=*/
1151 [&](int64_t index) {
1152 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
1153 width - index - 1);
1154 });
1155
1156 replaceOpAndCopyNamehint(rewriter, op, result);
1157 return success();
1158 }
1159};
1160
1161} // namespace
1162
1163//===----------------------------------------------------------------------===//
1164// Convert Comb to AIG pass
1165//===----------------------------------------------------------------------===//
1166
1167namespace {
1168struct ConvertCombToSynthPass
1169 : public impl::ConvertCombToSynthBase<ConvertCombToSynthPass> {
1170 void runOnOperation() override;
1171 using ConvertCombToSynthBase<ConvertCombToSynthPass>::ConvertCombToSynthBase;
1172};
1173} // namespace
1174
1175static void
1177 uint32_t maxEmulationUnknownBits,
1178 bool lowerToMIG) {
1179 patterns.add<
1180 // Bitwise Logical Ops
1181 CombAndOpConversion, CombXorOpConversion, CombMuxOpConversion,
1182 CombParityOpConversion,
1183 // Arithmetic Ops
1184 CombSubOpConversion, CombMulOpConversion, CombICmpOpConversion,
1185 // Shift Ops
1186 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
1187 // Variadic ops that must be lowered to binary operations
1188 CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
1189 CombLowerVariadicOp<MulOp>>(patterns.getContext());
1190
1191 if (lowerToMIG) {
1192 patterns.add<CombOrToMIGConversion, CombLowerVariadicOp<OrOp>,
1193 AndInverterToMIGConversion,
1195 CombAddOpConversion</*useMIG=*/true>>(patterns.getContext());
1196 } else {
1197 patterns.add<CombOrToAIGConversion, CombAddOpConversion</*useMIG=*/false>>(
1198 patterns.getContext());
1199 }
1200
1201 // Add div/mod patterns with a threshold given by the pass option.
1202 patterns.add<CombDivUOpConversion, CombModUOpConversion, CombDivSOpConversion,
1203 CombModSOpConversion>(patterns.getContext(),
1204 maxEmulationUnknownBits);
1205}
1206
1207void ConvertCombToSynthPass::runOnOperation() {
1208 ConversionTarget target(getContext());
1209
1210 // Comb is source dialect.
1211 target.addIllegalDialect<comb::CombDialect>();
1212 // Keep data movement operations like Extract, Concat and Replicate.
1213 target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
1215
1216 // Treat array operations as illegal. Strictly speaking, other than array
1217 // get operation with non-const index are legal in AIG but array types
1218 // prevent a bunch of optimizations so just lower them to integer
1219 // operations. It's required to run HWAggregateToComb pass before this pass.
1221 hw::AggregateConstantOp>();
1222
1223 target.addLegalDialect<synth::SynthDialect>();
1224
1225 if (targetIR == CombToSynthTargetIR::AIG) {
1226 // AIG is target dialect.
1227 target.addIllegalOp<synth::mig::MajorityInverterOp>();
1228 } else if (targetIR == CombToSynthTargetIR::MIG) {
1229 target.addIllegalOp<synth::aig::AndInverterOp>();
1230 }
1231
1232 // If additional legal ops are specified, add them to the target.
1233 if (!additionalLegalOps.empty())
1234 for (const auto &opName : additionalLegalOps)
1235 target.addLegalOp(OperationName(opName, &getContext()));
1236
1237 RewritePatternSet patterns(&getContext());
1238 populateCombToAIGConversionPatterns(patterns, maxEmulationUnknownBits,
1239 targetIR == CombToSynthTargetIR::MIG);
1240
1241 if (failed(mlir::applyPartialConversion(getOperation(), target,
1242 std::move(patterns))))
1243 return signalPassFailure();
1244}
assert(baseType &&"element must be base type")
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
static APInt substitueMaskToValues(size_t width, llvm::SmallVectorImpl< ConstantOrValue > &constantOrValues, uint32_t mask)
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns, uint32_t maxEmulationUnknownBits, bool lowerToMIG)
static Value createMajorityFunction(OpBuilder &rewriter, Location loc, Value a, Value b, Value carry, bool useMajorityInverterOp)
static LogicalResult emulateBinaryOpForUnknownBits(ConversionPatternRewriter &rewriter, int64_t maxEmulationUnknownBits, Operation *op, llvm::function_ref< APInt(const APInt &, const APInt &)> emulate)
static int64_t getNumUnknownBitsAndPopulateValues(Value value, llvm::SmallVectorImpl< ConstantOrValue > &values)
static std::optional< APSInt > getConstant(Attribute operand)
Determine the value of a constant operand for the sake of constant folding.
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
create(data_type, value)
Definition hw.py:441
create(data_type, value)
Definition hw.py:433
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
Definition comb.py:1