CIRCT 22.0.0git
Loading...
Searching...
No Matches
CutRewriter.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a DAG-based boolean matching cut rewriting algorithm for
10// applications like technology/LUT mapping and combinational logic
11// optimization. The algorithm uses priority cuts and NPN
12// (Negation-Permutation-Negation) canonical forms to efficiently match cuts
13// against rewriting patterns.
14//
15// References:
16// "Combinational and Sequential Mapping with Priority Cuts", Alan Mishchenko,
17// Sungmin Cho, Satrajit Chatterjee and Robert Brayton, ICCAD 2007
18//
19//===----------------------------------------------------------------------===//
20
22
26#include "circt/Support/LLVM.h"
29#include "mlir/Analysis/TopologicalSortUtils.h"
30#include "mlir/IR/Builders.h"
31#include "mlir/IR/Operation.h"
32#include "mlir/IR/RegionKindInterface.h"
33#include "mlir/IR/Value.h"
34#include "mlir/IR/ValueRange.h"
35#include "mlir/IR/Visitors.h"
36#include "mlir/Support/LLVM.h"
37#include "llvm/ADT/APInt.h"
38#include "llvm/ADT/Bitset.h"
39#include "llvm/ADT/DenseMap.h"
40#include "llvm/ADT/MapVector.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/ScopeExit.h"
43#include "llvm/ADT/SetVector.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/iterator.h"
46#include "llvm/Support/Debug.h"
47#include "llvm/Support/ErrorHandling.h"
48#include "llvm/Support/LogicalResult.h"
49#include <algorithm>
50#include <functional>
51#include <memory>
52#include <optional>
53#include <string>
54
55#define DEBUG_TYPE "synth-cut-rewriter"
56
57using namespace circt;
58using namespace circt::synth;
59
60static bool isSupportedLogicOp(mlir::Operation *op) {
61 // Check if the operation is a combinational operation that can be simulated
62 // TODO: Extend this to allow comb.and/xor/or as well.
63 return isa<aig::AndInverterOp>(op);
64}
65
66static void simulateLogicOp(Operation *op, DenseMap<Value, llvm::APInt> &eval) {
68 "Operation must be a supported logic operation for simulation");
69
70 // Simulate the operation by evaluating its inputs and computing the output
71 // This is a simplified simulation for demonstration purposes
72 if (auto andOp = dyn_cast<aig::AndInverterOp>(op)) {
73 SmallVector<llvm::APInt, 2> inputs;
74 inputs.reserve(andOp.getInputs().size());
75 for (auto input : andOp.getInputs()) {
76 auto it = eval.find(input);
77 if (it == eval.end())
78 llvm::report_fatal_error("Input value not found in evaluation map");
79 inputs.push_back(it->second);
80 }
81 // Evaluate the and inverter
82 eval[andOp.getResult()] = andOp.evaluate(inputs);
83 return;
84 }
85
86 llvm::report_fatal_error(
87 "Unsupported operation for simulation. isSupportedLogicOp should "
88 "be used to check if the operation can be simulated.");
89}
90
91// Return true if the value is always a cut input.
92static bool isAlwaysCutInput(Value value) {
93 auto *op = value.getDefiningOp();
94 // If the value has no defining operation, it is an input
95 if (!op)
96 return true;
97
98 if (op->hasTrait<OpTrait::ConstantLike>()) {
99 // Constant values are never cut inputs.
100 return false;
101 }
102
103 return !isSupportedLogicOp(op);
104}
105
106// Return true if the new area/delay is better than the old area/delay in the
107// context of the given strategy.
109 ArrayRef<DelayType> newDelay, double oldArea,
110 ArrayRef<DelayType> oldDelay) {
112 // Compare by area first.
113 return newArea < oldArea || (newArea == oldArea && newDelay < oldDelay);
114 }
116 // Compare by delay first.
117 return newDelay < oldDelay || (newDelay == oldDelay && newArea < oldArea);
118 }
119 llvm_unreachable("Unknown mapping strategy");
120}
121
122LogicalResult
124
125 auto isOperationReady = [](Value value, Operation *op) -> bool {
126 // Topologically sort simulatable ops and purely
127 // dataflow ops. Other operations can be scheduled.
128 return !(isSupportedLogicOp(op) ||
129 isa<comb::ExtractOp, comb::ReplicateOp, comb::ConcatOp>(op));
130 };
131
132 auto result = topologicallySortGraphRegionBlocks(topOp, isOperationReady);
133 if (failed(result))
134 return mlir::emitError(topOp->getLoc(),
135 "failed to sort operations topologically");
136 return success();
137}
138
139/// Get the truth table for an op.
140template <typename OpRange>
141FailureOr<BinaryTruthTable> static computeTruthTable(
142 mlir::ValueRange values, const OpRange &ops,
143 const llvm::SmallSetVector<mlir::Value, 4> &inputArgs) {
144 // Create a truth table for the operation
145 int64_t numInputs = inputArgs.size();
146 int64_t numOutputs = values.size();
147 if (LLVM_UNLIKELY(numOutputs != 1 || numInputs >= maxTruthTableInputs)) {
148 if (numOutputs == 0)
149 return BinaryTruthTable(numInputs, 0);
150 if (numInputs >= maxTruthTableInputs)
151 return mlir::emitError(values.front().getLoc(),
152 "Truth table is too large");
153 return mlir::emitError(values.front().getLoc(),
154 "Multiple outputs are not supported yet");
155 }
156
157 // Create a truth table with the given number of inputs and outputs
158 BinaryTruthTable truthTable(numInputs, numOutputs);
159 // The truth table size is 2^numInputs
160 // Create a map to evaluate the operation
161 DenseMap<Value, APInt> eval;
162 for (uint32_t i = 0; i < numInputs; ++i)
163 eval[inputArgs[i]] = circt::createVarMask(numInputs, i, true);
164 // Simulate the operation
165 for (auto *op : ops) {
166 if (op->getNumResults() == 0)
167 continue; // Skip operations with no results
168 if (!isSupportedLogicOp(op))
169 return op->emitError("Unsupported operation for truth table simulation");
170
171 // Simulate the operation
172 simulateLogicOp(op, eval);
173 }
174 // TODO: Currently numOutputs is always 1, so we can just return the first
175 // one.
176 return BinaryTruthTable(numInputs, 1, eval[values[0]]);
177}
178
179FailureOr<BinaryTruthTable> circt::synth::getTruthTable(ValueRange values,
180 Block *block) {
181 // Get the input arguments from the block
182 llvm::SmallSetVector<Value, 4> inputs;
183 for (auto arg : block->getArguments())
184 inputs.insert(arg);
185
186 // If there are no inputs, return an empty truth table
187 if (inputs.empty())
188 return BinaryTruthTable();
189
190 return computeTruthTable(values, llvm::make_pointer_range(*block), inputs);
191}
192
193//===----------------------------------------------------------------------===//
194// Cut
195//===----------------------------------------------------------------------===//
196
197bool Cut::isTrivialCut() const {
198 // A cut is a trival cut if it has no operations and only one input
199 return operations.empty() && inputs.size() == 1;
200}
201
202mlir::Operation *Cut::getRoot() const {
203 return operations.empty()
204 ? nullptr
205 : operations.back(); // The last operation is the root
206}
207
209 // If the NPN is already computed, return it
210 if (npnClass)
211 return *npnClass;
212
213 auto truthTable = getTruthTable();
214
215 // Compute the NPN canonical form
216 auto canonicalForm = NPNClass::computeNPNCanonicalForm(truthTable);
217
218 npnClass.emplace(std::move(canonicalForm));
219 return *npnClass;
220}
221
222void Cut::getPermutatedInputs(const NPNClass &patternNPN,
223 SmallVectorImpl<Value> &permutedInputs) const {
224 auto npnClass = getNPNClass();
225 SmallVector<unsigned> idx;
226 npnClass.getInputPermutation(patternNPN, idx);
227 permutedInputs.reserve(idx.size());
228 for (auto inputIndex : idx) {
229 assert(inputIndex < inputs.size() && "Input index out of bounds");
230 permutedInputs.push_back(inputs[inputIndex]);
231 }
232}
233
234LogicalResult
236 SmallVectorImpl<DelayType> &results) const {
237 results.reserve(getInputSize());
238
239 // Compute arrival times for each input.
240 for (auto input : inputs) {
241 if (isAlwaysCutInput(input)) {
242 // If the input is a primary input, it has no delay.
243 results.push_back(0);
244 continue;
245 }
246 auto *cutSet = enumerator.getCutSet(input);
247 assert(cutSet && "Input must have a valid cut set");
248
249 // If there is no matching pattern, it means it's not possible to use the
250 // input in the cut rewriting. Return empty vector to indicate failure.
251 auto *bestCut = cutSet->getBestMatchedCut();
252 if (!bestCut)
253 return failure();
254
255 const auto &matchedPattern = *bestCut->getMatchedPattern();
256
257 // Otherwise, the cut input is an op result. Get the arrival time
258 // from the matched pattern.
259 results.push_back(matchedPattern.getArrivalTime(
260 cast<mlir::OpResult>(input).getResultNumber()));
261 }
262
263 return success();
264}
265
266void Cut::dump(llvm::raw_ostream &os) const {
267 os << "// === Cut Dump ===\n";
268 os << "Cut with " << getInputSize() << " inputs and " << operations.size()
269 << " operations:\n";
270 if (isTrivialCut()) {
271 os << "Primary input cut: " << *inputs.begin() << "\n";
272 return;
273 }
274
275 os << "Inputs: \n";
276 for (auto [idx, input] : llvm::enumerate(inputs)) {
277 os << " Input " << idx << ": " << input << "\n";
278 }
279 os << "\nOperations: \n";
280 for (auto *op : operations) {
281 op->print(os);
282 os << "\n";
283 }
284 auto &npnClass = getNPNClass();
285 npnClass.dump(os);
286
287 os << "// === Cut End ===\n";
288}
289
290unsigned Cut::getInputSize() const { return inputs.size(); }
291
292unsigned Cut::getOutputSize() const { return getRoot()->getNumResults(); }
293
295 if (truthTable)
296 return *truthTable;
297
298 if (isTrivialCut()) {
299 // For a trivial cut, a truth table is simply the identity function.
300 // 0 -> 0, 1 -> 1
301 truthTable = BinaryTruthTable(1, 1, {llvm::APInt(2, 2)});
302 return *truthTable;
303 }
304
305 // Create a truth table with the given number of inputs and outputs
307
308 return *truthTable;
309}
310
311static Cut getAsTrivialCut(mlir::Value value) {
312 // Create a cut with a single root operation
313 Cut cut;
314 // There is no input for the primary input cut.
315 cut.inputs.insert(value);
316
317 return cut;
318}
319
320[[maybe_unused]] static bool isCutDerivedFromOperand(const Cut &cut,
321 Operation *op) {
322 if (auto *root = cut.getRoot())
323 return llvm::any_of(op->getOperands(),
324 [&](Value v) { return v.getDefiningOp() == root; });
325
326 assert(cut.isTrivialCut());
327 // If the cut is trivial, it has no operations, so it must be a primary input.
328 // In this case, the only operation that can be derived from it is the
329 // primary input itself.
330 return cut.inputs.size() == 1 &&
331 llvm::any_of(op->getOperands(),
332 [&](Value v) { return v == *cut.inputs.begin(); });
333}
334
335Cut Cut::mergeWith(const Cut &other, Operation *root) const {
336 assert(isCutDerivedFromOperand(*this, root) &&
337 isCutDerivedFromOperand(other, root) &&
338 "The operation must be a child of the current root operation");
339
340 // Create a new cut that combines this cut and the other cut
341 Cut newCut;
342 // Topological sort the operations in the new cut.
343 // TODO: Merge-sort `operations` and `other.operations` by operation index
344 // (since it's already topo-sorted, we can use a simple merge).
345 std::function<void(Operation *)> populateOperations = [&](Operation *op) {
346 // If the operation is already in the cut, skip it
347 if (newCut.operations.contains(op))
348 return;
349
350 // Add its operands to the worklist
351 for (auto value : op->getOperands()) {
352 if (isAlwaysCutInput(value))
353 continue;
354
355 // If the value is in *both* cuts inputs, it is an input. So skip
356 // it.
357 bool isInput = inputs.contains(value);
358 bool isOtherInput = other.inputs.contains(value);
359 // If the value is in this cut inputs, it is an input. So skip it
360 if (isInput && isOtherInput)
361 continue;
362
363 auto *defOp = value.getDefiningOp();
364
365 assert(defOp && "Value must have a defining operation since block"
366 "arguments are treated as inputs");
367
368 // Otherwise, check if the operation is in the other cut.
369 if (isInput)
370 if (!other.operations.contains(defOp)) // op is in the other cut.
371 continue;
372 if (isOtherInput)
373 if (!operations.contains(defOp)) // op is in this cut.
374 continue;
375 populateOperations(defOp);
376 }
377
378 // Add the operation to the cut
379 newCut.operations.insert(op);
380 };
381
382 populateOperations(root);
383
384 // Construct inputs.
385 for (auto *operation : newCut.operations) {
386 for (auto value : operation->getOperands()) {
387 if (isAlwaysCutInput(value)) {
388 newCut.inputs.insert(value);
389 continue;
390 }
391
392 auto *defOp = value.getDefiningOp();
393 assert(defOp && "Value must have a defining operation");
394
395 // If the operation is not in the cut, it is an input
396 if (!newCut.operations.contains(defOp))
397 // Add the input to the cut
398 newCut.inputs.insert(value);
399 }
400 }
401
402 // TODO: Sort the inputs by their defining operation.
403 // TODO: Update area and delay based on the merged cuts.
404
405 return newCut;
406}
407
408// Reroot the cut with a new root operation.
409// This is used to create a new cut with the same inputs and operations, but a
410// different root operation.
411Cut Cut::reRoot(Operation *root) const {
412 assert(isCutDerivedFromOperand(*this, root) &&
413 "The operation must be a child of the current root operation");
414 Cut newCut;
415 newCut.inputs = inputs;
416 newCut.operations = operations;
417 // Add the new root operation to the cut
418 newCut.operations.insert(root);
419 return newCut;
420}
421
422//===----------------------------------------------------------------------===//
423// MatchedPattern
424//===----------------------------------------------------------------------===//
425
426ArrayRef<DelayType> MatchedPattern::getArrivalTimes() const {
427 assert(pattern && "Pattern must be set to get arrival time");
428 return arrivalTimes;
429}
430
432 assert(pattern && "Pattern must be set to get arrival time");
433 return arrivalTimes[index];
434}
435
437 assert(pattern && "Pattern must be set to get the pattern");
438 return pattern;
439}
440
442 assert(pattern && "Pattern must be set to get area");
443 return area;
444}
445
446//===----------------------------------------------------------------------===//
447// CutSet
448//===----------------------------------------------------------------------===//
449
451
452unsigned CutSet::size() const { return cuts.size(); }
453
455 assert(!isFrozen && "Cannot add cuts to a frozen cut set");
456 cuts.push_back(std::move(cut));
457}
458
459ArrayRef<Cut> CutSet::getCuts() const { return cuts; }
460
461// Remove duplicate cuts and non-minimal cuts. A cut is non-minimal if there
462// exists another cut that is a subset of it. We use a bitset to represent the
463// inputs of each cut for efficient subset checking.
464static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl<Cut> &cuts) {
465 // First sort the cuts by input size (ascending). This ensures that when we
466 // iterate through the cuts, we always encounter smaller cuts first, allowing
467 // us to efficiently check for non-minimality. Stable sort to maintain
468 // relative order of cuts with the same input size.
469 std::stable_sort(cuts.begin(), cuts.end(), [](const Cut &a, const Cut &b) {
470 return a.getInputSize() < b.getInputSize();
471 });
472
473 llvm::SmallVector<llvm::Bitset<64>, 4> inputBitMasks;
474 DenseMap<Value, unsigned> inputIndices;
475 auto getIndex = [&](Value v) -> unsigned {
476 auto it = inputIndices.find(v);
477 if (it != inputIndices.end())
478 return it->second;
479 unsigned index = inputIndices.size();
480 if (LLVM_UNLIKELY(index >= 64))
481 llvm::report_fatal_error(
482 "Too many unique inputs across cuts. Max 64 supported. Consider "
483 "increasing the compile-time constant.");
484 inputIndices[v] = index;
485 return index;
486 };
487
488 for (unsigned i = 0; i < cuts.size(); ++i) {
489 auto &cut = cuts[i];
490 // Create a unique identifier for the cut based on its inputs.
491 llvm::Bitset<64> inputsMask;
492 for (auto input : cut.inputs.getArrayRef())
493 inputsMask.set(getIndex(input));
494
495 bool isUnique = llvm::all_of(
496 inputBitMasks, [&](const llvm::Bitset<64> &existingCutInputMask) {
497 // If the bitset is a subset of the current inputsMask, it is not
498 // unique
499 return (existingCutInputMask & inputsMask) != existingCutInputMask;
500 });
501
502 if (!isUnique)
503 continue;
504
505 // If the cut is unique, keep it
506 size_t uniqueCount = inputBitMasks.size();
507 if (i != uniqueCount)
508 cuts[uniqueCount] = std::move(cut);
509 inputBitMasks.push_back(inputsMask);
510 }
511
512 unsigned uniqueCount = inputBitMasks.size();
513
514 LLVM_DEBUG(llvm::dbgs() << "Original cuts: " << cuts.size()
515 << " Unique cuts: " << uniqueCount << "\n");
516
517 // Resize the cuts vector to the number of unique cuts found
518 cuts.resize(uniqueCount);
519}
520
522 const CutRewriterOptions &options,
523 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut) {
524
525 // Step 1: Remove duplicate and non-minimal cuts to reduce the search space
526 // This eliminates cuts that are strictly dominated by others
528
529 // Step 2: Match each remaining cut against available patterns
530 // This computes timing and area information needed for prioritization
531 for (auto &cut : cuts) {
532 // Verify cut doesn't exceed input size limits
533 assert(cut.getInputSize() <= options.maxCutInputSize &&
534 "Cut input size exceeds maximum allowed size");
535
536 // Attempt to match the cut against available patterns
537 auto matched = matchCut(cut);
538 if (!matched)
539 continue; // No matching pattern found for this cut
540
541 // Store the matched pattern with the cut for later evaluation
542 cut.setMatchedPattern(std::move(*matched));
543 }
544
545 // Step 3: Sort cuts by priority to select the best ones
546 // Priority is determined by the optimization strategy:
547 // - Trivial cuts (direct connections) have highest priority
548 // - Among matched cuts, compare by area/delay based on the strategy
549 // - Matched cuts are preferred over unmatched cuts
550 // See "Combinational and Sequential Mapping with Priority Cuts" by Mishchenko
551 // et al., ICCAD 2007 for more details.
552 // TODO: Use a priority queue instead of sorting for better performance.
553
554 // Partition the cuts into trivial and non-trivial cuts.
555 auto *trivialCutsEnd =
556 std::stable_partition(cuts.begin(), cuts.end(),
557 [](const Cut &cut) { return cut.isTrivialCut(); });
558
559 std::stable_sort(trivialCutsEnd, cuts.end(),
560 [&options](const Cut &a, const Cut &b) -> bool {
561 assert(!a.isTrivialCut() && !b.isTrivialCut() &&
562 "Trivial cuts should have been excluded");
563 const auto &aMatched = a.getMatchedPattern();
564 const auto &bMatched = b.getMatchedPattern();
565
566 // Both cuts have matched patterns.
567 if (aMatched && bMatched)
568 return compareDelayAndArea(
569 options.strategy, aMatched->getArea(),
570 aMatched->getArrivalTimes(), bMatched->getArea(),
571 bMatched->getArrivalTimes());
572
573 // Prefer cuts with matched patterns over those without
574 if (aMatched && !bMatched)
575 return true;
576 if (!aMatched && bMatched)
577 return false;
578
579 // Both cuts are unmatched - prefer smaller input size
580 return a.getInputSize() < b.getInputSize();
581 });
582
583 // Step 4: Limit the number of cuts to prevent exponential growth
584 // After sorting, keep only the best cuts up to the specified limit
585 if (cuts.size() > options.maxCutSizePerRoot)
586 cuts.resize(options.maxCutSizePerRoot);
587
588 // Select the best cut from the remaining candidates
589 for (auto &cut : cuts) {
590 const auto &currentMatch = cut.getMatchedPattern();
591 if (!currentMatch)
592 continue; // Skip cuts without matched patterns
593
594 // This is already sorted, so the first matched cut is the best.
595 bestCut = &cut;
596 break;
597 }
598
599 LLVM_DEBUG({
600 llvm::dbgs() << "Finalized cut set with " << cuts.size() << " cuts and "
601 << (bestCut
602 ? "matched pattern to " + bestCut->getMatchedPattern()
603 ->getPattern()
604 ->getPatternName()
605 : "no matched pattern")
606 << "\n";
607 });
608
609 isFrozen = true; // Mark the cut set as frozen
610}
611
612//===----------------------------------------------------------------------===//
613// CutRewritePattern
614//===----------------------------------------------------------------------===//
615
617 SmallVectorImpl<NPNClass> &matchingNPNClasses) const {
618 return false;
619}
620
621//===----------------------------------------------------------------------===//
622// CutRewritePatternSet
623//===----------------------------------------------------------------------===//
624
626 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns)
627 : patterns(std::move(patterns)) {
628 // Initialize the NPN to pattern map
629 for (auto &pattern : this->patterns) {
630 SmallVector<NPNClass, 2> npnClasses;
631 auto result = pattern->useTruthTableMatcher(npnClasses);
632 if (result) {
633 for (auto npnClass : npnClasses) {
634 // Create a NPN class from the truth table
635 npnToPatternMap[{npnClass.truthTable.table,
636 npnClass.truthTable.numInputs}]
637 .push_back(std::make_pair(std::move(npnClass), pattern.get()));
638 }
639 } else {
640 // If the pattern does not provide NPN classes, we use a special key
641 // to indicate that it should be considered for all cuts.
642 nonNPNPatterns.push_back(pattern.get());
643 }
644 }
645}
646
647//===----------------------------------------------------------------------===//
648// CutEnumerator
649//===----------------------------------------------------------------------===//
650
652 : options(options) {}
653
655 auto [cutSetPtr, inserted] =
656 cutSets.try_emplace(value, std::make_unique<CutSet>());
657 assert(inserted && "Cut set already exists for this value");
658 return cutSetPtr->second.get();
659}
660
661llvm::MapVector<Value, std::unique_ptr<CutSet>> CutEnumerator::takeVector() {
662 return std::move(cutSets);
663}
664
665void CutEnumerator::clear() { cutSets.clear(); }
666
667LogicalResult CutEnumerator::visit(Operation *op) {
668 if (isSupportedLogicOp(op))
669 return visitLogicOp(op);
670
671 // Skip other operations. If the operation is not a supported logic
672 // operation, we create a trivial cut lazily.
673 return success();
674}
675
676LogicalResult CutEnumerator::visitLogicOp(Operation *logicOp) {
677 assert(logicOp->getNumResults() == 1 &&
678 "Logic operation must have a single result");
679
680 Value result = logicOp->getResult(0);
681 unsigned numOperands = logicOp->getNumOperands();
682
683 // Validate operation constraints
684 // TODO: Variadic operations and non-single-bit results can be supported
685 if (numOperands > 2)
686 return logicOp->emitError("Cut enumeration supports at most 2 operands, "
687 "found: ")
688 << numOperands;
689 if (!logicOp->getOpResult(0).getType().isInteger(1))
690 return logicOp->emitError()
691 << "Supported logic operations must have a single bit "
692 "result type but found: "
693 << logicOp->getResult(0).getType();
694
695 SmallVector<const CutSet *, 2> operandCutSets;
696 operandCutSets.reserve(numOperands);
697 // Collect cut sets for each operand
698 for (unsigned i = 0; i < numOperands; ++i) {
699 auto *operandCutSet = getCutSet(logicOp->getOperand(i));
700 if (!operandCutSet)
701 return logicOp->emitError("Failed to get cut set for operand ")
702 << i << ": " << logicOp->getOperand(i);
703 operandCutSets.push_back(operandCutSet);
704 }
705
706 // Create the singleton cut (just this operation)
707 Cut primaryInputCut = getAsTrivialCut(result);
708
709 auto *resultCutSet = createNewCutSet(result);
710
711 // Add the singleton cut first
712 resultCutSet->addCut(primaryInputCut);
713
714 // Schedule cut set finalization when exiting this scope
715 llvm::scope_exit prune([&]() {
716 // Finalize cut set: remove duplicates, limit size, and match patterns
717 resultCutSet->finalize(options, matchCut);
718 });
719
720 // Handle unary operations
721 if (numOperands == 1) {
722 const auto &inputCutSet = operandCutSets[0];
723
724 // Try to extend each input cut by including this operation
725 for (const Cut &inputCut : inputCutSet->getCuts()) {
726 Cut extendedCut = inputCut.reRoot(logicOp);
727 // Skip cuts that exceed input size limit
728 if (extendedCut.getInputSize() > options.maxCutInputSize)
729 continue;
730
731 resultCutSet->addCut(std::move(extendedCut));
732 }
733 return success();
734 }
735
736 // Handle binary operations (like AND, OR, XOR gates)
737 assert(numOperands == 2 && "Expected binary operation");
738
739 const auto *lhsCutSet = operandCutSets[0];
740 const auto *rhsCutSet = operandCutSets[1];
741
742 // Combine cuts from both inputs to create larger cuts
743 for (const Cut &lhsCut : lhsCutSet->getCuts()) {
744 for (const Cut &rhsCut : rhsCutSet->getCuts()) {
745 Cut mergedCut = lhsCut.mergeWith(rhsCut, logicOp);
746 // Skip cuts that exceed input size limit
747 if (mergedCut.getInputSize() > options.maxCutInputSize)
748 continue;
749
750 resultCutSet->addCut(std::move(mergedCut));
751 }
752 }
753
754 return success();
755}
756
758 Operation *topOp,
759 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut) {
760 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts for module: " << topOp->getName()
761 << "\n");
762 // Topologically sort the logic network
763 if (failed(topologicallySortLogicNetwork(topOp)))
764 return failure();
765
766 // Store the pattern matching function for use during cut finalization
767 this->matchCut = matchCut;
768
769 // Walk through all operations in the module in a topological manner
770 auto result = topOp->walk([&](Operation *op) {
771 if (failed(visit(op)))
772 return mlir::WalkResult::interrupt();
773 return mlir::WalkResult::advance();
774 });
775
776 if (result.wasInterrupted())
777 return failure();
778
779 LLVM_DEBUG(llvm::dbgs() << "Cut enumeration completed successfully\n");
780 return success();
781}
782
783const CutSet *CutEnumerator::getCutSet(Value value) {
784 // Check if cut set already exists
785 auto *it = cutSets.find(value);
786 if (it == cutSets.end()) {
787 // Create new cut set for an unprocessed value
788 auto cutSet = std::make_unique<CutSet>();
789 cutSet->addCut(getAsTrivialCut(value));
790 auto [newIt, inserted] = cutSets.insert({value, std::move(cutSet)});
791 assert(inserted && "Cut set already exists for this value");
792 (void)newIt;
793 it = newIt;
794 }
795
796 return it->second.get();
797}
798
799/// Generate a human-readable name for a value used in test output.
800/// This function creates meaningful names for values to make debug output
801/// and test results more readable and understandable.
802static StringRef
803getTestVariableName(Value value, DenseMap<OperationName, unsigned> &opCounter) {
804 if (auto *op = value.getDefiningOp()) {
805 // Handle values defined by operations
806 // First, check if the operation already has a name hint attribute
807 if (auto name = op->getAttrOfType<StringAttr>("sv.namehint"))
808 return name.getValue();
809
810 // For single-result operations, generate a unique name based on operation
811 // type
812 if (op->getNumResults() == 1) {
813 auto opName = op->getName();
814 auto count = opCounter[opName]++;
815
816 // Create a unique name by appending a counter to the operation name
817 SmallString<16> nameStr;
818 nameStr += opName.getStringRef();
819 nameStr += "_";
820 nameStr += std::to_string(count);
821
822 // Store the generated name as a hint attribute for future reference
823 auto nameAttr = StringAttr::get(op->getContext(), nameStr);
824 op->setAttr("sv.namehint", nameAttr);
825 return nameAttr;
826 }
827
828 // Multi-result operations or other cases get a generic name
829 return "<unknown>";
830 }
831
832 // Handle block arguments
833 auto blockArg = cast<BlockArgument>(value);
834 auto hwOp =
835 dyn_cast<circt::hw::HWModuleOp>(blockArg.getOwner()->getParentOp());
836 if (!hwOp)
837 return "<unknown>";
838
839 // Return the formal input name from the hardware module
840 return hwOp.getInputName(blockArg.getArgNumber());
841}
842
844 DenseMap<OperationName, unsigned> opCounter;
845 for (auto &[value, cutSetPtr] : cutSets) {
846 auto &cutSet = *cutSetPtr;
847 llvm::outs() << getTestVariableName(value, opCounter) << " "
848 << cutSet.getCuts().size() << " cuts:";
849 for (const Cut &cut : cutSet.getCuts()) {
850 llvm::outs() << " {";
851 llvm::interleaveComma(cut.inputs, llvm::outs(), [&](Value input) {
852 llvm::outs() << getTestVariableName(input, opCounter);
853 });
854 auto &pattern = cut.getMatchedPattern();
855 llvm::outs() << "}"
856 << "@t" << cut.getTruthTable().table.getZExtValue() << "d";
857 if (pattern) {
858 llvm::outs() << *std::max_element(pattern->getArrivalTimes().begin(),
859 pattern->getArrivalTimes().end());
860 } else {
861 llvm::outs() << "0";
862 }
863 }
864 llvm::outs() << "\n";
865 }
866 llvm::outs() << "Cut enumeration completed successfully\n";
867}
868
869//===----------------------------------------------------------------------===//
870// CutRewriter
871//===----------------------------------------------------------------------===//
872
873LogicalResult CutRewriter::run(Operation *topOp) {
874 LLVM_DEBUG({
875 llvm::dbgs() << "Starting Cut Rewriter\n";
876 llvm::dbgs() << "Mode: "
878 : "timing")
879 << "\n";
880 llvm::dbgs() << "Max input size: " << options.maxCutInputSize << "\n";
881 llvm::dbgs() << "Max cut size: " << options.maxCutSizePerRoot << "\n";
882 llvm::dbgs() << "Max cuts per node: " << options.maxCutSizePerRoot << "\n";
883 });
884
885 // Currrently we don't support patterns with multiple outputs.
886 // So check that.
887 // TODO: This must be removed when we support multiple outputs.
888 for (auto &pattern : patterns.patterns) {
889 if (pattern->getNumOutputs() > 1) {
890 return mlir::emitError(pattern->getLoc(),
891 "Cut rewriter does not support patterns with "
892 "multiple outputs yet");
893 }
894 }
895
896 // First sort the operations topologically to ensure we can process them
897 // in a valid order.
898 if (failed(topologicallySortLogicNetwork(topOp)))
899 return failure();
900
901 // Enumerate cuts for all nodes
902 if (failed(enumerateCuts(topOp)))
903 return failure();
904
905 // Dump cuts if testing priority cuts.
908 return success();
909 }
910
911 // Select best cuts and perform mapping
912 if (failed(runBottomUpRewrite(topOp)))
913 return failure();
914
915 return success();
916}
917
918LogicalResult CutRewriter::enumerateCuts(Operation *topOp) {
919 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts...\n");
920
922 topOp, [&](const Cut &cut) -> std::optional<MatchedPattern> {
923 // Match the cut against the patterns
924 return patternMatchCut(cut);
925 });
926}
927
928ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
930 if (patterns.npnToPatternMap.empty())
931 return {};
932
933 auto &npnClass = cut.getNPNClass();
934 auto it = patterns.npnToPatternMap.find(
935 {npnClass.truthTable.table, npnClass.truthTable.numInputs});
936 if (it == patterns.npnToPatternMap.end())
937 return {};
938 return it->getSecond();
939}
940
941std::optional<MatchedPattern> CutRewriter::patternMatchCut(const Cut &cut) {
942 if (cut.isTrivialCut())
943 return {};
944
945 const CutRewritePattern *bestPattern = nullptr;
946 SmallVector<DelayType, 4> inputArrivalTimes;
947 SmallVector<DelayType, 1> bestArrivalTimes;
948 double bestArea = 0.0;
949 inputArrivalTimes.reserve(cut.getInputSize());
950 bestArrivalTimes.reserve(cut.getOutputSize());
951
952 // Compute arrival times for each input.
953 for (auto input : cut.inputs) {
954 assert(input.getType().isInteger(1));
955 if (isAlwaysCutInput(input)) {
956 // If the input is a primary input, it has no delay.
957 // TODO: This doesn't consider a global delay. Need to capture
958 // `arrivalTime` on the IR to make the primary input delays visible.
959 inputArrivalTimes.push_back(0);
960 continue;
961 }
962 auto *cutSet = cutEnumerator.getCutSet(input);
963 assert(cutSet && "Input must have a valid cut set");
964
965 // If there is no matching pattern, it means it's not possible to use the
966 // input in the cut rewriting. So abort early.
967 auto *bestCut = cutSet->getBestMatchedCut();
968 if (!bestCut)
969 return {};
970
971 const auto &matchedPattern = *bestCut->getMatchedPattern();
972
973 // Otherwise, the cut input is an op result. Get the arrival time
974 // from the matched pattern.
975 inputArrivalTimes.push_back(matchedPattern.getArrivalTime(
976 cast<mlir::OpResult>(input).getResultNumber()));
977 }
978
979 auto computeArrivalTimeAndPickBest =
980 [&](const CutRewritePattern *pattern, const MatchResult &matchResult,
981 llvm::function_ref<unsigned(unsigned)> mapIndex) {
982 SmallVector<DelayType, 1> outputArrivalTimes;
983 // Compute the maximum delay for each output from inputs.
984 for (unsigned outputIndex = 0, outputSize = cut.getOutputSize();
985 outputIndex < outputSize; ++outputIndex) {
986 // Compute the arrival time for this output.
987 DelayType outputArrivalTime = 0;
988 auto delays = matchResult.getDelays();
989 for (unsigned inputIndex = 0, inputSize = cut.getInputSize();
990 inputIndex < inputSize; ++inputIndex) {
991 // Map pattern input i to cut input through NPN transformations
992 unsigned cutOriginalInput = mapIndex(inputIndex);
993 outputArrivalTime =
994 std::max(outputArrivalTime,
995 delays[outputIndex * inputSize + inputIndex] +
996 inputArrivalTimes[cutOriginalInput]);
997 }
998
999 outputArrivalTimes.push_back(outputArrivalTime);
1000 }
1001
1002 // Update the arrival time
1003 if (!bestPattern ||
1004 compareDelayAndArea(options.strategy, matchResult.area,
1005 outputArrivalTimes, bestArea,
1006 bestArrivalTimes)) {
1007 LLVM_DEBUG({
1008 llvm::dbgs() << "== Matched Pattern ==============\n";
1009 llvm::dbgs() << "Matching cut: \n";
1010 cut.dump(llvm::dbgs());
1011 llvm::dbgs() << "Found better pattern: "
1012 << pattern->getPatternName();
1013 llvm::dbgs() << " with area: " << matchResult.area;
1014 llvm::dbgs() << " and input arrival times: ";
1015 for (unsigned i = 0; i < inputArrivalTimes.size(); ++i) {
1016 llvm::dbgs() << " " << inputArrivalTimes[i];
1017 }
1018 llvm::dbgs() << " and arrival times: ";
1019
1020 for (auto arrivalTime : outputArrivalTimes) {
1021 llvm::dbgs() << " " << arrivalTime;
1022 }
1023 llvm::dbgs() << "\n";
1024 llvm::dbgs() << "== Matched Pattern End ==============\n";
1025 });
1026
1027 bestArrivalTimes = std::move(outputArrivalTimes);
1028 bestArea = matchResult.area;
1029 bestPattern = pattern;
1030 }
1031 };
1032
1033 for (auto &[patternNPN, pattern] : getMatchingPatternsFromTruthTable(cut)) {
1034 assert(patternNPN.truthTable.numInputs == cut.getInputSize() &&
1035 "Pattern input size must match cut input size");
1036 auto matchResult = pattern->match(cutEnumerator, cut);
1037 if (!matchResult)
1038 continue;
1039 auto &cutNPN = cut.getNPNClass();
1040
1041 // Get the input mapping from pattern's NPN class to cut's NPN class
1042 SmallVector<unsigned> inputMapping;
1043 cutNPN.getInputPermutation(patternNPN, inputMapping);
1044 computeArrivalTimeAndPickBest(pattern, *matchResult,
1045 [&](unsigned i) { return inputMapping[i]; });
1046 }
1047
1048 for (const CutRewritePattern *pattern : patterns.nonNPNPatterns) {
1049 if (auto matchResult = pattern->match(cutEnumerator, cut))
1050 computeArrivalTimeAndPickBest(pattern, *matchResult,
1051 [&](unsigned i) { return i; });
1052 }
1053
1054 if (!bestPattern)
1055 return {}; // No matching pattern found
1056
1057 return MatchedPattern(bestPattern, std::move(bestArrivalTimes), bestArea);
1058}
1059
1060LogicalResult CutRewriter::runBottomUpRewrite(Operation *top) {
1061 LLVM_DEBUG(llvm::dbgs() << "Performing cut-based rewriting...\n");
1062 const auto &cutVector = cutEnumerator.getCutSets();
1063 // Note: Don't clear cutEnumerator yet - we need it during rewrite
1064 UnusedOpPruner pruner;
1065 PatternRewriter rewriter(top->getContext());
1066 for (auto &[value, cutSet] : llvm::reverse(cutVector)) {
1067 if (value.use_empty()) {
1068 if (auto *op = value.getDefiningOp())
1069 pruner.eraseNow(op);
1070 continue;
1071 }
1072
1073 if (isAlwaysCutInput(value)) {
1074 // If the value is a primary input, skip it
1075 LLVM_DEBUG(llvm::dbgs() << "Skipping inputs: " << value << "\n");
1076 continue;
1077 }
1078
1079 LLVM_DEBUG(llvm::dbgs() << "Cut set for value: " << value << "\n");
1080 auto *bestCut = cutSet->getBestMatchedCut();
1081 if (!bestCut) {
1083 continue; // No matching pattern found, skip this value
1084 return emitError(value.getLoc(), "No matching cut found for value: ")
1085 << value;
1086 }
1087
1088 rewriter.setInsertionPoint(bestCut->getRoot());
1089 const auto &matchedPattern = bestCut->getMatchedPattern();
1090 auto result = matchedPattern->getPattern()->rewrite(rewriter, cutEnumerator,
1091 *bestCut);
1092 if (failed(result))
1093 return failure();
1094
1095 rewriter.replaceOp(bestCut->getRoot(), *result);
1096
1098 auto array = rewriter.getI64ArrayAttr(matchedPattern->getArrivalTimes());
1099 (*result)->setAttr("test.arrival_times", array);
1100 }
1101 }
1102
1103 // Clear the enumerator after rewriting is complete
1105 return success();
1106}
assert(baseType &&"element must be base type")
static bool isAlwaysCutInput(Value value)
static bool isSupportedLogicOp(mlir::Operation *op)
static Cut getAsTrivialCut(mlir::Value value)
static FailureOr< BinaryTruthTable > computeTruthTable(mlir::ValueRange values, const OpRange &ops, const llvm::SmallSetVector< mlir::Value, 4 > &inputArgs)
Get the truth table for an op.
static StringRef getTestVariableName(Value value, DenseMap< OperationName, unsigned > &opCounter)
Generate a human-readable name for a value used in test output.
static bool compareDelayAndArea(OptimizationStrategy strategy, double newArea, ArrayRef< DelayType > newDelay, double oldArea, ArrayRef< DelayType > oldDelay)
static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl< Cut > &cuts)
static bool isCutDerivedFromOperand(const Cut &cut, Operation *op)
static void simulateLogicOp(Operation *op, DenseMap< Value, llvm::APInt > &eval)
RewritePatternSet pattern
Strategy strategy
Cut enumeration engine for combinational logic networks.
LogicalResult visit(Operation *op)
Visit a single operation and generate cuts for it.
const CutRewriterOptions & options
Configuration options for cut enumeration.
llvm::MapVector< Value, std::unique_ptr< CutSet > > cutSets
Maps values to their associated cut sets.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
const llvm::MapVector< Value, std::unique_ptr< CutSet > > & getCutSets() const
CutEnumerator(const CutRewriterOptions &options)
Constructor for cut enumerator.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
LogicalResult visitLogicOp(Operation *logicOp)
Visit a combinational logic operation and generate cuts.
llvm::MapVector< Value, std::unique_ptr< CutSet > > takeVector()
Move ownership of all cut sets to caller.
CutSet * createNewCutSet(Value value)
Create a new cut set for a value.
const CutSet * getCutSet(Value value)
Get the cut set for a specific value.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
CutRewritePatternSet(llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns)
Constructor that takes ownership of the provided patterns.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
CutEnumerator cutEnumerator
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
llvm::SmallVector< Cut, 4 > cuts
Collection of cuts for this node.
void addCut(Cut cut)
Add a new cut to this set.
unsigned size() const
Get the number of cuts in this set.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut)
Finalize the cut set by removing duplicates and selecting the best pattern.
bool isFrozen
Whether cut set is finalized.
ArrayRef< Cut > getCuts() const
Get read-only access to all cuts in this set.
Represents a cut in the combinational logic network.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
llvm::SmallSetVector< mlir::Operation *, 4 > operations
Operations contained within this cut.
unsigned getOutputSize() const
Get the number of outputs from root operation.
void getPermutatedInputs(const NPNClass &patternNPN, SmallVectorImpl< Value > &permutedInputs) const
Get the permutated inputs for this cut based on the given pattern NPN.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
mlir::Operation * getRoot() const
Get the root operation of this cut.
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
llvm::SmallSetVector< mlir::Value, 4 > inputs
External inputs to this cut (cut boundary).
void dump(llvm::raw_ostream &os) const
const BinaryTruthTable & getTruthTable() const
Get the truth table for this cut.
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
Cut mergeWith(const Cut &other, Operation *root) const
Merge this cut with another cut to form a new cut.
unsigned getInputSize() const
Get the number of inputs to this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
Cut reRoot(Operation *root) const
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
double getArea() const
Get the area cost of using this pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
OptimizationStrategy
Optimization strategy.
Definition SynthPasses.h:24
@ OptimizationStrategyArea
Optimize for minimal area.
Definition SynthPasses.h:25
@ OptimizationStrategyTiming
Optimize for minimal critical path delay.
Definition SynthPasses.h:26
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
int64_t DelayType
Definition CutRewriter.h:36
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:306
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:41
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::APInt createVarMask(unsigned numVars, unsigned varIndex, bool positive)
Create a mask for a variable in the truth table.
Represents a boolean function as a truth table.
Definition TruthTable.h:39
Represents the canonical form of a boolean function under NPN equivalence.
Definition TruthTable.h:104
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
void getInputPermutation(const NPNClass &targetNPN, llvm::SmallVectorImpl< unsigned > &permutation) const
Get input permutation from this NPN class to another equivalent NPN class.
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...
void eraseNow(Operation *op)
Erase an operation immediately, and remove it from the set of ops to be removed later.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Result of matching a cut against a pattern.
Definition CutRewriter.h:74