Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
FIRRTLTypes.cpp
Go to the documentation of this file.
1//===- FIRRTLTypes.cpp - Implement the FIRRTL dialect type system ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implement the FIRRTL dialect type system.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/IR/DialectImplementation.h"
18#include "llvm/ADT/StringExtras.h"
19#include "llvm/ADT/StringSwitch.h"
20#include "llvm/ADT/TypeSwitch.h"
21
22using namespace circt;
23using namespace firrtl;
24
25using mlir::OptionalParseResult;
26using mlir::TypeStorageAllocator;
27
28//===----------------------------------------------------------------------===//
29// TableGen generated logic.
30//===----------------------------------------------------------------------===//
31
32// Provide the autogenerated implementation for types.
33#define GET_TYPEDEF_CLASSES
34#include "circt/Dialect/FIRRTL/FIRRTLTypes.cpp.inc"
35
36//===----------------------------------------------------------------------===//
37// Type Printing
38//===----------------------------------------------------------------------===//
39
40// NOLINTBEGIN(misc-no-recursion)
41/// Print a type with a custom printer implementation.
42///
43/// This only prints a subset of all types in the dialect. Use `printNestedType`
44/// instead, which will call this function in turn, as appropriate.
45static LogicalResult customTypePrinter(Type type, AsmPrinter &os) {
46 if (isConst(type))
47 os << "const.";
48
49 auto printWidthQualifier = [&](std::optional<int32_t> width) {
50 if (width)
51 os << '<' << *width << '>';
52 };
53 bool anyFailed = false;
54 TypeSwitch<Type>(type)
55 .Case<ClockType>([&](auto) { os << "clock"; })
56 .Case<ResetType>([&](auto) { os << "reset"; })
57 .Case<AsyncResetType>([&](auto) { os << "asyncreset"; })
58 .Case<SIntType>([&](auto sIntType) {
59 os << "sint";
60 printWidthQualifier(sIntType.getWidth());
61 })
62 .Case<UIntType>([&](auto uIntType) {
63 os << "uint";
64 printWidthQualifier(uIntType.getWidth());
65 })
66 .Case<AnalogType>([&](auto analogType) {
67 os << "analog";
68 printWidthQualifier(analogType.getWidth());
69 })
70 .Case<BundleType, OpenBundleType>([&](auto bundleType) {
71 if (firrtl::type_isa<OpenBundleType>(bundleType))
72 os << "open";
73 os << "bundle<";
74 llvm::interleaveComma(bundleType, os, [&](auto element) {
75 StringRef fieldName = element.name.getValue();
76 bool isLiteralIdentifier =
77 !fieldName.empty() && llvm::isDigit(fieldName.front());
78 if (isLiteralIdentifier)
79 os << "\"";
80 os << element.name.getValue();
81 if (isLiteralIdentifier)
82 os << "\"";
83 if (element.isFlip)
84 os << " flip";
85 os << ": ";
86 printNestedType(element.type, os);
87 });
88 os << '>';
89 })
90 .Case<FEnumType>([&](auto fenumType) {
91 os << "enum<";
92 llvm::interleaveComma(fenumType, os,
93 [&](FEnumType::EnumElement element) {
94 os << element.name.getValue();
95 os << ": ";
96 printNestedType(element.type, os);
97 });
98 os << '>';
99 })
100 .Case<FVectorType, OpenVectorType>([&](auto vectorType) {
101 if (firrtl::type_isa<OpenVectorType>(vectorType))
102 os << "open";
103 os << "vector<";
104 printNestedType(vectorType.getElementType(), os);
105 os << ", " << vectorType.getNumElements() << '>';
106 })
107 .Case<RefType>([&](RefType refType) {
108 if (refType.getForceable())
109 os << "rw";
110 os << "probe<";
111 printNestedType(refType.getType(), os);
112 if (auto layer = refType.getLayer())
113 os << ", " << layer;
114 os << '>';
115 })
116 .Case<LHSType>([&](LHSType lhstype) {
117 os << "lhs<";
118 printNestedType(lhstype.getType(), os);
119 os << ">";
120 })
121 .Case<StringType>([&](auto stringType) { os << "string"; })
122 .Case<FIntegerType>([&](auto integerType) { os << "integer"; })
123 .Case<BoolType>([&](auto boolType) { os << "bool"; })
124 .Case<DoubleType>([&](auto doubleType) { os << "double"; })
125 .Case<ListType>([&](auto listType) {
126 os << "list<";
127 printNestedType(listType.getElementType(), os);
128 os << '>';
129 })
130 .Case<PathType>([&](auto pathType) { os << "path"; })
131 .Case<BaseTypeAliasType>([&](BaseTypeAliasType alias) {
132 os << "alias<" << alias.getName().getValue() << ", ";
133 printNestedType(alias.getInnerType(), os);
134 os << '>';
135 })
136 .Case<ClassType>([&](ClassType type) {
137 os << "class<";
138 type.printInterface(os);
139 os << ">";
140 })
141 .Case<AnyRefType>([&](AnyRefType type) { os << "anyref"; })
142 .Case<FStringType>([&](auto) { os << "fstring"; })
143 .Default([&](auto) { anyFailed = true; });
144 return failure(anyFailed);
145}
146// NOLINTEND(misc-no-recursion)
147
148/// Print a type defined by this dialect.
149void circt::firrtl::printNestedType(Type type, AsmPrinter &os) {
150 // Try the custom type printer.
151 if (succeeded(customTypePrinter(type, os)))
152 return;
153
154 // None of the above recognized the type, so we bail.
155 assert(false && "type to print unknown to FIRRTL dialect");
156}
157
158//===----------------------------------------------------------------------===//
159// Type Parsing
160//===----------------------------------------------------------------------===//
161
162/// Parse a type with a custom parser implementation.
163///
164/// This only accepts a subset of all types in the dialect. Use `parseType`
165/// instead, which will call this function in turn, as appropriate.
166///
167/// Returns `std::nullopt` if the type `name` is not covered by the custom
168/// parsers. Otherwise returns success or failure as appropriate. On success,
169/// `result` is set to the resulting type.
170///
171/// ```plain
172/// firrtl-type
173/// ::= clock
174/// ::= reset
175/// ::= asyncreset
176/// ::= sint ('<' int '>')?
177/// ::= uint ('<' int '>')?
178/// ::= analog ('<' int '>')?
179/// ::= bundle '<' (bundle-elt (',' bundle-elt)*)? '>'
180/// ::= enum '<' (enum-elt (',' enum-elt)*)? '>'
181/// ::= vector '<' type ',' int '>'
182/// ::= const '.' type
183/// ::= 'property.' firrtl-phased-type
184/// bundle-elt ::= identifier flip? ':' type
185/// enum-elt ::= identifier ':' type
186/// ```
187static OptionalParseResult customTypeParser(AsmParser &parser, StringRef name,
188 Type &result) {
189 bool isConst = false;
190 const char constPrefix[] = "const.";
191 if (name.starts_with(constPrefix)) {
192 isConst = true;
193 name = name.drop_front(std::size(constPrefix) - 1);
194 }
195
196 auto *context = parser.getContext();
197 if (name == "clock")
198 return result = ClockType::get(context, isConst), success();
199 if (name == "reset")
200 return result = ResetType::get(context, isConst), success();
201 if (name == "asyncreset")
202 return result = AsyncResetType::get(context, isConst), success();
203
204 if (name == "sint" || name == "uint" || name == "analog") {
205 // Parse the width specifier if it exists.
206 int32_t width = -1;
207 if (!parser.parseOptionalLess()) {
208 if (parser.parseInteger(width) || parser.parseGreater())
209 return failure();
210
211 if (width < 0)
212 return parser.emitError(parser.getNameLoc(), "unknown width"),
213 failure();
214 }
215
216 if (name == "sint")
217 result = SIntType::get(context, width, isConst);
218 else if (name == "uint")
219 result = UIntType::get(context, width, isConst);
220 else {
221 assert(name == "analog");
222 result = AnalogType::get(context, width, isConst);
223 }
224 return success();
225 }
226
227 if (name == "bundle") {
228 SmallVector<BundleType::BundleElement, 4> elements;
229
230 auto parseBundleElement = [&]() -> ParseResult {
231 std::string nameStr;
232 StringRef name;
233 FIRRTLBaseType type;
234
235 if (failed(parser.parseKeywordOrString(&nameStr)))
236 return failure();
237 name = nameStr;
238
239 bool isFlip = succeeded(parser.parseOptionalKeyword("flip"));
240 if (parser.parseColon() || parseNestedBaseType(type, parser))
241 return failure();
242
243 elements.push_back({StringAttr::get(context, name), isFlip, type});
244 return success();
245 };
246
247 if (parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater,
248 parseBundleElement))
249 return failure();
250
251 return result = BundleType::get(context, elements, isConst), success();
252 }
253 if (name == "openbundle") {
254 SmallVector<OpenBundleType::BundleElement, 4> elements;
255
256 auto parseBundleElement = [&]() -> ParseResult {
257 std::string nameStr;
258 StringRef name;
259 FIRRTLType type;
260
261 if (failed(parser.parseKeywordOrString(&nameStr)))
262 return failure();
263 name = nameStr;
264
265 bool isFlip = succeeded(parser.parseOptionalKeyword("flip"));
266 if (parser.parseColon() || parseNestedType(type, parser))
267 return failure();
268
269 elements.push_back({StringAttr::get(context, name), isFlip, type});
270 return success();
271 };
272
273 if (parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater,
274 parseBundleElement))
275 return failure();
276
277 result = parser.getChecked<OpenBundleType>(context, elements, isConst);
278 return failure(!result);
279 }
280
281 if (name == "enum") {
282 SmallVector<FEnumType::EnumElement, 4> elements;
283
284 auto parseEnumElement = [&]() -> ParseResult {
285 std::string nameStr;
286 StringRef name;
287 FIRRTLBaseType type;
288
289 if (failed(parser.parseKeywordOrString(&nameStr)))
290 return failure();
291 name = nameStr;
292
293 if (parser.parseColon() || parseNestedBaseType(type, parser))
294 return failure();
295
296 elements.push_back({StringAttr::get(context, name), type});
297 return success();
298 };
299
300 if (parser.parseCommaSeparatedList(mlir::AsmParser::Delimiter::LessGreater,
301 parseEnumElement))
302 return failure();
303 if (failed(FEnumType::verify(
304 [&]() { return parser.emitError(parser.getNameLoc()); }, elements,
305 isConst)))
306 return failure();
307
308 return result = FEnumType::get(context, elements, isConst), success();
309 }
310
311 if (name == "vector") {
313 uint64_t width = 0;
314
315 if (parser.parseLess() || parseNestedBaseType(elementType, parser) ||
316 parser.parseComma() || parser.parseInteger(width) ||
317 parser.parseGreater())
318 return failure();
319
320 return result = FVectorType::get(elementType, width, isConst), success();
321 }
322 if (name == "openvector") {
324 uint64_t width = 0;
325
326 if (parser.parseLess() || parseNestedType(elementType, parser) ||
327 parser.parseComma() || parser.parseInteger(width) ||
328 parser.parseGreater())
329 return failure();
330
331 result =
332 parser.getChecked<OpenVectorType>(context, elementType, width, isConst);
333 return failure(!result);
334 }
335
336 // For now, support both firrtl.ref and firrtl.probe.
337 if (name == "ref" || name == "probe") {
338 FIRRTLBaseType type;
339 SymbolRefAttr layer;
340 // Don't pass `isConst` to `parseNestedBaseType since `ref` can point to
341 // either `const` or non-`const` types
342 if (parser.parseLess() || parseNestedBaseType(type, parser))
343 return failure();
344 if (parser.parseOptionalComma().succeeded())
345 if (parser.parseOptionalAttribute(layer).value())
346 return parser.emitError(parser.getNameLoc(),
347 "expected symbol reference");
348 if (parser.parseGreater())
349 return failure();
350
351 if (failed(RefType::verify(
352 [&]() { return parser.emitError(parser.getNameLoc()); }, type,
353 false, layer)))
354 return failure();
355
356 return result = RefType::get(type, false, layer), success();
357 }
358 if (name == "lhs") {
359 FIRRTLType type;
360 if (parser.parseLess() || parseNestedType(type, parser) ||
361 parser.parseGreater())
362 return failure();
363 if (!isa<FIRRTLBaseType>(type))
364 return parser.emitError(parser.getNameLoc(), "expected base type");
365 result = parser.getChecked<LHSType>(context, cast<FIRRTLBaseType>(type));
366 return failure(!result);
367 }
368 if (name == "rwprobe") {
369 FIRRTLBaseType type;
370 SymbolRefAttr layer;
371 if (parser.parseLess() || parseNestedBaseType(type, parser))
372 return failure();
373 if (parser.parseOptionalComma().succeeded())
374 if (parser.parseOptionalAttribute(layer).value())
375 return parser.emitError(parser.getNameLoc(),
376 "expected symbol reference");
377 if (parser.parseGreater())
378 return failure();
379
380 if (failed(RefType::verify(
381 [&]() { return parser.emitError(parser.getNameLoc()); }, type, true,
382 layer)))
383 return failure();
384
385 return result = RefType::get(type, true, layer), success();
386 }
387 if (name == "class") {
388 if (isConst)
389 return parser.emitError(parser.getNameLoc(), "classes cannot be const");
390 ClassType classType;
391 if (parser.parseLess() || ClassType::parseInterface(parser, classType) ||
392 parser.parseGreater())
393 return failure();
394 result = classType;
395 return success();
396 }
397 if (name == "anyref") {
398 if (isConst)
399 return parser.emitError(parser.getNameLoc(), "any refs cannot be const");
400
401 result = AnyRefType::get(parser.getContext());
402 return success();
403 }
404 if (name == "string") {
405 if (isConst) {
406 parser.emitError(parser.getNameLoc(), "strings cannot be const");
407 return failure();
408 }
409 result = StringType::get(parser.getContext());
410 return success();
411 }
412 if (name == "integer") {
413 if (isConst) {
414 parser.emitError(parser.getNameLoc(), "bigints cannot be const");
415 return failure();
416 }
417 result = FIntegerType::get(parser.getContext());
418 return success();
419 }
420 if (name == "bool") {
421 if (isConst) {
422 parser.emitError(parser.getNameLoc(), "bools cannot be const");
423 return failure();
424 }
425 result = BoolType::get(parser.getContext());
426 return success();
427 }
428 if (name == "double") {
429 if (isConst) {
430 parser.emitError(parser.getNameLoc(), "doubles cannot be const");
431 return failure();
432 }
433 result = DoubleType::get(parser.getContext());
434 return success();
435 }
436 if (name == "list") {
437 if (isConst) {
438 parser.emitError(parser.getNameLoc(), "lists cannot be const");
439 return failure();
440 }
442 if (parser.parseLess() || parseNestedPropertyType(elementType, parser) ||
443 parser.parseGreater())
444 return failure();
445 result = parser.getChecked<ListType>(context, elementType);
446 if (!result)
447 return failure();
448 return success();
449 }
450 if (name == "path") {
451 if (isConst) {
452 parser.emitError(parser.getNameLoc(), "path cannot be const");
453 return failure();
454 }
455 result = PathType::get(parser.getContext());
456 return success();
457 }
458 if (name == "alias") {
459 FIRRTLBaseType type;
460 StringRef name;
461 if (parser.parseLess() || parser.parseKeyword(&name) ||
462 parser.parseComma() || parseNestedBaseType(type, parser) ||
463 parser.parseGreater())
464 return failure();
465
466 return result =
467 BaseTypeAliasType::get(StringAttr::get(context, name), type),
468 success();
469 }
470 if (name == "fstring") {
471 return result = FStringType::get(context), success();
472 }
473
474 return {};
475}
476
477/// Parse a type defined by this dialect.
478///
479/// This will first try the generated type parsers and then resort to the custom
480/// parser implementation. Emits an error and returns failure if `name` does not
481/// refer to a type defined in this dialect.
482static ParseResult parseType(Type &result, StringRef name, AsmParser &parser) {
483 // Try the custom type parser.
484 OptionalParseResult parseResult = customTypeParser(parser, name, result);
485 if (parseResult.has_value())
486 return parseResult.value();
487
488 // None of the above recognized the type, so we bail.
489 parser.emitError(parser.getNameLoc(), "unknown FIRRTL dialect type: \"")
490 << name << "\"";
491 return failure();
492}
493
494/// Parse a `FIRRTLType` with a `name` that has already been parsed.
495///
496/// Note that only a subset of types defined in the FIRRTL dialect inherit from
497/// `FIRRTLType`. Use `parseType` to parse *any* of the defined types.
498static ParseResult parseFIRRTLType(FIRRTLType &result, StringRef name,
499 AsmParser &parser) {
500 Type type;
501 if (failed(parseType(type, name, parser)))
502 return failure();
503 result = type_dyn_cast<FIRRTLType>(type);
504 if (result)
505 return success();
506 parser.emitError(parser.getNameLoc(), "unknown FIRRTL type: \"")
507 << name << "\"";
508 return failure();
509}
510
511static ParseResult parseFIRRTLBaseType(FIRRTLBaseType &result, StringRef name,
512 AsmParser &parser) {
513 FIRRTLType type;
514 if (failed(parseFIRRTLType(type, name, parser)))
515 return failure();
516 if (auto base = type_dyn_cast<FIRRTLBaseType>(type)) {
517 result = base;
518 return success();
519 }
520 parser.emitError(parser.getNameLoc(), "expected base type, found ") << type;
521 return failure();
522}
523
524static ParseResult parseFIRRTLPropertyType(PropertyType &result, StringRef name,
525 AsmParser &parser) {
526 FIRRTLType type;
527 if (failed(parseFIRRTLType(type, name, parser)))
528 return failure();
529 if (auto prop = type_dyn_cast<PropertyType>(type)) {
530 result = prop;
531 return success();
532 }
533 parser.emitError(parser.getNameLoc(), "expected property type, found ")
534 << type;
535 return failure();
536}
537
538// NOLINTBEGIN(misc-no-recursion)
539/// Parse a `FIRRTLType`.
540///
541/// Note that only a subset of types defined in the FIRRTL dialect inherit from
542/// `FIRRTLType`. Use `parseType` to parse *any* of the defined types.
544 AsmParser &parser) {
545 StringRef name;
546 if (parser.parseKeyword(&name))
547 return failure();
548 return parseFIRRTLType(result, name, parser);
549}
550// NOLINTEND(misc-no-recursion)
551
552// NOLINTBEGIN(misc-no-recursion)
554 AsmParser &parser) {
555 StringRef name;
556 if (parser.parseKeyword(&name))
557 return failure();
558 return parseFIRRTLBaseType(result, name, parser);
559}
560// NOLINTEND(misc-no-recursion)
561
562// NOLINTBEGIN(misc-no-recursion)
564 AsmParser &parser) {
565 StringRef name;
566 if (parser.parseKeyword(&name))
567 return failure();
568 return parseFIRRTLPropertyType(result, name, parser);
569}
570// NOLINTEND(misc-no-recursion)
571
572//===---------------------------------------------------------------------===//
573// Dialect Type Parsing and Printing
574//===----------------------------------------------------------------------===//
575
576/// Print a type registered to this dialect.
577void FIRRTLDialect::printType(Type type, DialectAsmPrinter &os) const {
578 printNestedType(type, os);
579}
580
581/// Parse a type registered to this dialect.
582Type FIRRTLDialect::parseType(DialectAsmParser &parser) const {
583 StringRef name;
584 Type result;
585 if (parser.parseKeyword(&name) || ::parseType(result, name, parser))
586 return Type();
587 return result;
588}
589
590//===----------------------------------------------------------------------===//
591// Recursive Type Properties
592//===----------------------------------------------------------------------===//
593
594enum {
595 /// Bit set if the type only contains passive elements.
597 /// Bit set if the type contains an analog type.
599 /// Bit set fi the type has any uninferred bit widths.
601};
602
603//===----------------------------------------------------------------------===//
604// FIRRTLBaseType Implementation
605//===----------------------------------------------------------------------===//
606
608 // Use `char` instead of `bool` since llvm already provides a
609 // DenseMapInfo<char> specialization
610 using KeyTy = char;
611
612 FIRRTLBaseTypeStorage(bool isConst) : isConst(static_cast<char>(isConst)) {}
613
614 bool operator==(const KeyTy &key) const { return key == isConst; }
615
616 KeyTy getAsKey() const { return isConst; }
617
618 static FIRRTLBaseTypeStorage *construct(TypeStorageAllocator &allocator,
619 KeyTy key) {
620 return new (allocator.allocate<FIRRTLBaseTypeStorage>())
622 }
623
625};
626
627/// Return true if this is a 'ground' type, aka a non-aggregate type.
628bool FIRRTLType::isGround() {
629 return TypeSwitch<FIRRTLType, bool>(*this)
630 .Case<ClockType, ResetType, AsyncResetType, SIntType, UIntType,
631 AnalogType>([](Type) { return true; })
632 .Case<BundleType, FVectorType, FEnumType, OpenBundleType, OpenVectorType>(
633 [](Type) { return false; })
634 .Case<BaseTypeAliasType>([](BaseTypeAliasType alias) {
635 return alias.getAnonymousType().isGround();
636 })
637 // Not ground per spec, but leaf of aggregate.
638 .Case<PropertyType, RefType>([](Type) { return false; })
639 .Default([](Type) {
640 llvm_unreachable("unknown FIRRTL type");
641 return false;
642 });
643}
644
645bool FIRRTLType::isConst() {
646 return TypeSwitch<FIRRTLType, bool>(*this)
647 .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>(
648 [](auto type) { return type.isConst(); })
649 .Default(false);
650}
651
652bool FIRRTLBaseType::isConst() { return getImpl()->isConst; }
653
655 return TypeSwitch<FIRRTLType, RecursiveTypeProperties>(*this)
656 .Case<ClockType, ResetType, AsyncResetType>([](FIRRTLBaseType type) {
657 return RecursiveTypeProperties{true,
658 false,
659 false,
660 type.isConst(),
661 false,
662 false,
663 firrtl::type_isa<ResetType>(type)};
664 })
665 .Case<SIntType, UIntType>([](auto type) {
667 true, false, false, type.isConst(), false, !type.hasWidth(), false};
668 })
669 .Case<AnalogType>([](auto type) {
671 true, false, true, type.isConst(), false, !type.hasWidth(), false};
672 })
673 .Case<BundleType, FVectorType, FEnumType, OpenBundleType, OpenVectorType,
674 RefType, BaseTypeAliasType>(
675 [](auto type) { return type.getRecursiveTypeProperties(); })
676 .Case<PropertyType>([](auto type) {
677 return RecursiveTypeProperties{true, false, false, false,
678 false, false, false};
679 })
680 .Case<LHSType>(
681 [](auto type) { return type.getType().getRecursiveTypeProperties(); })
682 .Case<FStringType>([](auto type) {
683 return RecursiveTypeProperties{true, false, false, false,
684 false, false, false};
685 })
686 .Default([](Type) {
687 llvm_unreachable("unknown FIRRTL type");
689 });
690}
691
692/// Return this type with any type aliases recursively removed from itself.
694 return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
695 .Case<ClockType, ResetType, AsyncResetType, SIntType, UIntType,
696 AnalogType>([&](Type) { return *this; })
697 .Case<BundleType, FVectorType, FEnumType, BaseTypeAliasType>(
698 [](auto type) { return type.getAnonymousType(); })
699 .Default([](Type) {
700 llvm_unreachable("unknown FIRRTL type");
701 return FIRRTLBaseType();
702 });
703}
704
705/// Return this type with any flip types recursively removed from itself.
707 return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
708 .Case<ClockType, ResetType, AsyncResetType, SIntType, UIntType,
709 AnalogType, FEnumType>([&](Type) { return *this; })
710 .Case<BundleType, FVectorType, FEnumType, BaseTypeAliasType>(
711 [](auto type) { return type.getPassiveType(); })
712 .Default([](Type) {
713 llvm_unreachable("unknown FIRRTL type");
714 return FIRRTLBaseType();
715 });
716}
717
718/// Return a 'const' or non-'const' version of this type.
720 return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
721 .Case<ClockType, ResetType, AsyncResetType, AnalogType, SIntType,
722 UIntType, BundleType, FVectorType, FEnumType, BaseTypeAliasType>(
723 [&](auto type) { return type.getConstType(isConst); })
724 .Default([](Type) {
725 llvm_unreachable("unknown FIRRTL type");
726 return FIRRTLBaseType();
727 });
728}
729
730/// Return this type with a 'const' modifiers dropped
732 return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
733 .Case<ClockType, ResetType, AsyncResetType, AnalogType, SIntType,
734 UIntType>([&](auto type) { return type.getConstType(false); })
735 .Case<BundleType, FVectorType, FEnumType, BaseTypeAliasType>(
736 [&](auto type) { return type.getAllConstDroppedType(); })
737 .Default([](Type) {
738 llvm_unreachable("unknown FIRRTL type");
739 return FIRRTLBaseType();
740 });
741}
742
743/// Return this type with all ground types replaced with UInt<1>. This is
744/// used for `mem` operations.
746 return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
747 .Case<ClockType, ResetType, AsyncResetType, SIntType, UIntType,
748 AnalogType>([&](Type) {
749 return UIntType::get(this->getContext(), 1, this->isConst());
750 })
751 .Case<BundleType>([&](BundleType bundleType) {
752 SmallVector<BundleType::BundleElement, 4> newElements;
753 newElements.reserve(bundleType.getElements().size());
754 for (auto elt : bundleType)
755 newElements.push_back(
756 {elt.name, false /* FIXME */, elt.type.getMaskType()});
757 return BundleType::get(this->getContext(), newElements,
758 bundleType.isConst());
759 })
760 .Case<FVectorType>([](FVectorType vectorType) {
761 return FVectorType::get(vectorType.getElementType().getMaskType(),
762 vectorType.getNumElements(),
763 vectorType.isConst());
764 })
765 .Case<BaseTypeAliasType>([](BaseTypeAliasType base) {
766 return base.getModifiedType(base.getInnerType().getMaskType());
767 })
768 .Default([](Type) {
769 llvm_unreachable("unknown FIRRTL type");
770 return FIRRTLBaseType();
771 });
772}
773
774/// Remove the widths from this type. All widths are replaced with an
775/// unknown width.
777 return TypeSwitch<FIRRTLBaseType, FIRRTLBaseType>(*this)
778 .Case<ClockType, ResetType, AsyncResetType>([](auto a) { return a; })
779 .Case<UIntType, SIntType, AnalogType>(
780 [&](auto a) { return a.get(this->getContext(), -1, a.isConst()); })
781 .Case<BundleType>([&](auto a) {
782 SmallVector<BundleType::BundleElement, 4> newElements;
783 newElements.reserve(a.getElements().size());
784 for (auto elt : a)
785 newElements.push_back(
786 {elt.name, elt.isFlip, elt.type.getWidthlessType()});
787 return BundleType::get(this->getContext(), newElements, a.isConst());
788 })
789 .Case<FVectorType>([](auto a) {
790 return FVectorType::get(a.getElementType().getWidthlessType(),
791 a.getNumElements(), a.isConst());
792 })
793 .Case<FEnumType>([&](FEnumType a) {
794 SmallVector<FEnumType::EnumElement, 4> newElements;
795 newElements.reserve(a.getNumElements());
796 for (auto elt : a)
797 newElements.push_back({elt.name, elt.type.getWidthlessType()});
798 return FEnumType::get(this->getContext(), newElements, a.isConst());
799 })
800 .Case<BaseTypeAliasType>([](BaseTypeAliasType type) {
801 return type.getModifiedType(type.getInnerType().getWidthlessType());
802 })
803 .Default([](auto) {
804 llvm_unreachable("unknown FIRRTL type");
805 return FIRRTLBaseType();
806 });
807}
808
809/// If this is an IntType, AnalogType, or sugar type for a single bit (Clock,
810/// Reset, etc) then return the bitwidth. Return -1 if the is one of these
811/// types but without a specified bitwidth. Return -2 if this isn't a simple
812/// type.
814 return TypeSwitch<FIRRTLBaseType, int32_t>(*this)
815 .Case<ClockType, ResetType, AsyncResetType>([](Type) { return 1; })
816 .Case<SIntType, UIntType>(
817 [&](IntType intType) { return intType.getWidthOrSentinel(); })
818 .Case<AnalogType>(
819 [](AnalogType analogType) { return analogType.getWidthOrSentinel(); })
820 .Case<BundleType, FVectorType, FEnumType>([](Type) { return -2; })
821 .Case<BaseTypeAliasType>([](BaseTypeAliasType type) {
822 // It's faster to use its anonymous type.
823 return type.getAnonymousType().getBitWidthOrSentinel();
824 })
825 .Default([](Type) {
826 llvm_unreachable("unknown FIRRTL type");
827 return -2;
828 });
829}
830
831/// Return true if this is a type usable as a reset. This must be
832/// either an abstract reset, a concrete 1-bit UInt, an
833/// asynchronous reset, or an uninfered width UInt.
835 return TypeSwitch<FIRRTLType, bool>(*this)
836 .Case<ResetType, AsyncResetType>([](Type) { return true; })
837 .Case<UIntType>(
838 [](UIntType a) { return !a.hasWidth() || a.getWidth() == 1; })
839 .Case<BaseTypeAliasType>(
840 [](auto type) { return type.getInnerType().isResetType(); })
841 .Default([](Type) { return false; });
842}
843
844bool firrtl::isConst(Type type) {
845 return TypeSwitch<Type, bool>(type)
846 .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>(
847 [](auto base) { return base.isConst(); })
848 .Default(false);
849}
850
851bool firrtl::containsConst(Type type) {
852 return TypeSwitch<Type, bool>(type)
853 .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>(
854 [](auto base) { return base.containsConst(); })
855 .Default(false);
856}
857
858// NOLINTBEGIN(misc-no-recursion)
861 .Case<BundleType>([&](auto bundle) {
862 for (size_t i = 0, e = bundle.getNumElements(); i < e; ++i) {
863 auto elt = bundle.getElement(i);
864 if (hasZeroBitWidth(elt.type))
865 return true;
866 }
867 return bundle.getNumElements() == 0;
868 })
869 .Case<FVectorType>([&](auto vector) {
870 if (vector.getNumElements() == 0)
871 return true;
872 return hasZeroBitWidth(vector.getElementType());
873 })
874 .Case<FIRRTLBaseType>([](auto groundType) {
875 return firrtl::getBitWidth(groundType).value_or(0) == 0;
876 })
877 .Case<RefType>([](auto ref) { return hasZeroBitWidth(ref.getType()); })
878 .Default([](auto) { return false; });
879}
880// NOLINTEND(misc-no-recursion)
881
882/// Helper to implement the equivalence logic for a pair of bundle elements.
883/// Note that the FIRRTL spec requires bundle elements to have the same
884/// orientation, but this only compares their passive types. The FIRRTL dialect
885/// differs from the spec in how it uses flip types for module output ports and
886/// canonicalizes flips in bundles, so only passive types can be compared here.
887static bool areBundleElementsEquivalent(BundleType::BundleElement destElement,
888 BundleType::BundleElement srcElement,
889 bool destOuterTypeIsConst,
890 bool srcOuterTypeIsConst,
891 bool requiresSameWidth) {
892 if (destElement.name != srcElement.name)
893 return false;
894 if (destElement.isFlip != srcElement.isFlip)
895 return false;
896
897 if (destElement.isFlip) {
898 std::swap(destElement, srcElement);
899 std::swap(destOuterTypeIsConst, srcOuterTypeIsConst);
900 }
901
902 return areTypesEquivalent(destElement.type, srcElement.type,
903 destOuterTypeIsConst, srcOuterTypeIsConst,
904 requiresSameWidth);
905}
906
907/// Returns whether the two types are equivalent. This implements the exact
908/// definition of type equivalence in the FIRRTL spec. If the types being
909/// compared have any outer flips that encode FIRRTL module directions (input or
910/// output), these should be stripped before using this method.
912 bool destOuterTypeIsConst,
913 bool srcOuterTypeIsConst,
914 bool requireSameWidths) {
915 auto destType = type_dyn_cast<FIRRTLBaseType>(destFType);
916 auto srcType = type_dyn_cast<FIRRTLBaseType>(srcFType);
917
918 // For non-base types, only equivalent if identical.
919 if (!destType || !srcType)
920 return destFType == srcFType;
921
922 bool srcIsConst = srcOuterTypeIsConst || srcFType.isConst();
923 bool destIsConst = destOuterTypeIsConst || destFType.isConst();
924
925 // Vector types can be connected if they have the same size and element type.
926 auto destVectorType = type_dyn_cast<FVectorType>(destType);
927 auto srcVectorType = type_dyn_cast<FVectorType>(srcType);
928 if (destVectorType && srcVectorType)
929 return destVectorType.getNumElements() == srcVectorType.getNumElements() &&
930 areTypesEquivalent(destVectorType.getElementType(),
931 srcVectorType.getElementType(), destIsConst,
932 srcIsConst, requireSameWidths);
933
934 // Bundle types can be connected if they have the same size, element names,
935 // and element types.
936 auto destBundleType = type_dyn_cast<BundleType>(destType);
937 auto srcBundleType = type_dyn_cast<BundleType>(srcType);
938 if (destBundleType && srcBundleType) {
939 auto destElements = destBundleType.getElements();
940 auto srcElements = srcBundleType.getElements();
941 size_t numDestElements = destElements.size();
942 if (numDestElements != srcElements.size())
943 return false;
944
945 for (size_t i = 0; i < numDestElements; ++i) {
946 auto destElement = destElements[i];
947 auto srcElement = srcElements[i];
948 if (!areBundleElementsEquivalent(destElement, srcElement, destIsConst,
949 srcIsConst, requireSameWidths))
950 return false;
951 }
952 return true;
953 }
954
955 // Enum types can be connected if they have the same size, element names, and
956 // element types.
957 auto dstEnumType = type_dyn_cast<FEnumType>(destType);
958 auto srcEnumType = type_dyn_cast<FEnumType>(srcType);
959
960 if (dstEnumType && srcEnumType) {
961 if (dstEnumType.getNumElements() != srcEnumType.getNumElements())
962 return false;
963 // Enums requires the types to match exactly.
964 for (const auto &[dst, src] : llvm::zip(dstEnumType, srcEnumType)) {
965 // The variant names must match.
966 if (dst.name != src.name)
967 return false;
968 // Enumeration types can only be connected if the inner types have the
969 // same width.
970 if (!areTypesEquivalent(dst.type, src.type, destIsConst, srcIsConst,
971 true))
972 return false;
973 }
974 return true;
975 }
976
977 // Ground type connections must be const compatible.
978 if (destIsConst && !srcIsConst)
979 return false;
980
981 // Reset types can be driven by UInt<1>, AsyncReset, or Reset types.
982 if (firrtl::type_isa<ResetType>(destType))
983 return srcType.isResetType();
984
985 // Reset types can drive UInt<1>, AsyncReset, or Reset types.
986 if (firrtl::type_isa<ResetType>(srcType))
987 return destType.isResetType();
988
989 // If we can implicitly truncate or extend the bitwidth, or either width is
990 // currently uninferred, then compare the widthless version of these types.
991 if (!requireSameWidths || destType.getBitWidthOrSentinel() == -1)
992 srcType = srcType.getWidthlessType();
993 if (!requireSameWidths || srcType.getBitWidthOrSentinel() == -1)
994 destType = destType.getWidthlessType();
995
996 // Ground types can be connected if their constless types are the same
997 return destType.getConstType(false) == srcType.getConstType(false);
998}
999
1000/// Returns whether the srcType can be const-casted to the destType.
1002 bool srcOuterTypeIsConst) {
1003 // Identical types are always castable
1004 if (destFType == srcFType)
1005 return true;
1006
1007 auto destType = type_dyn_cast<FIRRTLBaseType>(destFType);
1008 auto srcType = type_dyn_cast<FIRRTLBaseType>(srcFType);
1009
1010 // For non-base types, only castable if identical.
1011 if (!destType || !srcType)
1012 return false;
1013
1014 // Types must be passive
1015 if (!destType.isPassive() || !srcType.isPassive())
1016 return false;
1017
1018 bool srcIsConst = srcType.isConst() || srcOuterTypeIsConst;
1019
1020 // Cannot cast non-'const' src to 'const' dest
1021 if (destType.isConst() && !srcIsConst)
1022 return false;
1023
1024 // Vector types can be casted if they have the same size and castable element
1025 // type.
1026 auto destVectorType = type_dyn_cast<FVectorType>(destType);
1027 auto srcVectorType = type_dyn_cast<FVectorType>(srcType);
1028 if (destVectorType && srcVectorType)
1029 return destVectorType.getNumElements() == srcVectorType.getNumElements() &&
1030 areTypesConstCastable(destVectorType.getElementType(),
1031 srcVectorType.getElementType(), srcIsConst);
1032 if (destVectorType != srcVectorType)
1033 return false;
1034
1035 // Bundle types can be casted if they have the same size, element names,
1036 // and castable element types.
1037 auto destBundleType = type_dyn_cast<BundleType>(destType);
1038 auto srcBundleType = type_dyn_cast<BundleType>(srcType);
1039 if (destBundleType && srcBundleType) {
1040 auto destElements = destBundleType.getElements();
1041 auto srcElements = srcBundleType.getElements();
1042 size_t numDestElements = destElements.size();
1043 if (numDestElements != srcElements.size())
1044 return false;
1045
1046 return llvm::all_of_zip(
1047 destElements, srcElements,
1048 [&](const auto &destElement, const auto &srcElement) {
1049 return destElement.name == srcElement.name &&
1050 areTypesConstCastable(destElement.type, srcElement.type,
1051 srcIsConst);
1052 });
1053 }
1054 if (destBundleType != srcBundleType)
1055 return false;
1056
1057 // Ground types can be casted if the source type is a const
1058 // version of the destination type
1059 return destType == srcType.getConstType(destType.isConst());
1060}
1061
1062bool firrtl::areTypesRefCastable(Type dstType, Type srcType) {
1063 auto dstRefType = type_dyn_cast<RefType>(dstType);
1064 auto srcRefType = type_dyn_cast<RefType>(srcType);
1065 if (!dstRefType || !srcRefType)
1066 return false;
1067 if (dstRefType == srcRefType)
1068 return true;
1069 if (dstRefType.getForceable() && !srcRefType.getForceable())
1070 return false;
1071
1072 // Okay walk the types recursively. They must be identical "structurally"
1073 // with exception leaf (ground) types of destination can be uninferred
1074 // versions of the corresponding source type. (can lose width information or
1075 // become a more general reset type)
1076 // In addition, while not explicitly in spec its useful to allow probes
1077 // to have const cast away, especially for probes of literals and expressions
1078 // derived from them. Check const as with const cast.
1079 // NOLINTBEGIN(misc-no-recursion)
1080 auto recurse = [&](auto &&f, FIRRTLBaseType dest, FIRRTLBaseType src,
1081 bool srcOuterTypeIsConst) -> bool {
1082 // Fast-path for identical types.
1083 if (dest == src)
1084 return true;
1085
1086 // Always passive inside probes, but for sanity assert this.
1087 assert(dest.isPassive() && src.isPassive());
1088
1089 bool srcIsConst = src.isConst() || srcOuterTypeIsConst;
1090
1091 // Cannot cast non-'const' src to 'const' dest
1092 if (dest.isConst() && !srcIsConst)
1093 return false;
1094
1095 // Recurse through aggregates to get the leaves, checking
1096 // structural equivalence re:element count + names.
1097
1098 if (auto destVectorType = type_dyn_cast<FVectorType>(dest)) {
1099 auto srcVectorType = type_dyn_cast<FVectorType>(src);
1100 return srcVectorType &&
1101 destVectorType.getNumElements() ==
1102 srcVectorType.getNumElements() &&
1103 f(f, destVectorType.getElementType(),
1104 srcVectorType.getElementType(), srcIsConst);
1105 }
1106
1107 if (auto destBundleType = type_dyn_cast<BundleType>(dest)) {
1108 auto srcBundleType = type_dyn_cast<BundleType>(src);
1109 if (!srcBundleType)
1110 return false;
1111 // (no need to check orientation, these are always passive)
1112 auto destElements = destBundleType.getElements();
1113 auto srcElements = srcBundleType.getElements();
1114
1115 return destElements.size() == srcElements.size() &&
1116 llvm::all_of_zip(
1117 destElements, srcElements,
1118 [&](const auto &destElement, const auto &srcElement) {
1119 return destElement.name == srcElement.name &&
1120 f(f, destElement.type, srcElement.type, srcIsConst);
1121 });
1122 }
1123
1124 if (auto destEnumType = type_dyn_cast<FEnumType>(dest)) {
1125 auto srcEnumType = type_dyn_cast<FEnumType>(src);
1126 if (!srcEnumType)
1127 return false;
1128 auto destElements = destEnumType.getElements();
1129 auto srcElements = srcEnumType.getElements();
1130
1131 return destElements.size() == srcElements.size() &&
1132 llvm::all_of_zip(
1133 destElements, srcElements,
1134 [&](const auto &destElement, const auto &srcElement) {
1135 return destElement.name == srcElement.name &&
1136 f(f, destElement.type, srcElement.type, srcIsConst);
1137 });
1138 }
1139
1140 // Reset types can be driven by UInt<1>, AsyncReset, or Reset types.
1141 if (type_isa<ResetType>(dest))
1142 return src.isResetType();
1143 // (but don't allow the other direction, can only become more general)
1144
1145 // Compare against const src if dest is const.
1146 src = src.getConstType(dest.isConst());
1147
1148 // Compare against widthless src if dest is widthless.
1149 if (dest.getBitWidthOrSentinel() == -1)
1150 src = src.getWidthlessType();
1151
1152 return dest == src;
1153 };
1154
1155 return recurse(recurse, dstRefType.getType(), srcRefType.getType(), false);
1156 // NOLINTEND(misc-no-recursion)
1157}
1158
1159// NOLINTBEGIN(misc-no-recursion)
1160/// Returns true if the destination is at least as wide as an equivalent source.
1162 return TypeSwitch<FIRRTLBaseType, bool>(dstType)
1163 .Case<BundleType>([&](auto dstBundle) {
1164 auto srcBundle = type_cast<BundleType>(srcType);
1165 for (size_t i = 0, n = dstBundle.getNumElements(); i < n; ++i) {
1166 auto srcElem = srcBundle.getElement(i);
1167 auto dstElem = dstBundle.getElement(i);
1168 if (dstElem.isFlip) {
1169 if (!isTypeLarger(srcElem.type, dstElem.type))
1170 return false;
1171 } else {
1172 if (!isTypeLarger(dstElem.type, srcElem.type))
1173 return false;
1174 }
1175 }
1176 return true;
1177 })
1178 .Case<FVectorType>([&](auto vector) {
1179 return isTypeLarger(vector.getElementType(),
1180 type_cast<FVectorType>(srcType).getElementType());
1181 })
1182 .Default([&](auto dstGround) {
1183 int32_t destWidth = dstType.getPassiveType().getBitWidthOrSentinel();
1184 int32_t srcWidth = srcType.getPassiveType().getBitWidthOrSentinel();
1185 return destWidth <= -1 || srcWidth <= -1 || destWidth >= srcWidth;
1186 });
1187}
1188// NOLINTEND(misc-no-recursion)
1189
1194
1195bool firrtl::areAnonymousTypesEquivalent(mlir::Type lhs, mlir::Type rhs) {
1196 if (auto destBaseType = type_dyn_cast<FIRRTLBaseType>(lhs))
1197 if (auto srcBaseType = type_dyn_cast<FIRRTLBaseType>(rhs))
1198 return areAnonymousTypesEquivalent(destBaseType, srcBaseType);
1199
1200 if (auto destRefType = type_dyn_cast<RefType>(lhs))
1201 if (auto srcRefType = type_dyn_cast<RefType>(rhs))
1202 return areAnonymousTypesEquivalent(destRefType.getType(),
1203 srcRefType.getType());
1204
1205 return lhs == rhs;
1206}
1207
1208/// Return the passive version of a firrtl type
1209/// top level for ODS constraint usage
1210Type firrtl::getPassiveType(Type anyBaseFIRRTLType) {
1211 return type_cast<FIRRTLBaseType>(anyBaseFIRRTLType).getPassiveType();
1212}
1213
1214bool firrtl::isTypeInOut(Type type) {
1215 return llvm::TypeSwitch<Type, bool>(type)
1216 .Case<FIRRTLBaseType>([](auto type) {
1217 return !type.containsReference() &&
1218 (!type.isPassive() || type.containsAnalog());
1219 })
1220 .Default(false);
1221}
1222
1223//===----------------------------------------------------------------------===//
1224// IntType
1225//===----------------------------------------------------------------------===//
1226
1227/// Return a SIntType or UIntType with the specified signedness, width, and
1228/// constness
1229IntType IntType::get(MLIRContext *context, bool isSigned,
1230 int32_t widthOrSentinel, bool isConst) {
1231 if (isSigned)
1232 return SIntType::get(context, widthOrSentinel, isConst);
1233 return UIntType::get(context, widthOrSentinel, isConst);
1234}
1235
1237 if (auto sintType = type_dyn_cast<SIntType>(*this))
1238 return sintType.getWidthOrSentinel();
1239 if (auto uintType = type_dyn_cast<UIntType>(*this))
1240 return uintType.getWidthOrSentinel();
1241 return -1;
1242}
1243
1244//===----------------------------------------------------------------------===//
1245// WidthTypeStorage
1246//===----------------------------------------------------------------------===//
1247
1251 using KeyTy = std::pair<int32_t, char>;
1252
1253 bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1254
1255 KeyTy getAsKey() const { return KeyTy(width, isConst); }
1256
1257 static WidthTypeStorage *construct(TypeStorageAllocator &allocator,
1258 const KeyTy &key) {
1259 return new (allocator.allocate<WidthTypeStorage>())
1260 WidthTypeStorage(key.first, key.second);
1261 }
1262
1263 int32_t width;
1264};
1265
1267
1268 if (auto sIntType = type_dyn_cast<SIntType>(*this))
1269 return sIntType.getConstType(isConst);
1270 return type_cast<UIntType>(*this).getConstType(isConst);
1271}
1272
1273//===----------------------------------------------------------------------===//
1274// SIntType
1275//===----------------------------------------------------------------------===//
1276
1277SIntType SIntType::get(MLIRContext *context) { return get(context, -1, false); }
1278
1279SIntType SIntType::get(MLIRContext *context, std::optional<int32_t> width,
1280 bool isConst) {
1281 return get(context, width ? *width : -1, isConst);
1282}
1283
1284LogicalResult SIntType::verify(function_ref<InFlightDiagnostic()> emitError,
1285 int32_t widthOrSentinel, bool isConst) {
1286 if (widthOrSentinel < -1)
1287 return emitError() << "invalid width";
1288 return success();
1289}
1290
1291int32_t SIntType::getWidthOrSentinel() const { return getImpl()->width; }
1292
1293SIntType SIntType::getConstType(bool isConst) {
1294 if (isConst == this->isConst())
1295 return *this;
1296 return get(getContext(), getWidthOrSentinel(), isConst);
1297}
1298
1299//===----------------------------------------------------------------------===//
1300// UIntType
1301//===----------------------------------------------------------------------===//
1302
1303UIntType UIntType::get(MLIRContext *context) { return get(context, -1, false); }
1304
1305UIntType UIntType::get(MLIRContext *context, std::optional<int32_t> width,
1306 bool isConst) {
1307 return get(context, width ? *width : -1, isConst);
1308}
1309
1310LogicalResult UIntType::verify(function_ref<InFlightDiagnostic()> emitError,
1311 int32_t widthOrSentinel, bool isConst) {
1312 if (widthOrSentinel < -1)
1313 return emitError() << "invalid width";
1314 return success();
1315}
1316
1317int32_t UIntType::getWidthOrSentinel() const { return getImpl()->width; }
1318
1319UIntType UIntType::getConstType(bool isConst) {
1320 if (isConst == this->isConst())
1321 return *this;
1322 return get(getContext(), getWidthOrSentinel(), isConst);
1323}
1324
1325//===----------------------------------------------------------------------===//
1326// Bundle Type
1327//===----------------------------------------------------------------------===//
1328
1331 using KeyTy = std::pair<ArrayRef<BundleType::BundleElement>, char>;
1332
1333 BundleTypeStorage(ArrayRef<BundleType::BundleElement> elements, bool isConst)
1335 elements(elements.begin(), elements.end()), props{true, false, false,
1336 isConst, false, false,
1337 false} {
1338 uint64_t fieldID = 0;
1339 fieldIDs.reserve(elements.size());
1340 for (auto &element : elements) {
1341 auto type = element.type;
1342 auto eltInfo = type.getRecursiveTypeProperties();
1343 props.isPassive &= eltInfo.isPassive & !element.isFlip;
1344 props.containsAnalog |= eltInfo.containsAnalog;
1345 props.containsReference |= eltInfo.containsReference;
1346 props.containsConst |= eltInfo.containsConst;
1347 props.containsTypeAlias |= eltInfo.containsTypeAlias;
1348 props.hasUninferredWidth |= eltInfo.hasUninferredWidth;
1349 props.hasUninferredReset |= eltInfo.hasUninferredReset;
1350 fieldID += 1;
1351 fieldIDs.push_back(fieldID);
1352 // Increment the field ID for the next field by the number of subfields.
1353 fieldID += hw::FieldIdImpl::getMaxFieldID(type);
1354 }
1355 maxFieldID = fieldID;
1356 }
1357
1358 bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1359
1360 KeyTy getAsKey() const { return KeyTy(elements, isConst); }
1361
1362 static llvm::hash_code hashKey(const KeyTy &key) {
1363 return llvm::hash_combine(
1364 llvm::hash_combine_range(key.first.begin(), key.first.end()),
1365 key.second);
1366 }
1367
1368 static BundleTypeStorage *construct(TypeStorageAllocator &allocator,
1369 KeyTy key) {
1370 return new (allocator.allocate<BundleTypeStorage>())
1371 BundleTypeStorage(key.first, static_cast<bool>(key.second));
1372 }
1373
1374 SmallVector<BundleType::BundleElement, 4> elements;
1375 SmallVector<uint64_t, 4> fieldIDs;
1376 uint64_t maxFieldID;
1377
1378 /// This holds the bits for the type's recursive properties, and can hold a
1379 /// pointer to a passive version of the type.
1381 BundleType passiveType;
1382 BundleType anonymousType;
1383};
1384
1385BundleType BundleType::get(MLIRContext *context,
1386 ArrayRef<BundleElement> elements, bool isConst) {
1387 return Base::get(context, elements, isConst);
1388}
1389
1390auto BundleType::getElements() const -> ArrayRef<BundleElement> {
1391 return getImpl()->elements;
1392}
1393
1394/// Return a pair with the 'isPassive' and 'containsAnalog' bits.
1395RecursiveTypeProperties BundleType::getRecursiveTypeProperties() const {
1396 return getImpl()->props;
1397}
1398
1399/// Return this type with any flip types recursively removed from itself.
1400FIRRTLBaseType BundleType::getPassiveType() {
1401 auto *impl = getImpl();
1402
1403 // If we've already determined and cached the passive type, use it.
1404 if (impl->passiveType)
1405 return impl->passiveType;
1406
1407 // If this type is already passive, use it and remember for next time.
1408 if (impl->props.isPassive) {
1409 impl->passiveType = *this;
1410 return *this;
1411 }
1412
1413 // Otherwise at least one element is non-passive, rebuild a passive version.
1414 SmallVector<BundleType::BundleElement, 16> newElements;
1415 newElements.reserve(impl->elements.size());
1416 for (auto &elt : impl->elements) {
1417 newElements.push_back({elt.name, false, elt.type.getPassiveType()});
1418 }
1419
1420 auto passiveType = BundleType::get(getContext(), newElements, isConst());
1421 impl->passiveType = passiveType;
1422 return passiveType;
1423}
1424
1425BundleType BundleType::getConstType(bool isConst) {
1426 if (isConst == this->isConst())
1427 return *this;
1428 return get(getContext(), getElements(), isConst);
1429}
1430
1431BundleType BundleType::getAllConstDroppedType() {
1432 if (!containsConst())
1433 return *this;
1434
1435 SmallVector<BundleElement> constDroppedElements(
1436 llvm::map_range(getElements(), [](BundleElement element) {
1437 element.type = element.type.getAllConstDroppedType();
1438 return element;
1439 }));
1440 return get(getContext(), constDroppedElements, false);
1441}
1442
1443std::optional<unsigned> BundleType::getElementIndex(StringAttr name) {
1444 for (const auto &it : llvm::enumerate(getElements())) {
1445 auto element = it.value();
1446 if (element.name == name) {
1447 return unsigned(it.index());
1448 }
1449 }
1450 return std::nullopt;
1451}
1452
1453std::optional<unsigned> BundleType::getElementIndex(StringRef name) {
1454 for (const auto &it : llvm::enumerate(getElements())) {
1455 auto element = it.value();
1456 if (element.name.getValue() == name) {
1457 return unsigned(it.index());
1458 }
1459 }
1460 return std::nullopt;
1461}
1462
1463StringAttr BundleType::getElementNameAttr(size_t index) {
1464 assert(index < getNumElements() &&
1465 "index must be less than number of fields in bundle");
1466 return getElements()[index].name;
1467}
1468
1469StringRef BundleType::getElementName(size_t index) {
1470 return getElementNameAttr(index).getValue();
1471}
1472
1473std::optional<BundleType::BundleElement>
1474BundleType::getElement(StringAttr name) {
1475 if (auto maybeIndex = getElementIndex(name))
1476 return getElements()[*maybeIndex];
1477 return std::nullopt;
1478}
1479
1480std::optional<BundleType::BundleElement>
1481BundleType::getElement(StringRef name) {
1482 if (auto maybeIndex = getElementIndex(name))
1483 return getElements()[*maybeIndex];
1484 return std::nullopt;
1485}
1486
1487/// Look up an element by index.
1488BundleType::BundleElement BundleType::getElement(size_t index) {
1489 assert(index < getNumElements() &&
1490 "index must be less than number of fields in bundle");
1491 return getElements()[index];
1492}
1493
1494FIRRTLBaseType BundleType::getElementType(StringAttr name) {
1495 auto element = getElement(name);
1496 return element ? element->type : FIRRTLBaseType();
1497}
1498
1499FIRRTLBaseType BundleType::getElementType(StringRef name) {
1500 auto element = getElement(name);
1501 return element ? element->type : FIRRTLBaseType();
1502}
1503
1504FIRRTLBaseType BundleType::getElementType(size_t index) const {
1505 assert(index < getNumElements() &&
1506 "index must be less than number of fields in bundle");
1507 return getElements()[index].type;
1508}
1509
1510uint64_t BundleType::getFieldID(uint64_t index) const {
1511 return getImpl()->fieldIDs[index];
1512}
1513
1514uint64_t BundleType::getIndexForFieldID(uint64_t fieldID) const {
1515 assert(!getElements().empty() && "Bundle must have >0 fields");
1516 auto fieldIDs = getImpl()->fieldIDs;
1517 auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
1518 return std::distance(fieldIDs.begin(), it);
1519}
1520
1521std::pair<uint64_t, uint64_t>
1522BundleType::getIndexAndSubfieldID(uint64_t fieldID) const {
1523 auto index = getIndexForFieldID(fieldID);
1524 auto elementFieldID = getFieldID(index);
1525 return {index, fieldID - elementFieldID};
1526}
1527
1528std::pair<Type, uint64_t>
1529BundleType::getSubTypeByFieldID(uint64_t fieldID) const {
1530 if (fieldID == 0)
1531 return {*this, 0};
1532 auto subfieldIndex = getIndexForFieldID(fieldID);
1533 auto subfieldType = getElementType(subfieldIndex);
1534 auto subfieldID = fieldID - getFieldID(subfieldIndex);
1535 return {subfieldType, subfieldID};
1536}
1537
1538uint64_t BundleType::getMaxFieldID() const { return getImpl()->maxFieldID; }
1539
1540std::pair<uint64_t, bool>
1541BundleType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
1542 auto childRoot = getFieldID(index);
1543 auto rangeEnd = index + 1 >= getNumElements() ? getMaxFieldID()
1544 : (getFieldID(index + 1) - 1);
1545 return std::make_pair(fieldID - childRoot,
1546 fieldID >= childRoot && fieldID <= rangeEnd);
1547}
1548
1549bool BundleType::isConst() { return getImpl()->isConst; }
1550
1551BundleType::ElementType
1552BundleType::getElementTypePreservingConst(size_t index) {
1553 auto type = getElementType(index);
1554 return type.getConstType(type.isConst() || isConst());
1555}
1556
1557/// Return this type with any type aliases recursively removed from itself.
1558FIRRTLBaseType BundleType::getAnonymousType() {
1559 auto *impl = getImpl();
1560
1561 // If we've already determined and cached the anonymous type, use it.
1562 if (impl->anonymousType)
1563 return impl->anonymousType;
1564
1565 // If this type is already anonymous, use it and remember for next time.
1566 if (!impl->props.containsTypeAlias) {
1567 impl->anonymousType = *this;
1568 return *this;
1569 }
1570
1571 // Otherwise at least one element has an alias type, rebuild an anonymous
1572 // version.
1573 SmallVector<BundleType::BundleElement, 16> newElements;
1574 newElements.reserve(impl->elements.size());
1575 for (auto &elt : impl->elements)
1576 newElements.push_back({elt.name, elt.isFlip, elt.type.getAnonymousType()});
1577
1578 auto anonymousType = BundleType::get(getContext(), newElements, isConst());
1579 impl->anonymousType = anonymousType;
1580 return anonymousType;
1581}
1582
1583//===----------------------------------------------------------------------===//
1584// OpenBundle Type
1585//===----------------------------------------------------------------------===//
1586
1588 using KeyTy = std::pair<ArrayRef<OpenBundleType::BundleElement>, char>;
1589
1590 OpenBundleTypeStorage(ArrayRef<OpenBundleType::BundleElement> elements,
1591 bool isConst)
1592 : elements(elements.begin(), elements.end()), props{true, false, false,
1593 isConst, false, false,
1594 false},
1595 isConst(static_cast<char>(isConst)) {
1596 uint64_t fieldID = 0;
1597 fieldIDs.reserve(elements.size());
1598 for (auto &element : elements) {
1599 auto type = element.type;
1600 auto eltInfo = type.getRecursiveTypeProperties();
1601 props.isPassive &= eltInfo.isPassive & !element.isFlip;
1602 props.containsAnalog |= eltInfo.containsAnalog;
1603 props.containsReference |= eltInfo.containsReference;
1604 props.containsConst |= eltInfo.containsConst;
1605 props.containsTypeAlias |= eltInfo.containsTypeAlias;
1606 props.hasUninferredWidth |= eltInfo.hasUninferredWidth;
1607 props.hasUninferredReset |= eltInfo.hasUninferredReset;
1608 fieldID += 1;
1609 fieldIDs.push_back(fieldID);
1610 // Increment the field ID for the next field by the number of subfields.
1611 // TODO: Maybe just have elementType be FieldIDTypeInterface ?
1612 fieldID += hw::FieldIdImpl::getMaxFieldID(type);
1613 }
1614 maxFieldID = fieldID;
1615 }
1616
1617 bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1618
1619 static llvm::hash_code hashKey(const KeyTy &key) {
1620 return llvm::hash_combine(
1621 llvm::hash_combine_range(key.first.begin(), key.first.end()),
1622 key.second);
1623 }
1624
1625 KeyTy getAsKey() const { return KeyTy(elements, isConst); }
1626
1627 static OpenBundleTypeStorage *construct(TypeStorageAllocator &allocator,
1628 KeyTy key) {
1629 return new (allocator.allocate<OpenBundleTypeStorage>())
1630 OpenBundleTypeStorage(key.first, static_cast<bool>(key.second));
1631 }
1632
1633 SmallVector<OpenBundleType::BundleElement, 4> elements;
1634 SmallVector<uint64_t, 4> fieldIDs;
1635 uint64_t maxFieldID;
1636
1637 /// This holds the bits for the type's recursive properties, and can hold a
1638 /// pointer to a passive version of the type.
1640
1641 // Whether this is 'const'.
1643};
1644
1645OpenBundleType OpenBundleType::get(MLIRContext *context,
1646 ArrayRef<BundleElement> elements,
1647 bool isConst) {
1648 return Base::get(context, elements, isConst);
1649}
1650
1651auto OpenBundleType::getElements() const -> ArrayRef<BundleElement> {
1652 return getImpl()->elements;
1653}
1654
1655/// Return a pair with the 'isPassive' and 'containsAnalog' bits.
1656RecursiveTypeProperties OpenBundleType::getRecursiveTypeProperties() const {
1657 return getImpl()->props;
1658}
1659
1660OpenBundleType OpenBundleType::getConstType(bool isConst) {
1661 if (isConst == this->isConst())
1662 return *this;
1663 return get(getContext(), getElements(), isConst);
1664}
1665
1666std::optional<unsigned> OpenBundleType::getElementIndex(StringAttr name) {
1667 for (const auto &it : llvm::enumerate(getElements())) {
1668 auto element = it.value();
1669 if (element.name == name) {
1670 return unsigned(it.index());
1671 }
1672 }
1673 return std::nullopt;
1674}
1675
1676std::optional<unsigned> OpenBundleType::getElementIndex(StringRef name) {
1677 for (const auto &it : llvm::enumerate(getElements())) {
1678 auto element = it.value();
1679 if (element.name.getValue() == name) {
1680 return unsigned(it.index());
1681 }
1682 }
1683 return std::nullopt;
1684}
1685
1686StringAttr OpenBundleType::getElementNameAttr(size_t index) {
1687 assert(index < getNumElements() &&
1688 "index must be less than number of fields in bundle");
1689 return getElements()[index].name;
1690}
1691
1692StringRef OpenBundleType::getElementName(size_t index) {
1693 return getElementNameAttr(index).getValue();
1694}
1695
1696std::optional<OpenBundleType::BundleElement>
1697OpenBundleType::getElement(StringAttr name) {
1698 if (auto maybeIndex = getElementIndex(name))
1699 return getElements()[*maybeIndex];
1700 return std::nullopt;
1701}
1702
1703std::optional<OpenBundleType::BundleElement>
1704OpenBundleType::getElement(StringRef name) {
1705 if (auto maybeIndex = getElementIndex(name))
1706 return getElements()[*maybeIndex];
1707 return std::nullopt;
1708}
1709
1710/// Look up an element by index.
1711OpenBundleType::BundleElement OpenBundleType::getElement(size_t index) {
1712 assert(index < getNumElements() &&
1713 "index must be less than number of fields in bundle");
1714 return getElements()[index];
1715}
1716
1717OpenBundleType::ElementType OpenBundleType::getElementType(StringAttr name) {
1718 auto element = getElement(name);
1719 return element ? element->type : FIRRTLBaseType();
1720}
1721
1722OpenBundleType::ElementType OpenBundleType::getElementType(StringRef name) {
1723 auto element = getElement(name);
1724 return element ? element->type : FIRRTLBaseType();
1725}
1726
1727OpenBundleType::ElementType OpenBundleType::getElementType(size_t index) const {
1728 assert(index < getNumElements() &&
1729 "index must be less than number of fields in bundle");
1730 return getElements()[index].type;
1731}
1732
1733uint64_t OpenBundleType::getFieldID(uint64_t index) const {
1734 return getImpl()->fieldIDs[index];
1735}
1736
1737uint64_t OpenBundleType::getIndexForFieldID(uint64_t fieldID) const {
1738 assert(!getElements().empty() && "Bundle must have >0 fields");
1739 auto fieldIDs = getImpl()->fieldIDs;
1740 auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
1741 return std::distance(fieldIDs.begin(), it);
1742}
1743
1744std::pair<uint64_t, uint64_t>
1745OpenBundleType::getIndexAndSubfieldID(uint64_t fieldID) const {
1746 auto index = getIndexForFieldID(fieldID);
1747 auto elementFieldID = getFieldID(index);
1748 return {index, fieldID - elementFieldID};
1749}
1750
1751std::pair<Type, uint64_t>
1752OpenBundleType::getSubTypeByFieldID(uint64_t fieldID) const {
1753 if (fieldID == 0)
1754 return {*this, 0};
1755 auto subfieldIndex = getIndexForFieldID(fieldID);
1756 auto subfieldType = getElementType(subfieldIndex);
1757 auto subfieldID = fieldID - getFieldID(subfieldIndex);
1758 return {subfieldType, subfieldID};
1759}
1760
1761uint64_t OpenBundleType::getMaxFieldID() const { return getImpl()->maxFieldID; }
1762
1763std::pair<uint64_t, bool>
1764OpenBundleType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
1765 auto childRoot = getFieldID(index);
1766 auto rangeEnd = index + 1 >= getNumElements() ? getMaxFieldID()
1767 : (getFieldID(index + 1) - 1);
1768 return std::make_pair(fieldID - childRoot,
1769 fieldID >= childRoot && fieldID <= rangeEnd);
1770}
1771
1772bool OpenBundleType::isConst() { return getImpl()->isConst; }
1773
1774OpenBundleType::ElementType
1775OpenBundleType::getElementTypePreservingConst(size_t index) {
1776 auto type = getElementType(index);
1777 // TODO: ConstTypeInterface / Trait ?
1778 return TypeSwitch<FIRRTLType, ElementType>(type)
1779 .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>([&](auto type) {
1780 return type.getConstType(type.isConst() || isConst());
1781 })
1782 .Default(type);
1783}
1784
1785LogicalResult
1786OpenBundleType::verify(function_ref<InFlightDiagnostic()> emitErrorFn,
1787 ArrayRef<BundleElement> elements, bool isConst) {
1788 for (auto &element : elements) {
1789 if (FIRRTLType(element.type).containsReference() && isConst)
1790 return emitErrorFn()
1791 << "'const' bundle cannot have references, but element "
1792 << element.name << " has type " << element.type;
1793 if (type_isa<LHSType>(element.type))
1794 return emitErrorFn() << "bundle element " << element.name
1795 << " cannot have a left-hand side type";
1796 }
1797
1798 return success();
1799}
1800
1801//===----------------------------------------------------------------------===//
1802// FVectorType
1803//===----------------------------------------------------------------------===//
1804
1807 using KeyTy = std::tuple<FIRRTLBaseType, size_t, char>;
1808
1816
1817 bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1818
1820
1821 static FVectorTypeStorage *construct(TypeStorageAllocator &allocator,
1822 KeyTy key) {
1823 return new (allocator.allocate<FVectorTypeStorage>())
1824 FVectorTypeStorage(std::get<0>(key), std::get<1>(key),
1825 static_cast<bool>(std::get<2>(key)));
1826 }
1827
1830
1831 /// This holds the bits for the type's recursive properties, and can hold a
1832 /// pointer to a passive version of the type.
1836};
1837
1838FVectorType FVectorType::get(FIRRTLBaseType elementType, size_t numElements,
1839 bool isConst) {
1840 return Base::get(elementType.getContext(), elementType, numElements, isConst);
1841}
1842
1843FIRRTLBaseType FVectorType::getElementType() const {
1844 return getImpl()->elementType;
1845}
1846
1847size_t FVectorType::getNumElements() const { return getImpl()->numElements; }
1848
1849/// Return the recursive properties of the type.
1850RecursiveTypeProperties FVectorType::getRecursiveTypeProperties() const {
1851 return getImpl()->props;
1852}
1853
1854/// Return this type with any flip types recursively removed from itself.
1855FIRRTLBaseType FVectorType::getPassiveType() {
1856 auto *impl = getImpl();
1857
1858 // If we've already determined and cached the passive type, use it.
1859 if (impl->passiveType)
1860 return impl->passiveType;
1861
1862 // If this type is already passive, return it and remember for next time.
1863 if (impl->elementType.getRecursiveTypeProperties().isPassive)
1864 return impl->passiveType = *this;
1865
1866 // Otherwise, rebuild a passive version.
1867 auto passiveType = FVectorType::get(getElementType().getPassiveType(),
1868 getNumElements(), isConst());
1869 impl->passiveType = passiveType;
1870 return passiveType;
1871}
1872
1873FVectorType FVectorType::getConstType(bool isConst) {
1874 if (isConst == this->isConst())
1875 return *this;
1876 return get(getElementType(), getNumElements(), isConst);
1877}
1878
1879FVectorType FVectorType::getAllConstDroppedType() {
1880 if (!containsConst())
1881 return *this;
1882 return get(getElementType().getAllConstDroppedType(), getNumElements(),
1883 false);
1884}
1885
1886/// Return this type with any type aliases recursively removed from itself.
1887FIRRTLBaseType FVectorType::getAnonymousType() {
1888 auto *impl = getImpl();
1889
1890 if (impl->anonymousType)
1891 return impl->anonymousType;
1892
1893 // If this type is already anonymous, return it and remember for next time.
1894 if (!impl->props.containsTypeAlias)
1895 return impl->anonymousType = *this;
1896
1897 // Otherwise, rebuild an anonymous version.
1898 auto anonymousType = FVectorType::get(getElementType().getAnonymousType(),
1899 getNumElements(), isConst());
1900 impl->anonymousType = anonymousType;
1901 return anonymousType;
1902}
1903
1904uint64_t FVectorType::getFieldID(uint64_t index) const {
1905 return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
1906}
1907
1908uint64_t FVectorType::getIndexForFieldID(uint64_t fieldID) const {
1909 assert(fieldID && "fieldID must be at least 1");
1910 // Divide the field ID by the number of fieldID's per element.
1911 return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
1912}
1913
1914std::pair<uint64_t, uint64_t>
1915FVectorType::getIndexAndSubfieldID(uint64_t fieldID) const {
1916 auto index = getIndexForFieldID(fieldID);
1917 auto elementFieldID = getFieldID(index);
1918 return {index, fieldID - elementFieldID};
1919}
1920
1921std::pair<Type, uint64_t>
1922FVectorType::getSubTypeByFieldID(uint64_t fieldID) const {
1923 if (fieldID == 0)
1924 return {*this, 0};
1925 return {getElementType(), getIndexAndSubfieldID(fieldID).second};
1926}
1927
1928uint64_t FVectorType::getMaxFieldID() const {
1929 return getNumElements() *
1930 (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
1931}
1932
1933std::pair<uint64_t, bool>
1934FVectorType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
1935 auto childRoot = getFieldID(index);
1936 auto rangeEnd =
1937 index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
1938 return std::make_pair(fieldID - childRoot,
1939 fieldID >= childRoot && fieldID <= rangeEnd);
1940}
1941
1942bool FVectorType::isConst() { return getImpl()->isConst; }
1943
1944FVectorType::ElementType FVectorType::getElementTypePreservingConst() {
1945 auto type = getElementType();
1946 return type.getConstType(type.isConst() || isConst());
1947}
1948
1949//===----------------------------------------------------------------------===//
1950// OpenVectorType
1951//===----------------------------------------------------------------------===//
1952
1954 using KeyTy = std::tuple<FIRRTLType, size_t, char>;
1955
1963
1964 bool operator==(const KeyTy &key) const { return key == getAsKey(); }
1965
1967
1968 static OpenVectorTypeStorage *construct(TypeStorageAllocator &allocator,
1969 KeyTy key) {
1970 return new (allocator.allocate<OpenVectorTypeStorage>())
1971 OpenVectorTypeStorage(std::get<0>(key), std::get<1>(key),
1972 static_cast<bool>(std::get<2>(key)));
1973 }
1974
1977
1980};
1981
1982OpenVectorType OpenVectorType::get(FIRRTLType elementType, size_t numElements,
1983 bool isConst) {
1984 return Base::get(elementType.getContext(), elementType, numElements, isConst);
1985}
1986
1987FIRRTLType OpenVectorType::getElementType() const {
1988 return getImpl()->elementType;
1989}
1990
1991size_t OpenVectorType::getNumElements() const { return getImpl()->numElements; }
1992
1993/// Return the recursive properties of the type.
1994RecursiveTypeProperties OpenVectorType::getRecursiveTypeProperties() const {
1995 return getImpl()->props;
1996}
1997
1998OpenVectorType OpenVectorType::getConstType(bool isConst) {
1999 if (isConst == this->isConst())
2000 return *this;
2001 return get(getElementType(), getNumElements(), isConst);
2002}
2003
2004uint64_t OpenVectorType::getFieldID(uint64_t index) const {
2005 return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
2006}
2007
2008uint64_t OpenVectorType::getIndexForFieldID(uint64_t fieldID) const {
2009 assert(fieldID && "fieldID must be at least 1");
2010 // Divide the field ID by the number of fieldID's per element.
2011 return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
2012}
2013
2014std::pair<uint64_t, uint64_t>
2015OpenVectorType::getIndexAndSubfieldID(uint64_t fieldID) const {
2016 auto index = getIndexForFieldID(fieldID);
2017 auto elementFieldID = getFieldID(index);
2018 return {index, fieldID - elementFieldID};
2019}
2020
2021std::pair<Type, uint64_t>
2022OpenVectorType::getSubTypeByFieldID(uint64_t fieldID) const {
2023 if (fieldID == 0)
2024 return {*this, 0};
2025 return {getElementType(), getIndexAndSubfieldID(fieldID).second};
2026}
2027
2028uint64_t OpenVectorType::getMaxFieldID() const {
2029 // If this is requirement, make ODS constraint or actual elementType.
2030 return getNumElements() *
2031 (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
2032}
2033
2034std::pair<uint64_t, bool>
2035OpenVectorType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
2036 auto childRoot = getFieldID(index);
2037 auto rangeEnd =
2038 index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
2039 return std::make_pair(fieldID - childRoot,
2040 fieldID >= childRoot && fieldID <= rangeEnd);
2041}
2042
2043bool OpenVectorType::isConst() { return getImpl()->isConst; }
2044
2045OpenVectorType::ElementType OpenVectorType::getElementTypePreservingConst() {
2046 auto type = getElementType();
2047 // TODO: ConstTypeInterface / Trait ?
2048 return TypeSwitch<FIRRTLType, ElementType>(type)
2049 .Case<FIRRTLBaseType, OpenBundleType, OpenVectorType>([&](auto type) {
2050 return type.getConstType(type.isConst() || isConst());
2051 })
2052 .Default(type);
2053}
2054
2055LogicalResult
2056OpenVectorType::verify(function_ref<InFlightDiagnostic()> emitErrorFn,
2058 bool isConst) {
2059 if (elementType.containsReference() && isConst)
2060 return emitErrorFn() << "vector cannot be const with references";
2061 if (type_isa<LHSType>(elementType))
2062 return emitErrorFn() << "vector cannot have a left-hand side type";
2063 return success();
2064}
2065
2066//===----------------------------------------------------------------------===//
2067// FEnum Type
2068//===----------------------------------------------------------------------===//
2069
2071 using KeyTy = std::pair<ArrayRef<FEnumType::EnumElement>, char>;
2072
2073 FEnumTypeStorage(ArrayRef<FEnumType::EnumElement> elements, bool isConst)
2075 elements(elements.begin(), elements.end()) {
2076 RecursiveTypeProperties props{true, false, false, isConst,
2077 false, false, false};
2078 uint64_t fieldID = 0;
2079 fieldIDs.reserve(elements.size());
2080 for (auto &element : elements) {
2081 auto type = element.type;
2082 auto eltInfo = type.getRecursiveTypeProperties();
2083 props.isPassive &= eltInfo.isPassive;
2084 props.containsAnalog |= eltInfo.containsAnalog;
2085 props.containsConst |= eltInfo.containsConst;
2086 props.containsReference |= eltInfo.containsReference;
2087 props.containsTypeAlias |= eltInfo.containsTypeAlias;
2088 props.hasUninferredReset |= eltInfo.hasUninferredReset;
2089 props.hasUninferredWidth |= eltInfo.hasUninferredWidth;
2090 fieldID += 1;
2091 fieldIDs.push_back(fieldID);
2092 // Increment the field ID for the next field by the number of subfields.
2093 fieldID += hw::FieldIdImpl::getMaxFieldID(type);
2094 }
2095 maxFieldID = fieldID;
2096 recProps = props;
2097 }
2098
2099 bool operator==(const KeyTy &key) const { return key == getAsKey(); }
2100
2101 KeyTy getAsKey() const { return KeyTy(elements, isConst); }
2102
2103 static llvm::hash_code hashKey(const KeyTy &key) {
2104 return llvm::hash_combine(
2105 llvm::hash_combine_range(key.first.begin(), key.first.end()),
2106 key.second);
2107 }
2108
2109 static FEnumTypeStorage *construct(TypeStorageAllocator &allocator,
2110 KeyTy key) {
2111 return new (allocator.allocate<FEnumTypeStorage>())
2112 FEnumTypeStorage(key.first, static_cast<bool>(key.second));
2113 }
2114
2115 SmallVector<FEnumType::EnumElement, 4> elements;
2116 SmallVector<uint64_t, 4> fieldIDs;
2117 uint64_t maxFieldID;
2118
2121};
2122
2123FEnumType FEnumType::get(::mlir::MLIRContext *context,
2124 ArrayRef<EnumElement> elements, bool isConst) {
2125 return Base::get(context, elements, isConst);
2126}
2127
2128ArrayRef<FEnumType::EnumElement> FEnumType::getElements() const {
2129 return getImpl()->elements;
2130}
2131
2132FEnumType FEnumType::getConstType(bool isConst) {
2133 return get(getContext(), getElements(), isConst);
2134}
2135
2136FEnumType FEnumType::getAllConstDroppedType() {
2137 if (!containsConst())
2138 return *this;
2139
2140 SmallVector<EnumElement> constDroppedElements(
2141 llvm::map_range(getElements(), [](EnumElement element) {
2142 element.type = element.type.getAllConstDroppedType();
2143 return element;
2144 }));
2145 return get(getContext(), constDroppedElements, false);
2146}
2147
2148/// Return a pair with the 'isPassive' and 'containsAnalog' bits.
2149RecursiveTypeProperties FEnumType::getRecursiveTypeProperties() const {
2150 return getImpl()->recProps;
2151}
2152
2153std::optional<unsigned> FEnumType::getElementIndex(StringAttr name) {
2154 for (const auto &it : llvm::enumerate(getElements())) {
2155 auto element = it.value();
2156 if (element.name == name) {
2157 return unsigned(it.index());
2158 }
2159 }
2160 return std::nullopt;
2161}
2162
2163std::optional<unsigned> FEnumType::getElementIndex(StringRef name) {
2164 for (const auto &it : llvm::enumerate(getElements())) {
2165 auto element = it.value();
2166 if (element.name.getValue() == name) {
2167 return unsigned(it.index());
2168 }
2169 }
2170 return std::nullopt;
2171}
2172
2173StringAttr FEnumType::getElementNameAttr(size_t index) {
2174 assert(index < getNumElements() &&
2175 "index must be less than number of fields in enum");
2176 return getElements()[index].name;
2177}
2178
2179StringRef FEnumType::getElementName(size_t index) {
2180 return getElementNameAttr(index).getValue();
2181}
2182
2183std::optional<FEnumType::EnumElement> FEnumType::getElement(StringAttr name) {
2184 if (auto maybeIndex = getElementIndex(name))
2185 return getElements()[*maybeIndex];
2186 return std::nullopt;
2187}
2188
2189std::optional<FEnumType::EnumElement> FEnumType::getElement(StringRef name) {
2190 if (auto maybeIndex = getElementIndex(name))
2191 return getElements()[*maybeIndex];
2192 return std::nullopt;
2193}
2194
2195/// Look up an element by index.
2196FEnumType::EnumElement FEnumType::getElement(size_t index) {
2197 assert(index < getNumElements() &&
2198 "index must be less than number of fields in enum");
2199 return getElements()[index];
2200}
2201
2202FIRRTLBaseType FEnumType::getElementType(StringAttr name) {
2203 auto element = getElement(name);
2204 return element ? element->type : FIRRTLBaseType();
2205}
2206
2207FIRRTLBaseType FEnumType::getElementType(StringRef name) {
2208 auto element = getElement(name);
2209 return element ? element->type : FIRRTLBaseType();
2210}
2211
2212FIRRTLBaseType FEnumType::getElementType(size_t index) const {
2213 assert(index < getNumElements() &&
2214 "index must be less than number of fields in enum");
2215 return getElements()[index].type;
2216}
2217
2218FIRRTLBaseType FEnumType::getElementTypePreservingConst(size_t index) {
2219 auto type = getElementType(index);
2220 return type.getConstType(type.isConst() || isConst());
2221}
2222
2223uint64_t FEnumType::getFieldID(uint64_t index) const {
2224 return getImpl()->fieldIDs[index];
2225}
2226
2227uint64_t FEnumType::getIndexForFieldID(uint64_t fieldID) const {
2228 assert(!getElements().empty() && "Enum must have >0 fields");
2229 auto fieldIDs = getImpl()->fieldIDs;
2230 auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
2231 return std::distance(fieldIDs.begin(), it);
2232}
2233
2234std::pair<uint64_t, uint64_t>
2235FEnumType::getIndexAndSubfieldID(uint64_t fieldID) const {
2236 auto index = getIndexForFieldID(fieldID);
2237 auto elementFieldID = getFieldID(index);
2238 return {index, fieldID - elementFieldID};
2239}
2240
2241std::pair<Type, uint64_t>
2242FEnumType::getSubTypeByFieldID(uint64_t fieldID) const {
2243 if (fieldID == 0)
2244 return {*this, 0};
2245 auto subfieldIndex = getIndexForFieldID(fieldID);
2246 auto subfieldType = getElementType(subfieldIndex);
2247 auto subfieldID = fieldID - getFieldID(subfieldIndex);
2248 return {subfieldType, subfieldID};
2249}
2250
2251uint64_t FEnumType::getMaxFieldID() const { return getImpl()->maxFieldID; }
2252
2253std::pair<uint64_t, bool>
2254FEnumType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
2255 auto childRoot = getFieldID(index);
2256 auto rangeEnd = index + 1 >= getNumElements() ? getMaxFieldID()
2257 : (getFieldID(index + 1) - 1);
2258 return std::make_pair(fieldID - childRoot,
2259 fieldID >= childRoot && fieldID <= rangeEnd);
2260}
2261
2262auto FEnumType::verify(function_ref<InFlightDiagnostic()> emitErrorFn,
2263 ArrayRef<EnumElement> elements, bool isConst)
2264 -> LogicalResult {
2265 for (auto &elt : elements) {
2266 auto r = elt.type.getRecursiveTypeProperties();
2267 if (!r.isPassive)
2268 return emitErrorFn() << "enum field '" << elt.name << "' not passive";
2269 if (r.containsAnalog)
2270 return emitErrorFn() << "enum field '" << elt.name << "' contains analog";
2271 if (r.containsConst && !isConst)
2272 return emitErrorFn() << "enum with 'const' elements must be 'const'";
2273 // TODO: exclude reference containing
2274 }
2275 return success();
2276}
2277
2278/// Return this type with any type aliases recursively removed from itself.
2279FIRRTLBaseType FEnumType::getAnonymousType() {
2280 auto *impl = getImpl();
2281
2282 if (impl->anonymousType)
2283 return impl->anonymousType;
2284
2285 if (!impl->recProps.containsTypeAlias)
2286 return impl->anonymousType = *this;
2287
2288 SmallVector<FEnumType::EnumElement, 4> elements;
2289
2290 for (auto element : getElements())
2291 elements.push_back({element.name, element.type.getAnonymousType()});
2292 return impl->anonymousType = FEnumType::get(getContext(), elements);
2293}
2294
2295//===----------------------------------------------------------------------===//
2296// BaseTypeAliasType
2297//===----------------------------------------------------------------------===//
2298
2301 using KeyTy = std::tuple<StringAttr, FIRRTLBaseType>;
2302
2306
2307 bool operator==(const KeyTy &key) const { return key == getAsKey(); }
2308
2309 KeyTy getAsKey() const { return KeyTy(name, innerType); }
2310
2311 static llvm::hash_code hashKey(const KeyTy &key) {
2312 return llvm::hash_combine(key);
2313 }
2314
2315 static BaseTypeAliasStorage *construct(TypeStorageAllocator &allocator,
2316 KeyTy key) {
2317 return new (allocator.allocate<BaseTypeAliasStorage>())
2318 BaseTypeAliasStorage(std::get<0>(key), std::get<1>(key));
2319 }
2320 StringAttr name;
2323};
2324
2325auto BaseTypeAliasType::get(StringAttr name, FIRRTLBaseType innerType)
2326 -> BaseTypeAliasType {
2327 return Base::get(name.getContext(), name, innerType);
2328}
2329
2330auto BaseTypeAliasType::getName() const -> StringAttr {
2331 return getImpl()->name;
2332}
2333
2334auto BaseTypeAliasType::getInnerType() const -> FIRRTLBaseType {
2335 return getImpl()->innerType;
2336}
2337
2338FIRRTLBaseType BaseTypeAliasType::getAnonymousType() {
2339 auto *impl = getImpl();
2340 if (impl->anonymousType)
2341 return impl->anonymousType;
2342 return impl->anonymousType = getInnerType().getAnonymousType();
2343}
2344
2345FIRRTLBaseType BaseTypeAliasType::getPassiveType() {
2346 return getModifiedType(getInnerType().getPassiveType());
2347}
2348
2349RecursiveTypeProperties BaseTypeAliasType::getRecursiveTypeProperties() const {
2350 auto rtp = getInnerType().getRecursiveTypeProperties();
2351 rtp.containsTypeAlias = true;
2352 return rtp;
2353}
2354
2355// If a given `newInnerType` is identical to innerType, return `*this`
2356// because we can reuse the type alias. Otherwise return `newInnerType`.
2357FIRRTLBaseType BaseTypeAliasType::getModifiedType(FIRRTLBaseType newInnerType) {
2358 if (newInnerType == getInnerType())
2359 return *this;
2360 return newInnerType;
2361}
2362
2363// FieldIDTypeInterface implementation.
2364FIRRTLBaseType BaseTypeAliasType::getAllConstDroppedType() {
2365 return getModifiedType(getInnerType().getAllConstDroppedType());
2366}
2367
2368FIRRTLBaseType BaseTypeAliasType::getConstType(bool isConst) {
2369 return getModifiedType(getInnerType().getConstType(isConst));
2370}
2371
2372std::pair<Type, uint64_t>
2373BaseTypeAliasType::getSubTypeByFieldID(uint64_t fieldID) const {
2374 return hw::FieldIdImpl::getSubTypeByFieldID(getInnerType(), fieldID);
2375}
2376
2377uint64_t BaseTypeAliasType::getMaxFieldID() const {
2378 return hw::FieldIdImpl::getMaxFieldID(getInnerType());
2379}
2380
2381std::pair<uint64_t, bool>
2382BaseTypeAliasType::projectToChildFieldID(uint64_t fieldID,
2383 uint64_t index) const {
2384 return hw::FieldIdImpl::projectToChildFieldID(getInnerType(), fieldID, index);
2385}
2386
2387uint64_t BaseTypeAliasType::getIndexForFieldID(uint64_t fieldID) const {
2388 return hw::FieldIdImpl::getIndexForFieldID(getInnerType(), fieldID);
2389}
2390
2391uint64_t BaseTypeAliasType::getFieldID(uint64_t index) const {
2392 return hw::FieldIdImpl::getFieldID(getInnerType(), index);
2393}
2394
2395std::pair<uint64_t, uint64_t>
2396BaseTypeAliasType::getIndexAndSubfieldID(uint64_t fieldID) const {
2397 return hw::FieldIdImpl::getIndexAndSubfieldID(getInnerType(), fieldID);
2398}
2399
2400//===----------------------------------------------------------------------===//
2401// LHSType
2402//===----------------------------------------------------------------------===//
2403
2404LHSType LHSType::get(FIRRTLBaseType type) {
2405 return LHSType::get(type.getContext(), type);
2406}
2407
2408LogicalResult LHSType::verify(function_ref<InFlightDiagnostic()> emitError,
2409 FIRRTLBaseType type) {
2410 if (type.containsAnalog())
2411 return emitError() << "lhs type cannot contain an AnalogType";
2412 if (!type.isPassive())
2413 return emitError() << "lhs type cannot contain a non-passive type";
2414 if (type.containsReference())
2415 return emitError() << "lhs type cannot contain a reference";
2416 if (type_isa<LHSType>(type))
2417 return emitError() << "lhs type cannot contain a lhs type";
2418
2419 return success();
2420}
2421
2422//===----------------------------------------------------------------------===//
2423// RefType
2424//===----------------------------------------------------------------------===//
2425
2426auto RefType::get(FIRRTLBaseType type, bool forceable, SymbolRefAttr layer)
2427 -> RefType {
2428 return Base::get(type.getContext(), type, forceable, layer);
2429}
2430
2431auto RefType::verify(function_ref<InFlightDiagnostic()> emitErrorFn,
2432 FIRRTLBaseType base, bool forceable, SymbolRefAttr layer)
2433 -> LogicalResult {
2434 if (!base.isPassive())
2435 return emitErrorFn() << "reference base type must be passive";
2436 if (forceable && base.containsConst())
2437 return emitErrorFn()
2438 << "forceable reference base type cannot contain const";
2439 return success();
2440}
2441
2442RecursiveTypeProperties RefType::getRecursiveTypeProperties() const {
2443 auto rtp = getType().getRecursiveTypeProperties();
2444 rtp.containsReference = true;
2445 // References are not "passive", per FIRRTL spec.
2446 rtp.isPassive = false;
2447 return rtp;
2448}
2449
2450//===----------------------------------------------------------------------===//
2451// AnalogType
2452//===----------------------------------------------------------------------===//
2453
2454AnalogType AnalogType::get(mlir::MLIRContext *context) {
2455 return AnalogType::get(context, -1, false);
2456}
2457
2458AnalogType AnalogType::get(mlir::MLIRContext *context,
2459 std::optional<int32_t> width, bool isConst) {
2460 return AnalogType::get(context, width ? *width : -1, isConst);
2461}
2462
2463LogicalResult AnalogType::verify(function_ref<InFlightDiagnostic()> emitError,
2464 int32_t widthOrSentinel, bool isConst) {
2465 if (widthOrSentinel < -1)
2466 return emitError() << "invalid width";
2467 return success();
2468}
2469
2470int32_t AnalogType::getWidthOrSentinel() const { return getImpl()->width; }
2471
2472AnalogType AnalogType::getConstType(bool isConst) {
2473 if (isConst == this->isConst())
2474 return *this;
2475 return get(getContext(), getWidthOrSentinel(), isConst);
2476}
2477
2478//===----------------------------------------------------------------------===//
2479// ClockType
2480//===----------------------------------------------------------------------===//
2481
2482ClockType ClockType::getConstType(bool isConst) {
2483 if (isConst == this->isConst())
2484 return *this;
2485 return get(getContext(), isConst);
2486}
2487
2488//===----------------------------------------------------------------------===//
2489// ResetType
2490//===----------------------------------------------------------------------===//
2491
2492ResetType ResetType::getConstType(bool isConst) {
2493 if (isConst == this->isConst())
2494 return *this;
2495 return get(getContext(), isConst);
2496}
2497
2498//===----------------------------------------------------------------------===//
2499// AsyncResetType
2500//===----------------------------------------------------------------------===//
2501
2502AsyncResetType AsyncResetType::getConstType(bool isConst) {
2503 if (isConst == this->isConst())
2504 return *this;
2505 return get(getContext(), isConst);
2506}
2507
2508//===----------------------------------------------------------------------===//
2509// ClassType
2510//===----------------------------------------------------------------------===//
2511
2513 using KeyTy = std::pair<FlatSymbolRefAttr, ArrayRef<ClassElement>>;
2514
2515 static ClassTypeStorage *construct(TypeStorageAllocator &allocator,
2516 KeyTy key) {
2517 auto name = key.first;
2518 auto elements = allocator.copyInto(key.second);
2519
2520 // build the field ID table
2521 SmallVector<uint64_t, 4> ids;
2522 uint64_t id = 0;
2523 ids.reserve(elements.size());
2524 for (auto &element : elements) {
2525 id += 1;
2526 ids.push_back(id);
2527 id += hw::FieldIdImpl::getMaxFieldID(element.type);
2528 }
2529
2530 auto fieldIDs = allocator.copyInto(ArrayRef(ids));
2531 auto maxFieldID = id;
2532
2533 return new (allocator.allocate<ClassTypeStorage>())
2535 }
2536
2537 ClassTypeStorage(FlatSymbolRefAttr name, ArrayRef<ClassElement> elements,
2538 ArrayRef<uint64_t> fieldIDs, uint64_t maxFieldID)
2541
2542 bool operator==(const KeyTy &key) const {
2543 return name == key.first && elements == key.second;
2544 }
2545
2546 FlatSymbolRefAttr name;
2547 ArrayRef<ClassElement> elements;
2548 ArrayRef<uint64_t> fieldIDs;
2549 uint64_t maxFieldID;
2550};
2551
2552ClassType ClassType::get(FlatSymbolRefAttr name,
2553 ArrayRef<ClassElement> elements) {
2554 return get(name.getContext(), name, elements);
2555}
2556
2557StringRef ClassType::getName() const {
2558 return getNameAttr().getAttr().getValue();
2559}
2560
2561FlatSymbolRefAttr ClassType::getNameAttr() const { return getImpl()->name; }
2562
2563ArrayRef<ClassElement> ClassType::getElements() const {
2564 return getImpl()->elements;
2565}
2566
2567const ClassElement &ClassType::getElement(IntegerAttr index) const {
2568 return getElement(index.getValue().getZExtValue());
2569}
2570
2571const ClassElement &ClassType::getElement(size_t index) const {
2572 return getElements()[index];
2573}
2574
2575std::optional<uint64_t> ClassType::getElementIndex(StringRef fieldName) const {
2576 for (const auto [i, e] : llvm::enumerate(getElements()))
2577 if (fieldName == e.name)
2578 return i;
2579 return {};
2580}
2581
2582void ClassType::printInterface(AsmPrinter &p) const {
2583 p.printSymbolName(getName());
2584 p << "(";
2585 bool first = true;
2586 for (const auto &element : getElements()) {
2587 if (!first)
2588 p << ", ";
2589 p << element.direction << " ";
2590 p.printKeywordOrString(element.name);
2591 p << ": " << element.type;
2592 first = false;
2593 }
2594 p << ")";
2595}
2596
2597uint64_t ClassType::getFieldID(uint64_t index) const {
2598 return getImpl()->fieldIDs[index];
2599}
2600
2601uint64_t ClassType::getIndexForFieldID(uint64_t fieldID) const {
2602 assert(!getElements().empty() && "Class must have >0 fields");
2603 auto fieldIDs = getImpl()->fieldIDs;
2604 auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
2605 return std::distance(fieldIDs.begin(), it);
2606}
2607
2608std::pair<uint64_t, uint64_t>
2609ClassType::getIndexAndSubfieldID(uint64_t fieldID) const {
2610 auto index = getIndexForFieldID(fieldID);
2611 auto elementFieldID = getFieldID(index);
2612 return {index, fieldID - elementFieldID};
2613}
2614
2615std::pair<Type, uint64_t>
2616ClassType::getSubTypeByFieldID(uint64_t fieldID) const {
2617 if (fieldID == 0)
2618 return {*this, 0};
2619 auto subfieldIndex = getIndexForFieldID(fieldID);
2620 auto subfieldType = getElement(subfieldIndex).type;
2621 auto subfieldID = fieldID - getFieldID(subfieldIndex);
2622 return {subfieldType, subfieldID};
2623}
2624
2625uint64_t ClassType::getMaxFieldID() const { return getImpl()->maxFieldID; }
2626
2627std::pair<uint64_t, bool>
2628ClassType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
2629 auto childRoot = getFieldID(index);
2630 auto rangeEnd = index + 1 >= getNumElements() ? getMaxFieldID()
2631 : (getFieldID(index + 1) - 1);
2632 return std::make_pair(fieldID - childRoot,
2633 fieldID >= childRoot && fieldID <= rangeEnd);
2634}
2635
2636ParseResult ClassType::parseInterface(AsmParser &parser, ClassType &result) {
2637 StringAttr className;
2638 if (parser.parseSymbolName(className))
2639 return failure();
2640
2641 SmallVector<ClassElement> elements;
2642 if (parser.parseCommaSeparatedList(
2643 OpAsmParser::Delimiter::Paren, [&]() -> ParseResult {
2644 // Parse port direction.
2645 Direction direction;
2646 if (succeeded(parser.parseOptionalKeyword("out")))
2647 direction = Direction::Out;
2648 else if (succeeded(parser.parseKeyword("in", "or 'out'")))
2649 direction = Direction::In;
2650 else
2651 return failure();
2652
2653 // Parse port name.
2654 std::string keyword;
2655 if (parser.parseKeywordOrString(&keyword))
2656 return failure();
2657 StringAttr name = StringAttr::get(parser.getContext(), keyword);
2658
2659 // Parse port type.
2660 Type type;
2661 if (parser.parseColonType(type))
2662 return failure();
2663
2664 elements.emplace_back(name, type, direction);
2665 return success();
2666 }))
2667 return failure();
2668
2669 result = ClassType::get(FlatSymbolRefAttr::get(className), elements);
2670 return success();
2671}
2672
2673//===----------------------------------------------------------------------===//
2674// FIRRTLDialect
2675//===----------------------------------------------------------------------===//
2676
2677void FIRRTLDialect::registerTypes() {
2678 addTypes<
2679#define GET_TYPEDEF_LIST
2680#include "circt/Dialect/FIRRTL/FIRRTLTypes.cpp.inc"
2681 >();
2682}
2683
2684// Get the bit width for this type, return None if unknown. Unlike
2685// getBitWidthOrSentinel(), this can recursively compute the bitwidth of
2686// aggregate types. For bundle and vectors, recursively get the width of each
2687// field element and return the total bit width of the aggregate type. This
2688// returns None, if any of the bundle fields is a flip type, or ground type with
2689// unknown bit width.
2690std::optional<int64_t> firrtl::getBitWidth(FIRRTLBaseType type,
2691 bool ignoreFlip) {
2692 std::function<std::optional<int64_t>(FIRRTLBaseType)> getWidth =
2693 [&](FIRRTLBaseType type) -> std::optional<int64_t> {
2694 return TypeSwitch<FIRRTLBaseType, std::optional<int64_t>>(type)
2695 .Case<BundleType>([&](BundleType bundle) -> std::optional<int64_t> {
2696 int64_t width = 0;
2697 for (auto &elt : bundle) {
2698 if (elt.isFlip && !ignoreFlip)
2699 return std::nullopt;
2700 auto w = getBitWidth(elt.type);
2701 if (!w.has_value())
2702 return std::nullopt;
2703 width += *w;
2704 }
2705 return width;
2706 })
2707 .Case<FEnumType>([&](FEnumType fenum) -> std::optional<int64_t> {
2708 int64_t width = 0;
2709 for (auto &elt : fenum) {
2710 auto w = getBitWidth(elt.type);
2711 if (!w.has_value())
2712 return std::nullopt;
2713 width = std::max(width, *w);
2714 }
2715 return width + llvm::Log2_32_Ceil(fenum.getNumElements());
2716 })
2717 .Case<FVectorType>([&](auto vector) -> std::optional<int64_t> {
2718 auto w = getBitWidth(vector.getElementType());
2719 if (!w.has_value())
2720 return std::nullopt;
2721 return *w * vector.getNumElements();
2722 })
2723 .Case<IntType>([&](IntType iType) { return iType.getWidth(); })
2724 .Case<ClockType, ResetType, AsyncResetType>([](Type) { return 1; })
2725 .Default([&](auto t) { return std::nullopt; });
2726 };
2727 return getWidth(type);
2728}
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
MlirType elementType
Definition CHIRRTL.cpp:29
static ParseResult parseFIRRTLBaseType(FIRRTLBaseType &result, StringRef name, AsmParser &parser)
static ParseResult parseFIRRTLPropertyType(PropertyType &result, StringRef name, AsmParser &parser)
static LogicalResult customTypePrinter(Type type, AsmPrinter &os)
Print a type with a custom printer implementation.
static OptionalParseResult customTypeParser(AsmParser &parser, StringRef name, Type &result)
Parse a type with a custom parser implementation.
static ParseResult parseType(Type &result, StringRef name, AsmParser &parser)
Parse a type defined by this dialect.
@ ContainsAnalogBitMask
Bit set if the type contains an analog type.
@ HasUninferredWidthBitMask
Bit set fi the type has any uninferred bit widths.
@ IsPassiveBitMask
Bit set if the type only contains passive elements.
static bool areBundleElementsEquivalent(BundleType::BundleElement destElement, BundleType::BundleElement srcElement, bool destOuterTypeIsConst, bool srcOuterTypeIsConst, bool requiresSameWidth)
Helper to implement the equivalence logic for a pair of bundle elements.
static ParseResult parseFIRRTLType(FIRRTLType &result, StringRef name, AsmParser &parser)
Parse a FIRRTLType with a name that has already been parsed.
static unsigned getFieldID(BundleType type, unsigned index)
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
static unsigned getMaxFieldID(FIRRTLBaseType type)
static InstancePath empty
FIRRTLBaseType getAnonymousType()
Return this type with any type alias types recursively removed from itself.
bool isResetType()
Return true if this is a valid "reset" type.
FIRRTLBaseType getMaskType()
Return this type with all ground types replaced with UInt<1>.
FIRRTLBaseType getPassiveType()
Return this type with any flip types recursively removed from itself.
int32_t getBitWidthOrSentinel()
If this is an IntType, AnalogType, or sugar type for a single bit (Clock, Reset, etc) then return the...
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
FIRRTLBaseType getAllConstDroppedType()
Return this type with a 'const' modifiers dropped.
bool isPassive() const
Return true if this is a "passive" type - one that contains no "flip" types recursively within itself...
FIRRTLBaseType getConstType(bool isConst)
Return a 'const' or non-'const' version of this type.
FIRRTLBaseType getWidthlessType()
Return this type with widths of all ground types removed.
This class implements the same functionality as TypeSwitch except that it uses firrtl::type_dyn_cast ...
FIRRTLTypeSwitch< T, ResultT > & Case(CallableT &&caseFn)
Add a case on the given type.
bool containsReference()
Return true if this is or contains a Reference type.
bool isConst()
Returns true if this is a 'const' type that can only hold compile-time constant values.
RecursiveTypeProperties getRecursiveTypeProperties() const
Return the recursive properties of the type, containing the isPassive, containsAnalog,...
This is the common base class between SIntType and UIntType.
int32_t getWidthOrSentinel() const
Return the width of this type, or -1 if it has none specified.
static IntType get(MLIRContext *context, bool isSigned, int32_t widthOrSentinel=-1, bool isConst=false)
Return an SIntType or UIntType with the specified signedness, width, and constness.
IntType getConstType(bool isConst)
Return a 'const' or non-'const' version of this type.
std::optional< int32_t > getWidth() const
Return an optional containing the width, if the width is known (or empty if width is unknown).
Represents a limited word-length unsigned integer in SystemC as described in IEEE 1666-2011 §7....
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
ParseResult parseNestedType(FIRRTLType &result, AsmParser &parser)
Parse a FIRRTLType.
bool areAnonymousTypesEquivalent(FIRRTLBaseType lhs, FIRRTLBaseType rhs)
Return true if anonymous types of given arguments are equivalent by pointer comparison.
ParseResult parseNestedBaseType(FIRRTLBaseType &result, AsmParser &parser)
bool isTypeInOut(mlir::Type type)
Returns true if the given type has some flipped (aka unaligned) dataflow.
bool areTypesRefCastable(Type dstType, Type srcType)
Return true if destination ref type can be cast from source ref type, per FIRRTL spec rules they must...
bool areTypesEquivalent(FIRRTLType destType, FIRRTLType srcType, bool destOuterTypeIsConst=false, bool srcOuterTypeIsConst=false, bool requireSameWidths=false)
Returns whether the two types are equivalent.
mlir::Type getPassiveType(mlir::Type anyBaseFIRRTLType)
bool isTypeLarger(FIRRTLBaseType dstType, FIRRTLBaseType srcType)
Returns true if the destination is at least as wide as a source.
bool containsConst(Type type)
Returns true if the type is or contains a 'const' type whose value is guaranteed to be unchanging at ...
bool hasZeroBitWidth(FIRRTLType type)
Return true if the type has zero bit width.
void printNestedType(Type type, AsmPrinter &os)
Print a type defined by this dialect.
bool isConst(Type type)
Returns true if this is a 'const' type whose value is guaranteed to be unchanging at circuit executio...
bool areTypesConstCastable(FIRRTLType destType, FIRRTLType srcType, bool srcOuterTypeIsConst=false)
Returns whether the srcType can be const-casted to the destType.
ParseResult parseNestedPropertyType(PropertyType &result, AsmParser &parser)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
std::pair< uint64_t, uint64_t > getIndexAndSubfieldID(Type type, uint64_t fieldID)
uint64_t getFieldID(Type type, uint64_t index)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
std::pair< uint64_t, bool > projectToChildFieldID(Type, uint64_t fieldID, uint64_t index)
uint64_t getIndexForFieldID(Type type, uint64_t fieldID)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
A collection of bits indicating the recursive properties of a type.
Definition FIRRTLTypes.h:68
bool containsReference
Whether the type contains a reference type.
Definition FIRRTLTypes.h:72
bool isPassive
Whether the type only contains passive elements.
Definition FIRRTLTypes.h:70
bool containsAnalog
Whether the type contains an analog type.
Definition FIRRTLTypes.h:74
bool hasUninferredReset
Whether the type has any uninferred reset.
Definition FIRRTLTypes.h:82
bool containsTypeAlias
Whether the type contains a type alias.
Definition FIRRTLTypes.h:78
bool containsConst
Whether the type contains a const type.
Definition FIRRTLTypes.h:76
bool hasUninferredWidth
Whether the type has any uninferred bit widths.
Definition FIRRTLTypes.h:80
bool operator==(const KeyTy &key) const
static BaseTypeAliasStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
BaseTypeAliasStorage(StringAttr name, FIRRTLBaseType innerType)
std::tuple< StringAttr, FIRRTLBaseType > KeyTy
static llvm::hash_code hashKey(const KeyTy &key)
static BundleTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
SmallVector< BundleType::BundleElement, 4 > elements
std::pair< ArrayRef< BundleType::BundleElement >, char > KeyTy
static llvm::hash_code hashKey(const KeyTy &key)
RecursiveTypeProperties props
This holds the bits for the type's recursive properties, and can hold a pointer to a passive version ...
BundleTypeStorage(ArrayRef< BundleType::BundleElement > elements, bool isConst)
bool operator==(const KeyTy &key) const
std::pair< FlatSymbolRefAttr, ArrayRef< ClassElement > > KeyTy
bool operator==(const KeyTy &key) const
static ClassTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
ClassTypeStorage(FlatSymbolRefAttr name, ArrayRef< ClassElement > elements, ArrayRef< uint64_t > fieldIDs, uint64_t maxFieldID)
SmallVector< FEnumType::EnumElement, 4 > elements
static llvm::hash_code hashKey(const KeyTy &key)
bool operator==(const KeyTy &key) const
static FEnumTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
FEnumTypeStorage(ArrayRef< FEnumType::EnumElement > elements, bool isConst)
std::pair< ArrayRef< FEnumType::EnumElement >, char > KeyTy
SmallVector< uint64_t, 4 > fieldIDs
bool operator==(const KeyTy &key) const
static FIRRTLBaseTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
bool operator==(const KeyTy &key) const
RecursiveTypeProperties props
This holds the bits for the type's recursive properties, and can hold a pointer to a passive version ...
static FVectorTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
std::tuple< FIRRTLBaseType, size_t, char > KeyTy
FVectorTypeStorage(FIRRTLBaseType elementType, size_t numElements, bool isConst)
SmallVector< OpenBundleType::BundleElement, 4 > elements
static OpenBundleTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
static llvm::hash_code hashKey(const KeyTy &key)
RecursiveTypeProperties props
This holds the bits for the type's recursive properties, and can hold a pointer to a passive version ...
OpenBundleTypeStorage(ArrayRef< OpenBundleType::BundleElement > elements, bool isConst)
std::pair< ArrayRef< OpenBundleType::BundleElement >, char > KeyTy
std::tuple< FIRRTLType, size_t, char > KeyTy
static OpenVectorTypeStorage * construct(TypeStorageAllocator &allocator, KeyTy key)
OpenVectorTypeStorage(FIRRTLType elementType, size_t numElements, bool isConst)
WidthTypeStorage(int32_t width, bool isConst)
bool operator==(const KeyTy &key) const
static WidthTypeStorage * construct(TypeStorageAllocator &allocator, const KeyTy &key)