CIRCT  20.0.0git
InferWidths.cpp
Go to the documentation of this file.
1 //===- InferWidths.cpp - Infer width of types -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the InferWidths pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/Pass/Pass.h"
16 
21 #include "circt/Support/Debug.h"
22 #include "circt/Support/FieldRef.h"
23 #include "mlir/IR/ImplicitLocOpBuilder.h"
24 #include "mlir/IR/Threading.h"
25 #include "llvm/ADT/APSInt.h"
26 #include "llvm/ADT/DenseSet.h"
27 #include "llvm/ADT/Hashing.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 
32 #define DEBUG_TYPE "infer-widths"
33 
34 namespace circt {
35 namespace firrtl {
36 #define GEN_PASS_DEF_INFERWIDTHS
37 #include "circt/Dialect/FIRRTL/Passes.h.inc"
38 } // namespace firrtl
39 } // namespace circt
40 
41 using mlir::InferTypeOpInterface;
42 using mlir::WalkOrder;
43 
44 using namespace circt;
45 using namespace firrtl;
46 
47 //===----------------------------------------------------------------------===//
48 // Helpers
49 //===----------------------------------------------------------------------===//
50 
51 static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t,
52  Twine str) {
53  auto basetype = type_dyn_cast<FIRRTLBaseType>(t);
54  if (!basetype)
55  return;
56  if (!basetype.hasUninferredWidth())
57  return;
58 
59  if (basetype.isGround())
60  diag.attachNote() << "Field: \"" << str << "\"";
61  else if (auto vecType = type_dyn_cast<FVectorType>(basetype))
62  diagnoseUninferredType(diag, vecType.getElementType(), str + "[]");
63  else if (auto bundleType = type_dyn_cast<BundleType>(basetype))
64  for (auto &elem : bundleType.getElements())
65  diagnoseUninferredType(diag, elem.type, str + "." + elem.name.getValue());
66 }
67 
68 /// Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
69 static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type) {
70  uint64_t convertedFieldID = 0;
71 
72  auto curFID = fieldID;
73  Type curFType = type;
74  while (curFID != 0) {
75  auto [child, subID] =
76  hw::FieldIdImpl::getSubTypeByFieldID(curFType, curFID);
77  if (isa<FVectorType>(curFType))
78  convertedFieldID++; // Vector fieldID is 1.
79  else
80  convertedFieldID += curFID - subID; // Add consumed portion.
81  curFID = subID;
82  curFType = child;
83  }
84 
85  return convertedFieldID;
86 }
87 
88 //===----------------------------------------------------------------------===//
89 // Constraint Expressions
90 //===----------------------------------------------------------------------===//
91 
92 namespace {
93 struct Expr;
94 } // namespace
95 
96 /// Allow rvalue refs to `Expr` and subclasses to be printed to streams.
97 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
98  int>::type = 0>
99 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const T &e) {
100  e.print(os);
101  return os;
102 }
103 
104 // Allow expression subclasses to be hashed.
105 namespace mlir {
106 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
107  int>::type = 0>
108 inline llvm::hash_code hash_value(const T &e) {
109  return e.hash_value();
110 }
111 } // namespace mlir
112 
113 namespace {
114 #define EXPR_NAMES(x) \
115  Var##x, Derived##x, Id##x, Known##x, Add##x, Pow##x, Max##x, Min##x
116 #define EXPR_KINDS EXPR_NAMES()
117 #define EXPR_CLASSES EXPR_NAMES(Expr)
118 
119 /// An expression on the right-hand side of a constraint.
120 struct Expr {
121  enum class Kind : uint8_t { EXPR_KINDS };
122 
123  /// Print a human-readable representation of this expr.
124  void print(llvm::raw_ostream &os) const;
125 
126  std::optional<int32_t> getSolution() const {
127  if (hasSolution)
128  return solution;
129  return std::nullopt;
130  }
131 
132  void setSolution(int32_t solution) {
133  hasSolution = true;
134  this->solution = solution;
135  }
136 
137  Kind getKind() const { return kind; }
138 
139 protected:
140  Expr(Kind kind) : kind(kind) {}
141  llvm::hash_code hash_value() const { return llvm::hash_value(kind); }
142 
143 private:
144  int32_t solution;
145  Kind kind;
146  bool hasSolution = false;
147 };
148 
149 /// Helper class to CRTP-derive common functions.
150 template <class DerivedT, Expr::Kind DerivedKind>
151 struct ExprBase : public Expr {
152  ExprBase() : Expr(DerivedKind) {}
153  static bool classof(const Expr *e) { return e->getKind() == DerivedKind; }
154  bool operator==(const Expr &other) const {
155  if (auto otherSame = dyn_cast<DerivedT>(other))
156  return *static_cast<DerivedT *>(this) == otherSame;
157  return false;
158  }
159 };
160 
161 /// A free variable.
162 struct VarExpr : public ExprBase<VarExpr, Expr::Kind::Var> {
163  void print(llvm::raw_ostream &os) const {
164  // Hash the `this` pointer into something somewhat human readable. Since
165  // this is just for debug dumping, we wrap around at 65536 variables.
166  os << "var" << ((size_t)this / llvm::PowerOf2Ceil(sizeof(*this)) & 0xFFFF);
167  }
168 
169  /// The constraint expression this variable is supposed to be greater than or
170  /// equal to. This is not part of the variable's hash and equality property.
171  Expr *constraint = nullptr;
172 
173  /// The upper bound this variable is supposed to be smaller than or equal to.
174  Expr *upperBound = nullptr;
175  std::optional<int32_t> upperBoundSolution;
176 };
177 
178 /// A derived width.
179 ///
180 /// These are generated for `InvalidValueOp`s which want to derived their width
181 /// from connect operations that they are on the right hand side of.
182 struct DerivedExpr : public ExprBase<DerivedExpr, Expr::Kind::Derived> {
183  void print(llvm::raw_ostream &os) const {
184  // Hash the `this` pointer into something somewhat human readable.
185  os << "derive"
186  << ((size_t)this / llvm::PowerOf2Ceil(sizeof(*this)) & 0xFFF);
187  }
188 
189  /// The expression this derived width is equivalent to.
190  Expr *assigned = nullptr;
191 };
192 
193 /// An identity expression.
194 ///
195 /// This expression evaluates to its inner expression. It is used in a very
196 /// specific case of constraints on variables, in order to be able to track
197 /// where the constraint was imposed. Constraints on variables are represented
198 /// as `var >= <expr>`. When the first constraint `a` is imposed, it is stored
199 /// as the constraint expression (`var >= a`). When the second constraint `b` is
200 /// imposed, a *new* max expression is allocated (`var >= max(a, b)`).
201 /// Expressions are annotated with a location when they are created, which in
202 /// this case are connect ops. Since imposing the first constraint does not
203 /// create any new expression, the location information of that connect would be
204 /// lost. With an identity expression, imposing the first constraint becomes
205 /// `var >= identity(a)`, which is a *new* expression and properly tracks the
206 /// location info.
207 struct IdExpr : public ExprBase<IdExpr, Expr::Kind::Id> {
208  IdExpr(Expr *arg) : arg(arg) { assert(arg); }
209  void print(llvm::raw_ostream &os) const { os << "*" << *arg; }
210  bool operator==(const IdExpr &other) const {
211  return getKind() == other.getKind() && arg == other.arg;
212  }
213  llvm::hash_code hash_value() const {
214  return llvm::hash_combine(Expr::hash_value(), arg);
215  }
216 
217  /// The inner expression.
218  Expr *const arg;
219 };
220 
221 /// A known constant value.
222 struct KnownExpr : public ExprBase<KnownExpr, Expr::Kind::Known> {
223  KnownExpr(int32_t value) : ExprBase() { setSolution(value); }
224  void print(llvm::raw_ostream &os) const { os << *getSolution(); }
225  bool operator==(const KnownExpr &other) const {
226  return *getSolution() == *other.getSolution();
227  }
228  llvm::hash_code hash_value() const {
229  return llvm::hash_combine(Expr::hash_value(), *getSolution());
230  }
231  int32_t getValue() const { return *getSolution(); }
232 };
233 
234 /// A unary expression. Contains the actual data. Concrete subclasses are merely
235 /// there for show and ease of use.
236 struct UnaryExpr : public Expr {
237  bool operator==(const UnaryExpr &other) const {
238  return getKind() == other.getKind() && arg == other.arg;
239  }
240  llvm::hash_code hash_value() const {
241  return llvm::hash_combine(Expr::hash_value(), arg);
242  }
243 
244  /// The child expression.
245  Expr *const arg;
246 
247 protected:
248  UnaryExpr(Kind kind, Expr *arg) : Expr(kind), arg(arg) { assert(arg); }
249 };
250 
251 /// Helper class to CRTP-derive common functions.
252 template <class DerivedT, Expr::Kind DerivedKind>
253 struct UnaryExprBase : public UnaryExpr {
254  template <typename... Args>
255  UnaryExprBase(Args &&...args)
256  : UnaryExpr(DerivedKind, std::forward<Args>(args)...) {}
257  static bool classof(const Expr *e) { return e->getKind() == DerivedKind; }
258 };
259 
260 /// A power of two.
261 struct PowExpr : public UnaryExprBase<PowExpr, Expr::Kind::Pow> {
262  using UnaryExprBase::UnaryExprBase;
263  void print(llvm::raw_ostream &os) const { os << "2^" << arg; }
264 };
265 
266 /// A binary expression. Contains the actual data. Concrete subclasses are
267 /// merely there for show and ease of use.
268 struct BinaryExpr : public Expr {
269  bool operator==(const BinaryExpr &other) const {
270  return getKind() == other.getKind() && lhs() == other.lhs() &&
271  rhs() == other.rhs();
272  }
273  llvm::hash_code hash_value() const {
274  return llvm::hash_combine(Expr::hash_value(), *args);
275  }
276  Expr *lhs() const { return args[0]; }
277  Expr *rhs() const { return args[1]; }
278 
279  /// The child expressions.
280  Expr *const args[2];
281 
282 protected:
283  BinaryExpr(Kind kind, Expr *lhs, Expr *rhs) : Expr(kind), args{lhs, rhs} {
284  assert(lhs);
285  assert(rhs);
286  }
287 };
288 
289 /// Helper class to CRTP-derive common functions.
290 template <class DerivedT, Expr::Kind DerivedKind>
291 struct BinaryExprBase : public BinaryExpr {
292  template <typename... Args>
293  BinaryExprBase(Args &&...args)
294  : BinaryExpr(DerivedKind, std::forward<Args>(args)...) {}
295  static bool classof(const Expr *e) { return e->getKind() == DerivedKind; }
296 };
297 
298 /// An addition.
299 struct AddExpr : public BinaryExprBase<AddExpr, Expr::Kind::Add> {
300  using BinaryExprBase::BinaryExprBase;
301  void print(llvm::raw_ostream &os) const {
302  os << "(" << *lhs() << " + " << *rhs() << ")";
303  }
304 };
305 
306 /// The maximum of two expressions.
307 struct MaxExpr : public BinaryExprBase<MaxExpr, Expr::Kind::Max> {
308  using BinaryExprBase::BinaryExprBase;
309  void print(llvm::raw_ostream &os) const {
310  os << "max(" << *lhs() << ", " << *rhs() << ")";
311  }
312 };
313 
314 /// The minimum of two expressions.
315 struct MinExpr : public BinaryExprBase<MinExpr, Expr::Kind::Min> {
316  using BinaryExprBase::BinaryExprBase;
317  void print(llvm::raw_ostream &os) const {
318  os << "min(" << *lhs() << ", " << *rhs() << ")";
319  }
320 };
321 
322 void Expr::print(llvm::raw_ostream &os) const {
323  TypeSwitch<const Expr *>(this).Case<EXPR_CLASSES>(
324  [&](auto *e) { e->print(os); });
325 }
326 
327 } // namespace
328 
329 //===----------------------------------------------------------------------===//
330 // Fast bump allocator with optional interning
331 //===----------------------------------------------------------------------===//
332 
333 namespace {
334 
335 // Hash slots in the interned allocator as if they were the pointed-to value
336 // itself.
337 template <typename T>
338 struct InternedSlotInfo : DenseMapInfo<T *> {
339  static T *getEmptyKey() {
340  auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
341  return static_cast<T *>(pointer);
342  }
343  static T *getTombstoneKey() {
345  return static_cast<T *>(pointer);
346  }
347  static unsigned getHashValue(const T *val) { return mlir::hash_value(*val); }
348  static bool isEqual(const T *lhs, const T *rhs) {
349  auto empty = getEmptyKey();
350  auto tombstone = getTombstoneKey();
351  if (lhs == empty || rhs == empty || lhs == tombstone || rhs == tombstone)
352  return lhs == rhs;
353  return *lhs == *rhs;
354  }
355 };
356 
357 /// A simple bump allocator that ensures only ever one copy per object exists.
358 /// The allocated objects must not have a destructor.
359 template <typename T, typename std::enable_if_t<
360  std::is_trivially_destructible<T>::value, int> = 0>
361 class InternedAllocator {
362  llvm::DenseSet<T *, InternedSlotInfo<T>> interned;
363  llvm::BumpPtrAllocator &allocator;
364 
365 public:
366  InternedAllocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
367 
368  /// Allocate a new object if it does not yet exist, or return a pointer to the
369  /// existing one. `R` is the type of the object to be allocated. `R` must be
370  /// derived from or be the type `T`.
371  template <typename R = T, typename... Args>
372  std::pair<R *, bool> alloc(Args &&...args) {
373  auto stackValue = R(std::forward<Args>(args)...);
374  auto *stackSlot = &stackValue;
375  auto it = interned.find(stackSlot);
376  if (it != interned.end())
377  return std::make_pair(static_cast<R *>(*it), false);
378  auto heapValue = new (allocator) R(std::move(stackValue));
379  interned.insert(heapValue);
380  return std::make_pair(heapValue, true);
381  }
382 };
383 
384 /// A simple bump allocator. The allocated objects must not have a destructor.
385 /// This allocator is mainly there for symmetry with the `InternedAllocator`.
386 template <typename T, typename std::enable_if_t<
387  std::is_trivially_destructible<T>::value, int> = 0>
388 class Allocator {
389  llvm::BumpPtrAllocator &allocator;
390 
391 public:
392  Allocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
393 
394  /// Allocate a new object. `R` is the type of the object to be allocated. `R`
395  /// must be derived from or be the type `T`.
396  template <typename R = T, typename... Args>
397  R *alloc(Args &&...args) {
398  return new (allocator) R(std::forward<Args>(args)...);
399  }
400 };
401 
402 } // namespace
403 
404 //===----------------------------------------------------------------------===//
405 // Constraint Solver
406 //===----------------------------------------------------------------------===//
407 
408 namespace {
409 /// A canonicalized linear inequality that maps a constraint on var `x` to the
410 /// linear inequality `x >= max(a*x+b, c) + (failed ? ∞ : 0)`.
411 ///
412 /// The inequality separately tracks recursive (a, b) and non-recursive (c)
413 /// constraints on `x`. This allows it to properly identify the combination of
414 /// the two constraints `x >= x-1` and `x >= 4` to be satisfiable as
415 /// `x >= max(x-1, 4)`. If it only tracked inequality as `x >= a*x+b`, the
416 /// combination of these two constraints would be `x >= x+4` (due to max(-1,4) =
417 /// 4), which would be unsatisfiable.
418 ///
419 /// The `failed` flag acts as an additional `∞` term that renders the inequality
420 /// unsatisfiable. It is used as a tombstone value in case an operation renders
421 /// the equality unsatisfiable (e.g. `x >= 2**x` would be represented as the
422 /// inequality `x >= ∞`).
423 ///
424 /// Inequalities represented in this form can easily be checked for
425 /// unsatisfiability in the presence of recursion by inspecting the coefficients
426 /// a and b. The `sat` function performs this action.
427 struct LinIneq {
428  // x >= max(a*x+b, c) + (failed ? ∞ : 0)
429  int32_t recScale = 0; // a
430  int32_t recBias = 0; // b
431  int32_t nonrecBias = 0; // c
432  bool failed = false;
433 
434  /// Create a new unsatisfiable inequality `x >= ∞`.
435  static LinIneq unsat() { return LinIneq(true); }
436 
437  /// Create a new inequality `x >= (failed ? ∞ : 0)`.
438  explicit LinIneq(bool failed = false) : failed(failed) {}
439 
440  /// Create a new inequality `x >= bias`.
441  explicit LinIneq(int32_t bias) : nonrecBias(bias) {}
442 
443  /// Create a new inequality `x >= scale*x+bias`.
444  explicit LinIneq(int32_t scale, int32_t bias) {
445  if (scale != 0) {
446  recScale = scale;
447  recBias = bias;
448  } else {
449  nonrecBias = bias;
450  }
451  }
452 
453  /// Create a new inequality `x >= max(recScale*x+recBias, nonrecBias) +
454  /// (failed ? ∞ : 0)`.
455  explicit LinIneq(int32_t recScale, int32_t recBias, int32_t nonrecBias,
456  bool failed = false)
457  : failed(failed) {
458  if (recScale != 0) {
459  this->recScale = recScale;
460  this->recBias = recBias;
461  this->nonrecBias = nonrecBias;
462  } else {
463  this->nonrecBias = std::max(recBias, nonrecBias);
464  }
465  }
466 
467  /// Combine two inequalities by taking the maxima of corresponding
468  /// coefficients.
469  ///
470  /// This essentially combines `x >= max(a1*x+b1, c1)` and `x >= max(a2*x+b2,
471  /// c2)` into a new `x >= max(max(a1,a2)*x+max(b1,b2), max(c1,c2))`. This is
472  /// a pessimistic upper bound, since e.g. `x >= 2x-10` and `x >= x-5` may both
473  /// hold, but the resulting `x >= 2x-5` may pessimistically not hold.
474  static LinIneq max(const LinIneq &lhs, const LinIneq &rhs) {
475  return LinIneq(std::max(lhs.recScale, rhs.recScale),
476  std::max(lhs.recBias, rhs.recBias),
477  std::max(lhs.nonrecBias, rhs.nonrecBias),
478  lhs.failed || rhs.failed);
479  }
480 
481  /// Combine two inequalities by summing up the two right hand sides.
482  ///
483  /// This is a tricky one, since the addition of the two max terms will lead to
484  /// a maximum over four possible terms (similar to a binomial expansion). In
485  /// order to shoehorn this back into a two-term maximum, we have to pick the
486  /// recursive term that will grow the fastest.
487  ///
488  /// As an example for this problem, consider the following addition:
489  ///
490  /// x >= max(a1*x+b1, c1) + max(a2*x+b2, c2)
491  ///
492  /// We would like to expand and rearrange this again into a maximum:
493  ///
494  /// x >= max(a1*x+b1 + max(a2*x+b2, c2), c1 + max(a2*x+b2, c2))
495  /// x >= max(max(a1*x+b1 + a2*x+b2, a1*x+b1 + c2),
496  /// max(c1 + a2*x+b2, c1 + c2))
497  /// x >= max((a1+a2)*x+(b1+b2), a1*x+(b1+c2), a2*x+(b2+c1), c1+c2)
498  ///
499  /// Since we are combining two two-term maxima, there are four possible ways
500  /// how the terms can combine, leading to the above four-term maximum. An easy
501  /// upper bound of the form we want would be the following:
502  ///
503  /// x >= max(max(a1+a2, a1, a2)*x + max(b1+b2, b1+c2, b2+c1), c1+c2)
504  ///
505  /// However, this is a very pessimistic upper-bound that will declare very
506  /// common patterns in the IR as unbreakable cycles, despite them being very
507  /// much breakable. For example:
508  ///
509  /// x >= max(x, 42) + max(0, -3) <-- breakable recursion
510  /// x >= max(max(1+0, 1, 0)*x + max(42+0, -3, 42), 42-2)
511  /// x >= max(x + 42, 39) <-- unbreakable recursion!
512  ///
513  /// A better approach is to take the expanded four-term maximum, retain the
514  /// non-recursive term (c1+c2), and estimate which one of the recursive terms
515  /// (first three) will become dominant as we choose greater values for x.
516  /// Since x never is inferred to be negative, the recursive term in the
517  /// maximum with the highest scaling factor for x will end up dominating as
518  /// x tends to ∞:
519  ///
520  /// x >= max({
521  /// (a1+a2)*x+(b1+b2) if a1+a2 >= max(a1+a2, a1, a2) and a1>0 and a2>0,
522  /// a1*x+(b1+c2) if a1 >= max(a1+a2, a1, a2) and a1>0,
523  /// a2*x+(b2+c1) if a2 >= max(a1+a2, a1, a2) and a2>0,
524  /// 0 otherwise
525  /// }, c1+c2)
526  ///
527  /// In case multiple cases apply, the highest bias of the recursive term is
528  /// picked. With this, the above problematic example triggers the second case
529  /// and becomes:
530  ///
531  /// x >= max(1*x+(0-3), 42-3) = max(x-3, 39)
532  ///
533  /// Of which the first case is chosen, as it has the lower bias value.
534  static LinIneq add(const LinIneq &lhs, const LinIneq &rhs) {
535  // Determine the maximum scaling factor among the three possible recursive
536  // terms.
537  auto enable1 = lhs.recScale > 0 && rhs.recScale > 0;
538  auto enable2 = lhs.recScale > 0;
539  auto enable3 = rhs.recScale > 0;
540  auto scale1 = lhs.recScale + rhs.recScale; // (a1+a2)
541  auto scale2 = lhs.recScale; // a1
542  auto scale3 = rhs.recScale; // a2
543  auto bias1 = lhs.recBias + rhs.recBias; // (b1+b2)
544  auto bias2 = lhs.recBias + rhs.nonrecBias; // (b1+c2)
545  auto bias3 = rhs.recBias + lhs.nonrecBias; // (b2+c1)
546  auto maxScale = std::max(scale1, std::max(scale2, scale3));
547 
548  // Among those terms that have a maximum scaling factor, determine the
549  // largest bias value.
550  std::optional<int32_t> maxBias;
551  if (enable1 && scale1 == maxScale)
552  maxBias = bias1;
553  if (enable2 && scale2 == maxScale && (!maxBias || bias2 > *maxBias))
554  maxBias = bias2;
555  if (enable3 && scale3 == maxScale && (!maxBias || bias3 > *maxBias))
556  maxBias = bias3;
557 
558  // Pick from the recursive terms the one with maximum scaling factor and
559  // minimum bias value.
560  auto nonrecBias = lhs.nonrecBias + rhs.nonrecBias; // c1+c2
561  auto failed = lhs.failed || rhs.failed;
562  if (enable1 && scale1 == maxScale && bias1 == *maxBias)
563  return LinIneq(scale1, bias1, nonrecBias, failed);
564  if (enable2 && scale2 == maxScale && bias2 == *maxBias)
565  return LinIneq(scale2, bias2, nonrecBias, failed);
566  if (enable3 && scale3 == maxScale && bias3 == *maxBias)
567  return LinIneq(scale3, bias3, nonrecBias, failed);
568  return LinIneq(0, 0, nonrecBias, failed);
569  }
570 
571  /// Check if the inequality is satisfiable.
572  ///
573  /// The inequality becomes unsatisfiable if the RHS is ∞, or a>1, or a==1 and
574  /// b <= 0. Otherwise there exists as solution for `x` that satisfies the
575  /// inequality.
576  bool sat() const {
577  if (failed)
578  return false;
579  if (recScale > 1)
580  return false;
581  if (recScale == 1 && recBias > 0)
582  return false;
583  return true;
584  }
585 
586  /// Dump the inequality in human-readable form.
587  void print(llvm::raw_ostream &os) const {
588  bool any = false;
589  bool both = (recScale != 0 || recBias != 0) && nonrecBias != 0;
590  os << "x >= ";
591  if (both)
592  os << "max(";
593  if (recScale != 0) {
594  any = true;
595  if (recScale != 1)
596  os << recScale << "*";
597  os << "x";
598  }
599  if (recBias != 0) {
600  if (any) {
601  if (recBias < 0)
602  os << " - " << -recBias;
603  else
604  os << " + " << recBias;
605  } else {
606  any = true;
607  os << recBias;
608  }
609  }
610  if (both)
611  os << ", ";
612  if (nonrecBias != 0) {
613  any = true;
614  os << nonrecBias;
615  }
616  if (both)
617  os << ")";
618  if (failed) {
619  if (any)
620  os << " + ";
621  os << "∞";
622  }
623  if (!any)
624  os << "0";
625  }
626 };
627 
628 /// A simple solver for width constraints.
629 class ConstraintSolver {
630 public:
631  ConstraintSolver() = default;
632 
633  VarExpr *var() {
634  auto *v = vars.alloc();
635  varExprs.push_back(v);
636  if (currentInfo)
637  info[v].insert(currentInfo);
638  if (currentLoc)
639  locs[v].insert(*currentLoc);
640  return v;
641  }
642  DerivedExpr *derived() {
643  auto *d = derivs.alloc();
644  derivedExprs.push_back(d);
645  return d;
646  }
647  KnownExpr *known(int32_t value) { return alloc<KnownExpr>(knowns, value); }
648  IdExpr *id(Expr *arg) { return alloc<IdExpr>(ids, arg); }
649  PowExpr *pow(Expr *arg) { return alloc<PowExpr>(uns, arg); }
650  AddExpr *add(Expr *lhs, Expr *rhs) { return alloc<AddExpr>(bins, lhs, rhs); }
651  MaxExpr *max(Expr *lhs, Expr *rhs) { return alloc<MaxExpr>(bins, lhs, rhs); }
652  MinExpr *min(Expr *lhs, Expr *rhs) { return alloc<MinExpr>(bins, lhs, rhs); }
653 
654  /// Add a constraint `lhs >= rhs`. Multiple constraints on the same variable
655  /// are coalesced into a `max(a, b)` expr.
656  Expr *addGeqConstraint(VarExpr *lhs, Expr *rhs) {
657  if (lhs->constraint)
658  lhs->constraint = max(lhs->constraint, rhs);
659  else
660  lhs->constraint = id(rhs);
661  return lhs->constraint;
662  }
663 
664  /// Add a constraint `lhs <= rhs`. Multiple constraints on the same variable
665  /// are coalesced into a `min(a, b)` expr.
666  Expr *addLeqConstraint(VarExpr *lhs, Expr *rhs) {
667  if (lhs->upperBound)
668  lhs->upperBound = min(lhs->upperBound, rhs);
669  else
670  lhs->upperBound = id(rhs);
671  return lhs->upperBound;
672  }
673 
674  void dumpConstraints(llvm::raw_ostream &os);
675  LogicalResult solve();
676 
677  using ContextInfo = DenseMap<Expr *, llvm::SmallSetVector<FieldRef, 1>>;
678  const ContextInfo &getContextInfo() const { return info; }
679  void setCurrentContextInfo(FieldRef fieldRef) { currentInfo = fieldRef; }
680  void setCurrentLocation(std::optional<Location> loc) { currentLoc = loc; }
681 
682 private:
683  // Allocator for constraint expressions.
684  llvm::BumpPtrAllocator allocator;
685  Allocator<VarExpr> vars = {allocator};
686  Allocator<DerivedExpr> derivs = {allocator};
687  InternedAllocator<KnownExpr> knowns = {allocator};
688  InternedAllocator<IdExpr> ids = {allocator};
689  InternedAllocator<UnaryExpr> uns = {allocator};
690  InternedAllocator<BinaryExpr> bins = {allocator};
691 
692  /// A list of expressions in the order they were created.
693  std::vector<VarExpr *> varExprs;
694  std::vector<DerivedExpr *> derivedExprs;
695 
696  /// Add an allocated expression to the list above.
697  template <typename R, typename T, typename... Args>
698  R *alloc(InternedAllocator<T> &allocator, Args &&...args) {
699  auto [expr, inserted] =
700  allocator.template alloc<R>(std::forward<Args>(args)...);
701  if (currentInfo)
702  info[expr].insert(currentInfo);
703  if (currentLoc)
704  locs[expr].insert(*currentLoc);
705  return expr;
706  }
707 
708  /// Contextual information for each expression, indicating which values in the
709  /// IR lead to this expression.
710  ContextInfo info;
711  FieldRef currentInfo = {};
712  DenseMap<Expr *, llvm::SmallSetVector<Location, 1>> locs;
713  std::optional<Location> currentLoc;
714 
715  // Forbid copyign or moving the solver, which would invalidate the refs to
716  // allocator held by the allocators.
717  ConstraintSolver(ConstraintSolver &&) = delete;
718  ConstraintSolver(const ConstraintSolver &) = delete;
719  ConstraintSolver &operator=(ConstraintSolver &&) = delete;
720  ConstraintSolver &operator=(const ConstraintSolver &) = delete;
721 
722  void emitUninferredWidthError(VarExpr *var);
723 
724  LinIneq checkCycles(VarExpr *var, Expr *expr,
725  SmallPtrSetImpl<Expr *> &seenVars,
726  InFlightDiagnostic *reportInto = nullptr,
727  unsigned indent = 1);
728 };
729 
730 } // namespace
731 
732 /// Print all constraints in the solver to an output stream.
733 void ConstraintSolver::dumpConstraints(llvm::raw_ostream &os) {
734  for (auto *v : varExprs) {
735  if (v->constraint)
736  os << "- " << *v << " >= " << *v->constraint << "\n";
737  else
738  os << "- " << *v << " unconstrained\n";
739  }
740 }
741 
742 #ifndef NDEBUG
743 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const LinIneq &l) {
744  l.print(os);
745  return os;
746 }
747 #endif
748 
749 /// Compute the canonicalized linear inequality expression starting at `expr`,
750 /// for the `var` as the left hand side `x` of the inequality. `seenVars` is
751 /// used as a recursion breaker. Occurrences of `var` itself within the
752 /// expression are mapped to the `a` coefficient in the inequality. Any other
753 /// variables are substituted and, in the presence of a recursion in a variable
754 /// other than `var`, treated as zero. `info` is a mapping from constraint
755 /// expressions to values and operations that produced the expression, and is
756 /// used during error reporting. If `reportInto` is present, the function will
757 /// additionally attach unsatisfiable inequalities as notes to the diagnostic as
758 /// it encounters them.
759 LinIneq ConstraintSolver::checkCycles(VarExpr *var, Expr *expr,
760  SmallPtrSetImpl<Expr *> &seenVars,
761  InFlightDiagnostic *reportInto,
762  unsigned indent) {
763  auto ineq =
764  TypeSwitch<Expr *, LinIneq>(expr)
765  .Case<KnownExpr>(
766  [&](auto *expr) { return LinIneq(expr->getValue()); })
767  .Case<VarExpr>([&](auto *expr) {
768  if (expr == var)
769  return LinIneq(1, 0); // x >= 1*x + 0
770  if (!seenVars.insert(expr).second)
771  // Count recursions in other variables as 0. This is sane
772  // since the cycle is either breakable, in which case the
773  // recursion does not modify the resulting value of the
774  // variable, or it is not breakable and will be caught by
775  // this very function once it is called on that variable.
776  return LinIneq(0);
777  if (!expr->constraint)
778  // Count unconstrained variables as `x >= 0`.
779  return LinIneq(0);
780  auto l = checkCycles(var, expr->constraint, seenVars, reportInto,
781  indent + 1);
782  seenVars.erase(expr);
783  return l;
784  })
785  .Case<IdExpr>([&](auto *expr) {
786  return checkCycles(var, expr->arg, seenVars, reportInto,
787  indent + 1);
788  })
789  .Case<PowExpr>([&](auto *expr) {
790  // If we can evaluate `2**arg` to a sensible constant, do
791  // so. This is the case if a == 0 and c < 31 such that 2**c is
792  // representable.
793  auto arg =
794  checkCycles(var, expr->arg, seenVars, reportInto, indent + 1);
795  if (arg.recScale != 0 || arg.nonrecBias < 0 || arg.nonrecBias >= 31)
796  return LinIneq::unsat();
797  return LinIneq(1 << arg.nonrecBias); // x >= 2**arg
798  })
799  .Case<AddExpr>([&](auto *expr) {
800  return LinIneq::add(
801  checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
802  checkCycles(var, expr->rhs(), seenVars, reportInto,
803  indent + 1));
804  })
805  .Case<MaxExpr, MinExpr>([&](auto *expr) {
806  // Combine the inequalities of the LHS and RHS into a single overly
807  // pessimistic inequality. We treat `MinExpr` the same as `MaxExpr`,
808  // since `max(a,b)` is an upper bound to `min(a,b)`.
809  return LinIneq::max(
810  checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
811  checkCycles(var, expr->rhs(), seenVars, reportInto,
812  indent + 1));
813  })
814  .Default([](auto) { return LinIneq::unsat(); });
815 
816  // If we were passed an in-flight diagnostic and the current inequality is
817  // unsatisfiable, attach notes to the diagnostic indicating the values or
818  // operations that contributed to this part of the constraint expression.
819  if (reportInto && !ineq.sat()) {
820  auto report = [&](Location loc) {
821  auto &note = reportInto->attachNote(loc);
822  note << "constrained width W >= ";
823  if (ineq.recScale == -1)
824  note << "-";
825  if (ineq.recScale != 1)
826  note << ineq.recScale;
827  note << "W";
828  if (ineq.recBias < 0)
829  note << "-" << -ineq.recBias;
830  if (ineq.recBias > 0)
831  note << "+" << ineq.recBias;
832  note << " here:";
833  };
834  auto it = locs.find(expr);
835  if (it != locs.end())
836  for (auto loc : it->second)
837  report(loc);
838  }
839  if (!reportInto)
840  LLVM_DEBUG(llvm::dbgs().indent(indent * 2)
841  << "- Visited " << *expr << ": " << ineq << "\n");
842 
843  return ineq;
844 }
845 
846 using ExprSolution = std::pair<std::optional<int32_t>, bool>;
847 
848 static ExprSolution
849 computeUnary(ExprSolution arg, llvm::function_ref<int32_t(int32_t)> operation) {
850  if (arg.first)
851  arg.first = operation(*arg.first);
852  return arg;
853 }
854 
855 static ExprSolution
857  llvm::function_ref<int32_t(int32_t, int32_t)> operation) {
858  auto result = ExprSolution{std::nullopt, lhs.second || rhs.second};
859  if (lhs.first && rhs.first)
860  result.first = operation(*lhs.first, *rhs.first);
861  else if (lhs.first)
862  result.first = lhs.first;
863  else if (rhs.first)
864  result.first = rhs.first;
865  return result;
866 }
867 
868 namespace {
869 struct Frame {
870  Frame(Expr *expr, unsigned indent) : expr(expr), indent(indent) {}
871  Expr *expr;
872  // Indent is only used for debug logs.
873  unsigned indent;
874 };
875 } // namespace
876 
877 /// Compute the value of a constraint `expr`. `seenVars` is used as a recursion
878 /// breaker. Recursive variables are treated as zero. Returns the computed value
879 /// and a boolean indicating whether a recursion was detected. This may be used
880 /// to memoize the result of expressions in case they were not involved in a
881 /// cycle (which may alter their value from the perspective of a variable).
882 static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
883  std::vector<Frame> &worklist) {
884  worklist.clear();
885  worklist.emplace_back(expr, 1);
886  llvm::DenseMap<Expr *, ExprSolution> solvedExprs;
887 
888  while (!worklist.empty()) {
889  auto &frame = worklist.back();
890  auto indent = frame.indent;
891  auto setSolution = [&](ExprSolution solution) {
892  // Memoize the result.
893  if (solution.first && !solution.second)
894  frame.expr->setSolution(*solution.first);
895  solvedExprs[frame.expr] = solution;
896 
897  // Produce some useful debug prints.
898  LLVM_DEBUG({
899  if (!isa<KnownExpr>(frame.expr)) {
900  if (solution.first)
901  llvm::dbgs().indent(indent * 2)
902  << "= Solved " << *frame.expr << " = " << *solution.first;
903  else
904  llvm::dbgs().indent(indent * 2) << "= Skipped " << *frame.expr;
905  llvm::dbgs() << " (" << (solution.second ? "cycle broken" : "unique")
906  << ")\n";
907  }
908  });
909 
910  worklist.pop_back();
911  };
912 
913  // See if we have a memoized result we can return.
914  if (frame.expr->getSolution()) {
915  LLVM_DEBUG({
916  if (!isa<KnownExpr>(frame.expr))
917  llvm::dbgs().indent(indent * 2) << "- Cached " << *frame.expr << " = "
918  << *frame.expr->getSolution() << "\n";
919  });
920  setSolution(ExprSolution{*frame.expr->getSolution(), false});
921  continue;
922  }
923 
924  // Otherwise compute the value of the expression.
925  LLVM_DEBUG({
926  if (!isa<KnownExpr>(frame.expr))
927  llvm::dbgs().indent(indent * 2) << "- Solving " << *frame.expr << "\n";
928  });
929 
930  TypeSwitch<Expr *>(frame.expr)
931  .Case<KnownExpr>([&](auto *expr) {
932  setSolution(ExprSolution{expr->getValue(), false});
933  })
934  .Case<VarExpr>([&](auto *expr) {
935  if (solvedExprs.contains(expr->constraint)) {
936  auto solution = solvedExprs[expr->constraint];
937  // If we've solved the upper bound already, store the solution.
938  // This will be explicitly solved for later if not computed as
939  // part of the solving that resolved this constraint.
940  // This should only happen if somehow the constraint is
941  // solved before visiting this expression, so that our upperBound
942  // was not added to the worklist such that it was handled first.
943  if (expr->upperBound && solvedExprs.contains(expr->upperBound))
944  expr->upperBoundSolution = solvedExprs[expr->upperBound].first;
945  seenVars.erase(expr);
946  // Constrain variables >= 0.
947  if (solution.first && *solution.first < 0)
948  solution.first = 0;
949  return setSolution(solution);
950  }
951 
952  // Unconstrained variables produce no solution.
953  if (!expr->constraint)
954  return setSolution(ExprSolution{std::nullopt, false});
955  // Return no solution for recursions in the variables. This is sane
956  // and will cause the expression to be ignored when computing the
957  // parent, e.g. `a >= max(a, 1)` will become just `a >= 1`.
958  if (!seenVars.insert(expr).second)
959  return setSolution(ExprSolution{std::nullopt, true});
960 
961  worklist.emplace_back(expr->constraint, indent + 1);
962  if (expr->upperBound)
963  worklist.emplace_back(expr->upperBound, indent + 1);
964  })
965  .Case<IdExpr>([&](auto *expr) {
966  if (solvedExprs.contains(expr->arg))
967  return setSolution(solvedExprs[expr->arg]);
968  worklist.emplace_back(expr->arg, indent + 1);
969  })
970  .Case<PowExpr>([&](auto *expr) {
971  if (solvedExprs.contains(expr->arg))
972  return setSolution(computeUnary(
973  solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));
974 
975  worklist.emplace_back(expr->arg, indent + 1);
976  })
977  .Case<AddExpr>([&](auto *expr) {
978  if (solvedExprs.contains(expr->lhs()) &&
979  solvedExprs.contains(expr->rhs()))
980  return setSolution(computeBinary(
981  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
982  [](int32_t lhs, int32_t rhs) { return lhs + rhs; }));
983 
984  worklist.emplace_back(expr->lhs(), indent + 1);
985  worklist.emplace_back(expr->rhs(), indent + 1);
986  })
987  .Case<MaxExpr>([&](auto *expr) {
988  if (solvedExprs.contains(expr->lhs()) &&
989  solvedExprs.contains(expr->rhs()))
990  return setSolution(computeBinary(
991  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
992  [](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));
993 
994  worklist.emplace_back(expr->lhs(), indent + 1);
995  worklist.emplace_back(expr->rhs(), indent + 1);
996  })
997  .Case<MinExpr>([&](auto *expr) {
998  if (solvedExprs.contains(expr->lhs()) &&
999  solvedExprs.contains(expr->rhs()))
1000  return setSolution(computeBinary(
1001  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
1002  [](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));
1003 
1004  worklist.emplace_back(expr->lhs(), indent + 1);
1005  worklist.emplace_back(expr->rhs(), indent + 1);
1006  })
1007  .Default([&](auto) {
1008  setSolution(ExprSolution{std::nullopt, false});
1009  });
1010  }
1011 
1012  return solvedExprs[expr];
1013 }
1014 
1015 /// Solve the constraint problem. This is a very simple implementation that
1016 /// does not fully solve the problem if there are weird dependency cycles
1017 /// present.
1018 LogicalResult ConstraintSolver::solve() {
1019  LLVM_DEBUG({
1020  llvm::dbgs() << "\n";
1021  debugHeader("Constraints") << "\n\n";
1022  dumpConstraints(llvm::dbgs());
1023  });
1024 
1025  // Ensure that there are no adverse cycles around.
1026  LLVM_DEBUG({
1027  llvm::dbgs() << "\n";
1028  debugHeader("Checking for unbreakable loops") << "\n\n";
1029  });
1030  SmallPtrSet<Expr *, 16> seenVars;
1031  bool anyFailed = false;
1032 
1033  for (auto *var : varExprs) {
1034  if (!var->constraint)
1035  continue;
1036  LLVM_DEBUG(llvm::dbgs()
1037  << "- Checking " << *var << " >= " << *var->constraint << "\n");
1038 
1039  // Canonicalize the variable's constraint expression into a form that allows
1040  // us to easily determine if any recursion leads to an unsatisfiable
1041  // constraint. The `seenVars` set acts as a recursion breaker.
1042  seenVars.insert(var);
1043  auto ineq = checkCycles(var, var->constraint, seenVars);
1044  seenVars.clear();
1045 
1046  // If the constraint is satisfiable, we're done.
1047  // TODO: It's possible that this result is already sufficient to arrive at a
1048  // solution for the constraint, and the second pass further down is not
1049  // necessary. This would require more proper handling of `MinExpr` in the
1050  // cycle checking code.
1051  if (ineq.sat()) {
1052  LLVM_DEBUG(llvm::dbgs()
1053  << " = Breakable since " << ineq << " satisfiable\n");
1054  continue;
1055  }
1056 
1057  // If we arrive here, the constraint is not satisfiable at all. To provide
1058  // some guidance to the user, we call the cycle checking code again, but
1059  // this time with an in-flight diagnostic to attach notes indicating
1060  // unsatisfiable paths in the cycle.
1061  LLVM_DEBUG(llvm::dbgs()
1062  << " = UNBREAKABLE since " << ineq << " unsatisfiable\n");
1063  anyFailed = true;
1064  for (auto fieldRef : info.find(var)->second) {
1065  // Depending on whether this value stems from an operation or not, create
1066  // an appropriate diagnostic identifying the value.
1067  auto *op = fieldRef.getDefiningOp();
1068  auto diag = op ? op->emitOpError()
1069  : mlir::emitError(fieldRef.getValue().getLoc())
1070  << "value ";
1071  diag << "is constrained to be wider than itself";
1072 
1073  // Re-run the cycle checking, but this time reporting into the diagnostic.
1074  seenVars.insert(var);
1075  checkCycles(var, var->constraint, seenVars, &diag);
1076  seenVars.clear();
1077  }
1078  }
1079 
1080  // If there were cycles, return now to avoid complaining to the user about
1081  // dependent widths not being inferred.
1082  if (anyFailed)
1083  return failure();
1084 
1085  // Iterate over the constraint variables and solve each.
1086  LLVM_DEBUG({
1087  llvm::dbgs() << "\n";
1088  debugHeader("Solving constraints") << "\n\n";
1089  });
1090  std::vector<Frame> worklist;
1091  for (auto *var : varExprs) {
1092  // Complain about unconstrained variables.
1093  if (!var->constraint) {
1094  LLVM_DEBUG(llvm::dbgs() << "- Unconstrained " << *var << "\n");
1095  emitUninferredWidthError(var);
1096  anyFailed = true;
1097  continue;
1098  }
1099 
1100  // Compute the value for the variable.
1101  LLVM_DEBUG(llvm::dbgs()
1102  << "- Solving " << *var << " >= " << *var->constraint << "\n");
1103  seenVars.insert(var);
1104  auto solution = solveExpr(var->constraint, seenVars, worklist);
1105  // Compute the upperBound if there is one and haven't already.
1106  if (var->upperBound && !var->upperBoundSolution)
1107  var->upperBoundSolution =
1108  solveExpr(var->upperBound, seenVars, worklist).first;
1109  seenVars.clear();
1110 
1111  // Constrain variables >= 0.
1112  if (solution.first) {
1113  if (*solution.first < 0)
1114  solution.first = 0;
1115  var->setSolution(*solution.first);
1116  }
1117 
1118  // In case the width could not be inferred, complain to the user. This might
1119  // be the case if the width depends on an unconstrained variable.
1120  if (!solution.first) {
1121  LLVM_DEBUG(llvm::dbgs() << " - UNSOLVED " << *var << "\n");
1122  emitUninferredWidthError(var);
1123  anyFailed = true;
1124  continue;
1125  }
1126  LLVM_DEBUG(llvm::dbgs()
1127  << " = Solved " << *var << " = " << solution.first << " ("
1128  << (solution.second ? "cycle broken" : "unique") << ")\n");
1129 
1130  // Check if the solution we have found violates an upper bound.
1131  if (var->upperBoundSolution && var->upperBoundSolution < *solution.first) {
1132  LLVM_DEBUG(llvm::dbgs() << " ! Unsatisfiable " << *var
1133  << " <= " << var->upperBoundSolution << "\n");
1134  emitUninferredWidthError(var);
1135  anyFailed = true;
1136  }
1137  }
1138 
1139  // Copy over derived widths.
1140  for (auto *derived : derivedExprs) {
1141  auto *assigned = derived->assigned;
1142  if (!assigned || !assigned->getSolution()) {
1143  LLVM_DEBUG(llvm::dbgs() << "- Unused " << *derived << " set to 0\n");
1144  derived->setSolution(0);
1145  } else {
1146  LLVM_DEBUG(llvm::dbgs() << "- Deriving " << *derived << " = "
1147  << assigned->getSolution() << "\n");
1148  derived->setSolution(*assigned->getSolution());
1149  }
1150  }
1151 
1152  return failure(anyFailed);
1153 }
1154 
1155 // Emits the diagnostic to inform the user about an uninferred width in the
1156 // design. Returns true if an error was reported, false otherwise.
1157 void ConstraintSolver::emitUninferredWidthError(VarExpr *var) {
1158  FieldRef fieldRef = info.find(var)->second.back();
1159  Value value = fieldRef.getValue();
1160 
1161  auto diag = mlir::emitError(value.getLoc(), "uninferred width:");
1162 
1163  // Try to hint the user at what kind of node this is.
1164  if (isa<BlockArgument>(value)) {
1165  diag << " port";
1166  } else if (auto *op = value.getDefiningOp()) {
1167  TypeSwitch<Operation *>(op)
1168  .Case<WireOp>([&](auto) { diag << " wire"; })
1169  .Case<RegOp, RegResetOp>([&](auto) { diag << " reg"; })
1170  .Case<NodeOp>([&](auto) { diag << " node"; })
1171  .Default([&](auto) { diag << " value"; });
1172  } else {
1173  diag << " value";
1174  }
1175 
1176  // Actually print what the user can refer to.
1177  auto [fieldName, rootKnown] = getFieldName(fieldRef);
1178  if (!fieldName.empty()) {
1179  if (!rootKnown)
1180  diag << " field";
1181  diag << " \"" << fieldName << "\"";
1182  }
1183 
1184  if (!var->constraint) {
1185  diag << " is unconstrained";
1186  } else if (var->getSolution() && var->upperBoundSolution &&
1187  var->getSolution() > var->upperBoundSolution) {
1188  diag << " cannot satisfy all width requirements";
1189  LLVM_DEBUG(llvm::dbgs() << *var->constraint << "\n");
1190  LLVM_DEBUG(llvm::dbgs() << *var->upperBound << "\n");
1191  auto loc = locs.find(var->constraint)->second.back();
1192  diag.attachNote(loc) << "width is constrained to be at least "
1193  << *var->getSolution() << " here:";
1194  loc = locs.find(var->upperBound)->second.back();
1195  diag.attachNote(loc) << "width is constrained to be at most "
1196  << *var->upperBoundSolution << " here:";
1197  } else {
1198  diag << " width cannot be determined";
1199  LLVM_DEBUG(llvm::dbgs() << *var->constraint << "\n");
1200  auto loc = locs.find(var->constraint)->second.back();
1201  diag.attachNote(loc) << "width is constrained by an uninferred width here:";
1202  }
1203 }
1204 
1205 //===----------------------------------------------------------------------===//
1206 // Inference Constraint Problem Mapping
1207 //===----------------------------------------------------------------------===//
1208 
1209 namespace {
1210 
1211 /// A helper class which maps the types and operations in a design to a set of
1212 /// variables and constraints to be solved later.
1213 class InferenceMapping {
1214 public:
1215  InferenceMapping(ConstraintSolver &solver, SymbolTable &symtbl,
1216  hw::InnerSymbolTableCollection &istc)
1217  : solver(solver), symtbl(symtbl), irn{symtbl, istc} {}
1218 
1219  LogicalResult map(CircuitOp op);
1220  bool allWidthsKnown(Operation *op);
1221  LogicalResult mapOperation(Operation *op);
1222 
1223  /// Declare all the variables in the value. If the value is a ground type,
1224  /// there is a single variable declared. If the value is an aggregate type,
1225  /// it sets up variables for each unknown width.
1226  void declareVars(Value value, bool isDerived = false);
1227 
1228  /// Assign the constraint expressions of the fields in the `result` argument
1229  /// as the max of expressions in the `rhs` and `lhs` arguments. Both fields
1230  /// must be the same type.
1231  void maximumOfTypes(Value result, Value rhs, Value lhs);
1232 
1233  /// Constrain the value "larger" to be greater than or equal to "smaller".
1234  /// These may be aggregate values. This is used for regular connects.
1235  void constrainTypes(Value larger, Value smaller, bool equal = false);
1236 
1237  /// Constrain the expression "larger" to be greater than or equals to
1238  /// the expression "smaller".
1239  void constrainTypes(Expr *larger, Expr *smaller,
1240  bool imposeUpperBounds = false, bool equal = false);
1241 
1242  /// Assign the constraint expressions of the fields in the `src` argument as
1243  /// the expressions for the `dst` argument. Both fields must be of the given
1244  /// `type`.
1245  void unifyTypes(FieldRef lhs, FieldRef rhs, FIRRTLType type);
1246 
1247  /// Get the expr associated with the value. The value must be a non-aggregate
1248  /// type
1249  Expr *getExpr(Value value) const;
1250 
1251  /// Get the expr associated with a specific field in a value.
1252  Expr *getExpr(FieldRef fieldRef) const;
1253 
1254  /// Get the expr associated with a specific field in a value. If value is
1255  /// NULL, then this returns NULL.
1256  Expr *getExprOrNull(FieldRef fieldRef) const;
1257 
1258  /// Set the expr associated with the value. The value must be a non-aggregate
1259  /// type.
1260  void setExpr(Value value, Expr *expr);
1261 
1262  /// Set the expr associated with a specific field in a value.
1263  void setExpr(FieldRef fieldRef, Expr *expr);
1264 
1265  /// Return whether a module was skipped due to being fully inferred already.
1266  bool isModuleSkipped(FModuleOp module) const {
1267  return skippedModules.count(module);
1268  }
1269 
1270  /// Return whether all modules in the mapping were fully inferred.
1271  bool areAllModulesSkipped() const { return allModulesSkipped; }
1272 
1273 private:
1274  /// The constraint solver into which we emit variables and constraints.
1275  ConstraintSolver &solver;
1276 
1277  /// The constraint exprs for each result type of an operation.
1278  DenseMap<FieldRef, Expr *> opExprs;
1279 
1280  /// The fully inferred modules that were skipped entirely.
1281  SmallPtrSet<Operation *, 16> skippedModules;
1282  bool allModulesSkipped = true;
1283 
1284  /// Cache of module symbols
1285  SymbolTable &symtbl;
1286 
1287  /// Full design inner symbol information.
1288  hw::InnerRefNamespace irn;
1289 };
1290 
1291 } // namespace
1292 
1293 /// Check if a type contains any FIRRTL type with uninferred widths.
1294 static bool hasUninferredWidth(Type type) {
1295  return TypeSwitch<Type, bool>(type)
1296  .Case<FIRRTLBaseType>([](auto base) { return base.hasUninferredWidth(); })
1297  .Case<RefType>(
1298  [](auto ref) { return ref.getType().hasUninferredWidth(); })
1299  .Default([](auto) { return false; });
1300 }
1301 
1302 LogicalResult InferenceMapping::map(CircuitOp op) {
1303  LLVM_DEBUG(llvm::dbgs()
1304  << "\n===----- Mapping ops to constraint exprs -----===\n\n");
1305 
1306  // Ensure we have constraint variables established for all module ports.
1307  for (auto module : op.getOps<FModuleOp>())
1308  for (auto arg : module.getArguments()) {
1309  solver.setCurrentContextInfo(FieldRef(arg, 0));
1310  declareVars(arg);
1311  }
1312 
1313  for (auto module : op.getOps<FModuleOp>()) {
1314  // Check if the module contains *any* uninferred widths. This allows us to
1315  // do an early skip if the module is already fully inferred.
1316  bool anyUninferred = false;
1317  for (auto arg : module.getArguments()) {
1318  anyUninferred |= hasUninferredWidth(arg.getType());
1319  if (anyUninferred)
1320  break;
1321  }
1322  module.walk([&](Operation *op) {
1323  for (auto type : op->getResultTypes())
1324  anyUninferred |= hasUninferredWidth(type);
1325  if (anyUninferred)
1326  return WalkResult::interrupt();
1327  return WalkResult::advance();
1328  });
1329 
1330  if (!anyUninferred) {
1331  LLVM_DEBUG(llvm::dbgs() << "Skipping fully-inferred module '"
1332  << module.getName() << "'\n");
1333  skippedModules.insert(module);
1334  continue;
1335  }
1336 
1337  allModulesSkipped = false;
1338 
1339  // Go through operations in the module, creating type variables for results,
1340  // and generating constraints.
1341  auto result = module.getBodyBlock()->walk(
1342  [&](Operation *op) { return WalkResult(mapOperation(op)); });
1343  if (result.wasInterrupted())
1344  return failure();
1345  }
1346 
1347  return success();
1348 }
1349 
1350 bool InferenceMapping::allWidthsKnown(Operation *op) {
1351  /// Ignore property assignments, no widths to infer.
1352  if (isa<PropAssignOp>(op))
1353  return true;
1354 
1355  // If this is a mux, and the select signal is uninferred, we need to set an
1356  // upperbound limit on it.
1357  if (isa<MuxPrimOp, Mux4CellIntrinsicOp, Mux2CellIntrinsicOp>(op))
1358  if (hasUninferredWidth(op->getOperand(0).getType()))
1359  return false;
1360 
1361  // We need to propagate through connects.
1362  if (isa<FConnectLike, AttachOp>(op))
1363  return false;
1364 
1365  // Check if we know the width of every result of this operation.
1366  return llvm::all_of(op->getResults(), [&](auto result) {
1367  // Only consider FIRRTL types for width constraints. Ignore any foreign
1368  // types as they don't participate in the width inference process.
1369  if (auto type = type_dyn_cast<FIRRTLType>(result.getType()))
1370  if (hasUninferredWidth(type))
1371  return false;
1372  return true;
1373  });
1374 }
1375 
1376 LogicalResult InferenceMapping::mapOperation(Operation *op) {
1377  if (allWidthsKnown(op))
1378  return success();
1379 
1380  // Actually generate the necessary constraint expressions.
1381  bool mappingFailed = false;
1382  solver.setCurrentContextInfo(
1383  op->getNumResults() > 0 ? FieldRef(op->getResults()[0], 0) : FieldRef());
1384  solver.setCurrentLocation(op->getLoc());
1385  TypeSwitch<Operation *>(op)
1386  .Case<ConstantOp>([&](auto op) {
1387  // If the constant has a known width, use that. Otherwise pick the
1388  // smallest number of bits necessary to represent the constant.
1389  auto v = op.getValue();
1390  auto w = v.getBitWidth() - (v.isNegative() ? v.countLeadingOnes()
1391  : v.countLeadingZeros());
1392  if (v.isSigned())
1393  w += 1;
1394  setExpr(op.getResult(), solver.known(std::max(w, 1u)));
1395  })
1396  .Case<SpecialConstantOp>([&](auto op) {
1397  // Nothing required.
1398  })
1399  .Case<InvalidValueOp>([&](auto op) {
1400  // We must duplicate the invalid value for each use, since each use can
1401  // be inferred to a different width.
1402  declareVars(op.getResult(), /*isDerived=*/true);
1403  })
1404  .Case<WireOp, RegOp>([&](auto op) { declareVars(op.getResult()); })
1405  .Case<RegResetOp>([&](auto op) {
1406  // The original Scala code also constrains the reset signal to be at
1407  // least 1 bit wide. We don't do this here since the MLIR FIRRTL
1408  // dialect enforces the reset signal to be an async reset or a
1409  // `uint<1>`.
1410  declareVars(op.getResult());
1411  // Contrain the register to be greater than or equal to the reset
1412  // signal.
1413  constrainTypes(op.getResult(), op.getResetValue());
1414  })
1415  .Case<NodeOp>([&](auto op) {
1416  // Nodes have the same type as their input.
1417  unifyTypes(FieldRef(op.getResult(), 0), FieldRef(op.getInput(), 0),
1418  op.getResult().getType());
1419  })
1420 
1421  // Aggregate Values
1422  .Case<SubfieldOp>([&](auto op) {
1423  BundleType bundleType = op.getInput().getType();
1424  auto fieldID = bundleType.getFieldID(op.getFieldIndex());
1425  unifyTypes(FieldRef(op.getResult(), 0),
1426  FieldRef(op.getInput(), fieldID), op.getType());
1427  })
1428  .Case<SubindexOp, SubaccessOp>([&](auto op) {
1429  // All vec fields unify to the same thing. Always use the first element
1430  // of the vector, which has a field ID of 1.
1431  unifyTypes(FieldRef(op.getResult(), 0), FieldRef(op.getInput(), 1),
1432  op.getType());
1433  })
1434  .Case<SubtagOp>([&](auto op) {
1435  FEnumType enumType = op.getInput().getType();
1436  auto fieldID = enumType.getFieldID(op.getFieldIndex());
1437  unifyTypes(FieldRef(op.getResult(), 0),
1438  FieldRef(op.getInput(), fieldID), op.getType());
1439  })
1440 
1441  .Case<RefSubOp>([&](RefSubOp op) {
1442  uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(
1443  op.getInput().getType().getType())
1444  .Case<FVectorType>([](auto _) { return 1; })
1445  .Case<BundleType>([&](auto type) {
1446  return type.getFieldID(op.getIndex());
1447  });
1448  unifyTypes(FieldRef(op.getResult(), 0),
1449  FieldRef(op.getInput(), fieldID), op.getType());
1450  })
1451 
1452  // Arithmetic and Logical Binary Primitives
1453  .Case<AddPrimOp, SubPrimOp>([&](auto op) {
1454  auto lhs = getExpr(op.getLhs());
1455  auto rhs = getExpr(op.getRhs());
1456  auto e = solver.add(solver.max(lhs, rhs), solver.known(1));
1457  setExpr(op.getResult(), e);
1458  })
1459  .Case<MulPrimOp>([&](auto op) {
1460  auto lhs = getExpr(op.getLhs());
1461  auto rhs = getExpr(op.getRhs());
1462  auto e = solver.add(lhs, rhs);
1463  setExpr(op.getResult(), e);
1464  })
1465  .Case<DivPrimOp>([&](auto op) {
1466  auto lhs = getExpr(op.getLhs());
1467  Expr *e;
1468  if (op.getType().base().isSigned()) {
1469  e = solver.add(lhs, solver.known(1));
1470  } else {
1471  e = lhs;
1472  }
1473  setExpr(op.getResult(), e);
1474  })
1475  .Case<RemPrimOp>([&](auto op) {
1476  auto lhs = getExpr(op.getLhs());
1477  auto rhs = getExpr(op.getRhs());
1478  auto e = solver.min(lhs, rhs);
1479  setExpr(op.getResult(), e);
1480  })
1481  .Case<AndPrimOp, OrPrimOp, XorPrimOp>([&](auto op) {
1482  auto lhs = getExpr(op.getLhs());
1483  auto rhs = getExpr(op.getRhs());
1484  auto e = solver.max(lhs, rhs);
1485  setExpr(op.getResult(), e);
1486  })
1487 
1488  // Misc Binary Primitives
1489  .Case<CatPrimOp>([&](auto op) {
1490  auto lhs = getExpr(op.getLhs());
1491  auto rhs = getExpr(op.getRhs());
1492  auto e = solver.add(lhs, rhs);
1493  setExpr(op.getResult(), e);
1494  })
1495  .Case<DShlPrimOp>([&](auto op) {
1496  auto lhs = getExpr(op.getLhs());
1497  auto rhs = getExpr(op.getRhs());
1498  auto e = solver.add(lhs, solver.add(solver.pow(rhs), solver.known(-1)));
1499  setExpr(op.getResult(), e);
1500  })
1501  .Case<DShlwPrimOp, DShrPrimOp>([&](auto op) {
1502  auto e = getExpr(op.getLhs());
1503  setExpr(op.getResult(), e);
1504  })
1505 
1506  // Unary operators
1507  .Case<NegPrimOp>([&](auto op) {
1508  auto input = getExpr(op.getInput());
1509  auto e = solver.add(input, solver.known(1));
1510  setExpr(op.getResult(), e);
1511  })
1512  .Case<CvtPrimOp>([&](auto op) {
1513  auto input = getExpr(op.getInput());
1514  auto e = op.getInput().getType().base().isSigned()
1515  ? input
1516  : solver.add(input, solver.known(1));
1517  setExpr(op.getResult(), e);
1518  })
1519 
1520  // Miscellaneous
1521  .Case<BitsPrimOp>([&](auto op) {
1522  setExpr(op.getResult(), solver.known(op.getHi() - op.getLo() + 1));
1523  })
1524  .Case<HeadPrimOp>([&](auto op) {
1525  setExpr(op.getResult(), solver.known(op.getAmount()));
1526  })
1527  .Case<TailPrimOp>([&](auto op) {
1528  auto input = getExpr(op.getInput());
1529  auto e = solver.add(input, solver.known(-op.getAmount()));
1530  setExpr(op.getResult(), e);
1531  })
1532  .Case<PadPrimOp>([&](auto op) {
1533  auto input = getExpr(op.getInput());
1534  auto e = solver.max(input, solver.known(op.getAmount()));
1535  setExpr(op.getResult(), e);
1536  })
1537  .Case<ShlPrimOp>([&](auto op) {
1538  auto input = getExpr(op.getInput());
1539  auto e = solver.add(input, solver.known(op.getAmount()));
1540  setExpr(op.getResult(), e);
1541  })
1542  .Case<ShrPrimOp>([&](auto op) {
1543  auto input = getExpr(op.getInput());
1544  // UInt saturates at 0 bits, SInt at 1 bit
1545  auto minWidth = op.getInput().getType().base().isUnsigned() ? 0 : 1;
1546  auto e = solver.max(solver.add(input, solver.known(-op.getAmount())),
1547  solver.known(minWidth));
1548  setExpr(op.getResult(), e);
1549  })
1550 
1551  // Handle operations whose output width matches the input width.
1552  .Case<NotPrimOp, AsSIntPrimOp, AsUIntPrimOp, ConstCastOp>(
1553  [&](auto op) { setExpr(op.getResult(), getExpr(op.getInput())); })
1554  .Case<mlir::UnrealizedConversionCastOp>(
1555  [&](auto op) { setExpr(op.getResult(0), getExpr(op.getOperand(0))); })
1556 
1557  // Handle operations with a single result type that always has a
1558  // well-known width.
1559  .Case<LEQPrimOp, LTPrimOp, GEQPrimOp, GTPrimOp, EQPrimOp, NEQPrimOp,
1560  AsClockPrimOp, AsAsyncResetPrimOp, AndRPrimOp, OrRPrimOp,
1561  XorRPrimOp>([&](auto op) {
1562  auto width = op.getType().getBitWidthOrSentinel();
1563  assert(width > 0 && "width should have been checked by verifier");
1564  setExpr(op.getResult(), solver.known(width));
1565  })
1566  .Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](auto op) {
1567  auto *sel = getExpr(op.getSel());
1568  constrainTypes(solver.known(1), sel, /*imposeUpperBounds=*/true);
1569  maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
1570  })
1571  .Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
1572  auto *sel = getExpr(op.getSel());
1573  constrainTypes(solver.known(2), sel, /*imposeUpperBounds=*/true);
1574  maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
1575  maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
1576  maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
1577  })
1578 
1579  .Case<ConnectOp, MatchingConnectOp>(
1580  [&](auto op) { constrainTypes(op.getDest(), op.getSrc()); })
1581  .Case<RefDefineOp>([&](auto op) {
1582  // Dest >= Src, but also check Src <= Dest for correctness
1583  // (but don't solve to make this true, don't back-propagate)
1584  constrainTypes(op.getDest(), op.getSrc(), true);
1585  })
1586  .Case<AttachOp>([&](auto op) {
1587  // Attach connects multiple analog signals together. All signals must
1588  // have the same bit width. Signals without bit width inherit from the
1589  // other signals.
1590  if (op.getAttached().empty())
1591  return;
1592  auto prev = op.getAttached()[0];
1593  for (auto operand : op.getAttached().drop_front()) {
1594  auto e1 = getExpr(prev);
1595  auto e2 = getExpr(operand);
1596  constrainTypes(e1, e2, /*imposeUpperBounds=*/true);
1597  constrainTypes(e2, e1, /*imposeUpperBounds=*/true);
1598  prev = operand;
1599  }
1600  })
1601 
1602  // Handle the no-ops that don't interact with width inference.
1603  .Case<PrintFOp, SkipOp, StopOp, WhenOp, AssertOp, AssumeOp,
1604  UnclockedAssumeIntrinsicOp, CoverOp>([&](auto) {})
1605 
1606  // Handle instances of other modules.
1607  .Case<InstanceOp>([&](auto op) {
1608  auto refdModule = op.getReferencedOperation(symtbl);
1609  auto module = dyn_cast<FModuleOp>(&*refdModule);
1610  if (!module) {
1611  auto diag = mlir::emitError(op.getLoc());
1612  diag << "extern module `" << op.getModuleName()
1613  << "` has ports of uninferred width";
1614 
1615  auto fml = cast<FModuleLike>(&*refdModule);
1616  auto ports = fml.getPorts();
1617  for (auto &port : ports) {
1618  auto baseType = getBaseType(port.type);
1619  if (baseType && baseType.hasUninferredWidth()) {
1620  diag.attachNote(op.getLoc()) << "Port: " << port.name;
1621  if (!baseType.isGround())
1622  diagnoseUninferredType(diag, baseType, port.name.getValue());
1623  }
1624  }
1625 
1626  diag.attachNote(op.getLoc())
1627  << "Only non-extern FIRRTL modules may contain unspecified "
1628  "widths to be inferred automatically.";
1629  diag.attachNote(refdModule->getLoc())
1630  << "Module `" << op.getModuleName() << "` defined here:";
1631  mappingFailed = true;
1632  return;
1633  }
1634  // Simply look up the free variables created for the instantiated
1635  // module's ports, and use them for instance port wires. This way,
1636  // constraints imposed onto the ports of the instance will transparently
1637  // apply to the ports of the instantiated module.
1638  for (auto [result, arg] :
1639  llvm::zip(op->getResults(), module.getArguments()))
1640  unifyTypes({result, 0}, {arg, 0},
1641  type_cast<FIRRTLType>(result.getType()));
1642  })
1643 
1644  // Handle memories.
1645  .Case<MemOp>([&](MemOp op) {
1646  // Create constraint variables for all ports.
1647  unsigned nonDebugPort = 0;
1648  for (const auto &result : llvm::enumerate(op.getResults())) {
1649  declareVars(result.value());
1650  if (!type_isa<RefType>(result.value().getType()))
1651  nonDebugPort = result.index();
1652  }
1653 
1654  // A helper function that returns the indeces of the "data", "rdata",
1655  // and "wdata" fields in the bundle corresponding to a memory port.
1656  auto dataFieldIndices = [](MemOp::PortKind kind) -> ArrayRef<unsigned> {
1657  static const unsigned indices[] = {3, 5};
1658  static const unsigned debug[] = {0};
1659  switch (kind) {
1660  case MemOp::PortKind::Read:
1661  case MemOp::PortKind::Write:
1662  return ArrayRef<unsigned>(indices, 1); // {3}
1663  case MemOp::PortKind::ReadWrite:
1664  return ArrayRef<unsigned>(indices); // {3, 5}
1665  case MemOp::PortKind::Debug:
1666  return ArrayRef<unsigned>(debug);
1667  }
1668  llvm_unreachable("Imposible PortKind");
1669  };
1670 
1671  // This creates independent variables for every data port. Yet, what we
1672  // actually want is for all data ports to share the same variable. To do
1673  // this, we find the first data port declared, and use that port's vars
1674  // for all the other ports.
1675  unsigned firstFieldIndex =
1676  dataFieldIndices(op.getPortKind(nonDebugPort))[0];
1677  FieldRef firstData(
1678  op.getResult(nonDebugPort),
1679  type_cast<BundleType>(op.getPortType(nonDebugPort).getPassiveType())
1680  .getFieldID(firstFieldIndex));
1681  LLVM_DEBUG(llvm::dbgs() << "Adjusting memory port variables:\n");
1682 
1683  // Reuse data port variables.
1684  auto dataType = op.getDataType();
1685  for (unsigned i = 0, e = op.getResults().size(); i < e; ++i) {
1686  auto result = op.getResult(i);
1687  if (type_isa<RefType>(result.getType())) {
1688  // Debug ports are firrtl.ref<vector<data-type, depth>>
1689  // Use FieldRef of 1, to indicate the first vector element must be
1690  // of the dataType.
1691  unifyTypes(firstData, FieldRef(result, 1), dataType);
1692  continue;
1693  }
1694 
1695  auto portType =
1696  type_cast<BundleType>(op.getPortType(i).getPassiveType());
1697  for (auto fieldIndex : dataFieldIndices(op.getPortKind(i)))
1698  unifyTypes(FieldRef(result, portType.getFieldID(fieldIndex)),
1699  firstData, dataType);
1700  }
1701  })
1702 
1703  .Case<RefSendOp>([&](auto op) {
1704  declareVars(op.getResult());
1705  constrainTypes(op.getResult(), op.getBase(), true);
1706  })
1707  .Case<RefResolveOp>([&](auto op) {
1708  declareVars(op.getResult());
1709  constrainTypes(op.getResult(), op.getRef(), true);
1710  })
1711  .Case<RefCastOp>([&](auto op) {
1712  declareVars(op.getResult());
1713  constrainTypes(op.getResult(), op.getInput(), true);
1714  })
1715  .Case<RWProbeOp>([&](auto op) {
1716  auto ist = irn.lookup(op.getTarget());
1717  if (!ist) {
1718  op->emitError("target of rwprobe could not be resolved");
1719  mappingFailed = true;
1720  return;
1721  }
1722  auto ref = getFieldRefForTarget(ist);
1723  if (!ref) {
1724  op->emitError("target of rwprobe resolved to unsupported target");
1725  mappingFailed = true;
1726  return;
1727  }
1728  auto newFID = convertFieldIDToOurVersion(
1729  ref.getFieldID(), type_cast<FIRRTLType>(ref.getValue().getType()));
1730  unifyTypes(FieldRef(op.getResult(), 0),
1731  FieldRef(ref.getValue(), newFID), op.getType());
1732  })
1733  .Case<mlir::UnrealizedConversionCastOp>([&](auto op) {
1734  for (Value result : op.getResults()) {
1735  auto ty = result.getType();
1736  if (type_isa<FIRRTLType>(ty))
1737  declareVars(result);
1738  }
1739  })
1740  .Default([&](auto op) {
1741  op->emitOpError("not supported in width inference");
1742  mappingFailed = true;
1743  });
1744 
1745  // Forceable declarations should have the ref constrained to data result.
1746  if (auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable())
1747  unifyTypes(FieldRef(fop.getDataRef(), 0), FieldRef(fop.getDataRaw(), 0),
1748  fop.getDataType());
1749 
1750  return failure(mappingFailed);
1751 }
1752 
1753 /// Declare free variables for the type of a value, and associate the resulting
1754 /// set of variables with that value.
1755 void InferenceMapping::declareVars(Value value, bool isDerived) {
1756  // Declare a variable for every unknown width in the type. If this is a Bundle
1757  // type or a FVector type, we will have to potentially create many variables.
1758  unsigned fieldID = 0;
1759  std::function<void(FIRRTLBaseType)> declare = [&](FIRRTLBaseType type) {
1760  auto width = type.getBitWidthOrSentinel();
1761  if (width >= 0) {
1762  fieldID++;
1763  } else if (width == -1) {
1764  // Unknown width integers create a variable.
1765  FieldRef field(value, fieldID);
1766  solver.setCurrentContextInfo(field);
1767  if (isDerived)
1768  setExpr(field, solver.derived());
1769  else
1770  setExpr(field, solver.var());
1771  fieldID++;
1772  } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1773  // Bundle types recursively declare all bundle elements.
1774  fieldID++;
1775  for (auto &element : bundleType)
1776  declare(element.type);
1777  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1778  fieldID++;
1779  auto save = fieldID;
1780  declare(vecType.getElementType());
1781  // Skip past the rest of the elements
1782  fieldID = save + vecType.getMaxFieldID();
1783  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1784  fieldID++;
1785  for (auto &element : enumType.getElements())
1786  declare(element.type);
1787  } else {
1788  llvm_unreachable("Unknown type inside a bundle!");
1789  }
1790  };
1791  if (auto type = getBaseType(value.getType()))
1792  declare(type);
1793 }
1794 
1795 /// Assign the constraint expressions of the fields in the `result` argument as
1796 /// the max of expressions in the `rhs` and `lhs` arguments. Both fields must be
1797 /// the same type.
1798 void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
1799  // Recurse to every leaf element and set larger >= smaller.
1800  auto fieldID = 0;
1801  std::function<void(FIRRTLBaseType)> maximize = [&](FIRRTLBaseType type) {
1802  if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1803  fieldID++;
1804  for (auto &element : bundleType.getElements())
1805  maximize(element.type);
1806  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1807  fieldID++;
1808  auto save = fieldID;
1809  // Skip 0 length vectors.
1810  if (vecType.getNumElements() > 0)
1811  maximize(vecType.getElementType());
1812  fieldID = save + vecType.getMaxFieldID();
1813  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1814  fieldID++;
1815  for (auto &element : enumType.getElements())
1816  maximize(element.type);
1817  } else if (type.isGround()) {
1818  auto *e = solver.max(getExpr(FieldRef(rhs, fieldID)),
1819  getExpr(FieldRef(lhs, fieldID)));
1820  setExpr(FieldRef(result, fieldID), e);
1821  fieldID++;
1822  } else {
1823  llvm_unreachable("Unknown type inside a bundle!");
1824  }
1825  };
1826  if (auto type = getBaseType(result.getType()))
1827  maximize(type);
1828 }
1829 
1830 /// Establishes constraints to ensure the sizes in the `larger` type are greater
1831 /// than or equal to the sizes in the `smaller` type. Types have to be
1832 /// compatible in the sense that they may only differ in the presence or absence
1833 /// of bit widths.
1834 ///
1835 /// This function is used to apply regular connects.
1836 /// Set `equal` for constraining larger <= smaller for correctness but not
1837 /// solving.
1838 void InferenceMapping::constrainTypes(Value larger, Value smaller, bool equal) {
1839  // Recurse to every leaf element and set larger >= smaller. Ignore foreign
1840  // types as these do not participate in width inference.
1841 
1842  auto fieldID = 0;
1843  std::function<void(FIRRTLBaseType, Value, Value)> constrain =
1844  [&](FIRRTLBaseType type, Value larger, Value smaller) {
1845  if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1846  fieldID++;
1847  for (auto &element : bundleType.getElements()) {
1848  if (element.isFlip)
1849  constrain(element.type, smaller, larger);
1850  else
1851  constrain(element.type, larger, smaller);
1852  }
1853  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1854  fieldID++;
1855  auto save = fieldID;
1856  // Skip 0 length vectors.
1857  if (vecType.getNumElements() > 0) {
1858  constrain(vecType.getElementType(), larger, smaller);
1859  }
1860  fieldID = save + vecType.getMaxFieldID();
1861  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1862  fieldID++;
1863  for (auto &element : enumType.getElements())
1864  constrain(element.type, larger, smaller);
1865  } else if (type.isGround()) {
1866  // Leaf element, look up their expressions, and create the constraint.
1867  constrainTypes(getExpr(FieldRef(larger, fieldID)),
1868  getExpr(FieldRef(smaller, fieldID)), false, equal);
1869  fieldID++;
1870  } else {
1871  llvm_unreachable("Unknown type inside a bundle!");
1872  }
1873  };
1874 
1875  if (auto type = getBaseType(larger.getType()))
1876  constrain(type, larger, smaller);
1877 }
1878 
1879 /// Establishes constraints to ensure the sizes in the `larger` type are greater
1880 /// than or equal to the sizes in the `smaller` type.
1881 void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
1882  bool imposeUpperBounds, bool equal) {
1883  assert(larger && "Larger expression should be specified");
1884  assert(smaller && "Smaller expression should be specified");
1885 
1886  // If one of the sides is `DerivedExpr`, simply assign the other side as the
1887  // derived width. This allows `InvalidValueOp`s to properly infer their width
1888  // from the connects they are used in, but also be inferred to something
1889  // useful on their own.
1890  if (auto *largerDerived = dyn_cast<DerivedExpr>(larger)) {
1891  largerDerived->assigned = smaller;
1892  LLVM_DEBUG(llvm::dbgs() << "Deriving " << *largerDerived << " from "
1893  << *smaller << "\n");
1894  return;
1895  }
1896  if (auto *smallerDerived = dyn_cast<DerivedExpr>(smaller)) {
1897  smallerDerived->assigned = larger;
1898  LLVM_DEBUG(llvm::dbgs() << "Deriving " << *smallerDerived << " from "
1899  << *larger << "\n");
1900  return;
1901  }
1902 
1903  // If the larger expr is a free variable, create a `expr >= x` constraint for
1904  // it that we can try to satisfy with the smallest width.
1905  if (auto *largerVar = dyn_cast<VarExpr>(larger)) {
1906  [[maybe_unused]] auto *c = solver.addGeqConstraint(largerVar, smaller);
1907  LLVM_DEBUG(llvm::dbgs()
1908  << "Constrained " << *largerVar << " >= " << *c << "\n");
1909  // If we're constraining larger == smaller, add the LEQ contraint as well.
1910  // Solve for GEQ but check that LEQ is true.
1911  // Used for matchingconnect, some reference operations, and anywhere the
1912  // widths should be inferred strictly in one direction but are required to
1913  // also be equal for correctness.
1914  if (equal) {
1915  [[maybe_unused]] auto *leq = solver.addLeqConstraint(largerVar, smaller);
1916  LLVM_DEBUG(llvm::dbgs()
1917  << "Constrained " << *largerVar << " <= " << *leq << "\n");
1918  }
1919  return;
1920  }
1921 
1922  // If the smaller expr is a free variable but the larger one is not, create a
1923  // `expr <= k` upper bound that we can verify once all lower bounds have been
1924  // satisfied. Since we are always picking the smallest width to satisfy all
1925  // `>=` constraints, any `<=` constraints have no effect on the solution
1926  // besides indicating that a width is unsatisfiable.
1927  if (auto *smallerVar = dyn_cast<VarExpr>(smaller)) {
1928  if (imposeUpperBounds || equal) {
1929  [[maybe_unused]] auto *c = solver.addLeqConstraint(smallerVar, larger);
1930  LLVM_DEBUG(llvm::dbgs()
1931  << "Constrained " << *smallerVar << " <= " << *c << "\n");
1932  }
1933  }
1934 }
1935 
1936 /// Assign the constraint expressions of the fields in the `src` argument as the
1937 /// expressions for the `dst` argument. Both fields must be of the given `type`.
1938 void InferenceMapping::unifyTypes(FieldRef lhs, FieldRef rhs, FIRRTLType type) {
1939  // Fast path for `unifyTypes(x, x, _)`.
1940  if (lhs == rhs)
1941  return;
1942 
1943  // Co-iterate the two field refs, recurring into every leaf element and set
1944  // them equal.
1945  auto fieldID = 0;
1946  std::function<void(FIRRTLBaseType)> unify = [&](FIRRTLBaseType type) {
1947  if (type.isGround()) {
1948  // Leaf element, unify the fields!
1949  FieldRef lhsFieldRef(lhs.getValue(), lhs.getFieldID() + fieldID);
1950  FieldRef rhsFieldRef(rhs.getValue(), rhs.getFieldID() + fieldID);
1951  LLVM_DEBUG(llvm::dbgs()
1952  << "Unify " << getFieldName(lhsFieldRef).first << " = "
1953  << getFieldName(rhsFieldRef).first << "\n");
1954  // Abandon variables becoming unconstrainable by the unification.
1955  if (auto *var = dyn_cast_or_null<VarExpr>(getExprOrNull(lhsFieldRef)))
1956  solver.addGeqConstraint(var, solver.known(0));
1957  setExpr(lhsFieldRef, getExpr(rhsFieldRef));
1958  fieldID++;
1959  } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1960  fieldID++;
1961  for (auto &element : bundleType) {
1962  unify(element.type);
1963  }
1964  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1965  fieldID++;
1966  auto save = fieldID;
1967  // Skip 0 length vectors.
1968  if (vecType.getNumElements() > 0) {
1969  unify(vecType.getElementType());
1970  }
1971  fieldID = save + vecType.getMaxFieldID();
1972  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1973  fieldID++;
1974  for (auto &element : enumType.getElements())
1975  unify(element.type);
1976  } else {
1977  llvm_unreachable("Unknown type inside a bundle!");
1978  }
1979  };
1980  if (auto ftype = getBaseType(type))
1981  unify(ftype);
1982 }
1983 
1984 /// Get the constraint expression for a value.
1985 Expr *InferenceMapping::getExpr(Value value) const {
1986  assert(type_cast<FIRRTLType>(getBaseType(value.getType())).isGround());
1987  // A field ID of 0 indicates the entire value.
1988  return getExpr(FieldRef(value, 0));
1989 }
1990 
1991 /// Get the constraint expression for a value.
1992 Expr *InferenceMapping::getExpr(FieldRef fieldRef) const {
1993  auto *expr = getExprOrNull(fieldRef);
1994  assert(expr && "constraint expr should have been constructed for value");
1995  return expr;
1996 }
1997 
1998 Expr *InferenceMapping::getExprOrNull(FieldRef fieldRef) const {
1999  auto it = opExprs.find(fieldRef);
2000  if (it != opExprs.end())
2001  return it->second;
2002  // If we don't have an expression for this fieldRef, it should have a
2003  // constant width.
2004  auto baseType = getBaseType(fieldRef.getValue().getType());
2005  auto type =
2006  hw::FieldIdImpl::getFinalTypeByFieldID(baseType, fieldRef.getFieldID());
2007  auto width = cast<FIRRTLBaseType>(type).getBitWidthOrSentinel();
2008  if (width < 0)
2009  return nullptr;
2010  return solver.known(width);
2011 }
2012 
2013 /// Associate a constraint expression with a value.
2014 void InferenceMapping::setExpr(Value value, Expr *expr) {
2015  assert(type_cast<FIRRTLType>(getBaseType(value.getType())).isGround());
2016  // A field ID of 0 indicates the entire value.
2017  setExpr(FieldRef(value, 0), expr);
2018 }
2019 
2020 /// Associate a constraint expression with a value.
2021 void InferenceMapping::setExpr(FieldRef fieldRef, Expr *expr) {
2022  LLVM_DEBUG({
2023  llvm::dbgs() << "Expr " << *expr << " for " << fieldRef.getValue();
2024  if (fieldRef.getFieldID())
2025  llvm::dbgs() << " '" << getFieldName(fieldRef).first << "'";
2026  auto fieldName = getFieldName(fieldRef);
2027  if (fieldName.second)
2028  llvm::dbgs() << " (\"" << fieldName.first << "\")";
2029  llvm::dbgs() << "\n";
2030  });
2031  opExprs[fieldRef] = expr;
2032 }
2033 
2034 //===----------------------------------------------------------------------===//
2035 // Inference Result Application
2036 //===----------------------------------------------------------------------===//
2037 
2038 namespace {
2039 /// A helper class which maps the types and operations in a design to a set
2040 /// of variables and constraints to be solved later.
2041 class InferenceTypeUpdate {
2042 public:
2043  InferenceTypeUpdate(InferenceMapping &mapping) : mapping(mapping) {}
2044 
2045  LogicalResult update(CircuitOp op);
2046  FailureOr<bool> updateOperation(Operation *op);
2047  FailureOr<bool> updateValue(Value value);
2049 
2050 private:
2051  const InferenceMapping &mapping;
2052 };
2053 
2054 } // namespace
2055 
2056 /// Update the types throughout a circuit.
2057 LogicalResult InferenceTypeUpdate::update(CircuitOp op) {
2058  LLVM_DEBUG({
2059  llvm::dbgs() << "\n";
2060  debugHeader("Update types") << "\n\n";
2061  });
2062  return mlir::failableParallelForEach(
2063  op.getContext(), op.getOps<FModuleOp>(), [&](FModuleOp op) {
2064  // Skip this module if it had no widths to be
2065  // inferred at all.
2066  if (mapping.isModuleSkipped(op))
2067  return success();
2068  auto isFailed = op.walk<WalkOrder::PreOrder>([&](Operation *op) {
2069  if (failed(updateOperation(op)))
2070  return WalkResult::interrupt();
2071  return WalkResult::advance();
2072  }).wasInterrupted();
2073  return failure(isFailed);
2074  });
2075 }
2076 
2077 /// Update the result types of an operation.
2078 FailureOr<bool> InferenceTypeUpdate::updateOperation(Operation *op) {
2079  bool anyChanged = false;
2080 
2081  for (Value v : op->getResults()) {
2082  auto result = updateValue(v);
2083  if (failed(result))
2084  return result;
2085  anyChanged |= *result;
2086  }
2087 
2088  // If this is a connect operation, width inference might have inferred a RHS
2089  // that is wider than the LHS, in which case an additional BitsPrimOp is
2090  // necessary to truncate the value.
2091  if (auto con = dyn_cast<ConnectOp>(op)) {
2092  auto lhs = con.getDest();
2093  auto rhs = con.getSrc();
2094  auto lhsType = type_dyn_cast<FIRRTLBaseType>(lhs.getType());
2095  auto rhsType = type_dyn_cast<FIRRTLBaseType>(rhs.getType());
2096 
2097  // Nothing to do if not base types.
2098  if (!lhsType || !rhsType)
2099  return anyChanged;
2100 
2101  auto lhsWidth = lhsType.getBitWidthOrSentinel();
2102  auto rhsWidth = rhsType.getBitWidthOrSentinel();
2103  if (lhsWidth >= 0 && rhsWidth >= 0 && lhsWidth < rhsWidth) {
2104  OpBuilder builder(op);
2105  auto trunc = builder.createOrFold<TailPrimOp>(con.getLoc(), con.getSrc(),
2106  rhsWidth - lhsWidth);
2107  if (type_isa<SIntType>(rhsType))
2108  trunc =
2109  builder.createOrFold<AsSIntPrimOp>(con.getLoc(), lhsType, trunc);
2110 
2111  LLVM_DEBUG(llvm::dbgs()
2112  << "Truncating RHS to " << lhsType << " in " << con << "\n");
2113  con->replaceUsesOfWith(con.getSrc(), trunc);
2114  }
2115  return anyChanged;
2116  }
2117 
2118  // If this is a module, update its ports.
2119  if (auto module = dyn_cast<FModuleOp>(op)) {
2120  // Update the block argument types.
2121  bool argsChanged = false;
2122  SmallVector<Attribute> argTypes;
2123  argTypes.reserve(module.getNumPorts());
2124  for (auto arg : module.getArguments()) {
2125  auto result = updateValue(arg);
2126  if (failed(result))
2127  return result;
2128  argsChanged |= *result;
2129  argTypes.push_back(TypeAttr::get(arg.getType()));
2130  }
2131 
2132  // Update the module function type if needed.
2133  if (argsChanged) {
2134  module.setPortTypesAttr(ArrayAttr::get(module.getContext(), argTypes));
2135  anyChanged = true;
2136  }
2137  }
2138  return anyChanged;
2139 }
2140 
2141 /// Resize a `uint`, `sint`, or `analog` type to a specific width.
2142 static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth) {
2143  auto *context = type.getContext();
2145  .Case<UIntType>([&](auto type) {
2146  return UIntType::get(context, newWidth, type.isConst());
2147  })
2148  .Case<SIntType>([&](auto type) {
2149  return SIntType::get(context, newWidth, type.isConst());
2150  })
2151  .Case<AnalogType>([&](auto type) {
2152  return AnalogType::get(context, newWidth, type.isConst());
2153  })
2154  .Default([&](auto type) { return type; });
2155 }
2156 
2157 /// Update the type of a value.
2158 FailureOr<bool> InferenceTypeUpdate::updateValue(Value value) {
2159  // Check if the value has a type which we can update.
2160  auto type = type_dyn_cast<FIRRTLType>(value.getType());
2161  if (!type)
2162  return false;
2163 
2164  // Fast path for types that have fully inferred widths.
2165  if (!hasUninferredWidth(type))
2166  return false;
2167 
2168  // If this is an operation that does not generate any free variables that
2169  // are determined during width inference, simply update the value type based
2170  // on the operation arguments.
2171  if (auto op = dyn_cast_or_null<InferTypeOpInterface>(value.getDefiningOp())) {
2172  SmallVector<Type, 2> types;
2173  auto res =
2174  op.inferReturnTypes(op->getContext(), op->getLoc(), op->getOperands(),
2175  op->getAttrDictionary(), op->getPropertiesStorage(),
2176  op->getRegions(), types);
2177  if (failed(res))
2178  return failure();
2179 
2180  assert(types.size() == op->getNumResults());
2181  for (auto [result, type] : llvm::zip(op->getResults(), types)) {
2182  LLVM_DEBUG(llvm::dbgs()
2183  << "Inferring " << result << " as " << type << "\n");
2184  result.setType(type);
2185  }
2186  return true;
2187  }
2188 
2189  // Recreate the type, substituting the solved widths.
2190  auto *context = type.getContext();
2191  unsigned fieldID = 0;
2192  std::function<FIRRTLBaseType(FIRRTLBaseType)> updateBase =
2193  [&](FIRRTLBaseType type) -> FIRRTLBaseType {
2194  auto width = type.getBitWidthOrSentinel();
2195  if (width >= 0) {
2196  // Known width integers return themselves.
2197  fieldID++;
2198  return type;
2199  }
2200  if (width == -1) {
2201  // Unknown width integers return the solved type.
2202  auto newType = updateType(FieldRef(value, fieldID), type);
2203  fieldID++;
2204  return newType;
2205  }
2206  if (auto bundleType = type_dyn_cast<BundleType>(type)) {
2207  // Bundle types recursively update all bundle elements.
2208  fieldID++;
2209  llvm::SmallVector<BundleType::BundleElement, 3> elements;
2210  for (auto &element : bundleType) {
2211  auto updatedBase = updateBase(element.type);
2212  if (!updatedBase)
2213  return {};
2214  elements.emplace_back(element.name, element.isFlip, updatedBase);
2215  }
2216  return BundleType::get(context, elements, bundleType.isConst());
2217  }
2218  if (auto vecType = type_dyn_cast<FVectorType>(type)) {
2219  fieldID++;
2220  auto save = fieldID;
2221  // TODO: this should recurse into the element type of 0 length vectors and
2222  // set any unknown width to 0.
2223  if (vecType.getNumElements() > 0) {
2224  auto updatedBase = updateBase(vecType.getElementType());
2225  if (!updatedBase)
2226  return {};
2227  auto newType = FVectorType::get(updatedBase, vecType.getNumElements(),
2228  vecType.isConst());
2229  fieldID = save + vecType.getMaxFieldID();
2230  return newType;
2231  }
2232  // If this is a 0 length vector return the original type.
2233  return type;
2234  }
2235  if (auto enumType = type_dyn_cast<FEnumType>(type)) {
2236  fieldID++;
2237  llvm::SmallVector<FEnumType::EnumElement> elements;
2238  for (auto &element : enumType.getElements()) {
2239  auto updatedBase = updateBase(element.type);
2240  if (!updatedBase)
2241  return {};
2242  elements.emplace_back(element.name, updatedBase);
2243  }
2244  return FEnumType::get(context, elements, enumType.isConst());
2245  }
2246  llvm_unreachable("Unknown type inside a bundle!");
2247  };
2248 
2249  // Update the type.
2250  auto newType = mapBaseTypeNullable(type, updateBase);
2251  if (!newType)
2252  return failure();
2253  LLVM_DEBUG(llvm::dbgs() << "Update " << value << " to " << newType << "\n");
2254  value.setType(newType);
2255 
2256  // If this is a ConstantOp, adjust the width of the underlying APInt.
2257  // Unsized constants have APInts which are *at least* wide enough to hold
2258  // the value, but may be larger. This can trip up the verifier.
2259  if (auto op = value.getDefiningOp<ConstantOp>()) {
2260  auto k = op.getValue();
2261  auto bitwidth = op.getType().getBitWidthOrSentinel();
2262  if (k.getBitWidth() > unsigned(bitwidth))
2263  k = k.trunc(bitwidth);
2264  op->setAttr("value", IntegerAttr::get(op.getContext(), k));
2265  }
2266 
2267  return newType != type;
2268 }
2269 
2270 /// Update a type.
2272  FIRRTLBaseType type) {
2273  assert(type.isGround() && "Can only pass in ground types.");
2274  auto value = fieldRef.getValue();
2275  // Get the inferred width.
2276  Expr *expr = mapping.getExprOrNull(fieldRef);
2277  if (!expr || !expr->getSolution()) {
2278  // It should not be possible to arrive at an uninferred width at this point.
2279  // In case the constraints are not resolvable, checks before the calls to
2280  // `updateType` must have already caught the issues and aborted the pass
2281  // early. Might turn this into an assert later.
2282  mlir::emitError(value.getLoc(), "width should have been inferred");
2283  return {};
2284  }
2285  int32_t solution = *expr->getSolution();
2286  assert(solution >= 0); // The solver infers variables to be 0 or greater.
2287  return resizeType(type, solution);
2288 }
2289 
2290 //===----------------------------------------------------------------------===//
2291 // Pass Infrastructure
2292 //===----------------------------------------------------------------------===//
2293 
2294 namespace {
2295 class InferWidthsPass
2296  : public circt::firrtl::impl::InferWidthsBase<InferWidthsPass> {
2297  void runOnOperation() override;
2298 };
2299 } // namespace
2300 
2301 void InferWidthsPass::runOnOperation() {
2302  // Collect variables and constraints
2303  ConstraintSolver solver;
2304  InferenceMapping mapping(solver, getAnalysis<SymbolTable>(),
2305  getAnalysis<hw::InnerSymbolTableCollection>());
2306  if (failed(mapping.map(getOperation())))
2307  return signalPassFailure();
2308 
2309  // fast path if no inferrable widths are around
2310  if (mapping.areAllModulesSkipped())
2311  return markAllAnalysesPreserved();
2312 
2313  // Solve the constraints.
2314  if (failed(solver.solve()))
2315  return signalPassFailure();
2316 
2317  // Update the types with the inferred widths.
2318  if (failed(InferenceTypeUpdate(mapping).update(getOperation())))
2319  return signalPassFailure();
2320 }
2321 
2322 std::unique_ptr<mlir::Pass> circt::firrtl::createInferWidthsPass() {
2323  return std::make_unique<InferWidthsPass>();
2324 }
assert(baseType &&"element must be base type")
int32_t width
Definition: FIRRTL.cpp:36
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
std::pair< std::optional< int32_t >, bool > ExprSolution
static ExprSolution computeUnary(ExprSolution arg, llvm::function_ref< int32_t(int32_t)> operation)
static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type)
Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
Definition: InferWidths.cpp:69
static ExprSolution computeBinary(ExprSolution lhs, ExprSolution rhs, llvm::function_ref< int32_t(int32_t, int32_t)> operation)
#define EXPR_KINDS
#define EXPR_CLASSES
static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth)
Resize a uint, sint, or analog type to a specific width.
static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t, Twine str)
Definition: InferWidths.cpp:51
static bool hasUninferredWidth(Type type)
Check if a type contains any FIRRTL type with uninferred widths.
static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl< Expr * > &seenVars, std::vector< Frame > &worklist)
Compute the value of a constraint expr.
static InstancePath empty
This class represents a reference to a specific field or element of an aggregate value.
Definition: FieldRef.h:28
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Definition: FieldRef.h:59
Value getValue() const
Get the Value which created this location.
Definition: FieldRef.h:37
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:520
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:530
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
FIRRTLType mapBaseTypeNullable(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
Definition: FIRRTLUtils.h:252
FieldRef getFieldRefForTarget(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
Definition: FIRRTLUtils.h:222
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
llvm::raw_ostream & operator<<(llvm::raw_ostream &os, const InstanceInfo::LatticeValue &value)
std::unique_ptr< mlir::Pass > createInferWidthsPass()
llvm::hash_code hash_value(const ClassElement &element)
Definition: FIRRTLTypes.h:368
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
::mlir::Type getFinalTypeByFieldID(Type type, uint64_t fieldID)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugHeader(llvm::StringRef str, int width=80)
Write a "header"-like string to the debug stream with a certain width.
Definition: Debug.cpp:18
bool operator==(uint64_t a, const FVInt &b)
Definition: FVInt.h:649
Definition: debug.py:1
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition: Utils.h:32
llvm::hash_code hash_value(const T &e)