CIRCT 22.0.0git
Loading...
Searching...
No Matches
Dedup.cpp
Go to the documentation of this file.
1//===- Dedup.cpp - FIRRTL module deduping -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements FIRRTL module deduplication.
10//
11//===----------------------------------------------------------------------===//
12
23#include "circt/Support/Debug.h"
24#include "circt/Support/LLVM.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/Threading.h"
27#include "mlir/Pass/Pass.h"
28#include "llvm/ADT/DenseMap.h"
29#include "llvm/ADT/DenseMapInfo.h"
30#include "llvm/ADT/Hashing.h"
31#include "llvm/ADT/PostOrderIterator.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Debug.h"
34#include "llvm/Support/Format.h"
35#include "llvm/Support/SHA256.h"
36
37#define DEBUG_TYPE "firrtl-dedup"
38
39namespace circt {
40namespace firrtl {
41#define GEN_PASS_DEF_DEDUP
42#include "circt/Dialect/FIRRTL/Passes.h.inc"
43} // namespace firrtl
44} // namespace circt
45
46using namespace circt;
47using namespace firrtl;
48using hw::InnerRefAttr;
49
50//===----------------------------------------------------------------------===//
51// Utility function for classifying a Symbol's dedup-ability.
52//===----------------------------------------------------------------------===//
53
54/// Returns true if the module can be removed.
55static bool canRemoveModule(mlir::SymbolOpInterface symbol) {
56 // If the symbol is not private, it cannot be removed.
57 if (!symbol.isPrivate())
58 return false;
59 // Classes may be referenced in object types, so can not normally be removed
60 // if we can't find any symbol uses. Since we know that dedup will update the
61 // types of instances appropriately, we can ignore that and return true here.
62 if (isa<ClassLike>(*symbol))
63 return true;
64 // If module can not be removed even if no uses can be found, we can not
65 // delete it. The implication is that there are hidden symbol uses that dedup
66 // will not properly update.
67 if (!symbol.canDiscardOnUseEmpty())
68 return false;
69 // The module can be deleted.
70 return true;
71}
72
73//===----------------------------------------------------------------------===//
74// Hashing
75//===----------------------------------------------------------------------===//
76
77// This struct contains information to determine module module uniqueness. A
78// first element is a structural hash of the module, and the second element is
79// an array which tracks module names encountered in the walk. Since module
80// names could be replaced during dedup, it's necessary to keep names up-to-date
81// before actually combining them into structural hashes.
82struct ModuleInfo {
83 // SHA256 hash.
84 std::array<uint8_t, 32> structuralHash;
85 // Module names referred by instance op in the module.
86 std::vector<StringAttr> referredModuleNames;
87 // The operations that contain references to symbols that may be changed by
88 // dedup. These need a fixup pass after dedup. This is just an optimization
89 // and does not factor into the hash or the equality between `ModuleInfo`s.
90 std::vector<Operation *> symbolSensitiveOps;
91};
92
93static bool operator==(const ModuleInfo &lhs, const ModuleInfo &rhs) {
94 return lhs.structuralHash == rhs.structuralHash &&
96}
97
98/// This struct contains constant string attributes shared across different
99/// threads.
102 portTypesAttr = StringAttr::get(context, "portTypes");
103 moduleNameAttr = StringAttr::get(context, "moduleName");
104 portNamesAttr = StringAttr::get(context, "portNames");
105 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
106 nonessentialAttributes.insert(StringAttr::get(context, "convention"));
107 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
108 nonessentialAttributes.insert(StringAttr::get(context, "name"));
109 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
110 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
111 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
112 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
113 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
114 nonessentialAttributes.insert(StringAttr::get(context, "sym_visibility"));
115 };
116
117 // This is a cached "portTypes" string attr.
118 StringAttr portTypesAttr;
119
120 // This is a cached "moduleName" string attr.
121 StringAttr moduleNameAttr;
122
123 // This is a cached "portNames" string attr.
124 StringAttr portNamesAttr;
125
126 // This is a set of every attribute we should ignore.
127 DenseSet<Attribute> nonessentialAttributes;
128};
129
132 : constants(constants) {}
133
134 ModuleInfo getModuleInfo(FModuleLike module) {
136 update(&(*module));
137 return {sha.final(), std::move(referredModuleNames),
138 std::move(symbolSensitiveOps)};
139 }
140
141private:
142 /// Find all the ports and operations which may define an inner symbol
143 /// operations and give each a unique id. If the port/operation does define
144 /// an inner symbol, map the symbol name to a pair of the id and the symbol's
145 /// field id. When we hash (local) references to this inner symbol, we will
146 /// hash in the id and the the field id.
147 void populateInnerSymIDTable(FModuleLike module) {
148 // Add port symbols. If no port has a symbol defined, the port symbol array
149 // will be totally empty.
150 for (auto [index, innerSym] : llvm::enumerate(module.getPortSymbols())) {
151 for (auto prop : cast<hw::InnerSymAttr>(innerSym))
152 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
153 }
154 // Add operation symbols.
155 size_t index = module.getNumPorts();
156 module.walk([&](hw::InnerSymbolOpInterface innerSymOp) {
157 if (auto innerSym = innerSymOp.getInnerSymAttr()) {
158 for (auto prop : innerSym)
159 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
160 }
161 ++index;
162 });
163 }
164
165 // Get the identifier for an object. The identifier is assigned on first use.
166 unsigned getID(void *object) {
167 auto [it, inserted] = idTable.try_emplace(object, nextID);
168 if (inserted)
169 ++nextID;
170 return it->second;
171 }
172
173 // Get the identifier for an IR object. Free the ID, too.
174 unsigned finalizeID(void *object) {
175 auto it = idTable.find(object);
176 if (it == idTable.end())
177 return nextID++;
178 auto id = it->second;
179 idTable.erase(it);
180 return id;
181 }
182
183 std::pair<size_t, size_t> getInnerSymID(StringAttr name) {
184 return innerSymIDTable.at(name);
185 }
186
187 void update(OpOperand &operand) {
188 auto value = operand.get();
189 if (auto result = dyn_cast<OpResult>(value)) {
190 auto *op = result.getOwner();
191 update(getID(op));
192 update(result.getResultNumber());
193 return;
194 }
195 if (auto argument = dyn_cast<BlockArgument>(value)) {
196 auto *block = argument.getOwner();
197 update(getID(block));
198 update(argument.getArgNumber());
199 return;
200 }
201 llvm_unreachable("Unknown value type");
202 }
203
204 void update(const void *pointer) {
205 auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
206 sha.update(ArrayRef<uint8_t>(addr, sizeof pointer));
207 }
208
209 void update(size_t value) {
210 auto *addr = reinterpret_cast<const uint8_t *>(&value);
211 sha.update(ArrayRef<uint8_t>(addr, sizeof value));
212 }
213
214 template <typename T, typename U>
215 void update(const std::pair<T, U> &pair) {
216 update(pair.first);
217 update(pair.second);
218 }
219
220 void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
221
222 // NOLINTNEXTLINE(misc-no-recursion)
223 void update(BundleType type) {
224 update(type.getTypeID());
225 for (auto &element : type.getElements()) {
226 update(element.isFlip);
227 update(element.type);
228 }
229 }
230
231 // NOLINTNEXTLINE(misc-no-recursion)
232 void update(ClassType type) {
233 update(type.getTypeID());
234 // Don't hash the class name directly, since it may be replaced during
235 // dedup. Record the class name instead and lazily combine their hashes
236 // using the same mechanism as instances and modules.
237 hasSeenSymbol = true;
238 referredModuleNames.push_back(type.getNameAttr().getAttr());
239 for (auto &element : type.getElements()) {
240 update(element.name.getAsOpaquePointer());
241 update(element.type);
242 update(static_cast<unsigned>(element.direction));
243 }
244 }
245
246 // NOLINTNEXTLINE(misc-no-recursion)
247 void update(Type type) {
248 if (auto bundle = type_dyn_cast<BundleType>(type))
249 return update(bundle);
250 if (auto klass = type_dyn_cast<ClassType>(type))
251 return update(klass);
252 update(type.getAsOpaquePointer());
253 }
254
255 void update(OpResult result) {
256 // Like instance ops, don't use object ops' result types since they might be
257 // replaced by dedup. Record the class names and lazily combine their hashes
258 // using the same mechanism as instances and modules.
259 if (auto objectOp = dyn_cast<ObjectOp>(result.getOwner())) {
260 hasSeenSymbol = true;
261 referredModuleNames.push_back(objectOp.getType().getNameAttr().getAttr());
262 return;
263 }
264
265 update(result.getType());
266 }
267
268 /// Hash the top level attribute dictionary of the operation. This function
269 /// has special handling for inner symbols, ports, and referenced modules.
270 void update(Operation *op, DictionaryAttr dict) {
271 for (auto namedAttr : dict) {
272 auto name = namedAttr.getName();
273 auto value = namedAttr.getValue();
274
275 // Check whether this attribute contains a nested symbol, just to make
276 // sure we revisit this op after dedup and update any symbols.
277 value.walk([&](FlatSymbolRefAttr) { hasSeenSymbol = true; });
278
279 // Skip names and annotations, except in certain cases.
280 // Names of ports are load bearing for classes, so we do hash those.
281 bool isClassPortNames =
282 isa<ClassLike>(op) && name == constants.portNamesAttr;
283 if (constants.nonessentialAttributes.contains(name) && !isClassPortNames)
284 continue;
285
286 // Hash the attribute name (an interned pointer).
287 update(name.getAsOpaquePointer());
288
289 // Hash the port types.
290 if (name == constants.portTypesAttr) {
291 auto portTypes = cast<ArrayAttr>(value).getAsValueRange<TypeAttr>();
292 for (auto type : portTypes)
293 update(type);
294 continue;
295 }
296
297 // For instance op, don't use `moduleName` attributes since they might be
298 // replaced by dedup. Record the names and lazily combine their hashes.
299 // It is assumed that module names are hashed only through instance ops;
300 // it could cause suboptimal results if there was other operation that
301 // refers to module names through essential attributes.
302 if (isa<InstanceOp>(op) && name == constants.moduleNameAttr) {
303 referredModuleNames.push_back(cast<FlatSymbolRefAttr>(value).getAttr());
304 continue;
305 }
306
307 // TODO: properly handle DistinctAttr, including its use in paths.
308 // See https://github.com/llvm/circt/issues/6583.
309 if (isa<DistinctAttr>(value))
310 continue;
311
312 // If this is an symbol reference, we need to perform name erasure.
313 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(value)) {
314 update(getInnerSymID(innerRef.getName()));
315 continue;
316 }
317
318 // We don't need to handle this attribute specially, so hash its unique
319 // address.
320 update(value.getAsOpaquePointer());
321 }
322 }
323
324 void update(mlir::OperationName name) {
325 // Operation names are interned.
326 update(name.getAsOpaquePointer());
327 }
328
329 // NOLINTNEXTLINE(misc-no-recursion)
330 void update(Block *block) {
331 for (auto &op : llvm::reverse(*block))
332 update(&op);
333 for (auto type : block->getArgumentTypes())
334 update(type);
335 update(finalizeID(block));
336 update(position);
337 ++position;
338 }
339
340 // NOLINTNEXTLINE(misc-no-recursion)
341 void update(Region *region) {
342 for (auto &block : llvm::reverse(region->getBlocks()))
343 update(&block);
344 update(position);
345 ++position;
346 }
347
348 // NOLINTNEXTLINE(misc-no-recursion)
349 void update(Operation *op) {
350 // Hash the regions. We need to make sure an empty region doesn't hash the
351 // same as no region, so we include the number of regions.
352 update(op->getNumRegions());
353 for (auto &region : reverse(op->getRegions()))
354 update(&region);
355
356 update(op->getName());
357
358 // Record the uses for later hashing.
359 for (auto &operand : op->getOpOperands())
360 update(operand);
361
362 // This happens after the numbering above, as it uses blockarg numbering
363 // for inner symbols.
364 hasSeenSymbol = false;
365 update(op, op->getAttrDictionary());
366
367 // Record any op results (types).
368 for (auto result : op->getResults())
369 update(result);
370
371 // If any of the operands, attributes, or results depended on symbols,
372 // remember this op for later adjustment.
373 if (hasSeenSymbol)
374 symbolSensitiveOps.push_back(op);
375
376 // Incorporate the hash of uses we have already built.
377 update(finalizeID(op));
378 update(position);
379 ++position;
380 }
381
382 // A map from an operation/block, to its identifier.
383 DenseMap<void *, unsigned> idTable;
384 unsigned nextID = 0;
385
386 // A map from an inner symbol, to its identifier.
387 DenseMap<StringAttr, std::pair<size_t, size_t>> innerSymIDTable;
388
389 // This keeps track of module names in the order of the appearance.
390 std::vector<StringAttr> referredModuleNames;
391
392 // String constants.
394
395 // This is the actual running hash calculation. This is a stateful element
396 // that should be reinitialized after each hash is produced.
397 llvm::SHA256 sha;
398
399 // The index of the current op. Increment after handling each op.
400 size_t position = 0;
401
402 // The operations that contain references to symbols that may be changed by
403 // dedup. These need a fixup pass after dedup.
404 bool hasSeenSymbol = false;
405 std::vector<Operation *> symbolSensitiveOps;
406};
407
408/// A reference to a `ModuleInfo` that compares and hashes like it. This is used
409/// to keep the potentially heavy module infos in a vector, and then populate
410/// maps with references to them.
415
416/// Allow `ModuleInfoRef` to be used as dense map keys. Hashes and compares the
417/// `ModuleInfo` the ref points to.
418template <>
423
427
428 static unsigned getHashValue(const ModuleInfoRef &ref) {
429 // We assume SHA256 is already a good hash and just truncate down to the
430 // number of bytes we need for DenseMap.
431 unsigned hash;
432 std::memcpy(&hash, ref.info->structuralHash.data(), sizeof(unsigned));
433
434 // Combine module names.
435 return llvm::hash_combine(
436 hash, llvm::hash_combine_range(ref.info->referredModuleNames.begin(),
437 ref.info->referredModuleNames.end()));
438 }
439
440 static bool isEqual(const ModuleInfoRef &lhs, const ModuleInfoRef &rhs) {
441 auto *empty = getEmptyKey().info;
442 auto *tombstone = getTombstoneKey().info;
443 if (lhs.info == empty || rhs.info == empty || lhs.info == tombstone ||
444 rhs.info == tombstone)
445 return lhs.info == rhs.info;
446 return *lhs.info == *rhs.info;
447 }
448};
449
450//===----------------------------------------------------------------------===//
451// Equivalence
452//===----------------------------------------------------------------------===//
453
454/// This class is for reporting differences between two modules which should
455/// have been deduplicated.
459 noDedupClass = StringAttr::get(context, noDedupAnnoClass);
460 dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
461 portDirectionsAttr = StringAttr::get(context, "portDirections");
462 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
463 nonessentialAttributes.insert(StringAttr::get(context, "name"));
464 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
465 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
466 nonessentialAttributes.insert(StringAttr::get(context, "portTypes"));
467 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
468 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
469 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
470 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
471 }
472
480
481 std::string prettyPrint(Attribute attr) {
482 SmallString<64> buffer;
483 llvm::raw_svector_ostream os(buffer);
484 if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
485 os << "0x";
486 if (integerAttr.getType().isSignlessInteger())
487 integerAttr.getValue().toStringUnsigned(buffer, /*radix=*/16);
488 else
489 integerAttr.getAPSInt().toString(buffer, /*radix=*/16);
490
491 } else
492 os << attr;
493 return std::string(buffer);
494 }
495
496 // NOLINTNEXTLINE(misc-no-recursion)
497 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
498 Operation *a, BundleType aType, Operation *b,
499 BundleType bType) {
500 if (aType.getNumElements() != bType.getNumElements()) {
501 diag.attachNote(a->getLoc())
502 << message << " bundle type has different number of elements";
503 diag.attachNote(b->getLoc()) << "second operation here";
504 return failure();
505 }
506
507 for (auto elementPair :
508 llvm::zip(aType.getElements(), bType.getElements())) {
509 auto aElement = std::get<0>(elementPair);
510 auto bElement = std::get<1>(elementPair);
511 if (aElement.isFlip != bElement.isFlip) {
512 diag.attachNote(a->getLoc()) << message << " bundle element "
513 << aElement.name << " flip does not match";
514 diag.attachNote(b->getLoc()) << "second operation here";
515 return failure();
516 }
517
518 if (failed(check(diag,
519 message + " -> bundle element \'" +
520 aElement.name.getValue() + "'",
521 a, aElement.type, b, bElement.type)))
522 return failure();
523 }
524 return success();
525 }
526
527 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
528 Operation *a, Type aType, Operation *b, Type bType) {
529 if (aType == bType)
530 return success();
531 if (auto aBundleType = type_dyn_cast<BundleType>(aType))
532 if (auto bBundleType = type_dyn_cast<BundleType>(bType))
533 return check(diag, message, a, aBundleType, b, bBundleType);
534 if (type_isa<RefType>(aType) && type_isa<RefType>(bType) &&
535 aType != bType) {
536 diag.attachNote(a->getLoc())
537 << message << ", has a RefType with a different base type "
538 << type_cast<RefType>(aType).getType()
539 << " in the same position of the two modules marked as 'must dedup'. "
540 "(This may be due to Grand Central Taps or Views being different "
541 "between the two modules.)";
542 diag.attachNote(b->getLoc())
543 << "the second module has a different base type "
544 << type_cast<RefType>(bType).getType();
545 return failure();
546 }
547 diag.attachNote(a->getLoc())
548 << message << " types don't match, first type is " << aType;
549 diag.attachNote(b->getLoc()) << "second type is " << bType;
550 return failure();
551 }
552
553 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
554 Block &aBlock, Operation *b, Block &bBlock) {
555
556 // Block argument types.
557 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
558 auto portNo = 0;
559 auto emitMissingPort = [&](Value existsVal, Operation *opExists,
560 Operation *opDoesNotExist) {
561 StringRef portName;
562 auto portNames = opExists->getAttrOfType<ArrayAttr>("portNames");
563 if (portNames)
564 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
565 portName = portNameAttr.getValue();
566 if (type_isa<RefType>(existsVal.getType())) {
567 diag.attachNote(opExists->getLoc())
568 << " contains a RefType port named '" + portName +
569 "' that only exists in one of the modules (can be due to "
570 "difference in Grand Central Tap or View of two modules "
571 "marked with must dedup)";
572 diag.attachNote(opDoesNotExist->getLoc())
573 << "second module to be deduped that does not have the RefType "
574 "port";
575 } else {
576 diag.attachNote(opExists->getLoc())
577 << "port '" + portName + "' only exists in one of the modules";
578 diag.attachNote(opDoesNotExist->getLoc())
579 << "second module to be deduped that does not have the port";
580 }
581 return failure();
582 };
583
584 for (auto argPair :
585 llvm::zip_longest(aBlock.getArguments(), bBlock.getArguments())) {
586 auto &aArg = std::get<0>(argPair);
587 auto &bArg = std::get<1>(argPair);
588 if (aArg.has_value() && bArg.has_value()) {
589 // TODO: we should print the port number if there are no port names, but
590 // there are always port names ;).
591 StringRef portName;
592 if (portNames) {
593 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
594 portName = portNameAttr.getValue();
595 }
596 // Assumption here that block arguments correspond to ports.
597 if (failed(check(diag, "module port '" + portName + "'", a,
598 aArg->getType(), b, bArg->getType())))
599 return failure();
600 data.map.map(aArg.value(), bArg.value());
601 portNo++;
602 continue;
603 }
604 if (!aArg.has_value())
605 std::swap(a, b);
606 return emitMissingPort(aArg.has_value() ? aArg.value() : bArg.value(), a,
607 b);
608 }
609
610 // Blocks operations.
611 auto aIt = aBlock.begin();
612 auto aEnd = aBlock.end();
613 auto bIt = bBlock.begin();
614 auto bEnd = bBlock.end();
615 while (aIt != aEnd && bIt != bEnd)
616 if (failed(check(diag, data, &*aIt++, &*bIt++)))
617 return failure();
618 if (aIt != aEnd) {
619 diag.attachNote(aIt->getLoc()) << "first block has more operations";
620 diag.attachNote(b->getLoc()) << "second block here";
621 return failure();
622 }
623 if (bIt != bEnd) {
624 diag.attachNote(bIt->getLoc()) << "second block has more operations";
625 diag.attachNote(a->getLoc()) << "first block here";
626 return failure();
627 }
628 return success();
629 }
630
631 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
632 Region &aRegion, Operation *b, Region &bRegion) {
633 auto aIt = aRegion.begin();
634 auto aEnd = aRegion.end();
635 auto bIt = bRegion.begin();
636 auto bEnd = bRegion.end();
637
638 // Region blocks.
639 while (aIt != aEnd && bIt != bEnd)
640 if (failed(check(diag, data, a, *aIt++, b, *bIt++)))
641 return failure();
642 if (aIt != aEnd || bIt != bEnd) {
643 diag.attachNote(a->getLoc())
644 << "operation regions have different number of blocks";
645 diag.attachNote(b->getLoc()) << "second operation here";
646 return failure();
647 }
648 return success();
649 }
650
651 LogicalResult check(InFlightDiagnostic &diag, Operation *a,
652 mlir::DenseBoolArrayAttr aAttr, Operation *b,
653 mlir::DenseBoolArrayAttr bAttr) {
654 if (aAttr == bAttr)
655 return success();
656 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
657 for (unsigned i = 0, e = aAttr.size(); i < e; ++i) {
658 auto aDirection = aAttr[i];
659 auto bDirection = bAttr[i];
660 if (aDirection != bDirection) {
661 auto &note = diag.attachNote(a->getLoc()) << "module port ";
662 if (portNames)
663 note << "'" << cast<StringAttr>(portNames[i]).getValue() << "'";
664 else
665 note << i;
666 note << " directions don't match, first direction is '"
667 << direction::toString(aDirection) << "'";
668 diag.attachNote(b->getLoc()) << "second direction is '"
669 << direction::toString(bDirection) << "'";
670 return failure();
671 }
672 }
673 return success();
674 }
675
676 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
677 DictionaryAttr aDict, Operation *b,
678 DictionaryAttr bDict) {
679 // Fast path.
680 if (aDict == bDict)
681 return success();
682
683 DenseSet<Attribute> seenAttrs;
684 for (auto namedAttr : aDict) {
685 auto attrName = namedAttr.getName();
686 if (nonessentialAttributes.contains(attrName))
687 continue;
688
689 auto aAttr = namedAttr.getValue();
690 auto bAttr = bDict.get(attrName);
691 if (!bAttr) {
692 diag.attachNote(a->getLoc())
693 << "second operation is missing attribute " << attrName;
694 diag.attachNote(b->getLoc()) << "second operation here";
695 return diag;
696 }
697
698 if (isa<hw::InnerRefAttr>(aAttr) && isa<hw::InnerRefAttr>(bAttr)) {
699 auto bRef = cast<hw::InnerRefAttr>(bAttr);
700 auto aRef = cast<hw::InnerRefAttr>(aAttr);
701 // See if they are pointing at the same operation or port.
702 auto aTarget = data.a.lookup(aRef.getName());
703 auto bTarget = data.b.lookup(bRef.getName());
704 if (!aTarget || !bTarget)
705 diag.attachNote(a->getLoc())
706 << "malformed ir, possibly violating use-before-def";
707 auto error = [&]() {
708 diag.attachNote(a->getLoc())
709 << "operations have different targets, first operation has "
710 << aTarget;
711 diag.attachNote(b->getLoc()) << "second operation has " << bTarget;
712 return failure();
713 };
714 if (aTarget.isPort()) {
715 // If they are targeting ports, make sure its the same port number.
716 if (!bTarget.isPort() || aTarget.getPort() != bTarget.getPort())
717 return error();
718 } else {
719 // Otherwise make sure that they are targeting the same operation.
720 if (!bTarget.isOpOnly() ||
721 data.map.lookupOrNull(aTarget.getOp()) != bTarget.getOp())
722 return error();
723 }
724 if (aTarget.getField() != bTarget.getField())
725 return error();
726 } else if (attrName == portDirectionsAttr) {
727 // Special handling for the port directions attribute for better
728 // error messages.
729 if (failed(check(diag, a, cast<mlir::DenseBoolArrayAttr>(aAttr), b,
730 cast<mlir::DenseBoolArrayAttr>(bAttr))))
731 return failure();
732 } else if (isa<DistinctAttr>(aAttr) && isa<DistinctAttr>(bAttr)) {
733 // TODO: properly handle DistinctAttr, including its use in paths.
734 // See https://github.com/llvm/circt/issues/6583
735 } else if (aAttr != bAttr) {
736 diag.attachNote(a->getLoc())
737 << "first operation has attribute '" << attrName.getValue()
738 << "' with value " << prettyPrint(aAttr);
739 diag.attachNote(b->getLoc())
740 << "second operation has value " << prettyPrint(bAttr);
741 return failure();
742 }
743 seenAttrs.insert(attrName);
744 }
745 if (aDict.getValue().size() != bDict.getValue().size()) {
746 for (auto namedAttr : bDict) {
747 auto attrName = namedAttr.getName();
748 // Skip the attribute if we don't care about this particular one or it
749 // is one that is known to be in both dictionaries.
750 if (nonessentialAttributes.contains(attrName) ||
751 seenAttrs.contains(attrName))
752 continue;
753 // We have found an attribute that is only in the second operation.
754 diag.attachNote(a->getLoc())
755 << "first operation is missing attribute " << attrName;
756 diag.attachNote(b->getLoc()) << "second operation here";
757 return failure();
758 }
759 }
760 return success();
761 }
762
763 // NOLINTNEXTLINE(misc-no-recursion)
764 LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a,
765 FInstanceLike b) {
766 auto aName = a.getReferencedModuleNameAttr();
767 auto bName = b.getReferencedModuleNameAttr();
768 if (aName == bName)
769 return success();
770
771 // If the modules instantiate are different we will want to know why the
772 // sub module did not dedupliate. This code recursively checks the child
773 // module.
774 auto aModule = instanceGraph.lookup(aName)->getModule();
775 auto bModule = instanceGraph.lookup(bName)->getModule();
776 // Create a new error for the submodule.
777 diag.attachNote(std::nullopt)
778 << "in instance " << a.getInstanceNameAttr() << " of " << aName
779 << ", and instance " << b.getInstanceNameAttr() << " of " << bName;
780 check(diag, aModule, bModule);
781 return failure();
782 }
783
784 // NOLINTNEXTLINE(misc-no-recursion)
785 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
786 Operation *b) {
787 // Operation name.
788 if (a->getName() != b->getName()) {
789 diag.attachNote(a->getLoc()) << "first operation is a " << a->getName();
790 diag.attachNote(b->getLoc()) << "second operation is a " << b->getName();
791 return failure();
792 }
793
794 // If its an instance operaiton, perform some checking and possibly
795 // recurse.
796 if (auto aInst = dyn_cast<FInstanceLike>(a)) {
797 auto bInst = cast<FInstanceLike>(b);
798 if (failed(check(diag, aInst, bInst)))
799 return failure();
800 }
801
802 // Operation results.
803 if (a->getNumResults() != b->getNumResults()) {
804 diag.attachNote(a->getLoc())
805 << "operations have different number of results";
806 diag.attachNote(b->getLoc()) << "second operation here";
807 return failure();
808 }
809 for (auto resultPair : llvm::zip(a->getResults(), b->getResults())) {
810 auto &aValue = std::get<0>(resultPair);
811 auto &bValue = std::get<1>(resultPair);
812 if (failed(check(diag, "operation result", a, aValue.getType(), b,
813 bValue.getType())))
814 return failure();
815 data.map.map(aValue, bValue);
816 }
817
818 // Operations operands.
819 if (a->getNumOperands() != b->getNumOperands()) {
820 diag.attachNote(a->getLoc())
821 << "operations have different number of operands";
822 diag.attachNote(b->getLoc()) << "second operation here";
823 return failure();
824 }
825 for (auto operandPair : llvm::zip(a->getOperands(), b->getOperands())) {
826 auto &aValue = std::get<0>(operandPair);
827 auto &bValue = std::get<1>(operandPair);
828 if (bValue != data.map.lookup(aValue)) {
829 diag.attachNote(a->getLoc())
830 << "operations use different operands, first operand is '"
831 << getFieldName(
832 getFieldRefFromValue(aValue, /*lookThroughCasts=*/true))
833 .first
834 << "'";
835 diag.attachNote(b->getLoc())
836 << "second operand is '"
837 << getFieldName(
838 getFieldRefFromValue(bValue, /*lookThroughCasts=*/true))
839 .first
840 << "', but should have been '"
841 << getFieldName(getFieldRefFromValue(data.map.lookup(aValue),
842 /*lookThroughCasts=*/true))
843 .first
844 << "'";
845 return failure();
846 }
847 }
848 data.map.map(a, b);
849
850 // Operation regions.
851 if (a->getNumRegions() != b->getNumRegions()) {
852 diag.attachNote(a->getLoc())
853 << "operations have different number of regions";
854 diag.attachNote(b->getLoc()) << "second operation here";
855 return failure();
856 }
857 for (auto regionPair : llvm::zip(a->getRegions(), b->getRegions())) {
858 auto &aRegion = std::get<0>(regionPair);
859 auto &bRegion = std::get<1>(regionPair);
860 if (failed(check(diag, data, a, aRegion, b, bRegion)))
861 return failure();
862 }
863
864 // Operation attributes.
865 if (failed(check(diag, data, a, a->getAttrDictionary(), b,
866 b->getAttrDictionary())))
867 return failure();
868 return success();
869 }
870
871 // NOLINTNEXTLINE(misc-no-recursion)
872 void check(InFlightDiagnostic &diag, Operation *a, Operation *b) {
873 hw::InnerSymbolTable aTable(a);
874 hw::InnerSymbolTable bTable(b);
875 ModuleData data(aTable, bTable);
877 diag.attachNote(a->getLoc()) << "module marked NoDedup";
878 return;
879 }
881 diag.attachNote(b->getLoc()) << "module marked NoDedup";
882 return;
883 }
884 auto aSymbol = cast<mlir::SymbolOpInterface>(a);
885 auto bSymbol = cast<mlir::SymbolOpInterface>(b);
886 if (!canRemoveModule(aSymbol) && !canRemoveModule(bSymbol)) {
887 diag.attachNote(a->getLoc())
888 << "module is "
889 << (aSymbol.isPrivate() ? "private but not discardable" : "public");
890 diag.attachNote(b->getLoc())
891 << "module is "
892 << (bSymbol.isPrivate() ? "private but not discardable" : "public");
893 return;
894 }
895 auto aGroup =
896 dyn_cast_or_null<StringAttr>(a->getDiscardableAttr(dedupGroupAttrName));
897 auto bGroup = dyn_cast_or_null<StringAttr>(
898 b->getAttrOfType<StringAttr>(dedupGroupAttrName));
899 if (aGroup != bGroup) {
900 if (bGroup) {
901 diag.attachNote(b->getLoc())
902 << "module is in dedup group '" << bGroup.str() << "'";
903 } else {
904 diag.attachNote(b->getLoc()) << "module is not part of a dedup group";
905 }
906 if (aGroup) {
907 diag.attachNote(a->getLoc())
908 << "module is in dedup group '" << aGroup.str() << "'";
909 } else {
910 diag.attachNote(a->getLoc()) << "module is not part of a dedup group";
911 }
912 return;
913 }
914 if (failed(check(diag, data, a, b)))
915 return;
916 diag.attachNote(a->getLoc()) << "first module here";
917 diag.attachNote(b->getLoc()) << "second module here";
918 }
919
920 // This is a cached "portDirections" string attr.
922 // This is a cached "NoDedup" annotation class string attr.
923 StringAttr noDedupClass;
924 // This is a cached string attr for the dedup group attribute.
926
927 // This is a set of every attribute we should ignore.
928 DenseSet<Attribute> nonessentialAttributes;
930};
931
932//===----------------------------------------------------------------------===//
933// Deduplication
934//===----------------------------------------------------------------------===//
935
936// Custom location merging. This only keeps track of 8 annotations from ".fir"
937// files, and however many annotations come from "real" sources. When
938// deduplicating, modules tend not to have scala source locators, so we wind
939// up fusing source locators for a module from every copy being deduped. There
940// is little value in this (all the modules are identical by definition).
941static Location mergeLoc(MLIRContext *context, Location to, Location from) {
942 // Unique the set of locations to be fused.
943 llvm::SmallSetVector<Location, 4> decomposedLocs;
944 // only track 8 "fir" locations
945 unsigned seenFIR = 0;
946 for (auto loc : {to, from}) {
947 // If the location is a fused location we decompose it if it has no
948 // metadata or the metadata is the same as the top level metadata.
949 if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
950 // UnknownLoc's have already been removed from FusedLocs so we can
951 // simply add all of the internal locations.
952 for (auto loc : fusedLoc.getLocations()) {
953 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
954 if (fileLoc.getFilename().strref().ends_with(".fir")) {
955 ++seenFIR;
956 if (seenFIR > 8)
957 continue;
958 }
959 }
960 decomposedLocs.insert(loc);
961 }
962 continue;
963 }
964
965 // Might need to skip this fir.
966 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
967 if (fileLoc.getFilename().strref().ends_with(".fir")) {
968 ++seenFIR;
969 if (seenFIR > 8)
970 continue;
971 }
972 }
973 // Otherwise, only add known locations to the set.
974 if (!isa<UnknownLoc>(loc))
975 decomposedLocs.insert(loc);
976 }
977
978 auto locs = decomposedLocs.getArrayRef();
979
980 // Handle the simple cases of less than two locations. Ensure the metadata (if
981 // provided) is not dropped.
982 if (locs.empty())
983 return UnknownLoc::get(context);
984 if (locs.size() == 1)
985 return locs.front();
986
987 return FusedLoc::get(context, locs);
988}
989
990struct Deduper {
991
992 using RenameMap = DenseMap<StringAttr, StringAttr>;
993
995 NLATable *nlaTable, CircuitOp circuit)
996 : context(circuit->getContext()), instanceGraph(instanceGraph),
998 nlaBlock(circuit.getBodyBlock()),
999 nonLocalString(StringAttr::get(context, "circt.nonlocal")),
1000 classString(StringAttr::get(context, "class")) {
1001 // Populate the NLA cache.
1002 for (auto nla : circuit.getOps<hw::HierPathOp>())
1003 nlaCache[nla.getNamepathAttr()] = nla.getSymNameAttr();
1004 }
1005
1006 /// Remove the "fromModule", and replace all references to it with the
1007 /// "toModule". Modules should be deduplicated in a bottom-up order. Any
1008 /// module which is not deduplicated needs to be recorded with the `record`
1009 /// call.
1010 void dedup(FModuleLike toModule, FModuleLike fromModule) {
1011 // A map of operation (e.g. wires, nodes) names which are changed, which is
1012 // used to update NLAs that reference the "fromModule".
1013 RenameMap renameMap;
1014
1015 // Merge the port locations.
1016 SmallVector<Attribute> newLocs;
1017 for (auto [toLoc, fromLoc] : llvm::zip(toModule.getPortLocations(),
1018 fromModule.getPortLocations())) {
1019 if (toLoc == fromLoc)
1020 newLocs.push_back(toLoc);
1021 else
1022 newLocs.push_back(mergeLoc(context, cast<LocationAttr>(toLoc),
1023 cast<LocationAttr>(fromLoc)));
1024 }
1025 toModule->setAttr("portLocations", ArrayAttr::get(context, newLocs));
1026
1027 // Merge the two modules.
1028 mergeOps(renameMap, toModule, toModule, fromModule, fromModule);
1029
1030 // Rewrite NLAs pathing through these modules to refer to the to module. It
1031 // is safe to do this at this point because NLAs cannot be one element long.
1032 // This means that all NLAs which require more context cannot be targetting
1033 // something in the module it self.
1034 if (auto to = dyn_cast<FModuleOp>(*toModule))
1035 rewriteModuleNLAs(renameMap, to, cast<FModuleOp>(*fromModule));
1036 else
1037 rewriteExtModuleNLAs(renameMap, toModule.getModuleNameAttr(),
1038 fromModule.getModuleNameAttr());
1039
1040 replaceInstances(toModule, fromModule);
1041 }
1042
1043 /// Record the usages of any NLA's in this module, so that we may update the
1044 /// annotation if the parent module is deduped with another module.
1045 void record(FModuleLike module) {
1046 // Record any annotations on the module.
1047 recordAnnotations(module);
1048 // Record port annotations.
1049 for (unsigned i = 0, e = getNumPorts(module); i < e; ++i)
1051 // Record any annotations in the module body.
1052 module->walk([&](Operation *op) { recordAnnotations(op); });
1053 }
1054
1055private:
1056 /// Get a cached namespace for a module.
1058 return moduleNamespaces.try_emplace(module, cast<FModuleLike>(module))
1059 .first->second;
1060 }
1061
1062 /// For a specific annotation target, record all the unique NLAs which
1063 /// target it in the `targetMap`.
1065 for (auto anno : target.getAnnotations())
1066 if (auto nlaRef = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
1067 targetMap[nlaRef.getAttr()].insert(target);
1068 }
1069
1070 /// Record all targets which use an NLA.
1071 void recordAnnotations(Operation *op) {
1072 // Record annotations.
1073 recordAnnotations(OpAnnoTarget(op));
1074
1075 // Record port annotations only if this is a mem operation.
1076 auto mem = dyn_cast<MemOp>(op);
1077 if (!mem)
1078 return;
1079
1080 // Record port annotations.
1081 for (unsigned i = 0, e = mem->getNumResults(); i < e; ++i)
1082 recordAnnotations(PortAnnoTarget(mem, i));
1083 }
1084
1085 /// This deletes and replaces all instances of the "fromModule" with instances
1086 /// of the "toModule".
1087 void replaceInstances(FModuleLike toModule, Operation *fromModule) {
1088 // Replace all instances of the other module.
1089 auto *fromNode =
1090 instanceGraph[::cast<igraph::ModuleOpInterface>(fromModule)];
1091 auto *toNode = instanceGraph[toModule];
1092 auto toModuleRef = FlatSymbolRefAttr::get(toModule.getModuleNameAttr());
1093 for (auto *oldInstRec : llvm::make_early_inc_range(fromNode->uses())) {
1094 auto inst = oldInstRec->getInstance();
1095 if (auto instOp = dyn_cast<InstanceOp>(*inst)) {
1096 instOp.setModuleNameAttr(toModuleRef);
1097 instOp.setPortNamesAttr(toModule.getPortNamesAttr());
1098 } else if (auto objectOp = dyn_cast<ObjectOp>(*inst)) {
1099 auto classLike = cast<ClassLike>(*toNode->getModule());
1100 ClassType classType = detail::getInstanceTypeForClassLike(classLike);
1101 objectOp.getResult().setType(classType);
1102 }
1103 oldInstRec->getParent()->addInstance(inst, toNode);
1104 oldInstRec->erase();
1105 }
1106 instanceGraph.erase(fromNode);
1107 fromModule->erase();
1108 }
1109
1110 /// Look up the instantiations of the `from` module and create an NLA for each
1111 /// one, appending the baseNamepath to each NLA. This is used to add more
1112 /// context to an already existing NLA. The `fromModule` is used to indicate
1113 /// which module the annotation is coming from before the merge, and will be
1114 /// used to create the namepaths.
1115 SmallVector<FlatSymbolRefAttr>
1116 createNLAs(Operation *fromModule, ArrayRef<Attribute> baseNamepath,
1117 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1118 // Create an attribute array with a placeholder in the first element, where
1119 // the root refence of the NLA will be inserted.
1120 SmallVector<Attribute> namepath = {nullptr};
1121 namepath.append(baseNamepath.begin(), baseNamepath.end());
1122
1123 auto loc = fromModule->getLoc();
1124 auto *fromNode = instanceGraph[cast<igraph::ModuleOpInterface>(fromModule)];
1125 SmallVector<FlatSymbolRefAttr> nlas;
1126 for (auto *instanceRecord : fromNode->uses()) {
1127 auto parent = cast<FModuleOp>(*instanceRecord->getParent()->getModule());
1128 auto inst = instanceRecord->getInstance();
1129 namepath[0] = OpAnnoTarget(inst).getNLAReference(getNamespace(parent));
1130 auto arrayAttr = ArrayAttr::get(context, namepath);
1131 // Check the NLA cache to see if we already have this NLA.
1132 auto &cacheEntry = nlaCache[arrayAttr];
1133 if (!cacheEntry) {
1134 auto builder = OpBuilder::atBlockBegin(nlaBlock);
1135 auto nla = hw::HierPathOp::create(builder, loc, "nla", arrayAttr);
1136 // Insert it into the symbol table to get a unique name.
1137 symbolTable.insert(nla);
1138 // Store it in the cache.
1139 cacheEntry = nla.getNameAttr();
1140 nla.setVisibility(vis);
1141 nlaTable->addNLA(nla);
1142 }
1143 auto nlaRef = FlatSymbolRefAttr::get(cast<StringAttr>(cacheEntry));
1144 nlas.push_back(nlaRef);
1145 }
1146 return nlas;
1147 }
1148
1149 /// Look up the instantiations of this module and create an NLA for each one.
1150 /// This returns an array of symbol references which can be used to reference
1151 /// the NLAs.
1152 SmallVector<FlatSymbolRefAttr>
1153 createNLAs(StringAttr toModuleName, FModuleLike fromModule,
1154 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1155 return createNLAs(fromModule, FlatSymbolRefAttr::get(toModuleName), vis);
1156 }
1157
1158 /// Clone the annotation for each NLA in a list. The attribute list should
1159 /// have a placeholder for the "circt.nonlocal" field, and `nonLocalIndex`
1160 /// should be the index of this field.
1161 void cloneAnnotation(SmallVectorImpl<FlatSymbolRefAttr> &nlas,
1162 Annotation anno, ArrayRef<NamedAttribute> attributes,
1163 unsigned nonLocalIndex,
1164 SmallVectorImpl<Annotation> &newAnnotations) {
1165 SmallVector<NamedAttribute> mutableAttributes(attributes.begin(),
1166 attributes.end());
1167 for (auto &nla : nlas) {
1168 // Add the new annotation.
1169 mutableAttributes[nonLocalIndex].setValue(nla);
1170 auto dict = DictionaryAttr::getWithSorted(context, mutableAttributes);
1171 // The original annotation records if its a subannotation.
1172 anno.setDict(dict);
1173 newAnnotations.push_back(anno);
1174 }
1175 }
1176
1177 /// This erases the NLA op, and removes the NLA from every module's NLA map,
1178 /// but it does not delete the NLA reference from the target operation's
1179 /// annotations.
1180 void eraseNLA(hw::HierPathOp nla) {
1181 // Erase the NLA from the leaf module's nlaMap.
1182 targetMap.erase(nla.getNameAttr());
1183 nlaTable->erase(nla);
1184 nlaCache.erase(nla.getNamepathAttr());
1185 symbolTable.erase(nla);
1186 }
1187
1188 /// Process all NLAs referencing the "from" module to point to the "to"
1189 /// module. This is used after merging two modules together.
1190 void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule,
1191 FModuleOp fromModule) {
1192 auto toName = toModule.getNameAttr();
1193 auto fromName = fromModule.getNameAttr();
1194 // Create a copy of the current NLAs. We will be pushing and removing
1195 // NLAs from this op as we go.
1196 auto moduleNLAs = nlaTable->lookup(fromModule.getNameAttr()).vec();
1197 // Change the NLA to target the toModule.
1198 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1199 // Now we walk the NLA searching for ones that require more context to be
1200 // added.
1201 for (auto nla : moduleNLAs) {
1202 auto elements = nla.getNamepath().getValue();
1203 // If we don't need to add more context, we're done here.
1204 if (nla.root() != toName)
1205 continue;
1206 // Create the replacement NLAs.
1207 SmallVector<Attribute> namepath(elements.begin(), elements.end());
1208 auto nlaRefs = createNLAs(fromModule, namepath, nla.getVisibility());
1209 // Copy out the targets, because we will be updating the map.
1210 auto &set = targetMap[nla.getSymNameAttr()];
1211 SmallVector<AnnoTarget> targets(set.begin(), set.end());
1212 // Replace the uses of the old NLA with the new NLAs.
1213 for (auto target : targets) {
1214 // We have to clone any annotation which uses the old NLA for each new
1215 // NLA. This array collects the new set of annotations.
1216 SmallVector<Annotation> newAnnotations;
1217 for (auto anno : target.getAnnotations()) {
1218 // Find the non-local field of the annotation.
1219 auto [it, found] = mlir::impl::findAttrSorted(
1220 anno.begin(), anno.end(), nonLocalString);
1221 // If this annotation doesn't use the target NLA, copy it with no
1222 // changes.
1223 if (!found || cast<FlatSymbolRefAttr>(it->getValue()).getAttr() !=
1224 nla.getSymNameAttr()) {
1225 newAnnotations.push_back(anno);
1226 continue;
1227 }
1228 auto nonLocalIndex = std::distance(anno.begin(), it);
1229 // Clone the annotation and add it to the list of new annotations.
1230 cloneAnnotation(nlaRefs, anno,
1231 ArrayRef<NamedAttribute>(anno.begin(), anno.end()),
1232 nonLocalIndex, newAnnotations);
1233 }
1234
1235 // Apply the new annotations to the operation.
1236 AnnotationSet annotations(newAnnotations, context);
1237 target.setAnnotations(annotations);
1238 // Record that target uses the NLA.
1239 for (auto nla : nlaRefs)
1240 targetMap[nla.getAttr()].insert(target);
1241 }
1242
1243 // Erase the old NLA and remove it from all breadcrumbs.
1244 eraseNLA(nla);
1245 }
1246 }
1247
1248 /// Process all the NLAs that the two modules participate in, replacing
1249 /// references to the "from" module with references to the "to" module, and
1250 /// adding more context if necessary.
1251 void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule,
1252 FModuleOp fromModule) {
1253 addAnnotationContext(renameMap, toModule, toModule);
1254 addAnnotationContext(renameMap, toModule, fromModule);
1255 }
1256
1257 // Update all NLAs which the "from" external module participates in to the
1258 // "toName".
1259 void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName,
1260 StringAttr fromName) {
1261 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1262 }
1263
1264 /// Take an annotation, and update it to be a non-local annotation. If the
1265 /// annotation is already non-local and has enough context, it will be skipped
1266 /// for now. Return true if the annotation was made non-local.
1267 bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to,
1268 FModuleLike fromModule, Annotation anno,
1269 SmallVectorImpl<Annotation> &newAnnotations) {
1270 // Start constructing a new annotation, pushing a "circt.nonLocal" field
1271 // into the correct spot if its not already a non-local annotation.
1272 SmallVector<NamedAttribute> attributes;
1273 int nonLocalIndex = -1;
1274 for (const auto &val : llvm::enumerate(anno)) {
1275 auto attr = val.value();
1276 // Is this field "circt.nonlocal"?
1277 auto compare = attr.getName().compare(nonLocalString);
1278 assert(compare != 0 && "should not pass non-local annotations here");
1279 if (compare == 1) {
1280 // This annotation definitely does not have "circt.nonlocal" field. Push
1281 // an empty place holder for the non-local annotation.
1282 nonLocalIndex = val.index();
1283 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1284 break;
1285 }
1286 // Otherwise push the current attribute and keep searching for the
1287 // "circt.nonlocal" field.
1288 attributes.push_back(attr);
1289 }
1290 if (nonLocalIndex == -1) {
1291 // Push an empty "circt.nonlocal" field to the last slot.
1292 nonLocalIndex = attributes.size();
1293 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1294 } else {
1295 // Copy the remaining annotation fields in.
1296 attributes.append(anno.begin() + nonLocalIndex, anno.end());
1297 }
1298
1299 // Construct the NLAs if we don't have any yet.
1300 auto nlaRefs = createNLAs(toModuleName, fromModule);
1301 for (auto nla : nlaRefs)
1302 targetMap[nla.getAttr()].insert(to);
1303
1304 // Clone the annotation for each new NLA.
1305 cloneAnnotation(nlaRefs, anno, attributes, nonLocalIndex, newAnnotations);
1306 return true;
1307 }
1308
1309 void copyAnnotations(FModuleLike toModule, AnnoTarget to,
1310 FModuleLike fromModule, AnnotationSet annos,
1311 SmallVectorImpl<Annotation> &newAnnotations,
1312 SmallPtrSetImpl<Attribute> &dontTouches) {
1313 for (auto anno : annos) {
1314 if (anno.isClass(dontTouchAnnoClass)) {
1315 // Remove the nonlocal field of the annotation if it has one, since this
1316 // is a sticky annotation.
1317 anno.removeMember("circt.nonlocal");
1318 auto [it, inserted] = dontTouches.insert(anno.getAttr());
1319 if (inserted)
1320 newAnnotations.push_back(anno);
1321 continue;
1322 }
1323 // If the annotation is already non-local, we add it as is. It is already
1324 // added to the target map.
1325 if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal")) {
1326 newAnnotations.push_back(anno);
1327 targetMap[nla.getAttr()].insert(to);
1328 continue;
1329 }
1330 // Otherwise make the annotation non-local and add it to the set.
1331 makeAnnotationNonLocal(toModule.getModuleNameAttr(), to, fromModule, anno,
1332 newAnnotations);
1333 }
1334 }
1335
1336 /// Merge the annotations of a specific target, either a operation or a port
1337 /// on an operation.
1338 void mergeAnnotations(FModuleLike toModule, AnnoTarget to,
1339 AnnotationSet toAnnos, FModuleLike fromModule,
1340 AnnoTarget from, AnnotationSet fromAnnos) {
1341 // This is a list of all the annotations which will be added to `to`.
1342 SmallVector<Annotation> newAnnotations;
1343
1344 // We have special case handling of DontTouch to prevent it from being
1345 // turned into a non-local annotation, and to remove duplicates.
1346 llvm::SmallPtrSet<Attribute, 4> dontTouches;
1347
1348 // Iterate the annotations, transforming most annotations into non-local
1349 // ones.
1350 copyAnnotations(toModule, to, toModule, toAnnos, newAnnotations,
1351 dontTouches);
1352 copyAnnotations(toModule, to, fromModule, fromAnnos, newAnnotations,
1353 dontTouches);
1354
1355 // Copy over all the new annotations.
1356 if (!newAnnotations.empty())
1357 to.setAnnotations(AnnotationSet(newAnnotations, context));
1358 }
1359
1360 /// Merge all annotations and port annotations on two operations.
1361 void mergeAnnotations(FModuleLike toModule, Operation *to,
1362 FModuleLike fromModule, Operation *from) {
1363 // Merge op annotations.
1364 mergeAnnotations(toModule, OpAnnoTarget(to), AnnotationSet(to), fromModule,
1365 OpAnnoTarget(from), AnnotationSet(from));
1366
1367 // Merge port annotations.
1368 if (toModule == to) {
1369 // Merge module port annotations.
1370 for (unsigned i = 0, e = getNumPorts(toModule); i < e; ++i)
1371 mergeAnnotations(toModule, PortAnnoTarget(toModule, i),
1372 AnnotationSet::forPort(toModule, i), fromModule,
1373 PortAnnoTarget(fromModule, i),
1374 AnnotationSet::forPort(fromModule, i));
1375 } else if (auto toMem = dyn_cast<MemOp>(to)) {
1376 // Merge memory port annotations.
1377 auto fromMem = cast<MemOp>(from);
1378 for (unsigned i = 0, e = toMem.getNumResults(); i < e; ++i)
1379 mergeAnnotations(toModule, PortAnnoTarget(toMem, i),
1380 AnnotationSet::forPort(toMem, i), fromModule,
1381 PortAnnoTarget(fromMem, i),
1382 AnnotationSet::forPort(fromMem, i));
1383 }
1384 }
1385
1386 hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule,
1387 hw::InnerSymAttr toSym,
1388 hw::InnerSymAttr fromSym) {
1389 if (fromSym && !fromSym.getProps().empty()) {
1390 auto &isn = getNamespace(toModule);
1391 // The properties for the new inner symbol..
1392 SmallVector<hw::InnerSymPropertiesAttr> newProps;
1393 // If the "to" op already has an inner symbol, copy all its properties.
1394 if (toSym)
1395 llvm::append_range(newProps, toSym);
1396 // Add each property from the fromSym to the toSym.
1397 for (auto fromProp : fromSym) {
1398 hw::InnerSymPropertiesAttr newProp;
1399 auto *it = llvm::find_if(newProps, [&](auto p) {
1400 return p.getFieldID() == fromProp.getFieldID();
1401 });
1402 if (it != newProps.end()) {
1403 // If we already have an inner sym with the same field id, use
1404 // that.
1405 newProp = *it;
1406 // If the old symbol is public, we need to make the new one public.
1407 if (fromProp.getSymVisibility().getValue() == "public" &&
1408 newProp.getSymVisibility().getValue() != "public") {
1409 *it = hw::InnerSymPropertiesAttr::get(context, newProp.getName(),
1410 newProp.getFieldID(),
1411 fromProp.getSymVisibility());
1412 }
1413 } else {
1414 // We need to add a new property to the inner symbol for this field.
1415 auto newName = isn.newName(fromProp.getName().getValue());
1416 newProp = hw::InnerSymPropertiesAttr::get(
1417 context, StringAttr::get(context, newName), fromProp.getFieldID(),
1418 fromProp.getSymVisibility());
1419 newProps.push_back(newProp);
1420 }
1421 renameMap[fromProp.getName()] = newProp.getName();
1422 }
1423 // Sort the fields by field id.
1424 llvm::sort(newProps, [](auto &p, auto &q) {
1425 return p.getFieldID() < q.getFieldID();
1426 });
1427 // Return the merged inner symbol.
1428 return hw::InnerSymAttr::get(context, newProps);
1429 }
1430 return hw::InnerSymAttr();
1431 }
1432
1433 // Record the symbol name change of the operation or any of its ports when
1434 // merging two operations. The renamed symbols are used to update the
1435 // target of any NLAs. This will add symbols to the "to" operation if needed.
1436 void recordSymRenames(RenameMap &renameMap, FModuleLike toModule,
1437 Operation *to, FModuleLike fromModule,
1438 Operation *from) {
1439 // If the "from" operation has an inner_sym, we need to make sure the
1440 // "to" operation also has an `inner_sym` and then record the renaming.
1441 if (auto fromInnerSym = dyn_cast<hw::InnerSymbolOpInterface>(from)) {
1442 auto toInnerSym = cast<hw::InnerSymbolOpInterface>(to);
1443 if (auto newSymAttr = mergeInnerSymbols(renameMap, toModule,
1444 toInnerSym.getInnerSymAttr(),
1445 fromInnerSym.getInnerSymAttr()))
1446 toInnerSym.setInnerSymbolAttr(newSymAttr);
1447 }
1448
1449 // If there are no port symbols on the "from" operation, we are done here.
1450 auto fromPortSyms = from->getAttrOfType<ArrayAttr>("portSymbols");
1451 if (!fromPortSyms || fromPortSyms.empty())
1452 return;
1453 // We have to map each "fromPort" to each "toPort".
1454 auto portCount = fromPortSyms.size();
1455 auto toPortSyms = to->getAttrOfType<ArrayAttr>("portSymbols");
1456
1457 // Create an array of new port symbols for the "to" operation, copy in the
1458 // old symbols if it has any, create an empty symbol array if it doesn't.
1459 SmallVector<Attribute> newPortSyms;
1460 if (toPortSyms.empty())
1461 newPortSyms.assign(portCount, hw::InnerSymAttr());
1462 else
1463 newPortSyms.assign(toPortSyms.begin(), toPortSyms.end());
1464
1465 for (unsigned portNo = 0; portNo < portCount; ++portNo) {
1466 if (auto newPortSym = mergeInnerSymbols(
1467 renameMap, toModule,
1468 llvm::cast_if_present<hw::InnerSymAttr>(newPortSyms[portNo]),
1469 cast<hw::InnerSymAttr>(fromPortSyms[portNo]))) {
1470 newPortSyms[portNo] = newPortSym;
1471 }
1472 }
1473
1474 // Commit the new symbol attribute.
1475 FModuleLike::fixupPortSymsArray(newPortSyms, toModule.getContext());
1476 cast<FModuleLike>(to).setPortSymbols(newPortSyms);
1477 }
1478
1479 /// Recursively merge two operations.
1480 // NOLINTNEXTLINE(misc-no-recursion)
1481 void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to,
1482 FModuleLike fromModule, Operation *from) {
1483 // Merge the operation locations.
1484 if (to->getLoc() != from->getLoc())
1485 to->setLoc(mergeLoc(context, to->getLoc(), from->getLoc()));
1486
1487 // Recurse into any regions.
1488 for (auto regions : llvm::zip(to->getRegions(), from->getRegions()))
1489 mergeRegions(renameMap, toModule, std::get<0>(regions), fromModule,
1490 std::get<1>(regions));
1491
1492 // Record any inner_sym renamings that happened.
1493 recordSymRenames(renameMap, toModule, to, fromModule, from);
1494
1495 // Merge the annotations.
1496 mergeAnnotations(toModule, to, fromModule, from);
1497 }
1498
1499 /// Recursively merge two blocks.
1500 void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock,
1501 FModuleLike fromModule, Block &fromBlock) {
1502 // Merge the block locations.
1503 for (auto [toArg, fromArg] :
1504 llvm::zip(toBlock.getArguments(), fromBlock.getArguments()))
1505 if (toArg.getLoc() != fromArg.getLoc())
1506 toArg.setLoc(mergeLoc(context, toArg.getLoc(), fromArg.getLoc()));
1507
1508 for (auto ops : llvm::zip(toBlock, fromBlock))
1509 mergeOps(renameMap, toModule, &std::get<0>(ops), fromModule,
1510 &std::get<1>(ops));
1511 }
1512
1513 // Recursively merge two regions.
1514 void mergeRegions(RenameMap &renameMap, FModuleLike toModule,
1515 Region &toRegion, FModuleLike fromModule,
1516 Region &fromRegion) {
1517 for (auto blocks : llvm::zip(toRegion, fromRegion))
1518 mergeBlocks(renameMap, toModule, std::get<0>(blocks), fromModule,
1519 std::get<1>(blocks));
1520 }
1521
1522 MLIRContext *context;
1524 SymbolTable &symbolTable;
1525
1526 /// Cached nla table analysis.
1527 NLATable *nlaTable = nullptr;
1528
1529 /// We insert all NLAs to the beginning of this block.
1530 Block *nlaBlock;
1531
1532 // This maps an NLA to the operations and ports that uses it.
1533 DenseMap<Attribute, llvm::SmallDenseSet<AnnoTarget>> targetMap;
1534
1535 // This is a cache to avoid creating duplicate NLAs. This maps the ArrayAtr
1536 // of the NLA's path to the name of the NLA which contains it.
1537 DenseMap<Attribute, Attribute> nlaCache;
1538
1539 // Cached attributes for faster comparisons and attribute building.
1540 StringAttr nonLocalString;
1541 StringAttr classString;
1542
1543 /// A module namespace cache.
1544 DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
1545};
1546
1547//===----------------------------------------------------------------------===//
1548// Fixup
1549//===----------------------------------------------------------------------===//
1550
1551/// This fixes up connects when the field names of a bundle type changes. It
1552/// finds all fields which were previously bulk connected and legalizes it
1553/// into a connect for each field.
1554static void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src) {
1555 // If the types already match we can emit a connect.
1556 auto dstType = dst.getType();
1557 auto srcType = src.getType();
1558 if (dstType == srcType) {
1559 emitConnect(builder, dst, src);
1560 return;
1561 }
1562 // It must be a bundle type and the field name has changed. We have to
1563 // manually decompose the bulk connect into a connect for each field.
1564 auto dstBundle = type_cast<BundleType>(dstType);
1565 auto srcBundle = type_cast<BundleType>(srcType);
1566 for (unsigned i = 0; i < dstBundle.getNumElements(); ++i) {
1567 auto dstField = SubfieldOp::create(builder, dst, i);
1568 auto srcField = SubfieldOp::create(builder, src, i);
1569 if (dstBundle.getElement(i).isFlip) {
1570 std::swap(srcBundle, dstBundle);
1571 std::swap(srcField, dstField);
1572 }
1573 fixupConnect(builder, dstField, srcField);
1574 }
1575}
1576
1577/// Adjust the symbol references in an op. This includes updating its attributes
1578/// and types.
1579static void
1580fixupSymbolSensitiveOp(Operation *op, InstanceGraph &instanceGraph,
1581 const DenseMap<Attribute, StringAttr> &dedupMap) {
1582 // If this is an instance op, dedup may have subtly changed the port types.
1583 // For example, structurally different bundles may still dedup. In this case
1584 // we now have an instance op that produces result values of the old type, but
1585 // the port info on the instantiated module already represents the new type.
1586 // Fix this up by going through an intermediate wire.
1587 if (auto instOp = dyn_cast<InstanceOp>(op)) {
1588 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp->getContext());
1589 builder.setInsertionPointAfter(instOp);
1590 auto module = instanceGraph.lookup(instOp.getModuleNameAttr().getAttr())
1591 ->getModule<FModuleLike>();
1592 for (auto [index, result] : llvm::enumerate(instOp.getResults())) {
1593 auto newType = module.getPortType(index);
1594 auto oldType = result.getType();
1595 // If the type has not changed, we don't have to fix up anything.
1596 if (newType == oldType)
1597 continue;
1598 LLVM_DEBUG(llvm::dbgs()
1599 << "- Updating instance port \"" << instOp.getInstanceName()
1600 << "." << instOp.getPortName(index) << "\" from " << oldType
1601 << " to " << newType << "\n");
1602
1603 // If the type changed we transform it back to the old type with an
1604 // intermediate wire.
1605 auto wire = WireOp::create(builder, oldType, instOp.getPortName(index))
1606 .getResult();
1607 result.replaceAllUsesWith(wire);
1608 result.setType(newType);
1609 if (instOp.getPortDirection(index) == Direction::Out)
1610 fixupConnect(builder, wire, result);
1611 else
1612 fixupConnect(builder, result, wire);
1613 }
1614 }
1615
1616 // Use an attribute/type replacer to look for references to old symbols that
1617 // need to be replaced with new symbols.
1618 mlir::AttrTypeReplacer replacer;
1619 replacer.addReplacement([&](FlatSymbolRefAttr symRef) {
1620 auto oldName = symRef.getAttr();
1621 auto newName = dedupMap.lookup(oldName);
1622 if (newName && newName != oldName) {
1623 auto newSymRef = FlatSymbolRefAttr::get(newName);
1624 LLVM_DEBUG(llvm::dbgs()
1625 << "- Updating " << symRef << " to " << newSymRef << " in "
1626 << op->getName() << " at " << op->getLoc() << "\n");
1627 return newSymRef;
1628 }
1629 return symRef;
1630 });
1631
1632 // Update attributes.
1633 op->setAttrs(cast<DictionaryAttr>(replacer.replace(op->getAttrDictionary())));
1634
1635 // Update the argument types.
1636 for (auto &region : op->getRegions())
1637 for (auto &block : region)
1638 for (auto arg : block.getArguments())
1639 arg.setType(replacer.replace(arg.getType()));
1640
1641 // Update result types.
1642 for (auto result : op->getResults())
1643 result.setType(replacer.replace(result.getType()));
1644}
1645
1646/// Adjust the symbol references in ops marked as sensitive to them. This
1647/// includes updating their attributes and types.
1649 InstanceGraph &instanceGraph,
1650 const DenseMap<Operation *, ModuleInfoRef> &moduleToModuleInfo,
1651 const DenseMap<Attribute, StringAttr> &dedupMap) {
1652 for (auto *node : instanceGraph) {
1653 // Look up the module info for this module, which contains the list of ops
1654 // that need to be updated.
1655 auto module = node->getModule<FModuleLike>();
1656 auto it = moduleToModuleInfo.find(module);
1657 if (it == moduleToModuleInfo.end())
1658 continue;
1659
1660 // Update each symbol-sensitive op individually.
1661 auto &ops = it->second.info->symbolSensitiveOps;
1662 if (ops.empty())
1663 continue;
1664 LLVM_DEBUG(llvm::dbgs()
1665 << "- Updating " << ops.size() << " symbol-sensitive ops in "
1666 << module.getNameAttr() << "\n");
1667 for (auto *op : ops)
1668 fixupSymbolSensitiveOp(op, instanceGraph, dedupMap);
1669 }
1670}
1671
1672//===----------------------------------------------------------------------===//
1673// DedupPass
1674//===----------------------------------------------------------------------===//
1675
1676namespace {
1677class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {
1678 using DedupBase::DedupBase;
1679
1680 void runOnOperation() override {
1681 auto *context = &getContext();
1682 auto circuit = getOperation();
1683 auto &instanceGraph = getAnalysis<InstanceGraph>();
1684 auto *nlaTable = &getAnalysis<NLATable>();
1685 auto &symbolTable = getAnalysis<SymbolTable>();
1686 Deduper deduper(instanceGraph, symbolTable, nlaTable, circuit);
1687 Equivalence equiv(context, instanceGraph);
1688 auto anythingChanged = false;
1689 LLVM_DEBUG({
1690 llvm::dbgs() << "\n";
1691 debugHeader(Twine("Dedup circuit \"") + circuit.getName() + "\"")
1692 << "\n\n";
1693 });
1694
1695 // Modules annotated with this should not be considered for deduplication.
1696 auto noDedupClass = StringAttr::get(context, noDedupAnnoClass);
1697
1698 // Only modules within the same group may be deduplicated.
1699 auto dedupGroupClass = StringAttr::get(context, dedupGroupAnnoClass);
1700
1701 // A map of all the module moduleInfo that we have calculated so far.
1702 DenseMap<ModuleInfoRef, Operation *> moduleInfoToModule;
1703 DenseMap<Operation *, ModuleInfoRef> moduleToModuleInfo;
1704
1705 // We track the name of the module that each module is deduped into, so that
1706 // we can make sure all modules which are marked "must dedup" with each
1707 // other were all deduped to the same module.
1708 DenseMap<Attribute, StringAttr> dedupMap;
1709
1710 // We must iterate the modules from the bottom up so that we can properly
1711 // deduplicate the modules. We copy the list of modules into a vector first
1712 // to avoid iterator invalidation while we mutate the instance graph.
1713 SmallVector<FModuleLike, 0> modules;
1714 instanceGraph.walkPostOrder([&](auto &node) {
1715 if (auto mod = dyn_cast<FModuleLike>(*node.getModule()))
1716 modules.push_back(mod);
1717 });
1718 LLVM_DEBUG(llvm::dbgs() << "Found " << modules.size() << " modules\n");
1719
1720 SmallVector<std::optional<ModuleInfo>> moduleInfos(modules.size());
1721 StructuralHasherSharedConstants hasherConstants(&getContext());
1722
1723 // Attribute name used to store dedup_group for this pass.
1724 auto dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
1725
1726 // Move dedup group annotations to attributes on the module.
1727 // This results in the desired behavior (included in hash),
1728 // and avoids unnecessary processing of these as annotations
1729 // that need to be tracked, made non-local, so on.
1730 for (auto module : modules) {
1731 llvm::SmallSetVector<StringAttr, 1> groups;
1733 module, [&groups, dedupGroupClass](Annotation annotation) {
1734 if (annotation.getClassAttr() != dedupGroupClass)
1735 return false;
1736 groups.insert(annotation.getMember<StringAttr>("group"));
1737 return true;
1738 });
1739 if (groups.size() > 1) {
1740 module.emitError("module belongs to multiple dedup groups: ") << groups;
1741 return signalPassFailure();
1742 }
1743 assert(!module->hasAttr(dedupGroupAttrName) &&
1744 "unexpected existing use of temporary dedup group attribute");
1745 if (!groups.empty())
1746 module->setDiscardableAttr(dedupGroupAttrName, groups.front());
1747 }
1748
1749 // Calculate module information parallelly.
1750 LLVM_DEBUG(llvm::dbgs() << "Computing module information\n");
1751 auto result = mlir::failableParallelForEach(
1752 context, llvm::seq(modules.size()), [&](unsigned idx) {
1753 auto module = modules[idx];
1754 // If the module is marked with NoDedup, just skip it.
1755 if (AnnotationSet::hasAnnotation(module, noDedupClass))
1756 return success();
1757
1758 // Only dedup extmodule's with defname.
1759 if (auto ext = dyn_cast<FExtModuleOp>(*module);
1760 ext && !ext.getDefname().has_value())
1761 return success();
1762
1763 // Only dedup classes if enabled.
1764 if (isa<ClassOp>(*module) && !dedupClasses)
1765 return success();
1766
1767 StructuralHasher hasher(hasherConstants);
1768 // Calculate the hash of the module and referred module names.
1769 moduleInfos[idx] = hasher.getModuleInfo(module);
1770 return success();
1771 });
1772
1773 // Dump out the module hashes for debugging.
1774 LLVM_DEBUG({
1775 auto &os = llvm::dbgs();
1776 for (auto [module, info] : llvm::zip(modules, moduleInfos)) {
1777 os << "- Hash ";
1778 if (info) {
1779 os << llvm::format_bytes(info->structuralHash, std::nullopt, 32, 32);
1780 } else {
1781 os << "--------------------------------";
1782 os << "--------------------------------";
1783 }
1784 os << " for " << module.getModuleNameAttr() << "\n";
1785 }
1786 });
1787
1788 if (result.failed())
1789 return signalPassFailure();
1790
1791 LLVM_DEBUG(llvm::dbgs() << "Update modules\n");
1792 for (auto [i, module] : llvm::enumerate(modules)) {
1793 auto moduleName = module.getModuleNameAttr();
1794 auto &maybeModuleInfo = moduleInfos[i];
1795 // If the hash was not calculated, we need to skip it.
1796 if (!maybeModuleInfo) {
1797 // We record it in the dedup map to help detect errors when the user
1798 // marks the module as both NoDedup and MustDedup. We do not record this
1799 // module in the hasher to make sure no other module dedups "into" this
1800 // one.
1801 dedupMap[moduleName] = moduleName;
1802 continue;
1803 }
1804
1805 auto &moduleInfo = maybeModuleInfo.value();
1806 moduleToModuleInfo.try_emplace(module, &moduleInfo);
1807
1808 // Replace module names referred in the module with new names.
1809 for (auto &referredModule : moduleInfo.referredModuleNames)
1810 referredModule = dedupMap[referredModule];
1811
1812 // Check if there is a module with the same hash.
1813 auto it = moduleInfoToModule.find(&moduleInfo);
1814 if (it != moduleInfoToModule.end()) {
1815 auto original = cast<FModuleLike>(it->second);
1816 auto originalName = original.getModuleNameAttr();
1817
1818 // If the current module is public, and the original is private, we
1819 // want to dedup the private module into the public one.
1820 if (!canRemoveModule(module)) {
1821 // If both modules are public, then we can't dedup anything.
1822 if (!canRemoveModule(original))
1823 continue;
1824 // Swap the canonical module in the dedup map.
1825 for (auto &[_, dedupedName] : dedupMap)
1826 if (dedupedName == originalName)
1827 dedupedName = moduleName;
1828 // Update the module hash table to point to the new original, so all
1829 // future modules dedup with the new canonical module.
1830 it->second = module;
1831 // Swap the locals.
1832 std::swap(originalName, moduleName);
1833 std::swap(original, module);
1834 }
1835
1836 // Record the group ID of the other module.
1837 LLVM_DEBUG(llvm::dbgs() << "- Replace " << moduleName << " with "
1838 << originalName << "\n");
1839 dedupMap[moduleName] = originalName;
1840 deduper.dedup(original, module);
1841 ++erasedModules;
1842 anythingChanged = true;
1843 continue;
1844 }
1845 // Any module not deduplicated must be recorded.
1846 deduper.record(module);
1847 // Add the module to a new dedup group.
1848 dedupMap[moduleName] = moduleName;
1849 // Record the module info.
1850 moduleInfoToModule[&moduleInfo] = module;
1851 }
1852
1853 // This part verifies that all modules marked by "MustDedup" have been
1854 // properly deduped with each other. For this check to succeed, all modules
1855 // have to been deduped to the same module. It is possible that a module was
1856 // deduped with the wrong thing.
1857
1858 auto failed = false;
1859 // This parses the module name out of a target string.
1860 auto parseModule = [&](Attribute path) -> StringAttr {
1861 // Each module is listed as a target "~Circuit|Module" which we have to
1862 // parse.
1863 auto [_, rhs] = cast<StringAttr>(path).getValue().split('|');
1864 return StringAttr::get(context, rhs);
1865 };
1866 // This gets the name of the module which the current module was deduped
1867 // with. If the named module isn't in the map, then we didn't encounter it
1868 // in the circuit.
1869 auto getLead = [&](StringAttr module) -> StringAttr {
1870 auto it = dedupMap.find(module);
1871 if (it == dedupMap.end()) {
1872 auto diag = emitError(circuit.getLoc(),
1873 "MustDeduplicateAnnotation references module ")
1874 << module << " which does not exist";
1875 failed = true;
1876 return nullptr;
1877 }
1878 return it->second;
1879 };
1880
1881 LLVM_DEBUG(llvm::dbgs() << "Update annotations\n");
1882 AnnotationSet::removeAnnotations(circuit, [&](Annotation annotation) {
1883 if (!annotation.isClass(mustDedupAnnoClass))
1884 return false;
1885 auto modules = annotation.getMember<ArrayAttr>("modules");
1886 if (!modules) {
1887 emitError(circuit.getLoc(),
1888 "MustDeduplicateAnnotation missing \"modules\" member");
1889 failed = true;
1890 return false;
1891 }
1892 // Empty module list has nothing to process.
1893 if (modules.empty())
1894 return true;
1895 // Get the first element.
1896 auto firstModule = parseModule(modules[0]);
1897 auto firstLead = getLead(firstModule);
1898 if (!firstLead)
1899 return false;
1900 // Verify that the remaining elements are all the same as the first.
1901 for (auto attr : modules.getValue().drop_front()) {
1902 auto nextModule = parseModule(attr);
1903 auto nextLead = getLead(nextModule);
1904 if (!nextLead)
1905 return false;
1906 if (firstLead != nextLead) {
1907 auto diag = emitError(circuit.getLoc(), "module ")
1908 << nextModule << " not deduplicated with " << firstModule;
1909 auto a = instanceGraph.lookup(firstLead)->getModule();
1910 auto b = instanceGraph.lookup(nextLead)->getModule();
1911 equiv.check(diag, a, b);
1912 failed = true;
1913 return false;
1914 }
1915 }
1916 return true;
1917 });
1918 if (failed)
1919 return signalPassFailure();
1920
1921 // Remove all dedup group attributes, they only exist during this pass.
1922 for (auto module : circuit.getOps<FModuleLike>())
1923 module->removeDiscardableAttr(dedupGroupAttrName);
1924
1925 // Fixup all operations that we've found to be sensitive to symbol names.
1926 // This includes module and class instances, wires, connects, etc.
1927 fixupSymbolSensitiveOps(instanceGraph, moduleToModuleInfo, dedupMap);
1928
1929 markAnalysesPreserved<NLATable>();
1930 if (!anythingChanged)
1931 markAllAnalysesPreserved();
1932 }
1933};
1934} // end anonymous namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location mergeLoc(MLIRContext *context, Location to, Location from)
Definition Dedup.cpp:941
static void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src)
This fixes up connects when the field names of a bundle type changes.
Definition Dedup.cpp:1554
static bool canRemoveModule(mlir::SymbolOpInterface symbol)
Returns true if the module can be removed.
Definition Dedup.cpp:55
static void fixupSymbolSensitiveOp(Operation *op, InstanceGraph &instanceGraph, const DenseMap< Attribute, StringAttr > &dedupMap)
Adjust the symbol references in an op.
Definition Dedup.cpp:1580
static void fixupSymbolSensitiveOps(InstanceGraph &instanceGraph, const DenseMap< Operation *, ModuleInfoRef > &moduleToModuleInfo, const DenseMap< Attribute, StringAttr > &dedupMap)
Adjust the symbol references in ops marked as sensitive to them.
Definition Dedup.cpp:1648
static void mergeRegions(Region *region1, Region *region2)
Definition HWCleanup.cpp:77
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
void setDict(DictionaryAttr dict)
Set the data dictionary of this attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
StringAttr getClassAttr() const
Return the 'class' that this annotation is representing.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
A table of inner symbols and their resolutions.
auto getModule()
Get the module that this node is tracking.
decltype(auto) walkPostOrder(Fn &&fn)
Perform a post-order walk across the modules.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
static StringRef toString(Direction direction)
Definition FIRRTLEnums.h:44
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
constexpr const char * mustDedupAnnoClass
constexpr const char * noDedupAnnoClass
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
constexpr const char * dedupGroupAnnoClass
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
constexpr const char * dontTouchAnnoClass
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:35
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(const llvm::Twine &str, unsigned width=80)
Write a "header"-like string to the debug stream with a certain width.
Definition Debug.cpp:17
SmallVector< FlatSymbolRefAttr > createNLAs(Operation *fromModule, ArrayRef< Attribute > baseNamepath, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of the from module and create an NLA for each one, appending the baseNamep...
Definition Dedup.cpp:1116
MLIRContext * context
Definition Dedup.cpp:1522
Block * nlaBlock
We insert all NLAs to the beginning of this block.
Definition Dedup.cpp:1530
void recordAnnotations(Operation *op)
Record all targets which use an NLA.
Definition Dedup.cpp:1071
void eraseNLA(hw::HierPathOp nla)
This erases the NLA op, and removes the NLA from every module's NLA map, but it does not delete the N...
Definition Dedup.cpp:1180
void mergeAnnotations(FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Merge all annotations and port annotations on two operations.
Definition Dedup.cpp:1361
void replaceInstances(FModuleLike toModule, Operation *fromModule)
This deletes and replaces all instances of the "fromModule" with instances of the "toModule".
Definition Dedup.cpp:1087
void record(FModuleLike module)
Record the usages of any NLA's in this module, so that we may update the annotation if the parent mod...
Definition Dedup.cpp:1045
void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName, StringAttr fromName)
Definition Dedup.cpp:1259
void mergeRegions(RenameMap &renameMap, FModuleLike toModule, Region &toRegion, FModuleLike fromModule, Region &fromRegion)
Definition Dedup.cpp:1514
void dedup(FModuleLike toModule, FModuleLike fromModule)
Remove the "fromModule", and replace all references to it with the "toModule".
Definition Dedup.cpp:1010
void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all the NLAs that the two modules participate in, replacing references to the "from" module w...
Definition Dedup.cpp:1251
SmallVector< FlatSymbolRefAttr > createNLAs(StringAttr toModuleName, FModuleLike fromModule, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of this module and create an NLA for each one.
Definition Dedup.cpp:1153
void recordAnnotations(AnnoTarget target)
For a specific annotation target, record all the unique NLAs which target it in the targetMap.
Definition Dedup.cpp:1064
NLATable * nlaTable
Cached nla table analysis.
Definition Dedup.cpp:1527
hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule, hw::InnerSymAttr toSym, hw::InnerSymAttr fromSym)
Definition Dedup.cpp:1386
void cloneAnnotation(SmallVectorImpl< FlatSymbolRefAttr > &nlas, Annotation anno, ArrayRef< NamedAttribute > attributes, unsigned nonLocalIndex, SmallVectorImpl< Annotation > &newAnnotations)
Clone the annotation for each NLA in a list.
Definition Dedup.cpp:1161
void recordSymRenames(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Definition Dedup.cpp:1436
void mergeAnnotations(FModuleLike toModule, AnnoTarget to, AnnotationSet toAnnos, FModuleLike fromModule, AnnoTarget from, AnnotationSet fromAnnos)
Merge the annotations of a specific target, either a operation or a port on an operation.
Definition Dedup.cpp:1338
StringAttr nonLocalString
Definition Dedup.cpp:1540
hw::InnerSymbolNamespace & getNamespace(Operation *module)
Get a cached namespace for a module.
Definition Dedup.cpp:1057
SymbolTable & symbolTable
Definition Dedup.cpp:1524
void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Recursively merge two operations.
Definition Dedup.cpp:1481
DenseMap< Operation *, hw::InnerSymbolNamespace > moduleNamespaces
A module namespace cache.
Definition Dedup.cpp:1544
bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to, FModuleLike fromModule, Annotation anno, SmallVectorImpl< Annotation > &newAnnotations)
Take an annotation, and update it to be a non-local annotation.
Definition Dedup.cpp:1267
InstanceGraph & instanceGraph
Definition Dedup.cpp:1523
void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock, FModuleLike fromModule, Block &fromBlock)
Recursively merge two blocks.
Definition Dedup.cpp:1500
DenseMap< Attribute, llvm::SmallDenseSet< AnnoTarget > > targetMap
Definition Dedup.cpp:1533
StringAttr classString
Definition Dedup.cpp:1541
void copyAnnotations(FModuleLike toModule, AnnoTarget to, FModuleLike fromModule, AnnotationSet annos, SmallVectorImpl< Annotation > &newAnnotations, SmallPtrSetImpl< Attribute > &dontTouches)
Definition Dedup.cpp:1309
Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable, NLATable *nlaTable, CircuitOp circuit)
Definition Dedup.cpp:994
void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all NLAs referencing the "from" module to point to the "to" module.
Definition Dedup.cpp:1190
DenseMap< StringAttr, StringAttr > RenameMap
Definition Dedup.cpp:992
DenseMap< Attribute, Attribute > nlaCache
Definition Dedup.cpp:1537
const hw::InnerSymbolTable & a
Definition Dedup.cpp:477
ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
Definition Dedup.cpp:474
const hw::InnerSymbolTable & b
Definition Dedup.cpp:478
This class is for reporting differences between two modules which should have been deduplicated.
Definition Dedup.cpp:456
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:928
std::string prettyPrint(Attribute attr)
Definition Dedup.cpp:481
LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a, FInstanceLike b)
Definition Dedup.cpp:764
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Block &aBlock, Operation *b, Block &bBlock)
Definition Dedup.cpp:553
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, Type aType, Operation *b, Type bType)
Definition Dedup.cpp:527
StringAttr noDedupClass
Definition Dedup.cpp:923
StringAttr dedupGroupAttrName
Definition Dedup.cpp:925
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, DictionaryAttr aDict, Operation *b, DictionaryAttr bDict)
Definition Dedup.cpp:676
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Region &aRegion, Operation *b, Region &bRegion)
Definition Dedup.cpp:631
StringAttr portDirectionsAttr
Definition Dedup.cpp:921
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, BundleType aType, Operation *b, BundleType bType)
Definition Dedup.cpp:497
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Operation *b)
Definition Dedup.cpp:785
Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
Definition Dedup.cpp:457
LogicalResult check(InFlightDiagnostic &diag, Operation *a, mlir::DenseBoolArrayAttr aAttr, Operation *b, mlir::DenseBoolArrayAttr bAttr)
Definition Dedup.cpp:651
InstanceGraph & instanceGraph
Definition Dedup.cpp:929
void check(InFlightDiagnostic &diag, Operation *a, Operation *b)
Definition Dedup.cpp:872
A reference to a ModuleInfo that compares and hashes like it.
Definition Dedup.cpp:411
ModuleInfo * info
Definition Dedup.cpp:413
ModuleInfoRef(ModuleInfo *info)
Definition Dedup.cpp:412
std::vector< Operation * > symbolSensitiveOps
Definition Dedup.cpp:90
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:86
std::array< uint8_t, 32 > structuralHash
Definition Dedup.cpp:84
This struct contains constant string attributes shared across different threads.
Definition Dedup.cpp:100
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:127
StructuralHasherSharedConstants(MLIRContext *context)
Definition Dedup.cpp:101
void populateInnerSymIDTable(FModuleLike module)
Find all the ports and operations which may define an inner symbol operations and give each a unique ...
Definition Dedup.cpp:147
void update(Operation *op, DictionaryAttr dict)
Hash the top level attribute dictionary of the operation.
Definition Dedup.cpp:270
void update(Type type)
Definition Dedup.cpp:247
void update(const void *pointer)
Definition Dedup.cpp:204
void update(ClassType type)
Definition Dedup.cpp:232
DenseMap< void *, unsigned > idTable
Definition Dedup.cpp:383
void update(const std::pair< T, U > &pair)
Definition Dedup.cpp:215
void update(Operation *op)
Definition Dedup.cpp:349
llvm::SHA256 sha
Definition Dedup.cpp:397
DenseMap< StringAttr, std::pair< size_t, size_t > > innerSymIDTable
Definition Dedup.cpp:387
ModuleInfo getModuleInfo(FModuleLike module)
Definition Dedup.cpp:134
void update(size_t value)
Definition Dedup.cpp:209
void update(BundleType type)
Definition Dedup.cpp:223
unsigned getID(void *object)
Definition Dedup.cpp:166
void update(OpResult result)
Definition Dedup.cpp:255
void update(OpOperand &operand)
Definition Dedup.cpp:187
StructuralHasher(const StructuralHasherSharedConstants &constants)
Definition Dedup.cpp:131
std::vector< Operation * > symbolSensitiveOps
Definition Dedup.cpp:405
void update(Region *region)
Definition Dedup.cpp:341
void update(Block *block)
Definition Dedup.cpp:330
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:390
void update(TypeID typeID)
Definition Dedup.cpp:220
const StructuralHasherSharedConstants & constants
Definition Dedup.cpp:393
std::pair< size_t, size_t > getInnerSymID(StringAttr name)
Definition Dedup.cpp:183
unsigned finalizeID(void *object)
Definition Dedup.cpp:174
void update(mlir::OperationName name)
Definition Dedup.cpp:324
An annotation target is used to keep track of something that is targeted by an Annotation.
AnnotationSet getAnnotations() const
Get the annotations associated with the target.
void setAnnotations(AnnotationSet annotations) const
Set the annotations associated with the target.
This represents an annotation targeting a specific operation.
Attribute getNLAReference(hw::InnerSymbolNamespace &moduleNamespace) const
This represents an annotation targeting a specific port of a module, memory, or instance.
static bool isEqual(const ModuleInfoRef &lhs, const ModuleInfoRef &rhs)
Definition Dedup.cpp:440
static ModuleInfoRef getTombstoneKey()
Definition Dedup.cpp:424
static ModuleInfoRef getEmptyKey()
Definition Dedup.cpp:420
static unsigned getHashValue(const ModuleInfoRef &ref)
Definition Dedup.cpp:428