CIRCT 23.0.0git
Loading...
Searching...
No Matches
Dedup.cpp
Go to the documentation of this file.
1//===- Dedup.cpp - FIRRTL module deduping -----------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements FIRRTL module deduplication.
10//
11//===----------------------------------------------------------------------===//
12
24#include "circt/Support/Debug.h"
25#include "circt/Support/LLVM.h"
26#include "mlir/IR/IRMapping.h"
27#include "mlir/IR/Threading.h"
28#include "mlir/Pass/Pass.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/DenseMapInfo.h"
31#include "llvm/ADT/Hashing.h"
32#include "llvm/ADT/PostOrderIterator.h"
33#include "llvm/ADT/SmallPtrSet.h"
34#include "llvm/Support/Debug.h"
35#include "llvm/Support/Format.h"
36#include "llvm/Support/SHA256.h"
37
38#define DEBUG_TYPE "firrtl-dedup"
39
40namespace circt {
41namespace firrtl {
42#define GEN_PASS_DEF_DEDUP
43#include "circt/Dialect/FIRRTL/Passes.h.inc"
44} // namespace firrtl
45} // namespace circt
46
47using namespace circt;
48using namespace firrtl;
49using hw::InnerRefAttr;
50
51//===----------------------------------------------------------------------===//
52// Utility function for classifying a Symbol's dedup-ability.
53//===----------------------------------------------------------------------===//
54
55/// Returns true if the module can be removed.
56static bool canRemoveModule(mlir::SymbolOpInterface symbol) {
57 // If the symbol is not private, it cannot be removed.
58 if (!symbol.isPrivate())
59 return false;
60 // Classes may be referenced in object types, so can not normally be removed
61 // if we can't find any symbol uses. Since we know that dedup will update the
62 // types of instances appropriately, we can ignore that and return true here.
63 if (isa<ClassLike>(*symbol))
64 return true;
65 // If module can not be removed even if no uses can be found, we can not
66 // delete it. The implication is that there are hidden symbol uses that dedup
67 // will not properly update.
68 if (!symbol.canDiscardOnUseEmpty())
69 return false;
70 // The module can be deleted.
71 return true;
72}
73
74//===----------------------------------------------------------------------===//
75// Hashing
76//===----------------------------------------------------------------------===//
77
78// This struct contains information to determine module module uniqueness. A
79// first element is a structural hash of the module, and the second element is
80// an array which tracks module names encountered in the walk. Since module
81// names could be replaced during dedup, it's necessary to keep names up-to-date
82// before actually combining them into structural hashes.
83struct ModuleInfo {
84 // SHA256 hash.
85 std::array<uint8_t, 32> structuralHash;
86 // Module names referred by instance op in the module.
87 std::vector<StringAttr> referredModuleNames;
88 // The operations that contain references to symbols that may be changed by
89 // dedup. These need a fixup pass after dedup. This is just an optimization
90 // and does not factor into the hash or the equality between `ModuleInfo`s.
91 std::vector<Operation *> symbolSensitiveOps;
92};
93
94static bool operator==(const ModuleInfo &lhs, const ModuleInfo &rhs) {
95 return lhs.structuralHash == rhs.structuralHash &&
97}
98
99/// This struct contains constant string attributes shared across different
100/// threads.
103 portTypesAttr = StringAttr::get(context, "portTypes");
104 moduleNameAttr = StringAttr::get(context, "moduleName");
105 portNamesAttr = StringAttr::get(context, "portNames");
106 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
107 nonessentialAttributes.insert(StringAttr::get(context, "convention"));
108 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
109 nonessentialAttributes.insert(StringAttr::get(context, "name"));
110 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
111 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
112 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
113 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
114 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
115 nonessentialAttributes.insert(StringAttr::get(context, "sym_visibility"));
116 };
117
118 // This is a cached "portTypes" string attr.
119 StringAttr portTypesAttr;
120
121 // This is a cached "moduleName" string attr.
122 StringAttr moduleNameAttr;
123
124 // This is a cached "moduleNames" string attr.
125 StringAttr moduleNamesAttr;
126
127 // This is a cached "portNames" string attr.
128 StringAttr portNamesAttr;
129
130 // This is a set of every attribute we should ignore.
131 DenseSet<Attribute> nonessentialAttributes;
132};
133
136 : constants(constants) {}
137
138 ModuleInfo getModuleInfo(FModuleLike module) {
140 update(&(*module));
141 return {sha.final(), std::move(referredModuleNames),
142 std::move(symbolSensitiveOps)};
143 }
144
145private:
146 /// Find all the ports and operations which may define an inner symbol
147 /// operations and give each a unique id. If the port/operation does define
148 /// an inner symbol, map the symbol name to a pair of the id and the symbol's
149 /// field id. When we hash (local) references to this inner symbol, we will
150 /// hash in the id and the the field id.
151 void populateInnerSymIDTable(FModuleLike module) {
152 // Add port symbols. If no port has a symbol defined, the port symbol array
153 // will be totally empty.
154 for (auto [index, innerSym] : llvm::enumerate(module.getPortSymbols())) {
155 for (auto prop : cast<hw::InnerSymAttr>(innerSym))
156 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
157 }
158 // Add operation symbols.
159 size_t index = module.getNumPorts();
160 module.walk([&](hw::InnerSymbolOpInterface innerSymOp) {
161 if (auto innerSym = innerSymOp.getInnerSymAttr()) {
162 for (auto prop : innerSym)
163 innerSymIDTable[prop.getName()] = std::pair(index, prop.getFieldID());
164 }
165 ++index;
166 });
167 }
168
169 // Get the identifier for an object. The identifier is assigned on first use.
170 unsigned getID(void *object) {
171 auto [it, inserted] = idTable.try_emplace(object, nextID);
172 if (inserted)
173 ++nextID;
174 return it->second;
175 }
176
177 // Get the identifier for an IR object. Free the ID, too.
178 unsigned finalizeID(void *object) {
179 auto it = idTable.find(object);
180 if (it == idTable.end())
181 return nextID++;
182 auto id = it->second;
183 idTable.erase(it);
184 return id;
185 }
186
187 std::pair<size_t, size_t> getInnerSymID(StringAttr name) {
188 return innerSymIDTable.at(name);
189 }
190
191 void update(OpOperand &operand) {
192 auto value = operand.get();
193 if (auto result = dyn_cast<OpResult>(value)) {
194 auto *op = result.getOwner();
195 update(getID(op));
196 update(result.getResultNumber());
197 return;
198 }
199 if (auto argument = dyn_cast<BlockArgument>(value)) {
200 auto *block = argument.getOwner();
201 update(getID(block));
202 update(argument.getArgNumber());
203 return;
204 }
205 llvm_unreachable("Unknown value type");
206 }
207
208 void update(const void *pointer) {
209 auto *addr = reinterpret_cast<const uint8_t *>(&pointer);
210 sha.update(ArrayRef<uint8_t>(addr, sizeof pointer));
211 }
212
213 void update(size_t value) {
214 auto *addr = reinterpret_cast<const uint8_t *>(&value);
215 sha.update(ArrayRef<uint8_t>(addr, sizeof value));
216 }
217
218 template <typename T, typename U>
219 void update(const std::pair<T, U> &pair) {
220 update(pair.first);
221 update(pair.second);
222 }
223
224 void update(TypeID typeID) { update(typeID.getAsOpaquePointer()); }
225
226 // NOLINTNEXTLINE(misc-no-recursion)
227 void update(BundleType type) {
228 update(type.getTypeID());
229 for (auto &element : type.getElements()) {
230 update(element.isFlip);
231 update(element.type);
232 }
233 }
234
235 // NOLINTNEXTLINE(misc-no-recursion)
236 void update(ClassType type) {
237 update(type.getTypeID());
238 // Don't hash the class name directly, since it may be replaced during
239 // dedup. Record the class name instead and lazily combine their hashes
240 // using the same mechanism as instances and modules.
241 hasSeenSymbol = true;
242 referredModuleNames.push_back(type.getNameAttr().getAttr());
243 for (auto &element : type.getElements()) {
244 update(element.name.getAsOpaquePointer());
245 update(element.type);
246 update(static_cast<unsigned>(element.direction));
247 }
248 }
249
250 // NOLINTNEXTLINE(misc-no-recursion)
251 void update(Type type) {
252 if (auto bundle = type_dyn_cast<BundleType>(type))
253 return update(bundle);
254 if (auto klass = type_dyn_cast<ClassType>(type))
255 return update(klass);
256 update(type.getAsOpaquePointer());
257 }
258
259 void update(OpResult result) {
260 // Like instance ops, don't use object ops' result types since they might be
261 // replaced by dedup. Record the class names and lazily combine their hashes
262 // using the same mechanism as instances and modules.
263 if (auto objectOp = dyn_cast<ObjectOp>(result.getOwner())) {
264 hasSeenSymbol = true;
265 referredModuleNames.push_back(objectOp.getType().getNameAttr().getAttr());
266 return;
267 }
268
269 update(result.getType());
270 }
271
272 /// Hash the top level attribute dictionary of the operation. This function
273 /// has special handling for inner symbols, ports, and referenced modules.
274 void update(Operation *op, DictionaryAttr dict) {
275 for (auto namedAttr : dict) {
276 auto name = namedAttr.getName();
277 auto value = namedAttr.getValue();
278
279 // Check whether this attribute contains a nested symbol, just to make
280 // sure we revisit this op after dedup and update any symbols.
281 value.walk([&](FlatSymbolRefAttr) { hasSeenSymbol = true; });
282
283 // Skip names and annotations, except in certain cases.
284 // Names of ports are load bearing for classes, so we do hash those.
285 bool isClassPortNames =
286 isa<ClassLike>(op) && name == constants.portNamesAttr;
287 if (constants.nonessentialAttributes.contains(name) && !isClassPortNames)
288 continue;
289
290 // Hash the attribute name (an interned pointer).
291 update(name.getAsOpaquePointer());
292
293 // Hash the port types.
294 if (name == constants.portTypesAttr) {
295 auto portTypes = cast<ArrayAttr>(value).getAsValueRange<TypeAttr>();
296 for (auto type : portTypes)
297 update(type);
298 continue;
299 }
300
301 // For instance op, don't use `moduleName` attributes since they might be
302 // replaced by dedup. Record the names and lazily combine their hashes.
303 // It is assumed that module names are hashed only through instance ops;
304 // it could cause suboptimal results if there was other operation that
305 // refers to module names through essential attributes.
306 if (isa<InstanceOp>(op) && name == constants.moduleNameAttr) {
307 referredModuleNames.push_back(cast<FlatSymbolRefAttr>(value).getAttr());
308 continue;
309 }
310
311 if (isa<InstanceChoiceOp>(op) && name == constants.moduleNamesAttr) {
312 for (auto module : cast<ArrayAttr>(value))
313 referredModuleNames.push_back(
314 cast<FlatSymbolRefAttr>(module).getAttr());
315 continue;
316 }
317
318 // TODO: properly handle DistinctAttr, including its use in paths.
319 // See https://github.com/llvm/circt/issues/6583.
320 if (isa<DistinctAttr>(value))
321 continue;
322
323 // If this is an symbol reference, we need to perform name erasure.
324 if (auto innerRef = dyn_cast<hw::InnerRefAttr>(value)) {
325 update(getInnerSymID(innerRef.getName()));
326 continue;
327 }
328
329 // We don't need to handle this attribute specially, so hash its unique
330 // address.
331 update(value.getAsOpaquePointer());
332 }
333 }
334
335 void update(mlir::OperationName name) {
336 // Operation names are interned.
337 update(name.getAsOpaquePointer());
338 }
339
340 // NOLINTNEXTLINE(misc-no-recursion)
341 void update(Block *block) {
342 for (auto &op : llvm::reverse(*block))
343 update(&op);
344 for (auto type : block->getArgumentTypes())
345 update(type);
346 update(finalizeID(block));
347 update(position);
348 ++position;
349 }
350
351 // NOLINTNEXTLINE(misc-no-recursion)
352 void update(Region *region) {
353 for (auto &block : llvm::reverse(region->getBlocks()))
354 update(&block);
355 update(position);
356 ++position;
357 }
358
359 // NOLINTNEXTLINE(misc-no-recursion)
360 void update(Operation *op) {
361 // Hash the regions. We need to make sure an empty region doesn't hash the
362 // same as no region, so we include the number of regions.
363 update(op->getNumRegions());
364 for (auto &region : reverse(op->getRegions()))
365 update(&region);
366
367 update(op->getName());
368
369 // Record the uses for later hashing.
370 for (auto &operand : op->getOpOperands())
371 update(operand);
372
373 // This happens after the numbering above, as it uses blockarg numbering
374 // for inner symbols.
375 hasSeenSymbol = false;
376 update(op, op->getAttrDictionary());
377
378 // Record any op results (types).
379 for (auto result : op->getResults())
380 update(result);
381
382 // If any of the operands, attributes, or results depended on symbols,
383 // remember this op for later adjustment.
384 if (hasSeenSymbol)
385 symbolSensitiveOps.push_back(op);
386
387 // Incorporate the hash of uses we have already built.
388 update(finalizeID(op));
389 update(position);
390 ++position;
391 }
392
393 // A map from an operation/block, to its identifier.
394 DenseMap<void *, unsigned> idTable;
395 unsigned nextID = 0;
396
397 // A map from an inner symbol, to its identifier.
398 DenseMap<StringAttr, std::pair<size_t, size_t>> innerSymIDTable;
399
400 // This keeps track of module names in the order of the appearance.
401 std::vector<StringAttr> referredModuleNames;
402
403 // String constants.
405
406 // This is the actual running hash calculation. This is a stateful element
407 // that should be reinitialized after each hash is produced.
408 llvm::SHA256 sha;
409
410 // The index of the current op. Increment after handling each op.
411 size_t position = 0;
412
413 // The operations that contain references to symbols that may be changed by
414 // dedup. These need a fixup pass after dedup.
415 bool hasSeenSymbol = false;
416 std::vector<Operation *> symbolSensitiveOps;
417};
418
419/// A reference to a `ModuleInfo` that compares and hashes like it. This is used
420/// to keep the potentially heavy module infos in a vector, and then populate
421/// maps with references to them.
426
427/// Allow `ModuleInfoRef` to be used as dense map keys. Hashes and compares the
428/// `ModuleInfo` the ref points to.
429template <>
434
438
439 static unsigned getHashValue(const ModuleInfoRef &ref) {
440 // We assume SHA256 is already a good hash and just truncate down to the
441 // number of bytes we need for DenseMap.
442 unsigned hash;
443 std::memcpy(&hash, ref.info->structuralHash.data(), sizeof(unsigned));
444
445 // Combine module names.
446 return llvm::hash_combine(
447 hash, llvm::hash_combine_range(ref.info->referredModuleNames.begin(),
448 ref.info->referredModuleNames.end()));
449 }
450
451 static bool isEqual(const ModuleInfoRef &lhs, const ModuleInfoRef &rhs) {
452 auto *empty = getEmptyKey().info;
453 auto *tombstone = getTombstoneKey().info;
454 if (lhs.info == empty || rhs.info == empty || lhs.info == tombstone ||
455 rhs.info == tombstone)
456 return lhs.info == rhs.info;
457 return *lhs.info == *rhs.info;
458 }
459};
460
461//===----------------------------------------------------------------------===//
462// Equivalence
463//===----------------------------------------------------------------------===//
464
465/// This class is for reporting differences between two modules which should
466/// have been deduplicated.
470 noDedupClass = StringAttr::get(context, noDedupAnnoClass);
471 dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
472 portDirectionsAttr = StringAttr::get(context, "portDirections");
473 nonessentialAttributes.insert(StringAttr::get(context, "annotations"));
474 nonessentialAttributes.insert(StringAttr::get(context, "name"));
475 nonessentialAttributes.insert(StringAttr::get(context, "portAnnotations"));
476 nonessentialAttributes.insert(StringAttr::get(context, "portNames"));
477 nonessentialAttributes.insert(StringAttr::get(context, "portTypes"));
478 nonessentialAttributes.insert(StringAttr::get(context, "portSymbols"));
479 nonessentialAttributes.insert(StringAttr::get(context, "portLocations"));
480 nonessentialAttributes.insert(StringAttr::get(context, "sym_name"));
481 nonessentialAttributes.insert(StringAttr::get(context, "inner_sym"));
482 }
483
491
492 std::string prettyPrint(Attribute attr) {
493 SmallString<64> buffer;
494 llvm::raw_svector_ostream os(buffer);
495 if (auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
496 os << "0x";
497 if (integerAttr.getType().isSignlessInteger())
498 integerAttr.getValue().toStringUnsigned(buffer, /*radix=*/16);
499 else
500 integerAttr.getAPSInt().toString(buffer, /*radix=*/16);
501
502 } else
503 os << attr;
504 return std::string(buffer);
505 }
506
507 // NOLINTNEXTLINE(misc-no-recursion)
508 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
509 Operation *a, BundleType aType, Operation *b,
510 BundleType bType) {
511 if (aType.getNumElements() != bType.getNumElements()) {
512 diag.attachNote(a->getLoc())
513 << message << " bundle type has different number of elements";
514 diag.attachNote(b->getLoc()) << "second operation here";
515 return failure();
516 }
517
518 for (auto elementPair :
519 llvm::zip(aType.getElements(), bType.getElements())) {
520 auto aElement = std::get<0>(elementPair);
521 auto bElement = std::get<1>(elementPair);
522 if (aElement.isFlip != bElement.isFlip) {
523 diag.attachNote(a->getLoc()) << message << " bundle element "
524 << aElement.name << " flip does not match";
525 diag.attachNote(b->getLoc()) << "second operation here";
526 return failure();
527 }
528
529 if (failed(check(diag,
530 message + " -> bundle element \'" +
531 aElement.name.getValue() + "'",
532 a, aElement.type, b, bElement.type)))
533 return failure();
534 }
535 return success();
536 }
537
538 LogicalResult check(InFlightDiagnostic &diag, const Twine &message,
539 Operation *a, Type aType, Operation *b, Type bType) {
540 if (aType == bType)
541 return success();
542 if (auto aBundleType = type_dyn_cast<BundleType>(aType))
543 if (auto bBundleType = type_dyn_cast<BundleType>(bType))
544 return check(diag, message, a, aBundleType, b, bBundleType);
545 if (type_isa<RefType>(aType) && type_isa<RefType>(bType) &&
546 aType != bType) {
547 diag.attachNote(a->getLoc())
548 << message << ", has a RefType with a different base type "
549 << type_cast<RefType>(aType).getType()
550 << " in the same position of the two modules marked as 'must dedup'. "
551 "(This may be due to Grand Central Taps or Views being different "
552 "between the two modules.)";
553 diag.attachNote(b->getLoc())
554 << "the second module has a different base type "
555 << type_cast<RefType>(bType).getType();
556 return failure();
557 }
558 diag.attachNote(a->getLoc())
559 << message << " types don't match, first type is " << aType;
560 diag.attachNote(b->getLoc()) << "second type is " << bType;
561 return failure();
562 }
563
564 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
565 Block &aBlock, Operation *b, Block &bBlock) {
566
567 // Block argument types.
568 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
569 auto portNo = 0;
570 auto emitMissingPort = [&](Value existsVal, Operation *opExists,
571 Operation *opDoesNotExist) {
572 StringRef portName;
573 auto portNames = opExists->getAttrOfType<ArrayAttr>("portNames");
574 if (portNames)
575 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
576 portName = portNameAttr.getValue();
577 if (type_isa<RefType>(existsVal.getType())) {
578 diag.attachNote(opExists->getLoc())
579 << " contains a RefType port named '" + portName +
580 "' that only exists in one of the modules (can be due to "
581 "difference in Grand Central Tap or View of two modules "
582 "marked with must dedup)";
583 diag.attachNote(opDoesNotExist->getLoc())
584 << "second module to be deduped that does not have the RefType "
585 "port";
586 } else {
587 diag.attachNote(opExists->getLoc())
588 << "port '" + portName + "' only exists in one of the modules";
589 diag.attachNote(opDoesNotExist->getLoc())
590 << "second module to be deduped that does not have the port";
591 }
592 return failure();
593 };
594
595 for (auto argPair :
596 llvm::zip_longest(aBlock.getArguments(), bBlock.getArguments())) {
597 auto &aArg = std::get<0>(argPair);
598 auto &bArg = std::get<1>(argPair);
599 if (aArg.has_value() && bArg.has_value()) {
600 // TODO: we should print the port number if there are no port names, but
601 // there are always port names ;).
602 StringRef portName;
603 if (portNames) {
604 if (auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
605 portName = portNameAttr.getValue();
606 }
607 // Assumption here that block arguments correspond to ports.
608 if (failed(check(diag, "module port '" + portName + "'", a,
609 aArg->getType(), b, bArg->getType())))
610 return failure();
611 data.map.map(aArg.value(), bArg.value());
612 portNo++;
613 continue;
614 }
615 if (!aArg.has_value())
616 std::swap(a, b);
617 return emitMissingPort(aArg.has_value() ? aArg.value() : bArg.value(), a,
618 b);
619 }
620
621 // Blocks operations.
622 auto aIt = aBlock.begin();
623 auto aEnd = aBlock.end();
624 auto bIt = bBlock.begin();
625 auto bEnd = bBlock.end();
626 while (aIt != aEnd && bIt != bEnd)
627 if (failed(check(diag, data, &*aIt++, &*bIt++)))
628 return failure();
629 if (aIt != aEnd) {
630 diag.attachNote(aIt->getLoc()) << "first block has more operations";
631 diag.attachNote(b->getLoc()) << "second block here";
632 return failure();
633 }
634 if (bIt != bEnd) {
635 diag.attachNote(bIt->getLoc()) << "second block has more operations";
636 diag.attachNote(a->getLoc()) << "first block here";
637 return failure();
638 }
639 return success();
640 }
641
642 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
643 Region &aRegion, Operation *b, Region &bRegion) {
644 auto aIt = aRegion.begin();
645 auto aEnd = aRegion.end();
646 auto bIt = bRegion.begin();
647 auto bEnd = bRegion.end();
648
649 // Region blocks.
650 while (aIt != aEnd && bIt != bEnd)
651 if (failed(check(diag, data, a, *aIt++, b, *bIt++)))
652 return failure();
653 if (aIt != aEnd || bIt != bEnd) {
654 diag.attachNote(a->getLoc())
655 << "operation regions have different number of blocks";
656 diag.attachNote(b->getLoc()) << "second operation here";
657 return failure();
658 }
659 return success();
660 }
661
662 LogicalResult check(InFlightDiagnostic &diag, Operation *a,
663 mlir::DenseBoolArrayAttr aAttr, Operation *b,
664 mlir::DenseBoolArrayAttr bAttr) {
665 if (aAttr == bAttr)
666 return success();
667 auto portNames = a->getAttrOfType<ArrayAttr>("portNames");
668 for (unsigned i = 0, e = aAttr.size(); i < e; ++i) {
669 auto aDirection = aAttr[i];
670 auto bDirection = bAttr[i];
671 if (aDirection != bDirection) {
672 auto &note = diag.attachNote(a->getLoc()) << "module port ";
673 if (portNames)
674 note << "'" << cast<StringAttr>(portNames[i]).getValue() << "'";
675 else
676 note << i;
677 note << " directions don't match, first direction is '"
678 << direction::toString(aDirection) << "'";
679 diag.attachNote(b->getLoc()) << "second direction is '"
680 << direction::toString(bDirection) << "'";
681 return failure();
682 }
683 }
684 return success();
685 }
686
687 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
688 DictionaryAttr aDict, Operation *b,
689 DictionaryAttr bDict) {
690 // Fast path.
691 if (aDict == bDict)
692 return success();
693
694 DenseSet<Attribute> seenAttrs;
695 for (auto namedAttr : aDict) {
696 auto attrName = namedAttr.getName();
697 if (nonessentialAttributes.contains(attrName))
698 continue;
699
700 auto aAttr = namedAttr.getValue();
701 auto bAttr = bDict.get(attrName);
702 if (!bAttr) {
703 diag.attachNote(a->getLoc())
704 << "second operation is missing attribute " << attrName;
705 diag.attachNote(b->getLoc()) << "second operation here";
706 return diag;
707 }
708
709 if (isa<hw::InnerRefAttr>(aAttr) && isa<hw::InnerRefAttr>(bAttr)) {
710 auto bRef = cast<hw::InnerRefAttr>(bAttr);
711 auto aRef = cast<hw::InnerRefAttr>(aAttr);
712 // See if they are pointing at the same operation or port.
713 auto aTarget = data.a.lookup(aRef.getName());
714 auto bTarget = data.b.lookup(bRef.getName());
715 if (!aTarget || !bTarget)
716 diag.attachNote(a->getLoc())
717 << "malformed ir, possibly violating use-before-def";
718 auto error = [&]() {
719 diag.attachNote(a->getLoc())
720 << "operations have different targets, first operation has "
721 << aTarget;
722 diag.attachNote(b->getLoc()) << "second operation has " << bTarget;
723 return failure();
724 };
725 if (aTarget.isPort()) {
726 // If they are targeting ports, make sure its the same port number.
727 if (!bTarget.isPort() || aTarget.getPort() != bTarget.getPort())
728 return error();
729 } else {
730 // Otherwise make sure that they are targeting the same operation.
731 if (!bTarget.isOpOnly() ||
732 data.map.lookupOrNull(aTarget.getOp()) != bTarget.getOp())
733 return error();
734 }
735 if (aTarget.getField() != bTarget.getField())
736 return error();
737 } else if (attrName == portDirectionsAttr) {
738 // Special handling for the port directions attribute for better
739 // error messages.
740 if (failed(check(diag, a, cast<mlir::DenseBoolArrayAttr>(aAttr), b,
741 cast<mlir::DenseBoolArrayAttr>(bAttr))))
742 return failure();
743 } else if (isa<DistinctAttr>(aAttr) && isa<DistinctAttr>(bAttr)) {
744 // TODO: properly handle DistinctAttr, including its use in paths.
745 // See https://github.com/llvm/circt/issues/6583
746 } else if (aAttr != bAttr) {
747 diag.attachNote(a->getLoc())
748 << "first operation has attribute '" << attrName.getValue()
749 << "' with value " << prettyPrint(aAttr);
750 diag.attachNote(b->getLoc())
751 << "second operation has value " << prettyPrint(bAttr);
752 return failure();
753 }
754 seenAttrs.insert(attrName);
755 }
756 if (aDict.getValue().size() != bDict.getValue().size()) {
757 for (auto namedAttr : bDict) {
758 auto attrName = namedAttr.getName();
759 // Skip the attribute if we don't care about this particular one or it
760 // is one that is known to be in both dictionaries.
761 if (nonessentialAttributes.contains(attrName) ||
762 seenAttrs.contains(attrName))
763 continue;
764 // We have found an attribute that is only in the second operation.
765 diag.attachNote(a->getLoc())
766 << "first operation is missing attribute " << attrName;
767 diag.attachNote(b->getLoc()) << "second operation here";
768 return failure();
769 }
770 }
771 return success();
772 }
773
774 // NOLINTNEXTLINE(misc-no-recursion)
775 LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a,
776 FInstanceLike b) {
777 // Get the list of module names from the list (for InstanceOp/ObjectOp,
778 // there's only one)
779 auto aNames = a.getReferencedModuleNamesAttr();
780 auto bNames = b.getReferencedModuleNamesAttr();
781 if (aNames == bNames)
782 return success();
783
784 if (aNames.size() != bNames.size()) {
785 diag.attachNote(a->getLoc())
786 << "an instance has a different number of referenced "
787 "modules: first instance has "
788 << aNames.size() << " modules";
789 diag.attachNote(b->getLoc())
790 << "second instance has " << bNames.size() << " modules";
791 return failure();
792 }
793
794 for (auto [aName, bName] : llvm::zip(aNames.getAsRange<StringAttr>(),
795 bNames.getAsRange<StringAttr>())) {
796 if (aName == bName)
797 continue;
798 // If the modules instantiate are different we will want to know why the
799 // sub module did not dedupliate. This code recursively checks the child
800 // module.
801 auto aModule = instanceGraph.lookup(aName)->getModule();
802 auto bModule = instanceGraph.lookup(bName)->getModule();
803 // Create a new error for the submodule.
804 diag.attachNote(std::nullopt)
805 << "in instance " << a.getInstanceNameAttr() << " of " << aName
806 << ", and instance " << b.getInstanceNameAttr() << " of " << bName;
807 check(diag, aModule, bModule);
808 }
809 return failure();
810 }
811
812 // NOLINTNEXTLINE(misc-no-recursion)
813 LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a,
814 Operation *b) {
815 // Operation name.
816 if (a->getName() != b->getName()) {
817 diag.attachNote(a->getLoc()) << "first operation is a " << a->getName();
818 diag.attachNote(b->getLoc()) << "second operation is a " << b->getName();
819 return failure();
820 }
821
822 // If its an instance operaiton, perform some checking and possibly
823 // recurse.
824 if (auto aInst = dyn_cast<FInstanceLike>(a)) {
825 auto bInst = cast<FInstanceLike>(b);
826 if (failed(check(diag, aInst, bInst)))
827 return failure();
828 }
829
830 // Operation results.
831 if (a->getNumResults() != b->getNumResults()) {
832 diag.attachNote(a->getLoc())
833 << "operations have different number of results";
834 diag.attachNote(b->getLoc()) << "second operation here";
835 return failure();
836 }
837 for (auto resultPair : llvm::zip(a->getResults(), b->getResults())) {
838 auto &aValue = std::get<0>(resultPair);
839 auto &bValue = std::get<1>(resultPair);
840 if (failed(check(diag, "operation result", a, aValue.getType(), b,
841 bValue.getType())))
842 return failure();
843 data.map.map(aValue, bValue);
844 }
845
846 // Operations operands.
847 if (a->getNumOperands() != b->getNumOperands()) {
848 diag.attachNote(a->getLoc())
849 << "operations have different number of operands";
850 diag.attachNote(b->getLoc()) << "second operation here";
851 return failure();
852 }
853 for (auto operandPair : llvm::zip(a->getOperands(), b->getOperands())) {
854 auto &aValue = std::get<0>(operandPair);
855 auto &bValue = std::get<1>(operandPair);
856 if (bValue != data.map.lookup(aValue)) {
857 diag.attachNote(a->getLoc())
858 << "operations use different operands, first operand is '"
859 << getFieldName(
860 getFieldRefFromValue(aValue, /*lookThroughCasts=*/true))
861 .first
862 << "'";
863 diag.attachNote(b->getLoc())
864 << "second operand is '"
865 << getFieldName(
866 getFieldRefFromValue(bValue, /*lookThroughCasts=*/true))
867 .first
868 << "', but should have been '"
869 << getFieldName(getFieldRefFromValue(data.map.lookup(aValue),
870 /*lookThroughCasts=*/true))
871 .first
872 << "'";
873 return failure();
874 }
875 }
876 data.map.map(a, b);
877
878 // Operation regions.
879 if (a->getNumRegions() != b->getNumRegions()) {
880 diag.attachNote(a->getLoc())
881 << "operations have different number of regions";
882 diag.attachNote(b->getLoc()) << "second operation here";
883 return failure();
884 }
885 for (auto regionPair : llvm::zip(a->getRegions(), b->getRegions())) {
886 auto &aRegion = std::get<0>(regionPair);
887 auto &bRegion = std::get<1>(regionPair);
888 if (failed(check(diag, data, a, aRegion, b, bRegion)))
889 return failure();
890 }
891
892 // Operation attributes.
893 if (failed(check(diag, data, a, a->getAttrDictionary(), b,
894 b->getAttrDictionary())))
895 return failure();
896 return success();
897 }
898
899 // NOLINTNEXTLINE(misc-no-recursion)
900 void check(InFlightDiagnostic &diag, Operation *a, Operation *b) {
901 hw::InnerSymbolTable aTable(a);
902 hw::InnerSymbolTable bTable(b);
903 ModuleData data(aTable, bTable);
905 diag.attachNote(a->getLoc()) << "module marked NoDedup";
906 return;
907 }
909 diag.attachNote(b->getLoc()) << "module marked NoDedup";
910 return;
911 }
912 auto aSymbol = cast<mlir::SymbolOpInterface>(a);
913 auto bSymbol = cast<mlir::SymbolOpInterface>(b);
914 if (!canRemoveModule(aSymbol) && !canRemoveModule(bSymbol)) {
915 diag.attachNote(a->getLoc())
916 << "module is "
917 << (aSymbol.isPrivate() ? "private but not discardable" : "public");
918 diag.attachNote(b->getLoc())
919 << "module is "
920 << (bSymbol.isPrivate() ? "private but not discardable" : "public");
921 return;
922 }
923 auto aGroup =
924 dyn_cast_or_null<StringAttr>(a->getDiscardableAttr(dedupGroupAttrName));
925 auto bGroup = dyn_cast_or_null<StringAttr>(
926 b->getAttrOfType<StringAttr>(dedupGroupAttrName));
927 if (aGroup != bGroup) {
928 if (bGroup) {
929 diag.attachNote(b->getLoc())
930 << "module is in dedup group '" << bGroup.str() << "'";
931 } else {
932 diag.attachNote(b->getLoc()) << "module is not part of a dedup group";
933 }
934 if (aGroup) {
935 diag.attachNote(a->getLoc())
936 << "module is in dedup group '" << aGroup.str() << "'";
937 } else {
938 diag.attachNote(a->getLoc()) << "module is not part of a dedup group";
939 }
940 return;
941 }
942 if (failed(check(diag, data, a, b)))
943 return;
944 diag.attachNote(a->getLoc()) << "first module here";
945 diag.attachNote(b->getLoc()) << "second module here";
946 }
947
948 // This is a cached "portDirections" string attr.
950 // This is a cached "NoDedup" annotation class string attr.
951 StringAttr noDedupClass;
952 // This is a cached string attr for the dedup group attribute.
954
955 // This is a set of every attribute we should ignore.
956 DenseSet<Attribute> nonessentialAttributes;
958};
959
960//===----------------------------------------------------------------------===//
961// Deduplication
962//===----------------------------------------------------------------------===//
963
964// Custom location merging. This only keeps track of 8 annotations from ".fir"
965// files, and however many annotations come from "real" sources. When
966// deduplicating, modules tend not to have scala source locators, so we wind
967// up fusing source locators for a module from every copy being deduped. There
968// is little value in this (all the modules are identical by definition).
969static Location mergeLoc(MLIRContext *context, Location to, Location from) {
970 // Unique the set of locations to be fused.
971 llvm::SmallSetVector<Location, 4> decomposedLocs;
972 // only track 8 "fir" locations
973 unsigned seenFIR = 0;
974 for (auto loc : {to, from}) {
975 // If the location is a fused location we decompose it if it has no
976 // metadata or the metadata is the same as the top level metadata.
977 if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
978 // UnknownLoc's have already been removed from FusedLocs so we can
979 // simply add all of the internal locations.
980 for (auto loc : fusedLoc.getLocations()) {
981 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
982 if (fileLoc.getFilename().strref().ends_with(".fir")) {
983 ++seenFIR;
984 if (seenFIR > 8)
985 continue;
986 }
987 }
988 decomposedLocs.insert(loc);
989 }
990 continue;
991 }
992
993 // Might need to skip this fir.
994 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
995 if (fileLoc.getFilename().strref().ends_with(".fir")) {
996 ++seenFIR;
997 if (seenFIR > 8)
998 continue;
999 }
1000 }
1001 // Otherwise, only add known locations to the set.
1002 if (!isa<UnknownLoc>(loc))
1003 decomposedLocs.insert(loc);
1004 }
1005
1006 auto locs = decomposedLocs.getArrayRef();
1007
1008 // Handle the simple cases of less than two locations. Ensure the metadata (if
1009 // provided) is not dropped.
1010 if (locs.empty())
1011 return UnknownLoc::get(context);
1012 if (locs.size() == 1)
1013 return locs.front();
1014
1015 return FusedLoc::get(context, locs);
1016}
1017
1018struct Deduper {
1019
1020 using RenameMap = DenseMap<StringAttr, StringAttr>;
1021
1023 NLATable *nlaTable, CircuitOp circuit)
1024 : context(circuit->getContext()), instanceGraph(instanceGraph),
1026 nlaBlock(circuit.getBodyBlock()),
1027 nonLocalString(StringAttr::get(context, "circt.nonlocal")),
1028 classString(StringAttr::get(context, "class")) {
1029 // Populate the NLA cache.
1030 for (auto nla : circuit.getOps<hw::HierPathOp>())
1031 nlaCache[nla.getNamepathAttr()] = nla.getSymNameAttr();
1032 }
1033
1034 /// Remove the "fromModule", and replace all references to it with the
1035 /// "toModule". Modules should be deduplicated in a bottom-up order. Any
1036 /// module which is not deduplicated needs to be recorded with the `record`
1037 /// call.
1038 void dedup(FModuleLike toModule, FModuleLike fromModule) {
1039 // A map of operation (e.g. wires, nodes) names which are changed, which is
1040 // used to update NLAs that reference the "fromModule".
1041 RenameMap renameMap;
1042
1043 // Merge the port locations.
1044 SmallVector<Attribute> newLocs;
1045 for (auto [toLoc, fromLoc] : llvm::zip(toModule.getPortLocations(),
1046 fromModule.getPortLocations())) {
1047 if (toLoc == fromLoc)
1048 newLocs.push_back(toLoc);
1049 else
1050 newLocs.push_back(mergeLoc(context, cast<LocationAttr>(toLoc),
1051 cast<LocationAttr>(fromLoc)));
1052 }
1053 toModule->setAttr("portLocations", ArrayAttr::get(context, newLocs));
1054
1055 // Merge the two modules.
1056 mergeOps(renameMap, toModule, toModule, fromModule, fromModule);
1057
1058 // Rewrite NLAs pathing through these modules to refer to the to module. It
1059 // is safe to do this at this point because NLAs cannot be one element long.
1060 // This means that all NLAs which require more context cannot be targetting
1061 // something in the module it self.
1062 if (auto to = dyn_cast<FModuleOp>(*toModule))
1063 rewriteModuleNLAs(renameMap, to, cast<FModuleOp>(*fromModule));
1064 else
1065 rewriteExtModuleNLAs(renameMap, toModule.getModuleNameAttr(),
1066 fromModule.getModuleNameAttr());
1067
1068 replaceInstances(toModule, fromModule);
1069 }
1070
1071 /// Record the usages of any NLA's in this module, so that we may update the
1072 /// annotation if the parent module is deduped with another module.
1073 void record(FModuleLike module) {
1074 // Record any annotations on the module.
1075 recordAnnotations(module);
1076 // Record port annotations.
1077 for (unsigned i = 0, e = getNumPorts(module); i < e; ++i)
1079 // Record any annotations in the module body.
1080 module->walk([&](Operation *op) { recordAnnotations(op); });
1081 }
1082
1083private:
1084 /// Get a cached namespace for a module.
1086 return moduleNamespaces.try_emplace(module, cast<FModuleLike>(module))
1087 .first->second;
1088 }
1089
1090 /// For a specific annotation target, record all the unique NLAs which
1091 /// target it in the `targetMap`.
1093 for (auto anno : target.getAnnotations())
1094 if (auto nlaRef = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal"))
1095 targetMap[nlaRef.getAttr()].insert(target);
1096 }
1097
1098 /// Record all targets which use an NLA.
1099 void recordAnnotations(Operation *op) {
1100 // Record annotations.
1101 recordAnnotations(OpAnnoTarget(op));
1102
1103 // Record port annotations only if this is a mem operation.
1104 auto mem = dyn_cast<MemOp>(op);
1105 if (!mem)
1106 return;
1107
1108 // Record port annotations.
1109 for (unsigned i = 0, e = mem->getNumResults(); i < e; ++i)
1110 recordAnnotations(PortAnnoTarget(mem, i));
1111 }
1112
1113 /// This deletes and replaces all instances of the "fromModule" with instances
1114 /// of the "toModule".
1115 void replaceInstances(FModuleLike toModule, Operation *fromModule) {
1116 // Replace all instances of the other module.
1117 auto *fromNode =
1118 instanceGraph[::cast<igraph::ModuleOpInterface>(fromModule)];
1119 auto *toNode = instanceGraph[toModule];
1120 auto toModuleRef = FlatSymbolRefAttr::get(toModule.getModuleNameAttr());
1121 for (auto *oldInstRec : llvm::make_early_inc_range(fromNode->uses())) {
1122 auto inst = oldInstRec->getInstance();
1123 if (auto instOp = dyn_cast<InstanceOp>(*inst)) {
1124 instOp.setModuleNameAttr(toModuleRef);
1125 instOp.setPortNamesAttr(toModule.getPortNamesAttr());
1126 } else if (auto objectOp = dyn_cast<ObjectOp>(*inst)) {
1127 auto classLike = cast<ClassLike>(*toNode->getModule());
1128 ClassType classType = detail::getInstanceTypeForClassLike(classLike);
1129 objectOp.getResult().setType(classType);
1130 } else if (auto instanceChoiceOp = dyn_cast<InstanceChoiceOp>(*inst)) {
1131 auto fromModuleName = fromNode->getModule().getModuleNameAttr();
1132 SmallVector<Attribute> newModules;
1133 for (auto module : instanceChoiceOp.getReferencedModuleNamesAttr()) {
1134 auto moduleName = cast<StringAttr>(module);
1135 if (moduleName == fromModuleName)
1136 newModules.push_back(toModuleRef);
1137 else
1138 newModules.push_back(FlatSymbolRefAttr::get(moduleName));
1139 }
1140 instanceChoiceOp.setModuleNamesAttr(
1141 ArrayAttr::get(context, newModules));
1142 instanceChoiceOp.setPortNamesAttr(toModule.getPortNamesAttr());
1143 }
1144 oldInstRec->getParent()->addInstance(inst, toNode);
1145 oldInstRec->erase();
1146 }
1147 instanceGraph.erase(fromNode);
1148 fromModule->erase();
1149 }
1150
1151 /// Look up the instantiations of the `from` module and create an NLA for each
1152 /// one, appending the baseNamepath to each NLA. This is used to add more
1153 /// context to an already existing NLA. The `fromModule` is used to indicate
1154 /// which module the annotation is coming from before the merge, and will be
1155 /// used to create the namepaths.
1156 SmallVector<FlatSymbolRefAttr>
1157 createNLAs(Operation *fromModule, ArrayRef<Attribute> baseNamepath,
1158 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1159 // Create an attribute array with a placeholder in the first element, where
1160 // the root refence of the NLA will be inserted.
1161 SmallVector<Attribute> namepath = {nullptr};
1162 namepath.append(baseNamepath.begin(), baseNamepath.end());
1163
1164 auto loc = fromModule->getLoc();
1165 auto *fromNode = instanceGraph[cast<igraph::ModuleOpInterface>(fromModule)];
1166 SmallVector<FlatSymbolRefAttr> nlas;
1167 for (auto *instanceRecord : fromNode->uses()) {
1168 auto parent = cast<FModuleOp>(*instanceRecord->getParent()->getModule());
1169 auto inst = instanceRecord->getInstance();
1170 namepath[0] = OpAnnoTarget(inst).getNLAReference(getNamespace(parent));
1171 auto arrayAttr = ArrayAttr::get(context, namepath);
1172 // Check the NLA cache to see if we already have this NLA.
1173 auto &cacheEntry = nlaCache[arrayAttr];
1174 if (!cacheEntry) {
1175 auto builder = OpBuilder::atBlockBegin(nlaBlock);
1176 auto nla = hw::HierPathOp::create(builder, loc, "nla", arrayAttr);
1177 // Insert it into the symbol table to get a unique name.
1178 symbolTable.insert(nla);
1179 // Store it in the cache.
1180 cacheEntry = nla.getNameAttr();
1181 nla.setVisibility(vis);
1182 nlaTable->addNLA(nla);
1183 }
1184 auto nlaRef = FlatSymbolRefAttr::get(cast<StringAttr>(cacheEntry));
1185 nlas.push_back(nlaRef);
1186 }
1187 return nlas;
1188 }
1189
1190 /// Look up the instantiations of this module and create an NLA for each one.
1191 /// This returns an array of symbol references which can be used to reference
1192 /// the NLAs.
1193 SmallVector<FlatSymbolRefAttr>
1194 createNLAs(StringAttr toModuleName, FModuleLike fromModule,
1195 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1196 return createNLAs(fromModule, FlatSymbolRefAttr::get(toModuleName), vis);
1197 }
1198
1199 /// Clone the annotation for each NLA in a list. The attribute list should
1200 /// have a placeholder for the "circt.nonlocal" field, and `nonLocalIndex`
1201 /// should be the index of this field.
1202 void cloneAnnotation(SmallVectorImpl<FlatSymbolRefAttr> &nlas,
1203 Annotation anno, ArrayRef<NamedAttribute> attributes,
1204 unsigned nonLocalIndex,
1205 SmallVectorImpl<Annotation> &newAnnotations) {
1206 SmallVector<NamedAttribute> mutableAttributes(attributes.begin(),
1207 attributes.end());
1208 for (auto &nla : nlas) {
1209 // Add the new annotation.
1210 mutableAttributes[nonLocalIndex].setValue(nla);
1211 auto dict = DictionaryAttr::getWithSorted(context, mutableAttributes);
1212 // The original annotation records if its a subannotation.
1213 anno.setDict(dict);
1214 newAnnotations.push_back(anno);
1215 }
1216 }
1217
1218 /// This erases the NLA op, and removes the NLA from every module's NLA map,
1219 /// but it does not delete the NLA reference from the target operation's
1220 /// annotations.
1221 void eraseNLA(hw::HierPathOp nla) {
1222 // Erase the NLA from the leaf module's nlaMap.
1223 targetMap.erase(nla.getNameAttr());
1224 nlaTable->erase(nla);
1225 nlaCache.erase(nla.getNamepathAttr());
1226 symbolTable.erase(nla);
1227 }
1228
1229 /// Process all NLAs referencing the "from" module to point to the "to"
1230 /// module. This is used after merging two modules together.
1231 void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule,
1232 FModuleOp fromModule) {
1233 auto toName = toModule.getNameAttr();
1234 auto fromName = fromModule.getNameAttr();
1235 // Create a copy of the current NLAs. We will be pushing and removing
1236 // NLAs from this op as we go.
1237 auto moduleNLAs = nlaTable->lookup(fromModule.getNameAttr()).vec();
1238 // Change the NLA to target the toModule.
1239 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1240 // Now we walk the NLA searching for ones that require more context to be
1241 // added.
1242 for (auto nla : moduleNLAs) {
1243 auto elements = nla.getNamepath().getValue();
1244 // If we don't need to add more context, we're done here.
1245 if (nla.root() != toName)
1246 continue;
1247 // Create the replacement NLAs.
1248 SmallVector<Attribute> namepath(elements.begin(), elements.end());
1249 auto nlaRefs = createNLAs(fromModule, namepath, nla.getVisibility());
1250 // Copy out the targets, because we will be updating the map.
1251 auto &set = targetMap[nla.getSymNameAttr()];
1252 SmallVector<AnnoTarget> targets(set.begin(), set.end());
1253 // Replace the uses of the old NLA with the new NLAs.
1254 for (auto target : targets) {
1255 // We have to clone any annotation which uses the old NLA for each new
1256 // NLA. This array collects the new set of annotations.
1257 SmallVector<Annotation> newAnnotations;
1258 for (auto anno : target.getAnnotations()) {
1259 // Find the non-local field of the annotation.
1260 auto [it, found] = mlir::impl::findAttrSorted(
1261 anno.begin(), anno.end(), nonLocalString);
1262 // If this annotation doesn't use the target NLA, copy it with no
1263 // changes.
1264 if (!found || cast<FlatSymbolRefAttr>(it->getValue()).getAttr() !=
1265 nla.getSymNameAttr()) {
1266 newAnnotations.push_back(anno);
1267 continue;
1268 }
1269 auto nonLocalIndex = std::distance(anno.begin(), it);
1270 // Clone the annotation and add it to the list of new annotations.
1271 cloneAnnotation(nlaRefs, anno,
1272 ArrayRef<NamedAttribute>(anno.begin(), anno.end()),
1273 nonLocalIndex, newAnnotations);
1274 }
1275
1276 // Apply the new annotations to the operation.
1277 AnnotationSet annotations(newAnnotations, context);
1278 target.setAnnotations(annotations);
1279 // Record that target uses the NLA.
1280 for (auto nla : nlaRefs)
1281 targetMap[nla.getAttr()].insert(target);
1282 }
1283
1284 // Erase the old NLA and remove it from all breadcrumbs.
1285 eraseNLA(nla);
1286 }
1287 }
1288
1289 /// Process all the NLAs that the two modules participate in, replacing
1290 /// references to the "from" module with references to the "to" module, and
1291 /// adding more context if necessary.
1292 void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule,
1293 FModuleOp fromModule) {
1294 addAnnotationContext(renameMap, toModule, toModule);
1295 addAnnotationContext(renameMap, toModule, fromModule);
1296 }
1297
1298 // Update all NLAs which the "from" external module participates in to the
1299 // "toName".
1300 void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName,
1301 StringAttr fromName) {
1302 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1303 }
1304
1305 /// Take an annotation, and update it to be a non-local annotation. If the
1306 /// annotation is already non-local and has enough context, it will be skipped
1307 /// for now. Return true if the annotation was made non-local.
1308 bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to,
1309 FModuleLike fromModule, Annotation anno,
1310 SmallVectorImpl<Annotation> &newAnnotations) {
1311 // Start constructing a new annotation, pushing a "circt.nonLocal" field
1312 // into the correct spot if its not already a non-local annotation.
1313 SmallVector<NamedAttribute> attributes;
1314 int nonLocalIndex = -1;
1315 for (const auto &val : llvm::enumerate(anno)) {
1316 auto attr = val.value();
1317 // Is this field "circt.nonlocal"?
1318 auto compare = attr.getName().compare(nonLocalString);
1319 assert(compare != 0 && "should not pass non-local annotations here");
1320 if (compare == 1) {
1321 // This annotation definitely does not have "circt.nonlocal" field. Push
1322 // an empty place holder for the non-local annotation.
1323 nonLocalIndex = val.index();
1324 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1325 break;
1326 }
1327 // Otherwise push the current attribute and keep searching for the
1328 // "circt.nonlocal" field.
1329 attributes.push_back(attr);
1330 }
1331 if (nonLocalIndex == -1) {
1332 // Push an empty "circt.nonlocal" field to the last slot.
1333 nonLocalIndex = attributes.size();
1334 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1335 } else {
1336 // Copy the remaining annotation fields in.
1337 attributes.append(anno.begin() + nonLocalIndex, anno.end());
1338 }
1339
1340 // Construct the NLAs if we don't have any yet.
1341 auto nlaRefs = createNLAs(toModuleName, fromModule);
1342 for (auto nla : nlaRefs)
1343 targetMap[nla.getAttr()].insert(to);
1344
1345 // Clone the annotation for each new NLA.
1346 cloneAnnotation(nlaRefs, anno, attributes, nonLocalIndex, newAnnotations);
1347 return true;
1348 }
1349
1350 void copyAnnotations(FModuleLike toModule, AnnoTarget to,
1351 FModuleLike fromModule, AnnotationSet annos,
1352 SmallVectorImpl<Annotation> &newAnnotations,
1353 SmallPtrSetImpl<Attribute> &dontTouches) {
1354 for (auto anno : annos) {
1355 if (anno.isClass(dontTouchAnnoClass)) {
1356 // Remove the nonlocal field of the annotation if it has one, since this
1357 // is a sticky annotation.
1358 anno.removeMember("circt.nonlocal");
1359 auto [it, inserted] = dontTouches.insert(anno.getAttr());
1360 if (inserted)
1361 newAnnotations.push_back(anno);
1362 continue;
1363 }
1364 // If the annotation is already non-local, we add it as is. It is already
1365 // added to the target map.
1366 if (auto nla = anno.getMember<FlatSymbolRefAttr>("circt.nonlocal")) {
1367 newAnnotations.push_back(anno);
1368 targetMap[nla.getAttr()].insert(to);
1369 continue;
1370 }
1371 // Otherwise make the annotation non-local and add it to the set.
1372 makeAnnotationNonLocal(toModule.getModuleNameAttr(), to, fromModule, anno,
1373 newAnnotations);
1374 }
1375 }
1376
1377 /// Merge the annotations of a specific target, either a operation or a port
1378 /// on an operation.
1379 void mergeAnnotations(FModuleLike toModule, AnnoTarget to,
1380 AnnotationSet toAnnos, FModuleLike fromModule,
1381 AnnoTarget from, AnnotationSet fromAnnos) {
1382 // This is a list of all the annotations which will be added to `to`.
1383 SmallVector<Annotation> newAnnotations;
1384
1385 // We have special case handling of DontTouch to prevent it from being
1386 // turned into a non-local annotation, and to remove duplicates.
1387 llvm::SmallPtrSet<Attribute, 4> dontTouches;
1388
1389 // Iterate the annotations, transforming most annotations into non-local
1390 // ones.
1391 copyAnnotations(toModule, to, toModule, toAnnos, newAnnotations,
1392 dontTouches);
1393 copyAnnotations(toModule, to, fromModule, fromAnnos, newAnnotations,
1394 dontTouches);
1395
1396 // Copy over all the new annotations.
1397 if (!newAnnotations.empty())
1398 to.setAnnotations(AnnotationSet(newAnnotations, context));
1399 }
1400
1401 /// Merge all annotations and port annotations on two operations.
1402 void mergeAnnotations(FModuleLike toModule, Operation *to,
1403 FModuleLike fromModule, Operation *from) {
1404 // Merge op annotations.
1405 mergeAnnotations(toModule, OpAnnoTarget(to), AnnotationSet(to), fromModule,
1406 OpAnnoTarget(from), AnnotationSet(from));
1407
1408 // Merge port annotations.
1409 if (toModule == to) {
1410 // Merge module port annotations.
1411 for (unsigned i = 0, e = getNumPorts(toModule); i < e; ++i)
1412 mergeAnnotations(toModule, PortAnnoTarget(toModule, i),
1413 AnnotationSet::forPort(toModule, i), fromModule,
1414 PortAnnoTarget(fromModule, i),
1415 AnnotationSet::forPort(fromModule, i));
1416 } else if (auto toMem = dyn_cast<MemOp>(to)) {
1417 // Merge memory port annotations.
1418 auto fromMem = cast<MemOp>(from);
1419 for (unsigned i = 0, e = toMem.getNumResults(); i < e; ++i)
1420 mergeAnnotations(toModule, PortAnnoTarget(toMem, i),
1421 AnnotationSet::forPort(toMem, i), fromModule,
1422 PortAnnoTarget(fromMem, i),
1423 AnnotationSet::forPort(fromMem, i));
1424 }
1425 }
1426
1427 hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule,
1428 hw::InnerSymAttr toSym,
1429 hw::InnerSymAttr fromSym) {
1430 if (fromSym && !fromSym.getProps().empty()) {
1431 auto &isn = getNamespace(toModule);
1432 // The properties for the new inner symbol..
1433 SmallVector<hw::InnerSymPropertiesAttr> newProps;
1434 // If the "to" op already has an inner symbol, copy all its properties.
1435 if (toSym)
1436 llvm::append_range(newProps, toSym);
1437 // Add each property from the fromSym to the toSym.
1438 for (auto fromProp : fromSym) {
1439 hw::InnerSymPropertiesAttr newProp;
1440 auto *it = llvm::find_if(newProps, [&](auto p) {
1441 return p.getFieldID() == fromProp.getFieldID();
1442 });
1443 if (it != newProps.end()) {
1444 // If we already have an inner sym with the same field id, use
1445 // that.
1446 newProp = *it;
1447 // If the old symbol is public, we need to make the new one public.
1448 if (fromProp.getSymVisibility().getValue() == "public" &&
1449 newProp.getSymVisibility().getValue() != "public") {
1450 *it = hw::InnerSymPropertiesAttr::get(context, newProp.getName(),
1451 newProp.getFieldID(),
1452 fromProp.getSymVisibility());
1453 }
1454 } else {
1455 // We need to add a new property to the inner symbol for this field.
1456 auto newName = isn.newName(fromProp.getName().getValue());
1457 newProp = hw::InnerSymPropertiesAttr::get(
1458 context, StringAttr::get(context, newName), fromProp.getFieldID(),
1459 fromProp.getSymVisibility());
1460 newProps.push_back(newProp);
1461 }
1462 renameMap[fromProp.getName()] = newProp.getName();
1463 }
1464 // Sort the fields by field id.
1465 llvm::sort(newProps, [](auto &p, auto &q) {
1466 return p.getFieldID() < q.getFieldID();
1467 });
1468 // Return the merged inner symbol.
1469 return hw::InnerSymAttr::get(context, newProps);
1470 }
1471 return hw::InnerSymAttr();
1472 }
1473
1474 // Record the symbol name change of the operation or any of its ports when
1475 // merging two operations. The renamed symbols are used to update the
1476 // target of any NLAs. This will add symbols to the "to" operation if needed.
1477 void recordSymRenames(RenameMap &renameMap, FModuleLike toModule,
1478 Operation *to, FModuleLike fromModule,
1479 Operation *from) {
1480 // If the "from" operation has an inner_sym, we need to make sure the
1481 // "to" operation also has an `inner_sym` and then record the renaming.
1482 if (auto fromInnerSym = dyn_cast<hw::InnerSymbolOpInterface>(from)) {
1483 auto toInnerSym = cast<hw::InnerSymbolOpInterface>(to);
1484 if (auto newSymAttr = mergeInnerSymbols(renameMap, toModule,
1485 toInnerSym.getInnerSymAttr(),
1486 fromInnerSym.getInnerSymAttr()))
1487 toInnerSym.setInnerSymbolAttr(newSymAttr);
1488 }
1489
1490 // If there are no port symbols on the "from" operation, we are done here.
1491 auto fromPortSyms = from->getAttrOfType<ArrayAttr>("portSymbols");
1492 if (!fromPortSyms || fromPortSyms.empty())
1493 return;
1494 // We have to map each "fromPort" to each "toPort".
1495 auto portCount = fromPortSyms.size();
1496 auto toPortSyms = to->getAttrOfType<ArrayAttr>("portSymbols");
1497
1498 // Create an array of new port symbols for the "to" operation, copy in the
1499 // old symbols if it has any, create an empty symbol array if it doesn't.
1500 SmallVector<Attribute> newPortSyms;
1501 if (toPortSyms.empty())
1502 newPortSyms.assign(portCount, hw::InnerSymAttr());
1503 else
1504 newPortSyms.assign(toPortSyms.begin(), toPortSyms.end());
1505
1506 for (unsigned portNo = 0; portNo < portCount; ++portNo) {
1507 if (auto newPortSym = mergeInnerSymbols(
1508 renameMap, toModule,
1509 llvm::cast_if_present<hw::InnerSymAttr>(newPortSyms[portNo]),
1510 cast<hw::InnerSymAttr>(fromPortSyms[portNo]))) {
1511 newPortSyms[portNo] = newPortSym;
1512 }
1513 }
1514
1515 // Commit the new symbol attribute.
1516 FModuleLike::fixupPortSymsArray(newPortSyms, toModule.getContext());
1517 cast<FModuleLike>(to).setPortSymbols(newPortSyms);
1518 }
1519
1520 /// Recursively merge two operations.
1521 // NOLINTNEXTLINE(misc-no-recursion)
1522 void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to,
1523 FModuleLike fromModule, Operation *from) {
1524 // Merge the operation locations.
1525 if (to->getLoc() != from->getLoc())
1526 to->setLoc(mergeLoc(context, to->getLoc(), from->getLoc()));
1527
1528 // Recurse into any regions.
1529 for (auto regions : llvm::zip(to->getRegions(), from->getRegions()))
1530 mergeRegions(renameMap, toModule, std::get<0>(regions), fromModule,
1531 std::get<1>(regions));
1532
1533 // Record any inner_sym renamings that happened.
1534 recordSymRenames(renameMap, toModule, to, fromModule, from);
1535
1536 // Merge the annotations.
1537 mergeAnnotations(toModule, to, fromModule, from);
1538 }
1539
1540 /// Recursively merge two blocks.
1541 void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock,
1542 FModuleLike fromModule, Block &fromBlock) {
1543 // Merge the block locations.
1544 for (auto [toArg, fromArg] :
1545 llvm::zip(toBlock.getArguments(), fromBlock.getArguments()))
1546 if (toArg.getLoc() != fromArg.getLoc())
1547 toArg.setLoc(mergeLoc(context, toArg.getLoc(), fromArg.getLoc()));
1548
1549 for (auto ops : llvm::zip(toBlock, fromBlock))
1550 mergeOps(renameMap, toModule, &std::get<0>(ops), fromModule,
1551 &std::get<1>(ops));
1552 }
1553
1554 // Recursively merge two regions.
1555 void mergeRegions(RenameMap &renameMap, FModuleLike toModule,
1556 Region &toRegion, FModuleLike fromModule,
1557 Region &fromRegion) {
1558 for (auto blocks : llvm::zip(toRegion, fromRegion))
1559 mergeBlocks(renameMap, toModule, std::get<0>(blocks), fromModule,
1560 std::get<1>(blocks));
1561 }
1562
1563 MLIRContext *context;
1565 SymbolTable &symbolTable;
1566
1567 /// Cached nla table analysis.
1568 NLATable *nlaTable = nullptr;
1569
1570 /// We insert all NLAs to the beginning of this block.
1571 Block *nlaBlock;
1572
1573 // This maps an NLA to the operations and ports that uses it.
1574 DenseMap<Attribute, llvm::SmallDenseSet<AnnoTarget>> targetMap;
1575
1576 // This is a cache to avoid creating duplicate NLAs. This maps the ArrayAtr
1577 // of the NLA's path to the name of the NLA which contains it.
1578 DenseMap<Attribute, Attribute> nlaCache;
1579
1580 // Cached attributes for faster comparisons and attribute building.
1581 StringAttr nonLocalString;
1582 StringAttr classString;
1583
1584 /// A module namespace cache.
1585 DenseMap<Operation *, hw::InnerSymbolNamespace> moduleNamespaces;
1586};
1587
1588//===----------------------------------------------------------------------===//
1589// Fixup
1590//===----------------------------------------------------------------------===//
1591
1592/// This fixes up connects when the field names of a bundle type changes. It
1593/// finds all fields which were previously bulk connected and legalizes it
1594/// into a connect for each field.
1595static void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src) {
1596 // If the types already match we can emit a connect.
1597 auto dstType = dst.getType();
1598 auto srcType = src.getType();
1599 if (dstType == srcType) {
1600 emitConnect(builder, dst, src);
1601 return;
1602 }
1603 // It must be a bundle type and the field name has changed. We have to
1604 // manually decompose the bulk connect into a connect for each field.
1605 auto dstBundle = type_cast<BundleType>(dstType);
1606 auto srcBundle = type_cast<BundleType>(srcType);
1607 for (unsigned i = 0; i < dstBundle.getNumElements(); ++i) {
1608 auto dstField = SubfieldOp::create(builder, dst, i);
1609 auto srcField = SubfieldOp::create(builder, src, i);
1610 if (dstBundle.getElement(i).isFlip) {
1611 std::swap(srcBundle, dstBundle);
1612 std::swap(srcField, dstField);
1613 }
1614 fixupConnect(builder, dstField, srcField);
1615 }
1616}
1617
1618/// Adjust the symbol references in an op. This includes updating its attributes
1619/// and types.
1620static void
1621fixupSymbolSensitiveOp(Operation *op, InstanceGraph &instanceGraph,
1622 const DenseMap<Attribute, StringAttr> &dedupMap) {
1623 // If this is an instance op, dedup may have subtly changed the port types.
1624 // For example, structurally different bundles may still dedup. In this case
1625 // we now have an instance op that produces result values of the old type, but
1626 // the port info on the instantiated module already represents the new type.
1627 // Fix this up by going through an intermediate wire.
1628 if (auto instOp = dyn_cast<InstanceOp>(op)) {
1629 ImplicitLocOpBuilder builder(instOp.getLoc(), instOp->getContext());
1630 builder.setInsertionPointAfter(instOp);
1631 auto module = instanceGraph.lookup(instOp.getModuleNameAttr().getAttr())
1632 ->getModule<FModuleLike>();
1633 for (auto [index, result] : llvm::enumerate(instOp.getResults())) {
1634 auto newType = module.getPortType(index);
1635 auto oldType = result.getType();
1636 // If the type has not changed, we don't have to fix up anything.
1637 if (newType == oldType)
1638 continue;
1639 LLVM_DEBUG(llvm::dbgs()
1640 << "- Updating instance port \"" << instOp.getInstanceName()
1641 << "." << instOp.getPortName(index) << "\" from " << oldType
1642 << " to " << newType << "\n");
1643
1644 // If the type changed we transform it back to the old type with an
1645 // intermediate wire.
1646 auto wire = WireOp::create(builder, oldType, instOp.getPortName(index))
1647 .getResult();
1648 result.replaceAllUsesWith(wire);
1649 result.setType(newType);
1650 if (instOp.getPortDirection(index) == Direction::Out)
1651 fixupConnect(builder, wire, result);
1652 else
1653 fixupConnect(builder, result, wire);
1654 }
1655 }
1656
1657 // Use an attribute/type replacer to look for references to old symbols that
1658 // need to be replaced with new symbols.
1659 mlir::AttrTypeReplacer replacer;
1660 replacer.addReplacement([&](FlatSymbolRefAttr symRef) {
1661 auto oldName = symRef.getAttr();
1662 auto newName = dedupMap.lookup(oldName);
1663 if (newName && newName != oldName) {
1664 auto newSymRef = FlatSymbolRefAttr::get(newName);
1665 LLVM_DEBUG(llvm::dbgs()
1666 << "- Updating " << symRef << " to " << newSymRef << " in "
1667 << op->getName() << " at " << op->getLoc() << "\n");
1668 return newSymRef;
1669 }
1670 return symRef;
1671 });
1672
1673 // Update attributes.
1674 op->setAttrs(cast<DictionaryAttr>(replacer.replace(op->getAttrDictionary())));
1675
1676 // Update the argument types.
1677 for (auto &region : op->getRegions())
1678 for (auto &block : region)
1679 for (auto arg : block.getArguments())
1680 arg.setType(replacer.replace(arg.getType()));
1681
1682 // Update result types.
1683 for (auto result : op->getResults())
1684 result.setType(replacer.replace(result.getType()));
1685}
1686
1687/// Adjust the symbol references in ops marked as sensitive to them. This
1688/// includes updating their attributes and types.
1690 InstanceGraph &instanceGraph,
1691 const DenseMap<Operation *, ModuleInfoRef> &moduleToModuleInfo,
1692 const DenseMap<Attribute, StringAttr> &dedupMap) {
1693 for (auto *node : instanceGraph) {
1694 // Look up the module info for this module, which contains the list of ops
1695 // that need to be updated.
1696 auto module = node->getModule<FModuleLike>();
1697 auto it = moduleToModuleInfo.find(module);
1698 if (it == moduleToModuleInfo.end())
1699 continue;
1700
1701 // Update each symbol-sensitive op individually.
1702 auto &ops = it->second.info->symbolSensitiveOps;
1703 if (ops.empty())
1704 continue;
1705 LLVM_DEBUG(llvm::dbgs()
1706 << "- Updating " << ops.size() << " symbol-sensitive ops in "
1707 << module.getNameAttr() << "\n");
1708 for (auto *op : ops)
1709 fixupSymbolSensitiveOp(op, instanceGraph, dedupMap);
1710 }
1711}
1712
1713//===----------------------------------------------------------------------===//
1714// DedupPass
1715//===----------------------------------------------------------------------===//
1716
1717namespace {
1718class DedupPass : public circt::firrtl::impl::DedupBase<DedupPass> {
1719 using DedupBase::DedupBase;
1720
1721 void runOnOperation() override {
1722 auto *context = &getContext();
1723 auto circuit = getOperation();
1724 auto &instanceGraph = getAnalysis<InstanceGraph>();
1725 auto *nlaTable = &getAnalysis<NLATable>();
1726 auto &symbolTable = getAnalysis<SymbolTable>();
1727 Deduper deduper(instanceGraph, symbolTable, nlaTable, circuit);
1728 Equivalence equiv(context, instanceGraph);
1729 auto anythingChanged = false;
1730 LLVM_DEBUG({
1731 llvm::dbgs() << "\n";
1732 debugHeader(Twine("Dedup circuit \"") + circuit.getName() + "\"")
1733 << "\n\n";
1734 });
1735
1736 // Modules annotated with this should not be considered for deduplication.
1737 auto noDedupClass = StringAttr::get(context, noDedupAnnoClass);
1738
1739 // Only modules within the same group may be deduplicated.
1740 auto dedupGroupClass = StringAttr::get(context, dedupGroupAnnoClass);
1741
1742 // A map of all the module moduleInfo that we have calculated so far.
1743 DenseMap<ModuleInfoRef, Operation *> moduleInfoToModule;
1744 DenseMap<Operation *, ModuleInfoRef> moduleToModuleInfo;
1745
1746 // We track the name of the module that each module is deduped into, so that
1747 // we can make sure all modules which are marked "must dedup" with each
1748 // other were all deduped to the same module.
1749 DenseMap<Attribute, StringAttr> dedupMap;
1750
1751 // We must iterate the modules from the bottom up so that we can properly
1752 // deduplicate the modules. We copy the list of modules into a vector first
1753 // to avoid iterator invalidation while we mutate the instance graph.
1754 SmallVector<FModuleLike, 0> modules;
1755 instanceGraph.walkPostOrder([&](auto &node) {
1756 if (auto mod = dyn_cast<FModuleLike>(*node.getModule()))
1757 modules.push_back(mod);
1758 });
1759 LLVM_DEBUG(llvm::dbgs() << "Found " << modules.size() << " modules\n");
1760
1761 SmallVector<std::optional<ModuleInfo>> moduleInfos(modules.size());
1762 StructuralHasherSharedConstants hasherConstants(&getContext());
1763
1764 // Attribute name used to store dedup_group for this pass.
1765 auto dedupGroupAttrName = StringAttr::get(context, "firrtl.dedup_group");
1766
1767 // Move dedup group annotations to attributes on the module.
1768 // This results in the desired behavior (included in hash),
1769 // and avoids unnecessary processing of these as annotations
1770 // that need to be tracked, made non-local, so on.
1771 for (auto module : modules) {
1772 llvm::SmallSetVector<StringAttr, 1> groups;
1774 module, [&groups, dedupGroupClass](Annotation annotation) {
1775 if (annotation.getClassAttr() != dedupGroupClass)
1776 return false;
1777 groups.insert(annotation.getMember<StringAttr>("group"));
1778 return true;
1779 });
1780 if (groups.size() > 1) {
1781 module.emitError("module belongs to multiple dedup groups: ") << groups;
1782 return signalPassFailure();
1783 }
1784 assert(!module->hasAttr(dedupGroupAttrName) &&
1785 "unexpected existing use of temporary dedup group attribute");
1786 if (!groups.empty())
1787 module->setDiscardableAttr(dedupGroupAttrName, groups.front());
1788 }
1789
1790 // Calculate module information parallelly.
1791 LLVM_DEBUG(llvm::dbgs() << "Computing module information\n");
1792 auto result = mlir::failableParallelForEach(
1793 context, llvm::seq(modules.size()), [&](unsigned idx) {
1794 auto module = modules[idx];
1795 // If the module is marked with NoDedup, just skip it.
1796 if (AnnotationSet::hasAnnotation(module, noDedupClass))
1797 return success();
1798
1799 // Only dedup extmodule's with defname.
1800 if (auto ext = dyn_cast<FExtModuleOp>(*module);
1801 ext && !ext.getDefname().has_value())
1802 return success();
1803
1804 // Only dedup classes if enabled.
1805 if (isa<ClassOp>(*module) && !dedupClasses)
1806 return success();
1807
1808 StructuralHasher hasher(hasherConstants);
1809 // Calculate the hash of the module and referred module names.
1810 moduleInfos[idx] = hasher.getModuleInfo(module);
1811 return success();
1812 });
1813
1814 // Dump out the module hashes for debugging.
1815 LLVM_DEBUG({
1816 auto &os = llvm::dbgs();
1817 for (auto [module, info] : llvm::zip(modules, moduleInfos)) {
1818 os << "- Hash ";
1819 if (info) {
1820 os << llvm::format_bytes(info->structuralHash, std::nullopt, 32, 32);
1821 } else {
1822 os << "--------------------------------";
1823 os << "--------------------------------";
1824 }
1825 os << " for " << module.getModuleNameAttr() << "\n";
1826 }
1827 });
1828
1829 if (result.failed())
1830 return signalPassFailure();
1831
1832 LLVM_DEBUG(llvm::dbgs() << "Update modules\n");
1833 for (auto [i, module] : llvm::enumerate(modules)) {
1834 auto moduleName = module.getModuleNameAttr();
1835 auto &maybeModuleInfo = moduleInfos[i];
1836 // If the hash was not calculated, we need to skip it.
1837 if (!maybeModuleInfo) {
1838 // We record it in the dedup map to help detect errors when the user
1839 // marks the module as both NoDedup and MustDedup. We do not record this
1840 // module in the hasher to make sure no other module dedups "into" this
1841 // one.
1842 dedupMap[moduleName] = moduleName;
1843 continue;
1844 }
1845
1846 auto &moduleInfo = maybeModuleInfo.value();
1847 moduleToModuleInfo.try_emplace(module, &moduleInfo);
1848
1849 // Replace module names referred in the module with new names.
1850 for (auto &referredModule : moduleInfo.referredModuleNames)
1851 referredModule = dedupMap[referredModule];
1852
1853 // Check if there is a module with the same hash.
1854 auto it = moduleInfoToModule.find(&moduleInfo);
1855 if (it != moduleInfoToModule.end()) {
1856 auto original = cast<FModuleLike>(it->second);
1857 auto originalName = original.getModuleNameAttr();
1858
1859 // If the current module is public, and the original is private, we
1860 // want to dedup the private module into the public one.
1861 if (!canRemoveModule(module)) {
1862 // Record that this module's name is staying the same.
1863 dedupMap[moduleName] = moduleName;
1864 // If both modules are public, then we can't dedup anything.
1865 if (!canRemoveModule(original))
1866 continue;
1867 // Swap the canonical module in the dedup map.
1868 for (auto &[_, dedupedName] : dedupMap)
1869 if (dedupedName == originalName)
1870 dedupedName = moduleName;
1871 // Update the module hash table to point to the new original, so all
1872 // future modules dedup with the new canonical module.
1873 it->second = module;
1874 // Swap the locals.
1875 std::swap(originalName, moduleName);
1876 std::swap(original, module);
1877 }
1878
1879 // Record the group ID of the other module.
1880 LLVM_DEBUG(llvm::dbgs() << "- Replace " << moduleName << " with "
1881 << originalName << "\n");
1882 dedupMap[moduleName] = originalName;
1883 deduper.dedup(original, module);
1884 ++erasedModules;
1885 anythingChanged = true;
1886 continue;
1887 }
1888 // Any module not deduplicated must be recorded.
1889 deduper.record(module);
1890 // Add the module to a new dedup group.
1891 dedupMap[moduleName] = moduleName;
1892 // Record the module info.
1893 moduleInfoToModule[&moduleInfo] = module;
1894 }
1895
1896 // This part verifies that all modules marked by "MustDedup" have been
1897 // properly deduped with each other. For this check to succeed, all modules
1898 // have to been deduped to the same module. It is possible that a module was
1899 // deduped with the wrong thing.
1900
1901 auto failed = false;
1902 // This parses the module name out of a target string.
1903 auto parseModule = [&](Attribute path) -> StringAttr {
1904 // Each module is listed as a target "~Circuit|Module" which we have to
1905 // parse.
1906 auto [_, rhs] = cast<StringAttr>(path).getValue().split('|');
1907 return StringAttr::get(context, rhs);
1908 };
1909 // This gets the name of the module which the current module was deduped
1910 // with. If the named module isn't in the map, then we didn't encounter it
1911 // in the circuit.
1912 auto getLead = [&](StringAttr module) -> StringAttr {
1913 auto it = dedupMap.find(module);
1914 if (it == dedupMap.end()) {
1915 auto diag = emitError(circuit.getLoc(),
1916 "MustDeduplicateAnnotation references module ")
1917 << module << " which does not exist";
1918 failed = true;
1919 return nullptr;
1920 }
1921 return it->second;
1922 };
1923
1924 LLVM_DEBUG(llvm::dbgs() << "Update annotations\n");
1925 AnnotationSet::removeAnnotations(circuit, [&](Annotation annotation) {
1926 if (!annotation.isClass(mustDeduplicateAnnoClass))
1927 return false;
1928 auto modules = annotation.getMember<ArrayAttr>("modules");
1929 if (!modules) {
1930 emitError(circuit.getLoc(),
1931 "MustDeduplicateAnnotation missing \"modules\" member");
1932 failed = true;
1933 return false;
1934 }
1935 // Empty module list has nothing to process.
1936 if (modules.empty())
1937 return true;
1938 // Get the first element.
1939 auto firstModule = parseModule(modules[0]);
1940 auto firstLead = getLead(firstModule);
1941 if (!firstLead)
1942 return false;
1943 // Verify that the remaining elements are all the same as the first.
1944 for (auto attr : modules.getValue().drop_front()) {
1945 auto nextModule = parseModule(attr);
1946 auto nextLead = getLead(nextModule);
1947 if (!nextLead)
1948 return false;
1949 if (firstLead != nextLead) {
1950 auto diag = emitError(circuit.getLoc(), "module ")
1951 << nextModule << " not deduplicated with " << firstModule;
1952 auto a = instanceGraph.lookup(firstLead)->getModule();
1953 auto b = instanceGraph.lookup(nextLead)->getModule();
1954 equiv.check(diag, a, b);
1955 failed = true;
1956 return false;
1957 }
1958 }
1959 return true;
1960 });
1961 if (failed)
1962 return signalPassFailure();
1963
1964 // Remove all dedup group attributes, they only exist during this pass.
1965 for (auto module : circuit.getOps<FModuleLike>())
1966 module->removeDiscardableAttr(dedupGroupAttrName);
1967
1968 // Fixup all operations that we've found to be sensitive to symbol names.
1969 // This includes module and class instances, wires, connects, etc.
1970 fixupSymbolSensitiveOps(instanceGraph, moduleToModuleInfo, dedupMap);
1971
1972 markAnalysesPreserved<NLATable>();
1973 if (!anythingChanged)
1974 markAllAnalysesPreserved();
1975 }
1976};
1977} // end anonymous namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location mergeLoc(MLIRContext *context, Location to, Location from)
Definition Dedup.cpp:969
static void fixupConnect(ImplicitLocOpBuilder &builder, Value dst, Value src)
This fixes up connects when the field names of a bundle type changes.
Definition Dedup.cpp:1595
static bool canRemoveModule(mlir::SymbolOpInterface symbol)
Returns true if the module can be removed.
Definition Dedup.cpp:56
static void fixupSymbolSensitiveOp(Operation *op, InstanceGraph &instanceGraph, const DenseMap< Attribute, StringAttr > &dedupMap)
Adjust the symbol references in an op.
Definition Dedup.cpp:1621
static void fixupSymbolSensitiveOps(InstanceGraph &instanceGraph, const DenseMap< Operation *, ModuleInfoRef > &moduleToModuleInfo, const DenseMap< Attribute, StringAttr > &dedupMap)
Adjust the symbol references in ops marked as sensitive to them.
Definition Dedup.cpp:1689
static void mergeRegions(Region *region1, Region *region2)
Definition HWCleanup.cpp:77
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
void setDict(DictionaryAttr dict)
Set the data dictionary of this attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
StringAttr getClassAttr() const
Return the 'class' that this annotation is representing.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Definition NLATable.h:29
A table of inner symbols and their resolutions.
auto getModule()
Get the module that this node is tracking.
decltype(auto) walkPostOrder(Fn &&fn)
Perform a post-order walk across the modules.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
static StringRef toString(Direction direction)
Definition FIRRTLEnums.h:44
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:35
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(const llvm::Twine &str, unsigned width=80)
Write a "header"-like string to the debug stream with a certain width.
Definition Debug.cpp:17
SmallVector< FlatSymbolRefAttr > createNLAs(Operation *fromModule, ArrayRef< Attribute > baseNamepath, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of the from module and create an NLA for each one, appending the baseNamep...
Definition Dedup.cpp:1157
MLIRContext * context
Definition Dedup.cpp:1563
Block * nlaBlock
We insert all NLAs to the beginning of this block.
Definition Dedup.cpp:1571
void recordAnnotations(Operation *op)
Record all targets which use an NLA.
Definition Dedup.cpp:1099
void eraseNLA(hw::HierPathOp nla)
This erases the NLA op, and removes the NLA from every module's NLA map, but it does not delete the N...
Definition Dedup.cpp:1221
void mergeAnnotations(FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Merge all annotations and port annotations on two operations.
Definition Dedup.cpp:1402
void replaceInstances(FModuleLike toModule, Operation *fromModule)
This deletes and replaces all instances of the "fromModule" with instances of the "toModule".
Definition Dedup.cpp:1115
void record(FModuleLike module)
Record the usages of any NLA's in this module, so that we may update the annotation if the parent mod...
Definition Dedup.cpp:1073
void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName, StringAttr fromName)
Definition Dedup.cpp:1300
void mergeRegions(RenameMap &renameMap, FModuleLike toModule, Region &toRegion, FModuleLike fromModule, Region &fromRegion)
Definition Dedup.cpp:1555
void dedup(FModuleLike toModule, FModuleLike fromModule)
Remove the "fromModule", and replace all references to it with the "toModule".
Definition Dedup.cpp:1038
void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all the NLAs that the two modules participate in, replacing references to the "from" module w...
Definition Dedup.cpp:1292
SmallVector< FlatSymbolRefAttr > createNLAs(StringAttr toModuleName, FModuleLike fromModule, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of this module and create an NLA for each one.
Definition Dedup.cpp:1194
void recordAnnotations(AnnoTarget target)
For a specific annotation target, record all the unique NLAs which target it in the targetMap.
Definition Dedup.cpp:1092
NLATable * nlaTable
Cached nla table analysis.
Definition Dedup.cpp:1568
hw::InnerSymAttr mergeInnerSymbols(RenameMap &renameMap, FModuleLike toModule, hw::InnerSymAttr toSym, hw::InnerSymAttr fromSym)
Definition Dedup.cpp:1427
void cloneAnnotation(SmallVectorImpl< FlatSymbolRefAttr > &nlas, Annotation anno, ArrayRef< NamedAttribute > attributes, unsigned nonLocalIndex, SmallVectorImpl< Annotation > &newAnnotations)
Clone the annotation for each NLA in a list.
Definition Dedup.cpp:1202
void recordSymRenames(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Definition Dedup.cpp:1477
void mergeAnnotations(FModuleLike toModule, AnnoTarget to, AnnotationSet toAnnos, FModuleLike fromModule, AnnoTarget from, AnnotationSet fromAnnos)
Merge the annotations of a specific target, either a operation or a port on an operation.
Definition Dedup.cpp:1379
StringAttr nonLocalString
Definition Dedup.cpp:1581
hw::InnerSymbolNamespace & getNamespace(Operation *module)
Get a cached namespace for a module.
Definition Dedup.cpp:1085
SymbolTable & symbolTable
Definition Dedup.cpp:1565
void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Recursively merge two operations.
Definition Dedup.cpp:1522
DenseMap< Operation *, hw::InnerSymbolNamespace > moduleNamespaces
A module namespace cache.
Definition Dedup.cpp:1585
bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to, FModuleLike fromModule, Annotation anno, SmallVectorImpl< Annotation > &newAnnotations)
Take an annotation, and update it to be a non-local annotation.
Definition Dedup.cpp:1308
InstanceGraph & instanceGraph
Definition Dedup.cpp:1564
void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock, FModuleLike fromModule, Block &fromBlock)
Recursively merge two blocks.
Definition Dedup.cpp:1541
DenseMap< Attribute, llvm::SmallDenseSet< AnnoTarget > > targetMap
Definition Dedup.cpp:1574
StringAttr classString
Definition Dedup.cpp:1582
void copyAnnotations(FModuleLike toModule, AnnoTarget to, FModuleLike fromModule, AnnotationSet annos, SmallVectorImpl< Annotation > &newAnnotations, SmallPtrSetImpl< Attribute > &dontTouches)
Definition Dedup.cpp:1350
Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable, NLATable *nlaTable, CircuitOp circuit)
Definition Dedup.cpp:1022
void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all NLAs referencing the "from" module to point to the "to" module.
Definition Dedup.cpp:1231
DenseMap< StringAttr, StringAttr > RenameMap
Definition Dedup.cpp:1020
DenseMap< Attribute, Attribute > nlaCache
Definition Dedup.cpp:1578
const hw::InnerSymbolTable & a
Definition Dedup.cpp:488
ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
Definition Dedup.cpp:485
const hw::InnerSymbolTable & b
Definition Dedup.cpp:489
This class is for reporting differences between two modules which should have been deduplicated.
Definition Dedup.cpp:467
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:956
std::string prettyPrint(Attribute attr)
Definition Dedup.cpp:492
LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a, FInstanceLike b)
Definition Dedup.cpp:775
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Block &aBlock, Operation *b, Block &bBlock)
Definition Dedup.cpp:564
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, Type aType, Operation *b, Type bType)
Definition Dedup.cpp:538
StringAttr noDedupClass
Definition Dedup.cpp:951
StringAttr dedupGroupAttrName
Definition Dedup.cpp:953
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, DictionaryAttr aDict, Operation *b, DictionaryAttr bDict)
Definition Dedup.cpp:687
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Region &aRegion, Operation *b, Region &bRegion)
Definition Dedup.cpp:642
StringAttr portDirectionsAttr
Definition Dedup.cpp:949
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, BundleType aType, Operation *b, BundleType bType)
Definition Dedup.cpp:508
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Operation *b)
Definition Dedup.cpp:813
Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
Definition Dedup.cpp:468
LogicalResult check(InFlightDiagnostic &diag, Operation *a, mlir::DenseBoolArrayAttr aAttr, Operation *b, mlir::DenseBoolArrayAttr bAttr)
Definition Dedup.cpp:662
InstanceGraph & instanceGraph
Definition Dedup.cpp:957
void check(InFlightDiagnostic &diag, Operation *a, Operation *b)
Definition Dedup.cpp:900
A reference to a ModuleInfo that compares and hashes like it.
Definition Dedup.cpp:422
ModuleInfo * info
Definition Dedup.cpp:424
ModuleInfoRef(ModuleInfo *info)
Definition Dedup.cpp:423
std::vector< Operation * > symbolSensitiveOps
Definition Dedup.cpp:91
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:87
std::array< uint8_t, 32 > structuralHash
Definition Dedup.cpp:85
This struct contains constant string attributes shared across different threads.
Definition Dedup.cpp:101
DenseSet< Attribute > nonessentialAttributes
Definition Dedup.cpp:131
StructuralHasherSharedConstants(MLIRContext *context)
Definition Dedup.cpp:102
void populateInnerSymIDTable(FModuleLike module)
Find all the ports and operations which may define an inner symbol operations and give each a unique ...
Definition Dedup.cpp:151
void update(Operation *op, DictionaryAttr dict)
Hash the top level attribute dictionary of the operation.
Definition Dedup.cpp:274
void update(Type type)
Definition Dedup.cpp:251
void update(const void *pointer)
Definition Dedup.cpp:208
void update(ClassType type)
Definition Dedup.cpp:236
DenseMap< void *, unsigned > idTable
Definition Dedup.cpp:394
void update(const std::pair< T, U > &pair)
Definition Dedup.cpp:219
void update(Operation *op)
Definition Dedup.cpp:360
llvm::SHA256 sha
Definition Dedup.cpp:408
DenseMap< StringAttr, std::pair< size_t, size_t > > innerSymIDTable
Definition Dedup.cpp:398
ModuleInfo getModuleInfo(FModuleLike module)
Definition Dedup.cpp:138
void update(size_t value)
Definition Dedup.cpp:213
void update(BundleType type)
Definition Dedup.cpp:227
unsigned getID(void *object)
Definition Dedup.cpp:170
void update(OpResult result)
Definition Dedup.cpp:259
void update(OpOperand &operand)
Definition Dedup.cpp:191
StructuralHasher(const StructuralHasherSharedConstants &constants)
Definition Dedup.cpp:135
std::vector< Operation * > symbolSensitiveOps
Definition Dedup.cpp:416
void update(Region *region)
Definition Dedup.cpp:352
void update(Block *block)
Definition Dedup.cpp:341
std::vector< StringAttr > referredModuleNames
Definition Dedup.cpp:401
void update(TypeID typeID)
Definition Dedup.cpp:224
const StructuralHasherSharedConstants & constants
Definition Dedup.cpp:404
std::pair< size_t, size_t > getInnerSymID(StringAttr name)
Definition Dedup.cpp:187
unsigned finalizeID(void *object)
Definition Dedup.cpp:178
void update(mlir::OperationName name)
Definition Dedup.cpp:335
An annotation target is used to keep track of something that is targeted by an Annotation.
AnnotationSet getAnnotations() const
Get the annotations associated with the target.
void setAnnotations(AnnotationSet annotations) const
Set the annotations associated with the target.
This represents an annotation targeting a specific operation.
Attribute getNLAReference(hw::InnerSymbolNamespace &moduleNamespace) const
This represents an annotation targeting a specific port of a module, memory, or instance.
static bool isEqual(const ModuleInfoRef &lhs, const ModuleInfoRef &rhs)
Definition Dedup.cpp:451
static ModuleInfoRef getTombstoneKey()
Definition Dedup.cpp:435
static ModuleInfoRef getEmptyKey()
Definition Dedup.cpp:431
static unsigned getHashValue(const ModuleInfoRef &ref)
Definition Dedup.cpp:439