CIRCT 22.0.0git
Loading...
Searching...
No Matches
Dedup.cpp
Go to the documentation of this file.
1//===- Dedup.cpp - FIRRTL module deduping -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements FIRRTL module deduplication.
10//
11//===----------------------------------------------------------------------===//
12
23#include "circt/Support/Debug.h"
24#include "circt/Support/LLVM.h"
25#include "mlir/IR/IRMapping.h"
26#include "mlir/IR/Threading.h"
27#include "mlir/Pass/Pass.h"
28#include "llvm/ADT/DenseMap.h"
29#include "llvm/ADT/DenseMapInfo.h"
30#include "llvm/ADT/Hashing.h"
31#include "llvm/ADT/PostOrderIterator.h"
32#include "llvm/ADT/SmallPtrSet.h"
33#include "llvm/Support/Debug.h"
34#include "llvm/Support/Format.h"
35#include "llvm/Support/SHA256.h"
36
37#define DEBUG_TYPE "firrtl-dedup"
38
39namespace circt {
40namespace firrtl {
41#define GEN_PASS_DEF_DEDUP
42#include "circt/Dialect/FIRRTL/Passes.h.inc"
43} // namespace firrtl
44} // namespace circt
45
46using namespace circt;
47using namespace firrtl;
48using hw::InnerRefAttr;
49
50//===----------------------------------------------------------------------===//
51// Utility function for classifying a Symbol's dedup-ability.
52//===----------------------------------------------------------------------===//
53
54/// Returns true if the module can be removed.
55static bool canRemoveModule(mlir::SymbolOpInterface symbol) {
56 // If the symbol is not private, it cannot be removed.
57 if (!symbol.isPrivate())
58 return false;
59 // Classes may be referenced in object types, so can not normally be removed
60 // if we can't find any symbol uses. Since we know that dedup will update the
61 // types of instances appropriately, we can ignore that and return true here.
62 if (isa<ClassLike>(*symbol))
63 return true;
64 // If module can not be removed even if no uses can be found, we can not
65 // delete it. The implication is that there are hidden symbol uses that dedup
66 // will not properly update.
67 if (!symbol.canDiscardOnUseEmpty())
68 return false;
69 // The module can be deleted.
70 return true;
71}
72
73//===----------------------------------------------------------------------===//
74// Hashing
75//===----------------------------------------------------------------------===//
76
77// This struct contains information to determine module module uniqueness. A
78// first element is a structural hash of the module, and the second element is
79// an array which tracks module names encountered in the walk. Since module
80// names could be replaced during dedup, it's necessary to keep names up-to-date
81// before actually combining them into structural hashes.
82struct ModuleInfo {
83 // SHA256 hash.
84 std::array<uint8_t, 32> structuralHash;
85 // Module names referred by instance op in the module.
86 std::vector<StringAttr> referredModuleNames;
87 // The operations that contain references to symbols that may be changed by
88 // dedup. These need a fixup pass after dedup. This is just an optimization
89 // and does not factor into the hash or the equality between `ModuleInfo`s.
90 std::vector<Operation *> symbolSensitiveOps;
91};
92
93static bool operator==(const ModuleInfo &lhs, const ModuleInfo &rhs) {
94 return lhs.structuralHash == rhs.structuralHash &&
96}
97
98/// This struct contains constant string attributes shared across different
99/// threads.
101 explicit StructuralHasherSharedConstants(MLIRContext *context) {
102 portTypesAttr = StringAttr::get(context, "portTypes");
103 moduleNameAttr = StringAttr::get(context, "moduleName");
104 portNamesAttr = StringAttr::get(context, "portNames");
105 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
106 nonessentialAttributes.insert(StringAttr::get(context, "convention"));
107 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
108 nonessentialAttributes.insert(StringAttr::get(context, "name"));
109 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
110 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
111 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
112 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
113 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
114 nonessentialAttributes.insert(StringAttr::get(context, "sym_visibility"));
115 };
116
117 // This is a cached "portTypes" string attr.
118 StringAttr portTypesAttr;
119
120 // This is a cached "moduleName" string attr.
121 StringAttr moduleNameAttr;
122
123 // This is a cached "portNames" string attr.
124 StringAttr portNamesAttr;
125
126 // This is a set of every attribute we should ignore.
127 DenseSet<Attribute> nonessentialAttributes;
128};
129
132 : constants(constants) {}
133
134 ModuleInfo getModuleInfo(FModuleLike module) {
136 update(&(*module));
137 return {sha.final(), std::move(referredModuleNames),
138 std::move(symbolSensitiveOps)};
139 }
140
141private:
142 /// Find all the ports and operations which may define an inner symbol
143 /// operations and give each a unique id. If the port/operation does define
144 /// an inner symbol, map the symbol name to a pair of the id and the symbol's
145 /// field id. When we hash (local) references to this inner symbol, we will
146 /// hash in the id and the the field id.
147 void populateInnerSymIDTable(FModuleLike module) {
148 // Add port symbols. If no port has a symbol defined, the port symbol array
149 // will be totally empty.
150 for (auto [index, innerSym] : llvm::enumerate(module.getPortSymbols())) {
151 for (auto prop : cast<hw::InnerSymAttr>(innerSym))
152 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
153 }
154 // Add operation symbols.
155 size_t index = module.getNumPorts();
156 module.walk([&](hw::InnerSymbolOpInterface innerSymOp) {
157 if (auto innerSym = innerSymOp.getInnerSymAttr()) {
158 for (auto prop : innerSym)
159 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
160 }
161 ++index;
162 });
163 }
164
165 // Get the identifier for an object. The identifier is assigned on first use.
166 unsigned getID(void *object) {
167 auto [it, inserted] = idTable.try_emplace(object, nextID);
168 if (inserted)
169 ++nextID;
170 return it->second;
171 }
172
173 // Get the identifier for an IR object. Free the ID, too.
174 unsigned finalizeID(void *object) {
175 auto it = idTable.find(object);
176 if (it == idTable.end())
177 return nextID++;
178 auto id = it->second;
179 idTable.erase(it);
180 return id;
181 }
182
183 std::pair<size_t, size_t> getInnerSymID(StringAttr name) {
184 return innerSymIDTable.at(name);
185 }
186
187 void update(OpOperand &operand) {
188 auto value = operand.get();
189 if (auto result = dyn_cast<OpResult>(value)) {
190 auto *op = result.getOwner();
191 update(getID(op));
192 update(result.getResultNumber());
193 return;
194 }
195 if (auto argument = dyn_cast<BlockArgument>(value)) {
196 auto *block = argument.getOwner();
197 update(getID(block));
198 update(argument.getArgNumber());
199 return;
200 }
201 llvm_unreachable("Unknown value type");
202 }
203
204 void update(const void *pointer) {
205 auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
206 sha.update(ArrayRef<uint8_t>(addr, sizeof pointer));
207 }
208
209 void update(size_t value) {
210 auto *addr = reinterpret_cast<const uint8_t *>(&value);
211 sha.update(ArrayRef<uint8_t>(addr, sizeof value));
212 }
213
214 template <typename T, typename U>
215 void update(const std::pair<T, U> &pair) {
216 update(pair.first);
217 update(pair.second);
218 }
219
220 void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
221
222 // NOLINTNEXTLINE(misc-no-recursion)
223 void update(BundleType type) {
224 update(type.getTypeID());
225 for (auto &element : type.getElements()) {
226 update(element.isFlip);
227 update(element.type);
228 }
229 }
230
231 // NOLINTNEXTLINE(misc-no-recursion)
232 void update(ClassType type) {
233 update(type.getTypeID());
234 // Don't hash the class name directly, since it may be replaced during
235 // dedup. Record the class name instead and lazily combine their hashes
236 // using the same mechanism as instances and modules.
237 hasSeenSymbol = true;
238 referredModuleNames.push_back(type.getNameAttr().getAttr());
239 for (auto &element : type.getElements()) {
240 update(element.name.getAsOpaquePointer());
241 update(element.type);
242 update(static_cast<unsigned>(element.direction));
243 }
244 }
245
246 // NOLINTNEXTLINE(misc-no-recursion)
247 void update(Type type) {
248 if (auto bundle = type_dyn_cast<BundleType>(type))
249 return update(bundle);
250 if (auto klass = type_dyn_cast<ClassType>(type))
251 return update(klass);
252 update(type.getAsOpaquePointer());
253 }
254
255 void update(OpResult result) {
256 // Like instance ops, don't use object ops' result types since they might be
257 // replaced by dedup. Record the class names and lazily combine their hashes
258 // using the same mechanism as instances and modules.
259 if (auto objectOp = dyn_cast<ObjectOp>(result.getOwner())) {
260 hasSeenSymbol = true;
261 referredModuleNames.push_back(objectOp.getType().getNameAttr().getAttr());
262 return;
263 }
264
265 update(result.getType());
266 }
267
268 /// Hash the top level attribute dictionary of the operation. This function
269 /// has special handling for inner symbols, ports, and referenced modules.
270 void update(Operation *op, DictionaryAttr dict) {
271 for (auto namedAttr : dict) {
272 auto name = namedAttr.getName();
273 auto value = namedAttr.getValue();
274
275 // Check whether this attribute contains a nested symbol, just to make
276 // sure we revisit this op after dedup and update any symbols.
277 value.walk([&](FlatSymbolRefAttr) { hasSeenSymbol = true; });
278
279 // Skip names and annotations, except in certain cases.
280 // Names of ports are load bearing for classes, so we do hash those.
281 bool isClassPortNames =
282 isa<ClassLike>(op) && name == constants.portNamesAttr;
283 if (constants.nonessentialAttributes.contains(name) && !isClassPortNames)
284 continue;
285
286 // Hash the attribute name (an interned pointer).
287 update(name.getAsOpaquePointer());
288
289 // Hash the port types.
290 if (name == constants.portTypesAttr) {
291 auto portTypes = cast<ArrayAttr>(value).getAsValueRange<TypeAttr>();
292 for (auto type : portTypes)
293 update(type);
294 continue;
295 }
296
297 // For instance op, don't use `moduleName` attributes since they might be
298 // replaced by dedup. Record the names and lazily combine their hashes.
299 // It is assumed that module names are hashed only through instance ops;
300 // it could cause suboptimal results if there was other operation that
301 // refers to module names through essential attributes.
302 if (isa<InstanceOp>(op) && name == constants.moduleNameAttr) {
303 referredModuleNames.push_back(cast<FlatSymbolRefAttr>(value).getAttr());
304 continue;
305 }
306
307 // TODO: properly handle DistinctAttr, including its use in paths.
308 // See https://github.com/llvm/circt/issues/6583.
309 if (isa<DistinctAttr>(value))
310 continue;
311
312 // If this is an symbol reference, we need to perform name erasure.
313 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(value)) {
314 update(getInnerSymID(innerRef.getName()));
315 continue;
316 }
317
318 // We don't need to handle this attribute specially, so hash its unique
319 // address.
320 update(value.getAsOpaquePointer());
321 }
322 }
323
324 void update(mlir::OperationName name) {
325 // Operation names are interned.
326 update(name.getAsOpaquePointer());
327 }
328
329 // NOLINTNEXTLINE(misc-no-recursion)
330 void update(Block *block) {
331 for (auto &op : llvm::reverse(*block))
332 update(&op);
333 for (auto type : block->getArgumentTypes())
334 update(type);
335 update(finalizeID(block));
336 update(position);
337 ++position;
338 }
339
340 // NOLINTNEXTLINE(misc-no-recursion)
341 void update(Region *region) {
342 for (auto &block : llvm::reverse(region->getBlocks()))
343 update(&block);
344 update(position);
345 ++position;
346 }
347
348 // NOLINTNEXTLINE(misc-no-recursion)
349 void update(Operation *op) {
350 // Hash the regions. We need to make sure an empty region doesn't hash the
351 // same as no region, so we include the number of regions.
352 update(op->getNumRegions());
353 for (auto &region : reverse(op->getRegions()))
354 update(&region);
355
356 update(op->getName());
357
358 // Record the uses for later hashing.
359 for (auto &operand : op->getOpOperands())
360 update(operand);
361
362 // This happens after the numbering above, as it uses blockarg numbering
363 // for inner symbols.
364 hasSeenSymbol = false;
365 update(op, op->getAttrDictionary());
366
367 // Record any op results (types).
368 for (auto result : op->getResults())
369 update(result);
370
371 // If any of the operands, attributes, or results depended on symbols,
372 // remember this op for later adjustment.
373 if (hasSeenSymbol)
374 symbolSensitiveOps.push_back(op);
375
376 // Incorporate the hash of uses we have already built.
377 update(finalizeID(op));
378 update(position);
379 ++position;
380 }
381
382 // A map from an operation/block, to its identifier.
383 DenseMap<void *, unsigned> idTable;
384 unsigned nextID = 0;
385
386 // A map from an inner symbol, to its identifier.
387 DenseMap<StringAttr, std::pair<size_t, size_t>> innerSymIDTable;
388
389 // This keeps track of module names in the order of the appearance.
390 std::vector<StringAttr> referredModuleNames;
391
392 // String constants.
394
395 // This is the actual running hash calculation. This is a stateful element
396 // that should be reinitialized after each hash is produced.
397 llvm::SHA256 sha;
398
399 // The index of the current op. Increment after handling each op.
400 size_t position = 0;
401
402 // The operations that contain references to symbols that may be changed by
403 // dedup. These need a fixup pass after dedup.
404 bool hasSeenSymbol = false;
405 std::vector<Operation *> symbolSensitiveOps;
406};
407
408/// A reference to a `ModuleInfo` that compares and hashes like it. This is used
409/// to keep the potentially heavy module infos in a vector, and then populate
410/// maps with references to them.
415
416/// Allow `ModuleInfoRef` to be used as dense map keys. Hashes and compares the
417/// `ModuleInfo` the ref points to.
418template <>
423
427
428 static unsigned getHashValue(const ModuleInfoRef &ref) {
429 // We assume SHA256 is already a good hash and just truncate down to the
430 // number of bytes we need for DenseMap.
431 unsigned hash;
432 std::memcpy(&hash, ref.info->structuralHash.data(), sizeof(unsigned));
433
434 // Combine module names.
435 return llvm::hash_combine(
436 hash, llvm::hash_combine_range(ref.info->referredModuleNames.begin(),
437 ref.info->referredModuleNames.end()));
438 }
439
440 static bool isEqual(const ModuleInfoRef &lhs, const ModuleInfoRef &rhs) {
441 auto *empty = getEmptyKey().info;
442 auto *tombstone = getTombstoneKey().info;
443 if (lhs.info == empty || rhs.info == empty || lhs.info == tombstone ||
444 rhs.info == tombstone)
445 return lhs.info == rhs.info;
446 return *lhs.info == *rhs.info;
447 }
448};
449
450//===----------------------------------------------------------------------===//
451// Equivalence
452//===----------------------------------------------------------------------===//
453
454/// This class is for reporting differences between two modules which should
455/// have been deduplicated.
459 noDedupClass = StringAttr::get(context, noDedupAnnoClass);
460 dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
461 portDirectionsAttr = StringAttr::get(context, "portDirections");
462 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
463 nonessentialAttributes.insert(StringAttr::get(context, "name"));
464 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
465 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
466 nonessentialAttributes.insert(StringAttr::get(context, "portTypes"));
467 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
468 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
469 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
470 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
471 }
472
480
481 std::string prettyPrint(Attribute attr) {
482 SmallString<64> buffer;
483 llvm::raw_svector_ostream os(buffer);
484 if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
485 os << "0x";
486 if (integerAttr.getType().isSignlessInteger())
487 integerAttr.getValue().toStringUnsigned(buffer, /*radix=*/16);
488 else
489 integerAttr.getAPSInt().toString(buffer, /*radix=*/16);
490
491 } else
492 os << attr;
493 return std::string(buffer);
494 }
495
496 // NOLINTNEXTLINE(misc-no-recursion)
497 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
498 Operation *a, BundleType aType, Operation *b,
499 BundleType bType) {
500 if (aType.getNumElements() != bType.getNumElements()) {
501 diag.attachNote(a->getLoc())
502 << message << " bundle type has different number of elements";
503 diag.attachNote(b->getLoc()) << "second operation here";
504 return failure();
505 }
506
507 for (auto elementPair :
508 llvm::zip(aType.getElements(), bType.getElements())) {
509 auto aElement = std::get<0>(elementPair);
510 auto bElement = std::get<1>(elementPair);
511 if (aElement.isFlip != bElement.isFlip) {
512 diag.attachNote(a->getLoc()) << message << " bundle element "
513 << aElement.name << " flip does not match";
514 diag.attachNote(b->getLoc()) << "second operation here";
515 return failure();
516 }
517
518 if (failed(check(diag,
519 "bundle element \'" + aElement.name.getValue() + "'", a,
520 aElement.type, b, bElement.type)))
521 return failure();
522 }
523 return success();
524 }
525
526 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
527 Operation *a, Type aType, Operation *b, Type bType) {
528 if (aType == bType)
529 return success();
530 if (auto aBundleType = type_dyn_cast<BundleType>(aType))
531 if (auto bBundleType = type_dyn_cast<BundleType>(bType))
532 return check(diag, message, a, aBundleType, b, bBundleType);
533 if (type_isa<RefType>(aType) && type_isa<RefType>(bType) &&
534 aType != bType) {
535 diag.attachNote(a->getLoc())
536 << message << ", has a RefType with a different base type "
537 << type_cast<RefType>(aType).getType()
538 << " in the same position of the two modules marked as 'must dedup'. "
539 "(This may be due to Grand Central Taps or Views being different "
540 "between the two modules.)";
541 diag.attachNote(b->getLoc())
542 << "the second module has a different base type "
543 << type_cast<RefType>(bType).getType();
544 return failure();
545 }
546 diag.attachNote(a->getLoc())
547 << message << " types don't match, first type is " << aType;
548 diag.attachNote(b->getLoc()) << "second type is " << bType;
549 return failure();
550 }
551
552 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
553 Block &aBlock, Operation *b, Block &bBlock) {
554
555 // Block argument types.
556 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
557 auto portNo = 0;
558 auto emitMissingPort = [&](Value existsVal, Operation *opExists,
559 Operation *opDoesNotExist) {
560 StringRef portName;
561 auto portNames = opExists->getAttrOfType<ArrayAttr>("portNames");
562 if (portNames)
563 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
564 portName = portNameAttr.getValue();
565 if (type_isa<RefType>(existsVal.getType())) {
566 diag.attachNote(opExists->getLoc())
567 << " contains a RefType port named '" + portName +
568 "' that only exists in one of the modules (can be due to "
569 "difference in Grand Central Tap or View of two modules "
570 "marked with must dedup)";
571 diag.attachNote(opDoesNotExist->getLoc())
572 << "second module to be deduped that does not have the RefType "
573 "port";
574 } else {
575 diag.attachNote(opExists->getLoc())
576 << "port '" + portName + "' only exists in one of the modules";
577 diag.attachNote(opDoesNotExist->getLoc())
578 << "second module to be deduped that does not have the port";
579 }
580 return failure();
581 };
582
583 for (auto argPair :
584 llvm::zip_longest(aBlock.getArguments(), bBlock.getArguments())) {
585 auto &aArg = std::get<0>(argPair);
586 auto &bArg = std::get<1>(argPair);
587 if (aArg.has_value() && bArg.has_value()) {
588 // TODO: we should print the port number if there are no port names, but
589 // there are always port names ;).
590 StringRef portName;
591 if (portNames) {
592 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
593 portName = portNameAttr.getValue();
594 }
595 // Assumption here that block arguments correspond to ports.
596 if (failed(check(diag, "module port '" + portName + "'", a,
597 aArg->getType(), b, bArg->getType())))
598 return failure();
599 data.map.map(aArg.value(), bArg.value());
600 portNo++;
601 continue;
602 }
603 if (!aArg.has_value())
604 std::swap(a, b);
605 return emitMissingPort(aArg.has_value() ? aArg.value() : bArg.value(), a,
606 b);
607 }
608
609 // Blocks operations.
610 auto aIt = aBlock.begin();
611 auto aEnd = aBlock.end();
612 auto bIt = bBlock.begin();
613 auto bEnd = bBlock.end();
614 while (aIt != aEnd && bIt != bEnd)
615 if (failed(check(diag, data, &*aIt++, &*bIt++)))
616 return failure();
617 if (aIt != aEnd) {
618 diag.attachNote(aIt->getLoc()) << "first block has more operations";
619 diag.attachNote(b->getLoc()) << "second block here";
620 return failure();
621 }
622 if (bIt != bEnd) {
623 diag.attachNote(bIt->getLoc()) << "second block has more operations";
624 diag.attachNote(a->getLoc()) << "first block here";
625 return failure();
626 }
627 return success();
628 }
629
630 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
631 Region &aRegion, Operation *b, Region &bRegion) {
632 auto aIt = aRegion.begin();
633 auto aEnd = aRegion.end();
634 auto bIt = bRegion.begin();
635 auto bEnd = bRegion.end();
636
637 // Region blocks.
638 while (aIt != aEnd && bIt != bEnd)
639 if (failed(check(diag, data, a, *aIt++, b, *bIt++)))
640 return failure();
641 if (aIt != aEnd || bIt != bEnd) {
642 diag.attachNote(a->getLoc())
643 << "operation regions have different number of blocks";
644 diag.attachNote(b->getLoc()) << "second operation here";
645 return failure();
646 }
647 return success();
648 }
649
650 LogicalResult check(InFlightDiagnostic &diag, Operation *a,
651 mlir::DenseBoolArrayAttr aAttr, Operation *b,
652 mlir::DenseBoolArrayAttr bAttr) {
653 if (aAttr == bAttr)
654 return success();
655 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
656 for (unsigned i = 0, e = aAttr.size(); i < e; ++i) {
657 auto aDirection = aAttr[i];
658 auto bDirection = bAttr[i];
659 if (aDirection != bDirection) {
660 auto &note = diag.attachNote(a->getLoc()) << "module port ";
661 if (portNames)
662 note << "'" << cast<StringAttr>(portNames[i]).getValue() << "'";
663 else
664 note << i;
665 note << " directions don't match, first direction is '"
666 << direction::toString(aDirection) << "'";
667 diag.attachNote(b->getLoc()) << "second direction is '"
668 << direction::toString(bDirection) << "'";
669 return failure();
670 }
671 }
672 return success();
673 }
674
675 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
676 DictionaryAttr aDict, Operation *b,
677 DictionaryAttr bDict) {
678 // Fast path.
679 if (aDict == bDict)
680 return success();
681
682 DenseSet<Attribute> seenAttrs;
683 for (auto namedAttr : aDict) {
684 auto attrName = namedAttr.getName();
685 if (nonessentialAttributes.contains(attrName))
686 continue;
687
688 auto aAttr = namedAttr.getValue();
689 auto bAttr = bDict.get(attrName);
690 if (!bAttr) {
691 diag.attachNote(a->getLoc())
692 << "second operation is missing attribute " << attrName;
693 diag.attachNote(b->getLoc()) << "second operation here";
694 return diag;
695 }
696
697 if (isa<hw::InnerRefAttr>(aAttr) && isa<hw::InnerRefAttr>(bAttr)) {
698 auto bRef = cast<hw::InnerRefAttr>(bAttr);
699 auto aRef = cast<hw::InnerRefAttr>(aAttr);
700 // See if they are pointing at the same operation or port.
701 auto aTarget = data.a.lookup(aRef.getName());
702 auto bTarget = data.b.lookup(bRef.getName());
703 if (!aTarget || !bTarget)
704 diag.attachNote(a->getLoc())
705 << "malformed ir, possibly violating use-before-def";
706 auto error = [&]() {
707 diag.attachNote(a->getLoc())
708 << "operations have different targets, first operation has "
709 << aTarget;
710 diag.attachNote(b->getLoc()) << "second operation has " << bTarget;
711 return failure();
712 };
713 if (aTarget.isPort()) {
714 // If they are targeting ports, make sure its the same port number.
715 if (!bTarget.isPort() || aTarget.getPort() != bTarget.getPort())
716 return error();
717 } else {
718 // Otherwise make sure that they are targeting the same operation.
719 if (!bTarget.isOpOnly() ||
720 aTarget.getOp() != data.map.lookup(bTarget.getOp()))
721 return error();
722 }
723 if (aTarget.getField() != bTarget.getField())
724 return error();
725 } else if (attrName == portDirectionsAttr) {
726 // Special handling for the port directions attribute for better
727 // error messages.
728 if (failed(check(diag, a, cast<mlir::DenseBoolArrayAttr>(aAttr), b,
729 cast<mlir::DenseBoolArrayAttr>(bAttr))))
730 return failure();
731 } else if (isa<DistinctAttr>(aAttr) && isa<DistinctAttr>(bAttr)) {
732 // TODO: properly handle DistinctAttr, including its use in paths.
733 // See https://github.com/llvm/circt/issues/6583
734 } else if (aAttr != bAttr) {
735 diag.attachNote(a->getLoc())
736 << "first operation has attribute '" << attrName.getValue()
737 << "' with value " << prettyPrint(aAttr);
738 diag.attachNote(b->getLoc())
739 << "second operation has value " << prettyPrint(bAttr);
740 return failure();
741 }
742 seenAttrs.insert(attrName);
743 }
744 if (aDict.getValue().size() != bDict.getValue().size()) {
745 for (auto namedAttr : bDict) {
746 auto attrName = namedAttr.getName();
747 // Skip the attribute if we don't care about this particular one or it
748 // is one that is known to be in both dictionaries.
749 if (nonessentialAttributes.contains(attrName) ||
750 seenAttrs.contains(attrName))
751 continue;
752 // We have found an attribute that is only in the second operation.
753 diag.attachNote(a->getLoc())
754 << "first operation is missing attribute " << attrName;
755 diag.attachNote(b->getLoc()) << "second operation here";
756 return failure();
757 }
758 }
759 return success();
760 }
761
762 // NOLINTNEXTLINE(misc-no-recursion)
763 LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a,
764 FInstanceLike b) {
765 auto aName = a.getReferencedModuleNameAttr();
766 auto bName = b.getReferencedModuleNameAttr();
767 if (aName == bName)
768 return success();
769
770 // If the modules instantiate are different we will want to know why the
771 // sub module did not dedupliate. This code recursively checks the child
772 // module.
773 auto aModule = instanceGraph.lookup(aName)->getModule();
774 auto bModule = instanceGraph.lookup(bName)->getModule();
775 // Create a new error for the submodule.
776 diag.attachNote(std::nullopt)
777 << "in instance " << a.getInstanceNameAttr() << " of " << aName
778 << ", and instance " << b.getInstanceNameAttr() << " of " << bName;
779 check(diag, aModule, bModule);
780 return failure();
781 }
782
783 // NOLINTNEXTLINE(misc-no-recursion)
784 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
785 Operation *b) {
786 // Operation name.
787 if (a->getName() != b->getName()) {
788 diag.attachNote(a->getLoc()) << "first operation is a " << a->getName();
789 diag.attachNote(b->getLoc()) << "second operation is a " << b->getName();
790 return failure();
791 }
792
793 // If its an instance operaiton, perform some checking and possibly
794 // recurse.
795 if (auto aInst = dyn_cast<FInstanceLike>(a)) {
796 auto bInst = cast<FInstanceLike>(b);
797 if (failed(check(diag, aInst, bInst)))
798 return failure();
799 }
800
801 // Operation results.
802 if (a->getNumResults() != b->getNumResults()) {
803 diag.attachNote(a->getLoc())
804 << "operations have different number of results";
805 diag.attachNote(b->getLoc()) << "second operation here";
806 return failure();
807 }
808 for (auto resultPair : llvm::zip(a->getResults(), b->getResults())) {
809 auto &aValue = std::get<0>(resultPair);
810 auto &bValue = std::get<1>(resultPair);
811 if (failed(check(diag, "operation result", a, aValue.getType(), b,
812 bValue.getType())))
813 return failure();
814 data.map.map(aValue, bValue);
815 }
816
817 // Operations operands.
818 if (a->getNumOperands() != b->getNumOperands()) {
819 diag.attachNote(a->getLoc())
820 << "operations have different number of operands";
821 diag.attachNote(b->getLoc()) << "second operation here";
822 return failure();
823 }
824 for (auto operandPair : llvm::zip(a->getOperands(), b->getOperands())) {
825 auto &aValue = std::get<0>(operandPair);
826 auto &bValue = std::get<1>(operandPair);
827 if (bValue != data.map.lookup(aValue)) {
828 diag.attachNote(a->getLoc())
829 << "operations use different operands, first operand is '"
830 << getFieldName(
831 getFieldRefFromValue(aValue, /*lookThroughCasts=*/true))
832 .first
833 << "'";
834 diag.attachNote(b->getLoc())
835 << "second operand is '"
836 << getFieldName(
837 getFieldRefFromValue(bValue, /*lookThroughCasts=*/true))
838 .first
839 << "', but should have been '"
840 << getFieldName(getFieldRefFromValue(data.map.lookup(aValue),
841 /*lookThroughCasts=*/true))
842 .first
843 << "'";
844 return failure();
845 }
846 }
847 data.map.map(a, b);
848
849 // Operation regions.
850 if (a->getNumRegions() != b->getNumRegions()) {
851 diag.attachNote(a->getLoc())
852 << "operations have different number of regions";
853 diag.attachNote(b->getLoc()) << "second operation here";
854 return failure();
855 }
856 for (auto regionPair : llvm::zip(a->getRegions(), b->getRegions())) {
857 auto &aRegion = std::get<0>(regionPair);
858 auto &bRegion = std::get<1>(regionPair);
859 if (failed(check(diag, data, a, aRegion, b, bRegion)))
860 return failure();
861 }
862
863 // Operation attributes.
864 if (failed(check(diag, data, a, a->getAttrDictionary(), b,
865 b->getAttrDictionary())))
866 return failure();
867 return success();
868 }
869
870 // NOLINTNEXTLINE(misc-no-recursion)
871 void check(InFlightDiagnostic &diag, Operation *a, Operation *b) {
872 hw::InnerSymbolTable aTable(a);
873 hw::InnerSymbolTable bTable(b);
874 ModuleData data(aTable, bTable);
876 diag.attachNote(a->getLoc()) << "module marked NoDedup";
877 return;
878 }
880 diag.attachNote(b->getLoc()) << "module marked NoDedup";
881 return;
882 }
883 auto aSymbol = cast<mlir::SymbolOpInterface>(a);
884 auto bSymbol = cast<mlir::SymbolOpInterface>(b);
885 if (!canRemoveModule(aSymbol) && !canRemoveModule(bSymbol)) {
886 diag.attachNote(a->getLoc())
887 << "module is "
888 << (aSymbol.isPrivate() ? "private but not discardable" : "public");
889 diag.attachNote(b->getLoc())
890 << "module is "
891 << (bSymbol.isPrivate() ? "private but not discardable" : "public");
892 return;
893 }
894 auto aGroup =
895 dyn_cast_or_null<StringAttr>(a->getDiscardableAttr(dedupGroupAttrName));
896 auto bGroup = dyn_cast_or_null<StringAttr>(
897 b->getAttrOfType<StringAttr>(dedupGroupAttrName));
898 if (aGroup != bGroup) {
899 if (bGroup) {
900 diag.attachNote(b->getLoc())
901 << "module is in dedup group '" << bGroup.str() << "'";
902 } else {
903 diag.attachNote(b->getLoc()) << "module is not part of a dedup group";
904 }
905 if (aGroup) {
906 diag.attachNote(a->getLoc())
907 << "module is in dedup group '" << aGroup.str() << "'";
908 } else {
909 diag.attachNote(a->getLoc()) << "module is not part of a dedup group";
910 }
911 return;
912 }
913 if (failed(check(diag, data, a, b)))
914 return;
915 diag.attachNote(a->getLoc()) << "first module here";
916 diag.attachNote(b->getLoc()) << "second module here";
917 }
918
919 // This is a cached "portDirections" string attr.
921 // This is a cached "NoDedup" annotation class string attr.
922 StringAttr noDedupClass;
923 // This is a cached string attr for the dedup group attribute.
925
926 // This is a set of every attribute we should ignore.
927 DenseSet<Attribute> nonessentialAttributes;
929};
930
931//===----------------------------------------------------------------------===//
932// Deduplication
933//===----------------------------------------------------------------------===//
934
935// Custom location merging. This only keeps track of 8 annotations from ".fir"
936// files, and however many annotations come from "real" sources. When
937// deduplicating, modules tend not to have scala source locators, so we wind
938// up fusing source locators for a module from every copy being deduped. There
939// is little value in this (all the modules are identical by definition).
940static Location mergeLoc(MLIRContext *context, Location to, Location from) {
941 // Unique the set of locations to be fused.
942 llvm::SmallSetVector<Location, 4> decomposedLocs;
943 // only track 8 "fir" locations
944 unsigned seenFIR = 0;
945 for (auto loc : {to, from}) {
946 // If the location is a fused location we decompose it if it has no
947 // metadata or the metadata is the same as the top level metadata.
948 if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
949 // UnknownLoc's have already been removed from FusedLocs so we can
950 // simply add all of the internal locations.
951 for (auto loc : fusedLoc.getLocations()) {
952 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
953 if (fileLoc.getFilename().strref().ends_with(".fir")) {
954 ++seenFIR;
955 if (seenFIR > 8)
956 continue;
957 }
958 }
959 decomposedLocs.insert(loc);
960 }
961 continue;
962 }
963
964 // Might need to skip this fir.
965 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
966 if (fileLoc.getFilename().strref().ends_with(".fir")) {
967 ++seenFIR;
968 if (seenFIR > 8)
969 continue;
970 }
971 }
972 // Otherwise, only add known locations to the set.
973 if (!isa<UnknownLoc>(loc))
974 decomposedLocs.insert(loc);
975 }
976
977 auto locs = decomposedLocs.getArrayRef();
978
979 // Handle the simple cases of less than two locations. Ensure the metadata (if
980 // provided) is not dropped.
981 if (locs.empty())
982 return UnknownLoc::get(context);
983 if (locs.size() == 1)
984 return locs.front();
985
986 return FusedLoc::get(context, locs);
987}
988
989struct Deduper {
990
991 using RenameMap = DenseMap<StringAttr, StringAttr>;
992
994 NLATable *nlaTable, CircuitOp circuit)
995 : context(circuit->getContext()), instanceGraph(instanceGraph),
997 nlaBlock(circuit.getBodyBlock()),
998 nonLocalString(StringAttr::get(context, "circt.nonlocal")),
999 classString(StringAttr::get(context, "class")) {
1000 // Populate the NLA cache.
1001 for (auto nla : circuit.getOps<hw::HierPathOp>())
1002 nlaCache[nla.getNamepathAttr()] = nla.getSymNameAttr();
1003 }
1004
1005 /// Remove the "fromModule", and replace all references to it with the
1006 /// "toModule". Modules should be deduplicated in a bottom-up order. Any
1007 /// module which is not deduplicated needs to be recorded with the `record`
1008 /// call.
1009 void dedup(FModuleLike toModule, FModuleLike fromModule) {
1010 // A map of operation (e.g. wires, nodes) names which are changed, which is
1011 // used to update NLAs that reference the "fromModule".
1012 RenameMap renameMap;
1013
1014 // Merge the port locations.
1015 SmallVector<Attribute> newLocs;
1016 for (auto [toLoc, fromLoc] : llvm::zip(toModule.getPortLocations(),
1017 fromModule.getPortLocations())) {
1018 if (toLoc == fromLoc)
1019 newLocs.push_back(toLoc);
1020 else
1021 newLocs.push_back(mergeLoc(context, cast<LocationAttr>(toLoc),
1022 cast<LocationAttr>(fromLoc)));
1023 }
1024 toModule->setAttr("portLocations", ArrayAttr::get(context, newLocs));
1025
1026 // Merge the two modules.
1027 mergeOps(renameMap, toModule, toModule, fromModule, fromModule);
1028
1029 // Rewrite NLAs pathing through these modules to refer to the to module. It
1030 // is safe to do this at this point because NLAs cannot be one element long.
1031 // This means that all NLAs which require more context cannot be targetting
1032 // something in the module it self.
1033 if (auto to = dyn_cast<FModuleOp>(*toModule))
1034 rewriteModuleNLAs(renameMap, to, cast<FModuleOp>(*fromModule));
1035 else
1036 rewriteExtModuleNLAs(renameMap, toModule.getModuleNameAttr(),
1037 fromModule.getModuleNameAttr());
1038
1039 replaceInstances(toModule, fromModule);
1040 }
1041
1042 /// Record the usages of any NLA's in this module, so that we may update the
1043 /// annotation if the parent module is deduped with another module.
1044 void record(FModuleLike module) {
1045 // Record any annotations on the module.
1046 recordAnnotations(module);
1047 // Record port annotations.
1048 for (unsigned i = 0, e = getNumPorts(module); i < e; ++i)
1050 // Record any annotations in the module body.
1051 module->walk([&](Operation *op) { recordAnnotations(op); });
1052 }
1053
1054private:
1055 /// Get a cached namespace for a module.
1057 return moduleNamespaces.try_emplace(module, cast<FModuleLike>(module))
1058 .first->second;
1059 }
1060
1061 /// For a specific annotation target, record all the unique NLAs which
1062 /// target it in the `targetMap`.
1064 for (auto anno : target.getAnnotations())
1065 if (auto nlaRef = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
1066 targetMap[nlaRef.getAttr()].insert(target);
1067 }
1068
1069 /// Record all targets which use an NLA.
1070 void recordAnnotations(Operation *op) {
1071 // Record annotations.
1072 recordAnnotations(OpAnnoTarget(op));
1073
1074 // Record port annotations only if this is a mem operation.
1075 auto mem = dyn_cast<MemOp>(op);
1076 if (!mem)
1077 return;
1078
1079 // Record port annotations.
1080 for (unsigned i = 0, e = mem->getNumResults(); i < e; ++i)
1081 recordAnnotations(PortAnnoTarget(mem, i));
1082 }
1083
1084 /// This deletes and replaces all instances of the "fromModule" with instances
1085 /// of the "toModule".
1086 void replaceInstances(FModuleLike toModule, Operation *fromModule) {
1087 // Replace all instances of the other module.
1088 auto *fromNode =
1089 instanceGraph[::cast<igraph::ModuleOpInterface>(fromModule)];
1090 auto *toNode = instanceGraph[toModule];
1091 auto toModuleRef = FlatSymbolRefAttr::get(toModule.getModuleNameAttr());
1092 for (auto *oldInstRec : llvm::make_early_inc_range(fromNode->uses())) {
1093 auto inst = oldInstRec->getInstance();
1094 if (auto instOp = dyn_cast<InstanceOp>(*inst)) {
1095 instOp.setModuleNameAttr(toModuleRef);
1096 instOp.setPortNamesAttr(toModule.getPortNamesAttr());
1097 } else if (auto objectOp = dyn_cast<ObjectOp>(*inst)) {
1098 auto classLike = cast<ClassLike>(*toNode->getModule());
1099 ClassType classType = detail::getInstanceTypeForClassLike(classLike);
1100 objectOp.getResult().setType(classType);
1101 }
1102 oldInstRec->getParent()->addInstance(inst, toNode);
1103 oldInstRec->erase();
1104 }
1105 instanceGraph.erase(fromNode);
1106 fromModule->erase();
1107 }
1108
1109 /// Look up the instantiations of the `from` module and create an NLA for each
1110 /// one, appending the baseNamepath to each NLA. This is used to add more
1111 /// context to an already existing NLA. The `fromModule` is used to indicate
1112 /// which module the annotation is coming from before the merge, and will be
1113 /// used to create the namepaths.
1114 SmallVector<FlatSymbolRefAttr>
1115 createNLAs(Operation *fromModule, ArrayRef<Attribute> baseNamepath,
1116 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1117 // Create an attribute array with a placeholder in the first element, where
1118 // the root refence of the NLA will be inserted.
1119 SmallVector<Attribute> namepath = {nullptr};
1120 namepath.append(baseNamepath.begin(), baseNamepath.end());
1121
1122 auto loc = fromModule->getLoc();
1123 auto *fromNode = instanceGraph[cast<igraph::ModuleOpInterface>(fromModule)];
1124 SmallVector<FlatSymbolRefAttr> nlas;
1125 for (auto *instanceRecord : fromNode->uses()) {
1126 auto parent = cast<FModuleOp>(*instanceRecord->getParent()->getModule());
1127 auto inst = instanceRecord->getInstance();
1128 namepath[0] = OpAnnoTarget(inst).getNLAReference(getNamespace(parent));
1129 auto arrayAttr = ArrayAttr::get(context, namepath);
1130 // Check the NLA cache to see if we already have this NLA.
1131 auto &cacheEntry = nlaCache[arrayAttr];
1132 if (!cacheEntry) {
1133 auto builder = OpBuilder::atBlockBegin(nlaBlock);
1134 auto nla = hw::HierPathOp::create(builder, loc, "nla", arrayAttr);
1135 // Insert it into the symbol table to get a unique name.
1136 symbolTable.insert(nla);
1137 // Store it in the cache.
1138 cacheEntry = nla.getNameAttr();
1139 nla.setVisibility(vis);
1140 nlaTable->addNLA(nla);
1141 }
1142 auto nlaRef = FlatSymbolRefAttr::get(cast<StringAttr>(cacheEntry));
1143 nlas.push_back(nlaRef);
1144 }
1145 return nlas;
1146 }
1147
1148 /// Look up the instantiations of this module and create an NLA for each one.
1149 /// This returns an array of symbol references which can be used to reference
1150 /// the NLAs.
1151 SmallVector<FlatSymbolRefAttr>
1152 createNLAs(StringAttr toModuleName, FModuleLike fromModule,
1153 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1154 return createNLAs(fromModule, FlatSymbolRefAttr::get(toModuleName), vis);
1155 }
1156
1157 /// Clone the annotation for each NLA in a list. The attribute list should
1158 /// have a placeholder for the "circt.nonlocal" field, and `nonLocalIndex`
1159 /// should be the index of this field.
1160 void cloneAnnotation(SmallVectorImpl<FlatSymbolRefAttr> &nlas,
1161 Annotation anno, ArrayRef<NamedAttribute> attributes,
1162 unsigned nonLocalIndex,
1163 SmallVectorImpl<Annotation> &newAnnotations) {
1164 SmallVector<NamedAttribute> mutableAttributes(attributes.begin(),
1165 attributes.end());
1166 for (auto &nla : nlas) {
1167 // Add the new annotation.
1168 mutableAttributes[nonLocalIndex].setValue(nla);
1169 auto dict = DictionaryAttr::getWithSorted(context, mutableAttributes);
1170 // The original annotation records if its a subannotation.
1171 anno.setDict(dict);
1172 newAnnotations.push_back(anno);
1173 }
1174 }
1175
1176 /// This erases the NLA op, and removes the NLA from every module's NLA map,
1177 /// but it does not delete the NLA reference from the target operation's
1178 /// annotations.
1179 void eraseNLA(hw::HierPathOp nla) {
1180 // Erase the NLA from the leaf module's nlaMap.
1181 targetMap.erase(nla.getNameAttr());
1182 nlaTable->erase(nla);
1183 nlaCache.erase(nla.getNamepathAttr());
1184 symbolTable.erase(nla);
1185 }
1186
1187 /// Process all NLAs referencing the "from" module to point to the "to"
1188 /// module. This is used after merging two modules together.
1189 void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule,
1190 FModuleOp fromModule) {
1191 auto toName = toModule.getNameAttr();
1192 auto fromName = fromModule.getNameAttr();
1193 // Create a copy of the current NLAs. We will be pushing and removing
1194 // NLAs from this op as we go.
1195 auto moduleNLAs = nlaTable->lookup(fromModule.getNameAttr()).vec();
1196 // Change the NLA to target the toModule.
1197 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1198 // Now we walk the NLA searching for ones that require more context to be
1199 // added.
1200 for (auto nla : moduleNLAs) {
1201 auto elements = nla.getNamepath().getValue();
1202 // If we don't need to add more context, we're done here.
1203 if (nla.root() != toName)
1204 continue;
1205 // Create the replacement NLAs.
1206 SmallVector<Attribute> namepath(elements.begin(), elements.end());
1207 auto nlaRefs = createNLAs(fromModule, namepath, nla.getVisibility());
1208 // Copy out the targets, because we will be updating the map.
1209 auto &set = targetMap[nla.getSymNameAttr()];
1210 SmallVector<AnnoTarget> targets(set.begin(), set.end());
1211 // Replace the uses of the old NLA with the new NLAs.
1212 for (auto target : targets) {
1213 // We have to clone any annotation which uses the old NLA for each new
1214 // NLA. This array collects the new set of annotations.
1215 SmallVector<Annotation> newAnnotations;
1216 for (auto anno : target.getAnnotations()) {
1217 // Find the non-local field of the annotation.
1218 auto [it, found] = mlir::impl::findAttrSorted(
1219 anno.begin(), anno.end(), nonLocalString);
1220 // If this annotation doesn't use the target NLA, copy it with no
1221 // changes.
1222 if (!found || cast<FlatSymbolRefAttr>(it->getValue()).getAttr() !=
1223 nla.getSymNameAttr()) {
1224 newAnnotations.push_back(anno);
1225 continue;
1226 }
1227 auto nonLocalIndex = std::distance(anno.begin(), it);
1228 // Clone the annotation and add it to the list of new annotations.
1229 cloneAnnotation(nlaRefs, anno,
1230 ArrayRef<NamedAttribute>(anno.begin(), anno.end()),
1231 nonLocalIndex, newAnnotations);
1232 }
1233
1234 // Apply the new annotations to the operation.
1235 AnnotationSet annotations(newAnnotations, context);
1236 target.setAnnotations(annotations);
1237 // Record that target uses the NLA.
1238 for (auto nla : nlaRefs)
1239 targetMap[nla.getAttr()].insert(target);
1240 }
1241
1242 // Erase the old NLA and remove it from all breadcrumbs.
1243 eraseNLA(nla);
1244 }
1245 }
1246
1247 /// Process all the NLAs that the two modules participate in, replacing
1248 /// references to the "from" module with references to the "to" module, and
1249 /// adding more context if necessary.
1250 void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule,
1251 FModuleOp fromModule) {
1252 addAnnotationContext(renameMap, toModule, toModule);
1253 addAnnotationContext(renameMap, toModule, fromModule);
1254 }
1255
1256 // Update all NLAs which the "from" external module participates in to the
1257 // "toName".
1258 void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName,
1259 StringAttr fromName) {
1260 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1261 }
1262
1263 /// Take an annotation, and update it to be a non-local annotation. If the
1264 /// annotation is already non-local and has enough context, it will be skipped
1265 /// for now. Return true if the annotation was made non-local.
1266 bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to,
1267 FModuleLike fromModule, Annotation anno,
1268 SmallVectorImpl<Annotation> &newAnnotations) {
1269 // Start constructing a new annotation, pushing a "circt.nonLocal" field
1270 // into the correct spot if its not already a non-local annotation.
1271 SmallVector<NamedAttribute> attributes;
1272 int nonLocalIndex = -1;
1273 for (const auto &val : llvm::enumerate(anno)) {
1274 auto attr = val.value();
1275 // Is this field "circt.nonlocal"?
1276 auto compare = attr.getName().compare(nonLocalString);
1277 assert(compare != 0 && "should not pass non-local annotations here");
1278 if (compare == 1) {
1279 // This annotation definitely does not have "circt.nonlocal" field. Push
1280 // an empty place holder for the non-local annotation.
1281 nonLocalIndex = val.index();
1282 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1283 break;
1284 }
1285 // Otherwise push the current attribute and keep searching for the
1286 // "circt.nonlocal" field.
1287 attributes.push_back(attr);
1288 }
1289 if (nonLocalIndex == -1) {
1290 // Push an empty "circt.nonlocal" field to the last slot.
1291 nonLocalIndex = attributes.size();
1292 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1293 } else {
1294 // Copy the remaining annotation fields in.
1295 attributes.append(anno.begin() + nonLocalIndex, anno.end());
1296 }
1297
1298 // Construct the NLAs if we don't have any yet.
1299 auto nlaRefs = createNLAs(toModuleName, fromModule);
1300 for (auto nla : nlaRefs)
1301 targetMap[nla.getAttr()].insert(to);
1302
1303 // Clone the annotation for each new NLA.
1304 cloneAnnotation(nlaRefs, anno, attributes, nonLocalIndex, newAnnotations);
1305 return true;
1306 }
1307
1308 void copyAnnotations(FModuleLike toModule, AnnoTarget to,
1309 FModuleLike fromModule, AnnotationSet annos,
1310 SmallVectorImpl<Annotation> &newAnnotations,
1311 SmallPtrSetImpl<Attribute> &dontTouches) {
1312 for (auto anno : annos) {
1313 if (anno.isClass(dontTouchAnnoClass)) {
1314 // Remove the nonlocal field of the annotation if it has one, since this
1315 // is a sticky annotation.
1316 anno.removeMember("circt.nonlocal");
1317 auto [it, inserted] = dontTouches.insert(anno.getAttr());
1318 if (inserted)
1319 newAnnotations.push_back(anno);
1320 continue;
1321 }
1322 // If the annotation is already non-local, we add it as is. It is already
1323 // added to the target map.
1324 if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal")) {
1325 newAnnotations.push_back(anno);
1326 targetMap[nla.getAttr()].insert(to);
1327 continue;
1328 }
1329 // Otherwise make the annotation non-local and add it to the set.
1330 makeAnnotationNonLocal(toModule.getModuleNameAttr(), to, fromModule, anno,
1331 newAnnotations);
1332 }
1333 }
1334
1335 /// Merge the annotations of a specific target, either a operation or a port
1336 /// on an operation.
1337 void mergeAnnotations(FModuleLike toModule, AnnoTarget to,
1338 AnnotationSet toAnnos, FModuleLike fromModule,
1339 AnnoTarget from, AnnotationSet fromAnnos) {
1340 // This is a list of all the annotations which will be added to `to`.
1341 SmallVector<Annotation> newAnnotations;
1342
1343 // We have special case handling of DontTouch to prevent it from being
1344 // turned into a non-local annotation, and to remove duplicates.
1345 llvm::SmallPtrSet<Attribute, 4> dontTouches;
1346
1347 // Iterate the annotations, transforming most annotations into non-local
1348 // ones.
1349 copyAnnotations(toModule, to, toModule, toAnnos, newAnnotations,
1350 dontTouches);
1351 copyAnnotations(toModule, to, fromModule, fromAnnos, newAnnotations,
1352 dontTouches);
1353
1354 // Copy over all the new annotations.
1355 if (!newAnnotations.empty())
1356 to.setAnnotations(AnnotationSet(newAnnotations, context));
1357 }
1358
1359 /// Merge all annotations and port annotations on two operations.
1360 void mergeAnnotations(FModuleLike toModule, Operation *to,
1361 FModuleLike fromModule, Operation *from) {
1362 // Merge op annotations.
1363 mergeAnnotations(toModule, OpAnnoTarget(to), AnnotationSet(to), fromModule,
1364 OpAnnoTarget(from), AnnotationSet(from));
1365
1366 // Merge port annotations.
1367 if (toModule == to) {
1368 // Merge module port annotations.
1369 for (unsigned i = 0, e = getNumPorts(toModule); i < e; ++i)
1370 mergeAnnotations(toModule, PortAnnoTarget(toModule, i),
1371 AnnotationSet::forPort(toModule, i), fromModule,
1372 PortAnnoTarget(fromModule, i),
1373 AnnotationSet::forPort(fromModule, i));
1374 } else if (auto toMem = dyn_cast<MemOp>(to)) {
1375 // Merge memory port annotations.
1376 auto fromMem = cast<MemOp>(from);
1377 for (unsigned i = 0, e = toMem.getNumResults(); i < e; ++i)
1378 mergeAnnotations(toModule, PortAnnoTarget(toMem, i),
1379 AnnotationSet::forPort(toMem, i), fromModule,
1380 PortAnnoTarget(fromMem, i),
1381 AnnotationSet::forPort(fromMem, i));
1382 }
1383 }
1384
1385 hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule,
1386 hw::InnerSymAttr toSym,
1387 hw::InnerSymAttr fromSym) {
1388 if (fromSym && !fromSym.getProps().empty()) {
1389 auto &isn = getNamespace(toModule);
1390 // The properties for the new inner symbol..
1391 SmallVector<hw::InnerSymPropertiesAttr> newProps;
1392 // If the "to" op already has an inner symbol, copy all its properties.
1393 if (toSym)
1394 llvm::append_range(newProps, toSym);
1395 // Add each property from the fromSym to the toSym.
1396 for (auto fromProp : fromSym) {
1397 hw::InnerSymPropertiesAttr newProp;
1398 auto *it = llvm::find_if(newProps, [&](auto p) {
1399 return p.getFieldID() == fromProp.getFieldID();
1400 });
1401 if (it != newProps.end()) {
1402 // If we already have an inner sym with the same field id, use
1403 // that.
1404 newProp = *it;
1405 // If the old symbol is public, we need to make the new one public.
1406 if (fromProp.getSymVisibility().getValue() == "public" &&
1407 newProp.getSymVisibility().getValue() != "public") {
1408 *it = hw::InnerSymPropertiesAttr::get(context, newProp.getName(),
1409 newProp.getFieldID(),
1410 fromProp.getSymVisibility());
1411 }
1412 } else {
1413 // We need to add a new property to the inner symbol for this field.
1414 auto newName = isn.newName(fromProp.getName().getValue());
1415 newProp = hw::InnerSymPropertiesAttr::get(
1416 context, StringAttr::get(context, newName), fromProp.getFieldID(),
1417 fromProp.getSymVisibility());
1418 newProps.push_back(newProp);
1419 }
1420 renameMap[fromProp.getName()] = newProp.getName();
1421 }
1422 // Sort the fields by field id.
1423 llvm::sort(newProps, [](auto &p, auto &q) {
1424 return p.getFieldID() < q.getFieldID();
1425 });
1426 // Return the merged inner symbol.
1427 return hw::InnerSymAttr::get(context, newProps);
1428 }
1429 return hw::InnerSymAttr();
1430 }
1431
1432 // Record the symbol name change of the operation or any of its ports when
1433 // merging two operations. The renamed symbols are used to update the
1434 // target of any NLAs. This will add symbols to the "to" operation if needed.
1435 void recordSymRenames(RenameMap &renameMap, FModuleLike toModule,
1436 Operation *to, FModuleLike fromModule,
1437 Operation *from) {
1438 // If the "from" operation has an inner_sym, we need to make sure the
1439 // "to" operation also has an `inner_sym` and then record the renaming.
1440 if (auto fromInnerSym = dyn_cast<hw::InnerSymbolOpInterface>(from)) {
1441 auto toInnerSym = cast<hw::InnerSymbolOpInterface>(to);
1442 if (auto newSymAttr = mergeInnerSymbols(renameMap, toModule,
1443 toInnerSym.getInnerSymAttr(),
1444 fromInnerSym.getInnerSymAttr()))
1445 toInnerSym.setInnerSymbolAttr(newSymAttr);
1446 }
1447
1448 // If there are no port symbols on the "from" operation, we are done here.
1449 auto fromPortSyms = from->getAttrOfType<ArrayAttr>("portSymbols");
1450 if (!fromPortSyms || fromPortSyms.empty())
1451 return;
1452 // We have to map each "fromPort" to each "toPort".
1453 auto portCount = fromPortSyms.size();
1454 auto toPortSyms = to->getAttrOfType<ArrayAttr>("portSymbols");
1455
1456 // Create an array of new port symbols for the "to" operation, copy in the
1457 // old symbols if it has any, create an empty symbol array if it doesn't.
1458 SmallVector<Attribute> newPortSyms;
1459 if (toPortSyms.empty())
1460 newPortSyms.assign(portCount, hw::InnerSymAttr());
1461 else
1462 newPortSyms.assign(toPortSyms.begin(), toPortSyms.end());
1463
1464 for (unsigned portNo = 0; portNo < portCount; ++portNo) {
1465 if (auto newPortSym = mergeInnerSymbols(
1466 renameMap, toModule,
1467 llvm::cast_if_present<hw::InnerSymAttr>(newPortSyms[portNo]),
1468 cast<hw::InnerSymAttr>(fromPortSyms[portNo]))) {
1469 newPortSyms[portNo] = newPortSym;
1470 }
1471 }
1472
1473 // Commit the new symbol attribute.
1474 FModuleLike::fixupPortSymsArray(newPortSyms, toModule.getContext());
1475 cast<FModuleLike>(to).setPortSymbols(newPortSyms);
1476 }
1477
1478 /// Recursively merge two operations.
1479 // NOLINTNEXTLINE(misc-no-recursion)
1480 void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to,
1481 FModuleLike fromModule, Operation *from) {
1482 // Merge the operation locations.
1483 if (to->getLoc() != from->getLoc())
1484 to->setLoc(mergeLoc(context, to->getLoc(), from->getLoc()));
1485
1486 // Recurse into any regions.
1487 for (auto regions : llvm::zip(to->getRegions(), from->getRegions()))
1488 mergeRegions(renameMap, toModule, std::get<0>(regions), fromModule,
1489 std::get<1>(regions));
1490
1491 // Record any inner_sym renamings that happened.
1492 recordSymRenames(renameMap, toModule, to, fromModule, from);
1493
1494 // Merge the annotations.
1495 mergeAnnotations(toModule, to, fromModule, from);
1496 }
1497
1498 /// Recursively merge two blocks.
1499 void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock,
1500 FModuleLike fromModule, Block &fromBlock) {
1501 // Merge the block locations.
1502 for (auto [toArg, fromArg] :
1503 llvm::zip(toBlock.getArguments(), fromBlock.getArguments()))
1504 if (toArg.getLoc() != fromArg.getLoc())
1505 toArg.setLoc(mergeLoc(context, toArg.getLoc(), fromArg.getLoc()));
1506
1507 for (auto ops : llvm::zip(toBlock, fromBlock))
1508 mergeOps(renameMap, toModule, &std::get<0>(ops), fromModule,
1509 &std::get<1>(ops));
1510 }
1511
1512 // Recursively merge two regions.
1513 void mergeRegions(RenameMap &renameMap, FModuleLike toModule,
1514 Region &toRegion, FModuleLike fromModule,
1515 Region &fromRegion) {
1516 for (auto blocks : llvm::zip(toRegion, fromRegion))
1517 mergeBlocks(renameMap, toModule, std::get<0>(blocks), fromModule,
1518 std::get<1>(blocks));
1519 }
1520
1521 MLIRContext *context;
1523 SymbolTable &symbolTable;
1524
1525 /// Cached nla table analysis.
1526 NLATable *nlaTable = nullptr;
1527
1528 /// We insert all NLAs to the beginning of this block.
1529 Block *nlaBlock;
1530
1531 // This maps an NLA to the operations and ports that uses it.
1532 DenseMap<Attribute, llvm::SmallDenseSet<AnnoTarget>> targetMap;
1533
1534 // This is a cache to avoid creating duplicate NLAs. This maps the ArrayAtr
1535 // of the NLA's path to the name of the NLA which contains it.
1536 DenseMap<Attribute, Attribute> nlaCache;
1537
1538 // Cached attributes for faster comparisons and attribute building.
1539 StringAttr nonLocalString;
1540 StringAttr classString;
1541
1542 /// A module namespace cache.
1543 DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
1544};
1545
1546//===----------------------------------------------------------------------===//
1547// Fixup
1548//===----------------------------------------------------------------------===//
1549
1550/// This fixes up connects when the field names of a bundle type changes. It
1551/// finds all fields which were previously bulk connected and legalizes it
1552/// into a connect for each field.
1553static void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src) {
1554 // If the types already match we can emit a connect.
1555 auto dstType = dst.getType();
1556 auto srcType = src.getType();
1557 if (dstType == srcType) {
1558 emitConnect(builder, dst, src);
1559 return;
1560 }
1561 // It must be a bundle type and the field name has changed. We have to
1562 // manually decompose the bulk connect into a connect for each field.
1563 auto dstBundle = type_cast<BundleType>(dstType);
1564 auto srcBundle = type_cast<BundleType>(srcType);
1565 for (unsigned i = 0; i < dstBundle.getNumElements(); ++i) {
1566 auto dstField = SubfieldOp::create(builder, dst, i);
1567 auto srcField = SubfieldOp::create(builder, src, i);
1568 if (dstBundle.getElement(i).isFlip) {
1569 std::swap(srcBundle, dstBundle);
1570 std::swap(srcField, dstField);
1571 }
1572 fixupConnect(builder, dstField, srcField);
1573 }
1574}
1575
1576/// Adjust the symbol references in an op. Thsi includes updating its attributes
1577/// and types.
1578static void
1579fixupSymbolSensitiveOp(Operation *op, InstanceGraph &instanceGraph,
1580 const DenseMap<Attribute, StringAttr> &dedupMap) {
1581 // If this is an instance op, dedup may have subtly changed the port types.
1582 // For example, structurally different bundles may still dedup. In this case
1583 // we now have an instance op that produces result values of the old type, but
1584 // the port info on the instantiated module already represents the new type.
1585 // Fix this up by going through an intermediate wire.
1586 if (auto instOp = dyn_cast<InstanceOp>(op)) {
1587 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp->getContext());
1588 builder.setInsertionPointAfter(instOp);
1589 auto module = instanceGraph.lookup(instOp.getModuleNameAttr().getAttr())
1590 ->getModule<FModuleLike>();
1591 for (auto [index, result] : llvm::enumerate(instOp.getResults())) {
1592 auto newType = module.getPortType(index);
1593 auto oldType = result.getType();
1594 // If the type has not changed, we don't have to fix up anything.
1595 if (newType == oldType)
1596 continue;
1597 LLVM_DEBUG(llvm::dbgs()
1598 << "- Updating instance port \"" << instOp.getInstanceName()
1599 << "." << instOp.getPortName(index).getValue() << "\" from "
1600 << oldType << " to " << newType << "\n");
1601
1602 // If the type changed we transform it back to the old type with an
1603 // intermediate wire.
1604 auto wire = WireOp::create(builder, oldType, instOp.getPortName(index))
1605 .getResult();
1606 result.replaceAllUsesWith(wire);
1607 result.setType(newType);
1608 if (instOp.getPortDirection(index) == Direction::Out)
1609 fixupConnect(builder, wire, result);
1610 else
1611 fixupConnect(builder, result, wire);
1612 }
1613 }
1614
1615 // Use an attribute/type replacer to look for references to old symbols that
1616 // need to be replaced with new symbols.
1617 mlir::AttrTypeReplacer replacer;
1618 replacer.addReplacement([&](FlatSymbolRefAttr symRef) {
1619 auto oldName = symRef.getAttr();
1620 auto newName = dedupMap.lookup(oldName);
1621 if (newName && newName != oldName) {
1622 auto newSymRef = FlatSymbolRefAttr::get(newName);
1623 LLVM_DEBUG(llvm::dbgs()
1624 << "- Updating " << symRef << " to " << newSymRef << " in "
1625 << op->getName() << " at " << op->getLoc() << "\n");
1626 return newSymRef;
1627 }
1628 return symRef;
1629 });
1630
1631 // Update attributes.
1632 op->setAttrs(cast<DictionaryAttr>(replacer.replace(op->getAttrDictionary())));
1633
1634 // Update the argument types.
1635 for (auto &region : op->getRegions())
1636 for (auto &block : region)
1637 for (auto arg : block.getArguments())
1638 arg.setType(replacer.replace(arg.getType()));
1639
1640 // Update result types.
1641 for (auto result : op->getResults())
1642 result.setType(replacer.replace(result.getType()));
1643}
1644
1645/// Adjust the symbol references in ops marked as sensitive to them. This
1646/// includes updating their attributes and types.
1648 InstanceGraph &instanceGraph,
1649 const DenseMap<Operation *, ModuleInfoRef> &moduleToModuleInfo,
1650 const DenseMap<Attribute, StringAttr> &dedupMap) {
1651 for (auto *node : instanceGraph) {
1652 // Look up the module info for this module, which contains the list of ops
1653 // that need to be updated.
1654 auto module = node->getModule<FModuleLike>();
1655 auto it = moduleToModuleInfo.find(module);
1656 if (it == moduleToModuleInfo.end())
1657 continue;
1658
1659 // Update each symbol-sensitive op individually.
1660 auto &ops = it->second.info->symbolSensitiveOps;
1661 if (ops.empty())
1662 continue;
1663 LLVM_DEBUG(llvm::dbgs()
1664 << "- Updating " << ops.size() << " symbol-sensitive ops in "
1665 << module.getNameAttr() << "\n");
1666 for (auto *op : ops)
1667 fixupSymbolSensitiveOp(op, instanceGraph, dedupMap);
1668 }
1669}
1670
1671//===----------------------------------------------------------------------===//
1672// DedupPass
1673//===----------------------------------------------------------------------===//
1674
1675namespace {
1676class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {
1677 void runOnOperation() override {
1678 auto *context = &getContext();
1679 auto circuit = getOperation();
1680 auto &instanceGraph = getAnalysis<InstanceGraph>();
1681 auto *nlaTable = &getAnalysis<NLATable>();
1682 auto &symbolTable = getAnalysis<SymbolTable>();
1683 Deduper deduper(instanceGraph, symbolTable, nlaTable, circuit);
1684 Equivalence equiv(context, instanceGraph);
1685 auto anythingChanged = false;
1686 LLVM_DEBUG({
1687 llvm::dbgs() << "\n";
1688 debugHeader(Twine("Dedup circuit \"") + circuit.getName() + "\"")
1689 << "\n\n";
1690 });
1691
1692 // Modules annotated with this should not be considered for deduplication.
1693 auto noDedupClass = StringAttr::get(context, noDedupAnnoClass);
1694
1695 // Only modules within the same group may be deduplicated.
1696 auto dedupGroupClass = StringAttr::get(context, dedupGroupAnnoClass);
1697
1698 // A map of all the module moduleInfo that we have calculated so far.
1699 DenseMap<ModuleInfoRef, Operation *> moduleInfoToModule;
1700 DenseMap<Operation *, ModuleInfoRef> moduleToModuleInfo;
1701
1702 // We track the name of the module that each module is deduped into, so that
1703 // we can make sure all modules which are marked "must dedup" with each
1704 // other were all deduped to the same module.
1705 DenseMap<Attribute, StringAttr> dedupMap;
1706
1707 // We must iterate the modules from the bottom up so that we can properly
1708 // deduplicate the modules. We copy the list of modules into a vector first
1709 // to avoid iterator invalidation while we mutate the instance graph.
1710 SmallVector<FModuleLike, 0> modules;
1711 instanceGraph.walkPostOrder([&](auto &node) {
1712 if (auto mod = dyn_cast<FModuleLike>(*node.getModule()))
1713 modules.push_back(mod);
1714 });
1715 LLVM_DEBUG(llvm::dbgs() << "Found " << modules.size() << " modules\n");
1716
1717 SmallVector<std::optional<ModuleInfo>> moduleInfos(modules.size());
1718 StructuralHasherSharedConstants hasherConstants(&getContext());
1719
1720 // Attribute name used to store dedup_group for this pass.
1721 auto dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
1722
1723 // Move dedup group annotations to attributes on the module.
1724 // This results in the desired behavior (included in hash),
1725 // and avoids unnecessary processing of these as annotations
1726 // that need to be tracked, made non-local, so on.
1727 for (auto module : modules) {
1728 llvm::SmallSetVector<StringAttr, 1> groups;
1730 module, [&groups, dedupGroupClass](Annotation annotation) {
1731 if (annotation.getClassAttr() != dedupGroupClass)
1732 return false;
1733 groups.insert(annotation.getMember<StringAttr>("group"));
1734 return true;
1735 });
1736 if (groups.size() > 1) {
1737 module.emitError("module belongs to multiple dedup groups: ") << groups;
1738 return signalPassFailure();
1739 }
1740 assert(!module->hasAttr(dedupGroupAttrName) &&
1741 "unexpected existing use of temporary dedup group attribute");
1742 if (!groups.empty())
1743 module->setDiscardableAttr(dedupGroupAttrName, groups.front());
1744 }
1745
1746 // Calculate module information parallelly.
1747 LLVM_DEBUG(llvm::dbgs() << "Computing module information\n");
1748 auto result = mlir::failableParallelForEach(
1749 context, llvm::seq(modules.size()), [&](unsigned idx) {
1750 auto module = modules[idx];
1751 // If the module is marked with NoDedup, just skip it.
1752 if (AnnotationSet::hasAnnotation(module, noDedupClass))
1753 return success();
1754
1755 // Only dedup extmodule's with defname.
1756 if (auto ext = dyn_cast<FExtModuleOp>(*module);
1757 ext && !ext.getDefname().has_value())
1758 return success();
1759
1760 StructuralHasher hasher(hasherConstants);
1761 // Calculate the hash of the module and referred module names.
1762 moduleInfos[idx] = hasher.getModuleInfo(module);
1763 return success();
1764 });
1765
1766 // Dump out the module hashes for debugging.
1767 LLVM_DEBUG({
1768 auto &os = llvm::dbgs();
1769 for (auto [module, info] : llvm::zip(modules, moduleInfos)) {
1770 os << "- Hash ";
1771 if (info) {
1772 os << llvm::format_bytes(info->structuralHash, std::nullopt, 32, 32);
1773 } else {
1774 os << "--------------------------------";
1775 os << "--------------------------------";
1776 }
1777 os << " for " << module.getModuleNameAttr() << "\n";
1778 }
1779 });
1780
1781 if (result.failed())
1782 return signalPassFailure();
1783
1784 LLVM_DEBUG(llvm::dbgs() << "Update modules\n");
1785 for (auto [i, module] : llvm::enumerate(modules)) {
1786 auto moduleName = module.getModuleNameAttr();
1787 auto &maybeModuleInfo = moduleInfos[i];
1788 // If the hash was not calculated, we need to skip it.
1789 if (!maybeModuleInfo) {
1790 // We record it in the dedup map to help detect errors when the user
1791 // marks the module as both NoDedup and MustDedup. We do not record this
1792 // module in the hasher to make sure no other module dedups "into" this
1793 // one.
1794 dedupMap[moduleName] = moduleName;
1795 continue;
1796 }
1797
1798 auto &moduleInfo = maybeModuleInfo.value();
1799 moduleToModuleInfo.try_emplace(module, &moduleInfo);
1800
1801 // Replace module names referred in the module with new names.
1802 for (auto &referredModule : moduleInfo.referredModuleNames)
1803 referredModule = dedupMap[referredModule];
1804
1805 // Check if there is a module with the same hash.
1806 auto it = moduleInfoToModule.find(&moduleInfo);
1807 if (it != moduleInfoToModule.end()) {
1808 auto original = cast<FModuleLike>(it->second);
1809 auto originalName = original.getModuleNameAttr();
1810
1811 // If the current module is public, and the original is private, we
1812 // want to dedup the private module into the public one.
1813 if (!canRemoveModule(module)) {
1814 // If both modules are public, then we can't dedup anything.
1815 if (!canRemoveModule(original))
1816 continue;
1817 // Swap the canonical module in the dedup map.
1818 for (auto &[originalName, dedupedName] : dedupMap)
1819 if (dedupedName == originalName)
1820 dedupedName = moduleName;
1821 // Update the module hash table to point to the new original, so all
1822 // future modules dedup with the new canonical module.
1823 it->second = module;
1824 // Swap the locals.
1825 std::swap(originalName, moduleName);
1826 std::swap(original, module);
1827 }
1828
1829 // Record the group ID of the other module.
1830 LLVM_DEBUG(llvm::dbgs() << "- Replace " << moduleName << " with "
1831 << originalName << "\n");
1832 dedupMap[moduleName] = originalName;
1833 deduper.dedup(original, module);
1834 ++erasedModules;
1835 anythingChanged = true;
1836 continue;
1837 }
1838 // Any module not deduplicated must be recorded.
1839 deduper.record(module);
1840 // Add the module to a new dedup group.
1841 dedupMap[moduleName] = moduleName;
1842 // Record the module info.
1843 moduleInfoToModule[&moduleInfo] = module;
1844 }
1845
1846 // This part verifies that all modules marked by "MustDedup" have been
1847 // properly deduped with each other. For this check to succeed, all modules
1848 // have to been deduped to the same module. It is possible that a module was
1849 // deduped with the wrong thing.
1850
1851 auto failed = false;
1852 // This parses the module name out of a target string.
1853 auto parseModule = [&](Attribute path) -> StringAttr {
1854 // Each module is listed as a target "~Circuit|Module" which we have to
1855 // parse.
1856 auto [_, rhs] = cast<StringAttr>(path).getValue().split('|');
1857 return StringAttr::get(context, rhs);
1858 };
1859 // This gets the name of the module which the current module was deduped
1860 // with. If the named module isn't in the map, then we didn't encounter it
1861 // in the circuit.
1862 auto getLead = [&](StringAttr module) -> StringAttr {
1863 auto it = dedupMap.find(module);
1864 if (it == dedupMap.end()) {
1865 auto diag = emitError(circuit.getLoc(),
1866 "MustDeduplicateAnnotation references module ")
1867 << module << " which does not exist";
1868 failed = true;
1869 return nullptr;
1870 }
1871 return it->second;
1872 };
1873
1874 LLVM_DEBUG(llvm::dbgs() << "Update annotations\n");
1875 AnnotationSet::removeAnnotations(circuit, [&](Annotation annotation) {
1876 if (!annotation.isClass(mustDedupAnnoClass))
1877 return false;
1878 auto modules = annotation.getMember<ArrayAttr>("modules");
1879 if (!modules) {
1880 emitError(circuit.getLoc(),
1881 "MustDeduplicateAnnotation missing \"modules\" member");
1882 failed = true;
1883 return false;
1884 }
1885 // Empty module list has nothing to process.
1886 if (modules.empty())
1887 return true;
1888 // Get the first element.
1889 auto firstModule = parseModule(modules[0]);
1890 auto firstLead = getLead(firstModule);
1891 if (!firstLead)
1892 return false;
1893 // Verify that the remaining elements are all the same as the first.
1894 for (auto attr : modules.getValue().drop_front()) {
1895 auto nextModule = parseModule(attr);
1896 auto nextLead = getLead(nextModule);
1897 if (!nextLead)
1898 return false;
1899 if (firstLead != nextLead) {
1900 auto diag = emitError(circuit.getLoc(), "module ")
1901 << nextModule << " not deduplicated with " << firstModule;
1902 auto a = instanceGraph.lookup(firstLead)->getModule();
1903 auto b = instanceGraph.lookup(nextLead)->getModule();
1904 equiv.check(diag, a, b);
1905 failed = true;
1906 return false;
1907 }
1908 }
1909 return true;
1910 });
1911 if (failed)
1912 return signalPassFailure();
1913
1914 // Remove all dedup group attributes, they only exist during this pass.
1915 for (auto module : circuit.getOps<FModuleLike>())
1916 module->removeDiscardableAttr(dedupGroupAttrName);
1917
1918 // Fixup all operations that we've found to be sensitive to symbol names.
1919 // This includes module and class instances, wires, connects, etc.
1920 fixupSymbolSensitiveOps(instanceGraph, moduleToModuleInfo, dedupMap);
1921
1922 markAnalysesPreserved<NLATable>();
1923 if (!anythingChanged)
1924 markAllAnalysesPreserved();
1925 }
1926};
1927} // end anonymous namespace
assert(baseType &&"element must be base type")
static Location mergeLoc(MLIRContext *context, Location to, Location from)
Definition Dedup.cpp:940
static void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src)
This fixes up connects when the field names of a bundle type changes.
Definition Dedup.cpp:1553
static bool canRemoveModule(mlir::SymbolOpInterface symbol)
Returns true if the module can be removed.
Definition Dedup.cpp:55
static void fixupSymbolSensitiveOp(Operation *op, InstanceGraph &instanceGraph, const DenseMap< Attribute, StringAttr > &dedupMap)
Adjust the symbol references in an op.
Definition Dedup.cpp:1579
static void fixupSymbolSensitiveOps(InstanceGraph &instanceGraph, const DenseMap< Operation *, ModuleInfoRef > &moduleToModuleInfo, const DenseMap< Attribute, StringAttr > &dedupMap)
Adjust the symbol references in ops marked as sensitive to them.
Definition Dedup.cpp:1647
static void mergeRegions(Region *region1, Region *region2)
Definition HWCleanup.cpp:77
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
void setDict(DictionaryAttr dict)
Set the data dictionary of this attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
StringAttr getClassAttr() const
Return the 'class' that this annotation is representing.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
A table of inner symbols and their resolutions.
auto getModule()
Get the module that this node is tracking.
decltype(auto) walkPostOrder(Fn &&fn)
Perform a post-order walk across the modules.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
static StringRef toString(Direction direction)
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
constexpr const char * mustDedupAnnoClass
constexpr const char * noDedupAnnoClass
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
constexpr const char * dedupGroupAnnoClass
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
constexpr const char * dontTouchAnnoClass
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:35
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(const llvm::Twine &str, unsigned width=80)
Write a "header"-like string to the debug stream with a certain width.
Definition Debug.cpp:17
SmallVector< FlatSymbolRefAttr > createNLAs(Operation *fromModule, ArrayRef< Attribute > baseNamepath, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of the from module and create an NLA for each one, appending the baseNamep...
Definition Dedup.cpp:1115
MLIRContext * context
Definition Dedup.cpp:1521
Block * nlaBlock
We insert all NLAs to the beginning of this block.
Definition Dedup.cpp:1529
void recordAnnotations(Operation *op)
Record all targets which use an NLA.
Definition Dedup.cpp:1070
void eraseNLA(hw::HierPathOp nla)
This erases the NLA op, and removes the NLA from every module's NLA map, but it does not delete the N...
Definition Dedup.cpp:1179
void mergeAnnotations(FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Merge all annotations and port annotations on two operations.
Definition Dedup.cpp:1360
void replaceInstances(FModuleLike toModule, Operation *fromModule)
This deletes and replaces all instances of the "fromModule" with instances of the "toModule".
Definition Dedup.cpp:1086
void record(FModuleLike module)
Record the usages of any NLA's in this module, so that we may update the annotation if the parent mod...
Definition Dedup.cpp:1044
void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName, StringAttr fromName)
Definition Dedup.cpp:1258
void mergeRegions(RenameMap &renameMap, FModuleLike toModule, Region &toRegion, FModuleLike fromModule, Region &fromRegion)
Definition Dedup.cpp:1513
void dedup(FModuleLike toModule, FModuleLike fromModule)
Remove the "fromModule", and replace all references to it with the "toModule".
Definition Dedup.cpp:1009
void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all the NLAs that the two modules participate in, replacing references to the "from" module w...
Definition Dedup.cpp:1250
SmallVector< FlatSymbolRefAttr > createNLAs(StringAttr toModuleName, FModuleLike fromModule, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of this module and create an NLA for each one.
Definition Dedup.cpp:1152
void recordAnnotations(AnnoTarget target)
For a specific annotation target, record all the unique NLAs which target it in the targetMap.
Definition Dedup.cpp:1063
NLATable * nlaTable
Cached nla table analysis.
Definition Dedup.cpp:1526
hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule, hw::InnerSymAttr toSym, hw::InnerSymAttr fromSym)
Definition Dedup.cpp:1385
void cloneAnnotation(SmallVectorImpl< FlatSymbolRefAttr > &nlas, Annotation anno, ArrayRef< NamedAttribute > attributes, unsigned nonLocalIndex, SmallVectorImpl< Annotation > &newAnnotations)
Clone the annotation for each NLA in a list.
Definition Dedup.cpp:1160
void recordSymRenames(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Definition Dedup.cpp:1435
void mergeAnnotations(FModuleLike toModule, AnnoTarget to, AnnotationSet toAnnos, FModuleLike fromModule, AnnoTarget from, AnnotationSet fromAnnos)
Merge the annotations of a specific target, either a operation or a port on an operation.
Definition Dedup.cpp:1337
StringAttr nonLocalString
Definition Dedup.cpp:1539
hw::InnerSymbolNamespace & getNamespace(Operation *module)
Get a cached namespace for a module.
Definition Dedup.cpp:1056
SymbolTable & symbolTable
Definition Dedup.cpp:1523
void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Recursively merge two operations.
Definition Dedup.cpp:1480
DenseMap< Operation *, hw::InnerSymbolNamespace > moduleNamespaces
A module namespace cache.
Definition Dedup.cpp:1543
bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to, FModuleLike fromModule, Annotation anno, SmallVectorImpl< Annotation > &newAnnotations)
Take an annotation, and update it to be a non-local annotation.
Definition Dedup.cpp:1266
InstanceGraph & instanceGraph
Definition Dedup.cpp:1522
void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock, FModuleLike fromModule, Block &fromBlock)
Recursively merge two blocks.
Definition Dedup.cpp:1499
DenseMap< Attribute, llvm::SmallDenseSet< AnnoTarget > > targetMap
Definition Dedup.cpp:1532
StringAttr classString
Definition Dedup.cpp:1540
void copyAnnotations(FModuleLike toModule, AnnoTarget to, FModuleLike fromModule, AnnotationSet annos, SmallVectorImpl< Annotation > &newAnnotations, SmallPtrSetImpl< Attribute > &dontTouches)
Definition Dedup.cpp:1308
Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable, NLATable *nlaTable, CircuitOp circuit)
Definition Dedup.cpp:993
void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all NLAs referencing the "from" module to point to the "to" module.
Definition Dedup.cpp:1189
DenseMap< StringAttr, StringAttr > RenameMap
Definition Dedup.cpp:991
DenseMap< Attribute, Attribute > nlaCache
Definition Dedup.cpp:1536
const hw::InnerSymbolTable & a
Definition Dedup.cpp:477
ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
Definition Dedup.cpp:474
const hw::InnerSymbolTable & b
Definition Dedup.cpp:478
This class is for reporting differences between two modules which should have been deduplicated.
Definition Dedup.cpp:456
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:927
std::string prettyPrint(Attribute attr)
Definition Dedup.cpp:481
LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a, FInstanceLike b)
Definition Dedup.cpp:763
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Block &aBlock, Operation *b, Block &bBlock)
Definition Dedup.cpp:552
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, Type aType, Operation *b, Type bType)
Definition Dedup.cpp:526
StringAttr noDedupClass
Definition Dedup.cpp:922
StringAttr dedupGroupAttrName
Definition Dedup.cpp:924
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, DictionaryAttr aDict, Operation *b, DictionaryAttr bDict)
Definition Dedup.cpp:675
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Region &aRegion, Operation *b, Region &bRegion)
Definition Dedup.cpp:630
StringAttr portDirectionsAttr
Definition Dedup.cpp:920
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, BundleType aType, Operation *b, BundleType bType)
Definition Dedup.cpp:497
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Operation *b)
Definition Dedup.cpp:784
Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
Definition Dedup.cpp:457
LogicalResult check(InFlightDiagnostic &diag, Operation *a, mlir::DenseBoolArrayAttr aAttr, Operation *b, mlir::DenseBoolArrayAttr bAttr)
Definition Dedup.cpp:650
InstanceGraph & instanceGraph
Definition Dedup.cpp:928
void check(InFlightDiagnostic &diag, Operation *a, Operation *b)
Definition Dedup.cpp:871
A reference to a ModuleInfo that compares and hashes like it.
Definition Dedup.cpp:411
ModuleInfo * info
Definition Dedup.cpp:413
ModuleInfoRef(ModuleInfo *info)
Definition Dedup.cpp:412
std::vector< Operation * > symbolSensitiveOps
Definition Dedup.cpp:90
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:86
std::array< uint8_t, 32 > structuralHash
Definition Dedup.cpp:84
This struct contains constant string attributes shared across different threads.
Definition Dedup.cpp:100
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:127
StructuralHasherSharedConstants(MLIRContext *context)
Definition Dedup.cpp:101
void populateInnerSymIDTable(FModuleLike module)
Find all the ports and operations which may define an inner symbol operations and give each a unique ...
Definition Dedup.cpp:147
void update(Operation *op, DictionaryAttr dict)
Hash the top level attribute dictionary of the operation.
Definition Dedup.cpp:270
void update(Type type)
Definition Dedup.cpp:247
void update(const void *pointer)
Definition Dedup.cpp:204
void update(ClassType type)
Definition Dedup.cpp:232
DenseMap< void *, unsigned > idTable
Definition Dedup.cpp:383
void update(const std::pair< T, U > &pair)
Definition Dedup.cpp:215
void update(Operation *op)
Definition Dedup.cpp:349
llvm::SHA256 sha
Definition Dedup.cpp:397
DenseMap< StringAttr, std::pair< size_t, size_t > > innerSymIDTable
Definition Dedup.cpp:387
ModuleInfo getModuleInfo(FModuleLike module)
Definition Dedup.cpp:134
void update(size_t value)
Definition Dedup.cpp:209
void update(BundleType type)
Definition Dedup.cpp:223
unsigned getID(void *object)
Definition Dedup.cpp:166
void update(OpResult result)
Definition Dedup.cpp:255
void update(OpOperand &operand)
Definition Dedup.cpp:187
StructuralHasher(const StructuralHasherSharedConstants &constants)
Definition Dedup.cpp:131
std::vector< Operation * > symbolSensitiveOps
Definition Dedup.cpp:405
void update(Region *region)
Definition Dedup.cpp:341
void update(Block *block)
Definition Dedup.cpp:330
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:390
void update(TypeID typeID)
Definition Dedup.cpp:220
const StructuralHasherSharedConstants & constants
Definition Dedup.cpp:393
std::pair< size_t, size_t > getInnerSymID(StringAttr name)
Definition Dedup.cpp:183
unsigned finalizeID(void *object)
Definition Dedup.cpp:174
void update(mlir::OperationName name)
Definition Dedup.cpp:324
An annotation target is used to keep track of something that is targeted by an Annotation.
AnnotationSet getAnnotations() const
Get the annotations associated with the target.
void setAnnotations(AnnotationSet annotations) const
Set the annotations associated with the target.
This represents an annotation targeting a specific operation.
Attribute getNLAReference(hw::InnerSymbolNamespace &moduleNamespace) const
This represents an annotation targeting a specific port of a module, memory, or instance.
static bool isEqual(const ModuleInfoRef &lhs, const ModuleInfoRef &rhs)
Definition Dedup.cpp:440
static ModuleInfoRef getTombstoneKey()
Definition Dedup.cpp:424
static ModuleInfoRef getEmptyKey()
Definition Dedup.cpp:420
static unsigned getHashValue(const ModuleInfoRef &ref)
Definition Dedup.cpp:428