21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/ADT/TinyPtrVector.h"
24#include "llvm/Support/Debug.h"
26#define DEBUG_TYPE "firrtl-infer-domains"
30#define GEN_PASS_DEF_INFERDOMAINS
31#include "circt/Dialect/FIRRTL/Passes.h.inc"
36using namespace firrtl;
51 return cast<FlatSymbolRefAttr>(info[i]).getAttr();
58 return info.getAsRange<IntegerAttr>();
59 return cast<ArrayAttr>(info[i]).getAsRange<IntegerAttr>();
63static bool isPort(BlockArgument arg) {
64 return isa<FModuleOp>(arg.getOwner()->getParentOp());
69 auto arg = dyn_cast<BlockArgument>(value);
77 for (
auto *user : port.getUsers())
78 if (
auto connect = dyn_cast<FConnectLike>(user))
79 if (connect.getDest() == port)
104 DomainInfo(CircuitOp circuit) { processCircuit(circuit); }
106 ArrayRef<DomainOp> getDomains()
const {
return domainTable; }
107 size_t getNumDomains()
const {
return domainTable.size(); }
108 DomainOp getDomain(DomainTypeID
id)
const {
return domainTable[
id.index]; }
110 DomainTypeID getDomainTypeID(StringAttr name)
const {
111 return typeIDTable.at(name);
114 DomainTypeID getDomainTypeID(FlatSymbolRefAttr ref)
const {
115 return getDomainTypeID(ref.getAttr());
118 DomainTypeID getDomainTypeID(ArrayAttr info,
size_t i)
const {
120 return getDomainTypeID(name);
123 DomainTypeID getDomainTypeID(
DomainValue value)
const {
124 if (
auto arg = dyn_cast<BlockArgument>(value)) {
125 auto *block = arg.getOwner();
126 auto *owner = block->getParentOp();
127 auto moduleOp = cast<FModuleOp>(owner);
128 auto info = moduleOp.getDomainInfoAttr();
129 auto i = arg.getArgNumber();
130 return getDomainTypeID(info, i);
133 auto result = dyn_cast<OpResult>(value);
134 auto *owner = result.getOwner();
136 auto info = TypeSwitch<Operation *, ArrayAttr>(owner)
137 .Case<InstanceOp, InstanceChoiceOp>(
138 [&](
auto inst) {
return inst.getDomainInfoAttr(); })
139 .Default([&](
auto inst) {
return nullptr; });
140 assert(info &&
"unable to obtain domain information from op");
142 auto i = result.getResultNumber();
143 return getDomainTypeID(info, i);
147 void processDomain(DomainOp op) {
148 auto index = domainTable.size();
149 auto name = op.getNameAttr();
150 domainTable.push_back(op);
151 typeIDTable.insert({name, {index}});
154 void processCircuit(CircuitOp circuit) {
155 for (
auto decl : circuit.getOps<DomainOp>())
160 SmallVector<DomainOp> domainTable;
163 DenseMap<StringAttr, DomainTypeID> typeIDTable;
168struct ModuleUpdateInfo {
170 ArrayAttr portDomainInfo;
181 auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions);
182 clone.setDomainInfoAttr(update.portDomainInfo);
203 constexpr Term(TermKind kind) : kind(kind) {}
211struct TermBase : Term {
212 static bool classof(
const Term *term) {
return term->kind == K; }
213 TermBase() : Term(K) {}
219struct VariableTerm :
public TermBase<TermKind::Variable> {
220 VariableTerm() : leader(nullptr) {}
221 VariableTerm(Term *leader) : leader(leader) {}
228struct ValueTerm :
public TermBase<TermKind::Value> {
236struct RowTerm :
public TermBase<TermKind::Row> {
237 RowTerm(ArrayRef<Term *> elements) : elements(elements) {}
238 ArrayRef<Term *> elements;
247 if (
auto *var = dyn_cast<VariableTerm>(x)) {
248 if (var->leader ==
nullptr)
251 auto *leader =
find(var->leader);
252 if (leader != var->leader)
253 var->leader = leader;
262class VariableIDTable {
264 size_t get(VariableTerm *term) {
265 return table.insert({term, table.size() + 1}).first->second;
269 DenseMap<VariableTerm *, size_t> table;
274static void render(
const DomainInfo &info, Diagnostic &out,
275 VariableIDTable &idTable, Term *term) {
277 if (
auto *var = dyn_cast<VariableTerm>(term)) {
278 out <<
"?" << idTable.get(var);
281 if (
auto *val = dyn_cast<ValueTerm>(term)) {
282 auto value = val->value;
287 if (
auto *row = dyn_cast<RowTerm>(term)) {
290 for (
size_t i = 0, e = info.getNumDomains(); i < e; ++i) {
291 auto domainOp = info.getDomain(DomainTypeID{i});
296 out << domainOp.getName() <<
": ";
297 render(info, out, idTable, row->elements[i]);
304static LogicalResult
unify(Term *lhs, Term *rhs);
306static LogicalResult
unify(VariableTerm *x, Term *y) {
312static LogicalResult
unify(ValueTerm *xv, Term *y) {
313 if (
auto *yv = dyn_cast<VariableTerm>(y)) {
318 if (
auto *yv = dyn_cast<ValueTerm>(y))
319 return success(xv == yv);
325static LogicalResult
unify(RowTerm *lhsRow, Term *rhs) {
326 if (
auto *rhsVar = dyn_cast<VariableTerm>(rhs)) {
327 rhsVar->leader = lhsRow;
330 if (
auto *rhsRow = dyn_cast<RowTerm>(rhs)) {
331 for (
auto [x, y] : llvm::zip_equal(lhsRow->elements, rhsRow->elements))
332 if (failed(
unify(x, y)))
340static LogicalResult
unify(Term *lhs, Term *rhs) {
347 if (
auto *lhsVar = dyn_cast<VariableTerm>(lhs))
348 return unify(lhsVar, rhs);
349 if (
auto *lhsVal = dyn_cast<ValueTerm>(lhs))
350 return unify(lhsVal, rhs);
351 if (
auto *lhsRow = dyn_cast<RowTerm>(lhs))
352 return unify(lhsRow, rhs);
356static void solve(Term *lhs, Term *rhs) {
357 [[maybe_unused]]
auto result =
unify(lhs, rhs);
358 assert(result.succeeded());
365 [[nodiscard]] RowTerm *allocRow(
size_t size) {
366 SmallVector<Term *> elements;
367 elements.resize(size);
368 return allocRow(elements);
372 [[nodiscard]] RowTerm *allocRow(ArrayRef<Term *> elements) {
373 auto ds = allocArray(elements);
374 return alloc<RowTerm>(ds);
378 [[nodiscard]] VariableTerm *allocVar() {
return alloc<VariableTerm>(); }
381 [[nodiscard]] ValueTerm *allocVal(
DomainValue value) {
382 return alloc<ValueTerm>(value);
386 template <
typename T,
typename... Args>
387 [[nodiscard]] T *alloc(Args &&...args) {
388 static_assert(std::is_base_of_v<Term, T>,
"T must be a term");
389 return new (allocator) T(std::forward<Args>(args)...);
392 [[nodiscard]] ArrayRef<Term *> allocArray(ArrayRef<Term *> elements) {
393 auto size = elements.size();
397 auto *result = allocator.Allocate<Term *>(size);
398 llvm::uninitialized_copy(elements, result);
399 for (
size_t i = 0; i < size; ++i)
401 result[i] = alloc<VariableTerm>();
403 return ArrayRef(result, size);
406 llvm::BumpPtrAllocator allocator;
420 auto *term = getOptTermForDomain(value);
421 if (
auto *val = llvm::dyn_cast_if_present<ValueTerm>(term))
427 Term *getOptTermForDomain(
DomainValue value)
const {
428 assert(isa<DomainType>(value.getType()));
429 auto it = termTable.find(value);
430 if (it == termTable.end())
432 return find(it->second);
437 auto *term = getOptTermForDomain(value);
443 void setTermForDomain(
DomainValue value, Term *term) {
445 assert(!termTable.contains(value));
446 termTable.insert({value, term});
451 Term *getOptDomainAssociation(Value value)
const {
452 assert(isa<FIRRTLBaseType>(value.getType()));
453 auto it = associationTable.find(value);
454 if (it == associationTable.end())
456 return find(it->second);
461 Term *getDomainAssociation(Value value)
const {
462 auto *term = getOptDomainAssociation(value);
469 void setDomainAssociation(Value value, Term *term) {
470 assert(isa<FIRRTLBaseType>(value.getType()));
473 associationTable.insert({value, term});
478 DenseMap<Value, Term *> termTable;
481 DenseMap<Value, Term *> associationTable;
493 assert(isa<DomainType>(value.getType()));
494 if (
auto *term = table.getOptTermForDomain(value))
496 auto *term = allocator.allocVar();
497 table.setTermForDomain(value, term);
503 assert(isa<DomainType>(domain.getType()));
504 auto *newTerm = allocator.allocVal(domain);
505 auto *oldTerm = table.getOptTermForDomain(domain);
507 table.setTermForDomain(domain, newTerm);
511 [[maybe_unused]]
auto result =
unify(oldTerm, newTerm);
512 assert(result.succeeded());
518 TermAllocator &allocator,
519 DomainTable &table, Value value) {
520 assert(isa<FIRRTLBaseType>(value.getType()));
521 auto *term = table.getOptDomainAssociation(value);
525 auto *row = allocator.allocRow(info.getNumDomains());
526 table.setDomainAssociation(value, row);
531 if (
auto *row = dyn_cast<RowTerm>(term))
535 if (
auto *var = dyn_cast<VariableTerm>(term)) {
536 auto *row = allocator.allocRow(info.getNumDomains());
541 assert(
false &&
"unhandled term type");
545static void noteLocation(mlir::InFlightDiagnostic &diag, Operation *op) {
546 auto ¬e = diag.attachNote(op->getLoc());
547 if (
auto mod = dyn_cast<FModuleOp>(op)) {
548 note <<
"in module " << mod.getModuleNameAttr();
551 if (
auto mod = dyn_cast<FExtModuleOp>(op)) {
552 note <<
"in extmodule " << mod.getModuleNameAttr();
555 if (
auto inst = dyn_cast<InstanceOp>(op)) {
556 note <<
"in instance " << inst.getInstanceNameAttr();
559 if (
auto inst = dyn_cast<InstanceChoiceOp>(op)) {
560 note <<
"in instance_choice " << inst.getNameAttr();
569 DomainTypeID domainTypeID, Term *term1,
571 VariableIDTable idTable;
573 auto portName = op.getPortNameAttr(i);
574 auto portLoc = op.getPortLocation(i);
575 auto domainDecl = info.getDomain(domainTypeID);
576 auto domainName = domainDecl.getNameAttr();
578 auto diag = emitError(portLoc);
579 diag <<
"illegal " << domainName <<
" crossing in port " << portName;
581 auto ¬e1 = diag.attachNote();
582 note1 <<
"1st instance: ";
583 render(info, note1, idTable, term1);
585 auto ¬e2 = diag.attachNote();
586 note2 <<
"2nd instance: ";
587 render(info, note2, idTable, term2);
594 DomainTypeID domainTypeID,
595 IntegerAttr domainPortIndexAttr1,
596 IntegerAttr domainPortIndexAttr2) {
597 VariableIDTable idTable;
598 auto portName = op.getPortNameAttr(i);
599 auto portLoc = op.getPortLocation(i);
600 auto domainDecl = info.getDomain(domainTypeID);
601 auto domainName = domainDecl.getNameAttr();
602 auto domainPortIndex1 = domainPortIndexAttr1.getUInt();
603 auto domainPortIndex2 = domainPortIndexAttr2.getUInt();
604 auto domainPortName1 = op.getPortNameAttr(domainPortIndex1);
605 auto domainPortName2 = op.getPortNameAttr(domainPortIndex2);
606 auto domainPortLoc1 = op.getPortLocation(domainPortIndex1);
607 auto domainPortLoc2 = op.getPortLocation(domainPortIndex2);
608 auto diag = emitError(portLoc);
609 diag <<
"duplicate " << domainName <<
" association for port " << portName;
610 auto ¬e1 = diag.attachNote(domainPortLoc1);
611 note1 <<
"associated with " << domainName <<
" port " << domainPortName1;
612 auto ¬e2 = diag.attachNote(domainPortLoc2);
613 note2 <<
"associated with " << domainName <<
" port " << domainPortName2;
621 auto name = op.getPortNameAttr(i);
622 auto diag = emitError(op->getLoc());
623 auto info = op.getDomainInfo();
624 diag <<
"unable to infer value for undriven domain port " << name;
625 for (
size_t j = 0, e = op.getNumPorts(); j < e; ++j) {
626 if (
auto assocs = dyn_cast<ArrayAttr>(info[j])) {
627 for (
auto assoc : assocs) {
628 if (i == cast<IntegerAttr>(assoc).getValue()) {
629 auto name = op.getPortNameAttr(j);
630 auto loc = op.getPortLocation(j);
631 diag.attachNote(loc) <<
"associated with hardware port " << name;
642 const DomainInfo &info, T op,
643 const llvm::TinyPtrVector<DomainValue> &exports, DomainTypeID typeID,
645 auto portName = op.getPortNameAttr(i);
646 auto portLoc = op.getPortLocation(i);
647 auto domainDecl = info.getDomain(typeID);
648 auto domainName = domainDecl.getNameAttr();
649 auto diag = emitError(portLoc) <<
"ambiguous " << domainName
650 <<
" association for port " << portName;
651 for (
auto e : exports) {
652 auto arg = cast<BlockArgument>(e);
653 auto name = op.getPortNameAttr(arg.getArgNumber());
654 auto loc = op.getPortLocation(arg.getArgNumber());
655 diag.attachNote(loc) <<
"candidate association " << name;
664 auto domainName = info.getDomain(typeID).getNameAttr();
665 auto portName = op.getPortNameAttr(i);
666 auto diag = emitError(op.getPortLocation(i))
667 <<
"missing " << domainName <<
" association for port "
674 TermAllocator &allocator,
675 DomainTable &table, Operation *op,
676 Value lhs, Value rhs) {
683 auto *lhsTerm = table.getOptDomainAssociation(lhs);
684 auto *rhsTerm = table.getOptDomainAssociation(rhs);
688 if (failed(
unify(lhsTerm, rhsTerm))) {
689 auto diag = op->emitOpError(
"illegal domain crossing in operation");
690 auto ¬e1 = diag.attachNote(lhs.getLoc());
692 note1 <<
"1st operand has domains: ";
693 VariableIDTable idTable;
694 render(info, note1, idTable, lhsTerm);
696 auto ¬e2 = diag.attachNote(rhs.getLoc());
697 note2 <<
"2nd operand has domains: ";
698 render(info, note2, idTable, rhsTerm);
703 table.setDomainAssociation(rhs, lhsTerm);
708 table.setDomainAssociation(lhs, rhsTerm);
712 auto *var = allocator.allocVar();
713 table.setDomainAssociation(lhs, var);
714 table.setDomainAssociation(rhs, var);
719 TermAllocator &allocator,
721 FModuleOp moduleOp) {
722 auto numDomains = info.getNumDomains();
723 auto domainInfo = moduleOp.getDomainInfoAttr();
724 auto numPorts = moduleOp.getNumPorts();
726 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
727 for (
size_t i = 0; i < numPorts; ++i) {
728 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
732 if (moduleOp.getPortDirection(i) == Direction::In)
735 domainTypeIDTable[i] = info.getDomainTypeID(domainInfo, i);
738 for (
size_t i = 0; i < numPorts; ++i) {
739 BlockArgument port = moduleOp.getArgument(i);
740 auto type = type_dyn_cast<FIRRTLBaseType>(port.getType());
744 SmallVector<IntegerAttr> associations(numDomains);
746 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
747 auto prevDomainPortIndex = associations[domainTypeID.index];
748 if (prevDomainPortIndex) {
750 prevDomainPortIndex, domainPortIndex);
753 associations[domainTypeID.index] = domainPortIndex;
756 SmallVector<Term *> elements(numDomains);
757 for (
size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
759 auto domainPortIndex = associations[domainTypeIndex];
760 if (!domainPortIndex)
762 auto domainPortValue =
763 cast<DomainValue>(moduleOp.getArgument(domainPortIndex.getUInt()));
764 elements[domainTypeIndex] =
768 auto *domainAssociations = allocator.allocRow(elements);
769 table.setDomainAssociation(port, domainAssociations);
777 TermAllocator &allocator,
778 DomainTable &table, T op) {
779 auto numDomains = info.getNumDomains();
780 auto domainInfo = op.getDomainInfoAttr();
781 auto numPorts = op.getNumPorts();
783 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
784 for (
size_t i = 0; i < numPorts; ++i) {
785 auto port = dyn_cast<DomainValue>(op.getResult(i));
789 if (op.getPortDirection(i) == Direction::Out)
792 domainTypeIDTable[i] = info.getDomainTypeID(domainInfo, i);
795 for (
size_t i = 0; i < numPorts; ++i) {
796 Value port = op.getResult(i);
797 auto type = type_dyn_cast<FIRRTLBaseType>(port.getType());
801 SmallVector<IntegerAttr> associations(numDomains);
803 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
804 auto prevDomainPortIndex = associations[domainTypeID.index];
805 if (prevDomainPortIndex) {
807 prevDomainPortIndex, domainPortIndex);
810 associations[domainTypeID.index] = domainPortIndex;
813 SmallVector<Term *> elements(numDomains);
814 for (
size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
816 auto domainPortIndex = associations[domainTypeIndex];
817 if (!domainPortIndex)
819 auto domainPortValue =
820 cast<DomainValue>(op.getResult(domainPortIndex.getUInt()));
821 elements[domainTypeIndex] =
825 auto *domainAssociations = allocator.allocRow(elements);
826 table.setDomainAssociation(port, domainAssociations);
832static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
836 auto moduleOp = op.getReferencedModuleNameAttr();
837 auto lookup = updateTable.find(moduleOp);
838 if (lookup != updateTable.end())
843static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
846 InstanceChoiceOp op) {
847 auto moduleOp = op.getDefaultTargetAttr().getAttr();
848 auto lookup = updateTable.find(moduleOp);
849 if (lookup != updateTable.end())
854static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
855 DomainTable &table, UnsafeDomainCastOp op) {
856 auto domains = op.getDomains();
861 auto input = op.getInput();
863 SmallVector<Term *> elements(inputRow->elements);
864 for (
auto value : op.getDomains()) {
865 auto domain = cast<DomainValue>(value);
866 auto typeID = info.getDomainTypeID(domain);
870 auto *row = allocator.allocRow(elements);
871 table.setDomainAssociation(op.getResult(), row);
875static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
876 DomainTable &table, DomainDefineOp op) {
877 auto src = op.getSrc();
878 auto dst = op.getDest();
881 if (succeeded(
unify(dstTerm, srcTerm)))
884 VariableIDTable idTable;
885 auto diag = op->emitOpError(
"failed to propagate source to destination");
886 auto ¬e1 = diag.attachNote();
887 note1 <<
"destination has underlying value: ";
888 render(info, note1, idTable, dstTerm);
890 auto ¬e2 = diag.attachNote(src.getLoc());
891 note2 <<
"source has underlying value: ";
892 render(info, note2, idTable, srcTerm);
896static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
900 if (
auto instance = dyn_cast<InstanceOp>(op))
901 return processOp(info, allocator, table, updateTable, instance);
902 if (
auto instance = dyn_cast<InstanceChoiceOp>(op))
903 return processOp(info, allocator, table, updateTable, instance);
904 if (
auto cast = dyn_cast<UnsafeDomainCastOp>(op))
905 return processOp(info, allocator, table, cast);
906 if (
auto def = dyn_cast<DomainDefineOp>(op))
907 return processOp(info, allocator, table, def);
913 for (
auto rhs : op->getOperands()) {
914 if (!isa<FIRRTLBaseType>(rhs.getType()))
916 if (
auto *op = rhs.getDefiningOp();
917 op && op->hasTrait<OpTrait::ConstantLike>())
923 for (
auto rhs : op->getResults()) {
924 if (!isa<FIRRTLBaseType>(rhs.getType()))
926 if (
auto *op = rhs.getDefiningOp();
927 op && op->hasTrait<OpTrait::ConstantLike>())
937 TermAllocator &allocator,
940 FModuleOp moduleOp) {
941 auto result = moduleOp.getBody().walk([&](Operation *op) -> WalkResult {
942 return processOp(info, allocator, table, updateTable, op);
944 return failure(result.wasInterrupted());
950 TermAllocator &allocator, DomainTable &table,
952 FModuleOp moduleOp) {
967using ExportTable = DenseMap<DomainValue, TinyPtrVector<DomainValue>>;
972 FModuleOp moduleOp) {
974 size_t numPorts = moduleOp.getNumPorts();
975 for (
size_t i = 0; i < numPorts; ++i) {
976 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
979 auto value = table.getOptUnderlyingDomain(port);
981 exports[value].push_back(port);
1002struct PendingUpdates {
1012 DomainTypeID typeID,
size_t ip, LocationAttr loc,
1013 VariableTerm *var, PendingUpdates &pending) {
1014 if (pending.solutions.contains(var))
1017 auto *
context = loc.getContext();
1018 auto domainDecl = info.getDomain(typeID);
1019 auto domainName = domainDecl.getNameAttr();
1021 auto portName = StringAttr::get(
context, ns.
newName(domainName.getValue()));
1022 auto portType = DomainType::get(loc.getContext());
1023 auto portDirection = Direction::In;
1024 auto portSym = StringAttr();
1026 auto portAnnos = std::nullopt;
1027 auto portDomainInfo = FlatSymbolRefAttr::get(domainName);
1028 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1029 portAnnos, portDomainInfo);
1031 pending.solutions[var] = pending.insertions.size() + ip;
1032 pending.insertions.push_back({ip, portInfo});
1042 size_t ip, LocationAttr loc, ValueTerm *val,
1043 PendingUpdates &pending) {
1044 auto value = val->value;
1045 assert(isa<DomainType>(value.getType()));
1046 if (
isPort(value) || exports.contains(value) ||
1047 pending.exports.contains(value))
1050 auto *
context = loc.getContext();
1052 auto domainDecl = info.getDomain(typeID);
1053 auto domainName = domainDecl.getNameAttr();
1055 auto portName = StringAttr::get(
context, ns.
newName(domainName.getValue()));
1056 auto portType = DomainType::get(loc.getContext());
1057 auto portDirection = Direction::Out;
1058 auto portSym = StringAttr();
1059 auto portLoc = value.getLoc();
1060 auto portAnnos = std::nullopt;
1061 auto portDomainInfo = FlatSymbolRefAttr::get(domainName);
1062 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1063 portAnnos, portDomainInfo);
1064 pending.exports[value] = pending.insertions.size() + ip;
1065 pending.insertions.push_back({ip, portInfo});
1070 PendingUpdates &pending,
1071 DomainTypeID typeID,
size_t ip,
1072 LocationAttr loc, Term *term,
1074 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1078 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1079 ensureExported(info, ns, exports, typeID, ip, loc, val, pending);
1082 llvm_unreachable(
"invalid domain association");
1087 size_t ip, LocationAttr loc, RowTerm *row, PendingUpdates &pending) {
1088 for (
auto [index, term] : llvm::enumerate(row->elements))
1090 ip, loc,
find(term), exports);
1094 TermAllocator &allocator,
1098 PendingUpdates &pending) {
1099 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1100 auto port = moduleOp.getArgument(i);
1101 auto type = port.getType();
1102 if (!isa<FIRRTLBaseType>(type))
1105 info, ns, exports, i, moduleOp.getPortLocation(i),
1111 TermAllocator &allocator,
1113 FModuleOp mod, PendingUpdates &pending) {
1115 auto names = mod.getPortNamesAttr();
1116 for (
auto name : names.getAsRange<StringAttr>())
1123 DomainTable &table, FModuleOp moduleOp,
1124 const PendingUpdates &pending) {
1126 moduleOp.insertPorts(pending.insertions);
1129 for (
auto [var, portIndex] : pending.solutions) {
1130 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1131 auto *solution = allocator.allocVal(portValue);
1132 solve(var, solution);
1133 exports[portValue].push_back(portValue);
1137 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1138 for (
auto [domainValue, portIndex] : pending.exports) {
1139 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1140 builder.setInsertionPointAfterValue(domainValue);
1141 DomainDefineOp::create(builder, portValue.getLoc(), portValue, domainValue);
1143 exports[domainValue].push_back(portValue);
1144 table.setTermForDomain(portValue, allocator.allocVal(domainValue));
1150static SmallVector<Attribute>
1153 SmallVector<Attribute> result(info.getNumDomains());
1155 for (
auto domainPortIndexAttr : oldAssociations) {
1157 auto domainPortIndex = domainPortIndexAttr.getUInt();
1158 auto domainTypeID = info.getDomainTypeID(moduleDomainInfo, domainPortIndex);
1159 result[domainTypeID.index] = domainPortIndexAttr;
1167 const DomainTable &table,
1168 FModuleOp moduleOp) {
1169 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1170 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1171 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1172 if (!port || moduleOp.getPortDirection(i) == Direction::In ||
1176 auto *term = table.getOptTermForDomain(port);
1177 auto *val = llvm::dyn_cast_if_present<ValueTerm>(term);
1183 auto loc = port.getLoc();
1184 auto value = val->value;
1185 DomainDefineOp::create(builder, loc, port, value);
1194 const DomainTable &table,
1197 FModuleOp moduleOp) {
1202 auto *
context = moduleOp.getContext();
1203 auto numDomains = info.getNumDomains();
1204 auto oldModuleDomainInfo = moduleOp.getDomainInfoAttr();
1205 auto numPorts = moduleOp.getNumPorts();
1206 SmallVector<Attribute> newModuleDomainInfo(numPorts);
1208 for (
size_t i = 0; i < numPorts; ++i) {
1209 auto port = moduleOp.getArgument(i);
1210 auto type = port.getType();
1212 if (isa<DomainType>(type)) {
1213 newModuleDomainInfo[i] = oldModuleDomainInfo[i];
1217 if (isa<FIRRTLBaseType>(type)) {
1220 auto *row = cast<RowTerm>(table.getDomainAssociation(port));
1221 for (
size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1222 auto domainTypeID = DomainTypeID{domainIndex};
1223 if (associations[domainIndex])
1226 auto domain = cast<ValueTerm>(
find(row->elements[domainIndex]))->value;
1227 auto &exports = exportTable.at(domain);
1228 if (exports.empty()) {
1229 auto portName = moduleOp.getPortNameAttr(i);
1230 auto portLoc = moduleOp.getPortLocation(i);
1231 auto domainDecl = info.getDomain(domainTypeID);
1232 auto domainName = domainDecl.getNameAttr();
1233 auto diag = emitError(portLoc)
1234 <<
"private " << domainName <<
" association for port "
1236 diag.attachNote(domain.getLoc()) <<
"associated domain: " << domain;
1241 if (exports.size() > 1) {
1247 auto argument = cast<BlockArgument>(exports[0]);
1248 auto domainPortIndex = argument.getArgNumber();
1249 associations[domainTypeID.index] = IntegerAttr::get(
1250 IntegerType::get(
context, 32, IntegerType::Unsigned),
1254 newModuleDomainInfo[i] = ArrayAttr::get(
context, associations);
1258 newModuleDomainInfo[i] = ArrayAttr::get(
context, {});
1261 result = ArrayAttr::get(moduleOp.getContext(), newModuleDomainInfo);
1262 moduleOp.setDomainInfoAttr(result);
1266template <
typename T>
1268 TermAllocator &allocator,
1269 DomainTable &table, T op) {
1270 OpBuilder builder(op.getContext());
1271 builder.setInsertionPointAfter(op);
1272 auto numPorts = op->getNumResults();
1273 for (
size_t i = 0; i < numPorts; ++i) {
1274 auto port = dyn_cast<DomainValue>(op.getResult(i));
1275 auto direction = op.getPortDirection(i);
1280 if (port && direction == Direction::In && !
isDriven(port)) {
1281 auto loc = port.getLoc();
1283 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1285 auto anon = DomainCreateAnonOp::create(builder, loc, name);
1286 solve(var, allocator.allocVal(anon));
1287 DomainDefineOp::create(builder, loc, port, anon);
1290 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1291 auto value = val->value;
1292 DomainDefineOp::create(builder, loc, port, value);
1295 llvm_unreachable(
"unhandled domain term type");
1302static LogicalResult
updateOp(
const DomainInfo &info, TermAllocator &allocator,
1303 DomainTable &table, Operation *op) {
1304 if (
auto instance = dyn_cast<InstanceOp>(op))
1306 if (
auto instance = dyn_cast<InstanceChoiceOp>(op))
1314 TermAllocator &allocator,
1315 DomainTable &table, FModuleOp moduleOp) {
1316 auto result = moduleOp.getBodyBlock()->walk([&](Operation *op) -> WalkResult {
1317 return updateOp(info, allocator, table, op);
1319 return failure(result.wasInterrupted());
1324 TermAllocator &allocator, DomainTable &table,
1327 PendingUpdates pending;
1332 ArrayAttr portDomainInfo;
1341 auto &entry = updates[op.getModuleNameAttr()];
1342 entry.portDomainInfo = portDomainInfo;
1343 entry.portInsertions = std::move(pending.insertions);
1357 FModuleLike moduleOp) {
1358 auto numDomains = info.getNumDomains();
1359 auto domainInfo = moduleOp.getDomainInfoAttr();
1360 auto numPorts = moduleOp.getNumPorts();
1362 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1363 for (
size_t i = 0; i < numPorts; ++i) {
1364 if (isa<DomainType>(moduleOp.getPortType(i)))
1365 domainTypeIDTable[i] = info.getDomainTypeID(domainInfo, i);
1368 for (
size_t i = 0; i < numPorts; ++i) {
1369 auto type = type_dyn_cast<FIRRTLBaseType>(moduleOp.getPortType(i));
1374 SmallVector<IntegerAttr> associations(numDomains);
1376 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1377 auto prevDomainPortIndex = associations[domainTypeID.index];
1378 if (prevDomainPortIndex) {
1380 prevDomainPortIndex, domainPortIndex);
1383 associations[domainTypeID.index] = domainPortIndex;
1387 for (
size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1388 auto typeID = DomainTypeID{domainIndex};
1389 if (!associations[domainIndex]) {
1401 FModuleOp moduleOp) {
1402 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1403 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1404 if (!port || moduleOp.getPortDirection(i) != Direction::Out ||
1408 auto name = moduleOp.getPortNameAttr(i);
1409 auto diag = emitError(moduleOp.getPortLocation(i))
1410 <<
"undriven domain port " << name;
1419template <
typename T>
1421 for (
size_t i = 0, e = op.getNumResults(); i < e; ++i) {
1422 auto port = dyn_cast<DomainValue>(op.getResult(i));
1423 auto type = port.getType();
1424 if (!isa<DomainType>(type) || op.getPortDirection(i) != Direction::In ||
1428 auto name = op.getPortNameAttr(i);
1429 auto diag = emitError(op.getPortLocation(i))
1430 <<
"undriven domain port " << name;
1439 if (
auto inst = dyn_cast<InstanceOp>(op))
1441 if (
auto inst = dyn_cast<InstanceChoiceOp>(op))
1448 auto result = moduleOp.getBody().walk(
1449 [&](Operation *op) -> WalkResult {
return checkOp(op); });
1450 return failure(result.wasInterrupted());
1460 FModuleOp moduleOp) {
1461 TermAllocator allocator;
1464 if (failed(
processModule(info, allocator, table, updates, moduleOp)))
1467 return updateModule(info, allocator, table, updates, moduleOp);
1472static LogicalResult
checkModule(
const DomainInfo &info, FModuleOp moduleOp) {
1482 TermAllocator allocator;
1485 return processModule(info, allocator, table, updateTable, moduleOp);
1490 FExtModuleOp moduleOp) {
1499 FModuleOp moduleOp) {
1503 TermAllocator allocator;
1505 if (failed(
processModule(info, allocator, table, updateTable, moduleOp)))
1515 const DomainInfo &info,
1518 if (
auto moduleOp = dyn_cast<FModuleOp>(op)) {
1519 if (mode == InferDomainsMode::Check)
1522 if (mode == InferDomainsMode::InferAll || moduleOp.isPrivate())
1528 if (
auto extModule = dyn_cast<FExtModuleOp>(op))
1535struct InferDomainsPass
1536 :
public circt::firrtl::impl::InferDomainsBase<InferDomainsPass> {
1538 void runOnOperation()
override {
1540 auto circuit = getOperation();
1541 auto &instanceGraph = getAnalysis<InstanceGraph>();
1542 DomainInfo
info(circuit);
1544 DenseSet<Operation *> errored;
1545 instanceGraph.walkPostOrder([&](
auto &node) {
1546 auto moduleOp = node.getModule();
1547 for (
auto *inst : node) {
1548 if (errored.contains(inst->getTarget()->getModule())) {
1549 errored.insert(moduleOp);
1553 if (failed(
runOnModuleLike(mode, info, updateTable, node.getModule())))
1554 errored.insert(moduleOp);
1557 signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static LogicalResult checkOp(Operation *op)
static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, InstanceOp op)
static LogicalResult updateModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, ModuleUpdateTable &updates, FModuleOp op)
Write the domain associations recorded in the domain table back to the IR.
static void emitDuplicatePortDomainError(const DomainInfo &info, T op, size_t i, DomainTypeID domainTypeID, IntegerAttr domainPortIndexAttr1, IntegerAttr domainPortIndexAttr2)
static ExportTable initializeExportTable(const DomainTable &table, FModuleOp moduleOp)
Build a table of exported domains: a map from domains defined internally, to their set of aliasing ou...
static LogicalResult processModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FModuleOp moduleOp)
static void emitAmbiguousPortDomainAssociation(const DomainInfo &info, T op, const llvm::TinyPtrVector< DomainValue > &exports, DomainTypeID typeID, size_t i)
static LogicalResult processModulePorts(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp)
static LogicalResult inferModule(const DomainInfo &info, ModuleUpdateTable &updates, FModuleOp moduleOp)
Solve for domains and then write the domain associations back to the IR.
static LogicalResult driveModuleOutputDomainPorts(const DomainInfo &info, const DomainTable &table, FModuleOp moduleOp)
SmallVector< std::pair< unsigned, PortInfo > > PortInsertions
llvm::MapVector< DomainValue, unsigned > PendingExports
A map from local domains to an aliasing port index, where that port has not yet been created.
static LogicalResult runOnModuleLike(InferDomainsMode mode, const DomainInfo &info, ModuleUpdateTable &updateTable, Operation *op)
mlir::TypedValue< DomainType > DomainValue
static RowTerm * getDomainAssociationAsRow(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Value value)
Get the row of domains that a hardware value in the IR is associated with.
static void emitMissingPortDomainAssociationError(const DomainInfo &info, T op, DomainTypeID typeID, size_t i)
static void getUpdatesForModulePorts(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, Namespace &ns, FModuleOp moduleOp, PendingUpdates &pending)
static T fixInstancePorts(T op, const ModuleUpdateInfo &update)
Apply the port changes of a moduleOp onto an instance-like op.
static LogicalResult updateModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp)
After updating the port domain associations, walk the body of the moduleOp to fix up any child instan...
static void render(const DomainInfo &info, Diagnostic &out, VariableIDTable &idTable, Term *term)
static StringAttr getDomainPortTypeName(ArrayAttr info, size_t i)
From a domain info attribute, get the domain-type of a domain value at index i.
static void processDomainDefinition(TermAllocator &allocator, DomainTable &table, DomainValue domain)
static LogicalResult unify(Term *lhs, Term *rhs)
static LogicalResult updateModuleDomainInfo(const DomainInfo &info, const DomainTable &table, const ExportTable &exportTable, ArrayAttr &result, FModuleOp moduleOp)
After generalizing the moduleOp, all domains should be solved.
static LogicalResult unifyAssociations(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Operation *op, Value lhs, Value rhs)
Unify the associated domain rows of two terms.
DenseMap< VariableTerm *, unsigned > PendingSolutions
A map from unsolved variables to a port index, where that port has not yet been created.
static LogicalResult checkModule(const DomainInfo &info, FModuleOp moduleOp)
Check that a module's ports are fully annotated, before performing domain inference on the module.
static LogicalResult checkModuleDomainPortDrivers(const DomainInfo &info, FModuleOp moduleOp)
Check that output domain ports are driven.
static SmallVector< Attribute > copyPortDomainAssociations(const DomainInfo &info, ArrayAttr moduleDomainInfo, size_t portIndex)
Copy the domain associations from the moduleOp domain info attribute into a small vector.
static void noteLocation(mlir::InFlightDiagnostic &diag, Operation *op)
static void emitPortDomainCrossingError(const DomainInfo &info, T op, size_t i, DomainTypeID domainTypeID, Term *term1, Term *term2)
static void getUpdatesForDomainAssociationOfPort(const DomainInfo &info, Namespace &ns, PendingUpdates &pending, DomainTypeID typeID, size_t ip, LocationAttr loc, Term *term, const ExportTable &exports)
static void applyUpdatesToModule(const DomainInfo &info, TermAllocator &allocator, ExportTable &exports, DomainTable &table, FModuleOp moduleOp, const PendingUpdates &pending)
static void getUpdatesForModule(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, FModuleOp mod, PendingUpdates &pending)
static LogicalResult updateOp(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Operation *op)
static LogicalResult checkInstanceDomainPortDrivers(T op)
Check that the input domain ports are driven.
static void emitDomainPortInferenceError(T op, size_t i)
Emit an error when we fail to infer the concrete domain to drive to a domain port.
static void ensureExported(const DomainInfo &info, Namespace &ns, const ExportTable &exports, DomainTypeID typeID, size_t ip, LocationAttr loc, ValueTerm *val, PendingUpdates &pending)
Ensure that the domain value is available in the signature of the moduleOp, so that subsequent hardwa...
static bool isPort(BlockArgument arg)
Return true if the value is a port on the module.
DenseMap< DomainValue, TinyPtrVector< DomainValue > > ExportTable
A map from domain IR values defined internal to the moduleOp, to ports that alias that domain.
static Term * getTermForDomain(TermAllocator &allocator, DomainTable &table, DomainValue value)
Get the corresponding term for a domain in the IR.
static auto getPortDomainAssociation(ArrayAttr info, size_t i)
From a domain info attribute, get the row of associated domains for a hardware value at index i.
DenseMap< StringAttr, ModuleUpdateInfo > ModuleUpdateTable
static void ensureSolved(const DomainInfo &info, Namespace &ns, DomainTypeID typeID, size_t ip, LocationAttr loc, VariableTerm *var, PendingUpdates &pending)
If var is not solved, solve it by recording a pending input port at the indicated insertion point.
static bool isDriven(DomainValue port)
Returns true if the value is driven by a connect op.
static void solve(Term *lhs, Term *rhs)
static LogicalResult processModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FModuleOp moduleOp)
Populate the domain table by processing the moduleOp.
static LogicalResult checkModuleBody(FModuleOp moduleOp)
Check that instances under this module have driven domain input ports.
static LogicalResult checkModulePorts(const DomainInfo &info, FModuleLike moduleOp)
Check that a module's hardware ports have complete domain associations.
static Term * find(Term *x)
static LogicalResult updateInstance(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, T op)
static LogicalResult processInstancePorts(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, T op)
static LogicalResult checkAndInferModule(const DomainInfo &info, ModuleUpdateTable &updateTable, FModuleOp moduleOp)
Check that a module's ports are fully annotated, before performing domain inference on the module.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
This class represents a reference to a specific field or element of an aggregate value.
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
InferDomainsMode
The mode for the InferDomains pass.
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.