22 #include "mlir/IR/Dominance.h"
23 #include "mlir/IR/ImplicitLocOpBuilder.h"
24 #include "mlir/IR/Threading.h"
25 #include "llvm/ADT/EquivalenceClasses.h"
26 #include "llvm/ADT/SetVector.h"
27 #include "llvm/ADT/TinyPtrVector.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/Debug.h"
31 #define DEBUG_TYPE "infer-resets"
33 using llvm::BumpPtrAllocator;
34 using llvm::MapVector;
35 using llvm::SmallDenseSet;
36 using llvm::SmallSetVector;
38 using mlir::InferTypeOpInterface;
39 using mlir::WalkOrder;
41 using namespace circt;
42 using namespace firrtl;
57 os <<
"/" << inst.getInstanceName() <<
":"
58 << inst.getReferencedModuleName();
66 auto last = path.back();
67 return last.getInstanceName();
82 std::optional<unsigned> existingPort;
83 StringAttr newPortName;
85 ResetDomain(Value reset) : reset(reset) {}
89 inline bool operator==(
const ResetDomain &a,
const ResetDomain &b) {
90 return (a.isTop == b.isTop && a.reset == b.reset);
92 inline bool operator!=(
const ResetDomain &a,
const ResetDomain &b) {
99 if (
auto arg = dyn_cast<BlockArgument>(reset)) {
100 auto module = cast<FModuleOp>(arg.getParentRegion()->getParentOp());
101 return {module.getPortNameAttr(arg.getArgNumber()), module};
103 auto op = reset.getDefiningOp();
104 return {op->getAttrOfType<StringAttr>(
"name"),
105 op->getParentOfType<FModuleOp>()};
120 auto it = cache.find(type);
121 if (it != cache.end())
123 auto nullBit = [&]() {
130 .
Case<ClockType>([&](
auto type) {
131 return builder.create<AsClockPrimOp>(nullBit());
133 .Case<AsyncResetType>([&](
auto type) {
134 return builder.create<AsAsyncResetPrimOp>(nullBit());
136 .Case<SIntType, UIntType>([&](
auto type) {
137 return builder.create<ConstantOp>(
138 type, APInt::getZero(type.getWidth().value_or(1)));
140 .Case<BundleType>([&](
auto type) {
141 auto wireOp =
builder.create<WireOp>(type);
142 for (
unsigned i = 0, e = type.getNumElements(); i < e; ++i) {
143 auto fieldType = type.getElementTypePreservingConst(i);
146 builder.create<SubfieldOp>(fieldType, wireOp.getResult(), i);
147 builder.create<StrictConnectOp>(acc, zero);
149 return wireOp.getResult();
151 .Case<FVectorType>([&](
auto type) {
152 auto wireOp =
builder.create<WireOp>(type);
154 builder, type.getElementTypePreservingConst(), cache);
155 for (
unsigned i = 0, e = type.getNumElements(); i < e; ++i) {
156 auto acc =
builder.create<SubindexOp>(zero.getType(),
157 wireOp.getResult(), i);
158 builder.create<StrictConnectOp>(acc, zero);
160 return wireOp.getResult();
162 .Case<ResetType, AnalogType>(
163 [&](
auto type) {
return builder.create<InvalidValueOp>(type); })
165 llvm_unreachable(
"switch handles all types");
168 cache.insert({type, value});
185 Value reset, Value resetValue) {
189 bool resetValueUsed =
false;
191 for (
auto &use : target.getUses()) {
192 Operation *useOp = use.getOwner();
193 builder.setInsertionPoint(useOp);
194 TypeSwitch<Operation *>(useOp)
197 .Case<ConnectOp, StrictConnectOp>([&](
auto op) {
198 if (op.getDest() != target)
200 LLVM_DEBUG(
llvm::dbgs() <<
" - Insert mux into " << op <<
"\n");
202 builder.create<MuxPrimOp>(reset, resetValue, op.getSrc());
203 op.getSrcMutable().assign(muxOp);
204 resetValueUsed =
true;
207 .Case<SubfieldOp>([&](
auto op) {
209 builder.create<SubfieldOp>(resetValue, op.getFieldIndexAttr());
211 resetValueUsed =
true;
213 resetSubValue.erase();
216 .Case<SubindexOp>([&](
auto op) {
218 builder.create<SubindexOp>(resetValue, op.getIndexAttr());
220 resetValueUsed =
true;
222 resetSubValue.erase();
225 .Case<SubaccessOp>([&](
auto op) {
226 if (op.getInput() != target)
229 builder.create<SubaccessOp>(resetValue, op.getIndex());
231 resetValueUsed =
true;
233 resetSubValue.erase();
236 return resetValueUsed;
251 bool operator<(
const ResetSignal &other)
const {
return field < other.field; }
252 bool operator==(
const ResetSignal &other)
const {
253 return field == other.field;
255 bool operator!=(
const ResetSignal &other)
const {
return !(*
this == other); }
275 using ResetDrives = SmallVector<ResetDrive, 1>;
278 using ResetNetwork = llvm::iterator_range<
279 llvm::EquivalenceClasses<ResetSignal>::member_iterator>;
282 enum class ResetKind { Async, Sync };
288 struct DenseMapInfo<ResetSignal> {
290 return ResetSignal{DenseMapInfo<FieldRef>::getEmptyKey(), {}};
293 return ResetSignal{DenseMapInfo<FieldRef>::getTombstoneKey(), {}};
298 static bool isEqual(
const ResetSignal &lhs,
const ResetSignal &rhs) {
304 template <
typename T>
307 case ResetKind::Async:
308 return os <<
"async";
309 case ResetKind::Sync:
421 struct InferResetsPass :
public InferResetsBase<InferResetsPass> {
422 void runOnOperation()
override;
423 void runOnOperationInner();
426 using InferResetsBase::InferResetsBase;
427 InferResetsPass(
const InferResetsPass &other) : InferResetsBase(other) {}
432 void traceResets(CircuitOp circuit);
433 void traceResets(InstanceOp inst);
434 void traceResets(Value dst, Value src, Location loc);
435 void traceResets(Value value);
436 void traceResets(Type dstType, Value dst,
unsigned dstID, Type srcType,
437 Value src,
unsigned srcID, Location loc);
439 LogicalResult inferAndUpdateResets();
441 LogicalResult updateReset(ResetNetwork net, ResetKind kind);
447 LogicalResult collectAnnos(CircuitOp circuit);
455 LogicalResult buildDomains(CircuitOp circuit);
458 unsigned indent = 0);
460 void determineImpl();
461 void determineImpl(FModuleOp module, ResetDomain &domain);
463 LogicalResult implementAsyncReset();
464 LogicalResult implementAsyncReset(FModuleOp module, ResetDomain &domain);
465 void implementAsyncReset(Operation *op, FModuleOp module, Value actualReset);
467 LogicalResult verifyNoAbstractReset();
473 ResetNetwork getResetNetwork(ResetSignal signal) {
474 return llvm::make_range(resetClasses.findLeader(signal),
475 resetClasses.member_end());
479 ResetDrives &getResetDrives(ResetNetwork net) {
480 return resetDrives[*net.begin()];
485 ResetSignal guessRoot(ResetNetwork net);
486 ResetSignal guessRoot(ResetSignal signal) {
487 return guessRoot(getResetNetwork(signal));
494 llvm::EquivalenceClasses<ResetSignal> resetClasses;
497 DenseMap<ResetSignal, ResetDrives> resetDrives;
502 DenseMap<Operation *, Value> annotatedResets;
506 MapVector<FModuleOp, SmallVector<std::pair<ResetDomain, InstancePathVec>, 1>>
514 void InferResetsPass::runOnOperation() {
515 runOnOperationInner();
516 resetClasses = llvm::EquivalenceClasses<ResetSignal>();
518 annotatedResets.clear();
520 markAnalysesPreserved<InstanceGraph>();
523 void InferResetsPass::runOnOperationInner() {
524 instanceGraph = &getAnalysis<InstanceGraph>();
527 traceResets(getOperation());
530 if (failed(inferAndUpdateResets()))
531 return signalPassFailure();
534 if (failed(collectAnnos(getOperation())))
535 return signalPassFailure();
538 if (failed(buildDomains(getOperation())))
539 return signalPassFailure();
545 if (failed(implementAsyncReset()))
546 return signalPassFailure();
549 if (failed(verifyNoAbstractReset()))
550 return signalPassFailure();
554 return std::make_unique<InferResetsPass>();
557 ResetSignal InferResetsPass::guessRoot(ResetNetwork net) {
558 ResetDrives &drives = getResetDrives(net);
559 ResetSignal bestSignal = *net.begin();
560 unsigned bestNumDrives = -1;
562 for (
auto signal : net) {
564 if (isa_and_nonnull<InvalidValueOp>(
565 signal.field.getValue().getDefiningOp()))
570 unsigned numDrives = 0;
571 for (
auto &drive : drives)
572 if (drive.dst == signal)
578 if (numDrives < bestNumDrives) {
579 bestNumDrives = numDrives;
598 .
Case<BundleType>([](
auto type) {
600 for (
auto e : type.getElements())
605 [](
auto type) {
return getMaxFieldID(type.getElementType()) + 1; })
606 .Default([](
auto) {
return 0; });
609 static unsigned getFieldID(BundleType type,
unsigned index) {
610 assert(index < type.getNumElements());
612 for (
unsigned i = 0; i < index; ++i)
620 assert(type.getNumElements() &&
"Bundle must have >0 fields");
622 for (
const auto &e : llvm::enumerate(type.getElements())) {
624 if (fieldID < numSubfields)
626 fieldID -= numSubfields;
628 assert(
false &&
"field id outside bundle");
634 if (oldType.isGround()) {
640 if (
auto bundleType = type_dyn_cast<BundleType>(oldType)) {
648 if (
auto vectorType = type_dyn_cast<FVectorType>(oldType)) {
649 if (vectorType.getNumElements() == 0)
666 if (
auto arg = dyn_cast<BlockArgument>(value)) {
667 auto module = cast<FModuleOp>(arg.getOwner()->getParentOp());
668 string += module.getPortName(arg.getArgNumber());
672 auto *op = value.getDefiningOp();
673 return TypeSwitch<Operation *, bool>(op)
674 .Case<InstanceOp, MemOp>([&](
auto op) {
675 string += op.getName();
678 op.getPortName(cast<OpResult>(value).getResultNumber()).getValue();
681 .Case<WireOp, RegOp, RegResetOp>([&](
auto op) {
682 string += op.getName();
685 .Default([](
auto) {
return false; });
689 SmallString<64> name;
694 auto type = value.getType();
697 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
700 auto &element = bundleType.getElements()[index];
703 string += element.name.getValue();
706 localID = localID -
getFieldID(bundleType, index);
707 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
710 type = vecType.getElementType();
717 llvm_unreachable(
"unsupported type");
729 return TypeSwitch<Type, bool>(type)
731 return type.getRecursiveTypeProperties().hasUninferredReset;
733 .Default([](
auto) {
return false; });
740 void InferResetsPass::traceResets(CircuitOp circuit) {
742 llvm::dbgs() <<
"\n===----- Tracing uninferred resets -----===\n\n");
744 SmallVector<std::pair<FModuleOp, SmallVector<Operation *>>> moduleToOps;
746 for (
auto module : circuit.getOps<FModuleOp>())
747 moduleToOps.push_back({module, {}});
749 mlir::parallelForEach(circuit.getContext(), moduleToOps, [](
auto &e) {
750 e.first.walk([&](Operation *op) {
754 op->getResultTypes(),
755 [](mlir::Type type) { return typeContainsReset(type); }) ||
756 llvm::any_of(op->getOperandTypes(), typeContainsReset))
757 e.second.push_back(op);
761 for (
auto &[_, ops] : moduleToOps)
762 for (
auto *op : ops) {
763 TypeSwitch<Operation *>(op)
764 .Case<FConnectLike>([&](
auto op) {
765 traceResets(op.getDest(), op.getSrc(), op.getLoc());
767 .Case<InstanceOp>([&](
auto op) { traceResets(op); })
768 .Case<RefSendOp>([&](
auto op) {
770 traceResets(op.getType().getType(), op.getResult(), 0,
771 op.getBase().getType().getPassiveType(), op.getBase(),
774 .Case<RefResolveOp>([&](
auto op) {
776 traceResets(op.getType(), op.getResult(), 0,
777 op.getRef().getType().getType(), op.getRef(), 0,
780 .Case<Forceable>([&](Forceable op) {
782 if (op.isForceable())
783 traceResets(op.getDataType(), op.getData(), 0, op.getDataType(),
784 op.getDataRef(), 0, op.getLoc());
786 .Case<UninferredResetCastOp, ConstCastOp, RefCastOp>([&](
auto op) {
787 traceResets(op.getResult(), op.getInput(), op.getLoc());
789 .Case<InvalidValueOp>([&](
auto op) {
798 auto type = op.getType();
801 LLVM_DEBUG(
llvm::dbgs() <<
"Uniquify " << op <<
"\n");
802 ImplicitLocOpBuilder
builder(op->getLoc(), op);
804 llvm::make_early_inc_range(llvm::drop_begin(op->getUses()))) {
810 auto newOp =
builder.create<InvalidValueOp>(type);
815 .Case<SubfieldOp>([&](
auto op) {
818 BundleType bundleType = op.getInput().getType();
819 auto index = op.getFieldIndex();
820 traceResets(op.getType(), op.getResult(), 0,
821 bundleType.getElements()[index].type, op.getInput(),
825 .Case<SubindexOp, SubaccessOp>([&](
auto op) {
838 FVectorType vectorType = op.getInput().getType();
839 traceResets(op.getType(), op.getResult(), 0,
840 vectorType.getElementType(), op.getInput(),
844 .Case<RefSubOp>([&](RefSubOp op) {
846 auto aggType = op.getInput().getType().getType();
847 uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(aggType)
848 .Case<FVectorType>([](
auto type) {
851 .Case<BundleType>([&](
auto type) {
854 traceResets(op.getType(), op.getResult(), 0,
855 op.getResult().getType(), op.getInput(), fieldID,
863 void InferResetsPass::traceResets(InstanceOp inst) {
865 auto module = dyn_cast<FModuleOp>(*instanceGraph->getReferencedModule(inst));
868 LLVM_DEBUG(
llvm::dbgs() <<
"Visiting instance " << inst.getName() <<
"\n");
871 auto dirs = module.getPortDirections();
872 for (
const auto &it : llvm::enumerate(inst.getResults())) {
873 auto dir = module.getPortDirection(it.index());
874 Value dstPort = module.getArgument(it.index());
875 Value srcPort = it.value();
876 if (dir == Direction::Out)
877 std::swap(dstPort, srcPort);
878 traceResets(dstPort, srcPort, it.value().getLoc());
884 void InferResetsPass::traceResets(Value dst, Value src, Location loc) {
886 traceResets(dst.getType(), dst, 0, src.getType(), src, 0, loc);
891 void InferResetsPass::traceResets(Type dstType, Value dst,
unsigned dstID,
892 Type srcType, Value src,
unsigned srcID,
894 if (
auto dstBundle = type_dyn_cast<BundleType>(dstType)) {
895 auto srcBundle = type_cast<BundleType>(srcType);
896 for (
unsigned dstIdx = 0, e = dstBundle.getNumElements(); dstIdx < e;
898 auto dstField = dstBundle.getElements()[dstIdx].name;
899 auto srcIdx = srcBundle.getElementIndex(dstField);
902 auto &dstElt = dstBundle.getElements()[dstIdx];
903 auto &srcElt = srcBundle.getElements()[*srcIdx];
905 traceResets(srcElt.type, src, srcID +
getFieldID(srcBundle, *srcIdx),
906 dstElt.type, dst, dstID +
getFieldID(dstBundle, dstIdx),
909 traceResets(dstElt.type, dst, dstID +
getFieldID(dstBundle, dstIdx),
910 srcElt.type, src, srcID +
getFieldID(srcBundle, *srcIdx),
917 if (
auto dstVector = type_dyn_cast<FVectorType>(dstType)) {
918 auto srcVector = type_cast<FVectorType>(srcType);
919 auto srcElType = srcVector.getElementType();
920 auto dstElType = dstVector.getElementType();
933 traceResets(dstElType, dst, dstID +
getFieldID(dstVector), srcElType, src,
939 if (
auto dstRef = type_dyn_cast<RefType>(dstType)) {
940 auto srcRef = type_cast<RefType>(srcType);
941 return traceResets(dstRef.getType(), dst, dstID, srcRef.getType(), src,
946 auto dstBase = type_dyn_cast<FIRRTLBaseType>(dstType);
947 auto srcBase = type_dyn_cast<FIRRTLBaseType>(srcType);
948 if (!dstBase || !srcBase)
950 if (!type_isa<ResetType>(dstBase) && !type_isa<ResetType>(srcBase))
955 LLVM_DEBUG(
llvm::dbgs() <<
"Visiting driver '" << dstField <<
"' = '"
956 << srcField <<
"' (" << dstType <<
" = " << srcType
962 ResetSignal dstLeader =
963 *resetClasses.findLeader(resetClasses.insert({dstField, dstBase}));
964 ResetSignal srcLeader =
965 *resetClasses.findLeader(resetClasses.insert({srcField, srcBase}));
968 ResetSignal unionLeader = *resetClasses.unionSets(dstLeader, srcLeader);
969 assert(unionLeader == dstLeader || unionLeader == srcLeader);
974 if (dstLeader != srcLeader) {
975 auto &unionDrives = resetDrives[unionLeader];
976 auto mergedDrivesIt =
977 resetDrives.find(unionLeader == dstLeader ? srcLeader : dstLeader);
978 if (mergedDrivesIt != resetDrives.end()) {
979 unionDrives.append(mergedDrivesIt->second);
980 resetDrives.erase(mergedDrivesIt);
986 resetDrives[unionLeader].push_back(
987 {{dstField, dstBase}, {srcField, srcBase}, loc});
994 LogicalResult InferResetsPass::inferAndUpdateResets() {
995 LLVM_DEBUG(
llvm::dbgs() <<
"\n===----- Infer reset types -----===\n\n");
996 for (
auto it = resetClasses.begin(), end = resetClasses.end(); it != end;
1000 ResetNetwork net = llvm::make_range(resetClasses.member_begin(it),
1001 resetClasses.member_end());
1004 auto kind = inferReset(net);
1009 if (failed(updateReset(net, *kind)))
1016 LLVM_DEBUG(
llvm::dbgs() <<
"Inferring reset network with "
1017 << std::distance(net.begin(), net.end())
1021 unsigned asyncDrives = 0;
1022 unsigned syncDrives = 0;
1023 unsigned invalidDrives = 0;
1024 for (ResetSignal signal : net) {
1026 if (type_isa<AsyncResetType>(signal.type))
1028 else if (type_isa<UIntType>(signal.type))
1031 isa_and_nonnull<InvalidValueOp>(
1032 signal.field.getValue().getDefiningOp()))
1035 LLVM_DEBUG(
llvm::dbgs() <<
"- Found " << asyncDrives <<
" async, "
1036 << syncDrives <<
" sync, " << invalidDrives
1037 <<
" invalid drives\n");
1040 if (asyncDrives == 0 && syncDrives == 0 && invalidDrives == 0) {
1041 ResetSignal root = guessRoot(net);
1042 auto diag = mlir::emitError(root.field.getValue().getLoc())
1043 <<
"reset network never driven with concrete type";
1044 for (ResetSignal signal : net)
1045 diag.attachNote(signal.field.getLoc()) <<
"here: ";
1050 if (asyncDrives > 0 && syncDrives > 0) {
1051 ResetSignal root = guessRoot(net);
1052 bool majorityAsync = asyncDrives >= syncDrives;
1053 auto diag = mlir::emitError(root.field.getValue().getLoc())
1055 SmallString<32> fieldName;
1057 diag <<
" \"" << fieldName <<
"\"";
1058 diag <<
" simultaneously connected to async and sync resets";
1059 diag.attachNote(root.field.getValue().getLoc())
1060 <<
"majority of connections to this reset are "
1061 << (majorityAsync ?
"async" :
"sync");
1062 for (
auto &drive : getResetDrives(net)) {
1063 if ((type_isa<AsyncResetType>(drive.dst.type) && !majorityAsync) ||
1064 (type_isa<AsyncResetType>(drive.src.type) && !majorityAsync) ||
1065 (type_isa<UIntType>(drive.dst.type) && majorityAsync) ||
1066 (type_isa<UIntType>(drive.src.type) && majorityAsync))
1067 diag.attachNote(drive.loc)
1068 << (type_isa<AsyncResetType>(drive.src.type) ?
"async" :
"sync")
1077 auto kind = (asyncDrives ? ResetKind::Async : ResetKind::Sync);
1078 LLVM_DEBUG(
llvm::dbgs() <<
"- Inferred as " << kind <<
"\n");
1086 LogicalResult InferResetsPass::updateReset(ResetNetwork net, ResetKind kind) {
1087 LLVM_DEBUG(
llvm::dbgs() <<
"Updating reset network with "
1088 << std::distance(net.begin(), net.end())
1089 <<
" nodes to " << kind <<
"\n");
1093 if (kind == ResetKind::Async)
1101 SmallSetVector<Operation *, 16> worklist;
1102 SmallDenseSet<Operation *> moduleWorklist;
1103 SmallDenseSet<std::pair<Operation *, Operation *>> extmoduleWorklist;
1104 for (
auto signal : net) {
1105 Value value = signal.field.getValue();
1106 if (!isa<BlockArgument>(value) &&
1107 !isa_and_nonnull<WireOp, RegOp, RegResetOp, InstanceOp, InvalidValueOp,
1108 ConstCastOp, RefCastOp, UninferredResetCastOp>(
1109 value.getDefiningOp()))
1111 if (updateReset(signal.field, resetType)) {
1112 for (
auto user : value.getUsers())
1113 worklist.insert(user);
1114 if (
auto blockArg = dyn_cast<BlockArgument>(value))
1115 moduleWorklist.insert(blockArg.getOwner()->getParentOp());
1116 else if (
auto instOp = value.getDefiningOp<InstanceOp>()) {
1117 if (
auto extmodule = dyn_cast<FExtModuleOp>(
1118 *instanceGraph->getReferencedModule(instOp)))
1119 extmoduleWorklist.insert({extmodule, instOp});
1120 }
else if (
auto uncast = value.getDefiningOp<UninferredResetCastOp>()) {
1121 uncast.replaceAllUsesWith(uncast.getInput());
1131 while (!worklist.empty()) {
1132 auto *wop = worklist.pop_back_val();
1133 SmallVector<Type, 2> types;
1134 if (
auto op = dyn_cast<InferTypeOpInterface>(wop)) {
1136 SmallVector<Type, 2> types;
1137 if (failed(op.inferReturnTypes(op->getContext(), op->getLoc(),
1138 op->getOperands(), op->getAttrDictionary(),
1139 op->getPropertiesStorage(),
1140 op->getRegions(), types)))
1145 for (
auto it : llvm::zip(op->getResults(), types)) {
1146 auto newType = std::get<1>(it);
1147 if (std::get<0>(it).getType() == newType)
1149 std::get<0>(it).setType(newType);
1150 for (
auto *user : std::get<0>(it).getUsers())
1151 worklist.insert(user);
1153 LLVM_DEBUG(
llvm::dbgs() <<
"- Inferred " << *op <<
"\n");
1154 }
else if (
auto uop = dyn_cast<UninferredResetCastOp>(wop)) {
1155 for (
auto *user : uop.getResult().getUsers())
1156 worklist.insert(user);
1157 uop.replaceAllUsesWith(uop.getInput());
1158 LLVM_DEBUG(
llvm::dbgs() <<
"- Inferred " << uop <<
"\n");
1164 for (
auto *op : moduleWorklist) {
1165 auto module = dyn_cast<FModuleOp>(op);
1169 SmallVector<Attribute> argTypes;
1170 argTypes.reserve(module.getNumPorts());
1171 for (
auto arg : module.getArguments())
1174 module->setAttr(FModuleLike::getPortTypesAttrName(),
1177 <<
"- Updated type of module '" << module.getName() <<
"'\n");
1181 for (
auto pair : extmoduleWorklist) {
1182 auto module = cast<FExtModuleOp>(pair.first);
1183 auto instOp = cast<InstanceOp>(pair.second);
1185 SmallVector<Attribute> types;
1186 for (
auto type : instOp.getResultTypes())
1189 module->setAttr(FModuleLike::getPortTypesAttrName(),
1192 <<
"- Updated type of extmodule '" << module.getName() <<
"'\n");
1202 if (oldType.isGround()) {
1208 if (
auto bundleType = type_dyn_cast<BundleType>(oldType)) {
1210 SmallVector<BundleType::BundleElement> fields(bundleType.begin(),
1213 fields[index].type, fieldID -
getFieldID(bundleType, index), fieldType);
1218 if (
auto vectorType = type_dyn_cast<FVectorType>(oldType)) {
1219 auto newType =
updateType(vectorType.getElementType(),
1220 fieldID -
getFieldID(vectorType), fieldType);
1222 vectorType.isConst());
1225 llvm_unreachable(
"unknown aggregate type");
1232 auto oldType = type_cast<FIRRTLType>(field.
getValue().getType());
1238 if (oldType == newType)
1240 LLVM_DEBUG(
llvm::dbgs() <<
"- Updating '" << field <<
"' from " << oldType
1241 <<
" to " << newType <<
"\n");
1250 LogicalResult InferResetsPass::collectAnnos(CircuitOp circuit) {
1252 llvm::dbgs() <<
"\n===----- Gather async reset annotations -----===\n\n");
1253 SmallVector<std::pair<FModuleOp, std::optional<Value>>> results;
1254 for (
auto module : circuit.getOps<FModuleOp>())
1255 results.push_back({module, {}});
1257 if (failed(mlir::failableParallelForEach(
1258 circuit.getContext(), results, [&](
auto &moduleAndResult) {
1259 auto result = collectAnnos(moduleAndResult.first);
1262 moduleAndResult.second = *result;
1267 for (
auto [module, reset] : results)
1268 if (reset.has_value())
1269 annotatedResets.insert({module, *reset});
1274 InferResetsPass::collectAnnos(FModuleOp module) {
1275 bool anyFailed =
false;
1276 SmallSetVector<std::pair<Annotation, Location>, 4> conflictingAnnos;
1280 bool ignore =
false;
1282 if (!moduleAnnos.empty()) {
1283 moduleAnnos.removeAnnotations([&](
Annotation anno) {
1286 conflictingAnnos.insert({anno, module.getLoc()});
1291 module.emitError(
"'FullAsyncResetAnnotation' cannot target module; "
1292 "must target port or wire/node instead");
1297 moduleAnnos.applyToOperation(module);
1304 AnnotationSet::removePortAnnotations(module, [&](
unsigned argNum,
1306 Value arg = module.getArgument(argNum);
1309 conflictingAnnos.insert({anno, reset.getLoc()});
1314 mlir::emitError(arg.getLoc(),
1315 "'IgnoreFullAsyncResetAnnotation' cannot target port; "
1316 "must target module instead");
1325 module.walk([&](Operation *op) {
1326 AnnotationSet::removeAnnotations(op, [&](
Annotation anno) {
1328 if (!isa<WireOp, NodeOp>(op)) {
1333 "reset annotations must target module, port, or wire/node");
1342 reset = op->getResult(0);
1343 conflictingAnnos.insert({anno, reset.getLoc()});
1349 "'IgnoreFullAsyncResetAnnotation' cannot target wire/node; must "
1350 "target module instead");
1362 if (!ignore && !reset) {
1364 <<
"No reset annotation for " << module.getName() <<
"\n");
1365 return std::optional<Value>();
1369 if (conflictingAnnos.size() > 1) {
1370 auto diag = module.emitError(
"multiple reset annotations on module '")
1371 << module.getName() <<
"'";
1372 for (
auto &annoAndLoc : conflictingAnnos)
1373 diag.attachNote(annoAndLoc.second)
1374 <<
"conflicting " << annoAndLoc.first.getClassAttr() <<
":";
1380 llvm::dbgs() <<
"Annotated reset for " << module.getName() <<
": ";
1383 else if (
auto arg = dyn_cast<BlockArgument>(reset))
1384 llvm::dbgs() <<
"port " << module.getPortName(arg.getArgNumber()) <<
"\n";
1387 << reset.getDefiningOp()->getAttrOfType<StringAttr>(
"name")
1393 return std::optional<Value>(reset);
1405 LogicalResult InferResetsPass::buildDomains(CircuitOp circuit) {
1407 llvm::dbgs() <<
"\n===----- Build async reset domains -----===\n\n");
1410 auto &instGraph = getAnalysis<InstanceGraph>();
1411 auto module = dyn_cast<FModuleOp>(*instGraph.getTopLevelNode()->getModule());
1414 <<
"Skipping circuit because main module is no `firrtl.module`");
1420 bool anyFailed =
false;
1421 for (
auto &it : domains) {
1422 auto module = cast<FModuleOp>(it.first);
1423 auto &domainConflicts = it.second;
1424 if (domainConflicts.size() <= 1)
1428 SmallDenseSet<Value> printedDomainResets;
1429 auto diag = module.emitError(
"module '")
1431 <<
"' instantiated in different reset domains";
1432 for (
auto &it : domainConflicts) {
1433 ResetDomain &domain = it.first;
1435 auto inst = path.back();
1436 auto loc = path.empty() ? module.getLoc() : inst.getLoc();
1437 auto ¬e = diag.attachNote(loc);
1441 note <<
"root instance";
1443 note <<
"instance '";
1445 path, [&](
InstanceLike inst) { note << inst.getInstanceName(); },
1446 [&]() { note <<
"/"; });
1454 note <<
" reset domain rooted at '" << nameAndModule.first.getValue()
1455 <<
"' of module '" << nameAndModule.second.getName() <<
"'";
1458 if (printedDomainResets.insert(domain.reset).second) {
1459 diag.attachNote(domain.reset.getLoc())
1460 <<
"reset domain '" << nameAndModule.first.getValue()
1461 <<
"' of module '" << nameAndModule.second.getName()
1462 <<
"' declared here:";
1465 note <<
" no reset domain";
1468 return failure(anyFailed);
1471 void InferResetsPass::buildDomains(FModuleOp module,
1476 <<
"Visiting " <<
getTail(instPath) <<
" (" << module.getName()
1480 ResetDomain domain(parentReset);
1481 auto it = annotatedResets.find(module);
1482 if (it != annotatedResets.end()) {
1483 domain.isTop =
true;
1484 domain.reset = it->second;
1490 auto &entries = domains[module];
1491 if (llvm::all_of(entries,
1492 [&](
const auto &entry) {
return entry.first != domain; }))
1493 entries.push_back({domain, instPath});
1497 for (
auto *record : *instGraph[module]) {
1498 auto submodule = dyn_cast<FModuleOp>(*record->getTarget()->getModule());
1501 childPath.push_back(cast<InstanceLike>(*record->getInstance()));
1502 buildDomains(submodule, childPath, domain.reset, instGraph, indent + 1);
1503 childPath.pop_back();
1508 void InferResetsPass::determineImpl() {
1510 llvm::dbgs() <<
"\n===----- Determine implementation -----===\n\n");
1511 for (
auto &it : domains) {
1512 auto module = cast<FModuleOp>(it.first);
1513 auto &domain = it.second.back().first;
1514 determineImpl(module, domain);
1534 void InferResetsPass::determineImpl(FModuleOp module, ResetDomain &domain) {
1537 LLVM_DEBUG(
llvm::dbgs() <<
"Planning reset for " << module.getName() <<
"\n");
1542 LLVM_DEBUG(
llvm::dbgs() <<
"- Rooting at local value "
1544 domain.existingValue = domain.reset;
1545 if (
auto blockArg = dyn_cast<BlockArgument>(domain.reset))
1546 domain.existingPort = blockArg.getArgNumber();
1553 auto neededType = domain.reset.getType();
1554 LLVM_DEBUG(
llvm::dbgs() <<
"- Looking for existing port " << neededName
1556 auto portNames = module.getPortNames();
1557 auto ports = llvm::zip(portNames, module.getArguments());
1558 auto portIt = llvm::find_if(
1559 ports, [&](
auto port) {
return std::get<0>(port) == neededName; });
1560 if (portIt != ports.end() && std::get<1>(*portIt).getType() == neededType) {
1562 <<
"- Reusing existing port " << neededName <<
"\n");
1563 domain.existingValue = std::get<1>(*portIt);
1564 domain.existingPort = std::distance(ports.begin(), portIt);
1574 if (portIt != ports.end()) {
1576 <<
"- Existing " << neededName <<
" has incompatible type "
1577 << std::get<1>(*portIt).getType() <<
"\n");
1579 unsigned suffix = 0;
1583 Twine(
"_") + Twine(suffix++));
1584 }
while (llvm::is_contained(portNames, newName));
1586 <<
"- Creating uniquified port " << newName <<
"\n");
1587 domain.newPortName = newName;
1593 LLVM_DEBUG(
llvm::dbgs() <<
"- Creating new port " << neededName <<
"\n");
1594 domain.newPortName = neededName;
1602 LogicalResult InferResetsPass::implementAsyncReset() {
1603 LLVM_DEBUG(
llvm::dbgs() <<
"\n===----- Implement async resets -----===\n\n");
1604 for (
auto &it : domains)
1605 if (failed(implementAsyncReset(cast<FModuleOp>(it.first),
1606 it.second.back().first)))
1616 LogicalResult InferResetsPass::implementAsyncReset(FModuleOp module,
1617 ResetDomain &domain) {
1618 LLVM_DEBUG(
llvm::dbgs() <<
"Implementing async reset for " << module.getName()
1622 if (!domain.reset) {
1624 <<
"- Skipping because module explicitly has no domain\n");
1629 Value actualReset = domain.existingValue;
1630 if (domain.newPortName) {
1631 PortInfo portInfo{domain.newPortName,
1635 domain.reset.getLoc()};
1636 module.insertPorts({{0, portInfo}});
1637 actualReset = module.getArgument(0);
1639 <<
"- Inserted port " << domain.newPortName <<
"\n");
1644 if (
auto blockArg = dyn_cast<BlockArgument>(actualReset))
1645 llvm::dbgs() <<
"port #" << blockArg.getArgNumber() <<
" ";
1653 SmallVector<Operation *> opsToUpdate;
1654 module.walk([&](Operation *op) {
1655 if (isa<InstanceOp, RegOp, RegResetOp>(op))
1656 opsToUpdate.push_back(op);
1663 if (!isa<BlockArgument>(actualReset)) {
1664 mlir::DominanceInfo dom(module);
1669 auto *resetOp = actualReset.getDefiningOp();
1670 if (!opsToUpdate.empty() && !dom.dominates(resetOp, opsToUpdate[0])) {
1672 <<
"- Reset doesn't dominate all uses, needs to be moved\n");
1676 auto nodeOp = dyn_cast<NodeOp>(resetOp);
1677 if (nodeOp && !dom.dominates(nodeOp.getInput(), opsToUpdate[0])) {
1679 <<
"- Promoting node to wire for move: " << nodeOp <<
"\n");
1680 ImplicitLocOpBuilder
builder(nodeOp.getLoc(), nodeOp);
1681 auto wireOp =
builder.create<WireOp>(
1682 nodeOp.getResult().getType(), nodeOp.getNameAttr(),
1683 nodeOp.getNameKindAttr(), nodeOp.getAnnotationsAttr(),
1684 nodeOp.getInnerSymAttr(), nodeOp.getForceableAttr());
1685 builder.create<StrictConnectOp>(wireOp.getResult(), nodeOp.getInput());
1686 nodeOp->replaceAllUsesWith(wireOp);
1689 actualReset = wireOp.getResult();
1690 domain.existingValue = wireOp.getResult();
1695 Block *targetBlock = dom.findNearestCommonDominator(
1696 resetOp->getBlock(), opsToUpdate[0]->getBlock());
1698 if (targetBlock != resetOp->getBlock())
1699 llvm::dbgs() <<
"- Needs to be moved to different block\n";
1708 auto getParentInBlock = [](Operation *op,
Block *block) {
1709 while (op && op->getBlock() != block)
1710 op = op->getParentOp();
1713 auto *resetOpInTarget = getParentInBlock(resetOp, targetBlock);
1714 auto *firstOpInTarget = getParentInBlock(opsToUpdate[0], targetBlock);
1720 if (resetOpInTarget->isBeforeInBlock(firstOpInTarget))
1721 resetOp->moveBefore(resetOpInTarget);
1723 resetOp->moveBefore(firstOpInTarget);
1728 for (
auto *op : opsToUpdate)
1729 implementAsyncReset(op, module, actualReset);
1736 void InferResetsPass::implementAsyncReset(Operation *op, FModuleOp module,
1737 Value actualReset) {
1738 ImplicitLocOpBuilder
builder(op->getLoc(), op);
1741 if (
auto instOp = dyn_cast<InstanceOp>(op)) {
1746 dyn_cast<FModuleOp>(*instanceGraph->getReferencedModule(instOp));
1749 auto domainIt = domains.find(refModule);
1750 if (domainIt == domains.end())
1752 auto &domain = domainIt->second.back().first;
1756 <<
"- Update instance '" << instOp.getName() <<
"'\n");
1760 if (domain.newPortName) {
1761 LLVM_DEBUG(
llvm::dbgs() <<
" - Adding new result as reset\n");
1763 auto newInstOp = instOp.cloneAndInsertPorts(
1765 {domain.newPortName,
1766 type_cast<FIRRTLBaseType>(actualReset.getType()),
1768 instReset = newInstOp.getResult(0);
1771 instOp.replaceAllUsesWith(newInstOp.getResults().drop_front());
1772 instanceGraph->replaceInstance(instOp, newInstOp);
1775 }
else if (domain.existingPort.has_value()) {
1776 auto idx = *domain.existingPort;
1777 instReset = instOp.getResult(idx);
1778 LLVM_DEBUG(
llvm::dbgs() <<
" - Using result #" << idx <<
" as reset\n");
1788 assert(instReset && actualReset);
1789 builder.setInsertionPointAfter(instOp);
1790 builder.create<StrictConnectOp>(instReset, actualReset);
1795 if (
auto regOp = dyn_cast<RegOp>(op)) {
1799 LLVM_DEBUG(
llvm::dbgs() <<
"- Adding async reset to " << regOp <<
"\n");
1801 auto newRegOp =
builder.create<RegResetOp>(
1802 regOp.getResult().getType(), regOp.getClockVal(), actualReset, zero,
1803 regOp.getNameAttr(), regOp.getNameKindAttr(), regOp.getAnnotations(),
1804 regOp.getInnerSymAttr(), regOp.getForceableAttr());
1805 regOp.getResult().replaceAllUsesWith(newRegOp.getResult());
1806 if (regOp.getForceable())
1807 regOp.getRef().replaceAllUsesWith(newRegOp.getRef());
1813 if (
auto regOp = dyn_cast<RegResetOp>(op)) {
1815 if (type_isa<AsyncResetType>(regOp.getResetSignal().getType())) {
1817 <<
"- Skipping (has async reset) " << regOp <<
"\n");
1820 if (failed(regOp.verifyInvariants()))
1821 signalPassFailure();
1824 LLVM_DEBUG(
llvm::dbgs() <<
"- Updating reset of " << regOp <<
"\n");
1826 auto reset = regOp.getResetSignal();
1827 auto value = regOp.getResetValue();
1833 builder.setInsertionPointAfterValue(regOp.getResult());
1834 auto mux =
builder.create<MuxPrimOp>(reset, value, regOp.getResult());
1838 builder.setInsertionPoint(regOp);
1840 regOp.getResetSignalMutable().assign(actualReset);
1841 regOp.getResetValueMutable().assign(zero);
1845 LogicalResult InferResetsPass::verifyNoAbstractReset() {
1846 bool hasAbstractResetPorts =
false;
1847 for (FModuleLike module :
1848 getOperation().getBodyBlock()->getOps<FModuleLike>()) {
1849 for (
PortInfo port : module.getPorts()) {
1850 if (getBaseOfType<ResetType>(port.type)) {
1851 auto diag = emitError(port.loc)
1852 <<
"a port \"" << port.getName()
1853 <<
"\" with abstract reset type was unable to be "
1854 "inferred by InferResets (is this a top-level port?)";
1855 diag.attachNote(module->getLoc())
1856 <<
"the module with this uninferred reset port was defined here";
1857 hasAbstractResetPorts =
true;
1862 if (hasAbstractResetPorts)
assert(baseType &&"element must be base type")
static Value createZeroValue(ImplicitLocOpBuilder &builder, FIRRTLBaseType type, SmallDenseMap< FIRRTLBaseType, Value > &cache)
Construct a zero value of the given type using the given builder.
static unsigned getFieldID(BundleType type, unsigned index)
bool operator!=(const ResetDomain &a, const ResetDomain &b)
static std::pair< StringAttr, FModuleOp > getResetNameAndModule(Value reset)
Return the name and parent module of a reset.
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
::circt::igraph::InstanceOpInterface InstanceLike
An absolute instance path.
static bool isUselessVec(FIRRTLBaseType oldType, unsigned fieldID)
bool operator==(const ResetDomain &a, const ResetDomain &b)
static StringAttr getResetName(Value reset)
Return the name of a reset.
static bool insertResetMux(ImplicitLocOpBuilder &builder, Value target, Value reset, Value resetValue)
Helper function that inserts reset multiplexer into all ConnectOps with the given target.
SmallVector< InstanceLike > InstancePathVec
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static bool typeContainsReset(Type type)
Check whether a type contains a ResetType.
ArrayRef< InstanceLike > InstancePathRef
static StringRef getTail(InstancePathRef path)
static bool getDeclName(Value value, SmallString< 32 > &string)
static unsigned getMaxFieldID(FIRRTLBaseType type)
static InstancePath empty
This class represents a reference to a specific field or element of an aggregate value.
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Value getValue() const
Get the Value which created this location.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
This class provides a read-only projection of an annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
FIRRTLBaseType getConstType(bool isConst)
Return a 'const' or non-'const' version of this type.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
constexpr const char * excludeMemToRegAnnoClass
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
FIRRTLType mapBaseType(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
constexpr const char * fullAsyncResetAnnoClass
Annotation that marks a reset (port or wire) and domain.
T & operator<<(T &os, FIRVersion version)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
constexpr const char * ignoreFullAsyncResetAnnoClass
Annotation that marks a module as not belonging to any reset domain.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
std::unique_ptr< mlir::Pass > createInferResetsPass()
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
inline ::llvm::hash_code hash_value(const FieldRef &fieldRef)
Get a hash code for a FieldRef.
mlir::raw_indented_ostream & dbgs()
This holds the name and type that describes the module's ports.
static ResetSignal getEmptyKey()
static ResetSignal getTombstoneKey()
static bool isEqual(const ResetSignal &lhs, const ResetSignal &rhs)
static unsigned getHashValue(const ResetSignal &x)