CIRCT 22.0.0git
Loading...
Searching...
No Matches
CutRewriter.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a DAG-based boolean matching cut rewriting algorithm for
10// applications like technology/LUT mapping and combinational logic
11// optimization. The algorithm uses priority cuts and NPN
12// (Negation-Permutation-Negation) canonical forms to efficiently match cuts
13// against rewriting patterns.
14//
15// References:
16// "Combinational and Sequential Mapping with Priority Cuts", Alan Mishchenko,
17// Sungmin Cho, Satrajit Chatterjee and Robert Brayton, ICCAD 2007
18//
19//===----------------------------------------------------------------------===//
20
22
26#include "circt/Support/LLVM.h"
29#include "mlir/Analysis/TopologicalSortUtils.h"
30#include "mlir/IR/Builders.h"
31#include "mlir/IR/Operation.h"
32#include "mlir/IR/RegionKindInterface.h"
33#include "mlir/IR/Value.h"
34#include "mlir/IR/ValueRange.h"
35#include "mlir/IR/Visitors.h"
36#include "mlir/Support/LLVM.h"
37#include "llvm/ADT/APInt.h"
38#include "llvm/ADT/Bitset.h"
39#include "llvm/ADT/DenseMap.h"
40#include "llvm/ADT/MapVector.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/ScopeExit.h"
43#include "llvm/ADT/SetVector.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/iterator.h"
46#include "llvm/Support/Debug.h"
47#include "llvm/Support/ErrorHandling.h"
48#include "llvm/Support/LogicalResult.h"
49#include <algorithm>
50#include <functional>
51#include <memory>
52#include <optional>
53#include <string>
54
55#define DEBUG_TYPE "synth-cut-rewriter"
56
57using namespace circt;
58using namespace circt::synth;
59
60static bool isSupportedLogicOp(mlir::Operation *op) {
61 // Check if the operation is a combinational operation that can be simulated
62 // TODO: Extend this to allow comb.and/xor/or as well.
63 return isa<aig::AndInverterOp>(op);
64}
65
66static void simulateLogicOp(Operation *op, DenseMap<Value, llvm::APInt> &eval) {
68 "Operation must be a supported logic operation for simulation");
69
70 // Simulate the operation by evaluating its inputs and computing the output
71 // This is a simplified simulation for demonstration purposes
72 if (auto andOp = dyn_cast<aig::AndInverterOp>(op)) {
73 SmallVector<llvm::APInt, 2> inputs;
74 inputs.reserve(andOp.getInputs().size());
75 for (auto input : andOp.getInputs()) {
76 auto it = eval.find(input);
77 if (it == eval.end())
78 llvm::report_fatal_error("Input value not found in evaluation map");
79 inputs.push_back(it->second);
80 }
81 // Evaluate the and inverter
82 eval[andOp.getResult()] = andOp.evaluate(inputs);
83 return;
84 }
85
86 llvm::report_fatal_error(
87 "Unsupported operation for simulation. isSupportedLogicOp should "
88 "be used to check if the operation can be simulated.");
89}
90
91// Return true if the value is always a cut input.
92static bool isAlwaysCutInput(Value value) {
93 auto *op = value.getDefiningOp();
94 // If the value has no defining operation, it is an input
95 if (!op)
96 return true;
97
98 if (op->hasTrait<OpTrait::ConstantLike>()) {
99 // Constant values are never cut inputs.
100 return false;
101 }
102
103 return !isSupportedLogicOp(op);
104}
105
106// Return true if the new area/delay is better than the old area/delay in the
107// context of the given strategy.
109 ArrayRef<DelayType> newDelay, double oldArea,
110 ArrayRef<DelayType> oldDelay) {
112 // Compare by area first.
113 return newArea < oldArea || (newArea == oldArea && newDelay < oldDelay);
114 }
116 // Compare by delay first.
117 return newDelay < oldDelay || (newDelay == oldDelay && newArea < oldArea);
118 }
119 llvm_unreachable("Unknown mapping strategy");
120}
121
122LogicalResult
124
125 // Sort the operations topologically
126 auto walkResult = topOp->walk([&](Region *region) {
127 auto regionKindOp =
128 dyn_cast<mlir::RegionKindInterface>(region->getParentOp());
129 if (!regionKindOp ||
130 regionKindOp.hasSSADominance(region->getRegionNumber()))
131 return WalkResult::advance();
132
133 auto isOperationReady = [&](Value value, Operation *op) -> bool {
134 // Topologically sort simulatable ops and purely
135 // dataflow ops. Other operations can be scheduled.
136 return !(isSupportedLogicOp(op) ||
137 isa<comb::ExtractOp, comb::ReplicateOp, comb::ConcatOp>(op));
138 };
139
140 // Graph region.
141 for (auto &block : *region) {
142 if (!mlir::sortTopologically(&block, isOperationReady))
143 return WalkResult::interrupt();
144 }
145 return WalkResult::advance();
146 });
147
148 if (walkResult.wasInterrupted())
149 return mlir::emitError(topOp->getLoc(),
150 "failed to sort operations topologically");
151 return success();
152}
153
154/// Get the truth table for an op.
155template <typename OpRange>
156FailureOr<BinaryTruthTable> static computeTruthTable(
157 mlir::ValueRange values, const OpRange &ops,
158 const llvm::SmallSetVector<mlir::Value, 4> &inputArgs) {
159 // Create a truth table for the operation
160 int64_t numInputs = inputArgs.size();
161 int64_t numOutputs = values.size();
162 if (LLVM_UNLIKELY(numOutputs != 1 || numInputs >= maxTruthTableInputs)) {
163 if (numOutputs == 0)
164 return BinaryTruthTable(numInputs, 0);
165 if (numInputs >= maxTruthTableInputs)
166 return mlir::emitError(values.front().getLoc(),
167 "Truth table is too large");
168 return mlir::emitError(values.front().getLoc(),
169 "Multiple outputs are not supported yet");
170 }
171
172 // Create a truth table with the given number of inputs and outputs
173 BinaryTruthTable truthTable(numInputs, numOutputs);
174 // The truth table size is 2^numInputs
175 uint32_t tableSize = 1 << numInputs;
176 // Create a map to evaluate the operation
177 DenseMap<Value, APInt> eval;
178 for (uint32_t i = 0; i < numInputs; ++i) {
179 // Create alternating bit pattern for input i
180 // For input i, bits alternate every 2^i positions
181 uint32_t blockSize = 1 << i;
182 // Create the repeating pattern: blockSize zeros followed by blockSize ones
183 uint32_t patternWidth = 2 * blockSize;
184 APInt pattern(patternWidth, 0);
185 assert(patternWidth <= tableSize && "Pattern width exceeds table size");
186 // Set the upper half of the pattern to 1s
187 pattern = APInt::getHighBitsSet(patternWidth, blockSize);
188 // Use getSplat to repeat this pattern across the full
189 // table size
190
191 eval[inputArgs[i]] = APInt::getSplat(tableSize, pattern);
192 }
193 // Simulate the operation
194 for (auto *op : ops) {
195 if (op->getNumResults() == 0)
196 continue; // Skip operations with no results
197 if (!isSupportedLogicOp(op))
198 return op->emitError("Unsupported operation for truth table simulation");
199
200 // Simulate the operation
201 simulateLogicOp(op, eval);
202 }
203 // TODO: Currently numOutputs is always 1, so we can just return the first
204 // one.
205 return BinaryTruthTable(numInputs, 1, eval[values[0]]);
206}
207
208FailureOr<BinaryTruthTable> circt::synth::getTruthTable(ValueRange values,
209 Block *block) {
210 // Get the input arguments from the block
211 llvm::SmallSetVector<Value, 4> inputs;
212 for (auto arg : block->getArguments())
213 inputs.insert(arg);
214
215 // If there are no inputs, return an empty truth table
216 if (inputs.empty())
217 return BinaryTruthTable();
218
219 return computeTruthTable(values, llvm::make_pointer_range(*block), inputs);
220}
221
222//===----------------------------------------------------------------------===//
223// Cut
224//===----------------------------------------------------------------------===//
225
226bool Cut::isTrivialCut() const {
227 // A cut is a trival cut if it has no operations and only one input
228 return operations.empty() && inputs.size() == 1;
229}
230
231mlir::Operation *Cut::getRoot() const {
232 return operations.empty()
233 ? nullptr
234 : operations.back(); // The last operation is the root
235}
236
238 // If the NPN is already computed, return it
239 if (npnClass)
240 return *npnClass;
241
242 auto truthTable = getTruthTable();
243
244 // Compute the NPN canonical form
245 auto canonicalForm = NPNClass::computeNPNCanonicalForm(truthTable);
246
247 npnClass.emplace(std::move(canonicalForm));
248 return *npnClass;
249}
250
251void Cut::getPermutatedInputs(const NPNClass &patternNPN,
252 SmallVectorImpl<Value> &permutedInputs) const {
253 auto npnClass = getNPNClass();
254 SmallVector<unsigned> idx;
255 npnClass.getInputPermutation(patternNPN, idx);
256 permutedInputs.reserve(idx.size());
257 for (auto inputIndex : idx) {
258 assert(inputIndex < inputs.size() && "Input index out of bounds");
259 permutedInputs.push_back(inputs[inputIndex]);
260 }
261}
262
263void Cut::dump(llvm::raw_ostream &os) const {
264 os << "// === Cut Dump ===\n";
265 os << "Cut with " << getInputSize() << " inputs and " << operations.size()
266 << " operations:\n";
267 if (isTrivialCut()) {
268 os << "Primary input cut: " << *inputs.begin() << "\n";
269 return;
270 }
271
272 os << "Inputs: \n";
273 for (auto [idx, input] : llvm::enumerate(inputs)) {
274 os << " Input " << idx << ": " << input << "\n";
275 }
276 os << "\nOperations: \n";
277 for (auto *op : operations) {
278 op->print(os);
279 os << "\n";
280 }
281 auto &npnClass = getNPNClass();
282 npnClass.dump(os);
283
284 os << "// === Cut End ===\n";
285}
286
287unsigned Cut::getInputSize() const { return inputs.size(); }
288
289unsigned Cut::getOutputSize() const { return getRoot()->getNumResults(); }
290
292 if (truthTable)
293 return *truthTable;
294
295 if (isTrivialCut()) {
296 // For a trivial cut, a truth table is simply the identity function.
297 // 0 -> 0, 1 -> 1
298 truthTable = BinaryTruthTable(1, 1, {llvm::APInt(2, 2)});
299 return *truthTable;
300 }
301
302 // Create a truth table with the given number of inputs and outputs
304
305 return *truthTable;
306}
307
308static Cut getAsTrivialCut(mlir::Value value) {
309 // Create a cut with a single root operation
310 Cut cut;
311 // There is no input for the primary input cut.
312 cut.inputs.insert(value);
313
314 return cut;
315}
316
317[[maybe_unused]] static bool isCutDerivedFromOperand(const Cut &cut,
318 Operation *op) {
319 if (auto *root = cut.getRoot())
320 return llvm::any_of(op->getOperands(),
321 [&](Value v) { return v.getDefiningOp() == root; });
322
323 assert(cut.isTrivialCut());
324 // If the cut is trivial, it has no operations, so it must be a primary input.
325 // In this case, the only operation that can be derived from it is the
326 // primary input itself.
327 return cut.inputs.size() == 1 &&
328 llvm::any_of(op->getOperands(),
329 [&](Value v) { return v == *cut.inputs.begin(); });
330}
331
332Cut Cut::mergeWith(const Cut &other, Operation *root) const {
333 assert(isCutDerivedFromOperand(*this, root) &&
334 isCutDerivedFromOperand(other, root) &&
335 "The operation must be a child of the current root operation");
336
337 // Create a new cut that combines this cut and the other cut
338 Cut newCut;
339 // Topological sort the operations in the new cut.
340 // TODO: Merge-sort `operations` and `other.operations` by operation index
341 // (since it's already topo-sorted, we can use a simple merge).
342 std::function<void(Operation *)> populateOperations = [&](Operation *op) {
343 // If the operation is already in the cut, skip it
344 if (newCut.operations.contains(op))
345 return;
346
347 // Add its operands to the worklist
348 for (auto value : op->getOperands()) {
349 if (isAlwaysCutInput(value))
350 continue;
351
352 // If the value is in *both* cuts inputs, it is an input. So skip
353 // it.
354 bool isInput = inputs.contains(value);
355 bool isOtherInput = other.inputs.contains(value);
356 // If the value is in this cut inputs, it is an input. So skip it
357 if (isInput && isOtherInput)
358 continue;
359
360 auto *defOp = value.getDefiningOp();
361
362 assert(defOp && "Value must have a defining operation since block"
363 "arguments are treated as inputs");
364
365 // Otherwise, check if the operation is in the other cut.
366 if (isInput)
367 if (!other.operations.contains(defOp)) // op is in the other cut.
368 continue;
369 if (isOtherInput)
370 if (!operations.contains(defOp)) // op is in this cut.
371 continue;
372 populateOperations(defOp);
373 }
374
375 // Add the operation to the cut
376 newCut.operations.insert(op);
377 };
378
379 populateOperations(root);
380
381 // Construct inputs.
382 for (auto *operation : newCut.operations) {
383 for (auto value : operation->getOperands()) {
384 if (isAlwaysCutInput(value)) {
385 newCut.inputs.insert(value);
386 continue;
387 }
388
389 auto *defOp = value.getDefiningOp();
390 assert(defOp && "Value must have a defining operation");
391
392 // If the operation is not in the cut, it is an input
393 if (!newCut.operations.contains(defOp))
394 // Add the input to the cut
395 newCut.inputs.insert(value);
396 }
397 }
398
399 // TODO: Sort the inputs by their defining operation.
400 // TODO: Update area and delay based on the merged cuts.
401
402 return newCut;
403}
404
405// Reroot the cut with a new root operation.
406// This is used to create a new cut with the same inputs and operations, but a
407// different root operation.
408Cut Cut::reRoot(Operation *root) const {
409 assert(isCutDerivedFromOperand(*this, root) &&
410 "The operation must be a child of the current root operation");
411 Cut newCut;
412 newCut.inputs = inputs;
413 newCut.operations = operations;
414 // Add the new root operation to the cut
415 newCut.operations.insert(root);
416 return newCut;
417}
418
419//===----------------------------------------------------------------------===//
420// MatchedPattern
421//===----------------------------------------------------------------------===//
422
423ArrayRef<DelayType> MatchedPattern::getArrivalTimes() const {
424 assert(pattern && "Pattern must be set to get arrival time");
425 return arrivalTimes;
426}
427
429 assert(pattern && "Pattern must be set to get arrival time");
430 return arrivalTimes[index];
431}
432
434 assert(pattern && "Pattern must be set to get the pattern");
435 return pattern;
436}
437
439 assert(pattern && "Pattern must be set to get area");
440 return pattern->getArea();
441}
442
444 unsigned outputIndex) const {
445 assert(pattern && "Pattern must be set to get delay");
446 return pattern->getDelay(inputIndex, outputIndex);
447}
448
449//===----------------------------------------------------------------------===//
450// CutSet
451//===----------------------------------------------------------------------===//
452
454
455unsigned CutSet::size() const { return cuts.size(); }
456
458 assert(!isFrozen && "Cannot add cuts to a frozen cut set");
459 cuts.push_back(std::move(cut));
460}
461
462ArrayRef<Cut> CutSet::getCuts() const { return cuts; }
463
464// Remove duplicate cuts and non-minimal cuts. A cut is non-minimal if there
465// exists another cut that is a subset of it. We use a bitset to represent the
466// inputs of each cut for efficient subset checking.
467static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl<Cut> &cuts) {
468 // First sort the cuts by input size (ascending). This ensures that when we
469 // iterate through the cuts, we always encounter smaller cuts first, allowing
470 // us to efficiently check for non-minimality. Stable sort to maintain
471 // relative order of cuts with the same input size.
472 std::stable_sort(cuts.begin(), cuts.end(), [](const Cut &a, const Cut &b) {
473 return a.getInputSize() < b.getInputSize();
474 });
475
476 llvm::SmallVector<llvm::Bitset<64>, 4> inputBitMasks;
477 DenseMap<Value, unsigned> inputIndices;
478 auto getIndex = [&](Value v) -> unsigned {
479 auto it = inputIndices.find(v);
480 if (it != inputIndices.end())
481 return it->second;
482 unsigned index = inputIndices.size();
483 if (LLVM_UNLIKELY(index >= 64))
484 llvm::report_fatal_error(
485 "Too many unique inputs across cuts. Max 64 supported. Consider "
486 "increasing the compile-time constant.");
487 inputIndices[v] = index;
488 return index;
489 };
490
491 for (unsigned i = 0; i < cuts.size(); ++i) {
492 auto &cut = cuts[i];
493 // Create a unique identifier for the cut based on its inputs.
494 llvm::Bitset<64> inputsMask;
495 for (auto input : cut.inputs.getArrayRef())
496 inputsMask.set(getIndex(input));
497
498 bool isUnique = llvm::all_of(
499 inputBitMasks, [&](const llvm::Bitset<64> &existingCutInputMask) {
500 // If the bitset is a subset of the current inputsMask, it is not
501 // unique
502 return (existingCutInputMask & inputsMask) != existingCutInputMask;
503 });
504
505 if (!isUnique)
506 continue;
507
508 // If the cut is unique, keep it
509 size_t uniqueCount = inputBitMasks.size();
510 if (i != uniqueCount)
511 cuts[uniqueCount] = std::move(cut);
512 inputBitMasks.push_back(inputsMask);
513 }
514
515 unsigned uniqueCount = inputBitMasks.size();
516
517 LLVM_DEBUG(llvm::dbgs() << "Original cuts: " << cuts.size()
518 << " Unique cuts: " << uniqueCount << "\n");
519
520 // Resize the cuts vector to the number of unique cuts found
521 cuts.resize(uniqueCount);
522}
523
525 const CutRewriterOptions &options,
526 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut) {
527
528 // Step 1: Remove duplicate and non-minimal cuts to reduce the search space
529 // This eliminates cuts that are strictly dominated by others
531
532 // Step 2: Match each remaining cut against available patterns
533 // This computes timing and area information needed for prioritization
534 for (auto &cut : cuts) {
535 // Verify cut doesn't exceed input size limits
536 assert(cut.getInputSize() <= options.maxCutInputSize &&
537 "Cut input size exceeds maximum allowed size");
538
539 // Attempt to match the cut against available patterns
540 auto matched = matchCut(cut);
541 if (!matched)
542 continue; // No matching pattern found for this cut
543
544 // Store the matched pattern with the cut for later evaluation
545 cut.setMatchedPattern(std::move(*matched));
546 }
547
548 // Step 3: Sort cuts by priority to select the best ones
549 // Priority is determined by the optimization strategy:
550 // - Trivial cuts (direct connections) have highest priority
551 // - Among matched cuts, compare by area/delay based on the strategy
552 // - Matched cuts are preferred over unmatched cuts
553 // See "Combinational and Sequential Mapping with Priority Cuts" by Mishchenko
554 // et al., ICCAD 2007 for more details.
555 // TODO: Use a priority queue instead of sorting for better performance.
556
557 // Partition the cuts into trivial and non-trivial cuts.
558 auto *trivialCutsEnd =
559 std::stable_partition(cuts.begin(), cuts.end(),
560 [](const Cut &cut) { return cut.isTrivialCut(); });
561
562 std::stable_sort(trivialCutsEnd, cuts.end(),
563 [&options](const Cut &a, const Cut &b) -> bool {
564 assert(!a.isTrivialCut() && !b.isTrivialCut() &&
565 "Trivial cuts should have been excluded");
566 const auto &aMatched = a.getMatchedPattern();
567 const auto &bMatched = b.getMatchedPattern();
568
569 // Both cuts have matched patterns.
570 if (aMatched && bMatched)
571 return compareDelayAndArea(
572 options.strategy, aMatched->getArea(),
573 aMatched->getArrivalTimes(), bMatched->getArea(),
574 bMatched->getArrivalTimes());
575
576 // Prefer cuts with matched patterns over those without
577 if (aMatched && !bMatched)
578 return true;
579 if (!aMatched && bMatched)
580 return false;
581
582 // Both cuts are unmatched - prefer smaller input size
583 return a.getInputSize() < b.getInputSize();
584 });
585
586 // Step 4: Limit the number of cuts to prevent exponential growth
587 // After sorting, keep only the best cuts up to the specified limit
588 if (cuts.size() > options.maxCutSizePerRoot)
589 cuts.resize(options.maxCutSizePerRoot);
590
591 // Select the best cut from the remaining candidates
592 for (auto &cut : cuts) {
593 const auto &currentMatch = cut.getMatchedPattern();
594 if (!currentMatch)
595 continue; // Skip cuts without matched patterns
596
597 // This is already sorted, so the first matched cut is the best.
598 bestCut = &cut;
599 break;
600 }
601
602 LLVM_DEBUG({
603 llvm::dbgs() << "Finalized cut set with " << cuts.size() << " cuts and "
604 << (bestCut
605 ? "matched pattern to " + bestCut->getMatchedPattern()
606 ->getPattern()
607 ->getPatternName()
608 : "no matched pattern")
609 << "\n";
610 });
611
612 isFrozen = true; // Mark the cut set as frozen
613}
614
615//===----------------------------------------------------------------------===//
616// CutRewritePattern
617//===----------------------------------------------------------------------===//
618
620 SmallVectorImpl<NPNClass> &matchingNPNClasses) const {
621 return false;
622}
623
624//===----------------------------------------------------------------------===//
625// CutRewritePatternSet
626//===----------------------------------------------------------------------===//
627
629 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns)
630 : patterns(std::move(patterns)) {
631 // Initialize the NPN to pattern map
632 for (auto &pattern : this->patterns) {
633 SmallVector<NPNClass, 2> npnClasses;
634 auto result = pattern->useTruthTableMatcher(npnClasses);
635 if (result) {
636 for (auto npnClass : npnClasses) {
637 // Create a NPN class from the truth table
638 npnToPatternMap[{npnClass.truthTable.table,
639 npnClass.truthTable.numInputs}]
640 .push_back(std::make_pair(std::move(npnClass), pattern.get()));
641 }
642 } else {
643 // If the pattern does not provide NPN classes, we use a special key
644 // to indicate that it should be considered for all cuts.
645 nonTruthTablePatterns.push_back(pattern.get());
646 }
647 }
648}
649
650//===----------------------------------------------------------------------===//
651// CutEnumerator
652//===----------------------------------------------------------------------===//
653
655 : options(options) {}
656
658 auto [cutSetPtr, inserted] =
659 cutSets.try_emplace(value, std::make_unique<CutSet>());
660 assert(inserted && "Cut set already exists for this value");
661 return cutSetPtr->second.get();
662}
663
664llvm::MapVector<Value, std::unique_ptr<CutSet>> CutEnumerator::takeVector() {
665 return std::move(cutSets);
666}
667
668void CutEnumerator::clear() { cutSets.clear(); }
669
670LogicalResult CutEnumerator::visit(Operation *op) {
671 if (isSupportedLogicOp(op))
672 return visitLogicOp(op);
673
674 // Skip other operations. If the operation is not a supported logic
675 // operation, we create a trivial cut lazily.
676 return success();
677}
678
679LogicalResult CutEnumerator::visitLogicOp(Operation *logicOp) {
680 assert(logicOp->getNumResults() == 1 &&
681 "Logic operation must have a single result");
682
683 Value result = logicOp->getResult(0);
684 unsigned numOperands = logicOp->getNumOperands();
685
686 // Validate operation constraints
687 // TODO: Variadic operations and non-single-bit results can be supported
688 if (numOperands > 2)
689 return logicOp->emitError("Cut enumeration supports at most 2 operands, "
690 "found: ")
691 << numOperands;
692 if (!logicOp->getOpResult(0).getType().isInteger(1))
693 return logicOp->emitError()
694 << "Supported logic operations must have a single bit "
695 "result type but found: "
696 << logicOp->getResult(0).getType();
697
698 SmallVector<const CutSet *, 2> operandCutSets;
699 operandCutSets.reserve(numOperands);
700 // Collect cut sets for each operand
701 for (unsigned i = 0; i < numOperands; ++i) {
702 auto *operandCutSet = getCutSet(logicOp->getOperand(i));
703 if (!operandCutSet)
704 return logicOp->emitError("Failed to get cut set for operand ")
705 << i << ": " << logicOp->getOperand(i);
706 operandCutSets.push_back(operandCutSet);
707 }
708
709 // Create the singleton cut (just this operation)
710 Cut primaryInputCut = getAsTrivialCut(result);
711
712 auto *resultCutSet = createNewCutSet(result);
713
714 // Add the singleton cut first
715 resultCutSet->addCut(primaryInputCut);
716
717 // Schedule cut set finalization when exiting this scope
718 auto prune = llvm::make_scope_exit([&]() {
719 // Finalize cut set: remove duplicates, limit size, and match patterns
720 resultCutSet->finalize(options, matchCut);
721 });
722
723 // Handle unary operations
724 if (numOperands == 1) {
725 const auto &inputCutSet = operandCutSets[0];
726
727 // Try to extend each input cut by including this operation
728 for (const Cut &inputCut : inputCutSet->getCuts()) {
729 Cut extendedCut = inputCut.reRoot(logicOp);
730 // Skip cuts that exceed input size limit
731 if (extendedCut.getInputSize() > options.maxCutInputSize)
732 continue;
733
734 resultCutSet->addCut(std::move(extendedCut));
735 }
736 return success();
737 }
738
739 // Handle binary operations (like AND, OR, XOR gates)
740 assert(numOperands == 2 && "Expected binary operation");
741
742 const auto *lhsCutSet = operandCutSets[0];
743 const auto *rhsCutSet = operandCutSets[1];
744
745 // Combine cuts from both inputs to create larger cuts
746 for (const Cut &lhsCut : lhsCutSet->getCuts()) {
747 for (const Cut &rhsCut : rhsCutSet->getCuts()) {
748 Cut mergedCut = lhsCut.mergeWith(rhsCut, logicOp);
749 // Skip cuts that exceed input size limit
750 if (mergedCut.getInputSize() > options.maxCutInputSize)
751 continue;
752
753 resultCutSet->addCut(std::move(mergedCut));
754 }
755 }
756
757 return success();
758}
759
761 Operation *topOp,
762 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut) {
763 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts for module: " << topOp->getName()
764 << "\n");
765 // Topologically sort the logic network
766 if (failed(topologicallySortLogicNetwork(topOp)))
767 return failure();
768
769 // Store the pattern matching function for use during cut finalization
770 this->matchCut = matchCut;
771
772 // Walk through all operations in the module in a topological manner
773 auto result = topOp->walk([&](Operation *op) {
774 if (failed(visit(op)))
775 return mlir::WalkResult::interrupt();
776 return mlir::WalkResult::advance();
777 });
778
779 if (result.wasInterrupted())
780 return failure();
781
782 LLVM_DEBUG(llvm::dbgs() << "Cut enumeration completed successfully\n");
783 return success();
784}
785
786const CutSet *CutEnumerator::getCutSet(Value value) {
787 // Check if cut set already exists
788 auto *it = cutSets.find(value);
789 if (it == cutSets.end()) {
790 // Create new cut set for an unprocessed value
791 auto cutSet = std::make_unique<CutSet>();
792 cutSet->addCut(getAsTrivialCut(value));
793 auto [newIt, inserted] = cutSets.insert({value, std::move(cutSet)});
794 assert(inserted && "Cut set already exists for this value");
795 (void)newIt;
796 it = newIt;
797 }
798
799 return it->second.get();
800}
801
802/// Generate a human-readable name for a value used in test output.
803/// This function creates meaningful names for values to make debug output
804/// and test results more readable and understandable.
805static StringRef
806getTestVariableName(Value value, DenseMap<OperationName, unsigned> &opCounter) {
807 if (auto *op = value.getDefiningOp()) {
808 // Handle values defined by operations
809 // First, check if the operation already has a name hint attribute
810 if (auto name = op->getAttrOfType<StringAttr>("sv.namehint"))
811 return name.getValue();
812
813 // For single-result operations, generate a unique name based on operation
814 // type
815 if (op->getNumResults() == 1) {
816 auto opName = op->getName();
817 auto count = opCounter[opName]++;
818
819 // Create a unique name by appending a counter to the operation name
820 SmallString<16> nameStr;
821 nameStr += opName.getStringRef();
822 nameStr += "_";
823 nameStr += std::to_string(count);
824
825 // Store the generated name as a hint attribute for future reference
826 auto nameAttr = StringAttr::get(op->getContext(), nameStr);
827 op->setAttr("sv.namehint", nameAttr);
828 return nameAttr;
829 }
830
831 // Multi-result operations or other cases get a generic name
832 return "<unknown>";
833 }
834
835 // Handle block arguments
836 auto blockArg = cast<BlockArgument>(value);
837 auto hwOp =
838 dyn_cast<circt::hw::HWModuleOp>(blockArg.getOwner()->getParentOp());
839 if (!hwOp)
840 return "<unknown>";
841
842 // Return the formal input name from the hardware module
843 return hwOp.getInputName(blockArg.getArgNumber());
844}
845
847 DenseMap<OperationName, unsigned> opCounter;
848 for (auto &[value, cutSetPtr] : cutSets) {
849 auto &cutSet = *cutSetPtr;
850 llvm::outs() << getTestVariableName(value, opCounter) << " "
851 << cutSet.getCuts().size() << " cuts:";
852 for (const Cut &cut : cutSet.getCuts()) {
853 llvm::outs() << " {";
854 llvm::interleaveComma(cut.inputs, llvm::outs(), [&](Value input) {
855 llvm::outs() << getTestVariableName(input, opCounter);
856 });
857 auto &pattern = cut.getMatchedPattern();
858 llvm::outs() << "}"
859 << "@t" << cut.getTruthTable().table.getZExtValue() << "d";
860 if (pattern) {
861 llvm::outs() << *std::max_element(pattern->getArrivalTimes().begin(),
862 pattern->getArrivalTimes().end());
863 } else {
864 llvm::outs() << "0";
865 }
866 }
867 llvm::outs() << "\n";
868 }
869 llvm::outs() << "Cut enumeration completed successfully\n";
870}
871
872//===----------------------------------------------------------------------===//
873// CutRewriter
874//===----------------------------------------------------------------------===//
875
876LogicalResult CutRewriter::run(Operation *topOp) {
877 LLVM_DEBUG({
878 llvm::dbgs() << "Starting Cut Rewriter\n";
879 llvm::dbgs() << "Mode: "
881 : "timing")
882 << "\n";
883 llvm::dbgs() << "Max input size: " << options.maxCutInputSize << "\n";
884 llvm::dbgs() << "Max cut size: " << options.maxCutSizePerRoot << "\n";
885 llvm::dbgs() << "Max cuts per node: " << options.maxCutSizePerRoot << "\n";
886 });
887
888 // Currrently we don't support patterns with multiple outputs.
889 // So check that.
890 // TODO: This must be removed when we support multiple outputs.
891 for (auto &pattern : patterns.patterns) {
892 if (pattern->getNumOutputs() > 1) {
893 return mlir::emitError(pattern->getLoc(),
894 "Cut rewriter does not support patterns with "
895 "multiple outputs yet");
896 }
897 }
898
899 // First sort the operations topologically to ensure we can process them
900 // in a valid order.
901 if (failed(topologicallySortLogicNetwork(topOp)))
902 return failure();
903
904 // Enumerate cuts for all nodes
905 if (failed(enumerateCuts(topOp)))
906 return failure();
907
908 // Dump cuts if testing priority cuts.
911 return success();
912 }
913
914 // Select best cuts and perform mapping
915 if (failed(runBottomUpRewrite(topOp)))
916 return failure();
917
918 return success();
919}
920
921LogicalResult CutRewriter::enumerateCuts(Operation *topOp) {
922 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts...\n");
923
925 topOp, [&](const Cut &cut) -> std::optional<MatchedPattern> {
926 // Match the cut against the patterns
927 return patternMatchCut(cut);
928 });
929}
930
931ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
933 if (patterns.npnToPatternMap.empty())
934 return {};
935
936 auto &npnClass = cut.getNPNClass();
937 auto it = patterns.npnToPatternMap.find(
938 {npnClass.truthTable.table, npnClass.truthTable.numInputs});
939 if (it == patterns.npnToPatternMap.end())
940 return {};
941 return it->getSecond();
942}
943
944std::optional<MatchedPattern> CutRewriter::patternMatchCut(const Cut &cut) {
945 if (cut.isTrivialCut())
946 return {};
947
948 const CutRewritePattern *bestPattern = nullptr;
949 SmallVector<DelayType, 4> inputArrivalTimes;
950 SmallVector<DelayType, 1> bestArrivalTimes;
951 inputArrivalTimes.reserve(cut.getInputSize());
952 bestArrivalTimes.reserve(cut.getOutputSize());
953
954 // Compute arrival times for each input.
955 for (auto input : cut.inputs) {
956 assert(input.getType().isInteger(1));
957 if (isAlwaysCutInput(input)) {
958 // If the input is a primary input, it has no delay.
959 // TODO: This doesn't consider a global delay. Need to capture
960 // `arrivalTime` on the IR to make the primary input delays visible.
961 inputArrivalTimes.push_back(0);
962 continue;
963 }
964 auto *cutSet = cutEnumerator.getCutSet(input);
965 assert(cutSet && "Input must have a valid cut set");
966
967 // If there is no matching pattern, it means it's not possible to use the
968 // input in the cut rewriting. So abort early.
969 auto *bestCut = cutSet->getBestMatchedCut();
970 if (!bestCut)
971 return {};
972
973 const auto &matchedPattern = *bestCut->getMatchedPattern();
974
975 // Otherwise, the cut input is an op result. Get the arrival time
976 // from the matched pattern.
977 inputArrivalTimes.push_back(matchedPattern.getArrivalTime(
978 cast<mlir::OpResult>(input).getResultNumber()));
979 }
980
981 auto computeArrivalTimeAndPickBest =
982 [&](const CutRewritePattern *pattern,
983 llvm::function_ref<unsigned(unsigned)> mapIndex) {
984 SmallVector<DelayType, 1> outputArrivalTimes;
985 // Compute the maximum delay for each output from inputs.
986 for (unsigned outputIndex = 0, outputSize = cut.getOutputSize();
987 outputIndex < outputSize; ++outputIndex) {
988 // Compute the arrival time for this output.
989 DelayType outputArrivalTime = 0;
990 for (unsigned inputIndex = 0, inputSize = cut.getInputSize();
991 inputIndex < inputSize; ++inputIndex) {
992 // Map pattern input i to cut input through NPN transformations
993 unsigned cutOriginalInput = mapIndex(inputIndex);
994 outputArrivalTime =
995 std::max(outputArrivalTime,
996 pattern->getDelay(cutOriginalInput, outputIndex) +
997 inputArrivalTimes[cutOriginalInput]);
998 }
999
1000 outputArrivalTimes.push_back(outputArrivalTime);
1001 }
1002
1003 // Update the arrival time
1004 if (!bestPattern ||
1006 outputArrivalTimes, bestPattern->getArea(),
1007 bestArrivalTimes)) {
1008 LLVM_DEBUG({
1009 llvm::dbgs() << "== Matched Pattern ==============\n";
1010 llvm::dbgs() << "Matching cut: \n";
1011 cut.dump(llvm::dbgs());
1012 llvm::dbgs() << "Found better pattern: "
1013 << pattern->getPatternName();
1014 llvm::dbgs() << " with area: " << pattern->getArea();
1015 llvm::dbgs() << " and input arrival times: ";
1016 for (unsigned i = 0; i < inputArrivalTimes.size(); ++i) {
1017 llvm::dbgs() << " " << inputArrivalTimes[i];
1018 }
1019 llvm::dbgs() << " and arrival times: ";
1020
1021 for (auto arrivalTime : outputArrivalTimes) {
1022 llvm::dbgs() << " " << arrivalTime;
1023 }
1024 llvm::dbgs() << "\n";
1025 llvm::dbgs() << "== Matched Pattern End ==============\n";
1026 });
1027
1028 bestArrivalTimes = std::move(outputArrivalTimes);
1029 bestPattern = pattern;
1030 }
1031 };
1032
1033 for (auto &[patternNPN, pattern] : getMatchingPatternsFromTruthTable(cut)) {
1034 assert(patternNPN.truthTable.numInputs == cut.getInputSize() &&
1035 "Pattern input size must match cut input size");
1036 if (!pattern->match(cut))
1037 continue;
1038 auto &cutNPN = cut.getNPNClass();
1039
1040 // Get the input mapping from pattern's NPN class to cut's NPN class
1041 SmallVector<unsigned> inputMapping;
1042 cutNPN.getInputPermutation(patternNPN, inputMapping);
1043 computeArrivalTimeAndPickBest(pattern,
1044 [&](unsigned i) { return inputMapping[i]; });
1045 }
1046
1047 for (const CutRewritePattern *pattern : patterns.nonTruthTablePatterns)
1048 if (pattern->match(cut))
1049 computeArrivalTimeAndPickBest(pattern, [&](unsigned i) { return i; });
1050
1051 if (!bestPattern)
1052 return {}; // No matching pattern found
1053
1054 return MatchedPattern(bestPattern, std::move(bestArrivalTimes));
1055}
1056
1057LogicalResult CutRewriter::runBottomUpRewrite(Operation *top) {
1058 LLVM_DEBUG(llvm::dbgs() << "Performing cut-based rewriting...\n");
1059 auto cutVector = cutEnumerator.takeVector();
1061 UnusedOpPruner pruner;
1062 PatternRewriter rewriter(top->getContext());
1063 for (auto &[value, cutSet] : llvm::reverse(cutVector)) {
1064 if (value.use_empty()) {
1065 if (auto *op = value.getDefiningOp())
1066 pruner.eraseNow(op);
1067 continue;
1068 }
1069
1070 if (isAlwaysCutInput(value)) {
1071 // If the value is a primary input, skip it
1072 LLVM_DEBUG(llvm::dbgs() << "Skipping inputs: " << value << "\n");
1073 continue;
1074 }
1075
1076 LLVM_DEBUG(llvm::dbgs() << "Cut set for value: " << value << "\n");
1077 auto *bestCut = cutSet->getBestMatchedCut();
1078 if (!bestCut) {
1080 continue; // No matching pattern found, skip this value
1081 return emitError(value.getLoc(), "No matching cut found for value: ")
1082 << value;
1083 }
1084
1085 rewriter.setInsertionPoint(bestCut->getRoot());
1086 const auto &matchedPattern = bestCut->getMatchedPattern();
1087 auto result = matchedPattern->getPattern()->rewrite(rewriter, *bestCut);
1088 if (failed(result))
1089 return failure();
1090
1091 rewriter.replaceOp(bestCut->getRoot(), *result);
1092
1094 auto array = rewriter.getI64ArrayAttr(matchedPattern->getArrivalTimes());
1095 (*result)->setAttr("test.arrival_times", array);
1096 }
1097 }
1098
1099 return success();
1100}
assert(baseType &&"element must be base type")
static bool isAlwaysCutInput(Value value)
static bool isSupportedLogicOp(mlir::Operation *op)
static Cut getAsTrivialCut(mlir::Value value)
static FailureOr< BinaryTruthTable > computeTruthTable(mlir::ValueRange values, const OpRange &ops, const llvm::SmallSetVector< mlir::Value, 4 > &inputArgs)
Get the truth table for an op.
static StringRef getTestVariableName(Value value, DenseMap< OperationName, unsigned > &opCounter)
Generate a human-readable name for a value used in test output.
static bool compareDelayAndArea(OptimizationStrategy strategy, double newArea, ArrayRef< DelayType > newDelay, double oldArea, ArrayRef< DelayType > oldDelay)
static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl< Cut > &cuts)
static bool isCutDerivedFromOperand(const Cut &cut, Operation *op)
static void simulateLogicOp(Operation *op, DenseMap< Value, llvm::APInt > &eval)
RewritePatternSet pattern
Strategy strategy
LogicalResult visit(Operation *op)
Visit a single operation and generate cuts for it.
const CutRewriterOptions & options
Configuration options for cut enumeration.
llvm::MapVector< Value, std::unique_ptr< CutSet > > cutSets
Maps values to their associated cut sets.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
CutEnumerator(const CutRewriterOptions &options)
Constructor for cut enumerator.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
LogicalResult visitLogicOp(Operation *logicOp)
Visit a combinational logic operation and generate cuts.
llvm::MapVector< Value, std::unique_ptr< CutSet > > takeVector()
Move ownership of all cut sets to caller.
CutSet * createNewCutSet(Value value)
Create a new cut set for a value.
const CutSet * getCutSet(Value value)
Get the cut set for a specific value.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
SmallVector< const CutRewritePattern *, 4 > nonTruthTablePatterns
Patterns that use custom matching logic instead of truth tables.
CutRewritePatternSet(llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns)
Constructor that takes ownership of the provided patterns.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
CutEnumerator cutEnumerator
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
llvm::SmallVector< Cut, 4 > cuts
Collection of cuts for this node.
void addCut(Cut cut)
Add a new cut to this set.
unsigned size() const
Get the number of cuts in this set.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut)
Finalize the cut set by removing duplicates and selecting the best pattern.
bool isFrozen
Whether cut set is finalized.
ArrayRef< Cut > getCuts() const
Get read-only access to all cuts in this set.
Represents a cut in the combinational logic network.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
llvm::SmallSetVector< mlir::Operation *, 4 > operations
Operations contained within this cut.
unsigned getOutputSize() const
Get the number of outputs from root operation.
void getPermutatedInputs(const NPNClass &patternNPN, SmallVectorImpl< Value > &permutedInputs) const
Get the permutated inputs for this cut based on the given pattern NPN.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
mlir::Operation * getRoot() const
Get the root operation of this cut.
llvm::SmallSetVector< mlir::Value, 4 > inputs
External inputs to this cut (cut boundary).
void dump(llvm::raw_ostream &os) const
const BinaryTruthTable & getTruthTable() const
Get the truth table for this cut.
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
Cut mergeWith(const Cut &other, Operation *root) const
Merge this cut with another cut to form a new cut.
unsigned getInputSize() const
Get the number of inputs to this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
Cut reRoot(Operation *root) const
Represents a cut that has been successfully matched to a rewriting pattern.
Definition CutRewriter.h:67
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
DelayType getDelay(unsigned inputIndex, unsigned outputIndex) const
Get the delay between specific input and output pins.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
Definition CutRewriter.h:69
double getArea() const
Get the area cost of using this pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
Definition CutRewriter.h:71
OptimizationStrategy
Optimization strategy.
Definition SynthPasses.h:24
@ OptimizationStrategyArea
Optimize for minimal area.
Definition SynthPasses.h:25
@ OptimizationStrategyTiming
Optimize for minimal critical path delay.
Definition SynthPasses.h:26
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
int64_t DelayType
Definition CutRewriter.h:36
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:41
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Represents a boolean function as a truth table.
Definition NPNClass.h:38
Represents the canonical form of a boolean function under NPN equivalence.
Definition NPNClass.h:103
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
Definition NPNClass.cpp:210
void getInputPermutation(const NPNClass &targetNPN, llvm::SmallVectorImpl< unsigned > &permutation) const
Get input permutation from this NPN class to another equivalent NPN class.
Definition NPNClass.cpp:187
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...
void eraseNow(Operation *op)
Erase an operation immediately, and remove it from the set of ops to be removed later.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual DelayType getDelay(unsigned inputIndex, unsigned outputIndex) const =0
Get the delay between specific input and output.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
virtual double getArea() const =0
Get the area cost of this pattern.
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.