CIRCT 22.0.0git
Loading...
Searching...
No Matches
InferDomains.cpp
Go to the documentation of this file.
1//===- InferDomains.cpp - Infer and Check FIRRTL Domains ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// InferDomains implements FIRRTL domain inference and checking. This pass is a
10// bottom-up transform acting on modules. For each moduleOp, we ensure there are
11// no domain crossings, and we make explicit the domain associations of ports.
12//
13//===----------------------------------------------------------------------===//
14
19#include "circt/Support/Debug.h"
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/SmallVector.h"
23#include "llvm/ADT/TinyPtrVector.h"
24#include "llvm/Support/Debug.h"
25
26#define DEBUG_TYPE "firrtl-infer-domains"
27
28namespace circt {
29namespace firrtl {
30#define GEN_PASS_DEF_INFERDOMAINS
31#include "circt/Dialect/FIRRTL/Passes.h.inc"
32} // namespace firrtl
33} // namespace circt
34
35using namespace circt;
36using namespace firrtl;
37
38//====--------------------------------------------------------------------------
39// Helpers.
40//====--------------------------------------------------------------------------
41
42using DomainValue = mlir::TypedValue<DomainType>;
43
44using PortInsertions = SmallVector<std::pair<unsigned, PortInfo>>;
45
46/// From a domain info attribute, get the domain-type of a domain value at
47/// index i.
48static StringAttr getDomainPortTypeName(ArrayAttr info, size_t i) {
49 if (info.empty())
50 return nullptr;
51 return cast<FlatSymbolRefAttr>(info[i]).getAttr();
52}
53
54/// From a domain info attribute, get the row of associated domains for a
55/// hardware value at index i.
56static auto getPortDomainAssociation(ArrayAttr info, size_t i) {
57 if (info.empty())
58 return info.getAsRange<IntegerAttr>();
59 return cast<ArrayAttr>(info[i]).getAsRange<IntegerAttr>();
60}
61
62/// Return true if the value is a port on the module.
63static bool isPort(BlockArgument arg) {
64 return isa<FModuleOp>(arg.getOwner()->getParentOp());
65}
66
67/// Return true if the value is a port on the module.
68static bool isPort(Value value) {
69 auto arg = dyn_cast<BlockArgument>(value);
70 if (!arg)
71 return false;
72 return isPort(arg);
73}
74
75/// Returns true if the value is driven by a connect op.
76static bool isDriven(DomainValue port) {
77 for (auto *user : port.getUsers())
78 if (auto connect = dyn_cast<FConnectLike>(user))
79 if (connect.getDest() == port)
80 return true;
81 return false;
82}
83
84//====--------------------------------------------------------------------------
85// Global State.
86//====--------------------------------------------------------------------------
87
88/// Each domain type declared in the circuit is assigned a type-id, based on the
89/// order of declaration. Domain associations for hardware values are
90/// represented as a list, or row, of domains. The domains in a row are ordered
91/// according to their type's id.
92namespace {
93struct DomainTypeID {
94 size_t index;
95};
96} // namespace
97
98/// Information about the domains in the circuit. Able to map domains to their
99/// type ID, which in this pass is the canonical way to reference the type
100/// of a domain, as well as provide fast access to domain ops.
101namespace {
102class DomainInfo {
103public:
104 DomainInfo(CircuitOp circuit) { processCircuit(circuit); }
105
106 ArrayRef<DomainOp> getDomains() const { return domainTable; }
107 size_t getNumDomains() const { return domainTable.size(); }
108 DomainOp getDomain(DomainTypeID id) const { return domainTable[id.index]; }
109
110 DomainTypeID getDomainTypeID(StringAttr name) const {
111 return typeIDTable.at(name);
112 }
113
114 DomainTypeID getDomainTypeID(FlatSymbolRefAttr ref) const {
115 return getDomainTypeID(ref.getAttr());
116 }
117
118 DomainTypeID getDomainTypeID(ArrayAttr info, size_t i) const {
119 auto name = getDomainPortTypeName(info, i);
120 return getDomainTypeID(name);
121 }
122
123 DomainTypeID getDomainTypeID(DomainValue value) const {
124 if (auto arg = dyn_cast<BlockArgument>(value)) {
125 auto *block = arg.getOwner();
126 auto *owner = block->getParentOp();
127 auto moduleOp = cast<FModuleOp>(owner);
128 auto info = moduleOp.getDomainInfoAttr();
129 auto i = arg.getArgNumber();
130 return getDomainTypeID(info, i);
131 }
132
133 auto result = dyn_cast<OpResult>(value);
134 auto *owner = result.getOwner();
135
136 auto info = TypeSwitch<Operation *, ArrayAttr>(owner)
137 .Case<InstanceOp, InstanceChoiceOp>(
138 [&](auto inst) { return inst.getDomainInfoAttr(); })
139 .Default([&](auto inst) { return nullptr; });
140 assert(info && "unable to obtain domain information from op");
141
142 auto i = result.getResultNumber();
143 return getDomainTypeID(info, i);
144 }
145
146private:
147 void processDomain(DomainOp op) {
148 auto index = domainTable.size();
149 auto name = op.getNameAttr();
150 domainTable.push_back(op);
151 typeIDTable.insert({name, {index}});
152 }
153
154 void processCircuit(CircuitOp circuit) {
155 for (auto decl : circuit.getOps<DomainOp>())
156 processDomain(decl);
157 }
158
159 /// A map from domain type ID to op.
160 SmallVector<DomainOp> domainTable;
161
162 /// A map from domain name to type ID.
163 DenseMap<StringAttr, DomainTypeID> typeIDTable;
164};
165
166/// Information about the changes made to the interface of a moduleOp, which can
167/// be replayed onto an instance.
168struct ModuleUpdateInfo {
169 /// The updated domain information for a moduleOp.
170 ArrayAttr portDomainInfo;
171 /// The domain ports which have been inserted into a moduleOp.
172 PortInsertions portInsertions;
173};
174} // namespace
175
176using ModuleUpdateTable = DenseMap<StringAttr, ModuleUpdateInfo>;
177
178/// Apply the port changes of a moduleOp onto an instance-like op.
179template <typename T>
180static T fixInstancePorts(T op, const ModuleUpdateInfo &update) {
181 auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions);
182 clone.setDomainInfoAttr(update.portDomainInfo);
183 op->erase();
184 return clone;
185}
186
187//====--------------------------------------------------------------------------
188// Terms: Syntax for unifying domain and domain-rows.
189//====--------------------------------------------------------------------------
190
191/// The different sorts of terms in the unification engine.
192namespace {
193enum class TermKind {
194 Variable,
195 Value,
196 Row,
197};
198} // namespace
199
200/// A term in the unification engine.
201namespace {
202struct Term {
203 constexpr Term(TermKind kind) : kind(kind) {}
204 TermKind kind;
205};
206} // namespace
207
208/// Helper to define a term kind.
209namespace {
210template <TermKind K>
211struct TermBase : Term {
212 static bool classof(const Term *term) { return term->kind == K; }
213 TermBase() : Term(K) {}
214};
215} // namespace
216
217/// An unknown value.
218namespace {
219struct VariableTerm : public TermBase<TermKind::Variable> {
220 VariableTerm() : leader(nullptr) {}
221 VariableTerm(Term *leader) : leader(leader) {}
222 Term *leader;
223};
224} // namespace
225
226/// A concrete value defined in the IR.
227namespace {
228struct ValueTerm : public TermBase<TermKind::Value> {
229 ValueTerm(DomainValue value) : value(value) {}
230 DomainValue value;
231};
232} // namespace
233
234/// A row of domains.
235namespace {
236struct RowTerm : public TermBase<TermKind::Row> {
237 RowTerm(ArrayRef<Term *> elements) : elements(elements) {}
238 ArrayRef<Term *> elements;
239};
240} // namespace
241
242// NOLINTNEXTLINE(misc-no-recursion)
243static Term *find(Term *x) {
244 if (!x)
245 return nullptr;
246
247 if (auto *var = dyn_cast<VariableTerm>(x)) {
248 if (var->leader == nullptr)
249 return var;
250
251 auto *leader = find(var->leader);
252 if (leader != var->leader)
253 var->leader = leader;
254 return leader;
255 }
256
257 return x;
258}
259
260/// A helper for assigning low numeric IDs to variables for user-facing output.
261namespace {
262class VariableIDTable {
263public:
264 size_t get(VariableTerm *term) {
265 return table.insert({term, table.size() + 1}).first->second;
266 }
267
268private:
269 DenseMap<VariableTerm *, size_t> table;
270};
271} // namespace
272
273// NOLINTNEXTLINE(misc-no-recursion)
274static void render(const DomainInfo &info, Diagnostic &out,
275 VariableIDTable &idTable, Term *term) {
276 term = find(term);
277 if (auto *var = dyn_cast<VariableTerm>(term)) {
278 out << "?" << idTable.get(var);
279 return;
280 }
281 if (auto *val = dyn_cast<ValueTerm>(term)) {
282 auto value = val->value;
283 auto [name, _] = getFieldName(FieldRef(value, 0), false);
284 out << name;
285 return;
286 }
287 if (auto *row = dyn_cast<RowTerm>(term)) {
288 bool first = true;
289 out << "[";
290 for (size_t i = 0, e = info.getNumDomains(); i < e; ++i) {
291 auto domainOp = info.getDomain(DomainTypeID{i});
292 if (!first) {
293 out << ", ";
294 first = false;
295 }
296 out << domainOp.getName() << ": ";
297 render(info, out, idTable, row->elements[i]);
298 }
299 out << "]";
300 return;
301 }
302}
303
304static LogicalResult unify(Term *lhs, Term *rhs);
305
306static LogicalResult unify(VariableTerm *x, Term *y) {
307 assert(!x->leader);
308 x->leader = y;
309 return success();
310}
311
312static LogicalResult unify(ValueTerm *xv, Term *y) {
313 if (auto *yv = dyn_cast<VariableTerm>(y)) {
314 yv->leader = xv;
315 return success();
316 }
317
318 if (auto *yv = dyn_cast<ValueTerm>(y))
319 return success(xv == yv);
320
321 return failure();
322}
323
324// NOLINTNEXTLINE(misc-no-recursion)
325static LogicalResult unify(RowTerm *lhsRow, Term *rhs) {
326 if (auto *rhsVar = dyn_cast<VariableTerm>(rhs)) {
327 rhsVar->leader = lhsRow;
328 return success();
329 }
330 if (auto *rhsRow = dyn_cast<RowTerm>(rhs)) {
331 for (auto [x, y] : llvm::zip_equal(lhsRow->elements, rhsRow->elements))
332 if (failed(unify(x, y)))
333 return failure();
334 return success();
335 }
336 return failure();
337}
338
339// NOLINTNEXTLINE(misc-no-recursion)
340static LogicalResult unify(Term *lhs, Term *rhs) {
341 if (!lhs || !rhs)
342 return success();
343 lhs = find(lhs);
344 rhs = find(rhs);
345 if (lhs == rhs)
346 return success();
347 if (auto *lhsVar = dyn_cast<VariableTerm>(lhs))
348 return unify(lhsVar, rhs);
349 if (auto *lhsVal = dyn_cast<ValueTerm>(lhs))
350 return unify(lhsVal, rhs);
351 if (auto *lhsRow = dyn_cast<RowTerm>(lhs))
352 return unify(lhsRow, rhs);
353 return failure();
354}
355
356static void solve(Term *lhs, Term *rhs) {
357 [[maybe_unused]] auto result = unify(lhs, rhs);
358 assert(result.succeeded());
359}
360
361namespace {
362class TermAllocator {
363public:
364 /// Allocate a row of fresh domain variables.
365 [[nodiscard]] RowTerm *allocRow(size_t size) {
366 SmallVector<Term *> elements;
367 elements.resize(size);
368 return allocRow(elements);
369 }
370
371 /// Allocate a row of terms.
372 [[nodiscard]] RowTerm *allocRow(ArrayRef<Term *> elements) {
373 auto ds = allocArray(elements);
374 return alloc<RowTerm>(ds);
375 }
376
377 /// Allocate a fresh variable.
378 [[nodiscard]] VariableTerm *allocVar() { return alloc<VariableTerm>(); }
379
380 /// Allocate a concrete domain.
381 [[nodiscard]] ValueTerm *allocVal(DomainValue value) {
382 return alloc<ValueTerm>(value);
383 }
384
385private:
386 template <typename T, typename... Args>
387 [[nodiscard]] T *alloc(Args &&...args) {
388 static_assert(std::is_base_of_v<Term, T>, "T must be a term");
389 return new (allocator) T(std::forward<Args>(args)...);
390 }
391
392 [[nodiscard]] ArrayRef<Term *> allocArray(ArrayRef<Term *> elements) {
393 auto size = elements.size();
394 if (size == 0)
395 return {};
396
397 auto *result = allocator.Allocate<Term *>(size);
398 llvm::uninitialized_copy(elements, result);
399 for (size_t i = 0; i < size; ++i)
400 if (!result[i])
401 result[i] = alloc<VariableTerm>();
402
403 return ArrayRef(result, size);
404 }
405
406 llvm::BumpPtrAllocator allocator;
407};
408} // namespace
409
410//====--------------------------------------------------------------------------
411// DomainTable: A mapping from IR to terms.
412//====--------------------------------------------------------------------------
413
414namespace {
415/// Tracks domain infomation for IR values.
416class DomainTable {
417public:
418 /// If the domain value is an alias, returns the domain it aliases.
419 DomainValue getOptUnderlyingDomain(DomainValue value) const {
420 auto *term = getOptTermForDomain(value);
421 if (auto *val = llvm::dyn_cast_if_present<ValueTerm>(term))
422 return val->value;
423 return nullptr;
424 }
425
426 /// Get the corresponding term for a domain in the IR, or null if unset.
427 Term *getOptTermForDomain(DomainValue value) const {
428 assert(isa<DomainType>(value.getType()));
429 auto it = termTable.find(value);
430 if (it == termTable.end())
431 return nullptr;
432 return find(it->second);
433 }
434
435 /// Get the corresponding term for a domain in the IR.
436 Term *getTermForDomain(DomainValue value) const {
437 auto *term = getOptTermForDomain(value);
438 assert(term);
439 return term;
440 }
441
442 /// Record a mapping from domain in the IR to its corresponding term.
443 void setTermForDomain(DomainValue value, Term *term) {
444 assert(term);
445 assert(!termTable.contains(value));
446 termTable.insert({value, term});
447 }
448
449 /// For a hardware value, get the term which represents the row of associated
450 /// domains. If no mapping has been defined, returns nullptr.
451 Term *getOptDomainAssociation(Value value) const {
452 assert(isa<FIRRTLBaseType>(value.getType()));
453 auto it = associationTable.find(value);
454 if (it == associationTable.end())
455 return nullptr;
456 return find(it->second);
457 }
458
459 /// For a hardware value, get the term which represents the row of associated
460 /// domains.
461 Term *getDomainAssociation(Value value) const {
462 auto *term = getOptDomainAssociation(value);
463 assert(term);
464 return term;
465 }
466
467 /// Record a mapping from a hardware value in the IR to a term which
468 /// represents the row of domains it is associated with.
469 void setDomainAssociation(Value value, Term *term) {
470 assert(isa<FIRRTLBaseType>(value.getType()));
471 assert(term);
472 term = find(term);
473 associationTable.insert({value, term});
474 }
475
476private:
477 /// Map from domains in the IR to their underlying term.
478 DenseMap<Value, Term *> termTable;
479
480 /// A map from hardware values to their associated row of domains, as a term.
481 DenseMap<Value, Term *> associationTable;
482};
483} // namespace
484
485//====--------------------------------------------------------------------------
486// Module processing: solve for the domain associations of hardware.
487//====--------------------------------------------------------------------------
488
489/// Get the corresponding term for a domain in the IR. If we don't know what the
490/// term is, then map the domain in the IR to a variable term.
491static Term *getTermForDomain(TermAllocator &allocator, DomainTable &table,
492 DomainValue value) {
493 assert(isa<DomainType>(value.getType()));
494 if (auto *term = table.getOptTermForDomain(value))
495 return term;
496 auto *term = allocator.allocVar();
497 table.setTermForDomain(value, term);
498 return term;
499}
500
501static void processDomainDefinition(TermAllocator &allocator,
502 DomainTable &table, DomainValue domain) {
503 assert(isa<DomainType>(domain.getType()));
504 auto *newTerm = allocator.allocVal(domain);
505 auto *oldTerm = table.getOptTermForDomain(domain);
506 if (!oldTerm) {
507 table.setTermForDomain(domain, newTerm);
508 return;
509 }
510
511 [[maybe_unused]] auto result = unify(oldTerm, newTerm);
512 assert(result.succeeded());
513}
514
515/// Get the row of domains that a hardware value in the IR is associated with.
516/// The returned term is forced to be at least a row.
517static RowTerm *getDomainAssociationAsRow(const DomainInfo &info,
518 TermAllocator &allocator,
519 DomainTable &table, Value value) {
520 assert(isa<FIRRTLBaseType>(value.getType()));
521 auto *term = table.getOptDomainAssociation(value);
522
523 // If the term is unknown, allocate a fresh row and set the association.
524 if (!term) {
525 auto *row = allocator.allocRow(info.getNumDomains());
526 table.setDomainAssociation(value, row);
527 return row;
528 }
529
530 // If the term is already a row, return it.
531 if (auto *row = dyn_cast<RowTerm>(term))
532 return row;
533
534 // Otherwise, unify the term with a fresh row of domains.
535 if (auto *var = dyn_cast<VariableTerm>(term)) {
536 auto *row = allocator.allocRow(info.getNumDomains());
537 solve(var, row);
538 return row;
539 }
540
541 assert(false && "unhandled term type");
542 return nullptr;
543}
544
545static void noteLocation(mlir::InFlightDiagnostic &diag, Operation *op) {
546 auto &note = diag.attachNote(op->getLoc());
547 if (auto mod = dyn_cast<FModuleOp>(op)) {
548 note << "in module " << mod.getModuleNameAttr();
549 return;
550 }
551 if (auto mod = dyn_cast<FExtModuleOp>(op)) {
552 note << "in extmodule " << mod.getModuleNameAttr();
553 return;
554 }
555 if (auto inst = dyn_cast<InstanceOp>(op)) {
556 note << "in instance " << inst.getInstanceNameAttr();
557 return;
558 }
559 if (auto inst = dyn_cast<InstanceChoiceOp>(op)) {
560 note << "in instance_choice " << inst.getNameAttr();
561 return;
562 }
563
564 note << "here";
565}
566
567template <typename T>
568static void emitPortDomainCrossingError(const DomainInfo &info, T op, size_t i,
569 DomainTypeID domainTypeID, Term *term1,
570 Term *term2) {
571 VariableIDTable idTable;
572
573 auto portName = op.getPortNameAttr(i);
574 auto portLoc = op.getPortLocation(i);
575 auto domainDecl = info.getDomain(domainTypeID);
576 auto domainName = domainDecl.getNameAttr();
577
578 auto diag = emitError(portLoc);
579 diag << "illegal " << domainName << " crossing in port " << portName;
580
581 auto &note1 = diag.attachNote();
582 note1 << "1st instance: ";
583 render(info, note1, idTable, term1);
584
585 auto &note2 = diag.attachNote();
586 note2 << "2nd instance: ";
587 render(info, note2, idTable, term2);
588
589 noteLocation(diag, op);
590}
591
592template <typename T>
593static void emitDuplicatePortDomainError(const DomainInfo &info, T op, size_t i,
594 DomainTypeID domainTypeID,
595 IntegerAttr domainPortIndexAttr1,
596 IntegerAttr domainPortIndexAttr2) {
597 VariableIDTable idTable;
598 auto portName = op.getPortNameAttr(i);
599 auto portLoc = op.getPortLocation(i);
600 auto domainDecl = info.getDomain(domainTypeID);
601 auto domainName = domainDecl.getNameAttr();
602 auto domainPortIndex1 = domainPortIndexAttr1.getUInt();
603 auto domainPortIndex2 = domainPortIndexAttr2.getUInt();
604 auto domainPortName1 = op.getPortNameAttr(domainPortIndex1);
605 auto domainPortName2 = op.getPortNameAttr(domainPortIndex2);
606 auto domainPortLoc1 = op.getPortLocation(domainPortIndex1);
607 auto domainPortLoc2 = op.getPortLocation(domainPortIndex2);
608 auto diag = emitError(portLoc);
609 diag << "duplicate " << domainName << " association for port " << portName;
610 auto &note1 = diag.attachNote(domainPortLoc1);
611 note1 << "associated with " << domainName << " port " << domainPortName1;
612 auto &note2 = diag.attachNote(domainPortLoc2);
613 note2 << "associated with " << domainName << " port " << domainPortName2;
614 noteLocation(diag, op);
615}
616
617/// Emit an error when we fail to infer the concrete domain to drive to a
618/// domain port.
619template <typename T>
620static void emitDomainPortInferenceError(T op, size_t i) {
621 auto name = op.getPortNameAttr(i);
622 auto diag = emitError(op->getLoc());
623 auto info = op.getDomainInfo();
624 diag << "unable to infer value for undriven domain port " << name;
625 for (size_t j = 0, e = op.getNumPorts(); j < e; ++j) {
626 if (auto assocs = dyn_cast<ArrayAttr>(info[j])) {
627 for (auto assoc : assocs) {
628 if (i == cast<IntegerAttr>(assoc).getValue()) {
629 auto name = op.getPortNameAttr(j);
630 auto loc = op.getPortLocation(j);
631 diag.attachNote(loc) << "associated with hardware port " << name;
632 break;
633 }
634 }
635 }
636 }
637 noteLocation(diag, op);
638}
639
640template <typename T>
642 const DomainInfo &info, T op,
643 const llvm::TinyPtrVector<DomainValue> &exports, DomainTypeID typeID,
644 size_t i) {
645 auto portName = op.getPortNameAttr(i);
646 auto portLoc = op.getPortLocation(i);
647 auto domainDecl = info.getDomain(typeID);
648 auto domainName = domainDecl.getNameAttr();
649 auto diag = emitError(portLoc) << "ambiguous " << domainName
650 << " association for port " << portName;
651 for (auto e : exports) {
652 auto arg = cast<BlockArgument>(e);
653 auto name = op.getPortNameAttr(arg.getArgNumber());
654 auto loc = op.getPortLocation(arg.getArgNumber());
655 diag.attachNote(loc) << "candidate association " << name;
656 }
657 noteLocation(diag, op);
658}
659
660template <typename T>
661static void emitMissingPortDomainAssociationError(const DomainInfo &info, T op,
662 DomainTypeID typeID,
663 size_t i) {
664 auto domainName = info.getDomain(typeID).getNameAttr();
665 auto portName = op.getPortNameAttr(i);
666 auto diag = emitError(op.getPortLocation(i))
667 << "missing " << domainName << " association for port "
668 << portName;
669 noteLocation(diag, op);
670}
671
672/// Unify the associated domain rows of two terms.
673static LogicalResult unifyAssociations(const DomainInfo &info,
674 TermAllocator &allocator,
675 DomainTable &table, Operation *op,
676 Value lhs, Value rhs) {
677 if (!lhs || !rhs)
678 return success();
679
680 if (lhs == rhs)
681 return success();
682
683 auto *lhsTerm = table.getOptDomainAssociation(lhs);
684 auto *rhsTerm = table.getOptDomainAssociation(rhs);
685
686 if (lhsTerm) {
687 if (rhsTerm) {
688 if (failed(unify(lhsTerm, rhsTerm))) {
689 auto diag = op->emitOpError("illegal domain crossing in operation");
690 auto &note1 = diag.attachNote(lhs.getLoc());
691
692 note1 << "1st operand has domains: ";
693 VariableIDTable idTable;
694 render(info, note1, idTable, lhsTerm);
695
696 auto &note2 = diag.attachNote(rhs.getLoc());
697 note2 << "2nd operand has domains: ";
698 render(info, note2, idTable, rhsTerm);
699
700 return failure();
701 }
702 }
703 table.setDomainAssociation(rhs, lhsTerm);
704 return success();
705 }
706
707 if (rhsTerm) {
708 table.setDomainAssociation(lhs, rhsTerm);
709 return success();
710 }
711
712 auto *var = allocator.allocVar();
713 table.setDomainAssociation(lhs, var);
714 table.setDomainAssociation(rhs, var);
715 return success();
716}
717
718static LogicalResult processModulePorts(const DomainInfo &info,
719 TermAllocator &allocator,
720 DomainTable &table,
721 FModuleOp moduleOp) {
722 auto numDomains = info.getNumDomains();
723 auto domainInfo = moduleOp.getDomainInfoAttr();
724 auto numPorts = moduleOp.getNumPorts();
725
726 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
727 for (size_t i = 0; i < numPorts; ++i) {
728 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
729 if (!port)
730 continue;
731
732 if (moduleOp.getPortDirection(i) == Direction::In)
733 processDomainDefinition(allocator, table, port);
734
735 domainTypeIDTable[i] = info.getDomainTypeID(domainInfo, i);
736 }
737
738 for (size_t i = 0; i < numPorts; ++i) {
739 BlockArgument port = moduleOp.getArgument(i);
740 auto type = type_dyn_cast<FIRRTLBaseType>(port.getType());
741 if (!type)
742 continue;
743
744 SmallVector<IntegerAttr> associations(numDomains);
745 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
746 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
747 auto prevDomainPortIndex = associations[domainTypeID.index];
748 if (prevDomainPortIndex) {
749 emitDuplicatePortDomainError(info, moduleOp, i, domainTypeID,
750 prevDomainPortIndex, domainPortIndex);
751 return failure();
752 }
753 associations[domainTypeID.index] = domainPortIndex;
754 }
755
756 SmallVector<Term *> elements(numDomains);
757 for (size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
758 ++domainTypeIndex) {
759 auto domainPortIndex = associations[domainTypeIndex];
760 if (!domainPortIndex)
761 continue;
762 auto domainPortValue =
763 cast<DomainValue>(moduleOp.getArgument(domainPortIndex.getUInt()));
764 elements[domainTypeIndex] =
765 getTermForDomain(allocator, table, domainPortValue);
766 }
767
768 auto *domainAssociations = allocator.allocRow(elements);
769 table.setDomainAssociation(port, domainAssociations);
770 }
771
772 return success();
773}
774
775template <typename T>
776static LogicalResult processInstancePorts(const DomainInfo &info,
777 TermAllocator &allocator,
778 DomainTable &table, T op) {
779 auto numDomains = info.getNumDomains();
780 auto domainInfo = op.getDomainInfoAttr();
781 auto numPorts = op.getNumPorts();
782
783 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
784 for (size_t i = 0; i < numPorts; ++i) {
785 auto port = dyn_cast<DomainValue>(op.getResult(i));
786 if (!port)
787 continue;
788
789 if (op.getPortDirection(i) == Direction::Out)
790 processDomainDefinition(allocator, table, port);
791
792 domainTypeIDTable[i] = info.getDomainTypeID(domainInfo, i);
793 }
794
795 for (size_t i = 0; i < numPorts; ++i) {
796 Value port = op.getResult(i);
797 auto type = type_dyn_cast<FIRRTLBaseType>(port.getType());
798 if (!type)
799 continue;
800
801 SmallVector<IntegerAttr> associations(numDomains);
802 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
803 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
804 auto prevDomainPortIndex = associations[domainTypeID.index];
805 if (prevDomainPortIndex) {
806 emitDuplicatePortDomainError(info, op, i, domainTypeID,
807 prevDomainPortIndex, domainPortIndex);
808 return failure();
809 }
810 associations[domainTypeID.index] = domainPortIndex;
811 }
812
813 SmallVector<Term *> elements(numDomains);
814 for (size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
815 ++domainTypeIndex) {
816 auto domainPortIndex = associations[domainTypeIndex];
817 if (!domainPortIndex)
818 continue;
819 auto domainPortValue =
820 cast<DomainValue>(op.getResult(domainPortIndex.getUInt()));
821 elements[domainTypeIndex] =
822 getTermForDomain(allocator, table, domainPortValue);
823 }
824
825 auto *domainAssociations = allocator.allocRow(elements);
826 table.setDomainAssociation(port, domainAssociations);
827 }
828
829 return success();
830}
831
832static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
833 DomainTable &table,
834 const ModuleUpdateTable &updateTable,
835 InstanceOp op) {
836 auto moduleOp = op.getReferencedModuleNameAttr();
837 auto lookup = updateTable.find(moduleOp);
838 if (lookup != updateTable.end())
839 op = fixInstancePorts(op, lookup->second);
840 return processInstancePorts(info, allocator, table, op);
841}
842
843static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
844 DomainTable &table,
845 const ModuleUpdateTable &updateTable,
846 InstanceChoiceOp op) {
847 auto moduleOp = op.getDefaultTargetAttr().getAttr();
848 auto lookup = updateTable.find(moduleOp);
849 if (lookup != updateTable.end())
850 op = fixInstancePorts(op, lookup->second);
851 return processInstancePorts(info, allocator, table, op);
852}
853
854static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
855 DomainTable &table, UnsafeDomainCastOp op) {
856 auto domains = op.getDomains();
857 if (domains.empty())
858 return unifyAssociations(info, allocator, table, op, op.getInput(),
859 op.getResult());
860
861 auto input = op.getInput();
862 RowTerm *inputRow = getDomainAssociationAsRow(info, allocator, table, input);
863 SmallVector<Term *> elements(inputRow->elements);
864 for (auto value : op.getDomains()) {
865 auto domain = cast<DomainValue>(value);
866 auto typeID = info.getDomainTypeID(domain);
867 elements[typeID.index] = getTermForDomain(allocator, table, domain);
868 }
869
870 auto *row = allocator.allocRow(elements);
871 table.setDomainAssociation(op.getResult(), row);
872 return success();
873}
874
875static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
876 DomainTable &table, DomainDefineOp op) {
877 auto src = op.getSrc();
878 auto dst = op.getDest();
879 auto *srcTerm = getTermForDomain(allocator, table, src);
880 auto *dstTerm = getTermForDomain(allocator, table, dst);
881 if (succeeded(unify(dstTerm, srcTerm)))
882 return success();
883
884 VariableIDTable idTable;
885 auto diag = op->emitOpError("failed to propagate source to destination");
886 auto &note1 = diag.attachNote();
887 note1 << "destination has underlying value: ";
888 render(info, note1, idTable, dstTerm);
889
890 auto &note2 = diag.attachNote(src.getLoc());
891 note2 << "source has underlying value: ";
892 render(info, note2, idTable, srcTerm);
893 return failure();
894}
895
896static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
897 DomainTable &table,
898 const ModuleUpdateTable &updateTable,
899 Operation *op) {
900 if (auto instance = dyn_cast<InstanceOp>(op))
901 return processOp(info, allocator, table, updateTable, instance);
902 if (auto instance = dyn_cast<InstanceChoiceOp>(op))
903 return processOp(info, allocator, table, updateTable, instance);
904 if (auto cast = dyn_cast<UnsafeDomainCastOp>(op))
905 return processOp(info, allocator, table, cast);
906 if (auto def = dyn_cast<DomainDefineOp>(op))
907 return processOp(info, allocator, table, def);
908
909 // For all other operations (including connections), propagate domains from
910 // operands to results. This is a conservative approach - all operands and
911 // results share the same domain associations.
912 Value lhs;
913 for (auto rhs : op->getOperands()) {
914 if (!isa<FIRRTLBaseType>(rhs.getType()))
915 continue;
916 if (auto *op = rhs.getDefiningOp();
917 op && op->hasTrait<OpTrait::ConstantLike>())
918 continue;
919 if (failed(unifyAssociations(info, allocator, table, op, lhs, rhs)))
920 return failure();
921 lhs = rhs;
922 }
923 for (auto rhs : op->getResults()) {
924 if (!isa<FIRRTLBaseType>(rhs.getType()))
925 continue;
926 if (auto *op = rhs.getDefiningOp();
927 op && op->hasTrait<OpTrait::ConstantLike>())
928 continue;
929 if (failed(unifyAssociations(info, allocator, table, op, lhs, rhs)))
930 return failure();
931 lhs = rhs;
932 }
933 return success();
934}
935
936static LogicalResult processModuleBody(const DomainInfo &info,
937 TermAllocator &allocator,
938 DomainTable &table,
939 const ModuleUpdateTable &updateTable,
940 FModuleOp moduleOp) {
941 auto result = moduleOp.getBody().walk([&](Operation *op) -> WalkResult {
942 return processOp(info, allocator, table, updateTable, op);
943 });
944 return failure(result.wasInterrupted());
945}
946
947/// Populate the domain table by processing the moduleOp. If the moduleOp has
948/// any domain crossing errors, return failure.
949static LogicalResult processModule(const DomainInfo &info,
950 TermAllocator &allocator, DomainTable &table,
951 const ModuleUpdateTable &updateTable,
952 FModuleOp moduleOp) {
953 if (failed(processModulePorts(info, allocator, table, moduleOp)))
954 return failure();
955 if (failed(processModuleBody(info, allocator, table, updateTable, moduleOp)))
956 return failure();
957 return success();
958}
959
960//===---------------------------------------------------------------------------
961// ExportTable
962//===---------------------------------------------------------------------------
963
964/// A map from domain IR values defined internal to the moduleOp, to ports that
965/// alias that domain. These ports make the domain useable as associations of
966/// ports, and we say these are exporting ports.
967using ExportTable = DenseMap<DomainValue, TinyPtrVector<DomainValue>>;
968
969/// Build a table of exported domains: a map from domains defined internally,
970/// to their set of aliasing output ports.
971static ExportTable initializeExportTable(const DomainTable &table,
972 FModuleOp moduleOp) {
973 ExportTable exports;
974 size_t numPorts = moduleOp.getNumPorts();
975 for (size_t i = 0; i < numPorts; ++i) {
976 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
977 if (!port)
978 continue;
979 auto value = table.getOptUnderlyingDomain(port);
980 if (value)
981 exports[value].push_back(port);
982 }
983
984 return exports;
985}
986
987//====--------------------------------------------------------------------------
988// Updating: write domains back to the IR.
989//====--------------------------------------------------------------------------
990
991/// A map from unsolved variables to a port index, where that port has not yet
992/// been created. Eventually we will have an input domain at the port index,
993/// which will be the solution to the recorded variable.
994using PendingSolutions = DenseMap<VariableTerm *, unsigned>;
995
996/// A map from local domains to an aliasing port index, where that port has not
997/// yet been created. Eventually we will be exporting the domain value at the
998/// port index.
999using PendingExports = llvm::MapVector<DomainValue, unsigned>;
1000
1001namespace {
1002struct PendingUpdates {
1003 PortInsertions insertions;
1004 PendingSolutions solutions;
1005 PendingExports exports;
1006};
1007} // namespace
1008
1009/// If `var` is not solved, solve it by recording a pending input port at
1010/// the indicated insertion point.
1011static void ensureSolved(const DomainInfo &info, Namespace &ns,
1012 DomainTypeID typeID, size_t ip, LocationAttr loc,
1013 VariableTerm *var, PendingUpdates &pending) {
1014 if (pending.solutions.contains(var))
1015 return;
1016
1017 auto *context = loc.getContext();
1018 auto domainDecl = info.getDomain(typeID);
1019 auto domainName = domainDecl.getNameAttr();
1020
1021 auto portName = StringAttr::get(context, ns.newName(domainName.getValue()));
1022 auto portType = DomainType::get(loc.getContext());
1023 auto portDirection = Direction::In;
1024 auto portSym = StringAttr();
1025 auto portLoc = loc;
1026 auto portAnnos = std::nullopt;
1027 auto portDomainInfo = FlatSymbolRefAttr::get(domainName);
1028 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1029 portAnnos, portDomainInfo);
1030
1031 pending.solutions[var] = pending.insertions.size() + ip;
1032 pending.insertions.push_back({ip, portInfo});
1033}
1034
1035/// Ensure that the domain value is available in the signature of the moduleOp,
1036/// so that subsequent hardware ports may be associated with this domain.
1037// If the domain is defined internally in the moduleOp, ensure it is aliased by
1038// an
1039/// output port.
1040static void ensureExported(const DomainInfo &info, Namespace &ns,
1041 const ExportTable &exports, DomainTypeID typeID,
1042 size_t ip, LocationAttr loc, ValueTerm *val,
1043 PendingUpdates &pending) {
1044 auto value = val->value;
1045 assert(isa<DomainType>(value.getType()));
1046 if (isPort(value) || exports.contains(value) ||
1047 pending.exports.contains(value))
1048 return;
1049
1050 auto *context = loc.getContext();
1051
1052 auto domainDecl = info.getDomain(typeID);
1053 auto domainName = domainDecl.getNameAttr();
1054
1055 auto portName = StringAttr::get(context, ns.newName(domainName.getValue()));
1056 auto portType = DomainType::get(loc.getContext());
1057 auto portDirection = Direction::Out;
1058 auto portSym = StringAttr();
1059 auto portLoc = value.getLoc();
1060 auto portAnnos = std::nullopt;
1061 auto portDomainInfo = FlatSymbolRefAttr::get(domainName);
1062 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1063 portAnnos, portDomainInfo);
1064 pending.exports[value] = pending.insertions.size() + ip;
1065 pending.insertions.push_back({ip, portInfo});
1066}
1067
1068static void getUpdatesForDomainAssociationOfPort(const DomainInfo &info,
1069 Namespace &ns,
1070 PendingUpdates &pending,
1071 DomainTypeID typeID, size_t ip,
1072 LocationAttr loc, Term *term,
1073 const ExportTable &exports) {
1074 if (auto *var = dyn_cast<VariableTerm>(term)) {
1075 ensureSolved(info, ns, typeID, ip, loc, var, pending);
1076 return;
1077 }
1078 if (auto *val = dyn_cast<ValueTerm>(term)) {
1079 ensureExported(info, ns, exports, typeID, ip, loc, val, pending);
1080 return;
1081 }
1082 llvm_unreachable("invalid domain association");
1083}
1084
1086 const DomainInfo &info, Namespace &ns, const ExportTable &exports,
1087 size_t ip, LocationAttr loc, RowTerm *row, PendingUpdates &pending) {
1088 for (auto [index, term] : llvm::enumerate(row->elements))
1089 getUpdatesForDomainAssociationOfPort(info, ns, pending, DomainTypeID{index},
1090 ip, loc, find(term), exports);
1091}
1092
1093static void getUpdatesForModulePorts(const DomainInfo &info,
1094 TermAllocator &allocator,
1095 const ExportTable &exports,
1096 DomainTable &table, Namespace &ns,
1097 FModuleOp moduleOp,
1098 PendingUpdates &pending) {
1099 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1100 auto port = moduleOp.getArgument(i);
1101 auto type = port.getType();
1102 if (!isa<FIRRTLBaseType>(type))
1103 continue;
1105 info, ns, exports, i, moduleOp.getPortLocation(i),
1106 getDomainAssociationAsRow(info, allocator, table, port), pending);
1107 }
1108}
1109
1110static void getUpdatesForModule(const DomainInfo &info,
1111 TermAllocator &allocator,
1112 const ExportTable &exports, DomainTable &table,
1113 FModuleOp mod, PendingUpdates &pending) {
1114 Namespace ns;
1115 auto names = mod.getPortNamesAttr();
1116 for (auto name : names.getAsRange<StringAttr>())
1117 ns.add(name);
1118 getUpdatesForModulePorts(info, allocator, exports, table, ns, mod, pending);
1119}
1120
1121static void applyUpdatesToModule(const DomainInfo &info,
1122 TermAllocator &allocator, ExportTable &exports,
1123 DomainTable &table, FModuleOp moduleOp,
1124 const PendingUpdates &pending) {
1125 // Put the domain ports in place.
1126 moduleOp.insertPorts(pending.insertions);
1127
1128 // Solve any variables and record them as "self-exporting".
1129 for (auto [var, portIndex] : pending.solutions) {
1130 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1131 auto *solution = allocator.allocVal(portValue);
1132 solve(var, solution);
1133 exports[portValue].push_back(portValue);
1134 }
1135
1136 // Drive the output ports, and record the export.
1137 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1138 for (auto [domainValue, portIndex] : pending.exports) {
1139 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1140 builder.setInsertionPointAfterValue(domainValue);
1141 DomainDefineOp::create(builder, portValue.getLoc(), portValue, domainValue);
1142
1143 exports[domainValue].push_back(portValue);
1144 table.setTermForDomain(portValue, allocator.allocVal(domainValue));
1145 }
1146}
1147
1148/// Copy the domain associations from the moduleOp domain info attribute into a
1149/// small vector.
1150static SmallVector<Attribute>
1151copyPortDomainAssociations(const DomainInfo &info, ArrayAttr moduleDomainInfo,
1152 size_t portIndex) {
1153 SmallVector<Attribute> result(info.getNumDomains());
1154 auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex);
1155 for (auto domainPortIndexAttr : oldAssociations) {
1156
1157 auto domainPortIndex = domainPortIndexAttr.getUInt();
1158 auto domainTypeID = info.getDomainTypeID(moduleDomainInfo, domainPortIndex);
1159 result[domainTypeID.index] = domainPortIndexAttr;
1160 };
1161 return result;
1162}
1163
1164// If the port is an output domain, we may need to drive the output with
1165// a value. If we don't know what value to drive to the port, error.
1166static LogicalResult driveModuleOutputDomainPorts(const DomainInfo &info,
1167 const DomainTable &table,
1168 FModuleOp moduleOp) {
1169 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1170 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1171 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1172 if (!port || moduleOp.getPortDirection(i) == Direction::In ||
1173 isDriven(port))
1174 continue;
1175
1176 auto *term = table.getOptTermForDomain(port);
1177 auto *val = llvm::dyn_cast_if_present<ValueTerm>(term);
1178 if (!val) {
1179 emitDomainPortInferenceError(moduleOp, i);
1180 return failure();
1181 }
1182
1183 auto loc = port.getLoc();
1184 auto value = val->value;
1185 DomainDefineOp::create(builder, loc, port, value);
1186 }
1187
1188 return success();
1189}
1190
1191/// After generalizing the moduleOp, all domains should be solved. Reflect the
1192/// solved domain associations into the port domain info attribute.
1193static LogicalResult updateModuleDomainInfo(const DomainInfo &info,
1194 const DomainTable &table,
1195 const ExportTable &exportTable,
1196 ArrayAttr &result,
1197 FModuleOp moduleOp) {
1198 // At this point, all domain variables mentioned in ports have been
1199 // solved by generalizing the moduleOp (adding input domain ports). Now, we
1200 // have to form the new port domain information for the moduleOp by examining
1201 // the the associated domains of each port.
1202 auto *context = moduleOp.getContext();
1203 auto numDomains = info.getNumDomains();
1204 auto oldModuleDomainInfo = moduleOp.getDomainInfoAttr();
1205 auto numPorts = moduleOp.getNumPorts();
1206 SmallVector<Attribute> newModuleDomainInfo(numPorts);
1207
1208 for (size_t i = 0; i < numPorts; ++i) {
1209 auto port = moduleOp.getArgument(i);
1210 auto type = port.getType();
1211
1212 if (isa<DomainType>(type)) {
1213 newModuleDomainInfo[i] = oldModuleDomainInfo[i];
1214 continue;
1215 }
1216
1217 if (isa<FIRRTLBaseType>(type)) {
1218 auto associations =
1219 copyPortDomainAssociations(info, oldModuleDomainInfo, i);
1220 auto *row = cast<RowTerm>(table.getDomainAssociation(port));
1221 for (size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1222 auto domainTypeID = DomainTypeID{domainIndex};
1223 if (associations[domainIndex])
1224 continue;
1225
1226 auto domain = cast<ValueTerm>(find(row->elements[domainIndex]))->value;
1227 auto &exports = exportTable.at(domain);
1228 if (exports.empty()) {
1229 auto portName = moduleOp.getPortNameAttr(i);
1230 auto portLoc = moduleOp.getPortLocation(i);
1231 auto domainDecl = info.getDomain(domainTypeID);
1232 auto domainName = domainDecl.getNameAttr();
1233 auto diag = emitError(portLoc)
1234 << "private " << domainName << " association for port "
1235 << portName;
1236 diag.attachNote(domain.getLoc()) << "associated domain: " << domain;
1237 noteLocation(diag, moduleOp);
1238 return failure();
1239 }
1240
1241 if (exports.size() > 1) {
1242 emitAmbiguousPortDomainAssociation(info, moduleOp, exports,
1243 domainTypeID, i);
1244 return failure();
1245 }
1246
1247 auto argument = cast<BlockArgument>(exports[0]);
1248 auto domainPortIndex = argument.getArgNumber();
1249 associations[domainTypeID.index] = IntegerAttr::get(
1250 IntegerType::get(context, 32, IntegerType::Unsigned),
1251 domainPortIndex);
1252 }
1253
1254 newModuleDomainInfo[i] = ArrayAttr::get(context, associations);
1255 continue;
1256 }
1257
1258 newModuleDomainInfo[i] = ArrayAttr::get(context, {});
1259 }
1260
1261 result = ArrayAttr::get(moduleOp.getContext(), newModuleDomainInfo);
1262 moduleOp.setDomainInfoAttr(result);
1263 return success();
1264}
1265
1266template <typename T>
1267static LogicalResult updateInstance(const DomainInfo &info,
1268 TermAllocator &allocator,
1269 DomainTable &table, T op) {
1270 OpBuilder builder(op.getContext());
1271 builder.setInsertionPointAfter(op);
1272 auto numPorts = op->getNumResults();
1273 for (size_t i = 0; i < numPorts; ++i) {
1274 auto port = dyn_cast<DomainValue>(op.getResult(i));
1275 auto direction = op.getPortDirection(i);
1276
1277 // If the port is an input domain, we may need to drive the input with
1278 // a value. If we don't know what value to drive to the port, drive an
1279 // anonymous domain.
1280 if (port && direction == Direction::In && !isDriven(port)) {
1281 auto loc = port.getLoc();
1282 auto *term = getTermForDomain(allocator, table, port);
1283 if (auto *var = dyn_cast<VariableTerm>(term)) {
1284 auto name = getDomainPortTypeName(op.getDomainInfo(), i);
1285 auto anon = DomainCreateAnonOp::create(builder, loc, name);
1286 solve(var, allocator.allocVal(anon));
1287 DomainDefineOp::create(builder, loc, port, anon);
1288 continue;
1289 }
1290 if (auto *val = dyn_cast<ValueTerm>(term)) {
1291 auto value = val->value;
1292 DomainDefineOp::create(builder, loc, port, value);
1293 continue;
1294 }
1295 llvm_unreachable("unhandled domain term type");
1296 }
1297 }
1298
1299 return success();
1300}
1301
1302static LogicalResult updateOp(const DomainInfo &info, TermAllocator &allocator,
1303 DomainTable &table, Operation *op) {
1304 if (auto instance = dyn_cast<InstanceOp>(op))
1305 return updateInstance(info, allocator, table, instance);
1306 if (auto instance = dyn_cast<InstanceChoiceOp>(op))
1307 return updateInstance(info, allocator, table, instance);
1308 return success();
1309}
1310
1311/// After updating the port domain associations, walk the body of the moduleOp
1312/// to fix up any child instance modules.
1313static LogicalResult updateModuleBody(const DomainInfo &info,
1314 TermAllocator &allocator,
1315 DomainTable &table, FModuleOp moduleOp) {
1316 auto result = moduleOp.getBodyBlock()->walk([&](Operation *op) -> WalkResult {
1317 return updateOp(info, allocator, table, op);
1318 });
1319 return failure(result.wasInterrupted());
1320}
1321
1322/// Write the domain associations recorded in the domain table back to the IR.
1323static LogicalResult updateModule(const DomainInfo &info,
1324 TermAllocator &allocator, DomainTable &table,
1325 ModuleUpdateTable &updates, FModuleOp op) {
1326 auto exports = initializeExportTable(table, op);
1327 PendingUpdates pending;
1328 getUpdatesForModule(info, allocator, exports, table, op, pending);
1329 applyUpdatesToModule(info, allocator, exports, table, op, pending);
1330
1331 // Update the domain info for the moduleOp's ports.
1332 ArrayAttr portDomainInfo;
1333 if (failed(updateModuleDomainInfo(info, table, exports, portDomainInfo, op)))
1334 return failure();
1335
1336 // Drive output domain ports.
1337 if (failed(driveModuleOutputDomainPorts(info, table, op)))
1338 return failure();
1339
1340 // Record the updated interface change in the update table.
1341 auto &entry = updates[op.getModuleNameAttr()];
1342 entry.portDomainInfo = portDomainInfo;
1343 entry.portInsertions = std::move(pending.insertions);
1344
1345 if (failed(updateModuleBody(info, allocator, table, op)))
1346 return failure();
1347
1348 return success();
1349}
1350
1351//===---------------------------------------------------------------------------
1352// Checking: Check that a moduleOp has complete domain information.
1353//===---------------------------------------------------------------------------
1354
1355/// Check that a module's hardware ports have complete domain associations.
1356static LogicalResult checkModulePorts(const DomainInfo &info,
1357 FModuleLike moduleOp) {
1358 auto numDomains = info.getNumDomains();
1359 auto domainInfo = moduleOp.getDomainInfoAttr();
1360 auto numPorts = moduleOp.getNumPorts();
1361
1362 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1363 for (size_t i = 0; i < numPorts; ++i) {
1364 if (isa<DomainType>(moduleOp.getPortType(i)))
1365 domainTypeIDTable[i] = info.getDomainTypeID(domainInfo, i);
1366 }
1367
1368 for (size_t i = 0; i < numPorts; ++i) {
1369 auto type = type_dyn_cast<FIRRTLBaseType>(moduleOp.getPortType(i));
1370 if (!type)
1371 continue;
1372
1373 // Record the domain associations of this port.
1374 SmallVector<IntegerAttr> associations(numDomains);
1375 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
1376 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1377 auto prevDomainPortIndex = associations[domainTypeID.index];
1378 if (prevDomainPortIndex) {
1379 emitDuplicatePortDomainError(info, moduleOp, i, domainTypeID,
1380 prevDomainPortIndex, domainPortIndex);
1381 return failure();
1382 }
1383 associations[domainTypeID.index] = domainPortIndex;
1384 }
1385
1386 // Check the associations for completeness.
1387 for (size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1388 auto typeID = DomainTypeID{domainIndex};
1389 if (!associations[domainIndex]) {
1390 emitMissingPortDomainAssociationError(info, moduleOp, typeID, i);
1391 return failure();
1392 }
1393 }
1394 }
1395
1396 return success();
1397}
1398
1399/// Check that output domain ports are driven.
1400static LogicalResult checkModuleDomainPortDrivers(const DomainInfo &info,
1401 FModuleOp moduleOp) {
1402 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1403 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1404 if (!port || moduleOp.getPortDirection(i) != Direction::Out ||
1405 isDriven(port))
1406 continue;
1407
1408 auto name = moduleOp.getPortNameAttr(i);
1409 auto diag = emitError(moduleOp.getPortLocation(i))
1410 << "undriven domain port " << name;
1411 noteLocation(diag, moduleOp);
1412 return failure();
1413 }
1414
1415 return success();
1416}
1417
1418/// Check that the input domain ports are driven.
1419template <typename T>
1420static LogicalResult checkInstanceDomainPortDrivers(T op) {
1421 for (size_t i = 0, e = op.getNumResults(); i < e; ++i) {
1422 auto port = dyn_cast<DomainValue>(op.getResult(i));
1423 auto type = port.getType();
1424 if (!isa<DomainType>(type) || op.getPortDirection(i) != Direction::In ||
1425 isDriven(port))
1426 continue;
1427
1428 auto name = op.getPortNameAttr(i);
1429 auto diag = emitError(op.getPortLocation(i))
1430 << "undriven domain port " << name;
1431 noteLocation(diag, op);
1432 return failure();
1433 }
1434
1435 return success();
1436}
1437
1438static LogicalResult checkOp(Operation *op) {
1439 if (auto inst = dyn_cast<InstanceOp>(op))
1440 return checkInstanceDomainPortDrivers(inst);
1441 if (auto inst = dyn_cast<InstanceChoiceOp>(op))
1442 return checkInstanceDomainPortDrivers(inst);
1443 return success();
1444}
1445
1446/// Check that instances under this module have driven domain input ports.
1447static LogicalResult checkModuleBody(FModuleOp moduleOp) {
1448 auto result = moduleOp.getBody().walk(
1449 [&](Operation *op) -> WalkResult { return checkOp(op); });
1450 return failure(result.wasInterrupted());
1451}
1452
1453//===---------------------------------------------------------------------------
1454// InferDomainsPass: Top-level pass implementation.
1455//===---------------------------------------------------------------------------
1456
1457/// Solve for domains and then write the domain associations back to the IR.
1458static LogicalResult inferModule(const DomainInfo &info,
1459 ModuleUpdateTable &updates,
1460 FModuleOp moduleOp) {
1461 TermAllocator allocator;
1462 DomainTable table;
1463
1464 if (failed(processModule(info, allocator, table, updates, moduleOp)))
1465 return failure();
1466
1467 return updateModule(info, allocator, table, updates, moduleOp);
1468}
1469
1470/// Check that a module's ports are fully annotated, before performing domain
1471/// inference on the module.
1472static LogicalResult checkModule(const DomainInfo &info, FModuleOp moduleOp) {
1473 if (failed(checkModulePorts(info, moduleOp)))
1474 return failure();
1475
1476 if (failed(checkModuleDomainPortDrivers(info, moduleOp)))
1477 return failure();
1478
1479 if (failed(checkModuleBody(moduleOp)))
1480 return failure();
1481
1482 TermAllocator allocator;
1483 DomainTable table;
1484 ModuleUpdateTable updateTable;
1485 return processModule(info, allocator, table, updateTable, moduleOp);
1486}
1487
1488/// Check that an extmodule's ports are fully annotated.
1489static LogicalResult checkModule(const DomainInfo &info,
1490 FExtModuleOp moduleOp) {
1491 return checkModulePorts(info, moduleOp);
1492}
1493
1494/// Check that a module's ports are fully annotated, before performing domain
1495/// inference on the module. We use this when private module interfaces are
1496/// inferred but public module interfaces are checked.
1497static LogicalResult checkAndInferModule(const DomainInfo &info,
1498 ModuleUpdateTable &updateTable,
1499 FModuleOp moduleOp) {
1500 if (failed(checkModulePorts(info, moduleOp)))
1501 return failure();
1502
1503 TermAllocator allocator;
1504 DomainTable table;
1505 if (failed(processModule(info, allocator, table, updateTable, moduleOp)))
1506 return failure();
1507
1508 if (failed(driveModuleOutputDomainPorts(info, table, moduleOp)))
1509 return failure();
1510
1511 return updateModuleBody(info, allocator, table, moduleOp);
1512}
1513
1514static LogicalResult runOnModuleLike(InferDomainsMode mode,
1515 const DomainInfo &info,
1516 ModuleUpdateTable &updateTable,
1517 Operation *op) {
1518 if (auto moduleOp = dyn_cast<FModuleOp>(op)) {
1519 if (mode == InferDomainsMode::Check)
1520 return checkModule(info, moduleOp);
1521
1522 if (mode == InferDomainsMode::InferAll || moduleOp.isPrivate())
1523 return inferModule(info, updateTable, moduleOp);
1524
1525 return checkAndInferModule(info, updateTable, moduleOp);
1526 }
1527
1528 if (auto extModule = dyn_cast<FExtModuleOp>(op))
1529 return checkModule(info, extModule);
1530
1531 return success();
1532}
1533
1534namespace {
1535struct InferDomainsPass
1536 : public circt::firrtl::impl::InferDomainsBase<InferDomainsPass> {
1537 using Base::Base;
1538 void runOnOperation() override {
1540 auto circuit = getOperation();
1541 auto &instanceGraph = getAnalysis<InstanceGraph>();
1542 DomainInfo info(circuit);
1543 ModuleUpdateTable updateTable;
1544 DenseSet<Operation *> errored;
1545 instanceGraph.walkPostOrder([&](auto &node) {
1546 auto moduleOp = node.getModule();
1547 for (auto *inst : node) {
1548 if (errored.contains(inst->getTarget()->getModule())) {
1549 errored.insert(moduleOp);
1550 return;
1551 }
1552 }
1553 if (failed(runOnModuleLike(mode, info, updateTable, node.getModule())))
1554 errored.insert(moduleOp);
1555 });
1556 if (errored.size())
1557 signalPassFailure();
1558 }
1559};
1560} // namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static LogicalResult checkOp(Operation *op)
static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, InstanceOp op)
static LogicalResult updateModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, ModuleUpdateTable &updates, FModuleOp op)
Write the domain associations recorded in the domain table back to the IR.
static void emitDuplicatePortDomainError(const DomainInfo &info, T op, size_t i, DomainTypeID domainTypeID, IntegerAttr domainPortIndexAttr1, IntegerAttr domainPortIndexAttr2)
static ExportTable initializeExportTable(const DomainTable &table, FModuleOp moduleOp)
Build a table of exported domains: a map from domains defined internally, to their set of aliasing ou...
static LogicalResult processModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FModuleOp moduleOp)
static void emitAmbiguousPortDomainAssociation(const DomainInfo &info, T op, const llvm::TinyPtrVector< DomainValue > &exports, DomainTypeID typeID, size_t i)
static LogicalResult processModulePorts(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp)
static LogicalResult inferModule(const DomainInfo &info, ModuleUpdateTable &updates, FModuleOp moduleOp)
Solve for domains and then write the domain associations back to the IR.
static LogicalResult driveModuleOutputDomainPorts(const DomainInfo &info, const DomainTable &table, FModuleOp moduleOp)
SmallVector< std::pair< unsigned, PortInfo > > PortInsertions
llvm::MapVector< DomainValue, unsigned > PendingExports
A map from local domains to an aliasing port index, where that port has not yet been created.
static LogicalResult runOnModuleLike(InferDomainsMode mode, const DomainInfo &info, ModuleUpdateTable &updateTable, Operation *op)
mlir::TypedValue< DomainType > DomainValue
static RowTerm * getDomainAssociationAsRow(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Value value)
Get the row of domains that a hardware value in the IR is associated with.
static void emitMissingPortDomainAssociationError(const DomainInfo &info, T op, DomainTypeID typeID, size_t i)
static void getUpdatesForModulePorts(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, Namespace &ns, FModuleOp moduleOp, PendingUpdates &pending)
static T fixInstancePorts(T op, const ModuleUpdateInfo &update)
Apply the port changes of a moduleOp onto an instance-like op.
static LogicalResult updateModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp)
After updating the port domain associations, walk the body of the moduleOp to fix up any child instan...
static void render(const DomainInfo &info, Diagnostic &out, VariableIDTable &idTable, Term *term)
static StringAttr getDomainPortTypeName(ArrayAttr info, size_t i)
From a domain info attribute, get the domain-type of a domain value at index i.
static void processDomainDefinition(TermAllocator &allocator, DomainTable &table, DomainValue domain)
static LogicalResult unify(Term *lhs, Term *rhs)
static LogicalResult updateModuleDomainInfo(const DomainInfo &info, const DomainTable &table, const ExportTable &exportTable, ArrayAttr &result, FModuleOp moduleOp)
After generalizing the moduleOp, all domains should be solved.
static LogicalResult unifyAssociations(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Operation *op, Value lhs, Value rhs)
Unify the associated domain rows of two terms.
DenseMap< VariableTerm *, unsigned > PendingSolutions
A map from unsolved variables to a port index, where that port has not yet been created.
static LogicalResult checkModule(const DomainInfo &info, FModuleOp moduleOp)
Check that a module's ports are fully annotated, before performing domain inference on the module.
static LogicalResult checkModuleDomainPortDrivers(const DomainInfo &info, FModuleOp moduleOp)
Check that output domain ports are driven.
static SmallVector< Attribute > copyPortDomainAssociations(const DomainInfo &info, ArrayAttr moduleDomainInfo, size_t portIndex)
Copy the domain associations from the moduleOp domain info attribute into a small vector.
static void noteLocation(mlir::InFlightDiagnostic &diag, Operation *op)
static void emitPortDomainCrossingError(const DomainInfo &info, T op, size_t i, DomainTypeID domainTypeID, Term *term1, Term *term2)
static void getUpdatesForDomainAssociationOfPort(const DomainInfo &info, Namespace &ns, PendingUpdates &pending, DomainTypeID typeID, size_t ip, LocationAttr loc, Term *term, const ExportTable &exports)
static void applyUpdatesToModule(const DomainInfo &info, TermAllocator &allocator, ExportTable &exports, DomainTable &table, FModuleOp moduleOp, const PendingUpdates &pending)
static void getUpdatesForModule(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, FModuleOp mod, PendingUpdates &pending)
static LogicalResult updateOp(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Operation *op)
static LogicalResult checkInstanceDomainPortDrivers(T op)
Check that the input domain ports are driven.
static void emitDomainPortInferenceError(T op, size_t i)
Emit an error when we fail to infer the concrete domain to drive to a domain port.
static void ensureExported(const DomainInfo &info, Namespace &ns, const ExportTable &exports, DomainTypeID typeID, size_t ip, LocationAttr loc, ValueTerm *val, PendingUpdates &pending)
Ensure that the domain value is available in the signature of the moduleOp, so that subsequent hardwa...
static bool isPort(BlockArgument arg)
Return true if the value is a port on the module.
DenseMap< DomainValue, TinyPtrVector< DomainValue > > ExportTable
A map from domain IR values defined internal to the moduleOp, to ports that alias that domain.
static Term * getTermForDomain(TermAllocator &allocator, DomainTable &table, DomainValue value)
Get the corresponding term for a domain in the IR.
static auto getPortDomainAssociation(ArrayAttr info, size_t i)
From a domain info attribute, get the row of associated domains for a hardware value at index i.
DenseMap< StringAttr, ModuleUpdateInfo > ModuleUpdateTable
static void ensureSolved(const DomainInfo &info, Namespace &ns, DomainTypeID typeID, size_t ip, LocationAttr loc, VariableTerm *var, PendingUpdates &pending)
If var is not solved, solve it by recording a pending input port at the indicated insertion point.
static bool isDriven(DomainValue port)
Returns true if the value is driven by a connect op.
static void solve(Term *lhs, Term *rhs)
static LogicalResult processModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FModuleOp moduleOp)
Populate the domain table by processing the moduleOp.
static LogicalResult checkModuleBody(FModuleOp moduleOp)
Check that instances under this module have driven domain input ports.
static LogicalResult checkModulePorts(const DomainInfo &info, FModuleLike moduleOp)
Check that a module's hardware ports have complete domain associations.
static Term * find(Term *x)
static LogicalResult updateInstance(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, T op)
static LogicalResult processInstancePorts(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, T op)
static LogicalResult checkAndInferModule(const DomainInfo &info, ModuleUpdateTable &updateTable, FModuleOp moduleOp)
Check that a module's ports are fully annotated, before performing domain inference on the module.
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
Definition Debug.h:70
This class represents a reference to a specific field or element of an aggregate value.
Definition FieldRef.h:28
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
InferDomainsMode
The mode for the InferDomains pass.
Definition Passes.h:73
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.