CIRCT 20.0.0git
Loading...
Searching...
No Matches
ModuleSummary.cpp
Go to the documentation of this file.
1//===- ModuleSummary.cpp ----------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/Threading.h"
12#include "mlir/Pass/Pass.h"
13
14#include <mutex>
15#include <numeric>
16
17namespace circt {
18namespace firrtl {
19#define GEN_PASS_DEF_MODULESUMMARY
20#include "circt/Dialect/FIRRTL/Passes.h.inc"
21} // namespace firrtl
22} // namespace circt
23
24using namespace mlir;
25using namespace circt;
26using namespace firrtl;
27
28static size_t knownWidths(Type type) {
29 std::function<size_t(Type)> getWidth = [&](Type type) -> size_t {
30 return TypeSwitch<Type, size_t>(type)
31 .Case<BundleType>([&](BundleType bundle) -> size_t {
32 size_t width = 0;
33 for (auto &elt : bundle) {
34 auto w = getWidth(elt.type);
35 width += w;
36 }
37 return width;
38 })
39 .Case<FEnumType>([&](FEnumType fenum) -> size_t {
40 size_t width = 0;
41 for (auto &elt : fenum) {
42 auto w = getWidth(elt.type);
43 width = std::max(width, w);
44 }
45 return width + llvm::Log2_32_Ceil(fenum.getNumElements());
46 })
47 .Case<FVectorType>([&](auto vector) -> size_t {
48 auto w = getWidth(vector.getElementType());
49 return w * vector.getNumElements();
50 })
51 .Case<IntType>([&](IntType iType) {
52 auto v = iType.getWidth();
53 if (v)
54 return *v;
55 return 0;
56 })
57 .Case<ClockType, ResetType, AsyncResetType>([](Type) { return 1; })
58 .Default([&](auto t) { return 0; });
59 };
60 return getWidth(type);
61}
62
63namespace {
64struct ModuleSummaryPass
65 : public circt::firrtl::impl::ModuleSummaryBase<ModuleSummaryPass> {
66 struct KeyTy {
67 SmallVector<size_t> portSizes;
68 size_t opcount;
69 bool operator==(const KeyTy &rhs) const {
70 return portSizes == rhs.portSizes && opcount == rhs.opcount;
71 }
72 };
73
74 size_t countOps(FModuleOp mod) {
75 size_t retval = 0;
76 mod.walk([&](Operation *op) { retval += 1; });
77 return retval;
78 }
79
80 SmallVector<size_t> portSig(FModuleOp mod) {
81 SmallVector<size_t> ports;
82 for (auto p : mod.getPortTypes())
83 ports.push_back(knownWidths(cast<TypeAttr>(p).getValue()));
84 return ports;
85 }
86
87 void runOnOperation() override;
88};
89} // namespace
90
91namespace mlir {
92// NOLINTNEXTLINE(readability-identifier-naming)
93inline llvm::hash_code hash_value(const ModuleSummaryPass::KeyTy &element) {
94 return llvm::hash_combine(element.portSizes.size(),
95 llvm::hash_combine_range(element.portSizes.begin(),
96 element.portSizes.end()),
97 element.opcount);
98}
99} // namespace mlir
100
101namespace llvm {
102// Type hash just like pointers.
103template <>
104struct DenseMapInfo<ModuleSummaryPass::KeyTy> {
105 using KeyTy = ModuleSummaryPass::KeyTy;
106 static KeyTy getEmptyKey() { return {{}, ~0ULL}; }
107 static KeyTy getTombstoneKey() { return {{}, ~0ULL - 1}; }
108 static unsigned getHashValue(const KeyTy &val) {
109 return mlir::hash_value(val);
110 }
111 static bool isEqual(const KeyTy &lhs, const KeyTy &rhs) { return lhs == rhs; }
112};
113} // namespace llvm
114
115void ModuleSummaryPass::runOnOperation() {
116 auto circuit = getOperation();
117
118 using MapTy = DenseMap<KeyTy, SmallVector<FModuleOp>>;
119 MapTy data;
120
121 std::mutex dataMutex; // protects data
122
123 mlir::parallelForEach(circuit.getContext(),
124 circuit.getBodyBlock()->getOps<FModuleOp>(),
125 [&](auto mod) {
126 auto p = portSig(mod);
127 auto n = countOps(mod);
128 const std::lock_guard<std::mutex> lock(dataMutex);
129 data[{p, n}].push_back(mod);
130 });
131
132 SmallVector<MapTy::value_type> sortedData(data.begin(), data.end());
133 std::sort(sortedData.begin(), sortedData.end(),
134 [](const MapTy::value_type &lhs, const MapTy::value_type &rhs) {
135 return std::get<0>(lhs).opcount * std::get<1>(lhs).size() *
136 std::get<1>(lhs).size() >
137 std::get<0>(rhs).opcount * std::get<1>(rhs).size() *
138 std::get<1>(rhs).size();
139 });
140 llvm::errs() << "cost, opcount, portcount, modcount, portBits, examplename\n";
141 for (auto &p : sortedData)
142 if (p.second.size() > 1) {
143 llvm::errs() << (p.first.opcount * p.second.size() * p.second.size())
144 << "," << p.first.opcount << "," << p.first.portSizes.size()
145 << "," << p.second.size() << ","
146 << std::accumulate(p.first.portSizes.begin(),
147 p.first.portSizes.end(), 0)
148 << "," << p.second[0].getName() << "\n";
149 }
150 markAllAnalysesPreserved();
151}
152
153std::unique_ptr<Pass> firrtl::createModuleSummaryPass() {
154 return std::make_unique<ModuleSummaryPass>();
155}
static size_t knownWidths(Type type)
This is the common base class between SIntType and UIntType.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
std::unique_ptr< mlir::Pass > createModuleSummaryPass()
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:35
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::hash_code hash_value(const T &e)
static bool isEqual(const KeyTy &lhs, const KeyTy &rhs)
static unsigned getHashValue(const KeyTy &val)