21#include "mlir/IR/Attributes.h"
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/BuiltinOps.h"
24#include "mlir/IR/PatternMatch.h"
25#include "mlir/Pass/Pass.h"
26#include "mlir/Support/LogicalResult.h"
27#include "llvm/ADT/APInt.h"
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/MapVector.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/StringRef.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/Support/Debug.h"
37#define DEBUG_TYPE "synth-functional-reduction"
40 "synth.test.fc_equiv_class";
44#define GEN_PASS_DEF_FUNCTIONALREDUCTION
45#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc"
53enum class EquivResult { Proved, Disproved, Unknown };
59class FunctionalReductionSolver {
61 FunctionalReductionSolver(
hw::HWModuleOp module,
unsigned numPatterns,
62 unsigned seed,
bool testTransformation)
63 : module(module), numPatterns(numPatterns), seed(seed),
64 testTransformation(testTransformation) {}
66 ~FunctionalReductionSolver() =
default;
70 unsigned numEquivClasses = 0;
71 unsigned numProvedEquiv = 0;
72 unsigned numDisprovedEquiv = 0;
73 unsigned numUnknown = 0;
74 unsigned numMergedNodes = 0;
76 mlir::FailureOr<Stats>
run();
82 llvm::APInt simulateValue(Value v);
85 void buildEquivalenceClasses();
88 void verifyCandidates();
91 void mergeEquivalentNodes();
94 static Attribute getTestEquivClass(Value value);
95 static bool matchesTestEquivClass(Value lhs, Value rhs);
96 EquivResult verifyEquivalence(Value lhs, Value rhs);
102 unsigned numPatterns;
104 bool testTransformation;
108 SmallVector<Value> primaryInputs;
111 SmallVector<Value> allValues;
114 llvm::DenseMap<Value, llvm::APInt> simSignatures;
118 SmallVector<SmallVector<Value>> equivCandidates;
121 llvm::MapVector<Value, SmallVector<Value>> provenEquivalences;
126Attribute FunctionalReductionSolver::getTestEquivClass(Value value) {
127 Operation *op = value.getDefiningOp();
133bool FunctionalReductionSolver::matchesTestEquivClass(Value lhs, Value rhs) {
134 Attribute lhsClass = getTestEquivClass(lhs);
135 Attribute rhsClass = getTestEquivClass(rhs);
136 return lhsClass && rhsClass && lhsClass == rhsClass;
139EquivResult FunctionalReductionSolver::verifyEquivalence(Value lhs, Value rhs) {
140 if (testTransformation) {
141 if (matchesTestEquivClass(lhs, rhs))
142 return EquivResult::Proved;
143 return EquivResult::Unknown;
148 return EquivResult::Unknown;
155void FunctionalReductionSolver::collectValues() {
157 for (
auto arg : module.
getBodyBlock()->getArguments()) {
158 if (arg.getType().isInteger(1)) {
159 primaryInputs.push_back(arg);
160 allValues.push_back(arg);
167 module.walk([&](Operation *op) {
168 for (auto result : op->getResults()) {
169 if (!result.getType().isInteger(1))
172 allValues.push_back(result);
173 if (!isa<aig::AndInverterOp>(op)) {
175 primaryInputs.push_back(result);
180 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Collected "
181 << primaryInputs.size()
182 <<
" primary inputs (including unknown ops) and "
183 << allValues.size() <<
" total i1 values\n");
186void FunctionalReductionSolver::runSimulation() {
188 unsigned numWords = numPatterns / 64;
191 std::mt19937_64 rng(seed);
193 for (
auto input : primaryInputs) {
195 SmallVector<uint64_t> words(numWords);
196 for (
auto &word : words)
200 llvm::APInt
pattern(numPatterns, words);
201 simSignatures[input] =
pattern;
205 for (
auto value : allValues) {
206 if (simSignatures.count(value))
209 simSignatures[value] = simulateValue(value);
213 llvm::dbgs() <<
"FunctionalReduction: Simulation complete with "
214 << numPatterns <<
" patterns\n";
218llvm::APInt FunctionalReductionSolver::simulateValue(Value v) {
219 Operation *op = v.getDefiningOp();
221 return simSignatures.at(v);
222 return llvm::TypeSwitch<Operation *, llvm::APInt>(op)
223 .Case<aig::AndInverterOp>([&](
auto op) {
224 SmallVector<llvm::APInt> inputSigs;
225 for (
auto input : op.getInputs())
226 inputSigs.push_back(simSignatures.at(input));
227 return op.evaluate(inputSigs);
229 .Default([&](Operation *) {
232 return simSignatures.at(v);
240void FunctionalReductionSolver::buildEquivalenceClasses() {
242 llvm::MapVector<llvm::APInt, SmallVector<Value>> sigGroups;
244 for (
auto value : allValues)
245 sigGroups[simSignatures.at(value)].push_back(value);
248 for (
auto &[hash, members] : sigGroups) {
249 if (members.size() <= 1)
251 equivCandidates.push_back(std::move(members));
253 stats.numEquivClasses = equivCandidates.size();
255 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Built "
256 << equivCandidates.size()
257 <<
" equivalence candidates\n");
267void FunctionalReductionSolver::verifyCandidates() {
269 llvm::dbgs() <<
"FunctionalReduction: Starting SAT verification with "
270 << equivCandidates.size() <<
" equivalence classes\n");
272 for (
auto &members : equivCandidates) {
275 auto representative = members.front();
276 auto &provenMembers = provenEquivalences[representative];
278 for (
auto member :
llvm::ArrayRef<Value>(members).drop_front()) {
279 EquivResult result = verifyEquivalence(representative, member);
280 if (result == EquivResult::Proved) {
281 stats.numProvedEquiv++;
282 provenMembers.push_back(member);
283 }
else if (result == EquivResult::Disproved) {
284 stats.numDisprovedEquiv++;
294 llvm::dbgs() <<
"FunctionalReduction: SAT verification complete. Proved "
295 << stats.numProvedEquiv <<
" equivalences\n");
302void FunctionalReductionSolver::mergeEquivalentNodes() {
303 if (provenEquivalences.empty())
306 mlir::OpBuilder builder(module.getContext());
307 for (
auto &provenEquivSet : provenEquivalences) {
308 auto &[representative, members] = provenEquivSet;
311 SmallVector<Value> operands;
312 operands.reserve(members.size() + 1);
313 operands.push_back(representative);
314 operands.append(members);
315 builder.setInsertionPointAfterValue(members.back());
316 auto choice = synth::ChoiceOp::create(builder, representative.getLoc(),
317 representative.getType(), operands);
318 stats.numMergedNodes += members.size() + 1;
319 representative.replaceAllUsesExcept(choice, choice);
320 for (
auto value : members)
321 value.replaceAllUsesExcept(choice, choice);
324 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Merged "
325 << stats.numMergedNodes <<
" nodes\n");
332mlir::FailureOr<FunctionalReductionSolver::Stats>
333FunctionalReductionSolver::run() {
335 llvm::dbgs() <<
"FunctionalReduction: Starting functional reduction with "
336 << numPatterns <<
" simulation patterns\n");
341 << "FunctionalReduction: Failed to topologically sort logic network";
347 if (allValues.empty()) {
348 LLVM_DEBUG(llvm::dbgs()
349 <<
"FunctionalReduction: No i1 values to process\n");
356 buildEquivalenceClasses();
357 if (equivCandidates.empty()) {
358 LLVM_DEBUG(llvm::dbgs()
359 <<
"FunctionalReduction: No equivalence candidates found\n");
367 mergeEquivalentNodes();
369 LLVM_DEBUG(llvm::dbgs() <<
"FunctionalReduction: Complete. Stats:\n"
370 <<
" Equivalence classes: " << stats.numEquivClasses
372 <<
" Proved: " << stats.numProvedEquiv <<
"\n"
373 <<
" Disproved: " << stats.numDisprovedEquiv <<
"\n"
374 <<
" Unknown (limit): " << stats.numUnknown <<
"\n"
375 <<
" Merged: " << stats.numMergedNodes <<
"\n");
384struct FunctionalReductionPass
385 :
public circt::synth::impl::FunctionalReductionBase<
386 FunctionalReductionPass> {
387 using FunctionalReductionBase::FunctionalReductionBase;
388 void updateStats(
const FunctionalReductionSolver::Stats &stats) {
389 numEquivClasses += stats.numEquivClasses;
390 numProvedEquiv += stats.numProvedEquiv;
391 numDisprovedEquiv += stats.numDisprovedEquiv;
392 numUnknown += stats.numUnknown;
393 numMergedNodes += stats.numMergedNodes;
396 void runOnOperation()
override {
397 auto module = getOperation();
398 LLVM_DEBUG(llvm::dbgs() <<
"Running FunctionalReduction pass on "
399 << module.getName() <<
"\n");
401 if (numRandomPatterns == 0 || (numRandomPatterns & 63U) != 0) {
403 << "'num-random-patterns' must be a positive multiple of 64";
404 return signalPassFailure();
407 FunctionalReductionSolver fcSolver(module, numRandomPatterns, seed,
409 auto stats = fcSolver.run();
411 return signalPassFailure();
413 if (stats->numMergedNodes == 0)
414 markAllAnalysesPreserved();
static constexpr llvm::StringLiteral kTestClassAttrName
static Block * getBodyBlock(FModuleLike mod)
RewritePatternSet pattern
LogicalResult topologicallySortLogicNetwork(mlir::Operation *op)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)