32#include "mlir/Analysis/TopologicalSortUtils.h"
33#include "mlir/IR/Builders.h"
34#include "mlir/IR/Operation.h"
35#include "mlir/IR/RegionKindInterface.h"
36#include "mlir/IR/Value.h"
37#include "mlir/IR/ValueRange.h"
38#include "mlir/IR/Visitors.h"
39#include "mlir/Support/LLVM.h"
40#include "llvm/ADT/APInt.h"
41#include "llvm/ADT/Bitset.h"
42#include "llvm/ADT/DenseMap.h"
43#include "llvm/ADT/MapVector.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/ADT/ScopeExit.h"
46#include "llvm/ADT/SetVector.h"
47#include "llvm/ADT/SmallVector.h"
48#include "llvm/ADT/TypeSwitch.h"
49#include "llvm/ADT/iterator.h"
50#include "llvm/Support/Debug.h"
51#include "llvm/Support/ErrorHandling.h"
52#include "llvm/Support/LogicalResult.h"
59#define DEBUG_TYPE "synth-cut-rewriter"
80 "Value not found in LogicNetwork - use getOrCreateIndex or check with "
95 "Index out of bounds in LogicNetwork::getValue");
100 SmallVectorImpl<Value> &values)
const {
102 values.reserve(indices.size());
103 for (uint32_t idx : indices)
114 Value result, ArrayRef<Signal> operands) {
122 const size_t estimatedSize =
123 block->getArguments().size() + block->getOperations().size();
125 gates.reserve(estimatedSize);
127 auto handleSingleInputGate = [&](Operation *op, Value result,
128 const Signal &inputSignal) {
129 if (!inputSignal.isInverted()) {
139 for (Value arg : block->getArguments()) {
144 auto handleOtherResults = [&](Operation *op) {
145 for (Value result : op->getResults()) {
146 if (result.getType().isInteger(1) && !
hasIndex(result))
151 auto getInvertibleSignal = [&](
auto op,
unsigned index) {
155 auto handleInvertibleBinaryGate = [&](
auto logicOp,
159 const auto inputs = logicOp.getInputs();
160 if (inputs.size() == 1) {
161 const Signal inputSignal = getInvertibleSignal(logicOp, 0);
162 handleSingleInputGate(logicOp, logicOp.getResult(), inputSignal);
165 if (inputs.size() == 2) {
166 const Signal lhsSignal = getInvertibleSignal(logicOp, 0);
167 const Signal rhsSignal = getInvertibleSignal(logicOp, 1);
168 addGate(logicOp, kind, {lhsSignal, rhsSignal});
172 handleOtherResults(logicOp);
176 auto handleInvertibleTernaryGate = [&](
auto logicOp,
178 if (!logicOp.getType().isInteger(1)) {
179 handleOtherResults(logicOp);
182 const Signal aSignal = getInvertibleSignal(logicOp, 0);
183 const Signal bSignal = getInvertibleSignal(logicOp, 1);
184 const Signal cSignal = getInvertibleSignal(logicOp, 2);
185 addGate(logicOp, kind, {aSignal, bSignal, cSignal});
190 for (Operation &op : block->getOperations()) {
191 LogicalResult result =
192 llvm::TypeSwitch<Operation *, LogicalResult>(&op)
193 .Case<aig::AndInverterOp>([&](aig::AndInverterOp andOp) {
196 .Case<synth::XorInverterOp>([&](synth::XorInverterOp xorOp) {
199 .Case<synth::MuxInverterOp>([&](synth::MuxInverterOp muxOp) {
202 .Case<synth::DotOp>([&](synth::DotOp dotOp) {
205 .Case<synth::MajorityOp>([&](synth::MajorityOp majOp) {
208 .Case<synth::OneHotOp>([&](synth::OneHotOp oneHotOp) {
209 return handleInvertibleTernaryGate(oneHotOp,
213 if (xorOp->getNumOperands() != 2) {
214 handleOtherResults(xorOp);
225 Value result = constOp.getResult();
226 if (!result.getType().isInteger(1)) {
227 handleOtherResults(constOp);
235 .Case<synth::ChoiceOp>([&](synth::ChoiceOp choiceOp) {
236 if (!choiceOp.getType().isInteger(1)) {
237 handleOtherResults(choiceOp);
244 .Default([&](Operation *defaultOp) {
245 handleOtherResults(defaultOp);
274 const auto &gate = network.
getGate(index);
281 ArrayRef<DelayType> newDelay,
double oldArea,
282 ArrayRef<DelayType> oldDelay) {
284 return newArea < oldArea || (newArea == oldArea && newDelay < oldDelay);
286 return newDelay < oldDelay || (newDelay == oldDelay && newArea < oldArea);
287 llvm_unreachable(
"Unknown mapping strategy");
291 const auto isOperationReady = [](Value value, Operation *op) ->
bool {
298 return emitError(topOp->getLoc(),
299 "failed to sort operations topologically");
307 for (Value arg : block->getArguments())
308 inputArgs.insert(arg);
310 if (inputArgs.empty())
313 const int64_t numInputs = inputArgs.size();
314 const int64_t numOutputs = values.size();
319 return mlir::emitError(values.front().getLoc(),
320 "Truth table is too large");
321 return mlir::emitError(values.front().getLoc(),
322 "Multiple outputs are not supported yet");
326 DenseMap<Value, APInt> eval;
327 for (uint32_t i = 0; i < numInputs; ++i)
331 for (Operation &op : block->without_terminator()) {
332 if (op.getNumResults() != 1 ||
334 return op.emitError(
"Unsupported operation for truth table simulation");
336 if (
auto choiceOp = dyn_cast<synth::ChoiceOp>(&op)) {
337 auto it = eval.find(choiceOp.getInputs().front());
338 if (it == eval.end())
339 return choiceOp.emitError(
"Input value not found in evaluation map");
340 eval[choiceOp.getResult()] = it->second;
341 }
else if (
auto logicOp = dyn_cast<BooleanLogicOpInterface>(&op)) {
342 for (
auto value : logicOp.getInputs())
343 if (!eval.contains(value))
344 return logicOp->emitError(
"Input value not found in evaluation map");
346 eval[logicOp.getResult()] =
347 logicOp.evaluateBooleanLogic([&](
unsigned i) ->
const APInt & {
348 return eval.find(logicOp.getInput(i))->second;
350 }
else if (
auto xorOp = dyn_cast<comb::XorOp>(&op)) {
352 auto it = eval.find(xorOp.getOperand(0));
353 if (it == eval.end())
354 return xorOp.emitError(
"Input value not found in evaluation map");
355 llvm::APInt result = it->second;
356 for (
unsigned i = 1; i < xorOp.getNumOperands(); ++i) {
357 it = eval.find(xorOp.getOperand(i));
358 if (it == eval.end())
359 return xorOp.emitError(
"Input value not found in evaluation map");
360 result ^= it->second;
362 eval[xorOp.getResult()] = result;
363 }
else if (
auto constantOp = dyn_cast<hw::ConstantOp>(&op)) {
364 auto tableSize = 1ULL << numInputs;
365 eval[constantOp.getResult()] = constantOp.getValue().isZero()
366 ? llvm::APInt::getZero(tableSize)
367 : llvm::APInt::getAllOnes(tableSize);
369 return op.emitError(
"Unsupported operation for truth table simulation");
397 npnClass.emplace(std::move(canonicalForm));
403 SmallVectorImpl<unsigned> &permutedIndices)
const {
405 npnClass.getInputPermutation(patternNPN, permutedIndices);
410 SmallVectorImpl<DelayType> &results)
const {
415 for (
auto inputIndex :
inputs) {
418 results.push_back(0);
421 auto *cutSet = enumerator.
getCutSet(inputIndex);
422 assert(cutSet &&
"Input must have a valid cut set");
426 auto *bestCut = cutSet->getBestMatchedCut();
433 mlir::Value inputValue = network.getValue(inputIndex);
437 cast<mlir::OpResult>(inputValue).getResultNumber()));
444 os <<
"// === Cut Dump ===\n";
449 os <<
" and root: " << *rootOp;
455 os <<
"Primary input cut: " << inputVal <<
"\n";
459 os <<
"Inputs (indices): \n";
460 for (
auto [idx, inputIndex] : llvm::enumerate(
inputs)) {
461 mlir::Value inputVal = network.
getValue(inputIndex);
462 os <<
" Input " << idx <<
" (index " << inputIndex <<
"): " << inputVal
467 os <<
"\nRoot operation: \n";
476 os <<
"// === Cut End ===\n";
485 return rootOp ? rootOp->getNumResults() : 1;
490 const llvm::APInt &a) {
495 llvm_unreachable(
"Unsupported unary operation for truth table computation");
500 const llvm::APInt &a,
501 const llvm::APInt &b) {
509 "Unsupported binary operation for truth table computation");
514 const llvm::APInt &a,
515 const llvm::APInt &b,
516 const llvm::APInt &c) {
528 "Unsupported ternary operation for truth table computation");
536struct MergedTruthTableBuilder {
537 MergedTruthTableBuilder(ArrayRef<uint32_t> mergedInputs,
538 ArrayRef<const Cut *> operandCuts)
539 : mergedInputs(mergedInputs), numMergedInputs(mergedInputs.size()),
540 operandCuts(operandCuts) {
541 assert(llvm::is_sorted(mergedInputs) &&
"merged inputs must be sorted");
542 assert(llvm::adjacent_find(mergedInputs) == mergedInputs.end() &&
543 "merged inputs must be unique");
546 ArrayRef<uint32_t> mergedInputs;
547 unsigned numMergedInputs;
548 ArrayRef<const Cut *> operandCuts;
550 std::optional<unsigned> findMergedInputPosition(uint32_t operandIdx)
const {
551 auto *it = llvm::find(mergedInputs, operandIdx);
552 if (it == mergedInputs.end())
554 return static_cast<unsigned>(std::distance(mergedInputs.begin(), it));
557 const Cut *findOperandCut(uint32_t operandIdx)
const {
558 for (
const Cut *cut : operandCuts) {
562 cut->
isTrivialCut() ? cut->inputs[0] : cut->getRootIndex();
563 if (cutOutput == operandIdx)
569 void getInputMapping(
const Cut *cut,
570 SmallVectorImpl<unsigned> &mapping)
const {
572 mapping.reserve(cut->
inputs.size());
573 for (uint32_t idx : cut->inputs) {
574 auto *it = llvm::find(mergedInputs, idx);
575 assert(it != mergedInputs.end() &&
576 "cut input must exist in merged inputs");
577 mapping.push_back(
static_cast<unsigned>(it - mergedInputs.begin()));
581 llvm::APInt expandCutTruthTable(
const Cut *cut)
const {
583 SmallVector<unsigned, 8> inputMapping;
584 getInputMapping(cut, inputMapping);
586 cutTT.table, inputMapping, numMergedInputs);
589 llvm::APInt expandOperand(uint32_t operandIdx,
bool isInverted)
const {
590 llvm::APInt result(1, 0);
592 result = llvm::APInt::getZero(1U << numMergedInputs);
594 result = llvm::APInt::getAllOnes(1U << numMergedInputs);
595 }
else if (
auto pos = findMergedInputPosition(operandIdx)) {
598 }
else if (
const Cut *cut = findOperandCut(operandIdx)) {
601 result = expandCutTruthTable(cut);
603 llvm_unreachable(
"Operand not found in cuts or merged inputs");
607 result.flipAllBits();
612 auto getEdgeTT = [&](
unsigned edgeIdx) {
613 const auto &edge = rootGate.
edges[edgeIdx];
614 return expandOperand(edge.getIndex(), edge.isInverted());
626 getEdgeTT(0), getEdgeTT(1),
631 getEdgeTT(0), getEdgeTT(1),
636 getEdgeTT(0), getEdgeTT(1),
641 getEdgeTT(0), getEdgeTT(1),
648 llvm_unreachable(
"Unsupported operation for truth table computation");
662 "non-trivial cuts must carry operand cuts for truth table expansion");
681 return std::includes(otherInputs.begin(), otherInputs.end(),
inputs.begin(),
687 cut.
inputs.push_back(index);
700 assert(
pattern &&
"Pattern must be set to get arrival time");
705 assert(
pattern &&
"Pattern must be set to get arrival time");
737 auto dumpInputs = [](llvm::raw_ostream &os,
738 const llvm::SmallVectorImpl<uint32_t> &inputs) {
740 llvm::interleaveComma(inputs, os);
745 std::stable_sort(cuts.begin(), cuts.end(), [](
const Cut *a,
const Cut *b) {
746 if (a->getInputSize() != b->getInputSize())
747 return a->getInputSize() < b->getInputSize();
748 return std::lexicographical_compare(a->inputs.begin(), a->inputs.end(),
749 b->inputs.begin(), b->inputs.end());
753 unsigned maxCutSize = cuts.empty() ? 0 : cuts.back()->getInputSize();
754 llvm::SmallVector<llvm::SmallVector<Cut *, 4>, 16> keptBySize(maxCutSize + 1);
757 unsigned uniqueCount = 0;
758 for (
Cut *cut : cuts) {
763 if (uniqueCount > 0) {
764 Cut *lastKept = cuts[uniqueCount - 1];
770 bool isDominated =
false;
771 for (
unsigned existingSize = 1; existingSize < cutSize && !isDominated;
773 for (
const Cut *existingCut : keptBySize[existingSize]) {
774 if (!existingCut->dominates(*cut))
778 llvm::dbgs() <<
"Dropping non-minimal cut ";
779 dumpInputs(llvm::dbgs(), cut->
inputs);
780 llvm::dbgs() <<
" due to subset ";
781 dumpInputs(llvm::dbgs(), existingCut->inputs);
782 llvm::dbgs() <<
"\n";
792 cuts[uniqueCount++] = cut;
793 keptBySize[cutSize].push_back(cut);
796 LLVM_DEBUG(llvm::dbgs() <<
"Original cuts: " << cuts.size()
797 <<
" Unique cuts: " << uniqueCount <<
"\n");
800 cuts.resize(uniqueCount);
805 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)> matchCut,
818 "Cut input size exceeds maximum allowed size");
820 if (
auto matched = matchCut(*cut))
834 auto *trivialCutsEnd =
835 std::stable_partition(
cuts.begin(),
cuts.end(),
836 [](
const Cut *cut) { return cut->isTrivialCut(); });
838 auto isBetterCut = [&options](
const Cut *a,
const Cut *b) {
839 assert(!a->isTrivialCut() && !b->isTrivialCut() &&
840 "Trivial cuts should have been excluded");
841 const auto &aMatched = a->getMatchedPattern();
842 const auto &bMatched = b->getMatchedPattern();
844 if (aMatched && bMatched)
846 options.
strategy, aMatched->getArea(), aMatched->getArrivalTimes(),
847 bMatched->getArea(), bMatched->getArrivalTimes());
849 if (
static_cast<bool>(aMatched) !=
static_cast<bool>(bMatched))
850 return static_cast<bool>(aMatched);
852 return a->getInputSize() < b->getInputSize();
854 std::stable_sort(trivialCutsEnd,
cuts.end(), isBetterCut);
871 llvm::dbgs() <<
"Finalized cut set with " <<
cuts.size() <<
" cuts and "
876 :
"no matched pattern")
888 SmallVectorImpl<NPNClass> &matchingNPNClasses)
const {
897 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4>
patterns)
901 SmallVector<NPNClass, 2> npnClasses;
902 auto result =
pattern->useTruthTableMatcher(npnClasses);
904 for (
auto npnClass : npnClasses) {
907 npnClass.truthTable.numInputs}]
908 .push_back(std::make_pair(std::move(npnClass),
pattern.get()));
923 : cutAllocator(stats.numCutsCreated),
924 cutSetAllocator(stats.numCutSetsCreated), options(options) {}
928 auto [cutSetPtr, inserted] =
cutSets.try_emplace(index, cutSet);
929 assert(inserted &&
"Cut set already exists for this index");
930 return cutSetPtr->second;
944 assert(logicOp && logicOp->getNumResults() == 1 &&
945 "Logic operation must have a single result");
948 auto choiceOp = cast<synth::ChoiceOp>(logicOp);
952 resultCutSet->addCut(primaryInputCut);
954 for (Value operand : choiceOp.getInputs()) {
957 return logicOp->emitError(
"Failed to get cut set for choice operand");
961 for (
const Cut *operandCut : operandCutSet->getCuts()) {
962 if (operandCut->isTrivialCut())
966 nodeIndex, operandCut->inputs, operandCut->getSignature(),
967 ArrayRef<const Cut *>{operandCut}, *operandCut->getTruthTable()));
976 unsigned numFanins = gate.getNumFanins();
981 return logicOp->emitError(
"Cut enumeration supports at most 3 operands, "
984 if (!logicOp->getOpResult(0).getType().isInteger(1))
985 return logicOp->emitError()
986 <<
"Supported logic operations must have a single bit "
987 "result type but found: "
988 << logicOp->getResult(0).getType();
992 SmallVector<std::pair<const CutSet *, unsigned>, 2> operandCutSets;
993 operandCutSets.reserve(numFanins);
996 for (
unsigned i = 0; i < numFanins; ++i) {
997 uint32_t faninIndex = gate.edges[i].getIndex();
998 auto *operandCutSet =
getCutSet(faninIndex);
1000 return logicOp->emitError(
"Failed to get cut set for fanin index ")
1005 unsigned maxInputCutSize = 0;
1006 for (
auto *cut : operandCutSet->getCuts())
1007 maxInputCutSize = std::max(maxInputCutSize, cut->
getInputSize());
1008 operandCutSets.push_back(std::make_pair(operandCutSet, maxInputCutSize));
1016 resultCutSet->addCut(primaryInputCut);
1022 llvm::stable_sort(operandCutSets,
1023 [](
const std::pair<const CutSet *, unsigned> &a,
1024 const std::pair<const CutSet *, unsigned> &b) {
1025 return a.second > b.second;
1033 auto enumerateCutCombinations = [&](
auto &&self,
unsigned operandIdx,
1034 SmallVector<const Cut *, 3> &cutPtrs,
1035 uint64_t currentSig) ->
void {
1037 if (operandIdx == numFanins) {
1041 SmallVector<uint32_t, 6> mergedInputs;
1042 auto appendMergedInput = [&](uint32_t value) {
1046 if (!mergedInputs.empty() && mergedInputs.back() == value)
1048 mergedInputs.push_back(value);
1049 return mergedInputs.size() <= maxInputSize;
1052 if (numFanins == 1) {
1054 mergedInputs.reserve(
1055 std::min<size_t>(cutPtrs[0]->inputs.size(), maxInputSize));
1056 for (uint32_t value : cutPtrs[0]->inputs)
1057 if (!appendMergedInput(value))
1059 }
else if (numFanins == 2) {
1061 const auto &inputs0 = cutPtrs[0]->inputs;
1062 const auto &inputs1 = cutPtrs[1]->inputs;
1063 mergedInputs.reserve(
1064 std::min<size_t>(inputs0.size() + inputs1.size(), maxInputSize));
1066 unsigned i = 0, j = 0;
1067 while (i < inputs0.size() || j < inputs1.size()) {
1069 if (j == inputs1.size() ||
1070 (i < inputs0.size() && inputs0[i] <= inputs1[j])) {
1071 next = inputs0[i++];
1072 if (j < inputs1.size() && inputs1[j] == next)
1075 next = inputs1[j++];
1077 if (!appendMergedInput(next))
1082 const SmallVectorImpl<uint32_t> &inputs0 = cutPtrs[0]->inputs;
1083 const SmallVectorImpl<uint32_t> &inputs1 = cutPtrs[1]->inputs;
1084 const SmallVectorImpl<uint32_t> &inputs2 = cutPtrs[2]->inputs;
1085 mergedInputs.reserve(std::min<size_t>(
1086 inputs0.size() + inputs1.size() + inputs2.size(), maxInputSize));
1088 unsigned i = 0, j = 0, k = 0;
1089 while (i < inputs0.size() || j < inputs1.size() || k < inputs2.size()) {
1091 uint32_t minVal = UINT32_MAX;
1092 if (i < inputs0.size())
1093 minVal = std::min(minVal, inputs0[i]);
1094 if (j < inputs1.size())
1095 minVal = std::min(minVal, inputs1[j]);
1096 if (k < inputs2.size())
1097 minVal = std::min(minVal, inputs2[k]);
1100 if (i < inputs0.size() && inputs0[i] == minVal)
1102 if (j < inputs1.size() && inputs1[j] == minVal)
1104 if (k < inputs2.size() && inputs2[k] == minVal)
1107 if (!appendMergedInput(minVal))
1113 Cut *mergedCut =
cutAllocator.create(nodeIndex, mergedInputs, currentSig,
1114 ArrayRef<const Cut *>(cutPtrs));
1115 resultCutSet->addCut(mergedCut);
1118 if (mergedCut->
inputs.size() >= 4) {
1119 llvm::dbgs() <<
"Generated cut for node " << nodeIndex;
1121 llvm::dbgs() <<
" (" << logicOp->getName() <<
")";
1122 llvm::dbgs() <<
" inputs=";
1123 llvm::interleaveComma(mergedCut->inputs, llvm::dbgs());
1124 llvm::dbgs() <<
"\n";
1131 const CutSet *currentCutSet = operandCutSets[operandIdx].first;
1132 for (
const Cut *cut : currentCutSet->
getCuts()) {
1134 uint64_t newSig = currentSig | cutSig;
1135 if (
static_cast<unsigned>(llvm::popcount(newSig)) > maxInputSize)
1138 cutPtrs.push_back(cut);
1141 self(self, operandIdx + 1, cutPtrs, newSig);
1148 SmallVector<const Cut *, 3> cutPtrs;
1149 cutPtrs.reserve(numFanins);
1150 enumerateCutCombinations(enumerateCutCombinations, 0, cutPtrs, 0ULL);
1160 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)> matchCut) {
1161 LLVM_DEBUG(llvm::dbgs() <<
"Enumerating cuts for module: " << topOp->getName()
1171 auto &block = topOp->getRegion(0).getBlocks().front();
1177 if (!gate.isLogicGate())
1185 LLVM_DEBUG(llvm::dbgs() <<
"Cut enumeration completed successfully\n");
1191 auto it =
cutSets.find(index);
1196 cutSet->
addCut(trivialCut);
1197 auto [newIt, inserted] =
cutSets.insert({index, cutSet});
1198 assert(inserted &&
"Cut set already exists for this index");
1210 if (
auto *op = value.getDefiningOp()) {
1213 if (
auto name = op->getAttrOfType<StringAttr>(
"sv.namehint"))
1214 return name.getValue();
1218 if (op->getNumResults() == 1) {
1219 auto opName = op->getName();
1220 auto count = opCounter[opName]++;
1223 SmallString<16> nameStr;
1224 nameStr += opName.getStringRef();
1226 nameStr += std::to_string(count);
1229 auto nameAttr = StringAttr::get(op->getContext(), nameStr);
1230 op->setAttr(
"sv.namehint", nameAttr);
1239 auto blockArg = cast<BlockArgument>(value);
1241 dyn_cast<circt::hw::HWModuleOp>(blockArg.getOwner()->getParentOp());
1246 return hwOp.getInputName(blockArg.getArgNumber());
1250 DenseMap<OperationName, unsigned> opCounter;
1252 auto it =
cutSets.find(index);
1255 auto &cutSet = *it->second;
1258 << cutSet.getCuts().size() <<
" cuts:";
1259 for (
const Cut *cut : cutSet.getCuts()) {
1260 llvm::outs() <<
" {";
1261 llvm::interleaveComma(cut->
inputs, llvm::outs(), [&](uint32_t inputIdx) {
1262 mlir::Value inputVal = logicNetwork.getValue(inputIdx);
1263 llvm::outs() << getTestVariableName(inputVal, opCounter);
1267 <<
"@t" << cut->
getTruthTable()->table.getZExtValue() <<
"d";
1269 llvm::outs() << *std::max_element(
pattern->getArrivalTimes().begin(),
1270 pattern->getArrivalTimes().end());
1272 llvm::outs() <<
"0";
1275 llvm::outs() <<
"\n";
1277 llvm::outs() <<
"Cut enumeration completed successfully\n";
1286 llvm::dbgs() <<
"Starting Cut Rewriter\n";
1287 llvm::dbgs() <<
"Mode: "
1299 if (
pattern->getNumOutputs() > 1) {
1300 return mlir::emitError(
pattern->getLoc(),
1301 "Cut rewriter does not support patterns with "
1302 "multiple outputs yet");
1329 LLVM_DEBUG(llvm::dbgs() <<
"Enumerating cuts...\n");
1332 topOp, [&](
const Cut &cut) -> std::optional<MatchedPattern> {
1338ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
1340 if (
patterns.npnToPatternMap.empty())
1344 auto it =
patterns.npnToPatternMap.find(
1345 {npnClass.truthTable.table, npnClass.truthTable.numInputs});
1346 if (it ==
patterns.npnToPatternMap.end())
1348 return it->getSecond();
1357 SmallVector<DelayType, 4> inputArrivalTimes;
1358 SmallVector<DelayType, 1> bestArrivalTimes;
1359 double bestArea = 0.0;
1367 auto computeArrivalTimeAndPickBest =
1369 llvm::function_ref<unsigned(
unsigned)> mapIndex) {
1370 SmallVector<DelayType, 1> outputArrivalTimes;
1372 for (
unsigned outputIndex = 0, outputSize = cut.
getOutputSize(network);
1373 outputIndex < outputSize; ++outputIndex) {
1376 auto delays = matchResult.getDelays();
1377 for (
unsigned inputIndex = 0, inputSize = cut.
getInputSize();
1378 inputIndex < inputSize; ++inputIndex) {
1380 unsigned cutOriginalInput = mapIndex(inputIndex);
1382 std::max(outputArrivalTime,
1383 delays[outputIndex * inputSize + inputIndex] +
1384 inputArrivalTimes[cutOriginalInput]);
1387 outputArrivalTimes.push_back(outputArrivalTime);
1393 outputArrivalTimes, bestArea,
1394 bestArrivalTimes)) {
1396 llvm::dbgs() <<
"== Matched Pattern ==============\n";
1397 llvm::dbgs() <<
"Matching cut: \n";
1398 cut.
dump(llvm::dbgs(), network);
1399 llvm::dbgs() <<
"Found better pattern: "
1401 llvm::dbgs() <<
" with area: " << matchResult.area;
1402 llvm::dbgs() <<
" and input arrival times: ";
1403 for (
unsigned i = 0; i < inputArrivalTimes.size(); ++i) {
1404 llvm::dbgs() <<
" " << inputArrivalTimes[i];
1406 llvm::dbgs() <<
" and arrival times: ";
1408 for (
auto arrivalTime : outputArrivalTimes) {
1409 llvm::dbgs() <<
" " << arrivalTime;
1411 llvm::dbgs() <<
"\n";
1412 llvm::dbgs() <<
"== Matched Pattern End ==============\n";
1415 bestArrivalTimes = std::move(outputArrivalTimes);
1416 bestArea = matchResult.area;
1423 "Pattern input size must match cut input size");
1430 SmallVector<unsigned> inputMapping;
1432 computeArrivalTimeAndPickBest(
pattern, *matchResult,
1433 [&](
unsigned i) {
return inputMapping[i]; });
1438 computeArrivalTimeAndPickBest(
pattern, *matchResult,
1439 [&](
unsigned i) {
return i; });
1445 return MatchedPattern(bestPattern, std::move(bestArrivalTimes), bestArea);
1449 LLVM_DEBUG(llvm::dbgs() <<
"Performing cut-based rewriting...\n");
1456 PatternRewriter rewriter(top->getContext());
1459 for (
auto index : llvm::reverse(processingOrder)) {
1460 auto it = cutSets.find(index);
1461 if (it == cutSets.end())
1464 mlir::Value value = network.getValue(index);
1465 auto &cutSet = *it->second;
1467 if (value.use_empty()) {
1468 if (
auto *op = value.getDefiningOp())
1475 LLVM_DEBUG(llvm::dbgs() <<
"Skipping inputs: " << value <<
"\n");
1479 LLVM_DEBUG(llvm::dbgs() <<
"Cut set for value: " << value <<
"\n");
1480 auto *bestCut = cutSet.getBestMatchedCut();
1484 return emitError(value.getLoc(),
"No matching cut found for value: ")
1489 auto *rootOp = network.getGate(bestCut->getRootIndex()).getOperation();
1490 rewriter.setInsertionPoint(rootOp);
1491 const auto &matchedPattern = bestCut->getMatchedPattern();
1492 auto result = matchedPattern->getPattern()->rewrite(rewriter,
cutEnumerator,
1497 rewriter.replaceOp(rootOp, *result);
1501 auto array = rewriter.getI64ArrayAttr(matchedPattern->getArrivalTimes());
1502 (*result)->setAttr(
"test.arrival_times", array);
assert(baseType &&"element must be base type")
static llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind, const llvm::APInt &a)
Simulate a gate and return its truth table.
static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl< Cut * > &cuts)
static bool isAlwaysCutInput(const LogicNetwork &network, uint32_t index)
static StringRef getTestVariableName(Value value, DenseMap< OperationName, unsigned > &opCounter)
Generate a human-readable name for a value used in test output.
static bool compareDelayAndArea(OptimizationStrategy strategy, double newArea, ArrayRef< DelayType > newDelay, double oldArea, ArrayRef< DelayType > oldDelay)
RewritePatternSet pattern
Precomputed NPN canonicalization table for 4-input single-output functions.
bool lookup(const BinaryTruthTable &tt, NPNClass &result) const
Returns false if the given truth table shape is unsupported.
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
TrackedSpecificBumpPtrAllocator< CutSet > cutSetAllocator
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
void noteCutRewritten()
Record that one cut was successfully rewritten.
CutEnumerator(const CutRewriterOptions &options)
Constructor for cut enumerator.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
TrackedSpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
CutRewritePatternSet(llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns)
Constructor that takes ownership of the provided patterns.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
CutEnumerator cutEnumerator
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
static Cut getTrivialCut(uint32_t index)
Create a trivial cut for a value.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
uint64_t signature
Signature bitset for fast cut size estimation.
uint64_t getSignature() const
Get the signature of this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
void setTruthTable(BinaryTruthTable tt)
Set the truth table directly (used for incremental computation).
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
llvm::SmallVector< const Cut *, 3 > operandCuts
Operand cuts used to create this cut (for lazy TT computation).
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
void setSignature(uint64_t sig)
Set the signature of this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another (i.e., this cut's inputs are a subset of the other's inputs).
void getPermutatedInputIndices(const NPNTable *npnTable, const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
unsigned getInputSize() const
Get the number of inputs to this cut.
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
double getArea() const
Get the area cost of using this pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
llvm::APInt expandTruthTableToInputSpace(const llvm::APInt &tt, ArrayRef< unsigned > inputMapping, unsigned numExpandedInputs)
Expand a truth table to a larger input space using the given input mapping.
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
OptimizationStrategy
Optimization strategy.
@ OptimizationStrategyArea
Optimize for minimal area.
@ OptimizationStrategyTiming
Optimize for minimal critical path delay.
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
T evaluateDotLogic(const T &x, const T &y, const T &z)
Evaluate the Boolean function x ^ (z | (x & y)).
T evaluateMajorityLogic(const T &a, const T &b, const T &c)
T evaluateMuxLogic(const T &a, const T &b, const T &c)
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
bool isLogicNetworkOp(mlir::Operation *op)
T evaluateOneHotLogic(const T &a, const T &b, const T &c)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::APInt createVarMask(unsigned numVars, unsigned varIndex, bool positive)
Create a mask for a variable in the truth table.
Represents a boolean function as a truth table.
Represents the canonical form of a boolean function under NPN equivalence.
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
void getInputPermutation(const NPNClass &targetNPN, llvm::SmallVectorImpl< unsigned > &permutation) const
Get input permutation from this NPN class to another equivalent NPN class.
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...
void eraseNow(Operation *op)
Erase an operation immediately, and remove it from the set of ops to be removed later.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
const NPNTable * npnTable
Optional lookup table used to accelerate 4-input NPN canonicalization.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Signal edges[3]
Fanin edges (up to 3 inputs).
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
Kind getKind() const
Get the kind of this gate.
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Dot3
Ordered DOT gate (3-input, synth.dot)
@ OneHot3
OneHot gate (3-input, synth.onehot)
@ Identity
Identity gate (used for 1-input inverter)
@ Mux3
Ordered MUX gate (3-input, synth.mux_inv)
@ Maj3
Reserved 3-input gate kind.
@ PrimaryInput
Primary input to the network.
@ Choice
Choice node (synth.choice)
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
Edge representation in the logic network.