CIRCT 20.0.0git
Loading...
Searching...
No Matches
CombToAIG.cpp
Go to the documentation of this file.
1//===- CombToAIG.cpp - Comb to AIG Conversion Pass --------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This is the main Comb to AIG Conversion Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/Pass/Pass.h"
18#include "mlir/Transforms/DialectConversion.h"
19
20namespace circt {
21#define GEN_PASS_DEF_CONVERTCOMBTOAIG
22#include "circt/Conversion/Passes.h.inc"
23} // namespace circt
24
25using namespace circt;
26using namespace comb;
27
28//===----------------------------------------------------------------------===//
29// Utility Functions
30//===----------------------------------------------------------------------===//
31
32// A wrapper for comb::extractBits that returns a SmallVector<Value>.
33static SmallVector<Value> extractBits(OpBuilder &builder, Value val) {
34 SmallVector<Value> bits;
35 comb::extractBits(builder, val, bits);
36 return bits;
37}
38
39// Construct a mux tree for shift operations. `isLeftShift` controls the
40// direction of the shift operation and is used to determine order of the
41// padding and extracted bits. Callbacks `getPadding` and `getExtract` are used
42// to get the padding and extracted bits for each shift amount. `getPadding`
43// could return a nullptr as i0 value but except for that, these callbacks must
44// return a valid value for each shift amount in the range [0, maxShiftAmount].
45// The value for `maxShiftAmount` is used as the out-of-bounds value.
46template <bool isLeftShift>
47static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc,
48 Value shiftAmount, int64_t maxShiftAmount,
49 llvm::function_ref<Value(int64_t)> getPadding,
50 llvm::function_ref<Value(int64_t)> getExtract) {
51 // Extract individual bits from shift amount
52 auto bits = extractBits(rewriter, shiftAmount);
53
54 // Create nodes for each possible shift amount
55 SmallVector<Value> nodes;
56 nodes.reserve(maxShiftAmount);
57 for (int64_t i = 0; i < maxShiftAmount; ++i) {
58 Value extract = getExtract(i);
59 Value padding = getPadding(i);
60
61 if (!padding) {
62 nodes.push_back(extract);
63 continue;
64 }
65
66 // Concatenate extracted bits with padding
67 if (isLeftShift)
68 nodes.push_back(
69 rewriter.createOrFold<comb::ConcatOp>(loc, extract, padding));
70 else
71 nodes.push_back(
72 rewriter.createOrFold<comb::ConcatOp>(loc, padding, extract));
73 }
74
75 // Create out-of-bounds value
76 auto outOfBoundsValue = getPadding(maxShiftAmount);
77 assert(outOfBoundsValue && "outOfBoundsValue must be valid");
78
79 // Construct mux tree for shift operation
80 auto result =
81 comb::constructMuxTree(rewriter, loc, bits, nodes, outOfBoundsValue);
82
83 // Add bounds checking
84 auto inBound = rewriter.createOrFold<comb::ICmpOp>(
85 loc, ICmpPredicate::ult, shiftAmount,
86 rewriter.create<hw::ConstantOp>(loc, shiftAmount.getType(),
87 maxShiftAmount));
88
89 return rewriter.createOrFold<comb::MuxOp>(loc, inBound, result,
90 outOfBoundsValue);
91}
92
93//===----------------------------------------------------------------------===//
94// Conversion patterns
95//===----------------------------------------------------------------------===//
96
97namespace {
98
99/// Lower a comb::AndOp operation to aig::AndInverterOp
100struct CombAndOpConversion : OpConversionPattern<AndOp> {
102
103 LogicalResult
104 matchAndRewrite(AndOp op, OpAdaptor adaptor,
105 ConversionPatternRewriter &rewriter) const override {
106 SmallVector<bool> nonInverts(adaptor.getInputs().size(), false);
107 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, adaptor.getInputs(),
108 nonInverts);
109 return success();
110 }
111};
112
113/// Lower a comb::OrOp operation to aig::AndInverterOp with invert flags
114struct CombOrOpConversion : OpConversionPattern<OrOp> {
116
117 LogicalResult
118 matchAndRewrite(OrOp op, OpAdaptor adaptor,
119 ConversionPatternRewriter &rewriter) const override {
120 // Implement Or using And and invert flags: a | b = ~(~a & ~b)
121 SmallVector<bool> allInverts(adaptor.getInputs().size(), true);
122 auto andOp = rewriter.create<aig::AndInverterOp>(
123 op.getLoc(), adaptor.getInputs(), allInverts);
124 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, andOp,
125 /*invert=*/true);
126 return success();
127 }
128};
129
130/// Lower a comb::XorOp operation to AIG operations
131struct CombXorOpConversion : OpConversionPattern<XorOp> {
133
134 LogicalResult
135 matchAndRewrite(XorOp op, OpAdaptor adaptor,
136 ConversionPatternRewriter &rewriter) const override {
137 if (op.getNumOperands() != 2)
138 return failure();
139 // Xor using And with invert flags: a ^ b = (a | b) & (~a | ~b)
140
141 // (a | b) = ~(~a & ~b)
142 // (~a | ~b) = ~(a & b)
143 auto inputs = adaptor.getInputs();
144 SmallVector<bool> allInverts(inputs.size(), true);
145 SmallVector<bool> allNotInverts(inputs.size(), false);
146
147 auto notAAndNotB =
148 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allInverts);
149 auto aAndB =
150 rewriter.create<aig::AndInverterOp>(op.getLoc(), inputs, allNotInverts);
151
152 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, notAAndNotB, aAndB,
153 /*lhs_invert=*/true,
154 /*rhs_invert=*/true);
155 return success();
156 }
157};
158
159template <typename OpTy>
160struct CombLowerVariadicOp : OpConversionPattern<OpTy> {
162 using OpAdaptor = typename OpConversionPattern<OpTy>::OpAdaptor;
163 LogicalResult
164 matchAndRewrite(OpTy op, OpAdaptor adaptor,
165 ConversionPatternRewriter &rewriter) const override {
166 auto result = lowerFullyAssociativeOp(op, op.getOperands(), rewriter);
167 rewriter.replaceOp(op, result);
168 return success();
169 }
170
171 static Value lowerFullyAssociativeOp(OpTy op, OperandRange operands,
172 ConversionPatternRewriter &rewriter) {
173 Value lhs, rhs;
174 switch (operands.size()) {
175 case 0:
176 assert(false && "cannot be called with empty operand range");
177 break;
178 case 1:
179 return operands[0];
180 case 2:
181 lhs = operands[0];
182 rhs = operands[1];
183 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs}, true);
184 default:
185 auto firstHalf = operands.size() / 2;
186 lhs =
187 lowerFullyAssociativeOp(op, operands.take_front(firstHalf), rewriter);
188 rhs =
189 lowerFullyAssociativeOp(op, operands.drop_front(firstHalf), rewriter);
190 return rewriter.create<OpTy>(op.getLoc(), ValueRange{lhs, rhs}, true);
191 }
192 }
193};
194
195// Lower comb::MuxOp to AIG operations.
196struct CombMuxOpConversion : OpConversionPattern<MuxOp> {
198
199 LogicalResult
200 matchAndRewrite(MuxOp op, OpAdaptor adaptor,
201 ConversionPatternRewriter &rewriter) const override {
202 // Implement: c ? a : b = (replicate(c) & a) | (~replicate(c) & b)
203
204 Value cond = op.getCond();
205 auto trueVal = op.getTrueValue();
206 auto falseVal = op.getFalseValue();
207
208 if (!op.getType().isInteger()) {
209 // If the type of the mux is not integer, bitcast the operands first.
210 auto widthType = rewriter.getIntegerType(hw::getBitWidth(op.getType()));
211 trueVal =
212 rewriter.create<hw::BitcastOp>(op->getLoc(), widthType, trueVal);
213 falseVal =
214 rewriter.create<hw::BitcastOp>(op->getLoc(), widthType, falseVal);
215 }
216
217 // Replicate condition if needed
218 if (!trueVal.getType().isInteger(1))
219 cond = rewriter.create<comb::ReplicateOp>(op.getLoc(), trueVal.getType(),
220 cond);
221
222 // c ? a : b => (replicate(c) & a) | (~replicate(c) & b)
223 auto lhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, trueVal);
224 auto rhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), cond, falseVal,
225 true, false);
226
227 Value result = rewriter.create<comb::OrOp>(op.getLoc(), lhs, rhs);
228 // Insert the bitcast if the type of the mux is not integer.
229 if (result.getType() != op.getType())
230 result =
231 rewriter.create<hw::BitcastOp>(op.getLoc(), op.getType(), result);
232 rewriter.replaceOp(op, result);
233 return success();
234 }
235};
236
237struct CombAddOpConversion : OpConversionPattern<AddOp> {
239 LogicalResult
240 matchAndRewrite(AddOp op, OpAdaptor adaptor,
241 ConversionPatternRewriter &rewriter) const override {
242 auto inputs = adaptor.getInputs();
243 // Lower only when there are two inputs.
244 // Variadic operands must be lowered in a different pattern.
245 if (inputs.size() != 2)
246 return failure();
247
248 auto width = op.getType().getIntOrFloatBitWidth();
249 // Skip a zero width value.
250 if (width == 0) {
251 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, op.getType(), 0);
252 return success();
253 }
254
255 // Implement a naive Ripple-carry full adder.
256 Value carry;
257
258 auto aBits = extractBits(rewriter, inputs[0]);
259 auto bBits = extractBits(rewriter, inputs[1]);
260 SmallVector<Value> results;
261 results.resize(width);
262 for (int64_t i = 0; i < width; ++i) {
263 SmallVector<Value> xorOperands = {aBits[i], bBits[i]};
264 if (carry)
265 xorOperands.push_back(carry);
266
267 // sum[i] = xor(carry[i-1], a[i], b[i])
268 // NOTE: The result is stored in reverse order.
269 results[width - i - 1] =
270 rewriter.create<comb::XorOp>(op.getLoc(), xorOperands, true);
271
272 // If this is the last bit, we are done.
273 if (i == width - 1) {
274 break;
275 }
276
277 // carry[i] = (carry[i-1] & (a[i] ^ b[i])) | (a[i] & b[i])
278 Value nextCarry = rewriter.create<comb::AndOp>(
279 op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
280 if (!carry) {
281 // This is the first bit, so the carry is the next carry.
282 carry = nextCarry;
283 continue;
284 }
285
286 auto aXnorB = rewriter.create<comb::XorOp>(
287 op.getLoc(), ValueRange{aBits[i], bBits[i]}, true);
288 auto andOp = rewriter.create<comb::AndOp>(
289 op.getLoc(), ValueRange{carry, aXnorB}, true);
290 carry = rewriter.create<comb::OrOp>(op.getLoc(),
291 ValueRange{andOp, nextCarry}, true);
292 }
293
294 rewriter.replaceOpWithNewOp<comb::ConcatOp>(op, results);
295 return success();
296 }
297};
298
299struct CombSubOpConversion : OpConversionPattern<SubOp> {
301 LogicalResult
302 matchAndRewrite(SubOp op, OpAdaptor adaptor,
303 ConversionPatternRewriter &rewriter) const override {
304 auto lhs = op.getLhs();
305 auto rhs = op.getRhs();
306 // Since `-rhs = ~rhs + 1` holds, rewrite `sub(lhs, rhs)` to:
307 // sub(lhs, rhs) => add(lhs, -rhs) => add(lhs, add(~rhs, 1))
308 // => add(lhs, ~rhs, 1)
309 auto notRhs = rewriter.create<aig::AndInverterOp>(op.getLoc(), rhs,
310 /*invert=*/true);
311 auto one = rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), 1);
312 rewriter.replaceOpWithNewOp<comb::AddOp>(op, ValueRange{lhs, notRhs, one},
313 true);
314 return success();
315 }
316};
317
318struct CombMulOpConversion : OpConversionPattern<MulOp> {
320 using OpAdaptor = typename OpConversionPattern<MulOp>::OpAdaptor;
321 LogicalResult
322 matchAndRewrite(MulOp op, OpAdaptor adaptor,
323 ConversionPatternRewriter &rewriter) const override {
324 if (adaptor.getInputs().size() != 2)
325 return failure();
326
327 // FIXME: Currently it's lowered to a really naive implementation that
328 // chains add operations.
329
330 // a_{n}a_{n-1}...a_0 * b
331 // = sum_{i=0}^{n} a_i * 2^i * b
332 // = sum_{i=0}^{n} (a_i ? b : 0) << i
333 int64_t width = op.getType().getIntOrFloatBitWidth();
334 auto aBits = extractBits(rewriter, adaptor.getInputs()[0]);
335 SmallVector<Value> results;
336 auto rhs = op.getInputs()[1];
337 auto zero = rewriter.create<hw::ConstantOp>(op.getLoc(),
338 llvm::APInt::getZero(width));
339 for (int64_t i = 0; i < width; ++i) {
340 auto aBit = aBits[i];
341 auto andBit =
342 rewriter.createOrFold<comb::MuxOp>(op.getLoc(), aBit, rhs, zero);
343 auto upperBits = rewriter.createOrFold<comb::ExtractOp>(
344 op.getLoc(), andBit, 0, width - i);
345 if (i == 0) {
346 results.push_back(upperBits);
347 continue;
348 }
349
350 auto lowerBits =
351 rewriter.create<hw::ConstantOp>(op.getLoc(), APInt::getZero(i));
352
353 auto shifted = rewriter.createOrFold<comb::ConcatOp>(
354 op.getLoc(), op.getType(), ValueRange{upperBits, lowerBits});
355 results.push_back(shifted);
356 }
357
358 rewriter.replaceOpWithNewOp<comb::AddOp>(op, results, true);
359 return success();
360 }
361};
362
363struct CombICmpOpConversion : OpConversionPattern<ICmpOp> {
365 static Value constructUnsignedCompare(ICmpOp op, ArrayRef<Value> aBits,
366 ArrayRef<Value> bBits, bool isLess,
367 bool includeEq,
368 ConversionPatternRewriter &rewriter) {
369 // Construct following unsigned comparison expressions.
370 // a <= b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] <= b[n-1:0])
371 // a < b ==> (~a[n] & b[n]) | (a[n] == b[n] & a[n-1:0] < b[n-1:0])
372 // a >= b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] >= b[n-1:0])
373 // a > b ==> ( a[n] & ~b[n]) | (a[n] == b[n] & a[n-1:0] > b[n-1:0])
374 Value acc =
375 rewriter.create<hw::ConstantOp>(op.getLoc(), op.getType(), includeEq);
376
377 for (auto [aBit, bBit] : llvm::zip(aBits, bBits)) {
378 auto aBitXorBBit =
379 rewriter.createOrFold<comb::XorOp>(op.getLoc(), aBit, bBit, true);
380 auto aEqualB = rewriter.createOrFold<aig::AndInverterOp>(
381 op.getLoc(), aBitXorBBit, true);
382 auto pred = rewriter.createOrFold<aig::AndInverterOp>(
383 op.getLoc(), aBit, bBit, isLess, !isLess);
384
385 auto aBitAndBBit = rewriter.createOrFold<comb::AndOp>(
386 op.getLoc(), ValueRange{aEqualB, acc}, true);
387 acc = rewriter.createOrFold<comb::OrOp>(op.getLoc(), pred, aBitAndBBit,
388 true);
389 }
390 return acc;
391 }
392
393 LogicalResult
394 matchAndRewrite(ICmpOp op, OpAdaptor adaptor,
395 ConversionPatternRewriter &rewriter) const override {
396 auto lhs = adaptor.getLhs();
397 auto rhs = adaptor.getRhs();
398
399 switch (op.getPredicate()) {
400 default:
401 return failure();
402
403 case ICmpPredicate::eq:
404 case ICmpPredicate::ceq: {
405 // a == b ==> ~(a[n] ^ b[n]) & ~(a[n-1] ^ b[n-1]) & ...
406 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
407 auto xorBits = extractBits(rewriter, xorOp);
408 SmallVector<bool> allInverts(xorBits.size(), true);
409 rewriter.replaceOpWithNewOp<aig::AndInverterOp>(op, xorBits, allInverts);
410 return success();
411 }
412
413 case ICmpPredicate::ne:
414 case ICmpPredicate::cne: {
415 // a != b ==> (a[n] ^ b[n]) | (a[n-1] ^ b[n-1]) | ...
416 auto xorOp = rewriter.createOrFold<comb::XorOp>(op.getLoc(), lhs, rhs);
417 rewriter.replaceOpWithNewOp<comb::OrOp>(op, extractBits(rewriter, xorOp),
418 true);
419 return success();
420 }
421
422 case ICmpPredicate::uge:
423 case ICmpPredicate::ugt:
424 case ICmpPredicate::ule:
425 case ICmpPredicate::ult: {
426 bool isLess = op.getPredicate() == ICmpPredicate::ult ||
427 op.getPredicate() == ICmpPredicate::ule;
428 bool includeEq = op.getPredicate() == ICmpPredicate::uge ||
429 op.getPredicate() == ICmpPredicate::ule;
430 auto aBits = extractBits(rewriter, lhs);
431 auto bBits = extractBits(rewriter, rhs);
432 rewriter.replaceOp(op, constructUnsignedCompare(op, aBits, bBits, isLess,
433 includeEq, rewriter));
434 return success();
435 }
436 case ICmpPredicate::slt:
437 case ICmpPredicate::sle:
438 case ICmpPredicate::sgt:
439 case ICmpPredicate::sge: {
440 if (lhs.getType().getIntOrFloatBitWidth() == 0)
441 return rewriter.notifyMatchFailure(
442 op.getLoc(), "i0 signed comparison is unsupported");
443 bool isLess = op.getPredicate() == ICmpPredicate::slt ||
444 op.getPredicate() == ICmpPredicate::sle;
445 bool includeEq = op.getPredicate() == ICmpPredicate::sge ||
446 op.getPredicate() == ICmpPredicate::sle;
447
448 auto aBits = extractBits(rewriter, lhs);
449 auto bBits = extractBits(rewriter, rhs);
450
451 // Get a sign bit
452 auto signA = aBits.back();
453 auto signB = bBits.back();
454
455 // Compare magnitudes (all bits except sign)
456 auto sameSignResult = constructUnsignedCompare(
457 op, ArrayRef(aBits).drop_back(), ArrayRef(bBits).drop_back(), isLess,
458 includeEq, rewriter);
459
460 // XOR of signs: true if signs are different
461 auto signsDiffer =
462 rewriter.create<comb::XorOp>(op.getLoc(), signA, signB);
463
464 // Result when signs are different
465 Value diffSignResult = isLess ? signA : signB;
466
467 // Final result: choose based on whether signs differ
468 rewriter.replaceOpWithNewOp<comb::MuxOp>(op, signsDiffer, diffSignResult,
469 sameSignResult);
470 return success();
471 }
472 }
473 }
474};
475
476struct CombParityOpConversion : OpConversionPattern<ParityOp> {
478
479 LogicalResult
480 matchAndRewrite(ParityOp op, OpAdaptor adaptor,
481 ConversionPatternRewriter &rewriter) const override {
482 // Parity is the XOR of all bits.
483 rewriter.replaceOpWithNewOp<comb::XorOp>(
484 op, extractBits(rewriter, adaptor.getInput()), true);
485 return success();
486 }
487};
488
489struct CombShlOpConversion : OpConversionPattern<comb::ShlOp> {
491
492 LogicalResult
493 matchAndRewrite(comb::ShlOp op, OpAdaptor adaptor,
494 ConversionPatternRewriter &rewriter) const override {
495 auto width = op.getType().getIntOrFloatBitWidth();
496 auto lhs = adaptor.getLhs();
497 auto result = createShiftLogic</*isLeftShift=*/true>(
498 rewriter, op.getLoc(), adaptor.getRhs(), width,
499 /*getPadding=*/
500 [&](int64_t index) {
501 // Don't create zero width value.
502 if (index == 0)
503 return Value();
504 // Padding is 0 for left shift.
505 return rewriter.createOrFold<hw::ConstantOp>(
506 op.getLoc(), rewriter.getIntegerType(index), 0);
507 },
508 /*getExtract=*/
509 [&](int64_t index) {
510 assert(index < width && "index out of bounds");
511 // Exract the bits from LSB.
512 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, 0,
513 width - index);
514 });
515
516 rewriter.replaceOp(op, result);
517 return success();
518 }
519};
520
521struct CombShrUOpConversion : OpConversionPattern<comb::ShrUOp> {
523
524 LogicalResult
525 matchAndRewrite(comb::ShrUOp op, OpAdaptor adaptor,
526 ConversionPatternRewriter &rewriter) const override {
527 auto width = op.getType().getIntOrFloatBitWidth();
528 auto lhs = adaptor.getLhs();
529 auto result = createShiftLogic</*isLeftShift=*/false>(
530 rewriter, op.getLoc(), adaptor.getRhs(), width,
531 /*getPadding=*/
532 [&](int64_t index) {
533 // Don't create zero width value.
534 if (index == 0)
535 return Value();
536 // Padding is 0 for right shift.
537 return rewriter.createOrFold<hw::ConstantOp>(
538 op.getLoc(), rewriter.getIntegerType(index), 0);
539 },
540 /*getExtract=*/
541 [&](int64_t index) {
542 assert(index < width && "index out of bounds");
543 // Exract the bits from MSB.
544 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
545 width - index);
546 });
547
548 rewriter.replaceOp(op, result);
549 return success();
550 }
551};
552
553struct CombShrSOpConversion : OpConversionPattern<comb::ShrSOp> {
555
556 LogicalResult
557 matchAndRewrite(comb::ShrSOp op, OpAdaptor adaptor,
558 ConversionPatternRewriter &rewriter) const override {
559 auto width = op.getType().getIntOrFloatBitWidth();
560 if (width == 0)
561 return rewriter.notifyMatchFailure(op.getLoc(),
562 "i0 signed shift is unsupported");
563 auto lhs = adaptor.getLhs();
564 // Get the sign bit.
565 auto sign =
566 rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, width - 1, 1);
567
568 // NOTE: The max shift amount is width - 1 because the sign bit is already
569 // shifted out.
570 auto result = createShiftLogic</*isLeftShift=*/false>(
571 rewriter, op.getLoc(), adaptor.getRhs(), width - 1,
572 /*getPadding=*/
573 [&](int64_t index) {
574 return rewriter.createOrFold<comb::ReplicateOp>(op.getLoc(), sign,
575 index + 1);
576 },
577 /*getExtract=*/
578 [&](int64_t index) {
579 return rewriter.createOrFold<comb::ExtractOp>(op.getLoc(), lhs, index,
580 width - index - 1);
581 });
582
583 rewriter.replaceOp(op, result);
584 return success();
585 }
586};
587
588} // namespace
589
590//===----------------------------------------------------------------------===//
591// Convert Comb to AIG pass
592//===----------------------------------------------------------------------===//
593
594namespace {
595struct ConvertCombToAIGPass
596 : public impl::ConvertCombToAIGBase<ConvertCombToAIGPass> {
597 void runOnOperation() override;
598 using ConvertCombToAIGBase<ConvertCombToAIGPass>::ConvertCombToAIGBase;
599 using ConvertCombToAIGBase<ConvertCombToAIGPass>::additionalLegalOps;
600};
601} // namespace
602
603static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns) {
604 patterns.add<
605 // Bitwise Logical Ops
606 CombAndOpConversion, CombOrOpConversion, CombXorOpConversion,
607 CombMuxOpConversion, CombParityOpConversion,
608 // Arithmetic Ops
609 CombAddOpConversion, CombSubOpConversion, CombMulOpConversion,
610 CombICmpOpConversion,
611 // Shift Ops
612 CombShlOpConversion, CombShrUOpConversion, CombShrSOpConversion,
613 // Variadic ops that must be lowered to binary operations
614 CombLowerVariadicOp<XorOp>, CombLowerVariadicOp<AddOp>,
615 CombLowerVariadicOp<MulOp>>(patterns.getContext());
616}
617
618void ConvertCombToAIGPass::runOnOperation() {
619 ConversionTarget target(getContext());
620
621 // Comb is source dialect.
622 target.addIllegalDialect<comb::CombDialect>();
623 // Keep data movement operations like Extract, Concat and Replicate.
624 target.addLegalOp<comb::ExtractOp, comb::ConcatOp, comb::ReplicateOp,
626
627 // Treat array operations as illegal. Strictly speaking, other than array get
628 // operation with non-const index are legal in AIG but array types prevent a
629 // bunch of optimizations so just lower them to integer operations. It's
630 // required to run HWAggregateToComb pass before this pass.
632 hw::AggregateConstantOp>();
633
634 // AIG is target dialect.
635 target.addLegalDialect<aig::AIGDialect>();
636
637 // This is a test only option to add logical ops.
638 if (!additionalLegalOps.empty())
639 for (const auto &opName : additionalLegalOps)
640 target.addLegalOp(OperationName(opName, &getContext()));
641
642 RewritePatternSet patterns(&getContext());
644
645 if (failed(mlir::applyPartialConversion(getOperation(), target,
646 std::move(patterns))))
647 return signalPassFailure();
648}
assert(baseType &&"element must be base type")
static SmallVector< Value > extractBits(OpBuilder &builder, Value val)
Definition CombToAIG.cpp:33
static Value createShiftLogic(ConversionPatternRewriter &rewriter, Location loc, Value shiftAmount, int64_t maxShiftAmount, llvm::function_ref< Value(int64_t)> getPadding, llvm::function_ref< Value(int64_t)> getExtract)
Definition CombToAIG.cpp:47
static void populateCombToAIGConversionPatterns(RewritePatternSet &patterns)
static Value lowerFullyAssociativeOp(Operation &op, OperandRange operands, SmallVector< Operation * > &newOps)
Lower a variadic fully-associative operation into an expression tree.
create(low_bit, result_type, input=None)
Definition comb.py:187
create(data_type, value)
Definition hw.py:441
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition comb.py:1