CIRCT 20.0.0git
Loading...
Searching...
No Matches
InferWidths.cpp
Go to the documentation of this file.
1//===- InferWidths.cpp - Infer width of types -------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines the InferWidths pass.
10//
11//===----------------------------------------------------------------------===//
12
17#include "circt/Support/Debug.h"
19#include "mlir/IR/Threading.h"
20#include "mlir/Pass/Pass.h"
21#include "llvm/ADT/APSInt.h"
22#include "llvm/ADT/DenseSet.h"
23#include "llvm/ADT/Hashing.h"
24#include "llvm/ADT/SetVector.h"
25#include "llvm/Support/Debug.h"
26#include "llvm/Support/ErrorHandling.h"
27
28#define DEBUG_TYPE "infer-widths"
29
30namespace circt {
31namespace firrtl {
32#define GEN_PASS_DEF_INFERWIDTHS
33#include "circt/Dialect/FIRRTL/Passes.h.inc"
34} // namespace firrtl
35} // namespace circt
36
37using mlir::InferTypeOpInterface;
38using mlir::WalkOrder;
39
40using namespace circt;
41using namespace firrtl;
42
43//===----------------------------------------------------------------------===//
44// Helpers
45//===----------------------------------------------------------------------===//
46
47static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t,
48 Twine str) {
49 auto basetype = type_dyn_cast<FIRRTLBaseType>(t);
50 if (!basetype)
51 return;
52 if (!basetype.hasUninferredWidth())
53 return;
54
55 if (basetype.isGround())
56 diag.attachNote() << "Field: \"" << str << "\"";
57 else if (auto vecType = type_dyn_cast<FVectorType>(basetype))
58 diagnoseUninferredType(diag, vecType.getElementType(), str + "[]");
59 else if (auto bundleType = type_dyn_cast<BundleType>(basetype))
60 for (auto &elem : bundleType.getElements())
61 diagnoseUninferredType(diag, elem.type, str + "." + elem.name.getValue());
62}
63
64/// Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
65static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type) {
66 uint64_t convertedFieldID = 0;
67
68 auto curFID = fieldID;
69 Type curFType = type;
70 while (curFID != 0) {
71 auto [child, subID] =
73 if (isa<FVectorType>(curFType))
74 convertedFieldID++; // Vector fieldID is 1.
75 else
76 convertedFieldID += curFID - subID; // Add consumed portion.
77 curFID = subID;
78 curFType = child;
79 }
80
81 return convertedFieldID;
82}
83
84//===----------------------------------------------------------------------===//
85// Constraint Expressions
86//===----------------------------------------------------------------------===//
87
88namespace {
89struct Expr;
90} // namespace
91
92/// Allow rvalue refs to `Expr` and subclasses to be printed to streams.
93template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
94 int>::type = 0>
95inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const T &e) {
96 e.print(os);
97 return os;
98}
99
100// Allow expression subclasses to be hashed.
101namespace mlir {
102template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
103 int>::type = 0>
104inline llvm::hash_code hash_value(const T &e) {
105 return e.hash_value();
106}
107} // namespace mlir
108
109namespace {
110#define EXPR_NAMES(x) \
111 Var##x, Derived##x, Id##x, Known##x, Add##x, Pow##x, Max##x, Min##x
112#define EXPR_KINDS EXPR_NAMES()
113#define EXPR_CLASSES EXPR_NAMES(Expr)
114
115/// An expression on the right-hand side of a constraint.
116struct Expr {
117 enum class Kind : uint8_t { EXPR_KINDS };
118
119 /// Print a human-readable representation of this expr.
120 void print(llvm::raw_ostream &os) const;
121
122 std::optional<int32_t> getSolution() const {
123 if (hasSolution)
124 return solution;
125 return std::nullopt;
126 }
127
128 void setSolution(int32_t solution) {
129 hasSolution = true;
130 this->solution = solution;
131 }
132
133 Kind getKind() const { return kind; }
134
135protected:
136 Expr(Kind kind) : kind(kind) {}
137 llvm::hash_code hash_value() const { return llvm::hash_value(kind); }
138
139private:
140 int32_t solution;
141 Kind kind;
142 bool hasSolution = false;
143};
144
145/// Helper class to CRTP-derive common functions.
146template <class DerivedT, Expr::Kind DerivedKind>
147struct ExprBase : public Expr {
148 ExprBase() : Expr(DerivedKind) {}
149 static bool classof(const Expr *e) { return e->getKind() == DerivedKind; }
150 bool operator==(const Expr &other) const {
151 if (auto otherSame = dyn_cast<DerivedT>(other))
152 return *static_cast<DerivedT *>(this) == otherSame;
153 return false;
154 }
155};
156
157/// A free variable.
158struct VarExpr : public ExprBase<VarExpr, Expr::Kind::Var> {
159 void print(llvm::raw_ostream &os) const {
160 // Hash the `this` pointer into something somewhat human readable. Since
161 // this is just for debug dumping, we wrap around at 65536 variables.
162 os << "var" << ((size_t)this / llvm::PowerOf2Ceil(sizeof(*this)) & 0xFFFF);
163 }
164
165 /// The constraint expression this variable is supposed to be greater than or
166 /// equal to. This is not part of the variable's hash and equality property.
167 Expr *constraint = nullptr;
168
169 /// The upper bound this variable is supposed to be smaller than or equal to.
170 Expr *upperBound = nullptr;
171 std::optional<int32_t> upperBoundSolution;
172};
173
174/// A derived width.
175///
176/// These are generated for `InvalidValueOp`s which want to derived their width
177/// from connect operations that they are on the right hand side of.
178struct DerivedExpr : public ExprBase<DerivedExpr, Expr::Kind::Derived> {
179 void print(llvm::raw_ostream &os) const {
180 // Hash the `this` pointer into something somewhat human readable.
181 os << "derive"
182 << ((size_t)this / llvm::PowerOf2Ceil(sizeof(*this)) & 0xFFF);
183 }
184
185 /// The expression this derived width is equivalent to.
186 Expr *assigned = nullptr;
187};
188
189/// An identity expression.
190///
191/// This expression evaluates to its inner expression. It is used in a very
192/// specific case of constraints on variables, in order to be able to track
193/// where the constraint was imposed. Constraints on variables are represented
194/// as `var >= <expr>`. When the first constraint `a` is imposed, it is stored
195/// as the constraint expression (`var >= a`). When the second constraint `b` is
196/// imposed, a *new* max expression is allocated (`var >= max(a, b)`).
197/// Expressions are annotated with a location when they are created, which in
198/// this case are connect ops. Since imposing the first constraint does not
199/// create any new expression, the location information of that connect would be
200/// lost. With an identity expression, imposing the first constraint becomes
201/// `var >= identity(a)`, which is a *new* expression and properly tracks the
202/// location info.
203struct IdExpr : public ExprBase<IdExpr, Expr::Kind::Id> {
204 IdExpr(Expr *arg) : arg(arg) { assert(arg); }
205 void print(llvm::raw_ostream &os) const { os << "*" << *arg; }
206 bool operator==(const IdExpr &other) const {
207 return getKind() == other.getKind() && arg == other.arg;
208 }
209 llvm::hash_code hash_value() const {
210 return llvm::hash_combine(Expr::hash_value(), arg);
211 }
212
213 /// The inner expression.
214 Expr *const arg;
215};
216
217/// A known constant value.
218struct KnownExpr : public ExprBase<KnownExpr, Expr::Kind::Known> {
219 KnownExpr(int32_t value) : ExprBase() { setSolution(value); }
220 void print(llvm::raw_ostream &os) const { os << *getSolution(); }
221 bool operator==(const KnownExpr &other) const {
222 return *getSolution() == *other.getSolution();
223 }
224 llvm::hash_code hash_value() const {
225 return llvm::hash_combine(Expr::hash_value(), *getSolution());
226 }
227 int32_t getValue() const { return *getSolution(); }
228};
229
230/// A unary expression. Contains the actual data. Concrete subclasses are merely
231/// there for show and ease of use.
232struct UnaryExpr : public Expr {
233 bool operator==(const UnaryExpr &other) const {
234 return getKind() == other.getKind() && arg == other.arg;
235 }
236 llvm::hash_code hash_value() const {
237 return llvm::hash_combine(Expr::hash_value(), arg);
238 }
239
240 /// The child expression.
241 Expr *const arg;
242
243protected:
244 UnaryExpr(Kind kind, Expr *arg) : Expr(kind), arg(arg) { assert(arg); }
245};
246
247/// Helper class to CRTP-derive common functions.
248template <class DerivedT, Expr::Kind DerivedKind>
249struct UnaryExprBase : public UnaryExpr {
250 template <typename... Args>
251 UnaryExprBase(Args &&...args)
252 : UnaryExpr(DerivedKind, std::forward<Args>(args)...) {}
253 static bool classof(const Expr *e) { return e->getKind() == DerivedKind; }
254};
255
256/// A power of two.
257struct PowExpr : public UnaryExprBase<PowExpr, Expr::Kind::Pow> {
258 using UnaryExprBase::UnaryExprBase;
259 void print(llvm::raw_ostream &os) const { os << "2^" << arg; }
260};
261
262/// A binary expression. Contains the actual data. Concrete subclasses are
263/// merely there for show and ease of use.
264struct BinaryExpr : public Expr {
265 bool operator==(const BinaryExpr &other) const {
266 return getKind() == other.getKind() && lhs() == other.lhs() &&
267 rhs() == other.rhs();
268 }
269 llvm::hash_code hash_value() const {
270 return llvm::hash_combine(Expr::hash_value(), *args);
271 }
272 Expr *lhs() const { return args[0]; }
273 Expr *rhs() const { return args[1]; }
274
275 /// The child expressions.
276 Expr *const args[2];
277
278protected:
279 BinaryExpr(Kind kind, Expr *lhs, Expr *rhs) : Expr(kind), args{lhs, rhs} {
280 assert(lhs);
281 assert(rhs);
282 }
283};
284
285/// Helper class to CRTP-derive common functions.
286template <class DerivedT, Expr::Kind DerivedKind>
287struct BinaryExprBase : public BinaryExpr {
288 template <typename... Args>
289 BinaryExprBase(Args &&...args)
290 : BinaryExpr(DerivedKind, std::forward<Args>(args)...) {}
291 static bool classof(const Expr *e) { return e->getKind() == DerivedKind; }
292};
293
294/// An addition.
295struct AddExpr : public BinaryExprBase<AddExpr, Expr::Kind::Add> {
296 using BinaryExprBase::BinaryExprBase;
297 void print(llvm::raw_ostream &os) const {
298 os << "(" << *lhs() << " + " << *rhs() << ")";
299 }
300};
301
302/// The maximum of two expressions.
303struct MaxExpr : public BinaryExprBase<MaxExpr, Expr::Kind::Max> {
304 using BinaryExprBase::BinaryExprBase;
305 void print(llvm::raw_ostream &os) const {
306 os << "max(" << *lhs() << ", " << *rhs() << ")";
307 }
308};
309
310/// The minimum of two expressions.
311struct MinExpr : public BinaryExprBase<MinExpr, Expr::Kind::Min> {
312 using BinaryExprBase::BinaryExprBase;
313 void print(llvm::raw_ostream &os) const {
314 os << "min(" << *lhs() << ", " << *rhs() << ")";
315 }
316};
317
318void Expr::print(llvm::raw_ostream &os) const {
319 TypeSwitch<const Expr *>(this).Case<EXPR_CLASSES>(
320 [&](auto *e) { e->print(os); });
321}
322
323} // namespace
324
325//===----------------------------------------------------------------------===//
326// Fast bump allocator with optional interning
327//===----------------------------------------------------------------------===//
328
329namespace {
330
331// Hash slots in the interned allocator as if they were the pointed-to value
332// itself.
333template <typename T>
334struct InternedSlotInfo : DenseMapInfo<T *> {
335 static T *getEmptyKey() {
337 return static_cast<T *>(pointer);
338 }
339 static T *getTombstoneKey() {
341 return static_cast<T *>(pointer);
342 }
343 static unsigned getHashValue(const T *val) { return mlir::hash_value(*val); }
344 static bool isEqual(const T *lhs, const T *rhs) {
345 auto empty = getEmptyKey();
346 auto tombstone = getTombstoneKey();
347 if (lhs == empty || rhs == empty || lhs == tombstone || rhs == tombstone)
348 return lhs == rhs;
349 return *lhs == *rhs;
350 }
351};
352
353/// A simple bump allocator that ensures only ever one copy per object exists.
354/// The allocated objects must not have a destructor.
355template <typename T, typename std::enable_if_t<
356 std::is_trivially_destructible<T>::value, int> = 0>
357class InternedAllocator {
358 llvm::DenseSet<T *, InternedSlotInfo<T>> interned;
359 llvm::BumpPtrAllocator &allocator;
360
361public:
362 InternedAllocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
363
364 /// Allocate a new object if it does not yet exist, or return a pointer to the
365 /// existing one. `R` is the type of the object to be allocated. `R` must be
366 /// derived from or be the type `T`.
367 template <typename R = T, typename... Args>
368 std::pair<R *, bool> alloc(Args &&...args) {
369 auto stackValue = R(std::forward<Args>(args)...);
370 auto *stackSlot = &stackValue;
371 auto it = interned.find(stackSlot);
372 if (it != interned.end())
373 return std::make_pair(static_cast<R *>(*it), false);
374 auto heapValue = new (allocator) R(std::move(stackValue));
375 interned.insert(heapValue);
376 return std::make_pair(heapValue, true);
377 }
378};
379
380/// A simple bump allocator. The allocated objects must not have a destructor.
381/// This allocator is mainly there for symmetry with the `InternedAllocator`.
382template <typename T, typename std::enable_if_t<
383 std::is_trivially_destructible<T>::value, int> = 0>
384class Allocator {
385 llvm::BumpPtrAllocator &allocator;
386
387public:
388 Allocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
389
390 /// Allocate a new object. `R` is the type of the object to be allocated. `R`
391 /// must be derived from or be the type `T`.
392 template <typename R = T, typename... Args>
393 R *alloc(Args &&...args) {
394 return new (allocator) R(std::forward<Args>(args)...);
395 }
396};
397
398} // namespace
399
400//===----------------------------------------------------------------------===//
401// Constraint Solver
402//===----------------------------------------------------------------------===//
403
404namespace {
405/// A canonicalized linear inequality that maps a constraint on var `x` to the
406/// linear inequality `x >= max(a*x+b, c) + (failed ? ∞ : 0)`.
407///
408/// The inequality separately tracks recursive (a, b) and non-recursive (c)
409/// constraints on `x`. This allows it to properly identify the combination of
410/// the two constraints `x >= x-1` and `x >= 4` to be satisfiable as
411/// `x >= max(x-1, 4)`. If it only tracked inequality as `x >= a*x+b`, the
412/// combination of these two constraints would be `x >= x+4` (due to max(-1,4) =
413/// 4), which would be unsatisfiable.
414///
415/// The `failed` flag acts as an additional `∞` term that renders the inequality
416/// unsatisfiable. It is used as a tombstone value in case an operation renders
417/// the equality unsatisfiable (e.g. `x >= 2**x` would be represented as the
418/// inequality `x >= ∞`).
419///
420/// Inequalities represented in this form can easily be checked for
421/// unsatisfiability in the presence of recursion by inspecting the coefficients
422/// a and b. The `sat` function performs this action.
423struct LinIneq {
424 // x >= max(a*x+b, c) + (failed ? ∞ : 0)
425 int32_t recScale = 0; // a
426 int32_t recBias = 0; // b
427 int32_t nonrecBias = 0; // c
428 bool failed = false;
429
430 /// Create a new unsatisfiable inequality `x >= ∞`.
431 static LinIneq unsat() { return LinIneq(true); }
432
433 /// Create a new inequality `x >= (failed ? ∞ : 0)`.
434 explicit LinIneq(bool failed = false) : failed(failed) {}
435
436 /// Create a new inequality `x >= bias`.
437 explicit LinIneq(int32_t bias) : nonrecBias(bias) {}
438
439 /// Create a new inequality `x >= scale*x+bias`.
440 explicit LinIneq(int32_t scale, int32_t bias) {
441 if (scale != 0) {
442 recScale = scale;
443 recBias = bias;
444 } else {
445 nonrecBias = bias;
446 }
447 }
448
449 /// Create a new inequality `x >= max(recScale*x+recBias, nonrecBias) +
450 /// (failed ? ∞ : 0)`.
451 explicit LinIneq(int32_t recScale, int32_t recBias, int32_t nonrecBias,
452 bool failed = false)
453 : failed(failed) {
454 if (recScale != 0) {
455 this->recScale = recScale;
456 this->recBias = recBias;
457 this->nonrecBias = nonrecBias;
458 } else {
459 this->nonrecBias = std::max(recBias, nonrecBias);
460 }
461 }
462
463 /// Combine two inequalities by taking the maxima of corresponding
464 /// coefficients.
465 ///
466 /// This essentially combines `x >= max(a1*x+b1, c1)` and `x >= max(a2*x+b2,
467 /// c2)` into a new `x >= max(max(a1,a2)*x+max(b1,b2), max(c1,c2))`. This is
468 /// a pessimistic upper bound, since e.g. `x >= 2x-10` and `x >= x-5` may both
469 /// hold, but the resulting `x >= 2x-5` may pessimistically not hold.
470 static LinIneq max(const LinIneq &lhs, const LinIneq &rhs) {
471 return LinIneq(std::max(lhs.recScale, rhs.recScale),
472 std::max(lhs.recBias, rhs.recBias),
473 std::max(lhs.nonrecBias, rhs.nonrecBias),
474 lhs.failed || rhs.failed);
475 }
476
477 /// Combine two inequalities by summing up the two right hand sides.
478 ///
479 /// This is a tricky one, since the addition of the two max terms will lead to
480 /// a maximum over four possible terms (similar to a binomial expansion). In
481 /// order to shoehorn this back into a two-term maximum, we have to pick the
482 /// recursive term that will grow the fastest.
483 ///
484 /// As an example for this problem, consider the following addition:
485 ///
486 /// x >= max(a1*x+b1, c1) + max(a2*x+b2, c2)
487 ///
488 /// We would like to expand and rearrange this again into a maximum:
489 ///
490 /// x >= max(a1*x+b1 + max(a2*x+b2, c2), c1 + max(a2*x+b2, c2))
491 /// x >= max(max(a1*x+b1 + a2*x+b2, a1*x+b1 + c2),
492 /// max(c1 + a2*x+b2, c1 + c2))
493 /// x >= max((a1+a2)*x+(b1+b2), a1*x+(b1+c2), a2*x+(b2+c1), c1+c2)
494 ///
495 /// Since we are combining two two-term maxima, there are four possible ways
496 /// how the terms can combine, leading to the above four-term maximum. An easy
497 /// upper bound of the form we want would be the following:
498 ///
499 /// x >= max(max(a1+a2, a1, a2)*x + max(b1+b2, b1+c2, b2+c1), c1+c2)
500 ///
501 /// However, this is a very pessimistic upper-bound that will declare very
502 /// common patterns in the IR as unbreakable cycles, despite them being very
503 /// much breakable. For example:
504 ///
505 /// x >= max(x, 42) + max(0, -3) <-- breakable recursion
506 /// x >= max(max(1+0, 1, 0)*x + max(42+0, -3, 42), 42-2)
507 /// x >= max(x + 42, 39) <-- unbreakable recursion!
508 ///
509 /// A better approach is to take the expanded four-term maximum, retain the
510 /// non-recursive term (c1+c2), and estimate which one of the recursive terms
511 /// (first three) will become dominant as we choose greater values for x.
512 /// Since x never is inferred to be negative, the recursive term in the
513 /// maximum with the highest scaling factor for x will end up dominating as
514 /// x tends to ∞:
515 ///
516 /// x >= max({
517 /// (a1+a2)*x+(b1+b2) if a1+a2 >= max(a1+a2, a1, a2) and a1>0 and a2>0,
518 /// a1*x+(b1+c2) if a1 >= max(a1+a2, a1, a2) and a1>0,
519 /// a2*x+(b2+c1) if a2 >= max(a1+a2, a1, a2) and a2>0,
520 /// 0 otherwise
521 /// }, c1+c2)
522 ///
523 /// In case multiple cases apply, the highest bias of the recursive term is
524 /// picked. With this, the above problematic example triggers the second case
525 /// and becomes:
526 ///
527 /// x >= max(1*x+(0-3), 42-3) = max(x-3, 39)
528 ///
529 /// Of which the first case is chosen, as it has the lower bias value.
530 static LinIneq add(const LinIneq &lhs, const LinIneq &rhs) {
531 // Determine the maximum scaling factor among the three possible recursive
532 // terms.
533 auto enable1 = lhs.recScale > 0 && rhs.recScale > 0;
534 auto enable2 = lhs.recScale > 0;
535 auto enable3 = rhs.recScale > 0;
536 auto scale1 = lhs.recScale + rhs.recScale; // (a1+a2)
537 auto scale2 = lhs.recScale; // a1
538 auto scale3 = rhs.recScale; // a2
539 auto bias1 = lhs.recBias + rhs.recBias; // (b1+b2)
540 auto bias2 = lhs.recBias + rhs.nonrecBias; // (b1+c2)
541 auto bias3 = rhs.recBias + lhs.nonrecBias; // (b2+c1)
542 auto maxScale = std::max(scale1, std::max(scale2, scale3));
543
544 // Among those terms that have a maximum scaling factor, determine the
545 // largest bias value.
546 std::optional<int32_t> maxBias;
547 if (enable1 && scale1 == maxScale)
548 maxBias = bias1;
549 if (enable2 && scale2 == maxScale && (!maxBias || bias2 > *maxBias))
550 maxBias = bias2;
551 if (enable3 && scale3 == maxScale && (!maxBias || bias3 > *maxBias))
552 maxBias = bias3;
553
554 // Pick from the recursive terms the one with maximum scaling factor and
555 // minimum bias value.
556 auto nonrecBias = lhs.nonrecBias + rhs.nonrecBias; // c1+c2
557 auto failed = lhs.failed || rhs.failed;
558 if (enable1 && scale1 == maxScale && bias1 == *maxBias)
559 return LinIneq(scale1, bias1, nonrecBias, failed);
560 if (enable2 && scale2 == maxScale && bias2 == *maxBias)
561 return LinIneq(scale2, bias2, nonrecBias, failed);
562 if (enable3 && scale3 == maxScale && bias3 == *maxBias)
563 return LinIneq(scale3, bias3, nonrecBias, failed);
564 return LinIneq(0, 0, nonrecBias, failed);
565 }
566
567 /// Check if the inequality is satisfiable.
568 ///
569 /// The inequality becomes unsatisfiable if the RHS is ∞, or a>1, or a==1 and
570 /// b <= 0. Otherwise there exists as solution for `x` that satisfies the
571 /// inequality.
572 bool sat() const {
573 if (failed)
574 return false;
575 if (recScale > 1)
576 return false;
577 if (recScale == 1 && recBias > 0)
578 return false;
579 return true;
580 }
581
582 /// Dump the inequality in human-readable form.
583 void print(llvm::raw_ostream &os) const {
584 bool any = false;
585 bool both = (recScale != 0 || recBias != 0) && nonrecBias != 0;
586 os << "x >= ";
587 if (both)
588 os << "max(";
589 if (recScale != 0) {
590 any = true;
591 if (recScale != 1)
592 os << recScale << "*";
593 os << "x";
594 }
595 if (recBias != 0) {
596 if (any) {
597 if (recBias < 0)
598 os << " - " << -recBias;
599 else
600 os << " + " << recBias;
601 } else {
602 any = true;
603 os << recBias;
604 }
605 }
606 if (both)
607 os << ", ";
608 if (nonrecBias != 0) {
609 any = true;
610 os << nonrecBias;
611 }
612 if (both)
613 os << ")";
614 if (failed) {
615 if (any)
616 os << " + ";
617 os << "∞";
618 }
619 if (!any)
620 os << "0";
621 }
622};
623
624/// A simple solver for width constraints.
625class ConstraintSolver {
626public:
627 ConstraintSolver() = default;
628
629 VarExpr *var() {
630 auto *v = vars.alloc();
631 varExprs.push_back(v);
632 if (currentInfo)
633 info[v].insert(currentInfo);
634 if (currentLoc)
635 locs[v].insert(*currentLoc);
636 return v;
637 }
638 DerivedExpr *derived() {
639 auto *d = derivs.alloc();
640 derivedExprs.push_back(d);
641 return d;
642 }
643 KnownExpr *known(int32_t value) { return alloc<KnownExpr>(knowns, value); }
644 IdExpr *id(Expr *arg) { return alloc<IdExpr>(ids, arg); }
645 PowExpr *pow(Expr *arg) { return alloc<PowExpr>(uns, arg); }
646 AddExpr *add(Expr *lhs, Expr *rhs) { return alloc<AddExpr>(bins, lhs, rhs); }
647 MaxExpr *max(Expr *lhs, Expr *rhs) { return alloc<MaxExpr>(bins, lhs, rhs); }
648 MinExpr *min(Expr *lhs, Expr *rhs) { return alloc<MinExpr>(bins, lhs, rhs); }
649
650 /// Add a constraint `lhs >= rhs`. Multiple constraints on the same variable
651 /// are coalesced into a `max(a, b)` expr.
652 Expr *addGeqConstraint(VarExpr *lhs, Expr *rhs) {
653 if (lhs->constraint)
654 lhs->constraint = max(lhs->constraint, rhs);
655 else
656 lhs->constraint = id(rhs);
657 return lhs->constraint;
658 }
659
660 /// Add a constraint `lhs <= rhs`. Multiple constraints on the same variable
661 /// are coalesced into a `min(a, b)` expr.
662 Expr *addLeqConstraint(VarExpr *lhs, Expr *rhs) {
663 if (lhs->upperBound)
664 lhs->upperBound = min(lhs->upperBound, rhs);
665 else
666 lhs->upperBound = id(rhs);
667 return lhs->upperBound;
668 }
669
670 void dumpConstraints(llvm::raw_ostream &os);
671 LogicalResult solve();
672
673 using ContextInfo = DenseMap<Expr *, llvm::SmallSetVector<FieldRef, 1>>;
674 const ContextInfo &getContextInfo() const { return info; }
675 void setCurrentContextInfo(FieldRef fieldRef) { currentInfo = fieldRef; }
676 void setCurrentLocation(std::optional<Location> loc) { currentLoc = loc; }
677
678private:
679 // Allocator for constraint expressions.
680 llvm::BumpPtrAllocator allocator;
681 Allocator<VarExpr> vars = {allocator};
682 Allocator<DerivedExpr> derivs = {allocator};
683 InternedAllocator<KnownExpr> knowns = {allocator};
684 InternedAllocator<IdExpr> ids = {allocator};
685 InternedAllocator<UnaryExpr> uns = {allocator};
686 InternedAllocator<BinaryExpr> bins = {allocator};
687
688 /// A list of expressions in the order they were created.
689 std::vector<VarExpr *> varExprs;
690 std::vector<DerivedExpr *> derivedExprs;
691
692 /// Add an allocated expression to the list above.
693 template <typename R, typename T, typename... Args>
694 R *alloc(InternedAllocator<T> &allocator, Args &&...args) {
695 auto [expr, inserted] =
696 allocator.template alloc<R>(std::forward<Args>(args)...);
697 if (currentInfo)
698 info[expr].insert(currentInfo);
699 if (currentLoc)
700 locs[expr].insert(*currentLoc);
701 return expr;
702 }
703
704 /// Contextual information for each expression, indicating which values in the
705 /// IR lead to this expression.
706 ContextInfo info;
707 FieldRef currentInfo = {};
708 DenseMap<Expr *, llvm::SmallSetVector<Location, 1>> locs;
709 std::optional<Location> currentLoc;
710
711 // Forbid copyign or moving the solver, which would invalidate the refs to
712 // allocator held by the allocators.
713 ConstraintSolver(ConstraintSolver &&) = delete;
714 ConstraintSolver(const ConstraintSolver &) = delete;
715 ConstraintSolver &operator=(ConstraintSolver &&) = delete;
716 ConstraintSolver &operator=(const ConstraintSolver &) = delete;
717
718 void emitUninferredWidthError(VarExpr *var);
719
720 LinIneq checkCycles(VarExpr *var, Expr *expr,
721 SmallPtrSetImpl<Expr *> &seenVars,
722 InFlightDiagnostic *reportInto = nullptr,
723 unsigned indent = 1);
724};
725
726} // namespace
727
728/// Print all constraints in the solver to an output stream.
729void ConstraintSolver::dumpConstraints(llvm::raw_ostream &os) {
730 for (auto *v : varExprs) {
731 if (v->constraint)
732 os << "- " << *v << " >= " << *v->constraint << "\n";
733 else
734 os << "- " << *v << " unconstrained\n";
735 }
736}
737
738#ifndef NDEBUG
739inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const LinIneq &l) {
740 l.print(os);
741 return os;
742}
743#endif
744
745/// Compute the canonicalized linear inequality expression starting at `expr`,
746/// for the `var` as the left hand side `x` of the inequality. `seenVars` is
747/// used as a recursion breaker. Occurrences of `var` itself within the
748/// expression are mapped to the `a` coefficient in the inequality. Any other
749/// variables are substituted and, in the presence of a recursion in a variable
750/// other than `var`, treated as zero. `info` is a mapping from constraint
751/// expressions to values and operations that produced the expression, and is
752/// used during error reporting. If `reportInto` is present, the function will
753/// additionally attach unsatisfiable inequalities as notes to the diagnostic as
754/// it encounters them.
755LinIneq ConstraintSolver::checkCycles(VarExpr *var, Expr *expr,
756 SmallPtrSetImpl<Expr *> &seenVars,
757 InFlightDiagnostic *reportInto,
758 unsigned indent) {
759 auto ineq =
760 TypeSwitch<Expr *, LinIneq>(expr)
761 .Case<KnownExpr>(
762 [&](auto *expr) { return LinIneq(expr->getValue()); })
763 .Case<VarExpr>([&](auto *expr) {
764 if (expr == var)
765 return LinIneq(1, 0); // x >= 1*x + 0
766 if (!seenVars.insert(expr).second)
767 // Count recursions in other variables as 0. This is sane
768 // since the cycle is either breakable, in which case the
769 // recursion does not modify the resulting value of the
770 // variable, or it is not breakable and will be caught by
771 // this very function once it is called on that variable.
772 return LinIneq(0);
773 if (!expr->constraint)
774 // Count unconstrained variables as `x >= 0`.
775 return LinIneq(0);
776 auto l = checkCycles(var, expr->constraint, seenVars, reportInto,
777 indent + 1);
778 seenVars.erase(expr);
779 return l;
780 })
781 .Case<IdExpr>([&](auto *expr) {
782 return checkCycles(var, expr->arg, seenVars, reportInto,
783 indent + 1);
784 })
785 .Case<PowExpr>([&](auto *expr) {
786 // If we can evaluate `2**arg` to a sensible constant, do
787 // so. This is the case if a == 0 and c < 31 such that 2**c is
788 // representable.
789 auto arg =
790 checkCycles(var, expr->arg, seenVars, reportInto, indent + 1);
791 if (arg.recScale != 0 || arg.nonrecBias < 0 || arg.nonrecBias >= 31)
792 return LinIneq::unsat();
793 return LinIneq(1 << arg.nonrecBias); // x >= 2**arg
794 })
795 .Case<AddExpr>([&](auto *expr) {
796 return LinIneq::add(
797 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
798 checkCycles(var, expr->rhs(), seenVars, reportInto,
799 indent + 1));
800 })
801 .Case<MaxExpr, MinExpr>([&](auto *expr) {
802 // Combine the inequalities of the LHS and RHS into a single overly
803 // pessimistic inequality. We treat `MinExpr` the same as `MaxExpr`,
804 // since `max(a,b)` is an upper bound to `min(a,b)`.
805 return LinIneq::max(
806 checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
807 checkCycles(var, expr->rhs(), seenVars, reportInto,
808 indent + 1));
809 })
810 .Default([](auto) { return LinIneq::unsat(); });
811
812 // If we were passed an in-flight diagnostic and the current inequality is
813 // unsatisfiable, attach notes to the diagnostic indicating the values or
814 // operations that contributed to this part of the constraint expression.
815 if (reportInto && !ineq.sat()) {
816 auto report = [&](Location loc) {
817 auto &note = reportInto->attachNote(loc);
818 note << "constrained width W >= ";
819 if (ineq.recScale == -1)
820 note << "-";
821 if (ineq.recScale != 1)
822 note << ineq.recScale;
823 note << "W";
824 if (ineq.recBias < 0)
825 note << "-" << -ineq.recBias;
826 if (ineq.recBias > 0)
827 note << "+" << ineq.recBias;
828 note << " here:";
829 };
830 auto it = locs.find(expr);
831 if (it != locs.end())
832 for (auto loc : it->second)
833 report(loc);
834 }
835 if (!reportInto)
836 LLVM_DEBUG(llvm::dbgs().indent(indent * 2)
837 << "- Visited " << *expr << ": " << ineq << "\n");
838
839 return ineq;
840}
841
842using ExprSolution = std::pair<std::optional<int32_t>, bool>;
843
844static ExprSolution
845computeUnary(ExprSolution arg, llvm::function_ref<int32_t(int32_t)> operation) {
846 if (arg.first)
847 arg.first = operation(*arg.first);
848 return arg;
849}
850
851static ExprSolution
853 llvm::function_ref<int32_t(int32_t, int32_t)> operation) {
854 auto result = ExprSolution{std::nullopt, lhs.second || rhs.second};
855 if (lhs.first && rhs.first)
856 result.first = operation(*lhs.first, *rhs.first);
857 else if (lhs.first)
858 result.first = lhs.first;
859 else if (rhs.first)
860 result.first = rhs.first;
861 return result;
862}
863
864namespace {
865struct Frame {
866 Frame(Expr *expr, unsigned indent) : expr(expr), indent(indent) {}
867 Expr *expr;
868 // Indent is only used for debug logs.
869 unsigned indent;
870};
871} // namespace
872
873/// Compute the value of a constraint `expr`. `seenVars` is used as a recursion
874/// breaker. Recursive variables are treated as zero. Returns the computed value
875/// and a boolean indicating whether a recursion was detected. This may be used
876/// to memoize the result of expressions in case they were not involved in a
877/// cycle (which may alter their value from the perspective of a variable).
878static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
879 std::vector<Frame> &worklist) {
880 worklist.clear();
881 worklist.emplace_back(expr, 1);
882 llvm::DenseMap<Expr *, ExprSolution> solvedExprs;
883
884 while (!worklist.empty()) {
885 auto &frame = worklist.back();
886 auto indent = frame.indent;
887 auto setSolution = [&](ExprSolution solution) {
888 // Memoize the result.
889 if (solution.first && !solution.second)
890 frame.expr->setSolution(*solution.first);
891 solvedExprs[frame.expr] = solution;
892
893 // Produce some useful debug prints.
894 LLVM_DEBUG({
895 if (!isa<KnownExpr>(frame.expr)) {
896 if (solution.first)
897 llvm::dbgs().indent(indent * 2)
898 << "= Solved " << *frame.expr << " = " << *solution.first;
899 else
900 llvm::dbgs().indent(indent * 2) << "= Skipped " << *frame.expr;
901 llvm::dbgs() << " (" << (solution.second ? "cycle broken" : "unique")
902 << ")\n";
903 }
904 });
905
906 worklist.pop_back();
907 };
908
909 // See if we have a memoized result we can return.
910 if (frame.expr->getSolution()) {
911 LLVM_DEBUG({
912 if (!isa<KnownExpr>(frame.expr))
913 llvm::dbgs().indent(indent * 2) << "- Cached " << *frame.expr << " = "
914 << *frame.expr->getSolution() << "\n";
915 });
916 setSolution(ExprSolution{*frame.expr->getSolution(), false});
917 continue;
918 }
919
920 // Otherwise compute the value of the expression.
921 LLVM_DEBUG({
922 if (!isa<KnownExpr>(frame.expr))
923 llvm::dbgs().indent(indent * 2) << "- Solving " << *frame.expr << "\n";
924 });
925
926 TypeSwitch<Expr *>(frame.expr)
927 .Case<KnownExpr>([&](auto *expr) {
928 setSolution(ExprSolution{expr->getValue(), false});
929 })
930 .Case<VarExpr>([&](auto *expr) {
931 if (solvedExprs.contains(expr->constraint)) {
932 auto solution = solvedExprs[expr->constraint];
933 // If we've solved the upper bound already, store the solution.
934 // This will be explicitly solved for later if not computed as
935 // part of the solving that resolved this constraint.
936 // This should only happen if somehow the constraint is
937 // solved before visiting this expression, so that our upperBound
938 // was not added to the worklist such that it was handled first.
939 if (expr->upperBound && solvedExprs.contains(expr->upperBound))
940 expr->upperBoundSolution = solvedExprs[expr->upperBound].first;
941 seenVars.erase(expr);
942 // Constrain variables >= 0.
943 if (solution.first && *solution.first < 0)
944 solution.first = 0;
945 return setSolution(solution);
946 }
947
948 // Unconstrained variables produce no solution.
949 if (!expr->constraint)
950 return setSolution(ExprSolution{std::nullopt, false});
951 // Return no solution for recursions in the variables. This is sane
952 // and will cause the expression to be ignored when computing the
953 // parent, e.g. `a >= max(a, 1)` will become just `a >= 1`.
954 if (!seenVars.insert(expr).second)
955 return setSolution(ExprSolution{std::nullopt, true});
956
957 worklist.emplace_back(expr->constraint, indent + 1);
958 if (expr->upperBound)
959 worklist.emplace_back(expr->upperBound, indent + 1);
960 })
961 .Case<IdExpr>([&](auto *expr) {
962 if (solvedExprs.contains(expr->arg))
963 return setSolution(solvedExprs[expr->arg]);
964 worklist.emplace_back(expr->arg, indent + 1);
965 })
966 .Case<PowExpr>([&](auto *expr) {
967 if (solvedExprs.contains(expr->arg))
968 return setSolution(computeUnary(
969 solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));
970
971 worklist.emplace_back(expr->arg, indent + 1);
972 })
973 .Case<AddExpr>([&](auto *expr) {
974 if (solvedExprs.contains(expr->lhs()) &&
975 solvedExprs.contains(expr->rhs()))
976 return setSolution(computeBinary(
977 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
978 [](int32_t lhs, int32_t rhs) { return lhs + rhs; }));
979
980 worklist.emplace_back(expr->lhs(), indent + 1);
981 worklist.emplace_back(expr->rhs(), indent + 1);
982 })
983 .Case<MaxExpr>([&](auto *expr) {
984 if (solvedExprs.contains(expr->lhs()) &&
985 solvedExprs.contains(expr->rhs()))
986 return setSolution(computeBinary(
987 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
988 [](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));
989
990 worklist.emplace_back(expr->lhs(), indent + 1);
991 worklist.emplace_back(expr->rhs(), indent + 1);
992 })
993 .Case<MinExpr>([&](auto *expr) {
994 if (solvedExprs.contains(expr->lhs()) &&
995 solvedExprs.contains(expr->rhs()))
996 return setSolution(computeBinary(
997 solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
998 [](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));
999
1000 worklist.emplace_back(expr->lhs(), indent + 1);
1001 worklist.emplace_back(expr->rhs(), indent + 1);
1002 })
1003 .Default([&](auto) {
1004 setSolution(ExprSolution{std::nullopt, false});
1005 });
1006 }
1007
1008 return solvedExprs[expr];
1009}
1010
1011/// Solve the constraint problem. This is a very simple implementation that
1012/// does not fully solve the problem if there are weird dependency cycles
1013/// present.
1014LogicalResult ConstraintSolver::solve() {
1015 LLVM_DEBUG({
1016 llvm::dbgs() << "\n";
1017 debugHeader("Constraints") << "\n\n";
1018 dumpConstraints(llvm::dbgs());
1019 });
1020
1021 // Ensure that there are no adverse cycles around.
1022 LLVM_DEBUG({
1023 llvm::dbgs() << "\n";
1024 debugHeader("Checking for unbreakable loops") << "\n\n";
1025 });
1026 SmallPtrSet<Expr *, 16> seenVars;
1027 bool anyFailed = false;
1028
1029 for (auto *var : varExprs) {
1030 if (!var->constraint)
1031 continue;
1032 LLVM_DEBUG(llvm::dbgs()
1033 << "- Checking " << *var << " >= " << *var->constraint << "\n");
1034
1035 // Canonicalize the variable's constraint expression into a form that allows
1036 // us to easily determine if any recursion leads to an unsatisfiable
1037 // constraint. The `seenVars` set acts as a recursion breaker.
1038 seenVars.insert(var);
1039 auto ineq = checkCycles(var, var->constraint, seenVars);
1040 seenVars.clear();
1041
1042 // If the constraint is satisfiable, we're done.
1043 // TODO: It's possible that this result is already sufficient to arrive at a
1044 // solution for the constraint, and the second pass further down is not
1045 // necessary. This would require more proper handling of `MinExpr` in the
1046 // cycle checking code.
1047 if (ineq.sat()) {
1048 LLVM_DEBUG(llvm::dbgs()
1049 << " = Breakable since " << ineq << " satisfiable\n");
1050 continue;
1051 }
1052
1053 // If we arrive here, the constraint is not satisfiable at all. To provide
1054 // some guidance to the user, we call the cycle checking code again, but
1055 // this time with an in-flight diagnostic to attach notes indicating
1056 // unsatisfiable paths in the cycle.
1057 LLVM_DEBUG(llvm::dbgs()
1058 << " = UNBREAKABLE since " << ineq << " unsatisfiable\n");
1059 anyFailed = true;
1060 for (auto fieldRef : info.find(var)->second) {
1061 // Depending on whether this value stems from an operation or not, create
1062 // an appropriate diagnostic identifying the value.
1063 auto *op = fieldRef.getDefiningOp();
1064 auto diag = op ? op->emitOpError()
1065 : mlir::emitError(fieldRef.getValue().getLoc())
1066 << "value ";
1067 diag << "is constrained to be wider than itself";
1068
1069 // Re-run the cycle checking, but this time reporting into the diagnostic.
1070 seenVars.insert(var);
1071 checkCycles(var, var->constraint, seenVars, &diag);
1072 seenVars.clear();
1073 }
1074 }
1075
1076 // If there were cycles, return now to avoid complaining to the user about
1077 // dependent widths not being inferred.
1078 if (anyFailed)
1079 return failure();
1080
1081 // Iterate over the constraint variables and solve each.
1082 LLVM_DEBUG({
1083 llvm::dbgs() << "\n";
1084 debugHeader("Solving constraints") << "\n\n";
1085 });
1086 std::vector<Frame> worklist;
1087 for (auto *var : varExprs) {
1088 // Complain about unconstrained variables.
1089 if (!var->constraint) {
1090 LLVM_DEBUG(llvm::dbgs() << "- Unconstrained " << *var << "\n");
1091 emitUninferredWidthError(var);
1092 anyFailed = true;
1093 continue;
1094 }
1095
1096 // Compute the value for the variable.
1097 LLVM_DEBUG(llvm::dbgs()
1098 << "- Solving " << *var << " >= " << *var->constraint << "\n");
1099 seenVars.insert(var);
1100 auto solution = solveExpr(var->constraint, seenVars, worklist);
1101 // Compute the upperBound if there is one and haven't already.
1102 if (var->upperBound && !var->upperBoundSolution)
1103 var->upperBoundSolution =
1104 solveExpr(var->upperBound, seenVars, worklist).first;
1105 seenVars.clear();
1106
1107 // Constrain variables >= 0.
1108 if (solution.first) {
1109 if (*solution.first < 0)
1110 solution.first = 0;
1111 var->setSolution(*solution.first);
1112 }
1113
1114 // In case the width could not be inferred, complain to the user. This might
1115 // be the case if the width depends on an unconstrained variable.
1116 if (!solution.first) {
1117 LLVM_DEBUG(llvm::dbgs() << " - UNSOLVED " << *var << "\n");
1118 emitUninferredWidthError(var);
1119 anyFailed = true;
1120 continue;
1121 }
1122 LLVM_DEBUG(llvm::dbgs()
1123 << " = Solved " << *var << " = " << solution.first << " ("
1124 << (solution.second ? "cycle broken" : "unique") << ")\n");
1125
1126 // Check if the solution we have found violates an upper bound.
1127 if (var->upperBoundSolution && var->upperBoundSolution < *solution.first) {
1128 LLVM_DEBUG(llvm::dbgs() << " ! Unsatisfiable " << *var
1129 << " <= " << var->upperBoundSolution << "\n");
1130 emitUninferredWidthError(var);
1131 anyFailed = true;
1132 }
1133 }
1134
1135 // Copy over derived widths.
1136 for (auto *derived : derivedExprs) {
1137 auto *assigned = derived->assigned;
1138 if (!assigned || !assigned->getSolution()) {
1139 LLVM_DEBUG(llvm::dbgs() << "- Unused " << *derived << " set to 0\n");
1140 derived->setSolution(0);
1141 } else {
1142 LLVM_DEBUG(llvm::dbgs() << "- Deriving " << *derived << " = "
1143 << assigned->getSolution() << "\n");
1144 derived->setSolution(*assigned->getSolution());
1145 }
1146 }
1147
1148 return failure(anyFailed);
1149}
1150
1151// Emits the diagnostic to inform the user about an uninferred width in the
1152// design. Returns true if an error was reported, false otherwise.
1153void ConstraintSolver::emitUninferredWidthError(VarExpr *var) {
1154 FieldRef fieldRef = info.find(var)->second.back();
1155 Value value = fieldRef.getValue();
1156
1157 auto diag = mlir::emitError(value.getLoc(), "uninferred width:");
1158
1159 // Try to hint the user at what kind of node this is.
1160 if (isa<BlockArgument>(value)) {
1161 diag << " port";
1162 } else if (auto *op = value.getDefiningOp()) {
1163 TypeSwitch<Operation *>(op)
1164 .Case<WireOp>([&](auto) { diag << " wire"; })
1165 .Case<RegOp, RegResetOp>([&](auto) { diag << " reg"; })
1166 .Case<NodeOp>([&](auto) { diag << " node"; })
1167 .Default([&](auto) { diag << " value"; });
1168 } else {
1169 diag << " value";
1170 }
1171
1172 // Actually print what the user can refer to.
1173 auto [fieldName, rootKnown] = getFieldName(fieldRef);
1174 if (!fieldName.empty()) {
1175 if (!rootKnown)
1176 diag << " field";
1177 diag << " \"" << fieldName << "\"";
1178 }
1179
1180 if (!var->constraint) {
1181 diag << " is unconstrained";
1182 } else if (var->getSolution() && var->upperBoundSolution &&
1183 var->getSolution() > var->upperBoundSolution) {
1184 diag << " cannot satisfy all width requirements";
1185 LLVM_DEBUG(llvm::dbgs() << *var->constraint << "\n");
1186 LLVM_DEBUG(llvm::dbgs() << *var->upperBound << "\n");
1187 auto loc = locs.find(var->constraint)->second.back();
1188 diag.attachNote(loc) << "width is constrained to be at least "
1189 << *var->getSolution() << " here:";
1190 loc = locs.find(var->upperBound)->second.back();
1191 diag.attachNote(loc) << "width is constrained to be at most "
1192 << *var->upperBoundSolution << " here:";
1193 } else {
1194 diag << " width cannot be determined";
1195 LLVM_DEBUG(llvm::dbgs() << *var->constraint << "\n");
1196 auto loc = locs.find(var->constraint)->second.back();
1197 diag.attachNote(loc) << "width is constrained by an uninferred width here:";
1198 }
1199}
1200
1201//===----------------------------------------------------------------------===//
1202// Inference Constraint Problem Mapping
1203//===----------------------------------------------------------------------===//
1204
1205namespace {
1206
1207/// A helper class which maps the types and operations in a design to a set of
1208/// variables and constraints to be solved later.
1209class InferenceMapping {
1210public:
1211 InferenceMapping(ConstraintSolver &solver, SymbolTable &symtbl,
1213 : solver(solver), symtbl(symtbl), irn{symtbl, istc} {}
1214
1215 LogicalResult map(CircuitOp op);
1216 bool allWidthsKnown(Operation *op);
1217 LogicalResult mapOperation(Operation *op);
1218
1219 /// Declare all the variables in the value. If the value is a ground type,
1220 /// there is a single variable declared. If the value is an aggregate type,
1221 /// it sets up variables for each unknown width.
1222 void declareVars(Value value, bool isDerived = false);
1223
1224 /// Assign the constraint expressions of the fields in the `result` argument
1225 /// as the max of expressions in the `rhs` and `lhs` arguments. Both fields
1226 /// must be the same type.
1227 void maximumOfTypes(Value result, Value rhs, Value lhs);
1228
1229 /// Constrain the value "larger" to be greater than or equal to "smaller".
1230 /// These may be aggregate values. This is used for regular connects.
1231 void constrainTypes(Value larger, Value smaller, bool equal = false);
1232
1233 /// Constrain the expression "larger" to be greater than or equals to
1234 /// the expression "smaller".
1235 void constrainTypes(Expr *larger, Expr *smaller,
1236 bool imposeUpperBounds = false, bool equal = false);
1237
1238 /// Assign the constraint expressions of the fields in the `src` argument as
1239 /// the expressions for the `dst` argument. Both fields must be of the given
1240 /// `type`.
1241 void unifyTypes(FieldRef lhs, FieldRef rhs, FIRRTLType type);
1242
1243 /// Get the expr associated with the value. The value must be a non-aggregate
1244 /// type
1245 Expr *getExpr(Value value) const;
1246
1247 /// Get the expr associated with a specific field in a value.
1248 Expr *getExpr(FieldRef fieldRef) const;
1249
1250 /// Get the expr associated with a specific field in a value. If value is
1251 /// NULL, then this returns NULL.
1252 Expr *getExprOrNull(FieldRef fieldRef) const;
1253
1254 /// Set the expr associated with the value. The value must be a non-aggregate
1255 /// type.
1256 void setExpr(Value value, Expr *expr);
1257
1258 /// Set the expr associated with a specific field in a value.
1259 void setExpr(FieldRef fieldRef, Expr *expr);
1260
1261 /// Return whether a module was skipped due to being fully inferred already.
1262 bool isModuleSkipped(FModuleOp module) const {
1263 return skippedModules.count(module);
1264 }
1265
1266 /// Return whether all modules in the mapping were fully inferred.
1267 bool areAllModulesSkipped() const { return allModulesSkipped; }
1268
1269private:
1270 /// The constraint solver into which we emit variables and constraints.
1271 ConstraintSolver &solver;
1272
1273 /// The constraint exprs for each result type of an operation.
1274 DenseMap<FieldRef, Expr *> opExprs;
1275
1276 /// The fully inferred modules that were skipped entirely.
1277 SmallPtrSet<Operation *, 16> skippedModules;
1278 bool allModulesSkipped = true;
1279
1280 /// Cache of module symbols
1281 SymbolTable &symtbl;
1282
1283 /// Full design inner symbol information.
1285};
1286
1287} // namespace
1288
1289/// Check if a type contains any FIRRTL type with uninferred widths.
1290static bool hasUninferredWidth(Type type) {
1291 return TypeSwitch<Type, bool>(type)
1292 .Case<FIRRTLBaseType>([](auto base) { return base.hasUninferredWidth(); })
1293 .Case<RefType>(
1294 [](auto ref) { return ref.getType().hasUninferredWidth(); })
1295 .Default([](auto) { return false; });
1296}
1297
1298LogicalResult InferenceMapping::map(CircuitOp op) {
1299 LLVM_DEBUG(llvm::dbgs()
1300 << "\n===----- Mapping ops to constraint exprs -----===\n\n");
1301
1302 // Ensure we have constraint variables established for all module ports.
1303 for (auto module : op.getOps<FModuleOp>())
1304 for (auto arg : module.getArguments()) {
1305 solver.setCurrentContextInfo(FieldRef(arg, 0));
1306 declareVars(arg);
1307 }
1308
1309 for (auto module : op.getOps<FModuleOp>()) {
1310 // Check if the module contains *any* uninferred widths. This allows us to
1311 // do an early skip if the module is already fully inferred.
1312 bool anyUninferred = false;
1313 for (auto arg : module.getArguments()) {
1314 anyUninferred |= hasUninferredWidth(arg.getType());
1315 if (anyUninferred)
1316 break;
1317 }
1318 module.walk([&](Operation *op) {
1319 for (auto type : op->getResultTypes())
1320 anyUninferred |= hasUninferredWidth(type);
1321 if (anyUninferred)
1322 return WalkResult::interrupt();
1323 return WalkResult::advance();
1324 });
1325
1326 if (!anyUninferred) {
1327 LLVM_DEBUG(llvm::dbgs() << "Skipping fully-inferred module '"
1328 << module.getName() << "'\n");
1329 skippedModules.insert(module);
1330 continue;
1331 }
1332
1333 allModulesSkipped = false;
1334
1335 // Go through operations in the module, creating type variables for results,
1336 // and generating constraints.
1337 auto result = module.getBodyBlock()->walk(
1338 [&](Operation *op) { return WalkResult(mapOperation(op)); });
1339 if (result.wasInterrupted())
1340 return failure();
1341 }
1342
1343 return success();
1344}
1345
1346bool InferenceMapping::allWidthsKnown(Operation *op) {
1347 /// Ignore property assignments, no widths to infer.
1348 if (isa<PropAssignOp>(op))
1349 return true;
1350
1351 // If this is a mux, and the select signal is uninferred, we need to set an
1352 // upperbound limit on it.
1353 if (isa<MuxPrimOp, Mux4CellIntrinsicOp, Mux2CellIntrinsicOp>(op))
1354 if (hasUninferredWidth(op->getOperand(0).getType()))
1355 return false;
1356
1357 // We need to propagate through connects.
1358 if (isa<FConnectLike, AttachOp>(op))
1359 return false;
1360
1361 // Check if we know the width of every result of this operation.
1362 return llvm::all_of(op->getResults(), [&](auto result) {
1363 // Only consider FIRRTL types for width constraints. Ignore any foreign
1364 // types as they don't participate in the width inference process.
1365 if (auto type = type_dyn_cast<FIRRTLType>(result.getType()))
1366 if (hasUninferredWidth(type))
1367 return false;
1368 return true;
1369 });
1370}
1371
1372LogicalResult InferenceMapping::mapOperation(Operation *op) {
1373 if (allWidthsKnown(op))
1374 return success();
1375
1376 // Actually generate the necessary constraint expressions.
1377 bool mappingFailed = false;
1378 solver.setCurrentContextInfo(
1379 op->getNumResults() > 0 ? FieldRef(op->getResults()[0], 0) : FieldRef());
1380 solver.setCurrentLocation(op->getLoc());
1381 TypeSwitch<Operation *>(op)
1382 .Case<ConstantOp>([&](auto op) {
1383 // If the constant has a known width, use that. Otherwise pick the
1384 // smallest number of bits necessary to represent the constant.
1385 auto v = op.getValue();
1386 auto w = v.getBitWidth() - (v.isNegative() ? v.countLeadingOnes()
1387 : v.countLeadingZeros());
1388 if (v.isSigned())
1389 w += 1;
1390 setExpr(op.getResult(), solver.known(std::max(w, 1u)));
1391 })
1392 .Case<SpecialConstantOp>([&](auto op) {
1393 // Nothing required.
1394 })
1395 .Case<InvalidValueOp>([&](auto op) {
1396 // We must duplicate the invalid value for each use, since each use can
1397 // be inferred to a different width.
1398 declareVars(op.getResult(), /*isDerived=*/true);
1399 })
1400 .Case<WireOp, RegOp>([&](auto op) { declareVars(op.getResult()); })
1401 .Case<RegResetOp>([&](auto op) {
1402 // The original Scala code also constrains the reset signal to be at
1403 // least 1 bit wide. We don't do this here since the MLIR FIRRTL
1404 // dialect enforces the reset signal to be an async reset or a
1405 // `uint<1>`.
1406 declareVars(op.getResult());
1407 // Contrain the register to be greater than or equal to the reset
1408 // signal.
1409 constrainTypes(op.getResult(), op.getResetValue());
1410 })
1411 .Case<NodeOp>([&](auto op) {
1412 // Nodes have the same type as their input.
1413 unifyTypes(FieldRef(op.getResult(), 0), FieldRef(op.getInput(), 0),
1414 op.getResult().getType());
1415 })
1416
1417 // Aggregate Values
1418 .Case<SubfieldOp>([&](auto op) {
1419 BundleType bundleType = op.getInput().getType();
1420 auto fieldID = bundleType.getFieldID(op.getFieldIndex());
1421 unifyTypes(FieldRef(op.getResult(), 0),
1422 FieldRef(op.getInput(), fieldID), op.getType());
1423 })
1424 .Case<SubindexOp, SubaccessOp>([&](auto op) {
1425 // All vec fields unify to the same thing. Always use the first element
1426 // of the vector, which has a field ID of 1.
1427 unifyTypes(FieldRef(op.getResult(), 0), FieldRef(op.getInput(), 1),
1428 op.getType());
1429 })
1430 .Case<SubtagOp>([&](auto op) {
1431 FEnumType enumType = op.getInput().getType();
1432 auto fieldID = enumType.getFieldID(op.getFieldIndex());
1433 unifyTypes(FieldRef(op.getResult(), 0),
1434 FieldRef(op.getInput(), fieldID), op.getType());
1435 })
1436
1437 .Case<RefSubOp>([&](RefSubOp op) {
1438 uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(
1439 op.getInput().getType().getType())
1440 .Case<FVectorType>([](auto _) { return 1; })
1441 .Case<BundleType>([&](auto type) {
1442 return type.getFieldID(op.getIndex());
1443 });
1444 unifyTypes(FieldRef(op.getResult(), 0),
1445 FieldRef(op.getInput(), fieldID), op.getType());
1446 })
1447
1448 // Arithmetic and Logical Binary Primitives
1449 .Case<AddPrimOp, SubPrimOp>([&](auto op) {
1450 auto lhs = getExpr(op.getLhs());
1451 auto rhs = getExpr(op.getRhs());
1452 auto e = solver.add(solver.max(lhs, rhs), solver.known(1));
1453 setExpr(op.getResult(), e);
1454 })
1455 .Case<MulPrimOp>([&](auto op) {
1456 auto lhs = getExpr(op.getLhs());
1457 auto rhs = getExpr(op.getRhs());
1458 auto e = solver.add(lhs, rhs);
1459 setExpr(op.getResult(), e);
1460 })
1461 .Case<DivPrimOp>([&](auto op) {
1462 auto lhs = getExpr(op.getLhs());
1463 Expr *e;
1464 if (op.getType().base().isSigned()) {
1465 e = solver.add(lhs, solver.known(1));
1466 } else {
1467 e = lhs;
1468 }
1469 setExpr(op.getResult(), e);
1470 })
1471 .Case<RemPrimOp>([&](auto op) {
1472 auto lhs = getExpr(op.getLhs());
1473 auto rhs = getExpr(op.getRhs());
1474 auto e = solver.min(lhs, rhs);
1475 setExpr(op.getResult(), e);
1476 })
1477 .Case<AndPrimOp, OrPrimOp, XorPrimOp>([&](auto op) {
1478 auto lhs = getExpr(op.getLhs());
1479 auto rhs = getExpr(op.getRhs());
1480 auto e = solver.max(lhs, rhs);
1481 setExpr(op.getResult(), e);
1482 })
1483
1484 // Misc Binary Primitives
1485 .Case<CatPrimOp>([&](auto op) {
1486 auto lhs = getExpr(op.getLhs());
1487 auto rhs = getExpr(op.getRhs());
1488 auto e = solver.add(lhs, rhs);
1489 setExpr(op.getResult(), e);
1490 })
1491 .Case<DShlPrimOp>([&](auto op) {
1492 auto lhs = getExpr(op.getLhs());
1493 auto rhs = getExpr(op.getRhs());
1494 auto e = solver.add(lhs, solver.add(solver.pow(rhs), solver.known(-1)));
1495 setExpr(op.getResult(), e);
1496 })
1497 .Case<DShlwPrimOp, DShrPrimOp>([&](auto op) {
1498 auto e = getExpr(op.getLhs());
1499 setExpr(op.getResult(), e);
1500 })
1501
1502 // Unary operators
1503 .Case<NegPrimOp>([&](auto op) {
1504 auto input = getExpr(op.getInput());
1505 auto e = solver.add(input, solver.known(1));
1506 setExpr(op.getResult(), e);
1507 })
1508 .Case<CvtPrimOp>([&](auto op) {
1509 auto input = getExpr(op.getInput());
1510 auto e = op.getInput().getType().base().isSigned()
1511 ? input
1512 : solver.add(input, solver.known(1));
1513 setExpr(op.getResult(), e);
1514 })
1515
1516 // Miscellaneous
1517 .Case<BitsPrimOp>([&](auto op) {
1518 setExpr(op.getResult(), solver.known(op.getHi() - op.getLo() + 1));
1519 })
1520 .Case<HeadPrimOp>([&](auto op) {
1521 setExpr(op.getResult(), solver.known(op.getAmount()));
1522 })
1523 .Case<TailPrimOp>([&](auto op) {
1524 auto input = getExpr(op.getInput());
1525 auto e = solver.add(input, solver.known(-op.getAmount()));
1526 setExpr(op.getResult(), e);
1527 })
1528 .Case<PadPrimOp>([&](auto op) {
1529 auto input = getExpr(op.getInput());
1530 auto e = solver.max(input, solver.known(op.getAmount()));
1531 setExpr(op.getResult(), e);
1532 })
1533 .Case<ShlPrimOp>([&](auto op) {
1534 auto input = getExpr(op.getInput());
1535 auto e = solver.add(input, solver.known(op.getAmount()));
1536 setExpr(op.getResult(), e);
1537 })
1538 .Case<ShrPrimOp>([&](auto op) {
1539 auto input = getExpr(op.getInput());
1540 // UInt saturates at 0 bits, SInt at 1 bit
1541 auto minWidth = op.getInput().getType().base().isUnsigned() ? 0 : 1;
1542 auto e = solver.max(solver.add(input, solver.known(-op.getAmount())),
1543 solver.known(minWidth));
1544 setExpr(op.getResult(), e);
1545 })
1546
1547 // Handle operations whose output width matches the input width.
1548 .Case<NotPrimOp, AsSIntPrimOp, AsUIntPrimOp, ConstCastOp>(
1549 [&](auto op) { setExpr(op.getResult(), getExpr(op.getInput())); })
1550 .Case<mlir::UnrealizedConversionCastOp>(
1551 [&](auto op) { setExpr(op.getResult(0), getExpr(op.getOperand(0))); })
1552
1553 // Handle operations with a single result type that always has a
1554 // well-known width.
1555 .Case<LEQPrimOp, LTPrimOp, GEQPrimOp, GTPrimOp, EQPrimOp, NEQPrimOp,
1556 AsClockPrimOp, AsAsyncResetPrimOp, AndRPrimOp, OrRPrimOp,
1557 XorRPrimOp>([&](auto op) {
1558 auto width = op.getType().getBitWidthOrSentinel();
1559 assert(width > 0 && "width should have been checked by verifier");
1560 setExpr(op.getResult(), solver.known(width));
1561 })
1562 .Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](auto op) {
1563 auto *sel = getExpr(op.getSel());
1564 constrainTypes(solver.known(1), sel, /*imposeUpperBounds=*/true);
1565 maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
1566 })
1567 .Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
1568 auto *sel = getExpr(op.getSel());
1569 constrainTypes(solver.known(2), sel, /*imposeUpperBounds=*/true);
1570 maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
1571 maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
1572 maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
1573 })
1574
1575 .Case<ConnectOp, MatchingConnectOp>(
1576 [&](auto op) { constrainTypes(op.getDest(), op.getSrc()); })
1577 .Case<RefDefineOp>([&](auto op) {
1578 // Dest >= Src, but also check Src <= Dest for correctness
1579 // (but don't solve to make this true, don't back-propagate)
1580 constrainTypes(op.getDest(), op.getSrc(), true);
1581 })
1582 .Case<AttachOp>([&](auto op) {
1583 // Attach connects multiple analog signals together. All signals must
1584 // have the same bit width. Signals without bit width inherit from the
1585 // other signals.
1586 if (op.getAttached().empty())
1587 return;
1588 auto prev = op.getAttached()[0];
1589 for (auto operand : op.getAttached().drop_front()) {
1590 auto e1 = getExpr(prev);
1591 auto e2 = getExpr(operand);
1592 constrainTypes(e1, e2, /*imposeUpperBounds=*/true);
1593 constrainTypes(e2, e1, /*imposeUpperBounds=*/true);
1594 prev = operand;
1595 }
1596 })
1597
1598 // Handle the no-ops that don't interact with width inference.
1599 .Case<PrintFOp, SkipOp, StopOp, WhenOp, AssertOp, AssumeOp,
1600 UnclockedAssumeIntrinsicOp, CoverOp>([&](auto) {})
1601
1602 // Handle instances of other modules.
1603 .Case<InstanceOp>([&](auto op) {
1604 auto refdModule = op.getReferencedOperation(symtbl);
1605 auto module = dyn_cast<FModuleOp>(&*refdModule);
1606 if (!module) {
1607 auto diag = mlir::emitError(op.getLoc());
1608 diag << "extern module `" << op.getModuleName()
1609 << "` has ports of uninferred width";
1610
1611 auto fml = cast<FModuleLike>(&*refdModule);
1612 auto ports = fml.getPorts();
1613 for (auto &port : ports) {
1614 auto baseType = getBaseType(port.type);
1615 if (baseType && baseType.hasUninferredWidth()) {
1616 diag.attachNote(op.getLoc()) << "Port: " << port.name;
1617 if (!baseType.isGround())
1618 diagnoseUninferredType(diag, baseType, port.name.getValue());
1619 }
1620 }
1621
1622 diag.attachNote(op.getLoc())
1623 << "Only non-extern FIRRTL modules may contain unspecified "
1624 "widths to be inferred automatically.";
1625 diag.attachNote(refdModule->getLoc())
1626 << "Module `" << op.getModuleName() << "` defined here:";
1627 mappingFailed = true;
1628 return;
1629 }
1630 // Simply look up the free variables created for the instantiated
1631 // module's ports, and use them for instance port wires. This way,
1632 // constraints imposed onto the ports of the instance will transparently
1633 // apply to the ports of the instantiated module.
1634 for (auto [result, arg] :
1635 llvm::zip(op->getResults(), module.getArguments()))
1636 unifyTypes({result, 0}, {arg, 0},
1637 type_cast<FIRRTLType>(result.getType()));
1638 })
1639
1640 // Handle memories.
1641 .Case<MemOp>([&](MemOp op) {
1642 // Create constraint variables for all ports.
1643 unsigned nonDebugPort = 0;
1644 for (const auto &result : llvm::enumerate(op.getResults())) {
1645 declareVars(result.value());
1646 if (!type_isa<RefType>(result.value().getType()))
1647 nonDebugPort = result.index();
1648 }
1649
1650 // A helper function that returns the indeces of the "data", "rdata",
1651 // and "wdata" fields in the bundle corresponding to a memory port.
1652 auto dataFieldIndices = [](MemOp::PortKind kind) -> ArrayRef<unsigned> {
1653 static const unsigned indices[] = {3, 5};
1654 static const unsigned debug[] = {0};
1655 switch (kind) {
1656 case MemOp::PortKind::Read:
1657 case MemOp::PortKind::Write:
1658 return ArrayRef<unsigned>(indices, 1); // {3}
1659 case MemOp::PortKind::ReadWrite:
1660 return ArrayRef<unsigned>(indices); // {3, 5}
1661 case MemOp::PortKind::Debug:
1662 return ArrayRef<unsigned>(debug);
1663 }
1664 llvm_unreachable("Imposible PortKind");
1665 };
1666
1667 // This creates independent variables for every data port. Yet, what we
1668 // actually want is for all data ports to share the same variable. To do
1669 // this, we find the first data port declared, and use that port's vars
1670 // for all the other ports.
1671 unsigned firstFieldIndex =
1672 dataFieldIndices(op.getPortKind(nonDebugPort))[0];
1673 FieldRef firstData(
1674 op.getResult(nonDebugPort),
1675 type_cast<BundleType>(op.getPortType(nonDebugPort).getPassiveType())
1676 .getFieldID(firstFieldIndex));
1677 LLVM_DEBUG(llvm::dbgs() << "Adjusting memory port variables:\n");
1678
1679 // Reuse data port variables.
1680 auto dataType = op.getDataType();
1681 for (unsigned i = 0, e = op.getResults().size(); i < e; ++i) {
1682 auto result = op.getResult(i);
1683 if (type_isa<RefType>(result.getType())) {
1684 // Debug ports are firrtl.ref<vector<data-type, depth>>
1685 // Use FieldRef of 1, to indicate the first vector element must be
1686 // of the dataType.
1687 unifyTypes(firstData, FieldRef(result, 1), dataType);
1688 continue;
1689 }
1690
1691 auto portType =
1692 type_cast<BundleType>(op.getPortType(i).getPassiveType());
1693 for (auto fieldIndex : dataFieldIndices(op.getPortKind(i)))
1694 unifyTypes(FieldRef(result, portType.getFieldID(fieldIndex)),
1695 firstData, dataType);
1696 }
1697 })
1698
1699 .Case<RefSendOp>([&](auto op) {
1700 declareVars(op.getResult());
1701 constrainTypes(op.getResult(), op.getBase(), true);
1702 })
1703 .Case<RefResolveOp>([&](auto op) {
1704 declareVars(op.getResult());
1705 constrainTypes(op.getResult(), op.getRef(), true);
1706 })
1707 .Case<RefCastOp>([&](auto op) {
1708 declareVars(op.getResult());
1709 constrainTypes(op.getResult(), op.getInput(), true);
1710 })
1711 .Case<RWProbeOp>([&](auto op) {
1712 auto ist = irn.lookup(op.getTarget());
1713 if (!ist) {
1714 op->emitError("target of rwprobe could not be resolved");
1715 mappingFailed = true;
1716 return;
1717 }
1718 auto ref = getFieldRefForTarget(ist);
1719 if (!ref) {
1720 op->emitError("target of rwprobe resolved to unsupported target");
1721 mappingFailed = true;
1722 return;
1723 }
1724 auto newFID = convertFieldIDToOurVersion(
1725 ref.getFieldID(), type_cast<FIRRTLType>(ref.getValue().getType()));
1726 unifyTypes(FieldRef(op.getResult(), 0),
1727 FieldRef(ref.getValue(), newFID), op.getType());
1728 })
1729 .Case<mlir::UnrealizedConversionCastOp>([&](auto op) {
1730 for (Value result : op.getResults()) {
1731 auto ty = result.getType();
1732 if (type_isa<FIRRTLType>(ty))
1733 declareVars(result);
1734 }
1735 })
1736 .Default([&](auto op) {
1737 op->emitOpError("not supported in width inference");
1738 mappingFailed = true;
1739 });
1740
1741 // Forceable declarations should have the ref constrained to data result.
1742 if (auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable())
1743 unifyTypes(FieldRef(fop.getDataRef(), 0), FieldRef(fop.getDataRaw(), 0),
1744 fop.getDataType());
1745
1746 return failure(mappingFailed);
1747}
1748
1749/// Declare free variables for the type of a value, and associate the resulting
1750/// set of variables with that value.
1751void InferenceMapping::declareVars(Value value, bool isDerived) {
1752 // Declare a variable for every unknown width in the type. If this is a Bundle
1753 // type or a FVector type, we will have to potentially create many variables.
1754 unsigned fieldID = 0;
1755 std::function<void(FIRRTLBaseType)> declare = [&](FIRRTLBaseType type) {
1756 auto width = type.getBitWidthOrSentinel();
1757 if (width >= 0) {
1758 fieldID++;
1759 } else if (width == -1) {
1760 // Unknown width integers create a variable.
1761 FieldRef field(value, fieldID);
1762 solver.setCurrentContextInfo(field);
1763 if (isDerived)
1764 setExpr(field, solver.derived());
1765 else
1766 setExpr(field, solver.var());
1767 fieldID++;
1768 } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1769 // Bundle types recursively declare all bundle elements.
1770 fieldID++;
1771 for (auto &element : bundleType)
1772 declare(element.type);
1773 } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1774 fieldID++;
1775 auto save = fieldID;
1776 declare(vecType.getElementType());
1777 // Skip past the rest of the elements
1778 fieldID = save + vecType.getMaxFieldID();
1779 } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1780 fieldID++;
1781 for (auto &element : enumType.getElements())
1782 declare(element.type);
1783 } else {
1784 llvm_unreachable("Unknown type inside a bundle!");
1785 }
1786 };
1787 if (auto type = getBaseType(value.getType()))
1788 declare(type);
1789}
1790
1791/// Assign the constraint expressions of the fields in the `result` argument as
1792/// the max of expressions in the `rhs` and `lhs` arguments. Both fields must be
1793/// the same type.
1794void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
1795 // Recurse to every leaf element and set larger >= smaller.
1796 auto fieldID = 0;
1797 std::function<void(FIRRTLBaseType)> maximize = [&](FIRRTLBaseType type) {
1798 if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1799 fieldID++;
1800 for (auto &element : bundleType.getElements())
1801 maximize(element.type);
1802 } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1803 fieldID++;
1804 auto save = fieldID;
1805 // Skip 0 length vectors.
1806 if (vecType.getNumElements() > 0)
1807 maximize(vecType.getElementType());
1808 fieldID = save + vecType.getMaxFieldID();
1809 } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1810 fieldID++;
1811 for (auto &element : enumType.getElements())
1812 maximize(element.type);
1813 } else if (type.isGround()) {
1814 auto *e = solver.max(getExpr(FieldRef(rhs, fieldID)),
1815 getExpr(FieldRef(lhs, fieldID)));
1816 setExpr(FieldRef(result, fieldID), e);
1817 fieldID++;
1818 } else {
1819 llvm_unreachable("Unknown type inside a bundle!");
1820 }
1821 };
1822 if (auto type = getBaseType(result.getType()))
1823 maximize(type);
1824}
1825
1826/// Establishes constraints to ensure the sizes in the `larger` type are greater
1827/// than or equal to the sizes in the `smaller` type. Types have to be
1828/// compatible in the sense that they may only differ in the presence or absence
1829/// of bit widths.
1830///
1831/// This function is used to apply regular connects.
1832/// Set `equal` for constraining larger <= smaller for correctness but not
1833/// solving.
1834void InferenceMapping::constrainTypes(Value larger, Value smaller, bool equal) {
1835 // Recurse to every leaf element and set larger >= smaller. Ignore foreign
1836 // types as these do not participate in width inference.
1837
1838 auto fieldID = 0;
1839 std::function<void(FIRRTLBaseType, Value, Value)> constrain =
1840 [&](FIRRTLBaseType type, Value larger, Value smaller) {
1841 if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1842 fieldID++;
1843 for (auto &element : bundleType.getElements()) {
1844 if (element.isFlip)
1845 constrain(element.type, smaller, larger);
1846 else
1847 constrain(element.type, larger, smaller);
1848 }
1849 } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1850 fieldID++;
1851 auto save = fieldID;
1852 // Skip 0 length vectors.
1853 if (vecType.getNumElements() > 0) {
1854 constrain(vecType.getElementType(), larger, smaller);
1855 }
1856 fieldID = save + vecType.getMaxFieldID();
1857 } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1858 fieldID++;
1859 for (auto &element : enumType.getElements())
1860 constrain(element.type, larger, smaller);
1861 } else if (type.isGround()) {
1862 // Leaf element, look up their expressions, and create the constraint.
1863 constrainTypes(getExpr(FieldRef(larger, fieldID)),
1864 getExpr(FieldRef(smaller, fieldID)), false, equal);
1865 fieldID++;
1866 } else {
1867 llvm_unreachable("Unknown type inside a bundle!");
1868 }
1869 };
1870
1871 if (auto type = getBaseType(larger.getType()))
1872 constrain(type, larger, smaller);
1873}
1874
1875/// Establishes constraints to ensure the sizes in the `larger` type are greater
1876/// than or equal to the sizes in the `smaller` type.
1877void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
1878 bool imposeUpperBounds, bool equal) {
1879 assert(larger && "Larger expression should be specified");
1880 assert(smaller && "Smaller expression should be specified");
1881
1882 // If one of the sides is `DerivedExpr`, simply assign the other side as the
1883 // derived width. This allows `InvalidValueOp`s to properly infer their width
1884 // from the connects they are used in, but also be inferred to something
1885 // useful on their own.
1886 if (auto *largerDerived = dyn_cast<DerivedExpr>(larger)) {
1887 largerDerived->assigned = smaller;
1888 LLVM_DEBUG(llvm::dbgs() << "Deriving " << *largerDerived << " from "
1889 << *smaller << "\n");
1890 return;
1891 }
1892 if (auto *smallerDerived = dyn_cast<DerivedExpr>(smaller)) {
1893 smallerDerived->assigned = larger;
1894 LLVM_DEBUG(llvm::dbgs() << "Deriving " << *smallerDerived << " from "
1895 << *larger << "\n");
1896 return;
1897 }
1898
1899 // If the larger expr is a free variable, create a `expr >= x` constraint for
1900 // it that we can try to satisfy with the smallest width.
1901 if (auto *largerVar = dyn_cast<VarExpr>(larger)) {
1902 [[maybe_unused]] auto *c = solver.addGeqConstraint(largerVar, smaller);
1903 LLVM_DEBUG(llvm::dbgs()
1904 << "Constrained " << *largerVar << " >= " << *c << "\n");
1905 // If we're constraining larger == smaller, add the LEQ contraint as well.
1906 // Solve for GEQ but check that LEQ is true.
1907 // Used for matchingconnect, some reference operations, and anywhere the
1908 // widths should be inferred strictly in one direction but are required to
1909 // also be equal for correctness.
1910 if (equal) {
1911 [[maybe_unused]] auto *leq = solver.addLeqConstraint(largerVar, smaller);
1912 LLVM_DEBUG(llvm::dbgs()
1913 << "Constrained " << *largerVar << " <= " << *leq << "\n");
1914 }
1915 return;
1916 }
1917
1918 // If the smaller expr is a free variable but the larger one is not, create a
1919 // `expr <= k` upper bound that we can verify once all lower bounds have been
1920 // satisfied. Since we are always picking the smallest width to satisfy all
1921 // `>=` constraints, any `<=` constraints have no effect on the solution
1922 // besides indicating that a width is unsatisfiable.
1923 if (auto *smallerVar = dyn_cast<VarExpr>(smaller)) {
1924 if (imposeUpperBounds || equal) {
1925 [[maybe_unused]] auto *c = solver.addLeqConstraint(smallerVar, larger);
1926 LLVM_DEBUG(llvm::dbgs()
1927 << "Constrained " << *smallerVar << " <= " << *c << "\n");
1928 }
1929 }
1930}
1931
1932/// Assign the constraint expressions of the fields in the `src` argument as the
1933/// expressions for the `dst` argument. Both fields must be of the given `type`.
1934void InferenceMapping::unifyTypes(FieldRef lhs, FieldRef rhs, FIRRTLType type) {
1935 // Fast path for `unifyTypes(x, x, _)`.
1936 if (lhs == rhs)
1937 return;
1938
1939 // Co-iterate the two field refs, recurring into every leaf element and set
1940 // them equal.
1941 auto fieldID = 0;
1942 std::function<void(FIRRTLBaseType)> unify = [&](FIRRTLBaseType type) {
1943 if (type.isGround()) {
1944 // Leaf element, unify the fields!
1945 FieldRef lhsFieldRef(lhs.getValue(), lhs.getFieldID() + fieldID);
1946 FieldRef rhsFieldRef(rhs.getValue(), rhs.getFieldID() + fieldID);
1947 LLVM_DEBUG(llvm::dbgs()
1948 << "Unify " << getFieldName(lhsFieldRef).first << " = "
1949 << getFieldName(rhsFieldRef).first << "\n");
1950 // Abandon variables becoming unconstrainable by the unification.
1951 if (auto *var = dyn_cast_or_null<VarExpr>(getExprOrNull(lhsFieldRef)))
1952 solver.addGeqConstraint(var, solver.known(0));
1953 setExpr(lhsFieldRef, getExpr(rhsFieldRef));
1954 fieldID++;
1955 } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1956 fieldID++;
1957 for (auto &element : bundleType) {
1958 unify(element.type);
1959 }
1960 } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1961 fieldID++;
1962 auto save = fieldID;
1963 // Skip 0 length vectors.
1964 if (vecType.getNumElements() > 0) {
1965 unify(vecType.getElementType());
1966 }
1967 fieldID = save + vecType.getMaxFieldID();
1968 } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1969 fieldID++;
1970 for (auto &element : enumType.getElements())
1971 unify(element.type);
1972 } else {
1973 llvm_unreachable("Unknown type inside a bundle!");
1974 }
1975 };
1976 if (auto ftype = getBaseType(type))
1977 unify(ftype);
1978}
1979
1980/// Get the constraint expression for a value.
1981Expr *InferenceMapping::getExpr(Value value) const {
1982 assert(type_cast<FIRRTLType>(getBaseType(value.getType())).isGround());
1983 // A field ID of 0 indicates the entire value.
1984 return getExpr(FieldRef(value, 0));
1985}
1986
1987/// Get the constraint expression for a value.
1988Expr *InferenceMapping::getExpr(FieldRef fieldRef) const {
1989 auto *expr = getExprOrNull(fieldRef);
1990 assert(expr && "constraint expr should have been constructed for value");
1991 return expr;
1992}
1993
1994Expr *InferenceMapping::getExprOrNull(FieldRef fieldRef) const {
1995 auto it = opExprs.find(fieldRef);
1996 if (it != opExprs.end())
1997 return it->second;
1998 // If we don't have an expression for this fieldRef, it should have a
1999 // constant width.
2000 auto baseType = getBaseType(fieldRef.getValue().getType());
2001 auto type =
2003 auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
2004 if (width < 0)
2005 return nullptr;
2006 return solver.known(width);
2007}
2008
2009/// Associate a constraint expression with a value.
2010void InferenceMapping::setExpr(Value value, Expr *expr) {
2011 assert(type_cast<FIRRTLType>(getBaseType(value.getType())).isGround());
2012 // A field ID of 0 indicates the entire value.
2013 setExpr(FieldRef(value, 0), expr);
2014}
2015
2016/// Associate a constraint expression with a value.
2017void InferenceMapping::setExpr(FieldRef fieldRef, Expr *expr) {
2018 LLVM_DEBUG({
2019 llvm::dbgs() << "Expr " << *expr << " for " << fieldRef.getValue();
2020 if (fieldRef.getFieldID())
2021 llvm::dbgs() << " '" << getFieldName(fieldRef).first << "'";
2022 auto fieldName = getFieldName(fieldRef);
2023 if (fieldName.second)
2024 llvm::dbgs() << " (\"" << fieldName.first << "\")";
2025 llvm::dbgs() << "\n";
2026 });
2027 opExprs[fieldRef] = expr;
2028}
2029
2030//===----------------------------------------------------------------------===//
2031// Inference Result Application
2032//===----------------------------------------------------------------------===//
2033
2034namespace {
2035/// A helper class which maps the types and operations in a design to a set
2036/// of variables and constraints to be solved later.
2037class InferenceTypeUpdate {
2038public:
2039 InferenceTypeUpdate(InferenceMapping &mapping) : mapping(mapping) {}
2040
2041 LogicalResult update(CircuitOp op);
2042 FailureOr<bool> updateOperation(Operation *op);
2043 FailureOr<bool> updateValue(Value value);
2045
2046private:
2047 const InferenceMapping &mapping;
2048};
2049
2050} // namespace
2051
2052/// Update the types throughout a circuit.
2053LogicalResult InferenceTypeUpdate::update(CircuitOp op) {
2054 LLVM_DEBUG({
2055 llvm::dbgs() << "\n";
2056 debugHeader("Update types") << "\n\n";
2057 });
2058 return mlir::failableParallelForEach(
2059 op.getContext(), op.getOps<FModuleOp>(), [&](FModuleOp op) {
2060 // Skip this module if it had no widths to be
2061 // inferred at all.
2062 if (mapping.isModuleSkipped(op))
2063 return success();
2064 auto isFailed = op.walk<WalkOrder::PreOrder>([&](Operation *op) {
2065 if (failed(updateOperation(op)))
2066 return WalkResult::interrupt();
2067 return WalkResult::advance();
2068 }).wasInterrupted();
2069 return failure(isFailed);
2070 });
2071}
2072
2073/// Update the result types of an operation.
2074FailureOr<bool> InferenceTypeUpdate::updateOperation(Operation *op) {
2075 bool anyChanged = false;
2076
2077 for (Value v : op->getResults()) {
2078 auto result = updateValue(v);
2079 if (failed(result))
2080 return result;
2081 anyChanged |= *result;
2082 }
2083
2084 // If this is a connect operation, width inference might have inferred a RHS
2085 // that is wider than the LHS, in which case an additional BitsPrimOp is
2086 // necessary to truncate the value.
2087 if (auto con = dyn_cast<ConnectOp>(op)) {
2088 auto lhs = con.getDest();
2089 auto rhs = con.getSrc();
2090 auto lhsType = type_dyn_cast<FIRRTLBaseType>(lhs.getType());
2091 auto rhsType = type_dyn_cast<FIRRTLBaseType>(rhs.getType());
2092
2093 // Nothing to do if not base types.
2094 if (!lhsType || !rhsType)
2095 return anyChanged;
2096
2097 auto lhsWidth = lhsType.getBitWidthOrSentinel();
2098 auto rhsWidth = rhsType.getBitWidthOrSentinel();
2099 if (lhsWidth >= 0 && rhsWidth >= 0 && lhsWidth < rhsWidth) {
2100 OpBuilder builder(op);
2101 auto trunc = builder.createOrFold<TailPrimOp>(con.getLoc(), con.getSrc(),
2102 rhsWidth - lhsWidth);
2103 if (type_isa<SIntType>(rhsType))
2104 trunc =
2105 builder.createOrFold<AsSIntPrimOp>(con.getLoc(), lhsType, trunc);
2106
2107 LLVM_DEBUG(llvm::dbgs()
2108 << "Truncating RHS to " << lhsType << " in " << con << "\n");
2109 con->replaceUsesOfWith(con.getSrc(), trunc);
2110 }
2111 return anyChanged;
2112 }
2113
2114 // If this is a module, update its ports.
2115 if (auto module = dyn_cast<FModuleOp>(op)) {
2116 // Update the block argument types.
2117 bool argsChanged = false;
2118 SmallVector<Attribute> argTypes;
2119 argTypes.reserve(module.getNumPorts());
2120 for (auto arg : module.getArguments()) {
2121 auto result = updateValue(arg);
2122 if (failed(result))
2123 return result;
2124 argsChanged |= *result;
2125 argTypes.push_back(TypeAttr::get(arg.getType()));
2126 }
2127
2128 // Update the module function type if needed.
2129 if (argsChanged) {
2130 module.setPortTypesAttr(ArrayAttr::get(module.getContext(), argTypes));
2131 anyChanged = true;
2132 }
2133 }
2134 return anyChanged;
2135}
2136
2137/// Resize a `uint`, `sint`, or `analog` type to a specific width.
2138static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth) {
2139 auto *context = type.getContext();
2141 .Case<UIntType>([&](auto type) {
2142 return UIntType::get(context, newWidth, type.isConst());
2143 })
2144 .Case<SIntType>([&](auto type) {
2145 return SIntType::get(context, newWidth, type.isConst());
2146 })
2147 .Case<AnalogType>([&](auto type) {
2148 return AnalogType::get(context, newWidth, type.isConst());
2149 })
2150 .Default([&](auto type) { return type; });
2151}
2152
2153/// Update the type of a value.
2154FailureOr<bool> InferenceTypeUpdate::updateValue(Value value) {
2155 // Check if the value has a type which we can update.
2156 auto type = type_dyn_cast<FIRRTLType>(value.getType());
2157 if (!type)
2158 return false;
2159
2160 // Fast path for types that have fully inferred widths.
2161 if (!hasUninferredWidth(type))
2162 return false;
2163
2164 // If this is an operation that does not generate any free variables that
2165 // are determined during width inference, simply update the value type based
2166 // on the operation arguments.
2167 if (auto op = dyn_cast_or_null<InferTypeOpInterface>(value.getDefiningOp())) {
2168 SmallVector<Type, 2> types;
2169 auto res =
2170 op.inferReturnTypes(op->getContext(), op->getLoc(), op->getOperands(),
2171 op->getAttrDictionary(), op->getPropertiesStorage(),
2172 op->getRegions(), types);
2173 if (failed(res))
2174 return failure();
2175
2176 assert(types.size() == op->getNumResults());
2177 for (auto [result, type] : llvm::zip(op->getResults(), types)) {
2178 LLVM_DEBUG(llvm::dbgs()
2179 << "Inferring " << result << " as " << type << "\n");
2180 result.setType(type);
2181 }
2182 return true;
2183 }
2184
2185 // Recreate the type, substituting the solved widths.
2186 auto *context = type.getContext();
2187 unsigned fieldID = 0;
2188 std::function<FIRRTLBaseType(FIRRTLBaseType)> updateBase =
2189 [&](FIRRTLBaseType type) -> FIRRTLBaseType {
2190 auto width = type.getBitWidthOrSentinel();
2191 if (width >= 0) {
2192 // Known width integers return themselves.
2193 fieldID++;
2194 return type;
2195 }
2196 if (width == -1) {
2197 // Unknown width integers return the solved type.
2198 auto newType = updateType(FieldRef(value, fieldID), type);
2199 fieldID++;
2200 return newType;
2201 }
2202 if (auto bundleType = type_dyn_cast<BundleType>(type)) {
2203 // Bundle types recursively update all bundle elements.
2204 fieldID++;
2205 llvm::SmallVector<BundleType::BundleElement, 3> elements;
2206 for (auto &element : bundleType) {
2207 auto updatedBase = updateBase(element.type);
2208 if (!updatedBase)
2209 return {};
2210 elements.emplace_back(element.name, element.isFlip, updatedBase);
2211 }
2212 return BundleType::get(context, elements, bundleType.isConst());
2213 }
2214 if (auto vecType = type_dyn_cast<FVectorType>(type)) {
2215 fieldID++;
2216 auto save = fieldID;
2217 // TODO: this should recurse into the element type of 0 length vectors and
2218 // set any unknown width to 0.
2219 if (vecType.getNumElements() > 0) {
2220 auto updatedBase = updateBase(vecType.getElementType());
2221 if (!updatedBase)
2222 return {};
2223 auto newType = FVectorType::get(updatedBase, vecType.getNumElements(),
2224 vecType.isConst());
2225 fieldID = save + vecType.getMaxFieldID();
2226 return newType;
2227 }
2228 // If this is a 0 length vector return the original type.
2229 return type;
2230 }
2231 if (auto enumType = type_dyn_cast<FEnumType>(type)) {
2232 fieldID++;
2233 llvm::SmallVector<FEnumType::EnumElement> elements;
2234 for (auto &element : enumType.getElements()) {
2235 auto updatedBase = updateBase(element.type);
2236 if (!updatedBase)
2237 return {};
2238 elements.emplace_back(element.name, updatedBase);
2239 }
2240 return FEnumType::get(context, elements, enumType.isConst());
2241 }
2242 llvm_unreachable("Unknown type inside a bundle!");
2243 };
2244
2245 // Update the type.
2246 auto newType = mapBaseTypeNullable(type, updateBase);
2247 if (!newType)
2248 return failure();
2249 LLVM_DEBUG(llvm::dbgs() << "Update " << value << " to " << newType << "\n");
2250 value.setType(newType);
2251
2252 // If this is a ConstantOp, adjust the width of the underlying APInt.
2253 // Unsized constants have APInts which are *at least* wide enough to hold
2254 // the value, but may be larger. This can trip up the verifier.
2255 if (auto op = value.getDefiningOp<ConstantOp>()) {
2256 auto k = op.getValue();
2257 auto bitwidth = op.getType().getBitWidthOrSentinel();
2258 if (k.getBitWidth() > unsigned(bitwidth))
2259 k = k.trunc(bitwidth);
2260 op->setAttr("value", IntegerAttr::get(op.getContext(), k));
2261 }
2262
2263 return newType != type;
2264}
2265
2266/// Update a type.
2267FIRRTLBaseType InferenceTypeUpdate::updateType(FieldRef fieldRef,
2268 FIRRTLBaseType type) {
2269 assert(type.isGround() && "Can only pass in ground types.");
2270 auto value = fieldRef.getValue();
2271 // Get the inferred width.
2272 Expr *expr = mapping.getExprOrNull(fieldRef);
2273 if (!expr || !expr->getSolution()) {
2274 // It should not be possible to arrive at an uninferred width at this point.
2275 // In case the constraints are not resolvable, checks before the calls to
2276 // `updateType` must have already caught the issues and aborted the pass
2277 // early. Might turn this into an assert later.
2278 mlir::emitError(value.getLoc(), "width should have been inferred");
2279 return {};
2280 }
2281 int32_t solution = *expr->getSolution();
2282 assert(solution >= 0); // The solver infers variables to be 0 or greater.
2283 return resizeType(type, solution);
2284}
2285
2286//===----------------------------------------------------------------------===//
2287// Pass Infrastructure
2288//===----------------------------------------------------------------------===//
2289
2290namespace {
2291class InferWidthsPass
2292 : public circt::firrtl::impl::InferWidthsBase<InferWidthsPass> {
2293 void runOnOperation() override;
2294};
2295} // namespace
2296
2297void InferWidthsPass::runOnOperation() {
2298 // Collect variables and constraints
2299 ConstraintSolver solver;
2300 InferenceMapping mapping(solver, getAnalysis<SymbolTable>(),
2301 getAnalysis<hw::InnerSymbolTableCollection>());
2302 if (failed(mapping.map(getOperation())))
2303 return signalPassFailure();
2304
2305 // fast path if no inferrable widths are around
2306 if (mapping.areAllModulesSkipped())
2307 return markAllAnalysesPreserved();
2308
2309 // Solve the constraints.
2310 if (failed(solver.solve()))
2311 return signalPassFailure();
2312
2313 // Update the types with the inferred widths.
2314 if (failed(InferenceTypeUpdate(mapping).update(getOperation())))
2315 return signalPassFailure();
2316}
2317
2318std::unique_ptr<mlir::Pass> circt::firrtl::createInferWidthsPass() {
2319 return std::make_unique<InferWidthsPass>();
2320}
assert(baseType &&"element must be base type")
static unsigned getFieldID(BundleType type, unsigned index)
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
std::pair< std::optional< int32_t >, bool > ExprSolution
static ExprSolution computeUnary(ExprSolution arg, llvm::function_ref< int32_t(int32_t)> operation)
static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type)
Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
static ExprSolution computeBinary(ExprSolution lhs, ExprSolution rhs, llvm::function_ref< int32_t(int32_t, int32_t)> operation)
#define EXPR_KINDS
#define EXPR_CLASSES
static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth)
Resize a uint, sint, or analog type to a specific width.
static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t, Twine str)
static bool hasUninferredWidth(Type type)
Check if a type contains any FIRRTL type with uninferred widths.
static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl< Expr * > &seenVars, std::vector< Frame > &worklist)
Compute the value of a constraint expr.
static InstancePath empty
This class represents a reference to a specific field or element of an aggregate value.
Definition FieldRef.h:28
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Definition FieldRef.h:59
Value getValue() const
Get the Value which created this location.
Definition FieldRef.h:37
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
This class represents a collection of InnerSymbolTable's.
FIRRTLType mapBaseTypeNullable(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
FieldRef getFieldRefForTarget(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
std::unique_ptr< mlir::Pass > createInferWidthsPass()
llvm::hash_code hash_value(const ClassElement &element)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
static bool operator==(const ModulePort &a, const ModulePort &b)
Definition HWTypes.h:35
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugHeader(llvm::StringRef str, int width=80)
Write a "header"-like string to the debug stream with a certain width.
Definition Debug.cpp:18
Definition debug.py:1
llvm::hash_code hash_value(const T &e)
This class represents the namespace in which InnerRef's can be resolved.