CIRCT  19.0.0git
FIRRTLTypes.cpp
Go to the documentation of this file.
1 //===- FIRRTLTypes.cpp - Implement the FIRRTL dialect type system ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implement the FIRRTL dialect type system.
10 //
11 //===----------------------------------------------------------------------===//
12 
17 #include "mlir/IR/DialectImplementation.h"
18 #include "llvm/ADT/StringExtras.h"
19 #include "llvm/ADT/StringSwitch.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 
22 using namespace circt;
23 using namespace firrtl;
24 
25 using mlir::OptionalParseResult;
26 using mlir::TypeStorageAllocator;
27 
28 //===----------------------------------------------------------------------===//
29 // TableGen generated logic.
30 //===----------------------------------------------------------------------===//
31 
32 // Provide the autogenerated implementation for types.
33 #define GET_TYPEDEF_CLASSES
34 #include "circt/Dialect/FIRRTL/FIRRTLTypes.cpp.inc"
35 
36 //===----------------------------------------------------------------------===//
37 // Type Printing
38 //===----------------------------------------------------------------------===//
39 
40 // NOLINTBEGIN(misc-no-recursion)
41 /// Print a type with a custom printer implementation.
42 ///
43 /// This only prints a subset of all types in the dialect. Use `printNestedType`
44 /// instead, which will call this function in turn, as appropriate.
45 static LogicalResult customTypePrinter(Type type, AsmPrinter &os) {
46  if (isConst(type)) {
47  os << "const.";
48  }
49 
50  auto printWidthQualifier = [&](std::optional<int32_t> width) {
51  if (width)
52  os << '<' << *width << '>';
53  };
54  bool anyFailed = false;
55  TypeSwitch<Type>(type)
56  .Case<ClockType>([&](auto) { os << "clock"; })
57  .Case<ResetType>([&](auto) { os << "reset"; })
58  .Case<AsyncResetType>([&](auto) { os << "asyncreset"; })
59  .Case<SIntType>([&](auto sIntType) {
60  os << "sint";
61  printWidthQualifier(sIntType.getWidth());
62  })
63  .Case<UIntType>([&](auto uIntType) {
64  os << "uint";
65  printWidthQualifier(uIntType.getWidth());
66  })
67  .Case<AnalogType>([&](auto analogType) {
68  os << "analog";
69  printWidthQualifier(analogType.getWidth());
70  })
71  .Case<BundleType, OpenBundleType>([&](auto bundleType) {
72  if (firrtl::type_isa<OpenBundleType>(bundleType))
73  os << "open";
74  os << "bundle<";
75  llvm::interleaveComma(bundleType, os, [&](auto element) {
76  StringRef fieldName = element.name.getValue();
77  bool isLiteralIdentifier =
78  !fieldName.empty() && llvm::isDigit(fieldName.front());
79  if (isLiteralIdentifier)
80  os << "\"";
81  os << element.name.getValue();
82  if (isLiteralIdentifier)
83  os << "\"";
84  if (element.isFlip)
85  os << " flip";
86  os << ": ";
87  printNestedType(element.type, os);
88  });
89  os << '>';
90  })
91  .Case<FEnumType>([&](auto fenumType) {
92  os << "enum<";
93  llvm::interleaveComma(fenumType, os,
94  [&](FEnumType::EnumElement element) {
95  os << element.name.getValue();
96  os << ": ";
97  printNestedType(element.type, os);
98  });
99  os << '>';
100  })
101  .Case<FVectorType, OpenVectorType>([&](auto vectorType) {
102  if (firrtl::type_isa<OpenVectorType>(vectorType))
103  os << "open";
104  os << "vector<";
105  printNestedType(vectorType.getElementType(), os);
106  os << ", " << vectorType.getNumElements() << '>';
107  })
108  .Case<RefType>([&](RefType refType) {
109  if (refType.getForceable())
110  os << "rw";
111  os << "probe<";
112  printNestedType(refType.getType(), os);
113  if (auto layer = refType.getLayer())
114  os << ", " << layer;
115  os << '>';
116  })
117  .Case<StringType>([&](auto stringType) { os << "string"; })
118  .Case<FIntegerType>([&](auto integerType) { os << "integer"; })
119  .Case<BoolType>([&](auto boolType) { os << "bool"; })
120  .Case<DoubleType>([&](auto doubleType) { os << "double"; })
121  .Case<ListType>([&](auto listType) {
122  os << "list<";
123  printNestedType(listType.getElementType(), os);
124  os << '>';
125  })
126  .Case<PathType>([&](auto pathType) { os << "path"; })
127  .Case<BaseTypeAliasType>([&](BaseTypeAliasType alias) {
128  os << "alias<" << alias.getName().getValue() << ", ";
129  printNestedType(alias.getInnerType(), os);
130  os << '>';
131  })
132  .Case<ClassType>([&](ClassType type) {
133  os << "class<";
134  type.printInterface(os);
135  os << ">";
136  })
137  .Case<AnyRefType>([&](AnyRefType type) { os << "anyref"; })
138  .Default([&](auto) { anyFailed = true; });
139  return failure(anyFailed);
140 }
141 // NOLINTEND(misc-no-recursion)
142 
143 /// Print a type defined by this dialect.
144 void circt::firrtl::printNestedType(Type type, AsmPrinter &os) {
145  // Try the custom type printer.
146  if (succeeded(customTypePrinter(type, os)))
147  return;
148 
149  // None of the above recognized the type, so we bail.
150  assert(false && "type to print unknown to FIRRTL dialect");
151 }
152 
153 //===----------------------------------------------------------------------===//
154 // Type Parsing
155 //===----------------------------------------------------------------------===//
156 
157 /// Parse a type with a custom parser implementation.
158 ///
159 /// This only accepts a subset of all types in the dialect. Use `parseType`
160 /// instead, which will call this function in turn, as appropriate.
161 ///
162 /// Returns `std::nullopt` if the type `name` is not covered by the custom
163 /// parsers. Otherwise returns success or failure as appropriate. On success,
164 /// `result` is set to the resulting type.
165 ///
166 /// ```plain
167 /// firrtl-type
168 /// ::= clock
169 /// ::= reset
170 /// ::= asyncreset
171 /// ::= sint ('<' int '>')?
172 /// ::= uint ('<' int '>')?
173 /// ::= analog ('<' int '>')?
174 /// ::= bundle '<' (bundle-elt (',' bundle-elt)*)? '>'
175 /// ::= enum '<' (enum-elt (',' enum-elt)*)? '>'
176 /// ::= vector '<' type ',' int '>'
177 /// ::= const '.' type
178 /// ::= 'property.' firrtl-phased-type
179 /// bundle-elt ::= identifier flip? ':' type
180 /// enum-elt ::= identifier ':' type
181 /// ```
182 static OptionalParseResult customTypeParser(AsmParser &parser, StringRef name,
183  Type &result) {
184  bool isConst = false;
185  const char constPrefix[] = "const.";
186  if (name.starts_with(constPrefix)) {
187  isConst = true;
188  name = name.drop_front(std::size(constPrefix) - 1);
189  }
190 
191  auto *context = parser.getContext();
192  if (name.equals("clock"))
193  return result = ClockType::get(context, isConst), success();
194  if (name.equals("reset"))
195  return result = ResetType::get(context, isConst), success();
196  if (name.equals("asyncreset"))
197  return result = AsyncResetType::get(context, isConst), success();
198 
199  if (name.equals("sint") || name.equals("uint") || name.equals("analog")) {
200  // Parse the width specifier if it exists.
201  int32_t width = -1;
202  if (!parser.parseOptionalLess()) {
203  if (parser.parseInteger(width) || parser.parseGreater())
204  return failure();
205 
206  if (width < 0)
207  return parser.emitError(parser.getNameLoc(), "unknown width"),
208  failure();
209  }
210 
211  if (name.equals("sint"))
212  result = SIntType::get(context, width, isConst);
213  else if (name.equals("uint"))
214  result = UIntType::get(context, width, isConst);
215  else {
216  assert(name.equals("analog"));
217  result = AnalogType::get(context, width, isConst);
218  }
219  return success();
220  }
221 
222  if (name.equals("bundle")) {
223  SmallVector<BundleType::BundleElement, 4> elements;
224 
225  auto parseBundleElement = [&]() -> ParseResult {
226  std::string nameStr;
227  StringRef name;
228  FIRRTLBaseType type;
229 
230  if (failed(parser.parseKeywordOrString(&nameStr)))
231  return failure();
232  name = nameStr;
233 
234  bool isFlip = succeeded(parser.parseOptionalKeyword("flip"));
235  if (parser.parseColon() || parseNestedBaseType(type, parser))
236  return failure();
237 
238  elements.push_back({StringAttr::get(context, name), isFlip, type});
239  return success();
240  };
241 
242  if (parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater,
243  parseBundleElement))
244  return failure();
245 
246  return result = BundleType::get(context, elements, isConst), success();
247  }
248  if (name.equals("openbundle")) {
249  SmallVector<OpenBundleType::BundleElement, 4> elements;
250 
251  auto parseBundleElement = [&]() -> ParseResult {
252  std::string nameStr;
253  StringRef name;
254  FIRRTLType type;
255 
256  if (failed(parser.parseKeywordOrString(&nameStr)))
257  return failure();
258  name = nameStr;
259 
260  bool isFlip = succeeded(parser.parseOptionalKeyword("flip"));
261  if (parser.parseColon() || parseNestedType(type, parser))
262  return failure();
263 
264  elements.push_back({StringAttr::get(context, name), isFlip, type});
265  return success();
266  };
267 
268  if (parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater,
269  parseBundleElement))
270  return failure();
271 
272  result = parser.getChecked<OpenBundleType>(context, elements, isConst);
273  return failure(!result);
274  }
275 
276  if (name.equals("enum")) {
277  SmallVector<FEnumType::EnumElement, 4> elements;
278 
279  auto parseEnumElement = [&]() -> ParseResult {
280  std::string nameStr;
281  StringRef name;
282  FIRRTLBaseType type;
283 
284  if (failed(parser.parseKeywordOrString(&nameStr)))
285  return failure();
286  name = nameStr;
287 
288  if (parser.parseColon() || parseNestedBaseType(type, parser))
289  return failure();
290 
291  elements.push_back({StringAttr::get(context, name), type});
292  return success();
293  };
294 
295  if (parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater,
296  parseEnumElement))
297  return failure();
298  if (failed(FEnumType::verify(
299  [&]() { return parser.emitError(parser.getNameLoc()); }, elements,
300  isConst)))
301  return failure();
302 
303  return result = FEnumType::get(context, elements, isConst), success();
304  }
305 
306  if (name.equals("vector")) {
307  FIRRTLBaseType elementType;
308  uint64_t width = 0;
309 
310  if (parser.parseLess() || parseNestedBaseType(elementType, parser) ||
311  parser.parseComma() || parser.parseInteger(width) ||
312  parser.parseGreater())
313  return failure();
314 
315  return result = FVectorType::get(elementType, width, isConst), success();
316  }
317  if (name.equals("openvector")) {
318  FIRRTLType elementType;
319  uint64_t width = 0;
320 
321  if (parser.parseLess() || parseNestedType(elementType, parser) ||
322  parser.parseComma() || parser.parseInteger(width) ||
323  parser.parseGreater())
324  return failure();
325 
326  result =
327  parser.getChecked<OpenVectorType>(context, elementType, width, isConst);
328  return failure(!result);
329  }
330 
331  // For now, support both firrtl.ref and firrtl.probe.
332  if (name.equals("ref") || name.equals("probe")) {
333  FIRRTLBaseType type;
334  SymbolRefAttr layer;
335  // Don't pass `isConst` to `parseNestedBaseType since `ref` can point to
336  // either `const` or non-`const` types
337  if (parser.parseLess() || parseNestedBaseType(type, parser))
338  return failure();
339  if (parser.parseOptionalComma().succeeded())
340  if (parser.parseOptionalAttribute(layer).value())
341  return parser.emitError(parser.getNameLoc(),
342  "expected symbol reference");
343  if (parser.parseGreater())
344  return failure();
345 
346  if (failed(RefType::verify(
347  [&]() { return parser.emitError(parser.getNameLoc()); }, type,
348  false, layer)))
349  return failure();
350 
351  return result = RefType::get(type, false, layer), success();
352  }
353  if (name.equals("rwprobe")) {
354  FIRRTLBaseType type;
355  SymbolRefAttr layer;
356  if (parser.parseLess() || parseNestedBaseType(type, parser))
357  return failure();
358  if (parser.parseOptionalComma().succeeded())
359  if (parser.parseOptionalAttribute(layer).value())
360  return parser.emitError(parser.getNameLoc(),
361  "expected symbol reference");
362  if (parser.parseGreater())
363  return failure();
364 
365  if (failed(RefType::verify(
366  [&]() { return parser.emitError(parser.getNameLoc()); }, type, true,
367  layer)))
368  return failure();
369 
370  return result = RefType::get(type, true, layer), success();
371  }
372  if (name.equals("class")) {
373  if (isConst)
374  return parser.emitError(parser.getNameLoc(), "classes cannot be const");
375  ClassType classType;
376  if (parser.parseLess() || ClassType::parseInterface(parser, classType) ||
377  parser.parseGreater())
378  return failure();
379  result = classType;
380  return success();
381  }
382  if (name.equals("anyref")) {
383  if (isConst)
384  return parser.emitError(parser.getNameLoc(), "any refs cannot be const");
385 
386  result = AnyRefType::get(parser.getContext());
387  return success();
388  }
389  if (name.equals("string")) {
390  if (isConst) {
391  parser.emitError(parser.getNameLoc(), "strings cannot be const");
392  return failure();
393  }
394  result = StringType::get(parser.getContext());
395  return success();
396  }
397  if (name.equals("integer")) {
398  if (isConst) {
399  parser.emitError(parser.getNameLoc(), "bigints cannot be const");
400  return failure();
401  }
402  result = FIntegerType::get(parser.getContext());
403  return success();
404  }
405  if (name.equals("bool")) {
406  if (isConst) {
407  parser.emitError(parser.getNameLoc(), "bools cannot be const");
408  return failure();
409  }
410  result = BoolType::get(parser.getContext());
411  return success();
412  }
413  if (name.equals("double")) {
414  if (isConst) {
415  parser.emitError(parser.getNameLoc(), "doubles cannot be const");
416  return failure();
417  }
418  result = DoubleType::get(parser.getContext());
419  return success();
420  }
421  if (name.equals("list")) {
422  if (isConst) {
423  parser.emitError(parser.getNameLoc(), "lists cannot be const");
424  return failure();
425  }
427  if (parser.parseLess() || parseNestedPropertyType(elementType, parser) ||
428  parser.parseGreater())
429  return failure();
430  result = parser.getChecked<ListType>(context, elementType);
431  if (!result)
432  return failure();
433  return success();
434  }
435  if (name.equals("path")) {
436  if (isConst) {
437  parser.emitError(parser.getNameLoc(), "path cannot be const");
438  return failure();
439  }
440  result = PathType::get(parser.getContext());
441  return success();
442  }
443  if (name.equals("alias")) {
444  FIRRTLBaseType type;
445  StringRef name;
446  if (parser.parseLess() || parser.parseKeyword(&name) ||
447  parser.parseComma() || parseNestedBaseType(type, parser) ||
448  parser.parseGreater())
449  return failure();
450 
451  return result =
452  BaseTypeAliasType::get(StringAttr::get(context, name), type),
453  success();
454  }
455 
456  return {};
457 }
458 
459 /// Parse a type defined by this dialect.
460 ///
461 /// This will first try the generated type parsers and then resort to the custom
462 /// parser implementation. Emits an error and returns failure if `name` does not
463 /// refer to a type defined in this dialect.
464 static ParseResult parseType(Type &result, StringRef name, AsmParser &parser) {
465  // Try the custom type parser.
466  OptionalParseResult parseResult = customTypeParser(parser, name, result);
467  if (parseResult.has_value())
468  return parseResult.value();
469 
470  // None of the above recognized the type, so we bail.
471  parser.emitError(parser.getNameLoc(), "unknown FIRRTL dialect type: \"")
472  << name << "\"";
473  return failure();
474 }
475 
476 /// Parse a `FIRRTLType` with a `name` that has already been parsed.
477 ///
478 /// Note that only a subset of types defined in the FIRRTL dialect inherit from
479 /// `FIRRTLType`. Use `parseType` to parse *any* of the defined types.
480 static ParseResult parseFIRRTLType(FIRRTLType &result, StringRef name,
481  AsmParser &parser) {
482  Type type;
483  if (failed(parseType(type, name, parser)))
484  return failure();
485  result = type_dyn_cast<FIRRTLType>(type);
486  if (result)
487  return success();
488  parser.emitError(parser.getNameLoc(), "unknown FIRRTL type: \"")
489  << name << "\"";
490  return failure();
491 }
492 
493 static ParseResult parseFIRRTLBaseType(FIRRTLBaseType &result, StringRef name,
494  AsmParser &parser) {
495  FIRRTLType type;
496  if (failed(parseFIRRTLType(type, name, parser)))
497  return failure();
498  if (auto base = type_dyn_cast<FIRRTLBaseType>(type)) {
499  result = base;
500  return success();
501  }
502  parser.emitError(parser.getNameLoc(), "expected base type, found ") << type;
503  return failure();
504 }
505 
506 static ParseResult parseFIRRTLPropertyType(PropertyType &result, StringRef name,
507  AsmParser &parser) {
508  FIRRTLType type;
509  if (failed(parseFIRRTLType(type, name, parser)))
510  return failure();
511  if (auto prop = type_dyn_cast<PropertyType>(type)) {
512  result = prop;
513  return success();
514  }
515  parser.emitError(parser.getNameLoc(), "expected property type, found ")
516  << type;
517  return failure();
518 }
519 
520 // NOLINTBEGIN(misc-no-recursion)
521 /// Parse a `FIRRTLType`.
522 ///
523 /// Note that only a subset of types defined in the FIRRTL dialect inherit from
524 /// `FIRRTLType`. Use `parseType` to parse *any* of the defined types.
526  AsmParser &parser) {
527  StringRef name;
528  if (parser.parseKeyword(&name))
529  return failure();
530  return parseFIRRTLType(result, name, parser);
531 }
532 // NOLINTEND(misc-no-recursion)
533 
534 // NOLINTBEGIN(misc-no-recursion)
536  AsmParser &parser) {
537  StringRef name;
538  if (parser.parseKeyword(&name))
539  return failure();
540  return parseFIRRTLBaseType(result, name, parser);
541 }
542 // NOLINTEND(misc-no-recursion)
543 
544 // NOLINTBEGIN(misc-no-recursion)
546  AsmParser &parser) {
547  StringRef name;
548  if (parser.parseKeyword(&name))
549  return failure();
550  return parseFIRRTLPropertyType(result, name, parser);
551 }
552 // NOLINTEND(misc-no-recursion)
553 
554 //===---------------------------------------------------------------------===//
555 // Dialect Type Parsing and Printing
556 //===----------------------------------------------------------------------===//
557 
558 /// Print a type registered to this dialect.
559 void FIRRTLDialect::printType(Type type, DialectAsmPrinter &os) const {
560  printNestedType(type, os);
561 }
562 
563 /// Parse a type registered to this dialect.
564 Type FIRRTLDialect::parseType(DialectAsmParser &parser) const {
565  StringRef name;
566  Type result;
567  if (parser.parseKeyword(&name) || ::parseType(result, name, parser))
568  return Type();
569  return result;
570 }
571 
572 //===----------------------------------------------------------------------===//
573 // Recursive Type Properties
574 //===----------------------------------------------------------------------===//
575 
576 enum {
577  /// Bit set if the type only contains passive elements.
579  /// Bit set if the type contains an analog type.
581  /// Bit set fi the type has any uninferred bit widths.
583 };
584 
585 //===----------------------------------------------------------------------===//
586 // FIRRTLBaseType Implementation
587 //===----------------------------------------------------------------------===//
588 
590  // Use `char` instead of `bool` since llvm already provides a
591  // DenseMapInfo<char> specialization
592  using KeyTy = char;
593 
594  FIRRTLBaseTypeStorage(bool isConst) : isConst(static_cast<char>(isConst)) {}
595 
596  bool operator==(const KeyTy &key) const { return key == isConst; }
597 
598  KeyTy getAsKey() const { return isConst; }
599 
600  static FIRRTLBaseTypeStorage *construct(TypeStorageAllocator &allocator,
601  KeyTy key) {
602  return new (allocator.allocate<FIRRTLBaseTypeStorage>())
604  }
605 
606  char isConst;
607 };
608 
609 /// Return true if this is a 'ground' type, aka a non-aggregate type.
610 bool FIRRTLType::isGround() {
611  return TypeSwitch<FIRRTLType, bool>(*this)
612  .Case<ClockType, ResetType, AsyncResetType, SIntType, UIntType,
613  AnalogType>([](Type) { return true; })
614  .Case<BundleType, FVectorType, FEnumType, OpenBundleType, OpenVectorType>(
615  [](Type) { return false; })
616  .Case<BaseTypeAliasType>([](BaseTypeAliasType alias) {
617  return alias.getAnonymousType().isGround();
618  })
619  // Not ground per spec, but leaf of aggregate.
620  .Case<PropertyType, RefType>([](Type) { return false; })
621  .Default([](Type) {
622  llvm_unreachable("unknown FIRRTL type");
623  return false;
624  });
625 }
626 
627 bool FIRRTLType::isConst() {
628  return TypeSwitch<FIRRTLType, bool>(*this)
629  .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>(
630  [](auto type) { return type.isConst(); })
631  .Default(false);
632 }
633 
634 bool FIRRTLBaseType::isConst() { return getImpl()->isConst; }
635 
637  return TypeSwitch<FIRRTLType, RecursiveTypeProperties>(*this)
638  .Case<ClockType, ResetType, AsyncResetType>([](FIRRTLBaseType type) {
639  return RecursiveTypeProperties{true,
640  false,
641  false,
642  type.isConst(),
643  false,
644  false,
645  firrtl::type_isa<ResetType>(type)};
646  })
647  .Case<SIntType, UIntType>([](auto type) {
649  true, false, false, type.isConst(), false, !type.hasWidth(), false};
650  })
651  .Case<AnalogType>([](auto type) {
653  true, false, true, type.isConst(), false, !type.hasWidth(), false};
654  })
655  .Case<BundleType, FVectorType, FEnumType, OpenBundleType, OpenVectorType,
656  RefType, BaseTypeAliasType>(
657  [](auto type) { return type.getRecursiveTypeProperties(); })
658  .Case<PropertyType>([](auto type) {
659  return RecursiveTypeProperties{true, false, false, false,
660  false, false, false};
661  })
662  .Default([](Type) {
663  llvm_unreachable("unknown FIRRTL type");
664  return RecursiveTypeProperties{};
665  });
666 }
667 
668 /// Return this type with any type aliases recursively removed from itself.
670  return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
671  .Case<ClockType, ResetType, AsyncResetType, SIntType, UIntType,
672  AnalogType>([&](Type) { return *this; })
673  .Case<BundleType, FVectorType, FEnumType, BaseTypeAliasType>(
674  [](auto type) { return type.getAnonymousType(); })
675  .Default([](Type) {
676  llvm_unreachable("unknown FIRRTL type");
677  return FIRRTLBaseType();
678  });
679 }
680 
681 /// Return this type with any flip types recursively removed from itself.
683  return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
684  .Case<ClockType, ResetType, AsyncResetType, SIntType, UIntType,
685  AnalogType, FEnumType>([&](Type) { return *this; })
686  .Case<BundleType, FVectorType, FEnumType, BaseTypeAliasType>(
687  [](auto type) { return type.getPassiveType(); })
688  .Default([](Type) {
689  llvm_unreachable("unknown FIRRTL type");
690  return FIRRTLBaseType();
691  });
692 }
693 
694 /// Return a 'const' or non-'const' version of this type.
696  return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
697  .Case<ClockType, ResetType, AsyncResetType, AnalogType, SIntType,
698  UIntType, BundleType, FVectorType, FEnumType, BaseTypeAliasType>(
699  [&](auto type) { return type.getConstType(isConst); })
700  .Default([](Type) {
701  llvm_unreachable("unknown FIRRTL type");
702  return FIRRTLBaseType();
703  });
704 }
705 
706 /// Return this type with a 'const' modifiers dropped
708  return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
709  .Case<ClockType, ResetType, AsyncResetType, AnalogType, SIntType,
710  UIntType>([&](auto type) { return type.getConstType(false); })
711  .Case<BundleType, FVectorType, FEnumType, BaseTypeAliasType>(
712  [&](auto type) { return type.getAllConstDroppedType(); })
713  .Default([](Type) {
714  llvm_unreachable("unknown FIRRTL type");
715  return FIRRTLBaseType();
716  });
717 }
718 
719 /// Return this type with all ground types replaced with UInt<1>. This is
720 /// used for `mem` operations.
722  return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
723  .Case<ClockType, ResetType, AsyncResetType, SIntType, UIntType,
724  AnalogType>([&](Type) {
725  return UIntType::get(this->getContext(), 1, this->isConst());
726  })
727  .Case<BundleType>([&](BundleType bundleType) {
728  SmallVector<BundleType::BundleElement, 4> newElements;
729  newElements.reserve(bundleType.getElements().size());
730  for (auto elt : bundleType)
731  newElements.push_back(
732  {elt.name, false /* FIXME */, elt.type.getMaskType()});
733  return BundleType::get(this->getContext(), newElements,
734  bundleType.isConst());
735  })
736  .Case<FVectorType>([](FVectorType vectorType) {
737  return FVectorType::get(vectorType.getElementType().getMaskType(),
738  vectorType.getNumElements(),
739  vectorType.isConst());
740  })
741  .Case<BaseTypeAliasType>([](BaseTypeAliasType base) {
742  return base.getModifiedType(base.getInnerType().getMaskType());
743  })
744  .Default([](Type) {
745  llvm_unreachable("unknown FIRRTL type");
746  return FIRRTLBaseType();
747  });
748 }
749 
750 /// Remove the widths from this type. All widths are replaced with an
751 /// unknown width.
753  return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
754  .Case<ClockType, ResetType, AsyncResetType>([](auto a) { return a; })
755  .Case<UIntType, SIntType, AnalogType>(
756  [&](auto a) { return a.get(this->getContext(), -1, a.isConst()); })
757  .Case<BundleType>([&](auto a) {
758  SmallVector<BundleType::BundleElement, 4> newElements;
759  newElements.reserve(a.getElements().size());
760  for (auto elt : a)
761  newElements.push_back(
762  {elt.name, elt.isFlip, elt.type.getWidthlessType()});
763  return BundleType::get(this->getContext(), newElements, a.isConst());
764  })
765  .Case<FVectorType>([](auto a) {
766  return FVectorType::get(a.getElementType().getWidthlessType(),
767  a.getNumElements(), a.isConst());
768  })
769  .Case<FEnumType>([&](FEnumType a) {
770  SmallVector<FEnumType::EnumElement, 4> newElements;
771  newElements.reserve(a.getNumElements());
772  for (auto elt : a)
773  newElements.push_back({elt.name, elt.type.getWidthlessType()});
774  return FEnumType::get(this->getContext(), newElements, a.isConst());
775  })
776  .Case<BaseTypeAliasType>([](BaseTypeAliasType type) {
777  return type.getModifiedType(type.getInnerType().getWidthlessType());
778  })
779  .Default([](auto) {
780  llvm_unreachable("unknown FIRRTL type");
781  return FIRRTLBaseType();
782  });
783 }
784 
785 /// If this is an IntType, AnalogType, or sugar type for a single bit (Clock,
786 /// Reset, etc) then return the bitwidth. Return -1 if the is one of these
787 /// types but without a specified bitwidth. Return -2 if this isn't a simple
788 /// type.
790  return TypeSwitch<FIRRTLBaseType, int32_t>(*this)
791  .Case<ClockType, ResetType, AsyncResetType>([](Type) { return 1; })
792  .Case<SIntType, UIntType>(
793  [&](IntType intType) { return intType.getWidthOrSentinel(); })
794  .Case<AnalogType>(
795  [](AnalogType analogType) { return analogType.getWidthOrSentinel(); })
796  .Case<BundleType, FVectorType, FEnumType>([](Type) { return -2; })
797  .Case<BaseTypeAliasType>([](BaseTypeAliasType type) {
798  // It's faster to use its anonymous type.
799  return type.getAnonymousType().getBitWidthOrSentinel();
800  })
801  .Default([](Type) {
802  llvm_unreachable("unknown FIRRTL type");
803  return -2;
804  });
805 }
806 
807 /// Return true if this is a type usable as a reset. This must be
808 /// either an abstract reset, a concrete 1-bit UInt, an
809 /// asynchronous reset, or an uninfered width UInt.
811  return TypeSwitch<FIRRTLType, bool>(*this)
812  .Case<ResetType, AsyncResetType>([](Type) { return true; })
813  .Case<UIntType>(
814  [](UIntType a) { return !a.hasWidth() || a.getWidth() == 1; })
815  .Case<BaseTypeAliasType>(
816  [](auto type) { return type.getInnerType().isResetType(); })
817  .Default([](Type) { return false; });
818 }
819 
820 bool firrtl::isConst(Type type) {
821  return TypeSwitch<Type, bool>(type)
822  .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>(
823  [](auto base) { return base.isConst(); })
824  .Default(false);
825 }
826 
827 bool firrtl::containsConst(Type type) {
828  return TypeSwitch<Type, bool>(type)
829  .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>(
830  [](auto base) { return base.containsConst(); })
831  .Default(false);
832 }
833 
834 // NOLINTBEGIN(misc-no-recursion)
837  .Case<BundleType>([&](auto bundle) {
838  for (size_t i = 0, e = bundle.getNumElements(); i < e; ++i) {
839  auto elt = bundle.getElement(i);
840  if (hasZeroBitWidth(elt.type))
841  return true;
842  }
843  return bundle.getNumElements() == 0;
844  })
845  .Case<FVectorType>([&](auto vector) {
846  if (vector.getNumElements() == 0)
847  return true;
848  return hasZeroBitWidth(vector.getElementType());
849  })
850  .Case<FIRRTLBaseType>([](auto groundType) {
851  return firrtl::getBitWidth(groundType).value_or(0) == 0;
852  })
853  .Case<RefType>([](auto ref) { return hasZeroBitWidth(ref.getType()); })
854  .Default([](auto) { return false; });
855 }
856 // NOLINTEND(misc-no-recursion)
857 
858 /// Helper to implement the equivalence logic for a pair of bundle elements.
859 /// Note that the FIRRTL spec requires bundle elements to have the same
860 /// orientation, but this only compares their passive types. The FIRRTL dialect
861 /// differs from the spec in how it uses flip types for module output ports and
862 /// canonicalizes flips in bundles, so only passive types can be compared here.
863 static bool areBundleElementsEquivalent(BundleType::BundleElement destElement,
864  BundleType::BundleElement srcElement,
865  bool destOuterTypeIsConst,
866  bool srcOuterTypeIsConst,
867  bool requiresSameWidth) {
868  if (destElement.name != srcElement.name)
869  return false;
870  if (destElement.isFlip != srcElement.isFlip)
871  return false;
872 
873  if (destElement.isFlip) {
874  std::swap(destElement, srcElement);
875  std::swap(destOuterTypeIsConst, srcOuterTypeIsConst);
876  }
877 
878  return areTypesEquivalent(destElement.type, srcElement.type,
879  destOuterTypeIsConst, srcOuterTypeIsConst,
880  requiresSameWidth);
881 }
882 
883 /// Returns whether the two types are equivalent. This implements the exact
884 /// definition of type equivalence in the FIRRTL spec. If the types being
885 /// compared have any outer flips that encode FIRRTL module directions (input or
886 /// output), these should be stripped before using this method.
888  bool destOuterTypeIsConst,
889  bool srcOuterTypeIsConst,
890  bool requireSameWidths) {
891  auto destType = type_dyn_cast<FIRRTLBaseType>(destFType);
892  auto srcType = type_dyn_cast<FIRRTLBaseType>(srcFType);
893 
894  // For non-base types, only equivalent if identical.
895  if (!destType || !srcType)
896  return destFType == srcFType;
897 
898  bool srcIsConst = srcOuterTypeIsConst || srcFType.isConst();
899  bool destIsConst = destOuterTypeIsConst || destFType.isConst();
900 
901  // Vector types can be connected if they have the same size and element type.
902  auto destVectorType = type_dyn_cast<FVectorType>(destType);
903  auto srcVectorType = type_dyn_cast<FVectorType>(srcType);
904  if (destVectorType && srcVectorType)
905  return destVectorType.getNumElements() == srcVectorType.getNumElements() &&
906  areTypesEquivalent(destVectorType.getElementType(),
907  srcVectorType.getElementType(), destIsConst,
908  srcIsConst, requireSameWidths);
909 
910  // Bundle types can be connected if they have the same size, element names,
911  // and element types.
912  auto destBundleType = type_dyn_cast<BundleType>(destType);
913  auto srcBundleType = type_dyn_cast<BundleType>(srcType);
914  if (destBundleType && srcBundleType) {
915  auto destElements = destBundleType.getElements();
916  auto srcElements = srcBundleType.getElements();
917  size_t numDestElements = destElements.size();
918  if (numDestElements != srcElements.size())
919  return false;
920 
921  for (size_t i = 0; i < numDestElements; ++i) {
922  auto destElement = destElements[i];
923  auto srcElement = srcElements[i];
924  if (!areBundleElementsEquivalent(destElement, srcElement, destIsConst,
925  srcIsConst, requireSameWidths))
926  return false;
927  }
928  return true;
929  }
930 
931  // Enum types can be connected if they have the same size, element names, and
932  // element types.
933  auto dstEnumType = type_dyn_cast<FEnumType>(destType);
934  auto srcEnumType = type_dyn_cast<FEnumType>(srcType);
935 
936  if (dstEnumType && srcEnumType) {
937  if (dstEnumType.getNumElements() != srcEnumType.getNumElements())
938  return false;
939  // Enums requires the types to match exactly.
940  for (const auto &[dst, src] : llvm::zip(dstEnumType, srcEnumType)) {
941  // The variant names must match.
942  if (dst.name != src.name)
943  return false;
944  // Enumeration types can only be connected if the inner types have the
945  // same width.
946  if (!areTypesEquivalent(dst.type, src.type, destIsConst, srcIsConst,
947  true))
948  return false;
949  }
950  return true;
951  }
952 
953  // Ground type connections must be const compatible.
954  if (destIsConst && !srcIsConst)
955  return false;
956 
957  // Reset types can be driven by UInt<1>, AsyncReset, or Reset types.
958  if (firrtl::type_isa<ResetType>(destType))
959  return srcType.isResetType();
960 
961  // Reset types can drive UInt<1>, AsyncReset, or Reset types.
962  if (firrtl::type_isa<ResetType>(srcType))
963  return destType.isResetType();
964 
965  // If we can implicitly truncate or extend the bitwidth, or either width is
966  // currently uninferred, then compare the widthless version of these types.
967  if (!requireSameWidths || destType.getBitWidthOrSentinel() == -1)
968  srcType = srcType.getWidthlessType();
969  if (!requireSameWidths || srcType.getBitWidthOrSentinel() == -1)
970  destType = destType.getWidthlessType();
971 
972  // Ground types can be connected if their constless types are the same
973  return destType.getConstType(false) == srcType.getConstType(false);
974 }
975 
976 /// Returns whether the two types are weakly equivalent.
978  bool destFlip, bool srcFlip,
979  bool destOuterTypeIsConst,
980  bool srcOuterTypeIsConst) {
981  auto destType = type_dyn_cast<FIRRTLBaseType>(destFType);
982  auto srcType = type_dyn_cast<FIRRTLBaseType>(srcFType);
983 
984  // For non-base types, only equivalent if identical.
985  if (!destType || !srcType)
986  return destFType == srcFType;
987 
988  bool srcIsConst = srcOuterTypeIsConst || srcFType.isConst();
989  bool destIsConst = destOuterTypeIsConst || destFType.isConst();
990 
991  // Vector types can be connected if their element types are weakly equivalent.
992  // Size doesn't matter.
993  auto destVectorType = type_dyn_cast<FVectorType>(destType);
994  auto srcVectorType = type_dyn_cast<FVectorType>(srcType);
995  if (destVectorType && srcVectorType)
996  return areTypesWeaklyEquivalent(destVectorType.getElementType(),
997  srcVectorType.getElementType(), destFlip,
998  srcFlip, destIsConst, srcIsConst);
999 
1000  // Bundle types are weakly equivalent if all common elements are weakly
1001  // equivalent. Non-matching fields are ignored. Flips are "pushed" into
1002  // recursive weak type equivalence checks.
1003  auto destBundleType = type_dyn_cast<BundleType>(destType);
1004  auto srcBundleType = type_dyn_cast<BundleType>(srcType);
1005  if (destBundleType && srcBundleType)
1006  return llvm::all_of(destBundleType, [&](auto destElt) -> bool {
1007  auto destField = destElt.name.getValue();
1008  auto srcElt = srcBundleType.getElement(destField);
1009  // If the src doesn't contain the destination's field, that's okay.
1010  if (!srcElt)
1011  return true;
1012 
1013  return areTypesWeaklyEquivalent(
1014  destElt.type, srcElt->type, destFlip ^ destElt.isFlip,
1015  srcFlip ^ srcElt->isFlip, destOuterTypeIsConst, srcOuterTypeIsConst);
1016  });
1017 
1018  // Ground types require leaf flippedness and const compatibility
1019  if (destFlip != srcFlip)
1020  return false;
1021  if (destFlip && srcIsConst && !destIsConst)
1022  return false;
1023  if (srcFlip && destIsConst && !srcIsConst)
1024  return false;
1025 
1026  // Reset types can be driven by UInt<1>, AsyncReset, or Reset types.
1027  if (type_isa<ResetType>(destType))
1028  return srcType.isResetType();
1029 
1030  // Reset types can drive UInt<1>, AsyncReset, or Reset types.
1031  if (type_isa<ResetType>(srcType))
1032  return destType.isResetType();
1033 
1034  // Ground types can be connected if their passive, widthless versions
1035  // are equal and are const and flip compatible
1036  auto widthlessDestType = destType.getWidthlessType();
1037  auto widthlessSrcType = srcType.getWidthlessType();
1038  return widthlessDestType.getConstType(false) ==
1039  widthlessSrcType.getConstType(false);
1040 }
1041 
1042 /// Returns whether the srcType can be const-casted to the destType.
1044  bool srcOuterTypeIsConst) {
1045  // Identical types are always castable
1046  if (destFType == srcFType)
1047  return true;
1048 
1049  auto destType = type_dyn_cast<FIRRTLBaseType>(destFType);
1050  auto srcType = type_dyn_cast<FIRRTLBaseType>(srcFType);
1051 
1052  // For non-base types, only castable if identical.
1053  if (!destType || !srcType)
1054  return false;
1055 
1056  // Types must be passive
1057  if (!destType.isPassive() || !srcType.isPassive())
1058  return false;
1059 
1060  bool srcIsConst = srcType.isConst() || srcOuterTypeIsConst;
1061 
1062  // Cannot cast non-'const' src to 'const' dest
1063  if (destType.isConst() && !srcIsConst)
1064  return false;
1065 
1066  // Vector types can be casted if they have the same size and castable element
1067  // type.
1068  auto destVectorType = type_dyn_cast<FVectorType>(destType);
1069  auto srcVectorType = type_dyn_cast<FVectorType>(srcType);
1070  if (destVectorType && srcVectorType)
1071  return destVectorType.getNumElements() == srcVectorType.getNumElements() &&
1072  areTypesConstCastable(destVectorType.getElementType(),
1073  srcVectorType.getElementType(), srcIsConst);
1074  if (destVectorType != srcVectorType)
1075  return false;
1076 
1077  // Bundle types can be casted if they have the same size, element names,
1078  // and castable element types.
1079  auto destBundleType = type_dyn_cast<BundleType>(destType);
1080  auto srcBundleType = type_dyn_cast<BundleType>(srcType);
1081  if (destBundleType && srcBundleType) {
1082  auto destElements = destBundleType.getElements();
1083  auto srcElements = srcBundleType.getElements();
1084  size_t numDestElements = destElements.size();
1085  if (numDestElements != srcElements.size())
1086  return false;
1087 
1088  return llvm::all_of_zip(
1089  destElements, srcElements,
1090  [&](const auto &destElement, const auto &srcElement) {
1091  return destElement.name == srcElement.name &&
1092  areTypesConstCastable(destElement.type, srcElement.type,
1093  srcIsConst);
1094  });
1095  }
1096  if (destBundleType != srcBundleType)
1097  return false;
1098 
1099  // Ground types can be casted if the source type is a const
1100  // version of the destination type
1101  return destType == srcType.getConstType(destType.isConst());
1102 }
1103 
1104 bool firrtl::areTypesRefCastable(Type dstType, Type srcType) {
1105  auto dstRefType = type_dyn_cast<RefType>(dstType);
1106  auto srcRefType = type_dyn_cast<RefType>(srcType);
1107  if (!dstRefType || !srcRefType)
1108  return false;
1109  if (dstRefType == srcRefType)
1110  return true;
1111  if (dstRefType.getForceable() && !srcRefType.getForceable())
1112  return false;
1113 
1114  // Okay walk the types recursively. They must be identical "structurally"
1115  // with exception leaf (ground) types of destination can be uninferred
1116  // versions of the corresponding source type. (can lose width information or
1117  // become a more general reset type)
1118  // In addition, while not explicitly in spec its useful to allow probes
1119  // to have const cast away, especially for probes of literals and expressions
1120  // derived from them. Check const as with const cast.
1121  // NOLINTBEGIN(misc-no-recursion)
1122  auto recurse = [&](auto &&f, FIRRTLBaseType dest, FIRRTLBaseType src,
1123  bool srcOuterTypeIsConst) -> bool {
1124  // Fast-path for identical types.
1125  if (dest == src)
1126  return true;
1127 
1128  // Always passive inside probes, but for sanity assert this.
1129  assert(dest.isPassive() && src.isPassive());
1130 
1131  bool srcIsConst = src.isConst() || srcOuterTypeIsConst;
1132 
1133  // Cannot cast non-'const' src to 'const' dest
1134  if (dest.isConst() && !srcIsConst)
1135  return false;
1136 
1137  // Recurse through aggregates to get the leaves, checking
1138  // structural equivalence re:element count + names.
1139 
1140  if (auto destVectorType = type_dyn_cast<FVectorType>(dest)) {
1141  auto srcVectorType = type_dyn_cast<FVectorType>(src);
1142  return srcVectorType &&
1143  destVectorType.getNumElements() ==
1144  srcVectorType.getNumElements() &&
1145  f(f, destVectorType.getElementType(),
1146  srcVectorType.getElementType(), srcIsConst);
1147  }
1148 
1149  if (auto destBundleType = type_dyn_cast<BundleType>(dest)) {
1150  auto srcBundleType = type_dyn_cast<BundleType>(src);
1151  if (!srcBundleType)
1152  return false;
1153  // (no need to check orientation, these are always passive)
1154  auto destElements = destBundleType.getElements();
1155  auto srcElements = srcBundleType.getElements();
1156 
1157  return destElements.size() == srcElements.size() &&
1158  llvm::all_of_zip(
1159  destElements, srcElements,
1160  [&](const auto &destElement, const auto &srcElement) {
1161  return destElement.name == srcElement.name &&
1162  f(f, destElement.type, srcElement.type, srcIsConst);
1163  });
1164  }
1165 
1166  if (auto destEnumType = type_dyn_cast<FEnumType>(dest)) {
1167  auto srcEnumType = type_dyn_cast<FEnumType>(src);
1168  if (!srcEnumType)
1169  return false;
1170  auto destElements = destEnumType.getElements();
1171  auto srcElements = srcEnumType.getElements();
1172 
1173  return destElements.size() == srcElements.size() &&
1174  llvm::all_of_zip(
1175  destElements, srcElements,
1176  [&](const auto &destElement, const auto &srcElement) {
1177  return destElement.name == srcElement.name &&
1178  f(f, destElement.type, srcElement.type, srcIsConst);
1179  });
1180  }
1181 
1182  // Reset types can be driven by UInt<1>, AsyncReset, or Reset types.
1183  if (type_isa<ResetType>(dest))
1184  return src.isResetType();
1185  // (but don't allow the other direction, can only become more general)
1186 
1187  // Compare against const src if dest is const.
1188  src = src.getConstType(dest.isConst());
1189 
1190  // Compare against widthless src if dest is widthless.
1191  if (dest.getBitWidthOrSentinel() == -1)
1192  src = src.getWidthlessType();
1193 
1194  return dest == src;
1195  };
1196 
1197  return recurse(recurse, dstRefType.getType(), srcRefType.getType(), false);
1198  // NOLINTEND(misc-no-recursion)
1199 }
1200 
1201 // NOLINTBEGIN(misc-no-recursion)
1202 /// Returns true if the destination is at least as wide as an equivalent source.
1204  return TypeSwitch<FIRRTLBaseType, bool>(dstType)
1205  .Case<BundleType>([&](auto dstBundle) {
1206  auto srcBundle = type_cast<BundleType>(srcType);
1207  for (size_t i = 0, n = dstBundle.getNumElements(); i < n; ++i) {
1208  auto srcElem = srcBundle.getElement(i);
1209  auto dstElem = dstBundle.getElement(i);
1210  if (dstElem.isFlip) {
1211  if (!isTypeLarger(srcElem.type, dstElem.type))
1212  return false;
1213  } else {
1214  if (!isTypeLarger(dstElem.type, srcElem.type))
1215  return false;
1216  }
1217  }
1218  return true;
1219  })
1220  .Case<FVectorType>([&](auto vector) {
1221  return isTypeLarger(vector.getElementType(),
1222  type_cast<FVectorType>(srcType).getElementType());
1223  })
1224  .Default([&](auto dstGround) {
1225  int32_t destWidth = dstType.getPassiveType().getBitWidthOrSentinel();
1226  int32_t srcWidth = srcType.getPassiveType().getBitWidthOrSentinel();
1227  return destWidth <= -1 || srcWidth <= -1 || destWidth >= srcWidth;
1228  });
1229 }
1230 // NOLINTEND(misc-no-recursion)
1231 
1233  FIRRTLBaseType rhs) {
1234  return lhs.getAnonymousType() == rhs.getAnonymousType();
1235 }
1236 
1237 bool firrtl::areAnonymousTypesEquivalent(mlir::Type lhs, mlir::Type rhs) {
1238  if (auto destBaseType = type_dyn_cast<FIRRTLBaseType>(lhs))
1239  if (auto srcBaseType = type_dyn_cast<FIRRTLBaseType>(rhs))
1240  return areAnonymousTypesEquivalent(destBaseType, srcBaseType);
1241 
1242  if (auto destRefType = type_dyn_cast<RefType>(lhs))
1243  if (auto srcRefType = type_dyn_cast<RefType>(rhs))
1244  return areAnonymousTypesEquivalent(destRefType.getType(),
1245  srcRefType.getType());
1246 
1247  return lhs == rhs;
1248 }
1249 
1250 /// Return the passive version of a firrtl type
1251 /// top level for ODS constraint usage
1252 Type firrtl::getPassiveType(Type anyBaseFIRRTLType) {
1253  return type_cast<FIRRTLBaseType>(anyBaseFIRRTLType).getPassiveType();
1254 }
1255 
1256 bool firrtl::isTypeInOut(Type type) {
1257  return llvm::TypeSwitch<Type, bool>(type)
1258  .Case<FIRRTLBaseType>([](auto type) {
1259  return !type.containsReference() &&
1260  (!type.isPassive() || type.containsAnalog());
1261  })
1262  .Default(false);
1263 }
1264 
1265 //===----------------------------------------------------------------------===//
1266 // IntType
1267 //===----------------------------------------------------------------------===//
1268 
1269 /// Return a SIntType or UIntType with the specified signedness, width, and
1270 /// constness
1271 IntType IntType::get(MLIRContext *context, bool isSigned,
1272  int32_t widthOrSentinel, bool isConst) {
1273  if (isSigned)
1274  return SIntType::get(context, widthOrSentinel, isConst);
1275  return UIntType::get(context, widthOrSentinel, isConst);
1276 }
1277 
1279  if (auto sintType = type_dyn_cast<SIntType>(*this))
1280  return sintType.getWidthOrSentinel();
1281  if (auto uintType = type_dyn_cast<UIntType>(*this))
1282  return uintType.getWidthOrSentinel();
1283  return -1;
1284 }
1285 
1286 //===----------------------------------------------------------------------===//
1287 // WidthTypeStorage
1288 //===----------------------------------------------------------------------===//
1289 
1293  using KeyTy = std::pair<int32_t, char>;
1294 
1295  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1296 
1297  KeyTy getAsKey() const { return KeyTy(width, isConst); }
1298 
1299  static WidthTypeStorage *construct(TypeStorageAllocator &allocator,
1300  const KeyTy &key) {
1301  return new (allocator.allocate<WidthTypeStorage>())
1302  WidthTypeStorage(key.first, key.second);
1303  }
1304 
1305  int32_t width;
1306 };
1307 
1309 
1310  if (auto sIntType = type_dyn_cast<SIntType>(*this))
1311  return sIntType.getConstType(isConst);
1312  return type_cast<UIntType>(*this).getConstType(isConst);
1313 }
1314 
1315 //===----------------------------------------------------------------------===//
1316 // SIntType
1317 //===----------------------------------------------------------------------===//
1318 
1319 SIntType SIntType::get(MLIRContext *context) { return get(context, -1, false); }
1320 
1321 SIntType SIntType::get(MLIRContext *context, std::optional<int32_t> width,
1322  bool isConst) {
1323  return get(context, width ? *width : -1, isConst);
1324 }
1325 
1326 LogicalResult SIntType::verify(function_ref<InFlightDiagnostic()> emitError,
1327  int32_t widthOrSentinel, bool isConst) {
1328  if (widthOrSentinel < -1)
1329  return emitError() << "invalid width";
1330  return success();
1331 }
1332 
1333 int32_t SIntType::getWidthOrSentinel() const { return getImpl()->width; }
1334 
1335 SIntType SIntType::getConstType(bool isConst) {
1336  if (isConst == this->isConst())
1337  return *this;
1338  return get(getContext(), getWidthOrSentinel(), isConst);
1339 }
1340 
1341 //===----------------------------------------------------------------------===//
1342 // UIntType
1343 //===----------------------------------------------------------------------===//
1344 
1345 UIntType UIntType::get(MLIRContext *context) { return get(context, -1, false); }
1346 
1347 UIntType UIntType::get(MLIRContext *context, std::optional<int32_t> width,
1348  bool isConst) {
1349  return get(context, width ? *width : -1, isConst);
1350 }
1351 
1352 LogicalResult UIntType::verify(function_ref<InFlightDiagnostic()> emitError,
1353  int32_t widthOrSentinel, bool isConst) {
1354  if (widthOrSentinel < -1)
1355  return emitError() << "invalid width";
1356  return success();
1357 }
1358 
1359 int32_t UIntType::getWidthOrSentinel() const { return getImpl()->width; }
1360 
1361 UIntType UIntType::getConstType(bool isConst) {
1362  if (isConst == this->isConst())
1363  return *this;
1364  return get(getContext(), getWidthOrSentinel(), isConst);
1365 }
1366 
1367 //===----------------------------------------------------------------------===//
1368 // Bundle Type
1369 //===----------------------------------------------------------------------===//
1370 
1373  using KeyTy = std::pair<ArrayRef<BundleType::BundleElement>, char>;
1374 
1375  BundleTypeStorage(ArrayRef<BundleType::BundleElement> elements, bool isConst)
1377  elements(elements.begin(), elements.end()), props{true, false, false,
1378  isConst, false, false,
1379  false} {
1380  uint64_t fieldID = 0;
1381  fieldIDs.reserve(elements.size());
1382  for (auto &element : elements) {
1383  auto type = element.type;
1384  auto eltInfo = type.getRecursiveTypeProperties();
1385  props.isPassive &= eltInfo.isPassive & !element.isFlip;
1386  props.containsAnalog |= eltInfo.containsAnalog;
1387  props.containsReference |= eltInfo.containsReference;
1388  props.containsConst |= eltInfo.containsConst;
1389  props.containsTypeAlias |= eltInfo.containsTypeAlias;
1390  props.hasUninferredWidth |= eltInfo.hasUninferredWidth;
1391  props.hasUninferredReset |= eltInfo.hasUninferredReset;
1392  fieldID += 1;
1393  fieldIDs.push_back(fieldID);
1394  // Increment the field ID for the next field by the number of subfields.
1395  fieldID += hw::FieldIdImpl::getMaxFieldID(type);
1396  }
1397  maxFieldID = fieldID;
1398  }
1399 
1400  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1401 
1402  KeyTy getAsKey() const { return KeyTy(elements, isConst); }
1403 
1404  static llvm::hash_code hashKey(const KeyTy &key) {
1405  return llvm::hash_combine(
1406  llvm::hash_combine_range(key.first.begin(), key.first.end()),
1407  key.second);
1408  }
1409 
1410  static BundleTypeStorage *construct(TypeStorageAllocator &allocator,
1411  KeyTy key) {
1412  return new (allocator.allocate<BundleTypeStorage>())
1413  BundleTypeStorage(key.first, static_cast<bool>(key.second));
1414  }
1415 
1416  SmallVector<BundleType::BundleElement, 4> elements;
1417  SmallVector<uint64_t, 4> fieldIDs;
1418  uint64_t maxFieldID;
1419 
1420  /// This holds the bits for the type's recursive properties, and can hold a
1421  /// pointer to a passive version of the type.
1423  BundleType passiveType;
1424  BundleType anonymousType;
1425 };
1426 
1427 BundleType BundleType::get(MLIRContext *context,
1428  ArrayRef<BundleElement> elements, bool isConst) {
1429  return Base::get(context, elements, isConst);
1430 }
1431 
1432 auto BundleType::getElements() const -> ArrayRef<BundleElement> {
1433  return getImpl()->elements;
1434 }
1435 
1436 /// Return a pair with the 'isPassive' and 'containsAnalog' bits.
1437 RecursiveTypeProperties BundleType::getRecursiveTypeProperties() const {
1438  return getImpl()->props;
1439 }
1440 
1441 /// Return this type with any flip types recursively removed from itself.
1443  auto *impl = getImpl();
1444 
1445  // If we've already determined and cached the passive type, use it.
1446  if (impl->passiveType)
1447  return impl->passiveType;
1448 
1449  // If this type is already passive, use it and remember for next time.
1450  if (impl->props.isPassive) {
1451  impl->passiveType = *this;
1452  return *this;
1453  }
1454 
1455  // Otherwise at least one element is non-passive, rebuild a passive version.
1456  SmallVector<BundleType::BundleElement, 16> newElements;
1457  newElements.reserve(impl->elements.size());
1458  for (auto &elt : impl->elements) {
1459  newElements.push_back({elt.name, false, elt.type.getPassiveType()});
1460  }
1461 
1462  auto passiveType = BundleType::get(getContext(), newElements, isConst());
1463  impl->passiveType = passiveType;
1464  return passiveType;
1465 }
1466 
1467 BundleType BundleType::getConstType(bool isConst) {
1468  if (isConst == this->isConst())
1469  return *this;
1470  return get(getContext(), getElements(), isConst);
1471 }
1472 
1473 BundleType BundleType::getAllConstDroppedType() {
1474  if (!containsConst())
1475  return *this;
1476 
1477  SmallVector<BundleElement> constDroppedElements(
1478  llvm::map_range(getElements(), [](BundleElement element) {
1479  element.type = element.type.getAllConstDroppedType();
1480  return element;
1481  }));
1482  return get(getContext(), constDroppedElements, false);
1483 }
1484 
1485 std::optional<unsigned> BundleType::getElementIndex(StringAttr name) {
1486  for (const auto &it : llvm::enumerate(getElements())) {
1487  auto element = it.value();
1488  if (element.name == name) {
1489  return unsigned(it.index());
1490  }
1491  }
1492  return std::nullopt;
1493 }
1494 
1495 std::optional<unsigned> BundleType::getElementIndex(StringRef name) {
1496  for (const auto &it : llvm::enumerate(getElements())) {
1497  auto element = it.value();
1498  if (element.name.getValue() == name) {
1499  return unsigned(it.index());
1500  }
1501  }
1502  return std::nullopt;
1503 }
1504 
1505 StringAttr BundleType::getElementNameAttr(size_t index) {
1506  assert(index < getNumElements() &&
1507  "index must be less than number of fields in bundle");
1508  return getElements()[index].name;
1509 }
1510 
1511 StringRef BundleType::getElementName(size_t index) {
1512  return getElementNameAttr(index).getValue();
1513 }
1514 
1515 std::optional<BundleType::BundleElement>
1516 BundleType::getElement(StringAttr name) {
1517  if (auto maybeIndex = getElementIndex(name))
1518  return getElements()[*maybeIndex];
1519  return std::nullopt;
1520 }
1521 
1522 std::optional<BundleType::BundleElement>
1523 BundleType::getElement(StringRef name) {
1524  if (auto maybeIndex = getElementIndex(name))
1525  return getElements()[*maybeIndex];
1526  return std::nullopt;
1527 }
1528 
1529 /// Look up an element by index.
1530 BundleType::BundleElement BundleType::getElement(size_t index) {
1531  assert(index < getNumElements() &&
1532  "index must be less than number of fields in bundle");
1533  return getElements()[index];
1534 }
1535 
1536 FIRRTLBaseType BundleType::getElementType(StringAttr name) {
1537  auto element = getElement(name);
1538  return element ? element->type : FIRRTLBaseType();
1539 }
1540 
1541 FIRRTLBaseType BundleType::getElementType(StringRef name) {
1542  auto element = getElement(name);
1543  return element ? element->type : FIRRTLBaseType();
1544 }
1545 
1546 FIRRTLBaseType BundleType::getElementType(size_t index) const {
1547  assert(index < getNumElements() &&
1548  "index must be less than number of fields in bundle");
1549  return getElements()[index].type;
1550 }
1551 
1552 uint64_t BundleType::getFieldID(uint64_t index) const {
1553  return getImpl()->fieldIDs[index];
1554 }
1555 
1556 uint64_t BundleType::getIndexForFieldID(uint64_t fieldID) const {
1557  assert(!getElements().empty() && "Bundle must have >0 fields");
1558  auto fieldIDs = getImpl()->fieldIDs;
1559  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
1560  return std::distance(fieldIDs.begin(), it);
1561 }
1562 
1563 std::pair<uint64_t, uint64_t>
1564 BundleType::getIndexAndSubfieldID(uint64_t fieldID) const {
1565  auto index = getIndexForFieldID(fieldID);
1566  auto elementFieldID = getFieldID(index);
1567  return {index, fieldID - elementFieldID};
1568 }
1569 
1570 std::pair<Type, uint64_t>
1571 BundleType::getSubTypeByFieldID(uint64_t fieldID) const {
1572  if (fieldID == 0)
1573  return {*this, 0};
1574  auto subfieldIndex = getIndexForFieldID(fieldID);
1575  auto subfieldType = getElementType(subfieldIndex);
1576  auto subfieldID = fieldID - getFieldID(subfieldIndex);
1577  return {subfieldType, subfieldID};
1578 }
1579 
1580 uint64_t BundleType::getMaxFieldID() const { return getImpl()->maxFieldID; }
1581 
1582 std::pair<uint64_t, bool>
1583 BundleType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
1584  auto childRoot = getFieldID(index);
1585  auto rangeEnd = index + 1 >= getNumElements() ? getMaxFieldID()
1586  : (getFieldID(index + 1) - 1);
1587  return std::make_pair(fieldID - childRoot,
1588  fieldID >= childRoot && fieldID <= rangeEnd);
1589 }
1590 
1591 bool BundleType::isConst() { return getImpl()->isConst; }
1592 
1593 BundleType::ElementType
1594 BundleType::getElementTypePreservingConst(size_t index) {
1595  auto type = getElementType(index);
1596  return type.getConstType(type.isConst() || isConst());
1597 }
1598 
1599 /// Return this type with any type aliases recursively removed from itself.
1600 FIRRTLBaseType BundleType::getAnonymousType() {
1601  auto *impl = getImpl();
1602 
1603  // If we've already determined and cached the anonymous type, use it.
1604  if (impl->anonymousType)
1605  return impl->anonymousType;
1606 
1607  // If this type is already anonymous, use it and remember for next time.
1608  if (!impl->props.containsTypeAlias) {
1609  impl->anonymousType = *this;
1610  return *this;
1611  }
1612 
1613  // Otherwise at least one element has an alias type, rebuild an anonymous
1614  // version.
1615  SmallVector<BundleType::BundleElement, 16> newElements;
1616  newElements.reserve(impl->elements.size());
1617  for (auto &elt : impl->elements)
1618  newElements.push_back({elt.name, elt.isFlip, elt.type.getAnonymousType()});
1619 
1620  auto anonymousType = BundleType::get(getContext(), newElements, isConst());
1621  impl->anonymousType = anonymousType;
1622  return anonymousType;
1623 }
1624 
1625 //===----------------------------------------------------------------------===//
1626 // OpenBundle Type
1627 //===----------------------------------------------------------------------===//
1628 
1630  using KeyTy = std::pair<ArrayRef<OpenBundleType::BundleElement>, char>;
1631 
1632  OpenBundleTypeStorage(ArrayRef<OpenBundleType::BundleElement> elements,
1633  bool isConst)
1634  : elements(elements.begin(), elements.end()), props{true, false, false,
1635  isConst, false, false,
1636  false},
1637  isConst(static_cast<char>(isConst)) {
1638  uint64_t fieldID = 0;
1639  fieldIDs.reserve(elements.size());
1640  for (auto &element : elements) {
1641  auto type = element.type;
1642  auto eltInfo = type.getRecursiveTypeProperties();
1643  props.isPassive &= eltInfo.isPassive & !element.isFlip;
1644  props.containsAnalog |= eltInfo.containsAnalog;
1645  props.containsReference |= eltInfo.containsReference;
1646  props.containsConst |= eltInfo.containsConst;
1647  props.containsTypeAlias |= eltInfo.containsTypeAlias;
1648  props.hasUninferredWidth |= eltInfo.hasUninferredWidth;
1649  props.hasUninferredReset |= eltInfo.hasUninferredReset;
1650  fieldID += 1;
1651  fieldIDs.push_back(fieldID);
1652  // Increment the field ID for the next field by the number of subfields.
1653  // TODO: Maybe just have elementType be FieldIDTypeInterface ?
1654  fieldID += hw::FieldIdImpl::getMaxFieldID(type);
1655  }
1656  maxFieldID = fieldID;
1657  }
1658 
1659  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1660 
1661  static llvm::hash_code hashKey(const KeyTy &key) {
1662  return llvm::hash_combine(
1663  llvm::hash_combine_range(key.first.begin(), key.first.end()),
1664  key.second);
1665  }
1666 
1667  KeyTy getAsKey() const { return KeyTy(elements, isConst); }
1668 
1669  static OpenBundleTypeStorage *construct(TypeStorageAllocator &allocator,
1670  KeyTy key) {
1671  return new (allocator.allocate<OpenBundleTypeStorage>())
1672  OpenBundleTypeStorage(key.first, static_cast<bool>(key.second));
1673  }
1674 
1675  SmallVector<OpenBundleType::BundleElement, 4> elements;
1676  SmallVector<uint64_t, 4> fieldIDs;
1677  uint64_t maxFieldID;
1678 
1679  /// This holds the bits for the type's recursive properties, and can hold a
1680  /// pointer to a passive version of the type.
1682 
1683  // Whether this is 'const'.
1684  char isConst;
1685 };
1686 
1687 OpenBundleType OpenBundleType::get(MLIRContext *context,
1688  ArrayRef<BundleElement> elements,
1689  bool isConst) {
1690  return Base::get(context, elements, isConst);
1691 }
1692 
1693 auto OpenBundleType::getElements() const -> ArrayRef<BundleElement> {
1694  return getImpl()->elements;
1695 }
1696 
1697 /// Return a pair with the 'isPassive' and 'containsAnalog' bits.
1698 RecursiveTypeProperties OpenBundleType::getRecursiveTypeProperties() const {
1699  return getImpl()->props;
1700 }
1701 
1702 OpenBundleType OpenBundleType::getConstType(bool isConst) {
1703  if (isConst == this->isConst())
1704  return *this;
1705  return get(getContext(), getElements(), isConst);
1706 }
1707 
1708 std::optional<unsigned> OpenBundleType::getElementIndex(StringAttr name) {
1709  for (const auto &it : llvm::enumerate(getElements())) {
1710  auto element = it.value();
1711  if (element.name == name) {
1712  return unsigned(it.index());
1713  }
1714  }
1715  return std::nullopt;
1716 }
1717 
1718 std::optional<unsigned> OpenBundleType::getElementIndex(StringRef name) {
1719  for (const auto &it : llvm::enumerate(getElements())) {
1720  auto element = it.value();
1721  if (element.name.getValue() == name) {
1722  return unsigned(it.index());
1723  }
1724  }
1725  return std::nullopt;
1726 }
1727 
1728 StringAttr OpenBundleType::getElementNameAttr(size_t index) {
1729  assert(index < getNumElements() &&
1730  "index must be less than number of fields in bundle");
1731  return getElements()[index].name;
1732 }
1733 
1734 StringRef OpenBundleType::getElementName(size_t index) {
1735  return getElementNameAttr(index).getValue();
1736 }
1737 
1738 std::optional<OpenBundleType::BundleElement>
1739 OpenBundleType::getElement(StringAttr name) {
1740  if (auto maybeIndex = getElementIndex(name))
1741  return getElements()[*maybeIndex];
1742  return std::nullopt;
1743 }
1744 
1745 std::optional<OpenBundleType::BundleElement>
1746 OpenBundleType::getElement(StringRef name) {
1747  if (auto maybeIndex = getElementIndex(name))
1748  return getElements()[*maybeIndex];
1749  return std::nullopt;
1750 }
1751 
1752 /// Look up an element by index.
1753 OpenBundleType::BundleElement OpenBundleType::getElement(size_t index) {
1754  assert(index < getNumElements() &&
1755  "index must be less than number of fields in bundle");
1756  return getElements()[index];
1757 }
1758 
1759 OpenBundleType::ElementType OpenBundleType::getElementType(StringAttr name) {
1760  auto element = getElement(name);
1761  return element ? element->type : FIRRTLBaseType();
1762 }
1763 
1764 OpenBundleType::ElementType OpenBundleType::getElementType(StringRef name) {
1765  auto element = getElement(name);
1766  return element ? element->type : FIRRTLBaseType();
1767 }
1768 
1769 OpenBundleType::ElementType OpenBundleType::getElementType(size_t index) const {
1770  assert(index < getNumElements() &&
1771  "index must be less than number of fields in bundle");
1772  return getElements()[index].type;
1773 }
1774 
1775 uint64_t OpenBundleType::getFieldID(uint64_t index) const {
1776  return getImpl()->fieldIDs[index];
1777 }
1778 
1779 uint64_t OpenBundleType::getIndexForFieldID(uint64_t fieldID) const {
1780  assert(!getElements().empty() && "Bundle must have >0 fields");
1781  auto fieldIDs = getImpl()->fieldIDs;
1782  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
1783  return std::distance(fieldIDs.begin(), it);
1784 }
1785 
1786 std::pair<uint64_t, uint64_t>
1787 OpenBundleType::getIndexAndSubfieldID(uint64_t fieldID) const {
1788  auto index = getIndexForFieldID(fieldID);
1789  auto elementFieldID = getFieldID(index);
1790  return {index, fieldID - elementFieldID};
1791 }
1792 
1793 std::pair<Type, uint64_t>
1794 OpenBundleType::getSubTypeByFieldID(uint64_t fieldID) const {
1795  if (fieldID == 0)
1796  return {*this, 0};
1797  auto subfieldIndex = getIndexForFieldID(fieldID);
1798  auto subfieldType = getElementType(subfieldIndex);
1799  auto subfieldID = fieldID - getFieldID(subfieldIndex);
1800  return {subfieldType, subfieldID};
1801 }
1802 
1803 uint64_t OpenBundleType::getMaxFieldID() const { return getImpl()->maxFieldID; }
1804 
1805 std::pair<uint64_t, bool>
1806 OpenBundleType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
1807  auto childRoot = getFieldID(index);
1808  auto rangeEnd = index + 1 >= getNumElements() ? getMaxFieldID()
1809  : (getFieldID(index + 1) - 1);
1810  return std::make_pair(fieldID - childRoot,
1811  fieldID >= childRoot && fieldID <= rangeEnd);
1812 }
1813 
1814 bool OpenBundleType::isConst() { return getImpl()->isConst; }
1815 
1816 OpenBundleType::ElementType
1817 OpenBundleType::getElementTypePreservingConst(size_t index) {
1818  auto type = getElementType(index);
1819  // TODO: ConstTypeInterface / Trait ?
1820  return TypeSwitch<FIRRTLType, ElementType>(type)
1821  .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>([&](auto type) {
1822  return type.getConstType(type.isConst() || isConst());
1823  })
1824  .Default(type);
1825 }
1826 
1827 LogicalResult
1828 OpenBundleType::verify(function_ref<InFlightDiagnostic()> emitErrorFn,
1829  ArrayRef<BundleElement> elements, bool isConst) {
1830  for (auto &element : elements) {
1831  if (FIRRTLType(element.type).containsReference() && isConst)
1832  return emitErrorFn()
1833  << "'const' bundle cannot have references, but element "
1834  << element.name << " has type " << element.type;
1835  }
1836 
1837  return success();
1838 }
1839 
1840 //===----------------------------------------------------------------------===//
1841 // FVectorType
1842 //===----------------------------------------------------------------------===//
1843 
1846  using KeyTy = std::tuple<FIRRTLBaseType, size_t, char>;
1847 
1849  bool isConst)
1852  props(elementType.getRecursiveTypeProperties()) {
1854  }
1855 
1856  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1857 
1859 
1860  static FVectorTypeStorage *construct(TypeStorageAllocator &allocator,
1861  KeyTy key) {
1862  return new (allocator.allocate<FVectorTypeStorage>())
1863  FVectorTypeStorage(std::get<0>(key), std::get<1>(key),
1864  static_cast<bool>(std::get<2>(key)));
1865  }
1866 
1868  size_t numElements;
1869 
1870  /// This holds the bits for the type's recursive properties, and can hold a
1871  /// pointer to a passive version of the type.
1875 };
1876 
1878  bool isConst) {
1879  return Base::get(elementType.getContext(), elementType, numElements, isConst);
1880 }
1881 
1882 FIRRTLBaseType FVectorType::getElementType() const {
1883  return getImpl()->elementType;
1884 }
1885 
1886 size_t FVectorType::getNumElements() const { return getImpl()->numElements; }
1887 
1888 /// Return the recursive properties of the type.
1889 RecursiveTypeProperties FVectorType::getRecursiveTypeProperties() const {
1890  return getImpl()->props;
1891 }
1892 
1893 /// Return this type with any flip types recursively removed from itself.
1895  auto *impl = getImpl();
1896 
1897  // If we've already determined and cached the passive type, use it.
1898  if (impl->passiveType)
1899  return impl->passiveType;
1900 
1901  // If this type is already passive, return it and remember for next time.
1902  if (impl->elementType.getRecursiveTypeProperties().isPassive)
1903  return impl->passiveType = *this;
1904 
1905  // Otherwise, rebuild a passive version.
1906  auto passiveType = FVectorType::get(getElementType().getPassiveType(),
1907  getNumElements(), isConst());
1908  impl->passiveType = passiveType;
1909  return passiveType;
1910 }
1911 
1912 FVectorType FVectorType::getConstType(bool isConst) {
1913  if (isConst == this->isConst())
1914  return *this;
1915  return get(getElementType(), getNumElements(), isConst);
1916 }
1917 
1918 FVectorType FVectorType::getAllConstDroppedType() {
1919  if (!containsConst())
1920  return *this;
1921  return get(getElementType().getAllConstDroppedType(), getNumElements(),
1922  false);
1923 }
1924 
1925 /// Return this type with any type aliases recursively removed from itself.
1926 FIRRTLBaseType FVectorType::getAnonymousType() {
1927  auto *impl = getImpl();
1928 
1929  if (impl->anonymousType)
1930  return impl->anonymousType;
1931 
1932  // If this type is already anonymous, return it and remember for next time.
1933  if (!impl->props.containsTypeAlias)
1934  return impl->anonymousType = *this;
1935 
1936  // Otherwise, rebuild an anonymous version.
1937  auto anonymousType = FVectorType::get(getElementType().getAnonymousType(),
1938  getNumElements(), isConst());
1939  impl->anonymousType = anonymousType;
1940  return anonymousType;
1941 }
1942 
1943 uint64_t FVectorType::getFieldID(uint64_t index) const {
1944  return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
1945 }
1946 
1947 uint64_t FVectorType::getIndexForFieldID(uint64_t fieldID) const {
1948  assert(fieldID && "fieldID must be at least 1");
1949  // Divide the field ID by the number of fieldID's per element.
1950  return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
1951 }
1952 
1953 std::pair<uint64_t, uint64_t>
1954 FVectorType::getIndexAndSubfieldID(uint64_t fieldID) const {
1955  auto index = getIndexForFieldID(fieldID);
1956  auto elementFieldID = getFieldID(index);
1957  return {index, fieldID - elementFieldID};
1958 }
1959 
1960 std::pair<Type, uint64_t>
1961 FVectorType::getSubTypeByFieldID(uint64_t fieldID) const {
1962  if (fieldID == 0)
1963  return {*this, 0};
1964  return {getElementType(), getIndexAndSubfieldID(fieldID).second};
1965 }
1966 
1967 uint64_t FVectorType::getMaxFieldID() const {
1968  return getNumElements() *
1969  (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
1970 }
1971 
1972 std::pair<uint64_t, bool>
1973 FVectorType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
1974  auto childRoot = getFieldID(index);
1975  auto rangeEnd =
1976  index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
1977  return std::make_pair(fieldID - childRoot,
1978  fieldID >= childRoot && fieldID <= rangeEnd);
1979 }
1980 
1981 bool FVectorType::isConst() { return getImpl()->isConst; }
1982 
1983 FVectorType::ElementType FVectorType::getElementTypePreservingConst() {
1984  auto type = getElementType();
1985  return type.getConstType(type.isConst() || isConst());
1986 }
1987 
1988 //===----------------------------------------------------------------------===//
1989 // OpenVectorType
1990 //===----------------------------------------------------------------------===//
1991 
1993  using KeyTy = std::tuple<FIRRTLType, size_t, char>;
1994 
1996  bool isConst)
1998  isConst(static_cast<char>(isConst)) {
2001  }
2002 
2003  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
2004 
2006 
2007  static OpenVectorTypeStorage *construct(TypeStorageAllocator &allocator,
2008  KeyTy key) {
2009  return new (allocator.allocate<OpenVectorTypeStorage>())
2010  OpenVectorTypeStorage(std::get<0>(key), std::get<1>(key),
2011  static_cast<bool>(std::get<2>(key)));
2012  }
2013 
2015  size_t numElements;
2016 
2018  char isConst;
2019 };
2020 
2021 OpenVectorType OpenVectorType::get(FIRRTLType elementType, size_t numElements,
2022  bool isConst) {
2023  return Base::get(elementType.getContext(), elementType, numElements, isConst);
2024 }
2025 
2026 FIRRTLType OpenVectorType::getElementType() const {
2027  return getImpl()->elementType;
2028 }
2029 
2030 size_t OpenVectorType::getNumElements() const { return getImpl()->numElements; }
2031 
2032 /// Return the recursive properties of the type.
2033 RecursiveTypeProperties OpenVectorType::getRecursiveTypeProperties() const {
2034  return getImpl()->props;
2035 }
2036 
2037 OpenVectorType OpenVectorType::getConstType(bool isConst) {
2038  if (isConst == this->isConst())
2039  return *this;
2040  return get(getElementType(), getNumElements(), isConst);
2041 }
2042 
2043 uint64_t OpenVectorType::getFieldID(uint64_t index) const {
2044  return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
2045 }
2046 
2047 uint64_t OpenVectorType::getIndexForFieldID(uint64_t fieldID) const {
2048  assert(fieldID && "fieldID must be at least 1");
2049  // Divide the field ID by the number of fieldID's per element.
2050  return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
2051 }
2052 
2053 std::pair<uint64_t, uint64_t>
2054 OpenVectorType::getIndexAndSubfieldID(uint64_t fieldID) const {
2055  auto index = getIndexForFieldID(fieldID);
2056  auto elementFieldID = getFieldID(index);
2057  return {index, fieldID - elementFieldID};
2058 }
2059 
2060 std::pair<Type, uint64_t>
2061 OpenVectorType::getSubTypeByFieldID(uint64_t fieldID) const {
2062  if (fieldID == 0)
2063  return {*this, 0};
2064  return {getElementType(), getIndexAndSubfieldID(fieldID).second};
2065 }
2066 
2067 uint64_t OpenVectorType::getMaxFieldID() const {
2068  // If this is requirement, make ODS constraint or actual elementType.
2069  return getNumElements() *
2070  (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
2071 }
2072 
2073 std::pair<uint64_t, bool>
2074 OpenVectorType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
2075  auto childRoot = getFieldID(index);
2076  auto rangeEnd =
2077  index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
2078  return std::make_pair(fieldID - childRoot,
2079  fieldID >= childRoot && fieldID <= rangeEnd);
2080 }
2081 
2082 bool OpenVectorType::isConst() { return getImpl()->isConst; }
2083 
2084 OpenVectorType::ElementType OpenVectorType::getElementTypePreservingConst() {
2085  auto type = getElementType();
2086  // TODO: ConstTypeInterface / Trait ?
2087  return TypeSwitch<FIRRTLType, ElementType>(type)
2088  .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>([&](auto type) {
2089  return type.getConstType(type.isConst() || isConst());
2090  })
2091  .Default(type);
2092 }
2093 
2094 LogicalResult
2095 OpenVectorType::verify(function_ref<InFlightDiagnostic()> emitErrorFn,
2097  bool isConst) {
2098  if (elementType.containsReference() && isConst)
2099  return emitErrorFn() << "vector cannot be const with references";
2100  return success();
2101 }
2102 
2103 //===----------------------------------------------------------------------===//
2104 // FEnum Type
2105 //===----------------------------------------------------------------------===//
2106 
2108  using KeyTy = std::pair<ArrayRef<FEnumType::EnumElement>, char>;
2109 
2110  FEnumTypeStorage(ArrayRef<FEnumType::EnumElement> elements, bool isConst)
2112  elements(elements.begin(), elements.end()) {
2113  RecursiveTypeProperties props{true, false, false, isConst,
2114  false, false, false};
2115  uint64_t fieldID = 0;
2116  fieldIDs.reserve(elements.size());
2117  for (auto &element : elements) {
2118  auto type = element.type;
2119  auto eltInfo = type.getRecursiveTypeProperties();
2120  props.isPassive &= eltInfo.isPassive;
2121  props.containsAnalog |= eltInfo.containsAnalog;
2122  props.containsConst |= eltInfo.containsConst;
2123  props.containsReference |= eltInfo.containsReference;
2124  props.containsTypeAlias |= eltInfo.containsTypeAlias;
2125  props.hasUninferredReset |= eltInfo.hasUninferredReset;
2126  props.hasUninferredWidth |= eltInfo.hasUninferredWidth;
2127  fieldID += 1;
2128  fieldIDs.push_back(fieldID);
2129  // Increment the field ID for the next field by the number of subfields.
2130  fieldID += hw::FieldIdImpl::getMaxFieldID(type);
2131  }
2132  maxFieldID = fieldID;
2133  recProps = props;
2134  }
2135 
2136  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
2137 
2138  KeyTy getAsKey() const { return KeyTy(elements, isConst); }
2139 
2140  static llvm::hash_code hashKey(const KeyTy &key) {
2141  return llvm::hash_combine(
2142  llvm::hash_combine_range(key.first.begin(), key.first.end()),
2143  key.second);
2144  }
2145 
2146  static FEnumTypeStorage *construct(TypeStorageAllocator &allocator,
2147  KeyTy key) {
2148  return new (allocator.allocate<FEnumTypeStorage>())
2149  FEnumTypeStorage(key.first, static_cast<bool>(key.second));
2150  }
2151 
2152  SmallVector<FEnumType::EnumElement, 4> elements;
2153  SmallVector<uint64_t, 4> fieldIDs;
2154  uint64_t maxFieldID;
2155 
2158 };
2159 
2160 FEnumType FEnumType::get(::mlir::MLIRContext *context,
2161  ArrayRef<EnumElement> elements, bool isConst) {
2162  return Base::get(context, elements, isConst);
2163 }
2164 
2165 ArrayRef<FEnumType::EnumElement> FEnumType::getElements() const {
2166  return getImpl()->elements;
2167 }
2168 
2169 FEnumType FEnumType::getConstType(bool isConst) {
2170  return get(getContext(), getElements(), isConst);
2171 }
2172 
2173 FEnumType FEnumType::getAllConstDroppedType() {
2174  if (!containsConst())
2175  return *this;
2176 
2177  SmallVector<EnumElement> constDroppedElements(
2178  llvm::map_range(getElements(), [](EnumElement element) {
2179  element.type = element.type.getAllConstDroppedType();
2180  return element;
2181  }));
2182  return get(getContext(), constDroppedElements, false);
2183 }
2184 
2185 /// Return a pair with the 'isPassive' and 'containsAnalog' bits.
2186 RecursiveTypeProperties FEnumType::getRecursiveTypeProperties() const {
2187  return getImpl()->recProps;
2188 }
2189 
2190 std::optional<unsigned> FEnumType::getElementIndex(StringAttr name) {
2191  for (const auto &it : llvm::enumerate(getElements())) {
2192  auto element = it.value();
2193  if (element.name == name) {
2194  return unsigned(it.index());
2195  }
2196  }
2197  return std::nullopt;
2198 }
2199 
2200 std::optional<unsigned> FEnumType::getElementIndex(StringRef name) {
2201  for (const auto &it : llvm::enumerate(getElements())) {
2202  auto element = it.value();
2203  if (element.name.getValue() == name) {
2204  return unsigned(it.index());
2205  }
2206  }
2207  return std::nullopt;
2208 }
2209 
2210 StringAttr FEnumType::getElementNameAttr(size_t index) {
2211  assert(index < getNumElements() &&
2212  "index must be less than number of fields in enum");
2213  return getElements()[index].name;
2214 }
2215 
2216 StringRef FEnumType::getElementName(size_t index) {
2217  return getElementNameAttr(index).getValue();
2218 }
2219 
2220 std::optional<FEnumType::EnumElement> FEnumType::getElement(StringAttr name) {
2221  if (auto maybeIndex = getElementIndex(name))
2222  return getElements()[*maybeIndex];
2223  return std::nullopt;
2224 }
2225 
2226 std::optional<FEnumType::EnumElement> FEnumType::getElement(StringRef name) {
2227  if (auto maybeIndex = getElementIndex(name))
2228  return getElements()[*maybeIndex];
2229  return std::nullopt;
2230 }
2231 
2232 /// Look up an element by index.
2233 FEnumType::EnumElement FEnumType::getElement(size_t index) {
2234  assert(index < getNumElements() &&
2235  "index must be less than number of fields in enum");
2236  return getElements()[index];
2237 }
2238 
2239 FIRRTLBaseType FEnumType::getElementType(StringAttr name) {
2240  auto element = getElement(name);
2241  return element ? element->type : FIRRTLBaseType();
2242 }
2243 
2244 FIRRTLBaseType FEnumType::getElementType(StringRef name) {
2245  auto element = getElement(name);
2246  return element ? element->type : FIRRTLBaseType();
2247 }
2248 
2249 FIRRTLBaseType FEnumType::getElementType(size_t index) const {
2250  assert(index < getNumElements() &&
2251  "index must be less than number of fields in enum");
2252  return getElements()[index].type;
2253 }
2254 
2255 FIRRTLBaseType FEnumType::getElementTypePreservingConst(size_t index) {
2256  auto type = getElementType(index);
2257  return type.getConstType(type.isConst() || isConst());
2258 }
2259 
2260 uint64_t FEnumType::getFieldID(uint64_t index) const {
2261  return getImpl()->fieldIDs[index];
2262 }
2263 
2264 uint64_t FEnumType::getIndexForFieldID(uint64_t fieldID) const {
2265  assert(!getElements().empty() && "Enum must have >0 fields");
2266  auto fieldIDs = getImpl()->fieldIDs;
2267  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
2268  return std::distance(fieldIDs.begin(), it);
2269 }
2270 
2271 std::pair<uint64_t, uint64_t>
2272 FEnumType::getIndexAndSubfieldID(uint64_t fieldID) const {
2273  auto index = getIndexForFieldID(fieldID);
2274  auto elementFieldID = getFieldID(index);
2275  return {index, fieldID - elementFieldID};
2276 }
2277 
2278 std::pair<Type, uint64_t>
2279 FEnumType::getSubTypeByFieldID(uint64_t fieldID) const {
2280  if (fieldID == 0)
2281  return {*this, 0};
2282  auto subfieldIndex = getIndexForFieldID(fieldID);
2283  auto subfieldType = getElementType(subfieldIndex);
2284  auto subfieldID = fieldID - getFieldID(subfieldIndex);
2285  return {subfieldType, subfieldID};
2286 }
2287 
2288 uint64_t FEnumType::getMaxFieldID() const { return getImpl()->maxFieldID; }
2289 
2290 std::pair<uint64_t, bool>
2291 FEnumType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
2292  auto childRoot = getFieldID(index);
2293  auto rangeEnd = index + 1 >= getNumElements() ? getMaxFieldID()
2294  : (getFieldID(index + 1) - 1);
2295  return std::make_pair(fieldID - childRoot,
2296  fieldID >= childRoot && fieldID <= rangeEnd);
2297 }
2298 
2299 auto FEnumType::verify(function_ref<InFlightDiagnostic()> emitErrorFn,
2300  ArrayRef<EnumElement> elements, bool isConst)
2301  -> LogicalResult {
2302  for (auto &elt : elements) {
2303  auto r = elt.type.getRecursiveTypeProperties();
2304  if (!r.isPassive)
2305  return emitErrorFn() << "enum field '" << elt.name << "' not passive";
2306  if (r.containsAnalog)
2307  return emitErrorFn() << "enum field '" << elt.name << "' contains analog";
2308  if (r.containsConst && !isConst)
2309  return emitErrorFn() << "enum with 'const' elements must be 'const'";
2310  // TODO: exclude reference containing
2311  }
2312  return success();
2313 }
2314 
2315 /// Return this type with any type aliases recursively removed from itself.
2316 FIRRTLBaseType FEnumType::getAnonymousType() {
2317  auto *impl = getImpl();
2318 
2319  if (impl->anonymousType)
2320  return impl->anonymousType;
2321 
2322  if (!impl->recProps.containsTypeAlias)
2323  return impl->anonymousType = *this;
2324 
2325  SmallVector<FEnumType::EnumElement, 4> elements;
2326 
2327  for (auto element : getElements())
2328  elements.push_back({element.name, element.type.getAnonymousType()});
2329  return impl->anonymousType = FEnumType::get(getContext(), elements);
2330 }
2331 
2332 //===----------------------------------------------------------------------===//
2333 // BaseTypeAliasType
2334 //===----------------------------------------------------------------------===//
2335 
2338  using KeyTy = std::tuple<StringAttr, FIRRTLBaseType>;
2339 
2342  innerType(innerType) {}
2343 
2344  bool operator==(const KeyTy &key) const { return key == getAsKey(); }
2345 
2346  KeyTy getAsKey() const { return KeyTy(name, innerType); }
2347 
2348  static llvm::hash_code hashKey(const KeyTy &key) {
2349  return llvm::hash_combine(key);
2350  }
2351 
2352  static BaseTypeAliasStorage *construct(TypeStorageAllocator &allocator,
2353  KeyTy key) {
2354  return new (allocator.allocate<BaseTypeAliasStorage>())
2355  BaseTypeAliasStorage(std::get<0>(key), std::get<1>(key));
2356  }
2357  StringAttr name;
2360 };
2361 
2362 auto BaseTypeAliasType::get(StringAttr name, FIRRTLBaseType innerType)
2363  -> BaseTypeAliasType {
2364  return Base::get(name.getContext(), name, innerType);
2365 }
2366 
2367 auto BaseTypeAliasType::getName() const -> StringAttr {
2368  return getImpl()->name;
2369 }
2370 
2371 auto BaseTypeAliasType::getInnerType() const -> FIRRTLBaseType {
2372  return getImpl()->innerType;
2373 }
2374 
2375 FIRRTLBaseType BaseTypeAliasType::getAnonymousType() {
2376  auto *impl = getImpl();
2377  if (impl->anonymousType)
2378  return impl->anonymousType;
2379  return impl->anonymousType = getInnerType().getAnonymousType();
2380 }
2381 
2383  return getModifiedType(getInnerType().getPassiveType());
2384 }
2385 
2386 RecursiveTypeProperties BaseTypeAliasType::getRecursiveTypeProperties() const {
2387  auto rtp = getInnerType().getRecursiveTypeProperties();
2388  rtp.containsTypeAlias = true;
2389  return rtp;
2390 }
2391 
2392 // If a given `newInnerType` is identical to innerType, return `*this`
2393 // because we can reuse the type alias. Otherwise return `newInnerType`.
2394 FIRRTLBaseType BaseTypeAliasType::getModifiedType(FIRRTLBaseType newInnerType) {
2395  if (newInnerType == getInnerType())
2396  return *this;
2397  return newInnerType;
2398 }
2399 
2400 // FieldIDTypeInterface implementation.
2401 FIRRTLBaseType BaseTypeAliasType::getAllConstDroppedType() {
2402  return getModifiedType(getInnerType().getAllConstDroppedType());
2403 }
2404 
2405 FIRRTLBaseType BaseTypeAliasType::getConstType(bool isConst) {
2406  return getModifiedType(getInnerType().getConstType(isConst));
2407 }
2408 
2409 std::pair<Type, uint64_t>
2410 BaseTypeAliasType::getSubTypeByFieldID(uint64_t fieldID) const {
2411  return hw::FieldIdImpl::getSubTypeByFieldID(getInnerType(), fieldID);
2412 }
2413 
2414 uint64_t BaseTypeAliasType::getMaxFieldID() const {
2415  return hw::FieldIdImpl::getMaxFieldID(getInnerType());
2416 }
2417 
2418 std::pair<uint64_t, bool>
2420  uint64_t index) const {
2421  return hw::FieldIdImpl::projectToChildFieldID(getInnerType(), fieldID, index);
2422 }
2423 
2424 uint64_t BaseTypeAliasType::getIndexForFieldID(uint64_t fieldID) const {
2425  return hw::FieldIdImpl::getIndexForFieldID(getInnerType(), fieldID);
2426 }
2427 
2428 uint64_t BaseTypeAliasType::getFieldID(uint64_t index) const {
2429  return hw::FieldIdImpl::getFieldID(getInnerType(), index);
2430 }
2431 
2432 std::pair<uint64_t, uint64_t>
2433 BaseTypeAliasType::getIndexAndSubfieldID(uint64_t fieldID) const {
2434  return hw::FieldIdImpl::getIndexAndSubfieldID(getInnerType(), fieldID);
2435 }
2436 
2437 //===----------------------------------------------------------------------===//
2438 // RefType
2439 //===----------------------------------------------------------------------===//
2440 
2441 auto RefType::get(FIRRTLBaseType type, bool forceable, SymbolRefAttr layer)
2442  -> RefType {
2443  return Base::get(type.getContext(), type, forceable, layer);
2444 }
2445 
2446 auto RefType::verify(function_ref<InFlightDiagnostic()> emitErrorFn,
2447  FIRRTLBaseType base, bool forceable, SymbolRefAttr layer)
2448  -> LogicalResult {
2449  if (!base.isPassive())
2450  return emitErrorFn() << "reference base type must be passive";
2451  if (forceable && base.containsConst())
2452  return emitErrorFn()
2453  << "forceable reference base type cannot contain const";
2454  return success();
2455 }
2456 
2457 RecursiveTypeProperties RefType::getRecursiveTypeProperties() const {
2458  auto rtp = getType().getRecursiveTypeProperties();
2459  rtp.containsReference = true;
2460  // References are not "passive", per FIRRTL spec.
2461  rtp.isPassive = false;
2462  return rtp;
2463 }
2464 
2465 //===----------------------------------------------------------------------===//
2466 // AnalogType
2467 //===----------------------------------------------------------------------===//
2468 
2469 AnalogType AnalogType::get(mlir::MLIRContext *context) {
2470  return AnalogType::get(context, -1, false);
2471 }
2472 
2473 AnalogType AnalogType::get(mlir::MLIRContext *context,
2474  std::optional<int32_t> width, bool isConst) {
2475  return AnalogType::get(context, width ? *width : -1, isConst);
2476 }
2477 
2478 LogicalResult AnalogType::verify(function_ref<InFlightDiagnostic()> emitError,
2479  int32_t widthOrSentinel, bool isConst) {
2480  if (widthOrSentinel < -1)
2481  return emitError() << "invalid width";
2482  return success();
2483 }
2484 
2485 int32_t AnalogType::getWidthOrSentinel() const { return getImpl()->width; }
2486 
2487 AnalogType AnalogType::getConstType(bool isConst) {
2488  if (isConst == this->isConst())
2489  return *this;
2490  return get(getContext(), getWidthOrSentinel(), isConst);
2491 }
2492 
2493 //===----------------------------------------------------------------------===//
2494 // ClockType
2495 //===----------------------------------------------------------------------===//
2496 
2497 ClockType ClockType::getConstType(bool isConst) {
2498  if (isConst == this->isConst())
2499  return *this;
2500  return get(getContext(), isConst);
2501 }
2502 
2503 //===----------------------------------------------------------------------===//
2504 // ResetType
2505 //===----------------------------------------------------------------------===//
2506 
2507 ResetType ResetType::getConstType(bool isConst) {
2508  if (isConst == this->isConst())
2509  return *this;
2510  return get(getContext(), isConst);
2511 }
2512 
2513 //===----------------------------------------------------------------------===//
2514 // AsyncResetType
2515 //===----------------------------------------------------------------------===//
2516 
2517 AsyncResetType AsyncResetType::getConstType(bool isConst) {
2518  if (isConst == this->isConst())
2519  return *this;
2520  return get(getContext(), isConst);
2521 }
2522 
2523 //===----------------------------------------------------------------------===//
2524 // ClassType
2525 //===----------------------------------------------------------------------===//
2526 
2527 struct circt::firrtl::detail::ClassTypeStorage : mlir::TypeStorage {
2528  using KeyTy = std::pair<FlatSymbolRefAttr, ArrayRef<ClassElement>>;
2529 
2530  static ClassTypeStorage *construct(TypeStorageAllocator &allocator,
2531  KeyTy key) {
2532  auto name = key.first;
2533  auto elements = allocator.copyInto(key.second);
2534 
2535  // build the field ID table
2536  SmallVector<uint64_t, 4> ids;
2537  uint64_t id = 0;
2538  ids.reserve(elements.size());
2539  for (auto &element : elements) {
2540  id += 1;
2541  ids.push_back(id);
2542  id += hw::FieldIdImpl::getMaxFieldID(element.type);
2543  }
2544 
2545  auto fieldIDs = allocator.copyInto(ArrayRef(ids));
2546  auto maxFieldID = id;
2547 
2548  return new (allocator.allocate<ClassTypeStorage>())
2550  }
2551 
2552  ClassTypeStorage(FlatSymbolRefAttr name, ArrayRef<ClassElement> elements,
2553  ArrayRef<uint64_t> fieldIDs, uint64_t maxFieldID)
2556 
2557  bool operator==(const KeyTy &key) const {
2558  return name == key.first && elements == key.second;
2559  }
2560 
2561  FlatSymbolRefAttr name;
2562  ArrayRef<ClassElement> elements;
2563  ArrayRef<uint64_t> fieldIDs;
2564  uint64_t maxFieldID;
2565 };
2566 
2567 ClassType ClassType::get(FlatSymbolRefAttr name,
2568  ArrayRef<ClassElement> elements) {
2569  return get(name.getContext(), name, elements);
2570 }
2571 
2572 StringRef ClassType::getName() const {
2573  return getNameAttr().getAttr().getValue();
2574 }
2575 
2576 FlatSymbolRefAttr ClassType::getNameAttr() const { return getImpl()->name; }
2577 
2578 ArrayRef<ClassElement> ClassType::getElements() const {
2579  return getImpl()->elements;
2580 }
2581 
2582 const ClassElement &ClassType::getElement(IntegerAttr index) const {
2583  return getElement(index.getValue().getZExtValue());
2584 }
2585 
2586 const ClassElement &ClassType::getElement(size_t index) const {
2587  return getElements()[index];
2588 }
2589 
2590 std::optional<uint64_t> ClassType::getElementIndex(StringRef fieldName) const {
2591  for (const auto [i, e] : llvm::enumerate(getElements()))
2592  if (fieldName == e.name)
2593  return i;
2594  return {};
2595 }
2596 
2597 void ClassType::printInterface(AsmPrinter &p) const {
2598  p.printSymbolName(getName());
2599  p << "(";
2600  bool first = true;
2601  for (const auto &element : getElements()) {
2602  if (!first)
2603  p << ", ";
2604  p << element.direction << " ";
2605  p.printKeywordOrString(element.name);
2606  p << ": " << element.type;
2607  first = false;
2608  }
2609  p << ")";
2610 }
2611 
2612 uint64_t ClassType::getFieldID(uint64_t index) const {
2613  return getImpl()->fieldIDs[index];
2614 }
2615 
2616 uint64_t ClassType::getIndexForFieldID(uint64_t fieldID) const {
2617  assert(!getElements().empty() && "Class must have >0 fields");
2618  auto fieldIDs = getImpl()->fieldIDs;
2619  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
2620  return std::distance(fieldIDs.begin(), it);
2621 }
2622 
2623 std::pair<uint64_t, uint64_t>
2624 ClassType::getIndexAndSubfieldID(uint64_t fieldID) const {
2625  auto index = getIndexForFieldID(fieldID);
2626  auto elementFieldID = getFieldID(index);
2627  return {index, fieldID - elementFieldID};
2628 }
2629 
2630 std::pair<Type, uint64_t>
2631 ClassType::getSubTypeByFieldID(uint64_t fieldID) const {
2632  if (fieldID == 0)
2633  return {*this, 0};
2634  auto subfieldIndex = getIndexForFieldID(fieldID);
2635  auto subfieldType = getElement(subfieldIndex).type;
2636  auto subfieldID = fieldID - getFieldID(subfieldIndex);
2637  return {subfieldType, subfieldID};
2638 }
2639 
2640 uint64_t ClassType::getMaxFieldID() const { return getImpl()->maxFieldID; }
2641 
2642 std::pair<uint64_t, bool>
2643 ClassType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
2644  auto childRoot = getFieldID(index);
2645  auto rangeEnd = index + 1 >= getNumElements() ? getMaxFieldID()
2646  : (getFieldID(index + 1) - 1);
2647  return std::make_pair(fieldID - childRoot,
2648  fieldID >= childRoot && fieldID <= rangeEnd);
2649 }
2650 
2651 ParseResult ClassType::parseInterface(AsmParser &parser, ClassType &result) {
2652  StringAttr className;
2653  if (parser.parseSymbolName(className))
2654  return failure();
2655 
2656  SmallVector<ClassElement> elements;
2657  if (parser.parseCommaSeparatedList(
2658  OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
2659  // Parse port direction.
2660  Direction direction;
2661  if (succeeded(parser.parseOptionalKeyword("out")))
2662  direction = Direction::Out;
2663  else if (succeeded(parser.parseKeyword("in", "or 'out'")))
2664  direction = Direction::In;
2665  else
2666  return failure();
2667 
2668  // Parse port name.
2669  std::string keyword;
2670  if (parser.parseKeywordOrString(&keyword))
2671  return failure();
2672  StringAttr name = StringAttr::get(parser.getContext(), keyword);
2673 
2674  // Parse port type.
2675  Type type;
2676  if (parser.parseColonType(type))
2677  return failure();
2678 
2679  elements.emplace_back(name, type, direction);
2680  return success();
2681  }))
2682  return failure();
2683 
2684  result = ClassType::get(FlatSymbolRefAttr::get(className), elements);
2685  return success();
2686 }
2687 
2688 //===----------------------------------------------------------------------===//
2689 // FIRRTLDialect
2690 //===----------------------------------------------------------------------===//
2691 
2692 void FIRRTLDialect::registerTypes() {
2693  addTypes<
2694 #define GET_TYPEDEF_LIST
2695 #include "circt/Dialect/FIRRTL/FIRRTLTypes.cpp.inc"
2696  >();
2697 }
2698 
2699 // Get the bit width for this type, return None if unknown. Unlike
2700 // getBitWidthOrSentinel(), this can recursively compute the bitwidth of
2701 // aggregate types. For bundle and vectors, recursively get the width of each
2702 // field element and return the total bit width of the aggregate type. This
2703 // returns None, if any of the bundle fields is a flip type, or ground type with
2704 // unknown bit width.
2705 std::optional<int64_t> firrtl::getBitWidth(FIRRTLBaseType type,
2706  bool ignoreFlip) {
2707  std::function<std::optional<int64_t>(FIRRTLBaseType)> getWidth =
2708  [&](FIRRTLBaseType type) -> std::optional<int64_t> {
2709  return TypeSwitch<FIRRTLBaseType, std::optional<int64_t>>(type)
2710  .Case<BundleType>([&](BundleType bundle) -> std::optional<int64_t> {
2711  int64_t width = 0;
2712  for (auto &elt : bundle) {
2713  if (elt.isFlip && !ignoreFlip)
2714  return std::nullopt;
2715  auto w = getBitWidth(elt.type);
2716  if (!w.has_value())
2717  return std::nullopt;
2718  width += *w;
2719  }
2720  return width;
2721  })
2722  .Case<FEnumType>([&](FEnumType fenum) -> std::optional<int64_t> {
2723  int64_t width = 0;
2724  for (auto &elt : fenum) {
2725  auto w = getBitWidth(elt.type);
2726  if (!w.has_value())
2727  return std::nullopt;
2728  width = std::max(width, *w);
2729  }
2730  return width + llvm::Log2_32_Ceil(fenum.getNumElements());
2731  })
2732  .Case<FVectorType>([&](auto vector) -> std::optional<int64_t> {
2733  auto w = getBitWidth(vector.getElementType());
2734  if (!w.has_value())
2735  return std::nullopt;
2736  return *w * vector.getNumElements();
2737  })
2738  .Case<IntType>([&](IntType iType) { return iType.getWidth(); })
2739  .Case<ClockType, ResetType, AsyncResetType>([](Type) { return 1; })
2740  .Default([&](auto t) { return std::nullopt; });
2741  };
2742  return getWidth(type);
2743 }
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition: CHIRRTL.cpp:30
MlirType elementType
Definition: CHIRRTL.cpp:29
static ParseResult parseFIRRTLBaseType(FIRRTLBaseType &result, StringRef name, AsmParser &parser)
static ParseResult parseFIRRTLPropertyType(PropertyType &result, StringRef name, AsmParser &parser)
static LogicalResult customTypePrinter(Type type, AsmPrinter &os)
Print a type with a custom printer implementation.
Definition: FIRRTLTypes.cpp:45
static OptionalParseResult customTypeParser(AsmParser &parser, StringRef name, Type &result)
Parse a type with a custom parser implementation.
static ParseResult parseType(Type &result, StringRef name, AsmParser &parser)
Parse a type defined by this dialect.
@ ContainsAnalogBitMask
Bit set if the type contains an analog type.
@ HasUninferredWidthBitMask
Bit set fi the type has any uninferred bit widths.
@ IsPassiveBitMask
Bit set if the type only contains passive elements.
static bool areBundleElementsEquivalent(BundleType::BundleElement destElement, BundleType::BundleElement srcElement, bool destOuterTypeIsConst, bool srcOuterTypeIsConst, bool requiresSameWidth)
Helper to implement the equivalence logic for a pair of bundle elements.
static ParseResult parseFIRRTLType(FIRRTLType &result, StringRef name, AsmParser &parser)
Parse a FIRRTLType with a name that has already been parsed.
int32_t width
Definition: FIRRTL.cpp:36
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
static unsigned getMaxFieldID(FIRRTLBaseType type)
static InstancePath empty
FIRRTLBaseType getAnonymousType()
Return this type with any type alias types recursively removed from itself.
bool isResetType()
Return true if this is a valid "reset" type.
FIRRTLBaseType getMaskType()
Return this type with all ground types replaced with UInt<1>.
FIRRTLBaseType getPassiveType()
Return this type with any flip types recursively removed from itself.
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
FIRRTLBaseType getAllConstDroppedType()
Return this type with a 'const' modifiers dropped.
FIRRTLBaseType getConstType(bool isConst)
Return a 'const' or non-'const' version of this type.
FIRRTLBaseType getWidthlessType()
Return this type with widths of all ground types removed.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
Definition: FIRRTLTypes.h:518
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
Definition: FIRRTLTypes.h:528
bool containsReference()
Return true if this is or contains a Reference type.
Definition: FIRRTLTypes.h:105
bool isGround()
Return true if this is a 'ground' type, aka a non-aggregate type.
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
RecursiveTypeProperties getRecursiveTypeProperties() const
Return the recursive properties of the type, containing the isPassive, containsAnalog,...
This is the common base class between SIntType and UIntType.
Definition: FIRRTLTypes.h:294
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
IntType getConstType(bool isConst)
Return a 'const' or non-'const' version of this type.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
Definition: FIRRTLTypes.h:273
Represents a limited word-length unsigned integer in SystemC as described in IEEE 1666-2011 ยง7....
Definition: SystemCTypes.h:141
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:34
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
ParseResult parseNestedType(FIRRTLType &result, AsmParser &parser)
Parse a FIRRTLType.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
ParseResult parseNestedBaseType(FIRRTLBaseType &result, AsmParser &parser)
bool isTypeInOut(mlir::Type type)
Returns true if the given type has some flipped (aka unaligned) dataflow.
bool areTypesRefCastable(Type dstType, Type srcType)
Return true if destination ref type can be cast from source ref type, per FIRRTL spec rules they must...
bool areTypesEquivalent(FIRRTLType destType, FIRRTLType srcType, bool destOuterTypeIsConst=false, bool srcOuterTypeIsConst=false, bool requireSameWidths=false)
Returns whether the two types are equivalent.
bool areTypesWeaklyEquivalent(FIRRTLType destType, FIRRTLType srcType, bool destFlip=false, bool srcFlip=false, bool destOuterTypeIsConst=false, bool srcOuterTypeIsConst=false)
Returns true if two types are weakly equivalent.
mlir::Type getPassiveType(mlir::Type anyBaseFIRRTLType)
bool isTypeLarger(FIRRTLBaseType dstType, FIRRTLBaseType srcType)
Returns true if the destination is at least as wide as a source.
bool containsConst(Type type)
Returns true if the type is or contains a 'const' type whose value is guaranteed to be unchanging at ...
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
void printNestedType(Type type, AsmPrinter &os)
Print a type defined by this dialect.
bool isConst(Type type)
Returns true if this is a 'const' type whose value is guaranteed to be unchanging at circuit executio...
bool areTypesConstCastable(FIRRTLType destType, FIRRTLType srcType, bool srcOuterTypeIsConst=false)
Returns whether the srcType can be const-casted to the destType.
ParseResult parseNestedPropertyType(PropertyType &result, AsmParser &parser)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
std::pair< uint64_t, uint64_t > getIndexAndSubfieldID(Type type, uint64_t fieldID)
uint64_t getFieldID(Type type, uint64_t index)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
std::pair< uint64_t, bool > projectToChildFieldID(Type, uint64_t fieldID, uint64_t index)
uint64_t getIndexForFieldID(Type type, uint64_t fieldID)
uint64_t getMaxFieldID(Type)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
A collection of bits indicating the recursive properties of a type.
Definition: FIRRTLTypes.h:65
bool containsReference
Whether the type contains a reference type.
Definition: FIRRTLTypes.h:69
bool isPassive
Whether the type only contains passive elements.
Definition: FIRRTLTypes.h:67
bool containsAnalog
Whether the type contains an analog type.
Definition: FIRRTLTypes.h:71
bool hasUninferredReset
Whether the type has any uninferred reset.
Definition: FIRRTLTypes.h:79
bool containsTypeAlias
Whether the type contains a type alias.
Definition: FIRRTLTypes.h:75
bool containsConst
Whether the type contains a const type.
Definition: FIRRTLTypes.h:73
bool hasUninferredWidth
Whether the type has any uninferred bit widths.
Definition: FIRRTLTypes.h:77
bool operator==(const KeyTy &key) const
static BaseTypeAliasStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
BaseTypeAliasStorage(StringAttr name, FIRRTLBaseType innerType)
std::tuple< StringAttr, FIRRTLBaseType > KeyTy
static llvm::hash_code hashKey(const KeyTy &key)
SmallVector< BundleType::BundleElement, 4 > elements
static BundleTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
std::pair< ArrayRef< BundleType::BundleElement >, char > KeyTy
static llvm::hash_code hashKey(const KeyTy &key)
RecursiveTypeProperties props
This holds the bits for the type's recursive properties, and can hold a pointer to a passive version ...
BundleTypeStorage(ArrayRef< BundleType::BundleElement > elements, bool isConst)
bool operator==(const KeyTy &key) const
SmallVector< uint64_t, 4 > fieldIDs
std::pair< FlatSymbolRefAttr, ArrayRef< ClassElement > > KeyTy
bool operator==(const KeyTy &key) const
static ClassTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
ClassTypeStorage(FlatSymbolRefAttr name, ArrayRef< ClassElement > elements, ArrayRef< uint64_t > fieldIDs, uint64_t maxFieldID)
SmallVector< FEnumType::EnumElement, 4 > elements
static llvm::hash_code hashKey(const KeyTy &key)
bool operator==(const KeyTy &key) const
FEnumTypeStorage(ArrayRef< FEnumType::EnumElement > elements, bool isConst)
std::pair< ArrayRef< FEnumType::EnumElement >, char > KeyTy
static FEnumTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
SmallVector< uint64_t, 4 > fieldIDs
static FIRRTLBaseTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
bool operator==(const KeyTy &key) const
bool operator==(const KeyTy &key) const
RecursiveTypeProperties props
This holds the bits for the type's recursive properties, and can hold a pointer to a passive version ...
static FVectorTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
std::tuple< FIRRTLBaseType, size_t, char > KeyTy
FVectorTypeStorage(FIRRTLBaseType elementType, size_t numElements, bool isConst)
SmallVector< OpenBundleType::BundleElement, 4 > elements
static OpenBundleTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
bool operator==(const KeyTy &key) const
static llvm::hash_code hashKey(const KeyTy &key)
RecursiveTypeProperties props
This holds the bits for the type's recursive properties, and can hold a pointer to a passive version ...
OpenBundleTypeStorage(ArrayRef< OpenBundleType::BundleElement > elements, bool isConst)
std::pair< ArrayRef< OpenBundleType::BundleElement >, char > KeyTy
std::tuple< FIRRTLType, size_t, char > KeyTy
bool operator==(const KeyTy &key) const
OpenVectorTypeStorage(FIRRTLType elementType, size_t numElements, bool isConst)
static OpenVectorTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
WidthTypeStorage(int32_t width, bool isConst)
static WidthTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)
bool operator==(const KeyTy &key) const
std::pair< int32_t, char > KeyTy