CIRCT  20.0.0git
ModuleSummary.cpp
Go to the documentation of this file.
1 //===- ModuleSummary.cpp ----------------------------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "mlir/IR/Threading.h"
12 #include "mlir/Pass/Pass.h"
13 
14 #include <mutex>
15 #include <numeric>
16 
17 namespace circt {
18 namespace firrtl {
19 #define GEN_PASS_DEF_MODULESUMMARY
20 #include "circt/Dialect/FIRRTL/Passes.h.inc"
21 } // namespace firrtl
22 } // namespace circt
23 
24 using namespace mlir;
25 using namespace circt;
26 using namespace firrtl;
27 
28 static size_t knownWidths(Type type) {
29  std::function<size_t(Type)> getWidth = [&](Type type) -> size_t {
30  return TypeSwitch<Type, size_t>(type)
31  .Case<BundleType>([&](BundleType bundle) -> size_t {
32  size_t width = 0;
33  for (auto &elt : bundle) {
34  auto w = getWidth(elt.type);
35  width += w;
36  }
37  return width;
38  })
39  .Case<FEnumType>([&](FEnumType fenum) -> size_t {
40  size_t width = 0;
41  for (auto &elt : fenum) {
42  auto w = getWidth(elt.type);
43  width = std::max(width, w);
44  }
45  return width + llvm::Log2_32_Ceil(fenum.getNumElements());
46  })
47  .Case<FVectorType>([&](auto vector) -> size_t {
48  auto w = getWidth(vector.getElementType());
49  return w * vector.getNumElements();
50  })
51  .Case<IntType>([&](IntType iType) {
52  auto v = iType.getWidth();
53  if (v)
54  return *v;
55  return 0;
56  })
57  .Case<ClockType, ResetType, AsyncResetType>([](Type) { return 1; })
58  .Default([&](auto t) { return 0; });
59  };
60  return getWidth(type);
61 }
62 
63 namespace {
64 struct ModuleSummaryPass
65  : public circt::firrtl::impl::ModuleSummaryBase<ModuleSummaryPass> {
66  struct KeyTy {
67  SmallVector<size_t> portSizes;
68  size_t opcount;
69  bool operator==(const KeyTy &rhs) const {
70  return portSizes == rhs.portSizes && opcount == rhs.opcount;
71  }
72  };
73 
74  size_t countOps(FModuleOp mod) {
75  size_t retval = 0;
76  mod.walk([&](Operation *op) { retval += 1; });
77  return retval;
78  }
79 
80  SmallVector<size_t> portSig(FModuleOp mod) {
81  SmallVector<size_t> ports;
82  for (auto p : mod.getPortTypes())
83  ports.push_back(knownWidths(cast<TypeAttr>(p).getValue()));
84  return ports;
85  }
86 
87  void runOnOperation() override;
88 };
89 } // namespace
90 
91 namespace mlir {
92 // NOLINTNEXTLINE(readability-identifier-naming)
93 inline llvm::hash_code hash_value(const ModuleSummaryPass::KeyTy &element) {
94  return llvm::hash_combine(element.portSizes.size(),
95  llvm::hash_combine_range(element.portSizes.begin(),
96  element.portSizes.end()),
97  element.opcount);
98 }
99 } // namespace mlir
100 
101 namespace llvm {
102 // Type hash just like pointers.
103 template <>
104 struct DenseMapInfo<ModuleSummaryPass::KeyTy> {
105  using KeyTy = ModuleSummaryPass::KeyTy;
106  static KeyTy getEmptyKey() { return {{}, ~0ULL}; }
107  static KeyTy getTombstoneKey() { return {{}, ~0ULL - 1}; }
108  static unsigned getHashValue(const KeyTy &val) {
109  return mlir::hash_value(val);
110  }
111  static bool isEqual(const KeyTy &lhs, const KeyTy &rhs) { return lhs == rhs; }
112 };
113 } // namespace llvm
114 
115 void ModuleSummaryPass::runOnOperation() {
116  auto circuit = getOperation();
117 
118  using MapTy = DenseMap<KeyTy, SmallVector<FModuleOp>>;
119  MapTy data;
120 
121  std::mutex dataMutex; // protects data
122 
123  mlir::parallelForEach(circuit.getContext(),
124  circuit.getBodyBlock()->getOps<FModuleOp>(),
125  [&](auto mod) {
126  auto p = portSig(mod);
127  auto n = countOps(mod);
128  const std::lock_guard<std::mutex> lock(dataMutex);
129  data[{p, n}].push_back(mod);
130  });
131 
132  SmallVector<MapTy::value_type> sortedData(data.begin(), data.end());
133  std::sort(sortedData.begin(), sortedData.end(),
134  [](const MapTy::value_type &lhs, const MapTy::value_type &rhs) {
135  return std::get<0>(lhs).opcount * std::get<1>(lhs).size() *
136  std::get<1>(lhs).size() >
137  std::get<0>(rhs).opcount * std::get<1>(rhs).size() *
138  std::get<1>(rhs).size();
139  });
140  llvm::errs() << "cost, opcount, portcount, modcount, portBits, examplename\n";
141  for (auto &p : sortedData)
142  if (p.second.size() > 1) {
143  llvm::errs() << (p.first.opcount * p.second.size() * p.second.size())
144  << "," << p.first.opcount << "," << p.first.portSizes.size()
145  << "," << p.second.size() << ","
146  << std::accumulate(p.first.portSizes.begin(),
147  p.first.portSizes.end(), 0)
148  << "," << p.second[0].getName() << "\n";
149  }
150  markAllAnalysesPreserved();
151 }
152 
153 std::unique_ptr<Pass> firrtl::createModuleSummaryPass() {
154  return std::make_unique<ModuleSummaryPass>();
155 }
int32_t width
Definition: FIRRTL.cpp:36
bool operator==(const ResetDomain &a, const ResetDomain &b)
Definition: InferResets.cpp:78
static size_t knownWidths(Type type)
This is the common base class between SIntType and UIntType.
Definition: FIRRTLTypes.h:296
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
Definition: FIRRTLTypes.h:275
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:32
std::unique_ptr< mlir::Pass > createModuleSummaryPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition: Utils.h:32
llvm::hash_code hash_value(const T &e)
static bool isEqual(const KeyTy &lhs, const KeyTy &rhs)
static unsigned getHashValue(const KeyTy &val)