26#include "mlir/IR/AsmState.h"
27#include "mlir/IR/Iterators.h"
28#include "mlir/IR/Threading.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TinyPtrVector.h"
34#define DEBUG_TYPE "firrtl-infer-domains"
38#define GEN_PASS_DEF_INFERDOMAINS
39#include "circt/Dialect/FIRRTL/Passes.h.inc"
44using namespace firrtl;
50using mlir::InFlightDiagnostic;
51using mlir::ReverseIterator;
69 return info.getAsRange<IntegerAttr>();
70 return cast<ArrayAttr>(info[i]).getAsRange<IntegerAttr>();
74static bool isPort(BlockArgument arg) {
75 return isa<FModuleOp>(arg.getOwner()->getParentOp());
80 auto arg = dyn_cast<BlockArgument>(value);
88 for (
auto *user : port.getUsers())
89 if (
auto connect = dyn_cast<FConnectLike>(user))
90 if (connect.getDest() == port)
97 return type_isa<FIRRTLBaseType, RefType>(type);
105 if (
auto *op = value.getDefiningOp())
106 if (op->hasTrait<OpTrait::ConstantLike>())
129struct ModuleUpdateInfo {
131 ArrayAttr portDomainInfo;
139 CircuitState(CircuitOp circuit,
InstanceGraph &instanceGraph,
141 : circuit(circuit), instanceGraph(instanceGraph),
142 innerRefNamespace(innerRefNamespace), mode(mode) {
143 processCircuit(circuit);
148 ArrayRef<DomainOp> getDomains()
const {
return domainTable; }
149 size_t getNumDomains()
const {
return domainTable.size(); }
150 DomainOp getDomain(DomainTypeID
id)
const {
return domainTable[
id.index]; }
151 DomainTypeID getDomainTypeID(Type type) {
return typeIDTable[type]; }
153 void dirty() { asmState =
nullptr; }
154 AsmState &getAsmState() {
156 asmState = std::make_unique<AsmState>(
157 circuit, mlir::OpPrintingFlags().assumeVerified());
162 size_t getVariableID(VariableTerm *term) {
163 return variableIDTable.insert({term, variableIDTable.size() + 1})
167 DenseMap<StringAttr, ModuleUpdateInfo> &getModuleUpdateTable() {
168 return moduleUpdateTable;
173 DenseSet<Value> inserted;
176 LogicalResult runOnModule(Operation *moduleOp);
178 void processDomain(DomainOp op) {
179 auto index = domainTable.size();
180 auto domainType = DomainType::getFromDomainOp(op);
181 domainTable.push_back(op);
182 typeIDTable.insert({domainType, {index}});
185 void processCircuit(CircuitOp circuit) {
186 for (
auto decl : circuit.getOps<DomainOp>())
194 SmallVector<DomainOp> domainTable;
195 DenseMap<Type, DomainTypeID> typeIDTable;
196 DenseMap<VariableTerm *, size_t> variableIDTable;
197 std::unique_ptr<AsmState> asmState;
198 DenseMap<StringAttr, ModuleUpdateInfo> moduleUpdateTable;
218 constexpr Term(TermKind kind) : kind(kind) {}
226struct TermBase : Term {
227 static bool classof(
const Term *term) {
return term->kind == K; }
228 TermBase() : Term(K) {}
234struct VariableTerm :
public TermBase<TermKind::Variable> {
235 VariableTerm() : leader(nullptr) {}
236 VariableTerm(Term *leader) : leader(leader) {}
243struct ValueTerm :
public TermBase<TermKind::Value> {
251struct RowTerm :
public TermBase<TermKind::Row> {
252 RowTerm(ArrayRef<Term *> elements) : elements(elements) {}
253 ArrayRef<Term *> elements;
272struct PendingUpdates {
282using ExportTable = DenseMap<DomainValue, TinyPtrVector<DomainValue>>;
287 explicit ModuleState(CircuitState &globals) : globals(globals) {}
289 ArrayRef<DomainOp> getDomains() {
return globals.getDomains(); }
290 size_t getNumDomains() {
return globals.getNumDomains(); }
291 DomainOp getDomain(DomainTypeID
id) {
return globals.getDomain(
id); }
292 DomainTypeID getDomainTypeID(Type type) {
293 return globals.getDomainTypeID(type);
295 DomainTypeID getDomainTypeID(FModuleLike module,
size_t i) {
296 return globals.getDomainTypeID(module.getPortType(i));
298 DomainTypeID getDomainTypeID(FInstanceLike op,
size_t i)
const {
299 return globals.getDomainTypeID(op->getResult(i).getType());
301 DomainTypeID getDomainTypeID(
DomainValue value)
const {
302 return globals.getDomainTypeID(value.getType());
304 auto &getModuleUpdateTable() {
return globals.getModuleUpdateTable(); }
306 mlir::AsmState &getAsmState() {
return globals.getAsmState(); }
307 void dirty() { globals.dirty(); }
309 template <
typename T>
310 void render(Operation *op, T &out);
311 template <
typename T>
312 void render(Value value, T &out);
313 template <
typename T>
314 void renderLong(Value value, T &out);
315 template <
typename T>
316 void render(Term *term, T &out);
317 template <
typename T>
319 template <
typename T>
320 Render<T> render(T &&subject);
322 RenderLong renderLong(Value value);
325 LogicalResult unify(Term *lhs, Term *rhs);
326 LogicalResult unify(VariableTerm *x, Term *y);
327 LogicalResult unify(ValueTerm *xv, Term *y);
328 LogicalResult unify(RowTerm *lhsRow, Term *rhs);
329 void solve(Term *lhs, Term *rhs);
331 [[nodiscard]] RowTerm *allocRow(
size_t size);
332 [[nodiscard]] RowTerm *allocRow(ArrayRef<Term *> elements);
333 [[nodiscard]] VariableTerm *allocVar();
334 [[nodiscard]] ValueTerm *allocVal(
DomainValue value);
335 template <
typename T,
typename... Args>
336 T *alloc(Args &&...args);
337 ArrayRef<Term *> allocArray(ArrayRef<Term *> elements);
342 void setTermForDomain(
DomainValue value, Term *term);
344 Term *getOptDomainAssociation(Value value);
345 Term *getDomainAssociation(Value value);
346 void setDomainAssociation(Value value, Term *term);
349 RowTerm *getDomainAssociationAsRow(Value value);
351 void noteLocation(InFlightDiagnostic &diag, Operation *op);
352 void noteDomain(InFlightDiagnostic &diag,
DomainValue domain);
353 void noteDomainSource(InFlightDiagnostic &diag,
DomainValue domain);
354 void noteDomainSource(InFlightDiagnostic &diag, Term *term);
355 void emitDomainCrossingError(Operation *op, Value lhs, Term *lhsTerm,
356 Value rhs, Term *rhsTerm);
357 template <
typename T>
358 void emitDuplicatePortDomainError(T op,
size_t i, DomainTypeID domainTypeID,
359 IntegerAttr domainPortIndexAttr1,
360 IntegerAttr domainPortIndexAttr2);
361 template <
typename T>
362 void emitDomainPortInferenceError(T op,
size_t i);
363 template <
typename T>
364 void emitAmbiguousPortDomainAssociation(
365 T op,
const llvm::TinyPtrVector<DomainValue> &exports,
366 DomainTypeID typeID,
size_t i);
367 template <
typename T>
368 void emitMissingPortDomainAssociationError(T op, DomainTypeID typeID,
371 LogicalResult unifyAssociations(Operation *op, Value lhs, Value rhs);
372 template <
typename T>
373 LogicalResult unifyAssociations(Operation *op, T &&range);
374 LogicalResult unifyAssociations(Operation *op);
376 LogicalResult processModulePorts(FModuleOp moduleOp);
377 template <
typename T>
378 LogicalResult processInstancePorts(T op);
379 FInstanceLike fixInstancePorts(FInstanceLike op,
380 const ModuleUpdateInfo &update);
381 LogicalResult processOp(FInstanceLike op);
382 LogicalResult processOp(UnsafeDomainCastOp op);
383 LogicalResult processOp(DomainDefineOp op);
384 LogicalResult processOp(WireOp op);
385 LogicalResult processOp(RWProbeOp op);
386 LogicalResult processOp(Operation *op);
387 LogicalResult processModuleBody(FModuleOp moduleOp);
388 LogicalResult processModule(FModuleOp moduleOp);
390 ExportTable initializeExportTable(FModuleOp moduleOp);
391 void ensureSolved(
Namespace &ns, DomainTypeID typeID,
size_t ip,
392 LocationAttr loc, VariableTerm *var,
393 PendingUpdates &pending);
395 DomainTypeID typeID,
size_t ip, LocationAttr loc,
396 ValueTerm *val, PendingUpdates &pending);
397 void getUpdatesForDomainAssociationOfPort(
Namespace &ns,
398 PendingUpdates &pending,
399 DomainTypeID typeID,
size_t ip,
400 LocationAttr loc, Term *term,
402 void getUpdatesForDomainAssociationOfPort(
Namespace &ns,
404 size_t ip, LocationAttr loc,
406 PendingUpdates &pending);
407 void getUpdatesForModulePorts(FModuleOp moduleOp,
const ExportTable &exports,
409 void getUpdatesForModule(FModuleOp moduleOp,
const ExportTable &exports,
410 PendingUpdates &pending);
411 void applyUpdatesToModule(FModuleOp moduleOp,
ExportTable &exports,
412 const PendingUpdates &pending);
413 SmallVector<Attribute> copyPortDomainAssociations(FModuleOp moduleOp,
414 ArrayAttr moduleDomainInfo,
416 LogicalResult driveModuleOutputDomainPorts(FModuleOp moduleOp);
417 LogicalResult updateModuleDomainInfo(FModuleOp moduleOp,
421 solveVarWithAnonDomain(OpBuilder &builder,
422 DenseMap<DomainValue, DomainValue> &domainsInScope,
423 Operation *user, DomainType type, VariableTerm *var);
425 getDomainInScope(OpBuilder &builder,
426 DenseMap<DomainValue, DomainValue> &domainsInScope,
429 updateInstance(DenseMap<DomainValue, DomainValue> &domainsInScope,
431 LogicalResult updateWire(DenseMap<DomainValue, DomainValue> &domainsInScope,
433 LogicalResult updateModuleBody(FModuleOp moduleOp);
434 LogicalResult updateModule(FModuleOp moduleOp);
436 LogicalResult checkModulePorts(FModuleLike moduleOp);
437 LogicalResult checkModuleDomainPortDrivers(FModuleOp moduleOp);
438 LogicalResult checkInstanceDomainPortDrivers(FInstanceLike op);
439 LogicalResult checkModuleBody(FModuleOp moduleOp);
441 LogicalResult inferModule(FModuleOp moduleOp);
442 LogicalResult checkModule(FModuleOp moduleOp);
443 LogicalResult checkModule(FExtModuleOp extModuleOp);
444 LogicalResult checkAndInferModule(FModuleOp moduleOp);
447 CircuitState &globals;
448 DenseMap<Value, Term *> termTable;
449 DenseMap<Value, Term *> associationTable;
450 llvm::BumpPtrAllocator allocator;
455void ModuleState::render(Operation *op, T &out) {
456 op->print(out, getAsmState());
460void ModuleState::render(Value value, T &out) {
468 llvm::raw_string_ostream os(name);
469 value.printAsOperand(os, globals.getAsmState());
475void ModuleState::renderLong(Value value, T &out) {
476 if (
auto arg = dyn_cast<BlockArgument>(value)) {
477 if (
auto moduleOp = llvm::dyn_cast_if_present<FModuleLike>(
478 arg.getOwner()->getParentOp())) {
480 moduleOp.getPortDirection(arg.getArgNumber()));
481 out <<
" module port ";
483 }
else if (
auto result = dyn_cast<OpResult>(value)) {
484 auto *op = result.getOwner();
485 if (
auto inst = dyn_cast<FInstanceLike>(op)) {
487 inst.getPortDirection(result.getResultNumber()));
488 out <<
" instance port ";
497void ModuleState::render(Term *term, T &out) {
503 if (
auto *var = dyn_cast<VariableTerm>(term)) {
504 out <<
"?" << globals.getVariableID(var);
507 if (
auto *val = dyn_cast<ValueTerm>(term)) {
508 auto value = val->value;
512 if (
auto *row = dyn_cast<RowTerm>(term)) {
514 llvm::interleaveComma(
515 llvm::seq(
size_t(0), getNumDomains()), out, [&](
auto i) {
516 render(row->elements[i], out);
517 out <<
" : " << getDomain(DomainTypeID{i}).getSymName();
526struct ModuleState::Render {
532ModuleState::Render<T> ModuleState::render(T &&subject) {
533 return Render<T>{
this, std::forward<T>(subject)};
538 ModuleState::Render<T> r) {
539 r.state->render(r.subject, out);
558Term *ModuleState::find(Term *x) {
562 if (
auto *var = dyn_cast<VariableTerm>(x)) {
563 if (var->leader ==
nullptr)
566 auto *leader = find(var->leader);
567 if (leader != var->leader)
568 var->leader = leader;
575LogicalResult ModuleState::unify(VariableTerm *x, Term *y) {
581LogicalResult ModuleState::unify(ValueTerm *xv, Term *y) {
582 if (
auto *yv = dyn_cast<VariableTerm>(y)) {
587 if (
auto *yv = dyn_cast<ValueTerm>(y))
588 return success(xv == yv);
594LogicalResult ModuleState::unify(RowTerm *lhsRow, Term *rhs) {
595 if (
auto *rhsVar = dyn_cast<VariableTerm>(rhs)) {
596 rhsVar->leader = lhsRow;
599 if (
auto *rhsRow = dyn_cast<RowTerm>(rhs)) {
600 for (
auto [x, y] :
llvm::zip_equal(lhsRow->elements, rhsRow->elements))
601 if (failed(unify(x, y)))
609LogicalResult ModuleState::unify(Term *lhs, Term *rhs) {
617 LLVM_DEBUG(llvm::dbgs().indent(6)
618 <<
"unify " << render(lhs) <<
" = " << render(rhs) <<
"\n");
620 if (
auto *lhsVar = dyn_cast<VariableTerm>(lhs))
621 return unify(lhsVar, rhs);
622 if (
auto *lhsVal = dyn_cast<ValueTerm>(lhs))
623 return unify(lhsVal, rhs);
624 if (
auto *lhsRow = dyn_cast<RowTerm>(lhs))
625 return unify(lhsRow, rhs);
629void ModuleState::solve(Term *lhs, Term *rhs) {
630 [[maybe_unused]]
auto result = unify(lhs, rhs);
631 assert(result.succeeded());
634RowTerm *ModuleState::allocRow(
size_t size) {
635 SmallVector<Term *> elements;
636 elements.resize(size);
637 return allocRow(elements);
640RowTerm *ModuleState::allocRow(ArrayRef<Term *> elements) {
641 auto ds = allocArray(elements);
642 return alloc<RowTerm>(ds);
645VariableTerm *ModuleState::allocVar() {
return alloc<VariableTerm>(); }
647ValueTerm *ModuleState::allocVal(
DomainValue value) {
648 return alloc<ValueTerm>(value);
651template <
typename T,
typename... Args>
652T *ModuleState::alloc(Args &&...args) {
653 static_assert(std::is_base_of_v<Term, T>,
"T must be a term");
654 return new (allocator) T(std::forward<Args>(args)...);
657ArrayRef<Term *> ModuleState::allocArray(ArrayRef<Term *> elements) {
658 auto size = elements.size();
662 auto *result = allocator.Allocate<Term *>(size);
663 llvm::uninitialized_copy(elements, result);
664 for (
size_t i = 0; i < size; ++i)
666 result[i] = alloc<VariableTerm>();
668 return ArrayRef(result, size);
672 auto *term = getOptTermForDomain(value);
673 if (
auto *val = llvm::dyn_cast_if_present<ValueTerm>(term))
678Term *ModuleState::getOptTermForDomain(
DomainValue value) {
679 assert(isa<DomainType>(value.getType()));
680 auto it = termTable.find(value);
681 if (it == termTable.end())
683 return find(it->second);
686Term *ModuleState::getTermForDomain(
DomainValue value) {
687 assert(isa<DomainType>(value.getType()));
688 if (
auto *term = getOptTermForDomain(value))
690 auto *term = allocVar();
691 setTermForDomain(value, term);
695void ModuleState::setTermForDomain(
DomainValue value, Term *term) {
697 assert(!termTable.contains(value));
698 termTable.insert({value, term});
699 LLVM_DEBUG(llvm::dbgs().indent(6)
700 <<
"set " << render(value) <<
" := " << render(term) <<
"\n");
703Term *ModuleState::getOptDomainAssociation(Value value) {
705 auto it = associationTable.find(value);
706 if (it == associationTable.end())
708 return find(it->second);
711Term *ModuleState::getDomainAssociation(Value value) {
712 auto *term = getOptDomainAssociation(value);
717void ModuleState::setDomainAssociation(Value value, Term *term) {
721 associationTable.insert({value, term});
723 llvm::dbgs().indent(6) <<
"set domains(" << render(value)
724 <<
") := " << render(term) <<
"\n";
728void ModuleState::processDomainDefinition(
DomainValue domain) {
729 assert(isa<DomainType>(domain.getType()));
730 auto *newTerm = allocVal(domain);
731 auto *oldTerm = getOptTermForDomain(domain);
733 setTermForDomain(domain, newTerm);
737 [[maybe_unused]]
auto result = unify(oldTerm, newTerm);
738 assert(result.succeeded());
741RowTerm *ModuleState::getDomainAssociationAsRow(Value value) {
743 auto *term = getOptDomainAssociation(value);
747 auto *row = allocRow(getNumDomains());
748 setDomainAssociation(value, row);
753 if (
auto *row = dyn_cast<RowTerm>(term))
757 if (
auto *var = dyn_cast<VariableTerm>(term)) {
758 auto *row = allocRow(getNumDomains());
763 assert(
false &&
"unhandled term type");
767void ModuleState::noteLocation(InFlightDiagnostic &diag, Operation *op) {
768 auto ¬e = diag.attachNote(op->getLoc());
769 if (
auto mod = dyn_cast<FModuleOp>(op)) {
770 note <<
"in module " << mod.getModuleNameAttr();
773 if (
auto mod = dyn_cast<FExtModuleOp>(op)) {
774 note <<
"in extmodule " << mod.getModuleNameAttr();
777 if (
auto inst = dyn_cast<InstanceOp>(op)) {
778 note <<
"in instance " << inst.getInstanceNameAttr();
781 if (
auto inst = dyn_cast<InstanceChoiceOp>(op)) {
782 note <<
"in instance_choice " << inst.getNameAttr();
789void ModuleState::noteDomain(InFlightDiagnostic &diag,
DomainValue domain) {
790 auto ¬e = diag.attachNote(domain.getLoc());
791 note << renderLong(domain);
793 if (globals.inserted.contains(domain)) {
794 note <<
" automatically inserted here";
798 note <<
" declared here";
801void ModuleState::noteDomainSource(InFlightDiagnostic &diag,
803 auto &irns = globals.getInnerRefNamespace();
804 SmallVector<FInstanceLike> stack;
805 llvm::SmallDenseSet<DomainValue> seen;
809 auto chaseConnect = [&]() {
810 for (
auto *user : domain.getUsers()) {
811 if (
auto defineOp = dyn_cast<DomainDefineOp>(user)) {
812 if (defineOp.getDest() != domain)
814 auto src = defineOp.getSrc();
815 diag.attachNote(defineOp.getLoc())
816 << renderLong(domain) <<
" aliases " << renderLong(src);
817 domain = defineOp.getSrc();
824 auto chaseModulePort = [&]() {
825 auto arg = dyn_cast<BlockArgument>(domain);
830 llvm::dyn_cast_if_present<FModuleOp>(arg.getOwner()->getParentOp());
834 auto name =
module.getModuleNameAttr();
835 while (!stack.empty()) {
836 auto instance = stack.back();
838 auto referenced = instance.getReferencedModuleNamesAttr().getValue();
839 if (llvm::is_contained(referenced, name)) {
840 domain = cast<DomainValue>(instance->getResult(arg.getArgNumber()));
847 auto chaseInstancePort = [&]() {
848 auto result = dyn_cast<OpResult>(domain);
852 auto inst = dyn_cast<FInstanceLike>(result.getOwner());
856 auto index = result.getResultNumber();
857 if (inst.getPortDirection(index) == Direction::In)
860 auto names = inst.getReferencedModuleNamesAttr().getAsRange<StringAttr>();
861 for (
auto name : names) {
862 auto moduleLike = cast<FModuleLike>(irns.symTable.lookup(name));
863 if (
auto moduleOp = dyn_cast<FModuleOp>(moduleLike.getOperation())) {
864 stack.push_back(inst);
865 domain = cast<DomainValue>(moduleOp.getArgument(index));
872 auto chaseUnderlying = [&]() {
873 if (
auto *term = getOptTermForDomain(domain)) {
874 if (
auto *val = dyn_cast<ValueTerm>(term)) {
875 if (domain != val->value) {
876 diag.attachNote(domain.getLoc())
877 << renderLong(domain) <<
" aliases " << renderLong(val->value);
887 auto [it, inserted] = seen.insert(domain);
891 noteDomain(diag, domain);
892 chaseConnect() || chaseModulePort() || chaseInstancePort() ||
897void ModuleState::noteDomainSource(InFlightDiagnostic &diag, Term *term) {
898 auto *val = dyn_cast<ValueTerm>(find(term));
902 noteDomainSource(diag, val->value);
905void ModuleState::emitDomainCrossingError(Operation *op, Value lhs,
906 Term *lhsTerm, Value rhs,
908 auto *lhsRow = cast<RowTerm>(lhsTerm);
909 auto *rhsRow = cast<RowTerm>(rhsTerm);
911 op->emitError(
"illegal domain crossing in operation between operands ");
915 auto ¬e1 = diag.attachNote(lhs.getLoc());
917 note1 <<
" has domains ";
918 render(lhsRow, note1);
919 auto ¬e2 = diag.attachNote(rhs.getLoc());
921 note2 <<
" has domains ";
922 render(rhsRow, note2);
924 for (
size_t i = 0, e = getNumDomains(); i < e; ++i) {
925 auto *lhsDomain = find(lhsRow->elements[i]);
926 auto *rhsDomain = find(rhsRow->elements[i]);
927 if (lhsDomain == rhsDomain)
930 noteDomainSource(diag, lhsDomain);
931 noteDomainSource(diag, rhsDomain);
936void ModuleState::emitDuplicatePortDomainError(
937 T op,
size_t i, DomainTypeID domainTypeID, IntegerAttr domainPortIndexAttr1,
938 IntegerAttr domainPortIndexAttr2) {
939 auto portName = op.getPortNameAttr(i);
940 auto portLoc = op.getPortLocation(i);
941 auto domainDecl = getDomain(domainTypeID);
942 auto domainName = domainDecl.getNameAttr();
943 auto domainPortIndex1 = domainPortIndexAttr1.getUInt();
944 auto domainPortIndex2 = domainPortIndexAttr2.getUInt();
945 auto domainPortName1 = op.getPortNameAttr(domainPortIndex1);
946 auto domainPortName2 = op.getPortNameAttr(domainPortIndex2);
947 auto domainPortLoc1 = op.getPortLocation(domainPortIndex1);
948 auto domainPortLoc2 = op.getPortLocation(domainPortIndex2);
949 auto diag = emitError(portLoc);
950 diag <<
"duplicate " << domainName <<
" association for port " << portName;
951 auto ¬e1 = diag.attachNote(domainPortLoc1);
952 note1 <<
"associated with " << domainName <<
" port " << domainPortName1;
953 auto ¬e2 = diag.attachNote(domainPortLoc2);
954 note2 <<
"associated with " << domainName <<
" port " << domainPortName2;
955 noteLocation(diag, op);
961void ModuleState::emitDomainPortInferenceError(T op,
size_t i) {
962 auto name = op.getPortNameAttr(i);
963 auto diag = emitError(op->getLoc());
964 auto info = op.getDomainInfo();
965 diag <<
"unable to infer value for undriven domain port " << name;
966 for (
size_t j = 0, e = op.getNumPorts(); j < e; ++j) {
967 if (
auto assocs = dyn_cast<ArrayAttr>(info[j])) {
968 for (
auto assoc : assocs) {
969 if (i == cast<IntegerAttr>(assoc).getValue()) {
970 auto name = op.getPortNameAttr(j);
971 auto loc = op.getPortLocation(j);
972 diag.attachNote(loc) <<
"associated with hardware port " << name;
978 noteLocation(diag, op);
982void ModuleState::emitAmbiguousPortDomainAssociation(
983 T op,
const llvm::TinyPtrVector<DomainValue> &exports, DomainTypeID typeID,
985 auto portName = op.getPortNameAttr(i);
986 auto portLoc = op.getPortLocation(i);
987 auto domainDecl = getDomain(typeID);
988 auto domainName = domainDecl.getNameAttr();
989 auto diag = emitError(portLoc) <<
"ambiguous " << domainName
990 <<
" association for port " << portName;
991 for (
auto e : exports) {
992 auto arg = cast<BlockArgument>(e);
993 auto name = op.getPortNameAttr(arg.getArgNumber());
994 auto loc = op.getPortLocation(arg.getArgNumber());
995 diag.attachNote(loc) <<
"candidate association " << name;
997 noteLocation(diag, op);
1000template <
typename T>
1001void ModuleState::emitMissingPortDomainAssociationError(T op,
1002 DomainTypeID typeID,
1004 auto domainName = getDomain(typeID).getNameAttr();
1005 auto portName = op.getPortNameAttr(i);
1006 auto diag = emitError(op.getPortLocation(i))
1007 <<
"missing " << domainName <<
" association for port "
1009 noteLocation(diag, op);
1012LogicalResult ModuleState::unifyAssociations(Operation *op, Value lhs,
1024 llvm::dbgs().indent(6) <<
"unify domains(" << render(lhs) <<
") = domains("
1025 << render(rhs) <<
")\n";
1028 auto *lhsTerm = getOptDomainAssociation(lhs);
1029 auto *rhsTerm = getOptDomainAssociation(rhs);
1033 if (failed(unify(lhsTerm, rhsTerm))) {
1034 emitDomainCrossingError(op, lhs, lhsTerm, rhs, rhsTerm);
1039 setDomainAssociation(rhs, lhsTerm);
1044 setDomainAssociation(lhs, rhsTerm);
1048 auto *var = allocVar();
1049 setDomainAssociation(lhs, var);
1050 setDomainAssociation(rhs, var);
1054template <
typename T>
1055LogicalResult ModuleState::unifyAssociations(Operation *op, T &&range) {
1057 for (
auto rhs : std::forward<T>(range)) {
1060 if (failed(unifyAssociations(op, lhs, rhs)))
1068LogicalResult ModuleState::unifyAssociations(Operation *op) {
1069 return unifyAssociations(
1070 op, llvm::concat<Value>(op->getOperands(), op->getResults()));
1073LogicalResult ModuleState::processModulePorts(FModuleOp moduleOp) {
1074 auto numDomains = getNumDomains();
1075 auto domainInfo = moduleOp.getDomainInfoAttr();
1076 auto numPorts = moduleOp.getNumPorts();
1078 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1079 for (
size_t i = 0; i < numPorts; ++i) {
1080 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1084 LLVM_DEBUG(llvm::dbgs().indent(4)
1085 <<
"process port " << render(port) <<
"\n");
1087 if (moduleOp.getPortDirection(i) == Direction::In)
1088 processDomainDefinition(port);
1090 domainTypeIDTable[i] = getDomainTypeID(moduleOp, i);
1093 for (
size_t i = 0; i < numPorts; ++i) {
1094 BlockArgument port = moduleOp.getArgument(i);
1098 LLVM_DEBUG(llvm::dbgs().indent(4)
1099 <<
"process port " << render(port) <<
"\n");
1101 SmallVector<IntegerAttr> associations(numDomains);
1103 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1104 auto prevDomainPortIndex = associations[domainTypeID.index];
1105 if (prevDomainPortIndex) {
1106 emitDuplicatePortDomainError(moduleOp, i, domainTypeID,
1107 prevDomainPortIndex, domainPortIndex);
1110 associations[domainTypeID.index] = domainPortIndex;
1113 SmallVector<Term *> elements(numDomains);
1114 for (
size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
1115 ++domainTypeIndex) {
1116 auto domainPortIndex = associations[domainTypeIndex];
1117 if (!domainPortIndex)
1119 auto domainPortValue =
1120 cast<DomainValue>(moduleOp.getArgument(domainPortIndex.getUInt()));
1121 elements[domainTypeIndex] = getTermForDomain(domainPortValue);
1124 auto *domainAssociations = allocRow(elements);
1125 setDomainAssociation(port, domainAssociations);
1131template <
typename T>
1132LogicalResult ModuleState::processInstancePorts(T op) {
1133 auto numDomains = getNumDomains();
1134 auto domainInfo = op.getDomainInfoAttr();
1135 auto numPorts = op.getNumPorts();
1137 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1138 for (
size_t i = 0; i < numPorts; ++i) {
1139 auto port = dyn_cast<DomainValue>(op->getResult(i));
1143 if (op.getPortDirection(i) == Direction::Out)
1144 processDomainDefinition(port);
1146 domainTypeIDTable[i] = getDomainTypeID(op, i);
1149 for (
size_t i = 0; i < numPorts; ++i) {
1150 Value port = op->getResult(i);
1154 SmallVector<IntegerAttr> associations(numDomains);
1156 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1157 auto prevDomainPortIndex = associations[domainTypeID.index];
1158 if (prevDomainPortIndex) {
1159 emitDuplicatePortDomainError(op, i, domainTypeID, prevDomainPortIndex,
1163 associations[domainTypeID.index] = domainPortIndex;
1166 SmallVector<Term *> elements(numDomains);
1167 for (
size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
1168 ++domainTypeIndex) {
1169 auto domainPortIndex = associations[domainTypeIndex];
1170 if (!domainPortIndex)
1172 auto domainPortValue =
1173 cast<DomainValue>(op->getResult(domainPortIndex.getUInt()));
1174 elements[domainTypeIndex] = getTermForDomain(domainPortValue);
1177 auto *domainAssociations = allocRow(elements);
1178 setDomainAssociation(port, domainAssociations);
1184FInstanceLike ModuleState::fixInstancePorts(FInstanceLike op,
1185 const ModuleUpdateInfo &update) {
1186 auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions);
1187 clone.setDomainInfoAttr(update.portDomainInfo);
1190 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"fixup " << render(clone) <<
"\n");
1194LogicalResult ModuleState::processOp(FInstanceLike op) {
1196 cast<StringAttr>(cast<ArrayAttr>(op.getReferencedModuleNamesAttr())[0]);
1197 auto updateTable = getModuleUpdateTable();
1198 auto lookup = updateTable.find(moduleName);
1199 if (lookup != updateTable.end())
1200 op = fixInstancePorts(op, lookup->second);
1201 return processInstancePorts(op);
1204LogicalResult ModuleState::processOp(UnsafeDomainCastOp op) {
1205 auto domains = op.getDomains();
1206 if (domains.empty())
1207 return unifyAssociations(op, op.getInput(), op.getResult());
1209 auto input = op.getInput();
1211 SmallVector<Term *> elements(getNumDomains());
1213 auto *inputRow = getDomainAssociationAsRow(input);
1214 elements.assign(inputRow->elements);
1217 for (
auto value : op.getDomains()) {
1218 auto domain = cast<DomainValue>(value);
1219 auto typeID = getDomainTypeID(domain);
1220 elements[typeID.index] = getTermForDomain(domain);
1223 auto *row = allocRow(elements);
1224 setDomainAssociation(op.getResult(), row);
1228LogicalResult ModuleState::processOp(DomainDefineOp op) {
1229 auto src = op.getSrc();
1230 auto dst = op.getDest();
1232 auto *srcTerm = getTermForDomain(src);
1233 auto *dstTerm = getTermForDomain(dst);
1234 if (succeeded(unify(dstTerm, srcTerm)))
1239 <<
"defines a domain value that was inferred to be a different domain '";
1240 render(dstTerm, diag);
1246LogicalResult ModuleState::processOp(WireOp op) {
1253 if (op.getDomains().empty())
1254 return unifyAssociations(op, op.getResults());
1258 SmallVector<Term *> elements(getNumDomains());
1259 for (
auto domain : op.getDomains()) {
1260 auto domainValue = cast<DomainValue>(domain);
1261 auto typeID = getDomainTypeID(domainValue);
1262 elements[typeID.index] = getTermForDomain(domainValue);
1265 auto *row = allocRow(elements);
1266 for (
auto result : op.getResults())
1267 setDomainAssociation(result, row);
1272LogicalResult ModuleState::processOp(RWProbeOp op) {
1273 auto target = globals.getInnerRefNamespace().lookup(op.getTarget());
1275 if (target.isPort()) {
1276 auto targetOp = cast<FModuleOp>(target.getOp());
1277 auto targetValue = targetOp.getArgument(target.getPort());
1278 return unifyAssociations(op, targetValue, op.getResult());
1281 auto targetOp = cast<hw::InnerSymbolOpInterface>(target.getOp());
1282 auto targetValue = targetOp.getTargetResult();
1283 return unifyAssociations(op, targetValue, op.getResult());
1286LogicalResult ModuleState::processOp(Operation *op) {
1287 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"process " << render(op) <<
"\n");
1288 if (
auto instance = dyn_cast<FInstanceLike>(op))
1289 return processOp(instance);
1290 if (
auto wireOp = dyn_cast<WireOp>(op))
1291 return processOp(wireOp);
1292 if (
auto cast = dyn_cast<UnsafeDomainCastOp>(op))
1293 return processOp(cast);
1294 if (
auto def = dyn_cast<DomainDefineOp>(op))
1295 return processOp(def);
1296 if (
auto probe = dyn_cast<RWProbeOp>(op))
1297 return processOp(probe);
1298 if (
auto create = dyn_cast<DomainCreateOp>(op)) {
1299 processDomainDefinition(create);
1302 if (
auto createAnon = dyn_cast<DomainCreateAnonOp>(op)) {
1303 processDomainDefinition(createAnon);
1307 return unifyAssociations(op);
1310LogicalResult ModuleState::processModuleBody(FModuleOp moduleOp) {
1313 .walk([&](Operation *op) -> WalkResult { return processOp(op); })
1317LogicalResult ModuleState::processModule(FModuleOp moduleOp) {
1318 LLVM_DEBUG(llvm::dbgs().indent(2) <<
"processing:\n");
1319 if (failed(processModulePorts(moduleOp)))
1321 if (failed(processModuleBody(moduleOp)))
1326ExportTable ModuleState::initializeExportTable(FModuleOp moduleOp) {
1328 size_t numPorts = moduleOp.getNumPorts();
1329 for (
size_t i = 0; i < numPorts; ++i) {
1330 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1333 auto value = getOptUnderlyingDomain(port);
1335 exports[value].push_back(port);
1339 llvm::dbgs().indent(2) <<
"domain exports:\n";
1340 for (
auto entry : exports) {
1341 llvm::dbgs().indent(4) << render(entry.first) <<
" exported as ";
1342 llvm::interleaveComma(entry.second, llvm::dbgs(),
1343 [&](
auto e) { llvm::dbgs() << render(e); });
1344 llvm::dbgs() <<
"\n";
1351void ModuleState::ensureSolved(
Namespace &ns, DomainTypeID typeID,
size_t ip,
1352 LocationAttr loc, VariableTerm *var,
1353 PendingUpdates &pending) {
1354 if (pending.solutions.contains(var))
1357 auto *
context = loc.getContext();
1358 auto domainDecl = getDomain(typeID);
1359 auto domainName = domainDecl.getNameAttr();
1361 auto portName = StringAttr::get(
context, ns.
newName(domainName.getValue()));
1362 auto portType = DomainType::getFromDomainOp(domainDecl);
1363 auto portDirection = Direction::In;
1364 auto portSym = StringAttr();
1366 auto portAnnos = std::nullopt;
1368 auto portDomainInfo = ArrayAttr::get(
context, {});
1369 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1370 portAnnos, portDomainInfo);
1372 pending.solutions[var] = pending.insertions.size() + ip;
1373 pending.insertions.push_back({ip, portInfo});
1377 DomainTypeID typeID,
size_t ip,
1378 LocationAttr loc, ValueTerm *val,
1379 PendingUpdates &pending) {
1380 auto value = val->value;
1381 assert(isa<DomainType>(value.getType()));
1382 if (
isPort(value) || exports.contains(value) ||
1383 pending.exports.contains(value))
1386 auto *
context = loc.getContext();
1388 auto domainDecl = getDomain(typeID);
1389 auto domainName = domainDecl.getNameAttr();
1391 auto portName = StringAttr::get(
context, ns.
newName(domainName.getValue()));
1392 auto portType = DomainType::getFromDomainOp(domainDecl);
1393 auto portDirection = Direction::Out;
1394 auto portSym = StringAttr();
1395 auto portAnnos = std::nullopt;
1397 auto portDomainInfo = ArrayAttr::get(
context, {});
1398 PortInfo portInfo(portName, portType, portDirection, portSym, loc, portAnnos,
1400 pending.exports[value] = pending.insertions.size() + ip;
1401 pending.insertions.push_back({ip, portInfo});
1404void ModuleState::getUpdatesForDomainAssociationOfPort(
1405 Namespace &ns, PendingUpdates &pending, DomainTypeID typeID,
size_t ip,
1406 LocationAttr loc, Term *term,
const ExportTable &exports) {
1407 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1408 ensureSolved(ns, typeID, ip, loc, var, pending);
1411 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1412 ensureExported(ns, exports, typeID, ip, loc, val, pending);
1415 llvm_unreachable(
"invalid domain association");
1418void ModuleState::getUpdatesForDomainAssociationOfPort(
1420 RowTerm *row, PendingUpdates &pending) {
1421 for (
auto [index, term] :
llvm::enumerate(row->elements))
1422 getUpdatesForDomainAssociationOfPort(ns, pending, DomainTypeID{index}, ip,
1423 loc, find(term), exports);
1426void ModuleState::getUpdatesForModulePorts(FModuleOp moduleOp,
1429 PendingUpdates &pending) {
1430 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1431 auto port = moduleOp.getArgument(i);
1435 getUpdatesForDomainAssociationOfPort(
1436 ns, exports, i, moduleOp.getPortLocation(i),
1437 getDomainAssociationAsRow(port), pending);
1441void ModuleState::getUpdatesForModule(FModuleOp moduleOp,
1443 PendingUpdates &pending) {
1445 auto names = moduleOp.getPortNamesAttr();
1446 for (
auto name : names.getAsRange<StringAttr>())
1448 getUpdatesForModulePorts(moduleOp, exports, ns, pending);
1451void ModuleState::applyUpdatesToModule(FModuleOp moduleOp,
ExportTable &exports,
1452 const PendingUpdates &pending) {
1453 LLVM_DEBUG(llvm::dbgs().indent(2) <<
"applying updates:\n");
1455 moduleOp.insertPorts(pending.insertions);
1459 for (
auto [var, portIndex] : pending.solutions) {
1460 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1461 auto *solution = allocVal(portValue);
1462 LLVM_DEBUG(llvm::dbgs().indent(4)
1463 <<
"new-input " << render(portValue) <<
"\n");
1464 solve(var, solution);
1465 exports[portValue].push_back(portValue);
1466 globals.inserted.insert(portValue);
1470 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1471 for (
auto [domainValue, portIndex] : pending.exports) {
1472 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1473 builder.setInsertionPointAfterValue(domainValue);
1474 DomainDefineOp::create(builder, portValue.getLoc(), portValue, domainValue);
1475 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"new-output " << render(portValue)
1476 <<
" := " << render(domainValue) <<
"\n");
1477 exports[domainValue].push_back(portValue);
1478 globals.inserted.insert(portValue);
1479 setTermForDomain(portValue, allocVal(domainValue));
1483SmallVector<Attribute> ModuleState::copyPortDomainAssociations(
1484 FModuleOp moduleOp, ArrayAttr moduleDomainInfo,
size_t portIndex) {
1485 SmallVector<Attribute> result(getNumDomains());
1487 for (
auto domainPortIndexAttr : oldAssociations) {
1488 auto domainPortIndex = domainPortIndexAttr.getUInt();
1489 auto domainTypeID = getDomainTypeID(moduleOp, domainPortIndex);
1490 result[domainTypeID.index] = domainPortIndexAttr;
1495LogicalResult ModuleState::driveModuleOutputDomainPorts(FModuleOp moduleOp) {
1496 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1497 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1498 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1499 if (!port || moduleOp.getPortDirection(i) == Direction::In ||
1503 auto *term = getOptTermForDomain(port);
1504 auto *val = llvm::dyn_cast_if_present<ValueTerm>(term);
1506 emitDomainPortInferenceError(moduleOp, i);
1510 auto loc = port.getLoc();
1511 auto value = val->value;
1512 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"connect " << render(port)
1513 <<
" := " << render(value) <<
"\n");
1514 DomainDefineOp::create(builder, loc, port, value);
1520LogicalResult ModuleState::updateModuleDomainInfo(
1521 FModuleOp moduleOp,
const ExportTable &exportTable, ArrayAttr &result) {
1526 auto *
context = moduleOp.getContext();
1527 auto numDomains = getNumDomains();
1528 auto oldModuleDomainInfo = moduleOp.getDomainInfoAttr();
1529 auto numPorts = moduleOp.getNumPorts();
1530 SmallVector<Attribute> newModuleDomainInfo(numPorts);
1532 for (
size_t i = 0; i < numPorts; ++i) {
1533 auto port = moduleOp.getArgument(i);
1534 auto type = port.getType();
1536 if (isa<DomainType>(type)) {
1538 newModuleDomainInfo[i] = ArrayAttr::get(
context, {});
1543 newModuleDomainInfo[i] = ArrayAttr::get(
context, {});
1548 copyPortDomainAssociations(moduleOp, oldModuleDomainInfo, i);
1549 auto *row = cast<RowTerm>(getDomainAssociation(port));
1550 for (
size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1551 auto domainTypeID = DomainTypeID{domainIndex};
1552 if (associations[domainIndex])
1555 auto domain = cast<ValueTerm>(find(row->elements[domainIndex]))->value;
1556 auto &exports = exportTable.at(domain);
1557 if (exports.empty()) {
1558 auto portName = moduleOp.getPortNameAttr(i);
1559 auto portLoc = moduleOp.getPortLocation(i);
1560 auto domainDecl = getDomain(domainTypeID);
1561 auto domainName = domainDecl.getNameAttr();
1562 auto diag = emitError(portLoc) <<
"private " << domainName
1563 <<
" association for port " << portName;
1564 diag.attachNote(domain.getLoc()) <<
"associated domain: " << domain;
1565 noteLocation(diag, moduleOp);
1569 if (exports.size() > 1) {
1570 emitAmbiguousPortDomainAssociation(moduleOp, exports, domainTypeID, i);
1574 auto argument = cast<BlockArgument>(exports[0]);
1575 auto domainPortIndex = argument.getArgNumber();
1576 associations[domainTypeID.index] =
1577 IntegerAttr::get(IntegerType::get(
context, 32, IntegerType::Unsigned),
1581 newModuleDomainInfo[i] = ArrayAttr::get(
context, associations);
1584 result = ArrayAttr::get(moduleOp.getContext(), newModuleDomainInfo);
1585 moduleOp.setDomainInfoAttr(result);
1590 OpBuilder &builder, DenseMap<DomainValue, DomainValue> &domainsInScope,
1591 Operation *user, DomainType type, VariableTerm *var) {
1592 auto name = type.getName().getAttr();
1594 DomainCreateAnonOp::create(builder, user->getLoc(), type, name);
1596 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"create anon " << render(anon) <<
"\n");
1597 solve(var, allocVal(anon));
1598 domainsInScope[anon] = anon;
1599 globals.inserted.insert(anon);
1604 OpBuilder &builder, DenseMap<DomainValue, DomainValue> &domainsInScope,
1606 auto &domainInScope = domainsInScope[domain];
1608 return domainInScope;
1610 domainInScope = cast<DomainValue>(
1611 WireOp::create(builder, domain.getLoc(), domain.getType(),
1612 domain.getType().getName().getAttr())
1615 OpBuilder::InsertionGuard guard(builder);
1616 builder.setInsertionPointAfterValue(domain);
1617 DomainDefineOp::create(builder, domain.getLoc(), domainInScope, domain);
1619 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"bounce wire " << render(domainInScope)
1620 <<
" := " << render(domain) <<
"\n");
1621 return domainInScope;
1625ModuleState::updateInstance(DenseMap<DomainValue, DomainValue> &domainsInScope,
1627 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"update " << render(op) <<
"\n");
1628 OpBuilder builder(op.getContext());
1629 builder.setInsertionPointAfter(op);
1630 auto numPorts = op->getNumResults();
1632 for (
size_t i = 0; i < numPorts; ++i)
1633 if (
auto port = dyn_cast<DomainValue>(op->getResult(i)))
1634 if (op.getPortDirection(i) == Direction::Out)
1635 domainsInScope[port] = port;
1637 for (
size_t i = 0; i < numPorts; ++i) {
1638 auto port = dyn_cast<DomainValue>(op->getResult(i));
1639 auto direction = op.getPortDirection(i);
1643 if (port && direction == Direction::In && !
isDriven(port)) {
1644 auto loc = port.getLoc();
1645 auto *term = getTermForDomain(port);
1646 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1647 auto domain = solveVarWithAnonDomain(builder, domainsInScope, op,
1648 port.getType(), var);
1649 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"connect " << render(port)
1650 <<
" := " << render(domain) <<
"\n");
1651 DomainDefineOp::create(builder, loc, port, domain);
1654 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1655 auto domain = getDomainInScope(builder, domainsInScope, val->value);
1656 LLVM_DEBUG(llvm::dbgs().indent(6) <<
"connect " << render(port)
1657 <<
" := " << render(domain) <<
"\n");
1658 DomainDefineOp::create(builder, loc, port, domain);
1661 llvm_unreachable(
"unhandled domain term type");
1669ModuleState::updateWire(DenseMap<DomainValue, DomainValue> &domainsInScope,
1671 auto result = wireOp.getResult();
1673 if (
auto tgt = dyn_cast<DomainValue>(result)) {
1677 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"update " << render(wireOp) <<
"\n");
1678 OpBuilder builder(wireOp);
1679 builder.setInsertionPointAfter(wireOp);
1680 auto *term = getTermForDomain(tgt);
1681 if (
auto *var = dyn_cast<VariableTerm>(term)) {
1682 auto src = solveVarWithAnonDomain(builder, domainsInScope, wireOp,
1683 tgt.getType(), var);
1684 LLVM_DEBUG(llvm::dbgs().indent(6)
1685 <<
"connect " << render(tgt) <<
" := " << render(src) <<
"\n");
1686 DomainDefineOp::create(builder, wireOp.getLoc(), tgt, src);
1689 if (
auto *val = dyn_cast<ValueTerm>(term)) {
1690 auto src = getDomainInScope(builder, domainsInScope, val->value);
1691 LLVM_DEBUG(llvm::dbgs().indent(6)
1692 <<
"connect " << render(tgt) <<
" := " << render(src) <<
"\n");
1693 DomainDefineOp::create(builder, wireOp.getLoc(), tgt, src);
1696 llvm_unreachable(
"unhandled domain term type");
1702 LLVM_DEBUG(llvm::dbgs().indent(4) <<
"update " << render(wireOp) <<
"\n");
1703 OpBuilder builder(wireOp);
1704 auto *row = getDomainAssociationAsRow(wireOp.getResult());
1706 SmallVector<Value> domainOperands;
1707 for (
auto [i, element] :
llvm::enumerate(
1708 llvm::map_range(row->elements, [&](auto e) {
return find(e); }))) {
1709 if (
auto *val = dyn_cast<ValueTerm>(element)) {
1710 domainOperands.push_back(
1711 getDomainInScope(builder, domainsInScope, val->value));
1714 if (
auto *var = dyn_cast<VariableTerm>(element)) {
1715 auto type = DomainType::getFromDomainOp(getDomain(DomainTypeID{i}));
1717 solveVarWithAnonDomain(builder, domainsInScope, wireOp, type, var);
1718 domainOperands.push_back(domain);
1721 assert(0 &&
"unhandled domain type");
1723 wireOp.getDomainsMutable().assign(domainOperands);
1727LogicalResult ModuleState::updateModuleBody(FModuleOp moduleOp) {
1728 DenseMap<DomainValue, DomainValue> domainsInScope;
1730 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i)
1731 if (
auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i)))
1732 if (moduleOp.getPortDirection(i) == Direction::In)
1733 domainsInScope[port] = port;
1735 auto result = moduleOp.getBodyBlock()->walk([&](Operation *op) -> WalkResult {
1736 return TypeSwitch<Operation *, WalkResult>(op)
1738 [&](
auto wire) {
return updateWire(domainsInScope, wire); })
1739 .Case<FInstanceLike>([&](
auto instance) {
1740 return updateInstance(domainsInScope, instance);
1742 .Case<DomainCreateOp, DomainCreateAnonOp>([&](
auto domain) {
1743 domainsInScope[domain] = domain;
1746 .Default([&](
auto op) {
return success(); });
1748 return failure(result.wasInterrupted());
1751LogicalResult ModuleState::updateModule(FModuleOp moduleOp) {
1752 auto exports = initializeExportTable(moduleOp);
1753 PendingUpdates pending;
1754 getUpdatesForModule(moduleOp, exports, pending);
1755 applyUpdatesToModule(moduleOp, exports, pending);
1757 ArrayAttr portDomainInfo;
1758 if (failed(updateModuleDomainInfo(moduleOp, exports, portDomainInfo)))
1761 if (failed(driveModuleOutputDomainPorts(moduleOp)))
1765 auto &entry = getModuleUpdateTable()[moduleOp.getModuleNameAttr()];
1766 entry.portDomainInfo = portDomainInfo;
1767 entry.portInsertions = std::move(pending.insertions);
1769 if (failed(updateModuleBody(moduleOp)))
1773 llvm::dbgs().indent(2) <<
"port summary:\n";
1774 for (
auto port : moduleOp.
getBodyBlock()->getArguments()) {
1775 llvm::dbgs().indent(4) << render(port);
1776 auto info = cast<ArrayAttr>(
1777 moduleOp.getDomainInfoAttrForPort(port.getArgNumber()));
1779 llvm::dbgs() <<
" domains [";
1780 llvm::interleaveComma(
1781 info.getAsRange<IntegerAttr>(), llvm::dbgs(), [&](
auto i) {
1782 llvm::dbgs() << render(moduleOp.getArgument(i.getUInt()));
1784 llvm::dbgs() <<
"]";
1786 llvm::dbgs() <<
"\n";
1793LogicalResult ModuleState::checkModulePorts(FModuleLike moduleOp) {
1794 auto numDomains = getNumDomains();
1795 auto domainInfo = moduleOp.getDomainInfoAttr();
1796 auto numPorts = moduleOp.getNumPorts();
1798 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1799 for (
size_t i = 0; i < numPorts; ++i) {
1800 if (isa<DomainType>(moduleOp.getPortType(i)))
1801 domainTypeIDTable[i] = getDomainTypeID(moduleOp, i);
1804 for (
size_t i = 0; i < numPorts; ++i) {
1809 SmallVector<IntegerAttr> associations(numDomains);
1811 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1812 auto prevDomainPortIndex = associations[domainTypeID.index];
1813 if (prevDomainPortIndex) {
1814 emitDuplicatePortDomainError(moduleOp, i, domainTypeID,
1815 prevDomainPortIndex, domainPortIndex);
1818 associations[domainTypeID.index] = domainPortIndex;
1822 for (
size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1823 auto typeID = DomainTypeID{domainIndex};
1824 if (!associations[domainIndex]) {
1825 emitMissingPortDomainAssociationError(moduleOp, typeID, i);
1834LogicalResult ModuleState::checkModuleDomainPortDrivers(FModuleOp moduleOp) {
1835 for (
size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1836 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1837 if (!port || moduleOp.getPortDirection(i) != Direction::Out ||
1841 auto name = moduleOp.getPortNameAttr(i);
1842 auto diag = emitError(moduleOp.getPortLocation(i))
1843 <<
"undriven domain port " << name;
1844 noteLocation(diag, moduleOp);
1851LogicalResult ModuleState::checkInstanceDomainPortDrivers(FInstanceLike op) {
1852 for (
size_t i = 0, e = op->getNumResults(); i < e; ++i) {
1853 auto port = dyn_cast<DomainValue>(op->getResult(i));
1855 auto type = port.getType();
1856 if (!isa<DomainType>(type) || op.getPortDirection(i) != Direction::In ||
1860 auto name = op.getPortNameAttr(i);
1861 auto diag = emitError(op.getPortLocation(i))
1862 <<
"undriven domain port " << name;
1863 noteLocation(diag, op);
1870LogicalResult ModuleState::checkModuleBody(FModuleOp moduleOp) {
1871 auto result = moduleOp.getBody().walk([&](FInstanceLike op) -> WalkResult {
1872 return checkInstanceDomainPortDrivers(op);
1874 return failure(result.wasInterrupted());
1877LogicalResult ModuleState::inferModule(FModuleOp moduleOp) {
1878 LLVM_DEBUG(llvm::dbgs() <<
"infer: " << moduleOp.getModuleName() <<
"\n");
1879 if (failed(processModule(moduleOp)))
1882 return updateModule(moduleOp);
1885LogicalResult ModuleState::checkModule(FModuleOp moduleOp) {
1886 LLVM_DEBUG(llvm::dbgs() <<
"check: " << moduleOp.getModuleName() <<
"\n");
1887 if (failed(checkModulePorts(moduleOp)))
1890 if (failed(checkModuleDomainPortDrivers(moduleOp)))
1893 if (failed(checkModuleBody(moduleOp)))
1896 return processModule(moduleOp);
1899LogicalResult ModuleState::checkModule(FExtModuleOp extModuleOp) {
1900 LLVM_DEBUG(llvm::dbgs() <<
"check: " << extModuleOp.getModuleName() <<
"\n");
1901 return checkModulePorts(extModuleOp);
1904LogicalResult ModuleState::checkAndInferModule(FModuleOp moduleOp) {
1905 LLVM_DEBUG(llvm::dbgs() <<
"check/infer: " << moduleOp.getModuleName()
1908 if (failed(checkModulePorts(moduleOp)))
1911 if (failed(processModule(moduleOp)))
1914 if (failed(driveModuleOutputDomainPorts(moduleOp)))
1917 return updateModuleBody(moduleOp);
1925 WalkResult result = op->walk<mlir::WalkOrder::PostOrder, ReverseIterator>(
1926 [=](Operation *op) -> WalkResult {
1927 return TypeSwitch<Operation *, WalkResult>(op)
1928 .Case<FModuleLike>([](FModuleLike op) {
1929 auto n = op.getNumPorts();
1930 BitVector erasures(n);
1931 for (
size_t i = 0; i < n; ++i)
1932 if (isa<DomainType>(op.getPortType(i)))
1934 op.erasePorts(erasures);
1935 return WalkResult::advance();
1937 .Case<DomainDefineOp, DomainCreateAnonOp, DomainCreateOp>(
1940 return WalkResult::advance();
1942 .Case<DomainSubfieldOp>([](DomainSubfieldOp op) {
1943 if (!op->use_empty()) {
1944 OpBuilder builder(op);
1945 op.replaceAllUsesWith(
1946 UnknownValueOp::create(builder, op.getLoc(), op.getType())
1950 return WalkResult::advance();
1952 .Case<UnsafeDomainCastOp>([](UnsafeDomainCastOp op) {
1953 op.replaceAllUsesWith(op.getInput());
1955 return WalkResult::advance();
1957 .Case<WireOp>([](WireOp op) {
1959 if (isa<DomainType>(op.getType(0))) {
1961 return WalkResult::advance();
1964 if (!op.getDomains().empty()) {
1965 op->eraseOperands(0, op.getNumOperands());
1967 return WalkResult::advance();
1969 .Case<FInstanceLike>([](
auto op) {
1970 auto n = op.getNumPorts();
1971 BitVector erasures(n);
1972 for (
size_t i = 0; i < n; ++i)
1973 if (isa<DomainType>(op->getResult(i).getType()))
1975 op.cloneWithErasedPortsAndReplaceUses(erasures);
1977 return WalkResult::advance();
1979 .Default([](Operation *op) {
1981 concat<Type>(op->getOperandTypes(), op->getResultTypes())) {
1982 if (isa<DomainType>(type)) {
1983 op->emitOpError(
"cannot be stripped");
1984 return WalkResult::interrupt();
1987 return WalkResult::advance();
1990 return failure(result.wasInterrupted());
1994 llvm::SmallVector<FModuleLike> modules;
1995 for (Operation &op : make_early_inc_range(*circuit.getBodyBlock())) {
1996 TypeSwitch<Operation *, void>(&op)
1997 .Case<FModuleLike>([&](FModuleLike op) { modules.push_back(op); })
1998 .Case<DomainOp>([](DomainOp op) { op.erase(); });
2007LogicalResult CircuitState::runOnModule(Operation *op) {
2008 assert(mode != InferDomainsMode::Strip);
2009 ModuleState state(*
this);
2010 if (
auto moduleOp = dyn_cast<FModuleOp>(op)) {
2011 if (mode == InferDomainsMode::Check)
2012 return state.checkModule(moduleOp);
2014 if (mode == InferDomainsMode::InferAll || moduleOp.isPrivate())
2015 return state.inferModule(moduleOp);
2017 return state.checkAndInferModule(moduleOp);
2020 if (
auto extModuleOp = dyn_cast<FExtModuleOp>(op))
2021 return state.checkModule(extModuleOp);
2026LogicalResult CircuitState::run() {
2027 DenseSet<Operation *> errored;
2028 instanceGraph.walkPostOrder([&](
auto &node) {
2029 auto moduleOp = node.getModule();
2030 for (
auto *inst : node) {
2031 if (errored.contains(inst->getTarget()->getModule())) {
2032 errored.insert(moduleOp);
2036 if (failed(runOnModule(node.getModule())))
2037 errored.insert(moduleOp);
2039 return success(errored.empty());
2043struct InferDomainsPass
2044 :
public circt::firrtl::impl::InferDomainsBase<InferDomainsPass> {
2046 void runOnOperation()
override {
2048 auto circuit = getOperation();
2050 if (mode == InferDomainsMode::Strip) {
2052 signalPassFailure();
2056 auto &instanceGraph = getAnalysis<InstanceGraph>();
2057 auto &symbolTable = getAnalysis<SymbolTable>();
2058 auto &innerSymbolTableCollection =
2059 getAnalysis<InnerSymbolTableCollection>();
2061 innerSymbolTableCollection};
2062 CircuitState state(circuit, instanceGraph, innerRefNamespace, mode);
2063 if (failed(state.run()))
2064 signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
SmallVector< std::pair< unsigned, PortInfo > > PortInsertions
mlir::TypedValue< DomainType > DomainValue
static LogicalResult stripCircuit(MLIRContext *context, CircuitOp circuit)
DenseMap< VariableTerm *, unsigned > PendingSolutions
A map from unsolved variables to a port index, where that port has not yet been created.
static bool isHardware(Type type)
True if a value of the given type could be associated with a domain.
static bool isPort(BlockArgument arg)
Return true if the value is a port on the module.
DenseMap< DomainValue, TinyPtrVector< DomainValue > > ExportTable
A map from domain IR values defined internal to the moduleOp, to ports that alias that domain.
static auto getPortDomainAssociation(ArrayAttr info, size_t i)
From a domain info attribute, get the row of associated domains for a hardware value at index i.
static bool isDriven(DomainValue port)
Returns true if the value is driven by a connect op.
static LogicalResult stripModule(FModuleLike op)
static Block * getBodyBlock(FModuleLike mod)
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
This graph tracks modules and where they are instantiated.
This class represents a collection of InnerSymbolTable's.
static StringRef toLongString(Direction direction)
InferDomainsMode
The mode for the InferDomains pass.
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
This holds the name and type that describes the module's ports.
This class represents the namespace in which InnerRef's can be resolved.