CIRCT  18.0.0git
FIRRTLTypes.cpp
Go to the documentation of this file.
1 //===- FIRRTLTypes.cpp - Implement the FIRRTL dialect type system ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the FIRRTL dialect type system.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/DialectImplementation.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/ADT/StringSwitch.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 using namespace circt;
23 using namespace firrtl;
24 
25 using mlir::OptionalParseResult;
26 using mlir::TypeStorageAllocator;
27 
28 //===----------------------------------------------------------------------===//
29 // TableGen generated logic.
30 //===----------------------------------------------------------------------===//
31 
32 // Provide the autogenerated implementation for types.
33 #define GET_TYPEDEF_CLASSES
34 #include "circt/Dialect/FIRRTL/FIRRTLTypes.cpp.inc"
35 
36 //===----------------------------------------------------------------------===//
37 // Type Printing
38 //===----------------------------------------------------------------------===//
39 
40 // NOLINTBEGIN(misc-no-recursion)
41 /// Print a type with a custom printer implementation.
42 ///
43 /// This only prints a subset of all types in the dialect. Use `printNestedType`
44 /// instead, which will call this function in turn, as appropriate.
45 static LogicalResult customTypePrinter(Type type, AsmPrinter &os) {
46  if (isConst(type)) {
47  os << "const.";
48  }
49 
50  auto printWidthQualifier = [&](std::optional<int32_t> width) {
51  if (width)
52  os << '<' << *width << '>';
53  };
54  bool anyFailed = false;
55  TypeSwitch<Type>(type)
56  .Case<ClockType>([&](auto) { os << "clock"; })
57  .Case<ResetType>([&](auto) { os << "reset"; })
58  .Case<AsyncResetType>([&](auto) { os << "asyncreset"; })
59  .Case<SIntType>([&](auto sIntType) {
60  os << "sint";
61  printWidthQualifier(sIntType.getWidth());
62  })
63  .Case<UIntType>([&](auto uIntType) {
64  os << "uint";
65  printWidthQualifier(uIntType.getWidth());
66  })
67  .Case<AnalogType>([&](auto analogType) {
68  os << "analog";
69  printWidthQualifier(analogType.getWidth());
70  })
71  .Case<BundleType, OpenBundleType>([&](auto bundleType) {
72  if (firrtl::type_isa<OpenBundleType>(bundleType))
73  os << "open";
74  os << "bundle<";
75  llvm::interleaveComma(bundleType, os, [&](auto element) {
76  StringRef fieldName = element.name.getValue();
77  bool isLiteralIdentifier =
78  !fieldName.empty() && llvm::isDigit(fieldName.front());
79  if (isLiteralIdentifier)
80  os << "\"";
81  os << element.name.getValue();
82  if (isLiteralIdentifier)
83  os << "\"";
84  if (element.isFlip)
85  os << " flip";
86  os << ": ";
87  printNestedType(element.type, os);
88  });
89  os << '>';
90  })
91  .Case<FEnumType>([&](auto fenumType) {
92  os << "enum<";
93  llvm::interleaveComma(fenumType, os,
94  [&](FEnumType::EnumElement element) {
95  os << element.name.getValue();
96  os << ": ";
97  printNestedType(element.type, os);
98  });
99  os << '>';
100  })
101  .Case<FVectorType, OpenVectorType>([&](auto vectorType) {
102  if (firrtl::type_isa<OpenVectorType>(vectorType))
103  os << "open";
104  os << "vector<";
105  printNestedType(vectorType.getElementType(), os);
106  os << ", " << vectorType.getNumElements() << '>';
107  })
108  .Case<RefType>([&](auto refType) {
109  if (refType.getForceable())
110  os << "rw";
111  os << "probe<";
112  printNestedType(refType.getType(), os);
113  os << '>';
114  })
115  .Case<StringType>([&](auto stringType) { os << "string"; })
116  .Case<FIntegerType>([&](auto integerType) { os << "integer"; })
117  .Case<BoolType>([&](auto boolType) { os << "bool"; })
118  .Case<DoubleType>([&](auto doubleType) { os << "double"; })
119  .Case<ListType>([&](auto listType) {
120  os << "list<";
121  printNestedType(listType.getElementType(), os);
122  os << '>';
123  })
124  .Case<PathType>([&](auto pathType) { os << "path"; })
125  .Case<BaseTypeAliasType>([&](BaseTypeAliasType alias) {
126  os << "alias<" << alias.getName().getValue() << ", ";
127  printNestedType(alias.getInnerType(), os);
128  os << '>';
129  })
130  .Case<ClassType>([&](ClassType type) {
131  os << "class<";
132  type.printInterface(os);
133  os << ">";
134  })
135  .Case<AnyRefType>([&](AnyRefType type) { os << "anyref"; })
136  .Default([&](auto) { anyFailed = true; });
137  return failure(anyFailed);
138 }
139 // NOLINTEND(misc-no-recursion)
140 
141 /// Print a type defined by this dialect.
142 void circt::firrtl::printNestedType(Type type, AsmPrinter &os) {
143  // Try the custom type printer.
144  if (succeeded(customTypePrinter(type, os)))
145  return;
146 
147  // None of the above recognized the type, so we bail.
148  assert(false && "type to print unknown to FIRRTL dialect");
149 }
150 
151 //===----------------------------------------------------------------------===//
152 // Type Parsing
153 //===----------------------------------------------------------------------===//
154 
155 /// Parse a type with a custom parser implementation.
156 ///
157 /// This only accepts a subset of all types in the dialect. Use `parseType`
158 /// instead, which will call this function in turn, as appropriate.
159 ///
160 /// Returns `std::nullopt` if the type `name` is not covered by the custom
161 /// parsers. Otherwise returns success or failure as appropriate. On success,
162 /// `result` is set to the resulting type.
163 ///
164 /// ```plain
165 /// firrtl-type
166 /// ::= clock
167 /// ::= reset
168 /// ::= asyncreset
169 /// ::= sint ('<' int '>')?
170 /// ::= uint ('<' int '>')?
171 /// ::= analog ('<' int '>')?
172 /// ::= bundle '<' (bundle-elt (',' bundle-elt)*)? '>'
173 /// ::= enum '<' (enum-elt (',' enum-elt)*)? '>'
174 /// ::= vector '<' type ',' int '>'
175 /// ::= const '.' type
176 /// ::= 'property.' firrtl-phased-type
177 /// bundle-elt ::= identifier flip? ':' type
178 /// enum-elt ::= identifier ':' type
179 /// ```
180 static OptionalParseResult customTypeParser(AsmParser &parser, StringRef name,
181  Type &result) {
182  bool isConst = false;
183  const char constPrefix[] = "const.";
184  if (name.starts_with(constPrefix)) {
185  isConst = true;
186  name = name.drop_front(std::size(constPrefix) - 1);
187  }
188 
189  auto *context = parser.getContext();
190  if (name.equals("clock"))
191  return result = ClockType::get(context, isConst), success();
192  if (name.equals("reset"))
193  return result = ResetType::get(context, isConst), success();
194  if (name.equals("asyncreset"))
195  return result = AsyncResetType::get(context, isConst), success();
196 
197  if (name.equals("sint") || name.equals("uint") || name.equals("analog")) {
198  // Parse the width specifier if it exists.
199  int32_t width = -1;
200  if (!parser.parseOptionalLess()) {
201  if (parser.parseInteger(width) || parser.parseGreater())
202  return failure();
203 
204  if (width < 0)
205  return parser.emitError(parser.getNameLoc(), "unknown width"),
206  failure();
207  }
208 
209  if (name.equals("sint"))
210  result = SIntType::get(context, width, isConst);
211  else if (name.equals("uint"))
212  result = UIntType::get(context, width, isConst);
213  else {
214  assert(name.equals("analog"));
215  result = AnalogType::get(context, width, isConst);
216  }
217  return success();
218  }
219 
220  if (name.equals("bundle")) {
221  SmallVector<BundleType::BundleElement, 4> elements;
222 
223  auto parseBundleElement = [&]() -> ParseResult {
224  std::string nameStr;
225  StringRef name;
226  FIRRTLBaseType type;
227 
228  if (failed(parser.parseKeywordOrString(&nameStr)))
229  return failure();
230  name = nameStr;
231 
232  bool isFlip = succeeded(parser.parseOptionalKeyword("flip"));
233  if (parser.parseColon() || parseNestedBaseType(type, parser))
234  return failure();
235 
236  elements.push_back({StringAttr::get(context, name), isFlip, type});
237  return success();
238  };
239 
240  if (parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater,
241  parseBundleElement))
242  return failure();
243 
244  return result = BundleType::get(context, elements, isConst), success();
245  }
246  if (name.equals("openbundle")) {
247  SmallVector<OpenBundleType::BundleElement, 4> elements;
248 
249  auto parseBundleElement = [&]() -> ParseResult {
250  std::string nameStr;
251  StringRef name;
252  FIRRTLType type;
253 
254  if (failed(parser.parseKeywordOrString(&nameStr)))
255  return failure();
256  name = nameStr;
257 
258  bool isFlip = succeeded(parser.parseOptionalKeyword("flip"));
259  if (parser.parseColon() || parseNestedType(type, parser))
260  return failure();
261 
262  elements.push_back({StringAttr::get(context, name), isFlip, type});
263  return success();
264  };
265 
266  if (parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater,
267  parseBundleElement))
268  return failure();
269 
270  result = parser.getChecked<OpenBundleType>(context, elements, isConst);
271  return failure(!result);
272  }
273 
274  if (name.equals("enum")) {
275  SmallVector<FEnumType::EnumElement, 4> elements;
276 
277  auto parseEnumElement = [&]() -> ParseResult {
278  std::string nameStr;
279  StringRef name;
280  FIRRTLBaseType type;
281 
282  if (failed(parser.parseKeywordOrString(&nameStr)))
283  return failure();
284  name = nameStr;
285 
286  if (parser.parseColon() || parseNestedBaseType(type, parser))
287  return failure();
288 
289  elements.push_back({StringAttr::get(context, name), type});
290  return success();
291  };
292 
293  if (parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater,
294  parseEnumElement))
295  return failure();
296  if (failed(FEnumType::verify(
297  [&]() { return parser.emitError(parser.getNameLoc()); }, elements,
298  isConst)))
299  return failure();
300 
301  return result = FEnumType::get(context, elements, isConst), success();
302  }
303 
304  if (name.equals("vector")) {
305  FIRRTLBaseType elementType;
306  uint64_t width = 0;
307 
308  if (parser.parseLess() || parseNestedBaseType(elementType, parser) ||
309  parser.parseComma() || parser.parseInteger(width) ||
310  parser.parseGreater())
311  return failure();
312 
313  return result = FVectorType::get(elementType, width, isConst), success();
314  }
315  if (name.equals("openvector")) {
316  FIRRTLType elementType;
317  uint64_t width = 0;
318 
319  if (parser.parseLess() || parseNestedType(elementType, parser) ||
320  parser.parseComma() || parser.parseInteger(width) ||
321  parser.parseGreater())
322  return failure();
323 
324  result =
325  parser.getChecked<OpenVectorType>(context, elementType, width, isConst);
326  return failure(!result);
327  }
328 
329  // For now, support both firrtl.ref and firrtl.probe.
330  if (name.equals("ref") || name.equals("probe")) {
331  FIRRTLBaseType type;
332  // Don't pass `isConst` to `parseNestedBaseType since `ref` can point to
333  // either `const` or non-`const` types
334  if (parser.parseLess() || parseNestedBaseType(type, parser) ||
335  parser.parseGreater())
336  return failure();
337 
338  if (failed(RefType::verify(
339  [&]() { return parser.emitError(parser.getNameLoc()); }, type,
340  false)))
341  return failure();
342 
343  return result = RefType::get(type, false), success();
344  }
345  if (name.equals("rwprobe")) {
346  FIRRTLBaseType type;
347  if (parser.parseLess() || parseNestedBaseType(type, parser) ||
348  parser.parseGreater())
349  return failure();
350 
351  if (failed(RefType::verify(
352  [&]() { return parser.emitError(parser.getNameLoc()); }, type,
353  true)))
354  return failure();
355 
356  return result = RefType::get(type, true), success();
357  }
358  if (name.equals("class")) {
359  if (isConst)
360  return parser.emitError(parser.getNameLoc(), "classes cannot be const");
361  ClassType classType;
362  if (parser.parseLess() || ClassType::parseInterface(parser, classType) ||
363  parser.parseGreater())
364  return failure();
365  result = classType;
366  return success();
367  }
368  if (name.equals("anyref")) {
369  if (isConst)
370  return parser.emitError(parser.getNameLoc(), "any refs cannot be const");
371 
372  result = AnyRefType::get(parser.getContext());
373  return success();
374  }
375  if (name.equals("string")) {
376  if (isConst) {
377  parser.emitError(parser.getNameLoc(), "strings cannot be const");
378  return failure();
379  }
380  result = StringType::get(parser.getContext());
381  return success();
382  }
383  if (name.equals("integer")) {
384  if (isConst) {
385  parser.emitError(parser.getNameLoc(), "bigints cannot be const");
386  return failure();
387  }
388  result = FIntegerType::get(parser.getContext());
389  return success();
390  }
391  if (name.equals("bool")) {
392  if (isConst) {
393  parser.emitError(parser.getNameLoc(), "bools cannot be const");
394  return failure();
395  }
396  result = BoolType::get(parser.getContext());
397  return success();
398  }
399  if (name.equals("double")) {
400  if (isConst) {
401  parser.emitError(parser.getNameLoc(), "doubles cannot be const");
402  return failure();
403  }
404  result = DoubleType::get(parser.getContext());
405  return success();
406  }
407  if (name.equals("list")) {
408  if (isConst) {
409  parser.emitError(parser.getNameLoc(), "lists cannot be const");
410  return failure();
411  }
413  if (parser.parseLess() || parseNestedPropertyType(elementType, parser) ||
414  parser.parseGreater())
415  return failure();
416  result = parser.getChecked<ListType>(context, elementType);
417  if (!result)
418  return failure();
419  return success();
420  }
421  if (name.equals("path")) {
422  if (isConst) {
423  parser.emitError(parser.getNameLoc(), "path cannot be const");
424  return failure();
425  }
426  result = PathType::get(parser.getContext());
427  return success();
428  }
429  if (name.equals("alias")) {
430  FIRRTLBaseType type;
431  StringRef name;
432  if (parser.parseLess() || parser.parseKeyword(&name) ||
433  parser.parseComma() || parseNestedBaseType(type, parser) ||
434  parser.parseGreater())
435  return failure();
436 
437  return result =
438  BaseTypeAliasType::get(StringAttr::get(context, name), type),
439  success();
440  }
441 
442  return {};
443 }
444 
445 /// Parse a type defined by this dialect.
446 ///
447 /// This will first try the generated type parsers and then resort to the custom
448 /// parser implementation. Emits an error and returns failure if `name` does not
449 /// refer to a type defined in this dialect.
450 static ParseResult parseType(Type &result, StringRef name, AsmParser &parser) {
451  // Try the custom type parser.
452  OptionalParseResult parseResult = customTypeParser(parser, name, result);
453  if (parseResult.has_value())
454  return parseResult.value();
455 
456  // None of the above recognized the type, so we bail.
457  parser.emitError(parser.getNameLoc(), "unknown FIRRTL dialect type: \"")
458  << name << "\"";
459  return failure();
460 }
461 
462 /// Parse a `FIRRTLType` with a `name` that has already been parsed.
463 ///
464 /// Note that only a subset of types defined in the FIRRTL dialect inherit from
465 /// `FIRRTLType`. Use `parseType` to parse *any* of the defined types.
466 static ParseResult parseFIRRTLType(FIRRTLType &result, StringRef name,
467  AsmParser &parser) {
468  Type type;
469  if (failed(parseType(type, name, parser)))
470  return failure();
471  result = type_dyn_cast<FIRRTLType>(type);
472  if (result)
473  return success();
474  parser.emitError(parser.getNameLoc(), "unknown FIRRTL type: \"")
475  << name << "\"";
476  return failure();
477 }
478 
479 static ParseResult parseFIRRTLBaseType(FIRRTLBaseType &result, StringRef name,
480  AsmParser &parser) {
481  FIRRTLType type;
482  if (failed(parseFIRRTLType(type, name, parser)))
483  return failure();
484  if (auto base = type_dyn_cast<FIRRTLBaseType>(type)) {
485  result = base;
486  return success();
487  }
488  parser.emitError(parser.getNameLoc(), "expected base type, found ") << type;
489  return failure();
490 }
491 
492 static ParseResult parseFIRRTLPropertyType(PropertyType &result, StringRef name,
493  AsmParser &parser) {
494  FIRRTLType type;
495  if (failed(parseFIRRTLType(type, name, parser)))
496  return failure();
497  if (auto prop = type_dyn_cast<PropertyType>(type)) {
498  result = prop;
499  return success();
500  }
501  parser.emitError(parser.getNameLoc(), "expected property type, found ")
502  << type;
503  return failure();
504 }
505 
506 // NOLINTBEGIN(misc-no-recursion)
507 /// Parse a `FIRRTLType`.
508 ///
509 /// Note that only a subset of types defined in the FIRRTL dialect inherit from
510 /// `FIRRTLType`. Use `parseType` to parse *any* of the defined types.
512  AsmParser &parser) {
513  StringRef name;
514  if (parser.parseKeyword(&name))
515  return failure();
516  return parseFIRRTLType(result, name, parser);
517 }
518 // NOLINTEND(misc-no-recursion)
519 
520 // NOLINTBEGIN(misc-no-recursion)
522  AsmParser &parser) {
523  StringRef name;
524  if (parser.parseKeyword(&name))
525  return failure();
526  return parseFIRRTLBaseType(result, name, parser);
527 }
528 // NOLINTEND(misc-no-recursion)
529 
530 // NOLINTBEGIN(misc-no-recursion)
532  AsmParser &parser) {
533  StringRef name;
534  if (parser.parseKeyword(&name))
535  return failure();
536  return parseFIRRTLPropertyType(result, name, parser);
537 }
538 // NOLINTEND(misc-no-recursion)
539 
540 //===---------------------------------------------------------------------===//
541 // Dialect Type Parsing and Printing
542 //===----------------------------------------------------------------------===//
543 
544 /// Print a type registered to this dialect.
545 void FIRRTLDialect::printType(Type type, DialectAsmPrinter &os) const {
546  printNestedType(type, os);
547 }
548 
549 /// Parse a type registered to this dialect.
550 Type FIRRTLDialect::parseType(DialectAsmParser &parser) const {
551  StringRef name;
552  Type result;
553  if (parser.parseKeyword(&name) || ::parseType(result, name, parser))
554  return Type();
555  return result;
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // Recursive Type Properties
560 //===----------------------------------------------------------------------===//
561 
562 enum {
563  /// Bit set if the type only contains passive elements.
565  /// Bit set if the type contains an analog type.
567  /// Bit set fi the type has any uninferred bit widths.
569 };
570 
571 //===----------------------------------------------------------------------===//
572 // FIRRTLBaseType Implementation
573 //===----------------------------------------------------------------------===//
574 
576  // Use `char` instead of `bool` since llvm already provides a
577  // DenseMapInfo<char> specialization
578  using KeyTy = char;
579 
580  FIRRTLBaseTypeStorage(bool isConst) : isConst(static_cast<char>(isConst)) {}
581 
582  bool operator==(const KeyTy &key) const { return key == isConst; }
583 
584  KeyTy getAsKey() const { return isConst; }
585 
586  static FIRRTLBaseTypeStorage *construct(TypeStorageAllocator &allocator,
587  KeyTy key) {
588  return new (allocator.allocate<FIRRTLBaseTypeStorage>())
590  }
591 
592  char isConst;
593 };
594 
595 /// Return true if this is a 'ground' type, aka a non-aggregate type.
596 bool FIRRTLType::isGround() {
597  return TypeSwitch<FIRRTLType, bool>(*this)
598  .Case<ClockType, ResetType, AsyncResetType, SIntType, UIntType,
599  AnalogType>([](Type) { return true; })
600  .Case<BundleType, FVectorType, FEnumType, OpenBundleType, OpenVectorType>(
601  [](Type) { return false; })
602  .Case<BaseTypeAliasType>([](BaseTypeAliasType alias) {
603  return alias.getAnonymousType().isGround();
604  })
605  // Not ground per spec, but leaf of aggregate.
606  .Case<PropertyType, RefType>([](Type) { return false; })
607  .Default([](Type) {
608  llvm_unreachable("unknown FIRRTL type");
609  return false;
610  });
611 }
612 
613 bool FIRRTLType::isConst() {
614  return TypeSwitch<FIRRTLType, bool>(*this)
615  .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>(
616  [](auto type) { return type.isConst(); })
617  .Default(false);
618 }
619 
620 bool FIRRTLBaseType::isConst() { return getImpl()->isConst; }
621 
623  return TypeSwitch<FIRRTLType, RecursiveTypeProperties>(*this)
624  .Case<ClockType, ResetType, AsyncResetType>([](FIRRTLBaseType type) {
625  return RecursiveTypeProperties{true,
626  false,
627  false,
628  type.isConst(),
629  false,
630  false,
631  firrtl::type_isa<ResetType>(type)};
632  })
633  .Case<SIntType, UIntType>([](auto type) {
635  true, false, false, type.isConst(), false, !type.hasWidth(), false};
636  })
637  .Case<AnalogType>([](auto type) {
639  true, false, true, type.isConst(), false, !type.hasWidth(), false};
640  })
641  .Case<BundleType, FVectorType, FEnumType, OpenBundleType, OpenVectorType,
642  RefType, BaseTypeAliasType>(
643  [](auto type) { return type.getRecursiveTypeProperties(); })
644  .Case<PropertyType>([](auto type) {
645  return RecursiveTypeProperties{true, false, false, false,
646  false, false, false};
647  })
648  .Default([](Type) {
649  llvm_unreachable("unknown FIRRTL type");
650  return RecursiveTypeProperties{};
651  });
652 }
653 
654 /// Return this type with any type aliases recursively removed from itself.
656  return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
657  .Case<ClockType, ResetType, AsyncResetType, SIntType, UIntType,
658  AnalogType>([&](Type) { return *this; })
659  .Case<BundleType, FVectorType, FEnumType, BaseTypeAliasType>(
660  [](auto type) { return type.getAnonymousType(); })
661  .Default([](Type) {
662  llvm_unreachable("unknown FIRRTL type");
663  return FIRRTLBaseType();
664  });
665 }
666 
667 /// Return this type with any flip types recursively removed from itself.
669  return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
670  .Case<ClockType, ResetType, AsyncResetType, SIntType, UIntType,
671  AnalogType, FEnumType>([&](Type) { return *this; })
672  .Case<BundleType, FVectorType, FEnumType, BaseTypeAliasType>(
673  [](auto type) { return type.getPassiveType(); })
674  .Default([](Type) {
675  llvm_unreachable("unknown FIRRTL type");
676  return FIRRTLBaseType();
677  });
678 }
679 
680 /// Return a 'const' or non-'const' version of this type.
682  return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
683  .Case<ClockType, ResetType, AsyncResetType, AnalogType, SIntType,
684  UIntType, BundleType, FVectorType, FEnumType, BaseTypeAliasType>(
685  [&](auto type) { return type.getConstType(isConst); })
686  .Default([](Type) {
687  llvm_unreachable("unknown FIRRTL type");
688  return FIRRTLBaseType();
689  });
690 }
691 
692 /// Return this type with a 'const' modifiers dropped
694  return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
695  .Case<ClockType, ResetType, AsyncResetType, AnalogType, SIntType,
696  UIntType>([&](auto type) { return type.getConstType(false); })
697  .Case<BundleType, FVectorType, FEnumType, BaseTypeAliasType>(
698  [&](auto type) { return type.getAllConstDroppedType(); })
699  .Default([](Type) {
700  llvm_unreachable("unknown FIRRTL type");
701  return FIRRTLBaseType();
702  });
703 }
704 
705 /// Return this type with all ground types replaced with UInt<1>. This is
706 /// used for `mem` operations.
708  return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
709  .Case<ClockType, ResetType, AsyncResetType, SIntType, UIntType,
710  AnalogType>([&](Type) {
711  return UIntType::get(this->getContext(), 1, this->isConst());
712  })
713  .Case<BundleType>([&](BundleType bundleType) {
714  SmallVector<BundleType::BundleElement, 4> newElements;
715  newElements.reserve(bundleType.getElements().size());
716  for (auto elt : bundleType)
717  newElements.push_back(
718  {elt.name, false /* FIXME */, elt.type.getMaskType()});
719  return BundleType::get(this->getContext(), newElements,
720  bundleType.isConst());
721  })
722  .Case<FVectorType>([](FVectorType vectorType) {
723  return FVectorType::get(vectorType.getElementType().getMaskType(),
724  vectorType.getNumElements(),
725  vectorType.isConst());
726  })
727  .Case<BaseTypeAliasType>([](BaseTypeAliasType base) {
728  return base.getModifiedType(base.getInnerType().getMaskType());
729  })
730  .Default([](Type) {
731  llvm_unreachable("unknown FIRRTL type");
732  return FIRRTLBaseType();
733  });
734 }
735 
736 /// Remove the widths from this type. All widths are replaced with an
737 /// unknown width.
739  return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
740  .Case<ClockType, ResetType, AsyncResetType>([](auto a) { return a; })
741  .Case<UIntType, SIntType, AnalogType>(
742  [&](auto a) { return a.get(this->getContext(), -1, a.isConst()); })
743  .Case<BundleType>([&](auto a) {
744  SmallVector<BundleType::BundleElement, 4> newElements;
745  newElements.reserve(a.getElements().size());
746  for (auto elt : a)
747  newElements.push_back(
748  {elt.name, elt.isFlip, elt.type.getWidthlessType()});
749  return BundleType::get(this->getContext(), newElements, a.isConst());
750  })
751  .Case<FVectorType>([](auto a) {
752  return FVectorType::get(a.getElementType().getWidthlessType(),
753  a.getNumElements(), a.isConst());
754  })
755  .Case<FEnumType>([&](FEnumType a) {
756  SmallVector<FEnumType::EnumElement, 4> newElements;
757  newElements.reserve(a.getNumElements());
758  for (auto elt : a)
759  newElements.push_back({elt.name, elt.type.getWidthlessType()});
760  return FEnumType::get(this->getContext(), newElements, a.isConst());
761  })
762  .Case<BaseTypeAliasType>([](BaseTypeAliasType type) {
763  return type.getModifiedType(type.getInnerType().getWidthlessType());
764  })
765  .Default([](auto) {
766  llvm_unreachable("unknown FIRRTL type");
767  return FIRRTLBaseType();
768  });
769 }
770 
771 /// If this is an IntType, AnalogType, or sugar type for a single bit (Clock,
772 /// Reset, etc) then return the bitwidth. Return -1 if the is one of these
773 /// types but without a specified bitwidth. Return -2 if this isn't a simple
774 /// type.
776  return TypeSwitch<FIRRTLBaseType, int32_t>(*this)
777  .Case<ClockType, ResetType, AsyncResetType>([](Type) { return 1; })
778  .Case<SIntType, UIntType>(
779  [&](IntType intType) { return intType.getWidthOrSentinel(); })
780  .Case<AnalogType>(
781  [](AnalogType analogType) { return analogType.getWidthOrSentinel(); })
782  .Case<BundleType, FVectorType, FEnumType>([](Type) { return -2; })
783  .Case<BaseTypeAliasType>([](BaseTypeAliasType type) {
784  // It's faster to use its anonymous type.
785  return type.getAnonymousType().getBitWidthOrSentinel();
786  })
787  .Default([](Type) {
788  llvm_unreachable("unknown FIRRTL type");
789  return -2;
790  });
791 }
792 
793 /// Return true if this is a type usable as a reset. This must be
794 /// either an abstract reset, a concrete 1-bit UInt, an
795 /// asynchronous reset, or an uninfered width UInt.
797  return TypeSwitch<FIRRTLType, bool>(*this)
798  .Case<ResetType, AsyncResetType>([](Type) { return true; })
799  .Case<UIntType>(
800  [](UIntType a) { return !a.hasWidth() || a.getWidth() == 1; })
801  .Case<BaseTypeAliasType>(
802  [](auto type) { return type.getInnerType().isResetType(); })
803  .Default([](Type) { return false; });
804 }
805 
806 bool firrtl::isConst(Type type) {
807  return TypeSwitch<Type, bool>(type)
808  .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>(
809  [](auto base) { return base.isConst(); })
810  .Default(false);
811 }
812 
813 bool firrtl::containsConst(Type type) {
814  return TypeSwitch<Type, bool>(type)
815  .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>(
816  [](auto base) { return base.containsConst(); })
817  .Default(false);
818 }
819 
820 /// Helper to implement the equivalence logic for a pair of bundle elements.
821 /// Note that the FIRRTL spec requires bundle elements to have the same
822 /// orientation, but this only compares their passive types. The FIRRTL dialect
823 /// differs from the spec in how it uses flip types for module output ports and
824 /// canonicalizes flips in bundles, so only passive types can be compared here.
825 static bool areBundleElementsEquivalent(BundleType::BundleElement destElement,
826  BundleType::BundleElement srcElement,
827  bool destOuterTypeIsConst,
828  bool srcOuterTypeIsConst,
829  bool requiresSameWidth) {
830  if (destElement.name != srcElement.name)
831  return false;
832  if (destElement.isFlip != srcElement.isFlip)
833  return false;
834 
835  if (destElement.isFlip) {
836  std::swap(destElement, srcElement);
837  std::swap(destOuterTypeIsConst, srcOuterTypeIsConst);
838  }
839 
840  return areTypesEquivalent(destElement.type, srcElement.type,
841  destOuterTypeIsConst, srcOuterTypeIsConst,
842  requiresSameWidth);
843 }
844 
845 /// Returns whether the two types are equivalent. This implements the exact
846 /// definition of type equivalence in the FIRRTL spec. If the types being
847 /// compared have any outer flips that encode FIRRTL module directions (input or
848 /// output), these should be stripped before using this method.
850  bool destOuterTypeIsConst,
851  bool srcOuterTypeIsConst,
852  bool requireSameWidths) {
853  auto destType = type_dyn_cast<FIRRTLBaseType>(destFType);
854  auto srcType = type_dyn_cast<FIRRTLBaseType>(srcFType);
855 
856  // For non-base types, only equivalent if identical.
857  if (!destType || !srcType)
858  return destFType == srcFType;
859 
860  bool srcIsConst = srcOuterTypeIsConst || srcFType.isConst();
861  bool destIsConst = destOuterTypeIsConst || destFType.isConst();
862 
863  // Vector types can be connected if they have the same size and element type.
864  auto destVectorType = type_dyn_cast<FVectorType>(destType);
865  auto srcVectorType = type_dyn_cast<FVectorType>(srcType);
866  if (destVectorType && srcVectorType)
867  return destVectorType.getNumElements() == srcVectorType.getNumElements() &&
868  areTypesEquivalent(destVectorType.getElementType(),
869  srcVectorType.getElementType(), destIsConst,
870  srcIsConst, requireSameWidths);
871 
872  // Bundle types can be connected if they have the same size, element names,
873  // and element types.
874  auto destBundleType = type_dyn_cast<BundleType>(destType);
875  auto srcBundleType = type_dyn_cast<BundleType>(srcType);
876  if (destBundleType && srcBundleType) {
877  auto destElements = destBundleType.getElements();
878  auto srcElements = srcBundleType.getElements();
879  size_t numDestElements = destElements.size();
880  if (numDestElements != srcElements.size())
881  return false;
882 
883  for (size_t i = 0; i < numDestElements; ++i) {
884  auto destElement = destElements[i];
885  auto srcElement = srcElements[i];
886  if (!areBundleElementsEquivalent(destElement, srcElement, destIsConst,
887  srcIsConst, requireSameWidths))
888  return false;
889  }
890  return true;
891  }
892 
893  // Enum types can be connected if they have the same size, element names, and
894  // element types.
895  auto dstEnumType = type_dyn_cast<FEnumType>(destType);
896  auto srcEnumType = type_dyn_cast<FEnumType>(srcType);
897 
898  if (dstEnumType && srcEnumType) {
899  if (dstEnumType.getNumElements() != srcEnumType.getNumElements())
900  return false;
901  // Enums requires the types to match exactly.
902  for (const auto &[dst, src] : llvm::zip(dstEnumType, srcEnumType)) {
903  // The variant names must match.
904  if (dst.name != src.name)
905  return false;
906  // Enumeration types can only be connected if the inner types have the
907  // same width.
908  if (!areTypesEquivalent(dst.type, src.type, destIsConst, srcIsConst,
909  true))
910  return false;
911  }
912  return true;
913  }
914 
915  // Ground type connections must be const compatible.
916  if (destIsConst && !srcIsConst)
917  return false;
918 
919  // Reset types can be driven by UInt<1>, AsyncReset, or Reset types.
920  if (firrtl::type_isa<ResetType>(destType))
921  return srcType.isResetType();
922 
923  // Reset types can drive UInt<1>, AsyncReset, or Reset types.
924  if (firrtl::type_isa<ResetType>(srcType))
925  return destType.isResetType();
926 
927  // If we can implicitly truncate or extend the bitwidth, or either width is
928  // currently uninferred, then compare the widthless version of these types.
929  if (!requireSameWidths || destType.getBitWidthOrSentinel() == -1)
930  srcType = srcType.getWidthlessType();
931  if (!requireSameWidths || srcType.getBitWidthOrSentinel() == -1)
932  destType = destType.getWidthlessType();
933 
934  // Ground types can be connected if their constless types are the same
935  return destType.getConstType(false) == srcType.getConstType(false);
936 }
937 
938 /// Returns whether the two types are weakly equivalent.
940  bool destFlip, bool srcFlip,
941  bool destOuterTypeIsConst,
942  bool srcOuterTypeIsConst) {
943  auto destType = type_dyn_cast<FIRRTLBaseType>(destFType);
944  auto srcType = type_dyn_cast<FIRRTLBaseType>(srcFType);
945 
946  // For non-base types, only equivalent if identical.
947  if (!destType || !srcType)
948  return destFType == srcFType;
949 
950  bool srcIsConst = srcOuterTypeIsConst || srcFType.isConst();
951  bool destIsConst = destOuterTypeIsConst || destFType.isConst();
952 
953  // Vector types can be connected if their element types are weakly equivalent.
954  // Size doesn't matter.
955  auto destVectorType = type_dyn_cast<FVectorType>(destType);
956  auto srcVectorType = type_dyn_cast<FVectorType>(srcType);
957  if (destVectorType && srcVectorType)
958  return areTypesWeaklyEquivalent(destVectorType.getElementType(),
959  srcVectorType.getElementType(), destFlip,
960  srcFlip, destIsConst, srcIsConst);
961 
962  // Bundle types are weakly equivalent if all common elements are weakly
963  // equivalent. Non-matching fields are ignored. Flips are "pushed" into
964  // recursive weak type equivalence checks.
965  auto destBundleType = type_dyn_cast<BundleType>(destType);
966  auto srcBundleType = type_dyn_cast<BundleType>(srcType);
967  if (destBundleType && srcBundleType)
968  return llvm::all_of(destBundleType, [&](auto destElt) -> bool {
969  auto destField = destElt.name.getValue();
970  auto srcElt = srcBundleType.getElement(destField);
971  // If the src doesn't contain the destination's field, that's okay.
972  if (!srcElt)
973  return true;
974 
976  destElt.type, srcElt->type, destFlip ^ destElt.isFlip,
977  srcFlip ^ srcElt->isFlip, destOuterTypeIsConst, srcOuterTypeIsConst);
978  });
979 
980  // Ground types require leaf flippedness and const compatibility
981  if (destFlip != srcFlip)
982  return false;
983  if (destFlip && srcIsConst && !destIsConst)
984  return false;
985  if (srcFlip && destIsConst && !srcIsConst)
986  return false;
987 
988  // Reset types can be driven by UInt<1>, AsyncReset, or Reset types.
989  if (type_isa<ResetType>(destType))
990  return srcType.isResetType();
991 
992  // Reset types can drive UInt<1>, AsyncReset, or Reset types.
993  if (type_isa<ResetType>(srcType))
994  return destType.isResetType();
995 
996  // Ground types can be connected if their passive, widthless versions
997  // are equal and are const and flip compatible
998  auto widthlessDestType = destType.getWidthlessType();
999  auto widthlessSrcType = srcType.getWidthlessType();
1000  return widthlessDestType.getConstType(false) ==
1001  widthlessSrcType.getConstType(false);
1002 }
1003 
1004 /// Returns whether the srcType can be const-casted to the destType.
1006  bool srcOuterTypeIsConst) {
1007  // Identical types are always castable
1008  if (destFType == srcFType)
1009  return true;
1010 
1011  auto destType = type_dyn_cast<FIRRTLBaseType>(destFType);
1012  auto srcType = type_dyn_cast<FIRRTLBaseType>(srcFType);
1013 
1014  // For non-base types, only castable if identical.
1015  if (!destType || !srcType)
1016  return false;
1017 
1018  // Types must be passive
1019  if (!destType.isPassive() || !srcType.isPassive())
1020  return false;
1021 
1022  bool srcIsConst = srcType.isConst() || srcOuterTypeIsConst;
1023 
1024  // Cannot cast non-'const' src to 'const' dest
1025  if (destType.isConst() && !srcIsConst)
1026  return false;
1027 
1028  // Vector types can be casted if they have the same size and castable element
1029  // type.
1030  auto destVectorType = type_dyn_cast<FVectorType>(destType);
1031  auto srcVectorType = type_dyn_cast<FVectorType>(srcType);
1032  if (destVectorType && srcVectorType)
1033  return destVectorType.getNumElements() == srcVectorType.getNumElements() &&
1034  areTypesConstCastable(destVectorType.getElementType(),
1035  srcVectorType.getElementType(), srcIsConst);
1036  if (destVectorType != srcVectorType)
1037  return false;
1038 
1039  // Bundle types can be casted if they have the same size, element names,
1040  // and castable element types.
1041  auto destBundleType = type_dyn_cast<BundleType>(destType);
1042  auto srcBundleType = type_dyn_cast<BundleType>(srcType);
1043  if (destBundleType && srcBundleType) {
1044  auto destElements = destBundleType.getElements();
1045  auto srcElements = srcBundleType.getElements();
1046  size_t numDestElements = destElements.size();
1047  if (numDestElements != srcElements.size())
1048  return false;
1049 
1050  return llvm::all_of_zip(
1051  destElements, srcElements,
1052  [&](const auto &destElement, const auto &srcElement) {
1053  return destElement.name == srcElement.name &&
1054  areTypesConstCastable(destElement.type, srcElement.type,
1055  srcIsConst);
1056  });
1057  }
1058  if (destBundleType != srcBundleType)
1059  return false;
1060 
1061  // Ground types can be casted if the source type is a const
1062  // version of the destination type
1063  return destType == srcType.getConstType(destType.isConst());
1064 }
1065 
1066 bool firrtl::areTypesRefCastable(Type dstType, Type srcType) {
1067  auto dstRefType = type_dyn_cast<RefType>(dstType);
1068  auto srcRefType = type_dyn_cast<RefType>(srcType);
1069  if (!dstRefType || !srcRefType)
1070  return false;
1071  if (dstRefType == srcRefType)
1072  return true;
1073  if (dstRefType.getForceable() && !srcRefType.getForceable())
1074  return false;
1075 
1076  // Okay walk the types recursively. They must be identical "structurally"
1077  // with exception leaf (ground) types of destination can be uninferred
1078  // versions of the corresponding source type. (can lose width information or
1079  // become a more general reset type)
1080  // In addition, while not explicitly in spec its useful to allow probes
1081  // to have const cast away, especially for probes of literals and expressions
1082  // derived from them. Check const as with const cast.
1083  // NOLINTBEGIN(misc-no-recursion)
1084  auto recurse = [&](auto &&f, FIRRTLBaseType dest, FIRRTLBaseType src,
1085  bool srcOuterTypeIsConst) -> bool {
1086  // Fast-path for identical types.
1087  if (dest == src)
1088  return true;
1089 
1090  // Always passive inside probes, but for sanity assert this.
1091  assert(dest.isPassive() && src.isPassive());
1092 
1093  bool srcIsConst = src.isConst() || srcOuterTypeIsConst;
1094 
1095  // Cannot cast non-'const' src to 'const' dest
1096  if (dest.isConst() && !srcIsConst)
1097  return false;
1098 
1099  // Recurse through aggregates to get the leaves, checking
1100  // structural equivalence re:element count + names.
1101 
1102  if (auto destVectorType = type_dyn_cast<FVectorType>(dest)) {
1103  auto srcVectorType = type_dyn_cast<FVectorType>(src);
1104  return srcVectorType &&
1105  destVectorType.getNumElements() ==
1106  srcVectorType.getNumElements() &&
1107  f(f, destVectorType.getElementType(),
1108  srcVectorType.getElementType(), srcIsConst);
1109  }
1110 
1111  if (auto destBundleType = type_dyn_cast<BundleType>(dest)) {
1112  auto srcBundleType = type_dyn_cast<BundleType>(src);
1113  if (!srcBundleType)
1114  return false;
1115  // (no need to check orientation, these are always passive)
1116  auto destElements = destBundleType.getElements();
1117  auto srcElements = srcBundleType.getElements();
1118 
1119  return destElements.size() == srcElements.size() &&
1120  llvm::all_of_zip(
1121  destElements, srcElements,
1122  [&](const auto &destElement, const auto &srcElement) {
1123  return destElement.name == srcElement.name &&
1124  f(f, destElement.type, srcElement.type, srcIsConst);
1125  });
1126  }
1127 
1128  if (auto destEnumType = type_dyn_cast<FEnumType>(dest)) {
1129  auto srcEnumType = type_dyn_cast<FEnumType>(src);
1130  if (!srcEnumType)
1131  return false;
1132  auto destElements = destEnumType.getElements();
1133  auto srcElements = srcEnumType.getElements();
1134 
1135  return destElements.size() == srcElements.size() &&
1136  llvm::all_of_zip(
1137  destElements, srcElements,
1138  [&](const auto &destElement, const auto &srcElement) {
1139  return destElement.name == srcElement.name &&
1140  f(f, destElement.type, srcElement.type, srcIsConst);
1141  });
1142  }
1143 
1144  // Reset types can be driven by UInt<1>, AsyncReset, or Reset types.
1145  if (type_isa<ResetType>(dest))
1146  return src.isResetType();
1147  // (but don't allow the other direction, can only become more general)
1148 
1149  // Compare against const src if dest is const.
1150  src = src.getConstType(dest.isConst());
1151 
1152  // Compare against widthless src if dest is widthless.
1153  if (dest.getBitWidthOrSentinel() == -1)
1154  src = src.getWidthlessType();
1155 
1156  return dest == src;
1157  };
1158 
1159  return recurse(recurse, dstRefType.getType(), srcRefType.getType(), false);
1160  // NOLINTEND(misc-no-recursion)
1161 }
1162 
1163 // NOLINTBEGIN(misc-no-recursion)
1164 /// Returns true if the destination is at least as wide as an equivalent source.
1166  return TypeSwitch<FIRRTLBaseType, bool>(dstType)
1167  .Case<BundleType>([&](auto dstBundle) {
1168  auto srcBundle = type_cast<BundleType>(srcType);
1169  for (size_t i = 0, n = dstBundle.getNumElements(); i < n; ++i) {
1170  auto srcElem = srcBundle.getElement(i);
1171  auto dstElem = dstBundle.getElement(i);
1172  if (dstElem.isFlip) {
1173  if (!isTypeLarger(srcElem.type, dstElem.type))
1174  return false;
1175  } else {
1176  if (!isTypeLarger(dstElem.type, srcElem.type))
1177  return false;
1178  }
1179  }
1180  return true;
1181  })
1182  .Case<FVectorType>([&](auto vector) {
1183  return isTypeLarger(vector.getElementType(),
1184  type_cast<FVectorType>(srcType).getElementType());
1185  })
1186  .Default([&](auto dstGround) {
1187  int32_t destWidth = dstType.getPassiveType().getBitWidthOrSentinel();
1188  int32_t srcWidth = srcType.getPassiveType().getBitWidthOrSentinel();
1189  return destWidth <= -1 || srcWidth <= -1 || destWidth >= srcWidth;
1190  });
1191 }
1192 // NOLINTEND(misc-no-recursion)
1193 
1195  FIRRTLBaseType rhs) {
1196  return lhs.getAnonymousType() == rhs.getAnonymousType();
1197 }
1198 
1199 bool firrtl::areAnonymousTypesEquivalent(mlir::Type lhs, mlir::Type rhs) {
1200  if (auto destBaseType = type_dyn_cast<FIRRTLBaseType>(lhs))
1201  if (auto srcBaseType = type_dyn_cast<FIRRTLBaseType>(rhs))
1202  return areAnonymousTypesEquivalent(destBaseType, srcBaseType);
1203 
1204  if (auto destRefType = type_dyn_cast<RefType>(lhs))
1205  if (auto srcRefType = type_dyn_cast<RefType>(rhs))
1206  return areAnonymousTypesEquivalent(destRefType.getType(),
1207  srcRefType.getType());
1208 
1209  return lhs == rhs;
1210 }
1211 
1212 /// Return the passive version of a firrtl type
1213 /// top level for ODS constraint usage
1214 Type firrtl::getPassiveType(Type anyBaseFIRRTLType) {
1215  return type_cast<FIRRTLBaseType>(anyBaseFIRRTLType).getPassiveType();
1216 }
1217 
1218 bool firrtl::isTypeInOut(Type type) {
1219  return llvm::TypeSwitch<Type, bool>(type)
1220  .Case<FIRRTLBaseType>([](auto type) {
1221  return !type.containsReference() &&
1222  (!type.isPassive() || type.containsAnalog());
1223  })
1224  .Default(false);
1225 }
1226 
1227 //===----------------------------------------------------------------------===//
1228 // IntType
1229 //===----------------------------------------------------------------------===//
1230 
1231 /// Return a SIntType or UIntType with the specified signedness, width, and
1232 /// constness
1233 IntType IntType::get(MLIRContext *context, bool isSigned,
1234  int32_t widthOrSentinel, bool isConst) {
1235  if (isSigned)
1236  return SIntType::get(context, widthOrSentinel, isConst);
1237  return UIntType::get(context, widthOrSentinel, isConst);
1238 }
1239 
1241  if (auto sintType = type_dyn_cast<SIntType>(*this))
1242  return sintType.getWidthOrSentinel();
1243  if (auto uintType = type_dyn_cast<UIntType>(*this))
1244  return uintType.getWidthOrSentinel();
1245  return -1;
1246 }
1247 
1248 //===----------------------------------------------------------------------===//
1249 // WidthTypeStorage
1250 //===----------------------------------------------------------------------===//
1251 
1255  using KeyTy = std::pair<int32_t, char>;
1256 
1257  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1258 
1259  KeyTy getAsKey() const { return KeyTy(width, isConst); }
1260 
1261  static WidthTypeStorage *construct(TypeStorageAllocator &allocator,
1262  const KeyTy &key) {
1263  return new (allocator.allocate<WidthTypeStorage>())
1264  WidthTypeStorage(key.first, key.second);
1265  }
1266 
1267  int32_t width;
1268 };
1269 
1271 
1272  if (auto sIntType = type_dyn_cast<SIntType>(*this))
1273  return sIntType.getConstType(isConst);
1274  return type_cast<UIntType>(*this).getConstType(isConst);
1275 }
1276 
1277 //===----------------------------------------------------------------------===//
1278 // SIntType
1279 //===----------------------------------------------------------------------===//
1280 
1281 SIntType SIntType::get(MLIRContext *context) { return get(context, -1, false); }
1282 
1283 SIntType SIntType::get(MLIRContext *context, std::optional<int32_t> width,
1284  bool isConst) {
1285  return get(context, width ? *width : -1, isConst);
1286 }
1287 
1288 LogicalResult SIntType::verify(function_ref<InFlightDiagnostic()> emitError,
1289  int32_t widthOrSentinel, bool isConst) {
1290  if (widthOrSentinel < -1)
1291  return emitError() << "invalid width";
1292  return success();
1293 }
1294 
1295 int32_t SIntType::getWidthOrSentinel() const { return getImpl()->width; }
1296 
1297 SIntType SIntType::getConstType(bool isConst) {
1298  if (isConst == this->isConst())
1299  return *this;
1300  return get(getContext(), getWidthOrSentinel(), isConst);
1301 }
1302 
1303 //===----------------------------------------------------------------------===//
1304 // UIntType
1305 //===----------------------------------------------------------------------===//
1306 
1307 UIntType UIntType::get(MLIRContext *context) { return get(context, -1, false); }
1308 
1309 UIntType UIntType::get(MLIRContext *context, std::optional<int32_t> width,
1310  bool isConst) {
1311  return get(context, width ? *width : -1, isConst);
1312 }
1313 
1314 LogicalResult UIntType::verify(function_ref<InFlightDiagnostic()> emitError,
1315  int32_t widthOrSentinel, bool isConst) {
1316  if (widthOrSentinel < -1)
1317  return emitError() << "invalid width";
1318  return success();
1319 }
1320 
1321 int32_t UIntType::getWidthOrSentinel() const { return getImpl()->width; }
1322 
1323 UIntType UIntType::getConstType(bool isConst) {
1324  if (isConst == this->isConst())
1325  return *this;
1326  return get(getContext(), getWidthOrSentinel(), isConst);
1327 }
1328 
1329 //===----------------------------------------------------------------------===//
1330 // Bundle Type
1331 //===----------------------------------------------------------------------===//
1332 
1335  using KeyTy = std::pair<ArrayRef<BundleType::BundleElement>, char>;
1336 
1337  BundleTypeStorage(ArrayRef<BundleType::BundleElement> elements, bool isConst)
1339  elements(elements.begin(), elements.end()), props{true, false, false,
1340  isConst, false, false,
1341  false} {
1342  uint64_t fieldID = 0;
1343  fieldIDs.reserve(elements.size());
1344  for (auto &element : elements) {
1345  auto type = element.type;
1346  auto eltInfo = type.getRecursiveTypeProperties();
1347  props.isPassive &= eltInfo.isPassive & !element.isFlip;
1348  props.containsAnalog |= eltInfo.containsAnalog;
1349  props.containsReference |= eltInfo.containsReference;
1350  props.containsConst |= eltInfo.containsConst;
1351  props.containsTypeAlias |= eltInfo.containsTypeAlias;
1352  props.hasUninferredWidth |= eltInfo.hasUninferredWidth;
1353  props.hasUninferredReset |= eltInfo.hasUninferredReset;
1354  fieldID += 1;
1355  fieldIDs.push_back(fieldID);
1356  // Increment the field ID for the next field by the number of subfields.
1357  fieldID += hw::FieldIdImpl::getMaxFieldID(type);
1358  }
1359  maxFieldID = fieldID;
1360  }
1361 
1362  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1363 
1364  KeyTy getAsKey() const { return KeyTy(elements, isConst); }
1365 
1366  static llvm::hash_code hashKey(const KeyTy &key) {
1367  return llvm::hash_combine(
1368  llvm::hash_combine_range(key.first.begin(), key.first.end()),
1369  key.second);
1370  }
1371 
1372  static BundleTypeStorage *construct(TypeStorageAllocator &allocator,
1373  KeyTy key) {
1374  return new (allocator.allocate<BundleTypeStorage>())
1375  BundleTypeStorage(key.first, static_cast<bool>(key.second));
1376  }
1377 
1378  SmallVector<BundleType::BundleElement, 4> elements;
1379  SmallVector<uint64_t, 4> fieldIDs;
1380  uint64_t maxFieldID;
1381 
1382  /// This holds the bits for the type's recursive properties, and can hold a
1383  /// pointer to a passive version of the type.
1385  BundleType passiveType;
1386  BundleType anonymousType;
1387 };
1388 
1389 BundleType BundleType::get(MLIRContext *context,
1390  ArrayRef<BundleElement> elements, bool isConst) {
1391  return Base::get(context, elements, isConst);
1392 }
1393 
1394 auto BundleType::getElements() const -> ArrayRef<BundleElement> {
1395  return getImpl()->elements;
1396 }
1397 
1398 /// Return a pair with the 'isPassive' and 'containsAnalog' bits.
1399 RecursiveTypeProperties BundleType::getRecursiveTypeProperties() const {
1400  return getImpl()->props;
1401 }
1402 
1403 /// Return this type with any flip types recursively removed from itself.
1405  auto *impl = getImpl();
1406 
1407  // If we've already determined and cached the passive type, use it.
1408  if (impl->passiveType)
1409  return impl->passiveType;
1410 
1411  // If this type is already passive, use it and remember for next time.
1412  if (impl->props.isPassive) {
1413  impl->passiveType = *this;
1414  return *this;
1415  }
1416 
1417  // Otherwise at least one element is non-passive, rebuild a passive version.
1418  SmallVector<BundleType::BundleElement, 16> newElements;
1419  newElements.reserve(impl->elements.size());
1420  for (auto &elt : impl->elements) {
1421  newElements.push_back({elt.name, false, elt.type.getPassiveType()});
1422  }
1423 
1424  auto passiveType = BundleType::get(getContext(), newElements, isConst());
1425  impl->passiveType = passiveType;
1426  return passiveType;
1427 }
1428 
1429 BundleType BundleType::getConstType(bool isConst) {
1430  if (isConst == this->isConst())
1431  return *this;
1432  return get(getContext(), getElements(), isConst);
1433 }
1434 
1435 BundleType BundleType::getAllConstDroppedType() {
1436  if (!containsConst())
1437  return *this;
1438 
1439  SmallVector<BundleElement> constDroppedElements(
1440  llvm::map_range(getElements(), [](BundleElement element) {
1441  element.type = element.type.getAllConstDroppedType();
1442  return element;
1443  }));
1444  return get(getContext(), constDroppedElements, false);
1445 }
1446 
1447 std::optional<unsigned> BundleType::getElementIndex(StringAttr name) {
1448  for (const auto &it : llvm::enumerate(getElements())) {
1449  auto element = it.value();
1450  if (element.name == name) {
1451  return unsigned(it.index());
1452  }
1453  }
1454  return std::nullopt;
1455 }
1456 
1457 std::optional<unsigned> BundleType::getElementIndex(StringRef name) {
1458  for (const auto &it : llvm::enumerate(getElements())) {
1459  auto element = it.value();
1460  if (element.name.getValue() == name) {
1461  return unsigned(it.index());
1462  }
1463  }
1464  return std::nullopt;
1465 }
1466 
1467 StringAttr BundleType::getElementNameAttr(size_t index) {
1468  assert(index < getNumElements() &&
1469  "index must be less than number of fields in bundle");
1470  return getElements()[index].name;
1471 }
1472 
1473 StringRef BundleType::getElementName(size_t index) {
1474  return getElementNameAttr(index).getValue();
1475 }
1476 
1477 std::optional<BundleType::BundleElement>
1478 BundleType::getElement(StringAttr name) {
1479  if (auto maybeIndex = getElementIndex(name))
1480  return getElements()[*maybeIndex];
1481  return std::nullopt;
1482 }
1483 
1484 std::optional<BundleType::BundleElement>
1485 BundleType::getElement(StringRef name) {
1486  if (auto maybeIndex = getElementIndex(name))
1487  return getElements()[*maybeIndex];
1488  return std::nullopt;
1489 }
1490 
1491 /// Look up an element by index.
1492 BundleType::BundleElement BundleType::getElement(size_t index) {
1493  assert(index < getNumElements() &&
1494  "index must be less than number of fields in bundle");
1495  return getElements()[index];
1496 }
1497 
1498 FIRRTLBaseType BundleType::getElementType(StringAttr name) {
1499  auto element = getElement(name);
1500  return element ? element->type : FIRRTLBaseType();
1501 }
1502 
1503 FIRRTLBaseType BundleType::getElementType(StringRef name) {
1504  auto element = getElement(name);
1505  return element ? element->type : FIRRTLBaseType();
1506 }
1507 
1508 FIRRTLBaseType BundleType::getElementType(size_t index) const {
1509  assert(index < getNumElements() &&
1510  "index must be less than number of fields in bundle");
1511  return getElements()[index].type;
1512 }
1513 
1514 uint64_t BundleType::getFieldID(uint64_t index) const {
1515  return getImpl()->fieldIDs[index];
1516 }
1517 
1518 uint64_t BundleType::getIndexForFieldID(uint64_t fieldID) const {
1519  assert(!getElements().empty() && "Bundle must have >0 fields");
1520  auto fieldIDs = getImpl()->fieldIDs;
1521  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
1522  return std::distance(fieldIDs.begin(), it);
1523 }
1524 
1525 std::pair<uint64_t, uint64_t>
1526 BundleType::getIndexAndSubfieldID(uint64_t fieldID) const {
1527  auto index = getIndexForFieldID(fieldID);
1528  auto elementFieldID = getFieldID(index);
1529  return {index, fieldID - elementFieldID};
1530 }
1531 
1532 std::pair<Type, uint64_t>
1533 BundleType::getSubTypeByFieldID(uint64_t fieldID) const {
1534  if (fieldID == 0)
1535  return {*this, 0};
1536  auto subfieldIndex = getIndexForFieldID(fieldID);
1537  auto subfieldType = getElementType(subfieldIndex);
1538  auto subfieldID = fieldID - getFieldID(subfieldIndex);
1539  return {subfieldType, subfieldID};
1540 }
1541 
1542 uint64_t BundleType::getMaxFieldID() const { return getImpl()->maxFieldID; }
1543 
1544 std::pair<uint64_t, bool>
1545 BundleType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
1546  auto childRoot = getFieldID(index);
1547  auto rangeEnd = index + 1 >= getNumElements() ? getMaxFieldID()
1548  : (getFieldID(index + 1) - 1);
1549  return std::make_pair(fieldID - childRoot,
1550  fieldID >= childRoot && fieldID <= rangeEnd);
1551 }
1552 
1553 bool BundleType::isConst() { return getImpl()->isConst; }
1554 
1555 BundleType::ElementType
1556 BundleType::getElementTypePreservingConst(size_t index) {
1557  auto type = getElementType(index);
1558  return type.getConstType(type.isConst() || isConst());
1559 }
1560 
1561 /// Return this type with any type aliases recursively removed from itself.
1562 FIRRTLBaseType BundleType::getAnonymousType() {
1563  auto *impl = getImpl();
1564 
1565  // If we've already determined and cached the anonymous type, use it.
1566  if (impl->anonymousType)
1567  return impl->anonymousType;
1568 
1569  // If this type is already anonymous, use it and remember for next time.
1570  if (!impl->props.containsTypeAlias) {
1571  impl->anonymousType = *this;
1572  return *this;
1573  }
1574 
1575  // Otherwise at least one element has an alias type, rebuild an anonymous
1576  // version.
1577  SmallVector<BundleType::BundleElement, 16> newElements;
1578  newElements.reserve(impl->elements.size());
1579  for (auto &elt : impl->elements)
1580  newElements.push_back({elt.name, elt.isFlip, elt.type.getAnonymousType()});
1581 
1582  auto anonymousType = BundleType::get(getContext(), newElements, isConst());
1583  impl->anonymousType = anonymousType;
1584  return anonymousType;
1585 }
1586 
1587 //===----------------------------------------------------------------------===//
1588 // OpenBundle Type
1589 //===----------------------------------------------------------------------===//
1590 
1592  using KeyTy = std::pair<ArrayRef<OpenBundleType::BundleElement>, char>;
1593 
1594  OpenBundleTypeStorage(ArrayRef<OpenBundleType::BundleElement> elements,
1595  bool isConst)
1596  : elements(elements.begin(), elements.end()), props{true, false, false,
1597  isConst, false, false,
1598  false},
1599  isConst(static_cast<char>(isConst)) {
1600  uint64_t fieldID = 0;
1601  fieldIDs.reserve(elements.size());
1602  for (auto &element : elements) {
1603  auto type = element.type;
1604  auto eltInfo = type.getRecursiveTypeProperties();
1605  props.isPassive &= eltInfo.isPassive & !element.isFlip;
1606  props.containsAnalog |= eltInfo.containsAnalog;
1607  props.containsReference |= eltInfo.containsReference;
1608  props.containsConst |= eltInfo.containsConst;
1609  props.containsTypeAlias |= eltInfo.containsTypeAlias;
1610  props.hasUninferredWidth |= eltInfo.hasUninferredWidth;
1611  props.hasUninferredReset |= eltInfo.hasUninferredReset;
1612  fieldID += 1;
1613  fieldIDs.push_back(fieldID);
1614  // Increment the field ID for the next field by the number of subfields.
1615  // TODO: Maybe just have elementType be FieldIDTypeInterface ?
1616  fieldID += hw::FieldIdImpl::getMaxFieldID(type);
1617  }
1618  maxFieldID = fieldID;
1619  }
1620 
1621  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1622 
1623  static llvm::hash_code hashKey(const KeyTy &key) {
1624  return llvm::hash_combine(
1625  llvm::hash_combine_range(key.first.begin(), key.first.end()),
1626  key.second);
1627  }
1628 
1629  KeyTy getAsKey() const { return KeyTy(elements, isConst); }
1630 
1631  static OpenBundleTypeStorage *construct(TypeStorageAllocator &allocator,
1632  KeyTy key) {
1633  return new (allocator.allocate<OpenBundleTypeStorage>())
1634  OpenBundleTypeStorage(key.first, static_cast<bool>(key.second));
1635  }
1636 
1637  SmallVector<OpenBundleType::BundleElement, 4> elements;
1638  SmallVector<uint64_t, 4> fieldIDs;
1639  uint64_t maxFieldID;
1640 
1641  /// This holds the bits for the type's recursive properties, and can hold a
1642  /// pointer to a passive version of the type.
1644 
1645  // Whether this is 'const'.
1646  char isConst;
1647 };
1648 
1649 OpenBundleType OpenBundleType::get(MLIRContext *context,
1650  ArrayRef<BundleElement> elements,
1651  bool isConst) {
1652  return Base::get(context, elements, isConst);
1653 }
1654 
1655 auto OpenBundleType::getElements() const -> ArrayRef<BundleElement> {
1656  return getImpl()->elements;
1657 }
1658 
1659 /// Return a pair with the 'isPassive' and 'containsAnalog' bits.
1660 RecursiveTypeProperties OpenBundleType::getRecursiveTypeProperties() const {
1661  return getImpl()->props;
1662 }
1663 
1664 OpenBundleType OpenBundleType::getConstType(bool isConst) {
1665  if (isConst == this->isConst())
1666  return *this;
1667  return get(getContext(), getElements(), isConst);
1668 }
1669 
1670 std::optional<unsigned> OpenBundleType::getElementIndex(StringAttr name) {
1671  for (const auto &it : llvm::enumerate(getElements())) {
1672  auto element = it.value();
1673  if (element.name == name) {
1674  return unsigned(it.index());
1675  }
1676  }
1677  return std::nullopt;
1678 }
1679 
1680 std::optional<unsigned> OpenBundleType::getElementIndex(StringRef name) {
1681  for (const auto &it : llvm::enumerate(getElements())) {
1682  auto element = it.value();
1683  if (element.name.getValue() == name) {
1684  return unsigned(it.index());
1685  }
1686  }
1687  return std::nullopt;
1688 }
1689 
1690 StringAttr OpenBundleType::getElementNameAttr(size_t index) {
1691  assert(index < getNumElements() &&
1692  "index must be less than number of fields in bundle");
1693  return getElements()[index].name;
1694 }
1695 
1696 StringRef OpenBundleType::getElementName(size_t index) {
1697  return getElementNameAttr(index).getValue();
1698 }
1699 
1700 std::optional<OpenBundleType::BundleElement>
1701 OpenBundleType::getElement(StringAttr name) {
1702  if (auto maybeIndex = getElementIndex(name))
1703  return getElements()[*maybeIndex];
1704  return std::nullopt;
1705 }
1706 
1707 std::optional<OpenBundleType::BundleElement>
1708 OpenBundleType::getElement(StringRef name) {
1709  if (auto maybeIndex = getElementIndex(name))
1710  return getElements()[*maybeIndex];
1711  return std::nullopt;
1712 }
1713 
1714 /// Look up an element by index.
1715 OpenBundleType::BundleElement OpenBundleType::getElement(size_t index) {
1716  assert(index < getNumElements() &&
1717  "index must be less than number of fields in bundle");
1718  return getElements()[index];
1719 }
1720 
1721 OpenBundleType::ElementType OpenBundleType::getElementType(StringAttr name) {
1722  auto element = getElement(name);
1723  return element ? element->type : FIRRTLBaseType();
1724 }
1725 
1726 OpenBundleType::ElementType OpenBundleType::getElementType(StringRef name) {
1727  auto element = getElement(name);
1728  return element ? element->type : FIRRTLBaseType();
1729 }
1730 
1731 OpenBundleType::ElementType OpenBundleType::getElementType(size_t index) const {
1732  assert(index < getNumElements() &&
1733  "index must be less than number of fields in bundle");
1734  return getElements()[index].type;
1735 }
1736 
1737 uint64_t OpenBundleType::getFieldID(uint64_t index) const {
1738  return getImpl()->fieldIDs[index];
1739 }
1740 
1741 uint64_t OpenBundleType::getIndexForFieldID(uint64_t fieldID) const {
1742  assert(!getElements().empty() && "Bundle must have >0 fields");
1743  auto fieldIDs = getImpl()->fieldIDs;
1744  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
1745  return std::distance(fieldIDs.begin(), it);
1746 }
1747 
1748 std::pair<uint64_t, uint64_t>
1749 OpenBundleType::getIndexAndSubfieldID(uint64_t fieldID) const {
1750  auto index = getIndexForFieldID(fieldID);
1751  auto elementFieldID = getFieldID(index);
1752  return {index, fieldID - elementFieldID};
1753 }
1754 
1755 std::pair<Type, uint64_t>
1756 OpenBundleType::getSubTypeByFieldID(uint64_t fieldID) const {
1757  if (fieldID == 0)
1758  return {*this, 0};
1759  auto subfieldIndex = getIndexForFieldID(fieldID);
1760  auto subfieldType = getElementType(subfieldIndex);
1761  auto subfieldID = fieldID - getFieldID(subfieldIndex);
1762  return {subfieldType, subfieldID};
1763 }
1764 
1765 uint64_t OpenBundleType::getMaxFieldID() const { return getImpl()->maxFieldID; }
1766 
1767 std::pair<uint64_t, bool>
1768 OpenBundleType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
1769  auto childRoot = getFieldID(index);
1770  auto rangeEnd = index + 1 >= getNumElements() ? getMaxFieldID()
1771  : (getFieldID(index + 1) - 1);
1772  return std::make_pair(fieldID - childRoot,
1773  fieldID >= childRoot && fieldID <= rangeEnd);
1774 }
1775 
1776 bool OpenBundleType::isConst() { return getImpl()->isConst; }
1777 
1778 OpenBundleType::ElementType
1779 OpenBundleType::getElementTypePreservingConst(size_t index) {
1780  auto type = getElementType(index);
1781  // TODO: ConstTypeInterface / Trait ?
1782  return TypeSwitch<FIRRTLType, ElementType>(type)
1783  .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>([&](auto type) {
1784  return type.getConstType(type.isConst() || isConst());
1785  })
1786  .Default(type);
1787 }
1788 
1789 LogicalResult
1790 OpenBundleType::verify(function_ref<InFlightDiagnostic()> emitErrorFn,
1791  ArrayRef<BundleElement> elements, bool isConst) {
1792  for (auto &element : elements) {
1793  if (FIRRTLType(element.type).containsReference() && isConst)
1794  return emitErrorFn()
1795  << "'const' bundle cannot have references, but element "
1796  << element.name << " has type " << element.type;
1797  }
1798 
1799  return success();
1800 }
1801 
1802 //===----------------------------------------------------------------------===//
1803 // FVectorType
1804 //===----------------------------------------------------------------------===//
1805 
1808  using KeyTy = std::tuple<FIRRTLBaseType, size_t, char>;
1809 
1811  bool isConst)
1814  props(elementType.getRecursiveTypeProperties()) {
1816  }
1817 
1818  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1819 
1821 
1822  static FVectorTypeStorage *construct(TypeStorageAllocator &allocator,
1823  KeyTy key) {
1824  return new (allocator.allocate<FVectorTypeStorage>())
1825  FVectorTypeStorage(std::get<0>(key), std::get<1>(key),
1826  static_cast<bool>(std::get<2>(key)));
1827  }
1828 
1830  size_t numElements;
1831 
1832  /// This holds the bits for the type's recursive properties, and can hold a
1833  /// pointer to a passive version of the type.
1837 };
1838 
1840  bool isConst) {
1841  return Base::get(elementType.getContext(), elementType, numElements, isConst);
1842 }
1843 
1844 FIRRTLBaseType FVectorType::getElementType() const {
1845  return getImpl()->elementType;
1846 }
1847 
1848 size_t FVectorType::getNumElements() const { return getImpl()->numElements; }
1849 
1850 /// Return the recursive properties of the type.
1851 RecursiveTypeProperties FVectorType::getRecursiveTypeProperties() const {
1852  return getImpl()->props;
1853 }
1854 
1855 /// Return this type with any flip types recursively removed from itself.
1857  auto *impl = getImpl();
1858 
1859  // If we've already determined and cached the passive type, use it.
1860  if (impl->passiveType)
1861  return impl->passiveType;
1862 
1863  // If this type is already passive, return it and remember for next time.
1864  if (impl->elementType.getRecursiveTypeProperties().isPassive)
1865  return impl->passiveType = *this;
1866 
1867  // Otherwise, rebuild a passive version.
1868  auto passiveType = FVectorType::get(getElementType().getPassiveType(),
1869  getNumElements(), isConst());
1870  impl->passiveType = passiveType;
1871  return passiveType;
1872 }
1873 
1874 FVectorType FVectorType::getConstType(bool isConst) {
1875  if (isConst == this->isConst())
1876  return *this;
1877  return get(getElementType(), getNumElements(), isConst);
1878 }
1879 
1880 FVectorType FVectorType::getAllConstDroppedType() {
1881  if (!containsConst())
1882  return *this;
1883  return get(getElementType().getAllConstDroppedType(), getNumElements(),
1884  false);
1885 }
1886 
1887 /// Return this type with any type aliases recursively removed from itself.
1888 FIRRTLBaseType FVectorType::getAnonymousType() {
1889  auto *impl = getImpl();
1890 
1891  if (impl->anonymousType)
1892  return impl->anonymousType;
1893 
1894  // If this type is already anonymous, return it and remember for next time.
1895  if (!impl->props.containsTypeAlias)
1896  return impl->anonymousType = *this;
1897 
1898  // Otherwise, rebuild an anonymous version.
1899  auto anonymousType = FVectorType::get(getElementType().getAnonymousType(),
1900  getNumElements(), isConst());
1901  impl->anonymousType = anonymousType;
1902  return anonymousType;
1903 }
1904 
1905 uint64_t FVectorType::getFieldID(uint64_t index) const {
1906  return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
1907 }
1908 
1909 uint64_t FVectorType::getIndexForFieldID(uint64_t fieldID) const {
1910  assert(fieldID && "fieldID must be at least 1");
1911  // Divide the field ID by the number of fieldID's per element.
1912  return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
1913 }
1914 
1915 std::pair<uint64_t, uint64_t>
1916 FVectorType::getIndexAndSubfieldID(uint64_t fieldID) const {
1917  auto index = getIndexForFieldID(fieldID);
1918  auto elementFieldID = getFieldID(index);
1919  return {index, fieldID - elementFieldID};
1920 }
1921 
1922 std::pair<Type, uint64_t>
1923 FVectorType::getSubTypeByFieldID(uint64_t fieldID) const {
1924  if (fieldID == 0)
1925  return {*this, 0};
1926  return {getElementType(), getIndexAndSubfieldID(fieldID).second};
1927 }
1928 
1929 uint64_t FVectorType::getMaxFieldID() const {
1930  return getNumElements() *
1931  (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
1932 }
1933 
1934 std::pair<uint64_t, bool>
1935 FVectorType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
1936  auto childRoot = getFieldID(index);
1937  auto rangeEnd =
1938  index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
1939  return std::make_pair(fieldID - childRoot,
1940  fieldID >= childRoot && fieldID <= rangeEnd);
1941 }
1942 
1943 bool FVectorType::isConst() { return getImpl()->isConst; }
1944 
1945 FVectorType::ElementType FVectorType::getElementTypePreservingConst() {
1946  auto type = getElementType();
1947  return type.getConstType(type.isConst() || isConst());
1948 }
1949 
1950 //===----------------------------------------------------------------------===//
1951 // OpenVectorType
1952 //===----------------------------------------------------------------------===//
1953 
1955  using KeyTy = std::tuple<FIRRTLType, size_t, char>;
1956 
1958  bool isConst)
1960  isConst(static_cast<char>(isConst)) {
1963  }
1964 
1965  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1966 
1968 
1969  static OpenVectorTypeStorage *construct(TypeStorageAllocator &allocator,
1970  KeyTy key) {
1971  return new (allocator.allocate<OpenVectorTypeStorage>())
1972  OpenVectorTypeStorage(std::get<0>(key), std::get<1>(key),
1973  static_cast<bool>(std::get<2>(key)));
1974  }
1975 
1977  size_t numElements;
1978 
1980  char isConst;
1981 };
1982 
1983 OpenVectorType OpenVectorType::get(FIRRTLType elementType, size_t numElements,
1984  bool isConst) {
1985  return Base::get(elementType.getContext(), elementType, numElements, isConst);
1986 }
1987 
1988 FIRRTLType OpenVectorType::getElementType() const {
1989  return getImpl()->elementType;
1990 }
1991 
1992 size_t OpenVectorType::getNumElements() const { return getImpl()->numElements; }
1993 
1994 /// Return the recursive properties of the type.
1995 RecursiveTypeProperties OpenVectorType::getRecursiveTypeProperties() const {
1996  return getImpl()->props;
1997 }
1998 
1999 OpenVectorType OpenVectorType::getConstType(bool isConst) {
2000  if (isConst == this->isConst())
2001  return *this;
2002  return get(getElementType(), getNumElements(), isConst);
2003 }
2004 
2005 uint64_t OpenVectorType::getFieldID(uint64_t index) const {
2006  return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
2007 }
2008 
2009 uint64_t OpenVectorType::getIndexForFieldID(uint64_t fieldID) const {
2010  assert(fieldID && "fieldID must be at least 1");
2011  // Divide the field ID by the number of fieldID's per element.
2012  return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
2013 }
2014 
2015 std::pair<uint64_t, uint64_t>
2016 OpenVectorType::getIndexAndSubfieldID(uint64_t fieldID) const {
2017  auto index = getIndexForFieldID(fieldID);
2018  auto elementFieldID = getFieldID(index);
2019  return {index, fieldID - elementFieldID};
2020 }
2021 
2022 std::pair<Type, uint64_t>
2023 OpenVectorType::getSubTypeByFieldID(uint64_t fieldID) const {
2024  if (fieldID == 0)
2025  return {*this, 0};
2026  return {getElementType(), getIndexAndSubfieldID(fieldID).second};
2027 }
2028 
2029 uint64_t OpenVectorType::getMaxFieldID() const {
2030  // If this is requirement, make ODS constraint or actual elementType.
2031  return getNumElements() *
2032  (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
2033 }
2034 
2035 std::pair<uint64_t, bool>
2036 OpenVectorType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
2037  auto childRoot = getFieldID(index);
2038  auto rangeEnd =
2039  index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
2040  return std::make_pair(fieldID - childRoot,
2041  fieldID >= childRoot && fieldID <= rangeEnd);
2042 }
2043 
2044 bool OpenVectorType::isConst() { return getImpl()->isConst; }
2045 
2046 OpenVectorType::ElementType OpenVectorType::getElementTypePreservingConst() {
2047  auto type = getElementType();
2048  // TODO: ConstTypeInterface / Trait ?
2049  return TypeSwitch<FIRRTLType, ElementType>(type)
2050  .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>([&](auto type) {
2051  return type.getConstType(type.isConst() || isConst());
2052  })
2053  .Default(type);
2054 }
2055 
2056 LogicalResult
2057 OpenVectorType::verify(function_ref<InFlightDiagnostic()> emitErrorFn,
2059  bool isConst) {
2060  if (elementType.containsReference() && isConst)
2061  return emitErrorFn() << "vector cannot be const with references";
2062  return success();
2063 }
2064 
2065 //===----------------------------------------------------------------------===//
2066 // FEnum Type
2067 //===----------------------------------------------------------------------===//
2068 
2070  using KeyTy = std::pair<ArrayRef<FEnumType::EnumElement>, char>;
2071 
2072  FEnumTypeStorage(ArrayRef<FEnumType::EnumElement> elements, bool isConst)
2074  elements(elements.begin(), elements.end()) {
2075  RecursiveTypeProperties props{true, false, false, isConst,
2076  false, false, false};
2077  uint64_t fieldID = 0;
2078  fieldIDs.reserve(elements.size());
2079  for (auto &element : elements) {
2080  auto type = element.type;
2081  auto eltInfo = type.getRecursiveTypeProperties();
2082  props.isPassive &= eltInfo.isPassive;
2083  props.containsAnalog |= eltInfo.containsAnalog;
2084  props.containsConst |= eltInfo.containsConst;
2085  props.containsReference |= eltInfo.containsReference;
2086  props.containsTypeAlias |= eltInfo.containsTypeAlias;
2087  props.hasUninferredReset |= eltInfo.hasUninferredReset;
2088  props.hasUninferredWidth |= eltInfo.hasUninferredWidth;
2089  fieldID += 1;
2090  fieldIDs.push_back(fieldID);
2091  // Increment the field ID for the next field by the number of subfields.
2092  fieldID += hw::FieldIdImpl::getMaxFieldID(type);
2093  }
2094  maxFieldID = fieldID;
2095  recProps = props;
2096  }
2097 
2098  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
2099 
2100  KeyTy getAsKey() const { return KeyTy(elements, isConst); }
2101 
2102  static llvm::hash_code hashKey(const KeyTy &key) {
2103  return llvm::hash_combine(
2104  llvm::hash_combine_range(key.first.begin(), key.first.end()),
2105  key.second);
2106  }
2107 
2108  static FEnumTypeStorage *construct(TypeStorageAllocator &allocator,
2109  KeyTy key) {
2110  return new (allocator.allocate<FEnumTypeStorage>())
2111  FEnumTypeStorage(key.first, static_cast<bool>(key.second));
2112  }
2113 
2114  SmallVector<FEnumType::EnumElement, 4> elements;
2115  SmallVector<uint64_t, 4> fieldIDs;
2116  uint64_t maxFieldID;
2117 
2120 };
2121 
2122 FEnumType FEnumType::get(::mlir::MLIRContext *context,
2123  ArrayRef<EnumElement> elements, bool isConst) {
2124  return Base::get(context, elements, isConst);
2125 }
2126 
2127 ArrayRef<FEnumType::EnumElement> FEnumType::getElements() const {
2128  return getImpl()->elements;
2129 }
2130 
2131 FEnumType FEnumType::getConstType(bool isConst) {
2132  return get(getContext(), getElements(), isConst);
2133 }
2134 
2135 FEnumType FEnumType::getAllConstDroppedType() {
2136  if (!containsConst())
2137  return *this;
2138 
2139  SmallVector<EnumElement> constDroppedElements(
2140  llvm::map_range(getElements(), [](EnumElement element) {
2141  element.type = element.type.getAllConstDroppedType();
2142  return element;
2143  }));
2144  return get(getContext(), constDroppedElements, false);
2145 }
2146 
2147 /// Return a pair with the 'isPassive' and 'containsAnalog' bits.
2148 RecursiveTypeProperties FEnumType::getRecursiveTypeProperties() const {
2149  return getImpl()->recProps;
2150 }
2151 
2152 std::optional<unsigned> FEnumType::getElementIndex(StringAttr name) {
2153  for (const auto &it : llvm::enumerate(getElements())) {
2154  auto element = it.value();
2155  if (element.name == name) {
2156  return unsigned(it.index());
2157  }
2158  }
2159  return std::nullopt;
2160 }
2161 
2162 std::optional<unsigned> FEnumType::getElementIndex(StringRef name) {
2163  for (const auto &it : llvm::enumerate(getElements())) {
2164  auto element = it.value();
2165  if (element.name.getValue() == name) {
2166  return unsigned(it.index());
2167  }
2168  }
2169  return std::nullopt;
2170 }
2171 
2172 StringAttr FEnumType::getElementNameAttr(size_t index) {
2173  assert(index < getNumElements() &&
2174  "index must be less than number of fields in enum");
2175  return getElements()[index].name;
2176 }
2177 
2178 StringRef FEnumType::getElementName(size_t index) {
2179  return getElementNameAttr(index).getValue();
2180 }
2181 
2182 std::optional<FEnumType::EnumElement> FEnumType::getElement(StringAttr name) {
2183  if (auto maybeIndex = getElementIndex(name))
2184  return getElements()[*maybeIndex];
2185  return std::nullopt;
2186 }
2187 
2188 std::optional<FEnumType::EnumElement> FEnumType::getElement(StringRef name) {
2189  if (auto maybeIndex = getElementIndex(name))
2190  return getElements()[*maybeIndex];
2191  return std::nullopt;
2192 }
2193 
2194 /// Look up an element by index.
2195 FEnumType::EnumElement FEnumType::getElement(size_t index) {
2196  assert(index < getNumElements() &&
2197  "index must be less than number of fields in enum");
2198  return getElements()[index];
2199 }
2200 
2201 FIRRTLBaseType FEnumType::getElementType(StringAttr name) {
2202  auto element = getElement(name);
2203  return element ? element->type : FIRRTLBaseType();
2204 }
2205 
2206 FIRRTLBaseType FEnumType::getElementType(StringRef name) {
2207  auto element = getElement(name);
2208  return element ? element->type : FIRRTLBaseType();
2209 }
2210 
2211 FIRRTLBaseType FEnumType::getElementType(size_t index) const {
2212  assert(index < getNumElements() &&
2213  "index must be less than number of fields in enum");
2214  return getElements()[index].type;
2215 }
2216 
2217 FIRRTLBaseType FEnumType::getElementTypePreservingConst(size_t index) {
2218  auto type = getElementType(index);
2219  return type.getConstType(type.isConst() || isConst());
2220 }
2221 
2222 uint64_t FEnumType::getFieldID(uint64_t index) const {
2223  return getImpl()->fieldIDs[index];
2224 }
2225 
2226 uint64_t FEnumType::getIndexForFieldID(uint64_t fieldID) const {
2227  assert(!getElements().empty() && "Enum must have >0 fields");
2228  auto fieldIDs = getImpl()->fieldIDs;
2229  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
2230  return std::distance(fieldIDs.begin(), it);
2231 }
2232 
2233 std::pair<uint64_t, uint64_t>
2234 FEnumType::getIndexAndSubfieldID(uint64_t fieldID) const {
2235  auto index = getIndexForFieldID(fieldID);
2236  auto elementFieldID = getFieldID(index);
2237  return {index, fieldID - elementFieldID};
2238 }
2239 
2240 std::pair<Type, uint64_t>
2241 FEnumType::getSubTypeByFieldID(uint64_t fieldID) const {
2242  if (fieldID == 0)
2243  return {*this, 0};
2244  auto subfieldIndex = getIndexForFieldID(fieldID);
2245  auto subfieldType = getElementType(subfieldIndex);
2246  auto subfieldID = fieldID - getFieldID(subfieldIndex);
2247  return {subfieldType, subfieldID};
2248 }
2249 
2250 uint64_t FEnumType::getMaxFieldID() const { return getImpl()->maxFieldID; }
2251 
2252 std::pair<uint64_t, bool>
2253 FEnumType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
2254  auto childRoot = getFieldID(index);
2255  auto rangeEnd = index + 1 >= getNumElements() ? getMaxFieldID()
2256  : (getFieldID(index + 1) - 1);
2257  return std::make_pair(fieldID - childRoot,
2258  fieldID >= childRoot && fieldID <= rangeEnd);
2259 }
2260 
2261 auto FEnumType::verify(function_ref<InFlightDiagnostic()> emitErrorFn,
2262  ArrayRef<EnumElement> elements, bool isConst)
2263  -> LogicalResult {
2264  for (auto &elt : elements) {
2265  auto r = elt.type.getRecursiveTypeProperties();
2266  if (!r.isPassive)
2267  return emitErrorFn() << "enum field '" << elt.name << "' not passive";
2268  if (r.containsAnalog)
2269  return emitErrorFn() << "enum field '" << elt.name << "' contains analog";
2270  if (r.containsConst && !isConst)
2271  return emitErrorFn() << "enum with 'const' elements must be 'const'";
2272  // TODO: exclude reference containing
2273  }
2274  return success();
2275 }
2276 
2277 /// Return this type with any type aliases recursively removed from itself.
2278 FIRRTLBaseType FEnumType::getAnonymousType() {
2279  auto *impl = getImpl();
2280 
2281  if (impl->anonymousType)
2282  return impl->anonymousType;
2283 
2284  if (!impl->recProps.containsTypeAlias)
2285  return impl->anonymousType = *this;
2286 
2287  SmallVector<FEnumType::EnumElement, 4> elements;
2288 
2289  for (auto element : getElements())
2290  elements.push_back({element.name, element.type.getAnonymousType()});
2291  return impl->anonymousType = FEnumType::get(getContext(), elements);
2292 }
2293 
2294 //===----------------------------------------------------------------------===//
2295 // BaseTypeAliasType
2296 //===----------------------------------------------------------------------===//
2297 
2300  using KeyTy = std::tuple<StringAttr, FIRRTLBaseType>;
2301 
2304  innerType(innerType) {}
2305 
2306  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
2307 
2308  KeyTy getAsKey() const { return KeyTy(name, innerType); }
2309 
2310  static llvm::hash_code hashKey(const KeyTy &key) {
2311  return llvm::hash_combine(key);
2312  }
2313 
2314  static BaseTypeAliasStorage *construct(TypeStorageAllocator &allocator,
2315  KeyTy key) {
2316  return new (allocator.allocate<BaseTypeAliasStorage>())
2317  BaseTypeAliasStorage(std::get<0>(key), std::get<1>(key));
2318  }
2319  StringAttr name;
2322 };
2323 
2324 auto BaseTypeAliasType::get(StringAttr name, FIRRTLBaseType innerType)
2325  -> BaseTypeAliasType {
2326  return Base::get(name.getContext(), name, innerType);
2327 }
2328 
2329 auto BaseTypeAliasType::getName() const -> StringAttr {
2330  return getImpl()->name;
2331 }
2332 
2333 auto BaseTypeAliasType::getInnerType() const -> FIRRTLBaseType {
2334  return getImpl()->innerType;
2335 }
2336 
2337 FIRRTLBaseType BaseTypeAliasType::getAnonymousType() {
2338  auto *impl = getImpl();
2339  if (impl->anonymousType)
2340  return impl->anonymousType;
2341  return impl->anonymousType = getInnerType().getAnonymousType();
2342 }
2343 
2345  return getModifiedType(getInnerType().getPassiveType());
2346 }
2347 
2348 RecursiveTypeProperties BaseTypeAliasType::getRecursiveTypeProperties() const {
2349  auto rtp = getInnerType().getRecursiveTypeProperties();
2350  rtp.containsTypeAlias = true;
2351  return rtp;
2352 }
2353 
2354 // If a given `newInnerType` is identical to innerType, return `*this`
2355 // because we can reuse the type alias. Otherwise return `newInnerType`.
2356 FIRRTLBaseType BaseTypeAliasType::getModifiedType(FIRRTLBaseType newInnerType) {
2357  if (newInnerType == getInnerType())
2358  return *this;
2359  return newInnerType;
2360 }
2361 
2362 // FieldIDTypeInterface implementation.
2363 FIRRTLBaseType BaseTypeAliasType::getAllConstDroppedType() {
2364  return getModifiedType(getInnerType().getAllConstDroppedType());
2365 }
2366 
2367 FIRRTLBaseType BaseTypeAliasType::getConstType(bool isConst) {
2368  return getModifiedType(getInnerType().getConstType(isConst));
2369 }
2370 
2371 std::pair<Type, uint64_t>
2372 BaseTypeAliasType::getSubTypeByFieldID(uint64_t fieldID) const {
2373  return hw::FieldIdImpl::getSubTypeByFieldID(getInnerType(), fieldID);
2374 }
2375 
2376 uint64_t BaseTypeAliasType::getMaxFieldID() const {
2377  return hw::FieldIdImpl::getMaxFieldID(getInnerType());
2378 }
2379 
2380 std::pair<uint64_t, bool>
2382  uint64_t index) const {
2383  return hw::FieldIdImpl::projectToChildFieldID(getInnerType(), fieldID, index);
2384 }
2385 
2386 uint64_t BaseTypeAliasType::getIndexForFieldID(uint64_t fieldID) const {
2387  return hw::FieldIdImpl::getIndexForFieldID(getInnerType(), fieldID);
2388 }
2389 
2390 uint64_t BaseTypeAliasType::getFieldID(uint64_t index) const {
2391  return hw::FieldIdImpl::getFieldID(getInnerType(), index);
2392 }
2393 
2394 std::pair<uint64_t, uint64_t>
2395 BaseTypeAliasType::getIndexAndSubfieldID(uint64_t fieldID) const {
2396  return hw::FieldIdImpl::getIndexAndSubfieldID(getInnerType(), fieldID);
2397 }
2398 
2399 //===----------------------------------------------------------------------===//
2400 // RefType
2401 //===----------------------------------------------------------------------===//
2402 
2403 auto RefType::get(FIRRTLBaseType type, bool forceable) -> RefType {
2404  return Base::get(type.getContext(), type, forceable);
2405 }
2406 
2407 auto RefType::verify(function_ref<InFlightDiagnostic()> emitErrorFn,
2408  FIRRTLBaseType base, bool forceable) -> LogicalResult {
2409  if (!base.isPassive())
2410  return emitErrorFn() << "reference base type must be passive";
2411  if (forceable && base.containsConst())
2412  return emitErrorFn()
2413  << "forceable reference base type cannot contain const";
2414  return success();
2415 }
2416 
2417 RecursiveTypeProperties RefType::getRecursiveTypeProperties() const {
2418  auto rtp = getType().getRecursiveTypeProperties();
2419  rtp.containsReference = true;
2420  // References are not "passive", per FIRRTL spec.
2421  rtp.isPassive = false;
2422  return rtp;
2423 }
2424 
2425 //===----------------------------------------------------------------------===//
2426 // AnalogType
2427 //===----------------------------------------------------------------------===//
2428 
2429 AnalogType AnalogType::get(mlir::MLIRContext *context) {
2430  return AnalogType::get(context, -1, false);
2431 }
2432 
2433 AnalogType AnalogType::get(mlir::MLIRContext *context,
2434  std::optional<int32_t> width, bool isConst) {
2435  return AnalogType::get(context, width ? *width : -1, isConst);
2436 }
2437 
2438 LogicalResult AnalogType::verify(function_ref<InFlightDiagnostic()> emitError,
2439  int32_t widthOrSentinel, bool isConst) {
2440  if (widthOrSentinel < -1)
2441  return emitError() << "invalid width";
2442  return success();
2443 }
2444 
2445 int32_t AnalogType::getWidthOrSentinel() const { return getImpl()->width; }
2446 
2447 AnalogType AnalogType::getConstType(bool isConst) {
2448  if (isConst == this->isConst())
2449  return *this;
2450  return get(getContext(), getWidthOrSentinel(), isConst);
2451 }
2452 
2453 //===----------------------------------------------------------------------===//
2454 // ClockType
2455 //===----------------------------------------------------------------------===//
2456 
2457 ClockType ClockType::getConstType(bool isConst) {
2458  if (isConst == this->isConst())
2459  return *this;
2460  return get(getContext(), isConst);
2461 }
2462 
2463 //===----------------------------------------------------------------------===//
2464 // ResetType
2465 //===----------------------------------------------------------------------===//
2466 
2467 ResetType ResetType::getConstType(bool isConst) {
2468  if (isConst == this->isConst())
2469  return *this;
2470  return get(getContext(), isConst);
2471 }
2472 
2473 //===----------------------------------------------------------------------===//
2474 // AsyncResetType
2475 //===----------------------------------------------------------------------===//
2476 
2477 AsyncResetType AsyncResetType::getConstType(bool isConst) {
2478  if (isConst == this->isConst())
2479  return *this;
2480  return get(getContext(), isConst);
2481 }
2482 
2483 //===----------------------------------------------------------------------===//
2484 // ClassType
2485 //===----------------------------------------------------------------------===//
2486 
2487 struct circt::firrtl::detail::ClassTypeStorage : mlir::TypeStorage {
2488  using KeyTy = std::pair<FlatSymbolRefAttr, ArrayRef<ClassElement>>;
2489 
2490  static ClassTypeStorage *construct(TypeStorageAllocator &allocator,
2491  KeyTy key) {
2492  auto name = key.first;
2493  auto elements = allocator.copyInto(key.second);
2494 
2495  // build the field ID table
2496  SmallVector<uint64_t, 4> ids;
2497  uint64_t id = 0;
2498  ids.reserve(elements.size());
2499  for (auto &element : elements) {
2500  id += 1;
2501  ids.push_back(id);
2502  id += hw::FieldIdImpl::getMaxFieldID(element.type);
2503  }
2504 
2505  auto fieldIDs = allocator.copyInto(ArrayRef(ids));
2506  auto maxFieldID = id;
2507 
2508  return new (allocator.allocate<ClassTypeStorage>())
2510  }
2511 
2512  ClassTypeStorage(FlatSymbolRefAttr name, ArrayRef<ClassElement> elements,
2513  ArrayRef<uint64_t> fieldIDs, uint64_t maxFieldID)
2516 
2517  bool operator==(const KeyTy &key) const {
2518  return name == key.first && elements == key.second;
2519  }
2520 
2521  FlatSymbolRefAttr name;
2522  ArrayRef<ClassElement> elements;
2523  ArrayRef<uint64_t> fieldIDs;
2524  uint64_t maxFieldID;
2525 };
2526 
2527 ClassType ClassType::get(FlatSymbolRefAttr name,
2528  ArrayRef<ClassElement> elements) {
2529  return get(name.getContext(), name, elements);
2530 }
2531 
2532 StringRef ClassType::getName() const {
2533  return getNameAttr().getAttr().getValue();
2534 }
2535 
2536 FlatSymbolRefAttr ClassType::getNameAttr() const { return getImpl()->name; }
2537 
2538 ArrayRef<ClassElement> ClassType::getElements() const {
2539  return getImpl()->elements;
2540 }
2541 
2542 const ClassElement &ClassType::getElement(IntegerAttr index) const {
2543  return getElement(index.getValue().getZExtValue());
2544 }
2545 
2546 const ClassElement &ClassType::getElement(size_t index) const {
2547  return getElements()[index];
2548 }
2549 
2550 std::optional<uint64_t> ClassType::getElementIndex(StringRef fieldName) const {
2551  for (const auto [i, e] : llvm::enumerate(getElements()))
2552  if (fieldName == e.name)
2553  return i;
2554  return {};
2555 }
2556 
2557 void ClassType::printInterface(AsmPrinter &p) const {
2558  p.printSymbolName(getName());
2559  p << "(";
2560  bool first = true;
2561  for (const auto &element : getElements()) {
2562  if (!first)
2563  p << ", ";
2564  p << element.direction << " ";
2565  p.printKeywordOrString(element.name);
2566  p << ": " << element.type;
2567  first = false;
2568  }
2569  p << ")";
2570 }
2571 
2572 uint64_t ClassType::getFieldID(uint64_t index) const {
2573  return getImpl()->fieldIDs[index];
2574 }
2575 
2576 uint64_t ClassType::getIndexForFieldID(uint64_t fieldID) const {
2577  assert(!getElements().empty() && "Class must have >0 fields");
2578  auto fieldIDs = getImpl()->fieldIDs;
2579  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
2580  return std::distance(fieldIDs.begin(), it);
2581 }
2582 
2583 std::pair<uint64_t, uint64_t>
2584 ClassType::getIndexAndSubfieldID(uint64_t fieldID) const {
2585  auto index = getIndexForFieldID(fieldID);
2586  auto elementFieldID = getFieldID(index);
2587  return {index, fieldID - elementFieldID};
2588 }
2589 
2590 std::pair<Type, uint64_t>
2591 ClassType::getSubTypeByFieldID(uint64_t fieldID) const {
2592  if (fieldID == 0)
2593  return {*this, 0};
2594  auto subfieldIndex = getIndexForFieldID(fieldID);
2595  auto subfieldType = getElement(subfieldIndex).type;
2596  auto subfieldID = fieldID - getFieldID(subfieldIndex);
2597  return {subfieldType, subfieldID};
2598 }
2599 
2600 uint64_t ClassType::getMaxFieldID() const { return getImpl()->maxFieldID; }
2601 
2602 std::pair<uint64_t, bool>
2603 ClassType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
2604  auto childRoot = getFieldID(index);
2605  auto rangeEnd = index + 1 >= getNumElements() ? getMaxFieldID()
2606  : (getFieldID(index + 1) - 1);
2607  return std::make_pair(fieldID - childRoot,
2608  fieldID >= childRoot && fieldID <= rangeEnd);
2609 }
2610 
2611 ParseResult ClassType::parseInterface(AsmParser &parser, ClassType &result) {
2612  StringAttr className;
2613  if (parser.parseSymbolName(className))
2614  return failure();
2615 
2616  SmallVector<ClassElement> elements;
2617  if (parser.parseCommaSeparatedList(
2618  OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
2619  // Parse port direction.
2620  Direction direction;
2621  if (succeeded(parser.parseOptionalKeyword("out")))
2622  direction = Direction::Out;
2623  else if (succeeded(parser.parseKeyword("in", "or 'out'")))
2624  direction = Direction::In;
2625  else
2626  return failure();
2627 
2628  // Parse port name.
2629  std::string keyword;
2630  if (parser.parseKeywordOrString(&keyword))
2631  return failure();
2632  StringAttr name = StringAttr::get(parser.getContext(), keyword);
2633 
2634  // Parse port type.
2635  Type type;
2636  if (parser.parseColonType(type))
2637  return failure();
2638 
2639  elements.emplace_back(name, type, direction);
2640  return success();
2641  }))
2642  return failure();
2643 
2644  result = ClassType::get(FlatSymbolRefAttr::get(className), elements);
2645  return success();
2646 }
2647 
2648 //===----------------------------------------------------------------------===//
2649 // FIRRTLDialect
2650 //===----------------------------------------------------------------------===//
2651 
2652 void FIRRTLDialect::registerTypes() {
2653  addTypes<
2654 #define GET_TYPEDEF_LIST
2655 #include "circt/Dialect/FIRRTL/FIRRTLTypes.cpp.inc"
2656  >();
2657 }
2658 
2659 // Get the bit width for this type, return None if unknown. Unlike
2660 // getBitWidthOrSentinel(), this can recursively compute the bitwidth of
2661 // aggregate types. For bundle and vectors, recursively get the width of each
2662 // field element and return the total bit width of the aggregate type. This
2663 // returns None, if any of the bundle fields is a flip type, or ground type with
2664 // unknown bit width.
2665 std::optional<int64_t> firrtl::getBitWidth(FIRRTLBaseType type,
2666  bool ignoreFlip) {
2667  std::function<std::optional<int64_t>(FIRRTLBaseType)> getWidth =
2668  [&](FIRRTLBaseType type) -> std::optional<int64_t> {
2669  return TypeSwitch<FIRRTLBaseType, std::optional<int64_t>>(type)
2670  .Case<BundleType>([&](BundleType bundle) -> std::optional<int64_t> {
2671  int64_t width = 0;
2672  for (auto &elt : bundle) {
2673  if (elt.isFlip && !ignoreFlip)
2674  return std::nullopt;
2675  auto w = getBitWidth(elt.type);
2676  if (!w.has_value())
2677  return std::nullopt;
2678  width += *w;
2679  }
2680  return width;
2681  })
2682  .Case<FEnumType>([&](FEnumType fenum) -> std::optional<int64_t> {
2683  int64_t width = 0;
2684  for (auto &elt : fenum) {
2685  auto w = getBitWidth(elt.type);
2686  if (!w.has_value())
2687  return std::nullopt;
2688  width = std::max(width, *w);
2689  }
2690  return width + llvm::Log2_32_Ceil(fenum.getNumElements());
2691  })
2692  .Case<FVectorType>([&](auto vector) -> std::optional<int64_t> {
2693  auto w = getBitWidth(vector.getElementType());
2694  if (!w.has_value())
2695  return std::nullopt;
2696  return *w * vector.getNumElements();
2697  })
2698  .Case<IntType>([&](IntType iType) { return iType.getWidth(); })
2699  .Case<ClockType, ResetType, AsyncResetType>([](Type) { return 1; })
2700  .Default([&](auto t) { return std::nullopt; });
2701  };
2702  return getWidth(type);
2703 }
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition: CHIRRTL.cpp:26
MlirType elementType
Definition: CHIRRTL.cpp:25
static ParseResult parseFIRRTLBaseType(FIRRTLBaseType &result, StringRef name, AsmParser &parser)
static ParseResult parseFIRRTLPropertyType(PropertyType &result, StringRef name, AsmParser &parser)
static LogicalResult customTypePrinter(Type type, AsmPrinter &os)
Print a type with a custom printer implementation.
Definition: FIRRTLTypes.cpp:45
static OptionalParseResult customTypeParser(AsmParser &parser, StringRef name, Type &result)
Parse a type with a custom parser implementation.
static ParseResult parseType(Type &result, StringRef name, AsmParser &parser)
Parse a type defined by this dialect.
@ ContainsAnalogBitMask
Bit set if the type contains an analog type.
@ HasUninferredWidthBitMask
Bit set fi the type has any uninferred bit widths.
@ IsPassiveBitMask
Bit set if the type only contains passive elements.
static bool areBundleElementsEquivalent(BundleType::BundleElement destElement, BundleType::BundleElement srcElement, bool destOuterTypeIsConst, bool srcOuterTypeIsConst, bool requiresSameWidth)
Helper to implement the equivalence logic for a pair of bundle elements.
static ParseResult parseFIRRTLType(FIRRTLType &result, StringRef name, AsmParser &parser)
Parse a FIRRTLType with a name that has already been parsed.
int32_t width
Definition: FIRRTL.cpp:27
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
static unsigned getMaxFieldID(FIRRTLBaseType type)
static InstancePath empty
FIRRTLBaseType getAnonymousType()
Return this type with any type alias types recursively removed from itself.
bool isResetType()
Return true if this is a valid "reset" type.
FIRRTLBaseType getMaskType()
Return this type with all ground types replaced with UInt<1>.
FIRRTLBaseType getPassiveType()
Return this type with any flip types recursively removed from itself.
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
FIRRTLBaseType getAllConstDroppedType()
Return this type with a 'const' modifiers dropped.
FIRRTLBaseType getConstType(bool isConst)
Return a 'const' or non-'const' version of this type.
FIRRTLBaseType getWidthlessType()
Return this type with widths of all ground types removed.
bool containsReference()
Return true if this is or contains a Reference type.
Definition: FIRRTLTypes.h:105
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
RecursiveTypeProperties getRecursiveTypeProperties() const
Return the recursive properties of the type, containing the isPassive, containsAnalog,...
This is the common base class between SIntType and UIntType.
Definition: FIRRTLTypes.h:291
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
IntType getConstType(bool isConst)
Return a 'const' or non-'const' version of this type.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
Definition: FIRRTLTypes.h:270
Represents a limited word-length unsigned integer in SystemC as described in IEEE 1666-2011 ยง7....
Definition: SystemCTypes.h:135
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:34
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
ParseResult parseNestedType(FIRRTLType &result, AsmParser &parser)
Parse a FIRRTLType.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
ParseResult parseNestedBaseType(FIRRTLBaseType &result, AsmParser &parser)
bool isTypeInOut(mlir::Type type)
Returns true if the given type has some flipped (aka unaligned) dataflow.
bool areTypesRefCastable(Type dstType, Type srcType)
Return true if destination ref type can be cast from source ref type, per FIRRTL spec rules they must...
bool areTypesEquivalent(FIRRTLType destType, FIRRTLType srcType, bool destOuterTypeIsConst=false, bool srcOuterTypeIsConst=false, bool requireSameWidths=false)
Returns whether the two types are equivalent.
bool areTypesWeaklyEquivalent(FIRRTLType destType, FIRRTLType srcType, bool destFlip=false, bool srcFlip=false, bool destOuterTypeIsConst=false, bool srcOuterTypeIsConst=false)
Returns true if two types are weakly equivalent.
mlir::Type getPassiveType(mlir::Type anyBaseFIRRTLType)
bool isTypeLarger(FIRRTLBaseType dstType, FIRRTLBaseType srcType)
Returns true if the destination is at least as wide as a source.
bool containsConst(Type type)
Returns true if the type is or contains a 'const' type whose value is guaranteed to be unchanging at ...
void printNestedType(Type type, AsmPrinter &os)
Print a type defined by this dialect.
bool isConst(Type type)
Returns true if this is a 'const' type whose value is guaranteed to be unchanging at circuit executio...
bool areTypesConstCastable(FIRRTLType destType, FIRRTLType srcType, bool srcOuterTypeIsConst=false)
Returns whether the srcType can be const-casted to the destType.
ParseResult parseNestedPropertyType(PropertyType &result, AsmParser &parser)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
std::pair< uint64_t, uint64_t > getIndexAndSubfieldID(Type type, uint64_t fieldID)
uint64_t getFieldID(Type type, uint64_t index)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
std::pair< uint64_t, bool > projectToChildFieldID(Type, uint64_t fieldID, uint64_t index)
uint64_t getIndexForFieldID(Type type, uint64_t fieldID)
uint64_t getMaxFieldID(Type)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
A collection of bits indicating the recursive properties of a type.
Definition: FIRRTLTypes.h:65
bool containsReference
Whether the type contains a reference type.
Definition: FIRRTLTypes.h:69
bool isPassive
Whether the type only contains passive elements.
Definition: FIRRTLTypes.h:67
bool containsAnalog
Whether the type contains an analog type.
Definition: FIRRTLTypes.h:71
bool hasUninferredReset
Whether the type has any uninferred reset.
Definition: FIRRTLTypes.h:79
bool containsTypeAlias
Whether the type contains a type alias.
Definition: FIRRTLTypes.h:75
bool containsConst
Whether the type contains a const type.
Definition: FIRRTLTypes.h:73
bool hasUninferredWidth
Whether the type has any uninferred bit widths.
Definition: FIRRTLTypes.h:77
bool operator==(const KeyTy &key) const
static BaseTypeAliasStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
BaseTypeAliasStorage(StringAttr name, FIRRTLBaseType innerType)
std::tuple< StringAttr, FIRRTLBaseType > KeyTy
static llvm::hash_code hashKey(const KeyTy &key)
SmallVector< BundleType::BundleElement, 4 > elements
static BundleTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
std::pair< ArrayRef< BundleType::BundleElement >, char > KeyTy
static llvm::hash_code hashKey(const KeyTy &key)
RecursiveTypeProperties props
This holds the bits for the type's recursive properties, and can hold a pointer to a passive version ...
BundleTypeStorage(ArrayRef< BundleType::BundleElement > elements, bool isConst)
bool operator==(const KeyTy &key) const
SmallVector< uint64_t, 4 > fieldIDs
std::pair< FlatSymbolRefAttr, ArrayRef< ClassElement > > KeyTy
bool operator==(const KeyTy &key) const
static ClassTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
ClassTypeStorage(FlatSymbolRefAttr name, ArrayRef< ClassElement > elements, ArrayRef< uint64_t > fieldIDs, uint64_t maxFieldID)
SmallVector< FEnumType::EnumElement, 4 > elements
static llvm::hash_code hashKey(const KeyTy &key)
bool operator==(const KeyTy &key) const
FEnumTypeStorage(ArrayRef< FEnumType::EnumElement > elements, bool isConst)
std::pair< ArrayRef< FEnumType::EnumElement >, char > KeyTy
static FEnumTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
SmallVector< uint64_t, 4 > fieldIDs
static FIRRTLBaseTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
RecursiveTypeProperties props
This holds the bits for the type's recursive properties, and can hold a pointer to a passive version ...
static FVectorTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
std::tuple< FIRRTLBaseType, size_t, char > KeyTy
FVectorTypeStorage(FIRRTLBaseType elementType, size_t numElements, bool isConst)
SmallVector< OpenBundleType::BundleElement, 4 > elements
static OpenBundleTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
bool operator==(const KeyTy &key) const
static llvm::hash_code hashKey(const KeyTy &key)
RecursiveTypeProperties props
This holds the bits for the type's recursive properties, and can hold a pointer to a passive version ...
OpenBundleTypeStorage(ArrayRef< OpenBundleType::BundleElement > elements, bool isConst)
std::pair< ArrayRef< OpenBundleType::BundleElement >, char > KeyTy
std::tuple< FIRRTLType, size_t, char > KeyTy
bool operator==(const KeyTy &key) const
OpenVectorTypeStorage(FIRRTLType elementType, size_t numElements, bool isConst)
static OpenVectorTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
WidthTypeStorage(int32_t width, bool isConst)
static WidthTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< int32_t, char > KeyTy