CIRCT 23.0.0git
Loading...
Searching...
No Matches
FunctionalReduction.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements FunctionalReduction (Functionally Reduced And-Inverter
10// Graph) optimization. It identifies and merges functionally equivalent nodes
11// through simulation-based candidate detection followed by SAT-based
12// verification.
13//
14//===----------------------------------------------------------------------===//
15
23#include "mlir/IR/Attributes.h"
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinOps.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LogicalResult.h"
29#include "llvm/ADT/APInt.h"
30#include "llvm/ADT/ArrayRef.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/MapVector.h"
34#include "llvm/ADT/STLFunctionalExtras.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/Debug.h"
39#include <random>
40
41#define DEBUG_TYPE "synth-functional-reduction"
42
43static constexpr llvm::StringLiteral kTestClassAttrName =
44 "synth.test.fc_equiv_class";
45
46namespace circt {
47namespace synth {
48#define GEN_PASS_DEF_FUNCTIONALREDUCTION
49#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
50} // namespace synth
51} // namespace circt
52
53using namespace circt;
54using namespace circt::synth;
55
56namespace {
57enum class EquivResult { Proved, Disproved, Unknown };
58
59std::unique_ptr<IncrementalSATSolver>
60createFunctionalReductionSATSolver(llvm::StringRef backend) {
61 if (backend == "auto") {
62 if (auto solver = createCadicalSATSolver())
63 return solver;
64 return createZ3SATSolver();
65 }
66 if (backend == "cadical")
68 if (backend == "z3")
69 return createZ3SATSolver();
70 return {};
71}
72
73class FunctionalReductionSATBuilder {
74public:
75 FunctionalReductionSATBuilder(IncrementalSATSolver &solver,
76 llvm::DenseMap<Value, int> &satVars,
77 llvm::DenseSet<Value> &encodedValues);
78
79 // If inverted, negates rhs in the SAT encoding to check lhs == NOT(rhs).
80 EquivResult verify(Value lhs, Value rhs, bool inverted);
81
82private:
83 int getOrCreateVar(Value value);
84 // Create a fresh SAT variable for an intermediate Boolean subexpression that
85 // does not correspond to an MLIR value.
86 int createAuxVar();
87 SmallVector<int> getOperandVars(ValueRange operands);
88 void encodeValue(Value value);
89
91 llvm::DenseMap<Value, int> &satVars;
92 llvm::DenseSet<Value> &encodedValues;
93};
94
95static bool isFunctionalReductionSimulatableOp(Operation *op) {
96 return isa<BooleanLogicOpInterface, comb::AndOp, comb::OrOp, comb::XorOp>(op);
97}
98
99EquivResult FunctionalReductionSATBuilder::verify(Value lhs, Value rhs,
100 bool inverted) {
101 encodeValue(lhs);
102 encodeValue(rhs);
103
104 int lhsVar = getOrCreateVar(lhs);
105 int rhsVar = getOrCreateVar(rhs);
106
107 if (inverted)
108 rhsVar = -rhsVar;
109 // Check the two halves of the XOR miter separately. If either assignment is
110 // satisfiable, the solver found a distinguishing input pattern.
111 solver.assume(lhsVar);
112 solver.assume(-rhsVar);
113 auto result = solver.solve();
114 if (result == IncrementalSATSolver::kSAT)
115 return EquivResult::Disproved;
116 if (result != IncrementalSATSolver::kUNSAT)
117 return EquivResult::Unknown;
118
119 solver.assume(-lhsVar);
120 solver.assume(rhsVar);
121 result = solver.solve();
122 if (result == IncrementalSATSolver::kSAT)
123 return EquivResult::Disproved;
124 if (result != IncrementalSATSolver::kUNSAT)
125 return EquivResult::Unknown;
126
127 return EquivResult::Proved;
128}
129
130int FunctionalReductionSATBuilder::getOrCreateVar(Value value) {
131 auto it = satVars.find(value);
132 assert(it != satVars.end() && "SAT variable must be preallocated");
133 return it->second;
134}
135
136int FunctionalReductionSATBuilder::createAuxVar() { return solver.newVar(); }
137
138SmallVector<int>
139FunctionalReductionSATBuilder::getOperandVars(ValueRange operands) {
140 SmallVector<int> vars;
141 vars.reserve(operands.size());
142 for (auto operand : operands)
143 vars.push_back(getOrCreateVar(operand));
144 return vars;
145}
146
147void FunctionalReductionSATBuilder::encodeValue(Value value) {
148 SmallVector<std::pair<Value, bool>> worklist;
149 worklist.push_back({value, false});
150
151 while (!worklist.empty()) {
152 auto [current, readyToEncode] = worklist.pop_back_val();
153 if (encodedValues.contains(current))
154 continue;
155
156 Operation *op = current.getDefiningOp();
157 if (!op) {
158 encodedValues.insert(current);
159 continue;
160 }
161
162 APInt constantValue;
163 if (matchPattern(current, mlir::m_ConstantInt(&constantValue))) {
164 encodedValues.insert(current);
165 solver.addClause({constantValue.isZero() ? -getOrCreateVar(current)
166 : getOrCreateVar(current)});
167 continue;
168 }
169
170 if (!isFunctionalReductionSimulatableOp(op)) {
171 // Unsupported operations remain unconstrained, just like block
172 // arguments. Since we only prove equivalence from UNSAT, omitting these
173 // clauses may miss a proof but cannot create a false proof.
174 encodedValues.insert(current);
175 continue;
176 }
177
178 if (!readyToEncode) {
179 worklist.push_back({current, true});
180 for (auto input : op->getOperands()) {
181 assert(input.getType().isInteger(1) &&
182 "only i1 inputs should be simulated or encoded");
183 if (!encodedValues.contains(input))
184 worklist.push_back({input, false});
185 }
186 continue;
187 }
188
189 encodedValues.insert(current);
190 int outVar = getOrCreateVar(current);
191 auto addClause = [&](llvm::ArrayRef<int> clause) {
192 solver.addClause(clause);
193 };
194
195 TypeSwitch<Operation *>(op)
196 .Case<BooleanLogicOpInterface>([&](auto logicOp) {
197 auto inputVars = getOperandVars(logicOp.getInputs());
198 logicOp.emitCNF(outVar, inputVars, addClause,
199 [&]() { return createAuxVar(); });
200 })
201 .Case<comb::AndOp>([&](auto andOp) {
202 auto inputLits = getOperandVars(andOp.getInputs());
203 circt::addAndClauses(outVar, inputLits, addClause);
204 })
205 .Case<comb::OrOp>([&](auto orOp) {
206 auto inputLits = getOperandVars(orOp.getInputs());
207 circt::addOrClauses(outVar, inputLits, addClause);
208 })
209 .Case<comb::XorOp>([&](auto xorOp) {
210 auto inputLits = getOperandVars(xorOp.getInputs());
211 circt::addParityClauses(outVar, inputLits, addClause,
212 [&]() { return createAuxVar(); });
213 })
214 .Default(
215 [](Operation *) { llvm_unreachable("unexpected supported op"); });
216 }
217}
218
219//===----------------------------------------------------------------------===//
220// Core Functional Reduction Implementation
221//===----------------------------------------------------------------------===//
222
223class FunctionalReductionSolver {
224public:
225 FunctionalReductionSolver(hw::HWModuleOp module, unsigned numPatterns,
226 unsigned seed, bool testTransformation,
227 std::unique_ptr<IncrementalSATSolver> satSolver)
228 : module(module), numPatterns(numPatterns), seed(seed),
229 testTransformation(testTransformation),
230 satSolver(std::move(satSolver)) {}
231
232 ~FunctionalReductionSolver() = default;
233
234 /// Run the Functional Reduction algorithm and return statistics.
235 struct Stats {
236 unsigned numEquivClasses = 0;
237 unsigned numProvedEquiv = 0;
238 unsigned numDisprovedEquiv = 0;
239 unsigned numUnknown = 0;
240 unsigned numMergedNodes = 0;
241 };
242 mlir::FailureOr<Stats> run();
243
244private:
245 // Phase 1: Collect i1 values and run simulation
246 void collectValues();
247 void runSimulation();
248 llvm::APInt simulateValue(Value v);
249
250 // Phase 2: Build equivalence classes from simulation
251 void buildEquivalenceClasses();
252
253 // Phase 3: SAT-based verification with per-class solver
254 void verifyCandidates();
255 void initializeSATState();
256
257 // Phase 4: Merge equivalent nodes
258 void mergeEquivalentNodes();
259
260 // Test transformation helpers.
261 static Attribute getTestEquivClass(Value value);
262 static bool matchesTestEquivClass(Value lhs, Value rhs);
263 EquivResult verifyEquivalence(Value lhs, Value rhs, bool inverted);
264
265 // Module being processed
266 hw::HWModuleOp module;
267
268 // Configuration
269 unsigned numPatterns;
270 unsigned seed;
271 bool testTransformation;
272
273 // Primary inputs (block arguments or results of unknown operations treated as
274 // inputs)
275 SmallVector<Value> primaryInputs;
276
277 // All i1 values in topological order
278 SmallVector<Value> allValues;
279
280 // Simulation signatures: value -> APInt simulation result
281 llvm::DenseMap<Value, llvm::APInt> simSignatures;
282
283 // Equivalence candidates: groups of values with identical or inverted
284 // simulation signatures, tracked with an inversion flag
285 SmallVector<SmallVector<std::pair<Value, bool>>> equivCandidates;
286
287 // Proven equivalences: representative -> proven equivalent members with
288 // inversion flag indicating whether the member is inverted relative to
289 // representative
291 provenEquivalences;
292
293 std::unique_ptr<IncrementalSATSolver> satSolver;
294 std::unique_ptr<FunctionalReductionSATBuilder> satBuilder;
295 llvm::DenseMap<Value, int> satVars;
296 llvm::DenseSet<Value> encodedValues;
297 Stats stats;
298};
299
300FunctionalReductionSATBuilder::FunctionalReductionSATBuilder(
301 IncrementalSATSolver &solver, llvm::DenseMap<Value, int> &satVars,
302 llvm::DenseSet<Value> &encodedValues)
303 : solver(solver), satVars(satVars), encodedValues(encodedValues) {}
304
305Attribute FunctionalReductionSolver::getTestEquivClass(Value value) {
306 Operation *op = value.getDefiningOp();
307 if (!op)
308 return {};
309 return op->getAttr(kTestClassAttrName);
310}
311
312bool FunctionalReductionSolver::matchesTestEquivClass(Value lhs, Value rhs) {
313 Attribute lhsClass = getTestEquivClass(lhs);
314 Attribute rhsClass = getTestEquivClass(rhs);
315 return lhsClass && rhsClass && lhsClass == rhsClass;
316}
317
318EquivResult FunctionalReductionSolver::verifyEquivalence(Value lhs, Value rhs,
319 bool inverted) {
320
321 if (testTransformation) {
322 if (matchesTestEquivClass(lhs, rhs))
323 return EquivResult::Proved;
324 return EquivResult::Unknown;
325 }
326 assert(satBuilder && "SAT builder must be initialized before verification");
327 // SAT-based equivalence checking builds a miter for the two candidate nodes
328 // and proves that no input assignment can make them differ.
329 return satBuilder->verify(lhs, rhs, inverted);
330}
331
332void FunctionalReductionSolver::initializeSATState() {
333 assert(satSolver && "SAT solver must be initialized before SAT state setup");
334
335 satVars.clear();
336 encodedValues.clear();
337 satVars.reserve(allValues.size());
338 for (auto [index, value] : llvm::enumerate(allValues))
339 satVars[value] = index + 1;
340 satSolver->reserveVars(allValues.size());
341
342 satBuilder = std::make_unique<FunctionalReductionSATBuilder>(
343 *satSolver, satVars, encodedValues);
344}
345
346//===----------------------------------------------------------------------===//
347// Phase 1: Collect values and run simulation
348//===----------------------------------------------------------------------===//
349
350void FunctionalReductionSolver::collectValues() {
351 // Collect block arguments (primary inputs) that are i1
352 for (auto arg : module.getBodyBlock()->getArguments()) {
353 if (arg.getType().isInteger(1)) {
354 primaryInputs.push_back(arg);
355 allValues.push_back(arg);
356 }
357 }
358
359 // Walk operations and collect i1 results
360 // - AIG operations: add to allValues for simulation
361 // - Unknown operations: treat as inputs (assign random patterns)
362 module.walk([&](Operation *op) {
363 for (auto result : op->getResults()) {
364 if (!result.getType().isInteger(1))
365 continue;
366
367 allValues.push_back(result);
368 if (!op->hasTrait<OpTrait::ConstantLike>() &&
369 !isFunctionalReductionSimulatableOp(op)) {
370 // Unknown operations - treat as primary inputs
371 primaryInputs.push_back(result);
372 }
373 }
374 });
375
376 LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Collected "
377 << primaryInputs.size()
378 << " primary inputs (including unknown ops) and "
379 << allValues.size() << " total i1 values\n");
380}
381
382void FunctionalReductionSolver::runSimulation() {
383 // Calculate number of 64-bit words needed for numPatterns bits
384 unsigned numWords = numPatterns / 64;
385
386 // Create seeded random number generator for deterministic patterns
387 std::mt19937_64 rng(seed);
388
389 for (auto input : primaryInputs) {
390 // Generate random words using seeded RNG
391 SmallVector<uint64_t> words(numWords);
392 for (auto &word : words)
393 word = rng();
394
395 // Construct APInt directly from words
396 llvm::APInt pattern(numPatterns, words);
397 simSignatures[input] = pattern;
398 }
399
400 // Propagate simulation through the circuit in topological order
401 for (auto value : allValues) {
402 if (simSignatures.count(value))
403 continue; // Already computed (primary input)
404
405 simSignatures[value] = simulateValue(value);
406 }
407
408 LLVM_DEBUG({
409 llvm::dbgs() << "FunctionalReduction: Simulation complete with "
410 << numPatterns << " patterns\n";
411 });
412}
413
414llvm::APInt FunctionalReductionSolver::simulateValue(Value v) {
415 Operation *op = v.getDefiningOp();
416 if (!op)
417 return simSignatures.at(v);
418 return llvm::TypeSwitch<Operation *, llvm::APInt>(op)
419 .Case<BooleanLogicOpInterface>([&](auto op) {
420 return op.evaluateBooleanLogic([&](unsigned i) -> const APInt & {
421 return simSignatures.at(op.getInput(i));
422 });
423 })
424 .Case<comb::AndOp>([&](auto op) {
425 APInt result = APInt::getAllOnes(numPatterns);
426 for (auto input : op.getInputs())
427 result &= simSignatures.at(input);
428 return result;
429 })
430 .Case<comb::OrOp>([&](auto op) {
431 APInt result = APInt::getZero(numPatterns);
432 for (auto input : op.getInputs())
433 result |= simSignatures.at(input);
434 return result;
435 })
436 .Case<comb::XorOp>([&](auto op) {
437 APInt result = APInt::getZero(numPatterns);
438 for (auto input : op.getInputs())
439 result ^= simSignatures.at(input);
440 return result;
441 })
442 .Case([&](hw::ConstantOp op) {
443 return op.getValue().isZero() ? APInt::getZero(numPatterns)
444 : APInt::getAllOnes(numPatterns);
445 })
446 .Default([&](Operation *) {
447 // Unknown operation - treat as input (already assigned a random
448 // pattern)
449 return simSignatures.at(v);
450 });
451}
452
453//===----------------------------------------------------------------------===//
454// Phase 2: Build equivalence classes from simulation
455//===----------------------------------------------------------------------===//
456
457void FunctionalReductionSolver::buildEquivalenceClasses() {
458 // Map from canonical signature to list of {value, inverted pairs}
459 // Inverted signals share the same canonical signature since inversion
460 // is zero cost in synthesis
462 for (auto value : allValues) {
463 auto signature = simSignatures.at(value);
464 bool inverted = false;
465 if (signature.isNegative()) {
466 inverted = true;
467 signature.flipAllBits();
468 }
469 sigGroups[signature].push_back({value, inverted});
470 }
471
472 // Build equivalence candidates for groups with >1 member.
473 // Re-normalize so inverted is relative to representative (first member)
474 for (auto &[hash, members] : sigGroups) {
475 if (members.size() <= 1)
476 continue;
477 bool repInverted = members.front().second;
478 for (auto &[_, inv] : members)
479 inv ^= repInverted;
480 equivCandidates.push_back(std::move(members));
481 }
482 stats.numEquivClasses = equivCandidates.size();
483
484 LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Built "
485 << equivCandidates.size()
486 << " equivalence candidates\n");
487}
488
489//===----------------------------------------------------------------------===//
490// Phase 3: SAT-based verification with per-class solvers
491//
492// For each equivalence class candidates, verify each member against the
493// representative using a SAT solver.
494//===----------------------------------------------------------------------===//
495
496void FunctionalReductionSolver::verifyCandidates() {
497 LLVM_DEBUG(
498 llvm::dbgs() << "FunctionalReduction: Starting SAT verification with "
499 << equivCandidates.size() << " equivalence classes\n");
500
501 for (auto &members : equivCandidates) {
502 if (members.empty())
503 continue;
504 auto [representative, repInversion] = members.front();
505 assert(!repInversion && "representative must not be inverted");
506 (void)repInversion;
507 auto &provenMembers = provenEquivalences[representative];
508 // Representative is the canonical node for this class. Members can be
509 // inverted relative to the representative, tracked by the inversion flag
510 for (auto [member, inverted] :
511 llvm::ArrayRef<std::pair<Value, bool>>(members).drop_front()) {
512 EquivResult result = verifyEquivalence(representative, member, inverted);
513 if (result == EquivResult::Proved) {
514 stats.numProvedEquiv++;
515 provenMembers.push_back({member, inverted});
516 } else if (result == EquivResult::Disproved) {
517 stats.numDisprovedEquiv++;
518 // TODO: Refine equivalence classes based on counterexamples from SAT
519 // solver
520 } else {
521 stats.numUnknown++;
522 }
523 }
524 }
525
526 LLVM_DEBUG(
527 llvm::dbgs() << "FunctionalReduction: SAT verification complete. Proved "
528 << stats.numProvedEquiv << " equivalences\n");
529}
530
531//===----------------------------------------------------------------------===//
532// Phase 4: Merge equivalent nodes
533//===----------------------------------------------------------------------===//
534
535void FunctionalReductionSolver::mergeEquivalentNodes() {
536 if (provenEquivalences.empty())
537 return;
538
539 // Build all replacement IR first, then perform use rewrites in a second
540 // phase. This keeps `isBeforeInBlock` queries anchored to the final block
541 // order instead of an order that is still being mutated by insertion.
542 struct PlannedMember {
543 Value original;
544 bool inverted;
545 aig::AndInverterOp operandInverter;
546 };
547 struct MergeRewritePlan {
548 Value representative;
549 SmallVector<PlannedMember> members;
550 // Members which are at risk of reaching their representative
551 SmallVector<PlannedMember> reachableMembers;
552 synth::ChoiceOp choice;
553 aig::AndInverterOp choiceNot;
554 };
555
556 mlir::OpBuilder builder(module.getContext());
557 auto replaceDominatedUses =
558 [](Value from, Value to,
559 llvm::function_ref<bool(Operation *)> shouldReplaceOwner) {
560 auto *defOp = to.getDefiningOp();
561 assert(defOp && "replacement value must be defined by an operation");
562 from.replaceUsesWithIf(to, [&](OpOperand &use) {
563 auto *user = use.getOwner();
564 // Restrict rewrites to uses after the replacement value's definition
565 // in the same block so merging cannot introduce use-before-def edges
566 // or SSA cycles.
567 return shouldReplaceOwner(user) &&
568 user->getBlock() == defOp->getBlock();
569 });
570 };
571
572 DenseSet<Value> reachable;
573 auto visitFrom = [&](Value start) {
574 SmallVector<Value> stack;
575 stack.push_back(start);
576 while (!stack.empty()) {
577 Value current = stack.pop_back_val();
578 if (!reachable.insert(current).second)
579 continue;
580 for (Operation *user : current.getUsers())
581 if (isLogicNetworkOp(user))
582 for (Value result : user->getResults())
583 stack.push_back(result);
584 }
585 };
586
587 SmallVector<MergeRewritePlan> rewritePlans;
588 rewritePlans.reserve(provenEquivalences.size());
589 for (auto provenEquivSet : provenEquivalences) {
590 auto &[representative, members] = provenEquivSet;
591 if (members.empty())
592 continue;
593 // Mark all values reachable from representative before checking members.
594 visitFrom(representative);
595
596 // Greedily filter for members that can create a cycle with representative
597 SmallVector<std::pair<Value, bool>> safeMembers;
598 SmallVector<PlannedMember> plannedReachable;
599 for (auto [member, inverted] : members) {
600 if (reachable.count(member)) {
601 plannedReachable.push_back({member, inverted, {}});
602 continue;
603 }
604 visitFrom(member); // Visit users
605 safeMembers.push_back({member, inverted});
606 }
607
608 if (safeMembers.empty())
609 continue;
610
611 builder.setInsertionPointAfterValue(safeMembers.back().first);
612
613 SmallVector<Value> operands;
614 operands.reserve(safeMembers.size() + 1);
615 operands.push_back(representative);
616
617 SmallVector<PlannedMember> plannedMembers;
618 plannedMembers.reserve(safeMembers.size());
619 bool hasInvertedMember = false;
620 for (auto [member, inverted] : safeMembers) {
621 auto &planned =
622 plannedMembers.emplace_back(PlannedMember{member, inverted, {}});
623 if (!inverted) {
624 operands.push_back(member);
625 continue;
626 }
627 hasInvertedMember = true;
628 // If the member is inverted relative to the representative, we
629 // create an inverter for the choice operand
630 planned.operandInverter =
631 aig::AndInverterOp::create(builder, member.getLoc(), member, true);
632 operands.push_back(planned.operandInverter.getResult());
633 }
634
635 auto choice = synth::ChoiceOp::create(builder, representative.getLoc(),
636 representative.getType(), operands);
637
638 // If there is an inverted member, we need to create an inverter for the
639 // choice result as well
640 auto choiceNot = !hasInvertedMember
641 ? nullptr
642 : aig::AndInverterOp::create(builder, choice.getLoc(),
643 choice, true);
644
645 stats.numMergedNodes += safeMembers.size() + 1;
646 rewritePlans.push_back({representative, std::move(plannedMembers),
647 std::move(plannedReachable), choice, choiceNot});
648 }
649
650 for (auto &plan : rewritePlans) {
651 auto replaceValue = [&](const PlannedMember &member) {
652 if (member.inverted)
653 replaceDominatedUses(member.original, plan.choiceNot,
654 [&](Operation *user) {
655 // Do not rewrite the freshly created operand
656 // inverter or the choice result inverter. This
657 // avoids creating an immediate cycle when
658 // merging an inverted node into its
659 // representative.
660 return user != member.operandInverter &&
661 user != plan.choiceNot.getOperation();
662 });
663 else
664 replaceDominatedUses(member.original, plan.choice,
665 [&](Operation *user) {
666 return user != plan.choice.getOperation();
667 });
668 };
669
670 replaceDominatedUses(
671 plan.representative, plan.choice,
672 [&](Operation *user) { return user != plan.choice.getOperation(); });
673 for (const auto &member : plan.members)
674 replaceValue(member);
675
676 // Reachable members are redundant here so either replace their uses with
677 // choice or erase if they have no uses left.
678 for (auto &member : plan.reachableMembers) {
679 member.original.replaceUsesWithIf(plan.choice, [&](OpOperand &use) {
680 auto *user = use.getOwner();
681 return user->getBlock() == plan.choice->getBlock();
682 });
683 if (member.original.use_empty())
684 member.original.getDefiningOp()->erase();
685 }
686 }
687
688 LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Merged "
689 << stats.numMergedNodes << " nodes\n");
690}
691
692//===----------------------------------------------------------------------===//
693// Main Functional Reduction algorithm
694//===----------------------------------------------------------------------===//
695
696mlir::FailureOr<FunctionalReductionSolver::Stats>
697FunctionalReductionSolver::run() {
698 LLVM_DEBUG(
699 llvm::dbgs() << "FunctionalReduction: Starting functional reduction with "
700 << numPatterns << " simulation patterns\n");
701
702 if (!testTransformation && !satSolver) {
703 module->emitError()
704 << "FunctionalReduction requires a SAT solver, but none is "
705 "available in this build";
706 return failure();
707 }
708
709 // Topologically sort the values
711 module->emitError()
712 << "FunctionalReduction: Failed to topologically sort logic network";
713 return failure();
714 }
715
716 // Phase 1: Collect values and run simulation
717 collectValues();
718 if (allValues.empty()) {
719 LLVM_DEBUG(llvm::dbgs()
720 << "FunctionalReduction: No i1 values to process\n");
721 return stats;
722 }
723
724 runSimulation();
725
726 // Phase 2: Build equivalence classes
727 buildEquivalenceClasses();
728 if (equivCandidates.empty()) {
729 LLVM_DEBUG(llvm::dbgs()
730 << "FunctionalReduction: No equivalence candidates found\n");
731 return stats;
732 }
733
734 // Phase 3: SAT-based verification
735 if (!testTransformation)
736 initializeSATState();
737 verifyCandidates();
738
739 // Phase 4: Merge equivalent nodes
740 mergeEquivalentNodes();
741
742 // Re-sort after merging to restore topological order after choice insertion.
744 module->emitError()
745 << "FunctionalReduction: Failed to topologically sort logic network";
746 return failure();
747 }
748
749 LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Complete. Stats:\n"
750 << " Equivalence classes: " << stats.numEquivClasses
751 << "\n"
752 << " Proved: " << stats.numProvedEquiv << "\n"
753 << " Disproved: " << stats.numDisprovedEquiv << "\n"
754 << " Unknown (limit): " << stats.numUnknown << "\n"
755 << " Merged: " << stats.numMergedNodes << "\n");
756
757 return stats;
758}
759
760//===----------------------------------------------------------------------===//
761// Pass implementation
762//===----------------------------------------------------------------------===//
763
764struct FunctionalReductionPass
765 : public circt::synth::impl::FunctionalReductionBase<
766 FunctionalReductionPass> {
767 using FunctionalReductionBase::FunctionalReductionBase;
768 void updateStats(const FunctionalReductionSolver::Stats &stats) {
769 numEquivClasses += stats.numEquivClasses;
770 numProvedEquiv += stats.numProvedEquiv;
771 numDisprovedEquiv += stats.numDisprovedEquiv;
772 numUnknown += stats.numUnknown;
773 numMergedNodes += stats.numMergedNodes;
774 }
775
776 void runOnOperation() override {
777 auto module = getOperation();
778 LLVM_DEBUG(llvm::dbgs() << "Running FunctionalReduction pass on "
779 << module.getName() << "\n");
780
781 if (numRandomPatterns == 0 || (numRandomPatterns & 63U) != 0) {
782 module.emitError()
783 << "'num-random-patterns' must be a positive multiple of 64";
784 return signalPassFailure();
785 }
786 if (conflictLimit < -1) {
787 module.emitError()
788 << "'conflict-limit' must be greater than or equal to -1";
789 return signalPassFailure();
790 }
791
792 std::unique_ptr<IncrementalSATSolver> satSolver;
793 if (!testTransformation) {
794 satSolver = createFunctionalReductionSATSolver(this->satSolver);
795 if (!satSolver) {
796 module.emitError() << "unsupported or unavailable SAT solver '"
797 << this->satSolver
798 << "' (expected auto, z3, or cadical)";
799 return signalPassFailure();
800 }
801 satSolver->setConflictLimit(static_cast<int>(conflictLimit));
802 }
803
804 FunctionalReductionSolver fcSolver(module, numRandomPatterns, seed,
805 testTransformation,
806 std::move(satSolver));
807 auto stats = fcSolver.run();
808 if (failed(stats))
809 return signalPassFailure();
810 updateStats(*stats);
811 if (stats->numMergedNodes == 0)
812 markAllAnalysesPreserved();
813 }
814};
815
816} // namespace
assert(baseType &&"element must be base type")
static constexpr llvm::StringLiteral kTestClassAttrName
static Block * getBodyBlock(FModuleLike mod)
RewritePatternSet pattern
Abstract interface for incremental SAT solvers with an IPASIR-style API.
Definition SATSolver.h:158
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
bool isLogicNetworkOp(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void addAndClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> and(inputLits).
std::unique_ptr< IncrementalSATSolver > createZ3SATSolver()
Construct a Z3-backed incremental IPASIR-style SAT solver.
std::unique_ptr< IncrementalSATSolver > createCadicalSATSolver(const CadicalSATSolverOptions &options={})
Construct a CaDiCaL-backed incremental IPASIR-style SAT solver.
void addOrClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> or(inputLits).
void addParityClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause, llvm::function_ref< int()> newVar)
Emit clauses encoding outVar <=> parity(inputLits).
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:879
Definition synth.py:1