23#include "mlir/IR/Dominance.h"
24#include "mlir/IR/ImplicitLocOpBuilder.h"
25#include "mlir/IR/Threading.h"
26#include "mlir/Pass/Pass.h"
27#include "llvm/ADT/EquivalenceClasses.h"
28#include "llvm/ADT/SetVector.h"
29#include "llvm/ADT/TypeSwitch.h"
30#include "llvm/Support/Debug.h"
32#define DEBUG_TYPE "infer-resets"
36#define GEN_PASS_DEF_INFERRESETS
37#include "circt/Dialect/FIRRTL/Passes.h.inc"
41using circt::igraph::InstanceOpInterface;
44using llvm::BumpPtrAllocator;
46using llvm::SmallDenseSet;
47using llvm::SmallSetVector;
49using mlir::InferTypeOpInterface;
52using namespace firrtl;
69 std::optional<unsigned> existingPort;
70 StringAttr newPortName;
72 ResetDomain(Value reset) : reset(reset) {}
76inline bool operator==(
const ResetDomain &a,
const ResetDomain &b) {
77 return (a.isTop == b.isTop && a.reset == b.reset);
79inline bool operator!=(
const ResetDomain &a,
const ResetDomain &b) {
86 if (
auto arg = dyn_cast<BlockArgument>(reset)) {
87 auto module = cast<FModuleOp>(arg.getParentRegion()->getParentOp());
88 return {
module.getPortNameAttr(arg.getArgNumber()), module};
90 auto op = reset.getDefiningOp();
91 return {op->getAttrOfType<StringAttr>(
"name"),
92 op->getParentOfType<FModuleOp>()};
107 auto it = cache.find(type);
108 if (it != cache.end())
110 auto nullBit = [&]() {
112 builder, UIntType::get(builder.getContext(), 1,
true),
117 .
Case<ClockType>([&](
auto type) {
118 return builder.create<AsClockPrimOp>(nullBit());
120 .Case<AsyncResetType>([&](
auto type) {
121 return builder.create<AsAsyncResetPrimOp>(nullBit());
123 .Case<SIntType, UIntType>([&](
auto type) {
124 return builder.create<ConstantOp>(
125 type, APInt::getZero(type.getWidth().value_or(1)));
127 .Case<BundleType>([&](
auto type) {
128 auto wireOp = builder.create<WireOp>(type);
129 for (
unsigned i = 0, e = type.getNumElements(); i < e; ++i) {
130 auto fieldType = type.getElementTypePreservingConst(i);
133 builder.create<SubfieldOp>(fieldType, wireOp.getResult(), i);
136 return wireOp.getResult();
138 .Case<FVectorType>([&](
auto type) {
139 auto wireOp = builder.create<WireOp>(type);
141 builder, type.getElementTypePreservingConst(), cache);
142 for (
unsigned i = 0, e = type.getNumElements(); i < e; ++i) {
143 auto acc = builder.create<SubindexOp>(zero.getType(),
144 wireOp.getResult(), i);
147 return wireOp.getResult();
149 .Case<ResetType, AnalogType>(
150 [&](
auto type) {
return builder.create<InvalidValueOp>(type); })
152 llvm_unreachable(
"switch handles all types");
155 cache.insert({type, value});
172 Value reset, Value resetValue) {
176 bool resetValueUsed =
false;
178 for (
auto &use : target.getUses()) {
179 Operation *useOp = use.getOwner();
180 builder.setInsertionPoint(useOp);
181 TypeSwitch<Operation *>(useOp)
184 .Case<ConnectOp, MatchingConnectOp>([&](
auto op) {
185 if (op.getDest() != target)
187 LLVM_DEBUG(llvm::dbgs() <<
" - Insert mux into " << op <<
"\n");
189 builder.create<MuxPrimOp>(reset, resetValue, op.getSrc());
190 op.getSrcMutable().assign(muxOp);
191 resetValueUsed =
true;
194 .Case<SubfieldOp>([&](
auto op) {
196 builder.create<SubfieldOp>(resetValue, op.getFieldIndexAttr());
198 resetValueUsed =
true;
200 resetSubValue.erase();
203 .Case<SubindexOp>([&](
auto op) {
205 builder.create<SubindexOp>(resetValue, op.getIndexAttr());
207 resetValueUsed =
true;
209 resetSubValue.erase();
212 .Case<SubaccessOp>([&](
auto op) {
213 if (op.getInput() != target)
216 builder.create<SubaccessOp>(resetValue, op.getIndex());
218 resetValueUsed =
true;
220 resetSubValue.erase();
223 return resetValueUsed;
238 bool operator<(
const ResetSignal &other)
const {
return field < other.field; }
239 bool operator==(
const ResetSignal &other)
const {
240 return field == other.field;
242 bool operator!=(
const ResetSignal &other)
const {
return !(*
this == other); }
262using ResetDrives = SmallVector<ResetDrive, 1>;
265using ResetNetwork = llvm::iterator_range<
266 llvm::EquivalenceClasses<ResetSignal>::member_iterator>;
269enum class ResetKind { Async, Sync };
271static StringRef resetKindToStringRef(
const ResetKind &kind) {
273 case ResetKind::Async:
275 case ResetKind::Sync:
278 llvm_unreachable(
"unhandled reset kind");
294 static bool isEqual(
const ResetSignal &lhs,
const ResetSignal &rhs) {
303 case ResetKind::Async:
304 return os <<
"async";
305 case ResetKind::Sync:
415struct InferResetsPass
416 :
public circt::firrtl::impl::InferResetsBase<InferResetsPass> {
417 void runOnOperation()
override;
418 void runOnOperationInner();
421 using InferResetsBase::InferResetsBase;
422 InferResetsPass(
const InferResetsPass &other) : InferResetsBase(other) {}
427 void traceResets(CircuitOp circuit);
428 void traceResets(InstanceOp inst);
429 void traceResets(Value dst, Value src, Location loc);
430 void traceResets(Value value);
431 void traceResets(Type dstType, Value dst,
unsigned dstID, Type srcType,
432 Value src,
unsigned srcID, Location loc);
434 LogicalResult inferAndUpdateResets();
435 FailureOr<ResetKind> inferReset(ResetNetwork net);
436 LogicalResult updateReset(ResetNetwork net, ResetKind kind);
442 LogicalResult collectAnnos(CircuitOp circuit);
448 FailureOr<std::optional<Value>> collectAnnos(FModuleOp module);
450 LogicalResult buildDomains(CircuitOp circuit);
451 void buildDomains(FModuleOp module,
const InstancePath &instPath,
453 unsigned indent = 0);
455 void determineImpl();
456 void determineImpl(FModuleOp module, ResetDomain &domain);
458 LogicalResult implementFullReset();
459 LogicalResult implementFullReset(FModuleOp module, ResetDomain &domain);
460 void implementFullReset(Operation *op, FModuleOp module, Value actualReset);
462 LogicalResult verifyNoAbstractReset();
468 ResetNetwork getResetNetwork(ResetSignal signal) {
469 return llvm::make_range(resetClasses.findLeader(signal),
470 resetClasses.member_end());
474 ResetDrives &getResetDrives(ResetNetwork net) {
475 return resetDrives[*net.begin()];
480 ResetSignal guessRoot(ResetNetwork net);
481 ResetSignal guessRoot(ResetSignal signal) {
482 return guessRoot(getResetNetwork(signal));
489 llvm::EquivalenceClasses<ResetSignal> resetClasses;
492 DenseMap<ResetSignal, ResetDrives> resetDrives;
497 DenseMap<Operation *, Value> annotatedResets;
501 MapVector<FModuleOp, SmallVector<std::pair<ResetDomain, InstancePath>, 1>>
508 std::unique_ptr<InstancePathCache> instancePathCache;
512void InferResetsPass::runOnOperation() {
513 runOnOperationInner();
514 resetClasses = llvm::EquivalenceClasses<ResetSignal>();
516 annotatedResets.clear();
518 instancePathCache.reset(
nullptr);
519 markAnalysesPreserved<InstanceGraph>();
522void InferResetsPass::runOnOperationInner() {
523 instanceGraph = &getAnalysis<InstanceGraph>();
524 instancePathCache = std::make_unique<InstancePathCache>(*instanceGraph);
527 traceResets(getOperation());
530 if (failed(inferAndUpdateResets()))
531 return signalPassFailure();
534 if (failed(collectAnnos(getOperation())))
535 return signalPassFailure();
538 if (failed(buildDomains(getOperation())))
539 return signalPassFailure();
545 if (failed(implementFullReset()))
546 return signalPassFailure();
549 if (failed(verifyNoAbstractReset()))
550 return signalPassFailure();
554 return std::make_unique<InferResetsPass>();
557ResetSignal InferResetsPass::guessRoot(ResetNetwork net) {
558 ResetDrives &drives = getResetDrives(net);
559 ResetSignal bestSignal = *net.begin();
560 unsigned bestNumDrives = -1;
562 for (
auto signal : net) {
564 if (isa_and_nonnull<InvalidValueOp>(
565 signal.field.getValue().getDefiningOp()))
570 unsigned numDrives = 0;
571 for (
auto &drive : drives)
572 if (drive.dst == signal)
578 if (numDrives < bestNumDrives) {
579 bestNumDrives = numDrives;
598 .
Case<BundleType>([](
auto type) {
600 for (
auto e : type.getElements())
605 [](
auto type) {
return getMaxFieldID(type.getElementType()) + 1; })
606 .Default([](
auto) {
return 0; });
610 assert(index < type.getNumElements());
612 for (
unsigned i = 0; i < index; ++i)
620 assert(type.getNumElements() &&
"Bundle must have >0 fields");
622 for (
const auto &e : llvm::enumerate(type.getElements())) {
624 if (fieldID < numSubfields)
626 fieldID -= numSubfields;
628 assert(
false &&
"field id outside bundle");
634 if (oldType.isGround()) {
640 if (
auto bundleType = type_dyn_cast<BundleType>(oldType)) {
648 if (
auto vectorType = type_dyn_cast<FVectorType>(oldType)) {
649 if (vectorType.getNumElements() == 0)
666 if (
auto arg = dyn_cast<BlockArgument>(value)) {
667 auto module = cast<FModuleOp>(arg.getOwner()->getParentOp());
668 string +=
module.getPortName(arg.getArgNumber());
672 auto *op = value.getDefiningOp();
673 return TypeSwitch<Operation *, bool>(op)
674 .Case<InstanceOp, MemOp>([&](
auto op) {
675 string += op.getName();
678 op.getPortName(cast<OpResult>(value).getResultNumber()).getValue();
681 .Case<WireOp, NodeOp, RegOp, RegResetOp>([&](
auto op) {
682 string += op.getName();
685 .Default([](
auto) {
return false; });
689 SmallString<64> name;
694 auto type = value.getType();
697 if (
auto bundleType = type_dyn_cast<BundleType>(type)) {
700 auto &element = bundleType.getElements()[index];
703 string += element.name.getValue();
706 localID = localID -
getFieldID(bundleType, index);
707 }
else if (
auto vecType = type_dyn_cast<FVectorType>(type)) {
710 type = vecType.getElementType();
717 llvm_unreachable(
"unsupported type");
729 return TypeSwitch<Type, bool>(type)
731 return type.getRecursiveTypeProperties().hasUninferredReset;
733 .Default([](
auto) {
return false; });
740void InferResetsPass::traceResets(CircuitOp circuit) {
742 llvm::dbgs() <<
"\n";
743 debugHeader(
"Tracing uninferred resets") <<
"\n\n";
746 SmallVector<std::pair<FModuleOp, SmallVector<Operation *>>> moduleToOps;
748 for (
auto module : circuit.getOps<FModuleOp>())
749 moduleToOps.push_back({module, {}});
752 getAnalysis<hw::InnerSymbolTableCollection>()};
754 mlir::parallelForEach(circuit.getContext(), moduleToOps, [](
auto &e) {
755 e.first.walk([&](Operation *op) {
759 op->getResultTypes(),
760 [](mlir::Type type) { return typeContainsReset(type); }) ||
761 llvm::any_of(op->getOperandTypes(), typeContainsReset))
762 e.second.push_back(op);
766 for (
auto &[_, ops] : moduleToOps)
767 for (auto *op : ops) {
768 TypeSwitch<Operation *>(op)
769 .Case<FConnectLike>([&](
auto op) {
770 traceResets(op.getDest(), op.getSrc(), op.getLoc());
772 .Case<InstanceOp>([&](
auto op) { traceResets(op); })
773 .Case<RefSendOp>([&](
auto op) {
775 traceResets(op.getType().getType(), op.getResult(), 0,
776 op.getBase().getType().getPassiveType(), op.getBase(),
779 .Case<RefResolveOp>([&](
auto op) {
781 traceResets(op.getType(), op.getResult(), 0,
782 op.getRef().getType().getType(), op.getRef(), 0,
785 .Case<Forceable>([&](Forceable op) {
786 if (
auto node = dyn_cast<NodeOp>(op.getOperation()))
787 traceResets(node.getResult(), node.getInput(), node.getLoc());
789 if (op.isForceable())
790 traceResets(op.getDataType(), op.getData(), 0, op.getDataType(),
791 op.getDataRef(), 0, op.getLoc());
793 .Case<RWProbeOp>([&](RWProbeOp op) {
794 auto ist = irn.lookup(op.getTarget());
797 auto baseType = op.getType().getType();
798 traceResets(baseType, op.getResult(), 0, baseType.getPassiveType(),
799 ref.getValue(), ref.getFieldID(), op.getLoc());
801 .Case<UninferredResetCastOp, ConstCastOp, RefCastOp>([&](
auto op) {
802 traceResets(op.getResult(), op.getInput(), op.getLoc());
804 .Case<InvalidValueOp>([&](
auto op) {
813 auto type = op.getType();
816 LLVM_DEBUG(llvm::dbgs() <<
"Uniquify " << op <<
"\n");
817 ImplicitLocOpBuilder builder(op->getLoc(), op);
819 llvm::make_early_inc_range(
llvm::drop_begin(op->getUses()))) {
825 auto newOp = builder.create<InvalidValueOp>(type);
830 .Case<SubfieldOp>([&](
auto op) {
833 BundleType bundleType = op.getInput().getType();
834 auto index = op.getFieldIndex();
835 traceResets(op.getType(), op.getResult(), 0,
836 bundleType.getElements()[index].type, op.getInput(),
840 .Case<SubindexOp, SubaccessOp>([&](
auto op) {
853 FVectorType vectorType = op.getInput().getType();
854 traceResets(op.getType(), op.getResult(), 0,
855 vectorType.getElementType(), op.getInput(),
859 .Case<RefSubOp>([&](RefSubOp op) {
861 auto aggType = op.getInput().getType().getType();
862 uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(aggType)
863 .Case<FVectorType>([](
auto type) {
866 .Case<BundleType>([&](
auto type) {
869 traceResets(op.getType(), op.getResult(), 0,
870 op.getResult().getType(), op.getInput(), fieldID,
878void InferResetsPass::traceResets(InstanceOp inst) {
880 auto module = inst.getReferencedModule<FModuleOp>(*instanceGraph);
883 LLVM_DEBUG(llvm::dbgs() <<
"Visiting instance " << inst.getName() <<
"\n");
886 for (
const auto &it :
llvm::enumerate(inst.getResults())) {
887 auto dir =
module.getPortDirection(it.index());
888 Value dstPort =
module.getArgument(it.index());
889 Value srcPort = it.value();
890 if (dir == Direction::Out)
891 std::swap(dstPort, srcPort);
892 traceResets(dstPort, srcPort, it.value().getLoc());
898void InferResetsPass::traceResets(Value dst, Value src, Location loc) {
900 traceResets(dst.getType(), dst, 0, src.getType(), src, 0, loc);
905void InferResetsPass::traceResets(Type dstType, Value dst,
unsigned dstID,
906 Type srcType, Value src,
unsigned srcID,
908 if (
auto dstBundle = type_dyn_cast<BundleType>(dstType)) {
909 auto srcBundle = type_cast<BundleType>(srcType);
910 for (
unsigned dstIdx = 0, e = dstBundle.getNumElements(); dstIdx < e;
912 auto dstField = dstBundle.getElements()[dstIdx].name;
913 auto srcIdx = srcBundle.getElementIndex(dstField);
916 auto &dstElt = dstBundle.getElements()[dstIdx];
917 auto &srcElt = srcBundle.getElements()[*srcIdx];
919 traceResets(srcElt.type, src, srcID +
getFieldID(srcBundle, *srcIdx),
920 dstElt.type, dst, dstID +
getFieldID(dstBundle, dstIdx),
923 traceResets(dstElt.type, dst, dstID +
getFieldID(dstBundle, dstIdx),
924 srcElt.type, src, srcID +
getFieldID(srcBundle, *srcIdx),
931 if (
auto dstVector = type_dyn_cast<FVectorType>(dstType)) {
932 auto srcVector = type_cast<FVectorType>(srcType);
933 auto srcElType = srcVector.getElementType();
934 auto dstElType = dstVector.getElementType();
947 traceResets(dstElType, dst, dstID +
getFieldID(dstVector), srcElType, src,
953 if (
auto dstRef = type_dyn_cast<RefType>(dstType)) {
954 auto srcRef = type_cast<RefType>(srcType);
955 return traceResets(dstRef.getType(), dst, dstID, srcRef.getType(), src,
960 auto dstBase = type_dyn_cast<FIRRTLBaseType>(dstType);
961 auto srcBase = type_dyn_cast<FIRRTLBaseType>(srcType);
962 if (!dstBase || !srcBase)
964 if (!type_isa<ResetType>(dstBase) && !type_isa<ResetType>(srcBase))
969 LLVM_DEBUG(llvm::dbgs() <<
"Visiting driver '" << dstField <<
"' = '"
970 << srcField <<
"' (" << dstType <<
" = " << srcType
976 ResetSignal dstLeader =
977 *resetClasses.findLeader(resetClasses.insert({dstField, dstBase}));
978 ResetSignal srcLeader =
979 *resetClasses.findLeader(resetClasses.insert({srcField, srcBase}));
982 ResetSignal unionLeader = *resetClasses.unionSets(dstLeader, srcLeader);
983 assert(unionLeader == dstLeader || unionLeader == srcLeader);
988 if (dstLeader != srcLeader) {
989 auto &unionDrives = resetDrives[unionLeader];
990 auto mergedDrivesIt =
991 resetDrives.find(unionLeader == dstLeader ? srcLeader : dstLeader);
992 if (mergedDrivesIt != resetDrives.end()) {
993 unionDrives.append(mergedDrivesIt->second);
994 resetDrives.erase(mergedDrivesIt);
1000 resetDrives[unionLeader].push_back(
1001 {{dstField, dstBase}, {srcField, srcBase}, loc});
1008LogicalResult InferResetsPass::inferAndUpdateResets() {
1010 llvm::dbgs() <<
"\n";
1013 for (
auto it = resetClasses.begin(),
end = resetClasses.end(); it !=
end;
1015 if (!it->isLeader())
1017 ResetNetwork net = llvm::make_range(resetClasses.member_begin(it),
1018 resetClasses.member_end());
1021 auto kind = inferReset(net);
1026 if (failed(updateReset(net, *kind)))
1032FailureOr<ResetKind> InferResetsPass::inferReset(ResetNetwork net) {
1033 LLVM_DEBUG(llvm::dbgs() <<
"Inferring reset network with "
1034 << std::distance(net.begin(), net.end())
1038 unsigned asyncDrives = 0;
1039 unsigned syncDrives = 0;
1040 unsigned invalidDrives = 0;
1041 for (ResetSignal signal : net) {
1043 if (type_isa<AsyncResetType>(signal.type))
1045 else if (type_isa<UIntType>(signal.type))
1048 isa_and_nonnull<InvalidValueOp>(
1049 signal.field.getValue().getDefiningOp()))
1052 LLVM_DEBUG(llvm::dbgs() <<
"- Found " << asyncDrives <<
" async, "
1053 << syncDrives <<
" sync, " << invalidDrives
1054 <<
" invalid drives\n");
1057 if (asyncDrives == 0 && syncDrives == 0 && invalidDrives == 0) {
1058 ResetSignal root = guessRoot(net);
1059 auto diag = mlir::emitError(root.field.getValue().getLoc())
1060 <<
"reset network never driven with concrete type";
1061 for (ResetSignal signal : net)
1062 diag.attachNote(signal.field.getLoc()) <<
"here: ";
1067 if (asyncDrives > 0 && syncDrives > 0) {
1068 ResetSignal root = guessRoot(net);
1069 bool majorityAsync = asyncDrives >= syncDrives;
1070 auto diag = mlir::emitError(root.field.getValue().getLoc())
1072 SmallString<32> fieldName;
1074 diag <<
" \"" << fieldName <<
"\"";
1075 diag <<
" simultaneously connected to async and sync resets";
1076 diag.attachNote(root.field.getValue().getLoc())
1077 <<
"majority of connections to this reset are "
1078 << (majorityAsync ?
"async" :
"sync");
1079 for (
auto &drive : getResetDrives(net)) {
1080 if ((type_isa<AsyncResetType>(drive.dst.type) && !majorityAsync) ||
1081 (type_isa<AsyncResetType>(drive.src.type) && !majorityAsync) ||
1082 (type_isa<UIntType>(drive.dst.type) && majorityAsync) ||
1083 (type_isa<UIntType>(drive.src.type) && majorityAsync))
1084 diag.attachNote(drive.loc)
1085 << (type_isa<AsyncResetType>(drive.src.type) ?
"async" :
"sync")
1094 auto kind = (asyncDrives ? ResetKind::Async : ResetKind::Sync);
1095 LLVM_DEBUG(llvm::dbgs() <<
"- Inferred as " << kind <<
"\n");
1103LogicalResult InferResetsPass::updateReset(ResetNetwork net, ResetKind kind) {
1104 LLVM_DEBUG(llvm::dbgs() <<
"Updating reset network with "
1105 << std::distance(net.begin(), net.end())
1106 <<
" nodes to " << kind <<
"\n");
1110 if (kind == ResetKind::Async)
1111 resetType = AsyncResetType::get(&getContext());
1113 resetType = UIntType::get(&getContext(), 1);
1118 SmallSetVector<Operation *, 16> worklist;
1119 SmallDenseSet<Operation *> moduleWorklist;
1120 SmallDenseSet<std::pair<Operation *, Operation *>> extmoduleWorklist;
1121 for (
auto signal : net) {
1122 Value value = signal.field.getValue();
1123 if (!isa<BlockArgument>(value) &&
1124 !isa_and_nonnull<WireOp, RegOp, RegResetOp, InstanceOp, InvalidValueOp,
1125 ConstCastOp, RefCastOp, UninferredResetCastOp,
1126 RWProbeOp>(value.getDefiningOp()))
1128 if (updateReset(signal.field, resetType)) {
1129 for (
auto user : value.getUsers())
1130 worklist.insert(user);
1131 if (
auto blockArg = dyn_cast<BlockArgument>(value))
1132 moduleWorklist.insert(blockArg.getOwner()->getParentOp());
1133 else if (
auto instOp = value.getDefiningOp<InstanceOp>()) {
1134 if (
auto extmodule =
1135 instOp.getReferencedModule<FExtModuleOp>(*instanceGraph))
1136 extmoduleWorklist.insert({extmodule, instOp});
1137 }
else if (
auto uncast = value.getDefiningOp<UninferredResetCastOp>()) {
1138 uncast.replaceAllUsesWith(uncast.getInput());
1148 while (!worklist.empty()) {
1149 auto *wop = worklist.pop_back_val();
1150 SmallVector<Type, 2> types;
1151 if (
auto op = dyn_cast<InferTypeOpInterface>(wop)) {
1153 SmallVector<Type, 2> types;
1154 if (failed(op.inferReturnTypes(op->getContext(), op->getLoc(),
1155 op->getOperands(), op->getAttrDictionary(),
1156 op->getPropertiesStorage(),
1157 op->getRegions(), types)))
1162 for (
auto it :
llvm::zip(op->getResults(), types)) {
1163 auto newType = std::get<1>(it);
1164 if (std::get<0>(it).getType() == newType)
1166 std::get<0>(it).setType(newType);
1167 for (
auto *user : std::
get<0>(it).getUsers())
1168 worklist.insert(user);
1170 LLVM_DEBUG(llvm::dbgs() <<
"- Inferred " << *op <<
"\n");
1171 }
else if (
auto uop = dyn_cast<UninferredResetCastOp>(wop)) {
1172 for (
auto *user : uop.getResult().getUsers())
1173 worklist.insert(user);
1174 uop.replaceAllUsesWith(uop.getInput());
1175 LLVM_DEBUG(llvm::dbgs() <<
"- Inferred " << uop <<
"\n");
1181 for (
auto *op : moduleWorklist) {
1182 auto module = dyn_cast<FModuleOp>(op);
1186 SmallVector<Attribute> argTypes;
1187 argTypes.reserve(module.getNumPorts());
1188 for (
auto arg : module.getArguments())
1189 argTypes.push_back(TypeAttr::
get(arg.getType()));
1191 module.setPortTypesAttr(ArrayAttr::get(op->getContext(), argTypes));
1192 LLVM_DEBUG(llvm::dbgs()
1193 <<
"- Updated type of module '" << module.getName() <<
"'\n");
1197 for (
auto pair : extmoduleWorklist) {
1198 auto module = cast<FExtModuleOp>(pair.first);
1199 auto instOp = cast<InstanceOp>(pair.second);
1201 SmallVector<Attribute> types;
1202 for (
auto type : instOp.getResultTypes())
1203 types.push_back(TypeAttr::
get(type));
1205 module.setPortTypesAttr(ArrayAttr::get(module->getContext(), types));
1206 LLVM_DEBUG(llvm::dbgs()
1207 <<
"- Updated type of extmodule '" << module.getName() <<
"'\n");
1217 if (oldType.isGround()) {
1223 if (
auto bundleType = type_dyn_cast<BundleType>(oldType)) {
1225 SmallVector<BundleType::BundleElement> fields(bundleType.begin(),
1228 fields[index].type, fieldID -
getFieldID(bundleType, index), fieldType);
1229 return BundleType::get(oldType.getContext(), fields, bundleType.
isConst());
1233 if (
auto vectorType = type_dyn_cast<FVectorType>(oldType)) {
1234 auto newType =
updateType(vectorType.getElementType(),
1235 fieldID -
getFieldID(vectorType), fieldType);
1236 return FVectorType::get(newType, vectorType.getNumElements(),
1237 vectorType.isConst());
1240 llvm_unreachable(
"unknown aggregate type");
1247 auto oldType = type_cast<FIRRTLType>(field.
getValue().getType());
1253 if (oldType == newType)
1255 LLVM_DEBUG(llvm::dbgs() <<
"- Updating '" << field <<
"' from " << oldType
1256 <<
" to " << newType <<
"\n");
1265LogicalResult InferResetsPass::collectAnnos(CircuitOp circuit) {
1267 llvm::dbgs() <<
"\n";
1268 debugHeader(
"Gather reset annotations") <<
"\n\n";
1270 SmallVector<std::pair<FModuleOp, std::optional<Value>>> results;
1271 for (
auto module : circuit.getOps<FModuleOp>())
1272 results.push_back({module, {}});
1274 if (failed(mlir::failableParallelForEach(
1275 circuit.getContext(), results, [&](
auto &moduleAndResult) {
1276 auto result = collectAnnos(moduleAndResult.first);
1279 moduleAndResult.second = *result;
1284 for (
auto [module, reset] : results)
1285 if (reset.has_value())
1286 annotatedResets.insert({module, *reset});
1290FailureOr<std::optional<Value>>
1291InferResetsPass::collectAnnos(FModuleOp module) {
1292 bool anyFailed =
false;
1293 SmallSetVector<std::pair<Annotation, Location>, 4> conflictingAnnos;
1297 bool ignore =
false;
1301 conflictingAnnos.insert({anno, module.getLoc()});
1306 module.emitError("''FullResetAnnotation' cannot target module; must
"
1307 "target port or wire/node instead
");
1315 // Consume any reset annotations on module ports.
1317 // Helper for checking annotations and determining the reset
1318 auto checkAnnotations = [&](Annotation anno, Value arg) {
1319 if (anno.isClass(fullResetAnnoClass)) {
1320 ResetKind expectedResetKind;
1321 if (auto rt = anno.getMember<StringAttr>("resetType
")) {
1323 expectedResetKind = ResetKind::Sync;
1324 } else if (rt == "async
") {
1325 expectedResetKind = ResetKind::Async;
1327 mlir::emitError(arg.getLoc(),
1328 "'FullResetAnnotation' requires resetType ==
'sync' "
1329 "|
'async', but got resetType ==
")
1335 mlir::emitError(arg.getLoc(),
1336 "'FullResetAnnotation' requires resetType ==
"
1337 "'sync' |
'async', but got no resetType
");
1341 // Check that the type is well-formed
1342 bool isAsync = expectedResetKind == ResetKind::Async;
1343 bool validUint = false;
1344 if (auto uintT = dyn_cast<UIntType>(arg.getType()))
1345 validUint = uintT.getWidth() == 1;
1346 if ((isAsync && !isa<AsyncResetType>(arg.getType())) ||
1347 (!isAsync && !validUint)) {
1348 auto kind = resetKindToStringRef(expectedResetKind);
1349 mlir::emitError(arg.getLoc(),
1350 "'FullResetAnnotation' with resetType ==
'")
1351 << kind << "' must target
" << kind << " reset, but targets
"
1358 conflictingAnnos.insert({anno, reset.getLoc()});
1362 if (anno.isClass(excludeFromFullResetAnnoClass)) {
1364 mlir::emitError(arg.getLoc(),
1365 "'ExcludeFromFullResetAnnotation' cannot
"
1366 "target port/wire/node; must target
module instead");
1374 Value arg =
module.getArgument(argNum);
1375 return checkAnnotations(anno, arg);
1381 module.getBody().walk([&](Operation *op) {
1383 if (!isa<WireOp, NodeOp>(op)) {
1384 if (AnnotationSet::hasAnnotation(op, fullResetAnnoClass,
1385 excludeFromFullResetAnnoClass)) {
1388 "reset annotations must target module, port, or wire/node");
1396 auto arg = op->getResult(0);
1397 return checkAnnotations(anno, arg);
1406 if (!ignore && !reset) {
1407 LLVM_DEBUG(llvm::dbgs()
1408 <<
"No reset annotation for " << module.getName() <<
"\n");
1409 return std::optional<Value>();
1413 if (conflictingAnnos.size() > 1) {
1414 auto diag =
module.emitError("multiple reset annotations on module '")
1415 << module.getName() << "'";
1416 for (
auto &annoAndLoc : conflictingAnnos)
1417 diag.attachNote(annoAndLoc.second)
1418 <<
"conflicting " << annoAndLoc.first.getClassAttr() <<
":";
1424 llvm::dbgs() <<
"Annotated reset for " <<
module.getName() << ": ";
1426 llvm::dbgs() <<
"no domain\n";
1427 else if (
auto arg = dyn_cast<BlockArgument>(reset))
1428 llvm::dbgs() <<
"port " <<
module.getPortName(arg.getArgNumber()) << "\n";
1430 llvm::dbgs() <<
"wire "
1431 << reset.getDefiningOp()->getAttrOfType<StringAttr>(
"name")
1437 return std::optional<Value>(reset);
1449LogicalResult InferResetsPass::buildDomains(CircuitOp circuit) {
1451 llvm::dbgs() <<
"\n";
1452 debugHeader(
"Build full reset domains") <<
"\n\n";
1456 auto &instGraph = getAnalysis<InstanceGraph>();
1457 auto module = dyn_cast<FModuleOp>(*instGraph.getTopLevelNode()->getModule());
1459 LLVM_DEBUG(llvm::dbgs()
1460 <<
"Skipping circuit because main module is no `firrtl.module`");
1463 buildDomains(module,
InstancePath{}, Value{}, instGraph);
1466 bool anyFailed =
false;
1467 for (
auto &it : domains) {
1468 auto module = cast<FModuleOp>(it.first);
1469 auto &domainConflicts = it.second;
1470 if (domainConflicts.size() <= 1)
1474 SmallDenseSet<Value> printedDomainResets;
1475 auto diag =
module.emitError("module '")
1477 << "' instantiated in different reset domains";
1478 for (
auto &it : domainConflicts) {
1479 ResetDomain &domain = it.first;
1480 const auto &path = it.second;
1481 auto inst = path.leaf();
1482 auto loc = path.empty() ?
module.getLoc() : inst.getLoc();
1483 auto ¬e = diag.attachNote(loc);
1487 note <<
"root instance";
1489 note <<
"instance '";
1492 [&](InstanceOpInterface inst) { note << inst.getInstanceName(); },
1493 [&]() { note <<
"/"; });
1501 note <<
" reset domain rooted at '" << nameAndModule.first.getValue()
1502 <<
"' of module '" << nameAndModule.second.getName() <<
"'";
1505 if (printedDomainResets.insert(domain.reset).second) {
1506 diag.attachNote(domain.reset.getLoc())
1507 <<
"reset domain '" << nameAndModule.first.getValue()
1508 <<
"' of module '" << nameAndModule.second.getName()
1509 <<
"' declared here:";
1512 note <<
" no reset domain";
1515 return failure(anyFailed);
1518void InferResetsPass::buildDomains(FModuleOp module,
1523 llvm::dbgs().indent(indent * 2) <<
"Visiting ";
1524 if (instPath.
empty())
1525 llvm::dbgs() <<
"$root";
1527 llvm::dbgs() << instPath.
leaf().getInstanceName();
1528 llvm::dbgs() <<
" (" <<
module.getName() << ")\n";
1532 ResetDomain domain(parentReset);
1533 auto it = annotatedResets.find(module);
1534 if (it != annotatedResets.end()) {
1535 domain.isTop =
true;
1536 domain.reset = it->second;
1542 auto &entries = domains[module];
1543 if (llvm::all_of(entries,
1544 [&](
const auto &entry) {
return entry.first != domain; }))
1545 entries.push_back({domain, instPath});
1548 for (
auto *record : *instGraph[module]) {
1549 auto submodule = dyn_cast<FModuleOp>(*record->getTarget()->getModule());
1553 instancePathCache->appendInstance(instPath, record->getInstance());
1554 buildDomains(submodule, childPath, domain.reset, instGraph, indent + 1);
1559void InferResetsPass::determineImpl() {
1561 llvm::dbgs() <<
"\n";
1562 debugHeader(
"Determine implementation") <<
"\n\n";
1564 for (
auto &it : domains) {
1565 auto module = cast<FModuleOp>(it.first);
1566 auto &domain = it.second.back().first;
1567 determineImpl(module, domain);
1587void InferResetsPass::determineImpl(FModuleOp module, ResetDomain &domain) {
1590 LLVM_DEBUG(llvm::dbgs() <<
"Planning reset for " << module.getName() <<
"\n");
1595 LLVM_DEBUG(llvm::dbgs() <<
"- Rooting at local value "
1597 domain.existingValue = domain.reset;
1598 if (
auto blockArg = dyn_cast<BlockArgument>(domain.reset))
1599 domain.existingPort = blockArg.getArgNumber();
1606 auto neededType = domain.reset.getType();
1607 LLVM_DEBUG(llvm::dbgs() <<
"- Looking for existing port " << neededName
1609 auto portNames =
module.getPortNames();
1610 auto ports = llvm::zip(portNames, module.getArguments());
1611 auto portIt = llvm::find_if(
1612 ports, [&](
auto port) {
return std::get<0>(port) == neededName; });
1613 if (portIt != ports.end() && std::get<1>(*portIt).getType() == neededType) {
1614 LLVM_DEBUG(llvm::dbgs()
1615 <<
"- Reusing existing port " << neededName <<
"\n");
1616 domain.existingValue = std::get<1>(*portIt);
1617 domain.existingPort = std::distance(ports.begin(), portIt);
1627 if (portIt != ports.end()) {
1628 LLVM_DEBUG(llvm::dbgs()
1629 <<
"- Existing " << neededName <<
" has incompatible type "
1630 << std::get<1>(*portIt).getType() <<
"\n");
1632 unsigned suffix = 0;
1635 StringAttr::get(&getContext(), Twine(neededName.getValue()) +
1636 Twine(
"_") + Twine(suffix++));
1637 }
while (llvm::is_contained(portNames, newName));
1638 LLVM_DEBUG(llvm::dbgs()
1639 <<
"- Creating uniquified port " << newName <<
"\n");
1640 domain.newPortName = newName;
1646 LLVM_DEBUG(llvm::dbgs() <<
"- Creating new port " << neededName <<
"\n");
1647 domain.newPortName = neededName;
1655LogicalResult InferResetsPass::implementFullReset() {
1657 llvm::dbgs() <<
"\n";
1660 for (
auto &it : domains)
1661 if (failed(implementFullReset(cast<FModuleOp>(it.first),
1662 it.second.back().first)))
1672LogicalResult InferResetsPass::implementFullReset(FModuleOp module,
1673 ResetDomain &domain) {
1674 LLVM_DEBUG(llvm::dbgs() <<
"Implementing full reset for " << module.getName()
1678 if (!domain.reset) {
1679 LLVM_DEBUG(llvm::dbgs()
1680 <<
"- Skipping because module explicitly has no domain\n");
1685 auto *context =
module.getContext();
1687 annotations.addAnnotations(DictionaryAttr::get(
1688 context, NamedAttribute(StringAttr::get(context,
"class"),
1690 annotations.applyToOperation(module);
1693 Value actualReset = domain.existingValue;
1694 if (domain.newPortName) {
1695 PortInfo portInfo{domain.newPortName,
1696 domain.reset.getType(),
1699 domain.reset.getLoc()};
1700 module.insertPorts({{0, portInfo}});
1701 actualReset =
module.getArgument(0);
1702 LLVM_DEBUG(llvm::dbgs()
1703 <<
"- Inserted port " << domain.newPortName <<
"\n");
1707 llvm::dbgs() <<
"- Using ";
1708 if (
auto blockArg = dyn_cast<BlockArgument>(actualReset))
1709 llvm::dbgs() <<
"port #" << blockArg.getArgNumber() <<
" ";
1711 llvm::dbgs() <<
"wire/node ";
1717 SmallVector<Operation *> opsToUpdate;
1718 module.walk([&](Operation *op) {
1719 if (isa<InstanceOp, RegOp, RegResetOp>(op))
1720 opsToUpdate.push_back(op);
1727 if (!isa<BlockArgument>(actualReset)) {
1728 mlir::DominanceInfo dom(module);
1733 auto *resetOp = actualReset.getDefiningOp();
1734 if (!opsToUpdate.empty() && !dom.dominates(resetOp, opsToUpdate[0])) {
1735 LLVM_DEBUG(llvm::dbgs()
1736 <<
"- Reset doesn't dominate all uses, needs to be moved\n");
1740 auto nodeOp = dyn_cast<NodeOp>(resetOp);
1741 if (nodeOp && !dom.dominates(nodeOp.getInput(), opsToUpdate[0])) {
1742 LLVM_DEBUG(llvm::dbgs()
1743 <<
"- Promoting node to wire for move: " << nodeOp <<
"\n");
1744 auto builder = ImplicitLocOpBuilder::atBlockBegin(nodeOp.getLoc(),
1745 nodeOp->getBlock());
1746 auto wireOp = builder.create<WireOp>(
1747 nodeOp.getResult().getType(), nodeOp.getNameAttr(),
1748 nodeOp.getNameKindAttr(), nodeOp.getAnnotationsAttr(),
1749 nodeOp.getInnerSymAttr(), nodeOp.getForceableAttr());
1751 nodeOp->replaceAllUsesWith(wireOp);
1752 nodeOp->removeAttr(nodeOp.getInnerSymAttrName());
1756 nodeOp.setNameKind(NameKindEnum::DroppableName);
1757 nodeOp.setAnnotationsAttr(ArrayAttr::get(builder.getContext(), {}));
1758 builder.setInsertionPointAfter(nodeOp);
1759 emitConnect(builder, wireOp.getResult(), nodeOp.getResult());
1761 actualReset = wireOp.getResult();
1762 domain.existingValue = wireOp.getResult();
1767 Block *targetBlock = dom.findNearestCommonDominator(
1768 resetOp->getBlock(), opsToUpdate[0]->getBlock());
1770 if (targetBlock != resetOp->getBlock())
1771 llvm::dbgs() <<
"- Needs to be moved to different block\n";
1780 auto getParentInBlock = [](Operation *op,
Block *block) {
1781 while (op && op->getBlock() != block)
1782 op = op->getParentOp();
1785 auto *resetOpInTarget = getParentInBlock(resetOp, targetBlock);
1786 auto *firstOpInTarget = getParentInBlock(opsToUpdate[0], targetBlock);
1792 if (resetOpInTarget->isBeforeInBlock(firstOpInTarget))
1793 resetOp->moveBefore(resetOpInTarget);
1795 resetOp->moveBefore(firstOpInTarget);
1800 for (
auto *op : opsToUpdate)
1801 implementFullReset(op, module, actualReset);
1808void InferResetsPass::implementFullReset(Operation *op, FModuleOp module,
1809 Value actualReset) {
1810 ImplicitLocOpBuilder builder(op->getLoc(), op);
1813 if (
auto instOp = dyn_cast<InstanceOp>(op)) {
1817 auto refModule = instOp.getReferencedModule<FModuleOp>(*instanceGraph);
1820 auto domainIt = domains.find(refModule);
1821 if (domainIt == domains.end())
1823 auto &domain = domainIt->second.back().first;
1826 LLVM_DEBUG(llvm::dbgs()
1827 <<
"- Update instance '" << instOp.getName() <<
"'\n");
1831 if (domain.newPortName) {
1832 LLVM_DEBUG(llvm::dbgs() <<
" - Adding new result as reset\n");
1834 auto newInstOp = instOp.cloneAndInsertPorts(
1836 {domain.newPortName,
1837 type_cast<FIRRTLBaseType>(actualReset.getType()),
1839 instReset = newInstOp.getResult(0);
1842 instOp.replaceAllUsesWith(newInstOp.getResults().drop_front());
1843 instanceGraph->replaceInstance(instOp, newInstOp);
1846 }
else if (domain.existingPort.has_value()) {
1847 auto idx = *domain.existingPort;
1848 instReset = instOp.getResult(idx);
1849 LLVM_DEBUG(llvm::dbgs() <<
" - Using result #" << idx <<
" as reset\n");
1859 assert(instReset && actualReset);
1860 builder.setInsertionPointAfter(instOp);
1866 if (
auto regOp = dyn_cast<RegOp>(op)) {
1870 LLVM_DEBUG(llvm::dbgs() <<
"- Adding full reset to " << regOp <<
"\n");
1872 auto newRegOp = builder.create<RegResetOp>(
1873 regOp.getResult().getType(), regOp.getClockVal(), actualReset, zero,
1874 regOp.getNameAttr(), regOp.getNameKindAttr(), regOp.getAnnotations(),
1875 regOp.getInnerSymAttr(), regOp.getForceableAttr());
1876 regOp.getResult().replaceAllUsesWith(newRegOp.getResult());
1877 if (regOp.getForceable())
1878 regOp.getRef().replaceAllUsesWith(newRegOp.getRef());
1884 if (
auto regOp = dyn_cast<RegResetOp>(op)) {
1887 if (type_isa<AsyncResetType>(regOp.getResetSignal().getType()) ||
1888 type_isa<UIntType>(actualReset.getType())) {
1889 LLVM_DEBUG(llvm::dbgs() <<
"- Skipping (has reset) " << regOp <<
"\n");
1892 if (failed(regOp.verifyInvariants()))
1893 signalPassFailure();
1896 LLVM_DEBUG(llvm::dbgs() <<
"- Updating reset of " << regOp <<
"\n");
1898 auto reset = regOp.getResetSignal();
1899 auto value = regOp.getResetValue();
1905 builder.setInsertionPointAfterValue(regOp.getResult());
1906 auto mux = builder.create<MuxPrimOp>(reset, value, regOp.getResult());
1910 builder.setInsertionPoint(regOp);
1912 regOp.getResetSignalMutable().assign(actualReset);
1913 regOp.getResetValueMutable().assign(zero);
1917LogicalResult InferResetsPass::verifyNoAbstractReset() {
1918 bool hasAbstractResetPorts =
false;
1919 for (FModuleLike module :
1920 getOperation().
getBodyBlock()->getOps<FModuleLike>()) {
1921 for (
PortInfo port : module.getPorts()) {
1922 if (getBaseOfType<ResetType>(port.type)) {
1923 auto diag = emitError(port.loc)
1924 <<
"a port \"" << port.getName()
1925 <<
"\" with abstract reset type was unable to be "
1926 "inferred by InferResets (is this a top-level port?)";
1927 diag.attachNote(module->getLoc())
1928 <<
"the module with this uninferred reset port was defined here";
1929 hasAbstractResetPorts =
true;
1934 if (hasAbstractResetPorts)
assert(baseType &&"element must be base type")
static Value createZeroValue(ImplicitLocOpBuilder &builder, FIRRTLBaseType type, SmallDenseMap< FIRRTLBaseType, Value > &cache)
Construct a zero value of the given type using the given builder.
static unsigned getFieldID(BundleType type, unsigned index)
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
static bool isUselessVec(FIRRTLBaseType oldType, unsigned fieldID)
static StringAttr getResetName(Value reset)
Return the name of a reset.
static bool insertResetMux(ImplicitLocOpBuilder &builder, Value target, Value reset, Value resetValue)
Helper function that inserts reset multiplexer into all ConnectOps with the given target.
static bool typeContainsReset(Type type)
Check whether a type contains a ResetType.
static bool getDeclName(Value value, SmallString< 32 > &string)
static unsigned getMaxFieldID(FIRRTLBaseType type)
static std::pair< StringAttr, FModuleOp > getResetNameAndModule(Value reset)
Return the name and parent module of a reset.
static InstancePath empty
static Block * getBodyBlock(FModuleLike mod)
This class represents a reference to a specific field or element of an aggregate value.
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Value getValue() const
Get the Value which created this location.
This class provides a read-only projection over the MLIR attributes that represent a set of annotatio...
bool removeAnnotations(llvm::function_ref< bool(Annotation)> predicate)
Remove all annotations from this annotation set for which predicate returns true.
static bool removePortAnnotations(Operation *module, llvm::function_ref< bool(unsigned, Annotation)> predicate)
Remove all port annotations from a module or extmodule for which predicate returns true.
This class provides a read-only projection of an annotation.
bool isClass(Args... names) const
Return true if this annotation matches any of the specified class names.
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
FIRRTLBaseType getConstType(bool isConst)
Return a 'const' or non-'const' version of this type.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
This graph tracks modules and where they are instantiated.
An instance path composed of a series of instances.
InstanceOpInterface leaf() const
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
constexpr const char * excludeFromFullResetAnnoClass
Annotation that marks a module as not belonging to any reset domain.
FieldRef getFieldRefForTarget(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
constexpr const char * excludeMemToRegAnnoClass
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
FIRRTLType mapBaseType(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
constexpr const char * fullResetAnnoClass
Annotation that marks a reset (port or wire) and domain.
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
void emitConnect(OpBuilder &builder, Location loc, Value lhs, Value rhs)
Emit a connect between two values.
std::unique_ptr< mlir::Pass > createInferResetsPass()
static bool operator==(const ModulePort &a, const ModulePort &b)
static llvm::hash_code hash_value(const ModulePort &port)
bool operator<(const DictEntry &entry, const DictEntry &other)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(llvm::StringRef str, int width=80)
Write a "header"-like string to the debug stream with a certain width.
bool operator!=(uint64_t a, const FVInt &b)
This holds the name and type that describes the module's ports.
This class represents the namespace in which InnerRef's can be resolved.
A data structure that caches and provides absolute paths to module instances in the IR.
static ResetSignal getEmptyKey()
static ResetSignal getTombstoneKey()
static bool isEqual(const ResetSignal &lhs, const ResetSignal &rhs)
static unsigned getHashValue(const ResetSignal &x)