CIRCT  19.0.0git
ModuleSummary.cpp
Go to the documentation of this file.
1 //===- ModuleSummary.cpp ----------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "PassDetails.h"
10 #include "mlir/IR/Threading.h"
11 
12 #include <mutex>
13 #include <numeric>
14 
15 using namespace mlir;
16 using namespace circt;
17 using namespace firrtl;
18 
19 static size_t knownWidths(Type type) {
20  std::function<size_t(Type)> getWidth = [&](Type type) -> size_t {
21  return TypeSwitch<Type, size_t>(type)
22  .Case<BundleType>([&](BundleType bundle) -> size_t {
23  size_t width = 0;
24  for (auto &elt : bundle) {
25  auto w = getWidth(elt.type);
26  width += w;
27  }
28  return width;
29  })
30  .Case<FEnumType>([&](FEnumType fenum) -> size_t {
31  size_t width = 0;
32  for (auto &elt : fenum) {
33  auto w = getWidth(elt.type);
34  width = std::max(width, w);
35  }
36  return width + llvm::Log2_32_Ceil(fenum.getNumElements());
37  })
38  .Case<FVectorType>([&](auto vector) -> size_t {
39  auto w = getWidth(vector.getElementType());
40  return w * vector.getNumElements();
41  })
42  .Case<IntType>([&](IntType iType) {
43  auto v = iType.getWidth();
44  if (v)
45  return *v;
46  return 0;
47  })
48  .Case<ClockType, ResetType, AsyncResetType>([](Type) { return 1; })
49  .Default([&](auto t) { return 0; });
50  };
51  return getWidth(type);
52 }
53 
54 namespace {
55 struct ModuleSummaryPass : public ModuleSummaryBase<ModuleSummaryPass> {
56  struct KeyTy {
57  SmallVector<size_t> portSizes;
58  size_t opcount;
59  bool operator==(const KeyTy &rhs) const {
60  return portSizes == rhs.portSizes && opcount == rhs.opcount;
61  }
62  };
63 
64  size_t countOps(FModuleOp mod) {
65  size_t retval = 0;
66  mod.walk([&](Operation *op) { retval += 1; });
67  return retval;
68  }
69 
70  SmallVector<size_t> portSig(FModuleOp mod) {
71  SmallVector<size_t> ports;
72  for (auto p : mod.getPortTypes())
73  ports.push_back(knownWidths(cast<TypeAttr>(p).getValue()));
74  return ports;
75  }
76 
77  void runOnOperation() override;
78 };
79 } // namespace
80 
81 namespace mlir {
82 // NOLINTNEXTLINE(readability-identifier-naming)
83 inline llvm::hash_code hash_value(const ModuleSummaryPass::KeyTy &element) {
84  return llvm::hash_combine(element.portSizes.size(),
85  llvm::hash_combine_range(element.portSizes.begin(),
86  element.portSizes.end()),
87  element.opcount);
88 }
89 } // namespace mlir
90 
91 namespace llvm {
92 // Type hash just like pointers.
93 template <>
94 struct DenseMapInfo<ModuleSummaryPass::KeyTy> {
95  using KeyTy = ModuleSummaryPass::KeyTy;
96  static KeyTy getEmptyKey() { return {{}, ~0ULL}; }
97  static KeyTy getTombstoneKey() { return {{}, ~0ULL - 1}; }
98  static unsigned getHashValue(const KeyTy &val) {
99  return mlir::hash_value(val);
100  }
101  static bool isEqual(const KeyTy &lhs, const KeyTy &rhs) { return lhs == rhs; }
102 };
103 } // namespace llvm
104 
105 void ModuleSummaryPass::runOnOperation() {
106  auto circuit = getOperation();
107 
108  using MapTy = DenseMap<KeyTy, SmallVector<FModuleOp>>;
109  MapTy data;
110 
111  std::mutex dataMutex; // protects data
112 
113  mlir::parallelForEach(circuit.getContext(),
114  circuit.getBodyBlock()->getOps<FModuleOp>(),
115  [&](auto mod) {
116  auto p = portSig(mod);
117  auto n = countOps(mod);
118  const std::lock_guard<std::mutex> lock(dataMutex);
119  data[{p, n}].push_back(mod);
120  });
121 
122  SmallVector<MapTy::value_type> sortedData(data.begin(), data.end());
123  std::sort(sortedData.begin(), sortedData.end(),
124  [](const MapTy::value_type &lhs, const MapTy::value_type &rhs) {
125  return std::get<0>(lhs).opcount * std::get<1>(lhs).size() *
126  std::get<1>(lhs).size() >
127  std::get<0>(rhs).opcount * std::get<1>(rhs).size() *
128  std::get<1>(rhs).size();
129  });
130  llvm::errs() << "cost, opcount, portcount, modcount, portBits, examplename\n";
131  for (auto &p : sortedData)
132  if (p.second.size() > 1) {
133  llvm::errs() << (p.first.opcount * p.second.size() * p.second.size())
134  << "," << p.first.opcount << "," << p.first.portSizes.size()
135  << "," << p.second.size() << ","
136  << std::accumulate(p.first.portSizes.begin(),
137  p.first.portSizes.end(), 0)
138  << "," << p.second[0].getName() << "\n";
139  }
140  markAllAnalysesPreserved();
141 }
142 
143 std::unique_ptr<Pass> firrtl::createModuleSummaryPass() {
144  return std::make_unique<ModuleSummaryPass>();
145 }
int32_t width
Definition: FIRRTL.cpp:36
bool operator==(const ResetDomain &a, const ResetDomain &b)
Definition: InferResets.cpp:71
static size_t knownWidths(Type type)
This is the common base class between SIntType and UIntType.
Definition: FIRRTLTypes.h:294
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
Definition: FIRRTLTypes.h:273
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:32
std::unique_ptr< mlir::Pass > createModuleSummaryPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::hash_code hash_value(const T &e)
static bool isEqual(const KeyTy &lhs, const KeyTy &rhs)
static unsigned getHashValue(const KeyTy &val)