22#include "mlir/IR/Attributes.h"
23#include "mlir/IR/Builders.h"
24#include "mlir/IR/BuiltinOps.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Pass/Pass.h"
27#include "mlir/Support/LogicalResult.h"
28#include "llvm/ADT/APInt.h"
29#include "llvm/ADT/ArrayRef.h"
30#include "llvm/ADT/DenseMap.h"
31#include "llvm/ADT/DenseSet.h"
32#include "llvm/ADT/MapVector.h"
33#include "llvm/ADT/STLFunctionalExtras.h"
34#include "llvm/ADT/SmallVector.h"
35#include "llvm/ADT/StringRef.h"
36#include "llvm/ADT/TypeSwitch.h"
37#include "llvm/Support/Debug.h"
40#define DEBUG_TYPE "synth-functional-reduction"
43 "synth.test.fc_equiv_class";
47#define GEN_PASS_DEF_FUNCTIONALREDUCTION
48#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
56enum class EquivResult { Proved, Disproved, Unknown };
58std::unique_ptr<IncrementalSATSolver>
59createFunctionalReductionSATSolver(llvm::StringRef backend) {
60 if (backend ==
"auto") {
65 if (backend ==
"cadical")
72class FunctionalReductionSATBuilder {
75 llvm::DenseMap<Value, int> &satVars,
76 llvm::DenseSet<Value> &encodedValues,
79 EquivResult verify(Value lhs, Value rhs);
82 int getOrCreateVar(Value value);
86 int getLiteral(Value value,
bool inverted =
false);
87 void addAndClauses(
int outVar, llvm::ArrayRef<int> inputLits);
88 void addOrClauses(
int outVar, llvm::ArrayRef<int> inputLits);
89 void addXorClauses(
int outVar,
int lhsLit,
int rhsLit);
90 void addParityClauses(
int outVar, llvm::ArrayRef<int> inputLits);
91 void encodeValue(Value value);
94 llvm::DenseMap<Value, int> &satVars;
95 llvm::DenseSet<Value> &encodedValues;
99static bool isFunctionalReductionSimulatableOp(Operation *op) {
100 return isa<aig::AndInverterOp, comb::AndOp, comb::OrOp, comb::XorOp>(op);
103EquivResult FunctionalReductionSATBuilder::verify(Value lhs, Value rhs) {
107 int lhsVar = getOrCreateVar(lhs);
108 int rhsVar = getOrCreateVar(rhs);
112 solver.assume(lhsVar);
113 solver.assume(-rhsVar);
114 auto result = solver.solve();
116 return EquivResult::Disproved;
118 return EquivResult::Unknown;
120 solver.assume(-lhsVar);
121 solver.assume(rhsVar);
122 result = solver.solve();
124 return EquivResult::Disproved;
126 return EquivResult::Unknown;
128 return EquivResult::Proved;
131int FunctionalReductionSATBuilder::getOrCreateVar(Value value) {
132 auto it = satVars.find(value);
133 assert(it != satVars.end() &&
"SAT variable must be preallocated");
137int FunctionalReductionSATBuilder::createAuxVar() {
138 int freshVar = ++nextFreshVar;
139 solver.reserveVars(freshVar);
143int FunctionalReductionSATBuilder::getLiteral(Value value,
bool inverted) {
144 int lit = getOrCreateVar(value);
145 return inverted ? -
lit :
lit;
148void FunctionalReductionSATBuilder::addAndClauses(
149 int outVar, llvm::ArrayRef<int> inputLits) {
153 for (
int lit : inputLits)
154 solver.addClause({-outVar,
lit});
156 SmallVector<int> clause;
157 for (
int lit : inputLits)
158 clause.push_back(-
lit);
159 clause.push_back(outVar);
160 solver.addClause(clause);
163void FunctionalReductionSATBuilder::addOrClauses(
164 int outVar, llvm::ArrayRef<int> inputLits) {
169 for (
int lit : inputLits)
170 solver.addClause({-
lit, outVar});
172 SmallVector<int> clause;
173 clause.reserve(inputLits.size() + 1);
176 clause.push_back(-outVar);
177 clause.append(inputLits.begin(), inputLits.end());
178 solver.addClause(clause);
181void FunctionalReductionSATBuilder::addXorClauses(
int outVar,
int lhsLit,
186 solver.addClause({-lhsLit, -rhsLit, -outVar});
187 solver.addClause({lhsLit, rhsLit, -outVar});
188 solver.addClause({lhsLit, -rhsLit, outVar});
189 solver.addClause({-lhsLit, rhsLit, outVar});
192void FunctionalReductionSATBuilder::addParityClauses(
193 int outVar, llvm::ArrayRef<int> inputLits) {
194 assert(!inputLits.empty() &&
"parity requires at least one input");
195 if (inputLits.size() == 1) {
196 solver.addClause({-outVar, inputLits.front()});
197 solver.addClause({outVar, -inputLits.front()});
201 int accumulatedLit = inputLits.front();
205 for (
auto [index,
lit] :
llvm::enumerate(inputLits.drop_front())) {
206 bool isLast = index + 2 == inputLits.size();
207 int outLit = isLast ? outVar : createAuxVar();
208 addXorClauses(outLit, accumulatedLit,
lit);
209 accumulatedLit = outLit;
213void FunctionalReductionSATBuilder::encodeValue(Value value) {
214 SmallVector<std::pair<Value, bool>> worklist;
215 worklist.push_back({value,
false});
217 while (!worklist.empty()) {
218 auto [current, readyToEncode] = worklist.pop_back_val();
219 if (encodedValues.contains(current))
222 Operation *op = current.getDefiningOp();
224 encodedValues.insert(current);
229 if (matchPattern(current, mlir::m_ConstantInt(&constantValue))) {
230 encodedValues.insert(current);
231 solver.addClause({constantValue.isZero() ? -getOrCreateVar(current)
232 : getOrCreateVar(current)});
236 if (!isFunctionalReductionSimulatableOp(op)) {
240 encodedValues.insert(current);
244 if (!readyToEncode) {
245 worklist.push_back({current,
true});
246 for (
auto input : op->getOperands()) {
247 assert(input.getType().isInteger(1) &&
248 "only i1 inputs should be simulated or encoded");
249 if (!encodedValues.contains(input))
250 worklist.push_back({input,
false});
255 encodedValues.insert(current);
256 int outVar = getOrCreateVar(current);
258 SmallVector<int> inputLits;
259 inputLits.reserve(op->getNumOperands());
260 TypeSwitch<Operation *>(op)
261 .Case<aig::AndInverterOp>([&](
auto andOp) {
262 for (
auto [input, inverted] :
263 llvm::zip(andOp.getInputs(), andOp.getInverted()))
264 inputLits.push_back(getLiteral(input, inverted));
265 addAndClauses(outVar, inputLits);
267 .Case<comb::AndOp>([&](
auto andOp) {
268 for (
auto input : andOp.getInputs())
269 inputLits.push_back(getLiteral(input));
270 addAndClauses(outVar, inputLits);
272 .Case<comb::OrOp>([&](
auto orOp) {
273 for (
auto input : orOp.getInputs())
274 inputLits.push_back(getLiteral(input));
275 addOrClauses(outVar, inputLits);
277 .Case<comb::XorOp>([&](
auto xorOp) {
278 for (
auto input : xorOp.getInputs())
279 inputLits.push_back(getLiteral(input));
280 addParityClauses(outVar, inputLits);
283 [](Operation *) { llvm_unreachable(
"unexpected supported op"); });
291class FunctionalReductionSolver {
293 FunctionalReductionSolver(
hw::HWModuleOp module,
unsigned numPatterns,
294 unsigned seed,
bool testTransformation,
295 std::unique_ptr<IncrementalSATSolver> satSolver)
296 : module(module), numPatterns(numPatterns), seed(seed),
297 testTransformation(testTransformation),
298 satSolver(std::move(satSolver)) {}
300 ~FunctionalReductionSolver() =
default;
304 unsigned numEquivClasses = 0;
305 unsigned numProvedEquiv = 0;
306 unsigned numDisprovedEquiv = 0;
307 unsigned numUnknown = 0;
308 unsigned numMergedNodes = 0;
310 mlir::FailureOr<Stats>
run();
314 void collectValues();
315 void runSimulation();
316 llvm::APInt simulateValue(Value v);
319 void buildEquivalenceClasses();
322 void verifyCandidates();
323 void initializeSATState();
326 void mergeEquivalentNodes();
329 static Attribute getTestEquivClass(Value value);
330 static bool matchesTestEquivClass(Value lhs, Value rhs);
331 EquivResult verifyEquivalence(Value lhs, Value rhs);
337 unsigned numPatterns;
339 bool testTransformation;
343 SmallVector<Value> primaryInputs;
346 SmallVector<Value> allValues;
349 llvm::DenseMap<Value, llvm::APInt> simSignatures;
353 SmallVector<SmallVector<Value>> equivCandidates;
358 std::unique_ptr<IncrementalSATSolver> satSolver;
359 std::unique_ptr<FunctionalReductionSATBuilder> satBuilder;
360 llvm::DenseMap<Value, int> satVars;
361 llvm::DenseSet<Value> encodedValues;
364 int nextFreshVar = 0;
368FunctionalReductionSATBuilder::FunctionalReductionSATBuilder(
370 llvm::DenseSet<Value> &encodedValues,
int &nextFreshVar)
371 : solver(solver), satVars(satVars), encodedValues(encodedValues),
372 nextFreshVar(nextFreshVar) {}
374Attribute FunctionalReductionSolver::getTestEquivClass(Value value) {
375 Operation *op = value.getDefiningOp();
381bool FunctionalReductionSolver::matchesTestEquivClass(Value lhs, Value rhs) {
382 Attribute lhsClass = getTestEquivClass(lhs);
383 Attribute rhsClass = getTestEquivClass(rhs);
384 return lhsClass && rhsClass && lhsClass == rhsClass;
387EquivResult FunctionalReductionSolver::verifyEquivalence(Value lhs, Value rhs) {
388 if (testTransformation) {
389 if (matchesTestEquivClass(lhs, rhs))
390 return EquivResult::Proved;
391 return EquivResult::Unknown;
393 assert(satBuilder &&
"SAT builder must be initialized before verification");
396 return satBuilder->verify(lhs, rhs);
399void FunctionalReductionSolver::initializeSATState() {
400 assert(satSolver &&
"SAT solver must be initialized before SAT state setup");
403 encodedValues.clear();
404 satVars.reserve(allValues.size());
405 for (
auto [index, value] :
llvm::enumerate(allValues))
406 satVars[value] = index + 1;
407 nextFreshVar = allValues.size();
408 satSolver->reserveVars(allValues.size());
410 satBuilder = std::make_unique<FunctionalReductionSATBuilder>(
411 *satSolver, satVars, encodedValues, nextFreshVar);
418void FunctionalReductionSolver::collectValues() {
420 for (
auto arg : module.
getBodyBlock()->getArguments()) {
421 if (arg.getType().isInteger(1)) {
422 primaryInputs.push_back(arg);
423 allValues.push_back(arg);
430 module.walk([&](Operation *op) {
431 for (auto result : op->getResults()) {
432 if (!result.getType().isInteger(1))
435 allValues.push_back(result);
436 if (!op->hasTrait<OpTrait::ConstantLike>() &&
437 !isFunctionalReductionSimulatableOp(op)) {
439 primaryInputs.push_back(result);
444 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Collected "
445 << primaryInputs.size()
446 <<
" primary inputs (including unknown ops) and "
447 << allValues.size() <<
" total i1 values\n");
450void FunctionalReductionSolver::runSimulation() {
452 unsigned numWords = numPatterns / 64;
455 std::mt19937_64 rng(seed);
457 for (
auto input : primaryInputs) {
459 SmallVector<uint64_t> words(numWords);
460 for (
auto &word : words)
464 llvm::APInt
pattern(numPatterns, words);
465 simSignatures[input] =
pattern;
469 for (
auto value : allValues) {
470 if (simSignatures.count(value))
473 simSignatures[value] = simulateValue(value);
477 llvm::dbgs() <<
"FunctionalReduction: Simulation complete with "
478 << numPatterns <<
" patterns\n";
482llvm::APInt FunctionalReductionSolver::simulateValue(Value v) {
483 Operation *op = v.getDefiningOp();
485 return simSignatures.at(v);
486 return llvm::TypeSwitch<Operation *, llvm::APInt>(op)
487 .Case<aig::AndInverterOp>([&](
auto op) {
488 SmallVector<llvm::APInt> inputSigs;
489 for (
auto input : op.getInputs())
490 inputSigs.push_back(simSignatures.at(input));
491 return op.evaluate(inputSigs);
493 .Case<comb::AndOp>([&](
auto op) {
494 APInt result = APInt::getAllOnes(numPatterns);
495 for (
auto input : op.getInputs())
496 result &= simSignatures.at(input);
499 .Case<comb::OrOp>([&](
auto op) {
500 APInt result = APInt::getZero(numPatterns);
501 for (
auto input : op.getInputs())
502 result |= simSignatures.at(input);
505 .Case<comb::XorOp>([&](
auto op) {
506 APInt result = APInt::getZero(numPatterns);
507 for (
auto input : op.getInputs())
508 result ^= simSignatures.at(input);
512 return op.getValue().isZero() ? APInt::getZero(numPatterns)
513 : APInt::getAllOnes(numPatterns);
515 .Default([&](Operation *) {
518 return simSignatures.at(v);
526void FunctionalReductionSolver::buildEquivalenceClasses() {
530 for (
auto value : allValues)
531 sigGroups[simSignatures.at(value)].push_back(value);
534 for (
auto &[hash, members] : sigGroups) {
535 if (members.size() <= 1)
537 equivCandidates.push_back(std::move(members));
539 stats.numEquivClasses = equivCandidates.size();
541 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Built "
542 << equivCandidates.size()
543 <<
" equivalence candidates\n");
553void FunctionalReductionSolver::verifyCandidates() {
555 llvm::dbgs() <<
"FunctionalReduction: Starting SAT verification with "
556 << equivCandidates.size() <<
" equivalence classes\n");
558 for (
auto &members : equivCandidates) {
561 auto representative = members.front();
562 auto &provenMembers = provenEquivalences[representative];
564 for (
auto member :
llvm::ArrayRef<Value>(members).drop_front()) {
565 EquivResult result = verifyEquivalence(representative, member);
566 if (result == EquivResult::Proved) {
567 stats.numProvedEquiv++;
568 provenMembers.push_back(member);
569 }
else if (result == EquivResult::Disproved) {
570 stats.numDisprovedEquiv++;
580 llvm::dbgs() <<
"FunctionalReduction: SAT verification complete. Proved "
581 << stats.numProvedEquiv <<
" equivalences\n");
588void FunctionalReductionSolver::mergeEquivalentNodes() {
589 if (provenEquivalences.empty())
592 mlir::OpBuilder builder(module.getContext());
593 for (
auto &provenEquivSet : provenEquivalences) {
594 auto &[representative, members] = provenEquivSet;
597 SmallVector<Value> operands;
598 operands.reserve(members.size() + 1);
599 operands.push_back(representative);
600 operands.append(members);
601 builder.setInsertionPointAfterValue(members.back());
602 auto choice = synth::ChoiceOp::create(builder, representative.getLoc(),
603 representative.getType(), operands);
604 stats.numMergedNodes += members.size() + 1;
605 representative.replaceAllUsesExcept(choice, choice);
606 for (
auto value : members)
607 value.replaceAllUsesExcept(choice, choice);
610 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Merged "
611 << stats.numMergedNodes <<
" nodes\n");
618mlir::FailureOr<FunctionalReductionSolver::Stats>
619FunctionalReductionSolver::run() {
621 llvm::dbgs() <<
"FunctionalReduction: Starting functional reduction with "
622 << numPatterns <<
" simulation patterns\n");
624 if (!testTransformation && !satSolver) {
626 << "FunctionalReduction requires a SAT solver, but none is "
627 "available in this build";
634 << "FunctionalReduction: Failed to topologically sort logic network";
640 if (allValues.empty()) {
641 LLVM_DEBUG(llvm::dbgs()
642 <<
"FunctionalReduction: No i1 values to process\n");
649 buildEquivalenceClasses();
650 if (equivCandidates.empty()) {
651 LLVM_DEBUG(llvm::dbgs()
652 <<
"FunctionalReduction: No equivalence candidates found\n");
657 if (!testTransformation)
658 initializeSATState();
662 mergeEquivalentNodes();
664 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Complete. Stats:\n"
665 <<
" Equivalence classes: " << stats.numEquivClasses
667 <<
" Proved: " << stats.numProvedEquiv <<
"\n"
668 <<
" Disproved: " << stats.numDisprovedEquiv <<
"\n"
669 <<
" Unknown (limit): " << stats.numUnknown <<
"\n"
670 <<
" Merged: " << stats.numMergedNodes <<
"\n");
679struct FunctionalReductionPass
680 :
public circt::synth::impl::FunctionalReductionBase<
681 FunctionalReductionPass> {
682 using FunctionalReductionBase::FunctionalReductionBase;
683 void updateStats(
const FunctionalReductionSolver::Stats &stats) {
684 numEquivClasses += stats.numEquivClasses;
685 numProvedEquiv += stats.numProvedEquiv;
686 numDisprovedEquiv += stats.numDisprovedEquiv;
687 numUnknown += stats.numUnknown;
688 numMergedNodes += stats.numMergedNodes;
691 void runOnOperation()
override {
692 auto module = getOperation();
693 LLVM_DEBUG(llvm::dbgs() <<
"Running FunctionalReduction pass on "
694 << module.getName() <<
"\n");
696 if (numRandomPatterns == 0 || (numRandomPatterns & 63U) != 0) {
698 << "'num-random-patterns' must be a positive multiple of 64";
699 return signalPassFailure();
701 if (conflictLimit < -1) {
703 << "'conflict-limit' must be greater than or equal to -1";
704 return signalPassFailure();
707 std::unique_ptr<IncrementalSATSolver> satSolver;
708 if (!testTransformation) {
709 satSolver = createFunctionalReductionSATSolver(this->satSolver);
711 module.emitError() << "unsupported or unavailable SAT solver '"
713 << "' (expected auto, z3, or cadical)";
714 return signalPassFailure();
716 satSolver->setConflictLimit(
static_cast<int>(conflictLimit));
719 FunctionalReductionSolver fcSolver(module, numRandomPatterns, seed,
721 std::move(satSolver));
722 auto stats = fcSolver.run();
724 return signalPassFailure();
726 if (stats->numMergedNodes == 0)
727 markAllAnalysesPreserved();
assert(baseType &&"element must be base type")
static constexpr llvm::StringLiteral kTestClassAttrName
static Block * getBodyBlock(FModuleLike mod)
RewritePatternSet pattern
Abstract interface for incremental SAT solvers with an IPASIR-style API.
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< IncrementalSATSolver > createZ3SATSolver()
Construct a Z3-backed incremental IPASIR-style SAT solver.
std::unique_ptr< IncrementalSATSolver > createCadicalSATSolver(const CadicalSATSolverOptions &options={})
Construct a CaDiCaL-backed incremental IPASIR-style SAT solver.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)