25 #include "mlir/IR/IRMapping.h"
26 #include "mlir/IR/ImplicitLocOpBuilder.h"
27 #include "mlir/IR/Threading.h"
28 #include "mlir/Support/LogicalResult.h"
29 #include "llvm/ADT/DenseMap.h"
30 #include "llvm/ADT/DenseMapInfo.h"
31 #include "llvm/ADT/DepthFirstIterator.h"
32 #include "llvm/ADT/PostOrderIterator.h"
33 #include "llvm/ADT/SmallPtrSet.h"
34 #include "llvm/ADT/TypeSwitch.h"
35 #include "llvm/Support/Format.h"
36 #include "llvm/Support/SHA256.h"
38 using namespace circt;
39 using namespace firrtl;
40 using hw::InnerRefAttr;
46 llvm::raw_ostream &
printHex(llvm::raw_ostream &stream,
47 ArrayRef<uint8_t> bytes) {
49 return stream << format_bytes(bytes, std::nullopt, 32) <<
"\n";
52 llvm::raw_ostream &
printHash(llvm::raw_ostream &stream, llvm::SHA256 &data) {
56 llvm::raw_ostream &
printHash(llvm::raw_ostream &stream, std::string data) {
57 ArrayRef<uint8_t> bytes(
reinterpret_cast<const uint8_t *
>(
data.c_str()),
89 nonessentialAttributes.insert(
StringAttr::get(context,
"portAnnotations"));
91 nonessentialAttributes.insert(
StringAttr::get(context,
"portLocations"));
113 : constants(constants){};
115 std::pair<std::array<uint8_t, 32>, SmallVector<StringAttr>>
119 sha.update(group.str());
120 auto hash = sha.final();
121 return {hash, referredModuleNames};
126 auto *
addr =
reinterpret_cast<const uint8_t *
>(&pointer);
127 sha.update(ArrayRef<uint8_t>(
addr,
sizeof pointer));
131 auto *
addr =
reinterpret_cast<const uint8_t *
>(&
value);
132 sha.update(ArrayRef<uint8_t>(
addr,
sizeof value));
139 update(type.getTypeID());
140 for (
auto &element : type.getElements()) {
141 update(element.isFlip);
142 update(element.type);
148 if (
auto bundle = type_dyn_cast<BundleType>(type))
149 return update(bundle);
150 update(type.getAsOpaquePointer());
154 auto size = indices.size();
155 indices[address] = size;
158 void update(BlockArgument arg) { record(arg.getAsOpaquePointer()); }
161 record(result.getAsOpaquePointer());
162 update(result.getType());
167 auto it = indices.find(operand.get().getAsOpaquePointer());
168 assert(it != indices.end() &&
"op should have been previously hashed");
172 void update(Operation *op, hw::InnerSymAttr attr) {
173 for (
auto props : attr)
174 innerSymTargets[props.getName()] =
179 for (
auto props : attr)
180 innerSymTargets[props.getName()] =
185 update(target.
index);
191 auto it = innerSymTargets.find(attr.getName());
192 assert(it != innerSymTargets.end() &&
193 "inner symbol should have been previously hashed");
194 update(attr.getTypeID());
200 void update(Operation *op, DictionaryAttr dict) {
201 for (
auto namedAttr : dict) {
202 auto name = namedAttr.getName();
203 auto value = namedAttr.getValue();
205 if (constants.nonessentialAttributes.contains(name))
209 if (name == constants.portTypesAttr) {
210 auto portTypes = cast<ArrayAttr>(
value).getAsValueRange<TypeAttr>();
211 for (
auto type : portTypes)
217 if (name == constants.portSymsAttr) {
218 if (op->getNumRegions() != 1)
220 auto ®ion = op->getRegion(0);
221 if (region.getBlocks().empty())
223 auto *block = ®ion.front();
224 auto syms = cast<ArrayAttr>(
value).getAsRange<hw::InnerSymAttr>();
227 for (
auto [arg, sym] : llvm::zip_equal(block->getArguments(), syms))
231 if (name == constants.innerSymAttr) {
232 auto innerSym = cast<hw::InnerSymAttr>(
value);
233 update(op, innerSym);
242 if (isa<InstanceOp>(op) && name == constants.moduleNameAttr) {
243 referredModuleNames.push_back(cast<FlatSymbolRefAttr>(
value).
getAttr());
248 update(name.getAsOpaquePointer());
251 if (
auto innerRef = dyn_cast<hw::InnerRefAttr>(
value))
254 update(
value.getAsOpaquePointer());
260 for (
auto arg : block.getArguments())
263 for (
auto &op : block)
269 update(name.getAsOpaquePointer());
275 update(op->getName());
276 update(op, op->getAttrDictionary());
278 for (
auto &operand : op->getOpOperands())
282 update(op->getNumRegions());
283 for (
auto ®ion : op->getRegions())
284 for (
auto &block : region.getBlocks())
287 for (
auto result : op->getResults())
317 : instanceGraph(instanceGraph) {
321 nonessentialAttributes.insert(
StringAttr::get(context,
"annotations"));
323 nonessentialAttributes.insert(
StringAttr::get(context,
"portAnnotations"));
327 nonessentialAttributes.insert(
StringAttr::get(context,
"portLocations"));
333 ModuleData(
const hw::InnerSymbolTable &a,
const hw::InnerSymbolTable &b)
336 const hw::InnerSymbolTable &
a;
337 const hw::InnerSymbolTable &
b;
341 SmallString<64> buffer;
342 llvm::raw_svector_ostream os(buffer);
343 if (
auto integerAttr = dyn_cast<IntegerAttr>(attr)) {
345 if (integerAttr.getType().isSignlessInteger())
346 integerAttr.getValue().toStringUnsigned(buffer, 16);
348 integerAttr.getAPSInt().toString(buffer, 16);
352 return std::string(buffer);
356 LogicalResult
check(InFlightDiagnostic &diag,
const Twine &message,
357 Operation *a, BundleType aType, Operation *b,
359 if (aType.getNumElements() != bType.getNumElements()) {
360 diag.attachNote(a->getLoc())
361 << message <<
" bundle type has different number of elements";
362 diag.attachNote(b->getLoc()) <<
"second operation here";
366 for (
auto elementPair :
367 llvm::zip(aType.getElements(), bType.getElements())) {
368 auto aElement = std::get<0>(elementPair);
369 auto bElement = std::get<1>(elementPair);
370 if (aElement.isFlip != bElement.isFlip) {
371 diag.attachNote(a->getLoc()) << message <<
" bundle element "
372 << aElement.name <<
" flip does not match";
373 diag.attachNote(b->getLoc()) <<
"second operation here";
377 if (failed(check(diag,
378 "bundle element \'" + aElement.name.getValue() +
"'", a,
379 aElement.type, b, bElement.type)))
385 LogicalResult
check(InFlightDiagnostic &diag,
const Twine &message,
386 Operation *a, Type aType, Operation *b, Type bType) {
389 if (
auto aBundleType = type_dyn_cast<BundleType>(aType))
390 if (
auto bBundleType = type_dyn_cast<BundleType>(bType))
391 return check(diag, message, a, aBundleType, b, bBundleType);
392 if (type_isa<RefType>(aType) && type_isa<RefType>(bType) &&
394 diag.attachNote(a->getLoc())
395 << message <<
", has a RefType with a different base type "
396 << type_cast<RefType>(aType).getType()
397 <<
" in the same position of the two modules marked as 'must dedup'. "
398 "(This may be due to Grand Central Taps or Views being different "
399 "between the two modules.)";
400 diag.attachNote(b->getLoc())
401 <<
"the second module has a different base type "
402 << type_cast<RefType>(bType).getType();
405 diag.attachNote(a->getLoc())
406 << message <<
" types don't match, first type is " << aType;
407 diag.attachNote(b->getLoc()) <<
"second type is " << bType;
412 Block &aBlock, Operation *b, Block &bBlock) {
415 auto portNames = a->getAttrOfType<ArrayAttr>(
"portNames");
417 auto emitMissingPort = [&](Value existsVal, Operation *opExists,
418 Operation *opDoesNotExist) {
420 auto portNames = opExists->getAttrOfType<ArrayAttr>(
"portNames");
422 if (
auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
423 portName = portNameAttr.getValue();
424 if (type_isa<RefType>(existsVal.getType())) {
425 diag.attachNote(opExists->getLoc())
426 <<
" contains a RefType port named '" + portName +
427 "' that only exists in one of the modules (can be due to "
428 "difference in Grand Central Tap or View of two modules "
429 "marked with must dedup)";
430 diag.attachNote(opDoesNotExist->getLoc())
431 <<
"second module to be deduped that does not have the RefType "
434 diag.attachNote(opExists->getLoc())
435 <<
"port '" + portName +
"' only exists in one of the modules";
436 diag.attachNote(opDoesNotExist->getLoc())
437 <<
"second module to be deduped that does not have the port";
443 llvm::zip_longest(aBlock.getArguments(), bBlock.getArguments())) {
444 auto &aArg = std::get<0>(argPair);
445 auto &bArg = std::get<1>(argPair);
446 if (aArg.has_value() && bArg.has_value()) {
451 if (
auto portNameAttr = dyn_cast<StringAttr>(portNames[portNo]))
452 portName = portNameAttr.getValue();
455 if (failed(check(diag,
"module port '" + portName +
"'", a,
456 aArg->getType(), b, bArg->getType())))
458 data.map.map(aArg.value(), bArg.value());
462 if (!aArg.has_value())
464 return emitMissingPort(aArg.has_value() ? aArg.value() : bArg.value(), a,
469 auto aIt = aBlock.begin();
470 auto aEnd = aBlock.end();
471 auto bIt = bBlock.begin();
472 auto bEnd = bBlock.end();
473 while (aIt != aEnd && bIt != bEnd)
474 if (failed(check(diag,
data, &*aIt++, &*bIt++)))
477 diag.attachNote(aIt->getLoc()) <<
"first block has more operations";
478 diag.attachNote(b->getLoc()) <<
"second block here";
482 diag.attachNote(bIt->getLoc()) <<
"second block has more operations";
483 diag.attachNote(a->getLoc()) <<
"first block here";
490 Region &aRegion, Operation *b, Region &bRegion) {
491 auto aIt = aRegion.begin();
492 auto aEnd = aRegion.end();
493 auto bIt = bRegion.begin();
494 auto bEnd = bRegion.end();
497 while (aIt != aEnd && bIt != bEnd)
498 if (failed(check(diag,
data, a, *aIt++, b, *bIt++)))
500 if (aIt != aEnd || bIt != bEnd) {
501 diag.attachNote(a->getLoc())
502 <<
"operation regions have different number of blocks";
503 diag.attachNote(b->getLoc()) <<
"second operation here";
509 LogicalResult
check(InFlightDiagnostic &diag, Operation *a, IntegerAttr aAttr,
510 Operation *b, IntegerAttr bAttr) {
515 auto portNames = a->getAttrOfType<ArrayAttr>(
"portNames");
516 for (
unsigned i = 0, e = aDirections.size(); i < e; ++i) {
517 auto aDirection = aDirections[i];
518 auto bDirection = bDirections[i];
519 if (aDirection != bDirection) {
520 auto ¬e = diag.attachNote(a->getLoc()) <<
"module port ";
522 note <<
"'" << cast<StringAttr>(portNames[i]).getValue() <<
"'";
525 note <<
" directions don't match, first direction is '"
527 diag.attachNote(b->getLoc()) <<
"second direction is '"
536 DictionaryAttr aDict, Operation *b,
537 DictionaryAttr bDict) {
542 DenseSet<Attribute> seenAttrs;
543 for (
auto namedAttr : aDict) {
544 auto attrName = namedAttr.getName();
545 if (nonessentialAttributes.contains(attrName))
548 auto aAttr = namedAttr.getValue();
549 auto bAttr = bDict.get(attrName);
551 diag.attachNote(a->getLoc())
552 <<
"second operation is missing attribute " << attrName;
553 diag.attachNote(b->getLoc()) <<
"second operation here";
557 if (isa<hw::InnerRefAttr>(aAttr) && isa<hw::InnerRefAttr>(bAttr)) {
558 auto bRef = cast<hw::InnerRefAttr>(bAttr);
559 auto aRef = cast<hw::InnerRefAttr>(aAttr);
561 auto aTarget =
data.a.lookup(aRef.getName());
562 auto bTarget =
data.b.lookup(bRef.getName());
563 if (!aTarget || !bTarget)
564 diag.attachNote(a->getLoc())
565 <<
"malformed ir, possibly violating use-before-def";
567 diag.attachNote(a->getLoc())
568 <<
"operations have different targets, first operation has "
570 diag.attachNote(b->getLoc()) <<
"second operation has " << bTarget;
573 if (aTarget.isPort()) {
575 if (!bTarget.isPort() || aTarget.getPort() != bTarget.getPort())
579 if (!bTarget.isOpOnly() ||
580 aTarget.getOp() !=
data.map.lookup(bTarget.getOp()))
583 if (aTarget.getField() != bTarget.getField())
585 }
else if (attrName == portDirectionsAttr) {
588 if (failed(check(diag, a, cast<IntegerAttr>(aAttr), b,
589 cast<IntegerAttr>(bAttr))))
591 }
else if (aAttr != bAttr) {
592 diag.attachNote(a->getLoc())
593 <<
"first operation has attribute '" << attrName.getValue()
594 <<
"' with value " << prettyPrint(aAttr);
595 diag.attachNote(b->getLoc())
596 <<
"second operation has value " << prettyPrint(bAttr);
599 seenAttrs.insert(attrName);
601 if (aDict.getValue().size() != bDict.getValue().size()) {
602 for (
auto namedAttr : bDict) {
603 auto attrName = namedAttr.getName();
606 if (nonessentialAttributes.contains(attrName) ||
607 seenAttrs.contains(attrName))
610 diag.attachNote(a->getLoc())
611 <<
"first operation is missing attribute " << attrName;
612 diag.attachNote(b->getLoc()) <<
"second operation here";
620 LogicalResult
check(InFlightDiagnostic &diag, InstanceOp a, InstanceOp b) {
621 auto aName = a.getModuleNameAttr().getAttr();
622 auto bName = b.getModuleNameAttr().getAttr();
626 if (aName != bName) {
627 auto aModule = instanceGraph.getReferencedModule(a);
628 auto bModule = instanceGraph.getReferencedModule(b);
630 diag.attachNote(std::nullopt)
631 <<
"in instance " << a.getNameAttr() <<
" of " << aName
632 <<
", and instance " << b.getNameAttr() <<
" of " << bName;
633 check(diag, aModule, bModule);
643 if (a->getName() != b->getName()) {
644 diag.attachNote(a->getLoc()) <<
"first operation is a " << a->getName();
645 diag.attachNote(b->getLoc()) <<
"second operation is a " << b->getName();
651 if (
auto aInst = dyn_cast<InstanceOp>(a)) {
652 auto bInst = cast<InstanceOp>(b);
653 if (failed(check(diag, aInst, bInst)))
658 if (a->getNumResults() != b->getNumResults()) {
659 diag.attachNote(a->getLoc())
660 <<
"operations have different number of results";
661 diag.attachNote(b->getLoc()) <<
"second operation here";
664 for (
auto resultPair : llvm::zip(a->getResults(), b->getResults())) {
665 auto &aValue = std::get<0>(resultPair);
666 auto &bValue = std::get<1>(resultPair);
667 if (failed(check(diag,
"operation result", a, aValue.getType(), b,
670 data.map.map(aValue, bValue);
674 if (a->getNumOperands() != b->getNumOperands()) {
675 diag.attachNote(a->getLoc())
676 <<
"operations have different number of operands";
677 diag.attachNote(b->getLoc()) <<
"second operation here";
680 for (
auto operandPair : llvm::zip(a->getOperands(), b->getOperands())) {
681 auto &aValue = std::get<0>(operandPair);
682 auto &bValue = std::get<1>(operandPair);
683 if (bValue !=
data.map.lookup(aValue)) {
684 diag.attachNote(a->getLoc())
685 <<
"operations use different operands, first operand is '"
690 diag.attachNote(b->getLoc())
691 <<
"second operand is '"
695 <<
"', but should have been '"
706 if (a->getNumRegions() != b->getNumRegions()) {
707 diag.attachNote(a->getLoc())
708 <<
"operations have different number of regions";
709 diag.attachNote(b->getLoc()) <<
"second operation here";
712 for (
auto regionPair : llvm::zip(a->getRegions(), b->getRegions())) {
713 auto &aRegion = std::get<0>(regionPair);
714 auto &bRegion = std::get<1>(regionPair);
715 if (failed(check(diag,
data, a, aRegion, b, bRegion)))
720 if (failed(check(diag,
data, a, a->getAttrDictionary(), b,
721 b->getAttrDictionary())))
727 void check(InFlightDiagnostic &diag, Operation *a, Operation *b) {
728 hw::InnerSymbolTable aTable(a);
729 hw::InnerSymbolTable bTable(b);
734 diag.attachNote(a->getLoc()) <<
"module marked NoDedup";
738 diag.attachNote(b->getLoc()) <<
"module marked NoDedup";
749 if (aGroup != bGroup) {
751 diag.attachNote(b->getLoc())
752 <<
"module is in dedup group '" << bGroup.str() <<
"'";
754 diag.attachNote(b->getLoc()) <<
"module is not part of a dedup group";
757 diag.attachNote(a->getLoc())
758 <<
"module is in dedup group '" << aGroup.str() <<
"'";
760 diag.attachNote(a->getLoc()) <<
"module is not part of a dedup group";
764 if (failed(check(diag,
data, a, b)))
766 diag.attachNote(a->getLoc()) <<
"first module here";
767 diag.attachNote(b->getLoc()) <<
"second module here";
790 static Location
mergeLoc(MLIRContext *context, Location to, Location from) {
792 llvm::SmallSetVector<Location, 4> decomposedLocs;
794 unsigned seenFIR = 0;
795 for (
auto loc : {to, from}) {
798 if (
auto fusedLoc = dyn_cast<FusedLoc>(loc)) {
801 for (
auto loc : fusedLoc.getLocations()) {
802 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
803 if (fileLoc.getFilename().strref().endswith(
".fir")) {
809 decomposedLocs.insert(loc);
815 if (FileLineColLoc fileLoc = dyn_cast<FileLineColLoc>(loc)) {
816 if (fileLoc.getFilename().strref().endswith(
".fir")) {
823 if (!isa<UnknownLoc>(loc))
824 decomposedLocs.insert(loc);
827 auto locs = decomposedLocs.getArrayRef();
833 if (locs.size() == 1)
844 NLATable *nlaTable, CircuitOp circuit)
845 : context(circuit->getContext()), instanceGraph(instanceGraph),
846 symbolTable(symbolTable), nlaTable(nlaTable),
847 nlaBlock(circuit.getBodyBlock()),
848 nonLocalString(StringAttr::
get(context,
"circt.nonlocal")),
849 classString(StringAttr::
get(context,
"class")) {
851 for (
auto nla : circuit.getOps<hw::HierPathOp>())
852 nlaCache[nla.getNamepathAttr()] = nla.getSymNameAttr();
859 void dedup(FModuleLike toModule, FModuleLike fromModule) {
865 SmallVector<Attribute> newLocs;
866 for (
auto [toLoc, fromLoc] : llvm::zip(toModule.getPortLocations(),
867 fromModule.getPortLocations())) {
868 if (toLoc == fromLoc)
869 newLocs.push_back(toLoc);
871 newLocs.push_back(
mergeLoc(context, cast<LocationAttr>(toLoc),
872 cast<LocationAttr>(fromLoc)));
874 toModule->setAttr(
"portLocations",
ArrayAttr::get(context, newLocs));
877 mergeOps(renameMap, toModule, toModule, fromModule, fromModule);
883 if (
auto to = dyn_cast<FModuleOp>(*toModule))
884 rewriteModuleNLAs(renameMap, to, cast<FModuleOp>(*fromModule));
886 rewriteExtModuleNLAs(renameMap, toModule.getModuleNameAttr(),
887 fromModule.getModuleNameAttr());
889 replaceInstances(toModule, fromModule);
896 recordAnnotations(module);
898 for (
unsigned i = 0, e =
getNumPorts(module); i < e; ++i)
901 module->walk([&](Operation *op) { recordAnnotations(op); });
907 return moduleNamespaces.try_emplace(module, cast<FModuleLike>(module))
915 if (
auto nlaRef = anno.getMember<FlatSymbolRefAttr>(
"circt.nonlocal"))
916 targetMap[nlaRef.getAttr()].insert(target);
925 auto mem = dyn_cast<MemOp>(op);
930 for (
unsigned i = 0, e = mem->getNumResults(); i < e; ++i)
939 instanceGraph[::cast<igraph::ModuleOpInterface>(fromModule)];
940 auto *toNode = instanceGraph[toModule];
942 for (
auto *oldInstRec : llvm::make_early_inc_range(fromNode->uses())) {
943 auto inst = ::cast<InstanceOp>(*oldInstRec->getInstance());
944 inst.setModuleNameAttr(toModuleRef);
945 inst.setPortNamesAttr(toModule.getPortNamesAttr());
946 oldInstRec->getParent()->addInstance(inst, toNode);
949 instanceGraph.erase(fromNode);
958 SmallVector<FlatSymbolRefAttr>
959 createNLAs(Operation *fromModule, ArrayRef<Attribute> baseNamepath,
960 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
963 SmallVector<Attribute> namepath = {
nullptr};
964 namepath.append(baseNamepath.begin(), baseNamepath.end());
966 auto loc = fromModule->getLoc();
967 auto *fromNode = instanceGraph[cast<igraph::ModuleOpInterface>(fromModule)];
968 SmallVector<FlatSymbolRefAttr> nlas;
969 for (
auto *instanceRecord : fromNode->uses()) {
970 auto parent = cast<FModuleOp>(*instanceRecord->getParent()->getModule());
971 auto inst = instanceRecord->getInstance();
975 auto &cacheEntry = nlaCache[arrayAttr];
977 auto nla = OpBuilder::atBlockBegin(nlaBlock).create<hw::HierPathOp>(
978 loc,
"nla", arrayAttr);
980 symbolTable.insert(nla);
982 cacheEntry = nla.getNameAttr();
983 nla.setVisibility(vis);
984 nlaTable->addNLA(nla);
987 nlas.push_back(nlaRef);
995 SmallVector<FlatSymbolRefAttr>
997 SymbolTable::Visibility vis = SymbolTable::Visibility::Private) {
1005 Annotation anno, ArrayRef<NamedAttribute> attributes,
1006 unsigned nonLocalIndex,
1007 SmallVectorImpl<Annotation> &newAnnotations) {
1008 SmallVector<NamedAttribute> mutableAttributes(attributes.begin(),
1010 for (
auto &nla : nlas) {
1012 mutableAttributes[nonLocalIndex].setValue(nla);
1013 auto dict = DictionaryAttr::getWithSorted(context, mutableAttributes);
1016 newAnnotations.push_back(anno);
1025 targetMap.erase(nla.getNameAttr());
1026 nlaTable->erase(nla);
1027 nlaCache.erase(nla.getNamepathAttr());
1028 symbolTable.erase(nla);
1034 FModuleOp fromModule) {
1035 auto toName = toModule.getNameAttr();
1036 auto fromName = fromModule.getNameAttr();
1039 auto moduleNLAs = nlaTable->lookup(fromModule.getNameAttr()).vec();
1041 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1044 for (
auto nla : moduleNLAs) {
1045 auto elements = nla.getNamepath().getValue();
1047 if (nla.root() != toName)
1050 SmallVector<Attribute> namepath(elements.begin(), elements.end());
1051 auto nlaRefs = createNLAs(fromModule, namepath, nla.getVisibility());
1053 auto &set = targetMap[nla.getSymNameAttr()];
1054 SmallVector<AnnoTarget> targets(set.begin(), set.end());
1056 for (
auto target : targets) {
1059 SmallVector<Annotation> newAnnotations;
1060 for (
auto anno : target.getAnnotations()) {
1062 auto [it, found] = mlir::impl::findAttrSorted(
1063 anno.begin(), anno.end(), nonLocalString);
1066 if (!found || cast<FlatSymbolRefAttr>(it->getValue()).getAttr() !=
1067 nla.getSymNameAttr()) {
1068 newAnnotations.push_back(anno);
1071 auto nonLocalIndex = std::distance(anno.begin(), it);
1073 cloneAnnotation(nlaRefs, anno,
1074 ArrayRef<NamedAttribute>(anno.begin(), anno.end()),
1075 nonLocalIndex, newAnnotations);
1080 target.setAnnotations(annotations);
1082 for (
auto nla : nlaRefs)
1083 targetMap[nla.getAttr()].insert(target);
1095 FModuleOp fromModule) {
1096 addAnnotationContext(renameMap, toModule, toModule);
1097 addAnnotationContext(renameMap, toModule, fromModule);
1103 StringAttr fromName) {
1104 nlaTable->renameModuleAndInnerRef(toName, fromName, renameMap);
1112 SmallVectorImpl<Annotation> &newAnnotations) {
1115 SmallVector<NamedAttribute> attributes;
1116 int nonLocalIndex = -1;
1117 for (
const auto &val : llvm::enumerate(anno)) {
1118 auto attr = val.value();
1120 auto compare = attr.getName().compare(nonLocalString);
1121 assert(compare != 0 &&
"should not pass non-local annotations here");
1125 nonLocalIndex = val.index();
1126 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1131 attributes.push_back(attr);
1133 if (nonLocalIndex == -1) {
1135 nonLocalIndex = attributes.size();
1136 attributes.push_back(NamedAttribute(nonLocalString, nonLocalString));
1139 attributes.append(anno.
begin() + nonLocalIndex, anno.
end());
1143 auto nlaRefs = createNLAs(toModuleName, fromModule);
1144 for (
auto nla : nlaRefs)
1145 targetMap[nla.getAttr()].insert(to);
1148 cloneAnnotation(nlaRefs, anno, attributes, nonLocalIndex, newAnnotations);
1154 SmallVectorImpl<Annotation> &newAnnotations,
1155 SmallPtrSetImpl<Attribute> &dontTouches) {
1156 for (
auto anno : annos) {
1160 anno.removeMember(
"circt.nonlocal");
1161 auto [it, inserted] = dontTouches.insert(anno.getAttr());
1163 newAnnotations.push_back(anno);
1168 if (
auto nla = anno.getMember<FlatSymbolRefAttr>(
"circt.nonlocal")) {
1169 newAnnotations.push_back(anno);
1170 targetMap[nla.getAttr()].insert(to);
1174 makeAnnotationNonLocal(toModule.getModuleNameAttr(), to, fromModule, anno,
1185 SmallVector<Annotation> newAnnotations;
1189 llvm::SmallPtrSet<Attribute, 4> dontTouches;
1193 copyAnnotations(toModule, to, toModule, toAnnos, newAnnotations,
1195 copyAnnotations(toModule, to, fromModule, fromAnnos, newAnnotations,
1199 if (!newAnnotations.empty())
1205 FModuleLike fromModule, Operation *from) {
1211 if (toModule == to) {
1213 for (
unsigned i = 0, e =
getNumPorts(toModule); i < e; ++i)
1218 }
else if (
auto toMem = dyn_cast<MemOp>(to)) {
1220 auto fromMem = cast<MemOp>(from);
1221 for (
unsigned i = 0, e = toMem.getNumResults(); i < e; ++i)
1233 Operation *to, FModuleLike fromModule,
1240 return getNamespace(toModule);
1242 renameMap[fromSym] = toSym;
1246 auto fromPortSyms = from->getAttrOfType<ArrayAttr>(
"portSyms");
1247 if (!fromPortSyms || fromPortSyms.empty())
1250 auto &moduleNamespace = getNamespace(toModule);
1251 auto portCount = fromPortSyms.size();
1252 auto portNames = to->getAttrOfType<ArrayAttr>(
"portNames");
1253 auto toPortSyms = to->getAttrOfType<ArrayAttr>(
"portSyms");
1257 SmallVector<Attribute> newPortSyms;
1258 if (toPortSyms.empty())
1259 newPortSyms.assign(portCount, hw::InnerSymAttr());
1261 newPortSyms.assign(toPortSyms.begin(), toPortSyms.end());
1263 for (
unsigned portNo = 0; portNo < portCount; ++portNo) {
1265 if (!fromPortSyms[portNo])
1267 auto fromSym = fromPortSyms[portNo].cast<hw::InnerSymAttr>();
1270 hw::InnerSymAttr toSym;
1271 if (!newPortSyms[portNo]) {
1273 StringRef symName =
"inner_sym";
1275 symName = cast<StringAttr>(portNames[portNo]).getValue();
1279 newPortSyms[portNo] = toSym;
1281 toSym = newPortSyms[portNo].cast<hw::InnerSymAttr>();
1284 renameMap[fromSym.getSymName()] = toSym.getSymName();
1288 cast<FModuleLike>(to).setPortSymbols(newPortSyms);
1294 FModuleLike fromModule, Operation *from) {
1296 if (to->getLoc() != from->getLoc())
1297 to->setLoc(
mergeLoc(context, to->getLoc(), from->getLoc()));
1300 for (
auto regions : llvm::zip(to->getRegions(), from->getRegions()))
1301 mergeRegions(renameMap, toModule, std::get<0>(regions), fromModule,
1302 std::get<1>(regions));
1305 recordSymRenames(renameMap, toModule, to, fromModule, from);
1308 mergeAnnotations(toModule, to, fromModule, from);
1313 FModuleLike fromModule, Block &fromBlock) {
1315 for (
auto [toArg, fromArg] :
1316 llvm::zip(toBlock.getArguments(), fromBlock.getArguments()))
1317 if (toArg.getLoc() != fromArg.getLoc())
1318 toArg.setLoc(
mergeLoc(context, toArg.getLoc(), fromArg.getLoc()));
1320 for (
auto ops : llvm::zip(toBlock, fromBlock))
1321 mergeOps(renameMap, toModule, &std::get<0>(ops), fromModule,
1327 Region &toRegion, FModuleLike fromModule,
1328 Region &fromRegion) {
1329 for (
auto blocks : llvm::zip(toRegion, fromRegion))
1330 mergeBlocks(renameMap, toModule, std::get<0>(blocks), fromModule,
1331 std::get<1>(blocks));
1345 DenseMap<Attribute, llvm::SmallDenseSet<AnnoTarget>>
targetMap;
1368 auto dstType = dst.getType();
1369 auto srcType = src.getType();
1370 if (dstType == srcType) {
1376 auto dstBundle = type_cast<BundleType>(dstType);
1377 auto srcBundle = type_cast<BundleType>(srcType);
1378 for (
unsigned i = 0; i < dstBundle.getNumElements(); ++i) {
1379 auto dstField =
builder.create<SubfieldOp>(dst, i);
1380 auto srcField =
builder.create<SubfieldOp>(src, i);
1381 if (dstBundle.getElement(i).isFlip) {
1382 std::swap(srcBundle, dstBundle);
1383 std::swap(srcField, dstField);
1393 for (
auto *node : instanceGraph) {
1394 auto module = cast<FModuleLike>(*node->getModule());
1395 for (
auto *instRec : node->uses()) {
1396 auto inst = instRec->getInstance<InstanceOp>();
1400 ImplicitLocOpBuilder
builder(inst.getLoc(), inst->getContext());
1401 builder.setInsertionPointAfter(inst);
1402 for (
unsigned i = 0, e =
getNumPorts(module); i < e; ++i) {
1403 auto result = inst.getResult(i);
1404 auto newType = module.getPortType(i);
1405 auto oldType = result.getType();
1407 if (newType == oldType)
1412 builder.create<WireOp>(oldType, inst.getPortName(i)).getResult();
1413 result.replaceAllUsesWith(wire);
1414 result.setType(newType);
1430 struct DenseMapInfo<ModuleInfo> {
1432 std::array<uint8_t, 32> key;
1433 std::fill(key.begin(), key.end(), ~0);
1434 return {key, DenseMapInfo<mlir::ArrayAttr>::getEmptyKey()};
1438 std::array<uint8_t, 32> key;
1439 std::fill(key.begin(), key.end(), ~0 - 1);
1440 return {key, DenseMapInfo<mlir::ArrayAttr>::getTombstoneKey()};
1447 std::memcpy(&hash, val.structuralHash.data(),
sizeof(
unsigned));
1450 return llvm::hash_combine(hash, val.referredModuleNames);
1453 static bool isEqual(
const ModuleInfo &lhs,
const ModuleInfo &rhs) {
1454 return lhs.structuralHash == rhs.structuralHash &&
1455 lhs.referredModuleNames == rhs.referredModuleNames;
1465 class DedupPass :
public DedupBase<DedupPass> {
1466 void runOnOperation()
override {
1467 auto *context = &getContext();
1468 auto circuit = getOperation();
1469 auto &instanceGraph = getAnalysis<InstanceGraph>();
1470 auto *nlaTable = &getAnalysis<NLATable>();
1471 auto &symbolTable = getAnalysis<SymbolTable>();
1472 Deduper deduper(instanceGraph, symbolTable, nlaTable, circuit);
1474 auto anythingChanged =
false;
1483 llvm::DenseMap<ModuleInfo, Operation *> moduleInfoToModule;
1488 DenseMap<Attribute, StringAttr> dedupMap;
1493 SmallVector<FModuleLike, 0> modules(
1494 llvm::map_range(llvm::post_order(&instanceGraph), [](
auto *node) {
1495 return cast<FModuleLike>(*node->getModule());
1498 SmallVector<std::optional<
1499 std::pair<std::array<uint8_t, 32>, SmallVector<StringAttr>>>>
1500 hashesAndModuleNames(modules.size());
1504 auto result = mlir::failableParallelForEach(
1505 context, llvm::seq(modules.size()), [&](
unsigned idx) {
1506 auto module = modules[idx];
1507 AnnotationSet annotations(module);
1509 if (annotations.hasAnnotation(noDedupClass))
1513 if (llvm::any_of(module.getPorts(), [&](PortInfo port) {
1514 return type_isa<RefType>(port.type) && port.isInput();
1519 if (
auto ext = dyn_cast<FExtModuleOp>(*module);
1520 ext && !ext.getDefname().has_value())
1526 if (!module.isPrivate() || !module.canDiscardOnUseEmpty()) {
1533 if (isa<ClassLike>(*module)) {
1537 llvm::SmallSetVector<StringAttr, 1> groups;
1538 for (
auto annotation : annotations) {
1539 if (annotation.getClass() == dedupGroupClass)
1540 groups.insert(annotation.getMember<StringAttr>(
"group"));
1542 if (groups.size() > 1) {
1543 module.emitError(
"module belongs to multiple dedup groups: ")
1547 auto dedupGroup = groups.empty() ? StringAttr() : groups.front();
1551 hashesAndModuleNames[idx] =
1552 hasher.getHashAndModuleNames(module, dedupGroup);
1556 if (result.failed())
1557 return signalPassFailure();
1559 for (
auto [i, module] : llvm::enumerate(modules)) {
1560 auto moduleName = module.getModuleNameAttr();
1561 auto &hashAndModuleNamesOpt = hashesAndModuleNames[i];
1563 if (!hashAndModuleNamesOpt) {
1568 dedupMap[moduleName] = moduleName;
1573 SmallVector<mlir::Attribute> names;
1574 for (
auto oldModuleName : hashAndModuleNamesOpt->second) {
1575 auto newModuleName = dedupMap[oldModuleName];
1576 names.push_back(newModuleName);
1580 ModuleInfo moduleInfo{hashAndModuleNamesOpt->first,
1584 auto it = moduleInfoToModule.find(moduleInfo);
1585 if (it != moduleInfoToModule.end()) {
1586 auto original = cast<FModuleLike>(it->second);
1588 dedupMap[moduleName] = original.getModuleNameAttr();
1589 deduper.dedup(original, module);
1591 anythingChanged =
true;
1595 deduper.record(module);
1597 dedupMap[moduleName] = moduleName;
1599 moduleInfoToModule[moduleInfo] = module;
1607 auto failed =
false;
1609 auto parseModule = [&](Attribute path) -> StringAttr {
1612 auto [_, rhs] = cast<StringAttr>(path).getValue().split(
'|');
1618 auto getLead = [&](StringAttr module) -> StringAttr {
1619 auto it = dedupMap.find(module);
1620 if (it == dedupMap.end()) {
1621 auto diag = emitError(circuit.getLoc(),
1622 "MustDeduplicateAnnotation references module ")
1623 << module <<
" which does not exist";
1636 auto modules = annotation.
getMember<ArrayAttr>(
"modules");
1638 emitError(circuit.getLoc(),
1639 "MustDeduplicateAnnotation missing \"modules\" member");
1644 if (modules.size() == 0)
1647 auto firstModule = parseModule(modules[0]);
1648 auto firstLead = getLead(firstModule);
1652 for (
auto attr : modules.getValue().drop_front()) {
1653 auto nextModule = parseModule(attr);
1654 auto nextLead = getLead(nextModule);
1657 if (firstLead != nextLead) {
1658 auto diag = emitError(circuit.getLoc(),
"module ")
1659 << nextModule <<
" not deduplicated with " << firstModule;
1660 auto a = instanceGraph.lookup(firstLead)->getModule();
1661 auto b = instanceGraph.lookup(nextLead)->getModule();
1662 equiv.check(diag, a, b);
1670 return signalPassFailure();
1672 for (
auto module : circuit.getOps<FModuleOp>())
1680 markAnalysesPreserved<NLATable>();
1681 if (!anythingChanged)
1682 markAllAnalysesPreserved();
1688 return std::make_unique<DedupPass>();
assert(baseType &&"element must be base type")
static Attribute getAttr(ArrayRef< NamedAttribute > attrs, StringRef name)
Get an attribute by name from a list of named attributes.
static void mergeRegions(Region *region1, Region *region2)
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
Annotation getAnnotation(StringRef className) const
If this annotation set has an annotation with the specified class name, return it.
This class provides a read-only projection of an annotation.
void setDict(DictionaryAttr dict)
Set the data dictionary of this attribute.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
This graph tracks modules and where they are instantiated.
This table tracks nlas and what modules participate in them.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
StringRef toString(Direction direction)
SmallVector< Direction > unpackAttribute(IntegerAttr directions)
Turn a packed representation of port attributes into a vector that can be worked with.
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
constexpr const char * mustDedupAnnoClass
std::pair< hw::InnerSymAttr, StringAttr > getOrAddInnerSym(MLIRContext *context, hw::InnerSymAttr attr, uint64_t fieldID, llvm::function_ref< hw::InnerSymbolNamespace &()> getNamespace)
Ensure that the the InnerSymAttr has a symbol on the field specified.
constexpr const char * noDedupAnnoClass
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
constexpr const char * dedupGroupAnnoClass
std::unique_ptr< mlir::Pass > createDedupPass()
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
constexpr const char * dontTouchAnnoClass
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Block * nlaBlock
We insert all NLAs to the beginning of this block.
void recordAnnotations(Operation *op)
Record all targets which use an NLA.
void eraseNLA(hw::HierPathOp nla)
This erases the NLA op, and removes the NLA from every module's NLA map, but it does not delete the N...
void mergeAnnotations(FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Merge all annotations and port annotations on two operations.
void replaceInstances(FModuleLike toModule, Operation *fromModule)
This deletes and replaces all instances of the "fromModule" with instances of the "toModule".
SmallVector< FlatSymbolRefAttr > createNLAs(StringAttr toModuleName, FModuleLike fromModule, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of this module and create an NLA for each one.
void record(FModuleLike module)
Record the usages of any NLA's in this module, so that we may update the annotation if the parent mod...
void rewriteExtModuleNLAs(RenameMap &renameMap, StringAttr toName, StringAttr fromName)
void mergeRegions(RenameMap &renameMap, FModuleLike toModule, Region &toRegion, FModuleLike fromModule, Region &fromRegion)
void dedup(FModuleLike toModule, FModuleLike fromModule)
Remove the "fromModule", and replace all references to it with the "toModule".
void rewriteModuleNLAs(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all the NLAs that the two modules participate in, replacing references to the "from" module w...
void recordAnnotations(AnnoTarget target)
For a specific annotation target, record all the unique NLAs which target it in the targetMap.
void cloneAnnotation(SmallVectorImpl< FlatSymbolRefAttr > &nlas, Annotation anno, ArrayRef< NamedAttribute > attributes, unsigned nonLocalIndex, SmallVectorImpl< Annotation > &newAnnotations)
Clone the annotation for each NLA in a list.
void recordSymRenames(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
void mergeAnnotations(FModuleLike toModule, AnnoTarget to, AnnotationSet toAnnos, FModuleLike fromModule, AnnoTarget from, AnnotationSet fromAnnos)
Merge the annotations of a specific target, either a operation or a port on an operation.
StringAttr nonLocalString
SymbolTable & symbolTable
SmallVector< FlatSymbolRefAttr > createNLAs(Operation *fromModule, ArrayRef< Attribute > baseNamepath, SymbolTable::Visibility vis=SymbolTable::Visibility::Private)
Look up the instantiations of the from module and create an NLA for each one, appending the baseNamep...
void mergeOps(RenameMap &renameMap, FModuleLike toModule, Operation *to, FModuleLike fromModule, Operation *from)
Recursively merge two operations.
hw::InnerSymbolNamespace & getNamespace(Operation *module)
Get a cached namespace for a module.
DenseMap< Operation *, hw::InnerSymbolNamespace > moduleNamespaces
A module namespace cache.
bool makeAnnotationNonLocal(StringAttr toModuleName, AnnoTarget to, FModuleLike fromModule, Annotation anno, SmallVectorImpl< Annotation > &newAnnotations)
Take an annotation, and update it to be a non-local annotation.
InstanceGraph & instanceGraph
void mergeBlocks(RenameMap &renameMap, FModuleLike toModule, Block &toBlock, FModuleLike fromModule, Block &fromBlock)
Recursively merge two blocks.
DenseMap< Attribute, llvm::SmallDenseSet< AnnoTarget > > targetMap
void copyAnnotations(FModuleLike toModule, AnnoTarget to, FModuleLike fromModule, AnnotationSet annos, SmallVectorImpl< Annotation > &newAnnotations, SmallPtrSetImpl< Attribute > &dontTouches)
Deduper(InstanceGraph &instanceGraph, SymbolTable &symbolTable, NLATable *nlaTable, CircuitOp circuit)
void addAnnotationContext(RenameMap &renameMap, FModuleOp toModule, FModuleOp fromModule)
Process all NLAs referencing the "from" module to point to the "to" module.
DenseMap< StringAttr, StringAttr > RenameMap
DenseMap< Attribute, Attribute > nlaCache
const hw::InnerSymbolTable & a
ModuleData(const hw::InnerSymbolTable &a, const hw::InnerSymbolTable &b)
const hw::InnerSymbolTable & b
This class is for reporting differences between two modules which should have been deduplicated.
LogicalResult check(InFlightDiagnostic &diag, InstanceOp a, InstanceOp b)
DenseSet< Attribute > nonessentialAttributes
std::string prettyPrint(Attribute attr)
LogicalResult check(InFlightDiagnostic &diag, Operation *a, IntegerAttr aAttr, Operation *b, IntegerAttr bAttr)
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Block &aBlock, Operation *b, Block &bBlock)
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, Type aType, Operation *b, Type bType)
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, DictionaryAttr aDict, Operation *b, DictionaryAttr bDict)
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Region &aRegion, Operation *b, Region &bRegion)
StringAttr portDirectionsAttr
LogicalResult check(InFlightDiagnostic &diag, const Twine &message, Operation *a, BundleType aType, Operation *b, BundleType bType)
StringAttr dedupGroupClass
LogicalResult check(InFlightDiagnostic &diag, ModuleData &data, Operation *a, Operation *b)
Equivalence(MLIRContext *context, InstanceGraph &instanceGraph)
InstanceGraph & instanceGraph
void check(InFlightDiagnostic &diag, Operation *a, Operation *b)
std::array< uint8_t, 32 > structuralHash
mlir::ArrayAttr referredModuleNames
This struct contains constant string attributes shared across different threads.
StringAttr moduleNameAttr
DenseSet< Attribute > nonessentialAttributes
StructuralHasherSharedConstants(MLIRContext *context)
void update(Operation *op, DictionaryAttr dict)
Hash the top level attribute dictionary of the operation.
void update(BlockArgument arg)
std::pair< std::array< uint8_t, 32 >, SmallVector< StringAttr > > getHashAndModuleNames(FModuleLike module, StringAttr group)
void update(const void *pointer)
void update(InnerRefAttr attr)
void update(Operation *op)
void record(void *address)
void update(const SymbolTarget &target)
void update(Operation *op, hw::InnerSymAttr attr)
DenseMap< StringAttr, SymbolTarget > innerSymTargets
void update(size_t value)
SmallVector< mlir::StringAttr > referredModuleNames
void update(Block &block)
void update(BundleType type)
void update(OpResult result)
void update(OpOperand &operand)
StructuralHasher(const StructuralHasherSharedConstants &constants)
void update(TypeID typeID)
const StructuralHasherSharedConstants & constants
void update(Value value, hw::InnerSymAttr attr)
DenseMap< void *, unsigned > indices
void update(mlir::OperationName name)
An annotation target is used to keep track of something that is targeted by an Annotation.
AnnotationSet getAnnotations() const
Get the annotations associated with the target.
void setAnnotations(AnnotationSet annotations) const
Set the annotations associated with the target.
This represents an annotation targeting a specific operation.
Attribute getNLAReference(hw::InnerSymbolNamespace &moduleNamespace) const
This represents an annotation targeting a specific port of a module, memory, or instance.
static ModuleInfo getEmptyKey()
static ModuleInfo getTombstoneKey()
static unsigned getHashValue(const ModuleInfo &val)
static bool isEqual(const ModuleInfo &lhs, const ModuleInfo &rhs)