CIRCT 23.0.0git
Loading...
Searching...
No Matches
InferDomains.cpp
Go to the documentation of this file.
1//===- InferDomains.cpp - Infer and Check FIRRTL Domains ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// InferDomains implements FIRRTL domain inference and checking. This pass is
10// a bottom-up transform acting on modules. For each moduleOp, we ensure there
11// are no domain crossings, and we make explicit the domain associations of
12// ports.
13//
14// This pass does not require that ExpandWhens has run, but it should have run.
15// If ExpandWhens has not been run, then duplicate connections will influence
16// domain inference and this can result in errors.
17//
18//===----------------------------------------------------------------------===//
19
24#include "circt/Support/Debug.h"
26#include "mlir/IR/AsmState.h"
27#include "mlir/IR/Iterators.h"
28#include "mlir/IR/Threading.h"
29#include "llvm/ADT/DenseMap.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/SmallVector.h"
32#include "llvm/ADT/TinyPtrVector.h"
33
34#define DEBUG_TYPE "firrtl-infer-domains"
35
36namespace circt {
37namespace firrtl {
38#define GEN_PASS_DEF_INFERDOMAINS
39#include "circt/Dialect/FIRRTL/Passes.h.inc"
40} // namespace firrtl
41} // namespace circt
42
43using namespace circt;
44using namespace firrtl;
45
48using llvm::concat;
49using mlir::AsmState;
50using mlir::InFlightDiagnostic;
51using mlir::ReverseIterator;
52
53namespace {
54struct VariableTerm;
55} // namespace
56
57//====--------------------------------------------------------------------------
58// Helpers.
59//====--------------------------------------------------------------------------
60
61using DomainValue = mlir::TypedValue<DomainType>;
62
63using PortInsertions = SmallVector<std::pair<unsigned, PortInfo>>;
64
65/// From a domain info attribute, get the row of associated domains for a
66/// hardware value at index i.
67static auto getPortDomainAssociation(ArrayAttr info, size_t i) {
68 if (info.empty())
69 return info.getAsRange<IntegerAttr>();
70 return cast<ArrayAttr>(info[i]).getAsRange<IntegerAttr>();
71}
72
73/// Return true if the value is a port on the module.
74static bool isPort(BlockArgument arg) {
75 return isa<FModuleOp>(arg.getOwner()->getParentOp());
76}
77
78/// Return true if the value is a port on the module.
79static bool isPort(Value value) {
80 auto arg = dyn_cast<BlockArgument>(value);
81 if (!arg)
82 return false;
83 return isPort(arg);
84}
85
86/// Returns true if the value is driven by a connect op.
87static bool isDriven(DomainValue port) {
88 for (auto *user : port.getUsers())
89 if (auto connect = dyn_cast<FConnectLike>(user))
90 if (connect.getDest() == port)
91 return true;
92 return false;
93}
94
95/// True if a value of the given type could be associated with a domain.
96static bool isHardware(Type type) {
97 return type_isa<FIRRTLBaseType, RefType>(type);
98}
99
100/// True if the given value could be association with a domain.
101static bool isHardware(Value value) {
102 if (!isHardware(value.getType()))
103 return false;
104
105 if (auto *op = value.getDefiningOp())
106 if (op->hasTrait<OpTrait::ConstantLike>())
107 return false;
108
109 return true;
110}
111
112//====--------------------------------------------------------------------------
113// Global State.
114//====--------------------------------------------------------------------------
115
116/// Each domain type declared in the circuit is assigned a type-id, based on the
117/// order of declaration. Domain associations for hardware values are
118/// represented as a list, or row, of domains. The domains in a row are ordered
119/// according to their type's id.
120namespace {
121struct DomainTypeID {
122 size_t index;
123};
124} // namespace
125
126/// Information about the changes made to the interface of a moduleOp, which can
127/// be replayed onto an instance.
128namespace {
129struct ModuleUpdateInfo {
130 /// The updated domain information for a moduleOp.
131 ArrayAttr portDomainInfo;
132 /// The domain ports which have been inserted into a moduleOp.
133 PortInsertions portInsertions;
134};
135} // namespace
136
137namespace {
138struct CircuitState {
139 CircuitState(CircuitOp circuit, InstanceGraph &instanceGraph,
140 InnerRefNamespace &innerRefNamespace, InferDomainsMode mode)
141 : circuit(circuit), instanceGraph(instanceGraph),
142 innerRefNamespace(innerRefNamespace), mode(mode) {
143 processCircuit(circuit);
144 }
145
146 LogicalResult run();
147
148 ArrayRef<DomainOp> getDomains() const { return domainTable; }
149 size_t getNumDomains() const { return domainTable.size(); }
150 DomainOp getDomain(DomainTypeID id) const { return domainTable[id.index]; }
151 DomainTypeID getDomainTypeID(Type type) { return typeIDTable[type]; }
152
153 void dirty() { asmState = nullptr; }
154 AsmState &getAsmState() {
155 if (!asmState) {
156 asmState = std::make_unique<AsmState>(
157 circuit, mlir::OpPrintingFlags().assumeVerified());
158 }
159 return *asmState;
160 }
161
162 size_t getVariableID(VariableTerm *term) {
163 return variableIDTable.insert({term, variableIDTable.size() + 1})
164 .first->second;
165 }
166
167 DenseMap<StringAttr, ModuleUpdateInfo> &getModuleUpdateTable() {
168 return moduleUpdateTable;
169 }
170
171 InnerRefNamespace &getInnerRefNamespace() { return innerRefNamespace; }
172
173 DenseSet<Value> inserted;
174
175private:
176 LogicalResult runOnModule(Operation *moduleOp);
177
178 void processDomain(DomainOp op) {
179 auto index = domainTable.size();
180 auto domainType = DomainType::getFromDomainOp(op);
181 domainTable.push_back(op);
182 typeIDTable.insert({domainType, {index}});
183 }
184
185 void processCircuit(CircuitOp circuit) {
186 for (auto decl : circuit.getOps<DomainOp>())
187 processDomain(decl);
188 }
189
190 CircuitOp circuit;
191 InstanceGraph &instanceGraph;
192 InnerRefNamespace &innerRefNamespace;
193 InferDomainsMode mode;
194 SmallVector<DomainOp> domainTable;
195 DenseMap<Type, DomainTypeID> typeIDTable;
196 DenseMap<VariableTerm *, size_t> variableIDTable;
197 std::unique_ptr<AsmState> asmState;
198 DenseMap<StringAttr, ModuleUpdateInfo> moduleUpdateTable;
199};
200} // namespace
201
202//====--------------------------------------------------------------------------
203// Terms: Syntax for unifying domain and domain-rows.
204//====--------------------------------------------------------------------------
205
206/// The different sorts of terms in the unification engine.
207namespace {
208enum class TermKind {
209 Variable,
210 Value,
211 Row,
212};
213} // namespace
214
215/// A term in the unification engine.
216namespace {
217struct Term {
218 constexpr Term(TermKind kind) : kind(kind) {}
219 TermKind kind;
220};
221} // namespace
222
223/// Helper to define a term kind.
224namespace {
225template <TermKind K>
226struct TermBase : Term {
227 static bool classof(const Term *term) { return term->kind == K; }
228 TermBase() : Term(K) {}
229};
230} // namespace
231
232/// An unknown value.
233namespace {
234struct VariableTerm : public TermBase<TermKind::Variable> {
235 VariableTerm() : leader(nullptr) {}
236 VariableTerm(Term *leader) : leader(leader) {}
237 Term *leader;
238};
239} // namespace
240
241/// A concrete value defined in the IR.
242namespace {
243struct ValueTerm : public TermBase<TermKind::Value> {
244 ValueTerm(DomainValue value) : value(value) {}
245 DomainValue value;
246};
247} // namespace
248
249/// A row of domains.
250namespace {
251struct RowTerm : public TermBase<TermKind::Row> {
252 RowTerm(ArrayRef<Term *> elements) : elements(elements) {}
253 ArrayRef<Term *> elements;
254};
255} // namespace
256
257//====--------------------------------------------------------------------------
258// Module processing: solve for the domain associations of hardware.
259//====--------------------------------------------------------------------------
260
261/// A map from unsolved variables to a port index, where that port has not yet
262/// been created. Eventually we will have an input domain at the port index,
263/// which will be the solution to the recorded variable.
264using PendingSolutions = DenseMap<VariableTerm *, unsigned>;
265
266/// A map from local domains to an aliasing port index, where that port has not
267/// yet been created. Eventually we will be exporting the domain value at the
268/// port index.
270
271namespace {
272struct PendingUpdates {
273 PortInsertions insertions;
274 PendingSolutions solutions;
275 PendingExports exports;
276};
277} // namespace
278
279/// A map from domain IR values defined internal to the moduleOp, to ports that
280/// alias that domain. These ports make the domain useable as associations of
281/// ports, and we say these are exporting ports.
282using ExportTable = DenseMap<DomainValue, TinyPtrVector<DomainValue>>;
283
284namespace {
285class ModuleState {
286public:
287 explicit ModuleState(CircuitState &globals) : globals(globals) {}
288
289 ArrayRef<DomainOp> getDomains() { return globals.getDomains(); }
290 size_t getNumDomains() { return globals.getNumDomains(); }
291 DomainOp getDomain(DomainTypeID id) { return globals.getDomain(id); }
292 DomainTypeID getDomainTypeID(Type type) {
293 return globals.getDomainTypeID(type);
294 }
295 DomainTypeID getDomainTypeID(FModuleLike module, size_t i) {
296 return globals.getDomainTypeID(module.getPortType(i));
297 }
298 DomainTypeID getDomainTypeID(FInstanceLike op, size_t i) const {
299 return globals.getDomainTypeID(op->getResult(i).getType());
300 }
301 DomainTypeID getDomainTypeID(DomainValue value) const {
302 return globals.getDomainTypeID(value.getType());
303 }
304 auto &getModuleUpdateTable() { return globals.getModuleUpdateTable(); }
305
306 mlir::AsmState &getAsmState() { return globals.getAsmState(); }
307 void dirty() { globals.dirty(); }
308
309 template <typename T>
310 void render(Operation *op, T &out);
311 template <typename T>
312 void render(Value value, T &out);
313 template <typename T>
314 void renderLong(Value value, T &out);
315 template <typename T>
316 void render(Term *term, T &out);
317 template <typename T>
318 struct Render;
319 template <typename T>
320 Render<T> render(T &&subject);
321 struct RenderLong;
322 RenderLong renderLong(Value value);
323
324 Term *find(Term *x);
325 LogicalResult unify(Term *lhs, Term *rhs);
326 LogicalResult unify(VariableTerm *x, Term *y);
327 LogicalResult unify(ValueTerm *xv, Term *y);
328 LogicalResult unify(RowTerm *lhsRow, Term *rhs);
329 void solve(Term *lhs, Term *rhs);
330
331 [[nodiscard]] RowTerm *allocRow(size_t size);
332 [[nodiscard]] RowTerm *allocRow(ArrayRef<Term *> elements);
333 [[nodiscard]] VariableTerm *allocVar();
334 [[nodiscard]] ValueTerm *allocVal(DomainValue value);
335 template <typename T, typename... Args>
336 T *alloc(Args &&...args);
337 ArrayRef<Term *> allocArray(ArrayRef<Term *> elements);
338
339 DomainValue getOptUnderlyingDomain(DomainValue value);
340 Term *getOptTermForDomain(DomainValue value);
341 Term *getTermForDomain(DomainValue value);
342 void setTermForDomain(DomainValue value, Term *term);
343
344 Term *getOptDomainAssociation(Value value);
345 Term *getDomainAssociation(Value value);
346 void setDomainAssociation(Value value, Term *term);
347
348 void processDomainDefinition(DomainValue domain);
349 RowTerm *getDomainAssociationAsRow(Value value);
350
351 void noteLocation(InFlightDiagnostic &diag, Operation *op);
352 void noteDomain(InFlightDiagnostic &diag, DomainValue domain);
353 void noteDomainSource(InFlightDiagnostic &diag, DomainValue domain);
354 void noteDomainSource(InFlightDiagnostic &diag, Term *term);
355 void emitDomainCrossingError(Operation *op, Value lhs, Term *lhsTerm,
356 Value rhs, Term *rhsTerm);
357 template <typename T>
358 void emitDuplicatePortDomainError(T op, size_t i, DomainTypeID domainTypeID,
359 IntegerAttr domainPortIndexAttr1,
360 IntegerAttr domainPortIndexAttr2);
361 template <typename T>
362 void emitDomainPortInferenceError(T op, size_t i);
363 template <typename T>
364 void emitAmbiguousPortDomainAssociation(
365 T op, const llvm::TinyPtrVector<DomainValue> &exports,
366 DomainTypeID typeID, size_t i);
367 template <typename T>
368 void emitMissingPortDomainAssociationError(T op, DomainTypeID typeID,
369 size_t i);
370
371 LogicalResult unifyAssociations(Operation *op, Value lhs, Value rhs);
372 template <typename T>
373 LogicalResult unifyAssociations(Operation *op, T &&range);
374 LogicalResult unifyAssociations(Operation *op);
375
376 LogicalResult processModulePorts(FModuleOp moduleOp);
377 template <typename T>
378 LogicalResult processInstancePorts(T op);
379 FInstanceLike fixInstancePorts(FInstanceLike op,
380 const ModuleUpdateInfo &update);
381 LogicalResult processOp(FInstanceLike op);
382 LogicalResult processOp(UnsafeDomainCastOp op);
383 LogicalResult processOp(DomainDefineOp op);
384 LogicalResult processOp(WireOp op);
385 LogicalResult processOp(RWProbeOp op);
386 LogicalResult processOp(Operation *op);
387 LogicalResult processModuleBody(FModuleOp moduleOp);
388 LogicalResult processModule(FModuleOp moduleOp);
389
390 ExportTable initializeExportTable(FModuleOp moduleOp);
391 void ensureSolved(Namespace &ns, DomainTypeID typeID, size_t ip,
392 LocationAttr loc, VariableTerm *var,
393 PendingUpdates &pending);
394 void ensureExported(Namespace &ns, const ExportTable &exports,
395 DomainTypeID typeID, size_t ip, LocationAttr loc,
396 ValueTerm *val, PendingUpdates &pending);
397 void getUpdatesForDomainAssociationOfPort(Namespace &ns,
398 PendingUpdates &pending,
399 DomainTypeID typeID, size_t ip,
400 LocationAttr loc, Term *term,
401 const ExportTable &exports);
402 void getUpdatesForDomainAssociationOfPort(Namespace &ns,
403 const ExportTable &exports,
404 size_t ip, LocationAttr loc,
405 RowTerm *row,
406 PendingUpdates &pending);
407 void getUpdatesForModulePorts(FModuleOp moduleOp, const ExportTable &exports,
408 Namespace &ns, PendingUpdates &pending);
409 void getUpdatesForModule(FModuleOp moduleOp, const ExportTable &exports,
410 PendingUpdates &pending);
411 void applyUpdatesToModule(FModuleOp moduleOp, ExportTable &exports,
412 const PendingUpdates &pending);
413 SmallVector<Attribute> copyPortDomainAssociations(FModuleOp moduleOp,
414 ArrayAttr moduleDomainInfo,
415 size_t portIndex);
416 LogicalResult driveModuleOutputDomainPorts(FModuleOp moduleOp);
417 LogicalResult updateModuleDomainInfo(FModuleOp moduleOp,
418 const ExportTable &exportTable,
419 ArrayAttr &result);
421 solveVarWithAnonDomain(OpBuilder &builder,
422 DenseMap<DomainValue, DomainValue> &domainsInScope,
423 Operation *user, DomainType type, VariableTerm *var);
425 getDomainInScope(OpBuilder &builder,
426 DenseMap<DomainValue, DomainValue> &domainsInScope,
427 DomainValue domain);
428 LogicalResult
429 updateInstance(DenseMap<DomainValue, DomainValue> &domainsInScope,
430 FInstanceLike op);
431 LogicalResult updateWire(DenseMap<DomainValue, DomainValue> &domainsInScope,
432 WireOp wireOp);
433 LogicalResult updateModuleBody(FModuleOp moduleOp);
434 LogicalResult updateModule(FModuleOp moduleOp);
435
436 LogicalResult checkModulePorts(FModuleLike moduleOp);
437 LogicalResult checkModuleDomainPortDrivers(FModuleOp moduleOp);
438 LogicalResult checkInstanceDomainPortDrivers(FInstanceLike op);
439 LogicalResult checkModuleBody(FModuleOp moduleOp);
440
441 LogicalResult inferModule(FModuleOp moduleOp);
442 LogicalResult checkModule(FModuleOp moduleOp);
443 LogicalResult checkModule(FExtModuleOp extModuleOp);
444 LogicalResult checkAndInferModule(FModuleOp moduleOp);
445
446private:
447 CircuitState &globals;
448 DenseMap<Value, Term *> termTable;
449 DenseMap<Value, Term *> associationTable;
450 llvm::BumpPtrAllocator allocator;
451};
452} // namespace
453
454template <typename T>
455void ModuleState::render(Operation *op, T &out) {
456 op->print(out, getAsmState());
457}
458
459template <typename T>
460void ModuleState::render(Value value, T &out) {
461 if (!value) {
462 out << "null";
463 return;
464 }
465
466 auto [name, _] = getFieldName(value);
467 if (name.empty()) {
468 llvm::raw_string_ostream os(name);
469 value.printAsOperand(os, globals.getAsmState());
470 }
471 out << name;
472}
473
474template <typename T>
475void ModuleState::renderLong(Value value, T &out) {
476 if (auto arg = dyn_cast<BlockArgument>(value)) {
477 if (auto moduleOp = llvm::dyn_cast_if_present<FModuleLike>(
478 arg.getOwner()->getParentOp())) {
480 moduleOp.getPortDirection(arg.getArgNumber()));
481 out << " module port ";
482 }
483 } else if (auto result = dyn_cast<OpResult>(value)) {
484 auto *op = result.getOwner();
485 if (auto inst = dyn_cast<FInstanceLike>(op)) {
487 inst.getPortDirection(result.getResultNumber()));
488 out << " instance port ";
489 }
490 }
491
492 render(value, out);
493}
494
495template <typename T>
496// NOLINTNEXTLINE(misc-no-recursion)
497void ModuleState::render(Term *term, T &out) {
498 if (!term) {
499 out << "null";
500 return;
501 }
502 term = find(term);
503 if (auto *var = dyn_cast<VariableTerm>(term)) {
504 out << "?" << globals.getVariableID(var);
505 return;
506 }
507 if (auto *val = dyn_cast<ValueTerm>(term)) {
508 auto value = val->value;
509 render(value, out);
510 return;
511 }
512 if (auto *row = dyn_cast<RowTerm>(term)) {
513 out << "[";
514 llvm::interleaveComma(
515 llvm::seq(size_t(0), getNumDomains()), out, [&](auto i) {
516 render(row->elements[i], out);
517 out << " : " << getDomain(DomainTypeID{i}).getSymName();
518 });
519 out << "]";
520 return;
521 }
522 out << "unknown";
523}
524
525template <typename T>
526struct ModuleState::Render {
527 ModuleState *state;
529};
530
531template <typename T>
532ModuleState::Render<T> ModuleState::render(T &&subject) {
533 return Render<T>{this, std::forward<T>(subject)};
534}
535
536template <typename T>
537static llvm::raw_ostream &operator<<(llvm::raw_ostream &out,
538 ModuleState::Render<T> r) {
539 r.state->render(r.subject, out);
540 return out;
541}
542
544 ModuleState *state;
545 Value value;
546};
547
548ModuleState::RenderLong ModuleState::renderLong(Value value) {
549 return RenderLong{this, value};
550}
551
552static Diagnostic &operator<<(Diagnostic &diag, ModuleState::RenderLong r) {
553 r.state->renderLong(r.value, diag);
554 return diag;
555}
556
557// NOLINTNEXTLINE(misc-no-recursion)
558Term *ModuleState::find(Term *x) {
559 if (!x)
560 return nullptr;
561
562 if (auto *var = dyn_cast<VariableTerm>(x)) {
563 if (var->leader == nullptr)
564 return var;
565
566 auto *leader = find(var->leader);
567 if (leader != var->leader)
568 var->leader = leader;
569 return leader;
570 }
571
572 return x;
573}
574
575LogicalResult ModuleState::unify(VariableTerm *x, Term *y) {
576 assert(!x->leader);
577 x->leader = y;
578 return success();
579}
580
581LogicalResult ModuleState::unify(ValueTerm *xv, Term *y) {
582 if (auto *yv = dyn_cast<VariableTerm>(y)) {
583 yv->leader = xv;
584 return success();
585 }
586
587 if (auto *yv = dyn_cast<ValueTerm>(y))
588 return success(xv == yv);
589
590 return failure();
591}
592
593// NOLINTNEXTLINE(misc-no-recursion)
594LogicalResult ModuleState::unify(RowTerm *lhsRow, Term *rhs) {
595 if (auto *rhsVar = dyn_cast<VariableTerm>(rhs)) {
596 rhsVar->leader = lhsRow;
597 return success();
598 }
599 if (auto *rhsRow = dyn_cast<RowTerm>(rhs)) {
600 for (auto [x, y] : llvm::zip_equal(lhsRow->elements, rhsRow->elements))
601 if (failed(unify(x, y)))
602 return failure();
603 return success();
604 }
605 return failure();
606}
607
608// NOLINTNEXTLINE(misc-no-recursion)
609LogicalResult ModuleState::unify(Term *lhs, Term *rhs) {
610 if (!lhs || !rhs)
611 return success();
612 lhs = find(lhs);
613 rhs = find(rhs);
614 if (lhs == rhs)
615 return success();
616
617 LLVM_DEBUG(llvm::dbgs().indent(6)
618 << "unify " << render(lhs) << " = " << render(rhs) << "\n");
619
620 if (auto *lhsVar = dyn_cast<VariableTerm>(lhs))
621 return unify(lhsVar, rhs);
622 if (auto *lhsVal = dyn_cast<ValueTerm>(lhs))
623 return unify(lhsVal, rhs);
624 if (auto *lhsRow = dyn_cast<RowTerm>(lhs))
625 return unify(lhsRow, rhs);
626 return failure();
627}
628
629void ModuleState::solve(Term *lhs, Term *rhs) {
630 [[maybe_unused]] auto result = unify(lhs, rhs);
631 assert(result.succeeded());
632}
633
634RowTerm *ModuleState::allocRow(size_t size) {
635 SmallVector<Term *> elements;
636 elements.resize(size);
637 return allocRow(elements);
638}
639
640RowTerm *ModuleState::allocRow(ArrayRef<Term *> elements) {
641 auto ds = allocArray(elements);
642 return alloc<RowTerm>(ds);
643}
644
645VariableTerm *ModuleState::allocVar() { return alloc<VariableTerm>(); }
646
647ValueTerm *ModuleState::allocVal(DomainValue value) {
648 return alloc<ValueTerm>(value);
649}
650
651template <typename T, typename... Args>
652T *ModuleState::alloc(Args &&...args) {
653 static_assert(std::is_base_of_v<Term, T>, "T must be a term");
654 return new (allocator) T(std::forward<Args>(args)...);
655}
656
657ArrayRef<Term *> ModuleState::allocArray(ArrayRef<Term *> elements) {
658 auto size = elements.size();
659 if (size == 0)
660 return {};
661
662 auto *result = allocator.Allocate<Term *>(size);
663 llvm::uninitialized_copy(elements, result);
664 for (size_t i = 0; i < size; ++i)
665 if (!result[i])
666 result[i] = alloc<VariableTerm>();
667
668 return ArrayRef(result, size);
669}
670
671DomainValue ModuleState::getOptUnderlyingDomain(DomainValue value) {
672 auto *term = getOptTermForDomain(value);
673 if (auto *val = llvm::dyn_cast_if_present<ValueTerm>(term))
674 return val->value;
675 return nullptr;
676}
677
678Term *ModuleState::getOptTermForDomain(DomainValue value) {
679 assert(isa<DomainType>(value.getType()));
680 auto it = termTable.find(value);
681 if (it == termTable.end())
682 return nullptr;
683 return find(it->second);
684}
685
686Term *ModuleState::getTermForDomain(DomainValue value) {
687 assert(isa<DomainType>(value.getType()));
688 if (auto *term = getOptTermForDomain(value))
689 return term;
690 auto *term = allocVar();
691 setTermForDomain(value, term);
692 return term;
693}
694
695void ModuleState::setTermForDomain(DomainValue value, Term *term) {
696 assert(term);
697 assert(!termTable.contains(value));
698 termTable.insert({value, term});
699 LLVM_DEBUG(llvm::dbgs().indent(6)
700 << "set " << render(value) << " := " << render(term) << "\n");
701}
702
703Term *ModuleState::getOptDomainAssociation(Value value) {
704 assert(isHardware(value));
705 auto it = associationTable.find(value);
706 if (it == associationTable.end())
707 return nullptr;
708 return find(it->second);
709}
710
711Term *ModuleState::getDomainAssociation(Value value) {
712 auto *term = getOptDomainAssociation(value);
713 assert(term);
714 return term;
715}
716
717void ModuleState::setDomainAssociation(Value value, Term *term) {
718 assert(isHardware(value));
719 assert(term);
720 term = find(term);
721 associationTable.insert({value, term});
722 LLVM_DEBUG({
723 llvm::dbgs().indent(6) << "set domains(" << render(value)
724 << ") := " << render(term) << "\n";
725 });
726}
727
728void ModuleState::processDomainDefinition(DomainValue domain) {
729 assert(isa<DomainType>(domain.getType()));
730 auto *newTerm = allocVal(domain);
731 auto *oldTerm = getOptTermForDomain(domain);
732 if (!oldTerm) {
733 setTermForDomain(domain, newTerm);
734 return;
735 }
736
737 [[maybe_unused]] auto result = unify(oldTerm, newTerm);
738 assert(result.succeeded());
739}
740
741RowTerm *ModuleState::getDomainAssociationAsRow(Value value) {
742 assert(isHardware(value));
743 auto *term = getOptDomainAssociation(value);
744
745 // If the term is unknown, allocate a fresh row and set the association.
746 if (!term) {
747 auto *row = allocRow(getNumDomains());
748 setDomainAssociation(value, row);
749 return row;
750 }
751
752 // If the term is already a row, return it.
753 if (auto *row = dyn_cast<RowTerm>(term))
754 return row;
755
756 // Otherwise, unify the term with a fresh row of domains.
757 if (auto *var = dyn_cast<VariableTerm>(term)) {
758 auto *row = allocRow(getNumDomains());
759 solve(var, row);
760 return row;
761 }
762
763 assert(false && "unhandled term type");
764 return nullptr;
765}
766
767void ModuleState::noteLocation(InFlightDiagnostic &diag, Operation *op) {
768 auto &note = diag.attachNote(op->getLoc());
769 if (auto mod = dyn_cast<FModuleOp>(op)) {
770 note << "in module " << mod.getModuleNameAttr();
771 return;
772 }
773 if (auto mod = dyn_cast<FExtModuleOp>(op)) {
774 note << "in extmodule " << mod.getModuleNameAttr();
775 return;
776 }
777 if (auto inst = dyn_cast<InstanceOp>(op)) {
778 note << "in instance " << inst.getInstanceNameAttr();
779 return;
780 }
781 if (auto inst = dyn_cast<InstanceChoiceOp>(op)) {
782 note << "in instance_choice " << inst.getNameAttr();
783 return;
784 }
785
786 note << "here";
787}
788
789void ModuleState::noteDomain(InFlightDiagnostic &diag, DomainValue domain) {
790 auto &note = diag.attachNote(domain.getLoc());
791 note << renderLong(domain);
792
793 if (globals.inserted.contains(domain)) {
794 note << " automatically inserted here";
795 return;
796 }
797
798 note << " declared here";
799}
800
801void ModuleState::noteDomainSource(InFlightDiagnostic &diag,
802 DomainValue domain) {
803 auto &irns = globals.getInnerRefNamespace();
804 SmallVector<FInstanceLike> stack;
805 llvm::SmallDenseSet<DomainValue> seen;
806
807 // This is reusing "domain" across iterations of the while loop.
808
809 auto chaseConnect = [&]() {
810 for (auto *user : domain.getUsers()) {
811 if (auto defineOp = dyn_cast<DomainDefineOp>(user)) {
812 if (defineOp.getDest() != domain)
813 continue;
814 auto src = defineOp.getSrc();
815 diag.attachNote(defineOp.getLoc())
816 << renderLong(domain) << " aliases " << renderLong(src);
817 domain = defineOp.getSrc();
818 return true;
819 }
820 }
821 return false;
822 };
823
824 auto chaseModulePort = [&]() {
825 auto arg = dyn_cast<BlockArgument>(domain);
826 if (!arg)
827 return false;
828
829 auto module =
830 llvm::dyn_cast_if_present<FModuleOp>(arg.getOwner()->getParentOp());
831 if (!module)
832 return false;
833
834 auto name = module.getModuleNameAttr();
835 while (!stack.empty()) {
836 auto instance = stack.back();
837 stack.pop_back();
838 auto referenced = instance.getReferencedModuleNamesAttr().getValue();
839 if (llvm::is_contained(referenced, name)) {
840 domain = cast<DomainValue>(instance->getResult(arg.getArgNumber()));
841 return true;
842 }
843 }
844 return false;
845 };
846
847 auto chaseInstancePort = [&]() {
848 auto result = dyn_cast<OpResult>(domain);
849 if (!result)
850 return false;
851
852 auto inst = dyn_cast<FInstanceLike>(result.getOwner());
853 if (!inst)
854 return false;
855
856 auto index = result.getResultNumber();
857 if (inst.getPortDirection(index) == Direction::In)
858 return false;
859
860 auto names = inst.getReferencedModuleNamesAttr().getAsRange<StringAttr>();
861 for (auto name : names) {
862 auto moduleLike = cast<FModuleLike>(irns.symTable.lookup(name));
863 if (auto moduleOp = dyn_cast<FModuleOp>(moduleLike.getOperation())) {
864 stack.push_back(inst);
865 domain = cast<DomainValue>(moduleOp.getArgument(index));
866 return true;
867 }
868 }
869 return false;
870 };
871
872 auto chaseUnderlying = [&]() {
873 if (auto *term = getOptTermForDomain(domain)) {
874 if (auto *val = dyn_cast<ValueTerm>(term)) {
875 if (domain != val->value) {
876 diag.attachNote(domain.getLoc())
877 << renderLong(domain) << " aliases " << renderLong(val->value);
878 domain = val->value;
879 return true;
880 }
881 }
882 }
883 return false;
884 };
885
886 while (true) {
887 auto [it, inserted] = seen.insert(domain);
888 if (!inserted)
889 return;
890
891 noteDomain(diag, domain);
892 chaseConnect() || chaseModulePort() || chaseInstancePort() ||
893 chaseUnderlying();
894 }
895}
896
897void ModuleState::noteDomainSource(InFlightDiagnostic &diag, Term *term) {
898 auto *val = dyn_cast<ValueTerm>(find(term));
899 if (!val)
900 return;
901
902 noteDomainSource(diag, val->value);
903}
904
905void ModuleState::emitDomainCrossingError(Operation *op, Value lhs,
906 Term *lhsTerm, Value rhs,
907 Term *rhsTerm) {
908 auto *lhsRow = cast<RowTerm>(lhsTerm);
909 auto *rhsRow = cast<RowTerm>(rhsTerm);
910 auto diag =
911 op->emitError("illegal domain crossing in operation between operands ");
912 render(lhs, diag);
913 diag << " and ";
914 render(rhs, diag);
915 auto &note1 = diag.attachNote(lhs.getLoc());
916 render(lhs, note1);
917 note1 << " has domains ";
918 render(lhsRow, note1);
919 auto &note2 = diag.attachNote(rhs.getLoc());
920 render(rhs, note2);
921 note2 << " has domains ";
922 render(rhsRow, note2);
923
924 for (size_t i = 0, e = getNumDomains(); i < e; ++i) {
925 auto *lhsDomain = find(lhsRow->elements[i]);
926 auto *rhsDomain = find(rhsRow->elements[i]);
927 if (lhsDomain == rhsDomain)
928 continue;
929
930 noteDomainSource(diag, lhsDomain);
931 noteDomainSource(diag, rhsDomain);
932 }
933}
934
935template <typename T>
936void ModuleState::emitDuplicatePortDomainError(
937 T op, size_t i, DomainTypeID domainTypeID, IntegerAttr domainPortIndexAttr1,
938 IntegerAttr domainPortIndexAttr2) {
939 auto portName = op.getPortNameAttr(i);
940 auto portLoc = op.getPortLocation(i);
941 auto domainDecl = getDomain(domainTypeID);
942 auto domainName = domainDecl.getNameAttr();
943 auto domainPortIndex1 = domainPortIndexAttr1.getUInt();
944 auto domainPortIndex2 = domainPortIndexAttr2.getUInt();
945 auto domainPortName1 = op.getPortNameAttr(domainPortIndex1);
946 auto domainPortName2 = op.getPortNameAttr(domainPortIndex2);
947 auto domainPortLoc1 = op.getPortLocation(domainPortIndex1);
948 auto domainPortLoc2 = op.getPortLocation(domainPortIndex2);
949 auto diag = emitError(portLoc);
950 diag << "duplicate " << domainName << " association for port " << portName;
951 auto &note1 = diag.attachNote(domainPortLoc1);
952 note1 << "associated with " << domainName << " port " << domainPortName1;
953 auto &note2 = diag.attachNote(domainPortLoc2);
954 note2 << "associated with " << domainName << " port " << domainPortName2;
955 noteLocation(diag, op);
956}
957
958/// Emit an error when we fail to infer the concrete domain to drive to a
959/// domain port.
960template <typename T>
961void ModuleState::emitDomainPortInferenceError(T op, size_t i) {
962 auto name = op.getPortNameAttr(i);
963 auto diag = emitError(op->getLoc());
964 auto info = op.getDomainInfo();
965 diag << "unable to infer value for undriven domain port " << name;
966 for (size_t j = 0, e = op.getNumPorts(); j < e; ++j) {
967 if (auto assocs = dyn_cast<ArrayAttr>(info[j])) {
968 for (auto assoc : assocs) {
969 if (i == cast<IntegerAttr>(assoc).getValue()) {
970 auto name = op.getPortNameAttr(j);
971 auto loc = op.getPortLocation(j);
972 diag.attachNote(loc) << "associated with hardware port " << name;
973 break;
974 }
975 }
976 }
977 }
978 noteLocation(diag, op);
979}
980
981template <typename T>
982void ModuleState::emitAmbiguousPortDomainAssociation(
983 T op, const llvm::TinyPtrVector<DomainValue> &exports, DomainTypeID typeID,
984 size_t i) {
985 auto portName = op.getPortNameAttr(i);
986 auto portLoc = op.getPortLocation(i);
987 auto domainDecl = getDomain(typeID);
988 auto domainName = domainDecl.getNameAttr();
989 auto diag = emitError(portLoc) << "ambiguous " << domainName
990 << " association for port " << portName;
991 for (auto e : exports) {
992 auto arg = cast<BlockArgument>(e);
993 auto name = op.getPortNameAttr(arg.getArgNumber());
994 auto loc = op.getPortLocation(arg.getArgNumber());
995 diag.attachNote(loc) << "candidate association " << name;
996 }
997 noteLocation(diag, op);
998}
999
1000template <typename T>
1001void ModuleState::emitMissingPortDomainAssociationError(T op,
1002 DomainTypeID typeID,
1003 size_t i) {
1004 auto domainName = getDomain(typeID).getNameAttr();
1005 auto portName = op.getPortNameAttr(i);
1006 auto diag = emitError(op.getPortLocation(i))
1007 << "missing " << domainName << " association for port "
1008 << portName;
1009 noteLocation(diag, op);
1010}
1011
1012LogicalResult ModuleState::unifyAssociations(Operation *op, Value lhs,
1013 Value rhs) {
1014 if (!lhs || !rhs)
1015 return success();
1016
1017 if (lhs == rhs)
1018 return success();
1019
1020 LLVM_DEBUG({
1021 llvm::dbgs().indent(6) << "unify domains(" << render(lhs) << ") = domains("
1022 << render(rhs) << ")\n";
1023 });
1024
1025 auto *lhsTerm = getOptDomainAssociation(lhs);
1026 auto *rhsTerm = getOptDomainAssociation(rhs);
1027
1028 if (lhsTerm) {
1029 if (rhsTerm) {
1030 if (failed(unify(lhsTerm, rhsTerm))) {
1031 emitDomainCrossingError(op, lhs, lhsTerm, rhs, rhsTerm);
1032 return failure();
1033 }
1034 return success();
1035 }
1036 setDomainAssociation(rhs, lhsTerm);
1037 return success();
1038 }
1039
1040 if (rhsTerm) {
1041 setDomainAssociation(lhs, rhsTerm);
1042 return success();
1043 }
1044
1045 auto *var = allocVar();
1046 setDomainAssociation(lhs, var);
1047 setDomainAssociation(rhs, var);
1048 return success();
1049}
1050
1051template <typename T>
1052LogicalResult ModuleState::unifyAssociations(Operation *op, T &&range) {
1053 Value lhs;
1054 for (auto rhs : std::forward<T>(range)) {
1055 if (!isHardware(rhs))
1056 continue;
1057 if (failed(unifyAssociations(op, lhs, rhs)))
1058 return failure();
1059 lhs = rhs;
1060 }
1061
1062 return success();
1063}
1064
1065LogicalResult ModuleState::unifyAssociations(Operation *op) {
1066 return unifyAssociations(
1067 op, llvm::concat<Value>(op->getOperands(), op->getResults()));
1068}
1069
1070LogicalResult ModuleState::processModulePorts(FModuleOp moduleOp) {
1071 auto numDomains = getNumDomains();
1072 auto domainInfo = moduleOp.getDomainInfoAttr();
1073 auto numPorts = moduleOp.getNumPorts();
1074
1075 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1076 for (size_t i = 0; i < numPorts; ++i) {
1077 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1078 if (!port)
1079 continue;
1080
1081 LLVM_DEBUG(llvm::dbgs().indent(4)
1082 << "process port " << render(port) << "\n");
1083
1084 if (moduleOp.getPortDirection(i) == Direction::In)
1085 processDomainDefinition(port);
1086
1087 domainTypeIDTable[i] = getDomainTypeID(moduleOp, i);
1088 }
1089
1090 for (size_t i = 0; i < numPorts; ++i) {
1091 BlockArgument port = moduleOp.getArgument(i);
1092 if (!isHardware(port))
1093 continue;
1094
1095 LLVM_DEBUG(llvm::dbgs().indent(4)
1096 << "process port " << render(port) << "\n");
1097
1098 SmallVector<IntegerAttr> associations(numDomains);
1099 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
1100 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1101 auto prevDomainPortIndex = associations[domainTypeID.index];
1102 if (prevDomainPortIndex) {
1103 emitDuplicatePortDomainError(moduleOp, i, domainTypeID,
1104 prevDomainPortIndex, domainPortIndex);
1105 return failure();
1106 }
1107 associations[domainTypeID.index] = domainPortIndex;
1108 }
1109
1110 SmallVector<Term *> elements(numDomains);
1111 for (size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
1112 ++domainTypeIndex) {
1113 auto domainPortIndex = associations[domainTypeIndex];
1114 if (!domainPortIndex)
1115 continue;
1116 auto domainPortValue =
1117 cast<DomainValue>(moduleOp.getArgument(domainPortIndex.getUInt()));
1118 elements[domainTypeIndex] = getTermForDomain(domainPortValue);
1119 }
1120
1121 auto *domainAssociations = allocRow(elements);
1122 setDomainAssociation(port, domainAssociations);
1123 }
1124
1125 return success();
1126}
1127
1128template <typename T>
1129LogicalResult ModuleState::processInstancePorts(T op) {
1130 auto numDomains = getNumDomains();
1131 auto domainInfo = op.getDomainInfoAttr();
1132 auto numPorts = op.getNumPorts();
1133
1134 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1135 for (size_t i = 0; i < numPorts; ++i) {
1136 auto port = dyn_cast<DomainValue>(op->getResult(i));
1137 if (!port)
1138 continue;
1139
1140 if (op.getPortDirection(i) == Direction::Out)
1141 processDomainDefinition(port);
1142
1143 domainTypeIDTable[i] = getDomainTypeID(op, i);
1144 }
1145
1146 for (size_t i = 0; i < numPorts; ++i) {
1147 Value port = op->getResult(i);
1148 if (!isHardware(port))
1149 continue;
1150
1151 SmallVector<IntegerAttr> associations(numDomains);
1152 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
1153 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1154 auto prevDomainPortIndex = associations[domainTypeID.index];
1155 if (prevDomainPortIndex) {
1156 emitDuplicatePortDomainError(op, i, domainTypeID, prevDomainPortIndex,
1157 domainPortIndex);
1158 return failure();
1159 }
1160 associations[domainTypeID.index] = domainPortIndex;
1161 }
1162
1163 SmallVector<Term *> elements(numDomains);
1164 for (size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
1165 ++domainTypeIndex) {
1166 auto domainPortIndex = associations[domainTypeIndex];
1167 if (!domainPortIndex)
1168 continue;
1169 auto domainPortValue =
1170 cast<DomainValue>(op->getResult(domainPortIndex.getUInt()));
1171 elements[domainTypeIndex] = getTermForDomain(domainPortValue);
1172 }
1173
1174 auto *domainAssociations = allocRow(elements);
1175 setDomainAssociation(port, domainAssociations);
1176 }
1177
1178 return success();
1179}
1180
1181FInstanceLike ModuleState::fixInstancePorts(FInstanceLike op,
1182 const ModuleUpdateInfo &update) {
1183 auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions);
1184 clone.setDomainInfoAttr(update.portDomainInfo);
1185 op->erase();
1186 dirty();
1187 LLVM_DEBUG(llvm::dbgs().indent(6) << "fixup " << render(clone) << "\n");
1188 return clone;
1189}
1190
1191LogicalResult ModuleState::processOp(FInstanceLike op) {
1192 auto moduleName =
1193 cast<StringAttr>(cast<ArrayAttr>(op.getReferencedModuleNamesAttr())[0]);
1194 auto updateTable = getModuleUpdateTable();
1195 auto lookup = updateTable.find(moduleName);
1196 if (lookup != updateTable.end())
1197 op = fixInstancePorts(op, lookup->second);
1198 return processInstancePorts(op);
1199}
1200
1201LogicalResult ModuleState::processOp(UnsafeDomainCastOp op) {
1202 auto domains = op.getDomains();
1203 if (domains.empty())
1204 return unifyAssociations(op, op.getInput(), op.getResult());
1205
1206 auto input = op.getInput();
1207 RowTerm *inputRow = getDomainAssociationAsRow(input);
1208 SmallVector<Term *> elements(inputRow->elements);
1209 for (auto value : op.getDomains()) {
1210 auto domain = cast<DomainValue>(value);
1211 auto typeID = getDomainTypeID(domain);
1212 elements[typeID.index] = getTermForDomain(domain);
1213 }
1214
1215 auto *row = allocRow(elements);
1216 setDomainAssociation(op.getResult(), row);
1217 return success();
1218}
1219
1220LogicalResult ModuleState::processOp(DomainDefineOp op) {
1221 auto src = op.getSrc();
1222 auto dst = op.getDest();
1223
1224 auto *srcTerm = getTermForDomain(src);
1225 auto *dstTerm = getTermForDomain(dst);
1226 if (succeeded(unify(dstTerm, srcTerm)))
1227 return success();
1228
1229 auto diag =
1230 op->emitOpError()
1231 << "defines a domain value that was inferred to be a different domain '";
1232 render(dstTerm, diag);
1233 diag << "'";
1234
1235 return failure();
1236}
1237
1238LogicalResult ModuleState::processOp(WireOp op) {
1239 // If the wire has explicit domain operands, seed the domain table with them
1240 // as constraints. When this op is visited, connections have not yet been
1241 // processed (wire declarations precede their uses), so the existing row
1242 // contains only fresh variables that unify unconditionally. Any conflict
1243 // between an explicit wire domain and a connection's inferred domain is
1244 // caught later by the connection's own processOp.
1245 if (op.getDomains().empty())
1246 return unifyAssociations(op, op.getResults());
1247
1248 // Build a row with the explicitly-specified domain slots filled in and set
1249 // it as the association for this wire result.
1250 SmallVector<Term *> elements(getNumDomains());
1251 for (auto domain : op.getDomains()) {
1252 auto domainValue = cast<DomainValue>(domain);
1253 auto typeID = getDomainTypeID(domainValue);
1254 elements[typeID.index] = getTermForDomain(domainValue);
1255 }
1256
1257 auto *row = allocRow(elements);
1258 for (auto result : op.getResults())
1259 setDomainAssociation(result, row);
1260
1261 return success();
1262}
1263
1264LogicalResult ModuleState::processOp(RWProbeOp op) {
1265 auto target = globals.getInnerRefNamespace().lookup(op.getTarget());
1266
1267 if (target.isPort()) {
1268 auto targetOp = cast<FModuleOp>(target.getOp());
1269 auto targetValue = targetOp.getArgument(target.getPort());
1270 return unifyAssociations(op, targetValue, op.getResult());
1271 }
1272
1273 auto targetOp = cast<hw::InnerSymbolOpInterface>(target.getOp());
1274 auto targetValue = targetOp.getTargetResult();
1275 return unifyAssociations(op, targetValue, op.getResult());
1276}
1277
1278LogicalResult ModuleState::processOp(Operation *op) {
1279 LLVM_DEBUG(llvm::dbgs().indent(4) << "process " << render(op) << "\n");
1280 if (auto instance = dyn_cast<FInstanceLike>(op))
1281 return processOp(instance);
1282 if (auto wireOp = dyn_cast<WireOp>(op))
1283 return processOp(wireOp);
1284 if (auto cast = dyn_cast<UnsafeDomainCastOp>(op))
1285 return processOp(cast);
1286 if (auto def = dyn_cast<DomainDefineOp>(op))
1287 return processOp(def);
1288 if (auto probe = dyn_cast<RWProbeOp>(op))
1289 return processOp(probe);
1290 if (auto create = dyn_cast<DomainCreateOp>(op)) {
1291 processDomainDefinition(create);
1292 return success();
1293 }
1294 if (auto createAnon = dyn_cast<DomainCreateAnonOp>(op)) {
1295 processDomainDefinition(createAnon);
1296 return success();
1297 }
1298
1299 return unifyAssociations(op);
1300}
1301
1302LogicalResult ModuleState::processModuleBody(FModuleOp moduleOp) {
1303 return failure(
1304 moduleOp.getBody()
1305 .walk([&](Operation *op) -> WalkResult { return processOp(op); })
1306 .wasInterrupted());
1307}
1308
1309LogicalResult ModuleState::processModule(FModuleOp moduleOp) {
1310 LLVM_DEBUG(llvm::dbgs().indent(2) << "processing:\n");
1311 if (failed(processModulePorts(moduleOp)))
1312 return failure();
1313 if (failed(processModuleBody(moduleOp)))
1314 return failure();
1315 return success();
1316}
1317
1318ExportTable ModuleState::initializeExportTable(FModuleOp moduleOp) {
1319 ExportTable exports;
1320 size_t numPorts = moduleOp.getNumPorts();
1321 for (size_t i = 0; i < numPorts; ++i) {
1322 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1323 if (!port)
1324 continue;
1325 auto value = getOptUnderlyingDomain(port);
1326 if (value)
1327 exports[value].push_back(port);
1328 }
1329
1330 LLVM_DEBUG({
1331 llvm::dbgs().indent(2) << "domain exports:\n";
1332 for (auto entry : exports) {
1333 llvm::dbgs().indent(4) << render(entry.first) << " exported as ";
1334 llvm::interleaveComma(entry.second, llvm::dbgs(),
1335 [&](auto e) { llvm::dbgs() << render(e); });
1336 llvm::dbgs() << "\n";
1337 }
1338 });
1339
1340 return exports;
1341}
1342
1343void ModuleState::ensureSolved(Namespace &ns, DomainTypeID typeID, size_t ip,
1344 LocationAttr loc, VariableTerm *var,
1345 PendingUpdates &pending) {
1346 if (pending.solutions.contains(var))
1347 return;
1348
1349 auto *context = loc.getContext();
1350 auto domainDecl = getDomain(typeID);
1351 auto domainName = domainDecl.getNameAttr();
1352
1353 auto portName = StringAttr::get(context, ns.newName(domainName.getValue()));
1354 auto portType = DomainType::getFromDomainOp(domainDecl);
1355 auto portDirection = Direction::In;
1356 auto portSym = StringAttr();
1357 auto portLoc = loc;
1358 auto portAnnos = std::nullopt;
1359 // Domain type ports have no associations (domain info is in the type).
1360 auto portDomainInfo = ArrayAttr::get(context, {});
1361 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1362 portAnnos, portDomainInfo);
1363
1364 pending.solutions[var] = pending.insertions.size() + ip;
1365 pending.insertions.push_back({ip, portInfo});
1366}
1367
1368void ModuleState::ensureExported(Namespace &ns, const ExportTable &exports,
1369 DomainTypeID typeID, size_t ip,
1370 LocationAttr loc, ValueTerm *val,
1371 PendingUpdates &pending) {
1372 auto value = val->value;
1373 assert(isa<DomainType>(value.getType()));
1374 if (isPort(value) || exports.contains(value) ||
1375 pending.exports.contains(value))
1376 return;
1377
1378 auto *context = loc.getContext();
1379
1380 auto domainDecl = getDomain(typeID);
1381 auto domainName = domainDecl.getNameAttr();
1382
1383 auto portName = StringAttr::get(context, ns.newName(domainName.getValue()));
1384 auto portType = DomainType::getFromDomainOp(domainDecl);
1385 auto portDirection = Direction::Out;
1386 auto portSym = StringAttr();
1387 auto portAnnos = std::nullopt;
1388 // Domain type ports have no associations (domain info is in the type).
1389 auto portDomainInfo = ArrayAttr::get(context, {});
1390 PortInfo portInfo(portName, portType, portDirection, portSym, loc, portAnnos,
1391 portDomainInfo);
1392 pending.exports[value] = pending.insertions.size() + ip;
1393 pending.insertions.push_back({ip, portInfo});
1394}
1395
1396void ModuleState::getUpdatesForDomainAssociationOfPort(
1397 Namespace &ns, PendingUpdates &pending, DomainTypeID typeID, size_t ip,
1398 LocationAttr loc, Term *term, const ExportTable &exports) {
1399 if (auto *var = dyn_cast<VariableTerm>(term)) {
1400 ensureSolved(ns, typeID, ip, loc, var, pending);
1401 return;
1402 }
1403 if (auto *val = dyn_cast<ValueTerm>(term)) {
1404 ensureExported(ns, exports, typeID, ip, loc, val, pending);
1405 return;
1406 }
1407 llvm_unreachable("invalid domain association");
1408}
1409
1410void ModuleState::getUpdatesForDomainAssociationOfPort(
1411 Namespace &ns, const ExportTable &exports, size_t ip, LocationAttr loc,
1412 RowTerm *row, PendingUpdates &pending) {
1413 for (auto [index, term] : llvm::enumerate(row->elements))
1414 getUpdatesForDomainAssociationOfPort(ns, pending, DomainTypeID{index}, ip,
1415 loc, find(term), exports);
1416}
1417
1418void ModuleState::getUpdatesForModulePorts(FModuleOp moduleOp,
1419 const ExportTable &exports,
1420 Namespace &ns,
1421 PendingUpdates &pending) {
1422 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1423 auto port = moduleOp.getArgument(i);
1424 if (!isHardware(port))
1425 continue;
1426
1427 getUpdatesForDomainAssociationOfPort(
1428 ns, exports, i, moduleOp.getPortLocation(i),
1429 getDomainAssociationAsRow(port), pending);
1430 }
1431}
1432
1433void ModuleState::getUpdatesForModule(FModuleOp moduleOp,
1434 const ExportTable &exports,
1435 PendingUpdates &pending) {
1436 Namespace ns;
1437 auto names = moduleOp.getPortNamesAttr();
1438 for (auto name : names.getAsRange<StringAttr>())
1439 ns.add(name);
1440 getUpdatesForModulePorts(moduleOp, exports, ns, pending);
1441}
1442
1443void ModuleState::applyUpdatesToModule(FModuleOp moduleOp, ExportTable &exports,
1444 const PendingUpdates &pending) {
1445 LLVM_DEBUG(llvm::dbgs().indent(2) << "applying updates:\n");
1446 // Put the domain ports in place.
1447 moduleOp.insertPorts(pending.insertions);
1448 dirty();
1449
1450 // Solve any variables and record them as "self-exporting".
1451 for (auto [var, portIndex] : pending.solutions) {
1452 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1453 auto *solution = allocVal(portValue);
1454 LLVM_DEBUG(llvm::dbgs().indent(4)
1455 << "new-input " << render(portValue) << "\n");
1456 solve(var, solution);
1457 exports[portValue].push_back(portValue);
1458 globals.inserted.insert(portValue);
1459 }
1460
1461 // Drive the output ports, and record the export.
1462 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1463 for (auto [domainValue, portIndex] : pending.exports) {
1464 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1465 builder.setInsertionPointAfterValue(domainValue);
1466 DomainDefineOp::create(builder, portValue.getLoc(), portValue, domainValue);
1467 LLVM_DEBUG(llvm::dbgs().indent(4) << "new-output " << render(portValue)
1468 << " := " << render(domainValue) << "\n");
1469 exports[domainValue].push_back(portValue);
1470 globals.inserted.insert(portValue);
1471 setTermForDomain(portValue, allocVal(domainValue));
1472 }
1473}
1474
1475SmallVector<Attribute> ModuleState::copyPortDomainAssociations(
1476 FModuleOp moduleOp, ArrayAttr moduleDomainInfo, size_t portIndex) {
1477 SmallVector<Attribute> result(getNumDomains());
1478 auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex);
1479 for (auto domainPortIndexAttr : oldAssociations) {
1480 auto domainPortIndex = domainPortIndexAttr.getUInt();
1481 auto domainTypeID = getDomainTypeID(moduleOp, domainPortIndex);
1482 result[domainTypeID.index] = domainPortIndexAttr;
1483 };
1484 return result;
1485}
1486
1487LogicalResult ModuleState::driveModuleOutputDomainPorts(FModuleOp moduleOp) {
1488 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1489 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1490 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1491 if (!port || moduleOp.getPortDirection(i) == Direction::In ||
1492 isDriven(port))
1493 continue;
1494
1495 auto *term = getOptTermForDomain(port);
1496 auto *val = llvm::dyn_cast_if_present<ValueTerm>(term);
1497 if (!val) {
1498 emitDomainPortInferenceError(moduleOp, i);
1499 return failure();
1500 }
1501
1502 auto loc = port.getLoc();
1503 auto value = val->value;
1504 LLVM_DEBUG(llvm::dbgs().indent(4) << "connect " << render(port)
1505 << " := " << render(value) << "\n");
1506 DomainDefineOp::create(builder, loc, port, value);
1507 }
1508
1509 return success();
1510}
1511
1512LogicalResult ModuleState::updateModuleDomainInfo(
1513 FModuleOp moduleOp, const ExportTable &exportTable, ArrayAttr &result) {
1514 // At this point, all domain variables mentioned in ports have been
1515 // solved by generalizing the moduleOp (adding input domain ports). Now, we
1516 // have to form the new port domain information for the moduleOp by examining
1517 // the the associated domains of each port.
1518 auto *context = moduleOp.getContext();
1519 auto numDomains = getNumDomains();
1520 auto oldModuleDomainInfo = moduleOp.getDomainInfoAttr();
1521 auto numPorts = moduleOp.getNumPorts();
1522 SmallVector<Attribute> newModuleDomainInfo(numPorts);
1523
1524 for (size_t i = 0; i < numPorts; ++i) {
1525 auto port = moduleOp.getArgument(i);
1526 auto type = port.getType();
1527
1528 if (isa<DomainType>(type)) {
1529 // Domain type ports have no associations (domain info is in the type).
1530 newModuleDomainInfo[i] = ArrayAttr::get(context, {});
1531 continue;
1532 }
1533
1534 if (!isHardware(port)) {
1535 newModuleDomainInfo[i] = ArrayAttr::get(context, {});
1536 continue;
1537 }
1538
1539 auto associations =
1540 copyPortDomainAssociations(moduleOp, oldModuleDomainInfo, i);
1541 auto *row = cast<RowTerm>(getDomainAssociation(port));
1542 for (size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1543 auto domainTypeID = DomainTypeID{domainIndex};
1544 if (associations[domainIndex])
1545 continue;
1546
1547 auto domain = cast<ValueTerm>(find(row->elements[domainIndex]))->value;
1548 auto &exports = exportTable.at(domain);
1549 if (exports.empty()) {
1550 auto portName = moduleOp.getPortNameAttr(i);
1551 auto portLoc = moduleOp.getPortLocation(i);
1552 auto domainDecl = getDomain(domainTypeID);
1553 auto domainName = domainDecl.getNameAttr();
1554 auto diag = emitError(portLoc) << "private " << domainName
1555 << " association for port " << portName;
1556 diag.attachNote(domain.getLoc()) << "associated domain: " << domain;
1557 noteLocation(diag, moduleOp);
1558 return failure();
1559 }
1560
1561 if (exports.size() > 1) {
1562 emitAmbiguousPortDomainAssociation(moduleOp, exports, domainTypeID, i);
1563 return failure();
1564 }
1565
1566 auto argument = cast<BlockArgument>(exports[0]);
1567 auto domainPortIndex = argument.getArgNumber();
1568 associations[domainTypeID.index] =
1569 IntegerAttr::get(IntegerType::get(context, 32, IntegerType::Unsigned),
1570 domainPortIndex);
1571 }
1572
1573 newModuleDomainInfo[i] = ArrayAttr::get(context, associations);
1574 }
1575
1576 result = ArrayAttr::get(moduleOp.getContext(), newModuleDomainInfo);
1577 moduleOp.setDomainInfoAttr(result);
1578 return success();
1579}
1580
1581DomainValue ModuleState::solveVarWithAnonDomain(
1582 OpBuilder &builder, DenseMap<DomainValue, DomainValue> &domainsInScope,
1583 Operation *user, DomainType type, VariableTerm *var) {
1584 auto name = type.getName().getAttr();
1585 DomainValue anon =
1586 DomainCreateAnonOp::create(builder, user->getLoc(), type, name);
1587 dirty();
1588 LLVM_DEBUG(llvm::dbgs().indent(6) << "create anon " << render(anon) << "\n");
1589 solve(var, allocVal(anon));
1590 domainsInScope[anon] = anon;
1591 globals.inserted.insert(anon);
1592 return anon;
1593}
1594
1595DomainValue ModuleState::getDomainInScope(
1596 OpBuilder &builder, DenseMap<DomainValue, DomainValue> &domainsInScope,
1597 DomainValue domain) {
1598 auto &domainInScope = domainsInScope[domain];
1599 if (domainInScope)
1600 return domainInScope;
1601
1602 domainInScope = cast<DomainValue>(
1603 WireOp::create(builder, domain.getLoc(), domain.getType(),
1604 domain.getType().getName().getAttr())
1605 .getResult());
1606
1607 OpBuilder::InsertionGuard guard(builder);
1608 builder.setInsertionPointAfterValue(domain);
1609 DomainDefineOp::create(builder, domain.getLoc(), domainInScope, domain);
1610 dirty();
1611 LLVM_DEBUG(llvm::dbgs().indent(6) << "bounce wire " << render(domainInScope)
1612 << " := " << render(domain) << "\n");
1613 return domainInScope;
1614}
1615
1616LogicalResult
1617ModuleState::updateInstance(DenseMap<DomainValue, DomainValue> &domainsInScope,
1618 FInstanceLike op) {
1619 LLVM_DEBUG(llvm::dbgs().indent(4) << "update " << render(op) << "\n");
1620 OpBuilder builder(op.getContext());
1621 builder.setInsertionPointAfter(op);
1622 auto numPorts = op->getNumResults();
1623
1624 for (size_t i = 0; i < numPorts; ++i)
1625 if (auto port = dyn_cast<DomainValue>(op->getResult(i)))
1626 if (op.getPortDirection(i) == Direction::Out)
1627 domainsInScope[port] = port;
1628
1629 for (size_t i = 0; i < numPorts; ++i) {
1630 auto port = dyn_cast<DomainValue>(op->getResult(i));
1631 auto direction = op.getPortDirection(i);
1632 // If the port is an input domain, we may need to drive the input with
1633 // a value. If we don't know what value to drive to the port, drive an
1634 // anonymous domain.
1635 if (port && direction == Direction::In && !isDriven(port)) {
1636 auto loc = port.getLoc();
1637 auto *term = getTermForDomain(port);
1638 if (auto *var = dyn_cast<VariableTerm>(term)) {
1639 auto domain = solveVarWithAnonDomain(builder, domainsInScope, op,
1640 port.getType(), var);
1641 LLVM_DEBUG(llvm::dbgs().indent(6) << "connect " << render(port)
1642 << " := " << render(domain) << "\n");
1643 DomainDefineOp::create(builder, loc, port, domain);
1644 continue;
1645 }
1646 if (auto *val = dyn_cast<ValueTerm>(term)) {
1647 auto domain = getDomainInScope(builder, domainsInScope, val->value);
1648 LLVM_DEBUG(llvm::dbgs().indent(6) << "connect " << render(port)
1649 << " := " << render(domain) << "\n");
1650 DomainDefineOp::create(builder, loc, port, domain);
1651 continue;
1652 }
1653 llvm_unreachable("unhandled domain term type");
1654 }
1655 }
1656
1657 return success();
1658}
1659
1660LogicalResult
1661ModuleState::updateWire(DenseMap<DomainValue, DomainValue> &domainsInScope,
1662 WireOp wireOp) {
1663 auto result = wireOp.getResult();
1664
1665 if (auto tgt = dyn_cast<DomainValue>(result)) {
1666 if (isDriven(tgt))
1667 return success();
1668
1669 LLVM_DEBUG(llvm::dbgs().indent(4) << "update " << render(wireOp) << "\n");
1670 OpBuilder builder(wireOp);
1671 builder.setInsertionPointAfter(wireOp);
1672 auto *term = getTermForDomain(tgt);
1673 if (auto *var = dyn_cast<VariableTerm>(term)) {
1674 auto src = solveVarWithAnonDomain(builder, domainsInScope, wireOp,
1675 tgt.getType(), var);
1676 LLVM_DEBUG(llvm::dbgs().indent(6)
1677 << "connect " << render(tgt) << " := " << render(src) << "\n");
1678 DomainDefineOp::create(builder, wireOp.getLoc(), tgt, src);
1679 return success();
1680 }
1681 if (auto *val = dyn_cast<ValueTerm>(term)) {
1682 auto src = getDomainInScope(builder, domainsInScope, val->value);
1683 LLVM_DEBUG(llvm::dbgs().indent(6)
1684 << "connect " << render(tgt) << " := " << render(src) << "\n");
1685 DomainDefineOp::create(builder, wireOp.getLoc(), tgt, src);
1686 return success();
1687 }
1688 llvm_unreachable("unhandled domain term type");
1689 }
1690
1691 if (!isHardware(result))
1692 return success();
1693
1694 LLVM_DEBUG(llvm::dbgs().indent(4) << "update " << render(wireOp) << "\n");
1695 OpBuilder builder(wireOp);
1696 auto *row = getDomainAssociationAsRow(wireOp.getResult());
1697
1698 SmallVector<Value> domainOperands;
1699 for (auto [i, element] : llvm::enumerate(
1700 llvm::map_range(row->elements, [&](auto e) { return find(e); }))) {
1701 if (auto *val = dyn_cast<ValueTerm>(element)) {
1702 domainOperands.push_back(
1703 getDomainInScope(builder, domainsInScope, val->value));
1704 continue;
1705 }
1706 if (auto *var = dyn_cast<VariableTerm>(element)) {
1707 auto type = DomainType::getFromDomainOp(getDomain(DomainTypeID{i}));
1708 auto domain =
1709 solveVarWithAnonDomain(builder, domainsInScope, wireOp, type, var);
1710 domainOperands.push_back(domain);
1711 continue;
1712 }
1713 assert(0 && "unhandled domain type");
1714 }
1715 wireOp.getDomainsMutable().assign(domainOperands);
1716 return success();
1717}
1718
1719LogicalResult ModuleState::updateModuleBody(FModuleOp moduleOp) {
1720 DenseMap<DomainValue, DomainValue> domainsInScope;
1721
1722 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i)
1723 if (auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i)))
1724 if (moduleOp.getPortDirection(i) == Direction::In)
1725 domainsInScope[port] = port;
1726
1727 auto result = moduleOp.getBodyBlock()->walk([&](Operation *op) -> WalkResult {
1728 return TypeSwitch<Operation *, WalkResult>(op)
1729 .Case<WireOp>(
1730 [&](auto wire) { return updateWire(domainsInScope, wire); })
1731 .Case<FInstanceLike>([&](auto instance) {
1732 return updateInstance(domainsInScope, instance);
1733 })
1734 .Case<DomainCreateOp, DomainCreateAnonOp>([&](auto domain) {
1735 domainsInScope[domain] = domain;
1736 return success();
1737 })
1738 .Default([&](auto op) { return success(); });
1739 });
1740 return failure(result.wasInterrupted());
1741}
1742
1743LogicalResult ModuleState::updateModule(FModuleOp moduleOp) {
1744 auto exports = initializeExportTable(moduleOp);
1745 PendingUpdates pending;
1746 getUpdatesForModule(moduleOp, exports, pending);
1747 applyUpdatesToModule(moduleOp, exports, pending);
1748
1749 ArrayAttr portDomainInfo;
1750 if (failed(updateModuleDomainInfo(moduleOp, exports, portDomainInfo)))
1751 return failure();
1752
1753 if (failed(driveModuleOutputDomainPorts(moduleOp)))
1754 return failure();
1755
1756 // Record the updated interface change in the update
1757 auto &entry = getModuleUpdateTable()[moduleOp.getModuleNameAttr()];
1758 entry.portDomainInfo = portDomainInfo;
1759 entry.portInsertions = std::move(pending.insertions);
1760
1761 if (failed(updateModuleBody(moduleOp)))
1762 return failure();
1763
1764 LLVM_DEBUG({
1765 llvm::dbgs().indent(2) << "port summary:\n";
1766 for (auto port : moduleOp.getBodyBlock()->getArguments()) {
1767 llvm::dbgs().indent(4) << render(port);
1768 auto info = cast<ArrayAttr>(
1769 moduleOp.getDomainInfoAttrForPort(port.getArgNumber()));
1770 if (info.size()) {
1771 llvm::dbgs() << " domains [";
1772 llvm::interleaveComma(
1773 info.getAsRange<IntegerAttr>(), llvm::dbgs(), [&](auto i) {
1774 llvm::dbgs() << render(moduleOp.getArgument(i.getUInt()));
1775 });
1776 llvm::dbgs() << "]";
1777 }
1778 llvm::dbgs() << "\n";
1779 }
1780 });
1781
1782 return success();
1783}
1784
1785LogicalResult ModuleState::checkModulePorts(FModuleLike moduleOp) {
1786 auto numDomains = getNumDomains();
1787 auto domainInfo = moduleOp.getDomainInfoAttr();
1788 auto numPorts = moduleOp.getNumPorts();
1789
1790 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1791 for (size_t i = 0; i < numPorts; ++i) {
1792 if (isa<DomainType>(moduleOp.getPortType(i)))
1793 domainTypeIDTable[i] = getDomainTypeID(moduleOp, i);
1794 }
1795
1796 for (size_t i = 0; i < numPorts; ++i) {
1797 if (!isHardware(moduleOp.getPortType(i)))
1798 continue;
1799
1800 // Record the domain associations of this port.
1801 SmallVector<IntegerAttr> associations(numDomains);
1802 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
1803 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1804 auto prevDomainPortIndex = associations[domainTypeID.index];
1805 if (prevDomainPortIndex) {
1806 emitDuplicatePortDomainError(moduleOp, i, domainTypeID,
1807 prevDomainPortIndex, domainPortIndex);
1808 return failure();
1809 }
1810 associations[domainTypeID.index] = domainPortIndex;
1811 }
1812
1813 // Check the associations for completeness.
1814 for (size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1815 auto typeID = DomainTypeID{domainIndex};
1816 if (!associations[domainIndex]) {
1817 emitMissingPortDomainAssociationError(moduleOp, typeID, i);
1818 return failure();
1819 }
1820 }
1821 }
1822
1823 return success();
1824}
1825
1826LogicalResult ModuleState::checkModuleDomainPortDrivers(FModuleOp moduleOp) {
1827 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1828 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1829 if (!port || moduleOp.getPortDirection(i) != Direction::Out ||
1830 isDriven(port))
1831 continue;
1832
1833 auto name = moduleOp.getPortNameAttr(i);
1834 auto diag = emitError(moduleOp.getPortLocation(i))
1835 << "undriven domain port " << name;
1836 noteLocation(diag, moduleOp);
1837 return failure();
1838 }
1839
1840 return success();
1841}
1842
1843LogicalResult ModuleState::checkInstanceDomainPortDrivers(FInstanceLike op) {
1844 for (size_t i = 0, e = op->getNumResults(); i < e; ++i) {
1845 auto port = dyn_cast<DomainValue>(op->getResult(i));
1846
1847 auto type = port.getType();
1848 if (!isa<DomainType>(type) || op.getPortDirection(i) != Direction::In ||
1849 isDriven(port))
1850 continue;
1851
1852 auto name = op.getPortNameAttr(i);
1853 auto diag = emitError(op.getPortLocation(i))
1854 << "undriven domain port " << name;
1855 noteLocation(diag, op);
1856 return failure();
1857 }
1858
1859 return success();
1860}
1861
1862LogicalResult ModuleState::checkModuleBody(FModuleOp moduleOp) {
1863 auto result = moduleOp.getBody().walk([&](FInstanceLike op) -> WalkResult {
1864 return checkInstanceDomainPortDrivers(op);
1865 });
1866 return failure(result.wasInterrupted());
1867}
1868
1869LogicalResult ModuleState::inferModule(FModuleOp moduleOp) {
1870 LLVM_DEBUG(llvm::dbgs() << "infer: " << moduleOp.getModuleName() << "\n");
1871 if (failed(processModule(moduleOp)))
1872 return failure();
1873
1874 return updateModule(moduleOp);
1875}
1876
1877LogicalResult ModuleState::checkModule(FModuleOp moduleOp) {
1878 LLVM_DEBUG(llvm::dbgs() << "check: " << moduleOp.getModuleName() << "\n");
1879 if (failed(checkModulePorts(moduleOp)))
1880 return failure();
1881
1882 if (failed(checkModuleDomainPortDrivers(moduleOp)))
1883 return failure();
1884
1885 if (failed(checkModuleBody(moduleOp)))
1886 return failure();
1887
1888 return processModule(moduleOp);
1889}
1890
1891LogicalResult ModuleState::checkModule(FExtModuleOp extModuleOp) {
1892 LLVM_DEBUG(llvm::dbgs() << "check: " << extModuleOp.getModuleName() << "\n");
1893 return checkModulePorts(extModuleOp);
1894}
1895
1896LogicalResult ModuleState::checkAndInferModule(FModuleOp moduleOp) {
1897 LLVM_DEBUG(llvm::dbgs() << "check/infer: " << moduleOp.getModuleName()
1898 << "\n");
1899
1900 if (failed(checkModulePorts(moduleOp)))
1901 return failure();
1902
1903 if (failed(processModule(moduleOp)))
1904 return failure();
1905
1906 if (failed(driveModuleOutputDomainPorts(moduleOp)))
1907 return failure();
1908
1909 return updateModuleBody(moduleOp);
1910}
1911
1912//===---------------------------------------------------------------------------
1913// Domain Stripping.
1914//===---------------------------------------------------------------------------
1915
1916static LogicalResult stripModule(FModuleLike op) {
1917 WalkResult result = op->walk<mlir::WalkOrder::PostOrder, ReverseIterator>(
1918 [=](Operation *op) -> WalkResult {
1919 return TypeSwitch<Operation *, WalkResult>(op)
1920 .Case<FModuleLike>([](FModuleLike op) {
1921 auto n = op.getNumPorts();
1922 BitVector erasures(n);
1923 for (size_t i = 0; i < n; ++i)
1924 if (isa<DomainType>(op.getPortType(i)))
1925 erasures.set(i);
1926 op.erasePorts(erasures);
1927 return WalkResult::advance();
1928 })
1929 .Case<DomainDefineOp, DomainCreateAnonOp, DomainCreateOp>(
1930 [](Operation *op) {
1931 op->erase();
1932 return WalkResult::advance();
1933 })
1934 .Case<DomainSubfieldOp>([](DomainSubfieldOp op) {
1935 if (!op->use_empty()) {
1936 OpBuilder builder(op);
1937 op.replaceAllUsesWith(
1938 UnknownValueOp::create(builder, op.getLoc(), op.getType())
1939 .getResult());
1940 }
1941 op.erase();
1942 return WalkResult::advance();
1943 })
1944 .Case<UnsafeDomainCastOp>([](UnsafeDomainCastOp op) {
1945 op.replaceAllUsesWith(op.getInput());
1946 op.erase();
1947 return WalkResult::advance();
1948 })
1949 .Case<WireOp>([](WireOp op) {
1950 // Erase wires of DomainType
1951 if (isa<DomainType>(op.getType(0))) {
1952 op->erase();
1953 return WalkResult::advance();
1954 }
1955 // Erase domain operands from regular wires
1956 if (!op.getDomains().empty()) {
1957 op->eraseOperands(0, op.getNumOperands());
1958 }
1959 return WalkResult::advance();
1960 })
1961 .Case<FInstanceLike>([](auto op) {
1962 auto n = op.getNumPorts();
1963 BitVector erasures(n);
1964 for (size_t i = 0; i < n; ++i)
1965 if (isa<DomainType>(op->getResult(i).getType()))
1966 erasures.set(i);
1967 op.cloneWithErasedPortsAndReplaceUses(erasures);
1968 op.erase();
1969 return WalkResult::advance();
1970 })
1971 .Default([](Operation *op) {
1972 for (auto type :
1973 concat<Type>(op->getOperandTypes(), op->getResultTypes())) {
1974 if (isa<DomainType>(type)) {
1975 op->emitOpError("cannot be stripped");
1976 return WalkResult::interrupt();
1977 }
1978 }
1979 return WalkResult::advance();
1980 });
1981 });
1982 return failure(result.wasInterrupted());
1983}
1984
1985static LogicalResult stripCircuit(MLIRContext *context, CircuitOp circuit) {
1986 llvm::SmallVector<FModuleLike> modules;
1987 for (Operation &op : make_early_inc_range(*circuit.getBodyBlock())) {
1988 TypeSwitch<Operation *, void>(&op)
1989 .Case<FModuleLike>([&](FModuleLike op) { modules.push_back(op); })
1990 .Case<DomainOp>([](DomainOp op) { op.erase(); });
1991 }
1992 return failableParallelForEach(context, modules, stripModule);
1993}
1994
1995//===---------------------------------------------------------------------------
1996// InferDomainsPass: Top-level pass implementation.
1997//===---------------------------------------------------------------------------
1998
1999LogicalResult CircuitState::runOnModule(Operation *op) {
2000 assert(mode != InferDomainsMode::Strip);
2001 ModuleState state(*this);
2002 if (auto moduleOp = dyn_cast<FModuleOp>(op)) {
2003 if (mode == InferDomainsMode::Check)
2004 return state.checkModule(moduleOp);
2005
2006 if (mode == InferDomainsMode::InferAll || moduleOp.isPrivate())
2007 return state.inferModule(moduleOp);
2008
2009 return state.checkAndInferModule(moduleOp);
2010 }
2011
2012 if (auto extModuleOp = dyn_cast<FExtModuleOp>(op))
2013 return state.checkModule(extModuleOp);
2014
2015 return success();
2016}
2017
2018LogicalResult CircuitState::run() {
2019 DenseSet<Operation *> errored;
2020 instanceGraph.walkPostOrder([&](auto &node) {
2021 auto moduleOp = node.getModule();
2022 for (auto *inst : node) {
2023 if (errored.contains(inst->getTarget()->getModule())) {
2024 errored.insert(moduleOp);
2025 return;
2026 }
2027 }
2028 if (failed(runOnModule(node.getModule())))
2029 errored.insert(moduleOp);
2030 });
2031 return success(errored.empty());
2032}
2033
2034namespace {
2035struct InferDomainsPass
2036 : public circt::firrtl::impl::InferDomainsBase<InferDomainsPass> {
2037 using Base::Base;
2038 void runOnOperation() override {
2040 auto circuit = getOperation();
2041
2042 if (mode == InferDomainsMode::Strip) {
2043 if (failed(stripCircuit(&getContext(), circuit)))
2044 signalPassFailure();
2045 return;
2046 }
2047
2048 auto &instanceGraph = getAnalysis<InstanceGraph>();
2049 auto &symbolTable = getAnalysis<SymbolTable>();
2050 auto &innerSymbolTableCollection =
2051 getAnalysis<InnerSymbolTableCollection>();
2052 circt::hw::InnerRefNamespace innerRefNamespace{symbolTable,
2053 innerSymbolTableCollection};
2054 CircuitState state(circuit, instanceGraph, innerRefNamespace, mode);
2055 if (failed(state.run()))
2056 signalPassFailure();
2057 }
2058};
2059} // namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
SmallVector< std::pair< unsigned, PortInfo > > PortInsertions
mlir::TypedValue< DomainType > DomainValue
static LogicalResult stripCircuit(MLIRContext *context, CircuitOp circuit)
DenseMap< VariableTerm *, unsigned > PendingSolutions
A map from unsolved variables to a port index, where that port has not yet been created.
static bool isHardware(Type type)
True if a value of the given type could be associated with a domain.
static bool isPort(BlockArgument arg)
Return true if the value is a port on the module.
DenseMap< DomainValue, TinyPtrVector< DomainValue > > ExportTable
A map from domain IR values defined internal to the moduleOp, to ports that alias that domain.
static auto getPortDomainAssociation(ArrayAttr info, size_t i)
From a domain info attribute, get the row of associated domains for a hardware value at index i.
static bool isDriven(DomainValue port)
Returns true if the value is driven by a connect op.
static LogicalResult stripModule(FModuleLike op)
static Block * getBodyBlock(FModuleLike mod)
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
Definition Debug.h:70
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
This graph tracks modules and where they are instantiated.
This class represents a collection of InnerSymbolTable's.
static StringRef toLongString(Direction direction)
Definition FIRRTLEnums.h:48
InferDomainsMode
The mode for the InferDomains pass.
Definition Passes.h:73
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:1716
This holds the name and type that describes the module's ports.
This class represents the namespace in which InnerRef's can be resolved.