CIRCT 23.0.0git
Loading...
Searching...
No Matches
Dedup.cpp
Go to the documentation of this file.
1//===- Dedup.cpp - FIRRTL module deduping -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements FIRRTL module deduplication.
10//
11//===----------------------------------------------------------------------===//
12
24#include "circt/Support/Debug.h"
25#include "circt/Support/LLVM.h"
26#include "mlir/IR/IRMapping.h"
27#include "mlir/IR/Threading.h"
28#include "mlir/Pass/Pass.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/DenseMapInfo.h"
31#include "llvm/ADT/Hashing.h"
32#include "llvm/ADT/PostOrderIterator.h"
33#include "llvm/ADT/SmallPtrSet.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/Format.h"
36#include "llvm/Support/SHA256.h"
37
38#define DEBUG_TYPE "firrtl-dedup"
39
40namespace circt {
41namespace firrtl {
42#define GEN_PASS_DEF_DEDUP
43#include "circt/Dialect/FIRRTL/Passes.h.inc"
44} // namespace firrtl
45} // namespace circt
46
47using namespace circt;
48using namespace firrtl;
49using hw::InnerRefAttr;
50
51//===----------------------------------------------------------------------===//
52// Utility function for classifying a Symbol's dedup-ability.
53//===----------------------------------------------------------------------===//
54
55/// Returns true if the module can be removed.
56static bool canRemoveModule(mlir::SymbolOpInterface symbol) {
57 // If the symbol is not private, it cannot be removed.
58 if (!symbol.isPrivate())
59 return false;
60 // Classes may be referenced in object types, so can not normally be removed
61 // if we can't find any symbol uses. Since we know that dedup will update the
62 // types of instances appropriately, we can ignore that and return true here.
63 if (isa<ClassLike>(*symbol))
64 return true;
65 // If module can not be removed even if no uses can be found, we can not
66 // delete it. The implication is that there are hidden symbol uses that dedup
67 // will not properly update.
68 if (!symbol.canDiscardOnUseEmpty())
69 return false;
70 // The module can be deleted.
71 return true;
72}
73
74//===----------------------------------------------------------------------===//
75// Hashing
76//===----------------------------------------------------------------------===//
77
78// This struct contains information to determine module module uniqueness. A
79// first element is a structural hash of the module, and the second element is
80// an array which tracks module names encountered in the walk. Since module
81// names could be replaced during dedup, it's necessary to keep names up-to-date
82// before actually combining them into structural hashes.
83struct ModuleInfo {
84 // SHA256 hash.
85 std::array<uint8_t, 32> structuralHash;
86 // Module names referred by instance op in the module.
87 std::vector<StringAttr> referredModuleNames;
88 // The operations that contain references to symbols that may be changed by
89 // dedup. These need a fixup pass after dedup. This is just an optimization
90 // and does not factor into the hash or the equality between `ModuleInfo`s.
91 std::vector<Operation *> symbolSensitiveOps;
92};
93
94static bool operator==(const ModuleInfo &lhs, const ModuleInfo &rhs) {
95 return lhs.structuralHash == rhs.structuralHash &&
97}
98
99/// This struct contains constant string attributes shared across different
100/// threads.
103 portTypesAttr = StringAttr::get(context, "portTypes");
104 moduleNameAttr = StringAttr::get(context, "moduleName");
105 portNamesAttr = StringAttr::get(context, "portNames");
106 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
107 nonessentialAttributes.insert(StringAttr::get(context, "convention"));
108 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
109 nonessentialAttributes.insert(StringAttr::get(context, "name"));
110 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
111 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
112 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
113 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
114 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
115 nonessentialAttributes.insert(StringAttr::get(context, "sym_visibility"));
116 };
117
118 // This is a cached "portTypes" string attr.
119 StringAttr portTypesAttr;
120
121 // This is a cached "moduleName" string attr.
122 StringAttr moduleNameAttr;
123
124 // This is a cached "moduleNames" string attr.
125 StringAttr moduleNamesAttr;
126
127 // This is a cached "portNames" string attr.
128 StringAttr portNamesAttr;
129
130 // This is a set of every attribute we should ignore.
131 DenseSet<Attribute> nonessentialAttributes;
132};
133
136 : constants(constants) {}
137
138 ModuleInfo getModuleInfo(FModuleLike module) {
140 update(&(*module));
141 return {sha.final(), std::move(referredModuleNames),
142 std::move(symbolSensitiveOps)};
143 }
144
145private:
146 /// Find all the ports and operations which may define an inner symbol
147 /// operations and give each a unique id. If the port/operation does define
148 /// an inner symbol, map the symbol name to a pair of the id and the symbol's
149 /// field id. When we hash (local) references to this inner symbol, we will
150 /// hash in the id and the the field id.
151 void populateInnerSymIDTable(FModuleLike module) {
152 // Add port symbols. If no port has a symbol defined, the port symbol array
153 // will be totally empty.
154 for (auto [index, innerSym] : llvm::enumerate(module.getPortSymbols())) {
155 for (auto prop : cast<hw::InnerSymAttr>(innerSym))
156 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
157 }
158 // Add operation symbols.
159 size_t index = module.getNumPorts();
160 module.walk([&](hw::InnerSymbolOpInterface innerSymOp) {
161 if (auto innerSym = innerSymOp.getInnerSymAttr()) {
162 for (auto prop : innerSym)
163 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
164 }
165 ++index;
166 });
167 }
168
169 // Get the identifier for an object. The identifier is assigned on first use.
170 unsigned getID(void *object) {
171 auto [it, inserted] = idTable.try_emplace(object, nextID);
172 if (inserted)
173 ++nextID;
174 return it->second;
175 }
176
177 // Get the identifier for an IR object. Free the ID, too.
178 unsigned finalizeID(void *object) {
179 auto it = idTable.find(object);
180 if (it == idTable.end())
181 return nextID++;
182 auto id = it->second;
183 idTable.erase(it);
184 return id;
185 }
186
187 std::pair<size_t, size_t> getInnerSymID(StringAttr name) {
188 return innerSymIDTable.at(name);
189 }
190
191 void update(OpOperand &operand) {
192 auto value = operand.get();
193 if (auto result = dyn_cast<OpResult>(value)) {
194 auto *op = result.getOwner();
195 update(getID(op));
196 update(result.getResultNumber());
197 return;
198 }
199 if (auto argument = dyn_cast<BlockArgument>(value)) {
200 auto *block = argument.getOwner();
201 update(getID(block));
202 update(argument.getArgNumber());
203 return;
204 }
205 llvm_unreachable("Unknown value type");
206 }
207
208 void update(const void *pointer) {
209 auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
210 sha.update(ArrayRef<uint8_t>(addr, sizeof pointer));
211 }
212
213 void update(size_t value) {
214 auto *addr = reinterpret_cast<const uint8_t *>(&value);
215 sha.update(ArrayRef<uint8_t>(addr, sizeof value));
216 }
217
218 template <typename T, typename U>
219 void update(const std::pair<T, U> &pair) {
220 update(pair.first);
221 update(pair.second);
222 }
223
224 void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
225
226 // NOLINTNEXTLINE(misc-no-recursion)
227 void update(BundleType type) {
228 update(type.getTypeID());
229 for (auto &element : type.getElements()) {
230 update(element.isFlip);
231 update(element.type);
232 }
233 }
234
235 // NOLINTNEXTLINE(misc-no-recursion)
236 void update(ClassType type) {
237 update(type.getTypeID());
238 // Don't hash the class name directly, since it may be replaced during
239 // dedup. Record the class name instead and lazily combine their hashes
240 // using the same mechanism as instances and modules.
241 hasSeenSymbol = true;
242 referredModuleNames.push_back(type.getNameAttr().getAttr());
243 for (auto &element : type.getElements()) {
244 update(element.name.getAsOpaquePointer());
245 update(element.type);
246 update(static_cast<unsigned>(element.direction));
247 }
248 }
249
250 // NOLINTNEXTLINE(misc-no-recursion)
251 void update(Type type) {
252 if (auto bundle = type_dyn_cast<BundleType>(type))
253 return update(bundle);
254 if (auto klass = type_dyn_cast<ClassType>(type))
255 return update(klass);
256 update(type.getAsOpaquePointer());
257 }
258
259 void update(OpResult result) {
260 // Like instance ops, don't use object ops' result types since they might be
261 // replaced by dedup. Record the class names and lazily combine their hashes
262 // using the same mechanism as instances and modules.
263 if (auto objectOp = dyn_cast<ObjectOp>(result.getOwner())) {
264 hasSeenSymbol = true;
265 referredModuleNames.push_back(objectOp.getType().getNameAttr().getAttr());
266 return;
267 }
268
269 update(result.getType());
270 }
271
272 /// Hash the top level attribute dictionary of the operation. This function
273 /// has special handling for inner symbols, ports, and referenced modules.
274 void update(Operation *op, DictionaryAttr dict) {
275 for (auto namedAttr : dict) {
276 auto name = namedAttr.getName();
277 auto value = namedAttr.getValue();
278
279 // Check whether this attribute contains a nested symbol, just to make
280 // sure we revisit this op after dedup and update any symbols.
281 value.walk([&](FlatSymbolRefAttr) { hasSeenSymbol = true; });
282
283 // Skip names and annotations, except in certain cases.
284 // Names of ports are load bearing for classes, so we do hash those.
285 bool isClassPortNames =
286 isa<ClassLike>(op) && name == constants.portNamesAttr;
287 if (constants.nonessentialAttributes.contains(name) && !isClassPortNames)
288 continue;
289
290 // Hash the attribute name (an interned pointer).
291 update(name.getAsOpaquePointer());
292
293 // Hash the port types.
294 if (name == constants.portTypesAttr) {
295 auto portTypes = cast<ArrayAttr>(value).getAsValueRange<TypeAttr>();
296 for (auto type : portTypes)
297 update(type);
298 continue;
299 }
300
301 // For instance op, don't use `moduleName` attributes since they might be
302 // replaced by dedup. Record the names and lazily combine their hashes.
303 // It is assumed that module names are hashed only through instance ops;
304 // it could cause suboptimal results if there was other operation that
305 // refers to module names through essential attributes.
306 if (isa<InstanceOp>(op) && name == constants.moduleNameAttr) {
307 referredModuleNames.push_back(cast<FlatSymbolRefAttr>(value).getAttr());
308 continue;
309 }
310
311 if (isa<InstanceChoiceOp>(op) && name == constants.moduleNamesAttr) {
312 for (auto module : cast<ArrayAttr>(value))
313 referredModuleNames.push_back(
314 cast<FlatSymbolRefAttr>(module).getAttr());
315 continue;
316 }
317
318 // TODO: properly handle DistinctAttr, including its use in paths.
319 // See https://github.com/llvm/circt/issues/6583.
320 if (isa<DistinctAttr>(value))
321 continue;
322
323 // If this is an symbol reference, we need to perform name erasure.
324 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(value)) {
325 update(getInnerSymID(innerRef.getName()));
326 continue;
327 }
328
329 // We don't need to handle this attribute specially, so hash its unique
330 // address.
331 update(value.getAsOpaquePointer());
332 }
333 }
334
335 void update(mlir::OperationName name) {
336 // Operation names are interned.
337 update(name.getAsOpaquePointer());
338 }
339
340 // NOLINTNEXTLINE(misc-no-recursion)
341 void update(Block *block) {
342 for (auto &op : llvm::reverse(*block))
343 update(&op);
344 for (auto type : block->getArgumentTypes())
345 update(type);
346 update(finalizeID(block));
347 update(position);
348 ++position;
349 }
350
351 // NOLINTNEXTLINE(misc-no-recursion)
352 void update(Region *region) {
353 for (auto &block : llvm::reverse(region->getBlocks()))
354 update(&block);
355 update(position);
356 ++position;
357 }
358
359 // NOLINTNEXTLINE(misc-no-recursion)
360 void update(Operation *op) {
361 // Hash the regions. We need to make sure an empty region doesn't hash the
362 // same as no region, so we include the number of regions.
363 update(op->getNumRegions());
364 for (auto &region : reverse(op->getRegions()))
365 update(&region);
366
367 update(op->getName());
368
369 // Record the uses for later hashing.
370 for (auto &operand : op->getOpOperands())
371 update(operand);
372
373 // This happens after the numbering above, as it uses blockarg numbering
374 // for inner symbols.
375 hasSeenSymbol = false;
376 update(op, op->getAttrDictionary());
377
378 // Record any op results (types).
379 for (auto result : op->getResults())
380 update(result);
381
382 // If any of the operands, attributes, or results depended on symbols,
383 // remember this op for later adjustment.
384 if (hasSeenSymbol)
385 symbolSensitiveOps.push_back(op);
386
387 // Incorporate the hash of uses we have already built.
388 update(finalizeID(op));
389 update(position);
390 ++position;
391 }
392
393 // A map from an operation/block, to its identifier.
394 DenseMap<void *, unsigned> idTable;
395 unsigned nextID = 0;
396
397 // A map from an inner symbol, to its identifier.
398 DenseMap<StringAttr, std::pair<size_t, size_t>> innerSymIDTable;
399
400 // This keeps track of module names in the order of the appearance.
401 std::vector<StringAttr> referredModuleNames;
402
403 // String constants.
405
406 // This is the actual running hash calculation. This is a stateful element
407 // that should be reinitialized after each hash is produced.
408 llvm::SHA256 sha;
409
410 // The index of the current op. Increment after handling each op.
411 size_t position = 0;
412
413 // The operations that contain references to symbols that may be changed by
414 // dedup. These need a fixup pass after dedup.
415 bool hasSeenSymbol = false;
416 std::vector<Operation *> symbolSensitiveOps;
417};
418
419/// A reference to a `ModuleInfo` that compares and hashes like it. This is used
420/// to keep the potentially heavy module infos in a vector, and then populate
421/// maps with references to them.
426
427/// Allow `ModuleInfoRef` to be used as dense map keys. Hashes and compares the
428/// `ModuleInfo` the ref points to.
429template <>
434
438
439 static unsigned getHashValue(const ModuleInfoRef &ref) {
440 // We assume SHA256 is already a good hash and just truncate down to the
441 // number of bytes we need for DenseMap.
442 unsigned hash;
443 std::memcpy(&hash, ref.info->structuralHash.data(), sizeof(unsigned));
444
445 // Combine module names.
446 return llvm::hash_combine(
447 hash, llvm::hash_combine_range(ref.info->referredModuleNames.begin(),
448 ref.info->referredModuleNames.end()));
449 }
450
451 static bool isEqual(const ModuleInfoRef &lhs, const ModuleInfoRef &rhs) {
452 auto *empty = getEmptyKey().info;
453 auto *tombstone = getTombstoneKey().info;
454 if (lhs.info == empty || rhs.info == empty || lhs.info == tombstone ||
455 rhs.info == tombstone)
456 return lhs.info == rhs.info;
457 return *lhs.info == *rhs.info;
458 }
459};
460
461//===----------------------------------------------------------------------===//
462// Equivalence
463//===----------------------------------------------------------------------===//
464
465/// This class is for reporting differences between two modules which should
466/// have been deduplicated.
470 noDedupClass = StringAttr::get(context, noDedupAnnoClass);
471 dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
472 portDirectionsAttr = StringAttr::get(context, "portDirections");
473 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
474 nonessentialAttributes.insert(StringAttr::get(context, "name"));
475 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
476 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
477 nonessentialAttributes.insert(StringAttr::get(context, "portTypes"));
478 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
479 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
480 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
481 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
482 }
483
491
492 std::string prettyPrint(Attribute attr) {
493 SmallString<64> buffer;
494 llvm::raw_svector_ostream os(buffer);
495 if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
496 os << "0x";
497 if (integerAttr.getType().isSignlessInteger())
498 integerAttr.getValue().toStringUnsigned(buffer, /*radix=*/16);
499 else
500 integerAttr.getAPSInt().toString(buffer, /*radix=*/16);
501
502 } else
503 os << attr;
504 return std::string(buffer);
505 }
506
507 // NOLINTNEXTLINE(misc-no-recursion)
508 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
509 Operation *a, BundleType aType, Operation *b,
510 BundleType bType) {
511 if (aType.getNumElements() != bType.getNumElements()) {
512 diag.attachNote(a->getLoc())
513 << message << " bundle type has different number of elements";
514 diag.attachNote(b->getLoc()) << "second operation here";
515 return failure();
516 }
517
518 for (auto elementPair :
519 llvm::zip(aType.getElements(), bType.getElements())) {
520 auto aElement = std::get<0>(elementPair);
521 auto bElement = std::get<1>(elementPair);
522 if (aElement.isFlip != bElement.isFlip) {
523 diag.attachNote(a->getLoc()) << message << " bundle element "
524 << aElement.name << " flip does not match";
525 diag.attachNote(b->getLoc()) << "second operation here";
526 return failure();
527 }
528
529 if (failed(check(diag,
530 message + " -> bundle element \'" +
531 aElement.name.getValue() + "'",
532 a, aElement.type, b, bElement.type)))
533 return failure();
534 }
535 return success();
536 }
537
538 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
539 Operation *a, Type aType, Operation *b, Type bType) {
540 if (aType == bType)
541 return success();
542 if (auto aBundleType = type_dyn_cast<BundleType>(aType))
543 if (auto bBundleType = type_dyn_cast<BundleType>(bType))
544 return check(diag, message, a, aBundleType, b, bBundleType);
545 if (type_isa<RefType>(aType) && type_isa<RefType>(bType) &&
546 aType != bType) {
547 diag.attachNote(a->getLoc())
548 << message << ", has a RefType with a different base type "
549 << type_cast<RefType>(aType).getType()
550 << " in the same position of the two modules marked as 'must dedup'. "
551 "(This may be due to Grand Central Taps or Views being different "
552 "between the two modules.)";
553 diag.attachNote(b->getLoc())
554 << "the second module has a different base type "
555 << type_cast<RefType>(bType).getType();
556 return failure();
557 }
558 diag.attachNote(a->getLoc())
559 << message << " types don't match, first type is " << aType;
560 diag.attachNote(b->getLoc()) << "second type is " << bType;
561 return failure();
562 }
563
564 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
565 Block &aBlock, Operation *b, Block &bBlock) {
566
567 // Block argument types.
568 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
569 auto portNo = 0;
570 auto emitMissingPort = [&](Value existsVal, Operation *opExists,
571 Operation *opDoesNotExist) {
572 StringRef portName;
573 auto portNames = opExists->getAttrOfType<ArrayAttr>("portNames");
574 if (portNames)
575 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
576 portName = portNameAttr.getValue();
577 if (type_isa<RefType>(existsVal.getType())) {
578 diag.attachNote(opExists->getLoc())
579 << " contains a RefType port named '" + portName +
580 "' that only exists in one of the modules (can be due to "
581 "difference in Grand Central Tap or View of two modules "
582 "marked with must dedup)";
583 diag.attachNote(opDoesNotExist->getLoc())
584 << "second module to be deduped that does not have the RefType "
585 "port";
586 } else {
587 diag.attachNote(opExists->getLoc())
588 << "port '" + portName + "' only exists in one of the modules";
589 diag.attachNote(opDoesNotExist->getLoc())
590 << "second module to be deduped that does not have the port";
591 }
592 return failure();
593 };
594
595 for (auto argPair :
596 llvm::zip_longest(aBlock.getArguments(), bBlock.getArguments())) {
597 auto &aArg = std::get<0>(argPair);
598 auto &bArg = std::get<1>(argPair);
599 if (aArg.has_value() && bArg.has_value()) {
600 // TODO: we should print the port number if there are no port names, but
601 // there are always port names ;).
602 StringRef portName;
603 if (portNames) {
604 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
605 portName = portNameAttr.getValue();
606 }
607 // Assumption here that block arguments correspond to ports.
608 if (failed(check(diag, "module port '" + portName + "'", a,
609 aArg->getType(), b, bArg->getType())))
610 return failure();
611 data.map.map(aArg.value(), bArg.value());
612 portNo++;
613 continue;
614 }
615 if (!aArg.has_value())
616 std::swap(a, b);
617 return emitMissingPort(aArg.has_value() ? aArg.value() : bArg.value(), a,
618 b);
619 }
620
621 // Blocks operations.
622 auto aIt = aBlock.begin();
623 auto aEnd = aBlock.end();
624 auto bIt = bBlock.begin();
625 auto bEnd = bBlock.end();
626 while (aIt != aEnd && bIt != bEnd)
627 if (failed(check(diag, data, &*aIt++, &*bIt++)))
628 return failure();
629 if (aIt != aEnd) {
630 diag.attachNote(aIt->getLoc()) << "first block has more operations";
631 diag.attachNote(b->getLoc()) << "second block here";
632 return failure();
633 }
634 if (bIt != bEnd) {
635 diag.attachNote(bIt->getLoc()) << "second block has more operations";
636 diag.attachNote(a->getLoc()) << "first block here";
637 return failure();
638 }
639 return success();
640 }
641
642 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
643 Region &aRegion, Operation *b, Region &bRegion) {
644 auto aIt = aRegion.begin();
645 auto aEnd = aRegion.end();
646 auto bIt = bRegion.begin();
647 auto bEnd = bRegion.end();
648
649 // Region blocks.
650 while (aIt != aEnd && bIt != bEnd)
651 if (failed(check(diag, data, a, *aIt++, b, *bIt++)))
652 return failure();
653 if (aIt != aEnd || bIt != bEnd) {
654 diag.attachNote(a->getLoc())
655 << "operation regions have different number of blocks";
656 diag.attachNote(b->getLoc()) << "second operation here";
657 return failure();
658 }
659 return success();
660 }
661
662 LogicalResult check(InFlightDiagnostic &diag, Operation *a,
663 mlir::DenseBoolArrayAttr aAttr, Operation *b,
664 mlir::DenseBoolArrayAttr bAttr) {
665 if (aAttr == bAttr)
666 return success();
667 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
668 for (unsigned i = 0, e = aAttr.size(); i < e; ++i) {
669 auto aDirection = aAttr[i];
670 auto bDirection = bAttr[i];
671 if (aDirection != bDirection) {
672 auto &note = diag.attachNote(a->getLoc()) << "module port ";
673 if (portNames)
674 note << "'" << cast<StringAttr>(portNames[i]).getValue() << "'";
675 else
676 note << i;
677 note << " directions don't match, first direction is '"
678 << direction::toString(aDirection) << "'";
679 diag.attachNote(b->getLoc()) << "second direction is '"
680 << direction::toString(bDirection) << "'";
681 return failure();
682 }
683 }
684 return success();
685 }
686
687 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
688 DictionaryAttr aDict, Operation *b,
689 DictionaryAttr bDict) {
690 // Fast path.
691 if (aDict == bDict)
692 return success();
693
694 DenseSet<Attribute> seenAttrs;
695 for (auto namedAttr : aDict) {
696 auto attrName = namedAttr.getName();
697 if (nonessentialAttributes.contains(attrName))
698 continue;
699
700 auto aAttr = namedAttr.getValue();
701 auto bAttr = bDict.get(attrName);
702 if (!bAttr) {
703 diag.attachNote(a->getLoc())
704 << "second operation is missing attribute " << attrName;
705 diag.attachNote(b->getLoc()) << "second operation here";
706 return diag;
707 }
708
709 if (isa<hw::InnerRefAttr>(aAttr) && isa<hw::InnerRefAttr>(bAttr)) {
710 auto bRef = cast<hw::InnerRefAttr>(bAttr);
711 auto aRef = cast<hw::InnerRefAttr>(aAttr);
712 // See if they are pointing at the same operation or port.
713 auto aTarget = data.a.lookup(aRef.getName());
714 auto bTarget = data.b.lookup(bRef.getName());
715 if (!aTarget || !bTarget)
716 diag.attachNote(a->getLoc())
717 << "malformed ir, possibly violating use-before-def";
718 auto error = [&]() {
719 diag.attachNote(a->getLoc())
720 << "operations have different targets, first operation has "
721 << aTarget;
722 diag.attachNote(b->getLoc()) << "second operation has " << bTarget;
723 return failure();
724 };
725 if (aTarget.isPort()) {
726 // If they are targeting ports, make sure its the same port number.
727 if (!bTarget.isPort() || aTarget.getPort() != bTarget.getPort())
728 return error();
729 } else {
730 // Otherwise make sure that they are targeting the same operation.
731 if (!bTarget.isOpOnly() ||
732 data.map.lookupOrNull(aTarget.getOp()) != bTarget.getOp())
733 return error();
734 }
735 if (aTarget.getField() != bTarget.getField())
736 return error();
737 } else if (attrName == portDirectionsAttr) {
738 // Special handling for the port directions attribute for better
739 // error messages.
740 if (failed(check(diag, a, cast<mlir::DenseBoolArrayAttr>(aAttr), b,
741 cast<mlir::DenseBoolArrayAttr>(bAttr))))
742 return failure();
743 } else if (isa<DistinctAttr>(aAttr) && isa<DistinctAttr>(bAttr)) {
744 // TODO: properly handle DistinctAttr, including its use in paths.
745 // See https://github.com/llvm/circt/issues/6583
746 } else if (aAttr != bAttr) {
747 diag.attachNote(a->getLoc())
748 << "first operation has attribute '" << attrName.getValue()
749 << "' with value " << prettyPrint(aAttr);
750 diag.attachNote(b->getLoc())
751 << "second operation has value " << prettyPrint(bAttr);
752 return failure();
753 }
754 seenAttrs.insert(attrName);
755 }
756 if (aDict.getValue().size() != bDict.getValue().size()) {
757 for (auto namedAttr : bDict) {
758 auto attrName = namedAttr.getName();
759 // Skip the attribute if we don't care about this particular one or it
760 // is one that is known to be in both dictionaries.
761 if (nonessentialAttributes.contains(attrName) ||
762 seenAttrs.contains(attrName))
763 continue;
764 // We have found an attribute that is only in the second operation.
765 diag.attachNote(a->getLoc())
766 << "first operation is missing attribute " << attrName;
767 diag.attachNote(b->getLoc()) << "second operation here";
768 return failure();
769 }
770 }
771 return success();
772 }
773
774 // NOLINTNEXTLINE(misc-no-recursion)
775 LogicalResult check(InFlightDiagnostic &diag, igraph::InstanceOpInterface a,
776 igraph::InstanceOpInterface b) {
777 // Get the list of module names from the list (for InstanceOp/ObjectOp,
778 // there's only one)
779 auto aNames = a.getReferencedModuleNamesAttr();
780 auto bNames = b.getReferencedModuleNamesAttr();
781 if (aNames == bNames)
782 return success();
783
784 if (aNames.size() != bNames.size()) {
785 diag.attachNote(a->getLoc())
786 << "an instance has a different number of referenced "
787 "modules: first instance has "
788 << aNames.size() << " modules";
789 diag.attachNote(b->getLoc())
790 << "second instance has " << bNames.size() << " modules";
791 return failure();
792 }
793
794 for (auto [aName, bName] : llvm::zip(aNames.getAsRange<StringAttr>(),
795 bNames.getAsRange<StringAttr>())) {
796 if (aName == bName)
797 continue;
798 // If the modules instantiate are different we will want to know why the
799 // sub module did not dedupliate. This code recursively checks the child
800 // module.
801 auto aModule = instanceGraph.lookup(aName)->getModule();
802 auto bModule = instanceGraph.lookup(bName)->getModule();
803 // Create a new error for the submodule.
804 diag.attachNote(std::nullopt)
805 << "in instance " << a.getInstanceNameAttr() << " of " << aName
806 << ", and instance " << b.getInstanceNameAttr() << " of " << bName;
807 check(diag, aModule, bModule);
808 }
809 return failure();
810 }
811
812 // NOLINTNEXTLINE(misc-no-recursion)
813 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
814 Operation *b) {
815 // Operation name.
816 if (a->getName() != b->getName()) {
817 diag.attachNote(a->getLoc()) << "first operation is a " << a->getName();
818 diag.attachNote(b->getLoc()) << "second operation is a " << b->getName();
819 return failure();
820 }
821
822 // If it's a firrtl operation that implements InstanceOpInterface
823 // (InstanceOp/InstanceChoiceOp/ObjectOp) perform some checking and possibly
824 // recurse.
825 if (auto aInst = dyn_cast<igraph::InstanceOpInterface>(a))
826 if (auto bInst = dyn_cast<igraph::InstanceOpInterface>(b))
827 if (isa_and_nonnull<firrtl::FIRRTLDialect>(a->getDialect()) &&
828 isa_and_nonnull<firrtl::FIRRTLDialect>(b->getDialect()) &&
829 failed(check(diag, aInst, bInst)))
830 return failure();
831
832 // Operation results.
833 if (a->getNumResults() != b->getNumResults()) {
834 diag.attachNote(a->getLoc())
835 << "operations have different number of results";
836 diag.attachNote(b->getLoc()) << "second operation here";
837 return failure();
838 }
839 for (auto resultPair : llvm::zip(a->getResults(), b->getResults())) {
840 auto &aValue = std::get<0>(resultPair);
841 auto &bValue = std::get<1>(resultPair);
842 if (failed(check(diag, "operation result", a, aValue.getType(), b,
843 bValue.getType())))
844 return failure();
845 data.map.map(aValue, bValue);
846 }
847
848 // Operations operands.
849 if (a->getNumOperands() != b->getNumOperands()) {
850 diag.attachNote(a->getLoc())
851 << "operations have different number of operands";
852 diag.attachNote(b->getLoc()) << "second operation here";
853 return failure();
854 }
855 for (auto operandPair : llvm::zip(a->getOperands(), b->getOperands())) {
856 auto &aValue = std::get<0>(operandPair);
857 auto &bValue = std::get<1>(operandPair);
858 if (bValue != data.map.lookup(aValue)) {
859 diag.attachNote(a->getLoc())
860 << "operations use different operands, first operand is '"
861 << getFieldName(
862 getFieldRefFromValue(aValue, /*lookThroughCasts=*/true))
863 .first
864 << "'";
865 diag.attachNote(b->getLoc())
866 << "second operand is '"
867 << getFieldName(
868 getFieldRefFromValue(bValue, /*lookThroughCasts=*/true))
869 .first
870 << "', but should have been '"
871 << getFieldName(getFieldRefFromValue(data.map.lookup(aValue),
872 /*lookThroughCasts=*/true))
873 .first
874 << "'";
875 return failure();
876 }
877 }
878 data.map.map(a, b);
879
880 // Operation regions.
881 if (a->getNumRegions() != b->getNumRegions()) {
882 diag.attachNote(a->getLoc())
883 << "operations have different number of regions";
884 diag.attachNote(b->getLoc()) << "second operation here";
885 return failure();
886 }
887 for (auto regionPair : llvm::zip(a->getRegions(), b->getRegions())) {
888 auto &aRegion = std::get<0>(regionPair);
889 auto &bRegion = std::get<1>(regionPair);
890 if (failed(check(diag, data, a, aRegion, b, bRegion)))
891 return failure();
892 }
893
894 // Operation attributes.
895 if (failed(check(diag, data, a, a->getAttrDictionary(), b,
896 b->getAttrDictionary())))
897 return failure();
898 return success();
899 }
900
901 // NOLINTNEXTLINE(misc-no-recursion)
902 void check(InFlightDiagnostic &diag, Operation *a, Operation *b) {
903 hw::InnerSymbolTable aTable(a);
904 hw::InnerSymbolTable bTable(b);
905 ModuleData data(aTable, bTable);
907 diag.attachNote(a->getLoc()) << "module marked NoDedup";
908 return;
909 }
911 diag.attachNote(b->getLoc()) << "module marked NoDedup";
912 return;
913 }
914 auto aSymbol = cast<mlir::SymbolOpInterface>(a);
915 auto bSymbol = cast<mlir::SymbolOpInterface>(b);
916 if (!canRemoveModule(aSymbol) && !canRemoveModule(bSymbol)) {
917 diag.attachNote(a->getLoc())
918 << "module is "
919 << (aSymbol.isPrivate() ? "private but not discardable" : "public");
920 diag.attachNote(b->getLoc())
921 << "module is "
922 << (bSymbol.isPrivate() ? "private but not discardable" : "public");
923 return;
924 }
925 auto aGroup =
926 dyn_cast_or_null<StringAttr>(a->getDiscardableAttr(dedupGroupAttrName));
927 auto bGroup = dyn_cast_or_null<StringAttr>(
928 b->getAttrOfType<StringAttr>(dedupGroupAttrName));
929 if (aGroup != bGroup) {
930 if (bGroup) {
931 diag.attachNote(b->getLoc())
932 << "module is in dedup group '" << bGroup.str() << "'";
933 } else {
934 diag.attachNote(b->getLoc()) << "module is not part of a dedup group";
935 }
936 if (aGroup) {
937 diag.attachNote(a->getLoc())
938 << "module is in dedup group '" << aGroup.str() << "'";
939 } else {
940 diag.attachNote(a->getLoc()) << "module is not part of a dedup group";
941 }
942 return;
943 }
944 if (failed(check(diag, data, a, b)))
945 return;
946 diag.attachNote(a->getLoc()) << "first module here";
947 diag.attachNote(b->getLoc()) << "second module here";
948 }
949
950 // This is a cached "portDirections" string attr.
952 // This is a cached "NoDedup" annotation class string attr.
953 StringAttr noDedupClass;
954 // This is a cached string attr for the dedup group attribute.
956
957 // This is a set of every attribute we should ignore.
958 DenseSet<Attribute> nonessentialAttributes;
960};
961
962//===----------------------------------------------------------------------===//
963// Deduplication
964//===----------------------------------------------------------------------===//
965
966// Custom location merging. This only keeps track of 8 annotations from ".fir"
967// files, and however many annotations come from "real" sources. When
968// deduplicating, modules tend not to have scala source locators, so we wind
969// up fusing source locators for a module from every copy being deduped. There
970// is little value in this (all the modules are identical by definition).
971static Location mergeLoc(MLIRContext *context, Location to, Location from) {
972 // Unique the set of locations to be fused.
973 llvm::SmallSetVector<Location, 4> decomposedLocs;
974 // only track 8 "fir" locations
975 unsigned seenFIR = 0;
976 for (auto loc : {to, from}) {
977 // If the location is a fused location we decompose it if it has no
978 // metadata or the metadata is the same as the top level metadata.
979 if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
980 // UnknownLoc's have already been removed from FusedLocs so we can
981 // simply add all of the internal locations.
982 for (auto loc : fusedLoc.getLocations()) {
983 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
984 if (fileLoc.getFilename().strref().ends_with(".fir")) {
985 ++seenFIR;
986 if (seenFIR > 8)
987 continue;
988 }
989 }
990 decomposedLocs.insert(loc);
991 }
992 continue;
993 }
994
995 // Might need to skip this fir.
996 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
997 if (fileLoc.getFilename().strref().ends_with(".fir")) {
998 ++seenFIR;
999 if (seenFIR > 8)
1000 continue;
1001 }
1002 }
1003 // Otherwise, only add known locations to the set.
1004 if (!isa<UnknownLoc>(loc))
1005 decomposedLocs.insert(loc);
1006 }
1007
1008 auto locs = decomposedLocs.getArrayRef();
1009
1010 // Handle the simple cases of less than two locations. Ensure the metadata (if
1011 // provided) is not dropped.
1012 if (locs.empty())
1013 return UnknownLoc::get(context);
1014 if (locs.size() == 1)
1015 return locs.front();
1016
1017 return FusedLoc::get(context, locs);
1018}
1019
1020struct Deduper {
1021
1022 using RenameMap = DenseMap<StringAttr, StringAttr>;
1023
1025 NLATable *nlaTable, CircuitOp circuit)
1026 : context(circuit->getContext()), instanceGraph(instanceGraph),
1028 nlaBlock(circuit.getBodyBlock()),
1029 nonLocalString(StringAttr::get(context, "circt.nonlocal")),
1030 classString(StringAttr::get(context, "class")) {
1031 // Populate the NLA cache.
1032 for (auto nla : circuit.getOps<hw::HierPathOp>())
1033 nlaCache[nla.getNamepathAttr()] = nla.getSymNameAttr();
1034 }
1035
1036 /// Remove the "fromModule", and replace all references to it with the
1037 /// "toModule". Modules should be deduplicated in a bottom-up order. Any
1038 /// module which is not deduplicated needs to be recorded with the `record`
1039 /// call.
1040 void dedup(FModuleLike toModule, FModuleLike fromModule) {
1041 // A map of operation (e.g. wires, nodes) names which are changed, which is
1042 // used to update NLAs that reference the "fromModule".
1043 RenameMap renameMap;
1044
1045 // Merge the port locations.
1046 SmallVector<Attribute> newLocs;
1047 for (auto [toLoc, fromLoc] : llvm::zip(toModule.getPortLocations(),
1048 fromModule.getPortLocations())) {
1049 if (toLoc == fromLoc)
1050 newLocs.push_back(toLoc);
1051 else
1052 newLocs.push_back(mergeLoc(context, cast<LocationAttr>(toLoc),
1053 cast<LocationAttr>(fromLoc)));
1054 }
1055 toModule->setAttr("portLocations", ArrayAttr::get(context, newLocs));
1056
1057 // Merge the two modules.
1058 mergeOps(renameMap, toModule, toModule, fromModule, fromModule);
1059
1060 // Rewrite NLAs pathing through these modules to refer to the to module. It
1061 // is safe to do this at this point because NLAs cannot be one element long.
1062 // This means that all NLAs which require more context cannot be targetting
1063 // something in the module it self.
1064 if (auto to = dyn_cast<FModuleOp>(*toModule))
1065 rewriteModuleNLAs(renameMap, to, cast<FModuleOp>(*fromModule));
1066 else
1067 rewriteExtModuleNLAs(renameMap, toModule.getModuleNameAttr(),
1068 fromModule.getModuleNameAttr());
1069
1070 replaceInstances(toModule, fromModule);
1071 }
1072
1073 /// Record the usages of any NLA's in this module, so that we may update the
1074 /// annotation if the parent module is deduped with another module.
1075 void record(FModuleLike module) {
1076 // Record any annotations on the module.
1077 recordAnnotations(module);
1078 // Record port annotations.
1079 for (unsigned i = 0, e = getNumPorts(module); i < e; ++i)
1081 // Record any annotations in the module body.
1082 module->walk([&](Operation *op) { recordAnnotations(op); });
1083 }
1084
1085private:
1086 /// Get a cached namespace for a module.
1088 return moduleNamespaces.try_emplace(module, cast<FModuleLike>(module))
1089 .first->second;
1090 }
1091
1092 /// For a specific annotation target, record all the unique NLAs which
1093 /// target it in the `targetMap`.
1095 for (auto anno : target.getAnnotations())
1096 if (auto nlaRef = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
1097 targetMap[nlaRef.getAttr()].insert(target);
1098 }
1099
1100 /// Record all targets which use an NLA.
1101 void recordAnnotations(Operation *op) {
1102 // Record annotations.
1103 recordAnnotations(OpAnnoTarget(op));
1104
1105 // Record port annotations only if this is a mem operation.
1106 auto mem = dyn_cast<MemOp>(op);
1107 if (!mem)
1108 return;
1109
1110 // Record port annotations.
1111 for (unsigned i = 0, e = mem->getNumResults(); i < e; ++i)
1112 recordAnnotations(PortAnnoTarget(mem, i));
1113 }
1114
1115 /// This deletes and replaces all instances of the "fromModule" with instances
1116 /// of the "toModule".
1117 void replaceInstances(FModuleLike toModule, Operation *fromModule) {
1118 // Replace all instances of the other module.
1119 auto *fromNode =
1120 instanceGraph[::cast<igraph::ModuleOpInterface>(fromModule)];
1121 auto *toNode = instanceGraph[toModule];
1122 auto toModuleRef = FlatSymbolRefAttr::get(toModule.getModuleNameAttr());
1123 for (auto *oldInstRec : llvm::make_early_inc_range(fromNode->uses())) {
1124 auto inst = oldInstRec->getInstance();
1125 if (auto instOp = dyn_cast<InstanceOp>(*inst)) {
1126 instOp.setModuleNameAttr(toModuleRef);
1127 instOp.setPortNamesAttr(toModule.getPortNamesAttr());
1128 } else if (auto objectOp = dyn_cast<ObjectOp>(*inst)) {
1129 auto classLike = cast<ClassLike>(*toNode->getModule());
1130 ClassType classType = detail::getInstanceTypeForClassLike(classLike);
1131 objectOp.getResult().setType(classType);
1132 } else if (auto instanceChoiceOp = dyn_cast<InstanceChoiceOp>(*inst)) {
1133 auto fromModuleName = fromNode->getModule().getModuleNameAttr();
1134 SmallVector<Attribute> newModules;
1135 for (auto module : instanceChoiceOp.getReferencedModuleNamesAttr()) {
1136 auto moduleName = cast<StringAttr>(module);
1137 if (moduleName == fromModuleName)
1138 newModules.push_back(toModuleRef);
1139 else
1140 newModules.push_back(FlatSymbolRefAttr::get(moduleName));
1141 }
1142 instanceChoiceOp.setModuleNamesAttr(
1143 ArrayAttr::get(context, newModules));
1144 instanceChoiceOp.setPortNamesAttr(toModule.getPortNamesAttr());
1145 }
1146 oldInstRec->getParent()->addInstance(inst, toNode);
1147 oldInstRec->erase();
1148 }
1149 instanceGraph.erase(fromNode);
1150 fromModule->erase();
1151 }
1152
1153 /// Look up the instantiations of the `from` module and create an NLA for each
1154 /// one, appending the baseNamepath to each NLA. This is used to add more
1155 /// context to an already existing NLA. The `fromModule` is used to indicate
1156 /// which module the annotation is coming from before the merge, and will be
1157 /// used to create the namepaths.
1158 SmallVector<FlatSymbolRefAttr>
1159 createNLAs(Operation *fromModule, ArrayRef<Attribute> baseNamepath,
1160 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1161 // Create an attribute array with a placeholder in the first element, where
1162 // the root refence of the NLA will be inserted.
1163 SmallVector<Attribute> namepath = {nullptr};
1164 namepath.append(baseNamepath.begin(), baseNamepath.end());
1165
1166 auto loc = fromModule->getLoc();
1167 auto *fromNode = instanceGraph[cast<igraph::ModuleOpInterface>(fromModule)];
1168 SmallVector<FlatSymbolRefAttr> nlas;
1169 for (auto *instanceRecord : fromNode->uses()) {
1170 auto parent = cast<FModuleOp>(*instanceRecord->getParent()->getModule());
1171 auto inst = instanceRecord->getInstance();
1172 namepath[0] = OpAnnoTarget(inst).getNLAReference(getNamespace(parent));
1173 auto arrayAttr = ArrayAttr::get(context, namepath);
1174 // Check the NLA cache to see if we already have this NLA.
1175 auto &cacheEntry = nlaCache[arrayAttr];
1176 if (!cacheEntry) {
1177 auto builder = OpBuilder::atBlockBegin(nlaBlock);
1178 auto nla = hw::HierPathOp::create(builder, loc, "nla", arrayAttr);
1179 // Insert it into the symbol table to get a unique name.
1180 symbolTable.insert(nla);
1181 // Store it in the cache.
1182 cacheEntry = nla.getNameAttr();
1183 nla.setVisibility(vis);
1184 nlaTable->addNLA(nla);
1185 }
1186 auto nlaRef = FlatSymbolRefAttr::get(cast<StringAttr>(cacheEntry));
1187 nlas.push_back(nlaRef);
1188 }
1189 return nlas;
1190 }
1191
1192 /// Look up the instantiations of this module and create an NLA for each one.
1193 /// This returns an array of symbol references which can be used to reference
1194 /// the NLAs.
1195 SmallVector<FlatSymbolRefAttr>
1196 createNLAs(StringAttr toModuleName, FModuleLike fromModule,
1197 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1198 return createNLAs(fromModule, FlatSymbolRefAttr::get(toModuleName), vis);
1199 }
1200
1201 /// Clone the annotation for each NLA in a list. The attribute list should
1202 /// have a placeholder for the "circt.nonlocal" field, and `nonLocalIndex`
1203 /// should be the index of this field.
1204 void cloneAnnotation(SmallVectorImpl<FlatSymbolRefAttr> &nlas,
1205 Annotation anno, ArrayRef<NamedAttribute> attributes,
1206 unsigned nonLocalIndex,
1207 SmallVectorImpl<Annotation> &newAnnotations) {
1208 SmallVector<NamedAttribute> mutableAttributes(attributes.begin(),
1209 attributes.end());
1210 for (auto &nla : nlas) {
1211 // Add the new annotation.
1212 mutableAttributes[nonLocalIndex].setValue(nla);
1213 auto dict = DictionaryAttr::getWithSorted(context, mutableAttributes);
1214 // The original annotation records if its a subannotation.
1215 anno.setDict(dict);
1216 newAnnotations.push_back(anno);
1217 }
1218 }
1219
1220 /// This erases the NLA op, and removes the NLA from every module's NLA map,
1221 /// but it does not delete the NLA reference from the target operation's
1222 /// annotations.
1223 void eraseNLA(hw::HierPathOp nla) {
1224 // Erase the NLA from the leaf module's nlaMap.
1225 targetMap.erase(nla.getNameAttr());
1226 nlaTable->erase(nla);
1227 nlaCache.erase(nla.getNamepathAttr());
1228 symbolTable.erase(nla);
1229 }
1230
1231 /// Process all NLAs referencing the "from" module to point to the "to"
1232 /// module. This is used after merging two modules together.
1233 void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule,
1234 FModuleOp fromModule) {
1235 auto toName = toModule.getNameAttr();
1236 auto fromName = fromModule.getNameAttr();
1237 // Create a copy of the current NLAs. We will be pushing and removing
1238 // NLAs from this op as we go.
1239 auto moduleNLAs = nlaTable->lookup(fromModule.getNameAttr()).vec();
1240 // Change the NLA to target the toModule.
1241 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1242 // Now we walk the NLA searching for ones that require more context to be
1243 // added.
1244 for (auto nla : moduleNLAs) {
1245 auto elements = nla.getNamepath().getValue();
1246 // If we don't need to add more context, we're done here.
1247 if (nla.root() != toName)
1248 continue;
1249 // Create the replacement NLAs.
1250 SmallVector<Attribute> namepath(elements.begin(), elements.end());
1251 auto nlaRefs = createNLAs(fromModule, namepath, nla.getVisibility());
1252 // Copy out the targets, because we will be updating the map.
1253 auto &set = targetMap[nla.getSymNameAttr()];
1254 SmallVector<AnnoTarget> targets(set.begin(), set.end());
1255 // Replace the uses of the old NLA with the new NLAs.
1256 for (auto target : targets) {
1257 // We have to clone any annotation which uses the old NLA for each new
1258 // NLA. This array collects the new set of annotations.
1259 SmallVector<Annotation> newAnnotations;
1260 for (auto anno : target.getAnnotations()) {
1261 // Find the non-local field of the annotation.
1262 auto [it, found] = mlir::impl::findAttrSorted(
1263 anno.begin(), anno.end(), nonLocalString);
1264 // If this annotation doesn't use the target NLA, copy it with no
1265 // changes.
1266 if (!found || cast<FlatSymbolRefAttr>(it->getValue()).getAttr() !=
1267 nla.getSymNameAttr()) {
1268 newAnnotations.push_back(anno);
1269 continue;
1270 }
1271 auto nonLocalIndex = std::distance(anno.begin(), it);
1272 // Clone the annotation and add it to the list of new annotations.
1273 cloneAnnotation(nlaRefs, anno,
1274 ArrayRef<NamedAttribute>(anno.begin(), anno.end()),
1275 nonLocalIndex, newAnnotations);
1276 }
1277
1278 // Apply the new annotations to the operation.
1279 AnnotationSet annotations(newAnnotations, context);
1280 target.setAnnotations(annotations);
1281 // Record that target uses the NLA.
1282 for (auto nla : nlaRefs)
1283 targetMap[nla.getAttr()].insert(target);
1284 }
1285
1286 // Erase the old NLA and remove it from all breadcrumbs.
1287 eraseNLA(nla);
1288 }
1289 }
1290
1291 /// Process all the NLAs that the two modules participate in, replacing
1292 /// references to the "from" module with references to the "to" module, and
1293 /// adding more context if necessary.
1294 void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule,
1295 FModuleOp fromModule) {
1296 addAnnotationContext(renameMap, toModule, toModule);
1297 addAnnotationContext(renameMap, toModule, fromModule);
1298 }
1299
1300 // Update all NLAs which the "from" external module participates in to the
1301 // "toName".
1302 void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName,
1303 StringAttr fromName) {
1304 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1305 }
1306
1307 /// Take an annotation, and update it to be a non-local annotation. If the
1308 /// annotation is already non-local and has enough context, it will be skipped
1309 /// for now. Return true if the annotation was made non-local.
1310 bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to,
1311 FModuleLike fromModule, Annotation anno,
1312 SmallVectorImpl<Annotation> &newAnnotations) {
1313 // Start constructing a new annotation, pushing a "circt.nonLocal" field
1314 // into the correct spot if its not already a non-local annotation.
1315 SmallVector<NamedAttribute> attributes;
1316 int nonLocalIndex = -1;
1317 for (const auto &val : llvm::enumerate(anno)) {
1318 auto attr = val.value();
1319 // Is this field "circt.nonlocal"?
1320 auto compare = attr.getName().compare(nonLocalString);
1321 assert(compare != 0 && "should not pass non-local annotations here");
1322 if (compare == 1) {
1323 // This annotation definitely does not have "circt.nonlocal" field. Push
1324 // an empty place holder for the non-local annotation.
1325 nonLocalIndex = val.index();
1326 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1327 break;
1328 }
1329 // Otherwise push the current attribute and keep searching for the
1330 // "circt.nonlocal" field.
1331 attributes.push_back(attr);
1332 }
1333 if (nonLocalIndex == -1) {
1334 // Push an empty "circt.nonlocal" field to the last slot.
1335 nonLocalIndex = attributes.size();
1336 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1337 } else {
1338 // Copy the remaining annotation fields in.
1339 attributes.append(anno.begin() + nonLocalIndex, anno.end());
1340 }
1341
1342 // Construct the NLAs if we don't have any yet.
1343 auto nlaRefs = createNLAs(toModuleName, fromModule);
1344 for (auto nla : nlaRefs)
1345 targetMap[nla.getAttr()].insert(to);
1346
1347 // Clone the annotation for each new NLA.
1348 cloneAnnotation(nlaRefs, anno, attributes, nonLocalIndex, newAnnotations);
1349 return true;
1350 }
1351
1352 void copyAnnotations(FModuleLike toModule, AnnoTarget to,
1353 FModuleLike fromModule, AnnotationSet annos,
1354 SmallVectorImpl<Annotation> &newAnnotations,
1355 SmallPtrSetImpl<Attribute> &dontTouches) {
1356 for (auto anno : annos) {
1357 if (anno.isClass(dontTouchAnnoClass)) {
1358 // Remove the nonlocal field of the annotation if it has one, since this
1359 // is a sticky annotation.
1360 anno.removeMember("circt.nonlocal");
1361 auto [it, inserted] = dontTouches.insert(anno.getAttr());
1362 if (inserted)
1363 newAnnotations.push_back(anno);
1364 continue;
1365 }
1366 // If the annotation is already non-local, we add it as is. It is already
1367 // added to the target map.
1368 if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal")) {
1369 newAnnotations.push_back(anno);
1370 targetMap[nla.getAttr()].insert(to);
1371 continue;
1372 }
1373 // Otherwise make the annotation non-local and add it to the set.
1374 makeAnnotationNonLocal(toModule.getModuleNameAttr(), to, fromModule, anno,
1375 newAnnotations);
1376 }
1377 }
1378
1379 /// Merge the annotations of a specific target, either a operation or a port
1380 /// on an operation.
1381 void mergeAnnotations(FModuleLike toModule, AnnoTarget to,
1382 AnnotationSet toAnnos, FModuleLike fromModule,
1383 AnnoTarget from, AnnotationSet fromAnnos) {
1384 // This is a list of all the annotations which will be added to `to`.
1385 SmallVector<Annotation> newAnnotations;
1386
1387 // We have special case handling of DontTouch to prevent it from being
1388 // turned into a non-local annotation, and to remove duplicates.
1389 llvm::SmallPtrSet<Attribute, 4> dontTouches;
1390
1391 // Iterate the annotations, transforming most annotations into non-local
1392 // ones.
1393 copyAnnotations(toModule, to, toModule, toAnnos, newAnnotations,
1394 dontTouches);
1395 copyAnnotations(toModule, to, fromModule, fromAnnos, newAnnotations,
1396 dontTouches);
1397
1398 // Copy over all the new annotations.
1399 if (!newAnnotations.empty())
1400 to.setAnnotations(AnnotationSet(newAnnotations, context));
1401 }
1402
1403 /// Merge all annotations and port annotations on two operations.
1404 void mergeAnnotations(FModuleLike toModule, Operation *to,
1405 FModuleLike fromModule, Operation *from) {
1406 // Merge op annotations.
1407 mergeAnnotations(toModule, OpAnnoTarget(to), AnnotationSet(to), fromModule,
1408 OpAnnoTarget(from), AnnotationSet(from));
1409
1410 // Merge port annotations.
1411 if (toModule == to) {
1412 // Merge module port annotations.
1413 for (unsigned i = 0, e = getNumPorts(toModule); i < e; ++i)
1414 mergeAnnotations(toModule, PortAnnoTarget(toModule, i),
1415 AnnotationSet::forPort(toModule, i), fromModule,
1416 PortAnnoTarget(fromModule, i),
1417 AnnotationSet::forPort(fromModule, i));
1418 } else if (auto toMem = dyn_cast<MemOp>(to)) {
1419 // Merge memory port annotations.
1420 auto fromMem = cast<MemOp>(from);
1421 for (unsigned i = 0, e = toMem.getNumResults(); i < e; ++i)
1422 mergeAnnotations(toModule, PortAnnoTarget(toMem, i),
1423 AnnotationSet::forPort(toMem, i), fromModule,
1424 PortAnnoTarget(fromMem, i),
1425 AnnotationSet::forPort(fromMem, i));
1426 }
1427 }
1428
1429 hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule,
1430 hw::InnerSymAttr toSym,
1431 hw::InnerSymAttr fromSym) {
1432 if (fromSym && !fromSym.getProps().empty()) {
1433 auto &isn = getNamespace(toModule);
1434 // The properties for the new inner symbol..
1435 SmallVector<hw::InnerSymPropertiesAttr> newProps;
1436 // If the "to" op already has an inner symbol, copy all its properties.
1437 if (toSym)
1438 llvm::append_range(newProps, toSym);
1439 // Add each property from the fromSym to the toSym.
1440 for (auto fromProp : fromSym) {
1441 hw::InnerSymPropertiesAttr newProp;
1442 auto *it = llvm::find_if(newProps, [&](auto p) {
1443 return p.getFieldID() == fromProp.getFieldID();
1444 });
1445 if (it != newProps.end()) {
1446 // If we already have an inner sym with the same field id, use
1447 // that.
1448 newProp = *it;
1449 // If the old symbol is public, we need to make the new one public.
1450 if (fromProp.getSymVisibility().getValue() == "public" &&
1451 newProp.getSymVisibility().getValue() != "public") {
1452 *it = hw::InnerSymPropertiesAttr::get(context, newProp.getName(),
1453 newProp.getFieldID(),
1454 fromProp.getSymVisibility());
1455 }
1456 } else {
1457 // We need to add a new property to the inner symbol for this field.
1458 auto newName = isn.newName(fromProp.getName().getValue());
1459 newProp = hw::InnerSymPropertiesAttr::get(
1460 context, StringAttr::get(context, newName), fromProp.getFieldID(),
1461 fromProp.getSymVisibility());
1462 newProps.push_back(newProp);
1463 }
1464 renameMap[fromProp.getName()] = newProp.getName();
1465 }
1466 // Sort the fields by field id.
1467 llvm::sort(newProps, [](auto &p, auto &q) {
1468 return p.getFieldID() < q.getFieldID();
1469 });
1470 // Return the merged inner symbol.
1471 return hw::InnerSymAttr::get(context, newProps);
1472 }
1473 return hw::InnerSymAttr();
1474 }
1475
1476 // Record the symbol name change of the operation or any of its ports when
1477 // merging two operations. The renamed symbols are used to update the
1478 // target of any NLAs. This will add symbols to the "to" operation if needed.
1479 void recordSymRenames(RenameMap &renameMap, FModuleLike toModule,
1480 Operation *to, FModuleLike fromModule,
1481 Operation *from) {
1482 // If the "from" operation has an inner_sym, we need to make sure the
1483 // "to" operation also has an `inner_sym` and then record the renaming.
1484 if (auto fromInnerSym = dyn_cast<hw::InnerSymbolOpInterface>(from)) {
1485 auto toInnerSym = cast<hw::InnerSymbolOpInterface>(to);
1486 if (auto newSymAttr = mergeInnerSymbols(renameMap, toModule,
1487 toInnerSym.getInnerSymAttr(),
1488 fromInnerSym.getInnerSymAttr()))
1489 toInnerSym.setInnerSymbolAttr(newSymAttr);
1490 }
1491
1492 // If there are no port symbols on the "from" operation, we are done here.
1493 auto fromPortSyms = from->getAttrOfType<ArrayAttr>("portSymbols");
1494 if (!fromPortSyms || fromPortSyms.empty())
1495 return;
1496 // We have to map each "fromPort" to each "toPort".
1497 auto portCount = fromPortSyms.size();
1498 auto toPortSyms = to->getAttrOfType<ArrayAttr>("portSymbols");
1499
1500 // Create an array of new port symbols for the "to" operation, copy in the
1501 // old symbols if it has any, create an empty symbol array if it doesn't.
1502 SmallVector<Attribute> newPortSyms;
1503 if (toPortSyms.empty())
1504 newPortSyms.assign(portCount, hw::InnerSymAttr());
1505 else
1506 newPortSyms.assign(toPortSyms.begin(), toPortSyms.end());
1507
1508 for (unsigned portNo = 0; portNo < portCount; ++portNo) {
1509 if (auto newPortSym = mergeInnerSymbols(
1510 renameMap, toModule,
1511 llvm::cast_if_present<hw::InnerSymAttr>(newPortSyms[portNo]),
1512 cast<hw::InnerSymAttr>(fromPortSyms[portNo]))) {
1513 newPortSyms[portNo] = newPortSym;
1514 }
1515 }
1516
1517 // Commit the new symbol attribute.
1518 FModuleLike::fixupPortSymsArray(newPortSyms, toModule.getContext());
1519 cast<FModuleLike>(to).setPortSymbols(newPortSyms);
1520 }
1521
1522 /// Recursively merge two operations.
1523 // NOLINTNEXTLINE(misc-no-recursion)
1524 void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to,
1525 FModuleLike fromModule, Operation *from) {
1526 // Merge the operation locations.
1527 if (to->getLoc() != from->getLoc())
1528 to->setLoc(mergeLoc(context, to->getLoc(), from->getLoc()));
1529
1530 // Recurse into any regions.
1531 for (auto regions : llvm::zip(to->getRegions(), from->getRegions()))
1532 mergeRegions(renameMap, toModule, std::get<0>(regions), fromModule,
1533 std::get<1>(regions));
1534
1535 // Record any inner_sym renamings that happened.
1536 recordSymRenames(renameMap, toModule, to, fromModule, from);
1537
1538 // Merge the annotations.
1539 mergeAnnotations(toModule, to, fromModule, from);
1540 }
1541
1542 /// Recursively merge two blocks.
1543 void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock,
1544 FModuleLike fromModule, Block &fromBlock) {
1545 // Merge the block locations.
1546 for (auto [toArg, fromArg] :
1547 llvm::zip(toBlock.getArguments(), fromBlock.getArguments()))
1548 if (toArg.getLoc() != fromArg.getLoc())
1549 toArg.setLoc(mergeLoc(context, toArg.getLoc(), fromArg.getLoc()));
1550
1551 for (auto ops : llvm::zip(toBlock, fromBlock))
1552 mergeOps(renameMap, toModule, &std::get<0>(ops), fromModule,
1553 &std::get<1>(ops));
1554 }
1555
1556 // Recursively merge two regions.
1557 void mergeRegions(RenameMap &renameMap, FModuleLike toModule,
1558 Region &toRegion, FModuleLike fromModule,
1559 Region &fromRegion) {
1560 for (auto blocks : llvm::zip(toRegion, fromRegion))
1561 mergeBlocks(renameMap, toModule, std::get<0>(blocks), fromModule,
1562 std::get<1>(blocks));
1563 }
1564
1565 MLIRContext *context;
1567 SymbolTable &symbolTable;
1568
1569 /// Cached nla table analysis.
1570 NLATable *nlaTable = nullptr;
1571
1572 /// We insert all NLAs to the beginning of this block.
1573 Block *nlaBlock;
1574
1575 // This maps an NLA to the operations and ports that uses it.
1576 DenseMap<Attribute, llvm::SmallDenseSet<AnnoTarget>> targetMap;
1577
1578 // This is a cache to avoid creating duplicate NLAs. This maps the ArrayAtr
1579 // of the NLA's path to the name of the NLA which contains it.
1580 DenseMap<Attribute, Attribute> nlaCache;
1581
1582 // Cached attributes for faster comparisons and attribute building.
1583 StringAttr nonLocalString;
1584 StringAttr classString;
1585
1586 /// A module namespace cache.
1587 DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
1588};
1589
1590//===----------------------------------------------------------------------===//
1591// Fixup
1592//===----------------------------------------------------------------------===//
1593
1594/// This fixes up connects when the field names of a bundle type changes. It
1595/// finds all fields which were previously bulk connected and legalizes it
1596/// into a connect for each field.
1597static void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src) {
1598 // If the types already match we can emit a connect.
1599 auto dstType = dst.getType();
1600 auto srcType = src.getType();
1601 if (dstType == srcType) {
1602 emitConnect(builder, dst, src);
1603 return;
1604 }
1605 // It must be a bundle type and the field name has changed. We have to
1606 // manually decompose the bulk connect into a connect for each field.
1607 auto dstBundle = type_cast<BundleType>(dstType);
1608 auto srcBundle = type_cast<BundleType>(srcType);
1609 for (unsigned i = 0; i < dstBundle.getNumElements(); ++i) {
1610 auto dstField = SubfieldOp::create(builder, dst, i);
1611 auto srcField = SubfieldOp::create(builder, src, i);
1612 if (dstBundle.getElement(i).isFlip) {
1613 std::swap(srcBundle, dstBundle);
1614 std::swap(srcField, dstField);
1615 }
1616 fixupConnect(builder, dstField, srcField);
1617 }
1618}
1619
1620/// Adjust the symbol references in an op. This includes updating its attributes
1621/// and types.
1622static void
1623fixupSymbolSensitiveOp(Operation *op, InstanceGraph &instanceGraph,
1624 const DenseMap<Attribute, StringAttr> &dedupMap) {
1625 // If this is an instance op, dedup may have subtly changed the port types.
1626 // For example, structurally different bundles may still dedup. In this case
1627 // we now have an instance op that produces result values of the old type, but
1628 // the port info on the instantiated module already represents the new type.
1629 // Fix this up by going through an intermediate wire.
1630 if (auto instOp = dyn_cast<InstanceOp>(op)) {
1631 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp->getContext());
1632 builder.setInsertionPointAfter(instOp);
1633 auto module = instanceGraph.lookup(instOp.getModuleNameAttr().getAttr())
1634 ->getModule<FModuleLike>();
1635 for (auto [index, result] : llvm::enumerate(instOp.getResults())) {
1636 auto newType = module.getPortType(index);
1637 auto oldType = result.getType();
1638 // If the type has not changed, we don't have to fix up anything.
1639 if (newType == oldType)
1640 continue;
1641 LLVM_DEBUG(llvm::dbgs()
1642 << "- Updating instance port \"" << instOp.getInstanceName()
1643 << "." << instOp.getPortName(index) << "\" from " << oldType
1644 << " to " << newType << "\n");
1645
1646 // If the type changed we transform it back to the old type with an
1647 // intermediate wire.
1648 auto wire = WireOp::create(builder, oldType, instOp.getPortName(index))
1649 .getResult();
1650 result.replaceAllUsesWith(wire);
1651 result.setType(newType);
1652 if (instOp.getPortDirection(index) == Direction::Out)
1653 fixupConnect(builder, wire, result);
1654 else
1655 fixupConnect(builder, result, wire);
1656 }
1657 }
1658
1659 // Use an attribute/type replacer to look for references to old symbols that
1660 // need to be replaced with new symbols.
1661 mlir::AttrTypeReplacer replacer;
1662 replacer.addReplacement([&](FlatSymbolRefAttr symRef) {
1663 auto oldName = symRef.getAttr();
1664 auto newName = dedupMap.lookup(oldName);
1665 if (newName && newName != oldName) {
1666 auto newSymRef = FlatSymbolRefAttr::get(newName);
1667 LLVM_DEBUG(llvm::dbgs()
1668 << "- Updating " << symRef << " to " << newSymRef << " in "
1669 << op->getName() << " at " << op->getLoc() << "\n");
1670 return newSymRef;
1671 }
1672 return symRef;
1673 });
1674
1675 // Update attributes.
1676 op->setAttrs(cast<DictionaryAttr>(replacer.replace(op->getAttrDictionary())));
1677
1678 // Update the argument types.
1679 for (auto &region : op->getRegions())
1680 for (auto &block : region)
1681 for (auto arg : block.getArguments())
1682 arg.setType(replacer.replace(arg.getType()));
1683
1684 // Update result types.
1685 for (auto result : op->getResults())
1686 result.setType(replacer.replace(result.getType()));
1687}
1688
1689/// Adjust the symbol references in ops marked as sensitive to them. This
1690/// includes updating their attributes and types.
1692 InstanceGraph &instanceGraph,
1693 const DenseMap<Operation *, ModuleInfoRef> &moduleToModuleInfo,
1694 const DenseMap<Attribute, StringAttr> &dedupMap) {
1695 for (auto *node : instanceGraph) {
1696 // Look up the module info for this module, which contains the list of ops
1697 // that need to be updated.
1698 auto module = node->getModule<FModuleLike>();
1699 auto it = moduleToModuleInfo.find(module);
1700 if (it == moduleToModuleInfo.end())
1701 continue;
1702
1703 // Update each symbol-sensitive op individually.
1704 auto &ops = it->second.info->symbolSensitiveOps;
1705 if (ops.empty())
1706 continue;
1707 LLVM_DEBUG(llvm::dbgs()
1708 << "- Updating " << ops.size() << " symbol-sensitive ops in "
1709 << module.getNameAttr() << "\n");
1710 for (auto *op : ops)
1711 fixupSymbolSensitiveOp(op, instanceGraph, dedupMap);
1712 }
1713}
1714
1715//===----------------------------------------------------------------------===//
1716// DedupPass
1717//===----------------------------------------------------------------------===//
1718
1719namespace {
1720class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {
1721 using DedupBase::DedupBase;
1722
1723 void runOnOperation() override {
1724 auto *context = &getContext();
1725 auto circuit = getOperation();
1726 auto &instanceGraph = getAnalysis<InstanceGraph>();
1727 auto *nlaTable = &getAnalysis<NLATable>();
1728 auto &symbolTable = getAnalysis<SymbolTable>();
1729 Deduper deduper(instanceGraph, symbolTable, nlaTable, circuit);
1730 Equivalence equiv(context, instanceGraph);
1731 auto anythingChanged = false;
1732 LLVM_DEBUG({
1733 llvm::dbgs() << "\n";
1734 debugHeader(Twine("Dedup circuit \"") + circuit.getName() + "\"")
1735 << "\n\n";
1736 });
1737
1738 // Modules annotated with this should not be considered for deduplication.
1739 auto noDedupClass = StringAttr::get(context, noDedupAnnoClass);
1740
1741 // Only modules within the same group may be deduplicated.
1742 auto dedupGroupClass = StringAttr::get(context, dedupGroupAnnoClass);
1743
1744 // A map of all the module moduleInfo that we have calculated so far.
1745 DenseMap<ModuleInfoRef, Operation *> moduleInfoToModule;
1746 DenseMap<Operation *, ModuleInfoRef> moduleToModuleInfo;
1747
1748 // We track the name of the module that each module is deduped into, so that
1749 // we can make sure all modules which are marked "must dedup" with each
1750 // other were all deduped to the same module.
1751 DenseMap<Attribute, StringAttr> dedupMap;
1752
1753 // We must iterate the modules from the bottom up so that we can properly
1754 // deduplicate the modules. We copy the list of modules into a vector first
1755 // to avoid iterator invalidation while we mutate the instance graph.
1756 SmallVector<FModuleLike, 0> modules;
1757 instanceGraph.walkPostOrder([&](auto &node) {
1758 if (auto mod = dyn_cast<FModuleLike>(*node.getModule()))
1759 modules.push_back(mod);
1760 });
1761 LLVM_DEBUG(llvm::dbgs() << "Found " << modules.size() << " modules\n");
1762
1763 SmallVector<std::optional<ModuleInfo>> moduleInfos(modules.size());
1764 StructuralHasherSharedConstants hasherConstants(&getContext());
1765
1766 // Attribute name used to store dedup_group for this pass.
1767 auto dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
1768
1769 // Move dedup group annotations to attributes on the module.
1770 // This results in the desired behavior (included in hash),
1771 // and avoids unnecessary processing of these as annotations
1772 // that need to be tracked, made non-local, so on.
1773 for (auto module : modules) {
1774 llvm::SmallSetVector<StringAttr, 1> groups;
1776 module, [&groups, dedupGroupClass](Annotation annotation) {
1777 if (annotation.getClassAttr() != dedupGroupClass)
1778 return false;
1779 groups.insert(annotation.getMember<StringAttr>("group"));
1780 return true;
1781 });
1782 if (groups.size() > 1) {
1783 module.emitError("module belongs to multiple dedup groups: ") << groups;
1784 return signalPassFailure();
1785 }
1786 assert(!module->hasAttr(dedupGroupAttrName) &&
1787 "unexpected existing use of temporary dedup group attribute");
1788 if (!groups.empty())
1789 module->setDiscardableAttr(dedupGroupAttrName, groups.front());
1790 }
1791
1792 // Calculate module information parallelly.
1793 LLVM_DEBUG(llvm::dbgs() << "Computing module information\n");
1794 auto result = mlir::failableParallelForEach(
1795 context, llvm::seq(modules.size()), [&](unsigned idx) {
1796 auto module = modules[idx];
1797 // If the module is marked with NoDedup, just skip it.
1798 if (AnnotationSet::hasAnnotation(module, noDedupClass))
1799 return success();
1800
1801 // Only dedup extmodule's with defname.
1802 if (auto ext = dyn_cast<FExtModuleOp>(*module);
1803 ext && !ext.getDefname().has_value())
1804 return success();
1805
1806 // Only dedup classes if enabled.
1807 if (isa<ClassOp>(*module) && !dedupClasses)
1808 return success();
1809
1810 StructuralHasher hasher(hasherConstants);
1811 // Calculate the hash of the module and referred module names.
1812 moduleInfos[idx] = hasher.getModuleInfo(module);
1813 return success();
1814 });
1815
1816 // Dump out the module hashes for debugging.
1817 LLVM_DEBUG({
1818 auto &os = llvm::dbgs();
1819 for (auto [module, info] : llvm::zip(modules, moduleInfos)) {
1820 os << "- Hash ";
1821 if (info) {
1822 os << llvm::format_bytes(info->structuralHash, std::nullopt, 32, 32);
1823 } else {
1824 os << "--------------------------------";
1825 os << "--------------------------------";
1826 }
1827 os << " for " << module.getModuleNameAttr() << "\n";
1828 }
1829 });
1830
1831 if (result.failed())
1832 return signalPassFailure();
1833
1834 LLVM_DEBUG(llvm::dbgs() << "Update modules\n");
1835 for (auto [i, module] : llvm::enumerate(modules)) {
1836 auto moduleName = module.getModuleNameAttr();
1837 auto &maybeModuleInfo = moduleInfos[i];
1838 // If the hash was not calculated, we need to skip it.
1839 if (!maybeModuleInfo) {
1840 // We record it in the dedup map to help detect errors when the user
1841 // marks the module as both NoDedup and MustDedup. We do not record this
1842 // module in the hasher to make sure no other module dedups "into" this
1843 // one.
1844 dedupMap[moduleName] = moduleName;
1845 continue;
1846 }
1847
1848 auto &moduleInfo = maybeModuleInfo.value();
1849 moduleToModuleInfo.try_emplace(module, &moduleInfo);
1850
1851 // Replace module names referred in the module with new names.
1852 for (auto &referredModule : moduleInfo.referredModuleNames)
1853 referredModule = dedupMap[referredModule];
1854
1855 // Check if there is a module with the same hash.
1856 auto it = moduleInfoToModule.find(&moduleInfo);
1857 if (it != moduleInfoToModule.end()) {
1858 auto original = cast<FModuleLike>(it->second);
1859 auto originalName = original.getModuleNameAttr();
1860
1861 // If the current module is public, and the original is private, we
1862 // want to dedup the private module into the public one.
1863 if (!canRemoveModule(module)) {
1864 // Record that this module's name is staying the same.
1865 dedupMap[moduleName] = moduleName;
1866 // If both modules are public, then we can't dedup anything.
1867 if (!canRemoveModule(original))
1868 continue;
1869 // Swap the canonical module in the dedup map.
1870 for (auto &[_, dedupedName] : dedupMap)
1871 if (dedupedName == originalName)
1872 dedupedName = moduleName;
1873 // Update the module hash table to point to the new original, so all
1874 // future modules dedup with the new canonical module.
1875 it->second = module;
1876 // Swap the locals.
1877 std::swap(originalName, moduleName);
1878 std::swap(original, module);
1879 }
1880
1881 // Record the group ID of the other module.
1882 LLVM_DEBUG(llvm::dbgs() << "- Replace " << moduleName << " with "
1883 << originalName << "\n");
1884 dedupMap[moduleName] = originalName;
1885 deduper.dedup(original, module);
1886 ++erasedModules;
1887 anythingChanged = true;
1888 continue;
1889 }
1890 // Any module not deduplicated must be recorded.
1891 deduper.record(module);
1892 // Add the module to a new dedup group.
1893 dedupMap[moduleName] = moduleName;
1894 // Record the module info.
1895 moduleInfoToModule[&moduleInfo] = module;
1896 }
1897
1898 // This part verifies that all modules marked by "MustDedup" have been
1899 // properly deduped with each other. For this check to succeed, all modules
1900 // have to been deduped to the same module. It is possible that a module was
1901 // deduped with the wrong thing.
1902
1903 auto failed = false;
1904 // This parses the module name out of a target string.
1905 auto parseModule = [&](Attribute path) -> StringAttr {
1906 // Each module is listed as a target "~Circuit|Module" which we have to
1907 // parse.
1908 auto [_, rhs] = cast<StringAttr>(path).getValue().split('|');
1909 return StringAttr::get(context, rhs);
1910 };
1911 // This gets the name of the module which the current module was deduped
1912 // with. If the named module isn't in the map, then we didn't encounter it
1913 // in the circuit.
1914 auto getLead = [&](StringAttr module) -> StringAttr {
1915 auto it = dedupMap.find(module);
1916 if (it == dedupMap.end()) {
1917 auto diag = emitError(circuit.getLoc(),
1918 "MustDeduplicateAnnotation references module ")
1919 << module << " which does not exist";
1920 failed = true;
1921 return nullptr;
1922 }
1923 return it->second;
1924 };
1925
1926 LLVM_DEBUG(llvm::dbgs() << "Update annotations\n");
1927 AnnotationSet::removeAnnotations(circuit, [&](Annotation annotation) {
1928 if (!annotation.isClass(mustDeduplicateAnnoClass))
1929 return false;
1930 auto modules = annotation.getMember<ArrayAttr>("modules");
1931 if (!modules) {
1932 emitError(circuit.getLoc(),
1933 "MustDeduplicateAnnotation missing \"modules\" member");
1934 failed = true;
1935 return false;
1936 }
1937 // Empty module list has nothing to process.
1938 if (modules.empty())
1939 return true;
1940 // Get the first element.
1941 auto firstModule = parseModule(modules[0]);
1942 auto firstLead = getLead(firstModule);
1943 if (!firstLead)
1944 return false;
1945 // Verify that the remaining elements are all the same as the first.
1946 for (auto attr : modules.getValue().drop_front()) {
1947 auto nextModule = parseModule(attr);
1948 auto nextLead = getLead(nextModule);
1949 if (!nextLead)
1950 return false;
1951 if (firstLead != nextLead) {
1952 auto diag = emitError(circuit.getLoc(), "module ")
1953 << nextModule << " not deduplicated with " << firstModule;
1954 auto a = instanceGraph.lookup(firstLead)->getModule();
1955 auto b = instanceGraph.lookup(nextLead)->getModule();
1956 equiv.check(diag, a, b);
1957 failed = true;
1958 return false;
1959 }
1960 }
1961 return true;
1962 });
1963 if (failed)
1964 return signalPassFailure();
1965
1966 // Remove all dedup group attributes, they only exist during this pass.
1967 for (auto module : circuit.getOps<FModuleLike>())
1968 module->removeDiscardableAttr(dedupGroupAttrName);
1969
1970 // Fixup all operations that we've found to be sensitive to symbol names.
1971 // This includes module and class instances, wires, connects, etc.
1972 fixupSymbolSensitiveOps(instanceGraph, moduleToModuleInfo, dedupMap);
1973
1974 markAnalysesPreserved<NLATable>();
1975 if (!anythingChanged)
1976 markAllAnalysesPreserved();
1977 }
1978};
1979} // end anonymous namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location mergeLoc(MLIRContext *context, Location to, Location from)
Definition Dedup.cpp:971
static void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src)
This fixes up connects when the field names of a bundle type changes.
Definition Dedup.cpp:1597
static bool canRemoveModule(mlir::SymbolOpInterface symbol)
Returns true if the module can be removed.
Definition Dedup.cpp:56
static void fixupSymbolSensitiveOp(Operation *op, InstanceGraph &instanceGraph, const DenseMap< Attribute, StringAttr > &dedupMap)
Adjust the symbol references in an op.
Definition Dedup.cpp:1623
static void fixupSymbolSensitiveOps(InstanceGraph &instanceGraph, const DenseMap< Operation *, ModuleInfoRef > &moduleToModuleInfo, const DenseMap< Attribute, StringAttr > &dedupMap)
Adjust the symbol references in ops marked as sensitive to them.
Definition Dedup.cpp:1691
static void mergeRegions(Region *region1, Region *region2)
Definition HWCleanup.cpp:77
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
void setDict(DictionaryAttr dict)
Set the data dictionary of this attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
StringAttr getClassAttr() const
Return the 'class' that this annotation is representing.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
A table of inner symbols and their resolutions.
auto getModule()
Get the module that this node is tracking.
decltype(auto) walkPostOrder(Fn &&fn)
Perform a post-order walk across the modules.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
static StringRef toString(Direction direction)
Definition FIRRTLEnums.h:44
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:35
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(const llvm::Twine &str, unsigned width=80)
Write a "header"-like string to the debug stream with a certain width.
Definition Debug.cpp:17
SmallVector< FlatSymbolRefAttr > createNLAs(Operation *fromModule, ArrayRef< Attribute > baseNamepath, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of the from module and create an NLA for each one, appending the baseNamep...
Definition Dedup.cpp:1159
MLIRContext * context
Definition Dedup.cpp:1565
Block * nlaBlock
We insert all NLAs to the beginning of this block.
Definition Dedup.cpp:1573
void recordAnnotations(Operation *op)
Record all targets which use an NLA.
Definition Dedup.cpp:1101
void eraseNLA(hw::HierPathOp nla)
This erases the NLA op, and removes the NLA from every module's NLA map, but it does not delete the N...
Definition Dedup.cpp:1223
void mergeAnnotations(FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Merge all annotations and port annotations on two operations.
Definition Dedup.cpp:1404
void replaceInstances(FModuleLike toModule, Operation *fromModule)
This deletes and replaces all instances of the "fromModule" with instances of the "toModule".
Definition Dedup.cpp:1117
void record(FModuleLike module)
Record the usages of any NLA's in this module, so that we may update the annotation if the parent mod...
Definition Dedup.cpp:1075
void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName, StringAttr fromName)
Definition Dedup.cpp:1302
void mergeRegions(RenameMap &renameMap, FModuleLike toModule, Region &toRegion, FModuleLike fromModule, Region &fromRegion)
Definition Dedup.cpp:1557
void dedup(FModuleLike toModule, FModuleLike fromModule)
Remove the "fromModule", and replace all references to it with the "toModule".
Definition Dedup.cpp:1040
void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all the NLAs that the two modules participate in, replacing references to the "from" module w...
Definition Dedup.cpp:1294
SmallVector< FlatSymbolRefAttr > createNLAs(StringAttr toModuleName, FModuleLike fromModule, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of this module and create an NLA for each one.
Definition Dedup.cpp:1196
void recordAnnotations(AnnoTarget target)
For a specific annotation target, record all the unique NLAs which target it in the targetMap.
Definition Dedup.cpp:1094
NLATable * nlaTable
Cached nla table analysis.
Definition Dedup.cpp:1570
hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule, hw::InnerSymAttr toSym, hw::InnerSymAttr fromSym)
Definition Dedup.cpp:1429
void cloneAnnotation(SmallVectorImpl< FlatSymbolRefAttr > &nlas, Annotation anno, ArrayRef< NamedAttribute > attributes, unsigned nonLocalIndex, SmallVectorImpl< Annotation > &newAnnotations)
Clone the annotation for each NLA in a list.
Definition Dedup.cpp:1204
void recordSymRenames(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Definition Dedup.cpp:1479
void mergeAnnotations(FModuleLike toModule, AnnoTarget to, AnnotationSet toAnnos, FModuleLike fromModule, AnnoTarget from, AnnotationSet fromAnnos)
Merge the annotations of a specific target, either a operation or a port on an operation.
Definition Dedup.cpp:1381
StringAttr nonLocalString
Definition Dedup.cpp:1583
hw::InnerSymbolNamespace & getNamespace(Operation *module)
Get a cached namespace for a module.
Definition Dedup.cpp:1087
SymbolTable & symbolTable
Definition Dedup.cpp:1567
void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Recursively merge two operations.
Definition Dedup.cpp:1524
DenseMap< Operation *, hw::InnerSymbolNamespace > moduleNamespaces
A module namespace cache.
Definition Dedup.cpp:1587
bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to, FModuleLike fromModule, Annotation anno, SmallVectorImpl< Annotation > &newAnnotations)
Take an annotation, and update it to be a non-local annotation.
Definition Dedup.cpp:1310
InstanceGraph & instanceGraph
Definition Dedup.cpp:1566
void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock, FModuleLike fromModule, Block &fromBlock)
Recursively merge two blocks.
Definition Dedup.cpp:1543
DenseMap< Attribute, llvm::SmallDenseSet< AnnoTarget > > targetMap
Definition Dedup.cpp:1576
StringAttr classString
Definition Dedup.cpp:1584
void copyAnnotations(FModuleLike toModule, AnnoTarget to, FModuleLike fromModule, AnnotationSet annos, SmallVectorImpl< Annotation > &newAnnotations, SmallPtrSetImpl< Attribute > &dontTouches)
Definition Dedup.cpp:1352
Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable, NLATable *nlaTable, CircuitOp circuit)
Definition Dedup.cpp:1024
void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all NLAs referencing the "from" module to point to the "to" module.
Definition Dedup.cpp:1233
DenseMap< StringAttr, StringAttr > RenameMap
Definition Dedup.cpp:1022
DenseMap< Attribute, Attribute > nlaCache
Definition Dedup.cpp:1580
const hw::InnerSymbolTable & a
Definition Dedup.cpp:488
ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
Definition Dedup.cpp:485
const hw::InnerSymbolTable & b
Definition Dedup.cpp:489
This class is for reporting differences between two modules which should have been deduplicated.
Definition Dedup.cpp:467
LogicalResult check(InFlightDiagnostic &diag, igraph::InstanceOpInterface a, igraph::InstanceOpInterface b)
Definition Dedup.cpp:775
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:958
std::string prettyPrint(Attribute attr)
Definition Dedup.cpp:492
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Block &aBlock, Operation *b, Block &bBlock)
Definition Dedup.cpp:564
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, Type aType, Operation *b, Type bType)
Definition Dedup.cpp:538
StringAttr noDedupClass
Definition Dedup.cpp:953
StringAttr dedupGroupAttrName
Definition Dedup.cpp:955
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, DictionaryAttr aDict, Operation *b, DictionaryAttr bDict)
Definition Dedup.cpp:687
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Region &aRegion, Operation *b, Region &bRegion)
Definition Dedup.cpp:642
StringAttr portDirectionsAttr
Definition Dedup.cpp:951
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, BundleType aType, Operation *b, BundleType bType)
Definition Dedup.cpp:508
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Operation *b)
Definition Dedup.cpp:813
Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
Definition Dedup.cpp:468
LogicalResult check(InFlightDiagnostic &diag, Operation *a, mlir::DenseBoolArrayAttr aAttr, Operation *b, mlir::DenseBoolArrayAttr bAttr)
Definition Dedup.cpp:662
InstanceGraph & instanceGraph
Definition Dedup.cpp:959
void check(InFlightDiagnostic &diag, Operation *a, Operation *b)
Definition Dedup.cpp:902
A reference to a ModuleInfo that compares and hashes like it.
Definition Dedup.cpp:422
ModuleInfo * info
Definition Dedup.cpp:424
ModuleInfoRef(ModuleInfo *info)
Definition Dedup.cpp:423
std::vector< Operation * > symbolSensitiveOps
Definition Dedup.cpp:91
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:87
std::array< uint8_t, 32 > structuralHash
Definition Dedup.cpp:85
This struct contains constant string attributes shared across different threads.
Definition Dedup.cpp:101
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:131
StructuralHasherSharedConstants(MLIRContext *context)
Definition Dedup.cpp:102
void populateInnerSymIDTable(FModuleLike module)
Find all the ports and operations which may define an inner symbol operations and give each a unique ...
Definition Dedup.cpp:151
void update(Operation *op, DictionaryAttr dict)
Hash the top level attribute dictionary of the operation.
Definition Dedup.cpp:274
void update(Type type)
Definition Dedup.cpp:251
void update(const void *pointer)
Definition Dedup.cpp:208
void update(ClassType type)
Definition Dedup.cpp:236
DenseMap< void *, unsigned > idTable
Definition Dedup.cpp:394
void update(const std::pair< T, U > &pair)
Definition Dedup.cpp:219
void update(Operation *op)
Definition Dedup.cpp:360
llvm::SHA256 sha
Definition Dedup.cpp:408
DenseMap< StringAttr, std::pair< size_t, size_t > > innerSymIDTable
Definition Dedup.cpp:398
ModuleInfo getModuleInfo(FModuleLike module)
Definition Dedup.cpp:138
void update(size_t value)
Definition Dedup.cpp:213
void update(BundleType type)
Definition Dedup.cpp:227
unsigned getID(void *object)
Definition Dedup.cpp:170
void update(OpResult result)
Definition Dedup.cpp:259
void update(OpOperand &operand)
Definition Dedup.cpp:191
StructuralHasher(const StructuralHasherSharedConstants &constants)
Definition Dedup.cpp:135
std::vector< Operation * > symbolSensitiveOps
Definition Dedup.cpp:416
void update(Region *region)
Definition Dedup.cpp:352
void update(Block *block)
Definition Dedup.cpp:341
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:401
void update(TypeID typeID)
Definition Dedup.cpp:224
const StructuralHasherSharedConstants & constants
Definition Dedup.cpp:404
std::pair< size_t, size_t > getInnerSymID(StringAttr name)
Definition Dedup.cpp:187
unsigned finalizeID(void *object)
Definition Dedup.cpp:178
void update(mlir::OperationName name)
Definition Dedup.cpp:335
An annotation target is used to keep track of something that is targeted by an Annotation.
AnnotationSet getAnnotations() const
Get the annotations associated with the target.
void setAnnotations(AnnotationSet annotations) const
Set the annotations associated with the target.
This represents an annotation targeting a specific operation.
Attribute getNLAReference(hw::InnerSymbolNamespace &moduleNamespace) const
This represents an annotation targeting a specific port of a module, memory, or instance.
static bool isEqual(const ModuleInfoRef &lhs, const ModuleInfoRef &rhs)
Definition Dedup.cpp:451
static ModuleInfoRef getTombstoneKey()
Definition Dedup.cpp:435
static ModuleInfoRef getEmptyKey()
Definition Dedup.cpp:431
static unsigned getHashValue(const ModuleInfoRef &ref)
Definition Dedup.cpp:439