CIRCT 23.0.0git
Loading...
Searching...
No Matches
CutRewriter.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a DAG-based boolean matching cut rewriting algorithm for
10// applications like technology/LUT mapping and combinational logic
11// optimization. The algorithm uses priority cuts and NPN
12// (Negation-Permutation-Negation) canonical forms to efficiently match cuts
13// against rewriting patterns.
14//
15// References:
16// "Combinational and Sequential Mapping with Priority Cuts", Alan Mishchenko,
17// Sungmin Cho, Satrajit Chatterjee and Robert Brayton, ICCAD 2007
18// "Improvements to technology mapping for LUT-based FPGAs", Alan Mishchenko,
19// Satrajit Chatterjee and Robert Brayton, FPGA 2006
20//
21//===----------------------------------------------------------------------===//
22
24
29#include "circt/Support/LLVM.h"
32#include "mlir/Analysis/TopologicalSortUtils.h"
33#include "mlir/IR/Builders.h"
34#include "mlir/IR/Operation.h"
35#include "mlir/IR/RegionKindInterface.h"
36#include "mlir/IR/Value.h"
37#include "mlir/IR/ValueRange.h"
38#include "mlir/IR/Visitors.h"
39#include "mlir/Support/LLVM.h"
40#include "llvm/ADT/APInt.h"
41#include "llvm/ADT/Bitset.h"
42#include "llvm/ADT/DenseMap.h"
43#include "llvm/ADT/MapVector.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/ADT/ScopeExit.h"
46#include "llvm/ADT/SetVector.h"
47#include "llvm/ADT/SmallVector.h"
48#include "llvm/ADT/TypeSwitch.h"
49#include "llvm/ADT/iterator.h"
50#include "llvm/Support/Debug.h"
51#include "llvm/Support/ErrorHandling.h"
52#include "llvm/Support/LogicalResult.h"
53#include <algorithm>
54#include <functional>
55#include <memory>
56#include <optional>
57#include <string>
58
59#define DEBUG_TYPE "synth-cut-rewriter"
60
61using namespace circt;
62using namespace circt::synth;
63
64//===----------------------------------------------------------------------===//
65// LogicNetwork
66//===----------------------------------------------------------------------===//
67
68uint32_t LogicNetwork::getOrCreateIndex(Value value) {
69 auto [it, inserted] = valueToIndex.try_emplace(value, gates.size());
70 if (inserted) {
71 indexToValue.push_back(value);
72 gates.emplace_back(); // Will be filled in later
73 }
74 return it->second;
75}
76
77uint32_t LogicNetwork::getIndex(Value value) const {
78 const auto it = valueToIndex.find(value);
79 assert(it != valueToIndex.end() &&
80 "Value not found in LogicNetwork - use getOrCreateIndex or check with "
81 "hasIndex first");
82 return it->second;
83}
84
85bool LogicNetwork::hasIndex(Value value) const {
86 return valueToIndex.contains(value);
87}
88
89Value LogicNetwork::getValue(uint32_t index) const {
90 // Index 0 and 1 are reserved for constants, they have no associated Value
91 if (index == kConstant0 || index == kConstant1)
92 return Value();
93
94 assert(index < indexToValue.size() &&
95 "Index out of bounds in LogicNetwork::getValue");
96 return indexToValue[index];
97}
98
99void LogicNetwork::getValues(ArrayRef<uint32_t> indices,
100 SmallVectorImpl<Value> &values) const {
101 values.clear();
102 values.reserve(indices.size());
103 for (uint32_t idx : indices)
104 values.push_back(getValue(idx));
105}
106
107uint32_t LogicNetwork::addPrimaryInput(Value value) {
108 const uint32_t index = getOrCreateIndex(value);
110 return index;
111}
112
113uint32_t LogicNetwork::addGate(Operation *op, LogicNetworkGate::Kind kind,
114 Value result, ArrayRef<Signal> operands) {
115 const uint32_t index = getOrCreateIndex(result);
116 gates[index] = LogicNetworkGate(op, kind, operands);
117 return index;
118}
119
120LogicalResult LogicNetwork::buildFromBlock(Block *block) {
121 // Pre-size vectors to reduce reallocations (rough estimate)
122 const size_t estimatedSize =
123 block->getArguments().size() + block->getOperations().size();
124 indexToValue.reserve(estimatedSize);
125 gates.reserve(estimatedSize);
126
127 auto handleSingleInputGate = [&](Operation *op, Value result,
128 const Signal &inputSignal) {
129 if (!inputSignal.isInverted()) {
130 // Non-inverted buffer: directly alias the result to the input
131 valueToIndex[result] = inputSignal.getIndex();
132 return;
133 }
134 // Inverted operation: create a NOT gate
135 addGate(op, LogicNetworkGate::Identity, result, {inputSignal});
136 };
137
138 // Ensure all block arguments are indexed as primary inputs first
139 for (Value arg : block->getArguments()) {
140 if (!hasIndex(arg))
141 addPrimaryInput(arg);
142 }
143
144 auto handleOtherResults = [&](Operation *op) {
145 for (Value result : op->getResults()) {
146 if (result.getType().isInteger(1) && !hasIndex(result))
147 addPrimaryInput(result);
148 }
149 };
150
151 auto getInvertibleSignal = [&](auto op, unsigned index) {
152 return getOrCreateSignal(op.getOperand(index), op.isInverted(index));
153 };
154
155 auto handleInvertibleBinaryGate = [&](auto logicOp,
157 // The cut rewriter only has dedicated nodes for single-bit unary/binary
158 // gates. Wider or variadic forms stay as opaque cut inputs for now.
159 const auto inputs = logicOp.getInputs();
160 if (inputs.size() == 1) {
161 const Signal inputSignal = getInvertibleSignal(logicOp, 0);
162 handleSingleInputGate(logicOp, logicOp.getResult(), inputSignal);
163 return success();
164 }
165 if (inputs.size() == 2) {
166 const Signal lhsSignal = getInvertibleSignal(logicOp, 0);
167 const Signal rhsSignal = getInvertibleSignal(logicOp, 1);
168 addGate(logicOp, kind, {lhsSignal, rhsSignal});
169 }
170 // Variadic gates with >2 inputs are treated as primary
171 // inputs for now.
172 handleOtherResults(logicOp);
173 return success();
174 };
175
176 // Process operations in topological order
177 for (Operation &op : block->getOperations()) {
178 LogicalResult result =
179 llvm::TypeSwitch<Operation *, LogicalResult>(&op)
180 .Case<aig::AndInverterOp>([&](aig::AndInverterOp andOp) {
181 return handleInvertibleBinaryGate(andOp, LogicNetworkGate::And2);
182 })
183 .Case<synth::XorInverterOp>([&](synth::XorInverterOp xorOp) {
184 return handleInvertibleBinaryGate(xorOp, LogicNetworkGate::Xor2);
185 })
186 .Case<comb::XorOp>([&](comb::XorOp xorOp) {
187 if (xorOp->getNumOperands() != 2) {
188 handleOtherResults(xorOp);
189 return success();
190 }
191 const Signal lhsSignal =
192 getOrCreateSignal(xorOp.getOperand(0), false);
193 const Signal rhsSignal =
194 getOrCreateSignal(xorOp.getOperand(1), false);
195 addGate(xorOp, LogicNetworkGate::Xor2, {lhsSignal, rhsSignal});
196 return success();
197 })
198 .Case<hw::ConstantOp>([&](hw::ConstantOp constOp) {
199 Value result = constOp.getResult();
200 if (!result.getType().isInteger(1)) {
201 handleOtherResults(constOp);
202 return success();
203 }
204 uint32_t constIdx =
205 constOp.getValue().isZero() ? kConstant0 : kConstant1;
206 valueToIndex[result] = constIdx;
207 return success();
208 })
209 .Case<synth::ChoiceOp>([&](synth::ChoiceOp choiceOp) {
210 if (!choiceOp.getType().isInteger(1)) {
211 handleOtherResults(choiceOp);
212 return success();
213 }
214 addGate(choiceOp, LogicNetworkGate::Choice, choiceOp.getResult(),
215 {});
216 return success();
217 })
218 .Default([&](Operation *defaultOp) {
219 handleOtherResults(defaultOp);
220 return success();
221 });
222
223 if (failed(result))
224 return result;
225 }
226
227 return success();
228}
229
231 valueToIndex.clear();
232 indexToValue.clear();
233 gates.clear();
234 // Re-add the constant nodes (index 0 = const0, index 1 = const1)
235 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
236 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
237 // Placeholders for constants in indexToValue
238 indexToValue.push_back(Value()); // const0
239 indexToValue.push_back(Value()); // const1
240}
241
242//===----------------------------------------------------------------------===//
243// Helper functions
244//===----------------------------------------------------------------------===//
245
246// Return true if the gate at the given index is always a cut input.
247static bool isAlwaysCutInput(const LogicNetwork &network, uint32_t index) {
248 const auto &gate = network.getGate(index);
249 return gate.isAlwaysCutInput();
250}
251
252// Return true if the new area/delay is better than the old area/delay in the
253// context of the given strategy.
255 ArrayRef<DelayType> newDelay, double oldArea,
256 ArrayRef<DelayType> oldDelay) {
258 return newArea < oldArea || (newArea == oldArea && newDelay < oldDelay);
260 return newDelay < oldDelay || (newDelay == oldDelay && newArea < oldArea);
261 llvm_unreachable("Unknown mapping strategy");
262}
263
264LogicalResult circt::synth::topologicallySortLogicNetwork(Operation *topOp) {
265 const auto isOperationReady = [](Value value, Operation *op) -> bool {
266 // Topologically sort AIG ops and dataflow ops. Other operations
267 // can be scheduled.
268 return !(isa<aig::AndInverterOp, synth::XorInverterOp, synth::ChoiceOp,
270 comb::ReplicateOp, comb::ConcatOp>(op));
271 };
272
273 if (failed(topologicallySortGraphRegionBlocks(topOp, isOperationReady)))
274 return emitError(topOp->getLoc(),
275 "failed to sort operations topologically");
276 return success();
277}
278
279/// Get the truth table for operations within a block.
280FailureOr<BinaryTruthTable> circt::synth::getTruthTable(ValueRange values,
281 Block *block) {
283 for (Value arg : block->getArguments())
284 inputArgs.insert(arg);
285
286 if (inputArgs.empty())
287 return BinaryTruthTable();
288
289 const int64_t numInputs = inputArgs.size();
290 const int64_t numOutputs = values.size();
291 if (LLVM_UNLIKELY(numOutputs != 1 || numInputs >= maxTruthTableInputs)) {
292 if (numOutputs == 0)
293 return BinaryTruthTable(numInputs, 0);
294 if (numInputs >= maxTruthTableInputs)
295 return mlir::emitError(values.front().getLoc(),
296 "Truth table is too large");
297 return mlir::emitError(values.front().getLoc(),
298 "Multiple outputs are not supported yet");
299 }
300
301 // Create a map to evaluate the operation
302 DenseMap<Value, APInt> eval;
303 for (uint32_t i = 0; i < numInputs; ++i)
304 eval[inputArgs[i]] = circt::createVarMask(numInputs, i, true);
305
306 // Simulate the operations in the block
307 for (Operation &op : *block) {
308 if (op.getNumResults() == 0)
309 continue;
310
311 if (auto choiceOp = dyn_cast<synth::ChoiceOp>(&op)) {
312 auto it = eval.find(choiceOp.getInputs().front());
313 if (it == eval.end())
314 return choiceOp.emitError("Input value not found in evaluation map");
315 eval[choiceOp.getResult()] = it->second;
316 } else if (auto logicOp = dyn_cast<BooleanLogicOpInterface>(&op)) {
317 for (auto value : logicOp.getInputs())
318 if (!eval.contains(value))
319 return logicOp->emitError("Input value not found in evaluation map");
320
321 eval[logicOp.getResult()] =
322 logicOp.evaluateBooleanLogic([&](unsigned i) -> const APInt & {
323 return eval.find(logicOp.getInput(i))->second;
324 });
325 } else if (auto xorOp = dyn_cast<comb::XorOp>(&op)) {
326 // TODO: Define Xor as Synth op.
327 auto it = eval.find(xorOp.getOperand(0));
328 if (it == eval.end())
329 return xorOp.emitError("Input value not found in evaluation map");
330 llvm::APInt result = it->second;
331 for (unsigned i = 1; i < xorOp.getNumOperands(); ++i) {
332 it = eval.find(xorOp.getOperand(i));
333 if (it == eval.end())
334 return xorOp.emitError("Input value not found in evaluation map");
335 result ^= it->second;
336 }
337 eval[xorOp.getResult()] = result;
338 } else if (!isa<hw::OutputOp>(&op)) {
339 return op.emitError("Unsupported operation for truth table simulation");
340 }
341 }
342
343 return BinaryTruthTable(numInputs, 1, eval[values[0]]);
344}
345
346//===----------------------------------------------------------------------===//
347// Cut
348//===----------------------------------------------------------------------===//
349
350bool Cut::isTrivialCut() const {
351 // A cut is a trivial cut if it has no root (rootIndex == 0 sentinel)
352 // and only one input
353 return rootIndex == 0 && inputs.size() == 1;
354}
355
356const NPNClass &Cut::getNPNClass() const { return getNPNClass(nullptr); }
357
358const NPNClass &Cut::getNPNClass(const NPNTable *npnTable) const {
359 if (npnClass)
360 return *npnClass;
361
362 const auto &truthTable = *getTruthTable();
363 NPNClass canonicalForm;
364 if (!npnTable || !npnTable->lookup(truthTable, canonicalForm))
366
367 npnClass.emplace(std::move(canonicalForm));
368 return *npnClass;
369}
370
372 const NPNTable *npnTable, const NPNClass &patternNPN,
373 SmallVectorImpl<unsigned> &permutedIndices) const {
374 const auto &npnClass = getNPNClass(npnTable);
375 npnClass.getInputPermutation(patternNPN, permutedIndices);
376}
377
378LogicalResult
380 SmallVectorImpl<DelayType> &results) const {
381 results.reserve(getInputSize());
382 const auto &network = enumerator.getLogicNetwork();
383
384 // Compute arrival times for each input.
385 for (auto inputIndex : inputs) {
386 if (isAlwaysCutInput(network, inputIndex)) {
387 // If the input is a primary input, it has no delay.
388 results.push_back(0);
389 continue;
390 }
391 auto *cutSet = enumerator.getCutSet(inputIndex);
392 assert(cutSet && "Input must have a valid cut set");
393
394 // If there is no matching pattern, it means it's not possible to use the
395 // input in the cut rewriting. Return empty vector to indicate failure.
396 auto *bestCut = cutSet->getBestMatchedCut();
397 if (!bestCut)
398 return failure();
399
400 const auto &matchedPattern = *bestCut->getMatchedPattern();
401
402 // Get the value for result number lookup
403 mlir::Value inputValue = network.getValue(inputIndex);
404 // Otherwise, the cut input is an op result. Get the arrival time
405 // from the matched pattern.
406 results.push_back(matchedPattern.getArrivalTime(
407 cast<mlir::OpResult>(inputValue).getResultNumber()));
408 }
409
410 return success();
411}
412
413void Cut::dump(llvm::raw_ostream &os, const LogicNetwork &network) const {
414 os << "// === Cut Dump ===\n";
415 os << "Cut with " << getInputSize() << " inputs";
416 if (rootIndex != 0) {
417 auto *rootOp = network.getGate(rootIndex).getOperation();
418 if (rootOp)
419 os << " and root: " << *rootOp;
420 }
421 os << "\n";
422
423 if (isTrivialCut()) {
424 mlir::Value inputVal = network.getValue(inputs[0]);
425 os << "Primary input cut: " << inputVal << "\n";
426 return;
427 }
428
429 os << "Inputs (indices): \n";
430 for (auto [idx, inputIndex] : llvm::enumerate(inputs)) {
431 mlir::Value inputVal = network.getValue(inputIndex);
432 os << " Input " << idx << " (index " << inputIndex << "): " << inputVal
433 << "\n";
434 }
435
436 if (rootIndex != 0) {
437 os << "\nRoot operation: \n";
438 if (auto *rootOp = network.getGate(rootIndex).getOperation())
439 rootOp->print(os);
440 os << "\n";
441 }
442
443 auto &npnClass = getNPNClass();
444 npnClass.dump(os);
445
446 os << "// === Cut End ===\n";
447}
448
449unsigned Cut::getInputSize() const { return inputs.size(); }
450
451unsigned Cut::getOutputSize(const LogicNetwork &network) const {
452 if (rootIndex == 0)
453 return 1; // Trivial cut has 1 output
454 auto *rootOp = network.getGate(rootIndex).getOperation();
455 return rootOp ? rootOp->getNumResults() : 1;
456}
457
458/// Simulate a gate and return its truth table.
459static inline llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind,
460 const llvm::APInt &a) {
461 switch (kind) {
463 return a;
464 default:
465 llvm_unreachable("Unsupported unary operation for truth table computation");
466 }
467}
468
469static inline llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind,
470 const llvm::APInt &a,
471 const llvm::APInt &b) {
472 switch (kind) {
474 return a & b;
476 return a ^ b;
477 default:
478 llvm_unreachable(
479 "Unsupported binary operation for truth table computation");
480 }
481}
482
483static inline llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind,
484 const llvm::APInt &a,
485 const llvm::APInt &b,
486 const llvm::APInt &c) {
487 switch (kind) {
489 return (a & b) | (a & c) | (b & c);
490 default:
491 llvm_unreachable(
492 "Unsupported ternary operation for truth table computation");
493 }
494}
495
496namespace {
497
498// Helper class to build a merged truth table for a cut based on its operand
499// cuts
500struct MergedTruthTableBuilder {
501 MergedTruthTableBuilder(ArrayRef<uint32_t> mergedInputs,
502 ArrayRef<const Cut *> operandCuts)
503 : mergedInputs(mergedInputs), numMergedInputs(mergedInputs.size()),
504 operandCuts(operandCuts) {
505 assert(llvm::is_sorted(mergedInputs) && "merged inputs must be sorted");
506 assert(llvm::adjacent_find(mergedInputs) == mergedInputs.end() &&
507 "merged inputs must be unique");
508 }
509
510 ArrayRef<uint32_t> mergedInputs;
511 unsigned numMergedInputs;
512 ArrayRef<const Cut *> operandCuts;
513
514 std::optional<unsigned> findMergedInputPosition(uint32_t operandIdx) const {
515 auto *it = llvm::find(mergedInputs, operandIdx);
516 if (it == mergedInputs.end())
517 return std::nullopt;
518 return static_cast<unsigned>(std::distance(mergedInputs.begin(), it));
519 }
520
521 const Cut *findOperandCut(uint32_t operandIdx) const {
522 for (const Cut *cut : operandCuts) {
523 if (!cut)
524 continue;
525 uint32_t cutOutput =
526 cut->isTrivialCut() ? cut->inputs[0] : cut->getRootIndex();
527 if (cutOutput == operandIdx)
528 return cut;
529 }
530 return nullptr;
531 }
532
533 void getInputMapping(const Cut *cut,
534 SmallVectorImpl<unsigned> &mapping) const {
535 mapping.clear();
536 mapping.reserve(cut->inputs.size());
537 for (uint32_t idx : cut->inputs) {
538 auto *it = llvm::find(mergedInputs, idx);
539 assert(it != mergedInputs.end() &&
540 "cut input must exist in merged inputs");
541 mapping.push_back(static_cast<unsigned>(it - mergedInputs.begin()));
542 }
543 }
544
545 llvm::APInt expandCutTruthTable(const Cut *cut) const {
546 const auto &cutTT = *cut->getTruthTable();
547 SmallVector<unsigned, 8> inputMapping;
548 getInputMapping(cut, inputMapping);
550 cutTT.table, inputMapping, numMergedInputs);
551 }
552
553 llvm::APInt expandOperand(uint32_t operandIdx, bool isInverted) const {
554 llvm::APInt result(1, 0);
555 if (operandIdx == LogicNetwork::kConstant0) {
556 result = llvm::APInt::getZero(1U << numMergedInputs);
557 } else if (operandIdx == LogicNetwork::kConstant1) {
558 result = llvm::APInt::getAllOnes(1U << numMergedInputs);
559 } else if (auto pos = findMergedInputPosition(operandIdx)) {
560 // Direct cut inputs already live in the merged input space.
561 result = circt::createVarMask(numMergedInputs, *pos, true);
562 } else if (const Cut *cut = findOperandCut(operandIdx)) {
563 // Internal operands reuse the operand cut truth table after expanding it
564 // to this root cut's merged input space.
565 result = expandCutTruthTable(cut);
566 } else {
567 llvm_unreachable("Operand not found in cuts or merged inputs");
568 }
569
570 if (isInverted)
571 result.flipAllBits();
572 return result;
573 }
574
575 BinaryTruthTable computeForGate(const LogicNetworkGate &rootGate) const {
576 auto getEdgeTT = [&](unsigned edgeIdx) {
577 const auto &edge = rootGate.edges[edgeIdx];
578 return expandOperand(edge.getIndex(), edge.isInverted());
579 };
580
581 switch (rootGate.getKind()) {
584 return BinaryTruthTable(
585 numMergedInputs, 1,
586 applyGateSemantics(rootGate.getKind(), getEdgeTT(0), getEdgeTT(1)));
588 return BinaryTruthTable(numMergedInputs, 1,
589 applyGateSemantics(rootGate.getKind(),
590 getEdgeTT(0), getEdgeTT(1),
591 getEdgeTT(2)));
593 return BinaryTruthTable(
594 numMergedInputs, 1,
595 applyGateSemantics(rootGate.getKind(), getEdgeTT(0)));
596 default:
597 llvm_unreachable("Unsupported operation for truth table computation");
598 }
599 }
600};
601
602} // namespace
603
605 if (isTrivialCut()) {
606 assert(truthTable && "trivial cuts should have their truth table pre-set");
607 return;
608 }
609
610 assert(!operandCuts.empty() &&
611 "non-trivial cuts must carry operand cuts for truth table expansion");
612
613 const auto &rootGate = network.getGate(rootIndex);
614 truthTable.emplace(
615 MergedTruthTableBuilder(inputs, operandCuts).computeForGate(rootGate));
616}
617
618bool Cut::dominates(const Cut &other) const {
619 return dominates(other.inputs, other.signature);
620}
621
622bool Cut::dominates(ArrayRef<uint32_t> otherInputs, uint64_t otherSig) const {
623
624 if (getInputSize() > otherInputs.size())
625 return false;
626
627 if ((signature & otherSig) != signature)
628 return false;
629
630 return std::includes(otherInputs.begin(), otherInputs.end(), inputs.begin(),
631 inputs.end());
632}
633
634Cut Cut::getTrivialCut(uint32_t index) {
635 Cut cut;
636 cut.inputs.push_back(index);
637 // The truth table for a trivial cut is just the identity function on its
638 // single input.
639 cut.setTruthTable(BinaryTruthTable(1, 1, llvm::APInt(2, 2)));
640 cut.setSignature(1ULL << (index % 64)); // Set signature bit for this input
641 return cut;
642}
643
644//===----------------------------------------------------------------------===//
645// MatchedPattern
646//===----------------------------------------------------------------------===//
647
648ArrayRef<DelayType> MatchedPattern::getArrivalTimes() const {
649 assert(pattern && "Pattern must be set to get arrival time");
650 return arrivalTimes;
651}
652
654 assert(pattern && "Pattern must be set to get arrival time");
655 return arrivalTimes[index];
656}
657
659 assert(pattern && "Pattern must be set to get the pattern");
660 return pattern;
661}
662
664 assert(pattern && "Pattern must be set to get area");
665 return area;
666}
667
668//===----------------------------------------------------------------------===//
669// CutSet
670//===----------------------------------------------------------------------===//
671
673
674unsigned CutSet::size() const { return cuts.size(); }
675
676void CutSet::addCut(Cut *cut) {
677 assert(!isFrozen && "Cannot add cuts to a frozen cut set");
678 cuts.push_back(cut);
679}
680
681ArrayRef<Cut *> CutSet::getCuts() const { return cuts; }
682
683// Remove duplicate cuts and non-minimal cuts. A cut is non-minimal if there
684// exists another cut that is a subset of it.
685static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl<Cut *> &cuts) {
686 auto dumpInputs = [](llvm::raw_ostream &os,
687 const llvm::SmallVectorImpl<uint32_t> &inputs) {
688 os << "{";
689 llvm::interleaveComma(inputs, os);
690 os << "}";
691 };
692 // Sort by size, then lexicographically by inputs. This enables cheap exact
693 // duplicate elimination and tighter candidate filtering for subset checks.
694 std::stable_sort(cuts.begin(), cuts.end(), [](const Cut *a, const Cut *b) {
695 if (a->getInputSize() != b->getInputSize())
696 return a->getInputSize() < b->getInputSize();
697 return std::lexicographical_compare(a->inputs.begin(), a->inputs.end(),
698 b->inputs.begin(), b->inputs.end());
699 });
700
701 // Group kept cuts by input size so subset checks only visit smaller cuts.
702 unsigned maxCutSize = cuts.empty() ? 0 : cuts.back()->getInputSize();
703 llvm::SmallVector<llvm::SmallVector<Cut *, 4>, 16> keptBySize(maxCutSize + 1);
704
705 // Compact kept cuts in-place.
706 unsigned uniqueCount = 0;
707 for (Cut *cut : cuts) {
708 unsigned cutSize = cut->getInputSize();
709
710 // Fast exact duplicate check: with lexicographic sort, duplicates are
711 // adjacent among cuts with equal size.
712 if (uniqueCount > 0) {
713 Cut *lastKept = cuts[uniqueCount - 1];
714 if (lastKept->getInputSize() == cutSize &&
715 lastKept->inputs == cut->inputs)
716 continue;
717 }
718
719 bool isDominated = false;
720 for (unsigned existingSize = 1; existingSize < cutSize && !isDominated;
721 ++existingSize) {
722 for (const Cut *existingCut : keptBySize[existingSize]) {
723 if (!existingCut->dominates(*cut))
724 continue;
725
726 LLVM_DEBUG({
727 llvm::dbgs() << "Dropping non-minimal cut ";
728 dumpInputs(llvm::dbgs(), cut->inputs);
729 llvm::dbgs() << " due to subset ";
730 dumpInputs(llvm::dbgs(), existingCut->inputs);
731 llvm::dbgs() << "\n";
732 });
733 isDominated = true;
734 break;
735 }
736 }
737
738 if (isDominated)
739 continue;
740
741 cuts[uniqueCount++] = cut;
742 keptBySize[cutSize].push_back(cut);
743 }
744
745 LLVM_DEBUG(llvm::dbgs() << "Original cuts: " << cuts.size()
746 << " Unique cuts: " << uniqueCount << "\n");
747
748 // Resize the cuts vector to the number of surviving cuts.
749 cuts.resize(uniqueCount);
750}
751
753 const CutRewriterOptions &options,
754 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut,
755 const LogicNetwork &logicNetwork) {
756
757 // Remove duplicate/non-minimal cuts first so all follow-up work only runs on
758 // survivors.
760
761 // Compute truth tables lazily, then match cuts to collect timing/area data.
762 for (Cut *cut : cuts) {
763 if (!cut->getTruthTable().has_value())
764 cut->computeTruthTableFromOperands(logicNetwork);
765
766 assert(cut->getInputSize() <= options.maxCutInputSize &&
767 "Cut input size exceeds maximum allowed size");
768
769 if (auto matched = matchCut(*cut))
770 cut->setMatchedPattern(std::move(*matched));
771 }
772
773 // Sort cuts by priority to select the most promising ones.
774 // Priority is determined by the optimization strategy:
775 // - Trivial cuts (direct connections) have highest priority
776 // - Among matched cuts, compare by area/delay based on the strategy
777 // - Matched cuts are preferred over unmatched cuts
778 // See "Combinational and Sequential Mapping with Priority Cuts" by Mishchenko
779 // et al., ICCAD 2007 for more details.
780 // TODO: Use a priority queue instead of sorting for better performance.
781
782 // Partition the cuts into trivial and non-trivial cuts.
783 auto *trivialCutsEnd =
784 std::stable_partition(cuts.begin(), cuts.end(),
785 [](const Cut *cut) { return cut->isTrivialCut(); });
786
787 auto isBetterCut = [&options](const Cut *a, const Cut *b) {
788 assert(!a->isTrivialCut() && !b->isTrivialCut() &&
789 "Trivial cuts should have been excluded");
790 const auto &aMatched = a->getMatchedPattern();
791 const auto &bMatched = b->getMatchedPattern();
792
793 if (aMatched && bMatched)
794 return compareDelayAndArea(
795 options.strategy, aMatched->getArea(), aMatched->getArrivalTimes(),
796 bMatched->getArea(), bMatched->getArrivalTimes());
797
798 if (static_cast<bool>(aMatched) != static_cast<bool>(bMatched))
799 return static_cast<bool>(aMatched);
800
801 return a->getInputSize() < b->getInputSize();
802 };
803 std::stable_sort(trivialCutsEnd, cuts.end(), isBetterCut);
804
805 // Keep only the top-K cuts to bound growth.
806 if (cuts.size() > options.maxCutSizePerRoot)
807 cuts.resize(options.maxCutSizePerRoot);
808
809 // Select the best cut from the remaining candidates.
810 bestCut = nullptr;
811 for (Cut *cut : cuts) {
812 const auto &currentMatch = cut->getMatchedPattern();
813 if (!currentMatch)
814 continue;
815 bestCut = cut;
816 break;
817 }
818
819 LLVM_DEBUG({
820 llvm::dbgs() << "Finalized cut set with " << cuts.size() << " cuts and "
821 << (bestCut
822 ? "matched pattern to " + bestCut->getMatchedPattern()
823 ->getPattern()
824 ->getPatternName()
825 : "no matched pattern")
826 << "\n";
827 });
828
829 isFrozen = true; // Mark the cut set as frozen
830}
831
832//===----------------------------------------------------------------------===//
833// CutRewritePattern
834//===----------------------------------------------------------------------===//
835
837 SmallVectorImpl<NPNClass> &matchingNPNClasses) const {
838 return false;
839}
840
841//===----------------------------------------------------------------------===//
842// CutRewritePatternSet
843//===----------------------------------------------------------------------===//
844
846 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns)
847 : patterns(std::move(patterns)) {
848 // Initialize the NPN to pattern map
849 for (auto &pattern : this->patterns) {
850 SmallVector<NPNClass, 2> npnClasses;
851 auto result = pattern->useTruthTableMatcher(npnClasses);
852 if (result) {
853 for (auto npnClass : npnClasses) {
854 // Create a NPN class from the truth table
855 npnToPatternMap[{npnClass.truthTable.table,
856 npnClass.truthTable.numInputs}]
857 .push_back(std::make_pair(std::move(npnClass), pattern.get()));
858 }
859 } else {
860 // If the pattern does not provide NPN classes, we use a special key
861 // to indicate that it should be considered for all cuts.
862 nonNPNPatterns.push_back(pattern.get());
863 }
864 }
865}
866
867//===----------------------------------------------------------------------===//
868// CutEnumerator
869//===----------------------------------------------------------------------===//
870
872 : cutAllocator(stats.numCutsCreated),
873 cutSetAllocator(stats.numCutSetsCreated), options(options) {}
874
876 CutSet *cutSet = cutSetAllocator.create();
877 auto [cutSetPtr, inserted] = cutSets.try_emplace(index, cutSet);
878 assert(inserted && "Cut set already exists for this index");
879 return cutSetPtr->second;
880}
881
883 cutSets.clear();
884 processingOrder.clear();
886 cutAllocator.DestroyAll();
887 cutSetAllocator.DestroyAll();
888}
889
890LogicalResult CutEnumerator::visitLogicOp(uint32_t nodeIndex) {
891 const auto &gate = logicNetwork.getGate(nodeIndex);
892 auto *logicOp = gate.getOperation();
893 assert(logicOp && logicOp->getNumResults() == 1 &&
894 "Logic operation must have a single result");
895
896 if (gate.getKind() == LogicNetworkGate::Choice) {
897 auto choiceOp = cast<synth::ChoiceOp>(logicOp);
898 auto *resultCutSet = createNewCutSet(nodeIndex);
899 Cut *primaryInputCut = cutAllocator.create(Cut::getTrivialCut(nodeIndex));
900 processingOrder.push_back(nodeIndex);
901 resultCutSet->addCut(primaryInputCut);
902
903 for (Value operand : choiceOp.getInputs()) {
904 auto *operandCutSet = getCutSet(logicNetwork.getIndex(operand));
905 if (!operandCutSet)
906 return logicOp->emitError("Failed to get cut set for choice operand");
907
908 // Choice nodes do not introduce new logic. They forward each non-trivial
909 // operand cut as an equivalent alternative for the same root.
910 for (const Cut *operandCut : operandCutSet->getCuts()) {
911 if (operandCut->isTrivialCut())
912 continue;
913
914 resultCutSet->addCut(cutAllocator.create(
915 nodeIndex, operandCut->inputs, operandCut->getSignature(),
916 ArrayRef<const Cut *>{operandCut}, *operandCut->getTruthTable()));
917 }
918 }
919
920 // Finalize cut set: remove duplicates, limit size, and match patterns
921 resultCutSet->finalize(options, matchCut, logicNetwork);
922 return success();
923 }
924
925 unsigned numFanins = gate.getNumFanins();
926
927 // Validate operation constraints
928 // TODO: Variadic operations and non-single-bit results can be supported
929 if (numFanins > 3)
930 return logicOp->emitError("Cut enumeration supports at most 3 operands, "
931 "found: ")
932 << numFanins;
933 if (!logicOp->getOpResult(0).getType().isInteger(1))
934 return logicOp->emitError()
935 << "Supported logic operations must have a single bit "
936 "result type but found: "
937 << logicOp->getResult(0).getType();
938
939 // A vector to hold cut sets for each operand along with their max cut input
940 // size.
941 SmallVector<std::pair<const CutSet *, unsigned>, 2> operandCutSets;
942 operandCutSets.reserve(numFanins);
943
944 // Collect cut sets for each fanin (using LogicNetwork edges)
945 for (unsigned i = 0; i < numFanins; ++i) {
946 uint32_t faninIndex = gate.edges[i].getIndex();
947 auto *operandCutSet = getCutSet(faninIndex);
948 if (!operandCutSet)
949 return logicOp->emitError("Failed to get cut set for fanin index ")
950 << faninIndex;
951
952 // Find the largest cut size among the operand's cuts for sorting heuristic
953 // later.
954 unsigned maxInputCutSize = 0;
955 for (auto *cut : operandCutSet->getCuts())
956 maxInputCutSize = std::max(maxInputCutSize, cut->getInputSize());
957 operandCutSets.push_back(std::make_pair(operandCutSet, maxInputCutSize));
958 }
959
960 // Create the trivial cut for this node's output
961 Cut *primaryInputCut = cutAllocator.create(Cut::getTrivialCut(nodeIndex));
962
963 auto *resultCutSet = createNewCutSet(nodeIndex);
964 processingOrder.push_back(nodeIndex);
965 resultCutSet->addCut(primaryInputCut);
966
967 // Sort operand cut sets by their largest cut size in descending order. This
968 // heuristic improves efficiency of the k-way merge when generating cuts for
969 // the current node by maximizing the chance of early pruning when the merged
970 // cut exceeds the input size limit.
971 llvm::stable_sort(operandCutSets,
972 [](const std::pair<const CutSet *, unsigned> &a,
973 const std::pair<const CutSet *, unsigned> &b) {
974 return a.second > b.second;
975 });
976
977 // Cache maxCutInputSize to avoid repeated access
978 unsigned maxInputSize = options.maxCutInputSize;
979
980 // This lambda generates nested loops at runtime to iterate over all
981 // combinations of cuts from N operands
982 auto enumerateCutCombinations = [&](auto &&self, unsigned operandIdx,
983 SmallVector<const Cut *, 3> &cutPtrs,
984 uint64_t currentSig) -> void {
985 // Base case: all operands processed, create merged cut
986 if (operandIdx == numFanins) {
987 // Efficient k-way merge: inputs are sorted, so dedup and constant
988 // filtering can be done while merging. Abort early once we exceed the
989 // cut-size limit to avoid building doomed merged cuts.
990 SmallVector<uint32_t, 6> mergedInputs;
991 auto appendMergedInput = [&](uint32_t value) {
992 if (value == LogicNetwork::kConstant0 ||
994 return true;
995 if (!mergedInputs.empty() && mergedInputs.back() == value)
996 return true;
997 mergedInputs.push_back(value);
998 return mergedInputs.size() <= maxInputSize;
999 };
1000
1001 if (numFanins == 1) {
1002 // Single input: copy while filtering constants.
1003 mergedInputs.reserve(
1004 std::min<size_t>(cutPtrs[0]->inputs.size(), maxInputSize));
1005 for (uint32_t value : cutPtrs[0]->inputs)
1006 if (!appendMergedInput(value))
1007 return;
1008 } else if (numFanins == 2) {
1009 // Two-way merge (common case for AND gates)
1010 const auto &inputs0 = cutPtrs[0]->inputs;
1011 const auto &inputs1 = cutPtrs[1]->inputs;
1012 mergedInputs.reserve(
1013 std::min<size_t>(inputs0.size() + inputs1.size(), maxInputSize));
1014
1015 unsigned i = 0, j = 0;
1016 while (i < inputs0.size() || j < inputs1.size()) {
1017 uint32_t next;
1018 if (j == inputs1.size() ||
1019 (i < inputs0.size() && inputs0[i] <= inputs1[j])) {
1020 next = inputs0[i++];
1021 if (j < inputs1.size() && inputs1[j] == next)
1022 ++j;
1023 } else {
1024 next = inputs1[j++];
1025 }
1026 if (!appendMergedInput(next))
1027 return;
1028 }
1029 } else {
1030 // Three-way merge (for MAJ/MUX gates)
1031 const SmallVectorImpl<uint32_t> &inputs0 = cutPtrs[0]->inputs;
1032 const SmallVectorImpl<uint32_t> &inputs1 = cutPtrs[1]->inputs;
1033 const SmallVectorImpl<uint32_t> &inputs2 = cutPtrs[2]->inputs;
1034 mergedInputs.reserve(std::min<size_t>(
1035 inputs0.size() + inputs1.size() + inputs2.size(), maxInputSize));
1036
1037 unsigned i = 0, j = 0, k = 0;
1038 while (i < inputs0.size() || j < inputs1.size() || k < inputs2.size()) {
1039 // Find minimum among available elements
1040 uint32_t minVal = UINT32_MAX;
1041 if (i < inputs0.size())
1042 minVal = std::min(minVal, inputs0[i]);
1043 if (j < inputs1.size())
1044 minVal = std::min(minVal, inputs1[j]);
1045 if (k < inputs2.size())
1046 minVal = std::min(minVal, inputs2[k]);
1047
1048 // Advance all iterators pointing to minVal (handles duplicates)
1049 if (i < inputs0.size() && inputs0[i] == minVal)
1050 i++;
1051 if (j < inputs1.size() && inputs1[j] == minVal)
1052 j++;
1053 if (k < inputs2.size() && inputs2[k] == minVal)
1054 k++;
1055
1056 if (!appendMergedInput(minVal))
1057 return;
1058 }
1059 }
1060
1061 // Create the merged cut.
1062 Cut *mergedCut = cutAllocator.create(nodeIndex, mergedInputs, currentSig,
1063 ArrayRef<const Cut *>(cutPtrs));
1064 resultCutSet->addCut(mergedCut);
1065
1066 LLVM_DEBUG({
1067 if (mergedCut->inputs.size() >= 4) {
1068 llvm::dbgs() << "Generated cut for node " << nodeIndex;
1069 if (logicOp)
1070 llvm::dbgs() << " (" << logicOp->getName() << ")";
1071 llvm::dbgs() << " inputs=";
1072 llvm::interleaveComma(mergedCut->inputs, llvm::dbgs());
1073 llvm::dbgs() << "\n";
1074 }
1075 });
1076 return;
1077 }
1078
1079 // Recursive case: iterate over cuts for current operand
1080 const CutSet *currentCutSet = operandCutSets[operandIdx].first;
1081 for (const Cut *cut : currentCutSet->getCuts()) {
1082 uint64_t cutSig = cut->getSignature();
1083 uint64_t newSig = currentSig | cutSig;
1084 if (static_cast<unsigned>(llvm::popcount(newSig)) > maxInputSize)
1085 continue; // Early rejection based on signature
1086
1087 cutPtrs.push_back(cut);
1088
1089 // Recurse to next operand
1090 self(self, operandIdx + 1, cutPtrs, newSig);
1091
1092 cutPtrs.pop_back();
1093 }
1094 };
1095
1096 // Start the recursion with empty cut pointer list and zero signature
1097 SmallVector<const Cut *, 3> cutPtrs;
1098 cutPtrs.reserve(numFanins);
1099 enumerateCutCombinations(enumerateCutCombinations, 0, cutPtrs, 0ULL);
1100
1101 // Finalize cut set: remove duplicates, limit size, and match patterns
1102 resultCutSet->finalize(options, matchCut, logicNetwork);
1103
1104 return success();
1105}
1106
1108 Operation *topOp,
1109 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut) {
1110 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts for module: " << topOp->getName()
1111 << "\n");
1112 // Topologically sort the logic network
1113 if (failed(topologicallySortLogicNetwork(topOp)))
1114 return failure();
1115
1116 // Store the pattern matching function for use during cut finalization
1117 this->matchCut = matchCut;
1118
1119 // Build the flat logic network representation for efficient simulation
1120 auto &block = topOp->getRegion(0).getBlocks().front();
1121 if (failed(logicNetwork.buildFromBlock(&block)))
1122 return failure();
1123
1124 for (const auto &[index, gate] : llvm::enumerate(logicNetwork.getGates())) {
1125 // Skip non-logic gates.
1126 if (!gate.isLogicGate())
1127 continue;
1128
1129 // Ensure cut set exists for each logic gate
1130 if (failed(visitLogicOp(index)))
1131 return failure();
1132 }
1133
1134 LLVM_DEBUG(llvm::dbgs() << "Cut enumeration completed successfully\n");
1135 return success();
1136}
1137
1138const CutSet *CutEnumerator::getCutSet(uint32_t index) {
1139 // Check if cut set already exists
1140 auto it = cutSets.find(index);
1141 if (it == cutSets.end()) {
1142 // Create new cut set for an unprocessed value (primary input or other)
1143 CutSet *cutSet = cutSetAllocator.create();
1144 Cut *trivialCut = cutAllocator.create(Cut::getTrivialCut(index));
1145 cutSet->addCut(trivialCut);
1146 auto [newIt, inserted] = cutSets.insert({index, cutSet});
1147 assert(inserted && "Cut set already exists for this index");
1148 it = newIt;
1149 }
1150
1151 return it->second;
1152}
1153
1154/// Generate a human-readable name for a value used in test output.
1155/// This function creates meaningful names for values to make debug output
1156/// and test results more readable and understandable.
1157static StringRef
1158getTestVariableName(Value value, DenseMap<OperationName, unsigned> &opCounter) {
1159 if (auto *op = value.getDefiningOp()) {
1160 // Handle values defined by operations
1161 // First, check if the operation already has a name hint attribute
1162 if (auto name = op->getAttrOfType<StringAttr>("sv.namehint"))
1163 return name.getValue();
1164
1165 // For single-result operations, generate a unique name based on operation
1166 // type
1167 if (op->getNumResults() == 1) {
1168 auto opName = op->getName();
1169 auto count = opCounter[opName]++;
1170
1171 // Create a unique name by appending a counter to the operation name
1172 SmallString<16> nameStr;
1173 nameStr += opName.getStringRef();
1174 nameStr += "_";
1175 nameStr += std::to_string(count);
1176
1177 // Store the generated name as a hint attribute for future reference
1178 auto nameAttr = StringAttr::get(op->getContext(), nameStr);
1179 op->setAttr("sv.namehint", nameAttr);
1180 return nameAttr;
1181 }
1182
1183 // Multi-result operations or other cases get a generic name
1184 return "<unknown>";
1185 }
1186
1187 // Handle block arguments
1188 auto blockArg = cast<BlockArgument>(value);
1189 auto hwOp =
1190 dyn_cast<circt::hw::HWModuleOp>(blockArg.getOwner()->getParentOp());
1191 if (!hwOp)
1192 return "<unknown>";
1193
1194 // Return the formal input name from the hardware module
1195 return hwOp.getInputName(blockArg.getArgNumber());
1196}
1197
1199 DenseMap<OperationName, unsigned> opCounter;
1200 for (auto index : processingOrder) {
1201 auto it = cutSets.find(index);
1202 if (it == cutSets.end())
1203 continue;
1204 auto &cutSet = *it->second;
1205 mlir::Value value = logicNetwork.getValue(index);
1206 llvm::outs() << getTestVariableName(value, opCounter) << " "
1207 << cutSet.getCuts().size() << " cuts:";
1208 for (const Cut *cut : cutSet.getCuts()) {
1209 llvm::outs() << " {";
1210 llvm::interleaveComma(cut->inputs, llvm::outs(), [&](uint32_t inputIdx) {
1211 mlir::Value inputVal = logicNetwork.getValue(inputIdx);
1212 llvm::outs() << getTestVariableName(inputVal, opCounter);
1213 });
1214 auto &pattern = cut->getMatchedPattern();
1215 llvm::outs() << "}"
1216 << "@t" << cut->getTruthTable()->table.getZExtValue() << "d";
1217 if (pattern) {
1218 llvm::outs() << *std::max_element(pattern->getArrivalTimes().begin(),
1219 pattern->getArrivalTimes().end());
1220 } else {
1221 llvm::outs() << "0";
1222 }
1223 }
1224 llvm::outs() << "\n";
1225 }
1226 llvm::outs() << "Cut enumeration completed successfully\n";
1227}
1228
1229//===----------------------------------------------------------------------===//
1230// CutRewriter
1231//===----------------------------------------------------------------------===//
1232
1233LogicalResult CutRewriter::run(Operation *topOp) {
1234 LLVM_DEBUG({
1235 llvm::dbgs() << "Starting Cut Rewriter\n";
1236 llvm::dbgs() << "Mode: "
1238 : "timing")
1239 << "\n";
1240 llvm::dbgs() << "Max input size: " << options.maxCutInputSize << "\n";
1241 llvm::dbgs() << "Max cut size: " << options.maxCutSizePerRoot << "\n";
1242 });
1243
1244 // Currently we don't support patterns with multiple outputs.
1245 // So check that.
1246 // TODO: This must be removed when we support multiple outputs.
1247 for (auto &pattern : patterns.patterns) {
1248 if (pattern->getNumOutputs() > 1) {
1249 return mlir::emitError(pattern->getLoc(),
1250 "Cut rewriter does not support patterns with "
1251 "multiple outputs yet");
1252 }
1253 }
1254
1255 // First sort the operations topologically to ensure we can process them
1256 // in a valid order.
1257 if (failed(topologicallySortLogicNetwork(topOp)))
1258 return failure();
1259
1260 // Enumerate cuts for all nodes (initial delay-oriented selection)
1261 if (failed(enumerateCuts(topOp)))
1262 return failure();
1263
1264 // Dump cuts if testing priority cuts.
1267 return success();
1268 }
1269
1270 // Select best cuts and perform mapping
1271 if (failed(runBottomUpRewrite(topOp)))
1272 return failure();
1273
1274 return success();
1275}
1276
1277LogicalResult CutRewriter::enumerateCuts(Operation *topOp) {
1278 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts...\n");
1279
1281 topOp, [&](const Cut &cut) -> std::optional<MatchedPattern> {
1282 // Match the cut against the patterns
1283 return patternMatchCut(cut);
1284 });
1285}
1286
1287ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
1289 if (patterns.npnToPatternMap.empty())
1290 return {};
1291
1292 auto &npnClass = cut.getNPNClass(options.npnTable);
1293 auto it = patterns.npnToPatternMap.find(
1294 {npnClass.truthTable.table, npnClass.truthTable.numInputs});
1295 if (it == patterns.npnToPatternMap.end())
1296 return {};
1297 return it->getSecond();
1298}
1299
1300std::optional<MatchedPattern> CutRewriter::patternMatchCut(const Cut &cut) {
1301 if (cut.isTrivialCut())
1302 return {};
1303
1304 const auto &network = cutEnumerator.getLogicNetwork();
1305 const CutRewritePattern *bestPattern = nullptr;
1306 SmallVector<DelayType, 4> inputArrivalTimes;
1307 SmallVector<DelayType, 1> bestArrivalTimes;
1308 double bestArea = 0.0;
1309 inputArrivalTimes.reserve(cut.getInputSize());
1310 bestArrivalTimes.reserve(cut.getOutputSize(network));
1311
1312 // Compute arrival times for each input.
1313 if (failed(cut.getInputArrivalTimes(cutEnumerator, inputArrivalTimes)))
1314 return {};
1315
1316 auto computeArrivalTimeAndPickBest =
1317 [&](const CutRewritePattern *pattern, const MatchResult &matchResult,
1318 llvm::function_ref<unsigned(unsigned)> mapIndex) {
1319 SmallVector<DelayType, 1> outputArrivalTimes;
1320 // Compute the maximum delay for each output from inputs.
1321 for (unsigned outputIndex = 0, outputSize = cut.getOutputSize(network);
1322 outputIndex < outputSize; ++outputIndex) {
1323 // Compute the arrival time for this output.
1324 DelayType outputArrivalTime = 0;
1325 auto delays = matchResult.getDelays();
1326 for (unsigned inputIndex = 0, inputSize = cut.getInputSize();
1327 inputIndex < inputSize; ++inputIndex) {
1328 // Map pattern input i to cut input through NPN transformations
1329 unsigned cutOriginalInput = mapIndex(inputIndex);
1330 outputArrivalTime =
1331 std::max(outputArrivalTime,
1332 delays[outputIndex * inputSize + inputIndex] +
1333 inputArrivalTimes[cutOriginalInput]);
1334 }
1335
1336 outputArrivalTimes.push_back(outputArrivalTime);
1337 }
1338
1339 // Update the arrival time
1340 if (!bestPattern ||
1341 compareDelayAndArea(options.strategy, matchResult.area,
1342 outputArrivalTimes, bestArea,
1343 bestArrivalTimes)) {
1344 LLVM_DEBUG({
1345 llvm::dbgs() << "== Matched Pattern ==============\n";
1346 llvm::dbgs() << "Matching cut: \n";
1347 cut.dump(llvm::dbgs(), network);
1348 llvm::dbgs() << "Found better pattern: "
1349 << pattern->getPatternName();
1350 llvm::dbgs() << " with area: " << matchResult.area;
1351 llvm::dbgs() << " and input arrival times: ";
1352 for (unsigned i = 0; i < inputArrivalTimes.size(); ++i) {
1353 llvm::dbgs() << " " << inputArrivalTimes[i];
1354 }
1355 llvm::dbgs() << " and arrival times: ";
1356
1357 for (auto arrivalTime : outputArrivalTimes) {
1358 llvm::dbgs() << " " << arrivalTime;
1359 }
1360 llvm::dbgs() << "\n";
1361 llvm::dbgs() << "== Matched Pattern End ==============\n";
1362 });
1363
1364 bestArrivalTimes = std::move(outputArrivalTimes);
1365 bestArea = matchResult.area;
1366 bestPattern = pattern;
1367 }
1368 };
1369
1370 for (auto &[patternNPN, pattern] : getMatchingPatternsFromTruthTable(cut)) {
1371 assert(patternNPN.truthTable.numInputs == cut.getInputSize() &&
1372 "Pattern input size must match cut input size");
1373 auto matchResult = pattern->match(cutEnumerator, cut);
1374 if (!matchResult)
1375 continue;
1376 auto &cutNPN = cut.getNPNClass(options.npnTable);
1377
1378 // Get the input mapping from pattern's NPN class to cut's NPN class
1379 SmallVector<unsigned> inputMapping;
1380 cutNPN.getInputPermutation(patternNPN, inputMapping);
1381 computeArrivalTimeAndPickBest(pattern, *matchResult,
1382 [&](unsigned i) { return inputMapping[i]; });
1383 }
1384
1385 for (const CutRewritePattern *pattern : patterns.nonNPNPatterns) {
1386 if (auto matchResult = pattern->match(cutEnumerator, cut))
1387 computeArrivalTimeAndPickBest(pattern, *matchResult,
1388 [&](unsigned i) { return i; });
1389 }
1390
1391 if (!bestPattern)
1392 return {}; // No matching pattern found
1393
1394 return MatchedPattern(bestPattern, std::move(bestArrivalTimes), bestArea);
1395}
1396
1397LogicalResult CutRewriter::runBottomUpRewrite(Operation *top) {
1398 LLVM_DEBUG(llvm::dbgs() << "Performing cut-based rewriting...\n");
1399 const auto &network = cutEnumerator.getLogicNetwork();
1400 const auto &cutSets = cutEnumerator.getCutSets();
1401 auto processingOrder = cutEnumerator.getProcessingOrder();
1402
1403 // Note: Don't clear cutEnumerator yet - we need it during rewrite
1404 UnusedOpPruner pruner;
1405 PatternRewriter rewriter(top->getContext());
1406
1407 // Process in reverse topological order
1408 for (auto index : llvm::reverse(processingOrder)) {
1409 auto it = cutSets.find(index);
1410 if (it == cutSets.end())
1411 continue;
1412
1413 mlir::Value value = network.getValue(index);
1414 auto &cutSet = *it->second;
1415
1416 if (value.use_empty()) {
1417 if (auto *op = value.getDefiningOp())
1418 pruner.eraseNow(op);
1419 continue;
1420 }
1421
1422 if (isAlwaysCutInput(network, index)) {
1423 // If the value is a primary input, skip it
1424 LLVM_DEBUG(llvm::dbgs() << "Skipping inputs: " << value << "\n");
1425 continue;
1426 }
1427
1428 LLVM_DEBUG(llvm::dbgs() << "Cut set for value: " << value << "\n");
1429 auto *bestCut = cutSet.getBestMatchedCut();
1430 if (!bestCut) {
1432 continue; // No matching pattern found, skip this value
1433 return emitError(value.getLoc(), "No matching cut found for value: ")
1434 << value;
1435 }
1436
1437 // Get the root operation from LogicNetwork
1438 auto *rootOp = network.getGate(bestCut->getRootIndex()).getOperation();
1439 rewriter.setInsertionPoint(rootOp);
1440 const auto &matchedPattern = bestCut->getMatchedPattern();
1441 auto result = matchedPattern->getPattern()->rewrite(rewriter, cutEnumerator,
1442 *bestCut);
1443 if (failed(result))
1444 return failure();
1445
1446 rewriter.replaceOp(rootOp, *result);
1448
1450 auto array = rewriter.getI64ArrayAttr(matchedPattern->getArrivalTimes());
1451 (*result)->setAttr("test.arrival_times", array);
1452 }
1453 }
1454
1455 // Clear the enumerator after rewriting is complete
1457 return success();
1458}
assert(baseType &&"element must be base type")
static llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind, const llvm::APInt &a)
Simulate a gate and return its truth table.
static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl< Cut * > &cuts)
static bool isAlwaysCutInput(const LogicNetwork &network, uint32_t index)
static StringRef getTestVariableName(Value value, DenseMap< OperationName, unsigned > &opCounter)
Generate a human-readable name for a value used in test output.
static bool compareDelayAndArea(OptimizationStrategy strategy, double newArea, ArrayRef< DelayType > newDelay, double oldArea, ArrayRef< DelayType > oldDelay)
RewritePatternSet pattern
Strategy strategy
Precomputed NPN canonicalization table for 4-input single-output functions.
Definition TruthTable.h:168
bool lookup(const BinaryTruthTable &tt, NPNClass &result) const
Returns false if the given truth table shape is unsupported.
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
TrackedSpecificBumpPtrAllocator< CutSet > cutSetAllocator
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
void noteCutRewritten()
Record that one cut was successfully rewritten.
CutEnumerator(const CutRewriterOptions &options)
Constructor for cut enumerator.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
TrackedSpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
CutRewritePatternSet(llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns)
Constructor that takes ownership of the provided patterns.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
CutEnumerator cutEnumerator
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
static Cut getTrivialCut(uint32_t index)
Create a trivial cut for a value.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
uint64_t signature
Signature bitset for fast cut size estimation.
uint64_t getSignature() const
Get the signature of this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
void setTruthTable(BinaryTruthTable tt)
Set the truth table directly (used for incremental computation).
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
llvm::SmallVector< const Cut *, 3 > operandCuts
Operand cuts used to create this cut (for lazy TT computation).
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
void setSignature(uint64_t sig)
Set the signature of this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another (i.e., this cut's inputs are a subset of the other's inputs).
void getPermutatedInputIndices(const NPNTable *npnTable, const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
unsigned getInputSize() const
Get the number of inputs to this cut.
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
double getArea() const
Get the area cost of using this pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
llvm::APInt expandTruthTableToInputSpace(const llvm::APInt &tt, ArrayRef< unsigned > inputMapping, unsigned numExpandedInputs)
Expand a truth table to a larger input space using the given input mapping.
OptimizationStrategy
Optimization strategy.
Definition SynthPasses.h:24
@ OptimizationStrategyArea
Optimize for minimal area.
Definition SynthPasses.h:25
@ OptimizationStrategyTiming
Optimize for minimal critical path delay.
Definition SynthPasses.h:26
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
int64_t DelayType
Definition CutRewriter.h:40
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:376
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:45
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::APInt createVarMask(unsigned numVars, unsigned varIndex, bool positive)
Create a mask for a variable in the truth table.
Represents a boolean function as a truth table.
Definition TruthTable.h:41
Represents the canonical form of a boolean function under NPN equivalence.
Definition TruthTable.h:106
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
void getInputPermutation(const NPNClass &targetNPN, llvm::SmallVectorImpl< unsigned > &permutation) const
Get input permutation from this NPN class to another equivalent NPN class.
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...
void eraseNow(Operation *op)
Erase an operation immediately, and remove it from the set of ops to be removed later.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
const NPNTable * npnTable
Optional lookup table used to accelerate 4-input NPN canonicalization.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Signal edges[3]
Fanin edges (up to 3 inputs).
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
Kind getKind() const
Get the kind of this gate.
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Identity
Identity gate (used for 1-input inverter)
@ Maj3
Reserved 3-input gate kind.
@ PrimaryInput
Primary input to the network.
@ Choice
Choice node (synth.choice)
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
Edge representation in the logic network.
Definition CutRewriter.h:75