32#include "mlir/Analysis/TopologicalSortUtils.h"
33#include "mlir/IR/Builders.h"
34#include "mlir/IR/Operation.h"
35#include "mlir/IR/RegionKindInterface.h"
36#include "mlir/IR/Value.h"
37#include "mlir/IR/ValueRange.h"
38#include "mlir/IR/Visitors.h"
39#include "mlir/Support/LLVM.h"
40#include "llvm/ADT/APInt.h"
41#include "llvm/ADT/Bitset.h"
42#include "llvm/ADT/DenseMap.h"
43#include "llvm/ADT/MapVector.h"
44#include "llvm/ADT/STLExtras.h"
45#include "llvm/ADT/ScopeExit.h"
46#include "llvm/ADT/SetVector.h"
47#include "llvm/ADT/SmallVector.h"
48#include "llvm/ADT/TypeSwitch.h"
49#include "llvm/ADT/iterator.h"
50#include "llvm/Support/Debug.h"
51#include "llvm/Support/ErrorHandling.h"
52#include "llvm/Support/LogicalResult.h"
59#define DEBUG_TYPE "synth-cut-rewriter"
80 "Value not found in LogicNetwork - use getOrCreateIndex or check with "
95 "Index out of bounds in LogicNetwork::getValue");
100 SmallVectorImpl<Value> &values)
const {
102 values.reserve(indices.size());
103 for (uint32_t idx : indices)
114 Value result, ArrayRef<Signal> operands) {
122 const size_t estimatedSize =
123 block->getArguments().size() + block->getOperations().size();
125 gates.reserve(estimatedSize);
127 auto handleSingleInputGate = [&](Operation *op, Value result,
128 const Signal &inputSignal) {
129 if (!inputSignal.isInverted()) {
139 for (Value arg : block->getArguments()) {
144 auto handleOtherResults = [&](Operation *op) {
145 for (Value result : op->getResults()) {
146 if (result.getType().isInteger(1) && !
hasIndex(result))
151 auto getInvertibleSignal = [&](
auto op,
unsigned index) {
155 auto handleInvertibleBinaryGate = [&](
auto logicOp,
159 const auto inputs = logicOp.getInputs();
160 if (inputs.size() == 1) {
161 const Signal inputSignal = getInvertibleSignal(logicOp, 0);
162 handleSingleInputGate(logicOp, logicOp.getResult(), inputSignal);
165 if (inputs.size() == 2) {
166 const Signal lhsSignal = getInvertibleSignal(logicOp, 0);
167 const Signal rhsSignal = getInvertibleSignal(logicOp, 1);
168 addGate(logicOp, kind, {lhsSignal, rhsSignal});
172 handleOtherResults(logicOp);
177 for (Operation &op : block->getOperations()) {
178 LogicalResult result =
179 llvm::TypeSwitch<Operation *, LogicalResult>(&op)
180 .Case<aig::AndInverterOp>([&](aig::AndInverterOp andOp) {
183 .Case<synth::XorInverterOp>([&](synth::XorInverterOp xorOp) {
187 if (xorOp->getNumOperands() != 2) {
188 handleOtherResults(xorOp);
199 Value result = constOp.getResult();
200 if (!result.getType().isInteger(1)) {
201 handleOtherResults(constOp);
209 .Case<synth::ChoiceOp>([&](synth::ChoiceOp choiceOp) {
210 if (!choiceOp.getType().isInteger(1)) {
211 handleOtherResults(choiceOp);
218 .Default([&](Operation *defaultOp) {
219 handleOtherResults(defaultOp);
248 const auto &gate = network.
getGate(index);
255 ArrayRef<DelayType> newDelay,
double oldArea,
256 ArrayRef<DelayType> oldDelay) {
258 return newArea < oldArea || (newArea == oldArea && newDelay < oldDelay);
260 return newDelay < oldDelay || (newDelay == oldDelay && newArea < oldArea);
261 llvm_unreachable(
"Unknown mapping strategy");
265 const auto isOperationReady = [](Value value, Operation *op) ->
bool {
268 return !(isa<aig::AndInverterOp, synth::XorInverterOp, synth::ChoiceOp,
274 return emitError(topOp->getLoc(),
275 "failed to sort operations topologically");
283 for (Value arg : block->getArguments())
284 inputArgs.insert(arg);
286 if (inputArgs.empty())
289 const int64_t numInputs = inputArgs.size();
290 const int64_t numOutputs = values.size();
295 return mlir::emitError(values.front().getLoc(),
296 "Truth table is too large");
297 return mlir::emitError(values.front().getLoc(),
298 "Multiple outputs are not supported yet");
302 DenseMap<Value, APInt> eval;
303 for (uint32_t i = 0; i < numInputs; ++i)
307 for (Operation &op : *block) {
308 if (op.getNumResults() == 0)
311 if (
auto choiceOp = dyn_cast<synth::ChoiceOp>(&op)) {
312 auto it = eval.find(choiceOp.getInputs().front());
313 if (it == eval.end())
314 return choiceOp.emitError(
"Input value not found in evaluation map");
315 eval[choiceOp.getResult()] = it->second;
316 }
else if (
auto logicOp = dyn_cast<BooleanLogicOpInterface>(&op)) {
317 for (
auto value : logicOp.getInputs())
318 if (!eval.contains(value))
319 return logicOp->emitError(
"Input value not found in evaluation map");
321 eval[logicOp.getResult()] =
322 logicOp.evaluateBooleanLogic([&](
unsigned i) ->
const APInt & {
323 return eval.find(logicOp.getInput(i))->second;
325 }
else if (
auto xorOp = dyn_cast<comb::XorOp>(&op)) {
327 auto it = eval.find(xorOp.getOperand(0));
328 if (it == eval.end())
329 return xorOp.emitError(
"Input value not found in evaluation map");
330 llvm::APInt result = it->second;
331 for (
unsigned i = 1; i < xorOp.getNumOperands(); ++i) {
332 it = eval.find(xorOp.getOperand(i));
333 if (it == eval.end())
334 return xorOp.emitError(
"Input value not found in evaluation map");
335 result ^= it->second;
337 eval[xorOp.getResult()] = result;
338 }
else if (!isa<hw::OutputOp>(&op)) {
339 return op.emitError(
"Unsupported operation for truth table simulation");
367 npnClass.emplace(std::move(canonicalForm));
373 SmallVectorImpl<unsigned> &permutedIndices)
const {
375 npnClass.getInputPermutation(patternNPN, permutedIndices);
380 SmallVectorImpl<DelayType> &results)
const {
385 for (
auto inputIndex :
inputs) {
388 results.push_back(0);
391 auto *cutSet = enumerator.
getCutSet(inputIndex);
392 assert(cutSet &&
"Input must have a valid cut set");
396 auto *bestCut = cutSet->getBestMatchedCut();
403 mlir::Value inputValue = network.getValue(inputIndex);
407 cast<mlir::OpResult>(inputValue).getResultNumber()));
414 os <<
"// === Cut Dump ===\n";
419 os <<
" and root: " << *rootOp;
425 os <<
"Primary input cut: " << inputVal <<
"\n";
429 os <<
"Inputs (indices): \n";
430 for (
auto [idx, inputIndex] : llvm::enumerate(
inputs)) {
431 mlir::Value inputVal = network.
getValue(inputIndex);
432 os <<
" Input " << idx <<
" (index " << inputIndex <<
"): " << inputVal
437 os <<
"\nRoot operation: \n";
446 os <<
"// === Cut End ===\n";
455 return rootOp ? rootOp->getNumResults() : 1;
460 const llvm::APInt &a) {
465 llvm_unreachable(
"Unsupported unary operation for truth table computation");
470 const llvm::APInt &a,
471 const llvm::APInt &b) {
479 "Unsupported binary operation for truth table computation");
484 const llvm::APInt &a,
485 const llvm::APInt &b,
486 const llvm::APInt &c) {
489 return (a & b) | (a & c) | (b & c);
492 "Unsupported ternary operation for truth table computation");
500struct MergedTruthTableBuilder {
501 MergedTruthTableBuilder(ArrayRef<uint32_t> mergedInputs,
502 ArrayRef<const Cut *> operandCuts)
503 : mergedInputs(mergedInputs), numMergedInputs(mergedInputs.size()),
504 operandCuts(operandCuts) {
505 assert(llvm::is_sorted(mergedInputs) &&
"merged inputs must be sorted");
506 assert(llvm::adjacent_find(mergedInputs) == mergedInputs.end() &&
507 "merged inputs must be unique");
510 ArrayRef<uint32_t> mergedInputs;
511 unsigned numMergedInputs;
512 ArrayRef<const Cut *> operandCuts;
514 std::optional<unsigned> findMergedInputPosition(uint32_t operandIdx)
const {
515 auto *it = llvm::find(mergedInputs, operandIdx);
516 if (it == mergedInputs.end())
518 return static_cast<unsigned>(std::distance(mergedInputs.begin(), it));
521 const Cut *findOperandCut(uint32_t operandIdx)
const {
522 for (
const Cut *cut : operandCuts) {
526 cut->
isTrivialCut() ? cut->inputs[0] : cut->getRootIndex();
527 if (cutOutput == operandIdx)
533 void getInputMapping(
const Cut *cut,
534 SmallVectorImpl<unsigned> &mapping)
const {
536 mapping.reserve(cut->
inputs.size());
537 for (uint32_t idx : cut->inputs) {
538 auto *it = llvm::find(mergedInputs, idx);
539 assert(it != mergedInputs.end() &&
540 "cut input must exist in merged inputs");
541 mapping.push_back(
static_cast<unsigned>(it - mergedInputs.begin()));
545 llvm::APInt expandCutTruthTable(
const Cut *cut)
const {
547 SmallVector<unsigned, 8> inputMapping;
548 getInputMapping(cut, inputMapping);
550 cutTT.table, inputMapping, numMergedInputs);
553 llvm::APInt expandOperand(uint32_t operandIdx,
bool isInverted)
const {
554 llvm::APInt result(1, 0);
556 result = llvm::APInt::getZero(1U << numMergedInputs);
558 result = llvm::APInt::getAllOnes(1U << numMergedInputs);
559 }
else if (
auto pos = findMergedInputPosition(operandIdx)) {
562 }
else if (
const Cut *cut = findOperandCut(operandIdx)) {
565 result = expandCutTruthTable(cut);
567 llvm_unreachable(
"Operand not found in cuts or merged inputs");
571 result.flipAllBits();
576 auto getEdgeTT = [&](
unsigned edgeIdx) {
577 const auto &edge = rootGate.
edges[edgeIdx];
578 return expandOperand(edge.getIndex(), edge.isInverted());
590 getEdgeTT(0), getEdgeTT(1),
597 llvm_unreachable(
"Unsupported operation for truth table computation");
611 "non-trivial cuts must carry operand cuts for truth table expansion");
630 return std::includes(otherInputs.begin(), otherInputs.end(),
inputs.begin(),
636 cut.
inputs.push_back(index);
649 assert(
pattern &&
"Pattern must be set to get arrival time");
654 assert(
pattern &&
"Pattern must be set to get arrival time");
686 auto dumpInputs = [](llvm::raw_ostream &os,
687 const llvm::SmallVectorImpl<uint32_t> &inputs) {
689 llvm::interleaveComma(inputs, os);
694 std::stable_sort(cuts.begin(), cuts.end(), [](
const Cut *a,
const Cut *b) {
695 if (a->getInputSize() != b->getInputSize())
696 return a->getInputSize() < b->getInputSize();
697 return std::lexicographical_compare(a->inputs.begin(), a->inputs.end(),
698 b->inputs.begin(), b->inputs.end());
702 unsigned maxCutSize = cuts.empty() ? 0 : cuts.back()->getInputSize();
703 llvm::SmallVector<llvm::SmallVector<Cut *, 4>, 16> keptBySize(maxCutSize + 1);
706 unsigned uniqueCount = 0;
707 for (
Cut *cut : cuts) {
712 if (uniqueCount > 0) {
713 Cut *lastKept = cuts[uniqueCount - 1];
719 bool isDominated =
false;
720 for (
unsigned existingSize = 1; existingSize < cutSize && !isDominated;
722 for (
const Cut *existingCut : keptBySize[existingSize]) {
723 if (!existingCut->dominates(*cut))
727 llvm::dbgs() <<
"Dropping non-minimal cut ";
728 dumpInputs(llvm::dbgs(), cut->
inputs);
729 llvm::dbgs() <<
" due to subset ";
730 dumpInputs(llvm::dbgs(), existingCut->inputs);
731 llvm::dbgs() <<
"\n";
741 cuts[uniqueCount++] = cut;
742 keptBySize[cutSize].push_back(cut);
745 LLVM_DEBUG(llvm::dbgs() <<
"Original cuts: " << cuts.size()
746 <<
" Unique cuts: " << uniqueCount <<
"\n");
749 cuts.resize(uniqueCount);
754 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)> matchCut,
767 "Cut input size exceeds maximum allowed size");
769 if (
auto matched = matchCut(*cut))
783 auto *trivialCutsEnd =
784 std::stable_partition(
cuts.begin(),
cuts.end(),
785 [](
const Cut *cut) { return cut->isTrivialCut(); });
787 auto isBetterCut = [&options](
const Cut *a,
const Cut *b) {
788 assert(!a->isTrivialCut() && !b->isTrivialCut() &&
789 "Trivial cuts should have been excluded");
790 const auto &aMatched = a->getMatchedPattern();
791 const auto &bMatched = b->getMatchedPattern();
793 if (aMatched && bMatched)
795 options.
strategy, aMatched->getArea(), aMatched->getArrivalTimes(),
796 bMatched->getArea(), bMatched->getArrivalTimes());
798 if (
static_cast<bool>(aMatched) !=
static_cast<bool>(bMatched))
799 return static_cast<bool>(aMatched);
801 return a->getInputSize() < b->getInputSize();
803 std::stable_sort(trivialCutsEnd,
cuts.end(), isBetterCut);
820 llvm::dbgs() <<
"Finalized cut set with " <<
cuts.size() <<
" cuts and "
825 :
"no matched pattern")
837 SmallVectorImpl<NPNClass> &matchingNPNClasses)
const {
846 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4>
patterns)
850 SmallVector<NPNClass, 2> npnClasses;
851 auto result =
pattern->useTruthTableMatcher(npnClasses);
853 for (
auto npnClass : npnClasses) {
856 npnClass.truthTable.numInputs}]
857 .push_back(std::make_pair(std::move(npnClass),
pattern.get()));
872 : cutAllocator(stats.numCutsCreated),
873 cutSetAllocator(stats.numCutSetsCreated), options(options) {}
877 auto [cutSetPtr, inserted] =
cutSets.try_emplace(index, cutSet);
878 assert(inserted &&
"Cut set already exists for this index");
879 return cutSetPtr->second;
893 assert(logicOp && logicOp->getNumResults() == 1 &&
894 "Logic operation must have a single result");
897 auto choiceOp = cast<synth::ChoiceOp>(logicOp);
901 resultCutSet->addCut(primaryInputCut);
903 for (Value operand : choiceOp.getInputs()) {
906 return logicOp->emitError(
"Failed to get cut set for choice operand");
910 for (
const Cut *operandCut : operandCutSet->getCuts()) {
911 if (operandCut->isTrivialCut())
915 nodeIndex, operandCut->inputs, operandCut->getSignature(),
916 ArrayRef<const Cut *>{operandCut}, *operandCut->getTruthTable()));
925 unsigned numFanins = gate.getNumFanins();
930 return logicOp->emitError(
"Cut enumeration supports at most 3 operands, "
933 if (!logicOp->getOpResult(0).getType().isInteger(1))
934 return logicOp->emitError()
935 <<
"Supported logic operations must have a single bit "
936 "result type but found: "
937 << logicOp->getResult(0).getType();
941 SmallVector<std::pair<const CutSet *, unsigned>, 2> operandCutSets;
942 operandCutSets.reserve(numFanins);
945 for (
unsigned i = 0; i < numFanins; ++i) {
946 uint32_t faninIndex = gate.edges[i].getIndex();
947 auto *operandCutSet =
getCutSet(faninIndex);
949 return logicOp->emitError(
"Failed to get cut set for fanin index ")
954 unsigned maxInputCutSize = 0;
955 for (
auto *cut : operandCutSet->getCuts())
956 maxInputCutSize = std::max(maxInputCutSize, cut->
getInputSize());
957 operandCutSets.push_back(std::make_pair(operandCutSet, maxInputCutSize));
965 resultCutSet->addCut(primaryInputCut);
971 llvm::stable_sort(operandCutSets,
972 [](
const std::pair<const CutSet *, unsigned> &a,
973 const std::pair<const CutSet *, unsigned> &b) {
974 return a.second > b.second;
982 auto enumerateCutCombinations = [&](
auto &&self,
unsigned operandIdx,
983 SmallVector<const Cut *, 3> &cutPtrs,
984 uint64_t currentSig) ->
void {
986 if (operandIdx == numFanins) {
990 SmallVector<uint32_t, 6> mergedInputs;
991 auto appendMergedInput = [&](uint32_t value) {
995 if (!mergedInputs.empty() && mergedInputs.back() == value)
997 mergedInputs.push_back(value);
998 return mergedInputs.size() <= maxInputSize;
1001 if (numFanins == 1) {
1003 mergedInputs.reserve(
1004 std::min<size_t>(cutPtrs[0]->inputs.size(), maxInputSize));
1005 for (uint32_t value : cutPtrs[0]->inputs)
1006 if (!appendMergedInput(value))
1008 }
else if (numFanins == 2) {
1010 const auto &inputs0 = cutPtrs[0]->inputs;
1011 const auto &inputs1 = cutPtrs[1]->inputs;
1012 mergedInputs.reserve(
1013 std::min<size_t>(inputs0.size() + inputs1.size(), maxInputSize));
1015 unsigned i = 0, j = 0;
1016 while (i < inputs0.size() || j < inputs1.size()) {
1018 if (j == inputs1.size() ||
1019 (i < inputs0.size() && inputs0[i] <= inputs1[j])) {
1020 next = inputs0[i++];
1021 if (j < inputs1.size() && inputs1[j] == next)
1024 next = inputs1[j++];
1026 if (!appendMergedInput(next))
1031 const SmallVectorImpl<uint32_t> &inputs0 = cutPtrs[0]->inputs;
1032 const SmallVectorImpl<uint32_t> &inputs1 = cutPtrs[1]->inputs;
1033 const SmallVectorImpl<uint32_t> &inputs2 = cutPtrs[2]->inputs;
1034 mergedInputs.reserve(std::min<size_t>(
1035 inputs0.size() + inputs1.size() + inputs2.size(), maxInputSize));
1037 unsigned i = 0, j = 0, k = 0;
1038 while (i < inputs0.size() || j < inputs1.size() || k < inputs2.size()) {
1040 uint32_t minVal = UINT32_MAX;
1041 if (i < inputs0.size())
1042 minVal = std::min(minVal, inputs0[i]);
1043 if (j < inputs1.size())
1044 minVal = std::min(minVal, inputs1[j]);
1045 if (k < inputs2.size())
1046 minVal = std::min(minVal, inputs2[k]);
1049 if (i < inputs0.size() && inputs0[i] == minVal)
1051 if (j < inputs1.size() && inputs1[j] == minVal)
1053 if (k < inputs2.size() && inputs2[k] == minVal)
1056 if (!appendMergedInput(minVal))
1062 Cut *mergedCut =
cutAllocator.create(nodeIndex, mergedInputs, currentSig,
1063 ArrayRef<const Cut *>(cutPtrs));
1064 resultCutSet->addCut(mergedCut);
1067 if (mergedCut->
inputs.size() >= 4) {
1068 llvm::dbgs() <<
"Generated cut for node " << nodeIndex;
1070 llvm::dbgs() <<
" (" << logicOp->getName() <<
")";
1071 llvm::dbgs() <<
" inputs=";
1072 llvm::interleaveComma(mergedCut->inputs, llvm::dbgs());
1073 llvm::dbgs() <<
"\n";
1080 const CutSet *currentCutSet = operandCutSets[operandIdx].first;
1081 for (
const Cut *cut : currentCutSet->
getCuts()) {
1083 uint64_t newSig = currentSig | cutSig;
1084 if (
static_cast<unsigned>(llvm::popcount(newSig)) > maxInputSize)
1087 cutPtrs.push_back(cut);
1090 self(self, operandIdx + 1, cutPtrs, newSig);
1097 SmallVector<const Cut *, 3> cutPtrs;
1098 cutPtrs.reserve(numFanins);
1099 enumerateCutCombinations(enumerateCutCombinations, 0, cutPtrs, 0ULL);
1109 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)> matchCut) {
1110 LLVM_DEBUG(llvm::dbgs() <<
"Enumerating cuts for module: " << topOp->getName()
1120 auto &block = topOp->getRegion(0).getBlocks().front();
1126 if (!gate.isLogicGate())
1134 LLVM_DEBUG(llvm::dbgs() <<
"Cut enumeration completed successfully\n");
1140 auto it =
cutSets.find(index);
1145 cutSet->
addCut(trivialCut);
1146 auto [newIt, inserted] =
cutSets.insert({index, cutSet});
1147 assert(inserted &&
"Cut set already exists for this index");
1159 if (
auto *op = value.getDefiningOp()) {
1162 if (
auto name = op->getAttrOfType<StringAttr>(
"sv.namehint"))
1163 return name.getValue();
1167 if (op->getNumResults() == 1) {
1168 auto opName = op->getName();
1169 auto count = opCounter[opName]++;
1172 SmallString<16> nameStr;
1173 nameStr += opName.getStringRef();
1175 nameStr += std::to_string(count);
1178 auto nameAttr = StringAttr::get(op->getContext(), nameStr);
1179 op->setAttr(
"sv.namehint", nameAttr);
1188 auto blockArg = cast<BlockArgument>(value);
1190 dyn_cast<circt::hw::HWModuleOp>(blockArg.getOwner()->getParentOp());
1195 return hwOp.getInputName(blockArg.getArgNumber());
1199 DenseMap<OperationName, unsigned> opCounter;
1201 auto it =
cutSets.find(index);
1204 auto &cutSet = *it->second;
1207 << cutSet.getCuts().size() <<
" cuts:";
1208 for (
const Cut *cut : cutSet.getCuts()) {
1209 llvm::outs() <<
" {";
1210 llvm::interleaveComma(cut->
inputs, llvm::outs(), [&](uint32_t inputIdx) {
1211 mlir::Value inputVal = logicNetwork.getValue(inputIdx);
1212 llvm::outs() << getTestVariableName(inputVal, opCounter);
1216 <<
"@t" << cut->
getTruthTable()->table.getZExtValue() <<
"d";
1218 llvm::outs() << *std::max_element(
pattern->getArrivalTimes().begin(),
1219 pattern->getArrivalTimes().end());
1221 llvm::outs() <<
"0";
1224 llvm::outs() <<
"\n";
1226 llvm::outs() <<
"Cut enumeration completed successfully\n";
1235 llvm::dbgs() <<
"Starting Cut Rewriter\n";
1236 llvm::dbgs() <<
"Mode: "
1248 if (
pattern->getNumOutputs() > 1) {
1249 return mlir::emitError(
pattern->getLoc(),
1250 "Cut rewriter does not support patterns with "
1251 "multiple outputs yet");
1278 LLVM_DEBUG(llvm::dbgs() <<
"Enumerating cuts...\n");
1281 topOp, [&](
const Cut &cut) -> std::optional<MatchedPattern> {
1287ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
1289 if (
patterns.npnToPatternMap.empty())
1293 auto it =
patterns.npnToPatternMap.find(
1294 {npnClass.truthTable.table, npnClass.truthTable.numInputs});
1295 if (it ==
patterns.npnToPatternMap.end())
1297 return it->getSecond();
1306 SmallVector<DelayType, 4> inputArrivalTimes;
1307 SmallVector<DelayType, 1> bestArrivalTimes;
1308 double bestArea = 0.0;
1316 auto computeArrivalTimeAndPickBest =
1318 llvm::function_ref<unsigned(
unsigned)> mapIndex) {
1319 SmallVector<DelayType, 1> outputArrivalTimes;
1321 for (
unsigned outputIndex = 0, outputSize = cut.
getOutputSize(network);
1322 outputIndex < outputSize; ++outputIndex) {
1325 auto delays = matchResult.getDelays();
1326 for (
unsigned inputIndex = 0, inputSize = cut.
getInputSize();
1327 inputIndex < inputSize; ++inputIndex) {
1329 unsigned cutOriginalInput = mapIndex(inputIndex);
1331 std::max(outputArrivalTime,
1332 delays[outputIndex * inputSize + inputIndex] +
1333 inputArrivalTimes[cutOriginalInput]);
1336 outputArrivalTimes.push_back(outputArrivalTime);
1342 outputArrivalTimes, bestArea,
1343 bestArrivalTimes)) {
1345 llvm::dbgs() <<
"== Matched Pattern ==============\n";
1346 llvm::dbgs() <<
"Matching cut: \n";
1347 cut.
dump(llvm::dbgs(), network);
1348 llvm::dbgs() <<
"Found better pattern: "
1350 llvm::dbgs() <<
" with area: " << matchResult.area;
1351 llvm::dbgs() <<
" and input arrival times: ";
1352 for (
unsigned i = 0; i < inputArrivalTimes.size(); ++i) {
1353 llvm::dbgs() <<
" " << inputArrivalTimes[i];
1355 llvm::dbgs() <<
" and arrival times: ";
1357 for (
auto arrivalTime : outputArrivalTimes) {
1358 llvm::dbgs() <<
" " << arrivalTime;
1360 llvm::dbgs() <<
"\n";
1361 llvm::dbgs() <<
"== Matched Pattern End ==============\n";
1364 bestArrivalTimes = std::move(outputArrivalTimes);
1365 bestArea = matchResult.area;
1372 "Pattern input size must match cut input size");
1379 SmallVector<unsigned> inputMapping;
1381 computeArrivalTimeAndPickBest(
pattern, *matchResult,
1382 [&](
unsigned i) {
return inputMapping[i]; });
1387 computeArrivalTimeAndPickBest(
pattern, *matchResult,
1388 [&](
unsigned i) {
return i; });
1394 return MatchedPattern(bestPattern, std::move(bestArrivalTimes), bestArea);
1398 LLVM_DEBUG(llvm::dbgs() <<
"Performing cut-based rewriting...\n");
1405 PatternRewriter rewriter(top->getContext());
1408 for (
auto index : llvm::reverse(processingOrder)) {
1409 auto it = cutSets.find(index);
1410 if (it == cutSets.end())
1413 mlir::Value value = network.getValue(index);
1414 auto &cutSet = *it->second;
1416 if (value.use_empty()) {
1417 if (
auto *op = value.getDefiningOp())
1424 LLVM_DEBUG(llvm::dbgs() <<
"Skipping inputs: " << value <<
"\n");
1428 LLVM_DEBUG(llvm::dbgs() <<
"Cut set for value: " << value <<
"\n");
1429 auto *bestCut = cutSet.getBestMatchedCut();
1433 return emitError(value.getLoc(),
"No matching cut found for value: ")
1438 auto *rootOp = network.getGate(bestCut->getRootIndex()).getOperation();
1439 rewriter.setInsertionPoint(rootOp);
1440 const auto &matchedPattern = bestCut->getMatchedPattern();
1441 auto result = matchedPattern->getPattern()->rewrite(rewriter,
cutEnumerator,
1446 rewriter.replaceOp(rootOp, *result);
1450 auto array = rewriter.getI64ArrayAttr(matchedPattern->getArrivalTimes());
1451 (*result)->setAttr(
"test.arrival_times", array);
assert(baseType &&"element must be base type")
static llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind, const llvm::APInt &a)
Simulate a gate and return its truth table.
static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl< Cut * > &cuts)
static bool isAlwaysCutInput(const LogicNetwork &network, uint32_t index)
static StringRef getTestVariableName(Value value, DenseMap< OperationName, unsigned > &opCounter)
Generate a human-readable name for a value used in test output.
static bool compareDelayAndArea(OptimizationStrategy strategy, double newArea, ArrayRef< DelayType > newDelay, double oldArea, ArrayRef< DelayType > oldDelay)
RewritePatternSet pattern
Precomputed NPN canonicalization table for 4-input single-output functions.
bool lookup(const BinaryTruthTable &tt, NPNClass &result) const
Returns false if the given truth table shape is unsupported.
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
TrackedSpecificBumpPtrAllocator< CutSet > cutSetAllocator
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
void noteCutRewritten()
Record that one cut was successfully rewritten.
CutEnumerator(const CutRewriterOptions &options)
Constructor for cut enumerator.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
TrackedSpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
CutRewritePatternSet(llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns)
Constructor that takes ownership of the provided patterns.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
CutEnumerator cutEnumerator
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
static Cut getTrivialCut(uint32_t index)
Create a trivial cut for a value.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
uint64_t signature
Signature bitset for fast cut size estimation.
uint64_t getSignature() const
Get the signature of this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
void setTruthTable(BinaryTruthTable tt)
Set the truth table directly (used for incremental computation).
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
llvm::SmallVector< const Cut *, 3 > operandCuts
Operand cuts used to create this cut (for lazy TT computation).
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
void setSignature(uint64_t sig)
Set the signature of this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another (i.e., this cut's inputs are a subset of the other's inputs).
void getPermutatedInputIndices(const NPNTable *npnTable, const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
unsigned getInputSize() const
Get the number of inputs to this cut.
void setMatchedPattern(MatchedPattern pattern)
Matched pattern for this cut.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
double getArea() const
Get the area cost of using this pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
llvm::APInt expandTruthTableToInputSpace(const llvm::APInt &tt, ArrayRef< unsigned > inputMapping, unsigned numExpandedInputs)
Expand a truth table to a larger input space using the given input mapping.
OptimizationStrategy
Optimization strategy.
@ OptimizationStrategyArea
Optimize for minimal area.
@ OptimizationStrategyTiming
Optimize for minimal critical path delay.
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::APInt createVarMask(unsigned numVars, unsigned varIndex, bool positive)
Create a mask for a variable in the truth table.
Represents a boolean function as a truth table.
Represents the canonical form of a boolean function under NPN equivalence.
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
void getInputPermutation(const NPNClass &targetNPN, llvm::SmallVectorImpl< unsigned > &permutation) const
Get input permutation from this NPN class to another equivalent NPN class.
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...
void eraseNow(Operation *op)
Erase an operation immediately, and remove it from the set of ops to be removed later.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
const NPNTable * npnTable
Optional lookup table used to accelerate 4-input NPN canonicalization.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Signal edges[3]
Fanin edges (up to 3 inputs).
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
Kind getKind() const
Get the kind of this gate.
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Identity
Identity gate (used for 1-input inverter)
@ Maj3
Reserved 3-input gate kind.
@ PrimaryInput
Primary input to the network.
@ Choice
Choice node (synth.choice)
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
Edge representation in the logic network.