26#include "mlir/IR/AsmState.h"
27#include "mlir/IR/Iterators.h"
28#include "mlir/IR/Threading.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TinyPtrVector.h"
34#define DEBUG_TYPE "firrtl-infer-domains"
38#define GEN_PASS_DEF_INFERDOMAINS
39#include "circt/Dialect/FIRRTL/Passes.h.inc"
44using namespace firrtl;
50using mlir::InFlightDiagnostic;
51using mlir::ReverseIterator;
69 return info.getAsRange<IntegerAttr>();
70 return cast<ArrayAttr>(info[i]).getAsRange<IntegerAttr>();
74static bool isPort(BlockArgument arg) {
75 return isa<FModuleOp>(arg.getOwner()->getParentOp());
80 auto arg = dyn_cast<BlockArgument>(value);
88 for (
auto *user : port.getUsers())
89 if (
auto connect = dyn_cast<FConnectLike>(user))
90 if (connect.getDest() == port)
97 return type_isa<FIRRTLBaseType, RefType>(type);
105 if (
auto *op = value.getDefiningOp())
106 if (op->hasTrait<OpTrait::ConstantLike>())
129struct ModuleUpdateInfo {
131 ArrayAttr portDomainInfo;
139 CircuitState(CircuitOp circuit,
InstanceGraph &instanceGraph,
141 : circuit(circuit), instanceGraph(instanceGraph),
142 innerRefNamespace(innerRefNamespace), mode(mode) {
143 processCircuit(circuit);
148 ArrayRef<DomainOp> getDomains()
const {
return domainTable; }
149 size_t getNumDomains()
const {
return domainTable.size(); }
150 DomainOp getDomain(DomainTypeID
id)
const {
return domainTable[
id.index]; }
151 DomainTypeID getDomainTypeID(Type type) {
return typeIDTable[type]; }
153 void dirty() { asmState =
nullptr; }
154 AsmState &getAsmState() {
156 asmState = std::make_unique<AsmState>(
157 circuit, mlir::OpPrintingFlags().assumeVerified());
162 size_t getVariableID(VariableTerm *term) {
163 return variableIDTable.insert({term, variableIDTable.size() + 1})
167 DenseMap<StringAttr, ModuleUpdateInfo> &getModuleUpdateTable() {
168 return moduleUpdateTable;
173 DenseSet<Value> inserted;
176 LogicalResult runOnModule(Operation *moduleOp);
178 void processDomain(DomainOp op) {
179 auto index = domainTable.size();
180 auto domainType = DomainType::getFromDomainOp(op);
181 domainTable.push_back(op);
182 typeIDTable.insert({domainType, {index}});
185 void processCircuit(CircuitOp circuit) {
186 for (
auto decl : circuit.getOps<DomainOp>())
194 SmallVector<DomainOp> domainTable;
195 DenseMap<Type, DomainTypeID> typeIDTable;
196 DenseMap<VariableTerm *, size_t> variableIDTable;
197 std::unique_ptr<AsmState> asmState;
198 DenseMap<StringAttr, ModuleUpdateInfo> moduleUpdateTable;
218 constexpr Term(TermKind kind) : kind(kind) {}
226struct TermBase : Term {
227 static bool classof(
const Term *term) {
return term->kind == K; }
228 TermBase() : Term(K) {}
234struct VariableTerm :
public TermBase<TermKind::Variable> {
235 VariableTerm() : leader(nullptr) {}
236 VariableTerm(Term *leader) : leader(leader) {}
243struct ValueTerm :
public TermBase<TermKind::Value> {
251struct RowTerm :
public TermBase<TermKind::Row> {
252 RowTerm(ArrayRef<Term *> elements) : elements(elements) {}
253 ArrayRef<Term *> elements;
272struct PendingUpdates {
282using ExportTable = DenseMap<DomainValue, TinyPtrVector<DomainValue>>;
287 explicit ModuleState(CircuitState &globals) : globals(globals) {}
289 ArrayRef<DomainOp> getDomains() {
return globals.getDomains(); }
290 size_t getNumDomains() {
return globals.getNumDomains(); }
291 DomainOp getDomain(DomainTypeID
id) {
return globals.getDomain(
id); }
292 DomainTypeID getDomainTypeID(Type type) {
293 return globals.getDomainTypeID(type);
295 DomainTypeID getDomainTypeID(FModuleLike module,
size_t i) {
296 return globals.getDomainTypeID(module.getPortType(i));
298 DomainTypeID getDomainTypeID(FInstanceLike op,
size_t i)
const {
299 return globals.getDomainTypeID(op->getResult(i).getType());
301 DomainTypeID getDomainTypeID(
DomainValue value)
const {
302 return globals.getDomainTypeID(value.getType());
304 auto &getModuleUpdateTable() {
return globals.getModuleUpdateTable(); }
306 mlir::AsmState &getAsmState() {
return globals.getAsmState(); }
307 void dirty() { globals.dirty(); }
309 template <
typename T>
310 void render(Operation *op, T &out);
311 template <
typename T>
312 void render(Value value, T &out);
313 template <
typename T>
314 void renderLong(Value value, T &out);
315 template <
typename T>
316 void render(Term *term, T &out);
317 template <
typename T>
319 template <
typename T>
320 Render<T> render(T &&subject);
322 RenderLong renderLong(Value value);
325 LogicalResult unify(Term *lhs, Term *rhs);
326 LogicalResult unify(VariableTerm *x, Term *y);
327 LogicalResult unify(ValueTerm *xv, Term *y);
328 LogicalResult unify(RowTerm *lhsRow, Term *rhs);
329 void solve(Term *lhs, Term *rhs);
331 [[nodiscard]] RowTerm *allocRow(
size_t size);
332 [[nodiscard]] RowTerm *allocRow(ArrayRef<Term *> elements);
333 [[nodiscard]] VariableTerm *allocVar();
334 [[nodiscard]] ValueTerm *allocVal(
DomainValue value);
335 template <
typename T,
typename... Args>
336 T *alloc(Args &&...args);
337 ArrayRef<Term *> allocArray(ArrayRef<Term *> elements);
342 void setTermForDomain(
DomainValue value, Term *term);
344 Term *getOptDomainAssociation(Value value);
345 Term *getDomainAssociation(Value value);
346 void setDomainAssociation(Value value, Term *term);
349 RowTerm *getDomainAssociationAsRow(Value value);
351 void noteLocation(InFlightDiagnostic &diag, Operation *op);
352 void noteDomain(InFlightDiagnostic &diag,
DomainValue domain);
353 void noteDomainSource(InFlightDiagnostic &diag,
DomainValue domain);
354 void noteDomainSource(InFlightDiagnostic &diag, Term *term);
355 void emitDomainCrossingError(Operation *op, Value lhs, Term *lhsTerm,
356 Value rhs, Term *rhsTerm);
357 template <
typename T>
358 void emitDuplicatePortDomainError(T op,
size_t i, DomainTypeID domainTypeID,
359 IntegerAttr domainPortIndexAttr1,
360 IntegerAttr domainPortIndexAttr2);
361 template <
typename T>
362 void emitDomainPortInferenceError(T op,
size_t i);
363 template <
typename T>
364 void emitAmbiguousPortDomainAssociation(
365 T op,
const llvm::TinyPtrVector<DomainValue> &exports,
366 DomainTypeID typeID,
size_t i);
367 template <
typename T>
368 void emitMissingPortDomainAssociationError(T op, DomainTypeID typeID,
371 LogicalResult unifyAssociations(Operation *op, Value lhs, Value rhs);
372 template <
typename T>
373 LogicalResult unifyAssociations(Operation *op, T &&range);
374 LogicalResult unifyAssociations(Operation *op);
376 LogicalResult processModulePorts(FModuleOp moduleOp);
377 template <
typename T>
378 LogicalResult processInstancePorts(T op);
379 FInstanceLike fixInstancePorts(FInstanceLike op,
380 const ModuleUpdateInfo &update);
381 LogicalResult processOp(FInstanceLike op);
382 LogicalResult processOp(UnsafeDomainCastOp op);
383 LogicalResult processOp(DomainDefineOp op);
384 LogicalResult processOp(WireOp op);
385 LogicalResult processOp(RWProbeOp op);
386 LogicalResult processOp(Operation *op);
387 LogicalResult processModuleBody(FModuleOp moduleOp);
388 LogicalResult processModule(FModuleOp moduleOp);
390 ExportTable initializeExportTable(FModuleOp moduleOp);
391 void ensureSolved(
Namespace &ns, DomainTypeID typeID,
size_t ip,
392 LocationAttr loc, VariableTerm *var,
393 PendingUpdates &pending);
395 DomainTypeID typeID,
size_t ip, LocationAttr loc,
396 ValueTerm *val, PendingUpdates &pending);
397 void getUpdatesForDomainAssociationOfPort(
Namespace &ns,
398 PendingUpdates &pending,
399 DomainTypeID typeID,
size_t ip,
400 LocationAttr loc, Term *term,
402 void getUpdatesForDomainAssociationOfPort(
Namespace &ns,
404 size_t ip, LocationAttr loc,
406 PendingUpdates &pending);
407 void getUpdatesForModulePorts(FModuleOp moduleOp,
const ExportTable &exports,
409 void getUpdatesForModule(FModuleOp moduleOp,
const ExportTable &exports,
410 PendingUpdates &pending);
411 void applyUpdatesToModule(FModuleOp moduleOp,
ExportTable &exports,
412 const PendingUpdates &pending);
413 SmallVector<Attribute> copyPortDomainAssociations(FModuleOp moduleOp,
414 ArrayAttr moduleDomainInfo,
416 LogicalResult driveModuleOutputDomainPorts(FModuleOp moduleOp);
417 LogicalResult updateModuleDomainInfo(FModuleOp moduleOp,
421 solveVarWithAnonDomain(OpBuilder &builder,
422 DenseMap<DomainValue, DomainValue> &domainsInScope,
423 Operation *user, DomainType type, VariableTerm *var);
425 getDomainInScope(OpBuilder &builder,
426 DenseMap<DomainValue, DomainValue> &domainsInScope,
429 updateInstance(DenseMap<DomainValue, DomainValue> &domainsInScope,
431 LogicalResult updateWire(DenseMap<DomainValue, DomainValue> &domainsInScope,
433 LogicalResult updateModuleBody(FModuleOp moduleOp);
434 LogicalResult updateModule(FModuleOp moduleOp);
436 LogicalResult checkModulePorts(FModuleLike moduleOp);
437 LogicalResult checkModuleDomainPortDrivers(FModuleOp moduleOp);
438 LogicalResult checkInstanceDomainPortDrivers(FInstanceLike op);
439 LogicalResult checkModuleBody(FModuleOp moduleOp);
441 LogicalResult inferModule(FModuleOp moduleOp);
442 LogicalResult checkModule(FModuleOp moduleOp);
443 LogicalResult checkModule(FExtModuleOp extModuleOp);
444 LogicalResult checkAndInferModule(FModuleOp moduleOp);
447 CircuitState &globals;
448 DenseMap<Value, Term *> termTable;
449 DenseMap<Value, Term *> associationTable;
450 llvm::BumpPtrAllocator allocator;
455void ModuleState::render(Operation *op, T &out) {
456 op->print(out, getAsmState());
460void ModuleState::render(Value value, T &out) {
468 llvm::raw_string_ostream os(name);
469 value.printAsOperand(os, globals.getAsmState());
475void ModuleState::renderLong(Value value, T &out) {
476 if (
auto arg = dyn_cast<BlockArgument>(value)) {
477 if (
auto moduleOp = llvm::dyn_cast_if_present<FModuleLike>(
478 arg.getOwner()->getParentOp())) {
480 moduleOp.getPortDirection(arg.getArgNumber()));
481 out <<
" module port ";
483 }
else if (
auto result = dyn_cast<OpResult>(value)) {
484 auto *op = result.getOwner();
485 if (
auto inst = dyn_cast<FInstanceLike>(op)) {
487 inst.getPortDirection(result.getResultNumber()));
488 out <<
" instance port ";
497void ModuleState::render(Term *term, T &out) {
503 if (
auto *var = dyn_cast<VariableTerm>(term)) {
504 out <<
"?" << globals.getVariableID(var);
507 if (
auto *val = dyn_cast<ValueTerm>(term)) {
508 auto value = val->value;
512 if (
auto *row = dyn_cast<RowTerm>(term)) {
514 llvm::interleaveComma(
515 llvm::seq(
size_t(0), getNumDomains()), out, [&](
auto i) {
516 render(row->elements[i], out);
517 out <<
" : " << getDomain(DomainTypeID{i}).getSymName();
526struct ModuleState::Render {
532ModuleState::Render<T> ModuleState::render(T &&subject) {
533 return Render<T>{
this, std::forward<T>(subject)};
538 ModuleState::Render<T> r) {
539 r.state->render(r.subject, out);
558Term *ModuleState::find(Term *x) {
562 if (
auto *var = dyn_cast<VariableTerm>(x)) {
563 if (var->leader ==
nullptr)
566 auto *leader = find(var->leader);
567 if (leader != var->leader)
568 var->leader = leader;
575LogicalResult ModuleState::unify(VariableTerm *x, Term *y) {
581LogicalResult ModuleState::unify(ValueTerm *xv, Term *y) {
582 if (
auto *yv = dyn_cast<VariableTerm>(y)) {
587 if (
auto *yv = dyn_cast<ValueTerm>(y))
588 return success(xv == yv);
594LogicalResult ModuleState::unify(RowTerm *lhsRow, Term *rhs) {
595 if (
auto *rhsVar = dyn_cast<VariableTerm>(rhs)) {
596 rhsVar->leader = lhsRow;
599 if (
auto *rhsRow = dyn_cast<RowTerm>(rhs)) {
600 for (
auto [x, y] :
llvm::zip_equal(lhsRow->elements, rhsRow->elements))
601 if (failed(unify(x, y)))
609LogicalResult ModuleState::unify(Term *lhs, Term *rhs) {
617 LLVM_DEBUG(llvm::dbgs().indent(6)
618 <<
"unify " << render(lhs) <<
" = " << render(rhs) <<
"\n");
620 if (
auto *lhsVar = dyn_cast<VariableTerm>(lhs))
621 return unify(lhsVar, rhs);
622 if (
auto *lhsVal = dyn_cast<ValueTerm>(lhs))
623 return unify(lhsVal, rhs);
624 if (
auto *lhsRow = dyn_cast<RowTerm>(lhs))
625 return unify(lhsRow, rhs);
629void ModuleState::solve(Term *lhs, Term *rhs) {
630 [[maybe_unused]]
auto result = unify(lhs, rhs);
631 assert(result.succeeded());
634RowTerm *ModuleState::allocRow(
size_t size) {
635 SmallVector<Term *> elements;
636 elements.resize(size);
637 return allocRow(elements);
640RowTerm *ModuleState::allocRow(ArrayRef<Term *> elements) {
641 auto ds = allocArray(elements);
642 return alloc<RowTerm>(ds);
645VariableTerm *ModuleState::allocVar() {
return alloc<VariableTerm>(); }
647ValueTerm *ModuleState::allocVal(
DomainValue value) {
648 return alloc<ValueTerm>(value);
651template <
typename T,
typename... Args>
652T *ModuleState::alloc(Args &&...args) {
653 static_assert(std::is_base_of_v<Term, T>,
"T must be a term");
654 return new (allocator) T(std::forward<Args>(args)...);
657ArrayRef<Term *> ModuleState::allocArray(ArrayRef<Term *> elements) {
658 auto size = elements.size();
662 auto *result = allocator.Allocate<Term *>(size);
663 llvm::uninitialized_copy(elements, result);
664 for (
size_t i = 0; i < size; ++i)
666 result[i] = alloc<VariableTerm>();
668 return ArrayRef(result, size);
672 auto *term = getOptTermForDomain(value);
673 if (
auto *val = llvm::dyn_cast_if_present<ValueTerm>(term))
678Term *ModuleState::getOptTermForDomain(
DomainValue value) {
679 assert(isa<DomainType>(value.getType()));
680 auto it = termTable.find(value);
681 if (it == termTable.end())
683 return find(it->second);
686Term *ModuleState::getTermForDomain(
DomainValue value) {
687 assert(isa<DomainType>(value.getType()));
688 if (
auto *term = getOptTermForDomain(value))
690 auto *term = allocVar();
691 setTermForDomain(value, term);
695void ModuleState::setTermForDomain(
DomainValue value, Term *term) {
697 assert(!termTable.contains(value));
698 termTable.insert({value, term});
699 LLVM_DEBUG(llvm::dbgs().indent(6)
700 <<
"set " << render(value) <<
" := " << render(term) <<
"\n");
703Term *ModuleState::getOptDomainAssociation(Value value) {
705 auto it = associationTable.find(value);
706 if (it == associationTable.end())
708 return find(it->second);
711Term *ModuleState::getDomainAssociation(Value value) {
712 auto *term = getOptDomainAssociation(value);
717void ModuleState::setDomainAssociation(Value value, Term *term) {
721 associationTable.insert({value, term});
723 llvm::dbgs().indent(6) <<
"set domains(" << render(value)
724 <<
") := " << render(term) <<
"\n";
728void ModuleState::processDomainDefinition(
DomainValue domain) {
729 assert(isa<DomainType>(domain.getType()));
730 auto *newTerm = allocVal(domain);
731 auto *oldTerm = getOptTermForDomain(domain);
733 setTermForDomain(domain, newTerm);
737 [[maybe_unused]]
auto result = unify(oldTerm, newTerm);
738 assert(result.succeeded());
741RowTerm *ModuleState::getDomainAssociationAsRow(Value value) {
743 auto *term = getOptDomainAssociation(value);
747 auto *row = allocRow(getNumDomains());
748 setDomainAssociation(value, row);
753 if (
auto *row = dyn_cast<RowTerm>(term))
757 if (
auto *var = dyn_cast<VariableTerm>(term)) {
758 auto *row = allocRow(getNumDomains());
763 assert(
false &&
"unhandled term type");
767void ModuleState::noteLocation(InFlightDiagnostic &diag, Operation *op) {
768 auto ¬e = diag.attachNote(op->getLoc());
769 if (
auto mod = dyn_cast<FModuleOp>(op)) {
770 note <<
"in module " << mod.getModuleNameAttr();
773 if (
auto mod = dyn_cast<FExtModuleOp>(op)) {
774 note <<
"in extmodule " << mod.getModuleNameAttr();
777 if (
auto inst = dyn_cast<InstanceOp>(op)) {
778 note <<
"in instance " << inst.getInstanceNameAttr();
781 if (
auto inst = dyn_cast<InstanceChoiceOp>(op)) {
782 note <<
"in instance_choice " << inst.getNameAttr();
789void ModuleState::noteDomain(InFlightDiagnostic &diag,
DomainValue domain) {
790 auto ¬e = diag.attachNote(domain.getLoc());
791 note << renderLong(domain);
793 if (globals.inserted.contains(domain)) {
794 note <<
" automatically inserted here";
798 note <<
" declared here";
801void ModuleState::noteDomainSource(InFlightDiagnostic &diag,
803 auto &irns = globals.getInnerRefNamespace();
804 SmallVector<FInstanceLike> stack;
805 llvm::SmallDenseSet<DomainValue> seen;
809 auto chaseConnect = [&]() {
810 for (
auto *user : domain.getUsers()) {
811 if (
auto defineOp = dyn_cast<DomainDefineOp>(user)) {
812 if (defineOp.getDest() != domain)
814 auto src = defineOp.getSrc();
815 diag.attachNote(defineOp.getLoc())
816 << renderLong(domain) <<
" aliases " << renderLong(src);
817 domain = defineOp.getSrc();
824 auto chaseModulePort = [&]() {
825 auto arg = dyn_cast<BlockArgument>(domain);
830 llvm::dyn_cast_if_present<FModuleOp>(arg.getOwner()->getParentOp());
834 auto name =
module.getModuleNameAttr();
835 while (!stack.empty()) {
836 auto instance = stack.back();
838 auto referenced = instance.getReferencedModuleNamesAttr().getValue();
839 if (llvm::is_contained(referenced, name)) {
840 domain = cast<DomainValue>(instance->getResult(arg.getArgNumber()));
847 auto chaseInstancePort = [&]() {
848 auto result = dyn_cast<OpResult>(domain);
852 auto inst = dyn_cast<FInstanceLike>(result.getOwner());
856 auto index = result.getResultNumber();
857 if (inst.getPortDirection(index) == Direction::In)
860 auto names = inst.getReferencedModuleNamesAttr().getAsRange<StringAttr>();
861 for (
auto name : names) {
862 auto moduleLike = cast<FModuleLike>(irns.symTable.lookup(name));
863 if (
auto moduleOp = dyn_cast<FModuleOp>(moduleLike.getOperation())) {
864 stack.push_back(inst);
865 domain = cast<DomainValue>(moduleOp.getArgument(index));
872 auto chaseUnderlying = [&]() {
873 if (
auto *term = getOptTermForDomain(domain)) {
874 if (
auto *val = dyn_cast<ValueTerm>(term)) {
875 if (domain != val->value) {
876 diag.attachNote(domain.getLoc())
877 << renderLong(domain) <<
" aliases " << renderLong(val->value);
887 auto [it, inserted] = seen.insert(domain);
891 noteDomain(diag, domain);
892 chaseConnect() || chaseModulePort() || chaseInstancePort() ||
897void ModuleState::noteDomainSource(InFlightDiagnostic &diag, Term *term) {
898 auto *val = dyn_cast<ValueTerm>(find(term));
902 noteDomainSource(diag, val->value);
905void ModuleState::emitDomainCrossingError(Operation *op, Value lhs,
906 Term *lhsTerm, Value rhs,
908 auto *lhsRow = cast<RowTerm>(lhsTerm);
909 auto *rhsRow = cast<RowTerm>(rhsTerm);
911 op->emitError(
"illegal domain crossing in operation between operands ");
915 auto ¬e1 = diag.attachNote(lhs.getLoc());
917 note1 <<
" has domains ";
918 render(lhsRow, note1);
919 auto ¬e2 = diag.attachNote(rhs.getLoc());
921 note2 <<
" has domains ";
922 render(rhsRow, note2);
924 for (
size_t i = 0, e = getNumDomains(); i < e; ++i) {
925 auto *lhsDomain = find(lhsRow->elements[i]);
926 auto *rhsDomain = find(rhsRow->elements[i]);
927 if (lhsDomain == rhsDomain)
930 noteDomainSource(diag, lhsDomain);
931 noteDomainSource(diag, rhsDomain);
936void ModuleState::emitDuplicatePortDomainError(
937 T op,
size_t i, DomainTypeID domainTypeID, IntegerAttr domainPortIndexAttr1,
938 IntegerAttr domainPortIndexAttr2) {
939 auto portName = op.getPortNameAttr(i);
940 auto portLoc = op.getPortLocation(i);
941 auto domainDecl = getDomain(domainTypeID);
942 auto domainName = domainDecl.getNameAttr();
943 auto domainPortIndex1 = domainPortIndexAttr1.getUInt();
944 auto domainPortIndex2 = domainPortIndexAttr2.getUInt();
945 auto domainPortName1 = op.getPortNameAttr(domainPortIndex1);
946 auto domainPortName2 = op.getPortNameAttr(domainPortIndex2);
947 auto domainPortLoc1 = op.getPortLocation(domainPortIndex1);
948 auto domainPortLoc2 = op.getPortLocation(domainPortIndex2);
949 auto diag = emitError(portLoc);
950 diag <<
"duplicate " << domainName <<
" association for port " << portName;
951 auto ¬e1 = diag.attachNote(domainPortLoc1);
952 note1 <<
"associated with " << domainName <<
" port " << domainPortName1;
953 auto ¬e2 = diag.attachNote(domainPortLoc2);
954 note2 <<
"associated with " << domainName <<
" port " << domainPortName2;
955 noteLocation(diag, op);
961void ModuleState::emitDomainPortInferenceError(T op,
size_t i) {
962 auto name = op.getPortNameAttr(i);
963 auto diag = emitError(op->getLoc());
964 auto info = op.getDomainInfo();
965 diag <<
"unable to infer value for undriven domain port " << name;
966 for (
size_t j = 0, e = op.getNumPorts(); j < e; ++j) {
967 if (
auto assocs = dyn_cast<ArrayAttr>(info[j])) {
968 for (
auto assoc : assocs) {
969 if (i == cast<IntegerAttr>(assoc).getValue()) {
970 auto name = op.getPortNameAttr(j);
971 auto loc = op.getPortLocation(j);
972 diag.attachNote(loc) <<
"associated with hardware port " << name;
978 noteLocation(diag, op);
982void ModuleState::emitAmbiguousPortDomainAssociation(
983 T op,
const llvm::TinyPtrVector<DomainValue> &exports, DomainTypeID typeID,
985 auto portName = op.getPortNameAttr(i);
986 auto portLoc = op.getPortLocation(i);
987 auto domainDecl = getDomain(typeID);
988 auto domainName = domainDecl.getNameAttr();
989 auto diag = emitError(portLoc) <<
"ambiguous " << domainName
990 <<
" association for port " << portName;
991 for (
auto e : exports) {
992 auto arg = cast<BlockArgument>(e);
993 auto name = op.getPortNameAttr(arg.getArgNumber());
994 auto loc = op.getPortLocation(arg.getArgNumber());
995 diag.attachNote(loc) <<
"candidate association " << name;
997 noteLocation(diag, op);
1000template <
typename T>
1001void ModuleState::emitMissingPortDomainAssociationError(T op,
1002 DomainTypeID typeID,
1004 auto domainName = getDomain(typeID).getNameAttr();
1005 auto portName = op.getPortNameAttr(i);
1006 auto diag = emitError(op.getPortLocation(i))
1007 <<
"missing " << domainName <<
" association for port "
1009 noteLocation(diag, op);
1012LogicalResult ModuleState::unifyAssociations(Operation *op, Value lhs,
1021 llvm::dbgs().indent(6) <<
"unify domains(" << render(lhs) <<
") = domains("
1022 << render(rhs) <<
")\n";
1025 auto *lhsTerm = getOptDomainAssociation(lhs);
1026 auto *rhsTerm = getOptDomainAssociation(rhs);
1030 if (failed(unify(lhsTerm, rhsTerm))) {
1031 emitDomainCrossingError(op, lhs, lhsTerm, rhs, rhsTerm);
1036 setDomainAssociation(rhs, lhsTerm);
1041 setDomainAssociation(lhs, rhsTerm);
1045 auto *var = allocVar();
1046 setDomainAssociation(lhs, var);
1047 setDomainAssociation(rhs, var);
1051template <
typename T>
1052LogicalResult ModuleState::unifyAssociations(Operation *op, T &&range) {
1054 for (
auto rhs : std::forward<T>(range)) {
1057 if (failed(unifyAssociations(op, lhs, rhs)))
1065LogicalResult ModuleState::unifyAssociations(Operation *op) {
1066 return unifyAssociations(
1067 op, llvm::concat<Value>(op->getOperands(), op->getResults()));
1070LogicalResult ModuleState::processModulePorts(FModuleOp moduleOp) {
1071 auto numDomains = getNumDomains();
1072 auto domainInfo = moduleOp.getDomainInfoAttr();
1073 auto numPorts = moduleOp.getNumPorts();
1075 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1076 for (
size_t i = 0; i < numPorts; ++i) {
1077 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1081 LLVM_DEBUG(llvm::dbgs().indent(4)
1082 <<
"process port " << render(port) <<
"\n");
1084 if (moduleOp.getPortDirection(i) == Direction::In)
1085 processDomainDefinition(port);
1087 domainTypeIDTable[i] = getDomainTypeID(moduleOp, i);
1090 for (
size_t i = 0; i < numPorts; ++i) {
1091 BlockArgument port = moduleOp.getArgument(i);
1095 LLVM_DEBUG(llvm::dbgs().indent(4)
1096 <<
"process port " << render(port) <<
"\n");
1098 SmallVector<IntegerAttr> associations(numDomains);
1100 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1101 auto prevDomainPortIndex = associations[domainTypeID.index];
1102 if (prevDomainPortIndex) {
1103 emitDuplicatePortDomainError(moduleOp, i, domainTypeID,
1104 prevDomainPortIndex, domainPortIndex);
1107 associations[domainTypeID.index] = domainPortIndex;
1110 SmallVector<Term *> elements(numDomains);
1111 for (
size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
1112 ++domainTypeIndex) {
1113 auto domainPortIndex = associations[domainTypeIndex];
1114 if (!domainPortIndex)
1116 auto domainPortValue =
1117 cast<DomainValue>(moduleOp.getArgument(domainPortIndex.getUInt()));
1118 elements[domainTypeIndex] = getTermForDomain(domainPortValue);
1121 auto *domainAssociations = allocRow(elements);
1122 setDomainAssociation(port, domainAssociations);
1128template <
typename T>
1129LogicalResult ModuleState::processInstancePorts(T op) {
1130 auto numDomains = getNumDomains();
1131 auto domainInfo = op.getDomainInfoAttr();
1132 auto numPorts = op.getNumPorts();
1134 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1135 for (
size_t i = 0; i < numPorts; ++i) {
1136 auto port = dyn_cast<DomainValue>(op->getResult(i));
1140 if (op.getPortDirection(i) == Direction::Out)
1141 processDomainDefinition(port);
1143 domainTypeIDTable[i] = getDomainTypeID(op, i);
1146 for (
size_t i = 0; i < numPorts; ++i) {
1147 Value port = op->getResult(i);
1151 SmallVector<IntegerAttr> associations(numDomains);
1153 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1154 auto prevDomainPortIndex = associations[domainTypeID.index];
1155 if (prevDomainPortIndex) {
1156 emitDuplicatePortDomainError(op, i, domainTypeID, prevDomainPortIndex,
1160 associations[domainTypeID.index] = domainPortIndex;
1163 SmallVector<Term *> elements(numDomains);
1164 for (
size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
1165 ++domainTypeIndex) {
1166 auto domainPortIndex = associations[domainTypeIndex];
1167 if (!domainPortIndex)
1169 auto domainPortValue =
1170 cast<DomainValue>(op->getResult(domainPortIndex.getUInt()));
1171 elements[domainTypeIndex] = getTermForDomain(domainPortValue);
1174 auto *domainAssociations = allocRow(elements);
1175 setDomainAssociation(port, domainAssociations);
1181FInstanceLike ModuleState::fixInstancePorts(FInstanceLike op,
1182 const ModuleUpdateInfo &update) {
1183 auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions);
1184 clone.setDomainInfoAttr(update.portDomainInfo);
1187 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"fixup " << render(clone) <<
"\n");
1191LogicalResult ModuleState::processOp(FInstanceLike op) {
1193 cast<StringAttr>(cast<ArrayAttr>(op.getReferencedModuleNamesAttr())[0]);
1194 auto updateTable = getModuleUpdateTable();
1195 auto lookup = updateTable.find(moduleName);
1196 if (lookup != updateTable.end())
1197 op = fixInstancePorts(op, lookup->second);
1198 return processInstancePorts(op);
1201LogicalResult ModuleState::processOp(UnsafeDomainCastOp op) {
1202 auto domains = op.getDomains();
1203 if (domains.empty())
1204 return unifyAssociations(op, op.getInput(), op.getResult());
1206 auto input = op.getInput();
1207 RowTerm *inputRow = getDomainAssociationAsRow(input);
1208 SmallVector<Term *> elements(inputRow->elements);
1209 for (
auto value : op.getDomains()) {
1210 auto domain = cast<DomainValue>(value);
1211 auto typeID = getDomainTypeID(domain);
1212 elements[typeID.index] = getTermForDomain(domain);
1215 auto *row = allocRow(elements);
1216 setDomainAssociation(op.getResult(), row);
1220LogicalResult ModuleState::processOp(DomainDefineOp op) {
1221 auto src = op.getSrc();
1222 auto dst = op.getDest();
1224 auto *srcTerm = getTermForDomain(src);
1225 auto *dstTerm = getTermForDomain(dst);
1226 if (succeeded(unify(dstTerm, srcTerm)))
1231 <<
"defines a domain value that was inferred to be a different domain '";
1232 render(dstTerm, diag);
1238LogicalResult ModuleState::processOp(WireOp op) {
1245 if (op.getDomains().empty())
1246 return unifyAssociations(op, op.getResults());
1250 SmallVector<Term *> elements(getNumDomains());
1251 for (
auto domain : op.getDomains()) {
1252 auto domainValue = cast<DomainValue>(domain);
1253 auto typeID = getDomainTypeID(domainValue);
1254 elements[typeID.index] = getTermForDomain(domainValue);
1257 auto *row = allocRow(elements);
1258 for (
auto result : op.getResults())
1259 setDomainAssociation(result, row);
1264LogicalResult ModuleState::processOp(RWProbeOp op) {
1265 auto target = globals.getInnerRefNamespace().lookup(op.getTarget());
1267 if (target.isPort()) {
1268 auto targetOp = cast<FModuleOp>(target.getOp());
1269 auto targetValue = targetOp.getArgument(target.getPort());
1270 return unifyAssociations(op, targetValue, op.getResult());
1273 auto targetOp = cast<hw::InnerSymbolOpInterface>(target.getOp());
1274 auto targetValue = targetOp.getTargetResult();
1275 return unifyAssociations(op, targetValue, op.getResult());
1278LogicalResult ModuleState::processOp(Operation *op) {
1279 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"process " << render(op) <<
"\n");
1280 if (
auto instance = dyn_cast<FInstanceLike>(op))
1281 return processOp(instance);
1282 if (
auto wireOp = dyn_cast<WireOp>(op))
1283 return processOp(wireOp);
1284 if (
auto cast = dyn_cast<UnsafeDomainCastOp>(op))
1285 return processOp(cast);
1286 if (
auto def = dyn_cast<DomainDefineOp>(op))
1287 return processOp(def);
1288 if (
auto probe = dyn_cast<RWProbeOp>(op))
1289 return processOp(probe);
1290 if (
auto create = dyn_cast<DomainCreateOp>(op)) {
1291 processDomainDefinition(create);
1294 if (
auto createAnon = dyn_cast<DomainCreateAnonOp>(op)) {
1295 processDomainDefinition(createAnon);
1299 return unifyAssociations(op);
1302LogicalResult ModuleState::processModuleBody(FModuleOp moduleOp) {
1305 .walk([&](Operation *op) -> WalkResult { return processOp(op); })
1309LogicalResult ModuleState::processModule(FModuleOp moduleOp) {
1310 LLVM_DEBUG(llvm::dbgs().indent(2) <<
"processing:\n");
1311 if (failed(processModulePorts(moduleOp)))
1313 if (failed(processModuleBody(moduleOp)))
1318ExportTable ModuleState::initializeExportTable(FModuleOp moduleOp) {
1320 size_t numPorts = moduleOp.getNumPorts();
1321 for (
size_t i = 0; i < numPorts; ++i) {
1322 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1325 auto value = getOptUnderlyingDomain(port);
1327 exports[value].push_back(port);
1331 llvm::dbgs().indent(2) <<
"domain exports:\n";
1332 for (
auto entry : exports) {
1333 llvm::dbgs().indent(4) << render(entry.first) <<
" exported as ";
1334 llvm::interleaveComma(entry.second, llvm::dbgs(),
1335 [&](
auto e) { llvm::dbgs() << render(e); });
1336 llvm::dbgs() <<
"\n";
1343void ModuleState::ensureSolved(
Namespace &ns, DomainTypeID typeID,
size_t ip,
1344 LocationAttr loc, VariableTerm *var,
1345 PendingUpdates &pending) {
1346 if (pending.solutions.contains(var))
1349 auto *
context = loc.getContext();
1350 auto domainDecl = getDomain(typeID);
1351 auto domainName = domainDecl.getNameAttr();
1353 auto portName = StringAttr::get(
context, ns.
newName(domainName.getValue()));
1354 auto portType = DomainType::getFromDomainOp(domainDecl);
1355 auto portDirection = Direction::In;
1356 auto portSym = StringAttr();
1358 auto portAnnos = std::nullopt;
1360 auto portDomainInfo = ArrayAttr::get(
context, {});
1361 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1362 portAnnos, portDomainInfo);
1364 pending.solutions[var] = pending.insertions.size() + ip;
1365 pending.insertions.push_back({ip, portInfo});
1369 DomainTypeID typeID,
size_t ip,
1370 LocationAttr loc, ValueTerm *val,
1371 PendingUpdates &pending) {
1372 auto value = val->value;
1373 assert(isa<DomainType>(value.getType()));
1374 if (
isPort(value) || exports.contains(value) ||
1375 pending.exports.contains(value))
1378 auto *
context = loc.getContext();
1380 auto domainDecl = getDomain(typeID);
1381 auto domainName = domainDecl.getNameAttr();
1383 auto portName = StringAttr::get(
context, ns.
newName(domainName.getValue()));
1384 auto portType = DomainType::getFromDomainOp(domainDecl);
1385 auto portDirection = Direction::Out;
1386 auto portSym = StringAttr();
1387 auto portAnnos = std::nullopt;
1389 auto portDomainInfo = ArrayAttr::get(
context, {});
1390 PortInfo portInfo(portName, portType, portDirection, portSym, loc, portAnnos,
1392 pending.exports[value] = pending.insertions.size() + ip;
1393 pending.insertions.push_back({ip, portInfo});
1396void ModuleState::getUpdatesForDomainAssociationOfPort(
1397 Namespace &ns, PendingUpdates &pending, DomainTypeID typeID,
size_t ip,
1398 LocationAttr loc, Term *term,
const ExportTable &exports) {
1399 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1400 ensureSolved(ns, typeID, ip, loc, var, pending);
1403 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1404 ensureExported(ns, exports, typeID, ip, loc, val, pending);
1407 llvm_unreachable(
"invalid domain association");
1410void ModuleState::getUpdatesForDomainAssociationOfPort(
1412 RowTerm *row, PendingUpdates &pending) {
1413 for (
auto [index, term] :
llvm::enumerate(row->elements))
1414 getUpdatesForDomainAssociationOfPort(ns, pending, DomainTypeID{index}, ip,
1415 loc, find(term), exports);
1418void ModuleState::getUpdatesForModulePorts(FModuleOp moduleOp,
1421 PendingUpdates &pending) {
1422 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1423 auto port = moduleOp.getArgument(i);
1427 getUpdatesForDomainAssociationOfPort(
1428 ns, exports, i, moduleOp.getPortLocation(i),
1429 getDomainAssociationAsRow(port), pending);
1433void ModuleState::getUpdatesForModule(FModuleOp moduleOp,
1435 PendingUpdates &pending) {
1437 auto names = moduleOp.getPortNamesAttr();
1438 for (
auto name : names.getAsRange<StringAttr>())
1440 getUpdatesForModulePorts(moduleOp, exports, ns, pending);
1443void ModuleState::applyUpdatesToModule(FModuleOp moduleOp,
ExportTable &exports,
1444 const PendingUpdates &pending) {
1445 LLVM_DEBUG(llvm::dbgs().indent(2) <<
"applying updates:\n");
1447 moduleOp.insertPorts(pending.insertions);
1451 for (
auto [var, portIndex] : pending.solutions) {
1452 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1453 auto *solution = allocVal(portValue);
1454 LLVM_DEBUG(llvm::dbgs().indent(4)
1455 <<
"new-input " << render(portValue) <<
"\n");
1456 solve(var, solution);
1457 exports[portValue].push_back(portValue);
1458 globals.inserted.insert(portValue);
1462 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1463 for (
auto [domainValue, portIndex] : pending.exports) {
1464 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1465 builder.setInsertionPointAfterValue(domainValue);
1466 DomainDefineOp::create(builder, portValue.getLoc(), portValue, domainValue);
1467 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"new-output " << render(portValue)
1468 <<
" := " << render(domainValue) <<
"\n");
1469 exports[domainValue].push_back(portValue);
1470 globals.inserted.insert(portValue);
1471 setTermForDomain(portValue, allocVal(domainValue));
1475SmallVector<Attribute> ModuleState::copyPortDomainAssociations(
1476 FModuleOp moduleOp, ArrayAttr moduleDomainInfo,
size_t portIndex) {
1477 SmallVector<Attribute> result(getNumDomains());
1479 for (
auto domainPortIndexAttr : oldAssociations) {
1480 auto domainPortIndex = domainPortIndexAttr.getUInt();
1481 auto domainTypeID = getDomainTypeID(moduleOp, domainPortIndex);
1482 result[domainTypeID.index] = domainPortIndexAttr;
1487LogicalResult ModuleState::driveModuleOutputDomainPorts(FModuleOp moduleOp) {
1488 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1489 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1490 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1491 if (!port || moduleOp.getPortDirection(i) == Direction::In ||
1495 auto *term = getOptTermForDomain(port);
1496 auto *val = llvm::dyn_cast_if_present<ValueTerm>(term);
1498 emitDomainPortInferenceError(moduleOp, i);
1502 auto loc = port.getLoc();
1503 auto value = val->value;
1504 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"connect " << render(port)
1505 <<
" := " << render(value) <<
"\n");
1506 DomainDefineOp::create(builder, loc, port, value);
1512LogicalResult ModuleState::updateModuleDomainInfo(
1513 FModuleOp moduleOp,
const ExportTable &exportTable, ArrayAttr &result) {
1518 auto *
context = moduleOp.getContext();
1519 auto numDomains = getNumDomains();
1520 auto oldModuleDomainInfo = moduleOp.getDomainInfoAttr();
1521 auto numPorts = moduleOp.getNumPorts();
1522 SmallVector<Attribute> newModuleDomainInfo(numPorts);
1524 for (
size_t i = 0; i < numPorts; ++i) {
1525 auto port = moduleOp.getArgument(i);
1526 auto type = port.getType();
1528 if (isa<DomainType>(type)) {
1530 newModuleDomainInfo[i] = ArrayAttr::get(
context, {});
1535 newModuleDomainInfo[i] = ArrayAttr::get(
context, {});
1540 copyPortDomainAssociations(moduleOp, oldModuleDomainInfo, i);
1541 auto *row = cast<RowTerm>(getDomainAssociation(port));
1542 for (
size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1543 auto domainTypeID = DomainTypeID{domainIndex};
1544 if (associations[domainIndex])
1547 auto domain = cast<ValueTerm>(find(row->elements[domainIndex]))->value;
1548 auto &exports = exportTable.at(domain);
1549 if (exports.empty()) {
1550 auto portName = moduleOp.getPortNameAttr(i);
1551 auto portLoc = moduleOp.getPortLocation(i);
1552 auto domainDecl = getDomain(domainTypeID);
1553 auto domainName = domainDecl.getNameAttr();
1554 auto diag = emitError(portLoc) <<
"private " << domainName
1555 <<
" association for port " << portName;
1556 diag.attachNote(domain.getLoc()) <<
"associated domain: " << domain;
1557 noteLocation(diag, moduleOp);
1561 if (exports.size() > 1) {
1562 emitAmbiguousPortDomainAssociation(moduleOp, exports, domainTypeID, i);
1566 auto argument = cast<BlockArgument>(exports[0]);
1567 auto domainPortIndex = argument.getArgNumber();
1568 associations[domainTypeID.index] =
1569 IntegerAttr::get(IntegerType::get(
context, 32, IntegerType::Unsigned),
1573 newModuleDomainInfo[i] = ArrayAttr::get(
context, associations);
1576 result = ArrayAttr::get(moduleOp.getContext(), newModuleDomainInfo);
1577 moduleOp.setDomainInfoAttr(result);
1582 OpBuilder &builder, DenseMap<DomainValue, DomainValue> &domainsInScope,
1583 Operation *user, DomainType type, VariableTerm *var) {
1584 auto name = type.getName().getAttr();
1586 DomainCreateAnonOp::create(builder, user->getLoc(), type, name);
1588 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"create anon " << render(anon) <<
"\n");
1589 solve(var, allocVal(anon));
1590 domainsInScope[anon] = anon;
1591 globals.inserted.insert(anon);
1596 OpBuilder &builder, DenseMap<DomainValue, DomainValue> &domainsInScope,
1598 auto &domainInScope = domainsInScope[domain];
1600 return domainInScope;
1602 domainInScope = cast<DomainValue>(
1603 WireOp::create(builder, domain.getLoc(), domain.getType(),
1604 domain.getType().getName().getAttr())
1607 OpBuilder::InsertionGuard guard(builder);
1608 builder.setInsertionPointAfterValue(domain);
1609 DomainDefineOp::create(builder, domain.getLoc(), domainInScope, domain);
1611 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"bounce wire " << render(domainInScope)
1612 <<
" := " << render(domain) <<
"\n");
1613 return domainInScope;
1617ModuleState::updateInstance(DenseMap<DomainValue, DomainValue> &domainsInScope,
1619 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"update " << render(op) <<
"\n");
1620 OpBuilder builder(op.getContext());
1621 builder.setInsertionPointAfter(op);
1622 auto numPorts = op->getNumResults();
1624 for (
size_t i = 0; i < numPorts; ++i)
1625 if (
auto port = dyn_cast<DomainValue>(op->getResult(i)))
1626 if (op.getPortDirection(i) == Direction::Out)
1627 domainsInScope[port] = port;
1629 for (
size_t i = 0; i < numPorts; ++i) {
1630 auto port = dyn_cast<DomainValue>(op->getResult(i));
1631 auto direction = op.getPortDirection(i);
1635 if (port && direction == Direction::In && !
isDriven(port)) {
1636 auto loc = port.getLoc();
1637 auto *term = getTermForDomain(port);
1638 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1639 auto domain = solveVarWithAnonDomain(builder, domainsInScope, op,
1640 port.getType(), var);
1641 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"connect " << render(port)
1642 <<
" := " << render(domain) <<
"\n");
1643 DomainDefineOp::create(builder, loc, port, domain);
1646 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1647 auto domain = getDomainInScope(builder, domainsInScope, val->value);
1648 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"connect " << render(port)
1649 <<
" := " << render(domain) <<
"\n");
1650 DomainDefineOp::create(builder, loc, port, domain);
1653 llvm_unreachable(
"unhandled domain term type");
1661ModuleState::updateWire(DenseMap<DomainValue, DomainValue> &domainsInScope,
1663 auto result = wireOp.getResult();
1665 if (
auto tgt = dyn_cast<DomainValue>(result)) {
1669 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"update " << render(wireOp) <<
"\n");
1670 OpBuilder builder(wireOp);
1671 builder.setInsertionPointAfter(wireOp);
1672 auto *term = getTermForDomain(tgt);
1673 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1674 auto src = solveVarWithAnonDomain(builder, domainsInScope, wireOp,
1675 tgt.getType(), var);
1676 LLVM_DEBUG(llvm::dbgs().indent(6)
1677 <<
"connect " << render(tgt) <<
" := " << render(src) <<
"\n");
1678 DomainDefineOp::create(builder, wireOp.getLoc(), tgt, src);
1681 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1682 auto src = getDomainInScope(builder, domainsInScope, val->value);
1683 LLVM_DEBUG(llvm::dbgs().indent(6)
1684 <<
"connect " << render(tgt) <<
" := " << render(src) <<
"\n");
1685 DomainDefineOp::create(builder, wireOp.getLoc(), tgt, src);
1688 llvm_unreachable(
"unhandled domain term type");
1694 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"update " << render(wireOp) <<
"\n");
1695 OpBuilder builder(wireOp);
1696 auto *row = getDomainAssociationAsRow(wireOp.getResult());
1698 SmallVector<Value> domainOperands;
1699 for (
auto [i, element] :
llvm::enumerate(
1700 llvm::map_range(row->elements, [&](auto e) {
return find(e); }))) {
1701 if (
auto *val = dyn_cast<ValueTerm>(element)) {
1702 domainOperands.push_back(
1703 getDomainInScope(builder, domainsInScope, val->value));
1706 if (
auto *var = dyn_cast<VariableTerm>(element)) {
1707 auto type = DomainType::getFromDomainOp(getDomain(DomainTypeID{i}));
1709 solveVarWithAnonDomain(builder, domainsInScope, wireOp, type, var);
1710 domainOperands.push_back(domain);
1713 assert(0 &&
"unhandled domain type");
1715 wireOp.getDomainsMutable().assign(domainOperands);
1719LogicalResult ModuleState::updateModuleBody(FModuleOp moduleOp) {
1720 DenseMap<DomainValue, DomainValue> domainsInScope;
1722 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i)
1723 if (
auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i)))
1724 if (moduleOp.getPortDirection(i) == Direction::In)
1725 domainsInScope[port] = port;
1727 auto result = moduleOp.getBodyBlock()->walk([&](Operation *op) -> WalkResult {
1728 return TypeSwitch<Operation *, WalkResult>(op)
1730 [&](
auto wire) {
return updateWire(domainsInScope, wire); })
1731 .Case<FInstanceLike>([&](
auto instance) {
1732 return updateInstance(domainsInScope, instance);
1734 .Case<DomainCreateOp, DomainCreateAnonOp>([&](
auto domain) {
1735 domainsInScope[domain] = domain;
1738 .Default([&](
auto op) {
return success(); });
1740 return failure(result.wasInterrupted());
1743LogicalResult ModuleState::updateModule(FModuleOp moduleOp) {
1744 auto exports = initializeExportTable(moduleOp);
1745 PendingUpdates pending;
1746 getUpdatesForModule(moduleOp, exports, pending);
1747 applyUpdatesToModule(moduleOp, exports, pending);
1749 ArrayAttr portDomainInfo;
1750 if (failed(updateModuleDomainInfo(moduleOp, exports, portDomainInfo)))
1753 if (failed(driveModuleOutputDomainPorts(moduleOp)))
1757 auto &entry = getModuleUpdateTable()[moduleOp.getModuleNameAttr()];
1758 entry.portDomainInfo = portDomainInfo;
1759 entry.portInsertions = std::move(pending.insertions);
1761 if (failed(updateModuleBody(moduleOp)))
1765 llvm::dbgs().indent(2) <<
"port summary:\n";
1766 for (
auto port : moduleOp.
getBodyBlock()->getArguments()) {
1767 llvm::dbgs().indent(4) << render(port);
1768 auto info = cast<ArrayAttr>(
1769 moduleOp.getDomainInfoAttrForPort(port.getArgNumber()));
1771 llvm::dbgs() <<
" domains [";
1772 llvm::interleaveComma(
1773 info.getAsRange<IntegerAttr>(), llvm::dbgs(), [&](
auto i) {
1774 llvm::dbgs() << render(moduleOp.getArgument(i.getUInt()));
1776 llvm::dbgs() <<
"]";
1778 llvm::dbgs() <<
"\n";
1785LogicalResult ModuleState::checkModulePorts(FModuleLike moduleOp) {
1786 auto numDomains = getNumDomains();
1787 auto domainInfo = moduleOp.getDomainInfoAttr();
1788 auto numPorts = moduleOp.getNumPorts();
1790 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1791 for (
size_t i = 0; i < numPorts; ++i) {
1792 if (isa<DomainType>(moduleOp.getPortType(i)))
1793 domainTypeIDTable[i] = getDomainTypeID(moduleOp, i);
1796 for (
size_t i = 0; i < numPorts; ++i) {
1801 SmallVector<IntegerAttr> associations(numDomains);
1803 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1804 auto prevDomainPortIndex = associations[domainTypeID.index];
1805 if (prevDomainPortIndex) {
1806 emitDuplicatePortDomainError(moduleOp, i, domainTypeID,
1807 prevDomainPortIndex, domainPortIndex);
1810 associations[domainTypeID.index] = domainPortIndex;
1814 for (
size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1815 auto typeID = DomainTypeID{domainIndex};
1816 if (!associations[domainIndex]) {
1817 emitMissingPortDomainAssociationError(moduleOp, typeID, i);
1826LogicalResult ModuleState::checkModuleDomainPortDrivers(FModuleOp moduleOp) {
1827 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1828 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1829 if (!port || moduleOp.getPortDirection(i) != Direction::Out ||
1833 auto name = moduleOp.getPortNameAttr(i);
1834 auto diag = emitError(moduleOp.getPortLocation(i))
1835 <<
"undriven domain port " << name;
1836 noteLocation(diag, moduleOp);
1843LogicalResult ModuleState::checkInstanceDomainPortDrivers(FInstanceLike op) {
1844 for (
size_t i = 0, e = op->getNumResults(); i < e; ++i) {
1845 auto port = dyn_cast<DomainValue>(op->getResult(i));
1847 auto type = port.getType();
1848 if (!isa<DomainType>(type) || op.getPortDirection(i) != Direction::In ||
1852 auto name = op.getPortNameAttr(i);
1853 auto diag = emitError(op.getPortLocation(i))
1854 <<
"undriven domain port " << name;
1855 noteLocation(diag, op);
1862LogicalResult ModuleState::checkModuleBody(FModuleOp moduleOp) {
1863 auto result = moduleOp.getBody().walk([&](FInstanceLike op) -> WalkResult {
1864 return checkInstanceDomainPortDrivers(op);
1866 return failure(result.wasInterrupted());
1869LogicalResult ModuleState::inferModule(FModuleOp moduleOp) {
1870 LLVM_DEBUG(llvm::dbgs() <<
"infer: " << moduleOp.getModuleName() <<
"\n");
1871 if (failed(processModule(moduleOp)))
1874 return updateModule(moduleOp);
1877LogicalResult ModuleState::checkModule(FModuleOp moduleOp) {
1878 LLVM_DEBUG(llvm::dbgs() <<
"check: " << moduleOp.getModuleName() <<
"\n");
1879 if (failed(checkModulePorts(moduleOp)))
1882 if (failed(checkModuleDomainPortDrivers(moduleOp)))
1885 if (failed(checkModuleBody(moduleOp)))
1888 return processModule(moduleOp);
1891LogicalResult ModuleState::checkModule(FExtModuleOp extModuleOp) {
1892 LLVM_DEBUG(llvm::dbgs() <<
"check: " << extModuleOp.getModuleName() <<
"\n");
1893 return checkModulePorts(extModuleOp);
1896LogicalResult ModuleState::checkAndInferModule(FModuleOp moduleOp) {
1897 LLVM_DEBUG(llvm::dbgs() <<
"check/infer: " << moduleOp.getModuleName()
1900 if (failed(checkModulePorts(moduleOp)))
1903 if (failed(processModule(moduleOp)))
1906 if (failed(driveModuleOutputDomainPorts(moduleOp)))
1909 return updateModuleBody(moduleOp);
1917 WalkResult result = op->walk<mlir::WalkOrder::PostOrder, ReverseIterator>(
1918 [=](Operation *op) -> WalkResult {
1919 return TypeSwitch<Operation *, WalkResult>(op)
1920 .Case<FModuleLike>([](FModuleLike op) {
1921 auto n = op.getNumPorts();
1922 BitVector erasures(n);
1923 for (
size_t i = 0; i < n; ++i)
1924 if (isa<DomainType>(op.getPortType(i)))
1926 op.erasePorts(erasures);
1927 return WalkResult::advance();
1929 .Case<DomainDefineOp, DomainCreateAnonOp, DomainCreateOp>(
1932 return WalkResult::advance();
1934 .Case<DomainSubfieldOp>([](DomainSubfieldOp op) {
1935 if (!op->use_empty()) {
1936 OpBuilder builder(op);
1937 op.replaceAllUsesWith(
1938 UnknownValueOp::create(builder, op.getLoc(), op.getType())
1942 return WalkResult::advance();
1944 .Case<UnsafeDomainCastOp>([](UnsafeDomainCastOp op) {
1945 op.replaceAllUsesWith(op.getInput());
1947 return WalkResult::advance();
1949 .Case<WireOp>([](WireOp op) {
1951 if (isa<DomainType>(op.getType(0))) {
1953 return WalkResult::advance();
1956 if (!op.getDomains().empty()) {
1957 op->eraseOperands(0, op.getNumOperands());
1959 return WalkResult::advance();
1961 .Case<FInstanceLike>([](
auto op) {
1962 auto n = op.getNumPorts();
1963 BitVector erasures(n);
1964 for (
size_t i = 0; i < n; ++i)
1965 if (isa<DomainType>(op->getResult(i).getType()))
1967 op.cloneWithErasedPortsAndReplaceUses(erasures);
1969 return WalkResult::advance();
1971 .Default([](Operation *op) {
1973 concat<Type>(op->getOperandTypes(), op->getResultTypes())) {
1974 if (isa<DomainType>(type)) {
1975 op->emitOpError(
"cannot be stripped");
1976 return WalkResult::interrupt();
1979 return WalkResult::advance();
1982 return failure(result.wasInterrupted());
1986 llvm::SmallVector<FModuleLike> modules;
1987 for (Operation &op : make_early_inc_range(*circuit.getBodyBlock())) {
1988 TypeSwitch<Operation *, void>(&op)
1989 .Case<FModuleLike>([&](FModuleLike op) { modules.push_back(op); })
1990 .Case<DomainOp>([](DomainOp op) { op.erase(); });
1999LogicalResult CircuitState::runOnModule(Operation *op) {
2000 assert(mode != InferDomainsMode::Strip);
2001 ModuleState state(*
this);
2002 if (
auto moduleOp = dyn_cast<FModuleOp>(op)) {
2003 if (mode == InferDomainsMode::Check)
2004 return state.checkModule(moduleOp);
2006 if (mode == InferDomainsMode::InferAll || moduleOp.isPrivate())
2007 return state.inferModule(moduleOp);
2009 return state.checkAndInferModule(moduleOp);
2012 if (
auto extModuleOp = dyn_cast<FExtModuleOp>(op))
2013 return state.checkModule(extModuleOp);
2018LogicalResult CircuitState::run() {
2019 DenseSet<Operation *> errored;
2020 instanceGraph.walkPostOrder([&](
auto &node) {
2021 auto moduleOp = node.getModule();
2022 for (
auto *inst : node) {
2023 if (errored.contains(inst->getTarget()->getModule())) {
2024 errored.insert(moduleOp);
2028 if (failed(runOnModule(node.getModule())))
2029 errored.insert(moduleOp);
2031 return success(errored.empty());
2035struct InferDomainsPass
2036 :
public circt::firrtl::impl::InferDomainsBase<InferDomainsPass> {
2038 void runOnOperation()
override {
2040 auto circuit = getOperation();
2042 if (mode == InferDomainsMode::Strip) {
2044 signalPassFailure();
2048 auto &instanceGraph = getAnalysis<InstanceGraph>();
2049 auto &symbolTable = getAnalysis<SymbolTable>();
2050 auto &innerSymbolTableCollection =
2051 getAnalysis<InnerSymbolTableCollection>();
2053 innerSymbolTableCollection};
2054 CircuitState state(circuit, instanceGraph, innerRefNamespace, mode);
2055 if (failed(state.run()))
2056 signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
SmallVector< std::pair< unsigned, PortInfo > > PortInsertions
mlir::TypedValue< DomainType > DomainValue
static LogicalResult stripCircuit(MLIRContext *context, CircuitOp circuit)
DenseMap< VariableTerm *, unsigned > PendingSolutions
A map from unsolved variables to a port index, where that port has not yet been created.
static bool isHardware(Type type)
True if a value of the given type could be associated with a domain.
static bool isPort(BlockArgument arg)
Return true if the value is a port on the module.
DenseMap< DomainValue, TinyPtrVector< DomainValue > > ExportTable
A map from domain IR values defined internal to the moduleOp, to ports that alias that domain.
static auto getPortDomainAssociation(ArrayAttr info, size_t i)
From a domain info attribute, get the row of associated domains for a hardware value at index i.
static bool isDriven(DomainValue port)
Returns true if the value is driven by a connect op.
static LogicalResult stripModule(FModuleLike op)
static Block * getBodyBlock(FModuleLike mod)
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
This graph tracks modules and where they are instantiated.
This class represents a collection of InnerSymbolTable's.
static StringRef toLongString(Direction direction)
InferDomainsMode
The mode for the InferDomains pass.
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
This holds the name and type that describes the module's ports.
This class represents the namespace in which InnerRef's can be resolved.