24 #include "mlir/IR/Dominance.h"
25 #include "mlir/IR/ImplicitLocOpBuilder.h"
26 #include "mlir/IR/Threading.h"
27 #include "llvm/ADT/EquivalenceClasses.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/ADT/TinyPtrVector.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
33 #define DEBUG_TYPE "infer-resets"
35 using circt::igraph::InstanceOpInterface;
38 using llvm::BumpPtrAllocator;
39 using llvm::MapVector;
40 using llvm::SmallDenseSet;
41 using llvm::SmallSetVector;
43 using mlir::InferTypeOpInterface;
44 using mlir::WalkOrder;
46 using namespace circt;
47 using namespace firrtl;
64 std::optional<unsigned> existingPort;
65 StringAttr newPortName;
67 ResetDomain(Value reset) : reset(reset) {}
71 inline bool operator==(
const ResetDomain &a,
const ResetDomain &b) {
72 return (a.isTop == b.isTop && a.reset == b.reset);
74 inline bool operator!=(
const ResetDomain &a,
const ResetDomain &b) {
81 if (
auto arg = dyn_cast<BlockArgument>(reset)) {
82 auto module = cast<FModuleOp>(arg.getParentRegion()->getParentOp());
83 return {module.getPortNameAttr(arg.getArgNumber()), module};
85 auto op = reset.getDefiningOp();
86 return {op->getAttrOfType<StringAttr>(
"name"),
87 op->getParentOfType<FModuleOp>()};
102 auto it = cache.find(type);
103 if (it != cache.end())
105 auto nullBit = [&]() {
112 .
Case<ClockType>([&](
auto type) {
113 return builder.create<AsClockPrimOp>(nullBit());
115 .Case<AsyncResetType>([&](
auto type) {
116 return builder.create<AsAsyncResetPrimOp>(nullBit());
118 .Case<SIntType, UIntType>([&](
auto type) {
119 return builder.create<ConstantOp>(
120 type, APInt::getZero(type.getWidth().value_or(1)));
122 .Case<BundleType>([&](
auto type) {
123 auto wireOp =
builder.create<WireOp>(type);
124 for (
unsigned i = 0, e = type.getNumElements(); i < e; ++i) {
125 auto fieldType = type.getElementTypePreservingConst(i);
128 builder.create<SubfieldOp>(fieldType, wireOp.getResult(), i);
129 builder.create<StrictConnectOp>(acc, zero);
131 return wireOp.getResult();
133 .Case<FVectorType>([&](
auto type) {
134 auto wireOp =
builder.create<WireOp>(type);
136 builder, type.getElementTypePreservingConst(), cache);
137 for (
unsigned i = 0, e = type.getNumElements(); i < e; ++i) {
138 auto acc =
builder.create<SubindexOp>(zero.getType(),
139 wireOp.getResult(), i);
140 builder.create<StrictConnectOp>(acc, zero);
142 return wireOp.getResult();
144 .Case<ResetType, AnalogType>(
145 [&](
auto type) {
return builder.create<InvalidValueOp>(type); })
147 llvm_unreachable(
"switch handles all types");
150 cache.insert({type, value});
167 Value reset, Value resetValue) {
171 bool resetValueUsed =
false;
173 for (
auto &use : target.getUses()) {
174 Operation *useOp = use.getOwner();
175 builder.setInsertionPoint(useOp);
176 TypeSwitch<Operation *>(useOp)
179 .Case<ConnectOp, StrictConnectOp>([&](
auto op) {
180 if (op.getDest() != target)
182 LLVM_DEBUG(
llvm::dbgs() <<
" - Insert mux into " << op <<
"\n");
184 builder.create<MuxPrimOp>(reset, resetValue, op.getSrc());
185 op.getSrcMutable().assign(muxOp);
186 resetValueUsed =
true;
189 .Case<SubfieldOp>([&](
auto op) {
191 builder.create<SubfieldOp>(resetValue, op.getFieldIndexAttr());
193 resetValueUsed =
true;
195 resetSubValue.erase();
198 .Case<SubindexOp>([&](
auto op) {
200 builder.create<SubindexOp>(resetValue, op.getIndexAttr());
202 resetValueUsed =
true;
204 resetSubValue.erase();
207 .Case<SubaccessOp>([&](
auto op) {
208 if (op.getInput() != target)
211 builder.create<SubaccessOp>(resetValue, op.getIndex());
213 resetValueUsed =
true;
215 resetSubValue.erase();
218 return resetValueUsed;
233 bool operator<(
const ResetSignal &other)
const {
return field < other.field; }
234 bool operator==(
const ResetSignal &other)
const {
235 return field == other.field;
237 bool operator!=(
const ResetSignal &other)
const {
return !(*
this == other); }
257 using ResetDrives = SmallVector<ResetDrive, 1>;
260 using ResetNetwork = llvm::iterator_range<
261 llvm::EquivalenceClasses<ResetSignal>::member_iterator>;
264 enum class ResetKind { Async, Sync };
270 struct DenseMapInfo<ResetSignal> {
272 return ResetSignal{DenseMapInfo<FieldRef>::getEmptyKey(), {}};
275 return ResetSignal{DenseMapInfo<FieldRef>::getTombstoneKey(), {}};
280 static bool isEqual(
const ResetSignal &lhs,
const ResetSignal &rhs) {
286 template <
typename T>
289 case ResetKind::Async:
290 return os <<
"async";
291 case ResetKind::Sync:
403 struct InferResetsPass :
public InferResetsBase<InferResetsPass> {
404 void runOnOperation()
override;
405 void runOnOperationInner();
408 using InferResetsBase::InferResetsBase;
409 InferResetsPass(
const InferResetsPass &other) : InferResetsBase(other) {}
414 void traceResets(CircuitOp circuit);
415 void traceResets(InstanceOp inst);
416 void traceResets(Value dst, Value src, Location loc);
417 void traceResets(Value value);
418 void traceResets(Type dstType, Value dst,
unsigned dstID, Type srcType,
419 Value src,
unsigned srcID, Location loc);
421 LogicalResult inferAndUpdateResets();
423 LogicalResult updateReset(ResetNetwork net, ResetKind kind);
429 LogicalResult collectAnnos(CircuitOp circuit);
437 LogicalResult buildDomains(CircuitOp circuit);
438 void buildDomains(FModuleOp module,
const InstancePath &instPath,
440 unsigned indent = 0);
442 void determineImpl();
443 void determineImpl(FModuleOp module, ResetDomain &domain);
445 LogicalResult implementAsyncReset();
446 LogicalResult implementAsyncReset(FModuleOp module, ResetDomain &domain);
447 void implementAsyncReset(Operation *op, FModuleOp module, Value actualReset);
449 LogicalResult verifyNoAbstractReset();
455 ResetNetwork getResetNetwork(ResetSignal signal) {
456 return llvm::make_range(resetClasses.findLeader(signal),
457 resetClasses.member_end());
461 ResetDrives &getResetDrives(ResetNetwork net) {
462 return resetDrives[*net.begin()];
467 ResetSignal guessRoot(ResetNetwork net);
468 ResetSignal guessRoot(ResetSignal signal) {
469 return guessRoot(getResetNetwork(signal));
476 llvm::EquivalenceClasses<ResetSignal> resetClasses;
479 DenseMap<ResetSignal, ResetDrives> resetDrives;
484 DenseMap<Operation *, Value> annotatedResets;
488 MapVector<FModuleOp, SmallVector<std::pair<ResetDomain, InstancePath>, 1>>
495 std::unique_ptr<InstancePathCache> instancePathCache;
499 void InferResetsPass::runOnOperation() {
500 runOnOperationInner();
501 resetClasses = llvm::EquivalenceClasses<ResetSignal>();
503 annotatedResets.clear();
505 instancePathCache.reset(
nullptr);
506 markAnalysesPreserved<InstanceGraph>();
509 void InferResetsPass::runOnOperationInner() {
510 instanceGraph = &getAnalysis<InstanceGraph>();
511 instancePathCache = std::make_unique<InstancePathCache>(*instanceGraph);
514 traceResets(getOperation());
517 if (failed(inferAndUpdateResets()))
518 return signalPassFailure();
521 if (failed(collectAnnos(getOperation())))
522 return signalPassFailure();
525 if (failed(buildDomains(getOperation())))
526 return signalPassFailure();
532 if (failed(implementAsyncReset()))
533 return signalPassFailure();
536 if (failed(verifyNoAbstractReset()))
537 return signalPassFailure();
541 return std::make_unique<InferResetsPass>();
544 ResetSignal InferResetsPass::guessRoot(ResetNetwork net) {
545 ResetDrives &drives = getResetDrives(net);
546 ResetSignal bestSignal = *net.begin();
547 unsigned bestNumDrives = -1;
549 for (
auto signal : net) {
551 if (isa_and_nonnull<InvalidValueOp>(
552 signal.field.getValue().getDefiningOp()))
557 unsigned numDrives = 0;
558 for (
auto &drive : drives)
559 if (drive.dst == signal)
565 if (numDrives < bestNumDrives) {
566 bestNumDrives = numDrives;
585 .
Case<BundleType>([](
auto type) {
587 for (
auto e : type.getElements())
592 [](
auto type) {
return getMaxFieldID(type.getElementType()) + 1; })
593 .Default([](
auto) {
return 0; });
596 static unsigned getFieldID(BundleType type,
unsigned index) {
597 assert(index < type.getNumElements());
599 for (
unsigned i = 0; i < index; ++i)
607 assert(type.getNumElements() &&
"Bundle must have >0 fields");
609 for (
const auto &e : llvm::enumerate(type.getElements())) {
611 if (fieldID < numSubfields)
613 fieldID -= numSubfields;
615 assert(
false &&
"field id outside bundle");
621 if (oldType.isGround()) {
627 if (
auto bundleType = type_dyn_cast<BundleType>(oldType)) {
635 if (
auto vectorType = type_dyn_cast<FVectorType>(oldType)) {
636 if (vectorType.getNumElements() == 0)
653 if (
auto arg = dyn_cast<BlockArgument>(value)) {
654 auto module = cast<FModuleOp>(arg.getOwner()->getParentOp());
655 string += module.getPortName(arg.getArgNumber());
659 auto *op = value.getDefiningOp();
660 return TypeSwitch<Operation *, bool>(op)
661 .Case<InstanceOp, MemOp>([&](
auto op) {
662 string += op.getName();
665 op.getPortName(cast<OpResult>(value).getResultNumber()).getValue();
668 .Case<WireOp, RegOp, RegResetOp>([&](
auto op) {
669 string += op.getName();
672 .Default([](
auto) {
return false; });
676 SmallString<64> name;
681 auto type = value.getType();
684 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
687 auto &element = bundleType.getElements()[index];
690 string += element.name.getValue();
693 localID = localID -
getFieldID(bundleType, index);
694 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
697 type = vecType.getElementType();
704 llvm_unreachable(
"unsupported type");
716 return TypeSwitch<Type, bool>(type)
718 return type.getRecursiveTypeProperties().hasUninferredReset;
720 .Default([](
auto) {
return false; });
727 void InferResetsPass::traceResets(CircuitOp circuit) {
730 debugHeader(
"Tracing uninferred resets") <<
"\n\n";
733 SmallVector<std::pair<FModuleOp, SmallVector<Operation *>>> moduleToOps;
735 for (
auto module : circuit.getOps<FModuleOp>())
736 moduleToOps.push_back({module, {}});
738 mlir::parallelForEach(circuit.getContext(), moduleToOps, [](
auto &e) {
739 e.first.walk([&](Operation *op) {
743 op->getResultTypes(),
744 [](mlir::Type type) { return typeContainsReset(type); }) ||
745 llvm::any_of(op->getOperandTypes(), typeContainsReset))
746 e.second.push_back(op);
750 for (
auto &[_, ops] : moduleToOps)
751 for (
auto *op : ops) {
752 TypeSwitch<Operation *>(op)
753 .Case<FConnectLike>([&](
auto op) {
754 traceResets(op.getDest(), op.getSrc(), op.getLoc());
756 .Case<InstanceOp>([&](
auto op) { traceResets(op); })
757 .Case<RefSendOp>([&](
auto op) {
759 traceResets(op.getType().getType(), op.getResult(), 0,
760 op.getBase().getType().getPassiveType(), op.getBase(),
763 .Case<RefResolveOp>([&](
auto op) {
765 traceResets(op.getType(), op.getResult(), 0,
766 op.getRef().getType().getType(), op.getRef(), 0,
769 .Case<Forceable>([&](Forceable op) {
771 if (op.isForceable())
772 traceResets(op.getDataType(), op.getData(), 0, op.getDataType(),
773 op.getDataRef(), 0, op.getLoc());
775 .Case<UninferredResetCastOp, ConstCastOp, RefCastOp>([&](
auto op) {
776 traceResets(op.getResult(), op.getInput(), op.getLoc());
778 .Case<InvalidValueOp>([&](
auto op) {
787 auto type = op.getType();
790 LLVM_DEBUG(
llvm::dbgs() <<
"Uniquify " << op <<
"\n");
791 ImplicitLocOpBuilder
builder(op->getLoc(), op);
793 llvm::make_early_inc_range(llvm::drop_begin(op->getUses()))) {
799 auto newOp =
builder.create<InvalidValueOp>(type);
804 .Case<SubfieldOp>([&](
auto op) {
807 BundleType bundleType = op.getInput().getType();
808 auto index = op.getFieldIndex();
809 traceResets(op.getType(), op.getResult(), 0,
810 bundleType.getElements()[index].type, op.getInput(),
814 .Case<SubindexOp, SubaccessOp>([&](
auto op) {
827 FVectorType vectorType = op.getInput().getType();
828 traceResets(op.getType(), op.getResult(), 0,
829 vectorType.getElementType(), op.getInput(),
833 .Case<RefSubOp>([&](RefSubOp op) {
835 auto aggType = op.getInput().getType().getType();
836 uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(aggType)
837 .Case<FVectorType>([](
auto type) {
840 .Case<BundleType>([&](
auto type) {
843 traceResets(op.getType(), op.getResult(), 0,
844 op.getResult().getType(), op.getInput(), fieldID,
852 void InferResetsPass::traceResets(InstanceOp inst) {
854 auto module = inst.getReferencedModule<FModuleOp>(*instanceGraph);
857 LLVM_DEBUG(
llvm::dbgs() <<
"Visiting instance " << inst.getName() <<
"\n");
860 for (
const auto &it : llvm::enumerate(inst.getResults())) {
861 auto dir = module.getPortDirection(it.index());
862 Value dstPort = module.getArgument(it.index());
863 Value srcPort = it.value();
864 if (dir == Direction::Out)
865 std::swap(dstPort, srcPort);
866 traceResets(dstPort, srcPort, it.value().getLoc());
872 void InferResetsPass::traceResets(Value dst, Value src, Location loc) {
874 traceResets(dst.getType(), dst, 0, src.getType(), src, 0, loc);
879 void InferResetsPass::traceResets(Type dstType, Value dst,
unsigned dstID,
880 Type srcType, Value src,
unsigned srcID,
882 if (
auto dstBundle = type_dyn_cast<BundleType>(dstType)) {
883 auto srcBundle = type_cast<BundleType>(srcType);
884 for (
unsigned dstIdx = 0, e = dstBundle.getNumElements(); dstIdx < e;
886 auto dstField = dstBundle.getElements()[dstIdx].name;
887 auto srcIdx = srcBundle.getElementIndex(dstField);
890 auto &dstElt = dstBundle.getElements()[dstIdx];
891 auto &srcElt = srcBundle.getElements()[*srcIdx];
893 traceResets(srcElt.type, src, srcID +
getFieldID(srcBundle, *srcIdx),
894 dstElt.type, dst, dstID +
getFieldID(dstBundle, dstIdx),
897 traceResets(dstElt.type, dst, dstID +
getFieldID(dstBundle, dstIdx),
898 srcElt.type, src, srcID +
getFieldID(srcBundle, *srcIdx),
905 if (
auto dstVector = type_dyn_cast<FVectorType>(dstType)) {
906 auto srcVector = type_cast<FVectorType>(srcType);
907 auto srcElType = srcVector.getElementType();
908 auto dstElType = dstVector.getElementType();
921 traceResets(dstElType, dst, dstID +
getFieldID(dstVector), srcElType, src,
927 if (
auto dstRef = type_dyn_cast<RefType>(dstType)) {
928 auto srcRef = type_cast<RefType>(srcType);
929 return traceResets(dstRef.getType(), dst, dstID, srcRef.getType(), src,
934 auto dstBase = type_dyn_cast<FIRRTLBaseType>(dstType);
935 auto srcBase = type_dyn_cast<FIRRTLBaseType>(srcType);
936 if (!dstBase || !srcBase)
938 if (!type_isa<ResetType>(dstBase) && !type_isa<ResetType>(srcBase))
943 LLVM_DEBUG(
llvm::dbgs() <<
"Visiting driver '" << dstField <<
"' = '"
944 << srcField <<
"' (" << dstType <<
" = " << srcType
950 ResetSignal dstLeader =
951 *resetClasses.findLeader(resetClasses.insert({dstField, dstBase}));
952 ResetSignal srcLeader =
953 *resetClasses.findLeader(resetClasses.insert({srcField, srcBase}));
956 ResetSignal unionLeader = *resetClasses.unionSets(dstLeader, srcLeader);
957 assert(unionLeader == dstLeader || unionLeader == srcLeader);
962 if (dstLeader != srcLeader) {
963 auto &unionDrives = resetDrives[unionLeader];
964 auto mergedDrivesIt =
965 resetDrives.find(unionLeader == dstLeader ? srcLeader : dstLeader);
966 if (mergedDrivesIt != resetDrives.end()) {
967 unionDrives.append(mergedDrivesIt->second);
968 resetDrives.erase(mergedDrivesIt);
974 resetDrives[unionLeader].push_back(
975 {{dstField, dstBase}, {srcField, srcBase}, loc});
982 LogicalResult InferResetsPass::inferAndUpdateResets() {
987 for (
auto it = resetClasses.begin(),
end = resetClasses.end(); it !=
end;
991 ResetNetwork net = llvm::make_range(resetClasses.member_begin(it),
992 resetClasses.member_end());
995 auto kind = inferReset(net);
1000 if (failed(updateReset(net, *kind)))
1007 LLVM_DEBUG(
llvm::dbgs() <<
"Inferring reset network with "
1008 << std::distance(net.begin(), net.end())
1012 unsigned asyncDrives = 0;
1013 unsigned syncDrives = 0;
1014 unsigned invalidDrives = 0;
1015 for (ResetSignal signal : net) {
1017 if (type_isa<AsyncResetType>(signal.type))
1019 else if (type_isa<UIntType>(signal.type))
1022 isa_and_nonnull<InvalidValueOp>(
1023 signal.field.getValue().getDefiningOp()))
1026 LLVM_DEBUG(
llvm::dbgs() <<
"- Found " << asyncDrives <<
" async, "
1027 << syncDrives <<
" sync, " << invalidDrives
1028 <<
" invalid drives\n");
1031 if (asyncDrives == 0 && syncDrives == 0 && invalidDrives == 0) {
1032 ResetSignal root = guessRoot(net);
1033 auto diag = mlir::emitError(root.field.getValue().getLoc())
1034 <<
"reset network never driven with concrete type";
1035 for (ResetSignal signal : net)
1036 diag.attachNote(signal.field.getLoc()) <<
"here: ";
1041 if (asyncDrives > 0 && syncDrives > 0) {
1042 ResetSignal root = guessRoot(net);
1043 bool majorityAsync = asyncDrives >= syncDrives;
1044 auto diag = mlir::emitError(root.field.getValue().getLoc())
1046 SmallString<32> fieldName;
1048 diag <<
" \"" << fieldName <<
"\"";
1049 diag <<
" simultaneously connected to async and sync resets";
1050 diag.attachNote(root.field.getValue().getLoc())
1051 <<
"majority of connections to this reset are "
1052 << (majorityAsync ?
"async" :
"sync");
1053 for (
auto &drive : getResetDrives(net)) {
1054 if ((type_isa<AsyncResetType>(drive.dst.type) && !majorityAsync) ||
1055 (type_isa<AsyncResetType>(drive.src.type) && !majorityAsync) ||
1056 (type_isa<UIntType>(drive.dst.type) && majorityAsync) ||
1057 (type_isa<UIntType>(drive.src.type) && majorityAsync))
1058 diag.attachNote(drive.loc)
1059 << (type_isa<AsyncResetType>(drive.src.type) ?
"async" :
"sync")
1068 auto kind = (asyncDrives ? ResetKind::Async : ResetKind::Sync);
1069 LLVM_DEBUG(
llvm::dbgs() <<
"- Inferred as " << kind <<
"\n");
1077 LogicalResult InferResetsPass::updateReset(ResetNetwork net, ResetKind kind) {
1078 LLVM_DEBUG(
llvm::dbgs() <<
"Updating reset network with "
1079 << std::distance(net.begin(), net.end())
1080 <<
" nodes to " << kind <<
"\n");
1084 if (kind == ResetKind::Async)
1092 SmallSetVector<Operation *, 16> worklist;
1093 SmallDenseSet<Operation *> moduleWorklist;
1094 SmallDenseSet<std::pair<Operation *, Operation *>> extmoduleWorklist;
1095 for (
auto signal : net) {
1096 Value value = signal.field.getValue();
1097 if (!isa<BlockArgument>(value) &&
1098 !isa_and_nonnull<WireOp, RegOp, RegResetOp, InstanceOp, InvalidValueOp,
1099 ConstCastOp, RefCastOp, UninferredResetCastOp>(
1100 value.getDefiningOp()))
1102 if (updateReset(signal.field, resetType)) {
1103 for (
auto user : value.getUsers())
1104 worklist.insert(user);
1105 if (
auto blockArg = dyn_cast<BlockArgument>(value))
1106 moduleWorklist.insert(blockArg.getOwner()->getParentOp());
1107 else if (
auto instOp = value.getDefiningOp<InstanceOp>()) {
1108 if (
auto extmodule =
1109 instOp.getReferencedModule<FExtModuleOp>(*instanceGraph))
1110 extmoduleWorklist.insert({extmodule, instOp});
1111 }
else if (
auto uncast = value.getDefiningOp<UninferredResetCastOp>()) {
1112 uncast.replaceAllUsesWith(uncast.getInput());
1122 while (!worklist.empty()) {
1123 auto *wop = worklist.pop_back_val();
1124 SmallVector<Type, 2> types;
1125 if (
auto op = dyn_cast<InferTypeOpInterface>(wop)) {
1127 SmallVector<Type, 2> types;
1128 if (failed(op.inferReturnTypes(op->getContext(), op->getLoc(),
1129 op->getOperands(), op->getAttrDictionary(),
1130 op->getPropertiesStorage(),
1131 op->getRegions(), types)))
1136 for (
auto it : llvm::zip(op->getResults(), types)) {
1137 auto newType = std::get<1>(it);
1138 if (std::get<0>(it).getType() == newType)
1140 std::get<0>(it).setType(newType);
1141 for (
auto *user : std::get<0>(it).getUsers())
1142 worklist.insert(user);
1144 LLVM_DEBUG(
llvm::dbgs() <<
"- Inferred " << *op <<
"\n");
1145 }
else if (
auto uop = dyn_cast<UninferredResetCastOp>(wop)) {
1146 for (
auto *user : uop.getResult().getUsers())
1147 worklist.insert(user);
1148 uop.replaceAllUsesWith(uop.getInput());
1149 LLVM_DEBUG(
llvm::dbgs() <<
"- Inferred " << uop <<
"\n");
1155 for (
auto *op : moduleWorklist) {
1156 auto module = dyn_cast<FModuleOp>(op);
1160 SmallVector<Attribute> argTypes;
1161 argTypes.reserve(module.getNumPorts());
1162 for (
auto arg : module.getArguments())
1165 module->setAttr(FModuleLike::getPortTypesAttrName(),
1168 <<
"- Updated type of module '" << module.getName() <<
"'\n");
1172 for (
auto pair : extmoduleWorklist) {
1173 auto module = cast<FExtModuleOp>(pair.first);
1174 auto instOp = cast<InstanceOp>(pair.second);
1176 SmallVector<Attribute> types;
1177 for (
auto type : instOp.getResultTypes())
1180 module->setAttr(FModuleLike::getPortTypesAttrName(),
1183 <<
"- Updated type of extmodule '" << module.getName() <<
"'\n");
1193 if (oldType.isGround()) {
1199 if (
auto bundleType = type_dyn_cast<BundleType>(oldType)) {
1201 SmallVector<BundleType::BundleElement> fields(bundleType.begin(),
1204 fields[index].type, fieldID -
getFieldID(bundleType, index), fieldType);
1209 if (
auto vectorType = type_dyn_cast<FVectorType>(oldType)) {
1210 auto newType =
updateType(vectorType.getElementType(),
1211 fieldID -
getFieldID(vectorType), fieldType);
1213 vectorType.isConst());
1216 llvm_unreachable(
"unknown aggregate type");
1223 auto oldType = type_cast<FIRRTLType>(field.
getValue().getType());
1229 if (oldType == newType)
1231 LLVM_DEBUG(
llvm::dbgs() <<
"- Updating '" << field <<
"' from " << oldType
1232 <<
" to " << newType <<
"\n");
1241 LogicalResult InferResetsPass::collectAnnos(CircuitOp circuit) {
1244 debugHeader(
"Gather async reset annotations") <<
"\n\n";
1246 SmallVector<std::pair<FModuleOp, std::optional<Value>>> results;
1247 for (
auto module : circuit.getOps<FModuleOp>())
1248 results.push_back({module, {}});
1250 if (failed(mlir::failableParallelForEach(
1251 circuit.getContext(), results, [&](
auto &moduleAndResult) {
1252 auto result = collectAnnos(moduleAndResult.first);
1255 moduleAndResult.second = *result;
1260 for (
auto [module, reset] : results)
1261 if (reset.has_value())
1262 annotatedResets.insert({module, *reset});
1267 InferResetsPass::collectAnnos(FModuleOp module) {
1268 bool anyFailed =
false;
1269 SmallSetVector<std::pair<Annotation, Location>, 4> conflictingAnnos;
1273 bool ignore =
false;
1275 if (!moduleAnnos.empty()) {
1276 moduleAnnos.removeAnnotations([&](
Annotation anno) {
1279 conflictingAnnos.insert({anno, module.getLoc()});
1284 module.emitError(
"'FullAsyncResetAnnotation' cannot target module; "
1285 "must target port or wire/node instead");
1290 moduleAnnos.applyToOperation(module);
1297 AnnotationSet::removePortAnnotations(module, [&](
unsigned argNum,
1299 Value arg = module.getArgument(argNum);
1301 if (!isa<AsyncResetType>(arg.getType())) {
1302 mlir::emitError(arg.getLoc(),
"'IgnoreFullAsyncResetAnnotation' must "
1303 "target async reset, but targets ")
1309 conflictingAnnos.insert({anno, reset.getLoc()});
1315 mlir::emitError(arg.getLoc(),
1316 "'IgnoreFullAsyncResetAnnotation' cannot target port; "
1317 "must target module instead");
1326 module.walk([&](Operation *op) {
1327 AnnotationSet::removeAnnotations(op, [&](
Annotation anno) {
1329 if (!isa<WireOp, NodeOp>(op)) {
1334 "reset annotations must target module, port, or wire/node");
1342 auto resultType = op->getResult(0).getType();
1344 if (!isa<AsyncResetType>(resultType)) {
1345 mlir::emitError(op->getLoc(),
"'IgnoreFullAsyncResetAnnotation' must "
1346 "target async reset, but targets ")
1351 reset = op->getResult(0);
1352 conflictingAnnos.insert({anno, reset.getLoc()});
1358 "'IgnoreFullAsyncResetAnnotation' cannot target wire/node; must "
1359 "target module instead");
1371 if (!ignore && !reset) {
1373 <<
"No reset annotation for " << module.getName() <<
"\n");
1374 return std::optional<Value>();
1378 if (conflictingAnnos.size() > 1) {
1379 auto diag = module.emitError(
"multiple reset annotations on module '")
1380 << module.getName() <<
"'";
1381 for (
auto &annoAndLoc : conflictingAnnos)
1382 diag.attachNote(annoAndLoc.second)
1383 <<
"conflicting " << annoAndLoc.first.getClassAttr() <<
":";
1389 llvm::dbgs() <<
"Annotated reset for " << module.getName() <<
": ";
1392 else if (
auto arg = dyn_cast<BlockArgument>(reset))
1393 llvm::dbgs() <<
"port " << module.getPortName(arg.getArgNumber()) <<
"\n";
1396 << reset.getDefiningOp()->getAttrOfType<StringAttr>(
"name")
1402 return std::optional<Value>(reset);
1414 LogicalResult InferResetsPass::buildDomains(CircuitOp circuit) {
1417 debugHeader(
"Build async reset domains") <<
"\n\n";
1421 auto &instGraph = getAnalysis<InstanceGraph>();
1422 auto module = dyn_cast<FModuleOp>(*instGraph.getTopLevelNode()->getModule());
1425 <<
"Skipping circuit because main module is no `firrtl.module`");
1428 buildDomains(module,
InstancePath{}, Value{}, instGraph);
1431 bool anyFailed =
false;
1432 for (
auto &it : domains) {
1433 auto module = cast<FModuleOp>(it.first);
1434 auto &domainConflicts = it.second;
1435 if (domainConflicts.size() <= 1)
1439 SmallDenseSet<Value> printedDomainResets;
1440 auto diag = module.emitError(
"module '")
1442 <<
"' instantiated in different reset domains";
1443 for (
auto &it : domainConflicts) {
1444 ResetDomain &domain = it.first;
1445 const auto &path = it.second;
1446 auto inst = path.leaf();
1447 auto loc = path.empty() ? module.getLoc() : inst.getLoc();
1448 auto ¬e = diag.attachNote(loc);
1452 note <<
"root instance";
1454 note <<
"instance '";
1457 [&](InstanceOpInterface inst) { note << inst.getInstanceName(); },
1458 [&]() { note <<
"/"; });
1466 note <<
" reset domain rooted at '" << nameAndModule.first.getValue()
1467 <<
"' of module '" << nameAndModule.second.getName() <<
"'";
1470 if (printedDomainResets.insert(domain.reset).second) {
1471 diag.attachNote(domain.reset.getLoc())
1472 <<
"reset domain '" << nameAndModule.first.getValue()
1473 <<
"' of module '" << nameAndModule.second.getName()
1474 <<
"' declared here:";
1477 note <<
" no reset domain";
1480 return failure(anyFailed);
1483 void InferResetsPass::buildDomains(FModuleOp module,
1488 llvm::dbgs().indent(indent * 2) <<
"Visiting ";
1489 if (instPath.
empty())
1493 llvm::dbgs() <<
" (" << module.getName() <<
")\n";
1497 ResetDomain domain(parentReset);
1498 auto it = annotatedResets.find(module);
1499 if (it != annotatedResets.end()) {
1500 domain.isTop =
true;
1501 domain.reset = it->second;
1507 auto &entries = domains[module];
1508 if (llvm::all_of(entries,
1509 [&](
const auto &entry) {
return entry.first != domain; }))
1510 entries.push_back({domain, instPath});
1513 for (
auto *record : *instGraph[module]) {
1514 auto submodule = dyn_cast<FModuleOp>(*record->getTarget()->getModule());
1518 instancePathCache->appendInstance(instPath, record->getInstance());
1519 buildDomains(submodule, childPath, domain.reset, instGraph, indent + 1);
1524 void InferResetsPass::determineImpl() {
1527 debugHeader(
"Determine implementation") <<
"\n\n";
1529 for (
auto &it : domains) {
1530 auto module = cast<FModuleOp>(it.first);
1531 auto &domain = it.second.back().first;
1532 determineImpl(module, domain);
1552 void InferResetsPass::determineImpl(FModuleOp module, ResetDomain &domain) {
1555 LLVM_DEBUG(
llvm::dbgs() <<
"Planning reset for " << module.getName() <<
"\n");
1560 LLVM_DEBUG(
llvm::dbgs() <<
"- Rooting at local value "
1562 domain.existingValue = domain.reset;
1563 if (
auto blockArg = dyn_cast<BlockArgument>(domain.reset))
1564 domain.existingPort = blockArg.getArgNumber();
1571 auto neededType = domain.reset.getType();
1572 LLVM_DEBUG(
llvm::dbgs() <<
"- Looking for existing port " << neededName
1574 auto portNames = module.getPortNames();
1575 auto ports = llvm::zip(portNames, module.getArguments());
1576 auto portIt = llvm::find_if(
1577 ports, [&](
auto port) {
return std::get<0>(port) == neededName; });
1578 if (portIt != ports.end() && std::get<1>(*portIt).getType() == neededType) {
1580 <<
"- Reusing existing port " << neededName <<
"\n");
1581 domain.existingValue = std::get<1>(*portIt);
1582 domain.existingPort = std::distance(ports.begin(), portIt);
1592 if (portIt != ports.end()) {
1594 <<
"- Existing " << neededName <<
" has incompatible type "
1595 << std::get<1>(*portIt).getType() <<
"\n");
1597 unsigned suffix = 0;
1601 Twine(
"_") + Twine(suffix++));
1602 }
while (llvm::is_contained(portNames, newName));
1604 <<
"- Creating uniquified port " << newName <<
"\n");
1605 domain.newPortName = newName;
1611 LLVM_DEBUG(
llvm::dbgs() <<
"- Creating new port " << neededName <<
"\n");
1612 domain.newPortName = neededName;
1620 LogicalResult InferResetsPass::implementAsyncReset() {
1625 for (
auto &it : domains)
1626 if (failed(implementAsyncReset(cast<FModuleOp>(it.first),
1627 it.second.back().first)))
1637 LogicalResult InferResetsPass::implementAsyncReset(FModuleOp module,
1638 ResetDomain &domain) {
1639 LLVM_DEBUG(
llvm::dbgs() <<
"Implementing async reset for " << module.getName()
1643 if (!domain.reset) {
1645 <<
"- Skipping because module explicitly has no domain\n");
1650 Value actualReset = domain.existingValue;
1651 if (domain.newPortName) {
1652 PortInfo portInfo{domain.newPortName,
1656 domain.reset.getLoc()};
1657 module.insertPorts({{0, portInfo}});
1658 actualReset = module.getArgument(0);
1660 <<
"- Inserted port " << domain.newPortName <<
"\n");
1665 if (
auto blockArg = dyn_cast<BlockArgument>(actualReset))
1666 llvm::dbgs() <<
"port #" << blockArg.getArgNumber() <<
" ";
1674 SmallVector<Operation *> opsToUpdate;
1675 module.walk([&](Operation *op) {
1676 if (isa<InstanceOp, RegOp, RegResetOp>(op))
1677 opsToUpdate.push_back(op);
1684 if (!isa<BlockArgument>(actualReset)) {
1685 mlir::DominanceInfo dom(module);
1690 auto *resetOp = actualReset.getDefiningOp();
1691 if (!opsToUpdate.empty() && !dom.dominates(resetOp, opsToUpdate[0])) {
1693 <<
"- Reset doesn't dominate all uses, needs to be moved\n");
1697 auto nodeOp = dyn_cast<NodeOp>(resetOp);
1698 if (nodeOp && !dom.dominates(nodeOp.getInput(), opsToUpdate[0])) {
1700 <<
"- Promoting node to wire for move: " << nodeOp <<
"\n");
1701 ImplicitLocOpBuilder
builder(nodeOp.getLoc(), nodeOp);
1702 auto wireOp =
builder.create<WireOp>(
1703 nodeOp.getResult().getType(), nodeOp.getNameAttr(),
1704 nodeOp.getNameKindAttr(), nodeOp.getAnnotationsAttr(),
1705 nodeOp.getInnerSymAttr(), nodeOp.getForceableAttr());
1706 builder.create<StrictConnectOp>(wireOp.getResult(), nodeOp.getInput());
1707 nodeOp->replaceAllUsesWith(wireOp);
1710 actualReset = wireOp.getResult();
1711 domain.existingValue = wireOp.getResult();
1716 Block *targetBlock = dom.findNearestCommonDominator(
1717 resetOp->getBlock(), opsToUpdate[0]->getBlock());
1719 if (targetBlock != resetOp->getBlock())
1720 llvm::dbgs() <<
"- Needs to be moved to different block\n";
1729 auto getParentInBlock = [](Operation *op,
Block *block) {
1730 while (op && op->getBlock() != block)
1731 op = op->getParentOp();
1734 auto *resetOpInTarget = getParentInBlock(resetOp, targetBlock);
1735 auto *firstOpInTarget = getParentInBlock(opsToUpdate[0], targetBlock);
1741 if (resetOpInTarget->isBeforeInBlock(firstOpInTarget))
1742 resetOp->moveBefore(resetOpInTarget);
1744 resetOp->moveBefore(firstOpInTarget);
1749 for (
auto *op : opsToUpdate)
1750 implementAsyncReset(op, module, actualReset);
1757 void InferResetsPass::implementAsyncReset(Operation *op, FModuleOp module,
1758 Value actualReset) {
1759 ImplicitLocOpBuilder
builder(op->getLoc(), op);
1762 if (
auto instOp = dyn_cast<InstanceOp>(op)) {
1766 auto refModule = instOp.getReferencedModule<FModuleOp>(*instanceGraph);
1769 auto domainIt = domains.find(refModule);
1770 if (domainIt == domains.end())
1772 auto &domain = domainIt->second.back().first;
1776 <<
"- Update instance '" << instOp.getName() <<
"'\n");
1780 if (domain.newPortName) {
1781 LLVM_DEBUG(
llvm::dbgs() <<
" - Adding new result as reset\n");
1783 auto newInstOp = instOp.cloneAndInsertPorts(
1785 {domain.newPortName,
1786 type_cast<FIRRTLBaseType>(actualReset.getType()),
1788 instReset = newInstOp.getResult(0);
1791 instOp.replaceAllUsesWith(newInstOp.getResults().drop_front());
1792 instanceGraph->replaceInstance(instOp, newInstOp);
1795 }
else if (domain.existingPort.has_value()) {
1796 auto idx = *domain.existingPort;
1797 instReset = instOp.getResult(idx);
1798 LLVM_DEBUG(
llvm::dbgs() <<
" - Using result #" << idx <<
" as reset\n");
1808 assert(instReset && actualReset);
1809 builder.setInsertionPointAfter(instOp);
1810 builder.create<StrictConnectOp>(instReset, actualReset);
1815 if (
auto regOp = dyn_cast<RegOp>(op)) {
1819 LLVM_DEBUG(
llvm::dbgs() <<
"- Adding async reset to " << regOp <<
"\n");
1821 auto newRegOp =
builder.create<RegResetOp>(
1822 regOp.getResult().getType(), regOp.getClockVal(), actualReset, zero,
1823 regOp.getNameAttr(), regOp.getNameKindAttr(), regOp.getAnnotations(),
1824 regOp.getInnerSymAttr(), regOp.getForceableAttr());
1825 regOp.getResult().replaceAllUsesWith(newRegOp.getResult());
1826 if (regOp.getForceable())
1827 regOp.getRef().replaceAllUsesWith(newRegOp.getRef());
1833 if (
auto regOp = dyn_cast<RegResetOp>(op)) {
1835 if (type_isa<AsyncResetType>(regOp.getResetSignal().getType())) {
1837 <<
"- Skipping (has async reset) " << regOp <<
"\n");
1840 if (failed(regOp.verifyInvariants()))
1841 signalPassFailure();
1844 LLVM_DEBUG(
llvm::dbgs() <<
"- Updating reset of " << regOp <<
"\n");
1846 auto reset = regOp.getResetSignal();
1847 auto value = regOp.getResetValue();
1853 builder.setInsertionPointAfterValue(regOp.getResult());
1854 auto mux =
builder.create<MuxPrimOp>(reset, value, regOp.getResult());
1858 builder.setInsertionPoint(regOp);
1860 regOp.getResetSignalMutable().assign(actualReset);
1861 regOp.getResetValueMutable().assign(zero);
1865 LogicalResult InferResetsPass::verifyNoAbstractReset() {
1866 bool hasAbstractResetPorts =
false;
1867 for (FModuleLike module :
1868 getOperation().getBodyBlock()->getOps<FModuleLike>()) {
1869 for (
PortInfo port : module.getPorts()) {
1870 if (getBaseOfType<ResetType>(port.type)) {
1871 auto diag = emitError(port.loc)
1872 <<
"a port \"" << port.getName()
1873 <<
"\" with abstract reset type was unable to be "
1874 "inferred by InferResets (is this a top-level port?)";
1875 diag.attachNote(module->getLoc())
1876 <<
"the module with this uninferred reset port was defined here";
1877 hasAbstractResetPorts =
true;
1882 if (hasAbstractResetPorts)
assert(baseType &&"element must be base type")
static Value createZeroValue(ImplicitLocOpBuilder &builder, FIRRTLBaseType type, SmallDenseMap< FIRRTLBaseType, Value > &cache)
Construct a zero value of the given type using the given builder.
static unsigned getFieldID(BundleType type, unsigned index)
bool operator!=(const ResetDomain &a, const ResetDomain &b)
static std::pair< StringAttr, FModuleOp > getResetNameAndModule(Value reset)
Return the name and parent module of a reset.
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
static bool isUselessVec(FIRRTLBaseType oldType, unsigned fieldID)
bool operator==(const ResetDomain &a, const ResetDomain &b)
static StringAttr getResetName(Value reset)
Return the name of a reset.
static bool insertResetMux(ImplicitLocOpBuilder &builder, Value target, Value reset, Value resetValue)
Helper function that inserts reset multiplexer into all ConnectOps with the given target.
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static bool typeContainsReset(Type type)
Check whether a type contains a ResetType.
static bool getDeclName(Value value, SmallString< 32 > &string)
static unsigned getMaxFieldID(FIRRTLBaseType type)
static InstancePath empty
This class represents a reference to a specific field or element of an aggregate value.
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Value getValue() const
Get the Value which created this location.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class provides a read-only projection of an annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
FIRRTLBaseType getConstType(bool isConst)
Return a 'const' or non-'const' version of this type.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
An instance path composed of a series of instances.
InstanceOpInterface leaf() const
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
constexpr const char * excludeMemToRegAnnoClass
igraph::InstancePathCache InstancePathCache
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
FIRRTLType mapBaseType(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
constexpr const char * fullAsyncResetAnnoClass
Annotation that marks a reset (port or wire) and domain.
T & operator<<(T &os, FIRVersion version)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
constexpr const char * ignoreFullAsyncResetAnnoClass
Annotation that marks a module as not belonging to any reset domain.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
std::unique_ptr< mlir::Pass > createInferResetsPass()
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
llvm::raw_ostream & debugHeader(llvm::StringRef str, int width=80)
Write a "header"-like string to the debug stream with a certain width.
inline ::llvm::hash_code hash_value(const FieldRef &fieldRef)
Get a hash code for a FieldRef.
bool operator<(const AppID &a, const AppID &b)
mlir::raw_indented_ostream & dbgs()
This holds the name and type that describes the module's ports.
static ResetSignal getEmptyKey()
static ResetSignal getTombstoneKey()
static bool isEqual(const ResetSignal &lhs, const ResetSignal &rhs)
static unsigned getHashValue(const ResetSignal &x)