CIRCT 23.0.0git
Loading...
Searching...
No Matches
InferDomains.cpp
Go to the documentation of this file.
1//===- InferDomains.cpp - Infer and Check FIRRTL Domains ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// InferDomains implements FIRRTL domain inference and checking. This pass is
10// a bottom-up transform acting on modules. For each moduleOp, we ensure there
11// are no domain crossings, and we make explicit the domain associations of
12// ports.
13//
14// This pass does not require that ExpandWhens has run, but it should have run.
15// If ExpandWhens has not been run, then duplicate connections will influence
16// domain inference and this can result in errors.
17//
18//===----------------------------------------------------------------------===//
19
24#include "circt/Support/Debug.h"
26#include "mlir/IR/AsmState.h"
27#include "mlir/IR/Iterators.h"
28#include "mlir/IR/Threading.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TinyPtrVector.h"
33
34#define DEBUG_TYPE "firrtl-infer-domains"
35
36namespace circt {
37namespace firrtl {
38#define GEN_PASS_DEF_INFERDOMAINS
39#include "circt/Dialect/FIRRTL/Passes.h.inc"
40} // namespace firrtl
41} // namespace circt
42
43using namespace circt;
44using namespace firrtl;
45
48using llvm::concat;
49using mlir::AsmState;
50using mlir::InFlightDiagnostic;
51using mlir::ReverseIterator;
52
53namespace {
54struct VariableTerm;
55} // namespace
56
57//====--------------------------------------------------------------------------
58// Helpers.
59//====--------------------------------------------------------------------------
60
61using DomainValue = mlir::TypedValue<DomainType>;
62
63using PortInsertions = SmallVector<std::pair<unsigned, PortInfo>>;
64
65/// From a domain info attribute, get the row of associated domains for a
66/// hardware value at index i.
67static auto getPortDomainAssociation(ArrayAttr info, size_t i) {
68 if (info.empty())
69 return info.getAsRange<IntegerAttr>();
70 return cast<ArrayAttr>(info[i]).getAsRange<IntegerAttr>();
71}
72
73/// Return true if the value is a port on the module.
74static bool isPort(BlockArgument arg) {
75 return isa<FModuleOp>(arg.getOwner()->getParentOp());
76}
77
78/// Return true if the value is a port on the module.
79static bool isPort(Value value) {
80 auto arg = dyn_cast<BlockArgument>(value);
81 if (!arg)
82 return false;
83 return isPort(arg);
84}
85
86/// Returns true if the value is driven by a connect op.
87static bool isDriven(DomainValue port) {
88 for (auto *user : port.getUsers())
89 if (auto connect = dyn_cast<FConnectLike>(user))
90 if (connect.getDest() == port)
91 return true;
92 return false;
93}
94
95/// True if a value of the given type could be associated with a domain.
96static bool isHardware(Type type) {
97 return type_isa<FIRRTLBaseType, RefType>(type);
98}
99
100/// True if the given value could be association with a domain.
101static bool isHardware(Value value) {
102 if (!isHardware(value.getType()))
103 return false;
104
105 if (auto *op = value.getDefiningOp())
106 if (op->hasTrait<OpTrait::ConstantLike>())
107 return false;
108
109 return true;
110}
111
112//====--------------------------------------------------------------------------
113// Global State.
114//====--------------------------------------------------------------------------
115
116/// Each domain type declared in the circuit is assigned a type-id, based on the
117/// order of declaration. Domain associations for hardware values are
118/// represented as a list, or row, of domains. The domains in a row are ordered
119/// according to their type's id.
120namespace {
121struct DomainTypeID {
122 size_t index;
123};
124} // namespace
125
126/// Information about the changes made to the interface of a moduleOp, which can
127/// be replayed onto an instance.
128namespace {
129struct ModuleUpdateInfo {
130 /// The updated domain information for a moduleOp.
131 ArrayAttr portDomainInfo;
132 /// The domain ports which have been inserted into a moduleOp.
133 PortInsertions portInsertions;
134};
135} // namespace
136
137namespace {
138struct CircuitState {
139 CircuitState(CircuitOp circuit, InstanceGraph &instanceGraph,
140 InnerRefNamespace &innerRefNamespace, InferDomainsMode mode)
141 : circuit(circuit), instanceGraph(instanceGraph),
142 innerRefNamespace(innerRefNamespace), mode(mode) {
143 processCircuit(circuit);
144 }
145
146 LogicalResult run();
147
148 ArrayRef<DomainOp> getDomains() const { return domainTable; }
149 size_t getNumDomains() const { return domainTable.size(); }
150 DomainOp getDomain(DomainTypeID id) const { return domainTable[id.index]; }
151 DomainTypeID getDomainTypeID(Type type) { return typeIDTable[type]; }
152
153 void dirty() { asmState = nullptr; }
154 AsmState &getAsmState() {
155 if (!asmState) {
156 asmState = std::make_unique<AsmState>(
157 circuit, mlir::OpPrintingFlags().assumeVerified());
158 }
159 return *asmState;
160 }
161
162 size_t getVariableID(VariableTerm *term) {
163 return variableIDTable.insert({term, variableIDTable.size() + 1})
164 .first->second;
165 }
166
167 DenseMap<StringAttr, ModuleUpdateInfo> &getModuleUpdateTable() {
168 return moduleUpdateTable;
169 }
170
171 InnerRefNamespace &getInnerRefNamespace() { return innerRefNamespace; }
172
173 DenseSet<Value> inserted;
174
175private:
176 LogicalResult runOnModule(Operation *moduleOp);
177
178 void processDomain(DomainOp op) {
179 auto index = domainTable.size();
180 auto domainType = DomainType::getFromDomainOp(op);
181 domainTable.push_back(op);
182 typeIDTable.insert({domainType, {index}});
183 }
184
185 void processCircuit(CircuitOp circuit) {
186 for (auto decl : circuit.getOps<DomainOp>())
187 processDomain(decl);
188 }
189
190 CircuitOp circuit;
191 InstanceGraph &instanceGraph;
192 InnerRefNamespace &innerRefNamespace;
193 InferDomainsMode mode;
194 SmallVector<DomainOp> domainTable;
195 DenseMap<Type, DomainTypeID> typeIDTable;
196 DenseMap<VariableTerm *, size_t> variableIDTable;
197 std::unique_ptr<AsmState> asmState;
198 DenseMap<StringAttr, ModuleUpdateInfo> moduleUpdateTable;
199};
200} // namespace
201
202//====--------------------------------------------------------------------------
203// Terms: Syntax for unifying domain and domain-rows.
204//====--------------------------------------------------------------------------
205
206/// The different sorts of terms in the unification engine.
207namespace {
208enum class TermKind {
209 Variable,
210 Value,
211 Row,
212};
213} // namespace
214
215/// A term in the unification engine.
216namespace {
217struct Term {
218 constexpr Term(TermKind kind) : kind(kind) {}
219 TermKind kind;
220};
221} // namespace
222
223/// Helper to define a term kind.
224namespace {
225template <TermKind K>
226struct TermBase : Term {
227 static bool classof(const Term *term) { return term->kind == K; }
228 TermBase() : Term(K) {}
229};
230} // namespace
231
232/// An unknown value.
233namespace {
234struct VariableTerm : public TermBase<TermKind::Variable> {
235 VariableTerm() : leader(nullptr) {}
236 VariableTerm(Term *leader) : leader(leader) {}
237 Term *leader;
238};
239} // namespace
240
241/// A concrete value defined in the IR.
242namespace {
243struct ValueTerm : public TermBase<TermKind::Value> {
244 ValueTerm(DomainValue value) : value(value) {}
245 DomainValue value;
246};
247} // namespace
248
249/// A row of domains.
250namespace {
251struct RowTerm : public TermBase<TermKind::Row> {
252 RowTerm(ArrayRef<Term *> elements) : elements(elements) {}
253 ArrayRef<Term *> elements;
254};
255} // namespace
256
257//====--------------------------------------------------------------------------
258// Module processing: solve for the domain associations of hardware.
259//====--------------------------------------------------------------------------
260
261/// A map from unsolved variables to a port index, where that port has not yet
262/// been created. Eventually we will have an input domain at the port index,
263/// which will be the solution to the recorded variable.
264using PendingSolutions = DenseMap<VariableTerm *, unsigned>;
265
266/// A map from local domains to an aliasing port index, where that port has not
267/// yet been created. Eventually we will be exporting the domain value at the
268/// port index.
270
271namespace {
272struct PendingUpdates {
273 PortInsertions insertions;
274 PendingSolutions solutions;
275 PendingExports exports;
276};
277} // namespace
278
279/// A map from domain IR values defined internal to the moduleOp, to ports that
280/// alias that domain. These ports make the domain useable as associations of
281/// ports, and we say these are exporting ports.
282using ExportTable = DenseMap<DomainValue, TinyPtrVector<DomainValue>>;
283
284namespace {
285class ModuleState {
286public:
287 explicit ModuleState(CircuitState &globals) : globals(globals) {}
288
289 ArrayRef<DomainOp> getDomains() { return globals.getDomains(); }
290 size_t getNumDomains() { return globals.getNumDomains(); }
291 DomainOp getDomain(DomainTypeID id) { return globals.getDomain(id); }
292 DomainTypeID getDomainTypeID(Type type) {
293 return globals.getDomainTypeID(type);
294 }
295 DomainTypeID getDomainTypeID(FModuleLike module, size_t i) {
296 return globals.getDomainTypeID(module.getPortType(i));
297 }
298 DomainTypeID getDomainTypeID(FInstanceLike op, size_t i) const {
299 return globals.getDomainTypeID(op->getResult(i).getType());
300 }
301 DomainTypeID getDomainTypeID(DomainValue value) const {
302 return globals.getDomainTypeID(value.getType());
303 }
304 auto &getModuleUpdateTable() { return globals.getModuleUpdateTable(); }
305
306 mlir::AsmState &getAsmState() { return globals.getAsmState(); }
307 void dirty() { globals.dirty(); }
308
309 template <typename T>
310 void render(Operation *op, T &out);
311 template <typename T>
312 void render(Value value, T &out);
313 template <typename T>
314 void renderLong(Value value, T &out);
315 template <typename T>
316 void render(Term *term, T &out);
317 template <typename T>
318 struct Render;
319 template <typename T>
320 Render<T> render(T &&subject);
321 struct RenderLong;
322 RenderLong renderLong(Value value);
323
324 Term *find(Term *x);
325 LogicalResult unify(Term *lhs, Term *rhs);
326 LogicalResult unify(VariableTerm *x, Term *y);
327 LogicalResult unify(ValueTerm *xv, Term *y);
328 LogicalResult unify(RowTerm *lhsRow, Term *rhs);
329 void solve(Term *lhs, Term *rhs);
330
331 [[nodiscard]] RowTerm *allocRow(size_t size);
332 [[nodiscard]] RowTerm *allocRow(ArrayRef<Term *> elements);
333 [[nodiscard]] VariableTerm *allocVar();
334 [[nodiscard]] ValueTerm *allocVal(DomainValue value);
335 template <typename T, typename... Args>
336 T *alloc(Args &&...args);
337 ArrayRef<Term *> allocArray(ArrayRef<Term *> elements);
338
339 DomainValue getOptUnderlyingDomain(DomainValue value);
340 Term *getOptTermForDomain(DomainValue value);
341 Term *getTermForDomain(DomainValue value);
342 void setTermForDomain(DomainValue value, Term *term);
343
344 Term *getOptDomainAssociation(Value value);
345 Term *getDomainAssociation(Value value);
346 void setDomainAssociation(Value value, Term *term);
347
348 void processDomainDefinition(DomainValue domain);
349 RowTerm *getDomainAssociationAsRow(Value value);
350
351 void noteLocation(InFlightDiagnostic &diag, Operation *op);
352 void noteDomain(InFlightDiagnostic &diag, DomainValue domain);
353 void noteDomainSource(InFlightDiagnostic &diag, DomainValue domain);
354 void noteDomainSource(InFlightDiagnostic &diag, Term *term);
355 void emitDomainCrossingError(Operation *op, Value lhs, Term *lhsTerm,
356 Value rhs, Term *rhsTerm);
357 template <typename T>
358 void emitDuplicatePortDomainError(T op, size_t i, DomainTypeID domainTypeID,
359 IntegerAttr domainPortIndexAttr1,
360 IntegerAttr domainPortIndexAttr2);
361 template <typename T>
362 void emitDomainPortInferenceError(T op, size_t i);
363 template <typename T>
364 void emitAmbiguousPortDomainAssociation(
365 T op, const llvm::TinyPtrVector<DomainValue> &exports,
366 DomainTypeID typeID, size_t i);
367 template <typename T>
368 void emitMissingPortDomainAssociationError(T op, DomainTypeID typeID,
369 size_t i);
370
371 LogicalResult unifyAssociations(Operation *op, Value lhs, Value rhs);
372 template <typename T>
373 LogicalResult unifyAssociations(Operation *op, T &&range);
374 LogicalResult unifyAssociations(Operation *op);
375
376 LogicalResult processModulePorts(FModuleOp moduleOp);
377 template <typename T>
378 LogicalResult processInstancePorts(T op);
379 FInstanceLike fixInstancePorts(FInstanceLike op,
380 const ModuleUpdateInfo &update);
381 LogicalResult processOp(FInstanceLike op);
382 LogicalResult processOp(UnsafeDomainCastOp op);
383 LogicalResult processOp(DomainDefineOp op);
384 LogicalResult processOp(WireOp op);
385 LogicalResult processOp(RWProbeOp op);
386 LogicalResult processOp(Operation *op);
387 LogicalResult processModuleBody(FModuleOp moduleOp);
388 LogicalResult processModule(FModuleOp moduleOp);
389
390 ExportTable initializeExportTable(FModuleOp moduleOp);
391 void ensureSolved(Namespace &ns, DomainTypeID typeID, size_t ip,
392 LocationAttr loc, VariableTerm *var,
393 PendingUpdates &pending);
394 void ensureExported(Namespace &ns, const ExportTable &exports,
395 DomainTypeID typeID, size_t ip, LocationAttr loc,
396 ValueTerm *val, PendingUpdates &pending);
397 void getUpdatesForDomainAssociationOfPort(Namespace &ns,
398 PendingUpdates &pending,
399 DomainTypeID typeID, size_t ip,
400 LocationAttr loc, Term *term,
401 const ExportTable &exports);
402 void getUpdatesForDomainAssociationOfPort(Namespace &ns,
403 const ExportTable &exports,
404 size_t ip, LocationAttr loc,
405 RowTerm *row,
406 PendingUpdates &pending);
407 void getUpdatesForModulePorts(FModuleOp moduleOp, const ExportTable &exports,
408 Namespace &ns, PendingUpdates &pending);
409 void getUpdatesForModule(FModuleOp moduleOp, const ExportTable &exports,
410 PendingUpdates &pending);
411 void applyUpdatesToModule(FModuleOp moduleOp, ExportTable &exports,
412 const PendingUpdates &pending);
413 SmallVector<Attribute> copyPortDomainAssociations(FModuleOp moduleOp,
414 ArrayAttr moduleDomainInfo,
415 size_t portIndex);
416 LogicalResult driveModuleOutputDomainPorts(FModuleOp moduleOp);
417 LogicalResult updateModuleDomainInfo(FModuleOp moduleOp,
418 const ExportTable &exportTable,
419 ArrayAttr &result);
421 solveVarWithAnonDomain(OpBuilder &builder,
422 DenseMap<DomainValue, DomainValue> &domainsInScope,
423 Operation *user, DomainType type, VariableTerm *var);
425 getDomainInScope(OpBuilder &builder,
426 DenseMap<DomainValue, DomainValue> &domainsInScope,
427 DomainValue domain);
428 LogicalResult
429 updateInstance(DenseMap<DomainValue, DomainValue> &domainsInScope,
430 FInstanceLike op);
431 LogicalResult updateWire(DenseMap<DomainValue, DomainValue> &domainsInScope,
432 WireOp wireOp);
433 LogicalResult updateModuleBody(FModuleOp moduleOp);
434 LogicalResult updateModule(FModuleOp moduleOp);
435
436 LogicalResult checkModulePorts(FModuleLike moduleOp);
437 LogicalResult checkModuleDomainPortDrivers(FModuleOp moduleOp);
438 LogicalResult checkInstanceDomainPortDrivers(FInstanceLike op);
439 LogicalResult checkModuleBody(FModuleOp moduleOp);
440
441 LogicalResult inferModule(FModuleOp moduleOp);
442 LogicalResult checkModule(FModuleOp moduleOp);
443 LogicalResult checkModule(FExtModuleOp extModuleOp);
444 LogicalResult checkAndInferModule(FModuleOp moduleOp);
445
446private:
447 CircuitState &globals;
448 DenseMap<Value, Term *> termTable;
449 DenseMap<Value, Term *> associationTable;
450 llvm::BumpPtrAllocator allocator;
451};
452} // namespace
453
454template <typename T>
455void ModuleState::render(Operation *op, T &out) {
456 op->print(out, getAsmState());
457}
458
459template <typename T>
460void ModuleState::render(Value value, T &out) {
461 if (!value) {
462 out << "null";
463 return;
464 }
465
466 auto [name, _] = getFieldName(value);
467 if (name.empty()) {
468 llvm::raw_string_ostream os(name);
469 value.printAsOperand(os, globals.getAsmState());
470 }
471 out << name;
472}
473
474template <typename T>
475void ModuleState::renderLong(Value value, T &out) {
476 if (auto arg = dyn_cast<BlockArgument>(value)) {
477 if (auto moduleOp = llvm::dyn_cast_if_present<FModuleLike>(
478 arg.getOwner()->getParentOp())) {
480 moduleOp.getPortDirection(arg.getArgNumber()));
481 out << " module port ";
482 }
483 } else if (auto result = dyn_cast<OpResult>(value)) {
484 auto *op = result.getOwner();
485 if (auto inst = dyn_cast<FInstanceLike>(op)) {
487 inst.getPortDirection(result.getResultNumber()));
488 out << " instance port ";
489 }
490 }
491
492 render(value, out);
493}
494
495template <typename T>
496// NOLINTNEXTLINE(misc-no-recursion)
497void ModuleState::render(Term *term, T &out) {
498 if (!term) {
499 out << "null";
500 return;
501 }
502 term = find(term);
503 if (auto *var = dyn_cast<VariableTerm>(term)) {
504 out << "?" << globals.getVariableID(var);
505 return;
506 }
507 if (auto *val = dyn_cast<ValueTerm>(term)) {
508 auto value = val->value;
509 render(value, out);
510 return;
511 }
512 if (auto *row = dyn_cast<RowTerm>(term)) {
513 out << "[";
514 llvm::interleaveComma(
515 llvm::seq(size_t(0), getNumDomains()), out, [&](auto i) {
516 render(row->elements[i], out);
517 out << " : " << getDomain(DomainTypeID{i}).getSymName();
518 });
519 out << "]";
520 return;
521 }
522 out << "unknown";
523}
524
525template <typename T>
526struct ModuleState::Render {
527 ModuleState *state;
529};
530
531template <typename T>
532ModuleState::Render<T> ModuleState::render(T &&subject) {
533 return Render<T>{this, std::forward<T>(subject)};
534}
535
536template <typename T>
537static llvm::raw_ostream &operator<<(llvm::raw_ostream &out,
538 ModuleState::Render<T> r) {
539 r.state->render(r.subject, out);
540 return out;
541}
542
544 ModuleState *state;
545 Value value;
546};
547
548ModuleState::RenderLong ModuleState::renderLong(Value value) {
549 return RenderLong{this, value};
550}
551
552static Diagnostic &operator<<(Diagnostic &diag, ModuleState::RenderLong r) {
553 r.state->renderLong(r.value, diag);
554 return diag;
555}
556
557// NOLINTNEXTLINE(misc-no-recursion)
558Term *ModuleState::find(Term *x) {
559 if (!x)
560 return nullptr;
561
562 if (auto *var = dyn_cast<VariableTerm>(x)) {
563 if (var->leader == nullptr)
564 return var;
565
566 auto *leader = find(var->leader);
567 if (leader != var->leader)
568 var->leader = leader;
569 return leader;
570 }
571
572 return x;
573}
574
575LogicalResult ModuleState::unify(VariableTerm *x, Term *y) {
576 assert(!x->leader);
577 x->leader = y;
578 return success();
579}
580
581LogicalResult ModuleState::unify(ValueTerm *xv, Term *y) {
582 if (auto *yv = dyn_cast<VariableTerm>(y)) {
583 yv->leader = xv;
584 return success();
585 }
586
587 if (auto *yv = dyn_cast<ValueTerm>(y))
588 return success(xv == yv);
589
590 return failure();
591}
592
593// NOLINTNEXTLINE(misc-no-recursion)
594LogicalResult ModuleState::unify(RowTerm *lhsRow, Term *rhs) {
595 if (auto *rhsVar = dyn_cast<VariableTerm>(rhs)) {
596 rhsVar->leader = lhsRow;
597 return success();
598 }
599 if (auto *rhsRow = dyn_cast<RowTerm>(rhs)) {
600 for (auto [x, y] : llvm::zip_equal(lhsRow->elements, rhsRow->elements))
601 if (failed(unify(x, y)))
602 return failure();
603 return success();
604 }
605 return failure();
606}
607
608// NOLINTNEXTLINE(misc-no-recursion)
609LogicalResult ModuleState::unify(Term *lhs, Term *rhs) {
610 if (!lhs || !rhs)
611 return success();
612 lhs = find(lhs);
613 rhs = find(rhs);
614 if (lhs == rhs)
615 return success();
616
617 LLVM_DEBUG(llvm::dbgs().indent(6)
618 << "unify " << render(lhs) << " = " << render(rhs) << "\n");
619
620 if (auto *lhsVar = dyn_cast<VariableTerm>(lhs))
621 return unify(lhsVar, rhs);
622 if (auto *lhsVal = dyn_cast<ValueTerm>(lhs))
623 return unify(lhsVal, rhs);
624 if (auto *lhsRow = dyn_cast<RowTerm>(lhs))
625 return unify(lhsRow, rhs);
626 return failure();
627}
628
629void ModuleState::solve(Term *lhs, Term *rhs) {
630 [[maybe_unused]] auto result = unify(lhs, rhs);
631 assert(result.succeeded());
632}
633
634RowTerm *ModuleState::allocRow(size_t size) {
635 SmallVector<Term *> elements;
636 elements.resize(size);
637 return allocRow(elements);
638}
639
640RowTerm *ModuleState::allocRow(ArrayRef<Term *> elements) {
641 auto ds = allocArray(elements);
642 return alloc<RowTerm>(ds);
643}
644
645VariableTerm *ModuleState::allocVar() { return alloc<VariableTerm>(); }
646
647ValueTerm *ModuleState::allocVal(DomainValue value) {
648 return alloc<ValueTerm>(value);
649}
650
651template <typename T, typename... Args>
652T *ModuleState::alloc(Args &&...args) {
653 static_assert(std::is_base_of_v<Term, T>, "T must be a term");
654 return new (allocator) T(std::forward<Args>(args)...);
655}
656
657ArrayRef<Term *> ModuleState::allocArray(ArrayRef<Term *> elements) {
658 auto size = elements.size();
659 if (size == 0)
660 return {};
661
662 auto *result = allocator.Allocate<Term *>(size);
663 llvm::uninitialized_copy(elements, result);
664 for (size_t i = 0; i < size; ++i)
665 if (!result[i])
666 result[i] = alloc<VariableTerm>();
667
668 return ArrayRef(result, size);
669}
670
671DomainValue ModuleState::getOptUnderlyingDomain(DomainValue value) {
672 auto *term = getOptTermForDomain(value);
673 if (auto *val = llvm::dyn_cast_if_present<ValueTerm>(term))
674 return val->value;
675 return nullptr;
676}
677
678Term *ModuleState::getOptTermForDomain(DomainValue value) {
679 assert(isa<DomainType>(value.getType()));
680 auto it = termTable.find(value);
681 if (it == termTable.end())
682 return nullptr;
683 return find(it->second);
684}
685
686Term *ModuleState::getTermForDomain(DomainValue value) {
687 assert(isa<DomainType>(value.getType()));
688 if (auto *term = getOptTermForDomain(value))
689 return term;
690 auto *term = allocVar();
691 setTermForDomain(value, term);
692 return term;
693}
694
695void ModuleState::setTermForDomain(DomainValue value, Term *term) {
696 assert(term);
697 assert(!termTable.contains(value));
698 termTable.insert({value, term});
699 LLVM_DEBUG(llvm::dbgs().indent(6)
700 << "set " << render(value) << " := " << render(term) << "\n");
701}
702
703Term *ModuleState::getOptDomainAssociation(Value value) {
704 assert(isHardware(value));
705 auto it = associationTable.find(value);
706 if (it == associationTable.end())
707 return nullptr;
708 return find(it->second);
709}
710
711Term *ModuleState::getDomainAssociation(Value value) {
712 auto *term = getOptDomainAssociation(value);
713 assert(term);
714 return term;
715}
716
717void ModuleState::setDomainAssociation(Value value, Term *term) {
718 assert(isHardware(value));
719 assert(term);
720 term = find(term);
721 associationTable.insert({value, term});
722 LLVM_DEBUG({
723 llvm::dbgs().indent(6) << "set domains(" << render(value)
724 << ") := " << render(term) << "\n";
725 });
726}
727
728void ModuleState::processDomainDefinition(DomainValue domain) {
729 assert(isa<DomainType>(domain.getType()));
730 auto *newTerm = allocVal(domain);
731 auto *oldTerm = getOptTermForDomain(domain);
732 if (!oldTerm) {
733 setTermForDomain(domain, newTerm);
734 return;
735 }
736
737 [[maybe_unused]] auto result = unify(oldTerm, newTerm);
738 assert(result.succeeded());
739}
740
741RowTerm *ModuleState::getDomainAssociationAsRow(Value value) {
742 assert(isHardware(value));
743 auto *term = getOptDomainAssociation(value);
744
745 // If the term is unknown, allocate a fresh row and set the association.
746 if (!term) {
747 auto *row = allocRow(getNumDomains());
748 setDomainAssociation(value, row);
749 return row;
750 }
751
752 // If the term is already a row, return it.
753 if (auto *row = dyn_cast<RowTerm>(term))
754 return row;
755
756 // Otherwise, unify the term with a fresh row of domains.
757 if (auto *var = dyn_cast<VariableTerm>(term)) {
758 auto *row = allocRow(getNumDomains());
759 solve(var, row);
760 return row;
761 }
762
763 assert(false && "unhandled term type");
764 return nullptr;
765}
766
767void ModuleState::noteLocation(InFlightDiagnostic &diag, Operation *op) {
768 auto &note = diag.attachNote(op->getLoc());
769 if (auto mod = dyn_cast<FModuleOp>(op)) {
770 note << "in module " << mod.getModuleNameAttr();
771 return;
772 }
773 if (auto mod = dyn_cast<FExtModuleOp>(op)) {
774 note << "in extmodule " << mod.getModuleNameAttr();
775 return;
776 }
777 if (auto inst = dyn_cast<InstanceOp>(op)) {
778 note << "in instance " << inst.getInstanceNameAttr();
779 return;
780 }
781 if (auto inst = dyn_cast<InstanceChoiceOp>(op)) {
782 note << "in instance_choice " << inst.getNameAttr();
783 return;
784 }
785
786 note << "here";
787}
788
789void ModuleState::noteDomain(InFlightDiagnostic &diag, DomainValue domain) {
790 auto &note = diag.attachNote(domain.getLoc());
791 note << renderLong(domain);
792
793 if (globals.inserted.contains(domain)) {
794 note << " automatically inserted here";
795 return;
796 }
797
798 note << " declared here";
799}
800
801void ModuleState::noteDomainSource(InFlightDiagnostic &diag,
802 DomainValue domain) {
803 auto &irns = globals.getInnerRefNamespace();
804 SmallVector<FInstanceLike> stack;
805 llvm::SmallDenseSet<DomainValue> seen;
806
807 // This is reusing "domain" across iterations of the while loop.
808
809 auto chaseConnect = [&]() {
810 for (auto *user : domain.getUsers()) {
811 if (auto defineOp = dyn_cast<DomainDefineOp>(user)) {
812 if (defineOp.getDest() != domain)
813 continue;
814 auto src = defineOp.getSrc();
815 diag.attachNote(defineOp.getLoc())
816 << renderLong(domain) << " aliases " << renderLong(src);
817 domain = defineOp.getSrc();
818 return true;
819 }
820 }
821 return false;
822 };
823
824 auto chaseModulePort = [&]() {
825 auto arg = dyn_cast<BlockArgument>(domain);
826 if (!arg)
827 return false;
828
829 auto module =
830 llvm::dyn_cast_if_present<FModuleOp>(arg.getOwner()->getParentOp());
831 if (!module)
832 return false;
833
834 auto name = module.getModuleNameAttr();
835 while (!stack.empty()) {
836 auto instance = stack.back();
837 stack.pop_back();
838 auto referenced = instance.getReferencedModuleNamesAttr().getValue();
839 if (llvm::is_contained(referenced, name)) {
840 domain = cast<DomainValue>(instance->getResult(arg.getArgNumber()));
841 return true;
842 }
843 }
844 return false;
845 };
846
847 auto chaseInstancePort = [&]() {
848 auto result = dyn_cast<OpResult>(domain);
849 if (!result)
850 return false;
851
852 auto inst = dyn_cast<FInstanceLike>(result.getOwner());
853 if (!inst)
854 return false;
855
856 auto index = result.getResultNumber();
857 if (inst.getPortDirection(index) == Direction::In)
858 return false;
859
860 auto names = inst.getReferencedModuleNamesAttr().getAsRange<StringAttr>();
861 for (auto name : names) {
862 auto moduleLike = cast<FModuleLike>(irns.symTable.lookup(name));
863 if (auto moduleOp = dyn_cast<FModuleOp>(moduleLike.getOperation())) {
864 stack.push_back(inst);
865 domain = cast<DomainValue>(moduleOp.getArgument(index));
866 return true;
867 }
868 }
869 return false;
870 };
871
872 auto chaseUnderlying = [&]() {
873 if (auto *term = getOptTermForDomain(domain)) {
874 if (auto *val = dyn_cast<ValueTerm>(term)) {
875 if (domain != val->value) {
876 diag.attachNote(domain.getLoc())
877 << renderLong(domain) << " aliases " << renderLong(val->value);
878 domain = val->value;
879 return true;
880 }
881 }
882 }
883 return false;
884 };
885
886 while (true) {
887 auto [it, inserted] = seen.insert(domain);
888 if (!inserted)
889 return;
890
891 noteDomain(diag, domain);
892 chaseConnect() || chaseModulePort() || chaseInstancePort() ||
893 chaseUnderlying();
894 }
895}
896
897void ModuleState::noteDomainSource(InFlightDiagnostic &diag, Term *term) {
898 auto *val = dyn_cast<ValueTerm>(find(term));
899 if (!val)
900 return;
901
902 noteDomainSource(diag, val->value);
903}
904
905void ModuleState::emitDomainCrossingError(Operation *op, Value lhs,
906 Term *lhsTerm, Value rhs,
907 Term *rhsTerm) {
908 auto *lhsRow = cast<RowTerm>(lhsTerm);
909 auto *rhsRow = cast<RowTerm>(rhsTerm);
910 auto diag =
911 op->emitError("illegal domain crossing in operation between operands ");
912 render(lhs, diag);
913 diag << " and ";
914 render(rhs, diag);
915 auto &note1 = diag.attachNote(lhs.getLoc());
916 render(lhs, note1);
917 note1 << " has domains ";
918 render(lhsRow, note1);
919 auto &note2 = diag.attachNote(rhs.getLoc());
920 render(rhs, note2);
921 note2 << " has domains ";
922 render(rhsRow, note2);
923
924 for (size_t i = 0, e = getNumDomains(); i < e; ++i) {
925 auto *lhsDomain = find(lhsRow->elements[i]);
926 auto *rhsDomain = find(rhsRow->elements[i]);
927 if (lhsDomain == rhsDomain)
928 continue;
929
930 noteDomainSource(diag, lhsDomain);
931 noteDomainSource(diag, rhsDomain);
932 }
933}
934
935template <typename T>
936void ModuleState::emitDuplicatePortDomainError(
937 T op, size_t i, DomainTypeID domainTypeID, IntegerAttr domainPortIndexAttr1,
938 IntegerAttr domainPortIndexAttr2) {
939 auto portName = op.getPortNameAttr(i);
940 auto portLoc = op.getPortLocation(i);
941 auto domainDecl = getDomain(domainTypeID);
942 auto domainName = domainDecl.getNameAttr();
943 auto domainPortIndex1 = domainPortIndexAttr1.getUInt();
944 auto domainPortIndex2 = domainPortIndexAttr2.getUInt();
945 auto domainPortName1 = op.getPortNameAttr(domainPortIndex1);
946 auto domainPortName2 = op.getPortNameAttr(domainPortIndex2);
947 auto domainPortLoc1 = op.getPortLocation(domainPortIndex1);
948 auto domainPortLoc2 = op.getPortLocation(domainPortIndex2);
949 auto diag = emitError(portLoc);
950 diag << "duplicate " << domainName << " association for port " << portName;
951 auto &note1 = diag.attachNote(domainPortLoc1);
952 note1 << "associated with " << domainName << " port " << domainPortName1;
953 auto &note2 = diag.attachNote(domainPortLoc2);
954 note2 << "associated with " << domainName << " port " << domainPortName2;
955 noteLocation(diag, op);
956}
957
958/// Emit an error when we fail to infer the concrete domain to drive to a
959/// domain port.
960template <typename T>
961void ModuleState::emitDomainPortInferenceError(T op, size_t i) {
962 auto name = op.getPortNameAttr(i);
963 auto diag = emitError(op->getLoc());
964 auto info = op.getDomainInfo();
965 diag << "unable to infer value for undriven domain port " << name;
966 for (size_t j = 0, e = op.getNumPorts(); j < e; ++j) {
967 if (auto assocs = dyn_cast<ArrayAttr>(info[j])) {
968 for (auto assoc : assocs) {
969 if (i == cast<IntegerAttr>(assoc).getValue()) {
970 auto name = op.getPortNameAttr(j);
971 auto loc = op.getPortLocation(j);
972 diag.attachNote(loc) << "associated with hardware port " << name;
973 break;
974 }
975 }
976 }
977 }
978 noteLocation(diag, op);
979}
980
981template <typename T>
982void ModuleState::emitAmbiguousPortDomainAssociation(
983 T op, const llvm::TinyPtrVector<DomainValue> &exports, DomainTypeID typeID,
984 size_t i) {
985 auto portName = op.getPortNameAttr(i);
986 auto portLoc = op.getPortLocation(i);
987 auto domainDecl = getDomain(typeID);
988 auto domainName = domainDecl.getNameAttr();
989 auto diag = emitError(portLoc) << "ambiguous " << domainName
990 << " association for port " << portName;
991 for (auto e : exports) {
992 auto arg = cast<BlockArgument>(e);
993 auto name = op.getPortNameAttr(arg.getArgNumber());
994 auto loc = op.getPortLocation(arg.getArgNumber());
995 diag.attachNote(loc) << "candidate association " << name;
996 }
997 noteLocation(diag, op);
998}
999
1000template <typename T>
1001void ModuleState::emitMissingPortDomainAssociationError(T op,
1002 DomainTypeID typeID,
1003 size_t i) {
1004 auto domainName = getDomain(typeID).getNameAttr();
1005 auto portName = op.getPortNameAttr(i);
1006 auto diag = emitError(op.getPortLocation(i))
1007 << "missing " << domainName << " association for port "
1008 << portName;
1009 noteLocation(diag, op);
1010}
1011
1012LogicalResult ModuleState::unifyAssociations(Operation *op, Value lhs,
1013 Value rhs) {
1014 if (!lhs || !rhs)
1015 return success();
1016
1017 if (lhs == rhs)
1018 return success();
1019
1020 if (!isHardware(lhs) || !isHardware(rhs))
1021 return success();
1022
1023 LLVM_DEBUG({
1024 llvm::dbgs().indent(6) << "unify domains(" << render(lhs) << ") = domains("
1025 << render(rhs) << ")\n";
1026 });
1027
1028 auto *lhsTerm = getOptDomainAssociation(lhs);
1029 auto *rhsTerm = getOptDomainAssociation(rhs);
1030
1031 if (lhsTerm) {
1032 if (rhsTerm) {
1033 if (failed(unify(lhsTerm, rhsTerm))) {
1034 emitDomainCrossingError(op, lhs, lhsTerm, rhs, rhsTerm);
1035 return failure();
1036 }
1037 return success();
1038 }
1039 setDomainAssociation(rhs, lhsTerm);
1040 return success();
1041 }
1042
1043 if (rhsTerm) {
1044 setDomainAssociation(lhs, rhsTerm);
1045 return success();
1046 }
1047
1048 auto *var = allocVar();
1049 setDomainAssociation(lhs, var);
1050 setDomainAssociation(rhs, var);
1051 return success();
1052}
1053
1054template <typename T>
1055LogicalResult ModuleState::unifyAssociations(Operation *op, T &&range) {
1056 Value lhs;
1057 for (auto rhs : std::forward<T>(range)) {
1058 if (!isHardware(rhs))
1059 continue;
1060 if (failed(unifyAssociations(op, lhs, rhs)))
1061 return failure();
1062 lhs = rhs;
1063 }
1064
1065 return success();
1066}
1067
1068LogicalResult ModuleState::unifyAssociations(Operation *op) {
1069 return unifyAssociations(
1070 op, llvm::concat<Value>(op->getOperands(), op->getResults()));
1071}
1072
1073LogicalResult ModuleState::processModulePorts(FModuleOp moduleOp) {
1074 auto numDomains = getNumDomains();
1075 auto domainInfo = moduleOp.getDomainInfoAttr();
1076 auto numPorts = moduleOp.getNumPorts();
1077
1078 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1079 for (size_t i = 0; i < numPorts; ++i) {
1080 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1081 if (!port)
1082 continue;
1083
1084 LLVM_DEBUG(llvm::dbgs().indent(4)
1085 << "process port " << render(port) << "\n");
1086
1087 if (moduleOp.getPortDirection(i) == Direction::In)
1088 processDomainDefinition(port);
1089
1090 domainTypeIDTable[i] = getDomainTypeID(moduleOp, i);
1091 }
1092
1093 for (size_t i = 0; i < numPorts; ++i) {
1094 BlockArgument port = moduleOp.getArgument(i);
1095 if (!isHardware(port))
1096 continue;
1097
1098 LLVM_DEBUG(llvm::dbgs().indent(4)
1099 << "process port " << render(port) << "\n");
1100
1101 SmallVector<IntegerAttr> associations(numDomains);
1102 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
1103 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1104 auto prevDomainPortIndex = associations[domainTypeID.index];
1105 if (prevDomainPortIndex) {
1106 emitDuplicatePortDomainError(moduleOp, i, domainTypeID,
1107 prevDomainPortIndex, domainPortIndex);
1108 return failure();
1109 }
1110 associations[domainTypeID.index] = domainPortIndex;
1111 }
1112
1113 SmallVector<Term *> elements(numDomains);
1114 for (size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
1115 ++domainTypeIndex) {
1116 auto domainPortIndex = associations[domainTypeIndex];
1117 if (!domainPortIndex)
1118 continue;
1119 auto domainPortValue =
1120 cast<DomainValue>(moduleOp.getArgument(domainPortIndex.getUInt()));
1121 elements[domainTypeIndex] = getTermForDomain(domainPortValue);
1122 }
1123
1124 auto *domainAssociations = allocRow(elements);
1125 setDomainAssociation(port, domainAssociations);
1126 }
1127
1128 return success();
1129}
1130
1131template <typename T>
1132LogicalResult ModuleState::processInstancePorts(T op) {
1133 auto numDomains = getNumDomains();
1134 auto domainInfo = op.getDomainInfoAttr();
1135 auto numPorts = op.getNumPorts();
1136
1137 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1138 for (size_t i = 0; i < numPorts; ++i) {
1139 auto port = dyn_cast<DomainValue>(op->getResult(i));
1140 if (!port)
1141 continue;
1142
1143 if (op.getPortDirection(i) == Direction::Out)
1144 processDomainDefinition(port);
1145
1146 domainTypeIDTable[i] = getDomainTypeID(op, i);
1147 }
1148
1149 for (size_t i = 0; i < numPorts; ++i) {
1150 Value port = op->getResult(i);
1151 if (!isHardware(port))
1152 continue;
1153
1154 SmallVector<IntegerAttr> associations(numDomains);
1155 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
1156 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1157 auto prevDomainPortIndex = associations[domainTypeID.index];
1158 if (prevDomainPortIndex) {
1159 emitDuplicatePortDomainError(op, i, domainTypeID, prevDomainPortIndex,
1160 domainPortIndex);
1161 return failure();
1162 }
1163 associations[domainTypeID.index] = domainPortIndex;
1164 }
1165
1166 SmallVector<Term *> elements(numDomains);
1167 for (size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
1168 ++domainTypeIndex) {
1169 auto domainPortIndex = associations[domainTypeIndex];
1170 if (!domainPortIndex)
1171 continue;
1172 auto domainPortValue =
1173 cast<DomainValue>(op->getResult(domainPortIndex.getUInt()));
1174 elements[domainTypeIndex] = getTermForDomain(domainPortValue);
1175 }
1176
1177 auto *domainAssociations = allocRow(elements);
1178 setDomainAssociation(port, domainAssociations);
1179 }
1180
1181 return success();
1182}
1183
1184FInstanceLike ModuleState::fixInstancePorts(FInstanceLike op,
1185 const ModuleUpdateInfo &update) {
1186 auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions);
1187 clone.setDomainInfoAttr(update.portDomainInfo);
1188 op->erase();
1189 dirty();
1190 LLVM_DEBUG(llvm::dbgs().indent(6) << "fixup " << render(clone) << "\n");
1191 return clone;
1192}
1193
1194LogicalResult ModuleState::processOp(FInstanceLike op) {
1195 auto moduleName =
1196 cast<StringAttr>(cast<ArrayAttr>(op.getReferencedModuleNamesAttr())[0]);
1197 auto updateTable = getModuleUpdateTable();
1198 auto lookup = updateTable.find(moduleName);
1199 if (lookup != updateTable.end())
1200 op = fixInstancePorts(op, lookup->second);
1201 return processInstancePorts(op);
1202}
1203
1204LogicalResult ModuleState::processOp(UnsafeDomainCastOp op) {
1205 auto domains = op.getDomains();
1206 if (domains.empty())
1207 return unifyAssociations(op, op.getInput(), op.getResult());
1208
1209 auto input = op.getInput();
1210
1211 SmallVector<Term *> elements(getNumDomains());
1212 if (isHardware(input)) {
1213 auto *inputRow = getDomainAssociationAsRow(input);
1214 elements.assign(inputRow->elements);
1215 }
1216
1217 for (auto value : op.getDomains()) {
1218 auto domain = cast<DomainValue>(value);
1219 auto typeID = getDomainTypeID(domain);
1220 elements[typeID.index] = getTermForDomain(domain);
1221 }
1222
1223 auto *row = allocRow(elements);
1224 setDomainAssociation(op.getResult(), row);
1225 return success();
1226}
1227
1228LogicalResult ModuleState::processOp(DomainDefineOp op) {
1229 auto src = op.getSrc();
1230 auto dst = op.getDest();
1231
1232 auto *srcTerm = getTermForDomain(src);
1233 auto *dstTerm = getTermForDomain(dst);
1234 if (succeeded(unify(dstTerm, srcTerm)))
1235 return success();
1236
1237 auto diag =
1238 op->emitOpError()
1239 << "defines a domain value that was inferred to be a different domain '";
1240 render(dstTerm, diag);
1241 diag << "'";
1242
1243 return failure();
1244}
1245
1246LogicalResult ModuleState::processOp(WireOp op) {
1247 // If the wire has explicit domain operands, seed the domain table with them
1248 // as constraints. When this op is visited, connections have not yet been
1249 // processed (wire declarations precede their uses), so the existing row
1250 // contains only fresh variables that unify unconditionally. Any conflict
1251 // between an explicit wire domain and a connection's inferred domain is
1252 // caught later by the connection's own processOp.
1253 if (op.getDomains().empty())
1254 return unifyAssociations(op, op.getResults());
1255
1256 // Build a row with the explicitly-specified domain slots filled in and set
1257 // it as the association for this wire result.
1258 SmallVector<Term *> elements(getNumDomains());
1259 for (auto domain : op.getDomains()) {
1260 auto domainValue = cast<DomainValue>(domain);
1261 auto typeID = getDomainTypeID(domainValue);
1262 elements[typeID.index] = getTermForDomain(domainValue);
1263 }
1264
1265 auto *row = allocRow(elements);
1266 for (auto result : op.getResults())
1267 setDomainAssociation(result, row);
1268
1269 return success();
1270}
1271
1272LogicalResult ModuleState::processOp(RWProbeOp op) {
1273 auto target = globals.getInnerRefNamespace().lookup(op.getTarget());
1274
1275 if (target.isPort()) {
1276 auto targetOp = cast<FModuleOp>(target.getOp());
1277 auto targetValue = targetOp.getArgument(target.getPort());
1278 return unifyAssociations(op, targetValue, op.getResult());
1279 }
1280
1281 auto targetOp = cast<hw::InnerSymbolOpInterface>(target.getOp());
1282 auto targetValue = targetOp.getTargetResult();
1283 return unifyAssociations(op, targetValue, op.getResult());
1284}
1285
1286LogicalResult ModuleState::processOp(Operation *op) {
1287 LLVM_DEBUG(llvm::dbgs().indent(4) << "process " << render(op) << "\n");
1288 if (auto instance = dyn_cast<FInstanceLike>(op))
1289 return processOp(instance);
1290 if (auto wireOp = dyn_cast<WireOp>(op))
1291 return processOp(wireOp);
1292 if (auto cast = dyn_cast<UnsafeDomainCastOp>(op))
1293 return processOp(cast);
1294 if (auto def = dyn_cast<DomainDefineOp>(op))
1295 return processOp(def);
1296 if (auto probe = dyn_cast<RWProbeOp>(op))
1297 return processOp(probe);
1298 if (auto create = dyn_cast<DomainCreateOp>(op)) {
1299 processDomainDefinition(create);
1300 return success();
1301 }
1302 if (auto createAnon = dyn_cast<DomainCreateAnonOp>(op)) {
1303 processDomainDefinition(createAnon);
1304 return success();
1305 }
1306
1307 return unifyAssociations(op);
1308}
1309
1310LogicalResult ModuleState::processModuleBody(FModuleOp moduleOp) {
1311 return failure(
1312 moduleOp.getBody()
1313 .walk([&](Operation *op) -> WalkResult { return processOp(op); })
1314 .wasInterrupted());
1315}
1316
1317LogicalResult ModuleState::processModule(FModuleOp moduleOp) {
1318 LLVM_DEBUG(llvm::dbgs().indent(2) << "processing:\n");
1319 if (failed(processModulePorts(moduleOp)))
1320 return failure();
1321 if (failed(processModuleBody(moduleOp)))
1322 return failure();
1323 return success();
1324}
1325
1326ExportTable ModuleState::initializeExportTable(FModuleOp moduleOp) {
1327 ExportTable exports;
1328 size_t numPorts = moduleOp.getNumPorts();
1329 for (size_t i = 0; i < numPorts; ++i) {
1330 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1331 if (!port)
1332 continue;
1333 auto value = getOptUnderlyingDomain(port);
1334 if (value)
1335 exports[value].push_back(port);
1336 }
1337
1338 LLVM_DEBUG({
1339 llvm::dbgs().indent(2) << "domain exports:\n";
1340 for (auto entry : exports) {
1341 llvm::dbgs().indent(4) << render(entry.first) << " exported as ";
1342 llvm::interleaveComma(entry.second, llvm::dbgs(),
1343 [&](auto e) { llvm::dbgs() << render(e); });
1344 llvm::dbgs() << "\n";
1345 }
1346 });
1347
1348 return exports;
1349}
1350
1351void ModuleState::ensureSolved(Namespace &ns, DomainTypeID typeID, size_t ip,
1352 LocationAttr loc, VariableTerm *var,
1353 PendingUpdates &pending) {
1354 if (pending.solutions.contains(var))
1355 return;
1356
1357 auto *context = loc.getContext();
1358 auto domainDecl = getDomain(typeID);
1359 auto domainName = domainDecl.getNameAttr();
1360
1361 auto portName = StringAttr::get(context, ns.newName(domainName.getValue()));
1362 auto portType = DomainType::getFromDomainOp(domainDecl);
1363 auto portDirection = Direction::In;
1364 auto portSym = StringAttr();
1365 auto portLoc = loc;
1366 auto portAnnos = std::nullopt;
1367 // Domain type ports have no associations (domain info is in the type).
1368 auto portDomainInfo = ArrayAttr::get(context, {});
1369 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1370 portAnnos, portDomainInfo);
1371
1372 pending.solutions[var] = pending.insertions.size() + ip;
1373 pending.insertions.push_back({ip, portInfo});
1374}
1375
1376void ModuleState::ensureExported(Namespace &ns, const ExportTable &exports,
1377 DomainTypeID typeID, size_t ip,
1378 LocationAttr loc, ValueTerm *val,
1379 PendingUpdates &pending) {
1380 auto value = val->value;
1381 assert(isa<DomainType>(value.getType()));
1382 if (isPort(value) || exports.contains(value) ||
1383 pending.exports.contains(value))
1384 return;
1385
1386 auto *context = loc.getContext();
1387
1388 auto domainDecl = getDomain(typeID);
1389 auto domainName = domainDecl.getNameAttr();
1390
1391 auto portName = StringAttr::get(context, ns.newName(domainName.getValue()));
1392 auto portType = DomainType::getFromDomainOp(domainDecl);
1393 auto portDirection = Direction::Out;
1394 auto portSym = StringAttr();
1395 auto portAnnos = std::nullopt;
1396 // Domain type ports have no associations (domain info is in the type).
1397 auto portDomainInfo = ArrayAttr::get(context, {});
1398 PortInfo portInfo(portName, portType, portDirection, portSym, loc, portAnnos,
1399 portDomainInfo);
1400 pending.exports[value] = pending.insertions.size() + ip;
1401 pending.insertions.push_back({ip, portInfo});
1402}
1403
1404void ModuleState::getUpdatesForDomainAssociationOfPort(
1405 Namespace &ns, PendingUpdates &pending, DomainTypeID typeID, size_t ip,
1406 LocationAttr loc, Term *term, const ExportTable &exports) {
1407 if (auto *var = dyn_cast<VariableTerm>(term)) {
1408 ensureSolved(ns, typeID, ip, loc, var, pending);
1409 return;
1410 }
1411 if (auto *val = dyn_cast<ValueTerm>(term)) {
1412 ensureExported(ns, exports, typeID, ip, loc, val, pending);
1413 return;
1414 }
1415 llvm_unreachable("invalid domain association");
1416}
1417
1418void ModuleState::getUpdatesForDomainAssociationOfPort(
1419 Namespace &ns, const ExportTable &exports, size_t ip, LocationAttr loc,
1420 RowTerm *row, PendingUpdates &pending) {
1421 for (auto [index, term] : llvm::enumerate(row->elements))
1422 getUpdatesForDomainAssociationOfPort(ns, pending, DomainTypeID{index}, ip,
1423 loc, find(term), exports);
1424}
1425
1426void ModuleState::getUpdatesForModulePorts(FModuleOp moduleOp,
1427 const ExportTable &exports,
1428 Namespace &ns,
1429 PendingUpdates &pending) {
1430 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1431 auto port = moduleOp.getArgument(i);
1432 if (!isHardware(port))
1433 continue;
1434
1435 getUpdatesForDomainAssociationOfPort(
1436 ns, exports, i, moduleOp.getPortLocation(i),
1437 getDomainAssociationAsRow(port), pending);
1438 }
1439}
1440
1441void ModuleState::getUpdatesForModule(FModuleOp moduleOp,
1442 const ExportTable &exports,
1443 PendingUpdates &pending) {
1444 Namespace ns;
1445 auto names = moduleOp.getPortNamesAttr();
1446 for (auto name : names.getAsRange<StringAttr>())
1447 ns.add(name);
1448 getUpdatesForModulePorts(moduleOp, exports, ns, pending);
1449}
1450
1451void ModuleState::applyUpdatesToModule(FModuleOp moduleOp, ExportTable &exports,
1452 const PendingUpdates &pending) {
1453 LLVM_DEBUG(llvm::dbgs().indent(2) << "applying updates:\n");
1454 // Put the domain ports in place.
1455 moduleOp.insertPorts(pending.insertions);
1456 dirty();
1457
1458 // Solve any variables and record them as "self-exporting".
1459 for (auto [var, portIndex] : pending.solutions) {
1460 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1461 auto *solution = allocVal(portValue);
1462 LLVM_DEBUG(llvm::dbgs().indent(4)
1463 << "new-input " << render(portValue) << "\n");
1464 solve(var, solution);
1465 exports[portValue].push_back(portValue);
1466 globals.inserted.insert(portValue);
1467 }
1468
1469 // Drive the output ports, and record the export.
1470 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1471 for (auto [domainValue, portIndex] : pending.exports) {
1472 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1473 builder.setInsertionPointAfterValue(domainValue);
1474 DomainDefineOp::create(builder, portValue.getLoc(), portValue, domainValue);
1475 LLVM_DEBUG(llvm::dbgs().indent(4) << "new-output " << render(portValue)
1476 << " := " << render(domainValue) << "\n");
1477 exports[domainValue].push_back(portValue);
1478 globals.inserted.insert(portValue);
1479 setTermForDomain(portValue, allocVal(domainValue));
1480 }
1481}
1482
1483SmallVector<Attribute> ModuleState::copyPortDomainAssociations(
1484 FModuleOp moduleOp, ArrayAttr moduleDomainInfo, size_t portIndex) {
1485 SmallVector<Attribute> result(getNumDomains());
1486 auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex);
1487 for (auto domainPortIndexAttr : oldAssociations) {
1488 auto domainPortIndex = domainPortIndexAttr.getUInt();
1489 auto domainTypeID = getDomainTypeID(moduleOp, domainPortIndex);
1490 result[domainTypeID.index] = domainPortIndexAttr;
1491 };
1492 return result;
1493}
1494
1495LogicalResult ModuleState::driveModuleOutputDomainPorts(FModuleOp moduleOp) {
1496 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1497 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1498 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1499 if (!port || moduleOp.getPortDirection(i) == Direction::In ||
1500 isDriven(port))
1501 continue;
1502
1503 auto *term = getOptTermForDomain(port);
1504 auto *val = llvm::dyn_cast_if_present<ValueTerm>(term);
1505 if (!val) {
1506 emitDomainPortInferenceError(moduleOp, i);
1507 return failure();
1508 }
1509
1510 auto loc = port.getLoc();
1511 auto value = val->value;
1512 LLVM_DEBUG(llvm::dbgs().indent(4) << "connect " << render(port)
1513 << " := " << render(value) << "\n");
1514 DomainDefineOp::create(builder, loc, port, value);
1515 }
1516
1517 return success();
1518}
1519
1520LogicalResult ModuleState::updateModuleDomainInfo(
1521 FModuleOp moduleOp, const ExportTable &exportTable, ArrayAttr &result) {
1522 // At this point, all domain variables mentioned in ports have been
1523 // solved by generalizing the moduleOp (adding input domain ports). Now, we
1524 // have to form the new port domain information for the moduleOp by examining
1525 // the the associated domains of each port.
1526 auto *context = moduleOp.getContext();
1527 auto numDomains = getNumDomains();
1528 auto oldModuleDomainInfo = moduleOp.getDomainInfoAttr();
1529 auto numPorts = moduleOp.getNumPorts();
1530 SmallVector<Attribute> newModuleDomainInfo(numPorts);
1531
1532 for (size_t i = 0; i < numPorts; ++i) {
1533 auto port = moduleOp.getArgument(i);
1534 auto type = port.getType();
1535
1536 if (isa<DomainType>(type)) {
1537 // Domain type ports have no associations (domain info is in the type).
1538 newModuleDomainInfo[i] = ArrayAttr::get(context, {});
1539 continue;
1540 }
1541
1542 if (!isHardware(port)) {
1543 newModuleDomainInfo[i] = ArrayAttr::get(context, {});
1544 continue;
1545 }
1546
1547 auto associations =
1548 copyPortDomainAssociations(moduleOp, oldModuleDomainInfo, i);
1549 auto *row = cast<RowTerm>(getDomainAssociation(port));
1550 for (size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1551 auto domainTypeID = DomainTypeID{domainIndex};
1552 if (associations[domainIndex])
1553 continue;
1554
1555 auto domain = cast<ValueTerm>(find(row->elements[domainIndex]))->value;
1556 auto &exports = exportTable.at(domain);
1557 if (exports.empty()) {
1558 auto portName = moduleOp.getPortNameAttr(i);
1559 auto portLoc = moduleOp.getPortLocation(i);
1560 auto domainDecl = getDomain(domainTypeID);
1561 auto domainName = domainDecl.getNameAttr();
1562 auto diag = emitError(portLoc) << "private " << domainName
1563 << " association for port " << portName;
1564 diag.attachNote(domain.getLoc()) << "associated domain: " << domain;
1565 noteLocation(diag, moduleOp);
1566 return failure();
1567 }
1568
1569 if (exports.size() > 1) {
1570 emitAmbiguousPortDomainAssociation(moduleOp, exports, domainTypeID, i);
1571 return failure();
1572 }
1573
1574 auto argument = cast<BlockArgument>(exports[0]);
1575 auto domainPortIndex = argument.getArgNumber();
1576 associations[domainTypeID.index] =
1577 IntegerAttr::get(IntegerType::get(context, 32, IntegerType::Unsigned),
1578 domainPortIndex);
1579 }
1580
1581 newModuleDomainInfo[i] = ArrayAttr::get(context, associations);
1582 }
1583
1584 result = ArrayAttr::get(moduleOp.getContext(), newModuleDomainInfo);
1585 moduleOp.setDomainInfoAttr(result);
1586 return success();
1587}
1588
1589DomainValue ModuleState::solveVarWithAnonDomain(
1590 OpBuilder &builder, DenseMap<DomainValue, DomainValue> &domainsInScope,
1591 Operation *user, DomainType type, VariableTerm *var) {
1592 auto name = type.getName().getAttr();
1593 DomainValue anon =
1594 DomainCreateAnonOp::create(builder, user->getLoc(), type, name);
1595 dirty();
1596 LLVM_DEBUG(llvm::dbgs().indent(6) << "create anon " << render(anon) << "\n");
1597 solve(var, allocVal(anon));
1598 domainsInScope[anon] = anon;
1599 globals.inserted.insert(anon);
1600 return anon;
1601}
1602
1603DomainValue ModuleState::getDomainInScope(
1604 OpBuilder &builder, DenseMap<DomainValue, DomainValue> &domainsInScope,
1605 DomainValue domain) {
1606 auto &domainInScope = domainsInScope[domain];
1607 if (domainInScope)
1608 return domainInScope;
1609
1610 domainInScope = cast<DomainValue>(
1611 WireOp::create(builder, domain.getLoc(), domain.getType(),
1612 domain.getType().getName().getAttr())
1613 .getResult());
1614
1615 OpBuilder::InsertionGuard guard(builder);
1616 builder.setInsertionPointAfterValue(domain);
1617 DomainDefineOp::create(builder, domain.getLoc(), domainInScope, domain);
1618 dirty();
1619 LLVM_DEBUG(llvm::dbgs().indent(6) << "bounce wire " << render(domainInScope)
1620 << " := " << render(domain) << "\n");
1621 return domainInScope;
1622}
1623
1624LogicalResult
1625ModuleState::updateInstance(DenseMap<DomainValue, DomainValue> &domainsInScope,
1626 FInstanceLike op) {
1627 LLVM_DEBUG(llvm::dbgs().indent(4) << "update " << render(op) << "\n");
1628 OpBuilder builder(op.getContext());
1629 builder.setInsertionPointAfter(op);
1630 auto numPorts = op->getNumResults();
1631
1632 for (size_t i = 0; i < numPorts; ++i)
1633 if (auto port = dyn_cast<DomainValue>(op->getResult(i)))
1634 if (op.getPortDirection(i) == Direction::Out)
1635 domainsInScope[port] = port;
1636
1637 for (size_t i = 0; i < numPorts; ++i) {
1638 auto port = dyn_cast<DomainValue>(op->getResult(i));
1639 auto direction = op.getPortDirection(i);
1640 // If the port is an input domain, we may need to drive the input with
1641 // a value. If we don't know what value to drive to the port, drive an
1642 // anonymous domain.
1643 if (port && direction == Direction::In && !isDriven(port)) {
1644 auto loc = port.getLoc();
1645 auto *term = getTermForDomain(port);
1646 if (auto *var = dyn_cast<VariableTerm>(term)) {
1647 auto domain = solveVarWithAnonDomain(builder, domainsInScope, op,
1648 port.getType(), var);
1649 LLVM_DEBUG(llvm::dbgs().indent(6) << "connect " << render(port)
1650 << " := " << render(domain) << "\n");
1651 DomainDefineOp::create(builder, loc, port, domain);
1652 continue;
1653 }
1654 if (auto *val = dyn_cast<ValueTerm>(term)) {
1655 auto domain = getDomainInScope(builder, domainsInScope, val->value);
1656 LLVM_DEBUG(llvm::dbgs().indent(6) << "connect " << render(port)
1657 << " := " << render(domain) << "\n");
1658 DomainDefineOp::create(builder, loc, port, domain);
1659 continue;
1660 }
1661 llvm_unreachable("unhandled domain term type");
1662 }
1663 }
1664
1665 return success();
1666}
1667
1668LogicalResult
1669ModuleState::updateWire(DenseMap<DomainValue, DomainValue> &domainsInScope,
1670 WireOp wireOp) {
1671 auto result = wireOp.getResult();
1672
1673 if (auto tgt = dyn_cast<DomainValue>(result)) {
1674 if (isDriven(tgt))
1675 return success();
1676
1677 LLVM_DEBUG(llvm::dbgs().indent(4) << "update " << render(wireOp) << "\n");
1678 OpBuilder builder(wireOp);
1679 builder.setInsertionPointAfter(wireOp);
1680 auto *term = getTermForDomain(tgt);
1681 if (auto *var = dyn_cast<VariableTerm>(term)) {
1682 auto src = solveVarWithAnonDomain(builder, domainsInScope, wireOp,
1683 tgt.getType(), var);
1684 LLVM_DEBUG(llvm::dbgs().indent(6)
1685 << "connect " << render(tgt) << " := " << render(src) << "\n");
1686 DomainDefineOp::create(builder, wireOp.getLoc(), tgt, src);
1687 return success();
1688 }
1689 if (auto *val = dyn_cast<ValueTerm>(term)) {
1690 auto src = getDomainInScope(builder, domainsInScope, val->value);
1691 LLVM_DEBUG(llvm::dbgs().indent(6)
1692 << "connect " << render(tgt) << " := " << render(src) << "\n");
1693 DomainDefineOp::create(builder, wireOp.getLoc(), tgt, src);
1694 return success();
1695 }
1696 llvm_unreachable("unhandled domain term type");
1697 }
1698
1699 if (!isHardware(result))
1700 return success();
1701
1702 LLVM_DEBUG(llvm::dbgs().indent(4) << "update " << render(wireOp) << "\n");
1703 OpBuilder builder(wireOp);
1704 auto *row = getDomainAssociationAsRow(wireOp.getResult());
1705
1706 SmallVector<Value> domainOperands;
1707 for (auto [i, element] : llvm::enumerate(
1708 llvm::map_range(row->elements, [&](auto e) { return find(e); }))) {
1709 if (auto *val = dyn_cast<ValueTerm>(element)) {
1710 domainOperands.push_back(
1711 getDomainInScope(builder, domainsInScope, val->value));
1712 continue;
1713 }
1714 if (auto *var = dyn_cast<VariableTerm>(element)) {
1715 auto type = DomainType::getFromDomainOp(getDomain(DomainTypeID{i}));
1716 auto domain =
1717 solveVarWithAnonDomain(builder, domainsInScope, wireOp, type, var);
1718 domainOperands.push_back(domain);
1719 continue;
1720 }
1721 assert(0 && "unhandled domain type");
1722 }
1723 wireOp.getDomainsMutable().assign(domainOperands);
1724 return success();
1725}
1726
1727LogicalResult ModuleState::updateModuleBody(FModuleOp moduleOp) {
1728 DenseMap<DomainValue, DomainValue> domainsInScope;
1729
1730 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i)
1731 if (auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i)))
1732 if (moduleOp.getPortDirection(i) == Direction::In)
1733 domainsInScope[port] = port;
1734
1735 auto result = moduleOp.getBodyBlock()->walk([&](Operation *op) -> WalkResult {
1736 return TypeSwitch<Operation *, WalkResult>(op)
1737 .Case<WireOp>(
1738 [&](auto wire) { return updateWire(domainsInScope, wire); })
1739 .Case<FInstanceLike>([&](auto instance) {
1740 return updateInstance(domainsInScope, instance);
1741 })
1742 .Case<DomainCreateOp, DomainCreateAnonOp>([&](auto domain) {
1743 domainsInScope[domain] = domain;
1744 return success();
1745 })
1746 .Default([&](auto op) { return success(); });
1747 });
1748 return failure(result.wasInterrupted());
1749}
1750
1751LogicalResult ModuleState::updateModule(FModuleOp moduleOp) {
1752 auto exports = initializeExportTable(moduleOp);
1753 PendingUpdates pending;
1754 getUpdatesForModule(moduleOp, exports, pending);
1755 applyUpdatesToModule(moduleOp, exports, pending);
1756
1757 ArrayAttr portDomainInfo;
1758 if (failed(updateModuleDomainInfo(moduleOp, exports, portDomainInfo)))
1759 return failure();
1760
1761 if (failed(driveModuleOutputDomainPorts(moduleOp)))
1762 return failure();
1763
1764 // Record the updated interface change in the update
1765 auto &entry = getModuleUpdateTable()[moduleOp.getModuleNameAttr()];
1766 entry.portDomainInfo = portDomainInfo;
1767 entry.portInsertions = std::move(pending.insertions);
1768
1769 if (failed(updateModuleBody(moduleOp)))
1770 return failure();
1771
1772 LLVM_DEBUG({
1773 llvm::dbgs().indent(2) << "port summary:\n";
1774 for (auto port : moduleOp.getBodyBlock()->getArguments()) {
1775 llvm::dbgs().indent(4) << render(port);
1776 auto info = cast<ArrayAttr>(
1777 moduleOp.getDomainInfoAttrForPort(port.getArgNumber()));
1778 if (info.size()) {
1779 llvm::dbgs() << " domains [";
1780 llvm::interleaveComma(
1781 info.getAsRange<IntegerAttr>(), llvm::dbgs(), [&](auto i) {
1782 llvm::dbgs() << render(moduleOp.getArgument(i.getUInt()));
1783 });
1784 llvm::dbgs() << "]";
1785 }
1786 llvm::dbgs() << "\n";
1787 }
1788 });
1789
1790 return success();
1791}
1792
1793LogicalResult ModuleState::checkModulePorts(FModuleLike moduleOp) {
1794 auto numDomains = getNumDomains();
1795 auto domainInfo = moduleOp.getDomainInfoAttr();
1796 auto numPorts = moduleOp.getNumPorts();
1797
1798 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1799 for (size_t i = 0; i < numPorts; ++i) {
1800 if (isa<DomainType>(moduleOp.getPortType(i)))
1801 domainTypeIDTable[i] = getDomainTypeID(moduleOp, i);
1802 }
1803
1804 for (size_t i = 0; i < numPorts; ++i) {
1805 if (!isHardware(moduleOp.getPortType(i)))
1806 continue;
1807
1808 // Record the domain associations of this port.
1809 SmallVector<IntegerAttr> associations(numDomains);
1810 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
1811 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1812 auto prevDomainPortIndex = associations[domainTypeID.index];
1813 if (prevDomainPortIndex) {
1814 emitDuplicatePortDomainError(moduleOp, i, domainTypeID,
1815 prevDomainPortIndex, domainPortIndex);
1816 return failure();
1817 }
1818 associations[domainTypeID.index] = domainPortIndex;
1819 }
1820
1821 // Check the associations for completeness.
1822 for (size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1823 auto typeID = DomainTypeID{domainIndex};
1824 if (!associations[domainIndex]) {
1825 emitMissingPortDomainAssociationError(moduleOp, typeID, i);
1826 return failure();
1827 }
1828 }
1829 }
1830
1831 return success();
1832}
1833
1834LogicalResult ModuleState::checkModuleDomainPortDrivers(FModuleOp moduleOp) {
1835 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1836 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1837 if (!port || moduleOp.getPortDirection(i) != Direction::Out ||
1838 isDriven(port))
1839 continue;
1840
1841 auto name = moduleOp.getPortNameAttr(i);
1842 auto diag = emitError(moduleOp.getPortLocation(i))
1843 << "undriven domain port " << name;
1844 noteLocation(diag, moduleOp);
1845 return failure();
1846 }
1847
1848 return success();
1849}
1850
1851LogicalResult ModuleState::checkInstanceDomainPortDrivers(FInstanceLike op) {
1852 for (size_t i = 0, e = op->getNumResults(); i < e; ++i) {
1853 auto port = dyn_cast<DomainValue>(op->getResult(i));
1854
1855 auto type = port.getType();
1856 if (!isa<DomainType>(type) || op.getPortDirection(i) != Direction::In ||
1857 isDriven(port))
1858 continue;
1859
1860 auto name = op.getPortNameAttr(i);
1861 auto diag = emitError(op.getPortLocation(i))
1862 << "undriven domain port " << name;
1863 noteLocation(diag, op);
1864 return failure();
1865 }
1866
1867 return success();
1868}
1869
1870LogicalResult ModuleState::checkModuleBody(FModuleOp moduleOp) {
1871 auto result = moduleOp.getBody().walk([&](FInstanceLike op) -> WalkResult {
1872 return checkInstanceDomainPortDrivers(op);
1873 });
1874 return failure(result.wasInterrupted());
1875}
1876
1877LogicalResult ModuleState::inferModule(FModuleOp moduleOp) {
1878 LLVM_DEBUG(llvm::dbgs() << "infer: " << moduleOp.getModuleName() << "\n");
1879 if (failed(processModule(moduleOp)))
1880 return failure();
1881
1882 return updateModule(moduleOp);
1883}
1884
1885LogicalResult ModuleState::checkModule(FModuleOp moduleOp) {
1886 LLVM_DEBUG(llvm::dbgs() << "check: " << moduleOp.getModuleName() << "\n");
1887 if (failed(checkModulePorts(moduleOp)))
1888 return failure();
1889
1890 if (failed(checkModuleDomainPortDrivers(moduleOp)))
1891 return failure();
1892
1893 if (failed(checkModuleBody(moduleOp)))
1894 return failure();
1895
1896 return processModule(moduleOp);
1897}
1898
1899LogicalResult ModuleState::checkModule(FExtModuleOp extModuleOp) {
1900 LLVM_DEBUG(llvm::dbgs() << "check: " << extModuleOp.getModuleName() << "\n");
1901 return checkModulePorts(extModuleOp);
1902}
1903
1904LogicalResult ModuleState::checkAndInferModule(FModuleOp moduleOp) {
1905 LLVM_DEBUG(llvm::dbgs() << "check/infer: " << moduleOp.getModuleName()
1906 << "\n");
1907
1908 if (failed(checkModulePorts(moduleOp)))
1909 return failure();
1910
1911 if (failed(processModule(moduleOp)))
1912 return failure();
1913
1914 if (failed(driveModuleOutputDomainPorts(moduleOp)))
1915 return failure();
1916
1917 return updateModuleBody(moduleOp);
1918}
1919
1920//===---------------------------------------------------------------------------
1921// Domain Stripping.
1922//===---------------------------------------------------------------------------
1923
1924static LogicalResult stripModule(FModuleLike op) {
1925 WalkResult result = op->walk<mlir::WalkOrder::PostOrder, ReverseIterator>(
1926 [=](Operation *op) -> WalkResult {
1927 return TypeSwitch<Operation *, WalkResult>(op)
1928 .Case<FModuleLike>([](FModuleLike op) {
1929 auto n = op.getNumPorts();
1930 BitVector erasures(n);
1931 for (size_t i = 0; i < n; ++i)
1932 if (isa<DomainType>(op.getPortType(i)))
1933 erasures.set(i);
1934 op.erasePorts(erasures);
1935 return WalkResult::advance();
1936 })
1937 .Case<DomainDefineOp, DomainCreateAnonOp, DomainCreateOp>(
1938 [](Operation *op) {
1939 op->erase();
1940 return WalkResult::advance();
1941 })
1942 .Case<DomainSubfieldOp>([](DomainSubfieldOp op) {
1943 if (!op->use_empty()) {
1944 OpBuilder builder(op);
1945 op.replaceAllUsesWith(
1946 UnknownValueOp::create(builder, op.getLoc(), op.getType())
1947 .getResult());
1948 }
1949 op.erase();
1950 return WalkResult::advance();
1951 })
1952 .Case<UnsafeDomainCastOp>([](UnsafeDomainCastOp op) {
1953 op.replaceAllUsesWith(op.getInput());
1954 op.erase();
1955 return WalkResult::advance();
1956 })
1957 .Case<WireOp>([](WireOp op) {
1958 // Erase wires of DomainType
1959 if (isa<DomainType>(op.getType(0))) {
1960 op->erase();
1961 return WalkResult::advance();
1962 }
1963 // Erase domain operands from regular wires
1964 if (!op.getDomains().empty()) {
1965 op->eraseOperands(0, op.getNumOperands());
1966 }
1967 return WalkResult::advance();
1968 })
1969 .Case<FInstanceLike>([](auto op) {
1970 auto n = op.getNumPorts();
1971 BitVector erasures(n);
1972 for (size_t i = 0; i < n; ++i)
1973 if (isa<DomainType>(op->getResult(i).getType()))
1974 erasures.set(i);
1975 op.cloneWithErasedPortsAndReplaceUses(erasures);
1976 op.erase();
1977 return WalkResult::advance();
1978 })
1979 .Default([](Operation *op) {
1980 for (auto type :
1981 concat<Type>(op->getOperandTypes(), op->getResultTypes())) {
1982 if (isa<DomainType>(type)) {
1983 op->emitOpError("cannot be stripped");
1984 return WalkResult::interrupt();
1985 }
1986 }
1987 return WalkResult::advance();
1988 });
1989 });
1990 return failure(result.wasInterrupted());
1991}
1992
1993static LogicalResult stripCircuit(MLIRContext *context, CircuitOp circuit) {
1994 llvm::SmallVector<FModuleLike> modules;
1995 for (Operation &op : make_early_inc_range(*circuit.getBodyBlock())) {
1996 TypeSwitch<Operation *, void>(&op)
1997 .Case<FModuleLike>([&](FModuleLike op) { modules.push_back(op); })
1998 .Case<DomainOp>([](DomainOp op) { op.erase(); });
1999 }
2000 return failableParallelForEach(context, modules, stripModule);
2001}
2002
2003//===---------------------------------------------------------------------------
2004// InferDomainsPass: Top-level pass implementation.
2005//===---------------------------------------------------------------------------
2006
2007LogicalResult CircuitState::runOnModule(Operation *op) {
2008 assert(mode != InferDomainsMode::Strip);
2009 ModuleState state(*this);
2010 if (auto moduleOp = dyn_cast<FModuleOp>(op)) {
2011 if (mode == InferDomainsMode::Check)
2012 return state.checkModule(moduleOp);
2013
2014 if (mode == InferDomainsMode::InferAll || moduleOp.isPrivate())
2015 return state.inferModule(moduleOp);
2016
2017 return state.checkAndInferModule(moduleOp);
2018 }
2019
2020 if (auto extModuleOp = dyn_cast<FExtModuleOp>(op))
2021 return state.checkModule(extModuleOp);
2022
2023 return success();
2024}
2025
2026LogicalResult CircuitState::run() {
2027 DenseSet<Operation *> errored;
2028 instanceGraph.walkPostOrder([&](auto &node) {
2029 auto moduleOp = node.getModule();
2030 for (auto *inst : node) {
2031 if (errored.contains(inst->getTarget()->getModule())) {
2032 errored.insert(moduleOp);
2033 return;
2034 }
2035 }
2036 if (failed(runOnModule(node.getModule())))
2037 errored.insert(moduleOp);
2038 });
2039 return success(errored.empty());
2040}
2041
2042namespace {
2043struct InferDomainsPass
2044 : public circt::firrtl::impl::InferDomainsBase<InferDomainsPass> {
2045 using Base::Base;
2046 void runOnOperation() override {
2048 auto circuit = getOperation();
2049
2050 if (mode == InferDomainsMode::Strip) {
2051 if (failed(stripCircuit(&getContext(), circuit)))
2052 signalPassFailure();
2053 return;
2054 }
2055
2056 auto &instanceGraph = getAnalysis<InstanceGraph>();
2057 auto &symbolTable = getAnalysis<SymbolTable>();
2058 auto &innerSymbolTableCollection =
2059 getAnalysis<InnerSymbolTableCollection>();
2060 circt::hw::InnerRefNamespace innerRefNamespace{symbolTable,
2061 innerSymbolTableCollection};
2062 CircuitState state(circuit, instanceGraph, innerRefNamespace, mode);
2063 if (failed(state.run()))
2064 signalPassFailure();
2065 }
2066};
2067} // namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
SmallVector< std::pair< unsigned, PortInfo > > PortInsertions
mlir::TypedValue< DomainType > DomainValue
static LogicalResult stripCircuit(MLIRContext *context, CircuitOp circuit)
DenseMap< VariableTerm *, unsigned > PendingSolutions
A map from unsolved variables to a port index, where that port has not yet been created.
static bool isHardware(Type type)
True if a value of the given type could be associated with a domain.
static bool isPort(BlockArgument arg)
Return true if the value is a port on the module.
DenseMap< DomainValue, TinyPtrVector< DomainValue > > ExportTable
A map from domain IR values defined internal to the moduleOp, to ports that alias that domain.
static auto getPortDomainAssociation(ArrayAttr info, size_t i)
From a domain info attribute, get the row of associated domains for a hardware value at index i.
static bool isDriven(DomainValue port)
Returns true if the value is driven by a connect op.
static LogicalResult stripModule(FModuleLike op)
static Block * getBodyBlock(FModuleLike mod)
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
Definition Debug.h:70
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
This graph tracks modules and where they are instantiated.
This class represents a collection of InnerSymbolTable's.
static StringRef toLongString(Direction direction)
Definition FIRRTLEnums.h:48
InferDomainsMode
The mode for the InferDomains pass.
Definition Passes.h:73
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:2405
This holds the name and type that describes the module's ports.
This class represents the namespace in which InnerRef's can be resolved.