26#include "mlir/IR/Iterators.h"
27#include "mlir/IR/Threading.h"
28#include "llvm/ADT/DenseMap.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/ADT/TinyPtrVector.h"
33#define DEBUG_TYPE "firrtl-infer-domains"
37#define GEN_PASS_DEF_INFERDOMAINS
38#include "circt/Dialect/FIRRTL/Passes.h.inc"
43using namespace firrtl;
46using mlir::ReverseIterator;
60 return info.getAsRange<IntegerAttr>();
61 return cast<ArrayAttr>(info[i]).getAsRange<IntegerAttr>();
65static bool isPort(BlockArgument arg) {
66 return isa<FModuleOp>(arg.getOwner()->getParentOp());
71 auto arg = dyn_cast<BlockArgument>(value);
79 for (
auto *user : port.getUsers())
80 if (
auto connect = dyn_cast<FConnectLike>(user))
81 if (connect.getDest() == port)
106 DomainInfo(CircuitOp circuit) { processCircuit(circuit); }
108 ArrayRef<DomainOp> getDomains()
const {
return domainTable; }
109 size_t getNumDomains()
const {
return domainTable.size(); }
110 DomainOp getDomain(DomainTypeID
id)
const {
return domainTable[
id.index]; }
112 DomainTypeID getDomainTypeID(Type type)
const {
return typeIDTable.at(type); }
114 DomainTypeID getDomainTypeID(FModuleLike module,
size_t i)
const {
115 return getDomainTypeID(module.getPortType(i));
118 DomainTypeID getDomainTypeID(FInstanceLike op,
size_t i)
const {
119 return getDomainTypeID(op->getResult(i).getType());
122 DomainTypeID getDomainTypeID(
DomainValue value)
const {
123 return getDomainTypeID(value.getType());
127 void processDomain(DomainOp op) {
128 auto index = domainTable.size();
129 auto domainType = DomainType::getFromDomainOp(op);
130 domainTable.push_back(op);
131 typeIDTable.insert({domainType, {index}});
134 void processCircuit(CircuitOp circuit) {
135 for (
auto decl : circuit.getOps<DomainOp>())
140 SmallVector<DomainOp> domainTable;
143 DenseMap<Type, DomainTypeID> typeIDTable;
148struct ModuleUpdateInfo {
150 ArrayAttr portDomainInfo;
160 const ModuleUpdateInfo &update) {
161 auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions);
162 clone.setDomainInfoAttr(update.portDomainInfo);
183 constexpr Term(TermKind kind) : kind(kind) {}
191struct TermBase : Term {
192 static bool classof(
const Term *term) {
return term->kind == K; }
193 TermBase() : Term(K) {}
199struct VariableTerm :
public TermBase<TermKind::Variable> {
200 VariableTerm() : leader(nullptr) {}
201 VariableTerm(Term *leader) : leader(leader) {}
208struct ValueTerm :
public TermBase<TermKind::Value> {
216struct RowTerm :
public TermBase<TermKind::Row> {
217 RowTerm(ArrayRef<Term *> elements) : elements(elements) {}
218 ArrayRef<Term *> elements;
227 if (
auto *var = dyn_cast<VariableTerm>(x)) {
228 if (var->leader ==
nullptr)
231 auto *leader =
find(var->leader);
232 if (leader != var->leader)
233 var->leader = leader;
242class VariableIDTable {
244 size_t get(VariableTerm *term) {
245 return table.insert({term, table.size() + 1}).first->second;
249 DenseMap<VariableTerm *, size_t> table;
254static void render(
const DomainInfo &info, Diagnostic &out,
255 VariableIDTable &idTable, Term *term) {
257 if (
auto *var = dyn_cast<VariableTerm>(term)) {
258 out <<
"?" << idTable.get(var);
261 if (
auto *val = dyn_cast<ValueTerm>(term)) {
262 auto value = val->value;
267 if (
auto *row = dyn_cast<RowTerm>(term)) {
270 for (
size_t i = 0, e = info.getNumDomains(); i < e; ++i) {
271 auto domainOp = info.getDomain(DomainTypeID{i});
276 out << domainOp.getName() <<
": ";
277 render(info, out, idTable, row->elements[i]);
284static LogicalResult
unify(Term *lhs, Term *rhs);
286static LogicalResult
unify(VariableTerm *x, Term *y) {
292static LogicalResult
unify(ValueTerm *xv, Term *y) {
293 if (
auto *yv = dyn_cast<VariableTerm>(y)) {
298 if (
auto *yv = dyn_cast<ValueTerm>(y))
299 return success(xv == yv);
305static LogicalResult
unify(RowTerm *lhsRow, Term *rhs) {
306 if (
auto *rhsVar = dyn_cast<VariableTerm>(rhs)) {
307 rhsVar->leader = lhsRow;
310 if (
auto *rhsRow = dyn_cast<RowTerm>(rhs)) {
311 for (
auto [x, y] : llvm::zip_equal(lhsRow->elements, rhsRow->elements))
312 if (failed(
unify(x, y)))
320static LogicalResult
unify(Term *lhs, Term *rhs) {
327 if (
auto *lhsVar = dyn_cast<VariableTerm>(lhs))
328 return unify(lhsVar, rhs);
329 if (
auto *lhsVal = dyn_cast<ValueTerm>(lhs))
330 return unify(lhsVal, rhs);
331 if (
auto *lhsRow = dyn_cast<RowTerm>(lhs))
332 return unify(lhsRow, rhs);
336static void solve(Term *lhs, Term *rhs) {
337 [[maybe_unused]]
auto result =
unify(lhs, rhs);
338 assert(result.succeeded());
345 [[nodiscard]] RowTerm *allocRow(
size_t size) {
346 SmallVector<Term *> elements;
347 elements.resize(size);
348 return allocRow(elements);
352 [[nodiscard]] RowTerm *allocRow(ArrayRef<Term *> elements) {
353 auto ds = allocArray(elements);
354 return alloc<RowTerm>(ds);
358 [[nodiscard]] VariableTerm *allocVar() {
return alloc<VariableTerm>(); }
361 [[nodiscard]] ValueTerm *allocVal(
DomainValue value) {
362 return alloc<ValueTerm>(value);
366 template <
typename T,
typename... Args>
367 [[nodiscard]] T *alloc(Args &&...args) {
368 static_assert(std::is_base_of_v<Term, T>,
"T must be a term");
369 return new (allocator) T(std::forward<Args>(args)...);
372 [[nodiscard]] ArrayRef<Term *> allocArray(ArrayRef<Term *> elements) {
373 auto size = elements.size();
377 auto *result = allocator.Allocate<Term *>(size);
378 llvm::uninitialized_copy(elements, result);
379 for (
size_t i = 0; i < size; ++i)
381 result[i] = alloc<VariableTerm>();
383 return ArrayRef(result, size);
386 llvm::BumpPtrAllocator allocator;
400 auto *term = getOptTermForDomain(value);
401 if (
auto *val = llvm::dyn_cast_if_present<ValueTerm>(term))
407 Term *getOptTermForDomain(
DomainValue value)
const {
408 assert(isa<DomainType>(value.getType()));
409 auto it = termTable.find(value);
410 if (it == termTable.end())
412 return find(it->second);
417 auto *term = getOptTermForDomain(value);
423 void setTermForDomain(
DomainValue value, Term *term) {
425 assert(!termTable.contains(value));
426 termTable.insert({value, term});
431 Term *getOptDomainAssociation(Value value)
const {
432 assert(isa<FIRRTLBaseType>(value.getType()));
433 auto it = associationTable.find(value);
434 if (it == associationTable.end())
436 return find(it->second);
441 Term *getDomainAssociation(Value value)
const {
442 auto *term = getOptDomainAssociation(value);
449 void setDomainAssociation(Value value, Term *term) {
450 assert(isa<FIRRTLBaseType>(value.getType()));
453 associationTable.insert({value, term});
458 DenseMap<Value, Term *> termTable;
461 DenseMap<Value, Term *> associationTable;
473 assert(isa<DomainType>(value.getType()));
474 if (
auto *term = table.getOptTermForDomain(value))
476 auto *term = allocator.allocVar();
477 table.setTermForDomain(value, term);
483 assert(isa<DomainType>(domain.getType()));
484 auto *newTerm = allocator.allocVal(domain);
485 auto *oldTerm = table.getOptTermForDomain(domain);
487 table.setTermForDomain(domain, newTerm);
491 [[maybe_unused]]
auto result =
unify(oldTerm, newTerm);
492 assert(result.succeeded());
498 TermAllocator &allocator,
499 DomainTable &table, Value value) {
500 assert(isa<FIRRTLBaseType>(value.getType()));
501 auto *term = table.getOptDomainAssociation(value);
505 auto *row = allocator.allocRow(info.getNumDomains());
506 table.setDomainAssociation(value, row);
511 if (
auto *row = dyn_cast<RowTerm>(term))
515 if (
auto *var = dyn_cast<VariableTerm>(term)) {
516 auto *row = allocator.allocRow(info.getNumDomains());
521 assert(
false &&
"unhandled term type");
525static void noteLocation(mlir::InFlightDiagnostic &diag, Operation *op) {
526 auto ¬e = diag.attachNote(op->getLoc());
527 if (
auto mod = dyn_cast<FModuleOp>(op)) {
528 note <<
"in module " << mod.getModuleNameAttr();
531 if (
auto mod = dyn_cast<FExtModuleOp>(op)) {
532 note <<
"in extmodule " << mod.getModuleNameAttr();
535 if (
auto inst = dyn_cast<InstanceOp>(op)) {
536 note <<
"in instance " << inst.getInstanceNameAttr();
539 if (
auto inst = dyn_cast<InstanceChoiceOp>(op)) {
540 note <<
"in instance_choice " << inst.getNameAttr();
549 DomainTypeID domainTypeID, Term *term1,
551 VariableIDTable idTable;
553 auto portName = op.getPortNameAttr(i);
554 auto portLoc = op.getPortLocation(i);
555 auto domainDecl = info.getDomain(domainTypeID);
556 auto domainName = domainDecl.getNameAttr();
558 auto diag = emitError(portLoc);
559 diag <<
"illegal " << domainName <<
" crossing in port " << portName;
561 auto ¬e1 = diag.attachNote();
562 note1 <<
"1st instance: ";
563 render(info, note1, idTable, term1);
565 auto ¬e2 = diag.attachNote();
566 note2 <<
"2nd instance: ";
567 render(info, note2, idTable, term2);
574 DomainTypeID domainTypeID,
575 IntegerAttr domainPortIndexAttr1,
576 IntegerAttr domainPortIndexAttr2) {
577 VariableIDTable idTable;
578 auto portName = op.getPortNameAttr(i);
579 auto portLoc = op.getPortLocation(i);
580 auto domainDecl = info.getDomain(domainTypeID);
581 auto domainName = domainDecl.getNameAttr();
582 auto domainPortIndex1 = domainPortIndexAttr1.getUInt();
583 auto domainPortIndex2 = domainPortIndexAttr2.getUInt();
584 auto domainPortName1 = op.getPortNameAttr(domainPortIndex1);
585 auto domainPortName2 = op.getPortNameAttr(domainPortIndex2);
586 auto domainPortLoc1 = op.getPortLocation(domainPortIndex1);
587 auto domainPortLoc2 = op.getPortLocation(domainPortIndex2);
588 auto diag = emitError(portLoc);
589 diag <<
"duplicate " << domainName <<
" association for port " << portName;
590 auto ¬e1 = diag.attachNote(domainPortLoc1);
591 note1 <<
"associated with " << domainName <<
" port " << domainPortName1;
592 auto ¬e2 = diag.attachNote(domainPortLoc2);
593 note2 <<
"associated with " << domainName <<
" port " << domainPortName2;
601 auto name = op.getPortNameAttr(i);
602 auto diag = emitError(op->getLoc());
603 auto info = op.getDomainInfo();
604 diag <<
"unable to infer value for undriven domain port " << name;
605 for (
size_t j = 0, e = op.getNumPorts(); j < e; ++j) {
606 if (
auto assocs = dyn_cast<ArrayAttr>(info[j])) {
607 for (
auto assoc : assocs) {
608 if (i == cast<IntegerAttr>(assoc).getValue()) {
609 auto name = op.getPortNameAttr(j);
610 auto loc = op.getPortLocation(j);
611 diag.attachNote(loc) <<
"associated with hardware port " << name;
622 const DomainInfo &info, T op,
623 const llvm::TinyPtrVector<DomainValue> &exports, DomainTypeID typeID,
625 auto portName = op.getPortNameAttr(i);
626 auto portLoc = op.getPortLocation(i);
627 auto domainDecl = info.getDomain(typeID);
628 auto domainName = domainDecl.getNameAttr();
629 auto diag = emitError(portLoc) <<
"ambiguous " << domainName
630 <<
" association for port " << portName;
631 for (
auto e : exports) {
632 auto arg = cast<BlockArgument>(e);
633 auto name = op.getPortNameAttr(arg.getArgNumber());
634 auto loc = op.getPortLocation(arg.getArgNumber());
635 diag.attachNote(loc) <<
"candidate association " << name;
644 auto domainName = info.getDomain(typeID).getNameAttr();
645 auto portName = op.getPortNameAttr(i);
646 auto diag = emitError(op.getPortLocation(i))
647 <<
"missing " << domainName <<
" association for port "
654 TermAllocator &allocator,
655 DomainTable &table, Operation *op,
656 Value lhs, Value rhs) {
663 auto *lhsTerm = table.getOptDomainAssociation(lhs);
664 auto *rhsTerm = table.getOptDomainAssociation(rhs);
668 if (failed(
unify(lhsTerm, rhsTerm))) {
669 auto diag = op->emitOpError(
"illegal domain crossing in operation");
670 auto ¬e1 = diag.attachNote(lhs.getLoc());
672 note1 <<
"1st operand has domains: ";
673 VariableIDTable idTable;
674 render(info, note1, idTable, lhsTerm);
676 auto ¬e2 = diag.attachNote(rhs.getLoc());
677 note2 <<
"2nd operand has domains: ";
678 render(info, note2, idTable, rhsTerm);
683 table.setDomainAssociation(rhs, lhsTerm);
688 table.setDomainAssociation(lhs, rhsTerm);
692 auto *var = allocator.allocVar();
693 table.setDomainAssociation(lhs, var);
694 table.setDomainAssociation(rhs, var);
699 TermAllocator &allocator,
701 FModuleOp moduleOp) {
702 auto numDomains = info.getNumDomains();
703 auto domainInfo = moduleOp.getDomainInfoAttr();
704 auto numPorts = moduleOp.getNumPorts();
706 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
707 for (
size_t i = 0; i < numPorts; ++i) {
708 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
712 if (moduleOp.getPortDirection(i) == Direction::In)
715 domainTypeIDTable[i] = info.getDomainTypeID(moduleOp, i);
718 for (
size_t i = 0; i < numPorts; ++i) {
719 BlockArgument port = moduleOp.getArgument(i);
720 auto type = type_dyn_cast<FIRRTLBaseType>(port.getType());
724 SmallVector<IntegerAttr> associations(numDomains);
726 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
727 auto prevDomainPortIndex = associations[domainTypeID.index];
728 if (prevDomainPortIndex) {
730 prevDomainPortIndex, domainPortIndex);
733 associations[domainTypeID.index] = domainPortIndex;
736 SmallVector<Term *> elements(numDomains);
737 for (
size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
739 auto domainPortIndex = associations[domainTypeIndex];
740 if (!domainPortIndex)
742 auto domainPortValue =
743 cast<DomainValue>(moduleOp.getArgument(domainPortIndex.getUInt()));
744 elements[domainTypeIndex] =
748 auto *domainAssociations = allocator.allocRow(elements);
749 table.setDomainAssociation(port, domainAssociations);
757 TermAllocator &allocator,
758 DomainTable &table, T op) {
759 auto numDomains = info.getNumDomains();
760 auto domainInfo = op.getDomainInfoAttr();
761 auto numPorts = op.getNumPorts();
763 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
764 for (
size_t i = 0; i < numPorts; ++i) {
765 auto port = dyn_cast<DomainValue>(op->getResult(i));
769 if (op.getPortDirection(i) == Direction::Out)
772 domainTypeIDTable[i] = info.getDomainTypeID(op, i);
775 for (
size_t i = 0; i < numPorts; ++i) {
776 Value port = op->getResult(i);
777 auto type = type_dyn_cast<FIRRTLBaseType>(port.getType());
781 SmallVector<IntegerAttr> associations(numDomains);
783 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
784 auto prevDomainPortIndex = associations[domainTypeID.index];
785 if (prevDomainPortIndex) {
787 prevDomainPortIndex, domainPortIndex);
790 associations[domainTypeID.index] = domainPortIndex;
793 SmallVector<Term *> elements(numDomains);
794 for (
size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
796 auto domainPortIndex = associations[domainTypeIndex];
797 if (!domainPortIndex)
799 auto domainPortValue =
800 cast<DomainValue>(op->getResult(domainPortIndex.getUInt()));
801 elements[domainTypeIndex] =
805 auto *domainAssociations = allocator.allocRow(elements);
806 table.setDomainAssociation(port, domainAssociations);
812static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
817 cast<StringAttr>(cast<ArrayAttr>(op.getReferencedModuleNamesAttr())[0]);
818 auto lookup = updateTable.find(moduleName);
819 if (lookup != updateTable.end())
824static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
825 DomainTable &table, UnsafeDomainCastOp op) {
826 auto domains = op.getDomains();
831 auto input = op.getInput();
833 SmallVector<Term *> elements(inputRow->elements);
834 for (
auto value : op.getDomains()) {
835 auto domain = cast<DomainValue>(value);
836 auto typeID = info.getDomainTypeID(domain);
840 auto *row = allocator.allocRow(elements);
841 table.setDomainAssociation(op.getResult(), row);
845static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
846 DomainTable &table, DomainDefineOp op) {
847 auto src = op.getSrc();
848 auto dst = op.getDest();
851 if (succeeded(
unify(dstTerm, srcTerm)))
856 <<
"defines a domain value that was inferred to be a different domain '";
857 VariableIDTable idTable;
859 render(info, *diag.getUnderlyingDiagnostic(), idTable, dstTerm);
865static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
866 DomainTable &table, WireOp op) {
873 if (op.getDomains().empty())
878 SmallVector<Term *> elements(info.getNumDomains());
879 for (
auto domain : op.getDomains()) {
880 auto domainValue = cast<DomainValue>(domain);
881 auto typeID = info.getDomainTypeID(domainValue);
884 table.setDomainAssociation(op.getResult(), allocator.allocRow(elements));
889static LogicalResult
processOp(
const DomainInfo &info, TermAllocator &allocator,
893 if (
auto instance = dyn_cast<FInstanceLike>(op))
894 return processOp(info, allocator, table, updateTable, instance);
895 if (
auto wireOp = dyn_cast<WireOp>(op))
896 return processOp(info, allocator, table, wireOp);
897 if (
auto cast = dyn_cast<UnsafeDomainCastOp>(op))
898 return processOp(info, allocator, table, cast);
899 if (
auto def = dyn_cast<DomainDefineOp>(op))
900 return processOp(info, allocator, table, def);
901 if (
auto create = dyn_cast<DomainCreateOp>(op)) {
905 if (
auto createAnon = dyn_cast<DomainCreateAnonOp>(op)) {
914 for (
auto rhs : op->getOperands()) {
915 if (!isa<FIRRTLBaseType>(rhs.getType()))
917 if (
auto *op = rhs.getDefiningOp();
918 op && op->hasTrait<OpTrait::ConstantLike>())
924 for (
auto rhs : op->getResults()) {
925 if (!isa<FIRRTLBaseType>(rhs.getType()))
927 if (
auto *op = rhs.getDefiningOp();
928 op && op->hasTrait<OpTrait::ConstantLike>())
938 TermAllocator &allocator,
941 FModuleOp moduleOp) {
942 return failure(moduleOp.getBody()
943 .walk([&](Operation *op) -> WalkResult {
944 return processOp(info, allocator, table, updateTable,
953 TermAllocator &allocator, DomainTable &table,
955 FModuleOp moduleOp) {
970using ExportTable = DenseMap<DomainValue, TinyPtrVector<DomainValue>>;
975 FModuleOp moduleOp) {
977 size_t numPorts = moduleOp.getNumPorts();
978 for (
size_t i = 0; i < numPorts; ++i) {
979 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
982 auto value = table.getOptUnderlyingDomain(port);
984 exports[value].push_back(port);
1005struct PendingUpdates {
1015 DomainTypeID typeID,
size_t ip, LocationAttr loc,
1016 VariableTerm *var, PendingUpdates &pending) {
1017 if (pending.solutions.contains(var))
1020 auto *
context = loc.getContext();
1021 auto domainDecl = info.getDomain(typeID);
1022 auto domainName = domainDecl.getNameAttr();
1024 auto portName = StringAttr::get(
context, ns.
newName(domainName.getValue()));
1025 auto portType = DomainType::getFromDomainOp(domainDecl);
1026 auto portDirection = Direction::In;
1027 auto portSym = StringAttr();
1029 auto portAnnos = std::nullopt;
1031 auto portDomainInfo = ArrayAttr::get(
context, {});
1032 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1033 portAnnos, portDomainInfo);
1035 pending.solutions[var] = pending.insertions.size() + ip;
1036 pending.insertions.push_back({ip, portInfo});
1046 size_t ip, LocationAttr loc, ValueTerm *val,
1047 PendingUpdates &pending) {
1048 auto value = val->value;
1049 assert(isa<DomainType>(value.getType()));
1050 if (
isPort(value) || exports.contains(value) ||
1051 pending.exports.contains(value))
1054 auto *
context = loc.getContext();
1056 auto domainDecl = info.getDomain(typeID);
1057 auto domainName = domainDecl.getNameAttr();
1059 auto portName = StringAttr::get(
context, ns.
newName(domainName.getValue()));
1060 auto portType = DomainType::getFromDomainOp(domainDecl);
1061 auto portDirection = Direction::Out;
1062 auto portSym = StringAttr();
1063 auto portLoc = value.getLoc();
1064 auto portAnnos = std::nullopt;
1066 auto portDomainInfo = ArrayAttr::get(
context, {});
1067 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1068 portAnnos, portDomainInfo);
1069 pending.exports[value] = pending.insertions.size() + ip;
1070 pending.insertions.push_back({ip, portInfo});
1075 PendingUpdates &pending,
1076 DomainTypeID typeID,
size_t ip,
1077 LocationAttr loc, Term *term,
1079 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1083 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1084 ensureExported(info, ns, exports, typeID, ip, loc, val, pending);
1087 llvm_unreachable(
"invalid domain association");
1092 size_t ip, LocationAttr loc, RowTerm *row, PendingUpdates &pending) {
1093 for (
auto [index, term] : llvm::enumerate(row->elements))
1095 ip, loc,
find(term), exports);
1099 TermAllocator &allocator,
1103 PendingUpdates &pending) {
1104 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1105 auto port = moduleOp.getArgument(i);
1106 auto type = port.getType();
1107 if (!isa<FIRRTLBaseType>(type))
1110 info, ns, exports, i, moduleOp.getPortLocation(i),
1116 TermAllocator &allocator,
1118 FModuleOp mod, PendingUpdates &pending) {
1120 auto names = mod.getPortNamesAttr();
1121 for (
auto name : names.getAsRange<StringAttr>())
1128 DomainTable &table, FModuleOp moduleOp,
1129 const PendingUpdates &pending) {
1131 moduleOp.insertPorts(pending.insertions);
1134 for (
auto [var, portIndex] : pending.solutions) {
1135 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1136 auto *solution = allocator.allocVal(portValue);
1137 solve(var, solution);
1138 exports[portValue].push_back(portValue);
1142 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1143 for (
auto [domainValue, portIndex] : pending.exports) {
1144 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1145 builder.setInsertionPointAfterValue(domainValue);
1146 DomainDefineOp::create(builder, portValue.getLoc(), portValue, domainValue);
1148 exports[domainValue].push_back(portValue);
1149 table.setTermForDomain(portValue, allocator.allocVal(domainValue));
1155static SmallVector<Attribute>
1157 ArrayAttr moduleDomainInfo,
size_t portIndex) {
1158 SmallVector<Attribute> result(info.getNumDomains());
1160 for (
auto domainPortIndexAttr : oldAssociations) {
1162 auto domainPortIndex = domainPortIndexAttr.getUInt();
1163 auto domainTypeID = info.getDomainTypeID(moduleOp, domainPortIndex);
1164 result[domainTypeID.index] = domainPortIndexAttr;
1172 const DomainTable &table,
1173 FModuleOp moduleOp) {
1174 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1175 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1176 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1177 if (!port || moduleOp.getPortDirection(i) == Direction::In ||
1181 auto *term = table.getOptTermForDomain(port);
1182 auto *val = llvm::dyn_cast_if_present<ValueTerm>(term);
1188 auto loc = port.getLoc();
1189 auto value = val->value;
1190 DomainDefineOp::create(builder, loc, port, value);
1199 const DomainTable &table,
1202 FModuleOp moduleOp) {
1207 auto *
context = moduleOp.getContext();
1208 auto numDomains = info.getNumDomains();
1209 auto oldModuleDomainInfo = moduleOp.getDomainInfoAttr();
1210 auto numPorts = moduleOp.getNumPorts();
1211 SmallVector<Attribute> newModuleDomainInfo(numPorts);
1213 for (
size_t i = 0; i < numPorts; ++i) {
1214 auto port = moduleOp.getArgument(i);
1215 auto type = port.getType();
1217 if (isa<DomainType>(type)) {
1219 newModuleDomainInfo[i] = ArrayAttr::get(
context, {});
1223 if (isa<FIRRTLBaseType>(type)) {
1226 auto *row = cast<RowTerm>(table.getDomainAssociation(port));
1227 for (
size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1228 auto domainTypeID = DomainTypeID{domainIndex};
1229 if (associations[domainIndex])
1232 auto domain = cast<ValueTerm>(
find(row->elements[domainIndex]))->value;
1233 auto &exports = exportTable.at(domain);
1234 if (exports.empty()) {
1235 auto portName = moduleOp.getPortNameAttr(i);
1236 auto portLoc = moduleOp.getPortLocation(i);
1237 auto domainDecl = info.getDomain(domainTypeID);
1238 auto domainName = domainDecl.getNameAttr();
1239 auto diag = emitError(portLoc)
1240 <<
"private " << domainName <<
" association for port "
1242 diag.attachNote(domain.getLoc()) <<
"associated domain: " << domain;
1247 if (exports.size() > 1) {
1253 auto argument = cast<BlockArgument>(exports[0]);
1254 auto domainPortIndex = argument.getArgNumber();
1255 associations[domainTypeID.index] = IntegerAttr::get(
1256 IntegerType::get(
context, 32, IntegerType::Unsigned),
1260 newModuleDomainInfo[i] = ArrayAttr::get(
context, associations);
1264 newModuleDomainInfo[i] = ArrayAttr::get(
context, {});
1267 result = ArrayAttr::get(moduleOp.getContext(), newModuleDomainInfo);
1268 moduleOp.setDomainInfoAttr(result);
1273 TermAllocator &allocator,
1274 DomainTable &table, FInstanceLike op,
1275 OpBuilder &builder) {
1276 auto numPorts = op->getNumResults();
1277 for (
size_t i = 0; i < numPorts; ++i) {
1278 auto port = dyn_cast<DomainValue>(op->getResult(i));
1279 auto direction = op.getPortDirection(i);
1284 if (port && direction == Direction::In && !
isDriven(port)) {
1285 auto loc = port.getLoc();
1287 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1288 auto domainType = cast<DomainType>(op->getResult(i).getType());
1289 auto domainTypeID = info.getDomainTypeID(domainType);
1290 auto domainDecl = info.getDomain(domainTypeID);
1291 auto name = domainDecl.getNameAttr();
1294 OpBuilder::InsertionGuard guard(builder);
1295 builder.setInsertionPointAfter(op);
1296 anon = DomainCreateAnonOp::create(builder, loc, domainType, name);
1298 solve(var, allocator.allocVal(anon));
1300 DomainDefineOp::create(builder, loc, port, anon);
1303 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1304 auto value = val->value;
1306 DomainDefineOp::create(builder, loc, port, value);
1309 llvm_unreachable(
"unhandled domain term type");
1318 TermAllocator &allocator, DomainTable &table,
1320 auto result = wireOp.getResult();
1321 if (!isa<FIRRTLBaseType>(result.getType()))
1325 auto *term = table.getOptDomainAssociation(result);
1329 auto *row = dyn_cast<RowTerm>(
find(term));
1334 SmallVector<Value> domainOperands;
1335 for (
auto *element : llvm::map_range(row->elements,
find))
1336 if (
auto *val = dyn_cast_or_null<ValueTerm>(element))
1337 domainOperands.push_back(val->value);
1341 if (!domainOperands.empty() && wireOp.getDomains().empty())
1342 wireOp->setOperands(domainOperands);
1350 TermAllocator &allocator,
1351 DomainTable &table, FModuleOp moduleOp) {
1354 OpBuilder builder(moduleOp.getContext());
1355 builder.setInsertionPointToEnd(moduleOp.getBodyBlock());
1358 auto instanceResult =
1359 moduleOp.getBodyBlock()->walk([&](FInstanceLike op) -> WalkResult {
1362 if (instanceResult.wasInterrupted())
1366 auto wireResult = moduleOp.getBodyBlock()->walk([&](WireOp op) -> WalkResult {
1367 return updateWire(info, allocator, table, op);
1369 return failure(wireResult.wasInterrupted());
1374 TermAllocator &allocator, DomainTable &table,
1377 PendingUpdates pending;
1382 ArrayAttr portDomainInfo;
1391 auto &entry = updates[op.getModuleNameAttr()];
1392 entry.portDomainInfo = portDomainInfo;
1393 entry.portInsertions = std::move(pending.insertions);
1407 FModuleLike moduleOp) {
1408 auto numDomains = info.getNumDomains();
1409 auto domainInfo = moduleOp.getDomainInfoAttr();
1410 auto numPorts = moduleOp.getNumPorts();
1412 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1413 for (
size_t i = 0; i < numPorts; ++i) {
1414 if (isa<DomainType>(moduleOp.getPortType(i)))
1415 domainTypeIDTable[i] = info.getDomainTypeID(moduleOp, i);
1418 for (
size_t i = 0; i < numPorts; ++i) {
1419 auto type = type_dyn_cast<FIRRTLBaseType>(moduleOp.getPortType(i));
1424 SmallVector<IntegerAttr> associations(numDomains);
1426 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1427 auto prevDomainPortIndex = associations[domainTypeID.index];
1428 if (prevDomainPortIndex) {
1430 prevDomainPortIndex, domainPortIndex);
1433 associations[domainTypeID.index] = domainPortIndex;
1437 for (
size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1438 auto typeID = DomainTypeID{domainIndex};
1439 if (!associations[domainIndex]) {
1451 FModuleOp moduleOp) {
1452 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1453 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1454 if (!port || moduleOp.getPortDirection(i) != Direction::Out ||
1458 auto name = moduleOp.getPortNameAttr(i);
1459 auto diag = emitError(moduleOp.getPortLocation(i))
1460 <<
"undriven domain port " << name;
1470 for (
size_t i = 0, e = op->getNumResults(); i < e; ++i) {
1471 auto port = dyn_cast<DomainValue>(op->getResult(i));
1473 auto type = port.getType();
1474 if (!isa<DomainType>(type) || op.getPortDirection(i) != Direction::In ||
1478 auto name = op.getPortNameAttr(i);
1479 auto diag = emitError(op.getPortLocation(i))
1480 <<
"undriven domain port " << name;
1490 auto result = moduleOp.getBody().walk([](FInstanceLike op) -> WalkResult {
1493 return failure(result.wasInterrupted());
1501 WalkResult result = op->walk<mlir::WalkOrder::PostOrder, ReverseIterator>(
1502 [=](Operation *op) -> WalkResult {
1503 return TypeSwitch<Operation *, WalkResult>(op)
1504 .Case<FModuleLike>([](FModuleLike op) {
1505 auto n = op.getNumPorts();
1506 BitVector erasures(n);
1507 for (
size_t i = 0; i < n; ++i)
1508 if (isa<DomainType>(op.getPortType(i)))
1510 op.erasePorts(erasures);
1511 return WalkResult::advance();
1513 .Case<DomainDefineOp, DomainCreateAnonOp, DomainCreateOp>(
1516 return WalkResult::advance();
1518 .Case<DomainSubfieldOp>([](DomainSubfieldOp op) {
1519 if (!op->use_empty()) {
1520 OpBuilder builder(op);
1521 op.replaceAllUsesWith(
1522 UnknownValueOp::create(builder, op.getLoc(), op.getType())
1526 return WalkResult::advance();
1528 .Case<UnsafeDomainCastOp>([](UnsafeDomainCastOp op) {
1529 op.replaceAllUsesWith(op.getInput());
1531 return WalkResult::advance();
1533 .Case<WireOp>([](WireOp op) {
1535 if (isa<DomainType>(op.getType(0))) {
1537 return WalkResult::advance();
1540 if (!op.getDomains().empty()) {
1541 op->eraseOperands(0, op.getNumOperands());
1543 return WalkResult::advance();
1545 .Case<FInstanceLike>([](
auto op) {
1546 auto n = op.getNumPorts();
1547 BitVector erasures(n);
1548 for (
size_t i = 0; i < n; ++i)
1549 if (isa<DomainType>(op->getResult(i).getType()))
1551 op.cloneWithErasedPortsAndReplaceUses(erasures);
1553 return WalkResult::advance();
1555 .Default([](Operation *op) {
1557 concat<Type>(op->getOperandTypes(), op->getResultTypes())) {
1558 if (isa<DomainType>(type)) {
1559 op->emitOpError(
"cannot be stripped");
1560 return WalkResult::interrupt();
1563 return WalkResult::advance();
1566 return failure(result.wasInterrupted());
1570 llvm::SmallVector<FModuleLike> modules;
1571 for (Operation &op : make_early_inc_range(*circuit.getBodyBlock())) {
1572 TypeSwitch<Operation *, void>(&op)
1573 .Case<FModuleLike>([&](FModuleLike op) { modules.push_back(op); })
1574 .Case<DomainOp>([](DomainOp op) { op.erase(); });
1586 FModuleOp moduleOp) {
1587 TermAllocator allocator;
1590 if (failed(
processModule(info, allocator, table, updates, moduleOp)))
1593 return updateModule(info, allocator, table, updates, moduleOp);
1598static LogicalResult
checkModule(
const DomainInfo &info, FModuleOp moduleOp) {
1608 TermAllocator allocator;
1611 return processModule(info, allocator, table, updateTable, moduleOp);
1616 FExtModuleOp moduleOp) {
1625 FModuleOp moduleOp) {
1629 TermAllocator allocator;
1631 if (failed(
processModule(info, allocator, table, updateTable, moduleOp)))
1641 const DomainInfo &info,
1644 assert(mode != InferDomainsMode::Strip);
1646 if (
auto moduleOp = dyn_cast<FModuleOp>(op)) {
1647 if (mode == InferDomainsMode::Check)
1650 if (mode == InferDomainsMode::InferAll || moduleOp.isPrivate())
1656 if (
auto extModule = dyn_cast<FExtModuleOp>(op))
1663struct InferDomainsPass
1664 :
public circt::firrtl::impl::InferDomainsBase<InferDomainsPass> {
1666 void runOnOperation()
override {
1668 auto circuit = getOperation();
1670 if (mode == InferDomainsMode::Strip) {
1672 signalPassFailure();
1676 auto &instanceGraph = getAnalysis<InstanceGraph>();
1677 DomainInfo
info(circuit);
1679 DenseSet<Operation *> errored;
1680 instanceGraph.walkPostOrder([&](
auto &node) {
1681 auto moduleOp = node.getModule();
1682 for (
auto *inst : node) {
1683 if (errored.contains(inst->getTarget()->getModule())) {
1684 errored.insert(moduleOp);
1688 if (failed(
runOnModuleLike(mode, info, updateTable, node.getModule())))
1689 errored.insert(moduleOp);
1692 signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static LogicalResult updateModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, ModuleUpdateTable &updates, FModuleOp op)
Write the domain associations recorded in the domain table back to the IR.
static void emitDuplicatePortDomainError(const DomainInfo &info, T op, size_t i, DomainTypeID domainTypeID, IntegerAttr domainPortIndexAttr1, IntegerAttr domainPortIndexAttr2)
static ExportTable initializeExportTable(const DomainTable &table, FModuleOp moduleOp)
Build a table of exported domains: a map from domains defined internally, to their set of aliasing ou...
static LogicalResult processModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FModuleOp moduleOp)
static LogicalResult updateInstance(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FInstanceLike op, OpBuilder &builder)
static LogicalResult updateWire(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, WireOp wireOp)
Update a wire operation with inferred domain associations.
static void emitAmbiguousPortDomainAssociation(const DomainInfo &info, T op, const llvm::TinyPtrVector< DomainValue > &exports, DomainTypeID typeID, size_t i)
static LogicalResult processModulePorts(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp)
static LogicalResult inferModule(const DomainInfo &info, ModuleUpdateTable &updates, FModuleOp moduleOp)
Solve for domains and then write the domain associations back to the IR.
static LogicalResult driveModuleOutputDomainPorts(const DomainInfo &info, const DomainTable &table, FModuleOp moduleOp)
SmallVector< std::pair< unsigned, PortInfo > > PortInsertions
llvm::MapVector< DomainValue, unsigned > PendingExports
A map from local domains to an aliasing port index, where that port has not yet been created.
static LogicalResult runOnModuleLike(InferDomainsMode mode, const DomainInfo &info, ModuleUpdateTable &updateTable, Operation *op)
mlir::TypedValue< DomainType > DomainValue
static LogicalResult stripCircuit(MLIRContext *context, CircuitOp circuit)
static RowTerm * getDomainAssociationAsRow(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Value value)
Get the row of domains that a hardware value in the IR is associated with.
static void emitMissingPortDomainAssociationError(const DomainInfo &info, T op, DomainTypeID typeID, size_t i)
static void getUpdatesForModulePorts(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, Namespace &ns, FModuleOp moduleOp, PendingUpdates &pending)
static LogicalResult checkInstanceDomainPortDrivers(FInstanceLike op)
Check that the input domain ports are driven.
static LogicalResult updateModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp)
After updating the port domain associations, walk the body of the moduleOp to fix up any child instan...
static FInstanceLike fixInstancePorts(FInstanceLike op, const ModuleUpdateInfo &update)
Apply the port changes of a moduleOp onto an instance-like op.
static void render(const DomainInfo &info, Diagnostic &out, VariableIDTable &idTable, Term *term)
static void processDomainDefinition(TermAllocator &allocator, DomainTable &table, DomainValue domain)
static LogicalResult unify(Term *lhs, Term *rhs)
static LogicalResult updateModuleDomainInfo(const DomainInfo &info, const DomainTable &table, const ExportTable &exportTable, ArrayAttr &result, FModuleOp moduleOp)
After generalizing the moduleOp, all domains should be solved.
static LogicalResult unifyAssociations(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Operation *op, Value lhs, Value rhs)
Unify the associated domain rows of two terms.
DenseMap< VariableTerm *, unsigned > PendingSolutions
A map from unsolved variables to a port index, where that port has not yet been created.
static LogicalResult checkModule(const DomainInfo &info, FModuleOp moduleOp)
Check that a module's ports are fully annotated, before performing domain inference on the module.
static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FInstanceLike op)
static LogicalResult checkModuleDomainPortDrivers(const DomainInfo &info, FModuleOp moduleOp)
Check that output domain ports are driven.
static void noteLocation(mlir::InFlightDiagnostic &diag, Operation *op)
static void emitPortDomainCrossingError(const DomainInfo &info, T op, size_t i, DomainTypeID domainTypeID, Term *term1, Term *term2)
static void getUpdatesForDomainAssociationOfPort(const DomainInfo &info, Namespace &ns, PendingUpdates &pending, DomainTypeID typeID, size_t ip, LocationAttr loc, Term *term, const ExportTable &exports)
static void applyUpdatesToModule(const DomainInfo &info, TermAllocator &allocator, ExportTable &exports, DomainTable &table, FModuleOp moduleOp, const PendingUpdates &pending)
static void getUpdatesForModule(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, FModuleOp mod, PendingUpdates &pending)
static SmallVector< Attribute > copyPortDomainAssociations(const DomainInfo &info, FModuleLike moduleOp, ArrayAttr moduleDomainInfo, size_t portIndex)
Copy the domain associations from the moduleOp domain info attribute into a small vector.
static void emitDomainPortInferenceError(T op, size_t i)
Emit an error when we fail to infer the concrete domain to drive to a domain port.
static void ensureExported(const DomainInfo &info, Namespace &ns, const ExportTable &exports, DomainTypeID typeID, size_t ip, LocationAttr loc, ValueTerm *val, PendingUpdates &pending)
Ensure that the domain value is available in the signature of the moduleOp, so that subsequent hardwa...
static bool isPort(BlockArgument arg)
Return true if the value is a port on the module.
DenseMap< DomainValue, TinyPtrVector< DomainValue > > ExportTable
A map from domain IR values defined internal to the moduleOp, to ports that alias that domain.
static Term * getTermForDomain(TermAllocator &allocator, DomainTable &table, DomainValue value)
Get the corresponding term for a domain in the IR.
static auto getPortDomainAssociation(ArrayAttr info, size_t i)
From a domain info attribute, get the row of associated domains for a hardware value at index i.
DenseMap< StringAttr, ModuleUpdateInfo > ModuleUpdateTable
static void ensureSolved(const DomainInfo &info, Namespace &ns, DomainTypeID typeID, size_t ip, LocationAttr loc, VariableTerm *var, PendingUpdates &pending)
If var is not solved, solve it by recording a pending input port at the indicated insertion point.
static bool isDriven(DomainValue port)
Returns true if the value is driven by a connect op.
static void solve(Term *lhs, Term *rhs)
static LogicalResult processModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FModuleOp moduleOp)
Populate the domain table by processing the moduleOp.
static LogicalResult checkModuleBody(FModuleOp moduleOp)
Check that instances under this module have driven domain input ports.
static LogicalResult checkModulePorts(const DomainInfo &info, FModuleLike moduleOp)
Check that a module's hardware ports have complete domain associations.
static Term * find(Term *x)
static LogicalResult processInstancePorts(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, T op)
static LogicalResult checkAndInferModule(const DomainInfo &info, ModuleUpdateTable &updateTable, FModuleOp moduleOp)
Check that a module's ports are fully annotated, before performing domain inference on the module.
static LogicalResult stripModule(FModuleLike op)
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
This class represents a reference to a specific field or element of an aggregate value.
A namespace that is used to store existing names and generate new names in some scope within the IR.
void add(mlir::ModuleOp module)
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
InferDomainsMode
The mode for the InferDomains pass.
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.