23#include "mlir/IR/Attributes.h"
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinOps.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LogicalResult.h"
29#include "llvm/ADT/APInt.h"
30#include "llvm/ADT/ArrayRef.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/MapVector.h"
34#include "llvm/ADT/STLFunctionalExtras.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/Debug.h"
41#define DEBUG_TYPE "synth-functional-reduction"
44 "synth.test.fc_equiv_class";
48#define GEN_PASS_DEF_FUNCTIONALREDUCTION
49#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
57enum class EquivResult { Proved, Disproved, Unknown };
59std::unique_ptr<IncrementalSATSolver>
60createFunctionalReductionSATSolver(llvm::StringRef backend) {
61 if (backend ==
"auto") {
66 if (backend ==
"cadical")
73class FunctionalReductionSATBuilder {
76 llvm::DenseMap<Value, int> &satVars,
77 llvm::DenseSet<Value> &encodedValues);
80 EquivResult verify(Value lhs, Value rhs,
bool inverted);
83 int getOrCreateVar(Value value);
87 SmallVector<int> getOperandVars(ValueRange operands);
88 void encodeValue(Value value);
91 llvm::DenseMap<Value, int> &satVars;
92 llvm::DenseSet<Value> &encodedValues;
95static bool isFunctionalReductionSimulatableOp(Operation *op) {
96 return isa<BooleanLogicOpInterface, comb::AndOp, comb::OrOp, comb::XorOp>(op);
99EquivResult FunctionalReductionSATBuilder::verify(Value lhs, Value rhs,
104 int lhsVar = getOrCreateVar(lhs);
105 int rhsVar = getOrCreateVar(rhs);
111 solver.assume(lhsVar);
112 solver.assume(-rhsVar);
113 auto result = solver.solve();
115 return EquivResult::Disproved;
117 return EquivResult::Unknown;
119 solver.assume(-lhsVar);
120 solver.assume(rhsVar);
121 result = solver.solve();
123 return EquivResult::Disproved;
125 return EquivResult::Unknown;
127 return EquivResult::Proved;
130int FunctionalReductionSATBuilder::getOrCreateVar(Value value) {
131 auto it = satVars.find(value);
132 assert(it != satVars.end() &&
"SAT variable must be preallocated");
136int FunctionalReductionSATBuilder::createAuxVar() {
return solver.newVar(); }
139FunctionalReductionSATBuilder::getOperandVars(ValueRange operands) {
140 SmallVector<int> vars;
141 vars.reserve(operands.size());
142 for (
auto operand : operands)
143 vars.push_back(getOrCreateVar(operand));
147void FunctionalReductionSATBuilder::encodeValue(Value value) {
148 SmallVector<std::pair<Value, bool>> worklist;
149 worklist.push_back({value,
false});
151 while (!worklist.empty()) {
152 auto [current, readyToEncode] = worklist.pop_back_val();
153 if (encodedValues.contains(current))
156 Operation *op = current.getDefiningOp();
158 encodedValues.insert(current);
163 if (matchPattern(current, mlir::m_ConstantInt(&constantValue))) {
164 encodedValues.insert(current);
165 solver.addClause({constantValue.isZero() ? -getOrCreateVar(current)
166 : getOrCreateVar(current)});
170 if (!isFunctionalReductionSimulatableOp(op)) {
174 encodedValues.insert(current);
178 if (!readyToEncode) {
179 worklist.push_back({current,
true});
180 for (
auto input : op->getOperands()) {
181 assert(input.getType().isInteger(1) &&
182 "only i1 inputs should be simulated or encoded");
183 if (!encodedValues.contains(input))
184 worklist.push_back({input,
false});
189 encodedValues.insert(current);
190 int outVar = getOrCreateVar(current);
191 auto addClause = [&](llvm::ArrayRef<int> clause) {
192 solver.addClause(clause);
195 TypeSwitch<Operation *>(op)
196 .Case<BooleanLogicOpInterface>([&](
auto logicOp) {
197 auto inputVars = getOperandVars(logicOp.getInputs());
198 logicOp.emitCNF(outVar, inputVars, addClause,
199 [&]() {
return createAuxVar(); });
201 .Case<comb::AndOp>([&](
auto andOp) {
202 auto inputLits = getOperandVars(andOp.getInputs());
205 .Case<comb::OrOp>([&](
auto orOp) {
206 auto inputLits = getOperandVars(orOp.getInputs());
209 .Case<comb::XorOp>([&](
auto xorOp) {
210 auto inputLits = getOperandVars(xorOp.getInputs());
212 [&]() {
return createAuxVar(); });
215 [](Operation *) { llvm_unreachable(
"unexpected supported op"); });
223class FunctionalReductionSolver {
225 FunctionalReductionSolver(
hw::HWModuleOp module,
unsigned numPatterns,
226 unsigned seed,
bool testTransformation,
227 std::unique_ptr<IncrementalSATSolver> satSolver)
228 : module(module), numPatterns(numPatterns), seed(seed),
229 testTransformation(testTransformation),
230 satSolver(std::move(satSolver)) {}
232 ~FunctionalReductionSolver() =
default;
236 unsigned numEquivClasses = 0;
237 unsigned numProvedEquiv = 0;
238 unsigned numDisprovedEquiv = 0;
239 unsigned numUnknown = 0;
240 unsigned numMergedNodes = 0;
242 mlir::FailureOr<Stats>
run();
246 void collectValues();
247 void runSimulation();
248 llvm::APInt simulateValue(Value v);
251 void buildEquivalenceClasses();
254 void verifyCandidates();
255 void initializeSATState();
258 void mergeEquivalentNodes();
261 static Attribute getTestEquivClass(Value value);
262 static bool matchesTestEquivClass(Value lhs, Value rhs);
263 EquivResult verifyEquivalence(Value lhs, Value rhs,
bool inverted);
269 unsigned numPatterns;
271 bool testTransformation;
275 SmallVector<Value> primaryInputs;
278 SmallVector<Value> allValues;
281 llvm::DenseMap<Value, llvm::APInt> simSignatures;
285 SmallVector<SmallVector<std::pair<Value, bool>>> equivCandidates;
293 std::unique_ptr<IncrementalSATSolver> satSolver;
294 std::unique_ptr<FunctionalReductionSATBuilder> satBuilder;
295 llvm::DenseMap<Value, int> satVars;
296 llvm::DenseSet<Value> encodedValues;
300FunctionalReductionSATBuilder::FunctionalReductionSATBuilder(
302 llvm::DenseSet<Value> &encodedValues)
303 : solver(solver), satVars(satVars), encodedValues(encodedValues) {}
305Attribute FunctionalReductionSolver::getTestEquivClass(Value value) {
306 Operation *op = value.getDefiningOp();
312bool FunctionalReductionSolver::matchesTestEquivClass(Value lhs, Value rhs) {
313 Attribute lhsClass = getTestEquivClass(lhs);
314 Attribute rhsClass = getTestEquivClass(rhs);
315 return lhsClass && rhsClass && lhsClass == rhsClass;
318EquivResult FunctionalReductionSolver::verifyEquivalence(Value lhs, Value rhs,
321 if (testTransformation) {
322 if (matchesTestEquivClass(lhs, rhs))
323 return EquivResult::Proved;
324 return EquivResult::Unknown;
326 assert(satBuilder &&
"SAT builder must be initialized before verification");
329 return satBuilder->verify(lhs, rhs, inverted);
332void FunctionalReductionSolver::initializeSATState() {
333 assert(satSolver &&
"SAT solver must be initialized before SAT state setup");
336 encodedValues.clear();
337 satVars.reserve(allValues.size());
338 for (
auto [index, value] :
llvm::enumerate(allValues))
339 satVars[value] = index + 1;
340 satSolver->reserveVars(allValues.size());
342 satBuilder = std::make_unique<FunctionalReductionSATBuilder>(
343 *satSolver, satVars, encodedValues);
350void FunctionalReductionSolver::collectValues() {
354 OpBuilder builder(module.getContext());
355 builder.setInsertionPointToStart(module.getBodyBlock());
356 auto i1Type = builder.getIntegerType(1);
360 for (
auto arg : module.
getBodyBlock()->getArguments()) {
361 if (arg.getType().isInteger(1)) {
362 primaryInputs.push_back(arg);
363 allValues.push_back(arg);
370 module.walk([&](Operation *op) {
371 for (auto result : op->getResults()) {
372 if (!result.getType().isInteger(1))
375 allValues.push_back(result);
376 if (!op->hasTrait<OpTrait::ConstantLike>() &&
377 !isFunctionalReductionSimulatableOp(op)) {
379 primaryInputs.push_back(result);
384 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Collected "
385 << primaryInputs.size()
386 <<
" primary inputs (including unknown ops) and "
387 << allValues.size() <<
" total i1 values\n");
390void FunctionalReductionSolver::runSimulation() {
392 unsigned numWords = numPatterns / 64;
395 std::mt19937_64 rng(seed);
397 for (
auto input : primaryInputs) {
399 SmallVector<uint64_t> words(numWords);
400 for (
auto &word : words)
404 llvm::APInt
pattern(numPatterns, words);
405 simSignatures[input] =
pattern;
409 for (
auto value : allValues) {
410 if (simSignatures.count(value))
413 simSignatures[value] = simulateValue(value);
417 llvm::dbgs() <<
"FunctionalReduction: Simulation complete with "
418 << numPatterns <<
" patterns\n";
422llvm::APInt FunctionalReductionSolver::simulateValue(Value v) {
423 Operation *op = v.getDefiningOp();
425 return simSignatures.at(v);
426 return llvm::TypeSwitch<Operation *, llvm::APInt>(op)
427 .Case<BooleanLogicOpInterface>([&](
auto op) {
428 return op.evaluateBooleanLogic([&](
unsigned i) ->
const APInt & {
429 return simSignatures.at(op.getInput(i));
432 .Case<comb::AndOp>([&](
auto op) {
433 APInt result = APInt::getAllOnes(numPatterns);
434 for (
auto input : op.getInputs())
435 result &= simSignatures.at(input);
438 .Case<comb::OrOp>([&](
auto op) {
439 APInt result = APInt::getZero(numPatterns);
440 for (
auto input : op.getInputs())
441 result |= simSignatures.at(input);
444 .Case<comb::XorOp>([&](
auto op) {
445 APInt result = APInt::getZero(numPatterns);
446 for (
auto input : op.getInputs())
447 result ^= simSignatures.at(input);
451 return op.getValue().isZero() ? APInt::getZero(numPatterns)
452 : APInt::getAllOnes(numPatterns);
454 .Default([&](Operation *) {
457 return simSignatures.at(v);
465void FunctionalReductionSolver::buildEquivalenceClasses() {
470 for (
auto value : allValues) {
471 auto signature = simSignatures.at(value);
472 bool inverted =
false;
473 if (signature.isNegative()) {
475 signature.flipAllBits();
477 sigGroups[signature].push_back({value, inverted});
482 for (
auto &[hash, members] : sigGroups) {
483 if (members.size() <= 1)
485 bool repInverted = members.front().second;
486 for (
auto &[_, inv] : members)
488 equivCandidates.push_back(std::move(members));
490 stats.numEquivClasses = equivCandidates.size();
492 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Built "
493 << equivCandidates.size()
494 <<
" equivalence candidates\n");
504void FunctionalReductionSolver::verifyCandidates() {
506 llvm::dbgs() <<
"FunctionalReduction: Starting SAT verification with "
507 << equivCandidates.size() <<
" equivalence classes\n");
509 for (
auto &members : equivCandidates) {
512 auto [representative, repInversion] = members.front();
513 assert(!repInversion &&
"representative must not be inverted");
515 auto &provenMembers = provenEquivalences[representative];
518 for (
auto [member, inverted] :
519 llvm::ArrayRef<std::pair<Value, bool>>(members).drop_front()) {
520 EquivResult result = verifyEquivalence(representative, member, inverted);
521 if (result == EquivResult::Proved) {
522 stats.numProvedEquiv++;
523 provenMembers.push_back({member, inverted});
524 }
else if (result == EquivResult::Disproved) {
525 stats.numDisprovedEquiv++;
535 llvm::dbgs() <<
"FunctionalReduction: SAT verification complete. Proved "
536 << stats.numProvedEquiv <<
" equivalences\n");
543void FunctionalReductionSolver::mergeEquivalentNodes() {
544 if (provenEquivalences.empty())
550 struct PlannedMember {
553 aig::AndInverterOp operandInverter;
555 struct MergeRewritePlan {
556 Value representative;
557 SmallVector<PlannedMember> members;
559 SmallVector<PlannedMember> reachableMembers;
560 synth::ChoiceOp choice;
561 aig::AndInverterOp choiceNot;
564 mlir::OpBuilder builder(module.getContext());
565 auto replaceDominatedUses =
566 [](Value from, Value to,
567 llvm::function_ref<bool(Operation *)> shouldReplaceOwner) {
568 auto *defOp = to.getDefiningOp();
569 assert(defOp &&
"replacement value must be defined by an operation");
570 from.replaceUsesWithIf(to, [&](OpOperand &use) {
571 auto *user = use.getOwner();
575 return shouldReplaceOwner(user) &&
576 user->getBlock() == defOp->getBlock();
580 DenseSet<Value> reachable;
581 auto visitFrom = [&](Value start) {
582 SmallVector<Value> stack;
583 stack.push_back(start);
584 while (!stack.empty()) {
585 Value current = stack.pop_back_val();
586 if (!reachable.insert(current).second)
588 for (Operation *user : current.getUsers())
590 for (Value result : user->getResults())
591 stack.push_back(result);
595 SmallVector<MergeRewritePlan> rewritePlans;
596 rewritePlans.reserve(provenEquivalences.size());
597 for (
auto provenEquivSet : provenEquivalences) {
598 auto &[representative, members] = provenEquivSet;
602 visitFrom(representative);
605 SmallVector<std::pair<Value, bool>> safeMembers;
606 SmallVector<PlannedMember> plannedReachable;
607 for (
auto [member, inverted] : members) {
608 if (reachable.count(member)) {
609 plannedReachable.push_back({member, inverted, {}});
613 safeMembers.push_back({member, inverted});
616 if (safeMembers.empty())
619 builder.setInsertionPointAfterValue(safeMembers.back().first);
621 SmallVector<Value> operands;
622 operands.reserve(safeMembers.size() + 1);
623 operands.push_back(representative);
625 SmallVector<PlannedMember> plannedMembers;
626 plannedMembers.reserve(safeMembers.size());
627 bool hasInvertedMember =
false;
628 for (
auto [member, inverted] : safeMembers) {
630 plannedMembers.emplace_back(PlannedMember{member, inverted, {}});
632 operands.push_back(member);
635 hasInvertedMember =
true;
638 planned.operandInverter =
639 aig::AndInverterOp::create(builder, member.getLoc(), member,
true);
640 operands.push_back(planned.operandInverter.getResult());
643 auto choice = synth::ChoiceOp::create(builder, representative.getLoc(),
644 representative.getType(), operands);
648 auto choiceNot = !hasInvertedMember
650 : aig::AndInverterOp::create(builder, choice.getLoc(),
653 stats.numMergedNodes += safeMembers.size() + 1;
654 rewritePlans.push_back({representative, std::move(plannedMembers),
655 std::move(plannedReachable), choice, choiceNot});
658 for (
auto &plan : rewritePlans) {
659 auto replaceValue = [&](
const PlannedMember &member) {
661 replaceDominatedUses(member.original, plan.choiceNot,
662 [&](Operation *user) {
668 return user != member.operandInverter &&
669 user != plan.choiceNot.getOperation();
672 replaceDominatedUses(member.original, plan.choice,
673 [&](Operation *user) {
674 return user != plan.choice.getOperation();
678 replaceDominatedUses(
679 plan.representative, plan.choice,
680 [&](Operation *user) { return user != plan.choice.getOperation(); });
681 for (
const auto &member : plan.members)
682 replaceValue(member);
686 for (
auto &member : plan.reachableMembers) {
687 member.original.replaceUsesWithIf(plan.choice, [&](OpOperand &use) {
688 auto *user = use.getOwner();
689 return user->getBlock() == plan.choice->getBlock();
691 if (member.original.use_empty())
692 member.original.getDefiningOp()->erase();
696 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Merged "
697 << stats.numMergedNodes <<
" nodes\n");
704mlir::FailureOr<FunctionalReductionSolver::Stats>
705FunctionalReductionSolver::run() {
707 llvm::dbgs() <<
"FunctionalReduction: Starting functional reduction with "
708 << numPatterns <<
" simulation patterns\n");
710 if (!testTransformation && !satSolver) {
712 << "FunctionalReduction requires a SAT solver, but none is "
713 "available in this build";
720 << "FunctionalReduction: Failed to topologically sort logic network";
726 if (allValues.empty()) {
727 LLVM_DEBUG(llvm::dbgs()
728 <<
"FunctionalReduction: No i1 values to process\n");
735 buildEquivalenceClasses();
736 if (equivCandidates.empty()) {
737 LLVM_DEBUG(llvm::dbgs()
738 <<
"FunctionalReduction: No equivalence candidates found\n");
743 if (!testTransformation)
744 initializeSATState();
748 mergeEquivalentNodes();
753 << "FunctionalReduction: Failed to topologically sort logic network";
757 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Complete. Stats:\n"
758 <<
" Equivalence classes: " << stats.numEquivClasses
760 <<
" Proved: " << stats.numProvedEquiv <<
"\n"
761 <<
" Disproved: " << stats.numDisprovedEquiv <<
"\n"
762 <<
" Unknown (limit): " << stats.numUnknown <<
"\n"
763 <<
" Merged: " << stats.numMergedNodes <<
"\n");
772struct FunctionalReductionPass
773 :
public circt::synth::impl::FunctionalReductionBase<
774 FunctionalReductionPass> {
775 using FunctionalReductionBase::FunctionalReductionBase;
776 void updateStats(
const FunctionalReductionSolver::Stats &stats) {
777 numEquivClasses += stats.numEquivClasses;
778 numProvedEquiv += stats.numProvedEquiv;
779 numDisprovedEquiv += stats.numDisprovedEquiv;
780 numUnknown += stats.numUnknown;
781 numMergedNodes += stats.numMergedNodes;
784 void runOnOperation()
override {
785 auto module = getOperation();
786 LLVM_DEBUG(llvm::dbgs() <<
"Running FunctionalReduction pass on "
787 << module.getName() <<
"\n");
789 if (numRandomPatterns == 0 || (numRandomPatterns & 63U) != 0) {
791 << "'num-random-patterns' must be a positive multiple of 64";
792 return signalPassFailure();
794 if (conflictLimit < -1) {
796 << "'conflict-limit' must be greater than or equal to -1";
797 return signalPassFailure();
800 std::unique_ptr<IncrementalSATSolver> satSolver;
801 if (!testTransformation) {
802 satSolver = createFunctionalReductionSATSolver(this->satSolver);
804 module.emitError() << "unsupported or unavailable SAT solver '"
806 << "' (expected auto, z3, or cadical)";
807 return signalPassFailure();
809 satSolver->setConflictLimit(
static_cast<int>(conflictLimit));
812 FunctionalReductionSolver fcSolver(module, numRandomPatterns, seed,
814 std::move(satSolver));
815 auto stats = fcSolver.run();
817 return signalPassFailure();
819 if (stats->numMergedNodes == 0)
820 markAllAnalysesPreserved();
assert(baseType &&"element must be base type")
static constexpr llvm::StringLiteral kTestClassAttrName
static Block * getBodyBlock(FModuleLike mod)
RewritePatternSet pattern
Abstract interface for incremental SAT solvers with an IPASIR-style API.
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
bool isLogicNetworkOp(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void addAndClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> and(inputLits).
std::unique_ptr< IncrementalSATSolver > createZ3SATSolver()
Construct a Z3-backed incremental IPASIR-style SAT solver.
std::unique_ptr< IncrementalSATSolver > createCadicalSATSolver(const CadicalSATSolverOptions &options={})
Construct a CaDiCaL-backed incremental IPASIR-style SAT solver.
void addOrClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> or(inputLits).
void addParityClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause, llvm::function_ref< int()> newVar)
Emit clauses encoding outVar <=> parity(inputLits).
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)