CIRCT 23.0.0git
Loading...
Searching...
No Matches
SynthOps.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
16#include "mlir/Analysis/TopologicalSortUtils.h"
17#include "mlir/IR/BuiltinAttributes.h"
18#include "mlir/IR/Matchers.h"
19#include "mlir/IR/OpDefinition.h"
20#include "mlir/IR/PatternMatch.h"
21#include "mlir/IR/Value.h"
22#include "llvm/ADT/APInt.h"
23#include "llvm/ADT/SmallVector.h"
24#include "llvm/Support/Casting.h"
25#include "llvm/Support/LogicalResult.h"
26
27using namespace mlir;
28using namespace circt;
29using namespace circt::synth;
30using namespace circt::synth::aig;
31using namespace circt::comb;
32using namespace matchers;
33
34#define GET_OP_CLASSES
35#include "circt/Dialect/Synth/Synth.cpp.inc"
36
37namespace {
38
39inline llvm::KnownBits applyInversion(llvm::KnownBits value, bool inverted) {
40 if (inverted)
41 std::swap(value.Zero, value.One);
42 return value;
43}
44
45template <typename SubType>
46struct ComplementMatcher {
47 SubType lhs;
48 ComplementMatcher(SubType lhs) : lhs(std::move(lhs)) {}
49 bool match(Operation *op) {
50 auto boolOp = dyn_cast<BooleanLogicOpInterface>(op);
51 return boolOp && boolOp.getInputs().size() == 1 && boolOp.isInverted(0) &&
52 lhs.match(op->getOperand(0));
53 }
54};
55
56template <typename SubType>
57static inline ComplementMatcher<SubType> m_Complement(const SubType &subExpr) {
58 return ComplementMatcher<SubType>(subExpr);
59}
60
61} // namespace
62
63LogicalResult ChoiceOp::verify() {
64 if (getNumOperands() < 1)
65 return emitOpError("requires at least one operand");
66 return success();
67}
68
69OpFoldResult ChoiceOp::fold(FoldAdaptor adaptor) {
70 if (adaptor.getInputs().size() == 1)
71 return getOperand(0);
72 return {};
73}
74
75// Canonicalize a network of synth.choice operations by computing their
76// transitive closure and flattening them into a single choice operation.
77// This merges nested choices and deduplicates shared operands.
78// Pattern matched:
79// %0 = synth.choice %x, %y, %z
80// %1 = synth.choice %0, %u
81// %2 = synth.choice %z, %v
82// =>
83// %merged = synth.choice %x, %y, %z, %u, %v
84LogicalResult ChoiceOp::canonicalize(ChoiceOp op, PatternRewriter &rewriter) {
85 llvm::SetVector<Value> worklist;
87
88 auto addToWorklist = [&](ChoiceOp choice) -> bool {
89 if (choice->getBlock() == op->getBlock() && visitedChoices.insert(choice)) {
90 worklist.insert(choice.getInputs().begin(), choice.getInputs().end());
91 return true;
92 }
93 return false;
94 };
95
96 addToWorklist(op);
97
98 bool mergedOtherChoices = false;
99
100 // Look up and down at definitions and users.
101 for (unsigned i = 0; i < worklist.size(); ++i) {
102 Value val = worklist[i];
103 if (auto defOp = val.getDefiningOp<synth::ChoiceOp>()) {
104
105 if (addToWorklist(defOp))
106 mergedOtherChoices = true;
107 }
108
109 for (Operation *user : val.getUsers()) {
110 if (auto userChoice = llvm::dyn_cast<synth::ChoiceOp>(user)) {
111 if (addToWorklist(userChoice)) {
112 mergedOtherChoices = true;
113 }
114 }
115 }
116 }
117
118 llvm::SmallVector<mlir::Value> finalOperands;
119 for (Value v : worklist) {
120 if (!visitedChoices.contains(v.getDefiningOp())) {
121 finalOperands.push_back(v);
122 }
123 }
124
125 if (!mergedOtherChoices && finalOperands.size() == op.getInputs().size())
126 return llvm::failure();
127
128 auto newChoice = synth::ChoiceOp::create(rewriter, op->getLoc(), op.getType(),
129 finalOperands);
130 for (Operation *visited : visitedChoices.takeVector())
131 rewriter.replaceOp(visited, newChoice);
132
133 for (auto value : newChoice.getInputs())
134 rewriter.replaceAllUsesExcept(value, newChoice.getResult(), newChoice);
135
136 return success();
137}
138
139//===----------------------------------------------------------------------===//
140// AndInverterOp
141//===----------------------------------------------------------------------===//
142
143bool AndInverterOp::areInputsPermutationInvariant() { return true; }
144
145OpFoldResult AndInverterOp::fold(FoldAdaptor adaptor) {
146 if (getNumOperands() == 1 && !isInverted(0))
147 return getOperand(0);
148
149 auto inputs = adaptor.getInputs();
150 if (inputs.size() == 2)
151 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1])) {
152 auto value = intAttr.getValue();
153 if (isInverted(1))
154 value = ~value;
155 if (value.isZero())
156 return IntegerAttr::get(
157 IntegerType::get(getContext(), value.getBitWidth()), value);
158 if (value.isAllOnes()) {
159 if (isInverted(0))
160 return {};
161
162 return getOperand(0);
163 }
164 }
165 return {};
166}
167
168LogicalResult AndInverterOp::canonicalize(AndInverterOp op,
169 PatternRewriter &rewriter) {
171 SmallVector<Value> uniqueValues;
172 SmallVector<bool> uniqueInverts;
173
174 APInt constValue =
175 APInt::getAllOnes(op.getResult().getType().getIntOrFloatBitWidth());
176
177 bool invertedConstFound = false;
178 bool flippedFound = false;
179
180 for (auto [value, inverted] : llvm::zip(op.getInputs(), op.getInverted())) {
181 bool newInverted = inverted;
182 if (auto constOp = value.getDefiningOp<hw::ConstantOp>()) {
183 if (inverted) {
184 constValue &= ~constOp.getValue();
185 invertedConstFound = true;
186 } else {
187 constValue &= constOp.getValue();
188 }
189 continue;
190 }
191
192 if (auto andInverterOp = value.getDefiningOp<synth::aig::AndInverterOp>()) {
193 if (andInverterOp.getInputs().size() == 1 &&
194 andInverterOp.isInverted(0)) {
195 value = andInverterOp.getOperand(0);
196 newInverted = andInverterOp.isInverted(0) ^ inverted;
197 flippedFound = true;
198 }
199 }
200
201 auto it = seen.find(value);
202 if (it == seen.end()) {
203 seen.insert({value, newInverted});
204 uniqueValues.push_back(value);
205 uniqueInverts.push_back(newInverted);
206 } else if (it->second != newInverted) {
207 // replace with const 0
208 rewriter.replaceOpWithNewOp<hw::ConstantOp>(
209 op, APInt::getZero(value.getType().getIntOrFloatBitWidth()));
210 return success();
211 }
212 }
213
214 // If the constant is zero, we can just replace with zero.
215 if (constValue.isZero()) {
216 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, constValue);
217 return success();
218 }
219
220 // No change.
221 if ((uniqueValues.size() == op.getInputs().size() && !flippedFound) ||
222 (!constValue.isAllOnes() && !invertedConstFound &&
223 uniqueValues.size() + 1 == op.getInputs().size()))
224 return failure();
225
226 if (!constValue.isAllOnes()) {
227 auto constOp = hw::ConstantOp::create(rewriter, op.getLoc(), constValue);
228 uniqueInverts.push_back(false);
229 uniqueValues.push_back(constOp);
230 }
231
232 // It means the input is reduced to all ones.
233 if (uniqueValues.empty()) {
234 rewriter.replaceOpWithNewOp<hw::ConstantOp>(op, constValue);
235 return success();
236 }
237
238 // build new op with reduced input values
239 replaceOpWithNewOpAndCopyNamehint<synth::aig::AndInverterOp>(
240 rewriter, op, uniqueValues, uniqueInverts);
241 return success();
242}
243
244APInt AndInverterOp::evaluateBooleanLogicWithoutInversion(
245 llvm::ArrayRef<APInt> inputs) {
246 assert(!inputs.empty() && "expected non-empty input list");
247 APInt result = APInt::getAllOnes(inputs.front().getBitWidth());
248 for (const APInt &input : inputs)
249 result &= input;
250 return result;
251}
252
253bool AndInverterOp::supportsNumInputs(unsigned numInputs) {
254 return numInputs >= 1;
255}
256
257llvm::KnownBits AndInverterOp::computeKnownBits(
258 llvm::function_ref<const llvm::KnownBits &(unsigned)> getInputKnownBits) {
259 assert(getNumOperands() > 0 && "Expected non-empty input list");
260
261 auto width = getInputKnownBits(0).getBitWidth();
262 llvm::KnownBits result(width);
263 result.One = APInt::getAllOnes(width);
264 result.Zero = APInt::getZero(width);
265
266 for (auto [i, inverted] : llvm::enumerate(getInverted()))
267 result &= applyInversion(getInputKnownBits(i), inverted);
268
269 return result;
270}
271
272int64_t AndInverterOp::getLogicDepthCost() {
273 return llvm::Log2_64_Ceil(getNumOperands());
274}
275
276std::optional<uint64_t> AndInverterOp::getLogicAreaCost() {
277 int64_t bitWidth = hw::getBitWidth(getType());
278 if (bitWidth < 0)
279 return std::nullopt;
280 return static_cast<uint64_t>(getNumOperands() - 1) * bitWidth;
281}
282
283void AndInverterOp::emitCNFWithoutInversion(
284 int outVar, llvm::ArrayRef<int> inputVars,
285 llvm::function_ref<void(llvm::ArrayRef<int>)> addClause,
286 llvm::function_ref<int()> newVar) {
287 (void)newVar;
288 circt::addAndClauses(outVar, inputVars, addClause);
289}
290
291//===----------------------------------------------------------------------===//
292// XorInverterOp
293//===----------------------------------------------------------------------===//
294
295bool XorInverterOp::areInputsPermutationInvariant() { return true; }
296
297OpFoldResult XorInverterOp::fold(FoldAdaptor adaptor) {
298 // xor_inv(a) -> a
299 if (getNumOperands() == 1 && !isInverted(0))
300 return getOperand(0);
301
302 auto inputs = adaptor.getInputs();
303 if (inputs.size() == 2)
304 if (auto intAttr = dyn_cast_or_null<IntegerAttr>(inputs[1])) {
305 auto value = intAttr.getValue();
306 if (isInverted(1))
307 value = ~value;
308 // xor_inv(a, 0000000) -> a
309 if (value.isZero())
310 return getOperand(0);
311 }
312 return {};
313}
314
315LogicalResult XorInverterOp::canonicalize(XorInverterOp op,
316 PatternRewriter &rewriter) {
317
318 // Map to store active (non-canceled) operands and their inversion state
319 SmallMapVector<Value, bool, 4> activeOperands;
320
321 // XOR identity is zero; accumulate all constant operands here.
322 APInt constValue =
323 APInt::getZero(op.getResult().getType().getIntOrFloatBitWidth());
324
325 bool constFound = false;
326 bool changed = false;
327
328 for (auto [value, inverted] : llvm::zip(op.getInputs(), op.getInverted())) {
329 Value currentValue = value;
330 bool newInverted = inverted;
331
332 // xor_inv(a, c0, c1) -> xor_inv(a, c0 ^ c1)
333 // xor_inv(a, not c0) -> xor_inv(a, ~c0)
334 if (auto constOp = currentValue.getDefiningOp<hw::ConstantOp>()) {
335 APInt val = constOp.getValue();
336 if (newInverted)
337 val = ~val;
338 constValue ^= val;
339 constFound = true;
340 continue;
341 }
342
343 // xor_inv(a, not (xor_inv/aig_inv not b)) -> xor_inv(a, b)
344 Value matchedVal;
345 if (newInverted &&
346 matchPattern(currentValue, m_Complement(m_Any(&matchedVal)))) {
347 currentValue = matchedVal;
348 newInverted = false; // double inversion cancels out
349 changed = true;
350 }
351
352 // xor_inv (a, a, b) -> b
353 // xor_inv (a, not a, b) -> ~b
354 if (activeOperands.count(currentValue)) {
355 // If we see the value again, they cancel out.
356 // If one was inverted and the other wasn't (x ^ ~x), it results in a '1'.
357 if (activeOperands[currentValue] != newInverted)
358 constValue.flipAllBits();
359 activeOperands.erase(currentValue);
360 changed = true;
361 } else {
362 activeOperands[currentValue] = newInverted;
363 }
364 }
365
366 // No constants were folded and no operands cancelled out. There is nothing to
367 // do.
368 if (!changed && !constFound && activeOperands.size() == op.getInputs().size())
369 return failure();
370
371 // xor_inv(a, 1111111) -> xor_inv(not a)
372 // xor_inv(a, c0, c1) -> xor_inv(a, c0^c1)
373 if (!constValue.isZero()) {
374 if (constValue.isAllOnes() && !activeOperands.empty()) {
375 // Propagate ones as an inversion on the last operand.
376 activeOperands.back().second = !activeOperands.back().second;
377 } else {
378 if (op.getInputs().size() == 2 && !op.getInverted()[1] &&
379 activeOperands.size() == 1)
380 return failure();
381 auto constOp = hw::ConstantOp::create(rewriter, op.getLoc(), constValue);
382 activeOperands.insert({constOp, false});
383 }
384 }
385
386 if (activeOperands.empty()) {
387 rewriter.replaceOpWithNewOp<hw::ConstantOp>(
388 op, APInt::getZero(op.getResult().getType().getIntOrFloatBitWidth()));
389 return success();
390 }
391
392 replaceOpAndCopyNamehint(rewriter, op,
393 XorInverterOp::create(rewriter, op.getLoc(),
394 activeOperands.getArrayRef()));
395 return success();
396}
397
398APInt XorInverterOp::evaluateBooleanLogicWithoutInversion(
399 llvm::ArrayRef<APInt> inputs) {
400 assert(!inputs.empty() && "expected non-empty input list");
401 APInt result = APInt::getZero(inputs.front().getBitWidth());
402 for (const APInt &input : inputs)
403 result ^= input;
404 return result;
405}
406
407bool XorInverterOp::supportsNumInputs(unsigned numInputs) {
408 return numInputs >= 1;
409}
410
411llvm::KnownBits XorInverterOp::computeKnownBits(
412 llvm::function_ref<const llvm::KnownBits &(unsigned)> getInputKnownBits) {
413 assert(getNumOperands() > 0 && "Expected non-empty input list");
414
415 llvm::KnownBits result(getInputKnownBits(0).getBitWidth());
416 for (auto [i, inverted] : llvm::enumerate(getInverted()))
417 result ^= applyInversion(getInputKnownBits(i), inverted);
418 return result;
419}
420
421int64_t XorInverterOp::getLogicDepthCost() {
422 return llvm::Log2_64_Ceil(getNumOperands());
423}
424
425std::optional<uint64_t> XorInverterOp::getLogicAreaCost() {
426 int64_t bitWidth = hw::getBitWidth(getType());
427 if (bitWidth < 0)
428 return std::nullopt;
429 return static_cast<uint64_t>(getNumOperands() - 1) * bitWidth;
430}
431
432void XorInverterOp::emitCNFWithoutInversion(
433 int outVar, llvm::ArrayRef<int> inputVars,
434 llvm::function_ref<void(llvm::ArrayRef<int>)> addClause,
435 llvm::function_ref<int()> newVar) {
436 circt::addParityClauses(outVar, inputVars, addClause, newVar);
437}
438
439//===----------------------------------------------------------------------===//
440// DotOp
441//===----------------------------------------------------------------------===//
442
443ParseResult DotOp::parse(OpAsmParser &parser, OperationState &result) {
444 SmallVector<OpAsmParser::UnresolvedOperand> operands;
445 Type resultType;
446 DenseBoolArrayAttr inverted;
447 NamedAttrList attrs;
448
449 if (parseVariadicInvertibleOperands(parser, operands, resultType, inverted,
450 attrs))
451 return failure();
452 if (operands.size() != 3)
453 return parser.emitError(parser.getCurrentLocation())
454 << "expected exactly three operands";
455 if (parser.resolveOperands(operands, resultType, result.operands))
456 return failure();
457
458 result.addTypes(resultType);
459 result.addAttributes(attrs);
460 result.addAttribute("inverted", inverted);
461 return success();
462}
463
464void DotOp::print(OpAsmPrinter &printer) {
465 printer << ' ';
466 printVariadicInvertibleOperands(printer, getOperation(), getOperands(),
467 getType(), getInvertedAttr(),
468 (*this)->getAttrDictionary());
469}
470
471LogicalResult DotOp::verify() {
472 if (getInverted().size() != 3)
473 return emitOpError("requires exactly three inversion flags");
474 return success();
475}
476
477APInt DotOp::evaluateBooleanLogicWithoutInversion(
478 llvm::ArrayRef<APInt> inputs) {
479 assert(supportsNumInputs(inputs.size()) &&
480 "dot expects exactly three operands");
481 return evaluateDotLogic(inputs[0], inputs[1], inputs[2]);
482}
483
484bool DotOp::areInputsPermutationInvariant() { return false; }
485
486bool DotOp::supportsNumInputs(unsigned numInputs) { return numInputs == 3; }
487
488llvm::KnownBits DotOp::computeKnownBits(
489 llvm::function_ref<const llvm::KnownBits &(unsigned)> getInputKnownBits) {
490 auto x = applyInversion(getInputKnownBits(0), isInverted(0));
491 auto y = applyInversion(getInputKnownBits(1), isInverted(1));
492 auto z = applyInversion(getInputKnownBits(2), isInverted(2));
493 return evaluateDotLogic(x, y, z);
494}
495
496std::optional<uint64_t> DotOp::getLogicAreaCost() {
497 int64_t bitWidth = hw::getBitWidth(getType());
498 if (bitWidth < 0)
499 return std::nullopt;
500 return static_cast<uint64_t>(bitWidth);
501}
502
503void DotOp::emitCNFWithoutInversion(
504 int outVar, llvm::ArrayRef<int> inputVars,
505 llvm::function_ref<void(llvm::ArrayRef<int>)> addClause,
506 llvm::function_ref<int()> newVar) {
507 assert(inputVars.size() == 3 && "expected one SAT variable per operand");
508 int andVar = newVar();
509 int orVar = newVar();
510 // andVar = x and y
511 circt::addAndClauses(andVar, {inputVars[0], inputVars[1]}, addClause);
512 // orVar = z or andVar
513 circt::addOrClauses(orVar, {inputVars[2], andVar}, addClause);
514 // outVar = x xor orVar
515 circt::addXorClauses(outVar, inputVars[0], orVar, addClause);
516}
517
518//===----------------------------------------------------------------------===//
519// MajorityOp
520//===----------------------------------------------------------------------===//
521
522ParseResult MajorityOp::parse(OpAsmParser &parser, OperationState &result) {
523 SmallVector<OpAsmParser::UnresolvedOperand> operands;
524 Type resultType;
525 DenseBoolArrayAttr inverted;
526 NamedAttrList attrs;
527
528 if (parseVariadicInvertibleOperands(parser, operands, resultType, inverted,
529 attrs))
530 return failure();
531 if (operands.size() != 3)
532 return parser.emitError(parser.getCurrentLocation())
533 << "expected exactly three operands";
534 if (parser.resolveOperands(operands, resultType, result.operands))
535 return failure();
536
537 result.addTypes(resultType);
538 result.addAttributes(attrs);
539 result.addAttribute("inverted", inverted);
540 return success();
541}
542
543void MajorityOp::print(OpAsmPrinter &printer) {
544 printer << ' ';
545 printVariadicInvertibleOperands(printer, getOperation(), getOperands(),
546 getType(), getInvertedAttr(),
547 (*this)->getAttrDictionary());
548}
549
550LogicalResult MajorityOp::verify() {
551 if (getNumOperands() != 3)
552 return emitOpError("requires exactly three operands");
553 if (getInverted().size() != 3)
554 return emitOpError("requires exactly three inversion flags");
555 return success();
556}
557
558bool MajorityOp::areInputsPermutationInvariant() { return true; }
559
560bool MajorityOp::supportsNumInputs(unsigned numInputs) {
561 return numInputs == 3;
562}
563
564std::optional<uint64_t> MajorityOp::getLogicAreaCost() {
565 int64_t bitWidth = hw::getBitWidth(getType());
566 if (bitWidth < 0)
567 return std::nullopt;
568 return static_cast<uint64_t>(bitWidth);
569}
570
571llvm::KnownBits MajorityOp::computeKnownBits(
572 llvm::function_ref<const llvm::KnownBits &(unsigned)> getInputKnownBits) {
573 auto a = applyInversion(getInputKnownBits(0), isInverted(0));
574 auto b = applyInversion(getInputKnownBits(1), isInverted(1));
575 auto c = applyInversion(getInputKnownBits(2), isInverted(2));
576 return evaluateMajorityLogic(a, b, c);
577}
578
579APInt MajorityOp::evaluateBooleanLogicWithoutInversion(
580 llvm::ArrayRef<APInt> inputs) {
581 assert(inputs.size() == 3 && "majority requires exactly three inputs");
582 return evaluateMajorityLogic(inputs[0], inputs[1], inputs[2]);
583}
584
585void MajorityOp::emitCNFWithoutInversion(
586 int outVar, llvm::ArrayRef<int> inputVars,
587 llvm::function_ref<void(llvm::ArrayRef<int>)> addClause,
588 llvm::function_ref<int()> newVar) {
589 assert(inputVars.size() == 3 && "expected exactly three inputs");
590 int ab = newVar();
591 int ac = newVar();
592 int bc = newVar();
593 // ab = a & b
594 circt::addAndClauses(ab, {inputVars[0], inputVars[1]}, addClause);
595 // ac = a & c
596 circt::addAndClauses(ac, {inputVars[0], inputVars[2]}, addClause);
597 // bc = b & c
598 circt::addAndClauses(bc, {inputVars[1], inputVars[2]}, addClause);
599 // out = ab | ac | bc
600 circt::addOrClauses(outVar, {ab, ac, bc}, addClause);
601}
602
604 Location loc, ValueRange operands, ArrayRef<bool> inverts,
605 PatternRewriter &rewriter,
606 llvm::function_ref<Value(Value, bool)> createUnary,
607 llvm::function_ref<Value(Value, Value, bool, bool)> createBinary) {
608 switch (operands.size()) {
609 case 0:
610 assert(0 && "cannot be called with empty operand range");
611 break;
612 case 1:
613 return inverts[0] ? createUnary(operands[0], true) : operands[0];
614 case 2:
615 return createBinary(operands[0], operands[1], inverts[0], inverts[1]);
616 default:
617 auto firstHalf = operands.size() / 2;
618 auto lhs = lowerVariadicInvertibleOp(loc, operands.take_front(firstHalf),
619 inverts.take_front(firstHalf),
620 rewriter, createUnary, createBinary);
621 auto rhs = lowerVariadicInvertibleOp(loc, operands.drop_front(firstHalf),
622 inverts.drop_front(firstHalf),
623 rewriter, createUnary, createBinary);
624 return createBinary(lhs, rhs, false, false);
625 }
626 return Value();
627}
628
629template <typename OpTy>
631 PatternRewriter &rewriter) {
632 if (op.getInputs().size() <= 2)
633 return failure();
634 auto result = lowerVariadicInvertibleOp(
635 op.getLoc(), op.getOperands(), op.getInverted(), rewriter,
636 [&](Value input, bool invert) {
637 return OpTy::create(rewriter, op.getLoc(), input, invert);
638 },
639 [&](Value lhs, Value rhs, bool invertLhs, bool invertRhs) {
640 return OpTy::create(rewriter, op.getLoc(), lhs, rhs, invertLhs,
641 invertRhs);
642 });
643 replaceOpAndCopyNamehint(rewriter, op, result);
644 return success();
645}
646
648 RewritePatternSet &patterns) {
649 patterns.add(lowerVariadicAndInverterOpConversion<aig::AndInverterOp>);
650}
651
653 RewritePatternSet &patterns) {
654 patterns.add(lowerVariadicAndInverterOpConversion<XorInverterOp>);
655}
656
657bool circt::synth::isLogicNetworkOp(Operation *op) {
658 return isa<synth::BooleanLogicOpInterface, synth::ChoiceOp, comb::ExtractOp,
659 comb::ReplicateOp, comb::ConcatOp>(op);
660}
661
663 mlir::Operation *op,
664 llvm::function_ref<bool(mlir::Value, mlir::Operation *)> isOperandReady) {
665 // Sort the operations topologically
666 auto walkResult = op->walk([&](Region *region) {
667 auto regionKindOp =
668 dyn_cast<mlir::RegionKindInterface>(region->getParentOp());
669 if (!regionKindOp ||
670 regionKindOp.hasSSADominance(region->getRegionNumber()))
671 return WalkResult::advance();
672
673 // Graph region.
674 for (auto &block : *region) {
675 if (!mlir::sortTopologically(&block, isOperandReady))
676 return WalkResult::interrupt();
677 }
678 return WalkResult::advance();
679 });
680
681 return success(!walkResult.wasInterrupted());
682}
683
684//===----------------------------------------------------------------------===//
685// OneHotOp
686//===----------------------------------------------------------------------===//
687
688ParseResult OneHotOp::parse(OpAsmParser &parser, OperationState &result) {
689 SmallVector<OpAsmParser::UnresolvedOperand> operands;
690 Type resultType;
691 DenseBoolArrayAttr inverted;
692 NamedAttrList attrs;
693
694 if (parseVariadicInvertibleOperands(parser, operands, resultType, inverted,
695 attrs))
696 return failure();
697 if (operands.size() != 3)
698 return parser.emitError(parser.getCurrentLocation())
699 << "expected exactly three operands";
700 if (parser.resolveOperands(operands, resultType, result.operands))
701 return failure();
702
703 result.addTypes(resultType);
704 result.addAttributes(attrs);
705 result.addAttribute("inverted", inverted);
706 return success();
707}
708
709void OneHotOp::print(OpAsmPrinter &printer) {
710 printer << ' ';
711 printVariadicInvertibleOperands(printer, getOperation(), getOperands(),
712 getType(), getInvertedAttr(),
713 (*this)->getAttrDictionary());
714}
715
716LogicalResult OneHotOp::verify() {
717 if (getNumOperands() != 3)
718 return emitOpError("requires exactly three operands");
719 if (getInverted().size() != 3)
720 return emitOpError("requires exactly three inversion flags");
721 return success();
722}
723
724bool OneHotOp::areInputsPermutationInvariant() { return true; }
725
726bool OneHotOp::supportsNumInputs(unsigned numInputs) { return numInputs == 3; }
727
728std::optional<uint64_t> OneHotOp::getLogicAreaCost() {
729 int64_t bitWidth = hw::getBitWidth(getType());
730 if (bitWidth < 0)
731 return std::nullopt;
732 return static_cast<uint64_t>(bitWidth);
733}
734
735llvm::KnownBits OneHotOp::computeKnownBits(
736 llvm::function_ref<const llvm::KnownBits &(unsigned)> getInputKnownBits) {
737 auto a = applyInversion(getInputKnownBits(0), isInverted(0));
738 auto b = applyInversion(getInputKnownBits(1), isInverted(1));
739 auto c = applyInversion(getInputKnownBits(2), isInverted(2));
740 return evaluateOneHotLogic(a, b, c);
741}
742
743APInt OneHotOp::evaluateBooleanLogicWithoutInversion(
744 llvm::ArrayRef<APInt> inputs) {
745 assert(inputs.size() == 3 && "onehot requires exactly three inputs");
746 return evaluateOneHotLogic(inputs[0], inputs[1], inputs[2]);
747}
748
749void OneHotOp::emitCNFWithoutInversion(
750 int outVar, llvm::ArrayRef<int> inputVars,
751 llvm::function_ref<void(llvm::ArrayRef<int>)> addClause,
752 llvm::function_ref<int()> newVar) {
753 assert(inputVars.size() == 3 && "expected exactly three inputs");
754
755 // parity = a ^ b ^ c.
756 int parity = newVar();
757 circt::addParityClauses(parity, inputVars, addClause, newVar);
758
759 // allSet = a & b & c.
760 int allSet = newVar();
761 circt::addAndClauses(allSet, inputVars, addClause);
762
763 // out = (a ^ b ^ c) & ~(a & b & c).
764 circt::addAndClauses(outVar, {parity, -allSet}, addClause);
765}
766
767//===----------------------------------------------------------------------===//
768// MuxInverterOp
769//===----------------------------------------------------------------------===//
770
771LogicalResult MuxInverterOp::verify() {
772 if (getNumOperands() != 3)
773 return emitOpError("requires exactly three operands");
774 if (getInverted().size() != 3)
775 return emitOpError("requires exactly three inversion flags");
776 return success();
777}
778
779bool MuxInverterOp::areInputsPermutationInvariant() { return false; }
780
781APInt MuxInverterOp::evaluateBooleanLogicWithoutInversion(
782 llvm::ArrayRef<APInt> inputs) {
783 assert(inputs.size() == 3 && "expected exactly three inputs");
784 return evaluateMuxLogic(inputs[0], inputs[1], inputs[2]);
785}
786
787bool MuxInverterOp::supportsNumInputs(unsigned numInputs) {
788 return numInputs == 3;
789}
790
791llvm::KnownBits MuxInverterOp::computeKnownBits(
792 llvm::function_ref<const llvm::KnownBits &(unsigned)> getInputKnownBits) {
793 assert(getNumOperands() == 3 && "expected exactly three inputs");
794
795 auto a = applyInversion(getInputKnownBits(0), isInverted(0));
796 auto b = applyInversion(getInputKnownBits(1), isInverted(1));
797 auto c = applyInversion(getInputKnownBits(2), isInverted(2));
798
799 return evaluateMuxLogic(a, b, c);
800}
801
802int64_t MuxInverterOp::getLogicDepthCost() { return 2; }
803
804std::optional<uint64_t> MuxInverterOp::getLogicAreaCost() {
805 int64_t bitWidth = hw::getBitWidth(getType());
806 if (bitWidth < 0)
807 return std::nullopt;
808 return static_cast<uint64_t>(bitWidth);
809}
810
811void MuxInverterOp::emitCNFWithoutInversion(
812 int outVar, llvm::ArrayRef<int> inputVars,
813 llvm::function_ref<void(llvm::ArrayRef<int>)> addClause,
814 llvm::function_ref<int()> newVar) {
815 assert(inputVars.size() == 3 && "expected exactly three inputs");
816
817 int cond = inputVars[0];
818 int trueValue = inputVars[1];
819 int falseValue = inputVars[2];
820
821 int lhs = newVar();
822 int rhs = newVar();
823
824 // lhs = cond & trueValue
825 circt::addAndClauses(lhs, {cond, trueValue}, addClause);
826 // rhs = ~cond & falseValue
827 circt::addAndClauses(rhs, {-cond, falseValue}, addClause);
828 // out = lhs | rhs
829 circt::addOrClauses(outVar, {lhs, rhs}, addClause);
830}
assert(baseType &&"element must be base type")
static ComplementMatcher< SubType > m_Complement(const SubType &subExpr)
Definition CombFolds.cpp:85
LogicalResult lowerVariadicAndInverterOpConversion(OpTy op, PatternRewriter &rewriter)
Definition SynthOps.cpp:630
static Value lowerVariadicInvertibleOp(Location loc, ValueRange operands, ArrayRef< bool > inverts, PatternRewriter &rewriter, llvm::function_ref< Value(Value, bool)> createUnary, llvm::function_ref< Value(Value, Value, bool, bool)> createBinary)
Definition SynthOps.cpp:603
create(data_type, value)
Definition hw.py:433
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
void populateVariadicXorInverterLoweringPatterns(mlir::RewritePatternSet &patterns)
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:662
T evaluateDotLogic(const T &x, const T &y, const T &z)
Evaluate the Boolean function x ^ (z | (x & y)).
Definition SynthOps.h:134
T evaluateMajorityLogic(const T &a, const T &b, const T &c)
Definition SynthOps.h:139
T evaluateMuxLogic(const T &a, const T &b, const T &c)
Definition SynthOps.h:160
bool isLogicNetworkOp(mlir::Operation *op)
T evaluateOneHotLogic(const T &a, const T &b, const T &c)
Definition SynthOps.h:154
void populateVariadicAndInverterLoweringPatterns(mlir::RewritePatternSet &patterns)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseVariadicInvertibleOperands(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, Type &resultType, mlir::DenseBoolArrayAttr &inverted, NamedAttrList &attrDict)
Parse a variadic list of operands that may be prefixed with an optional not keyword.
void addAndClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> and(inputLits).
void addXorClauses(int outVar, int lhsLit, int rhsLit, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> (lhsLit xor rhsLit).
void printVariadicInvertibleOperands(OpAsmPrinter &printer, Operation *op, OperandRange operands, Type resultType, mlir::DenseBoolArrayAttr inverted, DictionaryAttr attrDict)
Print a variadic list of operands that may be prefixed with an optional not keyword.
void replaceOpAndCopyNamehint(PatternRewriter &rewriter, Operation *op, Value newValue)
A wrapper of PatternRewriter::replaceOp to propagate "sv.namehint" attribute.
Definition Naming.cpp:73
void addOrClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> or(inputLits).
void addParityClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause, llvm::function_ref< int()> newVar)
Emit clauses encoding outVar <=> parity(inputLits).