CIRCT  19.0.0git
InferWidths.cpp
Go to the documentation of this file.
1 //===- InferWidths.cpp - Infer width of types -------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines the InferWidths pass.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "PassDetails.h"
14 
20 #include "circt/Support/Debug.h"
21 #include "circt/Support/FieldRef.h"
22 #include "mlir/IR/ImplicitLocOpBuilder.h"
23 #include "mlir/IR/Threading.h"
24 #include "llvm/ADT/APSInt.h"
25 #include "llvm/ADT/DenseSet.h"
26 #include "llvm/ADT/GraphTraits.h"
27 #include "llvm/ADT/Hashing.h"
28 #include "llvm/ADT/MapVector.h"
29 #include "llvm/ADT/PostOrderIterator.h"
30 #include "llvm/ADT/SetVector.h"
31 #include "llvm/Support/Debug.h"
32 #include "llvm/Support/ErrorHandling.h"
33 
34 #define DEBUG_TYPE "infer-widths"
35 
36 using mlir::InferTypeOpInterface;
37 using mlir::WalkOrder;
38 
39 using namespace circt;
40 using namespace firrtl;
41 
42 //===----------------------------------------------------------------------===//
43 // Helpers
44 //===----------------------------------------------------------------------===//
45 
46 static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t,
47  Twine str) {
48  auto basetype = type_dyn_cast<FIRRTLBaseType>(t);
49  if (!basetype)
50  return;
51  if (!basetype.hasUninferredWidth())
52  return;
53 
54  if (basetype.isGround())
55  diag.attachNote() << "Field: \"" << str << "\"";
56  else if (auto vecType = type_dyn_cast<FVectorType>(basetype))
57  diagnoseUninferredType(diag, vecType.getElementType(), str + "[]");
58  else if (auto bundleType = type_dyn_cast<BundleType>(basetype))
59  for (auto &elem : bundleType.getElements())
60  diagnoseUninferredType(diag, elem.type, str + "." + elem.name.getValue());
61 }
62 
63 /// Get FieldRef pointing to the specified inner symbol target, which must be
64 /// valid. Returns null FieldRef if target points to something with no value,
65 /// such as a port of an external module.
66 static FieldRef getRefForIST(const hw::InnerSymTarget &ist) {
67  if (ist.isPort()) {
68  return TypeSwitch<Operation *, FieldRef>(ist.getOp())
69  .Case<FModuleOp>([&](auto fmod) {
70  return FieldRef(fmod.getArgument(ist.getPort()), ist.getField());
71  })
72  .Default({});
73  }
74 
75  auto symOp = dyn_cast<hw::InnerSymbolOpInterface>(ist.getOp());
76  assert(symOp && symOp.getTargetResultIndex() &&
77  (symOp.supportsPerFieldSymbols() || ist.getField() == 0));
78  return FieldRef(symOp.getTargetResult(), ist.getField());
79 }
80 
81 /// Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
82 static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type) {
83  uint64_t convertedFieldID = 0;
84 
85  auto curFID = fieldID;
86  Type curFType = type;
87  while (curFID != 0) {
88  auto [child, subID] =
89  hw::FieldIdImpl::getSubTypeByFieldID(curFType, curFID);
90  if (isa<FVectorType>(curFType))
91  convertedFieldID++; // Vector fieldID is 1.
92  else
93  convertedFieldID += curFID - subID; // Add consumed portion.
94  curFID = subID;
95  curFType = child;
96  }
97 
98  return convertedFieldID;
99 }
100 
101 //===----------------------------------------------------------------------===//
102 // Constraint Expressions
103 //===----------------------------------------------------------------------===//
104 
105 namespace {
106 struct Expr;
107 } // namespace
108 
109 /// Allow rvalue refs to `Expr` and subclasses to be printed to streams.
110 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
111  int>::type = 0>
112 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const T &e) {
113  e.print(os);
114  return os;
115 }
116 
117 // Allow expression subclasses to be hashed.
118 namespace mlir {
119 template <typename T, typename std::enable_if<std::is_base_of<Expr, T>::value,
120  int>::type = 0>
121 inline llvm::hash_code hash_value(const T &e) {
122  return e.hash_value();
123 }
124 } // namespace mlir
125 
126 namespace {
127 #define EXPR_NAMES(x) \
128  Root##x, Var##x, Derived##x, Id##x, Known##x, Add##x, Pow##x, Max##x, Min##x
129 #define EXPR_KINDS EXPR_NAMES()
130 #define EXPR_CLASSES EXPR_NAMES(Expr)
131 
132 /// An expression on the right-hand side of a constraint.
133 struct Expr {
134  enum class Kind { EXPR_KINDS };
135  std::optional<int32_t> solution;
136  Kind kind;
137 
138  /// Print a human-readable representation of this expr.
139  void print(llvm::raw_ostream &os) const;
140 
141 protected:
142  Expr(Kind kind) : kind(kind) {}
143  llvm::hash_code hash_value() const { return llvm::hash_value(kind); }
144 };
145 
146 /// Helper class to CRTP-derive common functions.
147 template <class DerivedT, Expr::Kind DerivedKind>
148 struct ExprBase : public Expr {
149  ExprBase() : Expr(DerivedKind) {}
150  static bool classof(const Expr *e) { return e->kind == DerivedKind; }
151  bool operator==(const Expr &other) const {
152  if (auto otherSame = dyn_cast<DerivedT>(other))
153  return *static_cast<DerivedT *>(this) == otherSame;
154  return false;
155  }
156 };
157 
158 /// The root containing all expressions.
159 struct RootExpr : public ExprBase<RootExpr, Expr::Kind::Root> {
160  RootExpr(std::vector<Expr *> &exprs) : exprs(exprs) {}
161  void print(llvm::raw_ostream &os) const { os << "root"; }
162  std::vector<Expr *> &exprs;
163 };
164 
165 /// A free variable.
166 struct VarExpr : public ExprBase<VarExpr, Expr::Kind::Var> {
167  void print(llvm::raw_ostream &os) const {
168  // Hash the `this` pointer into something somewhat human readable. Since
169  // this is just for debug dumping, we wrap around at 65536 variables.
170  os << "var" << ((size_t)this / llvm::PowerOf2Ceil(sizeof(*this)) & 0xFFFF);
171  }
172 
173  /// The constraint expression this variable is supposed to be greater than or
174  /// equal to. This is not part of the variable's hash and equality property.
175  Expr *constraint = nullptr;
176 
177  /// The upper bound this variable is supposed to be smaller than or equal to.
178  Expr *upperBound = nullptr;
179  std::optional<int32_t> upperBoundSolution;
180 };
181 
182 /// A derived width.
183 ///
184 /// These are generated for `InvalidValueOp`s which want to derived their width
185 /// from connect operations that they are on the right hand side of.
186 struct DerivedExpr : public ExprBase<DerivedExpr, Expr::Kind::Derived> {
187  void print(llvm::raw_ostream &os) const {
188  // Hash the `this` pointer into something somewhat human readable.
189  os << "derive"
190  << ((size_t)this / llvm::PowerOf2Ceil(sizeof(*this)) & 0xFFF);
191  }
192 
193  /// The expression this derived width is equivalent to.
194  Expr *assigned = nullptr;
195 };
196 
197 /// An identity expression.
198 ///
199 /// This expression evaluates to its inner expression. It is used in a very
200 /// specific case of constraints on variables, in order to be able to track
201 /// where the constraint was imposed. Constraints on variables are represented
202 /// as `var >= <expr>`. When the first constraint `a` is imposed, it is stored
203 /// as the constraint expression (`var >= a`). When the second constraint `b` is
204 /// imposed, a *new* max expression is allocated (`var >= max(a, b)`).
205 /// Expressions are annotated with a location when they are created, which in
206 /// this case are connect ops. Since imposing the first constraint does not
207 /// create any new expression, the location information of that connect would be
208 /// lost. With an identity expression, imposing the first constraint becomes
209 /// `var >= identity(a)`, which is a *new* expression and properly tracks the
210 /// location info.
211 struct IdExpr : public ExprBase<IdExpr, Expr::Kind::Id> {
212  IdExpr(Expr *arg) : arg(arg) { assert(arg); }
213  void print(llvm::raw_ostream &os) const { os << "*" << *arg; }
214  bool operator==(const IdExpr &other) const {
215  return kind == other.kind && arg == other.arg;
216  }
217  llvm::hash_code hash_value() const {
218  return llvm::hash_combine(Expr::hash_value(), arg);
219  }
220 
221  /// The inner expression.
222  Expr *const arg;
223 };
224 
225 /// A known constant value.
226 struct KnownExpr : public ExprBase<KnownExpr, Expr::Kind::Known> {
227  KnownExpr(int32_t value) : ExprBase() { solution = value; }
228  void print(llvm::raw_ostream &os) const { os << *solution; }
229  bool operator==(const KnownExpr &other) const {
230  return *solution == *other.solution;
231  }
232  llvm::hash_code hash_value() const {
233  return llvm::hash_combine(Expr::hash_value(), *solution);
234  }
235 };
236 
237 /// A unary expression. Contains the actual data. Concrete subclasses are merely
238 /// there for show and ease of use.
239 struct UnaryExpr : public Expr {
240  bool operator==(const UnaryExpr &other) const {
241  return kind == other.kind && arg == other.arg;
242  }
243  llvm::hash_code hash_value() const {
244  return llvm::hash_combine(Expr::hash_value(), arg);
245  }
246 
247  /// The child expression.
248  Expr *const arg;
249 
250 protected:
251  UnaryExpr(Kind kind, Expr *arg) : Expr(kind), arg(arg) { assert(arg); }
252 };
253 
254 /// Helper class to CRTP-derive common functions.
255 template <class DerivedT, Expr::Kind DerivedKind>
256 struct UnaryExprBase : public UnaryExpr {
257  template <typename... Args>
258  UnaryExprBase(Args &&...args)
259  : UnaryExpr(DerivedKind, std::forward<Args>(args)...) {}
260  static bool classof(const Expr *e) { return e->kind == DerivedKind; }
261 };
262 
263 /// A power of two.
264 struct PowExpr : public UnaryExprBase<PowExpr, Expr::Kind::Pow> {
265  using UnaryExprBase::UnaryExprBase;
266  void print(llvm::raw_ostream &os) const { os << "2^" << arg; }
267 };
268 
269 /// A binary expression. Contains the actual data. Concrete subclasses are
270 /// merely there for show and ease of use.
271 struct BinaryExpr : public Expr {
272  bool operator==(const BinaryExpr &other) const {
273  return kind == other.kind && lhs() == other.lhs() && rhs() == other.rhs();
274  }
275  llvm::hash_code hash_value() const {
276  return llvm::hash_combine(Expr::hash_value(), *args);
277  }
278  Expr *lhs() const { return args[0]; }
279  Expr *rhs() const { return args[1]; }
280 
281  /// The child expressions.
282  Expr *const args[2];
283 
284 protected:
285  BinaryExpr(Kind kind, Expr *lhs, Expr *rhs) : Expr(kind), args{lhs, rhs} {
286  assert(lhs);
287  assert(rhs);
288  }
289 };
290 
291 /// Helper class to CRTP-derive common functions.
292 template <class DerivedT, Expr::Kind DerivedKind>
293 struct BinaryExprBase : public BinaryExpr {
294  template <typename... Args>
295  BinaryExprBase(Args &&...args)
296  : BinaryExpr(DerivedKind, std::forward<Args>(args)...) {}
297  static bool classof(const Expr *e) { return e->kind == DerivedKind; }
298 };
299 
300 /// An addition.
301 struct AddExpr : public BinaryExprBase<AddExpr, Expr::Kind::Add> {
302  using BinaryExprBase::BinaryExprBase;
303  void print(llvm::raw_ostream &os) const {
304  os << "(" << *lhs() << " + " << *rhs() << ")";
305  }
306 };
307 
308 /// The maximum of two expressions.
309 struct MaxExpr : public BinaryExprBase<MaxExpr, Expr::Kind::Max> {
310  using BinaryExprBase::BinaryExprBase;
311  void print(llvm::raw_ostream &os) const {
312  os << "max(" << *lhs() << ", " << *rhs() << ")";
313  }
314 };
315 
316 /// The minimum of two expressions.
317 struct MinExpr : public BinaryExprBase<MinExpr, Expr::Kind::Min> {
318  using BinaryExprBase::BinaryExprBase;
319  void print(llvm::raw_ostream &os) const {
320  os << "min(" << *lhs() << ", " << *rhs() << ")";
321  }
322 };
323 
324 void Expr::print(llvm::raw_ostream &os) const {
325  TypeSwitch<const Expr *>(this).Case<EXPR_CLASSES>(
326  [&](auto *e) { e->print(os); });
327 }
328 
329 } // namespace
330 
331 //===----------------------------------------------------------------------===//
332 // Fast bump allocator with optional interning
333 //===----------------------------------------------------------------------===//
334 
335 namespace {
336 
337 /// An allocation slot in the `InternedAllocator`.
338 template <typename T>
339 struct InternedSlot {
340  T *ptr;
341  InternedSlot(T *ptr) : ptr(ptr) {}
342 };
343 
344 /// A simple bump allocator that ensures only ever one copy per object exists.
345 /// The allocated objects must not have a destructor.
346 template <typename T, typename std::enable_if_t<
347  std::is_trivially_destructible<T>::value, int> = 0>
348 class InternedAllocator {
349  using Slot = InternedSlot<T>;
350  llvm::DenseSet<Slot> interned;
351  llvm::BumpPtrAllocator &allocator;
352 
353 public:
354  InternedAllocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
355 
356  /// Allocate a new object if it does not yet exist, or return a pointer to the
357  /// existing one. `R` is the type of the object to be allocated. `R` must be
358  /// derived from or be the type `T`.
359  template <typename R = T, typename... Args>
360  std::pair<R *, bool> alloc(Args &&...args) {
361  auto stack_value = R(std::forward<Args>(args)...);
362  auto stack_slot = Slot(&stack_value);
363  auto it = interned.find(stack_slot);
364  if (it != interned.end())
365  return std::make_pair(static_cast<R *>(it->ptr), false);
366  auto heap_value = new (allocator) R(std::move(stack_value));
367  interned.insert(Slot(heap_value));
368  return std::make_pair(heap_value, true);
369  }
370 };
371 
372 /// A simple bump allocator. The allocated objects must not have a destructor.
373 /// This allocator is mainly there for symmetry with the `InternedAllocator`.
374 template <typename T, typename std::enable_if_t<
375  std::is_trivially_destructible<T>::value, int> = 0>
376 class Allocator {
377  llvm::BumpPtrAllocator &allocator;
378 
379 public:
380  Allocator(llvm::BumpPtrAllocator &allocator) : allocator(allocator) {}
381 
382  /// Allocate a new object. `R` is the type of the object to be allocated. `R`
383  /// must be derived from or be the type `T`.
384  template <typename R = T, typename... Args>
385  R *alloc(Args &&...args) {
386  return new (allocator) R(std::forward<Args>(args)...);
387  }
388 };
389 
390 } // namespace
391 
392 //===----------------------------------------------------------------------===//
393 // Constraint Solver
394 //===----------------------------------------------------------------------===//
395 
396 namespace {
397 /// A canonicalized linear inequality that maps a constraint on var `x` to the
398 /// linear inequality `x >= max(a*x+b, c) + (failed ? ∞ : 0)`.
399 ///
400 /// The inequality separately tracks recursive (a, b) and non-recursive (c)
401 /// constraints on `x`. This allows it to properly identify the combination of
402 /// the two constraints `x >= x-1` and `x >= 4` to be satisfiable as
403 /// `x >= max(x-1, 4)`. If it only tracked inequality as `x >= a*x+b`, the
404 /// combination of these two constraints would be `x >= x+4` (due to max(-1,4) =
405 /// 4), which would be unsatisfiable.
406 ///
407 /// The `failed` flag acts as an additional `∞` term that renders the inequality
408 /// unsatisfiable. It is used as a tombstone value in case an operation renders
409 /// the equality unsatisfiable (e.g. `x >= 2**x` would be represented as the
410 /// inequality `x >= ∞`).
411 ///
412 /// Inequalities represented in this form can easily be checked for
413 /// unsatisfiability in the presence of recursion by inspecting the coefficients
414 /// a and b. The `sat` function performs this action.
415 struct LinIneq {
416  // x >= max(a*x+b, c) + (failed ? ∞ : 0)
417  int32_t rec_scale = 0; // a
418  int32_t rec_bias = 0; // b
419  int32_t nonrec_bias = 0; // c
420  bool failed = false;
421 
422  /// Create a new unsatisfiable inequality `x >= ∞`.
423  static LinIneq unsat() { return LinIneq(true); }
424 
425  /// Create a new inequality `x >= (failed ? ∞ : 0)`.
426  explicit LinIneq(bool failed = false) : failed(failed) {}
427 
428  /// Create a new inequality `x >= bias`.
429  explicit LinIneq(int32_t bias) : nonrec_bias(bias) {}
430 
431  /// Create a new inequality `x >= scale*x+bias`.
432  explicit LinIneq(int32_t scale, int32_t bias) {
433  if (scale != 0) {
434  rec_scale = scale;
435  rec_bias = bias;
436  } else {
437  nonrec_bias = bias;
438  }
439  }
440 
441  /// Create a new inequality `x >= max(rec_scale*x+rec_bias, nonrec_bias) +
442  /// (failed ? ∞ : 0)`.
443  explicit LinIneq(int32_t rec_scale, int32_t rec_bias, int32_t nonrec_bias,
444  bool failed = false)
445  : failed(failed) {
446  if (rec_scale != 0) {
447  this->rec_scale = rec_scale;
448  this->rec_bias = rec_bias;
449  this->nonrec_bias = nonrec_bias;
450  } else {
451  this->nonrec_bias = std::max(rec_bias, nonrec_bias);
452  }
453  }
454 
455  /// Combine two inequalities by taking the maxima of corresponding
456  /// coefficients.
457  ///
458  /// This essentially combines `x >= max(a1*x+b1, c1)` and `x >= max(a2*x+b2,
459  /// c2)` into a new `x >= max(max(a1,a2)*x+max(b1,b2), max(c1,c2))`. This is
460  /// a pessimistic upper bound, since e.g. `x >= 2x-10` and `x >= x-5` may both
461  /// hold, but the resulting `x >= 2x-5` may pessimistically not hold.
462  static LinIneq max(const LinIneq &lhs, const LinIneq &rhs) {
463  return LinIneq(std::max(lhs.rec_scale, rhs.rec_scale),
464  std::max(lhs.rec_bias, rhs.rec_bias),
465  std::max(lhs.nonrec_bias, rhs.nonrec_bias),
466  lhs.failed || rhs.failed);
467  }
468 
469  /// Combine two inequalities by summing up the two right hand sides.
470  ///
471  /// This is a tricky one, since the addition of the two max terms will lead to
472  /// a maximum over four possible terms (similar to a binomial expansion). In
473  /// order to shoehorn this back into a two-term maximum, we have to pick the
474  /// recursive term that will grow the fastest.
475  ///
476  /// As an example for this problem, consider the following addition:
477  ///
478  /// x >= max(a1*x+b1, c1) + max(a2*x+b2, c2)
479  ///
480  /// We would like to expand and rearrange this again into a maximum:
481  ///
482  /// x >= max(a1*x+b1 + max(a2*x+b2, c2), c1 + max(a2*x+b2, c2))
483  /// x >= max(max(a1*x+b1 + a2*x+b2, a1*x+b1 + c2),
484  /// max(c1 + a2*x+b2, c1 + c2))
485  /// x >= max((a1+a2)*x+(b1+b2), a1*x+(b1+c2), a2*x+(b2+c1), c1+c2)
486  ///
487  /// Since we are combining two two-term maxima, there are four possible ways
488  /// how the terms can combine, leading to the above four-term maximum. An easy
489  /// upper bound of the form we want would be the following:
490  ///
491  /// x >= max(max(a1+a2, a1, a2)*x + max(b1+b2, b1+c2, b2+c1), c1+c2)
492  ///
493  /// However, this is a very pessimistic upper-bound that will declare very
494  /// common patterns in the IR as unbreakable cycles, despite them being very
495  /// much breakable. For example:
496  ///
497  /// x >= max(x, 42) + max(0, -3) <-- breakable recursion
498  /// x >= max(max(1+0, 1, 0)*x + max(42+0, -3, 42), 42-2)
499  /// x >= max(x + 42, 39) <-- unbreakable recursion!
500  ///
501  /// A better approach is to take the expanded four-term maximum, retain the
502  /// non-recursive term (c1+c2), and estimate which one of the recursive terms
503  /// (first three) will become dominant as we choose greater values for x.
504  /// Since x never is inferred to be negative, the recursive term in the
505  /// maximum with the highest scaling factor for x will end up dominating as
506  /// x tends to ∞:
507  ///
508  /// x >= max({
509  /// (a1+a2)*x+(b1+b2) if a1+a2 >= max(a1+a2, a1, a2) and a1>0 and a2>0,
510  /// a1*x+(b1+c2) if a1 >= max(a1+a2, a1, a2) and a1>0,
511  /// a2*x+(b2+c1) if a2 >= max(a1+a2, a1, a2) and a2>0,
512  /// 0 otherwise
513  /// }, c1+c2)
514  ///
515  /// In case multiple cases apply, the highest bias of the recursive term is
516  /// picked. With this, the above problematic example triggers the second case
517  /// and becomes:
518  ///
519  /// x >= max(1*x+(0-3), 42-3) = max(x-3, 39)
520  ///
521  /// Of which the first case is chosen, as it has the lower bias value.
522  static LinIneq add(const LinIneq &lhs, const LinIneq &rhs) {
523  // Determine the maximum scaling factor among the three possible recursive
524  // terms.
525  auto enable1 = lhs.rec_scale > 0 && rhs.rec_scale > 0;
526  auto enable2 = lhs.rec_scale > 0;
527  auto enable3 = rhs.rec_scale > 0;
528  auto scale1 = lhs.rec_scale + rhs.rec_scale; // (a1+a2)
529  auto scale2 = lhs.rec_scale; // a1
530  auto scale3 = rhs.rec_scale; // a2
531  auto bias1 = lhs.rec_bias + rhs.rec_bias; // (b1+b2)
532  auto bias2 = lhs.rec_bias + rhs.nonrec_bias; // (b1+c2)
533  auto bias3 = rhs.rec_bias + lhs.nonrec_bias; // (b2+c1)
534  auto maxScale = std::max(scale1, std::max(scale2, scale3));
535 
536  // Among those terms that have a maximum scaling factor, determine the
537  // largest bias value.
538  std::optional<int32_t> maxBias;
539  if (enable1 && scale1 == maxScale)
540  maxBias = bias1;
541  if (enable2 && scale2 == maxScale && (!maxBias || bias2 > *maxBias))
542  maxBias = bias2;
543  if (enable3 && scale3 == maxScale && (!maxBias || bias3 > *maxBias))
544  maxBias = bias3;
545 
546  // Pick from the recursive terms the one with maximum scaling factor and
547  // minimum bias value.
548  auto nonrec_bias = lhs.nonrec_bias + rhs.nonrec_bias; // c1+c2
549  auto failed = lhs.failed || rhs.failed;
550  if (enable1 && scale1 == maxScale && bias1 == *maxBias)
551  return LinIneq(scale1, bias1, nonrec_bias, failed);
552  if (enable2 && scale2 == maxScale && bias2 == *maxBias)
553  return LinIneq(scale2, bias2, nonrec_bias, failed);
554  if (enable3 && scale3 == maxScale && bias3 == *maxBias)
555  return LinIneq(scale3, bias3, nonrec_bias, failed);
556  return LinIneq(0, 0, nonrec_bias, failed);
557  }
558 
559  /// Check if the inequality is satisfiable.
560  ///
561  /// The inequality becomes unsatisfiable if the RHS is ∞, or a>1, or a==1 and
562  /// b <= 0. Otherwise there exists as solution for `x` that satisfies the
563  /// inequality.
564  bool sat() const {
565  if (failed)
566  return false;
567  if (rec_scale > 1)
568  return false;
569  if (rec_scale == 1 && rec_bias > 0)
570  return false;
571  return true;
572  }
573 
574  /// Dump the inequality in human-readable form.
575  void print(llvm::raw_ostream &os) const {
576  bool any = false;
577  bool both = (rec_scale != 0 || rec_bias != 0) && nonrec_bias != 0;
578  os << "x >= ";
579  if (both)
580  os << "max(";
581  if (rec_scale != 0) {
582  any = true;
583  if (rec_scale != 1)
584  os << rec_scale << "*";
585  os << "x";
586  }
587  if (rec_bias != 0) {
588  if (any) {
589  if (rec_bias < 0)
590  os << " - " << -rec_bias;
591  else
592  os << " + " << rec_bias;
593  } else {
594  any = true;
595  os << rec_bias;
596  }
597  }
598  if (both)
599  os << ", ";
600  if (nonrec_bias != 0) {
601  any = true;
602  os << nonrec_bias;
603  }
604  if (both)
605  os << ")";
606  if (failed) {
607  if (any)
608  os << " + ";
609  os << "∞";
610  }
611  if (!any)
612  os << "0";
613  }
614 };
615 
616 /// A simple solver for width constraints.
617 class ConstraintSolver {
618 public:
619  ConstraintSolver() = default;
620 
621  VarExpr *var() {
622  auto v = vars.alloc();
623  exprs.push_back(v);
624  if (currentInfo)
625  info[v].insert(currentInfo);
626  if (currentLoc)
627  locs[v].insert(*currentLoc);
628  return v;
629  }
630  DerivedExpr *derived() {
631  auto *d = derivs.alloc();
632  exprs.push_back(d);
633  return d;
634  }
635  KnownExpr *known(int32_t value) { return alloc<KnownExpr>(knowns, value); }
636  IdExpr *id(Expr *arg) { return alloc<IdExpr>(ids, arg); }
637  PowExpr *pow(Expr *arg) { return alloc<PowExpr>(uns, arg); }
638  AddExpr *add(Expr *lhs, Expr *rhs) { return alloc<AddExpr>(bins, lhs, rhs); }
639  MaxExpr *max(Expr *lhs, Expr *rhs) { return alloc<MaxExpr>(bins, lhs, rhs); }
640  MinExpr *min(Expr *lhs, Expr *rhs) { return alloc<MinExpr>(bins, lhs, rhs); }
641 
642  /// Add a constraint `lhs >= rhs`. Multiple constraints on the same variable
643  /// are coalesced into a `max(a, b)` expr.
644  Expr *addGeqConstraint(VarExpr *lhs, Expr *rhs) {
645  if (lhs->constraint)
646  lhs->constraint = max(lhs->constraint, rhs);
647  else
648  lhs->constraint = id(rhs);
649  return lhs->constraint;
650  }
651 
652  /// Add a constraint `lhs <= rhs`. Multiple constraints on the same variable
653  /// are coalesced into a `min(a, b)` expr.
654  Expr *addLeqConstraint(VarExpr *lhs, Expr *rhs) {
655  if (lhs->upperBound)
656  lhs->upperBound = min(lhs->upperBound, rhs);
657  else
658  lhs->upperBound = id(rhs);
659  return lhs->upperBound;
660  }
661 
662  void dumpConstraints(llvm::raw_ostream &os);
663  LogicalResult solve();
664 
665  using ContextInfo = DenseMap<Expr *, llvm::SmallSetVector<FieldRef, 1>>;
666  const ContextInfo &getContextInfo() const { return info; }
667  void setCurrentContextInfo(FieldRef fieldRef) { currentInfo = fieldRef; }
668  void setCurrentLocation(std::optional<Location> loc) { currentLoc = loc; }
669 
670 private:
671  // Allocator for constraint expressions.
672  llvm::BumpPtrAllocator allocator;
673  Allocator<VarExpr> vars = {allocator};
674  Allocator<DerivedExpr> derivs = {allocator};
675  InternedAllocator<KnownExpr> knowns = {allocator};
676  InternedAllocator<IdExpr> ids = {allocator};
677  InternedAllocator<UnaryExpr> uns = {allocator};
678  InternedAllocator<BinaryExpr> bins = {allocator};
679 
680  /// A list of expressions in the order they were created.
681  std::vector<Expr *> exprs;
682  RootExpr root = {exprs};
683 
684  /// Add an allocated expression to the list above.
685  template <typename R, typename T, typename... Args>
686  R *alloc(InternedAllocator<T> &allocator, Args &&...args) {
687  auto it = allocator.template alloc<R>(std::forward<Args>(args)...);
688  if (it.second)
689  exprs.push_back(it.first);
690  if (currentInfo)
691  info[it.first].insert(currentInfo);
692  if (currentLoc)
693  locs[it.first].insert(*currentLoc);
694  return it.first;
695  }
696 
697  /// Contextual information for each expression, indicating which values in the
698  /// IR lead to this expression.
699  ContextInfo info;
700  FieldRef currentInfo = {};
701  DenseMap<Expr *, llvm::SmallSetVector<Location, 1>> locs;
702  std::optional<Location> currentLoc;
703 
704  // Forbid copyign or moving the solver, which would invalidate the refs to
705  // allocator held by the allocators.
706  ConstraintSolver(ConstraintSolver &&) = delete;
707  ConstraintSolver(const ConstraintSolver &) = delete;
708  ConstraintSolver &operator=(ConstraintSolver &&) = delete;
709  ConstraintSolver &operator=(const ConstraintSolver &) = delete;
710 
711  void emitUninferredWidthError(VarExpr *var);
712 
713  LinIneq checkCycles(VarExpr *var, Expr *expr,
714  SmallPtrSetImpl<Expr *> &seenVars,
715  InFlightDiagnostic *reportInto = nullptr,
716  unsigned indent = 1);
717 };
718 
719 } // namespace
720 
721 /// Print all constraints in the solver to an output stream.
722 void ConstraintSolver::dumpConstraints(llvm::raw_ostream &os) {
723  for (auto *e : exprs) {
724  if (auto *v = dyn_cast<VarExpr>(e)) {
725  if (v->constraint)
726  os << "- " << *v << " >= " << *v->constraint << "\n";
727  else
728  os << "- " << *v << " unconstrained\n";
729  }
730  }
731 }
732 
733 #ifndef NDEBUG
734 inline llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const LinIneq &l) {
735  l.print(os);
736  return os;
737 }
738 #endif
739 
740 /// Compute the canonicalized linear inequality expression starting at `expr`,
741 /// for the `var` as the left hand side `x` of the inequality. `seenVars` is
742 /// used as a recursion breaker. Occurrences of `var` itself within the
743 /// expression are mapped to the `a` coefficient in the inequality. Any other
744 /// variables are substituted and, in the presence of a recursion in a variable
745 /// other than `var`, treated as zero. `info` is a mapping from constraint
746 /// expressions to values and operations that produced the expression, and is
747 /// used during error reporting. If `reportInto` is present, the function will
748 /// additionally attach unsatisfiable inequalities as notes to the diagnostic as
749 /// it encounters them.
750 LinIneq ConstraintSolver::checkCycles(VarExpr *var, Expr *expr,
751  SmallPtrSetImpl<Expr *> &seenVars,
752  InFlightDiagnostic *reportInto,
753  unsigned indent) {
754  auto ineq =
755  TypeSwitch<Expr *, LinIneq>(expr)
756  .Case<KnownExpr>([&](auto *expr) { return LinIneq(*expr->solution); })
757  .Case<VarExpr>([&](auto *expr) {
758  if (expr == var)
759  return LinIneq(1, 0); // x >= 1*x + 0
760  if (!seenVars.insert(expr).second)
761  // Count recursions in other variables as 0. This is sane
762  // since the cycle is either breakable, in which case the
763  // recursion does not modify the resulting value of the
764  // variable, or it is not breakable and will be caught by
765  // this very function once it is called on that variable.
766  return LinIneq(0);
767  if (!expr->constraint)
768  // Count unconstrained variables as `x >= 0`.
769  return LinIneq(0);
770  auto l = checkCycles(var, expr->constraint, seenVars, reportInto,
771  indent + 1);
772  seenVars.erase(expr);
773  return l;
774  })
775  .Case<IdExpr>([&](auto *expr) {
776  return checkCycles(var, expr->arg, seenVars, reportInto,
777  indent + 1);
778  })
779  .Case<PowExpr>([&](auto *expr) {
780  // If we can evaluate `2**arg` to a sensible constant, do
781  // so. This is the case if a == 0 and c < 31 such that 2**c is
782  // representable.
783  auto arg =
784  checkCycles(var, expr->arg, seenVars, reportInto, indent + 1);
785  if (arg.rec_scale != 0 || arg.nonrec_bias < 0 ||
786  arg.nonrec_bias >= 31)
787  return LinIneq::unsat();
788  return LinIneq(1 << arg.nonrec_bias); // x >= 2**arg
789  })
790  .Case<AddExpr>([&](auto *expr) {
791  return LinIneq::add(
792  checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
793  checkCycles(var, expr->rhs(), seenVars, reportInto,
794  indent + 1));
795  })
796  .Case<MaxExpr, MinExpr>([&](auto *expr) {
797  // Combine the inequalities of the LHS and RHS into a single overly
798  // pessimistic inequality. We treat `MinExpr` the same as `MaxExpr`,
799  // since `max(a,b)` is an upper bound to `min(a,b)`.
800  return LinIneq::max(
801  checkCycles(var, expr->lhs(), seenVars, reportInto, indent + 1),
802  checkCycles(var, expr->rhs(), seenVars, reportInto,
803  indent + 1));
804  })
805  .Default([](auto) { return LinIneq::unsat(); });
806 
807  // If we were passed an in-flight diagnostic and the current inequality is
808  // unsatisfiable, attach notes to the diagnostic indicating the values or
809  // operations that contributed to this part of the constraint expression.
810  if (reportInto && !ineq.sat()) {
811  auto report = [&](Location loc) {
812  auto &note = reportInto->attachNote(loc);
813  note << "constrained width W >= ";
814  if (ineq.rec_scale == -1)
815  note << "-";
816  if (ineq.rec_scale != 1)
817  note << ineq.rec_scale;
818  note << "W";
819  if (ineq.rec_bias < 0)
820  note << "-" << -ineq.rec_bias;
821  if (ineq.rec_bias > 0)
822  note << "+" << ineq.rec_bias;
823  note << " here:";
824  };
825  auto it = locs.find(expr);
826  if (it != locs.end())
827  for (auto loc : it->second)
828  report(loc);
829  }
830  if (!reportInto)
831  LLVM_DEBUG(llvm::dbgs().indent(indent * 2)
832  << "- Visited " << *expr << ": " << ineq << "\n");
833 
834  return ineq;
835 }
836 
837 using ExprSolution = std::pair<std::optional<int32_t>, bool>;
838 
839 static ExprSolution
840 computeUnary(ExprSolution arg, llvm::function_ref<int32_t(int32_t)> operation) {
841  if (arg.first)
842  arg.first = operation(*arg.first);
843  return arg;
844 }
845 
846 static ExprSolution
848  llvm::function_ref<int32_t(int32_t, int32_t)> operation) {
849  auto result = ExprSolution{std::nullopt, lhs.second || rhs.second};
850  if (lhs.first && rhs.first)
851  result.first = operation(*lhs.first, *rhs.first);
852  else if (lhs.first)
853  result.first = lhs.first;
854  else if (rhs.first)
855  result.first = rhs.first;
856  return result;
857 }
858 
859 /// Compute the value of a constraint `expr`. `seenVars` is used as a recursion
860 /// breaker. Recursive variables are treated as zero. Returns the computed value
861 /// and a boolean indicating whether a recursion was detected. This may be used
862 /// to memoize the result of expressions in case they were not involved in a
863 /// cycle (which may alter their value from the perspective of a variable).
864 static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
865  unsigned defaultWorklistSize) {
866 
867  struct Frame {
868  Expr *expr;
869  unsigned indent;
870  };
871 
872  // indent only used for debug logs.
873  unsigned indent = 1;
874  std::vector<Frame> worklist({{expr, indent}});
875  llvm::DenseMap<Expr *, ExprSolution> solvedExprs;
876  // Reserving the vector size, to avoid frequent reallocs. The worklist can be
877  // quite large.
878  worklist.reserve(defaultWorklistSize);
879 
880  while (!worklist.empty()) {
881  auto &frame = worklist.back();
882  auto setSolution = [&](ExprSolution solution) {
883  // Memoize the result.
884  if (solution.first && !solution.second)
885  frame.expr->solution = *solution.first;
886  solvedExprs[frame.expr] = solution;
887 
888  // Produce some useful debug prints.
889  LLVM_DEBUG({
890  if (!isa<KnownExpr>(frame.expr)) {
891  if (solution.first)
892  llvm::dbgs().indent(frame.indent * 2)
893  << "= Solved " << *frame.expr << " = " << *solution.first;
894  else
895  llvm::dbgs().indent(frame.indent * 2)
896  << "= Skipped " << *frame.expr;
897  llvm::dbgs() << " (" << (solution.second ? "cycle broken" : "unique")
898  << ")\n";
899  }
900  });
901 
902  worklist.pop_back();
903  };
904 
905  // See if we have a memoized result we can return.
906  if (frame.expr->solution) {
907  LLVM_DEBUG({
908  if (!isa<KnownExpr>(frame.expr))
909  llvm::dbgs().indent(indent * 2) << "- Cached " << *frame.expr << " = "
910  << *frame.expr->solution << "\n";
911  });
912  setSolution(ExprSolution{*frame.expr->solution, false});
913  continue;
914  }
915 
916  // Otherwise compute the value of the expression.
917  LLVM_DEBUG({
918  if (!isa<KnownExpr>(frame.expr))
919  llvm::dbgs().indent(frame.indent * 2)
920  << "- Solving " << *frame.expr << "\n";
921  });
922 
923  TypeSwitch<Expr *>(frame.expr)
924  .Case<KnownExpr>([&](auto *expr) {
925  setSolution(ExprSolution{*expr->solution, false});
926  })
927  .Case<VarExpr>([&](auto *expr) {
928  if (solvedExprs.contains(expr->constraint)) {
929  auto solution = solvedExprs[expr->constraint];
930  // If we've solved the upper bound already, store the solution.
931  // This will be explicitly solved for later if not computed as
932  // part of the solving that resolved this constraint.
933  // This should only happen if somehow the constraint is
934  // solved before visiting this expression, so that our upperBound
935  // was not added to the worklist such that it was handled first.
936  if (expr->upperBound && solvedExprs.contains(expr->upperBound))
937  expr->upperBoundSolution = solvedExprs[expr->upperBound].first;
938  seenVars.erase(expr);
939  // Constrain variables >= 0.
940  if (solution.first && *solution.first < 0)
941  solution.first = 0;
942  return setSolution(solution);
943  }
944 
945  // Unconstrained variables produce no solution.
946  if (!expr->constraint)
947  return setSolution(ExprSolution{std::nullopt, false});
948  // Return no solution for recursions in the variables. This is sane
949  // and will cause the expression to be ignored when computing the
950  // parent, e.g. `a >= max(a, 1)` will become just `a >= 1`.
951  if (!seenVars.insert(expr).second)
952  return setSolution(ExprSolution{std::nullopt, true});
953 
954  worklist.push_back({expr->constraint, indent + 1});
955  if (expr->upperBound)
956  worklist.push_back({expr->upperBound, indent + 1});
957  })
958  .Case<IdExpr>([&](auto *expr) {
959  if (solvedExprs.contains(expr->arg))
960  return setSolution(solvedExprs[expr->arg]);
961  worklist.push_back({expr->arg, indent + 1});
962  })
963  .Case<PowExpr>([&](auto *expr) {
964  if (solvedExprs.contains(expr->arg))
965  return setSolution(computeUnary(
966  solvedExprs[expr->arg], [](int32_t arg) { return 1 << arg; }));
967 
968  worklist.push_back({expr->arg, indent + 1});
969  })
970  .Case<AddExpr>([&](auto *expr) {
971  if (solvedExprs.contains(expr->lhs()) &&
972  solvedExprs.contains(expr->rhs()))
973  return setSolution(computeBinary(
974  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
975  [](int32_t lhs, int32_t rhs) { return lhs + rhs; }));
976 
977  worklist.push_back({expr->lhs(), indent + 1});
978  worklist.push_back({expr->rhs(), indent + 1});
979  })
980  .Case<MaxExpr>([&](auto *expr) {
981  if (solvedExprs.contains(expr->lhs()) &&
982  solvedExprs.contains(expr->rhs()))
983  return setSolution(computeBinary(
984  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
985  [](int32_t lhs, int32_t rhs) { return std::max(lhs, rhs); }));
986 
987  worklist.push_back({expr->lhs(), indent + 1});
988  worklist.push_back({expr->rhs(), indent + 1});
989  })
990  .Case<MinExpr>([&](auto *expr) {
991  if (solvedExprs.contains(expr->lhs()) &&
992  solvedExprs.contains(expr->rhs()))
993  return setSolution(computeBinary(
994  solvedExprs[expr->lhs()], solvedExprs[expr->rhs()],
995  [](int32_t lhs, int32_t rhs) { return std::min(lhs, rhs); }));
996 
997  worklist.push_back({expr->lhs(), indent + 1});
998  worklist.push_back({expr->rhs(), indent + 1});
999  })
1000  .Default([&](auto) {
1001  setSolution(ExprSolution{std::nullopt, false});
1002  });
1003  }
1004 
1005  return solvedExprs[expr];
1006 }
1007 
1008 /// Solve the constraint problem. This is a very simple implementation that
1009 /// does not fully solve the problem if there are weird dependency cycles
1010 /// present.
1011 LogicalResult ConstraintSolver::solve() {
1012  LLVM_DEBUG({
1013  llvm::dbgs() << "\n";
1014  debugHeader("Constraints") << "\n\n";
1015  dumpConstraints(llvm::dbgs());
1016  });
1017 
1018  // Ensure that there are no adverse cycles around.
1019  LLVM_DEBUG({
1020  llvm::dbgs() << "\n";
1021  debugHeader("Checking for unbreakable loops") << "\n\n";
1022  });
1023  SmallPtrSet<Expr *, 16> seenVars;
1024  bool anyFailed = false;
1025 
1026  for (auto *expr : exprs) {
1027  // Only work on variables.
1028  auto *var = dyn_cast<VarExpr>(expr);
1029  if (!var || !var->constraint)
1030  continue;
1031  LLVM_DEBUG(llvm::dbgs()
1032  << "- Checking " << *var << " >= " << *var->constraint << "\n");
1033 
1034  // Canonicalize the variable's constraint expression into a form that allows
1035  // us to easily determine if any recursion leads to an unsatisfiable
1036  // constraint. The `seenVars` set acts as a recursion breaker.
1037  seenVars.insert(var);
1038  auto ineq = checkCycles(var, var->constraint, seenVars);
1039  seenVars.clear();
1040 
1041  // If the constraint is satisfiable, we're done.
1042  // TODO: It's possible that this result is already sufficient to arrive at a
1043  // solution for the constraint, and the second pass further down is not
1044  // necessary. This would require more proper handling of `MinExpr` in the
1045  // cycle checking code.
1046  if (ineq.sat()) {
1047  LLVM_DEBUG(llvm::dbgs()
1048  << " = Breakable since " << ineq << " satisfiable\n");
1049  continue;
1050  }
1051 
1052  // If we arrive here, the constraint is not satisfiable at all. To provide
1053  // some guidance to the user, we call the cycle checking code again, but
1054  // this time with an in-flight diagnostic to attach notes indicating
1055  // unsatisfiable paths in the cycle.
1056  LLVM_DEBUG(llvm::dbgs()
1057  << " = UNBREAKABLE since " << ineq << " unsatisfiable\n");
1058  anyFailed = true;
1059  for (auto fieldRef : info.find(var)->second) {
1060  // Depending on whether this value stems from an operation or not, create
1061  // an appropriate diagnostic identifying the value.
1062  auto op = fieldRef.getDefiningOp();
1063  auto diag = op ? op->emitOpError()
1064  : mlir::emitError(fieldRef.getValue().getLoc())
1065  << "value ";
1066  diag << "is constrained to be wider than itself";
1067 
1068  // Re-run the cycle checking, but this time reporting into the diagnostic.
1069  seenVars.insert(var);
1070  checkCycles(var, var->constraint, seenVars, &diag);
1071  seenVars.clear();
1072  }
1073  }
1074 
1075  // If there were cycles, return now to avoid complaining to the user about
1076  // dependent widths not being inferred.
1077  if (anyFailed)
1078  return failure();
1079 
1080  // Iterate over the constraint variables and solve each.
1081  LLVM_DEBUG({
1082  llvm::dbgs() << "\n";
1083  debugHeader("Solving constraints") << "\n\n";
1084  });
1085  unsigned defaultWorklistSize = exprs.size() / 2;
1086  for (auto *expr : exprs) {
1087  // Only work on variables.
1088  auto *var = dyn_cast<VarExpr>(expr);
1089  if (!var)
1090  continue;
1091 
1092  // Complain about unconstrained variables.
1093  if (!var->constraint) {
1094  LLVM_DEBUG(llvm::dbgs() << "- Unconstrained " << *var << "\n");
1095  emitUninferredWidthError(var);
1096  anyFailed = true;
1097  continue;
1098  }
1099 
1100  // Compute the value for the variable.
1101  LLVM_DEBUG(llvm::dbgs()
1102  << "- Solving " << *var << " >= " << *var->constraint << "\n");
1103  seenVars.insert(var);
1104  auto solution = solveExpr(var->constraint, seenVars, defaultWorklistSize);
1105  // Compute the upperBound if there is one and haven't already.
1106  if (var->upperBound && !var->upperBoundSolution)
1107  var->upperBoundSolution =
1108  solveExpr(var->upperBound, seenVars, defaultWorklistSize).first;
1109  seenVars.clear();
1110 
1111  // Constrain variables >= 0.
1112  if (solution.first && *solution.first < 0)
1113  solution.first = 0;
1114  var->solution = solution.first;
1115 
1116  // In case the width could not be inferred, complain to the user. This might
1117  // be the case if the width depends on an unconstrained variable.
1118  if (!solution.first) {
1119  LLVM_DEBUG(llvm::dbgs() << " - UNSOLVED " << *var << "\n");
1120  emitUninferredWidthError(var);
1121  anyFailed = true;
1122  continue;
1123  }
1124  LLVM_DEBUG(llvm::dbgs()
1125  << " = Solved " << *var << " = " << solution.first << " ("
1126  << (solution.second ? "cycle broken" : "unique") << ")\n");
1127 
1128  // Check if the solution we have found violates an upper bound.
1129  if (var->upperBoundSolution && var->upperBoundSolution < *solution.first) {
1130  LLVM_DEBUG(llvm::dbgs() << " ! Unsatisfiable " << *var
1131  << " <= " << var->upperBoundSolution << "\n");
1132  emitUninferredWidthError(var);
1133  anyFailed = true;
1134  }
1135  }
1136 
1137  // Copy over derived widths.
1138  for (auto *expr : exprs) {
1139  // Only work on derived values.
1140  auto *derived = dyn_cast<DerivedExpr>(expr);
1141  if (!derived)
1142  continue;
1143 
1144  auto *assigned = derived->assigned;
1145  if (!assigned || !assigned->solution) {
1146  LLVM_DEBUG(llvm::dbgs() << "- Unused " << *derived << " set to 0\n");
1147  derived->solution = 0;
1148  } else {
1149  LLVM_DEBUG(llvm::dbgs() << "- Deriving " << *derived << " = "
1150  << assigned->solution << "\n");
1151  derived->solution = *assigned->solution;
1152  }
1153  }
1154 
1155  return failure(anyFailed);
1156 }
1157 
1158 // Emits the diagnostic to inform the user about an uninferred width in the
1159 // design. Returns true if an error was reported, false otherwise.
1160 void ConstraintSolver::emitUninferredWidthError(VarExpr *var) {
1161  FieldRef fieldRef = info.find(var)->second.back();
1162  Value value = fieldRef.getValue();
1163 
1164  auto diag = mlir::emitError(value.getLoc(), "uninferred width:");
1165 
1166  // Try to hint the user at what kind of node this is.
1167  if (isa<BlockArgument>(value)) {
1168  diag << " port";
1169  } else if (auto op = value.getDefiningOp()) {
1170  TypeSwitch<Operation *>(op)
1171  .Case<WireOp>([&](auto) { diag << " wire"; })
1172  .Case<RegOp, RegResetOp>([&](auto) { diag << " reg"; })
1173  .Case<NodeOp>([&](auto) { diag << " node"; })
1174  .Default([&](auto) { diag << " value"; });
1175  } else {
1176  diag << " value";
1177  }
1178 
1179  // Actually print what the user can refer to.
1180  auto [fieldName, rootKnown] = getFieldName(fieldRef);
1181  if (!fieldName.empty()) {
1182  if (!rootKnown)
1183  diag << " field";
1184  diag << " \"" << fieldName << "\"";
1185  }
1186 
1187  if (!var->constraint) {
1188  diag << " is unconstrained";
1189  } else if (var->solution && var->upperBoundSolution &&
1190  var->solution > var->upperBoundSolution) {
1191  diag << " cannot satisfy all width requirements";
1192  LLVM_DEBUG(llvm::dbgs() << *var->constraint << "\n");
1193  LLVM_DEBUG(llvm::dbgs() << *var->upperBound << "\n");
1194  auto loc = locs.find(var->constraint)->second.back();
1195  diag.attachNote(loc) << "width is constrained to be at least "
1196  << *var->solution << " here:";
1197  loc = locs.find(var->upperBound)->second.back();
1198  diag.attachNote(loc) << "width is constrained to be at most "
1199  << *var->upperBoundSolution << " here:";
1200  } else {
1201  diag << " width cannot be determined";
1202  LLVM_DEBUG(llvm::dbgs() << *var->constraint << "\n");
1203  auto loc = locs.find(var->constraint)->second.back();
1204  diag.attachNote(loc) << "width is constrained by an uninferred width here:";
1205  }
1206 }
1207 
1208 //===----------------------------------------------------------------------===//
1209 // Inference Constraint Problem Mapping
1210 //===----------------------------------------------------------------------===//
1211 
1212 namespace {
1213 
1214 /// A helper class which maps the types and operations in a design to a set of
1215 /// variables and constraints to be solved later.
1216 class InferenceMapping {
1217 public:
1218  InferenceMapping(ConstraintSolver &solver, SymbolTable &symtbl,
1219  hw::InnerSymbolTableCollection &istc)
1220  : solver(solver), symtbl(symtbl), irn{symtbl, istc} {}
1221 
1222  LogicalResult map(CircuitOp op);
1223  LogicalResult mapOperation(Operation *op);
1224 
1225  /// Declare all the variables in the value. If the value is a ground type,
1226  /// there is a single variable declared. If the value is an aggregate type,
1227  /// it sets up variables for each unknown width.
1228  void declareVars(Value value, Location loc, bool isDerived = false);
1229 
1230  /// Declare a variable associated with a specific field of an aggregate.
1231  Expr *declareVar(FieldRef fieldRef, Location loc);
1232 
1233  /// Declarate a variable for a type with an unknown width. The type must be a
1234  /// non-aggregate.
1235  Expr *declareVar(FIRRTLType type, Location loc);
1236 
1237  /// Assign the constraint expressions of the fields in the `result` argument
1238  /// as the max of expressions in the `rhs` and `lhs` arguments. Both fields
1239  /// must be the same type.
1240  void maximumOfTypes(Value result, Value rhs, Value lhs);
1241 
1242  /// Constrain the value "larger" to be greater than or equal to "smaller".
1243  /// These may be aggregate values. This is used for regular connects.
1244  void constrainTypes(Value larger, Value smaller, bool equal = false);
1245 
1246  /// Constrain the expression "larger" to be greater than or equals to
1247  /// the expression "smaller".
1248  void constrainTypes(Expr *larger, Expr *smaller,
1249  bool imposeUpperBounds = false, bool equal = false);
1250 
1251  /// Assign the constraint expressions of the fields in the `src` argument as
1252  /// the expressions for the `dst` argument. Both fields must be of the given
1253  /// `type`.
1254  void unifyTypes(FieldRef lhs, FieldRef rhs, FIRRTLType type);
1255 
1256  /// Get the expr associated with the value. The value must be a non-aggregate
1257  /// type.
1258  Expr *getExpr(Value value) const;
1259 
1260  /// Get the expr associated with a specific field in a value.
1261  Expr *getExpr(FieldRef fieldRef) const;
1262 
1263  /// Get the expr associated with a specific field in a value. If value is
1264  /// NULL, then this returns NULL.
1265  Expr *getExprOrNull(FieldRef fieldRef) const;
1266 
1267  /// Set the expr associated with the value. The value must be a non-aggregate
1268  /// type.
1269  void setExpr(Value value, Expr *expr);
1270 
1271  /// Set the expr associated with a specific field in a value.
1272  void setExpr(FieldRef fieldRef, Expr *expr);
1273 
1274  /// Return whether a module was skipped due to being fully inferred already.
1275  bool isModuleSkipped(FModuleOp module) const {
1276  return skippedModules.count(module);
1277  }
1278 
1279  /// Return whether all modules in the mapping were fully inferred.
1280  bool areAllModulesSkipped() const { return allModulesSkipped; }
1281 
1282 private:
1283  /// The constraint solver into which we emit variables and constraints.
1284  ConstraintSolver &solver;
1285 
1286  /// The constraint exprs for each result type of an operation.
1287  DenseMap<FieldRef, Expr *> opExprs;
1288 
1289  /// The fully inferred modules that were skipped entirely.
1290  SmallPtrSet<Operation *, 16> skippedModules;
1291  bool allModulesSkipped = true;
1292 
1293  /// Cache of module symbols
1294  SymbolTable &symtbl;
1295 
1296  /// Full design inner symbol information.
1297  hw::InnerRefNamespace irn;
1298 };
1299 
1300 } // namespace
1301 
1302 /// Check if a type contains any FIRRTL type with uninferred widths.
1303 static bool hasUninferredWidth(Type type) {
1304  return TypeSwitch<Type, bool>(type)
1305  .Case<FIRRTLBaseType>([](auto base) { return base.hasUninferredWidth(); })
1306  .Case<RefType>(
1307  [](auto ref) { return ref.getType().hasUninferredWidth(); })
1308  .Default([](auto) { return false; });
1309 }
1310 
1311 LogicalResult InferenceMapping::map(CircuitOp op) {
1312  LLVM_DEBUG(llvm::dbgs()
1313  << "\n===----- Mapping ops to constraint exprs -----===\n\n");
1314 
1315  // Ensure we have constraint variables established for all module ports.
1316  for (auto module : op.getOps<FModuleOp>())
1317  for (auto arg : module.getArguments()) {
1318  solver.setCurrentContextInfo(FieldRef(arg, 0));
1319  declareVars(arg, module.getLoc());
1320  }
1321 
1322  for (auto module : op.getOps<FModuleOp>()) {
1323  // Check if the module contains *any* uninferred widths. This allows us to
1324  // do an early skip if the module is already fully inferred.
1325  bool anyUninferred = false;
1326  for (auto arg : module.getArguments()) {
1327  anyUninferred |= hasUninferredWidth(arg.getType());
1328  if (anyUninferred)
1329  break;
1330  }
1331  module.walk([&](Operation *op) {
1332  for (auto type : op->getResultTypes())
1333  anyUninferred |= hasUninferredWidth(type);
1334  if (anyUninferred)
1335  return WalkResult::interrupt();
1336  return WalkResult::advance();
1337  });
1338 
1339  if (!anyUninferred) {
1340  LLVM_DEBUG(llvm::dbgs() << "Skipping fully-inferred module '"
1341  << module.getName() << "'\n");
1342  skippedModules.insert(module);
1343  continue;
1344  }
1345 
1346  allModulesSkipped = false;
1347 
1348  // Go through operations in the module, creating type variables for results,
1349  // and generating constraints.
1350  auto result = module.getBodyBlock()->walk(
1351  [&](Operation *op) { return WalkResult(mapOperation(op)); });
1352  if (result.wasInterrupted())
1353  return failure();
1354  }
1355 
1356  return success();
1357 }
1358 
1359 LogicalResult InferenceMapping::mapOperation(Operation *op) {
1360  // In case the operation result has a type without uninferred widths, don't
1361  // even bother to populate the constraint problem and treat that as a known
1362  // size directly. This is done in `declareVars`, which will generate
1363  // `KnownExpr` nodes for all known widths -- which are the only ones in this
1364  // case.
1365  bool allWidthsKnown = true;
1366  for (auto result : op->getResults()) {
1367  if (isa<MuxPrimOp, Mux4CellIntrinsicOp, Mux2CellIntrinsicOp>(op))
1368  if (hasUninferredWidth(op->getOperand(0).getType()))
1369  allWidthsKnown = false;
1370  // Only consider FIRRTL types for width constraints. Ignore any foreign
1371  // types as they don't participate in the width inference process.
1372  auto resultTy = type_dyn_cast<FIRRTLType>(result.getType());
1373  if (!resultTy)
1374  continue;
1375  if (!hasUninferredWidth(resultTy))
1376  declareVars(result, op->getLoc());
1377  else
1378  allWidthsKnown = false;
1379  }
1380  if (allWidthsKnown && !isa<FConnectLike, AttachOp>(op))
1381  return success();
1382 
1383  /// Ignore property assignments, no widths to infer.
1384  if (isa<PropAssignOp>(op))
1385  return success();
1386 
1387  // Actually generate the necessary constraint expressions.
1388  bool mappingFailed = false;
1389  solver.setCurrentContextInfo(
1390  op->getNumResults() > 0 ? FieldRef(op->getResults()[0], 0) : FieldRef());
1391  solver.setCurrentLocation(op->getLoc());
1392  TypeSwitch<Operation *>(op)
1393  .Case<ConstantOp>([&](auto op) {
1394  // If the constant has a known width, use that. Otherwise pick the
1395  // smallest number of bits necessary to represent the constant.
1396  Expr *e;
1397  if (auto width = op.getType().base().getWidth())
1398  e = solver.known(*width);
1399  else {
1400  auto v = op.getValue();
1401  auto w = v.getBitWidth() - (v.isNegative() ? v.countLeadingOnes()
1402  : v.countLeadingZeros());
1403  if (v.isSigned())
1404  w += 1;
1405  e = solver.known(std::max(w, 1u));
1406  }
1407  setExpr(op.getResult(), e);
1408  })
1409  .Case<SpecialConstantOp>([&](auto op) {
1410  // Nothing required.
1411  })
1412  .Case<InvalidValueOp>([&](auto op) {
1413  // We must duplicate the invalid value for each use, since each use can
1414  // be inferred to a different width.
1415  if (!hasUninferredWidth(op.getType()))
1416  return;
1417  declareVars(op.getResult(), op.getLoc(), /*isDerived=*/true);
1418  if (op.use_empty())
1419  return;
1420 
1421  auto type = op.getType();
1422  ImplicitLocOpBuilder builder(op->getLoc(), op);
1423  for (auto &use :
1424  llvm::make_early_inc_range(llvm::drop_begin(op->getUses()))) {
1425  // - `make_early_inc_range` since `getUses()` is invalidated upon
1426  // `use.set(...)`.
1427  // - `drop_begin` such that the first use can keep the original op.
1428  auto clone = builder.create<InvalidValueOp>(type);
1429  declareVars(clone.getResult(), clone.getLoc(),
1430  /*isDerived=*/true);
1431  use.set(clone);
1432  }
1433  })
1434  .Case<WireOp, RegOp>(
1435  [&](auto op) { declareVars(op.getResult(), op.getLoc()); })
1436  .Case<RegResetOp>([&](auto op) {
1437  // The original Scala code also constrains the reset signal to be at
1438  // least 1 bit wide. We don't do this here since the MLIR FIRRTL
1439  // dialect enforces the reset signal to be an async reset or a
1440  // `uint<1>`.
1441  declareVars(op.getResult(), op.getLoc());
1442  // Contrain the register to be greater than or equal to the reset
1443  // signal.
1444  constrainTypes(op.getResult(), op.getResetValue());
1445  })
1446  .Case<NodeOp>([&](auto op) {
1447  // Nodes have the same type as their input.
1448  unifyTypes(FieldRef(op.getResult(), 0), FieldRef(op.getInput(), 0),
1449  op.getResult().getType());
1450  })
1451 
1452  // Aggregate Values
1453  .Case<SubfieldOp>([&](auto op) {
1454  BundleType bundleType = op.getInput().getType();
1455  auto fieldID = bundleType.getFieldID(op.getFieldIndex());
1456  unifyTypes(FieldRef(op.getResult(), 0),
1457  FieldRef(op.getInput(), fieldID), op.getType());
1458  })
1459  .Case<SubindexOp, SubaccessOp>([&](auto op) {
1460  // All vec fields unify to the same thing. Always use the first element
1461  // of the vector, which has a field ID of 1.
1462  unifyTypes(FieldRef(op.getResult(), 0), FieldRef(op.getInput(), 1),
1463  op.getType());
1464  })
1465  .Case<SubtagOp>([&](auto op) {
1466  FEnumType enumType = op.getInput().getType();
1467  auto fieldID = enumType.getFieldID(op.getFieldIndex());
1468  unifyTypes(FieldRef(op.getResult(), 0),
1469  FieldRef(op.getInput(), fieldID), op.getType());
1470  })
1471 
1472  .Case<RefSubOp>([&](RefSubOp op) {
1473  uint64_t fieldID = TypeSwitch<FIRRTLBaseType, uint64_t>(
1474  op.getInput().getType().getType())
1475  .Case<FVectorType>([](auto _) { return 1; })
1476  .Case<BundleType>([&](auto type) {
1477  return type.getFieldID(op.getIndex());
1478  });
1479  unifyTypes(FieldRef(op.getResult(), 0),
1480  FieldRef(op.getInput(), fieldID), op.getType());
1481  })
1482 
1483  // Arithmetic and Logical Binary Primitives
1484  .Case<AddPrimOp, SubPrimOp>([&](auto op) {
1485  auto lhs = getExpr(op.getLhs());
1486  auto rhs = getExpr(op.getRhs());
1487  auto e = solver.add(solver.max(lhs, rhs), solver.known(1));
1488  setExpr(op.getResult(), e);
1489  })
1490  .Case<MulPrimOp>([&](auto op) {
1491  auto lhs = getExpr(op.getLhs());
1492  auto rhs = getExpr(op.getRhs());
1493  auto e = solver.add(lhs, rhs);
1494  setExpr(op.getResult(), e);
1495  })
1496  .Case<DivPrimOp>([&](auto op) {
1497  auto lhs = getExpr(op.getLhs());
1498  Expr *e;
1499  if (op.getType().base().isSigned()) {
1500  e = solver.add(lhs, solver.known(1));
1501  } else {
1502  e = lhs;
1503  }
1504  setExpr(op.getResult(), e);
1505  })
1506  .Case<RemPrimOp>([&](auto op) {
1507  auto lhs = getExpr(op.getLhs());
1508  auto rhs = getExpr(op.getRhs());
1509  auto e = solver.min(lhs, rhs);
1510  setExpr(op.getResult(), e);
1511  })
1512  .Case<AndPrimOp, OrPrimOp, XorPrimOp>([&](auto op) {
1513  auto lhs = getExpr(op.getLhs());
1514  auto rhs = getExpr(op.getRhs());
1515  auto e = solver.max(lhs, rhs);
1516  setExpr(op.getResult(), e);
1517  })
1518 
1519  // Misc Binary Primitives
1520  .Case<CatPrimOp>([&](auto op) {
1521  auto lhs = getExpr(op.getLhs());
1522  auto rhs = getExpr(op.getRhs());
1523  auto e = solver.add(lhs, rhs);
1524  setExpr(op.getResult(), e);
1525  })
1526  .Case<DShlPrimOp>([&](auto op) {
1527  auto lhs = getExpr(op.getLhs());
1528  auto rhs = getExpr(op.getRhs());
1529  auto e = solver.add(lhs, solver.add(solver.pow(rhs), solver.known(-1)));
1530  setExpr(op.getResult(), e);
1531  })
1532  .Case<DShlwPrimOp, DShrPrimOp>([&](auto op) {
1533  auto e = getExpr(op.getLhs());
1534  setExpr(op.getResult(), e);
1535  })
1536 
1537  // Unary operators
1538  .Case<NegPrimOp>([&](auto op) {
1539  auto input = getExpr(op.getInput());
1540  auto e = solver.add(input, solver.known(1));
1541  setExpr(op.getResult(), e);
1542  })
1543  .Case<CvtPrimOp>([&](auto op) {
1544  auto input = getExpr(op.getInput());
1545  auto e = op.getInput().getType().base().isSigned()
1546  ? input
1547  : solver.add(input, solver.known(1));
1548  setExpr(op.getResult(), e);
1549  })
1550 
1551  // Miscellaneous
1552  .Case<BitsPrimOp>([&](auto op) {
1553  setExpr(op.getResult(), solver.known(op.getHi() - op.getLo() + 1));
1554  })
1555  .Case<HeadPrimOp>([&](auto op) {
1556  setExpr(op.getResult(), solver.known(op.getAmount()));
1557  })
1558  .Case<TailPrimOp>([&](auto op) {
1559  auto input = getExpr(op.getInput());
1560  auto e = solver.add(input, solver.known(-op.getAmount()));
1561  setExpr(op.getResult(), e);
1562  })
1563  .Case<PadPrimOp>([&](auto op) {
1564  auto input = getExpr(op.getInput());
1565  auto e = solver.max(input, solver.known(op.getAmount()));
1566  setExpr(op.getResult(), e);
1567  })
1568  .Case<ShlPrimOp>([&](auto op) {
1569  auto input = getExpr(op.getInput());
1570  auto e = solver.add(input, solver.known(op.getAmount()));
1571  setExpr(op.getResult(), e);
1572  })
1573  .Case<ShrPrimOp>([&](auto op) {
1574  auto input = getExpr(op.getInput());
1575  // UInt saturates at 0 bits, SInt at 1 bit
1576  auto minWidth = op.getInput().getType().base().isUnsigned() ? 0 : 1;
1577  auto e = solver.max(solver.add(input, solver.known(-op.getAmount())),
1578  solver.known(minWidth));
1579  setExpr(op.getResult(), e);
1580  })
1581 
1582  // Handle operations whose output width matches the input width.
1583  .Case<NotPrimOp, AsSIntPrimOp, AsUIntPrimOp, ConstCastOp>(
1584  [&](auto op) { setExpr(op.getResult(), getExpr(op.getInput())); })
1585  .Case<mlir::UnrealizedConversionCastOp>(
1586  [&](auto op) { setExpr(op.getResult(0), getExpr(op.getOperand(0))); })
1587 
1588  // Handle operations with a single result type that always has a
1589  // well-known width.
1590  .Case<LEQPrimOp, LTPrimOp, GEQPrimOp, GTPrimOp, EQPrimOp, NEQPrimOp,
1591  AsClockPrimOp, AsAsyncResetPrimOp, AndRPrimOp, OrRPrimOp,
1592  XorRPrimOp>([&](auto op) {
1593  auto width = op.getType().getBitWidthOrSentinel();
1594  assert(width > 0 && "width should have been checked by verifier");
1595  setExpr(op.getResult(), solver.known(width));
1596  })
1597  .Case<MuxPrimOp, Mux2CellIntrinsicOp>([&](auto op) {
1598  auto *sel = getExpr(op.getSel());
1599  constrainTypes(solver.known(1), sel, /*imposeUpperBounds=*/true);
1600  maximumOfTypes(op.getResult(), op.getHigh(), op.getLow());
1601  })
1602  .Case<Mux4CellIntrinsicOp>([&](Mux4CellIntrinsicOp op) {
1603  auto *sel = getExpr(op.getSel());
1604  constrainTypes(solver.known(2), sel, /*imposeUpperBounds=*/true);
1605  maximumOfTypes(op.getResult(), op.getV3(), op.getV2());
1606  maximumOfTypes(op.getResult(), op.getResult(), op.getV1());
1607  maximumOfTypes(op.getResult(), op.getResult(), op.getV0());
1608  })
1609 
1610  .Case<ConnectOp, StrictConnectOp>(
1611  [&](auto op) { constrainTypes(op.getDest(), op.getSrc()); })
1612  .Case<RefDefineOp>([&](auto op) {
1613  // Dest >= Src, but also check Src <= Dest for correctness
1614  // (but don't solve to make this true, don't back-propagate)
1615  constrainTypes(op.getDest(), op.getSrc(), true);
1616  })
1617  // StrictConnect is an identify constraint
1618  .Case<StrictConnectOp>([&](auto op) {
1619  // This back-propagates width from destination to source,
1620  // causing source to sometimes be inferred wider than
1621  // it should be (https://github.com/llvm/circt/issues/5391).
1622  // This is needed to push the width into feeding widthCast?
1623  constrainTypes(op.getDest(), op.getSrc());
1624  constrainTypes(op.getSrc(), op.getDest());
1625  })
1626  .Case<AttachOp>([&](auto op) {
1627  // Attach connects multiple analog signals together. All signals must
1628  // have the same bit width. Signals without bit width inherit from the
1629  // other signals.
1630  if (op.getAttached().empty())
1631  return;
1632  auto prev = op.getAttached()[0];
1633  for (auto operand : op.getAttached().drop_front()) {
1634  auto e1 = getExpr(prev);
1635  auto e2 = getExpr(operand);
1636  constrainTypes(e1, e2, /*imposeUpperBounds=*/true);
1637  constrainTypes(e2, e1, /*imposeUpperBounds=*/true);
1638  prev = operand;
1639  }
1640  })
1641 
1642  // Handle the no-ops that don't interact with width inference.
1643  .Case<PrintFOp, SkipOp, StopOp, WhenOp, AssertOp, AssumeOp,
1644  UnclockedAssumeIntrinsicOp, CoverOp>([&](auto) {})
1645 
1646  // Handle instances of other modules.
1647  .Case<InstanceOp>([&](auto op) {
1648  auto refdModule = op.getReferencedOperation(symtbl);
1649  auto module = dyn_cast<FModuleOp>(&*refdModule);
1650  if (!module) {
1651  auto diag = mlir::emitError(op.getLoc());
1652  diag << "extern module `" << op.getModuleName()
1653  << "` has ports of uninferred width";
1654 
1655  auto fml = cast<FModuleLike>(&*refdModule);
1656  auto ports = fml.getPorts();
1657  for (auto &port : ports) {
1658  auto baseType = getBaseType(port.type);
1659  if (baseType && baseType.hasUninferredWidth()) {
1660  diag.attachNote(op.getLoc()) << "Port: " << port.name;
1661  if (!baseType.isGround())
1662  diagnoseUninferredType(diag, baseType, port.name.getValue());
1663  }
1664  }
1665 
1666  diag.attachNote(op.getLoc())
1667  << "Only non-extern FIRRTL modules may contain unspecified "
1668  "widths to be inferred automatically.";
1669  diag.attachNote(refdModule->getLoc())
1670  << "Module `" << op.getModuleName() << "` defined here:";
1671  mappingFailed = true;
1672  return;
1673  }
1674  // Simply look up the free variables created for the instantiated
1675  // module's ports, and use them for instance port wires. This way,
1676  // constraints imposed onto the ports of the instance will transparently
1677  // apply to the ports of the instantiated module.
1678  for (auto it : llvm::zip(op->getResults(), module.getArguments())) {
1679  unifyTypes(FieldRef(std::get<0>(it), 0), FieldRef(std::get<1>(it), 0),
1680  type_cast<FIRRTLType>(std::get<0>(it).getType()));
1681  }
1682  })
1683 
1684  // Handle memories.
1685  .Case<MemOp>([&](MemOp op) {
1686  // Create constraint variables for all ports.
1687  unsigned nonDebugPort = 0;
1688  for (const auto &result : llvm::enumerate(op.getResults())) {
1689  declareVars(result.value(), op.getLoc());
1690  if (!type_isa<RefType>(result.value().getType()))
1691  nonDebugPort = result.index();
1692  }
1693 
1694  // A helper function that returns the indeces of the "data", "rdata",
1695  // and "wdata" fields in the bundle corresponding to a memory port.
1696  auto dataFieldIndices = [](MemOp::PortKind kind) -> ArrayRef<unsigned> {
1697  static const unsigned indices[] = {3, 5};
1698  static const unsigned debug[] = {0};
1699  switch (kind) {
1700  case MemOp::PortKind::Read:
1701  case MemOp::PortKind::Write:
1702  return ArrayRef<unsigned>(indices, 1); // {3}
1703  case MemOp::PortKind::ReadWrite:
1704  return ArrayRef<unsigned>(indices); // {3, 5}
1705  case MemOp::PortKind::Debug:
1706  return ArrayRef<unsigned>(debug);
1707  }
1708  llvm_unreachable("Imposible PortKind");
1709  };
1710 
1711  // This creates independent variables for every data port. Yet, what we
1712  // actually want is for all data ports to share the same variable. To do
1713  // this, we find the first data port declared, and use that port's vars
1714  // for all the other ports.
1715  unsigned firstFieldIndex =
1716  dataFieldIndices(op.getPortKind(nonDebugPort))[0];
1717  FieldRef firstData(
1718  op.getResult(nonDebugPort),
1719  type_cast<BundleType>(op.getPortType(nonDebugPort).getPassiveType())
1720  .getFieldID(firstFieldIndex));
1721  LLVM_DEBUG(llvm::dbgs() << "Adjusting memory port variables:\n");
1722 
1723  // Reuse data port variables.
1724  auto dataType = op.getDataType();
1725  for (unsigned i = 0, e = op.getResults().size(); i < e; ++i) {
1726  auto result = op.getResult(i);
1727  if (type_isa<RefType>(result.getType())) {
1728  // Debug ports are firrtl.ref<vector<data-type, depth>>
1729  // Use FieldRef of 1, to indicate the first vector element must be
1730  // of the dataType.
1731  unifyTypes(firstData, FieldRef(result, 1), dataType);
1732  continue;
1733  }
1734 
1735  auto portType =
1736  type_cast<BundleType>(op.getPortType(i).getPassiveType());
1737  for (auto fieldIndex : dataFieldIndices(op.getPortKind(i)))
1738  unifyTypes(FieldRef(result, portType.getFieldID(fieldIndex)),
1739  firstData, dataType);
1740  }
1741  })
1742 
1743  .Case<RefSendOp>([&](auto op) {
1744  declareVars(op.getResult(), op.getLoc());
1745  constrainTypes(op.getResult(), op.getBase(), true);
1746  })
1747  .Case<RefResolveOp>([&](auto op) {
1748  declareVars(op.getResult(), op.getLoc());
1749  constrainTypes(op.getResult(), op.getRef(), true);
1750  })
1751  .Case<RefCastOp>([&](auto op) {
1752  declareVars(op.getResult(), op.getLoc());
1753  constrainTypes(op.getResult(), op.getInput(), true);
1754  })
1755  .Case<RWProbeOp>([&](auto op) {
1756  declareVars(op.getResult(), op.getLoc());
1757  auto ist = irn.lookup(op.getTarget());
1758  if (!ist) {
1759  op->emitError("target of rwprobe could not be resolved");
1760  mappingFailed = true;
1761  return;
1762  }
1763  auto ref = getRefForIST(ist);
1764  if (!ref) {
1765  op->emitError("target of rwprobe resolved to unsupported target");
1766  mappingFailed = true;
1767  return;
1768  }
1769  auto newFID = convertFieldIDToOurVersion(
1770  ref.getFieldID(), type_cast<FIRRTLType>(ref.getValue().getType()));
1771  unifyTypes(FieldRef(op.getResult(), 0),
1772  FieldRef(ref.getValue(), newFID), op.getType());
1773  })
1774  .Case<mlir::UnrealizedConversionCastOp>([&](auto op) {
1775  for (Value result : op.getResults()) {
1776  auto ty = result.getType();
1777  if (type_isa<FIRRTLType>(ty))
1778  declareVars(result, op.getLoc());
1779  }
1780  })
1781  .Default([&](auto op) {
1782  op->emitOpError("not supported in width inference");
1783  mappingFailed = true;
1784  });
1785 
1786  // Forceable declarations should have the ref constrained to data result.
1787  if (auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable())
1788  unifyTypes(FieldRef(fop.getDataRef(), 0), FieldRef(fop.getDataRaw(), 0),
1789  fop.getDataType());
1790 
1791  return failure(mappingFailed);
1792 }
1793 
1794 /// Declare free variables for the type of a value, and associate the resulting
1795 /// set of variables with that value.
1796 void InferenceMapping::declareVars(Value value, Location loc, bool isDerived) {
1797  // Declare a variable for every unknown width in the type. If this is a Bundle
1798  // type or a FVector type, we will have to potentially create many variables.
1799  unsigned fieldID = 0;
1800  std::function<void(FIRRTLBaseType)> declare = [&](FIRRTLBaseType type) {
1801  auto width = type.getBitWidthOrSentinel();
1802  if (width >= 0) {
1803  // Known width integer create a known expression.
1804  setExpr(FieldRef(value, fieldID), solver.known(width));
1805  fieldID++;
1806  } else if (width == -1) {
1807  // Unknown width integers create a variable.
1808  FieldRef field(value, fieldID);
1809  solver.setCurrentContextInfo(field);
1810  if (isDerived)
1811  setExpr(field, solver.derived());
1812  else
1813  setExpr(field, solver.var());
1814  fieldID++;
1815  } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1816  // Bundle types recursively declare all bundle elements.
1817  fieldID++;
1818  for (auto &element : bundleType) {
1819  declare(element.type);
1820  }
1821  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1822  fieldID++;
1823  auto save = fieldID;
1824  declare(vecType.getElementType());
1825  // Skip past the rest of the elements
1826  fieldID = save + vecType.getMaxFieldID();
1827  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1828  fieldID++;
1829  for (auto &element : enumType.getElements())
1830  declare(element.type);
1831  } else {
1832  llvm_unreachable("Unknown type inside a bundle!");
1833  }
1834  };
1835  if (auto type = getBaseType(value.getType()))
1836  declare(type);
1837 }
1838 
1839 /// Assign the constraint expressions of the fields in the `result` argument as
1840 /// the max of expressions in the `rhs` and `lhs` arguments. Both fields must be
1841 /// the same type.
1842 void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
1843  // Recurse to every leaf element and set larger >= smaller.
1844  auto fieldID = 0;
1845  std::function<void(FIRRTLBaseType)> maximize = [&](FIRRTLBaseType type) {
1846  if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1847  fieldID++;
1848  for (auto &element : bundleType.getElements())
1849  maximize(element.type);
1850  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1851  fieldID++;
1852  auto save = fieldID;
1853  // Skip 0 length vectors.
1854  if (vecType.getNumElements() > 0)
1855  maximize(vecType.getElementType());
1856  fieldID = save + vecType.getMaxFieldID();
1857  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1858  fieldID++;
1859  for (auto &element : enumType.getElements())
1860  maximize(element.type);
1861  } else if (type.isGround()) {
1862  auto *e = solver.max(getExpr(FieldRef(rhs, fieldID)),
1863  getExpr(FieldRef(lhs, fieldID)));
1864  setExpr(FieldRef(result, fieldID), e);
1865  fieldID++;
1866  } else {
1867  llvm_unreachable("Unknown type inside a bundle!");
1868  }
1869  };
1870  if (auto type = getBaseType(result.getType()))
1871  maximize(type);
1872 }
1873 
1874 /// Establishes constraints to ensure the sizes in the `larger` type are greater
1875 /// than or equal to the sizes in the `smaller` type. Types have to be
1876 /// compatible in the sense that they may only differ in the presence or absence
1877 /// of bit widths.
1878 ///
1879 /// This function is used to apply regular connects.
1880 /// Set `equal` for constraining larger <= smaller for correctness but not
1881 /// solving.
1882 void InferenceMapping::constrainTypes(Value larger, Value smaller, bool equal) {
1883  // Recurse to every leaf element and set larger >= smaller. Ignore foreign
1884  // types as these do not participate in width inference.
1885 
1886  auto fieldID = 0;
1887  std::function<void(FIRRTLBaseType, Value, Value)> constrain =
1888  [&](FIRRTLBaseType type, Value larger, Value smaller) {
1889  if (auto bundleType = type_dyn_cast<BundleType>(type)) {
1890  fieldID++;
1891  for (auto &element : bundleType.getElements()) {
1892  if (element.isFlip)
1893  constrain(element.type, smaller, larger);
1894  else
1895  constrain(element.type, larger, smaller);
1896  }
1897  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
1898  fieldID++;
1899  auto save = fieldID;
1900  // Skip 0 length vectors.
1901  if (vecType.getNumElements() > 0) {
1902  constrain(vecType.getElementType(), larger, smaller);
1903  }
1904  fieldID = save + vecType.getMaxFieldID();
1905  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
1906  fieldID++;
1907  for (auto &element : enumType.getElements())
1908  constrain(element.type, larger, smaller);
1909  } else if (type.isGround()) {
1910  // Leaf element, look up their expressions, and create the constraint.
1911  constrainTypes(getExpr(FieldRef(larger, fieldID)),
1912  getExpr(FieldRef(smaller, fieldID)), false, equal);
1913  fieldID++;
1914  } else {
1915  llvm_unreachable("Unknown type inside a bundle!");
1916  }
1917  };
1918 
1919  if (auto type = getBaseType(larger.getType()))
1920  constrain(type, larger, smaller);
1921 }
1922 
1923 /// Establishes constraints to ensure the sizes in the `larger` type are greater
1924 /// than or equal to the sizes in the `smaller` type.
1925 void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
1926  bool imposeUpperBounds, bool equal) {
1927  assert(larger && "Larger expression should be specified");
1928  assert(smaller && "Smaller expression should be specified");
1929 
1930  // If one of the sides is `DerivedExpr`, simply assign the other side as the
1931  // derived width. This allows `InvalidValueOp`s to properly infer their width
1932  // from the connects they are used in, but also be inferred to something
1933  // useful on their own.
1934  if (auto *largerDerived = dyn_cast<DerivedExpr>(larger)) {
1935  largerDerived->assigned = smaller;
1936  LLVM_DEBUG(llvm::dbgs() << "Deriving " << *largerDerived << " from "
1937  << *smaller << "\n");
1938  return;
1939  }
1940  if (auto *smallerDerived = dyn_cast<DerivedExpr>(smaller)) {
1941  smallerDerived->assigned = larger;
1942  LLVM_DEBUG(llvm::dbgs() << "Deriving " << *smallerDerived << " from "
1943  << *larger << "\n");
1944  return;
1945  }
1946 
1947  // If the larger expr is a free variable, create a `expr >= x` constraint for
1948  // it that we can try to satisfy with the smallest width.
1949  if (auto largerVar = dyn_cast<VarExpr>(larger)) {
1950  [[maybe_unused]] auto *c = solver.addGeqConstraint(largerVar, smaller);
1951  LLVM_DEBUG(llvm::dbgs()
1952  << "Constrained " << *largerVar << " >= " << *c << "\n");
1953  // If we're constraining larger == smaller, add the LEQ contraint as well.
1954  // Solve for GEQ but check that LEQ is true.
1955  // Used for strictconnect, some reference operations, and anywhere the
1956  // widths should be inferred strictly in one direction but are required to
1957  // also be equal for correctness.
1958  if (equal) {
1959  [[maybe_unused]] auto *leq = solver.addLeqConstraint(largerVar, smaller);
1960  LLVM_DEBUG(llvm::dbgs()
1961  << "Constrained " << *largerVar << " <= " << *leq << "\n");
1962  }
1963  return;
1964  }
1965 
1966  // If the smaller expr is a free variable but the larger one is not, create a
1967  // `expr <= k` upper bound that we can verify once all lower bounds have been
1968  // satisfied. Since we are always picking the smallest width to satisfy all
1969  // `>=` constraints, any `<=` constraints have no effect on the solution
1970  // besides indicating that a width is unsatisfiable.
1971  if (auto *smallerVar = dyn_cast<VarExpr>(smaller)) {
1972  if (imposeUpperBounds || equal) {
1973  [[maybe_unused]] auto *c = solver.addLeqConstraint(smallerVar, larger);
1974  LLVM_DEBUG(llvm::dbgs()
1975  << "Constrained " << *smallerVar << " <= " << *c << "\n");
1976  }
1977  }
1978 }
1979 
1980 /// Assign the constraint expressions of the fields in the `src` argument as the
1981 /// expressions for the `dst` argument. Both fields must be of the given `type`.
1982 void InferenceMapping::unifyTypes(FieldRef lhs, FieldRef rhs, FIRRTLType type) {
1983  // Fast path for `unifyTypes(x, x, _)`.
1984  if (lhs == rhs)
1985  return;
1986 
1987  // Co-iterate the two field refs, recurring into every leaf element and set
1988  // them equal.
1989  auto fieldID = 0;
1990  std::function<void(FIRRTLBaseType)> unify = [&](FIRRTLBaseType type) {
1991  if (type.isGround()) {
1992  // Leaf element, unify the fields!
1993  FieldRef lhsFieldRef(lhs.getValue(), lhs.getFieldID() + fieldID);
1994  FieldRef rhsFieldRef(rhs.getValue(), rhs.getFieldID() + fieldID);
1995  LLVM_DEBUG(llvm::dbgs()
1996  << "Unify " << getFieldName(lhsFieldRef).first << " = "
1997  << getFieldName(rhsFieldRef).first << "\n");
1998  // Abandon variables becoming unconstrainable by the unification.
1999  if (auto *var = dyn_cast_or_null<VarExpr>(getExprOrNull(lhsFieldRef)))
2000  solver.addGeqConstraint(var, solver.known(0));
2001  setExpr(lhsFieldRef, getExpr(rhsFieldRef));
2002  fieldID++;
2003  } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
2004  fieldID++;
2005  for (auto &element : bundleType) {
2006  unify(element.type);
2007  }
2008  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
2009  fieldID++;
2010  auto save = fieldID;
2011  // Skip 0 length vectors.
2012  if (vecType.getNumElements() > 0) {
2013  unify(vecType.getElementType());
2014  }
2015  fieldID = save + vecType.getMaxFieldID();
2016  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
2017  fieldID++;
2018  for (auto &element : enumType.getElements())
2019  unify(element.type);
2020  } else {
2021  llvm_unreachable("Unknown type inside a bundle!");
2022  }
2023  };
2024  if (auto ftype = getBaseType(type))
2025  unify(ftype);
2026 }
2027 
2028 /// Get the constraint expression for a value.
2029 Expr *InferenceMapping::getExpr(Value value) const {
2030  assert(type_cast<FIRRTLType>(getBaseType(value.getType())).isGround());
2031  // A field ID of 0 indicates the entire value.
2032  return getExpr(FieldRef(value, 0));
2033 }
2034 
2035 /// Get the constraint expression for a value.
2036 Expr *InferenceMapping::getExpr(FieldRef fieldRef) const {
2037  auto expr = getExprOrNull(fieldRef);
2038  assert(expr && "constraint expr should have been constructed for value");
2039  return expr;
2040 }
2041 
2042 Expr *InferenceMapping::getExprOrNull(FieldRef fieldRef) const {
2043  auto it = opExprs.find(fieldRef);
2044  return it != opExprs.end() ? it->second : nullptr;
2045 }
2046 
2047 /// Associate a constraint expression with a value.
2048 void InferenceMapping::setExpr(Value value, Expr *expr) {
2049  assert(type_cast<FIRRTLType>(getBaseType(value.getType())).isGround());
2050  // A field ID of 0 indicates the entire value.
2051  setExpr(FieldRef(value, 0), expr);
2052 }
2053 
2054 /// Associate a constraint expression with a value.
2055 void InferenceMapping::setExpr(FieldRef fieldRef, Expr *expr) {
2056  LLVM_DEBUG({
2057  llvm::dbgs() << "Expr " << *expr << " for " << fieldRef.getValue();
2058  if (fieldRef.getFieldID())
2059  llvm::dbgs() << " '" << getFieldName(fieldRef).first << "'";
2060  auto fieldName = getFieldName(fieldRef);
2061  if (fieldName.second)
2062  llvm::dbgs() << " (\"" << fieldName.first << "\")";
2063  llvm::dbgs() << "\n";
2064  });
2065  opExprs[fieldRef] = expr;
2066 }
2067 
2068 //===----------------------------------------------------------------------===//
2069 // Inference Result Application
2070 //===----------------------------------------------------------------------===//
2071 
2072 namespace {
2073 /// A helper class which maps the types and operations in a design to a set
2074 /// of variables and constraints to be solved later.
2075 class InferenceTypeUpdate {
2076 public:
2077  InferenceTypeUpdate(InferenceMapping &mapping) : mapping(mapping) {}
2078 
2079  LogicalResult update(CircuitOp op);
2080  FailureOr<bool> updateOperation(Operation *op);
2081  FailureOr<bool> updateValue(Value value);
2083 
2084 private:
2085  const InferenceMapping &mapping;
2086 };
2087 
2088 } // namespace
2089 
2090 /// Update the types throughout a circuit.
2091 LogicalResult InferenceTypeUpdate::update(CircuitOp op) {
2092  LLVM_DEBUG({
2093  llvm::dbgs() << "\n";
2094  debugHeader("Update types") << "\n\n";
2095  });
2096  return mlir::failableParallelForEach(
2097  op.getContext(), op.getOps<FModuleOp>(), [&](FModuleOp op) {
2098  // Skip this module if it had no widths to be
2099  // inferred at all.
2100  if (mapping.isModuleSkipped(op))
2101  return success();
2102  auto isFailed = op.walk<WalkOrder::PreOrder>([&](Operation *op) {
2103  if (failed(updateOperation(op)))
2104  return WalkResult::interrupt();
2105  return WalkResult::advance();
2106  }).wasInterrupted();
2107  return failure(isFailed);
2108  });
2109 }
2110 
2111 /// Update the result types of an operation.
2112 FailureOr<bool> InferenceTypeUpdate::updateOperation(Operation *op) {
2113  bool anyChanged = false;
2114 
2115  for (Value v : op->getResults()) {
2116  auto result = updateValue(v);
2117  if (failed(result))
2118  return result;
2119  anyChanged |= *result;
2120  }
2121 
2122  // If this is a connect operation, width inference might have inferred a RHS
2123  // that is wider than the LHS, in which case an additional BitsPrimOp is
2124  // necessary to truncate the value.
2125  if (auto con = dyn_cast<ConnectOp>(op)) {
2126  auto lhs = con.getDest();
2127  auto rhs = con.getSrc();
2128  auto lhsType = type_dyn_cast<FIRRTLBaseType>(lhs.getType());
2129  auto rhsType = type_dyn_cast<FIRRTLBaseType>(rhs.getType());
2130 
2131  // Nothing to do if not base types.
2132  if (!lhsType || !rhsType)
2133  return anyChanged;
2134 
2135  auto lhsWidth = lhsType.getBitWidthOrSentinel();
2136  auto rhsWidth = rhsType.getBitWidthOrSentinel();
2137  if (lhsWidth >= 0 && rhsWidth >= 0 && lhsWidth < rhsWidth) {
2138  OpBuilder builder(op);
2139  auto trunc = builder.createOrFold<TailPrimOp>(con.getLoc(), con.getSrc(),
2140  rhsWidth - lhsWidth);
2141  if (type_isa<SIntType>(rhsType))
2142  trunc =
2143  builder.createOrFold<AsSIntPrimOp>(con.getLoc(), lhsType, trunc);
2144 
2145  LLVM_DEBUG(llvm::dbgs()
2146  << "Truncating RHS to " << lhsType << " in " << con << "\n");
2147  con->replaceUsesOfWith(con.getSrc(), trunc);
2148  }
2149  return anyChanged;
2150  }
2151 
2152  // If this is a module, update its ports.
2153  if (auto module = dyn_cast<FModuleOp>(op)) {
2154  // Update the block argument types.
2155  bool argsChanged = false;
2156  SmallVector<Attribute> argTypes;
2157  argTypes.reserve(module.getNumPorts());
2158  for (auto arg : module.getArguments()) {
2159  auto result = updateValue(arg);
2160  if (failed(result))
2161  return result;
2162  argsChanged |= *result;
2163  argTypes.push_back(TypeAttr::get(arg.getType()));
2164  }
2165 
2166  // Update the module function type if needed.
2167  if (argsChanged) {
2168  module->setAttr(FModuleLike::getPortTypesAttrName(),
2169  ArrayAttr::get(module.getContext(), argTypes));
2170  anyChanged = true;
2171  }
2172  }
2173  return anyChanged;
2174 }
2175 
2176 /// Resize a `uint`, `sint`, or `analog` type to a specific width.
2177 static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth) {
2178  auto *context = type.getContext();
2180  .Case<UIntType>([&](auto type) {
2181  return UIntType::get(context, newWidth, type.isConst());
2182  })
2183  .Case<SIntType>([&](auto type) {
2184  return SIntType::get(context, newWidth, type.isConst());
2185  })
2186  .Case<AnalogType>([&](auto type) {
2187  return AnalogType::get(context, newWidth, type.isConst());
2188  })
2189  .Default([&](auto type) { return type; });
2190 }
2191 
2192 /// Update the type of a value.
2193 FailureOr<bool> InferenceTypeUpdate::updateValue(Value value) {
2194  // Check if the value has a type which we can update.
2195  auto type = type_dyn_cast<FIRRTLType>(value.getType());
2196  if (!type)
2197  return false;
2198 
2199  // Fast path for types that have fully inferred widths.
2200  if (!hasUninferredWidth(type))
2201  return false;
2202 
2203  // If this is an operation that does not generate any free variables that
2204  // are determined during width inference, simply update the value type based
2205  // on the operation arguments.
2206  if (auto op = dyn_cast_or_null<InferTypeOpInterface>(value.getDefiningOp())) {
2207  SmallVector<Type, 2> types;
2208  auto res =
2209  op.inferReturnTypes(op->getContext(), op->getLoc(), op->getOperands(),
2210  op->getAttrDictionary(), op->getPropertiesStorage(),
2211  op->getRegions(), types);
2212  if (failed(res))
2213  return failure();
2214 
2215  assert(types.size() == op->getNumResults());
2216  for (auto it : llvm::zip(op->getResults(), types)) {
2217  LLVM_DEBUG(llvm::dbgs() << "Inferring " << std::get<0>(it) << " as "
2218  << std::get<1>(it) << "\n");
2219  std::get<0>(it).setType(std::get<1>(it));
2220  }
2221  return true;
2222  }
2223 
2224  // Recreate the type, substituting the solved widths.
2225  auto context = type.getContext();
2226  unsigned fieldID = 0;
2227  std::function<FIRRTLBaseType(FIRRTLBaseType)> updateBase =
2228  [&](FIRRTLBaseType type) -> FIRRTLBaseType {
2229  auto width = type.getBitWidthOrSentinel();
2230  if (width >= 0) {
2231  // Known width integers return themselves.
2232  fieldID++;
2233  return type;
2234  } else if (width == -1) {
2235  // Unknown width integers return the solved type.
2236  auto newType = updateType(FieldRef(value, fieldID), type);
2237  fieldID++;
2238  return newType;
2239  } else if (auto bundleType = type_dyn_cast<BundleType>(type)) {
2240  // Bundle types recursively update all bundle elements.
2241  fieldID++;
2242  llvm::SmallVector<BundleType::BundleElement, 3> elements;
2243  for (auto &element : bundleType) {
2244  auto updatedBase = updateBase(element.type);
2245  if (!updatedBase)
2246  return {};
2247  elements.emplace_back(element.name, element.isFlip, updatedBase);
2248  }
2249  return BundleType::get(context, elements, bundleType.isConst());
2250  } else if (auto vecType = type_dyn_cast<FVectorType>(type)) {
2251  fieldID++;
2252  auto save = fieldID;
2253  // TODO: this should recurse into the element type of 0 length vectors and
2254  // set any unknown width to 0.
2255  if (vecType.getNumElements() > 0) {
2256  auto updatedBase = updateBase(vecType.getElementType());
2257  if (!updatedBase)
2258  return {};
2259  auto newType = FVectorType::get(updatedBase, vecType.getNumElements(),
2260  vecType.isConst());
2261  fieldID = save + vecType.getMaxFieldID();
2262  return newType;
2263  }
2264  // If this is a 0 length vector return the original type.
2265  return type;
2266  } else if (auto enumType = type_dyn_cast<FEnumType>(type)) {
2267  fieldID++;
2268  llvm::SmallVector<FEnumType::EnumElement> elements;
2269  for (auto &element : enumType.getElements()) {
2270  auto updatedBase = updateBase(element.type);
2271  if (!updatedBase)
2272  return {};
2273  elements.emplace_back(element.name, updatedBase);
2274  }
2275  return FEnumType::get(context, elements, enumType.isConst());
2276  }
2277  llvm_unreachable("Unknown type inside a bundle!");
2278  };
2279 
2280  // Update the type.
2281  auto newType = mapBaseTypeNullable(type, updateBase);
2282  if (!newType)
2283  return failure();
2284  LLVM_DEBUG(llvm::dbgs() << "Update " << value << " to " << newType << "\n");
2285  value.setType(newType);
2286 
2287  // If this is a ConstantOp, adjust the width of the underlying APInt.
2288  // Unsized constants have APInts which are *at least* wide enough to hold
2289  // the value, but may be larger. This can trip up the verifier.
2290  if (auto op = value.getDefiningOp<ConstantOp>()) {
2291  auto k = op.getValue();
2292  auto bitwidth = op.getType().getBitWidthOrSentinel();
2293  if (k.getBitWidth() > unsigned(bitwidth))
2294  k = k.trunc(bitwidth);
2295  op->setAttr("value", IntegerAttr::get(op.getContext(), k));
2296  }
2297 
2298  return newType != type;
2299 }
2300 
2301 /// Update a type.
2303  FIRRTLBaseType type) {
2304  assert(type.isGround() && "Can only pass in ground types.");
2305  auto value = fieldRef.getValue();
2306  // Get the inferred width.
2307  Expr *expr = mapping.getExprOrNull(fieldRef);
2308  if (!expr || !expr->solution) {
2309  // It should not be possible to arrive at an uninferred width at this point.
2310  // In case the constraints are not resolvable, checks before the calls to
2311  // `updateType` must have already caught the issues and aborted the pass
2312  // early. Might turn this into an assert later.
2313  mlir::emitError(value.getLoc(), "width should have been inferred");
2314  return {};
2315  }
2316  int32_t solution = *expr->solution;
2317  assert(solution >= 0); // The solver infers variables to be 0 or greater.
2318  return resizeType(type, solution);
2319 }
2320 
2321 //===----------------------------------------------------------------------===//
2322 // Hashing
2323 //===----------------------------------------------------------------------===//
2324 
2325 // Hash slots in the interned allocator as if they were the pointed-to value
2326 // itself.
2327 namespace llvm {
2328 template <typename T>
2329 struct DenseMapInfo<InternedSlot<T>> {
2330  using Slot = InternedSlot<T>;
2331  static Slot getEmptyKey() {
2332  auto pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
2333  return Slot(static_cast<T *>(pointer));
2334  }
2336  auto pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
2337  return Slot(static_cast<T *>(pointer));
2338  }
2339  static unsigned getHashValue(Slot val) { return mlir::hash_value(*val.ptr); }
2340  static bool isEqual(Slot LHS, Slot RHS) {
2341  auto empty = getEmptyKey().ptr;
2342  auto tombstone = getTombstoneKey().ptr;
2343  if (LHS.ptr == empty || RHS.ptr == empty || LHS.ptr == tombstone ||
2344  RHS.ptr == tombstone)
2345  return LHS.ptr == RHS.ptr;
2346  return *LHS.ptr == *RHS.ptr;
2347  }
2348 };
2349 } // namespace llvm
2350 
2351 //===----------------------------------------------------------------------===//
2352 // Pass Infrastructure
2353 //===----------------------------------------------------------------------===//
2354 
2355 namespace {
2356 class InferWidthsPass : public InferWidthsBase<InferWidthsPass> {
2357  void runOnOperation() override;
2358 };
2359 } // namespace
2360 
2361 void InferWidthsPass::runOnOperation() {
2362  // Collect variables and constraints
2363  ConstraintSolver solver;
2364  InferenceMapping mapping(solver, getAnalysis<SymbolTable>(),
2365  getAnalysis<hw::InnerSymbolTableCollection>());
2366  if (failed(mapping.map(getOperation()))) {
2367  signalPassFailure();
2368  return;
2369  }
2370  if (mapping.areAllModulesSkipped()) {
2371  markAllAnalysesPreserved();
2372  return; // fast path if no inferrable widths are around
2373  }
2374 
2375  // Solve the constraints.
2376  if (failed(solver.solve())) {
2377  signalPassFailure();
2378  return;
2379  }
2380 
2381  // Update the types with the inferred widths.
2382  if (failed(InferenceTypeUpdate(mapping).update(getOperation())))
2383  signalPassFailure();
2384 }
2385 
2386 std::unique_ptr<mlir::Pass> circt::firrtl::createInferWidthsPass() {
2387  return std::make_unique<InferWidthsPass>();
2388 }
assert(baseType &&"element must be base type")
int32_t width
Definition: FIRRTL.cpp:36
static FIRRTLBaseType updateType(FIRRTLBaseType oldType, unsigned fieldID, FIRRTLBaseType fieldType)
Update the type of a single field within a type.
bool operator==(const ResetDomain &a, const ResetDomain &b)
Definition: InferResets.cpp:71
static FieldRef getRefForIST(const hw::InnerSymTarget &ist)
Get FieldRef pointing to the specified inner symbol target, which must be valid.
Definition: InferWidths.cpp:66
std::pair< std::optional< int32_t >, bool > ExprSolution
static ExprSolution computeUnary(ExprSolution arg, llvm::function_ref< int32_t(int32_t)> operation)
static uint64_t convertFieldIDToOurVersion(uint64_t fieldID, FIRRTLType type)
Calculate the "InferWidths-fieldID" equivalent for the given fieldID + type.
Definition: InferWidths.cpp:82
static ExprSolution computeBinary(ExprSolution lhs, ExprSolution rhs, llvm::function_ref< int32_t(int32_t, int32_t)> operation)
#define EXPR_KINDS
#define EXPR_CLASSES
static FIRRTLBaseType resizeType(FIRRTLBaseType type, uint32_t newWidth)
Resize a uint, sint, or analog type to a specific width.
static void diagnoseUninferredType(InFlightDiagnostic &diag, Type t, Twine str)
Definition: InferWidths.cpp:46
static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl< Expr * > &seenVars, unsigned defaultWorklistSize)
Compute the value of a constraint expr.
static bool hasUninferredWidth(Type type)
Check if a type contains any FIRRTL type with uninferred widths.
static InstancePath empty
Builder builder
This class represents a reference to a specific field or element of an aggregate value.
Definition: FieldRef.h:28
unsigned getFieldID() const
Get the field ID of this FieldRef, which is a unique identifier mapped to a specific field in a bundl...
Definition: FieldRef.h:59
Value getValue() const
Get the Value which created this location.
Definition: FieldRef.h:37
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:518
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:528
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
FIRRTLType mapBaseTypeNullable(FIRRTLType type, function_ref< FIRRTLBaseType(FIRRTLBaseType)> fn)
Return a FIRRTLType with its base type component mutated by the given function.
Definition: FIRRTLUtils.h:246
FIRRTLBaseType getBaseType(Type type)
If it is a base type, return it as is.
Definition: FIRRTLUtils.h:217
T & operator<<(T &os, FIRVersion version)
Definition: FIRParser.h:116
std::pair< std::string, bool > getFieldName(const FieldRef &fieldRef, bool nameSafe=false)
Get a string identifier representing the FieldRef.
std::unique_ptr< mlir::Pass > createInferWidthsPass()
llvm::hash_code hash_value(const ClassElement &element)
Definition: FIRRTLTypes.h:366
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugHeader(llvm::StringRef str, int width=80)
Write a "header"-like string to the debug stream with a certain width.
Definition: Debug.cpp:18
Definition: debug.py:1
llvm::hash_code hash_value(const T &e)
static unsigned getHashValue(Slot val)
static bool isEqual(Slot LHS, Slot RHS)