23#include "mlir/IR/Attributes.h"
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinOps.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LogicalResult.h"
29#include "llvm/ADT/APInt.h"
30#include "llvm/ADT/ArrayRef.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/MapVector.h"
34#include "llvm/ADT/STLFunctionalExtras.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/Debug.h"
41#define DEBUG_TYPE "synth-functional-reduction"
44 "synth.test.fc_equiv_class";
48#define GEN_PASS_DEF_FUNCTIONALREDUCTION
49#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
57enum class EquivResult { Proved, Disproved, Unknown };
59std::unique_ptr<IncrementalSATSolver>
60createFunctionalReductionSATSolver(llvm::StringRef backend) {
61 if (backend ==
"auto") {
66 if (backend ==
"cadical")
73class FunctionalReductionSATBuilder {
76 llvm::DenseMap<Value, int> &satVars,
77 llvm::DenseSet<Value> &encodedValues);
80 EquivResult verify(Value lhs, Value rhs,
bool inverted);
83 int getOrCreateVar(Value value);
87 SmallVector<int> getOperandVars(ValueRange operands);
88 void encodeValue(Value value);
91 llvm::DenseMap<Value, int> &satVars;
92 llvm::DenseSet<Value> &encodedValues;
95static bool isFunctionalReductionSimulatableOp(Operation *op) {
96 return isa<BooleanLogicOpInterface, comb::AndOp, comb::OrOp, comb::XorOp>(op);
99EquivResult FunctionalReductionSATBuilder::verify(Value lhs, Value rhs,
104 int lhsVar = getOrCreateVar(lhs);
105 int rhsVar = getOrCreateVar(rhs);
111 solver.assume(lhsVar);
112 solver.assume(-rhsVar);
113 auto result = solver.solve();
115 return EquivResult::Disproved;
117 return EquivResult::Unknown;
119 solver.assume(-lhsVar);
120 solver.assume(rhsVar);
121 result = solver.solve();
123 return EquivResult::Disproved;
125 return EquivResult::Unknown;
127 return EquivResult::Proved;
130int FunctionalReductionSATBuilder::getOrCreateVar(Value value) {
131 auto it = satVars.find(value);
132 assert(it != satVars.end() &&
"SAT variable must be preallocated");
136int FunctionalReductionSATBuilder::createAuxVar() {
return solver.newVar(); }
139FunctionalReductionSATBuilder::getOperandVars(ValueRange operands) {
140 SmallVector<int> vars;
141 vars.reserve(operands.size());
142 for (
auto operand : operands)
143 vars.push_back(getOrCreateVar(operand));
147void FunctionalReductionSATBuilder::encodeValue(Value value) {
148 SmallVector<std::pair<Value, bool>> worklist;
149 worklist.push_back({value,
false});
151 while (!worklist.empty()) {
152 auto [current, readyToEncode] = worklist.pop_back_val();
153 if (encodedValues.contains(current))
156 Operation *op = current.getDefiningOp();
158 encodedValues.insert(current);
163 if (matchPattern(current, mlir::m_ConstantInt(&constantValue))) {
164 encodedValues.insert(current);
165 solver.addClause({constantValue.isZero() ? -getOrCreateVar(current)
166 : getOrCreateVar(current)});
170 if (!isFunctionalReductionSimulatableOp(op)) {
174 encodedValues.insert(current);
178 if (!readyToEncode) {
179 worklist.push_back({current,
true});
180 for (
auto input : op->getOperands()) {
181 assert(input.getType().isInteger(1) &&
182 "only i1 inputs should be simulated or encoded");
183 if (!encodedValues.contains(input))
184 worklist.push_back({input,
false});
189 encodedValues.insert(current);
190 int outVar = getOrCreateVar(current);
191 auto addClause = [&](llvm::ArrayRef<int> clause) {
192 solver.addClause(clause);
195 TypeSwitch<Operation *>(op)
196 .Case<BooleanLogicOpInterface>([&](
auto logicOp) {
197 auto inputVars = getOperandVars(logicOp.getInputs());
198 logicOp.emitCNF(outVar, inputVars, addClause,
199 [&]() {
return createAuxVar(); });
201 .Case<comb::AndOp>([&](
auto andOp) {
202 auto inputLits = getOperandVars(andOp.getInputs());
205 .Case<comb::OrOp>([&](
auto orOp) {
206 auto inputLits = getOperandVars(orOp.getInputs());
209 .Case<comb::XorOp>([&](
auto xorOp) {
210 auto inputLits = getOperandVars(xorOp.getInputs());
212 [&]() {
return createAuxVar(); });
215 [](Operation *) { llvm_unreachable(
"unexpected supported op"); });
223class FunctionalReductionSolver {
225 FunctionalReductionSolver(
hw::HWModuleOp module,
unsigned numPatterns,
226 unsigned seed,
bool testTransformation,
227 std::unique_ptr<IncrementalSATSolver> satSolver)
228 : module(module), numPatterns(numPatterns), seed(seed),
229 testTransformation(testTransformation),
230 satSolver(std::move(satSolver)) {}
232 ~FunctionalReductionSolver() =
default;
236 unsigned numEquivClasses = 0;
237 unsigned numProvedEquiv = 0;
238 unsigned numDisprovedEquiv = 0;
239 unsigned numUnknown = 0;
240 unsigned numMergedNodes = 0;
242 mlir::FailureOr<Stats>
run();
246 void collectValues();
247 void runSimulation();
248 llvm::APInt simulateValue(Value v);
251 void buildEquivalenceClasses();
254 void verifyCandidates();
255 void initializeSATState();
258 void mergeEquivalentNodes();
261 static Attribute getTestEquivClass(Value value);
262 static bool matchesTestEquivClass(Value lhs, Value rhs);
263 EquivResult verifyEquivalence(Value lhs, Value rhs,
bool inverted);
269 unsigned numPatterns;
271 bool testTransformation;
275 SmallVector<Value> primaryInputs;
278 SmallVector<Value> allValues;
281 llvm::DenseMap<Value, llvm::APInt> simSignatures;
285 SmallVector<SmallVector<std::pair<Value, bool>>> equivCandidates;
293 std::unique_ptr<IncrementalSATSolver> satSolver;
294 std::unique_ptr<FunctionalReductionSATBuilder> satBuilder;
295 llvm::DenseMap<Value, int> satVars;
296 llvm::DenseSet<Value> encodedValues;
300FunctionalReductionSATBuilder::FunctionalReductionSATBuilder(
302 llvm::DenseSet<Value> &encodedValues)
303 : solver(solver), satVars(satVars), encodedValues(encodedValues) {}
305Attribute FunctionalReductionSolver::getTestEquivClass(Value value) {
306 Operation *op = value.getDefiningOp();
312bool FunctionalReductionSolver::matchesTestEquivClass(Value lhs, Value rhs) {
313 Attribute lhsClass = getTestEquivClass(lhs);
314 Attribute rhsClass = getTestEquivClass(rhs);
315 return lhsClass && rhsClass && lhsClass == rhsClass;
318EquivResult FunctionalReductionSolver::verifyEquivalence(Value lhs, Value rhs,
321 if (testTransformation) {
322 if (matchesTestEquivClass(lhs, rhs))
323 return EquivResult::Proved;
324 return EquivResult::Unknown;
326 assert(satBuilder &&
"SAT builder must be initialized before verification");
329 return satBuilder->verify(lhs, rhs, inverted);
332void FunctionalReductionSolver::initializeSATState() {
333 assert(satSolver &&
"SAT solver must be initialized before SAT state setup");
336 encodedValues.clear();
337 satVars.reserve(allValues.size());
338 for (
auto [index, value] :
llvm::enumerate(allValues))
339 satVars[value] = index + 1;
340 satSolver->reserveVars(allValues.size());
342 satBuilder = std::make_unique<FunctionalReductionSATBuilder>(
343 *satSolver, satVars, encodedValues);
350void FunctionalReductionSolver::collectValues() {
352 for (
auto arg : module.
getBodyBlock()->getArguments()) {
353 if (arg.getType().isInteger(1)) {
354 primaryInputs.push_back(arg);
355 allValues.push_back(arg);
362 module.walk([&](Operation *op) {
363 for (auto result : op->getResults()) {
364 if (!result.getType().isInteger(1))
367 allValues.push_back(result);
368 if (!op->hasTrait<OpTrait::ConstantLike>() &&
369 !isFunctionalReductionSimulatableOp(op)) {
371 primaryInputs.push_back(result);
376 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Collected "
377 << primaryInputs.size()
378 <<
" primary inputs (including unknown ops) and "
379 << allValues.size() <<
" total i1 values\n");
382void FunctionalReductionSolver::runSimulation() {
384 unsigned numWords = numPatterns / 64;
387 std::mt19937_64 rng(seed);
389 for (
auto input : primaryInputs) {
391 SmallVector<uint64_t> words(numWords);
392 for (
auto &word : words)
396 llvm::APInt
pattern(numPatterns, words);
397 simSignatures[input] =
pattern;
401 for (
auto value : allValues) {
402 if (simSignatures.count(value))
405 simSignatures[value] = simulateValue(value);
409 llvm::dbgs() <<
"FunctionalReduction: Simulation complete with "
410 << numPatterns <<
" patterns\n";
414llvm::APInt FunctionalReductionSolver::simulateValue(Value v) {
415 Operation *op = v.getDefiningOp();
417 return simSignatures.at(v);
418 return llvm::TypeSwitch<Operation *, llvm::APInt>(op)
419 .Case<BooleanLogicOpInterface>([&](
auto op) {
420 return op.evaluateBooleanLogic([&](
unsigned i) ->
const APInt & {
421 return simSignatures.at(op.getInput(i));
424 .Case<comb::AndOp>([&](
auto op) {
425 APInt result = APInt::getAllOnes(numPatterns);
426 for (
auto input : op.getInputs())
427 result &= simSignatures.at(input);
430 .Case<comb::OrOp>([&](
auto op) {
431 APInt result = APInt::getZero(numPatterns);
432 for (
auto input : op.getInputs())
433 result |= simSignatures.at(input);
436 .Case<comb::XorOp>([&](
auto op) {
437 APInt result = APInt::getZero(numPatterns);
438 for (
auto input : op.getInputs())
439 result ^= simSignatures.at(input);
443 return op.getValue().isZero() ? APInt::getZero(numPatterns)
444 : APInt::getAllOnes(numPatterns);
446 .Default([&](Operation *) {
449 return simSignatures.at(v);
457void FunctionalReductionSolver::buildEquivalenceClasses() {
462 for (
auto value : allValues) {
463 auto signature = simSignatures.at(value);
464 bool inverted =
false;
465 if (signature.isNegative()) {
467 signature.flipAllBits();
469 sigGroups[signature].push_back({value, inverted});
474 for (
auto &[hash, members] : sigGroups) {
475 if (members.size() <= 1)
477 bool repInverted = members.front().second;
478 for (
auto &[_, inv] : members)
480 equivCandidates.push_back(std::move(members));
482 stats.numEquivClasses = equivCandidates.size();
484 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Built "
485 << equivCandidates.size()
486 <<
" equivalence candidates\n");
496void FunctionalReductionSolver::verifyCandidates() {
498 llvm::dbgs() <<
"FunctionalReduction: Starting SAT verification with "
499 << equivCandidates.size() <<
" equivalence classes\n");
501 for (
auto &members : equivCandidates) {
504 auto [representative, repInversion] = members.front();
505 assert(!repInversion &&
"representative must not be inverted");
507 auto &provenMembers = provenEquivalences[representative];
510 for (
auto [member, inverted] :
511 llvm::ArrayRef<std::pair<Value, bool>>(members).drop_front()) {
512 EquivResult result = verifyEquivalence(representative, member, inverted);
513 if (result == EquivResult::Proved) {
514 stats.numProvedEquiv++;
515 provenMembers.push_back({member, inverted});
516 }
else if (result == EquivResult::Disproved) {
517 stats.numDisprovedEquiv++;
527 llvm::dbgs() <<
"FunctionalReduction: SAT verification complete. Proved "
528 << stats.numProvedEquiv <<
" equivalences\n");
535void FunctionalReductionSolver::mergeEquivalentNodes() {
536 if (provenEquivalences.empty())
542 struct PlannedMember {
545 aig::AndInverterOp operandInverter;
547 struct MergeRewritePlan {
548 Value representative;
549 SmallVector<PlannedMember> members;
551 SmallVector<PlannedMember> reachableMembers;
552 synth::ChoiceOp choice;
553 aig::AndInverterOp choiceNot;
556 mlir::OpBuilder builder(module.getContext());
557 auto replaceDominatedUses =
558 [](Value from, Value to,
559 llvm::function_ref<bool(Operation *)> shouldReplaceOwner) {
560 auto *defOp = to.getDefiningOp();
561 assert(defOp &&
"replacement value must be defined by an operation");
562 from.replaceUsesWithIf(to, [&](OpOperand &use) {
563 auto *user = use.getOwner();
567 return shouldReplaceOwner(user) &&
568 user->getBlock() == defOp->getBlock();
572 DenseSet<Value> reachable;
573 auto visitFrom = [&](Value start) {
574 SmallVector<Value> stack;
575 stack.push_back(start);
576 while (!stack.empty()) {
577 Value current = stack.pop_back_val();
578 if (!reachable.insert(current).second)
580 for (Operation *user : current.getUsers())
582 for (Value result : user->getResults())
583 stack.push_back(result);
587 SmallVector<MergeRewritePlan> rewritePlans;
588 rewritePlans.reserve(provenEquivalences.size());
589 for (
auto provenEquivSet : provenEquivalences) {
590 auto &[representative, members] = provenEquivSet;
594 visitFrom(representative);
597 SmallVector<std::pair<Value, bool>> safeMembers;
598 SmallVector<PlannedMember> plannedReachable;
599 for (
auto [member, inverted] : members) {
600 if (reachable.count(member)) {
601 plannedReachable.push_back({member, inverted, {}});
605 safeMembers.push_back({member, inverted});
608 if (safeMembers.empty())
611 builder.setInsertionPointAfterValue(safeMembers.back().first);
613 SmallVector<Value> operands;
614 operands.reserve(safeMembers.size() + 1);
615 operands.push_back(representative);
617 SmallVector<PlannedMember> plannedMembers;
618 plannedMembers.reserve(safeMembers.size());
619 bool hasInvertedMember =
false;
620 for (
auto [member, inverted] : safeMembers) {
622 plannedMembers.emplace_back(PlannedMember{member, inverted, {}});
624 operands.push_back(member);
627 hasInvertedMember =
true;
630 planned.operandInverter =
631 aig::AndInverterOp::create(builder, member.getLoc(), member,
true);
632 operands.push_back(planned.operandInverter.getResult());
635 auto choice = synth::ChoiceOp::create(builder, representative.getLoc(),
636 representative.getType(), operands);
640 auto choiceNot = !hasInvertedMember
642 : aig::AndInverterOp::create(builder, choice.getLoc(),
645 stats.numMergedNodes += safeMembers.size() + 1;
646 rewritePlans.push_back({representative, std::move(plannedMembers),
647 std::move(plannedReachable), choice, choiceNot});
650 for (
auto &plan : rewritePlans) {
651 auto replaceValue = [&](
const PlannedMember &member) {
653 replaceDominatedUses(member.original, plan.choiceNot,
654 [&](Operation *user) {
660 return user != member.operandInverter &&
661 user != plan.choiceNot.getOperation();
664 replaceDominatedUses(member.original, plan.choice,
665 [&](Operation *user) {
666 return user != plan.choice.getOperation();
670 replaceDominatedUses(
671 plan.representative, plan.choice,
672 [&](Operation *user) { return user != plan.choice.getOperation(); });
673 for (
const auto &member : plan.members)
674 replaceValue(member);
678 for (
auto &member : plan.reachableMembers) {
679 member.original.replaceUsesWithIf(plan.choice, [&](OpOperand &use) {
680 auto *user = use.getOwner();
681 return user->getBlock() == plan.choice->getBlock();
683 if (member.original.use_empty())
684 member.original.getDefiningOp()->erase();
688 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Merged "
689 << stats.numMergedNodes <<
" nodes\n");
696mlir::FailureOr<FunctionalReductionSolver::Stats>
697FunctionalReductionSolver::run() {
699 llvm::dbgs() <<
"FunctionalReduction: Starting functional reduction with "
700 << numPatterns <<
" simulation patterns\n");
702 if (!testTransformation && !satSolver) {
704 << "FunctionalReduction requires a SAT solver, but none is "
705 "available in this build";
712 << "FunctionalReduction: Failed to topologically sort logic network";
718 if (allValues.empty()) {
719 LLVM_DEBUG(llvm::dbgs()
720 <<
"FunctionalReduction: No i1 values to process\n");
727 buildEquivalenceClasses();
728 if (equivCandidates.empty()) {
729 LLVM_DEBUG(llvm::dbgs()
730 <<
"FunctionalReduction: No equivalence candidates found\n");
735 if (!testTransformation)
736 initializeSATState();
740 mergeEquivalentNodes();
745 << "FunctionalReduction: Failed to topologically sort logic network";
749 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Complete. Stats:\n"
750 <<
" Equivalence classes: " << stats.numEquivClasses
752 <<
" Proved: " << stats.numProvedEquiv <<
"\n"
753 <<
" Disproved: " << stats.numDisprovedEquiv <<
"\n"
754 <<
" Unknown (limit): " << stats.numUnknown <<
"\n"
755 <<
" Merged: " << stats.numMergedNodes <<
"\n");
764struct FunctionalReductionPass
765 :
public circt::synth::impl::FunctionalReductionBase<
766 FunctionalReductionPass> {
767 using FunctionalReductionBase::FunctionalReductionBase;
768 void updateStats(
const FunctionalReductionSolver::Stats &stats) {
769 numEquivClasses += stats.numEquivClasses;
770 numProvedEquiv += stats.numProvedEquiv;
771 numDisprovedEquiv += stats.numDisprovedEquiv;
772 numUnknown += stats.numUnknown;
773 numMergedNodes += stats.numMergedNodes;
776 void runOnOperation()
override {
777 auto module = getOperation();
778 LLVM_DEBUG(llvm::dbgs() <<
"Running FunctionalReduction pass on "
779 << module.getName() <<
"\n");
781 if (numRandomPatterns == 0 || (numRandomPatterns & 63U) != 0) {
783 << "'num-random-patterns' must be a positive multiple of 64";
784 return signalPassFailure();
786 if (conflictLimit < -1) {
788 << "'conflict-limit' must be greater than or equal to -1";
789 return signalPassFailure();
792 std::unique_ptr<IncrementalSATSolver> satSolver;
793 if (!testTransformation) {
794 satSolver = createFunctionalReductionSATSolver(this->satSolver);
796 module.emitError() << "unsupported or unavailable SAT solver '"
798 << "' (expected auto, z3, or cadical)";
799 return signalPassFailure();
801 satSolver->setConflictLimit(
static_cast<int>(conflictLimit));
804 FunctionalReductionSolver fcSolver(module, numRandomPatterns, seed,
806 std::move(satSolver));
807 auto stats = fcSolver.run();
809 return signalPassFailure();
811 if (stats->numMergedNodes == 0)
812 markAllAnalysesPreserved();
assert(baseType &&"element must be base type")
static constexpr llvm::StringLiteral kTestClassAttrName
static Block * getBodyBlock(FModuleLike mod)
RewritePatternSet pattern
Abstract interface for incremental SAT solvers with an IPASIR-style API.
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
bool isLogicNetworkOp(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void addAndClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> and(inputLits).
std::unique_ptr< IncrementalSATSolver > createZ3SATSolver()
Construct a Z3-backed incremental IPASIR-style SAT solver.
std::unique_ptr< IncrementalSATSolver > createCadicalSATSolver(const CadicalSATSolverOptions &options={})
Construct a CaDiCaL-backed incremental IPASIR-style SAT solver.
void addOrClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> or(inputLits).
void addParityClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause, llvm::function_ref< int()> newVar)
Emit clauses encoding outVar <=> parity(inputLits).
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)