29#include "mlir/Analysis/TopologicalSortUtils.h"
30#include "mlir/IR/Builders.h"
31#include "mlir/IR/Operation.h"
32#include "mlir/IR/RegionKindInterface.h"
33#include "mlir/IR/Value.h"
34#include "mlir/IR/ValueRange.h"
35#include "mlir/IR/Visitors.h"
36#include "mlir/Support/LLVM.h"
37#include "llvm/ADT/APInt.h"
38#include "llvm/ADT/Bitset.h"
39#include "llvm/ADT/DenseMap.h"
40#include "llvm/ADT/MapVector.h"
41#include "llvm/ADT/STLExtras.h"
42#include "llvm/ADT/ScopeExit.h"
43#include "llvm/ADT/SetVector.h"
44#include "llvm/ADT/SmallVector.h"
45#include "llvm/ADT/TypeSwitch.h"
46#include "llvm/ADT/iterator.h"
47#include "llvm/Support/Debug.h"
48#include "llvm/Support/ErrorHandling.h"
49#include "llvm/Support/LogicalResult.h"
56#define DEBUG_TYPE "synth-cut-rewriter"
77 "Value not found in LogicNetwork - use getOrCreateIndex or check with "
92 "Index out of bounds in LogicNetwork::getValue");
97 SmallVectorImpl<Value> &values)
const {
99 values.reserve(indices.size());
100 for (uint32_t idx : indices)
111 Value result, ArrayRef<Signal> operands) {
119 const size_t estimatedSize =
120 block->getArguments().size() + block->getOperations().size();
122 gates.reserve(estimatedSize);
124 auto handleSingleInputGate = [&](Operation *op, Value result,
125 const Signal &inputSignal) {
126 if (!inputSignal.isInverted()) {
136 for (Value arg : block->getArguments()) {
141 auto handleOtherResults = [&](Operation *op) {
142 for (Value result : op->getResults()) {
143 if (result.getType().isInteger(1) && !
hasIndex(result))
149 for (Operation &op : block->getOperations()) {
150 LogicalResult result =
151 llvm::TypeSwitch<Operation *, LogicalResult>(&op)
152 .Case<aig::AndInverterOp>([&](aig::AndInverterOp andOp) {
153 const auto inputs = andOp.getInputs();
154 if (inputs.size() == 1) {
156 const Signal inputSignal =
158 handleSingleInputGate(andOp, andOp.getResult(), inputSignal);
159 }
else if (inputs.size() == 2) {
168 handleOtherResults(andOp);
173 if (xorOp->getNumOperands() != 2) {
174 handleOtherResults(xorOp);
184 .Case<synth::mig::MajorityInverterOp>(
185 [&](synth::mig::MajorityInverterOp majOp) {
186 if (majOp->getNumOperands() == 1) {
189 majOp.getOperand(0), majOp.isInverted(0));
190 handleSingleInputGate(majOp, majOp.getResult(),
194 if (majOp->getNumOperands() != 3) {
196 handleOtherResults(majOp);
200 majOp.isInverted(0));
202 majOp.isInverted(1));
204 majOp.isInverted(2));
206 {aSignal, bSignal, cSignal});
210 Value result = constOp.getResult();
211 if (!result.getType().isInteger(1)) {
212 handleOtherResults(constOp);
220 .Default([&](Operation *defaultOp) {
221 handleOtherResults(defaultOp);
250 const auto &gate = network.
getGate(index);
257 ArrayRef<DelayType> newDelay,
double oldArea,
258 ArrayRef<DelayType> oldDelay) {
260 return newArea < oldArea || (newArea == oldArea && newDelay < oldDelay);
262 return newDelay < oldDelay || (newDelay == oldDelay && newArea < oldArea);
263 llvm_unreachable(
"Unknown mapping strategy");
267 const auto isOperationReady = [](Value value, Operation *op) ->
bool {
270 return !(isa<aig::AndInverterOp, mig::MajorityInverterOp>(op) ||
276 return emitError(topOp->getLoc(),
277 "failed to sort operations topologically");
284 llvm::SmallSetVector<Value, 4> inputArgs;
285 for (Value arg : block->getArguments())
286 inputArgs.insert(arg);
288 if (inputArgs.empty())
291 const int64_t numInputs = inputArgs.size();
292 const int64_t numOutputs = values.size();
297 return mlir::emitError(values.front().getLoc(),
298 "Truth table is too large");
299 return mlir::emitError(values.front().getLoc(),
300 "Multiple outputs are not supported yet");
304 DenseMap<Value, APInt> eval;
305 for (uint32_t i = 0; i < numInputs; ++i)
309 for (Operation &op : *block) {
310 if (op.getNumResults() == 0)
314 if (
auto andOp = dyn_cast<aig::AndInverterOp>(&op)) {
315 SmallVector<llvm::APInt, 2> inputs;
316 inputs.reserve(andOp.getInputs().size());
317 for (
auto input : andOp.getInputs()) {
318 auto it = eval.find(input);
319 if (it == eval.end())
320 return andOp.emitError(
"Input value not found in evaluation map");
321 inputs.push_back(it->second);
323 eval[andOp.getResult()] = andOp.evaluate(inputs);
324 }
else if (
auto xorOp = dyn_cast<comb::XorOp>(&op)) {
325 auto it = eval.find(xorOp.getOperand(0));
326 if (it == eval.end())
327 return xorOp.emitError(
"Input value not found in evaluation map");
328 llvm::APInt result = it->second;
329 for (
unsigned i = 1; i < xorOp.getNumOperands(); ++i) {
330 it = eval.find(xorOp.getOperand(i));
331 if (it == eval.end())
332 return xorOp.emitError(
"Input value not found in evaluation map");
333 result ^= it->second;
335 eval[xorOp.getResult()] = result;
336 }
else if (
auto migOp = dyn_cast<synth::mig::MajorityInverterOp>(&op)) {
337 SmallVector<llvm::APInt, 3> inputs;
338 inputs.reserve(migOp.getInputs().size());
339 for (
auto input : migOp.getInputs()) {
340 auto it = eval.find(input);
341 if (it == eval.end())
342 return migOp.emitError(
"Input value not found in evaluation map");
343 inputs.push_back(it->second);
345 eval[migOp.getResult()] = migOp.evaluate(inputs);
346 }
else if (!isa<hw::OutputOp>(&op)) {
347 return op.emitError(
"Unsupported operation for truth table simulation");
374 npnClass.emplace(std::move(canonicalForm));
380 SmallVectorImpl<unsigned> &permutedIndices)
const {
382 npnClass.getInputPermutation(patternNPN, permutedIndices);
387 SmallVectorImpl<DelayType> &results)
const {
392 for (
auto inputIndex :
inputs) {
395 results.push_back(0);
398 auto *cutSet = enumerator.
getCutSet(inputIndex);
399 assert(cutSet &&
"Input must have a valid cut set");
403 auto *bestCut = cutSet->getBestMatchedCut();
410 mlir::Value inputValue = network.getValue(inputIndex);
414 cast<mlir::OpResult>(inputValue).getResultNumber()));
421 os <<
"// === Cut Dump ===\n";
426 os <<
" and root: " << *rootOp;
432 os <<
"Primary input cut: " << inputVal <<
"\n";
436 os <<
"Inputs (indices): \n";
437 for (
auto [idx, inputIndex] : llvm::enumerate(
inputs)) {
438 mlir::Value inputVal = network.
getValue(inputIndex);
439 os <<
" Input " << idx <<
" (index " << inputIndex <<
"): " << inputVal
444 os <<
"\nRoot operation: \n";
453 os <<
"// === Cut End ===\n";
462 return rootOp ? rootOp->getNumResults() : 1;
467 const llvm::APInt &a) {
472 llvm_unreachable(
"Unsupported unary operation for truth table computation");
477 const llvm::APInt &a,
478 const llvm::APInt &b) {
486 "Unsupported binary operation for truth table computation");
491 const llvm::APInt &a,
492 const llvm::APInt &b,
493 const llvm::APInt &c) {
496 return (a & b) | (a & c) | (b & c);
499 "Unsupported ternary operation for truth table computation");
505 llvm::DenseMap<uint32_t, llvm::APInt> &cache,
506 unsigned numInputs) {
508 auto cacheIt = cache.find(index);
509 if (cacheIt != cache.end())
510 return cacheIt->second;
512 const auto &gate = network.
getGate(index);
515 auto getEdgeTT = [&](
const Signal &edge) {
517 if (edge.isInverted())
522 switch (gate.getKind()) {
526 result = llvm::APInt::getZero(1U << numInputs);
528 result = llvm::APInt::getAllOnes(1U << numInputs);
534 llvm_unreachable(
"Primary input not in cache - not a cut input?");
539 getEdgeTT(gate.edges[1]));
546 getEdgeTT(gate.edges[1]), getEdgeTT(gate.edges[2]));
556 cache[index] = result;
568 unsigned numInputs =
inputs.size();
570 llvm_unreachable(
"Too many inputs for truth table computation");
574 llvm::DenseMap<uint32_t, llvm::APInt> cache;
575 for (
unsigned i = 0; i < numInputs; ++i) {
595 return std::includes(otherInputs.begin(), otherInputs.end(),
inputs.begin(),
602 cut.
inputs.push_back(index);
613 assert(
pattern &&
"Pattern must be set to get arrival time");
618 assert(
pattern &&
"Pattern must be set to get arrival time");
650 auto dumpInputs = [](llvm::raw_ostream &os,
651 const llvm::SmallVectorImpl<uint32_t> &inputs) {
653 llvm::interleaveComma(inputs, os);
658 std::stable_sort(cuts.begin(), cuts.end(), [](
const Cut *a,
const Cut *b) {
659 if (a->getInputSize() != b->getInputSize())
660 return a->getInputSize() < b->getInputSize();
661 return std::lexicographical_compare(a->inputs.begin(), a->inputs.end(),
662 b->inputs.begin(), b->inputs.end());
666 unsigned maxCutSize = cuts.empty() ? 0 : cuts.back()->getInputSize();
667 llvm::SmallVector<llvm::SmallVector<Cut *, 4>, 16> keptBySize(maxCutSize + 1);
670 unsigned uniqueCount = 0;
671 for (
Cut *cut : cuts) {
672 unsigned cutSize = cut->getInputSize();
676 if (uniqueCount > 0) {
677 Cut *lastKept = cuts[uniqueCount - 1];
679 lastKept->
inputs == cut->inputs)
683 bool isDominated =
false;
684 for (
unsigned existingSize = 1; existingSize < cutSize && !isDominated;
686 for (
const Cut *existingCut : keptBySize[existingSize]) {
687 if (!existingCut->dominates(*cut))
691 llvm::dbgs() <<
"Dropping non-minimal cut ";
692 dumpInputs(llvm::dbgs(), cut->inputs);
693 llvm::dbgs() <<
" due to subset ";
694 dumpInputs(llvm::dbgs(), existingCut->inputs);
695 llvm::dbgs() <<
"\n";
705 cuts[uniqueCount++] = cut;
706 keptBySize[cutSize].push_back(cut);
709 LLVM_DEBUG(llvm::dbgs() <<
"Original cuts: " << cuts.size()
710 <<
" Unique cuts: " << uniqueCount <<
"\n");
713 cuts.resize(uniqueCount);
718 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)> matchCut,
727 if (!cut->getTruthTable().has_value())
728 cut->computeTruthTableFromOperands(logicNetwork);
731 "Cut input size exceeds maximum allowed size");
733 if (
auto matched = matchCut(*cut))
734 cut->setMatchedPattern(std::move(*matched));
747 auto *trivialCutsEnd =
748 std::stable_partition(
cuts.begin(),
cuts.end(),
749 [](
const Cut *cut) { return cut->isTrivialCut(); });
751 auto isBetterCut = [&options](
const Cut *a,
const Cut *b) {
752 assert(!a->isTrivialCut() && !b->isTrivialCut() &&
753 "Trivial cuts should have been excluded");
754 const auto &aMatched = a->getMatchedPattern();
755 const auto &bMatched = b->getMatchedPattern();
757 if (aMatched && bMatched)
759 options.
strategy, aMatched->getArea(), aMatched->getArrivalTimes(),
760 bMatched->getArea(), bMatched->getArrivalTimes());
762 if (
static_cast<bool>(aMatched) !=
static_cast<bool>(bMatched))
763 return static_cast<bool>(aMatched);
765 return a->getInputSize() < b->getInputSize();
767 std::stable_sort(trivialCutsEnd,
cuts.end(), isBetterCut);
784 llvm::dbgs() <<
"Finalized cut set with " <<
cuts.size() <<
" cuts and "
789 :
"no matched pattern")
801 SmallVectorImpl<NPNClass> &matchingNPNClasses)
const {
810 llvm::SmallVector<std::unique_ptr<CutRewritePattern>, 4>
patterns)
814 SmallVector<NPNClass, 2> npnClasses;
815 auto result =
pattern->useTruthTableMatcher(npnClasses);
817 for (
auto npnClass : npnClasses) {
820 npnClass.truthTable.numInputs}]
821 .push_back(std::make_pair(std::move(npnClass),
pattern.get()));
836 : options(options) {}
840 auto [cutSetPtr, inserted] =
cutSets.try_emplace(index, cutSet);
841 assert(inserted &&
"Cut set already exists for this index");
842 return cutSetPtr->second;
856 assert(logicOp && logicOp->getNumResults() == 1 &&
857 "Logic operation must have a single result");
859 unsigned numFanins = gate.getNumFanins();
864 return logicOp->emitError(
"Cut enumeration supports at most 3 operands, "
867 if (!logicOp->getOpResult(0).getType().isInteger(1))
868 return logicOp->emitError()
869 <<
"Supported logic operations must have a single bit "
870 "result type but found: "
871 << logicOp->getResult(0).getType();
873 SmallVector<const CutSet *, 2> operandCutSets;
874 operandCutSets.reserve(numFanins);
877 for (
unsigned i = 0; i < numFanins; ++i) {
878 uint32_t faninIndex = gate.edges[i].getIndex();
879 auto *operandCutSet =
getCutSet(faninIndex);
881 return logicOp->emitError(
"Failed to get cut set for fanin index ")
883 operandCutSets.push_back(operandCutSet);
892 resultCutSet->addCut(primaryInputCut);
895 llvm::scope_exit prune([&]() {
905 auto enumerateCutCombinations =
906 [&](
auto &&self,
unsigned operandIdx,
907 SmallVector<const Cut *, 3> &cutPtrs) ->
void {
909 if (operandIdx == numFanins) {
913 SmallVector<uint32_t, 6> mergedInputs;
914 auto appendMergedInput = [&](uint32_t value) {
918 if (!mergedInputs.empty() && mergedInputs.back() == value)
920 mergedInputs.push_back(value);
921 return mergedInputs.size() <= maxInputSize;
924 if (numFanins == 1) {
926 mergedInputs.reserve(
927 std::min<size_t>(cutPtrs[0]->inputs.size(), maxInputSize));
928 for (uint32_t value : cutPtrs[0]->inputs)
929 if (!appendMergedInput(value))
931 }
else if (numFanins == 2) {
933 const auto &inputs0 = cutPtrs[0]->inputs;
934 const auto &inputs1 = cutPtrs[1]->inputs;
935 mergedInputs.reserve(
936 std::min<size_t>(inputs0.size() + inputs1.size(), maxInputSize));
938 unsigned i = 0, j = 0;
939 while (i < inputs0.size() || j < inputs1.size()) {
941 if (j == inputs1.size() ||
942 (i < inputs0.size() && inputs0[i] <= inputs1[j])) {
944 if (j < inputs1.size() && inputs1[j] == next)
949 if (!appendMergedInput(next))
954 const SmallVectorImpl<uint32_t> &inputs0 = cutPtrs[0]->inputs;
955 const SmallVectorImpl<uint32_t> &inputs1 = cutPtrs[1]->inputs;
956 const SmallVectorImpl<uint32_t> &inputs2 = cutPtrs[2]->inputs;
957 mergedInputs.reserve(std::min<size_t>(
958 inputs0.size() + inputs1.size() + inputs2.size(), maxInputSize));
960 unsigned i = 0, j = 0, k = 0;
961 while (i < inputs0.size() || j < inputs1.size() || k < inputs2.size()) {
963 uint32_t minVal = UINT32_MAX;
964 if (i < inputs0.size())
965 minVal = std::min(minVal, inputs0[i]);
966 if (j < inputs1.size())
967 minVal = std::min(minVal, inputs1[j]);
968 if (k < inputs2.size())
969 minVal = std::min(minVal, inputs2[k]);
972 if (i < inputs0.size() && inputs0[i] == minVal)
974 if (j < inputs1.size() && inputs1[j] == minVal)
976 if (k < inputs2.size() && inputs2[k] == minVal)
979 if (!appendMergedInput(minVal))
987 mergedCut->
inputs = std::move(mergedInputs);
992 resultCutSet->addCut(mergedCut);
995 if (mergedCut->
inputs.size() >= 4) {
996 llvm::dbgs() <<
"Generated cut for node " << nodeIndex;
998 llvm::dbgs() <<
" (" << logicOp->getName() <<
")";
999 llvm::dbgs() <<
" inputs=";
1000 llvm::interleaveComma(mergedCut->inputs, llvm::dbgs());
1001 llvm::dbgs() <<
"\n";
1008 const CutSet *currentCutSet = operandCutSets[operandIdx];
1009 for (
const Cut *cut : currentCutSet->
getCuts()) {
1010 cutPtrs.push_back(cut);
1013 self(self, operandIdx + 1, cutPtrs);
1020 SmallVector<const Cut *, 3> cutPtrs;
1021 cutPtrs.reserve(numFanins);
1022 enumerateCutCombinations(enumerateCutCombinations, 0, cutPtrs);
1029 llvm::function_ref<std::optional<MatchedPattern>(
const Cut &)> matchCut) {
1030 LLVM_DEBUG(llvm::dbgs() <<
"Enumerating cuts for module: " << topOp->getName()
1040 auto &block = topOp->getRegion(0).getBlocks().front();
1046 if (!gate.isLogicGate())
1054 LLVM_DEBUG(llvm::dbgs() <<
"Cut enumeration completed successfully\n");
1060 auto it =
cutSets.find(index);
1066 cutSet->
addCut(trivialCut);
1067 auto [newIt, inserted] =
cutSets.insert({index, cutSet});
1068 assert(inserted &&
"Cut set already exists for this index");
1080 if (
auto *op = value.getDefiningOp()) {
1083 if (
auto name = op->getAttrOfType<StringAttr>(
"sv.namehint"))
1084 return name.getValue();
1088 if (op->getNumResults() == 1) {
1089 auto opName = op->getName();
1090 auto count = opCounter[opName]++;
1093 SmallString<16> nameStr;
1094 nameStr += opName.getStringRef();
1096 nameStr += std::to_string(count);
1099 auto nameAttr = StringAttr::get(op->getContext(), nameStr);
1100 op->setAttr(
"sv.namehint", nameAttr);
1109 auto blockArg = cast<BlockArgument>(value);
1111 dyn_cast<circt::hw::HWModuleOp>(blockArg.getOwner()->getParentOp());
1116 return hwOp.getInputName(blockArg.getArgNumber());
1120 DenseMap<OperationName, unsigned> opCounter;
1122 auto it =
cutSets.find(index);
1125 auto &cutSet = *it->second;
1128 << cutSet.getCuts().size() <<
" cuts:";
1129 for (
const Cut *cut : cutSet.getCuts()) {
1130 llvm::outs() <<
" {";
1131 llvm::interleaveComma(cut->inputs, llvm::outs(), [&](uint32_t inputIdx) {
1132 mlir::Value inputVal = logicNetwork.getValue(inputIdx);
1133 llvm::outs() << getTestVariableName(inputVal, opCounter);
1135 auto &
pattern = cut->getMatchedPattern();
1137 <<
"@t" << cut->getTruthTable()->table.getZExtValue() <<
"d";
1139 llvm::outs() << *std::max_element(
pattern->getArrivalTimes().begin(),
1140 pattern->getArrivalTimes().end());
1142 llvm::outs() <<
"0";
1145 llvm::outs() <<
"\n";
1147 llvm::outs() <<
"Cut enumeration completed successfully\n";
1156 llvm::dbgs() <<
"Starting Cut Rewriter\n";
1157 llvm::dbgs() <<
"Mode: "
1169 if (
pattern->getNumOutputs() > 1) {
1170 return mlir::emitError(
pattern->getLoc(),
1171 "Cut rewriter does not support patterns with "
1172 "multiple outputs yet");
1199 LLVM_DEBUG(llvm::dbgs() <<
"Enumerating cuts...\n");
1202 topOp, [&](
const Cut &cut) -> std::optional<MatchedPattern> {
1208ArrayRef<std::pair<NPNClass, const CutRewritePattern *>>
1210 if (
patterns.npnToPatternMap.empty())
1214 auto it =
patterns.npnToPatternMap.find(
1215 {npnClass.truthTable.table, npnClass.truthTable.numInputs});
1216 if (it ==
patterns.npnToPatternMap.end())
1218 return it->getSecond();
1227 SmallVector<DelayType, 4> inputArrivalTimes;
1228 SmallVector<DelayType, 1> bestArrivalTimes;
1229 double bestArea = 0.0;
1237 auto computeArrivalTimeAndPickBest =
1239 llvm::function_ref<unsigned(
unsigned)> mapIndex) {
1240 SmallVector<DelayType, 1> outputArrivalTimes;
1242 for (
unsigned outputIndex = 0, outputSize = cut.
getOutputSize(network);
1243 outputIndex < outputSize; ++outputIndex) {
1246 auto delays = matchResult.getDelays();
1247 for (
unsigned inputIndex = 0, inputSize = cut.
getInputSize();
1248 inputIndex < inputSize; ++inputIndex) {
1250 unsigned cutOriginalInput = mapIndex(inputIndex);
1252 std::max(outputArrivalTime,
1253 delays[outputIndex * inputSize + inputIndex] +
1254 inputArrivalTimes[cutOriginalInput]);
1257 outputArrivalTimes.push_back(outputArrivalTime);
1263 outputArrivalTimes, bestArea,
1264 bestArrivalTimes)) {
1266 llvm::dbgs() <<
"== Matched Pattern ==============\n";
1267 llvm::dbgs() <<
"Matching cut: \n";
1268 cut.
dump(llvm::dbgs(), network);
1269 llvm::dbgs() <<
"Found better pattern: "
1271 llvm::dbgs() <<
" with area: " << matchResult.area;
1272 llvm::dbgs() <<
" and input arrival times: ";
1273 for (
unsigned i = 0; i < inputArrivalTimes.size(); ++i) {
1274 llvm::dbgs() <<
" " << inputArrivalTimes[i];
1276 llvm::dbgs() <<
" and arrival times: ";
1278 for (
auto arrivalTime : outputArrivalTimes) {
1279 llvm::dbgs() <<
" " << arrivalTime;
1281 llvm::dbgs() <<
"\n";
1282 llvm::dbgs() <<
"== Matched Pattern End ==============\n";
1285 bestArrivalTimes = std::move(outputArrivalTimes);
1286 bestArea = matchResult.area;
1293 "Pattern input size must match cut input size");
1300 SmallVector<unsigned> inputMapping;
1302 computeArrivalTimeAndPickBest(
pattern, *matchResult,
1303 [&](
unsigned i) {
return inputMapping[i]; });
1308 computeArrivalTimeAndPickBest(
pattern, *matchResult,
1309 [&](
unsigned i) {
return i; });
1315 return MatchedPattern(bestPattern, std::move(bestArrivalTimes), bestArea);
1319 LLVM_DEBUG(llvm::dbgs() <<
"Performing cut-based rewriting...\n");
1326 PatternRewriter rewriter(top->getContext());
1329 for (
auto index : llvm::reverse(processingOrder)) {
1330 auto it = cutSets.find(index);
1331 if (it == cutSets.end())
1334 mlir::Value value = network.getValue(index);
1335 auto &cutSet = *it->second;
1337 if (value.use_empty()) {
1338 if (
auto *op = value.getDefiningOp())
1345 LLVM_DEBUG(llvm::dbgs() <<
"Skipping inputs: " << value <<
"\n");
1349 LLVM_DEBUG(llvm::dbgs() <<
"Cut set for value: " << value <<
"\n");
1350 auto *bestCut = cutSet.getBestMatchedCut();
1354 return emitError(value.getLoc(),
"No matching cut found for value: ")
1359 auto *rootOp = network.getGate(bestCut->getRootIndex()).getOperation();
1360 rewriter.setInsertionPoint(rootOp);
1361 const auto &matchedPattern = bestCut->getMatchedPattern();
1362 auto result = matchedPattern->getPattern()->rewrite(rewriter,
cutEnumerator,
1367 rewriter.replaceOp(rootOp, *result);
1370 auto array = rewriter.getI64ArrayAttr(matchedPattern->getArrivalTimes());
1371 (*result)->setAttr(
"test.arrival_times", array);
assert(baseType &&"element must be base type")
static llvm::APInt simulateGate(const LogicNetwork &network, uint32_t index, llvm::DenseMap< uint32_t, llvm::APInt > &cache, unsigned numInputs)
Simulate a gate and return its truth table.
static llvm::APInt applyGateSemantics(LogicNetworkGate::Kind kind, const llvm::APInt &a)
Simulate a gate and return its truth table.
static void removeDuplicateAndNonMinimalCuts(SmallVectorImpl< Cut * > &cuts)
static Cut getAsTrivialCut(uint32_t index, const LogicNetwork &network)
static bool isAlwaysCutInput(const LogicNetwork &network, uint32_t index)
static StringRef getTestVariableName(Value value, DenseMap< OperationName, unsigned > &opCounter)
Generate a human-readable name for a value used in test output.
static bool compareDelayAndArea(OptimizationStrategy strategy, double newArea, ArrayRef< DelayType > newDelay, double oldArea, ArrayRef< DelayType > oldDelay)
RewritePatternSet pattern
Cut enumeration engine for combinational logic networks.
LogicalResult visitLogicOp(uint32_t nodeIndex)
Visit a combinational logic operation and generate cuts.
llvm::SmallVector< uint32_t > processingOrder
Indices in processing order.
const LogicNetwork & getLogicNetwork() const
Get the logic network (read-only).
const llvm::DenseMap< uint32_t, CutSet * > & getCutSets() const
Get cut sets (indexed by LogicNetwork index).
CutSet * createNewCutSet(uint32_t index)
Create a new cut set for an index.
ArrayRef< uint32_t > getProcessingOrder() const
Get the processing order.
const CutRewriterOptions & options
Configuration options for cut enumeration.
LogicalResult enumerateCuts(Operation *topOp, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut=[](const Cut &) { return std::nullopt;})
Enumerate cuts for all nodes in the given module.
llvm::SpecificBumpPtrAllocator< Cut > cutAllocator
Typed bump allocators for fast allocation with destructors.
CutEnumerator(const CutRewriterOptions &options)
Constructor for cut enumerator.
void clear()
Clear all cut sets and reset the enumerator.
llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut
Function to match cuts against available patterns.
llvm::DenseMap< uint32_t, CutSet * > cutSets
Maps indices to their associated cut sets.
const CutSet * getCutSet(uint32_t index)
Get the cut set for a specific index.
llvm::SpecificBumpPtrAllocator< CutSet > cutSetAllocator
LogicNetwork logicNetwork
Flat logic network representation used during enumeration/rewrite.
llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns
Owned collection of all rewriting patterns.
SmallVector< const CutRewritePattern *, 4 > nonNPNPatterns
Patterns that use custom matching logic instead of NPN lookup.
DenseMap< std::pair< APInt, unsigned >, SmallVector< std::pair< NPNClass, const CutRewritePattern * > > > npnToPatternMap
Fast lookup table mapping NPN canonical forms to matching patterns.
CutRewritePatternSet(llvm::SmallVector< std::unique_ptr< CutRewritePattern >, 4 > patterns)
Constructor that takes ownership of the provided patterns.
const CutRewriterOptions & options
Configuration options.
ArrayRef< std::pair< NPNClass, const CutRewritePattern * > > getMatchingPatternsFromTruthTable(const Cut &cut) const
Find patterns that match a cut's truth table.
std::optional< MatchedPattern > patternMatchCut(const Cut &cut)
Match a cut against available patterns and compute arrival time.
LogicalResult enumerateCuts(Operation *topOp)
Enumerate cuts for all nodes in the given module.
LogicalResult run(Operation *topOp)
Execute the complete cut-based rewriting algorithm.
CutEnumerator cutEnumerator
LogicalResult runBottomUpRewrite(Operation *topOp)
Perform the actual circuit rewriting using selected patterns.
Manages a collection of cuts for a single logic node using priority cuts algorithm.
Cut * getBestMatchedCut() const
Get the cut associated with the best matched pattern.
void addCut(Cut *cut)
Add a new cut to this set using bump allocator.
unsigned size() const
Get the number of cuts in this set.
llvm::SmallVector< Cut *, 12 > cuts
Collection of cuts for this node.
ArrayRef< Cut * > getCuts() const
Get read-only access to all cuts in this set.
bool isFrozen
Whether cut set is finalized.
void finalize(const CutRewriterOptions &options, llvm::function_ref< std::optional< MatchedPattern >(const Cut &)> matchCut, const LogicNetwork &logicNetwork)
Finalize the cut set by removing duplicates and selecting the best pattern.
Represents a cut in the combinational logic network.
void setRootIndex(uint32_t idx)
Set the root index of this cut.
std::optional< NPNClass > npnClass
Cached NPN canonical form for this cut.
void dump(llvm::raw_ostream &os, const LogicNetwork &network) const
void getPermutatedInputIndices(const NPNClass &patternNPN, SmallVectorImpl< unsigned > &permutedIndices) const
Get the permutated inputs for this cut based on the given pattern NPN.
std::optional< MatchedPattern > matchedPattern
const std::optional< MatchedPattern > & getMatchedPattern() const
Get the matched pattern for this cut.
unsigned getOutputSize(const LogicNetwork &network) const
Get the number of outputs from root operation.
const std::optional< BinaryTruthTable > & getTruthTable() const
Get the truth table for this cut.
const NPNClass & getNPNClass() const
Get the NPN canonical form for this cut.
llvm::SmallVector< uint32_t, 6 > inputs
External inputs to this cut (cut boundary).
LogicalResult getInputArrivalTimes(CutEnumerator &enumerator, SmallVectorImpl< DelayType > &results) const
Get arrival times for each input of this cut.
void computeTruthTableFromOperands(const LogicNetwork &network)
Compute truth table using fast incremental method from operand cuts.
void setOperandCuts(ArrayRef< const Cut * > cuts)
Set operand cuts for lazy truth table computation.
std::optional< BinaryTruthTable > truthTable
Cached truth table for this cut.
bool dominates(const Cut &other) const
Check if this cut dominates another cut.
unsigned getInputSize() const
Get the number of inputs to this cut.
void computeTruthTable(const LogicNetwork &network)
Compute and cache the truth table for this cut using the LogicNetwork.
bool isTrivialCut() const
Check if this cut represents a trivial cut.
uint32_t rootIndex
Root index in LogicNetwork (0 indicates no root for a trivial cut).
Flat logic network representation for efficient cut enumeration.
llvm::SmallVector< LogicNetworkGate > gates
Vector of all gates in the network.
bool hasIndex(Value value) const
Check if a value has been indexed.
static constexpr uint32_t kConstant0
Special constant indices.
uint32_t getOrCreateIndex(Value value)
Get or create an index for a value.
llvm::DenseMap< Value, uint32_t > valueToIndex
Map from MLIR Value to network index.
ArrayRef< LogicNetworkGate > getGates() const
uint32_t getIndex(Value value) const
Get the raw index for a value.
void getValues(ArrayRef< uint32_t > indices, SmallVectorImpl< Value > &values) const
Fill values for the given raw indices.
uint32_t addPrimaryInput(Value value)
Add a primary input to the network.
Value getValue(uint32_t index) const
Get the value for a given raw index.
LogicalResult buildFromBlock(Block *block)
Build the logic network from a region/block in topological order.
Signal getOrCreateSignal(Value value, bool inverted)
Get or create a Signal for a value.
static constexpr uint32_t kConstant1
llvm::SmallVector< Value > indexToValue
Map from network index to MLIR Value.
void clear()
Clear the network and reset to initial state.
const LogicNetworkGate & getGate(uint32_t index) const
Get the gate at a given index.
uint32_t addGate(Operation *op, LogicNetworkGate::Kind kind, Value result, llvm::ArrayRef< Signal > operands={})
Add a gate with explicit result value and operand signals.
Represents a cut that has been successfully matched to a rewriting pattern.
double area
Area cost of this pattern.
DelayType getArrivalTime(unsigned outputIndex) const
Get the arrival time of signals through this pattern.
ArrayRef< DelayType > getArrivalTimes() const
const CutRewritePattern * pattern
The matched library pattern.
double getArea() const
Get the area cost of using this pattern.
const CutRewritePattern * getPattern() const
Get the library pattern that was matched.
SmallVector< DelayType, 1 > arrivalTimes
Arrival times of outputs from this pattern.
OptimizationStrategy
Optimization strategy.
@ OptimizationStrategyArea
Optimize for minimal area.
@ OptimizationStrategyTiming
Optimize for minimal critical path delay.
FailureOr< BinaryTruthTable > getTruthTable(ValueRange values, Block *block)
Get the truth table for operations within a block.
LogicalResult topologicallySortGraphRegionBlocks(mlir::Operation *op, llvm::function_ref< bool(mlir::Value, mlir::Operation *)> isOperandReady)
This function performs a topological sort on the operations within each block of graph regions in the...
static constexpr unsigned maxTruthTableInputs
Maximum number of inputs supported for truth table generation.
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::APInt createVarMask(unsigned numVars, unsigned varIndex, bool positive)
Create a mask for a variable in the truth table.
Represents a boolean function as a truth table.
Represents the canonical form of a boolean function under NPN equivalence.
static NPNClass computeNPNCanonicalForm(const BinaryTruthTable &tt)
Compute the canonical NPN form for a given truth table.
void getInputPermutation(const NPNClass &targetNPN, llvm::SmallVectorImpl< unsigned > &permutation) const
Get input permutation from this NPN class to another equivalent NPN class.
Utility that tracks operations that have potentially become unused and allows them to be cleaned up a...
void eraseNow(Operation *op)
Erase an operation immediately, and remove it from the set of ops to be removed later.
Base class for cut rewriting patterns used in combinational logic optimization.
virtual bool useTruthTableMatcher(SmallVectorImpl< NPNClass > &matchingNPNClasses) const
Specify truth tables that this pattern can match.
Configuration options for the cut-based rewriting algorithm.
unsigned maxCutInputSize
Maximum number of inputs allowed for any cut.
unsigned maxCutSizePerRoot
Maximum number of cuts to maintain per logic node.
bool allowNoMatch
Fail if there is a root operation that has no matching pattern.
bool attachDebugTiming
Put arrival times to rewritten operations.
OptimizationStrategy strategy
Optimization strategy (area vs. timing).
bool testPriorityCuts
Run priority cuts enumeration and dump the cut sets.
Represents a single gate/node in the flat logic network.
Operation * getOperation() const
Get the operation pointer (nullptr for constants).
bool isAlwaysCutInput() const
Check if this should always be a cut input (PI or constant).
@ And2
AND gate (2-input, aig::AndInverterOp)
@ Identity
Identity gate (used for 1-input inverter)
@ Maj3
Majority gate (3-input, mig::MajOp)
@ PrimaryInput
Primary input to the network.
@ Constant
Constant 0/1 node (index 0 = const0, index 1 = const1)
Result of matching a cut against a pattern.
Edge representation in the logic network.