21#include "mlir/IR/Iterators.h"
22#include "mlir/IR/Threading.h"
23#include "llvm/ADT/DenseMap.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TinyPtrVector.h"
28#define DEBUG_TYPE "firrtl-infer-domains"
32#define GEN_PASS_DEF_INFERDOMAINS
33#include "circt/Dialect/FIRRTL/Passes.h.inc"
38using namespace firrtl;
41using mlir::ReverseIterator;
56 return cast<FlatSymbolRefAttr>(info[i]).getAttr();
63 return info.getAsRange<IntegerAttr>();
64 return cast<ArrayAttr>(info[i]).getAsRange<IntegerAttr>();
68static bool isPort(BlockArgument arg) {
69 return isa<FModuleOp>(arg.getOwner()->getParentOp());
74 auto arg = dyn_cast<BlockArgument>(value);
82 for (
auto *user : port.getUsers())
83 if (
auto connect = dyn_cast<FConnectLike>(user))
84 if (connect.getDest() == port)
109 DomainInfo(CircuitOp circuit) { processCircuit(circuit); }
111 ArrayRef<DomainOp> getDomains()
const {
return domainTable; }
112 size_t getNumDomains()
const {
return domainTable.size(); }
113 DomainOp getDomain(DomainTypeID
id)
const {
return domainTable[
id.index]; }
115 DomainTypeID getDomainTypeID(StringAttr name)
const {
116 return typeIDTable.at(name);
119 DomainTypeID getDomainTypeID(FlatSymbolRefAttr ref)
const {
120 return getDomainTypeID(ref.getAttr());
123 DomainTypeID getDomainTypeID(ArrayAttr info,
size_t i)
const {
125 return getDomainTypeID(name);
128 DomainTypeID getDomainTypeID(
DomainValue value)
const {
129 if (
auto arg = dyn_cast<BlockArgument>(value)) {
130 auto *block = arg.getOwner();
131 auto *owner = block->getParentOp();
132 auto moduleOp = cast<FModuleOp>(owner);
133 auto info = moduleOp.getDomainInfoAttr();
134 auto i = arg.getArgNumber();
135 return getDomainTypeID(info, i);
138 auto result = dyn_cast<OpResult>(value);
139 auto *owner = result.getOwner();
141 auto info = TypeSwitch<Operation *, ArrayAttr>(owner)
142 .Case<InstanceOp, InstanceChoiceOp>(
143 [&](
auto inst) {
return inst.getDomainInfoAttr(); })
144 .Default([&](
auto inst) {
return nullptr; });
145 assert(info &&
"unable to obtain domain information from op");
147 auto i = result.getResultNumber();
148 return getDomainTypeID(info, i);
152 void processDomain(DomainOp op) {
153 auto index = domainTable.size();
154 auto name = op.getNameAttr();
155 domainTable.push_back(op);
156 typeIDTable.insert({name, {index}});
159 void processCircuit(CircuitOp circuit) {
160 for (
auto decl : circuit.getOps<DomainOp>())
165 SmallVector<DomainOp> domainTable;
168 DenseMap<StringAttr, DomainTypeID> typeIDTable;
173struct ModuleUpdateInfo {
175 ArrayAttr portDomainInfo;
186 auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions);
187 clone.setDomainInfoAttr(update.portDomainInfo);
208 constexpr Term(TermKind kind) : kind(kind) {}
216struct TermBase : Term {
217 static bool classof(
const Term *term) {
return term->kind == K; }
218 TermBase() : Term(K) {}
224struct VariableTerm :
public TermBase<TermKind::Variable> {
225 VariableTerm() : leader(nullptr) {}
226 VariableTerm(Term *leader) : leader(leader) {}
233struct ValueTerm :
public TermBase<TermKind::Value> {
241struct RowTerm :
public TermBase<TermKind::Row> {
242 RowTerm(ArrayRef<Term *> elements) : elements(elements) {}
243 ArrayRef<Term *> elements;
252 if (
auto *var = dyn_cast<VariableTerm>(x)) {
253 if (var->leader ==
nullptr)
256 auto *leader =
find(var->leader);
257 if (leader != var->leader)
258 var->leader = leader;
267class VariableIDTable {
269 size_t get(VariableTerm *term) {
270 return table.insert({term, table.size() + 1}).first->second;
274 DenseMap<VariableTerm *, size_t> table;
279static void render(
const DomainInfo &info, Diagnostic &out,
280 VariableIDTable &idTable, Term *term) {
282 if (
auto *var = dyn_cast<VariableTerm>(term)) {
283 out <<
"?" << idTable.get(var);
286 if (
auto *val = dyn_cast<ValueTerm>(term)) {
287 auto value = val->value;
292 if (
auto *row = dyn_cast<RowTerm>(term)) {
295 for (
size_t i = 0, e = info.getNumDomains(); i < e; ++i) {
296 auto domainOp = info.getDomain(DomainTypeID{i});
301 out << domainOp.getName() <<
": ";
302 render(info, out, idTable, row->elements[i]);
309static LogicalResult
unify(Term *lhs, Term *rhs);
311static LogicalResult
unify(VariableTerm *x, Term *y) {
317static LogicalResult
unify(ValueTerm *xv, Term *y) {
318 if (
auto *yv = dyn_cast<VariableTerm>(y)) {
323 if (
auto *yv = dyn_cast<ValueTerm>(y))
324 return success(xv == yv);
330static LogicalResult
unify(RowTerm *lhsRow, Term *rhs) {
331 if (
auto *rhsVar = dyn_cast<VariableTerm>(rhs)) {
332 rhsVar->leader = lhsRow;
335 if (
auto *rhsRow = dyn_cast<RowTerm>(rhs)) {
336 for (
auto [x, y] : llvm::zip_equal(lhsRow->elements, rhsRow->elements))
337 if (failed(
unify(x, y)))
345static LogicalResult
unify(Term *lhs, Term *rhs) {
352 if (
auto *lhsVar = dyn_cast<VariableTerm>(lhs))
353 return unify(lhsVar, rhs);
354 if (
auto *lhsVal = dyn_cast<ValueTerm>(lhs))
355 return unify(lhsVal, rhs);
356 if (
auto *lhsRow = dyn_cast<RowTerm>(lhs))
357 return unify(lhsRow, rhs);
361static void solve(Term *lhs, Term *rhs) {
362 [[maybe_unused]]
auto result =
unify(lhs, rhs);
363 assert(result.succeeded());
370 [[nodiscard]] RowTerm *allocRow(
size_t size) {
371 SmallVector<Term *> elements;
372 elements.resize(size);
373 return allocRow(elements);
377 [[nodiscard]] RowTerm *allocRow(ArrayRef<Term *> elements) {
378 auto ds = allocArray(elements);
379 return alloc<RowTerm>(ds);
383 [[nodiscard]] VariableTerm *allocVar() {
return alloc<VariableTerm>(); }
386 [[nodiscard]] ValueTerm *allocVal(
DomainValue value) {
387 return alloc<ValueTerm>(value);
391 template <
typename T,
typename... Args>
392 [[nodiscard]] T *alloc(Args &&...args) {
393 static_assert(std::is_base_of_v<Term, T>,
"T must be a term");
394 return new (allocator) T(std::forward<Args>(args)...);
397 [[nodiscard]] ArrayRef<Term *> allocArray(ArrayRef<Term *> elements) {
398 auto size = elements.size();
402 auto *result = allocator.Allocate<Term *>(size);
403 llvm::uninitialized_copy(elements, result);
404 for (
size_t i = 0; i < size; ++i)
406 result[i] = alloc<VariableTerm>();
408 return ArrayRef(result, size);
411 llvm::BumpPtrAllocator allocator;
425 auto *term = getOptTermForDomain(value);
426 if (
auto *val = llvm::dyn_cast_if_present<ValueTerm>(term))
432 Term *getOptTermForDomain(
DomainValue value)
const {
433 assert(isa<DomainType>(value.getType()));
434 auto it = termTable.find(value);
435 if (it == termTable.end())
437 return find(it->second);
442 auto *term = getOptTermForDomain(value);
448 void setTermForDomain(
DomainValue value, Term *term) {
450 assert(!termTable.contains(value));
451 termTable.insert({value, term});
456 Term *getOptDomainAssociation(Value value)
const {
457 assert(isa<FIRRTLBaseType>(value.getType()));
458 auto it = associationTable.find(value);
459 if (it == associationTable.end())
461 return find(it->second);
466 Term *getDomainAssociation(Value value)
const {
467 auto *term = getOptDomainAssociation(value);
474 void setDomainAssociation(Value value, Term *term) {
475 assert(isa<FIRRTLBaseType>(value.getType()));
478 associationTable.insert({value, term});
483 DenseMap<Value, Term *> termTable;
486 DenseMap<Value, Term *> associationTable;
498 assert(isa<DomainType>(value.getType()));
499 if (
auto *term = table.getOptTermForDomain(value))
501 auto *term = allocator.allocVar();
502 table.setTermForDomain(value, term);
508 assert(isa<DomainType>(domain.getType()));
509 auto *newTerm = allocator.allocVal(domain);
510 auto *oldTerm = table.getOptTermForDomain(domain);
512 table.setTermForDomain(domain, newTerm);
516 [[maybe_unused]]
auto result =
unify(oldTerm, newTerm);
517 assert(result.succeeded());
523 TermAllocator &allocator,
524 DomainTable &table, Value value) {
525 assert(isa<FIRRTLBaseType>(value.getType()));
526 auto *term = table.getOptDomainAssociation(value);
530 auto *row = allocator.allocRow(info.getNumDomains());
531 table.setDomainAssociation(value, row);
536 if (
auto *row = dyn_cast<RowTerm>(term))
540 if (
auto *var = dyn_cast<VariableTerm>(term)) {
541 auto *row = allocator.allocRow(info.getNumDomains());
546 assert(
false &&
"unhandled term type");
550static void noteLocation(mlir::InFlightDiagnostic &diag, Operation *op) {
551 auto ¬e = diag.attachNote(op->getLoc());
552 if (
auto mod = dyn_cast<FModuleOp>(op)) {
553 note <<
"in module " << mod.getModuleNameAttr();
556 if (
auto mod = dyn_cast<FExtModuleOp>(op)) {
557 note <<
"in extmodule " << mod.getModuleNameAttr();
560 if (
auto inst = dyn_cast<InstanceOp>(op)) {
561 note <<
"in instance " << inst.getInstanceNameAttr();
564 if (
auto inst = dyn_cast<InstanceChoiceOp>(op)) {
565 note <<
"in instance_choice " << inst.getNameAttr();
574 DomainTypeID domainTypeID, Term *term1,
576 VariableIDTable idTable;
578 auto portName = op.getPortNameAttr(i);
579 auto portLoc = op.getPortLocation(i);
580 auto domainDecl = info.getDomain(domainTypeID);
581 auto domainName = domainDecl.getNameAttr();
583 auto diag = emitError(portLoc);
584 diag <<
"illegal " << domainName <<
" crossing in port " << portName;
586 auto ¬e1 = diag.attachNote();
587 note1 <<
"1st instance: ";
588 render(info, note1, idTable, term1);
590 auto ¬e2 = diag.attachNote();
591 note2 <<
"2nd instance: ";
592 render(info, note2, idTable, term2);
599 DomainTypeID domainTypeID,
600 IntegerAttr domainPortIndexAttr1,
601 IntegerAttr domainPortIndexAttr2) {
602 VariableIDTable idTable;
603 auto portName = op.getPortNameAttr(i);
604 auto portLoc = op.getPortLocation(i);
605 auto domainDecl = info.getDomain(domainTypeID);
606 auto domainName = domainDecl.getNameAttr();
607 auto domainPortIndex1 = domainPortIndexAttr1.getUInt();
608 auto domainPortIndex2 = domainPortIndexAttr2.getUInt();
609 auto domainPortName1 = op.getPortNameAttr(domainPortIndex1);
610 auto domainPortName2 = op.getPortNameAttr(domainPortIndex2);
611 auto domainPortLoc1 = op.getPortLocation(domainPortIndex1);
612 auto domainPortLoc2 = op.getPortLocation(domainPortIndex2);
613 auto diag = emitError(portLoc);
614 diag <<
"duplicate " << domainName <<
" association for port " << portName;
615 auto ¬e1 = diag.attachNote(domainPortLoc1);
616 note1 <<
"associated with " << domainName <<
" port " << domainPortName1;
617 auto ¬e2 = diag.attachNote(domainPortLoc2);
618 note2 <<
"associated with " << domainName <<
" port " << domainPortName2;
626 auto name = op.getPortNameAttr(i);
627 auto diag = emitError(op->getLoc());
628 auto info = op.getDomainInfo();
629 diag <<
"unable to infer value for undriven domain port " << name;
630 for (
size_t j = 0, e = op.getNumPorts(); j < e; ++j) {
631 if (
auto assocs = dyn_cast<ArrayAttr>(info[j])) {
632 for (
auto assoc : assocs) {
633 if (i == cast<IntegerAttr>(assoc).getValue()) {
634 auto name = op.getPortNameAttr(j);
635 auto loc = op.getPortLocation(j);
636 diag.attachNote(loc) <<
"associated with hardware port " << name;
647 const DomainInfo &info, T op,
648 const llvm::TinyPtrVector<DomainValue> &exports, DomainTypeID typeID,
650 auto portName = op.getPortNameAttr(i);
651 auto portLoc = op.getPortLocation(i);
652 auto domainDecl = info.getDomain(typeID);
653 auto domainName = domainDecl.getNameAttr();
654 auto diag = emitError(portLoc) <<
"ambiguous " << domainName
655 <<
" association for port " << portName;
656 for (
auto e : exports) {
657 auto arg = cast<BlockArgument>(e);
658 auto name = op.getPortNameAttr(arg.getArgNumber());
659 auto loc = op.getPortLocation(arg.getArgNumber());
660 diag.attachNote(loc) <<
"candidate association " << name;
669 auto domainName = info.getDomain(typeID).getNameAttr();
670 auto portName = op.getPortNameAttr(i);
671 auto diag = emitError(op.getPortLocation(i))
672 <<
"missing " << domainName <<
" association for port "
679 TermAllocator &allocator,
680 DomainTable &table, Operation *op,
681 Value lhs, Value rhs) {
688 auto *lhsTerm = table.getOptDomainAssociation(lhs);
689 auto *rhsTerm = table.getOptDomainAssociation(rhs);
693 if (failed(
unify(lhsTerm, rhsTerm))) {
694 auto diag = op->emitOpError(
"illegal domain crossing in operation");
695 auto ¬e1 = diag.attachNote(lhs.getLoc());
697 note1 <<
"1st operand has domains: ";
698 VariableIDTable idTable;
699 render(info, note1, idTable, lhsTerm);
701 auto ¬e2 = diag.attachNote(rhs.getLoc());
702 note2 <<
"2nd operand has domains: ";
703 render(info, note2, idTable, rhsTerm);
708 table.setDomainAssociation(rhs, lhsTerm);
713 table.setDomainAssociation(lhs, rhsTerm);
717 auto *var = allocator.allocVar();
718 table.setDomainAssociation(lhs, var);
719 table.setDomainAssociation(rhs, var);
724 TermAllocator &allocator,
726 FModuleOp moduleOp) {
727 auto numDomains = info.getNumDomains();
728 auto domainInfo = moduleOp.getDomainInfoAttr();
729 auto numPorts = moduleOp.getNumPorts();
731 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
732 for (
size_t i = 0; i < numPorts; ++i) {
733 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
737 if (moduleOp.getPortDirection(i) == Direction::In)
740 domainTypeIDTable[i] = info.getDomainTypeID(domainInfo, i);
743 for (
size_t i = 0; i < numPorts; ++i) {
744 BlockArgument port = moduleOp.getArgument(i);
745 auto type = type_dyn_cast<FIRRTLBaseType>(port.getType());
749 SmallVector<IntegerAttr> associations(numDomains);
751 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
752 auto prevDomainPortIndex = associations[domainTypeID.index];
753 if (prevDomainPortIndex) {
755 prevDomainPortIndex, domainPortIndex);
758 associations[domainTypeID.index] = domainPortIndex;
761 SmallVector<Term *> elements(numDomains);
762 for (
size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
764 auto domainPortIndex = associations[domainTypeIndex];
765 if (!domainPortIndex)
767 auto domainPortValue =
768 cast<DomainValue>(moduleOp.getArgument(domainPortIndex.getUInt()));
769 elements[domainTypeIndex] =
773 auto *domainAssociations = allocator.allocRow(elements);
774 table.setDomainAssociation(port, domainAssociations);
782 TermAllocator &allocator,
783 DomainTable &table, T op) {
784 auto numDomains = info.getNumDomains();
785 auto domainInfo = op.getDomainInfoAttr();
786 auto numPorts = op.getNumPorts();
788 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
789 for (
size_t i = 0; i < numPorts; ++i) {
790 auto port = dyn_cast<DomainValue>(op.getResult(i));
794 if (op.getPortDirection(i) == Direction::Out)
797 domainTypeIDTable[i] = info.getDomainTypeID(domainInfo, i);
800 for (
size_t i = 0; i < numPorts; ++i) {
801 Value port = op.getResult(i);
802 auto type = type_dyn_cast<FIRRTLBaseType>(port.getType());
806 SmallVector<IntegerAttr> associations(numDomains);
808 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
809 auto prevDomainPortIndex = associations[domainTypeID.index];
810 if (prevDomainPortIndex) {
812 prevDomainPortIndex, domainPortIndex);
815 associations[domainTypeID.index] = domainPortIndex;
818 SmallVector<Term *> elements(numDomains);
819 for (
size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
821 auto domainPortIndex = associations[domainTypeIndex];
822 if (!domainPortIndex)
824 auto domainPortValue =
825 cast<DomainValue>(op.getResult(domainPortIndex.getUInt()));
826 elements[domainTypeIndex] =
830 auto *domainAssociations = allocator.allocRow(elements);
831 table.setDomainAssociation(port, domainAssociations);
837static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
841 auto moduleOp = op.getReferencedModuleNameAttr();
842 auto lookup = updateTable.find(moduleOp);
843 if (lookup != updateTable.end())
848static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
851 InstanceChoiceOp op) {
852 auto moduleOp = op.getDefaultTargetAttr().getAttr();
853 auto lookup = updateTable.find(moduleOp);
854 if (lookup != updateTable.end())
859static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
860 DomainTable &table, UnsafeDomainCastOp op) {
861 auto domains = op.getDomains();
866 auto input = op.getInput();
868 SmallVector<Term *> elements(inputRow->elements);
869 for (
auto value : op.getDomains()) {
870 auto domain = cast<DomainValue>(value);
871 auto typeID = info.getDomainTypeID(domain);
875 auto *row = allocator.allocRow(elements);
876 table.setDomainAssociation(op.getResult(), row);
880static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
881 DomainTable &table, DomainDefineOp op) {
882 auto src = op.getSrc();
883 auto dst = op.getDest();
886 if (succeeded(
unify(dstTerm, srcTerm)))
889 VariableIDTable idTable;
890 auto diag = op->emitOpError(
"failed to propagate source to destination");
891 auto ¬e1 = diag.attachNote();
892 note1 <<
"destination has underlying value: ";
893 render(info, note1, idTable, dstTerm);
895 auto ¬e2 = diag.attachNote(src.getLoc());
896 note2 <<
"source has underlying value: ";
897 render(info, note2, idTable, srcTerm);
901static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
905 if (
auto instance = dyn_cast<InstanceOp>(op))
906 return processOp(info, allocator, table, updateTable, instance);
907 if (
auto instance = dyn_cast<InstanceChoiceOp>(op))
908 return processOp(info, allocator, table, updateTable, instance);
909 if (
auto cast = dyn_cast<UnsafeDomainCastOp>(op))
910 return processOp(info, allocator, table, cast);
911 if (
auto def = dyn_cast<DomainDefineOp>(op))
912 return processOp(info, allocator, table, def);
918 for (
auto rhs : op->getOperands()) {
919 if (!isa<FIRRTLBaseType>(rhs.getType()))
921 if (
auto *op = rhs.getDefiningOp();
922 op && op->hasTrait<OpTrait::ConstantLike>())
928 for (
auto rhs : op->getResults()) {
929 if (!isa<FIRRTLBaseType>(rhs.getType()))
931 if (
auto *op = rhs.getDefiningOp();
932 op && op->hasTrait<OpTrait::ConstantLike>())
942 TermAllocator &allocator,
945 FModuleOp moduleOp) {
946 auto result = moduleOp.getBody().walk([&](Operation *op) -> WalkResult {
947 return processOp(info, allocator, table, updateTable, op);
949 return failure(result.wasInterrupted());
955 TermAllocator &allocator, DomainTable &table,
957 FModuleOp moduleOp) {
972using ExportTable = DenseMap<DomainValue, TinyPtrVector<DomainValue>>;
977 FModuleOp moduleOp) {
979 size_t numPorts = moduleOp.getNumPorts();
980 for (
size_t i = 0; i < numPorts; ++i) {
981 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
984 auto value = table.getOptUnderlyingDomain(port);
986 exports[value].push_back(port);
1007struct PendingUpdates {
1017 DomainTypeID typeID,
size_t ip, LocationAttr loc,
1018 VariableTerm *var, PendingUpdates &pending) {
1019 if (pending.solutions.contains(var))
1022 auto *
context = loc.getContext();
1023 auto domainDecl = info.getDomain(typeID);
1024 auto domainName = domainDecl.getNameAttr();
1026 auto portName = StringAttr::get(
context, ns.
newName(domainName.getValue()));
1027 auto portType = DomainType::get(loc.getContext());
1028 auto portDirection = Direction::In;
1029 auto portSym = StringAttr();
1031 auto portAnnos = std::nullopt;
1032 auto portDomainInfo = FlatSymbolRefAttr::get(domainName);
1033 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1034 portAnnos, portDomainInfo);
1036 pending.solutions[var] = pending.insertions.size() + ip;
1037 pending.insertions.push_back({ip, portInfo});
1047 size_t ip, LocationAttr loc, ValueTerm *val,
1048 PendingUpdates &pending) {
1049 auto value = val->value;
1050 assert(isa<DomainType>(value.getType()));
1051 if (
isPort(value) || exports.contains(value) ||
1052 pending.exports.contains(value))
1055 auto *
context = loc.getContext();
1057 auto domainDecl = info.getDomain(typeID);
1058 auto domainName = domainDecl.getNameAttr();
1060 auto portName = StringAttr::get(
context, ns.
newName(domainName.getValue()));
1061 auto portType = DomainType::get(loc.getContext());
1062 auto portDirection = Direction::Out;
1063 auto portSym = StringAttr();
1064 auto portLoc = value.getLoc();
1065 auto portAnnos = std::nullopt;
1066 auto portDomainInfo = FlatSymbolRefAttr::get(domainName);
1067 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1068 portAnnos, portDomainInfo);
1069 pending.exports[value] = pending.insertions.size() + ip;
1070 pending.insertions.push_back({ip, portInfo});
1075 PendingUpdates &pending,
1076 DomainTypeID typeID,
size_t ip,
1077 LocationAttr loc, Term *term,
1079 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1083 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1084 ensureExported(info, ns, exports, typeID, ip, loc, val, pending);
1087 llvm_unreachable(
"invalid domain association");
1092 size_t ip, LocationAttr loc, RowTerm *row, PendingUpdates &pending) {
1093 for (
auto [index, term] : llvm::enumerate(row->elements))
1095 ip, loc,
find(term), exports);
1099 TermAllocator &allocator,
1103 PendingUpdates &pending) {
1104 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1105 auto port = moduleOp.getArgument(i);
1106 auto type = port.getType();
1107 if (!isa<FIRRTLBaseType>(type))
1110 info, ns, exports, i, moduleOp.getPortLocation(i),
1116 TermAllocator &allocator,
1118 FModuleOp mod, PendingUpdates &pending) {
1120 auto names = mod.getPortNamesAttr();
1121 for (
auto name : names.getAsRange<StringAttr>())
1128 DomainTable &table, FModuleOp moduleOp,
1129 const PendingUpdates &pending) {
1131 moduleOp.insertPorts(pending.insertions);
1134 for (
auto [var, portIndex] : pending.solutions) {
1135 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1136 auto *solution = allocator.allocVal(portValue);
1137 solve(var, solution);
1138 exports[portValue].push_back(portValue);
1142 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1143 for (
auto [domainValue, portIndex] : pending.exports) {
1144 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1145 builder.setInsertionPointAfterValue(domainValue);
1146 DomainDefineOp::create(builder, portValue.getLoc(), portValue, domainValue);
1148 exports[domainValue].push_back(portValue);
1149 table.setTermForDomain(portValue, allocator.allocVal(domainValue));
1155static SmallVector<Attribute>
1158 SmallVector<Attribute> result(info.getNumDomains());
1160 for (
auto domainPortIndexAttr : oldAssociations) {
1162 auto domainPortIndex = domainPortIndexAttr.getUInt();
1163 auto domainTypeID = info.getDomainTypeID(moduleDomainInfo, domainPortIndex);
1164 result[domainTypeID.index] = domainPortIndexAttr;
1172 const DomainTable &table,
1173 FModuleOp moduleOp) {
1174 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1175 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1176 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1177 if (!port || moduleOp.getPortDirection(i) == Direction::In ||
1181 auto *term = table.getOptTermForDomain(port);
1182 auto *val = llvm::dyn_cast_if_present<ValueTerm>(term);
1188 auto loc = port.getLoc();
1189 auto value = val->value;
1190 DomainDefineOp::create(builder, loc, port, value);
1199 const DomainTable &table,
1202 FModuleOp moduleOp) {
1207 auto *
context = moduleOp.getContext();
1208 auto numDomains = info.getNumDomains();
1209 auto oldModuleDomainInfo = moduleOp.getDomainInfoAttr();
1210 auto numPorts = moduleOp.getNumPorts();
1211 SmallVector<Attribute> newModuleDomainInfo(numPorts);
1213 for (
size_t i = 0; i < numPorts; ++i) {
1214 auto port = moduleOp.getArgument(i);
1215 auto type = port.getType();
1217 if (isa<DomainType>(type)) {
1218 newModuleDomainInfo[i] = oldModuleDomainInfo[i];
1222 if (isa<FIRRTLBaseType>(type)) {
1225 auto *row = cast<RowTerm>(table.getDomainAssociation(port));
1226 for (
size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1227 auto domainTypeID = DomainTypeID{domainIndex};
1228 if (associations[domainIndex])
1231 auto domain = cast<ValueTerm>(
find(row->elements[domainIndex]))->value;
1232 auto &exports = exportTable.at(domain);
1233 if (exports.empty()) {
1234 auto portName = moduleOp.getPortNameAttr(i);
1235 auto portLoc = moduleOp.getPortLocation(i);
1236 auto domainDecl = info.getDomain(domainTypeID);
1237 auto domainName = domainDecl.getNameAttr();
1238 auto diag = emitError(portLoc)
1239 <<
"private " << domainName <<
" association for port "
1241 diag.attachNote(domain.getLoc()) <<
"associated domain: " << domain;
1246 if (exports.size() > 1) {
1252 auto argument = cast<BlockArgument>(exports[0]);
1253 auto domainPortIndex = argument.getArgNumber();
1254 associations[domainTypeID.index] = IntegerAttr::get(
1255 IntegerType::get(
context, 32, IntegerType::Unsigned),
1259 newModuleDomainInfo[i] = ArrayAttr::get(
context, associations);
1263 newModuleDomainInfo[i] = ArrayAttr::get(
context, {});
1266 result = ArrayAttr::get(moduleOp.getContext(), newModuleDomainInfo);
1267 moduleOp.setDomainInfoAttr(result);
1271template <
typename T>
1273 TermAllocator &allocator,
1274 DomainTable &table, T op) {
1275 OpBuilder builder(op.getContext());
1276 builder.setInsertionPointAfter(op);
1277 auto numPorts = op->getNumResults();
1278 for (
size_t i = 0; i < numPorts; ++i) {
1279 auto port = dyn_cast<DomainValue>(op.getResult(i));
1280 auto direction = op.getPortDirection(i);
1285 if (port && direction == Direction::In && !
isDriven(port)) {
1286 auto loc = port.getLoc();
1288 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1290 auto anon = DomainCreateAnonOp::create(builder, loc, name);
1291 solve(var, allocator.allocVal(anon));
1292 DomainDefineOp::create(builder, loc, port, anon);
1295 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1296 auto value = val->value;
1297 DomainDefineOp::create(builder, loc, port, value);
1300 llvm_unreachable(
"unhandled domain term type");
1307static LogicalResult
updateOp(
const DomainInfo &info, TermAllocator &allocator,
1308 DomainTable &table, Operation *op) {
1309 if (
auto instance = dyn_cast<InstanceOp>(op))
1311 if (
auto instance = dyn_cast<InstanceChoiceOp>(op))
1319 TermAllocator &allocator,
1320 DomainTable &table, FModuleOp moduleOp) {
1321 auto result = moduleOp.getBodyBlock()->walk([&](Operation *op) -> WalkResult {
1322 return updateOp(info, allocator, table, op);
1324 return failure(result.wasInterrupted());
1329 TermAllocator &allocator, DomainTable &table,
1332 PendingUpdates pending;
1337 ArrayAttr portDomainInfo;
1346 auto &entry = updates[op.getModuleNameAttr()];
1347 entry.portDomainInfo = portDomainInfo;
1348 entry.portInsertions = std::move(pending.insertions);
1362 FModuleLike moduleOp) {
1363 auto numDomains = info.getNumDomains();
1364 auto domainInfo = moduleOp.getDomainInfoAttr();
1365 auto numPorts = moduleOp.getNumPorts();
1367 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1368 for (
size_t i = 0; i < numPorts; ++i) {
1369 if (isa<DomainType>(moduleOp.getPortType(i)))
1370 domainTypeIDTable[i] = info.getDomainTypeID(domainInfo, i);
1373 for (
size_t i = 0; i < numPorts; ++i) {
1374 auto type = type_dyn_cast<FIRRTLBaseType>(moduleOp.getPortType(i));
1379 SmallVector<IntegerAttr> associations(numDomains);
1381 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1382 auto prevDomainPortIndex = associations[domainTypeID.index];
1383 if (prevDomainPortIndex) {
1385 prevDomainPortIndex, domainPortIndex);
1388 associations[domainTypeID.index] = domainPortIndex;
1392 for (
size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1393 auto typeID = DomainTypeID{domainIndex};
1394 if (!associations[domainIndex]) {
1406 FModuleOp moduleOp) {
1407 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1408 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1409 if (!port || moduleOp.getPortDirection(i) != Direction::Out ||
1413 auto name = moduleOp.getPortNameAttr(i);
1414 auto diag = emitError(moduleOp.getPortLocation(i))
1415 <<
"undriven domain port " << name;
1424template <
typename T>
1426 for (
size_t i = 0, e = op.getNumResults(); i < e; ++i) {
1427 auto port = dyn_cast<DomainValue>(op.getResult(i));
1428 auto type = port.getType();
1429 if (!isa<DomainType>(type) || op.getPortDirection(i) != Direction::In ||
1433 auto name = op.getPortNameAttr(i);
1434 auto diag = emitError(op.getPortLocation(i))
1435 <<
"undriven domain port " << name;
1444 if (
auto inst = dyn_cast<InstanceOp>(op))
1446 if (
auto inst = dyn_cast<InstanceChoiceOp>(op))
1453 auto result = moduleOp.getBody().walk(
1454 [&](Operation *op) -> WalkResult {
return checkOp(op); });
1455 return failure(result.wasInterrupted());
1463 WalkResult result = op->walk<mlir::WalkOrder::PostOrder, ReverseIterator>(
1464 [=](Operation *op) -> WalkResult {
1465 return TypeSwitch<Operation *, WalkResult>(op)
1466 .Case<FModuleLike>([](FModuleLike op) {
1467 auto n = op.getNumPorts();
1468 BitVector erasures(n);
1469 for (
size_t i = 0; i < n; ++i)
1470 if (isa<DomainType>(op.getPortType(i)))
1472 op.erasePorts(erasures);
1473 return WalkResult::advance();
1475 .Case<DomainDefineOp, DomainCreateAnonOp>([](Operation *op) {
1477 return WalkResult::advance();
1479 .Case<UnsafeDomainCastOp>([](UnsafeDomainCastOp op) {
1480 op.replaceAllUsesWith(op.getInput());
1482 return WalkResult::advance();
1484 .Case<WireOp>([](WireOp op) {
1485 if (isa<DomainType>(op.getType(0)))
1487 return WalkResult::advance();
1489 .Case<InstanceOp, InstanceChoiceOp>([](
auto op) {
1490 auto n = op.getNumPorts();
1491 BitVector erasures(n);
1492 for (
size_t i = 0; i < n; ++i)
1493 if (isa<DomainType>(op->getResult(i).getType()))
1495 op.cloneWithErasedPortsAndReplaceUses(erasures);
1497 return WalkResult::advance();
1499 .Default([](Operation *op) {
1501 concat<Type>(op->getOperandTypes(), op->getResultTypes())) {
1502 if (isa<DomainType>(type)) {
1503 op->emitOpError(
"cannot be stripped");
1504 return WalkResult::interrupt();
1507 return WalkResult::advance();
1510 return failure(result.wasInterrupted());
1514 llvm::SmallVector<FModuleLike> modules;
1515 for (Operation &op : make_early_inc_range(*circuit.getBodyBlock())) {
1516 TypeSwitch<Operation *, void>(&op)
1517 .Case<FModuleLike>([&](FModuleLike op) { modules.push_back(op); })
1518 .Case<DomainOp>([](DomainOp op) { op.erase(); });
1530 FModuleOp moduleOp) {
1531 TermAllocator allocator;
1534 if (failed(
processModule(info, allocator, table, updates, moduleOp)))
1537 return updateModule(info, allocator, table, updates, moduleOp);
1542static LogicalResult
checkModule(
const DomainInfo &info, FModuleOp moduleOp) {
1552 TermAllocator allocator;
1555 return processModule(info, allocator, table, updateTable, moduleOp);
1560 FExtModuleOp moduleOp) {
1569 FModuleOp moduleOp) {
1573 TermAllocator allocator;
1575 if (failed(
processModule(info, allocator, table, updateTable, moduleOp)))
1585 const DomainInfo &info,
1588 assert(mode != InferDomainsMode::Strip);
1590 if (
auto moduleOp = dyn_cast<FModuleOp>(op)) {
1591 if (mode == InferDomainsMode::Check)
1594 if (mode == InferDomainsMode::InferAll || moduleOp.isPrivate())
1600 if (
auto extModule = dyn_cast<FExtModuleOp>(op))
1607struct InferDomainsPass
1608 :
public circt::firrtl::impl::InferDomainsBase<InferDomainsPass> {
1610 void runOnOperation()
override {
1612 auto circuit = getOperation();
1614 if (mode == InferDomainsMode::Strip) {
1616 signalPassFailure();
1620 auto &instanceGraph = getAnalysis<InstanceGraph>();
1621 DomainInfo
info(circuit);
1623 DenseSet<Operation *> errored;
1624 instanceGraph.walkPostOrder([&](
auto &node) {
1625 auto moduleOp = node.getModule();
1626 for (
auto *inst : node) {
1627 if (errored.contains(inst->getTarget()->getModule())) {
1628 errored.insert(moduleOp);
1632 if (failed(
runOnModuleLike(mode, info, updateTable, node.getModule())))
1633 errored.insert(moduleOp);
1636 signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static LogicalResult checkOp(Operation *op)
static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, InstanceOp op)
static LogicalResult updateModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, ModuleUpdateTable &updates, FModuleOp op)
Write the domain associations recorded in the domain table back to the IR.
static void emitDuplicatePortDomainError(const DomainInfo &info, T op, size_t i, DomainTypeID domainTypeID, IntegerAttr domainPortIndexAttr1, IntegerAttr domainPortIndexAttr2)
static ExportTable initializeExportTable(const DomainTable &table, FModuleOp moduleOp)
Build a table of exported domains: a map from domains defined internally, to their set of aliasing ou...
static LogicalResult processModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FModuleOp moduleOp)
static void emitAmbiguousPortDomainAssociation(const DomainInfo &info, T op, const llvm::TinyPtrVector< DomainValue > &exports, DomainTypeID typeID, size_t i)
static LogicalResult processModulePorts(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp)
static LogicalResult inferModule(const DomainInfo &info, ModuleUpdateTable &updates, FModuleOp moduleOp)
Solve for domains and then write the domain associations back to the IR.
static LogicalResult driveModuleOutputDomainPorts(const DomainInfo &info, const DomainTable &table, FModuleOp moduleOp)
SmallVector< std::pair< unsigned, PortInfo > > PortInsertions
llvm::MapVector< DomainValue, unsigned > PendingExports
A map from local domains to an aliasing port index, where that port has not yet been created.
static LogicalResult runOnModuleLike(InferDomainsMode mode, const DomainInfo &info, ModuleUpdateTable &updateTable, Operation *op)
mlir::TypedValue< DomainType > DomainValue
static LogicalResult stripCircuit(MLIRContext *context, CircuitOp circuit)
static RowTerm * getDomainAssociationAsRow(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Value value)
Get the row of domains that a hardware value in the IR is associated with.
static void emitMissingPortDomainAssociationError(const DomainInfo &info, T op, DomainTypeID typeID, size_t i)
static void getUpdatesForModulePorts(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, Namespace &ns, FModuleOp moduleOp, PendingUpdates &pending)
static T fixInstancePorts(T op, const ModuleUpdateInfo &update)
Apply the port changes of a moduleOp onto an instance-like op.
static LogicalResult updateModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp)
After updating the port domain associations, walk the body of the moduleOp to fix up any child instan...
static void render(const DomainInfo &info, Diagnostic &out, VariableIDTable &idTable, Term *term)
static StringAttr getDomainPortTypeName(ArrayAttr info, size_t i)
From a domain info attribute, get the domain-type of a domain value at index i.
static void processDomainDefinition(TermAllocator &allocator, DomainTable &table, DomainValue domain)
static LogicalResult unify(Term *lhs, Term *rhs)
static LogicalResult updateModuleDomainInfo(const DomainInfo &info, const DomainTable &table, const ExportTable &exportTable, ArrayAttr &result, FModuleOp moduleOp)
After generalizing the moduleOp, all domains should be solved.
static LogicalResult unifyAssociations(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Operation *op, Value lhs, Value rhs)
Unify the associated domain rows of two terms.
DenseMap< VariableTerm *, unsigned > PendingSolutions
A map from unsolved variables to a port index, where that port has not yet been created.
static LogicalResult checkModule(const DomainInfo &info, FModuleOp moduleOp)
Check that a module's ports are fully annotated, before performing domain inference on the module.
static LogicalResult checkModuleDomainPortDrivers(const DomainInfo &info, FModuleOp moduleOp)
Check that output domain ports are driven.
static SmallVector< Attribute > copyPortDomainAssociations(const DomainInfo &info, ArrayAttr moduleDomainInfo, size_t portIndex)
Copy the domain associations from the moduleOp domain info attribute into a small vector.
static void noteLocation(mlir::InFlightDiagnostic &diag, Operation *op)
static void emitPortDomainCrossingError(const DomainInfo &info, T op, size_t i, DomainTypeID domainTypeID, Term *term1, Term *term2)
static void getUpdatesForDomainAssociationOfPort(const DomainInfo &info, Namespace &ns, PendingUpdates &pending, DomainTypeID typeID, size_t ip, LocationAttr loc, Term *term, const ExportTable &exports)
static void applyUpdatesToModule(const DomainInfo &info, TermAllocator &allocator, ExportTable &exports, DomainTable &table, FModuleOp moduleOp, const PendingUpdates &pending)
static void getUpdatesForModule(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, FModuleOp mod, PendingUpdates &pending)
static LogicalResult updateOp(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Operation *op)
static LogicalResult checkInstanceDomainPortDrivers(T op)
Check that the input domain ports are driven.
static void emitDomainPortInferenceError(T op, size_t i)
Emit an error when we fail to infer the concrete domain to drive to a domain port.
static void ensureExported(const DomainInfo &info, Namespace &ns, const ExportTable &exports, DomainTypeID typeID, size_t ip, LocationAttr loc, ValueTerm *val, PendingUpdates &pending)
Ensure that the domain value is available in the signature of the moduleOp, so that subsequent hardwa...
static bool isPort(BlockArgument arg)
Return true if the value is a port on the module.
DenseMap< DomainValue, TinyPtrVector< DomainValue > > ExportTable
A map from domain IR values defined internal to the moduleOp, to ports that alias that domain.
static Term * getTermForDomain(TermAllocator &allocator, DomainTable &table, DomainValue value)
Get the corresponding term for a domain in the IR.
static auto getPortDomainAssociation(ArrayAttr info, size_t i)
From a domain info attribute, get the row of associated domains for a hardware value at index i.
DenseMap< StringAttr, ModuleUpdateInfo > ModuleUpdateTable
static void ensureSolved(const DomainInfo &info, Namespace &ns, DomainTypeID typeID, size_t ip, LocationAttr loc, VariableTerm *var, PendingUpdates &pending)
If var is not solved, solve it by recording a pending input port at the indicated insertion point.
static bool isDriven(DomainValue port)
Returns true if the value is driven by a connect op.
static void solve(Term *lhs, Term *rhs)
static LogicalResult processModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FModuleOp moduleOp)
Populate the domain table by processing the moduleOp.
static LogicalResult checkModuleBody(FModuleOp moduleOp)
Check that instances under this module have driven domain input ports.
static LogicalResult checkModulePorts(const DomainInfo &info, FModuleLike moduleOp)
Check that a module's hardware ports have complete domain associations.
static Term * find(Term *x)
static LogicalResult updateInstance(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, T op)
static LogicalResult processInstancePorts(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, T op)
static LogicalResult checkAndInferModule(const DomainInfo &info, ModuleUpdateTable &updateTable, FModuleOp moduleOp)
Check that a module's ports are fully annotated, before performing domain inference on the module.
static LogicalResult stripModule(FModuleLike op)
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
This class represents a reference to a specific field or element of an aggregate value.
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
InferDomainsMode
The mode for the InferDomains pass.
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.