26#include "mlir/IR/AsmState.h"
27#include "mlir/IR/Iterators.h"
28#include "mlir/IR/Threading.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TinyPtrVector.h"
34#define DEBUG_TYPE "firrtl-infer-domains"
38#define GEN_PASS_DEF_INFERDOMAINS
39#include "circt/Dialect/FIRRTL/Passes.h.inc"
44using namespace firrtl;
50using mlir::ReverseIterator;
68 return info.getAsRange<IntegerAttr>();
69 return cast<ArrayAttr>(info[i]).getAsRange<IntegerAttr>();
73static bool isPort(BlockArgument arg) {
74 return isa<FModuleOp>(arg.getOwner()->getParentOp());
79 auto arg = dyn_cast<BlockArgument>(value);
87 for (
auto *user : port.getUsers())
88 if (
auto connect = dyn_cast<FConnectLike>(user))
89 if (connect.getDest() == port)
96 return type_isa<FIRRTLBaseType, RefType>(type);
104 if (
auto *op = value.getDefiningOp())
105 if (op->hasTrait<OpTrait::ConstantLike>())
128struct ModuleUpdateInfo {
130 ArrayAttr portDomainInfo;
138 CircuitState(CircuitOp circuit,
InstanceGraph &instanceGraph,
140 : circuit(circuit), instanceGraph(instanceGraph),
141 innerRefNamespace(innerRefNamespace), mode(mode) {
142 processCircuit(circuit);
147 ArrayRef<DomainOp> getDomains()
const {
return domainTable; }
148 size_t getNumDomains()
const {
return domainTable.size(); }
149 DomainOp getDomain(DomainTypeID
id)
const {
return domainTable[
id.index]; }
150 DomainTypeID getDomainTypeID(Type type) {
return typeIDTable[type]; }
152 void dirty() { asmState =
nullptr; }
153 AsmState &getAsmState() {
155 asmState = std::make_unique<AsmState>(
156 circuit, mlir::OpPrintingFlags().assumeVerified());
161 size_t getVariableID(VariableTerm *term) {
162 return variableIDTable.insert({term, variableIDTable.size() + 1})
166 DenseMap<StringAttr, ModuleUpdateInfo> &getModuleUpdateTable() {
167 return moduleUpdateTable;
173 LogicalResult runOnModule(Operation *module);
175 void processDomain(DomainOp op) {
176 auto index = domainTable.size();
177 auto domainType = DomainType::getFromDomainOp(op);
178 domainTable.push_back(op);
179 typeIDTable.insert({domainType, {index}});
182 void processCircuit(CircuitOp circuit) {
183 for (
auto decl : circuit.getOps<DomainOp>())
191 SmallVector<DomainOp> domainTable;
192 DenseMap<Type, DomainTypeID> typeIDTable;
193 DenseMap<VariableTerm *, size_t> variableIDTable;
194 std::unique_ptr<AsmState> asmState;
195 DenseMap<StringAttr, ModuleUpdateInfo> moduleUpdateTable;
215 constexpr Term(TermKind kind) : kind(kind) {}
223struct TermBase : Term {
224 static bool classof(
const Term *term) {
return term->kind == K; }
225 TermBase() : Term(K) {}
231struct VariableTerm :
public TermBase<TermKind::Variable> {
232 VariableTerm() : leader(nullptr) {}
233 VariableTerm(Term *leader) : leader(leader) {}
240struct ValueTerm :
public TermBase<TermKind::Value> {
248struct RowTerm :
public TermBase<TermKind::Row> {
249 RowTerm(ArrayRef<Term *> elements) : elements(elements) {}
250 ArrayRef<Term *> elements;
269struct PendingUpdates {
279using ExportTable = DenseMap<DomainValue, TinyPtrVector<DomainValue>>;
284 explicit ModuleState(CircuitState &globals) : globals(globals) {}
286 ArrayRef<DomainOp> getDomains() {
return globals.getDomains(); }
287 size_t getNumDomains() {
return globals.getNumDomains(); }
288 DomainOp getDomain(DomainTypeID
id) {
return globals.getDomain(
id); }
289 DomainTypeID getDomainTypeID(Type type) {
290 return globals.getDomainTypeID(type);
292 DomainTypeID getDomainTypeID(FModuleLike module,
size_t i) {
293 return globals.getDomainTypeID(module.getPortType(i));
295 DomainTypeID getDomainTypeID(FInstanceLike op,
size_t i)
const {
296 return globals.getDomainTypeID(op->getResult(i).getType());
298 DomainTypeID getDomainTypeID(
DomainValue value)
const {
299 return globals.getDomainTypeID(value.getType());
301 auto &getModuleUpdateTable() {
return globals.getModuleUpdateTable(); }
303 mlir::AsmState &getAsmState() {
return globals.getAsmState(); }
304 void dirty() { globals.dirty(); }
306 template <
typename T>
307 void render(Operation *op, T &out);
308 template <
typename T>
309 void render(Value value, T &out);
310 template <
typename T>
311 void render(Term *term, T &out);
312 template <
typename T>
314 template <
typename T>
315 Render<T> render(T &&subject);
318 LogicalResult unify(Term *lhs, Term *rhs);
319 LogicalResult unify(VariableTerm *x, Term *y);
320 LogicalResult unify(ValueTerm *xv, Term *y);
321 LogicalResult unify(RowTerm *lhsRow, Term *rhs);
322 void solve(Term *lhs, Term *rhs);
324 [[nodiscard]] RowTerm *allocRow(
size_t size);
325 [[nodiscard]] RowTerm *allocRow(ArrayRef<Term *> elements);
326 [[nodiscard]] VariableTerm *allocVar();
327 [[nodiscard]] ValueTerm *allocVal(
DomainValue value);
328 template <
typename T,
typename... Args>
329 T *alloc(Args &&...args);
330 ArrayRef<Term *> allocArray(ArrayRef<Term *> elements);
335 void setTermForDomain(
DomainValue value, Term *term);
337 Term *getOptDomainAssociation(Value value);
338 Term *getDomainAssociation(Value value);
339 void setDomainAssociation(Value value, Term *term);
342 RowTerm *getDomainAssociationAsRow(Value value);
344 void noteLocation(mlir::InFlightDiagnostic &diag, Operation *op);
345 template <
typename T>
346 void emitDuplicatePortDomainError(T op,
size_t i, DomainTypeID domainTypeID,
347 IntegerAttr domainPortIndexAttr1,
348 IntegerAttr domainPortIndexAttr2);
349 template <
typename T>
350 void emitDomainPortInferenceError(T op,
size_t i);
351 template <
typename T>
352 void emitAmbiguousPortDomainAssociation(
353 T op,
const llvm::TinyPtrVector<DomainValue> &exports,
354 DomainTypeID typeID,
size_t i);
355 template <
typename T>
356 void emitMissingPortDomainAssociationError(T op, DomainTypeID typeID,
359 LogicalResult unifyAssociations(Operation *op, Value lhs, Value rhs);
360 template <
typename T>
361 LogicalResult unifyAssociations(Operation *op, T &&range);
362 LogicalResult unifyAssociations(Operation *op);
364 LogicalResult processModulePorts(FModuleOp moduleOp);
365 template <
typename T>
366 LogicalResult processInstancePorts(T op);
367 FInstanceLike fixInstancePorts(FInstanceLike op,
368 const ModuleUpdateInfo &update);
369 LogicalResult processOp(FInstanceLike op);
370 LogicalResult processOp(UnsafeDomainCastOp op);
371 LogicalResult processOp(DomainDefineOp op);
372 LogicalResult processOp(WireOp op);
373 LogicalResult processOp(RWProbeOp op);
374 LogicalResult processOp(Operation *op);
375 LogicalResult processModuleBody(FModuleOp moduleOp);
376 LogicalResult processModule(FModuleOp moduleOp);
378 ExportTable initializeExportTable(FModuleOp moduleOp);
379 void ensureSolved(
Namespace &ns, DomainTypeID typeID,
size_t ip,
380 LocationAttr loc, VariableTerm *var,
381 PendingUpdates &pending);
383 DomainTypeID typeID,
size_t ip, LocationAttr loc,
384 ValueTerm *val, PendingUpdates &pending);
385 void getUpdatesForDomainAssociationOfPort(
Namespace &ns,
386 PendingUpdates &pending,
387 DomainTypeID typeID,
size_t ip,
388 LocationAttr loc, Term *term,
390 void getUpdatesForDomainAssociationOfPort(
Namespace &ns,
392 size_t ip, LocationAttr loc,
394 PendingUpdates &pending);
395 void getUpdatesForModulePorts(FModuleOp moduleOp,
const ExportTable &exports,
397 void getUpdatesForModule(FModuleOp moduleOp,
const ExportTable &exports,
398 PendingUpdates &pending);
399 void applyUpdatesToModule(FModuleOp moduleOp,
ExportTable &exports,
400 const PendingUpdates &pending);
401 SmallVector<Attribute> copyPortDomainAssociations(FModuleOp moduleOp,
402 ArrayAttr moduleDomainInfo,
404 LogicalResult driveModuleOutputDomainPorts(FModuleOp moduleOp);
405 LogicalResult updateModuleDomainInfo(FModuleOp moduleOp,
409 solveVarWithAnonDomain(OpBuilder &builder,
410 DenseMap<DomainValue, DomainValue> &domainsInScope,
411 Operation *user, DomainType type, VariableTerm *var);
413 getDomainInScope(OpBuilder &builder,
414 DenseMap<DomainValue, DomainValue> &domainsInScope,
417 updateInstance(DenseMap<DomainValue, DomainValue> &domainsInScope,
419 LogicalResult updateWire(DenseMap<DomainValue, DomainValue> &domainsInScope,
421 LogicalResult updateModuleBody(FModuleOp moduleOp);
422 LogicalResult updateModule(FModuleOp moduleOp);
424 LogicalResult checkModulePorts(FModuleLike moduleOp);
425 LogicalResult checkModuleDomainPortDrivers(FModuleOp moduleOp);
426 LogicalResult checkInstanceDomainPortDrivers(FInstanceLike op);
427 LogicalResult checkModuleBody(FModuleOp moduleOp);
429 LogicalResult inferModule(FModuleOp moduleOp);
430 LogicalResult checkModule(FModuleOp moduleOp);
431 LogicalResult checkModule(FExtModuleOp extModuleOp);
432 LogicalResult checkAndInferModule(FModuleOp moduleOp);
435 CircuitState &globals;
436 DenseMap<Value, Term *> termTable;
437 DenseMap<Value, Term *> associationTable;
438 llvm::BumpPtrAllocator allocator;
443void ModuleState::render(Operation *op, T &out) {
444 op->print(out, getAsmState());
448void ModuleState::render(Value value, T &out) {
456 llvm::raw_string_ostream os(name);
457 value.printAsOperand(os, globals.getAsmState());
461 if (
auto type = dyn_cast<DomainType>(value.getType()))
462 out <<
" : " << type.getName().getValue();
467void ModuleState::render(Term *term, T &out) {
473 if (
auto *var = dyn_cast<VariableTerm>(term)) {
474 out <<
"?" << globals.getVariableID(var);
477 if (
auto *val = dyn_cast<ValueTerm>(term)) {
478 auto value = val->value;
482 if (
auto *row = dyn_cast<RowTerm>(term)) {
484 llvm::interleaveComma(llvm::seq(
size_t(0), getNumDomains()), out,
485 [&](
auto i) { render(row->elements[i], out); });
493struct ModuleState::Render {
499ModuleState::Render<T> ModuleState::render(T &&subject) {
500 return Render<T>{
this, std::forward<T>(subject)};
505 ModuleState::Render<T> r) {
506 r.state->render(r.subject, out);
511Term *ModuleState::find(Term *x) {
515 if (
auto *var = dyn_cast<VariableTerm>(x)) {
516 if (var->leader ==
nullptr)
519 auto *leader = find(var->leader);
520 if (leader != var->leader)
521 var->leader = leader;
528LogicalResult ModuleState::unify(VariableTerm *x, Term *y) {
534LogicalResult ModuleState::unify(ValueTerm *xv, Term *y) {
535 if (
auto *yv = dyn_cast<VariableTerm>(y)) {
540 if (
auto *yv = dyn_cast<ValueTerm>(y))
541 return success(xv == yv);
547LogicalResult ModuleState::unify(RowTerm *lhsRow, Term *rhs) {
548 if (
auto *rhsVar = dyn_cast<VariableTerm>(rhs)) {
549 rhsVar->leader = lhsRow;
552 if (
auto *rhsRow = dyn_cast<RowTerm>(rhs)) {
553 for (
auto [x, y] :
llvm::zip_equal(lhsRow->elements, rhsRow->elements))
554 if (failed(unify(x, y)))
562LogicalResult ModuleState::unify(Term *lhs, Term *rhs) {
570 LLVM_DEBUG(llvm::dbgs().indent(6)
571 <<
"unify " << render(lhs) <<
" = " << render(rhs) <<
"\n");
573 if (
auto *lhsVar = dyn_cast<VariableTerm>(lhs))
574 return unify(lhsVar, rhs);
575 if (
auto *lhsVal = dyn_cast<ValueTerm>(lhs))
576 return unify(lhsVal, rhs);
577 if (
auto *lhsRow = dyn_cast<RowTerm>(lhs))
578 return unify(lhsRow, rhs);
582void ModuleState::solve(Term *lhs, Term *rhs) {
583 [[maybe_unused]]
auto result = unify(lhs, rhs);
584 assert(result.succeeded());
587RowTerm *ModuleState::allocRow(
size_t size) {
588 SmallVector<Term *> elements;
589 elements.resize(size);
590 return allocRow(elements);
593RowTerm *ModuleState::allocRow(ArrayRef<Term *> elements) {
594 auto ds = allocArray(elements);
595 return alloc<RowTerm>(ds);
598VariableTerm *ModuleState::allocVar() {
return alloc<VariableTerm>(); }
600ValueTerm *ModuleState::allocVal(
DomainValue value) {
601 return alloc<ValueTerm>(value);
604template <
typename T,
typename... Args>
605T *ModuleState::alloc(Args &&...args) {
606 static_assert(std::is_base_of_v<Term, T>,
"T must be a term");
607 return new (allocator) T(std::forward<Args>(args)...);
610ArrayRef<Term *> ModuleState::allocArray(ArrayRef<Term *> elements) {
611 auto size = elements.size();
615 auto *result = allocator.Allocate<Term *>(size);
616 llvm::uninitialized_copy(elements, result);
617 for (
size_t i = 0; i < size; ++i)
619 result[i] = alloc<VariableTerm>();
621 return ArrayRef(result, size);
625 auto *term = getOptTermForDomain(value);
626 if (
auto *val = llvm::dyn_cast_if_present<ValueTerm>(term))
631Term *ModuleState::getOptTermForDomain(
DomainValue value) {
632 assert(isa<DomainType>(value.getType()));
633 auto it = termTable.find(value);
634 if (it == termTable.end())
636 return find(it->second);
639Term *ModuleState::getTermForDomain(
DomainValue value) {
640 assert(isa<DomainType>(value.getType()));
641 if (
auto *term = getOptTermForDomain(value))
643 auto *term = allocVar();
644 setTermForDomain(value, term);
648void ModuleState::setTermForDomain(
DomainValue value, Term *term) {
650 assert(!termTable.contains(value));
651 termTable.insert({value, term});
652 LLVM_DEBUG(llvm::dbgs().indent(6)
653 <<
"set " << render(value) <<
" := " << render(term) <<
"\n");
656Term *ModuleState::getOptDomainAssociation(Value value) {
658 auto it = associationTable.find(value);
659 if (it == associationTable.end())
661 return find(it->second);
664Term *ModuleState::getDomainAssociation(Value value) {
665 auto *term = getOptDomainAssociation(value);
670void ModuleState::setDomainAssociation(Value value, Term *term) {
674 associationTable.insert({value, term});
676 llvm::dbgs().indent(6) <<
"set domains(" << render(value)
677 <<
") := " << render(term) <<
"\n";
681void ModuleState::processDomainDefinition(
DomainValue domain) {
682 assert(isa<DomainType>(domain.getType()));
683 auto *newTerm = allocVal(domain);
684 auto *oldTerm = getOptTermForDomain(domain);
686 setTermForDomain(domain, newTerm);
690 [[maybe_unused]]
auto result = unify(oldTerm, newTerm);
691 assert(result.succeeded());
694RowTerm *ModuleState::getDomainAssociationAsRow(Value value) {
696 auto *term = getOptDomainAssociation(value);
700 auto *row = allocRow(getNumDomains());
701 setDomainAssociation(value, row);
706 if (
auto *row = dyn_cast<RowTerm>(term))
710 if (
auto *var = dyn_cast<VariableTerm>(term)) {
711 auto *row = allocRow(getNumDomains());
716 assert(
false &&
"unhandled term type");
720void ModuleState::noteLocation(mlir::InFlightDiagnostic &diag, Operation *op) {
721 auto ¬e = diag.attachNote(op->getLoc());
722 if (
auto mod = dyn_cast<FModuleOp>(op)) {
723 note <<
"in module " << mod.getModuleNameAttr();
726 if (
auto mod = dyn_cast<FExtModuleOp>(op)) {
727 note <<
"in extmodule " << mod.getModuleNameAttr();
730 if (
auto inst = dyn_cast<InstanceOp>(op)) {
731 note <<
"in instance " << inst.getInstanceNameAttr();
734 if (
auto inst = dyn_cast<InstanceChoiceOp>(op)) {
735 note <<
"in instance_choice " << inst.getNameAttr();
743void ModuleState::emitDuplicatePortDomainError(
744 T op,
size_t i, DomainTypeID domainTypeID, IntegerAttr domainPortIndexAttr1,
745 IntegerAttr domainPortIndexAttr2) {
746 auto portName = op.getPortNameAttr(i);
747 auto portLoc = op.getPortLocation(i);
748 auto domainDecl = getDomain(domainTypeID);
749 auto domainName = domainDecl.getNameAttr();
750 auto domainPortIndex1 = domainPortIndexAttr1.getUInt();
751 auto domainPortIndex2 = domainPortIndexAttr2.getUInt();
752 auto domainPortName1 = op.getPortNameAttr(domainPortIndex1);
753 auto domainPortName2 = op.getPortNameAttr(domainPortIndex2);
754 auto domainPortLoc1 = op.getPortLocation(domainPortIndex1);
755 auto domainPortLoc2 = op.getPortLocation(domainPortIndex2);
756 auto diag = emitError(portLoc);
757 diag <<
"duplicate " << domainName <<
" association for port " << portName;
758 auto ¬e1 = diag.attachNote(domainPortLoc1);
759 note1 <<
"associated with " << domainName <<
" port " << domainPortName1;
760 auto ¬e2 = diag.attachNote(domainPortLoc2);
761 note2 <<
"associated with " << domainName <<
" port " << domainPortName2;
762 noteLocation(diag, op);
768void ModuleState::emitDomainPortInferenceError(T op,
size_t i) {
769 auto name = op.getPortNameAttr(i);
770 auto diag = emitError(op->getLoc());
771 auto info = op.getDomainInfo();
772 diag <<
"unable to infer value for undriven domain port " << name;
773 for (
size_t j = 0, e = op.getNumPorts(); j < e; ++j) {
774 if (
auto assocs = dyn_cast<ArrayAttr>(info[j])) {
775 for (
auto assoc : assocs) {
776 if (i == cast<IntegerAttr>(assoc).getValue()) {
777 auto name = op.getPortNameAttr(j);
778 auto loc = op.getPortLocation(j);
779 diag.attachNote(loc) <<
"associated with hardware port " << name;
785 noteLocation(diag, op);
789void ModuleState::emitAmbiguousPortDomainAssociation(
790 T op,
const llvm::TinyPtrVector<DomainValue> &exports, DomainTypeID typeID,
792 auto portName = op.getPortNameAttr(i);
793 auto portLoc = op.getPortLocation(i);
794 auto domainDecl = getDomain(typeID);
795 auto domainName = domainDecl.getNameAttr();
796 auto diag = emitError(portLoc) <<
"ambiguous " << domainName
797 <<
" association for port " << portName;
798 for (
auto e : exports) {
799 auto arg = cast<BlockArgument>(e);
800 auto name = op.getPortNameAttr(arg.getArgNumber());
801 auto loc = op.getPortLocation(arg.getArgNumber());
802 diag.attachNote(loc) <<
"candidate association " << name;
804 noteLocation(diag, op);
808void ModuleState::emitMissingPortDomainAssociationError(T op,
811 auto domainName = getDomain(typeID).getNameAttr();
812 auto portName = op.getPortNameAttr(i);
813 auto diag = emitError(op.getPortLocation(i))
814 <<
"missing " << domainName <<
" association for port "
816 noteLocation(diag, op);
819LogicalResult ModuleState::unifyAssociations(Operation *op, Value lhs,
828 llvm::dbgs().indent(6) <<
"unify domains(" << render(lhs) <<
") = domains("
829 << render(rhs) <<
")\n";
832 auto *lhsTerm = getOptDomainAssociation(lhs);
833 auto *rhsTerm = getOptDomainAssociation(rhs);
837 if (failed(unify(lhsTerm, rhsTerm))) {
838 auto diag = op->emitOpError(
"illegal domain crossing in operation");
839 auto ¬e1 = diag.attachNote(lhs.getLoc());
840 note1 <<
"1st operand has domains: ";
841 render(lhsTerm, note1);
842 auto ¬e2 = diag.attachNote(rhs.getLoc());
843 note2 <<
"2nd operand has domains: ";
844 render(rhsTerm, note2);
849 setDomainAssociation(rhs, lhsTerm);
854 setDomainAssociation(lhs, rhsTerm);
858 auto *var = allocVar();
859 setDomainAssociation(lhs, var);
860 setDomainAssociation(rhs, var);
865LogicalResult ModuleState::unifyAssociations(Operation *op, T &&range) {
867 for (
auto rhs : std::forward<T>(range)) {
870 if (failed(unifyAssociations(op, lhs, rhs)))
878LogicalResult ModuleState::unifyAssociations(Operation *op) {
879 return unifyAssociations(
880 op, llvm::concat<Value>(op->getOperands(), op->getResults()));
883LogicalResult ModuleState::processModulePorts(FModuleOp moduleOp) {
884 auto numDomains = getNumDomains();
885 auto domainInfo = moduleOp.getDomainInfoAttr();
886 auto numPorts = moduleOp.getNumPorts();
888 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
889 for (
size_t i = 0; i < numPorts; ++i) {
890 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
894 LLVM_DEBUG(llvm::dbgs().indent(4)
895 <<
"process port " << render(port) <<
"\n");
897 if (moduleOp.getPortDirection(i) == Direction::In)
898 processDomainDefinition(port);
900 domainTypeIDTable[i] = getDomainTypeID(moduleOp, i);
903 for (
size_t i = 0; i < numPorts; ++i) {
904 BlockArgument port = moduleOp.getArgument(i);
908 LLVM_DEBUG(llvm::dbgs().indent(4)
909 <<
"process port " << render(port) <<
"\n");
911 SmallVector<IntegerAttr> associations(numDomains);
913 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
914 auto prevDomainPortIndex = associations[domainTypeID.index];
915 if (prevDomainPortIndex) {
916 emitDuplicatePortDomainError(moduleOp, i, domainTypeID,
917 prevDomainPortIndex, domainPortIndex);
920 associations[domainTypeID.index] = domainPortIndex;
923 SmallVector<Term *> elements(numDomains);
924 for (
size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
926 auto domainPortIndex = associations[domainTypeIndex];
927 if (!domainPortIndex)
929 auto domainPortValue =
930 cast<DomainValue>(moduleOp.getArgument(domainPortIndex.getUInt()));
931 elements[domainTypeIndex] = getTermForDomain(domainPortValue);
934 auto *domainAssociations = allocRow(elements);
935 setDomainAssociation(port, domainAssociations);
942LogicalResult ModuleState::processInstancePorts(T op) {
943 auto numDomains = getNumDomains();
944 auto domainInfo = op.getDomainInfoAttr();
945 auto numPorts = op.getNumPorts();
947 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
948 for (
size_t i = 0; i < numPorts; ++i) {
949 auto port = dyn_cast<DomainValue>(op->getResult(i));
953 if (op.getPortDirection(i) == Direction::Out)
954 processDomainDefinition(port);
956 domainTypeIDTable[i] = getDomainTypeID(op, i);
959 for (
size_t i = 0; i < numPorts; ++i) {
960 Value port = op->getResult(i);
964 SmallVector<IntegerAttr> associations(numDomains);
966 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
967 auto prevDomainPortIndex = associations[domainTypeID.index];
968 if (prevDomainPortIndex) {
969 emitDuplicatePortDomainError(op, i, domainTypeID, prevDomainPortIndex,
973 associations[domainTypeID.index] = domainPortIndex;
976 SmallVector<Term *> elements(numDomains);
977 for (
size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
979 auto domainPortIndex = associations[domainTypeIndex];
980 if (!domainPortIndex)
982 auto domainPortValue =
983 cast<DomainValue>(op->getResult(domainPortIndex.getUInt()));
984 elements[domainTypeIndex] = getTermForDomain(domainPortValue);
987 auto *domainAssociations = allocRow(elements);
988 setDomainAssociation(port, domainAssociations);
994FInstanceLike ModuleState::fixInstancePorts(FInstanceLike op,
995 const ModuleUpdateInfo &update) {
996 auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions);
997 clone.setDomainInfoAttr(update.portDomainInfo);
1000 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"fixup " << render(clone) <<
"\n");
1004LogicalResult ModuleState::processOp(FInstanceLike op) {
1006 cast<StringAttr>(cast<ArrayAttr>(op.getReferencedModuleNamesAttr())[0]);
1007 auto updateTable = getModuleUpdateTable();
1008 auto lookup = updateTable.find(moduleName);
1009 if (lookup != updateTable.end())
1010 op = fixInstancePorts(op, lookup->second);
1011 return processInstancePorts(op);
1014LogicalResult ModuleState::processOp(UnsafeDomainCastOp op) {
1015 auto domains = op.getDomains();
1016 if (domains.empty())
1017 return unifyAssociations(op, op.getInput(), op.getResult());
1019 auto input = op.getInput();
1020 RowTerm *inputRow = getDomainAssociationAsRow(input);
1021 SmallVector<Term *> elements(inputRow->elements);
1022 for (
auto value : op.getDomains()) {
1023 auto domain = cast<DomainValue>(value);
1024 auto typeID = getDomainTypeID(domain);
1025 elements[typeID.index] = getTermForDomain(domain);
1028 auto *row = allocRow(elements);
1029 setDomainAssociation(op.getResult(), row);
1033LogicalResult ModuleState::processOp(DomainDefineOp op) {
1034 auto src = op.getSrc();
1035 auto dst = op.getDest();
1037 auto *srcTerm = getTermForDomain(src);
1038 auto *dstTerm = getTermForDomain(dst);
1039 if (succeeded(unify(dstTerm, srcTerm)))
1044 <<
"defines a domain value that was inferred to be a different domain '";
1045 render(dstTerm, diag);
1051LogicalResult ModuleState::processOp(WireOp op) {
1058 if (op.getDomains().empty())
1059 return unifyAssociations(op, op.getResults());
1063 SmallVector<Term *> elements(getNumDomains());
1064 for (
auto domain : op.getDomains()) {
1065 auto domainValue = cast<DomainValue>(domain);
1066 auto typeID = getDomainTypeID(domainValue);
1067 elements[typeID.index] = getTermForDomain(domainValue);
1070 auto *row = allocRow(elements);
1071 for (
auto result : op.getResults())
1072 setDomainAssociation(result, row);
1077LogicalResult ModuleState::processOp(RWProbeOp op) {
1078 auto target = globals.getInnerRefNamespace().lookup(op.getTarget());
1080 if (target.isPort()) {
1081 auto targetOp = cast<FModuleOp>(target.getOp());
1082 auto targetValue = targetOp.getArgument(target.getPort());
1083 return unifyAssociations(op, targetValue, op.getResult());
1086 auto targetOp = cast<hw::InnerSymbolOpInterface>(target.getOp());
1087 auto targetValue = targetOp.getTargetResult();
1088 return unifyAssociations(op, targetValue, op.getResult());
1091LogicalResult ModuleState::processOp(Operation *op) {
1092 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"process " << render(op) <<
"\n");
1093 if (
auto instance = dyn_cast<FInstanceLike>(op))
1094 return processOp(instance);
1095 if (
auto wireOp = dyn_cast<WireOp>(op))
1096 return processOp(wireOp);
1097 if (
auto cast = dyn_cast<UnsafeDomainCastOp>(op))
1098 return processOp(cast);
1099 if (
auto def = dyn_cast<DomainDefineOp>(op))
1100 return processOp(def);
1101 if (
auto probe = dyn_cast<RWProbeOp>(op))
1102 return processOp(probe);
1103 if (
auto create = dyn_cast<DomainCreateOp>(op)) {
1104 processDomainDefinition(create);
1107 if (
auto createAnon = dyn_cast<DomainCreateAnonOp>(op)) {
1108 processDomainDefinition(createAnon);
1112 return unifyAssociations(op);
1115LogicalResult ModuleState::processModuleBody(FModuleOp moduleOp) {
1118 .walk([&](Operation *op) -> WalkResult { return processOp(op); })
1122LogicalResult ModuleState::processModule(FModuleOp moduleOp) {
1123 LLVM_DEBUG(llvm::dbgs().indent(2) <<
"processing:\n");
1124 if (failed(processModulePorts(moduleOp)))
1126 if (failed(processModuleBody(moduleOp)))
1131ExportTable ModuleState::initializeExportTable(FModuleOp moduleOp) {
1133 size_t numPorts = moduleOp.getNumPorts();
1134 for (
size_t i = 0; i < numPorts; ++i) {
1135 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1138 auto value = getOptUnderlyingDomain(port);
1140 exports[value].push_back(port);
1144 llvm::dbgs().indent(2) <<
"domain exports:\n";
1145 for (
auto entry : exports) {
1146 llvm::dbgs().indent(4) << render(entry.first) <<
" exported as ";
1147 llvm::interleaveComma(entry.second, llvm::dbgs(),
1148 [&](
auto e) { llvm::dbgs() << render(e); });
1149 llvm::dbgs() <<
"\n";
1156void ModuleState::ensureSolved(
Namespace &ns, DomainTypeID typeID,
size_t ip,
1157 LocationAttr loc, VariableTerm *var,
1158 PendingUpdates &pending) {
1159 if (pending.solutions.contains(var))
1162 auto *
context = loc.getContext();
1163 auto domainDecl = getDomain(typeID);
1164 auto domainName = domainDecl.getNameAttr();
1166 auto portName = StringAttr::get(
context, ns.
newName(domainName.getValue()));
1167 auto portType = DomainType::getFromDomainOp(domainDecl);
1168 auto portDirection = Direction::In;
1169 auto portSym = StringAttr();
1171 auto portAnnos = std::nullopt;
1173 auto portDomainInfo = ArrayAttr::get(
context, {});
1174 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1175 portAnnos, portDomainInfo);
1177 pending.solutions[var] = pending.insertions.size() + ip;
1178 pending.insertions.push_back({ip, portInfo});
1182 DomainTypeID typeID,
size_t ip,
1183 LocationAttr loc, ValueTerm *val,
1184 PendingUpdates &pending) {
1185 auto value = val->value;
1186 assert(isa<DomainType>(value.getType()));
1187 if (
isPort(value) || exports.contains(value) ||
1188 pending.exports.contains(value))
1191 auto *
context = loc.getContext();
1193 auto domainDecl = getDomain(typeID);
1194 auto domainName = domainDecl.getNameAttr();
1196 auto portName = StringAttr::get(
context, ns.
newName(domainName.getValue()));
1197 auto portType = DomainType::getFromDomainOp(domainDecl);
1198 auto portDirection = Direction::Out;
1199 auto portSym = StringAttr();
1200 auto portLoc = value.getLoc();
1201 auto portAnnos = std::nullopt;
1203 auto portDomainInfo = ArrayAttr::get(
context, {});
1204 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1205 portAnnos, portDomainInfo);
1206 pending.exports[value] = pending.insertions.size() + ip;
1207 pending.insertions.push_back({ip, portInfo});
1210void ModuleState::getUpdatesForDomainAssociationOfPort(
1211 Namespace &ns, PendingUpdates &pending, DomainTypeID typeID,
size_t ip,
1212 LocationAttr loc, Term *term,
const ExportTable &exports) {
1213 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1214 ensureSolved(ns, typeID, ip, loc, var, pending);
1217 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1218 ensureExported(ns, exports, typeID, ip, loc, val, pending);
1221 llvm_unreachable(
"invalid domain association");
1224void ModuleState::getUpdatesForDomainAssociationOfPort(
1226 RowTerm *row, PendingUpdates &pending) {
1227 for (
auto [index, term] :
llvm::enumerate(row->elements))
1228 getUpdatesForDomainAssociationOfPort(ns, pending, DomainTypeID{index}, ip,
1229 loc, find(term), exports);
1232void ModuleState::getUpdatesForModulePorts(FModuleOp moduleOp,
1235 PendingUpdates &pending) {
1236 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1237 auto port = moduleOp.getArgument(i);
1241 getUpdatesForDomainAssociationOfPort(
1242 ns, exports, i, moduleOp.getPortLocation(i),
1243 getDomainAssociationAsRow(port), pending);
1247void ModuleState::getUpdatesForModule(FModuleOp moduleOp,
1249 PendingUpdates &pending) {
1251 auto names = moduleOp.getPortNamesAttr();
1252 for (
auto name : names.getAsRange<StringAttr>())
1254 getUpdatesForModulePorts(moduleOp, exports, ns, pending);
1257void ModuleState::applyUpdatesToModule(FModuleOp moduleOp,
ExportTable &exports,
1258 const PendingUpdates &pending) {
1259 LLVM_DEBUG(llvm::dbgs().indent(2) <<
"applying updates:\n");
1261 moduleOp.insertPorts(pending.insertions);
1265 for (
auto [var, portIndex] : pending.solutions) {
1266 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1267 auto *solution = allocVal(portValue);
1268 LLVM_DEBUG(llvm::dbgs().indent(4)
1269 <<
"new-input " << render(portValue) <<
"\n");
1270 solve(var, solution);
1271 exports[portValue].push_back(portValue);
1275 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1276 for (
auto [domainValue, portIndex] : pending.exports) {
1277 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1278 builder.setInsertionPointAfterValue(domainValue);
1279 DomainDefineOp::create(builder, portValue.getLoc(), portValue, domainValue);
1280 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"new-output " << render(portValue)
1281 <<
" := " << render(domainValue) <<
"\n");
1282 exports[domainValue].push_back(portValue);
1283 setTermForDomain(portValue, allocVal(domainValue));
1287SmallVector<Attribute> ModuleState::copyPortDomainAssociations(
1288 FModuleOp moduleOp, ArrayAttr moduleDomainInfo,
size_t portIndex) {
1289 SmallVector<Attribute> result(getNumDomains());
1291 for (
auto domainPortIndexAttr : oldAssociations) {
1292 auto domainPortIndex = domainPortIndexAttr.getUInt();
1293 auto domainTypeID = getDomainTypeID(moduleOp, domainPortIndex);
1294 result[domainTypeID.index] = domainPortIndexAttr;
1299LogicalResult ModuleState::driveModuleOutputDomainPorts(FModuleOp moduleOp) {
1300 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1301 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1302 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1303 if (!port || moduleOp.getPortDirection(i) == Direction::In ||
1307 auto *term = getOptTermForDomain(port);
1308 auto *val = llvm::dyn_cast_if_present<ValueTerm>(term);
1310 emitDomainPortInferenceError(moduleOp, i);
1314 auto loc = port.getLoc();
1315 auto value = val->value;
1316 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"connect " << render(port)
1317 <<
" := " << render(value) <<
"\n");
1318 DomainDefineOp::create(builder, loc, port, value);
1324LogicalResult ModuleState::updateModuleDomainInfo(
1325 FModuleOp moduleOp,
const ExportTable &exportTable, ArrayAttr &result) {
1330 auto *
context = moduleOp.getContext();
1331 auto numDomains = getNumDomains();
1332 auto oldModuleDomainInfo = moduleOp.getDomainInfoAttr();
1333 auto numPorts = moduleOp.getNumPorts();
1334 SmallVector<Attribute> newModuleDomainInfo(numPorts);
1336 for (
size_t i = 0; i < numPorts; ++i) {
1337 auto port = moduleOp.getArgument(i);
1338 auto type = port.getType();
1340 if (isa<DomainType>(type)) {
1342 newModuleDomainInfo[i] = ArrayAttr::get(
context, {});
1347 newModuleDomainInfo[i] = ArrayAttr::get(
context, {});
1352 copyPortDomainAssociations(moduleOp, oldModuleDomainInfo, i);
1353 auto *row = cast<RowTerm>(getDomainAssociation(port));
1354 for (
size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1355 auto domainTypeID = DomainTypeID{domainIndex};
1356 if (associations[domainIndex])
1359 auto domain = cast<ValueTerm>(find(row->elements[domainIndex]))->value;
1360 auto &exports = exportTable.at(domain);
1361 if (exports.empty()) {
1362 auto portName = moduleOp.getPortNameAttr(i);
1363 auto portLoc = moduleOp.getPortLocation(i);
1364 auto domainDecl = getDomain(domainTypeID);
1365 auto domainName = domainDecl.getNameAttr();
1366 auto diag = emitError(portLoc) <<
"private " << domainName
1367 <<
" association for port " << portName;
1368 diag.attachNote(domain.getLoc()) <<
"associated domain: " << domain;
1369 noteLocation(diag, moduleOp);
1373 if (exports.size() > 1) {
1374 emitAmbiguousPortDomainAssociation(moduleOp, exports, domainTypeID, i);
1378 auto argument = cast<BlockArgument>(exports[0]);
1379 auto domainPortIndex = argument.getArgNumber();
1380 associations[domainTypeID.index] =
1381 IntegerAttr::get(IntegerType::get(
context, 32, IntegerType::Unsigned),
1385 newModuleDomainInfo[i] = ArrayAttr::get(
context, associations);
1388 result = ArrayAttr::get(moduleOp.getContext(), newModuleDomainInfo);
1389 moduleOp.setDomainInfoAttr(result);
1394 OpBuilder &builder, DenseMap<DomainValue, DomainValue> &domainsInScope,
1395 Operation *user, DomainType type, VariableTerm *var) {
1396 auto name = type.getName().getAttr();
1398 DomainCreateAnonOp::create(builder, user->getLoc(), type, name);
1400 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"create anon " << render(anon) <<
"\n");
1401 solve(var, allocVal(anon));
1402 domainsInScope[anon] = anon;
1407 OpBuilder &builder, DenseMap<DomainValue, DomainValue> &domainsInScope,
1409 auto &domainInScope = domainsInScope[domain];
1411 return domainInScope;
1413 domainInScope = cast<DomainValue>(
1414 WireOp::create(builder, domain.getLoc(), domain.getType(),
1415 domain.getType().getName().getAttr())
1418 OpBuilder::InsertionGuard guard(builder);
1419 builder.setInsertionPointAfterValue(domain);
1420 DomainDefineOp::create(builder, domain.getLoc(), domainInScope, domain);
1422 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"bounce wire " << render(domainInScope)
1423 <<
" := " << render(domain) <<
"\n");
1424 return domainInScope;
1428ModuleState::updateInstance(DenseMap<DomainValue, DomainValue> &domainsInScope,
1430 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"update " << render(op) <<
"\n");
1431 OpBuilder builder(op.getContext());
1432 builder.setInsertionPointAfter(op);
1433 auto numPorts = op->getNumResults();
1435 for (
size_t i = 0; i < numPorts; ++i)
1436 if (
auto port = dyn_cast<DomainValue>(op->getResult(i)))
1437 if (op.getPortDirection(i) == Direction::Out)
1438 domainsInScope[port] = port;
1440 for (
size_t i = 0; i < numPorts; ++i) {
1441 auto port = dyn_cast<DomainValue>(op->getResult(i));
1442 auto direction = op.getPortDirection(i);
1446 if (port && direction == Direction::In && !
isDriven(port)) {
1447 auto loc = port.getLoc();
1448 auto *term = getTermForDomain(port);
1449 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1450 auto domain = solveVarWithAnonDomain(builder, domainsInScope, op,
1451 port.getType(), var);
1452 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"connect " << render(port)
1453 <<
" := " << render(domain) <<
"\n");
1454 DomainDefineOp::create(builder, loc, port, domain);
1457 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1458 auto domain = getDomainInScope(builder, domainsInScope, val->value);
1459 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"connect " << render(port)
1460 <<
" := " << render(domain) <<
"\n");
1461 DomainDefineOp::create(builder, loc, port, domain);
1464 llvm_unreachable(
"unhandled domain term type");
1472ModuleState::updateWire(DenseMap<DomainValue, DomainValue> &domainsInScope,
1474 auto result = wireOp.getResult();
1478 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"update " << render(wireOp) <<
"\n");
1479 OpBuilder builder(wireOp);
1480 auto *row = getDomainAssociationAsRow(wireOp.getResult());
1482 SmallVector<Value> domainOperands;
1483 for (
auto [i, element] :
llvm::enumerate(
1484 llvm::map_range(row->elements, [&](auto e) {
return find(e); }))) {
1485 if (
auto *val = dyn_cast<ValueTerm>(element)) {
1486 domainOperands.push_back(
1487 getDomainInScope(builder, domainsInScope, val->value));
1490 if (
auto *var = dyn_cast<VariableTerm>(element)) {
1491 auto type = DomainType::getFromDomainOp(getDomain(DomainTypeID{i}));
1493 solveVarWithAnonDomain(builder, domainsInScope, wireOp, type, var);
1494 domainOperands.push_back(domain);
1497 assert(0 &&
"unhandled domain type");
1499 wireOp.getDomainsMutable().assign(domainOperands);
1503LogicalResult ModuleState::updateModuleBody(FModuleOp moduleOp) {
1504 DenseMap<DomainValue, DomainValue> domainsInScope;
1506 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i)
1507 if (
auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i)))
1508 if (moduleOp.getPortDirection(i) == Direction::In)
1509 domainsInScope[port] = port;
1511 auto result = moduleOp.getBodyBlock()->walk([&](Operation *op) -> WalkResult {
1512 return TypeSwitch<Operation *, WalkResult>(op)
1514 [&](
auto wire) {
return updateWire(domainsInScope, wire); })
1515 .Case<FInstanceLike>([&](
auto instance) {
1516 return updateInstance(domainsInScope, instance);
1518 .Case<DomainCreateOp, DomainCreateAnonOp>([&](
auto domain) {
1519 domainsInScope[domain] = domain;
1522 .Default([&](
auto op) {
return success(); });
1524 return failure(result.wasInterrupted());
1527LogicalResult ModuleState::updateModule(FModuleOp moduleOp) {
1528 auto exports = initializeExportTable(moduleOp);
1529 PendingUpdates pending;
1530 getUpdatesForModule(moduleOp, exports, pending);
1531 applyUpdatesToModule(moduleOp, exports, pending);
1533 ArrayAttr portDomainInfo;
1534 if (failed(updateModuleDomainInfo(moduleOp, exports, portDomainInfo)))
1537 if (failed(driveModuleOutputDomainPorts(moduleOp)))
1541 auto &entry = getModuleUpdateTable()[moduleOp.getModuleNameAttr()];
1542 entry.portDomainInfo = portDomainInfo;
1543 entry.portInsertions = std::move(pending.insertions);
1545 if (failed(updateModuleBody(moduleOp)))
1549 llvm::dbgs().indent(2) <<
"port summary:\n";
1550 for (
auto port : moduleOp.
getBodyBlock()->getArguments()) {
1551 llvm::dbgs().indent(4) << render(port);
1552 auto info = cast<ArrayAttr>(
1553 moduleOp.getDomainInfoAttrForPort(port.getArgNumber()));
1555 llvm::dbgs() <<
" domains [";
1556 llvm::interleaveComma(
1557 info.getAsRange<IntegerAttr>(), llvm::dbgs(), [&](
auto i) {
1558 llvm::dbgs() << render(moduleOp.getArgument(i.getUInt()));
1560 llvm::dbgs() <<
"]";
1562 llvm::dbgs() <<
"\n";
1569LogicalResult ModuleState::checkModulePorts(FModuleLike moduleOp) {
1570 auto numDomains = getNumDomains();
1571 auto domainInfo = moduleOp.getDomainInfoAttr();
1572 auto numPorts = moduleOp.getNumPorts();
1574 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1575 for (
size_t i = 0; i < numPorts; ++i) {
1576 if (isa<DomainType>(moduleOp.getPortType(i)))
1577 domainTypeIDTable[i] = getDomainTypeID(moduleOp, i);
1580 for (
size_t i = 0; i < numPorts; ++i) {
1585 SmallVector<IntegerAttr> associations(numDomains);
1587 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1588 auto prevDomainPortIndex = associations[domainTypeID.index];
1589 if (prevDomainPortIndex) {
1590 emitDuplicatePortDomainError(moduleOp, i, domainTypeID,
1591 prevDomainPortIndex, domainPortIndex);
1594 associations[domainTypeID.index] = domainPortIndex;
1598 for (
size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1599 auto typeID = DomainTypeID{domainIndex};
1600 if (!associations[domainIndex]) {
1601 emitMissingPortDomainAssociationError(moduleOp, typeID, i);
1610LogicalResult ModuleState::checkModuleDomainPortDrivers(FModuleOp moduleOp) {
1611 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1612 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1613 if (!port || moduleOp.getPortDirection(i) != Direction::Out ||
1617 auto name = moduleOp.getPortNameAttr(i);
1618 auto diag = emitError(moduleOp.getPortLocation(i))
1619 <<
"undriven domain port " << name;
1620 noteLocation(diag, moduleOp);
1627LogicalResult ModuleState::checkInstanceDomainPortDrivers(FInstanceLike op) {
1628 for (
size_t i = 0, e = op->getNumResults(); i < e; ++i) {
1629 auto port = dyn_cast<DomainValue>(op->getResult(i));
1631 auto type = port.getType();
1632 if (!isa<DomainType>(type) || op.getPortDirection(i) != Direction::In ||
1636 auto name = op.getPortNameAttr(i);
1637 auto diag = emitError(op.getPortLocation(i))
1638 <<
"undriven domain port " << name;
1639 noteLocation(diag, op);
1646LogicalResult ModuleState::checkModuleBody(FModuleOp moduleOp) {
1647 auto result = moduleOp.getBody().walk([&](FInstanceLike op) -> WalkResult {
1648 return checkInstanceDomainPortDrivers(op);
1650 return failure(result.wasInterrupted());
1653LogicalResult ModuleState::inferModule(FModuleOp moduleOp) {
1654 LLVM_DEBUG(llvm::dbgs() <<
"infer: " << moduleOp.getModuleName() <<
"\n");
1655 if (failed(processModule(moduleOp)))
1658 return updateModule(moduleOp);
1661LogicalResult ModuleState::checkModule(FModuleOp moduleOp) {
1662 LLVM_DEBUG(llvm::dbgs() <<
"check: " << moduleOp.getModuleName() <<
"\n");
1663 if (failed(checkModulePorts(moduleOp)))
1666 if (failed(checkModuleDomainPortDrivers(moduleOp)))
1669 if (failed(checkModuleBody(moduleOp)))
1672 return processModule(moduleOp);
1675LogicalResult ModuleState::checkModule(FExtModuleOp extModuleOp) {
1676 LLVM_DEBUG(llvm::dbgs() <<
"check: " << extModuleOp.getModuleName() <<
"\n");
1677 return checkModulePorts(extModuleOp);
1680LogicalResult ModuleState::checkAndInferModule(FModuleOp moduleOp) {
1681 LLVM_DEBUG(llvm::dbgs() <<
"check/infer: " << moduleOp.getModuleName()
1684 if (failed(checkModulePorts(moduleOp)))
1687 if (failed(processModule(moduleOp)))
1690 if (failed(driveModuleOutputDomainPorts(moduleOp)))
1693 return updateModuleBody(moduleOp);
1701 WalkResult result = op->walk<mlir::WalkOrder::PostOrder, ReverseIterator>(
1702 [=](Operation *op) -> WalkResult {
1703 return TypeSwitch<Operation *, WalkResult>(op)
1704 .Case<FModuleLike>([](FModuleLike op) {
1705 auto n = op.getNumPorts();
1706 BitVector erasures(n);
1707 for (
size_t i = 0; i < n; ++i)
1708 if (isa<DomainType>(op.getPortType(i)))
1710 op.erasePorts(erasures);
1711 return WalkResult::advance();
1713 .Case<DomainDefineOp, DomainCreateAnonOp, DomainCreateOp>(
1716 return WalkResult::advance();
1718 .Case<DomainSubfieldOp>([](DomainSubfieldOp op) {
1719 if (!op->use_empty()) {
1720 OpBuilder builder(op);
1721 op.replaceAllUsesWith(
1722 UnknownValueOp::create(builder, op.getLoc(), op.getType())
1726 return WalkResult::advance();
1728 .Case<UnsafeDomainCastOp>([](UnsafeDomainCastOp op) {
1729 op.replaceAllUsesWith(op.getInput());
1731 return WalkResult::advance();
1733 .Case<WireOp>([](WireOp op) {
1735 if (isa<DomainType>(op.getType(0))) {
1737 return WalkResult::advance();
1740 if (!op.getDomains().empty()) {
1741 op->eraseOperands(0, op.getNumOperands());
1743 return WalkResult::advance();
1745 .Case<FInstanceLike>([](
auto op) {
1746 auto n = op.getNumPorts();
1747 BitVector erasures(n);
1748 for (
size_t i = 0; i < n; ++i)
1749 if (isa<DomainType>(op->getResult(i).getType()))
1751 op.cloneWithErasedPortsAndReplaceUses(erasures);
1753 return WalkResult::advance();
1755 .Default([](Operation *op) {
1757 concat<Type>(op->getOperandTypes(), op->getResultTypes())) {
1758 if (isa<DomainType>(type)) {
1759 op->emitOpError(
"cannot be stripped");
1760 return WalkResult::interrupt();
1763 return WalkResult::advance();
1766 return failure(result.wasInterrupted());
1770 llvm::SmallVector<FModuleLike> modules;
1771 for (Operation &op : make_early_inc_range(*circuit.getBodyBlock())) {
1772 TypeSwitch<Operation *, void>(&op)
1773 .Case<FModuleLike>([&](FModuleLike op) { modules.push_back(op); })
1774 .Case<DomainOp>([](DomainOp op) { op.erase(); });
1783LogicalResult CircuitState::runOnModule(Operation *op) {
1784 assert(mode != InferDomainsMode::Strip);
1785 ModuleState state(*
this);
1786 if (
auto moduleOp = dyn_cast<FModuleOp>(op)) {
1787 if (mode == InferDomainsMode::Check)
1788 return state.checkModule(moduleOp);
1790 if (mode == InferDomainsMode::InferAll || moduleOp.isPrivate())
1791 return state.inferModule(moduleOp);
1793 return state.checkAndInferModule(moduleOp);
1796 if (
auto extModuleOp = dyn_cast<FExtModuleOp>(op))
1797 return state.checkModule(extModuleOp);
1802LogicalResult CircuitState::run() {
1803 DenseSet<Operation *> errored;
1804 instanceGraph.walkPostOrder([&](
auto &node) {
1805 auto moduleOp = node.getModule();
1806 for (
auto *inst : node) {
1807 if (errored.contains(inst->getTarget()->getModule())) {
1808 errored.insert(moduleOp);
1812 if (failed(runOnModule(node.getModule())))
1813 errored.insert(moduleOp);
1815 return success(errored.empty());
1819struct InferDomainsPass
1820 :
public circt::firrtl::impl::InferDomainsBase<InferDomainsPass> {
1822 void runOnOperation()
override {
1824 auto circuit = getOperation();
1826 if (mode == InferDomainsMode::Strip) {
1828 signalPassFailure();
1832 auto &instanceGraph = getAnalysis<InstanceGraph>();
1833 auto &symbolTable = getAnalysis<SymbolTable>();
1834 auto &innerSymbolTableCollection =
1835 getAnalysis<InnerSymbolTableCollection>();
1837 innerSymbolTableCollection};
1838 CircuitState state(circuit, instanceGraph, innerRefNamespace, mode);
1839 if (failed(state.run()))
1840 signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
SmallVector< std::pair< unsigned, PortInfo > > PortInsertions
mlir::TypedValue< DomainType > DomainValue
static LogicalResult stripCircuit(MLIRContext *context, CircuitOp circuit)
DenseMap< VariableTerm *, unsigned > PendingSolutions
A map from unsolved variables to a port index, where that port has not yet been created.
static bool isHardware(Type type)
True if a value of the given type could be associated with a domain.
static bool isPort(BlockArgument arg)
Return true if the value is a port on the module.
DenseMap< DomainValue, TinyPtrVector< DomainValue > > ExportTable
A map from domain IR values defined internal to the moduleOp, to ports that alias that domain.
static auto getPortDomainAssociation(ArrayAttr info, size_t i)
From a domain info attribute, get the row of associated domains for a hardware value at index i.
static bool isDriven(DomainValue port)
Returns true if the value is driven by a connect op.
static LogicalResult stripModule(FModuleLike op)
static Block * getBodyBlock(FModuleLike mod)
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
This graph tracks modules and where they are instantiated.
This class represents a collection of InnerSymbolTable's.
InferDomainsMode
The mode for the InferDomains pass.
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
This holds the name and type that describes the module's ports.
This class represents the namespace in which InnerRef's can be resolved.