CIRCT 23.0.0git
Loading...
Searching...
No Matches
InferDomains.cpp
Go to the documentation of this file.
1//===- InferDomains.cpp - Infer and Check FIRRTL Domains ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// InferDomains implements FIRRTL domain inference and checking. This pass is
10// a bottom-up transform acting on modules. For each moduleOp, we ensure there
11// are no domain crossings, and we make explicit the domain associations of
12// ports.
13//
14// This pass does not require that ExpandWhens has run, but it should have run.
15// If ExpandWhens has not been run, then duplicate connections will influence
16// domain inference and this can result in errors.
17//
18//===----------------------------------------------------------------------===//
19
24#include "circt/Support/Debug.h"
26#include "mlir/IR/Iterators.h"
27#include "mlir/IR/Threading.h"
28#include "llvm/ADT/DenseMap.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/SmallVector.h"
31#include "llvm/ADT/TinyPtrVector.h"
32
33#define DEBUG_TYPE "firrtl-infer-domains"
34
35namespace circt {
36namespace firrtl {
37#define GEN_PASS_DEF_INFERDOMAINS
38#include "circt/Dialect/FIRRTL/Passes.h.inc"
39} // namespace firrtl
40} // namespace circt
41
42using namespace circt;
43using namespace firrtl;
44
45using llvm::concat;
46using mlir::ReverseIterator;
47
48//====--------------------------------------------------------------------------
49// Helpers.
50//====--------------------------------------------------------------------------
51
52using DomainValue = mlir::TypedValue<DomainType>;
53
54using PortInsertions = SmallVector<std::pair<unsigned, PortInfo>>;
55
56/// From a domain info attribute, get the row of associated domains for a
57/// hardware value at index i.
58static auto getPortDomainAssociation(ArrayAttr info, size_t i) {
59 if (info.empty())
60 return info.getAsRange<IntegerAttr>();
61 return cast<ArrayAttr>(info[i]).getAsRange<IntegerAttr>();
62}
63
64/// Return true if the value is a port on the module.
65static bool isPort(BlockArgument arg) {
66 return isa<FModuleOp>(arg.getOwner()->getParentOp());
67}
68
69/// Return true if the value is a port on the module.
70static bool isPort(Value value) {
71 auto arg = dyn_cast<BlockArgument>(value);
72 if (!arg)
73 return false;
74 return isPort(arg);
75}
76
77/// Returns true if the value is driven by a connect op.
78static bool isDriven(DomainValue port) {
79 for (auto *user : port.getUsers())
80 if (auto connect = dyn_cast<FConnectLike>(user))
81 if (connect.getDest() == port)
82 return true;
83 return false;
84}
85
86//====--------------------------------------------------------------------------
87// Global State.
88//====--------------------------------------------------------------------------
89
90/// Each domain type declared in the circuit is assigned a type-id, based on the
91/// order of declaration. Domain associations for hardware values are
92/// represented as a list, or row, of domains. The domains in a row are ordered
93/// according to their type's id.
94namespace {
95struct DomainTypeID {
96 size_t index;
97};
98} // namespace
99
100/// Information about the domains in the circuit. Able to map domains to their
101/// type ID, which in this pass is the canonical way to reference the type
102/// of a domain, as well as provide fast access to domain ops.
103namespace {
104class DomainInfo {
105public:
106 DomainInfo(CircuitOp circuit) { processCircuit(circuit); }
107
108 ArrayRef<DomainOp> getDomains() const { return domainTable; }
109 size_t getNumDomains() const { return domainTable.size(); }
110 DomainOp getDomain(DomainTypeID id) const { return domainTable[id.index]; }
111
112 DomainTypeID getDomainTypeID(Type type) const { return typeIDTable.at(type); }
113
114 DomainTypeID getDomainTypeID(FModuleLike module, size_t i) const {
115 return getDomainTypeID(module.getPortType(i));
116 }
117
118 DomainTypeID getDomainTypeID(FInstanceLike op, size_t i) const {
119 return getDomainTypeID(op->getResult(i).getType());
120 }
121
122 DomainTypeID getDomainTypeID(DomainValue value) const {
123 return getDomainTypeID(value.getType());
124 }
125
126private:
127 void processDomain(DomainOp op) {
128 auto index = domainTable.size();
129 auto domainType = DomainType::getFromDomainOp(op);
130 domainTable.push_back(op);
131 typeIDTable.insert({domainType, {index}});
132 }
133
134 void processCircuit(CircuitOp circuit) {
135 for (auto decl : circuit.getOps<DomainOp>())
136 processDomain(decl);
137 }
138
139 /// A map from domain type ID to op.
140 SmallVector<DomainOp> domainTable;
141
142 /// A map from domain type to type ID.
143 DenseMap<Type, DomainTypeID> typeIDTable;
144};
145
146/// Information about the changes made to the interface of a moduleOp, which can
147/// be replayed onto an instance.
148struct ModuleUpdateInfo {
149 /// The updated domain information for a moduleOp.
150 ArrayAttr portDomainInfo;
151 /// The domain ports which have been inserted into a moduleOp.
152 PortInsertions portInsertions;
153};
154} // namespace
155
156using ModuleUpdateTable = DenseMap<StringAttr, ModuleUpdateInfo>;
157
158/// Apply the port changes of a moduleOp onto an instance-like op.
159static FInstanceLike fixInstancePorts(FInstanceLike op,
160 const ModuleUpdateInfo &update) {
161 auto clone = op.cloneWithInsertedPortsAndReplaceUses(update.portInsertions);
162 clone.setDomainInfoAttr(update.portDomainInfo);
163 op->erase();
164 return clone;
165}
166
167//====--------------------------------------------------------------------------
168// Terms: Syntax for unifying domain and domain-rows.
169//====--------------------------------------------------------------------------
170
171/// The different sorts of terms in the unification engine.
172namespace {
173enum class TermKind {
174 Variable,
175 Value,
176 Row,
177};
178} // namespace
179
180/// A term in the unification engine.
181namespace {
182struct Term {
183 constexpr Term(TermKind kind) : kind(kind) {}
184 TermKind kind;
185};
186} // namespace
187
188/// Helper to define a term kind.
189namespace {
190template <TermKind K>
191struct TermBase : Term {
192 static bool classof(const Term *term) { return term->kind == K; }
193 TermBase() : Term(K) {}
194};
195} // namespace
196
197/// An unknown value.
198namespace {
199struct VariableTerm : public TermBase<TermKind::Variable> {
200 VariableTerm() : leader(nullptr) {}
201 VariableTerm(Term *leader) : leader(leader) {}
202 Term *leader;
203};
204} // namespace
205
206/// A concrete value defined in the IR.
207namespace {
208struct ValueTerm : public TermBase<TermKind::Value> {
209 ValueTerm(DomainValue value) : value(value) {}
210 DomainValue value;
211};
212} // namespace
213
214/// A row of domains.
215namespace {
216struct RowTerm : public TermBase<TermKind::Row> {
217 RowTerm(ArrayRef<Term *> elements) : elements(elements) {}
218 ArrayRef<Term *> elements;
219};
220} // namespace
221
222// NOLINTNEXTLINE(misc-no-recursion)
223static Term *find(Term *x) {
224 if (!x)
225 return nullptr;
226
227 if (auto *var = dyn_cast<VariableTerm>(x)) {
228 if (var->leader == nullptr)
229 return var;
230
231 auto *leader = find(var->leader);
232 if (leader != var->leader)
233 var->leader = leader;
234 return leader;
235 }
236
237 return x;
238}
239
240/// A helper for assigning low numeric IDs to variables for user-facing output.
241namespace {
242class VariableIDTable {
243public:
244 size_t get(VariableTerm *term) {
245 return table.insert({term, table.size() + 1}).first->second;
246 }
247
248private:
249 DenseMap<VariableTerm *, size_t> table;
250};
251} // namespace
252
253// NOLINTNEXTLINE(misc-no-recursion)
254static void render(const DomainInfo &info, Diagnostic &out,
255 VariableIDTable &idTable, Term *term) {
256 term = find(term);
257 if (auto *var = dyn_cast<VariableTerm>(term)) {
258 out << "?" << idTable.get(var);
259 return;
260 }
261 if (auto *val = dyn_cast<ValueTerm>(term)) {
262 auto value = val->value;
263 auto [name, _] = getFieldName(FieldRef(value, 0), false);
264 out << name;
265 return;
266 }
267 if (auto *row = dyn_cast<RowTerm>(term)) {
268 bool first = true;
269 out << "[";
270 for (size_t i = 0, e = info.getNumDomains(); i < e; ++i) {
271 auto domainOp = info.getDomain(DomainTypeID{i});
272 if (!first) {
273 out << ", ";
274 first = false;
275 }
276 out << domainOp.getName() << ": ";
277 render(info, out, idTable, row->elements[i]);
278 }
279 out << "]";
280 return;
281 }
282}
283
284static LogicalResult unify(Term *lhs, Term *rhs);
285
286static LogicalResult unify(VariableTerm *x, Term *y) {
287 assert(!x->leader);
288 x->leader = y;
289 return success();
290}
291
292static LogicalResult unify(ValueTerm *xv, Term *y) {
293 if (auto *yv = dyn_cast<VariableTerm>(y)) {
294 yv->leader = xv;
295 return success();
296 }
297
298 if (auto *yv = dyn_cast<ValueTerm>(y))
299 return success(xv == yv);
300
301 return failure();
302}
303
304// NOLINTNEXTLINE(misc-no-recursion)
305static LogicalResult unify(RowTerm *lhsRow, Term *rhs) {
306 if (auto *rhsVar = dyn_cast<VariableTerm>(rhs)) {
307 rhsVar->leader = lhsRow;
308 return success();
309 }
310 if (auto *rhsRow = dyn_cast<RowTerm>(rhs)) {
311 for (auto [x, y] : llvm::zip_equal(lhsRow->elements, rhsRow->elements))
312 if (failed(unify(x, y)))
313 return failure();
314 return success();
315 }
316 return failure();
317}
318
319// NOLINTNEXTLINE(misc-no-recursion)
320static LogicalResult unify(Term *lhs, Term *rhs) {
321 if (!lhs || !rhs)
322 return success();
323 lhs = find(lhs);
324 rhs = find(rhs);
325 if (lhs == rhs)
326 return success();
327 if (auto *lhsVar = dyn_cast<VariableTerm>(lhs))
328 return unify(lhsVar, rhs);
329 if (auto *lhsVal = dyn_cast<ValueTerm>(lhs))
330 return unify(lhsVal, rhs);
331 if (auto *lhsRow = dyn_cast<RowTerm>(lhs))
332 return unify(lhsRow, rhs);
333 return failure();
334}
335
336static void solve(Term *lhs, Term *rhs) {
337 [[maybe_unused]] auto result = unify(lhs, rhs);
338 assert(result.succeeded());
339}
340
341namespace {
342class TermAllocator {
343public:
344 /// Allocate a row of fresh domain variables.
345 [[nodiscard]] RowTerm *allocRow(size_t size) {
346 SmallVector<Term *> elements;
347 elements.resize(size);
348 return allocRow(elements);
349 }
350
351 /// Allocate a row of terms.
352 [[nodiscard]] RowTerm *allocRow(ArrayRef<Term *> elements) {
353 auto ds = allocArray(elements);
354 return alloc<RowTerm>(ds);
355 }
356
357 /// Allocate a fresh variable.
358 [[nodiscard]] VariableTerm *allocVar() { return alloc<VariableTerm>(); }
359
360 /// Allocate a concrete domain.
361 [[nodiscard]] ValueTerm *allocVal(DomainValue value) {
362 return alloc<ValueTerm>(value);
363 }
364
365private:
366 template <typename T, typename... Args>
367 [[nodiscard]] T *alloc(Args &&...args) {
368 static_assert(std::is_base_of_v<Term, T>, "T must be a term");
369 return new (allocator) T(std::forward<Args>(args)...);
370 }
371
372 [[nodiscard]] ArrayRef<Term *> allocArray(ArrayRef<Term *> elements) {
373 auto size = elements.size();
374 if (size == 0)
375 return {};
376
377 auto *result = allocator.Allocate<Term *>(size);
378 llvm::uninitialized_copy(elements, result);
379 for (size_t i = 0; i < size; ++i)
380 if (!result[i])
381 result[i] = alloc<VariableTerm>();
382
383 return ArrayRef(result, size);
384 }
385
386 llvm::BumpPtrAllocator allocator;
387};
388} // namespace
389
390//====--------------------------------------------------------------------------
391// DomainTable: A mapping from IR to terms.
392//====--------------------------------------------------------------------------
393
394namespace {
395/// Tracks domain infomation for IR values.
396class DomainTable {
397public:
398 /// If the domain value is an alias, returns the domain it aliases.
399 DomainValue getOptUnderlyingDomain(DomainValue value) const {
400 auto *term = getOptTermForDomain(value);
401 if (auto *val = llvm::dyn_cast_if_present<ValueTerm>(term))
402 return val->value;
403 return nullptr;
404 }
405
406 /// Get the corresponding term for a domain in the IR, or null if unset.
407 Term *getOptTermForDomain(DomainValue value) const {
408 assert(isa<DomainType>(value.getType()));
409 auto it = termTable.find(value);
410 if (it == termTable.end())
411 return nullptr;
412 return find(it->second);
413 }
414
415 /// Get the corresponding term for a domain in the IR.
416 Term *getTermForDomain(DomainValue value) const {
417 auto *term = getOptTermForDomain(value);
418 assert(term);
419 return term;
420 }
421
422 /// Record a mapping from domain in the IR to its corresponding term.
423 void setTermForDomain(DomainValue value, Term *term) {
424 assert(term);
425 assert(!termTable.contains(value));
426 termTable.insert({value, term});
427 }
428
429 /// For a hardware value, get the term which represents the row of associated
430 /// domains. If no mapping has been defined, returns nullptr.
431 Term *getOptDomainAssociation(Value value) const {
432 assert(isa<FIRRTLBaseType>(value.getType()));
433 auto it = associationTable.find(value);
434 if (it == associationTable.end())
435 return nullptr;
436 return find(it->second);
437 }
438
439 /// For a hardware value, get the term which represents the row of associated
440 /// domains.
441 Term *getDomainAssociation(Value value) const {
442 auto *term = getOptDomainAssociation(value);
443 assert(term);
444 return term;
445 }
446
447 /// Record a mapping from a hardware value in the IR to a term which
448 /// represents the row of domains it is associated with.
449 void setDomainAssociation(Value value, Term *term) {
450 assert(isa<FIRRTLBaseType>(value.getType()));
451 assert(term);
452 term = find(term);
453 associationTable.insert({value, term});
454 }
455
456private:
457 /// Map from domains in the IR to their underlying term.
458 DenseMap<Value, Term *> termTable;
459
460 /// A map from hardware values to their associated row of domains, as a term.
461 DenseMap<Value, Term *> associationTable;
462};
463} // namespace
464
465//====--------------------------------------------------------------------------
466// Module processing: solve for the domain associations of hardware.
467//====--------------------------------------------------------------------------
468
469/// Get the corresponding term for a domain in the IR. If we don't know what the
470/// term is, then map the domain in the IR to a variable term.
471static Term *getTermForDomain(TermAllocator &allocator, DomainTable &table,
472 DomainValue value) {
473 assert(isa<DomainType>(value.getType()));
474 if (auto *term = table.getOptTermForDomain(value))
475 return term;
476 auto *term = allocator.allocVar();
477 table.setTermForDomain(value, term);
478 return term;
479}
480
481static void processDomainDefinition(TermAllocator &allocator,
482 DomainTable &table, DomainValue domain) {
483 assert(isa<DomainType>(domain.getType()));
484 auto *newTerm = allocator.allocVal(domain);
485 auto *oldTerm = table.getOptTermForDomain(domain);
486 if (!oldTerm) {
487 table.setTermForDomain(domain, newTerm);
488 return;
489 }
490
491 [[maybe_unused]] auto result = unify(oldTerm, newTerm);
492 assert(result.succeeded());
493}
494
495/// Get the row of domains that a hardware value in the IR is associated with.
496/// The returned term is forced to be at least a row.
497static RowTerm *getDomainAssociationAsRow(const DomainInfo &info,
498 TermAllocator &allocator,
499 DomainTable &table, Value value) {
500 assert(isa<FIRRTLBaseType>(value.getType()));
501 auto *term = table.getOptDomainAssociation(value);
502
503 // If the term is unknown, allocate a fresh row and set the association.
504 if (!term) {
505 auto *row = allocator.allocRow(info.getNumDomains());
506 table.setDomainAssociation(value, row);
507 return row;
508 }
509
510 // If the term is already a row, return it.
511 if (auto *row = dyn_cast<RowTerm>(term))
512 return row;
513
514 // Otherwise, unify the term with a fresh row of domains.
515 if (auto *var = dyn_cast<VariableTerm>(term)) {
516 auto *row = allocator.allocRow(info.getNumDomains());
517 solve(var, row);
518 return row;
519 }
520
521 assert(false && "unhandled term type");
522 return nullptr;
523}
524
525static void noteLocation(mlir::InFlightDiagnostic &diag, Operation *op) {
526 auto &note = diag.attachNote(op->getLoc());
527 if (auto mod = dyn_cast<FModuleOp>(op)) {
528 note << "in module " << mod.getModuleNameAttr();
529 return;
530 }
531 if (auto mod = dyn_cast<FExtModuleOp>(op)) {
532 note << "in extmodule " << mod.getModuleNameAttr();
533 return;
534 }
535 if (auto inst = dyn_cast<InstanceOp>(op)) {
536 note << "in instance " << inst.getInstanceNameAttr();
537 return;
538 }
539 if (auto inst = dyn_cast<InstanceChoiceOp>(op)) {
540 note << "in instance_choice " << inst.getNameAttr();
541 return;
542 }
543
544 note << "here";
545}
546
547template <typename T>
548static void emitPortDomainCrossingError(const DomainInfo &info, T op, size_t i,
549 DomainTypeID domainTypeID, Term *term1,
550 Term *term2) {
551 VariableIDTable idTable;
552
553 auto portName = op.getPortNameAttr(i);
554 auto portLoc = op.getPortLocation(i);
555 auto domainDecl = info.getDomain(domainTypeID);
556 auto domainName = domainDecl.getNameAttr();
557
558 auto diag = emitError(portLoc);
559 diag << "illegal " << domainName << " crossing in port " << portName;
560
561 auto &note1 = diag.attachNote();
562 note1 << "1st instance: ";
563 render(info, note1, idTable, term1);
564
565 auto &note2 = diag.attachNote();
566 note2 << "2nd instance: ";
567 render(info, note2, idTable, term2);
568
569 noteLocation(diag, op);
570}
571
572template <typename T>
573static void emitDuplicatePortDomainError(const DomainInfo &info, T op, size_t i,
574 DomainTypeID domainTypeID,
575 IntegerAttr domainPortIndexAttr1,
576 IntegerAttr domainPortIndexAttr2) {
577 VariableIDTable idTable;
578 auto portName = op.getPortNameAttr(i);
579 auto portLoc = op.getPortLocation(i);
580 auto domainDecl = info.getDomain(domainTypeID);
581 auto domainName = domainDecl.getNameAttr();
582 auto domainPortIndex1 = domainPortIndexAttr1.getUInt();
583 auto domainPortIndex2 = domainPortIndexAttr2.getUInt();
584 auto domainPortName1 = op.getPortNameAttr(domainPortIndex1);
585 auto domainPortName2 = op.getPortNameAttr(domainPortIndex2);
586 auto domainPortLoc1 = op.getPortLocation(domainPortIndex1);
587 auto domainPortLoc2 = op.getPortLocation(domainPortIndex2);
588 auto diag = emitError(portLoc);
589 diag << "duplicate " << domainName << " association for port " << portName;
590 auto &note1 = diag.attachNote(domainPortLoc1);
591 note1 << "associated with " << domainName << " port " << domainPortName1;
592 auto &note2 = diag.attachNote(domainPortLoc2);
593 note2 << "associated with " << domainName << " port " << domainPortName2;
594 noteLocation(diag, op);
595}
596
597/// Emit an error when we fail to infer the concrete domain to drive to a
598/// domain port.
599template <typename T>
600static void emitDomainPortInferenceError(T op, size_t i) {
601 auto name = op.getPortNameAttr(i);
602 auto diag = emitError(op->getLoc());
603 auto info = op.getDomainInfo();
604 diag << "unable to infer value for undriven domain port " << name;
605 for (size_t j = 0, e = op.getNumPorts(); j < e; ++j) {
606 if (auto assocs = dyn_cast<ArrayAttr>(info[j])) {
607 for (auto assoc : assocs) {
608 if (i == cast<IntegerAttr>(assoc).getValue()) {
609 auto name = op.getPortNameAttr(j);
610 auto loc = op.getPortLocation(j);
611 diag.attachNote(loc) << "associated with hardware port " << name;
612 break;
613 }
614 }
615 }
616 }
617 noteLocation(diag, op);
618}
619
620template <typename T>
622 const DomainInfo &info, T op,
623 const llvm::TinyPtrVector<DomainValue> &exports, DomainTypeID typeID,
624 size_t i) {
625 auto portName = op.getPortNameAttr(i);
626 auto portLoc = op.getPortLocation(i);
627 auto domainDecl = info.getDomain(typeID);
628 auto domainName = domainDecl.getNameAttr();
629 auto diag = emitError(portLoc) << "ambiguous " << domainName
630 << " association for port " << portName;
631 for (auto e : exports) {
632 auto arg = cast<BlockArgument>(e);
633 auto name = op.getPortNameAttr(arg.getArgNumber());
634 auto loc = op.getPortLocation(arg.getArgNumber());
635 diag.attachNote(loc) << "candidate association " << name;
636 }
637 noteLocation(diag, op);
638}
639
640template <typename T>
641static void emitMissingPortDomainAssociationError(const DomainInfo &info, T op,
642 DomainTypeID typeID,
643 size_t i) {
644 auto domainName = info.getDomain(typeID).getNameAttr();
645 auto portName = op.getPortNameAttr(i);
646 auto diag = emitError(op.getPortLocation(i))
647 << "missing " << domainName << " association for port "
648 << portName;
649 noteLocation(diag, op);
650}
651
652/// Unify the associated domain rows of two terms.
653static LogicalResult unifyAssociations(const DomainInfo &info,
654 TermAllocator &allocator,
655 DomainTable &table, Operation *op,
656 Value lhs, Value rhs) {
657 if (!lhs || !rhs)
658 return success();
659
660 if (lhs == rhs)
661 return success();
662
663 auto *lhsTerm = table.getOptDomainAssociation(lhs);
664 auto *rhsTerm = table.getOptDomainAssociation(rhs);
665
666 if (lhsTerm) {
667 if (rhsTerm) {
668 if (failed(unify(lhsTerm, rhsTerm))) {
669 auto diag = op->emitOpError("illegal domain crossing in operation");
670 auto &note1 = diag.attachNote(lhs.getLoc());
671
672 note1 << "1st operand has domains: ";
673 VariableIDTable idTable;
674 render(info, note1, idTable, lhsTerm);
675
676 auto &note2 = diag.attachNote(rhs.getLoc());
677 note2 << "2nd operand has domains: ";
678 render(info, note2, idTable, rhsTerm);
679
680 return failure();
681 }
682 }
683 table.setDomainAssociation(rhs, lhsTerm);
684 return success();
685 }
686
687 if (rhsTerm) {
688 table.setDomainAssociation(lhs, rhsTerm);
689 return success();
690 }
691
692 auto *var = allocator.allocVar();
693 table.setDomainAssociation(lhs, var);
694 table.setDomainAssociation(rhs, var);
695 return success();
696}
697
698static LogicalResult processModulePorts(const DomainInfo &info,
699 TermAllocator &allocator,
700 DomainTable &table,
701 FModuleOp moduleOp) {
702 auto numDomains = info.getNumDomains();
703 auto domainInfo = moduleOp.getDomainInfoAttr();
704 auto numPorts = moduleOp.getNumPorts();
705
706 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
707 for (size_t i = 0; i < numPorts; ++i) {
708 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
709 if (!port)
710 continue;
711
712 if (moduleOp.getPortDirection(i) == Direction::In)
713 processDomainDefinition(allocator, table, port);
714
715 domainTypeIDTable[i] = info.getDomainTypeID(moduleOp, i);
716 }
717
718 for (size_t i = 0; i < numPorts; ++i) {
719 BlockArgument port = moduleOp.getArgument(i);
720 auto type = type_dyn_cast<FIRRTLBaseType>(port.getType());
721 if (!type)
722 continue;
723
724 SmallVector<IntegerAttr> associations(numDomains);
725 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
726 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
727 auto prevDomainPortIndex = associations[domainTypeID.index];
728 if (prevDomainPortIndex) {
729 emitDuplicatePortDomainError(info, moduleOp, i, domainTypeID,
730 prevDomainPortIndex, domainPortIndex);
731 return failure();
732 }
733 associations[domainTypeID.index] = domainPortIndex;
734 }
735
736 SmallVector<Term *> elements(numDomains);
737 for (size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
738 ++domainTypeIndex) {
739 auto domainPortIndex = associations[domainTypeIndex];
740 if (!domainPortIndex)
741 continue;
742 auto domainPortValue =
743 cast<DomainValue>(moduleOp.getArgument(domainPortIndex.getUInt()));
744 elements[domainTypeIndex] =
745 getTermForDomain(allocator, table, domainPortValue);
746 }
747
748 auto *domainAssociations = allocator.allocRow(elements);
749 table.setDomainAssociation(port, domainAssociations);
750 }
751
752 return success();
753}
754
755template <typename T>
756static LogicalResult processInstancePorts(const DomainInfo &info,
757 TermAllocator &allocator,
758 DomainTable &table, T op) {
759 auto numDomains = info.getNumDomains();
760 auto domainInfo = op.getDomainInfoAttr();
761 auto numPorts = op.getNumPorts();
762
763 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
764 for (size_t i = 0; i < numPorts; ++i) {
765 auto port = dyn_cast<DomainValue>(op->getResult(i));
766 if (!port)
767 continue;
768
769 if (op.getPortDirection(i) == Direction::Out)
770 processDomainDefinition(allocator, table, port);
771
772 domainTypeIDTable[i] = info.getDomainTypeID(op, i);
773 }
774
775 for (size_t i = 0; i < numPorts; ++i) {
776 Value port = op->getResult(i);
777 auto type = type_dyn_cast<FIRRTLBaseType>(port.getType());
778 if (!type)
779 continue;
780
781 SmallVector<IntegerAttr> associations(numDomains);
782 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
783 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
784 auto prevDomainPortIndex = associations[domainTypeID.index];
785 if (prevDomainPortIndex) {
786 emitDuplicatePortDomainError(info, op, i, domainTypeID,
787 prevDomainPortIndex, domainPortIndex);
788 return failure();
789 }
790 associations[domainTypeID.index] = domainPortIndex;
791 }
792
793 SmallVector<Term *> elements(numDomains);
794 for (size_t domainTypeIndex = 0; domainTypeIndex < numDomains;
795 ++domainTypeIndex) {
796 auto domainPortIndex = associations[domainTypeIndex];
797 if (!domainPortIndex)
798 continue;
799 auto domainPortValue =
800 cast<DomainValue>(op->getResult(domainPortIndex.getUInt()));
801 elements[domainTypeIndex] =
802 getTermForDomain(allocator, table, domainPortValue);
803 }
804
805 auto *domainAssociations = allocator.allocRow(elements);
806 table.setDomainAssociation(port, domainAssociations);
807 }
808
809 return success();
810}
811
812static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
813 DomainTable &table,
814 const ModuleUpdateTable &updateTable,
815 FInstanceLike op) {
816 auto moduleName =
817 cast<StringAttr>(cast<ArrayAttr>(op.getReferencedModuleNamesAttr())[0]);
818 auto lookup = updateTable.find(moduleName);
819 if (lookup != updateTable.end())
820 op = fixInstancePorts(op, lookup->second);
821 return processInstancePorts(info, allocator, table, op);
822}
823
824static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
825 DomainTable &table, UnsafeDomainCastOp op) {
826 auto domains = op.getDomains();
827 if (domains.empty())
828 return unifyAssociations(info, allocator, table, op, op.getInput(),
829 op.getResult());
830
831 auto input = op.getInput();
832 RowTerm *inputRow = getDomainAssociationAsRow(info, allocator, table, input);
833 SmallVector<Term *> elements(inputRow->elements);
834 for (auto value : op.getDomains()) {
835 auto domain = cast<DomainValue>(value);
836 auto typeID = info.getDomainTypeID(domain);
837 elements[typeID.index] = getTermForDomain(allocator, table, domain);
838 }
839
840 auto *row = allocator.allocRow(elements);
841 table.setDomainAssociation(op.getResult(), row);
842 return success();
843}
844
845static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
846 DomainTable &table, DomainDefineOp op) {
847 auto src = op.getSrc();
848 auto dst = op.getDest();
849 auto *srcTerm = getTermForDomain(allocator, table, src);
850 auto *dstTerm = getTermForDomain(allocator, table, dst);
851 if (succeeded(unify(dstTerm, srcTerm)))
852 return success();
853
854 auto diag =
855 op->emitOpError()
856 << "defines a domain value that was inferred to be a different domain '";
857 VariableIDTable idTable;
858 auto dstName = getFieldName(FieldRef(dst, 0), false).first;
859 render(info, *diag.getUnderlyingDiagnostic(), idTable, dstTerm);
860 diag << "'";
861
862 return failure();
863}
864
865static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
866 DomainTable &table, WireOp op) {
867 // If the wire has explicit domain operands, seed the domain table with them
868 // as constraints. When this op is visited, connections have not yet been
869 // processed (wire declarations precede their uses), so the existing row
870 // contains only fresh variables that unify unconditionally. Any conflict
871 // between an explicit wire domain and a connection's inferred domain is
872 // caught later by the connection's own processOp.
873 if (op.getDomains().empty())
874 return success();
875
876 // Build a row with the explicitly-specified domain slots filled in and set
877 // it as the association for this wire result.
878 SmallVector<Term *> elements(info.getNumDomains());
879 for (auto domain : op.getDomains()) {
880 auto domainValue = cast<DomainValue>(domain);
881 auto typeID = info.getDomainTypeID(domainValue);
882 elements[typeID.index] = getTermForDomain(allocator, table, domainValue);
883 }
884 table.setDomainAssociation(op.getResult(), allocator.allocRow(elements));
885
886 return success();
887}
888
889static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator,
890 DomainTable &table,
891 const ModuleUpdateTable &updateTable,
892 Operation *op) {
893 if (auto instance = dyn_cast<FInstanceLike>(op))
894 return processOp(info, allocator, table, updateTable, instance);
895 if (auto wireOp = dyn_cast<WireOp>(op))
896 return processOp(info, allocator, table, wireOp);
897 if (auto cast = dyn_cast<UnsafeDomainCastOp>(op))
898 return processOp(info, allocator, table, cast);
899 if (auto def = dyn_cast<DomainDefineOp>(op))
900 return processOp(info, allocator, table, def);
901 if (auto create = dyn_cast<DomainCreateOp>(op)) {
902 processDomainDefinition(allocator, table, create);
903 return success();
904 }
905 if (auto createAnon = dyn_cast<DomainCreateAnonOp>(op)) {
906 processDomainDefinition(allocator, table, createAnon);
907 return success();
908 }
909
910 // For all other operations (including connections), propagate domains from
911 // operands to results. This is a conservative approach - all operands and
912 // results share the same domain associations.
913 Value lhs;
914 for (auto rhs : op->getOperands()) {
915 if (!isa<FIRRTLBaseType>(rhs.getType()))
916 continue;
917 if (auto *op = rhs.getDefiningOp();
918 op && op->hasTrait<OpTrait::ConstantLike>())
919 continue;
920 if (failed(unifyAssociations(info, allocator, table, op, lhs, rhs)))
921 return failure();
922 lhs = rhs;
923 }
924 for (auto rhs : op->getResults()) {
925 if (!isa<FIRRTLBaseType>(rhs.getType()))
926 continue;
927 if (auto *op = rhs.getDefiningOp();
928 op && op->hasTrait<OpTrait::ConstantLike>())
929 continue;
930 if (failed(unifyAssociations(info, allocator, table, op, lhs, rhs)))
931 return failure();
932 lhs = rhs;
933 }
934 return success();
935}
936
937static LogicalResult processModuleBody(const DomainInfo &info,
938 TermAllocator &allocator,
939 DomainTable &table,
940 const ModuleUpdateTable &updateTable,
941 FModuleOp moduleOp) {
942 return failure(moduleOp.getBody()
943 .walk([&](Operation *op) -> WalkResult {
944 return processOp(info, allocator, table, updateTable,
945 op);
946 })
947 .wasInterrupted());
948}
949
950/// Populate the domain table by processing the moduleOp. If the moduleOp has
951/// any domain crossing errors, return failure.
952static LogicalResult processModule(const DomainInfo &info,
953 TermAllocator &allocator, DomainTable &table,
954 const ModuleUpdateTable &updateTable,
955 FModuleOp moduleOp) {
956 if (failed(processModulePorts(info, allocator, table, moduleOp)))
957 return failure();
958 if (failed(processModuleBody(info, allocator, table, updateTable, moduleOp)))
959 return failure();
960 return success();
961}
962
963//===---------------------------------------------------------------------------
964// ExportTable
965//===---------------------------------------------------------------------------
966
967/// A map from domain IR values defined internal to the moduleOp, to ports that
968/// alias that domain. These ports make the domain useable as associations of
969/// ports, and we say these are exporting ports.
970using ExportTable = DenseMap<DomainValue, TinyPtrVector<DomainValue>>;
971
972/// Build a table of exported domains: a map from domains defined internally,
973/// to their set of aliasing output ports.
974static ExportTable initializeExportTable(const DomainTable &table,
975 FModuleOp moduleOp) {
976 ExportTable exports;
977 size_t numPorts = moduleOp.getNumPorts();
978 for (size_t i = 0; i < numPorts; ++i) {
979 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
980 if (!port)
981 continue;
982 auto value = table.getOptUnderlyingDomain(port);
983 if (value)
984 exports[value].push_back(port);
985 }
986
987 return exports;
988}
989
990//====--------------------------------------------------------------------------
991// Updating: write domains back to the IR.
992//====--------------------------------------------------------------------------
993
994/// A map from unsolved variables to a port index, where that port has not yet
995/// been created. Eventually we will have an input domain at the port index,
996/// which will be the solution to the recorded variable.
997using PendingSolutions = DenseMap<VariableTerm *, unsigned>;
998
999/// A map from local domains to an aliasing port index, where that port has not
1000/// yet been created. Eventually we will be exporting the domain value at the
1001/// port index.
1002using PendingExports = llvm::MapVector<DomainValue, unsigned>;
1003
1004namespace {
1005struct PendingUpdates {
1006 PortInsertions insertions;
1007 PendingSolutions solutions;
1008 PendingExports exports;
1009};
1010} // namespace
1011
1012/// If `var` is not solved, solve it by recording a pending input port at
1013/// the indicated insertion point.
1014static void ensureSolved(const DomainInfo &info, Namespace &ns,
1015 DomainTypeID typeID, size_t ip, LocationAttr loc,
1016 VariableTerm *var, PendingUpdates &pending) {
1017 if (pending.solutions.contains(var))
1018 return;
1019
1020 auto *context = loc.getContext();
1021 auto domainDecl = info.getDomain(typeID);
1022 auto domainName = domainDecl.getNameAttr();
1023
1024 auto portName = StringAttr::get(context, ns.newName(domainName.getValue()));
1025 auto portType = DomainType::getFromDomainOp(domainDecl);
1026 auto portDirection = Direction::In;
1027 auto portSym = StringAttr();
1028 auto portLoc = loc;
1029 auto portAnnos = std::nullopt;
1030 // Domain type ports have no associations (domain info is in the type).
1031 auto portDomainInfo = ArrayAttr::get(context, {});
1032 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1033 portAnnos, portDomainInfo);
1034
1035 pending.solutions[var] = pending.insertions.size() + ip;
1036 pending.insertions.push_back({ip, portInfo});
1037}
1038
1039/// Ensure that the domain value is available in the signature of the moduleOp,
1040/// so that subsequent hardware ports may be associated with this domain.
1041// If the domain is defined internally in the moduleOp, ensure it is aliased by
1042// an
1043/// output port.
1044static void ensureExported(const DomainInfo &info, Namespace &ns,
1045 const ExportTable &exports, DomainTypeID typeID,
1046 size_t ip, LocationAttr loc, ValueTerm *val,
1047 PendingUpdates &pending) {
1048 auto value = val->value;
1049 assert(isa<DomainType>(value.getType()));
1050 if (isPort(value) || exports.contains(value) ||
1051 pending.exports.contains(value))
1052 return;
1053
1054 auto *context = loc.getContext();
1055
1056 auto domainDecl = info.getDomain(typeID);
1057 auto domainName = domainDecl.getNameAttr();
1058
1059 auto portName = StringAttr::get(context, ns.newName(domainName.getValue()));
1060 auto portType = DomainType::getFromDomainOp(domainDecl);
1061 auto portDirection = Direction::Out;
1062 auto portSym = StringAttr();
1063 auto portLoc = value.getLoc();
1064 auto portAnnos = std::nullopt;
1065 // Domain type ports have no associations (domain info is in the type).
1066 auto portDomainInfo = ArrayAttr::get(context, {});
1067 PortInfo portInfo(portName, portType, portDirection, portSym, portLoc,
1068 portAnnos, portDomainInfo);
1069 pending.exports[value] = pending.insertions.size() + ip;
1070 pending.insertions.push_back({ip, portInfo});
1071}
1072
1073static void getUpdatesForDomainAssociationOfPort(const DomainInfo &info,
1074 Namespace &ns,
1075 PendingUpdates &pending,
1076 DomainTypeID typeID, size_t ip,
1077 LocationAttr loc, Term *term,
1078 const ExportTable &exports) {
1079 if (auto *var = dyn_cast<VariableTerm>(term)) {
1080 ensureSolved(info, ns, typeID, ip, loc, var, pending);
1081 return;
1082 }
1083 if (auto *val = dyn_cast<ValueTerm>(term)) {
1084 ensureExported(info, ns, exports, typeID, ip, loc, val, pending);
1085 return;
1086 }
1087 llvm_unreachable("invalid domain association");
1088}
1089
1091 const DomainInfo &info, Namespace &ns, const ExportTable &exports,
1092 size_t ip, LocationAttr loc, RowTerm *row, PendingUpdates &pending) {
1093 for (auto [index, term] : llvm::enumerate(row->elements))
1094 getUpdatesForDomainAssociationOfPort(info, ns, pending, DomainTypeID{index},
1095 ip, loc, find(term), exports);
1096}
1097
1098static void getUpdatesForModulePorts(const DomainInfo &info,
1099 TermAllocator &allocator,
1100 const ExportTable &exports,
1101 DomainTable &table, Namespace &ns,
1102 FModuleOp moduleOp,
1103 PendingUpdates &pending) {
1104 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1105 auto port = moduleOp.getArgument(i);
1106 auto type = port.getType();
1107 if (!isa<FIRRTLBaseType>(type))
1108 continue;
1110 info, ns, exports, i, moduleOp.getPortLocation(i),
1111 getDomainAssociationAsRow(info, allocator, table, port), pending);
1112 }
1113}
1114
1115static void getUpdatesForModule(const DomainInfo &info,
1116 TermAllocator &allocator,
1117 const ExportTable &exports, DomainTable &table,
1118 FModuleOp mod, PendingUpdates &pending) {
1119 Namespace ns;
1120 auto names = mod.getPortNamesAttr();
1121 for (auto name : names.getAsRange<StringAttr>())
1122 ns.add(name);
1123 getUpdatesForModulePorts(info, allocator, exports, table, ns, mod, pending);
1124}
1125
1126static void applyUpdatesToModule(const DomainInfo &info,
1127 TermAllocator &allocator, ExportTable &exports,
1128 DomainTable &table, FModuleOp moduleOp,
1129 const PendingUpdates &pending) {
1130 // Put the domain ports in place.
1131 moduleOp.insertPorts(pending.insertions);
1132
1133 // Solve any variables and record them as "self-exporting".
1134 for (auto [var, portIndex] : pending.solutions) {
1135 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1136 auto *solution = allocator.allocVal(portValue);
1137 solve(var, solution);
1138 exports[portValue].push_back(portValue);
1139 }
1140
1141 // Drive the output ports, and record the export.
1142 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1143 for (auto [domainValue, portIndex] : pending.exports) {
1144 auto portValue = cast<DomainValue>(moduleOp.getArgument(portIndex));
1145 builder.setInsertionPointAfterValue(domainValue);
1146 DomainDefineOp::create(builder, portValue.getLoc(), portValue, domainValue);
1147
1148 exports[domainValue].push_back(portValue);
1149 table.setTermForDomain(portValue, allocator.allocVal(domainValue));
1150 }
1151}
1152
1153/// Copy the domain associations from the moduleOp domain info attribute into a
1154/// small vector.
1155static SmallVector<Attribute>
1156copyPortDomainAssociations(const DomainInfo &info, FModuleLike moduleOp,
1157 ArrayAttr moduleDomainInfo, size_t portIndex) {
1158 SmallVector<Attribute> result(info.getNumDomains());
1159 auto oldAssociations = getPortDomainAssociation(moduleDomainInfo, portIndex);
1160 for (auto domainPortIndexAttr : oldAssociations) {
1161
1162 auto domainPortIndex = domainPortIndexAttr.getUInt();
1163 auto domainTypeID = info.getDomainTypeID(moduleOp, domainPortIndex);
1164 result[domainTypeID.index] = domainPortIndexAttr;
1165 };
1166 return result;
1167}
1168
1169// If the port is an output domain, we may need to drive the output with
1170// a value. If we don't know what value to drive to the port, error.
1171static LogicalResult driveModuleOutputDomainPorts(const DomainInfo &info,
1172 const DomainTable &table,
1173 FModuleOp moduleOp) {
1174 auto builder = OpBuilder::atBlockEnd(moduleOp.getBodyBlock());
1175 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1176 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1177 if (!port || moduleOp.getPortDirection(i) == Direction::In ||
1178 isDriven(port))
1179 continue;
1180
1181 auto *term = table.getOptTermForDomain(port);
1182 auto *val = llvm::dyn_cast_if_present<ValueTerm>(term);
1183 if (!val) {
1184 emitDomainPortInferenceError(moduleOp, i);
1185 return failure();
1186 }
1187
1188 auto loc = port.getLoc();
1189 auto value = val->value;
1190 DomainDefineOp::create(builder, loc, port, value);
1191 }
1192
1193 return success();
1194}
1195
1196/// After generalizing the moduleOp, all domains should be solved. Reflect the
1197/// solved domain associations into the port domain info attribute.
1198static LogicalResult updateModuleDomainInfo(const DomainInfo &info,
1199 const DomainTable &table,
1200 const ExportTable &exportTable,
1201 ArrayAttr &result,
1202 FModuleOp moduleOp) {
1203 // At this point, all domain variables mentioned in ports have been
1204 // solved by generalizing the moduleOp (adding input domain ports). Now, we
1205 // have to form the new port domain information for the moduleOp by examining
1206 // the the associated domains of each port.
1207 auto *context = moduleOp.getContext();
1208 auto numDomains = info.getNumDomains();
1209 auto oldModuleDomainInfo = moduleOp.getDomainInfoAttr();
1210 auto numPorts = moduleOp.getNumPorts();
1211 SmallVector<Attribute> newModuleDomainInfo(numPorts);
1212
1213 for (size_t i = 0; i < numPorts; ++i) {
1214 auto port = moduleOp.getArgument(i);
1215 auto type = port.getType();
1216
1217 if (isa<DomainType>(type)) {
1218 // Domain type ports have no associations (domain info is in the type).
1219 newModuleDomainInfo[i] = ArrayAttr::get(context, {});
1220 continue;
1221 }
1222
1223 if (isa<FIRRTLBaseType>(type)) {
1224 auto associations =
1225 copyPortDomainAssociations(info, moduleOp, oldModuleDomainInfo, i);
1226 auto *row = cast<RowTerm>(table.getDomainAssociation(port));
1227 for (size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1228 auto domainTypeID = DomainTypeID{domainIndex};
1229 if (associations[domainIndex])
1230 continue;
1231
1232 auto domain = cast<ValueTerm>(find(row->elements[domainIndex]))->value;
1233 auto &exports = exportTable.at(domain);
1234 if (exports.empty()) {
1235 auto portName = moduleOp.getPortNameAttr(i);
1236 auto portLoc = moduleOp.getPortLocation(i);
1237 auto domainDecl = info.getDomain(domainTypeID);
1238 auto domainName = domainDecl.getNameAttr();
1239 auto diag = emitError(portLoc)
1240 << "private " << domainName << " association for port "
1241 << portName;
1242 diag.attachNote(domain.getLoc()) << "associated domain: " << domain;
1243 noteLocation(diag, moduleOp);
1244 return failure();
1245 }
1246
1247 if (exports.size() > 1) {
1248 emitAmbiguousPortDomainAssociation(info, moduleOp, exports,
1249 domainTypeID, i);
1250 return failure();
1251 }
1252
1253 auto argument = cast<BlockArgument>(exports[0]);
1254 auto domainPortIndex = argument.getArgNumber();
1255 associations[domainTypeID.index] = IntegerAttr::get(
1256 IntegerType::get(context, 32, IntegerType::Unsigned),
1257 domainPortIndex);
1258 }
1259
1260 newModuleDomainInfo[i] = ArrayAttr::get(context, associations);
1261 continue;
1262 }
1263
1264 newModuleDomainInfo[i] = ArrayAttr::get(context, {});
1265 }
1266
1267 result = ArrayAttr::get(moduleOp.getContext(), newModuleDomainInfo);
1268 moduleOp.setDomainInfoAttr(result);
1269 return success();
1270}
1271
1272static LogicalResult updateInstance(const DomainInfo &info,
1273 TermAllocator &allocator,
1274 DomainTable &table, FInstanceLike op,
1275 OpBuilder &builder) {
1276 auto numPorts = op->getNumResults();
1277 for (size_t i = 0; i < numPorts; ++i) {
1278 auto port = dyn_cast<DomainValue>(op->getResult(i));
1279 auto direction = op.getPortDirection(i);
1280
1281 // If the port is an input domain, we may need to drive the input with
1282 // a value. If we don't know what value to drive to the port, drive an
1283 // anonymous domain.
1284 if (port && direction == Direction::In && !isDriven(port)) {
1285 auto loc = port.getLoc();
1286 auto *term = getTermForDomain(allocator, table, port);
1287 if (auto *var = dyn_cast<VariableTerm>(term)) {
1288 auto domainType = cast<DomainType>(op->getResult(i).getType());
1289 auto domainTypeID = info.getDomainTypeID(domainType);
1290 auto domainDecl = info.getDomain(domainTypeID);
1291 auto name = domainDecl.getNameAttr();
1292 DomainValue anon;
1293 {
1294 OpBuilder::InsertionGuard guard(builder);
1295 builder.setInsertionPointAfter(op);
1296 anon = DomainCreateAnonOp::create(builder, loc, domainType, name);
1297 }
1298 solve(var, allocator.allocVal(anon));
1299 // Create domain.define at the end of the block to avoid use-before-def.
1300 DomainDefineOp::create(builder, loc, port, anon);
1301 continue;
1302 }
1303 if (auto *val = dyn_cast<ValueTerm>(term)) {
1304 auto value = val->value;
1305 // Create domain.define at the end of the block to avoid use-before-def.
1306 DomainDefineOp::create(builder, loc, port, value);
1307 continue;
1308 }
1309 llvm_unreachable("unhandled domain term type");
1310 }
1311 }
1312
1313 return success();
1314}
1315
1316/// Update a wire operation with inferred domain associations.
1317static LogicalResult updateWire(const DomainInfo &info,
1318 TermAllocator &allocator, DomainTable &table,
1319 WireOp wireOp) {
1320 auto result = wireOp.getResult();
1321 if (!isa<FIRRTLBaseType>(result.getType()))
1322 return success();
1323
1324 // Get the inferred domain associations for this wire.
1325 auto *term = table.getOptDomainAssociation(result);
1326 if (!term)
1327 return success();
1328
1329 auto *row = dyn_cast<RowTerm>(find(term));
1330 if (!row)
1331 return success();
1332
1333 // Collect the domain values to add as operands.
1334 SmallVector<Value> domainOperands;
1335 for (auto *element : llvm::map_range(row->elements, find))
1336 if (auto *val = dyn_cast_or_null<ValueTerm>(element))
1337 domainOperands.push_back(val->value);
1338
1339 // Update the wire's domain operands in place. $domains is the only operand
1340 // group on WireOp, so setOperands replaces exactly the domain list.
1341 if (!domainOperands.empty() && wireOp.getDomains().empty())
1342 wireOp->setOperands(domainOperands);
1343
1344 return success();
1345}
1346
1347/// After updating the port domain associations, walk the body of the moduleOp
1348/// to fix up any child instance modules and update wires with inferred domains.
1349static LogicalResult updateModuleBody(const DomainInfo &info,
1350 TermAllocator &allocator,
1351 DomainTable &table, FModuleOp moduleOp) {
1352 // Set insertion point to end of block so all domain.define operations are
1353 // created there, avoiding use-before-def issues.
1354 OpBuilder builder(moduleOp.getContext());
1355 builder.setInsertionPointToEnd(moduleOp.getBodyBlock());
1356
1357 // Update instances
1358 auto instanceResult =
1359 moduleOp.getBodyBlock()->walk([&](FInstanceLike op) -> WalkResult {
1360 return updateInstance(info, allocator, table, op, builder);
1361 });
1362 if (instanceResult.wasInterrupted())
1363 return failure();
1364
1365 // Update wires with inferred domain associations
1366 auto wireResult = moduleOp.getBodyBlock()->walk([&](WireOp op) -> WalkResult {
1367 return updateWire(info, allocator, table, op);
1368 });
1369 return failure(wireResult.wasInterrupted());
1370}
1371
1372/// Write the domain associations recorded in the domain table back to the IR.
1373static LogicalResult updateModule(const DomainInfo &info,
1374 TermAllocator &allocator, DomainTable &table,
1375 ModuleUpdateTable &updates, FModuleOp op) {
1376 auto exports = initializeExportTable(table, op);
1377 PendingUpdates pending;
1378 getUpdatesForModule(info, allocator, exports, table, op, pending);
1379 applyUpdatesToModule(info, allocator, exports, table, op, pending);
1380
1381 // Update the domain info for the moduleOp's ports.
1382 ArrayAttr portDomainInfo;
1383 if (failed(updateModuleDomainInfo(info, table, exports, portDomainInfo, op)))
1384 return failure();
1385
1386 // Drive output domain ports.
1387 if (failed(driveModuleOutputDomainPorts(info, table, op)))
1388 return failure();
1389
1390 // Record the updated interface change in the update table.
1391 auto &entry = updates[op.getModuleNameAttr()];
1392 entry.portDomainInfo = portDomainInfo;
1393 entry.portInsertions = std::move(pending.insertions);
1394
1395 if (failed(updateModuleBody(info, allocator, table, op)))
1396 return failure();
1397
1398 return success();
1399}
1400
1401//===---------------------------------------------------------------------------
1402// Checking: Check that a moduleOp has complete domain information.
1403//===---------------------------------------------------------------------------
1404
1405/// Check that a module's hardware ports have complete domain associations.
1406static LogicalResult checkModulePorts(const DomainInfo &info,
1407 FModuleLike moduleOp) {
1408 auto numDomains = info.getNumDomains();
1409 auto domainInfo = moduleOp.getDomainInfoAttr();
1410 auto numPorts = moduleOp.getNumPorts();
1411
1412 DenseMap<unsigned, DomainTypeID> domainTypeIDTable;
1413 for (size_t i = 0; i < numPorts; ++i) {
1414 if (isa<DomainType>(moduleOp.getPortType(i)))
1415 domainTypeIDTable[i] = info.getDomainTypeID(moduleOp, i);
1416 }
1417
1418 for (size_t i = 0; i < numPorts; ++i) {
1419 auto type = type_dyn_cast<FIRRTLBaseType>(moduleOp.getPortType(i));
1420 if (!type)
1421 continue;
1422
1423 // Record the domain associations of this port.
1424 SmallVector<IntegerAttr> associations(numDomains);
1425 for (auto domainPortIndex : getPortDomainAssociation(domainInfo, i)) {
1426 auto domainTypeID = domainTypeIDTable.at(domainPortIndex.getUInt());
1427 auto prevDomainPortIndex = associations[domainTypeID.index];
1428 if (prevDomainPortIndex) {
1429 emitDuplicatePortDomainError(info, moduleOp, i, domainTypeID,
1430 prevDomainPortIndex, domainPortIndex);
1431 return failure();
1432 }
1433 associations[domainTypeID.index] = domainPortIndex;
1434 }
1435
1436 // Check the associations for completeness.
1437 for (size_t domainIndex = 0; domainIndex < numDomains; ++domainIndex) {
1438 auto typeID = DomainTypeID{domainIndex};
1439 if (!associations[domainIndex]) {
1440 emitMissingPortDomainAssociationError(info, moduleOp, typeID, i);
1441 return failure();
1442 }
1443 }
1444 }
1445
1446 return success();
1447}
1448
1449/// Check that output domain ports are driven.
1450static LogicalResult checkModuleDomainPortDrivers(const DomainInfo &info,
1451 FModuleOp moduleOp) {
1452 for (size_t i = 0, e = moduleOp.getNumPorts(); i < e; ++i) {
1453 auto port = dyn_cast<DomainValue>(moduleOp.getArgument(i));
1454 if (!port || moduleOp.getPortDirection(i) != Direction::Out ||
1455 isDriven(port))
1456 continue;
1457
1458 auto name = moduleOp.getPortNameAttr(i);
1459 auto diag = emitError(moduleOp.getPortLocation(i))
1460 << "undriven domain port " << name;
1461 noteLocation(diag, moduleOp);
1462 return failure();
1463 }
1464
1465 return success();
1466}
1467
1468/// Check that the input domain ports are driven.
1469static LogicalResult checkInstanceDomainPortDrivers(FInstanceLike op) {
1470 for (size_t i = 0, e = op->getNumResults(); i < e; ++i) {
1471 auto port = dyn_cast<DomainValue>(op->getResult(i));
1472
1473 auto type = port.getType();
1474 if (!isa<DomainType>(type) || op.getPortDirection(i) != Direction::In ||
1475 isDriven(port))
1476 continue;
1477
1478 auto name = op.getPortNameAttr(i);
1479 auto diag = emitError(op.getPortLocation(i))
1480 << "undriven domain port " << name;
1481 noteLocation(diag, op);
1482 return failure();
1483 }
1484
1485 return success();
1486}
1487
1488/// Check that instances under this module have driven domain input ports.
1489static LogicalResult checkModuleBody(FModuleOp moduleOp) {
1490 auto result = moduleOp.getBody().walk([](FInstanceLike op) -> WalkResult {
1492 });
1493 return failure(result.wasInterrupted());
1494}
1495
1496//===---------------------------------------------------------------------------
1497// Domain Stripping.
1498//===---------------------------------------------------------------------------
1499
1500static LogicalResult stripModule(FModuleLike op) {
1501 WalkResult result = op->walk<mlir::WalkOrder::PostOrder, ReverseIterator>(
1502 [=](Operation *op) -> WalkResult {
1503 return TypeSwitch<Operation *, WalkResult>(op)
1504 .Case<FModuleLike>([](FModuleLike op) {
1505 auto n = op.getNumPorts();
1506 BitVector erasures(n);
1507 for (size_t i = 0; i < n; ++i)
1508 if (isa<DomainType>(op.getPortType(i)))
1509 erasures.set(i);
1510 op.erasePorts(erasures);
1511 return WalkResult::advance();
1512 })
1513 .Case<DomainDefineOp, DomainCreateAnonOp, DomainCreateOp>(
1514 [](Operation *op) {
1515 op->erase();
1516 return WalkResult::advance();
1517 })
1518 .Case<DomainSubfieldOp>([](DomainSubfieldOp op) {
1519 if (!op->use_empty()) {
1520 OpBuilder builder(op);
1521 op.replaceAllUsesWith(
1522 UnknownValueOp::create(builder, op.getLoc(), op.getType())
1523 .getResult());
1524 }
1525 op.erase();
1526 return WalkResult::advance();
1527 })
1528 .Case<UnsafeDomainCastOp>([](UnsafeDomainCastOp op) {
1529 op.replaceAllUsesWith(op.getInput());
1530 op.erase();
1531 return WalkResult::advance();
1532 })
1533 .Case<WireOp>([](WireOp op) {
1534 // Erase wires of DomainType
1535 if (isa<DomainType>(op.getType(0))) {
1536 op->erase();
1537 return WalkResult::advance();
1538 }
1539 // Erase domain operands from regular wires
1540 if (!op.getDomains().empty()) {
1541 op->eraseOperands(0, op.getNumOperands());
1542 }
1543 return WalkResult::advance();
1544 })
1545 .Case<FInstanceLike>([](auto op) {
1546 auto n = op.getNumPorts();
1547 BitVector erasures(n);
1548 for (size_t i = 0; i < n; ++i)
1549 if (isa<DomainType>(op->getResult(i).getType()))
1550 erasures.set(i);
1551 op.cloneWithErasedPortsAndReplaceUses(erasures);
1552 op.erase();
1553 return WalkResult::advance();
1554 })
1555 .Default([](Operation *op) {
1556 for (auto type :
1557 concat<Type>(op->getOperandTypes(), op->getResultTypes())) {
1558 if (isa<DomainType>(type)) {
1559 op->emitOpError("cannot be stripped");
1560 return WalkResult::interrupt();
1561 }
1562 }
1563 return WalkResult::advance();
1564 });
1565 });
1566 return failure(result.wasInterrupted());
1567}
1568
1569static LogicalResult stripCircuit(MLIRContext *context, CircuitOp circuit) {
1570 llvm::SmallVector<FModuleLike> modules;
1571 for (Operation &op : make_early_inc_range(*circuit.getBodyBlock())) {
1572 TypeSwitch<Operation *, void>(&op)
1573 .Case<FModuleLike>([&](FModuleLike op) { modules.push_back(op); })
1574 .Case<DomainOp>([](DomainOp op) { op.erase(); });
1575 }
1576 return failableParallelForEach(context, modules, stripModule);
1577}
1578
1579//===---------------------------------------------------------------------------
1580// InferDomainsPass: Top-level pass implementation.
1581//===---------------------------------------------------------------------------
1582
1583/// Solve for domains and then write the domain associations back to the IR.
1584static LogicalResult inferModule(const DomainInfo &info,
1585 ModuleUpdateTable &updates,
1586 FModuleOp moduleOp) {
1587 TermAllocator allocator;
1588 DomainTable table;
1589
1590 if (failed(processModule(info, allocator, table, updates, moduleOp)))
1591 return failure();
1592
1593 return updateModule(info, allocator, table, updates, moduleOp);
1594}
1595
1596/// Check that a module's ports are fully annotated, before performing domain
1597/// inference on the module.
1598static LogicalResult checkModule(const DomainInfo &info, FModuleOp moduleOp) {
1599 if (failed(checkModulePorts(info, moduleOp)))
1600 return failure();
1601
1602 if (failed(checkModuleDomainPortDrivers(info, moduleOp)))
1603 return failure();
1604
1605 if (failed(checkModuleBody(moduleOp)))
1606 return failure();
1607
1608 TermAllocator allocator;
1609 DomainTable table;
1610 ModuleUpdateTable updateTable;
1611 return processModule(info, allocator, table, updateTable, moduleOp);
1612}
1613
1614/// Check that an extmodule's ports are fully annotated.
1615static LogicalResult checkModule(const DomainInfo &info,
1616 FExtModuleOp moduleOp) {
1617 return checkModulePorts(info, moduleOp);
1618}
1619
1620/// Check that a module's ports are fully annotated, before performing domain
1621/// inference on the module. We use this when private module interfaces are
1622/// inferred but public module interfaces are checked.
1623static LogicalResult checkAndInferModule(const DomainInfo &info,
1624 ModuleUpdateTable &updateTable,
1625 FModuleOp moduleOp) {
1626 if (failed(checkModulePorts(info, moduleOp)))
1627 return failure();
1628
1629 TermAllocator allocator;
1630 DomainTable table;
1631 if (failed(processModule(info, allocator, table, updateTable, moduleOp)))
1632 return failure();
1633
1634 if (failed(driveModuleOutputDomainPorts(info, table, moduleOp)))
1635 return failure();
1636
1637 return updateModuleBody(info, allocator, table, moduleOp);
1638}
1639
1640static LogicalResult runOnModuleLike(InferDomainsMode mode,
1641 const DomainInfo &info,
1642 ModuleUpdateTable &updateTable,
1643 Operation *op) {
1644 assert(mode != InferDomainsMode::Strip);
1645
1646 if (auto moduleOp = dyn_cast<FModuleOp>(op)) {
1647 if (mode == InferDomainsMode::Check)
1648 return checkModule(info, moduleOp);
1649
1650 if (mode == InferDomainsMode::InferAll || moduleOp.isPrivate())
1651 return inferModule(info, updateTable, moduleOp);
1652
1653 return checkAndInferModule(info, updateTable, moduleOp);
1654 }
1655
1656 if (auto extModule = dyn_cast<FExtModuleOp>(op))
1657 return checkModule(info, extModule);
1658
1659 return success();
1660}
1661
1662namespace {
1663struct InferDomainsPass
1664 : public circt::firrtl::impl::InferDomainsBase<InferDomainsPass> {
1665 using Base::Base;
1666 void runOnOperation() override {
1668 auto circuit = getOperation();
1669
1670 if (mode == InferDomainsMode::Strip) {
1671 if (failed(stripCircuit(&getContext(), circuit)))
1672 signalPassFailure();
1673 return;
1674 }
1675
1676 auto &instanceGraph = getAnalysis<InstanceGraph>();
1677 DomainInfo info(circuit);
1678 ModuleUpdateTable updateTable;
1679 DenseSet<Operation *> errored;
1680 instanceGraph.walkPostOrder([&](auto &node) {
1681 auto moduleOp = node.getModule();
1682 for (auto *inst : node) {
1683 if (errored.contains(inst->getTarget()->getModule())) {
1684 errored.insert(moduleOp);
1685 return;
1686 }
1687 }
1688 if (failed(runOnModuleLike(mode, info, updateTable, node.getModule())))
1689 errored.insert(moduleOp);
1690 });
1691 if (errored.size())
1692 signalPassFailure();
1693 }
1694};
1695} // namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static LogicalResult updateModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, ModuleUpdateTable &updates, FModuleOp op)
Write the domain associations recorded in the domain table back to the IR.
static void emitDuplicatePortDomainError(const DomainInfo &info, T op, size_t i, DomainTypeID domainTypeID, IntegerAttr domainPortIndexAttr1, IntegerAttr domainPortIndexAttr2)
static ExportTable initializeExportTable(const DomainTable &table, FModuleOp moduleOp)
Build a table of exported domains: a map from domains defined internally, to their set of aliasing ou...
static LogicalResult processModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FModuleOp moduleOp)
static LogicalResult updateInstance(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FInstanceLike op, OpBuilder &builder)
static LogicalResult updateWire(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, WireOp wireOp)
Update a wire operation with inferred domain associations.
static void emitAmbiguousPortDomainAssociation(const DomainInfo &info, T op, const llvm::TinyPtrVector< DomainValue > &exports, DomainTypeID typeID, size_t i)
static LogicalResult processModulePorts(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp)
static LogicalResult inferModule(const DomainInfo &info, ModuleUpdateTable &updates, FModuleOp moduleOp)
Solve for domains and then write the domain associations back to the IR.
static LogicalResult driveModuleOutputDomainPorts(const DomainInfo &info, const DomainTable &table, FModuleOp moduleOp)
SmallVector< std::pair< unsigned, PortInfo > > PortInsertions
llvm::MapVector< DomainValue, unsigned > PendingExports
A map from local domains to an aliasing port index, where that port has not yet been created.
static LogicalResult runOnModuleLike(InferDomainsMode mode, const DomainInfo &info, ModuleUpdateTable &updateTable, Operation *op)
mlir::TypedValue< DomainType > DomainValue
static LogicalResult stripCircuit(MLIRContext *context, CircuitOp circuit)
static RowTerm * getDomainAssociationAsRow(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Value value)
Get the row of domains that a hardware value in the IR is associated with.
static void emitMissingPortDomainAssociationError(const DomainInfo &info, T op, DomainTypeID typeID, size_t i)
static void getUpdatesForModulePorts(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, Namespace &ns, FModuleOp moduleOp, PendingUpdates &pending)
static LogicalResult checkInstanceDomainPortDrivers(FInstanceLike op)
Check that the input domain ports are driven.
static LogicalResult updateModuleBody(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, FModuleOp moduleOp)
After updating the port domain associations, walk the body of the moduleOp to fix up any child instan...
static FInstanceLike fixInstancePorts(FInstanceLike op, const ModuleUpdateInfo &update)
Apply the port changes of a moduleOp onto an instance-like op.
static void render(const DomainInfo &info, Diagnostic &out, VariableIDTable &idTable, Term *term)
static void processDomainDefinition(TermAllocator &allocator, DomainTable &table, DomainValue domain)
static LogicalResult unify(Term *lhs, Term *rhs)
static LogicalResult updateModuleDomainInfo(const DomainInfo &info, const DomainTable &table, const ExportTable &exportTable, ArrayAttr &result, FModuleOp moduleOp)
After generalizing the moduleOp, all domains should be solved.
static LogicalResult unifyAssociations(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, Operation *op, Value lhs, Value rhs)
Unify the associated domain rows of two terms.
DenseMap< VariableTerm *, unsigned > PendingSolutions
A map from unsolved variables to a port index, where that port has not yet been created.
static LogicalResult checkModule(const DomainInfo &info, FModuleOp moduleOp)
Check that a module's ports are fully annotated, before performing domain inference on the module.
static LogicalResult processOp(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FInstanceLike op)
static LogicalResult checkModuleDomainPortDrivers(const DomainInfo &info, FModuleOp moduleOp)
Check that output domain ports are driven.
static void noteLocation(mlir::InFlightDiagnostic &diag, Operation *op)
static void emitPortDomainCrossingError(const DomainInfo &info, T op, size_t i, DomainTypeID domainTypeID, Term *term1, Term *term2)
static void getUpdatesForDomainAssociationOfPort(const DomainInfo &info, Namespace &ns, PendingUpdates &pending, DomainTypeID typeID, size_t ip, LocationAttr loc, Term *term, const ExportTable &exports)
static void applyUpdatesToModule(const DomainInfo &info, TermAllocator &allocator, ExportTable &exports, DomainTable &table, FModuleOp moduleOp, const PendingUpdates &pending)
static void getUpdatesForModule(const DomainInfo &info, TermAllocator &allocator, const ExportTable &exports, DomainTable &table, FModuleOp mod, PendingUpdates &pending)
static SmallVector< Attribute > copyPortDomainAssociations(const DomainInfo &info, FModuleLike moduleOp, ArrayAttr moduleDomainInfo, size_t portIndex)
Copy the domain associations from the moduleOp domain info attribute into a small vector.
static void emitDomainPortInferenceError(T op, size_t i)
Emit an error when we fail to infer the concrete domain to drive to a domain port.
static void ensureExported(const DomainInfo &info, Namespace &ns, const ExportTable &exports, DomainTypeID typeID, size_t ip, LocationAttr loc, ValueTerm *val, PendingUpdates &pending)
Ensure that the domain value is available in the signature of the moduleOp, so that subsequent hardwa...
static bool isPort(BlockArgument arg)
Return true if the value is a port on the module.
DenseMap< DomainValue, TinyPtrVector< DomainValue > > ExportTable
A map from domain IR values defined internal to the moduleOp, to ports that alias that domain.
static Term * getTermForDomain(TermAllocator &allocator, DomainTable &table, DomainValue value)
Get the corresponding term for a domain in the IR.
static auto getPortDomainAssociation(ArrayAttr info, size_t i)
From a domain info attribute, get the row of associated domains for a hardware value at index i.
DenseMap< StringAttr, ModuleUpdateInfo > ModuleUpdateTable
static void ensureSolved(const DomainInfo &info, Namespace &ns, DomainTypeID typeID, size_t ip, LocationAttr loc, VariableTerm *var, PendingUpdates &pending)
If var is not solved, solve it by recording a pending input port at the indicated insertion point.
static bool isDriven(DomainValue port)
Returns true if the value is driven by a connect op.
static void solve(Term *lhs, Term *rhs)
static LogicalResult processModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FModuleOp moduleOp)
Populate the domain table by processing the moduleOp.
static LogicalResult checkModuleBody(FModuleOp moduleOp)
Check that instances under this module have driven domain input ports.
static LogicalResult checkModulePorts(const DomainInfo &info, FModuleLike moduleOp)
Check that a module's hardware ports have complete domain associations.
static Term * find(Term *x)
static LogicalResult processInstancePorts(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, T op)
static LogicalResult checkAndInferModule(const DomainInfo &info, ModuleUpdateTable &updateTable, FModuleOp moduleOp)
Check that a module's ports are fully annotated, before performing domain inference on the module.
static LogicalResult stripModule(FModuleLike op)
#define CIRCT_DEBUG_SCOPED_PASS_LOGGER(PASS)
Definition Debug.h:70
This class represents a reference to a specific field or element of an aggregate value.
Definition FieldRef.h:28
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void add(mlir::ModuleOp module)
Definition Namespace.h:48
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
Definition Namespace.h:87
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
InferDomainsMode
The mode for the InferDomains pass.
Definition Passes.h:73
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
void info(Twine message)
Definition LSPUtils.cpp:20
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name and type that describes the module's ports.