32#include "mlir/Analysis/TopologicalSortUtils.h"
33#include "mlir/IR/Builders.h"
34#include "mlir/IR/Operation.h"
35#include "mlir/IR/RegionKindInterface.h"
36#include "mlir/IR/Value.h"
37#include "mlir/IR/ValueRange.h"
38#include "mlir/IR/Visitors.h"
39#include "mlir/Support/LLVM.h"
40#include "llvm/ADT/APInt.h"
41#include "llvm/ADT/Bitset.h"
42#include "llvm/ADT/DenseMap.h"
43#include "llvm/ADT/MapVector.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/ADT/ScopeExit.h"
46#include "llvm/ADT/SetVector.h"
47#include "llvm/ADT/SmallVector.h"
48#include "llvm/ADT/TypeSwitch.h"
49#include "llvm/ADT/iterator.h"
50#include "llvm/Support/Debug.h"
51#include "llvm/Support/ErrorHandling.h"
52#include "llvm/Support/LogicalResult.h"
59#define DEBUG_TYPE "synth-cut-rewriter"
80 "Value not found in LogicNetwork - use getOrCreateIndex or check with "
95 "Index out of bounds in LogicNetwork::getValue");
100 SmallVectorImpl<Value> &values)
const {
102 values.reserve(indices.size());
103 for (uint32_t idx : indices)
114 Value result, ArrayRef<Signal> operands) {
122 const size_t estimatedSize =
123 block->getArguments().size() + block->getOperations().size();
125 gates.reserve(estimatedSize);
127 auto handleSingleInputGate = [&](Operation *op, Value result,
128 const Signal &inputSignal) {
129 if (!inputSignal.isInverted()) {
139 for (Value arg : block->getArguments()) {
144 auto handleOtherResults = [&](Operation *op) {
145 for (Value result : op->getResults()) {
146 if (result.getType().isInteger(1) && !
hasIndex(result))
151 auto getInvertibleSignal = [&](
auto op,
unsigned index) {
155 auto handleInvertibleBinaryGate = [&](
auto logicOp,
159 const auto inputs = logicOp.getInputs();
160 if (inputs.size() == 1) {
161 const Signal inputSignal = getInvertibleSignal(logicOp, 0);
162 handleSingleInputGate(logicOp, logicOp.getResult(), inputSignal);
165 if (inputs.size() == 2) {
166 const Signal lhsSignal = getInvertibleSignal(logicOp, 0);
167 const Signal rhsSignal = getInvertibleSignal(logicOp, 1);
168 addGate(logicOp, kind, {lhsSignal, rhsSignal});
172 handleOtherResults(logicOp);
176 auto handleInvertibleTernaryGate = [&](
auto logicOp,
178 if (!logicOp.getType().isInteger(1)) {
179 handleOtherResults(logicOp);
182 const Signal aSignal = getInvertibleSignal(logicOp, 0);
183 const Signal bSignal = getInvertibleSignal(logicOp, 1);
184 const Signal cSignal = getInvertibleSignal(logicOp, 2);
185 addGate(logicOp, kind, {aSignal, bSignal, cSignal});
190 for (Operation &op : block->getOperations()) {
191 LogicalResult result =
192 llvm::TypeSwitch<Operation *, LogicalResult>(&op)
193 .Case<aig::AndInverterOp>([&](aig::AndInverterOp andOp) {
196 .Case<synth::XorInverterOp>([&](synth::XorInverterOp xorOp) {
199 .Case<synth::MuxInverterOp>([&](synth::MuxInverterOp muxOp) {
202 .Case<synth::DotOp>([&](synth::DotOp dotOp) {
205 .Case<synth::MajorityOp>([&](synth::MajorityOp majOp) {
208 .Case<synth::OneHotOp>([&](synth::OneHotOp oneHotOp) {
209 return handleInvertibleTernaryGate(oneHotOp,
212 .Case<synth::GambleOp>([&](synth::GambleOp gambleOp) {
213 return handleInvertibleTernaryGate(gambleOp,
217 if (xorOp->getNumOperands() != 2) {
218 handleOtherResults(xorOp);
229 Value result = constOp.getResult();
230 if (!result.getType().isInteger(1)) {
231 handleOtherResults(constOp);
239 .Case<synth::ChoiceOp>([&](synth::ChoiceOp choiceOp) {
240 if (!choiceOp.getType().isInteger(1)) {
241 handleOtherResults(choiceOp);
248 .Default([&](Operation *defaultOp) {
249 handleOtherResults(defaultOp);
278 const auto &gate = network.
getGate(index);
285 ArrayRef<DelayType> newDelay,
double oldArea,
286 ArrayRef<DelayType> oldDelay) {
288 return newArea < oldArea || (newArea == oldArea && newDelay < oldDelay);
290 return newDelay < oldDelay || (newDelay == oldDelay && newArea < oldArea);
291 llvm_unreachable(
"Unknown mapping strategy");
295 const auto isOperationReady = [](Value value, Operation *op) ->
bool {
302 return emitError(topOp->getLoc(),
303 "failed to sort operations topologically");
311 for (Value arg : block->getArguments())
312 inputArgs.insert(arg);
314 if (inputArgs.empty())
317 const int64_t numInputs = inputArgs.size();
318 const int64_t numOutputs = values.size();
323 return mlir::emitError(values.front().getLoc(),
324 "Truth table is too large");
325 return mlir::emitError(values.front().getLoc(),
326 "Multiple outputs are not supported yet");
330 DenseMap<Value, APInt> eval;
331 for (uint32_t i = 0; i < numInputs; ++i)
335 for (Operation &op : block->without_terminator()) {
336 if (op.getNumResults() != 1 ||
338 return op.emitError(
"Unsupported operation for truth table simulation");
340 if (
auto choiceOp = dyn_cast<synth::ChoiceOp>(&op)) {
341 auto it = eval.find(choiceOp.getInputs().front());
342 if (it == eval.end())
343 return choiceOp.emitError(
"Input value not found in evaluation map");
344 eval[choiceOp.getResult()] = it->second;
345 }
else if (
auto logicOp = dyn_cast<BooleanLogicOpInterface>(&op)) {
346 for (
auto value : logicOp.getInputs())
347 if (!eval.contains(value))
348 return logicOp->emitError(
"Input value not found in evaluation map");
350 eval[logicOp.getResult()] =
351 logicOp.evaluateBooleanLogic([&](
unsigned i) ->
const APInt & {
352 return eval.find(logicOp.getInput(i))->second;
354 }
else if (
auto xorOp = dyn_cast<comb::XorOp>(&op)) {
356 auto it = eval.find(xorOp.getOperand(0));
357 if (it == eval.end())
358 return xorOp.emitError(
"Input value not found in evaluation map");
359 llvm::APInt result = it->second;
360 for (
unsigned i = 1; i < xorOp.getNumOperands(); ++i) {
361 it = eval.find(xorOp.getOperand(i));
362 if (it == eval.end())
363 return xorOp.emitError(
"Input value not found in evaluation map");
364 result ^= it->second;
366 eval[xorOp.getResult()] = result;
367 }
else if (
auto constantOp = dyn_cast<hw::ConstantOp>(&op)) {
368 auto tableSize = 1ULL << numInputs;
369 eval[constantOp.getResult()] = constantOp.getValue().isZero()
370 ? llvm::APInt::getZero(tableSize)
371 : llvm::APInt::getAllOnes(tableSize);
373 return op.emitError(
"Unsupported operation for truth table simulation");
401 npnClass.emplace(std::move(canonicalForm));
407 SmallVectorImpl<unsigned> &permutedIndices)
const {
409 npnClass.getInputPermutation(patternNPN, permutedIndices);
414 SmallVectorImpl<DelayType> &results)
const {
419 for (
auto inputIndex :
inputs) {
422 results.push_back(0);
425 auto *cutSet = enumerator.
getCutSet(inputIndex);
426 assert(cutSet &&
"Input must have a valid cut set");
430 auto *bestCut = cutSet->getBestMatchedCut();
437 mlir::Value inputValue = network.getValue(inputIndex);
441 cast<mlir::OpResult>(inputValue).getResultNumber()));
448 os <<
"// === Cut Dump ===\n";
453 os <<
" and root: " << *rootOp;
459 os <<
"Primary input cut: " << inputVal <<
"\n";
463 os <<
"Inputs (indices): \n";
464 for (
auto [idx, inputIndex] : llvm::enumerate(
inputs)) {
465 mlir::Value inputVal = network.
getValue(inputIndex);
466 os <<
" Input " << idx <<
" (index " << inputIndex <<
"): " << inputVal
471 os <<
"\nRoot operation: \n";
480 os <<
"// === Cut End ===\n";
489 return rootOp ? rootOp->getNumResults() : 1;
494 const llvm::APInt &a) {
499 llvm_unreachable(
"Unsupported unary operation for truth table computation");
504 const llvm::APInt &a,
505 const llvm::APInt &b) {
513 "Unsupported binary operation for truth table computation");
518 const llvm::APInt &a,
519 const llvm::APInt &b,
520 const llvm::APInt &c) {
534 "Unsupported ternary operation for truth table computation");
542struct MergedTruthTableBuilder {
543 MergedTruthTableBuilder(ArrayRef<uint32_t> mergedInputs,
544 ArrayRef<const Cut *> operandCuts)
545 : mergedInputs(mergedInputs), numMergedInputs(mergedInputs.size()),
546 operandCuts(operandCuts) {
547 assert(llvm::is_sorted(mergedInputs) &&
"merged inputs must be sorted");
548 assert(llvm::adjacent_find(mergedInputs) == mergedInputs.end() &&
549 "merged inputs must be unique");
552 ArrayRef<uint32_t> mergedInputs;
553 unsigned numMergedInputs;
554 ArrayRef<const Cut *> operandCuts;
556 std::optional<unsigned> findMergedInputPosition(uint32_t operandIdx)
const {
557 auto *it = llvm::find(mergedInputs, operandIdx);
558 if (it == mergedInputs.end())
560 return static_cast<unsigned>(std::distance(mergedInputs.begin(), it));
563 const Cut *findOperandCut(uint32_t operandIdx)
const {
564 for (
const Cut *cut : operandCuts) {
568 cut->
isTrivialCut() ? cut->inputs[0] : cut->getRootIndex();
569 if (cutOutput == operandIdx)
575 void getInputMapping(
const Cut *cut,
576 SmallVectorImpl<unsigned> &mapping)
const {
578 mapping.reserve(cut->
inputs.size());
579 for (uint32_t idx : cut->inputs) {
580 auto *it = llvm::find(mergedInputs, idx);
581 assert(it != mergedInputs.end() &&
582 "cut input must exist in merged inputs");
583 mapping.push_back(
static_cast<unsigned>(it - mergedInputs.begin()));
587 llvm::APInt expandCutTruthTable(
const Cut *cut)
const {
589 SmallVector<unsigned, 8> inputMapping;
590 getInputMapping(cut, inputMapping);
592 cutTT.table, inputMapping, numMergedInputs);
595 llvm::APInt expandOperand(uint32_t operandIdx,
bool isInverted)
const {
596 llvm::APInt result(1, 0);
598 result = llvm::APInt::getZero(1U << numMergedInputs);
600 result = llvm::APInt::getAllOnes(1U << numMergedInputs);
601 }
else if (
auto pos = findMergedInputPosition(operandIdx)) {
604 }
else if (
const Cut *cut = findOperandCut(operandIdx)) {
607 result = expandCutTruthTable(cut);
609 llvm_unreachable(
"Operand not found in cuts or merged inputs");
613 result.flipAllBits();
618 auto getEdgeTT = [&](
unsigned edgeIdx) {
619 const auto &edge = rootGate.
edges[edgeIdx];
620 return expandOperand(edge.getIndex(), edge.isInverted());
636 getEdgeTT(0), getEdgeTT(1),
643 llvm_unreachable(
"Unsupported operation for truth table computation");
657 "non-trivial cuts must carry operand cuts for truth table expansion");
676 return std::includes(otherInputs.begin(), otherInputs.end(),
inputs.begin(),
682 cut.
inputs.push_back(index);
695 assert(
pattern &&
"Pattern must be set to get arrival time");
700 assert(
pattern &&
"Pattern must be set to get arrival time");
732 auto dumpInputs = [](llvm::raw_ostream &os,
733 const llvm::SmallVectorImpl<uint32_t> &inputs) {
735 llvm::interleaveComma(inputs, os);
740 std::stable_sort(cuts.begin(), cuts.end(), [](
const Cut *a,
const Cut *b) {
741 if (a->getInputSize() != b->getInputSize())
742 return a->getInputSize() < b->getInputSize();
743 return std::lexicographical_compare(a->inputs.begin(), a->inputs.end(),
744 b->inputs.begin(), b->inputs.end());
748 unsigned maxCutSize = cuts.empty() ? 0 : cuts.back()->getInputSize();
749 llvm::SmallVector<llvm::SmallVector<Cut *, 4>, 16> keptBySize(maxCutSize + 1);
752 unsigned uniqueCount = 0;
753 for (
Cut *cut : cuts) {
758 if (uniqueCount > 0) {
759 Cut *lastKept = cuts[uniqueCount - 1];
765 bool isDominated =
false;
766 for (
unsigned existingSize = 1; existingSize < cutSize && !isDominated;
768 for (
const Cut *existingCut : keptBySize[existingSize]) {
769 if (!existingCut->dominates(*cut))
773 llvm::dbgs() <<
"Dropping non-minimal cut ";
774 dumpInputs(llvm::dbgs(), cut->
inputs);
775 llvm::dbgs() <<
" due to subset ";
776 dumpInputs(llvm::dbgs(), existingCut->inputs);
777 llvm::dbgs() <<
"\n";
787 cuts[uniqueCount++] = cut;
788 keptBySize[cutSize].push_back(cut);
791 LLVM_DEBUG(llvm::dbgs() <<
"Original cuts: " << cuts.size()
792 <<
" Unique cuts: " << uniqueCount <<
"\n");
795 cuts.resize(uniqueCount);
800 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)> matchCut,
813 "Cut input size exceeds maximum allowed size");
815 if (
auto matched = matchCut(*cut))
829 auto *trivialCutsEnd =
830 std::stable_partition(
cuts.begin(),
cuts.end(),
831 [](
const Cut *cut) { return cut->isTrivialCut(); });
833 auto isBetterCut = [&options](
const Cut *a,
const Cut *b) {
835 "Trivial cuts should have been excluded");
837 const auto &bMatched = b->getMatchedPattern();
839 if (aMatched && bMatched)
841 options.
strategy, aMatched->getArea(), aMatched->getArrivalTimes(),
842 bMatched->getArea(), bMatched->getArrivalTimes());
844 if (
static_cast<bool>(aMatched) !=
static_cast<bool>(bMatched))
845 return static_cast<bool>(aMatched);
849 std::stable_sort(trivialCutsEnd,
cuts.end(), isBetterCut);
866 llvm::dbgs() <<
"Finalized cut set with " <<
cuts.size() <<
" cuts and "
871 :
"no matched pattern")
883 SmallVectorImpl<NPNClass> &matchingNPNClasses)
const {
892 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4>
patterns)
896 SmallVector<NPNClass, 2> npnClasses;
897 auto result =
pattern->useTruthTableMatcher(npnClasses);
899 for (
auto npnClass : npnClasses) {
902 npnClass.truthTable.numInputs}]
903 .push_back(std::make_pair(std::move(npnClass),
pattern.get()));
918 : cutAllocator(stats.numCutsCreated),
919 cutSetAllocator(stats.numCutSetsCreated), options(options) {}
923 auto [cutSetPtr, inserted] =
cutSets.try_emplace(index, cutSet);
924 assert(inserted &&
"Cut set already exists for this index");
925 return cutSetPtr->second;
939 assert(logicOp && logicOp->getNumResults() == 1 &&
940 "Logic operation must have a single result");
943 auto choiceOp = cast<synth::ChoiceOp>(logicOp);
947 resultCutSet->addCut(primaryInputCut);
949 for (Value operand : choiceOp.getInputs()) {
952 return logicOp->emitError(
"Failed to get cut set for choice operand");
956 for (
const Cut *operandCut : operandCutSet->getCuts()) {
957 if (operandCut->isTrivialCut())
961 nodeIndex, operandCut->inputs, operandCut->getSignature(),
962 ArrayRef<const Cut *>{operandCut}, *operandCut->getTruthTable()));
971 unsigned numFanins = gate.getNumFanins();
976 return logicOp->emitError(
"Cut enumeration supports at most 3 operands, "
979 if (!logicOp->getOpResult(0).getType().isInteger(1))
980 return logicOp->emitError()
981 <<
"Supported logic operations must have a single bit "
982 "result type but found: "
983 << logicOp->getResult(0).getType();
987 SmallVector<std::pair<const CutSet *, unsigned>, 2> operandCutSets;
988 operandCutSets.reserve(numFanins);
991 for (
unsigned i = 0; i < numFanins; ++i) {
992 uint32_t faninIndex = gate.edges[i].getIndex();
993 auto *operandCutSet =
getCutSet(faninIndex);
995 return logicOp->emitError(
"Failed to get cut set for fanin index ")
1000 unsigned maxInputCutSize = 0;
1001 for (
auto *cut : operandCutSet->getCuts())
1002 maxInputCutSize = std::max(maxInputCutSize, cut->
getInputSize());
1003 operandCutSets.push_back(std::make_pair(operandCutSet, maxInputCutSize));
1011 resultCutSet->addCut(primaryInputCut);
1017 llvm::stable_sort(operandCutSets,
1018 [](
const std::pair<const CutSet *, unsigned> &a,
1019 const std::pair<const CutSet *, unsigned> &b) {
1020 return a.second > b.second;
1028 auto enumerateCutCombinations = [&](
auto &&self,
unsigned operandIdx,
1029 SmallVector<const Cut *, 3> &cutPtrs,
1030 uint64_t currentSig) ->
void {
1032 if (operandIdx == numFanins) {
1036 SmallVector<uint32_t, 6> mergedInputs;
1037 auto appendMergedInput = [&](uint32_t value) {
1041 if (!mergedInputs.empty() && mergedInputs.back() == value)
1043 mergedInputs.push_back(value);
1044 return mergedInputs.size() <= maxInputSize;
1047 if (numFanins == 1) {
1049 mergedInputs.reserve(
1050 std::min<size_t>(cutPtrs[0]->inputs.size(), maxInputSize));
1051 for (uint32_t value : cutPtrs[0]->inputs)
1052 if (!appendMergedInput(value))
1054 }
else if (numFanins == 2) {
1056 const auto &inputs0 = cutPtrs[0]->inputs;
1057 const auto &inputs1 = cutPtrs[1]->inputs;
1058 mergedInputs.reserve(
1059 std::min<size_t>(inputs0.size() + inputs1.size(), maxInputSize));
1061 unsigned i = 0, j = 0;
1062 while (i < inputs0.size() || j < inputs1.size()) {
1064 if (j == inputs1.size() ||
1065 (i < inputs0.size() && inputs0[i] <= inputs1[j])) {
1066 next = inputs0[i++];
1067 if (j < inputs1.size() && inputs1[j] == next)
1070 next = inputs1[j++];
1072 if (!appendMergedInput(next))
1077 const SmallVectorImpl<uint32_t> &inputs0 = cutPtrs[0]->inputs;
1078 const SmallVectorImpl<uint32_t> &inputs1 = cutPtrs[1]->inputs;
1079 const SmallVectorImpl<uint32_t> &inputs2 = cutPtrs[2]->inputs;
1080 mergedInputs.reserve(std::min<size_t>(
1081 inputs0.size() + inputs1.size() + inputs2.size(), maxInputSize));
1083 unsigned i = 0, j = 0, k = 0;
1084 while (i < inputs0.size() || j < inputs1.size() || k < inputs2.size()) {
1086 uint32_t minVal = UINT32_MAX;
1087 if (i < inputs0.size())
1088 minVal = std::min(minVal, inputs0[i]);
1089 if (j < inputs1.size())
1090 minVal = std::min(minVal, inputs1[j]);
1091 if (k < inputs2.size())
1092 minVal = std::min(minVal, inputs2[k]);
1095 if (i < inputs0.size() && inputs0[i] == minVal)
1097 if (j < inputs1.size() && inputs1[j] == minVal)
1099 if (k < inputs2.size() && inputs2[k] == minVal)
1102 if (!appendMergedInput(minVal))
1108 Cut *mergedCut =
cutAllocator.create(nodeIndex, mergedInputs, currentSig,
1109 ArrayRef<const Cut *>(cutPtrs));
1110 resultCutSet->addCut(mergedCut);
1113 if (mergedCut->
inputs.size() >= 4) {
1114 llvm::dbgs() <<
"Generated cut for node " << nodeIndex;
1116 llvm::dbgs() <<
" (" << logicOp->getName() <<
")";
1117 llvm::dbgs() <<
" inputs=";
1118 llvm::interleaveComma(mergedCut->inputs, llvm::dbgs());
1119 llvm::dbgs() <<
"\n";
1126 const CutSet *currentCutSet = operandCutSets[operandIdx].first;
1127 for (
const Cut *cut : currentCutSet->
getCuts()) {
1129 uint64_t newSig = currentSig | cutSig;
1130 if (
static_cast<unsigned>(llvm::popcount(newSig)) > maxInputSize)
1133 cutPtrs.push_back(cut);
1136 self(self, operandIdx + 1, cutPtrs, newSig);
1143 SmallVector<const Cut *, 3> cutPtrs;
1144 cutPtrs.reserve(numFanins);
1145 enumerateCutCombinations(enumerateCutCombinations, 0, cutPtrs, 0ULL);
1155 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)> matchCut) {
1156 LLVM_DEBUG(llvm::dbgs() <<
"Enumerating cuts for module: " << topOp->getName()
1166 auto &block = topOp->getRegion(0).getBlocks().front();
1172 if (!gate.isLogicGate())
1180 LLVM_DEBUG(llvm::dbgs() <<
"Cut enumeration completed successfully\n");
1186 auto it =
cutSets.find(index);
1191 cutSet->
addCut(trivialCut);
1192 auto [newIt, inserted] =
cutSets.insert({index, cutSet});
1193 assert(inserted &&
"Cut set already exists for this index");
1205 if (
auto *op = value.getDefiningOp()) {
1208 if (
auto name = op->getAttrOfType<StringAttr>(
"sv.namehint"))
1209 return name.getValue();
1213 if (op->getNumResults() == 1) {
1214 auto opName = op->getName();
1215 auto count = opCounter[opName]++;
1218 SmallString<16> nameStr;
1219 nameStr += opName.getStringRef();
1221 nameStr += std::to_string(count);
1224 auto nameAttr = StringAttr::get(op->getContext(), nameStr);
1225 op->setAttr(
"sv.namehint", nameAttr);
1234 auto blockArg = cast<BlockArgument>(value);
1236 dyn_cast<circt::hw::HWModuleOp>(blockArg.getOwner()->getParentOp());
1241 return hwOp.getInputName(blockArg.getArgNumber());
1245 DenseMap<OperationName, unsigned> opCounter;
1247 auto it =
cutSets.find(index);
1250 auto &cutSet = *it->second;
1253 << cutSet.getCuts().size() <<
" cuts:";
1254 for (
const Cut *cut : cutSet.getCuts()) {
1255 llvm::outs() <<
" {";
1256 llvm::interleaveComma(cut->
inputs, llvm::outs(), [&](uint32_t inputIdx) {
1257 mlir::Value inputVal = logicNetwork.getValue(inputIdx);
1258 llvm::outs() << getTestVariableName(inputVal, opCounter);
1262 <<
"@t" << cut->
getTruthTable()->table.getZExtValue() <<
"d";
1264 llvm::outs() << *std::max_element(
pattern->getArrivalTimes().begin(),
1265 pattern->getArrivalTimes().end());
1267 llvm::outs() <<
"0";
1270 llvm::outs() <<
"\n";
1272 llvm::outs() <<
"Cut enumeration completed successfully\n";
1281 llvm::dbgs() <<
"Starting Cut Rewriter\n";
1282 llvm::dbgs() <<
"Mode: "
1294 if (
pattern->getNumOutputs() > 1) {
1295 return mlir::emitError(
pattern->getLoc(),
1296 "Cut rewriter does not support patterns with "
1297 "multiple outputs yet");
1324 LLVM_DEBUG(llvm::dbgs() <<
"Enumerating cuts...\n");
1327 topOp, [&](
const Cut &cut) -> std::optional<MatchedPattern> {
1333ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
1335 if (
patterns.npnToPatternMap.empty())
1339 auto it =
patterns.npnToPatternMap.find(
1340 {npnClass.truthTable.table, npnClass.truthTable.numInputs});
1341 if (it ==
patterns.npnToPatternMap.end())
1343 return it->getSecond();
1352 SmallVector<DelayType, 4> inputArrivalTimes;
1353 SmallVector<DelayType, 1> bestArrivalTimes;
1354 double bestArea = 0.0;
1362 auto computeArrivalTimeAndPickBest =
1364 llvm::function_ref<unsigned(
unsigned)> mapIndex) {
1365 SmallVector<DelayType, 1> outputArrivalTimes;
1367 for (
unsigned outputIndex = 0, outputSize = cut.
getOutputSize(network);
1368 outputIndex < outputSize; ++outputIndex) {
1371 auto delays = matchResult.getDelays();
1372 for (
unsigned inputIndex = 0, inputSize = cut.
getInputSize();
1373 inputIndex < inputSize; ++inputIndex) {
1375 unsigned cutOriginalInput = mapIndex(inputIndex);
1377 std::max(outputArrivalTime,
1378 delays[outputIndex * inputSize + inputIndex] +
1379 inputArrivalTimes[cutOriginalInput]);
1382 outputArrivalTimes.push_back(outputArrivalTime);
1388 outputArrivalTimes, bestArea,
1389 bestArrivalTimes)) {
1391 llvm::dbgs() <<
"== Matched Pattern ==============\n";
1392 llvm::dbgs() <<
"Matching cut: \n";
1393 cut.
dump(llvm::dbgs(), network);
1394 llvm::dbgs() <<
"Found better pattern: "
1396 llvm::dbgs() <<
" with area: " << matchResult.area;
1397 llvm::dbgs() <<
" and input arrival times: ";
1398 for (
unsigned i = 0; i < inputArrivalTimes.size(); ++i) {
1399 llvm::dbgs() <<
" " << inputArrivalTimes[i];
1401 llvm::dbgs() <<
" and arrival times: ";
1403 for (
auto arrivalTime : outputArrivalTimes) {
1404 llvm::dbgs() <<
" " << arrivalTime;
1406 llvm::dbgs() <<
"\n";
1407 llvm::dbgs() <<
"== Matched Pattern End ==============\n";
1410 bestArrivalTimes = std::move(outputArrivalTimes);
1411 bestArea = matchResult.area;
1418 "Pattern input size must match cut input size");
1425 SmallVector<unsigned> inputMapping;
1427 computeArrivalTimeAndPickBest(
pattern, *matchResult,
1428 [&](
unsigned i) {
return inputMapping[i]; });
1433 computeArrivalTimeAndPickBest(
pattern, *matchResult,
1434 [&](
unsigned i) {
return i; });
1440 return MatchedPattern(bestPattern, std::move(bestArrivalTimes), bestArea);
1444 LLVM_DEBUG(llvm::dbgs() <<
"Performing cut-based rewriting...\n");
1451 PatternRewriter rewriter(top->getContext());
1454 for (
auto index : llvm::reverse(processingOrder)) {
1455 auto it = cutSets.find(index);
1456 if (it == cutSets.end())
1459 mlir::Value value = network.getValue(index);
1460 auto &cutSet = *it->second;
1462 if (value.use_empty()) {
1463 if (
auto *op = value.getDefiningOp())
1470 LLVM_DEBUG(llvm::dbgs() <<
"Skipping inputs: " << value <<
"\n");
1474 LLVM_DEBUG(llvm::dbgs() <<
"Cut set for value: " << value <<
"\n");
1475 auto *bestCut = cutSet.getBestMatchedCut();
1479 return emitError(value.getLoc(),
"No matching cut found for value: ")
1484 auto *rootOp = network.getGate(bestCut->getRootIndex()).getOperation();
1485 rewriter.setInsertionPoint(rootOp);
1486 const auto &matchedPattern = bestCut->getMatchedPattern();
1487 auto result = matchedPattern->getPattern()->rewrite(rewriter,
cutEnumerator,
1492 rewriter.replaceOp(rootOp, *result);
1496 auto array = rewriter.getI64ArrayAttr(matchedPattern->getArrivalTimes());
1497 (*result)->setAttr(
"test.arrival_times", array);
assert(baseType &&"element must be base type")
static llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind, const llvm::APInt &a)
Simulate a gate and return its truth table.
static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl< Cut * > &cuts)
static bool isAlwaysCutInput(const LogicNetwork &network, uint32_t index)
static StringRef getTestVariableName(Value value, DenseMap< OperationName, unsigned > &opCounter)
Generate a human-readable name for a value used in test output.
static bool compareDelayAndArea(OptimizationStrategy strategy, double newArea, ArrayRef< DelayType > newDelay, double oldArea, ArrayRef< DelayType > oldDelay)
RewritePatternSet pattern
Precomputed NPN canonicalization table for 4-input single-output functions.
bool lookup(const BinaryTruthTable &tt, NPNClass &result) const
Returns false if the given truth table shape is unsupported.
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
TrackedSpecificBumpPtrAllocator< CutSet > cutSetAllocator
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
void noteCutRewritten()
Record that one cut was successfully rewritten.
CutEnumerator(const CutRewriterOptions &options)
Constructor for cut enumerator.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
TrackedSpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
CutRewritePatternSet(llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns)
Constructor that takes ownership of the provided patterns.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
CutEnumerator cutEnumerator
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
static Cut getTrivialCut(uint32_t index)
Create a trivial cut for a value.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
uint64_t signature
Signature bitset for fast cut size estimation.
uint64_t getSignature() const
Get the signature of this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
void setTruthTable(BinaryTruthTable tt)
Set the truth table directly (used for incremental computation).
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
llvm::SmallVector< const Cut *, 3 > operandCuts
Operand cuts used to create this cut (for lazy TT computation).
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
void setSignature(uint64_t sig)
Set the signature of this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another (i.e., this cut's inputs are a subset of the other's inputs).
void getPermutatedInputIndices(const NPNTable *npnTable, const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
unsigned getInputSize() const
Get the number of inputs to this cut.
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
double getArea() const
Get the area cost of using this pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
llvm::APInt expandTruthTableToInputSpace(const llvm::APInt &tt, ArrayRef< unsigned > inputMapping, unsigned numExpandedInputs)
Expand a truth table to a larger input space using the given input mapping.
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
OptimizationStrategy
Optimization strategy.
@ OptimizationStrategyArea
Optimize for minimal area.
@ OptimizationStrategyTiming
Optimize for minimal critical path delay.
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
T evaluateDotLogic(const T &x, const T &y, const T &z)
Evaluate the Boolean function x ^ (z | (x & y)).
T evaluateGambleLogic(const T &a, const T &b, const T &c)
T evaluateMajorityLogic(const T &a, const T &b, const T &c)
T evaluateMuxLogic(const T &a, const T &b, const T &c)
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
bool isLogicNetworkOp(mlir::Operation *op)
T evaluateOneHotLogic(const T &a, const T &b, const T &c)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::APInt createVarMask(unsigned numVars, unsigned varIndex, bool positive)
Create a mask for a variable in the truth table.
Represents a boolean function as a truth table.
Represents the canonical form of a boolean function under NPN equivalence.
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
void getInputPermutation(const NPNClass &targetNPN, llvm::SmallVectorImpl< unsigned > &permutation) const
Get input permutation from this NPN class to another equivalent NPN class.
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...
void eraseNow(Operation *op)
Erase an operation immediately, and remove it from the set of ops to be removed later.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
const NPNTable * npnTable
Optional lookup table used to accelerate 4-input NPN canonicalization.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Signal edges[3]
Fanin edges (up to 3 inputs).
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
Kind getKind() const
Get the kind of this gate.
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Dot3
Ordered DOT gate (3-input, synth.dot)
@ OneHot3
OneHot gate (3-input, synth.onehot)
@ Identity
Identity gate (used for 1-input inverter)
@ Gamble3
Ordered Gamble gate (3-input, synth.gamble)
@ Mux3
Ordered MUX gate (3-input, synth.mux_inv)
@ Maj3
Reserved 3-input gate kind.
@ PrimaryInput
Primary input to the network.
@ Choice
Choice node (synth.choice)
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
Edge representation in the logic network.