CIRCT 23.0.0git
Loading...
Searching...
No Matches
InferDomains.cpp
Go to the documentation of this file.
1//===- InferDomains.cpp - Infer and Check FIRRTL Domains ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// InferDomains implements FIRRTL domain inference and checking. This pass is a
10// bottom-up transform acting on modules. For each moduleOp, we ensure there are
11// no domain crossings, and we make explicit the domain associations of ports.
12//
13//===----------------------------------------------------------------------===//
14
19#include "circt/Support/Debug.h"
21#include "mlir/IR/Iterators.h"
22#include "mlir/IR/Threading.h"
23#include "llvm/ADT/DenseMap.h"
24#include "llvm/ADT/STLExtras.h"
25#include "llvm/ADT/SmallVector.h"
26#include "llvm/ADT/TinyPtrVector.h"
27
28#define DEBUG_TYPE "firrtl-infer-domains"
29
30namespace circt {
31namespace firrtl {
32#define GEN_PASS_DEF_INFERDOMAINS
33#include "circt/Dialect/FIRRTL/Passes.h.inc"
34} // namespace firrtl
35} // namespace circt
36
37using namespace circt;
38using namespace firrtl;
39
40using llvm::concat;
41using mlir::ReverseIterator;
42
43//====--------------------------------------------------------------------------
44// Helpers.
45//====--------------------------------------------------------------------------
46
47using DomainValue = mlir::TypedValue<DomainType>;
48
49using PortInsertions = SmallVector<std::pair<unsigned, PortInfo>>;
50
51/// From a domain info attribute, get the domain-type of a domain value at
52/// index i.
53static StringAttr getDomainPortTypeName(ArrayAttr info, size_t i) {
54 if (info.empty())
55 return nullptr;
56 return cast<FlatSymbolRefAttr>(info[i]).getAttr();
57}
58
59/// From a domain info attribute, get the row of associated domains for a
60/// hardware value at index i.
61static auto getPortDomainAssociation(ArrayAttr info, size_t i) {
62 if (info.empty())
63 return info.getAsRange<IntegerAttr>();
64 return cast<ArrayAttr>(info[i]).getAsRange<IntegerAttr>();
65}
66
67/// Return true if the value is a port on the module.
68static bool isPort(BlockArgument arg) {
69 return isa<FModuleOp>(arg.getOwner()->getParentOp());
70}
71
72/// Return true if the value is a port on the module.
73static bool isPort(Value value) {
74 auto arg = dyn_cast<BlockArgument>(value);
75 if (!arg)
76 return false;
77 return isPort(arg);
78}
79
80/// Returns true if the value is driven by a connect op.
81static bool isDriven(DomainValue port) {
82 for (auto *user : port.getUsers())
83 if (auto connect = dyn_cast<FConnectLike>(user))
84 if (connect.getDest() == port)
85 return true;
86 return false;
87}
88
89//====--------------------------------------------------------------------------
90// Global State.
91//====--------------------------------------------------------------------------
92
93/// Each domain type declared in the circuit is assigned a type-id, based on the
94/// order of declaration. Domain associations for hardware values are
95/// represented as a list, or row, of domains. The domains in a row are ordered
96/// according to their type's id.
97namespace {
98struct DomainTypeID {
99 size_t index;
100};
101} // namespace
102
103/// Information about the domains in the circuit. Able to map domains to their
104/// type ID, which in this pass is the canonical way to reference the type
105/// of a domain, as well as provide fast access to domain ops.
106namespace {
107class DomainInfo {
108public:
109 DomainInfo(CircuitOp circuit) { processCircuit(circuit); }
110
111 ArrayRef<DomainOp> getDomains() const { return domainTable; }
112 size_t getNumDomains() const { return domainTable.size(); }
113 DomainOp getDomain(DomainTypeID id) const { return domainTable[id.index]; }
114
115 DomainTypeID getDomainTypeID(StringAttr name) const {
116 return typeIDTable.at(name);
117 }
118
119 DomainTypeID getDomainTypeID(FlatSymbolRefAttr ref) const {
120 return getDomainTypeID(ref.getAttr());
121 }
122
123 DomainTypeID getDomainTypeID(ArrayAttr info, size_t i) const {
124 auto name = getDomainPortTypeName(info, i);
125 return getDomainTypeID(name);
126 }
127
128 DomainTypeID getDomainTypeID(DomainValue value) const {
129 if (auto arg = dyn_cast<BlockArgument>(value)) {
130 auto *block = arg.getOwner();
131 auto *owner = block->getParentOp();
132 auto moduleOp = cast<FModuleOp>(owner);
133 auto info = moduleOp.getDomainInfoAttr();
134 auto i = arg.getArgNumber();
135 return getDomainTypeID(info, i);
136 }
137
138 auto result = dyn_cast<OpResult>(value);
139 auto *owner = result.getOwner();
140
141 auto info = TypeSwitch<Operation *, ArrayAttr>(owner)
142 .Case<InstanceOp, InstanceChoiceOp>(
143 [&](auto inst) { return inst.getDomainInfoAttr(); })
144 .Default([&](auto inst) { return nullptr; });
145 assert(info && "unable to obtain domain information from op");
146
147 auto i = result.getResultNumber();
148 return getDomainTypeID(info, i);
149 }
150
151private:
152 void processDomain(DomainOp op) {
153 auto index = domainTable.size();
154 auto name = op.getNameAttr();
155 domainTable.push_back(op);
156 typeIDTable.insert({name, {index}});
157 }
158
159 void processCircuit(CircuitOp circuit) {
160 for (auto decl : circuit.getOps<DomainOp>())
161 processDomain(decl);
162 }
163
164 /// A map from domain type ID to op.
165 SmallVector<DomainOp> domainTable;
166
167 /// A map from domain name to type ID.
168 DenseMap<StringAttr, DomainTypeID> typeIDTable;
169};
170
171/// Information about the changes made to the interface of a moduleOp, which can
172/// be replayed onto an instance.
173struct ModuleUpdateInfo {
174 /// The updated domain information for a moduleOp.
175 ArrayAttr portDomainInfo;
176 /// The domain ports which have been inserted into a moduleOp.
177 PortInsertions portInsertions;
178};
179} // namespace
180
181using ModuleUpdateTable = DenseMap<StringAttr, ModuleUpdateInfo>;
182
183/// Apply the port changes of a moduleOp onto an instance-like op.
184template <typename T>
185static T fixInstancePorts(T op, const ModuleUpdateInfo &update) {
186 auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions);
187 clone.setDomainInfoAttr(update.portDomainInfo);
188 op->erase();
189 return clone;
190}
191
192//====--------------------------------------------------------------------------
193// Terms: Syntax for unifying domain and domain-rows.
194//====--------------------------------------------------------------------------
195
196/// The different sorts of terms in the unification engine.
197namespace {
198enum class TermKind {
199 Variable,
200 Value,
201 Row,
202};
203} // namespace
204
205/// A term in the unification engine.
206namespace {
207struct Term {
208 constexpr Term(TermKind kind) : kind(kind) {}
209 TermKind kind;
210};
211} // namespace
212
213/// Helper to define a term kind.
214namespace {
215template <TermKind K>
216struct TermBase : Term {
217 static bool classof(const Term *term) { return term->kind == K; }
218 TermBase() : Term(K) {}
219};
220} // namespace
221
222/// An unknown value.
223namespace {
224struct VariableTerm : public TermBase<TermKind::Variable> {
225 VariableTerm() : leader(nullptr) {}
226 VariableTerm(Term *leader) : leader(leader) {}
227 Term *leader;
228};
229} // namespace
230
231/// A concrete value defined in the IR.
232namespace {
233struct ValueTerm : public TermBase<TermKind::Value> {
234 ValueTerm(DomainValue value) : value(value) {}
235 DomainValue value;
236};
237} // namespace
238
239/// A row of domains.
240namespace {
241struct RowTerm : public TermBase<TermKind::Row> {
242 RowTerm(ArrayRef<Term *> elements) : elements(elements) {}
243 ArrayRef<Term *> elements;
244};
245} // namespace
246
247// NOLINTNEXTLINE(misc-no-recursion)
248static Term *find(Term *x) {
249 if (!x)
250 return nullptr;
251
252 if (auto *var = dyn_cast<VariableTerm>(x)) {
253 if (var->leader == nullptr)
254 return var;
255
256 auto *leader = find(var->leader);
257 if (leader != var->leader)
258 var->leader = leader;
259 return leader;
260 }
261
262 return x;
263}
264
265/// A helper for assigning low numeric IDs to variables for user-facing output.
266namespace {
267class VariableIDTable {
268public:
269 size_t get(VariableTerm *term) {
270 return table.insert({term, table.size() + 1}).first->second;
271 }
272
273private:
274 DenseMap<VariableTerm *, size_t> table;
275};
276} // namespace
277
278// NOLINTNEXTLINE(misc-no-recursion)
279static void render(const DomainInfo &info, Diagnostic &out,
280 VariableIDTable &idTable, Term *term) {
281 term = find(term);
282 if (auto *var = dyn_cast<VariableTerm>(term)) {
283 out << "?" << idTable.get(var);
284 return;
285 }
286 if (auto *val = dyn_cast<ValueTerm>(term)) {
287 auto value = val->value;
288 auto [name, _] = getFieldName(FieldRef(value, 0), false);
289 out << name;
290 return;
291 }
292 if (auto *row = dyn_cast<RowTerm>(term)) {
293 bool first = true;
294 out << "[";
295 for (size_t i = 0, e = info.getNumDomains(); i < e; ++i) {
296 auto domainOp = info.getDomain(DomainTypeID{i});
297 if (!first) {
298 out << ", ";
299 first = false;
300 }
301 out << domainOp.getName() << ": ";
302 render(info, out, idTable, row->elements[i]);
303 }
304 out << "]";
305 return;
306 }
307}
308
309static LogicalResult unify(Term *lhs, Term *rhs);
310
311static LogicalResult unify(VariableTerm *x, Term *y) {
312 assert(!x->leader);
313 x->leader = y;
314 return success();
315}
316
317static LogicalResult unify(ValueTerm *xv, Term *y) {
318 if (auto *yv = dyn_cast<VariableTerm>(y)) {
319 yv->leader = xv;
320 return success();
321 }
322
323 if (auto *yv = dyn_cast<ValueTerm>(y))
324 return success(xv == yv);
325
326 return failure();
327}
328
329// NOLINTNEXTLINE(misc-no-recursion)
330static LogicalResult unify(RowTerm *lhsRow, Term *rhs) {
331 if (auto *rhsVar = dyn_cast<VariableTerm>(rhs)) {
332 rhsVar->leader = lhsRow;
333 return success();
334 }
335 if (auto *rhsRow = dyn_cast<RowTerm>(rhs)) {
336 for (auto [x, y] : llvm::zip_equal(lhsRow->elements, rhsRow->elements))
337 if (failed(unify(x, y)))
338 return failure();
339 return success();
340 }
341 return failure();
342}
343
344// NOLINTNEXTLINE(misc-no-recursion)
345static LogicalResult unify(Term *lhs, Term *rhs) {
346 if (!lhs || !rhs)
347 return success();
348 lhs = find(lhs);
349 rhs = find(rhs);
350 if (lhs == rhs)
351 return success();
352 if (auto *lhsVar = dyn_cast<VariableTerm>(lhs))
353 return unify(lhsVar, rhs);
354 if (auto *lhsVal = dyn_cast<ValueTerm>(lhs))
355 return unify(lhsVal, rhs);
356 if (auto *lhsRow = dyn_cast<RowTerm>(lhs))
357 return unify(lhsRow, rhs);
358 return failure();
359}
360
361static void solve(Term *lhs, Term *rhs) {
362 [[maybe_unused]] auto result = unify(lhs, rhs);
363 assert(result.succeeded());
364}
365
366namespace {
367class TermAllocator {
368public:
369 /// Allocate a row of fresh domain variables.
370 [[nodiscard]] RowTerm *allocRow(size_t size) {
371 SmallVector<Term *> elements;
372 elements.resize(size);
373 return allocRow(elements);
374 }
375
376 /// Allocate a row of terms.
377 [[nodiscard]] RowTerm *allocRow(ArrayRef<Term *> elements) {
378 auto ds = allocArray(elements);
379 return alloc<RowTerm>(ds);
380 }
381
382 /// Allocate a fresh variable.
383 [[nodiscard]] VariableTerm *allocVar() { return alloc<VariableTerm>(); }
384
385 /// Allocate a concrete domain.
386 [[nodiscard]] ValueTerm *allocVal(DomainValue value) {
387 return alloc<ValueTerm>(value);
388 }
389
390private:
391 template <typename T, typename... Args>
392 [[nodiscard]] T *alloc(Args &&...args) {
393 static_assert(std::is_base_of_v<Term, T>, "T must be a term");
394 return new (allocator) T(std::forward<Args>(args)...);
395 }
396
397 [[nodiscard]] ArrayRef<Term *> allocArray(ArrayRef<Term *> elements) {
398 auto size = elements.size();
399 if (size == 0)
400 return {};
401
402 auto *result = allocator.Allocate<Term *>(size);
403 llvm::uninitialized_copy(elements, result);
404 for (size_t i = 0; i < size; ++i)
405 if (!result[i])
406 result[i] = alloc<VariableTerm>();
407
408 return ArrayRef(result, size);
409 }
410
411 llvm::BumpPtrAllocator allocator;
412};
413} // namespace
414
415//====--------------------------------------------------------------------------
416// DomainTable: A mapping from IR to terms.
417//====--------------------------------------------------------------------------
418
419namespace {
420/// Tracks domain infomation for IR values.
421class DomainTable {
422public:
423 /// If the domain value is an alias, returns the domain it aliases.
424 DomainValue getOptUnderlyingDomain(DomainValue value) const {
425 auto *term = getOptTermForDomain(value);
426 if (auto *val = llvm::dyn_cast_if_present<ValueTerm>(term))
427 return val->value;
428 return nullptr;
429 }
430
431 /// Get the corresponding term for a domain in the IR, or null if unset.
432 Term *getOptTermForDomain(DomainValue value) const {
433 assert(isa<DomainType>(value.getType()));
434 auto it = termTable.find(value);
435 if (it == termTable.end())
436 return nullptr;
437 return find(it->second);
438 }
439
440 /// Get the corresponding term for a domain in the IR.
441 Term *getTermForDomain(DomainValue value) const {
442 auto *term = getOptTermForDomain(value);
443 assert(term);
444 return term;
445 }
446
447 /// Record a mapping from domain in the IR to its corresponding term.
448 void setTermForDomain(DomainValue value, Term *term) {
449 assert(term);
450 assert(!termTable.contains(value));
451 termTable.insert({value, term});
452 }
453
454 /// For a hardware value, get the term which represents the row of associated
455 /// domains. If no mapping has been defined, returns nullptr.
456 Term *getOptDomainAssociation(Value value) const {
457 assert(isa<FIRRTLBaseType>(value.getType()));
458 auto it = associationTable.find(value);
459 if (it == associationTable.end())
460 return nullptr;
461 return find(it->second);
462 }
463
464 /// For a hardware value, get the term which represents the row of associated
465 /// domains.
466 Term *getDomainAssociation(Value value) const {
467 auto *term = getOptDomainAssociation(value);
468 assert(term);
469 return term;
470 }
471
472 /// Record a mapping from a hardware value in the IR to a term which
473 /// represents the row of domains it is associated with.
474 void setDomainAssociation(Value value, Term *term) {
475 assert(isa<FIRRTLBaseType>(value.getType()));
476 assert(term);
477 term = find(term);
478 associationTable.insert({value, term});
479 }
480
481private:
482 /// Map from domains in the IR to their underlying term.
483 DenseMap<Value, Term *> termTable;
484
485 /// A map from hardware values to their associated row of domains, as a term.
486 DenseMap<Value, Term *> associationTable;
487};
488} // namespace
489
490//====--------------------------------------------------------------------------
491// Module processing: solve for the domain associations of hardware.
492//====--------------------------------------------------------------------------
493
494/// Get the corresponding term for a domain in the IR. If we don't know what the
495/// term is, then map the domain in the IR to a variable term.
496static Term *getTermForDomain(TermAllocator &allocator, DomainTable &table,
497 DomainValue value) {
498 assert(isa<DomainType>(value.getType()));
499 if (auto *term = table.getOptTermForDomain(value))
500 return term;
501 auto *term = allocator.allocVar();
502 table.setTermForDomain(value, term);
503 return term;
504}
505
506static void processDomainDefinition(TermAllocator &allocator,
507 DomainTable &table, DomainValue domain) {
508 assert(isa<DomainType>(domain.getType()));
509 auto *newTerm = allocator.allocVal(domain);
510 auto *oldTerm = table.getOptTermForDomain(domain);
511 if (!oldTerm) {
512 table.setTermForDomain(domain, newTerm);
513 return;
514 }
515
516 [[maybe_unused]] auto result = unify(oldTerm, newTerm);
517 assert(result.succeeded());
518}
519
520/// Get the row of domains that a hardware value in the IR is associated with.
521/// The returned term is forced to be at least a row.
522static RowTerm *getDomainAssociationAsRow(const DomainInfo &info,
523 TermAllocator &allocator,
524 DomainTable &table, Value value) {
525 assert(isa<FIRRTLBaseType>(value.getType()));
526 auto *term = table.getOptDomainAssociation(value);
527
528 // If the term is unknown, allocate a fresh row and set the association.
529 if (!term) {
530 auto *row = allocator.allocRow(info.getNumDomains());
531 table.setDomainAssociation(value, row);
532 return row;
533 }
534
535 // If the term is already a row, return it.
536 if (auto *row = dyn_cast<RowTerm>(term))
537 return row;
538
539 // Otherwise, unify the term with a fresh row of domains.
540 if (auto *var = dyn_cast<VariableTerm>(term)) {
541 auto *row = allocator.allocRow(info.getNumDomains());
542 solve(var, row);
543 return row;
544 }
545
546 assert(false && "unhandled term type");
547 return nullptr;
548}
549
550static void noteLocation(mlir::InFlightDiagnostic &diag, Operation *op) {
551 auto &note = diag.attachNote(op->getLoc());
552 if (auto mod = dyn_cast<FModuleOp>(op)) {
553 note << "in module " << mod.getModuleNameAttr();
554 return;
555 }
556 if (auto mod = dyn_cast<FExtModuleOp>(op)) {
557 note << "in extmodule " << mod.getModuleNameAttr();
558 return;
559 }
560 if (auto inst = dyn_cast<InstanceOp>(op)) {
561 note << "in instance " << inst.getInstanceNameAttr();
562 return;
563 }
564 if (auto inst = dyn_cast<InstanceChoiceOp>(op)) {
565 note << "in instance_choice " << inst.getNameAttr();
566 return;
567 }
568
569 note << "here";
570}
571
572template <typename T>
573static void emitPortDomainCrossingError(const DomainInfo &info, T op, size_t i,
574 DomainTypeID domainTypeID, Term *term1,
575 Term *term2) {
576 VariableIDTable idTable;
577
578 auto portName = op.getPortNameAttr(i);
579 auto portLoc = op.getPortLocation(i);
580 auto domainDecl = info.getDomain(domainTypeID);
581 auto domainName = domainDecl.getNameAttr();
582
583 auto diag = emitError(portLoc);
584 diag << "illegal " << domainName << " crossing in port " << portName;
585
586 auto &note1 = diag.attachNote();
587 note1 << "1st instance: ";
588 render(info, note1, idTable, term1);
589
590 auto &note2 = diag.attachNote();
591 note2 << "2nd instance: ";
592 render(info, note2, idTable, term2);
593
594 noteLocation(diag, op);
595}
596
597template <typename T>
598static void emitDuplicatePortDomainError(const DomainInfo &info, T op, size_t i,
599 DomainTypeID domainTypeID,
600 IntegerAttr domainPortIndexAttr1,
601 IntegerAttr domainPortIndexAttr2) {
602 VariableIDTable idTable;
603 auto portName = op.getPortNameAttr(i);
604 auto portLoc = op.getPortLocation(i);
605 auto domainDecl = info.getDomain(domainTypeID);
606 auto domainName = domainDecl.getNameAttr();
607 auto domainPortIndex1 = domainPortIndexAttr1.getUInt();
608 auto domainPortIndex2 = domainPortIndexAttr2.getUInt();
609 auto domainPortName1 = op.getPortNameAttr(domainPortIndex1);
610 auto domainPortName2 = op.getPortNameAttr(domainPortIndex2);
611 auto domainPortLoc1 = op.getPortLocation(domainPortIndex1);
612 auto domainPortLoc2 = op.getPortLocation(domainPortIndex2);
613 auto diag = emitError(portLoc);
614 diag << "duplicate " << domainName << " association for port " << portName;
615 auto &note1 = diag.attachNote(domainPortLoc1);
616 note1 << "associated with " << domainName << " port " << domainPortName1;
617 auto &note2 = diag.attachNote(domainPortLoc2);
618 note2 << "associated with " << domainName << " port " << domainPortName2;
619 noteLocation(diag, op);
620}
621
622/// Emit an error when we fail to infer the concrete domain to drive to a
623/// domain port.
624template <typename T>
625static void emitDomainPortInferenceError(T op, size_t i) {
626 auto name = op.getPortNameAttr(i);
627 auto diag = emitError(op->getLoc());
628 auto info = op.getDomainInfo();
629 diag << "unable to infer value for undriven domain port " << name;
630 for (size_t j = 0, e = op.getNumPorts(); j < e; ++j) {
631 if (auto assocs = dyn_cast<ArrayAttr>(info[j])) {
632 for (auto assoc : assocs) {
633 if (i == cast<IntegerAttr>(assoc).getValue()) {
634 auto name = op.getPortNameAttr(j);
635 auto loc = op.getPortLocation(j);
636 diag.attachNote(loc) << "associated with hardware port " << name;
637 break;
638 }
639 }
640 }
641 }
642 noteLocation(diag, op);
643}
644
645template <typename T>
647 const DomainInfo &info, T op,
648 const llvm::TinyPtrVector<DomainValue> &exports, DomainTypeID typeID,
649 size_t i) {
650 auto portName = op.getPortNameAttr(i);
651 auto portLoc = op.getPortLocation(i);
652 auto domainDecl = info.getDomain(typeID);
653 auto domainName = domainDecl.getNameAttr();
654 auto diag = emitError(portLoc) << "ambiguous " << domainName
655 << " association for port " << portName;
656 for (auto e : exports) {
657 auto arg = cast<BlockArgument>(e);
658 auto name = op.getPortNameAttr(arg.getArgNumber());
659 auto loc = op.getPortLocation(arg.getArgNumber());
660 diag.attachNote(loc) << "candidate association " << name;
661 }
662 noteLocation(diag, op);
663}
664
665template <typename T>
666static void emitMissingPortDomainAssociationError(const DomainInfo &info, T op,
667 DomainTypeID typeID,
668 size_t i) {
669 auto domainName = info.getDomain(typeID).getNameAttr();
670 auto portName = op.getPortNameAttr(i);
671 auto diag = emitError(op.getPortLocation(i))
672 << "missing " << domainName << " association for port "
673 << portName;
674 noteLocation(diag, op);
675}
676
677/// Unify the associated domain rows of two terms.
678static LogicalResult unifyAssociations(const DomainInfo &info,
679 TermAllocator &allocator,
680 DomainTable &table, Operation *op,
681 Value lhs, Value rhs) {
682 if (!lhs || !rhs)
683 return success();
684
685 if (lhs == rhs)
686 return success();
687
688 auto *lhsTerm = table.getOptDomainAssociation(lhs);
689 auto *rhsTerm = table.getOptDomainAssociation(rhs);
690
691 if (lhsTerm) {
692 if (rhsTerm) {
693 if (failed(unify(lhsTerm, rhsTerm))) {
694 auto diag = op->emitOpError("illegal domain crossing in operation");
695 auto &note1 = diag.attachNote(lhs.getLoc());
696
697 note1 << "1st operand has domains: ";
698 VariableIDTable idTable;
699 render(info, note1, idTable, lhsTerm);
700
701 auto &note2 = diag.attachNote(rhs.getLoc());
702 note2 << "2nd operand has domains: ";
703 render(info, note2, idTable, rhsTerm);
704
705 return failure();
706 }
707 }
708 table.setDomainAssociation(rhs, lhsTerm);
709 return success();
710 }
711
712 if (rhsTerm) {
713 table.setDomainAssociation(lhs, rhsTerm);
714 return success();
715 }
716
717 auto *var = allocator.allocVar();
718 table.setDomainAssociation(lhs, var);
719 table.setDomainAssociation(rhs, var);
720 return success();
721}
722
723static LogicalResult processModulePorts(const DomainInfo &info,
724 TermAllocator &allocator,
725 DomainTable &table,
726 FModuleOp moduleOp) {
727 auto numDomains = info.getNumDomains();
728 auto domainInfo = moduleOp.getDomainInfoAttr();
729 auto numPorts = moduleOp.getNumPorts();
730
731 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
732 for (size_t i = 0; i < numPorts; ++i) {
733 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
734 if (!port)
735 continue;
736
737 if (moduleOp.getPortDirection(i) == Direction::In)
738 processDomainDefinition(allocator, table, port);
739
740 domainTypeIDTable[i] = info.getDomainTypeID(domainInfo, i);
741 }
742
743 for (size_t i = 0; i < numPorts; ++i) {
744 BlockArgument port = moduleOp.getArgument(i);
745 auto type = type_dyn_cast<FIRRTLBaseType>(port.getType());
746 if (!type)
747 continue;
748
749 SmallVector<IntegerAttr> associations(numDomains);
750 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
751 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
752 auto prevDomainPortIndex = associations[domainTypeID.index];
753 if (prevDomainPortIndex) {
754 emitDuplicatePortDomainError(info, moduleOp, i, domainTypeID,
755 prevDomainPortIndex, domainPortIndex);
756 return failure();
757 }
758 associations[domainTypeID.index] = domainPortIndex;
759 }
760
761 SmallVector<Term *> elements(numDomains);
762 for (size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
763 ++domainTypeIndex) {
764 auto domainPortIndex = associations[domainTypeIndex];
765 if (!domainPortIndex)
766 continue;
767 auto domainPortValue =
768 cast<DomainValue>(moduleOp.getArgument(domainPortIndex.getUInt()));
769 elements[domainTypeIndex] =
770 getTermForDomain(allocator, table, domainPortValue);
771 }
772
773 auto *domainAssociations = allocator.allocRow(elements);
774 table.setDomainAssociation(port, domainAssociations);
775 }
776
777 return success();
778}
779
780template <typename T>
781static LogicalResult processInstancePorts(const DomainInfo &info,
782 TermAllocator &allocator,
783 DomainTable &table, T op) {
784 auto numDomains = info.getNumDomains();
785 auto domainInfo = op.getDomainInfoAttr();
786 auto numPorts = op.getNumPorts();
787
788 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
789 for (size_t i = 0; i < numPorts; ++i) {
790 auto port = dyn_cast<DomainValue>(op.getResult(i));
791 if (!port)
792 continue;
793
794 if (op.getPortDirection(i) == Direction::Out)
795 processDomainDefinition(allocator, table, port);
796
797 domainTypeIDTable[i] = info.getDomainTypeID(domainInfo, i);
798 }
799
800 for (size_t i = 0; i < numPorts; ++i) {
801 Value port = op.getResult(i);
802 auto type = type_dyn_cast<FIRRTLBaseType>(port.getType());
803 if (!type)
804 continue;
805
806 SmallVector<IntegerAttr> associations(numDomains);
807 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
808 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
809 auto prevDomainPortIndex = associations[domainTypeID.index];
810 if (prevDomainPortIndex) {
811 emitDuplicatePortDomainError(info, op, i, domainTypeID,
812 prevDomainPortIndex, domainPortIndex);
813 return failure();
814 }
815 associations[domainTypeID.index] = domainPortIndex;
816 }
817
818 SmallVector<Term *> elements(numDomains);
819 for (size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
820 ++domainTypeIndex) {
821 auto domainPortIndex = associations[domainTypeIndex];
822 if (!domainPortIndex)
823 continue;
824 auto domainPortValue =
825 cast<DomainValue>(op.getResult(domainPortIndex.getUInt()));
826 elements[domainTypeIndex] =
827 getTermForDomain(allocator, table, domainPortValue);
828 }
829
830 auto *domainAssociations = allocator.allocRow(elements);
831 table.setDomainAssociation(port, domainAssociations);
832 }
833
834 return success();
835}
836
837static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
838 DomainTable &table,
839 const ModuleUpdateTable &updateTable,
840 InstanceOp op) {
841 auto moduleOp = op.getReferencedModuleNameAttr();
842 auto lookup = updateTable.find(moduleOp);
843 if (lookup != updateTable.end())
844 op = fixInstancePorts(op, lookup->second);
845 return processInstancePorts(info, allocator, table, op);
846}
847
848static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
849 DomainTable &table,
850 const ModuleUpdateTable &updateTable,
851 InstanceChoiceOp op) {
852 auto moduleOp = op.getDefaultTargetAttr().getAttr();
853 auto lookup = updateTable.find(moduleOp);
854 if (lookup != updateTable.end())
855 op = fixInstancePorts(op, lookup->second);
856 return processInstancePorts(info, allocator, table, op);
857}
858
859static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
860 DomainTable &table, UnsafeDomainCastOp op) {
861 auto domains = op.getDomains();
862 if (domains.empty())
863 return unifyAssociations(info, allocator, table, op, op.getInput(),
864 op.getResult());
865
866 auto input = op.getInput();
867 RowTerm *inputRow = getDomainAssociationAsRow(info, allocator, table, input);
868 SmallVector<Term *> elements(inputRow->elements);
869 for (auto value : op.getDomains()) {
870 auto domain = cast<DomainValue>(value);
871 auto typeID = info.getDomainTypeID(domain);
872 elements[typeID.index] = getTermForDomain(allocator, table, domain);
873 }
874
875 auto *row = allocator.allocRow(elements);
876 table.setDomainAssociation(op.getResult(), row);
877 return success();
878}
879
880static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
881 DomainTable &table, DomainDefineOp op) {
882 auto src = op.getSrc();
883 auto dst = op.getDest();
884 auto *srcTerm = getTermForDomain(allocator, table, src);
885 auto *dstTerm = getTermForDomain(allocator, table, dst);
886 if (succeeded(unify(dstTerm, srcTerm)))
887 return success();
888
889 VariableIDTable idTable;
890 auto diag = op->emitOpError("failed to propagate source to destination");
891 auto &note1 = diag.attachNote();
892 note1 << "destination has underlying value: ";
893 render(info, note1, idTable, dstTerm);
894
895 auto &note2 = diag.attachNote(src.getLoc());
896 note2 << "source has underlying value: ";
897 render(info, note2, idTable, srcTerm);
898 return failure();
899}
900
901static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
902 DomainTable &table,
903 const ModuleUpdateTable &updateTable,
904 Operation *op) {
905 if (auto instance = dyn_cast<InstanceOp>(op))
906 return processOp(info, allocator, table, updateTable, instance);
907 if (auto instance = dyn_cast<InstanceChoiceOp>(op))
908 return processOp(info, allocator, table, updateTable, instance);
909 if (auto cast = dyn_cast<UnsafeDomainCastOp>(op))
910 return processOp(info, allocator, table, cast);
911 if (auto def = dyn_cast<DomainDefineOp>(op))
912 return processOp(info, allocator, table, def);
913
914 // For all other operations (including connections), propagate domains from
915 // operands to results. This is a conservative approach - all operands and
916 // results share the same domain associations.
917 Value lhs;
918 for (auto rhs : op->getOperands()) {
919 if (!isa<FIRRTLBaseType>(rhs.getType()))
920 continue;
921 if (auto *op = rhs.getDefiningOp();
922 op && op->hasTrait<OpTrait::ConstantLike>())
923 continue;
924 if (failed(unifyAssociations(info, allocator, table, op, lhs, rhs)))
925 return failure();
926 lhs = rhs;
927 }
928 for (auto rhs : op->getResults()) {
929 if (!isa<FIRRTLBaseType>(rhs.getType()))
930 continue;
931 if (auto *op = rhs.getDefiningOp();
932 op && op->hasTrait<OpTrait::ConstantLike>())
933 continue;
934 if (failed(unifyAssociations(info, allocator, table, op, lhs, rhs)))
935 return failure();
936 lhs = rhs;
937 }
938 return success();
939}
940
941static LogicalResult processModuleBody(const DomainInfo &info,
942 TermAllocator &allocator,
943 DomainTable &table,
944 const ModuleUpdateTable &updateTable,
945 FModuleOp moduleOp) {
946 auto result = moduleOp.getBody().walk([&](Operation *op) -> WalkResult {
947 return processOp(info, allocator, table, updateTable, op);
948 });
949 return failure(result.wasInterrupted());
950}
951
952/// Populate the domain table by processing the moduleOp. If the moduleOp has
953/// any domain crossing errors, return failure.
954static LogicalResult processModule(const DomainInfo &info,
955 TermAllocator &allocator, DomainTable &table,
956 const ModuleUpdateTable &updateTable,
957 FModuleOp moduleOp) {
958 if (failed(processModulePorts(info, allocator, table, moduleOp)))
959 return failure();
960 if (failed(processModuleBody(info, allocator, table, updateTable, moduleOp)))
961 return failure();
962 return success();
963}
964
965//===---------------------------------------------------------------------------
966// ExportTable
967//===---------------------------------------------------------------------------
968
969/// A map from domain IR values defined internal to the moduleOp, to ports that
970/// alias that domain. These ports make the domain useable as associations of
971/// ports, and we say these are exporting ports.
972using ExportTable = DenseMap<DomainValue, TinyPtrVector<DomainValue>>;
973
974/// Build a table of exported domains: a map from domains defined internally,
975/// to their set of aliasing output ports.
976static ExportTable initializeExportTable(const DomainTable &table,
977 FModuleOp moduleOp) {
978 ExportTable exports;
979 size_t numPorts = moduleOp.getNumPorts();
980 for (size_t i = 0; i < numPorts; ++i) {
981 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
982 if (!port)
983 continue;
984 auto value = table.getOptUnderlyingDomain(port);
985 if (value)
986 exports[value].push_back(port);
987 }
988
989 return exports;
990}
991
992//====--------------------------------------------------------------------------
993// Updating: write domains back to the IR.
994//====--------------------------------------------------------------------------
995
996/// A map from unsolved variables to a port index, where that port has not yet
997/// been created. Eventually we will have an input domain at the port index,
998/// which will be the solution to the recorded variable.
999using PendingSolutions = DenseMap<VariableTerm *, unsigned>;
1000
1001/// A map from local domains to an aliasing port index, where that port has not
1002/// yet been created. Eventually we will be exporting the domain value at the
1003/// port index.
1004using PendingExports = llvm::MapVector<DomainValue, unsigned>;
1005
1006namespace {
1007struct PendingUpdates {
1008 PortInsertions insertions;
1009 PendingSolutions solutions;
1010 PendingExports exports;
1011};
1012} // namespace
1013
1014/// If `var` is not solved, solve it by recording a pending input port at
1015/// the indicated insertion point.
1016static void ensureSolved(const DomainInfo &info, Namespace &ns,
1017 DomainTypeID typeID, size_t ip, LocationAttr loc,
1018 VariableTerm *var, PendingUpdates &pending) {
1019 if (pending.solutions.contains(var))
1020 return;
1021
1022 auto *context = loc.getContext();
1023 auto domainDecl = info.getDomain(typeID);
1024 auto domainName = domainDecl.getNameAttr();
1025
1026 auto portName = StringAttr::get(context, ns.newName(domainName.getValue()));
1027 auto portType = DomainType::get(loc.getContext());
1028 auto portDirection = Direction::In;
1029 auto portSym = StringAttr();
1030 auto portLoc = loc;
1031 auto portAnnos = std::nullopt;
1032 auto portDomainInfo = FlatSymbolRefAttr::get(domainName);
1033 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1034 portAnnos, portDomainInfo);
1035
1036 pending.solutions[var] = pending.insertions.size() + ip;
1037 pending.insertions.push_back({ip, portInfo});
1038}
1039
1040/// Ensure that the domain value is available in the signature of the moduleOp,
1041/// so that subsequent hardware ports may be associated with this domain.
1042// If the domain is defined internally in the moduleOp, ensure it is aliased by
1043// an
1044/// output port.
1045static void ensureExported(const DomainInfo &info, Namespace &ns,
1046 const ExportTable &exports, DomainTypeID typeID,
1047 size_t ip, LocationAttr loc, ValueTerm *val,
1048 PendingUpdates &pending) {
1049 auto value = val->value;
1050 assert(isa<DomainType>(value.getType()));
1051 if (isPort(value) || exports.contains(value) ||
1052 pending.exports.contains(value))
1053 return;
1054
1055 auto *context = loc.getContext();
1056
1057 auto domainDecl = info.getDomain(typeID);
1058 auto domainName = domainDecl.getNameAttr();
1059
1060 auto portName = StringAttr::get(context, ns.newName(domainName.getValue()));
1061 auto portType = DomainType::get(loc.getContext());
1062 auto portDirection = Direction::Out;
1063 auto portSym = StringAttr();
1064 auto portLoc = value.getLoc();
1065 auto portAnnos = std::nullopt;
1066 auto portDomainInfo = FlatSymbolRefAttr::get(domainName);
1067 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1068 portAnnos, portDomainInfo);
1069 pending.exports[value] = pending.insertions.size() + ip;
1070 pending.insertions.push_back({ip, portInfo});
1071}
1072
1073static void getUpdatesForDomainAssociationOfPort(const DomainInfo &info,
1074 Namespace &ns,
1075 PendingUpdates &pending,
1076 DomainTypeID typeID, size_t ip,
1077 LocationAttr loc, Term *term,
1078 const ExportTable &exports) {
1079 if (auto *var = dyn_cast<VariableTerm>(term)) {
1080 ensureSolved(info, ns, typeID, ip, loc, var, pending);
1081 return;
1082 }
1083 if (auto *val = dyn_cast<ValueTerm>(term)) {
1084 ensureExported(info, ns, exports, typeID, ip, loc, val, pending);
1085 return;
1086 }
1087 llvm_unreachable("invalid domain association");
1088}
1089
1091 const DomainInfo &info, Namespace &ns, const ExportTable &exports,
1092 size_t ip, LocationAttr loc, RowTerm *row, PendingUpdates &pending) {
1093 for (auto [index, term] : llvm::enumerate(row->elements))
1094 getUpdatesForDomainAssociationOfPort(info, ns, pending, DomainTypeID{index},
1095 ip, loc, find(term), exports);
1096}
1097
1098static void getUpdatesForModulePorts(const DomainInfo &info,
1099 TermAllocator &allocator,
1100 const ExportTable &exports,
1101 DomainTable &table, Namespace &ns,
1102 FModuleOp moduleOp,
1103 PendingUpdates &pending) {
1104 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1105 auto port = moduleOp.getArgument(i);
1106 auto type = port.getType();
1107 if (!isa<FIRRTLBaseType>(type))
1108 continue;
1110 info, ns, exports, i, moduleOp.getPortLocation(i),
1111 getDomainAssociationAsRow(info, allocator, table, port), pending);
1112 }
1113}
1114
1115static void getUpdatesForModule(const DomainInfo &info,
1116 TermAllocator &allocator,
1117 const ExportTable &exports, DomainTable &table,
1118 FModuleOp mod, PendingUpdates &pending) {
1119 Namespace ns;
1120 auto names = mod.getPortNamesAttr();
1121 for (auto name : names.getAsRange<StringAttr>())
1122 ns.add(name);
1123 getUpdatesForModulePorts(info, allocator, exports, table, ns, mod, pending);
1124}
1125
1126static void applyUpdatesToModule(const DomainInfo &info,
1127 TermAllocator &allocator, ExportTable &exports,
1128 DomainTable &table, FModuleOp moduleOp,
1129 const PendingUpdates &pending) {
1130 // Put the domain ports in place.
1131 moduleOp.insertPorts(pending.insertions);
1132
1133 // Solve any variables and record them as "self-exporting".
1134 for (auto [var, portIndex] : pending.solutions) {
1135 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1136 auto *solution = allocator.allocVal(portValue);
1137 solve(var, solution);
1138 exports[portValue].push_back(portValue);
1139 }
1140
1141 // Drive the output ports, and record the export.
1142 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1143 for (auto [domainValue, portIndex] : pending.exports) {
1144 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1145 builder.setInsertionPointAfterValue(domainValue);
1146 DomainDefineOp::create(builder, portValue.getLoc(), portValue, domainValue);
1147
1148 exports[domainValue].push_back(portValue);
1149 table.setTermForDomain(portValue, allocator.allocVal(domainValue));
1150 }
1151}
1152
1153/// Copy the domain associations from the moduleOp domain info attribute into a
1154/// small vector.
1155static SmallVector<Attribute>
1156copyPortDomainAssociations(const DomainInfo &info, ArrayAttr moduleDomainInfo,
1157 size_t portIndex) {
1158 SmallVector<Attribute> result(info.getNumDomains());
1159 auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex);
1160 for (auto domainPortIndexAttr : oldAssociations) {
1161
1162 auto domainPortIndex = domainPortIndexAttr.getUInt();
1163 auto domainTypeID = info.getDomainTypeID(moduleDomainInfo, domainPortIndex);
1164 result[domainTypeID.index] = domainPortIndexAttr;
1165 };
1166 return result;
1167}
1168
1169// If the port is an output domain, we may need to drive the output with
1170// a value. If we don't know what value to drive to the port, error.
1171static LogicalResult driveModuleOutputDomainPorts(const DomainInfo &info,
1172 const DomainTable &table,
1173 FModuleOp moduleOp) {
1174 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1175 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1176 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1177 if (!port || moduleOp.getPortDirection(i) == Direction::In ||
1178 isDriven(port))
1179 continue;
1180
1181 auto *term = table.getOptTermForDomain(port);
1182 auto *val = llvm::dyn_cast_if_present<ValueTerm>(term);
1183 if (!val) {
1184 emitDomainPortInferenceError(moduleOp, i);
1185 return failure();
1186 }
1187
1188 auto loc = port.getLoc();
1189 auto value = val->value;
1190 DomainDefineOp::create(builder, loc, port, value);
1191 }
1192
1193 return success();
1194}
1195
1196/// After generalizing the moduleOp, all domains should be solved. Reflect the
1197/// solved domain associations into the port domain info attribute.
1198static LogicalResult updateModuleDomainInfo(const DomainInfo &info,
1199 const DomainTable &table,
1200 const ExportTable &exportTable,
1201 ArrayAttr &result,
1202 FModuleOp moduleOp) {
1203 // At this point, all domain variables mentioned in ports have been
1204 // solved by generalizing the moduleOp (adding input domain ports). Now, we
1205 // have to form the new port domain information for the moduleOp by examining
1206 // the the associated domains of each port.
1207 auto *context = moduleOp.getContext();
1208 auto numDomains = info.getNumDomains();
1209 auto oldModuleDomainInfo = moduleOp.getDomainInfoAttr();
1210 auto numPorts = moduleOp.getNumPorts();
1211 SmallVector<Attribute> newModuleDomainInfo(numPorts);
1212
1213 for (size_t i = 0; i < numPorts; ++i) {
1214 auto port = moduleOp.getArgument(i);
1215 auto type = port.getType();
1216
1217 if (isa<DomainType>(type)) {
1218 newModuleDomainInfo[i] = oldModuleDomainInfo[i];
1219 continue;
1220 }
1221
1222 if (isa<FIRRTLBaseType>(type)) {
1223 auto associations =
1224 copyPortDomainAssociations(info, oldModuleDomainInfo, i);
1225 auto *row = cast<RowTerm>(table.getDomainAssociation(port));
1226 for (size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1227 auto domainTypeID = DomainTypeID{domainIndex};
1228 if (associations[domainIndex])
1229 continue;
1230
1231 auto domain = cast<ValueTerm>(find(row->elements[domainIndex]))->value;
1232 auto &exports = exportTable.at(domain);
1233 if (exports.empty()) {
1234 auto portName = moduleOp.getPortNameAttr(i);
1235 auto portLoc = moduleOp.getPortLocation(i);
1236 auto domainDecl = info.getDomain(domainTypeID);
1237 auto domainName = domainDecl.getNameAttr();
1238 auto diag = emitError(portLoc)
1239 << "private " << domainName << " association for port "
1240 << portName;
1241 diag.attachNote(domain.getLoc()) << "associated domain: " << domain;
1242 noteLocation(diag, moduleOp);
1243 return failure();
1244 }
1245
1246 if (exports.size() > 1) {
1247 emitAmbiguousPortDomainAssociation(info, moduleOp, exports,
1248 domainTypeID, i);
1249 return failure();
1250 }
1251
1252 auto argument = cast<BlockArgument>(exports[0]);
1253 auto domainPortIndex = argument.getArgNumber();
1254 associations[domainTypeID.index] = IntegerAttr::get(
1255 IntegerType::get(context, 32, IntegerType::Unsigned),
1256 domainPortIndex);
1257 }
1258
1259 newModuleDomainInfo[i] = ArrayAttr::get(context, associations);
1260 continue;
1261 }
1262
1263 newModuleDomainInfo[i] = ArrayAttr::get(context, {});
1264 }
1265
1266 result = ArrayAttr::get(moduleOp.getContext(), newModuleDomainInfo);
1267 moduleOp.setDomainInfoAttr(result);
1268 return success();
1269}
1270
1271template <typename T>
1272static LogicalResult updateInstance(const DomainInfo &info,
1273 TermAllocator &allocator,
1274 DomainTable &table, T op) {
1275 OpBuilder builder(op.getContext());
1276 builder.setInsertionPointAfter(op);
1277 auto numPorts = op->getNumResults();
1278 for (size_t i = 0; i < numPorts; ++i) {
1279 auto port = dyn_cast<DomainValue>(op.getResult(i));
1280 auto direction = op.getPortDirection(i);
1281
1282 // If the port is an input domain, we may need to drive the input with
1283 // a value. If we don't know what value to drive to the port, drive an
1284 // anonymous domain.
1285 if (port && direction == Direction::In && !isDriven(port)) {
1286 auto loc = port.getLoc();
1287 auto *term = getTermForDomain(allocator, table, port);
1288 if (auto *var = dyn_cast<VariableTerm>(term)) {
1289 auto name = getDomainPortTypeName(op.getDomainInfo(), i);
1290 auto anon = DomainCreateAnonOp::create(builder, loc, name);
1291 solve(var, allocator.allocVal(anon));
1292 DomainDefineOp::create(builder, loc, port, anon);
1293 continue;
1294 }
1295 if (auto *val = dyn_cast<ValueTerm>(term)) {
1296 auto value = val->value;
1297 DomainDefineOp::create(builder, loc, port, value);
1298 continue;
1299 }
1300 llvm_unreachable("unhandled domain term type");
1301 }
1302 }
1303
1304 return success();
1305}
1306
1307static LogicalResult updateOp(const DomainInfo &info, TermAllocator &allocator,
1308 DomainTable &table, Operation *op) {
1309 if (auto instance = dyn_cast<InstanceOp>(op))
1310 return updateInstance(info, allocator, table, instance);
1311 if (auto instance = dyn_cast<InstanceChoiceOp>(op))
1312 return updateInstance(info, allocator, table, instance);
1313 return success();
1314}
1315
1316/// After updating the port domain associations, walk the body of the moduleOp
1317/// to fix up any child instance modules.
1318static LogicalResult updateModuleBody(const DomainInfo &info,
1319 TermAllocator &allocator,
1320 DomainTable &table, FModuleOp moduleOp) {
1321 auto result = moduleOp.getBodyBlock()->walk([&](Operation *op) -> WalkResult {
1322 return updateOp(info, allocator, table, op);
1323 });
1324 return failure(result.wasInterrupted());
1325}
1326
1327/// Write the domain associations recorded in the domain table back to the IR.
1328static LogicalResult updateModule(const DomainInfo &info,
1329 TermAllocator &allocator, DomainTable &table,
1330 ModuleUpdateTable &updates, FModuleOp op) {
1331 auto exports = initializeExportTable(table, op);
1332 PendingUpdates pending;
1333 getUpdatesForModule(info, allocator, exports, table, op, pending);
1334 applyUpdatesToModule(info, allocator, exports, table, op, pending);
1335
1336 // Update the domain info for the moduleOp's ports.
1337 ArrayAttr portDomainInfo;
1338 if (failed(updateModuleDomainInfo(info, table, exports, portDomainInfo, op)))
1339 return failure();
1340
1341 // Drive output domain ports.
1342 if (failed(driveModuleOutputDomainPorts(info, table, op)))
1343 return failure();
1344
1345 // Record the updated interface change in the update table.
1346 auto &entry = updates[op.getModuleNameAttr()];
1347 entry.portDomainInfo = portDomainInfo;
1348 entry.portInsertions = std::move(pending.insertions);
1349
1350 if (failed(updateModuleBody(info, allocator, table, op)))
1351 return failure();
1352
1353 return success();
1354}
1355
1356//===---------------------------------------------------------------------------
1357// Checking: Check that a moduleOp has complete domain information.
1358//===---------------------------------------------------------------------------
1359
1360/// Check that a module's hardware ports have complete domain associations.
1361static LogicalResult checkModulePorts(const DomainInfo &info,
1362 FModuleLike moduleOp) {
1363 auto numDomains = info.getNumDomains();
1364 auto domainInfo = moduleOp.getDomainInfoAttr();
1365 auto numPorts = moduleOp.getNumPorts();
1366
1367 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1368 for (size_t i = 0; i < numPorts; ++i) {
1369 if (isa<DomainType>(moduleOp.getPortType(i)))
1370 domainTypeIDTable[i] = info.getDomainTypeID(domainInfo, i);
1371 }
1372
1373 for (size_t i = 0; i < numPorts; ++i) {
1374 auto type = type_dyn_cast<FIRRTLBaseType>(moduleOp.getPortType(i));
1375 if (!type)
1376 continue;
1377
1378 // Record the domain associations of this port.
1379 SmallVector<IntegerAttr> associations(numDomains);
1380 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
1381 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1382 auto prevDomainPortIndex = associations[domainTypeID.index];
1383 if (prevDomainPortIndex) {
1384 emitDuplicatePortDomainError(info, moduleOp, i, domainTypeID,
1385 prevDomainPortIndex, domainPortIndex);
1386 return failure();
1387 }
1388 associations[domainTypeID.index] = domainPortIndex;
1389 }
1390
1391 // Check the associations for completeness.
1392 for (size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1393 auto typeID = DomainTypeID{domainIndex};
1394 if (!associations[domainIndex]) {
1395 emitMissingPortDomainAssociationError(info, moduleOp, typeID, i);
1396 return failure();
1397 }
1398 }
1399 }
1400
1401 return success();
1402}
1403
1404/// Check that output domain ports are driven.
1405static LogicalResult checkModuleDomainPortDrivers(const DomainInfo &info,
1406 FModuleOp moduleOp) {
1407 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1408 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1409 if (!port || moduleOp.getPortDirection(i) != Direction::Out ||
1410 isDriven(port))
1411 continue;
1412
1413 auto name = moduleOp.getPortNameAttr(i);
1414 auto diag = emitError(moduleOp.getPortLocation(i))
1415 << "undriven domain port " << name;
1416 noteLocation(diag, moduleOp);
1417 return failure();
1418 }
1419
1420 return success();
1421}
1422
1423/// Check that the input domain ports are driven.
1424template <typename T>
1425static LogicalResult checkInstanceDomainPortDrivers(T op) {
1426 for (size_t i = 0, e = op.getNumResults(); i < e; ++i) {
1427 auto port = dyn_cast<DomainValue>(op.getResult(i));
1428 auto type = port.getType();
1429 if (!isa<DomainType>(type) || op.getPortDirection(i) != Direction::In ||
1430 isDriven(port))
1431 continue;
1432
1433 auto name = op.getPortNameAttr(i);
1434 auto diag = emitError(op.getPortLocation(i))
1435 << "undriven domain port " << name;
1436 noteLocation(diag, op);
1437 return failure();
1438 }
1439
1440 return success();
1441}
1442
1443static LogicalResult checkOp(Operation *op) {
1444 if (auto inst = dyn_cast<InstanceOp>(op))
1445 return checkInstanceDomainPortDrivers(inst);
1446 if (auto inst = dyn_cast<InstanceChoiceOp>(op))
1447 return checkInstanceDomainPortDrivers(inst);
1448 return success();
1449}
1450
1451/// Check that instances under this module have driven domain input ports.
1452static LogicalResult checkModuleBody(FModuleOp moduleOp) {
1453 auto result = moduleOp.getBody().walk(
1454 [&](Operation *op) -> WalkResult { return checkOp(op); });
1455 return failure(result.wasInterrupted());
1456}
1457
1458//===---------------------------------------------------------------------------
1459// Domain Stripping.
1460//===---------------------------------------------------------------------------
1461
1462static LogicalResult stripModule(FModuleLike op) {
1463 WalkResult result = op->walk<mlir::WalkOrder::PostOrder, ReverseIterator>(
1464 [=](Operation *op) -> WalkResult {
1465 return TypeSwitch<Operation *, WalkResult>(op)
1466 .Case<FModuleLike>([](FModuleLike op) {
1467 auto n = op.getNumPorts();
1468 BitVector erasures(n);
1469 for (size_t i = 0; i < n; ++i)
1470 if (isa<DomainType>(op.getPortType(i)))
1471 erasures.set(i);
1472 op.erasePorts(erasures);
1473 return WalkResult::advance();
1474 })
1475 .Case<DomainDefineOp, DomainCreateAnonOp>([](Operation *op) {
1476 op->erase();
1477 return WalkResult::advance();
1478 })
1479 .Case<UnsafeDomainCastOp>([](UnsafeDomainCastOp op) {
1480 op.replaceAllUsesWith(op.getInput());
1481 op.erase();
1482 return WalkResult::advance();
1483 })
1484 .Case<WireOp>([](WireOp op) {
1485 if (isa<DomainType>(op.getType(0)))
1486 op->erase();
1487 return WalkResult::advance();
1488 })
1489 .Case<InstanceOp, InstanceChoiceOp>([](auto op) {
1490 auto n = op.getNumPorts();
1491 BitVector erasures(n);
1492 for (size_t i = 0; i < n; ++i)
1493 if (isa<DomainType>(op->getResult(i).getType()))
1494 erasures.set(i);
1495 op.cloneWithErasedPortsAndReplaceUses(erasures);
1496 op.erase();
1497 return WalkResult::advance();
1498 })
1499 .Default([](Operation *op) {
1500 for (auto type :
1501 concat<Type>(op->getOperandTypes(), op->getResultTypes())) {
1502 if (isa<DomainType>(type)) {
1503 op->emitOpError("cannot be stripped");
1504 return WalkResult::interrupt();
1505 }
1506 }
1507 return WalkResult::advance();
1508 });
1509 });
1510 return failure(result.wasInterrupted());
1511}
1512
1513static LogicalResult stripCircuit(MLIRContext *context, CircuitOp circuit) {
1514 llvm::SmallVector<FModuleLike> modules;
1515 for (Operation &op : make_early_inc_range(*circuit.getBodyBlock())) {
1516 TypeSwitch<Operation *, void>(&op)
1517 .Case<FModuleLike>([&](FModuleLike op) { modules.push_back(op); })
1518 .Case<DomainOp>([](DomainOp op) { op.erase(); });
1519 }
1520 return failableParallelForEach(context, modules, stripModule);
1521}
1522
1523//===---------------------------------------------------------------------------
1524// InferDomainsPass: Top-level pass implementation.
1525//===---------------------------------------------------------------------------
1526
1527/// Solve for domains and then write the domain associations back to the IR.
1528static LogicalResult inferModule(const DomainInfo &info,
1529 ModuleUpdateTable &updates,
1530 FModuleOp moduleOp) {
1531 TermAllocator allocator;
1532 DomainTable table;
1533
1534 if (failed(processModule(info, allocator, table, updates, moduleOp)))
1535 return failure();
1536
1537 return updateModule(info, allocator, table, updates, moduleOp);
1538}
1539
1540/// Check that a module's ports are fully annotated, before performing domain
1541/// inference on the module.
1542static LogicalResult checkModule(const DomainInfo &info, FModuleOp moduleOp) {
1543 if (failed(checkModulePorts(info, moduleOp)))
1544 return failure();
1545
1546 if (failed(checkModuleDomainPortDrivers(info, moduleOp)))
1547 return failure();
1548
1549 if (failed(checkModuleBody(moduleOp)))
1550 return failure();
1551
1552 TermAllocator allocator;
1553 DomainTable table;
1554 ModuleUpdateTable updateTable;
1555 return processModule(info, allocator, table, updateTable, moduleOp);
1556}
1557
1558/// Check that an extmodule's ports are fully annotated.
1559static LogicalResult checkModule(const DomainInfo &info,
1560 FExtModuleOp moduleOp) {
1561 return checkModulePorts(info, moduleOp);
1562}
1563
1564/// Check that a module's ports are fully annotated, before performing domain
1565/// inference on the module. We use this when private module interfaces are
1566/// inferred but public module interfaces are checked.
1567static LogicalResult checkAndInferModule(const DomainInfo &info,
1568 ModuleUpdateTable &updateTable,
1569 FModuleOp moduleOp) {
1570 if (failed(checkModulePorts(info, moduleOp)))
1571 return failure();
1572
1573 TermAllocator allocator;
1574 DomainTable table;
1575 if (failed(processModule(info, allocator, table, updateTable, moduleOp)))
1576 return failure();
1577
1578 if (failed(driveModuleOutputDomainPorts(info, table, moduleOp)))
1579 return failure();
1580
1581 return updateModuleBody(info, allocator, table, moduleOp);
1582}
1583
1584static LogicalResult runOnModuleLike(InferDomainsMode mode,
1585 const DomainInfo &info,
1586 ModuleUpdateTable &updateTable,
1587 Operation *op) {
1588 assert(mode != InferDomainsMode::Strip);
1589
1590 if (auto moduleOp = dyn_cast<FModuleOp>(op)) {
1591 if (mode == InferDomainsMode::Check)
1592 return checkModule(info, moduleOp);
1593
1594 if (mode == InferDomainsMode::InferAll || moduleOp.isPrivate())
1595 return inferModule(info, updateTable, moduleOp);
1596
1597 return checkAndInferModule(info, updateTable, moduleOp);
1598 }
1599
1600 if (auto extModule = dyn_cast<FExtModuleOp>(op))
1601 return checkModule(info, extModule);
1602
1603 return success();
1604}
1605
1606namespace {
1607struct InferDomainsPass
1608 : public circt::firrtl::impl::InferDomainsBase<InferDomainsPass> {
1609 using Base::Base;
1610 void runOnOperation() override {
1612 auto circuit = getOperation();
1613
1614 if (mode == InferDomainsMode::Strip) {
1615 if (failed(stripCircuit(&getContext(), circuit)))
1616 signalPassFailure();
1617 return;
1618 }
1619
1620 auto &instanceGraph = getAnalysis<InstanceGraph>();
1621 DomainInfo info(circuit);
1622 ModuleUpdateTable updateTable;
1623 DenseSet<Operation *> errored;
1624 instanceGraph.walkPostOrder([&](auto &node) {
1625 auto moduleOp = node.getModule();
1626 for (auto *inst : node) {
1627 if (errored.contains(inst->getTarget()->getModule())) {
1628 errored.insert(moduleOp);
1629 return;
1630 }
1631 }
1632 if (failed(runOnModuleLike(mode, info, updateTable, node.getModule())))
1633 errored.insert(moduleOp);
1634 });
1635 if (errored.size())
1636 signalPassFailure();
1637 }
1638};
1639} // namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static LogicalResult checkOp(Operation *op)
static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, InstanceOp op)
static LogicalResult updateModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, ModuleUpdateTable &updates, FModuleOp op)
Write the domain associations recorded in the domain table back to the IR.
static void emitDuplicatePortDomainError(const DomainInfo &info, T op, size_t i, DomainTypeID domainTypeID, IntegerAttr domainPortIndexAttr1, IntegerAttr domainPortIndexAttr2)
static ExportTable initializeExportTable(const DomainTable &table, FModuleOp moduleOp)
Build a table of exported domains: a map from domains defined internally, to their set of aliasing ou...
static LogicalResult processModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FModuleOp moduleOp)
static void emitAmbiguousPortDomainAssociation(const DomainInfo &info, T op, const llvm::TinyPtrVector< DomainValue > &exports, DomainTypeID typeID, size_t i)
static LogicalResult processModulePorts(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp)
static LogicalResult inferModule(const DomainInfo &info, ModuleUpdateTable &updates, FModuleOp moduleOp)
Solve for domains and then write the domain associations back to the IR.
static LogicalResult driveModuleOutputDomainPorts(const DomainInfo &info, const DomainTable &table, FModuleOp moduleOp)
SmallVector< std::pair< unsigned, PortInfo > > PortInsertions
llvm::MapVector< DomainValue, unsigned > PendingExports
A map from local domains to an aliasing port index, where that port has not yet been created.
static LogicalResult runOnModuleLike(InferDomainsMode mode, const DomainInfo &info, ModuleUpdateTable &updateTable, Operation *op)
mlir::TypedValue< DomainType > DomainValue
static LogicalResult stripCircuit(MLIRContext *context, CircuitOp circuit)
static RowTerm * getDomainAssociationAsRow(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Value value)
Get the row of domains that a hardware value in the IR is associated with.
static void emitMissingPortDomainAssociationError(const DomainInfo &info, T op, DomainTypeID typeID, size_t i)
static void getUpdatesForModulePorts(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, Namespace &ns, FModuleOp moduleOp, PendingUpdates &pending)
static T fixInstancePorts(T op, const ModuleUpdateInfo &update)
Apply the port changes of a moduleOp onto an instance-like op.
static LogicalResult updateModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp)
After updating the port domain associations, walk the body of the moduleOp to fix up any child instan...
static void render(const DomainInfo &info, Diagnostic &out, VariableIDTable &idTable, Term *term)
static StringAttr getDomainPortTypeName(ArrayAttr info, size_t i)
From a domain info attribute, get the domain-type of a domain value at index i.
static void processDomainDefinition(TermAllocator &allocator, DomainTable &table, DomainValue domain)
static LogicalResult unify(Term *lhs, Term *rhs)
static LogicalResult updateModuleDomainInfo(const DomainInfo &info, const DomainTable &table, const ExportTable &exportTable, ArrayAttr &result, FModuleOp moduleOp)
After generalizing the moduleOp, all domains should be solved.
static LogicalResult unifyAssociations(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Operation *op, Value lhs, Value rhs)
Unify the associated domain rows of two terms.
DenseMap< VariableTerm *, unsigned > PendingSolutions
A map from unsolved variables to a port index, where that port has not yet been created.
static LogicalResult checkModule(const DomainInfo &info, FModuleOp moduleOp)
Check that a module's ports are fully annotated, before performing domain inference on the module.
static LogicalResult checkModuleDomainPortDrivers(const DomainInfo &info, FModuleOp moduleOp)
Check that output domain ports are driven.
static SmallVector< Attribute > copyPortDomainAssociations(const DomainInfo &info, ArrayAttr moduleDomainInfo, size_t portIndex)
Copy the domain associations from the moduleOp domain info attribute into a small vector.
static void noteLocation(mlir::InFlightDiagnostic &diag, Operation *op)
static void emitPortDomainCrossingError(const DomainInfo &info, T op, size_t i, DomainTypeID domainTypeID, Term *term1, Term *term2)
static void getUpdatesForDomainAssociationOfPort(const DomainInfo &info, Namespace &ns, PendingUpdates &pending, DomainTypeID typeID, size_t ip, LocationAttr loc, Term *term, const ExportTable &exports)
static void applyUpdatesToModule(const DomainInfo &info, TermAllocator &allocator, ExportTable &exports, DomainTable &table, FModuleOp moduleOp, const PendingUpdates &pending)
static void getUpdatesForModule(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, FModuleOp mod, PendingUpdates &pending)
static LogicalResult updateOp(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Operation *op)
static LogicalResult checkInstanceDomainPortDrivers(T op)
Check that the input domain ports are driven.
static void emitDomainPortInferenceError(T op, size_t i)
Emit an error when we fail to infer the concrete domain to drive to a domain port.
static void ensureExported(const DomainInfo &info, Namespace &ns, const ExportTable &exports, DomainTypeID typeID, size_t ip, LocationAttr loc, ValueTerm *val, PendingUpdates &pending)
Ensure that the domain value is available in the signature of the moduleOp, so that subsequent hardwa...
static bool isPort(BlockArgument arg)
Return true if the value is a port on the module.
DenseMap< DomainValue, TinyPtrVector< DomainValue > > ExportTable
A map from domain IR values defined internal to the moduleOp, to ports that alias that domain.
static Term * getTermForDomain(TermAllocator &allocator, DomainTable &table, DomainValue value)
Get the corresponding term for a domain in the IR.
static auto getPortDomainAssociation(ArrayAttr info, size_t i)
From a domain info attribute, get the row of associated domains for a hardware value at index i.
DenseMap< StringAttr, ModuleUpdateInfo > ModuleUpdateTable
static void ensureSolved(const DomainInfo &info, Namespace &ns, DomainTypeID typeID, size_t ip, LocationAttr loc, VariableTerm *var, PendingUpdates &pending)
If var is not solved, solve it by recording a pending input port at the indicated insertion point.
static bool isDriven(DomainValue port)
Returns true if the value is driven by a connect op.
static void solve(Term *lhs, Term *rhs)
static LogicalResult processModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FModuleOp moduleOp)
Populate the domain table by processing the moduleOp.
static LogicalResult checkModuleBody(FModuleOp moduleOp)
Check that instances under this module have driven domain input ports.
static LogicalResult checkModulePorts(const DomainInfo &info, FModuleLike moduleOp)
Check that a module's hardware ports have complete domain associations.
static Term * find(Term *x)
static LogicalResult updateInstance(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, T op)
static LogicalResult processInstancePorts(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, T op)
static LogicalResult checkAndInferModule(const DomainInfo &info, ModuleUpdateTable &updateTable, FModuleOp moduleOp)
Check that a module's ports are fully annotated, before performing domain inference on the module.
static LogicalResult stripModule(FModuleLike op)
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
Definition Debug.h:70
This class represents a reference to a specific field or element of an aggregate value.
Definition FieldRef.h:28
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
InferDomainsMode
The mode for the InferDomains pass.
Definition Passes.h:73
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.