CIRCT 23.0.0git
Loading...
Searching...
No Matches
FunctionalReduction.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements FunctionalReduction (Functionally Reduced And-Inverter
10// Graph) optimization. It identifies and merges functionally equivalent nodes
11// through simulation-based candidate detection followed by SAT-based
12// verification.
13//
14//===----------------------------------------------------------------------===//
15
23#include "mlir/IR/Attributes.h"
24#include "mlir/IR/Builders.h"
25#include "mlir/IR/BuiltinOps.h"
26#include "mlir/IR/PatternMatch.h"
27#include "mlir/Pass/Pass.h"
28#include "mlir/Support/LogicalResult.h"
29#include "llvm/ADT/APInt.h"
30#include "llvm/ADT/ArrayRef.h"
31#include "llvm/ADT/DenseMap.h"
32#include "llvm/ADT/DenseSet.h"
33#include "llvm/ADT/MapVector.h"
34#include "llvm/ADT/STLFunctionalExtras.h"
35#include "llvm/ADT/SmallVector.h"
36#include "llvm/ADT/StringRef.h"
37#include "llvm/ADT/TypeSwitch.h"
38#include "llvm/Support/Debug.h"
39#include <random>
40
41#define DEBUG_TYPE "synth-functional-reduction"
42
43static constexpr llvm::StringLiteral kTestClassAttrName =
44 "synth.test.fc_equiv_class";
45
46namespace circt {
47namespace synth {
48#define GEN_PASS_DEF_FUNCTIONALREDUCTION
49#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
50} // namespace synth
51} // namespace circt
52
53using namespace circt;
54using namespace circt::synth;
55
56namespace {
57enum class EquivResult { Proved, Disproved, Unknown };
58
59std::unique_ptr<IncrementalSATSolver>
60createFunctionalReductionSATSolver(llvm::StringRef backend) {
61 if (backend == "auto") {
62 if (auto solver = createCadicalSATSolver())
63 return solver;
64 return createZ3SATSolver();
65 }
66 if (backend == "cadical")
68 if (backend == "z3")
69 return createZ3SATSolver();
70 return {};
71}
72
73class FunctionalReductionSATBuilder {
74public:
75 FunctionalReductionSATBuilder(IncrementalSATSolver &solver,
76 llvm::DenseMap<Value, int> &satVars,
77 llvm::DenseSet<Value> &encodedValues);
78
79 // If inverted, negates rhs in the SAT encoding to check lhs == NOT(rhs).
80 EquivResult verify(Value lhs, Value rhs, bool inverted);
81
82private:
83 int getOrCreateVar(Value value);
84 // Create a fresh SAT variable for an intermediate Boolean subexpression that
85 // does not correspond to an MLIR value.
86 int createAuxVar();
87 SmallVector<int> getOperandVars(ValueRange operands);
88 void encodeValue(Value value);
89
91 llvm::DenseMap<Value, int> &satVars;
92 llvm::DenseSet<Value> &encodedValues;
93};
94
95static bool isFunctionalReductionSimulatableOp(Operation *op) {
96 return isa<BooleanLogicOpInterface, comb::AndOp, comb::OrOp, comb::XorOp>(op);
97}
98
99EquivResult FunctionalReductionSATBuilder::verify(Value lhs, Value rhs,
100 bool inverted) {
101 encodeValue(lhs);
102 encodeValue(rhs);
103
104 int lhsVar = getOrCreateVar(lhs);
105 int rhsVar = getOrCreateVar(rhs);
106
107 if (inverted)
108 rhsVar = -rhsVar;
109 // Check the two halves of the XOR miter separately. If either assignment is
110 // satisfiable, the solver found a distinguishing input pattern.
111 solver.assume(lhsVar);
112 solver.assume(-rhsVar);
113 auto result = solver.solve();
114 if (result == IncrementalSATSolver::kSAT)
115 return EquivResult::Disproved;
116 if (result != IncrementalSATSolver::kUNSAT)
117 return EquivResult::Unknown;
118
119 solver.assume(-lhsVar);
120 solver.assume(rhsVar);
121 result = solver.solve();
122 if (result == IncrementalSATSolver::kSAT)
123 return EquivResult::Disproved;
124 if (result != IncrementalSATSolver::kUNSAT)
125 return EquivResult::Unknown;
126
127 return EquivResult::Proved;
128}
129
130int FunctionalReductionSATBuilder::getOrCreateVar(Value value) {
131 auto it = satVars.find(value);
132 assert(it != satVars.end() && "SAT variable must be preallocated");
133 return it->second;
134}
135
136int FunctionalReductionSATBuilder::createAuxVar() { return solver.newVar(); }
137
138SmallVector<int>
139FunctionalReductionSATBuilder::getOperandVars(ValueRange operands) {
140 SmallVector<int> vars;
141 vars.reserve(operands.size());
142 for (auto operand : operands)
143 vars.push_back(getOrCreateVar(operand));
144 return vars;
145}
146
147void FunctionalReductionSATBuilder::encodeValue(Value value) {
148 SmallVector<std::pair<Value, bool>> worklist;
149 worklist.push_back({value, false});
150
151 while (!worklist.empty()) {
152 auto [current, readyToEncode] = worklist.pop_back_val();
153 if (encodedValues.contains(current))
154 continue;
155
156 Operation *op = current.getDefiningOp();
157 if (!op) {
158 encodedValues.insert(current);
159 continue;
160 }
161
162 APInt constantValue;
163 if (matchPattern(current, mlir::m_ConstantInt(&constantValue))) {
164 encodedValues.insert(current);
165 solver.addClause({constantValue.isZero() ? -getOrCreateVar(current)
166 : getOrCreateVar(current)});
167 continue;
168 }
169
170 if (!isFunctionalReductionSimulatableOp(op)) {
171 // Unsupported operations remain unconstrained, just like block
172 // arguments. Since we only prove equivalence from UNSAT, omitting these
173 // clauses may miss a proof but cannot create a false proof.
174 encodedValues.insert(current);
175 continue;
176 }
177
178 if (!readyToEncode) {
179 worklist.push_back({current, true});
180 for (auto input : op->getOperands()) {
181 assert(input.getType().isInteger(1) &&
182 "only i1 inputs should be simulated or encoded");
183 if (!encodedValues.contains(input))
184 worklist.push_back({input, false});
185 }
186 continue;
187 }
188
189 encodedValues.insert(current);
190 int outVar = getOrCreateVar(current);
191 auto addClause = [&](llvm::ArrayRef<int> clause) {
192 solver.addClause(clause);
193 };
194
195 TypeSwitch<Operation *>(op)
196 .Case<BooleanLogicOpInterface>([&](auto logicOp) {
197 auto inputVars = getOperandVars(logicOp.getInputs());
198 logicOp.emitCNF(outVar, inputVars, addClause,
199 [&]() { return createAuxVar(); });
200 })
201 .Case<comb::AndOp>([&](auto andOp) {
202 auto inputLits = getOperandVars(andOp.getInputs());
203 circt::addAndClauses(outVar, inputLits, addClause);
204 })
205 .Case<comb::OrOp>([&](auto orOp) {
206 auto inputLits = getOperandVars(orOp.getInputs());
207 circt::addOrClauses(outVar, inputLits, addClause);
208 })
209 .Case<comb::XorOp>([&](auto xorOp) {
210 auto inputLits = getOperandVars(xorOp.getInputs());
211 circt::addParityClauses(outVar, inputLits, addClause,
212 [&]() { return createAuxVar(); });
213 })
214 .Default(
215 [](Operation *) { llvm_unreachable("unexpected supported op"); });
216 }
217}
218
219//===----------------------------------------------------------------------===//
220// Core Functional Reduction Implementation
221//===----------------------------------------------------------------------===//
222
223class FunctionalReductionSolver {
224public:
225 FunctionalReductionSolver(hw::HWModuleOp module, unsigned numPatterns,
226 unsigned seed, bool testTransformation,
227 std::unique_ptr<IncrementalSATSolver> satSolver)
228 : module(module), numPatterns(numPatterns), seed(seed),
229 testTransformation(testTransformation),
230 satSolver(std::move(satSolver)) {}
231
232 ~FunctionalReductionSolver() = default;
233
234 /// Run the Functional Reduction algorithm and return statistics.
235 struct Stats {
236 unsigned numEquivClasses = 0;
237 unsigned numProvedEquiv = 0;
238 unsigned numDisprovedEquiv = 0;
239 unsigned numUnknown = 0;
240 unsigned numMergedNodes = 0;
241 };
242 mlir::FailureOr<Stats> run();
243
244private:
245 // Phase 1: Collect i1 values and run simulation
246 void collectValues();
247 void runSimulation();
248 llvm::APInt simulateValue(Value v);
249
250 // Phase 2: Build equivalence classes from simulation
251 void buildEquivalenceClasses();
252
253 // Phase 3: SAT-based verification with per-class solver
254 void verifyCandidates();
255 void initializeSATState();
256
257 // Phase 4: Merge equivalent nodes
258 void mergeEquivalentNodes();
259
260 // Test transformation helpers.
261 static Attribute getTestEquivClass(Value value);
262 static bool matchesTestEquivClass(Value lhs, Value rhs);
263 EquivResult verifyEquivalence(Value lhs, Value rhs, bool inverted);
264
265 // Module being processed
266 hw::HWModuleOp module;
267
268 // Configuration
269 unsigned numPatterns;
270 unsigned seed;
271 bool testTransformation;
272
273 // Primary inputs (block arguments or results of unknown operations treated as
274 // inputs)
275 SmallVector<Value> primaryInputs;
276
277 // All i1 values in topological order
278 SmallVector<Value> allValues;
279
280 // Simulation signatures: value -> APInt simulation result
281 llvm::DenseMap<Value, llvm::APInt> simSignatures;
282
283 // Equivalence candidates: groups of values with identical or inverted
284 // simulation signatures, tracked with an inversion flag
285 SmallVector<SmallVector<std::pair<Value, bool>>> equivCandidates;
286
287 // Proven equivalences: representative -> proven equivalent members with
288 // inversion flag indicating whether the member is inverted relative to
289 // representative
291 provenEquivalences;
292
293 std::unique_ptr<IncrementalSATSolver> satSolver;
294 std::unique_ptr<FunctionalReductionSATBuilder> satBuilder;
295 llvm::DenseMap<Value, int> satVars;
296 llvm::DenseSet<Value> encodedValues;
297 Stats stats;
298};
299
300FunctionalReductionSATBuilder::FunctionalReductionSATBuilder(
301 IncrementalSATSolver &solver, llvm::DenseMap<Value, int> &satVars,
302 llvm::DenseSet<Value> &encodedValues)
303 : solver(solver), satVars(satVars), encodedValues(encodedValues) {}
304
305Attribute FunctionalReductionSolver::getTestEquivClass(Value value) {
306 Operation *op = value.getDefiningOp();
307 if (!op)
308 return {};
309 return op->getAttr(kTestClassAttrName);
310}
311
312bool FunctionalReductionSolver::matchesTestEquivClass(Value lhs, Value rhs) {
313 Attribute lhsClass = getTestEquivClass(lhs);
314 Attribute rhsClass = getTestEquivClass(rhs);
315 return lhsClass && rhsClass && lhsClass == rhsClass;
316}
317
318EquivResult FunctionalReductionSolver::verifyEquivalence(Value lhs, Value rhs,
319 bool inverted) {
320
321 if (testTransformation) {
322 if (matchesTestEquivClass(lhs, rhs))
323 return EquivResult::Proved;
324 return EquivResult::Unknown;
325 }
326 assert(satBuilder && "SAT builder must be initialized before verification");
327 // SAT-based equivalence checking builds a miter for the two candidate nodes
328 // and proves that no input assignment can make them differ.
329 return satBuilder->verify(lhs, rhs, inverted);
330}
331
332void FunctionalReductionSolver::initializeSATState() {
333 assert(satSolver && "SAT solver must be initialized before SAT state setup");
334
335 satVars.clear();
336 encodedValues.clear();
337 satVars.reserve(allValues.size());
338 for (auto [index, value] : llvm::enumerate(allValues))
339 satVars[value] = index + 1;
340 satSolver->reserveVars(allValues.size());
341
342 satBuilder = std::make_unique<FunctionalReductionSATBuilder>(
343 *satSolver, satVars, encodedValues);
344}
345
346//===----------------------------------------------------------------------===//
347// Phase 1: Collect values and run simulation
348//===----------------------------------------------------------------------===//
349
350void FunctionalReductionSolver::collectValues() {
351
352 // Seed zero constants so nodes can be merged
353 // if input IR does not contain constants already.
354 OpBuilder builder(module.getContext());
355 builder.setInsertionPointToStart(module.getBodyBlock());
356 auto i1Type = builder.getIntegerType(1);
357 hw::ConstantOp::create(builder, module.getLoc(), i1Type, 0);
358
359 // Collect block arguments (primary inputs) that are i1
360 for (auto arg : module.getBodyBlock()->getArguments()) {
361 if (arg.getType().isInteger(1)) {
362 primaryInputs.push_back(arg);
363 allValues.push_back(arg);
364 }
365 }
366
367 // Walk operations and collect i1 results
368 // - AIG operations: add to allValues for simulation
369 // - Unknown operations: treat as inputs (assign random patterns)
370 module.walk([&](Operation *op) {
371 for (auto result : op->getResults()) {
372 if (!result.getType().isInteger(1))
373 continue;
374
375 allValues.push_back(result);
376 if (!op->hasTrait<OpTrait::ConstantLike>() &&
377 !isFunctionalReductionSimulatableOp(op)) {
378 // Unknown operations - treat as primary inputs
379 primaryInputs.push_back(result);
380 }
381 }
382 });
383
384 LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Collected "
385 << primaryInputs.size()
386 << " primary inputs (including unknown ops) and "
387 << allValues.size() << " total i1 values\n");
388}
389
390void FunctionalReductionSolver::runSimulation() {
391 // Calculate number of 64-bit words needed for numPatterns bits
392 unsigned numWords = numPatterns / 64;
393
394 // Create seeded random number generator for deterministic patterns
395 std::mt19937_64 rng(seed);
396
397 for (auto input : primaryInputs) {
398 // Generate random words using seeded RNG
399 SmallVector<uint64_t> words(numWords);
400 for (auto &word : words)
401 word = rng();
402
403 // Construct APInt directly from words
404 llvm::APInt pattern(numPatterns, words);
405 simSignatures[input] = pattern;
406 }
407
408 // Propagate simulation through the circuit in topological order
409 for (auto value : allValues) {
410 if (simSignatures.count(value))
411 continue; // Already computed (primary input)
412
413 simSignatures[value] = simulateValue(value);
414 }
415
416 LLVM_DEBUG({
417 llvm::dbgs() << "FunctionalReduction: Simulation complete with "
418 << numPatterns << " patterns\n";
419 });
420}
421
422llvm::APInt FunctionalReductionSolver::simulateValue(Value v) {
423 Operation *op = v.getDefiningOp();
424 if (!op)
425 return simSignatures.at(v);
426 return llvm::TypeSwitch<Operation *, llvm::APInt>(op)
427 .Case<BooleanLogicOpInterface>([&](auto op) {
428 return op.evaluateBooleanLogic([&](unsigned i) -> const APInt & {
429 return simSignatures.at(op.getInput(i));
430 });
431 })
432 .Case<comb::AndOp>([&](auto op) {
433 APInt result = APInt::getAllOnes(numPatterns);
434 for (auto input : op.getInputs())
435 result &= simSignatures.at(input);
436 return result;
437 })
438 .Case<comb::OrOp>([&](auto op) {
439 APInt result = APInt::getZero(numPatterns);
440 for (auto input : op.getInputs())
441 result |= simSignatures.at(input);
442 return result;
443 })
444 .Case<comb::XorOp>([&](auto op) {
445 APInt result = APInt::getZero(numPatterns);
446 for (auto input : op.getInputs())
447 result ^= simSignatures.at(input);
448 return result;
449 })
450 .Case([&](hw::ConstantOp op) {
451 return op.getValue().isZero() ? APInt::getZero(numPatterns)
452 : APInt::getAllOnes(numPatterns);
453 })
454 .Default([&](Operation *) {
455 // Unknown operation - treat as input (already assigned a random
456 // pattern)
457 return simSignatures.at(v);
458 });
459}
460
461//===----------------------------------------------------------------------===//
462// Phase 2: Build equivalence classes from simulation
463//===----------------------------------------------------------------------===//
464
465void FunctionalReductionSolver::buildEquivalenceClasses() {
466 // Map from canonical signature to list of {value, inverted pairs}
467 // Inverted signals share the same canonical signature since inversion
468 // is zero cost in synthesis
470 for (auto value : allValues) {
471 auto signature = simSignatures.at(value);
472 bool inverted = false;
473 if (signature.isNegative()) {
474 inverted = true;
475 signature.flipAllBits();
476 }
477 sigGroups[signature].push_back({value, inverted});
478 }
479
480 // Build equivalence candidates for groups with >1 member.
481 // Re-normalize so inverted is relative to representative (first member)
482 for (auto &[hash, members] : sigGroups) {
483 if (members.size() <= 1)
484 continue;
485 bool repInverted = members.front().second;
486 for (auto &[_, inv] : members)
487 inv ^= repInverted;
488 equivCandidates.push_back(std::move(members));
489 }
490 stats.numEquivClasses = equivCandidates.size();
491
492 LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Built "
493 << equivCandidates.size()
494 << " equivalence candidates\n");
495}
496
497//===----------------------------------------------------------------------===//
498// Phase 3: SAT-based verification with per-class solvers
499//
500// For each equivalence class candidates, verify each member against the
501// representative using a SAT solver.
502//===----------------------------------------------------------------------===//
503
504void FunctionalReductionSolver::verifyCandidates() {
505 LLVM_DEBUG(
506 llvm::dbgs() << "FunctionalReduction: Starting SAT verification with "
507 << equivCandidates.size() << " equivalence classes\n");
508
509 for (auto &members : equivCandidates) {
510 if (members.empty())
511 continue;
512 auto [representative, repInversion] = members.front();
513 assert(!repInversion && "representative must not be inverted");
514 (void)repInversion;
515 auto &provenMembers = provenEquivalences[representative];
516 // Representative is the canonical node for this class. Members can be
517 // inverted relative to the representative, tracked by the inversion flag
518 for (auto [member, inverted] :
519 llvm::ArrayRef<std::pair<Value, bool>>(members).drop_front()) {
520 EquivResult result = verifyEquivalence(representative, member, inverted);
521 if (result == EquivResult::Proved) {
522 stats.numProvedEquiv++;
523 provenMembers.push_back({member, inverted});
524 } else if (result == EquivResult::Disproved) {
525 stats.numDisprovedEquiv++;
526 // TODO: Refine equivalence classes based on counterexamples from SAT
527 // solver
528 } else {
529 stats.numUnknown++;
530 }
531 }
532 }
533
534 LLVM_DEBUG(
535 llvm::dbgs() << "FunctionalReduction: SAT verification complete. Proved "
536 << stats.numProvedEquiv << " equivalences\n");
537}
538
539//===----------------------------------------------------------------------===//
540// Phase 4: Merge equivalent nodes
541//===----------------------------------------------------------------------===//
542
543void FunctionalReductionSolver::mergeEquivalentNodes() {
544 if (provenEquivalences.empty())
545 return;
546
547 // Build all replacement IR first, then perform use rewrites in a second
548 // phase. This keeps `isBeforeInBlock` queries anchored to the final block
549 // order instead of an order that is still being mutated by insertion.
550 struct PlannedMember {
551 Value original;
552 bool inverted;
553 aig::AndInverterOp operandInverter;
554 };
555 struct MergeRewritePlan {
556 Value representative;
557 SmallVector<PlannedMember> members;
558 // Members which are at risk of reaching their representative
559 SmallVector<PlannedMember> reachableMembers;
560 synth::ChoiceOp choice;
561 aig::AndInverterOp choiceNot;
562 };
563
564 mlir::OpBuilder builder(module.getContext());
565 auto replaceDominatedUses =
566 [](Value from, Value to,
567 llvm::function_ref<bool(Operation *)> shouldReplaceOwner) {
568 auto *defOp = to.getDefiningOp();
569 assert(defOp && "replacement value must be defined by an operation");
570 from.replaceUsesWithIf(to, [&](OpOperand &use) {
571 auto *user = use.getOwner();
572 // Restrict rewrites to uses after the replacement value's definition
573 // in the same block so merging cannot introduce use-before-def edges
574 // or SSA cycles.
575 return shouldReplaceOwner(user) &&
576 user->getBlock() == defOp->getBlock();
577 });
578 };
579
580 DenseSet<Value> reachable;
581 auto visitFrom = [&](Value start) {
582 SmallVector<Value> stack;
583 stack.push_back(start);
584 while (!stack.empty()) {
585 Value current = stack.pop_back_val();
586 if (!reachable.insert(current).second)
587 continue;
588 for (Operation *user : current.getUsers())
589 if (isLogicNetworkOp(user))
590 for (Value result : user->getResults())
591 stack.push_back(result);
592 }
593 };
594
595 SmallVector<MergeRewritePlan> rewritePlans;
596 rewritePlans.reserve(provenEquivalences.size());
597 for (auto provenEquivSet : provenEquivalences) {
598 auto &[representative, members] = provenEquivSet;
599 if (members.empty())
600 continue;
601 // Mark all values reachable from representative before checking members.
602 visitFrom(representative);
603
604 // Greedily filter for members that can create a cycle with representative
605 SmallVector<std::pair<Value, bool>> safeMembers;
606 SmallVector<PlannedMember> plannedReachable;
607 for (auto [member, inverted] : members) {
608 if (reachable.count(member)) {
609 plannedReachable.push_back({member, inverted, {}});
610 continue;
611 }
612 visitFrom(member); // Visit users
613 safeMembers.push_back({member, inverted});
614 }
615
616 if (safeMembers.empty())
617 continue;
618
619 builder.setInsertionPointAfterValue(safeMembers.back().first);
620
621 SmallVector<Value> operands;
622 operands.reserve(safeMembers.size() + 1);
623 operands.push_back(representative);
624
625 SmallVector<PlannedMember> plannedMembers;
626 plannedMembers.reserve(safeMembers.size());
627 bool hasInvertedMember = false;
628 for (auto [member, inverted] : safeMembers) {
629 auto &planned =
630 plannedMembers.emplace_back(PlannedMember{member, inverted, {}});
631 if (!inverted) {
632 operands.push_back(member);
633 continue;
634 }
635 hasInvertedMember = true;
636 // If the member is inverted relative to the representative, we
637 // create an inverter for the choice operand
638 planned.operandInverter =
639 aig::AndInverterOp::create(builder, member.getLoc(), member, true);
640 operands.push_back(planned.operandInverter.getResult());
641 }
642
643 auto choice = synth::ChoiceOp::create(builder, representative.getLoc(),
644 representative.getType(), operands);
645
646 // If there is an inverted member, we need to create an inverter for the
647 // choice result as well
648 auto choiceNot = !hasInvertedMember
649 ? nullptr
650 : aig::AndInverterOp::create(builder, choice.getLoc(),
651 choice, true);
652
653 stats.numMergedNodes += safeMembers.size() + 1;
654 rewritePlans.push_back({representative, std::move(plannedMembers),
655 std::move(plannedReachable), choice, choiceNot});
656 }
657
658 for (auto &plan : rewritePlans) {
659 auto replaceValue = [&](const PlannedMember &member) {
660 if (member.inverted)
661 replaceDominatedUses(member.original, plan.choiceNot,
662 [&](Operation *user) {
663 // Do not rewrite the freshly created operand
664 // inverter or the choice result inverter. This
665 // avoids creating an immediate cycle when
666 // merging an inverted node into its
667 // representative.
668 return user != member.operandInverter &&
669 user != plan.choiceNot.getOperation();
670 });
671 else
672 replaceDominatedUses(member.original, plan.choice,
673 [&](Operation *user) {
674 return user != plan.choice.getOperation();
675 });
676 };
677
678 replaceDominatedUses(
679 plan.representative, plan.choice,
680 [&](Operation *user) { return user != plan.choice.getOperation(); });
681 for (const auto &member : plan.members)
682 replaceValue(member);
683
684 // Reachable members are redundant here so either replace their uses with
685 // choice or erase if they have no uses left.
686 for (auto &member : plan.reachableMembers) {
687 member.original.replaceUsesWithIf(plan.choice, [&](OpOperand &use) {
688 auto *user = use.getOwner();
689 return user->getBlock() == plan.choice->getBlock();
690 });
691 if (member.original.use_empty())
692 member.original.getDefiningOp()->erase();
693 }
694 }
695
696 LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Merged "
697 << stats.numMergedNodes << " nodes\n");
698}
699
700//===----------------------------------------------------------------------===//
701// Main Functional Reduction algorithm
702//===----------------------------------------------------------------------===//
703
704mlir::FailureOr<FunctionalReductionSolver::Stats>
705FunctionalReductionSolver::run() {
706 LLVM_DEBUG(
707 llvm::dbgs() << "FunctionalReduction: Starting functional reduction with "
708 << numPatterns << " simulation patterns\n");
709
710 if (!testTransformation && !satSolver) {
711 module->emitError()
712 << "FunctionalReduction requires a SAT solver, but none is "
713 "available in this build";
714 return failure();
715 }
716
717 // Topologically sort the values
719 module->emitError()
720 << "FunctionalReduction: Failed to topologically sort logic network";
721 return failure();
722 }
723
724 // Phase 1: Collect values and run simulation
725 collectValues();
726 if (allValues.empty()) {
727 LLVM_DEBUG(llvm::dbgs()
728 << "FunctionalReduction: No i1 values to process\n");
729 return stats;
730 }
731
732 runSimulation();
733
734 // Phase 2: Build equivalence classes
735 buildEquivalenceClasses();
736 if (equivCandidates.empty()) {
737 LLVM_DEBUG(llvm::dbgs()
738 << "FunctionalReduction: No equivalence candidates found\n");
739 return stats;
740 }
741
742 // Phase 3: SAT-based verification
743 if (!testTransformation)
744 initializeSATState();
745 verifyCandidates();
746
747 // Phase 4: Merge equivalent nodes
748 mergeEquivalentNodes();
749
750 // Re-sort after merging to restore topological order after choice insertion.
752 module->emitError()
753 << "FunctionalReduction: Failed to topologically sort logic network";
754 return failure();
755 }
756
757 LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Complete. Stats:\n"
758 << " Equivalence classes: " << stats.numEquivClasses
759 << "\n"
760 << " Proved: " << stats.numProvedEquiv << "\n"
761 << " Disproved: " << stats.numDisprovedEquiv << "\n"
762 << " Unknown (limit): " << stats.numUnknown << "\n"
763 << " Merged: " << stats.numMergedNodes << "\n");
764
765 return stats;
766}
767
768//===----------------------------------------------------------------------===//
769// Pass implementation
770//===----------------------------------------------------------------------===//
771
772struct FunctionalReductionPass
773 : public circt::synth::impl::FunctionalReductionBase<
774 FunctionalReductionPass> {
775 using FunctionalReductionBase::FunctionalReductionBase;
776 void updateStats(const FunctionalReductionSolver::Stats &stats) {
777 numEquivClasses += stats.numEquivClasses;
778 numProvedEquiv += stats.numProvedEquiv;
779 numDisprovedEquiv += stats.numDisprovedEquiv;
780 numUnknown += stats.numUnknown;
781 numMergedNodes += stats.numMergedNodes;
782 }
783
784 void runOnOperation() override {
785 auto module = getOperation();
786 LLVM_DEBUG(llvm::dbgs() << "Running FunctionalReduction pass on "
787 << module.getName() << "\n");
788
789 if (numRandomPatterns == 0 || (numRandomPatterns & 63U) != 0) {
790 module.emitError()
791 << "'num-random-patterns' must be a positive multiple of 64";
792 return signalPassFailure();
793 }
794 if (conflictLimit < -1) {
795 module.emitError()
796 << "'conflict-limit' must be greater than or equal to -1";
797 return signalPassFailure();
798 }
799
800 std::unique_ptr<IncrementalSATSolver> satSolver;
801 if (!testTransformation) {
802 satSolver = createFunctionalReductionSATSolver(this->satSolver);
803 if (!satSolver) {
804 module.emitError() << "unsupported or unavailable SAT solver '"
805 << this->satSolver
806 << "' (expected auto, z3, or cadical)";
807 return signalPassFailure();
808 }
809 satSolver->setConflictLimit(static_cast<int>(conflictLimit));
810 }
811
812 FunctionalReductionSolver fcSolver(module, numRandomPatterns, seed,
813 testTransformation,
814 std::move(satSolver));
815 auto stats = fcSolver.run();
816 if (failed(stats))
817 return signalPassFailure();
818 updateStats(*stats);
819 if (stats->numMergedNodes == 0)
820 markAllAnalysesPreserved();
821 }
822};
823
824} // namespace
assert(baseType &&"element must be base type")
static constexpr llvm::StringLiteral kTestClassAttrName
static Block * getBodyBlock(FModuleLike mod)
RewritePatternSet pattern
Abstract interface for incremental SAT solvers with an IPASIR-style API.
Definition SATSolver.h:158
create(data_type, value)
Definition hw.py:433
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
bool isLogicNetworkOp(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
void addAndClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> and(inputLits).
std::unique_ptr< IncrementalSATSolver > createZ3SATSolver()
Construct a Z3-backed incremental IPASIR-style SAT solver.
std::unique_ptr< IncrementalSATSolver > createCadicalSATSolver(const CadicalSATSolverOptions &options={})
Construct a CaDiCaL-backed incremental IPASIR-style SAT solver.
void addOrClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause)
Emit clauses encoding outVar <=> or(inputLits).
void addParityClauses(int outVar, llvm::ArrayRef< int > inputLits, llvm::function_ref< void(llvm::ArrayRef< int >)> addClause, llvm::function_ref< int()> newVar)
Emit clauses encoding outVar <=> parity(inputLits).
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:1716
Definition synth.py:1