CIRCT  20.0.0git
Dedup.cpp
Go to the documentation of this file.
1 //===- Dedup.cpp - FIRRTL module deduping -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements FIRRTL module deduplication.
10 //
11 //===----------------------------------------------------------------------===//
12 
23 #include "circt/Support/LLVM.h"
24 #include "mlir/IR/IRMapping.h"
25 #include "mlir/IR/Threading.h"
26 #include "mlir/Pass/Pass.h"
27 #include "llvm/ADT/DenseMap.h"
28 #include "llvm/ADT/DenseMapInfo.h"
29 #include "llvm/ADT/Hashing.h"
30 #include "llvm/ADT/PostOrderIterator.h"
31 #include "llvm/ADT/SmallPtrSet.h"
32 #include "llvm/Support/Format.h"
33 #include "llvm/Support/SHA256.h"
34 
35 namespace circt {
36 namespace firrtl {
37 #define GEN_PASS_DEF_DEDUP
38 #include "circt/Dialect/FIRRTL/Passes.h.inc"
39 } // namespace firrtl
40 } // namespace circt
41 
42 using namespace circt;
43 using namespace firrtl;
44 using hw::InnerRefAttr;
45 
46 //===----------------------------------------------------------------------===//
47 // Utility function for classifying a Symbol's dedup-ability.
48 //===----------------------------------------------------------------------===//
49 
50 /// Returns true if the module can be removed.
51 static bool canRemoveModule(mlir::SymbolOpInterface symbol) {
52  // If the symbol is not private, it cannot be removed.
53  if (!symbol.isPrivate())
54  return false;
55  // Classes may be referenced in object types, so can not normally be removed
56  // if we can't find any symbol uses. Since we know that dedup will update the
57  // types of instances appropriately, we can ignore that and return true here.
58  if (isa<ClassLike>(*symbol))
59  return true;
60  // If module can not be removed even if no uses can be found, we can not
61  // delete it. The implication is that there are hidden symbol uses that dedup
62  // will not properly update.
63  if (!symbol.canDiscardOnUseEmpty())
64  return false;
65  // The module can be deleted.
66  return true;
67 }
68 
69 //===----------------------------------------------------------------------===//
70 // Hashing
71 //===----------------------------------------------------------------------===//
72 
73 llvm::raw_ostream &printHex(llvm::raw_ostream &stream,
74  ArrayRef<uint8_t> bytes) {
75  // Print the hash on a single line.
76  return stream << format_bytes(bytes, std::nullopt, 32) << "\n";
77 }
78 
79 llvm::raw_ostream &printHash(llvm::raw_ostream &stream, llvm::SHA256 &data) {
80  return printHex(stream, data.result());
81 }
82 
83 llvm::raw_ostream &printHash(llvm::raw_ostream &stream, std::string data) {
84  ArrayRef<uint8_t> bytes(reinterpret_cast<const uint8_t *>(data.c_str()),
85  data.length());
86  return printHex(stream, bytes);
87 }
88 
89 // This struct contains information to determine module module uniqueness. A
90 // first element is a structural hash of the module, and the second element is
91 // an array which tracks module names encountered in the walk. Since module
92 // names could be replaced during dedup, it's necessary to keep names up-to-date
93 // before actually combining them into structural hashes.
94 struct ModuleInfo {
95  // SHA256 hash.
96  std::array<uint8_t, 32> structuralHash;
97  // Module names referred by instance op in the module.
98  std::vector<StringAttr> referredModuleNames;
99 };
100 
101 static bool operator==(const ModuleInfo &lhs, const ModuleInfo &rhs) {
102  return lhs.structuralHash == rhs.structuralHash &&
103  lhs.referredModuleNames == rhs.referredModuleNames;
104 }
105 
106 /// This struct contains constant string attributes shared across different
107 /// threads.
109  explicit StructuralHasherSharedConstants(MLIRContext *context) {
110  portTypesAttr = StringAttr::get(context, "portTypes");
111  moduleNameAttr = StringAttr::get(context, "moduleName");
112  innerSymAttr = StringAttr::get(context, "inner_sym");
113  portSymbolsAttr = StringAttr::get(context, "portSymbols");
114  portNamesAttr = StringAttr::get(context, "portNames");
115  nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
116  nonessentialAttributes.insert(StringAttr::get(context, "convention"));
117  nonessentialAttributes.insert(StringAttr::get(context, "name"));
118  nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
119  nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
120  nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
121  nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
122  nonessentialAttributes.insert(StringAttr::get(context, "sym_visibility"));
123  };
124 
125  // This is a cached "portTypes" string attr.
126  StringAttr portTypesAttr;
127 
128  // This is a cached "moduleName" string attr.
129  StringAttr moduleNameAttr;
130 
131  // This is a cached "inner_sym" string attr.
132  StringAttr innerSymAttr;
133 
134  // This is a cached "portSymbols" string attr.
135  StringAttr portSymbolsAttr;
136 
137  // This is a cached "portNames" string attr.
138  StringAttr portNamesAttr;
139 
140  // This is a set of every attribute we should ignore.
141  DenseSet<Attribute> nonessentialAttributes;
142 };
143 
146  : constants(constants) {}
147 
148  ModuleInfo getModuleInfo(FModuleLike module) {
149  update(&(*module));
150  return {sha.final(), std::move(referredModuleNames)};
151  }
152 
153 private:
154  // Get the identifier for an object. The identifier is assigned on first use.
155  unsigned getID(void *object) {
156  auto [it, inserted] = idTable.try_emplace(object, nextID);
157  if (inserted)
158  ++nextID;
159  return it->second;
160  }
161 
162  // Get the identifier for an IR object. Free the ID, too.
163  unsigned finalizeID(void *object) {
164  auto it = idTable.find(object);
165  if (it == idTable.end())
166  return nextID++;
167  auto id = it->second;
168  idTable.erase(it);
169  return id;
170  }
171 
172  unsigned getInnerSymID(StringAttr name) {
173  auto [it, inserted] = innerSymIDTable.try_emplace(name, nextInnerSymID);
174  if (inserted)
175  ++nextInnerSymID;
176  return it->second;
177  }
178 
179  void update(OpOperand &operand) {
180  auto value = operand.get();
181  if (auto result = dyn_cast<OpResult>(value)) {
182  auto *op = result.getOwner();
183  update(getID(op));
184  update(result.getResultNumber());
185  return;
186  }
187  if (auto argument = dyn_cast<BlockArgument>(value)) {
188  auto *block = argument.getOwner();
189  update(getID(block));
190  update(argument.getArgNumber());
191  return;
192  }
193  llvm_unreachable("Unknown value type");
194  }
195 
196  void update(const void *pointer) {
197  auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
198  sha.update(ArrayRef<uint8_t>(addr, sizeof pointer));
199  }
200 
201  void update(size_t value) {
202  auto *addr = reinterpret_cast<const uint8_t *>(&value);
203  sha.update(ArrayRef<uint8_t>(addr, sizeof value));
204  }
205 
206  void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
207 
208  // NOLINTNEXTLINE(misc-no-recursion)
209  void update(BundleType type) {
210  update(type.getTypeID());
211  for (auto &element : type.getElements()) {
212  update(element.isFlip);
213  update(element.type);
214  }
215  }
216 
217  // NOLINTNEXTLINE(misc-no-recursion)
218  void update(Type type) {
219  if (auto bundle = type_dyn_cast<BundleType>(type))
220  return update(bundle);
221  update(type.getAsOpaquePointer());
222  }
223 
224  void update(OpResult result) {
225  // Like instance ops, don't use object ops' result types since they might be
226  // replaced by dedup. Record the class names and lazily combine their hashes
227  // using the same mechanism as instances and modules.
228  if (auto objectOp = dyn_cast<ObjectOp>(result.getOwner())) {
229  referredModuleNames.push_back(objectOp.getType().getNameAttr().getAttr());
230  return;
231  }
232 
233  update(result.getType());
234  }
235 
236  /// Hash the top level attribute dictionary of the operation. This function
237  /// has special handling for inner symbols, ports, and referenced modules.
238  void update(Operation *op, DictionaryAttr dict) {
239  for (auto namedAttr : dict) {
240  auto name = namedAttr.getName();
241  auto value = namedAttr.getValue();
242  // Skip names and annotations, except in certain cases.
243  // Names of ports are load bearing for classes, so we do hash those.
244  bool isClassPortNames =
245  isa<ClassLike>(op) && name == constants.portNamesAttr;
246  if (constants.nonessentialAttributes.contains(name) && !isClassPortNames)
247  continue;
248 
249  // Hash the attribute name (an interned pointer).
250  update(name.getAsOpaquePointer());
251 
252  // Hash the port types.
253  if (name == constants.portTypesAttr) {
254  auto portTypes = cast<ArrayAttr>(value).getAsValueRange<TypeAttr>();
255  for (auto type : portTypes)
256  update(type);
257  continue;
258  }
259 
260  // Special case the InnerSymbols to ignore the symbol names.
261  if (name == constants.portSymbolsAttr) {
262  if (op->getNumRegions() != 1)
263  continue;
264  auto &region = op->getRegion(0);
265  if (region.getBlocks().empty())
266  continue;
267  for (auto sym : cast<ArrayAttr>(value).getAsRange<hw::InnerSymAttr>()) {
268  for (auto property : sym) {
269  update(property.getFieldID());
270  update(getInnerSymID(property.getName()));
271  }
272  }
273  continue;
274  }
275 
276  if (name == constants.innerSymAttr) {
277  auto innerSym = cast<hw::InnerSymAttr>(value);
278  for (auto property : innerSym) {
279  update(property.getFieldID());
280  update(getInnerSymID(property.getName()));
281  }
282  continue;
283  }
284 
285  // For instance op, don't use `moduleName` attributes since they might be
286  // replaced by dedup. Record the names and lazily combine their hashes.
287  // It is assumed that module names are hashed only through instance ops;
288  // it could cause suboptimal results if there was other operation that
289  // refers to module names through essential attributes.
290  if (isa<InstanceOp>(op) && name == constants.moduleNameAttr) {
291  referredModuleNames.push_back(cast<FlatSymbolRefAttr>(value).getAttr());
292  continue;
293  }
294 
295  // TODO: properly handle DistinctAttr, including its use in paths.
296  // See https://github.com/llvm/circt/issues/6583.
297  if (isa<DistinctAttr>(value))
298  continue;
299 
300  // If this is an symbol reference, we need to perform name erasure.
301  if (auto innerRef = dyn_cast<hw::InnerRefAttr>(value))
302  update(getInnerSymID(innerRef.getName()));
303  else
304  update(value.getAsOpaquePointer());
305  }
306  }
307 
308  void update(mlir::OperationName name) {
309  // Operation names are interned.
310  update(name.getAsOpaquePointer());
311  }
312 
313  // NOLINTNEXTLINE(misc-no-recursion)
314  void update(Block *block) {
315  for (auto &op : llvm::reverse(*block))
316  update(&op);
317  for (auto type : block->getArgumentTypes())
318  update(type);
319  update(finalizeID(block));
320  update(position);
321  ++position;
322  }
323 
324  // NOLINTNEXTLINE(misc-no-recursion)
325  void update(Region *region) {
326  for (auto &block : llvm::reverse(region->getBlocks()))
327  update(&block);
328  update(position);
329  ++position;
330  }
331 
332  // NOLINTNEXTLINE(misc-no-recursion)
333  void update(Operation *op) {
334  // Hash the regions. We need to make sure an empty region doesn't hash the
335  // same as no region, so we include the number of regions.
336  update(op->getNumRegions());
337  for (auto &region : reverse(op->getRegions()))
338  update(&region);
339 
340  update(op->getName());
341 
342  // Record the uses for later hashing.
343  for (auto &operand : op->getOpOperands())
344  update(operand);
345 
346  // This happens after the numbering above, as it uses blockarg numbering
347  // for inner symbols.
348  update(op, op->getAttrDictionary());
349 
350  // Record any op results (types).
351  for (auto result : op->getResults())
352  update(result);
353 
354  // Incorporate the hash of uses we have already built.
355  update(finalizeID(op));
356  update(position);
357  ++position;
358  }
359 
360  // A map from an operation/block, to its identifier.
361  DenseMap<void *, unsigned> idTable;
362  unsigned nextID = 0;
363 
364  // A map from an inner symbol, to its identifier.
365  DenseMap<StringAttr, unsigned> innerSymIDTable;
366  unsigned nextInnerSymID = 0;
367 
368  // This keeps track of module names in the order of the appearance.
369  std::vector<StringAttr> referredModuleNames;
370 
371  // String constants.
373 
374  // This is the actual running hash calculation. This is a stateful element
375  // that should be reinitialized after each hash is produced.
376  llvm::SHA256 sha;
377 
378  // The index of the current op. Increment after handling each op.
379  size_t position = 0;
380 };
381 
382 //===----------------------------------------------------------------------===//
383 // Equivalence
384 //===----------------------------------------------------------------------===//
385 
386 /// This class is for reporting differences between two modules which should
387 /// have been deduplicated.
388 struct Equivalence {
389  Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
390  : instanceGraph(instanceGraph) {
391  noDedupClass = StringAttr::get(context, noDedupAnnoClass);
392  dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
393  portDirectionsAttr = StringAttr::get(context, "portDirections");
394  nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
395  nonessentialAttributes.insert(StringAttr::get(context, "name"));
396  nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
397  nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
398  nonessentialAttributes.insert(StringAttr::get(context, "portTypes"));
399  nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
400  nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
401  nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
402  nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
403  }
404 
405  struct ModuleData {
406  ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
407  : a(a), b(b) {}
408  IRMapping map;
409  const hw::InnerSymbolTable &a;
410  const hw::InnerSymbolTable &b;
411  };
412 
413  std::string prettyPrint(Attribute attr) {
414  SmallString<64> buffer;
415  llvm::raw_svector_ostream os(buffer);
416  if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
417  os << "0x";
418  if (integerAttr.getType().isSignlessInteger())
419  integerAttr.getValue().toStringUnsigned(buffer, /*radix=*/16);
420  else
421  integerAttr.getAPSInt().toString(buffer, /*radix=*/16);
422 
423  } else
424  os << attr;
425  return std::string(buffer);
426  }
427 
428  // NOLINTNEXTLINE(misc-no-recursion)
429  LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
430  Operation *a, BundleType aType, Operation *b,
431  BundleType bType) {
432  if (aType.getNumElements() != bType.getNumElements()) {
433  diag.attachNote(a->getLoc())
434  << message << " bundle type has different number of elements";
435  diag.attachNote(b->getLoc()) << "second operation here";
436  return failure();
437  }
438 
439  for (auto elementPair :
440  llvm::zip(aType.getElements(), bType.getElements())) {
441  auto aElement = std::get<0>(elementPair);
442  auto bElement = std::get<1>(elementPair);
443  if (aElement.isFlip != bElement.isFlip) {
444  diag.attachNote(a->getLoc()) << message << " bundle element "
445  << aElement.name << " flip does not match";
446  diag.attachNote(b->getLoc()) << "second operation here";
447  return failure();
448  }
449 
450  if (failed(check(diag,
451  "bundle element \'" + aElement.name.getValue() + "'", a,
452  aElement.type, b, bElement.type)))
453  return failure();
454  }
455  return success();
456  }
457 
458  LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
459  Operation *a, Type aType, Operation *b, Type bType) {
460  if (aType == bType)
461  return success();
462  if (auto aBundleType = type_dyn_cast<BundleType>(aType))
463  if (auto bBundleType = type_dyn_cast<BundleType>(bType))
464  return check(diag, message, a, aBundleType, b, bBundleType);
465  if (type_isa<RefType>(aType) && type_isa<RefType>(bType) &&
466  aType != bType) {
467  diag.attachNote(a->getLoc())
468  << message << ", has a RefType with a different base type "
469  << type_cast<RefType>(aType).getType()
470  << " in the same position of the two modules marked as 'must dedup'. "
471  "(This may be due to Grand Central Taps or Views being different "
472  "between the two modules.)";
473  diag.attachNote(b->getLoc())
474  << "the second module has a different base type "
475  << type_cast<RefType>(bType).getType();
476  return failure();
477  }
478  diag.attachNote(a->getLoc())
479  << message << " types don't match, first type is " << aType;
480  diag.attachNote(b->getLoc()) << "second type is " << bType;
481  return failure();
482  }
483 
484  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
485  Block &aBlock, Operation *b, Block &bBlock) {
486 
487  // Block argument types.
488  auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
489  auto portNo = 0;
490  auto emitMissingPort = [&](Value existsVal, Operation *opExists,
491  Operation *opDoesNotExist) {
492  StringRef portName;
493  auto portNames = opExists->getAttrOfType<ArrayAttr>("portNames");
494  if (portNames)
495  if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
496  portName = portNameAttr.getValue();
497  if (type_isa<RefType>(existsVal.getType())) {
498  diag.attachNote(opExists->getLoc())
499  << " contains a RefType port named '" + portName +
500  "' that only exists in one of the modules (can be due to "
501  "difference in Grand Central Tap or View of two modules "
502  "marked with must dedup)";
503  diag.attachNote(opDoesNotExist->getLoc())
504  << "second module to be deduped that does not have the RefType "
505  "port";
506  } else {
507  diag.attachNote(opExists->getLoc())
508  << "port '" + portName + "' only exists in one of the modules";
509  diag.attachNote(opDoesNotExist->getLoc())
510  << "second module to be deduped that does not have the port";
511  }
512  return failure();
513  };
514 
515  for (auto argPair :
516  llvm::zip_longest(aBlock.getArguments(), bBlock.getArguments())) {
517  auto &aArg = std::get<0>(argPair);
518  auto &bArg = std::get<1>(argPair);
519  if (aArg.has_value() && bArg.has_value()) {
520  // TODO: we should print the port number if there are no port names, but
521  // there are always port names ;).
522  StringRef portName;
523  if (portNames) {
524  if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
525  portName = portNameAttr.getValue();
526  }
527  // Assumption here that block arguments correspond to ports.
528  if (failed(check(diag, "module port '" + portName + "'", a,
529  aArg->getType(), b, bArg->getType())))
530  return failure();
531  data.map.map(aArg.value(), bArg.value());
532  portNo++;
533  continue;
534  }
535  if (!aArg.has_value())
536  std::swap(a, b);
537  return emitMissingPort(aArg.has_value() ? aArg.value() : bArg.value(), a,
538  b);
539  }
540 
541  // Blocks operations.
542  auto aIt = aBlock.begin();
543  auto aEnd = aBlock.end();
544  auto bIt = bBlock.begin();
545  auto bEnd = bBlock.end();
546  while (aIt != aEnd && bIt != bEnd)
547  if (failed(check(diag, data, &*aIt++, &*bIt++)))
548  return failure();
549  if (aIt != aEnd) {
550  diag.attachNote(aIt->getLoc()) << "first block has more operations";
551  diag.attachNote(b->getLoc()) << "second block here";
552  return failure();
553  }
554  if (bIt != bEnd) {
555  diag.attachNote(bIt->getLoc()) << "second block has more operations";
556  diag.attachNote(a->getLoc()) << "first block here";
557  return failure();
558  }
559  return success();
560  }
561 
562  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
563  Region &aRegion, Operation *b, Region &bRegion) {
564  auto aIt = aRegion.begin();
565  auto aEnd = aRegion.end();
566  auto bIt = bRegion.begin();
567  auto bEnd = bRegion.end();
568 
569  // Region blocks.
570  while (aIt != aEnd && bIt != bEnd)
571  if (failed(check(diag, data, a, *aIt++, b, *bIt++)))
572  return failure();
573  if (aIt != aEnd || bIt != bEnd) {
574  diag.attachNote(a->getLoc())
575  << "operation regions have different number of blocks";
576  diag.attachNote(b->getLoc()) << "second operation here";
577  return failure();
578  }
579  return success();
580  }
581 
582  LogicalResult check(InFlightDiagnostic &diag, Operation *a,
583  mlir::DenseBoolArrayAttr aAttr, Operation *b,
584  mlir::DenseBoolArrayAttr bAttr) {
585  if (aAttr == bAttr)
586  return success();
587  auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
588  for (unsigned i = 0, e = aAttr.size(); i < e; ++i) {
589  auto aDirection = aAttr[i];
590  auto bDirection = bAttr[i];
591  if (aDirection != bDirection) {
592  auto &note = diag.attachNote(a->getLoc()) << "module port ";
593  if (portNames)
594  note << "'" << cast<StringAttr>(portNames[i]).getValue() << "'";
595  else
596  note << i;
597  note << " directions don't match, first direction is '"
598  << direction::toString(aDirection) << "'";
599  diag.attachNote(b->getLoc()) << "second direction is '"
600  << direction::toString(bDirection) << "'";
601  return failure();
602  }
603  }
604  return success();
605  }
606 
607  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
608  DictionaryAttr aDict, Operation *b,
609  DictionaryAttr bDict) {
610  // Fast path.
611  if (aDict == bDict)
612  return success();
613 
614  DenseSet<Attribute> seenAttrs;
615  for (auto namedAttr : aDict) {
616  auto attrName = namedAttr.getName();
617  if (nonessentialAttributes.contains(attrName))
618  continue;
619 
620  auto aAttr = namedAttr.getValue();
621  auto bAttr = bDict.get(attrName);
622  if (!bAttr) {
623  diag.attachNote(a->getLoc())
624  << "second operation is missing attribute " << attrName;
625  diag.attachNote(b->getLoc()) << "second operation here";
626  return diag;
627  }
628 
629  if (isa<hw::InnerRefAttr>(aAttr) && isa<hw::InnerRefAttr>(bAttr)) {
630  auto bRef = cast<hw::InnerRefAttr>(bAttr);
631  auto aRef = cast<hw::InnerRefAttr>(aAttr);
632  // See if they are pointing at the same operation or port.
633  auto aTarget = data.a.lookup(aRef.getName());
634  auto bTarget = data.b.lookup(bRef.getName());
635  if (!aTarget || !bTarget)
636  diag.attachNote(a->getLoc())
637  << "malformed ir, possibly violating use-before-def";
638  auto error = [&]() {
639  diag.attachNote(a->getLoc())
640  << "operations have different targets, first operation has "
641  << aTarget;
642  diag.attachNote(b->getLoc()) << "second operation has " << bTarget;
643  return failure();
644  };
645  if (aTarget.isPort()) {
646  // If they are targeting ports, make sure its the same port number.
647  if (!bTarget.isPort() || aTarget.getPort() != bTarget.getPort())
648  return error();
649  } else {
650  // Otherwise make sure that they are targeting the same operation.
651  if (!bTarget.isOpOnly() ||
652  aTarget.getOp() != data.map.lookup(bTarget.getOp()))
653  return error();
654  }
655  if (aTarget.getField() != bTarget.getField())
656  return error();
657  } else if (attrName == portDirectionsAttr) {
658  // Special handling for the port directions attribute for better
659  // error messages.
660  if (failed(check(diag, a, cast<mlir::DenseBoolArrayAttr>(aAttr), b,
661  cast<mlir::DenseBoolArrayAttr>(bAttr))))
662  return failure();
663  } else if (isa<DistinctAttr>(aAttr) && isa<DistinctAttr>(bAttr)) {
664  // TODO: properly handle DistinctAttr, including its use in paths.
665  // See https://github.com/llvm/circt/issues/6583
666  } else if (aAttr != bAttr) {
667  diag.attachNote(a->getLoc())
668  << "first operation has attribute '" << attrName.getValue()
669  << "' with value " << prettyPrint(aAttr);
670  diag.attachNote(b->getLoc())
671  << "second operation has value " << prettyPrint(bAttr);
672  return failure();
673  }
674  seenAttrs.insert(attrName);
675  }
676  if (aDict.getValue().size() != bDict.getValue().size()) {
677  for (auto namedAttr : bDict) {
678  auto attrName = namedAttr.getName();
679  // Skip the attribute if we don't care about this particular one or it
680  // is one that is known to be in both dictionaries.
681  if (nonessentialAttributes.contains(attrName) ||
682  seenAttrs.contains(attrName))
683  continue;
684  // We have found an attribute that is only in the second operation.
685  diag.attachNote(a->getLoc())
686  << "first operation is missing attribute " << attrName;
687  diag.attachNote(b->getLoc()) << "second operation here";
688  return failure();
689  }
690  }
691  return success();
692  }
693 
694  // NOLINTNEXTLINE(misc-no-recursion)
695  LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a,
696  FInstanceLike b) {
697  auto aName = a.getReferencedModuleNameAttr();
698  auto bName = b.getReferencedModuleNameAttr();
699  if (aName == bName)
700  return success();
701 
702  // If the modules instantiate are different we will want to know why the
703  // sub module did not dedupliate. This code recursively checks the child
704  // module.
705  auto aModule = instanceGraph.lookup(aName)->getModule();
706  auto bModule = instanceGraph.lookup(bName)->getModule();
707  // Create a new error for the submodule.
708  diag.attachNote(std::nullopt)
709  << "in instance " << a.getInstanceNameAttr() << " of " << aName
710  << ", and instance " << b.getInstanceNameAttr() << " of " << bName;
711  check(diag, aModule, bModule);
712  return failure();
713  }
714 
715  // NOLINTNEXTLINE(misc-no-recursion)
716  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
717  Operation *b) {
718  // Operation name.
719  if (a->getName() != b->getName()) {
720  diag.attachNote(a->getLoc()) << "first operation is a " << a->getName();
721  diag.attachNote(b->getLoc()) << "second operation is a " << b->getName();
722  return failure();
723  }
724 
725  // If its an instance operaiton, perform some checking and possibly
726  // recurse.
727  if (auto aInst = dyn_cast<FInstanceLike>(a)) {
728  auto bInst = cast<FInstanceLike>(b);
729  if (failed(check(diag, aInst, bInst)))
730  return failure();
731  }
732 
733  // Operation results.
734  if (a->getNumResults() != b->getNumResults()) {
735  diag.attachNote(a->getLoc())
736  << "operations have different number of results";
737  diag.attachNote(b->getLoc()) << "second operation here";
738  return failure();
739  }
740  for (auto resultPair : llvm::zip(a->getResults(), b->getResults())) {
741  auto &aValue = std::get<0>(resultPair);
742  auto &bValue = std::get<1>(resultPair);
743  if (failed(check(diag, "operation result", a, aValue.getType(), b,
744  bValue.getType())))
745  return failure();
746  data.map.map(aValue, bValue);
747  }
748 
749  // Operations operands.
750  if (a->getNumOperands() != b->getNumOperands()) {
751  diag.attachNote(a->getLoc())
752  << "operations have different number of operands";
753  diag.attachNote(b->getLoc()) << "second operation here";
754  return failure();
755  }
756  for (auto operandPair : llvm::zip(a->getOperands(), b->getOperands())) {
757  auto &aValue = std::get<0>(operandPair);
758  auto &bValue = std::get<1>(operandPair);
759  if (bValue != data.map.lookup(aValue)) {
760  diag.attachNote(a->getLoc())
761  << "operations use different operands, first operand is '"
762  << getFieldName(
763  getFieldRefFromValue(aValue, /*lookThroughCasts=*/true))
764  .first
765  << "'";
766  diag.attachNote(b->getLoc())
767  << "second operand is '"
768  << getFieldName(
769  getFieldRefFromValue(bValue, /*lookThroughCasts=*/true))
770  .first
771  << "', but should have been '"
772  << getFieldName(getFieldRefFromValue(data.map.lookup(aValue),
773  /*lookThroughCasts=*/true))
774  .first
775  << "'";
776  return failure();
777  }
778  }
779  data.map.map(a, b);
780 
781  // Operation regions.
782  if (a->getNumRegions() != b->getNumRegions()) {
783  diag.attachNote(a->getLoc())
784  << "operations have different number of regions";
785  diag.attachNote(b->getLoc()) << "second operation here";
786  return failure();
787  }
788  for (auto regionPair : llvm::zip(a->getRegions(), b->getRegions())) {
789  auto &aRegion = std::get<0>(regionPair);
790  auto &bRegion = std::get<1>(regionPair);
791  if (failed(check(diag, data, a, aRegion, b, bRegion)))
792  return failure();
793  }
794 
795  // Operation attributes.
796  if (failed(check(diag, data, a, a->getAttrDictionary(), b,
797  b->getAttrDictionary())))
798  return failure();
799  return success();
800  }
801 
802  // NOLINTNEXTLINE(misc-no-recursion)
803  void check(InFlightDiagnostic &diag, Operation *a, Operation *b) {
804  hw::InnerSymbolTable aTable(a);
805  hw::InnerSymbolTable bTable(b);
806  ModuleData data(aTable, bTable);
807  if (AnnotationSet::hasAnnotation(a, noDedupClass)) {
808  diag.attachNote(a->getLoc()) << "module marked NoDedup";
809  return;
810  }
811  if (AnnotationSet::hasAnnotation(b, noDedupClass)) {
812  diag.attachNote(b->getLoc()) << "module marked NoDedup";
813  return;
814  }
815  auto aSymbol = cast<mlir::SymbolOpInterface>(a);
816  auto bSymbol = cast<mlir::SymbolOpInterface>(b);
817  if (!canRemoveModule(aSymbol)) {
818  diag.attachNote(a->getLoc())
819  << "module is "
820  << (aSymbol.isPrivate() ? "private but not discardable" : "public");
821  return;
822  }
823  if (!canRemoveModule(bSymbol)) {
824  diag.attachNote(b->getLoc())
825  << "module is "
826  << (bSymbol.isPrivate() ? "private but not discardable" : "public");
827  return;
828  }
829  auto aGroup =
830  dyn_cast_or_null<StringAttr>(a->getDiscardableAttr(dedupGroupAttrName));
831  auto bGroup = dyn_cast_or_null<StringAttr>(
832  b->getAttrOfType<StringAttr>(dedupGroupAttrName));
833  if (aGroup != bGroup) {
834  if (bGroup) {
835  diag.attachNote(b->getLoc())
836  << "module is in dedup group '" << bGroup.str() << "'";
837  } else {
838  diag.attachNote(b->getLoc()) << "module is not part of a dedup group";
839  }
840  if (aGroup) {
841  diag.attachNote(a->getLoc())
842  << "module is in dedup group '" << aGroup.str() << "'";
843  } else {
844  diag.attachNote(a->getLoc()) << "module is not part of a dedup group";
845  }
846  return;
847  }
848  if (failed(check(diag, data, a, b)))
849  return;
850  diag.attachNote(a->getLoc()) << "first module here";
851  diag.attachNote(b->getLoc()) << "second module here";
852  }
853 
854  // This is a cached "portDirections" string attr.
855  StringAttr portDirectionsAttr;
856  // This is a cached "NoDedup" annotation class string attr.
857  StringAttr noDedupClass;
858  // This is a cached string attr for the dedup group attribute.
859  StringAttr dedupGroupAttrName;
860 
861  // This is a set of every attribute we should ignore.
862  DenseSet<Attribute> nonessentialAttributes;
864 };
865 
866 //===----------------------------------------------------------------------===//
867 // Deduplication
868 //===----------------------------------------------------------------------===//
869 
870 // Custom location merging. This only keeps track of 8 annotations from ".fir"
871 // files, and however many annotations come from "real" sources. When
872 // deduplicating, modules tend not to have scala source locators, so we wind
873 // up fusing source locators for a module from every copy being deduped. There
874 // is little value in this (all the modules are identical by definition).
875 static Location mergeLoc(MLIRContext *context, Location to, Location from) {
876  // Unique the set of locations to be fused.
877  llvm::SmallSetVector<Location, 4> decomposedLocs;
878  // only track 8 "fir" locations
879  unsigned seenFIR = 0;
880  for (auto loc : {to, from}) {
881  // If the location is a fused location we decompose it if it has no
882  // metadata or the metadata is the same as the top level metadata.
883  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
884  // UnknownLoc's have already been removed from FusedLocs so we can
885  // simply add all of the internal locations.
886  for (auto loc : fusedLoc.getLocations()) {
887  if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
888  if (fileLoc.getFilename().strref().ends_with(".fir")) {
889  ++seenFIR;
890  if (seenFIR > 8)
891  continue;
892  }
893  }
894  decomposedLocs.insert(loc);
895  }
896  continue;
897  }
898 
899  // Might need to skip this fir.
900  if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
901  if (fileLoc.getFilename().strref().ends_with(".fir")) {
902  ++seenFIR;
903  if (seenFIR > 8)
904  continue;
905  }
906  }
907  // Otherwise, only add known locations to the set.
908  if (!isa<UnknownLoc>(loc))
909  decomposedLocs.insert(loc);
910  }
911 
912  auto locs = decomposedLocs.getArrayRef();
913 
914  // Handle the simple cases of less than two locations. Ensure the metadata (if
915  // provided) is not dropped.
916  if (locs.empty())
917  return UnknownLoc::get(context);
918  if (locs.size() == 1)
919  return locs.front();
920 
921  return FusedLoc::get(context, locs);
922 }
923 
924 struct Deduper {
925 
926  using RenameMap = DenseMap<StringAttr, StringAttr>;
927 
928  Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable,
929  NLATable *nlaTable, CircuitOp circuit)
930  : context(circuit->getContext()), instanceGraph(instanceGraph),
931  symbolTable(symbolTable), nlaTable(nlaTable),
932  nlaBlock(circuit.getBodyBlock()),
933  nonLocalString(StringAttr::get(context, "circt.nonlocal")),
934  classString(StringAttr::get(context, "class")) {
935  // Populate the NLA cache.
936  for (auto nla : circuit.getOps<hw::HierPathOp>())
937  nlaCache[nla.getNamepathAttr()] = nla.getSymNameAttr();
938  }
939 
940  /// Remove the "fromModule", and replace all references to it with the
941  /// "toModule". Modules should be deduplicated in a bottom-up order. Any
942  /// module which is not deduplicated needs to be recorded with the `record`
943  /// call.
944  void dedup(FModuleLike toModule, FModuleLike fromModule) {
945  // A map of operation (e.g. wires, nodes) names which are changed, which is
946  // used to update NLAs that reference the "fromModule".
947  RenameMap renameMap;
948 
949  // Merge the port locations.
950  SmallVector<Attribute> newLocs;
951  for (auto [toLoc, fromLoc] : llvm::zip(toModule.getPortLocations(),
952  fromModule.getPortLocations())) {
953  if (toLoc == fromLoc)
954  newLocs.push_back(toLoc);
955  else
956  newLocs.push_back(mergeLoc(context, cast<LocationAttr>(toLoc),
957  cast<LocationAttr>(fromLoc)));
958  }
959  toModule->setAttr("portLocations", ArrayAttr::get(context, newLocs));
960 
961  // Merge the two modules.
962  mergeOps(renameMap, toModule, toModule, fromModule, fromModule);
963 
964  // Rewrite NLAs pathing through these modules to refer to the to module. It
965  // is safe to do this at this point because NLAs cannot be one element long.
966  // This means that all NLAs which require more context cannot be targetting
967  // something in the module it self.
968  if (auto to = dyn_cast<FModuleOp>(*toModule))
969  rewriteModuleNLAs(renameMap, to, cast<FModuleOp>(*fromModule));
970  else
971  rewriteExtModuleNLAs(renameMap, toModule.getModuleNameAttr(),
972  fromModule.getModuleNameAttr());
973 
974  replaceInstances(toModule, fromModule);
975  }
976 
977  /// Record the usages of any NLA's in this module, so that we may update the
978  /// annotation if the parent module is deduped with another module.
979  void record(FModuleLike module) {
980  // Record any annotations on the module.
981  recordAnnotations(module);
982  // Record port annotations.
983  for (unsigned i = 0, e = getNumPorts(module); i < e; ++i)
984  recordAnnotations(PortAnnoTarget(module, i));
985  // Record any annotations in the module body.
986  module->walk([&](Operation *op) { recordAnnotations(op); });
987  }
988 
989 private:
990  /// Get a cached namespace for a module.
991  hw::InnerSymbolNamespace &getNamespace(Operation *module) {
992  return moduleNamespaces.try_emplace(module, cast<FModuleLike>(module))
993  .first->second;
994  }
995 
996  /// For a specific annotation target, record all the unique NLAs which
997  /// target it in the `targetMap`.
999  for (auto anno : target.getAnnotations())
1000  if (auto nlaRef = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
1001  targetMap[nlaRef.getAttr()].insert(target);
1002  }
1003 
1004  /// Record all targets which use an NLA.
1005  void recordAnnotations(Operation *op) {
1006  // Record annotations.
1007  recordAnnotations(OpAnnoTarget(op));
1008 
1009  // Record port annotations only if this is a mem operation.
1010  auto mem = dyn_cast<MemOp>(op);
1011  if (!mem)
1012  return;
1013 
1014  // Record port annotations.
1015  for (unsigned i = 0, e = mem->getNumResults(); i < e; ++i)
1016  recordAnnotations(PortAnnoTarget(mem, i));
1017  }
1018 
1019  /// This deletes and replaces all instances of the "fromModule" with instances
1020  /// of the "toModule".
1021  void replaceInstances(FModuleLike toModule, Operation *fromModule) {
1022  // Replace all instances of the other module.
1023  auto *fromNode =
1024  instanceGraph[::cast<igraph::ModuleOpInterface>(fromModule)];
1025  auto *toNode = instanceGraph[toModule];
1026  auto toModuleRef = FlatSymbolRefAttr::get(toModule.getModuleNameAttr());
1027  for (auto *oldInstRec : llvm::make_early_inc_range(fromNode->uses())) {
1028  auto inst = oldInstRec->getInstance();
1029  if (auto instOp = dyn_cast<InstanceOp>(*inst)) {
1030  instOp.setModuleNameAttr(toModuleRef);
1031  instOp.setPortNamesAttr(toModule.getPortNamesAttr());
1032  } else if (auto objectOp = dyn_cast<ObjectOp>(*inst)) {
1033  auto classLike = cast<ClassLike>(*toNode->getModule());
1034  ClassType classType = detail::getInstanceTypeForClassLike(classLike);
1035  objectOp.getResult().setType(classType);
1036  }
1037  oldInstRec->getParent()->addInstance(inst, toNode);
1038  oldInstRec->erase();
1039  }
1040  instanceGraph.erase(fromNode);
1041  fromModule->erase();
1042  }
1043 
1044  /// Look up the instantiations of the `from` module and create an NLA for each
1045  /// one, appending the baseNamepath to each NLA. This is used to add more
1046  /// context to an already existing NLA. The `fromModule` is used to indicate
1047  /// which module the annotation is coming from before the merge, and will be
1048  /// used to create the namepaths.
1049  SmallVector<FlatSymbolRefAttr>
1050  createNLAs(Operation *fromModule, ArrayRef<Attribute> baseNamepath,
1051  SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1052  // Create an attribute array with a placeholder in the first element, where
1053  // the root refence of the NLA will be inserted.
1054  SmallVector<Attribute> namepath = {nullptr};
1055  namepath.append(baseNamepath.begin(), baseNamepath.end());
1056 
1057  auto loc = fromModule->getLoc();
1058  auto *fromNode = instanceGraph[cast<igraph::ModuleOpInterface>(fromModule)];
1059  SmallVector<FlatSymbolRefAttr> nlas;
1060  for (auto *instanceRecord : fromNode->uses()) {
1061  auto parent = cast<FModuleOp>(*instanceRecord->getParent()->getModule());
1062  auto inst = instanceRecord->getInstance();
1063  namepath[0] = OpAnnoTarget(inst).getNLAReference(getNamespace(parent));
1064  auto arrayAttr = ArrayAttr::get(context, namepath);
1065  // Check the NLA cache to see if we already have this NLA.
1066  auto &cacheEntry = nlaCache[arrayAttr];
1067  if (!cacheEntry) {
1068  auto nla = OpBuilder::atBlockBegin(nlaBlock).create<hw::HierPathOp>(
1069  loc, "nla", arrayAttr);
1070  // Insert it into the symbol table to get a unique name.
1071  symbolTable.insert(nla);
1072  // Store it in the cache.
1073  cacheEntry = nla.getNameAttr();
1074  nla.setVisibility(vis);
1075  nlaTable->addNLA(nla);
1076  }
1077  auto nlaRef = FlatSymbolRefAttr::get(cast<StringAttr>(cacheEntry));
1078  nlas.push_back(nlaRef);
1079  }
1080  return nlas;
1081  }
1082 
1083  /// Look up the instantiations of this module and create an NLA for each one.
1084  /// This returns an array of symbol references which can be used to reference
1085  /// the NLAs.
1086  SmallVector<FlatSymbolRefAttr>
1087  createNLAs(StringAttr toModuleName, FModuleLike fromModule,
1088  SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1089  return createNLAs(fromModule, FlatSymbolRefAttr::get(toModuleName), vis);
1090  }
1091 
1092  /// Clone the annotation for each NLA in a list. The attribute list should
1093  /// have a placeholder for the "circt.nonlocal" field, and `nonLocalIndex`
1094  /// should be the index of this field.
1095  void cloneAnnotation(SmallVectorImpl<FlatSymbolRefAttr> &nlas,
1096  Annotation anno, ArrayRef<NamedAttribute> attributes,
1097  unsigned nonLocalIndex,
1098  SmallVectorImpl<Annotation> &newAnnotations) {
1099  SmallVector<NamedAttribute> mutableAttributes(attributes.begin(),
1100  attributes.end());
1101  for (auto &nla : nlas) {
1102  // Add the new annotation.
1103  mutableAttributes[nonLocalIndex].setValue(nla);
1104  auto dict = DictionaryAttr::getWithSorted(context, mutableAttributes);
1105  // The original annotation records if its a subannotation.
1106  anno.setDict(dict);
1107  newAnnotations.push_back(anno);
1108  }
1109  }
1110 
1111  /// This erases the NLA op, and removes the NLA from every module's NLA map,
1112  /// but it does not delete the NLA reference from the target operation's
1113  /// annotations.
1114  void eraseNLA(hw::HierPathOp nla) {
1115  // Erase the NLA from the leaf module's nlaMap.
1116  targetMap.erase(nla.getNameAttr());
1117  nlaTable->erase(nla);
1118  nlaCache.erase(nla.getNamepathAttr());
1119  symbolTable.erase(nla);
1120  }
1121 
1122  /// Process all NLAs referencing the "from" module to point to the "to"
1123  /// module. This is used after merging two modules together.
1124  void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule,
1125  FModuleOp fromModule) {
1126  auto toName = toModule.getNameAttr();
1127  auto fromName = fromModule.getNameAttr();
1128  // Create a copy of the current NLAs. We will be pushing and removing
1129  // NLAs from this op as we go.
1130  auto moduleNLAs = nlaTable->lookup(fromModule.getNameAttr()).vec();
1131  // Change the NLA to target the toModule.
1132  nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1133  // Now we walk the NLA searching for ones that require more context to be
1134  // added.
1135  for (auto nla : moduleNLAs) {
1136  auto elements = nla.getNamepath().getValue();
1137  // If we don't need to add more context, we're done here.
1138  if (nla.root() != toName)
1139  continue;
1140  // Create the replacement NLAs.
1141  SmallVector<Attribute> namepath(elements.begin(), elements.end());
1142  auto nlaRefs = createNLAs(fromModule, namepath, nla.getVisibility());
1143  // Copy out the targets, because we will be updating the map.
1144  auto &set = targetMap[nla.getSymNameAttr()];
1145  SmallVector<AnnoTarget> targets(set.begin(), set.end());
1146  // Replace the uses of the old NLA with the new NLAs.
1147  for (auto target : targets) {
1148  // We have to clone any annotation which uses the old NLA for each new
1149  // NLA. This array collects the new set of annotations.
1150  SmallVector<Annotation> newAnnotations;
1151  for (auto anno : target.getAnnotations()) {
1152  // Find the non-local field of the annotation.
1153  auto [it, found] = mlir::impl::findAttrSorted(
1154  anno.begin(), anno.end(), nonLocalString);
1155  // If this annotation doesn't use the target NLA, copy it with no
1156  // changes.
1157  if (!found || cast<FlatSymbolRefAttr>(it->getValue()).getAttr() !=
1158  nla.getSymNameAttr()) {
1159  newAnnotations.push_back(anno);
1160  continue;
1161  }
1162  auto nonLocalIndex = std::distance(anno.begin(), it);
1163  // Clone the annotation and add it to the list of new annotations.
1164  cloneAnnotation(nlaRefs, anno,
1165  ArrayRef<NamedAttribute>(anno.begin(), anno.end()),
1166  nonLocalIndex, newAnnotations);
1167  }
1168 
1169  // Apply the new annotations to the operation.
1170  AnnotationSet annotations(newAnnotations, context);
1171  target.setAnnotations(annotations);
1172  // Record that target uses the NLA.
1173  for (auto nla : nlaRefs)
1174  targetMap[nla.getAttr()].insert(target);
1175  }
1176 
1177  // Erase the old NLA and remove it from all breadcrumbs.
1178  eraseNLA(nla);
1179  }
1180  }
1181 
1182  /// Process all the NLAs that the two modules participate in, replacing
1183  /// references to the "from" module with references to the "to" module, and
1184  /// adding more context if necessary.
1185  void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule,
1186  FModuleOp fromModule) {
1187  addAnnotationContext(renameMap, toModule, toModule);
1188  addAnnotationContext(renameMap, toModule, fromModule);
1189  }
1190 
1191  // Update all NLAs which the "from" external module participates in to the
1192  // "toName".
1193  void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName,
1194  StringAttr fromName) {
1195  nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1196  }
1197 
1198  /// Take an annotation, and update it to be a non-local annotation. If the
1199  /// annotation is already non-local and has enough context, it will be skipped
1200  /// for now. Return true if the annotation was made non-local.
1201  bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to,
1202  FModuleLike fromModule, Annotation anno,
1203  SmallVectorImpl<Annotation> &newAnnotations) {
1204  // Start constructing a new annotation, pushing a "circt.nonLocal" field
1205  // into the correct spot if its not already a non-local annotation.
1206  SmallVector<NamedAttribute> attributes;
1207  int nonLocalIndex = -1;
1208  for (const auto &val : llvm::enumerate(anno)) {
1209  auto attr = val.value();
1210  // Is this field "circt.nonlocal"?
1211  auto compare = attr.getName().compare(nonLocalString);
1212  assert(compare != 0 && "should not pass non-local annotations here");
1213  if (compare == 1) {
1214  // This annotation definitely does not have "circt.nonlocal" field. Push
1215  // an empty place holder for the non-local annotation.
1216  nonLocalIndex = val.index();
1217  attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1218  break;
1219  }
1220  // Otherwise push the current attribute and keep searching for the
1221  // "circt.nonlocal" field.
1222  attributes.push_back(attr);
1223  }
1224  if (nonLocalIndex == -1) {
1225  // Push an empty "circt.nonlocal" field to the last slot.
1226  nonLocalIndex = attributes.size();
1227  attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1228  } else {
1229  // Copy the remaining annotation fields in.
1230  attributes.append(anno.begin() + nonLocalIndex, anno.end());
1231  }
1232 
1233  // Construct the NLAs if we don't have any yet.
1234  auto nlaRefs = createNLAs(toModuleName, fromModule);
1235  for (auto nla : nlaRefs)
1236  targetMap[nla.getAttr()].insert(to);
1237 
1238  // Clone the annotation for each new NLA.
1239  cloneAnnotation(nlaRefs, anno, attributes, nonLocalIndex, newAnnotations);
1240  return true;
1241  }
1242 
1243  void copyAnnotations(FModuleLike toModule, AnnoTarget to,
1244  FModuleLike fromModule, AnnotationSet annos,
1245  SmallVectorImpl<Annotation> &newAnnotations,
1246  SmallPtrSetImpl<Attribute> &dontTouches) {
1247  for (auto anno : annos) {
1248  if (anno.isClass(dontTouchAnnoClass)) {
1249  // Remove the nonlocal field of the annotation if it has one, since this
1250  // is a sticky annotation.
1251  anno.removeMember("circt.nonlocal");
1252  auto [it, inserted] = dontTouches.insert(anno.getAttr());
1253  if (inserted)
1254  newAnnotations.push_back(anno);
1255  continue;
1256  }
1257  // If the annotation is already non-local, we add it as is. It is already
1258  // added to the target map.
1259  if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal")) {
1260  newAnnotations.push_back(anno);
1261  targetMap[nla.getAttr()].insert(to);
1262  continue;
1263  }
1264  // Otherwise make the annotation non-local and add it to the set.
1265  makeAnnotationNonLocal(toModule.getModuleNameAttr(), to, fromModule, anno,
1266  newAnnotations);
1267  }
1268  }
1269 
1270  /// Merge the annotations of a specific target, either a operation or a port
1271  /// on an operation.
1272  void mergeAnnotations(FModuleLike toModule, AnnoTarget to,
1273  AnnotationSet toAnnos, FModuleLike fromModule,
1274  AnnoTarget from, AnnotationSet fromAnnos) {
1275  // This is a list of all the annotations which will be added to `to`.
1276  SmallVector<Annotation> newAnnotations;
1277 
1278  // We have special case handling of DontTouch to prevent it from being
1279  // turned into a non-local annotation, and to remove duplicates.
1280  llvm::SmallPtrSet<Attribute, 4> dontTouches;
1281 
1282  // Iterate the annotations, transforming most annotations into non-local
1283  // ones.
1284  copyAnnotations(toModule, to, toModule, toAnnos, newAnnotations,
1285  dontTouches);
1286  copyAnnotations(toModule, to, fromModule, fromAnnos, newAnnotations,
1287  dontTouches);
1288 
1289  // Copy over all the new annotations.
1290  if (!newAnnotations.empty())
1291  to.setAnnotations(AnnotationSet(newAnnotations, context));
1292  }
1293 
1294  /// Merge all annotations and port annotations on two operations.
1295  void mergeAnnotations(FModuleLike toModule, Operation *to,
1296  FModuleLike fromModule, Operation *from) {
1297  // Merge op annotations.
1298  mergeAnnotations(toModule, OpAnnoTarget(to), AnnotationSet(to), fromModule,
1299  OpAnnoTarget(from), AnnotationSet(from));
1300 
1301  // Merge port annotations.
1302  if (toModule == to) {
1303  // Merge module port annotations.
1304  for (unsigned i = 0, e = getNumPorts(toModule); i < e; ++i)
1305  mergeAnnotations(toModule, PortAnnoTarget(toModule, i),
1306  AnnotationSet::forPort(toModule, i), fromModule,
1307  PortAnnoTarget(fromModule, i),
1308  AnnotationSet::forPort(fromModule, i));
1309  } else if (auto toMem = dyn_cast<MemOp>(to)) {
1310  // Merge memory port annotations.
1311  auto fromMem = cast<MemOp>(from);
1312  for (unsigned i = 0, e = toMem.getNumResults(); i < e; ++i)
1313  mergeAnnotations(toModule, PortAnnoTarget(toMem, i),
1314  AnnotationSet::forPort(toMem, i), fromModule,
1315  PortAnnoTarget(fromMem, i),
1316  AnnotationSet::forPort(fromMem, i));
1317  }
1318  }
1319 
1320  // Record the symbol name change of the operation or any of its ports when
1321  // merging two operations. The renamed symbols are used to update the
1322  // target of any NLAs. This will add symbols to the "to" operation if needed.
1323  void recordSymRenames(RenameMap &renameMap, FModuleLike toModule,
1324  Operation *to, FModuleLike fromModule,
1325  Operation *from) {
1326  // If the "from" operation has an inner_sym, we need to make sure the
1327  // "to" operation also has an `inner_sym` and then record the renaming.
1328  if (auto fromSym = getInnerSymName(from)) {
1329  auto toSym =
1330  getOrAddInnerSym(to, [&](auto _) -> hw::InnerSymbolNamespace & {
1331  return getNamespace(toModule);
1332  });
1333  renameMap[fromSym] = toSym;
1334  }
1335 
1336  // If there are no port symbols on the "from" operation, we are done here.
1337  auto fromPortSyms = from->getAttrOfType<ArrayAttr>("portSymbols");
1338  if (!fromPortSyms || fromPortSyms.empty())
1339  return;
1340  // We have to map each "fromPort" to each "toPort".
1341  auto &moduleNamespace = getNamespace(toModule);
1342  auto portCount = fromPortSyms.size();
1343  auto portNames = to->getAttrOfType<ArrayAttr>("portNames");
1344  auto toPortSyms = to->getAttrOfType<ArrayAttr>("portSymbols");
1345 
1346  // Create an array of new port symbols for the "to" operation, copy in the
1347  // old symbols if it has any, create an empty symbol array if it doesn't.
1348  SmallVector<Attribute> newPortSyms;
1349  if (toPortSyms.empty())
1350  newPortSyms.assign(portCount, hw::InnerSymAttr());
1351  else
1352  newPortSyms.assign(toPortSyms.begin(), toPortSyms.end());
1353 
1354  for (unsigned portNo = 0; portNo < portCount; ++portNo) {
1355  // If this fromPort doesn't have a symbol, move on to the next one.
1356  if (!fromPortSyms[portNo])
1357  continue;
1358  auto fromSym = cast<hw::InnerSymAttr>(fromPortSyms[portNo]);
1359 
1360  // If this toPort doesn't have a symbol, assign one.
1361  hw::InnerSymAttr toSym;
1362  if (!newPortSyms[portNo]) {
1363  // Get a reasonable base name for the port.
1364  StringRef symName = "inner_sym";
1365  if (portNames)
1366  symName = cast<StringAttr>(portNames[portNo]).getValue();
1367  // Create the symbol and store it into the array.
1368  toSym = hw::InnerSymAttr::get(
1369  StringAttr::get(context, moduleNamespace.newName(symName)));
1370  newPortSyms[portNo] = toSym;
1371  } else
1372  toSym = cast<hw::InnerSymAttr>(newPortSyms[portNo]);
1373 
1374  // Record the renaming.
1375  renameMap[fromSym.getSymName()] = toSym.getSymName();
1376  }
1377 
1378  // Commit the new symbol attribute.
1379  FModuleLike::fixupPortSymsArray(newPortSyms, toModule.getContext());
1380  cast<FModuleLike>(to).setPortSymbols(newPortSyms);
1381  }
1382 
1383  /// Recursively merge two operations.
1384  // NOLINTNEXTLINE(misc-no-recursion)
1385  void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to,
1386  FModuleLike fromModule, Operation *from) {
1387  // Merge the operation locations.
1388  if (to->getLoc() != from->getLoc())
1389  to->setLoc(mergeLoc(context, to->getLoc(), from->getLoc()));
1390 
1391  // Recurse into any regions.
1392  for (auto regions : llvm::zip(to->getRegions(), from->getRegions()))
1393  mergeRegions(renameMap, toModule, std::get<0>(regions), fromModule,
1394  std::get<1>(regions));
1395 
1396  // Record any inner_sym renamings that happened.
1397  recordSymRenames(renameMap, toModule, to, fromModule, from);
1398 
1399  // Merge the annotations.
1400  mergeAnnotations(toModule, to, fromModule, from);
1401  }
1402 
1403  /// Recursively merge two blocks.
1404  void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock,
1405  FModuleLike fromModule, Block &fromBlock) {
1406  // Merge the block locations.
1407  for (auto [toArg, fromArg] :
1408  llvm::zip(toBlock.getArguments(), fromBlock.getArguments()))
1409  if (toArg.getLoc() != fromArg.getLoc())
1410  toArg.setLoc(mergeLoc(context, toArg.getLoc(), fromArg.getLoc()));
1411 
1412  for (auto ops : llvm::zip(toBlock, fromBlock))
1413  mergeOps(renameMap, toModule, &std::get<0>(ops), fromModule,
1414  &std::get<1>(ops));
1415  }
1416 
1417  // Recursively merge two regions.
1418  void mergeRegions(RenameMap &renameMap, FModuleLike toModule,
1419  Region &toRegion, FModuleLike fromModule,
1420  Region &fromRegion) {
1421  for (auto blocks : llvm::zip(toRegion, fromRegion))
1422  mergeBlocks(renameMap, toModule, std::get<0>(blocks), fromModule,
1423  std::get<1>(blocks));
1424  }
1425 
1426  MLIRContext *context;
1428  SymbolTable &symbolTable;
1429 
1430  /// Cached nla table analysis.
1431  NLATable *nlaTable = nullptr;
1432 
1433  /// We insert all NLAs to the beginning of this block.
1434  Block *nlaBlock;
1435 
1436  // This maps an NLA to the operations and ports that uses it.
1437  DenseMap<Attribute, llvm::SmallDenseSet<AnnoTarget>> targetMap;
1438 
1439  // This is a cache to avoid creating duplicate NLAs. This maps the ArrayAtr
1440  // of the NLA's path to the name of the NLA which contains it.
1441  DenseMap<Attribute, Attribute> nlaCache;
1442 
1443  // Cached attributes for faster comparisons and attribute building.
1444  StringAttr nonLocalString;
1445  StringAttr classString;
1446 
1447  /// A module namespace cache.
1448  DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
1449 };
1450 
1451 //===----------------------------------------------------------------------===//
1452 // Fixup
1453 //===----------------------------------------------------------------------===//
1454 
1455 /// This fixes up ClassLikes with ClassType ports, when the classes have
1456 /// deduped. For each ClassType port, if the object reference being assigned is
1457 /// a different type, update the port type. Returns true if the ClassOp was
1458 /// updated and the associated ObjectOps should be updated.
1459 bool fixupClassOp(ClassOp classOp) {
1460  // New port type attributes, if necessary.
1461  SmallVector<Attribute> newPortTypes;
1462  bool anyDifferences = false;
1463 
1464  // Check each port.
1465  for (size_t i = 0, e = classOp.getNumPorts(); i < e; ++i) {
1466  // Check if this port is a ClassType. If not, save the original type
1467  // attribute in case we need to update port types.
1468  auto portClassType = dyn_cast<ClassType>(classOp.getPortType(i));
1469  if (!portClassType) {
1470  newPortTypes.push_back(classOp.getPortTypeAttr(i));
1471  continue;
1472  }
1473 
1474  // Check if this port is assigned a reference of a different ClassType.
1475  Type newPortClassType;
1476  BlockArgument portArg = classOp.getArgument(i);
1477  for (auto &use : portArg.getUses()) {
1478  if (auto propassign = dyn_cast<PropAssignOp>(use.getOwner())) {
1479  Type sourceType = propassign.getSrc().getType();
1480  if (propassign.getDest() == use.get() && sourceType != portClassType) {
1481  // Double check that all references are the same new type.
1482  if (newPortClassType) {
1483  assert(newPortClassType == sourceType &&
1484  "expected all references to be of the same type");
1485  continue;
1486  }
1487 
1488  newPortClassType = sourceType;
1489  }
1490  }
1491  }
1492 
1493  // If there was no difference, save the original type attribute in case we
1494  // need to update port types and move along.
1495  if (!newPortClassType) {
1496  newPortTypes.push_back(classOp.getPortTypeAttr(i));
1497  continue;
1498  }
1499 
1500  // The port type changed, so update the block argument, save the new port
1501  // type attribute, and indicate there was a difference.
1502  classOp.getArgument(i).setType(newPortClassType);
1503  newPortTypes.push_back(TypeAttr::get(newPortClassType));
1504  anyDifferences = true;
1505  }
1506 
1507  // If necessary, update port types.
1508  if (anyDifferences)
1509  classOp.setPortTypes(newPortTypes);
1510 
1511  return anyDifferences;
1512 }
1513 
1514 /// This fixes up ObjectOps when the signature of their ClassOp changes. This
1515 /// amounts to updating the ObjectOp result type to match the newly updated
1516 /// ClassOp type.
1517 void fixupObjectOp(ObjectOp objectOp, ClassType newClassType) {
1518  objectOp.getResult().setType(newClassType);
1519 }
1520 
1521 /// This fixes up connects when the field names of a bundle type changes. It
1522 /// finds all fields which were previously bulk connected and legalizes it
1523 /// into a connect for each field.
1524 void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src) {
1525  // If the types already match we can emit a connect.
1526  auto dstType = dst.getType();
1527  auto srcType = src.getType();
1528  if (dstType == srcType) {
1529  emitConnect(builder, dst, src);
1530  return;
1531  }
1532  // It must be a bundle type and the field name has changed. We have to
1533  // manually decompose the bulk connect into a connect for each field.
1534  auto dstBundle = type_cast<BundleType>(dstType);
1535  auto srcBundle = type_cast<BundleType>(srcType);
1536  for (unsigned i = 0; i < dstBundle.getNumElements(); ++i) {
1537  auto dstField = builder.create<SubfieldOp>(dst, i);
1538  auto srcField = builder.create<SubfieldOp>(src, i);
1539  if (dstBundle.getElement(i).isFlip) {
1540  std::swap(srcBundle, dstBundle);
1541  std::swap(srcField, dstField);
1542  }
1543  fixupConnect(builder, dstField, srcField);
1544  }
1545 }
1546 
1547 /// This is the root method to fixup module references when a module changes.
1548 /// It matches all the results of "to" module with the results of the "from"
1549 /// module.
1550 void fixupAllModules(InstanceGraph &instanceGraph) {
1551  for (auto *node : instanceGraph) {
1552  auto module = cast<FModuleLike>(*node->getModule());
1553 
1554  // Handle class declarations here.
1555  bool shouldFixupObjects = false;
1556  auto classOp = dyn_cast<ClassOp>(module.getOperation());
1557  if (classOp)
1558  shouldFixupObjects = fixupClassOp(classOp);
1559 
1560  for (auto *instRec : node->uses()) {
1561  // Handle object instantiations here.
1562  if (classOp) {
1563  if (shouldFixupObjects) {
1564  fixupObjectOp(instRec->getInstance<ObjectOp>(),
1565  classOp.getInstanceType());
1566  }
1567  continue;
1568  }
1569 
1570  auto inst = instRec->getInstance<InstanceOp>();
1571  // Only handle module instantiations here.
1572  if (!inst)
1573  continue;
1574  ImplicitLocOpBuilder builder(inst.getLoc(), inst->getContext());
1575  builder.setInsertionPointAfter(inst);
1576  for (size_t i = 0, e = getNumPorts(module); i < e; ++i) {
1577  auto result = inst.getResult(i);
1578  auto newType = module.getPortType(i);
1579  auto oldType = result.getType();
1580  // If the type has not changed, we don't have to fix up anything.
1581  if (newType == oldType)
1582  continue;
1583  // If the type changed we transform it back to the old type with an
1584  // intermediate wire.
1585  auto wire =
1586  builder.create<WireOp>(oldType, inst.getPortName(i)).getResult();
1587  result.replaceAllUsesWith(wire);
1588  result.setType(newType);
1589  if (inst.getPortDirection(i) == Direction::Out)
1590  fixupConnect(builder, wire, result);
1591  else
1592  fixupConnect(builder, result, wire);
1593  }
1594  }
1595  }
1596 }
1597 
1598 namespace llvm {
1599 /// A DenseMapInfo implementation for `ModuleInfo` that is a pair of
1600 /// llvm::SHA256 hashes, which are represented as std::array<uint8_t, 32>, and
1601 /// an array of string attributes. This allows us to create a DenseMap with
1602 /// `ModuleInfo` as keys.
1603 template <>
1604 struct DenseMapInfo<ModuleInfo> {
1605  static inline ModuleInfo getEmptyKey() {
1606  std::array<uint8_t, 32> key;
1607  std::fill(key.begin(), key.end(), ~0);
1608  return {key, {}};
1609  }
1610 
1611  static inline ModuleInfo getTombstoneKey() {
1612  std::array<uint8_t, 32> key;
1613  std::fill(key.begin(), key.end(), ~0 - 1);
1614  return {key, {}};
1615  }
1616 
1617  static unsigned getHashValue(const ModuleInfo &val) {
1618  // We assume SHA256 is already a good hash and just truncate down to the
1619  // number of bytes we need for DenseMap.
1620  unsigned hash;
1621  std::memcpy(&hash, val.structuralHash.data(), sizeof(unsigned));
1622 
1623  // Combine module names.
1624  return llvm::hash_combine(
1625  hash, llvm::hash_combine_range(val.referredModuleNames.begin(),
1626  val.referredModuleNames.end()));
1627  }
1628 
1629  static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs) {
1630  return lhs == rhs;
1631  }
1632 };
1633 } // namespace llvm
1634 
1635 //===----------------------------------------------------------------------===//
1636 // DedupPass
1637 //===----------------------------------------------------------------------===//
1638 
1639 namespace {
1640 class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {
1641  void runOnOperation() override {
1642  auto *context = &getContext();
1643  auto circuit = getOperation();
1644  auto &instanceGraph = getAnalysis<InstanceGraph>();
1645  auto *nlaTable = &getAnalysis<NLATable>();
1646  auto &symbolTable = getAnalysis<SymbolTable>();
1647  Deduper deduper(instanceGraph, symbolTable, nlaTable, circuit);
1648  Equivalence equiv(context, instanceGraph);
1649  auto anythingChanged = false;
1650 
1651  // Modules annotated with this should not be considered for deduplication.
1652  auto noDedupClass = StringAttr::get(context, noDedupAnnoClass);
1653 
1654  // Only modules within the same group may be deduplicated.
1655  auto dedupGroupClass = StringAttr::get(context, dedupGroupAnnoClass);
1656 
1657  // A map of all the module moduleInfo that we have calculated so far.
1658  llvm::DenseMap<ModuleInfo, Operation *> moduleInfoToModule;
1659 
1660  // We track the name of the module that each module is deduped into, so that
1661  // we can make sure all modules which are marked "must dedup" with each
1662  // other were all deduped to the same module.
1663  DenseMap<Attribute, StringAttr> dedupMap;
1664 
1665  // We must iterate the modules from the bottom up so that we can properly
1666  // deduplicate the modules. We copy the list of modules into a vector first
1667  // to avoid iterator invalidation while we mutate the instance graph.
1668  SmallVector<FModuleLike, 0> modules(
1669  llvm::map_range(llvm::post_order(&instanceGraph), [](auto *node) {
1670  return cast<FModuleLike>(*node->getModule());
1671  }));
1672 
1673  SmallVector<std::optional<ModuleInfo>> moduleInfos(modules.size());
1674  StructuralHasherSharedConstants hasherConstants(&getContext());
1675 
1676  // Attribute name used to store dedup_group for this pass.
1677  auto dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
1678 
1679  // Move dedup group annotations to attributes on the module.
1680  // This results in the desired behavior (included in hash),
1681  // and avoids unnecessary processing of these as annotations
1682  // that need to be tracked, made non-local, so on.
1683  for (auto module : modules) {
1684  llvm::SmallSetVector<StringAttr, 1> groups;
1686  module, [&groups, dedupGroupClass](Annotation annotation) {
1687  if (annotation.getClassAttr() != dedupGroupClass)
1688  return false;
1689  groups.insert(annotation.getMember<StringAttr>("group"));
1690  return true;
1691  });
1692  if (groups.size() > 1) {
1693  module.emitError("module belongs to multiple dedup groups: ") << groups;
1694  return signalPassFailure();
1695  }
1696  assert(!module->hasAttr(dedupGroupAttrName) &&
1697  "unexpected existing use of temporary dedup group attribute");
1698  if (!groups.empty())
1699  module->setDiscardableAttr(dedupGroupAttrName, groups.front());
1700  }
1701 
1702  // Calculate module information parallelly.
1703  auto result = mlir::failableParallelForEach(
1704  context, llvm::seq(modules.size()), [&](unsigned idx) {
1705  auto module = modules[idx];
1706  // If the module is marked with NoDedup, just skip it.
1707  if (AnnotationSet::hasAnnotation(module, noDedupClass))
1708  return success();
1709 
1710  // Only dedup extmodule's with defname.
1711  if (auto ext = dyn_cast<FExtModuleOp>(*module);
1712  ext && !ext.getDefname().has_value())
1713  return success();
1714 
1715  StructuralHasher hasher(hasherConstants);
1716  // Calculate the hash of the module and referred module names.
1717  moduleInfos[idx] = hasher.getModuleInfo(module);
1718  return success();
1719  });
1720 
1721  if (result.failed())
1722  return signalPassFailure();
1723 
1724  for (auto [i, module] : llvm::enumerate(modules)) {
1725  auto moduleName = module.getModuleNameAttr();
1726  auto &maybeModuleInfo = moduleInfos[i];
1727  // If the hash was not calculated, we need to skip it.
1728  if (!maybeModuleInfo) {
1729  // We record it in the dedup map to help detect errors when the user
1730  // marks the module as both NoDedup and MustDedup. We do not record this
1731  // module in the hasher to make sure no other module dedups "into" this
1732  // one.
1733  dedupMap[moduleName] = moduleName;
1734  continue;
1735  }
1736 
1737  auto &moduleInfo = maybeModuleInfo.value();
1738 
1739  // Replace module names referred in the module with new names.
1740  for (auto &referredModule : moduleInfo.referredModuleNames)
1741  referredModule = dedupMap[referredModule];
1742 
1743  // Check if there a module with the same hash.
1744  auto it = moduleInfoToModule.find(moduleInfo);
1745  if (it != moduleInfoToModule.end()) {
1746  auto original = cast<FModuleLike>(it->second);
1747  auto originalName = original.getModuleNameAttr();
1748 
1749  // If the current module is public, and the original is private, we
1750  // want to dedup the private module into the public one.
1751  if (!canRemoveModule(module)) {
1752  // If both modules are public, then we can't dedup anything.
1753  if (!canRemoveModule(original))
1754  continue;
1755  // Swap the canonical module in the dedup map.
1756  for (auto &[originalName, dedupedName] : dedupMap)
1757  if (dedupedName == originalName)
1758  dedupedName = moduleName;
1759  // Update the module hash table to point to the new original, so all
1760  // future modules dedup with the new canonical module.
1761  it->second = module;
1762  // Swap the locals.
1763  std::swap(originalName, moduleName);
1764  std::swap(original, module);
1765  }
1766 
1767  // Record the group ID of the other module.
1768  dedupMap[moduleName] = originalName;
1769  deduper.dedup(original, module);
1770  ++erasedModules;
1771  anythingChanged = true;
1772  continue;
1773  }
1774  // Any module not deduplicated must be recorded.
1775  deduper.record(module);
1776  // Add the module to a new dedup group.
1777  dedupMap[moduleName] = moduleName;
1778  // Record the module info.
1779  moduleInfoToModule[std::move(moduleInfo)] = module;
1780  }
1781 
1782  // This part verifies that all modules marked by "MustDedup" have been
1783  // properly deduped with each other. For this check to succeed, all modules
1784  // have to been deduped to the same module. It is possible that a module was
1785  // deduped with the wrong thing.
1786 
1787  auto failed = false;
1788  // This parses the module name out of a target string.
1789  auto parseModule = [&](Attribute path) -> StringAttr {
1790  // Each module is listed as a target "~Circuit|Module" which we have to
1791  // parse.
1792  auto [_, rhs] = cast<StringAttr>(path).getValue().split('|');
1793  return StringAttr::get(context, rhs);
1794  };
1795  // This gets the name of the module which the current module was deduped
1796  // with. If the named module isn't in the map, then we didn't encounter it
1797  // in the circuit.
1798  auto getLead = [&](StringAttr module) -> StringAttr {
1799  auto it = dedupMap.find(module);
1800  if (it == dedupMap.end()) {
1801  auto diag = emitError(circuit.getLoc(),
1802  "MustDeduplicateAnnotation references module ")
1803  << module << " which does not exist";
1804  failed = true;
1805  return nullptr;
1806  }
1807  return it->second;
1808  };
1809 
1810  AnnotationSet::removeAnnotations(circuit, [&](Annotation annotation) {
1811  if (!annotation.isClass(mustDedupAnnoClass))
1812  return false;
1813  auto modules = annotation.getMember<ArrayAttr>("modules");
1814  if (!modules) {
1815  emitError(circuit.getLoc(),
1816  "MustDeduplicateAnnotation missing \"modules\" member");
1817  failed = true;
1818  return false;
1819  }
1820  // Empty module list has nothing to process.
1821  if (modules.empty())
1822  return true;
1823  // Get the first element.
1824  auto firstModule = parseModule(modules[0]);
1825  auto firstLead = getLead(firstModule);
1826  if (!firstLead)
1827  return false;
1828  // Verify that the remaining elements are all the same as the first.
1829  for (auto attr : modules.getValue().drop_front()) {
1830  auto nextModule = parseModule(attr);
1831  auto nextLead = getLead(nextModule);
1832  if (!nextLead)
1833  return false;
1834  if (firstLead != nextLead) {
1835  auto diag = emitError(circuit.getLoc(), "module ")
1836  << nextModule << " not deduplicated with " << firstModule;
1837  auto a = instanceGraph.lookup(firstLead)->getModule();
1838  auto b = instanceGraph.lookup(nextLead)->getModule();
1839  equiv.check(diag, a, b);
1840  failed = true;
1841  return false;
1842  }
1843  }
1844  return true;
1845  });
1846  if (failed)
1847  return signalPassFailure();
1848 
1849  // Remove all dedup group attributes, they only exist during this pass.
1850  for (auto module : circuit.getOps<FModuleLike>())
1851  module->removeDiscardableAttr(dedupGroupAttrName);
1852 
1853  // Walk all the modules and fixup the instance operation to return the
1854  // correct type. We delay this fixup until the end because doing it early
1855  // can block the deduplication of the parent modules.
1856  fixupAllModules(instanceGraph);
1857 
1858  markAnalysesPreserved<NLATable>();
1859  if (!anythingChanged)
1860  markAllAnalysesPreserved();
1861  }
1862 };
1863 } // end anonymous namespace
1864 
1865 std::unique_ptr<mlir::Pass> circt::firrtl::createDedupPass() {
1866  return std::make_unique<DedupPass>();
1867 }
assert(baseType &&"element must be base type")
static Location mergeLoc(MLIRContext *context, Location to, Location from)
Definition: Dedup.cpp:875
void fixupObjectOp(ObjectOp objectOp, ClassType newClassType)
This fixes up ObjectOps when the signature of their ClassOp changes.
Definition: Dedup.cpp:1517
llvm::raw_ostream & printHex(llvm::raw_ostream &stream, ArrayRef< uint8_t > bytes)
Definition: Dedup.cpp:73
static bool canRemoveModule(mlir::SymbolOpInterface symbol)
Returns true if the module can be removed.
Definition: Dedup.cpp:51
void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src)
This fixes up connects when the field names of a bundle type changes.
Definition: Dedup.cpp:1524
llvm::raw_ostream & printHash(llvm::raw_ostream &stream, llvm::SHA256 &data)
Definition: Dedup.cpp:79
void fixupAllModules(InstanceGraph &instanceGraph)
This is the root method to fixup module references when a module changes.
Definition: Dedup.cpp:1550
bool fixupClassOp(ClassOp classOp)
This fixes up ClassLikes with ClassType ports, when the classes have deduped.
Definition: Dedup.cpp:1459
static void mergeRegions(Region *region1, Region *region2)
Definition: HWCleanup.cpp:77
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
void setDict(DictionaryAttr dict)
Set the data dictionary of this attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
StringAttr getClassAttr() const
Return the 'class' that this annotation is representing.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition: NLATable.h:29
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
ClassType getInstanceTypeForClassLike(ClassLike classOp)
Definition: FIRRTLOps.cpp:1835
static StringRef toString(Direction direction)
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
constexpr const char * mustDedupAnnoClass
std::pair< hw::InnerSymAttr, StringAttr > getOrAddInnerSym(MLIRContext *context, hw::InnerSymAttr attr, uint64_t fieldID, llvm::function_ref< hw::InnerSymbolNamespace &()> getNamespace)
Ensure that the the InnerSymAttr has a symbol on the field specified.
constexpr const char * noDedupAnnoClass
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
Definition: FIRRTLOps.cpp:301
constexpr const char * dedupGroupAnnoClass
std::unique_ptr< mlir::Pass > createDedupPass()
Definition: Dedup.cpp:1865
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
constexpr const char * dontTouchAnnoClass
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
Definition: FIRRTLUtils.cpp:25
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
Definition: FIRRTLOps.h:108
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
bool operator==(uint64_t a, const FVInt &b)
Definition: FVInt.h:650
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition: Utils.h:32
MLIRContext * context
Definition: Dedup.cpp:1426
Block * nlaBlock
We insert all NLAs to the beginning of this block.
Definition: Dedup.cpp:1434
void recordAnnotations(Operation *op)
Record all targets which use an NLA.
Definition: Dedup.cpp:1005
void eraseNLA(hw::HierPathOp nla)
This erases the NLA op, and removes the NLA from every module's NLA map, but it does not delete the N...
Definition: Dedup.cpp:1114
void mergeAnnotations(FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Merge all annotations and port annotations on two operations.
Definition: Dedup.cpp:1295
void replaceInstances(FModuleLike toModule, Operation *fromModule)
This deletes and replaces all instances of the "fromModule" with instances of the "toModule".
Definition: Dedup.cpp:1021
SmallVector< FlatSymbolRefAttr > createNLAs(StringAttr toModuleName, FModuleLike fromModule, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of this module and create an NLA for each one.
Definition: Dedup.cpp:1087
void record(FModuleLike module)
Record the usages of any NLA's in this module, so that we may update the annotation if the parent mod...
Definition: Dedup.cpp:979
void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName, StringAttr fromName)
Definition: Dedup.cpp:1193
void mergeRegions(RenameMap &renameMap, FModuleLike toModule, Region &toRegion, FModuleLike fromModule, Region &fromRegion)
Definition: Dedup.cpp:1418
void dedup(FModuleLike toModule, FModuleLike fromModule)
Remove the "fromModule", and replace all references to it with the "toModule".
Definition: Dedup.cpp:944
void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all the NLAs that the two modules participate in, replacing references to the "from" module w...
Definition: Dedup.cpp:1185
void recordAnnotations(AnnoTarget target)
For a specific annotation target, record all the unique NLAs which target it in the targetMap.
Definition: Dedup.cpp:998
void cloneAnnotation(SmallVectorImpl< FlatSymbolRefAttr > &nlas, Annotation anno, ArrayRef< NamedAttribute > attributes, unsigned nonLocalIndex, SmallVectorImpl< Annotation > &newAnnotations)
Clone the annotation for each NLA in a list.
Definition: Dedup.cpp:1095
void recordSymRenames(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Definition: Dedup.cpp:1323
void mergeAnnotations(FModuleLike toModule, AnnoTarget to, AnnotationSet toAnnos, FModuleLike fromModule, AnnoTarget from, AnnotationSet fromAnnos)
Merge the annotations of a specific target, either a operation or a port on an operation.
Definition: Dedup.cpp:1272
StringAttr nonLocalString
Definition: Dedup.cpp:1444
SymbolTable & symbolTable
Definition: Dedup.cpp:1428
SmallVector< FlatSymbolRefAttr > createNLAs(Operation *fromModule, ArrayRef< Attribute > baseNamepath, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of the from module and create an NLA for each one, appending the baseNamep...
Definition: Dedup.cpp:1050
void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Recursively merge two operations.
Definition: Dedup.cpp:1385
hw::InnerSymbolNamespace & getNamespace(Operation *module)
Get a cached namespace for a module.
Definition: Dedup.cpp:991
DenseMap< Operation *, hw::InnerSymbolNamespace > moduleNamespaces
A module namespace cache.
Definition: Dedup.cpp:1448
bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to, FModuleLike fromModule, Annotation anno, SmallVectorImpl< Annotation > &newAnnotations)
Take an annotation, and update it to be a non-local annotation.
Definition: Dedup.cpp:1201
InstanceGraph & instanceGraph
Definition: Dedup.cpp:1427
void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock, FModuleLike fromModule, Block &fromBlock)
Recursively merge two blocks.
Definition: Dedup.cpp:1404
DenseMap< Attribute, llvm::SmallDenseSet< AnnoTarget > > targetMap
Definition: Dedup.cpp:1437
StringAttr classString
Definition: Dedup.cpp:1445
void copyAnnotations(FModuleLike toModule, AnnoTarget to, FModuleLike fromModule, AnnotationSet annos, SmallVectorImpl< Annotation > &newAnnotations, SmallPtrSetImpl< Attribute > &dontTouches)
Definition: Dedup.cpp:1243
Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable, NLATable *nlaTable, CircuitOp circuit)
Definition: Dedup.cpp:928
void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all NLAs referencing the "from" module to point to the "to" module.
Definition: Dedup.cpp:1124
DenseMap< StringAttr, StringAttr > RenameMap
Definition: Dedup.cpp:926
DenseMap< Attribute, Attribute > nlaCache
Definition: Dedup.cpp:1441
const hw::InnerSymbolTable & a
Definition: Dedup.cpp:409
ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
Definition: Dedup.cpp:406
const hw::InnerSymbolTable & b
Definition: Dedup.cpp:410
This class is for reporting differences between two modules which should have been deduplicated.
Definition: Dedup.cpp:388
DenseSet< Attribute > nonessentialAttributes
Definition: Dedup.cpp:862
std::string prettyPrint(Attribute attr)
Definition: Dedup.cpp:413
LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a, FInstanceLike b)
Definition: Dedup.cpp:695
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Block &aBlock, Operation *b, Block &bBlock)
Definition: Dedup.cpp:484
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, Type aType, Operation *b, Type bType)
Definition: Dedup.cpp:458
StringAttr noDedupClass
Definition: Dedup.cpp:857
StringAttr dedupGroupAttrName
Definition: Dedup.cpp:859
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, DictionaryAttr aDict, Operation *b, DictionaryAttr bDict)
Definition: Dedup.cpp:607
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Region &aRegion, Operation *b, Region &bRegion)
Definition: Dedup.cpp:562
StringAttr portDirectionsAttr
Definition: Dedup.cpp:855
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, BundleType aType, Operation *b, BundleType bType)
Definition: Dedup.cpp:429
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Operation *b)
Definition: Dedup.cpp:716
Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
Definition: Dedup.cpp:389
LogicalResult check(InFlightDiagnostic &diag, Operation *a, mlir::DenseBoolArrayAttr aAttr, Operation *b, mlir::DenseBoolArrayAttr bAttr)
Definition: Dedup.cpp:582
InstanceGraph & instanceGraph
Definition: Dedup.cpp:863
void check(InFlightDiagnostic &diag, Operation *a, Operation *b)
Definition: Dedup.cpp:803
std::vector< StringAttr > referredModuleNames
Definition: Dedup.cpp:98
std::array< uint8_t, 32 > structuralHash
Definition: Dedup.cpp:96
This struct contains constant string attributes shared across different threads.
Definition: Dedup.cpp:108
DenseSet< Attribute > nonessentialAttributes
Definition: Dedup.cpp:141
StructuralHasherSharedConstants(MLIRContext *context)
Definition: Dedup.cpp:109
void update(Operation *op, DictionaryAttr dict)
Hash the top level attribute dictionary of the operation.
Definition: Dedup.cpp:238
void update(Type type)
Definition: Dedup.cpp:218
void update(const void *pointer)
Definition: Dedup.cpp:196
DenseMap< void *, unsigned > idTable
Definition: Dedup.cpp:361
void update(Operation *op)
Definition: Dedup.cpp:333
llvm::SHA256 sha
Definition: Dedup.cpp:376
ModuleInfo getModuleInfo(FModuleLike module)
Definition: Dedup.cpp:148
void update(size_t value)
Definition: Dedup.cpp:201
unsigned getInnerSymID(StringAttr name)
Definition: Dedup.cpp:172
void update(BundleType type)
Definition: Dedup.cpp:209
unsigned getID(void *object)
Definition: Dedup.cpp:155
void update(OpResult result)
Definition: Dedup.cpp:224
void update(OpOperand &operand)
Definition: Dedup.cpp:179
StructuralHasher(const StructuralHasherSharedConstants &constants)
Definition: Dedup.cpp:145
void update(Region *region)
Definition: Dedup.cpp:325
void update(Block *block)
Definition: Dedup.cpp:314
std::vector< StringAttr > referredModuleNames
Definition: Dedup.cpp:369
DenseMap< StringAttr, unsigned > innerSymIDTable
Definition: Dedup.cpp:365
void update(TypeID typeID)
Definition: Dedup.cpp:206
const StructuralHasherSharedConstants & constants
Definition: Dedup.cpp:372
unsigned finalizeID(void *object)
Definition: Dedup.cpp:163
void update(mlir::OperationName name)
Definition: Dedup.cpp:308
An annotation target is used to keep track of something that is targeted by an Annotation.
AnnotationSet getAnnotations() const
Get the annotations associated with the target.
void setAnnotations(AnnotationSet annotations) const
Set the annotations associated with the target.
This represents an annotation targeting a specific operation.
Attribute getNLAReference(hw::InnerSymbolNamespace &moduleNamespace) const
This represents an annotation targeting a specific port of a module, memory, or instance.
static ModuleInfo getEmptyKey()
Definition: Dedup.cpp:1605
static ModuleInfo getTombstoneKey()
Definition: Dedup.cpp:1611
static unsigned getHashValue(const ModuleInfo &val)
Definition: Dedup.cpp:1617
static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs)
Definition: Dedup.cpp:1629