27#include "mlir/IR/IRMapping.h"
28#include "mlir/Pass/Pass.h"
29#include "llvm/ADT/BitVector.h"
30#include "llvm/ADT/SetOperations.h"
31#include "llvm/Support/Debug.h"
32#include "llvm/Support/FormatVariadic.h"
34#define DEBUG_TYPE "firrtl-inliner"
38#define GEN_PASS_DEF_INLINER
39#include "circt/Dialect/FIRRTL/Passes.h.inc"
44using namespace firrtl;
45using namespace chirrtl;
47using hw::InnerRefAttr;
69 DenseMap<Attribute, unsigned> symIdx;
72 BitVector inlinedSymbols;
76 signed flattenPoint = -1;
87 bool moduleOnly =
false;
93 SmallVector<InnerRefAttr> newTops;
97 DenseSet<StringAttr> rootSet;
106 DenseMap<Attribute, StringAttr> renames;
112 StringAttr lookupRename(Attribute lastMod,
unsigned idx = 0) {
113 if (renames.count(lastMod))
114 return renames[lastMod];
115 return nla.refPart(idx);
120 : nla(nla), circuitNamespace(circuitNamespace),
121 inlinedSymbols(BitVector(nla.getNamepath().size(),
true)),
122 size(nla.getNamepath().size()) {
123 for (
size_t i = 0, e = size; i != e; ++i)
124 symIdx.insert({nla.modPart(i), i});
137 "the default constructor for MutableNLA should never be used");
142 void markDead() { dead =
true; }
145 void markModuleOnly() { moduleOnly =
true; }
148 hw::HierPathOp getNLA() {
return nla; }
156 hw::HierPathOp applyUpdates() {
158 if (isLocal() || isDead()) {
165 if (inlinedSymbols.all() && newTops.empty() && flattenPoint == -1 &&
172 auto writeBack = [&](StringAttr root, StringAttr sym) -> hw::HierPathOp {
173 SmallVector<Attribute> namepath;
177 if (!inlinedSymbols.test(1))
180 namepath.push_back(InnerRefAttr::get(root, lookupRename(root)));
183 for (
signed i = 1, e = inlinedSymbols.size() - 1; i != e; ++i) {
184 if (i == flattenPoint) {
185 lastMod = nla.modPart(i);
189 if (!inlinedSymbols.test(i + 1)) {
191 lastMod = nla.modPart(i);
196 auto modPart = lastMod ? lastMod : nla.modPart(i);
197 auto refPart = lookupRename(modPart, i);
198 namepath.push_back(InnerRefAttr::get(modPart, refPart));
203 auto modPart = lastMod ? lastMod : nla.modPart(size - 1);
204 auto refPart = lookupRename(modPart, size - 1);
207 namepath.push_back(InnerRefAttr::get(modPart, refPart));
209 namepath.push_back(FlatSymbolRefAttr::get(modPart));
211 auto hp = hw::HierPathOp::create(b, b.getUnknownLoc(), sym,
212 b.getArrayAttr(namepath));
213 hp.setVisibility(nla.getVisibility());
218 assert(!dead || !newTops.empty());
220 last = writeBack(nla.root(), nla.getNameAttr());
221 for (
auto root : newTops)
222 last = writeBack(root.getModule(), root.getName());
229 llvm::errs() <<
" - orig: " << nla <<
"\n"
230 <<
" new: " << *
this <<
"\n"
231 <<
" dead: " << dead <<
"\n"
232 <<
" isDead: " << isDead() <<
"\n"
233 <<
" isModuleOnly: " << isModuleOnly() <<
"\n"
234 <<
" isLocal: " << isLocal() <<
"\n"
235 <<
" inlinedSymbols: [";
236 llvm::interleaveComma(inlinedSymbols.getData(), llvm::errs(), [](
auto a) {
237 llvm::errs() << llvm::formatv(
"{0:x-}", a);
239 llvm::errs() <<
"]\n"
240 <<
" flattenPoint: " << flattenPoint <<
"\n"
242 for (
auto rename : renames)
243 llvm::errs() <<
" - " << rename.first <<
" -> " << rename.second
250 friend llvm::raw_ostream &
operator<<(llvm::raw_ostream &os, MutableNLA &x) {
251 auto writePathSegment = [&](StringAttr mod, StringAttr sym = {}) {
253 os <<
"#hw.innerNameRef<";
254 os <<
"@" << mod.getValue();
256 os <<
"::@" << sym.getValue() <<
">";
259 auto writeOne = [&](StringAttr root, StringAttr sym) {
260 os <<
"firrtl.nla @" << sym.getValue() <<
" [";
264 if (!x.inlinedSymbols.test(1))
267 writePathSegment(root, x.lookupRename(root));
270 bool needsComma =
false;
271 for (
signed i = 1, e = x.inlinedSymbols.size() - 1; i != e; ++i) {
272 if (i == x.flattenPoint) {
273 lastMod = x.nla.modPart(i);
277 if (!x.inlinedSymbols.test(i + 1)) {
279 lastMod = x.nla.modPart(i);
285 auto modPart = lastMod ? lastMod : x.nla.modPart(i);
286 auto refPart = x.nla.refPart(i);
287 if (x.renames.count(modPart))
288 refPart = x.renames[modPart];
289 writePathSegment(modPart, refPart);
296 auto modPart = lastMod ? lastMod : x.nla.modPart(x.size - 1);
297 auto refPart = x.nla.refPart(x.size - 1);
298 if (x.renames.count(modPart))
299 refPart = x.renames[modPart];
300 writePathSegment(modPart, refPart);
304 SmallVector<InnerRefAttr> tops;
306 tops.push_back(InnerRefAttr::get(x.nla.root(), x.nla.getNameAttr()));
307 tops.append(x.newTops.begin(), x.newTops.end());
309 bool multiary = !x.newTops.empty();
312 llvm::interleaveComma(tops, os, [&](InnerRefAttr a) {
313 writeOne(a.getModule(), a.getName());
325 bool isDead() {
return dead && newTops.empty(); }
328 bool isModuleOnly() {
return moduleOnly; }
334 unsigned end = flattenPoint > -1 ? flattenPoint + 1 : inlinedSymbols.size();
335 return inlinedSymbols.find_first_in(1, end) == -1;
339 bool hasRoot(FModuleLike mod) {
340 return (isDead() && nla.root() == mod.getModuleNameAttr()) ||
341 rootSet.contains(mod.getModuleNameAttr());
345 bool hasRoot(StringAttr modName) {
346 return (nla.root() == modName) || rootSet.contains(modName);
350 void inlineModule(FModuleOp module) {
351 auto sym =
module.getNameAttr();
352 assert(sym != nla.root() &&
"unable to inline the root module");
353 assert(symIdx.count(sym) &&
"module is not in the symIdx map");
354 auto idx = symIdx[sym];
355 inlinedSymbols.reset(idx);
358 if (idx == size - 1 && moduleOnly)
365 void flattenModule(FModuleOp module) {
366 auto sym =
module.getNameAttr();
367 assert(symIdx.count(sym) &&
"module is not in the symIdx map");
368 auto idx = symIdx[sym] - 1;
376 StringAttr reTop(FModuleOp module) {
377 StringAttr sym = nla.getSymNameAttr();
378 if (!newTops.empty())
379 sym = StringAttr::get(nla.getContext(),
380 circuitNamespace->
newName(sym.getValue()));
381 newTops.push_back(InnerRefAttr::get(module.getNameAttr(), sym));
382 rootSet.insert(module.getNameAttr());
383 symIdx.insert({
module.getNameAttr(), 0});
388 ArrayRef<InnerRefAttr> getAdditionalSymbols() {
return ArrayRef(newTops); }
390 void setInnerSym(Attribute module, StringAttr innerSym) {
391 assert(symIdx.count(module) &&
"Mutable NLA did not contain symbol");
392 assert(!renames.count(module) &&
"Module already renamed");
393 renames.insert({module, innerSym});
404 InstanceOp instance) {
405 for (
unsigned i = 0, e = instance.getNumResults(); i < e; ++i) {
406 auto result = instance.getResult(i);
407 auto wire = wires[i];
408 mapper.map(result, wire);
416 StringAttr istName) {
417 mlir::AttrTypeReplacer replacer;
418 replacer.addReplacement([&](hw::InnerRefAttr innerRef) {
419 auto it = map.find(innerRef);
423 return std::pair{hw::InnerRefAttr::get(istName, it->second),
426 llvm::for_each(newOps,
427 [&](
auto *op) { replacer.recursivelyReplaceElementsIn(op); });
435 StringAttr istName) {
436 if (!old || old.empty())
439 bool anyChanged =
false;
441 SmallVector<hw::InnerSymPropertiesAttr> newProps;
442 auto *context = old.getContext();
443 for (
auto &prop : old) {
444 auto newSym = ns.
newName(prop.getName().strref());
445 if (newSym == prop.getName()) {
446 newProps.push_back(prop);
449 auto newSymStrAttr = StringAttr::get(context, newSym);
450 auto newProp = hw::InnerSymPropertiesAttr::get(
451 context, newSymStrAttr, prop.getFieldID(), prop.getSymVisibility());
453 newProps.push_back(newProp);
456 auto newSymAttr = anyChanged ? hw::InnerSymAttr::get(context, newProps) : old;
458 for (
auto [oldProp, newProp] : llvm::zip(old, newSymAttr)) {
459 assert(oldProp.getFieldID() == newProp.getFieldID());
461 map[hw::InnerRefAttr::get(istName, oldProp.getName())] = newProp.getName();
493 Inliner(CircuitOp circuit, SymbolTable &symbolTable);
501 struct ModuleInliningContext {
502 ModuleInliningContext(FModuleOp module)
503 : module(module), modNamespace(module), b(module.getContext()) {}
515 struct InliningLevel {
516 InliningLevel(ModuleInliningContext &mic, FModuleOp childModule)
517 : mic(mic), childModule(childModule) {}
520 ModuleInliningContext &mic;
524 SmallVector<Operation *> newOps;
526 SmallVector<Value> wires;
528 FModuleOp childModule;
534 mic.module.getNameAttr());
541 bool doesNLAMatchCurrentPath(hw::HierPathOp nla);
545 bool rename(StringRef prefix, Operation *op, InliningLevel &il);
550 bool renameInstance(StringRef prefix, InliningLevel &il, InstanceOp oldInst,
552 const DenseMap<Attribute, Attribute> &symbolRenames);
556 void cloneAndRename(StringRef prefix, InliningLevel &il, IRMapping &mapper,
558 const DenseMap<Attribute, Attribute> &symbolRenames,
559 const DenseSet<Attribute> &localSymbols);
564 void mapPortsToWires(StringRef prefix, InliningLevel &il, IRMapping &mapper,
565 const DenseSet<Attribute> &localSymbols);
568 bool shouldFlatten(Operation *op);
571 bool shouldInline(Operation *op);
575 LogicalResult checkInstanceParents(InstanceOp instance);
581 inliningWalk(OpBuilder &builder, Block *block, IRMapping &mapper,
582 llvm::function_ref<LogicalResult(Operation *op)> process);
587 LogicalResult flattenInto(StringRef prefix, InliningLevel &il,
589 DenseSet<Attribute> localSymbols);
594 LogicalResult inlineInto(StringRef prefix, InliningLevel &il,
596 DenseMap<Attribute, Attribute> &symbolRenames);
599 LogicalResult flattenInstances(FModuleOp module);
602 LogicalResult inlineInstances(FModuleOp module);
606 void createDebugScope(InliningLevel &il, InstanceOp instance,
607 Value parentScope = {});
610 void identifyNLAsTargetingOnlyModules();
616 void setActiveHierPaths(StringAttr moduleName, StringAttr instInnerSym) {
618 instOpHierPaths[InnerRefAttr::get(moduleName, instInnerSym)];
619 if (currentPath.empty()) {
620 activeHierpaths.insert(instPaths.begin(), instPaths.end());
623 DenseSet<StringAttr> hPaths(instPaths.begin(), instPaths.end());
626 llvm::set_intersect(activeHierpaths, hPaths);
629 for (
auto hPath : instPaths)
630 if (nlaMap[hPath].hasRoot(moduleName))
631 activeHierpaths.insert(hPath);
635 MLIRContext *context;
638 SymbolTable &symbolTable;
642 DenseSet<Operation *> liveModules;
645 SmallVector<FModuleOp, 16> worklist;
648 DenseMap<Attribute, MutableNLA> nlaMap;
651 DenseMap<Attribute, SmallVector<Attribute>> rootMap;
656 SmallVector<std::pair<Attribute, Attribute>> currentPath;
658 DenseSet<StringAttr> activeHierpaths;
663 DenseMap<InnerRefAttr, SmallVector<StringAttr>> instOpHierPaths;
667 SmallVector<debug::ScopeOp> debugScopes;
674bool Inliner::doesNLAMatchCurrentPath(hw::HierPathOp nla) {
675 return (activeHierpaths.find(nla.getSymNameAttr()) != activeHierpaths.end());
682bool Inliner::rename(StringRef prefix, Operation *op, InliningLevel &il) {
685 auto updateDebugScope = [&](
auto op) {
687 op.getScopeMutable().assign(il.debugScope);
689 if (
auto varOp = dyn_cast<debug::VariableOp>(op))
690 return updateDebugScope(varOp),
false;
691 if (
auto scopeOp = dyn_cast<debug::ScopeOp>(op))
692 return updateDebugScope(scopeOp),
false;
695 if (
auto nameAttr = op->getAttrOfType<StringAttr>(
"name"))
696 op->setAttr(
"name", StringAttr::get(op->getContext(),
697 (prefix + nameAttr.getValue())));
701 auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(op);
704 auto oldSymAttr = symOp.getInnerSymAttr();
707 il.childModule.getNameAttr());
713 if (
auto newSymStrAttr = newSymAttr.getSymName();
714 newSymStrAttr && newSymStrAttr != oldSymAttr.getSymName()) {
716 auto sym = anno.getMember<FlatSymbolRefAttr>(
"circt.nonlocal");
724 auto &mnla = nlaMap[sym.getAttr()];
725 if (!doesNLAMatchCurrentPath(mnla.getNLA()))
727 mnla.setInnerSym(il.mic.module.getModuleNameAttr(), newSymStrAttr);
731 symOp.setInnerSymbolAttr(newSymAttr);
733 return newSymAttr != oldSymAttr;
736bool Inliner::renameInstance(
737 StringRef prefix, InliningLevel &il, InstanceOp oldInst, InstanceOp newInst,
738 const DenseMap<Attribute, Attribute> &symbolRenames) {
743 llvm::dbgs() <<
"Discarding parent debug scope for " << oldInst <<
"\n";
748 auto parentActivePaths = activeHierpaths;
749 assert(oldInst->getParentOfType<FModuleOp>() == il.childModule);
751 setActiveHierPaths(oldInst->getParentOfType<FModuleOp>().getNameAttr(),
756 SmallVector<StringAttr> validHierPaths;
757 auto oldParent = oldInst->getParentOfType<FModuleOp>().getNameAttr();
764 auto oldInnerRef = InnerRefAttr::get(oldParent, oldInstSym);
765 for (
auto old : instOpHierPaths[oldInnerRef]) {
769 if (activeHierpaths.find(old) != activeHierpaths.end())
770 validHierPaths.push_back(old);
774 for (
auto additionalSym : nlaMap[old].getAdditionalSymbols())
775 if (activeHierpaths.find(additionalSym.
getName()) !=
776 activeHierpaths.
end()) {
777 validHierPaths.push_back(old);
786 auto symbolChanged = rename(prefix, newInst, il);
794 auto newInnerRef = InnerRefAttr::get(
795 newInst->getParentOfType<FModuleOp>().getNameAttr(), newSymAttr);
796 instOpHierPaths[newInnerRef] = validHierPaths;
798 for (
auto nla : instOpHierPaths[newInnerRef]) {
799 if (!nlaMap.count(nla))
801 auto &mnla = nlaMap[nla];
802 mnla.setInnerSym(newInnerRef.getModule(), newSymAttr);
807 auto innerRef = InnerRefAttr::get(
808 newInst->getParentOfType<FModuleOp>().getNameAttr(), newSymAttr);
809 SmallVector<StringAttr> &nlaList = instOpHierPaths[innerRef];
811 for (
const auto &
en :
llvm::enumerate(nlaList)) {
812 auto oldNLA =
en.value();
813 if (
auto newSym = symbolRenames.lookup(oldNLA))
814 nlaList[
en.index()] = cast<StringAttr>(newSym);
817 activeHierpaths = std::move(parentActivePaths);
818 return symbolChanged;
826void Inliner::mapPortsToWires(StringRef prefix, InliningLevel &il,
828 const DenseSet<Attribute> &localSymbols) {
829 auto target = il.childModule;
830 auto portInfo = target.getPorts();
831 for (
unsigned i = 0, e = target.getNumPorts(); i < e; ++i) {
832 auto arg = target.getArgument(i);
834 auto type = type_cast<FIRRTLType>(arg.getType());
837 auto oldSymAttr = portInfo[i].sym;
840 il.mic.modNamespace, target.getNameAttr());
842 StringAttr newRootSymName, oldRootSymName;
844 oldRootSymName = oldSymAttr.getSymName();
846 newRootSymName = newSymAttr.getSymName();
848 SmallVector<Attribute> newAnnotations;
851 if (
auto sym = anno.getMember<FlatSymbolRefAttr>(
"circt.nonlocal")) {
852 auto &mnla = nlaMap[sym.getAttr()];
854 if (!doesNLAMatchCurrentPath(mnla.getNLA()))
858 if (oldRootSymName != newRootSymName)
859 mnla.setInnerSym(il.mic.module.getModuleNameAttr(), newRootSymName);
861 if (mnla.isLocal() || localSymbols.count(sym.getAttr()))
862 anno.removeMember(
"circt.nonlocal");
864 newAnnotations.push_back(anno.getAttr());
869 il.mic.b, target.getLoc(), type,
870 StringAttr::get(context, (prefix + portInfo[i].getName())),
871 NameKindEnumAttr::get(context, NameKindEnum::DroppableName),
872 ArrayAttr::get(context, newAnnotations), newSymAttr,
875 il.wires.push_back(wire);
876 mapper.map(arg, wire);
883void Inliner::cloneAndRename(
884 StringRef prefix, InliningLevel &il, IRMapping &mapper, Operation &op,
885 const DenseMap<Attribute, Attribute> &symbolRenames,
886 const DenseSet<Attribute> &localSymbols) {
889 SmallVector<Annotation> newAnnotations;
890 for (
auto anno : oldAnnotations) {
893 if (
auto sym = anno.getMember<FlatSymbolRefAttr>(
"circt.nonlocal")) {
895 auto &mnla = nlaMap[sym.getAttr()];
897 if (!doesNLAMatchCurrentPath(mnla.getNLA()))
900 if (mnla.isLocal() || localSymbols.count(sym.getAttr()))
901 anno.removeMember(
"circt.nonlocal");
904 newAnnotations.push_back(anno);
908 assert(op.getNumRegions() == 0 &&
909 "operation with regions should not reach cloneAndRename");
910 auto *newOp = il.mic.b.cloneWithoutRegions(op, mapper);
917 if (
auto oldInst = dyn_cast<InstanceOp>(op))
918 renameInstance(prefix, il, oldInst, cast<InstanceOp>(newOp), symbolRenames);
920 rename(prefix, newOp, il);
924 if (!newAnnotations.empty() || !oldAnnotations.empty())
927 il.newOps.push_back(newOp);
930bool Inliner::shouldFlatten(Operation *op) {
934bool Inliner::shouldInline(Operation *op) {
938LogicalResult Inliner::inliningWalk(
939 OpBuilder &builder, Block *block, IRMapping &mapper,
940 llvm::function_ref<LogicalResult(Operation *op)> process) {
942 OpBuilder::InsertPoint target;
943 Block::iterator source;
946 SmallVector<IPs> inliningStack;
950 inliningStack.push_back(IPs{builder.saveInsertionPoint(), block->begin()});
951 OpBuilder::InsertionGuard guard(builder);
953 while (!inliningStack.empty()) {
954 auto target = inliningStack.back().target;
955 builder.restoreInsertionPoint(target);
959 auto &ips = inliningStack.back();
960 source = &*ips.source;
961 auto end = source->getBlock()->end();
962 if (++ips.source == end)
963 inliningStack.pop_back();
967 if (source->getNumRegions() == 0) {
968 assert(builder.saveInsertionPoint().getPoint() == target.getPoint());
970 if (failed(process(source)))
972 assert(builder.saveInsertionPoint().getPoint() == target.getPoint());
978 if (!isa<LayerBlockOp, WhenOp, MatchOp>(source))
979 return source->emitError(
"unsupported operation '")
980 << source->getName() <<
"' cannot be inlined";
986 auto *newOp = builder.cloneWithoutRegions(*source, mapper);
987 for (
auto [newRegion, oldRegion] :
llvm::reverse(
988 llvm::zip_equal(newOp->getRegions(), source->getRegions()))) {
990 if (oldRegion.empty()) {
991 assert(newRegion.empty());
995 assert(oldRegion.hasOneBlock());
998 auto &oldBlock = oldRegion.getBlocks().front();
999 auto &newBlock = newRegion.emplaceBlock();
1000 mapper.map(&oldBlock, &newBlock);
1003 for (
auto arg : oldBlock.getArguments())
1004 mapper.map(arg, newBlock.addArgument(arg.getType(), arg.
getLoc()));
1006 if (oldBlock.empty())
1009 inliningStack.push_back(
1010 IPs{OpBuilder::InsertPoint(&newBlock, newBlock.begin()),
1017LogicalResult Inliner::checkInstanceParents(InstanceOp instance) {
1018 auto *parent = instance->getParentOp();
1019 while (!isa<FModuleLike>(parent)) {
1020 if (!isa<LayerBlockOp>(parent))
1021 return instance->emitError(
"cannot inline instance")
1022 .attachNote(parent->getLoc())
1023 <<
"containing operation '" << parent->getName()
1024 <<
"' not safe to inline into";
1025 parent = parent->getParentOp();
1031LogicalResult Inliner::flattenInto(StringRef prefix, InliningLevel &il,
1033 DenseSet<Attribute> localSymbols) {
1034 auto target = il.childModule;
1035 auto moduleName = target.getNameAttr();
1036 DenseMap<Attribute, Attribute> symbolRenames;
1038 LLVM_DEBUG(llvm::dbgs() <<
"flattening " << target.getModuleName() <<
" into "
1039 << il.mic.module.getModuleName() <<
"\n");
1040 auto visit = [&](Operation *op) {
1042 auto instance = dyn_cast<InstanceOp>(op);
1044 cloneAndRename(prefix, il, mapper, *op, symbolRenames, localSymbols);
1049 auto *moduleOp = symbolTable.lookup(instance.getModuleName());
1050 auto childModule = dyn_cast<FModuleOp>(moduleOp);
1052 liveModules.insert(moduleOp);
1054 cloneAndRename(prefix, il, mapper, *op, symbolRenames, localSymbols);
1058 if (failed(checkInstanceParents(instance)))
1064 llvm::set_union(localSymbols, rootMap[childModule.getNameAttr()]);
1066 auto parentActivePaths = activeHierpaths;
1067 setActiveHierPaths(moduleName, instInnerSym);
1068 currentPath.emplace_back(moduleName, instInnerSym);
1070 InliningLevel childIL(il.mic, childModule);
1071 createDebugScope(childIL, instance, il.debugScope);
1074 auto nestedPrefix = (prefix + instance.getName() +
"_").str();
1075 mapPortsToWires(nestedPrefix, childIL, mapper, localSymbols);
1079 if (failed(flattenInto(nestedPrefix, childIL, mapper, localSymbols)))
1081 currentPath.pop_back();
1082 activeHierpaths = parentActivePaths;
1085 return inliningWalk(il.mic.b, target.getBodyBlock(), mapper, visit);
1088LogicalResult Inliner::flattenInstances(FModuleOp module) {
1089 auto moduleName =
module.getNameAttr();
1090 ModuleInliningContext mic(module);
1092 auto visit = [&](InstanceOp instance) {
1094 auto *targetModule = symbolTable.lookup(instance.getModuleName());
1095 auto target = dyn_cast<FModuleOp>(targetModule);
1097 liveModules.insert(targetModule);
1098 return WalkResult::advance();
1101 if (failed(checkInstanceParents(instance)))
1102 return WalkResult::interrupt();
1105 auto innerRef = InnerRefAttr::get(moduleName, instSym);
1109 for (
auto targetNLA : instOpHierPaths[innerRef])
1110 nlaMap[targetNLA].flattenModule(target);
1116 DenseSet<Attribute> localSymbols;
1117 llvm::set_union(localSymbols, rootMap[target.getNameAttr()]);
1119 auto parentActivePaths = activeHierpaths;
1120 setActiveHierPaths(moduleName, instInnerSym);
1121 currentPath.emplace_back(moduleName, instInnerSym);
1126 mic.b.setInsertionPoint(instance);
1128 InliningLevel il(mic, target);
1129 createDebugScope(il, instance);
1131 auto nestedPrefix = (instance.getName() +
"_").str();
1132 mapPortsToWires(nestedPrefix, il, mapper, localSymbols);
1133 for (
unsigned i = 0, e = instance.getNumResults(); i < e; ++i)
1134 instance.getResult(i).replaceAllUsesWith(il.wires[i]);
1137 if (failed(flattenInto(nestedPrefix, il, mapper, localSymbols)))
1138 return WalkResult::interrupt();
1139 currentPath.pop_back();
1140 activeHierpaths = parentActivePaths;
1144 return WalkResult::skip();
1146 return failure(module.getBodyBlock()
1147 ->walk<mlir::WalkOrder::PreOrder>(visit)
1153Inliner::inlineInto(StringRef prefix, InliningLevel &il, IRMapping &mapper,
1154 DenseMap<Attribute, Attribute> &symbolRenames) {
1155 auto target = il.childModule;
1156 auto inlineToParent = il.mic.module;
1157 auto moduleName = target.getNameAttr();
1159 LLVM_DEBUG(llvm::dbgs() <<
"inlining " << target.getModuleName() <<
" into "
1160 << inlineToParent.getModuleName() <<
"\n");
1162 auto visit = [&](Operation *op) {
1164 auto instance = dyn_cast<InstanceOp>(op);
1166 cloneAndRename(prefix, il, mapper, *op, symbolRenames, {});
1171 auto *moduleOp = symbolTable.lookup(instance.getModuleName());
1172 auto childModule = dyn_cast<FModuleOp>(moduleOp);
1174 liveModules.insert(moduleOp);
1175 cloneAndRename(prefix, il, mapper, *op, symbolRenames, {});
1180 if (!shouldInline(childModule)) {
1181 if (liveModules.insert(childModule).second) {
1182 worklist.push_back(childModule);
1184 cloneAndRename(prefix, il, mapper, *op, symbolRenames, {});
1188 if (failed(checkInstanceParents(instance)))
1191 auto toBeFlattened = shouldFlatten(childModule);
1193 auto innerRef = InnerRefAttr::get(moduleName, instSym);
1197 for (
auto sym : instOpHierPaths[innerRef]) {
1199 nlaMap[sym].flattenModule(childModule);
1201 nlaMap[sym].inlineModule(childModule);
1212 DenseMap<Attribute, Attribute> symbolRenames;
1213 if (!rootMap[childModule.getNameAttr()].empty()) {
1214 for (
auto sym : rootMap[childModule.getNameAttr()]) {
1215 auto &mnla = nlaMap[sym];
1218 sym = mnla.reTop(inlineToParent);
1221 instSym = StringAttr::get(
1222 context, il.mic.modNamespace.newName(instance.getName()));
1223 instance.setInnerSymAttr(hw::InnerSymAttr::get(instSym));
1225 instOpHierPaths[InnerRefAttr::get(moduleName, instSym)].push_back(
1226 cast<StringAttr>(sym));
1230 symbolRenames.insert({mnla.getNLA().getNameAttr(), sym});
1234 auto parentActivePaths = activeHierpaths;
1235 setActiveHierPaths(moduleName, instInnerSym);
1237 currentPath.emplace_back(moduleName, instInnerSym);
1239 InliningLevel childIL(il.mic, childModule);
1240 createDebugScope(childIL, instance, il.debugScope);
1243 auto nestedPrefix = (prefix + instance.getName() +
"_").str();
1244 mapPortsToWires(nestedPrefix, childIL, mapper, {});
1248 if (toBeFlattened) {
1249 if (failed(flattenInto(nestedPrefix, childIL, mapper, {})))
1252 if (failed(inlineInto(nestedPrefix, childIL, mapper, symbolRenames)))
1255 currentPath.pop_back();
1256 activeHierpaths = parentActivePaths;
1260 return inliningWalk(il.mic.b, target.getBodyBlock(), mapper, visit);
1263LogicalResult Inliner::inlineInstances(FModuleOp module) {
1265 auto moduleName =
module.getNameAttr();
1266 ModuleInliningContext mic(module);
1268 auto visit = [&](InstanceOp instance) {
1270 auto *childModule = symbolTable.lookup(instance.getModuleName());
1271 auto target = dyn_cast<FModuleOp>(childModule);
1273 liveModules.insert(childModule);
1274 return WalkResult::advance();
1278 if (!shouldInline(target)) {
1279 if (liveModules.insert(target).second) {
1280 worklist.push_back(target);
1282 return WalkResult::advance();
1285 if (failed(checkInstanceParents(instance)))
1286 return WalkResult::interrupt();
1288 auto toBeFlattened = shouldFlatten(target);
1290 auto innerRef = InnerRefAttr::get(moduleName, instSym);
1294 for (
auto sym : instOpHierPaths[innerRef]) {
1296 nlaMap[sym].flattenModule(target);
1298 nlaMap[sym].inlineModule(target);
1305 DenseMap<Attribute, Attribute> symbolRenames;
1306 if (!rootMap[target.getNameAttr()].empty() && !toBeFlattened) {
1307 for (
auto sym : rootMap[target.getNameAttr()]) {
1308 auto &mnla = nlaMap[sym];
1309 sym = mnla.reTop(module);
1312 return mic.modNamespace;
1314 instOpHierPaths[InnerRefAttr::get(moduleName, instSym)].push_back(
1315 cast<StringAttr>(sym));
1319 symbolRenames.insert({mnla.getNLA().getNameAttr(), sym});
1323 auto parentActivePaths = activeHierpaths;
1324 setActiveHierPaths(moduleName, instInnerSym);
1326 currentPath.emplace_back(moduleName, instInnerSym);
1330 mic.b.setInsertionPoint(instance);
1331 auto nestedPrefix = (instance.getName() +
"_").str();
1333 InliningLevel childIL(mic, target);
1334 createDebugScope(childIL, instance);
1336 mapPortsToWires(nestedPrefix, childIL, mapper, {});
1337 for (
unsigned i = 0, e = instance.getNumResults(); i < e; ++i)
1338 instance.getResult(i).replaceAllUsesWith(childIL.wires[i]);
1341 if (toBeFlattened) {
1342 if (failed(flattenInto(nestedPrefix, childIL, mapper, {})))
1343 return WalkResult::interrupt();
1347 if (failed(inlineInto(nestedPrefix, childIL, mapper, symbolRenames)))
1348 return WalkResult::interrupt();
1350 currentPath.pop_back();
1351 activeHierpaths = parentActivePaths;
1355 return WalkResult::skip();
1358 return failure(module.getBodyBlock()
1359 ->walk<mlir::WalkOrder::PreOrder>(visit)
1363void Inliner::createDebugScope(InliningLevel &il, InstanceOp instance,
1364 Value parentScope) {
1365 auto op = debug::ScopeOp::create(
1366 il.mic.b, instance.getLoc(), instance.getInstanceNameAttr(),
1367 instance.getModuleNameAttr().getAttr(), parentScope);
1368 debugScopes.push_back(op);
1372void Inliner::identifyNLAsTargetingOnlyModules() {
1373 DenseSet<Operation *> nlaTargetedModules;
1376 for (
auto &[sym, mnla] : nlaMap) {
1377 auto nla = mnla.getNLA();
1378 if (nla.isModule()) {
1379 auto mod = symbolTable.lookup<FModuleLike>(nla.leafMod());
1381 "NLA ends in module reference but does not target FModuleLike?");
1382 nlaTargetedModules.insert(mod);
1387 auto scanForNLARefs = [&](FModuleLike mod) {
1388 DenseSet<StringAttr> referencedNLASyms;
1390 for (
auto anno : annos)
1391 if (auto sym = anno.getMember<FlatSymbolRefAttr>(
"circt.nonlocal"))
1392 referencedNLASyms.insert(sym.getAttr());
1395 for (
unsigned i = 0, e = mod.getNumPorts(); i != e; ++i)
1400 mod.walk([&](Operation *op) {
1401 if (op == mod.getOperation())
1406 TypeSwitch<Operation *>(op).Case<MemOp, InstanceOp>([&](
auto op) {
1407 for (
auto portAnnoAttr : op.getPortAnnotations())
1412 return referencedNLASyms;
1416 auto mergeSets = [](
auto &&a,
auto &&b) {
1417 a.insert(b.begin(), b.end());
1418 return std::move(a);
1423 SmallVector<FModuleLike, 0> mods(nlaTargetedModules.begin(),
1424 nlaTargetedModules.end());
1425 auto nonModOnlyNLAs =
1427 mergeSets, scanForNLARefs);
1430 for (
auto &[_, mnla] : nlaMap) {
1431 auto nla = mnla.getNLA();
1432 if (nla.isModule() && !nonModOnlyNLAs.count(nla.getSymNameAttr()))
1433 mnla.markModuleOnly();
1437Inliner::Inliner(CircuitOp circuit, SymbolTable &symbolTable)
1438 : circuit(circuit), context(circuit.getContext()),
1439 symbolTable(symbolTable) {}
1441LogicalResult Inliner::run() {
1445 for (
auto nla : circuit.
getBodyBlock()->getOps<
hw::HierPathOp>()) {
1446 auto mnla = MutableNLA(nla, &circuitNamespace);
1447 nlaMap.insert({nla.getSymNameAttr(), mnla});
1448 rootMap[mnla.getNLA().root()].push_back(nla.getSymNameAttr());
1449 for (
auto p : nla.getNamepath())
1450 if (auto ref = dyn_cast<InnerRefAttr>(p))
1451 instOpHierPaths[ref].push_back(nla.getSymNameAttr());
1455 identifyNLAsTargetingOnlyModules();
1458 for (
auto &op : circuit.getOps()) {
1460 if (
auto module = dyn_cast<FModuleLike>(op)) {
1461 if (module.canDiscardOnUseEmpty())
1463 liveModules.insert(module);
1464 if (isa<FModuleOp>(module))
1465 worklist.push_back(cast<FModuleOp>(module));
1470 if (isa<hw::HierPathOp>(op))
1474 auto symbolUses = SymbolTable::getSymbolUses(&op);
1477 for (
const auto &use : *symbolUses) {
1478 if (
auto flat = dyn_cast<FlatSymbolRefAttr>(use.getSymbolRef()))
1479 if (
auto moduleLike = symbolTable.lookup<FModuleLike>(flat.getAttr()))
1480 if (liveModules.insert(moduleLike).second)
1481 if (
auto module = dyn_cast<FModuleOp>(*moduleLike))
1482 worklist.push_back(module);
1488 while (!worklist.empty()) {
1489 auto moduleOp = worklist.pop_back_val();
1490 if (shouldFlatten(moduleOp)) {
1491 if (failed(flattenInstances(moduleOp)))
1498 if (failed(inlineInstances(moduleOp)))
1505 for (
auto scopeOp :
llvm::reverse(debugScopes))
1506 if (scopeOp.use_empty())
1508 debugScopes.clear();
1512 for (
auto mod :
llvm::make_early_inc_range(
1514 if (liveModules.count(mod))
1516 for (
auto nla : rootMap[mod.getModuleNameAttr()])
1517 nlaMap[nla].markDead();
1523 for (
auto mod : circuit.
getBodyBlock()->getOps<FModuleLike>()) {
1524 if (shouldInline(mod)) {
1526 "non-public module with inline annotation still present");
1529 assert(!shouldFlatten(mod) &&
"flatten annotation found on live module");
1533 llvm::dbgs() <<
"NLA modifications:\n";
1534 for (
auto nla : circuit.
getBodyBlock()->getOps<
hw::HierPathOp>()) {
1535 auto &mnla = nlaMap[nla.getNameAttr()];
1541 for (
auto &nla : nlaMap)
1542 nla.getSecond().applyUpdates();
1546 for (
auto fmodule : circuit.
getBodyBlock()->getOps<FModuleOp>()) {
1547 SmallVector<Attribute> newAnnotations;
1548 auto processNLAs = [&](
Annotation anno) ->
bool {
1549 if (
auto sym = anno.getMember<FlatSymbolRefAttr>(
"circt.nonlocal")) {
1553 if (!nlaMap.count(sym.getAttr()))
1556 auto mnla = nlaMap[sym.getAttr()];
1567 auto newTops = mnla.getAdditionalSymbols();
1568 if (newTops.empty() || mnla.hasRoot(fmodule))
1574 NamedAttrList newAnnotation;
1575 for (
auto rootAndSym : newTops.drop_front()) {
1576 for (
auto pair : anno.getDict()) {
1577 if (pair.getName().getValue() !=
"circt.nonlocal") {
1578 newAnnotation.push_back(pair);
1581 newAnnotation.push_back(
1582 {pair.getName(), FlatSymbolRefAttr::get(rootAndSym.getName())});
1584 newAnnotations.push_back(DictionaryAttr::get(context, newAnnotation));
1589 fmodule.walk([&](Operation *op) {
1593 if (annotations.empty())
1597 newAnnotations.clear();
1598 annotations.removeAnnotations(processNLAs);
1599 annotations.addAnnotations(newAnnotations);
1600 annotations.applyToOperation(op);
1604 SmallVector<Attribute> newPortAnnotations;
1605 for (
auto port : fmodule.getPorts()) {
1606 newAnnotations.clear();
1607 port.annotations.removeAnnotations(processNLAs);
1608 port.annotations.addAnnotations(newAnnotations);
1609 newPortAnnotations.push_back(
1610 ArrayAttr::get(context, port.annotations.getArray()));
1612 fmodule->setAttr(
"portAnnotations",
1613 ArrayAttr::get(context, newPortAnnotations));
1623class InlinerPass :
public circt::firrtl::impl::InlinerBase<InlinerPass> {
1624 void runOnOperation()
override {
1626 Inliner inliner(getOperation(), getAnalysis<SymbolTable>());
1627 if (failed(inliner.run()))
1628 signalPassFailure();
assert(baseType &&"element must be base type")
static void dump(DIModule &module, raw_indented_ostream &os)
static AnnotationSet forPort(Operation *op, size_t portNo)
static Location getLoc(DefSlot slot)
DenseMap< hw::InnerRefAttr, StringAttr > InnerRefToNewNameMap
static hw::InnerSymAttr uniqueInNamespace(hw::InnerSymAttr old, InnerRefToNewNameMap &map, hw::InnerSymbolNamespace &ns, StringAttr istName)
Generate and creating map entries for new inner symbol based on old one and an appropriate namespace ...
static void mapResultsToWires(IRMapping &mapper, SmallVectorImpl< Value > &wires, InstanceOp instance)
This function is used after inlining a module, to handle the conversion between module ports and inst...
static void replaceInnerRefUsers(ArrayRef< Operation * > newOps, const InnerRefToNewNameMap &map, StringAttr istName)
Process each operation, updating InnerRefAttr's using the specified map and the given name as the con...
static Block * getBodyBlock(FModuleLike mod)
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
bool applyToOperation(Operation *op) const
Store the annotations in this set in an operation's annotations attribute, overwriting any existing a...
bool hasAnnotation(StringRef className) const
Return true if we have an annotation with the specified class name.
static AnnotationSet forPort(FModuleLike op, size_t portNo)
Get an annotation set for the specified port.
This class provides a read-only projection of an annotation.
std::pair< hw::InnerSymAttr, StringAttr > getOrAddInnerSym(MLIRContext *context, hw::InnerSymAttr attr, uint64_t fieldID, llvm::function_ref< hw::InnerSymbolNamespace &()> getNamespace)
Ensure that the the InnerSymAttr has a symbol on the field specified.
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
constexpr const char * inlineAnnoClass
constexpr const char * flattenAnnoClass
StringAttr getInnerSymName(Operation *op)
Return the StringAttr for the inner_sym name, if it exists.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
static ResultTy transformReduce(MLIRContext *context, IterTy begin, IterTy end, ResultTy init, ReduceFuncTy reduce, TransformFuncTy transform)
Wrapper for llvm::parallelTransformReduce that performs the transform_reduce serially when MLIR multi...
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
llvm::raw_ostream & debugFooter(int width=80)
Write a boilerplate footer to the debug stream to indicate that a pass has ended.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
The namespace of a CircuitOp, generally inhabited by modules.