CIRCT  19.0.0git
InferWidths.cpp
Go to the documentation of this file.
1 //===- InferWidths.cpp - Infer width of types -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the InferWidths pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
15 #include "mlir/Pass/Pass.h"
16 
22 #include "circt/Support/Debug.h"
23 #include "circt/Support/FieldRef.h"
24 #include "mlir/IR/ImplicitLocOpBuilder.h"
25 #include "mlir/IR/Threading.h"
26 #include "llvm/ADT/APSInt.h"
27 #include "llvm/ADT/DenseSet.h"
28 #include "llvm/ADT/GraphTraits.h"
29 #include "llvm/ADT/Hashing.h"
30 #include "llvm/ADT/MapVector.h"
31 #include "llvm/ADT/PostOrderIterator.h"
32 #include "llvm/ADT/SetVector.h"
33 #include "llvm/Support/Debug.h"
34 #include "llvm/Support/ErrorHandling.h"
35 
36 #define DEBUG_TYPE "infer-widths"
37 
38 namespace circt {
39 namespace firrtl {
40 #define GEN_PASS_DEF_INFERWIDTHS
41 #include "circt/Dialect/FIRRTL/Passes.h.inc"
42 } // namespace firrtl
43 } // namespace circt
44 
45 using mlir::InferTypeOpInterface;
46 using mlir::WalkOrder;
47 
48 using namespace circt;
49 using namespace firrtl;
50 
51 //===----------------------------------------------------------------------===//
52 // Helpers
53 //===----------------------------------------------------------------------===//
54 
55 static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t,
56  Twine str) {
57  auto basetype = type_dyn_cast<FIRRTLBaseType>(t);
58  if (!basetype)
59  return;
60  if (!basetype.hasUninferredWidth())
61  return;
62 
63  if (basetype.isGround())
64  diag.attachNote() << "Field: \"" << str << "\"";
65  else if (auto vecType = type_dyn_cast<FVectorType>(basetype))
66  diagnoseUninferredType(diag, vecType.getElementType(), str + "[]");
67  else if (auto bundleType = type_dyn_cast<BundleType>(basetype))
68  for (auto &elem : bundleType.getElements())
69  diagnoseUninferredType(diag, elem.type, str + "." + elem.name.getValue());
70 }
71 
72 /// Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
73 static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type) {
74  uint64_t convertedFieldID = 0;
75 
76  auto curFID = fieldID;
77  Type curFType = type;
78  while (curFID != 0) {
79  auto [child, subID] =
80  hw::FieldIdImpl::getSubTypeByFieldID(curFType, curFID);
81  if (isa<FVectorType>(curFType))
82  convertedFieldID++; // Vector fieldID is 1.
83  else
84  convertedFieldID += curFID - subID; // Add consumed portion.
85  curFID = subID;
86  curFType = child;
87  }
88 
89  return convertedFieldID;
90 }
91 
92 //===----------------------------------------------------------------------===//
93 // Constraint Expressions
94 //===----------------------------------------------------------------------===//
95 
96 namespace {
97 struct Expr;
98 } // namespace
99 
100 /// Allow rvalue refs to `Expr` and subclasses to be printed to streams.
101 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
102  int>::type = 0>
103 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const T &e) {
104  e.print(os);
105  return os;
106 }
107 
108 // Allow expression subclasses to be hashed.
109 namespace mlir {
110 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
111  int>::type = 0>
112 inline llvm::hash_code hash_value(const T &e) {
113  return e.hash_value();
114 }
115 } // namespace mlir
116 
117 namespace {
118 #define EXPR_NAMES(x) \
119  Root##x, Var##x, Derived##x, Id##x, Known##x, Add##x, Pow##x, Max##x, Min##x
120 #define EXPR_KINDS EXPR_NAMES()
121 #define EXPR_CLASSES EXPR_NAMES(Expr)
122 
123 /// An expression on the right-hand side of a constraint.
124 struct Expr {
125  enum class Kind { EXPR_KINDS };
126  std::optional<int32_t> solution;
127  Kind kind;
128 
129  /// Print a human-readable representation of this expr.
130  void print(llvm::raw_ostream &os) const;
131 
132 protected:
133  Expr(Kind kind) : kind(kind) {}
134  llvm::hash_code hash_value() const { return llvm::hash_value(kind); }
135 };
136 
137 /// Helper class to CRTP-derive common functions.
138 template <class DerivedT, Expr::Kind DerivedKind>
139 struct ExprBase : public Expr {
140  ExprBase() : Expr(DerivedKind) {}
141  static bool classof(const Expr *e) { return e->kind == DerivedKind; }
142  bool operator==(const Expr &other) const {
143  if (auto otherSame = dyn_cast<DerivedT>(other))
144  return *static_cast<DerivedT *>(this) == otherSame;
145  return false;
146  }
147 };
148 
149 /// The root containing all expressions.
150 struct RootExpr : public ExprBase<RootExpr, Expr::Kind::Root> {
151  RootExpr(std::vector<Expr *> &exprs) : exprs(exprs) {}
152  void print(llvm::raw_ostream &os) const { os << "root"; }
153  std::vector<Expr *> &exprs;
154 };
155 
156 /// A free variable.
157 struct VarExpr : public ExprBase<VarExpr, Expr::Kind::Var> {
158  void print(llvm::raw_ostream &os) const {
159  // Hash the `this` pointer into something somewhat human readable. Since
160  // this is just for debug dumping, we wrap around at 65536 variables.
161  os << "var" << ((size_t)this / llvm::PowerOf2Ceil(sizeof(*this)) & 0xFFFF);
162  }
163 
164  /// The constraint expression this variable is supposed to be greater than or
165  /// equal to. This is not part of the variable's hash and equality property.
166  Expr *constraint = nullptr;
167 
168  /// The upper bound this variable is supposed to be smaller than or equal to.
169  Expr *upperBound = nullptr;
170  std::optional<int32_t> upperBoundSolution;
171 };
172 
173 /// A derived width.
174 ///
175 /// These are generated for `InvalidValueOp`s which want to derived their width
176 /// from connect operations that they are on the right hand side of.
177 struct DerivedExpr : public ExprBase<DerivedExpr, Expr::Kind::Derived> {
178  void print(llvm::raw_ostream &os) const {
179  // Hash the `this` pointer into something somewhat human readable.
180  os << "derive"
181  << ((size_t)this / llvm::PowerOf2Ceil(sizeof(*this)) & 0xFFF);
182  }
183 
184  /// The expression this derived width is equivalent to.
185  Expr *assigned = nullptr;
186 };
187 
188 /// An identity expression.
189 ///
190 /// This expression evaluates to its inner expression. It is used in a very
191 /// specific case of constraints on variables, in order to be able to track
192 /// where the constraint was imposed. Constraints on variables are represented
193 /// as `var >= <expr>`. When the first constraint `a` is imposed, it is stored
194 /// as the constraint expression (`var >= a`). When the second constraint `b` is
195 /// imposed, a *new* max expression is allocated (`var >= max(a, b)`).
196 /// Expressions are annotated with a location when they are created, which in
197 /// this case are connect ops. Since imposing the first constraint does not
198 /// create any new expression, the location information of that connect would be
199 /// lost. With an identity expression, imposing the first constraint becomes
200 /// `var >= identity(a)`, which is a *new* expression and properly tracks the
201 /// location info.
202 struct IdExpr : public ExprBase<IdExpr, Expr::Kind::Id> {
203  IdExpr(Expr *arg) : arg(arg) { assert(arg); }
204  void print(llvm::raw_ostream &os) const { os << "*" << *arg; }
205  bool operator==(const IdExpr &other) const {
206  return kind == other.kind && arg == other.arg;
207  }
208  llvm::hash_code hash_value() const {
209  return llvm::hash_combine(Expr::hash_value(), arg);
210  }
211 
212  /// The inner expression.
213  Expr *const arg;
214 };
215 
216 /// A known constant value.
217 struct KnownExpr : public ExprBase<KnownExpr, Expr::Kind::Known> {
218  KnownExpr(int32_t value) : ExprBase() { solution = value; }
219  void print(llvm::raw_ostream &os) const { os << *solution; }
220  bool operator==(const KnownExpr &other) const {
221  return *solution == *other.solution;
222  }
223  llvm::hash_code hash_value() const {
224  return llvm::hash_combine(Expr::hash_value(), *solution);
225  }
226 };
227 
228 /// A unary expression. Contains the actual data. Concrete subclasses are merely
229 /// there for show and ease of use.
230 struct UnaryExpr : public Expr {
231  bool operator==(const UnaryExpr &other) const {
232  return kind == other.kind && arg == other.arg;
233  }
234  llvm::hash_code hash_value() const {
235  return llvm::hash_combine(Expr::hash_value(), arg);
236  }
237 
238  /// The child expression.
239  Expr *const arg;
240 
241 protected:
242  UnaryExpr(Kind kind, Expr *arg) : Expr(kind), arg(arg) { assert(arg); }
243 };
244 
245 /// Helper class to CRTP-derive common functions.
246 template <class DerivedT, Expr::Kind DerivedKind>
247 struct UnaryExprBase : public UnaryExpr {
248  template <typename... Args>
249  UnaryExprBase(Args &&...args)
250  : UnaryExpr(DerivedKind, std::forward<Args>(args)...) {}
251  static bool classof(const Expr *e) { return e->kind == DerivedKind; }
252 };
253 
254 /// A power of two.
255 struct PowExpr : public UnaryExprBase<PowExpr, Expr::Kind::Pow> {
256  using UnaryExprBase::UnaryExprBase;
257  void print(llvm::raw_ostream &os) const { os << "2^" << arg; }
258 };
259 
260 /// A binary expression. Contains the actual data. Concrete subclasses are
261 /// merely there for show and ease of use.
262 struct BinaryExpr : public Expr {
263  bool operator==(const BinaryExpr &other) const {
264  return kind == other.kind && lhs() == other.lhs() && rhs() == other.rhs();
265  }
266  llvm::hash_code hash_value() const {
267  return llvm::hash_combine(Expr::hash_value(), *args);
268  }
269  Expr *lhs() const { return args[0]; }
270  Expr *rhs() const { return args[1]; }
271 
272  /// The child expressions.
273  Expr *const args[2];
274 
275 protected:
276  BinaryExpr(Kind kind, Expr *lhs, Expr *rhs) : Expr(kind), args{lhs, rhs} {
277  assert(lhs);
278  assert(rhs);
279  }
280 };
281 
282 /// Helper class to CRTP-derive common functions.
283 template <class DerivedT, Expr::Kind DerivedKind>
284 struct BinaryExprBase : public BinaryExpr {
285  template <typename... Args>
286  BinaryExprBase(Args &&...args)
287  : BinaryExpr(DerivedKind, std::forward<Args>(args)...) {}
288  static bool classof(const Expr *e) { return e->kind == DerivedKind; }
289 };
290 
291 /// An addition.
292 struct AddExpr : public BinaryExprBase<AddExpr, Expr::Kind::Add> {
293  using BinaryExprBase::BinaryExprBase;
294  void print(llvm::raw_ostream &os) const {
295  os << "(" << *lhs() << " + " << *rhs() << ")";
296  }
297 };
298 
299 /// The maximum of two expressions.
300 struct MaxExpr : public BinaryExprBase<MaxExpr, Expr::Kind::Max> {
301  using BinaryExprBase::BinaryExprBase;
302  void print(llvm::raw_ostream &os) const {
303  os << "max(" << *lhs() << ", " << *rhs() << ")";
304  }
305 };
306 
307 /// The minimum of two expressions.
308 struct MinExpr : public BinaryExprBase<MinExpr, Expr::Kind::Min> {
309  using BinaryExprBase::BinaryExprBase;
310  void print(llvm::raw_ostream &os) const {
311  os << "min(" << *lhs() << ", " << *rhs() << ")";
312  }
313 };
314 
315 void Expr::print(llvm::raw_ostream &os) const {
316  TypeSwitch<const Expr *>(this).Case<EXPR_CLASSES>(
317  [&](auto *e) { e->print(os); });
318 }
319 
320 } // namespace
321 
322 //===----------------------------------------------------------------------===//
323 // Fast bump allocator with optional interning
324 //===----------------------------------------------------------------------===//
325 
326 namespace {
327 
328 /// An allocation slot in the `InternedAllocator`.
329 template <typename T>
330 struct InternedSlot {
331  T *ptr;
332  InternedSlot(T *ptr) : ptr(ptr) {}
333 };
334 
335 /// A simple bump allocator that ensures only ever one copy per object exists.
336 /// The allocated objects must not have a destructor.
337 template <typename T, typename std::enable_if_t<
338  std::is_trivially_destructible<T>::value, int> = 0>
339 class InternedAllocator {
340  using Slot = InternedSlot<T>;
341  llvm::DenseSet<Slot> interned;
342  llvm::BumpPtrAllocator &allocator;
343 
344 public:
345  InternedAllocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
346 
347  /// Allocate a new object if it does not yet exist, or return a pointer to the
348  /// existing one. `R` is the type of the object to be allocated. `R` must be
349  /// derived from or be the type `T`.
350  template <typename R = T, typename... Args>
351  std::pair<R *, bool> alloc(Args &&...args) {
352  auto stack_value = R(std::forward<Args>(args)...);
353  auto stack_slot = Slot(&stack_value);
354  auto it = interned.find(stack_slot);
355  if (it != interned.end())
356  return std::make_pair(static_cast<R *>(it->ptr), false);
357  auto heap_value = new (allocator) R(std::move(stack_value));
358  interned.insert(Slot(heap_value));
359  return std::make_pair(heap_value, true);
360  }
361 };
362 
363 /// A simple bump allocator. The allocated objects must not have a destructor.
364 /// This allocator is mainly there for symmetry with the `InternedAllocator`.
365 template <typename T, typename std::enable_if_t<
366  std::is_trivially_destructible<T>::value, int> = 0>
367 class Allocator {
368  llvm::BumpPtrAllocator &allocator;
369 
370 public:
371  Allocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
372 
373  /// Allocate a new object. `R` is the type of the object to be allocated. `R`
374  /// must be derived from or be the type `T`.
375  template <typename R = T, typename... Args>
376  R *alloc(Args &&...args) {
377  return new (allocator) R(std::forward<Args>(args)...);
378  }
379 };
380 
381 } // namespace
382 
383 //===----------------------------------------------------------------------===//
384 // Constraint Solver
385 //===----------------------------------------------------------------------===//
386 
387 namespace {
388 /// A canonicalized linear inequality that maps a constraint on var `x` to the
389 /// linear inequality `x >= max(a*x+b, c) + (failed ? ∞ : 0)`.
390 ///
391 /// The inequality separately tracks recursive (a, b) and non-recursive (c)
392 /// constraints on `x`. This allows it to properly identify the combination of
393 /// the two constraints `x >= x-1` and `x >= 4` to be satisfiable as
394 /// `x >= max(x-1, 4)`. If it only tracked inequality as `x >= a*x+b`, the
395 /// combination of these two constraints would be `x >= x+4` (due to max(-1,4) =
396 /// 4), which would be unsatisfiable.
397 ///
398 /// The `failed` flag acts as an additional `∞` term that renders the inequality
399 /// unsatisfiable. It is used as a tombstone value in case an operation renders
400 /// the equality unsatisfiable (e.g. `x >= 2**x` would be represented as the
401 /// inequality `x >= ∞`).
402 ///
403 /// Inequalities represented in this form can easily be checked for
404 /// unsatisfiability in the presence of recursion by inspecting the coefficients
405 /// a and b. The `sat` function performs this action.
406 struct LinIneq {
407  // x >= max(a*x+b, c) + (failed ? ∞ : 0)
408  int32_t rec_scale = 0; // a
409  int32_t rec_bias = 0; // b
410  int32_t nonrec_bias = 0; // c
411  bool failed = false;
412 
413  /// Create a new unsatisfiable inequality `x >= ∞`.
414  static LinIneq unsat() { return LinIneq(true); }
415 
416  /// Create a new inequality `x >= (failed ? ∞ : 0)`.
417  explicit LinIneq(bool failed = false) : failed(failed) {}
418 
419  /// Create a new inequality `x >= bias`.
420  explicit LinIneq(int32_t bias) : nonrec_bias(bias) {}
421 
422  /// Create a new inequality `x >= scale*x+bias`.
423  explicit LinIneq(int32_t scale, int32_t bias) {
424  if (scale != 0) {
425  rec_scale = scale;
426  rec_bias = bias;
427  } else {
428  nonrec_bias = bias;
429  }
430  }
431 
432  /// Create a new inequality `x >= max(rec_scale*x+rec_bias, nonrec_bias) +
433  /// (failed ? ∞ : 0)`.
434  explicit LinIneq(int32_t rec_scale, int32_t rec_bias, int32_t nonrec_bias,
435  bool failed = false)
436  : failed(failed) {
437  if (rec_scale != 0) {
438  this->rec_scale = rec_scale;
439  this->rec_bias = rec_bias;
440  this->nonrec_bias = nonrec_bias;
441  } else {
442  this->nonrec_bias = std::max(rec_bias, nonrec_bias);
443  }
444  }
445 
446  /// Combine two inequalities by taking the maxima of corresponding
447  /// coefficients.
448  ///
449  /// This essentially combines `x >= max(a1*x+b1, c1)` and `x >= max(a2*x+b2,
450  /// c2)` into a new `x >= max(max(a1,a2)*x+max(b1,b2), max(c1,c2))`. This is
451  /// a pessimistic upper bound, since e.g. `x >= 2x-10` and `x >= x-5` may both
452  /// hold, but the resulting `x >= 2x-5` may pessimistically not hold.
453  static LinIneq max(const LinIneq &lhs, const LinIneq &rhs) {
454  return LinIneq(std::max(lhs.rec_scale, rhs.rec_scale),
455  std::max(lhs.rec_bias, rhs.rec_bias),
456  std::max(lhs.nonrec_bias, rhs.nonrec_bias),
457  lhs.failed || rhs.failed);
458  }
459 
460  /// Combine two inequalities by summing up the two right hand sides.
461  ///
462  /// This is a tricky one, since the addition of the two max terms will lead to
463  /// a maximum over four possible terms (similar to a binomial expansion). In
464  /// order to shoehorn this back into a two-term maximum, we have to pick the
465  /// recursive term that will grow the fastest.
466  ///
467  /// As an example for this problem, consider the following addition:
468  ///
469  /// x >= max(a1*x+b1, c1) + max(a2*x+b2, c2)
470  ///
471  /// We would like to expand and rearrange this again into a maximum:
472  ///
473  /// x >= max(a1*x+b1 + max(a2*x+b2, c2), c1 + max(a2*x+b2, c2))
474  /// x >= max(max(a1*x+b1 + a2*x+b2, a1*x+b1 + c2),
475  /// max(c1 + a2*x+b2, c1 + c2))
476  /// x >= max((a1+a2)*x+(b1+b2), a1*x+(b1+c2), a2*x+(b2+c1), c1+c2)
477  ///
478  /// Since we are combining two two-term maxima, there are four possible ways
479  /// how the terms can combine, leading to the above four-term maximum. An easy
480  /// upper bound of the form we want would be the following:
481  ///
482  /// x >= max(max(a1+a2, a1, a2)*x + max(b1+b2, b1+c2, b2+c1), c1+c2)
483  ///
484  /// However, this is a very pessimistic upper-bound that will declare very
485  /// common patterns in the IR as unbreakable cycles, despite them being very
486  /// much breakable. For example:
487  ///
488  /// x >= max(x, 42) + max(0, -3) <-- breakable recursion
489  /// x >= max(max(1+0, 1, 0)*x + max(42+0, -3, 42), 42-2)
490  /// x >= max(x + 42, 39) <-- unbreakable recursion!
491  ///
492  /// A better approach is to take the expanded four-term maximum, retain the
493  /// non-recursive term (c1+c2), and estimate which one of the recursive terms
494  /// (first three) will become dominant as we choose greater values for x.
495  /// Since x never is inferred to be negative, the recursive term in the
496  /// maximum with the highest scaling factor for x will end up dominating as
497  /// x tends to ∞:
498  ///
499  /// x >= max({
500  /// (a1+a2)*x+(b1+b2) if a1+a2 >= max(a1+a2, a1, a2) and a1>0 and a2>0,
501  /// a1*x+(b1+c2) if a1 >= max(a1+a2, a1, a2) and a1>0,
502  /// a2*x+(b2+c1) if a2 >= max(a1+a2, a1, a2) and a2>0,
503  /// 0 otherwise
504  /// }, c1+c2)
505  ///
506  /// In case multiple cases apply, the highest bias of the recursive term is
507  /// picked. With this, the above problematic example triggers the second case
508  /// and becomes:
509  ///
510  /// x >= max(1*x+(0-3), 42-3) = max(x-3, 39)
511  ///
512  /// Of which the first case is chosen, as it has the lower bias value.
513  static LinIneq add(const LinIneq &lhs, const LinIneq &rhs) {
514  // Determine the maximum scaling factor among the three possible recursive
515  // terms.
516  auto enable1 = lhs.rec_scale > 0 && rhs.rec_scale > 0;
517  auto enable2 = lhs.rec_scale > 0;
518  auto enable3 = rhs.rec_scale > 0;
519  auto scale1 = lhs.rec_scale + rhs.rec_scale; // (a1+a2)
520  auto scale2 = lhs.rec_scale; // a1
521  auto scale3 = rhs.rec_scale; // a2
522  auto bias1 = lhs.rec_bias + rhs.rec_bias; // (b1+b2)
523  auto bias2 = lhs.rec_bias + rhs.nonrec_bias; // (b1+c2)
524  auto bias3 = rhs.rec_bias + lhs.nonrec_bias; // (b2+c1)
525  auto maxScale = std::max(scale1, std::max(scale2, scale3));
526 
527  // Among those terms that have a maximum scaling factor, determine the
528  // largest bias value.
529  std::optional<int32_t> maxBias;
530  if (enable1 && scale1 == maxScale)
531  maxBias = bias1;
532  if (enable2 && scale2 == maxScale && (!maxBias || bias2 > *maxBias))
533  maxBias = bias2;
534  if (enable3 && scale3 == maxScale && (!maxBias || bias3 > *maxBias))
535  maxBias = bias3;
536 
537  // Pick from the recursive terms the one with maximum scaling factor and
538  // minimum bias value.
539  auto nonrec_bias = lhs.nonrec_bias + rhs.nonrec_bias; // c1+c2
540  auto failed = lhs.failed || rhs.failed;
541  if (enable1 && scale1 == maxScale && bias1 == *maxBias)
542  return LinIneq(scale1, bias1, nonrec_bias, failed);
543  if (enable2 && scale2 == maxScale && bias2 == *maxBias)
544  return LinIneq(scale2, bias2, nonrec_bias, failed);
545  if (enable3 && scale3 == maxScale && bias3 == *maxBias)
546  return LinIneq(scale3, bias3, nonrec_bias, failed);
547  return LinIneq(0, 0, nonrec_bias, failed);
548  }
549 
550  /// Check if the inequality is satisfiable.
551  ///
552  /// The inequality becomes unsatisfiable if the RHS is ∞, or a>1, or a==1 and
553  /// b <= 0. Otherwise there exists as solution for `x` that satisfies the
554  /// inequality.
555  bool sat() const {
556  if (failed)
557  return false;
558  if (rec_scale > 1)
559  return false;
560  if (rec_scale == 1 && rec_bias > 0)
561  return false;
562  return true;
563  }
564 
565  /// Dump the inequality in human-readable form.
566  void print(llvm::raw_ostream &os) const {
567  bool any = false;
568  bool both = (rec_scale != 0 || rec_bias != 0) && nonrec_bias != 0;
569  os << "x >= ";
570  if (both)
571  os << "max(";
572  if (rec_scale != 0) {
573  any = true;
574  if (rec_scale != 1)
575  os << rec_scale << "*";
576  os << "x";
577  }
578  if (rec_bias != 0) {
579  if (any) {
580  if (rec_bias < 0)
581  os << " - " << -rec_bias;
582  else
583  os << " + " << rec_bias;
584  } else {
585  any = true;
586  os << rec_bias;
587  }
588  }
589  if (both)
590  os << ", ";
591  if (nonrec_bias != 0) {
592  any = true;
593  os << nonrec_bias;
594  }
595  if (both)
596  os << ")";
597  if (failed) {
598  if (any)
599  os << " + ";
600  os << "∞";
601  }
602  if (!any)
603  os << "0";
604  }
605 };
606 
607 /// A simple solver for width constraints.
608 class ConstraintSolver {
609 public:
610  ConstraintSolver() = default;
611 
612  VarExpr *var() {
613  auto v = vars.alloc();
614  exprs.push_back(v);
615  if (currentInfo)
616  info[v].insert(currentInfo);
617  if (currentLoc)
618  locs[v].insert(*currentLoc);
619  return v;
620  }
621  DerivedExpr *derived() {
622  auto *d = derivs.alloc();
623  exprs.push_back(d);
624  return d;
625  }
626  KnownExpr *known(int32_t value) { return alloc<KnownExpr>(knowns, value); }
627  IdExpr *id(Expr *arg) { return alloc<IdExpr>(ids, arg); }
628  PowExpr *pow(Expr *arg) { return alloc<PowExpr>(uns, arg); }
629  AddExpr *add(Expr *lhs, Expr *rhs) { return alloc<AddExpr>(bins, lhs, rhs); }
630  MaxExpr *max(Expr *lhs, Expr *rhs) { return alloc<MaxExpr>(bins, lhs, rhs); }
631  MinExpr *min(Expr *lhs, Expr *rhs) { return alloc<MinExpr>(bins, lhs, rhs); }
632 
633  /// Add a constraint `lhs >= rhs`. Multiple constraints on the same variable
634  /// are coalesced into a `max(a, b)` expr.
635  Expr *addGeqConstraint(VarExpr *lhs, Expr *rhs) {
636  if (lhs->constraint)
637  lhs->constraint = max(lhs->constraint, rhs);
638  else
639  lhs->constraint = id(rhs);
640  return lhs->constraint;
641  }
642 
643  /// Add a constraint `lhs <= rhs`. Multiple constraints on the same variable
644  /// are coalesced into a `min(a, b)` expr.
645  Expr *addLeqConstraint(VarExpr *lhs, Expr *rhs) {
646  if (lhs->upperBound)
647  lhs->upperBound = min(lhs->upperBound, rhs);
648  else
649  lhs->upperBound = id(rhs);
650  return lhs->upperBound;
651  }
652 
653  void dumpConstraints(llvm::raw_ostream &os);
654  LogicalResult solve();
655 
656  using ContextInfo = DenseMap<Expr *, llvm::SmallSetVector<FieldRef, 1>>;
657  const ContextInfo &getContextInfo() const { return info; }
658  void setCurrentContextInfo(FieldRef fieldRef) { currentInfo = fieldRef; }
659  void setCurrentLocation(std::optional<Location> loc) { currentLoc = loc; }
660 
661 private:
662  // Allocator for constraint expressions.
663  llvm::BumpPtrAllocator allocator;
664  Allocator<VarExpr> vars = {allocator};
665  Allocator<DerivedExpr> derivs = {allocator};
666  InternedAllocator<KnownExpr> knowns = {allocator};
667  InternedAllocator<IdExpr> ids = {allocator};
668  InternedAllocator<UnaryExpr> uns = {allocator};
669  InternedAllocator<BinaryExpr> bins = {allocator};
670 
671  /// A list of expressions in the order they were created.
672  std::vector<Expr *> exprs;
673  RootExpr root = {exprs};
674 
675  /// Add an allocated expression to the list above.
676  template <typename R, typename T, typename... Args>
677  R *alloc(InternedAllocator<T> &allocator, Args &&...args) {
678  auto it = allocator.template alloc<R>(std::forward<Args>(args)...);
679  if (it.second)
680  exprs.push_back(it.first);
681  if (currentInfo)
682  info[it.first].insert(currentInfo);
683  if (currentLoc)
684  locs[it.first].insert(*currentLoc);
685  return it.first;
686  }
687 
688  /// Contextual information for each expression, indicating which values in the
689  /// IR lead to this expression.
690  ContextInfo info;
691  FieldRef currentInfo = {};
692  DenseMap<Expr *, llvm::SmallSetVector<Location, 1>> locs;
693  std::optional<Location> currentLoc;
694 
695  // Forbid copyign or moving the solver, which would invalidate the refs to
696  // allocator held by the allocators.
697  ConstraintSolver(ConstraintSolver &&) = delete;
698  ConstraintSolver(const ConstraintSolver &) = delete;
699  ConstraintSolver &operator=(ConstraintSolver &&) = delete;
700  ConstraintSolver &operator=(const ConstraintSolver &) = delete;
701 
702  void emitUninferredWidthError(VarExpr *var);
703 
704  LinIneq checkCycles(VarExpr *var, Expr *expr,
705  SmallPtrSetImpl<Expr *> &seenVars,
706  InFlightDiagnostic *reportInto = nullptr,
707  unsigned indent = 1);
708 };
709 
710 } // namespace
711 
712 /// Print all constraints in the solver to an output stream.
713 void ConstraintSolver::dumpConstraints(llvm::raw_ostream &os) {
714  for (auto *e : exprs) {
715  if (auto *v = dyn_cast<VarExpr>(e)) {
716  if (v->constraint)
717  os << "- " << *v << " >= " << *v->constraint << "\n";
718  else
719  os << "- " << *v << " unconstrained\n";
720  }
721  }
722 }
723 
724 #ifndef NDEBUG
725 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const LinIneq &l) {
726  l.print(os);
727  return os;
728 }
729 #endif
730 
731 /// Compute the canonicalized linear inequality expression starting at `expr`,
732 /// for the `var` as the left hand side `x` of the inequality. `seenVars` is
733 /// used as a recursion breaker. Occurrences of `var` itself within the
734 /// expression are mapped to the `a` coefficient in the inequality. Any other
735 /// variables are substituted and, in the presence of a recursion in a variable
736 /// other than `var`, treated as zero. `info` is a mapping from constraint
737 /// expressions to values and operations that produced the expression, and is
738 /// used during error reporting. If `reportInto` is present, the function will
739 /// additionally attach unsatisfiable inequalities as notes to the diagnostic as
740 /// it encounters them.
741 LinIneq ConstraintSolver::checkCycles(VarExpr *var, Expr *expr,
742  SmallPtrSetImpl<Expr *> &seenVars,
743  InFlightDiagnostic *reportInto,
744  unsigned indent) {
745  auto ineq =
746  TypeSwitch<Expr *, LinIneq>(expr)
747  .Case<KnownExpr>([&](auto *expr) { return LinIneq(*expr->solution); })
748  .Case<VarExpr>([&](auto *expr) {
749  if (expr == var)
750  return LinIneq(1, 0); // x >= 1*x + 0
751  if (!seenVars.insert(expr).second)
752  // Count recursions in other variables as 0. This is sane
753  // since the cycle is either breakable, in which case the
754  // recursion does not modify the resulting value of the
755  // variable, or it is not breakable and will be caught by
756  // this very function once it is called on that variable.
757  return LinIneq(0);
758  if (!expr->constraint)
759  // Count unconstrained variables as `x >= 0`.
760  return LinIneq(0);
761  auto l = checkCycles(var, expr->constraint, seenVars, reportInto,
762  indent + 1);
763  seenVars.erase(expr);
764  return l;
765  })
766  .Case<IdExpr>([&](auto *expr) {
767  return checkCycles(var, expr->arg, seenVars, reportInto,
768  indent + 1);
769  })
770  .Case<PowExpr>([&](auto *expr) {
771  // If we can evaluate `2**arg` to a sensible constant, do
772  // so. This is the case if a == 0 and c < 31 such that 2**c is
773  // representable.
774  auto arg =
775  checkCycles(var, expr->arg, seenVars, reportInto, indent + 1);
776  if (arg.rec_scale != 0 || arg.nonrec_bias < 0 ||
777  arg.nonrec_bias >= 31)
778  return LinIneq::unsat();
779  return LinIneq(1 << arg.nonrec_bias); // x >= 2**arg
780  })
781  .Case<AddExpr>([&](auto *expr) {
782  return LinIneq::add(
783  checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
784  checkCycles(var, expr->rhs(), seenVars, reportInto,
785  indent + 1));
786  })
787  .Case<MaxExpr, MinExpr>([&](auto *expr) {
788  // Combine the inequalities of the LHS and RHS into a single overly
789  // pessimistic inequality. We treat `MinExpr` the same as `MaxExpr`,
790  // since `max(a,b)` is an upper bound to `min(a,b)`.
791  return LinIneq::max(
792  checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
793  checkCycles(var, expr->rhs(), seenVars, reportInto,
794  indent + 1));
795  })
796  .Default([](auto) { return LinIneq::unsat(); });
797 
798  // If we were passed an in-flight diagnostic and the current inequality is
799  // unsatisfiable, attach notes to the diagnostic indicating the values or
800  // operations that contributed to this part of the constraint expression.
801  if (reportInto && !ineq.sat()) {
802  auto report = [&](Location loc) {
803  auto &note = reportInto->attachNote(loc);
804  note << "constrained width W >= ";
805  if (ineq.rec_scale == -1)
806  note << "-";
807  if (ineq.rec_scale != 1)
808  note << ineq.rec_scale;
809  note << "W";
810  if (ineq.rec_bias < 0)
811  note << "-" << -ineq.rec_bias;
812  if (ineq.rec_bias > 0)
813  note << "+" << ineq.rec_bias;
814  note << " here:";
815  };
816  auto it = locs.find(expr);
817  if (it != locs.end())
818  for (auto loc : it->second)
819  report(loc);
820  }
821  if (!reportInto)
822  LLVM_DEBUG(llvm::dbgs().indent(indent * 2)
823  << "- Visited " << *expr << ": " << ineq << "\n");
824 
825  return ineq;
826 }
827 
828 using ExprSolution = std::pair<std::optional<int32_t>, bool>;
829 
830 static ExprSolution
831 computeUnary(ExprSolution arg, llvm::function_ref<int32_t(int32_t)> operation) {
832  if (arg.first)
833  arg.first = operation(*arg.first);
834  return arg;
835 }
836 
837 static ExprSolution
839  llvm::function_ref<int32_t(int32_t, int32_t)> operation) {
840  auto result = ExprSolution{std::nullopt, lhs.second || rhs.second};
841  if (lhs.first && rhs.first)
842  result.first = operation(*lhs.first, *rhs.first);
843  else if (lhs.first)
844  result.first = lhs.first;
845  else if (rhs.first)
846  result.first = rhs.first;
847  return result;
848 }
849 
850 /// Compute the value of a constraint `expr`. `seenVars` is used as a recursion
851 /// breaker. Recursive variables are treated as zero. Returns the computed value
852 /// and a boolean indicating whether a recursion was detected. This may be used
853 /// to memoize the result of expressions in case they were not involved in a
854 /// cycle (which may alter their value from the perspective of a variable).
855 static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
856  unsigned defaultWorklistSize) {
857 
858  struct Frame {
859  Expr *expr;
860  unsigned indent;
861  };
862 
863  // indent only used for debug logs.
864  unsigned indent = 1;
865  std::vector<Frame> worklist({{expr, indent}});
866  llvm::DenseMap<Expr *, ExprSolution> solvedExprs;
867  // Reserving the vector size, to avoid frequent reallocs. The worklist can be
868  // quite large.
869  worklist.reserve(defaultWorklistSize);
870 
871  while (!worklist.empty()) {
872  auto &frame = worklist.back();
873  auto setSolution = [&](ExprSolution solution) {
874  // Memoize the result.
875  if (solution.first && !solution.second)
876  frame.expr->solution = *solution.first;
877  solvedExprs[frame.expr] = solution;
878 
879  // Produce some useful debug prints.
880  LLVM_DEBUG({
881  if (!isa<KnownExpr>(frame.expr)) {
882  if (solution.first)
883  llvm::dbgs().indent(frame.indent * 2)
884  << "= Solved " << *frame.expr << " = " << *solution.first;
885  else
886  llvm::dbgs().indent(frame.indent * 2)
887  << "= Skipped " << *frame.expr;
888  llvm::dbgs() << " (" << (solution.second ? "cycle broken" : "unique")
889  << ")\n";
890  }
891  });
892 
893  worklist.pop_back();
894  };
895 
896  // See if we have a memoized result we can return.
897  if (frame.expr->solution) {
898  LLVM_DEBUG({
899  if (!isa<KnownExpr>(frame.expr))
900  llvm::dbgs().indent(indent * 2) << "- Cached " << *frame.expr << " = "
901  << *frame.expr->solution << "\n";
902  });
903  setSolution(ExprSolution{*frame.expr->solution, false});
904  continue;
905  }
906 
907  // Otherwise compute the value of the expression.
908  LLVM_DEBUG({
909  if (!isa<KnownExpr>(frame.expr))
910  llvm::dbgs().indent(frame.indent * 2)
911  << "- Solving " << *frame.expr << "\n";
912  });
913 
914  TypeSwitch<Expr *>(frame.expr)
915  .Case<KnownExpr>([&](auto *expr) {
916  setSolution(ExprSolution{*expr->solution, false});
917  })
918  .Case<VarExpr>([&](auto *expr) {
919  if (solvedExprs.contains(expr->constraint)) {
920  auto solution = solvedExprs[expr->constraint];
921  // If we've solved the upper bound already, store the solution.
922  // This will be explicitly solved for later if not computed as
923  // part of the solving that resolved this constraint.
924  // This should only happen if somehow the constraint is
925  // solved before visiting this expression, so that our upperBound
926  // was not added to the worklist such that it was handled first.
927  if (expr->upperBound && solvedExprs.contains(expr->upperBound))
928  expr->upperBoundSolution = solvedExprs[expr->upperBound].first;
929  seenVars.erase(expr);
930  // Constrain variables >= 0.
931  if (solution.first && *solution.first < 0)
932  solution.first = 0;
933  return setSolution(solution);
934  }
935 
936  // Unconstrained variables produce no solution.
937  if (!expr->constraint)
938  return setSolution(ExprSolution{std::nullopt, false});
939  // Return no solution for recursions in the variables. This is sane
940  // and will cause the expression to be ignored when computing the
941  // parent, e.g. `a >= max(a, 1)` will become just `a >= 1`.
942  if (!seenVars.insert(expr).second)
943  return setSolution(ExprSolution{std::nullopt, true});
944 
945  worklist.push_back({expr->constraint, indent + 1});
946  if (expr->upperBound)
947  worklist.push_back({expr->upperBound, indent + 1});
948  })
949  .Case<IdExpr>([&](auto *expr) {
950  if (solvedExprs.contains(expr->arg))
951  return setSolution(solvedExprs[expr->arg]);
952  worklist.push_back({expr->arg, indent + 1});
953  })
954  .Case<PowExpr>([&](auto *expr) {
955  if (solvedExprs.contains(expr->arg))
956  return setSolution(computeUnary(
957  solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));
958 
959  worklist.push_back({expr->arg, indent + 1});
960  })
961  .Case<AddExpr>([&](auto *expr) {
962  if (solvedExprs.contains(expr->lhs()) &&
963  solvedExprs.contains(expr->rhs()))
964  return setSolution(computeBinary(
965  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
966  [](int32_t lhs, int32_t rhs) { return lhs + rhs; }));
967 
968  worklist.push_back({expr->lhs(), indent + 1});
969  worklist.push_back({expr->rhs(), indent + 1});
970  })
971  .Case<MaxExpr>([&](auto *expr) {
972  if (solvedExprs.contains(expr->lhs()) &&
973  solvedExprs.contains(expr->rhs()))
974  return setSolution(computeBinary(
975  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
976  [](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));
977 
978  worklist.push_back({expr->lhs(), indent + 1});
979  worklist.push_back({expr->rhs(), indent + 1});
980  })
981  .Case<MinExpr>([&](auto *expr) {
982  if (solvedExprs.contains(expr->lhs()) &&
983  solvedExprs.contains(expr->rhs()))
984  return setSolution(computeBinary(
985  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
986  [](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));
987 
988  worklist.push_back({expr->lhs(), indent + 1});
989  worklist.push_back({expr->rhs(), indent + 1});
990  })
991  .Default([&](auto) {
992  setSolution(ExprSolution{std::nullopt, false});
993  });
994  }
995 
996  return solvedExprs[expr];
997 }
998 
999 /// Solve the constraint problem. This is a very simple implementation that
1000 /// does not fully solve the problem if there are weird dependency cycles
1001 /// present.
1002 LogicalResult ConstraintSolver::solve() {
1003  LLVM_DEBUG({
1004  llvm::dbgs() << "\n";
1005  debugHeader("Constraints") << "\n\n";
1006  dumpConstraints(llvm::dbgs());
1007  });
1008 
1009  // Ensure that there are no adverse cycles around.
1010  LLVM_DEBUG({
1011  llvm::dbgs() << "\n";
1012  debugHeader("Checking for unbreakable loops") << "\n\n";
1013  });
1014  SmallPtrSet<Expr *, 16> seenVars;
1015  bool anyFailed = false;
1016 
1017  for (auto *expr : exprs) {
1018  // Only work on variables.
1019  auto *var = dyn_cast<VarExpr>(expr);
1020  if (!var || !var->constraint)
1021  continue;
1022  LLVM_DEBUG(llvm::dbgs()
1023  << "- Checking " << *var << " >= " << *var->constraint << "\n");
1024 
1025  // Canonicalize the variable's constraint expression into a form that allows
1026  // us to easily determine if any recursion leads to an unsatisfiable
1027  // constraint. The `seenVars` set acts as a recursion breaker.
1028  seenVars.insert(var);
1029  auto ineq = checkCycles(var, var->constraint, seenVars);
1030  seenVars.clear();
1031 
1032  // If the constraint is satisfiable, we're done.
1033  // TODO: It's possible that this result is already sufficient to arrive at a
1034  // solution for the constraint, and the second pass further down is not
1035  // necessary. This would require more proper handling of `MinExpr` in the
1036  // cycle checking code.
1037  if (ineq.sat()) {
1038  LLVM_DEBUG(llvm::dbgs()
1039  << " = Breakable since " << ineq << " satisfiable\n");
1040  continue;
1041  }
1042 
1043  // If we arrive here, the constraint is not satisfiable at all. To provide
1044  // some guidance to the user, we call the cycle checking code again, but
1045  // this time with an in-flight diagnostic to attach notes indicating
1046  // unsatisfiable paths in the cycle.
1047  LLVM_DEBUG(llvm::dbgs()
1048  << " = UNBREAKABLE since " << ineq << " unsatisfiable\n");
1049  anyFailed = true;
1050  for (auto fieldRef : info.find(var)->second) {
1051  // Depending on whether this value stems from an operation or not, create
1052  // an appropriate diagnostic identifying the value.
1053  auto op = fieldRef.getDefiningOp();
1054  auto diag = op ? op->emitOpError()
1055  : mlir::emitError(fieldRef.getValue().getLoc())
1056  << "value ";
1057  diag << "is constrained to be wider than itself";
1058 
1059  // Re-run the cycle checking, but this time reporting into the diagnostic.
1060  seenVars.insert(var);
1061  checkCycles(var, var->constraint, seenVars, &diag);
1062  seenVars.clear();
1063  }
1064  }
1065 
1066  // If there were cycles, return now to avoid complaining to the user about
1067  // dependent widths not being inferred.
1068  if (anyFailed)
1069  return failure();
1070 
1071  // Iterate over the constraint variables and solve each.
1072  LLVM_DEBUG({
1073  llvm::dbgs() << "\n";
1074  debugHeader("Solving constraints") << "\n\n";
1075  });
1076  unsigned defaultWorklistSize = exprs.size() / 2;
1077  for (auto *expr : exprs) {
1078  // Only work on variables.
1079  auto *var = dyn_cast<VarExpr>(expr);
1080  if (!var)
1081  continue;
1082 
1083  // Complain about unconstrained variables.
1084  if (!var->constraint) {
1085  LLVM_DEBUG(llvm::dbgs() << "- Unconstrained " << *var << "\n");
1086  emitUninferredWidthError(var);
1087  anyFailed = true;
1088  continue;
1089  }
1090 
1091  // Compute the value for the variable.
1092  LLVM_DEBUG(llvm::dbgs()
1093  << "- Solving " << *var << " >= " << *var->constraint << "\n");
1094  seenVars.insert(var);
1095  auto solution = solveExpr(var->constraint, seenVars, defaultWorklistSize);
1096  // Compute the upperBound if there is one and haven't already.
1097  if (var->upperBound && !var->upperBoundSolution)
1098  var->upperBoundSolution =
1099  solveExpr(var->upperBound, seenVars, defaultWorklistSize).first;
1100  seenVars.clear();
1101 
1102  // Constrain variables >= 0.
1103  if (solution.first && *solution.first < 0)
1104  solution.first = 0;
1105  var->solution = solution.first;
1106 
1107  // In case the width could not be inferred, complain to the user. This might
1108  // be the case if the width depends on an unconstrained variable.
1109  if (!solution.first) {
1110  LLVM_DEBUG(llvm::dbgs() << " - UNSOLVED " << *var << "\n");
1111  emitUninferredWidthError(var);
1112  anyFailed = true;
1113  continue;
1114  }
1115  LLVM_DEBUG(llvm::dbgs()
1116  << " = Solved " << *var << " = " << solution.first << " ("
1117  << (solution.second ? "cycle broken" : "unique") << ")\n");
1118 
1119  // Check if the solution we have found violates an upper bound.
1120  if (var->upperBoundSolution && var->upperBoundSolution < *solution.first) {
1121  LLVM_DEBUG(llvm::dbgs() << " ! Unsatisfiable " << *var
1122  << " <= " << var->upperBoundSolution << "\n");
1123  emitUninferredWidthError(var);
1124  anyFailed = true;
1125  }
1126  }
1127 
1128  // Copy over derived widths.
1129  for (auto *expr : exprs) {
1130  // Only work on derived values.
1131  auto *derived = dyn_cast<DerivedExpr>(expr);
1132  if (!derived)
1133  continue;
1134 
1135  auto *assigned = derived->assigned;
1136  if (!assigned || !assigned->solution) {
1137  LLVM_DEBUG(llvm::dbgs() << "- Unused " << *derived << " set to 0\n");
1138  derived->solution = 0;
1139  } else {
1140  LLVM_DEBUG(llvm::dbgs() << "- Deriving " << *derived << " = "
1141  << assigned->solution << "\n");
1142  derived->solution = *assigned->solution;
1143  }
1144  }
1145 
1146  return failure(anyFailed);
1147 }
1148 
1149 // Emits the diagnostic to inform the user about an uninferred width in the
1150 // design. Returns true if an error was reported, false otherwise.
1151 void ConstraintSolver::emitUninferredWidthError(VarExpr *var) {
1152  FieldRef fieldRef = info.find(var)->second.back();
1153  Value value = fieldRef.getValue();
1154 
1155  auto diag = mlir::emitError(value.getLoc(), "uninferred width:");
1156 
1157  // Try to hint the user at what kind of node this is.
1158  if (isa<BlockArgument>(value)) {
1159  diag << " port";
1160  } else if (auto op = value.getDefiningOp()) {
1161  TypeSwitch<Operation *>(op)
1162  .Case<WireOp>([&](auto) { diag << " wire"; })
1163  .Case<RegOp, RegResetOp>([&](auto) { diag << " reg"; })
1164  .Case<NodeOp>([&](auto) { diag << " node"; })
1165  .Default([&](auto) { diag << " value"; });
1166  } else {
1167  diag << " value";
1168  }
1169 
1170  // Actually print what the user can refer to.
1171  auto [fieldName, rootKnown] = getFieldName(fieldRef);
1172  if (!fieldName.empty()) {
1173  if (!rootKnown)
1174  diag << " field";
1175  diag << " \"" << fieldName << "\"";
1176  }
1177 
1178  if (!var->constraint) {
1179  diag << " is unconstrained";
1180  } else if (var->solution && var->upperBoundSolution &&
1181  var->solution > var->upperBoundSolution) {
1182  diag << " cannot satisfy all width requirements";
1183  LLVM_DEBUG(llvm::dbgs() << *var->constraint << "\n");
1184  LLVM_DEBUG(llvm::dbgs() << *var->upperBound << "\n");
1185  auto loc = locs.find(var->constraint)->second.back();
1186  diag.attachNote(loc) << "width is constrained to be at least "
1187  << *var->solution << " here:";
1188  loc = locs.find(var->upperBound)->second.back();
1189  diag.attachNote(loc) << "width is constrained to be at most "
1190  << *var->upperBoundSolution << " here:";
1191  } else {
1192  diag << " width cannot be determined";
1193  LLVM_DEBUG(llvm::dbgs() << *var->constraint << "\n");
1194  auto loc = locs.find(var->constraint)->second.back();
1195  diag.attachNote(loc) << "width is constrained by an uninferred width here:";
1196  }
1197 }
1198 
1199 //===----------------------------------------------------------------------===//
1200 // Inference Constraint Problem Mapping
1201 //===----------------------------------------------------------------------===//
1202 
1203 namespace {
1204 
1205 /// A helper class which maps the types and operations in a design to a set of
1206 /// variables and constraints to be solved later.
1207 class InferenceMapping {
1208 public:
1209  InferenceMapping(ConstraintSolver &solver, SymbolTable &symtbl,
1210  hw::InnerSymbolTableCollection &istc)
1211  : solver(solver), symtbl(symtbl), irn{symtbl, istc} {}
1212 
1213  LogicalResult map(CircuitOp op);
1214  LogicalResult mapOperation(Operation *op);
1215 
1216  /// Declare all the variables in the value. If the value is a ground type,
1217  /// there is a single variable declared. If the value is an aggregate type,
1218  /// it sets up variables for each unknown width.
1219  void declareVars(Value value, Location loc, bool isDerived = false);
1220 
1221  /// Declare a variable associated with a specific field of an aggregate.
1222  Expr *declareVar(FieldRef fieldRef, Location loc);
1223 
1224  /// Declarate a variable for a type with an unknown width. The type must be a
1225  /// non-aggregate.
1226  Expr *declareVar(FIRRTLType type, Location loc);
1227 
1228  /// Assign the constraint expressions of the fields in the `result` argument
1229  /// as the max of expressions in the `rhs` and `lhs` arguments. Both fields
1230  /// must be the same type.
1231  void maximumOfTypes(Value result, Value rhs, Value lhs);
1232 
1233  /// Constrain the value "larger" to be greater than or equal to "smaller".
1234  /// These may be aggregate values. This is used for regular connects.
1235  void constrainTypes(Value larger, Value smaller, bool equal = false);
1236 
1237  /// Constrain the expression "larger" to be greater than or equals to
1238  /// the expression "smaller".
1239  void constrainTypes(Expr *larger, Expr *smaller,
1240  bool imposeUpperBounds = false, bool equal = false);
1241 
1242  /// Assign the constraint expressions of the fields in the `src` argument as
1243  /// the expressions for the `dst` argument. Both fields must be of the given
1244  /// `type`.
1245  void unifyTypes(FieldRef lhs, FieldRef rhs, FIRRTLType type);
1246 
1247  /// Get the expr associated with the value. The value must be a non-aggregate
1248  /// type.
1249  Expr *getExpr(Value value) const;
1250 
1251  /// Get the expr associated with a specific field in a value.
1252  Expr *getExpr(FieldRef fieldRef) const;
1253 
1254  /// Get the expr associated with a specific field in a value. If value is
1255  /// NULL, then this returns NULL.
1256  Expr *getExprOrNull(FieldRef fieldRef) const;
1257 
1258  /// Set the expr associated with the value. The value must be a non-aggregate
1259  /// type.
1260  void setExpr(Value value, Expr *expr);
1261 
1262  /// Set the expr associated with a specific field in a value.
1263  void setExpr(FieldRef fieldRef, Expr *expr);
1264 
1265  /// Return whether a module was skipped due to being fully inferred already.
1266  bool isModuleSkipped(FModuleOp module) const {
1267  return skippedModules.count(module);
1268  }
1269 
1270  /// Return whether all modules in the mapping were fully inferred.
1271  bool areAllModulesSkipped() const { return allModulesSkipped; }
1272 
1273 private:
1274  /// The constraint solver into which we emit variables and constraints.
1275  ConstraintSolver &solver;
1276 
1277  /// The constraint exprs for each result type of an operation.
1278  DenseMap<FieldRef, Expr *> opExprs;
1279 
1280  /// The fully inferred modules that were skipped entirely.
1281  SmallPtrSet<Operation *, 16> skippedModules;
1282  bool allModulesSkipped = true;
1283 
1284  /// Cache of module symbols
1285  SymbolTable &symtbl;
1286 
1287  /// Full design inner symbol information.
1288  hw::InnerRefNamespace irn;
1289 };
1290 
1291 } // namespace
1292 
1293 /// Check if a type contains any FIRRTL type with uninferred widths.
1294 static bool hasUninferredWidth(Type type) {
1295  return TypeSwitch<Type, bool>(type)
1296  .Case<FIRRTLBaseType>([](auto base) { return base.hasUninferredWidth(); })
1297  .Case<RefType>(
1298  [](auto ref) { return ref.getType().hasUninferredWidth(); })
1299  .Default([](auto) { return false; });
1300 }
1301 
1302 LogicalResult InferenceMapping::map(CircuitOp op) {
1303  LLVM_DEBUG(llvm::dbgs()
1304  << "\n===----- Mapping ops to constraint exprs -----===\n\n");
1305 
1306  // Ensure we have constraint variables established for all module ports.
1307  for (auto module : op.getOps<FModuleOp>())
1308  for (auto arg : module.getArguments()) {
1309  solver.setCurrentContextInfo(FieldRef(arg, 0));
1310  declareVars(arg, module.getLoc());
1311  }
1312 
1313  for (auto module : op.getOps<FModuleOp>()) {
1314  // Check if the module contains *any* uninferred widths. This allows us to
1315  // do an early skip if the module is already fully inferred.
1316  bool anyUninferred = false;
1317  for (auto arg : module.getArguments()) {
1318  anyUninferred |= hasUninferredWidth(arg.getType());
1319  if (anyUninferred)
1320  break;
1321  }
1322  module.walk([&](Operation *op) {
1323  for (auto type : op->getResultTypes())
1324  anyUninferred |= hasUninferredWidth(type);
1325  if (anyUninferred)
1326  return WalkResult::interrupt();
1327  return WalkResult::advance();
1328  });
1329 
1330  if (!anyUninferred) {
1331  LLVM_DEBUG(llvm::dbgs() << "Skipping fully-inferred module '"
1332  << module.getName() << "'\n");
1333  skippedModules.insert(module);
1334  continue;
1335  }
1336 
1337  allModulesSkipped = false;
1338 
1339  // Go through operations in the module, creating type variables for results,
1340  // and generating constraints.
1341  auto result = module.getBodyBlock()->walk(
1342  [&](Operation *op) { return WalkResult(mapOperation(op)); });
1343  if (result.wasInterrupted())
1344  return failure();
1345  }
1346 
1347  return success();
1348 }
1349 
1350 LogicalResult InferenceMapping::mapOperation(Operation *op) {
1351  // In case the operation result has a type without uninferred widths, don't
1352  // even bother to populate the constraint problem and treat that as a known
1353  // size directly. This is done in `declareVars`, which will generate
1354  // `KnownExpr` nodes for all known widths -- which are the only ones in this
1355  // case.
1356  bool allWidthsKnown = true;
1357  for (auto result : op->getResults()) {
1358  if (isa<MuxPrimOp, Mux4CellIntrinsicOp, Mux2CellIntrinsicOp>(op))
1359  if (hasUninferredWidth(op->getOperand(0).getType()))
1360  allWidthsKnown = false;
1361  // Only consider FIRRTL types for width constraints. Ignore any foreign
1362  // types as they don't participate in the width inference process.
1363  auto resultTy = type_dyn_cast<FIRRTLType>(result.getType());
1364  if (!resultTy)
1365  continue;
1366  if (!hasUninferredWidth(resultTy))
1367  declareVars(result, op->getLoc());
1368  else
1369  allWidthsKnown = false;
1370  }
1371  if (allWidthsKnown && !isa<FConnectLike, AttachOp>(op))
1372  return success();
1373 
1374  /// Ignore property assignments, no widths to infer.
1375  if (isa<PropAssignOp>(op))
1376  return success();
1377 
1378  // Actually generate the necessary constraint expressions.
1379  bool mappingFailed = false;
1380  solver.setCurrentContextInfo(
1381  op->getNumResults() > 0 ? FieldRef(op->getResults()[0], 0) : FieldRef());
1382  solver.setCurrentLocation(op->getLoc());
1383  TypeSwitch<Operation *>(op)
1384  .Case<ConstantOp>([&](auto op) {
1385  // If the constant has a known width, use that. Otherwise pick the
1386  // smallest number of bits necessary to represent the constant.
1387  Expr *e;
1388  if (auto width = op.getType().base().getWidth())
1389  e = solver.known(*width);
1390  else {
1391  auto v = op.getValue();
1392  auto w = v.getBitWidth() - (v.isNegative() ? v.countLeadingOnes()
1393  : v.countLeadingZeros());
1394  if (v.isSigned())
1395  w += 1;
1396  e = solver.known(std::max(w, 1u));
1397  }
1398  setExpr(op.getResult(), e);
1399  })
1400  .Case<SpecialConstantOp>([&](auto op) {
1401  // Nothing required.
1402  })
1403  .Case<InvalidValueOp>([&](auto op) {
1404  // We must duplicate the invalid value for each use, since each use can
1405  // be inferred to a different width.
1406  if (!hasUninferredWidth(op.getType()))
1407  return;
1408  declareVars(op.getResult(), op.getLoc(), /*isDerived=*/true);
1409  if (op.use_empty())
1410  return;
1411 
1412  auto type = op.getType();
1413  ImplicitLocOpBuilder builder(op->getLoc(), op);
1414  for (auto &use :
1415  llvm::make_early_inc_range(llvm::drop_begin(op->getUses()))) {
1416  // - `make_early_inc_range` since `getUses()` is invalidated upon
1417  // `use.set(...)`.
1418  // - `drop_begin` such that the first use can keep the original op.
1419  auto clone = builder.create<InvalidValueOp>(type);
1420  declareVars(clone.getResult(), clone.getLoc(),
1421  /*isDerived=*/true);
1422  use.set(clone);
1423  }
1424  })
1425  .Case<WireOp, RegOp>(
1426  [&](auto op) { declareVars(op.getResult(), op.getLoc()); })
1427  .Case<RegResetOp>([&](auto op) {
1428  // The original Scala code also constrains the reset signal to be at
1429  // least 1 bit wide. We don't do this here since the MLIR FIRRTL
1430  // dialect enforces the reset signal to be an async reset or a
1431  // `uint<1>`.
1432  declareVars(op.getResult(), op.getLoc());
1433  // Contrain the register to be greater than or equal to the reset
1434  // signal.
1435  constrainTypes(op.getResult(), op.getResetValue());
1436  })
1437  .Case<NodeOp>([&](auto op) {
1438  // Nodes have the same type as their input.
1439  unifyTypes(FieldRef(op.getResult(), 0), FieldRef(op.getInput(), 0),
1440  op.getResult().getType());
1441  })
1442 
1443  // Aggregate Values
1444  .Case<SubfieldOp>([&](auto op) {
1445  BundleType bundleType = op.getInput().getType();
1446  auto fieldID = bundleType.getFieldID(op.getFieldIndex());
1447  unifyTypes(FieldRef(op.getResult(), 0),
1448  FieldRef(op.getInput(), fieldID), op.getType());
1449  })
1450  .Case<SubindexOp, SubaccessOp>([&](auto op) {
1451  // All vec fields unify to the same thing. Always use the first element
1452  // of the vector, which has a field ID of 1.
1453  unifyTypes(FieldRef(op.getResult(), 0), FieldRef(op.getInput(), 1),
1454  op.getType());
1455  })
1456  .Case<SubtagOp>([&](auto op) {
1457  FEnumType enumType = op.getInput().getType();
1458  auto fieldID = enumType.getFieldID(op.getFieldIndex());
1459  unifyTypes(FieldRef(op.getResult(), 0),
1460  FieldRef(op.getInput(), fieldID), op.getType());
1461  })
1462 
1463  .Case<RefSubOp>([&](RefSubOp op) {
1464  uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(
1465  op.getInput().getType().getType())
1466  .Case<FVectorType>([](auto _) { return 1; })
1467  .Case<BundleType>([&](auto type) {
1468  return type.getFieldID(op.getIndex());
1469  });
1470  unifyTypes(FieldRef(op.getResult(), 0),
1471  FieldRef(op.getInput(), fieldID), op.getType());
1472  })
1473 
1474  // Arithmetic and Logical Binary Primitives
1475  .Case<AddPrimOp, SubPrimOp>([&](auto op) {
1476  auto lhs = getExpr(op.getLhs());
1477  auto rhs = getExpr(op.getRhs());
1478  auto e = solver.add(solver.max(lhs, rhs), solver.known(1));
1479  setExpr(op.getResult(), e);
1480  })
1481  .Case<MulPrimOp>([&](auto op) {
1482  auto lhs = getExpr(op.getLhs());
1483  auto rhs = getExpr(op.getRhs());
1484  auto e = solver.add(lhs, rhs);
1485  setExpr(op.getResult(), e);
1486  })
1487  .Case<DivPrimOp>([&](auto op) {
1488  auto lhs = getExpr(op.getLhs());
1489  Expr *e;
1490  if (op.getType().base().isSigned()) {
1491  e = solver.add(lhs, solver.known(1));
1492  } else {
1493  e = lhs;
1494  }
1495  setExpr(op.getResult(), e);
1496  })
1497  .Case<RemPrimOp>([&](auto op) {
1498  auto lhs = getExpr(op.getLhs());
1499  auto rhs = getExpr(op.getRhs());
1500  auto e = solver.min(lhs, rhs);
1501  setExpr(op.getResult(), e);
1502  })
1503  .Case<AndPrimOp, OrPrimOp, XorPrimOp>([&](auto op) {
1504  auto lhs = getExpr(op.getLhs());
1505  auto rhs = getExpr(op.getRhs());
1506  auto e = solver.max(lhs, rhs);
1507  setExpr(op.getResult(), e);
1508  })
1509 
1510  // Misc Binary Primitives
1511  .Case<CatPrimOp>([&](auto op) {
1512  auto lhs = getExpr(op.getLhs());
1513  auto rhs = getExpr(op.getRhs());
1514  auto e = solver.add(lhs, rhs);
1515  setExpr(op.getResult(), e);
1516  })
1517  .Case<DShlPrimOp>([&](auto op) {
1518  auto lhs = getExpr(op.getLhs());
1519  auto rhs = getExpr(op.getRhs());
1520  auto e = solver.add(lhs, solver.add(solver.pow(rhs), solver.known(-1)));
1521  setExpr(op.getResult(), e);
1522  })
1523  .Case<DShlwPrimOp, DShrPrimOp>([&](auto op) {
1524  auto e = getExpr(op.getLhs());
1525  setExpr(op.getResult(), e);
1526  })
1527 
1528  // Unary operators
1529  .Case<NegPrimOp>([&](auto op) {
1530  auto input = getExpr(op.getInput());
1531  auto e = solver.add(input, solver.known(1));
1532  setExpr(op.getResult(), e);
1533  })
1534  .Case<CvtPrimOp>([&](auto op) {
1535  auto input = getExpr(op.getInput());
1536  auto e = op.getInput().getType().base().isSigned()
1537  ? input
1538  : solver.add(input, solver.known(1));
1539  setExpr(op.getResult(), e);
1540  })
1541 
1542  // Miscellaneous
1543  .Case<BitsPrimOp>([&](auto op) {
1544  setExpr(op.getResult(), solver.known(op.getHi() - op.getLo() + 1));
1545  })
1546  .Case<HeadPrimOp>([&](auto op) {
1547  setExpr(op.getResult(), solver.known(op.getAmount()));
1548  })
1549  .Case<TailPrimOp>([&](auto op) {
1550  auto input = getExpr(op.getInput());
1551  auto e = solver.add(input, solver.known(-op.getAmount()));
1552  setExpr(op.getResult(), e);
1553  })
1554  .Case<PadPrimOp>([&](auto op) {
1555  auto input = getExpr(op.getInput());
1556  auto e = solver.max(input, solver.known(op.getAmount()));
1557  setExpr(op.getResult(), e);
1558  })
1559  .Case<ShlPrimOp>([&](auto op) {
1560  auto input = getExpr(op.getInput());
1561  auto e = solver.add(input, solver.known(op.getAmount()));
1562  setExpr(op.getResult(), e);
1563  })
1564  .Case<ShrPrimOp>([&](auto op) {
1565  auto input = getExpr(op.getInput());
1566  // UInt saturates at 0 bits, SInt at 1 bit
1567  auto minWidth = op.getInput().getType().base().isUnsigned() ? 0 : 1;
1568  auto e = solver.max(solver.add(input, solver.known(-op.getAmount())),
1569  solver.known(minWidth));
1570  setExpr(op.getResult(), e);
1571  })
1572 
1573  // Handle operations whose output width matches the input width.
1574  .Case<NotPrimOp, AsSIntPrimOp, AsUIntPrimOp, ConstCastOp>(
1575  [&](auto op) { setExpr(op.getResult(), getExpr(op.getInput())); })
1576  .Case<mlir::UnrealizedConversionCastOp>(
1577  [&](auto op) { setExpr(op.getResult(0), getExpr(op.getOperand(0))); })
1578 
1579  // Handle operations with a single result type that always has a
1580  // well-known width.
1581  .Case<LEQPrimOp, LTPrimOp, GEQPrimOp, GTPrimOp, EQPrimOp, NEQPrimOp,
1582  AsClockPrimOp, AsAsyncResetPrimOp, AndRPrimOp, OrRPrimOp,
1583  XorRPrimOp>([&](auto op) {
1584  auto width = op.getType().getBitWidthOrSentinel();
1585  assert(width > 0 && "width should have been checked by verifier");
1586  setExpr(op.getResult(), solver.known(width));
1587  })
1588  .Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](auto op) {
1589  auto *sel = getExpr(op.getSel());
1590  constrainTypes(solver.known(1), sel, /*imposeUpperBounds=*/true);
1591  maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
1592  })
1593  .Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
1594  auto *sel = getExpr(op.getSel());
1595  constrainTypes(solver.known(2), sel, /*imposeUpperBounds=*/true);
1596  maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
1597  maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
1598  maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
1599  })
1600 
1601  .Case<ConnectOp, MatchingConnectOp>(
1602  [&](auto op) { constrainTypes(op.getDest(), op.getSrc()); })
1603  .Case<RefDefineOp>([&](auto op) {
1604  // Dest >= Src, but also check Src <= Dest for correctness
1605  // (but don't solve to make this true, don't back-propagate)
1606  constrainTypes(op.getDest(), op.getSrc(), true);
1607  })
1608  // MatchingConnect is an identify constraint
1609  .Case<MatchingConnectOp>([&](auto op) {
1610  // This back-propagates width from destination to source,
1611  // causing source to sometimes be inferred wider than
1612  // it should be (https://github.com/llvm/circt/issues/5391).
1613  // This is needed to push the width into feeding widthCast?
1614  constrainTypes(op.getDest(), op.getSrc());
1615  constrainTypes(op.getSrc(), op.getDest());
1616  })
1617  .Case<AttachOp>([&](auto op) {
1618  // Attach connects multiple analog signals together. All signals must
1619  // have the same bit width. Signals without bit width inherit from the
1620  // other signals.
1621  if (op.getAttached().empty())
1622  return;
1623  auto prev = op.getAttached()[0];
1624  for (auto operand : op.getAttached().drop_front()) {
1625  auto e1 = getExpr(prev);
1626  auto e2 = getExpr(operand);
1627  constrainTypes(e1, e2, /*imposeUpperBounds=*/true);
1628  constrainTypes(e2, e1, /*imposeUpperBounds=*/true);
1629  prev = operand;
1630  }
1631  })
1632 
1633  // Handle the no-ops that don't interact with width inference.
1634  .Case<PrintFOp, SkipOp, StopOp, WhenOp, AssertOp, AssumeOp,
1635  UnclockedAssumeIntrinsicOp, CoverOp>([&](auto) {})
1636 
1637  // Handle instances of other modules.
1638  .Case<InstanceOp>([&](auto op) {
1639  auto refdModule = op.getReferencedOperation(symtbl);
1640  auto module = dyn_cast<FModuleOp>(&*refdModule);
1641  if (!module) {
1642  auto diag = mlir::emitError(op.getLoc());
1643  diag << "extern module `" << op.getModuleName()
1644  << "` has ports of uninferred width";
1645 
1646  auto fml = cast<FModuleLike>(&*refdModule);
1647  auto ports = fml.getPorts();
1648  for (auto &port : ports) {
1649  auto baseType = getBaseType(port.type);
1650  if (baseType && baseType.hasUninferredWidth()) {
1651  diag.attachNote(op.getLoc()) << "Port: " << port.name;
1652  if (!baseType.isGround())
1653  diagnoseUninferredType(diag, baseType, port.name.getValue());
1654  }
1655  }
1656 
1657  diag.attachNote(op.getLoc())
1658  << "Only non-extern FIRRTL modules may contain unspecified "
1659  "widths to be inferred automatically.";
1660  diag.attachNote(refdModule->getLoc())
1661  << "Module `" << op.getModuleName() << "` defined here:";
1662  mappingFailed = true;
1663  return;
1664  }
1665  // Simply look up the free variables created for the instantiated
1666  // module's ports, and use them for instance port wires. This way,
1667  // constraints imposed onto the ports of the instance will transparently
1668  // apply to the ports of the instantiated module.
1669  for (auto it : llvm::zip(op->getResults(), module.getArguments())) {
1670  unifyTypes(FieldRef(std::get<0>(it), 0), FieldRef(std::get<1>(it), 0),
1671  type_cast<FIRRTLType>(std::get<0>(it).getType()));
1672  }
1673  })
1674 
1675  // Handle memories.
1676  .Case<MemOp>([&](MemOp op) {
1677  // Create constraint variables for all ports.
1678  unsigned nonDebugPort = 0;
1679  for (const auto &result : llvm::enumerate(op.getResults())) {
1680  declareVars(result.value(), op.getLoc());
1681  if (!type_isa<RefType>(result.value().getType()))
1682  nonDebugPort = result.index();
1683  }
1684 
1685  // A helper function that returns the indeces of the "data", "rdata",
1686  // and "wdata" fields in the bundle corresponding to a memory port.
1687  auto dataFieldIndices = [](MemOp::PortKind kind) -> ArrayRef<unsigned> {
1688  static const unsigned indices[] = {3, 5};
1689  static const unsigned debug[] = {0};
1690  switch (kind) {
1691  case MemOp::PortKind::Read:
1692  case MemOp::PortKind::Write:
1693  return ArrayRef<unsigned>(indices, 1); // {3}
1694  case MemOp::PortKind::ReadWrite:
1695  return ArrayRef<unsigned>(indices); // {3, 5}
1696  case MemOp::PortKind::Debug:
1697  return ArrayRef<unsigned>(debug);
1698  }
1699  llvm_unreachable("Imposible PortKind");
1700  };
1701 
1702  // This creates independent variables for every data port. Yet, what we
1703  // actually want is for all data ports to share the same variable. To do
1704  // this, we find the first data port declared, and use that port's vars
1705  // for all the other ports.
1706  unsigned firstFieldIndex =
1707  dataFieldIndices(op.getPortKind(nonDebugPort))[0];
1708  FieldRef firstData(
1709  op.getResult(nonDebugPort),
1710  type_cast<BundleType>(op.getPortType(nonDebugPort).getPassiveType())
1711  .getFieldID(firstFieldIndex));
1712  LLVM_DEBUG(llvm::dbgs() << "Adjusting memory port variables:\n");
1713 
1714  // Reuse data port variables.
1715  auto dataType = op.getDataType();
1716  for (unsigned i = 0, e = op.getResults().size(); i < e; ++i) {
1717  auto result = op.getResult(i);
1718  if (type_isa<RefType>(result.getType())) {
1719  // Debug ports are firrtl.ref<vector<data-type, depth>>
1720  // Use FieldRef of 1, to indicate the first vector element must be
1721  // of the dataType.
1722  unifyTypes(firstData, FieldRef(result, 1), dataType);
1723  continue;
1724  }
1725 
1726  auto portType =
1727  type_cast<BundleType>(op.getPortType(i).getPassiveType());
1728  for (auto fieldIndex : dataFieldIndices(op.getPortKind(i)))
1729  unifyTypes(FieldRef(result, portType.getFieldID(fieldIndex)),
1730  firstData, dataType);
1731  }
1732  })
1733 
1734  .Case<RefSendOp>([&](auto op) {
1735  declareVars(op.getResult(), op.getLoc());
1736  constrainTypes(op.getResult(), op.getBase(), true);
1737  })
1738  .Case<RefResolveOp>([&](auto op) {
1739  declareVars(op.getResult(), op.getLoc());
1740  constrainTypes(op.getResult(), op.getRef(), true);
1741  })
1742  .Case<RefCastOp>([&](auto op) {
1743  declareVars(op.getResult(), op.getLoc());
1744  constrainTypes(op.getResult(), op.getInput(), true);
1745  })
1746  .Case<RWProbeOp>([&](auto op) {
1747  declareVars(op.getResult(), op.getLoc());
1748  auto ist = irn.lookup(op.getTarget());
1749  if (!ist) {
1750  op->emitError("target of rwprobe could not be resolved");
1751  mappingFailed = true;
1752  return;
1753  }
1754  auto ref = getFieldRefForTarget(ist);
1755  if (!ref) {
1756  op->emitError("target of rwprobe resolved to unsupported target");
1757  mappingFailed = true;
1758  return;
1759  }
1760  auto newFID = convertFieldIDToOurVersion(
1761  ref.getFieldID(), type_cast<FIRRTLType>(ref.getValue().getType()));
1762  unifyTypes(FieldRef(op.getResult(), 0),
1763  FieldRef(ref.getValue(), newFID), op.getType());
1764  })
1765  .Case<mlir::UnrealizedConversionCastOp>([&](auto op) {
1766  for (Value result : op.getResults()) {
1767  auto ty = result.getType();
1768  if (type_isa<FIRRTLType>(ty))
1769  declareVars(result, op.getLoc());
1770  }
1771  })
1772  .Default([&](auto op) {
1773  op->emitOpError("not supported in width inference");
1774  mappingFailed = true;
1775  });
1776 
1777  // Forceable declarations should have the ref constrained to data result.
1778  if (auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable())
1779  unifyTypes(FieldRef(fop.getDataRef(), 0), FieldRef(fop.getDataRaw(), 0),
1780  fop.getDataType());
1781 
1782  return failure(mappingFailed);
1783 }
1784 
1785 /// Declare free variables for the type of a value, and associate the resulting
1786 /// set of variables with that value.
1787 void InferenceMapping::declareVars(Value value, Location loc, bool isDerived) {
1788  // Declare a variable for every unknown width in the type. If this is a Bundle
1789  // type or a FVector type, we will have to potentially create many variables.
1790  unsigned fieldID = 0;
1791  std::function<void(FIRRTLBaseType)> declare = [&](FIRRTLBaseType type) {
1792  auto width = type.getBitWidthOrSentinel();
1793  if (width >= 0) {
1794  // Known width integer create a known expression.
1795  setExpr(FieldRef(value, fieldID), solver.known(width));
1796  fieldID++;
1797  } else if (width == -1) {
1798  // Unknown width integers create a variable.
1799  FieldRef field(value, fieldID);
1800  solver.setCurrentContextInfo(field);
1801  if (isDerived)
1802  setExpr(field, solver.derived());
1803  else
1804  setExpr(field, solver.var());
1805  fieldID++;
1806  } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1807  // Bundle types recursively declare all bundle elements.
1808  fieldID++;
1809  for (auto &element : bundleType) {
1810  declare(element.type);
1811  }
1812  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1813  fieldID++;
1814  auto save = fieldID;
1815  declare(vecType.getElementType());
1816  // Skip past the rest of the elements
1817  fieldID = save + vecType.getMaxFieldID();
1818  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1819  fieldID++;
1820  for (auto &element : enumType.getElements())
1821  declare(element.type);
1822  } else {
1823  llvm_unreachable("Unknown type inside a bundle!");
1824  }
1825  };
1826  if (auto type = getBaseType(value.getType()))
1827  declare(type);
1828 }
1829 
1830 /// Assign the constraint expressions of the fields in the `result` argument as
1831 /// the max of expressions in the `rhs` and `lhs` arguments. Both fields must be
1832 /// the same type.
1833 void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
1834  // Recurse to every leaf element and set larger >= smaller.
1835  auto fieldID = 0;
1836  std::function<void(FIRRTLBaseType)> maximize = [&](FIRRTLBaseType type) {
1837  if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1838  fieldID++;
1839  for (auto &element : bundleType.getElements())
1840  maximize(element.type);
1841  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1842  fieldID++;
1843  auto save = fieldID;
1844  // Skip 0 length vectors.
1845  if (vecType.getNumElements() > 0)
1846  maximize(vecType.getElementType());
1847  fieldID = save + vecType.getMaxFieldID();
1848  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1849  fieldID++;
1850  for (auto &element : enumType.getElements())
1851  maximize(element.type);
1852  } else if (type.isGround()) {
1853  auto *e = solver.max(getExpr(FieldRef(rhs, fieldID)),
1854  getExpr(FieldRef(lhs, fieldID)));
1855  setExpr(FieldRef(result, fieldID), e);
1856  fieldID++;
1857  } else {
1858  llvm_unreachable("Unknown type inside a bundle!");
1859  }
1860  };
1861  if (auto type = getBaseType(result.getType()))
1862  maximize(type);
1863 }
1864 
1865 /// Establishes constraints to ensure the sizes in the `larger` type are greater
1866 /// than or equal to the sizes in the `smaller` type. Types have to be
1867 /// compatible in the sense that they may only differ in the presence or absence
1868 /// of bit widths.
1869 ///
1870 /// This function is used to apply regular connects.
1871 /// Set `equal` for constraining larger <= smaller for correctness but not
1872 /// solving.
1873 void InferenceMapping::constrainTypes(Value larger, Value smaller, bool equal) {
1874  // Recurse to every leaf element and set larger >= smaller. Ignore foreign
1875  // types as these do not participate in width inference.
1876 
1877  auto fieldID = 0;
1878  std::function<void(FIRRTLBaseType, Value, Value)> constrain =
1879  [&](FIRRTLBaseType type, Value larger, Value smaller) {
1880  if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1881  fieldID++;
1882  for (auto &element : bundleType.getElements()) {
1883  if (element.isFlip)
1884  constrain(element.type, smaller, larger);
1885  else
1886  constrain(element.type, larger, smaller);
1887  }
1888  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1889  fieldID++;
1890  auto save = fieldID;
1891  // Skip 0 length vectors.
1892  if (vecType.getNumElements() > 0) {
1893  constrain(vecType.getElementType(), larger, smaller);
1894  }
1895  fieldID = save + vecType.getMaxFieldID();
1896  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1897  fieldID++;
1898  for (auto &element : enumType.getElements())
1899  constrain(element.type, larger, smaller);
1900  } else if (type.isGround()) {
1901  // Leaf element, look up their expressions, and create the constraint.
1902  constrainTypes(getExpr(FieldRef(larger, fieldID)),
1903  getExpr(FieldRef(smaller, fieldID)), false, equal);
1904  fieldID++;
1905  } else {
1906  llvm_unreachable("Unknown type inside a bundle!");
1907  }
1908  };
1909 
1910  if (auto type = getBaseType(larger.getType()))
1911  constrain(type, larger, smaller);
1912 }
1913 
1914 /// Establishes constraints to ensure the sizes in the `larger` type are greater
1915 /// than or equal to the sizes in the `smaller` type.
1916 void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
1917  bool imposeUpperBounds, bool equal) {
1918  assert(larger && "Larger expression should be specified");
1919  assert(smaller && "Smaller expression should be specified");
1920 
1921  // If one of the sides is `DerivedExpr`, simply assign the other side as the
1922  // derived width. This allows `InvalidValueOp`s to properly infer their width
1923  // from the connects they are used in, but also be inferred to something
1924  // useful on their own.
1925  if (auto *largerDerived = dyn_cast<DerivedExpr>(larger)) {
1926  largerDerived->assigned = smaller;
1927  LLVM_DEBUG(llvm::dbgs() << "Deriving " << *largerDerived << " from "
1928  << *smaller << "\n");
1929  return;
1930  }
1931  if (auto *smallerDerived = dyn_cast<DerivedExpr>(smaller)) {
1932  smallerDerived->assigned = larger;
1933  LLVM_DEBUG(llvm::dbgs() << "Deriving " << *smallerDerived << " from "
1934  << *larger << "\n");
1935  return;
1936  }
1937 
1938  // If the larger expr is a free variable, create a `expr >= x` constraint for
1939  // it that we can try to satisfy with the smallest width.
1940  if (auto largerVar = dyn_cast<VarExpr>(larger)) {
1941  [[maybe_unused]] auto *c = solver.addGeqConstraint(largerVar, smaller);
1942  LLVM_DEBUG(llvm::dbgs()
1943  << "Constrained " << *largerVar << " >= " << *c << "\n");
1944  // If we're constraining larger == smaller, add the LEQ contraint as well.
1945  // Solve for GEQ but check that LEQ is true.
1946  // Used for matchingconnect, some reference operations, and anywhere the
1947  // widths should be inferred strictly in one direction but are required to
1948  // also be equal for correctness.
1949  if (equal) {
1950  [[maybe_unused]] auto *leq = solver.addLeqConstraint(largerVar, smaller);
1951  LLVM_DEBUG(llvm::dbgs()
1952  << "Constrained " << *largerVar << " <= " << *leq << "\n");
1953  }
1954  return;
1955  }
1956 
1957  // If the smaller expr is a free variable but the larger one is not, create a
1958  // `expr <= k` upper bound that we can verify once all lower bounds have been
1959  // satisfied. Since we are always picking the smallest width to satisfy all
1960  // `>=` constraints, any `<=` constraints have no effect on the solution
1961  // besides indicating that a width is unsatisfiable.
1962  if (auto *smallerVar = dyn_cast<VarExpr>(smaller)) {
1963  if (imposeUpperBounds || equal) {
1964  [[maybe_unused]] auto *c = solver.addLeqConstraint(smallerVar, larger);
1965  LLVM_DEBUG(llvm::dbgs()
1966  << "Constrained " << *smallerVar << " <= " << *c << "\n");
1967  }
1968  }
1969 }
1970 
1971 /// Assign the constraint expressions of the fields in the `src` argument as the
1972 /// expressions for the `dst` argument. Both fields must be of the given `type`.
1973 void InferenceMapping::unifyTypes(FieldRef lhs, FieldRef rhs, FIRRTLType type) {
1974  // Fast path for `unifyTypes(x, x, _)`.
1975  if (lhs == rhs)
1976  return;
1977 
1978  // Co-iterate the two field refs, recurring into every leaf element and set
1979  // them equal.
1980  auto fieldID = 0;
1981  std::function<void(FIRRTLBaseType)> unify = [&](FIRRTLBaseType type) {
1982  if (type.isGround()) {
1983  // Leaf element, unify the fields!
1984  FieldRef lhsFieldRef(lhs.getValue(), lhs.getFieldID() + fieldID);
1985  FieldRef rhsFieldRef(rhs.getValue(), rhs.getFieldID() + fieldID);
1986  LLVM_DEBUG(llvm::dbgs()
1987  << "Unify " << getFieldName(lhsFieldRef).first << " = "
1988  << getFieldName(rhsFieldRef).first << "\n");
1989  // Abandon variables becoming unconstrainable by the unification.
1990  if (auto *var = dyn_cast_or_null<VarExpr>(getExprOrNull(lhsFieldRef)))
1991  solver.addGeqConstraint(var, solver.known(0));
1992  setExpr(lhsFieldRef, getExpr(rhsFieldRef));
1993  fieldID++;
1994  } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1995  fieldID++;
1996  for (auto &element : bundleType) {
1997  unify(element.type);
1998  }
1999  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
2000  fieldID++;
2001  auto save = fieldID;
2002  // Skip 0 length vectors.
2003  if (vecType.getNumElements() > 0) {
2004  unify(vecType.getElementType());
2005  }
2006  fieldID = save + vecType.getMaxFieldID();
2007  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
2008  fieldID++;
2009  for (auto &element : enumType.getElements())
2010  unify(element.type);
2011  } else {
2012  llvm_unreachable("Unknown type inside a bundle!");
2013  }
2014  };
2015  if (auto ftype = getBaseType(type))
2016  unify(ftype);
2017 }
2018 
2019 /// Get the constraint expression for a value.
2020 Expr *InferenceMapping::getExpr(Value value) const {
2021  assert(type_cast<FIRRTLType>(getBaseType(value.getType())).isGround());
2022  // A field ID of 0 indicates the entire value.
2023  return getExpr(FieldRef(value, 0));
2024 }
2025 
2026 /// Get the constraint expression for a value.
2027 Expr *InferenceMapping::getExpr(FieldRef fieldRef) const {
2028  auto expr = getExprOrNull(fieldRef);
2029  assert(expr && "constraint expr should have been constructed for value");
2030  return expr;
2031 }
2032 
2033 Expr *InferenceMapping::getExprOrNull(FieldRef fieldRef) const {
2034  auto it = opExprs.find(fieldRef);
2035  return it != opExprs.end() ? it->second : nullptr;
2036 }
2037 
2038 /// Associate a constraint expression with a value.
2039 void InferenceMapping::setExpr(Value value, Expr *expr) {
2040  assert(type_cast<FIRRTLType>(getBaseType(value.getType())).isGround());
2041  // A field ID of 0 indicates the entire value.
2042  setExpr(FieldRef(value, 0), expr);
2043 }
2044 
2045 /// Associate a constraint expression with a value.
2046 void InferenceMapping::setExpr(FieldRef fieldRef, Expr *expr) {
2047  LLVM_DEBUG({
2048  llvm::dbgs() << "Expr " << *expr << " for " << fieldRef.getValue();
2049  if (fieldRef.getFieldID())
2050  llvm::dbgs() << " '" << getFieldName(fieldRef).first << "'";
2051  auto fieldName = getFieldName(fieldRef);
2052  if (fieldName.second)
2053  llvm::dbgs() << " (\"" << fieldName.first << "\")";
2054  llvm::dbgs() << "\n";
2055  });
2056  opExprs[fieldRef] = expr;
2057 }
2058 
2059 //===----------------------------------------------------------------------===//
2060 // Inference Result Application
2061 //===----------------------------------------------------------------------===//
2062 
2063 namespace {
2064 /// A helper class which maps the types and operations in a design to a set
2065 /// of variables and constraints to be solved later.
2066 class InferenceTypeUpdate {
2067 public:
2068  InferenceTypeUpdate(InferenceMapping &mapping) : mapping(mapping) {}
2069 
2070  LogicalResult update(CircuitOp op);
2071  FailureOr<bool> updateOperation(Operation *op);
2072  FailureOr<bool> updateValue(Value value);
2074 
2075 private:
2076  const InferenceMapping &mapping;
2077 };
2078 
2079 } // namespace
2080 
2081 /// Update the types throughout a circuit.
2082 LogicalResult InferenceTypeUpdate::update(CircuitOp op) {
2083  LLVM_DEBUG({
2084  llvm::dbgs() << "\n";
2085  debugHeader("Update types") << "\n\n";
2086  });
2087  return mlir::failableParallelForEach(
2088  op.getContext(), op.getOps<FModuleOp>(), [&](FModuleOp op) {
2089  // Skip this module if it had no widths to be
2090  // inferred at all.
2091  if (mapping.isModuleSkipped(op))
2092  return success();
2093  auto isFailed = op.walk<WalkOrder::PreOrder>([&](Operation *op) {
2094  if (failed(updateOperation(op)))
2095  return WalkResult::interrupt();
2096  return WalkResult::advance();
2097  }).wasInterrupted();
2098  return failure(isFailed);
2099  });
2100 }
2101 
2102 /// Update the result types of an operation.
2103 FailureOr<bool> InferenceTypeUpdate::updateOperation(Operation *op) {
2104  bool anyChanged = false;
2105 
2106  for (Value v : op->getResults()) {
2107  auto result = updateValue(v);
2108  if (failed(result))
2109  return result;
2110  anyChanged |= *result;
2111  }
2112 
2113  // If this is a connect operation, width inference might have inferred a RHS
2114  // that is wider than the LHS, in which case an additional BitsPrimOp is
2115  // necessary to truncate the value.
2116  if (auto con = dyn_cast<ConnectOp>(op)) {
2117  auto lhs = con.getDest();
2118  auto rhs = con.getSrc();
2119  auto lhsType = type_dyn_cast<FIRRTLBaseType>(lhs.getType());
2120  auto rhsType = type_dyn_cast<FIRRTLBaseType>(rhs.getType());
2121 
2122  // Nothing to do if not base types.
2123  if (!lhsType || !rhsType)
2124  return anyChanged;
2125 
2126  auto lhsWidth = lhsType.getBitWidthOrSentinel();
2127  auto rhsWidth = rhsType.getBitWidthOrSentinel();
2128  if (lhsWidth >= 0 && rhsWidth >= 0 && lhsWidth < rhsWidth) {
2129  OpBuilder builder(op);
2130  auto trunc = builder.createOrFold<TailPrimOp>(con.getLoc(), con.getSrc(),
2131  rhsWidth - lhsWidth);
2132  if (type_isa<SIntType>(rhsType))
2133  trunc =
2134  builder.createOrFold<AsSIntPrimOp>(con.getLoc(), lhsType, trunc);
2135 
2136  LLVM_DEBUG(llvm::dbgs()
2137  << "Truncating RHS to " << lhsType << " in " << con << "\n");
2138  con->replaceUsesOfWith(con.getSrc(), trunc);
2139  }
2140  return anyChanged;
2141  }
2142 
2143  // If this is a module, update its ports.
2144  if (auto module = dyn_cast<FModuleOp>(op)) {
2145  // Update the block argument types.
2146  bool argsChanged = false;
2147  SmallVector<Attribute> argTypes;
2148  argTypes.reserve(module.getNumPorts());
2149  for (auto arg : module.getArguments()) {
2150  auto result = updateValue(arg);
2151  if (failed(result))
2152  return result;
2153  argsChanged |= *result;
2154  argTypes.push_back(TypeAttr::get(arg.getType()));
2155  }
2156 
2157  // Update the module function type if needed.
2158  if (argsChanged) {
2159  module->setAttr(FModuleLike::getPortTypesAttrName(),
2160  ArrayAttr::get(module.getContext(), argTypes));
2161  anyChanged = true;
2162  }
2163  }
2164  return anyChanged;
2165 }
2166 
2167 /// Resize a `uint`, `sint`, or `analog` type to a specific width.
2168 static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth) {
2169  auto *context = type.getContext();
2171  .Case<UIntType>([&](auto type) {
2172  return UIntType::get(context, newWidth, type.isConst());
2173  })
2174  .Case<SIntType>([&](auto type) {
2175  return SIntType::get(context, newWidth, type.isConst());
2176  })
2177  .Case<AnalogType>([&](auto type) {
2178  return AnalogType::get(context, newWidth, type.isConst());
2179  })
2180  .Default([&](auto type) { return type; });
2181 }
2182 
2183 /// Update the type of a value.
2184 FailureOr<bool> InferenceTypeUpdate::updateValue(Value value) {
2185  // Check if the value has a type which we can update.
2186  auto type = type_dyn_cast<FIRRTLType>(value.getType());
2187  if (!type)
2188  return false;
2189 
2190  // Fast path for types that have fully inferred widths.
2191  if (!hasUninferredWidth(type))
2192  return false;
2193 
2194  // If this is an operation that does not generate any free variables that
2195  // are determined during width inference, simply update the value type based
2196  // on the operation arguments.
2197  if (auto op = dyn_cast_or_null<InferTypeOpInterface>(value.getDefiningOp())) {
2198  SmallVector<Type, 2> types;
2199  auto res =
2200  op.inferReturnTypes(op->getContext(), op->getLoc(), op->getOperands(),
2201  op->getAttrDictionary(), op->getPropertiesStorage(),
2202  op->getRegions(), types);
2203  if (failed(res))
2204  return failure();
2205 
2206  assert(types.size() == op->getNumResults());
2207  for (auto it : llvm::zip(op->getResults(), types)) {
2208  LLVM_DEBUG(llvm::dbgs() << "Inferring " << std::get<0>(it) << " as "
2209  << std::get<1>(it) << "\n");
2210  std::get<0>(it).setType(std::get<1>(it));
2211  }
2212  return true;
2213  }
2214 
2215  // Recreate the type, substituting the solved widths.
2216  auto context = type.getContext();
2217  unsigned fieldID = 0;
2218  std::function<FIRRTLBaseType(FIRRTLBaseType)> updateBase =
2219  [&](FIRRTLBaseType type) -> FIRRTLBaseType {
2220  auto width = type.getBitWidthOrSentinel();
2221  if (width >= 0) {
2222  // Known width integers return themselves.
2223  fieldID++;
2224  return type;
2225  } else if (width == -1) {
2226  // Unknown width integers return the solved type.
2227  auto newType = updateType(FieldRef(value, fieldID), type);
2228  fieldID++;
2229  return newType;
2230  } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
2231  // Bundle types recursively update all bundle elements.
2232  fieldID++;
2233  llvm::SmallVector<BundleType::BundleElement, 3> elements;
2234  for (auto &element : bundleType) {
2235  auto updatedBase = updateBase(element.type);
2236  if (!updatedBase)
2237  return {};
2238  elements.emplace_back(element.name, element.isFlip, updatedBase);
2239  }
2240  return BundleType::get(context, elements, bundleType.isConst());
2241  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
2242  fieldID++;
2243  auto save = fieldID;
2244  // TODO: this should recurse into the element type of 0 length vectors and
2245  // set any unknown width to 0.
2246  if (vecType.getNumElements() > 0) {
2247  auto updatedBase = updateBase(vecType.getElementType());
2248  if (!updatedBase)
2249  return {};
2250  auto newType = FVectorType::get(updatedBase, vecType.getNumElements(),
2251  vecType.isConst());
2252  fieldID = save + vecType.getMaxFieldID();
2253  return newType;
2254  }
2255  // If this is a 0 length vector return the original type.
2256  return type;
2257  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
2258  fieldID++;
2259  llvm::SmallVector<FEnumType::EnumElement> elements;
2260  for (auto &element : enumType.getElements()) {
2261  auto updatedBase = updateBase(element.type);
2262  if (!updatedBase)
2263  return {};
2264  elements.emplace_back(element.name, updatedBase);
2265  }
2266  return FEnumType::get(context, elements, enumType.isConst());
2267  }
2268  llvm_unreachable("Unknown type inside a bundle!");
2269  };
2270 
2271  // Update the type.
2272  auto newType = mapBaseTypeNullable(type, updateBase);
2273  if (!newType)
2274  return failure();
2275  LLVM_DEBUG(llvm::dbgs() << "Update " << value << " to " << newType << "\n");
2276  value.setType(newType);
2277 
2278  // If this is a ConstantOp, adjust the width of the underlying APInt.
2279  // Unsized constants have APInts which are *at least* wide enough to hold
2280  // the value, but may be larger. This can trip up the verifier.
2281  if (auto op = value.getDefiningOp<ConstantOp>()) {
2282  auto k = op.getValue();
2283  auto bitwidth = op.getType().getBitWidthOrSentinel();
2284  if (k.getBitWidth() > unsigned(bitwidth))
2285  k = k.trunc(bitwidth);
2286  op->setAttr("value", IntegerAttr::get(op.getContext(), k));
2287  }
2288 
2289  return newType != type;
2290 }
2291 
2292 /// Update a type.
2294  FIRRTLBaseType type) {
2295  assert(type.isGround() && "Can only pass in ground types.");
2296  auto value = fieldRef.getValue();
2297  // Get the inferred width.
2298  Expr *expr = mapping.getExprOrNull(fieldRef);
2299  if (!expr || !expr->solution) {
2300  // It should not be possible to arrive at an uninferred width at this point.
2301  // In case the constraints are not resolvable, checks before the calls to
2302  // `updateType` must have already caught the issues and aborted the pass
2303  // early. Might turn this into an assert later.
2304  mlir::emitError(value.getLoc(), "width should have been inferred");
2305  return {};
2306  }
2307  int32_t solution = *expr->solution;
2308  assert(solution >= 0); // The solver infers variables to be 0 or greater.
2309  return resizeType(type, solution);
2310 }
2311 
2312 //===----------------------------------------------------------------------===//
2313 // Hashing
2314 //===----------------------------------------------------------------------===//
2315 
2316 // Hash slots in the interned allocator as if they were the pointed-to value
2317 // itself.
2318 namespace llvm {
2319 template <typename T>
2320 struct DenseMapInfo<InternedSlot<T>> {
2321  using Slot = InternedSlot<T>;
2322  static Slot getEmptyKey() {
2323  auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
2324  return Slot(static_cast<T *>(pointer));
2325  }
2327  auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
2328  return Slot(static_cast<T *>(pointer));
2329  }
2330  static unsigned getHashValue(Slot val) { return mlir::hash_value(*val.ptr); }
2331  static bool isEqual(Slot LHS, Slot RHS) {
2332  auto empty = getEmptyKey().ptr;
2333  auto tombstone = getTombstoneKey().ptr;
2334  if (LHS.ptr == empty || RHS.ptr == empty || LHS.ptr == tombstone ||
2335  RHS.ptr == tombstone)
2336  return LHS.ptr == RHS.ptr;
2337  return *LHS.ptr == *RHS.ptr;
2338  }
2339 };
2340 } // namespace llvm
2341 
2342 //===----------------------------------------------------------------------===//
2343 // Pass Infrastructure
2344 //===----------------------------------------------------------------------===//
2345 
2346 namespace {
2347 class InferWidthsPass
2348  : public circt::firrtl::impl::InferWidthsBase<InferWidthsPass> {
2349  void runOnOperation() override;
2350 };
2351 } // namespace
2352 
2353 void InferWidthsPass::runOnOperation() {
2354  // Collect variables and constraints
2355  ConstraintSolver solver;
2356  InferenceMapping mapping(solver, getAnalysis<SymbolTable>(),
2357  getAnalysis<hw::InnerSymbolTableCollection>());
2358  if (failed(mapping.map(getOperation()))) {
2359  signalPassFailure();
2360  return;
2361  }
2362  if (mapping.areAllModulesSkipped()) {
2363  markAllAnalysesPreserved();
2364  return; // fast path if no inferrable widths are around
2365  }
2366 
2367  // Solve the constraints.
2368  if (failed(solver.solve())) {
2369  signalPassFailure();
2370  return;
2371  }
2372 
2373  // Update the types with the inferred widths.
2374  if (failed(InferenceTypeUpdate(mapping).update(getOperation())))
2375  signalPassFailure();
2376 }
2377 
2378 std::unique_ptr<mlir::Pass> circt::firrtl::createInferWidthsPass() {
2379  return std::make_unique<InferWidthsPass>();
2380 }
assert(baseType &&"element must be base type")
int32_t width
Definition: FIRRTL.cpp:36
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
bool operator==(const ResetDomain &a, const ResetDomain &b)
Definition: InferResets.cpp:78
std::pair< std::optional< int32_t >, bool > ExprSolution
static ExprSolution computeUnary(ExprSolution arg, llvm::function_ref< int32_t(int32_t)> operation)
static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type)
Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
Definition: InferWidths.cpp:73
static ExprSolution computeBinary(ExprSolution lhs, ExprSolution rhs, llvm::function_ref< int32_t(int32_t, int32_t)> operation)
#define EXPR_KINDS
#define EXPR_CLASSES
static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth)
Resize a uint, sint, or analog type to a specific width.
static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t, Twine str)
Definition: InferWidths.cpp:55
static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl< Expr * > &seenVars, unsigned defaultWorklistSize)
Compute the value of a constraint expr.
static bool hasUninferredWidth(Type type)
Check if a type contains any FIRRTL type with uninferred widths.
static InstancePath empty
This class represents a reference to a specific field or element of an aggregate value.
Definition: FieldRef.h:28
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Definition: FieldRef.h:59
Value getValue() const
Get the Value which created this location.
Definition: FieldRef.h:37
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:520
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:530
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
FIRRTLType mapBaseTypeNullable(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
Definition: FIRRTLUtils.h:252
FieldRef getFieldRefForTarget(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
Definition: FIRRTLUtils.h:222
T & operator<<(T &os, FIRVersion version)
Definition: FIRParser.h:119
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
std::unique_ptr< mlir::Pass > createInferWidthsPass()
llvm::hash_code hash_value(const ClassElement &element)
Definition: FIRRTLTypes.h:368
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugHeader(llvm::StringRef str, int width=80)
Write a "header"-like string to the debug stream with a certain width.
Definition: Debug.cpp:18
Definition: debug.py:1
llvm::hash_code hash_value(const T &e)
static unsigned getHashValue(Slot val)
static bool isEqual(Slot LHS, Slot RHS)