Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Dedup.cpp
Go to the documentation of this file.
1//===- Dedup.cpp - FIRRTL module deduping -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements FIRRTL module deduplication.
10//
11//===----------------------------------------------------------------------===//
12
23#include "circt/Support/LLVM.h"
24#include "mlir/IR/IRMapping.h"
25#include "mlir/IR/Threading.h"
26#include "mlir/Pass/Pass.h"
27#include "llvm/ADT/DenseMap.h"
28#include "llvm/ADT/DenseMapInfo.h"
29#include "llvm/ADT/Hashing.h"
30#include "llvm/ADT/PostOrderIterator.h"
31#include "llvm/ADT/SmallPtrSet.h"
32#include "llvm/Support/Format.h"
33#include "llvm/Support/SHA256.h"
34
35namespace circt {
36namespace firrtl {
37#define GEN_PASS_DEF_DEDUP
38#include "circt/Dialect/FIRRTL/Passes.h.inc"
39} // namespace firrtl
40} // namespace circt
41
42using namespace circt;
43using namespace firrtl;
44using hw::InnerRefAttr;
45
46//===----------------------------------------------------------------------===//
47// Utility function for classifying a Symbol's dedup-ability.
48//===----------------------------------------------------------------------===//
49
50/// Returns true if the module can be removed.
51static bool canRemoveModule(mlir::SymbolOpInterface symbol) {
52 // If the symbol is not private, it cannot be removed.
53 if (!symbol.isPrivate())
54 return false;
55 // Classes may be referenced in object types, so can not normally be removed
56 // if we can't find any symbol uses. Since we know that dedup will update the
57 // types of instances appropriately, we can ignore that and return true here.
58 if (isa<ClassLike>(*symbol))
59 return true;
60 // If module can not be removed even if no uses can be found, we can not
61 // delete it. The implication is that there are hidden symbol uses that dedup
62 // will not properly update.
63 if (!symbol.canDiscardOnUseEmpty())
64 return false;
65 // The module can be deleted.
66 return true;
67}
68
69//===----------------------------------------------------------------------===//
70// Hashing
71//===----------------------------------------------------------------------===//
72
73llvm::raw_ostream &printHex(llvm::raw_ostream &stream,
74 ArrayRef<uint8_t> bytes) {
75 // Print the hash on a single line.
76 return stream << format_bytes(bytes, std::nullopt, 32) << "\n";
77}
78
79llvm::raw_ostream &printHash(llvm::raw_ostream &stream, llvm::SHA256 &data) {
80 return printHex(stream, data.result());
81}
82
83llvm::raw_ostream &printHash(llvm::raw_ostream &stream, std::string data) {
84 ArrayRef<uint8_t> bytes(reinterpret_cast<const uint8_t *>(data.c_str()),
85 data.length());
86 return printHex(stream, bytes);
87}
88
89// This struct contains information to determine module module uniqueness. A
90// first element is a structural hash of the module, and the second element is
91// an array which tracks module names encountered in the walk. Since module
92// names could be replaced during dedup, it's necessary to keep names up-to-date
93// before actually combining them into structural hashes.
94struct ModuleInfo {
95 // SHA256 hash.
96 std::array<uint8_t, 32> structuralHash;
97 // Module names referred by instance op in the module.
98 std::vector<StringAttr> referredModuleNames;
99};
100
101static bool operator==(const ModuleInfo &lhs, const ModuleInfo &rhs) {
102 return lhs.structuralHash == rhs.structuralHash &&
104}
105
106/// This struct contains constant string attributes shared across different
107/// threads.
109 explicit StructuralHasherSharedConstants(MLIRContext *context) {
110 portTypesAttr = StringAttr::get(context, "portTypes");
111 moduleNameAttr = StringAttr::get(context, "moduleName");
112 portNamesAttr = StringAttr::get(context, "portNames");
113 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
114 nonessentialAttributes.insert(StringAttr::get(context, "convention"));
115 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
116 nonessentialAttributes.insert(StringAttr::get(context, "name"));
117 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
118 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
119 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
120 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
121 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
122 nonessentialAttributes.insert(StringAttr::get(context, "sym_visibility"));
123 };
124
125 // This is a cached "portTypes" string attr.
126 StringAttr portTypesAttr;
127
128 // This is a cached "moduleName" string attr.
129 StringAttr moduleNameAttr;
130
131 // This is a cached "portNames" string attr.
132 StringAttr portNamesAttr;
133
134 // This is a set of every attribute we should ignore.
135 DenseSet<Attribute> nonessentialAttributes;
136};
137
140 : constants(constants) {}
141
142 ModuleInfo getModuleInfo(FModuleLike module) {
144 update(&(*module));
145 return {sha.final(), std::move(referredModuleNames)};
146 }
147
148private:
149 /// Find all the ports and operations which may define an inner symbol
150 /// operations and give each a unique id. If the port/operation does define
151 /// an inner symbol, map the symbol name to a pair of the id and the symbol's
152 /// field id. When we hash (local) references to this inner symbol, we will
153 /// hash in the id and the the field id.
154 void populateInnerSymIDTable(FModuleLike module) {
155 // Add port symbols. If no port has a symbol defined, the port symbol array
156 // will be totally empty.
157 for (auto [index, innerSym] : llvm::enumerate(module.getPortSymbols())) {
158 for (auto prop : cast<hw::InnerSymAttr>(innerSym))
159 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
160 }
161 // Add operation symbols.
162 size_t index = module.getNumPorts();
163 module.walk([&](hw::InnerSymbolOpInterface innerSymOp) {
164 if (auto innerSym = innerSymOp.getInnerSymAttr()) {
165 for (auto prop : innerSym)
166 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
167 }
168 ++index;
169 });
170 }
171
172 // Get the identifier for an object. The identifier is assigned on first use.
173 unsigned getID(void *object) {
174 auto [it, inserted] = idTable.try_emplace(object, nextID);
175 if (inserted)
176 ++nextID;
177 return it->second;
178 }
179
180 // Get the identifier for an IR object. Free the ID, too.
181 unsigned finalizeID(void *object) {
182 auto it = idTable.find(object);
183 if (it == idTable.end())
184 return nextID++;
185 auto id = it->second;
186 idTable.erase(it);
187 return id;
188 }
189
190 std::pair<size_t, size_t> getInnerSymID(StringAttr name) {
191 return innerSymIDTable.at(name);
192 }
193
194 void update(OpOperand &operand) {
195 auto value = operand.get();
196 if (auto result = dyn_cast<OpResult>(value)) {
197 auto *op = result.getOwner();
198 update(getID(op));
199 update(result.getResultNumber());
200 return;
201 }
202 if (auto argument = dyn_cast<BlockArgument>(value)) {
203 auto *block = argument.getOwner();
204 update(getID(block));
205 update(argument.getArgNumber());
206 return;
207 }
208 llvm_unreachable("Unknown value type");
209 }
210
211 void update(const void *pointer) {
212 auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
213 sha.update(ArrayRef<uint8_t>(addr, sizeof pointer));
214 }
215
216 void update(size_t value) {
217 auto *addr = reinterpret_cast<const uint8_t *>(&value);
218 sha.update(ArrayRef<uint8_t>(addr, sizeof value));
219 }
220
221 template <typename T, typename U>
222 void update(const std::pair<T, U> &pair) {
223 update(pair.first);
224 update(pair.second);
225 }
226
227 void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
228
229 // NOLINTNEXTLINE(misc-no-recursion)
230 void update(BundleType type) {
231 update(type.getTypeID());
232 for (auto &element : type.getElements()) {
233 update(element.isFlip);
234 update(element.type);
235 }
236 }
237
238 // NOLINTNEXTLINE(misc-no-recursion)
239 void update(Type type) {
240 if (auto bundle = type_dyn_cast<BundleType>(type))
241 return update(bundle);
242 update(type.getAsOpaquePointer());
243 }
244
245 void update(OpResult result) {
246 // Like instance ops, don't use object ops' result types since they might be
247 // replaced by dedup. Record the class names and lazily combine their hashes
248 // using the same mechanism as instances and modules.
249 if (auto objectOp = dyn_cast<ObjectOp>(result.getOwner())) {
250 referredModuleNames.push_back(objectOp.getType().getNameAttr().getAttr());
251 return;
252 }
253
254 update(result.getType());
255 }
256
257 /// Hash the top level attribute dictionary of the operation. This function
258 /// has special handling for inner symbols, ports, and referenced modules.
259 void update(Operation *op, DictionaryAttr dict) {
260 for (auto namedAttr : dict) {
261 auto name = namedAttr.getName();
262 auto value = namedAttr.getValue();
263 // Skip names and annotations, except in certain cases.
264 // Names of ports are load bearing for classes, so we do hash those.
265 bool isClassPortNames =
266 isa<ClassLike>(op) && name == constants.portNamesAttr;
267 if (constants.nonessentialAttributes.contains(name) && !isClassPortNames)
268 continue;
269
270 // Hash the attribute name (an interned pointer).
271 update(name.getAsOpaquePointer());
272
273 // Hash the port types.
274 if (name == constants.portTypesAttr) {
275 auto portTypes = cast<ArrayAttr>(value).getAsValueRange<TypeAttr>();
276 for (auto type : portTypes)
277 update(type);
278 continue;
279 }
280
281 // For instance op, don't use `moduleName` attributes since they might be
282 // replaced by dedup. Record the names and lazily combine their hashes.
283 // It is assumed that module names are hashed only through instance ops;
284 // it could cause suboptimal results if there was other operation that
285 // refers to module names through essential attributes.
286 if (isa<InstanceOp>(op) && name == constants.moduleNameAttr) {
287 referredModuleNames.push_back(cast<FlatSymbolRefAttr>(value).getAttr());
288 continue;
289 }
290
291 // TODO: properly handle DistinctAttr, including its use in paths.
292 // See https://github.com/llvm/circt/issues/6583.
293 if (isa<DistinctAttr>(value))
294 continue;
295
296 // If this is an symbol reference, we need to perform name erasure.
297 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(value)) {
298 update(getInnerSymID(innerRef.getName()));
299 continue;
300 }
301
302 // We don't need to handle this attribute specially, so hash its unique
303 // address.
304 update(value.getAsOpaquePointer());
305 }
306 }
307
308 void update(mlir::OperationName name) {
309 // Operation names are interned.
310 update(name.getAsOpaquePointer());
311 }
312
313 // NOLINTNEXTLINE(misc-no-recursion)
314 void update(Block *block) {
315 for (auto &op : llvm::reverse(*block))
316 update(&op);
317 for (auto type : block->getArgumentTypes())
318 update(type);
319 update(finalizeID(block));
320 update(position);
321 ++position;
322 }
323
324 // NOLINTNEXTLINE(misc-no-recursion)
325 void update(Region *region) {
326 for (auto &block : llvm::reverse(region->getBlocks()))
327 update(&block);
328 update(position);
329 ++position;
330 }
331
332 // NOLINTNEXTLINE(misc-no-recursion)
333 void update(Operation *op) {
334 // Hash the regions. We need to make sure an empty region doesn't hash the
335 // same as no region, so we include the number of regions.
336 update(op->getNumRegions());
337 for (auto &region : reverse(op->getRegions()))
338 update(&region);
339
340 update(op->getName());
341
342 // Record the uses for later hashing.
343 for (auto &operand : op->getOpOperands())
344 update(operand);
345
346 // This happens after the numbering above, as it uses blockarg numbering
347 // for inner symbols.
348 update(op, op->getAttrDictionary());
349
350 // Record any op results (types).
351 for (auto result : op->getResults())
352 update(result);
353
354 // Incorporate the hash of uses we have already built.
355 update(finalizeID(op));
356 update(position);
357 ++position;
358 }
359
360 // A map from an operation/block, to its identifier.
361 DenseMap<void *, unsigned> idTable;
362 unsigned nextID = 0;
363
364 // A map from an inner symbol, to its identifier.
365 DenseMap<StringAttr, std::pair<size_t, size_t>> innerSymIDTable;
366
367 // This keeps track of module names in the order of the appearance.
368 std::vector<StringAttr> referredModuleNames;
369
370 // String constants.
372
373 // This is the actual running hash calculation. This is a stateful element
374 // that should be reinitialized after each hash is produced.
375 llvm::SHA256 sha;
376
377 // The index of the current op. Increment after handling each op.
378 size_t position = 0;
379};
380
381//===----------------------------------------------------------------------===//
382// Equivalence
383//===----------------------------------------------------------------------===//
384
385/// This class is for reporting differences between two modules which should
386/// have been deduplicated.
390 noDedupClass = StringAttr::get(context, noDedupAnnoClass);
391 dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
392 portDirectionsAttr = StringAttr::get(context, "portDirections");
393 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
394 nonessentialAttributes.insert(StringAttr::get(context, "name"));
395 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
396 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
397 nonessentialAttributes.insert(StringAttr::get(context, "portTypes"));
398 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
399 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
400 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
401 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
402 }
403
411
412 std::string prettyPrint(Attribute attr) {
413 SmallString<64> buffer;
414 llvm::raw_svector_ostream os(buffer);
415 if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
416 os << "0x";
417 if (integerAttr.getType().isSignlessInteger())
418 integerAttr.getValue().toStringUnsigned(buffer, /*radix=*/16);
419 else
420 integerAttr.getAPSInt().toString(buffer, /*radix=*/16);
421
422 } else
423 os << attr;
424 return std::string(buffer);
425 }
426
427 // NOLINTNEXTLINE(misc-no-recursion)
428 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
429 Operation *a, BundleType aType, Operation *b,
430 BundleType bType) {
431 if (aType.getNumElements() != bType.getNumElements()) {
432 diag.attachNote(a->getLoc())
433 << message << " bundle type has different number of elements";
434 diag.attachNote(b->getLoc()) << "second operation here";
435 return failure();
436 }
437
438 for (auto elementPair :
439 llvm::zip(aType.getElements(), bType.getElements())) {
440 auto aElement = std::get<0>(elementPair);
441 auto bElement = std::get<1>(elementPair);
442 if (aElement.isFlip != bElement.isFlip) {
443 diag.attachNote(a->getLoc()) << message << " bundle element "
444 << aElement.name << " flip does not match";
445 diag.attachNote(b->getLoc()) << "second operation here";
446 return failure();
447 }
448
449 if (failed(check(diag,
450 "bundle element \'" + aElement.name.getValue() + "'", a,
451 aElement.type, b, bElement.type)))
452 return failure();
453 }
454 return success();
455 }
456
457 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
458 Operation *a, Type aType, Operation *b, Type bType) {
459 if (aType == bType)
460 return success();
461 if (auto aBundleType = type_dyn_cast<BundleType>(aType))
462 if (auto bBundleType = type_dyn_cast<BundleType>(bType))
463 return check(diag, message, a, aBundleType, b, bBundleType);
464 if (type_isa<RefType>(aType) && type_isa<RefType>(bType) &&
465 aType != bType) {
466 diag.attachNote(a->getLoc())
467 << message << ", has a RefType with a different base type "
468 << type_cast<RefType>(aType).getType()
469 << " in the same position of the two modules marked as 'must dedup'. "
470 "(This may be due to Grand Central Taps or Views being different "
471 "between the two modules.)";
472 diag.attachNote(b->getLoc())
473 << "the second module has a different base type "
474 << type_cast<RefType>(bType).getType();
475 return failure();
476 }
477 diag.attachNote(a->getLoc())
478 << message << " types don't match, first type is " << aType;
479 diag.attachNote(b->getLoc()) << "second type is " << bType;
480 return failure();
481 }
482
483 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
484 Block &aBlock, Operation *b, Block &bBlock) {
485
486 // Block argument types.
487 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
488 auto portNo = 0;
489 auto emitMissingPort = [&](Value existsVal, Operation *opExists,
490 Operation *opDoesNotExist) {
491 StringRef portName;
492 auto portNames = opExists->getAttrOfType<ArrayAttr>("portNames");
493 if (portNames)
494 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
495 portName = portNameAttr.getValue();
496 if (type_isa<RefType>(existsVal.getType())) {
497 diag.attachNote(opExists->getLoc())
498 << " contains a RefType port named '" + portName +
499 "' that only exists in one of the modules (can be due to "
500 "difference in Grand Central Tap or View of two modules "
501 "marked with must dedup)";
502 diag.attachNote(opDoesNotExist->getLoc())
503 << "second module to be deduped that does not have the RefType "
504 "port";
505 } else {
506 diag.attachNote(opExists->getLoc())
507 << "port '" + portName + "' only exists in one of the modules";
508 diag.attachNote(opDoesNotExist->getLoc())
509 << "second module to be deduped that does not have the port";
510 }
511 return failure();
512 };
513
514 for (auto argPair :
515 llvm::zip_longest(aBlock.getArguments(), bBlock.getArguments())) {
516 auto &aArg = std::get<0>(argPair);
517 auto &bArg = std::get<1>(argPair);
518 if (aArg.has_value() && bArg.has_value()) {
519 // TODO: we should print the port number if there are no port names, but
520 // there are always port names ;).
521 StringRef portName;
522 if (portNames) {
523 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
524 portName = portNameAttr.getValue();
525 }
526 // Assumption here that block arguments correspond to ports.
527 if (failed(check(diag, "module port '" + portName + "'", a,
528 aArg->getType(), b, bArg->getType())))
529 return failure();
530 data.map.map(aArg.value(), bArg.value());
531 portNo++;
532 continue;
533 }
534 if (!aArg.has_value())
535 std::swap(a, b);
536 return emitMissingPort(aArg.has_value() ? aArg.value() : bArg.value(), a,
537 b);
538 }
539
540 // Blocks operations.
541 auto aIt = aBlock.begin();
542 auto aEnd = aBlock.end();
543 auto bIt = bBlock.begin();
544 auto bEnd = bBlock.end();
545 while (aIt != aEnd && bIt != bEnd)
546 if (failed(check(diag, data, &*aIt++, &*bIt++)))
547 return failure();
548 if (aIt != aEnd) {
549 diag.attachNote(aIt->getLoc()) << "first block has more operations";
550 diag.attachNote(b->getLoc()) << "second block here";
551 return failure();
552 }
553 if (bIt != bEnd) {
554 diag.attachNote(bIt->getLoc()) << "second block has more operations";
555 diag.attachNote(a->getLoc()) << "first block here";
556 return failure();
557 }
558 return success();
559 }
560
561 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
562 Region &aRegion, Operation *b, Region &bRegion) {
563 auto aIt = aRegion.begin();
564 auto aEnd = aRegion.end();
565 auto bIt = bRegion.begin();
566 auto bEnd = bRegion.end();
567
568 // Region blocks.
569 while (aIt != aEnd && bIt != bEnd)
570 if (failed(check(diag, data, a, *aIt++, b, *bIt++)))
571 return failure();
572 if (aIt != aEnd || bIt != bEnd) {
573 diag.attachNote(a->getLoc())
574 << "operation regions have different number of blocks";
575 diag.attachNote(b->getLoc()) << "second operation here";
576 return failure();
577 }
578 return success();
579 }
580
581 LogicalResult check(InFlightDiagnostic &diag, Operation *a,
582 mlir::DenseBoolArrayAttr aAttr, Operation *b,
583 mlir::DenseBoolArrayAttr bAttr) {
584 if (aAttr == bAttr)
585 return success();
586 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
587 for (unsigned i = 0, e = aAttr.size(); i < e; ++i) {
588 auto aDirection = aAttr[i];
589 auto bDirection = bAttr[i];
590 if (aDirection != bDirection) {
591 auto &note = diag.attachNote(a->getLoc()) << "module port ";
592 if (portNames)
593 note << "'" << cast<StringAttr>(portNames[i]).getValue() << "'";
594 else
595 note << i;
596 note << " directions don't match, first direction is '"
597 << direction::toString(aDirection) << "'";
598 diag.attachNote(b->getLoc()) << "second direction is '"
599 << direction::toString(bDirection) << "'";
600 return failure();
601 }
602 }
603 return success();
604 }
605
606 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
607 DictionaryAttr aDict, Operation *b,
608 DictionaryAttr bDict) {
609 // Fast path.
610 if (aDict == bDict)
611 return success();
612
613 DenseSet<Attribute> seenAttrs;
614 for (auto namedAttr : aDict) {
615 auto attrName = namedAttr.getName();
616 if (nonessentialAttributes.contains(attrName))
617 continue;
618
619 auto aAttr = namedAttr.getValue();
620 auto bAttr = bDict.get(attrName);
621 if (!bAttr) {
622 diag.attachNote(a->getLoc())
623 << "second operation is missing attribute " << attrName;
624 diag.attachNote(b->getLoc()) << "second operation here";
625 return diag;
626 }
627
628 if (isa<hw::InnerRefAttr>(aAttr) && isa<hw::InnerRefAttr>(bAttr)) {
629 auto bRef = cast<hw::InnerRefAttr>(bAttr);
630 auto aRef = cast<hw::InnerRefAttr>(aAttr);
631 // See if they are pointing at the same operation or port.
632 auto aTarget = data.a.lookup(aRef.getName());
633 auto bTarget = data.b.lookup(bRef.getName());
634 if (!aTarget || !bTarget)
635 diag.attachNote(a->getLoc())
636 << "malformed ir, possibly violating use-before-def";
637 auto error = [&]() {
638 diag.attachNote(a->getLoc())
639 << "operations have different targets, first operation has "
640 << aTarget;
641 diag.attachNote(b->getLoc()) << "second operation has " << bTarget;
642 return failure();
643 };
644 if (aTarget.isPort()) {
645 // If they are targeting ports, make sure its the same port number.
646 if (!bTarget.isPort() || aTarget.getPort() != bTarget.getPort())
647 return error();
648 } else {
649 // Otherwise make sure that they are targeting the same operation.
650 if (!bTarget.isOpOnly() ||
651 aTarget.getOp() != data.map.lookup(bTarget.getOp()))
652 return error();
653 }
654 if (aTarget.getField() != bTarget.getField())
655 return error();
656 } else if (attrName == portDirectionsAttr) {
657 // Special handling for the port directions attribute for better
658 // error messages.
659 if (failed(check(diag, a, cast<mlir::DenseBoolArrayAttr>(aAttr), b,
660 cast<mlir::DenseBoolArrayAttr>(bAttr))))
661 return failure();
662 } else if (isa<DistinctAttr>(aAttr) && isa<DistinctAttr>(bAttr)) {
663 // TODO: properly handle DistinctAttr, including its use in paths.
664 // See https://github.com/llvm/circt/issues/6583
665 } else if (aAttr != bAttr) {
666 diag.attachNote(a->getLoc())
667 << "first operation has attribute '" << attrName.getValue()
668 << "' with value " << prettyPrint(aAttr);
669 diag.attachNote(b->getLoc())
670 << "second operation has value " << prettyPrint(bAttr);
671 return failure();
672 }
673 seenAttrs.insert(attrName);
674 }
675 if (aDict.getValue().size() != bDict.getValue().size()) {
676 for (auto namedAttr : bDict) {
677 auto attrName = namedAttr.getName();
678 // Skip the attribute if we don't care about this particular one or it
679 // is one that is known to be in both dictionaries.
680 if (nonessentialAttributes.contains(attrName) ||
681 seenAttrs.contains(attrName))
682 continue;
683 // We have found an attribute that is only in the second operation.
684 diag.attachNote(a->getLoc())
685 << "first operation is missing attribute " << attrName;
686 diag.attachNote(b->getLoc()) << "second operation here";
687 return failure();
688 }
689 }
690 return success();
691 }
692
693 // NOLINTNEXTLINE(misc-no-recursion)
694 LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a,
695 FInstanceLike b) {
696 auto aName = a.getReferencedModuleNameAttr();
697 auto bName = b.getReferencedModuleNameAttr();
698 if (aName == bName)
699 return success();
700
701 // If the modules instantiate are different we will want to know why the
702 // sub module did not dedupliate. This code recursively checks the child
703 // module.
704 auto aModule = instanceGraph.lookup(aName)->getModule();
705 auto bModule = instanceGraph.lookup(bName)->getModule();
706 // Create a new error for the submodule.
707 diag.attachNote(std::nullopt)
708 << "in instance " << a.getInstanceNameAttr() << " of " << aName
709 << ", and instance " << b.getInstanceNameAttr() << " of " << bName;
710 check(diag, aModule, bModule);
711 return failure();
712 }
713
714 // NOLINTNEXTLINE(misc-no-recursion)
715 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
716 Operation *b) {
717 // Operation name.
718 if (a->getName() != b->getName()) {
719 diag.attachNote(a->getLoc()) << "first operation is a " << a->getName();
720 diag.attachNote(b->getLoc()) << "second operation is a " << b->getName();
721 return failure();
722 }
723
724 // If its an instance operaiton, perform some checking and possibly
725 // recurse.
726 if (auto aInst = dyn_cast<FInstanceLike>(a)) {
727 auto bInst = cast<FInstanceLike>(b);
728 if (failed(check(diag, aInst, bInst)))
729 return failure();
730 }
731
732 // Operation results.
733 if (a->getNumResults() != b->getNumResults()) {
734 diag.attachNote(a->getLoc())
735 << "operations have different number of results";
736 diag.attachNote(b->getLoc()) << "second operation here";
737 return failure();
738 }
739 for (auto resultPair : llvm::zip(a->getResults(), b->getResults())) {
740 auto &aValue = std::get<0>(resultPair);
741 auto &bValue = std::get<1>(resultPair);
742 if (failed(check(diag, "operation result", a, aValue.getType(), b,
743 bValue.getType())))
744 return failure();
745 data.map.map(aValue, bValue);
746 }
747
748 // Operations operands.
749 if (a->getNumOperands() != b->getNumOperands()) {
750 diag.attachNote(a->getLoc())
751 << "operations have different number of operands";
752 diag.attachNote(b->getLoc()) << "second operation here";
753 return failure();
754 }
755 for (auto operandPair : llvm::zip(a->getOperands(), b->getOperands())) {
756 auto &aValue = std::get<0>(operandPair);
757 auto &bValue = std::get<1>(operandPair);
758 if (bValue != data.map.lookup(aValue)) {
759 diag.attachNote(a->getLoc())
760 << "operations use different operands, first operand is '"
761 << getFieldName(
762 getFieldRefFromValue(aValue, /*lookThroughCasts=*/true))
763 .first
764 << "'";
765 diag.attachNote(b->getLoc())
766 << "second operand is '"
767 << getFieldName(
768 getFieldRefFromValue(bValue, /*lookThroughCasts=*/true))
769 .first
770 << "', but should have been '"
771 << getFieldName(getFieldRefFromValue(data.map.lookup(aValue),
772 /*lookThroughCasts=*/true))
773 .first
774 << "'";
775 return failure();
776 }
777 }
778 data.map.map(a, b);
779
780 // Operation regions.
781 if (a->getNumRegions() != b->getNumRegions()) {
782 diag.attachNote(a->getLoc())
783 << "operations have different number of regions";
784 diag.attachNote(b->getLoc()) << "second operation here";
785 return failure();
786 }
787 for (auto regionPair : llvm::zip(a->getRegions(), b->getRegions())) {
788 auto &aRegion = std::get<0>(regionPair);
789 auto &bRegion = std::get<1>(regionPair);
790 if (failed(check(diag, data, a, aRegion, b, bRegion)))
791 return failure();
792 }
793
794 // Operation attributes.
795 if (failed(check(diag, data, a, a->getAttrDictionary(), b,
796 b->getAttrDictionary())))
797 return failure();
798 return success();
799 }
800
801 // NOLINTNEXTLINE(misc-no-recursion)
802 void check(InFlightDiagnostic &diag, Operation *a, Operation *b) {
803 hw::InnerSymbolTable aTable(a);
804 hw::InnerSymbolTable bTable(b);
805 ModuleData data(aTable, bTable);
807 diag.attachNote(a->getLoc()) << "module marked NoDedup";
808 return;
809 }
811 diag.attachNote(b->getLoc()) << "module marked NoDedup";
812 return;
813 }
814 auto aSymbol = cast<mlir::SymbolOpInterface>(a);
815 auto bSymbol = cast<mlir::SymbolOpInterface>(b);
816 if (!canRemoveModule(aSymbol) && !canRemoveModule(bSymbol)) {
817 diag.attachNote(a->getLoc())
818 << "module is "
819 << (aSymbol.isPrivate() ? "private but not discardable" : "public");
820 diag.attachNote(b->getLoc())
821 << "module is "
822 << (bSymbol.isPrivate() ? "private but not discardable" : "public");
823 return;
824 }
825 auto aGroup =
826 dyn_cast_or_null<StringAttr>(a->getDiscardableAttr(dedupGroupAttrName));
827 auto bGroup = dyn_cast_or_null<StringAttr>(
828 b->getAttrOfType<StringAttr>(dedupGroupAttrName));
829 if (aGroup != bGroup) {
830 if (bGroup) {
831 diag.attachNote(b->getLoc())
832 << "module is in dedup group '" << bGroup.str() << "'";
833 } else {
834 diag.attachNote(b->getLoc()) << "module is not part of a dedup group";
835 }
836 if (aGroup) {
837 diag.attachNote(a->getLoc())
838 << "module is in dedup group '" << aGroup.str() << "'";
839 } else {
840 diag.attachNote(a->getLoc()) << "module is not part of a dedup group";
841 }
842 return;
843 }
844 if (failed(check(diag, data, a, b)))
845 return;
846 diag.attachNote(a->getLoc()) << "first module here";
847 diag.attachNote(b->getLoc()) << "second module here";
848 }
849
850 // This is a cached "portDirections" string attr.
852 // This is a cached "NoDedup" annotation class string attr.
853 StringAttr noDedupClass;
854 // This is a cached string attr for the dedup group attribute.
856
857 // This is a set of every attribute we should ignore.
858 DenseSet<Attribute> nonessentialAttributes;
860};
861
862//===----------------------------------------------------------------------===//
863// Deduplication
864//===----------------------------------------------------------------------===//
865
866// Custom location merging. This only keeps track of 8 annotations from ".fir"
867// files, and however many annotations come from "real" sources. When
868// deduplicating, modules tend not to have scala source locators, so we wind
869// up fusing source locators for a module from every copy being deduped. There
870// is little value in this (all the modules are identical by definition).
871static Location mergeLoc(MLIRContext *context, Location to, Location from) {
872 // Unique the set of locations to be fused.
873 llvm::SmallSetVector<Location, 4> decomposedLocs;
874 // only track 8 "fir" locations
875 unsigned seenFIR = 0;
876 for (auto loc : {to, from}) {
877 // If the location is a fused location we decompose it if it has no
878 // metadata or the metadata is the same as the top level metadata.
879 if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
880 // UnknownLoc's have already been removed from FusedLocs so we can
881 // simply add all of the internal locations.
882 for (auto loc : fusedLoc.getLocations()) {
883 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
884 if (fileLoc.getFilename().strref().ends_with(".fir")) {
885 ++seenFIR;
886 if (seenFIR > 8)
887 continue;
888 }
889 }
890 decomposedLocs.insert(loc);
891 }
892 continue;
893 }
894
895 // Might need to skip this fir.
896 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
897 if (fileLoc.getFilename().strref().ends_with(".fir")) {
898 ++seenFIR;
899 if (seenFIR > 8)
900 continue;
901 }
902 }
903 // Otherwise, only add known locations to the set.
904 if (!isa<UnknownLoc>(loc))
905 decomposedLocs.insert(loc);
906 }
907
908 auto locs = decomposedLocs.getArrayRef();
909
910 // Handle the simple cases of less than two locations. Ensure the metadata (if
911 // provided) is not dropped.
912 if (locs.empty())
913 return UnknownLoc::get(context);
914 if (locs.size() == 1)
915 return locs.front();
916
917 return FusedLoc::get(context, locs);
918}
919
920struct Deduper {
921
922 using RenameMap = DenseMap<StringAttr, StringAttr>;
923
925 NLATable *nlaTable, CircuitOp circuit)
926 : context(circuit->getContext()), instanceGraph(instanceGraph),
928 nlaBlock(circuit.getBodyBlock()),
929 nonLocalString(StringAttr::get(context, "circt.nonlocal")),
930 classString(StringAttr::get(context, "class")) {
931 // Populate the NLA cache.
932 for (auto nla : circuit.getOps<hw::HierPathOp>())
933 nlaCache[nla.getNamepathAttr()] = nla.getSymNameAttr();
934 }
935
936 /// Remove the "fromModule", and replace all references to it with the
937 /// "toModule". Modules should be deduplicated in a bottom-up order. Any
938 /// module which is not deduplicated needs to be recorded with the `record`
939 /// call.
940 void dedup(FModuleLike toModule, FModuleLike fromModule) {
941 // A map of operation (e.g. wires, nodes) names which are changed, which is
942 // used to update NLAs that reference the "fromModule".
943 RenameMap renameMap;
944
945 // Merge the port locations.
946 SmallVector<Attribute> newLocs;
947 for (auto [toLoc, fromLoc] : llvm::zip(toModule.getPortLocations(),
948 fromModule.getPortLocations())) {
949 if (toLoc == fromLoc)
950 newLocs.push_back(toLoc);
951 else
952 newLocs.push_back(mergeLoc(context, cast<LocationAttr>(toLoc),
953 cast<LocationAttr>(fromLoc)));
954 }
955 toModule->setAttr("portLocations", ArrayAttr::get(context, newLocs));
956
957 // Merge the two modules.
958 mergeOps(renameMap, toModule, toModule, fromModule, fromModule);
959
960 // Rewrite NLAs pathing through these modules to refer to the to module. It
961 // is safe to do this at this point because NLAs cannot be one element long.
962 // This means that all NLAs which require more context cannot be targetting
963 // something in the module it self.
964 if (auto to = dyn_cast<FModuleOp>(*toModule))
965 rewriteModuleNLAs(renameMap, to, cast<FModuleOp>(*fromModule));
966 else
967 rewriteExtModuleNLAs(renameMap, toModule.getModuleNameAttr(),
968 fromModule.getModuleNameAttr());
969
970 replaceInstances(toModule, fromModule);
971 }
972
973 /// Record the usages of any NLA's in this module, so that we may update the
974 /// annotation if the parent module is deduped with another module.
975 void record(FModuleLike module) {
976 // Record any annotations on the module.
977 recordAnnotations(module);
978 // Record port annotations.
979 for (unsigned i = 0, e = getNumPorts(module); i < e; ++i)
981 // Record any annotations in the module body.
982 module->walk([&](Operation *op) { recordAnnotations(op); });
983 }
984
985private:
986 /// Get a cached namespace for a module.
988 return moduleNamespaces.try_emplace(module, cast<FModuleLike>(module))
989 .first->second;
990 }
991
992 /// For a specific annotation target, record all the unique NLAs which
993 /// target it in the `targetMap`.
995 for (auto anno : target.getAnnotations())
996 if (auto nlaRef = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
997 targetMap[nlaRef.getAttr()].insert(target);
998 }
999
1000 /// Record all targets which use an NLA.
1001 void recordAnnotations(Operation *op) {
1002 // Record annotations.
1003 recordAnnotations(OpAnnoTarget(op));
1004
1005 // Record port annotations only if this is a mem operation.
1006 auto mem = dyn_cast<MemOp>(op);
1007 if (!mem)
1008 return;
1009
1010 // Record port annotations.
1011 for (unsigned i = 0, e = mem->getNumResults(); i < e; ++i)
1012 recordAnnotations(PortAnnoTarget(mem, i));
1013 }
1014
1015 /// This deletes and replaces all instances of the "fromModule" with instances
1016 /// of the "toModule".
1017 void replaceInstances(FModuleLike toModule, Operation *fromModule) {
1018 // Replace all instances of the other module.
1019 auto *fromNode =
1020 instanceGraph[::cast<igraph::ModuleOpInterface>(fromModule)];
1021 auto *toNode = instanceGraph[toModule];
1022 auto toModuleRef = FlatSymbolRefAttr::get(toModule.getModuleNameAttr());
1023 for (auto *oldInstRec : llvm::make_early_inc_range(fromNode->uses())) {
1024 auto inst = oldInstRec->getInstance();
1025 if (auto instOp = dyn_cast<InstanceOp>(*inst)) {
1026 instOp.setModuleNameAttr(toModuleRef);
1027 instOp.setPortNamesAttr(toModule.getPortNamesAttr());
1028 } else if (auto objectOp = dyn_cast<ObjectOp>(*inst)) {
1029 auto classLike = cast<ClassLike>(*toNode->getModule());
1030 ClassType classType = detail::getInstanceTypeForClassLike(classLike);
1031 objectOp.getResult().setType(classType);
1032 }
1033 oldInstRec->getParent()->addInstance(inst, toNode);
1034 oldInstRec->erase();
1035 }
1036 instanceGraph.erase(fromNode);
1037 fromModule->erase();
1038 }
1039
1040 /// Look up the instantiations of the `from` module and create an NLA for each
1041 /// one, appending the baseNamepath to each NLA. This is used to add more
1042 /// context to an already existing NLA. The `fromModule` is used to indicate
1043 /// which module the annotation is coming from before the merge, and will be
1044 /// used to create the namepaths.
1045 SmallVector<FlatSymbolRefAttr>
1046 createNLAs(Operation *fromModule, ArrayRef<Attribute> baseNamepath,
1047 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1048 // Create an attribute array with a placeholder in the first element, where
1049 // the root refence of the NLA will be inserted.
1050 SmallVector<Attribute> namepath = {nullptr};
1051 namepath.append(baseNamepath.begin(), baseNamepath.end());
1052
1053 auto loc = fromModule->getLoc();
1054 auto *fromNode = instanceGraph[cast<igraph::ModuleOpInterface>(fromModule)];
1055 SmallVector<FlatSymbolRefAttr> nlas;
1056 for (auto *instanceRecord : fromNode->uses()) {
1057 auto parent = cast<FModuleOp>(*instanceRecord->getParent()->getModule());
1058 auto inst = instanceRecord->getInstance();
1059 namepath[0] = OpAnnoTarget(inst).getNLAReference(getNamespace(parent));
1060 auto arrayAttr = ArrayAttr::get(context, namepath);
1061 // Check the NLA cache to see if we already have this NLA.
1062 auto &cacheEntry = nlaCache[arrayAttr];
1063 if (!cacheEntry) {
1064 auto nla = OpBuilder::atBlockBegin(nlaBlock).create<hw::HierPathOp>(
1065 loc, "nla", arrayAttr);
1066 // Insert it into the symbol table to get a unique name.
1067 symbolTable.insert(nla);
1068 // Store it in the cache.
1069 cacheEntry = nla.getNameAttr();
1070 nla.setVisibility(vis);
1071 nlaTable->addNLA(nla);
1072 }
1073 auto nlaRef = FlatSymbolRefAttr::get(cast<StringAttr>(cacheEntry));
1074 nlas.push_back(nlaRef);
1075 }
1076 return nlas;
1077 }
1078
1079 /// Look up the instantiations of this module and create an NLA for each one.
1080 /// This returns an array of symbol references which can be used to reference
1081 /// the NLAs.
1082 SmallVector<FlatSymbolRefAttr>
1083 createNLAs(StringAttr toModuleName, FModuleLike fromModule,
1084 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1085 return createNLAs(fromModule, FlatSymbolRefAttr::get(toModuleName), vis);
1086 }
1087
1088 /// Clone the annotation for each NLA in a list. The attribute list should
1089 /// have a placeholder for the "circt.nonlocal" field, and `nonLocalIndex`
1090 /// should be the index of this field.
1091 void cloneAnnotation(SmallVectorImpl<FlatSymbolRefAttr> &nlas,
1092 Annotation anno, ArrayRef<NamedAttribute> attributes,
1093 unsigned nonLocalIndex,
1094 SmallVectorImpl<Annotation> &newAnnotations) {
1095 SmallVector<NamedAttribute> mutableAttributes(attributes.begin(),
1096 attributes.end());
1097 for (auto &nla : nlas) {
1098 // Add the new annotation.
1099 mutableAttributes[nonLocalIndex].setValue(nla);
1100 auto dict = DictionaryAttr::getWithSorted(context, mutableAttributes);
1101 // The original annotation records if its a subannotation.
1102 anno.setDict(dict);
1103 newAnnotations.push_back(anno);
1104 }
1105 }
1106
1107 /// This erases the NLA op, and removes the NLA from every module's NLA map,
1108 /// but it does not delete the NLA reference from the target operation's
1109 /// annotations.
1110 void eraseNLA(hw::HierPathOp nla) {
1111 // Erase the NLA from the leaf module's nlaMap.
1112 targetMap.erase(nla.getNameAttr());
1113 nlaTable->erase(nla);
1114 nlaCache.erase(nla.getNamepathAttr());
1115 symbolTable.erase(nla);
1116 }
1117
1118 /// Process all NLAs referencing the "from" module to point to the "to"
1119 /// module. This is used after merging two modules together.
1120 void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule,
1121 FModuleOp fromModule) {
1122 auto toName = toModule.getNameAttr();
1123 auto fromName = fromModule.getNameAttr();
1124 // Create a copy of the current NLAs. We will be pushing and removing
1125 // NLAs from this op as we go.
1126 auto moduleNLAs = nlaTable->lookup(fromModule.getNameAttr()).vec();
1127 // Change the NLA to target the toModule.
1128 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1129 // Now we walk the NLA searching for ones that require more context to be
1130 // added.
1131 for (auto nla : moduleNLAs) {
1132 auto elements = nla.getNamepath().getValue();
1133 // If we don't need to add more context, we're done here.
1134 if (nla.root() != toName)
1135 continue;
1136 // Create the replacement NLAs.
1137 SmallVector<Attribute> namepath(elements.begin(), elements.end());
1138 auto nlaRefs = createNLAs(fromModule, namepath, nla.getVisibility());
1139 // Copy out the targets, because we will be updating the map.
1140 auto &set = targetMap[nla.getSymNameAttr()];
1141 SmallVector<AnnoTarget> targets(set.begin(), set.end());
1142 // Replace the uses of the old NLA with the new NLAs.
1143 for (auto target : targets) {
1144 // We have to clone any annotation which uses the old NLA for each new
1145 // NLA. This array collects the new set of annotations.
1146 SmallVector<Annotation> newAnnotations;
1147 for (auto anno : target.getAnnotations()) {
1148 // Find the non-local field of the annotation.
1149 auto [it, found] = mlir::impl::findAttrSorted(
1150 anno.begin(), anno.end(), nonLocalString);
1151 // If this annotation doesn't use the target NLA, copy it with no
1152 // changes.
1153 if (!found || cast<FlatSymbolRefAttr>(it->getValue()).getAttr() !=
1154 nla.getSymNameAttr()) {
1155 newAnnotations.push_back(anno);
1156 continue;
1157 }
1158 auto nonLocalIndex = std::distance(anno.begin(), it);
1159 // Clone the annotation and add it to the list of new annotations.
1160 cloneAnnotation(nlaRefs, anno,
1161 ArrayRef<NamedAttribute>(anno.begin(), anno.end()),
1162 nonLocalIndex, newAnnotations);
1163 }
1164
1165 // Apply the new annotations to the operation.
1166 AnnotationSet annotations(newAnnotations, context);
1167 target.setAnnotations(annotations);
1168 // Record that target uses the NLA.
1169 for (auto nla : nlaRefs)
1170 targetMap[nla.getAttr()].insert(target);
1171 }
1172
1173 // Erase the old NLA and remove it from all breadcrumbs.
1174 eraseNLA(nla);
1175 }
1176 }
1177
1178 /// Process all the NLAs that the two modules participate in, replacing
1179 /// references to the "from" module with references to the "to" module, and
1180 /// adding more context if necessary.
1181 void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule,
1182 FModuleOp fromModule) {
1183 addAnnotationContext(renameMap, toModule, toModule);
1184 addAnnotationContext(renameMap, toModule, fromModule);
1185 }
1186
1187 // Update all NLAs which the "from" external module participates in to the
1188 // "toName".
1189 void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName,
1190 StringAttr fromName) {
1191 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1192 }
1193
1194 /// Take an annotation, and update it to be a non-local annotation. If the
1195 /// annotation is already non-local and has enough context, it will be skipped
1196 /// for now. Return true if the annotation was made non-local.
1197 bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to,
1198 FModuleLike fromModule, Annotation anno,
1199 SmallVectorImpl<Annotation> &newAnnotations) {
1200 // Start constructing a new annotation, pushing a "circt.nonLocal" field
1201 // into the correct spot if its not already a non-local annotation.
1202 SmallVector<NamedAttribute> attributes;
1203 int nonLocalIndex = -1;
1204 for (const auto &val : llvm::enumerate(anno)) {
1205 auto attr = val.value();
1206 // Is this field "circt.nonlocal"?
1207 auto compare = attr.getName().compare(nonLocalString);
1208 assert(compare != 0 && "should not pass non-local annotations here");
1209 if (compare == 1) {
1210 // This annotation definitely does not have "circt.nonlocal" field. Push
1211 // an empty place holder for the non-local annotation.
1212 nonLocalIndex = val.index();
1213 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1214 break;
1215 }
1216 // Otherwise push the current attribute and keep searching for the
1217 // "circt.nonlocal" field.
1218 attributes.push_back(attr);
1219 }
1220 if (nonLocalIndex == -1) {
1221 // Push an empty "circt.nonlocal" field to the last slot.
1222 nonLocalIndex = attributes.size();
1223 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1224 } else {
1225 // Copy the remaining annotation fields in.
1226 attributes.append(anno.begin() + nonLocalIndex, anno.end());
1227 }
1228
1229 // Construct the NLAs if we don't have any yet.
1230 auto nlaRefs = createNLAs(toModuleName, fromModule);
1231 for (auto nla : nlaRefs)
1232 targetMap[nla.getAttr()].insert(to);
1233
1234 // Clone the annotation for each new NLA.
1235 cloneAnnotation(nlaRefs, anno, attributes, nonLocalIndex, newAnnotations);
1236 return true;
1237 }
1238
1239 void copyAnnotations(FModuleLike toModule, AnnoTarget to,
1240 FModuleLike fromModule, AnnotationSet annos,
1241 SmallVectorImpl<Annotation> &newAnnotations,
1242 SmallPtrSetImpl<Attribute> &dontTouches) {
1243 for (auto anno : annos) {
1244 if (anno.isClass(dontTouchAnnoClass)) {
1245 // Remove the nonlocal field of the annotation if it has one, since this
1246 // is a sticky annotation.
1247 anno.removeMember("circt.nonlocal");
1248 auto [it, inserted] = dontTouches.insert(anno.getAttr());
1249 if (inserted)
1250 newAnnotations.push_back(anno);
1251 continue;
1252 }
1253 // If the annotation is already non-local, we add it as is. It is already
1254 // added to the target map.
1255 if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal")) {
1256 newAnnotations.push_back(anno);
1257 targetMap[nla.getAttr()].insert(to);
1258 continue;
1259 }
1260 // Otherwise make the annotation non-local and add it to the set.
1261 makeAnnotationNonLocal(toModule.getModuleNameAttr(), to, fromModule, anno,
1262 newAnnotations);
1263 }
1264 }
1265
1266 /// Merge the annotations of a specific target, either a operation or a port
1267 /// on an operation.
1268 void mergeAnnotations(FModuleLike toModule, AnnoTarget to,
1269 AnnotationSet toAnnos, FModuleLike fromModule,
1270 AnnoTarget from, AnnotationSet fromAnnos) {
1271 // This is a list of all the annotations which will be added to `to`.
1272 SmallVector<Annotation> newAnnotations;
1273
1274 // We have special case handling of DontTouch to prevent it from being
1275 // turned into a non-local annotation, and to remove duplicates.
1276 llvm::SmallPtrSet<Attribute, 4> dontTouches;
1277
1278 // Iterate the annotations, transforming most annotations into non-local
1279 // ones.
1280 copyAnnotations(toModule, to, toModule, toAnnos, newAnnotations,
1281 dontTouches);
1282 copyAnnotations(toModule, to, fromModule, fromAnnos, newAnnotations,
1283 dontTouches);
1284
1285 // Copy over all the new annotations.
1286 if (!newAnnotations.empty())
1287 to.setAnnotations(AnnotationSet(newAnnotations, context));
1288 }
1289
1290 /// Merge all annotations and port annotations on two operations.
1291 void mergeAnnotations(FModuleLike toModule, Operation *to,
1292 FModuleLike fromModule, Operation *from) {
1293 // Merge op annotations.
1294 mergeAnnotations(toModule, OpAnnoTarget(to), AnnotationSet(to), fromModule,
1295 OpAnnoTarget(from), AnnotationSet(from));
1296
1297 // Merge port annotations.
1298 if (toModule == to) {
1299 // Merge module port annotations.
1300 for (unsigned i = 0, e = getNumPorts(toModule); i < e; ++i)
1301 mergeAnnotations(toModule, PortAnnoTarget(toModule, i),
1302 AnnotationSet::forPort(toModule, i), fromModule,
1303 PortAnnoTarget(fromModule, i),
1304 AnnotationSet::forPort(fromModule, i));
1305 } else if (auto toMem = dyn_cast<MemOp>(to)) {
1306 // Merge memory port annotations.
1307 auto fromMem = cast<MemOp>(from);
1308 for (unsigned i = 0, e = toMem.getNumResults(); i < e; ++i)
1309 mergeAnnotations(toModule, PortAnnoTarget(toMem, i),
1310 AnnotationSet::forPort(toMem, i), fromModule,
1311 PortAnnoTarget(fromMem, i),
1312 AnnotationSet::forPort(fromMem, i));
1313 }
1314 }
1315
1316 hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule,
1317 hw::InnerSymAttr toSym,
1318 hw::InnerSymAttr fromSym) {
1319 if (fromSym && !fromSym.getProps().empty()) {
1320 auto &isn = getNamespace(toModule);
1321 // The properties for the new inner symbol..
1322 SmallVector<hw::InnerSymPropertiesAttr> newProps;
1323 // If the "to" op already has an inner symbol, copy all its properties.
1324 if (toSym)
1325 llvm::append_range(newProps, toSym);
1326 // Add each property from the fromSym to the toSym.
1327 for (auto fromProp : fromSym) {
1328 hw::InnerSymPropertiesAttr newProp;
1329 auto *it = llvm::find_if(newProps, [&](auto p) {
1330 return p.getFieldID() == fromProp.getFieldID();
1331 });
1332 if (it != newProps.end()) {
1333 // If we already have an inner sym with the same field id, use
1334 // that.
1335 newProp = *it;
1336 // If the old symbol is public, we need to make the new one public.
1337 if (fromProp.getSymVisibility().getValue() == "public" &&
1338 newProp.getSymVisibility().getValue() != "public") {
1339 *it = hw::InnerSymPropertiesAttr::get(context, newProp.getName(),
1340 newProp.getFieldID(),
1341 fromProp.getSymVisibility());
1342 }
1343 } else {
1344 // We need to add a new property to the inner symbol for this field.
1345 auto newName = isn.newName(fromProp.getName().getValue());
1346 newProp = hw::InnerSymPropertiesAttr::get(
1347 context, StringAttr::get(context, newName), fromProp.getFieldID(),
1348 fromProp.getSymVisibility());
1349 newProps.push_back(newProp);
1350 }
1351 renameMap[fromProp.getName()] = newProp.getName();
1352 }
1353 // Sort the fields by field id.
1354 llvm::sort(newProps, [](auto &p, auto &q) {
1355 return p.getFieldID() < q.getFieldID();
1356 });
1357 // Return the merged inner symbol.
1358 return hw::InnerSymAttr::get(context, newProps);
1359 }
1360 return hw::InnerSymAttr();
1361 }
1362
1363 // Record the symbol name change of the operation or any of its ports when
1364 // merging two operations. The renamed symbols are used to update the
1365 // target of any NLAs. This will add symbols to the "to" operation if needed.
1366 void recordSymRenames(RenameMap &renameMap, FModuleLike toModule,
1367 Operation *to, FModuleLike fromModule,
1368 Operation *from) {
1369 // If the "from" operation has an inner_sym, we need to make sure the
1370 // "to" operation also has an `inner_sym` and then record the renaming.
1371 if (auto fromInnerSym = dyn_cast<hw::InnerSymbolOpInterface>(from)) {
1372 auto toInnerSym = cast<hw::InnerSymbolOpInterface>(to);
1373 if (auto newSymAttr = mergeInnerSymbols(renameMap, toModule,
1374 toInnerSym.getInnerSymAttr(),
1375 fromInnerSym.getInnerSymAttr()))
1376 toInnerSym.setInnerSymbolAttr(newSymAttr);
1377 }
1378
1379 // If there are no port symbols on the "from" operation, we are done here.
1380 auto fromPortSyms = from->getAttrOfType<ArrayAttr>("portSymbols");
1381 if (!fromPortSyms || fromPortSyms.empty())
1382 return;
1383 // We have to map each "fromPort" to each "toPort".
1384 auto portCount = fromPortSyms.size();
1385 auto toPortSyms = to->getAttrOfType<ArrayAttr>("portSymbols");
1386
1387 // Create an array of new port symbols for the "to" operation, copy in the
1388 // old symbols if it has any, create an empty symbol array if it doesn't.
1389 SmallVector<Attribute> newPortSyms;
1390 if (toPortSyms.empty())
1391 newPortSyms.assign(portCount, hw::InnerSymAttr());
1392 else
1393 newPortSyms.assign(toPortSyms.begin(), toPortSyms.end());
1394
1395 for (unsigned portNo = 0; portNo < portCount; ++portNo) {
1396 if (auto newPortSym = mergeInnerSymbols(
1397 renameMap, toModule,
1398 llvm::cast_if_present<hw::InnerSymAttr>(newPortSyms[portNo]),
1399 cast<hw::InnerSymAttr>(fromPortSyms[portNo]))) {
1400 newPortSyms[portNo] = newPortSym;
1401 }
1402 }
1403
1404 // Commit the new symbol attribute.
1405 FModuleLike::fixupPortSymsArray(newPortSyms, toModule.getContext());
1406 cast<FModuleLike>(to).setPortSymbols(newPortSyms);
1407 }
1408
1409 /// Recursively merge two operations.
1410 // NOLINTNEXTLINE(misc-no-recursion)
1411 void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to,
1412 FModuleLike fromModule, Operation *from) {
1413 // Merge the operation locations.
1414 if (to->getLoc() != from->getLoc())
1415 to->setLoc(mergeLoc(context, to->getLoc(), from->getLoc()));
1416
1417 // Recurse into any regions.
1418 for (auto regions : llvm::zip(to->getRegions(), from->getRegions()))
1419 mergeRegions(renameMap, toModule, std::get<0>(regions), fromModule,
1420 std::get<1>(regions));
1421
1422 // Record any inner_sym renamings that happened.
1423 recordSymRenames(renameMap, toModule, to, fromModule, from);
1424
1425 // Merge the annotations.
1426 mergeAnnotations(toModule, to, fromModule, from);
1427 }
1428
1429 /// Recursively merge two blocks.
1430 void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock,
1431 FModuleLike fromModule, Block &fromBlock) {
1432 // Merge the block locations.
1433 for (auto [toArg, fromArg] :
1434 llvm::zip(toBlock.getArguments(), fromBlock.getArguments()))
1435 if (toArg.getLoc() != fromArg.getLoc())
1436 toArg.setLoc(mergeLoc(context, toArg.getLoc(), fromArg.getLoc()));
1437
1438 for (auto ops : llvm::zip(toBlock, fromBlock))
1439 mergeOps(renameMap, toModule, &std::get<0>(ops), fromModule,
1440 &std::get<1>(ops));
1441 }
1442
1443 // Recursively merge two regions.
1444 void mergeRegions(RenameMap &renameMap, FModuleLike toModule,
1445 Region &toRegion, FModuleLike fromModule,
1446 Region &fromRegion) {
1447 for (auto blocks : llvm::zip(toRegion, fromRegion))
1448 mergeBlocks(renameMap, toModule, std::get<0>(blocks), fromModule,
1449 std::get<1>(blocks));
1450 }
1451
1452 MLIRContext *context;
1454 SymbolTable &symbolTable;
1455
1456 /// Cached nla table analysis.
1457 NLATable *nlaTable = nullptr;
1458
1459 /// We insert all NLAs to the beginning of this block.
1460 Block *nlaBlock;
1461
1462 // This maps an NLA to the operations and ports that uses it.
1463 DenseMap<Attribute, llvm::SmallDenseSet<AnnoTarget>> targetMap;
1464
1465 // This is a cache to avoid creating duplicate NLAs. This maps the ArrayAtr
1466 // of the NLA's path to the name of the NLA which contains it.
1467 DenseMap<Attribute, Attribute> nlaCache;
1468
1469 // Cached attributes for faster comparisons and attribute building.
1470 StringAttr nonLocalString;
1471 StringAttr classString;
1472
1473 /// A module namespace cache.
1474 DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
1475};
1476
1477//===----------------------------------------------------------------------===//
1478// Fixup
1479//===----------------------------------------------------------------------===//
1480
1481/// This fixes up ClassLikes with ClassType ports, when the classes have
1482/// deduped. For each ClassType port, if the object reference being assigned is
1483/// a different type, update the port type. Returns true if the ClassOp was
1484/// updated and the associated ObjectOps should be updated.
1485bool fixupClassOp(ClassOp classOp) {
1486 // New port type attributes, if necessary.
1487 SmallVector<Attribute> newPortTypes;
1488 bool anyDifferences = false;
1489
1490 // Check each port.
1491 for (size_t i = 0, e = classOp.getNumPorts(); i < e; ++i) {
1492 // Check if this port is a ClassType. If not, save the original type
1493 // attribute in case we need to update port types.
1494 auto portClassType = dyn_cast<ClassType>(classOp.getPortType(i));
1495 if (!portClassType) {
1496 newPortTypes.push_back(classOp.getPortTypeAttr(i));
1497 continue;
1498 }
1499
1500 // Check if this port is assigned a reference of a different ClassType.
1501 Type newPortClassType;
1502 BlockArgument portArg = classOp.getArgument(i);
1503 for (auto &use : portArg.getUses()) {
1504 if (auto propassign = dyn_cast<PropAssignOp>(use.getOwner())) {
1505 Type sourceType = propassign.getSrc().getType();
1506 if (propassign.getDest() == use.get() && sourceType != portClassType) {
1507 // Double check that all references are the same new type.
1508 if (newPortClassType) {
1509 assert(newPortClassType == sourceType &&
1510 "expected all references to be of the same type");
1511 continue;
1512 }
1513
1514 newPortClassType = sourceType;
1515 }
1516 }
1517 }
1518
1519 // If there was no difference, save the original type attribute in case we
1520 // need to update port types and move along.
1521 if (!newPortClassType) {
1522 newPortTypes.push_back(classOp.getPortTypeAttr(i));
1523 continue;
1524 }
1525
1526 // The port type changed, so update the block argument, save the new port
1527 // type attribute, and indicate there was a difference.
1528 classOp.getArgument(i).setType(newPortClassType);
1529 newPortTypes.push_back(TypeAttr::get(newPortClassType));
1530 anyDifferences = true;
1531 }
1532
1533 // If necessary, update port types.
1534 if (anyDifferences)
1535 classOp.setPortTypes(newPortTypes);
1536
1537 return anyDifferences;
1538}
1539
1540/// This fixes up ObjectOps when the signature of their ClassOp changes. This
1541/// amounts to updating the ObjectOp result type to match the newly updated
1542/// ClassOp type.
1543void fixupObjectOp(ObjectOp objectOp, ClassType newClassType) {
1544 objectOp.getResult().setType(newClassType);
1545}
1546
1547/// This fixes up connects when the field names of a bundle type changes. It
1548/// finds all fields which were previously bulk connected and legalizes it
1549/// into a connect for each field.
1550void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src) {
1551 // If the types already match we can emit a connect.
1552 auto dstType = dst.getType();
1553 auto srcType = src.getType();
1554 if (dstType == srcType) {
1555 emitConnect(builder, dst, src);
1556 return;
1557 }
1558 // It must be a bundle type and the field name has changed. We have to
1559 // manually decompose the bulk connect into a connect for each field.
1560 auto dstBundle = type_cast<BundleType>(dstType);
1561 auto srcBundle = type_cast<BundleType>(srcType);
1562 for (unsigned i = 0; i < dstBundle.getNumElements(); ++i) {
1563 auto dstField = builder.create<SubfieldOp>(dst, i);
1564 auto srcField = builder.create<SubfieldOp>(src, i);
1565 if (dstBundle.getElement(i).isFlip) {
1566 std::swap(srcBundle, dstBundle);
1567 std::swap(srcField, dstField);
1568 }
1569 fixupConnect(builder, dstField, srcField);
1570 }
1571}
1572
1573/// This is the root method to fixup module references when a module changes.
1574/// It matches all the results of "to" module with the results of the "from"
1575/// module.
1576void fixupAllModules(InstanceGraph &instanceGraph) {
1577 for (auto *node : instanceGraph) {
1578 auto module = cast<FModuleLike>(*node->getModule());
1579
1580 // Handle class declarations here.
1581 bool shouldFixupObjects = false;
1582 auto classOp = dyn_cast<ClassOp>(module.getOperation());
1583 if (classOp)
1584 shouldFixupObjects = fixupClassOp(classOp);
1585
1586 for (auto *instRec : node->uses()) {
1587 // Handle object instantiations here.
1588 if (classOp) {
1589 if (shouldFixupObjects) {
1590 fixupObjectOp(instRec->getInstance<ObjectOp>(),
1591 classOp.getInstanceType());
1592 }
1593 continue;
1594 }
1595
1596 auto inst = instRec->getInstance<InstanceOp>();
1597 // Only handle module instantiations here.
1598 if (!inst)
1599 continue;
1600 ImplicitLocOpBuilder builder(inst.getLoc(), inst->getContext());
1601 builder.setInsertionPointAfter(inst);
1602 for (size_t i = 0, e = getNumPorts(module); i < e; ++i) {
1603 auto result = inst.getResult(i);
1604 auto newType = module.getPortType(i);
1605 auto oldType = result.getType();
1606 // If the type has not changed, we don't have to fix up anything.
1607 if (newType == oldType)
1608 continue;
1609 // If the type changed we transform it back to the old type with an
1610 // intermediate wire.
1611 auto wire =
1612 builder.create<WireOp>(oldType, inst.getPortName(i)).getResult();
1613 result.replaceAllUsesWith(wire);
1614 result.setType(newType);
1615 if (inst.getPortDirection(i) == Direction::Out)
1616 fixupConnect(builder, wire, result);
1617 else
1618 fixupConnect(builder, result, wire);
1619 }
1620 }
1621 }
1622}
1623
1624namespace llvm {
1625/// A DenseMapInfo implementation for `ModuleInfo` that is a pair of
1626/// llvm::SHA256 hashes, which are represented as std::array<uint8_t, 32>, and
1627/// an array of string attributes. This allows us to create a DenseMap with
1628/// `ModuleInfo` as keys.
1629template <>
1630struct DenseMapInfo<ModuleInfo> {
1631 static inline ModuleInfo getEmptyKey() {
1632 std::array<uint8_t, 32> key;
1633 std::fill(key.begin(), key.end(), ~0);
1634 return {key, {}};
1635 }
1636
1637 static inline ModuleInfo getTombstoneKey() {
1638 std::array<uint8_t, 32> key;
1639 std::fill(key.begin(), key.end(), ~0 - 1);
1640 return {key, {}};
1641 }
1642
1643 static unsigned getHashValue(const ModuleInfo &val) {
1644 // We assume SHA256 is already a good hash and just truncate down to the
1645 // number of bytes we need for DenseMap.
1646 unsigned hash;
1647 std::memcpy(&hash, val.structuralHash.data(), sizeof(unsigned));
1648
1649 // Combine module names.
1650 return llvm::hash_combine(
1651 hash, llvm::hash_combine_range(val.referredModuleNames.begin(),
1652 val.referredModuleNames.end()));
1653 }
1654
1655 static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs) {
1656 return lhs == rhs;
1657 }
1658};
1659} // namespace llvm
1660
1661//===----------------------------------------------------------------------===//
1662// DedupPass
1663//===----------------------------------------------------------------------===//
1664
1665namespace {
1666class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {
1667 void runOnOperation() override {
1668 auto *context = &getContext();
1669 auto circuit = getOperation();
1670 auto &instanceGraph = getAnalysis<InstanceGraph>();
1671 auto *nlaTable = &getAnalysis<NLATable>();
1672 auto &symbolTable = getAnalysis<SymbolTable>();
1673 Deduper deduper(instanceGraph, symbolTable, nlaTable, circuit);
1674 Equivalence equiv(context, instanceGraph);
1675 auto anythingChanged = false;
1676
1677 // Modules annotated with this should not be considered for deduplication.
1678 auto noDedupClass = StringAttr::get(context, noDedupAnnoClass);
1679
1680 // Only modules within the same group may be deduplicated.
1681 auto dedupGroupClass = StringAttr::get(context, dedupGroupAnnoClass);
1682
1683 // A map of all the module moduleInfo that we have calculated so far.
1684 llvm::DenseMap<ModuleInfo, Operation *> moduleInfoToModule;
1685
1686 // We track the name of the module that each module is deduped into, so that
1687 // we can make sure all modules which are marked "must dedup" with each
1688 // other were all deduped to the same module.
1689 DenseMap<Attribute, StringAttr> dedupMap;
1690
1691 // We must iterate the modules from the bottom up so that we can properly
1692 // deduplicate the modules. We copy the list of modules into a vector first
1693 // to avoid iterator invalidation while we mutate the instance graph.
1694 SmallVector<FModuleLike, 0> modules(
1695 llvm::map_range(llvm::post_order(&instanceGraph), [](auto *node) {
1696 return cast<FModuleLike>(*node->getModule());
1697 }));
1698
1699 SmallVector<std::optional<ModuleInfo>> moduleInfos(modules.size());
1700 StructuralHasherSharedConstants hasherConstants(&getContext());
1701
1702 // Attribute name used to store dedup_group for this pass.
1703 auto dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
1704
1705 // Move dedup group annotations to attributes on the module.
1706 // This results in the desired behavior (included in hash),
1707 // and avoids unnecessary processing of these as annotations
1708 // that need to be tracked, made non-local, so on.
1709 for (auto module : modules) {
1710 llvm::SmallSetVector<StringAttr, 1> groups;
1712 module, [&groups, dedupGroupClass](Annotation annotation) {
1713 if (annotation.getClassAttr() != dedupGroupClass)
1714 return false;
1715 groups.insert(annotation.getMember<StringAttr>("group"));
1716 return true;
1717 });
1718 if (groups.size() > 1) {
1719 module.emitError("module belongs to multiple dedup groups: ") << groups;
1720 return signalPassFailure();
1721 }
1722 assert(!module->hasAttr(dedupGroupAttrName) &&
1723 "unexpected existing use of temporary dedup group attribute");
1724 if (!groups.empty())
1725 module->setDiscardableAttr(dedupGroupAttrName, groups.front());
1726 }
1727
1728 // Calculate module information parallelly.
1729 auto result = mlir::failableParallelForEach(
1730 context, llvm::seq(modules.size()), [&](unsigned idx) {
1731 auto module = modules[idx];
1732 // If the module is marked with NoDedup, just skip it.
1733 if (AnnotationSet::hasAnnotation(module, noDedupClass))
1734 return success();
1735
1736 // Only dedup extmodule's with defname.
1737 if (auto ext = dyn_cast<FExtModuleOp>(*module);
1738 ext && !ext.getDefname().has_value())
1739 return success();
1740
1741 StructuralHasher hasher(hasherConstants);
1742 // Calculate the hash of the module and referred module names.
1743 moduleInfos[idx] = hasher.getModuleInfo(module);
1744 return success();
1745 });
1746
1747 if (result.failed())
1748 return signalPassFailure();
1749
1750 for (auto [i, module] : llvm::enumerate(modules)) {
1751 auto moduleName = module.getModuleNameAttr();
1752 auto &maybeModuleInfo = moduleInfos[i];
1753 // If the hash was not calculated, we need to skip it.
1754 if (!maybeModuleInfo) {
1755 // We record it in the dedup map to help detect errors when the user
1756 // marks the module as both NoDedup and MustDedup. We do not record this
1757 // module in the hasher to make sure no other module dedups "into" this
1758 // one.
1759 dedupMap[moduleName] = moduleName;
1760 continue;
1761 }
1762
1763 auto &moduleInfo = maybeModuleInfo.value();
1764
1765 // Replace module names referred in the module with new names.
1766 for (auto &referredModule : moduleInfo.referredModuleNames)
1767 referredModule = dedupMap[referredModule];
1768
1769 // Check if there a module with the same hash.
1770 auto it = moduleInfoToModule.find(moduleInfo);
1771 if (it != moduleInfoToModule.end()) {
1772 auto original = cast<FModuleLike>(it->second);
1773 auto originalName = original.getModuleNameAttr();
1774
1775 // If the current module is public, and the original is private, we
1776 // want to dedup the private module into the public one.
1777 if (!canRemoveModule(module)) {
1778 // If both modules are public, then we can't dedup anything.
1779 if (!canRemoveModule(original))
1780 continue;
1781 // Swap the canonical module in the dedup map.
1782 for (auto &[originalName, dedupedName] : dedupMap)
1783 if (dedupedName == originalName)
1784 dedupedName = moduleName;
1785 // Update the module hash table to point to the new original, so all
1786 // future modules dedup with the new canonical module.
1787 it->second = module;
1788 // Swap the locals.
1789 std::swap(originalName, moduleName);
1790 std::swap(original, module);
1791 }
1792
1793 // Record the group ID of the other module.
1794 dedupMap[moduleName] = originalName;
1795 deduper.dedup(original, module);
1796 ++erasedModules;
1797 anythingChanged = true;
1798 continue;
1799 }
1800 // Any module not deduplicated must be recorded.
1801 deduper.record(module);
1802 // Add the module to a new dedup group.
1803 dedupMap[moduleName] = moduleName;
1804 // Record the module info.
1805 moduleInfoToModule[std::move(moduleInfo)] = module;
1806 }
1807
1808 // This part verifies that all modules marked by "MustDedup" have been
1809 // properly deduped with each other. For this check to succeed, all modules
1810 // have to been deduped to the same module. It is possible that a module was
1811 // deduped with the wrong thing.
1812
1813 auto failed = false;
1814 // This parses the module name out of a target string.
1815 auto parseModule = [&](Attribute path) -> StringAttr {
1816 // Each module is listed as a target "~Circuit|Module" which we have to
1817 // parse.
1818 auto [_, rhs] = cast<StringAttr>(path).getValue().split('|');
1819 return StringAttr::get(context, rhs);
1820 };
1821 // This gets the name of the module which the current module was deduped
1822 // with. If the named module isn't in the map, then we didn't encounter it
1823 // in the circuit.
1824 auto getLead = [&](StringAttr module) -> StringAttr {
1825 auto it = dedupMap.find(module);
1826 if (it == dedupMap.end()) {
1827 auto diag = emitError(circuit.getLoc(),
1828 "MustDeduplicateAnnotation references module ")
1829 << module << " which does not exist";
1830 failed = true;
1831 return nullptr;
1832 }
1833 return it->second;
1834 };
1835
1836 AnnotationSet::removeAnnotations(circuit, [&](Annotation annotation) {
1837 if (!annotation.isClass(mustDedupAnnoClass))
1838 return false;
1839 auto modules = annotation.getMember<ArrayAttr>("modules");
1840 if (!modules) {
1841 emitError(circuit.getLoc(),
1842 "MustDeduplicateAnnotation missing \"modules\" member");
1843 failed = true;
1844 return false;
1845 }
1846 // Empty module list has nothing to process.
1847 if (modules.empty())
1848 return true;
1849 // Get the first element.
1850 auto firstModule = parseModule(modules[0]);
1851 auto firstLead = getLead(firstModule);
1852 if (!firstLead)
1853 return false;
1854 // Verify that the remaining elements are all the same as the first.
1855 for (auto attr : modules.getValue().drop_front()) {
1856 auto nextModule = parseModule(attr);
1857 auto nextLead = getLead(nextModule);
1858 if (!nextLead)
1859 return false;
1860 if (firstLead != nextLead) {
1861 auto diag = emitError(circuit.getLoc(), "module ")
1862 << nextModule << " not deduplicated with " << firstModule;
1863 auto a = instanceGraph.lookup(firstLead)->getModule();
1864 auto b = instanceGraph.lookup(nextLead)->getModule();
1865 equiv.check(diag, a, b);
1866 failed = true;
1867 return false;
1868 }
1869 }
1870 return true;
1871 });
1872 if (failed)
1873 return signalPassFailure();
1874
1875 // Remove all dedup group attributes, they only exist during this pass.
1876 for (auto module : circuit.getOps<FModuleLike>())
1877 module->removeDiscardableAttr(dedupGroupAttrName);
1878
1879 // Walk all the modules and fixup the instance operation to return the
1880 // correct type. We delay this fixup until the end because doing it early
1881 // can block the deduplication of the parent modules.
1882 fixupAllModules(instanceGraph);
1883
1884 markAnalysesPreserved<NLATable>();
1885 if (!anythingChanged)
1886 markAllAnalysesPreserved();
1887 }
1888};
1889} // end anonymous namespace
1890
1891std::unique_ptr<mlir::Pass> circt::firrtl::createDedupPass() {
1892 return std::make_unique<DedupPass>();
1893}
assert(baseType &&"element must be base type")
static Location mergeLoc(MLIRContext *context, Location to, Location from)
Definition Dedup.cpp:871
void fixupObjectOp(ObjectOp objectOp, ClassType newClassType)
This fixes up ObjectOps when the signature of their ClassOp changes.
Definition Dedup.cpp:1543
llvm::raw_ostream & printHash(llvm::raw_ostream &stream, llvm::SHA256 &data)
Definition Dedup.cpp:79
static bool canRemoveModule(mlir::SymbolOpInterface symbol)
Returns true if the module can be removed.
Definition Dedup.cpp:51
void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src)
This fixes up connects when the field names of a bundle type changes.
Definition Dedup.cpp:1550
void fixupAllModules(InstanceGraph &instanceGraph)
This is the root method to fixup module references when a module changes.
Definition Dedup.cpp:1576
bool fixupClassOp(ClassOp classOp)
This fixes up ClassLikes with ClassType ports, when the classes have deduped.
Definition Dedup.cpp:1485
llvm::raw_ostream & printHex(llvm::raw_ostream &stream, ArrayRef< uint8_t > bytes)
Definition Dedup.cpp:73
static void mergeRegions(Region *region1, Region *region2)
Definition HWCleanup.cpp:77
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
void setDict(DictionaryAttr dict)
Set the data dictionary of this attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
StringAttr getClassAttr() const
Return the 'class' that this annotation is representing.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
A table of inner symbols and their resolutions.
auto getModule()
Get the module that this node is tracking.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
static StringRef toString(Direction direction)
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
constexpr const char * mustDedupAnnoClass
constexpr const char * noDedupAnnoClass
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
constexpr const char * dedupGroupAnnoClass
std::unique_ptr< mlir::Pass > createDedupPass()
Definition Dedup.cpp:1891
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
constexpr const char * dontTouchAnnoClass
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:35
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
SmallVector< FlatSymbolRefAttr > createNLAs(Operation *fromModule, ArrayRef< Attribute > baseNamepath, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of the from module and create an NLA for each one, appending the baseNamep...
Definition Dedup.cpp:1046
MLIRContext * context
Definition Dedup.cpp:1452
Block * nlaBlock
We insert all NLAs to the beginning of this block.
Definition Dedup.cpp:1460
void recordAnnotations(Operation *op)
Record all targets which use an NLA.
Definition Dedup.cpp:1001
void eraseNLA(hw::HierPathOp nla)
This erases the NLA op, and removes the NLA from every module's NLA map, but it does not delete the N...
Definition Dedup.cpp:1110
void mergeAnnotations(FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Merge all annotations and port annotations on two operations.
Definition Dedup.cpp:1291
void replaceInstances(FModuleLike toModule, Operation *fromModule)
This deletes and replaces all instances of the "fromModule" with instances of the "toModule".
Definition Dedup.cpp:1017
void record(FModuleLike module)
Record the usages of any NLA's in this module, so that we may update the annotation if the parent mod...
Definition Dedup.cpp:975
void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName, StringAttr fromName)
Definition Dedup.cpp:1189
void mergeRegions(RenameMap &renameMap, FModuleLike toModule, Region &toRegion, FModuleLike fromModule, Region &fromRegion)
Definition Dedup.cpp:1444
void dedup(FModuleLike toModule, FModuleLike fromModule)
Remove the "fromModule", and replace all references to it with the "toModule".
Definition Dedup.cpp:940
void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all the NLAs that the two modules participate in, replacing references to the "from" module w...
Definition Dedup.cpp:1181
SmallVector< FlatSymbolRefAttr > createNLAs(StringAttr toModuleName, FModuleLike fromModule, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of this module and create an NLA for each one.
Definition Dedup.cpp:1083
void recordAnnotations(AnnoTarget target)
For a specific annotation target, record all the unique NLAs which target it in the targetMap.
Definition Dedup.cpp:994
NLATable * nlaTable
Cached nla table analysis.
Definition Dedup.cpp:1457
hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule, hw::InnerSymAttr toSym, hw::InnerSymAttr fromSym)
Definition Dedup.cpp:1316
void cloneAnnotation(SmallVectorImpl< FlatSymbolRefAttr > &nlas, Annotation anno, ArrayRef< NamedAttribute > attributes, unsigned nonLocalIndex, SmallVectorImpl< Annotation > &newAnnotations)
Clone the annotation for each NLA in a list.
Definition Dedup.cpp:1091
void recordSymRenames(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Definition Dedup.cpp:1366
void mergeAnnotations(FModuleLike toModule, AnnoTarget to, AnnotationSet toAnnos, FModuleLike fromModule, AnnoTarget from, AnnotationSet fromAnnos)
Merge the annotations of a specific target, either a operation or a port on an operation.
Definition Dedup.cpp:1268
StringAttr nonLocalString
Definition Dedup.cpp:1470
hw::InnerSymbolNamespace & getNamespace(Operation *module)
Get a cached namespace for a module.
Definition Dedup.cpp:987
SymbolTable & symbolTable
Definition Dedup.cpp:1454
void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Recursively merge two operations.
Definition Dedup.cpp:1411
DenseMap< Operation *, hw::InnerSymbolNamespace > moduleNamespaces
A module namespace cache.
Definition Dedup.cpp:1474
bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to, FModuleLike fromModule, Annotation anno, SmallVectorImpl< Annotation > &newAnnotations)
Take an annotation, and update it to be a non-local annotation.
Definition Dedup.cpp:1197
InstanceGraph & instanceGraph
Definition Dedup.cpp:1453
void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock, FModuleLike fromModule, Block &fromBlock)
Recursively merge two blocks.
Definition Dedup.cpp:1430
DenseMap< Attribute, llvm::SmallDenseSet< AnnoTarget > > targetMap
Definition Dedup.cpp:1463
StringAttr classString
Definition Dedup.cpp:1471
void copyAnnotations(FModuleLike toModule, AnnoTarget to, FModuleLike fromModule, AnnotationSet annos, SmallVectorImpl< Annotation > &newAnnotations, SmallPtrSetImpl< Attribute > &dontTouches)
Definition Dedup.cpp:1239
Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable, NLATable *nlaTable, CircuitOp circuit)
Definition Dedup.cpp:924
void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all NLAs referencing the "from" module to point to the "to" module.
Definition Dedup.cpp:1120
DenseMap< StringAttr, StringAttr > RenameMap
Definition Dedup.cpp:922
DenseMap< Attribute, Attribute > nlaCache
Definition Dedup.cpp:1467
const hw::InnerSymbolTable & a
Definition Dedup.cpp:408
ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
Definition Dedup.cpp:405
const hw::InnerSymbolTable & b
Definition Dedup.cpp:409
This class is for reporting differences between two modules which should have been deduplicated.
Definition Dedup.cpp:387
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:858
std::string prettyPrint(Attribute attr)
Definition Dedup.cpp:412
LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a, FInstanceLike b)
Definition Dedup.cpp:694
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Block &aBlock, Operation *b, Block &bBlock)
Definition Dedup.cpp:483
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, Type aType, Operation *b, Type bType)
Definition Dedup.cpp:457
StringAttr noDedupClass
Definition Dedup.cpp:853
StringAttr dedupGroupAttrName
Definition Dedup.cpp:855
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, DictionaryAttr aDict, Operation *b, DictionaryAttr bDict)
Definition Dedup.cpp:606
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Region &aRegion, Operation *b, Region &bRegion)
Definition Dedup.cpp:561
StringAttr portDirectionsAttr
Definition Dedup.cpp:851
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, BundleType aType, Operation *b, BundleType bType)
Definition Dedup.cpp:428
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Operation *b)
Definition Dedup.cpp:715
Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
Definition Dedup.cpp:388
LogicalResult check(InFlightDiagnostic &diag, Operation *a, mlir::DenseBoolArrayAttr aAttr, Operation *b, mlir::DenseBoolArrayAttr bAttr)
Definition Dedup.cpp:581
InstanceGraph & instanceGraph
Definition Dedup.cpp:859
void check(InFlightDiagnostic &diag, Operation *a, Operation *b)
Definition Dedup.cpp:802
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:98
std::array< uint8_t, 32 > structuralHash
Definition Dedup.cpp:96
This struct contains constant string attributes shared across different threads.
Definition Dedup.cpp:108
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:135
StructuralHasherSharedConstants(MLIRContext *context)
Definition Dedup.cpp:109
void populateInnerSymIDTable(FModuleLike module)
Find all the ports and operations which may define an inner symbol operations and give each a unique ...
Definition Dedup.cpp:154
void update(Operation *op, DictionaryAttr dict)
Hash the top level attribute dictionary of the operation.
Definition Dedup.cpp:259
void update(Type type)
Definition Dedup.cpp:239
void update(const void *pointer)
Definition Dedup.cpp:211
DenseMap< void *, unsigned > idTable
Definition Dedup.cpp:361
void update(const std::pair< T, U > &pair)
Definition Dedup.cpp:222
void update(Operation *op)
Definition Dedup.cpp:333
llvm::SHA256 sha
Definition Dedup.cpp:375
DenseMap< StringAttr, std::pair< size_t, size_t > > innerSymIDTable
Definition Dedup.cpp:365
ModuleInfo getModuleInfo(FModuleLike module)
Definition Dedup.cpp:142
void update(size_t value)
Definition Dedup.cpp:216
void update(BundleType type)
Definition Dedup.cpp:230
unsigned getID(void *object)
Definition Dedup.cpp:173
void update(OpResult result)
Definition Dedup.cpp:245
void update(OpOperand &operand)
Definition Dedup.cpp:194
StructuralHasher(const StructuralHasherSharedConstants &constants)
Definition Dedup.cpp:139
void update(Region *region)
Definition Dedup.cpp:325
void update(Block *block)
Definition Dedup.cpp:314
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:368
void update(TypeID typeID)
Definition Dedup.cpp:227
const StructuralHasherSharedConstants & constants
Definition Dedup.cpp:371
std::pair< size_t, size_t > getInnerSymID(StringAttr name)
Definition Dedup.cpp:190
unsigned finalizeID(void *object)
Definition Dedup.cpp:181
void update(mlir::OperationName name)
Definition Dedup.cpp:308
An annotation target is used to keep track of something that is targeted by an Annotation.
AnnotationSet getAnnotations() const
Get the annotations associated with the target.
void setAnnotations(AnnotationSet annotations) const
Set the annotations associated with the target.
This represents an annotation targeting a specific operation.
Attribute getNLAReference(hw::InnerSymbolNamespace &moduleNamespace) const
This represents an annotation targeting a specific port of a module, memory, or instance.
static ModuleInfo getEmptyKey()
Definition Dedup.cpp:1631
static ModuleInfo getTombstoneKey()
Definition Dedup.cpp:1637
static unsigned getHashValue(const ModuleInfo &val)
Definition Dedup.cpp:1643
static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs)
Definition Dedup.cpp:1655