CIRCT 23.0.0git
Loading...
Searching...
No Matches
FunctionalReduction.cpp
Go to the documentation of this file.
1//===----------------------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass implements FunctionalReduction (Functionally Reduced And-Inverter
10// Graph) optimization using a built-in minimal CDCL SAT solver. It identifies
11// and merges functionally equivalent nodes through simulation-based candidate
12// detection followed by SAT-based verification.
13//
14//===----------------------------------------------------------------------===//
15
21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinOps.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Support/LogicalResult.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/StringRef.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
35#include <random>
36
37#define DEBUG_TYPE "synth-functional-reduction"
38
39static constexpr llvm::StringLiteral kTestClassAttrName =
40 "synth.test.fc_equiv_class";
41
42namespace circt {
43namespace synth {
44#define GEN_PASS_DEF_FUNCTIONALREDUCTION
45#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
46} // namespace synth
47} // namespace circt
48
49using namespace circt;
50using namespace circt::synth;
51
52namespace {
53enum class EquivResult { Proved, Disproved, Unknown };
54
55//===----------------------------------------------------------------------===//
56// Core Functional Reduction Implementation
57//===----------------------------------------------------------------------===//
58
59class FunctionalReductionSolver {
60public:
61 FunctionalReductionSolver(hw::HWModuleOp module, unsigned numPatterns,
62 unsigned seed, bool testTransformation)
63 : module(module), numPatterns(numPatterns), seed(seed),
64 testTransformation(testTransformation) {}
65
66 ~FunctionalReductionSolver() = default;
67
68 /// Run the Functional Reduction algorithm and return statistics.
69 struct Stats {
70 unsigned numEquivClasses = 0;
71 unsigned numProvedEquiv = 0;
72 unsigned numDisprovedEquiv = 0;
73 unsigned numUnknown = 0;
74 unsigned numMergedNodes = 0;
75 };
76 mlir::FailureOr<Stats> run();
77
78private:
79 // Phase 1: Collect i1 values and run simulation
80 void collectValues();
81 void runSimulation();
82 llvm::APInt simulateValue(Value v);
83
84 // Phase 2: Build equivalence classes from simulation
85 void buildEquivalenceClasses();
86
87 // Phase 3: SAT-based verification with per-class solver
88 void verifyCandidates();
89
90 // Phase 4: Merge equivalent nodes
91 void mergeEquivalentNodes();
92
93 // Test transformation helpers.
94 static Attribute getTestEquivClass(Value value);
95 static bool matchesTestEquivClass(Value lhs, Value rhs);
96 EquivResult verifyEquivalence(Value lhs, Value rhs);
97
98 // Module being processed
99 hw::HWModuleOp module;
100
101 // Configuration
102 unsigned numPatterns;
103 unsigned seed;
104 bool testTransformation;
105
106 // Primary inputs (block arguments or results of unknown operations treated as
107 // inputs)
108 SmallVector<Value> primaryInputs;
109
110 // All i1 values in topological order
111 SmallVector<Value> allValues;
112
113 // Simulation signatures: value -> APInt simulation result
114 llvm::DenseMap<Value, llvm::APInt> simSignatures;
115
116 // Equivalence candidates: groups of values with identical simulation
117 // signatures
118 SmallVector<SmallVector<Value>> equivCandidates;
119
120 // Proven equivalences: representative -> proven equivalent members.
121 llvm::MapVector<Value, SmallVector<Value>> provenEquivalences;
122
123 Stats stats;
124};
125
126Attribute FunctionalReductionSolver::getTestEquivClass(Value value) {
127 Operation *op = value.getDefiningOp();
128 if (!op)
129 return {};
130 return op->getAttr(kTestClassAttrName);
131}
132
133bool FunctionalReductionSolver::matchesTestEquivClass(Value lhs, Value rhs) {
134 Attribute lhsClass = getTestEquivClass(lhs);
135 Attribute rhsClass = getTestEquivClass(rhs);
136 return lhsClass && rhsClass && lhsClass == rhsClass;
137}
138
139EquivResult FunctionalReductionSolver::verifyEquivalence(Value lhs, Value rhs) {
140 if (testTransformation) {
141 if (matchesTestEquivClass(lhs, rhs))
142 return EquivResult::Proved;
143 return EquivResult::Unknown;
144 }
145
146 // TODO: Implement actual SAT-based verification here. For now, we return
147 // Unknown.
148 return EquivResult::Unknown;
149}
150
151//===----------------------------------------------------------------------===//
152// Phase 1: Collect values and run simulation
153//===----------------------------------------------------------------------===//
154
155void FunctionalReductionSolver::collectValues() {
156 // Collect block arguments (primary inputs) that are i1
157 for (auto arg : module.getBodyBlock()->getArguments()) {
158 if (arg.getType().isInteger(1)) {
159 primaryInputs.push_back(arg);
160 allValues.push_back(arg);
161 }
162 }
163
164 // Walk operations and collect i1 results
165 // - AIG/MIG operations: add to allValues for simulation
166 // - Unknown operations: treat as inputs (assign random patterns)
167 module.walk([&](Operation *op) {
168 for (auto result : op->getResults()) {
169 if (!result.getType().isInteger(1))
170 continue;
171
172 allValues.push_back(result);
173 if (!isa<aig::AndInverterOp>(op)) {
174 // Unknown operations - treat as primary inputs
175 primaryInputs.push_back(result);
176 }
177 }
178 });
179
180 LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Collected "
181 << primaryInputs.size()
182 << " primary inputs (including unknown ops) and "
183 << allValues.size() << " total i1 values\n");
184}
185
186void FunctionalReductionSolver::runSimulation() {
187 // Calculate number of 64-bit words needed for numPatterns bits
188 unsigned numWords = numPatterns / 64;
189
190 // Create seeded random number generator for deterministic patterns
191 std::mt19937_64 rng(seed);
192
193 for (auto input : primaryInputs) {
194 // Generate random words using seeded RNG
195 SmallVector<uint64_t> words(numWords);
196 for (auto &word : words)
197 word = rng();
198
199 // Construct APInt directly from words
200 llvm::APInt pattern(numPatterns, words);
201 simSignatures[input] = pattern;
202 }
203
204 // Propagate simulation through the circuit in topological order
205 for (auto value : allValues) {
206 if (simSignatures.count(value))
207 continue; // Already computed (primary input)
208
209 simSignatures[value] = simulateValue(value);
210 }
211
212 LLVM_DEBUG({
213 llvm::dbgs() << "FunctionalReduction: Simulation complete with "
214 << numPatterns << " patterns\n";
215 });
216}
217
218llvm::APInt FunctionalReductionSolver::simulateValue(Value v) {
219 Operation *op = v.getDefiningOp();
220 if (!op)
221 return simSignatures.at(v);
222 return llvm::TypeSwitch<Operation *, llvm::APInt>(op)
223 .Case<aig::AndInverterOp>([&](auto op) {
224 SmallVector<llvm::APInt> inputSigs;
225 for (auto input : op.getInputs())
226 inputSigs.push_back(simSignatures.at(input));
227 return op.evaluate(inputSigs);
228 })
229 .Default([&](Operation *) {
230 // Unknown operation - treat as input (already assigned a random
231 // pattern)
232 return simSignatures.at(v);
233 });
234}
235
236//===----------------------------------------------------------------------===//
237// Phase 2: Build equivalence classes from simulation
238//===----------------------------------------------------------------------===//
239
240void FunctionalReductionSolver::buildEquivalenceClasses() {
241 // Map from signature to list of values
242 llvm::MapVector<llvm::APInt, SmallVector<Value>> sigGroups;
243
244 for (auto value : allValues)
245 sigGroups[simSignatures.at(value)].push_back(value);
246
247 // Build equivalence candidates for groups with >1 member.
248 for (auto &[hash, members] : sigGroups) {
249 if (members.size() <= 1)
250 continue;
251 equivCandidates.push_back(std::move(members));
252 }
253 stats.numEquivClasses = equivCandidates.size();
254
255 LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Built "
256 << equivCandidates.size()
257 << " equivalence candidates\n");
258}
259
260//===----------------------------------------------------------------------===//
261// Phase 3: SAT-based verification with per-class solvers
262//
263// For each equivalence class candidates, verify each member against the
264// representative using a SAT solver.
265//===----------------------------------------------------------------------===//
266
267void FunctionalReductionSolver::verifyCandidates() {
268 LLVM_DEBUG(
269 llvm::dbgs() << "FunctionalReduction: Starting SAT verification with "
270 << equivCandidates.size() << " equivalence classes\n");
271
272 for (auto &members : equivCandidates) {
273 if (members.empty())
274 continue;
275 auto representative = members.front();
276 auto &provenMembers = provenEquivalences[representative];
277 // Representative is the canonical node for this class.
278 for (auto member : llvm::ArrayRef<Value>(members).drop_front()) {
279 EquivResult result = verifyEquivalence(representative, member);
280 if (result == EquivResult::Proved) {
281 stats.numProvedEquiv++;
282 provenMembers.push_back(member);
283 } else if (result == EquivResult::Disproved) {
284 stats.numDisprovedEquiv++;
285 // TODO: Refine equivalence classes based on counterexamples from SAT
286 // solver
287 } else {
288 stats.numUnknown++;
289 }
290 }
291 }
292
293 LLVM_DEBUG(
294 llvm::dbgs() << "FunctionalReduction: SAT verification complete. Proved "
295 << stats.numProvedEquiv << " equivalences\n");
296}
297
298//===----------------------------------------------------------------------===//
299// Phase 4: Merge equivalent nodes
300//===----------------------------------------------------------------------===//
301
302void FunctionalReductionSolver::mergeEquivalentNodes() {
303 if (provenEquivalences.empty())
304 return;
305
306 mlir::OpBuilder builder(module.getContext());
307 for (auto &provenEquivSet : provenEquivalences) {
308 auto &[representative, members] = provenEquivSet;
309 if (members.empty())
310 continue;
311 SmallVector<Value> operands;
312 operands.reserve(members.size() + 1);
313 operands.push_back(representative);
314 operands.append(members);
315 builder.setInsertionPointAfterValue(members.back());
316 auto choice = synth::ChoiceOp::create(builder, representative.getLoc(),
317 representative.getType(), operands);
318 stats.numMergedNodes += members.size() + 1;
319 representative.replaceAllUsesExcept(choice, choice);
320 for (auto value : members)
321 value.replaceAllUsesExcept(choice, choice);
322 }
323
324 LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Merged "
325 << stats.numMergedNodes << " nodes\n");
326}
327
328//===----------------------------------------------------------------------===//
329// Main Functional Reduction algorithm
330//===----------------------------------------------------------------------===//
331
332mlir::FailureOr<FunctionalReductionSolver::Stats>
333FunctionalReductionSolver::run() {
334 LLVM_DEBUG(
335 llvm::dbgs() << "FunctionalReduction: Starting functional reduction with "
336 << numPatterns << " simulation patterns\n");
337 // Topologically sort the values
338
340 module->emitError()
341 << "FunctionalReduction: Failed to topologically sort logic network";
342 return failure();
343 }
344
345 // Phase 1: Collect values and run simulation
346 collectValues();
347 if (allValues.empty()) {
348 LLVM_DEBUG(llvm::dbgs()
349 << "FunctionalReduction: No i1 values to process\n");
350 return stats;
351 }
352
353 runSimulation();
354
355 // Phase 2: Build equivalence classes
356 buildEquivalenceClasses();
357 if (equivCandidates.empty()) {
358 LLVM_DEBUG(llvm::dbgs()
359 << "FunctionalReduction: No equivalence candidates found\n");
360 return stats;
361 }
362
363 // Phase 3: SAT-based verification
364 verifyCandidates();
365
366 // Phase 4: Merge equivalent nodes
367 mergeEquivalentNodes();
368
369 LLVM_DEBUG(llvm::dbgs() << "FunctionalReduction: Complete. Stats:\n"
370 << " Equivalence classes: " << stats.numEquivClasses
371 << "\n"
372 << " Proved: " << stats.numProvedEquiv << "\n"
373 << " Disproved: " << stats.numDisprovedEquiv << "\n"
374 << " Unknown (limit): " << stats.numUnknown << "\n"
375 << " Merged: " << stats.numMergedNodes << "\n");
376
377 return stats;
378}
379
380//===----------------------------------------------------------------------===//
381// Pass implementation
382//===----------------------------------------------------------------------===//
383
384struct FunctionalReductionPass
385 : public circt::synth::impl::FunctionalReductionBase<
386 FunctionalReductionPass> {
387 using FunctionalReductionBase::FunctionalReductionBase;
388 void updateStats(const FunctionalReductionSolver::Stats &stats) {
389 numEquivClasses += stats.numEquivClasses;
390 numProvedEquiv += stats.numProvedEquiv;
391 numDisprovedEquiv += stats.numDisprovedEquiv;
392 numUnknown += stats.numUnknown;
393 numMergedNodes += stats.numMergedNodes;
394 }
395
396 void runOnOperation() override {
397 auto module = getOperation();
398 LLVM_DEBUG(llvm::dbgs() << "Running FunctionalReduction pass on "
399 << module.getName() << "\n");
400
401 if (numRandomPatterns == 0 || (numRandomPatterns & 63U) != 0) {
402 module.emitError()
403 << "'num-random-patterns' must be a positive multiple of 64";
404 return signalPassFailure();
405 }
406
407 FunctionalReductionSolver fcSolver(module, numRandomPatterns, seed,
408 testTransformation);
409 auto stats = fcSolver.run();
410 if (failed(stats))
411 return signalPassFailure();
412 updateStats(*stats);
413 if (stats->numMergedNodes == 0)
414 markAllAnalysesPreserved();
415 }
416};
417
418} // namespace
static constexpr llvm::StringLiteral kTestClassAttrName
static Block * getBodyBlock(FModuleLike mod)
RewritePatternSet pattern
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:445
Definition synth.py:1