23 #include "mlir/IR/Dominance.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/IR/Threading.h"
26 #include "mlir/Pass/Pass.h"
27 #include "llvm/ADT/EquivalenceClasses.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/TinyPtrVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
33 #define DEBUG_TYPE "infer-resets"
37 #define GEN_PASS_DEF_INFERRESETS
38 #include "circt/Dialect/FIRRTL/Passes.h.inc"
42 using circt::igraph::InstanceOpInterface;
45 using llvm::BumpPtrAllocator;
46 using llvm::MapVector;
47 using llvm::SmallDenseSet;
48 using llvm::SmallSetVector;
49 using mlir::FailureOr;
50 using mlir::InferTypeOpInterface;
51 using mlir::WalkOrder;
53 using namespace circt;
54 using namespace firrtl;
71 std::optional<unsigned> existingPort;
72 StringAttr newPortName;
74 ResetDomain(Value reset) : reset(reset) {}
78 inline bool operator==(
const ResetDomain &a,
const ResetDomain &b) {
79 return (a.isTop == b.isTop && a.reset == b.reset);
81 inline bool operator!=(
const ResetDomain &a,
const ResetDomain &b) {
88 if (
auto arg = dyn_cast<BlockArgument>(reset)) {
89 auto module = cast<FModuleOp>(arg.getParentRegion()->getParentOp());
90 return {module.getPortNameAttr(arg.getArgNumber()), module};
92 auto op = reset.getDefiningOp();
93 return {op->getAttrOfType<StringAttr>(
"name"),
94 op->getParentOfType<FModuleOp>()};
109 auto it = cache.find(type);
110 if (it != cache.end())
112 auto nullBit = [&]() {
119 .
Case<ClockType>([&](
auto type) {
120 return builder.create<AsClockPrimOp>(nullBit());
122 .Case<AsyncResetType>([&](
auto type) {
123 return builder.create<AsAsyncResetPrimOp>(nullBit());
125 .Case<SIntType, UIntType>([&](
auto type) {
126 return builder.create<ConstantOp>(
127 type, APInt::getZero(type.getWidth().value_or(1)));
129 .Case<BundleType>([&](
auto type) {
130 auto wireOp = builder.create<WireOp>(type);
131 for (
unsigned i = 0, e = type.getNumElements(); i < e; ++i) {
132 auto fieldType = type.getElementTypePreservingConst(i);
135 builder.create<SubfieldOp>(fieldType, wireOp.getResult(), i);
138 return wireOp.getResult();
140 .Case<FVectorType>([&](
auto type) {
141 auto wireOp = builder.create<WireOp>(type);
143 builder, type.getElementTypePreservingConst(), cache);
144 for (
unsigned i = 0, e = type.getNumElements(); i < e; ++i) {
145 auto acc = builder.create<SubindexOp>(zero.getType(),
146 wireOp.getResult(), i);
149 return wireOp.getResult();
151 .Case<ResetType, AnalogType>(
152 [&](
auto type) {
return builder.create<InvalidValueOp>(type); })
154 llvm_unreachable(
"switch handles all types");
157 cache.insert({type, value});
174 Value reset, Value resetValue) {
178 bool resetValueUsed =
false;
180 for (
auto &use : target.getUses()) {
181 Operation *useOp = use.getOwner();
182 builder.setInsertionPoint(useOp);
183 TypeSwitch<Operation *>(useOp)
186 .Case<ConnectOp, MatchingConnectOp>([&](
auto op) {
187 if (op.getDest() != target)
189 LLVM_DEBUG(llvm::dbgs() <<
" - Insert mux into " << op <<
"\n");
191 builder.create<MuxPrimOp>(reset, resetValue, op.getSrc());
192 op.getSrcMutable().assign(muxOp);
193 resetValueUsed =
true;
196 .Case<SubfieldOp>([&](
auto op) {
198 builder.create<SubfieldOp>(resetValue, op.getFieldIndexAttr());
200 resetValueUsed =
true;
202 resetSubValue.erase();
205 .Case<SubindexOp>([&](
auto op) {
207 builder.create<SubindexOp>(resetValue, op.getIndexAttr());
209 resetValueUsed =
true;
211 resetSubValue.erase();
214 .Case<SubaccessOp>([&](
auto op) {
215 if (op.getInput() != target)
218 builder.create<SubaccessOp>(resetValue, op.getIndex());
220 resetValueUsed =
true;
222 resetSubValue.erase();
225 return resetValueUsed;
240 bool operator<(
const ResetSignal &other)
const {
return field < other.field; }
241 bool operator==(
const ResetSignal &other)
const {
242 return field == other.field;
244 bool operator!=(
const ResetSignal &other)
const {
return !(*
this == other); }
264 using ResetDrives = SmallVector<ResetDrive, 1>;
267 using ResetNetwork = llvm::iterator_range<
268 llvm::EquivalenceClasses<ResetSignal>::member_iterator>;
271 enum class ResetKind { Async, Sync };
273 static StringRef resetKindToStringRef(
const ResetKind &kind) {
275 case ResetKind::Async:
277 case ResetKind::Sync:
280 llvm_unreachable(
"unhandled reset kind");
296 static bool isEqual(
const ResetSignal &lhs,
const ResetSignal &rhs) {
302 template <
typename T>
305 case ResetKind::Async:
306 return os <<
"async";
307 case ResetKind::Sync:
417 struct InferResetsPass
418 :
public circt::firrtl::impl::InferResetsBase<InferResetsPass> {
419 void runOnOperation()
override;
420 void runOnOperationInner();
423 using InferResetsBase::InferResetsBase;
424 InferResetsPass(
const InferResetsPass &other) : InferResetsBase(other) {}
429 void traceResets(CircuitOp circuit);
430 void traceResets(InstanceOp inst);
431 void traceResets(Value dst, Value src, Location loc);
432 void traceResets(Value value);
433 void traceResets(Type dstType, Value dst,
unsigned dstID, Type srcType,
434 Value src,
unsigned srcID, Location loc);
436 LogicalResult inferAndUpdateResets();
437 FailureOr<ResetKind> inferReset(ResetNetwork net);
438 LogicalResult updateReset(ResetNetwork net, ResetKind kind);
444 LogicalResult collectAnnos(CircuitOp circuit);
450 FailureOr<std::optional<Value>> collectAnnos(FModuleOp module);
452 LogicalResult buildDomains(CircuitOp circuit);
453 void buildDomains(FModuleOp module,
const InstancePath &instPath,
455 unsigned indent = 0);
457 void determineImpl();
458 void determineImpl(FModuleOp module, ResetDomain &domain);
460 LogicalResult implementFullReset();
461 LogicalResult implementFullReset(FModuleOp module, ResetDomain &domain);
462 void implementFullReset(Operation *op, FModuleOp module, Value actualReset);
464 LogicalResult verifyNoAbstractReset();
470 ResetNetwork getResetNetwork(ResetSignal signal) {
471 return llvm::make_range(resetClasses.findLeader(signal),
472 resetClasses.member_end());
476 ResetDrives &getResetDrives(ResetNetwork net) {
477 return resetDrives[*net.begin()];
482 ResetSignal guessRoot(ResetNetwork net);
483 ResetSignal guessRoot(ResetSignal signal) {
484 return guessRoot(getResetNetwork(signal));
491 llvm::EquivalenceClasses<ResetSignal> resetClasses;
494 DenseMap<ResetSignal, ResetDrives> resetDrives;
499 DenseMap<Operation *, Value> annotatedResets;
503 MapVector<FModuleOp, SmallVector<std::pair<ResetDomain, InstancePath>, 1>>
510 std::unique_ptr<InstancePathCache> instancePathCache;
514 void InferResetsPass::runOnOperation() {
515 runOnOperationInner();
516 resetClasses = llvm::EquivalenceClasses<ResetSignal>();
518 annotatedResets.clear();
520 instancePathCache.reset(
nullptr);
521 markAnalysesPreserved<InstanceGraph>();
524 void InferResetsPass::runOnOperationInner() {
525 instanceGraph = &getAnalysis<InstanceGraph>();
526 instancePathCache = std::make_unique<InstancePathCache>(*instanceGraph);
529 traceResets(getOperation());
532 if (failed(inferAndUpdateResets()))
533 return signalPassFailure();
536 if (failed(collectAnnos(getOperation())))
537 return signalPassFailure();
540 if (failed(buildDomains(getOperation())))
541 return signalPassFailure();
547 if (failed(implementFullReset()))
548 return signalPassFailure();
551 if (failed(verifyNoAbstractReset()))
552 return signalPassFailure();
556 return std::make_unique<InferResetsPass>();
559 ResetSignal InferResetsPass::guessRoot(ResetNetwork net) {
560 ResetDrives &drives = getResetDrives(net);
561 ResetSignal bestSignal = *net.begin();
562 unsigned bestNumDrives = -1;
564 for (
auto signal : net) {
566 if (isa_and_nonnull<InvalidValueOp>(
567 signal.field.getValue().getDefiningOp()))
572 unsigned numDrives = 0;
573 for (
auto &drive : drives)
574 if (drive.dst == signal)
580 if (numDrives < bestNumDrives) {
581 bestNumDrives = numDrives;
600 .
Case<BundleType>([](
auto type) {
602 for (
auto e : type.getElements())
607 [](
auto type) {
return getMaxFieldID(type.getElementType()) + 1; })
608 .Default([](
auto) {
return 0; });
611 static unsigned getFieldID(BundleType type,
unsigned index) {
612 assert(index < type.getNumElements());
614 for (
unsigned i = 0; i < index; ++i)
622 assert(type.getNumElements() &&
"Bundle must have >0 fields");
624 for (
const auto &e : llvm::enumerate(type.getElements())) {
626 if (fieldID < numSubfields)
628 fieldID -= numSubfields;
630 assert(
false &&
"field id outside bundle");
636 if (oldType.isGround()) {
642 if (
auto bundleType = type_dyn_cast<BundleType>(oldType)) {
650 if (
auto vectorType = type_dyn_cast<FVectorType>(oldType)) {
651 if (vectorType.getNumElements() == 0)
668 if (
auto arg = dyn_cast<BlockArgument>(value)) {
669 auto module = cast<FModuleOp>(arg.getOwner()->getParentOp());
670 string += module.getPortName(arg.getArgNumber());
674 auto *op = value.getDefiningOp();
675 return TypeSwitch<Operation *, bool>(op)
676 .Case<InstanceOp, MemOp>([&](
auto op) {
677 string += op.getName();
680 op.getPortName(cast<OpResult>(value).getResultNumber()).getValue();
683 .Case<WireOp, NodeOp, RegOp, RegResetOp>([&](
auto op) {
684 string += op.getName();
687 .Default([](
auto) {
return false; });
691 SmallString<64> name;
696 auto type = value.getType();
699 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
702 auto &element = bundleType.getElements()[index];
705 string += element.name.getValue();
708 localID = localID -
getFieldID(bundleType, index);
709 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
712 type = vecType.getElementType();
719 llvm_unreachable(
"unsupported type");
731 return TypeSwitch<Type, bool>(type)
733 return type.getRecursiveTypeProperties().hasUninferredReset;
735 .Default([](
auto) {
return false; });
742 void InferResetsPass::traceResets(CircuitOp circuit) {
744 llvm::dbgs() <<
"\n";
745 debugHeader(
"Tracing uninferred resets") <<
"\n\n";
748 SmallVector<std::pair<FModuleOp, SmallVector<Operation *>>> moduleToOps;
750 for (
auto module : circuit.getOps<FModuleOp>())
751 moduleToOps.push_back({module, {}});
753 hw::InnerRefNamespace irn{getAnalysis<SymbolTable>(),
754 getAnalysis<hw::InnerSymbolTableCollection>()};
756 mlir::parallelForEach(circuit.getContext(), moduleToOps, [](
auto &e) {
757 e.first.walk([&](Operation *op) {
761 op->getResultTypes(),
762 [](mlir::Type type) { return typeContainsReset(type); }) ||
763 llvm::any_of(op->getOperandTypes(), typeContainsReset))
764 e.second.push_back(op);
768 for (
auto &[_, ops] : moduleToOps)
769 for (
auto *op : ops) {
770 TypeSwitch<Operation *>(op)
771 .Case<FConnectLike>([&](
auto op) {
772 traceResets(op.getDest(), op.getSrc(), op.getLoc());
774 .Case<InstanceOp>([&](
auto op) { traceResets(op); })
775 .Case<RefSendOp>([&](
auto op) {
777 traceResets(op.getType().getType(), op.getResult(), 0,
778 op.getBase().getType().getPassiveType(), op.getBase(),
781 .Case<RefResolveOp>([&](
auto op) {
783 traceResets(op.getType(), op.getResult(), 0,
784 op.getRef().getType().getType(), op.getRef(), 0,
787 .Case<Forceable>([&](Forceable op) {
788 if (
auto node = dyn_cast<NodeOp>(op.getOperation()))
789 traceResets(node.getResult(), node.getInput(), node.getLoc());
791 if (op.isForceable())
792 traceResets(op.getDataType(), op.getData(), 0, op.getDataType(),
793 op.getDataRef(), 0, op.getLoc());
795 .Case<RWProbeOp>([&](RWProbeOp op) {
796 auto ist = irn.lookup(op.getTarget());
799 auto baseType = op.getType().getType();
800 traceResets(baseType, op.getResult(), 0, baseType.getPassiveType(),
801 ref.getValue(), ref.getFieldID(), op.getLoc());
803 .Case<UninferredResetCastOp, ConstCastOp, RefCastOp>([&](
auto op) {
804 traceResets(op.getResult(), op.getInput(), op.getLoc());
806 .Case<InvalidValueOp>([&](
auto op) {
815 auto type = op.getType();
818 LLVM_DEBUG(llvm::dbgs() <<
"Uniquify " << op <<
"\n");
819 ImplicitLocOpBuilder builder(op->getLoc(), op);
821 llvm::make_early_inc_range(llvm::drop_begin(op->getUses()))) {
827 auto newOp = builder.create<InvalidValueOp>(type);
832 .Case<SubfieldOp>([&](
auto op) {
835 BundleType bundleType = op.getInput().getType();
836 auto index = op.getFieldIndex();
837 traceResets(op.getType(), op.getResult(), 0,
838 bundleType.getElements()[index].type, op.getInput(),
842 .Case<SubindexOp, SubaccessOp>([&](
auto op) {
855 FVectorType vectorType = op.getInput().getType();
856 traceResets(op.getType(), op.getResult(), 0,
857 vectorType.getElementType(), op.getInput(),
861 .Case<RefSubOp>([&](RefSubOp op) {
863 auto aggType = op.getInput().getType().getType();
864 uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(aggType)
865 .Case<FVectorType>([](
auto type) {
868 .Case<BundleType>([&](
auto type) {
871 traceResets(op.getType(), op.getResult(), 0,
872 op.getResult().getType(), op.getInput(), fieldID,
880 void InferResetsPass::traceResets(InstanceOp inst) {
882 auto module = inst.getReferencedModule<FModuleOp>(*instanceGraph);
885 LLVM_DEBUG(llvm::dbgs() <<
"Visiting instance " << inst.getName() <<
"\n");
888 for (
const auto &it : llvm::enumerate(inst.getResults())) {
889 auto dir = module.getPortDirection(it.index());
890 Value dstPort = module.getArgument(it.index());
891 Value srcPort = it.value();
892 if (dir == Direction::Out)
893 std::swap(dstPort, srcPort);
894 traceResets(dstPort, srcPort, it.value().getLoc());
900 void InferResetsPass::traceResets(Value dst, Value src, Location loc) {
902 traceResets(dst.getType(), dst, 0, src.getType(), src, 0, loc);
907 void InferResetsPass::traceResets(Type dstType, Value dst,
unsigned dstID,
908 Type srcType, Value src,
unsigned srcID,
910 if (
auto dstBundle = type_dyn_cast<BundleType>(dstType)) {
911 auto srcBundle = type_cast<BundleType>(srcType);
912 for (
unsigned dstIdx = 0, e = dstBundle.getNumElements(); dstIdx < e;
914 auto dstField = dstBundle.getElements()[dstIdx].name;
915 auto srcIdx = srcBundle.getElementIndex(dstField);
918 auto &dstElt = dstBundle.getElements()[dstIdx];
919 auto &srcElt = srcBundle.getElements()[*srcIdx];
921 traceResets(srcElt.type, src, srcID +
getFieldID(srcBundle, *srcIdx),
922 dstElt.type, dst, dstID +
getFieldID(dstBundle, dstIdx),
925 traceResets(dstElt.type, dst, dstID +
getFieldID(dstBundle, dstIdx),
926 srcElt.type, src, srcID +
getFieldID(srcBundle, *srcIdx),
933 if (
auto dstVector = type_dyn_cast<FVectorType>(dstType)) {
934 auto srcVector = type_cast<FVectorType>(srcType);
935 auto srcElType = srcVector.getElementType();
936 auto dstElType = dstVector.getElementType();
949 traceResets(dstElType, dst, dstID +
getFieldID(dstVector), srcElType, src,
955 if (
auto dstRef = type_dyn_cast<RefType>(dstType)) {
956 auto srcRef = type_cast<RefType>(srcType);
957 return traceResets(dstRef.getType(), dst, dstID, srcRef.getType(), src,
962 auto dstBase = type_dyn_cast<FIRRTLBaseType>(dstType);
963 auto srcBase = type_dyn_cast<FIRRTLBaseType>(srcType);
964 if (!dstBase || !srcBase)
966 if (!type_isa<ResetType>(dstBase) && !type_isa<ResetType>(srcBase))
971 LLVM_DEBUG(llvm::dbgs() <<
"Visiting driver '" << dstField <<
"' = '"
972 << srcField <<
"' (" << dstType <<
" = " << srcType
978 ResetSignal dstLeader =
979 *resetClasses.findLeader(resetClasses.insert({dstField, dstBase}));
980 ResetSignal srcLeader =
981 *resetClasses.findLeader(resetClasses.insert({srcField, srcBase}));
984 ResetSignal unionLeader = *resetClasses.unionSets(dstLeader, srcLeader);
985 assert(unionLeader == dstLeader || unionLeader == srcLeader);
990 if (dstLeader != srcLeader) {
991 auto &unionDrives = resetDrives[unionLeader];
992 auto mergedDrivesIt =
993 resetDrives.find(unionLeader == dstLeader ? srcLeader : dstLeader);
994 if (mergedDrivesIt != resetDrives.end()) {
995 unionDrives.append(mergedDrivesIt->second);
996 resetDrives.erase(mergedDrivesIt);
1002 resetDrives[unionLeader].push_back(
1003 {{dstField, dstBase}, {srcField, srcBase}, loc});
1010 LogicalResult InferResetsPass::inferAndUpdateResets() {
1012 llvm::dbgs() <<
"\n";
1015 for (
auto it = resetClasses.begin(),
end = resetClasses.end(); it !=
end;
1017 if (!it->isLeader())
1019 ResetNetwork net = llvm::make_range(resetClasses.member_begin(it),
1020 resetClasses.member_end());
1023 auto kind = inferReset(net);
1028 if (failed(updateReset(net, *kind)))
1034 FailureOr<ResetKind> InferResetsPass::inferReset(ResetNetwork net) {
1035 LLVM_DEBUG(llvm::dbgs() <<
"Inferring reset network with "
1036 << std::distance(net.begin(), net.end())
1040 unsigned asyncDrives = 0;
1041 unsigned syncDrives = 0;
1042 unsigned invalidDrives = 0;
1043 for (ResetSignal signal : net) {
1045 if (type_isa<AsyncResetType>(signal.type))
1047 else if (type_isa<UIntType>(signal.type))
1050 isa_and_nonnull<InvalidValueOp>(
1051 signal.field.getValue().getDefiningOp()))
1054 LLVM_DEBUG(llvm::dbgs() <<
"- Found " << asyncDrives <<
" async, "
1055 << syncDrives <<
" sync, " << invalidDrives
1056 <<
" invalid drives\n");
1059 if (asyncDrives == 0 && syncDrives == 0 && invalidDrives == 0) {
1060 ResetSignal root = guessRoot(net);
1061 auto diag = mlir::emitError(root.field.getValue().getLoc())
1062 <<
"reset network never driven with concrete type";
1063 for (ResetSignal signal : net)
1064 diag.attachNote(signal.field.getLoc()) <<
"here: ";
1069 if (asyncDrives > 0 && syncDrives > 0) {
1070 ResetSignal root = guessRoot(net);
1071 bool majorityAsync = asyncDrives >= syncDrives;
1072 auto diag = mlir::emitError(root.field.getValue().getLoc())
1074 SmallString<32> fieldName;
1076 diag <<
" \"" << fieldName <<
"\"";
1077 diag <<
" simultaneously connected to async and sync resets";
1078 diag.attachNote(root.field.getValue().getLoc())
1079 <<
"majority of connections to this reset are "
1080 << (majorityAsync ?
"async" :
"sync");
1081 for (
auto &drive : getResetDrives(net)) {
1082 if ((type_isa<AsyncResetType>(drive.dst.type) && !majorityAsync) ||
1083 (type_isa<AsyncResetType>(drive.src.type) && !majorityAsync) ||
1084 (type_isa<UIntType>(drive.dst.type) && majorityAsync) ||
1085 (type_isa<UIntType>(drive.src.type) && majorityAsync))
1086 diag.attachNote(drive.loc)
1087 << (type_isa<AsyncResetType>(drive.src.type) ?
"async" :
"sync")
1096 auto kind = (asyncDrives ? ResetKind::Async : ResetKind::Sync);
1097 LLVM_DEBUG(llvm::dbgs() <<
"- Inferred as " << kind <<
"\n");
1105 LogicalResult InferResetsPass::updateReset(ResetNetwork net, ResetKind kind) {
1106 LLVM_DEBUG(llvm::dbgs() <<
"Updating reset network with "
1107 << std::distance(net.begin(), net.end())
1108 <<
" nodes to " << kind <<
"\n");
1112 if (kind == ResetKind::Async)
1120 SmallSetVector<Operation *, 16> worklist;
1121 SmallDenseSet<Operation *> moduleWorklist;
1122 SmallDenseSet<std::pair<Operation *, Operation *>> extmoduleWorklist;
1123 for (
auto signal : net) {
1124 Value value = signal.field.getValue();
1125 if (!isa<BlockArgument>(value) &&
1126 !isa_and_nonnull<WireOp, RegOp, RegResetOp, InstanceOp, InvalidValueOp,
1127 ConstCastOp, RefCastOp, UninferredResetCastOp,
1128 RWProbeOp>(value.getDefiningOp()))
1130 if (updateReset(signal.field, resetType)) {
1131 for (
auto user : value.getUsers())
1132 worklist.insert(user);
1133 if (
auto blockArg = dyn_cast<BlockArgument>(value))
1134 moduleWorklist.insert(blockArg.getOwner()->getParentOp());
1135 else if (
auto instOp = value.getDefiningOp<InstanceOp>()) {
1136 if (
auto extmodule =
1137 instOp.getReferencedModule<FExtModuleOp>(*instanceGraph))
1138 extmoduleWorklist.insert({extmodule, instOp});
1139 }
else if (
auto uncast = value.getDefiningOp<UninferredResetCastOp>()) {
1140 uncast.replaceAllUsesWith(uncast.getInput());
1150 while (!worklist.empty()) {
1151 auto *wop = worklist.pop_back_val();
1152 SmallVector<Type, 2> types;
1153 if (
auto op = dyn_cast<InferTypeOpInterface>(wop)) {
1155 SmallVector<Type, 2> types;
1156 if (failed(op.inferReturnTypes(op->getContext(), op->getLoc(),
1157 op->getOperands(), op->getAttrDictionary(),
1158 op->getPropertiesStorage(),
1159 op->getRegions(), types)))
1164 for (
auto it : llvm::zip(op->getResults(), types)) {
1165 auto newType = std::get<1>(it);
1166 if (std::get<0>(it).getType() == newType)
1168 std::get<0>(it).setType(newType);
1169 for (
auto *user : std::get<0>(it).getUsers())
1170 worklist.insert(user);
1172 LLVM_DEBUG(llvm::dbgs() <<
"- Inferred " << *op <<
"\n");
1173 }
else if (
auto uop = dyn_cast<UninferredResetCastOp>(wop)) {
1174 for (
auto *user : uop.getResult().getUsers())
1175 worklist.insert(user);
1176 uop.replaceAllUsesWith(uop.getInput());
1177 LLVM_DEBUG(llvm::dbgs() <<
"- Inferred " << uop <<
"\n");
1183 for (
auto *op : moduleWorklist) {
1184 auto module = dyn_cast<FModuleOp>(op);
1188 SmallVector<Attribute> argTypes;
1189 argTypes.reserve(module.getNumPorts());
1190 for (
auto arg : module.getArguments())
1193 module.setPortTypesAttr(
ArrayAttr::get(op->getContext(), argTypes));
1194 LLVM_DEBUG(llvm::dbgs()
1195 <<
"- Updated type of module '" << module.getName() <<
"'\n");
1199 for (
auto pair : extmoduleWorklist) {
1200 auto module = cast<FExtModuleOp>(pair.first);
1201 auto instOp = cast<InstanceOp>(pair.second);
1203 SmallVector<Attribute> types;
1204 for (
auto type : instOp.getResultTypes())
1207 module.setPortTypesAttr(
ArrayAttr::get(module->getContext(), types));
1208 LLVM_DEBUG(llvm::dbgs()
1209 <<
"- Updated type of extmodule '" << module.getName() <<
"'\n");
1219 if (oldType.isGround()) {
1225 if (
auto bundleType = type_dyn_cast<BundleType>(oldType)) {
1227 SmallVector<BundleType::BundleElement> fields(bundleType.begin(),
1230 fields[index].type, fieldID -
getFieldID(bundleType, index), fieldType);
1235 if (
auto vectorType = type_dyn_cast<FVectorType>(oldType)) {
1236 auto newType =
updateType(vectorType.getElementType(),
1237 fieldID -
getFieldID(vectorType), fieldType);
1239 vectorType.isConst());
1242 llvm_unreachable(
"unknown aggregate type");
1249 auto oldType = type_cast<FIRRTLType>(field.
getValue().getType());
1255 if (oldType == newType)
1257 LLVM_DEBUG(llvm::dbgs() <<
"- Updating '" << field <<
"' from " << oldType
1258 <<
" to " << newType <<
"\n");
1267 LogicalResult InferResetsPass::collectAnnos(CircuitOp circuit) {
1269 llvm::dbgs() <<
"\n";
1270 debugHeader(
"Gather reset annotations") <<
"\n\n";
1272 SmallVector<std::pair<FModuleOp, std::optional<Value>>> results;
1273 for (
auto module : circuit.getOps<FModuleOp>())
1274 results.push_back({module, {}});
1276 if (failed(mlir::failableParallelForEach(
1277 circuit.getContext(), results, [&](
auto &moduleAndResult) {
1278 auto result = collectAnnos(moduleAndResult.first);
1281 moduleAndResult.second = *result;
1286 for (
auto [module, reset] : results)
1287 if (reset.has_value())
1288 annotatedResets.insert({module, *reset});
1292 FailureOr<std::optional<Value>>
1293 InferResetsPass::collectAnnos(FModuleOp module) {
1294 bool anyFailed =
false;
1295 SmallSetVector<std::pair<Annotation, Location>, 4> conflictingAnnos;
1299 bool ignore =
false;
1300 AnnotationSet::removeAnnotations(module, [&](
Annotation anno) {
1303 conflictingAnnos.insert({anno, module.getLoc()});
1308 module.emitError(
"''FullResetAnnotation' cannot target module; must "
1309 "target port or wire/node instead");
1320 auto checkAnnotations = [&](
Annotation anno, Value arg) {
1322 ResetKind expectedResetKind;
1323 if (
auto rt = anno.
getMember<StringAttr>(
"resetType")) {
1325 expectedResetKind = ResetKind::Sync;
1326 }
else if (rt ==
"async") {
1327 expectedResetKind = ResetKind::Async;
1329 mlir::emitError(arg.getLoc(),
1330 "'FullResetAnnotation' requires resetType == 'sync' "
1331 "| 'async', but got resetType == ")
1337 mlir::emitError(arg.getLoc(),
1338 "'FullResetAnnotation' requires resetType == "
1339 "'sync' | 'async', but got no resetType");
1344 bool isAsync = expectedResetKind == ResetKind::Async;
1345 bool validUint =
false;
1346 if (
auto uintT = dyn_cast<UIntType>(arg.getType()))
1347 validUint = uintT.getWidth() == 1;
1348 if ((isAsync && !isa<AsyncResetType>(arg.getType())) ||
1349 (!isAsync && !validUint)) {
1350 auto kind = resetKindToStringRef(expectedResetKind);
1351 mlir::emitError(arg.getLoc(),
1352 "'FullResetAnnotation' with resetType == '")
1353 << kind <<
"' must target " << kind <<
" reset, but targets "
1360 conflictingAnnos.insert({anno, reset.getLoc()});
1366 mlir::emitError(arg.getLoc(),
1367 "'ExcludeFromFullResetAnnotation' cannot "
1368 "target port/wire/node; must target module instead");
1374 AnnotationSet::removePortAnnotations(module,
1376 Value arg = module.getArgument(argNum);
1377 return checkAnnotations(anno, arg);
1383 module.getBody().walk([&](Operation *op) {
1385 if (!isa<WireOp, NodeOp>(op)) {
1390 "reset annotations must target module, port, or wire/node");
1397 AnnotationSet::removeAnnotations(op, [&](
Annotation anno) {
1398 auto arg = op->getResult(0);
1399 return checkAnnotations(anno, arg);
1408 if (!ignore && !reset) {
1409 LLVM_DEBUG(llvm::dbgs()
1410 <<
"No reset annotation for " << module.getName() <<
"\n");
1411 return std::optional<Value>();
1415 if (conflictingAnnos.size() > 1) {
1416 auto diag = module.emitError(
"multiple reset annotations on module '")
1417 << module.getName() <<
"'";
1418 for (
auto &annoAndLoc : conflictingAnnos)
1419 diag.attachNote(annoAndLoc.second)
1420 <<
"conflicting " << annoAndLoc.first.getClassAttr() <<
":";
1426 llvm::dbgs() <<
"Annotated reset for " << module.getName() <<
": ";
1428 llvm::dbgs() <<
"no domain\n";
1429 else if (
auto arg = dyn_cast<BlockArgument>(reset))
1430 llvm::dbgs() <<
"port " << module.getPortName(arg.getArgNumber()) <<
"\n";
1432 llvm::dbgs() <<
"wire "
1433 << reset.getDefiningOp()->getAttrOfType<StringAttr>(
"name")
1439 return std::optional<Value>(reset);
1451 LogicalResult InferResetsPass::buildDomains(CircuitOp circuit) {
1453 llvm::dbgs() <<
"\n";
1454 debugHeader(
"Build full reset domains") <<
"\n\n";
1458 auto &instGraph = getAnalysis<InstanceGraph>();
1459 auto module = dyn_cast<FModuleOp>(*instGraph.getTopLevelNode()->getModule());
1461 LLVM_DEBUG(llvm::dbgs()
1462 <<
"Skipping circuit because main module is no `firrtl.module`");
1465 buildDomains(module,
InstancePath{}, Value{}, instGraph);
1468 bool anyFailed =
false;
1469 for (
auto &it : domains) {
1470 auto module = cast<FModuleOp>(it.first);
1471 auto &domainConflicts = it.second;
1472 if (domainConflicts.size() <= 1)
1476 SmallDenseSet<Value> printedDomainResets;
1477 auto diag = module.emitError(
"module '")
1479 <<
"' instantiated in different reset domains";
1480 for (
auto &it : domainConflicts) {
1481 ResetDomain &domain = it.first;
1482 const auto &path = it.second;
1483 auto inst = path.leaf();
1484 auto loc = path.empty() ? module.getLoc() : inst.getLoc();
1485 auto ¬e = diag.attachNote(loc);
1489 note <<
"root instance";
1491 note <<
"instance '";
1494 [&](InstanceOpInterface inst) { note << inst.getInstanceName(); },
1495 [&]() { note <<
"/"; });
1503 note <<
" reset domain rooted at '" << nameAndModule.first.getValue()
1504 <<
"' of module '" << nameAndModule.second.getName() <<
"'";
1507 if (printedDomainResets.insert(domain.reset).second) {
1508 diag.attachNote(domain.reset.getLoc())
1509 <<
"reset domain '" << nameAndModule.first.getValue()
1510 <<
"' of module '" << nameAndModule.second.getName()
1511 <<
"' declared here:";
1514 note <<
" no reset domain";
1517 return failure(anyFailed);
1520 void InferResetsPass::buildDomains(FModuleOp module,
1525 llvm::dbgs().indent(indent * 2) <<
"Visiting ";
1526 if (instPath.
empty())
1527 llvm::dbgs() <<
"$root";
1529 llvm::dbgs() << instPath.
leaf().getInstanceName();
1530 llvm::dbgs() <<
" (" << module.getName() <<
")\n";
1534 ResetDomain domain(parentReset);
1535 auto it = annotatedResets.find(module);
1536 if (it != annotatedResets.end()) {
1537 domain.isTop =
true;
1538 domain.reset = it->second;
1544 auto &entries = domains[module];
1545 if (llvm::all_of(entries,
1546 [&](
const auto &entry) {
return entry.first != domain; }))
1547 entries.push_back({domain, instPath});
1550 for (
auto *record : *instGraph[module]) {
1551 auto submodule = dyn_cast<FModuleOp>(*record->getTarget()->getModule());
1555 instancePathCache->appendInstance(instPath, record->getInstance());
1556 buildDomains(submodule, childPath, domain.reset, instGraph, indent + 1);
1561 void InferResetsPass::determineImpl() {
1563 llvm::dbgs() <<
"\n";
1564 debugHeader(
"Determine implementation") <<
"\n\n";
1566 for (
auto &it : domains) {
1567 auto module = cast<FModuleOp>(it.first);
1568 auto &domain = it.second.back().first;
1569 determineImpl(module, domain);
1589 void InferResetsPass::determineImpl(FModuleOp module, ResetDomain &domain) {
1592 LLVM_DEBUG(llvm::dbgs() <<
"Planning reset for " << module.getName() <<
"\n");
1597 LLVM_DEBUG(llvm::dbgs() <<
"- Rooting at local value "
1599 domain.existingValue = domain.reset;
1600 if (
auto blockArg = dyn_cast<BlockArgument>(domain.reset))
1601 domain.existingPort = blockArg.getArgNumber();
1608 auto neededType = domain.reset.getType();
1609 LLVM_DEBUG(llvm::dbgs() <<
"- Looking for existing port " << neededName
1611 auto portNames = module.getPortNames();
1612 auto ports = llvm::zip(portNames, module.getArguments());
1613 auto portIt = llvm::find_if(
1614 ports, [&](
auto port) {
return std::get<0>(port) == neededName; });
1615 if (portIt != ports.end() && std::get<1>(*portIt).getType() == neededType) {
1616 LLVM_DEBUG(llvm::dbgs()
1617 <<
"- Reusing existing port " << neededName <<
"\n");
1618 domain.existingValue = std::get<1>(*portIt);
1619 domain.existingPort = std::distance(ports.begin(), portIt);
1629 if (portIt != ports.end()) {
1630 LLVM_DEBUG(llvm::dbgs()
1631 <<
"- Existing " << neededName <<
" has incompatible type "
1632 << std::get<1>(*portIt).getType() <<
"\n");
1634 unsigned suffix = 0;
1638 Twine(
"_") + Twine(suffix++));
1639 }
while (llvm::is_contained(portNames, newName));
1640 LLVM_DEBUG(llvm::dbgs()
1641 <<
"- Creating uniquified port " << newName <<
"\n");
1642 domain.newPortName = newName;
1648 LLVM_DEBUG(llvm::dbgs() <<
"- Creating new port " << neededName <<
"\n");
1649 domain.newPortName = neededName;
1657 LogicalResult InferResetsPass::implementFullReset() {
1659 llvm::dbgs() <<
"\n";
1662 for (
auto &it : domains)
1663 if (failed(implementFullReset(cast<FModuleOp>(it.first),
1664 it.second.back().first)))
1674 LogicalResult InferResetsPass::implementFullReset(FModuleOp module,
1675 ResetDomain &domain) {
1676 LLVM_DEBUG(llvm::dbgs() <<
"Implementing full reset for " << module.getName()
1680 if (!domain.reset) {
1681 LLVM_DEBUG(llvm::dbgs()
1682 <<
"- Skipping because module explicitly has no domain\n");
1687 auto *context = module.getContext();
1692 annotations.applyToOperation(module);
1695 Value actualReset = domain.existingValue;
1696 if (domain.newPortName) {
1697 PortInfo portInfo{domain.newPortName,
1698 domain.reset.getType(),
1701 domain.reset.getLoc()};
1702 module.insertPorts({{0, portInfo}});
1703 actualReset = module.getArgument(0);
1704 LLVM_DEBUG(llvm::dbgs()
1705 <<
"- Inserted port " << domain.newPortName <<
"\n");
1709 llvm::dbgs() <<
"- Using ";
1710 if (
auto blockArg = dyn_cast<BlockArgument>(actualReset))
1711 llvm::dbgs() <<
"port #" << blockArg.getArgNumber() <<
" ";
1713 llvm::dbgs() <<
"wire/node ";
1719 SmallVector<Operation *> opsToUpdate;
1720 module.walk([&](Operation *op) {
1721 if (isa<InstanceOp, RegOp, RegResetOp>(op))
1722 opsToUpdate.push_back(op);
1729 if (!isa<BlockArgument>(actualReset)) {
1730 mlir::DominanceInfo dom(module);
1735 auto *resetOp = actualReset.getDefiningOp();
1736 if (!opsToUpdate.empty() && !dom.dominates(resetOp, opsToUpdate[0])) {
1737 LLVM_DEBUG(llvm::dbgs()
1738 <<
"- Reset doesn't dominate all uses, needs to be moved\n");
1742 auto nodeOp = dyn_cast<NodeOp>(resetOp);
1743 if (nodeOp && !dom.dominates(nodeOp.getInput(), opsToUpdate[0])) {
1744 LLVM_DEBUG(llvm::dbgs()
1745 <<
"- Promoting node to wire for move: " << nodeOp <<
"\n");
1746 auto builder = ImplicitLocOpBuilder::atBlockBegin(nodeOp.getLoc(),
1747 nodeOp->getBlock());
1748 auto wireOp = builder.create<WireOp>(
1749 nodeOp.getResult().getType(), nodeOp.getNameAttr(),
1750 nodeOp.getNameKindAttr(), nodeOp.getAnnotationsAttr(),
1751 nodeOp.getInnerSymAttr(), nodeOp.getForceableAttr());
1753 nodeOp->replaceAllUsesWith(wireOp);
1754 nodeOp->removeAttr(nodeOp.getInnerSymAttrName());
1758 nodeOp.setNameKind(NameKindEnum::DroppableName);
1759 nodeOp.setAnnotationsAttr(
ArrayAttr::get(builder.getContext(), {}));
1760 builder.setInsertionPointAfter(nodeOp);
1761 emitConnect(builder, wireOp.getResult(), nodeOp.getResult());
1763 actualReset = wireOp.getResult();
1764 domain.existingValue = wireOp.getResult();
1769 Block *targetBlock = dom.findNearestCommonDominator(
1770 resetOp->getBlock(), opsToUpdate[0]->getBlock());
1772 if (targetBlock != resetOp->getBlock())
1773 llvm::dbgs() <<
"- Needs to be moved to different block\n";
1782 auto getParentInBlock = [](Operation *op,
Block *block) {
1783 while (op && op->getBlock() != block)
1784 op = op->getParentOp();
1787 auto *resetOpInTarget = getParentInBlock(resetOp, targetBlock);
1788 auto *firstOpInTarget = getParentInBlock(opsToUpdate[0], targetBlock);
1794 if (resetOpInTarget->isBeforeInBlock(firstOpInTarget))
1795 resetOp->moveBefore(resetOpInTarget);
1797 resetOp->moveBefore(firstOpInTarget);
1802 for (
auto *op : opsToUpdate)
1803 implementFullReset(op, module, actualReset);
1810 void InferResetsPass::implementFullReset(Operation *op, FModuleOp module,
1811 Value actualReset) {
1812 ImplicitLocOpBuilder builder(op->getLoc(), op);
1815 if (
auto instOp = dyn_cast<InstanceOp>(op)) {
1819 auto refModule = instOp.getReferencedModule<FModuleOp>(*instanceGraph);
1822 auto domainIt = domains.find(refModule);
1823 if (domainIt == domains.end())
1825 auto &domain = domainIt->second.back().first;
1828 LLVM_DEBUG(llvm::dbgs()
1829 <<
"- Update instance '" << instOp.getName() <<
"'\n");
1833 if (domain.newPortName) {
1834 LLVM_DEBUG(llvm::dbgs() <<
" - Adding new result as reset\n");
1836 auto newInstOp = instOp.cloneAndInsertPorts(
1838 {domain.newPortName,
1839 type_cast<FIRRTLBaseType>(actualReset.getType()),
1841 instReset = newInstOp.getResult(0);
1844 instOp.replaceAllUsesWith(newInstOp.getResults().drop_front());
1845 instanceGraph->replaceInstance(instOp, newInstOp);
1848 }
else if (domain.existingPort.has_value()) {
1849 auto idx = *domain.existingPort;
1850 instReset = instOp.getResult(idx);
1851 LLVM_DEBUG(llvm::dbgs() <<
" - Using result #" << idx <<
" as reset\n");
1861 assert(instReset && actualReset);
1862 builder.setInsertionPointAfter(instOp);
1868 if (
auto regOp = dyn_cast<RegOp>(op)) {
1872 LLVM_DEBUG(llvm::dbgs() <<
"- Adding full reset to " << regOp <<
"\n");
1874 auto newRegOp = builder.create<RegResetOp>(
1875 regOp.getResult().getType(), regOp.getClockVal(), actualReset, zero,
1876 regOp.getNameAttr(), regOp.getNameKindAttr(), regOp.getAnnotations(),
1877 regOp.getInnerSymAttr(), regOp.getForceableAttr());
1878 regOp.getResult().replaceAllUsesWith(newRegOp.getResult());
1879 if (regOp.getForceable())
1880 regOp.getRef().replaceAllUsesWith(newRegOp.getRef());
1886 if (
auto regOp = dyn_cast<RegResetOp>(op)) {
1889 if (type_isa<AsyncResetType>(regOp.getResetSignal().getType()) ||
1890 type_isa<UIntType>(actualReset.getType())) {
1891 LLVM_DEBUG(llvm::dbgs() <<
"- Skipping (has reset) " << regOp <<
"\n");
1894 if (failed(regOp.verifyInvariants()))
1895 signalPassFailure();
1898 LLVM_DEBUG(llvm::dbgs() <<
"- Updating reset of " << regOp <<
"\n");
1900 auto reset = regOp.getResetSignal();
1901 auto value = regOp.getResetValue();
1907 builder.setInsertionPointAfterValue(regOp.getResult());
1908 auto mux = builder.create<MuxPrimOp>(reset, value, regOp.getResult());
1912 builder.setInsertionPoint(regOp);
1914 regOp.getResetSignalMutable().assign(actualReset);
1915 regOp.getResetValueMutable().assign(zero);
1919 LogicalResult InferResetsPass::verifyNoAbstractReset() {
1920 bool hasAbstractResetPorts =
false;
1921 for (FModuleLike module :
1922 getOperation().
getBodyBlock()->getOps<FModuleLike>()) {
1923 for (
PortInfo port : module.getPorts()) {
1924 if (getBaseOfType<ResetType>(port.type)) {
1925 auto diag = emitError(port.loc)
1926 <<
"a port \"" << port.getName()
1927 <<
"\" with abstract reset type was unable to be "
1928 "inferred by InferResets (is this a top-level port?)";
1929 diag.attachNote(module->getLoc())
1930 <<
"the module with this uninferred reset port was defined here";
1931 hasAbstractResetPorts =
true;
1936 if (hasAbstractResetPorts)
assert(baseType &&"element must be base type")
static Value createZeroValue(ImplicitLocOpBuilder &builder, FIRRTLBaseType type, SmallDenseMap< FIRRTLBaseType, Value > &cache)
Construct a zero value of the given type using the given builder.
static unsigned getFieldID(BundleType type, unsigned index)
static std::pair< StringAttr, FModuleOp > getResetNameAndModule(Value reset)
Return the name and parent module of a reset.
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
static bool isUselessVec(FIRRTLBaseType oldType, unsigned fieldID)
static StringAttr getResetName(Value reset)
Return the name of a reset.
static bool insertResetMux(ImplicitLocOpBuilder &builder, Value target, Value reset, Value resetValue)
Helper function that inserts reset multiplexer into all ConnectOps with the given target.
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static bool typeContainsReset(Type type)
Check whether a type contains a ResetType.
static bool getDeclName(Value value, SmallString< 32 > &string)
static unsigned getMaxFieldID(FIRRTLBaseType type)
static InstancePath empty
static Block * getBodyBlock(FModuleLike mod)
This class represents a reference to a specific field or element of an aggregate value.
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Value getValue() const
Get the Value which created this location.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class provides a read-only projection of an annotation.
AttrClass getMember(StringAttr name) const
Return a member of the annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
FIRRTLBaseType getConstType(bool isConst)
Return a 'const' or non-'const' version of this type.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
An instance path composed of a series of instances.
InstanceOpInterface leaf() const
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
constexpr const char * excludeFromFullResetAnnoClass
Annotation that marks a module as not belonging to any reset domain.
FieldRef getFieldRefForTarget(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
constexpr const char * excludeMemToRegAnnoClass
igraph::InstancePathCache InstancePathCache
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
FIRRTLType mapBaseType(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
constexpr const char * fullResetAnnoClass
Annotation that marks a reset (port or wire) and domain.
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
std::unique_ptr< mlir::Pass > createInferResetsPass()
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(llvm::StringRef str, int width=80)
Write a "header"-like string to the debug stream with a certain width.
inline ::llvm::hash_code hash_value(const FieldRef &fieldRef)
Get a hash code for a FieldRef.
bool operator==(uint64_t a, const FVInt &b)
bool operator!=(uint64_t a, const FVInt &b)
bool operator<(const AppID &a, const AppID &b)
This holds the name and type that describes the module's ports.
static ResetSignal getEmptyKey()
static ResetSignal getTombstoneKey()
static bool isEqual(const ResetSignal &lhs, const ResetSignal &rhs)
static unsigned getHashValue(const ResetSignal &x)