CIRCT 23.0.0git
Loading...
Searching...
No Matches
ModuleSummary.cpp
Go to the documentation of this file.
1//===- ModuleSummary.cpp ----------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/IR/Threading.h"
12#include "mlir/Pass/Pass.h"
13
14#include <mutex>
15#include <numeric>
16
17namespace circt {
18namespace firrtl {
19#define GEN_PASS_DEF_MODULESUMMARY
20#include "circt/Dialect/FIRRTL/Passes.h.inc"
21} // namespace firrtl
22} // namespace circt
23
24using namespace mlir;
25using namespace circt;
26using namespace firrtl;
27
28static size_t knownWidths(Type type) {
29 std::function<size_t(Type)> getWidth = [&](Type type) -> size_t {
30 return TypeSwitch<Type, size_t>(type)
31 .Case<BundleType>([&](BundleType bundle) -> size_t {
32 size_t width = 0;
33 for (auto &elt : bundle) {
34 auto w = getWidth(elt.type);
35 width += w;
36 }
37 return width;
38 })
39 .Case<FEnumType>([&](FEnumType fenum) -> size_t {
40 size_t width = 0;
41 for (auto &elt : fenum) {
42 auto w = getWidth(elt.type);
43 width = std::max(width, w);
44 }
45 return width + fenum.getTagWidth();
46 })
47 .Case<FVectorType>([&](auto vector) -> size_t {
48 auto w = getWidth(vector.getElementType());
49 return w * vector.getNumElements();
50 })
51 .Case<IntType>([&](IntType iType) {
52 auto v = iType.getWidth();
53 if (v)
54 return *v;
55 return 0;
56 })
57 .Case<ClockType, ResetType, AsyncResetType>([](Type) { return 1; })
58 .Default([&](auto t) { return 0; });
59 };
60 return getWidth(type);
61}
62
63namespace {
64struct ModuleSummaryPass
65 : public circt::firrtl::impl::ModuleSummaryBase<ModuleSummaryPass> {
66 struct KeyTy {
67 SmallVector<size_t> portSizes;
68 size_t opcount;
69 bool operator==(const KeyTy &rhs) const {
70 return portSizes == rhs.portSizes && opcount == rhs.opcount;
71 }
72 };
73
74 size_t countOps(FModuleOp mod) {
75 size_t retval = 0;
76 mod.walk([&](Operation *op) { retval += 1; });
77 return retval;
78 }
79
80 SmallVector<size_t> portSig(FModuleOp mod) {
81 SmallVector<size_t> ports;
82 for (auto p : mod.getPortTypes())
83 ports.push_back(knownWidths(cast<TypeAttr>(p).getValue()));
84 return ports;
85 }
86
87 void runOnOperation() override;
88};
89} // namespace
90
91namespace mlir {
92// NOLINTNEXTLINE(readability-identifier-naming)
93inline llvm::hash_code hash_value(const ModuleSummaryPass::KeyTy &element) {
94 return llvm::hash_combine(element.portSizes.size(),
95 llvm::hash_combine_range(element.portSizes.begin(),
96 element.portSizes.end()),
97 element.opcount);
98}
99} // namespace mlir
100
101namespace llvm {
102// Type hash just like pointers.
103template <>
104struct DenseMapInfo<ModuleSummaryPass::KeyTy> {
105 using KeyTy = ModuleSummaryPass::KeyTy;
106 static unsigned getHashValue(const KeyTy &val) {
107 return mlir::hash_value(val);
108 }
109 static bool isEqual(const KeyTy &lhs, const KeyTy &rhs) { return lhs == rhs; }
110};
111} // namespace llvm
112
113void ModuleSummaryPass::runOnOperation() {
114 auto circuit = getOperation();
115
116 using MapTy = DenseMap<KeyTy, SmallVector<FModuleOp>>;
117 MapTy data;
118
119 std::mutex dataMutex; // protects data
120
121 mlir::parallelForEach(circuit.getContext(),
122 circuit.getBodyBlock()->getOps<FModuleOp>(),
123 [&](auto mod) {
124 auto p = portSig(mod);
125 auto n = countOps(mod);
126 const std::lock_guard<std::mutex> lock(dataMutex);
127 data[{p, n}].push_back(mod);
128 });
129
130 SmallVector<MapTy::value_type> sortedData(data.begin(), data.end());
131 std::sort(sortedData.begin(), sortedData.end(),
132 [](const MapTy::value_type &lhs, const MapTy::value_type &rhs) {
133 return std::get<0>(lhs).opcount * std::get<1>(lhs).size() *
134 std::get<1>(lhs).size() >
135 std::get<0>(rhs).opcount * std::get<1>(rhs).size() *
136 std::get<1>(rhs).size();
137 });
138 llvm::errs() << "cost, opcount, portcount, modcount, portBits, examplename\n";
139 for (auto &p : sortedData)
140 if (p.second.size() > 1) {
141 llvm::errs() << (p.first.opcount * p.second.size() * p.second.size())
142 << "," << p.first.opcount << "," << p.first.portSizes.size()
143 << "," << p.second.size() << ","
144 << std::accumulate(p.first.portSizes.begin(),
145 p.first.portSizes.end(), 0)
146 << "," << p.second[0].getName() << "\n";
147 }
148 markAllAnalysesPreserved();
149}
static size_t knownWidths(Type type)
This is the common base class between SIntType and UIntType.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:36
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::hash_code hash_value(const T &e)
static bool isEqual(const KeyTy &lhs, const KeyTy &rhs)
static unsigned getHashValue(const KeyTy &val)