CIRCT  20.0.0git
Dedup.cpp
Go to the documentation of this file.
1 //===- Dedup.cpp - FIRRTL module deduping -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements FIRRTL module deduplication.
10 //
11 //===----------------------------------------------------------------------===//
12 
23 #include "circt/Support/LLVM.h"
24 #include "mlir/IR/IRMapping.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/IR/Threading.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Support/LogicalResult.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/DenseMapInfo.h"
31 #include "llvm/ADT/DepthFirstIterator.h"
32 #include "llvm/ADT/Hashing.h"
33 #include "llvm/ADT/PostOrderIterator.h"
34 #include "llvm/ADT/SmallPtrSet.h"
35 #include "llvm/ADT/TypeSwitch.h"
36 #include "llvm/Support/Format.h"
37 #include "llvm/Support/SHA256.h"
38 
39 namespace circt {
40 namespace firrtl {
41 #define GEN_PASS_DEF_DEDUP
42 #include "circt/Dialect/FIRRTL/Passes.h.inc"
43 } // namespace firrtl
44 } // namespace circt
45 
46 using namespace circt;
47 using namespace firrtl;
48 using hw::InnerRefAttr;
49 
50 //===----------------------------------------------------------------------===//
51 // Utility function for classifying a Symbol's dedup-ability.
52 //===----------------------------------------------------------------------===//
53 
54 /// Returns true if the module can be removed.
55 static bool canRemoveModule(mlir::SymbolOpInterface symbol) {
56  // If the symbol is not private, it cannot be removed.
57  if (!symbol.isPrivate())
58  return false;
59  // Classes may be referenced in object types, so can not normally be removed
60  // if we can't find any symbol uses. Since we know that dedup will update the
61  // types of instances appropriately, we can ignore that and return true here.
62  if (isa<ClassLike>(*symbol))
63  return true;
64  // If module can not be removed even if no uses can be found, we can not
65  // delete it. The implication is that there are hidden symbol uses that dedup
66  // will not properly update.
67  if (!symbol.canDiscardOnUseEmpty())
68  return false;
69  // The module can be deleted.
70  return true;
71 }
72 
73 //===----------------------------------------------------------------------===//
74 // Hashing
75 //===----------------------------------------------------------------------===//
76 
77 llvm::raw_ostream &printHex(llvm::raw_ostream &stream,
78  ArrayRef<uint8_t> bytes) {
79  // Print the hash on a single line.
80  return stream << format_bytes(bytes, std::nullopt, 32) << "\n";
81 }
82 
83 llvm::raw_ostream &printHash(llvm::raw_ostream &stream, llvm::SHA256 &data) {
84  return printHex(stream, data.result());
85 }
86 
87 llvm::raw_ostream &printHash(llvm::raw_ostream &stream, std::string data) {
88  ArrayRef<uint8_t> bytes(reinterpret_cast<const uint8_t *>(data.c_str()),
89  data.length());
90  return printHex(stream, bytes);
91 }
92 
93 // This struct contains information to determine module module uniqueness. A
94 // first element is a structural hash of the module, and the second element is
95 // an array which tracks module names encountered in the walk. Since module
96 // names could be replaced during dedup, it's necessary to keep names up-to-date
97 // before actually combining them into structural hashes.
98 struct ModuleInfo {
99  // SHA256 hash.
100  std::array<uint8_t, 32> structuralHash;
101  // Module names referred by instance op in the module.
102  std::vector<StringAttr> referredModuleNames;
103 };
104 
105 static bool operator==(const ModuleInfo &lhs, const ModuleInfo &rhs) {
106  return lhs.structuralHash == rhs.structuralHash &&
107  lhs.referredModuleNames == rhs.referredModuleNames;
108 }
109 
110 /// This struct contains constant string attributes shared across different
111 /// threads.
113  explicit StructuralHasherSharedConstants(MLIRContext *context) {
114  portTypesAttr = StringAttr::get(context, "portTypes");
115  moduleNameAttr = StringAttr::get(context, "moduleName");
116  innerSymAttr = StringAttr::get(context, "inner_sym");
117  portSymbolsAttr = StringAttr::get(context, "portSymbols");
118  portNamesAttr = StringAttr::get(context, "portNames");
119  nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
120  nonessentialAttributes.insert(StringAttr::get(context, "convention"));
121  nonessentialAttributes.insert(StringAttr::get(context, "name"));
122  nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
123  nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
124  nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
125  nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
126  nonessentialAttributes.insert(StringAttr::get(context, "sym_visibility"));
127  };
128 
129  // This is a cached "portTypes" string attr.
130  StringAttr portTypesAttr;
131 
132  // This is a cached "moduleName" string attr.
133  StringAttr moduleNameAttr;
134 
135  // This is a cached "inner_sym" string attr.
136  StringAttr innerSymAttr;
137 
138  // This is a cached "portSymbols" string attr.
139  StringAttr portSymbolsAttr;
140 
141  // This is a cached "portNames" string attr.
142  StringAttr portNamesAttr;
143 
144  // This is a set of every attribute we should ignore.
145  DenseSet<Attribute> nonessentialAttributes;
146 };
147 
150  : constants(constants) {}
151 
152  ModuleInfo getModuleInfo(FModuleLike module) {
153  update(&(*module));
154  return {sha.final(), std::move(referredModuleNames)};
155  }
156 
157 private:
158  // Get the identifier for an object. The identifier is assigned on first use.
159  unsigned getID(void *object) {
160  auto [it, inserted] = idTable.try_emplace(object, nextID);
161  if (inserted)
162  ++nextID;
163  return it->second;
164  }
165 
166  // Get the identifier for an IR object. Free the ID, too.
167  unsigned finalizeID(void *object) {
168  auto it = idTable.find(object);
169  if (it == idTable.end())
170  return nextID++;
171  auto id = it->second;
172  idTable.erase(it);
173  return id;
174  }
175 
176  unsigned getInnerSymID(StringAttr name) {
177  auto [it, inserted] = innerSymIDTable.try_emplace(name, nextInnerSymID);
178  if (inserted)
179  ++nextInnerSymID;
180  return it->second;
181  }
182 
183  void update(OpOperand &operand) {
184  auto value = operand.get();
185  if (auto result = dyn_cast<OpResult>(value)) {
186  auto *op = result.getOwner();
187  update(getID(op));
188  update(result.getResultNumber());
189  return;
190  }
191  if (auto argument = dyn_cast<BlockArgument>(value)) {
192  auto *block = argument.getOwner();
193  update(getID(block));
194  update(argument.getArgNumber());
195  return;
196  }
197  llvm_unreachable("Unknown value type");
198  }
199 
200  void update(const void *pointer) {
201  auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
202  sha.update(ArrayRef<uint8_t>(addr, sizeof pointer));
203  }
204 
205  void update(size_t value) {
206  auto *addr = reinterpret_cast<const uint8_t *>(&value);
207  sha.update(ArrayRef<uint8_t>(addr, sizeof value));
208  }
209 
210  void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
211 
212  // NOLINTNEXTLINE(misc-no-recursion)
213  void update(BundleType type) {
214  update(type.getTypeID());
215  for (auto &element : type.getElements()) {
216  update(element.isFlip);
217  update(element.type);
218  }
219  }
220 
221  // NOLINTNEXTLINE(misc-no-recursion)
222  void update(Type type) {
223  if (auto bundle = type_dyn_cast<BundleType>(type))
224  return update(bundle);
225  update(type.getAsOpaquePointer());
226  }
227 
228  void update(OpResult result) {
229  // Like instance ops, don't use object ops' result types since they might be
230  // replaced by dedup. Record the class names and lazily combine their hashes
231  // using the same mechanism as instances and modules.
232  if (auto objectOp = dyn_cast<ObjectOp>(result.getOwner())) {
233  referredModuleNames.push_back(objectOp.getType().getNameAttr().getAttr());
234  return;
235  }
236 
237  update(result.getType());
238  }
239 
240  /// Hash the top level attribute dictionary of the operation. This function
241  /// has special handling for inner symbols, ports, and referenced modules.
242  void update(Operation *op, DictionaryAttr dict) {
243  for (auto namedAttr : dict) {
244  auto name = namedAttr.getName();
245  auto value = namedAttr.getValue();
246  // Skip names and annotations, except in certain cases.
247  // Names of ports are load bearing for classes, so we do hash those.
248  bool isClassPortNames =
249  isa<ClassLike>(op) && name == constants.portNamesAttr;
250  if (constants.nonessentialAttributes.contains(name) && !isClassPortNames)
251  continue;
252 
253  // Hash the attribute name (an interned pointer).
254  update(name.getAsOpaquePointer());
255 
256  // Hash the port types.
257  if (name == constants.portTypesAttr) {
258  auto portTypes = cast<ArrayAttr>(value).getAsValueRange<TypeAttr>();
259  for (auto type : portTypes)
260  update(type);
261  continue;
262  }
263 
264  // Special case the InnerSymbols to ignore the symbol names.
265  if (name == constants.portSymbolsAttr) {
266  if (op->getNumRegions() != 1)
267  continue;
268  auto &region = op->getRegion(0);
269  if (region.getBlocks().empty())
270  continue;
271  for (auto sym : cast<ArrayAttr>(value).getAsRange<hw::InnerSymAttr>()) {
272  for (auto property : sym) {
273  update(property.getFieldID());
274  update(getInnerSymID(property.getName()));
275  }
276  }
277  continue;
278  }
279 
280  if (name == constants.innerSymAttr) {
281  auto innerSym = cast<hw::InnerSymAttr>(value);
282  for (auto property : innerSym) {
283  update(property.getFieldID());
284  update(getInnerSymID(property.getName()));
285  }
286  continue;
287  }
288 
289  // For instance op, don't use `moduleName` attributes since they might be
290  // replaced by dedup. Record the names and lazily combine their hashes.
291  // It is assumed that module names are hashed only through instance ops;
292  // it could cause suboptimal results if there was other operation that
293  // refers to module names through essential attributes.
294  if (isa<InstanceOp>(op) && name == constants.moduleNameAttr) {
295  referredModuleNames.push_back(cast<FlatSymbolRefAttr>(value).getAttr());
296  continue;
297  }
298 
299  // TODO: properly handle DistinctAttr, including its use in paths.
300  // See https://github.com/llvm/circt/issues/6583.
301  if (isa<DistinctAttr>(value))
302  continue;
303 
304  // If this is an symbol reference, we need to perform name erasure.
305  if (auto innerRef = dyn_cast<hw::InnerRefAttr>(value))
306  update(getInnerSymID(innerRef.getName()));
307  else
308  update(value.getAsOpaquePointer());
309  }
310  }
311 
312  void update(mlir::OperationName name) {
313  // Operation names are interned.
314  update(name.getAsOpaquePointer());
315  }
316 
317  // NOLINTNEXTLINE(misc-no-recursion)
318  void update(Block *block) {
319  for (auto &op : llvm::reverse(*block))
320  update(&op);
321  for (auto type : block->getArgumentTypes())
322  update(type);
323  update(finalizeID(block));
324  update(position);
325  ++position;
326  }
327 
328  // NOLINTNEXTLINE(misc-no-recursion)
329  void update(Region *region) {
330  for (auto &block : llvm::reverse(region->getBlocks()))
331  update(&block);
332  update(position);
333  ++position;
334  }
335 
336  // NOLINTNEXTLINE(misc-no-recursion)
337  void update(Operation *op) {
338  // Hash the regions. We need to make sure an empty region doesn't hash the
339  // same as no region, so we include the number of regions.
340  update(op->getNumRegions());
341  for (auto &region : reverse(op->getRegions()))
342  update(&region);
343 
344  update(op->getName());
345 
346  // Record the uses for later hashing.
347  for (auto &operand : op->getOpOperands())
348  update(operand);
349 
350  // This happens after the numbering above, as it uses blockarg numbering
351  // for inner symbols.
352  update(op, op->getAttrDictionary());
353 
354  // Record any op results (types).
355  for (auto result : op->getResults())
356  update(result);
357 
358  // Incorporate the hash of uses we have already built.
359  update(finalizeID(op));
360  update(position);
361  ++position;
362  }
363 
364  // A map from an operation/block, to its identifier.
365  DenseMap<void *, unsigned> idTable;
366  unsigned nextID = 0;
367 
368  // A map from an inner symbol, to its identifier.
369  DenseMap<StringAttr, unsigned> innerSymIDTable;
370  unsigned nextInnerSymID = 0;
371 
372  // This keeps track of module names in the order of the appearance.
373  std::vector<StringAttr> referredModuleNames;
374 
375  // String constants.
377 
378  // This is the actual running hash calculation. This is a stateful element
379  // that should be reinitialized after each hash is produced.
380  llvm::SHA256 sha;
381 
382  // The index of the current op. Increment after handling each op.
383  size_t position = 0;
384 };
385 
386 //===----------------------------------------------------------------------===//
387 // Equivalence
388 //===----------------------------------------------------------------------===//
389 
390 /// This class is for reporting differences between two modules which should
391 /// have been deduplicated.
392 struct Equivalence {
393  Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
394  : instanceGraph(instanceGraph) {
395  noDedupClass = StringAttr::get(context, noDedupAnnoClass);
396  dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
397  portDirectionsAttr = StringAttr::get(context, "portDirections");
398  nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
399  nonessentialAttributes.insert(StringAttr::get(context, "name"));
400  nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
401  nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
402  nonessentialAttributes.insert(StringAttr::get(context, "portTypes"));
403  nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
404  nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
405  nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
406  nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
407  }
408 
409  struct ModuleData {
410  ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
411  : a(a), b(b) {}
412  IRMapping map;
413  const hw::InnerSymbolTable &a;
414  const hw::InnerSymbolTable &b;
415  };
416 
417  std::string prettyPrint(Attribute attr) {
418  SmallString<64> buffer;
419  llvm::raw_svector_ostream os(buffer);
420  if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
421  os << "0x";
422  if (integerAttr.getType().isSignlessInteger())
423  integerAttr.getValue().toStringUnsigned(buffer, /*radix=*/16);
424  else
425  integerAttr.getAPSInt().toString(buffer, /*radix=*/16);
426 
427  } else
428  os << attr;
429  return std::string(buffer);
430  }
431 
432  // NOLINTNEXTLINE(misc-no-recursion)
433  LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
434  Operation *a, BundleType aType, Operation *b,
435  BundleType bType) {
436  if (aType.getNumElements() != bType.getNumElements()) {
437  diag.attachNote(a->getLoc())
438  << message << " bundle type has different number of elements";
439  diag.attachNote(b->getLoc()) << "second operation here";
440  return failure();
441  }
442 
443  for (auto elementPair :
444  llvm::zip(aType.getElements(), bType.getElements())) {
445  auto aElement = std::get<0>(elementPair);
446  auto bElement = std::get<1>(elementPair);
447  if (aElement.isFlip != bElement.isFlip) {
448  diag.attachNote(a->getLoc()) << message << " bundle element "
449  << aElement.name << " flip does not match";
450  diag.attachNote(b->getLoc()) << "second operation here";
451  return failure();
452  }
453 
454  if (failed(check(diag,
455  "bundle element \'" + aElement.name.getValue() + "'", a,
456  aElement.type, b, bElement.type)))
457  return failure();
458  }
459  return success();
460  }
461 
462  LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
463  Operation *a, Type aType, Operation *b, Type bType) {
464  if (aType == bType)
465  return success();
466  if (auto aBundleType = type_dyn_cast<BundleType>(aType))
467  if (auto bBundleType = type_dyn_cast<BundleType>(bType))
468  return check(diag, message, a, aBundleType, b, bBundleType);
469  if (type_isa<RefType>(aType) && type_isa<RefType>(bType) &&
470  aType != bType) {
471  diag.attachNote(a->getLoc())
472  << message << ", has a RefType with a different base type "
473  << type_cast<RefType>(aType).getType()
474  << " in the same position of the two modules marked as 'must dedup'. "
475  "(This may be due to Grand Central Taps or Views being different "
476  "between the two modules.)";
477  diag.attachNote(b->getLoc())
478  << "the second module has a different base type "
479  << type_cast<RefType>(bType).getType();
480  return failure();
481  }
482  diag.attachNote(a->getLoc())
483  << message << " types don't match, first type is " << aType;
484  diag.attachNote(b->getLoc()) << "second type is " << bType;
485  return failure();
486  }
487 
488  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
489  Block &aBlock, Operation *b, Block &bBlock) {
490 
491  // Block argument types.
492  auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
493  auto portNo = 0;
494  auto emitMissingPort = [&](Value existsVal, Operation *opExists,
495  Operation *opDoesNotExist) {
496  StringRef portName;
497  auto portNames = opExists->getAttrOfType<ArrayAttr>("portNames");
498  if (portNames)
499  if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
500  portName = portNameAttr.getValue();
501  if (type_isa<RefType>(existsVal.getType())) {
502  diag.attachNote(opExists->getLoc())
503  << " contains a RefType port named '" + portName +
504  "' that only exists in one of the modules (can be due to "
505  "difference in Grand Central Tap or View of two modules "
506  "marked with must dedup)";
507  diag.attachNote(opDoesNotExist->getLoc())
508  << "second module to be deduped that does not have the RefType "
509  "port";
510  } else {
511  diag.attachNote(opExists->getLoc())
512  << "port '" + portName + "' only exists in one of the modules";
513  diag.attachNote(opDoesNotExist->getLoc())
514  << "second module to be deduped that does not have the port";
515  }
516  return failure();
517  };
518 
519  for (auto argPair :
520  llvm::zip_longest(aBlock.getArguments(), bBlock.getArguments())) {
521  auto &aArg = std::get<0>(argPair);
522  auto &bArg = std::get<1>(argPair);
523  if (aArg.has_value() && bArg.has_value()) {
524  // TODO: we should print the port number if there are no port names, but
525  // there are always port names ;).
526  StringRef portName;
527  if (portNames) {
528  if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
529  portName = portNameAttr.getValue();
530  }
531  // Assumption here that block arguments correspond to ports.
532  if (failed(check(diag, "module port '" + portName + "'", a,
533  aArg->getType(), b, bArg->getType())))
534  return failure();
535  data.map.map(aArg.value(), bArg.value());
536  portNo++;
537  continue;
538  }
539  if (!aArg.has_value())
540  std::swap(a, b);
541  return emitMissingPort(aArg.has_value() ? aArg.value() : bArg.value(), a,
542  b);
543  }
544 
545  // Blocks operations.
546  auto aIt = aBlock.begin();
547  auto aEnd = aBlock.end();
548  auto bIt = bBlock.begin();
549  auto bEnd = bBlock.end();
550  while (aIt != aEnd && bIt != bEnd)
551  if (failed(check(diag, data, &*aIt++, &*bIt++)))
552  return failure();
553  if (aIt != aEnd) {
554  diag.attachNote(aIt->getLoc()) << "first block has more operations";
555  diag.attachNote(b->getLoc()) << "second block here";
556  return failure();
557  }
558  if (bIt != bEnd) {
559  diag.attachNote(bIt->getLoc()) << "second block has more operations";
560  diag.attachNote(a->getLoc()) << "first block here";
561  return failure();
562  }
563  return success();
564  }
565 
566  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
567  Region &aRegion, Operation *b, Region &bRegion) {
568  auto aIt = aRegion.begin();
569  auto aEnd = aRegion.end();
570  auto bIt = bRegion.begin();
571  auto bEnd = bRegion.end();
572 
573  // Region blocks.
574  while (aIt != aEnd && bIt != bEnd)
575  if (failed(check(diag, data, a, *aIt++, b, *bIt++)))
576  return failure();
577  if (aIt != aEnd || bIt != bEnd) {
578  diag.attachNote(a->getLoc())
579  << "operation regions have different number of blocks";
580  diag.attachNote(b->getLoc()) << "second operation here";
581  return failure();
582  }
583  return success();
584  }
585 
586  LogicalResult check(InFlightDiagnostic &diag, Operation *a,
587  mlir::DenseBoolArrayAttr aAttr, Operation *b,
588  mlir::DenseBoolArrayAttr bAttr) {
589  if (aAttr == bAttr)
590  return success();
591  auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
592  for (unsigned i = 0, e = aAttr.size(); i < e; ++i) {
593  auto aDirection = aAttr[i];
594  auto bDirection = bAttr[i];
595  if (aDirection != bDirection) {
596  auto &note = diag.attachNote(a->getLoc()) << "module port ";
597  if (portNames)
598  note << "'" << cast<StringAttr>(portNames[i]).getValue() << "'";
599  else
600  note << i;
601  note << " directions don't match, first direction is '"
602  << direction::toString(aDirection) << "'";
603  diag.attachNote(b->getLoc()) << "second direction is '"
604  << direction::toString(bDirection) << "'";
605  return failure();
606  }
607  }
608  return success();
609  }
610 
611  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
612  DictionaryAttr aDict, Operation *b,
613  DictionaryAttr bDict) {
614  // Fast path.
615  if (aDict == bDict)
616  return success();
617 
618  DenseSet<Attribute> seenAttrs;
619  for (auto namedAttr : aDict) {
620  auto attrName = namedAttr.getName();
621  if (nonessentialAttributes.contains(attrName))
622  continue;
623 
624  auto aAttr = namedAttr.getValue();
625  auto bAttr = bDict.get(attrName);
626  if (!bAttr) {
627  diag.attachNote(a->getLoc())
628  << "second operation is missing attribute " << attrName;
629  diag.attachNote(b->getLoc()) << "second operation here";
630  return diag;
631  }
632 
633  if (isa<hw::InnerRefAttr>(aAttr) && isa<hw::InnerRefAttr>(bAttr)) {
634  auto bRef = cast<hw::InnerRefAttr>(bAttr);
635  auto aRef = cast<hw::InnerRefAttr>(aAttr);
636  // See if they are pointing at the same operation or port.
637  auto aTarget = data.a.lookup(aRef.getName());
638  auto bTarget = data.b.lookup(bRef.getName());
639  if (!aTarget || !bTarget)
640  diag.attachNote(a->getLoc())
641  << "malformed ir, possibly violating use-before-def";
642  auto error = [&]() {
643  diag.attachNote(a->getLoc())
644  << "operations have different targets, first operation has "
645  << aTarget;
646  diag.attachNote(b->getLoc()) << "second operation has " << bTarget;
647  return failure();
648  };
649  if (aTarget.isPort()) {
650  // If they are targeting ports, make sure its the same port number.
651  if (!bTarget.isPort() || aTarget.getPort() != bTarget.getPort())
652  return error();
653  } else {
654  // Otherwise make sure that they are targeting the same operation.
655  if (!bTarget.isOpOnly() ||
656  aTarget.getOp() != data.map.lookup(bTarget.getOp()))
657  return error();
658  }
659  if (aTarget.getField() != bTarget.getField())
660  return error();
661  } else if (attrName == portDirectionsAttr) {
662  // Special handling for the port directions attribute for better
663  // error messages.
664  if (failed(check(diag, a, cast<mlir::DenseBoolArrayAttr>(aAttr), b,
665  cast<mlir::DenseBoolArrayAttr>(bAttr))))
666  return failure();
667  } else if (isa<DistinctAttr>(aAttr) && isa<DistinctAttr>(bAttr)) {
668  // TODO: properly handle DistinctAttr, including its use in paths.
669  // See https://github.com/llvm/circt/issues/6583
670  } else if (aAttr != bAttr) {
671  diag.attachNote(a->getLoc())
672  << "first operation has attribute '" << attrName.getValue()
673  << "' with value " << prettyPrint(aAttr);
674  diag.attachNote(b->getLoc())
675  << "second operation has value " << prettyPrint(bAttr);
676  return failure();
677  }
678  seenAttrs.insert(attrName);
679  }
680  if (aDict.getValue().size() != bDict.getValue().size()) {
681  for (auto namedAttr : bDict) {
682  auto attrName = namedAttr.getName();
683  // Skip the attribute if we don't care about this particular one or it
684  // is one that is known to be in both dictionaries.
685  if (nonessentialAttributes.contains(attrName) ||
686  seenAttrs.contains(attrName))
687  continue;
688  // We have found an attribute that is only in the second operation.
689  diag.attachNote(a->getLoc())
690  << "first operation is missing attribute " << attrName;
691  diag.attachNote(b->getLoc()) << "second operation here";
692  return failure();
693  }
694  }
695  return success();
696  }
697 
698  // NOLINTNEXTLINE(misc-no-recursion)
699  LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a,
700  FInstanceLike b) {
701  auto aName = a.getReferencedModuleNameAttr();
702  auto bName = b.getReferencedModuleNameAttr();
703  if (aName == bName)
704  return success();
705 
706  // If the modules instantiate are different we will want to know why the
707  // sub module did not dedupliate. This code recursively checks the child
708  // module.
709  auto aModule = instanceGraph.lookup(aName)->getModule();
710  auto bModule = instanceGraph.lookup(bName)->getModule();
711  // Create a new error for the submodule.
712  diag.attachNote(std::nullopt)
713  << "in instance " << a.getInstanceNameAttr() << " of " << aName
714  << ", and instance " << b.getInstanceNameAttr() << " of " << bName;
715  check(diag, aModule, bModule);
716  return failure();
717  }
718 
719  // NOLINTNEXTLINE(misc-no-recursion)
720  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
721  Operation *b) {
722  // Operation name.
723  if (a->getName() != b->getName()) {
724  diag.attachNote(a->getLoc()) << "first operation is a " << a->getName();
725  diag.attachNote(b->getLoc()) << "second operation is a " << b->getName();
726  return failure();
727  }
728 
729  // If its an instance operaiton, perform some checking and possibly
730  // recurse.
731  if (auto aInst = dyn_cast<FInstanceLike>(a)) {
732  auto bInst = cast<FInstanceLike>(b);
733  if (failed(check(diag, aInst, bInst)))
734  return failure();
735  }
736 
737  // Operation results.
738  if (a->getNumResults() != b->getNumResults()) {
739  diag.attachNote(a->getLoc())
740  << "operations have different number of results";
741  diag.attachNote(b->getLoc()) << "second operation here";
742  return failure();
743  }
744  for (auto resultPair : llvm::zip(a->getResults(), b->getResults())) {
745  auto &aValue = std::get<0>(resultPair);
746  auto &bValue = std::get<1>(resultPair);
747  if (failed(check(diag, "operation result", a, aValue.getType(), b,
748  bValue.getType())))
749  return failure();
750  data.map.map(aValue, bValue);
751  }
752 
753  // Operations operands.
754  if (a->getNumOperands() != b->getNumOperands()) {
755  diag.attachNote(a->getLoc())
756  << "operations have different number of operands";
757  diag.attachNote(b->getLoc()) << "second operation here";
758  return failure();
759  }
760  for (auto operandPair : llvm::zip(a->getOperands(), b->getOperands())) {
761  auto &aValue = std::get<0>(operandPair);
762  auto &bValue = std::get<1>(operandPair);
763  if (bValue != data.map.lookup(aValue)) {
764  diag.attachNote(a->getLoc())
765  << "operations use different operands, first operand is '"
766  << getFieldName(
767  getFieldRefFromValue(aValue, /*lookThroughCasts=*/true))
768  .first
769  << "'";
770  diag.attachNote(b->getLoc())
771  << "second operand is '"
772  << getFieldName(
773  getFieldRefFromValue(bValue, /*lookThroughCasts=*/true))
774  .first
775  << "', but should have been '"
776  << getFieldName(getFieldRefFromValue(data.map.lookup(aValue),
777  /*lookThroughCasts=*/true))
778  .first
779  << "'";
780  return failure();
781  }
782  }
783  data.map.map(a, b);
784 
785  // Operation regions.
786  if (a->getNumRegions() != b->getNumRegions()) {
787  diag.attachNote(a->getLoc())
788  << "operations have different number of regions";
789  diag.attachNote(b->getLoc()) << "second operation here";
790  return failure();
791  }
792  for (auto regionPair : llvm::zip(a->getRegions(), b->getRegions())) {
793  auto &aRegion = std::get<0>(regionPair);
794  auto &bRegion = std::get<1>(regionPair);
795  if (failed(check(diag, data, a, aRegion, b, bRegion)))
796  return failure();
797  }
798 
799  // Operation attributes.
800  if (failed(check(diag, data, a, a->getAttrDictionary(), b,
801  b->getAttrDictionary())))
802  return failure();
803  return success();
804  }
805 
806  // NOLINTNEXTLINE(misc-no-recursion)
807  void check(InFlightDiagnostic &diag, Operation *a, Operation *b) {
808  hw::InnerSymbolTable aTable(a);
809  hw::InnerSymbolTable bTable(b);
810  ModuleData data(aTable, bTable);
811  if (AnnotationSet::hasAnnotation(a, noDedupClass)) {
812  diag.attachNote(a->getLoc()) << "module marked NoDedup";
813  return;
814  }
815  if (AnnotationSet::hasAnnotation(b, noDedupClass)) {
816  diag.attachNote(b->getLoc()) << "module marked NoDedup";
817  return;
818  }
819  auto aSymbol = cast<mlir::SymbolOpInterface>(a);
820  auto bSymbol = cast<mlir::SymbolOpInterface>(b);
821  if (!canRemoveModule(aSymbol)) {
822  diag.attachNote(a->getLoc())
823  << "module is "
824  << (aSymbol.isPrivate() ? "private but not discardable" : "public");
825  return;
826  }
827  if (!canRemoveModule(bSymbol)) {
828  diag.attachNote(b->getLoc())
829  << "module is "
830  << (bSymbol.isPrivate() ? "private but not discardable" : "public");
831  return;
832  }
833  auto aGroup =
834  dyn_cast_or_null<StringAttr>(a->getDiscardableAttr(dedupGroupAttrName));
835  auto bGroup = dyn_cast_or_null<StringAttr>(
836  b->getAttrOfType<StringAttr>(dedupGroupAttrName));
837  if (aGroup != bGroup) {
838  if (bGroup) {
839  diag.attachNote(b->getLoc())
840  << "module is in dedup group '" << bGroup.str() << "'";
841  } else {
842  diag.attachNote(b->getLoc()) << "module is not part of a dedup group";
843  }
844  if (aGroup) {
845  diag.attachNote(a->getLoc())
846  << "module is in dedup group '" << aGroup.str() << "'";
847  } else {
848  diag.attachNote(a->getLoc()) << "module is not part of a dedup group";
849  }
850  return;
851  }
852  if (failed(check(diag, data, a, b)))
853  return;
854  diag.attachNote(a->getLoc()) << "first module here";
855  diag.attachNote(b->getLoc()) << "second module here";
856  }
857 
858  // This is a cached "portDirections" string attr.
859  StringAttr portDirectionsAttr;
860  // This is a cached "NoDedup" annotation class string attr.
861  StringAttr noDedupClass;
862  // This is a cached string attr for the dedup group attribute.
863  StringAttr dedupGroupAttrName;
864 
865  // This is a set of every attribute we should ignore.
866  DenseSet<Attribute> nonessentialAttributes;
868 };
869 
870 //===----------------------------------------------------------------------===//
871 // Deduplication
872 //===----------------------------------------------------------------------===//
873 
874 // Custom location merging. This only keeps track of 8 annotations from ".fir"
875 // files, and however many annotations come from "real" sources. When
876 // deduplicating, modules tend not to have scala source locators, so we wind
877 // up fusing source locators for a module from every copy being deduped. There
878 // is little value in this (all the modules are identical by definition).
879 static Location mergeLoc(MLIRContext *context, Location to, Location from) {
880  // Unique the set of locations to be fused.
881  llvm::SmallSetVector<Location, 4> decomposedLocs;
882  // only track 8 "fir" locations
883  unsigned seenFIR = 0;
884  for (auto loc : {to, from}) {
885  // If the location is a fused location we decompose it if it has no
886  // metadata or the metadata is the same as the top level metadata.
887  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
888  // UnknownLoc's have already been removed from FusedLocs so we can
889  // simply add all of the internal locations.
890  for (auto loc : fusedLoc.getLocations()) {
891  if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
892  if (fileLoc.getFilename().strref().ends_with(".fir")) {
893  ++seenFIR;
894  if (seenFIR > 8)
895  continue;
896  }
897  }
898  decomposedLocs.insert(loc);
899  }
900  continue;
901  }
902 
903  // Might need to skip this fir.
904  if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
905  if (fileLoc.getFilename().strref().ends_with(".fir")) {
906  ++seenFIR;
907  if (seenFIR > 8)
908  continue;
909  }
910  }
911  // Otherwise, only add known locations to the set.
912  if (!isa<UnknownLoc>(loc))
913  decomposedLocs.insert(loc);
914  }
915 
916  auto locs = decomposedLocs.getArrayRef();
917 
918  // Handle the simple cases of less than two locations. Ensure the metadata (if
919  // provided) is not dropped.
920  if (locs.empty())
921  return UnknownLoc::get(context);
922  if (locs.size() == 1)
923  return locs.front();
924 
925  return FusedLoc::get(context, locs);
926 }
927 
928 struct Deduper {
929 
930  using RenameMap = DenseMap<StringAttr, StringAttr>;
931 
932  Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable,
933  NLATable *nlaTable, CircuitOp circuit)
934  : context(circuit->getContext()), instanceGraph(instanceGraph),
935  symbolTable(symbolTable), nlaTable(nlaTable),
936  nlaBlock(circuit.getBodyBlock()),
937  nonLocalString(StringAttr::get(context, "circt.nonlocal")),
938  classString(StringAttr::get(context, "class")) {
939  // Populate the NLA cache.
940  for (auto nla : circuit.getOps<hw::HierPathOp>())
941  nlaCache[nla.getNamepathAttr()] = nla.getSymNameAttr();
942  }
943 
944  /// Remove the "fromModule", and replace all references to it with the
945  /// "toModule". Modules should be deduplicated in a bottom-up order. Any
946  /// module which is not deduplicated needs to be recorded with the `record`
947  /// call.
948  void dedup(FModuleLike toModule, FModuleLike fromModule) {
949  // A map of operation (e.g. wires, nodes) names which are changed, which is
950  // used to update NLAs that reference the "fromModule".
951  RenameMap renameMap;
952 
953  // Merge the port locations.
954  SmallVector<Attribute> newLocs;
955  for (auto [toLoc, fromLoc] : llvm::zip(toModule.getPortLocations(),
956  fromModule.getPortLocations())) {
957  if (toLoc == fromLoc)
958  newLocs.push_back(toLoc);
959  else
960  newLocs.push_back(mergeLoc(context, cast<LocationAttr>(toLoc),
961  cast<LocationAttr>(fromLoc)));
962  }
963  toModule->setAttr("portLocations", ArrayAttr::get(context, newLocs));
964 
965  // Merge the two modules.
966  mergeOps(renameMap, toModule, toModule, fromModule, fromModule);
967 
968  // Rewrite NLAs pathing through these modules to refer to the to module. It
969  // is safe to do this at this point because NLAs cannot be one element long.
970  // This means that all NLAs which require more context cannot be targetting
971  // something in the module it self.
972  if (auto to = dyn_cast<FModuleOp>(*toModule))
973  rewriteModuleNLAs(renameMap, to, cast<FModuleOp>(*fromModule));
974  else
975  rewriteExtModuleNLAs(renameMap, toModule.getModuleNameAttr(),
976  fromModule.getModuleNameAttr());
977 
978  replaceInstances(toModule, fromModule);
979  }
980 
981  /// Record the usages of any NLA's in this module, so that we may update the
982  /// annotation if the parent module is deduped with another module.
983  void record(FModuleLike module) {
984  // Record any annotations on the module.
985  recordAnnotations(module);
986  // Record port annotations.
987  for (unsigned i = 0, e = getNumPorts(module); i < e; ++i)
988  recordAnnotations(PortAnnoTarget(module, i));
989  // Record any annotations in the module body.
990  module->walk([&](Operation *op) { recordAnnotations(op); });
991  }
992 
993 private:
994  /// Get a cached namespace for a module.
995  hw::InnerSymbolNamespace &getNamespace(Operation *module) {
996  return moduleNamespaces.try_emplace(module, cast<FModuleLike>(module))
997  .first->second;
998  }
999 
1000  /// For a specific annotation target, record all the unique NLAs which
1001  /// target it in the `targetMap`.
1003  for (auto anno : target.getAnnotations())
1004  if (auto nlaRef = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
1005  targetMap[nlaRef.getAttr()].insert(target);
1006  }
1007 
1008  /// Record all targets which use an NLA.
1009  void recordAnnotations(Operation *op) {
1010  // Record annotations.
1011  recordAnnotations(OpAnnoTarget(op));
1012 
1013  // Record port annotations only if this is a mem operation.
1014  auto mem = dyn_cast<MemOp>(op);
1015  if (!mem)
1016  return;
1017 
1018  // Record port annotations.
1019  for (unsigned i = 0, e = mem->getNumResults(); i < e; ++i)
1020  recordAnnotations(PortAnnoTarget(mem, i));
1021  }
1022 
1023  /// This deletes and replaces all instances of the "fromModule" with instances
1024  /// of the "toModule".
1025  void replaceInstances(FModuleLike toModule, Operation *fromModule) {
1026  // Replace all instances of the other module.
1027  auto *fromNode =
1028  instanceGraph[::cast<igraph::ModuleOpInterface>(fromModule)];
1029  auto *toNode = instanceGraph[toModule];
1030  auto toModuleRef = FlatSymbolRefAttr::get(toModule.getModuleNameAttr());
1031  for (auto *oldInstRec : llvm::make_early_inc_range(fromNode->uses())) {
1032  auto inst = oldInstRec->getInstance();
1033  if (auto instOp = dyn_cast<InstanceOp>(*inst)) {
1034  instOp.setModuleNameAttr(toModuleRef);
1035  instOp.setPortNamesAttr(toModule.getPortNamesAttr());
1036  } else if (auto objectOp = dyn_cast<ObjectOp>(*inst)) {
1037  auto classLike = cast<ClassLike>(*toNode->getModule());
1038  ClassType classType = detail::getInstanceTypeForClassLike(classLike);
1039  objectOp.getResult().setType(classType);
1040  }
1041  oldInstRec->getParent()->addInstance(inst, toNode);
1042  oldInstRec->erase();
1043  }
1044  instanceGraph.erase(fromNode);
1045  fromModule->erase();
1046  }
1047 
1048  /// Look up the instantiations of the `from` module and create an NLA for each
1049  /// one, appending the baseNamepath to each NLA. This is used to add more
1050  /// context to an already existing NLA. The `fromModule` is used to indicate
1051  /// which module the annotation is coming from before the merge, and will be
1052  /// used to create the namepaths.
1053  SmallVector<FlatSymbolRefAttr>
1054  createNLAs(Operation *fromModule, ArrayRef<Attribute> baseNamepath,
1055  SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1056  // Create an attribute array with a placeholder in the first element, where
1057  // the root refence of the NLA will be inserted.
1058  SmallVector<Attribute> namepath = {nullptr};
1059  namepath.append(baseNamepath.begin(), baseNamepath.end());
1060 
1061  auto loc = fromModule->getLoc();
1062  auto *fromNode = instanceGraph[cast<igraph::ModuleOpInterface>(fromModule)];
1063  SmallVector<FlatSymbolRefAttr> nlas;
1064  for (auto *instanceRecord : fromNode->uses()) {
1065  auto parent = cast<FModuleOp>(*instanceRecord->getParent()->getModule());
1066  auto inst = instanceRecord->getInstance();
1067  namepath[0] = OpAnnoTarget(inst).getNLAReference(getNamespace(parent));
1068  auto arrayAttr = ArrayAttr::get(context, namepath);
1069  // Check the NLA cache to see if we already have this NLA.
1070  auto &cacheEntry = nlaCache[arrayAttr];
1071  if (!cacheEntry) {
1072  auto nla = OpBuilder::atBlockBegin(nlaBlock).create<hw::HierPathOp>(
1073  loc, "nla", arrayAttr);
1074  // Insert it into the symbol table to get a unique name.
1075  symbolTable.insert(nla);
1076  // Store it in the cache.
1077  cacheEntry = nla.getNameAttr();
1078  nla.setVisibility(vis);
1079  nlaTable->addNLA(nla);
1080  }
1081  auto nlaRef = FlatSymbolRefAttr::get(cast<StringAttr>(cacheEntry));
1082  nlas.push_back(nlaRef);
1083  }
1084  return nlas;
1085  }
1086 
1087  /// Look up the instantiations of this module and create an NLA for each one.
1088  /// This returns an array of symbol references which can be used to reference
1089  /// the NLAs.
1090  SmallVector<FlatSymbolRefAttr>
1091  createNLAs(StringAttr toModuleName, FModuleLike fromModule,
1092  SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1093  return createNLAs(fromModule, FlatSymbolRefAttr::get(toModuleName), vis);
1094  }
1095 
1096  /// Clone the annotation for each NLA in a list. The attribute list should
1097  /// have a placeholder for the "circt.nonlocal" field, and `nonLocalIndex`
1098  /// should be the index of this field.
1099  void cloneAnnotation(SmallVectorImpl<FlatSymbolRefAttr> &nlas,
1100  Annotation anno, ArrayRef<NamedAttribute> attributes,
1101  unsigned nonLocalIndex,
1102  SmallVectorImpl<Annotation> &newAnnotations) {
1103  SmallVector<NamedAttribute> mutableAttributes(attributes.begin(),
1104  attributes.end());
1105  for (auto &nla : nlas) {
1106  // Add the new annotation.
1107  mutableAttributes[nonLocalIndex].setValue(nla);
1108  auto dict = DictionaryAttr::getWithSorted(context, mutableAttributes);
1109  // The original annotation records if its a subannotation.
1110  anno.setDict(dict);
1111  newAnnotations.push_back(anno);
1112  }
1113  }
1114 
1115  /// This erases the NLA op, and removes the NLA from every module's NLA map,
1116  /// but it does not delete the NLA reference from the target operation's
1117  /// annotations.
1118  void eraseNLA(hw::HierPathOp nla) {
1119  // Erase the NLA from the leaf module's nlaMap.
1120  targetMap.erase(nla.getNameAttr());
1121  nlaTable->erase(nla);
1122  nlaCache.erase(nla.getNamepathAttr());
1123  symbolTable.erase(nla);
1124  }
1125 
1126  /// Process all NLAs referencing the "from" module to point to the "to"
1127  /// module. This is used after merging two modules together.
1128  void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule,
1129  FModuleOp fromModule) {
1130  auto toName = toModule.getNameAttr();
1131  auto fromName = fromModule.getNameAttr();
1132  // Create a copy of the current NLAs. We will be pushing and removing
1133  // NLAs from this op as we go.
1134  auto moduleNLAs = nlaTable->lookup(fromModule.getNameAttr()).vec();
1135  // Change the NLA to target the toModule.
1136  nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1137  // Now we walk the NLA searching for ones that require more context to be
1138  // added.
1139  for (auto nla : moduleNLAs) {
1140  auto elements = nla.getNamepath().getValue();
1141  // If we don't need to add more context, we're done here.
1142  if (nla.root() != toName)
1143  continue;
1144  // Create the replacement NLAs.
1145  SmallVector<Attribute> namepath(elements.begin(), elements.end());
1146  auto nlaRefs = createNLAs(fromModule, namepath, nla.getVisibility());
1147  // Copy out the targets, because we will be updating the map.
1148  auto &set = targetMap[nla.getSymNameAttr()];
1149  SmallVector<AnnoTarget> targets(set.begin(), set.end());
1150  // Replace the uses of the old NLA with the new NLAs.
1151  for (auto target : targets) {
1152  // We have to clone any annotation which uses the old NLA for each new
1153  // NLA. This array collects the new set of annotations.
1154  SmallVector<Annotation> newAnnotations;
1155  for (auto anno : target.getAnnotations()) {
1156  // Find the non-local field of the annotation.
1157  auto [it, found] = mlir::impl::findAttrSorted(
1158  anno.begin(), anno.end(), nonLocalString);
1159  // If this annotation doesn't use the target NLA, copy it with no
1160  // changes.
1161  if (!found || cast<FlatSymbolRefAttr>(it->getValue()).getAttr() !=
1162  nla.getSymNameAttr()) {
1163  newAnnotations.push_back(anno);
1164  continue;
1165  }
1166  auto nonLocalIndex = std::distance(anno.begin(), it);
1167  // Clone the annotation and add it to the list of new annotations.
1168  cloneAnnotation(nlaRefs, anno,
1169  ArrayRef<NamedAttribute>(anno.begin(), anno.end()),
1170  nonLocalIndex, newAnnotations);
1171  }
1172 
1173  // Apply the new annotations to the operation.
1174  AnnotationSet annotations(newAnnotations, context);
1175  target.setAnnotations(annotations);
1176  // Record that target uses the NLA.
1177  for (auto nla : nlaRefs)
1178  targetMap[nla.getAttr()].insert(target);
1179  }
1180 
1181  // Erase the old NLA and remove it from all breadcrumbs.
1182  eraseNLA(nla);
1183  }
1184  }
1185 
1186  /// Process all the NLAs that the two modules participate in, replacing
1187  /// references to the "from" module with references to the "to" module, and
1188  /// adding more context if necessary.
1189  void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule,
1190  FModuleOp fromModule) {
1191  addAnnotationContext(renameMap, toModule, toModule);
1192  addAnnotationContext(renameMap, toModule, fromModule);
1193  }
1194 
1195  // Update all NLAs which the "from" external module participates in to the
1196  // "toName".
1197  void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName,
1198  StringAttr fromName) {
1199  nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1200  }
1201 
1202  /// Take an annotation, and update it to be a non-local annotation. If the
1203  /// annotation is already non-local and has enough context, it will be skipped
1204  /// for now. Return true if the annotation was made non-local.
1205  bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to,
1206  FModuleLike fromModule, Annotation anno,
1207  SmallVectorImpl<Annotation> &newAnnotations) {
1208  // Start constructing a new annotation, pushing a "circt.nonLocal" field
1209  // into the correct spot if its not already a non-local annotation.
1210  SmallVector<NamedAttribute> attributes;
1211  int nonLocalIndex = -1;
1212  for (const auto &val : llvm::enumerate(anno)) {
1213  auto attr = val.value();
1214  // Is this field "circt.nonlocal"?
1215  auto compare = attr.getName().compare(nonLocalString);
1216  assert(compare != 0 && "should not pass non-local annotations here");
1217  if (compare == 1) {
1218  // This annotation definitely does not have "circt.nonlocal" field. Push
1219  // an empty place holder for the non-local annotation.
1220  nonLocalIndex = val.index();
1221  attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1222  break;
1223  }
1224  // Otherwise push the current attribute and keep searching for the
1225  // "circt.nonlocal" field.
1226  attributes.push_back(attr);
1227  }
1228  if (nonLocalIndex == -1) {
1229  // Push an empty "circt.nonlocal" field to the last slot.
1230  nonLocalIndex = attributes.size();
1231  attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1232  } else {
1233  // Copy the remaining annotation fields in.
1234  attributes.append(anno.begin() + nonLocalIndex, anno.end());
1235  }
1236 
1237  // Construct the NLAs if we don't have any yet.
1238  auto nlaRefs = createNLAs(toModuleName, fromModule);
1239  for (auto nla : nlaRefs)
1240  targetMap[nla.getAttr()].insert(to);
1241 
1242  // Clone the annotation for each new NLA.
1243  cloneAnnotation(nlaRefs, anno, attributes, nonLocalIndex, newAnnotations);
1244  return true;
1245  }
1246 
1247  void copyAnnotations(FModuleLike toModule, AnnoTarget to,
1248  FModuleLike fromModule, AnnotationSet annos,
1249  SmallVectorImpl<Annotation> &newAnnotations,
1250  SmallPtrSetImpl<Attribute> &dontTouches) {
1251  for (auto anno : annos) {
1252  if (anno.isClass(dontTouchAnnoClass)) {
1253  // Remove the nonlocal field of the annotation if it has one, since this
1254  // is a sticky annotation.
1255  anno.removeMember("circt.nonlocal");
1256  auto [it, inserted] = dontTouches.insert(anno.getAttr());
1257  if (inserted)
1258  newAnnotations.push_back(anno);
1259  continue;
1260  }
1261  // If the annotation is already non-local, we add it as is. It is already
1262  // added to the target map.
1263  if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal")) {
1264  newAnnotations.push_back(anno);
1265  targetMap[nla.getAttr()].insert(to);
1266  continue;
1267  }
1268  // Otherwise make the annotation non-local and add it to the set.
1269  makeAnnotationNonLocal(toModule.getModuleNameAttr(), to, fromModule, anno,
1270  newAnnotations);
1271  }
1272  }
1273 
1274  /// Merge the annotations of a specific target, either a operation or a port
1275  /// on an operation.
1276  void mergeAnnotations(FModuleLike toModule, AnnoTarget to,
1277  AnnotationSet toAnnos, FModuleLike fromModule,
1278  AnnoTarget from, AnnotationSet fromAnnos) {
1279  // This is a list of all the annotations which will be added to `to`.
1280  SmallVector<Annotation> newAnnotations;
1281 
1282  // We have special case handling of DontTouch to prevent it from being
1283  // turned into a non-local annotation, and to remove duplicates.
1284  llvm::SmallPtrSet<Attribute, 4> dontTouches;
1285 
1286  // Iterate the annotations, transforming most annotations into non-local
1287  // ones.
1288  copyAnnotations(toModule, to, toModule, toAnnos, newAnnotations,
1289  dontTouches);
1290  copyAnnotations(toModule, to, fromModule, fromAnnos, newAnnotations,
1291  dontTouches);
1292 
1293  // Copy over all the new annotations.
1294  if (!newAnnotations.empty())
1295  to.setAnnotations(AnnotationSet(newAnnotations, context));
1296  }
1297 
1298  /// Merge all annotations and port annotations on two operations.
1299  void mergeAnnotations(FModuleLike toModule, Operation *to,
1300  FModuleLike fromModule, Operation *from) {
1301  // Merge op annotations.
1302  mergeAnnotations(toModule, OpAnnoTarget(to), AnnotationSet(to), fromModule,
1303  OpAnnoTarget(from), AnnotationSet(from));
1304 
1305  // Merge port annotations.
1306  if (toModule == to) {
1307  // Merge module port annotations.
1308  for (unsigned i = 0, e = getNumPorts(toModule); i < e; ++i)
1309  mergeAnnotations(toModule, PortAnnoTarget(toModule, i),
1310  AnnotationSet::forPort(toModule, i), fromModule,
1311  PortAnnoTarget(fromModule, i),
1312  AnnotationSet::forPort(fromModule, i));
1313  } else if (auto toMem = dyn_cast<MemOp>(to)) {
1314  // Merge memory port annotations.
1315  auto fromMem = cast<MemOp>(from);
1316  for (unsigned i = 0, e = toMem.getNumResults(); i < e; ++i)
1317  mergeAnnotations(toModule, PortAnnoTarget(toMem, i),
1318  AnnotationSet::forPort(toMem, i), fromModule,
1319  PortAnnoTarget(fromMem, i),
1320  AnnotationSet::forPort(fromMem, i));
1321  }
1322  }
1323 
1324  // Record the symbol name change of the operation or any of its ports when
1325  // merging two operations. The renamed symbols are used to update the
1326  // target of any NLAs. This will add symbols to the "to" operation if needed.
1327  void recordSymRenames(RenameMap &renameMap, FModuleLike toModule,
1328  Operation *to, FModuleLike fromModule,
1329  Operation *from) {
1330  // If the "from" operation has an inner_sym, we need to make sure the
1331  // "to" operation also has an `inner_sym` and then record the renaming.
1332  if (auto fromSym = getInnerSymName(from)) {
1333  auto toSym =
1334  getOrAddInnerSym(to, [&](auto _) -> hw::InnerSymbolNamespace & {
1335  return getNamespace(toModule);
1336  });
1337  renameMap[fromSym] = toSym;
1338  }
1339 
1340  // If there are no port symbols on the "from" operation, we are done here.
1341  auto fromPortSyms = from->getAttrOfType<ArrayAttr>("portSymbols");
1342  if (!fromPortSyms || fromPortSyms.empty())
1343  return;
1344  // We have to map each "fromPort" to each "toPort".
1345  auto &moduleNamespace = getNamespace(toModule);
1346  auto portCount = fromPortSyms.size();
1347  auto portNames = to->getAttrOfType<ArrayAttr>("portNames");
1348  auto toPortSyms = to->getAttrOfType<ArrayAttr>("portSymbols");
1349 
1350  // Create an array of new port symbols for the "to" operation, copy in the
1351  // old symbols if it has any, create an empty symbol array if it doesn't.
1352  SmallVector<Attribute> newPortSyms;
1353  if (toPortSyms.empty())
1354  newPortSyms.assign(portCount, hw::InnerSymAttr());
1355  else
1356  newPortSyms.assign(toPortSyms.begin(), toPortSyms.end());
1357 
1358  for (unsigned portNo = 0; portNo < portCount; ++portNo) {
1359  // If this fromPort doesn't have a symbol, move on to the next one.
1360  if (!fromPortSyms[portNo])
1361  continue;
1362  auto fromSym = cast<hw::InnerSymAttr>(fromPortSyms[portNo]);
1363 
1364  // If this toPort doesn't have a symbol, assign one.
1365  hw::InnerSymAttr toSym;
1366  if (!newPortSyms[portNo]) {
1367  // Get a reasonable base name for the port.
1368  StringRef symName = "inner_sym";
1369  if (portNames)
1370  symName = cast<StringAttr>(portNames[portNo]).getValue();
1371  // Create the symbol and store it into the array.
1372  toSym = hw::InnerSymAttr::get(
1373  StringAttr::get(context, moduleNamespace.newName(symName)));
1374  newPortSyms[portNo] = toSym;
1375  } else
1376  toSym = cast<hw::InnerSymAttr>(newPortSyms[portNo]);
1377 
1378  // Record the renaming.
1379  renameMap[fromSym.getSymName()] = toSym.getSymName();
1380  }
1381 
1382  // Commit the new symbol attribute.
1383  FModuleLike::fixupPortSymsArray(newPortSyms, toModule.getContext());
1384  cast<FModuleLike>(to).setPortSymbols(newPortSyms);
1385  }
1386 
1387  /// Recursively merge two operations.
1388  // NOLINTNEXTLINE(misc-no-recursion)
1389  void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to,
1390  FModuleLike fromModule, Operation *from) {
1391  // Merge the operation locations.
1392  if (to->getLoc() != from->getLoc())
1393  to->setLoc(mergeLoc(context, to->getLoc(), from->getLoc()));
1394 
1395  // Recurse into any regions.
1396  for (auto regions : llvm::zip(to->getRegions(), from->getRegions()))
1397  mergeRegions(renameMap, toModule, std::get<0>(regions), fromModule,
1398  std::get<1>(regions));
1399 
1400  // Record any inner_sym renamings that happened.
1401  recordSymRenames(renameMap, toModule, to, fromModule, from);
1402 
1403  // Merge the annotations.
1404  mergeAnnotations(toModule, to, fromModule, from);
1405  }
1406 
1407  /// Recursively merge two blocks.
1408  void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock,
1409  FModuleLike fromModule, Block &fromBlock) {
1410  // Merge the block locations.
1411  for (auto [toArg, fromArg] :
1412  llvm::zip(toBlock.getArguments(), fromBlock.getArguments()))
1413  if (toArg.getLoc() != fromArg.getLoc())
1414  toArg.setLoc(mergeLoc(context, toArg.getLoc(), fromArg.getLoc()));
1415 
1416  for (auto ops : llvm::zip(toBlock, fromBlock))
1417  mergeOps(renameMap, toModule, &std::get<0>(ops), fromModule,
1418  &std::get<1>(ops));
1419  }
1420 
1421  // Recursively merge two regions.
1422  void mergeRegions(RenameMap &renameMap, FModuleLike toModule,
1423  Region &toRegion, FModuleLike fromModule,
1424  Region &fromRegion) {
1425  for (auto blocks : llvm::zip(toRegion, fromRegion))
1426  mergeBlocks(renameMap, toModule, std::get<0>(blocks), fromModule,
1427  std::get<1>(blocks));
1428  }
1429 
1430  MLIRContext *context;
1432  SymbolTable &symbolTable;
1433 
1434  /// Cached nla table analysis.
1435  NLATable *nlaTable = nullptr;
1436 
1437  /// We insert all NLAs to the beginning of this block.
1438  Block *nlaBlock;
1439 
1440  // This maps an NLA to the operations and ports that uses it.
1441  DenseMap<Attribute, llvm::SmallDenseSet<AnnoTarget>> targetMap;
1442 
1443  // This is a cache to avoid creating duplicate NLAs. This maps the ArrayAtr
1444  // of the NLA's path to the name of the NLA which contains it.
1445  DenseMap<Attribute, Attribute> nlaCache;
1446 
1447  // Cached attributes for faster comparisons and attribute building.
1448  StringAttr nonLocalString;
1449  StringAttr classString;
1450 
1451  /// A module namespace cache.
1452  DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
1453 };
1454 
1455 //===----------------------------------------------------------------------===//
1456 // Fixup
1457 //===----------------------------------------------------------------------===//
1458 
1459 /// This fixes up ClassLikes with ClassType ports, when the classes have
1460 /// deduped. For each ClassType port, if the object reference being assigned is
1461 /// a different type, update the port type. Returns true if the ClassOp was
1462 /// updated and the associated ObjectOps should be updated.
1463 bool fixupClassOp(ClassOp classOp) {
1464  // New port type attributes, if necessary.
1465  SmallVector<Attribute> newPortTypes;
1466  bool anyDifferences = false;
1467 
1468  // Check each port.
1469  for (size_t i = 0, e = classOp.getNumPorts(); i < e; ++i) {
1470  // Check if this port is a ClassType. If not, save the original type
1471  // attribute in case we need to update port types.
1472  auto portClassType = dyn_cast<ClassType>(classOp.getPortType(i));
1473  if (!portClassType) {
1474  newPortTypes.push_back(classOp.getPortTypeAttr(i));
1475  continue;
1476  }
1477 
1478  // Check if this port is assigned a reference of a different ClassType.
1479  Type newPortClassType;
1480  BlockArgument portArg = classOp.getArgument(i);
1481  for (auto &use : portArg.getUses()) {
1482  if (auto propassign = dyn_cast<PropAssignOp>(use.getOwner())) {
1483  Type sourceType = propassign.getSrc().getType();
1484  if (propassign.getDest() == use.get() && sourceType != portClassType) {
1485  // Double check that all references are the same new type.
1486  if (newPortClassType) {
1487  assert(newPortClassType == sourceType &&
1488  "expected all references to be of the same type");
1489  continue;
1490  }
1491 
1492  newPortClassType = sourceType;
1493  }
1494  }
1495  }
1496 
1497  // If there was no difference, save the original type attribute in case we
1498  // need to update port types and move along.
1499  if (!newPortClassType) {
1500  newPortTypes.push_back(classOp.getPortTypeAttr(i));
1501  continue;
1502  }
1503 
1504  // The port type changed, so update the block argument, save the new port
1505  // type attribute, and indicate there was a difference.
1506  classOp.getArgument(i).setType(newPortClassType);
1507  newPortTypes.push_back(TypeAttr::get(newPortClassType));
1508  anyDifferences = true;
1509  }
1510 
1511  // If necessary, update port types.
1512  if (anyDifferences)
1513  classOp.setPortTypes(newPortTypes);
1514 
1515  return anyDifferences;
1516 }
1517 
1518 /// This fixes up ObjectOps when the signature of their ClassOp changes. This
1519 /// amounts to updating the ObjectOp result type to match the newly updated
1520 /// ClassOp type.
1521 void fixupObjectOp(ObjectOp objectOp, ClassType newClassType) {
1522  objectOp.getResult().setType(newClassType);
1523 }
1524 
1525 /// This fixes up connects when the field names of a bundle type changes. It
1526 /// finds all fields which were previously bulk connected and legalizes it
1527 /// into a connect for each field.
1528 void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src) {
1529  // If the types already match we can emit a connect.
1530  auto dstType = dst.getType();
1531  auto srcType = src.getType();
1532  if (dstType == srcType) {
1533  emitConnect(builder, dst, src);
1534  return;
1535  }
1536  // It must be a bundle type and the field name has changed. We have to
1537  // manually decompose the bulk connect into a connect for each field.
1538  auto dstBundle = type_cast<BundleType>(dstType);
1539  auto srcBundle = type_cast<BundleType>(srcType);
1540  for (unsigned i = 0; i < dstBundle.getNumElements(); ++i) {
1541  auto dstField = builder.create<SubfieldOp>(dst, i);
1542  auto srcField = builder.create<SubfieldOp>(src, i);
1543  if (dstBundle.getElement(i).isFlip) {
1544  std::swap(srcBundle, dstBundle);
1545  std::swap(srcField, dstField);
1546  }
1547  fixupConnect(builder, dstField, srcField);
1548  }
1549 }
1550 
1551 /// This is the root method to fixup module references when a module changes.
1552 /// It matches all the results of "to" module with the results of the "from"
1553 /// module.
1554 void fixupAllModules(InstanceGraph &instanceGraph) {
1555  for (auto *node : instanceGraph) {
1556  auto module = cast<FModuleLike>(*node->getModule());
1557 
1558  // Handle class declarations here.
1559  bool shouldFixupObjects = false;
1560  auto classOp = dyn_cast<ClassOp>(module.getOperation());
1561  if (classOp)
1562  shouldFixupObjects = fixupClassOp(classOp);
1563 
1564  for (auto *instRec : node->uses()) {
1565  // Handle object instantiations here.
1566  if (classOp) {
1567  if (shouldFixupObjects) {
1568  fixupObjectOp(instRec->getInstance<ObjectOp>(),
1569  classOp.getInstanceType());
1570  }
1571  continue;
1572  }
1573 
1574  auto inst = instRec->getInstance<InstanceOp>();
1575  // Only handle module instantiations here.
1576  if (!inst)
1577  continue;
1578  ImplicitLocOpBuilder builder(inst.getLoc(), inst->getContext());
1579  builder.setInsertionPointAfter(inst);
1580  for (size_t i = 0, e = getNumPorts(module); i < e; ++i) {
1581  auto result = inst.getResult(i);
1582  auto newType = module.getPortType(i);
1583  auto oldType = result.getType();
1584  // If the type has not changed, we don't have to fix up anything.
1585  if (newType == oldType)
1586  continue;
1587  // If the type changed we transform it back to the old type with an
1588  // intermediate wire.
1589  auto wire =
1590  builder.create<WireOp>(oldType, inst.getPortName(i)).getResult();
1591  result.replaceAllUsesWith(wire);
1592  result.setType(newType);
1593  if (inst.getPortDirection(i) == Direction::Out)
1594  fixupConnect(builder, wire, result);
1595  else
1596  fixupConnect(builder, result, wire);
1597  }
1598  }
1599  }
1600 }
1601 
1602 namespace llvm {
1603 /// A DenseMapInfo implementation for `ModuleInfo` that is a pair of
1604 /// llvm::SHA256 hashes, which are represented as std::array<uint8_t, 32>, and
1605 /// an array of string attributes. This allows us to create a DenseMap with
1606 /// `ModuleInfo` as keys.
1607 template <>
1608 struct DenseMapInfo<ModuleInfo> {
1609  static inline ModuleInfo getEmptyKey() {
1610  std::array<uint8_t, 32> key;
1611  std::fill(key.begin(), key.end(), ~0);
1612  return {key, {}};
1613  }
1614 
1615  static inline ModuleInfo getTombstoneKey() {
1616  std::array<uint8_t, 32> key;
1617  std::fill(key.begin(), key.end(), ~0 - 1);
1618  return {key, {}};
1619  }
1620 
1621  static unsigned getHashValue(const ModuleInfo &val) {
1622  // We assume SHA256 is already a good hash and just truncate down to the
1623  // number of bytes we need for DenseMap.
1624  unsigned hash;
1625  std::memcpy(&hash, val.structuralHash.data(), sizeof(unsigned));
1626 
1627  // Combine module names.
1628  return llvm::hash_combine(
1629  hash, llvm::hash_combine_range(val.referredModuleNames.begin(),
1630  val.referredModuleNames.end()));
1631  }
1632 
1633  static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs) {
1634  return lhs == rhs;
1635  }
1636 };
1637 } // namespace llvm
1638 
1639 //===----------------------------------------------------------------------===//
1640 // DedupPass
1641 //===----------------------------------------------------------------------===//
1642 
1643 namespace {
1644 class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {
1645  void runOnOperation() override {
1646  auto *context = &getContext();
1647  auto circuit = getOperation();
1648  auto &instanceGraph = getAnalysis<InstanceGraph>();
1649  auto *nlaTable = &getAnalysis<NLATable>();
1650  auto &symbolTable = getAnalysis<SymbolTable>();
1651  Deduper deduper(instanceGraph, symbolTable, nlaTable, circuit);
1652  Equivalence equiv(context, instanceGraph);
1653  auto anythingChanged = false;
1654 
1655  // Modules annotated with this should not be considered for deduplication.
1656  auto noDedupClass = StringAttr::get(context, noDedupAnnoClass);
1657 
1658  // Only modules within the same group may be deduplicated.
1659  auto dedupGroupClass = StringAttr::get(context, dedupGroupAnnoClass);
1660 
1661  // A map of all the module moduleInfo that we have calculated so far.
1662  llvm::DenseMap<ModuleInfo, Operation *> moduleInfoToModule;
1663 
1664  // We track the name of the module that each module is deduped into, so that
1665  // we can make sure all modules which are marked "must dedup" with each
1666  // other were all deduped to the same module.
1667  DenseMap<Attribute, StringAttr> dedupMap;
1668 
1669  // We must iterate the modules from the bottom up so that we can properly
1670  // deduplicate the modules. We copy the list of modules into a vector first
1671  // to avoid iterator invalidation while we mutate the instance graph.
1672  SmallVector<FModuleLike, 0> modules(
1673  llvm::map_range(llvm::post_order(&instanceGraph), [](auto *node) {
1674  return cast<FModuleLike>(*node->getModule());
1675  }));
1676 
1677  SmallVector<std::optional<ModuleInfo>> moduleInfos(modules.size());
1678  StructuralHasherSharedConstants hasherConstants(&getContext());
1679 
1680  // Attribute name used to store dedup_group for this pass.
1681  auto dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
1682 
1683  // Move dedup group annotations to attributes on the module.
1684  // This results in the desired behavior (included in hash),
1685  // and avoids unnecessary processing of these as annotations
1686  // that need to be tracked, made non-local, so on.
1687  for (auto module : modules) {
1688  llvm::SmallSetVector<StringAttr, 1> groups;
1690  module, [&groups, dedupGroupClass](Annotation annotation) {
1691  if (annotation.getClassAttr() != dedupGroupClass)
1692  return false;
1693  groups.insert(annotation.getMember<StringAttr>("group"));
1694  return true;
1695  });
1696  if (groups.size() > 1) {
1697  module.emitError("module belongs to multiple dedup groups: ") << groups;
1698  return signalPassFailure();
1699  }
1700  assert(!module->hasAttr(dedupGroupAttrName) &&
1701  "unexpected existing use of temporary dedup group attribute");
1702  if (!groups.empty())
1703  module->setDiscardableAttr(dedupGroupAttrName, groups.front());
1704  }
1705 
1706  // Calculate module information parallelly.
1707  auto result = mlir::failableParallelForEach(
1708  context, llvm::seq(modules.size()), [&](unsigned idx) {
1709  auto module = modules[idx];
1710  // If the module is marked with NoDedup, just skip it.
1711  if (AnnotationSet::hasAnnotation(module, noDedupClass))
1712  return success();
1713 
1714  // Only dedup extmodule's with defname.
1715  if (auto ext = dyn_cast<FExtModuleOp>(*module);
1716  ext && !ext.getDefname().has_value())
1717  return success();
1718 
1719  StructuralHasher hasher(hasherConstants);
1720  // Calculate the hash of the module and referred module names.
1721  moduleInfos[idx] = hasher.getModuleInfo(module);
1722  return success();
1723  });
1724 
1725  if (result.failed())
1726  return signalPassFailure();
1727 
1728  for (auto [i, module] : llvm::enumerate(modules)) {
1729  auto moduleName = module.getModuleNameAttr();
1730  auto &maybeModuleInfo = moduleInfos[i];
1731  // If the hash was not calculated, we need to skip it.
1732  if (!maybeModuleInfo) {
1733  // We record it in the dedup map to help detect errors when the user
1734  // marks the module as both NoDedup and MustDedup. We do not record this
1735  // module in the hasher to make sure no other module dedups "into" this
1736  // one.
1737  dedupMap[moduleName] = moduleName;
1738  continue;
1739  }
1740 
1741  auto &moduleInfo = maybeModuleInfo.value();
1742 
1743  // Replace module names referred in the module with new names.
1744  for (auto &referredModule : moduleInfo.referredModuleNames)
1745  referredModule = dedupMap[referredModule];
1746 
1747  // Check if there a module with the same hash.
1748  auto it = moduleInfoToModule.find(moduleInfo);
1749  if (it != moduleInfoToModule.end()) {
1750  auto original = cast<FModuleLike>(it->second);
1751  auto originalName = original.getModuleNameAttr();
1752 
1753  // If the current module is public, and the original is private, we
1754  // want to dedup the private module into the public one.
1755  if (!canRemoveModule(module)) {
1756  // If both modules are public, then we can't dedup anything.
1757  if (!canRemoveModule(original))
1758  continue;
1759  // Swap the canonical module in the dedup map.
1760  for (auto &[originalName, dedupedName] : dedupMap)
1761  if (dedupedName == originalName)
1762  dedupedName = moduleName;
1763  // Update the module hash table to point to the new original, so all
1764  // future modules dedup with the new canonical module.
1765  it->second = module;
1766  // Swap the locals.
1767  std::swap(originalName, moduleName);
1768  std::swap(original, module);
1769  }
1770 
1771  // Record the group ID of the other module.
1772  dedupMap[moduleName] = originalName;
1773  deduper.dedup(original, module);
1774  ++erasedModules;
1775  anythingChanged = true;
1776  continue;
1777  }
1778  // Any module not deduplicated must be recorded.
1779  deduper.record(module);
1780  // Add the module to a new dedup group.
1781  dedupMap[moduleName] = moduleName;
1782  // Record the module info.
1783  moduleInfoToModule[std::move(moduleInfo)] = module;
1784  }
1785 
1786  // This part verifies that all modules marked by "MustDedup" have been
1787  // properly deduped with each other. For this check to succeed, all modules
1788  // have to been deduped to the same module. It is possible that a module was
1789  // deduped with the wrong thing.
1790 
1791  auto failed = false;
1792  // This parses the module name out of a target string.
1793  auto parseModule = [&](Attribute path) -> StringAttr {
1794  // Each module is listed as a target "~Circuit|Module" which we have to
1795  // parse.
1796  auto [_, rhs] = cast<StringAttr>(path).getValue().split('|');
1797  return StringAttr::get(context, rhs);
1798  };
1799  // This gets the name of the module which the current module was deduped
1800  // with. If the named module isn't in the map, then we didn't encounter it
1801  // in the circuit.
1802  auto getLead = [&](StringAttr module) -> StringAttr {
1803  auto it = dedupMap.find(module);
1804  if (it == dedupMap.end()) {
1805  auto diag = emitError(circuit.getLoc(),
1806  "MustDeduplicateAnnotation references module ")
1807  << module << " which does not exist";
1808  failed = true;
1809  return nullptr;
1810  }
1811  return it->second;
1812  };
1813 
1814  AnnotationSet::removeAnnotations(circuit, [&](Annotation annotation) {
1815  if (!annotation.isClass(mustDedupAnnoClass))
1816  return false;
1817  auto modules = annotation.getMember<ArrayAttr>("modules");
1818  if (!modules) {
1819  emitError(circuit.getLoc(),
1820  "MustDeduplicateAnnotation missing \"modules\" member");
1821  failed = true;
1822  return false;
1823  }
1824  // Empty module list has nothing to process.
1825  if (modules.empty())
1826  return true;
1827  // Get the first element.
1828  auto firstModule = parseModule(modules[0]);
1829  auto firstLead = getLead(firstModule);
1830  if (!firstLead)
1831  return false;
1832  // Verify that the remaining elements are all the same as the first.
1833  for (auto attr : modules.getValue().drop_front()) {
1834  auto nextModule = parseModule(attr);
1835  auto nextLead = getLead(nextModule);
1836  if (!nextLead)
1837  return false;
1838  if (firstLead != nextLead) {
1839  auto diag = emitError(circuit.getLoc(), "module ")
1840  << nextModule << " not deduplicated with " << firstModule;
1841  auto a = instanceGraph.lookup(firstLead)->getModule();
1842  auto b = instanceGraph.lookup(nextLead)->getModule();
1843  equiv.check(diag, a, b);
1844  failed = true;
1845  return false;
1846  }
1847  }
1848  return true;
1849  });
1850  if (failed)
1851  return signalPassFailure();
1852 
1853  // Remove all dedup group attributes, they only exist during this pass.
1854  for (auto module : circuit.getOps<FModuleLike>())
1855  module->removeDiscardableAttr(dedupGroupAttrName);
1856 
1857  // Walk all the modules and fixup the instance operation to return the
1858  // correct type. We delay this fixup until the end because doing it early
1859  // can block the deduplication of the parent modules.
1860  fixupAllModules(instanceGraph);
1861 
1862  markAnalysesPreserved<NLATable>();
1863  if (!anythingChanged)
1864  markAllAnalysesPreserved();
1865  }
1866 };
1867 } // end anonymous namespace
1868 
1869 std::unique_ptr<mlir::Pass> circt::firrtl::createDedupPass() {
1870  return std::make_unique<DedupPass>();
1871 }
assert(baseType &&"element must be base type")
static Location mergeLoc(MLIRContext *context, Location to, Location from)
Definition: Dedup.cpp:879
void fixupObjectOp(ObjectOp objectOp, ClassType newClassType)
This fixes up ObjectOps when the signature of their ClassOp changes.
Definition: Dedup.cpp:1521
llvm::raw_ostream & printHex(llvm::raw_ostream &stream, ArrayRef< uint8_t > bytes)
Definition: Dedup.cpp:77
static bool canRemoveModule(mlir::SymbolOpInterface symbol)
Returns true if the module can be removed.
Definition: Dedup.cpp:55
void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src)
This fixes up connects when the field names of a bundle type changes.
Definition: Dedup.cpp:1528
llvm::raw_ostream & printHash(llvm::raw_ostream &stream, llvm::SHA256 &data)
Definition: Dedup.cpp:83
void fixupAllModules(InstanceGraph &instanceGraph)
This is the root method to fixup module references when a module changes.
Definition: Dedup.cpp:1554
bool fixupClassOp(ClassOp classOp)
This fixes up ClassLikes with ClassType ports, when the classes have deduped.
Definition: Dedup.cpp:1463
static void mergeRegions(Region *region1, Region *region2)
Definition: HWCleanup.cpp:77
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
void setDict(DictionaryAttr dict)
Set the data dictionary of this attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
StringAttr getClassAttr() const
Return the 'class' that this annotation is representing.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition: NLATable.h:29
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
ClassType getInstanceTypeForClassLike(ClassLike classOp)
Definition: FIRRTLOps.cpp:1835
static StringRef toString(Direction direction)
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
constexpr const char * mustDedupAnnoClass
std::pair< hw::InnerSymAttr, StringAttr > getOrAddInnerSym(MLIRContext *context, hw::InnerSymAttr attr, uint64_t fieldID, llvm::function_ref< hw::InnerSymbolNamespace &()> getNamespace)
Ensure that the the InnerSymAttr has a symbol on the field specified.
constexpr const char * noDedupAnnoClass
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
Definition: FIRRTLOps.cpp:301
constexpr const char * dedupGroupAnnoClass
std::unique_ptr< mlir::Pass > createDedupPass()
Definition: Dedup.cpp:1869
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
constexpr const char * dontTouchAnnoClass
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
Definition: FIRRTLUtils.cpp:25
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
Definition: FIRRTLOps.h:108
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
bool operator==(uint64_t a, const FVInt &b)
Definition: FVInt.h:649
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition: Utils.h:32
MLIRContext * context
Definition: Dedup.cpp:1430
Block * nlaBlock
We insert all NLAs to the beginning of this block.
Definition: Dedup.cpp:1438
void recordAnnotations(Operation *op)
Record all targets which use an NLA.
Definition: Dedup.cpp:1009
void eraseNLA(hw::HierPathOp nla)
This erases the NLA op, and removes the NLA from every module's NLA map, but it does not delete the N...
Definition: Dedup.cpp:1118
void mergeAnnotations(FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Merge all annotations and port annotations on two operations.
Definition: Dedup.cpp:1299
void replaceInstances(FModuleLike toModule, Operation *fromModule)
This deletes and replaces all instances of the "fromModule" with instances of the "toModule".
Definition: Dedup.cpp:1025
SmallVector< FlatSymbolRefAttr > createNLAs(StringAttr toModuleName, FModuleLike fromModule, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of this module and create an NLA for each one.
Definition: Dedup.cpp:1091
void record(FModuleLike module)
Record the usages of any NLA's in this module, so that we may update the annotation if the parent mod...
Definition: Dedup.cpp:983
void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName, StringAttr fromName)
Definition: Dedup.cpp:1197
void mergeRegions(RenameMap &renameMap, FModuleLike toModule, Region &toRegion, FModuleLike fromModule, Region &fromRegion)
Definition: Dedup.cpp:1422
void dedup(FModuleLike toModule, FModuleLike fromModule)
Remove the "fromModule", and replace all references to it with the "toModule".
Definition: Dedup.cpp:948
void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all the NLAs that the two modules participate in, replacing references to the "from" module w...
Definition: Dedup.cpp:1189
void recordAnnotations(AnnoTarget target)
For a specific annotation target, record all the unique NLAs which target it in the targetMap.
Definition: Dedup.cpp:1002
void cloneAnnotation(SmallVectorImpl< FlatSymbolRefAttr > &nlas, Annotation anno, ArrayRef< NamedAttribute > attributes, unsigned nonLocalIndex, SmallVectorImpl< Annotation > &newAnnotations)
Clone the annotation for each NLA in a list.
Definition: Dedup.cpp:1099
void recordSymRenames(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Definition: Dedup.cpp:1327
void mergeAnnotations(FModuleLike toModule, AnnoTarget to, AnnotationSet toAnnos, FModuleLike fromModule, AnnoTarget from, AnnotationSet fromAnnos)
Merge the annotations of a specific target, either a operation or a port on an operation.
Definition: Dedup.cpp:1276
StringAttr nonLocalString
Definition: Dedup.cpp:1448
SymbolTable & symbolTable
Definition: Dedup.cpp:1432
SmallVector< FlatSymbolRefAttr > createNLAs(Operation *fromModule, ArrayRef< Attribute > baseNamepath, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of the from module and create an NLA for each one, appending the baseNamep...
Definition: Dedup.cpp:1054
void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Recursively merge two operations.
Definition: Dedup.cpp:1389
hw::InnerSymbolNamespace & getNamespace(Operation *module)
Get a cached namespace for a module.
Definition: Dedup.cpp:995
DenseMap< Operation *, hw::InnerSymbolNamespace > moduleNamespaces
A module namespace cache.
Definition: Dedup.cpp:1452
bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to, FModuleLike fromModule, Annotation anno, SmallVectorImpl< Annotation > &newAnnotations)
Take an annotation, and update it to be a non-local annotation.
Definition: Dedup.cpp:1205
InstanceGraph & instanceGraph
Definition: Dedup.cpp:1431
void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock, FModuleLike fromModule, Block &fromBlock)
Recursively merge two blocks.
Definition: Dedup.cpp:1408
DenseMap< Attribute, llvm::SmallDenseSet< AnnoTarget > > targetMap
Definition: Dedup.cpp:1441
StringAttr classString
Definition: Dedup.cpp:1449
void copyAnnotations(FModuleLike toModule, AnnoTarget to, FModuleLike fromModule, AnnotationSet annos, SmallVectorImpl< Annotation > &newAnnotations, SmallPtrSetImpl< Attribute > &dontTouches)
Definition: Dedup.cpp:1247
Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable, NLATable *nlaTable, CircuitOp circuit)
Definition: Dedup.cpp:932
void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all NLAs referencing the "from" module to point to the "to" module.
Definition: Dedup.cpp:1128
DenseMap< StringAttr, StringAttr > RenameMap
Definition: Dedup.cpp:930
DenseMap< Attribute, Attribute > nlaCache
Definition: Dedup.cpp:1445
const hw::InnerSymbolTable & a
Definition: Dedup.cpp:413
ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
Definition: Dedup.cpp:410
const hw::InnerSymbolTable & b
Definition: Dedup.cpp:414
This class is for reporting differences between two modules which should have been deduplicated.
Definition: Dedup.cpp:392
DenseSet< Attribute > nonessentialAttributes
Definition: Dedup.cpp:866
std::string prettyPrint(Attribute attr)
Definition: Dedup.cpp:417
LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a, FInstanceLike b)
Definition: Dedup.cpp:699
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Block &aBlock, Operation *b, Block &bBlock)
Definition: Dedup.cpp:488
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, Type aType, Operation *b, Type bType)
Definition: Dedup.cpp:462
StringAttr noDedupClass
Definition: Dedup.cpp:861
StringAttr dedupGroupAttrName
Definition: Dedup.cpp:863
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, DictionaryAttr aDict, Operation *b, DictionaryAttr bDict)
Definition: Dedup.cpp:611
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Region &aRegion, Operation *b, Region &bRegion)
Definition: Dedup.cpp:566
StringAttr portDirectionsAttr
Definition: Dedup.cpp:859
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, BundleType aType, Operation *b, BundleType bType)
Definition: Dedup.cpp:433
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Operation *b)
Definition: Dedup.cpp:720
Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
Definition: Dedup.cpp:393
LogicalResult check(InFlightDiagnostic &diag, Operation *a, mlir::DenseBoolArrayAttr aAttr, Operation *b, mlir::DenseBoolArrayAttr bAttr)
Definition: Dedup.cpp:586
InstanceGraph & instanceGraph
Definition: Dedup.cpp:867
void check(InFlightDiagnostic &diag, Operation *a, Operation *b)
Definition: Dedup.cpp:807
std::vector< StringAttr > referredModuleNames
Definition: Dedup.cpp:102
std::array< uint8_t, 32 > structuralHash
Definition: Dedup.cpp:100
This struct contains constant string attributes shared across different threads.
Definition: Dedup.cpp:112
DenseSet< Attribute > nonessentialAttributes
Definition: Dedup.cpp:145
StructuralHasherSharedConstants(MLIRContext *context)
Definition: Dedup.cpp:113
void update(Operation *op, DictionaryAttr dict)
Hash the top level attribute dictionary of the operation.
Definition: Dedup.cpp:242
void update(Type type)
Definition: Dedup.cpp:222
void update(const void *pointer)
Definition: Dedup.cpp:200
DenseMap< void *, unsigned > idTable
Definition: Dedup.cpp:365
void update(Operation *op)
Definition: Dedup.cpp:337
llvm::SHA256 sha
Definition: Dedup.cpp:380
ModuleInfo getModuleInfo(FModuleLike module)
Definition: Dedup.cpp:152
void update(size_t value)
Definition: Dedup.cpp:205
unsigned getInnerSymID(StringAttr name)
Definition: Dedup.cpp:176
void update(BundleType type)
Definition: Dedup.cpp:213
unsigned getID(void *object)
Definition: Dedup.cpp:159
void update(OpResult result)
Definition: Dedup.cpp:228
void update(OpOperand &operand)
Definition: Dedup.cpp:183
StructuralHasher(const StructuralHasherSharedConstants &constants)
Definition: Dedup.cpp:149
void update(Region *region)
Definition: Dedup.cpp:329
void update(Block *block)
Definition: Dedup.cpp:318
std::vector< StringAttr > referredModuleNames
Definition: Dedup.cpp:373
DenseMap< StringAttr, unsigned > innerSymIDTable
Definition: Dedup.cpp:369
void update(TypeID typeID)
Definition: Dedup.cpp:210
const StructuralHasherSharedConstants & constants
Definition: Dedup.cpp:376
unsigned finalizeID(void *object)
Definition: Dedup.cpp:167
void update(mlir::OperationName name)
Definition: Dedup.cpp:312
An annotation target is used to keep track of something that is targeted by an Annotation.
AnnotationSet getAnnotations() const
Get the annotations associated with the target.
void setAnnotations(AnnotationSet annotations) const
Set the annotations associated with the target.
This represents an annotation targeting a specific operation.
Attribute getNLAReference(hw::InnerSymbolNamespace &moduleNamespace) const
This represents an annotation targeting a specific port of a module, memory, or instance.
static ModuleInfo getEmptyKey()
Definition: Dedup.cpp:1609
static ModuleInfo getTombstoneKey()
Definition: Dedup.cpp:1615
static unsigned getHashValue(const ModuleInfo &val)
Definition: Dedup.cpp:1621
static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs)
Definition: Dedup.cpp:1633