CIRCT  19.0.0git
Dedup.cpp
Go to the documentation of this file.
1 //===- Dedup.cpp - FIRRTL module deduping -----------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements FIRRTL module deduplication.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetails.h"
24 #include "circt/Support/LLVM.h"
25 #include "mlir/IR/IRMapping.h"
26 #include "mlir/IR/ImplicitLocOpBuilder.h"
27 #include "mlir/IR/Threading.h"
28 #include "mlir/Support/LogicalResult.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/DenseMapInfo.h"
31 #include "llvm/ADT/DepthFirstIterator.h"
32 #include "llvm/ADT/PostOrderIterator.h"
33 #include "llvm/ADT/SmallPtrSet.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Format.h"
36 #include "llvm/Support/SHA256.h"
37 
38 using namespace circt;
39 using namespace firrtl;
40 using hw::InnerRefAttr;
41 
42 //===----------------------------------------------------------------------===//
43 // Hashing
44 //===----------------------------------------------------------------------===//
45 
46 llvm::raw_ostream &printHex(llvm::raw_ostream &stream,
47  ArrayRef<uint8_t> bytes) {
48  // Print the hash on a single line.
49  return stream << format_bytes(bytes, std::nullopt, 32) << "\n";
50 }
51 
52 llvm::raw_ostream &printHash(llvm::raw_ostream &stream, llvm::SHA256 &data) {
53  return printHex(stream, data.result());
54 }
55 
56 llvm::raw_ostream &printHash(llvm::raw_ostream &stream, std::string data) {
57  ArrayRef<uint8_t> bytes(reinterpret_cast<const uint8_t *>(data.c_str()),
58  data.length());
59  return printHex(stream, bytes);
60 }
61 
62 // This struct contains information to determine module module uniqueness. A
63 // first element is a structural hash of the module, and the second element is
64 // an array which tracks module names encountered in the walk. Since module
65 // names could be replaced during dedup, it's necessary to keep names up-to-date
66 // before actually combining them into structural hashes.
67 struct ModuleInfo {
68  // SHA256 hash.
69  std::array<uint8_t, 32> structuralHash;
70  // Module names referred by instance op in the module.
71  mlir::ArrayAttr referredModuleNames;
72 };
73 
74 struct SymbolTarget {
75  uint64_t index;
76  uint64_t fieldID;
77 };
78 
79 /// This struct contains constant string attributes shared across different
80 /// threads.
82  explicit StructuralHasherSharedConstants(MLIRContext *context) {
83  portTypesAttr = StringAttr::get(context, "portTypes");
84  moduleNameAttr = StringAttr::get(context, "moduleName");
85  innerSymAttr = StringAttr::get(context, "inner_sym");
86  portSymsAttr = StringAttr::get(context, "portSyms");
87  portNamesAttr = StringAttr::get(context, "portNames");
88  nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
89  nonessentialAttributes.insert(StringAttr::get(context, "name"));
90  nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
91  nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
92  nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
93  nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
94  };
95 
96  // This is a cached "portTypes" string attr.
97  StringAttr portTypesAttr;
98 
99  // This is a cached "moduleName" string attr.
100  StringAttr moduleNameAttr;
101 
102  // This is a cached "inner_sym" string attr.
103  StringAttr innerSymAttr;
104 
105  // This is a cached "portSyms" string attr.
106  StringAttr portSymsAttr;
107 
108  // This is a cached "portNames" string attr.
109  StringAttr portNamesAttr;
110 
111  // This is a set of every attribute we should ignore.
112  DenseSet<Attribute> nonessentialAttributes;
113 };
114 
117  : constants(constants){};
118 
119  std::pair<std::array<uint8_t, 32>, SmallVector<StringAttr>>
120  getHashAndModuleNames(FModuleLike module) {
121  update(&(*module));
122  auto hash = sha.final();
123  return {hash, referredModuleNames};
124  }
125 
126 private:
127  void update(const void *pointer) {
128  auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
129  sha.update(ArrayRef<uint8_t>(addr, sizeof pointer));
130  }
131 
132  void update(size_t value) {
133  auto *addr = reinterpret_cast<const uint8_t *>(&value);
134  sha.update(ArrayRef<uint8_t>(addr, sizeof value));
135  }
136 
137  void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
138 
139  // NOLINTNEXTLINE(misc-no-recursion)
140  void update(BundleType type) {
141  update(type.getTypeID());
142  for (auto &element : type.getElements()) {
143  update(element.isFlip);
144  update(element.type);
145  }
146  }
147 
148  // NOLINTNEXTLINE(misc-no-recursion)
149  void update(Type type) {
150  if (auto bundle = type_dyn_cast<BundleType>(type))
151  return update(bundle);
152  update(type.getAsOpaquePointer());
153  }
154 
155  void record(void *address) {
156  auto size = indices.size();
157  indices[address] = size;
158  }
159 
160  void update(BlockArgument arg) { record(arg.getAsOpaquePointer()); }
161 
162  void update(OpResult result) {
163  record(result.getAsOpaquePointer());
164 
165  // Like instance ops, don't use object ops' result types since they might be
166  // replaced by dedup. Record the class names and lazily combine their hashes
167  // using the same mechanism as instances and modules.
168  if (auto objectOp = dyn_cast<ObjectOp>(result.getOwner())) {
169  referredModuleNames.push_back(objectOp.getType().getNameAttr().getAttr());
170  return;
171  }
172 
173  update(result.getType());
174  }
175 
176  void update(OpOperand &operand) {
177  // We hash the value's index as it apears in the block.
178  auto it = indices.find(operand.get().getAsOpaquePointer());
179  assert(it != indices.end() && "op should have been previously hashed");
180  update(it->second);
181  }
182 
183  void update(Operation *op, hw::InnerSymAttr attr) {
184  for (auto props : attr)
185  innerSymTargets[props.getName()] =
186  SymbolTarget{indices[op], props.getFieldID()};
187  }
188 
189  void update(Value value, hw::InnerSymAttr attr) {
190  for (auto props : attr)
191  innerSymTargets[props.getName()] =
192  SymbolTarget{indices[value.getAsOpaquePointer()], props.getFieldID()};
193  }
194 
195  void update(const SymbolTarget &target) {
196  update(target.index);
197  update(target.fieldID);
198  }
199 
200  void update(InnerRefAttr attr) {
201  // We hash the value's index as it apears in the block.
202  auto it = innerSymTargets.find(attr.getName());
203  assert(it != innerSymTargets.end() &&
204  "inner symbol should have been previously hashed");
205  update(attr.getTypeID());
206  update(it->second);
207  }
208 
209  /// Hash the top level attribute dictionary of the operation. This function
210  /// has special handling for inner symbols, ports, and referenced modules.
211  void update(Operation *op, DictionaryAttr dict) {
212  for (auto namedAttr : dict) {
213  auto name = namedAttr.getName();
214  auto value = namedAttr.getValue();
215  // Skip names and annotations, except in certain cases.
216  // Names of ports are load bearing for classes, so we do hash those.
217  bool isClassPortNames =
218  isa<ClassLike>(op) && name == constants.portNamesAttr;
219  if (constants.nonessentialAttributes.contains(name) && !isClassPortNames)
220  continue;
221 
222  // Hash the port types.
223  if (name == constants.portTypesAttr) {
224  auto portTypes = cast<ArrayAttr>(value).getAsValueRange<TypeAttr>();
225  for (auto type : portTypes)
226  update(type);
227  continue;
228  }
229 
230  // Special case the InnerSymbols to ignore the symbol names.
231  if (name == constants.portSymsAttr) {
232  if (op->getNumRegions() != 1)
233  continue;
234  auto &region = op->getRegion(0);
235  if (region.getBlocks().empty())
236  continue;
237  auto *block = &region.front();
238  auto syms = cast<ArrayAttr>(value).getAsRange<hw::InnerSymAttr>();
239  if (syms.empty())
240  continue;
241  for (auto [arg, sym] : llvm::zip_equal(block->getArguments(), syms))
242  update(arg, sym);
243  continue;
244  }
245  if (name == constants.innerSymAttr) {
246  auto innerSym = cast<hw::InnerSymAttr>(value);
247  update(op, innerSym);
248  continue;
249  }
250 
251  // For instance op, don't use `moduleName` attributes since they might be
252  // replaced by dedup. Record the names and lazily combine their hashes.
253  // It is assumed that module names are hashed only through instance ops;
254  // it could cause suboptimal results if there was other operation that
255  // refers to module names through essential attributes.
256  if (isa<InstanceOp>(op) && name == constants.moduleNameAttr) {
257  referredModuleNames.push_back(cast<FlatSymbolRefAttr>(value).getAttr());
258  continue;
259  }
260 
261  // Hash the interned pointer.
262  update(name.getAsOpaquePointer());
263 
264  // TODO: properly handle DistinctAttr, including its use in paths.
265  // See https://github.com/llvm/circt/issues/6583.
266  if (isa<DistinctAttr>(value))
267  continue;
268 
269  // If this is an symbol reference, we need to perform name erasure.
270  if (auto innerRef = dyn_cast<hw::InnerRefAttr>(value))
271  update(innerRef);
272  else
273  update(value.getAsOpaquePointer());
274  }
275  }
276 
277  void update(Block &block) {
278  // Hash the block arguments.
279  for (auto arg : block.getArguments())
280  update(arg);
281  // Hash the operations in the block.
282  for (auto &op : block)
283  update(&op);
284  }
285 
286  void update(mlir::OperationName name) {
287  // Operation names are interned.
288  update(name.getAsOpaquePointer());
289  }
290 
291  // NOLINTNEXTLINE(misc-no-recursion)
292  void update(Operation *op) {
293  record(op);
294  update(op->getName());
295  update(op, op->getAttrDictionary());
296  // Hash the operands.
297  for (auto &operand : op->getOpOperands())
298  update(operand);
299  // Hash the regions. We need to make sure an empty region doesn't hash the
300  // same as no region, so we include the number of regions.
301  update(op->getNumRegions());
302  for (auto &region : op->getRegions())
303  for (auto &block : region.getBlocks())
304  update(block);
305  // Record any op results.
306  for (auto result : op->getResults())
307  update(result);
308  }
309 
310  // Every operation and value is assigned a unique id based on their order of
311  // appearance
312  DenseMap<void *, unsigned> indices;
313 
314  // Every value is assigned a unique id based on their order of appearance.
315  DenseMap<StringAttr, SymbolTarget> innerSymTargets;
316 
317  // This keeps track of module names in the order of the appearance.
318  SmallVector<mlir::StringAttr> referredModuleNames;
319 
320  // String constants.
322 
323  // This is the actual running hash calculation. This is a stateful element
324  // that should be reinitialized after each hash is produced.
325  llvm::SHA256 sha;
326 };
327 
328 //===----------------------------------------------------------------------===//
329 // Equivalence
330 //===----------------------------------------------------------------------===//
331 
332 /// This class is for reporting differences between two modules which should
333 /// have been deduplicated.
334 struct Equivalence {
335  Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
336  : instanceGraph(instanceGraph) {
337  noDedupClass = StringAttr::get(context, noDedupAnnoClass);
338  dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
339  portDirectionsAttr = StringAttr::get(context, "portDirections");
340  nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
341  nonessentialAttributes.insert(StringAttr::get(context, "name"));
342  nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
343  nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
344  nonessentialAttributes.insert(StringAttr::get(context, "portTypes"));
345  nonessentialAttributes.insert(StringAttr::get(context, "portSyms"));
346  nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
347  nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
348  nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
349  }
350 
351  struct ModuleData {
352  ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
353  : a(a), b(b) {}
354  IRMapping map;
355  const hw::InnerSymbolTable &a;
356  const hw::InnerSymbolTable &b;
357  };
358 
359  std::string prettyPrint(Attribute attr) {
360  SmallString<64> buffer;
361  llvm::raw_svector_ostream os(buffer);
362  if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
363  os << "0x";
364  if (integerAttr.getType().isSignlessInteger())
365  integerAttr.getValue().toStringUnsigned(buffer, /*radix=*/16);
366  else
367  integerAttr.getAPSInt().toString(buffer, /*radix=*/16);
368 
369  } else
370  os << attr;
371  return std::string(buffer);
372  }
373 
374  // NOLINTNEXTLINE(misc-no-recursion)
375  LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
376  Operation *a, BundleType aType, Operation *b,
377  BundleType bType) {
378  if (aType.getNumElements() != bType.getNumElements()) {
379  diag.attachNote(a->getLoc())
380  << message << " bundle type has different number of elements";
381  diag.attachNote(b->getLoc()) << "second operation here";
382  return failure();
383  }
384 
385  for (auto elementPair :
386  llvm::zip(aType.getElements(), bType.getElements())) {
387  auto aElement = std::get<0>(elementPair);
388  auto bElement = std::get<1>(elementPair);
389  if (aElement.isFlip != bElement.isFlip) {
390  diag.attachNote(a->getLoc()) << message << " bundle element "
391  << aElement.name << " flip does not match";
392  diag.attachNote(b->getLoc()) << "second operation here";
393  return failure();
394  }
395 
396  if (failed(check(diag,
397  "bundle element \'" + aElement.name.getValue() + "'", a,
398  aElement.type, b, bElement.type)))
399  return failure();
400  }
401  return success();
402  }
403 
404  LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
405  Operation *a, Type aType, Operation *b, Type bType) {
406  if (aType == bType)
407  return success();
408  if (auto aBundleType = type_dyn_cast<BundleType>(aType))
409  if (auto bBundleType = type_dyn_cast<BundleType>(bType))
410  return check(diag, message, a, aBundleType, b, bBundleType);
411  if (type_isa<RefType>(aType) && type_isa<RefType>(bType) &&
412  aType != bType) {
413  diag.attachNote(a->getLoc())
414  << message << ", has a RefType with a different base type "
415  << type_cast<RefType>(aType).getType()
416  << " in the same position of the two modules marked as 'must dedup'. "
417  "(This may be due to Grand Central Taps or Views being different "
418  "between the two modules.)";
419  diag.attachNote(b->getLoc())
420  << "the second module has a different base type "
421  << type_cast<RefType>(bType).getType();
422  return failure();
423  }
424  diag.attachNote(a->getLoc())
425  << message << " types don't match, first type is " << aType;
426  diag.attachNote(b->getLoc()) << "second type is " << bType;
427  return failure();
428  }
429 
430  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
431  Block &aBlock, Operation *b, Block &bBlock) {
432 
433  // Block argument types.
434  auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
435  auto portNo = 0;
436  auto emitMissingPort = [&](Value existsVal, Operation *opExists,
437  Operation *opDoesNotExist) {
438  StringRef portName;
439  auto portNames = opExists->getAttrOfType<ArrayAttr>("portNames");
440  if (portNames)
441  if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
442  portName = portNameAttr.getValue();
443  if (type_isa<RefType>(existsVal.getType())) {
444  diag.attachNote(opExists->getLoc())
445  << " contains a RefType port named '" + portName +
446  "' that only exists in one of the modules (can be due to "
447  "difference in Grand Central Tap or View of two modules "
448  "marked with must dedup)";
449  diag.attachNote(opDoesNotExist->getLoc())
450  << "second module to be deduped that does not have the RefType "
451  "port";
452  } else {
453  diag.attachNote(opExists->getLoc())
454  << "port '" + portName + "' only exists in one of the modules";
455  diag.attachNote(opDoesNotExist->getLoc())
456  << "second module to be deduped that does not have the port";
457  }
458  return failure();
459  };
460 
461  for (auto argPair :
462  llvm::zip_longest(aBlock.getArguments(), bBlock.getArguments())) {
463  auto &aArg = std::get<0>(argPair);
464  auto &bArg = std::get<1>(argPair);
465  if (aArg.has_value() && bArg.has_value()) {
466  // TODO: we should print the port number if there are no port names, but
467  // there are always port names ;).
468  StringRef portName;
469  if (portNames) {
470  if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
471  portName = portNameAttr.getValue();
472  }
473  // Assumption here that block arguments correspond to ports.
474  if (failed(check(diag, "module port '" + portName + "'", a,
475  aArg->getType(), b, bArg->getType())))
476  return failure();
477  data.map.map(aArg.value(), bArg.value());
478  portNo++;
479  continue;
480  }
481  if (!aArg.has_value())
482  std::swap(a, b);
483  return emitMissingPort(aArg.has_value() ? aArg.value() : bArg.value(), a,
484  b);
485  }
486 
487  // Blocks operations.
488  auto aIt = aBlock.begin();
489  auto aEnd = aBlock.end();
490  auto bIt = bBlock.begin();
491  auto bEnd = bBlock.end();
492  while (aIt != aEnd && bIt != bEnd)
493  if (failed(check(diag, data, &*aIt++, &*bIt++)))
494  return failure();
495  if (aIt != aEnd) {
496  diag.attachNote(aIt->getLoc()) << "first block has more operations";
497  diag.attachNote(b->getLoc()) << "second block here";
498  return failure();
499  }
500  if (bIt != bEnd) {
501  diag.attachNote(bIt->getLoc()) << "second block has more operations";
502  diag.attachNote(a->getLoc()) << "first block here";
503  return failure();
504  }
505  return success();
506  }
507 
508  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
509  Region &aRegion, Operation *b, Region &bRegion) {
510  auto aIt = aRegion.begin();
511  auto aEnd = aRegion.end();
512  auto bIt = bRegion.begin();
513  auto bEnd = bRegion.end();
514 
515  // Region blocks.
516  while (aIt != aEnd && bIt != bEnd)
517  if (failed(check(diag, data, a, *aIt++, b, *bIt++)))
518  return failure();
519  if (aIt != aEnd || bIt != bEnd) {
520  diag.attachNote(a->getLoc())
521  << "operation regions have different number of blocks";
522  diag.attachNote(b->getLoc()) << "second operation here";
523  return failure();
524  }
525  return success();
526  }
527 
528  LogicalResult check(InFlightDiagnostic &diag, Operation *a,
529  mlir::DenseBoolArrayAttr aAttr, Operation *b,
530  mlir::DenseBoolArrayAttr bAttr) {
531  if (aAttr == bAttr)
532  return success();
533  auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
534  for (unsigned i = 0, e = aAttr.size(); i < e; ++i) {
535  auto aDirection = aAttr[i];
536  auto bDirection = bAttr[i];
537  if (aDirection != bDirection) {
538  auto &note = diag.attachNote(a->getLoc()) << "module port ";
539  if (portNames)
540  note << "'" << cast<StringAttr>(portNames[i]).getValue() << "'";
541  else
542  note << i;
543  note << " directions don't match, first direction is '"
544  << direction::toString(aDirection) << "'";
545  diag.attachNote(b->getLoc()) << "second direction is '"
546  << direction::toString(bDirection) << "'";
547  return failure();
548  }
549  }
550  return success();
551  }
552 
553  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
554  DictionaryAttr aDict, Operation *b,
555  DictionaryAttr bDict) {
556  // Fast path.
557  if (aDict == bDict)
558  return success();
559 
560  DenseSet<Attribute> seenAttrs;
561  for (auto namedAttr : aDict) {
562  auto attrName = namedAttr.getName();
563  if (nonessentialAttributes.contains(attrName))
564  continue;
565 
566  auto aAttr = namedAttr.getValue();
567  auto bAttr = bDict.get(attrName);
568  if (!bAttr) {
569  diag.attachNote(a->getLoc())
570  << "second operation is missing attribute " << attrName;
571  diag.attachNote(b->getLoc()) << "second operation here";
572  return diag;
573  }
574 
575  if (isa<hw::InnerRefAttr>(aAttr) && isa<hw::InnerRefAttr>(bAttr)) {
576  auto bRef = cast<hw::InnerRefAttr>(bAttr);
577  auto aRef = cast<hw::InnerRefAttr>(aAttr);
578  // See if they are pointing at the same operation or port.
579  auto aTarget = data.a.lookup(aRef.getName());
580  auto bTarget = data.b.lookup(bRef.getName());
581  if (!aTarget || !bTarget)
582  diag.attachNote(a->getLoc())
583  << "malformed ir, possibly violating use-before-def";
584  auto error = [&]() {
585  diag.attachNote(a->getLoc())
586  << "operations have different targets, first operation has "
587  << aTarget;
588  diag.attachNote(b->getLoc()) << "second operation has " << bTarget;
589  return failure();
590  };
591  if (aTarget.isPort()) {
592  // If they are targeting ports, make sure its the same port number.
593  if (!bTarget.isPort() || aTarget.getPort() != bTarget.getPort())
594  return error();
595  } else {
596  // Otherwise make sure that they are targeting the same operation.
597  if (!bTarget.isOpOnly() ||
598  aTarget.getOp() != data.map.lookup(bTarget.getOp()))
599  return error();
600  }
601  if (aTarget.getField() != bTarget.getField())
602  return error();
603  } else if (attrName == portDirectionsAttr) {
604  // Special handling for the port directions attribute for better
605  // error messages.
606  if (failed(check(diag, a, cast<mlir::DenseBoolArrayAttr>(aAttr), b,
607  cast<mlir::DenseBoolArrayAttr>(bAttr))))
608  return failure();
609  } else if (isa<DistinctAttr>(aAttr) && isa<DistinctAttr>(bAttr)) {
610  // TODO: properly handle DistinctAttr, including its use in paths.
611  // See https://github.com/llvm/circt/issues/6583
612  } else if (aAttr != bAttr) {
613  diag.attachNote(a->getLoc())
614  << "first operation has attribute '" << attrName.getValue()
615  << "' with value " << prettyPrint(aAttr);
616  diag.attachNote(b->getLoc())
617  << "second operation has value " << prettyPrint(bAttr);
618  return failure();
619  }
620  seenAttrs.insert(attrName);
621  }
622  if (aDict.getValue().size() != bDict.getValue().size()) {
623  for (auto namedAttr : bDict) {
624  auto attrName = namedAttr.getName();
625  // Skip the attribute if we don't care about this particular one or it
626  // is one that is known to be in both dictionaries.
627  if (nonessentialAttributes.contains(attrName) ||
628  seenAttrs.contains(attrName))
629  continue;
630  // We have found an attribute that is only in the second operation.
631  diag.attachNote(a->getLoc())
632  << "first operation is missing attribute " << attrName;
633  diag.attachNote(b->getLoc()) << "second operation here";
634  return failure();
635  }
636  }
637  return success();
638  }
639 
640  // NOLINTNEXTLINE(misc-no-recursion)
641  LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a,
642  FInstanceLike b) {
643  auto aName = a.getReferencedModuleNameAttr();
644  auto bName = b.getReferencedModuleNameAttr();
645  if (aName == bName)
646  return success();
647 
648  // If the modules instantiate are different we will want to know why the
649  // sub module did not dedupliate. This code recursively checks the child
650  // module.
651  auto aModule = instanceGraph.lookup(aName)->getModule();
652  auto bModule = instanceGraph.lookup(bName)->getModule();
653  // Create a new error for the submodule.
654  diag.attachNote(std::nullopt)
655  << "in instance " << a.getInstanceNameAttr() << " of " << aName
656  << ", and instance " << b.getInstanceNameAttr() << " of " << bName;
657  check(diag, aModule, bModule);
658  return failure();
659  }
660 
661  // NOLINTNEXTLINE(misc-no-recursion)
662  LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
663  Operation *b) {
664  // Operation name.
665  if (a->getName() != b->getName()) {
666  diag.attachNote(a->getLoc()) << "first operation is a " << a->getName();
667  diag.attachNote(b->getLoc()) << "second operation is a " << b->getName();
668  return failure();
669  }
670 
671  // If its an instance operaiton, perform some checking and possibly
672  // recurse.
673  if (auto aInst = dyn_cast<FInstanceLike>(a)) {
674  auto bInst = cast<FInstanceLike>(b);
675  if (failed(check(diag, aInst, bInst)))
676  return failure();
677  }
678 
679  // Operation results.
680  if (a->getNumResults() != b->getNumResults()) {
681  diag.attachNote(a->getLoc())
682  << "operations have different number of results";
683  diag.attachNote(b->getLoc()) << "second operation here";
684  return failure();
685  }
686  for (auto resultPair : llvm::zip(a->getResults(), b->getResults())) {
687  auto &aValue = std::get<0>(resultPair);
688  auto &bValue = std::get<1>(resultPair);
689  if (failed(check(diag, "operation result", a, aValue.getType(), b,
690  bValue.getType())))
691  return failure();
692  data.map.map(aValue, bValue);
693  }
694 
695  // Operations operands.
696  if (a->getNumOperands() != b->getNumOperands()) {
697  diag.attachNote(a->getLoc())
698  << "operations have different number of operands";
699  diag.attachNote(b->getLoc()) << "second operation here";
700  return failure();
701  }
702  for (auto operandPair : llvm::zip(a->getOperands(), b->getOperands())) {
703  auto &aValue = std::get<0>(operandPair);
704  auto &bValue = std::get<1>(operandPair);
705  if (bValue != data.map.lookup(aValue)) {
706  diag.attachNote(a->getLoc())
707  << "operations use different operands, first operand is '"
708  << getFieldName(
709  getFieldRefFromValue(aValue, /*lookThroughCasts=*/true))
710  .first
711  << "'";
712  diag.attachNote(b->getLoc())
713  << "second operand is '"
714  << getFieldName(
715  getFieldRefFromValue(bValue, /*lookThroughCasts=*/true))
716  .first
717  << "', but should have been '"
718  << getFieldName(getFieldRefFromValue(data.map.lookup(aValue),
719  /*lookThroughCasts=*/true))
720  .first
721  << "'";
722  return failure();
723  }
724  }
725  data.map.map(a, b);
726 
727  // Operation regions.
728  if (a->getNumRegions() != b->getNumRegions()) {
729  diag.attachNote(a->getLoc())
730  << "operations have different number of regions";
731  diag.attachNote(b->getLoc()) << "second operation here";
732  return failure();
733  }
734  for (auto regionPair : llvm::zip(a->getRegions(), b->getRegions())) {
735  auto &aRegion = std::get<0>(regionPair);
736  auto &bRegion = std::get<1>(regionPair);
737  if (failed(check(diag, data, a, aRegion, b, bRegion)))
738  return failure();
739  }
740 
741  // Operation attributes.
742  if (failed(check(diag, data, a, a->getAttrDictionary(), b,
743  b->getAttrDictionary())))
744  return failure();
745  return success();
746  }
747 
748  // NOLINTNEXTLINE(misc-no-recursion)
749  void check(InFlightDiagnostic &diag, Operation *a, Operation *b) {
750  hw::InnerSymbolTable aTable(a);
751  hw::InnerSymbolTable bTable(b);
752  ModuleData data(aTable, bTable);
753  AnnotationSet aAnnos(a);
754  AnnotationSet bAnnos(b);
755  if (aAnnos.hasAnnotation(noDedupClass)) {
756  diag.attachNote(a->getLoc()) << "module marked NoDedup";
757  return;
758  }
759  if (bAnnos.hasAnnotation(noDedupClass)) {
760  diag.attachNote(b->getLoc()) << "module marked NoDedup";
761  return;
762  }
763  auto aGroup =
764  dyn_cast_or_null<StringAttr>(a->getDiscardableAttr(dedupGroupAttrName));
765  auto bGroup = dyn_cast_or_null<StringAttr>(
766  b->getAttrOfType<StringAttr>(dedupGroupAttrName));
767  if (aGroup != bGroup) {
768  if (bGroup) {
769  diag.attachNote(b->getLoc())
770  << "module is in dedup group '" << bGroup.str() << "'";
771  } else {
772  diag.attachNote(b->getLoc()) << "module is not part of a dedup group";
773  }
774  if (aGroup) {
775  diag.attachNote(a->getLoc())
776  << "module is in dedup group '" << aGroup.str() << "'";
777  } else {
778  diag.attachNote(a->getLoc()) << "module is not part of a dedup group";
779  }
780  return;
781  }
782  if (failed(check(diag, data, a, b)))
783  return;
784  diag.attachNote(a->getLoc()) << "first module here";
785  diag.attachNote(b->getLoc()) << "second module here";
786  }
787 
788  // This is a cached "portDirections" string attr.
789  StringAttr portDirectionsAttr;
790  // This is a cached "NoDedup" annotation class string attr.
791  StringAttr noDedupClass;
792  // This is a cached string attr for the dedup group attribute.
793  StringAttr dedupGroupAttrName;
794 
795  // This is a set of every attribute we should ignore.
796  DenseSet<Attribute> nonessentialAttributes;
798 };
799 
800 //===----------------------------------------------------------------------===//
801 // Deduplication
802 //===----------------------------------------------------------------------===//
803 
804 // Custom location merging. This only keeps track of 8 annotations from ".fir"
805 // files, and however many annotations come from "real" sources. When
806 // deduplicating, modules tend not to have scala source locators, so we wind
807 // up fusing source locators for a module from every copy being deduped. There
808 // is little value in this (all the modules are identical by definition).
809 static Location mergeLoc(MLIRContext *context, Location to, Location from) {
810  // Unique the set of locations to be fused.
811  llvm::SmallSetVector<Location, 4> decomposedLocs;
812  // only track 8 "fir" locations
813  unsigned seenFIR = 0;
814  for (auto loc : {to, from}) {
815  // If the location is a fused location we decompose it if it has no
816  // metadata or the metadata is the same as the top level metadata.
817  if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
818  // UnknownLoc's have already been removed from FusedLocs so we can
819  // simply add all of the internal locations.
820  for (auto loc : fusedLoc.getLocations()) {
821  if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
822  if (fileLoc.getFilename().strref().ends_with(".fir")) {
823  ++seenFIR;
824  if (seenFIR > 8)
825  continue;
826  }
827  }
828  decomposedLocs.insert(loc);
829  }
830  continue;
831  }
832 
833  // Might need to skip this fir.
834  if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
835  if (fileLoc.getFilename().strref().ends_with(".fir")) {
836  ++seenFIR;
837  if (seenFIR > 8)
838  continue;
839  }
840  }
841  // Otherwise, only add known locations to the set.
842  if (!isa<UnknownLoc>(loc))
843  decomposedLocs.insert(loc);
844  }
845 
846  auto locs = decomposedLocs.getArrayRef();
847 
848  // Handle the simple cases of less than two locations. Ensure the metadata (if
849  // provided) is not dropped.
850  if (locs.empty())
851  return UnknownLoc::get(context);
852  if (locs.size() == 1)
853  return locs.front();
854 
855  return FusedLoc::get(context, locs);
856 }
857 
858 struct Deduper {
859 
860  using RenameMap = DenseMap<StringAttr, StringAttr>;
861 
862  Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable,
863  NLATable *nlaTable, CircuitOp circuit)
864  : context(circuit->getContext()), instanceGraph(instanceGraph),
865  symbolTable(symbolTable), nlaTable(nlaTable),
866  nlaBlock(circuit.getBodyBlock()),
867  nonLocalString(StringAttr::get(context, "circt.nonlocal")),
868  classString(StringAttr::get(context, "class")) {
869  // Populate the NLA cache.
870  for (auto nla : circuit.getOps<hw::HierPathOp>())
871  nlaCache[nla.getNamepathAttr()] = nla.getSymNameAttr();
872  }
873 
874  /// Remove the "fromModule", and replace all references to it with the
875  /// "toModule". Modules should be deduplicated in a bottom-up order. Any
876  /// module which is not deduplicated needs to be recorded with the `record`
877  /// call.
878  void dedup(FModuleLike toModule, FModuleLike fromModule) {
879  // A map of operation (e.g. wires, nodes) names which are changed, which is
880  // used to update NLAs that reference the "fromModule".
881  RenameMap renameMap;
882 
883  // Merge the port locations.
884  SmallVector<Attribute> newLocs;
885  for (auto [toLoc, fromLoc] : llvm::zip(toModule.getPortLocations(),
886  fromModule.getPortLocations())) {
887  if (toLoc == fromLoc)
888  newLocs.push_back(toLoc);
889  else
890  newLocs.push_back(mergeLoc(context, cast<LocationAttr>(toLoc),
891  cast<LocationAttr>(fromLoc)));
892  }
893  toModule->setAttr("portLocations", ArrayAttr::get(context, newLocs));
894 
895  // Merge the two modules.
896  mergeOps(renameMap, toModule, toModule, fromModule, fromModule);
897 
898  // Rewrite NLAs pathing through these modules to refer to the to module. It
899  // is safe to do this at this point because NLAs cannot be one element long.
900  // This means that all NLAs which require more context cannot be targetting
901  // something in the module it self.
902  if (auto to = dyn_cast<FModuleOp>(*toModule))
903  rewriteModuleNLAs(renameMap, to, cast<FModuleOp>(*fromModule));
904  else
905  rewriteExtModuleNLAs(renameMap, toModule.getModuleNameAttr(),
906  fromModule.getModuleNameAttr());
907 
908  replaceInstances(toModule, fromModule);
909  }
910 
911  /// Record the usages of any NLA's in this module, so that we may update the
912  /// annotation if the parent module is deduped with another module.
913  void record(FModuleLike module) {
914  // Record any annotations on the module.
915  recordAnnotations(module);
916  // Record port annotations.
917  for (unsigned i = 0, e = getNumPorts(module); i < e; ++i)
918  recordAnnotations(PortAnnoTarget(module, i));
919  // Record any annotations in the module body.
920  module->walk([&](Operation *op) { recordAnnotations(op); });
921  }
922 
923 private:
924  /// Get a cached namespace for a module.
925  hw::InnerSymbolNamespace &getNamespace(Operation *module) {
926  return moduleNamespaces.try_emplace(module, cast<FModuleLike>(module))
927  .first->second;
928  }
929 
930  /// For a specific annotation target, record all the unique NLAs which
931  /// target it in the `targetMap`.
933  for (auto anno : target.getAnnotations())
934  if (auto nlaRef = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
935  targetMap[nlaRef.getAttr()].insert(target);
936  }
937 
938  /// Record all targets which use an NLA.
939  void recordAnnotations(Operation *op) {
940  // Record annotations.
941  recordAnnotations(OpAnnoTarget(op));
942 
943  // Record port annotations only if this is a mem operation.
944  auto mem = dyn_cast<MemOp>(op);
945  if (!mem)
946  return;
947 
948  // Record port annotations.
949  for (unsigned i = 0, e = mem->getNumResults(); i < e; ++i)
950  recordAnnotations(PortAnnoTarget(mem, i));
951  }
952 
953  /// This deletes and replaces all instances of the "fromModule" with instances
954  /// of the "toModule".
955  void replaceInstances(FModuleLike toModule, Operation *fromModule) {
956  // Replace all instances of the other module.
957  auto *fromNode =
958  instanceGraph[::cast<igraph::ModuleOpInterface>(fromModule)];
959  auto *toNode = instanceGraph[toModule];
960  auto toModuleRef = FlatSymbolRefAttr::get(toModule.getModuleNameAttr());
961  for (auto *oldInstRec : llvm::make_early_inc_range(fromNode->uses())) {
962  auto inst = oldInstRec->getInstance();
963  if (auto instOp = dyn_cast<InstanceOp>(*inst)) {
964  instOp.setModuleNameAttr(toModuleRef);
965  instOp.setPortNamesAttr(toModule.getPortNamesAttr());
966  } else if (auto objectOp = dyn_cast<ObjectOp>(*inst)) {
967  auto classLike = cast<ClassLike>(*toNode->getModule());
968  ClassType classType = detail::getInstanceTypeForClassLike(classLike);
969  objectOp.getResult().setType(classType);
970  }
971  oldInstRec->getParent()->addInstance(inst, toNode);
972  oldInstRec->erase();
973  }
974  instanceGraph.erase(fromNode);
975  fromModule->erase();
976  }
977 
978  /// Look up the instantiations of the `from` module and create an NLA for each
979  /// one, appending the baseNamepath to each NLA. This is used to add more
980  /// context to an already existing NLA. The `fromModule` is used to indicate
981  /// which module the annotation is coming from before the merge, and will be
982  /// used to create the namepaths.
983  SmallVector<FlatSymbolRefAttr>
984  createNLAs(Operation *fromModule, ArrayRef<Attribute> baseNamepath,
985  SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
986  // Create an attribute array with a placeholder in the first element, where
987  // the root refence of the NLA will be inserted.
988  SmallVector<Attribute> namepath = {nullptr};
989  namepath.append(baseNamepath.begin(), baseNamepath.end());
990 
991  auto loc = fromModule->getLoc();
992  auto *fromNode = instanceGraph[cast<igraph::ModuleOpInterface>(fromModule)];
993  SmallVector<FlatSymbolRefAttr> nlas;
994  for (auto *instanceRecord : fromNode->uses()) {
995  auto parent = cast<FModuleOp>(*instanceRecord->getParent()->getModule());
996  auto inst = instanceRecord->getInstance();
997  namepath[0] = OpAnnoTarget(inst).getNLAReference(getNamespace(parent));
998  auto arrayAttr = ArrayAttr::get(context, namepath);
999  // Check the NLA cache to see if we already have this NLA.
1000  auto &cacheEntry = nlaCache[arrayAttr];
1001  if (!cacheEntry) {
1002  auto nla = OpBuilder::atBlockBegin(nlaBlock).create<hw::HierPathOp>(
1003  loc, "nla", arrayAttr);
1004  // Insert it into the symbol table to get a unique name.
1005  symbolTable.insert(nla);
1006  // Store it in the cache.
1007  cacheEntry = nla.getNameAttr();
1008  nla.setVisibility(vis);
1009  nlaTable->addNLA(nla);
1010  }
1011  auto nlaRef = FlatSymbolRefAttr::get(cast<StringAttr>(cacheEntry));
1012  nlas.push_back(nlaRef);
1013  }
1014  return nlas;
1015  }
1016 
1017  /// Look up the instantiations of this module and create an NLA for each one.
1018  /// This returns an array of symbol references which can be used to reference
1019  /// the NLAs.
1020  SmallVector<FlatSymbolRefAttr>
1021  createNLAs(StringAttr toModuleName, FModuleLike fromModule,
1022  SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1023  return createNLAs(fromModule, FlatSymbolRefAttr::get(toModuleName), vis);
1024  }
1025 
1026  /// Clone the annotation for each NLA in a list. The attribute list should
1027  /// have a placeholder for the "circt.nonlocal" field, and `nonLocalIndex`
1028  /// should be the index of this field.
1029  void cloneAnnotation(SmallVectorImpl<FlatSymbolRefAttr> &nlas,
1030  Annotation anno, ArrayRef<NamedAttribute> attributes,
1031  unsigned nonLocalIndex,
1032  SmallVectorImpl<Annotation> &newAnnotations) {
1033  SmallVector<NamedAttribute> mutableAttributes(attributes.begin(),
1034  attributes.end());
1035  for (auto &nla : nlas) {
1036  // Add the new annotation.
1037  mutableAttributes[nonLocalIndex].setValue(nla);
1038  auto dict = DictionaryAttr::getWithSorted(context, mutableAttributes);
1039  // The original annotation records if its a subannotation.
1040  anno.setDict(dict);
1041  newAnnotations.push_back(anno);
1042  }
1043  }
1044 
1045  /// This erases the NLA op, and removes the NLA from every module's NLA map,
1046  /// but it does not delete the NLA reference from the target operation's
1047  /// annotations.
1048  void eraseNLA(hw::HierPathOp nla) {
1049  // Erase the NLA from the leaf module's nlaMap.
1050  targetMap.erase(nla.getNameAttr());
1051  nlaTable->erase(nla);
1052  nlaCache.erase(nla.getNamepathAttr());
1053  symbolTable.erase(nla);
1054  }
1055 
1056  /// Process all NLAs referencing the "from" module to point to the "to"
1057  /// module. This is used after merging two modules together.
1058  void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule,
1059  FModuleOp fromModule) {
1060  auto toName = toModule.getNameAttr();
1061  auto fromName = fromModule.getNameAttr();
1062  // Create a copy of the current NLAs. We will be pushing and removing
1063  // NLAs from this op as we go.
1064  auto moduleNLAs = nlaTable->lookup(fromModule.getNameAttr()).vec();
1065  // Change the NLA to target the toModule.
1066  nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1067  // Now we walk the NLA searching for ones that require more context to be
1068  // added.
1069  for (auto nla : moduleNLAs) {
1070  auto elements = nla.getNamepath().getValue();
1071  // If we don't need to add more context, we're done here.
1072  if (nla.root() != toName)
1073  continue;
1074  // Create the replacement NLAs.
1075  SmallVector<Attribute> namepath(elements.begin(), elements.end());
1076  auto nlaRefs = createNLAs(fromModule, namepath, nla.getVisibility());
1077  // Copy out the targets, because we will be updating the map.
1078  auto &set = targetMap[nla.getSymNameAttr()];
1079  SmallVector<AnnoTarget> targets(set.begin(), set.end());
1080  // Replace the uses of the old NLA with the new NLAs.
1081  for (auto target : targets) {
1082  // We have to clone any annotation which uses the old NLA for each new
1083  // NLA. This array collects the new set of annotations.
1084  SmallVector<Annotation> newAnnotations;
1085  for (auto anno : target.getAnnotations()) {
1086  // Find the non-local field of the annotation.
1087  auto [it, found] = mlir::impl::findAttrSorted(
1088  anno.begin(), anno.end(), nonLocalString);
1089  // If this annotation doesn't use the target NLA, copy it with no
1090  // changes.
1091  if (!found || cast<FlatSymbolRefAttr>(it->getValue()).getAttr() !=
1092  nla.getSymNameAttr()) {
1093  newAnnotations.push_back(anno);
1094  continue;
1095  }
1096  auto nonLocalIndex = std::distance(anno.begin(), it);
1097  // Clone the annotation and add it to the list of new annotations.
1098  cloneAnnotation(nlaRefs, anno,
1099  ArrayRef<NamedAttribute>(anno.begin(), anno.end()),
1100  nonLocalIndex, newAnnotations);
1101  }
1102 
1103  // Apply the new annotations to the operation.
1104  AnnotationSet annotations(newAnnotations, context);
1105  target.setAnnotations(annotations);
1106  // Record that target uses the NLA.
1107  for (auto nla : nlaRefs)
1108  targetMap[nla.getAttr()].insert(target);
1109  }
1110 
1111  // Erase the old NLA and remove it from all breadcrumbs.
1112  eraseNLA(nla);
1113  }
1114  }
1115 
1116  /// Process all the NLAs that the two modules participate in, replacing
1117  /// references to the "from" module with references to the "to" module, and
1118  /// adding more context if necessary.
1119  void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule,
1120  FModuleOp fromModule) {
1121  addAnnotationContext(renameMap, toModule, toModule);
1122  addAnnotationContext(renameMap, toModule, fromModule);
1123  }
1124 
1125  // Update all NLAs which the "from" external module participates in to the
1126  // "toName".
1127  void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName,
1128  StringAttr fromName) {
1129  nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1130  }
1131 
1132  /// Take an annotation, and update it to be a non-local annotation. If the
1133  /// annotation is already non-local and has enough context, it will be skipped
1134  /// for now. Return true if the annotation was made non-local.
1135  bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to,
1136  FModuleLike fromModule, Annotation anno,
1137  SmallVectorImpl<Annotation> &newAnnotations) {
1138  // Start constructing a new annotation, pushing a "circt.nonLocal" field
1139  // into the correct spot if its not already a non-local annotation.
1140  SmallVector<NamedAttribute> attributes;
1141  int nonLocalIndex = -1;
1142  for (const auto &val : llvm::enumerate(anno)) {
1143  auto attr = val.value();
1144  // Is this field "circt.nonlocal"?
1145  auto compare = attr.getName().compare(nonLocalString);
1146  assert(compare != 0 && "should not pass non-local annotations here");
1147  if (compare == 1) {
1148  // This annotation definitely does not have "circt.nonlocal" field. Push
1149  // an empty place holder for the non-local annotation.
1150  nonLocalIndex = val.index();
1151  attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1152  break;
1153  }
1154  // Otherwise push the current attribute and keep searching for the
1155  // "circt.nonlocal" field.
1156  attributes.push_back(attr);
1157  }
1158  if (nonLocalIndex == -1) {
1159  // Push an empty "circt.nonlocal" field to the last slot.
1160  nonLocalIndex = attributes.size();
1161  attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1162  } else {
1163  // Copy the remaining annotation fields in.
1164  attributes.append(anno.begin() + nonLocalIndex, anno.end());
1165  }
1166 
1167  // Construct the NLAs if we don't have any yet.
1168  auto nlaRefs = createNLAs(toModuleName, fromModule);
1169  for (auto nla : nlaRefs)
1170  targetMap[nla.getAttr()].insert(to);
1171 
1172  // Clone the annotation for each new NLA.
1173  cloneAnnotation(nlaRefs, anno, attributes, nonLocalIndex, newAnnotations);
1174  return true;
1175  }
1176 
1177  void copyAnnotations(FModuleLike toModule, AnnoTarget to,
1178  FModuleLike fromModule, AnnotationSet annos,
1179  SmallVectorImpl<Annotation> &newAnnotations,
1180  SmallPtrSetImpl<Attribute> &dontTouches) {
1181  for (auto anno : annos) {
1182  if (anno.isClass(dontTouchAnnoClass)) {
1183  // Remove the nonlocal field of the annotation if it has one, since this
1184  // is a sticky annotation.
1185  anno.removeMember("circt.nonlocal");
1186  auto [it, inserted] = dontTouches.insert(anno.getAttr());
1187  if (inserted)
1188  newAnnotations.push_back(anno);
1189  continue;
1190  }
1191  // If the annotation is already non-local, we add it as is. It is already
1192  // added to the target map.
1193  if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal")) {
1194  newAnnotations.push_back(anno);
1195  targetMap[nla.getAttr()].insert(to);
1196  continue;
1197  }
1198  // Otherwise make the annotation non-local and add it to the set.
1199  makeAnnotationNonLocal(toModule.getModuleNameAttr(), to, fromModule, anno,
1200  newAnnotations);
1201  }
1202  }
1203 
1204  /// Merge the annotations of a specific target, either a operation or a port
1205  /// on an operation.
1206  void mergeAnnotations(FModuleLike toModule, AnnoTarget to,
1207  AnnotationSet toAnnos, FModuleLike fromModule,
1208  AnnoTarget from, AnnotationSet fromAnnos) {
1209  // This is a list of all the annotations which will be added to `to`.
1210  SmallVector<Annotation> newAnnotations;
1211 
1212  // We have special case handling of DontTouch to prevent it from being
1213  // turned into a non-local annotation, and to remove duplicates.
1214  llvm::SmallPtrSet<Attribute, 4> dontTouches;
1215 
1216  // Iterate the annotations, transforming most annotations into non-local
1217  // ones.
1218  copyAnnotations(toModule, to, toModule, toAnnos, newAnnotations,
1219  dontTouches);
1220  copyAnnotations(toModule, to, fromModule, fromAnnos, newAnnotations,
1221  dontTouches);
1222 
1223  // Copy over all the new annotations.
1224  if (!newAnnotations.empty())
1225  to.setAnnotations(AnnotationSet(newAnnotations, context));
1226  }
1227 
1228  /// Merge all annotations and port annotations on two operations.
1229  void mergeAnnotations(FModuleLike toModule, Operation *to,
1230  FModuleLike fromModule, Operation *from) {
1231  // Merge op annotations.
1232  mergeAnnotations(toModule, OpAnnoTarget(to), AnnotationSet(to), fromModule,
1233  OpAnnoTarget(from), AnnotationSet(from));
1234 
1235  // Merge port annotations.
1236  if (toModule == to) {
1237  // Merge module port annotations.
1238  for (unsigned i = 0, e = getNumPorts(toModule); i < e; ++i)
1239  mergeAnnotations(toModule, PortAnnoTarget(toModule, i),
1240  AnnotationSet::forPort(toModule, i), fromModule,
1241  PortAnnoTarget(fromModule, i),
1242  AnnotationSet::forPort(fromModule, i));
1243  } else if (auto toMem = dyn_cast<MemOp>(to)) {
1244  // Merge memory port annotations.
1245  auto fromMem = cast<MemOp>(from);
1246  for (unsigned i = 0, e = toMem.getNumResults(); i < e; ++i)
1247  mergeAnnotations(toModule, PortAnnoTarget(toMem, i),
1248  AnnotationSet::forPort(toMem, i), fromModule,
1249  PortAnnoTarget(fromMem, i),
1250  AnnotationSet::forPort(fromMem, i));
1251  }
1252  }
1253 
1254  // Record the symbol name change of the operation or any of its ports when
1255  // merging two operations. The renamed symbols are used to update the
1256  // target of any NLAs. This will add symbols to the "to" operation if needed.
1257  void recordSymRenames(RenameMap &renameMap, FModuleLike toModule,
1258  Operation *to, FModuleLike fromModule,
1259  Operation *from) {
1260  // If the "from" operation has an inner_sym, we need to make sure the
1261  // "to" operation also has an `inner_sym` and then record the renaming.
1262  if (auto fromSym = getInnerSymName(from)) {
1263  auto toSym =
1264  getOrAddInnerSym(to, [&](auto _) -> hw::InnerSymbolNamespace & {
1265  return getNamespace(toModule);
1266  });
1267  renameMap[fromSym] = toSym;
1268  }
1269 
1270  // If there are no port symbols on the "from" operation, we are done here.
1271  auto fromPortSyms = from->getAttrOfType<ArrayAttr>("portSyms");
1272  if (!fromPortSyms || fromPortSyms.empty())
1273  return;
1274  // We have to map each "fromPort" to each "toPort".
1275  auto &moduleNamespace = getNamespace(toModule);
1276  auto portCount = fromPortSyms.size();
1277  auto portNames = to->getAttrOfType<ArrayAttr>("portNames");
1278  auto toPortSyms = to->getAttrOfType<ArrayAttr>("portSyms");
1279 
1280  // Create an array of new port symbols for the "to" operation, copy in the
1281  // old symbols if it has any, create an empty symbol array if it doesn't.
1282  SmallVector<Attribute> newPortSyms;
1283  if (toPortSyms.empty())
1284  newPortSyms.assign(portCount, hw::InnerSymAttr());
1285  else
1286  newPortSyms.assign(toPortSyms.begin(), toPortSyms.end());
1287 
1288  for (unsigned portNo = 0; portNo < portCount; ++portNo) {
1289  // If this fromPort doesn't have a symbol, move on to the next one.
1290  if (!fromPortSyms[portNo])
1291  continue;
1292  auto fromSym = cast<hw::InnerSymAttr>(fromPortSyms[portNo]);
1293 
1294  // If this toPort doesn't have a symbol, assign one.
1295  hw::InnerSymAttr toSym;
1296  if (!newPortSyms[portNo]) {
1297  // Get a reasonable base name for the port.
1298  StringRef symName = "inner_sym";
1299  if (portNames)
1300  symName = cast<StringAttr>(portNames[portNo]).getValue();
1301  // Create the symbol and store it into the array.
1302  toSym = hw::InnerSymAttr::get(
1303  StringAttr::get(context, moduleNamespace.newName(symName)));
1304  newPortSyms[portNo] = toSym;
1305  } else
1306  toSym = cast<hw::InnerSymAttr>(newPortSyms[portNo]);
1307 
1308  // Record the renaming.
1309  renameMap[fromSym.getSymName()] = toSym.getSymName();
1310  }
1311 
1312  // Commit the new symbol attribute.
1313  cast<FModuleLike>(to).setPortSymbols(newPortSyms);
1314  }
1315 
1316  /// Recursively merge two operations.
1317  // NOLINTNEXTLINE(misc-no-recursion)
1318  void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to,
1319  FModuleLike fromModule, Operation *from) {
1320  // Merge the operation locations.
1321  if (to->getLoc() != from->getLoc())
1322  to->setLoc(mergeLoc(context, to->getLoc(), from->getLoc()));
1323 
1324  // Recurse into any regions.
1325  for (auto regions : llvm::zip(to->getRegions(), from->getRegions()))
1326  mergeRegions(renameMap, toModule, std::get<0>(regions), fromModule,
1327  std::get<1>(regions));
1328 
1329  // Record any inner_sym renamings that happened.
1330  recordSymRenames(renameMap, toModule, to, fromModule, from);
1331 
1332  // Merge the annotations.
1333  mergeAnnotations(toModule, to, fromModule, from);
1334  }
1335 
1336  /// Recursively merge two blocks.
1337  void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock,
1338  FModuleLike fromModule, Block &fromBlock) {
1339  // Merge the block locations.
1340  for (auto [toArg, fromArg] :
1341  llvm::zip(toBlock.getArguments(), fromBlock.getArguments()))
1342  if (toArg.getLoc() != fromArg.getLoc())
1343  toArg.setLoc(mergeLoc(context, toArg.getLoc(), fromArg.getLoc()));
1344 
1345  for (auto ops : llvm::zip(toBlock, fromBlock))
1346  mergeOps(renameMap, toModule, &std::get<0>(ops), fromModule,
1347  &std::get<1>(ops));
1348  }
1349 
1350  // Recursively merge two regions.
1351  void mergeRegions(RenameMap &renameMap, FModuleLike toModule,
1352  Region &toRegion, FModuleLike fromModule,
1353  Region &fromRegion) {
1354  for (auto blocks : llvm::zip(toRegion, fromRegion))
1355  mergeBlocks(renameMap, toModule, std::get<0>(blocks), fromModule,
1356  std::get<1>(blocks));
1357  }
1358 
1359  MLIRContext *context;
1361  SymbolTable &symbolTable;
1362 
1363  /// Cached nla table analysis.
1364  NLATable *nlaTable = nullptr;
1365 
1366  /// We insert all NLAs to the beginning of this block.
1367  Block *nlaBlock;
1368 
1369  // This maps an NLA to the operations and ports that uses it.
1370  DenseMap<Attribute, llvm::SmallDenseSet<AnnoTarget>> targetMap;
1371 
1372  // This is a cache to avoid creating duplicate NLAs. This maps the ArrayAtr
1373  // of the NLA's path to the name of the NLA which contains it.
1374  DenseMap<Attribute, Attribute> nlaCache;
1375 
1376  // Cached attributes for faster comparisons and attribute building.
1377  StringAttr nonLocalString;
1378  StringAttr classString;
1379 
1380  /// A module namespace cache.
1381  DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
1382 };
1383 
1384 //===----------------------------------------------------------------------===//
1385 // Fixup
1386 //===----------------------------------------------------------------------===//
1387 
1388 /// This fixes up ClassLikes with ClassType ports, when the classes have
1389 /// deduped. For each ClassType port, if the object reference being assigned is
1390 /// a different type, update the port type. Returns true if the ClassOp was
1391 /// updated and the associated ObjectOps should be updated.
1392 bool fixupClassOp(ClassOp classOp) {
1393  // New port type attributes, if necessary.
1394  SmallVector<Attribute> newPortTypes;
1395  bool anyDifferences = false;
1396 
1397  // Check each port.
1398  for (size_t i = 0, e = classOp.getNumPorts(); i < e; ++i) {
1399  // Check if this port is a ClassType. If not, save the original type
1400  // attribute in case we need to update port types.
1401  auto portClassType = dyn_cast<ClassType>(classOp.getPortType(i));
1402  if (!portClassType) {
1403  newPortTypes.push_back(classOp.getPortTypeAttr(i));
1404  continue;
1405  }
1406 
1407  // Check if this port is assigned a reference of a different ClassType.
1408  Type newPortClassType;
1409  BlockArgument portArg = classOp.getArgument(i);
1410  for (auto &use : portArg.getUses()) {
1411  if (auto propassign = dyn_cast<PropAssignOp>(use.getOwner())) {
1412  Type sourceType = propassign.getSrc().getType();
1413  if (propassign.getDest() == use.get() && sourceType != portClassType) {
1414  // Double check that all references are the same new type.
1415  if (newPortClassType) {
1416  assert(newPortClassType == sourceType &&
1417  "expected all references to be of the same type");
1418  continue;
1419  }
1420 
1421  newPortClassType = sourceType;
1422  }
1423  }
1424  }
1425 
1426  // If there was no difference, save the original type attribute in case we
1427  // need to update port types and move along.
1428  if (!newPortClassType) {
1429  newPortTypes.push_back(classOp.getPortTypeAttr(i));
1430  continue;
1431  }
1432 
1433  // The port type changed, so update the block argument, save the new port
1434  // type attribute, and indicate there was a difference.
1435  classOp.getArgument(i).setType(newPortClassType);
1436  newPortTypes.push_back(TypeAttr::get(newPortClassType));
1437  anyDifferences = true;
1438  }
1439 
1440  // If necessary, update port types.
1441  if (anyDifferences)
1442  classOp.setPortTypes(newPortTypes);
1443 
1444  return anyDifferences;
1445 }
1446 
1447 /// This fixes up ObjectOps when the signature of their ClassOp changes. This
1448 /// amounts to updating the ObjectOp result type to match the newly updated
1449 /// ClassOp type.
1450 void fixupObjectOp(ObjectOp objectOp, ClassType newClassType) {
1451  objectOp.getResult().setType(newClassType);
1452 }
1453 
1454 /// This fixes up connects when the field names of a bundle type changes. It
1455 /// finds all fields which were previously bulk connected and legalizes it
1456 /// into a connect for each field.
1457 void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src) {
1458  // If the types already match we can emit a connect.
1459  auto dstType = dst.getType();
1460  auto srcType = src.getType();
1461  if (dstType == srcType) {
1462  emitConnect(builder, dst, src);
1463  return;
1464  }
1465  // It must be a bundle type and the field name has changed. We have to
1466  // manually decompose the bulk connect into a connect for each field.
1467  auto dstBundle = type_cast<BundleType>(dstType);
1468  auto srcBundle = type_cast<BundleType>(srcType);
1469  for (unsigned i = 0; i < dstBundle.getNumElements(); ++i) {
1470  auto dstField = builder.create<SubfieldOp>(dst, i);
1471  auto srcField = builder.create<SubfieldOp>(src, i);
1472  if (dstBundle.getElement(i).isFlip) {
1473  std::swap(srcBundle, dstBundle);
1474  std::swap(srcField, dstField);
1475  }
1476  fixupConnect(builder, dstField, srcField);
1477  }
1478 }
1479 
1480 /// This is the root method to fixup module references when a module changes.
1481 /// It matches all the results of "to" module with the results of the "from"
1482 /// module.
1483 void fixupAllModules(InstanceGraph &instanceGraph) {
1484  for (auto *node : instanceGraph) {
1485  auto module = cast<FModuleLike>(*node->getModule());
1486 
1487  // Handle class declarations here.
1488  bool shouldFixupObjects = false;
1489  auto classOp = dyn_cast<ClassOp>(module.getOperation());
1490  if (classOp)
1491  shouldFixupObjects = fixupClassOp(classOp);
1492 
1493  for (auto *instRec : node->uses()) {
1494  // Handle object instantiations here.
1495  if (classOp) {
1496  if (shouldFixupObjects) {
1497  fixupObjectOp(instRec->getInstance<ObjectOp>(),
1498  classOp.getInstanceType());
1499  }
1500  continue;
1501  }
1502 
1503  auto inst = instRec->getInstance<InstanceOp>();
1504  // Only handle module instantiations here.
1505  if (!inst)
1506  continue;
1507  ImplicitLocOpBuilder builder(inst.getLoc(), inst->getContext());
1508  builder.setInsertionPointAfter(inst);
1509  for (size_t i = 0, e = getNumPorts(module); i < e; ++i) {
1510  auto result = inst.getResult(i);
1511  auto newType = module.getPortType(i);
1512  auto oldType = result.getType();
1513  // If the type has not changed, we don't have to fix up anything.
1514  if (newType == oldType)
1515  continue;
1516  // If the type changed we transform it back to the old type with an
1517  // intermediate wire.
1518  auto wire =
1519  builder.create<WireOp>(oldType, inst.getPortName(i)).getResult();
1520  result.replaceAllUsesWith(wire);
1521  result.setType(newType);
1522  if (inst.getPortDirection(i) == Direction::Out)
1523  fixupConnect(builder, wire, result);
1524  else
1525  fixupConnect(builder, result, wire);
1526  }
1527  }
1528  }
1529 }
1530 
1531 namespace llvm {
1532 /// A DenseMapInfo implementation for `ModuleInfo` that is a pair of
1533 /// llvm::SHA256 hashes, which are represented as std::array<uint8_t, 32>, and
1534 /// an array of string attributes. This allows us to create a DenseMap with
1535 /// `ModuleInfo` as keys.
1536 template <>
1537 struct DenseMapInfo<ModuleInfo> {
1538  static inline ModuleInfo getEmptyKey() {
1539  std::array<uint8_t, 32> key;
1540  std::fill(key.begin(), key.end(), ~0);
1541  return {key, DenseMapInfo<mlir::ArrayAttr>::getEmptyKey()};
1542  }
1543 
1544  static inline ModuleInfo getTombstoneKey() {
1545  std::array<uint8_t, 32> key;
1546  std::fill(key.begin(), key.end(), ~0 - 1);
1547  return {key, DenseMapInfo<mlir::ArrayAttr>::getTombstoneKey()};
1548  }
1549 
1550  static unsigned getHashValue(const ModuleInfo &val) {
1551  // We assume SHA256 is already a good hash and just truncate down to the
1552  // number of bytes we need for DenseMap.
1553  unsigned hash;
1554  std::memcpy(&hash, val.structuralHash.data(), sizeof(unsigned));
1555 
1556  // Combine module names.
1557  return llvm::hash_combine(hash, val.referredModuleNames);
1558  }
1559 
1560  static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs) {
1561  return lhs.structuralHash == rhs.structuralHash &&
1562  lhs.referredModuleNames == rhs.referredModuleNames;
1563  }
1564 };
1565 } // namespace llvm
1566 
1567 //===----------------------------------------------------------------------===//
1568 // DedupPass
1569 //===----------------------------------------------------------------------===//
1570 
1571 namespace {
1572 class DedupPass : public DedupBase<DedupPass> {
1573  void runOnOperation() override {
1574  auto *context = &getContext();
1575  auto circuit = getOperation();
1576  auto &instanceGraph = getAnalysis<InstanceGraph>();
1577  auto *nlaTable = &getAnalysis<NLATable>();
1578  auto &symbolTable = getAnalysis<SymbolTable>();
1579  Deduper deduper(instanceGraph, symbolTable, nlaTable, circuit);
1580  Equivalence equiv(context, instanceGraph);
1581  auto anythingChanged = false;
1582 
1583  // Modules annotated with this should not be considered for deduplication.
1584  auto noDedupClass = StringAttr::get(context, noDedupAnnoClass);
1585 
1586  // Only modules within the same group may be deduplicated.
1587  auto dedupGroupClass = StringAttr::get(context, dedupGroupAnnoClass);
1588 
1589  // A map of all the module moduleInfo that we have calculated so far.
1590  llvm::DenseMap<ModuleInfo, Operation *> moduleInfoToModule;
1591 
1592  // We track the name of the module that each module is deduped into, so that
1593  // we can make sure all modules which are marked "must dedup" with each
1594  // other were all deduped to the same module.
1595  DenseMap<Attribute, StringAttr> dedupMap;
1596 
1597  // We must iterate the modules from the bottom up so that we can properly
1598  // deduplicate the modules. We copy the list of modules into a vector first
1599  // to avoid iterator invalidation while we mutate the instance graph.
1600  SmallVector<FModuleLike, 0> modules(
1601  llvm::map_range(llvm::post_order(&instanceGraph), [](auto *node) {
1602  return cast<FModuleLike>(*node->getModule());
1603  }));
1604 
1605  SmallVector<std::optional<
1606  std::pair<std::array<uint8_t, 32>, SmallVector<StringAttr>>>>
1607  hashesAndModuleNames(modules.size());
1608  StructuralHasherSharedConstants hasherConstants(&getContext());
1609 
1610  // Attribute name used to store dedup_group for this pass.
1611  auto dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
1612 
1613  // Move dedup group annotations to attributes on the module.
1614  // This results in the desired behavior (included in hash),
1615  // and avoids unnecessary processing of these as annotations
1616  // that need to be tracked, made non-local, so on.
1617  for (auto module : modules) {
1618  llvm::SmallSetVector<StringAttr, 1> groups;
1620  module, [&groups, dedupGroupClass](Annotation annotation) {
1621  if (annotation.getClassAttr() != dedupGroupClass)
1622  return false;
1623  groups.insert(annotation.getMember<StringAttr>("group"));
1624  return true;
1625  });
1626  if (groups.size() > 1) {
1627  module.emitError("module belongs to multiple dedup groups: ") << groups;
1628  return signalPassFailure();
1629  }
1630  assert(!module->hasAttr(dedupGroupAttrName) &&
1631  "unexpected existing use of temporary dedup group attribute");
1632  if (!groups.empty())
1633  module->setDiscardableAttr(dedupGroupAttrName, groups.front());
1634  }
1635 
1636  // Calculate module information parallelly.
1637  auto result = mlir::failableParallelForEach(
1638  context, llvm::seq(modules.size()), [&](unsigned idx) {
1639  auto module = modules[idx];
1640  AnnotationSet annotations(module);
1641  // If the module is marked with NoDedup, just skip it.
1642  if (annotations.hasAnnotation(noDedupClass))
1643  return success();
1644 
1645  // Only dedup extmodule's with defname.
1646  if (auto ext = dyn_cast<FExtModuleOp>(*module);
1647  ext && !ext.getDefname().has_value())
1648  return success();
1649 
1650  // If module has symbol (name) that must be preserved even if unused,
1651  // skip it. All symbol uses must be supported, which is not true if
1652  // non-private.
1653  if (!module.isPrivate() ||
1654  (!module.canDiscardOnUseEmpty() && !isa<ClassLike>(*module))) {
1655  return success();
1656  }
1657 
1658  StructuralHasher hasher(hasherConstants);
1659  // Calculate the hash of the module and referred module names.
1660  hashesAndModuleNames[idx] = hasher.getHashAndModuleNames(module);
1661  return success();
1662  });
1663 
1664  if (result.failed())
1665  return signalPassFailure();
1666 
1667  for (auto [i, module] : llvm::enumerate(modules)) {
1668  auto moduleName = module.getModuleNameAttr();
1669  auto &hashAndModuleNamesOpt = hashesAndModuleNames[i];
1670  // If the hash was not calculated, we need to skip it.
1671  if (!hashAndModuleNamesOpt) {
1672  // We record it in the dedup map to help detect errors when the user
1673  // marks the module as both NoDedup and MustDedup. We do not record this
1674  // module in the hasher to make sure no other module dedups "into" this
1675  // one.
1676  dedupMap[moduleName] = moduleName;
1677  continue;
1678  }
1679 
1680  // Replace module names referred in the module with new names.
1681  SmallVector<mlir::Attribute> names;
1682  for (auto oldModuleName : hashAndModuleNamesOpt->second) {
1683  auto newModuleName = dedupMap[oldModuleName];
1684  names.push_back(newModuleName);
1685  }
1686 
1687  // Create a module info to use it as a key.
1688  ModuleInfo moduleInfo{hashAndModuleNamesOpt->first,
1689  mlir::ArrayAttr::get(module.getContext(), names)};
1690 
1691  // Check if there a module with the same hash.
1692  auto it = moduleInfoToModule.find(moduleInfo);
1693  if (it != moduleInfoToModule.end()) {
1694  auto original = cast<FModuleLike>(it->second);
1695  // Record the group ID of the other module.
1696  dedupMap[moduleName] = original.getModuleNameAttr();
1697  deduper.dedup(original, module);
1698  ++erasedModules;
1699  anythingChanged = true;
1700  continue;
1701  }
1702  // Any module not deduplicated must be recorded.
1703  deduper.record(module);
1704  // Add the module to a new dedup group.
1705  dedupMap[moduleName] = moduleName;
1706  // Record the module info.
1707  moduleInfoToModule[moduleInfo] = module;
1708  }
1709 
1710  // This part verifies that all modules marked by "MustDedup" have been
1711  // properly deduped with each other. For this check to succeed, all modules
1712  // have to been deduped to the same module. It is possible that a module was
1713  // deduped with the wrong thing.
1714 
1715  auto failed = false;
1716  // This parses the module name out of a target string.
1717  auto parseModule = [&](Attribute path) -> StringAttr {
1718  // Each module is listed as a target "~Circuit|Module" which we have to
1719  // parse.
1720  auto [_, rhs] = cast<StringAttr>(path).getValue().split('|');
1721  return StringAttr::get(context, rhs);
1722  };
1723  // This gets the name of the module which the current module was deduped
1724  // with. If the named module isn't in the map, then we didn't encounter it
1725  // in the circuit.
1726  auto getLead = [&](StringAttr module) -> StringAttr {
1727  auto it = dedupMap.find(module);
1728  if (it == dedupMap.end()) {
1729  auto diag = emitError(circuit.getLoc(),
1730  "MustDeduplicateAnnotation references module ")
1731  << module << " which does not exist";
1732  failed = true;
1733  return 0;
1734  }
1735  return it->second;
1736  };
1737 
1738  AnnotationSet::removeAnnotations(circuit, [&](Annotation annotation) {
1739  // If we have already failed, don't process any more annotations.
1740  if (failed)
1741  return false;
1742  if (!annotation.isClass(mustDedupAnnoClass))
1743  return false;
1744  auto modules = annotation.getMember<ArrayAttr>("modules");
1745  if (!modules) {
1746  emitError(circuit.getLoc(),
1747  "MustDeduplicateAnnotation missing \"modules\" member");
1748  failed = true;
1749  return false;
1750  }
1751  // Empty module list has nothing to process.
1752  if (modules.size() == 0)
1753  return true;
1754  // Get the first element.
1755  auto firstModule = parseModule(modules[0]);
1756  auto firstLead = getLead(firstModule);
1757  if (failed)
1758  return false;
1759  // Verify that the remaining elements are all the same as the first.
1760  for (auto attr : modules.getValue().drop_front()) {
1761  auto nextModule = parseModule(attr);
1762  auto nextLead = getLead(nextModule);
1763  if (failed)
1764  return false;
1765  if (firstLead != nextLead) {
1766  auto diag = emitError(circuit.getLoc(), "module ")
1767  << nextModule << " not deduplicated with " << firstModule;
1768  auto a = instanceGraph.lookup(firstLead)->getModule();
1769  auto b = instanceGraph.lookup(nextLead)->getModule();
1770  equiv.check(diag, a, b);
1771  failed = true;
1772  return false;
1773  }
1774  }
1775  return true;
1776  });
1777  if (failed)
1778  return signalPassFailure();
1779 
1780  // Remove all dedup group attributes, they only exist during this pass.
1781  for (auto module : circuit.getOps<FModuleLike>())
1782  module->removeDiscardableAttr(dedupGroupAttrName);
1783 
1784  // Walk all the modules and fixup the instance operation to return the
1785  // correct type. We delay this fixup until the end because doing it early
1786  // can block the deduplication of the parent modules.
1787  fixupAllModules(instanceGraph);
1788 
1789  markAnalysesPreserved<NLATable>();
1790  if (!anythingChanged)
1791  markAllAnalysesPreserved();
1792  }
1793 };
1794 } // end anonymous namespace
1795 
1796 std::unique_ptr<mlir::Pass> circt::firrtl::createDedupPass() {
1797  return std::make_unique<DedupPass>();
1798 }
assert(baseType &&"element must be base type")
static Attribute getAttr(ArrayRef< NamedAttribute > attrs, StringRef name)
Get an attribute by name from a list of named attributes.
Definition: FIRRTLOps.cpp:3852
static Location mergeLoc(MLIRContext *context, Location to, Location from)
Definition: Dedup.cpp:809
void fixupObjectOp(ObjectOp objectOp, ClassType newClassType)
This fixes up ObjectOps when the signature of their ClassOp changes.
Definition: Dedup.cpp:1450
llvm::raw_ostream & printHex(llvm::raw_ostream &stream, ArrayRef< uint8_t > bytes)
Definition: Dedup.cpp:46
void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src)
This fixes up connects when the field names of a bundle type changes.
Definition: Dedup.cpp:1457
llvm::raw_ostream & printHash(llvm::raw_ostream &stream, llvm::SHA256 &data)
Definition: Dedup.cpp:52
void fixupAllModules(InstanceGraph &instanceGraph)
This is the root method to fixup module references when a module changes.
Definition: Dedup.cpp:1483
bool fixupClassOp(ClassOp classOp)
This fixes up ClassLikes with ClassType ports, when the classes have deduped.
Definition: Dedup.cpp:1392
static void mergeRegions(Region *region1, Region *region2)
Definition: HWCleanup.cpp:70
Builder builder
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
void setDict(DictionaryAttr dict)
Set the data dictionary of this attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
StringAttr getClassAttr() const
Return the 'class' that this annotation is representing.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition: NLATable.h:29
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
ClassType getInstanceTypeForClassLike(ClassLike classOp)
Definition: FIRRTLOps.cpp:1776
static StringRef toString(Direction direction)
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
constexpr const char * mustDedupAnnoClass
std::pair< hw::InnerSymAttr, StringAttr > getOrAddInnerSym(MLIRContext *context, hw::InnerSymAttr attr, uint64_t fieldID, llvm::function_ref< hw::InnerSymbolNamespace &()> getNamespace)
Ensure that the the InnerSymAttr has a symbol on the field specified.
constexpr const char * noDedupAnnoClass
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
Definition: FIRRTLOps.cpp:271
constexpr const char * dedupGroupAnnoClass
std::unique_ptr< mlir::Pass > createDedupPass()
Definition: Dedup.cpp:1796
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
constexpr const char * dontTouchAnnoClass
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
Definition: FIRRTLUtils.cpp:24
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
Definition: FIRRTLOps.h:107
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
MLIRContext * context
Definition: Dedup.cpp:1359
Block * nlaBlock
We insert all NLAs to the beginning of this block.
Definition: Dedup.cpp:1367
void recordAnnotations(Operation *op)
Record all targets which use an NLA.
Definition: Dedup.cpp:939
void eraseNLA(hw::HierPathOp nla)
This erases the NLA op, and removes the NLA from every module's NLA map, but it does not delete the N...
Definition: Dedup.cpp:1048
void mergeAnnotations(FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Merge all annotations and port annotations on two operations.
Definition: Dedup.cpp:1229
void replaceInstances(FModuleLike toModule, Operation *fromModule)
This deletes and replaces all instances of the "fromModule" with instances of the "toModule".
Definition: Dedup.cpp:955
SmallVector< FlatSymbolRefAttr > createNLAs(StringAttr toModuleName, FModuleLike fromModule, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of this module and create an NLA for each one.
Definition: Dedup.cpp:1021
void record(FModuleLike module)
Record the usages of any NLA's in this module, so that we may update the annotation if the parent mod...
Definition: Dedup.cpp:913
void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName, StringAttr fromName)
Definition: Dedup.cpp:1127
void mergeRegions(RenameMap &renameMap, FModuleLike toModule, Region &toRegion, FModuleLike fromModule, Region &fromRegion)
Definition: Dedup.cpp:1351
void dedup(FModuleLike toModule, FModuleLike fromModule)
Remove the "fromModule", and replace all references to it with the "toModule".
Definition: Dedup.cpp:878
void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all the NLAs that the two modules participate in, replacing references to the "from" module w...
Definition: Dedup.cpp:1119
void recordAnnotations(AnnoTarget target)
For a specific annotation target, record all the unique NLAs which target it in the targetMap.
Definition: Dedup.cpp:932
void cloneAnnotation(SmallVectorImpl< FlatSymbolRefAttr > &nlas, Annotation anno, ArrayRef< NamedAttribute > attributes, unsigned nonLocalIndex, SmallVectorImpl< Annotation > &newAnnotations)
Clone the annotation for each NLA in a list.
Definition: Dedup.cpp:1029
void recordSymRenames(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Definition: Dedup.cpp:1257
void mergeAnnotations(FModuleLike toModule, AnnoTarget to, AnnotationSet toAnnos, FModuleLike fromModule, AnnoTarget from, AnnotationSet fromAnnos)
Merge the annotations of a specific target, either a operation or a port on an operation.
Definition: Dedup.cpp:1206
StringAttr nonLocalString
Definition: Dedup.cpp:1377
SymbolTable & symbolTable
Definition: Dedup.cpp:1361
SmallVector< FlatSymbolRefAttr > createNLAs(Operation *fromModule, ArrayRef< Attribute > baseNamepath, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of the from module and create an NLA for each one, appending the baseNamep...
Definition: Dedup.cpp:984
void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Recursively merge two operations.
Definition: Dedup.cpp:1318
hw::InnerSymbolNamespace & getNamespace(Operation *module)
Get a cached namespace for a module.
Definition: Dedup.cpp:925
DenseMap< Operation *, hw::InnerSymbolNamespace > moduleNamespaces
A module namespace cache.
Definition: Dedup.cpp:1381
bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to, FModuleLike fromModule, Annotation anno, SmallVectorImpl< Annotation > &newAnnotations)
Take an annotation, and update it to be a non-local annotation.
Definition: Dedup.cpp:1135
InstanceGraph & instanceGraph
Definition: Dedup.cpp:1360
void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock, FModuleLike fromModule, Block &fromBlock)
Recursively merge two blocks.
Definition: Dedup.cpp:1337
DenseMap< Attribute, llvm::SmallDenseSet< AnnoTarget > > targetMap
Definition: Dedup.cpp:1370
StringAttr classString
Definition: Dedup.cpp:1378
void copyAnnotations(FModuleLike toModule, AnnoTarget to, FModuleLike fromModule, AnnotationSet annos, SmallVectorImpl< Annotation > &newAnnotations, SmallPtrSetImpl< Attribute > &dontTouches)
Definition: Dedup.cpp:1177
Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable, NLATable *nlaTable, CircuitOp circuit)
Definition: Dedup.cpp:862
void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all NLAs referencing the "from" module to point to the "to" module.
Definition: Dedup.cpp:1058
DenseMap< StringAttr, StringAttr > RenameMap
Definition: Dedup.cpp:860
DenseMap< Attribute, Attribute > nlaCache
Definition: Dedup.cpp:1374
const hw::InnerSymbolTable & a
Definition: Dedup.cpp:355
ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
Definition: Dedup.cpp:352
const hw::InnerSymbolTable & b
Definition: Dedup.cpp:356
This class is for reporting differences between two modules which should have been deduplicated.
Definition: Dedup.cpp:334
DenseSet< Attribute > nonessentialAttributes
Definition: Dedup.cpp:796
std::string prettyPrint(Attribute attr)
Definition: Dedup.cpp:359
LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a, FInstanceLike b)
Definition: Dedup.cpp:641
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Block &aBlock, Operation *b, Block &bBlock)
Definition: Dedup.cpp:430
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, Type aType, Operation *b, Type bType)
Definition: Dedup.cpp:404
StringAttr noDedupClass
Definition: Dedup.cpp:791
StringAttr dedupGroupAttrName
Definition: Dedup.cpp:793
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, DictionaryAttr aDict, Operation *b, DictionaryAttr bDict)
Definition: Dedup.cpp:553
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Region &aRegion, Operation *b, Region &bRegion)
Definition: Dedup.cpp:508
StringAttr portDirectionsAttr
Definition: Dedup.cpp:789
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, BundleType aType, Operation *b, BundleType bType)
Definition: Dedup.cpp:375
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Operation *b)
Definition: Dedup.cpp:662
Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
Definition: Dedup.cpp:335
LogicalResult check(InFlightDiagnostic &diag, Operation *a, mlir::DenseBoolArrayAttr aAttr, Operation *b, mlir::DenseBoolArrayAttr bAttr)
Definition: Dedup.cpp:528
InstanceGraph & instanceGraph
Definition: Dedup.cpp:797
void check(InFlightDiagnostic &diag, Operation *a, Operation *b)
Definition: Dedup.cpp:749
std::array< uint8_t, 32 > structuralHash
Definition: Dedup.cpp:69
mlir::ArrayAttr referredModuleNames
Definition: Dedup.cpp:71
This struct contains constant string attributes shared across different threads.
Definition: Dedup.cpp:81
DenseSet< Attribute > nonessentialAttributes
Definition: Dedup.cpp:112
StructuralHasherSharedConstants(MLIRContext *context)
Definition: Dedup.cpp:82
void update(Operation *op, DictionaryAttr dict)
Hash the top level attribute dictionary of the operation.
Definition: Dedup.cpp:211
void update(BlockArgument arg)
Definition: Dedup.cpp:160
void update(Type type)
Definition: Dedup.cpp:149
void update(const void *pointer)
Definition: Dedup.cpp:127
void update(InnerRefAttr attr)
Definition: Dedup.cpp:200
void update(Operation *op)
Definition: Dedup.cpp:292
void record(void *address)
Definition: Dedup.cpp:155
void update(const SymbolTarget &target)
Definition: Dedup.cpp:195
std::pair< std::array< uint8_t, 32 >, SmallVector< StringAttr > > getHashAndModuleNames(FModuleLike module)
Definition: Dedup.cpp:120
llvm::SHA256 sha
Definition: Dedup.cpp:325
void update(Operation *op, hw::InnerSymAttr attr)
Definition: Dedup.cpp:183
DenseMap< StringAttr, SymbolTarget > innerSymTargets
Definition: Dedup.cpp:315
void update(size_t value)
Definition: Dedup.cpp:132
SmallVector< mlir::StringAttr > referredModuleNames
Definition: Dedup.cpp:318
void update(Block &block)
Definition: Dedup.cpp:277
void update(BundleType type)
Definition: Dedup.cpp:140
void update(OpResult result)
Definition: Dedup.cpp:162
void update(OpOperand &operand)
Definition: Dedup.cpp:176
StructuralHasher(const StructuralHasherSharedConstants &constants)
Definition: Dedup.cpp:116
void update(TypeID typeID)
Definition: Dedup.cpp:137
const StructuralHasherSharedConstants & constants
Definition: Dedup.cpp:321
void update(Value value, hw::InnerSymAttr attr)
Definition: Dedup.cpp:189
DenseMap< void *, unsigned > indices
Definition: Dedup.cpp:312
void update(mlir::OperationName name)
Definition: Dedup.cpp:286
uint64_t fieldID
Definition: Dedup.cpp:76
uint64_t index
Definition: Dedup.cpp:75
An annotation target is used to keep track of something that is targeted by an Annotation.
AnnotationSet getAnnotations() const
Get the annotations associated with the target.
void setAnnotations(AnnotationSet annotations) const
Set the annotations associated with the target.
This represents an annotation targeting a specific operation.
Attribute getNLAReference(hw::InnerSymbolNamespace &moduleNamespace) const
This represents an annotation targeting a specific port of a module, memory, or instance.
static ModuleInfo getEmptyKey()
Definition: Dedup.cpp:1538
static ModuleInfo getTombstoneKey()
Definition: Dedup.cpp:1544
static unsigned getHashValue(const ModuleInfo &val)
Definition: Dedup.cpp:1550
static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs)
Definition: Dedup.cpp:1560