CIRCT  19.0.0git
Dedup.cpp
Go to the documentation of this file.
1 //===- Dedup.cpp - FIRRTL module deduping -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements FIRRTL module deduplication.
10 //
11 //===----------------------------------------------------------------------===//
12 
23 #include "circt/Support/LLVM.h"
24 #include "mlir/IR/IRMapping.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/IR/Threading.h"
27 #include "mlir/Pass/Pass.h"
28 #include "mlir/Support/LogicalResult.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/DenseMapInfo.h"
31 #include "llvm/ADT/DepthFirstIterator.h"
32 #include "llvm/ADT/PostOrderIterator.h"
33 #include "llvm/ADT/SmallPtrSet.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Format.h"
36 #include "llvm/Support/SHA256.h"
37 
38 namespace circt {
39 namespace firrtl {
40 #define GEN_PASS_DEF_DEDUP
41 #include "circt/Dialect/FIRRTL/Passes.h.inc"
42 } // namespace firrtl
43 } // namespace circt
44 
45 using namespace circt;
46 using namespace firrtl;
47 using hw::InnerRefAttr;
48 
49 //===----------------------------------------------------------------------===//
50 // Hashing
51 //===----------------------------------------------------------------------===//
52 
53 llvm::raw_ostream &printHex(llvm::raw_ostream &stream,
54  ArrayRef<uint8_t> bytes) {
55  // Print the hash on a single line.
56  return stream << format_bytes(bytes, std::nullopt, 32) << "\n";
57 }
58 
59 llvm::raw_ostream &printHash(llvm::raw_ostream &stream, llvm::SHA256 &data) {
60  return printHex(stream, data.result());
61 }
62 
63 llvm::raw_ostream &printHash(llvm::raw_ostream &stream, std::string data) {
64  ArrayRef<uint8_t> bytes(reinterpret_cast<const uint8_t *>(data.c_str()),
65  data.length());
66  return printHex(stream, bytes);
67 }
68 
69 // This struct contains information to determine module module uniqueness. A
70 // first element is a structural hash of the module, and the second element is
71 // an array which tracks module names encountered in the walk. Since module
72 // names could be replaced during dedup, it's necessary to keep names up-to-date
73 // before actually combining them into structural hashes.
74 struct ModuleInfo {
75  // SHA256 hash.
76  std::array<uint8_t, 32> structuralHash;
77  // Module names referred by instance op in the module.
78  mlir::ArrayAttr referredModuleNames;
79 };
80 
81 struct SymbolTarget {
82  uint64_t index;
83  uint64_t fieldID;
84 };
85 
86 /// This struct contains constant string attributes shared across different
87 /// threads.
89  explicit StructuralHasherSharedConstants(MLIRContext *context) {
90  portTypesAttr = StringAttr::get(context, "portTypes");
91  moduleNameAttr = StringAttr::get(context, "moduleName");
92  innerSymAttr = StringAttr::get(context, "inner_sym");
93  portSymsAttr = StringAttr::get(context, "portSyms");
94  portNamesAttr = StringAttr::get(context, "portNames");
95  nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
96  nonessentialAttributes.insert(StringAttr::get(context, "name"));
97  nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
98  nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
99  nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
100  nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
101  };
102 
103  // This is a cached "portTypes" string attr.
104  StringAttr portTypesAttr;
105 
106  // This is a cached "moduleName" string attr.
107  StringAttr moduleNameAttr;
108 
109  // This is a cached "inner_sym" string attr.
110  StringAttr innerSymAttr;
111 
112  // This is a cached "portSyms" string attr.
113  StringAttr portSymsAttr;
114 
115  // This is a cached "portNames" string attr.
116  StringAttr portNamesAttr;
117 
118  // This is a set of every attribute we should ignore.
119  DenseSet<Attribute> nonessentialAttributes;
120 };
121 
124  : constants(constants){};
125 
126  std::pair<std::array<uint8_t, 32>, SmallVector<StringAttr>>
127  getHashAndModuleNames(FModuleLike module) {
128  update(&(*module));
129  auto hash = sha.final();
130  return {hash, referredModuleNames};
131  }
132 
133 private:
134  void update(const void *pointer) {
135  auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
136  sha.update(ArrayRef<uint8_t>(addr, sizeof pointer));
137  }
138 
139  void update(size_t value) {
140  auto *addr = reinterpret_cast<const uint8_t *>(&value);
141  sha.update(ArrayRef<uint8_t>(addr, sizeof value));
142  }
143 
144  void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
145 
146  // NOLINTNEXTLINE(misc-no-recursion)
147  void update(BundleType type) {
148  update(type.getTypeID());
149  for (auto &element : type.getElements()) {
150  update(element.isFlip);
151  update(element.type);
152  }
153  }
154 
155  // NOLINTNEXTLINE(misc-no-recursion)
156  void update(Type type) {
157  if (auto bundle = type_dyn_cast<BundleType>(type))
158  return update(bundle);
159  update(type.getAsOpaquePointer());
160  }
161 
162  void record(void *address) {
163  auto size = indices.size();
164  indices[address] = size;
165  }
166 
167  void update(BlockArgument arg) { record(arg.getAsOpaquePointer()); }
168 
169  void update(OpResult result) {
170  record(result.getAsOpaquePointer());
171 
172  // Like instance ops, don't use object ops' result types since they might be
173  // replaced by dedup. Record the class names and lazily combine their hashes
174  // using the same mechanism as instances and modules.
175  if (auto objectOp = dyn_cast<ObjectOp>(result.getOwner())) {
176  referredModuleNames.push_back(objectOp.getType().getNameAttr().getAttr());
177  return;
178  }
179 
180  update(result.getType());
181  }
182 
183  void update(OpOperand &operand) {
184  // We hash the value's index as it apears in the block.
185  auto it = indices.find(operand.get().getAsOpaquePointer());
186  assert(it != indices.end() && "op should have been previously hashed");
187  update(it->second);
188  }
189 
190  void update(Operation *op, hw::InnerSymAttr attr) {
191  for (auto props : attr)
192  innerSymTargets[props.getName()] =
193  SymbolTarget{indices[op], props.getFieldID()};
194  }
195 
196  void update(Value value, hw::InnerSymAttr attr) {
197  for (auto props : attr)
198  innerSymTargets[props.getName()] =
199  SymbolTarget{indices[value.getAsOpaquePointer()], props.getFieldID()};
200  }
201 
202  void update(const SymbolTarget &target) {
203  update(target.index);
204  update(target.fieldID);
205  }
206 
207  void update(InnerRefAttr attr) {
208  // We hash the value's index as it apears in the block.
209  auto it = innerSymTargets.find(attr.getName());
210  assert(it != innerSymTargets.end() &&
211  "inner symbol should have been previously hashed");
212  update(attr.getTypeID());
213  update(it->second);
214  }
215 
216  /// Hash the top level attribute dictionary of the operation. This function
217  /// has special handling for inner symbols, ports, and referenced modules.
218  void update(Operation *op, DictionaryAttr dict) {
219  for (auto namedAttr : dict) {
220  auto name = namedAttr.getName();
221  auto value = namedAttr.getValue();
222  // Skip names and annotations, except in certain cases.
223  // Names of ports are load bearing for classes, so we do hash those.
224  bool isClassPortNames =
225  isa<ClassLike>(op) && name == constants.portNamesAttr;
226  if (constants.nonessentialAttributes.contains(name) && !isClassPortNames)
227  continue;
228 
229  // Hash the port types.
230  if (name == constants.portTypesAttr) {
231  auto portTypes = cast<ArrayAttr>(value).getAsValueRange<TypeAttr>();
232  for (auto type : portTypes)
233  update(type);
234  continue;
235  }
236 
237  // Special case the InnerSymbols to ignore the symbol names.
238  if (name == constants.portSymsAttr) {
239  if (op->getNumRegions() != 1)
240  continue;
241  auto &region = op->getRegion(0);
242  if (region.getBlocks().empty())
243  continue;
244  auto *block = &region.front();
245  auto syms = cast<ArrayAttr>(value).getAsRange<hw::InnerSymAttr>();
246  if (syms.empty())
247  continue;
248  for (auto [arg, sym] : llvm::zip_equal(block->getArguments(), syms))
249  update(arg, sym);
250  continue;
251  }
252  if (name == constants.innerSymAttr) {
253  auto innerSym = cast<hw::InnerSymAttr>(value);
254  update(op, innerSym);
255  continue;
256  }
257 
258  // For instance op, don't use `moduleName` attributes since they might be
259  // replaced by dedup. Record the names and lazily combine their hashes.
260  // It is assumed that module names are hashed only through instance ops;
261  // it could cause suboptimal results if there was other operation that
262  // refers to module names through essential attributes.
263  if (isa<InstanceOp>(op) && name == constants.moduleNameAttr) {
264  referredModuleNames.push_back(cast<FlatSymbolRefAttr>(value).getAttr());
265  continue;
266  }
267 
268  // Hash the interned pointer.
269  update(name.getAsOpaquePointer());
270 
271  // TODO: properly handle DistinctAttr, including its use in paths.
272  // See https://github.com/llvm/circt/issues/6583.
273  if (isa<DistinctAttr>(value))
274  continue;
275 
276  // If this is an symbol reference, we need to perform name erasure.
277  if (auto innerRef = dyn_cast<hw::InnerRefAttr>(value))
278  update(innerRef);
279  else
280  update(value.getAsOpaquePointer());
281  }
282  }
283 
284  void update(Block &block) {
285  // Hash the block arguments.
286  for (auto arg : block.getArguments())
287  update(arg);
288  // Hash the operations in the block.
289  for (auto &op : block)
290  update(&op);
291  }
292 
293  void update(mlir::OperationName name) {
294  // Operation names are interned.
295  update(name.getAsOpaquePointer());
296  }
297 
298  // NOLINTNEXTLINE(misc-no-recursion)
299  void update(Operation *op) {
300  record(op);
301  update(op->getName());
302  update(op, op->getAttrDictionary());
303  // Hash the operands.
304  for (auto &operand : op->getOpOperands())
305  update(operand);
306  // Hash the regions. We need to make sure an empty region doesn't hash the
307  // same as no region, so we include the number of regions.
308  update(op->getNumRegions());
309  for (auto &region : op->getRegions())
310  for (auto &block : region.getBlocks())
311  update(block);
312  // Record any op results.
313  for (auto result : op->getResults())
314  update(result);
315  }
316 
317  // Every operation and value is assigned a unique id based on their order of
318  // appearance
319  DenseMap<void *, unsigned> indices;
320 
321  // Every value is assigned a unique id based on their order of appearance.
322  DenseMap<StringAttr, SymbolTarget> innerSymTargets;
323 
324  // This keeps track of module names in the order of the appearance.
325  SmallVector<mlir::StringAttr> referredModuleNames;
326 
327  // String constants.
329 
330  // This is the actual running hash calculation. This is a stateful element
331  // that should be reinitialized after each hash is produced.
332  llvm::SHA256 sha;
333 };
334 
335 //===----------------------------------------------------------------------===//
336 // Equivalence
337 //===----------------------------------------------------------------------===//
338 
339 /// This class is for reporting differences between two modules which should
340 /// have been deduplicated.
341 struct Equivalence {
342  Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
343  : instanceGraph(instanceGraph) {
344  noDedupClass = StringAttr::get(context, noDedupAnnoClass);
345  dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
346  portDirectionsAttr = StringAttr::get(context, "portDirections");
347  nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
348  nonessentialAttributes.insert(StringAttr::get(context, "name"));
349  nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
350  nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
351  nonessentialAttributes.insert(StringAttr::get(context, "portTypes"));
352  nonessentialAttributes.insert(StringAttr::get(context, "portSyms"));
353  nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
354  nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
355  nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
356  }
357 
358  struct ModuleData {
359  ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
360  : a(a), b(b) {}
361  IRMapping map;
362  const hw::InnerSymbolTable &a;
363  const hw::InnerSymbolTable &b;
364  };
365 
366  std::string prettyPrint(Attribute attr) {
367  SmallString<64> buffer;
368  llvm::raw_svector_ostream os(buffer);
369  if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
370  os << "0x";
371  if (integerAttr.getType().isSignlessInteger())
372  integerAttr.getValue().toStringUnsigned(buffer, /*radix=*/16);
373  else
374  integerAttr.getAPSInt().toString(buffer, /*radix=*/16);
375 
376  } else
377  os << attr;
378  return std::string(buffer);
379  }
380 
381  // NOLINTNEXTLINE(misc-no-recursion)
382  LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
383  Operation *a, BundleType aType, Operation *b,
384  BundleType bType) {
385  if (aType.getNumElements() != bType.getNumElements()) {
386  diag.attachNote(a->getLoc())
387  << message << " bundle type has different number of elements";
388  diag.attachNote(b->getLoc()) << "second operation here";
389  return failure();
390  }
391 
392  for (auto elementPair :
393  llvm::zip(aType.getElements(), bType.getElements())) {
394  auto aElement = std::get<0>(elementPair);
395  auto bElement = std::get<1>(elementPair);
396  if (aElement.isFlip != bElement.isFlip) {
397  diag.attachNote(a->getLoc()) << message << " bundle element "
398  << aElement.name << " flip does not match";
399  diag.attachNote(b->getLoc()) << "second operation here";
400  return failure();
401  }
402 
403  if (failed(check(diag,
404  "bundle element \'" + aElement.name.getValue() + "'", a,
405  aElement.type, b, bElement.type)))
406  return failure();
407  }
408  return success();
409  }
410 
411  LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
412  Operation *a, Type aType, Operation *b, Type bType) {
413  if (aType == bType)
414  return success();
415  if (auto aBundleType = type_dyn_cast<BundleType>(aType))
416  if (auto bBundleType = type_dyn_cast<BundleType>(bType))
417  return check(diag, message, a, aBundleType, b, bBundleType);
418  if (type_isa<RefType>(aType) && type_isa<RefType>(bType) &&
419  aType != bType) {
420  diag.attachNote(a->getLoc())
421  << message << ", has a RefType with a different base type "
422  << type_cast<RefType>(aType).getType()
423  << " in the same position of the two modules marked as 'must dedup'. "
424  "(This may be due to Grand Central Taps or Views being different "
425  "between the two modules.)";
426  diag.attachNote(b->getLoc())
427  << "the second module has a different base type "
428  << type_cast<RefType>(bType).getType();
429  return failure();
430  }
431  diag.attachNote(a->getLoc())
432  << message << " types don't match, first type is " << aType;
433  diag.attachNote(b->getLoc()) << "second type is " << bType;
434  return failure();
435  }
436 
437  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
438  Block &aBlock, Operation *b, Block &bBlock) {
439 
440  // Block argument types.
441  auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
442  auto portNo = 0;
443  auto emitMissingPort = [&](Value existsVal, Operation *opExists,
444  Operation *opDoesNotExist) {
445  StringRef portName;
446  auto portNames = opExists->getAttrOfType<ArrayAttr>("portNames");
447  if (portNames)
448  if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
449  portName = portNameAttr.getValue();
450  if (type_isa<RefType>(existsVal.getType())) {
451  diag.attachNote(opExists->getLoc())
452  << " contains a RefType port named '" + portName +
453  "' that only exists in one of the modules (can be due to "
454  "difference in Grand Central Tap or View of two modules "
455  "marked with must dedup)";
456  diag.attachNote(opDoesNotExist->getLoc())
457  << "second module to be deduped that does not have the RefType "
458  "port";
459  } else {
460  diag.attachNote(opExists->getLoc())
461  << "port '" + portName + "' only exists in one of the modules";
462  diag.attachNote(opDoesNotExist->getLoc())
463  << "second module to be deduped that does not have the port";
464  }
465  return failure();
466  };
467 
468  for (auto argPair :
469  llvm::zip_longest(aBlock.getArguments(), bBlock.getArguments())) {
470  auto &aArg = std::get<0>(argPair);
471  auto &bArg = std::get<1>(argPair);
472  if (aArg.has_value() && bArg.has_value()) {
473  // TODO: we should print the port number if there are no port names, but
474  // there are always port names ;).
475  StringRef portName;
476  if (portNames) {
477  if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
478  portName = portNameAttr.getValue();
479  }
480  // Assumption here that block arguments correspond to ports.
481  if (failed(check(diag, "module port '" + portName + "'", a,
482  aArg->getType(), b, bArg->getType())))
483  return failure();
484  data.map.map(aArg.value(), bArg.value());
485  portNo++;
486  continue;
487  }
488  if (!aArg.has_value())
489  std::swap(a, b);
490  return emitMissingPort(aArg.has_value() ? aArg.value() : bArg.value(), a,
491  b);
492  }
493 
494  // Blocks operations.
495  auto aIt = aBlock.begin();
496  auto aEnd = aBlock.end();
497  auto bIt = bBlock.begin();
498  auto bEnd = bBlock.end();
499  while (aIt != aEnd && bIt != bEnd)
500  if (failed(check(diag, data, &*aIt++, &*bIt++)))
501  return failure();
502  if (aIt != aEnd) {
503  diag.attachNote(aIt->getLoc()) << "first block has more operations";
504  diag.attachNote(b->getLoc()) << "second block here";
505  return failure();
506  }
507  if (bIt != bEnd) {
508  diag.attachNote(bIt->getLoc()) << "second block has more operations";
509  diag.attachNote(a->getLoc()) << "first block here";
510  return failure();
511  }
512  return success();
513  }
514 
515  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
516  Region &aRegion, Operation *b, Region &bRegion) {
517  auto aIt = aRegion.begin();
518  auto aEnd = aRegion.end();
519  auto bIt = bRegion.begin();
520  auto bEnd = bRegion.end();
521 
522  // Region blocks.
523  while (aIt != aEnd && bIt != bEnd)
524  if (failed(check(diag, data, a, *aIt++, b, *bIt++)))
525  return failure();
526  if (aIt != aEnd || bIt != bEnd) {
527  diag.attachNote(a->getLoc())
528  << "operation regions have different number of blocks";
529  diag.attachNote(b->getLoc()) << "second operation here";
530  return failure();
531  }
532  return success();
533  }
534 
535  LogicalResult check(InFlightDiagnostic &diag, Operation *a,
536  mlir::DenseBoolArrayAttr aAttr, Operation *b,
537  mlir::DenseBoolArrayAttr bAttr) {
538  if (aAttr == bAttr)
539  return success();
540  auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
541  for (unsigned i = 0, e = aAttr.size(); i < e; ++i) {
542  auto aDirection = aAttr[i];
543  auto bDirection = bAttr[i];
544  if (aDirection != bDirection) {
545  auto &note = diag.attachNote(a->getLoc()) << "module port ";
546  if (portNames)
547  note << "'" << cast<StringAttr>(portNames[i]).getValue() << "'";
548  else
549  note << i;
550  note << " directions don't match, first direction is '"
551  << direction::toString(aDirection) << "'";
552  diag.attachNote(b->getLoc()) << "second direction is '"
553  << direction::toString(bDirection) << "'";
554  return failure();
555  }
556  }
557  return success();
558  }
559 
560  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
561  DictionaryAttr aDict, Operation *b,
562  DictionaryAttr bDict) {
563  // Fast path.
564  if (aDict == bDict)
565  return success();
566 
567  DenseSet<Attribute> seenAttrs;
568  for (auto namedAttr : aDict) {
569  auto attrName = namedAttr.getName();
570  if (nonessentialAttributes.contains(attrName))
571  continue;
572 
573  auto aAttr = namedAttr.getValue();
574  auto bAttr = bDict.get(attrName);
575  if (!bAttr) {
576  diag.attachNote(a->getLoc())
577  << "second operation is missing attribute " << attrName;
578  diag.attachNote(b->getLoc()) << "second operation here";
579  return diag;
580  }
581 
582  if (isa<hw::InnerRefAttr>(aAttr) && isa<hw::InnerRefAttr>(bAttr)) {
583  auto bRef = cast<hw::InnerRefAttr>(bAttr);
584  auto aRef = cast<hw::InnerRefAttr>(aAttr);
585  // See if they are pointing at the same operation or port.
586  auto aTarget = data.a.lookup(aRef.getName());
587  auto bTarget = data.b.lookup(bRef.getName());
588  if (!aTarget || !bTarget)
589  diag.attachNote(a->getLoc())
590  << "malformed ir, possibly violating use-before-def";
591  auto error = [&]() {
592  diag.attachNote(a->getLoc())
593  << "operations have different targets, first operation has "
594  << aTarget;
595  diag.attachNote(b->getLoc()) << "second operation has " << bTarget;
596  return failure();
597  };
598  if (aTarget.isPort()) {
599  // If they are targeting ports, make sure its the same port number.
600  if (!bTarget.isPort() || aTarget.getPort() != bTarget.getPort())
601  return error();
602  } else {
603  // Otherwise make sure that they are targeting the same operation.
604  if (!bTarget.isOpOnly() ||
605  aTarget.getOp() != data.map.lookup(bTarget.getOp()))
606  return error();
607  }
608  if (aTarget.getField() != bTarget.getField())
609  return error();
610  } else if (attrName == portDirectionsAttr) {
611  // Special handling for the port directions attribute for better
612  // error messages.
613  if (failed(check(diag, a, cast<mlir::DenseBoolArrayAttr>(aAttr), b,
614  cast<mlir::DenseBoolArrayAttr>(bAttr))))
615  return failure();
616  } else if (isa<DistinctAttr>(aAttr) && isa<DistinctAttr>(bAttr)) {
617  // TODO: properly handle DistinctAttr, including its use in paths.
618  // See https://github.com/llvm/circt/issues/6583
619  } else if (aAttr != bAttr) {
620  diag.attachNote(a->getLoc())
621  << "first operation has attribute '" << attrName.getValue()
622  << "' with value " << prettyPrint(aAttr);
623  diag.attachNote(b->getLoc())
624  << "second operation has value " << prettyPrint(bAttr);
625  return failure();
626  }
627  seenAttrs.insert(attrName);
628  }
629  if (aDict.getValue().size() != bDict.getValue().size()) {
630  for (auto namedAttr : bDict) {
631  auto attrName = namedAttr.getName();
632  // Skip the attribute if we don't care about this particular one or it
633  // is one that is known to be in both dictionaries.
634  if (nonessentialAttributes.contains(attrName) ||
635  seenAttrs.contains(attrName))
636  continue;
637  // We have found an attribute that is only in the second operation.
638  diag.attachNote(a->getLoc())
639  << "first operation is missing attribute " << attrName;
640  diag.attachNote(b->getLoc()) << "second operation here";
641  return failure();
642  }
643  }
644  return success();
645  }
646 
647  // NOLINTNEXTLINE(misc-no-recursion)
648  LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a,
649  FInstanceLike b) {
650  auto aName = a.getReferencedModuleNameAttr();
651  auto bName = b.getReferencedModuleNameAttr();
652  if (aName == bName)
653  return success();
654 
655  // If the modules instantiate are different we will want to know why the
656  // sub module did not dedupliate. This code recursively checks the child
657  // module.
658  auto aModule = instanceGraph.lookup(aName)->getModule();
659  auto bModule = instanceGraph.lookup(bName)->getModule();
660  // Create a new error for the submodule.
661  diag.attachNote(std::nullopt)
662  << "in instance " << a.getInstanceNameAttr() << " of " << aName
663  << ", and instance " << b.getInstanceNameAttr() << " of " << bName;
664  check(diag, aModule, bModule);
665  return failure();
666  }
667 
668  // NOLINTNEXTLINE(misc-no-recursion)
669  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
670  Operation *b) {
671  // Operation name.
672  if (a->getName() != b->getName()) {
673  diag.attachNote(a->getLoc()) << "first operation is a " << a->getName();
674  diag.attachNote(b->getLoc()) << "second operation is a " << b->getName();
675  return failure();
676  }
677 
678  // If its an instance operaiton, perform some checking and possibly
679  // recurse.
680  if (auto aInst = dyn_cast<FInstanceLike>(a)) {
681  auto bInst = cast<FInstanceLike>(b);
682  if (failed(check(diag, aInst, bInst)))
683  return failure();
684  }
685 
686  // Operation results.
687  if (a->getNumResults() != b->getNumResults()) {
688  diag.attachNote(a->getLoc())
689  << "operations have different number of results";
690  diag.attachNote(b->getLoc()) << "second operation here";
691  return failure();
692  }
693  for (auto resultPair : llvm::zip(a->getResults(), b->getResults())) {
694  auto &aValue = std::get<0>(resultPair);
695  auto &bValue = std::get<1>(resultPair);
696  if (failed(check(diag, "operation result", a, aValue.getType(), b,
697  bValue.getType())))
698  return failure();
699  data.map.map(aValue, bValue);
700  }
701 
702  // Operations operands.
703  if (a->getNumOperands() != b->getNumOperands()) {
704  diag.attachNote(a->getLoc())
705  << "operations have different number of operands";
706  diag.attachNote(b->getLoc()) << "second operation here";
707  return failure();
708  }
709  for (auto operandPair : llvm::zip(a->getOperands(), b->getOperands())) {
710  auto &aValue = std::get<0>(operandPair);
711  auto &bValue = std::get<1>(operandPair);
712  if (bValue != data.map.lookup(aValue)) {
713  diag.attachNote(a->getLoc())
714  << "operations use different operands, first operand is '"
715  << getFieldName(
716  getFieldRefFromValue(aValue, /*lookThroughCasts=*/true))
717  .first
718  << "'";
719  diag.attachNote(b->getLoc())
720  << "second operand is '"
721  << getFieldName(
722  getFieldRefFromValue(bValue, /*lookThroughCasts=*/true))
723  .first
724  << "', but should have been '"
725  << getFieldName(getFieldRefFromValue(data.map.lookup(aValue),
726  /*lookThroughCasts=*/true))
727  .first
728  << "'";
729  return failure();
730  }
731  }
732  data.map.map(a, b);
733 
734  // Operation regions.
735  if (a->getNumRegions() != b->getNumRegions()) {
736  diag.attachNote(a->getLoc())
737  << "operations have different number of regions";
738  diag.attachNote(b->getLoc()) << "second operation here";
739  return failure();
740  }
741  for (auto regionPair : llvm::zip(a->getRegions(), b->getRegions())) {
742  auto &aRegion = std::get<0>(regionPair);
743  auto &bRegion = std::get<1>(regionPair);
744  if (failed(check(diag, data, a, aRegion, b, bRegion)))
745  return failure();
746  }
747 
748  // Operation attributes.
749  if (failed(check(diag, data, a, a->getAttrDictionary(), b,
750  b->getAttrDictionary())))
751  return failure();
752  return success();
753  }
754 
755  // NOLINTNEXTLINE(misc-no-recursion)
756  void check(InFlightDiagnostic &diag, Operation *a, Operation *b) {
757  hw::InnerSymbolTable aTable(a);
758  hw::InnerSymbolTable bTable(b);
759  ModuleData data(aTable, bTable);
760  if (AnnotationSet::hasAnnotation(a, noDedupClass)) {
761  diag.attachNote(a->getLoc()) << "module marked NoDedup";
762  return;
763  }
764  if (AnnotationSet::hasAnnotation(b, noDedupClass)) {
765  diag.attachNote(b->getLoc()) << "module marked NoDedup";
766  return;
767  }
768  auto aGroup =
769  dyn_cast_or_null<StringAttr>(a->getDiscardableAttr(dedupGroupAttrName));
770  auto bGroup = dyn_cast_or_null<StringAttr>(
771  b->getAttrOfType<StringAttr>(dedupGroupAttrName));
772  if (aGroup != bGroup) {
773  if (bGroup) {
774  diag.attachNote(b->getLoc())
775  << "module is in dedup group '" << bGroup.str() << "'";
776  } else {
777  diag.attachNote(b->getLoc()) << "module is not part of a dedup group";
778  }
779  if (aGroup) {
780  diag.attachNote(a->getLoc())
781  << "module is in dedup group '" << aGroup.str() << "'";
782  } else {
783  diag.attachNote(a->getLoc()) << "module is not part of a dedup group";
784  }
785  return;
786  }
787  if (failed(check(diag, data, a, b)))
788  return;
789  diag.attachNote(a->getLoc()) << "first module here";
790  diag.attachNote(b->getLoc()) << "second module here";
791  }
792 
793  // This is a cached "portDirections" string attr.
794  StringAttr portDirectionsAttr;
795  // This is a cached "NoDedup" annotation class string attr.
796  StringAttr noDedupClass;
797  // This is a cached string attr for the dedup group attribute.
798  StringAttr dedupGroupAttrName;
799 
800  // This is a set of every attribute we should ignore.
801  DenseSet<Attribute> nonessentialAttributes;
803 };
804 
805 //===----------------------------------------------------------------------===//
806 // Deduplication
807 //===----------------------------------------------------------------------===//
808 
809 // Custom location merging. This only keeps track of 8 annotations from ".fir"
810 // files, and however many annotations come from "real" sources. When
811 // deduplicating, modules tend not to have scala source locators, so we wind
812 // up fusing source locators for a module from every copy being deduped. There
813 // is little value in this (all the modules are identical by definition).
814 static Location mergeLoc(MLIRContext *context, Location to, Location from) {
815  // Unique the set of locations to be fused.
816  llvm::SmallSetVector<Location, 4> decomposedLocs;
817  // only track 8 "fir" locations
818  unsigned seenFIR = 0;
819  for (auto loc : {to, from}) {
820  // If the location is a fused location we decompose it if it has no
821  // metadata or the metadata is the same as the top level metadata.
822  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
823  // UnknownLoc's have already been removed from FusedLocs so we can
824  // simply add all of the internal locations.
825  for (auto loc : fusedLoc.getLocations()) {
826  if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
827  if (fileLoc.getFilename().strref().ends_with(".fir")) {
828  ++seenFIR;
829  if (seenFIR > 8)
830  continue;
831  }
832  }
833  decomposedLocs.insert(loc);
834  }
835  continue;
836  }
837 
838  // Might need to skip this fir.
839  if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
840  if (fileLoc.getFilename().strref().ends_with(".fir")) {
841  ++seenFIR;
842  if (seenFIR > 8)
843  continue;
844  }
845  }
846  // Otherwise, only add known locations to the set.
847  if (!isa<UnknownLoc>(loc))
848  decomposedLocs.insert(loc);
849  }
850 
851  auto locs = decomposedLocs.getArrayRef();
852 
853  // Handle the simple cases of less than two locations. Ensure the metadata (if
854  // provided) is not dropped.
855  if (locs.empty())
856  return UnknownLoc::get(context);
857  if (locs.size() == 1)
858  return locs.front();
859 
860  return FusedLoc::get(context, locs);
861 }
862 
863 struct Deduper {
864 
865  using RenameMap = DenseMap<StringAttr, StringAttr>;
866 
867  Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable,
868  NLATable *nlaTable, CircuitOp circuit)
869  : context(circuit->getContext()), instanceGraph(instanceGraph),
870  symbolTable(symbolTable), nlaTable(nlaTable),
871  nlaBlock(circuit.getBodyBlock()),
872  nonLocalString(StringAttr::get(context, "circt.nonlocal")),
873  classString(StringAttr::get(context, "class")) {
874  // Populate the NLA cache.
875  for (auto nla : circuit.getOps<hw::HierPathOp>())
876  nlaCache[nla.getNamepathAttr()] = nla.getSymNameAttr();
877  }
878 
879  /// Remove the "fromModule", and replace all references to it with the
880  /// "toModule". Modules should be deduplicated in a bottom-up order. Any
881  /// module which is not deduplicated needs to be recorded with the `record`
882  /// call.
883  void dedup(FModuleLike toModule, FModuleLike fromModule) {
884  // A map of operation (e.g. wires, nodes) names which are changed, which is
885  // used to update NLAs that reference the "fromModule".
886  RenameMap renameMap;
887 
888  // Merge the port locations.
889  SmallVector<Attribute> newLocs;
890  for (auto [toLoc, fromLoc] : llvm::zip(toModule.getPortLocations(),
891  fromModule.getPortLocations())) {
892  if (toLoc == fromLoc)
893  newLocs.push_back(toLoc);
894  else
895  newLocs.push_back(mergeLoc(context, cast<LocationAttr>(toLoc),
896  cast<LocationAttr>(fromLoc)));
897  }
898  toModule->setAttr("portLocations", ArrayAttr::get(context, newLocs));
899 
900  // Merge the two modules.
901  mergeOps(renameMap, toModule, toModule, fromModule, fromModule);
902 
903  // Rewrite NLAs pathing through these modules to refer to the to module. It
904  // is safe to do this at this point because NLAs cannot be one element long.
905  // This means that all NLAs which require more context cannot be targetting
906  // something in the module it self.
907  if (auto to = dyn_cast<FModuleOp>(*toModule))
908  rewriteModuleNLAs(renameMap, to, cast<FModuleOp>(*fromModule));
909  else
910  rewriteExtModuleNLAs(renameMap, toModule.getModuleNameAttr(),
911  fromModule.getModuleNameAttr());
912 
913  replaceInstances(toModule, fromModule);
914  }
915 
916  /// Record the usages of any NLA's in this module, so that we may update the
917  /// annotation if the parent module is deduped with another module.
918  void record(FModuleLike module) {
919  // Record any annotations on the module.
920  recordAnnotations(module);
921  // Record port annotations.
922  for (unsigned i = 0, e = getNumPorts(module); i < e; ++i)
923  recordAnnotations(PortAnnoTarget(module, i));
924  // Record any annotations in the module body.
925  module->walk([&](Operation *op) { recordAnnotations(op); });
926  }
927 
928 private:
929  /// Get a cached namespace for a module.
930  hw::InnerSymbolNamespace &getNamespace(Operation *module) {
931  return moduleNamespaces.try_emplace(module, cast<FModuleLike>(module))
932  .first->second;
933  }
934 
935  /// For a specific annotation target, record all the unique NLAs which
936  /// target it in the `targetMap`.
938  for (auto anno : target.getAnnotations())
939  if (auto nlaRef = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
940  targetMap[nlaRef.getAttr()].insert(target);
941  }
942 
943  /// Record all targets which use an NLA.
944  void recordAnnotations(Operation *op) {
945  // Record annotations.
946  recordAnnotations(OpAnnoTarget(op));
947 
948  // Record port annotations only if this is a mem operation.
949  auto mem = dyn_cast<MemOp>(op);
950  if (!mem)
951  return;
952 
953  // Record port annotations.
954  for (unsigned i = 0, e = mem->getNumResults(); i < e; ++i)
955  recordAnnotations(PortAnnoTarget(mem, i));
956  }
957 
958  /// This deletes and replaces all instances of the "fromModule" with instances
959  /// of the "toModule".
960  void replaceInstances(FModuleLike toModule, Operation *fromModule) {
961  // Replace all instances of the other module.
962  auto *fromNode =
963  instanceGraph[::cast<igraph::ModuleOpInterface>(fromModule)];
964  auto *toNode = instanceGraph[toModule];
965  auto toModuleRef = FlatSymbolRefAttr::get(toModule.getModuleNameAttr());
966  for (auto *oldInstRec : llvm::make_early_inc_range(fromNode->uses())) {
967  auto inst = oldInstRec->getInstance();
968  if (auto instOp = dyn_cast<InstanceOp>(*inst)) {
969  instOp.setModuleNameAttr(toModuleRef);
970  instOp.setPortNamesAttr(toModule.getPortNamesAttr());
971  } else if (auto objectOp = dyn_cast<ObjectOp>(*inst)) {
972  auto classLike = cast<ClassLike>(*toNode->getModule());
973  ClassType classType = detail::getInstanceTypeForClassLike(classLike);
974  objectOp.getResult().setType(classType);
975  }
976  oldInstRec->getParent()->addInstance(inst, toNode);
977  oldInstRec->erase();
978  }
979  instanceGraph.erase(fromNode);
980  fromModule->erase();
981  }
982 
983  /// Look up the instantiations of the `from` module and create an NLA for each
984  /// one, appending the baseNamepath to each NLA. This is used to add more
985  /// context to an already existing NLA. The `fromModule` is used to indicate
986  /// which module the annotation is coming from before the merge, and will be
987  /// used to create the namepaths.
988  SmallVector<FlatSymbolRefAttr>
989  createNLAs(Operation *fromModule, ArrayRef<Attribute> baseNamepath,
990  SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
991  // Create an attribute array with a placeholder in the first element, where
992  // the root refence of the NLA will be inserted.
993  SmallVector<Attribute> namepath = {nullptr};
994  namepath.append(baseNamepath.begin(), baseNamepath.end());
995 
996  auto loc = fromModule->getLoc();
997  auto *fromNode = instanceGraph[cast<igraph::ModuleOpInterface>(fromModule)];
998  SmallVector<FlatSymbolRefAttr> nlas;
999  for (auto *instanceRecord : fromNode->uses()) {
1000  auto parent = cast<FModuleOp>(*instanceRecord->getParent()->getModule());
1001  auto inst = instanceRecord->getInstance();
1002  namepath[0] = OpAnnoTarget(inst).getNLAReference(getNamespace(parent));
1003  auto arrayAttr = ArrayAttr::get(context, namepath);
1004  // Check the NLA cache to see if we already have this NLA.
1005  auto &cacheEntry = nlaCache[arrayAttr];
1006  if (!cacheEntry) {
1007  auto nla = OpBuilder::atBlockBegin(nlaBlock).create<hw::HierPathOp>(
1008  loc, "nla", arrayAttr);
1009  // Insert it into the symbol table to get a unique name.
1010  symbolTable.insert(nla);
1011  // Store it in the cache.
1012  cacheEntry = nla.getNameAttr();
1013  nla.setVisibility(vis);
1014  nlaTable->addNLA(nla);
1015  }
1016  auto nlaRef = FlatSymbolRefAttr::get(cast<StringAttr>(cacheEntry));
1017  nlas.push_back(nlaRef);
1018  }
1019  return nlas;
1020  }
1021 
1022  /// Look up the instantiations of this module and create an NLA for each one.
1023  /// This returns an array of symbol references which can be used to reference
1024  /// the NLAs.
1025  SmallVector<FlatSymbolRefAttr>
1026  createNLAs(StringAttr toModuleName, FModuleLike fromModule,
1027  SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1028  return createNLAs(fromModule, FlatSymbolRefAttr::get(toModuleName), vis);
1029  }
1030 
1031  /// Clone the annotation for each NLA in a list. The attribute list should
1032  /// have a placeholder for the "circt.nonlocal" field, and `nonLocalIndex`
1033  /// should be the index of this field.
1034  void cloneAnnotation(SmallVectorImpl<FlatSymbolRefAttr> &nlas,
1035  Annotation anno, ArrayRef<NamedAttribute> attributes,
1036  unsigned nonLocalIndex,
1037  SmallVectorImpl<Annotation> &newAnnotations) {
1038  SmallVector<NamedAttribute> mutableAttributes(attributes.begin(),
1039  attributes.end());
1040  for (auto &nla : nlas) {
1041  // Add the new annotation.
1042  mutableAttributes[nonLocalIndex].setValue(nla);
1043  auto dict = DictionaryAttr::getWithSorted(context, mutableAttributes);
1044  // The original annotation records if its a subannotation.
1045  anno.setDict(dict);
1046  newAnnotations.push_back(anno);
1047  }
1048  }
1049 
1050  /// This erases the NLA op, and removes the NLA from every module's NLA map,
1051  /// but it does not delete the NLA reference from the target operation's
1052  /// annotations.
1053  void eraseNLA(hw::HierPathOp nla) {
1054  // Erase the NLA from the leaf module's nlaMap.
1055  targetMap.erase(nla.getNameAttr());
1056  nlaTable->erase(nla);
1057  nlaCache.erase(nla.getNamepathAttr());
1058  symbolTable.erase(nla);
1059  }
1060 
1061  /// Process all NLAs referencing the "from" module to point to the "to"
1062  /// module. This is used after merging two modules together.
1063  void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule,
1064  FModuleOp fromModule) {
1065  auto toName = toModule.getNameAttr();
1066  auto fromName = fromModule.getNameAttr();
1067  // Create a copy of the current NLAs. We will be pushing and removing
1068  // NLAs from this op as we go.
1069  auto moduleNLAs = nlaTable->lookup(fromModule.getNameAttr()).vec();
1070  // Change the NLA to target the toModule.
1071  nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1072  // Now we walk the NLA searching for ones that require more context to be
1073  // added.
1074  for (auto nla : moduleNLAs) {
1075  auto elements = nla.getNamepath().getValue();
1076  // If we don't need to add more context, we're done here.
1077  if (nla.root() != toName)
1078  continue;
1079  // Create the replacement NLAs.
1080  SmallVector<Attribute> namepath(elements.begin(), elements.end());
1081  auto nlaRefs = createNLAs(fromModule, namepath, nla.getVisibility());
1082  // Copy out the targets, because we will be updating the map.
1083  auto &set = targetMap[nla.getSymNameAttr()];
1084  SmallVector<AnnoTarget> targets(set.begin(), set.end());
1085  // Replace the uses of the old NLA with the new NLAs.
1086  for (auto target : targets) {
1087  // We have to clone any annotation which uses the old NLA for each new
1088  // NLA. This array collects the new set of annotations.
1089  SmallVector<Annotation> newAnnotations;
1090  for (auto anno : target.getAnnotations()) {
1091  // Find the non-local field of the annotation.
1092  auto [it, found] = mlir::impl::findAttrSorted(
1093  anno.begin(), anno.end(), nonLocalString);
1094  // If this annotation doesn't use the target NLA, copy it with no
1095  // changes.
1096  if (!found || cast<FlatSymbolRefAttr>(it->getValue()).getAttr() !=
1097  nla.getSymNameAttr()) {
1098  newAnnotations.push_back(anno);
1099  continue;
1100  }
1101  auto nonLocalIndex = std::distance(anno.begin(), it);
1102  // Clone the annotation and add it to the list of new annotations.
1103  cloneAnnotation(nlaRefs, anno,
1104  ArrayRef<NamedAttribute>(anno.begin(), anno.end()),
1105  nonLocalIndex, newAnnotations);
1106  }
1107 
1108  // Apply the new annotations to the operation.
1109  AnnotationSet annotations(newAnnotations, context);
1110  target.setAnnotations(annotations);
1111  // Record that target uses the NLA.
1112  for (auto nla : nlaRefs)
1113  targetMap[nla.getAttr()].insert(target);
1114  }
1115 
1116  // Erase the old NLA and remove it from all breadcrumbs.
1117  eraseNLA(nla);
1118  }
1119  }
1120 
1121  /// Process all the NLAs that the two modules participate in, replacing
1122  /// references to the "from" module with references to the "to" module, and
1123  /// adding more context if necessary.
1124  void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule,
1125  FModuleOp fromModule) {
1126  addAnnotationContext(renameMap, toModule, toModule);
1127  addAnnotationContext(renameMap, toModule, fromModule);
1128  }
1129 
1130  // Update all NLAs which the "from" external module participates in to the
1131  // "toName".
1132  void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName,
1133  StringAttr fromName) {
1134  nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1135  }
1136 
1137  /// Take an annotation, and update it to be a non-local annotation. If the
1138  /// annotation is already non-local and has enough context, it will be skipped
1139  /// for now. Return true if the annotation was made non-local.
1140  bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to,
1141  FModuleLike fromModule, Annotation anno,
1142  SmallVectorImpl<Annotation> &newAnnotations) {
1143  // Start constructing a new annotation, pushing a "circt.nonLocal" field
1144  // into the correct spot if its not already a non-local annotation.
1145  SmallVector<NamedAttribute> attributes;
1146  int nonLocalIndex = -1;
1147  for (const auto &val : llvm::enumerate(anno)) {
1148  auto attr = val.value();
1149  // Is this field "circt.nonlocal"?
1150  auto compare = attr.getName().compare(nonLocalString);
1151  assert(compare != 0 && "should not pass non-local annotations here");
1152  if (compare == 1) {
1153  // This annotation definitely does not have "circt.nonlocal" field. Push
1154  // an empty place holder for the non-local annotation.
1155  nonLocalIndex = val.index();
1156  attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1157  break;
1158  }
1159  // Otherwise push the current attribute and keep searching for the
1160  // "circt.nonlocal" field.
1161  attributes.push_back(attr);
1162  }
1163  if (nonLocalIndex == -1) {
1164  // Push an empty "circt.nonlocal" field to the last slot.
1165  nonLocalIndex = attributes.size();
1166  attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1167  } else {
1168  // Copy the remaining annotation fields in.
1169  attributes.append(anno.begin() + nonLocalIndex, anno.end());
1170  }
1171 
1172  // Construct the NLAs if we don't have any yet.
1173  auto nlaRefs = createNLAs(toModuleName, fromModule);
1174  for (auto nla : nlaRefs)
1175  targetMap[nla.getAttr()].insert(to);
1176 
1177  // Clone the annotation for each new NLA.
1178  cloneAnnotation(nlaRefs, anno, attributes, nonLocalIndex, newAnnotations);
1179  return true;
1180  }
1181 
1182  void copyAnnotations(FModuleLike toModule, AnnoTarget to,
1183  FModuleLike fromModule, AnnotationSet annos,
1184  SmallVectorImpl<Annotation> &newAnnotations,
1185  SmallPtrSetImpl<Attribute> &dontTouches) {
1186  for (auto anno : annos) {
1187  if (anno.isClass(dontTouchAnnoClass)) {
1188  // Remove the nonlocal field of the annotation if it has one, since this
1189  // is a sticky annotation.
1190  anno.removeMember("circt.nonlocal");
1191  auto [it, inserted] = dontTouches.insert(anno.getAttr());
1192  if (inserted)
1193  newAnnotations.push_back(anno);
1194  continue;
1195  }
1196  // If the annotation is already non-local, we add it as is. It is already
1197  // added to the target map.
1198  if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal")) {
1199  newAnnotations.push_back(anno);
1200  targetMap[nla.getAttr()].insert(to);
1201  continue;
1202  }
1203  // Otherwise make the annotation non-local and add it to the set.
1204  makeAnnotationNonLocal(toModule.getModuleNameAttr(), to, fromModule, anno,
1205  newAnnotations);
1206  }
1207  }
1208 
1209  /// Merge the annotations of a specific target, either a operation or a port
1210  /// on an operation.
1211  void mergeAnnotations(FModuleLike toModule, AnnoTarget to,
1212  AnnotationSet toAnnos, FModuleLike fromModule,
1213  AnnoTarget from, AnnotationSet fromAnnos) {
1214  // This is a list of all the annotations which will be added to `to`.
1215  SmallVector<Annotation> newAnnotations;
1216 
1217  // We have special case handling of DontTouch to prevent it from being
1218  // turned into a non-local annotation, and to remove duplicates.
1219  llvm::SmallPtrSet<Attribute, 4> dontTouches;
1220 
1221  // Iterate the annotations, transforming most annotations into non-local
1222  // ones.
1223  copyAnnotations(toModule, to, toModule, toAnnos, newAnnotations,
1224  dontTouches);
1225  copyAnnotations(toModule, to, fromModule, fromAnnos, newAnnotations,
1226  dontTouches);
1227 
1228  // Copy over all the new annotations.
1229  if (!newAnnotations.empty())
1230  to.setAnnotations(AnnotationSet(newAnnotations, context));
1231  }
1232 
1233  /// Merge all annotations and port annotations on two operations.
1234  void mergeAnnotations(FModuleLike toModule, Operation *to,
1235  FModuleLike fromModule, Operation *from) {
1236  // Merge op annotations.
1237  mergeAnnotations(toModule, OpAnnoTarget(to), AnnotationSet(to), fromModule,
1238  OpAnnoTarget(from), AnnotationSet(from));
1239 
1240  // Merge port annotations.
1241  if (toModule == to) {
1242  // Merge module port annotations.
1243  for (unsigned i = 0, e = getNumPorts(toModule); i < e; ++i)
1244  mergeAnnotations(toModule, PortAnnoTarget(toModule, i),
1245  AnnotationSet::forPort(toModule, i), fromModule,
1246  PortAnnoTarget(fromModule, i),
1247  AnnotationSet::forPort(fromModule, i));
1248  } else if (auto toMem = dyn_cast<MemOp>(to)) {
1249  // Merge memory port annotations.
1250  auto fromMem = cast<MemOp>(from);
1251  for (unsigned i = 0, e = toMem.getNumResults(); i < e; ++i)
1252  mergeAnnotations(toModule, PortAnnoTarget(toMem, i),
1253  AnnotationSet::forPort(toMem, i), fromModule,
1254  PortAnnoTarget(fromMem, i),
1255  AnnotationSet::forPort(fromMem, i));
1256  }
1257  }
1258 
1259  // Record the symbol name change of the operation or any of its ports when
1260  // merging two operations. The renamed symbols are used to update the
1261  // target of any NLAs. This will add symbols to the "to" operation if needed.
1262  void recordSymRenames(RenameMap &renameMap, FModuleLike toModule,
1263  Operation *to, FModuleLike fromModule,
1264  Operation *from) {
1265  // If the "from" operation has an inner_sym, we need to make sure the
1266  // "to" operation also has an `inner_sym` and then record the renaming.
1267  if (auto fromSym = getInnerSymName(from)) {
1268  auto toSym =
1269  getOrAddInnerSym(to, [&](auto _) -> hw::InnerSymbolNamespace & {
1270  return getNamespace(toModule);
1271  });
1272  renameMap[fromSym] = toSym;
1273  }
1274 
1275  // If there are no port symbols on the "from" operation, we are done here.
1276  auto fromPortSyms = from->getAttrOfType<ArrayAttr>("portSyms");
1277  if (!fromPortSyms || fromPortSyms.empty())
1278  return;
1279  // We have to map each "fromPort" to each "toPort".
1280  auto &moduleNamespace = getNamespace(toModule);
1281  auto portCount = fromPortSyms.size();
1282  auto portNames = to->getAttrOfType<ArrayAttr>("portNames");
1283  auto toPortSyms = to->getAttrOfType<ArrayAttr>("portSyms");
1284 
1285  // Create an array of new port symbols for the "to" operation, copy in the
1286  // old symbols if it has any, create an empty symbol array if it doesn't.
1287  SmallVector<Attribute> newPortSyms;
1288  if (toPortSyms.empty())
1289  newPortSyms.assign(portCount, hw::InnerSymAttr());
1290  else
1291  newPortSyms.assign(toPortSyms.begin(), toPortSyms.end());
1292 
1293  for (unsigned portNo = 0; portNo < portCount; ++portNo) {
1294  // If this fromPort doesn't have a symbol, move on to the next one.
1295  if (!fromPortSyms[portNo])
1296  continue;
1297  auto fromSym = cast<hw::InnerSymAttr>(fromPortSyms[portNo]);
1298 
1299  // If this toPort doesn't have a symbol, assign one.
1300  hw::InnerSymAttr toSym;
1301  if (!newPortSyms[portNo]) {
1302  // Get a reasonable base name for the port.
1303  StringRef symName = "inner_sym";
1304  if (portNames)
1305  symName = cast<StringAttr>(portNames[portNo]).getValue();
1306  // Create the symbol and store it into the array.
1307  toSym = hw::InnerSymAttr::get(
1308  StringAttr::get(context, moduleNamespace.newName(symName)));
1309  newPortSyms[portNo] = toSym;
1310  } else
1311  toSym = cast<hw::InnerSymAttr>(newPortSyms[portNo]);
1312 
1313  // Record the renaming.
1314  renameMap[fromSym.getSymName()] = toSym.getSymName();
1315  }
1316 
1317  // Commit the new symbol attribute.
1318  cast<FModuleLike>(to).setPortSymbols(newPortSyms);
1319  }
1320 
1321  /// Recursively merge two operations.
1322  // NOLINTNEXTLINE(misc-no-recursion)
1323  void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to,
1324  FModuleLike fromModule, Operation *from) {
1325  // Merge the operation locations.
1326  if (to->getLoc() != from->getLoc())
1327  to->setLoc(mergeLoc(context, to->getLoc(), from->getLoc()));
1328 
1329  // Recurse into any regions.
1330  for (auto regions : llvm::zip(to->getRegions(), from->getRegions()))
1331  mergeRegions(renameMap, toModule, std::get<0>(regions), fromModule,
1332  std::get<1>(regions));
1333 
1334  // Record any inner_sym renamings that happened.
1335  recordSymRenames(renameMap, toModule, to, fromModule, from);
1336 
1337  // Merge the annotations.
1338  mergeAnnotations(toModule, to, fromModule, from);
1339  }
1340 
1341  /// Recursively merge two blocks.
1342  void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock,
1343  FModuleLike fromModule, Block &fromBlock) {
1344  // Merge the block locations.
1345  for (auto [toArg, fromArg] :
1346  llvm::zip(toBlock.getArguments(), fromBlock.getArguments()))
1347  if (toArg.getLoc() != fromArg.getLoc())
1348  toArg.setLoc(mergeLoc(context, toArg.getLoc(), fromArg.getLoc()));
1349 
1350  for (auto ops : llvm::zip(toBlock, fromBlock))
1351  mergeOps(renameMap, toModule, &std::get<0>(ops), fromModule,
1352  &std::get<1>(ops));
1353  }
1354 
1355  // Recursively merge two regions.
1356  void mergeRegions(RenameMap &renameMap, FModuleLike toModule,
1357  Region &toRegion, FModuleLike fromModule,
1358  Region &fromRegion) {
1359  for (auto blocks : llvm::zip(toRegion, fromRegion))
1360  mergeBlocks(renameMap, toModule, std::get<0>(blocks), fromModule,
1361  std::get<1>(blocks));
1362  }
1363 
1364  MLIRContext *context;
1366  SymbolTable &symbolTable;
1367 
1368  /// Cached nla table analysis.
1369  NLATable *nlaTable = nullptr;
1370 
1371  /// We insert all NLAs to the beginning of this block.
1372  Block *nlaBlock;
1373 
1374  // This maps an NLA to the operations and ports that uses it.
1375  DenseMap<Attribute, llvm::SmallDenseSet<AnnoTarget>> targetMap;
1376 
1377  // This is a cache to avoid creating duplicate NLAs. This maps the ArrayAtr
1378  // of the NLA's path to the name of the NLA which contains it.
1379  DenseMap<Attribute, Attribute> nlaCache;
1380 
1381  // Cached attributes for faster comparisons and attribute building.
1382  StringAttr nonLocalString;
1383  StringAttr classString;
1384 
1385  /// A module namespace cache.
1386  DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
1387 };
1388 
1389 //===----------------------------------------------------------------------===//
1390 // Fixup
1391 //===----------------------------------------------------------------------===//
1392 
1393 /// This fixes up ClassLikes with ClassType ports, when the classes have
1394 /// deduped. For each ClassType port, if the object reference being assigned is
1395 /// a different type, update the port type. Returns true if the ClassOp was
1396 /// updated and the associated ObjectOps should be updated.
1397 bool fixupClassOp(ClassOp classOp) {
1398  // New port type attributes, if necessary.
1399  SmallVector<Attribute> newPortTypes;
1400  bool anyDifferences = false;
1401 
1402  // Check each port.
1403  for (size_t i = 0, e = classOp.getNumPorts(); i < e; ++i) {
1404  // Check if this port is a ClassType. If not, save the original type
1405  // attribute in case we need to update port types.
1406  auto portClassType = dyn_cast<ClassType>(classOp.getPortType(i));
1407  if (!portClassType) {
1408  newPortTypes.push_back(classOp.getPortTypeAttr(i));
1409  continue;
1410  }
1411 
1412  // Check if this port is assigned a reference of a different ClassType.
1413  Type newPortClassType;
1414  BlockArgument portArg = classOp.getArgument(i);
1415  for (auto &use : portArg.getUses()) {
1416  if (auto propassign = dyn_cast<PropAssignOp>(use.getOwner())) {
1417  Type sourceType = propassign.getSrc().getType();
1418  if (propassign.getDest() == use.get() && sourceType != portClassType) {
1419  // Double check that all references are the same new type.
1420  if (newPortClassType) {
1421  assert(newPortClassType == sourceType &&
1422  "expected all references to be of the same type");
1423  continue;
1424  }
1425 
1426  newPortClassType = sourceType;
1427  }
1428  }
1429  }
1430 
1431  // If there was no difference, save the original type attribute in case we
1432  // need to update port types and move along.
1433  if (!newPortClassType) {
1434  newPortTypes.push_back(classOp.getPortTypeAttr(i));
1435  continue;
1436  }
1437 
1438  // The port type changed, so update the block argument, save the new port
1439  // type attribute, and indicate there was a difference.
1440  classOp.getArgument(i).setType(newPortClassType);
1441  newPortTypes.push_back(TypeAttr::get(newPortClassType));
1442  anyDifferences = true;
1443  }
1444 
1445  // If necessary, update port types.
1446  if (anyDifferences)
1447  classOp.setPortTypes(newPortTypes);
1448 
1449  return anyDifferences;
1450 }
1451 
1452 /// This fixes up ObjectOps when the signature of their ClassOp changes. This
1453 /// amounts to updating the ObjectOp result type to match the newly updated
1454 /// ClassOp type.
1455 void fixupObjectOp(ObjectOp objectOp, ClassType newClassType) {
1456  objectOp.getResult().setType(newClassType);
1457 }
1458 
1459 /// This fixes up connects when the field names of a bundle type changes. It
1460 /// finds all fields which were previously bulk connected and legalizes it
1461 /// into a connect for each field.
1462 void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src) {
1463  // If the types already match we can emit a connect.
1464  auto dstType = dst.getType();
1465  auto srcType = src.getType();
1466  if (dstType == srcType) {
1467  emitConnect(builder, dst, src);
1468  return;
1469  }
1470  // It must be a bundle type and the field name has changed. We have to
1471  // manually decompose the bulk connect into a connect for each field.
1472  auto dstBundle = type_cast<BundleType>(dstType);
1473  auto srcBundle = type_cast<BundleType>(srcType);
1474  for (unsigned i = 0; i < dstBundle.getNumElements(); ++i) {
1475  auto dstField = builder.create<SubfieldOp>(dst, i);
1476  auto srcField = builder.create<SubfieldOp>(src, i);
1477  if (dstBundle.getElement(i).isFlip) {
1478  std::swap(srcBundle, dstBundle);
1479  std::swap(srcField, dstField);
1480  }
1481  fixupConnect(builder, dstField, srcField);
1482  }
1483 }
1484 
1485 /// This is the root method to fixup module references when a module changes.
1486 /// It matches all the results of "to" module with the results of the "from"
1487 /// module.
1488 void fixupAllModules(InstanceGraph &instanceGraph) {
1489  for (auto *node : instanceGraph) {
1490  auto module = cast<FModuleLike>(*node->getModule());
1491 
1492  // Handle class declarations here.
1493  bool shouldFixupObjects = false;
1494  auto classOp = dyn_cast<ClassOp>(module.getOperation());
1495  if (classOp)
1496  shouldFixupObjects = fixupClassOp(classOp);
1497 
1498  for (auto *instRec : node->uses()) {
1499  // Handle object instantiations here.
1500  if (classOp) {
1501  if (shouldFixupObjects) {
1502  fixupObjectOp(instRec->getInstance<ObjectOp>(),
1503  classOp.getInstanceType());
1504  }
1505  continue;
1506  }
1507 
1508  auto inst = instRec->getInstance<InstanceOp>();
1509  // Only handle module instantiations here.
1510  if (!inst)
1511  continue;
1512  ImplicitLocOpBuilder builder(inst.getLoc(), inst->getContext());
1513  builder.setInsertionPointAfter(inst);
1514  for (size_t i = 0, e = getNumPorts(module); i < e; ++i) {
1515  auto result = inst.getResult(i);
1516  auto newType = module.getPortType(i);
1517  auto oldType = result.getType();
1518  // If the type has not changed, we don't have to fix up anything.
1519  if (newType == oldType)
1520  continue;
1521  // If the type changed we transform it back to the old type with an
1522  // intermediate wire.
1523  auto wire =
1524  builder.create<WireOp>(oldType, inst.getPortName(i)).getResult();
1525  result.replaceAllUsesWith(wire);
1526  result.setType(newType);
1527  if (inst.getPortDirection(i) == Direction::Out)
1528  fixupConnect(builder, wire, result);
1529  else
1530  fixupConnect(builder, result, wire);
1531  }
1532  }
1533  }
1534 }
1535 
1536 namespace llvm {
1537 /// A DenseMapInfo implementation for `ModuleInfo` that is a pair of
1538 /// llvm::SHA256 hashes, which are represented as std::array<uint8_t, 32>, and
1539 /// an array of string attributes. This allows us to create a DenseMap with
1540 /// `ModuleInfo` as keys.
1541 template <>
1542 struct DenseMapInfo<ModuleInfo> {
1543  static inline ModuleInfo getEmptyKey() {
1544  std::array<uint8_t, 32> key;
1545  std::fill(key.begin(), key.end(), ~0);
1546  return {key, DenseMapInfo<mlir::ArrayAttr>::getEmptyKey()};
1547  }
1548 
1549  static inline ModuleInfo getTombstoneKey() {
1550  std::array<uint8_t, 32> key;
1551  std::fill(key.begin(), key.end(), ~0 - 1);
1552  return {key, DenseMapInfo<mlir::ArrayAttr>::getTombstoneKey()};
1553  }
1554 
1555  static unsigned getHashValue(const ModuleInfo &val) {
1556  // We assume SHA256 is already a good hash and just truncate down to the
1557  // number of bytes we need for DenseMap.
1558  unsigned hash;
1559  std::memcpy(&hash, val.structuralHash.data(), sizeof(unsigned));
1560 
1561  // Combine module names.
1562  return llvm::hash_combine(hash, val.referredModuleNames);
1563  }
1564 
1565  static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs) {
1566  return lhs.structuralHash == rhs.structuralHash &&
1567  lhs.referredModuleNames == rhs.referredModuleNames;
1568  }
1569 };
1570 } // namespace llvm
1571 
1572 //===----------------------------------------------------------------------===//
1573 // DedupPass
1574 //===----------------------------------------------------------------------===//
1575 
1576 namespace {
1577 class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {
1578  void runOnOperation() override {
1579  auto *context = &getContext();
1580  auto circuit = getOperation();
1581  auto &instanceGraph = getAnalysis<InstanceGraph>();
1582  auto *nlaTable = &getAnalysis<NLATable>();
1583  auto &symbolTable = getAnalysis<SymbolTable>();
1584  Deduper deduper(instanceGraph, symbolTable, nlaTable, circuit);
1585  Equivalence equiv(context, instanceGraph);
1586  auto anythingChanged = false;
1587 
1588  // Modules annotated with this should not be considered for deduplication.
1589  auto noDedupClass = StringAttr::get(context, noDedupAnnoClass);
1590 
1591  // Only modules within the same group may be deduplicated.
1592  auto dedupGroupClass = StringAttr::get(context, dedupGroupAnnoClass);
1593 
1594  // A map of all the module moduleInfo that we have calculated so far.
1595  llvm::DenseMap<ModuleInfo, Operation *> moduleInfoToModule;
1596 
1597  // We track the name of the module that each module is deduped into, so that
1598  // we can make sure all modules which are marked "must dedup" with each
1599  // other were all deduped to the same module.
1600  DenseMap<Attribute, StringAttr> dedupMap;
1601 
1602  // We must iterate the modules from the bottom up so that we can properly
1603  // deduplicate the modules. We copy the list of modules into a vector first
1604  // to avoid iterator invalidation while we mutate the instance graph.
1605  SmallVector<FModuleLike, 0> modules(
1606  llvm::map_range(llvm::post_order(&instanceGraph), [](auto *node) {
1607  return cast<FModuleLike>(*node->getModule());
1608  }));
1609 
1610  SmallVector<std::optional<
1611  std::pair<std::array<uint8_t, 32>, SmallVector<StringAttr>>>>
1612  hashesAndModuleNames(modules.size());
1613  StructuralHasherSharedConstants hasherConstants(&getContext());
1614 
1615  // Attribute name used to store dedup_group for this pass.
1616  auto dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
1617 
1618  // Move dedup group annotations to attributes on the module.
1619  // This results in the desired behavior (included in hash),
1620  // and avoids unnecessary processing of these as annotations
1621  // that need to be tracked, made non-local, so on.
1622  for (auto module : modules) {
1623  llvm::SmallSetVector<StringAttr, 1> groups;
1625  module, [&groups, dedupGroupClass](Annotation annotation) {
1626  if (annotation.getClassAttr() != dedupGroupClass)
1627  return false;
1628  groups.insert(annotation.getMember<StringAttr>("group"));
1629  return true;
1630  });
1631  if (groups.size() > 1) {
1632  module.emitError("module belongs to multiple dedup groups: ") << groups;
1633  return signalPassFailure();
1634  }
1635  assert(!module->hasAttr(dedupGroupAttrName) &&
1636  "unexpected existing use of temporary dedup group attribute");
1637  if (!groups.empty())
1638  module->setDiscardableAttr(dedupGroupAttrName, groups.front());
1639  }
1640 
1641  // Calculate module information parallelly.
1642  auto result = mlir::failableParallelForEach(
1643  context, llvm::seq(modules.size()), [&](unsigned idx) {
1644  auto module = modules[idx];
1645  // If the module is marked with NoDedup, just skip it.
1646  if (AnnotationSet::hasAnnotation(module, noDedupClass))
1647  return success();
1648 
1649  // Only dedup extmodule's with defname.
1650  if (auto ext = dyn_cast<FExtModuleOp>(*module);
1651  ext && !ext.getDefname().has_value())
1652  return success();
1653 
1654  // If module has symbol (name) that must be preserved even if unused,
1655  // skip it. All symbol uses must be supported, which is not true if
1656  // non-private.
1657  if (!module.isPrivate() ||
1658  (!module.canDiscardOnUseEmpty() && !isa<ClassLike>(*module))) {
1659  return success();
1660  }
1661 
1662  StructuralHasher hasher(hasherConstants);
1663  // Calculate the hash of the module and referred module names.
1664  hashesAndModuleNames[idx] = hasher.getHashAndModuleNames(module);
1665  return success();
1666  });
1667 
1668  if (result.failed())
1669  return signalPassFailure();
1670 
1671  for (auto [i, module] : llvm::enumerate(modules)) {
1672  auto moduleName = module.getModuleNameAttr();
1673  auto &hashAndModuleNamesOpt = hashesAndModuleNames[i];
1674  // If the hash was not calculated, we need to skip it.
1675  if (!hashAndModuleNamesOpt) {
1676  // We record it in the dedup map to help detect errors when the user
1677  // marks the module as both NoDedup and MustDedup. We do not record this
1678  // module in the hasher to make sure no other module dedups "into" this
1679  // one.
1680  dedupMap[moduleName] = moduleName;
1681  continue;
1682  }
1683 
1684  // Replace module names referred in the module with new names.
1685  SmallVector<mlir::Attribute> names;
1686  for (auto oldModuleName : hashAndModuleNamesOpt->second) {
1687  auto newModuleName = dedupMap[oldModuleName];
1688  names.push_back(newModuleName);
1689  }
1690 
1691  // Create a module info to use it as a key.
1692  ModuleInfo moduleInfo{hashAndModuleNamesOpt->first,
1693  mlir::ArrayAttr::get(module.getContext(), names)};
1694 
1695  // Check if there a module with the same hash.
1696  auto it = moduleInfoToModule.find(moduleInfo);
1697  if (it != moduleInfoToModule.end()) {
1698  auto original = cast<FModuleLike>(it->second);
1699  // Record the group ID of the other module.
1700  dedupMap[moduleName] = original.getModuleNameAttr();
1701  deduper.dedup(original, module);
1702  ++erasedModules;
1703  anythingChanged = true;
1704  continue;
1705  }
1706  // Any module not deduplicated must be recorded.
1707  deduper.record(module);
1708  // Add the module to a new dedup group.
1709  dedupMap[moduleName] = moduleName;
1710  // Record the module info.
1711  moduleInfoToModule[moduleInfo] = module;
1712  }
1713 
1714  // This part verifies that all modules marked by "MustDedup" have been
1715  // properly deduped with each other. For this check to succeed, all modules
1716  // have to been deduped to the same module. It is possible that a module was
1717  // deduped with the wrong thing.
1718 
1719  auto failed = false;
1720  // This parses the module name out of a target string.
1721  auto parseModule = [&](Attribute path) -> StringAttr {
1722  // Each module is listed as a target "~Circuit|Module" which we have to
1723  // parse.
1724  auto [_, rhs] = cast<StringAttr>(path).getValue().split('|');
1725  return StringAttr::get(context, rhs);
1726  };
1727  // This gets the name of the module which the current module was deduped
1728  // with. If the named module isn't in the map, then we didn't encounter it
1729  // in the circuit.
1730  auto getLead = [&](StringAttr module) -> StringAttr {
1731  auto it = dedupMap.find(module);
1732  if (it == dedupMap.end()) {
1733  auto diag = emitError(circuit.getLoc(),
1734  "MustDeduplicateAnnotation references module ")
1735  << module << " which does not exist";
1736  failed = true;
1737  return nullptr;
1738  }
1739  return it->second;
1740  };
1741 
1742  AnnotationSet::removeAnnotations(circuit, [&](Annotation annotation) {
1743  if (!annotation.isClass(mustDedupAnnoClass))
1744  return false;
1745  auto modules = annotation.getMember<ArrayAttr>("modules");
1746  if (!modules) {
1747  emitError(circuit.getLoc(),
1748  "MustDeduplicateAnnotation missing \"modules\" member");
1749  failed = true;
1750  return false;
1751  }
1752  // Empty module list has nothing to process.
1753  if (modules.empty())
1754  return true;
1755  // Get the first element.
1756  auto firstModule = parseModule(modules[0]);
1757  auto firstLead = getLead(firstModule);
1758  if (!firstLead)
1759  return false;
1760  // Verify that the remaining elements are all the same as the first.
1761  for (auto attr : modules.getValue().drop_front()) {
1762  auto nextModule = parseModule(attr);
1763  auto nextLead = getLead(nextModule);
1764  if (!nextLead)
1765  return false;
1766  if (firstLead != nextLead) {
1767  auto diag = emitError(circuit.getLoc(), "module ")
1768  << nextModule << " not deduplicated with " << firstModule;
1769  auto a = instanceGraph.lookup(firstLead)->getModule();
1770  auto b = instanceGraph.lookup(nextLead)->getModule();
1771  equiv.check(diag, a, b);
1772  failed = true;
1773  return false;
1774  }
1775  }
1776  return true;
1777  });
1778  if (failed)
1779  return signalPassFailure();
1780 
1781  // Remove all dedup group attributes, they only exist during this pass.
1782  for (auto module : circuit.getOps<FModuleLike>())
1783  module->removeDiscardableAttr(dedupGroupAttrName);
1784 
1785  // Walk all the modules and fixup the instance operation to return the
1786  // correct type. We delay this fixup until the end because doing it early
1787  // can block the deduplication of the parent modules.
1788  fixupAllModules(instanceGraph);
1789 
1790  markAnalysesPreserved<NLATable>();
1791  if (!anythingChanged)
1792  markAllAnalysesPreserved();
1793  }
1794 };
1795 } // end anonymous namespace
1796 
1797 std::unique_ptr<mlir::Pass> circt::firrtl::createDedupPass() {
1798  return std::make_unique<DedupPass>();
1799 }
assert(baseType &&"element must be base type")
static Attribute getAttr(ArrayRef< NamedAttribute > attrs, StringRef name)
Get an attribute by name from a list of named attributes.
Definition: FIRRTLOps.cpp:3929
static Location mergeLoc(MLIRContext *context, Location to, Location from)
Definition: Dedup.cpp:814
void fixupObjectOp(ObjectOp objectOp, ClassType newClassType)
This fixes up ObjectOps when the signature of their ClassOp changes.
Definition: Dedup.cpp:1455
llvm::raw_ostream & printHex(llvm::raw_ostream &stream, ArrayRef< uint8_t > bytes)
Definition: Dedup.cpp:53
void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src)
This fixes up connects when the field names of a bundle type changes.
Definition: Dedup.cpp:1462
llvm::raw_ostream & printHash(llvm::raw_ostream &stream, llvm::SHA256 &data)
Definition: Dedup.cpp:59
void fixupAllModules(InstanceGraph &instanceGraph)
This is the root method to fixup module references when a module changes.
Definition: Dedup.cpp:1488
bool fixupClassOp(ClassOp classOp)
This fixes up ClassLikes with ClassType ports, when the classes have deduped.
Definition: Dedup.cpp:1397
static void mergeRegions(Region *region1, Region *region2)
Definition: HWCleanup.cpp:77
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
void setDict(DictionaryAttr dict)
Set the data dictionary of this attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
StringAttr getClassAttr() const
Return the 'class' that this annotation is representing.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition: NLATable.h:29
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
ClassType getInstanceTypeForClassLike(ClassLike classOp)
Definition: FIRRTLOps.cpp:1803
static StringRef toString(Direction direction)
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
constexpr const char * mustDedupAnnoClass
std::pair< hw::InnerSymAttr, StringAttr > getOrAddInnerSym(MLIRContext *context, hw::InnerSymAttr attr, uint64_t fieldID, llvm::function_ref< hw::InnerSymbolNamespace &()> getNamespace)
Ensure that the the InnerSymAttr has a symbol on the field specified.
constexpr const char * noDedupAnnoClass
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
Definition: FIRRTLOps.cpp:298
constexpr const char * dedupGroupAnnoClass
std::unique_ptr< mlir::Pass > createDedupPass()
Definition: Dedup.cpp:1797
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
constexpr const char * dontTouchAnnoClass
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
Definition: FIRRTLUtils.cpp:24
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
Definition: FIRRTLOps.h:108
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
MLIRContext * context
Definition: Dedup.cpp:1364
Block * nlaBlock
We insert all NLAs to the beginning of this block.
Definition: Dedup.cpp:1372
void recordAnnotations(Operation *op)
Record all targets which use an NLA.
Definition: Dedup.cpp:944
void eraseNLA(hw::HierPathOp nla)
This erases the NLA op, and removes the NLA from every module's NLA map, but it does not delete the N...
Definition: Dedup.cpp:1053
void mergeAnnotations(FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Merge all annotations and port annotations on two operations.
Definition: Dedup.cpp:1234
void replaceInstances(FModuleLike toModule, Operation *fromModule)
This deletes and replaces all instances of the "fromModule" with instances of the "toModule".
Definition: Dedup.cpp:960
SmallVector< FlatSymbolRefAttr > createNLAs(StringAttr toModuleName, FModuleLike fromModule, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of this module and create an NLA for each one.
Definition: Dedup.cpp:1026
void record(FModuleLike module)
Record the usages of any NLA's in this module, so that we may update the annotation if the parent mod...
Definition: Dedup.cpp:918
void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName, StringAttr fromName)
Definition: Dedup.cpp:1132
void mergeRegions(RenameMap &renameMap, FModuleLike toModule, Region &toRegion, FModuleLike fromModule, Region &fromRegion)
Definition: Dedup.cpp:1356
void dedup(FModuleLike toModule, FModuleLike fromModule)
Remove the "fromModule", and replace all references to it with the "toModule".
Definition: Dedup.cpp:883
void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all the NLAs that the two modules participate in, replacing references to the "from" module w...
Definition: Dedup.cpp:1124
void recordAnnotations(AnnoTarget target)
For a specific annotation target, record all the unique NLAs which target it in the targetMap.
Definition: Dedup.cpp:937
void cloneAnnotation(SmallVectorImpl< FlatSymbolRefAttr > &nlas, Annotation anno, ArrayRef< NamedAttribute > attributes, unsigned nonLocalIndex, SmallVectorImpl< Annotation > &newAnnotations)
Clone the annotation for each NLA in a list.
Definition: Dedup.cpp:1034
void recordSymRenames(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Definition: Dedup.cpp:1262
void mergeAnnotations(FModuleLike toModule, AnnoTarget to, AnnotationSet toAnnos, FModuleLike fromModule, AnnoTarget from, AnnotationSet fromAnnos)
Merge the annotations of a specific target, either a operation or a port on an operation.
Definition: Dedup.cpp:1211
StringAttr nonLocalString
Definition: Dedup.cpp:1382
SymbolTable & symbolTable
Definition: Dedup.cpp:1366
SmallVector< FlatSymbolRefAttr > createNLAs(Operation *fromModule, ArrayRef< Attribute > baseNamepath, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of the from module and create an NLA for each one, appending the baseNamep...
Definition: Dedup.cpp:989
void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Recursively merge two operations.
Definition: Dedup.cpp:1323
hw::InnerSymbolNamespace & getNamespace(Operation *module)
Get a cached namespace for a module.
Definition: Dedup.cpp:930
DenseMap< Operation *, hw::InnerSymbolNamespace > moduleNamespaces
A module namespace cache.
Definition: Dedup.cpp:1386
bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to, FModuleLike fromModule, Annotation anno, SmallVectorImpl< Annotation > &newAnnotations)
Take an annotation, and update it to be a non-local annotation.
Definition: Dedup.cpp:1140
InstanceGraph & instanceGraph
Definition: Dedup.cpp:1365
void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock, FModuleLike fromModule, Block &fromBlock)
Recursively merge two blocks.
Definition: Dedup.cpp:1342
DenseMap< Attribute, llvm::SmallDenseSet< AnnoTarget > > targetMap
Definition: Dedup.cpp:1375
StringAttr classString
Definition: Dedup.cpp:1383
void copyAnnotations(FModuleLike toModule, AnnoTarget to, FModuleLike fromModule, AnnotationSet annos, SmallVectorImpl< Annotation > &newAnnotations, SmallPtrSetImpl< Attribute > &dontTouches)
Definition: Dedup.cpp:1182
Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable, NLATable *nlaTable, CircuitOp circuit)
Definition: Dedup.cpp:867
void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all NLAs referencing the "from" module to point to the "to" module.
Definition: Dedup.cpp:1063
DenseMap< StringAttr, StringAttr > RenameMap
Definition: Dedup.cpp:865
DenseMap< Attribute, Attribute > nlaCache
Definition: Dedup.cpp:1379
const hw::InnerSymbolTable & a
Definition: Dedup.cpp:362
ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
Definition: Dedup.cpp:359
const hw::InnerSymbolTable & b
Definition: Dedup.cpp:363
This class is for reporting differences between two modules which should have been deduplicated.
Definition: Dedup.cpp:341
DenseSet< Attribute > nonessentialAttributes
Definition: Dedup.cpp:801
std::string prettyPrint(Attribute attr)
Definition: Dedup.cpp:366
LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a, FInstanceLike b)
Definition: Dedup.cpp:648
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Block &aBlock, Operation *b, Block &bBlock)
Definition: Dedup.cpp:437
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, Type aType, Operation *b, Type bType)
Definition: Dedup.cpp:411
StringAttr noDedupClass
Definition: Dedup.cpp:796
StringAttr dedupGroupAttrName
Definition: Dedup.cpp:798
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, DictionaryAttr aDict, Operation *b, DictionaryAttr bDict)
Definition: Dedup.cpp:560
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Region &aRegion, Operation *b, Region &bRegion)
Definition: Dedup.cpp:515
StringAttr portDirectionsAttr
Definition: Dedup.cpp:794
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, BundleType aType, Operation *b, BundleType bType)
Definition: Dedup.cpp:382
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Operation *b)
Definition: Dedup.cpp:669
Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
Definition: Dedup.cpp:342
LogicalResult check(InFlightDiagnostic &diag, Operation *a, mlir::DenseBoolArrayAttr aAttr, Operation *b, mlir::DenseBoolArrayAttr bAttr)
Definition: Dedup.cpp:535
InstanceGraph & instanceGraph
Definition: Dedup.cpp:802
void check(InFlightDiagnostic &diag, Operation *a, Operation *b)
Definition: Dedup.cpp:756
std::array< uint8_t, 32 > structuralHash
Definition: Dedup.cpp:76
mlir::ArrayAttr referredModuleNames
Definition: Dedup.cpp:78
This struct contains constant string attributes shared across different threads.
Definition: Dedup.cpp:88
DenseSet< Attribute > nonessentialAttributes
Definition: Dedup.cpp:119
StructuralHasherSharedConstants(MLIRContext *context)
Definition: Dedup.cpp:89
void update(Operation *op, DictionaryAttr dict)
Hash the top level attribute dictionary of the operation.
Definition: Dedup.cpp:218
void update(BlockArgument arg)
Definition: Dedup.cpp:167
void update(Type type)
Definition: Dedup.cpp:156
void update(const void *pointer)
Definition: Dedup.cpp:134
void update(InnerRefAttr attr)
Definition: Dedup.cpp:207
void update(Operation *op)
Definition: Dedup.cpp:299
void record(void *address)
Definition: Dedup.cpp:162
void update(const SymbolTarget &target)
Definition: Dedup.cpp:202
std::pair< std::array< uint8_t, 32 >, SmallVector< StringAttr > > getHashAndModuleNames(FModuleLike module)
Definition: Dedup.cpp:127
llvm::SHA256 sha
Definition: Dedup.cpp:332
void update(Operation *op, hw::InnerSymAttr attr)
Definition: Dedup.cpp:190
DenseMap< StringAttr, SymbolTarget > innerSymTargets
Definition: Dedup.cpp:322
void update(size_t value)
Definition: Dedup.cpp:139
SmallVector< mlir::StringAttr > referredModuleNames
Definition: Dedup.cpp:325
void update(Block &block)
Definition: Dedup.cpp:284
void update(BundleType type)
Definition: Dedup.cpp:147
void update(OpResult result)
Definition: Dedup.cpp:169
void update(OpOperand &operand)
Definition: Dedup.cpp:183
StructuralHasher(const StructuralHasherSharedConstants &constants)
Definition: Dedup.cpp:123
void update(TypeID typeID)
Definition: Dedup.cpp:144
const StructuralHasherSharedConstants & constants
Definition: Dedup.cpp:328
void update(Value value, hw::InnerSymAttr attr)
Definition: Dedup.cpp:196
DenseMap< void *, unsigned > indices
Definition: Dedup.cpp:319
void update(mlir::OperationName name)
Definition: Dedup.cpp:293
uint64_t fieldID
Definition: Dedup.cpp:83
uint64_t index
Definition: Dedup.cpp:82
An annotation target is used to keep track of something that is targeted by an Annotation.
AnnotationSet getAnnotations() const
Get the annotations associated with the target.
void setAnnotations(AnnotationSet annotations) const
Set the annotations associated with the target.
This represents an annotation targeting a specific operation.
Attribute getNLAReference(hw::InnerSymbolNamespace &moduleNamespace) const
This represents an annotation targeting a specific port of a module, memory, or instance.
static ModuleInfo getEmptyKey()
Definition: Dedup.cpp:1543
static ModuleInfo getTombstoneKey()
Definition: Dedup.cpp:1549
static unsigned getHashValue(const ModuleInfo &val)
Definition: Dedup.cpp:1555
static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs)
Definition: Dedup.cpp:1565