CIRCT 23.0.0git
Loading...
Searching...
No Matches
Dedup.cpp
Go to the documentation of this file.
1//===- Dedup.cpp - FIRRTL module deduping -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements FIRRTL module deduplication.
10//
11//===----------------------------------------------------------------------===//
12
24#include "circt/Support/Debug.h"
25#include "circt/Support/LLVM.h"
26#include "mlir/IR/IRMapping.h"
27#include "mlir/IR/Threading.h"
28#include "mlir/Pass/Pass.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/DenseMapInfo.h"
31#include "llvm/ADT/Hashing.h"
32#include "llvm/ADT/PostOrderIterator.h"
33#include "llvm/ADT/SmallPtrSet.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/Format.h"
36#include "llvm/Support/SHA256.h"
37
38#define DEBUG_TYPE "firrtl-dedup"
39
40namespace circt {
41namespace firrtl {
42#define GEN_PASS_DEF_DEDUP
43#include "circt/Dialect/FIRRTL/Passes.h.inc"
44} // namespace firrtl
45} // namespace circt
46
47using namespace circt;
48using namespace firrtl;
49using hw::InnerRefAttr;
50
51//===----------------------------------------------------------------------===//
52// Utility function for classifying a Symbol's dedup-ability.
53//===----------------------------------------------------------------------===//
54
55/// Returns true if the module can be removed.
56static bool canRemoveModule(mlir::SymbolOpInterface symbol) {
57 // If the symbol is not private, it cannot be removed.
58 if (!symbol.isPrivate())
59 return false;
60 // Classes may be referenced in object types, so can not normally be removed
61 // if we can't find any symbol uses. Since we know that dedup will update the
62 // types of instances appropriately, we can ignore that and return true here.
63 if (isa<ClassLike>(*symbol))
64 return true;
65 // If module can not be removed even if no uses can be found, we can not
66 // delete it. The implication is that there are hidden symbol uses that dedup
67 // will not properly update.
68 if (!symbol.canDiscardOnUseEmpty())
69 return false;
70 // The module can be deleted.
71 return true;
72}
73
74//===----------------------------------------------------------------------===//
75// Hashing
76//===----------------------------------------------------------------------===//
77
78// This struct contains information to determine module module uniqueness. A
79// first element is a structural hash of the module, and the second element is
80// an array which tracks module names encountered in the walk. Since module
81// names could be replaced during dedup, it's necessary to keep names up-to-date
82// before actually combining them into structural hashes.
83struct ModuleInfo {
84 // SHA256 hash.
85 std::array<uint8_t, 32> structuralHash;
86 // Module names referred by instance op in the module.
87 std::vector<StringAttr> referredModuleNames;
88 // The operations that contain references to symbols that may be changed by
89 // dedup. These need a fixup pass after dedup. This is just an optimization
90 // and does not factor into the hash or the equality between `ModuleInfo`s.
91 std::vector<Operation *> symbolSensitiveOps;
92};
93
94static bool operator==(const ModuleInfo &lhs, const ModuleInfo &rhs) {
95 return lhs.structuralHash == rhs.structuralHash &&
97}
98
99/// This struct contains constant string attributes shared across different
100/// threads.
103 portTypesAttr = StringAttr::get(context, "portTypes");
104 moduleNameAttr = StringAttr::get(context, "moduleName");
105 portNamesAttr = StringAttr::get(context, "portNames");
106 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
107 nonessentialAttributes.insert(StringAttr::get(context, "convention"));
108 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
109 nonessentialAttributes.insert(StringAttr::get(context, "name"));
110 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
111 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
112 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
113 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
114 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
115 nonessentialAttributes.insert(StringAttr::get(context, "sym_visibility"));
116 };
117
118 // This is a cached "portTypes" string attr.
119 StringAttr portTypesAttr;
120
121 // This is a cached "moduleName" string attr.
122 StringAttr moduleNameAttr;
123
124 // This is a cached "moduleNames" string attr.
125 StringAttr moduleNamesAttr;
126
127 // This is a cached "portNames" string attr.
128 StringAttr portNamesAttr;
129
130 // This is a set of every attribute we should ignore.
131 DenseSet<Attribute> nonessentialAttributes;
132};
133
136 : constants(constants) {}
137
138 ModuleInfo getModuleInfo(FModuleLike module) {
140 update(&(*module));
141 return {sha.final(), std::move(referredModuleNames),
142 std::move(symbolSensitiveOps)};
143 }
144
145private:
146 /// Find all the ports and operations which may define an inner symbol
147 /// operations and give each a unique id. If the port/operation does define
148 /// an inner symbol, map the symbol name to a pair of the id and the symbol's
149 /// field id. When we hash (local) references to this inner symbol, we will
150 /// hash in the id and the the field id.
151 void populateInnerSymIDTable(FModuleLike module) {
152 // Add port symbols. If no port has a symbol defined, the port symbol array
153 // will be totally empty.
154 for (auto [index, innerSym] : llvm::enumerate(module.getPortSymbols())) {
155 for (auto prop : cast<hw::InnerSymAttr>(innerSym))
156 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
157 }
158 // Add operation symbols.
159 size_t index = module.getNumPorts();
160 module.walk([&](hw::InnerSymbolOpInterface innerSymOp) {
161 if (auto innerSym = innerSymOp.getInnerSymAttr()) {
162 for (auto prop : innerSym)
163 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
164 }
165 ++index;
166 });
167 }
168
169 // Get the identifier for an object. The identifier is assigned on first use.
170 unsigned getID(void *object) {
171 auto [it, inserted] = idTable.try_emplace(object, nextID);
172 if (inserted)
173 ++nextID;
174 return it->second;
175 }
176
177 // Get the identifier for an IR object. Free the ID, too.
178 unsigned finalizeID(void *object) {
179 auto it = idTable.find(object);
180 if (it == idTable.end())
181 return nextID++;
182 auto id = it->second;
183 idTable.erase(it);
184 return id;
185 }
186
187 std::pair<size_t, size_t> getInnerSymID(StringAttr name) {
188 return innerSymIDTable.at(name);
189 }
190
191 void update(OpOperand &operand) {
192 auto value = operand.get();
193 if (auto result = dyn_cast<OpResult>(value)) {
194 auto *op = result.getOwner();
195 update(getID(op));
196 update(result.getResultNumber());
197 return;
198 }
199 if (auto argument = dyn_cast<BlockArgument>(value)) {
200 auto *block = argument.getOwner();
201 update(getID(block));
202 update(argument.getArgNumber());
203 return;
204 }
205 llvm_unreachable("Unknown value type");
206 }
207
208 void update(const void *pointer) {
209 auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
210 sha.update(ArrayRef<uint8_t>(addr, sizeof pointer));
211 }
212
213 void update(size_t value) {
214 auto *addr = reinterpret_cast<const uint8_t *>(&value);
215 sha.update(ArrayRef<uint8_t>(addr, sizeof value));
216 }
217
218 template <typename T, typename U>
219 void update(const std::pair<T, U> &pair) {
220 update(pair.first);
221 update(pair.second);
222 }
223
224 void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
225
226 // NOLINTNEXTLINE(misc-no-recursion)
227 void update(BundleType type) {
228 update(type.getTypeID());
229 for (auto &element : type.getElements()) {
230 update(element.isFlip);
231 update(element.type);
232 }
233 }
234
235 // NOLINTNEXTLINE(misc-no-recursion)
236 void update(ClassType type) {
237 update(type.getTypeID());
238 // Don't hash the class name directly, since it may be replaced during
239 // dedup. Record the class name instead and lazily combine their hashes
240 // using the same mechanism as instances and modules.
241 hasSeenSymbol = true;
242 referredModuleNames.push_back(type.getNameAttr().getAttr());
243 for (auto &element : type.getElements()) {
244 update(element.name.getAsOpaquePointer());
245 update(element.type);
246 update(static_cast<unsigned>(element.direction));
247 }
248 }
249
250 // NOLINTNEXTLINE(misc-no-recursion)
251 void update(Type type) {
252 if (auto bundle = type_dyn_cast<BundleType>(type))
253 return update(bundle);
254 if (auto klass = type_dyn_cast<ClassType>(type))
255 return update(klass);
256 update(type.getAsOpaquePointer());
257 }
258
259 void update(OpResult result) {
260 // Like instance ops, don't use object ops' result types since they might be
261 // replaced by dedup. Record the class names and lazily combine their hashes
262 // using the same mechanism as instances and modules.
263 if (auto objectOp = dyn_cast<ObjectOp>(result.getOwner())) {
264 hasSeenSymbol = true;
265 referredModuleNames.push_back(objectOp.getType().getNameAttr().getAttr());
266 return;
267 }
268
269 update(result.getType());
270 }
271
272 /// Hash the top level attribute dictionary of the operation. This function
273 /// has special handling for inner symbols, ports, and referenced modules.
274 void update(Operation *op, DictionaryAttr dict) {
275 for (auto namedAttr : dict) {
276 auto name = namedAttr.getName();
277 auto value = namedAttr.getValue();
278
279 // Check whether this attribute contains a nested symbol, just to make
280 // sure we revisit this op after dedup and update any symbols.
281 value.walk([&](FlatSymbolRefAttr) { hasSeenSymbol = true; });
282
283 // Skip names and annotations, except in certain cases.
284 // Names of ports are load bearing for classes, so we do hash those.
285 bool isClassPortNames =
286 isa<ClassLike>(op) && name == constants.portNamesAttr;
287 if (constants.nonessentialAttributes.contains(name) && !isClassPortNames)
288 continue;
289
290 // Hash the attribute name (an interned pointer).
291 update(name.getAsOpaquePointer());
292
293 // Hash the port types.
294 if (name == constants.portTypesAttr) {
295 auto portTypes = cast<ArrayAttr>(value).getAsValueRange<TypeAttr>();
296 for (auto type : portTypes)
297 update(type);
298 continue;
299 }
300
301 // For instance op, don't use `moduleName` attributes since they might be
302 // replaced by dedup. Record the names and lazily combine their hashes.
303 // It is assumed that module names are hashed only through instance ops;
304 // it could cause suboptimal results if there was other operation that
305 // refers to module names through essential attributes.
306 if (isa<InstanceOp>(op) && name == constants.moduleNameAttr) {
307 referredModuleNames.push_back(cast<FlatSymbolRefAttr>(value).getAttr());
308 continue;
309 }
310
311 if (isa<InstanceChoiceOp>(op) && name == constants.moduleNamesAttr) {
312 for (auto module : cast<ArrayAttr>(value))
313 referredModuleNames.push_back(
314 cast<FlatSymbolRefAttr>(module).getAttr());
315 continue;
316 }
317
318 // TODO: properly handle DistinctAttr, including its use in paths.
319 // See https://github.com/llvm/circt/issues/6583.
320 if (isa<DistinctAttr>(value))
321 continue;
322
323 // If this is an symbol reference, we need to perform name erasure.
324 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(value)) {
325 update(getInnerSymID(innerRef.getName()));
326 continue;
327 }
328
329 // We don't need to handle this attribute specially, so hash its unique
330 // address.
331 update(value.getAsOpaquePointer());
332 }
333 }
334
335 void update(mlir::OperationName name) {
336 // Operation names are interned.
337 update(name.getAsOpaquePointer());
338 }
339
340 // NOLINTNEXTLINE(misc-no-recursion)
341 void update(Block *block) {
342 for (auto &op : llvm::reverse(*block))
343 update(&op);
344 for (auto type : block->getArgumentTypes())
345 update(type);
346 update(finalizeID(block));
347 update(position);
348 ++position;
349 }
350
351 // NOLINTNEXTLINE(misc-no-recursion)
352 void update(Region *region) {
353 for (auto &block : llvm::reverse(region->getBlocks()))
354 update(&block);
355 update(position);
356 ++position;
357 }
358
359 // NOLINTNEXTLINE(misc-no-recursion)
360 void update(Operation *op) {
361 // Hash the regions. We need to make sure an empty region doesn't hash the
362 // same as no region, so we include the number of regions.
363 update(op->getNumRegions());
364 for (auto &region : reverse(op->getRegions()))
365 update(&region);
366
367 update(op->getName());
368
369 // Record the uses for later hashing.
370 for (auto &operand : op->getOpOperands())
371 update(operand);
372
373 // This happens after the numbering above, as it uses blockarg numbering
374 // for inner symbols.
375 hasSeenSymbol = false;
376 update(op, op->getAttrDictionary());
377
378 // Record any op results (types).
379 for (auto result : op->getResults())
380 update(result);
381
382 // If any of the operands, attributes, or results depended on symbols,
383 // remember this op for later adjustment.
384 if (hasSeenSymbol)
385 symbolSensitiveOps.push_back(op);
386
387 // Incorporate the hash of uses we have already built.
388 update(finalizeID(op));
389 update(position);
390 ++position;
391 }
392
393 // A map from an operation/block, to its identifier.
394 DenseMap<void *, unsigned> idTable;
395 unsigned nextID = 0;
396
397 // A map from an inner symbol, to its identifier.
398 DenseMap<StringAttr, std::pair<size_t, size_t>> innerSymIDTable;
399
400 // This keeps track of module names in the order of the appearance.
401 std::vector<StringAttr> referredModuleNames;
402
403 // String constants.
405
406 // This is the actual running hash calculation. This is a stateful element
407 // that should be reinitialized after each hash is produced.
408 llvm::SHA256 sha;
409
410 // The index of the current op. Increment after handling each op.
411 size_t position = 0;
412
413 // The operations that contain references to symbols that may be changed by
414 // dedup. These need a fixup pass after dedup.
415 bool hasSeenSymbol = false;
416 std::vector<Operation *> symbolSensitiveOps;
417};
418
419/// A reference to a `ModuleInfo` that compares and hashes like it. This is used
420/// to keep the potentially heavy module infos in a vector, and then populate
421/// maps with references to them.
426
427/// Allow `ModuleInfoRef` to be used as dense map keys. Hashes and compares the
428/// `ModuleInfo` the ref points to.
429template <>
431 static unsigned getHashValue(const ModuleInfoRef &ref) {
432 // We assume SHA256 is already a good hash and just truncate down to the
433 // number of bytes we need for DenseMap.
434 unsigned hash;
435 std::memcpy(&hash, ref.info->structuralHash.data(), sizeof(unsigned));
436
437 // Combine module names.
438 return llvm::hash_combine(
439 hash, llvm::hash_combine_range(ref.info->referredModuleNames.begin(),
440 ref.info->referredModuleNames.end()));
441 }
442
443 static bool isEqual(const ModuleInfoRef &lhs, const ModuleInfoRef &rhs) {
444 if (!lhs.info || !rhs.info)
445 return lhs.info == rhs.info;
446 return *lhs.info == *rhs.info;
447 }
448};
449
450//===----------------------------------------------------------------------===//
451// Equivalence
452//===----------------------------------------------------------------------===//
453
454/// This class is for reporting differences between two modules which should
455/// have been deduplicated.
459 noDedupClass = StringAttr::get(context, noDedupAnnoClass);
460 dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
461 portDirectionsAttr = StringAttr::get(context, "portDirections");
462 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
463 nonessentialAttributes.insert(StringAttr::get(context, "name"));
464 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
465 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
466 nonessentialAttributes.insert(StringAttr::get(context, "portTypes"));
467 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
468 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
469 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
470 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
471 }
472
480
481 std::string prettyPrint(Attribute attr) {
482 SmallString<64> buffer;
483 llvm::raw_svector_ostream os(buffer);
484 if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
485 os << "0x";
486 if (integerAttr.getType().isSignlessInteger())
487 integerAttr.getValue().toStringUnsigned(buffer, /*radix=*/16);
488 else
489 integerAttr.getAPSInt().toString(buffer, /*radix=*/16);
490
491 } else
492 os << attr;
493 return std::string(buffer);
494 }
495
496 // NOLINTNEXTLINE(misc-no-recursion)
497 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
498 Operation *a, BundleType aType, Operation *b,
499 BundleType bType) {
500 if (aType.getNumElements() != bType.getNumElements()) {
501 diag.attachNote(a->getLoc())
502 << message << " bundle type has different number of elements";
503 diag.attachNote(b->getLoc()) << "second operation here";
504 return failure();
505 }
506
507 for (auto elementPair :
508 llvm::zip(aType.getElements(), bType.getElements())) {
509 auto aElement = std::get<0>(elementPair);
510 auto bElement = std::get<1>(elementPair);
511 if (aElement.isFlip != bElement.isFlip) {
512 diag.attachNote(a->getLoc()) << message << " bundle element "
513 << aElement.name << " flip does not match";
514 diag.attachNote(b->getLoc()) << "second operation here";
515 return failure();
516 }
517
518 if (failed(check(diag,
519 message + " -> bundle element \'" +
520 aElement.name.getValue() + "'",
521 a, aElement.type, b, bElement.type)))
522 return failure();
523 }
524 return success();
525 }
526
527 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
528 Operation *a, Type aType, Operation *b, Type bType) {
529 if (aType == bType)
530 return success();
531 if (auto aBundleType = type_dyn_cast<BundleType>(aType))
532 if (auto bBundleType = type_dyn_cast<BundleType>(bType))
533 return check(diag, message, a, aBundleType, b, bBundleType);
534 if (type_isa<RefType>(aType) && type_isa<RefType>(bType) &&
535 aType != bType) {
536 diag.attachNote(a->getLoc())
537 << message << ", has a RefType with a different base type "
538 << type_cast<RefType>(aType).getType()
539 << " in the same position of the two modules marked as 'must dedup'. "
540 "(This may be due to Grand Central Taps or Views being different "
541 "between the two modules.)";
542 diag.attachNote(b->getLoc())
543 << "the second module has a different base type "
544 << type_cast<RefType>(bType).getType();
545 return failure();
546 }
547 diag.attachNote(a->getLoc())
548 << message << " types don't match, first type is " << aType;
549 diag.attachNote(b->getLoc()) << "second type is " << bType;
550 return failure();
551 }
552
553 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
554 Block &aBlock, Operation *b, Block &bBlock) {
555
556 // Block argument types.
557 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
558 auto portNo = 0;
559 auto emitMissingPort = [&](Value existsVal, Operation *opExists,
560 Operation *opDoesNotExist) {
561 StringRef portName;
562 auto portNames = opExists->getAttrOfType<ArrayAttr>("portNames");
563 if (portNames)
564 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
565 portName = portNameAttr.getValue();
566 if (type_isa<RefType>(existsVal.getType())) {
567 diag.attachNote(opExists->getLoc())
568 << " contains a RefType port named '" + portName +
569 "' that only exists in one of the modules (can be due to "
570 "difference in Grand Central Tap or View of two modules "
571 "marked with must dedup)";
572 diag.attachNote(opDoesNotExist->getLoc())
573 << "second module to be deduped that does not have the RefType "
574 "port";
575 } else {
576 diag.attachNote(opExists->getLoc())
577 << "port '" + portName + "' only exists in one of the modules";
578 diag.attachNote(opDoesNotExist->getLoc())
579 << "second module to be deduped that does not have the port";
580 }
581 return failure();
582 };
583
584 for (auto argPair :
585 llvm::zip_longest(aBlock.getArguments(), bBlock.getArguments())) {
586 auto &aArg = std::get<0>(argPair);
587 auto &bArg = std::get<1>(argPair);
588 if (aArg.has_value() && bArg.has_value()) {
589 // TODO: we should print the port number if there are no port names, but
590 // there are always port names ;).
591 StringRef portName;
592 if (portNames) {
593 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
594 portName = portNameAttr.getValue();
595 }
596 // Assumption here that block arguments correspond to ports.
597 if (failed(check(diag, "module port '" + portName + "'", a,
598 aArg->getType(), b, bArg->getType())))
599 return failure();
600 data.map.map(aArg.value(), bArg.value());
601 portNo++;
602 continue;
603 }
604 if (!aArg.has_value())
605 std::swap(a, b);
606 return emitMissingPort(aArg.has_value() ? aArg.value() : bArg.value(), a,
607 b);
608 }
609
610 // Blocks operations.
611 auto aIt = aBlock.begin();
612 auto aEnd = aBlock.end();
613 auto bIt = bBlock.begin();
614 auto bEnd = bBlock.end();
615 while (aIt != aEnd && bIt != bEnd)
616 if (failed(check(diag, data, &*aIt++, &*bIt++)))
617 return failure();
618 if (aIt != aEnd) {
619 diag.attachNote(aIt->getLoc()) << "first block has more operations";
620 diag.attachNote(b->getLoc()) << "second block here";
621 return failure();
622 }
623 if (bIt != bEnd) {
624 diag.attachNote(bIt->getLoc()) << "second block has more operations";
625 diag.attachNote(a->getLoc()) << "first block here";
626 return failure();
627 }
628 return success();
629 }
630
631 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
632 Region &aRegion, Operation *b, Region &bRegion) {
633 auto aIt = aRegion.begin();
634 auto aEnd = aRegion.end();
635 auto bIt = bRegion.begin();
636 auto bEnd = bRegion.end();
637
638 // Region blocks.
639 while (aIt != aEnd && bIt != bEnd)
640 if (failed(check(diag, data, a, *aIt++, b, *bIt++)))
641 return failure();
642 if (aIt != aEnd || bIt != bEnd) {
643 diag.attachNote(a->getLoc())
644 << "operation regions have different number of blocks";
645 diag.attachNote(b->getLoc()) << "second operation here";
646 return failure();
647 }
648 return success();
649 }
650
651 LogicalResult check(InFlightDiagnostic &diag, Operation *a,
652 mlir::DenseBoolArrayAttr aAttr, Operation *b,
653 mlir::DenseBoolArrayAttr bAttr) {
654 if (aAttr == bAttr)
655 return success();
656 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
657 for (unsigned i = 0, e = aAttr.size(); i < e; ++i) {
658 auto aDirection = aAttr[i];
659 auto bDirection = bAttr[i];
660 if (aDirection != bDirection) {
661 auto &note = diag.attachNote(a->getLoc()) << "module port ";
662 if (portNames)
663 note << "'" << cast<StringAttr>(portNames[i]).getValue() << "'";
664 else
665 note << i;
666 note << " directions don't match, first direction is '"
667 << direction::toString(aDirection) << "'";
668 diag.attachNote(b->getLoc()) << "second direction is '"
669 << direction::toString(bDirection) << "'";
670 return failure();
671 }
672 }
673 return success();
674 }
675
676 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
677 DictionaryAttr aDict, Operation *b,
678 DictionaryAttr bDict) {
679 // Fast path.
680 if (aDict == bDict)
681 return success();
682
683 DenseSet<Attribute> seenAttrs;
684 for (auto namedAttr : aDict) {
685 auto attrName = namedAttr.getName();
686 if (nonessentialAttributes.contains(attrName))
687 continue;
688
689 auto aAttr = namedAttr.getValue();
690 auto bAttr = bDict.get(attrName);
691 if (!bAttr) {
692 diag.attachNote(a->getLoc())
693 << "second operation is missing attribute " << attrName;
694 diag.attachNote(b->getLoc()) << "second operation here";
695 return diag;
696 }
697
698 if (isa<hw::InnerRefAttr>(aAttr) && isa<hw::InnerRefAttr>(bAttr)) {
699 auto bRef = cast<hw::InnerRefAttr>(bAttr);
700 auto aRef = cast<hw::InnerRefAttr>(aAttr);
701 // See if they are pointing at the same operation or port.
702 auto aTarget = data.a.lookup(aRef.getName());
703 auto bTarget = data.b.lookup(bRef.getName());
704 if (!aTarget || !bTarget)
705 diag.attachNote(a->getLoc())
706 << "malformed ir, possibly violating use-before-def";
707 auto error = [&]() {
708 diag.attachNote(a->getLoc())
709 << "operations have different targets, first operation has "
710 << aTarget;
711 diag.attachNote(b->getLoc()) << "second operation has " << bTarget;
712 return failure();
713 };
714 if (aTarget.isPort()) {
715 // If they are targeting ports, make sure its the same port number.
716 if (!bTarget.isPort() || aTarget.getPort() != bTarget.getPort())
717 return error();
718 } else {
719 // Otherwise make sure that they are targeting the same operation.
720 if (!bTarget.isOpOnly() ||
721 data.map.lookupOrNull(aTarget.getOp()) != bTarget.getOp())
722 return error();
723 }
724 if (aTarget.getField() != bTarget.getField())
725 return error();
726 } else if (attrName == portDirectionsAttr) {
727 // Special handling for the port directions attribute for better
728 // error messages.
729 if (failed(check(diag, a, cast<mlir::DenseBoolArrayAttr>(aAttr), b,
730 cast<mlir::DenseBoolArrayAttr>(bAttr))))
731 return failure();
732 } else if (isa<DistinctAttr>(aAttr) && isa<DistinctAttr>(bAttr)) {
733 // TODO: properly handle DistinctAttr, including its use in paths.
734 // See https://github.com/llvm/circt/issues/6583
735 } else if (aAttr != bAttr) {
736 diag.attachNote(a->getLoc())
737 << "first operation has attribute '" << attrName.getValue()
738 << "' with value " << prettyPrint(aAttr);
739 diag.attachNote(b->getLoc())
740 << "second operation has value " << prettyPrint(bAttr);
741 return failure();
742 }
743 seenAttrs.insert(attrName);
744 }
745 if (aDict.getValue().size() != bDict.getValue().size()) {
746 for (auto namedAttr : bDict) {
747 auto attrName = namedAttr.getName();
748 // Skip the attribute if we don't care about this particular one or it
749 // is one that is known to be in both dictionaries.
750 if (nonessentialAttributes.contains(attrName) ||
751 seenAttrs.contains(attrName))
752 continue;
753 // We have found an attribute that is only in the second operation.
754 diag.attachNote(a->getLoc())
755 << "first operation is missing attribute " << attrName;
756 diag.attachNote(b->getLoc()) << "second operation here";
757 return failure();
758 }
759 }
760 return success();
761 }
762
763 // NOLINTNEXTLINE(misc-no-recursion)
764 LogicalResult check(InFlightDiagnostic &diag, igraph::InstanceOpInterface a,
765 igraph::InstanceOpInterface b) {
766 // Get the list of module names from the list (for InstanceOp/ObjectOp,
767 // there's only one)
768 auto aNames = a.getReferencedModuleNamesAttr();
769 auto bNames = b.getReferencedModuleNamesAttr();
770 if (aNames == bNames)
771 return success();
772
773 if (aNames.size() != bNames.size()) {
774 diag.attachNote(a->getLoc())
775 << "an instance has a different number of referenced "
776 "modules: first instance has "
777 << aNames.size() << " modules";
778 diag.attachNote(b->getLoc())
779 << "second instance has " << bNames.size() << " modules";
780 return failure();
781 }
782
783 for (auto [aName, bName] : llvm::zip(aNames.getAsRange<StringAttr>(),
784 bNames.getAsRange<StringAttr>())) {
785 if (aName == bName)
786 continue;
787 // If the modules instantiate are different we will want to know why the
788 // sub module did not dedupliate. This code recursively checks the child
789 // module.
790 auto aModule = instanceGraph.lookup(aName)->getModule();
791 auto bModule = instanceGraph.lookup(bName)->getModule();
792 // Create a new error for the submodule.
793 diag.attachNote(std::nullopt)
794 << "in instance " << a.getInstanceNameAttr() << " of " << aName
795 << ", and instance " << b.getInstanceNameAttr() << " of " << bName;
796 check(diag, aModule, bModule);
797 }
798 return failure();
799 }
800
801 // NOLINTNEXTLINE(misc-no-recursion)
802 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
803 Operation *b) {
804 // Operation name.
805 if (a->getName() != b->getName()) {
806 diag.attachNote(a->getLoc()) << "first operation is a " << a->getName();
807 diag.attachNote(b->getLoc()) << "second operation is a " << b->getName();
808 return failure();
809 }
810
811 // If it's a firrtl operation that implements InstanceOpInterface
812 // (InstanceOp/InstanceChoiceOp/ObjectOp) perform some checking and possibly
813 // recurse.
814 if (auto aInst = dyn_cast<igraph::InstanceOpInterface>(a))
815 if (auto bInst = dyn_cast<igraph::InstanceOpInterface>(b))
816 if (isa_and_nonnull<firrtl::FIRRTLDialect>(a->getDialect()) &&
817 isa_and_nonnull<firrtl::FIRRTLDialect>(b->getDialect()) &&
818 failed(check(diag, aInst, bInst)))
819 return failure();
820
821 // Operation results.
822 if (a->getNumResults() != b->getNumResults()) {
823 diag.attachNote(a->getLoc())
824 << "operations have different number of results";
825 diag.attachNote(b->getLoc()) << "second operation here";
826 return failure();
827 }
828 for (auto resultPair : llvm::zip(a->getResults(), b->getResults())) {
829 auto &aValue = std::get<0>(resultPair);
830 auto &bValue = std::get<1>(resultPair);
831 if (failed(check(diag, "operation result", a, aValue.getType(), b,
832 bValue.getType())))
833 return failure();
834 data.map.map(aValue, bValue);
835 }
836
837 // Operations operands.
838 if (a->getNumOperands() != b->getNumOperands()) {
839 diag.attachNote(a->getLoc())
840 << "operations have different number of operands";
841 diag.attachNote(b->getLoc()) << "second operation here";
842 return failure();
843 }
844 for (auto operandPair : llvm::zip(a->getOperands(), b->getOperands())) {
845 auto &aValue = std::get<0>(operandPair);
846 auto &bValue = std::get<1>(operandPair);
847 if (bValue != data.map.lookup(aValue)) {
848 diag.attachNote(a->getLoc())
849 << "operations use different operands, first operand is '"
850 << getFieldName(
851 getFieldRefFromValue(aValue, /*lookThroughCasts=*/true))
852 .first
853 << "'";
854 diag.attachNote(b->getLoc())
855 << "second operand is '"
856 << getFieldName(
857 getFieldRefFromValue(bValue, /*lookThroughCasts=*/true))
858 .first
859 << "', but should have been '"
860 << getFieldName(getFieldRefFromValue(data.map.lookup(aValue),
861 /*lookThroughCasts=*/true))
862 .first
863 << "'";
864 return failure();
865 }
866 }
867 data.map.map(a, b);
868
869 // Operation regions.
870 if (a->getNumRegions() != b->getNumRegions()) {
871 diag.attachNote(a->getLoc())
872 << "operations have different number of regions";
873 diag.attachNote(b->getLoc()) << "second operation here";
874 return failure();
875 }
876 for (auto regionPair : llvm::zip(a->getRegions(), b->getRegions())) {
877 auto &aRegion = std::get<0>(regionPair);
878 auto &bRegion = std::get<1>(regionPair);
879 if (failed(check(diag, data, a, aRegion, b, bRegion)))
880 return failure();
881 }
882
883 // Operation attributes.
884 if (failed(check(diag, data, a, a->getAttrDictionary(), b,
885 b->getAttrDictionary())))
886 return failure();
887 return success();
888 }
889
890 // NOLINTNEXTLINE(misc-no-recursion)
891 void check(InFlightDiagnostic &diag, Operation *a, Operation *b) {
892 hw::InnerSymbolTable aTable(a);
893 hw::InnerSymbolTable bTable(b);
894 ModuleData data(aTable, bTable);
896 diag.attachNote(a->getLoc()) << "module marked NoDedup";
897 return;
898 }
900 diag.attachNote(b->getLoc()) << "module marked NoDedup";
901 return;
902 }
903 auto aSymbol = cast<mlir::SymbolOpInterface>(a);
904 auto bSymbol = cast<mlir::SymbolOpInterface>(b);
905 if (!canRemoveModule(aSymbol) && !canRemoveModule(bSymbol)) {
906 diag.attachNote(a->getLoc())
907 << "module is "
908 << (aSymbol.isPrivate() ? "private but not discardable" : "public");
909 diag.attachNote(b->getLoc())
910 << "module is "
911 << (bSymbol.isPrivate() ? "private but not discardable" : "public");
912 return;
913 }
914 auto aGroup =
915 dyn_cast_or_null<StringAttr>(a->getDiscardableAttr(dedupGroupAttrName));
916 auto bGroup = dyn_cast_or_null<StringAttr>(
917 b->getAttrOfType<StringAttr>(dedupGroupAttrName));
918 if (aGroup != bGroup) {
919 if (bGroup) {
920 diag.attachNote(b->getLoc())
921 << "module is in dedup group '" << bGroup.str() << "'";
922 } else {
923 diag.attachNote(b->getLoc()) << "module is not part of a dedup group";
924 }
925 if (aGroup) {
926 diag.attachNote(a->getLoc())
927 << "module is in dedup group '" << aGroup.str() << "'";
928 } else {
929 diag.attachNote(a->getLoc()) << "module is not part of a dedup group";
930 }
931 return;
932 }
933 if (failed(check(diag, data, a, b)))
934 return;
935 diag.attachNote(a->getLoc()) << "first module here";
936 diag.attachNote(b->getLoc()) << "second module here";
937 }
938
939 // This is a cached "portDirections" string attr.
941 // This is a cached "NoDedup" annotation class string attr.
942 StringAttr noDedupClass;
943 // This is a cached string attr for the dedup group attribute.
945
946 // This is a set of every attribute we should ignore.
947 DenseSet<Attribute> nonessentialAttributes;
949};
950
951//===----------------------------------------------------------------------===//
952// Deduplication
953//===----------------------------------------------------------------------===//
954
955// Custom location merging. This only keeps track of 8 annotations from ".fir"
956// files, and however many annotations come from "real" sources. When
957// deduplicating, modules tend not to have scala source locators, so we wind
958// up fusing source locators for a module from every copy being deduped. There
959// is little value in this (all the modules are identical by definition).
960static Location mergeLoc(MLIRContext *context, Location to, Location from) {
961 // Unique the set of locations to be fused.
963 // only track 8 "fir" locations
964 unsigned seenFIR = 0;
965 for (auto loc : {to, from}) {
966 // If the location is a fused location we decompose it if it has no
967 // metadata or the metadata is the same as the top level metadata.
968 if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
969 // UnknownLoc's have already been removed from FusedLocs so we can
970 // simply add all of the internal locations.
971 for (auto loc : fusedLoc.getLocations()) {
972 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
973 if (fileLoc.getFilename().strref().ends_with(".fir")) {
974 ++seenFIR;
975 if (seenFIR > 8)
976 continue;
977 }
978 }
979 decomposedLocs.insert(loc);
980 }
981 continue;
982 }
983
984 // Might need to skip this fir.
985 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
986 if (fileLoc.getFilename().strref().ends_with(".fir")) {
987 ++seenFIR;
988 if (seenFIR > 8)
989 continue;
990 }
991 }
992 // Otherwise, only add known locations to the set.
993 if (!isa<UnknownLoc>(loc))
994 decomposedLocs.insert(loc);
995 }
996
997 auto locs = decomposedLocs.getArrayRef();
998
999 // Handle the simple cases of less than two locations. Ensure the metadata (if
1000 // provided) is not dropped.
1001 if (locs.empty())
1002 return UnknownLoc::get(context);
1003 if (locs.size() == 1)
1004 return locs.front();
1005
1006 return FusedLoc::get(context, locs);
1007}
1008
1009struct Deduper {
1010
1011 using RenameMap = DenseMap<StringAttr, StringAttr>;
1012
1014 NLATable *nlaTable, CircuitOp circuit)
1015 : context(circuit->getContext()), instanceGraph(instanceGraph),
1017 nlaBlock(circuit.getBodyBlock()),
1018 nonLocalString(StringAttr::get(context, "circt.nonlocal")),
1019 classString(StringAttr::get(context, "class")) {
1020 // Populate the NLA cache.
1021 for (auto nla : circuit.getOps<hw::HierPathOp>())
1022 nlaCache[nla.getNamepathAttr()] = nla.getSymNameAttr();
1023 }
1024
1025 /// Remove the "fromModule", and replace all references to it with the
1026 /// "toModule". Modules should be deduplicated in a bottom-up order. Any
1027 /// module which is not deduplicated needs to be recorded with the `record`
1028 /// call.
1029 void dedup(FModuleLike toModule, FModuleLike fromModule) {
1030 // A map of operation (e.g. wires, nodes) names which are changed, which is
1031 // used to update NLAs that reference the "fromModule".
1032 RenameMap renameMap;
1033
1034 // Merge the port locations.
1035 SmallVector<Attribute> newLocs;
1036 for (auto [toLoc, fromLoc] : llvm::zip(toModule.getPortLocations(),
1037 fromModule.getPortLocations())) {
1038 if (toLoc == fromLoc)
1039 newLocs.push_back(toLoc);
1040 else
1041 newLocs.push_back(mergeLoc(context, cast<LocationAttr>(toLoc),
1042 cast<LocationAttr>(fromLoc)));
1043 }
1044 toModule->setAttr("portLocations", ArrayAttr::get(context, newLocs));
1045
1046 // Merge the two modules.
1047 mergeOps(renameMap, toModule, toModule, fromModule, fromModule);
1048
1049 // Rewrite NLAs pathing through these modules to refer to the to module. It
1050 // is safe to do this at this point because NLAs cannot be one element long.
1051 // This means that all NLAs which require more context cannot be targetting
1052 // something in the module it self.
1053 if (auto to = dyn_cast<FModuleOp>(*toModule))
1054 rewriteModuleNLAs(renameMap, to, cast<FModuleOp>(*fromModule));
1055 else
1056 rewriteExtModuleNLAs(renameMap, toModule.getModuleNameAttr(),
1057 fromModule.getModuleNameAttr());
1058
1059 replaceInstances(toModule, fromModule);
1060 }
1061
1062 /// Record the usages of any NLA's in this module, so that we may update the
1063 /// annotation if the parent module is deduped with another module.
1064 void record(FModuleLike module) {
1065 // Record any annotations on the module.
1066 recordAnnotations(module);
1067 // Record port annotations.
1068 for (unsigned i = 0, e = getNumPorts(module); i < e; ++i)
1070 // Record any annotations in the module body.
1071 module->walk([&](Operation *op) { recordAnnotations(op); });
1072 }
1073
1074private:
1075 /// Get a cached namespace for a module.
1077 return moduleNamespaces.try_emplace(module, cast<FModuleLike>(module))
1078 .first->second;
1079 }
1080
1081 /// For a specific annotation target, record all the unique NLAs which
1082 /// target it in the `targetMap`.
1084 for (auto anno : target.getAnnotations())
1085 if (auto nlaRef = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
1086 targetMap[nlaRef.getAttr()].insert(target);
1087 }
1088
1089 /// Record all targets which use an NLA.
1090 void recordAnnotations(Operation *op) {
1091 // Record annotations.
1092 recordAnnotations(OpAnnoTarget(op));
1093
1094 // Record port annotations only if this is a mem operation.
1095 auto mem = dyn_cast<MemOp>(op);
1096 if (!mem)
1097 return;
1098
1099 // Record port annotations.
1100 for (unsigned i = 0, e = mem->getNumResults(); i < e; ++i)
1101 recordAnnotations(PortAnnoTarget(mem, i));
1102 }
1103
1104 /// This deletes and replaces all instances of the "fromModule" with instances
1105 /// of the "toModule".
1106 void replaceInstances(FModuleLike toModule, Operation *fromModule) {
1107 // Replace all instances of the other module.
1108 auto *fromNode =
1109 instanceGraph[::cast<igraph::ModuleOpInterface>(fromModule)];
1110 auto *toNode = instanceGraph[toModule];
1111 auto toModuleRef = FlatSymbolRefAttr::get(toModule.getModuleNameAttr());
1112 for (auto *oldInstRec : llvm::make_early_inc_range(fromNode->uses())) {
1113 auto inst = oldInstRec->getInstance();
1114 if (auto instOp = dyn_cast<InstanceOp>(*inst)) {
1115 instOp.setModuleNameAttr(toModuleRef);
1116 instOp.setPortNamesAttr(toModule.getPortNamesAttr());
1117 } else if (auto objectOp = dyn_cast<ObjectOp>(*inst)) {
1118 auto classLike = cast<ClassLike>(*toNode->getModule());
1119 ClassType classType = detail::getInstanceTypeForClassLike(classLike);
1120 objectOp.getResult().setType(classType);
1121 } else if (auto instanceChoiceOp = dyn_cast<InstanceChoiceOp>(*inst)) {
1122 auto fromModuleName = fromNode->getModule().getModuleNameAttr();
1123 SmallVector<Attribute> newModules;
1124 for (auto module : instanceChoiceOp.getReferencedModuleNamesAttr()) {
1125 auto moduleName = cast<StringAttr>(module);
1126 if (moduleName == fromModuleName)
1127 newModules.push_back(toModuleRef);
1128 else
1129 newModules.push_back(FlatSymbolRefAttr::get(moduleName));
1130 }
1131 instanceChoiceOp.setModuleNamesAttr(
1132 ArrayAttr::get(context, newModules));
1133 instanceChoiceOp.setPortNamesAttr(toModule.getPortNamesAttr());
1134 }
1135 oldInstRec->getParent()->addInstance(inst, toNode);
1136 oldInstRec->erase();
1137 }
1138 instanceGraph.erase(fromNode);
1139 fromModule->erase();
1140 }
1141
1142 /// Look up the instantiations of the `from` module and create an NLA for each
1143 /// one, appending the baseNamepath to each NLA. This is used to add more
1144 /// context to an already existing NLA. The `fromModule` is used to indicate
1145 /// which module the annotation is coming from before the merge, and will be
1146 /// used to create the namepaths.
1147 SmallVector<FlatSymbolRefAttr>
1148 createNLAs(Operation *fromModule, ArrayRef<Attribute> baseNamepath,
1149 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1150 // Create an attribute array with a placeholder in the first element, where
1151 // the root refence of the NLA will be inserted.
1152 SmallVector<Attribute> namepath = {nullptr};
1153 namepath.append(baseNamepath.begin(), baseNamepath.end());
1154
1155 auto loc = fromModule->getLoc();
1156 auto *fromNode = instanceGraph[cast<igraph::ModuleOpInterface>(fromModule)];
1157 SmallVector<FlatSymbolRefAttr> nlas;
1158 for (auto *instanceRecord : fromNode->uses()) {
1159 auto parent = cast<FModuleOp>(*instanceRecord->getParent()->getModule());
1160 auto inst = instanceRecord->getInstance();
1161 namepath[0] = OpAnnoTarget(inst).getNLAReference(getNamespace(parent));
1162 auto arrayAttr = ArrayAttr::get(context, namepath);
1163 // Check the NLA cache to see if we already have this NLA.
1164 auto &cacheEntry = nlaCache[arrayAttr];
1165 if (!cacheEntry) {
1166 auto builder = OpBuilder::atBlockBegin(nlaBlock);
1167 auto nla = hw::HierPathOp::create(builder, loc, "nla", arrayAttr);
1168 // Insert it into the symbol table to get a unique name.
1169 symbolTable.insert(nla);
1170 // Store it in the cache.
1171 cacheEntry = nla.getNameAttr();
1172 nla.setVisibility(vis);
1173 nlaTable->addNLA(nla);
1174 }
1175 auto nlaRef = FlatSymbolRefAttr::get(cast<StringAttr>(cacheEntry));
1176 nlas.push_back(nlaRef);
1177 }
1178 return nlas;
1179 }
1180
1181 /// Look up the instantiations of this module and create an NLA for each one.
1182 /// This returns an array of symbol references which can be used to reference
1183 /// the NLAs.
1184 SmallVector<FlatSymbolRefAttr>
1185 createNLAs(StringAttr toModuleName, FModuleLike fromModule,
1186 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1187 return createNLAs(fromModule, FlatSymbolRefAttr::get(toModuleName), vis);
1188 }
1189
1190 /// Clone the annotation for each NLA in a list. The attribute list should
1191 /// have a placeholder for the "circt.nonlocal" field, and `nonLocalIndex`
1192 /// should be the index of this field.
1193 void cloneAnnotation(SmallVectorImpl<FlatSymbolRefAttr> &nlas,
1194 Annotation anno, ArrayRef<NamedAttribute> attributes,
1195 unsigned nonLocalIndex,
1196 SmallVectorImpl<Annotation> &newAnnotations) {
1197 SmallVector<NamedAttribute> mutableAttributes(attributes.begin(),
1198 attributes.end());
1199 for (auto &nla : nlas) {
1200 // Add the new annotation.
1201 mutableAttributes[nonLocalIndex].setValue(nla);
1202 auto dict = DictionaryAttr::getWithSorted(context, mutableAttributes);
1203 // The original annotation records if its a subannotation.
1204 anno.setDict(dict);
1205 newAnnotations.push_back(anno);
1206 }
1207 }
1208
1209 /// This erases the NLA op, and removes the NLA from every module's NLA map,
1210 /// but it does not delete the NLA reference from the target operation's
1211 /// annotations.
1212 void eraseNLA(hw::HierPathOp nla) {
1213 // Erase the NLA from the leaf module's nlaMap.
1214 targetMap.erase(nla.getNameAttr());
1215 nlaTable->erase(nla);
1216 nlaCache.erase(nla.getNamepathAttr());
1217 symbolTable.erase(nla);
1218 }
1219
1220 /// Process all NLAs referencing the "from" module to point to the "to"
1221 /// module. This is used after merging two modules together.
1222 void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule,
1223 FModuleOp fromModule) {
1224 auto toName = toModule.getNameAttr();
1225 auto fromName = fromModule.getNameAttr();
1226 // Create a copy of the current NLAs. We will be pushing and removing
1227 // NLAs from this op as we go.
1228 auto moduleNLAs = nlaTable->lookup(fromModule.getNameAttr()).vec();
1229 // Change the NLA to target the toModule.
1230 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1231 // Now we walk the NLA searching for ones that require more context to be
1232 // added.
1233 for (auto nla : moduleNLAs) {
1234 auto elements = nla.getNamepath().getValue();
1235 // If we don't need to add more context, we're done here.
1236 if (nla.root() != toName)
1237 continue;
1238 // Create the replacement NLAs.
1239 SmallVector<Attribute> namepath(elements.begin(), elements.end());
1240 auto nlaRefs = createNLAs(fromModule, namepath, nla.getVisibility());
1241 // Copy out the targets, because we will be updating the map.
1242 auto &set = targetMap[nla.getSymNameAttr()];
1243 SmallVector<AnnoTarget> targets(set.begin(), set.end());
1244 // Replace the uses of the old NLA with the new NLAs.
1245 for (auto target : targets) {
1246 // We have to clone any annotation which uses the old NLA for each new
1247 // NLA. This array collects the new set of annotations.
1248 SmallVector<Annotation> newAnnotations;
1249 for (auto anno : target.getAnnotations()) {
1250 // Find the non-local field of the annotation.
1251 auto [it, found] = mlir::impl::findAttrSorted(
1252 anno.begin(), anno.end(), nonLocalString);
1253 // If this annotation doesn't use the target NLA, copy it with no
1254 // changes.
1255 if (!found || cast<FlatSymbolRefAttr>(it->getValue()).getAttr() !=
1256 nla.getSymNameAttr()) {
1257 newAnnotations.push_back(anno);
1258 continue;
1259 }
1260 auto nonLocalIndex = std::distance(anno.begin(), it);
1261 // Clone the annotation and add it to the list of new annotations.
1262 cloneAnnotation(nlaRefs, anno,
1263 ArrayRef<NamedAttribute>(anno.begin(), anno.end()),
1264 nonLocalIndex, newAnnotations);
1265 }
1266
1267 // Apply the new annotations to the operation.
1268 AnnotationSet annotations(newAnnotations, context);
1269 target.setAnnotations(annotations);
1270 // Record that target uses the NLA.
1271 for (auto nla : nlaRefs)
1272 targetMap[nla.getAttr()].insert(target);
1273 }
1274
1275 // Erase the old NLA and remove it from all breadcrumbs.
1276 eraseNLA(nla);
1277 }
1278 }
1279
1280 /// Process all the NLAs that the two modules participate in, replacing
1281 /// references to the "from" module with references to the "to" module, and
1282 /// adding more context if necessary.
1283 void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule,
1284 FModuleOp fromModule) {
1285 addAnnotationContext(renameMap, toModule, toModule);
1286 addAnnotationContext(renameMap, toModule, fromModule);
1287 }
1288
1289 // Update all NLAs which the "from" external module participates in to the
1290 // "toName".
1291 void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName,
1292 StringAttr fromName) {
1293 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1294 }
1295
1296 /// Take an annotation, and update it to be a non-local annotation. If the
1297 /// annotation is already non-local and has enough context, it will be skipped
1298 /// for now. Return true if the annotation was made non-local.
1299 bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to,
1300 FModuleLike fromModule, Annotation anno,
1301 SmallVectorImpl<Annotation> &newAnnotations) {
1302 // Start constructing a new annotation, pushing a "circt.nonLocal" field
1303 // into the correct spot if its not already a non-local annotation.
1304 SmallVector<NamedAttribute> attributes;
1305 int nonLocalIndex = -1;
1306 for (const auto &val : llvm::enumerate(anno)) {
1307 auto attr = val.value();
1308 // Is this field "circt.nonlocal"?
1309 auto compare = attr.getName().compare(nonLocalString);
1310 assert(compare != 0 && "should not pass non-local annotations here");
1311 if (compare == 1) {
1312 // This annotation definitely does not have "circt.nonlocal" field. Push
1313 // an empty place holder for the non-local annotation.
1314 nonLocalIndex = val.index();
1315 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1316 break;
1317 }
1318 // Otherwise push the current attribute and keep searching for the
1319 // "circt.nonlocal" field.
1320 attributes.push_back(attr);
1321 }
1322 if (nonLocalIndex == -1) {
1323 // Push an empty "circt.nonlocal" field to the last slot.
1324 nonLocalIndex = attributes.size();
1325 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1326 } else {
1327 // Copy the remaining annotation fields in.
1328 attributes.append(anno.begin() + nonLocalIndex, anno.end());
1329 }
1330
1331 // Construct the NLAs if we don't have any yet.
1332 auto nlaRefs = createNLAs(toModuleName, fromModule);
1333 for (auto nla : nlaRefs)
1334 targetMap[nla.getAttr()].insert(to);
1335
1336 // Clone the annotation for each new NLA.
1337 cloneAnnotation(nlaRefs, anno, attributes, nonLocalIndex, newAnnotations);
1338 return true;
1339 }
1340
1341 void copyAnnotations(FModuleLike toModule, AnnoTarget to,
1342 FModuleLike fromModule, AnnotationSet annos,
1343 SmallVectorImpl<Annotation> &newAnnotations,
1344 SmallPtrSetImpl<Attribute> &dontTouches) {
1345 for (auto anno : annos) {
1346 if (anno.isClass(dontTouchAnnoClass)) {
1347 // Remove the nonlocal field of the annotation if it has one, since this
1348 // is a sticky annotation.
1349 anno.removeMember("circt.nonlocal");
1350 auto [it, inserted] = dontTouches.insert(anno.getAttr());
1351 if (inserted)
1352 newAnnotations.push_back(anno);
1353 continue;
1354 }
1355 // If the annotation is already non-local, we add it as is. It is already
1356 // added to the target map.
1357 if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal")) {
1358 newAnnotations.push_back(anno);
1359 targetMap[nla.getAttr()].insert(to);
1360 continue;
1361 }
1362 // Otherwise make the annotation non-local and add it to the set.
1363 makeAnnotationNonLocal(toModule.getModuleNameAttr(), to, fromModule, anno,
1364 newAnnotations);
1365 }
1366 }
1367
1368 /// Merge the annotations of a specific target, either a operation or a port
1369 /// on an operation.
1370 void mergeAnnotations(FModuleLike toModule, AnnoTarget to,
1371 AnnotationSet toAnnos, FModuleLike fromModule,
1372 AnnoTarget from, AnnotationSet fromAnnos) {
1373 // This is a list of all the annotations which will be added to `to`.
1374 SmallVector<Annotation> newAnnotations;
1375
1376 // We have special case handling of DontTouch to prevent it from being
1377 // turned into a non-local annotation, and to remove duplicates.
1378 llvm::SmallPtrSet<Attribute, 4> dontTouches;
1379
1380 // Iterate the annotations, transforming most annotations into non-local
1381 // ones.
1382 copyAnnotations(toModule, to, toModule, toAnnos, newAnnotations,
1383 dontTouches);
1384 copyAnnotations(toModule, to, fromModule, fromAnnos, newAnnotations,
1385 dontTouches);
1386
1387 // Copy over all the new annotations.
1388 if (!newAnnotations.empty())
1389 to.setAnnotations(AnnotationSet(newAnnotations, context));
1390 }
1391
1392 /// Merge all annotations and port annotations on two operations.
1393 void mergeAnnotations(FModuleLike toModule, Operation *to,
1394 FModuleLike fromModule, Operation *from) {
1395 // Merge op annotations.
1396 mergeAnnotations(toModule, OpAnnoTarget(to), AnnotationSet(to), fromModule,
1397 OpAnnoTarget(from), AnnotationSet(from));
1398
1399 // Merge port annotations.
1400 if (toModule == to) {
1401 // Merge module port annotations.
1402 for (unsigned i = 0, e = getNumPorts(toModule); i < e; ++i)
1403 mergeAnnotations(toModule, PortAnnoTarget(toModule, i),
1404 AnnotationSet::forPort(toModule, i), fromModule,
1405 PortAnnoTarget(fromModule, i),
1406 AnnotationSet::forPort(fromModule, i));
1407 } else if (auto toMem = dyn_cast<MemOp>(to)) {
1408 // Merge memory port annotations.
1409 auto fromMem = cast<MemOp>(from);
1410 for (unsigned i = 0, e = toMem.getNumResults(); i < e; ++i)
1411 mergeAnnotations(toModule, PortAnnoTarget(toMem, i),
1412 AnnotationSet::forPort(toMem, i), fromModule,
1413 PortAnnoTarget(fromMem, i),
1414 AnnotationSet::forPort(fromMem, i));
1415 }
1416 }
1417
1418 hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule,
1419 hw::InnerSymAttr toSym,
1420 hw::InnerSymAttr fromSym) {
1421 if (fromSym && !fromSym.getProps().empty()) {
1422 auto &isn = getNamespace(toModule);
1423 // The properties for the new inner symbol..
1424 SmallVector<hw::InnerSymPropertiesAttr> newProps;
1425 // If the "to" op already has an inner symbol, copy all its properties.
1426 if (toSym)
1427 llvm::append_range(newProps, toSym);
1428 // Add each property from the fromSym to the toSym.
1429 for (auto fromProp : fromSym) {
1430 hw::InnerSymPropertiesAttr newProp;
1431 auto *it = llvm::find_if(newProps, [&](auto p) {
1432 return p.getFieldID() == fromProp.getFieldID();
1433 });
1434 if (it != newProps.end()) {
1435 // If we already have an inner sym with the same field id, use
1436 // that.
1437 newProp = *it;
1438 // If the old symbol is public, we need to make the new one public.
1439 if (fromProp.getSymVisibility().getValue() == "public" &&
1440 newProp.getSymVisibility().getValue() != "public") {
1441 *it = hw::InnerSymPropertiesAttr::get(context, newProp.getName(),
1442 newProp.getFieldID(),
1443 fromProp.getSymVisibility());
1444 }
1445 } else {
1446 // We need to add a new property to the inner symbol for this field.
1447 auto newName = isn.newName(fromProp.getName().getValue());
1448 newProp = hw::InnerSymPropertiesAttr::get(
1449 context, StringAttr::get(context, newName), fromProp.getFieldID(),
1450 fromProp.getSymVisibility());
1451 newProps.push_back(newProp);
1452 }
1453 renameMap[fromProp.getName()] = newProp.getName();
1454 }
1455 // Sort the fields by field id.
1456 llvm::sort(newProps, [](auto &p, auto &q) {
1457 return p.getFieldID() < q.getFieldID();
1458 });
1459 // Return the merged inner symbol.
1460 return hw::InnerSymAttr::get(context, newProps);
1461 }
1462 return hw::InnerSymAttr();
1463 }
1464
1465 // Record the symbol name change of the operation or any of its ports when
1466 // merging two operations. The renamed symbols are used to update the
1467 // target of any NLAs. This will add symbols to the "to" operation if needed.
1468 void recordSymRenames(RenameMap &renameMap, FModuleLike toModule,
1469 Operation *to, FModuleLike fromModule,
1470 Operation *from) {
1471 // If the "from" operation has an inner_sym, we need to make sure the
1472 // "to" operation also has an `inner_sym` and then record the renaming.
1473 if (auto fromInnerSym = dyn_cast<hw::InnerSymbolOpInterface>(from)) {
1474 auto toInnerSym = cast<hw::InnerSymbolOpInterface>(to);
1475 if (auto newSymAttr = mergeInnerSymbols(renameMap, toModule,
1476 toInnerSym.getInnerSymAttr(),
1477 fromInnerSym.getInnerSymAttr()))
1478 toInnerSym.setInnerSymbolAttr(newSymAttr);
1479 }
1480
1481 // If there are no port symbols on the "from" operation, we are done here.
1482 auto fromPortSyms = from->getAttrOfType<ArrayAttr>("portSymbols");
1483 if (!fromPortSyms || fromPortSyms.empty())
1484 return;
1485 // We have to map each "fromPort" to each "toPort".
1486 auto portCount = fromPortSyms.size();
1487 auto toPortSyms = to->getAttrOfType<ArrayAttr>("portSymbols");
1488
1489 // Create an array of new port symbols for the "to" operation, copy in the
1490 // old symbols if it has any, create an empty symbol array if it doesn't.
1491 SmallVector<Attribute> newPortSyms;
1492 if (toPortSyms.empty())
1493 newPortSyms.assign(portCount, hw::InnerSymAttr());
1494 else
1495 newPortSyms.assign(toPortSyms.begin(), toPortSyms.end());
1496
1497 for (unsigned portNo = 0; portNo < portCount; ++portNo) {
1498 if (auto newPortSym = mergeInnerSymbols(
1499 renameMap, toModule,
1500 llvm::cast_if_present<hw::InnerSymAttr>(newPortSyms[portNo]),
1501 cast<hw::InnerSymAttr>(fromPortSyms[portNo]))) {
1502 newPortSyms[portNo] = newPortSym;
1503 }
1504 }
1505
1506 // Commit the new symbol attribute.
1507 FModuleLike::fixupPortSymsArray(newPortSyms, toModule.getContext());
1508 cast<FModuleLike>(to).setPortSymbols(newPortSyms);
1509 }
1510
1511 /// Recursively merge two operations.
1512 // NOLINTNEXTLINE(misc-no-recursion)
1513 void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to,
1514 FModuleLike fromModule, Operation *from) {
1515 // Merge the operation locations.
1516 if (to->getLoc() != from->getLoc())
1517 to->setLoc(mergeLoc(context, to->getLoc(), from->getLoc()));
1518
1519 // Recurse into any regions.
1520 for (auto regions : llvm::zip(to->getRegions(), from->getRegions()))
1521 mergeRegions(renameMap, toModule, std::get<0>(regions), fromModule,
1522 std::get<1>(regions));
1523
1524 // Record any inner_sym renamings that happened.
1525 recordSymRenames(renameMap, toModule, to, fromModule, from);
1526
1527 // Merge the annotations.
1528 mergeAnnotations(toModule, to, fromModule, from);
1529 }
1530
1531 /// Recursively merge two blocks.
1532 void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock,
1533 FModuleLike fromModule, Block &fromBlock) {
1534 // Merge the block locations.
1535 for (auto [toArg, fromArg] :
1536 llvm::zip(toBlock.getArguments(), fromBlock.getArguments()))
1537 if (toArg.getLoc() != fromArg.getLoc())
1538 toArg.setLoc(mergeLoc(context, toArg.getLoc(), fromArg.getLoc()));
1539
1540 for (auto ops : llvm::zip(toBlock, fromBlock))
1541 mergeOps(renameMap, toModule, &std::get<0>(ops), fromModule,
1542 &std::get<1>(ops));
1543 }
1544
1545 // Recursively merge two regions.
1546 void mergeRegions(RenameMap &renameMap, FModuleLike toModule,
1547 Region &toRegion, FModuleLike fromModule,
1548 Region &fromRegion) {
1549 for (auto blocks : llvm::zip(toRegion, fromRegion))
1550 mergeBlocks(renameMap, toModule, std::get<0>(blocks), fromModule,
1551 std::get<1>(blocks));
1552 }
1553
1554 MLIRContext *context;
1556 SymbolTable &symbolTable;
1557
1558 /// Cached nla table analysis.
1559 NLATable *nlaTable = nullptr;
1560
1561 /// We insert all NLAs to the beginning of this block.
1562 Block *nlaBlock;
1563
1564 // This maps an NLA to the operations and ports that uses it.
1565 DenseMap<Attribute, llvm::SmallDenseSet<AnnoTarget>> targetMap;
1566
1567 // This is a cache to avoid creating duplicate NLAs. This maps the ArrayAtr
1568 // of the NLA's path to the name of the NLA which contains it.
1569 DenseMap<Attribute, Attribute> nlaCache;
1570
1571 // Cached attributes for faster comparisons and attribute building.
1572 StringAttr nonLocalString;
1573 StringAttr classString;
1574
1575 /// A module namespace cache.
1576 DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
1577};
1578
1579//===----------------------------------------------------------------------===//
1580// Fixup
1581//===----------------------------------------------------------------------===//
1582
1583/// This fixes up connects when the field names of a bundle type changes. It
1584/// finds all fields which were previously bulk connected and legalizes it
1585/// into a connect for each field.
1586static void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src) {
1587 // If the types already match we can emit a connect.
1588 auto dstType = dst.getType();
1589 auto srcType = src.getType();
1590 if (dstType == srcType) {
1591 emitConnect(builder, dst, src);
1592 return;
1593 }
1594 // It must be a bundle type and the field name has changed. We have to
1595 // manually decompose the bulk connect into a connect for each field.
1596 auto dstBundle = type_cast<BundleType>(dstType);
1597 auto srcBundle = type_cast<BundleType>(srcType);
1598 for (unsigned i = 0; i < dstBundle.getNumElements(); ++i) {
1599 auto dstField = SubfieldOp::create(builder, dst, i);
1600 auto srcField = SubfieldOp::create(builder, src, i);
1601 if (dstBundle.getElement(i).isFlip) {
1602 std::swap(srcBundle, dstBundle);
1603 std::swap(srcField, dstField);
1604 }
1605 fixupConnect(builder, dstField, srcField);
1606 }
1607}
1608
1609/// Adjust the symbol references in an op. This includes updating its attributes
1610/// and types.
1611static void
1612fixupSymbolSensitiveOp(Operation *op, InstanceGraph &instanceGraph,
1613 const DenseMap<Attribute, StringAttr> &dedupMap) {
1614 // If this is an instance op, dedup may have subtly changed the port types.
1615 // For example, structurally different bundles may still dedup. In this case
1616 // we now have an instance op that produces result values of the old type, but
1617 // the port info on the instantiated module already represents the new type.
1618 // Fix this up by going through an intermediate wire.
1619 if (auto instOp = dyn_cast<InstanceOp>(op)) {
1620 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp->getContext());
1621 builder.setInsertionPointAfter(instOp);
1622 auto module = instanceGraph.lookup(instOp.getModuleNameAttr().getAttr())
1623 ->getModule<FModuleLike>();
1624 for (auto [index, result] : llvm::enumerate(instOp.getResults())) {
1625 auto newType = module.getPortType(index);
1626 auto oldType = result.getType();
1627 // If the type has not changed, we don't have to fix up anything.
1628 if (newType == oldType)
1629 continue;
1630 LLVM_DEBUG(llvm::dbgs()
1631 << "- Updating instance port \"" << instOp.getInstanceName()
1632 << "." << instOp.getPortName(index) << "\" from " << oldType
1633 << " to " << newType << "\n");
1634
1635 // If the type changed we transform it back to the old type with an
1636 // intermediate wire.
1637 auto wire = WireOp::create(builder, oldType, instOp.getPortName(index))
1638 .getResult();
1639 result.replaceAllUsesWith(wire);
1640 result.setType(newType);
1641 if (instOp.getPortDirection(index) == Direction::Out)
1642 fixupConnect(builder, wire, result);
1643 else
1644 fixupConnect(builder, result, wire);
1645 }
1646 }
1647
1648 // Use an attribute/type replacer to look for references to old symbols that
1649 // need to be replaced with new symbols.
1650 mlir::AttrTypeReplacer replacer;
1651 replacer.addReplacement([&](FlatSymbolRefAttr symRef) {
1652 auto oldName = symRef.getAttr();
1653 auto newName = dedupMap.lookup(oldName);
1654 if (newName && newName != oldName) {
1655 auto newSymRef = FlatSymbolRefAttr::get(newName);
1656 LLVM_DEBUG(llvm::dbgs()
1657 << "- Updating " << symRef << " to " << newSymRef << " in "
1658 << op->getName() << " at " << op->getLoc() << "\n");
1659 return newSymRef;
1660 }
1661 return symRef;
1662 });
1663
1664 // Update attributes.
1665 op->setAttrs(cast<DictionaryAttr>(replacer.replace(op->getAttrDictionary())));
1666
1667 // Update the argument types.
1668 for (auto &region : op->getRegions())
1669 for (auto &block : region)
1670 for (auto arg : block.getArguments())
1671 arg.setType(replacer.replace(arg.getType()));
1672
1673 // Update result types.
1674 for (auto result : op->getResults())
1675 result.setType(replacer.replace(result.getType()));
1676}
1677
1678/// Adjust the symbol references in ops marked as sensitive to them. This
1679/// includes updating their attributes and types.
1681 InstanceGraph &instanceGraph,
1682 const DenseMap<Operation *, ModuleInfoRef> &moduleToModuleInfo,
1683 const DenseMap<Attribute, StringAttr> &dedupMap) {
1684 for (auto *node : instanceGraph) {
1685 // Look up the module info for this module, which contains the list of ops
1686 // that need to be updated.
1687 auto module = node->getModule<FModuleLike>();
1688 auto it = moduleToModuleInfo.find(module);
1689 if (it == moduleToModuleInfo.end())
1690 continue;
1691
1692 // Update each symbol-sensitive op individually.
1693 auto &ops = it->second.info->symbolSensitiveOps;
1694 if (ops.empty())
1695 continue;
1696 LLVM_DEBUG(llvm::dbgs()
1697 << "- Updating " << ops.size() << " symbol-sensitive ops in "
1698 << module.getNameAttr() << "\n");
1699 for (auto *op : ops)
1700 fixupSymbolSensitiveOp(op, instanceGraph, dedupMap);
1701 }
1702}
1703
1704//===----------------------------------------------------------------------===//
1705// DedupPass
1706//===----------------------------------------------------------------------===//
1707
1708namespace {
1709class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {
1710 using DedupBase::DedupBase;
1711
1712 void runOnOperation() override {
1713 auto *context = &getContext();
1714 auto circuit = getOperation();
1715 auto &instanceGraph = getAnalysis<InstanceGraph>();
1716 auto *nlaTable = &getAnalysis<NLATable>();
1717 auto &symbolTable = getAnalysis<SymbolTable>();
1718 Deduper deduper(instanceGraph, symbolTable, nlaTable, circuit);
1719 Equivalence equiv(context, instanceGraph);
1720 auto anythingChanged = false;
1721 LLVM_DEBUG({
1722 llvm::dbgs() << "\n";
1723 debugHeader(Twine("Dedup circuit \"") + circuit.getName() + "\"")
1724 << "\n\n";
1725 });
1726
1727 // Modules annotated with this should not be considered for deduplication.
1728 auto noDedupClass = StringAttr::get(context, noDedupAnnoClass);
1729
1730 // Only modules within the same group may be deduplicated.
1731 auto dedupGroupClass = StringAttr::get(context, dedupGroupAnnoClass);
1732
1733 // A map of all the module moduleInfo that we have calculated so far.
1734 DenseMap<ModuleInfoRef, Operation *> moduleInfoToModule;
1735 DenseMap<Operation *, ModuleInfoRef> moduleToModuleInfo;
1736
1737 // We track the name of the module that each module is deduped into, so that
1738 // we can make sure all modules which are marked "must dedup" with each
1739 // other were all deduped to the same module.
1740 DenseMap<Attribute, StringAttr> dedupMap;
1741
1742 // We must iterate the modules from the bottom up so that we can properly
1743 // deduplicate the modules. We copy the list of modules into a vector first
1744 // to avoid iterator invalidation while we mutate the instance graph.
1745 SmallVector<FModuleLike, 0> modules;
1746 instanceGraph.walkPostOrder([&](auto &node) {
1747 if (auto mod = dyn_cast<FModuleLike>(*node.getModule()))
1748 modules.push_back(mod);
1749 });
1750 LLVM_DEBUG(llvm::dbgs() << "Found " << modules.size() << " modules\n");
1751
1752 SmallVector<std::optional<ModuleInfo>> moduleInfos(modules.size());
1753 StructuralHasherSharedConstants hasherConstants(&getContext());
1754
1755 // Attribute name used to store dedup_group for this pass.
1756 auto dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
1757
1758 // Move dedup group annotations to attributes on the module.
1759 // This results in the desired behavior (included in hash),
1760 // and avoids unnecessary processing of these as annotations
1761 // that need to be tracked, made non-local, so on.
1762 for (auto module : modules) {
1765 module, [&groups, dedupGroupClass](Annotation annotation) {
1766 if (annotation.getClassAttr() != dedupGroupClass)
1767 return false;
1768 groups.insert(annotation.getMember<StringAttr>("group"));
1769 return true;
1770 });
1771 if (groups.size() > 1) {
1772 module.emitError("module belongs to multiple dedup groups: ") << groups;
1773 return signalPassFailure();
1774 }
1775 assert(!module->hasAttr(dedupGroupAttrName) &&
1776 "unexpected existing use of temporary dedup group attribute");
1777 if (!groups.empty())
1778 module->setDiscardableAttr(dedupGroupAttrName, groups.front());
1779 }
1780
1781 // Calculate module information parallelly.
1782 LLVM_DEBUG(llvm::dbgs() << "Computing module information\n");
1783 auto result = mlir::failableParallelForEach(
1784 context, llvm::seq(modules.size()), [&](unsigned idx) {
1785 auto module = modules[idx];
1786 // If the module is marked with NoDedup, just skip it.
1787 if (AnnotationSet::hasAnnotation(module, noDedupClass))
1788 return success();
1789
1790 // Only dedup extmodule's with defname.
1791 if (auto ext = dyn_cast<FExtModuleOp>(*module);
1792 ext && !ext.getDefname().has_value())
1793 return success();
1794
1795 // Only dedup classes if enabled.
1796 if (isa<ClassOp>(*module) && !dedupClasses)
1797 return success();
1798
1799 StructuralHasher hasher(hasherConstants);
1800 // Calculate the hash of the module and referred module names.
1801 moduleInfos[idx] = hasher.getModuleInfo(module);
1802 return success();
1803 });
1804
1805 // Dump out the module hashes for debugging.
1806 LLVM_DEBUG({
1807 auto &os = llvm::dbgs();
1808 for (auto [module, info] : llvm::zip(modules, moduleInfos)) {
1809 os << "- Hash ";
1810 if (info) {
1811 os << llvm::format_bytes(info->structuralHash, std::nullopt, 32, 32);
1812 } else {
1813 os << "--------------------------------";
1814 os << "--------------------------------";
1815 }
1816 os << " for " << module.getModuleNameAttr() << "\n";
1817 }
1818 });
1819
1820 if (result.failed())
1821 return signalPassFailure();
1822
1823 LLVM_DEBUG(llvm::dbgs() << "Update modules\n");
1824 for (auto [i, module] : llvm::enumerate(modules)) {
1825 auto moduleName = module.getModuleNameAttr();
1826 auto &maybeModuleInfo = moduleInfos[i];
1827 // If the hash was not calculated, we need to skip it.
1828 if (!maybeModuleInfo) {
1829 // We record it in the dedup map to help detect errors when the user
1830 // marks the module as both NoDedup and MustDedup. We do not record this
1831 // module in the hasher to make sure no other module dedups "into" this
1832 // one.
1833 dedupMap[moduleName] = moduleName;
1834 continue;
1835 }
1836
1837 auto &moduleInfo = maybeModuleInfo.value();
1838 moduleToModuleInfo.try_emplace(module, &moduleInfo);
1839
1840 // Replace module names referred in the module with new names.
1841 for (auto &referredModule : moduleInfo.referredModuleNames)
1842 referredModule = dedupMap[referredModule];
1843
1844 // Check if there is a module with the same hash.
1845 auto it = moduleInfoToModule.find(&moduleInfo);
1846 if (it != moduleInfoToModule.end()) {
1847 auto original = cast<FModuleLike>(it->second);
1848 auto originalName = original.getModuleNameAttr();
1849
1850 // If the current module is public, and the original is private, we
1851 // want to dedup the private module into the public one.
1852 if (!canRemoveModule(module)) {
1853 // Record that this module's name is staying the same.
1854 dedupMap[moduleName] = moduleName;
1855 // If both modules are public, then we can't dedup anything.
1856 if (!canRemoveModule(original))
1857 continue;
1858 // Swap the canonical module in the dedup map.
1859 for (auto &[_, dedupedName] : dedupMap)
1860 if (dedupedName == originalName)
1861 dedupedName = moduleName;
1862 // Update the module hash table to point to the new original, so all
1863 // future modules dedup with the new canonical module.
1864 it->second = module;
1865 // Swap the locals.
1866 std::swap(originalName, moduleName);
1867 std::swap(original, module);
1868 }
1869
1870 // Record the group ID of the other module.
1871 LLVM_DEBUG(llvm::dbgs() << "- Replace " << moduleName << " with "
1872 << originalName << "\n");
1873 dedupMap[moduleName] = originalName;
1874 deduper.dedup(original, module);
1875 ++erasedModules;
1876 anythingChanged = true;
1877 continue;
1878 }
1879 // Any module not deduplicated must be recorded.
1880 deduper.record(module);
1881 // Add the module to a new dedup group.
1882 dedupMap[moduleName] = moduleName;
1883 // Record the module info.
1884 moduleInfoToModule[&moduleInfo] = module;
1885 }
1886
1887 // This part verifies that all modules marked by "MustDedup" have been
1888 // properly deduped with each other. For this check to succeed, all modules
1889 // have to been deduped to the same module. It is possible that a module was
1890 // deduped with the wrong thing.
1891
1892 auto failed = false;
1893 // This parses the module name out of a target string.
1894 auto parseModule = [&](Attribute path) -> StringAttr {
1895 // Each module is listed as a target "~Circuit|Module" which we have to
1896 // parse.
1897 auto [_, rhs] = cast<StringAttr>(path).getValue().split('|');
1898 return StringAttr::get(context, rhs);
1899 };
1900 // This gets the name of the module which the current module was deduped
1901 // with. If the named module isn't in the map, then we didn't encounter it
1902 // in the circuit.
1903 auto getLead = [&](StringAttr module) -> StringAttr {
1904 auto it = dedupMap.find(module);
1905 if (it == dedupMap.end()) {
1906 auto diag = emitError(circuit.getLoc(),
1907 "MustDeduplicateAnnotation references module ")
1908 << module << " which does not exist";
1909 failed = true;
1910 return nullptr;
1911 }
1912 return it->second;
1913 };
1914
1915 LLVM_DEBUG(llvm::dbgs() << "Update annotations\n");
1916 AnnotationSet::removeAnnotations(circuit, [&](Annotation annotation) {
1917 if (!annotation.isClass(mustDeduplicateAnnoClass))
1918 return false;
1919 auto modules = annotation.getMember<ArrayAttr>("modules");
1920 if (!modules) {
1921 emitError(circuit.getLoc(),
1922 "MustDeduplicateAnnotation missing \"modules\" member");
1923 failed = true;
1924 return false;
1925 }
1926 // Empty module list has nothing to process.
1927 if (modules.empty())
1928 return true;
1929 // Get the first element.
1930 auto firstModule = parseModule(modules[0]);
1931 auto firstLead = getLead(firstModule);
1932 if (!firstLead)
1933 return false;
1934 // Verify that the remaining elements are all the same as the first.
1935 for (auto attr : modules.getValue().drop_front()) {
1936 auto nextModule = parseModule(attr);
1937 auto nextLead = getLead(nextModule);
1938 if (!nextLead)
1939 return false;
1940 if (firstLead != nextLead) {
1941 auto diag = emitError(circuit.getLoc(), "module ")
1942 << nextModule << " not deduplicated with " << firstModule;
1943 auto a = instanceGraph.lookup(firstLead)->getModule();
1944 auto b = instanceGraph.lookup(nextLead)->getModule();
1945 equiv.check(diag, a, b);
1946 failed = true;
1947 return false;
1948 }
1949 }
1950 return true;
1951 });
1952 if (failed)
1953 return signalPassFailure();
1954
1955 // Remove all dedup group attributes, they only exist during this pass.
1956 for (auto module : circuit.getOps<FModuleLike>())
1957 module->removeDiscardableAttr(dedupGroupAttrName);
1958
1959 // Fixup all operations that we've found to be sensitive to symbol names.
1960 // This includes module and class instances, wires, connects, etc.
1961 fixupSymbolSensitiveOps(instanceGraph, moduleToModuleInfo, dedupMap);
1962
1963 markAnalysesPreserved<NLATable>();
1964 if (!anythingChanged)
1965 markAllAnalysesPreserved();
1966 }
1967};
1968} // end anonymous namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location mergeLoc(MLIRContext *context, Location to, Location from)
Definition Dedup.cpp:960
static void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src)
This fixes up connects when the field names of a bundle type changes.
Definition Dedup.cpp:1586
static bool canRemoveModule(mlir::SymbolOpInterface symbol)
Returns true if the module can be removed.
Definition Dedup.cpp:56
static void fixupSymbolSensitiveOp(Operation *op, InstanceGraph &instanceGraph, const DenseMap< Attribute, StringAttr > &dedupMap)
Adjust the symbol references in an op.
Definition Dedup.cpp:1612
static void fixupSymbolSensitiveOps(InstanceGraph &instanceGraph, const DenseMap< Operation *, ModuleInfoRef > &moduleToModuleInfo, const DenseMap< Attribute, StringAttr > &dedupMap)
Adjust the symbol references in ops marked as sensitive to them.
Definition Dedup.cpp:1680
static void mergeRegions(Region *region1, Region *region2)
Definition HWCleanup.cpp:77
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
void setDict(DictionaryAttr dict)
Set the data dictionary of this attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
StringAttr getClassAttr() const
Return the 'class' that this annotation is representing.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
A table of inner symbols and their resolutions.
auto getModule()
Get the module that this node is tracking.
decltype(auto) walkPostOrder(Fn &&fn)
Perform a post-order walk across the modules.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
static StringRef toString(Direction direction)
Definition FIRRTLEnums.h:44
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:36
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(const llvm::Twine &str, unsigned width=80)
Write a "header"-like string to the debug stream with a certain width.
Definition Debug.cpp:17
SmallVector< FlatSymbolRefAttr > createNLAs(Operation *fromModule, ArrayRef< Attribute > baseNamepath, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of the from module and create an NLA for each one, appending the baseNamep...
Definition Dedup.cpp:1148
MLIRContext * context
Definition Dedup.cpp:1554
Block * nlaBlock
We insert all NLAs to the beginning of this block.
Definition Dedup.cpp:1562
void recordAnnotations(Operation *op)
Record all targets which use an NLA.
Definition Dedup.cpp:1090
void eraseNLA(hw::HierPathOp nla)
This erases the NLA op, and removes the NLA from every module's NLA map, but it does not delete the N...
Definition Dedup.cpp:1212
void mergeAnnotations(FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Merge all annotations and port annotations on two operations.
Definition Dedup.cpp:1393
void replaceInstances(FModuleLike toModule, Operation *fromModule)
This deletes and replaces all instances of the "fromModule" with instances of the "toModule".
Definition Dedup.cpp:1106
void record(FModuleLike module)
Record the usages of any NLA's in this module, so that we may update the annotation if the parent mod...
Definition Dedup.cpp:1064
void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName, StringAttr fromName)
Definition Dedup.cpp:1291
void mergeRegions(RenameMap &renameMap, FModuleLike toModule, Region &toRegion, FModuleLike fromModule, Region &fromRegion)
Definition Dedup.cpp:1546
void dedup(FModuleLike toModule, FModuleLike fromModule)
Remove the "fromModule", and replace all references to it with the "toModule".
Definition Dedup.cpp:1029
void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all the NLAs that the two modules participate in, replacing references to the "from" module w...
Definition Dedup.cpp:1283
SmallVector< FlatSymbolRefAttr > createNLAs(StringAttr toModuleName, FModuleLike fromModule, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of this module and create an NLA for each one.
Definition Dedup.cpp:1185
void recordAnnotations(AnnoTarget target)
For a specific annotation target, record all the unique NLAs which target it in the targetMap.
Definition Dedup.cpp:1083
NLATable * nlaTable
Cached nla table analysis.
Definition Dedup.cpp:1559
hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule, hw::InnerSymAttr toSym, hw::InnerSymAttr fromSym)
Definition Dedup.cpp:1418
void cloneAnnotation(SmallVectorImpl< FlatSymbolRefAttr > &nlas, Annotation anno, ArrayRef< NamedAttribute > attributes, unsigned nonLocalIndex, SmallVectorImpl< Annotation > &newAnnotations)
Clone the annotation for each NLA in a list.
Definition Dedup.cpp:1193
void recordSymRenames(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Definition Dedup.cpp:1468
void mergeAnnotations(FModuleLike toModule, AnnoTarget to, AnnotationSet toAnnos, FModuleLike fromModule, AnnoTarget from, AnnotationSet fromAnnos)
Merge the annotations of a specific target, either a operation or a port on an operation.
Definition Dedup.cpp:1370
StringAttr nonLocalString
Definition Dedup.cpp:1572
hw::InnerSymbolNamespace & getNamespace(Operation *module)
Get a cached namespace for a module.
Definition Dedup.cpp:1076
SymbolTable & symbolTable
Definition Dedup.cpp:1556
void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Recursively merge two operations.
Definition Dedup.cpp:1513
DenseMap< Operation *, hw::InnerSymbolNamespace > moduleNamespaces
A module namespace cache.
Definition Dedup.cpp:1576
bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to, FModuleLike fromModule, Annotation anno, SmallVectorImpl< Annotation > &newAnnotations)
Take an annotation, and update it to be a non-local annotation.
Definition Dedup.cpp:1299
InstanceGraph & instanceGraph
Definition Dedup.cpp:1555
void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock, FModuleLike fromModule, Block &fromBlock)
Recursively merge two blocks.
Definition Dedup.cpp:1532
DenseMap< Attribute, llvm::SmallDenseSet< AnnoTarget > > targetMap
Definition Dedup.cpp:1565
StringAttr classString
Definition Dedup.cpp:1573
void copyAnnotations(FModuleLike toModule, AnnoTarget to, FModuleLike fromModule, AnnotationSet annos, SmallVectorImpl< Annotation > &newAnnotations, SmallPtrSetImpl< Attribute > &dontTouches)
Definition Dedup.cpp:1341
Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable, NLATable *nlaTable, CircuitOp circuit)
Definition Dedup.cpp:1013
void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all NLAs referencing the "from" module to point to the "to" module.
Definition Dedup.cpp:1222
DenseMap< StringAttr, StringAttr > RenameMap
Definition Dedup.cpp:1011
DenseMap< Attribute, Attribute > nlaCache
Definition Dedup.cpp:1569
const hw::InnerSymbolTable & a
Definition Dedup.cpp:477
ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
Definition Dedup.cpp:474
const hw::InnerSymbolTable & b
Definition Dedup.cpp:478
This class is for reporting differences between two modules which should have been deduplicated.
Definition Dedup.cpp:456
LogicalResult check(InFlightDiagnostic &diag, igraph::InstanceOpInterface a, igraph::InstanceOpInterface b)
Definition Dedup.cpp:764
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:947
std::string prettyPrint(Attribute attr)
Definition Dedup.cpp:481
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Block &aBlock, Operation *b, Block &bBlock)
Definition Dedup.cpp:553
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, Type aType, Operation *b, Type bType)
Definition Dedup.cpp:527
StringAttr noDedupClass
Definition Dedup.cpp:942
StringAttr dedupGroupAttrName
Definition Dedup.cpp:944
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, DictionaryAttr aDict, Operation *b, DictionaryAttr bDict)
Definition Dedup.cpp:676
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Region &aRegion, Operation *b, Region &bRegion)
Definition Dedup.cpp:631
StringAttr portDirectionsAttr
Definition Dedup.cpp:940
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, BundleType aType, Operation *b, BundleType bType)
Definition Dedup.cpp:497
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Operation *b)
Definition Dedup.cpp:802
Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
Definition Dedup.cpp:457
LogicalResult check(InFlightDiagnostic &diag, Operation *a, mlir::DenseBoolArrayAttr aAttr, Operation *b, mlir::DenseBoolArrayAttr bAttr)
Definition Dedup.cpp:651
InstanceGraph & instanceGraph
Definition Dedup.cpp:948
void check(InFlightDiagnostic &diag, Operation *a, Operation *b)
Definition Dedup.cpp:891
A reference to a ModuleInfo that compares and hashes like it.
Definition Dedup.cpp:422
ModuleInfo * info
Definition Dedup.cpp:424
ModuleInfoRef(ModuleInfo *info)
Definition Dedup.cpp:423
std::vector< Operation * > symbolSensitiveOps
Definition Dedup.cpp:91
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:87
std::array< uint8_t, 32 > structuralHash
Definition Dedup.cpp:85
This struct contains constant string attributes shared across different threads.
Definition Dedup.cpp:101
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:131
StructuralHasherSharedConstants(MLIRContext *context)
Definition Dedup.cpp:102
void populateInnerSymIDTable(FModuleLike module)
Find all the ports and operations which may define an inner symbol operations and give each a unique ...
Definition Dedup.cpp:151
void update(Operation *op, DictionaryAttr dict)
Hash the top level attribute dictionary of the operation.
Definition Dedup.cpp:274
void update(Type type)
Definition Dedup.cpp:251
void update(const void *pointer)
Definition Dedup.cpp:208
void update(ClassType type)
Definition Dedup.cpp:236
DenseMap< void *, unsigned > idTable
Definition Dedup.cpp:394
void update(const std::pair< T, U > &pair)
Definition Dedup.cpp:219
void update(Operation *op)
Definition Dedup.cpp:360
llvm::SHA256 sha
Definition Dedup.cpp:408
DenseMap< StringAttr, std::pair< size_t, size_t > > innerSymIDTable
Definition Dedup.cpp:398
ModuleInfo getModuleInfo(FModuleLike module)
Definition Dedup.cpp:138
void update(size_t value)
Definition Dedup.cpp:213
void update(BundleType type)
Definition Dedup.cpp:227
unsigned getID(void *object)
Definition Dedup.cpp:170
void update(OpResult result)
Definition Dedup.cpp:259
void update(OpOperand &operand)
Definition Dedup.cpp:191
StructuralHasher(const StructuralHasherSharedConstants &constants)
Definition Dedup.cpp:135
std::vector< Operation * > symbolSensitiveOps
Definition Dedup.cpp:416
void update(Region *region)
Definition Dedup.cpp:352
void update(Block *block)
Definition Dedup.cpp:341
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:401
void update(TypeID typeID)
Definition Dedup.cpp:224
const StructuralHasherSharedConstants & constants
Definition Dedup.cpp:404
std::pair< size_t, size_t > getInnerSymID(StringAttr name)
Definition Dedup.cpp:187
unsigned finalizeID(void *object)
Definition Dedup.cpp:178
void update(mlir::OperationName name)
Definition Dedup.cpp:335
An annotation target is used to keep track of something that is targeted by an Annotation.
AnnotationSet getAnnotations() const
Get the annotations associated with the target.
void setAnnotations(AnnotationSet annotations) const
Set the annotations associated with the target.
This represents an annotation targeting a specific operation.
Attribute getNLAReference(hw::InnerSymbolNamespace &moduleNamespace) const
This represents an annotation targeting a specific port of a module, memory, or instance.
static bool isEqual(const ModuleInfoRef &lhs, const ModuleInfoRef &rhs)
Definition Dedup.cpp:443
static unsigned getHashValue(const ModuleInfoRef &ref)
Definition Dedup.cpp:431