CIRCT  20.0.0git
InferWidths.cpp
Go to the documentation of this file.
1 //===- InferWidths.cpp - Infer width of types -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the InferWidths pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "circt/Support/Debug.h"
18 #include "circt/Support/FieldRef.h"
19 #include "mlir/IR/Threading.h"
20 #include "mlir/Pass/Pass.h"
21 #include "llvm/ADT/APSInt.h"
22 #include "llvm/ADT/DenseSet.h"
23 #include "llvm/ADT/Hashing.h"
24 #include "llvm/ADT/SetVector.h"
25 #include "llvm/Support/Debug.h"
26 #include "llvm/Support/ErrorHandling.h"
27 
28 #define DEBUG_TYPE "infer-widths"
29 
30 namespace circt {
31 namespace firrtl {
32 #define GEN_PASS_DEF_INFERWIDTHS
33 #include "circt/Dialect/FIRRTL/Passes.h.inc"
34 } // namespace firrtl
35 } // namespace circt
36 
37 using mlir::InferTypeOpInterface;
38 using mlir::WalkOrder;
39 
40 using namespace circt;
41 using namespace firrtl;
42 
43 //===----------------------------------------------------------------------===//
44 // Helpers
45 //===----------------------------------------------------------------------===//
46 
47 static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t,
48  Twine str) {
49  auto basetype = type_dyn_cast<FIRRTLBaseType>(t);
50  if (!basetype)
51  return;
52  if (!basetype.hasUninferredWidth())
53  return;
54 
55  if (basetype.isGround())
56  diag.attachNote() << "Field: \"" << str << "\"";
57  else if (auto vecType = type_dyn_cast<FVectorType>(basetype))
58  diagnoseUninferredType(diag, vecType.getElementType(), str + "[]");
59  else if (auto bundleType = type_dyn_cast<BundleType>(basetype))
60  for (auto &elem : bundleType.getElements())
61  diagnoseUninferredType(diag, elem.type, str + "." + elem.name.getValue());
62 }
63 
64 /// Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
65 static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type) {
66  uint64_t convertedFieldID = 0;
67 
68  auto curFID = fieldID;
69  Type curFType = type;
70  while (curFID != 0) {
71  auto [child, subID] =
72  hw::FieldIdImpl::getSubTypeByFieldID(curFType, curFID);
73  if (isa<FVectorType>(curFType))
74  convertedFieldID++; // Vector fieldID is 1.
75  else
76  convertedFieldID += curFID - subID; // Add consumed portion.
77  curFID = subID;
78  curFType = child;
79  }
80 
81  return convertedFieldID;
82 }
83 
84 //===----------------------------------------------------------------------===//
85 // Constraint Expressions
86 //===----------------------------------------------------------------------===//
87 
88 namespace {
89 struct Expr;
90 } // namespace
91 
92 /// Allow rvalue refs to `Expr` and subclasses to be printed to streams.
93 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
94  int>::type = 0>
95 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const T &e) {
96  e.print(os);
97  return os;
98 }
99 
100 // Allow expression subclasses to be hashed.
101 namespace mlir {
102 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
103  int>::type = 0>
104 inline llvm::hash_code hash_value(const T &e) {
105  return e.hash_value();
106 }
107 } // namespace mlir
108 
109 namespace {
110 #define EXPR_NAMES(x) \
111  Var##x, Derived##x, Id##x, Known##x, Add##x, Pow##x, Max##x, Min##x
112 #define EXPR_KINDS EXPR_NAMES()
113 #define EXPR_CLASSES EXPR_NAMES(Expr)
114 
115 /// An expression on the right-hand side of a constraint.
116 struct Expr {
117  enum class Kind : uint8_t { EXPR_KINDS };
118 
119  /// Print a human-readable representation of this expr.
120  void print(llvm::raw_ostream &os) const;
121 
122  std::optional<int32_t> getSolution() const {
123  if (hasSolution)
124  return solution;
125  return std::nullopt;
126  }
127 
128  void setSolution(int32_t solution) {
129  hasSolution = true;
130  this->solution = solution;
131  }
132 
133  Kind getKind() const { return kind; }
134 
135 protected:
136  Expr(Kind kind) : kind(kind) {}
137  llvm::hash_code hash_value() const { return llvm::hash_value(kind); }
138 
139 private:
140  int32_t solution;
141  Kind kind;
142  bool hasSolution = false;
143 };
144 
145 /// Helper class to CRTP-derive common functions.
146 template <class DerivedT, Expr::Kind DerivedKind>
147 struct ExprBase : public Expr {
148  ExprBase() : Expr(DerivedKind) {}
149  static bool classof(const Expr *e) { return e->getKind() == DerivedKind; }
150  bool operator==(const Expr &other) const {
151  if (auto otherSame = dyn_cast<DerivedT>(other))
152  return *static_cast<DerivedT *>(this) == otherSame;
153  return false;
154  }
155 };
156 
157 /// A free variable.
158 struct VarExpr : public ExprBase<VarExpr, Expr::Kind::Var> {
159  void print(llvm::raw_ostream &os) const {
160  // Hash the `this` pointer into something somewhat human readable. Since
161  // this is just for debug dumping, we wrap around at 65536 variables.
162  os << "var" << ((size_t)this / llvm::PowerOf2Ceil(sizeof(*this)) & 0xFFFF);
163  }
164 
165  /// The constraint expression this variable is supposed to be greater than or
166  /// equal to. This is not part of the variable's hash and equality property.
167  Expr *constraint = nullptr;
168 
169  /// The upper bound this variable is supposed to be smaller than or equal to.
170  Expr *upperBound = nullptr;
171  std::optional<int32_t> upperBoundSolution;
172 };
173 
174 /// A derived width.
175 ///
176 /// These are generated for `InvalidValueOp`s which want to derived their width
177 /// from connect operations that they are on the right hand side of.
178 struct DerivedExpr : public ExprBase<DerivedExpr, Expr::Kind::Derived> {
179  void print(llvm::raw_ostream &os) const {
180  // Hash the `this` pointer into something somewhat human readable.
181  os << "derive"
182  << ((size_t)this / llvm::PowerOf2Ceil(sizeof(*this)) & 0xFFF);
183  }
184 
185  /// The expression this derived width is equivalent to.
186  Expr *assigned = nullptr;
187 };
188 
189 /// An identity expression.
190 ///
191 /// This expression evaluates to its inner expression. It is used in a very
192 /// specific case of constraints on variables, in order to be able to track
193 /// where the constraint was imposed. Constraints on variables are represented
194 /// as `var >= <expr>`. When the first constraint `a` is imposed, it is stored
195 /// as the constraint expression (`var >= a`). When the second constraint `b` is
196 /// imposed, a *new* max expression is allocated (`var >= max(a, b)`).
197 /// Expressions are annotated with a location when they are created, which in
198 /// this case are connect ops. Since imposing the first constraint does not
199 /// create any new expression, the location information of that connect would be
200 /// lost. With an identity expression, imposing the first constraint becomes
201 /// `var >= identity(a)`, which is a *new* expression and properly tracks the
202 /// location info.
203 struct IdExpr : public ExprBase<IdExpr, Expr::Kind::Id> {
204  IdExpr(Expr *arg) : arg(arg) { assert(arg); }
205  void print(llvm::raw_ostream &os) const { os << "*" << *arg; }
206  bool operator==(const IdExpr &other) const {
207  return getKind() == other.getKind() && arg == other.arg;
208  }
209  llvm::hash_code hash_value() const {
210  return llvm::hash_combine(Expr::hash_value(), arg);
211  }
212 
213  /// The inner expression.
214  Expr *const arg;
215 };
216 
217 /// A known constant value.
218 struct KnownExpr : public ExprBase<KnownExpr, Expr::Kind::Known> {
219  KnownExpr(int32_t value) : ExprBase() { setSolution(value); }
220  void print(llvm::raw_ostream &os) const { os << *getSolution(); }
221  bool operator==(const KnownExpr &other) const {
222  return *getSolution() == *other.getSolution();
223  }
224  llvm::hash_code hash_value() const {
225  return llvm::hash_combine(Expr::hash_value(), *getSolution());
226  }
227  int32_t getValue() const { return *getSolution(); }
228 };
229 
230 /// A unary expression. Contains the actual data. Concrete subclasses are merely
231 /// there for show and ease of use.
232 struct UnaryExpr : public Expr {
233  bool operator==(const UnaryExpr &other) const {
234  return getKind() == other.getKind() && arg == other.arg;
235  }
236  llvm::hash_code hash_value() const {
237  return llvm::hash_combine(Expr::hash_value(), arg);
238  }
239 
240  /// The child expression.
241  Expr *const arg;
242 
243 protected:
244  UnaryExpr(Kind kind, Expr *arg) : Expr(kind), arg(arg) { assert(arg); }
245 };
246 
247 /// Helper class to CRTP-derive common functions.
248 template <class DerivedT, Expr::Kind DerivedKind>
249 struct UnaryExprBase : public UnaryExpr {
250  template <typename... Args>
251  UnaryExprBase(Args &&...args)
252  : UnaryExpr(DerivedKind, std::forward<Args>(args)...) {}
253  static bool classof(const Expr *e) { return e->getKind() == DerivedKind; }
254 };
255 
256 /// A power of two.
257 struct PowExpr : public UnaryExprBase<PowExpr, Expr::Kind::Pow> {
258  using UnaryExprBase::UnaryExprBase;
259  void print(llvm::raw_ostream &os) const { os << "2^" << arg; }
260 };
261 
262 /// A binary expression. Contains the actual data. Concrete subclasses are
263 /// merely there for show and ease of use.
264 struct BinaryExpr : public Expr {
265  bool operator==(const BinaryExpr &other) const {
266  return getKind() == other.getKind() && lhs() == other.lhs() &&
267  rhs() == other.rhs();
268  }
269  llvm::hash_code hash_value() const {
270  return llvm::hash_combine(Expr::hash_value(), *args);
271  }
272  Expr *lhs() const { return args[0]; }
273  Expr *rhs() const { return args[1]; }
274 
275  /// The child expressions.
276  Expr *const args[2];
277 
278 protected:
279  BinaryExpr(Kind kind, Expr *lhs, Expr *rhs) : Expr(kind), args{lhs, rhs} {
280  assert(lhs);
281  assert(rhs);
282  }
283 };
284 
285 /// Helper class to CRTP-derive common functions.
286 template <class DerivedT, Expr::Kind DerivedKind>
287 struct BinaryExprBase : public BinaryExpr {
288  template <typename... Args>
289  BinaryExprBase(Args &&...args)
290  : BinaryExpr(DerivedKind, std::forward<Args>(args)...) {}
291  static bool classof(const Expr *e) { return e->getKind() == DerivedKind; }
292 };
293 
294 /// An addition.
295 struct AddExpr : public BinaryExprBase<AddExpr, Expr::Kind::Add> {
296  using BinaryExprBase::BinaryExprBase;
297  void print(llvm::raw_ostream &os) const {
298  os << "(" << *lhs() << " + " << *rhs() << ")";
299  }
300 };
301 
302 /// The maximum of two expressions.
303 struct MaxExpr : public BinaryExprBase<MaxExpr, Expr::Kind::Max> {
304  using BinaryExprBase::BinaryExprBase;
305  void print(llvm::raw_ostream &os) const {
306  os << "max(" << *lhs() << ", " << *rhs() << ")";
307  }
308 };
309 
310 /// The minimum of two expressions.
311 struct MinExpr : public BinaryExprBase<MinExpr, Expr::Kind::Min> {
312  using BinaryExprBase::BinaryExprBase;
313  void print(llvm::raw_ostream &os) const {
314  os << "min(" << *lhs() << ", " << *rhs() << ")";
315  }
316 };
317 
318 void Expr::print(llvm::raw_ostream &os) const {
319  TypeSwitch<const Expr *>(this).Case<EXPR_CLASSES>(
320  [&](auto *e) { e->print(os); });
321 }
322 
323 } // namespace
324 
325 //===----------------------------------------------------------------------===//
326 // Fast bump allocator with optional interning
327 //===----------------------------------------------------------------------===//
328 
329 namespace {
330 
331 // Hash slots in the interned allocator as if they were the pointed-to value
332 // itself.
333 template <typename T>
334 struct InternedSlotInfo : DenseMapInfo<T *> {
335  static T *getEmptyKey() {
336  auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
337  return static_cast<T *>(pointer);
338  }
339  static T *getTombstoneKey() {
341  return static_cast<T *>(pointer);
342  }
343  static unsigned getHashValue(const T *val) { return mlir::hash_value(*val); }
344  static bool isEqual(const T *lhs, const T *rhs) {
345  auto empty = getEmptyKey();
346  auto tombstone = getTombstoneKey();
347  if (lhs == empty || rhs == empty || lhs == tombstone || rhs == tombstone)
348  return lhs == rhs;
349  return *lhs == *rhs;
350  }
351 };
352 
353 /// A simple bump allocator that ensures only ever one copy per object exists.
354 /// The allocated objects must not have a destructor.
355 template <typename T, typename std::enable_if_t<
356  std::is_trivially_destructible<T>::value, int> = 0>
357 class InternedAllocator {
358  llvm::DenseSet<T *, InternedSlotInfo<T>> interned;
359  llvm::BumpPtrAllocator &allocator;
360 
361 public:
362  InternedAllocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
363 
364  /// Allocate a new object if it does not yet exist, or return a pointer to the
365  /// existing one. `R` is the type of the object to be allocated. `R` must be
366  /// derived from or be the type `T`.
367  template <typename R = T, typename... Args>
368  std::pair<R *, bool> alloc(Args &&...args) {
369  auto stackValue = R(std::forward<Args>(args)...);
370  auto *stackSlot = &stackValue;
371  auto it = interned.find(stackSlot);
372  if (it != interned.end())
373  return std::make_pair(static_cast<R *>(*it), false);
374  auto heapValue = new (allocator) R(std::move(stackValue));
375  interned.insert(heapValue);
376  return std::make_pair(heapValue, true);
377  }
378 };
379 
380 /// A simple bump allocator. The allocated objects must not have a destructor.
381 /// This allocator is mainly there for symmetry with the `InternedAllocator`.
382 template <typename T, typename std::enable_if_t<
383  std::is_trivially_destructible<T>::value, int> = 0>
384 class Allocator {
385  llvm::BumpPtrAllocator &allocator;
386 
387 public:
388  Allocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
389 
390  /// Allocate a new object. `R` is the type of the object to be allocated. `R`
391  /// must be derived from or be the type `T`.
392  template <typename R = T, typename... Args>
393  R *alloc(Args &&...args) {
394  return new (allocator) R(std::forward<Args>(args)...);
395  }
396 };
397 
398 } // namespace
399 
400 //===----------------------------------------------------------------------===//
401 // Constraint Solver
402 //===----------------------------------------------------------------------===//
403 
404 namespace {
405 /// A canonicalized linear inequality that maps a constraint on var `x` to the
406 /// linear inequality `x >= max(a*x+b, c) + (failed ? ∞ : 0)`.
407 ///
408 /// The inequality separately tracks recursive (a, b) and non-recursive (c)
409 /// constraints on `x`. This allows it to properly identify the combination of
410 /// the two constraints `x >= x-1` and `x >= 4` to be satisfiable as
411 /// `x >= max(x-1, 4)`. If it only tracked inequality as `x >= a*x+b`, the
412 /// combination of these two constraints would be `x >= x+4` (due to max(-1,4) =
413 /// 4), which would be unsatisfiable.
414 ///
415 /// The `failed` flag acts as an additional `∞` term that renders the inequality
416 /// unsatisfiable. It is used as a tombstone value in case an operation renders
417 /// the equality unsatisfiable (e.g. `x >= 2**x` would be represented as the
418 /// inequality `x >= ∞`).
419 ///
420 /// Inequalities represented in this form can easily be checked for
421 /// unsatisfiability in the presence of recursion by inspecting the coefficients
422 /// a and b. The `sat` function performs this action.
423 struct LinIneq {
424  // x >= max(a*x+b, c) + (failed ? ∞ : 0)
425  int32_t recScale = 0; // a
426  int32_t recBias = 0; // b
427  int32_t nonrecBias = 0; // c
428  bool failed = false;
429 
430  /// Create a new unsatisfiable inequality `x >= ∞`.
431  static LinIneq unsat() { return LinIneq(true); }
432 
433  /// Create a new inequality `x >= (failed ? ∞ : 0)`.
434  explicit LinIneq(bool failed = false) : failed(failed) {}
435 
436  /// Create a new inequality `x >= bias`.
437  explicit LinIneq(int32_t bias) : nonrecBias(bias) {}
438 
439  /// Create a new inequality `x >= scale*x+bias`.
440  explicit LinIneq(int32_t scale, int32_t bias) {
441  if (scale != 0) {
442  recScale = scale;
443  recBias = bias;
444  } else {
445  nonrecBias = bias;
446  }
447  }
448 
449  /// Create a new inequality `x >= max(recScale*x+recBias, nonrecBias) +
450  /// (failed ? ∞ : 0)`.
451  explicit LinIneq(int32_t recScale, int32_t recBias, int32_t nonrecBias,
452  bool failed = false)
453  : failed(failed) {
454  if (recScale != 0) {
455  this->recScale = recScale;
456  this->recBias = recBias;
457  this->nonrecBias = nonrecBias;
458  } else {
459  this->nonrecBias = std::max(recBias, nonrecBias);
460  }
461  }
462 
463  /// Combine two inequalities by taking the maxima of corresponding
464  /// coefficients.
465  ///
466  /// This essentially combines `x >= max(a1*x+b1, c1)` and `x >= max(a2*x+b2,
467  /// c2)` into a new `x >= max(max(a1,a2)*x+max(b1,b2), max(c1,c2))`. This is
468  /// a pessimistic upper bound, since e.g. `x >= 2x-10` and `x >= x-5` may both
469  /// hold, but the resulting `x >= 2x-5` may pessimistically not hold.
470  static LinIneq max(const LinIneq &lhs, const LinIneq &rhs) {
471  return LinIneq(std::max(lhs.recScale, rhs.recScale),
472  std::max(lhs.recBias, rhs.recBias),
473  std::max(lhs.nonrecBias, rhs.nonrecBias),
474  lhs.failed || rhs.failed);
475  }
476 
477  /// Combine two inequalities by summing up the two right hand sides.
478  ///
479  /// This is a tricky one, since the addition of the two max terms will lead to
480  /// a maximum over four possible terms (similar to a binomial expansion). In
481  /// order to shoehorn this back into a two-term maximum, we have to pick the
482  /// recursive term that will grow the fastest.
483  ///
484  /// As an example for this problem, consider the following addition:
485  ///
486  /// x >= max(a1*x+b1, c1) + max(a2*x+b2, c2)
487  ///
488  /// We would like to expand and rearrange this again into a maximum:
489  ///
490  /// x >= max(a1*x+b1 + max(a2*x+b2, c2), c1 + max(a2*x+b2, c2))
491  /// x >= max(max(a1*x+b1 + a2*x+b2, a1*x+b1 + c2),
492  /// max(c1 + a2*x+b2, c1 + c2))
493  /// x >= max((a1+a2)*x+(b1+b2), a1*x+(b1+c2), a2*x+(b2+c1), c1+c2)
494  ///
495  /// Since we are combining two two-term maxima, there are four possible ways
496  /// how the terms can combine, leading to the above four-term maximum. An easy
497  /// upper bound of the form we want would be the following:
498  ///
499  /// x >= max(max(a1+a2, a1, a2)*x + max(b1+b2, b1+c2, b2+c1), c1+c2)
500  ///
501  /// However, this is a very pessimistic upper-bound that will declare very
502  /// common patterns in the IR as unbreakable cycles, despite them being very
503  /// much breakable. For example:
504  ///
505  /// x >= max(x, 42) + max(0, -3) <-- breakable recursion
506  /// x >= max(max(1+0, 1, 0)*x + max(42+0, -3, 42), 42-2)
507  /// x >= max(x + 42, 39) <-- unbreakable recursion!
508  ///
509  /// A better approach is to take the expanded four-term maximum, retain the
510  /// non-recursive term (c1+c2), and estimate which one of the recursive terms
511  /// (first three) will become dominant as we choose greater values for x.
512  /// Since x never is inferred to be negative, the recursive term in the
513  /// maximum with the highest scaling factor for x will end up dominating as
514  /// x tends to ∞:
515  ///
516  /// x >= max({
517  /// (a1+a2)*x+(b1+b2) if a1+a2 >= max(a1+a2, a1, a2) and a1>0 and a2>0,
518  /// a1*x+(b1+c2) if a1 >= max(a1+a2, a1, a2) and a1>0,
519  /// a2*x+(b2+c1) if a2 >= max(a1+a2, a1, a2) and a2>0,
520  /// 0 otherwise
521  /// }, c1+c2)
522  ///
523  /// In case multiple cases apply, the highest bias of the recursive term is
524  /// picked. With this, the above problematic example triggers the second case
525  /// and becomes:
526  ///
527  /// x >= max(1*x+(0-3), 42-3) = max(x-3, 39)
528  ///
529  /// Of which the first case is chosen, as it has the lower bias value.
530  static LinIneq add(const LinIneq &lhs, const LinIneq &rhs) {
531  // Determine the maximum scaling factor among the three possible recursive
532  // terms.
533  auto enable1 = lhs.recScale > 0 && rhs.recScale > 0;
534  auto enable2 = lhs.recScale > 0;
535  auto enable3 = rhs.recScale > 0;
536  auto scale1 = lhs.recScale + rhs.recScale; // (a1+a2)
537  auto scale2 = lhs.recScale; // a1
538  auto scale3 = rhs.recScale; // a2
539  auto bias1 = lhs.recBias + rhs.recBias; // (b1+b2)
540  auto bias2 = lhs.recBias + rhs.nonrecBias; // (b1+c2)
541  auto bias3 = rhs.recBias + lhs.nonrecBias; // (b2+c1)
542  auto maxScale = std::max(scale1, std::max(scale2, scale3));
543 
544  // Among those terms that have a maximum scaling factor, determine the
545  // largest bias value.
546  std::optional<int32_t> maxBias;
547  if (enable1 && scale1 == maxScale)
548  maxBias = bias1;
549  if (enable2 && scale2 == maxScale && (!maxBias || bias2 > *maxBias))
550  maxBias = bias2;
551  if (enable3 && scale3 == maxScale && (!maxBias || bias3 > *maxBias))
552  maxBias = bias3;
553 
554  // Pick from the recursive terms the one with maximum scaling factor and
555  // minimum bias value.
556  auto nonrecBias = lhs.nonrecBias + rhs.nonrecBias; // c1+c2
557  auto failed = lhs.failed || rhs.failed;
558  if (enable1 && scale1 == maxScale && bias1 == *maxBias)
559  return LinIneq(scale1, bias1, nonrecBias, failed);
560  if (enable2 && scale2 == maxScale && bias2 == *maxBias)
561  return LinIneq(scale2, bias2, nonrecBias, failed);
562  if (enable3 && scale3 == maxScale && bias3 == *maxBias)
563  return LinIneq(scale3, bias3, nonrecBias, failed);
564  return LinIneq(0, 0, nonrecBias, failed);
565  }
566 
567  /// Check if the inequality is satisfiable.
568  ///
569  /// The inequality becomes unsatisfiable if the RHS is ∞, or a>1, or a==1 and
570  /// b <= 0. Otherwise there exists as solution for `x` that satisfies the
571  /// inequality.
572  bool sat() const {
573  if (failed)
574  return false;
575  if (recScale > 1)
576  return false;
577  if (recScale == 1 && recBias > 0)
578  return false;
579  return true;
580  }
581 
582  /// Dump the inequality in human-readable form.
583  void print(llvm::raw_ostream &os) const {
584  bool any = false;
585  bool both = (recScale != 0 || recBias != 0) && nonrecBias != 0;
586  os << "x >= ";
587  if (both)
588  os << "max(";
589  if (recScale != 0) {
590  any = true;
591  if (recScale != 1)
592  os << recScale << "*";
593  os << "x";
594  }
595  if (recBias != 0) {
596  if (any) {
597  if (recBias < 0)
598  os << " - " << -recBias;
599  else
600  os << " + " << recBias;
601  } else {
602  any = true;
603  os << recBias;
604  }
605  }
606  if (both)
607  os << ", ";
608  if (nonrecBias != 0) {
609  any = true;
610  os << nonrecBias;
611  }
612  if (both)
613  os << ")";
614  if (failed) {
615  if (any)
616  os << " + ";
617  os << "∞";
618  }
619  if (!any)
620  os << "0";
621  }
622 };
623 
624 /// A simple solver for width constraints.
625 class ConstraintSolver {
626 public:
627  ConstraintSolver() = default;
628 
629  VarExpr *var() {
630  auto *v = vars.alloc();
631  varExprs.push_back(v);
632  if (currentInfo)
633  info[v].insert(currentInfo);
634  if (currentLoc)
635  locs[v].insert(*currentLoc);
636  return v;
637  }
638  DerivedExpr *derived() {
639  auto *d = derivs.alloc();
640  derivedExprs.push_back(d);
641  return d;
642  }
643  KnownExpr *known(int32_t value) { return alloc<KnownExpr>(knowns, value); }
644  IdExpr *id(Expr *arg) { return alloc<IdExpr>(ids, arg); }
645  PowExpr *pow(Expr *arg) { return alloc<PowExpr>(uns, arg); }
646  AddExpr *add(Expr *lhs, Expr *rhs) { return alloc<AddExpr>(bins, lhs, rhs); }
647  MaxExpr *max(Expr *lhs, Expr *rhs) { return alloc<MaxExpr>(bins, lhs, rhs); }
648  MinExpr *min(Expr *lhs, Expr *rhs) { return alloc<MinExpr>(bins, lhs, rhs); }
649 
650  /// Add a constraint `lhs >= rhs`. Multiple constraints on the same variable
651  /// are coalesced into a `max(a, b)` expr.
652  Expr *addGeqConstraint(VarExpr *lhs, Expr *rhs) {
653  if (lhs->constraint)
654  lhs->constraint = max(lhs->constraint, rhs);
655  else
656  lhs->constraint = id(rhs);
657  return lhs->constraint;
658  }
659 
660  /// Add a constraint `lhs <= rhs`. Multiple constraints on the same variable
661  /// are coalesced into a `min(a, b)` expr.
662  Expr *addLeqConstraint(VarExpr *lhs, Expr *rhs) {
663  if (lhs->upperBound)
664  lhs->upperBound = min(lhs->upperBound, rhs);
665  else
666  lhs->upperBound = id(rhs);
667  return lhs->upperBound;
668  }
669 
670  void dumpConstraints(llvm::raw_ostream &os);
671  LogicalResult solve();
672 
673  using ContextInfo = DenseMap<Expr *, llvm::SmallSetVector<FieldRef, 1>>;
674  const ContextInfo &getContextInfo() const { return info; }
675  void setCurrentContextInfo(FieldRef fieldRef) { currentInfo = fieldRef; }
676  void setCurrentLocation(std::optional<Location> loc) { currentLoc = loc; }
677 
678 private:
679  // Allocator for constraint expressions.
680  llvm::BumpPtrAllocator allocator;
681  Allocator<VarExpr> vars = {allocator};
682  Allocator<DerivedExpr> derivs = {allocator};
683  InternedAllocator<KnownExpr> knowns = {allocator};
684  InternedAllocator<IdExpr> ids = {allocator};
685  InternedAllocator<UnaryExpr> uns = {allocator};
686  InternedAllocator<BinaryExpr> bins = {allocator};
687 
688  /// A list of expressions in the order they were created.
689  std::vector<VarExpr *> varExprs;
690  std::vector<DerivedExpr *> derivedExprs;
691 
692  /// Add an allocated expression to the list above.
693  template <typename R, typename T, typename... Args>
694  R *alloc(InternedAllocator<T> &allocator, Args &&...args) {
695  auto [expr, inserted] =
696  allocator.template alloc<R>(std::forward<Args>(args)...);
697  if (currentInfo)
698  info[expr].insert(currentInfo);
699  if (currentLoc)
700  locs[expr].insert(*currentLoc);
701  return expr;
702  }
703 
704  /// Contextual information for each expression, indicating which values in the
705  /// IR lead to this expression.
706  ContextInfo info;
707  FieldRef currentInfo = {};
708  DenseMap<Expr *, llvm::SmallSetVector<Location, 1>> locs;
709  std::optional<Location> currentLoc;
710 
711  // Forbid copyign or moving the solver, which would invalidate the refs to
712  // allocator held by the allocators.
713  ConstraintSolver(ConstraintSolver &&) = delete;
714  ConstraintSolver(const ConstraintSolver &) = delete;
715  ConstraintSolver &operator=(ConstraintSolver &&) = delete;
716  ConstraintSolver &operator=(const ConstraintSolver &) = delete;
717 
718  void emitUninferredWidthError(VarExpr *var);
719 
720  LinIneq checkCycles(VarExpr *var, Expr *expr,
721  SmallPtrSetImpl<Expr *> &seenVars,
722  InFlightDiagnostic *reportInto = nullptr,
723  unsigned indent = 1);
724 };
725 
726 } // namespace
727 
728 /// Print all constraints in the solver to an output stream.
729 void ConstraintSolver::dumpConstraints(llvm::raw_ostream &os) {
730  for (auto *v : varExprs) {
731  if (v->constraint)
732  os << "- " << *v << " >= " << *v->constraint << "\n";
733  else
734  os << "- " << *v << " unconstrained\n";
735  }
736 }
737 
738 #ifndef NDEBUG
739 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const LinIneq &l) {
740  l.print(os);
741  return os;
742 }
743 #endif
744 
745 /// Compute the canonicalized linear inequality expression starting at `expr`,
746 /// for the `var` as the left hand side `x` of the inequality. `seenVars` is
747 /// used as a recursion breaker. Occurrences of `var` itself within the
748 /// expression are mapped to the `a` coefficient in the inequality. Any other
749 /// variables are substituted and, in the presence of a recursion in a variable
750 /// other than `var`, treated as zero. `info` is a mapping from constraint
751 /// expressions to values and operations that produced the expression, and is
752 /// used during error reporting. If `reportInto` is present, the function will
753 /// additionally attach unsatisfiable inequalities as notes to the diagnostic as
754 /// it encounters them.
755 LinIneq ConstraintSolver::checkCycles(VarExpr *var, Expr *expr,
756  SmallPtrSetImpl<Expr *> &seenVars,
757  InFlightDiagnostic *reportInto,
758  unsigned indent) {
759  auto ineq =
760  TypeSwitch<Expr *, LinIneq>(expr)
761  .Case<KnownExpr>(
762  [&](auto *expr) { return LinIneq(expr->getValue()); })
763  .Case<VarExpr>([&](auto *expr) {
764  if (expr == var)
765  return LinIneq(1, 0); // x >= 1*x + 0
766  if (!seenVars.insert(expr).second)
767  // Count recursions in other variables as 0. This is sane
768  // since the cycle is either breakable, in which case the
769  // recursion does not modify the resulting value of the
770  // variable, or it is not breakable and will be caught by
771  // this very function once it is called on that variable.
772  return LinIneq(0);
773  if (!expr->constraint)
774  // Count unconstrained variables as `x >= 0`.
775  return LinIneq(0);
776  auto l = checkCycles(var, expr->constraint, seenVars, reportInto,
777  indent + 1);
778  seenVars.erase(expr);
779  return l;
780  })
781  .Case<IdExpr>([&](auto *expr) {
782  return checkCycles(var, expr->arg, seenVars, reportInto,
783  indent + 1);
784  })
785  .Case<PowExpr>([&](auto *expr) {
786  // If we can evaluate `2**arg` to a sensible constant, do
787  // so. This is the case if a == 0 and c < 31 such that 2**c is
788  // representable.
789  auto arg =
790  checkCycles(var, expr->arg, seenVars, reportInto, indent + 1);
791  if (arg.recScale != 0 || arg.nonrecBias < 0 || arg.nonrecBias >= 31)
792  return LinIneq::unsat();
793  return LinIneq(1 << arg.nonrecBias); // x >= 2**arg
794  })
795  .Case<AddExpr>([&](auto *expr) {
796  return LinIneq::add(
797  checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
798  checkCycles(var, expr->rhs(), seenVars, reportInto,
799  indent + 1));
800  })
801  .Case<MaxExpr, MinExpr>([&](auto *expr) {
802  // Combine the inequalities of the LHS and RHS into a single overly
803  // pessimistic inequality. We treat `MinExpr` the same as `MaxExpr`,
804  // since `max(a,b)` is an upper bound to `min(a,b)`.
805  return LinIneq::max(
806  checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
807  checkCycles(var, expr->rhs(), seenVars, reportInto,
808  indent + 1));
809  })
810  .Default([](auto) { return LinIneq::unsat(); });
811 
812  // If we were passed an in-flight diagnostic and the current inequality is
813  // unsatisfiable, attach notes to the diagnostic indicating the values or
814  // operations that contributed to this part of the constraint expression.
815  if (reportInto && !ineq.sat()) {
816  auto report = [&](Location loc) {
817  auto &note = reportInto->attachNote(loc);
818  note << "constrained width W >= ";
819  if (ineq.recScale == -1)
820  note << "-";
821  if (ineq.recScale != 1)
822  note << ineq.recScale;
823  note << "W";
824  if (ineq.recBias < 0)
825  note << "-" << -ineq.recBias;
826  if (ineq.recBias > 0)
827  note << "+" << ineq.recBias;
828  note << " here:";
829  };
830  auto it = locs.find(expr);
831  if (it != locs.end())
832  for (auto loc : it->second)
833  report(loc);
834  }
835  if (!reportInto)
836  LLVM_DEBUG(llvm::dbgs().indent(indent * 2)
837  << "- Visited " << *expr << ": " << ineq << "\n");
838 
839  return ineq;
840 }
841 
842 using ExprSolution = std::pair<std::optional<int32_t>, bool>;
843 
844 static ExprSolution
845 computeUnary(ExprSolution arg, llvm::function_ref<int32_t(int32_t)> operation) {
846  if (arg.first)
847  arg.first = operation(*arg.first);
848  return arg;
849 }
850 
851 static ExprSolution
853  llvm::function_ref<int32_t(int32_t, int32_t)> operation) {
854  auto result = ExprSolution{std::nullopt, lhs.second || rhs.second};
855  if (lhs.first && rhs.first)
856  result.first = operation(*lhs.first, *rhs.first);
857  else if (lhs.first)
858  result.first = lhs.first;
859  else if (rhs.first)
860  result.first = rhs.first;
861  return result;
862 }
863 
864 namespace {
865 struct Frame {
866  Frame(Expr *expr, unsigned indent) : expr(expr), indent(indent) {}
867  Expr *expr;
868  // Indent is only used for debug logs.
869  unsigned indent;
870 };
871 } // namespace
872 
873 /// Compute the value of a constraint `expr`. `seenVars` is used as a recursion
874 /// breaker. Recursive variables are treated as zero. Returns the computed value
875 /// and a boolean indicating whether a recursion was detected. This may be used
876 /// to memoize the result of expressions in case they were not involved in a
877 /// cycle (which may alter their value from the perspective of a variable).
878 static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
879  std::vector<Frame> &worklist) {
880  worklist.clear();
881  worklist.emplace_back(expr, 1);
882  llvm::DenseMap<Expr *, ExprSolution> solvedExprs;
883 
884  while (!worklist.empty()) {
885  auto &frame = worklist.back();
886  auto indent = frame.indent;
887  auto setSolution = [&](ExprSolution solution) {
888  // Memoize the result.
889  if (solution.first && !solution.second)
890  frame.expr->setSolution(*solution.first);
891  solvedExprs[frame.expr] = solution;
892 
893  // Produce some useful debug prints.
894  LLVM_DEBUG({
895  if (!isa<KnownExpr>(frame.expr)) {
896  if (solution.first)
897  llvm::dbgs().indent(indent * 2)
898  << "= Solved " << *frame.expr << " = " << *solution.first;
899  else
900  llvm::dbgs().indent(indent * 2) << "= Skipped " << *frame.expr;
901  llvm::dbgs() << " (" << (solution.second ? "cycle broken" : "unique")
902  << ")\n";
903  }
904  });
905 
906  worklist.pop_back();
907  };
908 
909  // See if we have a memoized result we can return.
910  if (frame.expr->getSolution()) {
911  LLVM_DEBUG({
912  if (!isa<KnownExpr>(frame.expr))
913  llvm::dbgs().indent(indent * 2) << "- Cached " << *frame.expr << " = "
914  << *frame.expr->getSolution() << "\n";
915  });
916  setSolution(ExprSolution{*frame.expr->getSolution(), false});
917  continue;
918  }
919 
920  // Otherwise compute the value of the expression.
921  LLVM_DEBUG({
922  if (!isa<KnownExpr>(frame.expr))
923  llvm::dbgs().indent(indent * 2) << "- Solving " << *frame.expr << "\n";
924  });
925 
926  TypeSwitch<Expr *>(frame.expr)
927  .Case<KnownExpr>([&](auto *expr) {
928  setSolution(ExprSolution{expr->getValue(), false});
929  })
930  .Case<VarExpr>([&](auto *expr) {
931  if (solvedExprs.contains(expr->constraint)) {
932  auto solution = solvedExprs[expr->constraint];
933  // If we've solved the upper bound already, store the solution.
934  // This will be explicitly solved for later if not computed as
935  // part of the solving that resolved this constraint.
936  // This should only happen if somehow the constraint is
937  // solved before visiting this expression, so that our upperBound
938  // was not added to the worklist such that it was handled first.
939  if (expr->upperBound && solvedExprs.contains(expr->upperBound))
940  expr->upperBoundSolution = solvedExprs[expr->upperBound].first;
941  seenVars.erase(expr);
942  // Constrain variables >= 0.
943  if (solution.first && *solution.first < 0)
944  solution.first = 0;
945  return setSolution(solution);
946  }
947 
948  // Unconstrained variables produce no solution.
949  if (!expr->constraint)
950  return setSolution(ExprSolution{std::nullopt, false});
951  // Return no solution for recursions in the variables. This is sane
952  // and will cause the expression to be ignored when computing the
953  // parent, e.g. `a >= max(a, 1)` will become just `a >= 1`.
954  if (!seenVars.insert(expr).second)
955  return setSolution(ExprSolution{std::nullopt, true});
956 
957  worklist.emplace_back(expr->constraint, indent + 1);
958  if (expr->upperBound)
959  worklist.emplace_back(expr->upperBound, indent + 1);
960  })
961  .Case<IdExpr>([&](auto *expr) {
962  if (solvedExprs.contains(expr->arg))
963  return setSolution(solvedExprs[expr->arg]);
964  worklist.emplace_back(expr->arg, indent + 1);
965  })
966  .Case<PowExpr>([&](auto *expr) {
967  if (solvedExprs.contains(expr->arg))
968  return setSolution(computeUnary(
969  solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));
970 
971  worklist.emplace_back(expr->arg, indent + 1);
972  })
973  .Case<AddExpr>([&](auto *expr) {
974  if (solvedExprs.contains(expr->lhs()) &&
975  solvedExprs.contains(expr->rhs()))
976  return setSolution(computeBinary(
977  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
978  [](int32_t lhs, int32_t rhs) { return lhs + rhs; }));
979 
980  worklist.emplace_back(expr->lhs(), indent + 1);
981  worklist.emplace_back(expr->rhs(), indent + 1);
982  })
983  .Case<MaxExpr>([&](auto *expr) {
984  if (solvedExprs.contains(expr->lhs()) &&
985  solvedExprs.contains(expr->rhs()))
986  return setSolution(computeBinary(
987  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
988  [](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));
989 
990  worklist.emplace_back(expr->lhs(), indent + 1);
991  worklist.emplace_back(expr->rhs(), indent + 1);
992  })
993  .Case<MinExpr>([&](auto *expr) {
994  if (solvedExprs.contains(expr->lhs()) &&
995  solvedExprs.contains(expr->rhs()))
996  return setSolution(computeBinary(
997  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
998  [](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));
999 
1000  worklist.emplace_back(expr->lhs(), indent + 1);
1001  worklist.emplace_back(expr->rhs(), indent + 1);
1002  })
1003  .Default([&](auto) {
1004  setSolution(ExprSolution{std::nullopt, false});
1005  });
1006  }
1007 
1008  return solvedExprs[expr];
1009 }
1010 
1011 /// Solve the constraint problem. This is a very simple implementation that
1012 /// does not fully solve the problem if there are weird dependency cycles
1013 /// present.
1014 LogicalResult ConstraintSolver::solve() {
1015  LLVM_DEBUG({
1016  llvm::dbgs() << "\n";
1017  debugHeader("Constraints") << "\n\n";
1018  dumpConstraints(llvm::dbgs());
1019  });
1020 
1021  // Ensure that there are no adverse cycles around.
1022  LLVM_DEBUG({
1023  llvm::dbgs() << "\n";
1024  debugHeader("Checking for unbreakable loops") << "\n\n";
1025  });
1026  SmallPtrSet<Expr *, 16> seenVars;
1027  bool anyFailed = false;
1028 
1029  for (auto *var : varExprs) {
1030  if (!var->constraint)
1031  continue;
1032  LLVM_DEBUG(llvm::dbgs()
1033  << "- Checking " << *var << " >= " << *var->constraint << "\n");
1034 
1035  // Canonicalize the variable's constraint expression into a form that allows
1036  // us to easily determine if any recursion leads to an unsatisfiable
1037  // constraint. The `seenVars` set acts as a recursion breaker.
1038  seenVars.insert(var);
1039  auto ineq = checkCycles(var, var->constraint, seenVars);
1040  seenVars.clear();
1041 
1042  // If the constraint is satisfiable, we're done.
1043  // TODO: It's possible that this result is already sufficient to arrive at a
1044  // solution for the constraint, and the second pass further down is not
1045  // necessary. This would require more proper handling of `MinExpr` in the
1046  // cycle checking code.
1047  if (ineq.sat()) {
1048  LLVM_DEBUG(llvm::dbgs()
1049  << " = Breakable since " << ineq << " satisfiable\n");
1050  continue;
1051  }
1052 
1053  // If we arrive here, the constraint is not satisfiable at all. To provide
1054  // some guidance to the user, we call the cycle checking code again, but
1055  // this time with an in-flight diagnostic to attach notes indicating
1056  // unsatisfiable paths in the cycle.
1057  LLVM_DEBUG(llvm::dbgs()
1058  << " = UNBREAKABLE since " << ineq << " unsatisfiable\n");
1059  anyFailed = true;
1060  for (auto fieldRef : info.find(var)->second) {
1061  // Depending on whether this value stems from an operation or not, create
1062  // an appropriate diagnostic identifying the value.
1063  auto *op = fieldRef.getDefiningOp();
1064  auto diag = op ? op->emitOpError()
1065  : mlir::emitError(fieldRef.getValue().getLoc())
1066  << "value ";
1067  diag << "is constrained to be wider than itself";
1068 
1069  // Re-run the cycle checking, but this time reporting into the diagnostic.
1070  seenVars.insert(var);
1071  checkCycles(var, var->constraint, seenVars, &diag);
1072  seenVars.clear();
1073  }
1074  }
1075 
1076  // If there were cycles, return now to avoid complaining to the user about
1077  // dependent widths not being inferred.
1078  if (anyFailed)
1079  return failure();
1080 
1081  // Iterate over the constraint variables and solve each.
1082  LLVM_DEBUG({
1083  llvm::dbgs() << "\n";
1084  debugHeader("Solving constraints") << "\n\n";
1085  });
1086  std::vector<Frame> worklist;
1087  for (auto *var : varExprs) {
1088  // Complain about unconstrained variables.
1089  if (!var->constraint) {
1090  LLVM_DEBUG(llvm::dbgs() << "- Unconstrained " << *var << "\n");
1091  emitUninferredWidthError(var);
1092  anyFailed = true;
1093  continue;
1094  }
1095 
1096  // Compute the value for the variable.
1097  LLVM_DEBUG(llvm::dbgs()
1098  << "- Solving " << *var << " >= " << *var->constraint << "\n");
1099  seenVars.insert(var);
1100  auto solution = solveExpr(var->constraint, seenVars, worklist);
1101  // Compute the upperBound if there is one and haven't already.
1102  if (var->upperBound && !var->upperBoundSolution)
1103  var->upperBoundSolution =
1104  solveExpr(var->upperBound, seenVars, worklist).first;
1105  seenVars.clear();
1106 
1107  // Constrain variables >= 0.
1108  if (solution.first) {
1109  if (*solution.first < 0)
1110  solution.first = 0;
1111  var->setSolution(*solution.first);
1112  }
1113 
1114  // In case the width could not be inferred, complain to the user. This might
1115  // be the case if the width depends on an unconstrained variable.
1116  if (!solution.first) {
1117  LLVM_DEBUG(llvm::dbgs() << " - UNSOLVED " << *var << "\n");
1118  emitUninferredWidthError(var);
1119  anyFailed = true;
1120  continue;
1121  }
1122  LLVM_DEBUG(llvm::dbgs()
1123  << " = Solved " << *var << " = " << solution.first << " ("
1124  << (solution.second ? "cycle broken" : "unique") << ")\n");
1125 
1126  // Check if the solution we have found violates an upper bound.
1127  if (var->upperBoundSolution && var->upperBoundSolution < *solution.first) {
1128  LLVM_DEBUG(llvm::dbgs() << " ! Unsatisfiable " << *var
1129  << " <= " << var->upperBoundSolution << "\n");
1130  emitUninferredWidthError(var);
1131  anyFailed = true;
1132  }
1133  }
1134 
1135  // Copy over derived widths.
1136  for (auto *derived : derivedExprs) {
1137  auto *assigned = derived->assigned;
1138  if (!assigned || !assigned->getSolution()) {
1139  LLVM_DEBUG(llvm::dbgs() << "- Unused " << *derived << " set to 0\n");
1140  derived->setSolution(0);
1141  } else {
1142  LLVM_DEBUG(llvm::dbgs() << "- Deriving " << *derived << " = "
1143  << assigned->getSolution() << "\n");
1144  derived->setSolution(*assigned->getSolution());
1145  }
1146  }
1147 
1148  return failure(anyFailed);
1149 }
1150 
1151 // Emits the diagnostic to inform the user about an uninferred width in the
1152 // design. Returns true if an error was reported, false otherwise.
1153 void ConstraintSolver::emitUninferredWidthError(VarExpr *var) {
1154  FieldRef fieldRef = info.find(var)->second.back();
1155  Value value = fieldRef.getValue();
1156 
1157  auto diag = mlir::emitError(value.getLoc(), "uninferred width:");
1158 
1159  // Try to hint the user at what kind of node this is.
1160  if (isa<BlockArgument>(value)) {
1161  diag << " port";
1162  } else if (auto *op = value.getDefiningOp()) {
1163  TypeSwitch<Operation *>(op)
1164  .Case<WireOp>([&](auto) { diag << " wire"; })
1165  .Case<RegOp, RegResetOp>([&](auto) { diag << " reg"; })
1166  .Case<NodeOp>([&](auto) { diag << " node"; })
1167  .Default([&](auto) { diag << " value"; });
1168  } else {
1169  diag << " value";
1170  }
1171 
1172  // Actually print what the user can refer to.
1173  auto [fieldName, rootKnown] = getFieldName(fieldRef);
1174  if (!fieldName.empty()) {
1175  if (!rootKnown)
1176  diag << " field";
1177  diag << " \"" << fieldName << "\"";
1178  }
1179 
1180  if (!var->constraint) {
1181  diag << " is unconstrained";
1182  } else if (var->getSolution() && var->upperBoundSolution &&
1183  var->getSolution() > var->upperBoundSolution) {
1184  diag << " cannot satisfy all width requirements";
1185  LLVM_DEBUG(llvm::dbgs() << *var->constraint << "\n");
1186  LLVM_DEBUG(llvm::dbgs() << *var->upperBound << "\n");
1187  auto loc = locs.find(var->constraint)->second.back();
1188  diag.attachNote(loc) << "width is constrained to be at least "
1189  << *var->getSolution() << " here:";
1190  loc = locs.find(var->upperBound)->second.back();
1191  diag.attachNote(loc) << "width is constrained to be at most "
1192  << *var->upperBoundSolution << " here:";
1193  } else {
1194  diag << " width cannot be determined";
1195  LLVM_DEBUG(llvm::dbgs() << *var->constraint << "\n");
1196  auto loc = locs.find(var->constraint)->second.back();
1197  diag.attachNote(loc) << "width is constrained by an uninferred width here:";
1198  }
1199 }
1200 
1201 //===----------------------------------------------------------------------===//
1202 // Inference Constraint Problem Mapping
1203 //===----------------------------------------------------------------------===//
1204 
1205 namespace {
1206 
1207 /// A helper class which maps the types and operations in a design to a set of
1208 /// variables and constraints to be solved later.
1209 class InferenceMapping {
1210 public:
1211  InferenceMapping(ConstraintSolver &solver, SymbolTable &symtbl,
1212  hw::InnerSymbolTableCollection &istc)
1213  : solver(solver), symtbl(symtbl), irn{symtbl, istc} {}
1214 
1215  LogicalResult map(CircuitOp op);
1216  bool allWidthsKnown(Operation *op);
1217  LogicalResult mapOperation(Operation *op);
1218 
1219  /// Declare all the variables in the value. If the value is a ground type,
1220  /// there is a single variable declared. If the value is an aggregate type,
1221  /// it sets up variables for each unknown width.
1222  void declareVars(Value value, bool isDerived = false);
1223 
1224  /// Assign the constraint expressions of the fields in the `result` argument
1225  /// as the max of expressions in the `rhs` and `lhs` arguments. Both fields
1226  /// must be the same type.
1227  void maximumOfTypes(Value result, Value rhs, Value lhs);
1228 
1229  /// Constrain the value "larger" to be greater than or equal to "smaller".
1230  /// These may be aggregate values. This is used for regular connects.
1231  void constrainTypes(Value larger, Value smaller, bool equal = false);
1232 
1233  /// Constrain the expression "larger" to be greater than or equals to
1234  /// the expression "smaller".
1235  void constrainTypes(Expr *larger, Expr *smaller,
1236  bool imposeUpperBounds = false, bool equal = false);
1237 
1238  /// Assign the constraint expressions of the fields in the `src` argument as
1239  /// the expressions for the `dst` argument. Both fields must be of the given
1240  /// `type`.
1241  void unifyTypes(FieldRef lhs, FieldRef rhs, FIRRTLType type);
1242 
1243  /// Get the expr associated with the value. The value must be a non-aggregate
1244  /// type
1245  Expr *getExpr(Value value) const;
1246 
1247  /// Get the expr associated with a specific field in a value.
1248  Expr *getExpr(FieldRef fieldRef) const;
1249 
1250  /// Get the expr associated with a specific field in a value. If value is
1251  /// NULL, then this returns NULL.
1252  Expr *getExprOrNull(FieldRef fieldRef) const;
1253 
1254  /// Set the expr associated with the value. The value must be a non-aggregate
1255  /// type.
1256  void setExpr(Value value, Expr *expr);
1257 
1258  /// Set the expr associated with a specific field in a value.
1259  void setExpr(FieldRef fieldRef, Expr *expr);
1260 
1261  /// Return whether a module was skipped due to being fully inferred already.
1262  bool isModuleSkipped(FModuleOp module) const {
1263  return skippedModules.count(module);
1264  }
1265 
1266  /// Return whether all modules in the mapping were fully inferred.
1267  bool areAllModulesSkipped() const { return allModulesSkipped; }
1268 
1269 private:
1270  /// The constraint solver into which we emit variables and constraints.
1271  ConstraintSolver &solver;
1272 
1273  /// The constraint exprs for each result type of an operation.
1274  DenseMap<FieldRef, Expr *> opExprs;
1275 
1276  /// The fully inferred modules that were skipped entirely.
1277  SmallPtrSet<Operation *, 16> skippedModules;
1278  bool allModulesSkipped = true;
1279 
1280  /// Cache of module symbols
1281  SymbolTable &symtbl;
1282 
1283  /// Full design inner symbol information.
1284  hw::InnerRefNamespace irn;
1285 };
1286 
1287 } // namespace
1288 
1289 /// Check if a type contains any FIRRTL type with uninferred widths.
1290 static bool hasUninferredWidth(Type type) {
1291  return TypeSwitch<Type, bool>(type)
1292  .Case<FIRRTLBaseType>([](auto base) { return base.hasUninferredWidth(); })
1293  .Case<RefType>(
1294  [](auto ref) { return ref.getType().hasUninferredWidth(); })
1295  .Default([](auto) { return false; });
1296 }
1297 
1298 LogicalResult InferenceMapping::map(CircuitOp op) {
1299  LLVM_DEBUG(llvm::dbgs()
1300  << "\n===----- Mapping ops to constraint exprs -----===\n\n");
1301 
1302  // Ensure we have constraint variables established for all module ports.
1303  for (auto module : op.getOps<FModuleOp>())
1304  for (auto arg : module.getArguments()) {
1305  solver.setCurrentContextInfo(FieldRef(arg, 0));
1306  declareVars(arg);
1307  }
1308 
1309  for (auto module : op.getOps<FModuleOp>()) {
1310  // Check if the module contains *any* uninferred widths. This allows us to
1311  // do an early skip if the module is already fully inferred.
1312  bool anyUninferred = false;
1313  for (auto arg : module.getArguments()) {
1314  anyUninferred |= hasUninferredWidth(arg.getType());
1315  if (anyUninferred)
1316  break;
1317  }
1318  module.walk([&](Operation *op) {
1319  for (auto type : op->getResultTypes())
1320  anyUninferred |= hasUninferredWidth(type);
1321  if (anyUninferred)
1322  return WalkResult::interrupt();
1323  return WalkResult::advance();
1324  });
1325 
1326  if (!anyUninferred) {
1327  LLVM_DEBUG(llvm::dbgs() << "Skipping fully-inferred module '"
1328  << module.getName() << "'\n");
1329  skippedModules.insert(module);
1330  continue;
1331  }
1332 
1333  allModulesSkipped = false;
1334 
1335  // Go through operations in the module, creating type variables for results,
1336  // and generating constraints.
1337  auto result = module.getBodyBlock()->walk(
1338  [&](Operation *op) { return WalkResult(mapOperation(op)); });
1339  if (result.wasInterrupted())
1340  return failure();
1341  }
1342 
1343  return success();
1344 }
1345 
1346 bool InferenceMapping::allWidthsKnown(Operation *op) {
1347  /// Ignore property assignments, no widths to infer.
1348  if (isa<PropAssignOp>(op))
1349  return true;
1350 
1351  // If this is a mux, and the select signal is uninferred, we need to set an
1352  // upperbound limit on it.
1353  if (isa<MuxPrimOp, Mux4CellIntrinsicOp, Mux2CellIntrinsicOp>(op))
1354  if (hasUninferredWidth(op->getOperand(0).getType()))
1355  return false;
1356 
1357  // We need to propagate through connects.
1358  if (isa<FConnectLike, AttachOp>(op))
1359  return false;
1360 
1361  // Check if we know the width of every result of this operation.
1362  return llvm::all_of(op->getResults(), [&](auto result) {
1363  // Only consider FIRRTL types for width constraints. Ignore any foreign
1364  // types as they don't participate in the width inference process.
1365  if (auto type = type_dyn_cast<FIRRTLType>(result.getType()))
1366  if (hasUninferredWidth(type))
1367  return false;
1368  return true;
1369  });
1370 }
1371 
1372 LogicalResult InferenceMapping::mapOperation(Operation *op) {
1373  if (allWidthsKnown(op))
1374  return success();
1375 
1376  // Actually generate the necessary constraint expressions.
1377  bool mappingFailed = false;
1378  solver.setCurrentContextInfo(
1379  op->getNumResults() > 0 ? FieldRef(op->getResults()[0], 0) : FieldRef());
1380  solver.setCurrentLocation(op->getLoc());
1381  TypeSwitch<Operation *>(op)
1382  .Case<ConstantOp>([&](auto op) {
1383  // If the constant has a known width, use that. Otherwise pick the
1384  // smallest number of bits necessary to represent the constant.
1385  auto v = op.getValue();
1386  auto w = v.getBitWidth() - (v.isNegative() ? v.countLeadingOnes()
1387  : v.countLeadingZeros());
1388  if (v.isSigned())
1389  w += 1;
1390  setExpr(op.getResult(), solver.known(std::max(w, 1u)));
1391  })
1392  .Case<SpecialConstantOp>([&](auto op) {
1393  // Nothing required.
1394  })
1395  .Case<InvalidValueOp>([&](auto op) {
1396  // We must duplicate the invalid value for each use, since each use can
1397  // be inferred to a different width.
1398  declareVars(op.getResult(), /*isDerived=*/true);
1399  })
1400  .Case<WireOp, RegOp>([&](auto op) { declareVars(op.getResult()); })
1401  .Case<RegResetOp>([&](auto op) {
1402  // The original Scala code also constrains the reset signal to be at
1403  // least 1 bit wide. We don't do this here since the MLIR FIRRTL
1404  // dialect enforces the reset signal to be an async reset or a
1405  // `uint<1>`.
1406  declareVars(op.getResult());
1407  // Contrain the register to be greater than or equal to the reset
1408  // signal.
1409  constrainTypes(op.getResult(), op.getResetValue());
1410  })
1411  .Case<NodeOp>([&](auto op) {
1412  // Nodes have the same type as their input.
1413  unifyTypes(FieldRef(op.getResult(), 0), FieldRef(op.getInput(), 0),
1414  op.getResult().getType());
1415  })
1416 
1417  // Aggregate Values
1418  .Case<SubfieldOp>([&](auto op) {
1419  BundleType bundleType = op.getInput().getType();
1420  auto fieldID = bundleType.getFieldID(op.getFieldIndex());
1421  unifyTypes(FieldRef(op.getResult(), 0),
1422  FieldRef(op.getInput(), fieldID), op.getType());
1423  })
1424  .Case<SubindexOp, SubaccessOp>([&](auto op) {
1425  // All vec fields unify to the same thing. Always use the first element
1426  // of the vector, which has a field ID of 1.
1427  unifyTypes(FieldRef(op.getResult(), 0), FieldRef(op.getInput(), 1),
1428  op.getType());
1429  })
1430  .Case<SubtagOp>([&](auto op) {
1431  FEnumType enumType = op.getInput().getType();
1432  auto fieldID = enumType.getFieldID(op.getFieldIndex());
1433  unifyTypes(FieldRef(op.getResult(), 0),
1434  FieldRef(op.getInput(), fieldID), op.getType());
1435  })
1436 
1437  .Case<RefSubOp>([&](RefSubOp op) {
1438  uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(
1439  op.getInput().getType().getType())
1440  .Case<FVectorType>([](auto _) { return 1; })
1441  .Case<BundleType>([&](auto type) {
1442  return type.getFieldID(op.getIndex());
1443  });
1444  unifyTypes(FieldRef(op.getResult(), 0),
1445  FieldRef(op.getInput(), fieldID), op.getType());
1446  })
1447 
1448  // Arithmetic and Logical Binary Primitives
1449  .Case<AddPrimOp, SubPrimOp>([&](auto op) {
1450  auto lhs = getExpr(op.getLhs());
1451  auto rhs = getExpr(op.getRhs());
1452  auto e = solver.add(solver.max(lhs, rhs), solver.known(1));
1453  setExpr(op.getResult(), e);
1454  })
1455  .Case<MulPrimOp>([&](auto op) {
1456  auto lhs = getExpr(op.getLhs());
1457  auto rhs = getExpr(op.getRhs());
1458  auto e = solver.add(lhs, rhs);
1459  setExpr(op.getResult(), e);
1460  })
1461  .Case<DivPrimOp>([&](auto op) {
1462  auto lhs = getExpr(op.getLhs());
1463  Expr *e;
1464  if (op.getType().base().isSigned()) {
1465  e = solver.add(lhs, solver.known(1));
1466  } else {
1467  e = lhs;
1468  }
1469  setExpr(op.getResult(), e);
1470  })
1471  .Case<RemPrimOp>([&](auto op) {
1472  auto lhs = getExpr(op.getLhs());
1473  auto rhs = getExpr(op.getRhs());
1474  auto e = solver.min(lhs, rhs);
1475  setExpr(op.getResult(), e);
1476  })
1477  .Case<AndPrimOp, OrPrimOp, XorPrimOp>([&](auto op) {
1478  auto lhs = getExpr(op.getLhs());
1479  auto rhs = getExpr(op.getRhs());
1480  auto e = solver.max(lhs, rhs);
1481  setExpr(op.getResult(), e);
1482  })
1483 
1484  // Misc Binary Primitives
1485  .Case<CatPrimOp>([&](auto op) {
1486  auto lhs = getExpr(op.getLhs());
1487  auto rhs = getExpr(op.getRhs());
1488  auto e = solver.add(lhs, rhs);
1489  setExpr(op.getResult(), e);
1490  })
1491  .Case<DShlPrimOp>([&](auto op) {
1492  auto lhs = getExpr(op.getLhs());
1493  auto rhs = getExpr(op.getRhs());
1494  auto e = solver.add(lhs, solver.add(solver.pow(rhs), solver.known(-1)));
1495  setExpr(op.getResult(), e);
1496  })
1497  .Case<DShlwPrimOp, DShrPrimOp>([&](auto op) {
1498  auto e = getExpr(op.getLhs());
1499  setExpr(op.getResult(), e);
1500  })
1501 
1502  // Unary operators
1503  .Case<NegPrimOp>([&](auto op) {
1504  auto input = getExpr(op.getInput());
1505  auto e = solver.add(input, solver.known(1));
1506  setExpr(op.getResult(), e);
1507  })
1508  .Case<CvtPrimOp>([&](auto op) {
1509  auto input = getExpr(op.getInput());
1510  auto e = op.getInput().getType().base().isSigned()
1511  ? input
1512  : solver.add(input, solver.known(1));
1513  setExpr(op.getResult(), e);
1514  })
1515 
1516  // Miscellaneous
1517  .Case<BitsPrimOp>([&](auto op) {
1518  setExpr(op.getResult(), solver.known(op.getHi() - op.getLo() + 1));
1519  })
1520  .Case<HeadPrimOp>([&](auto op) {
1521  setExpr(op.getResult(), solver.known(op.getAmount()));
1522  })
1523  .Case<TailPrimOp>([&](auto op) {
1524  auto input = getExpr(op.getInput());
1525  auto e = solver.add(input, solver.known(-op.getAmount()));
1526  setExpr(op.getResult(), e);
1527  })
1528  .Case<PadPrimOp>([&](auto op) {
1529  auto input = getExpr(op.getInput());
1530  auto e = solver.max(input, solver.known(op.getAmount()));
1531  setExpr(op.getResult(), e);
1532  })
1533  .Case<ShlPrimOp>([&](auto op) {
1534  auto input = getExpr(op.getInput());
1535  auto e = solver.add(input, solver.known(op.getAmount()));
1536  setExpr(op.getResult(), e);
1537  })
1538  .Case<ShrPrimOp>([&](auto op) {
1539  auto input = getExpr(op.getInput());
1540  // UInt saturates at 0 bits, SInt at 1 bit
1541  auto minWidth = op.getInput().getType().base().isUnsigned() ? 0 : 1;
1542  auto e = solver.max(solver.add(input, solver.known(-op.getAmount())),
1543  solver.known(minWidth));
1544  setExpr(op.getResult(), e);
1545  })
1546 
1547  // Handle operations whose output width matches the input width.
1548  .Case<NotPrimOp, AsSIntPrimOp, AsUIntPrimOp, ConstCastOp>(
1549  [&](auto op) { setExpr(op.getResult(), getExpr(op.getInput())); })
1550  .Case<mlir::UnrealizedConversionCastOp>(
1551  [&](auto op) { setExpr(op.getResult(0), getExpr(op.getOperand(0))); })
1552 
1553  // Handle operations with a single result type that always has a
1554  // well-known width.
1555  .Case<LEQPrimOp, LTPrimOp, GEQPrimOp, GTPrimOp, EQPrimOp, NEQPrimOp,
1556  AsClockPrimOp, AsAsyncResetPrimOp, AndRPrimOp, OrRPrimOp,
1557  XorRPrimOp>([&](auto op) {
1558  auto width = op.getType().getBitWidthOrSentinel();
1559  assert(width > 0 && "width should have been checked by verifier");
1560  setExpr(op.getResult(), solver.known(width));
1561  })
1562  .Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](auto op) {
1563  auto *sel = getExpr(op.getSel());
1564  constrainTypes(solver.known(1), sel, /*imposeUpperBounds=*/true);
1565  maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
1566  })
1567  .Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
1568  auto *sel = getExpr(op.getSel());
1569  constrainTypes(solver.known(2), sel, /*imposeUpperBounds=*/true);
1570  maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
1571  maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
1572  maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
1573  })
1574 
1575  .Case<ConnectOp, MatchingConnectOp>(
1576  [&](auto op) { constrainTypes(op.getDest(), op.getSrc()); })
1577  .Case<RefDefineOp>([&](auto op) {
1578  // Dest >= Src, but also check Src <= Dest for correctness
1579  // (but don't solve to make this true, don't back-propagate)
1580  constrainTypes(op.getDest(), op.getSrc(), true);
1581  })
1582  .Case<AttachOp>([&](auto op) {
1583  // Attach connects multiple analog signals together. All signals must
1584  // have the same bit width. Signals without bit width inherit from the
1585  // other signals.
1586  if (op.getAttached().empty())
1587  return;
1588  auto prev = op.getAttached()[0];
1589  for (auto operand : op.getAttached().drop_front()) {
1590  auto e1 = getExpr(prev);
1591  auto e2 = getExpr(operand);
1592  constrainTypes(e1, e2, /*imposeUpperBounds=*/true);
1593  constrainTypes(e2, e1, /*imposeUpperBounds=*/true);
1594  prev = operand;
1595  }
1596  })
1597 
1598  // Handle the no-ops that don't interact with width inference.
1599  .Case<PrintFOp, SkipOp, StopOp, WhenOp, AssertOp, AssumeOp,
1600  UnclockedAssumeIntrinsicOp, CoverOp>([&](auto) {})
1601 
1602  // Handle instances of other modules.
1603  .Case<InstanceOp>([&](auto op) {
1604  auto refdModule = op.getReferencedOperation(symtbl);
1605  auto module = dyn_cast<FModuleOp>(&*refdModule);
1606  if (!module) {
1607  auto diag = mlir::emitError(op.getLoc());
1608  diag << "extern module `" << op.getModuleName()
1609  << "` has ports of uninferred width";
1610 
1611  auto fml = cast<FModuleLike>(&*refdModule);
1612  auto ports = fml.getPorts();
1613  for (auto &port : ports) {
1614  auto baseType = getBaseType(port.type);
1615  if (baseType && baseType.hasUninferredWidth()) {
1616  diag.attachNote(op.getLoc()) << "Port: " << port.name;
1617  if (!baseType.isGround())
1618  diagnoseUninferredType(diag, baseType, port.name.getValue());
1619  }
1620  }
1621 
1622  diag.attachNote(op.getLoc())
1623  << "Only non-extern FIRRTL modules may contain unspecified "
1624  "widths to be inferred automatically.";
1625  diag.attachNote(refdModule->getLoc())
1626  << "Module `" << op.getModuleName() << "` defined here:";
1627  mappingFailed = true;
1628  return;
1629  }
1630  // Simply look up the free variables created for the instantiated
1631  // module's ports, and use them for instance port wires. This way,
1632  // constraints imposed onto the ports of the instance will transparently
1633  // apply to the ports of the instantiated module.
1634  for (auto [result, arg] :
1635  llvm::zip(op->getResults(), module.getArguments()))
1636  unifyTypes({result, 0}, {arg, 0},
1637  type_cast<FIRRTLType>(result.getType()));
1638  })
1639 
1640  // Handle memories.
1641  .Case<MemOp>([&](MemOp op) {
1642  // Create constraint variables for all ports.
1643  unsigned nonDebugPort = 0;
1644  for (const auto &result : llvm::enumerate(op.getResults())) {
1645  declareVars(result.value());
1646  if (!type_isa<RefType>(result.value().getType()))
1647  nonDebugPort = result.index();
1648  }
1649 
1650  // A helper function that returns the indeces of the "data", "rdata",
1651  // and "wdata" fields in the bundle corresponding to a memory port.
1652  auto dataFieldIndices = [](MemOp::PortKind kind) -> ArrayRef<unsigned> {
1653  static const unsigned indices[] = {3, 5};
1654  static const unsigned debug[] = {0};
1655  switch (kind) {
1656  case MemOp::PortKind::Read:
1657  case MemOp::PortKind::Write:
1658  return ArrayRef<unsigned>(indices, 1); // {3}
1659  case MemOp::PortKind::ReadWrite:
1660  return ArrayRef<unsigned>(indices); // {3, 5}
1661  case MemOp::PortKind::Debug:
1662  return ArrayRef<unsigned>(debug);
1663  }
1664  llvm_unreachable("Imposible PortKind");
1665  };
1666 
1667  // This creates independent variables for every data port. Yet, what we
1668  // actually want is for all data ports to share the same variable. To do
1669  // this, we find the first data port declared, and use that port's vars
1670  // for all the other ports.
1671  unsigned firstFieldIndex =
1672  dataFieldIndices(op.getPortKind(nonDebugPort))[0];
1673  FieldRef firstData(
1674  op.getResult(nonDebugPort),
1675  type_cast<BundleType>(op.getPortType(nonDebugPort).getPassiveType())
1676  .getFieldID(firstFieldIndex));
1677  LLVM_DEBUG(llvm::dbgs() << "Adjusting memory port variables:\n");
1678 
1679  // Reuse data port variables.
1680  auto dataType = op.getDataType();
1681  for (unsigned i = 0, e = op.getResults().size(); i < e; ++i) {
1682  auto result = op.getResult(i);
1683  if (type_isa<RefType>(result.getType())) {
1684  // Debug ports are firrtl.ref<vector<data-type, depth>>
1685  // Use FieldRef of 1, to indicate the first vector element must be
1686  // of the dataType.
1687  unifyTypes(firstData, FieldRef(result, 1), dataType);
1688  continue;
1689  }
1690 
1691  auto portType =
1692  type_cast<BundleType>(op.getPortType(i).getPassiveType());
1693  for (auto fieldIndex : dataFieldIndices(op.getPortKind(i)))
1694  unifyTypes(FieldRef(result, portType.getFieldID(fieldIndex)),
1695  firstData, dataType);
1696  }
1697  })
1698 
1699  .Case<RefSendOp>([&](auto op) {
1700  declareVars(op.getResult());
1701  constrainTypes(op.getResult(), op.getBase(), true);
1702  })
1703  .Case<RefResolveOp>([&](auto op) {
1704  declareVars(op.getResult());
1705  constrainTypes(op.getResult(), op.getRef(), true);
1706  })
1707  .Case<RefCastOp>([&](auto op) {
1708  declareVars(op.getResult());
1709  constrainTypes(op.getResult(), op.getInput(), true);
1710  })
1711  .Case<RWProbeOp>([&](auto op) {
1712  auto ist = irn.lookup(op.getTarget());
1713  if (!ist) {
1714  op->emitError("target of rwprobe could not be resolved");
1715  mappingFailed = true;
1716  return;
1717  }
1718  auto ref = getFieldRefForTarget(ist);
1719  if (!ref) {
1720  op->emitError("target of rwprobe resolved to unsupported target");
1721  mappingFailed = true;
1722  return;
1723  }
1724  auto newFID = convertFieldIDToOurVersion(
1725  ref.getFieldID(), type_cast<FIRRTLType>(ref.getValue().getType()));
1726  unifyTypes(FieldRef(op.getResult(), 0),
1727  FieldRef(ref.getValue(), newFID), op.getType());
1728  })
1729  .Case<mlir::UnrealizedConversionCastOp>([&](auto op) {
1730  for (Value result : op.getResults()) {
1731  auto ty = result.getType();
1732  if (type_isa<FIRRTLType>(ty))
1733  declareVars(result);
1734  }
1735  })
1736  .Default([&](auto op) {
1737  op->emitOpError("not supported in width inference");
1738  mappingFailed = true;
1739  });
1740 
1741  // Forceable declarations should have the ref constrained to data result.
1742  if (auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable())
1743  unifyTypes(FieldRef(fop.getDataRef(), 0), FieldRef(fop.getDataRaw(), 0),
1744  fop.getDataType());
1745 
1746  return failure(mappingFailed);
1747 }
1748 
1749 /// Declare free variables for the type of a value, and associate the resulting
1750 /// set of variables with that value.
1751 void InferenceMapping::declareVars(Value value, bool isDerived) {
1752  // Declare a variable for every unknown width in the type. If this is a Bundle
1753  // type or a FVector type, we will have to potentially create many variables.
1754  unsigned fieldID = 0;
1755  std::function<void(FIRRTLBaseType)> declare = [&](FIRRTLBaseType type) {
1756  auto width = type.getBitWidthOrSentinel();
1757  if (width >= 0) {
1758  fieldID++;
1759  } else if (width == -1) {
1760  // Unknown width integers create a variable.
1761  FieldRef field(value, fieldID);
1762  solver.setCurrentContextInfo(field);
1763  if (isDerived)
1764  setExpr(field, solver.derived());
1765  else
1766  setExpr(field, solver.var());
1767  fieldID++;
1768  } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1769  // Bundle types recursively declare all bundle elements.
1770  fieldID++;
1771  for (auto &element : bundleType)
1772  declare(element.type);
1773  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1774  fieldID++;
1775  auto save = fieldID;
1776  declare(vecType.getElementType());
1777  // Skip past the rest of the elements
1778  fieldID = save + vecType.getMaxFieldID();
1779  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1780  fieldID++;
1781  for (auto &element : enumType.getElements())
1782  declare(element.type);
1783  } else {
1784  llvm_unreachable("Unknown type inside a bundle!");
1785  }
1786  };
1787  if (auto type = getBaseType(value.getType()))
1788  declare(type);
1789 }
1790 
1791 /// Assign the constraint expressions of the fields in the `result` argument as
1792 /// the max of expressions in the `rhs` and `lhs` arguments. Both fields must be
1793 /// the same type.
1794 void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
1795  // Recurse to every leaf element and set larger >= smaller.
1796  auto fieldID = 0;
1797  std::function<void(FIRRTLBaseType)> maximize = [&](FIRRTLBaseType type) {
1798  if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1799  fieldID++;
1800  for (auto &element : bundleType.getElements())
1801  maximize(element.type);
1802  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1803  fieldID++;
1804  auto save = fieldID;
1805  // Skip 0 length vectors.
1806  if (vecType.getNumElements() > 0)
1807  maximize(vecType.getElementType());
1808  fieldID = save + vecType.getMaxFieldID();
1809  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1810  fieldID++;
1811  for (auto &element : enumType.getElements())
1812  maximize(element.type);
1813  } else if (type.isGround()) {
1814  auto *e = solver.max(getExpr(FieldRef(rhs, fieldID)),
1815  getExpr(FieldRef(lhs, fieldID)));
1816  setExpr(FieldRef(result, fieldID), e);
1817  fieldID++;
1818  } else {
1819  llvm_unreachable("Unknown type inside a bundle!");
1820  }
1821  };
1822  if (auto type = getBaseType(result.getType()))
1823  maximize(type);
1824 }
1825 
1826 /// Establishes constraints to ensure the sizes in the `larger` type are greater
1827 /// than or equal to the sizes in the `smaller` type. Types have to be
1828 /// compatible in the sense that they may only differ in the presence or absence
1829 /// of bit widths.
1830 ///
1831 /// This function is used to apply regular connects.
1832 /// Set `equal` for constraining larger <= smaller for correctness but not
1833 /// solving.
1834 void InferenceMapping::constrainTypes(Value larger, Value smaller, bool equal) {
1835  // Recurse to every leaf element and set larger >= smaller. Ignore foreign
1836  // types as these do not participate in width inference.
1837 
1838  auto fieldID = 0;
1839  std::function<void(FIRRTLBaseType, Value, Value)> constrain =
1840  [&](FIRRTLBaseType type, Value larger, Value smaller) {
1841  if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1842  fieldID++;
1843  for (auto &element : bundleType.getElements()) {
1844  if (element.isFlip)
1845  constrain(element.type, smaller, larger);
1846  else
1847  constrain(element.type, larger, smaller);
1848  }
1849  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1850  fieldID++;
1851  auto save = fieldID;
1852  // Skip 0 length vectors.
1853  if (vecType.getNumElements() > 0) {
1854  constrain(vecType.getElementType(), larger, smaller);
1855  }
1856  fieldID = save + vecType.getMaxFieldID();
1857  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1858  fieldID++;
1859  for (auto &element : enumType.getElements())
1860  constrain(element.type, larger, smaller);
1861  } else if (type.isGround()) {
1862  // Leaf element, look up their expressions, and create the constraint.
1863  constrainTypes(getExpr(FieldRef(larger, fieldID)),
1864  getExpr(FieldRef(smaller, fieldID)), false, equal);
1865  fieldID++;
1866  } else {
1867  llvm_unreachable("Unknown type inside a bundle!");
1868  }
1869  };
1870 
1871  if (auto type = getBaseType(larger.getType()))
1872  constrain(type, larger, smaller);
1873 }
1874 
1875 /// Establishes constraints to ensure the sizes in the `larger` type are greater
1876 /// than or equal to the sizes in the `smaller` type.
1877 void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
1878  bool imposeUpperBounds, bool equal) {
1879  assert(larger && "Larger expression should be specified");
1880  assert(smaller && "Smaller expression should be specified");
1881 
1882  // If one of the sides is `DerivedExpr`, simply assign the other side as the
1883  // derived width. This allows `InvalidValueOp`s to properly infer their width
1884  // from the connects they are used in, but also be inferred to something
1885  // useful on their own.
1886  if (auto *largerDerived = dyn_cast<DerivedExpr>(larger)) {
1887  largerDerived->assigned = smaller;
1888  LLVM_DEBUG(llvm::dbgs() << "Deriving " << *largerDerived << " from "
1889  << *smaller << "\n");
1890  return;
1891  }
1892  if (auto *smallerDerived = dyn_cast<DerivedExpr>(smaller)) {
1893  smallerDerived->assigned = larger;
1894  LLVM_DEBUG(llvm::dbgs() << "Deriving " << *smallerDerived << " from "
1895  << *larger << "\n");
1896  return;
1897  }
1898 
1899  // If the larger expr is a free variable, create a `expr >= x` constraint for
1900  // it that we can try to satisfy with the smallest width.
1901  if (auto *largerVar = dyn_cast<VarExpr>(larger)) {
1902  [[maybe_unused]] auto *c = solver.addGeqConstraint(largerVar, smaller);
1903  LLVM_DEBUG(llvm::dbgs()
1904  << "Constrained " << *largerVar << " >= " << *c << "\n");
1905  // If we're constraining larger == smaller, add the LEQ contraint as well.
1906  // Solve for GEQ but check that LEQ is true.
1907  // Used for matchingconnect, some reference operations, and anywhere the
1908  // widths should be inferred strictly in one direction but are required to
1909  // also be equal for correctness.
1910  if (equal) {
1911  [[maybe_unused]] auto *leq = solver.addLeqConstraint(largerVar, smaller);
1912  LLVM_DEBUG(llvm::dbgs()
1913  << "Constrained " << *largerVar << " <= " << *leq << "\n");
1914  }
1915  return;
1916  }
1917 
1918  // If the smaller expr is a free variable but the larger one is not, create a
1919  // `expr <= k` upper bound that we can verify once all lower bounds have been
1920  // satisfied. Since we are always picking the smallest width to satisfy all
1921  // `>=` constraints, any `<=` constraints have no effect on the solution
1922  // besides indicating that a width is unsatisfiable.
1923  if (auto *smallerVar = dyn_cast<VarExpr>(smaller)) {
1924  if (imposeUpperBounds || equal) {
1925  [[maybe_unused]] auto *c = solver.addLeqConstraint(smallerVar, larger);
1926  LLVM_DEBUG(llvm::dbgs()
1927  << "Constrained " << *smallerVar << " <= " << *c << "\n");
1928  }
1929  }
1930 }
1931 
1932 /// Assign the constraint expressions of the fields in the `src` argument as the
1933 /// expressions for the `dst` argument. Both fields must be of the given `type`.
1934 void InferenceMapping::unifyTypes(FieldRef lhs, FieldRef rhs, FIRRTLType type) {
1935  // Fast path for `unifyTypes(x, x, _)`.
1936  if (lhs == rhs)
1937  return;
1938 
1939  // Co-iterate the two field refs, recurring into every leaf element and set
1940  // them equal.
1941  auto fieldID = 0;
1942  std::function<void(FIRRTLBaseType)> unify = [&](FIRRTLBaseType type) {
1943  if (type.isGround()) {
1944  // Leaf element, unify the fields!
1945  FieldRef lhsFieldRef(lhs.getValue(), lhs.getFieldID() + fieldID);
1946  FieldRef rhsFieldRef(rhs.getValue(), rhs.getFieldID() + fieldID);
1947  LLVM_DEBUG(llvm::dbgs()
1948  << "Unify " << getFieldName(lhsFieldRef).first << " = "
1949  << getFieldName(rhsFieldRef).first << "\n");
1950  // Abandon variables becoming unconstrainable by the unification.
1951  if (auto *var = dyn_cast_or_null<VarExpr>(getExprOrNull(lhsFieldRef)))
1952  solver.addGeqConstraint(var, solver.known(0));
1953  setExpr(lhsFieldRef, getExpr(rhsFieldRef));
1954  fieldID++;
1955  } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1956  fieldID++;
1957  for (auto &element : bundleType) {
1958  unify(element.type);
1959  }
1960  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1961  fieldID++;
1962  auto save = fieldID;
1963  // Skip 0 length vectors.
1964  if (vecType.getNumElements() > 0) {
1965  unify(vecType.getElementType());
1966  }
1967  fieldID = save + vecType.getMaxFieldID();
1968  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1969  fieldID++;
1970  for (auto &element : enumType.getElements())
1971  unify(element.type);
1972  } else {
1973  llvm_unreachable("Unknown type inside a bundle!");
1974  }
1975  };
1976  if (auto ftype = getBaseType(type))
1977  unify(ftype);
1978 }
1979 
1980 /// Get the constraint expression for a value.
1981 Expr *InferenceMapping::getExpr(Value value) const {
1982  assert(type_cast<FIRRTLType>(getBaseType(value.getType())).isGround());
1983  // A field ID of 0 indicates the entire value.
1984  return getExpr(FieldRef(value, 0));
1985 }
1986 
1987 /// Get the constraint expression for a value.
1988 Expr *InferenceMapping::getExpr(FieldRef fieldRef) const {
1989  auto *expr = getExprOrNull(fieldRef);
1990  assert(expr && "constraint expr should have been constructed for value");
1991  return expr;
1992 }
1993 
1994 Expr *InferenceMapping::getExprOrNull(FieldRef fieldRef) const {
1995  auto it = opExprs.find(fieldRef);
1996  if (it != opExprs.end())
1997  return it->second;
1998  // If we don't have an expression for this fieldRef, it should have a
1999  // constant width.
2000  auto baseType = getBaseType(fieldRef.getValue().getType());
2001  auto type =
2002  hw::FieldIdImpl::getFinalTypeByFieldID(baseType, fieldRef.getFieldID());
2003  auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
2004  if (width < 0)
2005  return nullptr;
2006  return solver.known(width);
2007 }
2008 
2009 /// Associate a constraint expression with a value.
2010 void InferenceMapping::setExpr(Value value, Expr *expr) {
2011  assert(type_cast<FIRRTLType>(getBaseType(value.getType())).isGround());
2012  // A field ID of 0 indicates the entire value.
2013  setExpr(FieldRef(value, 0), expr);
2014 }
2015 
2016 /// Associate a constraint expression with a value.
2017 void InferenceMapping::setExpr(FieldRef fieldRef, Expr *expr) {
2018  LLVM_DEBUG({
2019  llvm::dbgs() << "Expr " << *expr << " for " << fieldRef.getValue();
2020  if (fieldRef.getFieldID())
2021  llvm::dbgs() << " '" << getFieldName(fieldRef).first << "'";
2022  auto fieldName = getFieldName(fieldRef);
2023  if (fieldName.second)
2024  llvm::dbgs() << " (\"" << fieldName.first << "\")";
2025  llvm::dbgs() << "\n";
2026  });
2027  opExprs[fieldRef] = expr;
2028 }
2029 
2030 //===----------------------------------------------------------------------===//
2031 // Inference Result Application
2032 //===----------------------------------------------------------------------===//
2033 
2034 namespace {
2035 /// A helper class which maps the types and operations in a design to a set
2036 /// of variables and constraints to be solved later.
2037 class InferenceTypeUpdate {
2038 public:
2039  InferenceTypeUpdate(InferenceMapping &mapping) : mapping(mapping) {}
2040 
2041  LogicalResult update(CircuitOp op);
2042  FailureOr<bool> updateOperation(Operation *op);
2043  FailureOr<bool> updateValue(Value value);
2045 
2046 private:
2047  const InferenceMapping &mapping;
2048 };
2049 
2050 } // namespace
2051 
2052 /// Update the types throughout a circuit.
2053 LogicalResult InferenceTypeUpdate::update(CircuitOp op) {
2054  LLVM_DEBUG({
2055  llvm::dbgs() << "\n";
2056  debugHeader("Update types") << "\n\n";
2057  });
2058  return mlir::failableParallelForEach(
2059  op.getContext(), op.getOps<FModuleOp>(), [&](FModuleOp op) {
2060  // Skip this module if it had no widths to be
2061  // inferred at all.
2062  if (mapping.isModuleSkipped(op))
2063  return success();
2064  auto isFailed = op.walk<WalkOrder::PreOrder>([&](Operation *op) {
2065  if (failed(updateOperation(op)))
2066  return WalkResult::interrupt();
2067  return WalkResult::advance();
2068  }).wasInterrupted();
2069  return failure(isFailed);
2070  });
2071 }
2072 
2073 /// Update the result types of an operation.
2074 FailureOr<bool> InferenceTypeUpdate::updateOperation(Operation *op) {
2075  bool anyChanged = false;
2076 
2077  for (Value v : op->getResults()) {
2078  auto result = updateValue(v);
2079  if (failed(result))
2080  return result;
2081  anyChanged |= *result;
2082  }
2083 
2084  // If this is a connect operation, width inference might have inferred a RHS
2085  // that is wider than the LHS, in which case an additional BitsPrimOp is
2086  // necessary to truncate the value.
2087  if (auto con = dyn_cast<ConnectOp>(op)) {
2088  auto lhs = con.getDest();
2089  auto rhs = con.getSrc();
2090  auto lhsType = type_dyn_cast<FIRRTLBaseType>(lhs.getType());
2091  auto rhsType = type_dyn_cast<FIRRTLBaseType>(rhs.getType());
2092 
2093  // Nothing to do if not base types.
2094  if (!lhsType || !rhsType)
2095  return anyChanged;
2096 
2097  auto lhsWidth = lhsType.getBitWidthOrSentinel();
2098  auto rhsWidth = rhsType.getBitWidthOrSentinel();
2099  if (lhsWidth >= 0 && rhsWidth >= 0 && lhsWidth < rhsWidth) {
2100  OpBuilder builder(op);
2101  auto trunc = builder.createOrFold<TailPrimOp>(con.getLoc(), con.getSrc(),
2102  rhsWidth - lhsWidth);
2103  if (type_isa<SIntType>(rhsType))
2104  trunc =
2105  builder.createOrFold<AsSIntPrimOp>(con.getLoc(), lhsType, trunc);
2106 
2107  LLVM_DEBUG(llvm::dbgs()
2108  << "Truncating RHS to " << lhsType << " in " << con << "\n");
2109  con->replaceUsesOfWith(con.getSrc(), trunc);
2110  }
2111  return anyChanged;
2112  }
2113 
2114  // If this is a module, update its ports.
2115  if (auto module = dyn_cast<FModuleOp>(op)) {
2116  // Update the block argument types.
2117  bool argsChanged = false;
2118  SmallVector<Attribute> argTypes;
2119  argTypes.reserve(module.getNumPorts());
2120  for (auto arg : module.getArguments()) {
2121  auto result = updateValue(arg);
2122  if (failed(result))
2123  return result;
2124  argsChanged |= *result;
2125  argTypes.push_back(TypeAttr::get(arg.getType()));
2126  }
2127 
2128  // Update the module function type if needed.
2129  if (argsChanged) {
2130  module.setPortTypesAttr(ArrayAttr::get(module.getContext(), argTypes));
2131  anyChanged = true;
2132  }
2133  }
2134  return anyChanged;
2135 }
2136 
2137 /// Resize a `uint`, `sint`, or `analog` type to a specific width.
2138 static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth) {
2139  auto *context = type.getContext();
2141  .Case<UIntType>([&](auto type) {
2142  return UIntType::get(context, newWidth, type.isConst());
2143  })
2144  .Case<SIntType>([&](auto type) {
2145  return SIntType::get(context, newWidth, type.isConst());
2146  })
2147  .Case<AnalogType>([&](auto type) {
2148  return AnalogType::get(context, newWidth, type.isConst());
2149  })
2150  .Default([&](auto type) { return type; });
2151 }
2152 
2153 /// Update the type of a value.
2154 FailureOr<bool> InferenceTypeUpdate::updateValue(Value value) {
2155  // Check if the value has a type which we can update.
2156  auto type = type_dyn_cast<FIRRTLType>(value.getType());
2157  if (!type)
2158  return false;
2159 
2160  // Fast path for types that have fully inferred widths.
2161  if (!hasUninferredWidth(type))
2162  return false;
2163 
2164  // If this is an operation that does not generate any free variables that
2165  // are determined during width inference, simply update the value type based
2166  // on the operation arguments.
2167  if (auto op = dyn_cast_or_null<InferTypeOpInterface>(value.getDefiningOp())) {
2168  SmallVector<Type, 2> types;
2169  auto res =
2170  op.inferReturnTypes(op->getContext(), op->getLoc(), op->getOperands(),
2171  op->getAttrDictionary(), op->getPropertiesStorage(),
2172  op->getRegions(), types);
2173  if (failed(res))
2174  return failure();
2175 
2176  assert(types.size() == op->getNumResults());
2177  for (auto [result, type] : llvm::zip(op->getResults(), types)) {
2178  LLVM_DEBUG(llvm::dbgs()
2179  << "Inferring " << result << " as " << type << "\n");
2180  result.setType(type);
2181  }
2182  return true;
2183  }
2184 
2185  // Recreate the type, substituting the solved widths.
2186  auto *context = type.getContext();
2187  unsigned fieldID = 0;
2188  std::function<FIRRTLBaseType(FIRRTLBaseType)> updateBase =
2189  [&](FIRRTLBaseType type) -> FIRRTLBaseType {
2190  auto width = type.getBitWidthOrSentinel();
2191  if (width >= 0) {
2192  // Known width integers return themselves.
2193  fieldID++;
2194  return type;
2195  }
2196  if (width == -1) {
2197  // Unknown width integers return the solved type.
2198  auto newType = updateType(FieldRef(value, fieldID), type);
2199  fieldID++;
2200  return newType;
2201  }
2202  if (auto bundleType = type_dyn_cast<BundleType>(type)) {
2203  // Bundle types recursively update all bundle elements.
2204  fieldID++;
2205  llvm::SmallVector<BundleType::BundleElement, 3> elements;
2206  for (auto &element : bundleType) {
2207  auto updatedBase = updateBase(element.type);
2208  if (!updatedBase)
2209  return {};
2210  elements.emplace_back(element.name, element.isFlip, updatedBase);
2211  }
2212  return BundleType::get(context, elements, bundleType.isConst());
2213  }
2214  if (auto vecType = type_dyn_cast<FVectorType>(type)) {
2215  fieldID++;
2216  auto save = fieldID;
2217  // TODO: this should recurse into the element type of 0 length vectors and
2218  // set any unknown width to 0.
2219  if (vecType.getNumElements() > 0) {
2220  auto updatedBase = updateBase(vecType.getElementType());
2221  if (!updatedBase)
2222  return {};
2223  auto newType = FVectorType::get(updatedBase, vecType.getNumElements(),
2224  vecType.isConst());
2225  fieldID = save + vecType.getMaxFieldID();
2226  return newType;
2227  }
2228  // If this is a 0 length vector return the original type.
2229  return type;
2230  }
2231  if (auto enumType = type_dyn_cast<FEnumType>(type)) {
2232  fieldID++;
2233  llvm::SmallVector<FEnumType::EnumElement> elements;
2234  for (auto &element : enumType.getElements()) {
2235  auto updatedBase = updateBase(element.type);
2236  if (!updatedBase)
2237  return {};
2238  elements.emplace_back(element.name, updatedBase);
2239  }
2240  return FEnumType::get(context, elements, enumType.isConst());
2241  }
2242  llvm_unreachable("Unknown type inside a bundle!");
2243  };
2244 
2245  // Update the type.
2246  auto newType = mapBaseTypeNullable(type, updateBase);
2247  if (!newType)
2248  return failure();
2249  LLVM_DEBUG(llvm::dbgs() << "Update " << value << " to " << newType << "\n");
2250  value.setType(newType);
2251 
2252  // If this is a ConstantOp, adjust the width of the underlying APInt.
2253  // Unsized constants have APInts which are *at least* wide enough to hold
2254  // the value, but may be larger. This can trip up the verifier.
2255  if (auto op = value.getDefiningOp<ConstantOp>()) {
2256  auto k = op.getValue();
2257  auto bitwidth = op.getType().getBitWidthOrSentinel();
2258  if (k.getBitWidth() > unsigned(bitwidth))
2259  k = k.trunc(bitwidth);
2260  op->setAttr("value", IntegerAttr::get(op.getContext(), k));
2261  }
2262 
2263  return newType != type;
2264 }
2265 
2266 /// Update a type.
2268  FIRRTLBaseType type) {
2269  assert(type.isGround() && "Can only pass in ground types.");
2270  auto value = fieldRef.getValue();
2271  // Get the inferred width.
2272  Expr *expr = mapping.getExprOrNull(fieldRef);
2273  if (!expr || !expr->getSolution()) {
2274  // It should not be possible to arrive at an uninferred width at this point.
2275  // In case the constraints are not resolvable, checks before the calls to
2276  // `updateType` must have already caught the issues and aborted the pass
2277  // early. Might turn this into an assert later.
2278  mlir::emitError(value.getLoc(), "width should have been inferred");
2279  return {};
2280  }
2281  int32_t solution = *expr->getSolution();
2282  assert(solution >= 0); // The solver infers variables to be 0 or greater.
2283  return resizeType(type, solution);
2284 }
2285 
2286 //===----------------------------------------------------------------------===//
2287 // Pass Infrastructure
2288 //===----------------------------------------------------------------------===//
2289 
2290 namespace {
2291 class InferWidthsPass
2292  : public circt::firrtl::impl::InferWidthsBase<InferWidthsPass> {
2293  void runOnOperation() override;
2294 };
2295 } // namespace
2296 
2297 void InferWidthsPass::runOnOperation() {
2298  // Collect variables and constraints
2299  ConstraintSolver solver;
2300  InferenceMapping mapping(solver, getAnalysis<SymbolTable>(),
2301  getAnalysis<hw::InnerSymbolTableCollection>());
2302  if (failed(mapping.map(getOperation())))
2303  return signalPassFailure();
2304 
2305  // fast path if no inferrable widths are around
2306  if (mapping.areAllModulesSkipped())
2307  return markAllAnalysesPreserved();
2308 
2309  // Solve the constraints.
2310  if (failed(solver.solve()))
2311  return signalPassFailure();
2312 
2313  // Update the types with the inferred widths.
2314  if (failed(InferenceTypeUpdate(mapping).update(getOperation())))
2315  return signalPassFailure();
2316 }
2317 
2318 std::unique_ptr<mlir::Pass> circt::firrtl::createInferWidthsPass() {
2319  return std::make_unique<InferWidthsPass>();
2320 }
assert(baseType &&"element must be base type")
static llvm::hash_code hash_value(const ElaboratorValue &val)
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
std::pair< std::optional< int32_t >, bool > ExprSolution
static ExprSolution computeUnary(ExprSolution arg, llvm::function_ref< int32_t(int32_t)> operation)
static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type)
Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
Definition: InferWidths.cpp:65
static ExprSolution computeBinary(ExprSolution lhs, ExprSolution rhs, llvm::function_ref< int32_t(int32_t, int32_t)> operation)
#define EXPR_KINDS
#define EXPR_CLASSES
static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth)
Resize a uint, sint, or analog type to a specific width.
static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t, Twine str)
Definition: InferWidths.cpp:47
static bool hasUninferredWidth(Type type)
Check if a type contains any FIRRTL type with uninferred widths.
static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl< Expr * > &seenVars, std::vector< Frame > &worklist)
Compute the value of a constraint expr.
static InstancePath empty
This class represents a reference to a specific field or element of an aggregate value.
Definition: FieldRef.h:28
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Definition: FieldRef.h:59
Value getValue() const
Get the Value which created this location.
Definition: FieldRef.h:37
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:520
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:530
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
FIRRTLType mapBaseTypeNullable(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
Definition: FIRRTLUtils.h:252
FieldRef getFieldRefForTarget(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
Definition: FIRRTLUtils.h:222
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
std::unique_ptr< mlir::Pass > createInferWidthsPass()
llvm::hash_code hash_value(const ClassElement &element)
Definition: FIRRTLTypes.h:368
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugHeader(llvm::StringRef str, int width=80)
Write a "header"-like string to the debug stream with a certain width.
Definition: Debug.cpp:18
bool operator==(uint64_t a, const FVInt &b)
Definition: FVInt.h:650
Definition: debug.py:1
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition: Utils.h:32
llvm::hash_code hash_value(const T &e)