11 #include "mlir/IR/Threading.h"
12 #include "mlir/Pass/Pass.h"
19 #define GEN_PASS_DEF_MODULESUMMARY
20 #include "circt/Dialect/FIRRTL/Passes.h.inc"
25 using namespace circt;
26 using namespace firrtl;
29 std::function<size_t(Type)>
getWidth = [&](Type type) ->
size_t {
30 return TypeSwitch<Type, size_t>(type)
31 .Case<BundleType>([&](BundleType bundle) ->
size_t {
33 for (
auto &elt : bundle) {
39 .Case<FEnumType>([&](FEnumType fenum) ->
size_t {
41 for (
auto &elt : fenum) {
45 return width + llvm::Log2_32_Ceil(fenum.getNumElements());
47 .Case<FVectorType>([&](
auto vector) ->
size_t {
48 auto w =
getWidth(vector.getElementType());
49 return w * vector.getNumElements();
51 .Case<IntType>([&](
IntType iType) {
57 .Case<ClockType, ResetType, AsyncResetType>([](Type) {
return 1; })
58 .Default([&](
auto t) {
return 0; });
64 struct ModuleSummaryPass
65 :
public circt::firrtl::impl::ModuleSummaryBase<ModuleSummaryPass> {
67 SmallVector<size_t> portSizes;
70 return portSizes == rhs.portSizes && opcount == rhs.opcount;
74 size_t countOps(FModuleOp mod) {
76 mod.walk([&](Operation *op) { retval += 1; });
80 SmallVector<size_t> portSig(FModuleOp mod) {
81 SmallVector<size_t> ports;
82 for (
auto p : mod.getPortTypes())
83 ports.push_back(
knownWidths(cast<TypeAttr>(p).getValue()));
87 void runOnOperation()
override;
93 inline llvm::hash_code
hash_value(
const ModuleSummaryPass::KeyTy &element) {
95 llvm::hash_combine_range(element.portSizes.begin(),
96 element.portSizes.end()),
105 using KeyTy = ModuleSummaryPass::KeyTy;
115 void ModuleSummaryPass::runOnOperation() {
116 auto circuit = getOperation();
118 using MapTy = DenseMap<KeyTy, SmallVector<FModuleOp>>;
121 std::mutex dataMutex;
123 mlir::parallelForEach(circuit.getContext(),
124 circuit.getBodyBlock()->getOps<FModuleOp>(),
126 auto p = portSig(mod);
127 auto n = countOps(mod);
128 const std::lock_guard<std::mutex> lock(dataMutex);
129 data[{p, n}].push_back(mod);
132 SmallVector<MapTy::value_type> sortedData(
data.begin(),
data.end());
133 std::sort(sortedData.begin(), sortedData.end(),
134 [](
const MapTy::value_type &lhs,
const MapTy::value_type &rhs) {
135 return std::get<0>(lhs).opcount * std::get<1>(lhs).size() *
136 std::get<1>(lhs).size() >
137 std::get<0>(rhs).opcount * std::get<1>(rhs).size() *
138 std::get<1>(rhs).size();
140 llvm::errs() <<
"cost, opcount, portcount, modcount, portBits, examplename\n";
141 for (
auto &p : sortedData)
142 if (p.second.size() > 1) {
143 llvm::errs() << (p.first.opcount * p.second.size() * p.second.size())
144 <<
"," << p.first.opcount <<
"," << p.first.portSizes.size()
145 <<
"," << p.second.size() <<
","
146 << std::accumulate(p.first.portSizes.begin(),
147 p.first.portSizes.end(), 0)
148 <<
"," << p.second[0].getName() <<
"\n";
150 markAllAnalysesPreserved();
154 return std::make_unique<ModuleSummaryPass>();
bool operator==(const ResetDomain &a, const ResetDomain &b)
static size_t knownWidths(Type type)
This is the common base class between SIntType and UIntType.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
uint64_t getWidth(Type t)
std::unique_ptr< mlir::Pass > createModuleSummaryPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
llvm::hash_code hash_value(const T &e)
static bool isEqual(const KeyTy &lhs, const KeyTy &rhs)
static KeyTy getEmptyKey()
static KeyTy getTombstoneKey()
ModuleSummaryPass::KeyTy KeyTy
static unsigned getHashValue(const KeyTy &val)