24#include "mlir/IR/IRMapping.h"
25#include "mlir/IR/Threading.h"
26#include "mlir/Pass/Pass.h"
27#include "llvm/ADT/DenseMap.h"
28#include "llvm/ADT/DenseMapInfo.h"
29#include "llvm/ADT/Hashing.h"
30#include "llvm/ADT/PostOrderIterator.h"
31#include "llvm/ADT/SmallPtrSet.h"
32#include "llvm/Support/Format.h"
33#include "llvm/Support/SHA256.h"
37#define GEN_PASS_DEF_DEDUP
38#include "circt/Dialect/FIRRTL/Passes.h.inc"
43using namespace firrtl;
44using hw::InnerRefAttr;
53 if (!symbol.isPrivate())
58 if (isa<ClassLike>(*symbol))
63 if (!symbol.canDiscardOnUseEmpty())
73llvm::raw_ostream &
printHex(llvm::raw_ostream &stream,
74 ArrayRef<uint8_t> bytes) {
76 return stream << format_bytes(bytes, std::nullopt, 32) <<
"\n";
79llvm::raw_ostream &
printHash(llvm::raw_ostream &stream, llvm::SHA256 &
data) {
83llvm::raw_ostream &
printHash(llvm::raw_ostream &stream, std::string
data) {
84 ArrayRef<uint8_t> bytes(
reinterpret_cast<const uint8_t *
>(
data.c_str()),
146 : constants(constants) {}
164 auto it =
idTable.find(
object);
167 auto id = it->second;
180 auto value = operand.get();
181 if (
auto result = dyn_cast<OpResult>(value)) {
182 auto *op = result.getOwner();
184 update(result.getResultNumber());
187 if (
auto argument = dyn_cast<BlockArgument>(value)) {
188 auto *block = argument.getOwner();
190 update(argument.getArgNumber());
193 llvm_unreachable(
"Unknown value type");
197 auto *
addr =
reinterpret_cast<const uint8_t *
>(&pointer);
198 sha.update(ArrayRef<uint8_t>(
addr,
sizeof pointer));
202 auto *
addr =
reinterpret_cast<const uint8_t *
>(&value);
203 sha.update(ArrayRef<uint8_t>(
addr,
sizeof value));
211 for (
auto &element : type.getElements()) {
219 if (
auto bundle = type_dyn_cast<BundleType>(type))
221 update(type.getAsOpaquePointer());
228 if (
auto objectOp = dyn_cast<ObjectOp>(result.getOwner())) {
238 void update(Operation *op, DictionaryAttr dict) {
239 for (
auto namedAttr : dict) {
240 auto name = namedAttr.getName();
241 auto value = namedAttr.getValue();
244 bool isClassPortNames =
250 update(name.getAsOpaquePointer());
254 auto portTypes = cast<ArrayAttr>(value).getAsValueRange<TypeAttr>();
255 for (
auto type : portTypes)
262 if (op->getNumRegions() != 1)
264 auto ®ion = op->getRegion(0);
265 if (region.getBlocks().empty())
267 for (
auto sym : cast<ArrayAttr>(value).getAsRange<hw::InnerSymAttr>()) {
268 for (
auto property : sym) {
269 update(property.getFieldID());
277 auto innerSym = cast<hw::InnerSymAttr>(value);
278 for (
auto property : innerSym) {
279 update(property.getFieldID());
297 if (isa<DistinctAttr>(value))
301 if (
auto innerRef = dyn_cast<hw::InnerRefAttr>(value))
304 update(value.getAsOpaquePointer());
310 update(name.getAsOpaquePointer());
315 for (
auto &op : llvm::reverse(*block))
317 for (
auto type : block->getArgumentTypes())
326 for (
auto &block : llvm::reverse(region->getBlocks()))
336 update(op->getNumRegions());
337 for (
auto ®ion : reverse(op->getRegions()))
343 for (
auto &operand : op->getOpOperands())
348 update(op, op->getAttrDictionary());
351 for (
auto result : op->getResults())
414 SmallString<64> buffer;
415 llvm::raw_svector_ostream os(buffer);
416 if (
auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
418 if (integerAttr.getType().isSignlessInteger())
419 integerAttr.getValue().toStringUnsigned(buffer, 16);
421 integerAttr.getAPSInt().toString(buffer, 16);
425 return std::string(buffer);
429 LogicalResult
check(InFlightDiagnostic &diag,
const Twine &message,
430 Operation *a, BundleType aType, Operation *b,
432 if (aType.getNumElements() != bType.getNumElements()) {
433 diag.attachNote(a->getLoc())
434 << message <<
" bundle type has different number of elements";
435 diag.attachNote(b->getLoc()) <<
"second operation here";
439 for (
auto elementPair :
440 llvm::zip(aType.getElements(), bType.getElements())) {
441 auto aElement = std::get<0>(elementPair);
442 auto bElement = std::get<1>(elementPair);
443 if (aElement.isFlip != bElement.isFlip) {
444 diag.attachNote(a->getLoc()) << message <<
" bundle element "
445 << aElement.name <<
" flip does not match";
446 diag.attachNote(b->getLoc()) <<
"second operation here";
450 if (failed(
check(diag,
451 "bundle element \'" + aElement.name.getValue() +
"'", a,
452 aElement.type, b, bElement.type)))
458 LogicalResult
check(InFlightDiagnostic &diag,
const Twine &message,
459 Operation *a, Type aType, Operation *b, Type bType) {
462 if (
auto aBundleType = type_dyn_cast<BundleType>(aType))
463 if (
auto bBundleType = type_dyn_cast<BundleType>(bType))
464 return check(diag, message, a, aBundleType, b, bBundleType);
465 if (type_isa<RefType>(aType) && type_isa<RefType>(bType) &&
467 diag.attachNote(a->getLoc())
468 << message <<
", has a RefType with a different base type "
469 << type_cast<RefType>(aType).getType()
470 <<
" in the same position of the two modules marked as 'must dedup'. "
471 "(This may be due to Grand Central Taps or Views being different "
472 "between the two modules.)";
473 diag.attachNote(b->getLoc())
474 <<
"the second module has a different base type "
475 << type_cast<RefType>(bType).getType();
478 diag.attachNote(a->getLoc())
479 << message <<
" types don't match, first type is " << aType;
480 diag.attachNote(b->getLoc()) <<
"second type is " << bType;
485 Block &aBlock, Operation *b, Block &bBlock) {
488 auto portNames = a->getAttrOfType<ArrayAttr>(
"portNames");
490 auto emitMissingPort = [&](Value existsVal, Operation *opExists,
491 Operation *opDoesNotExist) {
493 auto portNames = opExists->getAttrOfType<ArrayAttr>(
"portNames");
495 if (
auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
496 portName = portNameAttr.getValue();
497 if (type_isa<RefType>(existsVal.getType())) {
498 diag.attachNote(opExists->getLoc())
499 <<
" contains a RefType port named '" + portName +
500 "' that only exists in one of the modules (can be due to "
501 "difference in Grand Central Tap or View of two modules "
502 "marked with must dedup)";
503 diag.attachNote(opDoesNotExist->getLoc())
504 <<
"second module to be deduped that does not have the RefType "
507 diag.attachNote(opExists->getLoc())
508 <<
"port '" + portName +
"' only exists in one of the modules";
509 diag.attachNote(opDoesNotExist->getLoc())
510 <<
"second module to be deduped that does not have the port";
516 llvm::zip_longest(aBlock.getArguments(), bBlock.getArguments())) {
517 auto &aArg = std::get<0>(argPair);
518 auto &bArg = std::get<1>(argPair);
519 if (aArg.has_value() && bArg.has_value()) {
524 if (
auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
525 portName = portNameAttr.getValue();
528 if (failed(
check(diag,
"module port '" + portName +
"'", a,
529 aArg->getType(), b, bArg->getType())))
531 data.map.map(aArg.value(), bArg.value());
535 if (!aArg.has_value())
537 return emitMissingPort(aArg.has_value() ? aArg.value() : bArg.value(), a,
542 auto aIt = aBlock.begin();
543 auto aEnd = aBlock.end();
544 auto bIt = bBlock.begin();
545 auto bEnd = bBlock.end();
546 while (aIt != aEnd && bIt != bEnd)
547 if (failed(
check(diag,
data, &*aIt++, &*bIt++)))
550 diag.attachNote(aIt->getLoc()) <<
"first block has more operations";
551 diag.attachNote(b->getLoc()) <<
"second block here";
555 diag.attachNote(bIt->getLoc()) <<
"second block has more operations";
556 diag.attachNote(a->getLoc()) <<
"first block here";
563 Region &aRegion, Operation *b, Region &bRegion) {
564 auto aIt = aRegion.begin();
565 auto aEnd = aRegion.end();
566 auto bIt = bRegion.begin();
567 auto bEnd = bRegion.end();
570 while (aIt != aEnd && bIt != bEnd)
571 if (failed(
check(diag,
data, a, *aIt++, b, *bIt++)))
573 if (aIt != aEnd || bIt != bEnd) {
574 diag.attachNote(a->getLoc())
575 <<
"operation regions have different number of blocks";
576 diag.attachNote(b->getLoc()) <<
"second operation here";
582 LogicalResult
check(InFlightDiagnostic &diag, Operation *a,
583 mlir::DenseBoolArrayAttr aAttr, Operation *b,
584 mlir::DenseBoolArrayAttr bAttr) {
587 auto portNames = a->getAttrOfType<ArrayAttr>(
"portNames");
588 for (
unsigned i = 0, e = aAttr.size(); i < e; ++i) {
589 auto aDirection = aAttr[i];
590 auto bDirection = bAttr[i];
591 if (aDirection != bDirection) {
592 auto ¬e = diag.attachNote(a->getLoc()) <<
"module port ";
594 note <<
"'" << cast<StringAttr>(portNames[i]).getValue() <<
"'";
597 note <<
" directions don't match, first direction is '"
599 diag.attachNote(b->getLoc()) <<
"second direction is '"
608 DictionaryAttr aDict, Operation *b,
609 DictionaryAttr bDict) {
614 DenseSet<Attribute> seenAttrs;
615 for (
auto namedAttr : aDict) {
616 auto attrName = namedAttr.getName();
620 auto aAttr = namedAttr.getValue();
621 auto bAttr = bDict.get(attrName);
623 diag.attachNote(a->getLoc())
624 <<
"second operation is missing attribute " << attrName;
625 diag.attachNote(b->getLoc()) <<
"second operation here";
629 if (isa<hw::InnerRefAttr>(aAttr) && isa<hw::InnerRefAttr>(bAttr)) {
630 auto bRef = cast<hw::InnerRefAttr>(bAttr);
631 auto aRef = cast<hw::InnerRefAttr>(aAttr);
633 auto aTarget =
data.a.lookup(aRef.getName());
634 auto bTarget =
data.b.lookup(bRef.getName());
635 if (!aTarget || !bTarget)
636 diag.attachNote(a->getLoc())
637 <<
"malformed ir, possibly violating use-before-def";
639 diag.attachNote(a->getLoc())
640 <<
"operations have different targets, first operation has "
642 diag.attachNote(b->getLoc()) <<
"second operation has " << bTarget;
645 if (aTarget.isPort()) {
647 if (!bTarget.isPort() || aTarget.getPort() != bTarget.getPort())
651 if (!bTarget.isOpOnly() ||
652 aTarget.getOp() !=
data.map.lookup(bTarget.getOp()))
655 if (aTarget.getField() != bTarget.getField())
660 if (failed(
check(diag, a, cast<mlir::DenseBoolArrayAttr>(aAttr), b,
661 cast<mlir::DenseBoolArrayAttr>(bAttr))))
663 }
else if (isa<DistinctAttr>(aAttr) && isa<DistinctAttr>(bAttr)) {
666 }
else if (aAttr != bAttr) {
667 diag.attachNote(a->getLoc())
668 <<
"first operation has attribute '" << attrName.getValue()
670 diag.attachNote(b->getLoc())
671 <<
"second operation has value " <<
prettyPrint(bAttr);
674 seenAttrs.insert(attrName);
676 if (aDict.getValue().size() != bDict.getValue().size()) {
677 for (
auto namedAttr : bDict) {
678 auto attrName = namedAttr.getName();
682 seenAttrs.contains(attrName))
685 diag.attachNote(a->getLoc())
686 <<
"first operation is missing attribute " << attrName;
687 diag.attachNote(b->getLoc()) <<
"second operation here";
695 LogicalResult
check(InFlightDiagnostic &diag, FInstanceLike a,
697 auto aName = a.getReferencedModuleNameAttr();
698 auto bName = b.getReferencedModuleNameAttr();
708 diag.attachNote(std::nullopt)
709 <<
"in instance " << a.getInstanceNameAttr() <<
" of " << aName
710 <<
", and instance " << b.getInstanceNameAttr() <<
" of " << bName;
711 check(diag, aModule, bModule);
719 if (a->getName() != b->getName()) {
720 diag.attachNote(a->getLoc()) <<
"first operation is a " << a->getName();
721 diag.attachNote(b->getLoc()) <<
"second operation is a " << b->getName();
727 if (
auto aInst = dyn_cast<FInstanceLike>(a)) {
728 auto bInst = cast<FInstanceLike>(b);
729 if (failed(
check(diag, aInst, bInst)))
734 if (a->getNumResults() != b->getNumResults()) {
735 diag.attachNote(a->getLoc())
736 <<
"operations have different number of results";
737 diag.attachNote(b->getLoc()) <<
"second operation here";
740 for (
auto resultPair : llvm::zip(a->getResults(), b->getResults())) {
741 auto &aValue = std::get<0>(resultPair);
742 auto &bValue = std::get<1>(resultPair);
743 if (failed(
check(diag,
"operation result", a, aValue.getType(), b,
746 data.map.map(aValue, bValue);
750 if (a->getNumOperands() != b->getNumOperands()) {
751 diag.attachNote(a->getLoc())
752 <<
"operations have different number of operands";
753 diag.attachNote(b->getLoc()) <<
"second operation here";
756 for (
auto operandPair : llvm::zip(a->getOperands(), b->getOperands())) {
757 auto &aValue = std::get<0>(operandPair);
758 auto &bValue = std::get<1>(operandPair);
759 if (bValue !=
data.map.lookup(aValue)) {
760 diag.attachNote(a->getLoc())
761 <<
"operations use different operands, first operand is '"
766 diag.attachNote(b->getLoc())
767 <<
"second operand is '"
771 <<
"', but should have been '"
782 if (a->getNumRegions() != b->getNumRegions()) {
783 diag.attachNote(a->getLoc())
784 <<
"operations have different number of regions";
785 diag.attachNote(b->getLoc()) <<
"second operation here";
788 for (
auto regionPair : llvm::zip(a->getRegions(), b->getRegions())) {
789 auto &aRegion = std::get<0>(regionPair);
790 auto &bRegion = std::get<1>(regionPair);
791 if (failed(
check(diag,
data, a, aRegion, b, bRegion)))
796 if (failed(
check(diag,
data, a, a->getAttrDictionary(), b,
797 b->getAttrDictionary())))
803 void check(InFlightDiagnostic &diag, Operation *a, Operation *b) {
808 diag.attachNote(a->getLoc()) <<
"module marked NoDedup";
812 diag.attachNote(b->getLoc()) <<
"module marked NoDedup";
815 auto aSymbol = cast<mlir::SymbolOpInterface>(a);
816 auto bSymbol = cast<mlir::SymbolOpInterface>(b);
818 diag.attachNote(a->getLoc())
820 << (aSymbol.isPrivate() ?
"private but not discardable" :
"public");
824 diag.attachNote(b->getLoc())
826 << (bSymbol.isPrivate() ?
"private but not discardable" :
"public");
831 auto bGroup = dyn_cast_or_null<StringAttr>(
833 if (aGroup != bGroup) {
835 diag.attachNote(b->getLoc())
836 <<
"module is in dedup group '" << bGroup.str() <<
"'";
838 diag.attachNote(b->getLoc()) <<
"module is not part of a dedup group";
841 diag.attachNote(a->getLoc())
842 <<
"module is in dedup group '" << aGroup.str() <<
"'";
844 diag.attachNote(a->getLoc()) <<
"module is not part of a dedup group";
850 diag.attachNote(a->getLoc()) <<
"first module here";
851 diag.attachNote(b->getLoc()) <<
"second module here";
875static Location
mergeLoc(MLIRContext *context, Location to, Location from) {
877 llvm::SmallSetVector<Location, 4> decomposedLocs;
879 unsigned seenFIR = 0;
880 for (
auto loc : {to, from}) {
883 if (
auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
886 for (
auto loc : fusedLoc.getLocations()) {
887 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
888 if (fileLoc.getFilename().strref().ends_with(
".fir")) {
894 decomposedLocs.insert(loc);
900 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
901 if (fileLoc.getFilename().strref().ends_with(
".fir")) {
908 if (!isa<UnknownLoc>(loc))
909 decomposedLocs.insert(loc);
912 auto locs = decomposedLocs.getArrayRef();
917 return UnknownLoc::get(context);
918 if (locs.size() == 1)
921 return FusedLoc::get(context, locs);
936 for (
auto nla : circuit.getOps<hw::HierPathOp>())
937 nlaCache[nla.getNamepathAttr()] = nla.getSymNameAttr();
944 void dedup(FModuleLike toModule, FModuleLike fromModule) {
950 SmallVector<Attribute> newLocs;
951 for (
auto [toLoc, fromLoc] : llvm::zip(toModule.getPortLocations(),
952 fromModule.getPortLocations())) {
953 if (toLoc == fromLoc)
954 newLocs.push_back(toLoc);
957 cast<LocationAttr>(fromLoc)));
959 toModule->setAttr(
"portLocations", ArrayAttr::get(
context, newLocs));
962 mergeOps(renameMap, toModule, toModule, fromModule, fromModule);
968 if (
auto to = dyn_cast<FModuleOp>(*toModule))
972 fromModule.getModuleNameAttr());
983 for (
unsigned i = 0, e =
getNumPorts(module); i < e; ++i)
986 module->walk([&](Operation *op) { recordAnnotations(op); });
992 return moduleNamespaces.try_emplace(module, cast<FModuleLike>(module))
1000 if (
auto nlaRef = anno.getMember<FlatSymbolRefAttr>(
"circt.nonlocal"))
1001 targetMap[nlaRef.getAttr()].insert(target);
1010 auto mem = dyn_cast<MemOp>(op);
1015 for (
unsigned i = 0, e = mem->getNumResults(); i < e; ++i)
1024 instanceGraph[::cast<igraph::ModuleOpInterface>(fromModule)];
1025 auto *toNode = instanceGraph[toModule];
1026 auto toModuleRef = FlatSymbolRefAttr::get(toModule.getModuleNameAttr());
1027 for (
auto *oldInstRec : llvm::make_early_inc_range(fromNode->uses())) {
1028 auto inst = oldInstRec->getInstance();
1029 if (
auto instOp = dyn_cast<InstanceOp>(*inst)) {
1030 instOp.setModuleNameAttr(toModuleRef);
1031 instOp.setPortNamesAttr(toModule.getPortNamesAttr());
1032 }
else if (
auto objectOp = dyn_cast<ObjectOp>(*inst)) {
1033 auto classLike = cast<ClassLike>(*toNode->getModule());
1034 ClassType classType = detail::getInstanceTypeForClassLike(classLike);
1035 objectOp.getResult().setType(classType);
1037 oldInstRec->getParent()->addInstance(inst, toNode);
1038 oldInstRec->erase();
1040 instanceGraph.erase(fromNode);
1041 fromModule->erase();
1049 SmallVector<FlatSymbolRefAttr>
1050 createNLAs(Operation *fromModule, ArrayRef<Attribute> baseNamepath,
1051 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1054 SmallVector<Attribute> namepath = {
nullptr};
1055 namepath.append(baseNamepath.begin(), baseNamepath.end());
1057 auto loc = fromModule->getLoc();
1058 auto *fromNode = instanceGraph[cast<igraph::ModuleOpInterface>(fromModule)];
1059 SmallVector<FlatSymbolRefAttr> nlas;
1060 for (
auto *instanceRecord : fromNode->uses()) {
1061 auto parent = cast<FModuleOp>(*instanceRecord->getParent()->getModule());
1062 auto inst = instanceRecord->getInstance();
1064 auto arrayAttr = ArrayAttr::get(context, namepath);
1066 auto &cacheEntry = nlaCache[arrayAttr];
1068 auto nla = OpBuilder::atBlockBegin(nlaBlock).create<hw::HierPathOp>(
1069 loc,
"nla", arrayAttr);
1071 symbolTable.insert(nla);
1073 cacheEntry = nla.getNameAttr();
1074 nla.setVisibility(vis);
1075 nlaTable->addNLA(nla);
1077 auto nlaRef = FlatSymbolRefAttr::get(cast<StringAttr>(cacheEntry));
1078 nlas.push_back(nlaRef);
1086 SmallVector<FlatSymbolRefAttr>
1088 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1089 return createNLAs(fromModule, FlatSymbolRefAttr::get(toModuleName), vis);
1096 Annotation anno, ArrayRef<NamedAttribute> attributes,
1097 unsigned nonLocalIndex,
1098 SmallVectorImpl<Annotation> &newAnnotations) {
1099 SmallVector<NamedAttribute> mutableAttributes(attributes.begin(),
1101 for (
auto &nla : nlas) {
1103 mutableAttributes[nonLocalIndex].setValue(nla);
1104 auto dict = DictionaryAttr::getWithSorted(context, mutableAttributes);
1107 newAnnotations.push_back(anno);
1116 targetMap.erase(nla.getNameAttr());
1117 nlaTable->erase(nla);
1118 nlaCache.erase(nla.getNamepathAttr());
1119 symbolTable.erase(nla);
1125 FModuleOp fromModule) {
1126 auto toName = toModule.getNameAttr();
1127 auto fromName = fromModule.getNameAttr();
1130 auto moduleNLAs = nlaTable->lookup(fromModule.getNameAttr()).vec();
1132 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1135 for (
auto nla : moduleNLAs) {
1136 auto elements = nla.getNamepath().getValue();
1138 if (nla.root() != toName)
1141 SmallVector<Attribute> namepath(elements.begin(), elements.end());
1142 auto nlaRefs = createNLAs(fromModule, namepath, nla.getVisibility());
1144 auto &set = targetMap[nla.getSymNameAttr()];
1145 SmallVector<AnnoTarget> targets(set.begin(), set.end());
1147 for (
auto target : targets) {
1150 SmallVector<Annotation> newAnnotations;
1151 for (
auto anno : target.getAnnotations()) {
1153 auto [it, found] = mlir::impl::findAttrSorted(
1154 anno.begin(), anno.end(), nonLocalString);
1157 if (!found || cast<FlatSymbolRefAttr>(it->getValue()).getAttr() !=
1158 nla.getSymNameAttr()) {
1159 newAnnotations.push_back(anno);
1162 auto nonLocalIndex = std::distance(anno.begin(), it);
1164 cloneAnnotation(nlaRefs, anno,
1165 ArrayRef<NamedAttribute>(anno.begin(), anno.end()),
1166 nonLocalIndex, newAnnotations);
1171 target.setAnnotations(annotations);
1173 for (
auto nla : nlaRefs)
1174 targetMap[nla.getAttr()].insert(target);
1186 FModuleOp fromModule) {
1187 addAnnotationContext(renameMap, toModule, toModule);
1188 addAnnotationContext(renameMap, toModule, fromModule);
1194 StringAttr fromName) {
1195 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1203 SmallVectorImpl<Annotation> &newAnnotations) {
1206 SmallVector<NamedAttribute> attributes;
1207 int nonLocalIndex = -1;
1208 for (
const auto &val : llvm::enumerate(anno)) {
1209 auto attr = val.value();
1211 auto compare = attr.getName().compare(nonLocalString);
1212 assert(compare != 0 &&
"should not pass non-local annotations here");
1216 nonLocalIndex = val.index();
1217 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1222 attributes.push_back(attr);
1224 if (nonLocalIndex == -1) {
1226 nonLocalIndex = attributes.size();
1227 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1230 attributes.append(anno.
begin() + nonLocalIndex, anno.
end());
1234 auto nlaRefs = createNLAs(toModuleName, fromModule);
1235 for (
auto nla : nlaRefs)
1236 targetMap[nla.getAttr()].insert(to);
1239 cloneAnnotation(nlaRefs, anno, attributes, nonLocalIndex, newAnnotations);
1245 SmallVectorImpl<Annotation> &newAnnotations,
1246 SmallPtrSetImpl<Attribute> &dontTouches) {
1247 for (
auto anno : annos) {
1251 anno.removeMember(
"circt.nonlocal");
1252 auto [it, inserted] = dontTouches.insert(anno.getAttr());
1254 newAnnotations.push_back(anno);
1259 if (
auto nla = anno.getMember<FlatSymbolRefAttr>(
"circt.nonlocal")) {
1260 newAnnotations.push_back(anno);
1261 targetMap[nla.getAttr()].insert(to);
1265 makeAnnotationNonLocal(toModule.getModuleNameAttr(), to, fromModule, anno,
1276 SmallVector<Annotation> newAnnotations;
1280 llvm::SmallPtrSet<Attribute, 4> dontTouches;
1284 copyAnnotations(toModule, to, toModule, toAnnos, newAnnotations,
1286 copyAnnotations(toModule, to, fromModule, fromAnnos, newAnnotations,
1290 if (!newAnnotations.empty())
1296 FModuleLike fromModule, Operation *from) {
1302 if (toModule == to) {
1304 for (
unsigned i = 0, e =
getNumPorts(toModule); i < e; ++i)
1309 }
else if (
auto toMem = dyn_cast<MemOp>(to)) {
1311 auto fromMem = cast<MemOp>(from);
1312 for (
unsigned i = 0, e = toMem.getNumResults(); i < e; ++i)
1324 Operation *to, FModuleLike fromModule,
1331 return getNamespace(toModule);
1333 renameMap[fromSym] = toSym;
1337 auto fromPortSyms = from->getAttrOfType<ArrayAttr>(
"portSymbols");
1338 if (!fromPortSyms || fromPortSyms.empty())
1341 auto &moduleNamespace = getNamespace(toModule);
1342 auto portCount = fromPortSyms.size();
1343 auto portNames = to->getAttrOfType<ArrayAttr>(
"portNames");
1344 auto toPortSyms = to->getAttrOfType<ArrayAttr>(
"portSymbols");
1348 SmallVector<Attribute> newPortSyms;
1349 if (toPortSyms.empty())
1350 newPortSyms.assign(portCount, hw::InnerSymAttr());
1352 newPortSyms.assign(toPortSyms.begin(), toPortSyms.end());
1354 for (
unsigned portNo = 0; portNo < portCount; ++portNo) {
1356 if (!fromPortSyms[portNo])
1358 auto fromSym = cast<hw::InnerSymAttr>(fromPortSyms[portNo]);
1361 hw::InnerSymAttr toSym;
1362 if (!newPortSyms[portNo]) {
1364 StringRef symName =
"inner_sym";
1366 symName = cast<StringAttr>(portNames[portNo]).getValue();
1368 toSym = hw::InnerSymAttr::get(
1369 StringAttr::get(context, moduleNamespace.newName(symName)));
1370 newPortSyms[portNo] = toSym;
1372 toSym = cast<hw::InnerSymAttr>(newPortSyms[portNo]);
1375 renameMap[fromSym.getSymName()] = toSym.getSymName();
1379 FModuleLike::fixupPortSymsArray(newPortSyms, toModule.getContext());
1380 cast<FModuleLike>(to).setPortSymbols(newPortSyms);
1386 FModuleLike fromModule, Operation *from) {
1388 if (to->getLoc() != from->getLoc())
1389 to->setLoc(
mergeLoc(context, to->getLoc(), from->getLoc()));
1392 for (
auto regions : llvm::zip(to->getRegions(), from->getRegions()))
1393 mergeRegions(renameMap, toModule, std::get<0>(regions), fromModule,
1394 std::get<1>(regions));
1397 recordSymRenames(renameMap, toModule, to, fromModule, from);
1400 mergeAnnotations(toModule, to, fromModule, from);
1405 FModuleLike fromModule, Block &fromBlock) {
1407 for (
auto [toArg, fromArg] :
1408 llvm::zip(toBlock.getArguments(), fromBlock.getArguments()))
1409 if (toArg.getLoc() != fromArg.getLoc())
1410 toArg.setLoc(
mergeLoc(context, toArg.getLoc(), fromArg.getLoc()));
1412 for (
auto ops : llvm::zip(toBlock, fromBlock))
1413 mergeOps(renameMap, toModule, &std::get<0>(ops), fromModule,
1419 Region &toRegion, FModuleLike fromModule,
1420 Region &fromRegion) {
1421 for (
auto blocks : llvm::zip(toRegion, fromRegion))
1422 mergeBlocks(renameMap, toModule, std::get<0>(blocks), fromModule,
1423 std::get<1>(blocks));
1437 DenseMap<Attribute, llvm::SmallDenseSet<AnnoTarget>>
targetMap;
1461 SmallVector<Attribute> newPortTypes;
1462 bool anyDifferences =
false;
1465 for (
size_t i = 0, e = classOp.getNumPorts(); i < e; ++i) {
1468 auto portClassType = dyn_cast<ClassType>(classOp.getPortType(i));
1469 if (!portClassType) {
1470 newPortTypes.push_back(classOp.getPortTypeAttr(i));
1475 Type newPortClassType;
1476 BlockArgument portArg = classOp.getArgument(i);
1477 for (
auto &use : portArg.getUses()) {
1478 if (
auto propassign = dyn_cast<PropAssignOp>(use.getOwner())) {
1479 Type sourceType = propassign.getSrc().getType();
1480 if (propassign.getDest() == use.get() && sourceType != portClassType) {
1482 if (newPortClassType) {
1483 assert(newPortClassType == sourceType &&
1484 "expected all references to be of the same type");
1488 newPortClassType = sourceType;
1495 if (!newPortClassType) {
1496 newPortTypes.push_back(classOp.getPortTypeAttr(i));
1502 classOp.getArgument(i).setType(newPortClassType);
1503 newPortTypes.push_back(TypeAttr::get(newPortClassType));
1504 anyDifferences =
true;
1509 classOp.setPortTypes(newPortTypes);
1511 return anyDifferences;
1518 objectOp.getResult().setType(newClassType);
1526 auto dstType = dst.getType();
1527 auto srcType = src.getType();
1528 if (dstType == srcType) {
1534 auto dstBundle = type_cast<BundleType>(dstType);
1535 auto srcBundle = type_cast<BundleType>(srcType);
1536 for (
unsigned i = 0; i < dstBundle.getNumElements(); ++i) {
1537 auto dstField = builder.create<SubfieldOp>(dst, i);
1538 auto srcField = builder.create<SubfieldOp>(src, i);
1539 if (dstBundle.getElement(i).isFlip) {
1540 std::swap(srcBundle, dstBundle);
1541 std::swap(srcField, dstField);
1551 for (
auto *node : instanceGraph) {
1552 auto module = cast<FModuleLike>(*node->getModule());
1555 bool shouldFixupObjects =
false;
1556 auto classOp = dyn_cast<ClassOp>(module.getOperation());
1560 for (
auto *instRec : node->uses()) {
1563 if (shouldFixupObjects) {
1565 classOp.getInstanceType());
1570 auto inst = instRec->getInstance<InstanceOp>();
1574 ImplicitLocOpBuilder builder(inst.getLoc(), inst->getContext());
1575 builder.setInsertionPointAfter(inst);
1576 for (
size_t i = 0, e =
getNumPorts(module); i < e; ++i) {
1577 auto result = inst.getResult(i);
1578 auto newType =
module.getPortType(i);
1579 auto oldType = result.getType();
1581 if (newType == oldType)
1586 builder.create<WireOp>(oldType, inst.getPortName(i)).getResult();
1587 result.replaceAllUsesWith(wire);
1588 result.setType(newType);
1589 if (inst.getPortDirection(i) == Direction::Out)
1606 std::array<uint8_t, 32> key;
1607 std::fill(key.begin(), key.end(), ~0);
1612 std::array<uint8_t, 32> key;
1613 std::fill(key.begin(), key.end(), ~0 - 1);
1621 std::memcpy(&hash, val.structuralHash.data(),
sizeof(
unsigned));
1624 return llvm::hash_combine(
1625 hash, llvm::hash_combine_range(val.referredModuleNames.begin(),
1626 val.referredModuleNames.end()));
1629 static bool isEqual(
const ModuleInfo &lhs,
const ModuleInfo &rhs) {
1640class DedupPass :
public circt::firrtl::impl::DedupBase<DedupPass> {
1641 void runOnOperation()
override {
1642 auto *context = &getContext();
1643 auto circuit = getOperation();
1644 auto &instanceGraph = getAnalysis<InstanceGraph>();
1645 auto *nlaTable = &getAnalysis<NLATable>();
1646 auto &symbolTable = getAnalysis<SymbolTable>();
1647 Deduper deduper(instanceGraph, symbolTable, nlaTable, circuit);
1649 auto anythingChanged =
false;
1658 llvm::DenseMap<ModuleInfo, Operation *> moduleInfoToModule;
1663 DenseMap<Attribute, StringAttr> dedupMap;
1668 SmallVector<FModuleLike, 0> modules(
1669 llvm::map_range(llvm::post_order(&instanceGraph), [](
auto *node) {
1670 return cast<FModuleLike>(*node->getModule());
1673 SmallVector<std::optional<ModuleInfo>> moduleInfos(modules.size());
1677 auto dedupGroupAttrName = StringAttr::get(context,
"firrtl.dedup_group");
1683 for (
auto module : modules) {
1684 llvm::SmallSetVector<StringAttr, 1> groups;
1686 module, [&groups, dedupGroupClass](
Annotation annotation) {
1689 groups.insert(annotation.
getMember<StringAttr>(
"group"));
1692 if (groups.size() > 1) {
1693 module.emitError("module belongs to multiple dedup groups: ") << groups;
1694 return signalPassFailure();
1696 assert(!module->hasAttr(dedupGroupAttrName) &&
1697 "unexpected existing use of temporary dedup group attribute");
1698 if (!groups.empty())
1699 module->setDiscardableAttr(dedupGroupAttrName, groups.front());
1703 auto result = mlir::failableParallelForEach(
1704 context, llvm::seq(modules.size()), [&](
unsigned idx) {
1705 auto module = modules[idx];
1707 if (AnnotationSet::hasAnnotation(module, noDedupClass))
1711 if (auto ext = dyn_cast<FExtModuleOp>(*module);
1712 ext && !ext.getDefname().has_value())
1715 StructuralHasher hasher(hasherConstants);
1717 moduleInfos[idx] = hasher.getModuleInfo(module);
1721 if (result.failed())
1722 return signalPassFailure();
1724 for (
auto [i, module] :
llvm::enumerate(modules)) {
1725 auto moduleName =
module.getModuleNameAttr();
1726 auto &maybeModuleInfo = moduleInfos[i];
1728 if (!maybeModuleInfo) {
1733 dedupMap[moduleName] = moduleName;
1737 auto &moduleInfo = maybeModuleInfo.value();
1740 for (
auto &referredModule : moduleInfo.referredModuleNames)
1741 referredModule = dedupMap[referredModule];
1744 auto it = moduleInfoToModule.find(moduleInfo);
1745 if (it != moduleInfoToModule.end()) {
1746 auto original = cast<FModuleLike>(it->second);
1747 auto originalName = original.getModuleNameAttr();
1756 for (
auto &[originalName, dedupedName] : dedupMap)
1757 if (dedupedName == originalName)
1758 dedupedName = moduleName;
1761 it->second =
module;
1763 std::swap(originalName, moduleName);
1764 std::swap(original, module);
1768 dedupMap[moduleName] = originalName;
1769 deduper.dedup(original, module);
1771 anythingChanged =
true;
1775 deduper.record(module);
1777 dedupMap[moduleName] = moduleName;
1779 moduleInfoToModule[std::move(moduleInfo)] =
module;
1787 auto failed =
false;
1789 auto parseModule = [&](Attribute path) -> StringAttr {
1792 auto [_, rhs] = cast<StringAttr>(path).getValue().split(
'|');
1793 return StringAttr::get(context, rhs);
1798 auto getLead = [&](StringAttr module) -> StringAttr {
1799 auto it = dedupMap.find(module);
1800 if (it == dedupMap.end()) {
1801 auto diag = emitError(circuit.getLoc(),
1802 "MustDeduplicateAnnotation references module ")
1803 <<
module << " which does not exist";
1813 auto modules = annotation.
getMember<ArrayAttr>(
"modules");
1815 emitError(circuit.getLoc(),
1816 "MustDeduplicateAnnotation missing \"modules\" member");
1821 if (modules.empty())
1824 auto firstModule = parseModule(modules[0]);
1825 auto firstLead = getLead(firstModule);
1829 for (
auto attr : modules.getValue().drop_front()) {
1830 auto nextModule = parseModule(attr);
1831 auto nextLead = getLead(nextModule);
1834 if (firstLead != nextLead) {
1835 auto diag = emitError(circuit.getLoc(),
"module ")
1836 << nextModule <<
" not deduplicated with " << firstModule;
1839 equiv.check(diag, a, b);
1847 return signalPassFailure();
1850 for (
auto module : circuit.getOps<FModuleLike>())
1851 module->removeDiscardableAttr(dedupGroupAttrName);
1858 markAnalysesPreserved<NLATable>();
1859 if (!anythingChanged)
1860 markAllAnalysesPreserved();
1866 return std::make_unique<DedupPass>();
assert(baseType &&"element must be base type")
static void mergeRegions(Region *region1, Region *region2)
static Block * getBodyBlock(FModuleLike mod)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
void setDict(DictionaryAttr dict)
Set the data dictionary of this attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
StringAttr getClassAttr() const
Return the 'class' that this annotation is representing.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
A table of inner symbols and their resolutions.
auto getModule()
Get the module that this node is tracking.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
static StringRef toString(Direction direction)
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
constexpr const char * mustDedupAnnoClass
std::pair< hw::InnerSymAttr, StringAttr > getOrAddInnerSym(MLIRContext *context, hw::InnerSymAttr attr, uint64_t fieldID, llvm::function_ref< hw::InnerSymbolNamespace &()> getNamespace)
Ensure that the the InnerSymAttr has a symbol on the field specified.
constexpr const char * noDedupAnnoClass
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
constexpr const char * dedupGroupAnnoClass
std::unique_ptr< mlir::Pass > createDedupPass()
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
constexpr const char * dontTouchAnnoClass
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
static bool operator==(const ModulePort &a, const ModulePort &b)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
SmallVector< FlatSymbolRefAttr > createNLAs(Operation *fromModule, ArrayRef< Attribute > baseNamepath, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of the from module and create an NLA for each one, appending the baseNamep...
Block * nlaBlock
We insert all NLAs to the beginning of this block.
void recordAnnotations(Operation *op)
Record all targets which use an NLA.
void eraseNLA(hw::HierPathOp nla)
This erases the NLA op, and removes the NLA from every module's NLA map, but it does not delete the N...
void mergeAnnotations(FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Merge all annotations and port annotations on two operations.
void replaceInstances(FModuleLike toModule, Operation *fromModule)
This deletes and replaces all instances of the "fromModule" with instances of the "toModule".
void record(FModuleLike module)
Record the usages of any NLA's in this module, so that we may update the annotation if the parent mod...
void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName, StringAttr fromName)
void mergeRegions(RenameMap &renameMap, FModuleLike toModule, Region &toRegion, FModuleLike fromModule, Region &fromRegion)
void dedup(FModuleLike toModule, FModuleLike fromModule)
Remove the "fromModule", and replace all references to it with the "toModule".
void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all the NLAs that the two modules participate in, replacing references to the "from" module w...
SmallVector< FlatSymbolRefAttr > createNLAs(StringAttr toModuleName, FModuleLike fromModule, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of this module and create an NLA for each one.
void recordAnnotations(AnnoTarget target)
For a specific annotation target, record all the unique NLAs which target it in the targetMap.
NLATable * nlaTable
Cached nla table analysis.
void cloneAnnotation(SmallVectorImpl< FlatSymbolRefAttr > &nlas, Annotation anno, ArrayRef< NamedAttribute > attributes, unsigned nonLocalIndex, SmallVectorImpl< Annotation > &newAnnotations)
Clone the annotation for each NLA in a list.
void recordSymRenames(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
void mergeAnnotations(FModuleLike toModule, AnnoTarget to, AnnotationSet toAnnos, FModuleLike fromModule, AnnoTarget from, AnnotationSet fromAnnos)
Merge the annotations of a specific target, either a operation or a port on an operation.
StringAttr nonLocalString
hw::InnerSymbolNamespace & getNamespace(Operation *module)
Get a cached namespace for a module.
SymbolTable & symbolTable
void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Recursively merge two operations.
DenseMap< Operation *, hw::InnerSymbolNamespace > moduleNamespaces
A module namespace cache.
bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to, FModuleLike fromModule, Annotation anno, SmallVectorImpl< Annotation > &newAnnotations)
Take an annotation, and update it to be a non-local annotation.
InstanceGraph & instanceGraph
void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock, FModuleLike fromModule, Block &fromBlock)
Recursively merge two blocks.
DenseMap< Attribute, llvm::SmallDenseSet< AnnoTarget > > targetMap
void copyAnnotations(FModuleLike toModule, AnnoTarget to, FModuleLike fromModule, AnnotationSet annos, SmallVectorImpl< Annotation > &newAnnotations, SmallPtrSetImpl< Attribute > &dontTouches)
Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable, NLATable *nlaTable, CircuitOp circuit)
void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all NLAs referencing the "from" module to point to the "to" module.
DenseMap< StringAttr, StringAttr > RenameMap
DenseMap< Attribute, Attribute > nlaCache
const hw::InnerSymbolTable & a
ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
const hw::InnerSymbolTable & b
This class is for reporting differences between two modules which should have been deduplicated.
DenseSet< Attribute > nonessentialAttributes
std::string prettyPrint(Attribute attr)
LogicalResult check(InFlightDiagnostic &diag, FInstanceLike a, FInstanceLike b)
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Block &aBlock, Operation *b, Block &bBlock)
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, Type aType, Operation *b, Type bType)
StringAttr dedupGroupAttrName
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, DictionaryAttr aDict, Operation *b, DictionaryAttr bDict)
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Region &aRegion, Operation *b, Region &bRegion)
StringAttr portDirectionsAttr
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, BundleType aType, Operation *b, BundleType bType)
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Operation *b)
Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
LogicalResult check(InFlightDiagnostic &diag, Operation *a, mlir::DenseBoolArrayAttr aAttr, Operation *b, mlir::DenseBoolArrayAttr bAttr)
InstanceGraph & instanceGraph
void check(InFlightDiagnostic &diag, Operation *a, Operation *b)
std::vector< StringAttr > referredModuleNames
std::array< uint8_t, 32 > structuralHash
This struct contains constant string attributes shared across different threads.
StringAttr portSymbolsAttr
StringAttr moduleNameAttr
DenseSet< Attribute > nonessentialAttributes
StructuralHasherSharedConstants(MLIRContext *context)
void update(Operation *op, DictionaryAttr dict)
Hash the top level attribute dictionary of the operation.
void update(const void *pointer)
DenseMap< void *, unsigned > idTable
void update(Operation *op)
ModuleInfo getModuleInfo(FModuleLike module)
void update(size_t value)
unsigned getInnerSymID(StringAttr name)
void update(BundleType type)
unsigned getID(void *object)
void update(OpResult result)
void update(OpOperand &operand)
StructuralHasher(const StructuralHasherSharedConstants &constants)
void update(Region *region)
void update(Block *block)
std::vector< StringAttr > referredModuleNames
DenseMap< StringAttr, unsigned > innerSymIDTable
void update(TypeID typeID)
const StructuralHasherSharedConstants & constants
unsigned finalizeID(void *object)
void update(mlir::OperationName name)
An annotation target is used to keep track of something that is targeted by an Annotation.
AnnotationSet getAnnotations() const
Get the annotations associated with the target.
void setAnnotations(AnnotationSet annotations) const
Set the annotations associated with the target.
This represents an annotation targeting a specific operation.
Attribute getNLAReference(hw::InnerSymbolNamespace &moduleNamespace) const
This represents an annotation targeting a specific port of a module, memory, or instance.
static ModuleInfo getEmptyKey()
static ModuleInfo getTombstoneKey()
static unsigned getHashValue(const ModuleInfo &val)
static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs)