CIRCT 23.0.0git
Loading...
Searching...
No Matches
CutRewriter.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements a DAG-based boolean matching cut rewriting algorithm for
10// applications like technology/LUT mapping and combinational logic
11// optimization. The algorithm uses priority cuts and NPN
12// (Negation-Permutation-Negation) canonical forms to efficiently match cuts
13// against rewriting patterns.
14//
15// References:
16// "Combinational and Sequential Mapping with Priority Cuts", Alan Mishchenko,
17// Sungmin Cho, Satrajit Chatterjee and Robert Brayton, ICCAD 2007
18// "Improvements to technology mapping for LUT-based FPGAs", Alan Mishchenko,
19// Satrajit Chatterjee and Robert Brayton, FPGA 2006
20//
21//===----------------------------------------------------------------------===//
22
24
29#include "circt/Support/LLVM.h"
32#include "mlir/Analysis/TopologicalSortUtils.h"
33#include "mlir/IR/Builders.h"
34#include "mlir/IR/Operation.h"
35#include "mlir/IR/RegionKindInterface.h"
36#include "mlir/IR/Value.h"
37#include "mlir/IR/ValueRange.h"
38#include "mlir/IR/Visitors.h"
39#include "mlir/Support/LLVM.h"
40#include "llvm/ADT/APInt.h"
41#include "llvm/ADT/Bitset.h"
42#include "llvm/ADT/DenseMap.h"
43#include "llvm/ADT/MapVector.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/ADT/ScopeExit.h"
46#include "llvm/ADT/SetVector.h"
47#include "llvm/ADT/SmallVector.h"
48#include "llvm/ADT/TypeSwitch.h"
49#include "llvm/ADT/iterator.h"
50#include "llvm/Support/Debug.h"
51#include "llvm/Support/ErrorHandling.h"
52#include "llvm/Support/LogicalResult.h"
53#include <algorithm>
54#include <functional>
55#include <memory>
56#include <optional>
57#include <string>
58
59#define DEBUG_TYPE "synth-cut-rewriter"
60
61using namespace circt;
62using namespace circt::synth;
63
64//===----------------------------------------------------------------------===//
65// LogicNetwork
66//===----------------------------------------------------------------------===//
67
68uint32_t LogicNetwork::getOrCreateIndex(Value value) {
69 auto [it, inserted] = valueToIndex.try_emplace(value, gates.size());
70 if (inserted) {
71 indexToValue.push_back(value);
72 gates.emplace_back(); // Will be filled in later
73 }
74 return it->second;
75}
76
77uint32_t LogicNetwork::getIndex(Value value) const {
78 const auto it = valueToIndex.find(value);
79 assert(it != valueToIndex.end() &&
80 "Value not found in LogicNetwork - use getOrCreateIndex or check with "
81 "hasIndex first");
82 return it->second;
83}
84
85bool LogicNetwork::hasIndex(Value value) const {
86 return valueToIndex.contains(value);
87}
88
89Value LogicNetwork::getValue(uint32_t index) const {
90 // Index 0 and 1 are reserved for constants, they have no associated Value
91 if (index == kConstant0 || index == kConstant1)
92 return Value();
93
94 assert(index < indexToValue.size() &&
95 "Index out of bounds in LogicNetwork::getValue");
96 return indexToValue[index];
97}
98
99void LogicNetwork::getValues(ArrayRef<uint32_t> indices,
100 SmallVectorImpl<Value> &values) const {
101 values.clear();
102 values.reserve(indices.size());
103 for (uint32_t idx : indices)
104 values.push_back(getValue(idx));
105}
106
107uint32_t LogicNetwork::addPrimaryInput(Value value) {
108 const uint32_t index = getOrCreateIndex(value);
110 return index;
111}
112
113uint32_t LogicNetwork::addGate(Operation *op, LogicNetworkGate::Kind kind,
114 Value result, ArrayRef<Signal> operands) {
115 const uint32_t index = getOrCreateIndex(result);
116 gates[index] = LogicNetworkGate(op, kind, operands);
117 return index;
118}
119
120LogicalResult LogicNetwork::buildFromBlock(Block *block) {
121 // Pre-size vectors to reduce reallocations (rough estimate)
122 const size_t estimatedSize =
123 block->getArguments().size() + block->getOperations().size();
124 indexToValue.reserve(estimatedSize);
125 gates.reserve(estimatedSize);
126
127 auto handleSingleInputGate = [&](Operation *op, Value result,
128 const Signal &inputSignal) {
129 if (!inputSignal.isInverted()) {
130 // Non-inverted buffer: directly alias the result to the input
131 valueToIndex[result] = inputSignal.getIndex();
132 return;
133 }
134 // Inverted operation: create a NOT gate
135 addGate(op, LogicNetworkGate::Identity, result, {inputSignal});
136 };
137
138 // Ensure all block arguments are indexed as primary inputs first
139 for (Value arg : block->getArguments()) {
140 if (!hasIndex(arg))
141 addPrimaryInput(arg);
142 }
143
144 auto handleOtherResults = [&](Operation *op) {
145 for (Value result : op->getResults()) {
146 if (result.getType().isInteger(1) && !hasIndex(result))
147 addPrimaryInput(result);
148 }
149 };
150
151 auto getInvertibleSignal = [&](auto op, unsigned index) {
152 return getOrCreateSignal(op.getOperand(index), op.isInverted(index));
153 };
154
155 auto handleInvertibleBinaryGate = [&](auto logicOp,
157 // The cut rewriter only has dedicated nodes for single-bit unary/binary
158 // gates. Wider or variadic forms stay as opaque cut inputs for now.
159 const auto inputs = logicOp.getInputs();
160 if (inputs.size() == 1) {
161 const Signal inputSignal = getInvertibleSignal(logicOp, 0);
162 handleSingleInputGate(logicOp, logicOp.getResult(), inputSignal);
163 return success();
164 }
165 if (inputs.size() == 2) {
166 const Signal lhsSignal = getInvertibleSignal(logicOp, 0);
167 const Signal rhsSignal = getInvertibleSignal(logicOp, 1);
168 addGate(logicOp, kind, {lhsSignal, rhsSignal});
169 }
170 // Variadic gates with >2 inputs are treated as primary
171 // inputs for now.
172 handleOtherResults(logicOp);
173 return success();
174 };
175
176 auto handleInvertibleTernaryGate = [&](auto logicOp,
178 if (!logicOp.getType().isInteger(1)) {
179 handleOtherResults(logicOp);
180 return success();
181 }
182 const Signal aSignal = getInvertibleSignal(logicOp, 0);
183 const Signal bSignal = getInvertibleSignal(logicOp, 1);
184 const Signal cSignal = getInvertibleSignal(logicOp, 2);
185 addGate(logicOp, kind, {aSignal, bSignal, cSignal});
186 return success();
187 };
188
189 // Process operations in topological order
190 for (Operation &op : block->getOperations()) {
191 LogicalResult result =
192 llvm::TypeSwitch<Operation *, LogicalResult>(&op)
193 .Case<aig::AndInverterOp>([&](aig::AndInverterOp andOp) {
194 return handleInvertibleBinaryGate(andOp, LogicNetworkGate::And2);
195 })
196 .Case<synth::XorInverterOp>([&](synth::XorInverterOp xorOp) {
197 return handleInvertibleBinaryGate(xorOp, LogicNetworkGate::Xor2);
198 })
199 .Case<synth::MuxInverterOp>([&](synth::MuxInverterOp muxOp) {
200 return handleInvertibleTernaryGate(muxOp, LogicNetworkGate::Mux3);
201 })
202 .Case<synth::DotOp>([&](synth::DotOp dotOp) {
203 return handleInvertibleTernaryGate(dotOp, LogicNetworkGate::Dot3);
204 })
205 .Case<synth::MajorityOp>([&](synth::MajorityOp majOp) {
206 return handleInvertibleTernaryGate(majOp, LogicNetworkGate::Maj3);
207 })
208 .Case<synth::OneHotOp>([&](synth::OneHotOp oneHotOp) {
209 return handleInvertibleTernaryGate(oneHotOp,
211 })
212 .Case<comb::XorOp>([&](comb::XorOp xorOp) {
213 if (xorOp->getNumOperands() != 2) {
214 handleOtherResults(xorOp);
215 return success();
216 }
217 const Signal lhsSignal =
218 getOrCreateSignal(xorOp.getOperand(0), false);
219 const Signal rhsSignal =
220 getOrCreateSignal(xorOp.getOperand(1), false);
221 addGate(xorOp, LogicNetworkGate::Xor2, {lhsSignal, rhsSignal});
222 return success();
223 })
224 .Case<hw::ConstantOp>([&](hw::ConstantOp constOp) {
225 Value result = constOp.getResult();
226 if (!result.getType().isInteger(1)) {
227 handleOtherResults(constOp);
228 return success();
229 }
230 uint32_t constIdx =
231 constOp.getValue().isZero() ? kConstant0 : kConstant1;
232 valueToIndex[result] = constIdx;
233 return success();
234 })
235 .Case<synth::ChoiceOp>([&](synth::ChoiceOp choiceOp) {
236 if (!choiceOp.getType().isInteger(1)) {
237 handleOtherResults(choiceOp);
238 return success();
239 }
240 addGate(choiceOp, LogicNetworkGate::Choice, choiceOp.getResult(),
241 {});
242 return success();
243 })
244 .Default([&](Operation *defaultOp) {
245 handleOtherResults(defaultOp);
246 return success();
247 });
248
249 if (failed(result))
250 return result;
251 }
252
253 return success();
254}
255
257 valueToIndex.clear();
258 indexToValue.clear();
259 gates.clear();
260 // Re-add the constant nodes (index 0 = const0, index 1 = const1)
261 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
262 gates.emplace_back(nullptr, LogicNetworkGate::Constant);
263 // Placeholders for constants in indexToValue
264 indexToValue.push_back(Value()); // const0
265 indexToValue.push_back(Value()); // const1
266}
267
268//===----------------------------------------------------------------------===//
269// Helper functions
270//===----------------------------------------------------------------------===//
271
272// Return true if the gate at the given index is always a cut input.
273static bool isAlwaysCutInput(const LogicNetwork &network, uint32_t index) {
274 const auto &gate = network.getGate(index);
275 return gate.isAlwaysCutInput();
276}
277
278// Return true if the new area/delay is better than the old area/delay in the
279// context of the given strategy.
281 ArrayRef<DelayType> newDelay, double oldArea,
282 ArrayRef<DelayType> oldDelay) {
284 return newArea < oldArea || (newArea == oldArea && newDelay < oldDelay);
286 return newDelay < oldDelay || (newDelay == oldDelay && newArea < oldArea);
287 llvm_unreachable("Unknown mapping strategy");
288}
289
290LogicalResult circt::synth::topologicallySortLogicNetwork(Operation *topOp) {
291 const auto isOperationReady = [](Value value, Operation *op) -> bool {
292 // Topologically sort AIG ops and dataflow ops. Other operations
293 // can be scheduled.
295 };
296
297 if (failed(topologicallySortGraphRegionBlocks(topOp, isOperationReady)))
298 return emitError(topOp->getLoc(),
299 "failed to sort operations topologically");
300 return success();
301}
302
303/// Get the truth table for operations within a block.
304FailureOr<BinaryTruthTable> circt::synth::getTruthTable(ValueRange values,
305 Block *block) {
307 for (Value arg : block->getArguments())
308 inputArgs.insert(arg);
309
310 if (inputArgs.empty())
311 return BinaryTruthTable();
312
313 const int64_t numInputs = inputArgs.size();
314 const int64_t numOutputs = values.size();
315 if (LLVM_UNLIKELY(numOutputs != 1 || numInputs >= maxTruthTableInputs)) {
316 if (numOutputs == 0)
317 return BinaryTruthTable(numInputs, 0);
318 if (numInputs >= maxTruthTableInputs)
319 return mlir::emitError(values.front().getLoc(),
320 "Truth table is too large");
321 return mlir::emitError(values.front().getLoc(),
322 "Multiple outputs are not supported yet");
323 }
324
325 // Create a map to evaluate the operation
326 DenseMap<Value, APInt> eval;
327 for (uint32_t i = 0; i < numInputs; ++i)
328 eval[inputArgs[i]] = circt::createVarMask(numInputs, i, true);
329
330 // Simulate the operations in the block
331 for (Operation &op : block->without_terminator()) {
332 if (op.getNumResults() != 1 ||
333 hw::getBitWidth(op.getResult(0).getType()) != 1)
334 return op.emitError("Unsupported operation for truth table simulation");
335
336 if (auto choiceOp = dyn_cast<synth::ChoiceOp>(&op)) {
337 auto it = eval.find(choiceOp.getInputs().front());
338 if (it == eval.end())
339 return choiceOp.emitError("Input value not found in evaluation map");
340 eval[choiceOp.getResult()] = it->second;
341 } else if (auto logicOp = dyn_cast<BooleanLogicOpInterface>(&op)) {
342 for (auto value : logicOp.getInputs())
343 if (!eval.contains(value))
344 return logicOp->emitError("Input value not found in evaluation map");
345
346 eval[logicOp.getResult()] =
347 logicOp.evaluateBooleanLogic([&](unsigned i) -> const APInt & {
348 return eval.find(logicOp.getInput(i))->second;
349 });
350 } else if (auto xorOp = dyn_cast<comb::XorOp>(&op)) {
351 // TODO: Define Xor as Synth op.
352 auto it = eval.find(xorOp.getOperand(0));
353 if (it == eval.end())
354 return xorOp.emitError("Input value not found in evaluation map");
355 llvm::APInt result = it->second;
356 for (unsigned i = 1; i < xorOp.getNumOperands(); ++i) {
357 it = eval.find(xorOp.getOperand(i));
358 if (it == eval.end())
359 return xorOp.emitError("Input value not found in evaluation map");
360 result ^= it->second;
361 }
362 eval[xorOp.getResult()] = result;
363 } else if (auto constantOp = dyn_cast<hw::ConstantOp>(&op)) {
364 auto tableSize = 1ULL << numInputs;
365 eval[constantOp.getResult()] = constantOp.getValue().isZero()
366 ? llvm::APInt::getZero(tableSize)
367 : llvm::APInt::getAllOnes(tableSize);
368 } else {
369 return op.emitError("Unsupported operation for truth table simulation");
370 }
371 }
372
373 return BinaryTruthTable(numInputs, 1, eval[values[0]]);
374}
375
376//===----------------------------------------------------------------------===//
377// Cut
378//===----------------------------------------------------------------------===//
379
380bool Cut::isTrivialCut() const {
381 // A cut is a trivial cut if it has no root (rootIndex == 0 sentinel)
382 // and only one input
383 return rootIndex == 0 && inputs.size() == 1;
384}
385
386const NPNClass &Cut::getNPNClass() const { return getNPNClass(nullptr); }
387
388const NPNClass &Cut::getNPNClass(const NPNTable *npnTable) const {
389 if (npnClass)
390 return *npnClass;
391
392 const auto &truthTable = *getTruthTable();
393 NPNClass canonicalForm;
394 if (!npnTable || !npnTable->lookup(truthTable, canonicalForm))
396
397 npnClass.emplace(std::move(canonicalForm));
398 return *npnClass;
399}
400
402 const NPNTable *npnTable, const NPNClass &patternNPN,
403 SmallVectorImpl<unsigned> &permutedIndices) const {
404 const auto &npnClass = getNPNClass(npnTable);
405 npnClass.getInputPermutation(patternNPN, permutedIndices);
406}
407
408LogicalResult
410 SmallVectorImpl<DelayType> &results) const {
411 results.reserve(getInputSize());
412 const auto &network = enumerator.getLogicNetwork();
413
414 // Compute arrival times for each input.
415 for (auto inputIndex : inputs) {
416 if (isAlwaysCutInput(network, inputIndex)) {
417 // If the input is a primary input, it has no delay.
418 results.push_back(0);
419 continue;
420 }
421 auto *cutSet = enumerator.getCutSet(inputIndex);
422 assert(cutSet && "Input must have a valid cut set");
423
424 // If there is no matching pattern, it means it's not possible to use the
425 // input in the cut rewriting. Return empty vector to indicate failure.
426 auto *bestCut = cutSet->getBestMatchedCut();
427 if (!bestCut)
428 return failure();
429
430 const auto &matchedPattern = *bestCut->getMatchedPattern();
431
432 // Get the value for result number lookup
433 mlir::Value inputValue = network.getValue(inputIndex);
434 // Otherwise, the cut input is an op result. Get the arrival time
435 // from the matched pattern.
436 results.push_back(matchedPattern.getArrivalTime(
437 cast<mlir::OpResult>(inputValue).getResultNumber()));
438 }
439
440 return success();
441}
442
443void Cut::dump(llvm::raw_ostream &os, const LogicNetwork &network) const {
444 os << "// === Cut Dump ===\n";
445 os << "Cut with " << getInputSize() << " inputs";
446 if (rootIndex != 0) {
447 auto *rootOp = network.getGate(rootIndex).getOperation();
448 if (rootOp)
449 os << " and root: " << *rootOp;
450 }
451 os << "\n";
452
453 if (isTrivialCut()) {
454 mlir::Value inputVal = network.getValue(inputs[0]);
455 os << "Primary input cut: " << inputVal << "\n";
456 return;
457 }
458
459 os << "Inputs (indices): \n";
460 for (auto [idx, inputIndex] : llvm::enumerate(inputs)) {
461 mlir::Value inputVal = network.getValue(inputIndex);
462 os << " Input " << idx << " (index " << inputIndex << "): " << inputVal
463 << "\n";
464 }
465
466 if (rootIndex != 0) {
467 os << "\nRoot operation: \n";
468 if (auto *rootOp = network.getGate(rootIndex).getOperation())
469 rootOp->print(os);
470 os << "\n";
471 }
472
473 auto &npnClass = getNPNClass();
474 npnClass.dump(os);
475
476 os << "// === Cut End ===\n";
477}
478
479unsigned Cut::getInputSize() const { return inputs.size(); }
480
481unsigned Cut::getOutputSize(const LogicNetwork &network) const {
482 if (rootIndex == 0)
483 return 1; // Trivial cut has 1 output
484 auto *rootOp = network.getGate(rootIndex).getOperation();
485 return rootOp ? rootOp->getNumResults() : 1;
486}
487
488/// Simulate a gate and return its truth table.
489static inline llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind,
490 const llvm::APInt &a) {
491 switch (kind) {
493 return a;
494 default:
495 llvm_unreachable("Unsupported unary operation for truth table computation");
496 }
497}
498
499static inline llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind,
500 const llvm::APInt &a,
501 const llvm::APInt &b) {
502 switch (kind) {
504 return a & b;
506 return a ^ b;
507 default:
508 llvm_unreachable(
509 "Unsupported binary operation for truth table computation");
510 }
511}
512
513static inline llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind,
514 const llvm::APInt &a,
515 const llvm::APInt &b,
516 const llvm::APInt &c) {
517 switch (kind) {
519 return evaluateMuxLogic(a, b, c);
521 return evaluateMajorityLogic(a, b, c);
523 return evaluateDotLogic(a, b, c);
525 return evaluateOneHotLogic(a, b, c);
526 default:
527 llvm_unreachable(
528 "Unsupported ternary operation for truth table computation");
529 }
530}
531
532namespace {
533
534// Helper class to build a merged truth table for a cut based on its operand
535// cuts
536struct MergedTruthTableBuilder {
537 MergedTruthTableBuilder(ArrayRef<uint32_t> mergedInputs,
538 ArrayRef<const Cut *> operandCuts)
539 : mergedInputs(mergedInputs), numMergedInputs(mergedInputs.size()),
540 operandCuts(operandCuts) {
541 assert(llvm::is_sorted(mergedInputs) && "merged inputs must be sorted");
542 assert(llvm::adjacent_find(mergedInputs) == mergedInputs.end() &&
543 "merged inputs must be unique");
544 }
545
546 ArrayRef<uint32_t> mergedInputs;
547 unsigned numMergedInputs;
548 ArrayRef<const Cut *> operandCuts;
549
550 std::optional<unsigned> findMergedInputPosition(uint32_t operandIdx) const {
551 auto *it = llvm::find(mergedInputs, operandIdx);
552 if (it == mergedInputs.end())
553 return std::nullopt;
554 return static_cast<unsigned>(std::distance(mergedInputs.begin(), it));
555 }
556
557 const Cut *findOperandCut(uint32_t operandIdx) const {
558 for (const Cut *cut : operandCuts) {
559 if (!cut)
560 continue;
561 uint32_t cutOutput =
562 cut->isTrivialCut() ? cut->inputs[0] : cut->getRootIndex();
563 if (cutOutput == operandIdx)
564 return cut;
565 }
566 return nullptr;
567 }
568
569 void getInputMapping(const Cut *cut,
570 SmallVectorImpl<unsigned> &mapping) const {
571 mapping.clear();
572 mapping.reserve(cut->inputs.size());
573 for (uint32_t idx : cut->inputs) {
574 auto *it = llvm::find(mergedInputs, idx);
575 assert(it != mergedInputs.end() &&
576 "cut input must exist in merged inputs");
577 mapping.push_back(static_cast<unsigned>(it - mergedInputs.begin()));
578 }
579 }
580
581 llvm::APInt expandCutTruthTable(const Cut *cut) const {
582 const auto &cutTT = *cut->getTruthTable();
583 SmallVector<unsigned, 8> inputMapping;
584 getInputMapping(cut, inputMapping);
586 cutTT.table, inputMapping, numMergedInputs);
587 }
588
589 llvm::APInt expandOperand(uint32_t operandIdx, bool isInverted) const {
590 llvm::APInt result(1, 0);
591 if (operandIdx == LogicNetwork::kConstant0) {
592 result = llvm::APInt::getZero(1U << numMergedInputs);
593 } else if (operandIdx == LogicNetwork::kConstant1) {
594 result = llvm::APInt::getAllOnes(1U << numMergedInputs);
595 } else if (auto pos = findMergedInputPosition(operandIdx)) {
596 // Direct cut inputs already live in the merged input space.
597 result = circt::createVarMask(numMergedInputs, *pos, true);
598 } else if (const Cut *cut = findOperandCut(operandIdx)) {
599 // Internal operands reuse the operand cut truth table after expanding it
600 // to this root cut's merged input space.
601 result = expandCutTruthTable(cut);
602 } else {
603 llvm_unreachable("Operand not found in cuts or merged inputs");
604 }
605
606 if (isInverted)
607 result.flipAllBits();
608 return result;
609 }
610
611 BinaryTruthTable computeForGate(const LogicNetworkGate &rootGate) const {
612 auto getEdgeTT = [&](unsigned edgeIdx) {
613 const auto &edge = rootGate.edges[edgeIdx];
614 return expandOperand(edge.getIndex(), edge.isInverted());
615 };
616
617 switch (rootGate.getKind()) {
620 return BinaryTruthTable(
621 numMergedInputs, 1,
622 applyGateSemantics(rootGate.getKind(), getEdgeTT(0), getEdgeTT(1)));
624 return BinaryTruthTable(numMergedInputs, 1,
625 applyGateSemantics(rootGate.getKind(),
626 getEdgeTT(0), getEdgeTT(1),
627 getEdgeTT(2)));
629 return BinaryTruthTable(numMergedInputs, 1,
630 applyGateSemantics(rootGate.getKind(),
631 getEdgeTT(0), getEdgeTT(1),
632 getEdgeTT(2)));
634 return BinaryTruthTable(numMergedInputs, 1,
635 applyGateSemantics(rootGate.getKind(),
636 getEdgeTT(0), getEdgeTT(1),
637 getEdgeTT(2)));
639 return BinaryTruthTable(numMergedInputs, 1,
640 applyGateSemantics(rootGate.getKind(),
641 getEdgeTT(0), getEdgeTT(1),
642 getEdgeTT(2)));
644 return BinaryTruthTable(
645 numMergedInputs, 1,
646 applyGateSemantics(rootGate.getKind(), getEdgeTT(0)));
647 default:
648 llvm_unreachable("Unsupported operation for truth table computation");
649 }
650 }
651};
652
653} // namespace
654
656 if (isTrivialCut()) {
657 assert(truthTable && "trivial cuts should have their truth table pre-set");
658 return;
659 }
660
661 assert(!operandCuts.empty() &&
662 "non-trivial cuts must carry operand cuts for truth table expansion");
663
664 const auto &rootGate = network.getGate(rootIndex);
665 truthTable.emplace(
666 MergedTruthTableBuilder(inputs, operandCuts).computeForGate(rootGate));
667}
668
669bool Cut::dominates(const Cut &other) const {
670 return dominates(other.inputs, other.signature);
671}
672
673bool Cut::dominates(ArrayRef<uint32_t> otherInputs, uint64_t otherSig) const {
674
675 if (getInputSize() > otherInputs.size())
676 return false;
677
678 if ((signature & otherSig) != signature)
679 return false;
680
681 return std::includes(otherInputs.begin(), otherInputs.end(), inputs.begin(),
682 inputs.end());
683}
684
685Cut Cut::getTrivialCut(uint32_t index) {
686 Cut cut;
687 cut.inputs.push_back(index);
688 // The truth table for a trivial cut is just the identity function on its
689 // single input.
690 cut.setTruthTable(BinaryTruthTable(1, 1, llvm::APInt(2, 2)));
691 cut.setSignature(1ULL << (index % 64)); // Set signature bit for this input
692 return cut;
693}
694
695//===----------------------------------------------------------------------===//
696// MatchedPattern
697//===----------------------------------------------------------------------===//
698
699ArrayRef<DelayType> MatchedPattern::getArrivalTimes() const {
700 assert(pattern && "Pattern must be set to get arrival time");
701 return arrivalTimes;
702}
703
705 assert(pattern && "Pattern must be set to get arrival time");
706 return arrivalTimes[index];
707}
708
710 assert(pattern && "Pattern must be set to get the pattern");
711 return pattern;
712}
713
715 assert(pattern && "Pattern must be set to get area");
716 return area;
717}
718
719//===----------------------------------------------------------------------===//
720// CutSet
721//===----------------------------------------------------------------------===//
722
724
725unsigned CutSet::size() const { return cuts.size(); }
726
727void CutSet::addCut(Cut *cut) {
728 assert(!isFrozen && "Cannot add cuts to a frozen cut set");
729 cuts.push_back(cut);
730}
731
732ArrayRef<Cut *> CutSet::getCuts() const { return cuts; }
733
734// Remove duplicate cuts and non-minimal cuts. A cut is non-minimal if there
735// exists another cut that is a subset of it.
736static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl<Cut *> &cuts) {
737 auto dumpInputs = [](llvm::raw_ostream &os,
738 const llvm::SmallVectorImpl<uint32_t> &inputs) {
739 os << "{";
740 llvm::interleaveComma(inputs, os);
741 os << "}";
742 };
743 // Sort by size, then lexicographically by inputs. This enables cheap exact
744 // duplicate elimination and tighter candidate filtering for subset checks.
745 std::stable_sort(cuts.begin(), cuts.end(), [](const Cut *a, const Cut *b) {
746 if (a->getInputSize() != b->getInputSize())
747 return a->getInputSize() < b->getInputSize();
748 return std::lexicographical_compare(a->inputs.begin(), a->inputs.end(),
749 b->inputs.begin(), b->inputs.end());
750 });
751
752 // Group kept cuts by input size so subset checks only visit smaller cuts.
753 unsigned maxCutSize = cuts.empty() ? 0 : cuts.back()->getInputSize();
754 llvm::SmallVector<llvm::SmallVector<Cut *, 4>, 16> keptBySize(maxCutSize + 1);
755
756 // Compact kept cuts in-place.
757 unsigned uniqueCount = 0;
758 for (Cut *cut : cuts) {
759 unsigned cutSize = cut->getInputSize();
760
761 // Fast exact duplicate check: with lexicographic sort, duplicates are
762 // adjacent among cuts with equal size.
763 if (uniqueCount > 0) {
764 Cut *lastKept = cuts[uniqueCount - 1];
765 if (lastKept->getInputSize() == cutSize &&
766 lastKept->inputs == cut->inputs)
767 continue;
768 }
769
770 bool isDominated = false;
771 for (unsigned existingSize = 1; existingSize < cutSize && !isDominated;
772 ++existingSize) {
773 for (const Cut *existingCut : keptBySize[existingSize]) {
774 if (!existingCut->dominates(*cut))
775 continue;
776
777 LLVM_DEBUG({
778 llvm::dbgs() << "Dropping non-minimal cut ";
779 dumpInputs(llvm::dbgs(), cut->inputs);
780 llvm::dbgs() << " due to subset ";
781 dumpInputs(llvm::dbgs(), existingCut->inputs);
782 llvm::dbgs() << "\n";
783 });
784 isDominated = true;
785 break;
786 }
787 }
788
789 if (isDominated)
790 continue;
791
792 cuts[uniqueCount++] = cut;
793 keptBySize[cutSize].push_back(cut);
794 }
795
796 LLVM_DEBUG(llvm::dbgs() << "Original cuts: " << cuts.size()
797 << " Unique cuts: " << uniqueCount << "\n");
798
799 // Resize the cuts vector to the number of surviving cuts.
800 cuts.resize(uniqueCount);
801}
802
804 const CutRewriterOptions &options,
805 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut,
806 const LogicNetwork &logicNetwork) {
807
808 // Remove duplicate/non-minimal cuts first so all follow-up work only runs on
809 // survivors.
811
812 // Compute truth tables lazily, then match cuts to collect timing/area data.
813 for (Cut *cut : cuts) {
814 if (!cut->getTruthTable().has_value())
815 cut->computeTruthTableFromOperands(logicNetwork);
816
817 assert(cut->getInputSize() <= options.maxCutInputSize &&
818 "Cut input size exceeds maximum allowed size");
819
820 if (auto matched = matchCut(*cut))
821 cut->setMatchedPattern(std::move(*matched));
822 }
823
824 // Sort cuts by priority to select the most promising ones.
825 // Priority is determined by the optimization strategy:
826 // - Trivial cuts (direct connections) have highest priority
827 // - Among matched cuts, compare by area/delay based on the strategy
828 // - Matched cuts are preferred over unmatched cuts
829 // See "Combinational and Sequential Mapping with Priority Cuts" by Mishchenko
830 // et al., ICCAD 2007 for more details.
831 // TODO: Use a priority queue instead of sorting for better performance.
832
833 // Partition the cuts into trivial and non-trivial cuts.
834 auto *trivialCutsEnd =
835 std::stable_partition(cuts.begin(), cuts.end(),
836 [](const Cut *cut) { return cut->isTrivialCut(); });
837
838 auto isBetterCut = [&options](const Cut *a, const Cut *b) {
839 assert(!a->isTrivialCut() && !b->isTrivialCut() &&
840 "Trivial cuts should have been excluded");
841 const auto &aMatched = a->getMatchedPattern();
842 const auto &bMatched = b->getMatchedPattern();
843
844 if (aMatched && bMatched)
845 return compareDelayAndArea(
846 options.strategy, aMatched->getArea(), aMatched->getArrivalTimes(),
847 bMatched->getArea(), bMatched->getArrivalTimes());
848
849 if (static_cast<bool>(aMatched) != static_cast<bool>(bMatched))
850 return static_cast<bool>(aMatched);
851
852 return a->getInputSize() < b->getInputSize();
853 };
854 std::stable_sort(trivialCutsEnd, cuts.end(), isBetterCut);
855
856 // Keep only the top-K cuts to bound growth.
857 if (cuts.size() > options.maxCutSizePerRoot)
858 cuts.resize(options.maxCutSizePerRoot);
859
860 // Select the best cut from the remaining candidates.
861 bestCut = nullptr;
862 for (Cut *cut : cuts) {
863 const auto &currentMatch = cut->getMatchedPattern();
864 if (!currentMatch)
865 continue;
866 bestCut = cut;
867 break;
868 }
869
870 LLVM_DEBUG({
871 llvm::dbgs() << "Finalized cut set with " << cuts.size() << " cuts and "
872 << (bestCut
873 ? "matched pattern to " + bestCut->getMatchedPattern()
874 ->getPattern()
875 ->getPatternName()
876 : "no matched pattern")
877 << "\n";
878 });
879
880 isFrozen = true; // Mark the cut set as frozen
881}
882
883//===----------------------------------------------------------------------===//
884// CutRewritePattern
885//===----------------------------------------------------------------------===//
886
888 SmallVectorImpl<NPNClass> &matchingNPNClasses) const {
889 return false;
890}
891
892//===----------------------------------------------------------------------===//
893// CutRewritePatternSet
894//===----------------------------------------------------------------------===//
895
897 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4> patterns)
898 : patterns(std::move(patterns)) {
899 // Initialize the NPN to pattern map
900 for (auto &pattern : this->patterns) {
901 SmallVector<NPNClass, 2> npnClasses;
902 auto result = pattern->useTruthTableMatcher(npnClasses);
903 if (result) {
904 for (auto npnClass : npnClasses) {
905 // Create a NPN class from the truth table
906 npnToPatternMap[{npnClass.truthTable.table,
907 npnClass.truthTable.numInputs}]
908 .push_back(std::make_pair(std::move(npnClass), pattern.get()));
909 }
910 } else {
911 // If the pattern does not provide NPN classes, we use a special key
912 // to indicate that it should be considered for all cuts.
913 nonNPNPatterns.push_back(pattern.get());
914 }
915 }
916}
917
918//===----------------------------------------------------------------------===//
919// CutEnumerator
920//===----------------------------------------------------------------------===//
921
923 : cutAllocator(stats.numCutsCreated),
924 cutSetAllocator(stats.numCutSetsCreated), options(options) {}
925
927 CutSet *cutSet = cutSetAllocator.create();
928 auto [cutSetPtr, inserted] = cutSets.try_emplace(index, cutSet);
929 assert(inserted && "Cut set already exists for this index");
930 return cutSetPtr->second;
931}
932
934 cutSets.clear();
935 processingOrder.clear();
937 cutAllocator.DestroyAll();
938 cutSetAllocator.DestroyAll();
939}
940
941LogicalResult CutEnumerator::visitLogicOp(uint32_t nodeIndex) {
942 const auto &gate = logicNetwork.getGate(nodeIndex);
943 auto *logicOp = gate.getOperation();
944 assert(logicOp && logicOp->getNumResults() == 1 &&
945 "Logic operation must have a single result");
946
947 if (gate.getKind() == LogicNetworkGate::Choice) {
948 auto choiceOp = cast<synth::ChoiceOp>(logicOp);
949 auto *resultCutSet = createNewCutSet(nodeIndex);
950 Cut *primaryInputCut = cutAllocator.create(Cut::getTrivialCut(nodeIndex));
951 processingOrder.push_back(nodeIndex);
952 resultCutSet->addCut(primaryInputCut);
953
954 for (Value operand : choiceOp.getInputs()) {
955 auto *operandCutSet = getCutSet(logicNetwork.getIndex(operand));
956 if (!operandCutSet)
957 return logicOp->emitError("Failed to get cut set for choice operand");
958
959 // Choice nodes do not introduce new logic. They forward each non-trivial
960 // operand cut as an equivalent alternative for the same root.
961 for (const Cut *operandCut : operandCutSet->getCuts()) {
962 if (operandCut->isTrivialCut())
963 continue;
964
965 resultCutSet->addCut(cutAllocator.create(
966 nodeIndex, operandCut->inputs, operandCut->getSignature(),
967 ArrayRef<const Cut *>{operandCut}, *operandCut->getTruthTable()));
968 }
969 }
970
971 // Finalize cut set: remove duplicates, limit size, and match patterns
972 resultCutSet->finalize(options, matchCut, logicNetwork);
973 return success();
974 }
975
976 unsigned numFanins = gate.getNumFanins();
977
978 // Validate operation constraints
979 // TODO: Variadic operations and non-single-bit results can be supported
980 if (numFanins > 3)
981 return logicOp->emitError("Cut enumeration supports at most 3 operands, "
982 "found: ")
983 << numFanins;
984 if (!logicOp->getOpResult(0).getType().isInteger(1))
985 return logicOp->emitError()
986 << "Supported logic operations must have a single bit "
987 "result type but found: "
988 << logicOp->getResult(0).getType();
989
990 // A vector to hold cut sets for each operand along with their max cut input
991 // size.
992 SmallVector<std::pair<const CutSet *, unsigned>, 2> operandCutSets;
993 operandCutSets.reserve(numFanins);
994
995 // Collect cut sets for each fanin (using LogicNetwork edges)
996 for (unsigned i = 0; i < numFanins; ++i) {
997 uint32_t faninIndex = gate.edges[i].getIndex();
998 auto *operandCutSet = getCutSet(faninIndex);
999 if (!operandCutSet)
1000 return logicOp->emitError("Failed to get cut set for fanin index ")
1001 << faninIndex;
1002
1003 // Find the largest cut size among the operand's cuts for sorting heuristic
1004 // later.
1005 unsigned maxInputCutSize = 0;
1006 for (auto *cut : operandCutSet->getCuts())
1007 maxInputCutSize = std::max(maxInputCutSize, cut->getInputSize());
1008 operandCutSets.push_back(std::make_pair(operandCutSet, maxInputCutSize));
1009 }
1010
1011 // Create the trivial cut for this node's output
1012 Cut *primaryInputCut = cutAllocator.create(Cut::getTrivialCut(nodeIndex));
1013
1014 auto *resultCutSet = createNewCutSet(nodeIndex);
1015 processingOrder.push_back(nodeIndex);
1016 resultCutSet->addCut(primaryInputCut);
1017
1018 // Sort operand cut sets by their largest cut size in descending order. This
1019 // heuristic improves efficiency of the k-way merge when generating cuts for
1020 // the current node by maximizing the chance of early pruning when the merged
1021 // cut exceeds the input size limit.
1022 llvm::stable_sort(operandCutSets,
1023 [](const std::pair<const CutSet *, unsigned> &a,
1024 const std::pair<const CutSet *, unsigned> &b) {
1025 return a.second > b.second;
1026 });
1027
1028 // Cache maxCutInputSize to avoid repeated access
1029 unsigned maxInputSize = options.maxCutInputSize;
1030
1031 // This lambda generates nested loops at runtime to iterate over all
1032 // combinations of cuts from N operands
1033 auto enumerateCutCombinations = [&](auto &&self, unsigned operandIdx,
1034 SmallVector<const Cut *, 3> &cutPtrs,
1035 uint64_t currentSig) -> void {
1036 // Base case: all operands processed, create merged cut
1037 if (operandIdx == numFanins) {
1038 // Efficient k-way merge: inputs are sorted, so dedup and constant
1039 // filtering can be done while merging. Abort early once we exceed the
1040 // cut-size limit to avoid building doomed merged cuts.
1041 SmallVector<uint32_t, 6> mergedInputs;
1042 auto appendMergedInput = [&](uint32_t value) {
1043 if (value == LogicNetwork::kConstant0 ||
1044 value == LogicNetwork::kConstant1)
1045 return true;
1046 if (!mergedInputs.empty() && mergedInputs.back() == value)
1047 return true;
1048 mergedInputs.push_back(value);
1049 return mergedInputs.size() <= maxInputSize;
1050 };
1051
1052 if (numFanins == 1) {
1053 // Single input: copy while filtering constants.
1054 mergedInputs.reserve(
1055 std::min<size_t>(cutPtrs[0]->inputs.size(), maxInputSize));
1056 for (uint32_t value : cutPtrs[0]->inputs)
1057 if (!appendMergedInput(value))
1058 return;
1059 } else if (numFanins == 2) {
1060 // Two-way merge (common case for AND gates)
1061 const auto &inputs0 = cutPtrs[0]->inputs;
1062 const auto &inputs1 = cutPtrs[1]->inputs;
1063 mergedInputs.reserve(
1064 std::min<size_t>(inputs0.size() + inputs1.size(), maxInputSize));
1065
1066 unsigned i = 0, j = 0;
1067 while (i < inputs0.size() || j < inputs1.size()) {
1068 uint32_t next;
1069 if (j == inputs1.size() ||
1070 (i < inputs0.size() && inputs0[i] <= inputs1[j])) {
1071 next = inputs0[i++];
1072 if (j < inputs1.size() && inputs1[j] == next)
1073 ++j;
1074 } else {
1075 next = inputs1[j++];
1076 }
1077 if (!appendMergedInput(next))
1078 return;
1079 }
1080 } else {
1081 // Three-way merge (for MAJ/MUX gates)
1082 const SmallVectorImpl<uint32_t> &inputs0 = cutPtrs[0]->inputs;
1083 const SmallVectorImpl<uint32_t> &inputs1 = cutPtrs[1]->inputs;
1084 const SmallVectorImpl<uint32_t> &inputs2 = cutPtrs[2]->inputs;
1085 mergedInputs.reserve(std::min<size_t>(
1086 inputs0.size() + inputs1.size() + inputs2.size(), maxInputSize));
1087
1088 unsigned i = 0, j = 0, k = 0;
1089 while (i < inputs0.size() || j < inputs1.size() || k < inputs2.size()) {
1090 // Find minimum among available elements
1091 uint32_t minVal = UINT32_MAX;
1092 if (i < inputs0.size())
1093 minVal = std::min(minVal, inputs0[i]);
1094 if (j < inputs1.size())
1095 minVal = std::min(minVal, inputs1[j]);
1096 if (k < inputs2.size())
1097 minVal = std::min(minVal, inputs2[k]);
1098
1099 // Advance all iterators pointing to minVal (handles duplicates)
1100 if (i < inputs0.size() && inputs0[i] == minVal)
1101 i++;
1102 if (j < inputs1.size() && inputs1[j] == minVal)
1103 j++;
1104 if (k < inputs2.size() && inputs2[k] == minVal)
1105 k++;
1106
1107 if (!appendMergedInput(minVal))
1108 return;
1109 }
1110 }
1111
1112 // Create the merged cut.
1113 Cut *mergedCut = cutAllocator.create(nodeIndex, mergedInputs, currentSig,
1114 ArrayRef<const Cut *>(cutPtrs));
1115 resultCutSet->addCut(mergedCut);
1116
1117 LLVM_DEBUG({
1118 if (mergedCut->inputs.size() >= 4) {
1119 llvm::dbgs() << "Generated cut for node " << nodeIndex;
1120 if (logicOp)
1121 llvm::dbgs() << " (" << logicOp->getName() << ")";
1122 llvm::dbgs() << " inputs=";
1123 llvm::interleaveComma(mergedCut->inputs, llvm::dbgs());
1124 llvm::dbgs() << "\n";
1125 }
1126 });
1127 return;
1128 }
1129
1130 // Recursive case: iterate over cuts for current operand
1131 const CutSet *currentCutSet = operandCutSets[operandIdx].first;
1132 for (const Cut *cut : currentCutSet->getCuts()) {
1133 uint64_t cutSig = cut->getSignature();
1134 uint64_t newSig = currentSig | cutSig;
1135 if (static_cast<unsigned>(llvm::popcount(newSig)) > maxInputSize)
1136 continue; // Early rejection based on signature
1137
1138 cutPtrs.push_back(cut);
1139
1140 // Recurse to next operand
1141 self(self, operandIdx + 1, cutPtrs, newSig);
1142
1143 cutPtrs.pop_back();
1144 }
1145 };
1146
1147 // Start the recursion with empty cut pointer list and zero signature
1148 SmallVector<const Cut *, 3> cutPtrs;
1149 cutPtrs.reserve(numFanins);
1150 enumerateCutCombinations(enumerateCutCombinations, 0, cutPtrs, 0ULL);
1151
1152 // Finalize cut set: remove duplicates, limit size, and match patterns
1153 resultCutSet->finalize(options, matchCut, logicNetwork);
1154
1155 return success();
1156}
1157
1159 Operation *topOp,
1160 llvm::function_ref<std::optional<MatchedPattern>(const Cut &)> matchCut) {
1161 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts for module: " << topOp->getName()
1162 << "\n");
1163 // Topologically sort the logic network
1164 if (failed(topologicallySortLogicNetwork(topOp)))
1165 return failure();
1166
1167 // Store the pattern matching function for use during cut finalization
1168 this->matchCut = matchCut;
1169
1170 // Build the flat logic network representation for efficient simulation
1171 auto &block = topOp->getRegion(0).getBlocks().front();
1172 if (failed(logicNetwork.buildFromBlock(&block)))
1173 return failure();
1174
1175 for (const auto &[index, gate] : llvm::enumerate(logicNetwork.getGates())) {
1176 // Skip non-logic gates.
1177 if (!gate.isLogicGate())
1178 continue;
1179
1180 // Ensure cut set exists for each logic gate
1181 if (failed(visitLogicOp(index)))
1182 return failure();
1183 }
1184
1185 LLVM_DEBUG(llvm::dbgs() << "Cut enumeration completed successfully\n");
1186 return success();
1187}
1188
1189const CutSet *CutEnumerator::getCutSet(uint32_t index) {
1190 // Check if cut set already exists
1191 auto it = cutSets.find(index);
1192 if (it == cutSets.end()) {
1193 // Create new cut set for an unprocessed value (primary input or other)
1194 CutSet *cutSet = cutSetAllocator.create();
1195 Cut *trivialCut = cutAllocator.create(Cut::getTrivialCut(index));
1196 cutSet->addCut(trivialCut);
1197 auto [newIt, inserted] = cutSets.insert({index, cutSet});
1198 assert(inserted && "Cut set already exists for this index");
1199 it = newIt;
1200 }
1201
1202 return it->second;
1203}
1204
1205/// Generate a human-readable name for a value used in test output.
1206/// This function creates meaningful names for values to make debug output
1207/// and test results more readable and understandable.
1208static StringRef
1209getTestVariableName(Value value, DenseMap<OperationName, unsigned> &opCounter) {
1210 if (auto *op = value.getDefiningOp()) {
1211 // Handle values defined by operations
1212 // First, check if the operation already has a name hint attribute
1213 if (auto name = op->getAttrOfType<StringAttr>("sv.namehint"))
1214 return name.getValue();
1215
1216 // For single-result operations, generate a unique name based on operation
1217 // type
1218 if (op->getNumResults() == 1) {
1219 auto opName = op->getName();
1220 auto count = opCounter[opName]++;
1221
1222 // Create a unique name by appending a counter to the operation name
1223 SmallString<16> nameStr;
1224 nameStr += opName.getStringRef();
1225 nameStr += "_";
1226 nameStr += std::to_string(count);
1227
1228 // Store the generated name as a hint attribute for future reference
1229 auto nameAttr = StringAttr::get(op->getContext(), nameStr);
1230 op->setAttr("sv.namehint", nameAttr);
1231 return nameAttr;
1232 }
1233
1234 // Multi-result operations or other cases get a generic name
1235 return "<unknown>";
1236 }
1237
1238 // Handle block arguments
1239 auto blockArg = cast<BlockArgument>(value);
1240 auto hwOp =
1241 dyn_cast<circt::hw::HWModuleOp>(blockArg.getOwner()->getParentOp());
1242 if (!hwOp)
1243 return "<unknown>";
1244
1245 // Return the formal input name from the hardware module
1246 return hwOp.getInputName(blockArg.getArgNumber());
1247}
1248
1250 DenseMap<OperationName, unsigned> opCounter;
1251 for (auto index : processingOrder) {
1252 auto it = cutSets.find(index);
1253 if (it == cutSets.end())
1254 continue;
1255 auto &cutSet = *it->second;
1256 mlir::Value value = logicNetwork.getValue(index);
1257 llvm::outs() << getTestVariableName(value, opCounter) << " "
1258 << cutSet.getCuts().size() << " cuts:";
1259 for (const Cut *cut : cutSet.getCuts()) {
1260 llvm::outs() << " {";
1261 llvm::interleaveComma(cut->inputs, llvm::outs(), [&](uint32_t inputIdx) {
1262 mlir::Value inputVal = logicNetwork.getValue(inputIdx);
1263 llvm::outs() << getTestVariableName(inputVal, opCounter);
1264 });
1265 auto &pattern = cut->getMatchedPattern();
1266 llvm::outs() << "}"
1267 << "@t" << cut->getTruthTable()->table.getZExtValue() << "d";
1268 if (pattern) {
1269 llvm::outs() << *std::max_element(pattern->getArrivalTimes().begin(),
1270 pattern->getArrivalTimes().end());
1271 } else {
1272 llvm::outs() << "0";
1273 }
1274 }
1275 llvm::outs() << "\n";
1276 }
1277 llvm::outs() << "Cut enumeration completed successfully\n";
1278}
1279
1280//===----------------------------------------------------------------------===//
1281// CutRewriter
1282//===----------------------------------------------------------------------===//
1283
1284LogicalResult CutRewriter::run(Operation *topOp) {
1285 LLVM_DEBUG({
1286 llvm::dbgs() << "Starting Cut Rewriter\n";
1287 llvm::dbgs() << "Mode: "
1289 : "timing")
1290 << "\n";
1291 llvm::dbgs() << "Max input size: " << options.maxCutInputSize << "\n";
1292 llvm::dbgs() << "Max cut size: " << options.maxCutSizePerRoot << "\n";
1293 });
1294
1295 // Currently we don't support patterns with multiple outputs.
1296 // So check that.
1297 // TODO: This must be removed when we support multiple outputs.
1298 for (auto &pattern : patterns.patterns) {
1299 if (pattern->getNumOutputs() > 1) {
1300 return mlir::emitError(pattern->getLoc(),
1301 "Cut rewriter does not support patterns with "
1302 "multiple outputs yet");
1303 }
1304 }
1305
1306 // First sort the operations topologically to ensure we can process them
1307 // in a valid order.
1308 if (failed(topologicallySortLogicNetwork(topOp)))
1309 return failure();
1310
1311 // Enumerate cuts for all nodes (initial delay-oriented selection)
1312 if (failed(enumerateCuts(topOp)))
1313 return failure();
1314
1315 // Dump cuts if testing priority cuts.
1318 return success();
1319 }
1320
1321 // Select best cuts and perform mapping
1322 if (failed(runBottomUpRewrite(topOp)))
1323 return failure();
1324
1325 return success();
1326}
1327
1328LogicalResult CutRewriter::enumerateCuts(Operation *topOp) {
1329 LLVM_DEBUG(llvm::dbgs() << "Enumerating cuts...\n");
1330
1332 topOp, [&](const Cut &cut) -> std::optional<MatchedPattern> {
1333 // Match the cut against the patterns
1334 return patternMatchCut(cut);
1335 });
1336}
1337
1338ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
1340 if (patterns.npnToPatternMap.empty())
1341 return {};
1342
1343 auto &npnClass = cut.getNPNClass(options.npnTable);
1344 auto it = patterns.npnToPatternMap.find(
1345 {npnClass.truthTable.table, npnClass.truthTable.numInputs});
1346 if (it == patterns.npnToPatternMap.end())
1347 return {};
1348 return it->getSecond();
1349}
1350
1351std::optional<MatchedPattern> CutRewriter::patternMatchCut(const Cut &cut) {
1352 if (cut.isTrivialCut())
1353 return {};
1354
1355 const auto &network = cutEnumerator.getLogicNetwork();
1356 const CutRewritePattern *bestPattern = nullptr;
1357 SmallVector<DelayType, 4> inputArrivalTimes;
1358 SmallVector<DelayType, 1> bestArrivalTimes;
1359 double bestArea = 0.0;
1360 inputArrivalTimes.reserve(cut.getInputSize());
1361 bestArrivalTimes.reserve(cut.getOutputSize(network));
1362
1363 // Compute arrival times for each input.
1364 if (failed(cut.getInputArrivalTimes(cutEnumerator, inputArrivalTimes)))
1365 return {};
1366
1367 auto computeArrivalTimeAndPickBest =
1368 [&](const CutRewritePattern *pattern, const MatchResult &matchResult,
1369 llvm::function_ref<unsigned(unsigned)> mapIndex) {
1370 SmallVector<DelayType, 1> outputArrivalTimes;
1371 // Compute the maximum delay for each output from inputs.
1372 for (unsigned outputIndex = 0, outputSize = cut.getOutputSize(network);
1373 outputIndex < outputSize; ++outputIndex) {
1374 // Compute the arrival time for this output.
1375 DelayType outputArrivalTime = 0;
1376 auto delays = matchResult.getDelays();
1377 for (unsigned inputIndex = 0, inputSize = cut.getInputSize();
1378 inputIndex < inputSize; ++inputIndex) {
1379 // Map pattern input i to cut input through NPN transformations
1380 unsigned cutOriginalInput = mapIndex(inputIndex);
1381 outputArrivalTime =
1382 std::max(outputArrivalTime,
1383 delays[outputIndex * inputSize + inputIndex] +
1384 inputArrivalTimes[cutOriginalInput]);
1385 }
1386
1387 outputArrivalTimes.push_back(outputArrivalTime);
1388 }
1389
1390 // Update the arrival time
1391 if (!bestPattern ||
1392 compareDelayAndArea(options.strategy, matchResult.area,
1393 outputArrivalTimes, bestArea,
1394 bestArrivalTimes)) {
1395 LLVM_DEBUG({
1396 llvm::dbgs() << "== Matched Pattern ==============\n";
1397 llvm::dbgs() << "Matching cut: \n";
1398 cut.dump(llvm::dbgs(), network);
1399 llvm::dbgs() << "Found better pattern: "
1400 << pattern->getPatternName();
1401 llvm::dbgs() << " with area: " << matchResult.area;
1402 llvm::dbgs() << " and input arrival times: ";
1403 for (unsigned i = 0; i < inputArrivalTimes.size(); ++i) {
1404 llvm::dbgs() << " " << inputArrivalTimes[i];
1405 }
1406 llvm::dbgs() << " and arrival times: ";
1407
1408 for (auto arrivalTime : outputArrivalTimes) {
1409 llvm::dbgs() << " " << arrivalTime;
1410 }
1411 llvm::dbgs() << "\n";
1412 llvm::dbgs() << "== Matched Pattern End ==============\n";
1413 });
1414
1415 bestArrivalTimes = std::move(outputArrivalTimes);
1416 bestArea = matchResult.area;
1417 bestPattern = pattern;
1418 }
1419 };
1420
1421 for (auto &[patternNPN, pattern] : getMatchingPatternsFromTruthTable(cut)) {
1422 assert(patternNPN.truthTable.numInputs == cut.getInputSize() &&
1423 "Pattern input size must match cut input size");
1424 auto matchResult = pattern->match(cutEnumerator, cut);
1425 if (!matchResult)
1426 continue;
1427 auto &cutNPN = cut.getNPNClass(options.npnTable);
1428
1429 // Get the input mapping from pattern's NPN class to cut's NPN class
1430 SmallVector<unsigned> inputMapping;
1431 cutNPN.getInputPermutation(patternNPN, inputMapping);
1432 computeArrivalTimeAndPickBest(pattern, *matchResult,
1433 [&](unsigned i) { return inputMapping[i]; });
1434 }
1435
1436 for (const CutRewritePattern *pattern : patterns.nonNPNPatterns) {
1437 if (auto matchResult = pattern->match(cutEnumerator, cut))
1438 computeArrivalTimeAndPickBest(pattern, *matchResult,
1439 [&](unsigned i) { return i; });
1440 }
1441
1442 if (!bestPattern)
1443 return {}; // No matching pattern found
1444
1445 return MatchedPattern(bestPattern, std::move(bestArrivalTimes), bestArea);
1446}
1447
1448LogicalResult CutRewriter::runBottomUpRewrite(Operation *top) {
1449 LLVM_DEBUG(llvm::dbgs() << "Performing cut-based rewriting...\n");
1450 const auto &network = cutEnumerator.getLogicNetwork();
1451 const auto &cutSets = cutEnumerator.getCutSets();
1452 auto processingOrder = cutEnumerator.getProcessingOrder();
1453
1454 // Note: Don't clear cutEnumerator yet - we need it during rewrite
1455 UnusedOpPruner pruner;
1456 PatternRewriter rewriter(top->getContext());
1457
1458 // Process in reverse topological order
1459 for (auto index : llvm::reverse(processingOrder)) {
1460 auto it = cutSets.find(index);
1461 if (it == cutSets.end())
1462 continue;
1463
1464 mlir::Value value = network.getValue(index);
1465 auto &cutSet = *it->second;
1466
1467 if (value.use_empty()) {
1468 if (auto *op = value.getDefiningOp())
1469 pruner.eraseNow(op);
1470 continue;
1471 }
1472
1473 if (isAlwaysCutInput(network, index)) {
1474 // If the value is a primary input, skip it
1475 LLVM_DEBUG(llvm::dbgs() << "Skipping inputs: " << value << "\n");
1476 continue;
1477 }
1478
1479 LLVM_DEBUG(llvm::dbgs() << "Cut set for value: " << value << "\n");
1480 auto *bestCut = cutSet.getBestMatchedCut();
1481 if (!bestCut) {
1483 continue; // No matching pattern found, skip this value
1484 return emitError(value.getLoc(), "No matching cut found for value: ")
1485 << value;
1486 }
1487
1488 // Get the root operation from LogicNetwork
1489 auto *rootOp = network.getGate(bestCut->getRootIndex()).getOperation();
1490 rewriter.setInsertionPoint(rootOp);
1491 const auto &matchedPattern = bestCut->getMatchedPattern();
1492 auto result = matchedPattern->getPattern()->rewrite(rewriter, cutEnumerator,
1493 *bestCut);
1494 if (failed(result))
1495 return failure();
1496
1497 rewriter.replaceOp(rootOp, *result);
1499
1501 auto array = rewriter.getI64ArrayAttr(matchedPattern->getArrivalTimes());
1502 (*result)->setAttr("test.arrival_times", array);
1503 }
1504 }
1505
1506 // Clear the enumerator after rewriting is complete
1508 return success();
1509}
assert(baseType &&"element must be base type")
static llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind, const llvm::APInt &a)
Simulate a gate and return its truth table.
static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl< Cut * > &cuts)
static bool isAlwaysCutInput(const LogicNetwork &network, uint32_t index)
static StringRef getTestVariableName(Value value, DenseMap< OperationName, unsigned > &opCounter)
Generate a human-readable name for a value used in test output.
static bool compareDelayAndArea(OptimizationStrategy strategy, double newArea, ArrayRef< DelayType > newDelay, double oldArea, ArrayRef< DelayType > oldDelay)
RewritePatternSet pattern
Strategy strategy
Precomputed NPN canonicalization table for 4-input single-output functions.
Definition TruthTable.h:168
bool lookup(const BinaryTruthTable &tt, NPNClass &result) const
Returns false if the given truth table shape is unsupported.
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
TrackedSpecificBumpPtrAllocator< CutSet > cutSetAllocator
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
void noteCutRewritten()
Record that one cut was successfully rewritten.
CutEnumerator(const CutRewriterOptions &options)
Constructor for cut enumerator.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
TrackedSpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
CutRewritePatternSet(llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns)
Constructor that takes ownership of the provided patterns.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
CutEnumerator cutEnumerator
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
static Cut getTrivialCut(uint32_t index)
Create a trivial cut for a value.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
uint64_t signature
Signature bitset for fast cut size estimation.
uint64_t getSignature() const
Get the signature of this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
void setTruthTable(BinaryTruthTable tt)
Set the truth table directly (used for incremental computation).
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
llvm::SmallVector< const Cut *, 3 > operandCuts
Operand cuts used to create this cut (for lazy TT computation).
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
void setSignature(uint64_t sig)
Set the signature of this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another (i.e., this cut's inputs are a subset of the other's inputs).
void getPermutatedInputIndices(const NPNTable *npnTable, const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
unsigned getInputSize() const
Get the number of inputs to this cut.
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
double getArea() const
Get the area cost of using this pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
llvm::APInt expandTruthTableToInputSpace(const llvm::APInt &tt, ArrayRef< unsigned > inputMapping, unsigned numExpandedInputs)
Expand a truth table to a larger input space using the given input mapping.
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
OptimizationStrategy
Optimization strategy.
Definition SynthPasses.h:24
@ OptimizationStrategyArea
Optimize for minimal area.
Definition SynthPasses.h:25
@ OptimizationStrategyTiming
Optimize for minimal critical path delay.
Definition SynthPasses.h:26
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
int64_t DelayType
Definition CutRewriter.h:40
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
Definition SynthOps.cpp:662
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
Definition CutRewriter.h:45
T evaluateDotLogic(const T &x, const T &y, const T &z)
Evaluate the Boolean function x ^ (z | (x & y)).
Definition SynthOps.h:134
T evaluateMajorityLogic(const T &a, const T &b, const T &c)
Definition SynthOps.h:139
T evaluateMuxLogic(const T &a, const T &b, const T &c)
Definition SynthOps.h:160
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
bool isLogicNetworkOp(mlir::Operation *op)
T evaluateOneHotLogic(const T &a, const T &b, const T &c)
Definition SynthOps.h:154
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::APInt createVarMask(unsigned numVars, unsigned varIndex, bool positive)
Create a mask for a variable in the truth table.
Represents a boolean function as a truth table.
Definition TruthTable.h:41
Represents the canonical form of a boolean function under NPN equivalence.
Definition TruthTable.h:106
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
void getInputPermutation(const NPNClass &targetNPN, llvm::SmallVectorImpl< unsigned > &permutation) const
Get input permutation from this NPN class to another equivalent NPN class.
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...
void eraseNow(Operation *op)
Erase an operation immediately, and remove it from the set of ops to be removed later.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
const NPNTable * npnTable
Optional lookup table used to accelerate 4-input NPN canonicalization.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Signal edges[3]
Fanin edges (up to 3 inputs).
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
Kind getKind() const
Get the kind of this gate.
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Dot3
Ordered DOT gate (3-input, synth.dot)
@ OneHot3
OneHot gate (3-input, synth.onehot)
@ Identity
Identity gate (used for 1-input inverter)
@ Mux3
Ordered MUX gate (3-input, synth.mux_inv)
@ Maj3
Reserved 3-input gate kind.
@ PrimaryInput
Primary input to the network.
@ Choice
Choice node (synth.choice)
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
Edge representation in the logic network.
Definition CutRewriter.h:75