Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
Dedup.cpp
Go to the documentation of this file.
1//===- Dedup.cpp - FIRRTL module deduping -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements FIRRTL module deduplication.
10//
11//===----------------------------------------------------------------------===//
12
23#include "circt/Support/LLVM.h"
24#include "mlir/IR/IRMapping.h"
25#include "mlir/IR/Threading.h"
26#include "mlir/Pass/Pass.h"
27#include "llvm/ADT/DenseMap.h"
28#include "llvm/ADT/DenseMapInfo.h"
29#include "llvm/ADT/Hashing.h"
30#include "llvm/ADT/PostOrderIterator.h"
31#include "llvm/ADT/SmallPtrSet.h"
32#include "llvm/Support/Format.h"
33#include "llvm/Support/SHA256.h"
34
35namespace circt {
36namespace firrtl {
37#define GEN_PASS_DEF_DEDUP
38#include "circt/Dialect/FIRRTL/Passes.h.inc"
39} // namespace firrtl
40} // namespace circt
41
42using namespace circt;
43using namespace firrtl;
44using hw::InnerRefAttr;
45
46//===----------------------------------------------------------------------===//
47// Utility function for classifying a Symbol's dedup-ability.
48//===----------------------------------------------------------------------===//
49
50/// Returns true if the module can be removed.
51static bool canRemoveModule(mlir::SymbolOpInterface symbol) {
52 // If the symbol is not private, it cannot be removed.
53 if (!symbol.isPrivate())
54 return false;
55 // Classes may be referenced in object types, so can not normally be removed
56 // if we can't find any symbol uses. Since we know that dedup will update the
57 // types of instances appropriately, we can ignore that and return true here.
58 if (isa<ClassLike>(*symbol))
59 return true;
60 // If module can not be removed even if no uses can be found, we can not
61 // delete it. The implication is that there are hidden symbol uses that dedup
62 // will not properly update.
63 if (!symbol.canDiscardOnUseEmpty())
64 return false;
65 // The module can be deleted.
66 return true;
67}
68
69//===----------------------------------------------------------------------===//
70// Hashing
71//===----------------------------------------------------------------------===//
72
73llvm::raw_ostream &printHex(llvm::raw_ostream &stream,
74 ArrayRef<uint8_t> bytes) {
75 // Print the hash on a single line.
76 return stream << format_bytes(bytes, std::nullopt, 32) << "\n";
77}
78
79llvm::raw_ostream &printHash(llvm::raw_ostream &stream, llvm::SHA256 &data) {
80 return printHex(stream, data.result());
81}
82
83llvm::raw_ostream &printHash(llvm::raw_ostream &stream, std::string data) {
84 ArrayRef<uint8_t> bytes(reinterpret_cast<const uint8_t *>(data.c_str()),
85 data.length());
86 return printHex(stream, bytes);
87}
88
89// This struct contains information to determine module module uniqueness. A
90// first element is a structural hash of the module, and the second element is
91// an array which tracks module names encountered in the walk. Since module
92// names could be replaced during dedup, it's necessary to keep names up-to-date
93// before actually combining them into structural hashes.
94struct ModuleInfo {
95 // SHA256 hash.
96 std::array<uint8_t, 32> structuralHash;
97 // Module names referred by instance op in the module.
98 std::vector<StringAttr> referredModuleNames;
99};
100
101static bool operator==(const ModuleInfo &lhs, const ModuleInfo &rhs) {
102 return lhs.structuralHash == rhs.structuralHash &&
104}
105
106/// This struct contains constant string attributes shared across different
107/// threads.
109 explicit StructuralHasherSharedConstants(MLIRContext *context) {
110 portTypesAttr = StringAttr::get(context, "portTypes");
111 moduleNameAttr = StringAttr::get(context, "moduleName");
112 portNamesAttr = StringAttr::get(context, "portNames");
113 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
114 nonessentialAttributes.insert(StringAttr::get(context, "convention"));
115 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
116 nonessentialAttributes.insert(StringAttr::get(context, "name"));
117 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
118 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
119 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
120 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
121 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
122 nonessentialAttributes.insert(StringAttr::get(context, "sym_visibility"));
123 };
124
125 // This is a cached "portTypes" string attr.
126 StringAttr portTypesAttr;
127
128 // This is a cached "moduleName" string attr.
129 StringAttr moduleNameAttr;
130
131 // This is a cached "portNames" string attr.
132 StringAttr portNamesAttr;
133
134 // This is a set of every attribute we should ignore.
135 DenseSet<Attribute> nonessentialAttributes;
136};
137
140 : constants(constants) {}
141
142 ModuleInfo getModuleInfo(FModuleLike module) {
144 update(&(*module));
145 return {sha.final(), std::move(referredModuleNames)};
146 }
147
148private:
149 /// Find all the ports and operations which may define an inner symbol
150 /// operations and give each a unique id. If the port/operation does define
151 /// an inner symbol, map the symbol name to a pair of the id and the symbol's
152 /// field id. When we hash (local) references to this inner symbol, we will
153 /// hash in the id and the the field id.
154 void populateInnerSymIDTable(FModuleLike module) {
155 // Add port symbols. If no port has a symbol defined, the port symbol array
156 // will be totally empty.
157 for (auto [index, innerSym] : llvm::enumerate(module.getPortSymbols())) {
158 for (auto prop : cast<hw::InnerSymAttr>(innerSym))
159 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
160 }
161 // Add operation symbols.
162 size_t index = module.getNumPorts();
163 module.walk([&](hw::InnerSymbolOpInterface innerSymOp) {
164 if (auto innerSym = innerSymOp.getInnerSymAttr()) {
165 for (auto prop : innerSym)
166 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
167 }
168 ++index;
169 });
170 }
171
172 // Get the identifier for an object. The identifier is assigned on first use.
173 unsigned getID(void *object) {
174 auto [it, inserted] = idTable.try_emplace(object, nextID);
175 if (inserted)
176 ++nextID;
177 return it->second;
178 }
179
180 // Get the identifier for an IR object. Free the ID, too.
181 unsigned finalizeID(void *object) {
182 auto it = idTable.find(object);
183 if (it == idTable.end())
184 return nextID++;
185 auto id = it->second;
186 idTable.erase(it);
187 return id;
188 }
189
190 std::pair<size_t, size_t> getInnerSymID(StringAttr name) {
191 return innerSymIDTable.at(name);
192 }
193
194 void update(OpOperand &operand) {
195 auto value = operand.get();
196 if (auto result = dyn_cast<OpResult>(value)) {
197 auto *op = result.getOwner();
198 update(getID(op));
199 update(result.getResultNumber());
200 return;
201 }
202 if (auto argument = dyn_cast<BlockArgument>(value)) {
203 auto *block = argument.getOwner();
204 update(getID(block));
205 update(argument.getArgNumber());
206 return;
207 }
208 llvm_unreachable("Unknown value type");
209 }
210
211 void update(const void *pointer) {
212 auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
213 sha.update(ArrayRef<uint8_t>(addr, sizeof pointer));
214 }
215
216 void update(size_t value) {
217 auto *addr = reinterpret_cast<const uint8_t *>(&value);
218 sha.update(ArrayRef<uint8_t>(addr, sizeof value));
219 }
220
221 template <typename T, typename U>
222 void update(const std::pair<T, U> &pair) {
223 update(pair.first);
224 update(pair.second);
225 }
226
227 void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
228
229 // NOLINTNEXTLINE(misc-no-recursion)
230 void update(BundleType type) {
231 update(type.getTypeID());
232 for (auto &element : type.getElements()) {
233 update(element.isFlip);
234 update(element.type);
235 }
236 }
237
238 // NOLINTNEXTLINE(misc-no-recursion)
239 void update(Type type) {
240 if (auto bundle = type_dyn_cast<BundleType>(type))
241 return update(bundle);
242 update(type.getAsOpaquePointer());
243 }
244
245 void update(OpResult result) {
246 // Like instance ops, don't use object ops' result types since they might be
247 // replaced by dedup. Record the class names and lazily combine their hashes
248 // using the same mechanism as instances and modules.
249 if (auto objectOp = dyn_cast<ObjectOp>(result.getOwner())) {
250 referredModuleNames.push_back(objectOp.getType().getNameAttr().getAttr());
251 return;
252 }
253
254 update(result.getType());
255 }
256
257 /// Hash the top level attribute dictionary of the operation. This function
258 /// has special handling for inner symbols, ports, and referenced modules.
259 void update(Operation *op, DictionaryAttr dict) {
260 for (auto namedAttr : dict) {
261 auto name = namedAttr.getName();
262 auto value = namedAttr.getValue();
263 // Skip names and annotations, except in certain cases.
264 // Names of ports are load bearing for classes, so we do hash those.
265 bool isClassPortNames =
266 isa<ClassLike>(op) && name == constants.portNamesAttr;
267 if (constants.nonessentialAttributes.contains(name) && !isClassPortNames)
268 continue;
269
270 // Hash the attribute name (an interned pointer).
271 update(name.getAsOpaquePointer());
272
273 // Hash the port types.
274 if (name == constants.portTypesAttr) {
275 auto portTypes = cast<ArrayAttr>(value).getAsValueRange<TypeAttr>();
276 for (auto type : portTypes)
277 update(type);
278 continue;
279 }
280
281 // For instance op, don't use `moduleName` attributes since they might be
282 // replaced by dedup. Record the names and lazily combine their hashes.
283 // It is assumed that module names are hashed only through instance ops;
284 // it could cause suboptimal results if there was other operation that
285 // refers to module names through essential attributes.
286 if (isa<InstanceOp>(op) && name == constants.moduleNameAttr) {
287 referredModuleNames.push_back(cast<FlatSymbolRefAttr>(value).getAttr());
288 continue;
289 }
290
291 // TODO: properly handle DistinctAttr, including its use in paths.
292 // See https://github.com/llvm/circt/issues/6583.
293 if (isa<DistinctAttr>(value))
294 continue;
295
296 // If this is an symbol reference, we need to perform name erasure.
297 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(value)) {
298 update(getInnerSymID(innerRef.getName()));
299 continue;
300 }
301
302 // We don't need to handle this attribute specially, so hash its unique
303 // address.
304 update(value.getAsOpaquePointer());
305 }
306 }
307
308 void update(mlir::OperationName name) {
309 // Operation names are interned.
310 update(name.getAsOpaquePointer());
311 }
312
313 // NOLINTNEXTLINE(misc-no-recursion)
314 void update(Block *block) {
315 for (auto &op : llvm::reverse(*block))
316 update(&op);
317 for (auto type : block->getArgumentTypes())
318 update(type);
319 update(finalizeID(block));
320 update(position);
321 ++position;
322 }
323
324 // NOLINTNEXTLINE(misc-no-recursion)
325 void update(Region *region) {
326 for (auto &block : llvm::reverse(region->getBlocks()))
327 update(&block);
328 update(position);
329 ++position;
330 }
331
332 // NOLINTNEXTLINE(misc-no-recursion)
333 void update(Operation *op) {
334 // Hash the regions. We need to make sure an empty region doesn't hash the
335 // same as no region, so we include the number of regions.
336 update(op->getNumRegions());
337 for (auto &region : reverse(op->getRegions()))
338 update(&region);
339
340 update(op->getName());
341
342 // Record the uses for later hashing.
343 for (auto &operand : op->getOpOperands())
344 update(operand);
345
346 // This happens after the numbering above, as it uses blockarg numbering
347 // for inner symbols.
348 update(op, op->getAttrDictionary());
349
350 // Record any op results (types).
351 for (auto result : op->getResults())
352 update(result);
353
354 // Incorporate the hash of uses we have already built.
355 update(finalizeID(op));
356 update(position);
357 ++position;
358 }
359
360 // A map from an operation/block, to its identifier.
361 DenseMap<void *, unsigned> idTable;
362 unsigned nextID = 0;
363
364 // A map from an inner symbol, to its identifier.
365 DenseMap<StringAttr, std::pair<size_t, size_t>> innerSymIDTable;
366
367 // This keeps track of module names in the order of the appearance.
368 std::vector<StringAttr> referredModuleNames;
369
370 // String constants.
372
373 // This is the actual running hash calculation. This is a stateful element
374 // that should be reinitialized after each hash is produced.
375 llvm::SHA256 sha;
376
377 // The index of the current op. Increment after handling each op.
378 size_t position = 0;
379};
380
381//===----------------------------------------------------------------------===//
382// Equivalence
383//===----------------------------------------------------------------------===//
384
385/// This class is for reporting differences between two modules which should
386/// have been deduplicated.
390 noDedupClass = StringAttr::get(context, noDedupAnnoClass);
391 dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
392 portDirectionsAttr = StringAttr::get(context, "portDirections");
393 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
394 nonessentialAttributes.insert(StringAttr::get(context, "name"));
395 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
396 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
397 nonessentialAttributes.insert(StringAttr::get(context, "portTypes"));
398 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
399 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
400 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
401 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
402 }
403
411
412 std::string prettyPrint(Attribute attr) {
413 SmallString<64> buffer;
414 llvm::raw_svector_ostream os(buffer);
415 if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
416 os << "0x";
417 if (integerAttr.getType().isSignlessInteger())
418 integerAttr.getValue().toStringUnsigned(buffer, /*radix=*/16);
419 else
420 integerAttr.getAPSInt().toString(buffer, /*radix=*/16);
421
422 } else
423 os << attr;
424 return std::string(buffer);
425 }
426
427 // NOLINTNEXTLINE(misc-no-recursion)
428 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
429 Operation *a, BundleType aType, Operation *b,
430 BundleType bType) {
431 if (aType.getNumElements() != bType.getNumElements()) {
432 diag.attachNote(a->getLoc())
433 << message << " bundle type has different number of elements";
434 diag.attachNote(b->getLoc()) << "second operation here";
435 return failure();
436 }
437
438 for (auto elementPair :
439 llvm::zip(aType.getElements(), bType.getElements())) {
440 auto aElement = std::get<0>(elementPair);
441 auto bElement = std::get<1>(elementPair);
442 if (aElement.isFlip != bElement.isFlip) {
443 diag.attachNote(a->getLoc()) << message << " bundle element "
444 << aElement.name << " flip does not match";
445 diag.attachNote(b->getLoc()) << "second operation here";
446 return failure();
447 }
448
449 if (failed(check(diag,
450 "bundle element \'" + aElement.name.getValue() + "'", a,
451 aElement.type, b, bElement.type)))
452 return failure();
453 }
454 return success();
455 }
456
457 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
458 Operation *a, Type aType, Operation *b, Type bType) {
459 if (aType == bType)
460 return success();
461 if (auto aBundleType = type_dyn_cast<BundleType>(aType))
462 if (auto bBundleType = type_dyn_cast<BundleType>(bType))
463 return check(diag, message, a, aBundleType, b, bBundleType);
464 if (type_isa<RefType>(aType) && type_isa<RefType>(bType) &&
465 aType != bType) {
466 diag.attachNote(a->getLoc())
467 << message << ", has a RefType with a different base type "
468 << type_cast<RefType>(aType).getType()
469 << " in the same position of the two modules marked as 'must dedup'. "
470 "(This may be due to Grand Central Taps or Views being different "
471 "between the two modules.)";
472 diag.attachNote(b->getLoc())
473 << "the second module has a different base type "
474 << type_cast<RefType>(bType).getType();
475 return failure();
476 }
477 diag.attachNote(a->getLoc())
478 << message << " types don't match, first type is " << aType;
479 diag.attachNote(b->getLoc()) << "second type is " << bType;
480 return failure();
481 }
482
483 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
484 Block &aBlock, Operation *b, Block &bBlock) {
485
486 // Block argument types.
487 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
488 auto portNo = 0;
489 auto emitMissingPort = [&](Value existsVal, Operation *opExists,
490 Operation *opDoesNotExist) {
491 StringRef portName;
492 auto portNames = opExists->getAttrOfType<ArrayAttr>("portNames");
493 if (portNames)
494 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
495 portName = portNameAttr.getValue();
496 if (type_isa<RefType>(existsVal.getType())) {
497 diag.attachNote(opExists->getLoc())
498 << " contains a RefType port named '" + portName +
499 "' that only exists in one of the modules (can be due to "
500 "difference in Grand Central Tap or View of two modules "
501 "marked with must dedup)";
502 diag.attachNote(opDoesNotExist->getLoc())
503 << "second module to be deduped that does not have the RefType "
504 "port";
505 } else {
506 diag.attachNote(opExists->getLoc())
507 << "port '" + portName + "' only exists in one of the modules";
508 diag.attachNote(opDoesNotExist->getLoc())
509 << "second module to be deduped that does not have the port";
510 }
511 return failure();
512 };
513
514 for (auto argPair :
515 llvm::zip_longest(aBlock.getArguments(), bBlock.getArguments())) {
516 auto &aArg = std::get<0>(argPair);
517 auto &bArg = std::get<1>(argPair);
518 if (aArg.has_value() && bArg.has_value()) {
519 // TODO: we should print the port number if there are no port names, but
520 // there are always port names ;).
521 StringRef portName;
522 if (portNames) {
523 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
524 portName = portNameAttr.getValue();
525 }
526 // Assumption here that block arguments correspond to ports.
527 if (failed(check(diag, "module port '" + portName + "'", a,
528 aArg->getType(), b, bArg->getType())))
529 return failure();
530 data.map.map(aArg.value(), bArg.value());
531 portNo++;
532 continue;
533 }
534 if (!aArg.has_value())
535 std::swap(a, b);
536 return emitMissingPort(aArg.has_value() ? aArg.value() : bArg.value(), a,
537 b);
538 }
539
540 // Blocks operations.
541 auto aIt = aBlock.begin();
542 auto aEnd = aBlock.end();
543 auto bIt = bBlock.begin();
544 auto bEnd = bBlock.end();
545 while (aIt != aEnd && bIt != bEnd)
546 if (failed(check(diag, data, &*aIt++, &*bIt++)))
547 return failure();
548 if (aIt != aEnd) {
549 diag.attachNote(aIt->getLoc()) << "first block has more operations";
550 diag.attachNote(b->getLoc()) << "second block here";
551 return failure();
552 }
553 if (bIt != bEnd) {
554 diag.attachNote(bIt->getLoc()) << "second block has more operations";
555 diag.attachNote(a->getLoc()) << "first block here";
556 return failure();
557 }
558 return success();
559 }
560
561 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
562 Region &aRegion, Operation *b, Region &bRegion) {
563 auto aIt = aRegion.begin();
564 auto aEnd = aRegion.end();
565 auto bIt = bRegion.begin();
566 auto bEnd = bRegion.end();
567
568 // Region blocks.
569 while (aIt != aEnd && bIt != bEnd)
570 if (failed(check(diag, data, a, *aIt++, b, *bIt++)))
571 return failure();
572 if (aIt != aEnd || bIt != bEnd) {
573 diag.attachNote(a->getLoc())
574 << "operation regions have different number of blocks";
575 diag.attachNote(b->getLoc()) << "second operation here";
576 return failure();
577 }
578 return success();
579 }
580
581 LogicalResult check(InFlightDiagnostic &diag, Operation *a,
582 mlir::DenseBoolArrayAttr aAttr, Operation *b,
583 mlir::DenseBoolArrayAttr bAttr) {
584 if (aAttr == bAttr)
585 return success();
586 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
587 for (unsigned i = 0, e = aAttr.size(); i < e; ++i) {
588 auto aDirection = aAttr[i];
589 auto bDirection = bAttr[i];
590 if (aDirection != bDirection) {
591 auto &note = diag.attachNote(a->getLoc()) << "module port ";
592 if (portNames)
593 note << "'" << cast<StringAttr>(portNames[i]).getValue() << "'";
594 else
595 note << i;
596 note << " directions don't match, first direction is '"
597 << direction::toString(aDirection) << "'";
598 diag.attachNote(b->getLoc()) << "second direction is '"
599 << direction::toString(bDirection) << "'";
600 return failure();
601 }
602 }
603 return success();
604 }
605
606 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
607 DictionaryAttr aDict, Operation *b,
608 DictionaryAttr bDict) {
609 // Fast path.
610 if (aDict == bDict)
611 return success();
612
613 DenseSet<Attribute> seenAttrs;
614 for (auto namedAttr : aDict) {
615 auto attrName = namedAttr.getName();
616 if (nonessentialAttributes.contains(attrName))
617 continue;
618
619 auto aAttr = namedAttr.getValue();
620 auto bAttr = bDict.get(attrName);
621 if (!bAttr) {
622 diag.attachNote(a->getLoc())
623 << "second operation is missing attribute " << attrName;
624 diag.attachNote(b->getLoc()) << "second operation here";
625 return diag;
626 }
627
628 if (isa<hw::InnerRefAttr>(aAttr) && isa<hw::InnerRefAttr>(bAttr)) {
629 auto bRef = cast<hw::InnerRefAttr>(bAttr);
630 auto aRef = cast<hw::InnerRefAttr>(aAttr);
631 // See if they are pointing at the same operation or port.
632 auto aTarget = data.a.lookup(aRef.getName());
633 auto bTarget = data.b.lookup(bRef.getName());
634 if (!aTarget || !bTarget)
635 diag.attachNote(a->getLoc())
636 << "malformed ir, possibly violating use-before-def";
637 auto error = [&]() {
638 diag.attachNote(a->getLoc())
639 << "operations have different targets, first operation has "
640 << aTarget;
641 diag.attachNote(b->getLoc()) << "second operation has " << bTarget;
642 return failure();
643 };
644 if (aTarget.isPort()) {
645 // If they are targeting ports, make sure its the same port number.
646 if (!bTarget.isPort() || aTarget.getPort() != bTarget.getPort())
647 return error();
648 } else {
649 // Otherwise make sure that they are targeting the same operation.
650 if (!bTarget.isOpOnly() ||
651 aTarget.getOp() != data.map.lookup(bTarget.getOp()))
652 return error();
653 }
654 if (aTarget.getField() != bTarget.getField())
655 return error();
656 } else if (attrName == portDirectionsAttr) {
657 // Special handling for the port directions attribute for better
658 // error messages.
659 if (failed(check(diag, a, cast<mlir::DenseBoolArrayAttr>(aAttr), b,
660 cast<mlir::DenseBoolArrayAttr>(bAttr))))
661 return failure();
662 } else if (isa<DistinctAttr>(aAttr) && isa<DistinctAttr>(bAttr)) {
663 // TODO: properly handle DistinctAttr, including its use in paths.
664 // See https://github.com/llvm/circt/issues/6583
665 } else if (aAttr != bAttr) {
666 diag.attachNote(a->getLoc())
667 << "first operation has attribute '" << attrName.getValue()
668 << "' with value " << prettyPrint(aAttr);
669 diag.attachNote(b->getLoc())
670 << "second operation has value " << prettyPrint(bAttr);
671 return failure();
672 }
673 seenAttrs.insert(attrName);
674 }
675 if (aDict.getValue().size() != bDict.getValue().size()) {
676 for (auto namedAttr : bDict) {
677 auto attrName = namedAttr.getName();
678 // Skip the attribute if we don't care about this particular one or it
679 // is one that is known to be in both dictionaries.
680 if (nonessentialAttributes.contains(attrName) ||
681 seenAttrs.contains(attrName))
682 continue;
683 // We have found an attribute that is only in the second operation.
684 diag.attachNote(a->getLoc())
685 << "first operation is missing attribute " << attrName;
686 diag.attachNote(b->getLoc()) << "second operation here";
687 return failure();
688 }
689 }
690 return success();
691 }
692
693 // NOLINTNEXTLINE(misc-no-recursion)
694 LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a,
695 FInstanceLike b) {
696 auto aName = a.getReferencedModuleNameAttr();
697 auto bName = b.getReferencedModuleNameAttr();
698 if (aName == bName)
699 return success();
700
701 // If the modules instantiate are different we will want to know why the
702 // sub module did not dedupliate. This code recursively checks the child
703 // module.
704 auto aModule = instanceGraph.lookup(aName)->getModule();
705 auto bModule = instanceGraph.lookup(bName)->getModule();
706 // Create a new error for the submodule.
707 diag.attachNote(std::nullopt)
708 << "in instance " << a.getInstanceNameAttr() << " of " << aName
709 << ", and instance " << b.getInstanceNameAttr() << " of " << bName;
710 check(diag, aModule, bModule);
711 return failure();
712 }
713
714 // NOLINTNEXTLINE(misc-no-recursion)
715 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
716 Operation *b) {
717 // Operation name.
718 if (a->getName() != b->getName()) {
719 diag.attachNote(a->getLoc()) << "first operation is a " << a->getName();
720 diag.attachNote(b->getLoc()) << "second operation is a " << b->getName();
721 return failure();
722 }
723
724 // If its an instance operaiton, perform some checking and possibly
725 // recurse.
726 if (auto aInst = dyn_cast<FInstanceLike>(a)) {
727 auto bInst = cast<FInstanceLike>(b);
728 if (failed(check(diag, aInst, bInst)))
729 return failure();
730 }
731
732 // Operation results.
733 if (a->getNumResults() != b->getNumResults()) {
734 diag.attachNote(a->getLoc())
735 << "operations have different number of results";
736 diag.attachNote(b->getLoc()) << "second operation here";
737 return failure();
738 }
739 for (auto resultPair : llvm::zip(a->getResults(), b->getResults())) {
740 auto &aValue = std::get<0>(resultPair);
741 auto &bValue = std::get<1>(resultPair);
742 if (failed(check(diag, "operation result", a, aValue.getType(), b,
743 bValue.getType())))
744 return failure();
745 data.map.map(aValue, bValue);
746 }
747
748 // Operations operands.
749 if (a->getNumOperands() != b->getNumOperands()) {
750 diag.attachNote(a->getLoc())
751 << "operations have different number of operands";
752 diag.attachNote(b->getLoc()) << "second operation here";
753 return failure();
754 }
755 for (auto operandPair : llvm::zip(a->getOperands(), b->getOperands())) {
756 auto &aValue = std::get<0>(operandPair);
757 auto &bValue = std::get<1>(operandPair);
758 if (bValue != data.map.lookup(aValue)) {
759 diag.attachNote(a->getLoc())
760 << "operations use different operands, first operand is '"
761 << getFieldName(
762 getFieldRefFromValue(aValue, /*lookThroughCasts=*/true))
763 .first
764 << "'";
765 diag.attachNote(b->getLoc())
766 << "second operand is '"
767 << getFieldName(
768 getFieldRefFromValue(bValue, /*lookThroughCasts=*/true))
769 .first
770 << "', but should have been '"
771 << getFieldName(getFieldRefFromValue(data.map.lookup(aValue),
772 /*lookThroughCasts=*/true))
773 .first
774 << "'";
775 return failure();
776 }
777 }
778 data.map.map(a, b);
779
780 // Operation regions.
781 if (a->getNumRegions() != b->getNumRegions()) {
782 diag.attachNote(a->getLoc())
783 << "operations have different number of regions";
784 diag.attachNote(b->getLoc()) << "second operation here";
785 return failure();
786 }
787 for (auto regionPair : llvm::zip(a->getRegions(), b->getRegions())) {
788 auto &aRegion = std::get<0>(regionPair);
789 auto &bRegion = std::get<1>(regionPair);
790 if (failed(check(diag, data, a, aRegion, b, bRegion)))
791 return failure();
792 }
793
794 // Operation attributes.
795 if (failed(check(diag, data, a, a->getAttrDictionary(), b,
796 b->getAttrDictionary())))
797 return failure();
798 return success();
799 }
800
801 // NOLINTNEXTLINE(misc-no-recursion)
802 void check(InFlightDiagnostic &diag, Operation *a, Operation *b) {
803 hw::InnerSymbolTable aTable(a);
804 hw::InnerSymbolTable bTable(b);
805 ModuleData data(aTable, bTable);
807 diag.attachNote(a->getLoc()) << "module marked NoDedup";
808 return;
809 }
811 diag.attachNote(b->getLoc()) << "module marked NoDedup";
812 return;
813 }
814 auto aSymbol = cast<mlir::SymbolOpInterface>(a);
815 auto bSymbol = cast<mlir::SymbolOpInterface>(b);
816 if (!canRemoveModule(aSymbol)) {
817 diag.attachNote(a->getLoc())
818 << "module is "
819 << (aSymbol.isPrivate() ? "private but not discardable" : "public");
820 return;
821 }
822 if (!canRemoveModule(bSymbol)) {
823 diag.attachNote(b->getLoc())
824 << "module is "
825 << (bSymbol.isPrivate() ? "private but not discardable" : "public");
826 return;
827 }
828 auto aGroup =
829 dyn_cast_or_null<StringAttr>(a->getDiscardableAttr(dedupGroupAttrName));
830 auto bGroup = dyn_cast_or_null<StringAttr>(
831 b->getAttrOfType<StringAttr>(dedupGroupAttrName));
832 if (aGroup != bGroup) {
833 if (bGroup) {
834 diag.attachNote(b->getLoc())
835 << "module is in dedup group '" << bGroup.str() << "'";
836 } else {
837 diag.attachNote(b->getLoc()) << "module is not part of a dedup group";
838 }
839 if (aGroup) {
840 diag.attachNote(a->getLoc())
841 << "module is in dedup group '" << aGroup.str() << "'";
842 } else {
843 diag.attachNote(a->getLoc()) << "module is not part of a dedup group";
844 }
845 return;
846 }
847 if (failed(check(diag, data, a, b)))
848 return;
849 diag.attachNote(a->getLoc()) << "first module here";
850 diag.attachNote(b->getLoc()) << "second module here";
851 }
852
853 // This is a cached "portDirections" string attr.
855 // This is a cached "NoDedup" annotation class string attr.
856 StringAttr noDedupClass;
857 // This is a cached string attr for the dedup group attribute.
859
860 // This is a set of every attribute we should ignore.
861 DenseSet<Attribute> nonessentialAttributes;
863};
864
865//===----------------------------------------------------------------------===//
866// Deduplication
867//===----------------------------------------------------------------------===//
868
869// Custom location merging. This only keeps track of 8 annotations from ".fir"
870// files, and however many annotations come from "real" sources. When
871// deduplicating, modules tend not to have scala source locators, so we wind
872// up fusing source locators for a module from every copy being deduped. There
873// is little value in this (all the modules are identical by definition).
874static Location mergeLoc(MLIRContext *context, Location to, Location from) {
875 // Unique the set of locations to be fused.
876 llvm::SmallSetVector<Location, 4> decomposedLocs;
877 // only track 8 "fir" locations
878 unsigned seenFIR = 0;
879 for (auto loc : {to, from}) {
880 // If the location is a fused location we decompose it if it has no
881 // metadata or the metadata is the same as the top level metadata.
882 if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
883 // UnknownLoc's have already been removed from FusedLocs so we can
884 // simply add all of the internal locations.
885 for (auto loc : fusedLoc.getLocations()) {
886 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
887 if (fileLoc.getFilename().strref().ends_with(".fir")) {
888 ++seenFIR;
889 if (seenFIR > 8)
890 continue;
891 }
892 }
893 decomposedLocs.insert(loc);
894 }
895 continue;
896 }
897
898 // Might need to skip this fir.
899 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
900 if (fileLoc.getFilename().strref().ends_with(".fir")) {
901 ++seenFIR;
902 if (seenFIR > 8)
903 continue;
904 }
905 }
906 // Otherwise, only add known locations to the set.
907 if (!isa<UnknownLoc>(loc))
908 decomposedLocs.insert(loc);
909 }
910
911 auto locs = decomposedLocs.getArrayRef();
912
913 // Handle the simple cases of less than two locations. Ensure the metadata (if
914 // provided) is not dropped.
915 if (locs.empty())
916 return UnknownLoc::get(context);
917 if (locs.size() == 1)
918 return locs.front();
919
920 return FusedLoc::get(context, locs);
921}
922
923struct Deduper {
924
925 using RenameMap = DenseMap<StringAttr, StringAttr>;
926
928 NLATable *nlaTable, CircuitOp circuit)
929 : context(circuit->getContext()), instanceGraph(instanceGraph),
931 nlaBlock(circuit.getBodyBlock()),
932 nonLocalString(StringAttr::get(context, "circt.nonlocal")),
933 classString(StringAttr::get(context, "class")) {
934 // Populate the NLA cache.
935 for (auto nla : circuit.getOps<hw::HierPathOp>())
936 nlaCache[nla.getNamepathAttr()] = nla.getSymNameAttr();
937 }
938
939 /// Remove the "fromModule", and replace all references to it with the
940 /// "toModule". Modules should be deduplicated in a bottom-up order. Any
941 /// module which is not deduplicated needs to be recorded with the `record`
942 /// call.
943 void dedup(FModuleLike toModule, FModuleLike fromModule) {
944 // A map of operation (e.g. wires, nodes) names which are changed, which is
945 // used to update NLAs that reference the "fromModule".
946 RenameMap renameMap;
947
948 // Merge the port locations.
949 SmallVector<Attribute> newLocs;
950 for (auto [toLoc, fromLoc] : llvm::zip(toModule.getPortLocations(),
951 fromModule.getPortLocations())) {
952 if (toLoc == fromLoc)
953 newLocs.push_back(toLoc);
954 else
955 newLocs.push_back(mergeLoc(context, cast<LocationAttr>(toLoc),
956 cast<LocationAttr>(fromLoc)));
957 }
958 toModule->setAttr("portLocations", ArrayAttr::get(context, newLocs));
959
960 // Merge the two modules.
961 mergeOps(renameMap, toModule, toModule, fromModule, fromModule);
962
963 // Rewrite NLAs pathing through these modules to refer to the to module. It
964 // is safe to do this at this point because NLAs cannot be one element long.
965 // This means that all NLAs which require more context cannot be targetting
966 // something in the module it self.
967 if (auto to = dyn_cast<FModuleOp>(*toModule))
968 rewriteModuleNLAs(renameMap, to, cast<FModuleOp>(*fromModule));
969 else
970 rewriteExtModuleNLAs(renameMap, toModule.getModuleNameAttr(),
971 fromModule.getModuleNameAttr());
972
973 replaceInstances(toModule, fromModule);
974 }
975
976 /// Record the usages of any NLA's in this module, so that we may update the
977 /// annotation if the parent module is deduped with another module.
978 void record(FModuleLike module) {
979 // Record any annotations on the module.
980 recordAnnotations(module);
981 // Record port annotations.
982 for (unsigned i = 0, e = getNumPorts(module); i < e; ++i)
984 // Record any annotations in the module body.
985 module->walk([&](Operation *op) { recordAnnotations(op); });
986 }
987
988private:
989 /// Get a cached namespace for a module.
991 return moduleNamespaces.try_emplace(module, cast<FModuleLike>(module))
992 .first->second;
993 }
994
995 /// For a specific annotation target, record all the unique NLAs which
996 /// target it in the `targetMap`.
998 for (auto anno : target.getAnnotations())
999 if (auto nlaRef = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
1000 targetMap[nlaRef.getAttr()].insert(target);
1001 }
1002
1003 /// Record all targets which use an NLA.
1004 void recordAnnotations(Operation *op) {
1005 // Record annotations.
1006 recordAnnotations(OpAnnoTarget(op));
1007
1008 // Record port annotations only if this is a mem operation.
1009 auto mem = dyn_cast<MemOp>(op);
1010 if (!mem)
1011 return;
1012
1013 // Record port annotations.
1014 for (unsigned i = 0, e = mem->getNumResults(); i < e; ++i)
1015 recordAnnotations(PortAnnoTarget(mem, i));
1016 }
1017
1018 /// This deletes and replaces all instances of the "fromModule" with instances
1019 /// of the "toModule".
1020 void replaceInstances(FModuleLike toModule, Operation *fromModule) {
1021 // Replace all instances of the other module.
1022 auto *fromNode =
1023 instanceGraph[::cast<igraph::ModuleOpInterface>(fromModule)];
1024 auto *toNode = instanceGraph[toModule];
1025 auto toModuleRef = FlatSymbolRefAttr::get(toModule.getModuleNameAttr());
1026 for (auto *oldInstRec : llvm::make_early_inc_range(fromNode->uses())) {
1027 auto inst = oldInstRec->getInstance();
1028 if (auto instOp = dyn_cast<InstanceOp>(*inst)) {
1029 instOp.setModuleNameAttr(toModuleRef);
1030 instOp.setPortNamesAttr(toModule.getPortNamesAttr());
1031 } else if (auto objectOp = dyn_cast<ObjectOp>(*inst)) {
1032 auto classLike = cast<ClassLike>(*toNode->getModule());
1033 ClassType classType = detail::getInstanceTypeForClassLike(classLike);
1034 objectOp.getResult().setType(classType);
1035 }
1036 oldInstRec->getParent()->addInstance(inst, toNode);
1037 oldInstRec->erase();
1038 }
1039 instanceGraph.erase(fromNode);
1040 fromModule->erase();
1041 }
1042
1043 /// Look up the instantiations of the `from` module and create an NLA for each
1044 /// one, appending the baseNamepath to each NLA. This is used to add more
1045 /// context to an already existing NLA. The `fromModule` is used to indicate
1046 /// which module the annotation is coming from before the merge, and will be
1047 /// used to create the namepaths.
1048 SmallVector<FlatSymbolRefAttr>
1049 createNLAs(Operation *fromModule, ArrayRef<Attribute> baseNamepath,
1050 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1051 // Create an attribute array with a placeholder in the first element, where
1052 // the root refence of the NLA will be inserted.
1053 SmallVector<Attribute> namepath = {nullptr};
1054 namepath.append(baseNamepath.begin(), baseNamepath.end());
1055
1056 auto loc = fromModule->getLoc();
1057 auto *fromNode = instanceGraph[cast<igraph::ModuleOpInterface>(fromModule)];
1058 SmallVector<FlatSymbolRefAttr> nlas;
1059 for (auto *instanceRecord : fromNode->uses()) {
1060 auto parent = cast<FModuleOp>(*instanceRecord->getParent()->getModule());
1061 auto inst = instanceRecord->getInstance();
1062 namepath[0] = OpAnnoTarget(inst).getNLAReference(getNamespace(parent));
1063 auto arrayAttr = ArrayAttr::get(context, namepath);
1064 // Check the NLA cache to see if we already have this NLA.
1065 auto &cacheEntry = nlaCache[arrayAttr];
1066 if (!cacheEntry) {
1067 auto nla = OpBuilder::atBlockBegin(nlaBlock).create<hw::HierPathOp>(
1068 loc, "nla", arrayAttr);
1069 // Insert it into the symbol table to get a unique name.
1070 symbolTable.insert(nla);
1071 // Store it in the cache.
1072 cacheEntry = nla.getNameAttr();
1073 nla.setVisibility(vis);
1074 nlaTable->addNLA(nla);
1075 }
1076 auto nlaRef = FlatSymbolRefAttr::get(cast<StringAttr>(cacheEntry));
1077 nlas.push_back(nlaRef);
1078 }
1079 return nlas;
1080 }
1081
1082 /// Look up the instantiations of this module and create an NLA for each one.
1083 /// This returns an array of symbol references which can be used to reference
1084 /// the NLAs.
1085 SmallVector<FlatSymbolRefAttr>
1086 createNLAs(StringAttr toModuleName, FModuleLike fromModule,
1087 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1088 return createNLAs(fromModule, FlatSymbolRefAttr::get(toModuleName), vis);
1089 }
1090
1091 /// Clone the annotation for each NLA in a list. The attribute list should
1092 /// have a placeholder for the "circt.nonlocal" field, and `nonLocalIndex`
1093 /// should be the index of this field.
1094 void cloneAnnotation(SmallVectorImpl<FlatSymbolRefAttr> &nlas,
1095 Annotation anno, ArrayRef<NamedAttribute> attributes,
1096 unsigned nonLocalIndex,
1097 SmallVectorImpl<Annotation> &newAnnotations) {
1098 SmallVector<NamedAttribute> mutableAttributes(attributes.begin(),
1099 attributes.end());
1100 for (auto &nla : nlas) {
1101 // Add the new annotation.
1102 mutableAttributes[nonLocalIndex].setValue(nla);
1103 auto dict = DictionaryAttr::getWithSorted(context, mutableAttributes);
1104 // The original annotation records if its a subannotation.
1105 anno.setDict(dict);
1106 newAnnotations.push_back(anno);
1107 }
1108 }
1109
1110 /// This erases the NLA op, and removes the NLA from every module's NLA map,
1111 /// but it does not delete the NLA reference from the target operation's
1112 /// annotations.
1113 void eraseNLA(hw::HierPathOp nla) {
1114 // Erase the NLA from the leaf module's nlaMap.
1115 targetMap.erase(nla.getNameAttr());
1116 nlaTable->erase(nla);
1117 nlaCache.erase(nla.getNamepathAttr());
1118 symbolTable.erase(nla);
1119 }
1120
1121 /// Process all NLAs referencing the "from" module to point to the "to"
1122 /// module. This is used after merging two modules together.
1123 void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule,
1124 FModuleOp fromModule) {
1125 auto toName = toModule.getNameAttr();
1126 auto fromName = fromModule.getNameAttr();
1127 // Create a copy of the current NLAs. We will be pushing and removing
1128 // NLAs from this op as we go.
1129 auto moduleNLAs = nlaTable->lookup(fromModule.getNameAttr()).vec();
1130 // Change the NLA to target the toModule.
1131 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1132 // Now we walk the NLA searching for ones that require more context to be
1133 // added.
1134 for (auto nla : moduleNLAs) {
1135 auto elements = nla.getNamepath().getValue();
1136 // If we don't need to add more context, we're done here.
1137 if (nla.root() != toName)
1138 continue;
1139 // Create the replacement NLAs.
1140 SmallVector<Attribute> namepath(elements.begin(), elements.end());
1141 auto nlaRefs = createNLAs(fromModule, namepath, nla.getVisibility());
1142 // Copy out the targets, because we will be updating the map.
1143 auto &set = targetMap[nla.getSymNameAttr()];
1144 SmallVector<AnnoTarget> targets(set.begin(), set.end());
1145 // Replace the uses of the old NLA with the new NLAs.
1146 for (auto target : targets) {
1147 // We have to clone any annotation which uses the old NLA for each new
1148 // NLA. This array collects the new set of annotations.
1149 SmallVector<Annotation> newAnnotations;
1150 for (auto anno : target.getAnnotations()) {
1151 // Find the non-local field of the annotation.
1152 auto [it, found] = mlir::impl::findAttrSorted(
1153 anno.begin(), anno.end(), nonLocalString);
1154 // If this annotation doesn't use the target NLA, copy it with no
1155 // changes.
1156 if (!found || cast<FlatSymbolRefAttr>(it->getValue()).getAttr() !=
1157 nla.getSymNameAttr()) {
1158 newAnnotations.push_back(anno);
1159 continue;
1160 }
1161 auto nonLocalIndex = std::distance(anno.begin(), it);
1162 // Clone the annotation and add it to the list of new annotations.
1163 cloneAnnotation(nlaRefs, anno,
1164 ArrayRef<NamedAttribute>(anno.begin(), anno.end()),
1165 nonLocalIndex, newAnnotations);
1166 }
1167
1168 // Apply the new annotations to the operation.
1169 AnnotationSet annotations(newAnnotations, context);
1170 target.setAnnotations(annotations);
1171 // Record that target uses the NLA.
1172 for (auto nla : nlaRefs)
1173 targetMap[nla.getAttr()].insert(target);
1174 }
1175
1176 // Erase the old NLA and remove it from all breadcrumbs.
1177 eraseNLA(nla);
1178 }
1179 }
1180
1181 /// Process all the NLAs that the two modules participate in, replacing
1182 /// references to the "from" module with references to the "to" module, and
1183 /// adding more context if necessary.
1184 void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule,
1185 FModuleOp fromModule) {
1186 addAnnotationContext(renameMap, toModule, toModule);
1187 addAnnotationContext(renameMap, toModule, fromModule);
1188 }
1189
1190 // Update all NLAs which the "from" external module participates in to the
1191 // "toName".
1192 void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName,
1193 StringAttr fromName) {
1194 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1195 }
1196
1197 /// Take an annotation, and update it to be a non-local annotation. If the
1198 /// annotation is already non-local and has enough context, it will be skipped
1199 /// for now. Return true if the annotation was made non-local.
1200 bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to,
1201 FModuleLike fromModule, Annotation anno,
1202 SmallVectorImpl<Annotation> &newAnnotations) {
1203 // Start constructing a new annotation, pushing a "circt.nonLocal" field
1204 // into the correct spot if its not already a non-local annotation.
1205 SmallVector<NamedAttribute> attributes;
1206 int nonLocalIndex = -1;
1207 for (const auto &val : llvm::enumerate(anno)) {
1208 auto attr = val.value();
1209 // Is this field "circt.nonlocal"?
1210 auto compare = attr.getName().compare(nonLocalString);
1211 assert(compare != 0 && "should not pass non-local annotations here");
1212 if (compare == 1) {
1213 // This annotation definitely does not have "circt.nonlocal" field. Push
1214 // an empty place holder for the non-local annotation.
1215 nonLocalIndex = val.index();
1216 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1217 break;
1218 }
1219 // Otherwise push the current attribute and keep searching for the
1220 // "circt.nonlocal" field.
1221 attributes.push_back(attr);
1222 }
1223 if (nonLocalIndex == -1) {
1224 // Push an empty "circt.nonlocal" field to the last slot.
1225 nonLocalIndex = attributes.size();
1226 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1227 } else {
1228 // Copy the remaining annotation fields in.
1229 attributes.append(anno.begin() + nonLocalIndex, anno.end());
1230 }
1231
1232 // Construct the NLAs if we don't have any yet.
1233 auto nlaRefs = createNLAs(toModuleName, fromModule);
1234 for (auto nla : nlaRefs)
1235 targetMap[nla.getAttr()].insert(to);
1236
1237 // Clone the annotation for each new NLA.
1238 cloneAnnotation(nlaRefs, anno, attributes, nonLocalIndex, newAnnotations);
1239 return true;
1240 }
1241
1242 void copyAnnotations(FModuleLike toModule, AnnoTarget to,
1243 FModuleLike fromModule, AnnotationSet annos,
1244 SmallVectorImpl<Annotation> &newAnnotations,
1245 SmallPtrSetImpl<Attribute> &dontTouches) {
1246 for (auto anno : annos) {
1247 if (anno.isClass(dontTouchAnnoClass)) {
1248 // Remove the nonlocal field of the annotation if it has one, since this
1249 // is a sticky annotation.
1250 anno.removeMember("circt.nonlocal");
1251 auto [it, inserted] = dontTouches.insert(anno.getAttr());
1252 if (inserted)
1253 newAnnotations.push_back(anno);
1254 continue;
1255 }
1256 // If the annotation is already non-local, we add it as is. It is already
1257 // added to the target map.
1258 if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal")) {
1259 newAnnotations.push_back(anno);
1260 targetMap[nla.getAttr()].insert(to);
1261 continue;
1262 }
1263 // Otherwise make the annotation non-local and add it to the set.
1264 makeAnnotationNonLocal(toModule.getModuleNameAttr(), to, fromModule, anno,
1265 newAnnotations);
1266 }
1267 }
1268
1269 /// Merge the annotations of a specific target, either a operation or a port
1270 /// on an operation.
1271 void mergeAnnotations(FModuleLike toModule, AnnoTarget to,
1272 AnnotationSet toAnnos, FModuleLike fromModule,
1273 AnnoTarget from, AnnotationSet fromAnnos) {
1274 // This is a list of all the annotations which will be added to `to`.
1275 SmallVector<Annotation> newAnnotations;
1276
1277 // We have special case handling of DontTouch to prevent it from being
1278 // turned into a non-local annotation, and to remove duplicates.
1279 llvm::SmallPtrSet<Attribute, 4> dontTouches;
1280
1281 // Iterate the annotations, transforming most annotations into non-local
1282 // ones.
1283 copyAnnotations(toModule, to, toModule, toAnnos, newAnnotations,
1284 dontTouches);
1285 copyAnnotations(toModule, to, fromModule, fromAnnos, newAnnotations,
1286 dontTouches);
1287
1288 // Copy over all the new annotations.
1289 if (!newAnnotations.empty())
1290 to.setAnnotations(AnnotationSet(newAnnotations, context));
1291 }
1292
1293 /// Merge all annotations and port annotations on two operations.
1294 void mergeAnnotations(FModuleLike toModule, Operation *to,
1295 FModuleLike fromModule, Operation *from) {
1296 // Merge op annotations.
1297 mergeAnnotations(toModule, OpAnnoTarget(to), AnnotationSet(to), fromModule,
1298 OpAnnoTarget(from), AnnotationSet(from));
1299
1300 // Merge port annotations.
1301 if (toModule == to) {
1302 // Merge module port annotations.
1303 for (unsigned i = 0, e = getNumPorts(toModule); i < e; ++i)
1304 mergeAnnotations(toModule, PortAnnoTarget(toModule, i),
1305 AnnotationSet::forPort(toModule, i), fromModule,
1306 PortAnnoTarget(fromModule, i),
1307 AnnotationSet::forPort(fromModule, i));
1308 } else if (auto toMem = dyn_cast<MemOp>(to)) {
1309 // Merge memory port annotations.
1310 auto fromMem = cast<MemOp>(from);
1311 for (unsigned i = 0, e = toMem.getNumResults(); i < e; ++i)
1312 mergeAnnotations(toModule, PortAnnoTarget(toMem, i),
1313 AnnotationSet::forPort(toMem, i), fromModule,
1314 PortAnnoTarget(fromMem, i),
1315 AnnotationSet::forPort(fromMem, i));
1316 }
1317 }
1318
1319 hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule,
1320 hw::InnerSymAttr toSym,
1321 hw::InnerSymAttr fromSym) {
1322 if (fromSym && !fromSym.getProps().empty()) {
1323 auto &isn = getNamespace(toModule);
1324 // The properties for the new inner symbol..
1325 SmallVector<hw::InnerSymPropertiesAttr> newProps;
1326 // If the "to" op already has an inner symbol, copy all its properties.
1327 if (toSym)
1328 llvm::append_range(newProps, toSym);
1329 // Add each property from the fromSym to the toSym.
1330 for (auto fromProp : fromSym) {
1331 hw::InnerSymPropertiesAttr newProp;
1332 auto *it = llvm::find_if(newProps, [&](auto p) {
1333 return p.getFieldID() == fromProp.getFieldID();
1334 });
1335 if (it != newProps.end()) {
1336 // If we already have an inner sym with the same field id, use
1337 // that.
1338 newProp = *it;
1339 // If the old symbol is public, we need to make the new one public.
1340 if (fromProp.getSymVisibility().getValue() == "public" &&
1341 newProp.getSymVisibility().getValue() != "public") {
1342 *it = hw::InnerSymPropertiesAttr::get(context, newProp.getName(),
1343 newProp.getFieldID(),
1344 fromProp.getSymVisibility());
1345 }
1346 } else {
1347 // We need to add a new property to the inner symbol for this field.
1348 auto newName = isn.newName(fromProp.getName().getValue());
1349 newProp = hw::InnerSymPropertiesAttr::get(
1350 context, StringAttr::get(context, newName), fromProp.getFieldID(),
1351 fromProp.getSymVisibility());
1352 newProps.push_back(newProp);
1353 }
1354 renameMap[fromProp.getName()] = newProp.getName();
1355 }
1356 // Sort the fields by field id.
1357 llvm::sort(newProps, [](auto &p, auto &q) {
1358 return p.getFieldID() < q.getFieldID();
1359 });
1360 // Return the merged inner symbol.
1361 return hw::InnerSymAttr::get(context, newProps);
1362 }
1363 return hw::InnerSymAttr();
1364 }
1365
1366 // Record the symbol name change of the operation or any of its ports when
1367 // merging two operations. The renamed symbols are used to update the
1368 // target of any NLAs. This will add symbols to the "to" operation if needed.
1369 void recordSymRenames(RenameMap &renameMap, FModuleLike toModule,
1370 Operation *to, FModuleLike fromModule,
1371 Operation *from) {
1372 // If the "from" operation has an inner_sym, we need to make sure the
1373 // "to" operation also has an `inner_sym` and then record the renaming.
1374 if (auto fromInnerSym = dyn_cast<hw::InnerSymbolOpInterface>(from)) {
1375 auto toInnerSym = cast<hw::InnerSymbolOpInterface>(to);
1376 if (auto newSymAttr = mergeInnerSymbols(renameMap, toModule,
1377 toInnerSym.getInnerSymAttr(),
1378 fromInnerSym.getInnerSymAttr()))
1379 toInnerSym.setInnerSymbolAttr(newSymAttr);
1380 }
1381
1382 // If there are no port symbols on the "from" operation, we are done here.
1383 auto fromPortSyms = from->getAttrOfType<ArrayAttr>("portSymbols");
1384 if (!fromPortSyms || fromPortSyms.empty())
1385 return;
1386 // We have to map each "fromPort" to each "toPort".
1387 auto portCount = fromPortSyms.size();
1388 auto toPortSyms = to->getAttrOfType<ArrayAttr>("portSymbols");
1389
1390 // Create an array of new port symbols for the "to" operation, copy in the
1391 // old symbols if it has any, create an empty symbol array if it doesn't.
1392 SmallVector<Attribute> newPortSyms;
1393 if (toPortSyms.empty())
1394 newPortSyms.assign(portCount, hw::InnerSymAttr());
1395 else
1396 newPortSyms.assign(toPortSyms.begin(), toPortSyms.end());
1397
1398 for (unsigned portNo = 0; portNo < portCount; ++portNo) {
1399 if (auto newPortSym = mergeInnerSymbols(
1400 renameMap, toModule,
1401 llvm::cast_if_present<hw::InnerSymAttr>(newPortSyms[portNo]),
1402 cast<hw::InnerSymAttr>(fromPortSyms[portNo]))) {
1403 newPortSyms[portNo] = newPortSym;
1404 }
1405 }
1406
1407 // Commit the new symbol attribute.
1408 FModuleLike::fixupPortSymsArray(newPortSyms, toModule.getContext());
1409 cast<FModuleLike>(to).setPortSymbols(newPortSyms);
1410 }
1411
1412 /// Recursively merge two operations.
1413 // NOLINTNEXTLINE(misc-no-recursion)
1414 void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to,
1415 FModuleLike fromModule, Operation *from) {
1416 // Merge the operation locations.
1417 if (to->getLoc() != from->getLoc())
1418 to->setLoc(mergeLoc(context, to->getLoc(), from->getLoc()));
1419
1420 // Recurse into any regions.
1421 for (auto regions : llvm::zip(to->getRegions(), from->getRegions()))
1422 mergeRegions(renameMap, toModule, std::get<0>(regions), fromModule,
1423 std::get<1>(regions));
1424
1425 // Record any inner_sym renamings that happened.
1426 recordSymRenames(renameMap, toModule, to, fromModule, from);
1427
1428 // Merge the annotations.
1429 mergeAnnotations(toModule, to, fromModule, from);
1430 }
1431
1432 /// Recursively merge two blocks.
1433 void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock,
1434 FModuleLike fromModule, Block &fromBlock) {
1435 // Merge the block locations.
1436 for (auto [toArg, fromArg] :
1437 llvm::zip(toBlock.getArguments(), fromBlock.getArguments()))
1438 if (toArg.getLoc() != fromArg.getLoc())
1439 toArg.setLoc(mergeLoc(context, toArg.getLoc(), fromArg.getLoc()));
1440
1441 for (auto ops : llvm::zip(toBlock, fromBlock))
1442 mergeOps(renameMap, toModule, &std::get<0>(ops), fromModule,
1443 &std::get<1>(ops));
1444 }
1445
1446 // Recursively merge two regions.
1447 void mergeRegions(RenameMap &renameMap, FModuleLike toModule,
1448 Region &toRegion, FModuleLike fromModule,
1449 Region &fromRegion) {
1450 for (auto blocks : llvm::zip(toRegion, fromRegion))
1451 mergeBlocks(renameMap, toModule, std::get<0>(blocks), fromModule,
1452 std::get<1>(blocks));
1453 }
1454
1455 MLIRContext *context;
1457 SymbolTable &symbolTable;
1458
1459 /// Cached nla table analysis.
1460 NLATable *nlaTable = nullptr;
1461
1462 /// We insert all NLAs to the beginning of this block.
1463 Block *nlaBlock;
1464
1465 // This maps an NLA to the operations and ports that uses it.
1466 DenseMap<Attribute, llvm::SmallDenseSet<AnnoTarget>> targetMap;
1467
1468 // This is a cache to avoid creating duplicate NLAs. This maps the ArrayAtr
1469 // of the NLA's path to the name of the NLA which contains it.
1470 DenseMap<Attribute, Attribute> nlaCache;
1471
1472 // Cached attributes for faster comparisons and attribute building.
1473 StringAttr nonLocalString;
1474 StringAttr classString;
1475
1476 /// A module namespace cache.
1477 DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
1478};
1479
1480//===----------------------------------------------------------------------===//
1481// Fixup
1482//===----------------------------------------------------------------------===//
1483
1484/// This fixes up ClassLikes with ClassType ports, when the classes have
1485/// deduped. For each ClassType port, if the object reference being assigned is
1486/// a different type, update the port type. Returns true if the ClassOp was
1487/// updated and the associated ObjectOps should be updated.
1488bool fixupClassOp(ClassOp classOp) {
1489 // New port type attributes, if necessary.
1490 SmallVector<Attribute> newPortTypes;
1491 bool anyDifferences = false;
1492
1493 // Check each port.
1494 for (size_t i = 0, e = classOp.getNumPorts(); i < e; ++i) {
1495 // Check if this port is a ClassType. If not, save the original type
1496 // attribute in case we need to update port types.
1497 auto portClassType = dyn_cast<ClassType>(classOp.getPortType(i));
1498 if (!portClassType) {
1499 newPortTypes.push_back(classOp.getPortTypeAttr(i));
1500 continue;
1501 }
1502
1503 // Check if this port is assigned a reference of a different ClassType.
1504 Type newPortClassType;
1505 BlockArgument portArg = classOp.getArgument(i);
1506 for (auto &use : portArg.getUses()) {
1507 if (auto propassign = dyn_cast<PropAssignOp>(use.getOwner())) {
1508 Type sourceType = propassign.getSrc().getType();
1509 if (propassign.getDest() == use.get() && sourceType != portClassType) {
1510 // Double check that all references are the same new type.
1511 if (newPortClassType) {
1512 assert(newPortClassType == sourceType &&
1513 "expected all references to be of the same type");
1514 continue;
1515 }
1516
1517 newPortClassType = sourceType;
1518 }
1519 }
1520 }
1521
1522 // If there was no difference, save the original type attribute in case we
1523 // need to update port types and move along.
1524 if (!newPortClassType) {
1525 newPortTypes.push_back(classOp.getPortTypeAttr(i));
1526 continue;
1527 }
1528
1529 // The port type changed, so update the block argument, save the new port
1530 // type attribute, and indicate there was a difference.
1531 classOp.getArgument(i).setType(newPortClassType);
1532 newPortTypes.push_back(TypeAttr::get(newPortClassType));
1533 anyDifferences = true;
1534 }
1535
1536 // If necessary, update port types.
1537 if (anyDifferences)
1538 classOp.setPortTypes(newPortTypes);
1539
1540 return anyDifferences;
1541}
1542
1543/// This fixes up ObjectOps when the signature of their ClassOp changes. This
1544/// amounts to updating the ObjectOp result type to match the newly updated
1545/// ClassOp type.
1546void fixupObjectOp(ObjectOp objectOp, ClassType newClassType) {
1547 objectOp.getResult().setType(newClassType);
1548}
1549
1550/// This fixes up connects when the field names of a bundle type changes. It
1551/// finds all fields which were previously bulk connected and legalizes it
1552/// into a connect for each field.
1553void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src) {
1554 // If the types already match we can emit a connect.
1555 auto dstType = dst.getType();
1556 auto srcType = src.getType();
1557 if (dstType == srcType) {
1558 emitConnect(builder, dst, src);
1559 return;
1560 }
1561 // It must be a bundle type and the field name has changed. We have to
1562 // manually decompose the bulk connect into a connect for each field.
1563 auto dstBundle = type_cast<BundleType>(dstType);
1564 auto srcBundle = type_cast<BundleType>(srcType);
1565 for (unsigned i = 0; i < dstBundle.getNumElements(); ++i) {
1566 auto dstField = builder.create<SubfieldOp>(dst, i);
1567 auto srcField = builder.create<SubfieldOp>(src, i);
1568 if (dstBundle.getElement(i).isFlip) {
1569 std::swap(srcBundle, dstBundle);
1570 std::swap(srcField, dstField);
1571 }
1572 fixupConnect(builder, dstField, srcField);
1573 }
1574}
1575
1576/// This is the root method to fixup module references when a module changes.
1577/// It matches all the results of "to" module with the results of the "from"
1578/// module.
1579void fixupAllModules(InstanceGraph &instanceGraph) {
1580 for (auto *node : instanceGraph) {
1581 auto module = cast<FModuleLike>(*node->getModule());
1582
1583 // Handle class declarations here.
1584 bool shouldFixupObjects = false;
1585 auto classOp = dyn_cast<ClassOp>(module.getOperation());
1586 if (classOp)
1587 shouldFixupObjects = fixupClassOp(classOp);
1588
1589 for (auto *instRec : node->uses()) {
1590 // Handle object instantiations here.
1591 if (classOp) {
1592 if (shouldFixupObjects) {
1593 fixupObjectOp(instRec->getInstance<ObjectOp>(),
1594 classOp.getInstanceType());
1595 }
1596 continue;
1597 }
1598
1599 auto inst = instRec->getInstance<InstanceOp>();
1600 // Only handle module instantiations here.
1601 if (!inst)
1602 continue;
1603 ImplicitLocOpBuilder builder(inst.getLoc(), inst->getContext());
1604 builder.setInsertionPointAfter(inst);
1605 for (size_t i = 0, e = getNumPorts(module); i < e; ++i) {
1606 auto result = inst.getResult(i);
1607 auto newType = module.getPortType(i);
1608 auto oldType = result.getType();
1609 // If the type has not changed, we don't have to fix up anything.
1610 if (newType == oldType)
1611 continue;
1612 // If the type changed we transform it back to the old type with an
1613 // intermediate wire.
1614 auto wire =
1615 builder.create<WireOp>(oldType, inst.getPortName(i)).getResult();
1616 result.replaceAllUsesWith(wire);
1617 result.setType(newType);
1618 if (inst.getPortDirection(i) == Direction::Out)
1619 fixupConnect(builder, wire, result);
1620 else
1621 fixupConnect(builder, result, wire);
1622 }
1623 }
1624 }
1625}
1626
1627namespace llvm {
1628/// A DenseMapInfo implementation for `ModuleInfo` that is a pair of
1629/// llvm::SHA256 hashes, which are represented as std::array<uint8_t, 32>, and
1630/// an array of string attributes. This allows us to create a DenseMap with
1631/// `ModuleInfo` as keys.
1632template <>
1633struct DenseMapInfo<ModuleInfo> {
1634 static inline ModuleInfo getEmptyKey() {
1635 std::array<uint8_t, 32> key;
1636 std::fill(key.begin(), key.end(), ~0);
1637 return {key, {}};
1638 }
1639
1640 static inline ModuleInfo getTombstoneKey() {
1641 std::array<uint8_t, 32> key;
1642 std::fill(key.begin(), key.end(), ~0 - 1);
1643 return {key, {}};
1644 }
1645
1646 static unsigned getHashValue(const ModuleInfo &val) {
1647 // We assume SHA256 is already a good hash and just truncate down to the
1648 // number of bytes we need for DenseMap.
1649 unsigned hash;
1650 std::memcpy(&hash, val.structuralHash.data(), sizeof(unsigned));
1651
1652 // Combine module names.
1653 return llvm::hash_combine(
1654 hash, llvm::hash_combine_range(val.referredModuleNames.begin(),
1655 val.referredModuleNames.end()));
1656 }
1657
1658 static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs) {
1659 return lhs == rhs;
1660 }
1661};
1662} // namespace llvm
1663
1664//===----------------------------------------------------------------------===//
1665// DedupPass
1666//===----------------------------------------------------------------------===//
1667
1668namespace {
1669class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {
1670 void runOnOperation() override {
1671 auto *context = &getContext();
1672 auto circuit = getOperation();
1673 auto &instanceGraph = getAnalysis<InstanceGraph>();
1674 auto *nlaTable = &getAnalysis<NLATable>();
1675 auto &symbolTable = getAnalysis<SymbolTable>();
1676 Deduper deduper(instanceGraph, symbolTable, nlaTable, circuit);
1677 Equivalence equiv(context, instanceGraph);
1678 auto anythingChanged = false;
1679
1680 // Modules annotated with this should not be considered for deduplication.
1681 auto noDedupClass = StringAttr::get(context, noDedupAnnoClass);
1682
1683 // Only modules within the same group may be deduplicated.
1684 auto dedupGroupClass = StringAttr::get(context, dedupGroupAnnoClass);
1685
1686 // A map of all the module moduleInfo that we have calculated so far.
1687 llvm::DenseMap<ModuleInfo, Operation *> moduleInfoToModule;
1688
1689 // We track the name of the module that each module is deduped into, so that
1690 // we can make sure all modules which are marked "must dedup" with each
1691 // other were all deduped to the same module.
1692 DenseMap<Attribute, StringAttr> dedupMap;
1693
1694 // We must iterate the modules from the bottom up so that we can properly
1695 // deduplicate the modules. We copy the list of modules into a vector first
1696 // to avoid iterator invalidation while we mutate the instance graph.
1697 SmallVector<FModuleLike, 0> modules(
1698 llvm::map_range(llvm::post_order(&instanceGraph), [](auto *node) {
1699 return cast<FModuleLike>(*node->getModule());
1700 }));
1701
1702 SmallVector<std::optional<ModuleInfo>> moduleInfos(modules.size());
1703 StructuralHasherSharedConstants hasherConstants(&getContext());
1704
1705 // Attribute name used to store dedup_group for this pass.
1706 auto dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
1707
1708 // Move dedup group annotations to attributes on the module.
1709 // This results in the desired behavior (included in hash),
1710 // and avoids unnecessary processing of these as annotations
1711 // that need to be tracked, made non-local, so on.
1712 for (auto module : modules) {
1713 llvm::SmallSetVector<StringAttr, 1> groups;
1715 module, [&groups, dedupGroupClass](Annotation annotation) {
1716 if (annotation.getClassAttr() != dedupGroupClass)
1717 return false;
1718 groups.insert(annotation.getMember<StringAttr>("group"));
1719 return true;
1720 });
1721 if (groups.size() > 1) {
1722 module.emitError("module belongs to multiple dedup groups: ") << groups;
1723 return signalPassFailure();
1724 }
1725 assert(!module->hasAttr(dedupGroupAttrName) &&
1726 "unexpected existing use of temporary dedup group attribute");
1727 if (!groups.empty())
1728 module->setDiscardableAttr(dedupGroupAttrName, groups.front());
1729 }
1730
1731 // Calculate module information parallelly.
1732 auto result = mlir::failableParallelForEach(
1733 context, llvm::seq(modules.size()), [&](unsigned idx) {
1734 auto module = modules[idx];
1735 // If the module is marked with NoDedup, just skip it.
1736 if (AnnotationSet::hasAnnotation(module, noDedupClass))
1737 return success();
1738
1739 // Only dedup extmodule's with defname.
1740 if (auto ext = dyn_cast<FExtModuleOp>(*module);
1741 ext && !ext.getDefname().has_value())
1742 return success();
1743
1744 StructuralHasher hasher(hasherConstants);
1745 // Calculate the hash of the module and referred module names.
1746 moduleInfos[idx] = hasher.getModuleInfo(module);
1747 return success();
1748 });
1749
1750 if (result.failed())
1751 return signalPassFailure();
1752
1753 for (auto [i, module] : llvm::enumerate(modules)) {
1754 auto moduleName = module.getModuleNameAttr();
1755 auto &maybeModuleInfo = moduleInfos[i];
1756 // If the hash was not calculated, we need to skip it.
1757 if (!maybeModuleInfo) {
1758 // We record it in the dedup map to help detect errors when the user
1759 // marks the module as both NoDedup and MustDedup. We do not record this
1760 // module in the hasher to make sure no other module dedups "into" this
1761 // one.
1762 dedupMap[moduleName] = moduleName;
1763 continue;
1764 }
1765
1766 auto &moduleInfo = maybeModuleInfo.value();
1767
1768 // Replace module names referred in the module with new names.
1769 for (auto &referredModule : moduleInfo.referredModuleNames)
1770 referredModule = dedupMap[referredModule];
1771
1772 // Check if there a module with the same hash.
1773 auto it = moduleInfoToModule.find(moduleInfo);
1774 if (it != moduleInfoToModule.end()) {
1775 auto original = cast<FModuleLike>(it->second);
1776 auto originalName = original.getModuleNameAttr();
1777
1778 // If the current module is public, and the original is private, we
1779 // want to dedup the private module into the public one.
1780 if (!canRemoveModule(module)) {
1781 // If both modules are public, then we can't dedup anything.
1782 if (!canRemoveModule(original))
1783 continue;
1784 // Swap the canonical module in the dedup map.
1785 for (auto &[originalName, dedupedName] : dedupMap)
1786 if (dedupedName == originalName)
1787 dedupedName = moduleName;
1788 // Update the module hash table to point to the new original, so all
1789 // future modules dedup with the new canonical module.
1790 it->second = module;
1791 // Swap the locals.
1792 std::swap(originalName, moduleName);
1793 std::swap(original, module);
1794 }
1795
1796 // Record the group ID of the other module.
1797 dedupMap[moduleName] = originalName;
1798 deduper.dedup(original, module);
1799 ++erasedModules;
1800 anythingChanged = true;
1801 continue;
1802 }
1803 // Any module not deduplicated must be recorded.
1804 deduper.record(module);
1805 // Add the module to a new dedup group.
1806 dedupMap[moduleName] = moduleName;
1807 // Record the module info.
1808 moduleInfoToModule[std::move(moduleInfo)] = module;
1809 }
1810
1811 // This part verifies that all modules marked by "MustDedup" have been
1812 // properly deduped with each other. For this check to succeed, all modules
1813 // have to been deduped to the same module. It is possible that a module was
1814 // deduped with the wrong thing.
1815
1816 auto failed = false;
1817 // This parses the module name out of a target string.
1818 auto parseModule = [&](Attribute path) -> StringAttr {
1819 // Each module is listed as a target "~Circuit|Module" which we have to
1820 // parse.
1821 auto [_, rhs] = cast<StringAttr>(path).getValue().split('|');
1822 return StringAttr::get(context, rhs);
1823 };
1824 // This gets the name of the module which the current module was deduped
1825 // with. If the named module isn't in the map, then we didn't encounter it
1826 // in the circuit.
1827 auto getLead = [&](StringAttr module) -> StringAttr {
1828 auto it = dedupMap.find(module);
1829 if (it == dedupMap.end()) {
1830 auto diag = emitError(circuit.getLoc(),
1831 "MustDeduplicateAnnotation references module ")
1832 << module << " which does not exist";
1833 failed = true;
1834 return nullptr;
1835 }
1836 return it->second;
1837 };
1838
1839 AnnotationSet::removeAnnotations(circuit, [&](Annotation annotation) {
1840 if (!annotation.isClass(mustDedupAnnoClass))
1841 return false;
1842 auto modules = annotation.getMember<ArrayAttr>("modules");
1843 if (!modules) {
1844 emitError(circuit.getLoc(),
1845 "MustDeduplicateAnnotation missing \"modules\" member");
1846 failed = true;
1847 return false;
1848 }
1849 // Empty module list has nothing to process.
1850 if (modules.empty())
1851 return true;
1852 // Get the first element.
1853 auto firstModule = parseModule(modules[0]);
1854 auto firstLead = getLead(firstModule);
1855 if (!firstLead)
1856 return false;
1857 // Verify that the remaining elements are all the same as the first.
1858 for (auto attr : modules.getValue().drop_front()) {
1859 auto nextModule = parseModule(attr);
1860 auto nextLead = getLead(nextModule);
1861 if (!nextLead)
1862 return false;
1863 if (firstLead != nextLead) {
1864 auto diag = emitError(circuit.getLoc(), "module ")
1865 << nextModule << " not deduplicated with " << firstModule;
1866 auto a = instanceGraph.lookup(firstLead)->getModule();
1867 auto b = instanceGraph.lookup(nextLead)->getModule();
1868 equiv.check(diag, a, b);
1869 failed = true;
1870 return false;
1871 }
1872 }
1873 return true;
1874 });
1875 if (failed)
1876 return signalPassFailure();
1877
1878 // Remove all dedup group attributes, they only exist during this pass.
1879 for (auto module : circuit.getOps<FModuleLike>())
1880 module->removeDiscardableAttr(dedupGroupAttrName);
1881
1882 // Walk all the modules and fixup the instance operation to return the
1883 // correct type. We delay this fixup until the end because doing it early
1884 // can block the deduplication of the parent modules.
1885 fixupAllModules(instanceGraph);
1886
1887 markAnalysesPreserved<NLATable>();
1888 if (!anythingChanged)
1889 markAllAnalysesPreserved();
1890 }
1891};
1892} // end anonymous namespace
1893
1894std::unique_ptr<mlir::Pass> circt::firrtl::createDedupPass() {
1895 return std::make_unique<DedupPass>();
1896}
assert(baseType &&"element must be base type")
static Location mergeLoc(MLIRContext *context, Location to, Location from)
Definition Dedup.cpp:874
void fixupObjectOp(ObjectOp objectOp, ClassType newClassType)
This fixes up ObjectOps when the signature of their ClassOp changes.
Definition Dedup.cpp:1546
llvm::raw_ostream & printHash(llvm::raw_ostream &stream, llvm::SHA256 &data)
Definition Dedup.cpp:79
static bool canRemoveModule(mlir::SymbolOpInterface symbol)
Returns true if the module can be removed.
Definition Dedup.cpp:51
void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src)
This fixes up connects when the field names of a bundle type changes.
Definition Dedup.cpp:1553
void fixupAllModules(InstanceGraph &instanceGraph)
This is the root method to fixup module references when a module changes.
Definition Dedup.cpp:1579
bool fixupClassOp(ClassOp classOp)
This fixes up ClassLikes with ClassType ports, when the classes have deduped.
Definition Dedup.cpp:1488
llvm::raw_ostream & printHex(llvm::raw_ostream &stream, ArrayRef< uint8_t > bytes)
Definition Dedup.cpp:73
static void mergeRegions(Region *region1, Region *region2)
Definition HWCleanup.cpp:77
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
void setDict(DictionaryAttr dict)
Set the data dictionary of this attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
StringAttr getClassAttr() const
Return the 'class' that this annotation is representing.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
A table of inner symbols and their resolutions.
auto getModule()
Get the module that this node is tracking.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
static StringRef toString(Direction direction)
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
constexpr const char * mustDedupAnnoClass
constexpr const char * noDedupAnnoClass
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
constexpr const char * dedupGroupAnnoClass
std::unique_ptr< mlir::Pass > createDedupPass()
Definition Dedup.cpp:1894
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
constexpr const char * dontTouchAnnoClass
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:35
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
SmallVector< FlatSymbolRefAttr > createNLAs(Operation *fromModule, ArrayRef< Attribute > baseNamepath, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of the from module and create an NLA for each one, appending the baseNamep...
Definition Dedup.cpp:1049
MLIRContext * context
Definition Dedup.cpp:1455
Block * nlaBlock
We insert all NLAs to the beginning of this block.
Definition Dedup.cpp:1463
void recordAnnotations(Operation *op)
Record all targets which use an NLA.
Definition Dedup.cpp:1004
void eraseNLA(hw::HierPathOp nla)
This erases the NLA op, and removes the NLA from every module's NLA map, but it does not delete the N...
Definition Dedup.cpp:1113
void mergeAnnotations(FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Merge all annotations and port annotations on two operations.
Definition Dedup.cpp:1294
void replaceInstances(FModuleLike toModule, Operation *fromModule)
This deletes and replaces all instances of the "fromModule" with instances of the "toModule".
Definition Dedup.cpp:1020
void record(FModuleLike module)
Record the usages of any NLA's in this module, so that we may update the annotation if the parent mod...
Definition Dedup.cpp:978
void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName, StringAttr fromName)
Definition Dedup.cpp:1192
void mergeRegions(RenameMap &renameMap, FModuleLike toModule, Region &toRegion, FModuleLike fromModule, Region &fromRegion)
Definition Dedup.cpp:1447
void dedup(FModuleLike toModule, FModuleLike fromModule)
Remove the "fromModule", and replace all references to it with the "toModule".
Definition Dedup.cpp:943
void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all the NLAs that the two modules participate in, replacing references to the "from" module w...
Definition Dedup.cpp:1184
SmallVector< FlatSymbolRefAttr > createNLAs(StringAttr toModuleName, FModuleLike fromModule, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of this module and create an NLA for each one.
Definition Dedup.cpp:1086
void recordAnnotations(AnnoTarget target)
For a specific annotation target, record all the unique NLAs which target it in the targetMap.
Definition Dedup.cpp:997
NLATable * nlaTable
Cached nla table analysis.
Definition Dedup.cpp:1460
hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule, hw::InnerSymAttr toSym, hw::InnerSymAttr fromSym)
Definition Dedup.cpp:1319
void cloneAnnotation(SmallVectorImpl< FlatSymbolRefAttr > &nlas, Annotation anno, ArrayRef< NamedAttribute > attributes, unsigned nonLocalIndex, SmallVectorImpl< Annotation > &newAnnotations)
Clone the annotation for each NLA in a list.
Definition Dedup.cpp:1094
void recordSymRenames(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Definition Dedup.cpp:1369
void mergeAnnotations(FModuleLike toModule, AnnoTarget to, AnnotationSet toAnnos, FModuleLike fromModule, AnnoTarget from, AnnotationSet fromAnnos)
Merge the annotations of a specific target, either a operation or a port on an operation.
Definition Dedup.cpp:1271
StringAttr nonLocalString
Definition Dedup.cpp:1473
hw::InnerSymbolNamespace & getNamespace(Operation *module)
Get a cached namespace for a module.
Definition Dedup.cpp:990
SymbolTable & symbolTable
Definition Dedup.cpp:1457
void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Recursively merge two operations.
Definition Dedup.cpp:1414
DenseMap< Operation *, hw::InnerSymbolNamespace > moduleNamespaces
A module namespace cache.
Definition Dedup.cpp:1477
bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to, FModuleLike fromModule, Annotation anno, SmallVectorImpl< Annotation > &newAnnotations)
Take an annotation, and update it to be a non-local annotation.
Definition Dedup.cpp:1200
InstanceGraph & instanceGraph
Definition Dedup.cpp:1456
void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock, FModuleLike fromModule, Block &fromBlock)
Recursively merge two blocks.
Definition Dedup.cpp:1433
DenseMap< Attribute, llvm::SmallDenseSet< AnnoTarget > > targetMap
Definition Dedup.cpp:1466
StringAttr classString
Definition Dedup.cpp:1474
void copyAnnotations(FModuleLike toModule, AnnoTarget to, FModuleLike fromModule, AnnotationSet annos, SmallVectorImpl< Annotation > &newAnnotations, SmallPtrSetImpl< Attribute > &dontTouches)
Definition Dedup.cpp:1242
Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable, NLATable *nlaTable, CircuitOp circuit)
Definition Dedup.cpp:927
void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all NLAs referencing the "from" module to point to the "to" module.
Definition Dedup.cpp:1123
DenseMap< StringAttr, StringAttr > RenameMap
Definition Dedup.cpp:925
DenseMap< Attribute, Attribute > nlaCache
Definition Dedup.cpp:1470
const hw::InnerSymbolTable & a
Definition Dedup.cpp:408
ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
Definition Dedup.cpp:405
const hw::InnerSymbolTable & b
Definition Dedup.cpp:409
This class is for reporting differences between two modules which should have been deduplicated.
Definition Dedup.cpp:387
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:861
std::string prettyPrint(Attribute attr)
Definition Dedup.cpp:412
LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a, FInstanceLike b)
Definition Dedup.cpp:694
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Block &aBlock, Operation *b, Block &bBlock)
Definition Dedup.cpp:483
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, Type aType, Operation *b, Type bType)
Definition Dedup.cpp:457
StringAttr noDedupClass
Definition Dedup.cpp:856
StringAttr dedupGroupAttrName
Definition Dedup.cpp:858
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, DictionaryAttr aDict, Operation *b, DictionaryAttr bDict)
Definition Dedup.cpp:606
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Region &aRegion, Operation *b, Region &bRegion)
Definition Dedup.cpp:561
StringAttr portDirectionsAttr
Definition Dedup.cpp:854
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, BundleType aType, Operation *b, BundleType bType)
Definition Dedup.cpp:428
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Operation *b)
Definition Dedup.cpp:715
Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
Definition Dedup.cpp:388
LogicalResult check(InFlightDiagnostic &diag, Operation *a, mlir::DenseBoolArrayAttr aAttr, Operation *b, mlir::DenseBoolArrayAttr bAttr)
Definition Dedup.cpp:581
InstanceGraph & instanceGraph
Definition Dedup.cpp:862
void check(InFlightDiagnostic &diag, Operation *a, Operation *b)
Definition Dedup.cpp:802
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:98
std::array< uint8_t, 32 > structuralHash
Definition Dedup.cpp:96
This struct contains constant string attributes shared across different threads.
Definition Dedup.cpp:108
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:135
StructuralHasherSharedConstants(MLIRContext *context)
Definition Dedup.cpp:109
void populateInnerSymIDTable(FModuleLike module)
Find all the ports and operations which may define an inner symbol operations and give each a unique ...
Definition Dedup.cpp:154
void update(Operation *op, DictionaryAttr dict)
Hash the top level attribute dictionary of the operation.
Definition Dedup.cpp:259
void update(Type type)
Definition Dedup.cpp:239
void update(const void *pointer)
Definition Dedup.cpp:211
DenseMap< void *, unsigned > idTable
Definition Dedup.cpp:361
void update(const std::pair< T, U > &pair)
Definition Dedup.cpp:222
void update(Operation *op)
Definition Dedup.cpp:333
llvm::SHA256 sha
Definition Dedup.cpp:375
DenseMap< StringAttr, std::pair< size_t, size_t > > innerSymIDTable
Definition Dedup.cpp:365
ModuleInfo getModuleInfo(FModuleLike module)
Definition Dedup.cpp:142
void update(size_t value)
Definition Dedup.cpp:216
void update(BundleType type)
Definition Dedup.cpp:230
unsigned getID(void *object)
Definition Dedup.cpp:173
void update(OpResult result)
Definition Dedup.cpp:245
void update(OpOperand &operand)
Definition Dedup.cpp:194
StructuralHasher(const StructuralHasherSharedConstants &constants)
Definition Dedup.cpp:139
void update(Region *region)
Definition Dedup.cpp:325
void update(Block *block)
Definition Dedup.cpp:314
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:368
void update(TypeID typeID)
Definition Dedup.cpp:227
const StructuralHasherSharedConstants & constants
Definition Dedup.cpp:371
std::pair< size_t, size_t > getInnerSymID(StringAttr name)
Definition Dedup.cpp:190
unsigned finalizeID(void *object)
Definition Dedup.cpp:181
void update(mlir::OperationName name)
Definition Dedup.cpp:308
An annotation target is used to keep track of something that is targeted by an Annotation.
AnnotationSet getAnnotations() const
Get the annotations associated with the target.
void setAnnotations(AnnotationSet annotations) const
Set the annotations associated with the target.
This represents an annotation targeting a specific operation.
Attribute getNLAReference(hw::InnerSymbolNamespace &moduleNamespace) const
This represents an annotation targeting a specific port of a module, memory, or instance.
static ModuleInfo getEmptyKey()
Definition Dedup.cpp:1634
static ModuleInfo getTombstoneKey()
Definition Dedup.cpp:1640
static unsigned getHashValue(const ModuleInfo &val)
Definition Dedup.cpp:1646
static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs)
Definition Dedup.cpp:1658