CIRCT 23.0.0git
Loading...
Searching...
No Matches
CutRewriter.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a DAG-based boolean matching cut rewriting algorithm for
10// applications like technology/LUT mapping and combinational logic
11// optimization. The algorithm uses priority cuts and NPN
12// (Negation-Permutation-Negation) canonical forms to efficiently match cuts
13// against rewriting patterns.
14//
15// References:
16// "Combinational and Sequential Mapping with Priority Cuts", Alan Mishchenko,
17// Sungmin Cho, Satrajit Chatterjee and Robert Brayton, ICCAD 2007
18//
19//===----------------------------------------------------------------------===//
20
22
26#include "circt/Support/LLVM.h"
29#include "mlir/Analysis/TopologicalSortUtils.h"
30#include "mlir/IR/Builders.h"
31#include "mlir/IR/Operation.h"
32#include "mlir/IR/RegionKindInterface.h"
33#include "mlir/IR/Value.h"
34#include "mlir/IR/ValueRange.h"
35#include "mlir/IR/Visitors.h"
36#include "mlir/Support/LLVM.h"
37#include "llvm/ADT/APInt.h"
38#include "llvm/ADT/Bitset.h"
39#include "llvm/ADT/DenseMap.h"
40#include "llvm/ADT/MapVector.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/ScopeExit.h"
43#include "llvm/ADT/SetVector.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/TypeSwitch.h"
46#include "llvm/ADT/iterator.h"
47#include "llvm/Support/Debug.h"
48#include "llvm/Support/ErrorHandling.h"
49#include "llvm/Support/LogicalResult.h"
50#include <algorithm>
51#include <functional>
52#include <memory>
53#include <optional>
54#include <string>
55
56#define DEBUG_TYPE "synth-cut-rewriter"
57
58using namespace circt;
59using namespace circt::synth;
60
61//===----------------------------------------------------------------------===//
62// LogicNetwork
63//===----------------------------------------------------------------------===//
64
65uint32_t LogicNetwork::getOrCreateIndex(Value value) {
66 auto [it, inserted] = valueToIndex.try_emplace(value, gates.size());
67 if (inserted) {
68 indexToValue.push_back(value);
69 gates.emplace_back(); // Will be filled in later
70 }
71 return it->second;
72}
73
74uint32_t LogicNetwork::getIndex(Value value) const {
75 const auto it = valueToIndex.find(value);
76 assert(it != valueToIndex.end() &&
77 "Value not found in LogicNetwork - use getOrCreateIndex or check with "
78 "hasIndex first");
79 return it->second;
80}
81
82bool LogicNetwork::hasIndex(Value value) const {
83 return valueToIndex.contains(value);
84}
85
86Value LogicNetwork::getValue(uint32_t index) const {
87 // Index 0 and 1 are reserved for constants, they have no associated Value
88 if (index == kConstant0 || index == kConstant1)
89 return Value();
90
91 assert(index < indexToValue.size() &&
92 "Index out of bounds in LogicNetwork::getValue");
93 return indexToValue[index];
94}
95
96void LogicNetwork::getValues(ArrayRef<uint32_t> indices,
97 SmallVectorImpl<Value> &values) const {
98 values.clear();
99 values.reserve(indices.size());
100 for (uint32_t idx : indices)
101 values.push_back(getValue(idx));
102}
103
104uint32_t LogicNetwork::addPrimaryInput(Value value) {
105 const uint32_t index = getOrCreateIndex(value);
107 return index;
108}
109
110uint32_t LogicNetwork::addGate(Operation *op, LogicNetworkGate::Kind kind,
111 Value result, ArrayRef<Signal> operands) {
112 const uint32_t index = getOrCreateIndex(result);
113 gates[index] = LogicNetworkGate(op, kind, operands);
114 return index;
115}
116
117LogicalResult LogicNetwork::buildFromBlock(Block *block) {
118 // Pre-size vectors to reduce reallocations (rough estimate)
119 const size_t estimatedSize =
120 block->getArguments().size() + block->getOperations().size();
121 indexToValue.reserve(estimatedSize);
122 gates.reserve(estimatedSize);
123
124 auto handleSingleInputGate = [&](Operation *op, Value result,
125 const Signal &inputSignal) {
126 if (!inputSignal.isInverted()) {
127 // Non-inverted buffer: directly alias the result to the input
128 valueToIndex[result] = inputSignal.getIndex();
129 return;
130 }
131 // Inverted operation: create a NOT gate
132 addGate(op, LogicNetworkGate::Identity, result, {inputSignal});
133 };
134
135 // Ensure all block arguments are indexed as primary inputs first
136 for (Value arg : block->getArguments()) {
137 if (!hasIndex(arg))
138 addPrimaryInput(arg);
139 }
140
141 auto handleOtherResults = [&](Operation *op) {
142 for (Value result : op->getResults()) {
143 if (result.getType().isInteger(1) && !hasIndex(result))
144 addPrimaryInput(result);
145 }
146 };
147
148 // Process operations in topological order
149 for (Operation &op : block->getOperations()) {
150 LogicalResult result =
151 llvm::TypeSwitch<Operation *, LogicalResult>(&op)
152 .Case<aig::AndInverterOp>([&](aig::AndInverterOp andOp) {
153 const auto inputs = andOp.getInputs();
154 if (inputs.size() == 1) {
155 // Single-input AND is a buffer or NOT gate
156 const Signal inputSignal =
157 getOrCreateSignal(inputs[0], andOp.isInverted(0));
158 handleSingleInputGate(andOp, andOp.getResult(), inputSignal);
159 } else if (inputs.size() == 2) {
160 const Signal lhsSignal =
161 getOrCreateSignal(inputs[0], andOp.isInverted(0));
162 const Signal rhsSignal =
163 getOrCreateSignal(inputs[1], andOp.isInverted(1));
164 addGate(andOp, LogicNetworkGate::And2, {lhsSignal, rhsSignal});
165 } else {
166 // Variadic AND gates with >2 inputs are treated as primary
167 // inputs.
168 handleOtherResults(andOp);
169 }
170 return success();
171 })
172 .Case<comb::XorOp>([&](comb::XorOp xorOp) {
173 if (xorOp->getNumOperands() != 2) {
174 handleOtherResults(xorOp);
175 return success();
176 }
177 const Signal lhsSignal =
178 getOrCreateSignal(xorOp.getOperand(0), false);
179 const Signal rhsSignal =
180 getOrCreateSignal(xorOp.getOperand(1), false);
181 addGate(xorOp, LogicNetworkGate::Xor2, {lhsSignal, rhsSignal});
182 return success();
183 })
184 .Case<synth::mig::MajorityInverterOp>(
185 [&](synth::mig::MajorityInverterOp majOp) {
186 if (majOp->getNumOperands() == 1) {
187 // Single input = inverter
188 const Signal inputSignal = getOrCreateSignal(
189 majOp.getOperand(0), majOp.isInverted(0));
190 handleSingleInputGate(majOp, majOp.getResult(),
191 inputSignal);
192 return success();
193 }
194 if (majOp->getNumOperands() != 3) {
195 // Variadic MAJ is treated as primary inputs.
196 handleOtherResults(majOp);
197 return success();
198 }
199 const Signal aSignal = getOrCreateSignal(majOp.getOperand(0),
200 majOp.isInverted(0));
201 const Signal bSignal = getOrCreateSignal(majOp.getOperand(1),
202 majOp.isInverted(1));
203 const Signal cSignal = getOrCreateSignal(majOp.getOperand(2),
204 majOp.isInverted(2));
206 {aSignal, bSignal, cSignal});
207 return success();
208 })
209 .Case<hw::ConstantOp>([&](hw::ConstantOp constOp) {
210 Value result = constOp.getResult();
211 if (!result.getType().isInteger(1)) {
212 handleOtherResults(constOp);
213 return success();
214 }
215 uint32_t constIdx =
216 constOp.getValue().isZero() ? kConstant0 : kConstant1;
217 valueToIndex[result] = constIdx;
218 return success();
219 })
220 .Default([&](Operation *defaultOp) {
221 handleOtherResults(defaultOp);
222 return success();
223 });
224
225 if (failed(result))
226 return result;
227 }
228
229 return success();
230}
231
233 valueToIndex.clear();
234 indexToValue.clear();
235 gates.clear();
236 // Re-add the constant nodes (index 0 = const0, index 1 = const1)
237 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
238 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
239 // Placeholders for constants in indexToValue
240 indexToValue.push_back(Value()); // const0
241 indexToValue.push_back(Value()); // const1
242}
243
244//===----------------------------------------------------------------------===//
245// Helper functions
246//===----------------------------------------------------------------------===//
247
248// Return true if the gate at the given index is always a cut input.
249static bool isAlwaysCutInput(const LogicNetwork &network, uint32_t index) {
250 const auto &gate = network.getGate(index);
251 return gate.isAlwaysCutInput();
252}
253
254// Return true if the new area/delay is better than the old area/delay in the
255// context of the given strategy.
257 ArrayRef<DelayType> newDelay, double oldArea,
258 ArrayRef<DelayType> oldDelay) {
260 return newArea < oldArea || (newArea == oldArea && newDelay < oldDelay);
262 return newDelay < oldDelay || (newDelay == oldDelay && newArea < oldArea);
263 llvm_unreachable("Unknown mapping strategy");
264}
265
266LogicalResult circt::synth::topologicallySortLogicNetwork(Operation *topOp) {
267 const auto isOperationReady = [](Value value, Operation *op) -> bool {
268 // Topologically sort AIG ops, MIG ops, and dataflow ops. Other operations
269 // can be scheduled.
270 return !(isa<aig::AndInverterOp, mig::MajorityInverterOp>(op) ||
271 isa<comb::XorOp, comb::AndOp, comb::ExtractOp, comb::ReplicateOp,
272 comb::ConcatOp>(op));
273 };
274
275 if (failed(topologicallySortGraphRegionBlocks(topOp, isOperationReady)))
276 return emitError(topOp->getLoc(),
277 "failed to sort operations topologically");
278 return success();
279}
280
281/// Get the truth table for operations within a block.
282FailureOr<BinaryTruthTable> circt::synth::getTruthTable(ValueRange values,
283 Block *block) {
284 llvm::SmallSetVector<Value, 4> inputArgs;
285 for (Value arg : block->getArguments())
286 inputArgs.insert(arg);
287
288 if (inputArgs.empty())
289 return BinaryTruthTable();
290
291 const int64_t numInputs = inputArgs.size();
292 const int64_t numOutputs = values.size();
293 if (LLVM_UNLIKELY(numOutputs != 1 || numInputs >= maxTruthTableInputs)) {
294 if (numOutputs == 0)
295 return BinaryTruthTable(numInputs, 0);
296 if (numInputs >= maxTruthTableInputs)
297 return mlir::emitError(values.front().getLoc(),
298 "Truth table is too large");
299 return mlir::emitError(values.front().getLoc(),
300 "Multiple outputs are not supported yet");
301 }
302
303 // Create a map to evaluate the operation
304 DenseMap<Value, APInt> eval;
305 for (uint32_t i = 0; i < numInputs; ++i)
306 eval[inputArgs[i]] = circt::createVarMask(numInputs, i, true);
307
308 // Simulate the operations in the block
309 for (Operation &op : *block) {
310 if (op.getNumResults() == 0)
311 continue;
312
313 // Support AIG, XOR, and MIG operations
314 if (auto andOp = dyn_cast<aig::AndInverterOp>(&op)) {
315 SmallVector<llvm::APInt, 2> inputs;
316 inputs.reserve(andOp.getInputs().size());
317 for (auto input : andOp.getInputs()) {
318 auto it = eval.find(input);
319 if (it == eval.end())
320 return andOp.emitError("Input value not found in evaluation map");
321 inputs.push_back(it->second);
322 }
323 eval[andOp.getResult()] = andOp.evaluate(inputs);
324 } else if (auto xorOp = dyn_cast<comb::XorOp>(&op)) {
325 auto it = eval.find(xorOp.getOperand(0));
326 if (it == eval.end())
327 return xorOp.emitError("Input value not found in evaluation map");
328 llvm::APInt result = it->second;
329 for (unsigned i = 1; i < xorOp.getNumOperands(); ++i) {
330 it = eval.find(xorOp.getOperand(i));
331 if (it == eval.end())
332 return xorOp.emitError("Input value not found in evaluation map");
333 result ^= it->second;
334 }
335 eval[xorOp.getResult()] = result;
336 } else if (auto migOp = dyn_cast<synth::mig::MajorityInverterOp>(&op)) {
337 SmallVector<llvm::APInt, 3> inputs;
338 inputs.reserve(migOp.getInputs().size());
339 for (auto input : migOp.getInputs()) {
340 auto it = eval.find(input);
341 if (it == eval.end())
342 return migOp.emitError("Input value not found in evaluation map");
343 inputs.push_back(it->second);
344 }
345 eval[migOp.getResult()] = migOp.evaluate(inputs);
346 } else if (!isa<hw::OutputOp>(&op)) {
347 return op.emitError("Unsupported operation for truth table simulation");
348 }
349 }
350
351 return BinaryTruthTable(numInputs, 1, eval[values[0]]);
352}
353
354//===----------------------------------------------------------------------===//
355// Cut
356//===----------------------------------------------------------------------===//
357
358bool Cut::isTrivialCut() const {
359 // A cut is a trivial cut if it has no root (rootIndex == 0 sentinel)
360 // and only one input
361 return rootIndex == 0 && inputs.size() == 1;
362}
363
365 // If the NPN is already computed, return it
366 if (npnClass)
367 return *npnClass;
368
369 const auto &truthTable = *getTruthTable();
370
371 // Compute the NPN canonical form
372 auto canonicalForm = NPNClass::computeNPNCanonicalForm(truthTable);
373
374 npnClass.emplace(std::move(canonicalForm));
375 return *npnClass;
376}
377
379 const NPNClass &patternNPN,
380 SmallVectorImpl<unsigned> &permutedIndices) const {
381 auto npnClass = getNPNClass();
382 npnClass.getInputPermutation(patternNPN, permutedIndices);
383}
384
385LogicalResult
387 SmallVectorImpl<DelayType> &results) const {
388 results.reserve(getInputSize());
389 const auto &network = enumerator.getLogicNetwork();
390
391 // Compute arrival times for each input.
392 for (auto inputIndex : inputs) {
393 if (isAlwaysCutInput(network, inputIndex)) {
394 // If the input is a primary input, it has no delay.
395 results.push_back(0);
396 continue;
397 }
398 auto *cutSet = enumerator.getCutSet(inputIndex);
399 assert(cutSet && "Input must have a valid cut set");
400
401 // If there is no matching pattern, it means it's not possible to use the
402 // input in the cut rewriting. Return empty vector to indicate failure.
403 auto *bestCut = cutSet->getBestMatchedCut();
404 if (!bestCut)
405 return failure();
406
407 const auto &matchedPattern = *bestCut->getMatchedPattern();
408
409 // Get the value for result number lookup
410 mlir::Value inputValue = network.getValue(inputIndex);
411 // Otherwise, the cut input is an op result. Get the arrival time
412 // from the matched pattern.
413 results.push_back(matchedPattern.getArrivalTime(
414 cast<mlir::OpResult>(inputValue).getResultNumber()));
415 }
416
417 return success();
418}
419
420void Cut::dump(llvm::raw_ostream &os, const LogicNetwork &network) const {
421 os << "// === Cut Dump ===\n";
422 os << "Cut with " << getInputSize() << " inputs";
423 if (rootIndex != 0) {
424 auto *rootOp = network.getGate(rootIndex).getOperation();
425 if (rootOp)
426 os << " and root: " << *rootOp;
427 }
428 os << "\n";
429
430 if (isTrivialCut()) {
431 mlir::Value inputVal = network.getValue(inputs[0]);
432 os << "Primary input cut: " << inputVal << "\n";
433 return;
434 }
435
436 os << "Inputs (indices): \n";
437 for (auto [idx, inputIndex] : llvm::enumerate(inputs)) {
438 mlir::Value inputVal = network.getValue(inputIndex);
439 os << " Input " << idx << " (index " << inputIndex << "): " << inputVal
440 << "\n";
441 }
442
443 if (rootIndex != 0) {
444 os << "\nRoot operation: \n";
445 if (auto *rootOp = network.getGate(rootIndex).getOperation())
446 rootOp->print(os);
447 os << "\n";
448 }
449
450 auto &npnClass = getNPNClass();
451 npnClass.dump(os);
452
453 os << "// === Cut End ===\n";
454}
455
456unsigned Cut::getInputSize() const { return inputs.size(); }
457
458unsigned Cut::getOutputSize(const LogicNetwork &network) const {
459 if (rootIndex == 0)
460 return 1; // Trivial cut has 1 output
461 auto *rootOp = network.getGate(rootIndex).getOperation();
462 return rootOp ? rootOp->getNumResults() : 1;
463}
464
465/// Simulate a gate and return its truth table.
466static inline llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind,
467 const llvm::APInt &a) {
468 switch (kind) {
470 return a;
471 default:
472 llvm_unreachable("Unsupported unary operation for truth table computation");
473 }
474}
475
476static inline llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind,
477 const llvm::APInt &a,
478 const llvm::APInt &b) {
479 switch (kind) {
481 return a & b;
483 return a ^ b;
484 default:
485 llvm_unreachable(
486 "Unsupported binary operation for truth table computation");
487 }
488}
489
490static inline llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind,
491 const llvm::APInt &a,
492 const llvm::APInt &b,
493 const llvm::APInt &c) {
494 switch (kind) {
496 return (a & b) | (a & c) | (b & c);
497 default:
498 llvm_unreachable(
499 "Unsupported ternary operation for truth table computation");
500 }
501}
502
503/// Simulate a gate and return its truth table.
504static llvm::APInt simulateGate(const LogicNetwork &network, uint32_t index,
505 llvm::DenseMap<uint32_t, llvm::APInt> &cache,
506 unsigned numInputs) {
507 // Check cache first
508 auto cacheIt = cache.find(index);
509 if (cacheIt != cache.end())
510 return cacheIt->second;
511
512 const auto &gate = network.getGate(index);
513 llvm::APInt result;
514
515 auto getEdgeTT = [&](const Signal &edge) {
516 auto tt = simulateGate(network, edge.getIndex(), cache, numInputs);
517 if (edge.isInverted())
518 tt.flipAllBits();
519 return tt;
520 };
521
522 switch (gate.getKind()) {
524 // Constant 0 or 1 - return all zeros or all ones
525 if (index == LogicNetwork::kConstant0)
526 result = llvm::APInt::getZero(1U << numInputs);
527 else
528 result = llvm::APInt::getAllOnes(1U << numInputs);
529 break;
530 }
531
533 // Should be in cache already as cut input
534 llvm_unreachable("Primary input not in cache - not a cut input?");
535
538 result = applyGateSemantics(gate.getKind(), getEdgeTT(gate.edges[0]),
539 getEdgeTT(gate.edges[1]));
540 break;
541 }
542
544 result =
545 applyGateSemantics(gate.getKind(), getEdgeTT(gate.edges[0]),
546 getEdgeTT(gate.edges[1]), getEdgeTT(gate.edges[2]));
547 break;
548 }
549
551 result = applyGateSemantics(gate.getKind(), getEdgeTT(gate.edges[0]));
552 break;
553 }
554 }
555
556 cache[index] = result;
557 return result;
558}
559
561 if (isTrivialCut()) {
562 // For a trivial cut, a truth table is simply the identity function.
563 // 0 -> 0, 1 -> 1
564 truthTable.emplace(1, 1, llvm::APInt(2, 2));
565 return;
566 }
567
568 unsigned numInputs = inputs.size();
569 if (numInputs >= maxTruthTableInputs) {
570 llvm_unreachable("Too many inputs for truth table computation");
571 }
572
573 // Initialize cache with input variable masks
574 llvm::DenseMap<uint32_t, llvm::APInt> cache;
575 for (unsigned i = 0; i < numInputs; ++i) {
576 cache[inputs[i]] = circt::createVarMask(numInputs, i, true);
577 }
578
579 // Simulate from root
580 llvm::APInt result = simulateGate(network, rootIndex, cache, numInputs);
581
582 truthTable.emplace(numInputs, 1, result);
583}
584
586 computeTruthTable(network);
587}
588
589bool Cut::dominates(const Cut &other) const { return dominates(other.inputs); }
590
591bool Cut::dominates(ArrayRef<uint32_t> otherInputs) const {
592 if (getInputSize() > otherInputs.size())
593 return false;
594
595 return std::includes(otherInputs.begin(), otherInputs.end(), inputs.begin(),
596 inputs.end());
597}
598
599static Cut getAsTrivialCut(uint32_t index, const LogicNetwork &network) {
600 // Create a trivial cut for a value
601 Cut cut;
602 cut.inputs.push_back(index);
603 // Compute truth table eagerly for trivial cut
604 cut.computeTruthTable(network);
605 return cut;
606}
607
608//===----------------------------------------------------------------------===//
609// MatchedPattern
610//===----------------------------------------------------------------------===//
611
612ArrayRef<DelayType> MatchedPattern::getArrivalTimes() const {
613 assert(pattern && "Pattern must be set to get arrival time");
614 return arrivalTimes;
615}
616
618 assert(pattern && "Pattern must be set to get arrival time");
619 return arrivalTimes[index];
620}
621
623 assert(pattern && "Pattern must be set to get the pattern");
624 return pattern;
625}
626
628 assert(pattern && "Pattern must be set to get area");
629 return area;
630}
631
632//===----------------------------------------------------------------------===//
633// CutSet
634//===----------------------------------------------------------------------===//
635
637
638unsigned CutSet::size() const { return cuts.size(); }
639
640void CutSet::addCut(Cut *cut) {
641 assert(!isFrozen && "Cannot add cuts to a frozen cut set");
642 cuts.push_back(cut);
643}
644
645ArrayRef<Cut *> CutSet::getCuts() const { return cuts; }
646
647// Remove duplicate cuts and non-minimal cuts. A cut is non-minimal if there
648// exists another cut that is a subset of it.
649static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl<Cut *> &cuts) {
650 auto dumpInputs = [](llvm::raw_ostream &os,
651 const llvm::SmallVectorImpl<uint32_t> &inputs) {
652 os << "{";
653 llvm::interleaveComma(inputs, os);
654 os << "}";
655 };
656 // Sort by size, then lexicographically by inputs. This enables cheap exact
657 // duplicate elimination and tighter candidate filtering for subset checks.
658 std::stable_sort(cuts.begin(), cuts.end(), [](const Cut *a, const Cut *b) {
659 if (a->getInputSize() != b->getInputSize())
660 return a->getInputSize() < b->getInputSize();
661 return std::lexicographical_compare(a->inputs.begin(), a->inputs.end(),
662 b->inputs.begin(), b->inputs.end());
663 });
664
665 // Group kept cuts by input size so subset checks only visit smaller cuts.
666 unsigned maxCutSize = cuts.empty() ? 0 : cuts.back()->getInputSize();
667 llvm::SmallVector<llvm::SmallVector<Cut *, 4>, 16> keptBySize(maxCutSize + 1);
668
669 // Compact kept cuts in-place.
670 unsigned uniqueCount = 0;
671 for (Cut *cut : cuts) {
672 unsigned cutSize = cut->getInputSize();
673
674 // Fast exact duplicate check: with lexicographic sort, duplicates are
675 // adjacent among cuts with equal size.
676 if (uniqueCount > 0) {
677 Cut *lastKept = cuts[uniqueCount - 1];
678 if (lastKept->getInputSize() == cutSize &&
679 lastKept->inputs == cut->inputs)
680 continue;
681 }
682
683 bool isDominated = false;
684 for (unsigned existingSize = 1; existingSize < cutSize && !isDominated;
685 ++existingSize) {
686 for (const Cut *existingCut : keptBySize[existingSize]) {
687 if (!existingCut->dominates(*cut))
688 continue;
689
690 LLVM_DEBUG({
691 llvm::dbgs() << "Dropping non-minimal cut ";
692 dumpInputs(llvm::dbgs(), cut->inputs);
693 llvm::dbgs() << " due to subset ";
694 dumpInputs(llvm::dbgs(), existingCut->inputs);
695 llvm::dbgs() << "\n";
696 });
697 isDominated = true;
698 break;
699 }
700 }
701
702 if (isDominated)
703 continue;
704
705 cuts[uniqueCount++] = cut;
706 keptBySize[cutSize].push_back(cut);
707 }
708
709 LLVM_DEBUG(llvm::dbgs() << "Original cuts: " << cuts.size()
710 << " Unique cuts: " << uniqueCount << "\n");
711
712 // Resize the cuts vector to the number of surviving cuts.
713 cuts.resize(uniqueCount);
714}
715
717 const CutRewriterOptions &options,
718 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut,
719 const LogicNetwork &logicNetwork) {
720
721 // Remove duplicate/non-minimal cuts first so all follow-up work only runs on
722 // survivors.
724
725 // Compute truth tables lazily, then match cuts to collect timing/area data.
726 for (Cut *cut : cuts) {
727 if (!cut->getTruthTable().has_value())
728 cut->computeTruthTableFromOperands(logicNetwork);
729
730 assert(cut->getInputSize() <= options.maxCutInputSize &&
731 "Cut input size exceeds maximum allowed size");
732
733 if (auto matched = matchCut(*cut))
734 cut->setMatchedPattern(std::move(*matched));
735 }
736
737 // Sort cuts by priority to select the most promising ones.
738 // Priority is determined by the optimization strategy:
739 // - Trivial cuts (direct connections) have highest priority
740 // - Among matched cuts, compare by area/delay based on the strategy
741 // - Matched cuts are preferred over unmatched cuts
742 // See "Combinational and Sequential Mapping with Priority Cuts" by Mishchenko
743 // et al., ICCAD 2007 for more details.
744 // TODO: Use a priority queue instead of sorting for better performance.
745
746 // Partition the cuts into trivial and non-trivial cuts.
747 auto *trivialCutsEnd =
748 std::stable_partition(cuts.begin(), cuts.end(),
749 [](const Cut *cut) { return cut->isTrivialCut(); });
750
751 auto isBetterCut = [&options](const Cut *a, const Cut *b) {
752 assert(!a->isTrivialCut() && !b->isTrivialCut() &&
753 "Trivial cuts should have been excluded");
754 const auto &aMatched = a->getMatchedPattern();
755 const auto &bMatched = b->getMatchedPattern();
756
757 if (aMatched && bMatched)
758 return compareDelayAndArea(
759 options.strategy, aMatched->getArea(), aMatched->getArrivalTimes(),
760 bMatched->getArea(), bMatched->getArrivalTimes());
761
762 if (static_cast<bool>(aMatched) != static_cast<bool>(bMatched))
763 return static_cast<bool>(aMatched);
764
765 return a->getInputSize() < b->getInputSize();
766 };
767 std::stable_sort(trivialCutsEnd, cuts.end(), isBetterCut);
768
769 // Keep only the top-K cuts to bound growth.
770 if (cuts.size() > options.maxCutSizePerRoot)
771 cuts.resize(options.maxCutSizePerRoot);
772
773 // Select the best cut from the remaining candidates.
774 bestCut = nullptr;
775 for (Cut *cut : cuts) {
776 const auto &currentMatch = cut->getMatchedPattern();
777 if (!currentMatch)
778 continue;
779 bestCut = cut;
780 break;
781 }
782
783 LLVM_DEBUG({
784 llvm::dbgs() << "Finalized cut set with " << cuts.size() << " cuts and "
785 << (bestCut
786 ? "matched pattern to " + bestCut->getMatchedPattern()
787 ->getPattern()
788 ->getPatternName()
789 : "no matched pattern")
790 << "\n";
791 });
792
793 isFrozen = true; // Mark the cut set as frozen
794}
795
796//===----------------------------------------------------------------------===//
797// CutRewritePattern
798//===----------------------------------------------------------------------===//
799
801 SmallVectorImpl<NPNClass> &matchingNPNClasses) const {
802 return false;
803}
804
805//===----------------------------------------------------------------------===//
806// CutRewritePatternSet
807//===----------------------------------------------------------------------===//
808
810 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns)
811 : patterns(std::move(patterns)) {
812 // Initialize the NPN to pattern map
813 for (auto &pattern : this->patterns) {
814 SmallVector<NPNClass, 2> npnClasses;
815 auto result = pattern->useTruthTableMatcher(npnClasses);
816 if (result) {
817 for (auto npnClass : npnClasses) {
818 // Create a NPN class from the truth table
819 npnToPatternMap[{npnClass.truthTable.table,
820 npnClass.truthTable.numInputs}]
821 .push_back(std::make_pair(std::move(npnClass), pattern.get()));
822 }
823 } else {
824 // If the pattern does not provide NPN classes, we use a special key
825 // to indicate that it should be considered for all cuts.
826 nonNPNPatterns.push_back(pattern.get());
827 }
828 }
829}
830
831//===----------------------------------------------------------------------===//
832// CutEnumerator
833//===----------------------------------------------------------------------===//
834
836 : options(options) {}
837
839 CutSet *cutSet = new (cutSetAllocator.Allocate()) CutSet();
840 auto [cutSetPtr, inserted] = cutSets.try_emplace(index, cutSet);
841 assert(inserted && "Cut set already exists for this index");
842 return cutSetPtr->second;
843}
844
846 cutSets.clear();
847 processingOrder.clear();
849 cutAllocator.DestroyAll();
850 cutSetAllocator.DestroyAll();
851}
852
853LogicalResult CutEnumerator::visitLogicOp(uint32_t nodeIndex) {
854 const auto &gate = logicNetwork.getGate(nodeIndex);
855 auto *logicOp = gate.getOperation();
856 assert(logicOp && logicOp->getNumResults() == 1 &&
857 "Logic operation must have a single result");
858
859 unsigned numFanins = gate.getNumFanins();
860
861 // Validate operation constraints
862 // TODO: Variadic operations and non-single-bit results can be supported
863 if (numFanins > 3)
864 return logicOp->emitError("Cut enumeration supports at most 3 operands, "
865 "found: ")
866 << numFanins;
867 if (!logicOp->getOpResult(0).getType().isInteger(1))
868 return logicOp->emitError()
869 << "Supported logic operations must have a single bit "
870 "result type but found: "
871 << logicOp->getResult(0).getType();
872
873 SmallVector<const CutSet *, 2> operandCutSets;
874 operandCutSets.reserve(numFanins);
875
876 // Collect cut sets for each fanin (using LogicNetwork edges)
877 for (unsigned i = 0; i < numFanins; ++i) {
878 uint32_t faninIndex = gate.edges[i].getIndex();
879 auto *operandCutSet = getCutSet(faninIndex);
880 if (!operandCutSet)
881 return logicOp->emitError("Failed to get cut set for fanin index ")
882 << faninIndex;
883 operandCutSets.push_back(operandCutSet);
884 }
885
886 // Create the trivial cut for this node's output
887 Cut *primaryInputCut = new (cutAllocator.Allocate())
888 Cut(getAsTrivialCut(nodeIndex, logicNetwork));
889
890 auto *resultCutSet = createNewCutSet(nodeIndex);
891 processingOrder.push_back(nodeIndex);
892 resultCutSet->addCut(primaryInputCut);
893
894 // Schedule cut set finalization when exiting this scope
895 llvm::scope_exit prune([&]() {
896 // Finalize cut set: remove duplicates, limit size, and match patterns
897 resultCutSet->finalize(options, matchCut, logicNetwork);
898 });
899
900 // Cache maxCutInputSize to avoid repeated access
901 unsigned maxInputSize = options.maxCutInputSize;
902
903 // This lambda generates nested loops at runtime to iterate over all
904 // combinations of cuts from N operands
905 auto enumerateCutCombinations =
906 [&](auto &&self, unsigned operandIdx,
907 SmallVector<const Cut *, 3> &cutPtrs) -> void {
908 // Base case: all operands processed, create merged cut
909 if (operandIdx == numFanins) {
910 // Efficient k-way merge: inputs are sorted, so dedup and constant
911 // filtering can be done while merging. Abort early once we exceed the
912 // cut-size limit to avoid building doomed merged cuts.
913 SmallVector<uint32_t, 6> mergedInputs;
914 auto appendMergedInput = [&](uint32_t value) {
915 if (value == LogicNetwork::kConstant0 ||
917 return true;
918 if (!mergedInputs.empty() && mergedInputs.back() == value)
919 return true;
920 mergedInputs.push_back(value);
921 return mergedInputs.size() <= maxInputSize;
922 };
923
924 if (numFanins == 1) {
925 // Single input: copy while filtering constants.
926 mergedInputs.reserve(
927 std::min<size_t>(cutPtrs[0]->inputs.size(), maxInputSize));
928 for (uint32_t value : cutPtrs[0]->inputs)
929 if (!appendMergedInput(value))
930 return;
931 } else if (numFanins == 2) {
932 // Two-way merge (common case for AND gates)
933 const auto &inputs0 = cutPtrs[0]->inputs;
934 const auto &inputs1 = cutPtrs[1]->inputs;
935 mergedInputs.reserve(
936 std::min<size_t>(inputs0.size() + inputs1.size(), maxInputSize));
937
938 unsigned i = 0, j = 0;
939 while (i < inputs0.size() || j < inputs1.size()) {
940 uint32_t next;
941 if (j == inputs1.size() ||
942 (i < inputs0.size() && inputs0[i] <= inputs1[j])) {
943 next = inputs0[i++];
944 if (j < inputs1.size() && inputs1[j] == next)
945 ++j;
946 } else {
947 next = inputs1[j++];
948 }
949 if (!appendMergedInput(next))
950 return;
951 }
952 } else {
953 // Three-way merge (for MAJ/MUX gates)
954 const SmallVectorImpl<uint32_t> &inputs0 = cutPtrs[0]->inputs;
955 const SmallVectorImpl<uint32_t> &inputs1 = cutPtrs[1]->inputs;
956 const SmallVectorImpl<uint32_t> &inputs2 = cutPtrs[2]->inputs;
957 mergedInputs.reserve(std::min<size_t>(
958 inputs0.size() + inputs1.size() + inputs2.size(), maxInputSize));
959
960 unsigned i = 0, j = 0, k = 0;
961 while (i < inputs0.size() || j < inputs1.size() || k < inputs2.size()) {
962 // Find minimum among available elements
963 uint32_t minVal = UINT32_MAX;
964 if (i < inputs0.size())
965 minVal = std::min(minVal, inputs0[i]);
966 if (j < inputs1.size())
967 minVal = std::min(minVal, inputs1[j]);
968 if (k < inputs2.size())
969 minVal = std::min(minVal, inputs2[k]);
970
971 // Advance all iterators pointing to minVal (handles duplicates)
972 if (i < inputs0.size() && inputs0[i] == minVal)
973 i++;
974 if (j < inputs1.size() && inputs1[j] == minVal)
975 j++;
976 if (k < inputs2.size() && inputs2[k] == minVal)
977 k++;
978
979 if (!appendMergedInput(minVal))
980 return;
981 }
982 }
983
984 // Create the merged cut.
985 Cut *mergedCut = new (cutAllocator.Allocate()) Cut();
986 mergedCut->setRootIndex(nodeIndex);
987 mergedCut->inputs = std::move(mergedInputs);
988
989 // Store operand cuts for lazy truth table computation using fast
990 // incremental method (after duplicate removal in finalize)
991 mergedCut->setOperandCuts(cutPtrs);
992 resultCutSet->addCut(mergedCut);
993
994 LLVM_DEBUG({
995 if (mergedCut->inputs.size() >= 4) {
996 llvm::dbgs() << "Generated cut for node " << nodeIndex;
997 if (logicOp)
998 llvm::dbgs() << " (" << logicOp->getName() << ")";
999 llvm::dbgs() << " inputs=";
1000 llvm::interleaveComma(mergedCut->inputs, llvm::dbgs());
1001 llvm::dbgs() << "\n";
1002 }
1003 });
1004 return;
1005 }
1006
1007 // Recursive case: iterate over cuts for current operand
1008 const CutSet *currentCutSet = operandCutSets[operandIdx];
1009 for (const Cut *cut : currentCutSet->getCuts()) {
1010 cutPtrs.push_back(cut);
1011
1012 // Recurse to next operand
1013 self(self, operandIdx + 1, cutPtrs);
1014
1015 cutPtrs.pop_back();
1016 }
1017 };
1018
1019 // Start recursion with an empty cut pointer list.
1020 SmallVector<const Cut *, 3> cutPtrs;
1021 cutPtrs.reserve(numFanins);
1022 enumerateCutCombinations(enumerateCutCombinations, 0, cutPtrs);
1023
1024 return success();
1025}
1026
1028 Operation *topOp,
1029 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut) {
1030 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts for module: " << topOp->getName()
1031 << "\n");
1032 // Topologically sort the logic network
1033 if (failed(topologicallySortLogicNetwork(topOp)))
1034 return failure();
1035
1036 // Store the pattern matching function for use during cut finalization
1037 this->matchCut = matchCut;
1038
1039 // Build the flat logic network representation for efficient simulation
1040 auto &block = topOp->getRegion(0).getBlocks().front();
1041 if (failed(logicNetwork.buildFromBlock(&block)))
1042 return failure();
1043
1044 for (const auto &[index, gate] : llvm::enumerate(logicNetwork.getGates())) {
1045 // Skip non-logic gates.
1046 if (!gate.isLogicGate())
1047 continue;
1048
1049 // Ensure cut set exists for each logic gate
1050 if (failed(visitLogicOp(index)))
1051 return failure();
1052 }
1053
1054 LLVM_DEBUG(llvm::dbgs() << "Cut enumeration completed successfully\n");
1055 return success();
1056}
1057
1058const CutSet *CutEnumerator::getCutSet(uint32_t index) {
1059 // Check if cut set already exists
1060 auto it = cutSets.find(index);
1061 if (it == cutSets.end()) {
1062 // Create new cut set for an unprocessed value (primary input or other)
1063 CutSet *cutSet = new (cutSetAllocator.Allocate()) CutSet();
1064 Cut *trivialCut =
1065 new (cutAllocator.Allocate()) Cut(getAsTrivialCut(index, logicNetwork));
1066 cutSet->addCut(trivialCut);
1067 auto [newIt, inserted] = cutSets.insert({index, cutSet});
1068 assert(inserted && "Cut set already exists for this index");
1069 it = newIt;
1070 }
1071
1072 return it->second;
1073}
1074
1075/// Generate a human-readable name for a value used in test output.
1076/// This function creates meaningful names for values to make debug output
1077/// and test results more readable and understandable.
1078static StringRef
1079getTestVariableName(Value value, DenseMap<OperationName, unsigned> &opCounter) {
1080 if (auto *op = value.getDefiningOp()) {
1081 // Handle values defined by operations
1082 // First, check if the operation already has a name hint attribute
1083 if (auto name = op->getAttrOfType<StringAttr>("sv.namehint"))
1084 return name.getValue();
1085
1086 // For single-result operations, generate a unique name based on operation
1087 // type
1088 if (op->getNumResults() == 1) {
1089 auto opName = op->getName();
1090 auto count = opCounter[opName]++;
1091
1092 // Create a unique name by appending a counter to the operation name
1093 SmallString<16> nameStr;
1094 nameStr += opName.getStringRef();
1095 nameStr += "_";
1096 nameStr += std::to_string(count);
1097
1098 // Store the generated name as a hint attribute for future reference
1099 auto nameAttr = StringAttr::get(op->getContext(), nameStr);
1100 op->setAttr("sv.namehint", nameAttr);
1101 return nameAttr;
1102 }
1103
1104 // Multi-result operations or other cases get a generic name
1105 return "<unknown>";
1106 }
1107
1108 // Handle block arguments
1109 auto blockArg = cast<BlockArgument>(value);
1110 auto hwOp =
1111 dyn_cast<circt::hw::HWModuleOp>(blockArg.getOwner()->getParentOp());
1112 if (!hwOp)
1113 return "<unknown>";
1114
1115 // Return the formal input name from the hardware module
1116 return hwOp.getInputName(blockArg.getArgNumber());
1117}
1118
1120 DenseMap<OperationName, unsigned> opCounter;
1121 for (auto index : processingOrder) {
1122 auto it = cutSets.find(index);
1123 if (it == cutSets.end())
1124 continue;
1125 auto &cutSet = *it->second;
1126 mlir::Value value = logicNetwork.getValue(index);
1127 llvm::outs() << getTestVariableName(value, opCounter) << " "
1128 << cutSet.getCuts().size() << " cuts:";
1129 for (const Cut *cut : cutSet.getCuts()) {
1130 llvm::outs() << " {";
1131 llvm::interleaveComma(cut->inputs, llvm::outs(), [&](uint32_t inputIdx) {
1132 mlir::Value inputVal = logicNetwork.getValue(inputIdx);
1133 llvm::outs() << getTestVariableName(inputVal, opCounter);
1134 });
1135 auto &pattern = cut->getMatchedPattern();
1136 llvm::outs() << "}"
1137 << "@t" << cut->getTruthTable()->table.getZExtValue() << "d";
1138 if (pattern) {
1139 llvm::outs() << *std::max_element(pattern->getArrivalTimes().begin(),
1140 pattern->getArrivalTimes().end());
1141 } else {
1142 llvm::outs() << "0";
1143 }
1144 }
1145 llvm::outs() << "\n";
1146 }
1147 llvm::outs() << "Cut enumeration completed successfully\n";
1148}
1149
1150//===----------------------------------------------------------------------===//
1151// CutRewriter
1152//===----------------------------------------------------------------------===//
1153
1154LogicalResult CutRewriter::run(Operation *topOp) {
1155 LLVM_DEBUG({
1156 llvm::dbgs() << "Starting Cut Rewriter\n";
1157 llvm::dbgs() << "Mode: "
1159 : "timing")
1160 << "\n";
1161 llvm::dbgs() << "Max input size: " << options.maxCutInputSize << "\n";
1162 llvm::dbgs() << "Max cut size: " << options.maxCutSizePerRoot << "\n";
1163 });
1164
1165 // Currently we don't support patterns with multiple outputs.
1166 // So check that.
1167 // TODO: This must be removed when we support multiple outputs.
1168 for (auto &pattern : patterns.patterns) {
1169 if (pattern->getNumOutputs() > 1) {
1170 return mlir::emitError(pattern->getLoc(),
1171 "Cut rewriter does not support patterns with "
1172 "multiple outputs yet");
1173 }
1174 }
1175
1176 // First sort the operations topologically to ensure we can process them
1177 // in a valid order.
1178 if (failed(topologicallySortLogicNetwork(topOp)))
1179 return failure();
1180
1181 // Enumerate cuts for all nodes (initial delay-oriented selection)
1182 if (failed(enumerateCuts(topOp)))
1183 return failure();
1184
1185 // Dump cuts if testing priority cuts.
1188 return success();
1189 }
1190
1191 // Select best cuts and perform mapping
1192 if (failed(runBottomUpRewrite(topOp)))
1193 return failure();
1194
1195 return success();
1196}
1197
1198LogicalResult CutRewriter::enumerateCuts(Operation *topOp) {
1199 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts...\n");
1200
1202 topOp, [&](const Cut &cut) -> std::optional<MatchedPattern> {
1203 // Match the cut against the patterns
1204 return patternMatchCut(cut);
1205 });
1206}
1207
1208ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
1210 if (patterns.npnToPatternMap.empty())
1211 return {};
1212
1213 auto &npnClass = cut.getNPNClass();
1214 auto it = patterns.npnToPatternMap.find(
1215 {npnClass.truthTable.table, npnClass.truthTable.numInputs});
1216 if (it == patterns.npnToPatternMap.end())
1217 return {};
1218 return it->getSecond();
1219}
1220
1221std::optional<MatchedPattern> CutRewriter::patternMatchCut(const Cut &cut) {
1222 if (cut.isTrivialCut())
1223 return {};
1224
1225 const auto &network = cutEnumerator.getLogicNetwork();
1226 const CutRewritePattern *bestPattern = nullptr;
1227 SmallVector<DelayType, 4> inputArrivalTimes;
1228 SmallVector<DelayType, 1> bestArrivalTimes;
1229 double bestArea = 0.0;
1230 inputArrivalTimes.reserve(cut.getInputSize());
1231 bestArrivalTimes.reserve(cut.getOutputSize(network));
1232
1233 // Compute arrival times for each input.
1234 if (failed(cut.getInputArrivalTimes(cutEnumerator, inputArrivalTimes)))
1235 return {};
1236
1237 auto computeArrivalTimeAndPickBest =
1238 [&](const CutRewritePattern *pattern, const MatchResult &matchResult,
1239 llvm::function_ref<unsigned(unsigned)> mapIndex) {
1240 SmallVector<DelayType, 1> outputArrivalTimes;
1241 // Compute the maximum delay for each output from inputs.
1242 for (unsigned outputIndex = 0, outputSize = cut.getOutputSize(network);
1243 outputIndex < outputSize; ++outputIndex) {
1244 // Compute the arrival time for this output.
1245 DelayType outputArrivalTime = 0;
1246 auto delays = matchResult.getDelays();
1247 for (unsigned inputIndex = 0, inputSize = cut.getInputSize();
1248 inputIndex < inputSize; ++inputIndex) {
1249 // Map pattern input i to cut input through NPN transformations
1250 unsigned cutOriginalInput = mapIndex(inputIndex);
1251 outputArrivalTime =
1252 std::max(outputArrivalTime,
1253 delays[outputIndex * inputSize + inputIndex] +
1254 inputArrivalTimes[cutOriginalInput]);
1255 }
1256
1257 outputArrivalTimes.push_back(outputArrivalTime);
1258 }
1259
1260 // Update the arrival time
1261 if (!bestPattern ||
1262 compareDelayAndArea(options.strategy, matchResult.area,
1263 outputArrivalTimes, bestArea,
1264 bestArrivalTimes)) {
1265 LLVM_DEBUG({
1266 llvm::dbgs() << "== Matched Pattern ==============\n";
1267 llvm::dbgs() << "Matching cut: \n";
1268 cut.dump(llvm::dbgs(), network);
1269 llvm::dbgs() << "Found better pattern: "
1270 << pattern->getPatternName();
1271 llvm::dbgs() << " with area: " << matchResult.area;
1272 llvm::dbgs() << " and input arrival times: ";
1273 for (unsigned i = 0; i < inputArrivalTimes.size(); ++i) {
1274 llvm::dbgs() << " " << inputArrivalTimes[i];
1275 }
1276 llvm::dbgs() << " and arrival times: ";
1277
1278 for (auto arrivalTime : outputArrivalTimes) {
1279 llvm::dbgs() << " " << arrivalTime;
1280 }
1281 llvm::dbgs() << "\n";
1282 llvm::dbgs() << "== Matched Pattern End ==============\n";
1283 });
1284
1285 bestArrivalTimes = std::move(outputArrivalTimes);
1286 bestArea = matchResult.area;
1287 bestPattern = pattern;
1288 }
1289 };
1290
1291 for (auto &[patternNPN, pattern] : getMatchingPatternsFromTruthTable(cut)) {
1292 assert(patternNPN.truthTable.numInputs == cut.getInputSize() &&
1293 "Pattern input size must match cut input size");
1294 auto matchResult = pattern->match(cutEnumerator, cut);
1295 if (!matchResult)
1296 continue;
1297 auto &cutNPN = cut.getNPNClass();
1298
1299 // Get the input mapping from pattern's NPN class to cut's NPN class
1300 SmallVector<unsigned> inputMapping;
1301 cutNPN.getInputPermutation(patternNPN, inputMapping);
1302 computeArrivalTimeAndPickBest(pattern, *matchResult,
1303 [&](unsigned i) { return inputMapping[i]; });
1304 }
1305
1306 for (const CutRewritePattern *pattern : patterns.nonNPNPatterns) {
1307 if (auto matchResult = pattern->match(cutEnumerator, cut))
1308 computeArrivalTimeAndPickBest(pattern, *matchResult,
1309 [&](unsigned i) { return i; });
1310 }
1311
1312 if (!bestPattern)
1313 return {}; // No matching pattern found
1314
1315 return MatchedPattern(bestPattern, std::move(bestArrivalTimes), bestArea);
1316}
1317
1318LogicalResult CutRewriter::runBottomUpRewrite(Operation *top) {
1319 LLVM_DEBUG(llvm::dbgs() << "Performing cut-based rewriting...\n");
1320 const auto &network = cutEnumerator.getLogicNetwork();
1321 const auto &cutSets = cutEnumerator.getCutSets();
1322 auto processingOrder = cutEnumerator.getProcessingOrder();
1323
1324 // Note: Don't clear cutEnumerator yet - we need it during rewrite
1325 UnusedOpPruner pruner;
1326 PatternRewriter rewriter(top->getContext());
1327
1328 // Process in reverse topological order
1329 for (auto index : llvm::reverse(processingOrder)) {
1330 auto it = cutSets.find(index);
1331 if (it == cutSets.end())
1332 continue;
1333
1334 mlir::Value value = network.getValue(index);
1335 auto &cutSet = *it->second;
1336
1337 if (value.use_empty()) {
1338 if (auto *op = value.getDefiningOp())
1339 pruner.eraseNow(op);
1340 continue;
1341 }
1342
1343 if (isAlwaysCutInput(network, index)) {
1344 // If the value is a primary input, skip it
1345 LLVM_DEBUG(llvm::dbgs() << "Skipping inputs: " << value << "\n");
1346 continue;
1347 }
1348
1349 LLVM_DEBUG(llvm::dbgs() << "Cut set for value: " << value << "\n");
1350 auto *bestCut = cutSet.getBestMatchedCut();
1351 if (!bestCut) {
1353 continue; // No matching pattern found, skip this value
1354 return emitError(value.getLoc(), "No matching cut found for value: ")
1355 << value;
1356 }
1357
1358 // Get the root operation from LogicNetwork
1359 auto *rootOp = network.getGate(bestCut->getRootIndex()).getOperation();
1360 rewriter.setInsertionPoint(rootOp);
1361 const auto &matchedPattern = bestCut->getMatchedPattern();
1362 auto result = matchedPattern->getPattern()->rewrite(rewriter, cutEnumerator,
1363 *bestCut);
1364 if (failed(result))
1365 return failure();
1366
1367 rewriter.replaceOp(rootOp, *result);
1368
1370 auto array = rewriter.getI64ArrayAttr(matchedPattern->getArrivalTimes());
1371 (*result)->setAttr("test.arrival_times", array);
1372 }
1373 }
1374
1375 // Clear the enumerator after rewriting is complete
1377 return success();
1378}
assert(baseType &&"element must be base type")
static llvm::APInt simulateGate(const LogicNetwork &network, uint32_t index, llvm::DenseMap< uint32_t, llvm::APInt > &cache, unsigned numInputs)
Simulate a gate and return its truth table.
static llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind, const llvm::APInt &a)
Simulate a gate and return its truth table.
static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl< Cut * > &cuts)
static Cut getAsTrivialCut(uint32_t index, const LogicNetwork &network)
static bool isAlwaysCutInput(const LogicNetwork &network, uint32_t index)
static StringRef getTestVariableName(Value value, DenseMap< OperationName, unsigned > &opCounter)
Generate a human-readable name for a value used in test output.
static bool compareDelayAndArea(OptimizationStrategy strategy, double newArea, ArrayRef< DelayType > newDelay, double oldArea, ArrayRef< DelayType > oldDelay)
RewritePatternSet pattern
Strategy strategy
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
llvm::SpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
CutEnumerator(const CutRewriterOptions &options)
Constructor for cut enumerator.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
llvm::SpecificBumpPtrAllocator< CutSet > cutSetAllocator
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
CutRewritePatternSet(llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns)
Constructor that takes ownership of the provided patterns.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
CutEnumerator cutEnumerator
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
void setRootIndex(uint32_t idx)
Set the root index of this cut.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
void getPermutatedInputIndices(const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
void setOperandCuts(ArrayRef< const Cut * > cuts)
Set operand cuts for lazy truth table computation.
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another cut.
unsigned getInputSize() const
Get the number of inputs to this cut.
void computeTruthTable(const LogicNetwork &network)
Compute and cache the truth table for this cut using the LogicNetwork.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
double getArea() const
Get the area cost of using this pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
OptimizationStrategy
Optimization strategy.
Definition SynthPasses.h:24
@ OptimizationStrategyArea
Optimize for minimal area.
Definition SynthPasses.h:25
@ OptimizationStrategyTiming
Optimize for minimal critical path delay.
Definition SynthPasses.h:26
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
int64_t DelayType
Definition CutRewriter.h:39
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:313
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:44
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::APInt createVarMask(unsigned numVars, unsigned varIndex, bool positive)
Create a mask for a variable in the truth table.
Represents a boolean function as a truth table.
Definition TruthTable.h:39
Represents the canonical form of a boolean function under NPN equivalence.
Definition TruthTable.h:104
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
void getInputPermutation(const NPNClass &targetNPN, llvm::SmallVectorImpl< unsigned > &permutation) const
Get input permutation from this NPN class to another equivalent NPN class.
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...
void eraseNow(Operation *op)
Erase an operation immediately, and remove it from the set of ops to be removed later.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Identity
Identity gate (used for 1-input inverter)
@ Maj3
Majority gate (3-input, mig::MajOp)
@ PrimaryInput
Primary input to the network.
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
Edge representation in the logic network.
Definition CutRewriter.h:74