23 #include "mlir/IR/Dominance.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/IR/Threading.h"
26 #include "mlir/Pass/Pass.h"
27 #include "llvm/ADT/EquivalenceClasses.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/TinyPtrVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
33 #define DEBUG_TYPE "infer-resets"
37 #define GEN_PASS_DEF_INFERRESETS
38 #include "circt/Dialect/FIRRTL/Passes.h.inc"
42 using circt::igraph::InstanceOpInterface;
45 using llvm::BumpPtrAllocator;
46 using llvm::MapVector;
47 using llvm::SmallDenseSet;
48 using llvm::SmallSetVector;
49 using mlir::FailureOr;
50 using mlir::InferTypeOpInterface;
51 using mlir::WalkOrder;
53 using namespace circt;
54 using namespace firrtl;
71 std::optional<unsigned> existingPort;
72 StringAttr newPortName;
74 ResetDomain(Value reset) : reset(reset) {}
78 inline bool operator==(
const ResetDomain &a,
const ResetDomain &b) {
79 return (a.isTop == b.isTop && a.reset == b.reset);
81 inline bool operator!=(
const ResetDomain &a,
const ResetDomain &b) {
88 if (
auto arg = dyn_cast<BlockArgument>(reset)) {
89 auto module = cast<FModuleOp>(arg.getParentRegion()->getParentOp());
90 return {module.getPortNameAttr(arg.getArgNumber()), module};
92 auto op = reset.getDefiningOp();
93 return {op->getAttrOfType<StringAttr>(
"name"),
94 op->getParentOfType<FModuleOp>()};
109 auto it = cache.find(type);
110 if (it != cache.end())
112 auto nullBit = [&]() {
119 .
Case<ClockType>([&](
auto type) {
120 return builder.create<AsClockPrimOp>(nullBit());
122 .Case<AsyncResetType>([&](
auto type) {
123 return builder.create<AsAsyncResetPrimOp>(nullBit());
125 .Case<SIntType, UIntType>([&](
auto type) {
126 return builder.create<ConstantOp>(
127 type, APInt::getZero(type.getWidth().value_or(1)));
129 .Case<BundleType>([&](
auto type) {
130 auto wireOp = builder.create<WireOp>(type);
131 for (
unsigned i = 0, e = type.getNumElements(); i < e; ++i) {
132 auto fieldType = type.getElementTypePreservingConst(i);
135 builder.create<SubfieldOp>(fieldType, wireOp.getResult(), i);
138 return wireOp.getResult();
140 .Case<FVectorType>([&](
auto type) {
141 auto wireOp = builder.create<WireOp>(type);
143 builder, type.getElementTypePreservingConst(), cache);
144 for (
unsigned i = 0, e = type.getNumElements(); i < e; ++i) {
145 auto acc = builder.create<SubindexOp>(zero.getType(),
146 wireOp.getResult(), i);
149 return wireOp.getResult();
151 .Case<ResetType, AnalogType>(
152 [&](
auto type) {
return builder.create<InvalidValueOp>(type); })
154 llvm_unreachable(
"switch handles all types");
157 cache.insert({type, value});
174 Value reset, Value resetValue) {
178 bool resetValueUsed =
false;
180 for (
auto &use : target.getUses()) {
181 Operation *useOp = use.getOwner();
182 builder.setInsertionPoint(useOp);
183 TypeSwitch<Operation *>(useOp)
186 .Case<ConnectOp, MatchingConnectOp>([&](
auto op) {
187 if (op.getDest() != target)
189 LLVM_DEBUG(llvm::dbgs() <<
" - Insert mux into " << op <<
"\n");
191 builder.create<MuxPrimOp>(reset, resetValue, op.getSrc());
192 op.getSrcMutable().assign(muxOp);
193 resetValueUsed =
true;
196 .Case<SubfieldOp>([&](
auto op) {
198 builder.create<SubfieldOp>(resetValue, op.getFieldIndexAttr());
200 resetValueUsed =
true;
202 resetSubValue.erase();
205 .Case<SubindexOp>([&](
auto op) {
207 builder.create<SubindexOp>(resetValue, op.getIndexAttr());
209 resetValueUsed =
true;
211 resetSubValue.erase();
214 .Case<SubaccessOp>([&](
auto op) {
215 if (op.getInput() != target)
218 builder.create<SubaccessOp>(resetValue, op.getIndex());
220 resetValueUsed =
true;
222 resetSubValue.erase();
225 return resetValueUsed;
240 bool operator<(
const ResetSignal &other)
const {
return field < other.field; }
241 bool operator==(
const ResetSignal &other)
const {
242 return field == other.field;
244 bool operator!=(
const ResetSignal &other)
const {
return !(*
this == other); }
264 using ResetDrives = SmallVector<ResetDrive, 1>;
267 using ResetNetwork = llvm::iterator_range<
268 llvm::EquivalenceClasses<ResetSignal>::member_iterator>;
271 enum class ResetKind { Async, Sync };
277 struct DenseMapInfo<ResetSignal> {
279 return ResetSignal{DenseMapInfo<FieldRef>::getEmptyKey(), {}};
282 return ResetSignal{DenseMapInfo<FieldRef>::getTombstoneKey(), {}};
287 static bool isEqual(
const ResetSignal &lhs,
const ResetSignal &rhs) {
293 template <
typename T>
296 case ResetKind::Async:
297 return os <<
"async";
298 case ResetKind::Sync:
410 struct InferResetsPass
411 :
public circt::firrtl::impl::InferResetsBase<InferResetsPass> {
412 void runOnOperation()
override;
413 void runOnOperationInner();
416 using InferResetsBase::InferResetsBase;
417 InferResetsPass(
const InferResetsPass &other) : InferResetsBase(other) {}
422 void traceResets(CircuitOp circuit);
423 void traceResets(InstanceOp inst);
424 void traceResets(Value dst, Value src, Location loc);
425 void traceResets(Value value);
426 void traceResets(Type dstType, Value dst,
unsigned dstID, Type srcType,
427 Value src,
unsigned srcID, Location loc);
429 LogicalResult inferAndUpdateResets();
430 FailureOr<ResetKind> inferReset(ResetNetwork net);
431 LogicalResult updateReset(ResetNetwork net, ResetKind kind);
437 LogicalResult collectAnnos(CircuitOp circuit);
443 FailureOr<std::optional<Value>> collectAnnos(FModuleOp module);
445 LogicalResult buildDomains(CircuitOp circuit);
446 void buildDomains(FModuleOp module,
const InstancePath &instPath,
448 unsigned indent = 0);
450 void determineImpl();
451 void determineImpl(FModuleOp module, ResetDomain &domain);
453 LogicalResult implementAsyncReset();
454 LogicalResult implementAsyncReset(FModuleOp module, ResetDomain &domain);
455 void implementAsyncReset(Operation *op, FModuleOp module, Value actualReset);
457 LogicalResult verifyNoAbstractReset();
463 ResetNetwork getResetNetwork(ResetSignal signal) {
464 return llvm::make_range(resetClasses.findLeader(signal),
465 resetClasses.member_end());
469 ResetDrives &getResetDrives(ResetNetwork net) {
470 return resetDrives[*net.begin()];
475 ResetSignal guessRoot(ResetNetwork net);
476 ResetSignal guessRoot(ResetSignal signal) {
477 return guessRoot(getResetNetwork(signal));
484 llvm::EquivalenceClasses<ResetSignal> resetClasses;
487 DenseMap<ResetSignal, ResetDrives> resetDrives;
492 DenseMap<Operation *, Value> annotatedResets;
496 MapVector<FModuleOp, SmallVector<std::pair<ResetDomain, InstancePath>, 1>>
503 std::unique_ptr<InstancePathCache> instancePathCache;
507 void InferResetsPass::runOnOperation() {
508 runOnOperationInner();
509 resetClasses = llvm::EquivalenceClasses<ResetSignal>();
511 annotatedResets.clear();
513 instancePathCache.reset(
nullptr);
514 markAnalysesPreserved<InstanceGraph>();
517 void InferResetsPass::runOnOperationInner() {
518 instanceGraph = &getAnalysis<InstanceGraph>();
519 instancePathCache = std::make_unique<InstancePathCache>(*instanceGraph);
522 traceResets(getOperation());
525 if (failed(inferAndUpdateResets()))
526 return signalPassFailure();
529 if (failed(collectAnnos(getOperation())))
530 return signalPassFailure();
533 if (failed(buildDomains(getOperation())))
534 return signalPassFailure();
540 if (failed(implementAsyncReset()))
541 return signalPassFailure();
544 if (failed(verifyNoAbstractReset()))
545 return signalPassFailure();
549 return std::make_unique<InferResetsPass>();
552 ResetSignal InferResetsPass::guessRoot(ResetNetwork net) {
553 ResetDrives &drives = getResetDrives(net);
554 ResetSignal bestSignal = *net.begin();
555 unsigned bestNumDrives = -1;
557 for (
auto signal : net) {
559 if (isa_and_nonnull<InvalidValueOp>(
560 signal.field.getValue().getDefiningOp()))
565 unsigned numDrives = 0;
566 for (
auto &drive : drives)
567 if (drive.dst == signal)
573 if (numDrives < bestNumDrives) {
574 bestNumDrives = numDrives;
593 .
Case<BundleType>([](
auto type) {
595 for (
auto e : type.getElements())
600 [](
auto type) {
return getMaxFieldID(type.getElementType()) + 1; })
601 .Default([](
auto) {
return 0; });
604 static unsigned getFieldID(BundleType type,
unsigned index) {
605 assert(index < type.getNumElements());
607 for (
unsigned i = 0; i < index; ++i)
615 assert(type.getNumElements() &&
"Bundle must have >0 fields");
617 for (
const auto &e : llvm::enumerate(type.getElements())) {
619 if (fieldID < numSubfields)
621 fieldID -= numSubfields;
623 assert(
false &&
"field id outside bundle");
629 if (oldType.isGround()) {
635 if (
auto bundleType = type_dyn_cast<BundleType>(oldType)) {
643 if (
auto vectorType = type_dyn_cast<FVectorType>(oldType)) {
644 if (vectorType.getNumElements() == 0)
661 if (
auto arg = dyn_cast<BlockArgument>(value)) {
662 auto module = cast<FModuleOp>(arg.getOwner()->getParentOp());
663 string += module.getPortName(arg.getArgNumber());
667 auto *op = value.getDefiningOp();
668 return TypeSwitch<Operation *, bool>(op)
669 .Case<InstanceOp, MemOp>([&](
auto op) {
670 string += op.getName();
673 op.getPortName(cast<OpResult>(value).getResultNumber()).getValue();
676 .Case<WireOp, NodeOp, RegOp, RegResetOp>([&](
auto op) {
677 string += op.getName();
680 .Default([](
auto) {
return false; });
684 SmallString<64> name;
689 auto type = value.getType();
692 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
695 auto &element = bundleType.getElements()[index];
698 string += element.name.getValue();
701 localID = localID -
getFieldID(bundleType, index);
702 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
705 type = vecType.getElementType();
712 llvm_unreachable(
"unsupported type");
724 return TypeSwitch<Type, bool>(type)
726 return type.getRecursiveTypeProperties().hasUninferredReset;
728 .Default([](
auto) {
return false; });
735 void InferResetsPass::traceResets(CircuitOp circuit) {
737 llvm::dbgs() <<
"\n";
738 debugHeader(
"Tracing uninferred resets") <<
"\n\n";
741 SmallVector<std::pair<FModuleOp, SmallVector<Operation *>>> moduleToOps;
743 for (
auto module : circuit.getOps<FModuleOp>())
744 moduleToOps.push_back({module, {}});
746 hw::InnerRefNamespace irn{getAnalysis<SymbolTable>(),
747 getAnalysis<hw::InnerSymbolTableCollection>()};
749 mlir::parallelForEach(circuit.getContext(), moduleToOps, [](
auto &e) {
750 e.first.walk([&](Operation *op) {
754 op->getResultTypes(),
755 [](mlir::Type type) { return typeContainsReset(type); }) ||
756 llvm::any_of(op->getOperandTypes(), typeContainsReset))
757 e.second.push_back(op);
761 for (
auto &[_, ops] : moduleToOps)
762 for (
auto *op : ops) {
763 TypeSwitch<Operation *>(op)
764 .Case<FConnectLike>([&](
auto op) {
765 traceResets(op.getDest(), op.getSrc(), op.getLoc());
767 .Case<InstanceOp>([&](
auto op) { traceResets(op); })
768 .Case<RefSendOp>([&](
auto op) {
770 traceResets(op.getType().getType(), op.getResult(), 0,
771 op.getBase().getType().getPassiveType(), op.getBase(),
774 .Case<RefResolveOp>([&](
auto op) {
776 traceResets(op.getType(), op.getResult(), 0,
777 op.getRef().getType().getType(), op.getRef(), 0,
780 .Case<Forceable>([&](Forceable op) {
781 if (
auto node = dyn_cast<NodeOp>(op.getOperation()))
782 traceResets(node.getResult(), node.getInput(), node.getLoc());
784 if (op.isForceable())
785 traceResets(op.getDataType(), op.getData(), 0, op.getDataType(),
786 op.getDataRef(), 0, op.getLoc());
788 .Case<RWProbeOp>([&](RWProbeOp op) {
789 auto ist = irn.lookup(op.getTarget());
792 auto baseType = op.getType().getType();
793 traceResets(baseType, op.getResult(), 0, baseType.getPassiveType(),
794 ref.getValue(), ref.getFieldID(), op.getLoc());
796 .Case<UninferredResetCastOp, ConstCastOp, RefCastOp>([&](
auto op) {
797 traceResets(op.getResult(), op.getInput(), op.getLoc());
799 .Case<InvalidValueOp>([&](
auto op) {
808 auto type = op.getType();
811 LLVM_DEBUG(llvm::dbgs() <<
"Uniquify " << op <<
"\n");
812 ImplicitLocOpBuilder builder(op->getLoc(), op);
814 llvm::make_early_inc_range(llvm::drop_begin(op->getUses()))) {
820 auto newOp = builder.create<InvalidValueOp>(type);
825 .Case<SubfieldOp>([&](
auto op) {
828 BundleType bundleType = op.getInput().getType();
829 auto index = op.getFieldIndex();
830 traceResets(op.getType(), op.getResult(), 0,
831 bundleType.getElements()[index].type, op.getInput(),
835 .Case<SubindexOp, SubaccessOp>([&](
auto op) {
848 FVectorType vectorType = op.getInput().getType();
849 traceResets(op.getType(), op.getResult(), 0,
850 vectorType.getElementType(), op.getInput(),
854 .Case<RefSubOp>([&](RefSubOp op) {
856 auto aggType = op.getInput().getType().getType();
857 uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(aggType)
858 .Case<FVectorType>([](
auto type) {
861 .Case<BundleType>([&](
auto type) {
864 traceResets(op.getType(), op.getResult(), 0,
865 op.getResult().getType(), op.getInput(), fieldID,
873 void InferResetsPass::traceResets(InstanceOp inst) {
875 auto module = inst.getReferencedModule<FModuleOp>(*instanceGraph);
878 LLVM_DEBUG(llvm::dbgs() <<
"Visiting instance " << inst.getName() <<
"\n");
881 for (
const auto &it : llvm::enumerate(inst.getResults())) {
882 auto dir = module.getPortDirection(it.index());
883 Value dstPort = module.getArgument(it.index());
884 Value srcPort = it.value();
885 if (dir == Direction::Out)
886 std::swap(dstPort, srcPort);
887 traceResets(dstPort, srcPort, it.value().getLoc());
893 void InferResetsPass::traceResets(Value dst, Value src, Location loc) {
895 traceResets(dst.getType(), dst, 0, src.getType(), src, 0, loc);
900 void InferResetsPass::traceResets(Type dstType, Value dst,
unsigned dstID,
901 Type srcType, Value src,
unsigned srcID,
903 if (
auto dstBundle = type_dyn_cast<BundleType>(dstType)) {
904 auto srcBundle = type_cast<BundleType>(srcType);
905 for (
unsigned dstIdx = 0, e = dstBundle.getNumElements(); dstIdx < e;
907 auto dstField = dstBundle.getElements()[dstIdx].name;
908 auto srcIdx = srcBundle.getElementIndex(dstField);
911 auto &dstElt = dstBundle.getElements()[dstIdx];
912 auto &srcElt = srcBundle.getElements()[*srcIdx];
914 traceResets(srcElt.type, src, srcID +
getFieldID(srcBundle, *srcIdx),
915 dstElt.type, dst, dstID +
getFieldID(dstBundle, dstIdx),
918 traceResets(dstElt.type, dst, dstID +
getFieldID(dstBundle, dstIdx),
919 srcElt.type, src, srcID +
getFieldID(srcBundle, *srcIdx),
926 if (
auto dstVector = type_dyn_cast<FVectorType>(dstType)) {
927 auto srcVector = type_cast<FVectorType>(srcType);
928 auto srcElType = srcVector.getElementType();
929 auto dstElType = dstVector.getElementType();
942 traceResets(dstElType, dst, dstID +
getFieldID(dstVector), srcElType, src,
948 if (
auto dstRef = type_dyn_cast<RefType>(dstType)) {
949 auto srcRef = type_cast<RefType>(srcType);
950 return traceResets(dstRef.getType(), dst, dstID, srcRef.getType(), src,
955 auto dstBase = type_dyn_cast<FIRRTLBaseType>(dstType);
956 auto srcBase = type_dyn_cast<FIRRTLBaseType>(srcType);
957 if (!dstBase || !srcBase)
959 if (!type_isa<ResetType>(dstBase) && !type_isa<ResetType>(srcBase))
964 LLVM_DEBUG(llvm::dbgs() <<
"Visiting driver '" << dstField <<
"' = '"
965 << srcField <<
"' (" << dstType <<
" = " << srcType
971 ResetSignal dstLeader =
972 *resetClasses.findLeader(resetClasses.insert({dstField, dstBase}));
973 ResetSignal srcLeader =
974 *resetClasses.findLeader(resetClasses.insert({srcField, srcBase}));
977 ResetSignal unionLeader = *resetClasses.unionSets(dstLeader, srcLeader);
978 assert(unionLeader == dstLeader || unionLeader == srcLeader);
983 if (dstLeader != srcLeader) {
984 auto &unionDrives = resetDrives[unionLeader];
985 auto mergedDrivesIt =
986 resetDrives.find(unionLeader == dstLeader ? srcLeader : dstLeader);
987 if (mergedDrivesIt != resetDrives.end()) {
988 unionDrives.append(mergedDrivesIt->second);
989 resetDrives.erase(mergedDrivesIt);
995 resetDrives[unionLeader].push_back(
996 {{dstField, dstBase}, {srcField, srcBase}, loc});
1003 LogicalResult InferResetsPass::inferAndUpdateResets() {
1005 llvm::dbgs() <<
"\n";
1008 for (
auto it = resetClasses.begin(),
end = resetClasses.end(); it !=
end;
1010 if (!it->isLeader())
1012 ResetNetwork net = llvm::make_range(resetClasses.member_begin(it),
1013 resetClasses.member_end());
1016 auto kind = inferReset(net);
1021 if (failed(updateReset(net, *kind)))
1027 FailureOr<ResetKind> InferResetsPass::inferReset(ResetNetwork net) {
1028 LLVM_DEBUG(llvm::dbgs() <<
"Inferring reset network with "
1029 << std::distance(net.begin(), net.end())
1033 unsigned asyncDrives = 0;
1034 unsigned syncDrives = 0;
1035 unsigned invalidDrives = 0;
1036 for (ResetSignal signal : net) {
1038 if (type_isa<AsyncResetType>(signal.type))
1040 else if (type_isa<UIntType>(signal.type))
1043 isa_and_nonnull<InvalidValueOp>(
1044 signal.field.getValue().getDefiningOp()))
1047 LLVM_DEBUG(llvm::dbgs() <<
"- Found " << asyncDrives <<
" async, "
1048 << syncDrives <<
" sync, " << invalidDrives
1049 <<
" invalid drives\n");
1052 if (asyncDrives == 0 && syncDrives == 0 && invalidDrives == 0) {
1053 ResetSignal root = guessRoot(net);
1054 auto diag = mlir::emitError(root.field.getValue().getLoc())
1055 <<
"reset network never driven with concrete type";
1056 for (ResetSignal signal : net)
1057 diag.attachNote(signal.field.getLoc()) <<
"here: ";
1062 if (asyncDrives > 0 && syncDrives > 0) {
1063 ResetSignal root = guessRoot(net);
1064 bool majorityAsync = asyncDrives >= syncDrives;
1065 auto diag = mlir::emitError(root.field.getValue().getLoc())
1067 SmallString<32> fieldName;
1069 diag <<
" \"" << fieldName <<
"\"";
1070 diag <<
" simultaneously connected to async and sync resets";
1071 diag.attachNote(root.field.getValue().getLoc())
1072 <<
"majority of connections to this reset are "
1073 << (majorityAsync ?
"async" :
"sync");
1074 for (
auto &drive : getResetDrives(net)) {
1075 if ((type_isa<AsyncResetType>(drive.dst.type) && !majorityAsync) ||
1076 (type_isa<AsyncResetType>(drive.src.type) && !majorityAsync) ||
1077 (type_isa<UIntType>(drive.dst.type) && majorityAsync) ||
1078 (type_isa<UIntType>(drive.src.type) && majorityAsync))
1079 diag.attachNote(drive.loc)
1080 << (type_isa<AsyncResetType>(drive.src.type) ?
"async" :
"sync")
1089 auto kind = (asyncDrives ? ResetKind::Async : ResetKind::Sync);
1090 LLVM_DEBUG(llvm::dbgs() <<
"- Inferred as " << kind <<
"\n");
1098 LogicalResult InferResetsPass::updateReset(ResetNetwork net, ResetKind kind) {
1099 LLVM_DEBUG(llvm::dbgs() <<
"Updating reset network with "
1100 << std::distance(net.begin(), net.end())
1101 <<
" nodes to " << kind <<
"\n");
1105 if (kind == ResetKind::Async)
1113 SmallSetVector<Operation *, 16> worklist;
1114 SmallDenseSet<Operation *> moduleWorklist;
1115 SmallDenseSet<std::pair<Operation *, Operation *>> extmoduleWorklist;
1116 for (
auto signal : net) {
1117 Value value = signal.field.getValue();
1118 if (!isa<BlockArgument>(value) &&
1119 !isa_and_nonnull<WireOp, RegOp, RegResetOp, InstanceOp, InvalidValueOp,
1120 ConstCastOp, RefCastOp, UninferredResetCastOp,
1121 RWProbeOp>(value.getDefiningOp()))
1123 if (updateReset(signal.field, resetType)) {
1124 for (
auto user : value.getUsers())
1125 worklist.insert(user);
1126 if (
auto blockArg = dyn_cast<BlockArgument>(value))
1127 moduleWorklist.insert(blockArg.getOwner()->getParentOp());
1128 else if (
auto instOp = value.getDefiningOp<InstanceOp>()) {
1129 if (
auto extmodule =
1130 instOp.getReferencedModule<FExtModuleOp>(*instanceGraph))
1131 extmoduleWorklist.insert({extmodule, instOp});
1132 }
else if (
auto uncast = value.getDefiningOp<UninferredResetCastOp>()) {
1133 uncast.replaceAllUsesWith(uncast.getInput());
1143 while (!worklist.empty()) {
1144 auto *wop = worklist.pop_back_val();
1145 SmallVector<Type, 2> types;
1146 if (
auto op = dyn_cast<InferTypeOpInterface>(wop)) {
1148 SmallVector<Type, 2> types;
1149 if (failed(op.inferReturnTypes(op->getContext(), op->getLoc(),
1150 op->getOperands(), op->getAttrDictionary(),
1151 op->getPropertiesStorage(),
1152 op->getRegions(), types)))
1157 for (
auto it : llvm::zip(op->getResults(), types)) {
1158 auto newType = std::get<1>(it);
1159 if (std::get<0>(it).getType() == newType)
1161 std::get<0>(it).setType(newType);
1162 for (
auto *user : std::get<0>(it).getUsers())
1163 worklist.insert(user);
1165 LLVM_DEBUG(llvm::dbgs() <<
"- Inferred " << *op <<
"\n");
1166 }
else if (
auto uop = dyn_cast<UninferredResetCastOp>(wop)) {
1167 for (
auto *user : uop.getResult().getUsers())
1168 worklist.insert(user);
1169 uop.replaceAllUsesWith(uop.getInput());
1170 LLVM_DEBUG(llvm::dbgs() <<
"- Inferred " << uop <<
"\n");
1176 for (
auto *op : moduleWorklist) {
1177 auto module = dyn_cast<FModuleOp>(op);
1181 SmallVector<Attribute> argTypes;
1182 argTypes.reserve(module.getNumPorts());
1183 for (
auto arg : module.getArguments())
1186 module->setAttr(FModuleLike::getPortTypesAttrName(),
1188 LLVM_DEBUG(llvm::dbgs()
1189 <<
"- Updated type of module '" << module.getName() <<
"'\n");
1193 for (
auto pair : extmoduleWorklist) {
1194 auto module = cast<FExtModuleOp>(pair.first);
1195 auto instOp = cast<InstanceOp>(pair.second);
1197 SmallVector<Attribute> types;
1198 for (
auto type : instOp.getResultTypes())
1201 module->setAttr(FModuleLike::getPortTypesAttrName(),
1203 LLVM_DEBUG(llvm::dbgs()
1204 <<
"- Updated type of extmodule '" << module.getName() <<
"'\n");
1214 if (oldType.isGround()) {
1220 if (
auto bundleType = type_dyn_cast<BundleType>(oldType)) {
1222 SmallVector<BundleType::BundleElement> fields(bundleType.begin(),
1225 fields[index].type, fieldID -
getFieldID(bundleType, index), fieldType);
1230 if (
auto vectorType = type_dyn_cast<FVectorType>(oldType)) {
1231 auto newType =
updateType(vectorType.getElementType(),
1232 fieldID -
getFieldID(vectorType), fieldType);
1234 vectorType.isConst());
1237 llvm_unreachable(
"unknown aggregate type");
1244 auto oldType = type_cast<FIRRTLType>(field.
getValue().getType());
1250 if (oldType == newType)
1252 LLVM_DEBUG(llvm::dbgs() <<
"- Updating '" << field <<
"' from " << oldType
1253 <<
" to " << newType <<
"\n");
1262 LogicalResult InferResetsPass::collectAnnos(CircuitOp circuit) {
1264 llvm::dbgs() <<
"\n";
1265 debugHeader(
"Gather async reset annotations") <<
"\n\n";
1267 SmallVector<std::pair<FModuleOp, std::optional<Value>>> results;
1268 for (
auto module : circuit.getOps<FModuleOp>())
1269 results.push_back({module, {}});
1271 if (failed(mlir::failableParallelForEach(
1272 circuit.getContext(), results, [&](
auto &moduleAndResult) {
1273 auto result = collectAnnos(moduleAndResult.first);
1276 moduleAndResult.second = *result;
1281 for (
auto [module, reset] : results)
1282 if (reset.has_value())
1283 annotatedResets.insert({module, *reset});
1287 FailureOr<std::optional<Value>>
1288 InferResetsPass::collectAnnos(FModuleOp module) {
1289 bool anyFailed =
false;
1290 SmallSetVector<std::pair<Annotation, Location>, 4> conflictingAnnos;
1294 bool ignore =
false;
1295 AnnotationSet::removeAnnotations(module, [&](
Annotation anno) {
1298 conflictingAnnos.insert({anno, module.getLoc()});
1303 module.emitError(
"'FullAsyncResetAnnotation' cannot target module; "
1304 "must target port or wire/node instead");
1314 AnnotationSet::removePortAnnotations(module, [&](
unsigned argNum,
1316 Value arg = module.getArgument(argNum);
1318 if (!isa<AsyncResetType>(arg.getType())) {
1319 mlir::emitError(arg.getLoc(),
"'FullAsyncResetAnnotation' must "
1320 "target async reset, but targets ")
1326 conflictingAnnos.insert({anno, reset.getLoc()});
1332 mlir::emitError(arg.getLoc(),
1333 "'IgnoreFullAsyncResetAnnotation' cannot target port; "
1334 "must target module instead");
1343 module.getBody().walk([&](Operation *op) {
1345 if (!isa<WireOp, NodeOp>(op)) {
1350 "reset annotations must target module, port, or wire/node");
1357 AnnotationSet::removeAnnotations(op, [&](
Annotation anno) {
1359 auto resultType = op->getResult(0).getType();
1360 if (!isa<AsyncResetType>(resultType)) {
1361 mlir::emitError(op->getLoc(),
"'FullAsyncResetAnnotation' must "
1362 "target async reset, but targets ")
1367 reset = op->getResult(0);
1368 conflictingAnnos.insert({anno, reset.getLoc()});
1374 "'IgnoreFullAsyncResetAnnotation' cannot target wire/node; must "
1375 "target module instead");
1387 if (!ignore && !reset) {
1388 LLVM_DEBUG(llvm::dbgs()
1389 <<
"No reset annotation for " << module.getName() <<
"\n");
1390 return std::optional<Value>();
1394 if (conflictingAnnos.size() > 1) {
1395 auto diag = module.emitError(
"multiple reset annotations on module '")
1396 << module.getName() <<
"'";
1397 for (
auto &annoAndLoc : conflictingAnnos)
1398 diag.attachNote(annoAndLoc.second)
1399 <<
"conflicting " << annoAndLoc.first.getClassAttr() <<
":";
1405 llvm::dbgs() <<
"Annotated reset for " << module.getName() <<
": ";
1407 llvm::dbgs() <<
"no domain\n";
1408 else if (
auto arg = dyn_cast<BlockArgument>(reset))
1409 llvm::dbgs() <<
"port " << module.getPortName(arg.getArgNumber()) <<
"\n";
1411 llvm::dbgs() <<
"wire "
1412 << reset.getDefiningOp()->getAttrOfType<StringAttr>(
"name")
1418 return std::optional<Value>(reset);
1430 LogicalResult InferResetsPass::buildDomains(CircuitOp circuit) {
1432 llvm::dbgs() <<
"\n";
1433 debugHeader(
"Build async reset domains") <<
"\n\n";
1437 auto &instGraph = getAnalysis<InstanceGraph>();
1438 auto module = dyn_cast<FModuleOp>(*instGraph.getTopLevelNode()->getModule());
1440 LLVM_DEBUG(llvm::dbgs()
1441 <<
"Skipping circuit because main module is no `firrtl.module`");
1444 buildDomains(module,
InstancePath{}, Value{}, instGraph);
1447 bool anyFailed =
false;
1448 for (
auto &it : domains) {
1449 auto module = cast<FModuleOp>(it.first);
1450 auto &domainConflicts = it.second;
1451 if (domainConflicts.size() <= 1)
1455 SmallDenseSet<Value> printedDomainResets;
1456 auto diag = module.emitError(
"module '")
1458 <<
"' instantiated in different reset domains";
1459 for (
auto &it : domainConflicts) {
1460 ResetDomain &domain = it.first;
1461 const auto &path = it.second;
1462 auto inst = path.leaf();
1463 auto loc = path.empty() ? module.getLoc() : inst.getLoc();
1464 auto ¬e = diag.attachNote(loc);
1468 note <<
"root instance";
1470 note <<
"instance '";
1473 [&](InstanceOpInterface inst) { note << inst.getInstanceName(); },
1474 [&]() { note <<
"/"; });
1482 note <<
" reset domain rooted at '" << nameAndModule.first.getValue()
1483 <<
"' of module '" << nameAndModule.second.getName() <<
"'";
1486 if (printedDomainResets.insert(domain.reset).second) {
1487 diag.attachNote(domain.reset.getLoc())
1488 <<
"reset domain '" << nameAndModule.first.getValue()
1489 <<
"' of module '" << nameAndModule.second.getName()
1490 <<
"' declared here:";
1493 note <<
" no reset domain";
1496 return failure(anyFailed);
1499 void InferResetsPass::buildDomains(FModuleOp module,
1504 llvm::dbgs().indent(indent * 2) <<
"Visiting ";
1505 if (instPath.
empty())
1506 llvm::dbgs() <<
"$root";
1508 llvm::dbgs() << instPath.
leaf().getInstanceName();
1509 llvm::dbgs() <<
" (" << module.getName() <<
")\n";
1513 ResetDomain domain(parentReset);
1514 auto it = annotatedResets.find(module);
1515 if (it != annotatedResets.end()) {
1516 domain.isTop =
true;
1517 domain.reset = it->second;
1523 auto &entries = domains[module];
1524 if (llvm::all_of(entries,
1525 [&](
const auto &entry) {
return entry.first != domain; }))
1526 entries.push_back({domain, instPath});
1529 for (
auto *record : *instGraph[module]) {
1530 auto submodule = dyn_cast<FModuleOp>(*record->getTarget()->getModule());
1534 instancePathCache->appendInstance(instPath, record->getInstance());
1535 buildDomains(submodule, childPath, domain.reset, instGraph, indent + 1);
1540 void InferResetsPass::determineImpl() {
1542 llvm::dbgs() <<
"\n";
1543 debugHeader(
"Determine implementation") <<
"\n\n";
1545 for (
auto &it : domains) {
1546 auto module = cast<FModuleOp>(it.first);
1547 auto &domain = it.second.back().first;
1548 determineImpl(module, domain);
1568 void InferResetsPass::determineImpl(FModuleOp module, ResetDomain &domain) {
1571 LLVM_DEBUG(llvm::dbgs() <<
"Planning reset for " << module.getName() <<
"\n");
1576 LLVM_DEBUG(llvm::dbgs() <<
"- Rooting at local value "
1578 domain.existingValue = domain.reset;
1579 if (
auto blockArg = dyn_cast<BlockArgument>(domain.reset))
1580 domain.existingPort = blockArg.getArgNumber();
1587 auto neededType = domain.reset.getType();
1588 LLVM_DEBUG(llvm::dbgs() <<
"- Looking for existing port " << neededName
1590 auto portNames = module.getPortNames();
1591 auto ports = llvm::zip(portNames, module.getArguments());
1592 auto portIt = llvm::find_if(
1593 ports, [&](
auto port) {
return std::get<0>(port) == neededName; });
1594 if (portIt != ports.end() && std::get<1>(*portIt).getType() == neededType) {
1595 LLVM_DEBUG(llvm::dbgs()
1596 <<
"- Reusing existing port " << neededName <<
"\n");
1597 domain.existingValue = std::get<1>(*portIt);
1598 domain.existingPort = std::distance(ports.begin(), portIt);
1608 if (portIt != ports.end()) {
1609 LLVM_DEBUG(llvm::dbgs()
1610 <<
"- Existing " << neededName <<
" has incompatible type "
1611 << std::get<1>(*portIt).getType() <<
"\n");
1613 unsigned suffix = 0;
1617 Twine(
"_") + Twine(suffix++));
1618 }
while (llvm::is_contained(portNames, newName));
1619 LLVM_DEBUG(llvm::dbgs()
1620 <<
"- Creating uniquified port " << newName <<
"\n");
1621 domain.newPortName = newName;
1627 LLVM_DEBUG(llvm::dbgs() <<
"- Creating new port " << neededName <<
"\n");
1628 domain.newPortName = neededName;
1636 LogicalResult InferResetsPass::implementAsyncReset() {
1638 llvm::dbgs() <<
"\n";
1641 for (
auto &it : domains)
1642 if (failed(implementAsyncReset(cast<FModuleOp>(it.first),
1643 it.second.back().first)))
1653 LogicalResult InferResetsPass::implementAsyncReset(FModuleOp module,
1654 ResetDomain &domain) {
1655 LLVM_DEBUG(llvm::dbgs() <<
"Implementing async reset for " << module.getName()
1659 if (!domain.reset) {
1660 LLVM_DEBUG(llvm::dbgs()
1661 <<
"- Skipping because module explicitly has no domain\n");
1666 Value actualReset = domain.existingValue;
1667 if (domain.newPortName) {
1668 PortInfo portInfo{domain.newPortName,
1672 domain.reset.getLoc()};
1673 module.insertPorts({{0, portInfo}});
1674 actualReset = module.getArgument(0);
1675 LLVM_DEBUG(llvm::dbgs()
1676 <<
"- Inserted port " << domain.newPortName <<
"\n");
1680 llvm::dbgs() <<
"- Using ";
1681 if (
auto blockArg = dyn_cast<BlockArgument>(actualReset))
1682 llvm::dbgs() <<
"port #" << blockArg.getArgNumber() <<
" ";
1684 llvm::dbgs() <<
"wire/node ";
1690 SmallVector<Operation *> opsToUpdate;
1691 module.walk([&](Operation *op) {
1692 if (isa<InstanceOp, RegOp, RegResetOp>(op))
1693 opsToUpdate.push_back(op);
1700 if (!isa<BlockArgument>(actualReset)) {
1701 mlir::DominanceInfo dom(module);
1706 auto *resetOp = actualReset.getDefiningOp();
1707 if (!opsToUpdate.empty() && !dom.dominates(resetOp, opsToUpdate[0])) {
1708 LLVM_DEBUG(llvm::dbgs()
1709 <<
"- Reset doesn't dominate all uses, needs to be moved\n");
1713 auto nodeOp = dyn_cast<NodeOp>(resetOp);
1714 if (nodeOp && !dom.dominates(nodeOp.getInput(), opsToUpdate[0])) {
1715 LLVM_DEBUG(llvm::dbgs()
1716 <<
"- Promoting node to wire for move: " << nodeOp <<
"\n");
1717 auto builder = ImplicitLocOpBuilder::atBlockBegin(nodeOp.getLoc(),
1718 nodeOp->getBlock());
1719 auto wireOp = builder.create<WireOp>(
1720 nodeOp.getResult().getType(), nodeOp.getNameAttr(),
1721 nodeOp.getNameKindAttr(), nodeOp.getAnnotationsAttr(),
1722 nodeOp.getInnerSymAttr(), nodeOp.getForceableAttr());
1724 nodeOp->replaceAllUsesWith(wireOp);
1725 nodeOp->removeAttr(nodeOp.getInnerSymAttrName());
1729 nodeOp.setNameKind(NameKindEnum::DroppableName);
1730 nodeOp.setAnnotationsAttr(
ArrayAttr::get(builder.getContext(), {}));
1731 builder.setInsertionPointAfter(nodeOp);
1732 emitConnect(builder, wireOp.getResult(), nodeOp.getResult());
1734 actualReset = wireOp.getResult();
1735 domain.existingValue = wireOp.getResult();
1740 Block *targetBlock = dom.findNearestCommonDominator(
1741 resetOp->getBlock(), opsToUpdate[0]->getBlock());
1743 if (targetBlock != resetOp->getBlock())
1744 llvm::dbgs() <<
"- Needs to be moved to different block\n";
1753 auto getParentInBlock = [](Operation *op,
Block *block) {
1754 while (op && op->getBlock() != block)
1755 op = op->getParentOp();
1758 auto *resetOpInTarget = getParentInBlock(resetOp, targetBlock);
1759 auto *firstOpInTarget = getParentInBlock(opsToUpdate[0], targetBlock);
1765 if (resetOpInTarget->isBeforeInBlock(firstOpInTarget))
1766 resetOp->moveBefore(resetOpInTarget);
1768 resetOp->moveBefore(firstOpInTarget);
1773 for (
auto *op : opsToUpdate)
1774 implementAsyncReset(op, module, actualReset);
1781 void InferResetsPass::implementAsyncReset(Operation *op, FModuleOp module,
1782 Value actualReset) {
1783 ImplicitLocOpBuilder builder(op->getLoc(), op);
1786 if (
auto instOp = dyn_cast<InstanceOp>(op)) {
1790 auto refModule = instOp.getReferencedModule<FModuleOp>(*instanceGraph);
1793 auto domainIt = domains.find(refModule);
1794 if (domainIt == domains.end())
1796 auto &domain = domainIt->second.back().first;
1799 LLVM_DEBUG(llvm::dbgs()
1800 <<
"- Update instance '" << instOp.getName() <<
"'\n");
1804 if (domain.newPortName) {
1805 LLVM_DEBUG(llvm::dbgs() <<
" - Adding new result as reset\n");
1807 auto newInstOp = instOp.cloneAndInsertPorts(
1809 {domain.newPortName,
1810 type_cast<FIRRTLBaseType>(actualReset.getType()),
1812 instReset = newInstOp.getResult(0);
1815 instOp.replaceAllUsesWith(newInstOp.getResults().drop_front());
1816 instanceGraph->replaceInstance(instOp, newInstOp);
1819 }
else if (domain.existingPort.has_value()) {
1820 auto idx = *domain.existingPort;
1821 instReset = instOp.getResult(idx);
1822 LLVM_DEBUG(llvm::dbgs() <<
" - Using result #" << idx <<
" as reset\n");
1832 assert(instReset && actualReset);
1833 builder.setInsertionPointAfter(instOp);
1839 if (
auto regOp = dyn_cast<RegOp>(op)) {
1843 LLVM_DEBUG(llvm::dbgs() <<
"- Adding async reset to " << regOp <<
"\n");
1845 auto newRegOp = builder.create<RegResetOp>(
1846 regOp.getResult().getType(), regOp.getClockVal(), actualReset, zero,
1847 regOp.getNameAttr(), regOp.getNameKindAttr(), regOp.getAnnotations(),
1848 regOp.getInnerSymAttr(), regOp.getForceableAttr());
1849 regOp.getResult().replaceAllUsesWith(newRegOp.getResult());
1850 if (regOp.getForceable())
1851 regOp.getRef().replaceAllUsesWith(newRegOp.getRef());
1857 if (
auto regOp = dyn_cast<RegResetOp>(op)) {
1859 if (type_isa<AsyncResetType>(regOp.getResetSignal().getType())) {
1860 LLVM_DEBUG(llvm::dbgs()
1861 <<
"- Skipping (has async reset) " << regOp <<
"\n");
1864 if (failed(regOp.verifyInvariants()))
1865 signalPassFailure();
1868 LLVM_DEBUG(llvm::dbgs() <<
"- Updating reset of " << regOp <<
"\n");
1870 auto reset = regOp.getResetSignal();
1871 auto value = regOp.getResetValue();
1877 builder.setInsertionPointAfterValue(regOp.getResult());
1878 auto mux = builder.create<MuxPrimOp>(reset, value, regOp.getResult());
1882 builder.setInsertionPoint(regOp);
1884 regOp.getResetSignalMutable().assign(actualReset);
1885 regOp.getResetValueMutable().assign(zero);
1889 LogicalResult InferResetsPass::verifyNoAbstractReset() {
1890 bool hasAbstractResetPorts =
false;
1891 for (FModuleLike module :
1892 getOperation().
getBodyBlock()->getOps<FModuleLike>()) {
1893 for (
PortInfo port : module.getPorts()) {
1894 if (getBaseOfType<ResetType>(port.type)) {
1895 auto diag = emitError(port.loc)
1896 <<
"a port \"" << port.getName()
1897 <<
"\" with abstract reset type was unable to be "
1898 "inferred by InferResets (is this a top-level port?)";
1899 diag.attachNote(module->getLoc())
1900 <<
"the module with this uninferred reset port was defined here";
1901 hasAbstractResetPorts =
true;
1906 if (hasAbstractResetPorts)
assert(baseType &&"element must be base type")
static Value createZeroValue(ImplicitLocOpBuilder &builder, FIRRTLBaseType type, SmallDenseMap< FIRRTLBaseType, Value > &cache)
Construct a zero value of the given type using the given builder.
static unsigned getFieldID(BundleType type, unsigned index)
bool operator!=(const ResetDomain &a, const ResetDomain &b)
static std::pair< StringAttr, FModuleOp > getResetNameAndModule(Value reset)
Return the name and parent module of a reset.
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
static bool isUselessVec(FIRRTLBaseType oldType, unsigned fieldID)
bool operator==(const ResetDomain &a, const ResetDomain &b)
static StringAttr getResetName(Value reset)
Return the name of a reset.
static bool insertResetMux(ImplicitLocOpBuilder &builder, Value target, Value reset, Value resetValue)
Helper function that inserts reset multiplexer into all ConnectOps with the given target.
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static bool typeContainsReset(Type type)
Check whether a type contains a ResetType.
static bool getDeclName(Value value, SmallString< 32 > &string)
static unsigned getMaxFieldID(FIRRTLBaseType type)
static InstancePath empty
static Block * getBodyBlock(FModuleLike mod)
This class represents a reference to a specific field or element of an aggregate value.
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Value getValue() const
Get the Value which created this location.
This class provides a read-only projection of an annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
FIRRTLBaseType getConstType(bool isConst)
Return a 'const' or non-'const' version of this type.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
An instance path composed of a series of instances.
InstanceOpInterface leaf() const
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
FieldRef getFieldRefForTarget(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
constexpr const char * excludeMemToRegAnnoClass
igraph::InstancePathCache InstancePathCache
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
FIRRTLType mapBaseType(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
constexpr const char * fullAsyncResetAnnoClass
Annotation that marks a reset (port or wire) and domain.
T & operator<<(T &os, FIRVersion version)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
constexpr const char * ignoreFullAsyncResetAnnoClass
Annotation that marks a module as not belonging to any reset domain.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
std::unique_ptr< mlir::Pass > createInferResetsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(llvm::StringRef str, int width=80)
Write a "header"-like string to the debug stream with a certain width.
inline ::llvm::hash_code hash_value(const FieldRef &fieldRef)
Get a hash code for a FieldRef.
bool operator<(const AppID &a, const AppID &b)
This holds the name and type that describes the module's ports.
static ResetSignal getEmptyKey()
static ResetSignal getTombstoneKey()
static bool isEqual(const ResetSignal &lhs, const ResetSignal &rhs)
static unsigned getHashValue(const ResetSignal &x)