CIRCT 23.0.0git
Loading...
Searching...
No Matches
CutRewriter.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a DAG-based boolean matching cut rewriting algorithm for
10// applications like technology/LUT mapping and combinational logic
11// optimization. The algorithm uses priority cuts and NPN
12// (Negation-Permutation-Negation) canonical forms to efficiently match cuts
13// against rewriting patterns.
14//
15// References:
16// "Combinational and Sequential Mapping with Priority Cuts", Alan Mishchenko,
17// Sungmin Cho, Satrajit Chatterjee and Robert Brayton, ICCAD 2007
18// "Improvements to technology mapping for LUT-based FPGAs", Alan Mishchenko,
19// Satrajit Chatterjee and Robert Brayton, FPGA 2006
20//
21//===----------------------------------------------------------------------===//
22
24
29#include "circt/Support/LLVM.h"
32#include "mlir/Analysis/TopologicalSortUtils.h"
33#include "mlir/IR/Builders.h"
34#include "mlir/IR/Operation.h"
35#include "mlir/IR/RegionKindInterface.h"
36#include "mlir/IR/Value.h"
37#include "mlir/IR/ValueRange.h"
38#include "mlir/IR/Visitors.h"
39#include "mlir/Support/LLVM.h"
40#include "llvm/ADT/APInt.h"
41#include "llvm/ADT/Bitset.h"
42#include "llvm/ADT/DenseMap.h"
43#include "llvm/ADT/MapVector.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/ADT/ScopeExit.h"
46#include "llvm/ADT/SetVector.h"
47#include "llvm/ADT/SmallVector.h"
48#include "llvm/ADT/TypeSwitch.h"
49#include "llvm/ADT/iterator.h"
50#include "llvm/Support/Debug.h"
51#include "llvm/Support/ErrorHandling.h"
52#include "llvm/Support/LogicalResult.h"
53#include <algorithm>
54#include <functional>
55#include <memory>
56#include <optional>
57#include <string>
58
59#define DEBUG_TYPE "synth-cut-rewriter"
60
61using namespace circt;
62using namespace circt::synth;
63
64//===----------------------------------------------------------------------===//
65// LogicNetwork
66//===----------------------------------------------------------------------===//
67
68uint32_t LogicNetwork::getOrCreateIndex(Value value) {
69 auto [it, inserted] = valueToIndex.try_emplace(value, gates.size());
70 if (inserted) {
71 indexToValue.push_back(value);
72 gates.emplace_back(); // Will be filled in later
73 }
74 return it->second;
75}
76
77uint32_t LogicNetwork::getIndex(Value value) const {
78 const auto it = valueToIndex.find(value);
79 assert(it != valueToIndex.end() &&
80 "Value not found in LogicNetwork - use getOrCreateIndex or check with "
81 "hasIndex first");
82 return it->second;
83}
84
85bool LogicNetwork::hasIndex(Value value) const {
86 return valueToIndex.contains(value);
87}
88
89Value LogicNetwork::getValue(uint32_t index) const {
90 // Index 0 and 1 are reserved for constants, they have no associated Value
91 if (index == kConstant0 || index == kConstant1)
92 return Value();
93
94 assert(index < indexToValue.size() &&
95 "Index out of bounds in LogicNetwork::getValue");
96 return indexToValue[index];
97}
98
99void LogicNetwork::getValues(ArrayRef<uint32_t> indices,
100 SmallVectorImpl<Value> &values) const {
101 values.clear();
102 values.reserve(indices.size());
103 for (uint32_t idx : indices)
104 values.push_back(getValue(idx));
105}
106
107uint32_t LogicNetwork::addPrimaryInput(Value value) {
108 const uint32_t index = getOrCreateIndex(value);
110 return index;
111}
112
113uint32_t LogicNetwork::addGate(Operation *op, LogicNetworkGate::Kind kind,
114 Value result, ArrayRef<Signal> operands) {
115 const uint32_t index = getOrCreateIndex(result);
116 gates[index] = LogicNetworkGate(op, kind, operands);
117 return index;
118}
119
120LogicalResult LogicNetwork::buildFromBlock(Block *block) {
121 // Pre-size vectors to reduce reallocations (rough estimate)
122 const size_t estimatedSize =
123 block->getArguments().size() + block->getOperations().size();
124 indexToValue.reserve(estimatedSize);
125 gates.reserve(estimatedSize);
126
127 auto handleSingleInputGate = [&](Operation *op, Value result,
128 const Signal &inputSignal) {
129 if (!inputSignal.isInverted()) {
130 // Non-inverted buffer: directly alias the result to the input
131 valueToIndex[result] = inputSignal.getIndex();
132 return;
133 }
134 // Inverted operation: create a NOT gate
135 addGate(op, LogicNetworkGate::Identity, result, {inputSignal});
136 };
137
138 // Ensure all block arguments are indexed as primary inputs first
139 for (Value arg : block->getArguments()) {
140 if (!hasIndex(arg))
141 addPrimaryInput(arg);
142 }
143
144 auto handleOtherResults = [&](Operation *op) {
145 for (Value result : op->getResults()) {
146 if (result.getType().isInteger(1) && !hasIndex(result))
147 addPrimaryInput(result);
148 }
149 };
150
151 auto getInvertibleSignal = [&](auto op, unsigned index) {
152 return getOrCreateSignal(op.getOperand(index), op.isInverted(index));
153 };
154
155 auto handleInvertibleBinaryGate = [&](auto logicOp,
157 // The cut rewriter only has dedicated nodes for single-bit unary/binary
158 // gates. Wider or variadic forms stay as opaque cut inputs for now.
159 const auto inputs = logicOp.getInputs();
160 if (inputs.size() == 1) {
161 const Signal inputSignal = getInvertibleSignal(logicOp, 0);
162 handleSingleInputGate(logicOp, logicOp.getResult(), inputSignal);
163 return success();
164 }
165 if (inputs.size() == 2) {
166 const Signal lhsSignal = getInvertibleSignal(logicOp, 0);
167 const Signal rhsSignal = getInvertibleSignal(logicOp, 1);
168 addGate(logicOp, kind, {lhsSignal, rhsSignal});
169 }
170 // Variadic gates with >2 inputs are treated as primary
171 // inputs for now.
172 handleOtherResults(logicOp);
173 return success();
174 };
175
176 auto handleInvertibleTernaryGate = [&](auto logicOp,
178 if (!logicOp.getType().isInteger(1)) {
179 handleOtherResults(logicOp);
180 return success();
181 }
182 const Signal aSignal = getInvertibleSignal(logicOp, 0);
183 const Signal bSignal = getInvertibleSignal(logicOp, 1);
184 const Signal cSignal = getInvertibleSignal(logicOp, 2);
185 addGate(logicOp, kind, {aSignal, bSignal, cSignal});
186 return success();
187 };
188
189 // Process operations in topological order
190 for (Operation &op : block->getOperations()) {
191 LogicalResult result =
192 llvm::TypeSwitch<Operation *, LogicalResult>(&op)
193 .Case<aig::AndInverterOp>([&](aig::AndInverterOp andOp) {
194 return handleInvertibleBinaryGate(andOp, LogicNetworkGate::And2);
195 })
196 .Case<synth::XorInverterOp>([&](synth::XorInverterOp xorOp) {
197 return handleInvertibleBinaryGate(xorOp, LogicNetworkGate::Xor2);
198 })
199 .Case<synth::MuxInverterOp>([&](synth::MuxInverterOp muxOp) {
200 return handleInvertibleTernaryGate(muxOp, LogicNetworkGate::Mux3);
201 })
202 .Case<synth::DotOp>([&](synth::DotOp dotOp) {
203 return handleInvertibleTernaryGate(dotOp, LogicNetworkGate::Dot3);
204 })
205 .Case<synth::MajorityOp>([&](synth::MajorityOp majOp) {
206 return handleInvertibleTernaryGate(majOp, LogicNetworkGate::Maj3);
207 })
208 .Case<synth::OneHotOp>([&](synth::OneHotOp oneHotOp) {
209 return handleInvertibleTernaryGate(oneHotOp,
211 })
212 .Case<synth::GambleOp>([&](synth::GambleOp gambleOp) {
213 return handleInvertibleTernaryGate(gambleOp,
215 })
216 .Case<comb::XorOp>([&](comb::XorOp xorOp) {
217 if (xorOp->getNumOperands() != 2) {
218 handleOtherResults(xorOp);
219 return success();
220 }
221 const Signal lhsSignal =
222 getOrCreateSignal(xorOp.getOperand(0), false);
223 const Signal rhsSignal =
224 getOrCreateSignal(xorOp.getOperand(1), false);
225 addGate(xorOp, LogicNetworkGate::Xor2, {lhsSignal, rhsSignal});
226 return success();
227 })
228 .Case<hw::ConstantOp>([&](hw::ConstantOp constOp) {
229 Value result = constOp.getResult();
230 if (!result.getType().isInteger(1)) {
231 handleOtherResults(constOp);
232 return success();
233 }
234 uint32_t constIdx =
235 constOp.getValue().isZero() ? kConstant0 : kConstant1;
236 valueToIndex[result] = constIdx;
237 return success();
238 })
239 .Case<synth::ChoiceOp>([&](synth::ChoiceOp choiceOp) {
240 if (!choiceOp.getType().isInteger(1)) {
241 handleOtherResults(choiceOp);
242 return success();
243 }
244 addGate(choiceOp, LogicNetworkGate::Choice, choiceOp.getResult(),
245 {});
246 return success();
247 })
248 .Default([&](Operation *defaultOp) {
249 handleOtherResults(defaultOp);
250 return success();
251 });
252
253 if (failed(result))
254 return result;
255 }
256
257 return success();
258}
259
261 valueToIndex.clear();
262 indexToValue.clear();
263 gates.clear();
264 // Re-add the constant nodes (index 0 = const0, index 1 = const1)
265 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
266 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
267 // Placeholders for constants in indexToValue
268 indexToValue.push_back(Value()); // const0
269 indexToValue.push_back(Value()); // const1
270}
271
272//===----------------------------------------------------------------------===//
273// Helper functions
274//===----------------------------------------------------------------------===//
275
276// Return true if the gate at the given index is always a cut input.
277static bool isAlwaysCutInput(const LogicNetwork &network, uint32_t index) {
278 const auto &gate = network.getGate(index);
279 return gate.isAlwaysCutInput();
280}
281
282// Return true if the new area/delay is better than the old area/delay in the
283// context of the given strategy.
285 ArrayRef<DelayType> newDelay, double oldArea,
286 ArrayRef<DelayType> oldDelay) {
288 return newArea < oldArea || (newArea == oldArea && newDelay < oldDelay);
290 return newDelay < oldDelay || (newDelay == oldDelay && newArea < oldArea);
291 llvm_unreachable("Unknown mapping strategy");
292}
293
294LogicalResult circt::synth::topologicallySortLogicNetwork(Operation *topOp) {
295 const auto isOperationReady = [](Value value, Operation *op) -> bool {
296 // Topologically sort AIG ops and dataflow ops. Other operations
297 // can be scheduled.
299 };
300
301 if (failed(topologicallySortGraphRegionBlocks(topOp, isOperationReady)))
302 return emitError(topOp->getLoc(),
303 "failed to sort operations topologically");
304 return success();
305}
306
307/// Get the truth table for operations within a block.
308FailureOr<BinaryTruthTable> circt::synth::getTruthTable(ValueRange values,
309 Block *block) {
311 for (Value arg : block->getArguments())
312 inputArgs.insert(arg);
313
314 if (inputArgs.empty())
315 return BinaryTruthTable();
316
317 const int64_t numInputs = inputArgs.size();
318 const int64_t numOutputs = values.size();
319 if (LLVM_UNLIKELY(numOutputs != 1 || numInputs >= maxTruthTableInputs)) {
320 if (numOutputs == 0)
321 return BinaryTruthTable(numInputs, 0);
322 if (numInputs >= maxTruthTableInputs)
323 return mlir::emitError(values.front().getLoc(),
324 "Truth table is too large");
325 return mlir::emitError(values.front().getLoc(),
326 "Multiple outputs are not supported yet");
327 }
328
329 // Create a map to evaluate the operation
330 DenseMap<Value, APInt> eval;
331 for (uint32_t i = 0; i < numInputs; ++i)
332 eval[inputArgs[i]] = circt::createVarMask(numInputs, i, true);
333
334 // Simulate the operations in the block
335 for (Operation &op : block->without_terminator()) {
336 if (op.getNumResults() != 1 ||
337 hw::getBitWidth(op.getResult(0).getType()) != 1)
338 return op.emitError("Unsupported operation for truth table simulation");
339
340 if (auto choiceOp = dyn_cast<synth::ChoiceOp>(&op)) {
341 auto it = eval.find(choiceOp.getInputs().front());
342 if (it == eval.end())
343 return choiceOp.emitError("Input value not found in evaluation map");
344 eval[choiceOp.getResult()] = it->second;
345 } else if (auto logicOp = dyn_cast<BooleanLogicOpInterface>(&op)) {
346 for (auto value : logicOp.getInputs())
347 if (!eval.contains(value))
348 return logicOp->emitError("Input value not found in evaluation map");
349
350 eval[logicOp.getResult()] =
351 logicOp.evaluateBooleanLogic([&](unsigned i) -> const APInt & {
352 return eval.find(logicOp.getInput(i))->second;
353 });
354 } else if (auto xorOp = dyn_cast<comb::XorOp>(&op)) {
355 // TODO: Define Xor as Synth op.
356 auto it = eval.find(xorOp.getOperand(0));
357 if (it == eval.end())
358 return xorOp.emitError("Input value not found in evaluation map");
359 llvm::APInt result = it->second;
360 for (unsigned i = 1; i < xorOp.getNumOperands(); ++i) {
361 it = eval.find(xorOp.getOperand(i));
362 if (it == eval.end())
363 return xorOp.emitError("Input value not found in evaluation map");
364 result ^= it->second;
365 }
366 eval[xorOp.getResult()] = result;
367 } else if (auto constantOp = dyn_cast<hw::ConstantOp>(&op)) {
368 auto tableSize = 1ULL << numInputs;
369 eval[constantOp.getResult()] = constantOp.getValue().isZero()
370 ? llvm::APInt::getZero(tableSize)
371 : llvm::APInt::getAllOnes(tableSize);
372 } else {
373 return op.emitError("Unsupported operation for truth table simulation");
374 }
375 }
376
377 return BinaryTruthTable(numInputs, 1, eval[values[0]]);
378}
379
380//===----------------------------------------------------------------------===//
381// Cut
382//===----------------------------------------------------------------------===//
383
384bool Cut::isTrivialCut() const {
385 // A cut is a trivial cut if it has no root (rootIndex == 0 sentinel)
386 // and only one input
387 return rootIndex == 0 && inputs.size() == 1;
388}
389
390const NPNClass &Cut::getNPNClass() const { return getNPNClass(nullptr); }
391
392const NPNClass &Cut::getNPNClass(const NPNTable *npnTable) const {
393 if (npnClass)
394 return *npnClass;
395
396 const auto &truthTable = *getTruthTable();
397 NPNClass canonicalForm;
398 if (!npnTable || !npnTable->lookup(truthTable, canonicalForm))
400
401 npnClass.emplace(std::move(canonicalForm));
402 return *npnClass;
403}
404
406 const NPNTable *npnTable, const NPNClass &patternNPN,
407 SmallVectorImpl<unsigned> &permutedIndices) const {
408 const auto &npnClass = getNPNClass(npnTable);
409 npnClass.getInputPermutation(patternNPN, permutedIndices);
410}
411
412LogicalResult
414 SmallVectorImpl<DelayType> &results) const {
415 results.reserve(getInputSize());
416 const auto &network = enumerator.getLogicNetwork();
417
418 // Compute arrival times for each input.
419 for (auto inputIndex : inputs) {
420 if (isAlwaysCutInput(network, inputIndex)) {
421 // If the input is a primary input, it has no delay.
422 results.push_back(0);
423 continue;
424 }
425 auto *cutSet = enumerator.getCutSet(inputIndex);
426 assert(cutSet && "Input must have a valid cut set");
427
428 // If there is no matching pattern, it means it's not possible to use the
429 // input in the cut rewriting. Return empty vector to indicate failure.
430 auto *bestCut = cutSet->getBestMatchedCut();
431 if (!bestCut)
432 return failure();
433
434 const auto &matchedPattern = *bestCut->getMatchedPattern();
435
436 // Get the value for result number lookup
437 mlir::Value inputValue = network.getValue(inputIndex);
438 // Otherwise, the cut input is an op result. Get the arrival time
439 // from the matched pattern.
440 results.push_back(matchedPattern.getArrivalTime(
441 cast<mlir::OpResult>(inputValue).getResultNumber()));
442 }
443
444 return success();
445}
446
447void Cut::dump(llvm::raw_ostream &os, const LogicNetwork &network) const {
448 os << "// === Cut Dump ===\n";
449 os << "Cut with " << getInputSize() << " inputs";
450 if (rootIndex != 0) {
451 auto *rootOp = network.getGate(rootIndex).getOperation();
452 if (rootOp)
453 os << " and root: " << *rootOp;
454 }
455 os << "\n";
456
457 if (isTrivialCut()) {
458 mlir::Value inputVal = network.getValue(inputs[0]);
459 os << "Primary input cut: " << inputVal << "\n";
460 return;
461 }
462
463 os << "Inputs (indices): \n";
464 for (auto [idx, inputIndex] : llvm::enumerate(inputs)) {
465 mlir::Value inputVal = network.getValue(inputIndex);
466 os << " Input " << idx << " (index " << inputIndex << "): " << inputVal
467 << "\n";
468 }
469
470 if (rootIndex != 0) {
471 os << "\nRoot operation: \n";
472 if (auto *rootOp = network.getGate(rootIndex).getOperation())
473 rootOp->print(os);
474 os << "\n";
475 }
476
477 auto &npnClass = getNPNClass();
478 npnClass.dump(os);
479
480 os << "// === Cut End ===\n";
481}
482
483unsigned Cut::getInputSize() const { return inputs.size(); }
484
485unsigned Cut::getOutputSize(const LogicNetwork &network) const {
486 if (rootIndex == 0)
487 return 1; // Trivial cut has 1 output
488 auto *rootOp = network.getGate(rootIndex).getOperation();
489 return rootOp ? rootOp->getNumResults() : 1;
490}
491
492/// Simulate a gate and return its truth table.
493static inline llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind,
494 const llvm::APInt &a) {
495 switch (kind) {
497 return a;
498 default:
499 llvm_unreachable("Unsupported unary operation for truth table computation");
500 }
501}
502
503static inline llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind,
504 const llvm::APInt &a,
505 const llvm::APInt &b) {
506 switch (kind) {
508 return a & b;
510 return a ^ b;
511 default:
512 llvm_unreachable(
513 "Unsupported binary operation for truth table computation");
514 }
515}
516
517static inline llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind,
518 const llvm::APInt &a,
519 const llvm::APInt &b,
520 const llvm::APInt &c) {
521 switch (kind) {
523 return evaluateMuxLogic(a, b, c);
525 return evaluateMajorityLogic(a, b, c);
527 return evaluateDotLogic(a, b, c);
529 return evaluateOneHotLogic(a, b, c);
531 return evaluateGambleLogic(a, b, c);
532 default:
533 llvm_unreachable(
534 "Unsupported ternary operation for truth table computation");
535 }
536}
537
538namespace {
539
540// Helper class to build a merged truth table for a cut based on its operand
541// cuts
542struct MergedTruthTableBuilder {
543 MergedTruthTableBuilder(ArrayRef<uint32_t> mergedInputs,
544 ArrayRef<const Cut *> operandCuts)
545 : mergedInputs(mergedInputs), numMergedInputs(mergedInputs.size()),
546 operandCuts(operandCuts) {
547 assert(llvm::is_sorted(mergedInputs) && "merged inputs must be sorted");
548 assert(llvm::adjacent_find(mergedInputs) == mergedInputs.end() &&
549 "merged inputs must be unique");
550 }
551
552 ArrayRef<uint32_t> mergedInputs;
553 unsigned numMergedInputs;
554 ArrayRef<const Cut *> operandCuts;
555
556 std::optional<unsigned> findMergedInputPosition(uint32_t operandIdx) const {
557 auto *it = llvm::find(mergedInputs, operandIdx);
558 if (it == mergedInputs.end())
559 return std::nullopt;
560 return static_cast<unsigned>(std::distance(mergedInputs.begin(), it));
561 }
562
563 const Cut *findOperandCut(uint32_t operandIdx) const {
564 for (const Cut *cut : operandCuts) {
565 if (!cut)
566 continue;
567 uint32_t cutOutput =
568 cut->isTrivialCut() ? cut->inputs[0] : cut->getRootIndex();
569 if (cutOutput == operandIdx)
570 return cut;
571 }
572 return nullptr;
573 }
574
575 void getInputMapping(const Cut *cut,
576 SmallVectorImpl<unsigned> &mapping) const {
577 mapping.clear();
578 mapping.reserve(cut->inputs.size());
579 for (uint32_t idx : cut->inputs) {
580 auto *it = llvm::find(mergedInputs, idx);
581 assert(it != mergedInputs.end() &&
582 "cut input must exist in merged inputs");
583 mapping.push_back(static_cast<unsigned>(it - mergedInputs.begin()));
584 }
585 }
586
587 llvm::APInt expandCutTruthTable(const Cut *cut) const {
588 const auto &cutTT = *cut->getTruthTable();
589 SmallVector<unsigned, 8> inputMapping;
590 getInputMapping(cut, inputMapping);
592 cutTT.table, inputMapping, numMergedInputs);
593 }
594
595 llvm::APInt expandOperand(uint32_t operandIdx, bool isInverted) const {
596 llvm::APInt result(1, 0);
597 if (operandIdx == LogicNetwork::kConstant0) {
598 result = llvm::APInt::getZero(1U << numMergedInputs);
599 } else if (operandIdx == LogicNetwork::kConstant1) {
600 result = llvm::APInt::getAllOnes(1U << numMergedInputs);
601 } else if (auto pos = findMergedInputPosition(operandIdx)) {
602 // Direct cut inputs already live in the merged input space.
603 result = circt::createVarMask(numMergedInputs, *pos, true);
604 } else if (const Cut *cut = findOperandCut(operandIdx)) {
605 // Internal operands reuse the operand cut truth table after expanding it
606 // to this root cut's merged input space.
607 result = expandCutTruthTable(cut);
608 } else {
609 llvm_unreachable("Operand not found in cuts or merged inputs");
610 }
611
612 if (isInverted)
613 result.flipAllBits();
614 return result;
615 }
616
617 BinaryTruthTable computeForGate(const LogicNetworkGate &rootGate) const {
618 auto getEdgeTT = [&](unsigned edgeIdx) {
619 const auto &edge = rootGate.edges[edgeIdx];
620 return expandOperand(edge.getIndex(), edge.isInverted());
621 };
622
623 switch (rootGate.getKind()) {
626 return BinaryTruthTable(
627 numMergedInputs, 1,
628 applyGateSemantics(rootGate.getKind(), getEdgeTT(0), getEdgeTT(1)));
634 return BinaryTruthTable(numMergedInputs, 1,
635 applyGateSemantics(rootGate.getKind(),
636 getEdgeTT(0), getEdgeTT(1),
637 getEdgeTT(2)));
639 return BinaryTruthTable(
640 numMergedInputs, 1,
641 applyGateSemantics(rootGate.getKind(), getEdgeTT(0)));
642 default:
643 llvm_unreachable("Unsupported operation for truth table computation");
644 }
645 }
646};
647
648} // namespace
649
651 if (isTrivialCut()) {
652 assert(truthTable && "trivial cuts should have their truth table pre-set");
653 return;
654 }
655
656 assert(!operandCuts.empty() &&
657 "non-trivial cuts must carry operand cuts for truth table expansion");
658
659 const auto &rootGate = network.getGate(rootIndex);
660 truthTable.emplace(
661 MergedTruthTableBuilder(inputs, operandCuts).computeForGate(rootGate));
662}
663
664bool Cut::dominates(const Cut &other) const {
665 return dominates(other.inputs, other.signature);
666}
667
668bool Cut::dominates(ArrayRef<uint32_t> otherInputs, uint64_t otherSig) const {
669
670 if (getInputSize() > otherInputs.size())
671 return false;
672
673 if ((signature & otherSig) != signature)
674 return false;
675
676 return std::includes(otherInputs.begin(), otherInputs.end(), inputs.begin(),
677 inputs.end());
678}
679
680Cut Cut::getTrivialCut(uint32_t index) {
681 Cut cut;
682 cut.inputs.push_back(index);
683 // The truth table for a trivial cut is just the identity function on its
684 // single input.
685 cut.setTruthTable(BinaryTruthTable(1, 1, llvm::APInt(2, 2)));
686 cut.setSignature(1ULL << (index % 64)); // Set signature bit for this input
687 return cut;
688}
689
690//===----------------------------------------------------------------------===//
691// MatchedPattern
692//===----------------------------------------------------------------------===//
693
694ArrayRef<DelayType> MatchedPattern::getArrivalTimes() const {
695 assert(pattern && "Pattern must be set to get arrival time");
696 return arrivalTimes;
697}
698
700 assert(pattern && "Pattern must be set to get arrival time");
701 return arrivalTimes[index];
702}
703
705 assert(pattern && "Pattern must be set to get the pattern");
706 return pattern;
707}
708
710 assert(pattern && "Pattern must be set to get area");
711 return area;
712}
713
714//===----------------------------------------------------------------------===//
715// CutSet
716//===----------------------------------------------------------------------===//
717
719
720unsigned CutSet::size() const { return cuts.size(); }
721
722void CutSet::addCut(Cut *cut) {
723 assert(!isFrozen && "Cannot add cuts to a frozen cut set");
724 cuts.push_back(cut);
725}
726
727ArrayRef<Cut *> CutSet::getCuts() const { return cuts; }
728
729// Remove duplicate cuts and non-minimal cuts. A cut is non-minimal if there
730// exists another cut that is a subset of it.
731static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl<Cut *> &cuts) {
732 auto dumpInputs = [](llvm::raw_ostream &os,
733 const llvm::SmallVectorImpl<uint32_t> &inputs) {
734 os << "{";
735 llvm::interleaveComma(inputs, os);
736 os << "}";
737 };
738 // Sort by size, then lexicographically by inputs. This enables cheap exact
739 // duplicate elimination and tighter candidate filtering for subset checks.
740 std::stable_sort(cuts.begin(), cuts.end(), [](const Cut *a, const Cut *b) {
741 if (a->getInputSize() != b->getInputSize())
742 return a->getInputSize() < b->getInputSize();
743 return std::lexicographical_compare(a->inputs.begin(), a->inputs.end(),
744 b->inputs.begin(), b->inputs.end());
745 });
746
747 // Group kept cuts by input size so subset checks only visit smaller cuts.
748 unsigned maxCutSize = cuts.empty() ? 0 : cuts.back()->getInputSize();
749 llvm::SmallVector<llvm::SmallVector<Cut *, 4>, 16> keptBySize(maxCutSize + 1);
750
751 // Compact kept cuts in-place.
752 unsigned uniqueCount = 0;
753 for (Cut *cut : cuts) {
754 unsigned cutSize = cut->getInputSize();
755
756 // Fast exact duplicate check: with lexicographic sort, duplicates are
757 // adjacent among cuts with equal size.
758 if (uniqueCount > 0) {
759 Cut *lastKept = cuts[uniqueCount - 1];
760 if (lastKept->getInputSize() == cutSize &&
761 lastKept->inputs == cut->inputs)
762 continue;
763 }
764
765 bool isDominated = false;
766 for (unsigned existingSize = 1; existingSize < cutSize && !isDominated;
767 ++existingSize) {
768 for (const Cut *existingCut : keptBySize[existingSize]) {
769 if (!existingCut->dominates(*cut))
770 continue;
771
772 LLVM_DEBUG({
773 llvm::dbgs() << "Dropping non-minimal cut ";
774 dumpInputs(llvm::dbgs(), cut->inputs);
775 llvm::dbgs() << " due to subset ";
776 dumpInputs(llvm::dbgs(), existingCut->inputs);
777 llvm::dbgs() << "\n";
778 });
779 isDominated = true;
780 break;
781 }
782 }
783
784 if (isDominated)
785 continue;
786
787 cuts[uniqueCount++] = cut;
788 keptBySize[cutSize].push_back(cut);
789 }
790
791 LLVM_DEBUG(llvm::dbgs() << "Original cuts: " << cuts.size()
792 << " Unique cuts: " << uniqueCount << "\n");
793
794 // Resize the cuts vector to the number of surviving cuts.
795 cuts.resize(uniqueCount);
796}
797
799 const CutRewriterOptions &options,
800 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut,
801 const LogicNetwork &logicNetwork) {
802
803 // Remove duplicate/non-minimal cuts first so all follow-up work only runs on
804 // survivors.
806
807 // Compute truth tables lazily, then match cuts to collect timing/area data.
808 for (Cut *cut : cuts) {
809 if (!cut->getTruthTable().has_value())
810 cut->computeTruthTableFromOperands(logicNetwork);
811
812 assert(cut->getInputSize() <= options.maxCutInputSize &&
813 "Cut input size exceeds maximum allowed size");
814
815 if (auto matched = matchCut(*cut))
816 cut->setMatchedPattern(std::move(*matched));
817 }
818
819 // Sort cuts by priority to select the most promising ones.
820 // Priority is determined by the optimization strategy:
821 // - Trivial cuts (direct connections) have highest priority
822 // - Among matched cuts, compare by area/delay based on the strategy
823 // - Matched cuts are preferred over unmatched cuts
824 // See "Combinational and Sequential Mapping with Priority Cuts" by Mishchenko
825 // et al., ICCAD 2007 for more details.
826 // TODO: Use a priority queue instead of sorting for better performance.
827
828 // Partition the cuts into trivial and non-trivial cuts.
829 auto *trivialCutsEnd =
830 std::stable_partition(cuts.begin(), cuts.end(),
831 [](const Cut *cut) { return cut->isTrivialCut(); });
832
833 auto isBetterCut = [&options](const Cut *a, const Cut *b) {
834 assert(!a->isTrivialCut() && !b->isTrivialCut() &&
835 "Trivial cuts should have been excluded");
836 const auto &aMatched = a->getMatchedPattern();
837 const auto &bMatched = b->getMatchedPattern();
838
839 if (aMatched && bMatched)
840 return compareDelayAndArea(
841 options.strategy, aMatched->getArea(), aMatched->getArrivalTimes(),
842 bMatched->getArea(), bMatched->getArrivalTimes());
843
844 if (static_cast<bool>(aMatched) != static_cast<bool>(bMatched))
845 return static_cast<bool>(aMatched);
846
847 return a->getInputSize() < b->getInputSize();
848 };
849 std::stable_sort(trivialCutsEnd, cuts.end(), isBetterCut);
850
851 // Keep only the top-K cuts to bound growth.
852 if (cuts.size() > options.maxCutSizePerRoot)
853 cuts.resize(options.maxCutSizePerRoot);
854
855 // Select the best cut from the remaining candidates.
856 bestCut = nullptr;
857 for (Cut *cut : cuts) {
858 const auto &currentMatch = cut->getMatchedPattern();
859 if (!currentMatch)
860 continue;
861 bestCut = cut;
862 break;
863 }
864
865 LLVM_DEBUG({
866 llvm::dbgs() << "Finalized cut set with " << cuts.size() << " cuts and "
867 << (bestCut
868 ? "matched pattern to " + bestCut->getMatchedPattern()
869 ->getPattern()
870 ->getPatternName()
871 : "no matched pattern")
872 << "\n";
873 });
874
875 isFrozen = true; // Mark the cut set as frozen
876}
877
878//===----------------------------------------------------------------------===//
879// CutRewritePattern
880//===----------------------------------------------------------------------===//
881
883 SmallVectorImpl<NPNClass> &matchingNPNClasses) const {
884 return false;
885}
886
887//===----------------------------------------------------------------------===//
888// CutRewritePatternSet
889//===----------------------------------------------------------------------===//
890
892 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns)
893 : patterns(std::move(patterns)) {
894 // Initialize the NPN to pattern map
895 for (auto &pattern : this->patterns) {
896 SmallVector<NPNClass, 2> npnClasses;
897 auto result = pattern->useTruthTableMatcher(npnClasses);
898 if (result) {
899 for (auto npnClass : npnClasses) {
900 // Create a NPN class from the truth table
901 npnToPatternMap[{npnClass.truthTable.table,
902 npnClass.truthTable.numInputs}]
903 .push_back(std::make_pair(std::move(npnClass), pattern.get()));
904 }
905 } else {
906 // If the pattern does not provide NPN classes, we use a special key
907 // to indicate that it should be considered for all cuts.
908 nonNPNPatterns.push_back(pattern.get());
909 }
910 }
911}
912
913//===----------------------------------------------------------------------===//
914// CutEnumerator
915//===----------------------------------------------------------------------===//
916
918 : cutAllocator(stats.numCutsCreated),
919 cutSetAllocator(stats.numCutSetsCreated), options(options) {}
920
922 CutSet *cutSet = cutSetAllocator.create();
923 auto [cutSetPtr, inserted] = cutSets.try_emplace(index, cutSet);
924 assert(inserted && "Cut set already exists for this index");
925 return cutSetPtr->second;
926}
927
929 cutSets.clear();
930 processingOrder.clear();
932 cutAllocator.DestroyAll();
933 cutSetAllocator.DestroyAll();
934}
935
936LogicalResult CutEnumerator::visitLogicOp(uint32_t nodeIndex) {
937 const auto &gate = logicNetwork.getGate(nodeIndex);
938 auto *logicOp = gate.getOperation();
939 assert(logicOp && logicOp->getNumResults() == 1 &&
940 "Logic operation must have a single result");
941
942 if (gate.getKind() == LogicNetworkGate::Choice) {
943 auto choiceOp = cast<synth::ChoiceOp>(logicOp);
944 auto *resultCutSet = createNewCutSet(nodeIndex);
945 Cut *primaryInputCut = cutAllocator.create(Cut::getTrivialCut(nodeIndex));
946 processingOrder.push_back(nodeIndex);
947 resultCutSet->addCut(primaryInputCut);
948
949 for (Value operand : choiceOp.getInputs()) {
950 auto *operandCutSet = getCutSet(logicNetwork.getIndex(operand));
951 if (!operandCutSet)
952 return logicOp->emitError("Failed to get cut set for choice operand");
953
954 // Choice nodes do not introduce new logic. They forward each non-trivial
955 // operand cut as an equivalent alternative for the same root.
956 for (const Cut *operandCut : operandCutSet->getCuts()) {
957 if (operandCut->isTrivialCut())
958 continue;
959
960 resultCutSet->addCut(cutAllocator.create(
961 nodeIndex, operandCut->inputs, operandCut->getSignature(),
962 ArrayRef<const Cut *>{operandCut}, *operandCut->getTruthTable()));
963 }
964 }
965
966 // Finalize cut set: remove duplicates, limit size, and match patterns
967 resultCutSet->finalize(options, matchCut, logicNetwork);
968 return success();
969 }
970
971 unsigned numFanins = gate.getNumFanins();
972
973 // Validate operation constraints
974 // TODO: Variadic operations and non-single-bit results can be supported
975 if (numFanins > 3)
976 return logicOp->emitError("Cut enumeration supports at most 3 operands, "
977 "found: ")
978 << numFanins;
979 if (!logicOp->getOpResult(0).getType().isInteger(1))
980 return logicOp->emitError()
981 << "Supported logic operations must have a single bit "
982 "result type but found: "
983 << logicOp->getResult(0).getType();
984
985 // A vector to hold cut sets for each operand along with their max cut input
986 // size.
987 SmallVector<std::pair<const CutSet *, unsigned>, 2> operandCutSets;
988 operandCutSets.reserve(numFanins);
989
990 // Collect cut sets for each fanin (using LogicNetwork edges)
991 for (unsigned i = 0; i < numFanins; ++i) {
992 uint32_t faninIndex = gate.edges[i].getIndex();
993 auto *operandCutSet = getCutSet(faninIndex);
994 if (!operandCutSet)
995 return logicOp->emitError("Failed to get cut set for fanin index ")
996 << faninIndex;
997
998 // Find the largest cut size among the operand's cuts for sorting heuristic
999 // later.
1000 unsigned maxInputCutSize = 0;
1001 for (auto *cut : operandCutSet->getCuts())
1002 maxInputCutSize = std::max(maxInputCutSize, cut->getInputSize());
1003 operandCutSets.push_back(std::make_pair(operandCutSet, maxInputCutSize));
1004 }
1005
1006 // Create the trivial cut for this node's output
1007 Cut *primaryInputCut = cutAllocator.create(Cut::getTrivialCut(nodeIndex));
1008
1009 auto *resultCutSet = createNewCutSet(nodeIndex);
1010 processingOrder.push_back(nodeIndex);
1011 resultCutSet->addCut(primaryInputCut);
1012
1013 // Sort operand cut sets by their largest cut size in descending order. This
1014 // heuristic improves efficiency of the k-way merge when generating cuts for
1015 // the current node by maximizing the chance of early pruning when the merged
1016 // cut exceeds the input size limit.
1017 llvm::stable_sort(operandCutSets,
1018 [](const std::pair<const CutSet *, unsigned> &a,
1019 const std::pair<const CutSet *, unsigned> &b) {
1020 return a.second > b.second;
1021 });
1022
1023 // Cache maxCutInputSize to avoid repeated access
1024 unsigned maxInputSize = options.maxCutInputSize;
1025
1026 // This lambda generates nested loops at runtime to iterate over all
1027 // combinations of cuts from N operands
1028 auto enumerateCutCombinations = [&](auto &&self, unsigned operandIdx,
1029 SmallVector<const Cut *, 3> &cutPtrs,
1030 uint64_t currentSig) -> void {
1031 // Base case: all operands processed, create merged cut
1032 if (operandIdx == numFanins) {
1033 // Efficient k-way merge: inputs are sorted, so dedup and constant
1034 // filtering can be done while merging. Abort early once we exceed the
1035 // cut-size limit to avoid building doomed merged cuts.
1036 SmallVector<uint32_t, 6> mergedInputs;
1037 auto appendMergedInput = [&](uint32_t value) {
1038 if (value == LogicNetwork::kConstant0 ||
1039 value == LogicNetwork::kConstant1)
1040 return true;
1041 if (!mergedInputs.empty() && mergedInputs.back() == value)
1042 return true;
1043 mergedInputs.push_back(value);
1044 return mergedInputs.size() <= maxInputSize;
1045 };
1046
1047 if (numFanins == 1) {
1048 // Single input: copy while filtering constants.
1049 mergedInputs.reserve(
1050 std::min<size_t>(cutPtrs[0]->inputs.size(), maxInputSize));
1051 for (uint32_t value : cutPtrs[0]->inputs)
1052 if (!appendMergedInput(value))
1053 return;
1054 } else if (numFanins == 2) {
1055 // Two-way merge (common case for AND gates)
1056 const auto &inputs0 = cutPtrs[0]->inputs;
1057 const auto &inputs1 = cutPtrs[1]->inputs;
1058 mergedInputs.reserve(
1059 std::min<size_t>(inputs0.size() + inputs1.size(), maxInputSize));
1060
1061 unsigned i = 0, j = 0;
1062 while (i < inputs0.size() || j < inputs1.size()) {
1063 uint32_t next;
1064 if (j == inputs1.size() ||
1065 (i < inputs0.size() && inputs0[i] <= inputs1[j])) {
1066 next = inputs0[i++];
1067 if (j < inputs1.size() && inputs1[j] == next)
1068 ++j;
1069 } else {
1070 next = inputs1[j++];
1071 }
1072 if (!appendMergedInput(next))
1073 return;
1074 }
1075 } else {
1076 // Three-way merge (for MAJ/MUX gates)
1077 const SmallVectorImpl<uint32_t> &inputs0 = cutPtrs[0]->inputs;
1078 const SmallVectorImpl<uint32_t> &inputs1 = cutPtrs[1]->inputs;
1079 const SmallVectorImpl<uint32_t> &inputs2 = cutPtrs[2]->inputs;
1080 mergedInputs.reserve(std::min<size_t>(
1081 inputs0.size() + inputs1.size() + inputs2.size(), maxInputSize));
1082
1083 unsigned i = 0, j = 0, k = 0;
1084 while (i < inputs0.size() || j < inputs1.size() || k < inputs2.size()) {
1085 // Find minimum among available elements
1086 uint32_t minVal = UINT32_MAX;
1087 if (i < inputs0.size())
1088 minVal = std::min(minVal, inputs0[i]);
1089 if (j < inputs1.size())
1090 minVal = std::min(minVal, inputs1[j]);
1091 if (k < inputs2.size())
1092 minVal = std::min(minVal, inputs2[k]);
1093
1094 // Advance all iterators pointing to minVal (handles duplicates)
1095 if (i < inputs0.size() && inputs0[i] == minVal)
1096 i++;
1097 if (j < inputs1.size() && inputs1[j] == minVal)
1098 j++;
1099 if (k < inputs2.size() && inputs2[k] == minVal)
1100 k++;
1101
1102 if (!appendMergedInput(minVal))
1103 return;
1104 }
1105 }
1106
1107 // Create the merged cut.
1108 Cut *mergedCut = cutAllocator.create(nodeIndex, mergedInputs, currentSig,
1109 ArrayRef<const Cut *>(cutPtrs));
1110 resultCutSet->addCut(mergedCut);
1111
1112 LLVM_DEBUG({
1113 if (mergedCut->inputs.size() >= 4) {
1114 llvm::dbgs() << "Generated cut for node " << nodeIndex;
1115 if (logicOp)
1116 llvm::dbgs() << " (" << logicOp->getName() << ")";
1117 llvm::dbgs() << " inputs=";
1118 llvm::interleaveComma(mergedCut->inputs, llvm::dbgs());
1119 llvm::dbgs() << "\n";
1120 }
1121 });
1122 return;
1123 }
1124
1125 // Recursive case: iterate over cuts for current operand
1126 const CutSet *currentCutSet = operandCutSets[operandIdx].first;
1127 for (const Cut *cut : currentCutSet->getCuts()) {
1128 uint64_t cutSig = cut->getSignature();
1129 uint64_t newSig = currentSig | cutSig;
1130 if (static_cast<unsigned>(llvm::popcount(newSig)) > maxInputSize)
1131 continue; // Early rejection based on signature
1132
1133 cutPtrs.push_back(cut);
1134
1135 // Recurse to next operand
1136 self(self, operandIdx + 1, cutPtrs, newSig);
1137
1138 cutPtrs.pop_back();
1139 }
1140 };
1141
1142 // Start the recursion with empty cut pointer list and zero signature
1143 SmallVector<const Cut *, 3> cutPtrs;
1144 cutPtrs.reserve(numFanins);
1145 enumerateCutCombinations(enumerateCutCombinations, 0, cutPtrs, 0ULL);
1146
1147 // Finalize cut set: remove duplicates, limit size, and match patterns
1148 resultCutSet->finalize(options, matchCut, logicNetwork);
1149
1150 return success();
1151}
1152
1154 Operation *topOp,
1155 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut) {
1156 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts for module: " << topOp->getName()
1157 << "\n");
1158 // Topologically sort the logic network
1159 if (failed(topologicallySortLogicNetwork(topOp)))
1160 return failure();
1161
1162 // Store the pattern matching function for use during cut finalization
1163 this->matchCut = matchCut;
1164
1165 // Build the flat logic network representation for efficient simulation
1166 auto &block = topOp->getRegion(0).getBlocks().front();
1167 if (failed(logicNetwork.buildFromBlock(&block)))
1168 return failure();
1169
1170 for (const auto &[index, gate] : llvm::enumerate(logicNetwork.getGates())) {
1171 // Skip non-logic gates.
1172 if (!gate.isLogicGate())
1173 continue;
1174
1175 // Ensure cut set exists for each logic gate
1176 if (failed(visitLogicOp(index)))
1177 return failure();
1178 }
1179
1180 LLVM_DEBUG(llvm::dbgs() << "Cut enumeration completed successfully\n");
1181 return success();
1182}
1183
1184const CutSet *CutEnumerator::getCutSet(uint32_t index) {
1185 // Check if cut set already exists
1186 auto it = cutSets.find(index);
1187 if (it == cutSets.end()) {
1188 // Create new cut set for an unprocessed value (primary input or other)
1189 CutSet *cutSet = cutSetAllocator.create();
1190 Cut *trivialCut = cutAllocator.create(Cut::getTrivialCut(index));
1191 cutSet->addCut(trivialCut);
1192 auto [newIt, inserted] = cutSets.insert({index, cutSet});
1193 assert(inserted && "Cut set already exists for this index");
1194 it = newIt;
1195 }
1196
1197 return it->second;
1198}
1199
1200/// Generate a human-readable name for a value used in test output.
1201/// This function creates meaningful names for values to make debug output
1202/// and test results more readable and understandable.
1203static StringRef
1204getTestVariableName(Value value, DenseMap<OperationName, unsigned> &opCounter) {
1205 if (auto *op = value.getDefiningOp()) {
1206 // Handle values defined by operations
1207 // First, check if the operation already has a name hint attribute
1208 if (auto name = op->getAttrOfType<StringAttr>("sv.namehint"))
1209 return name.getValue();
1210
1211 // For single-result operations, generate a unique name based on operation
1212 // type
1213 if (op->getNumResults() == 1) {
1214 auto opName = op->getName();
1215 auto count = opCounter[opName]++;
1216
1217 // Create a unique name by appending a counter to the operation name
1218 SmallString<16> nameStr;
1219 nameStr += opName.getStringRef();
1220 nameStr += "_";
1221 nameStr += std::to_string(count);
1222
1223 // Store the generated name as a hint attribute for future reference
1224 auto nameAttr = StringAttr::get(op->getContext(), nameStr);
1225 op->setAttr("sv.namehint", nameAttr);
1226 return nameAttr;
1227 }
1228
1229 // Multi-result operations or other cases get a generic name
1230 return "<unknown>";
1231 }
1232
1233 // Handle block arguments
1234 auto blockArg = cast<BlockArgument>(value);
1235 auto hwOp =
1236 dyn_cast<circt::hw::HWModuleOp>(blockArg.getOwner()->getParentOp());
1237 if (!hwOp)
1238 return "<unknown>";
1239
1240 // Return the formal input name from the hardware module
1241 return hwOp.getInputName(blockArg.getArgNumber());
1242}
1243
1245 DenseMap<OperationName, unsigned> opCounter;
1246 for (auto index : processingOrder) {
1247 auto it = cutSets.find(index);
1248 if (it == cutSets.end())
1249 continue;
1250 auto &cutSet = *it->second;
1251 mlir::Value value = logicNetwork.getValue(index);
1252 llvm::outs() << getTestVariableName(value, opCounter) << " "
1253 << cutSet.getCuts().size() << " cuts:";
1254 for (const Cut *cut : cutSet.getCuts()) {
1255 llvm::outs() << " {";
1256 llvm::interleaveComma(cut->inputs, llvm::outs(), [&](uint32_t inputIdx) {
1257 mlir::Value inputVal = logicNetwork.getValue(inputIdx);
1258 llvm::outs() << getTestVariableName(inputVal, opCounter);
1259 });
1260 auto &pattern = cut->getMatchedPattern();
1261 llvm::outs() << "}"
1262 << "@t" << cut->getTruthTable()->table.getZExtValue() << "d";
1263 if (pattern) {
1264 llvm::outs() << *std::max_element(pattern->getArrivalTimes().begin(),
1265 pattern->getArrivalTimes().end());
1266 } else {
1267 llvm::outs() << "0";
1268 }
1269 }
1270 llvm::outs() << "\n";
1271 }
1272 llvm::outs() << "Cut enumeration completed successfully\n";
1273}
1274
1275//===----------------------------------------------------------------------===//
1276// CutRewriter
1277//===----------------------------------------------------------------------===//
1278
1279LogicalResult CutRewriter::run(Operation *topOp) {
1280 LLVM_DEBUG({
1281 llvm::dbgs() << "Starting Cut Rewriter\n";
1282 llvm::dbgs() << "Mode: "
1284 : "timing")
1285 << "\n";
1286 llvm::dbgs() << "Max input size: " << options.maxCutInputSize << "\n";
1287 llvm::dbgs() << "Max cut size: " << options.maxCutSizePerRoot << "\n";
1288 });
1289
1290 // Currently we don't support patterns with multiple outputs.
1291 // So check that.
1292 // TODO: This must be removed when we support multiple outputs.
1293 for (auto &pattern : patterns.patterns) {
1294 if (pattern->getNumOutputs() > 1) {
1295 return mlir::emitError(pattern->getLoc(),
1296 "Cut rewriter does not support patterns with "
1297 "multiple outputs yet");
1298 }
1299 }
1300
1301 // First sort the operations topologically to ensure we can process them
1302 // in a valid order.
1303 if (failed(topologicallySortLogicNetwork(topOp)))
1304 return failure();
1305
1306 // Enumerate cuts for all nodes (initial delay-oriented selection)
1307 if (failed(enumerateCuts(topOp)))
1308 return failure();
1309
1310 // Dump cuts if testing priority cuts.
1313 return success();
1314 }
1315
1316 // Select best cuts and perform mapping
1317 if (failed(runBottomUpRewrite(topOp)))
1318 return failure();
1319
1320 return success();
1321}
1322
1323LogicalResult CutRewriter::enumerateCuts(Operation *topOp) {
1324 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts...\n");
1325
1327 topOp, [&](const Cut &cut) -> std::optional<MatchedPattern> {
1328 // Match the cut against the patterns
1329 return patternMatchCut(cut);
1330 });
1331}
1332
1333ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
1335 if (patterns.npnToPatternMap.empty())
1336 return {};
1337
1338 auto &npnClass = cut.getNPNClass(options.npnTable);
1339 auto it = patterns.npnToPatternMap.find(
1340 {npnClass.truthTable.table, npnClass.truthTable.numInputs});
1341 if (it == patterns.npnToPatternMap.end())
1342 return {};
1343 return it->getSecond();
1344}
1345
1346std::optional<MatchedPattern> CutRewriter::patternMatchCut(const Cut &cut) {
1347 if (cut.isTrivialCut())
1348 return {};
1349
1350 const auto &network = cutEnumerator.getLogicNetwork();
1351 const CutRewritePattern *bestPattern = nullptr;
1352 SmallVector<DelayType, 4> inputArrivalTimes;
1353 SmallVector<DelayType, 1> bestArrivalTimes;
1354 double bestArea = 0.0;
1355 inputArrivalTimes.reserve(cut.getInputSize());
1356 bestArrivalTimes.reserve(cut.getOutputSize(network));
1357
1358 // Compute arrival times for each input.
1359 if (failed(cut.getInputArrivalTimes(cutEnumerator, inputArrivalTimes)))
1360 return {};
1361
1362 auto computeArrivalTimeAndPickBest =
1363 [&](const CutRewritePattern *pattern, const MatchResult &matchResult,
1364 llvm::function_ref<unsigned(unsigned)> mapIndex) {
1365 SmallVector<DelayType, 1> outputArrivalTimes;
1366 // Compute the maximum delay for each output from inputs.
1367 for (unsigned outputIndex = 0, outputSize = cut.getOutputSize(network);
1368 outputIndex < outputSize; ++outputIndex) {
1369 // Compute the arrival time for this output.
1370 DelayType outputArrivalTime = 0;
1371 auto delays = matchResult.getDelays();
1372 for (unsigned inputIndex = 0, inputSize = cut.getInputSize();
1373 inputIndex < inputSize; ++inputIndex) {
1374 // Map pattern input i to cut input through NPN transformations
1375 unsigned cutOriginalInput = mapIndex(inputIndex);
1376 outputArrivalTime =
1377 std::max(outputArrivalTime,
1378 delays[outputIndex * inputSize + inputIndex] +
1379 inputArrivalTimes[cutOriginalInput]);
1380 }
1381
1382 outputArrivalTimes.push_back(outputArrivalTime);
1383 }
1384
1385 // Update the arrival time
1386 if (!bestPattern ||
1387 compareDelayAndArea(options.strategy, matchResult.area,
1388 outputArrivalTimes, bestArea,
1389 bestArrivalTimes)) {
1390 LLVM_DEBUG({
1391 llvm::dbgs() << "== Matched Pattern ==============\n";
1392 llvm::dbgs() << "Matching cut: \n";
1393 cut.dump(llvm::dbgs(), network);
1394 llvm::dbgs() << "Found better pattern: "
1395 << pattern->getPatternName();
1396 llvm::dbgs() << " with area: " << matchResult.area;
1397 llvm::dbgs() << " and input arrival times: ";
1398 for (unsigned i = 0; i < inputArrivalTimes.size(); ++i) {
1399 llvm::dbgs() << " " << inputArrivalTimes[i];
1400 }
1401 llvm::dbgs() << " and arrival times: ";
1402
1403 for (auto arrivalTime : outputArrivalTimes) {
1404 llvm::dbgs() << " " << arrivalTime;
1405 }
1406 llvm::dbgs() << "\n";
1407 llvm::dbgs() << "== Matched Pattern End ==============\n";
1408 });
1409
1410 bestArrivalTimes = std::move(outputArrivalTimes);
1411 bestArea = matchResult.area;
1412 bestPattern = pattern;
1413 }
1414 };
1415
1416 for (auto &[patternNPN, pattern] : getMatchingPatternsFromTruthTable(cut)) {
1417 assert(patternNPN.truthTable.numInputs == cut.getInputSize() &&
1418 "Pattern input size must match cut input size");
1419 auto matchResult = pattern->match(cutEnumerator, cut);
1420 if (!matchResult)
1421 continue;
1422 auto &cutNPN = cut.getNPNClass(options.npnTable);
1423
1424 // Get the input mapping from pattern's NPN class to cut's NPN class
1425 SmallVector<unsigned> inputMapping;
1426 cutNPN.getInputPermutation(patternNPN, inputMapping);
1427 computeArrivalTimeAndPickBest(pattern, *matchResult,
1428 [&](unsigned i) { return inputMapping[i]; });
1429 }
1430
1431 for (const CutRewritePattern *pattern : patterns.nonNPNPatterns) {
1432 if (auto matchResult = pattern->match(cutEnumerator, cut))
1433 computeArrivalTimeAndPickBest(pattern, *matchResult,
1434 [&](unsigned i) { return i; });
1435 }
1436
1437 if (!bestPattern)
1438 return {}; // No matching pattern found
1439
1440 return MatchedPattern(bestPattern, std::move(bestArrivalTimes), bestArea);
1441}
1442
1443LogicalResult CutRewriter::runBottomUpRewrite(Operation *top) {
1444 LLVM_DEBUG(llvm::dbgs() << "Performing cut-based rewriting...\n");
1445 const auto &network = cutEnumerator.getLogicNetwork();
1446 const auto &cutSets = cutEnumerator.getCutSets();
1447 auto processingOrder = cutEnumerator.getProcessingOrder();
1448
1449 // Note: Don't clear cutEnumerator yet - we need it during rewrite
1450 UnusedOpPruner pruner;
1451 PatternRewriter rewriter(top->getContext());
1452
1453 // Process in reverse topological order
1454 for (auto index : llvm::reverse(processingOrder)) {
1455 auto it = cutSets.find(index);
1456 if (it == cutSets.end())
1457 continue;
1458
1459 mlir::Value value = network.getValue(index);
1460 auto &cutSet = *it->second;
1461
1462 if (value.use_empty()) {
1463 if (auto *op = value.getDefiningOp())
1464 pruner.eraseNow(op);
1465 continue;
1466 }
1467
1468 if (isAlwaysCutInput(network, index)) {
1469 // If the value is a primary input, skip it
1470 LLVM_DEBUG(llvm::dbgs() << "Skipping inputs: " << value << "\n");
1471 continue;
1472 }
1473
1474 LLVM_DEBUG(llvm::dbgs() << "Cut set for value: " << value << "\n");
1475 auto *bestCut = cutSet.getBestMatchedCut();
1476 if (!bestCut) {
1478 continue; // No matching pattern found, skip this value
1479 return emitError(value.getLoc(), "No matching cut found for value: ")
1480 << value;
1481 }
1482
1483 // Get the root operation from LogicNetwork
1484 auto *rootOp = network.getGate(bestCut->getRootIndex()).getOperation();
1485 rewriter.setInsertionPoint(rootOp);
1486 const auto &matchedPattern = bestCut->getMatchedPattern();
1487 auto result = matchedPattern->getPattern()->rewrite(rewriter, cutEnumerator,
1488 *bestCut);
1489 if (failed(result))
1490 return failure();
1491
1492 rewriter.replaceOp(rootOp, *result);
1494
1496 auto array = rewriter.getI64ArrayAttr(matchedPattern->getArrivalTimes());
1497 (*result)->setAttr("test.arrival_times", array);
1498 }
1499 }
1500
1501 // Clear the enumerator after rewriting is complete
1503 return success();
1504}
assert(baseType &&"element must be base type")
static llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind, const llvm::APInt &a)
Simulate a gate and return its truth table.
static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl< Cut * > &cuts)
static bool isAlwaysCutInput(const LogicNetwork &network, uint32_t index)
static StringRef getTestVariableName(Value value, DenseMap< OperationName, unsigned > &opCounter)
Generate a human-readable name for a value used in test output.
static bool compareDelayAndArea(OptimizationStrategy strategy, double newArea, ArrayRef< DelayType > newDelay, double oldArea, ArrayRef< DelayType > oldDelay)
RewritePatternSet pattern
Strategy strategy
Precomputed NPN canonicalization table for 4-input single-output functions.
Definition TruthTable.h:168
bool lookup(const BinaryTruthTable &tt, NPNClass &result) const
Returns false if the given truth table shape is unsupported.
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
TrackedSpecificBumpPtrAllocator< CutSet > cutSetAllocator
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
void noteCutRewritten()
Record that one cut was successfully rewritten.
CutEnumerator(const CutRewriterOptions &options)
Constructor for cut enumerator.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
TrackedSpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
CutRewritePatternSet(llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns)
Constructor that takes ownership of the provided patterns.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
CutEnumerator cutEnumerator
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
static Cut getTrivialCut(uint32_t index)
Create a trivial cut for a value.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
uint64_t signature
Signature bitset for fast cut size estimation.
uint64_t getSignature() const
Get the signature of this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
void setTruthTable(BinaryTruthTable tt)
Set the truth table directly (used for incremental computation).
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
llvm::SmallVector< const Cut *, 3 > operandCuts
Operand cuts used to create this cut (for lazy TT computation).
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
void setSignature(uint64_t sig)
Set the signature of this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another (i.e., this cut's inputs are a subset of the other's inputs).
void getPermutatedInputIndices(const NPNTable *npnTable, const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
unsigned getInputSize() const
Get the number of inputs to this cut.
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
double getArea() const
Get the area cost of using this pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
llvm::APInt expandTruthTableToInputSpace(const llvm::APInt &tt, ArrayRef< unsigned > inputMapping, unsigned numExpandedInputs)
Expand a truth table to a larger input space using the given input mapping.
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
OptimizationStrategy
Optimization strategy.
Definition SynthPasses.h:24
@ OptimizationStrategyArea
Optimize for minimal area.
Definition SynthPasses.h:25
@ OptimizationStrategyTiming
Optimize for minimal critical path delay.
Definition SynthPasses.h:26
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
int64_t DelayType
Definition CutRewriter.h:40
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:539
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:45
T evaluateDotLogic(const T &x, const T &y, const T &z)
Evaluate the Boolean function x ^ (z | (x & y)).
Definition SynthOps.h:134
T evaluateGambleLogic(const T &a, const T &b, const T &c)
Definition SynthOps.h:172
T evaluateMajorityLogic(const T &a, const T &b, const T &c)
Definition SynthOps.h:139
T evaluateMuxLogic(const T &a, const T &b, const T &c)
Definition SynthOps.h:167
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
bool isLogicNetworkOp(mlir::Operation *op)
T evaluateOneHotLogic(const T &a, const T &b, const T &c)
Definition SynthOps.h:161
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::APInt createVarMask(unsigned numVars, unsigned varIndex, bool positive)
Create a mask for a variable in the truth table.
Represents a boolean function as a truth table.
Definition TruthTable.h:41
Represents the canonical form of a boolean function under NPN equivalence.
Definition TruthTable.h:106
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
void getInputPermutation(const NPNClass &targetNPN, llvm::SmallVectorImpl< unsigned > &permutation) const
Get input permutation from this NPN class to another equivalent NPN class.
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...
void eraseNow(Operation *op)
Erase an operation immediately, and remove it from the set of ops to be removed later.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
const NPNTable * npnTable
Optional lookup table used to accelerate 4-input NPN canonicalization.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Signal edges[3]
Fanin edges (up to 3 inputs).
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
Kind getKind() const
Get the kind of this gate.
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Dot3
Ordered DOT gate (3-input, synth.dot)
@ OneHot3
OneHot gate (3-input, synth.onehot)
@ Identity
Identity gate (used for 1-input inverter)
@ Gamble3
Ordered Gamble gate (3-input, synth.gamble)
@ Mux3
Ordered MUX gate (3-input, synth.mux_inv)
@ Maj3
Reserved 3-input gate kind.
@ PrimaryInput
Primary input to the network.
@ Choice
Choice node (synth.choice)
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
Edge representation in the logic network.
Definition CutRewriter.h:75