CIRCT 22.0.0git
Loading...
Searching...
No Matches
CutRewriter.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a DAG-based boolean matching cut rewriting algorithm for
10// applications like technology/LUT mapping and combinational logic
11// optimization. The algorithm uses priority cuts and NPN
12// (Negation-Permutation-Negation) canonical forms to efficiently match cuts
13// against rewriting patterns.
14//
15// References:
16// "Combinational and Sequential Mapping with Priority Cuts", Alan Mishchenko,
17// Sungmin Cho, Satrajit Chatterjee and Robert Brayton, ICCAD 2007
18//
19//===----------------------------------------------------------------------===//
20
22
26#include "circt/Support/LLVM.h"
29#include "mlir/Analysis/TopologicalSortUtils.h"
30#include "mlir/IR/Builders.h"
31#include "mlir/IR/Operation.h"
32#include "mlir/IR/RegionKindInterface.h"
33#include "mlir/IR/Value.h"
34#include "mlir/IR/ValueRange.h"
35#include "mlir/IR/Visitors.h"
36#include "mlir/Support/LLVM.h"
37#include "llvm/ADT/APInt.h"
38#include "llvm/ADT/Bitset.h"
39#include "llvm/ADT/DenseMap.h"
40#include "llvm/ADT/MapVector.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/ScopeExit.h"
43#include "llvm/ADT/SetVector.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/iterator.h"
46#include "llvm/Support/Debug.h"
47#include "llvm/Support/ErrorHandling.h"
48#include "llvm/Support/LogicalResult.h"
49#include <algorithm>
50#include <functional>
51#include <memory>
52#include <optional>
53#include <string>
54
55#define DEBUG_TYPE "synth-cut-rewriter"
56
57using namespace circt;
58using namespace circt::synth;
59
60static bool isSupportedLogicOp(mlir::Operation *op) {
61 // Check if the operation is a combinational operation that can be simulated
62 // TODO: Extend this to allow comb.and/xor/or as well.
63 return isa<aig::AndInverterOp>(op);
64}
65
66static void simulateLogicOp(Operation *op, DenseMap<Value, llvm::APInt> &eval) {
68 "Operation must be a supported logic operation for simulation");
69
70 // Simulate the operation by evaluating its inputs and computing the output
71 // This is a simplified simulation for demonstration purposes
72 if (auto andOp = dyn_cast<aig::AndInverterOp>(op)) {
73 SmallVector<llvm::APInt, 2> inputs;
74 inputs.reserve(andOp.getInputs().size());
75 for (auto input : andOp.getInputs()) {
76 auto it = eval.find(input);
77 if (it == eval.end())
78 llvm::report_fatal_error("Input value not found in evaluation map");
79 inputs.push_back(it->second);
80 }
81 // Evaluate the and inverter
82 eval[andOp.getResult()] = andOp.evaluate(inputs);
83 return;
84 }
85
86 llvm::report_fatal_error(
87 "Unsupported operation for simulation. isSupportedLogicOp should "
88 "be used to check if the operation can be simulated.");
89}
90
91// Return true if the value is always a cut input.
92static bool isAlwaysCutInput(Value value) {
93 auto *op = value.getDefiningOp();
94 // If the value has no defining operation, it is an input
95 if (!op)
96 return true;
97
98 if (op->hasTrait<OpTrait::ConstantLike>()) {
99 // Constant values are never cut inputs.
100 return false;
101 }
102
103 return !isSupportedLogicOp(op);
104}
105
106// Return true if the new area/delay is better than the old area/delay in the
107// context of the given strategy.
109 ArrayRef<DelayType> newDelay, double oldArea,
110 ArrayRef<DelayType> oldDelay) {
112 // Compare by area first.
113 return newArea < oldArea || (newArea == oldArea && newDelay < oldDelay);
114 }
116 // Compare by delay first.
117 return newDelay < oldDelay || (newDelay == oldDelay && newArea < oldArea);
118 }
119 llvm_unreachable("Unknown mapping strategy");
120}
121
122LogicalResult
124
125 auto isOperationReady = [](Value value, Operation *op) -> bool {
126 // Topologically sort simulatable ops and purely
127 // dataflow ops. Other operations can be scheduled.
128 return !(isSupportedLogicOp(op) ||
129 isa<comb::ExtractOp, comb::ReplicateOp, comb::ConcatOp>(op));
130 };
131
132 auto result = topologicallySortGraphRegionBlocks(topOp, isOperationReady);
133 if (failed(result))
134 return mlir::emitError(topOp->getLoc(),
135 "failed to sort operations topologically");
136 return success();
137}
138
139/// Get the truth table for an op.
140template <typename OpRange>
141FailureOr<BinaryTruthTable> static computeTruthTable(
142 mlir::ValueRange values, const OpRange &ops,
143 const llvm::SmallSetVector<mlir::Value, 4> &inputArgs) {
144 // Create a truth table for the operation
145 int64_t numInputs = inputArgs.size();
146 int64_t numOutputs = values.size();
147 if (LLVM_UNLIKELY(numOutputs != 1 || numInputs >= maxTruthTableInputs)) {
148 if (numOutputs == 0)
149 return BinaryTruthTable(numInputs, 0);
150 if (numInputs >= maxTruthTableInputs)
151 return mlir::emitError(values.front().getLoc(),
152 "Truth table is too large");
153 return mlir::emitError(values.front().getLoc(),
154 "Multiple outputs are not supported yet");
155 }
156
157 // Create a truth table with the given number of inputs and outputs
158 BinaryTruthTable truthTable(numInputs, numOutputs);
159 // The truth table size is 2^numInputs
160 uint32_t tableSize = 1 << numInputs;
161 // Create a map to evaluate the operation
162 DenseMap<Value, APInt> eval;
163 for (uint32_t i = 0; i < numInputs; ++i) {
164 // Create alternating bit pattern for input i
165 // For input i, bits alternate every 2^i positions
166 uint32_t blockSize = 1 << i;
167 // Create the repeating pattern: blockSize zeros followed by blockSize ones
168 uint32_t patternWidth = 2 * blockSize;
169 APInt pattern(patternWidth, 0);
170 assert(patternWidth <= tableSize && "Pattern width exceeds table size");
171 // Set the upper half of the pattern to 1s
172 pattern = APInt::getHighBitsSet(patternWidth, blockSize);
173 // Use getSplat to repeat this pattern across the full
174 // table size
175
176 eval[inputArgs[i]] = APInt::getSplat(tableSize, pattern);
177 }
178 // Simulate the operation
179 for (auto *op : ops) {
180 if (op->getNumResults() == 0)
181 continue; // Skip operations with no results
182 if (!isSupportedLogicOp(op))
183 return op->emitError("Unsupported operation for truth table simulation");
184
185 // Simulate the operation
186 simulateLogicOp(op, eval);
187 }
188 // TODO: Currently numOutputs is always 1, so we can just return the first
189 // one.
190 return BinaryTruthTable(numInputs, 1, eval[values[0]]);
191}
192
193FailureOr<BinaryTruthTable> circt::synth::getTruthTable(ValueRange values,
194 Block *block) {
195 // Get the input arguments from the block
196 llvm::SmallSetVector<Value, 4> inputs;
197 for (auto arg : block->getArguments())
198 inputs.insert(arg);
199
200 // If there are no inputs, return an empty truth table
201 if (inputs.empty())
202 return BinaryTruthTable();
203
204 return computeTruthTable(values, llvm::make_pointer_range(*block), inputs);
205}
206
207//===----------------------------------------------------------------------===//
208// Cut
209//===----------------------------------------------------------------------===//
210
211bool Cut::isTrivialCut() const {
212 // A cut is a trival cut if it has no operations and only one input
213 return operations.empty() && inputs.size() == 1;
214}
215
216mlir::Operation *Cut::getRoot() const {
217 return operations.empty()
218 ? nullptr
219 : operations.back(); // The last operation is the root
220}
221
223 // If the NPN is already computed, return it
224 if (npnClass)
225 return *npnClass;
226
227 auto truthTable = getTruthTable();
228
229 // Compute the NPN canonical form
230 auto canonicalForm = NPNClass::computeNPNCanonicalForm(truthTable);
231
232 npnClass.emplace(std::move(canonicalForm));
233 return *npnClass;
234}
235
236void Cut::getPermutatedInputs(const NPNClass &patternNPN,
237 SmallVectorImpl<Value> &permutedInputs) const {
238 auto npnClass = getNPNClass();
239 SmallVector<unsigned> idx;
240 npnClass.getInputPermutation(patternNPN, idx);
241 permutedInputs.reserve(idx.size());
242 for (auto inputIndex : idx) {
243 assert(inputIndex < inputs.size() && "Input index out of bounds");
244 permutedInputs.push_back(inputs[inputIndex]);
245 }
246}
247
248void Cut::dump(llvm::raw_ostream &os) const {
249 os << "// === Cut Dump ===\n";
250 os << "Cut with " << getInputSize() << " inputs and " << operations.size()
251 << " operations:\n";
252 if (isTrivialCut()) {
253 os << "Primary input cut: " << *inputs.begin() << "\n";
254 return;
255 }
256
257 os << "Inputs: \n";
258 for (auto [idx, input] : llvm::enumerate(inputs)) {
259 os << " Input " << idx << ": " << input << "\n";
260 }
261 os << "\nOperations: \n";
262 for (auto *op : operations) {
263 op->print(os);
264 os << "\n";
265 }
266 auto &npnClass = getNPNClass();
267 npnClass.dump(os);
268
269 os << "// === Cut End ===\n";
270}
271
272unsigned Cut::getInputSize() const { return inputs.size(); }
273
274unsigned Cut::getOutputSize() const { return getRoot()->getNumResults(); }
275
277 if (truthTable)
278 return *truthTable;
279
280 if (isTrivialCut()) {
281 // For a trivial cut, a truth table is simply the identity function.
282 // 0 -> 0, 1 -> 1
283 truthTable = BinaryTruthTable(1, 1, {llvm::APInt(2, 2)});
284 return *truthTable;
285 }
286
287 // Create a truth table with the given number of inputs and outputs
289
290 return *truthTable;
291}
292
293static Cut getAsTrivialCut(mlir::Value value) {
294 // Create a cut with a single root operation
295 Cut cut;
296 // There is no input for the primary input cut.
297 cut.inputs.insert(value);
298
299 return cut;
300}
301
302[[maybe_unused]] static bool isCutDerivedFromOperand(const Cut &cut,
303 Operation *op) {
304 if (auto *root = cut.getRoot())
305 return llvm::any_of(op->getOperands(),
306 [&](Value v) { return v.getDefiningOp() == root; });
307
308 assert(cut.isTrivialCut());
309 // If the cut is trivial, it has no operations, so it must be a primary input.
310 // In this case, the only operation that can be derived from it is the
311 // primary input itself.
312 return cut.inputs.size() == 1 &&
313 llvm::any_of(op->getOperands(),
314 [&](Value v) { return v == *cut.inputs.begin(); });
315}
316
317Cut Cut::mergeWith(const Cut &other, Operation *root) const {
318 assert(isCutDerivedFromOperand(*this, root) &&
319 isCutDerivedFromOperand(other, root) &&
320 "The operation must be a child of the current root operation");
321
322 // Create a new cut that combines this cut and the other cut
323 Cut newCut;
324 // Topological sort the operations in the new cut.
325 // TODO: Merge-sort `operations` and `other.operations` by operation index
326 // (since it's already topo-sorted, we can use a simple merge).
327 std::function<void(Operation *)> populateOperations = [&](Operation *op) {
328 // If the operation is already in the cut, skip it
329 if (newCut.operations.contains(op))
330 return;
331
332 // Add its operands to the worklist
333 for (auto value : op->getOperands()) {
334 if (isAlwaysCutInput(value))
335 continue;
336
337 // If the value is in *both* cuts inputs, it is an input. So skip
338 // it.
339 bool isInput = inputs.contains(value);
340 bool isOtherInput = other.inputs.contains(value);
341 // If the value is in this cut inputs, it is an input. So skip it
342 if (isInput && isOtherInput)
343 continue;
344
345 auto *defOp = value.getDefiningOp();
346
347 assert(defOp && "Value must have a defining operation since block"
348 "arguments are treated as inputs");
349
350 // Otherwise, check if the operation is in the other cut.
351 if (isInput)
352 if (!other.operations.contains(defOp)) // op is in the other cut.
353 continue;
354 if (isOtherInput)
355 if (!operations.contains(defOp)) // op is in this cut.
356 continue;
357 populateOperations(defOp);
358 }
359
360 // Add the operation to the cut
361 newCut.operations.insert(op);
362 };
363
364 populateOperations(root);
365
366 // Construct inputs.
367 for (auto *operation : newCut.operations) {
368 for (auto value : operation->getOperands()) {
369 if (isAlwaysCutInput(value)) {
370 newCut.inputs.insert(value);
371 continue;
372 }
373
374 auto *defOp = value.getDefiningOp();
375 assert(defOp && "Value must have a defining operation");
376
377 // If the operation is not in the cut, it is an input
378 if (!newCut.operations.contains(defOp))
379 // Add the input to the cut
380 newCut.inputs.insert(value);
381 }
382 }
383
384 // TODO: Sort the inputs by their defining operation.
385 // TODO: Update area and delay based on the merged cuts.
386
387 return newCut;
388}
389
390// Reroot the cut with a new root operation.
391// This is used to create a new cut with the same inputs and operations, but a
392// different root operation.
393Cut Cut::reRoot(Operation *root) const {
394 assert(isCutDerivedFromOperand(*this, root) &&
395 "The operation must be a child of the current root operation");
396 Cut newCut;
397 newCut.inputs = inputs;
398 newCut.operations = operations;
399 // Add the new root operation to the cut
400 newCut.operations.insert(root);
401 return newCut;
402}
403
404//===----------------------------------------------------------------------===//
405// MatchedPattern
406//===----------------------------------------------------------------------===//
407
408ArrayRef<DelayType> MatchedPattern::getArrivalTimes() const {
409 assert(pattern && "Pattern must be set to get arrival time");
410 return arrivalTimes;
411}
412
414 assert(pattern && "Pattern must be set to get arrival time");
415 return arrivalTimes[index];
416}
417
419 assert(pattern && "Pattern must be set to get the pattern");
420 return pattern;
421}
422
424 assert(pattern && "Pattern must be set to get area");
425 return pattern->getArea();
426}
427
429 unsigned outputIndex) const {
430 assert(pattern && "Pattern must be set to get delay");
431 return pattern->getDelay(inputIndex, outputIndex);
432}
433
434//===----------------------------------------------------------------------===//
435// CutSet
436//===----------------------------------------------------------------------===//
437
439
440unsigned CutSet::size() const { return cuts.size(); }
441
443 assert(!isFrozen && "Cannot add cuts to a frozen cut set");
444 cuts.push_back(std::move(cut));
445}
446
447ArrayRef<Cut> CutSet::getCuts() const { return cuts; }
448
449// Remove duplicate cuts and non-minimal cuts. A cut is non-minimal if there
450// exists another cut that is a subset of it. We use a bitset to represent the
451// inputs of each cut for efficient subset checking.
452static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl<Cut> &cuts) {
453 // First sort the cuts by input size (ascending). This ensures that when we
454 // iterate through the cuts, we always encounter smaller cuts first, allowing
455 // us to efficiently check for non-minimality. Stable sort to maintain
456 // relative order of cuts with the same input size.
457 std::stable_sort(cuts.begin(), cuts.end(), [](const Cut &a, const Cut &b) {
458 return a.getInputSize() < b.getInputSize();
459 });
460
461 llvm::SmallVector<llvm::Bitset<64>, 4> inputBitMasks;
462 DenseMap<Value, unsigned> inputIndices;
463 auto getIndex = [&](Value v) -> unsigned {
464 auto it = inputIndices.find(v);
465 if (it != inputIndices.end())
466 return it->second;
467 unsigned index = inputIndices.size();
468 if (LLVM_UNLIKELY(index >= 64))
469 llvm::report_fatal_error(
470 "Too many unique inputs across cuts. Max 64 supported. Consider "
471 "increasing the compile-time constant.");
472 inputIndices[v] = index;
473 return index;
474 };
475
476 for (unsigned i = 0; i < cuts.size(); ++i) {
477 auto &cut = cuts[i];
478 // Create a unique identifier for the cut based on its inputs.
479 llvm::Bitset<64> inputsMask;
480 for (auto input : cut.inputs.getArrayRef())
481 inputsMask.set(getIndex(input));
482
483 bool isUnique = llvm::all_of(
484 inputBitMasks, [&](const llvm::Bitset<64> &existingCutInputMask) {
485 // If the bitset is a subset of the current inputsMask, it is not
486 // unique
487 return (existingCutInputMask & inputsMask) != existingCutInputMask;
488 });
489
490 if (!isUnique)
491 continue;
492
493 // If the cut is unique, keep it
494 size_t uniqueCount = inputBitMasks.size();
495 if (i != uniqueCount)
496 cuts[uniqueCount] = std::move(cut);
497 inputBitMasks.push_back(inputsMask);
498 }
499
500 unsigned uniqueCount = inputBitMasks.size();
501
502 LLVM_DEBUG(llvm::dbgs() << "Original cuts: " << cuts.size()
503 << " Unique cuts: " << uniqueCount << "\n");
504
505 // Resize the cuts vector to the number of unique cuts found
506 cuts.resize(uniqueCount);
507}
508
510 const CutRewriterOptions &options,
511 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut) {
512
513 // Step 1: Remove duplicate and non-minimal cuts to reduce the search space
514 // This eliminates cuts that are strictly dominated by others
516
517 // Step 2: Match each remaining cut against available patterns
518 // This computes timing and area information needed for prioritization
519 for (auto &cut : cuts) {
520 // Verify cut doesn't exceed input size limits
521 assert(cut.getInputSize() <= options.maxCutInputSize &&
522 "Cut input size exceeds maximum allowed size");
523
524 // Attempt to match the cut against available patterns
525 auto matched = matchCut(cut);
526 if (!matched)
527 continue; // No matching pattern found for this cut
528
529 // Store the matched pattern with the cut for later evaluation
530 cut.setMatchedPattern(std::move(*matched));
531 }
532
533 // Step 3: Sort cuts by priority to select the best ones
534 // Priority is determined by the optimization strategy:
535 // - Trivial cuts (direct connections) have highest priority
536 // - Among matched cuts, compare by area/delay based on the strategy
537 // - Matched cuts are preferred over unmatched cuts
538 // See "Combinational and Sequential Mapping with Priority Cuts" by Mishchenko
539 // et al., ICCAD 2007 for more details.
540 // TODO: Use a priority queue instead of sorting for better performance.
541
542 // Partition the cuts into trivial and non-trivial cuts.
543 auto *trivialCutsEnd =
544 std::stable_partition(cuts.begin(), cuts.end(),
545 [](const Cut &cut) { return cut.isTrivialCut(); });
546
547 std::stable_sort(trivialCutsEnd, cuts.end(),
548 [&options](const Cut &a, const Cut &b) -> bool {
549 assert(!a.isTrivialCut() && !b.isTrivialCut() &&
550 "Trivial cuts should have been excluded");
551 const auto &aMatched = a.getMatchedPattern();
552 const auto &bMatched = b.getMatchedPattern();
553
554 // Both cuts have matched patterns.
555 if (aMatched && bMatched)
556 return compareDelayAndArea(
557 options.strategy, aMatched->getArea(),
558 aMatched->getArrivalTimes(), bMatched->getArea(),
559 bMatched->getArrivalTimes());
560
561 // Prefer cuts with matched patterns over those without
562 if (aMatched && !bMatched)
563 return true;
564 if (!aMatched && bMatched)
565 return false;
566
567 // Both cuts are unmatched - prefer smaller input size
568 return a.getInputSize() < b.getInputSize();
569 });
570
571 // Step 4: Limit the number of cuts to prevent exponential growth
572 // After sorting, keep only the best cuts up to the specified limit
573 if (cuts.size() > options.maxCutSizePerRoot)
574 cuts.resize(options.maxCutSizePerRoot);
575
576 // Select the best cut from the remaining candidates
577 for (auto &cut : cuts) {
578 const auto &currentMatch = cut.getMatchedPattern();
579 if (!currentMatch)
580 continue; // Skip cuts without matched patterns
581
582 // This is already sorted, so the first matched cut is the best.
583 bestCut = &cut;
584 break;
585 }
586
587 LLVM_DEBUG({
588 llvm::dbgs() << "Finalized cut set with " << cuts.size() << " cuts and "
589 << (bestCut
590 ? "matched pattern to " + bestCut->getMatchedPattern()
591 ->getPattern()
592 ->getPatternName()
593 : "no matched pattern")
594 << "\n";
595 });
596
597 isFrozen = true; // Mark the cut set as frozen
598}
599
600//===----------------------------------------------------------------------===//
601// CutRewritePattern
602//===----------------------------------------------------------------------===//
603
605 SmallVectorImpl<NPNClass> &matchingNPNClasses) const {
606 return false;
607}
608
609//===----------------------------------------------------------------------===//
610// CutRewritePatternSet
611//===----------------------------------------------------------------------===//
612
614 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns)
615 : patterns(std::move(patterns)) {
616 // Initialize the NPN to pattern map
617 for (auto &pattern : this->patterns) {
618 SmallVector<NPNClass, 2> npnClasses;
619 auto result = pattern->useTruthTableMatcher(npnClasses);
620 if (result) {
621 for (auto npnClass : npnClasses) {
622 // Create a NPN class from the truth table
623 npnToPatternMap[{npnClass.truthTable.table,
624 npnClass.truthTable.numInputs}]
625 .push_back(std::make_pair(std::move(npnClass), pattern.get()));
626 }
627 } else {
628 // If the pattern does not provide NPN classes, we use a special key
629 // to indicate that it should be considered for all cuts.
630 nonTruthTablePatterns.push_back(pattern.get());
631 }
632 }
633}
634
635//===----------------------------------------------------------------------===//
636// CutEnumerator
637//===----------------------------------------------------------------------===//
638
640 : options(options) {}
641
643 auto [cutSetPtr, inserted] =
644 cutSets.try_emplace(value, std::make_unique<CutSet>());
645 assert(inserted && "Cut set already exists for this value");
646 return cutSetPtr->second.get();
647}
648
649llvm::MapVector<Value, std::unique_ptr<CutSet>> CutEnumerator::takeVector() {
650 return std::move(cutSets);
651}
652
653void CutEnumerator::clear() { cutSets.clear(); }
654
655LogicalResult CutEnumerator::visit(Operation *op) {
656 if (isSupportedLogicOp(op))
657 return visitLogicOp(op);
658
659 // Skip other operations. If the operation is not a supported logic
660 // operation, we create a trivial cut lazily.
661 return success();
662}
663
664LogicalResult CutEnumerator::visitLogicOp(Operation *logicOp) {
665 assert(logicOp->getNumResults() == 1 &&
666 "Logic operation must have a single result");
667
668 Value result = logicOp->getResult(0);
669 unsigned numOperands = logicOp->getNumOperands();
670
671 // Validate operation constraints
672 // TODO: Variadic operations and non-single-bit results can be supported
673 if (numOperands > 2)
674 return logicOp->emitError("Cut enumeration supports at most 2 operands, "
675 "found: ")
676 << numOperands;
677 if (!logicOp->getOpResult(0).getType().isInteger(1))
678 return logicOp->emitError()
679 << "Supported logic operations must have a single bit "
680 "result type but found: "
681 << logicOp->getResult(0).getType();
682
683 SmallVector<const CutSet *, 2> operandCutSets;
684 operandCutSets.reserve(numOperands);
685 // Collect cut sets for each operand
686 for (unsigned i = 0; i < numOperands; ++i) {
687 auto *operandCutSet = getCutSet(logicOp->getOperand(i));
688 if (!operandCutSet)
689 return logicOp->emitError("Failed to get cut set for operand ")
690 << i << ": " << logicOp->getOperand(i);
691 operandCutSets.push_back(operandCutSet);
692 }
693
694 // Create the singleton cut (just this operation)
695 Cut primaryInputCut = getAsTrivialCut(result);
696
697 auto *resultCutSet = createNewCutSet(result);
698
699 // Add the singleton cut first
700 resultCutSet->addCut(primaryInputCut);
701
702 // Schedule cut set finalization when exiting this scope
703 auto prune = llvm::make_scope_exit([&]() {
704 // Finalize cut set: remove duplicates, limit size, and match patterns
705 resultCutSet->finalize(options, matchCut);
706 });
707
708 // Handle unary operations
709 if (numOperands == 1) {
710 const auto &inputCutSet = operandCutSets[0];
711
712 // Try to extend each input cut by including this operation
713 for (const Cut &inputCut : inputCutSet->getCuts()) {
714 Cut extendedCut = inputCut.reRoot(logicOp);
715 // Skip cuts that exceed input size limit
716 if (extendedCut.getInputSize() > options.maxCutInputSize)
717 continue;
718
719 resultCutSet->addCut(std::move(extendedCut));
720 }
721 return success();
722 }
723
724 // Handle binary operations (like AND, OR, XOR gates)
725 assert(numOperands == 2 && "Expected binary operation");
726
727 const auto *lhsCutSet = operandCutSets[0];
728 const auto *rhsCutSet = operandCutSets[1];
729
730 // Combine cuts from both inputs to create larger cuts
731 for (const Cut &lhsCut : lhsCutSet->getCuts()) {
732 for (const Cut &rhsCut : rhsCutSet->getCuts()) {
733 Cut mergedCut = lhsCut.mergeWith(rhsCut, logicOp);
734 // Skip cuts that exceed input size limit
735 if (mergedCut.getInputSize() > options.maxCutInputSize)
736 continue;
737
738 resultCutSet->addCut(std::move(mergedCut));
739 }
740 }
741
742 return success();
743}
744
746 Operation *topOp,
747 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut) {
748 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts for module: " << topOp->getName()
749 << "\n");
750 // Topologically sort the logic network
751 if (failed(topologicallySortLogicNetwork(topOp)))
752 return failure();
753
754 // Store the pattern matching function for use during cut finalization
755 this->matchCut = matchCut;
756
757 // Walk through all operations in the module in a topological manner
758 auto result = topOp->walk([&](Operation *op) {
759 if (failed(visit(op)))
760 return mlir::WalkResult::interrupt();
761 return mlir::WalkResult::advance();
762 });
763
764 if (result.wasInterrupted())
765 return failure();
766
767 LLVM_DEBUG(llvm::dbgs() << "Cut enumeration completed successfully\n");
768 return success();
769}
770
771const CutSet *CutEnumerator::getCutSet(Value value) {
772 // Check if cut set already exists
773 auto *it = cutSets.find(value);
774 if (it == cutSets.end()) {
775 // Create new cut set for an unprocessed value
776 auto cutSet = std::make_unique<CutSet>();
777 cutSet->addCut(getAsTrivialCut(value));
778 auto [newIt, inserted] = cutSets.insert({value, std::move(cutSet)});
779 assert(inserted && "Cut set already exists for this value");
780 (void)newIt;
781 it = newIt;
782 }
783
784 return it->second.get();
785}
786
787/// Generate a human-readable name for a value used in test output.
788/// This function creates meaningful names for values to make debug output
789/// and test results more readable and understandable.
790static StringRef
791getTestVariableName(Value value, DenseMap<OperationName, unsigned> &opCounter) {
792 if (auto *op = value.getDefiningOp()) {
793 // Handle values defined by operations
794 // First, check if the operation already has a name hint attribute
795 if (auto name = op->getAttrOfType<StringAttr>("sv.namehint"))
796 return name.getValue();
797
798 // For single-result operations, generate a unique name based on operation
799 // type
800 if (op->getNumResults() == 1) {
801 auto opName = op->getName();
802 auto count = opCounter[opName]++;
803
804 // Create a unique name by appending a counter to the operation name
805 SmallString<16> nameStr;
806 nameStr += opName.getStringRef();
807 nameStr += "_";
808 nameStr += std::to_string(count);
809
810 // Store the generated name as a hint attribute for future reference
811 auto nameAttr = StringAttr::get(op->getContext(), nameStr);
812 op->setAttr("sv.namehint", nameAttr);
813 return nameAttr;
814 }
815
816 // Multi-result operations or other cases get a generic name
817 return "<unknown>";
818 }
819
820 // Handle block arguments
821 auto blockArg = cast<BlockArgument>(value);
822 auto hwOp =
823 dyn_cast<circt::hw::HWModuleOp>(blockArg.getOwner()->getParentOp());
824 if (!hwOp)
825 return "<unknown>";
826
827 // Return the formal input name from the hardware module
828 return hwOp.getInputName(blockArg.getArgNumber());
829}
830
832 DenseMap<OperationName, unsigned> opCounter;
833 for (auto &[value, cutSetPtr] : cutSets) {
834 auto &cutSet = *cutSetPtr;
835 llvm::outs() << getTestVariableName(value, opCounter) << " "
836 << cutSet.getCuts().size() << " cuts:";
837 for (const Cut &cut : cutSet.getCuts()) {
838 llvm::outs() << " {";
839 llvm::interleaveComma(cut.inputs, llvm::outs(), [&](Value input) {
840 llvm::outs() << getTestVariableName(input, opCounter);
841 });
842 auto &pattern = cut.getMatchedPattern();
843 llvm::outs() << "}"
844 << "@t" << cut.getTruthTable().table.getZExtValue() << "d";
845 if (pattern) {
846 llvm::outs() << *std::max_element(pattern->getArrivalTimes().begin(),
847 pattern->getArrivalTimes().end());
848 } else {
849 llvm::outs() << "0";
850 }
851 }
852 llvm::outs() << "\n";
853 }
854 llvm::outs() << "Cut enumeration completed successfully\n";
855}
856
857//===----------------------------------------------------------------------===//
858// CutRewriter
859//===----------------------------------------------------------------------===//
860
861LogicalResult CutRewriter::run(Operation *topOp) {
862 LLVM_DEBUG({
863 llvm::dbgs() << "Starting Cut Rewriter\n";
864 llvm::dbgs() << "Mode: "
866 : "timing")
867 << "\n";
868 llvm::dbgs() << "Max input size: " << options.maxCutInputSize << "\n";
869 llvm::dbgs() << "Max cut size: " << options.maxCutSizePerRoot << "\n";
870 llvm::dbgs() << "Max cuts per node: " << options.maxCutSizePerRoot << "\n";
871 });
872
873 // Currrently we don't support patterns with multiple outputs.
874 // So check that.
875 // TODO: This must be removed when we support multiple outputs.
876 for (auto &pattern : patterns.patterns) {
877 if (pattern->getNumOutputs() > 1) {
878 return mlir::emitError(pattern->getLoc(),
879 "Cut rewriter does not support patterns with "
880 "multiple outputs yet");
881 }
882 }
883
884 // First sort the operations topologically to ensure we can process them
885 // in a valid order.
886 if (failed(topologicallySortLogicNetwork(topOp)))
887 return failure();
888
889 // Enumerate cuts for all nodes
890 if (failed(enumerateCuts(topOp)))
891 return failure();
892
893 // Dump cuts if testing priority cuts.
896 return success();
897 }
898
899 // Select best cuts and perform mapping
900 if (failed(runBottomUpRewrite(topOp)))
901 return failure();
902
903 return success();
904}
905
906LogicalResult CutRewriter::enumerateCuts(Operation *topOp) {
907 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts...\n");
908
910 topOp, [&](const Cut &cut) -> std::optional<MatchedPattern> {
911 // Match the cut against the patterns
912 return patternMatchCut(cut);
913 });
914}
915
916ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
918 if (patterns.npnToPatternMap.empty())
919 return {};
920
921 auto &npnClass = cut.getNPNClass();
922 auto it = patterns.npnToPatternMap.find(
923 {npnClass.truthTable.table, npnClass.truthTable.numInputs});
924 if (it == patterns.npnToPatternMap.end())
925 return {};
926 return it->getSecond();
927}
928
929std::optional<MatchedPattern> CutRewriter::patternMatchCut(const Cut &cut) {
930 if (cut.isTrivialCut())
931 return {};
932
933 const CutRewritePattern *bestPattern = nullptr;
934 SmallVector<DelayType, 4> inputArrivalTimes;
935 SmallVector<DelayType, 1> bestArrivalTimes;
936 inputArrivalTimes.reserve(cut.getInputSize());
937 bestArrivalTimes.reserve(cut.getOutputSize());
938
939 // Compute arrival times for each input.
940 for (auto input : cut.inputs) {
941 assert(input.getType().isInteger(1));
942 if (isAlwaysCutInput(input)) {
943 // If the input is a primary input, it has no delay.
944 // TODO: This doesn't consider a global delay. Need to capture
945 // `arrivalTime` on the IR to make the primary input delays visible.
946 inputArrivalTimes.push_back(0);
947 continue;
948 }
949 auto *cutSet = cutEnumerator.getCutSet(input);
950 assert(cutSet && "Input must have a valid cut set");
951
952 // If there is no matching pattern, it means it's not possible to use the
953 // input in the cut rewriting. So abort early.
954 auto *bestCut = cutSet->getBestMatchedCut();
955 if (!bestCut)
956 return {};
957
958 const auto &matchedPattern = *bestCut->getMatchedPattern();
959
960 // Otherwise, the cut input is an op result. Get the arrival time
961 // from the matched pattern.
962 inputArrivalTimes.push_back(matchedPattern.getArrivalTime(
963 cast<mlir::OpResult>(input).getResultNumber()));
964 }
965
966 auto computeArrivalTimeAndPickBest =
967 [&](const CutRewritePattern *pattern,
968 llvm::function_ref<unsigned(unsigned)> mapIndex) {
969 SmallVector<DelayType, 1> outputArrivalTimes;
970 // Compute the maximum delay for each output from inputs.
971 for (unsigned outputIndex = 0, outputSize = cut.getOutputSize();
972 outputIndex < outputSize; ++outputIndex) {
973 // Compute the arrival time for this output.
974 DelayType outputArrivalTime = 0;
975 for (unsigned inputIndex = 0, inputSize = cut.getInputSize();
976 inputIndex < inputSize; ++inputIndex) {
977 // Map pattern input i to cut input through NPN transformations
978 unsigned cutOriginalInput = mapIndex(inputIndex);
979 outputArrivalTime =
980 std::max(outputArrivalTime,
981 pattern->getDelay(cutOriginalInput, outputIndex) +
982 inputArrivalTimes[cutOriginalInput]);
983 }
984
985 outputArrivalTimes.push_back(outputArrivalTime);
986 }
987
988 // Update the arrival time
989 if (!bestPattern ||
991 outputArrivalTimes, bestPattern->getArea(),
992 bestArrivalTimes)) {
993 LLVM_DEBUG({
994 llvm::dbgs() << "== Matched Pattern ==============\n";
995 llvm::dbgs() << "Matching cut: \n";
996 cut.dump(llvm::dbgs());
997 llvm::dbgs() << "Found better pattern: "
998 << pattern->getPatternName();
999 llvm::dbgs() << " with area: " << pattern->getArea();
1000 llvm::dbgs() << " and input arrival times: ";
1001 for (unsigned i = 0; i < inputArrivalTimes.size(); ++i) {
1002 llvm::dbgs() << " " << inputArrivalTimes[i];
1003 }
1004 llvm::dbgs() << " and arrival times: ";
1005
1006 for (auto arrivalTime : outputArrivalTimes) {
1007 llvm::dbgs() << " " << arrivalTime;
1008 }
1009 llvm::dbgs() << "\n";
1010 llvm::dbgs() << "== Matched Pattern End ==============\n";
1011 });
1012
1013 bestArrivalTimes = std::move(outputArrivalTimes);
1014 bestPattern = pattern;
1015 }
1016 };
1017
1018 for (auto &[patternNPN, pattern] : getMatchingPatternsFromTruthTable(cut)) {
1019 assert(patternNPN.truthTable.numInputs == cut.getInputSize() &&
1020 "Pattern input size must match cut input size");
1021 if (!pattern->match(cut))
1022 continue;
1023 auto &cutNPN = cut.getNPNClass();
1024
1025 // Get the input mapping from pattern's NPN class to cut's NPN class
1026 SmallVector<unsigned> inputMapping;
1027 cutNPN.getInputPermutation(patternNPN, inputMapping);
1028 computeArrivalTimeAndPickBest(pattern,
1029 [&](unsigned i) { return inputMapping[i]; });
1030 }
1031
1032 for (const CutRewritePattern *pattern : patterns.nonTruthTablePatterns)
1033 if (pattern->match(cut))
1034 computeArrivalTimeAndPickBest(pattern, [&](unsigned i) { return i; });
1035
1036 if (!bestPattern)
1037 return {}; // No matching pattern found
1038
1039 return MatchedPattern(bestPattern, std::move(bestArrivalTimes));
1040}
1041
1042LogicalResult CutRewriter::runBottomUpRewrite(Operation *top) {
1043 LLVM_DEBUG(llvm::dbgs() << "Performing cut-based rewriting...\n");
1044 auto cutVector = cutEnumerator.takeVector();
1046 UnusedOpPruner pruner;
1047 PatternRewriter rewriter(top->getContext());
1048 for (auto &[value, cutSet] : llvm::reverse(cutVector)) {
1049 if (value.use_empty()) {
1050 if (auto *op = value.getDefiningOp())
1051 pruner.eraseNow(op);
1052 continue;
1053 }
1054
1055 if (isAlwaysCutInput(value)) {
1056 // If the value is a primary input, skip it
1057 LLVM_DEBUG(llvm::dbgs() << "Skipping inputs: " << value << "\n");
1058 continue;
1059 }
1060
1061 LLVM_DEBUG(llvm::dbgs() << "Cut set for value: " << value << "\n");
1062 auto *bestCut = cutSet->getBestMatchedCut();
1063 if (!bestCut) {
1065 continue; // No matching pattern found, skip this value
1066 return emitError(value.getLoc(), "No matching cut found for value: ")
1067 << value;
1068 }
1069
1070 rewriter.setInsertionPoint(bestCut->getRoot());
1071 const auto &matchedPattern = bestCut->getMatchedPattern();
1072 auto result = matchedPattern->getPattern()->rewrite(rewriter, *bestCut);
1073 if (failed(result))
1074 return failure();
1075
1076 rewriter.replaceOp(bestCut->getRoot(), *result);
1077
1079 auto array = rewriter.getI64ArrayAttr(matchedPattern->getArrivalTimes());
1080 (*result)->setAttr("test.arrival_times", array);
1081 }
1082 }
1083
1084 return success();
1085}
assert(baseType &&"element must be base type")
static bool isAlwaysCutInput(Value value)
static bool isSupportedLogicOp(mlir::Operation *op)
static Cut getAsTrivialCut(mlir::Value value)
static FailureOr< BinaryTruthTable > computeTruthTable(mlir::ValueRange values, const OpRange &ops, const llvm::SmallSetVector< mlir::Value, 4 > &inputArgs)
Get the truth table for an op.
static StringRef getTestVariableName(Value value, DenseMap< OperationName, unsigned > &opCounter)
Generate a human-readable name for a value used in test output.
static bool compareDelayAndArea(OptimizationStrategy strategy, double newArea, ArrayRef< DelayType > newDelay, double oldArea, ArrayRef< DelayType > oldDelay)
static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl< Cut > &cuts)
static bool isCutDerivedFromOperand(const Cut &cut, Operation *op)
static void simulateLogicOp(Operation *op, DenseMap< Value, llvm::APInt > &eval)
RewritePatternSet pattern
Strategy strategy
LogicalResult visit(Operation *op)
Visit a single operation and generate cuts for it.
const CutRewriterOptions & options
Configuration options for cut enumeration.
llvm::MapVector< Value, std::unique_ptr< CutSet > > cutSets
Maps values to their associated cut sets.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
CutEnumerator(const CutRewriterOptions &options)
Constructor for cut enumerator.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
LogicalResult visitLogicOp(Operation *logicOp)
Visit a combinational logic operation and generate cuts.
llvm::MapVector< Value, std::unique_ptr< CutSet > > takeVector()
Move ownership of all cut sets to caller.
CutSet * createNewCutSet(Value value)
Create a new cut set for a value.
const CutSet * getCutSet(Value value)
Get the cut set for a specific value.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
SmallVector< const CutRewritePattern *, 4 > nonTruthTablePatterns
Patterns that use custom matching logic instead of truth tables.
CutRewritePatternSet(llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns)
Constructor that takes ownership of the provided patterns.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
CutEnumerator cutEnumerator
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
llvm::SmallVector< Cut, 4 > cuts
Collection of cuts for this node.
void addCut(Cut cut)
Add a new cut to this set.
unsigned size() const
Get the number of cuts in this set.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut)
Finalize the cut set by removing duplicates and selecting the best pattern.
bool isFrozen
Whether cut set is finalized.
ArrayRef< Cut > getCuts() const
Get read-only access to all cuts in this set.
Represents a cut in the combinational logic network.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
llvm::SmallSetVector< mlir::Operation *, 4 > operations
Operations contained within this cut.
unsigned getOutputSize() const
Get the number of outputs from root operation.
void getPermutatedInputs(const NPNClass &patternNPN, SmallVectorImpl< Value > &permutedInputs) const
Get the permutated inputs for this cut based on the given pattern NPN.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
mlir::Operation * getRoot() const
Get the root operation of this cut.
llvm::SmallSetVector< mlir::Value, 4 > inputs
External inputs to this cut (cut boundary).
void dump(llvm::raw_ostream &os) const
const BinaryTruthTable & getTruthTable() const
Get the truth table for this cut.
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
Cut mergeWith(const Cut &other, Operation *root) const
Merge this cut with another cut to form a new cut.
unsigned getInputSize() const
Get the number of inputs to this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
Cut reRoot(Operation *root) const
Represents a cut that has been successfully matched to a rewriting pattern.
Definition CutRewriter.h:67
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
DelayType getDelay(unsigned inputIndex, unsigned outputIndex) const
Get the delay between specific input and output pins.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
Definition CutRewriter.h:69
double getArea() const
Get the area cost of using this pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
Definition CutRewriter.h:71
OptimizationStrategy
Optimization strategy.
Definition SynthPasses.h:24
@ OptimizationStrategyArea
Optimize for minimal area.
Definition SynthPasses.h:25
@ OptimizationStrategyTiming
Optimize for minimal critical path delay.
Definition SynthPasses.h:26
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
int64_t DelayType
Definition CutRewriter.h:36
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:307
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:41
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Represents a boolean function as a truth table.
Definition NPNClass.h:38
Represents the canonical form of a boolean function under NPN equivalence.
Definition NPNClass.h:103
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
Definition NPNClass.cpp:210
void getInputPermutation(const NPNClass &targetNPN, llvm::SmallVectorImpl< unsigned > &permutation) const
Get input permutation from this NPN class to another equivalent NPN class.
Definition NPNClass.cpp:187
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...
void eraseNow(Operation *op)
Erase an operation immediately, and remove it from the set of ops to be removed later.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual DelayType getDelay(unsigned inputIndex, unsigned outputIndex) const =0
Get the delay between specific input and output.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
virtual double getArea() const =0
Get the area cost of this pattern.
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.