CIRCT  18.0.0git
InferWidths.cpp
Go to the documentation of this file.
1 //===- InferWidths.cpp - Infer width of types -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the InferWidths pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetails.h"
19 #include "circt/Support/FieldRef.h"
20 #include "mlir/IR/ImplicitLocOpBuilder.h"
21 #include "mlir/IR/Threading.h"
22 #include "llvm/ADT/APSInt.h"
23 #include "llvm/ADT/DenseSet.h"
24 #include "llvm/ADT/GraphTraits.h"
25 #include "llvm/ADT/Hashing.h"
26 #include "llvm/ADT/MapVector.h"
27 #include "llvm/ADT/PostOrderIterator.h"
28 #include "llvm/ADT/SetVector.h"
29 #include "llvm/Support/Debug.h"
30 #include "llvm/Support/ErrorHandling.h"
31 
32 #define DEBUG_TYPE "infer-widths"
33 
34 using mlir::InferTypeOpInterface;
35 using mlir::WalkOrder;
36 
37 using namespace circt;
38 using namespace firrtl;
39 
40 //===----------------------------------------------------------------------===//
41 // Helpers
42 //===----------------------------------------------------------------------===//
43 
44 static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t,
45  Twine str) {
46  auto basetype = type_dyn_cast<FIRRTLBaseType>(t);
47  if (!basetype)
48  return;
49  if (!basetype.hasUninferredWidth())
50  return;
51 
52  if (basetype.isGround())
53  diag.attachNote() << "Field: \"" << str << "\"";
54  else if (auto vecType = type_dyn_cast<FVectorType>(basetype))
55  diagnoseUninferredType(diag, vecType.getElementType(), str + "[]");
56  else if (auto bundleType = type_dyn_cast<BundleType>(basetype))
57  for (auto &elem : bundleType.getElements())
58  diagnoseUninferredType(diag, elem.type, str + "." + elem.name.getValue());
59 }
60 
61 /// Get FieldRef pointing to the specified inner symbol target, which must be
62 /// valid. Returns null FieldRef if target points to something with no value,
63 /// such as a port of an external module.
64 static FieldRef getRefForIST(const hw::InnerSymTarget &ist) {
65  if (ist.isPort()) {
66  return TypeSwitch<Operation *, FieldRef>(ist.getOp())
67  .Case<FModuleOp>([&](auto fmod) {
68  return FieldRef(fmod.getArgument(ist.getPort()), ist.getField());
69  })
70  .Default({});
71  }
72 
73  auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(ist.getOp());
74  assert(symOp && symOp.getTargetResultIndex() &&
75  (symOp.supportsPerFieldSymbols() || ist.getField() == 0));
76  return FieldRef(symOp.getTargetResult(), ist.getField());
77 }
78 
79 /// Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
80 static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type) {
81  uint64_t convertedFieldID = 0;
82 
83  auto curFID = fieldID;
84  Type curFType = type;
85  while (curFID != 0) {
86  auto [child, subID] =
87  hw::FieldIdImpl::getSubTypeByFieldID(curFType, curFID);
88  if (isa<FVectorType>(curFType))
89  convertedFieldID++; // Vector fieldID is 1.
90  else
91  convertedFieldID += curFID - subID; // Add consumed portion.
92  curFID = subID;
93  curFType = child;
94  }
95 
96  return convertedFieldID;
97 }
98 
99 //===----------------------------------------------------------------------===//
100 // Constraint Expressions
101 //===----------------------------------------------------------------------===//
102 
103 namespace {
104 struct Expr;
105 } // namespace
106 
107 /// Allow rvalue refs to `Expr` and subclasses to be printed to streams.
109  int>::type = 0>
110 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const T &e) {
111  e.print(os);
112  return os;
113 }
114 
115 // Allow expression subclasses to be hashed.
116 namespace mlir {
118  int>::type = 0>
119 inline llvm::hash_code hash_value(const T &e) {
120  return e.hash_value();
121 }
122 } // namespace mlir
123 
124 namespace {
125 #define EXPR_NAMES(x) \
126  Root##x, Var##x, Derived##x, Id##x, Known##x, Add##x, Pow##x, Max##x, Min##x
127 #define EXPR_KINDS EXPR_NAMES()
128 #define EXPR_CLASSES EXPR_NAMES(Expr)
129 
130 /// An expression on the right-hand side of a constraint.
131 struct Expr {
132  enum class Kind { EXPR_KINDS };
133  std::optional<int32_t> solution;
134  Kind kind;
135 
136  /// Print a human-readable representation of this expr.
137  void print(llvm::raw_ostream &os) const;
138 
139 protected:
140  Expr(Kind kind) : kind(kind) {}
141  llvm::hash_code hash_value() const { return llvm::hash_value(kind); }
142 };
143 
144 /// Helper class to CRTP-derive common functions.
145 template <class DerivedT, Expr::Kind DerivedKind>
146 struct ExprBase : public Expr {
147  ExprBase() : Expr(DerivedKind) {}
148  static bool classof(const Expr *e) { return e->kind == DerivedKind; }
149  bool operator==(const Expr &other) const {
150  if (auto otherSame = dyn_cast<DerivedT>(other))
151  return *static_cast<DerivedT *>(this) == otherSame;
152  return false;
153  }
154 };
155 
156 /// The root containing all expressions.
157 struct RootExpr : public ExprBase<RootExpr, Expr::Kind::Root> {
158  RootExpr(std::vector<Expr *> &exprs) : exprs(exprs) {}
159  void print(llvm::raw_ostream &os) const { os << "root"; }
160  std::vector<Expr *> &exprs;
161 };
162 
163 /// A free variable.
164 struct VarExpr : public ExprBase<VarExpr, Expr::Kind::Var> {
165  void print(llvm::raw_ostream &os) const {
166  // Hash the `this` pointer into something somewhat human readable. Since
167  // this is just for debug dumping, we wrap around at 65536 variables.
168  os << "var" << ((size_t)this / llvm::PowerOf2Ceil(sizeof(*this)) & 0xFFFF);
169  }
170 
171  /// The constraint expression this variable is supposed to be greater than or
172  /// equal to. This is not part of the variable's hash and equality property.
173  Expr *constraint = nullptr;
174 
175  /// The upper bound this variable is supposed to be smaller than or equal to.
176  Expr *upperBound = nullptr;
177  std::optional<int32_t> upperBoundSolution;
178 };
179 
180 /// A derived width.
181 ///
182 /// These are generated for `InvalidValueOp`s which want to derived their width
183 /// from connect operations that they are on the right hand side of.
184 struct DerivedExpr : public ExprBase<DerivedExpr, Expr::Kind::Derived> {
185  void print(llvm::raw_ostream &os) const {
186  // Hash the `this` pointer into something somewhat human readable.
187  os << "derive"
188  << ((size_t)this / llvm::PowerOf2Ceil(sizeof(*this)) & 0xFFF);
189  }
190 
191  /// The expression this derived width is equivalent to.
192  Expr *assigned = nullptr;
193 };
194 
195 /// An identity expression.
196 ///
197 /// This expression evaluates to its inner expression. It is used in a very
198 /// specific case of constraints on variables, in order to be able to track
199 /// where the constraint was imposed. Constraints on variables are represented
200 /// as `var >= <expr>`. When the first constraint `a` is imposed, it is stored
201 /// as the constraint expression (`var >= a`). When the second constraint `b` is
202 /// imposed, a *new* max expression is allocated (`var >= max(a, b)`).
203 /// Expressions are annotated with a location when they are created, which in
204 /// this case are connect ops. Since imposing the first constraint does not
205 /// create any new expression, the location information of that connect would be
206 /// lost. With an identity expression, imposing the first constraint becomes
207 /// `var >= identity(a)`, which is a *new* expression and properly tracks the
208 /// location info.
209 struct IdExpr : public ExprBase<IdExpr, Expr::Kind::Id> {
210  IdExpr(Expr *arg) : arg(arg) { assert(arg); }
211  void print(llvm::raw_ostream &os) const { os << "*" << *arg; }
212  bool operator==(const IdExpr &other) const {
213  return kind == other.kind && arg == other.arg;
214  }
215  llvm::hash_code hash_value() const {
216  return llvm::hash_combine(Expr::hash_value(), arg);
217  }
218 
219  /// The inner expression.
220  Expr *const arg;
221 };
222 
223 /// A known constant value.
224 struct KnownExpr : public ExprBase<KnownExpr, Expr::Kind::Known> {
225  KnownExpr(int32_t value) : ExprBase() { solution = value; }
226  void print(llvm::raw_ostream &os) const { os << *solution; }
227  bool operator==(const KnownExpr &other) const {
228  return *solution == *other.solution;
229  }
230  llvm::hash_code hash_value() const {
231  return llvm::hash_combine(Expr::hash_value(), *solution);
232  }
233 };
234 
235 /// A unary expression. Contains the actual data. Concrete subclasses are merely
236 /// there for show and ease of use.
237 struct UnaryExpr : public Expr {
238  bool operator==(const UnaryExpr &other) const {
239  return kind == other.kind && arg == other.arg;
240  }
241  llvm::hash_code hash_value() const {
242  return llvm::hash_combine(Expr::hash_value(), arg);
243  }
244 
245  /// The child expression.
246  Expr *const arg;
247 
248 protected:
249  UnaryExpr(Kind kind, Expr *arg) : Expr(kind), arg(arg) { assert(arg); }
250 };
251 
252 /// Helper class to CRTP-derive common functions.
253 template <class DerivedT, Expr::Kind DerivedKind>
254 struct UnaryExprBase : public UnaryExpr {
255  template <typename... Args>
256  UnaryExprBase(Args &&...args)
257  : UnaryExpr(DerivedKind, std::forward<Args>(args)...) {}
258  static bool classof(const Expr *e) { return e->kind == DerivedKind; }
259 };
260 
261 /// A power of two.
262 struct PowExpr : public UnaryExprBase<PowExpr, Expr::Kind::Pow> {
263  using UnaryExprBase::UnaryExprBase;
264  void print(llvm::raw_ostream &os) const { os << "2^" << arg; }
265 };
266 
267 /// A binary expression. Contains the actual data. Concrete subclasses are
268 /// merely there for show and ease of use.
269 struct BinaryExpr : public Expr {
270  bool operator==(const BinaryExpr &other) const {
271  return kind == other.kind && lhs() == other.lhs() && rhs() == other.rhs();
272  }
273  llvm::hash_code hash_value() const {
274  return llvm::hash_combine(Expr::hash_value(), *args);
275  }
276  Expr *lhs() const { return args[0]; }
277  Expr *rhs() const { return args[1]; }
278 
279  /// The child expressions.
280  Expr *const args[2];
281 
282 protected:
283  BinaryExpr(Kind kind, Expr *lhs, Expr *rhs) : Expr(kind), args{lhs, rhs} {
284  assert(lhs);
285  assert(rhs);
286  }
287 };
288 
289 /// Helper class to CRTP-derive common functions.
290 template <class DerivedT, Expr::Kind DerivedKind>
291 struct BinaryExprBase : public BinaryExpr {
292  template <typename... Args>
293  BinaryExprBase(Args &&...args)
294  : BinaryExpr(DerivedKind, std::forward<Args>(args)...) {}
295  static bool classof(const Expr *e) { return e->kind == DerivedKind; }
296 };
297 
298 /// An addition.
299 struct AddExpr : public BinaryExprBase<AddExpr, Expr::Kind::Add> {
300  using BinaryExprBase::BinaryExprBase;
301  void print(llvm::raw_ostream &os) const {
302  os << "(" << *lhs() << " + " << *rhs() << ")";
303  }
304 };
305 
306 /// The maximum of two expressions.
307 struct MaxExpr : public BinaryExprBase<MaxExpr, Expr::Kind::Max> {
308  using BinaryExprBase::BinaryExprBase;
309  void print(llvm::raw_ostream &os) const {
310  os << "max(" << *lhs() << ", " << *rhs() << ")";
311  }
312 };
313 
314 /// The minimum of two expressions.
315 struct MinExpr : public BinaryExprBase<MinExpr, Expr::Kind::Min> {
316  using BinaryExprBase::BinaryExprBase;
317  void print(llvm::raw_ostream &os) const {
318  os << "min(" << *lhs() << ", " << *rhs() << ")";
319  }
320 };
321 
322 void Expr::print(llvm::raw_ostream &os) const {
323  TypeSwitch<const Expr *>(this).Case<EXPR_CLASSES>(
324  [&](auto *e) { e->print(os); });
325 }
326 
327 } // namespace
328 
329 //===----------------------------------------------------------------------===//
330 // Fast bump allocator with optional interning
331 //===----------------------------------------------------------------------===//
332 
333 namespace {
334 
335 /// An allocation slot in the `InternedAllocator`.
336 template <typename T>
337 struct InternedSlot {
338  T *ptr;
339  InternedSlot(T *ptr) : ptr(ptr) {}
340 };
341 
342 /// A simple bump allocator that ensures only ever one copy per object exists.
343 /// The allocated objects must not have a destructor.
344 template <typename T, typename std::enable_if_t<
346 class InternedAllocator {
347  using Slot = InternedSlot<T>;
348  llvm::DenseSet<Slot> interned;
349  llvm::BumpPtrAllocator &allocator;
350 
351 public:
352  InternedAllocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
353 
354  /// Allocate a new object if it does not yet exist, or return a pointer to the
355  /// existing one. `R` is the type of the object to be allocated. `R` must be
356  /// derived from or be the type `T`.
357  template <typename R = T, typename... Args>
358  std::pair<R *, bool> alloc(Args &&...args) {
359  auto stack_value = R(std::forward<Args>(args)...);
360  auto stack_slot = Slot(&stack_value);
361  auto it = interned.find(stack_slot);
362  if (it != interned.end())
363  return std::make_pair(static_cast<R *>(it->ptr), false);
364  auto heap_value = new (allocator) R(std::move(stack_value));
365  interned.insert(Slot(heap_value));
366  return std::make_pair(heap_value, true);
367  }
368 };
369 
370 /// A simple bump allocator. The allocated objects must not have a destructor.
371 /// This allocator is mainly there for symmetry with the `InternedAllocator`.
372 template <typename T, typename std::enable_if_t<
374 class Allocator {
375  llvm::BumpPtrAllocator &allocator;
376 
377 public:
378  Allocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
379 
380  /// Allocate a new object. `R` is the type of the object to be allocated. `R`
381  /// must be derived from or be the type `T`.
382  template <typename R = T, typename... Args>
383  R *alloc(Args &&...args) {
384  return new (allocator) R(std::forward<Args>(args)...);
385  }
386 };
387 
388 } // namespace
389 
390 //===----------------------------------------------------------------------===//
391 // Constraint Solver
392 //===----------------------------------------------------------------------===//
393 
394 namespace {
395 /// A canonicalized linear inequality that maps a constraint on var `x` to the
396 /// linear inequality `x >= max(a*x+b, c) + (failed ? ∞ : 0)`.
397 ///
398 /// The inequality separately tracks recursive (a, b) and non-recursive (c)
399 /// constraints on `x`. This allows it to properly identify the combination of
400 /// the two constraints `x >= x-1` and `x >= 4` to be satisfiable as
401 /// `x >= max(x-1, 4)`. If it only tracked inequality as `x >= a*x+b`, the
402 /// combination of these two constraints would be `x >= x+4` (due to max(-1,4) =
403 /// 4), which would be unsatisfiable.
404 ///
405 /// The `failed` flag acts as an additional `∞` term that renders the inequality
406 /// unsatisfiable. It is used as a tombstone value in case an operation renders
407 /// the equality unsatisfiable (e.g. `x >= 2**x` would be represented as the
408 /// inequality `x >= ∞`).
409 ///
410 /// Inequalities represented in this form can easily be checked for
411 /// unsatisfiability in the presence of recursion by inspecting the coefficients
412 /// a and b. The `sat` function performs this action.
413 struct LinIneq {
414  // x >= max(a*x+b, c) + (failed ? ∞ : 0)
415  int32_t rec_scale = 0; // a
416  int32_t rec_bias = 0; // b
417  int32_t nonrec_bias = 0; // c
418  bool failed = false;
419 
420  /// Create a new unsatisfiable inequality `x >= ∞`.
421  static LinIneq unsat() { return LinIneq(true); }
422 
423  /// Create a new inequality `x >= (failed ? ∞ : 0)`.
424  explicit LinIneq(bool failed = false) : failed(failed) {}
425 
426  /// Create a new inequality `x >= bias`.
427  explicit LinIneq(int32_t bias) : nonrec_bias(bias) {}
428 
429  /// Create a new inequality `x >= scale*x+bias`.
430  explicit LinIneq(int32_t scale, int32_t bias) {
431  if (scale != 0) {
432  rec_scale = scale;
433  rec_bias = bias;
434  } else {
435  nonrec_bias = bias;
436  }
437  }
438 
439  /// Create a new inequality `x >= max(rec_scale*x+rec_bias, nonrec_bias) +
440  /// (failed ? ∞ : 0)`.
441  explicit LinIneq(int32_t rec_scale, int32_t rec_bias, int32_t nonrec_bias,
442  bool failed = false)
443  : failed(failed) {
444  if (rec_scale != 0) {
445  this->rec_scale = rec_scale;
446  this->rec_bias = rec_bias;
447  this->nonrec_bias = nonrec_bias;
448  } else {
449  this->nonrec_bias = std::max(rec_bias, nonrec_bias);
450  }
451  }
452 
453  /// Combine two inequalities by taking the maxima of corresponding
454  /// coefficients.
455  ///
456  /// This essentially combines `x >= max(a1*x+b1, c1)` and `x >= max(a2*x+b2,
457  /// c2)` into a new `x >= max(max(a1,a2)*x+max(b1,b2), max(c1,c2))`. This is
458  /// a pessimistic upper bound, since e.g. `x >= 2x-10` and `x >= x-5` may both
459  /// hold, but the resulting `x >= 2x-5` may pessimistically not hold.
460  static LinIneq max(const LinIneq &lhs, const LinIneq &rhs) {
461  return LinIneq(std::max(lhs.rec_scale, rhs.rec_scale),
462  std::max(lhs.rec_bias, rhs.rec_bias),
463  std::max(lhs.nonrec_bias, rhs.nonrec_bias),
464  lhs.failed || rhs.failed);
465  }
466 
467  /// Combine two inequalities by summing up the two right hand sides.
468  ///
469  /// This is a tricky one, since the addition of the two max terms will lead to
470  /// a maximum over four possible terms (similar to a binomial expansion). In
471  /// order to shoehorn this back into a two-term maximum, we have to pick the
472  /// recursive term that will grow the fastest.
473  ///
474  /// As an example for this problem, consider the following addition:
475  ///
476  /// x >= max(a1*x+b1, c1) + max(a2*x+b2, c2)
477  ///
478  /// We would like to expand and rearrange this again into a maximum:
479  ///
480  /// x >= max(a1*x+b1 + max(a2*x+b2, c2), c1 + max(a2*x+b2, c2))
481  /// x >= max(max(a1*x+b1 + a2*x+b2, a1*x+b1 + c2),
482  /// max(c1 + a2*x+b2, c1 + c2))
483  /// x >= max((a1+a2)*x+(b1+b2), a1*x+(b1+c2), a2*x+(b2+c1), c1+c2)
484  ///
485  /// Since we are combining two two-term maxima, there are four possible ways
486  /// how the terms can combine, leading to the above four-term maximum. An easy
487  /// upper bound of the form we want would be the following:
488  ///
489  /// x >= max(max(a1+a2, a1, a2)*x + max(b1+b2, b1+c2, b2+c1), c1+c2)
490  ///
491  /// However, this is a very pessimistic upper-bound that will declare very
492  /// common patterns in the IR as unbreakable cycles, despite them being very
493  /// much breakable. For example:
494  ///
495  /// x >= max(x, 42) + max(0, -3) <-- breakable recursion
496  /// x >= max(max(1+0, 1, 0)*x + max(42+0, -3, 42), 42-2)
497  /// x >= max(x + 42, 39) <-- unbreakable recursion!
498  ///
499  /// A better approach is to take the expanded four-term maximum, retain the
500  /// non-recursive term (c1+c2), and estimate which one of the recursive terms
501  /// (first three) will become dominant as we choose greater values for x.
502  /// Since x never is inferred to be negative, the recursive term in the
503  /// maximum with the highest scaling factor for x will end up dominating as
504  /// x tends to ∞:
505  ///
506  /// x >= max({
507  /// (a1+a2)*x+(b1+b2) if a1+a2 >= max(a1+a2, a1, a2) and a1>0 and a2>0,
508  /// a1*x+(b1+c2) if a1 >= max(a1+a2, a1, a2) and a1>0,
509  /// a2*x+(b2+c1) if a2 >= max(a1+a2, a1, a2) and a2>0,
510  /// 0 otherwise
511  /// }, c1+c2)
512  ///
513  /// In case multiple cases apply, the highest bias of the recursive term is
514  /// picked. With this, the above problematic example triggers the second case
515  /// and becomes:
516  ///
517  /// x >= max(1*x+(0-3), 42-3) = max(x-3, 39)
518  ///
519  /// Of which the first case is chosen, as it has the lower bias value.
520  static LinIneq add(const LinIneq &lhs, const LinIneq &rhs) {
521  // Determine the maximum scaling factor among the three possible recursive
522  // terms.
523  auto enable1 = lhs.rec_scale > 0 && rhs.rec_scale > 0;
524  auto enable2 = lhs.rec_scale > 0;
525  auto enable3 = rhs.rec_scale > 0;
526  auto scale1 = lhs.rec_scale + rhs.rec_scale; // (a1+a2)
527  auto scale2 = lhs.rec_scale; // a1
528  auto scale3 = rhs.rec_scale; // a2
529  auto bias1 = lhs.rec_bias + rhs.rec_bias; // (b1+b2)
530  auto bias2 = lhs.rec_bias + rhs.nonrec_bias; // (b1+c2)
531  auto bias3 = rhs.rec_bias + lhs.nonrec_bias; // (b2+c1)
532  auto maxScale = std::max(scale1, std::max(scale2, scale3));
533 
534  // Among those terms that have a maximum scaling factor, determine the
535  // largest bias value.
536  std::optional<int32_t> maxBias;
537  if (enable1 && scale1 == maxScale)
538  maxBias = bias1;
539  if (enable2 && scale2 == maxScale && (!maxBias || bias2 > *maxBias))
540  maxBias = bias2;
541  if (enable3 && scale3 == maxScale && (!maxBias || bias3 > *maxBias))
542  maxBias = bias3;
543 
544  // Pick from the recursive terms the one with maximum scaling factor and
545  // minimum bias value.
546  auto nonrec_bias = lhs.nonrec_bias + rhs.nonrec_bias; // c1+c2
547  auto failed = lhs.failed || rhs.failed;
548  if (enable1 && scale1 == maxScale && bias1 == *maxBias)
549  return LinIneq(scale1, bias1, nonrec_bias, failed);
550  if (enable2 && scale2 == maxScale && bias2 == *maxBias)
551  return LinIneq(scale2, bias2, nonrec_bias, failed);
552  if (enable3 && scale3 == maxScale && bias3 == *maxBias)
553  return LinIneq(scale3, bias3, nonrec_bias, failed);
554  return LinIneq(0, 0, nonrec_bias, failed);
555  }
556 
557  /// Check if the inequality is satisfiable.
558  ///
559  /// The inequality becomes unsatisfiable if the RHS is ∞, or a>1, or a==1 and
560  /// b <= 0. Otherwise there exists as solution for `x` that satisfies the
561  /// inequality.
562  bool sat() const {
563  if (failed)
564  return false;
565  if (rec_scale > 1)
566  return false;
567  if (rec_scale == 1 && rec_bias > 0)
568  return false;
569  return true;
570  }
571 
572  /// Dump the inequality in human-readable form.
573  void print(llvm::raw_ostream &os) const {
574  bool any = false;
575  bool both = (rec_scale != 0 || rec_bias != 0) && nonrec_bias != 0;
576  os << "x >= ";
577  if (both)
578  os << "max(";
579  if (rec_scale != 0) {
580  any = true;
581  if (rec_scale != 1)
582  os << rec_scale << "*";
583  os << "x";
584  }
585  if (rec_bias != 0) {
586  if (any) {
587  if (rec_bias < 0)
588  os << " - " << -rec_bias;
589  else
590  os << " + " << rec_bias;
591  } else {
592  any = true;
593  os << rec_bias;
594  }
595  }
596  if (both)
597  os << ", ";
598  if (nonrec_bias != 0) {
599  any = true;
600  os << nonrec_bias;
601  }
602  if (both)
603  os << ")";
604  if (failed) {
605  if (any)
606  os << " + ";
607  os << "∞";
608  }
609  if (!any)
610  os << "0";
611  }
612 };
613 
614 /// A simple solver for width constraints.
615 class ConstraintSolver {
616 public:
617  ConstraintSolver() = default;
618 
619  VarExpr *var() {
620  auto v = vars.alloc();
621  exprs.push_back(v);
622  if (currentInfo)
623  info[v].insert(currentInfo);
624  if (currentLoc)
625  locs[v].insert(*currentLoc);
626  return v;
627  }
628  DerivedExpr *derived() {
629  auto *d = derivs.alloc();
630  exprs.push_back(d);
631  return d;
632  }
633  KnownExpr *known(int32_t value) { return alloc<KnownExpr>(knowns, value); }
634  IdExpr *id(Expr *arg) { return alloc<IdExpr>(ids, arg); }
635  PowExpr *pow(Expr *arg) { return alloc<PowExpr>(uns, arg); }
636  AddExpr *add(Expr *lhs, Expr *rhs) { return alloc<AddExpr>(bins, lhs, rhs); }
637  MaxExpr *max(Expr *lhs, Expr *rhs) { return alloc<MaxExpr>(bins, lhs, rhs); }
638  MinExpr *min(Expr *lhs, Expr *rhs) { return alloc<MinExpr>(bins, lhs, rhs); }
639 
640  /// Add a constraint `lhs >= rhs`. Multiple constraints on the same variable
641  /// are coalesced into a `max(a, b)` expr.
642  Expr *addGeqConstraint(VarExpr *lhs, Expr *rhs) {
643  if (lhs->constraint)
644  lhs->constraint = max(lhs->constraint, rhs);
645  else
646  lhs->constraint = id(rhs);
647  return lhs->constraint;
648  }
649 
650  /// Add a constraint `lhs <= rhs`. Multiple constraints on the same variable
651  /// are coalesced into a `min(a, b)` expr.
652  Expr *addLeqConstraint(VarExpr *lhs, Expr *rhs) {
653  if (lhs->upperBound)
654  lhs->upperBound = min(lhs->upperBound, rhs);
655  else
656  lhs->upperBound = id(rhs);
657  return lhs->upperBound;
658  }
659 
660  void dumpConstraints(llvm::raw_ostream &os);
661  LogicalResult solve();
662 
663  using ContextInfo = DenseMap<Expr *, llvm::SmallSetVector<FieldRef, 1>>;
664  const ContextInfo &getContextInfo() const { return info; }
665  void setCurrentContextInfo(FieldRef fieldRef) { currentInfo = fieldRef; }
666  void setCurrentLocation(std::optional<Location> loc) { currentLoc = loc; }
667 
668 private:
669  // Allocator for constraint expressions.
670  llvm::BumpPtrAllocator allocator;
671  Allocator<VarExpr> vars = {allocator};
672  Allocator<DerivedExpr> derivs = {allocator};
673  InternedAllocator<KnownExpr> knowns = {allocator};
674  InternedAllocator<IdExpr> ids = {allocator};
675  InternedAllocator<UnaryExpr> uns = {allocator};
676  InternedAllocator<BinaryExpr> bins = {allocator};
677 
678  /// A list of expressions in the order they were created.
679  std::vector<Expr *> exprs;
680  RootExpr root = {exprs};
681 
682  /// Add an allocated expression to the list above.
683  template <typename R, typename T, typename... Args>
684  R *alloc(InternedAllocator<T> &allocator, Args &&...args) {
685  auto it = allocator.template alloc<R>(std::forward<Args>(args)...);
686  if (it.second)
687  exprs.push_back(it.first);
688  if (currentInfo)
689  info[it.first].insert(currentInfo);
690  if (currentLoc)
691  locs[it.first].insert(*currentLoc);
692  return it.first;
693  }
694 
695  /// Contextual information for each expression, indicating which values in the
696  /// IR lead to this expression.
697  ContextInfo info;
698  FieldRef currentInfo = {};
699  DenseMap<Expr *, llvm::SmallSetVector<Location, 1>> locs;
700  std::optional<Location> currentLoc;
701 
702  // Forbid copyign or moving the solver, which would invalidate the refs to
703  // allocator held by the allocators.
704  ConstraintSolver(ConstraintSolver &&) = delete;
705  ConstraintSolver(const ConstraintSolver &) = delete;
706  ConstraintSolver &operator=(ConstraintSolver &&) = delete;
707  ConstraintSolver &operator=(const ConstraintSolver &) = delete;
708 
709  void emitUninferredWidthError(VarExpr *var);
710 
711  LinIneq checkCycles(VarExpr *var, Expr *expr,
712  SmallPtrSetImpl<Expr *> &seenVars,
713  InFlightDiagnostic *reportInto = nullptr,
714  unsigned indent = 1);
715 };
716 
717 } // namespace
718 
719 /// Print all constraints in the solver to an output stream.
720 void ConstraintSolver::dumpConstraints(llvm::raw_ostream &os) {
721  for (auto *e : exprs) {
722  if (auto *v = dyn_cast<VarExpr>(e)) {
723  if (v->constraint)
724  os << "- " << *v << " >= " << *v->constraint << "\n";
725  else
726  os << "- " << *v << " unconstrained\n";
727  }
728  }
729 }
730 
731 #ifndef NDEBUG
732 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const LinIneq &l) {
733  l.print(os);
734  return os;
735 }
736 #endif
737 
738 /// Compute the canonicalized linear inequality expression starting at `expr`,
739 /// for the `var` as the left hand side `x` of the inequality. `seenVars` is
740 /// used as a recursion breaker. Occurrences of `var` itself within the
741 /// expression are mapped to the `a` coefficient in the inequality. Any other
742 /// variables are substituted and, in the presence of a recursion in a variable
743 /// other than `var`, treated as zero. `info` is a mapping from constraint
744 /// expressions to values and operations that produced the expression, and is
745 /// used during error reporting. If `reportInto` is present, the function will
746 /// additionally attach unsatisfiable inequalities as notes to the diagnostic as
747 /// it encounters them.
748 LinIneq ConstraintSolver::checkCycles(VarExpr *var, Expr *expr,
749  SmallPtrSetImpl<Expr *> &seenVars,
750  InFlightDiagnostic *reportInto,
751  unsigned indent) {
752  auto ineq =
753  TypeSwitch<Expr *, LinIneq>(expr)
754  .Case<KnownExpr>([&](auto *expr) { return LinIneq(*expr->solution); })
755  .Case<VarExpr>([&](auto *expr) {
756  if (expr == var)
757  return LinIneq(1, 0); // x >= 1*x + 0
758  if (!seenVars.insert(expr).second)
759  // Count recursions in other variables as 0. This is sane
760  // since the cycle is either breakable, in which case the
761  // recursion does not modify the resulting value of the
762  // variable, or it is not breakable and will be caught by
763  // this very function once it is called on that variable.
764  return LinIneq(0);
765  if (!expr->constraint)
766  // Count unconstrained variables as `x >= 0`.
767  return LinIneq(0);
768  auto l = checkCycles(var, expr->constraint, seenVars, reportInto,
769  indent + 1);
770  seenVars.erase(expr);
771  return l;
772  })
773  .Case<IdExpr>([&](auto *expr) {
774  return checkCycles(var, expr->arg, seenVars, reportInto,
775  indent + 1);
776  })
777  .Case<PowExpr>([&](auto *expr) {
778  // If we can evaluate `2**arg` to a sensible constant, do
779  // so. This is the case if a == 0 and c < 31 such that 2**c is
780  // representable.
781  auto arg =
782  checkCycles(var, expr->arg, seenVars, reportInto, indent + 1);
783  if (arg.rec_scale != 0 || arg.nonrec_bias < 0 ||
784  arg.nonrec_bias >= 31)
785  return LinIneq::unsat();
786  return LinIneq(1 << arg.nonrec_bias); // x >= 2**arg
787  })
788  .Case<AddExpr>([&](auto *expr) {
789  return LinIneq::add(
790  checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
791  checkCycles(var, expr->rhs(), seenVars, reportInto,
792  indent + 1));
793  })
794  .Case<MaxExpr, MinExpr>([&](auto *expr) {
795  // Combine the inequalities of the LHS and RHS into a single overly
796  // pessimistic inequality. We treat `MinExpr` the same as `MaxExpr`,
797  // since `max(a,b)` is an upper bound to `min(a,b)`.
798  return LinIneq::max(
799  checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
800  checkCycles(var, expr->rhs(), seenVars, reportInto,
801  indent + 1));
802  })
803  .Default([](auto) { return LinIneq::unsat(); });
804 
805  // If we were passed an in-flight diagnostic and the current inequality is
806  // unsatisfiable, attach notes to the diagnostic indicating the values or
807  // operations that contributed to this part of the constraint expression.
808  if (reportInto && !ineq.sat()) {
809  auto report = [&](Location loc) {
810  auto &note = reportInto->attachNote(loc);
811  note << "constrained width W >= ";
812  if (ineq.rec_scale == -1)
813  note << "-";
814  if (ineq.rec_scale != 1)
815  note << ineq.rec_scale;
816  note << "W";
817  if (ineq.rec_bias < 0)
818  note << "-" << -ineq.rec_bias;
819  if (ineq.rec_bias > 0)
820  note << "+" << ineq.rec_bias;
821  note << " here:";
822  };
823  auto it = locs.find(expr);
824  if (it != locs.end())
825  for (auto loc : it->second)
826  report(loc);
827  }
828  if (!reportInto)
829  LLVM_DEBUG(llvm::dbgs().indent(indent * 2)
830  << "- Visited " << *expr << ": " << ineq << "\n");
831 
832  return ineq;
833 }
834 
835 using ExprSolution = std::pair<std::optional<int32_t>, bool>;
836 
837 static ExprSolution
838 computeUnary(ExprSolution arg, llvm::function_ref<int32_t(int32_t)> operation) {
839  if (arg.first)
840  arg.first = operation(*arg.first);
841  return arg;
842 }
843 
844 static ExprSolution
846  llvm::function_ref<int32_t(int32_t, int32_t)> operation) {
847  auto result = ExprSolution{std::nullopt, lhs.second || rhs.second};
848  if (lhs.first && rhs.first)
849  result.first = operation(*lhs.first, *rhs.first);
850  else if (lhs.first)
851  result.first = lhs.first;
852  else if (rhs.first)
853  result.first = rhs.first;
854  return result;
855 }
856 
857 /// Compute the value of a constraint `expr`. `seenVars` is used as a recursion
858 /// breaker. Recursive variables are treated as zero. Returns the computed value
859 /// and a boolean indicating whether a recursion was detected. This may be used
860 /// to memoize the result of expressions in case they were not involved in a
861 /// cycle (which may alter their value from the perspective of a variable).
862 static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
863  unsigned defaultWorklistSize) {
864 
865  struct Frame {
866  Expr *expr;
867  unsigned indent;
868  };
869 
870  // indent only used for debug logs.
871  unsigned indent = 1;
872  std::vector<Frame> worklist({{expr, indent}});
873  llvm::DenseMap<Expr *, ExprSolution> solvedExprs;
874  // Reserving the vector size, to avoid frequent reallocs. The worklist can be
875  // quite large.
876  worklist.reserve(defaultWorklistSize);
877 
878  while (!worklist.empty()) {
879  auto &frame = worklist.back();
880  auto setSolution = [&](ExprSolution solution) {
881  // Memoize the result.
882  if (solution.first && !solution.second)
883  frame.expr->solution = *solution.first;
884  solvedExprs[frame.expr] = solution;
885 
886  // Produce some useful debug prints.
887  LLVM_DEBUG({
888  if (!isa<KnownExpr>(frame.expr)) {
889  if (solution.first)
890  llvm::dbgs().indent(frame.indent * 2)
891  << "= Solved " << *frame.expr << " = " << *solution.first;
892  else
893  llvm::dbgs().indent(frame.indent * 2)
894  << "= Skipped " << *frame.expr;
895  llvm::dbgs() << " (" << (solution.second ? "cycle broken" : "unique")
896  << ")\n";
897  }
898  });
899 
900  worklist.pop_back();
901  };
902 
903  // See if we have a memoized result we can return.
904  if (frame.expr->solution) {
905  LLVM_DEBUG({
906  if (!isa<KnownExpr>(frame.expr))
907  llvm::dbgs().indent(indent * 2) << "- Cached " << *frame.expr << " = "
908  << *frame.expr->solution << "\n";
909  });
910  setSolution(ExprSolution{*frame.expr->solution, false});
911  continue;
912  }
913 
914  // Otherwise compute the value of the expression.
915  LLVM_DEBUG({
916  if (!isa<KnownExpr>(frame.expr))
917  llvm::dbgs().indent(frame.indent * 2)
918  << "- Solving " << *frame.expr << "\n";
919  });
920 
921  TypeSwitch<Expr *>(frame.expr)
922  .Case<KnownExpr>([&](auto *expr) {
923  setSolution(ExprSolution{*expr->solution, false});
924  })
925  .Case<VarExpr>([&](auto *expr) {
926  if (solvedExprs.contains(expr->constraint)) {
927  auto solution = solvedExprs[expr->constraint];
928  // If we've solved the upper bound already, store the solution.
929  // This will be explicitly solved for later if not computed as
930  // part of the solving that resolved this constraint.
931  // This should only happen if somehow the constraint is
932  // solved before visiting this expression, so that our upperBound
933  // was not added to the worklist such that it was handled first.
934  if (expr->upperBound && solvedExprs.contains(expr->upperBound))
935  expr->upperBoundSolution = solvedExprs[expr->upperBound].first;
936  seenVars.erase(expr);
937  // Constrain variables >= 0.
938  if (solution.first && *solution.first < 0)
939  solution.first = 0;
940  return setSolution(solution);
941  }
942 
943  // Unconstrained variables produce no solution.
944  if (!expr->constraint)
945  return setSolution(ExprSolution{std::nullopt, false});
946  // Return no solution for recursions in the variables. This is sane
947  // and will cause the expression to be ignored when computing the
948  // parent, e.g. `a >= max(a, 1)` will become just `a >= 1`.
949  if (!seenVars.insert(expr).second)
950  return setSolution(ExprSolution{std::nullopt, true});
951 
952  worklist.push_back({expr->constraint, indent + 1});
953  if (expr->upperBound)
954  worklist.push_back({expr->upperBound, indent + 1});
955  })
956  .Case<IdExpr>([&](auto *expr) {
957  if (solvedExprs.contains(expr->arg))
958  return setSolution(solvedExprs[expr->arg]);
959  worklist.push_back({expr->arg, indent + 1});
960  })
961  .Case<PowExpr>([&](auto *expr) {
962  if (solvedExprs.contains(expr->arg))
963  return setSolution(computeUnary(
964  solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));
965 
966  worklist.push_back({expr->arg, indent + 1});
967  })
968  .Case<AddExpr>([&](auto *expr) {
969  if (solvedExprs.contains(expr->lhs()) &&
970  solvedExprs.contains(expr->rhs()))
971  return setSolution(computeBinary(
972  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
973  [](int32_t lhs, int32_t rhs) { return lhs + rhs; }));
974 
975  worklist.push_back({expr->lhs(), indent + 1});
976  worklist.push_back({expr->rhs(), indent + 1});
977  })
978  .Case<MaxExpr>([&](auto *expr) {
979  if (solvedExprs.contains(expr->lhs()) &&
980  solvedExprs.contains(expr->rhs()))
981  return setSolution(computeBinary(
982  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
983  [](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));
984 
985  worklist.push_back({expr->lhs(), indent + 1});
986  worklist.push_back({expr->rhs(), indent + 1});
987  })
988  .Case<MinExpr>([&](auto *expr) {
989  if (solvedExprs.contains(expr->lhs()) &&
990  solvedExprs.contains(expr->rhs()))
991  return setSolution(computeBinary(
992  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
993  [](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));
994 
995  worklist.push_back({expr->lhs(), indent + 1});
996  worklist.push_back({expr->rhs(), indent + 1});
997  })
998  .Default([&](auto) {
999  setSolution(ExprSolution{std::nullopt, false});
1000  });
1001  }
1002 
1003  return solvedExprs[expr];
1004 }
1005 
1006 /// Solve the constraint problem. This is a very simple implementation that
1007 /// does not fully solve the problem if there are weird dependency cycles
1008 /// present.
1009 LogicalResult ConstraintSolver::solve() {
1010  LLVM_DEBUG({
1011  llvm::dbgs() << "\n===----- Constraints -----===\n\n";
1012  dumpConstraints(llvm::dbgs());
1013  });
1014 
1015  // Ensure that there are no adverse cycles around.
1016  LLVM_DEBUG(
1017  llvm::dbgs() << "\n===----- Checking for unbreakable loops -----===\n\n");
1018  SmallPtrSet<Expr *, 16> seenVars;
1019  bool anyFailed = false;
1020 
1021  for (auto *expr : exprs) {
1022  // Only work on variables.
1023  auto *var = dyn_cast<VarExpr>(expr);
1024  if (!var || !var->constraint)
1025  continue;
1026  LLVM_DEBUG(llvm::dbgs()
1027  << "- Checking " << *var << " >= " << *var->constraint << "\n");
1028 
1029  // Canonicalize the variable's constraint expression into a form that allows
1030  // us to easily determine if any recursion leads to an unsatisfiable
1031  // constraint. The `seenVars` set acts as a recursion breaker.
1032  seenVars.insert(var);
1033  auto ineq = checkCycles(var, var->constraint, seenVars);
1034  seenVars.clear();
1035 
1036  // If the constraint is satisfiable, we're done.
1037  // TODO: It's possible that this result is already sufficient to arrive at a
1038  // solution for the constraint, and the second pass further down is not
1039  // necessary. This would require more proper handling of `MinExpr` in the
1040  // cycle checking code.
1041  if (ineq.sat()) {
1042  LLVM_DEBUG(llvm::dbgs()
1043  << " = Breakable since " << ineq << " satisfiable\n");
1044  continue;
1045  }
1046 
1047  // If we arrive here, the constraint is not satisfiable at all. To provide
1048  // some guidance to the user, we call the cycle checking code again, but
1049  // this time with an in-flight diagnostic to attach notes indicating
1050  // unsatisfiable paths in the cycle.
1051  LLVM_DEBUG(llvm::dbgs()
1052  << " = UNBREAKABLE since " << ineq << " unsatisfiable\n");
1053  anyFailed = true;
1054  for (auto fieldRef : info.find(var)->second) {
1055  // Depending on whether this value stems from an operation or not, create
1056  // an appropriate diagnostic identifying the value.
1057  auto op = fieldRef.getDefiningOp();
1058  auto diag = op ? op->emitOpError()
1059  : mlir::emitError(fieldRef.getValue().getLoc())
1060  << "value ";
1061  diag << "is constrained to be wider than itself";
1062 
1063  // Re-run the cycle checking, but this time reporting into the diagnostic.
1064  seenVars.insert(var);
1065  checkCycles(var, var->constraint, seenVars, &diag);
1066  seenVars.clear();
1067  }
1068  }
1069 
1070  // If there were cycles, return now to avoid complaining to the user about
1071  // dependent widths not being inferred.
1072  if (anyFailed)
1073  return failure();
1074 
1075  // Iterate over the constraint variables and solve each.
1076  LLVM_DEBUG(llvm::dbgs() << "\n===----- Solving constraints -----===\n\n");
1077  unsigned defaultWorklistSize = exprs.size() / 2;
1078  for (auto *expr : exprs) {
1079  // Only work on variables.
1080  auto *var = dyn_cast<VarExpr>(expr);
1081  if (!var)
1082  continue;
1083 
1084  // Complain about unconstrained variables.
1085  if (!var->constraint) {
1086  LLVM_DEBUG(llvm::dbgs() << "- Unconstrained " << *var << "\n");
1087  emitUninferredWidthError(var);
1088  anyFailed = true;
1089  continue;
1090  }
1091 
1092  // Compute the value for the variable.
1093  LLVM_DEBUG(llvm::dbgs()
1094  << "- Solving " << *var << " >= " << *var->constraint << "\n");
1095  seenVars.insert(var);
1096  auto solution = solveExpr(var->constraint, seenVars, defaultWorklistSize);
1097  // Compute the upperBound if there is one and haven't already.
1098  if (var->upperBound && !var->upperBoundSolution)
1099  var->upperBoundSolution =
1100  solveExpr(var->upperBound, seenVars, defaultWorklistSize).first;
1101  seenVars.clear();
1102 
1103  // Constrain variables >= 0.
1104  if (solution.first && *solution.first < 0)
1105  solution.first = 0;
1106  var->solution = solution.first;
1107 
1108  // In case the width could not be inferred, complain to the user. This might
1109  // be the case if the width depends on an unconstrained variable.
1110  if (!solution.first) {
1111  LLVM_DEBUG(llvm::dbgs() << " - UNSOLVED " << *var << "\n");
1112  emitUninferredWidthError(var);
1113  anyFailed = true;
1114  continue;
1115  }
1116  LLVM_DEBUG(llvm::dbgs()
1117  << " = Solved " << *var << " = " << solution.first << " ("
1118  << (solution.second ? "cycle broken" : "unique") << ")\n");
1119 
1120  // Check if the solution we have found violates an upper bound.
1121  if (var->upperBoundSolution && var->upperBoundSolution < *solution.first) {
1122  LLVM_DEBUG(llvm::dbgs() << " ! Unsatisfiable " << *var
1123  << " <= " << var->upperBoundSolution << "\n");
1124  emitUninferredWidthError(var);
1125  anyFailed = true;
1126  }
1127  }
1128 
1129  // Copy over derived widths.
1130  for (auto *expr : exprs) {
1131  // Only work on derived values.
1132  auto *derived = dyn_cast<DerivedExpr>(expr);
1133  if (!derived)
1134  continue;
1135 
1136  auto *assigned = derived->assigned;
1137  if (!assigned || !assigned->solution) {
1138  LLVM_DEBUG(llvm::dbgs() << "- Unused " << *derived << " set to 0\n");
1139  derived->solution = 0;
1140  } else {
1141  LLVM_DEBUG(llvm::dbgs() << "- Deriving " << *derived << " = "
1142  << assigned->solution << "\n");
1143  derived->solution = *assigned->solution;
1144  }
1145  }
1146 
1147  return failure(anyFailed);
1148 }
1149 
1150 // Emits the diagnostic to inform the user about an uninferred width in the
1151 // design. Returns true if an error was reported, false otherwise.
1152 void ConstraintSolver::emitUninferredWidthError(VarExpr *var) {
1153  FieldRef fieldRef = info.find(var)->second.back();
1154  Value value = fieldRef.getValue();
1155 
1156  auto diag = mlir::emitError(value.getLoc(), "uninferred width:");
1157 
1158  // Try to hint the user at what kind of node this is.
1159  if (isa<BlockArgument>(value)) {
1160  diag << " port";
1161  } else if (auto op = value.getDefiningOp()) {
1162  TypeSwitch<Operation *>(op)
1163  .Case<WireOp>([&](auto) { diag << " wire"; })
1164  .Case<RegOp, RegResetOp>([&](auto) { diag << " reg"; })
1165  .Case<NodeOp>([&](auto) { diag << " node"; })
1166  .Default([&](auto) { diag << " value"; });
1167  } else {
1168  diag << " value";
1169  }
1170 
1171  // Actually print what the user can refer to.
1172  auto [fieldName, rootKnown] = getFieldName(fieldRef);
1173  if (!fieldName.empty()) {
1174  if (!rootKnown)
1175  diag << " field";
1176  diag << " \"" << fieldName << "\"";
1177  }
1178 
1179  if (!var->constraint) {
1180  diag << " is unconstrained";
1181  } else if (var->solution && var->upperBoundSolution &&
1182  var->solution > var->upperBoundSolution) {
1183  diag << " cannot satisfy all width requirements";
1184  LLVM_DEBUG(llvm::dbgs() << *var->constraint << "\n");
1185  LLVM_DEBUG(llvm::dbgs() << *var->upperBound << "\n");
1186  auto loc = locs.find(var->constraint)->second.back();
1187  diag.attachNote(loc) << "width is constrained to be at least "
1188  << *var->solution << " here:";
1189  loc = locs.find(var->upperBound)->second.back();
1190  diag.attachNote(loc) << "width is constrained to be at most "
1191  << *var->upperBoundSolution << " here:";
1192  } else {
1193  diag << " width cannot be determined";
1194  LLVM_DEBUG(llvm::dbgs() << *var->constraint << "\n");
1195  auto loc = locs.find(var->constraint)->second.back();
1196  diag.attachNote(loc) << "width is constrained by an uninferred width here:";
1197  }
1198 }
1199 
1200 //===----------------------------------------------------------------------===//
1201 // Inference Constraint Problem Mapping
1202 //===----------------------------------------------------------------------===//
1203 
1204 namespace {
1205 
1206 /// A helper class which maps the types and operations in a design to a set of
1207 /// variables and constraints to be solved later.
1208 class InferenceMapping {
1209 public:
1210  InferenceMapping(ConstraintSolver &solver, SymbolTable &symtbl,
1211  hw::InnerSymbolTableCollection &istc)
1212  : solver(solver), symtbl(symtbl), irn{symtbl, istc} {}
1213 
1214  LogicalResult map(CircuitOp op);
1215  LogicalResult mapOperation(Operation *op);
1216 
1217  /// Declare all the variables in the value. If the value is a ground type,
1218  /// there is a single variable declared. If the value is an aggregate type,
1219  /// it sets up variables for each unknown width.
1220  void declareVars(Value value, Location loc, bool isDerived = false);
1221 
1222  /// Declare a variable associated with a specific field of an aggregate.
1223  Expr *declareVar(FieldRef fieldRef, Location loc);
1224 
1225  /// Declarate a variable for a type with an unknown width. The type must be a
1226  /// non-aggregate.
1227  Expr *declareVar(FIRRTLType type, Location loc);
1228 
1229  /// Assign the constraint expressions of the fields in the `result` argument
1230  /// as the max of expressions in the `rhs` and `lhs` arguments. Both fields
1231  /// must be the same type.
1232  void maximumOfTypes(Value result, Value rhs, Value lhs);
1233 
1234  /// Constrain the value "larger" to be greater than or equal to "smaller".
1235  /// These may be aggregate values. This is used for regular connects.
1236  void constrainTypes(Value larger, Value smaller, bool equal = false);
1237 
1238  /// Constrain the expression "larger" to be greater than or equals to
1239  /// the expression "smaller".
1240  void constrainTypes(Expr *larger, Expr *smaller,
1241  bool imposeUpperBounds = false, bool equal = false);
1242 
1243  /// Assign the constraint expressions of the fields in the `src` argument as
1244  /// the expressions for the `dst` argument. Both fields must be of the given
1245  /// `type`.
1246  void unifyTypes(FieldRef lhs, FieldRef rhs, FIRRTLType type);
1247 
1248  /// Get the expr associated with the value. The value must be a non-aggregate
1249  /// type.
1250  Expr *getExpr(Value value) const;
1251 
1252  /// Get the expr associated with a specific field in a value.
1253  Expr *getExpr(FieldRef fieldRef) const;
1254 
1255  /// Get the expr associated with a specific field in a value. If value is
1256  /// NULL, then this returns NULL.
1257  Expr *getExprOrNull(FieldRef fieldRef) const;
1258 
1259  /// Set the expr associated with the value. The value must be a non-aggregate
1260  /// type.
1261  void setExpr(Value value, Expr *expr);
1262 
1263  /// Set the expr associated with a specific field in a value.
1264  void setExpr(FieldRef fieldRef, Expr *expr);
1265 
1266  /// Return whether a module was skipped due to being fully inferred already.
1267  bool isModuleSkipped(FModuleOp module) const {
1268  return skippedModules.count(module);
1269  }
1270 
1271  /// Return whether all modules in the mapping were fully inferred.
1272  bool areAllModulesSkipped() const { return allModulesSkipped; }
1273 
1274 private:
1275  /// The constraint solver into which we emit variables and constraints.
1276  ConstraintSolver &solver;
1277 
1278  /// The constraint exprs for each result type of an operation.
1279  DenseMap<FieldRef, Expr *> opExprs;
1280 
1281  /// The fully inferred modules that were skipped entirely.
1282  SmallPtrSet<Operation *, 16> skippedModules;
1283  bool allModulesSkipped = true;
1284 
1285  /// Cache of module symbols
1286  SymbolTable &symtbl;
1287 
1288  /// Full design inner symbol information.
1289  hw::InnerRefNamespace irn;
1290 };
1291 
1292 } // namespace
1293 
1294 /// Check if a type contains any FIRRTL type with uninferred widths.
1295 static bool hasUninferredWidth(Type type) {
1296  return TypeSwitch<Type, bool>(type)
1297  .Case<FIRRTLBaseType>([](auto base) { return base.hasUninferredWidth(); })
1298  .Case<RefType>(
1299  [](auto ref) { return ref.getType().hasUninferredWidth(); })
1300  .Default([](auto) { return false; });
1301 }
1302 
1303 LogicalResult InferenceMapping::map(CircuitOp op) {
1304  LLVM_DEBUG(llvm::dbgs()
1305  << "\n===----- Mapping ops to constraint exprs -----===\n\n");
1306 
1307  // Ensure we have constraint variables established for all module ports.
1308  for (auto module : op.getOps<FModuleOp>())
1309  for (auto arg : module.getArguments()) {
1310  solver.setCurrentContextInfo(FieldRef(arg, 0));
1311  declareVars(arg, module.getLoc());
1312  }
1313 
1314  for (auto module : op.getOps<FModuleOp>()) {
1315  // Check if the module contains *any* uninferred widths. This allows us to
1316  // do an early skip if the module is already fully inferred.
1317  bool anyUninferred = false;
1318  for (auto arg : module.getArguments()) {
1319  anyUninferred |= hasUninferredWidth(arg.getType());
1320  if (anyUninferred)
1321  break;
1322  }
1323  module.walk([&](Operation *op) {
1324  for (auto type : op->getResultTypes())
1325  anyUninferred |= hasUninferredWidth(type);
1326  if (anyUninferred)
1327  return WalkResult::interrupt();
1328  return WalkResult::advance();
1329  });
1330 
1331  if (!anyUninferred) {
1332  LLVM_DEBUG(llvm::dbgs() << "Skipping fully-inferred module '"
1333  << module.getName() << "'\n");
1334  skippedModules.insert(module);
1335  continue;
1336  }
1337 
1338  allModulesSkipped = false;
1339 
1340  // Go through operations in the module, creating type variables for results,
1341  // and generating constraints.
1342  auto result = module.getBodyBlock()->walk(
1343  [&](Operation *op) { return WalkResult(mapOperation(op)); });
1344  if (result.wasInterrupted())
1345  return failure();
1346  }
1347 
1348  return success();
1349 }
1350 
1351 LogicalResult InferenceMapping::mapOperation(Operation *op) {
1352  // In case the operation result has a type without uninferred widths, don't
1353  // even bother to populate the constraint problem and treat that as a known
1354  // size directly. This is done in `declareVars`, which will generate
1355  // `KnownExpr` nodes for all known widths -- which are the only ones in this
1356  // case.
1357  bool allWidthsKnown = true;
1358  for (auto result : op->getResults()) {
1359  if (isa<MuxPrimOp, Mux4CellIntrinsicOp, Mux2CellIntrinsicOp>(op))
1360  if (hasUninferredWidth(op->getOperand(0).getType()))
1361  allWidthsKnown = false;
1362  // Only consider FIRRTL types for width constraints. Ignore any foreign
1363  // types as they don't participate in the width inference process.
1364  auto resultTy = type_dyn_cast<FIRRTLType>(result.getType());
1365  if (!resultTy)
1366  continue;
1367  if (!hasUninferredWidth(resultTy))
1368  declareVars(result, op->getLoc());
1369  else
1370  allWidthsKnown = false;
1371  }
1372  if (allWidthsKnown && !isa<FConnectLike, AttachOp>(op))
1373  return success();
1374 
1375  /// Ignore property assignments, no widths to infer.
1376  if (isa<PropAssignOp>(op))
1377  return success();
1378 
1379  // Actually generate the necessary constraint expressions.
1380  bool mappingFailed = false;
1381  solver.setCurrentContextInfo(
1382  op->getNumResults() > 0 ? FieldRef(op->getResults()[0], 0) : FieldRef());
1383  solver.setCurrentLocation(op->getLoc());
1384  TypeSwitch<Operation *>(op)
1385  .Case<ConstantOp>([&](auto op) {
1386  // If the constant has a known width, use that. Otherwise pick the
1387  // smallest number of bits necessary to represent the constant.
1388  Expr *e;
1389  if (auto width = op.getType().get().getWidth())
1390  e = solver.known(*width);
1391  else {
1392  auto v = op.getValue();
1393  auto w = v.getBitWidth() - (v.isNegative() ? v.countLeadingOnes()
1394  : v.countLeadingZeros());
1395  if (v.isSigned())
1396  w += 1;
1397  e = solver.known(std::max(w, 1u));
1398  }
1399  setExpr(op.getResult(), e);
1400  })
1401  .Case<SpecialConstantOp>([&](auto op) {
1402  // Nothing required.
1403  })
1404  .Case<InvalidValueOp>([&](auto op) {
1405  // We must duplicate the invalid value for each use, since each use can
1406  // be inferred to a different width.
1407  if (!hasUninferredWidth(op.getType()))
1408  return;
1409  declareVars(op.getResult(), op.getLoc(), /*isDerived=*/true);
1410  if (op.use_empty())
1411  return;
1412 
1413  auto type = op.getType();
1414  ImplicitLocOpBuilder builder(op->getLoc(), op);
1415  for (auto &use :
1416  llvm::make_early_inc_range(llvm::drop_begin(op->getUses()))) {
1417  // - `make_early_inc_range` since `getUses()` is invalidated upon
1418  // `use.set(...)`.
1419  // - `drop_begin` such that the first use can keep the original op.
1420  auto clone = builder.create<InvalidValueOp>(type);
1421  declareVars(clone.getResult(), clone.getLoc(),
1422  /*isDerived=*/true);
1423  use.set(clone);
1424  }
1425  })
1426  .Case<WireOp, RegOp>(
1427  [&](auto op) { declareVars(op.getResult(), op.getLoc()); })
1428  .Case<RegResetOp>([&](auto op) {
1429  // The original Scala code also constrains the reset signal to be at
1430  // least 1 bit wide. We don't do this here since the MLIR FIRRTL
1431  // dialect enforces the reset signal to be an async reset or a
1432  // `uint<1>`.
1433  declareVars(op.getResult(), op.getLoc());
1434  // Contrain the register to be greater than or equal to the reset
1435  // signal.
1436  constrainTypes(op.getResult(), op.getResetValue());
1437  })
1438  .Case<NodeOp>([&](auto op) {
1439  // Nodes have the same type as their input.
1440  unifyTypes(FieldRef(op.getResult(), 0), FieldRef(op.getInput(), 0),
1441  op.getResult().getType());
1442  })
1443 
1444  // Aggregate Values
1445  .Case<SubfieldOp>([&](auto op) {
1446  BundleType bundleType = op.getInput().getType();
1447  auto fieldID = bundleType.getFieldID(op.getFieldIndex());
1448  unifyTypes(FieldRef(op.getResult(), 0),
1449  FieldRef(op.getInput(), fieldID), op.getType());
1450  })
1451  .Case<SubindexOp, SubaccessOp>([&](auto op) {
1452  // All vec fields unify to the same thing. Always use the first element
1453  // of the vector, which has a field ID of 1.
1454  unifyTypes(FieldRef(op.getResult(), 0), FieldRef(op.getInput(), 1),
1455  op.getType());
1456  })
1457  .Case<SubtagOp>([&](auto op) {
1458  FEnumType enumType = op.getInput().getType();
1459  auto fieldID = enumType.getFieldID(op.getFieldIndex());
1460  unifyTypes(FieldRef(op.getResult(), 0),
1461  FieldRef(op.getInput(), fieldID), op.getType());
1462  })
1463 
1464  .Case<RefSubOp>([&](RefSubOp op) {
1465  uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(
1466  op.getInput().getType().getType())
1467  .Case<FVectorType>([](auto _) { return 1; })
1468  .Case<BundleType>([&](auto type) {
1469  return type.getFieldID(op.getIndex());
1470  });
1471  unifyTypes(FieldRef(op.getResult(), 0),
1472  FieldRef(op.getInput(), fieldID), op.getType());
1473  })
1474 
1475  // Arithmetic and Logical Binary Primitives
1476  .Case<AddPrimOp, SubPrimOp>([&](auto op) {
1477  auto lhs = getExpr(op.getLhs());
1478  auto rhs = getExpr(op.getRhs());
1479  auto e = solver.add(solver.max(lhs, rhs), solver.known(1));
1480  setExpr(op.getResult(), e);
1481  })
1482  .Case<MulPrimOp>([&](auto op) {
1483  auto lhs = getExpr(op.getLhs());
1484  auto rhs = getExpr(op.getRhs());
1485  auto e = solver.add(lhs, rhs);
1486  setExpr(op.getResult(), e);
1487  })
1488  .Case<DivPrimOp>([&](auto op) {
1489  auto lhs = getExpr(op.getLhs());
1490  Expr *e;
1491  if (op.getType().get().isSigned()) {
1492  e = solver.add(lhs, solver.known(1));
1493  } else {
1494  e = lhs;
1495  }
1496  setExpr(op.getResult(), e);
1497  })
1498  .Case<RemPrimOp>([&](auto op) {
1499  auto lhs = getExpr(op.getLhs());
1500  auto rhs = getExpr(op.getRhs());
1501  auto e = solver.min(lhs, rhs);
1502  setExpr(op.getResult(), e);
1503  })
1504  .Case<AndPrimOp, OrPrimOp, XorPrimOp>([&](auto op) {
1505  auto lhs = getExpr(op.getLhs());
1506  auto rhs = getExpr(op.getRhs());
1507  auto e = solver.max(lhs, rhs);
1508  setExpr(op.getResult(), e);
1509  })
1510 
1511  // Misc Binary Primitives
1512  .Case<CatPrimOp>([&](auto op) {
1513  auto lhs = getExpr(op.getLhs());
1514  auto rhs = getExpr(op.getRhs());
1515  auto e = solver.add(lhs, rhs);
1516  setExpr(op.getResult(), e);
1517  })
1518  .Case<DShlPrimOp>([&](auto op) {
1519  auto lhs = getExpr(op.getLhs());
1520  auto rhs = getExpr(op.getRhs());
1521  auto e = solver.add(lhs, solver.add(solver.pow(rhs), solver.known(-1)));
1522  setExpr(op.getResult(), e);
1523  })
1524  .Case<DShlwPrimOp, DShrPrimOp>([&](auto op) {
1525  auto e = getExpr(op.getLhs());
1526  setExpr(op.getResult(), e);
1527  })
1528 
1529  // Unary operators
1530  .Case<NegPrimOp>([&](auto op) {
1531  auto input = getExpr(op.getInput());
1532  auto e = solver.add(input, solver.known(1));
1533  setExpr(op.getResult(), e);
1534  })
1535  .Case<CvtPrimOp>([&](auto op) {
1536  auto input = getExpr(op.getInput());
1537  auto e = op.getInput().getType().get().isSigned()
1538  ? input
1539  : solver.add(input, solver.known(1));
1540  setExpr(op.getResult(), e);
1541  })
1542 
1543  // Miscellaneous
1544  .Case<BitsPrimOp>([&](auto op) {
1545  setExpr(op.getResult(), solver.known(op.getHi() - op.getLo() + 1));
1546  })
1547  .Case<HeadPrimOp>([&](auto op) {
1548  setExpr(op.getResult(), solver.known(op.getAmount()));
1549  })
1550  .Case<TailPrimOp>([&](auto op) {
1551  auto input = getExpr(op.getInput());
1552  auto e = solver.add(input, solver.known(-op.getAmount()));
1553  setExpr(op.getResult(), e);
1554  })
1555  .Case<PadPrimOp>([&](auto op) {
1556  auto input = getExpr(op.getInput());
1557  auto e = solver.max(input, solver.known(op.getAmount()));
1558  setExpr(op.getResult(), e);
1559  })
1560  .Case<ShlPrimOp>([&](auto op) {
1561  auto input = getExpr(op.getInput());
1562  auto e = solver.add(input, solver.known(op.getAmount()));
1563  setExpr(op.getResult(), e);
1564  })
1565  .Case<ShrPrimOp>([&](auto op) {
1566  auto input = getExpr(op.getInput());
1567  auto e = solver.max(solver.add(input, solver.known(-op.getAmount())),
1568  solver.known(1));
1569  setExpr(op.getResult(), e);
1570  })
1571 
1572  // Handle operations whose output width matches the input width.
1573  .Case<NotPrimOp, AsSIntPrimOp, AsUIntPrimOp, ConstCastOp>(
1574  [&](auto op) { setExpr(op.getResult(), getExpr(op.getInput())); })
1575  .Case<mlir::UnrealizedConversionCastOp>(
1576  [&](auto op) { setExpr(op.getResult(0), getExpr(op.getOperand(0))); })
1577 
1578  // Handle operations with a single result type that always has a
1579  // well-known width.
1580  .Case<LEQPrimOp, LTPrimOp, GEQPrimOp, GTPrimOp, EQPrimOp, NEQPrimOp,
1581  AsClockPrimOp, AsAsyncResetPrimOp, AndRPrimOp, OrRPrimOp,
1582  XorRPrimOp>([&](auto op) {
1583  auto width = op.getType().getBitWidthOrSentinel();
1584  assert(width > 0 && "width should have been checked by verifier");
1585  setExpr(op.getResult(), solver.known(width));
1586  })
1587  .Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](auto op) {
1588  auto *sel = getExpr(op.getSel());
1589  constrainTypes(sel, solver.known(1));
1590  maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
1591  })
1592  .Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
1593  auto *sel = getExpr(op.getSel());
1594  constrainTypes(sel, solver.known(2));
1595  maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
1596  maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
1597  maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
1598  })
1599 
1600  .Case<ConnectOp, StrictConnectOp>(
1601  [&](auto op) { constrainTypes(op.getDest(), op.getSrc()); })
1602  .Case<RefDefineOp>([&](auto op) {
1603  // Dest >= Src, but also check Src <= Dest for correctness
1604  // (but don't solve to make this true, don't back-propagate)
1605  constrainTypes(op.getDest(), op.getSrc(), true);
1606  })
1607  // StrictConnect is an identify constraint
1608  .Case<StrictConnectOp>([&](auto op) {
1609  // This back-propagates width from destination to source,
1610  // causing source to sometimes be inferred wider than
1611  // it should be (https://github.com/llvm/circt/issues/5391).
1612  // This is needed to push the width into feeding widthCast?
1613  constrainTypes(op.getDest(), op.getSrc());
1614  constrainTypes(op.getSrc(), op.getDest());
1615  })
1616  .Case<AttachOp>([&](auto op) {
1617  // Attach connects multiple analog signals together. All signals must
1618  // have the same bit width. Signals without bit width inherit from the
1619  // other signals.
1620  if (op.getAttached().empty())
1621  return;
1622  auto prev = op.getAttached()[0];
1623  for (auto operand : op.getAttached().drop_front()) {
1624  auto e1 = getExpr(prev);
1625  auto e2 = getExpr(operand);
1626  constrainTypes(e1, e2, /*imposeUpperBounds=*/true);
1627  constrainTypes(e2, e1, /*imposeUpperBounds=*/true);
1628  prev = operand;
1629  }
1630  })
1631 
1632  // Handle the no-ops that don't interact with width inference.
1633  .Case<PrintFOp, SkipOp, StopOp, WhenOp, AssertOp, AssumeOp, CoverOp>(
1634  [&](auto) {})
1635 
1636  // Handle instances of other modules.
1637  .Case<InstanceOp>([&](auto op) {
1638  auto refdModule = op.getReferencedModule(symtbl);
1639  auto module = dyn_cast<FModuleOp>(&*refdModule);
1640  if (!module) {
1641  auto diag = mlir::emitError(op.getLoc());
1642  diag << "extern module `" << op.getModuleName()
1643  << "` has ports of uninferred width";
1644 
1645  auto fml = cast<FModuleLike>(&*refdModule);
1646  auto ports = fml.getPorts();
1647  for (auto &port : ports) {
1648  auto baseType = getBaseType(port.type);
1649  if (baseType && baseType.hasUninferredWidth()) {
1650  diag.attachNote(op.getLoc()) << "Port: " << port.name;
1651  if (!baseType.isGround())
1652  diagnoseUninferredType(diag, baseType, port.name.getValue());
1653  }
1654  }
1655 
1656  diag.attachNote(op.getLoc())
1657  << "Only non-extern FIRRTL modules may contain unspecified "
1658  "widths to be inferred automatically.";
1659  diag.attachNote(refdModule->getLoc())
1660  << "Module `" << op.getModuleName() << "` defined here:";
1661  mappingFailed = true;
1662  return;
1663  }
1664  // Simply look up the free variables created for the instantiated
1665  // module's ports, and use them for instance port wires. This way,
1666  // constraints imposed onto the ports of the instance will transparently
1667  // apply to the ports of the instantiated module.
1668  for (auto it : llvm::zip(op->getResults(), module.getArguments())) {
1669  unifyTypes(FieldRef(std::get<0>(it), 0), FieldRef(std::get<1>(it), 0),
1670  type_cast<FIRRTLType>(std::get<0>(it).getType()));
1671  }
1672  })
1673 
1674  // Handle memories.
1675  .Case<MemOp>([&](MemOp op) {
1676  // Create constraint variables for all ports.
1677  unsigned nonDebugPort = 0;
1678  for (const auto &result : llvm::enumerate(op.getResults())) {
1679  declareVars(result.value(), op.getLoc());
1680  if (!type_isa<RefType>(result.value().getType()))
1681  nonDebugPort = result.index();
1682  }
1683 
1684  // A helper function that returns the indeces of the "data", "rdata",
1685  // and "wdata" fields in the bundle corresponding to a memory port.
1686  auto dataFieldIndices = [](MemOp::PortKind kind) -> ArrayRef<unsigned> {
1687  static const unsigned indices[] = {3, 5};
1688  static const unsigned debug[] = {0};
1689  switch (kind) {
1690  case MemOp::PortKind::Read:
1691  case MemOp::PortKind::Write:
1692  return ArrayRef<unsigned>(indices, 1); // {3}
1693  case MemOp::PortKind::ReadWrite:
1694  return ArrayRef<unsigned>(indices); // {3, 5}
1695  case MemOp::PortKind::Debug:
1696  return ArrayRef<unsigned>(debug);
1697  }
1698  llvm_unreachable("Imposible PortKind");
1699  };
1700 
1701  // This creates independent variables for every data port. Yet, what we
1702  // actually want is for all data ports to share the same variable. To do
1703  // this, we find the first data port declared, and use that port's vars
1704  // for all the other ports.
1705  unsigned firstFieldIndex =
1706  dataFieldIndices(op.getPortKind(nonDebugPort))[0];
1707  FieldRef firstData(
1708  op.getResult(nonDebugPort),
1709  type_cast<BundleType>(op.getPortType(nonDebugPort).getPassiveType())
1710  .getFieldID(firstFieldIndex));
1711  LLVM_DEBUG(llvm::dbgs() << "Adjusting memory port variables:\n");
1712 
1713  // Reuse data port variables.
1714  auto dataType = op.getDataType();
1715  for (unsigned i = 0, e = op.getResults().size(); i < e; ++i) {
1716  auto result = op.getResult(i);
1717  if (type_isa<RefType>(result.getType())) {
1718  // Debug ports are firrtl.ref<vector<data-type, depth>>
1719  // Use FieldRef of 1, to indicate the first vector element must be
1720  // of the dataType.
1721  unifyTypes(firstData, FieldRef(result, 1), dataType);
1722  continue;
1723  }
1724 
1725  auto portType =
1726  type_cast<BundleType>(op.getPortType(i).getPassiveType());
1727  for (auto fieldIndex : dataFieldIndices(op.getPortKind(i)))
1728  unifyTypes(FieldRef(result, portType.getFieldID(fieldIndex)),
1729  firstData, dataType);
1730  }
1731  })
1732 
1733  .Case<RefSendOp>([&](auto op) {
1734  declareVars(op.getResult(), op.getLoc());
1735  constrainTypes(op.getResult(), op.getBase(), true);
1736  })
1737  .Case<RefResolveOp>([&](auto op) {
1738  declareVars(op.getResult(), op.getLoc());
1739  constrainTypes(op.getResult(), op.getRef(), true);
1740  })
1741  .Case<RefCastOp>([&](auto op) {
1742  declareVars(op.getResult(), op.getLoc());
1743  constrainTypes(op.getResult(), op.getInput(), true);
1744  })
1745  .Case<RWProbeOp>([&](auto op) {
1746  declareVars(op.getResult(), op.getLoc());
1747  auto ist = irn.lookup(op.getTarget());
1748  if (!ist) {
1749  op->emitError("target of rwprobe could not be resolved");
1750  mappingFailed = true;
1751  return;
1752  }
1753  auto ref = getRefForIST(ist);
1754  if (!ref) {
1755  op->emitError("target of rwprobe resolved to unsupported target");
1756  mappingFailed = true;
1757  return;
1758  }
1759  auto newFID = convertFieldIDToOurVersion(
1760  ref.getFieldID(), type_cast<FIRRTLType>(ref.getValue().getType()));
1761  unifyTypes(FieldRef(op.getResult(), 0),
1762  FieldRef(ref.getValue(), newFID), op.getType());
1763  })
1764  .Case<mlir::UnrealizedConversionCastOp>([&](auto op) {
1765  for (Value result : op.getResults()) {
1766  auto ty = result.getType();
1767  if (type_isa<FIRRTLType>(ty))
1768  declareVars(result, op.getLoc());
1769  }
1770  })
1771  .Default([&](auto op) {
1772  op->emitOpError("not supported in width inference");
1773  mappingFailed = true;
1774  });
1775 
1776  // Forceable declarations should have the ref constrained to data result.
1777  if (auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable())
1778  unifyTypes(FieldRef(fop.getDataRef(), 0), FieldRef(fop.getDataRaw(), 0),
1779  fop.getDataType());
1780 
1781  return failure(mappingFailed);
1782 }
1783 
1784 /// Declare free variables for the type of a value, and associate the resulting
1785 /// set of variables with that value.
1786 void InferenceMapping::declareVars(Value value, Location loc, bool isDerived) {
1787  // Declare a variable for every unknown width in the type. If this is a Bundle
1788  // type or a FVector type, we will have to potentially create many variables.
1789  unsigned fieldID = 0;
1790  std::function<void(FIRRTLBaseType)> declare = [&](FIRRTLBaseType type) {
1791  auto width = type.getBitWidthOrSentinel();
1792  if (width >= 0) {
1793  // Known width integer create a known expression.
1794  setExpr(FieldRef(value, fieldID), solver.known(width));
1795  fieldID++;
1796  } else if (width == -1) {
1797  // Unknown width integers create a variable.
1798  FieldRef field(value, fieldID);
1799  solver.setCurrentContextInfo(field);
1800  if (isDerived)
1801  setExpr(field, solver.derived());
1802  else
1803  setExpr(field, solver.var());
1804  fieldID++;
1805  } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1806  // Bundle types recursively declare all bundle elements.
1807  fieldID++;
1808  for (auto &element : bundleType) {
1809  declare(element.type);
1810  }
1811  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1812  fieldID++;
1813  auto save = fieldID;
1814  declare(vecType.getElementType());
1815  // Skip past the rest of the elements
1816  fieldID = save + vecType.getMaxFieldID();
1817  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1818  fieldID++;
1819  for (auto &element : enumType.getElements())
1820  declare(element.type);
1821  } else {
1822  llvm_unreachable("Unknown type inside a bundle!");
1823  }
1824  };
1825  if (auto type = getBaseType(value.getType()))
1826  declare(type);
1827 }
1828 
1829 /// Assign the constraint expressions of the fields in the `result` argument as
1830 /// the max of expressions in the `rhs` and `lhs` arguments. Both fields must be
1831 /// the same type.
1832 void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
1833  // Recurse to every leaf element and set larger >= smaller.
1834  auto fieldID = 0;
1835  std::function<void(FIRRTLBaseType)> maximize = [&](FIRRTLBaseType type) {
1836  if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1837  fieldID++;
1838  for (auto &element : bundleType.getElements())
1839  maximize(element.type);
1840  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1841  fieldID++;
1842  auto save = fieldID;
1843  // Skip 0 length vectors.
1844  if (vecType.getNumElements() > 0)
1845  maximize(vecType.getElementType());
1846  fieldID = save + vecType.getMaxFieldID();
1847  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1848  fieldID++;
1849  for (auto &element : enumType.getElements())
1850  maximize(element.type);
1851  } else if (type.isGround()) {
1852  auto *e = solver.max(getExpr(FieldRef(rhs, fieldID)),
1853  getExpr(FieldRef(lhs, fieldID)));
1854  setExpr(FieldRef(result, fieldID), e);
1855  fieldID++;
1856  } else {
1857  llvm_unreachable("Unknown type inside a bundle!");
1858  }
1859  };
1860  if (auto type = getBaseType(result.getType()))
1861  maximize(type);
1862 }
1863 
1864 /// Establishes constraints to ensure the sizes in the `larger` type are greater
1865 /// than or equal to the sizes in the `smaller` type. Types have to be
1866 /// compatible in the sense that they may only differ in the presence or absence
1867 /// of bit widths.
1868 ///
1869 /// This function is used to apply regular connects.
1870 /// Set `equal` for constraining larger <= smaller for correctness but not
1871 /// solving.
1872 void InferenceMapping::constrainTypes(Value larger, Value smaller, bool equal) {
1873  // Recurse to every leaf element and set larger >= smaller. Ignore foreign
1874  // types as these do not participate in width inference.
1875 
1876  auto fieldID = 0;
1877  std::function<void(FIRRTLBaseType, Value, Value)> constrain =
1878  [&](FIRRTLBaseType type, Value larger, Value smaller) {
1879  if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1880  fieldID++;
1881  for (auto &element : bundleType.getElements()) {
1882  if (element.isFlip)
1883  constrain(element.type, smaller, larger);
1884  else
1885  constrain(element.type, larger, smaller);
1886  }
1887  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1888  fieldID++;
1889  auto save = fieldID;
1890  // Skip 0 length vectors.
1891  if (vecType.getNumElements() > 0) {
1892  constrain(vecType.getElementType(), larger, smaller);
1893  }
1894  fieldID = save + vecType.getMaxFieldID();
1895  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1896  fieldID++;
1897  for (auto &element : enumType.getElements())
1898  constrain(element.type, larger, smaller);
1899  } else if (type.isGround()) {
1900  // Leaf element, look up their expressions, and create the constraint.
1901  constrainTypes(getExpr(FieldRef(larger, fieldID)),
1902  getExpr(FieldRef(smaller, fieldID)), false, equal);
1903  fieldID++;
1904  } else {
1905  llvm_unreachable("Unknown type inside a bundle!");
1906  }
1907  };
1908 
1909  if (auto type = getBaseType(larger.getType()))
1910  constrain(type, larger, smaller);
1911 }
1912 
1913 /// Establishes constraints to ensure the sizes in the `larger` type are greater
1914 /// than or equal to the sizes in the `smaller` type.
1915 void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
1916  bool imposeUpperBounds, bool equal) {
1917  assert(larger && "Larger expression should be specified");
1918  assert(smaller && "Smaller expression should be specified");
1919 
1920  // If one of the sides is `DerivedExpr`, simply assign the other side as the
1921  // derived width. This allows `InvalidValueOp`s to properly infer their width
1922  // from the connects they are used in, but also be inferred to something
1923  // useful on their own.
1924  if (auto *largerDerived = dyn_cast<DerivedExpr>(larger)) {
1925  largerDerived->assigned = smaller;
1926  LLVM_DEBUG(llvm::dbgs() << "Deriving " << *largerDerived << " from "
1927  << *smaller << "\n");
1928  return;
1929  }
1930  if (auto *smallerDerived = dyn_cast<DerivedExpr>(smaller)) {
1931  smallerDerived->assigned = larger;
1932  LLVM_DEBUG(llvm::dbgs() << "Deriving " << *smallerDerived << " from "
1933  << *larger << "\n");
1934  return;
1935  }
1936 
1937  // If the larger expr is a free variable, create a `expr >= x` constraint for
1938  // it that we can try to satisfy with the smallest width.
1939  if (auto largerVar = dyn_cast<VarExpr>(larger)) {
1940  [[maybe_unused]] auto *c = solver.addGeqConstraint(largerVar, smaller);
1941  LLVM_DEBUG(llvm::dbgs()
1942  << "Constrained " << *largerVar << " >= " << *c << "\n");
1943  // If we're constraining larger == smaller, add the LEQ contraint as well.
1944  // Solve for GEQ but check that LEQ is true.
1945  // Used for strictconnect, some reference operations, and anywhere the
1946  // widths should be inferred strictly in one direction but are required to
1947  // also be equal for correctness.
1948  if (equal) {
1949  [[maybe_unused]] auto *leq = solver.addLeqConstraint(largerVar, smaller);
1950  LLVM_DEBUG(llvm::dbgs()
1951  << "Constrained " << *largerVar << " <= " << *leq << "\n");
1952  }
1953  return;
1954  }
1955 
1956  // If the smaller expr is a free variable but the larger one is not, create a
1957  // `expr <= k` upper bound that we can verify once all lower bounds have been
1958  // satisfied. Since we are always picking the smallest width to satisfy all
1959  // `>=` constraints, any `<=` constraints have no effect on the solution
1960  // besides indicating that a width is unsatisfiable.
1961  if (auto *smallerVar = dyn_cast<VarExpr>(smaller)) {
1962  if (imposeUpperBounds || equal) {
1963  [[maybe_unused]] auto *c = solver.addLeqConstraint(smallerVar, larger);
1964  LLVM_DEBUG(llvm::dbgs()
1965  << "Constrained " << *smallerVar << " <= " << *c << "\n");
1966  }
1967  }
1968 }
1969 
1970 /// Assign the constraint expressions of the fields in the `src` argument as the
1971 /// expressions for the `dst` argument. Both fields must be of the given `type`.
1972 void InferenceMapping::unifyTypes(FieldRef lhs, FieldRef rhs, FIRRTLType type) {
1973  // Fast path for `unifyTypes(x, x, _)`.
1974  if (lhs == rhs)
1975  return;
1976 
1977  // Co-iterate the two field refs, recurring into every leaf element and set
1978  // them equal.
1979  auto fieldID = 0;
1980  std::function<void(FIRRTLBaseType)> unify = [&](FIRRTLBaseType type) {
1981  if (type.isGround()) {
1982  // Leaf element, unify the fields!
1983  FieldRef lhsFieldRef(lhs.getValue(), lhs.getFieldID() + fieldID);
1984  FieldRef rhsFieldRef(rhs.getValue(), rhs.getFieldID() + fieldID);
1985  LLVM_DEBUG(llvm::dbgs()
1986  << "Unify " << getFieldName(lhsFieldRef).first << " = "
1987  << getFieldName(rhsFieldRef).first << "\n");
1988  // Abandon variables becoming unconstrainable by the unification.
1989  if (auto *var = dyn_cast_or_null<VarExpr>(getExprOrNull(lhsFieldRef)))
1990  solver.addGeqConstraint(var, solver.known(0));
1991  setExpr(lhsFieldRef, getExpr(rhsFieldRef));
1992  fieldID++;
1993  } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1994  fieldID++;
1995  for (auto &element : bundleType) {
1996  unify(element.type);
1997  }
1998  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1999  fieldID++;
2000  auto save = fieldID;
2001  // Skip 0 length vectors.
2002  if (vecType.getNumElements() > 0) {
2003  unify(vecType.getElementType());
2004  }
2005  fieldID = save + vecType.getMaxFieldID();
2006  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
2007  fieldID++;
2008  for (auto &element : enumType.getElements())
2009  unify(element.type);
2010  } else {
2011  llvm_unreachable("Unknown type inside a bundle!");
2012  }
2013  };
2014  if (auto ftype = getBaseType(type))
2015  unify(ftype);
2016 }
2017 
2018 /// Get the constraint expression for a value.
2019 Expr *InferenceMapping::getExpr(Value value) const {
2020  assert(type_cast<FIRRTLType>(getBaseType(value.getType())).isGround());
2021  // A field ID of 0 indicates the entire value.
2022  return getExpr(FieldRef(value, 0));
2023 }
2024 
2025 /// Get the constraint expression for a value.
2026 Expr *InferenceMapping::getExpr(FieldRef fieldRef) const {
2027  auto expr = getExprOrNull(fieldRef);
2028  assert(expr && "constraint expr should have been constructed for value");
2029  return expr;
2030 }
2031 
2032 Expr *InferenceMapping::getExprOrNull(FieldRef fieldRef) const {
2033  auto it = opExprs.find(fieldRef);
2034  return it != opExprs.end() ? it->second : nullptr;
2035 }
2036 
2037 /// Associate a constraint expression with a value.
2038 void InferenceMapping::setExpr(Value value, Expr *expr) {
2039  assert(type_cast<FIRRTLType>(getBaseType(value.getType())).isGround());
2040  // A field ID of 0 indicates the entire value.
2041  setExpr(FieldRef(value, 0), expr);
2042 }
2043 
2044 /// Associate a constraint expression with a value.
2045 void InferenceMapping::setExpr(FieldRef fieldRef, Expr *expr) {
2046  LLVM_DEBUG({
2047  llvm::dbgs() << "Expr " << *expr << " for " << fieldRef.getValue();
2048  if (fieldRef.getFieldID())
2049  llvm::dbgs() << " '" << getFieldName(fieldRef).first << "'";
2050  auto fieldName = getFieldName(fieldRef);
2051  if (fieldName.second)
2052  llvm::dbgs() << " (\"" << fieldName.first << "\")";
2053  llvm::dbgs() << "\n";
2054  });
2055  opExprs[fieldRef] = expr;
2056 }
2057 
2058 //===----------------------------------------------------------------------===//
2059 // Inference Result Application
2060 //===----------------------------------------------------------------------===//
2061 
2062 namespace {
2063 /// A helper class which maps the types and operations in a design to a set
2064 /// of variables and constraints to be solved later.
2065 class InferenceTypeUpdate {
2066 public:
2067  InferenceTypeUpdate(InferenceMapping &mapping) : mapping(mapping) {}
2068 
2069  LogicalResult update(CircuitOp op);
2070  FailureOr<bool> updateOperation(Operation *op);
2071  FailureOr<bool> updateValue(Value value);
2073 
2074 private:
2075  const InferenceMapping &mapping;
2076 };
2077 
2078 } // namespace
2079 
2080 /// Update the types throughout a circuit.
2081 LogicalResult InferenceTypeUpdate::update(CircuitOp op) {
2082  LLVM_DEBUG(llvm::dbgs() << "\n===----- Update types -----===\n\n");
2083  return mlir::failableParallelForEach(
2084  op.getContext(), op.getOps<FModuleOp>(), [&](FModuleOp op) {
2085  // Skip this module if it had no widths to be
2086  // inferred at all.
2087  if (mapping.isModuleSkipped(op))
2088  return success();
2089  auto isFailed = op.walk<WalkOrder::PreOrder>([&](Operation *op) {
2090  if (failed(updateOperation(op)))
2091  return WalkResult::interrupt();
2092  return WalkResult::advance();
2093  }).wasInterrupted();
2094  return failure(isFailed);
2095  });
2096 }
2097 
2098 /// Update the result types of an operation.
2099 FailureOr<bool> InferenceTypeUpdate::updateOperation(Operation *op) {
2100  bool anyChanged = false;
2101 
2102  for (Value v : op->getResults()) {
2103  auto result = updateValue(v);
2104  if (failed(result))
2105  return result;
2106  anyChanged |= *result;
2107  }
2108 
2109  // If this is a connect operation, width inference might have inferred a RHS
2110  // that is wider than the LHS, in which case an additional BitsPrimOp is
2111  // necessary to truncate the value.
2112  if (auto con = dyn_cast<ConnectOp>(op)) {
2113  auto lhs = con.getDest();
2114  auto rhs = con.getSrc();
2115  auto lhsType = type_dyn_cast<FIRRTLBaseType>(lhs.getType());
2116  auto rhsType = type_dyn_cast<FIRRTLBaseType>(rhs.getType());
2117 
2118  // Nothing to do if not base types.
2119  if (!lhsType || !rhsType)
2120  return anyChanged;
2121 
2122  auto lhsWidth = lhsType.getBitWidthOrSentinel();
2123  auto rhsWidth = rhsType.getBitWidthOrSentinel();
2124  if (lhsWidth >= 0 && rhsWidth >= 0 && lhsWidth < rhsWidth) {
2125  OpBuilder builder(op);
2126  auto trunc = builder.createOrFold<TailPrimOp>(con.getLoc(), con.getSrc(),
2127  rhsWidth - lhsWidth);
2128  if (type_isa<SIntType>(rhsType))
2129  trunc =
2130  builder.createOrFold<AsSIntPrimOp>(con.getLoc(), lhsType, trunc);
2131 
2132  LLVM_DEBUG(llvm::dbgs()
2133  << "Truncating RHS to " << lhsType << " in " << con << "\n");
2134  con->replaceUsesOfWith(con.getSrc(), trunc);
2135  }
2136  return anyChanged;
2137  }
2138 
2139  // If this is a module, update its ports.
2140  if (auto module = dyn_cast<FModuleOp>(op)) {
2141  // Update the block argument types.
2142  bool argsChanged = false;
2143  SmallVector<Attribute> argTypes;
2144  argTypes.reserve(module.getNumPorts());
2145  for (auto arg : module.getArguments()) {
2146  auto result = updateValue(arg);
2147  if (failed(result))
2148  return result;
2149  argsChanged |= *result;
2150  argTypes.push_back(TypeAttr::get(arg.getType()));
2151  }
2152 
2153  // Update the module function type if needed.
2154  if (argsChanged) {
2155  module->setAttr(FModuleLike::getPortTypesAttrName(),
2156  ArrayAttr::get(module.getContext(), argTypes));
2157  anyChanged = true;
2158  }
2159  }
2160  return anyChanged;
2161 }
2162 
2163 /// Resize a `uint`, `sint`, or `analog` type to a specific width.
2164 static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth) {
2165  auto *context = type.getContext();
2167  .Case<UIntType>([&](auto type) {
2168  return UIntType::get(context, newWidth, type.isConst());
2169  })
2170  .Case<SIntType>([&](auto type) {
2171  return SIntType::get(context, newWidth, type.isConst());
2172  })
2173  .Case<AnalogType>([&](auto type) {
2174  return AnalogType::get(context, newWidth, type.isConst());
2175  })
2176  .Default([&](auto type) { return type; });
2177 }
2178 
2179 /// Update the type of a value.
2180 FailureOr<bool> InferenceTypeUpdate::updateValue(Value value) {
2181  // Check if the value has a type which we can update.
2182  auto type = type_dyn_cast<FIRRTLType>(value.getType());
2183  if (!type)
2184  return false;
2185 
2186  // Fast path for types that have fully inferred widths.
2187  if (!hasUninferredWidth(type))
2188  return false;
2189 
2190  // If this is an operation that does not generate any free variables that
2191  // are determined during width inference, simply update the value type based
2192  // on the operation arguments.
2193  if (auto op = dyn_cast_or_null<InferTypeOpInterface>(value.getDefiningOp())) {
2194  SmallVector<Type, 2> types;
2195  auto res =
2196  op.inferReturnTypes(op->getContext(), op->getLoc(), op->getOperands(),
2197  op->getAttrDictionary(), op->getPropertiesStorage(),
2198  op->getRegions(), types);
2199  if (failed(res))
2200  return failure();
2201 
2202  assert(types.size() == op->getNumResults());
2203  for (auto it : llvm::zip(op->getResults(), types)) {
2204  LLVM_DEBUG(llvm::dbgs() << "Inferring " << std::get<0>(it) << " as "
2205  << std::get<1>(it) << "\n");
2206  std::get<0>(it).setType(std::get<1>(it));
2207  }
2208  return true;
2209  }
2210 
2211  // Recreate the type, substituting the solved widths.
2212  auto context = type.getContext();
2213  unsigned fieldID = 0;
2214  std::function<FIRRTLBaseType(FIRRTLBaseType)> updateBase =
2215  [&](FIRRTLBaseType type) -> FIRRTLBaseType {
2216  auto width = type.getBitWidthOrSentinel();
2217  if (width >= 0) {
2218  // Known width integers return themselves.
2219  fieldID++;
2220  return type;
2221  } else if (width == -1) {
2222  // Unknown width integers return the solved type.
2223  auto newType = updateType(FieldRef(value, fieldID), type);
2224  fieldID++;
2225  return newType;
2226  } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
2227  // Bundle types recursively update all bundle elements.
2228  fieldID++;
2229  llvm::SmallVector<BundleType::BundleElement, 3> elements;
2230  for (auto &element : bundleType) {
2231  auto updatedBase = updateBase(element.type);
2232  if (!updatedBase)
2233  return {};
2234  elements.emplace_back(element.name, element.isFlip, updatedBase);
2235  }
2236  return BundleType::get(context, elements, bundleType.isConst());
2237  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
2238  fieldID++;
2239  auto save = fieldID;
2240  // TODO: this should recurse into the element type of 0 length vectors and
2241  // set any unknown width to 0.
2242  if (vecType.getNumElements() > 0) {
2243  auto updatedBase = updateBase(vecType.getElementType());
2244  if (!updatedBase)
2245  return {};
2246  auto newType = FVectorType::get(updatedBase, vecType.getNumElements(),
2247  vecType.isConst());
2248  fieldID = save + vecType.getMaxFieldID();
2249  return newType;
2250  }
2251  // If this is a 0 length vector return the original type.
2252  return type;
2253  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
2254  fieldID++;
2255  llvm::SmallVector<FEnumType::EnumElement> elements;
2256  for (auto &element : enumType.getElements()) {
2257  auto updatedBase = updateBase(element.type);
2258  if (!updatedBase)
2259  return {};
2260  elements.emplace_back(element.name, updatedBase);
2261  }
2262  return FEnumType::get(context, elements, enumType.isConst());
2263  }
2264  llvm_unreachable("Unknown type inside a bundle!");
2265  };
2266 
2267  // Update the type.
2268  auto newType = mapBaseTypeNullable(type, updateBase);
2269  if (!newType)
2270  return failure();
2271  LLVM_DEBUG(llvm::dbgs() << "Update " << value << " to " << newType << "\n");
2272  value.setType(newType);
2273 
2274  // If this is a ConstantOp, adjust the width of the underlying APInt.
2275  // Unsized constants have APInts which are *at least* wide enough to hold
2276  // the value, but may be larger. This can trip up the verifier.
2277  if (auto op = value.getDefiningOp<ConstantOp>()) {
2278  auto k = op.getValue();
2279  auto bitwidth = op.getType().getBitWidthOrSentinel();
2280  if (k.getBitWidth() > unsigned(bitwidth))
2281  k = k.trunc(bitwidth);
2282  op->setAttr("value", IntegerAttr::get(op.getContext(), k));
2283  }
2284 
2285  return newType != type;
2286 }
2287 
2288 /// Update a type.
2290  FIRRTLBaseType type) {
2291  assert(type.isGround() && "Can only pass in ground types.");
2292  auto value = fieldRef.getValue();
2293  // Get the inferred width.
2294  Expr *expr = mapping.getExprOrNull(fieldRef);
2295  if (!expr || !expr->solution) {
2296  // It should not be possible to arrive at an uninferred width at this point.
2297  // In case the constraints are not resolvable, checks before the calls to
2298  // `updateType` must have already caught the issues and aborted the pass
2299  // early. Might turn this into an assert later.
2300  mlir::emitError(value.getLoc(), "width should have been inferred");
2301  return {};
2302  }
2303  int32_t solution = *expr->solution;
2304  assert(solution >= 0); // The solver infers variables to be 0 or greater.
2305  return resizeType(type, solution);
2306 }
2307 
2308 //===----------------------------------------------------------------------===//
2309 // Hashing
2310 //===----------------------------------------------------------------------===//
2311 
2312 // Hash slots in the interned allocator as if they were the pointed-to value
2313 // itself.
2314 namespace llvm {
2315 template <typename T>
2316 struct DenseMapInfo<InternedSlot<T>> {
2317  using Slot = InternedSlot<T>;
2318  static Slot getEmptyKey() {
2319  auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
2320  return Slot(static_cast<T *>(pointer));
2321  }
2323  auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
2324  return Slot(static_cast<T *>(pointer));
2325  }
2326  static unsigned getHashValue(Slot val) { return mlir::hash_value(*val.ptr); }
2327  static bool isEqual(Slot LHS, Slot RHS) {
2328  auto empty = getEmptyKey().ptr;
2329  auto tombstone = getTombstoneKey().ptr;
2330  if (LHS.ptr == empty || RHS.ptr == empty || LHS.ptr == tombstone ||
2331  RHS.ptr == tombstone)
2332  return LHS.ptr == RHS.ptr;
2333  return *LHS.ptr == *RHS.ptr;
2334  }
2335 };
2336 } // namespace llvm
2337 
2338 //===----------------------------------------------------------------------===//
2339 // Pass Infrastructure
2340 //===----------------------------------------------------------------------===//
2341 
2342 namespace {
2343 class InferWidthsPass : public InferWidthsBase<InferWidthsPass> {
2344  void runOnOperation() override;
2345 };
2346 } // namespace
2347 
2348 void InferWidthsPass::runOnOperation() {
2349  // Collect variables and constraints
2350  ConstraintSolver solver;
2351  InferenceMapping mapping(solver, getAnalysis<SymbolTable>(),
2352  getAnalysis<hw::InnerSymbolTableCollection>());
2353  if (failed(mapping.map(getOperation()))) {
2354  signalPassFailure();
2355  return;
2356  }
2357  if (mapping.areAllModulesSkipped()) {
2358  markAllAnalysesPreserved();
2359  return; // fast path if no inferrable widths are around
2360  }
2361 
2362  // Solve the constraints.
2363  if (failed(solver.solve())) {
2364  signalPassFailure();
2365  return;
2366  }
2367 
2368  // Update the types with the inferred widths.
2369  if (failed(InferenceTypeUpdate(mapping).update(getOperation())))
2370  signalPassFailure();
2371 }
2372 
2373 std::unique_ptr<mlir::Pass> circt::firrtl::createInferWidthsPass() {
2374  return std::make_unique<InferWidthsPass>();
2375 }
lowerAnnotationsNoRefTypePorts FirtoolPreserveValuesMode value
Definition: Firtool.cpp:95
assert(baseType &&"element must be base type")
int32_t width
Definition: FIRRTL.cpp:27
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
bool operator==(const ResetDomain &a, const ResetDomain &b)
Definition: InferResets.cpp:89
static FieldRef getRefForIST(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
Definition: InferWidths.cpp:64
std::pair< std::optional< int32_t >, bool > ExprSolution
static ExprSolution computeUnary(ExprSolution arg, llvm::function_ref< int32_t(int32_t)> operation)
static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type)
Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
Definition: InferWidths.cpp:80
static ExprSolution computeBinary(ExprSolution lhs, ExprSolution rhs, llvm::function_ref< int32_t(int32_t, int32_t)> operation)
#define EXPR_KINDS
#define EXPR_CLASSES
static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth)
Resize a uint, sint, or analog type to a specific width.
static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t, Twine str)
Definition: InferWidths.cpp:44
static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl< Expr * > &seenVars, unsigned defaultWorklistSize)
Compute the value of a constraint expr.
static bool hasUninferredWidth(Type type)
Check if a type contains any FIRRTL type with uninferred widths.
static InstancePath empty
Builder builder
This class represents a reference to a specific field or element of an aggregate value.
Definition: FieldRef.h:28
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Definition: FieldRef.h:59
Value getValue() const
Get the Value which created this location.
Definition: FieldRef.h:37
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:515
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:525
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
FIRRTLType mapBaseTypeNullable(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
Definition: FIRRTLUtils.h:245
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
Definition: FIRRTLUtils.h:217
T & operator<<(T &os, FIRVersion version)
Definition: FIRParser.h:115
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
std::unique_ptr< mlir::Pass > createInferWidthsPass()
llvm::hash_code hash_value(const ClassElement &element)
Definition: FIRRTLTypes.h:363
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
mlir::raw_indented_ostream & dbgs()
Definition: Utility.h:28
llvm::hash_code hash_value(const T &e)
static unsigned getHashValue(Slot val)
static bool isEqual(Slot LHS, Slot RHS)