CIRCT 20.0.0git
Loading...
Searching...
No Matches
HWTypes.cpp
Go to the documentation of this file.
1//===- HWTypes.cpp - HW types code defs -----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation logic for HW data types.
10//
11//===----------------------------------------------------------------------===//
12
18#include "circt/Support/LLVM.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/DialectImplementation.h"
23#include "mlir/IR/StorageUniquerSupport.h"
24#include "mlir/IR/Types.h"
25#include "mlir/Interfaces/MemorySlotInterfaces.h"
26#include "llvm/ADT/SmallSet.h"
27#include "llvm/ADT/StringExtras.h"
28#include "llvm/ADT/StringSet.h"
29#include "llvm/ADT/TypeSwitch.h"
30
31using namespace circt;
32using namespace circt::hw;
33using namespace circt::hw::detail;
34
35static ParseResult parseHWArray(AsmParser &parser, Attribute &dim,
36 Type &elementType);
37static void printHWArray(AsmPrinter &printer, Attribute dim, Type elementType);
38
39static ParseResult parseHWElementType(AsmParser &parser, Type &elementType);
40static void printHWElementType(AsmPrinter &printer, Type dim);
41
42#define GET_TYPEDEF_CLASSES
43#include "circt/Dialect/HW/HWTypes.cpp.inc"
44
45//===----------------------------------------------------------------------===//
46// Type Helpers
47//===----------------------------------------------------------------------===/
48
49mlir::Type circt::hw::getCanonicalType(mlir::Type type) {
50 Type canonicalType;
51 if (auto typeAlias = dyn_cast<TypeAliasType>(type))
52 canonicalType = typeAlias.getCanonicalType();
53 else
54 canonicalType = type;
55 return canonicalType;
56}
57
58/// Return true if the specified type is a value HW Integer type. This checks
59/// that it is a signless standard dialect type or a hw::IntType.
60bool circt::hw::isHWIntegerType(mlir::Type type) {
61 Type canonicalType = getCanonicalType(type);
62
63 if (isa<hw::IntType>(canonicalType))
64 return true;
65
66 auto intType = dyn_cast<IntegerType>(canonicalType);
67 if (!intType || !intType.isSignless())
68 return false;
69
70 return true;
71}
72
73bool circt::hw::isHWEnumType(mlir::Type type) {
74 return isa<hw::EnumType>(getCanonicalType(type));
75}
76
77/// Return true if the specified type can be used as an HW value type, that is
78/// the set of types that can be composed together to represent synthesized,
79/// hardware but not marker types like InOutType.
80bool circt::hw::isHWValueType(Type type) {
81 // Signless and signed integer types are both valid.
82 if (isa<IntegerType, IntType, EnumType>(type))
83 return true;
84
85 if (auto array = dyn_cast<ArrayType>(type))
86 return isHWValueType(array.getElementType());
87
88 if (auto array = dyn_cast<UnpackedArrayType>(type))
89 return isHWValueType(array.getElementType());
90
91 if (auto t = dyn_cast<StructType>(type))
92 return llvm::all_of(t.getElements(),
93 [](auto f) { return isHWValueType(f.type); });
94
95 if (auto t = dyn_cast<UnionType>(type))
96 return llvm::all_of(t.getElements(),
97 [](auto f) { return isHWValueType(f.type); });
98
99 if (auto t = dyn_cast<TypeAliasType>(type))
100 return isHWValueType(t.getCanonicalType());
101
102 return false;
103}
104
105/// Return the hardware bit width of a type. Does not reflect any encoding,
106/// padding, or storage scheme, just the bit (and wire width) of a
107/// statically-size type. Reflects the number of wires needed to transmit a
108/// value of this type. Returns -1 if the type is not known or cannot be
109/// statically computed.
110int64_t circt::hw::getBitWidth(mlir::Type type) {
111 return llvm::TypeSwitch<::mlir::Type, size_t>(type)
112 .Case<IntegerType>(
113 [](IntegerType t) { return t.getIntOrFloatBitWidth(); })
114 .Case<ArrayType, UnpackedArrayType>([](auto a) {
115 int64_t elementBitWidth = getBitWidth(a.getElementType());
116 if (elementBitWidth < 0)
117 return elementBitWidth;
118 int64_t dimBitWidth = a.getNumElements();
119 if (dimBitWidth < 0)
120 return static_cast<int64_t>(-1L);
121 return (int64_t)a.getNumElements() * elementBitWidth;
122 })
123 .Case<StructType>([](StructType s) {
124 int64_t total = 0;
125 for (auto field : s.getElements()) {
126 int64_t fieldSize = getBitWidth(field.type);
127 if (fieldSize < 0)
128 return fieldSize;
129 total += fieldSize;
130 }
131 return total;
132 })
133 .Case<UnionType>([](UnionType u) {
134 int64_t maxSize = 0;
135 for (auto field : u.getElements()) {
136 int64_t fieldSize = getBitWidth(field.type) + field.offset;
137 if (fieldSize > maxSize)
138 maxSize = fieldSize;
139 }
140 return maxSize;
141 })
142 .Case<EnumType>([](EnumType e) { return e.getBitWidth(); })
143 .Case<TypeAliasType>(
144 [](TypeAliasType t) { return getBitWidth(t.getCanonicalType()); })
145 .Default([](Type) { return -1; });
146}
147
148/// Return true if the specified type contains known marker types like
149/// InOutType. Unlike isHWValueType, this is not conservative, it only returns
150/// false on known InOut types, rather than any unknown types.
151bool circt::hw::hasHWInOutType(Type type) {
152 if (auto array = dyn_cast<ArrayType>(type))
153 return hasHWInOutType(array.getElementType());
154
155 if (auto array = dyn_cast<UnpackedArrayType>(type))
156 return hasHWInOutType(array.getElementType());
157
158 if (auto t = dyn_cast<StructType>(type)) {
159 return std::any_of(t.getElements().begin(), t.getElements().end(),
160 [](const auto &f) { return hasHWInOutType(f.type); });
161 }
162
163 if (auto t = dyn_cast<TypeAliasType>(type))
164 return hasHWInOutType(t.getCanonicalType());
165
166 return isa<InOutType>(type);
167}
168
169/// Parse and print nested HW types nicely. These helper methods allow eliding
170/// the "hw." prefix on array, inout, and other types when in a context that
171/// expects HW subelement types.
172static ParseResult parseHWElementType(AsmParser &p, Type &result) {
173 // If this is an HW dialect type, then we don't need/want the !hw. prefix
174 // redundantly specified.
175 auto fullString = static_cast<DialectAsmParser &>(p).getFullSymbolSpec();
176 auto *curPtr = p.getCurrentLocation().getPointer();
177 auto typeString =
178 StringRef(curPtr, fullString.size() - (curPtr - fullString.data()));
179
180 if (typeString.starts_with("array<") || typeString.starts_with("inout<") ||
181 typeString.starts_with("uarray<") || typeString.starts_with("struct<") ||
182 typeString.starts_with("typealias<") || typeString.starts_with("int<") ||
183 typeString.starts_with("enum<")) {
184 llvm::StringRef mnemonic;
185 if (auto parseResult = generatedTypeParser(p, &mnemonic, result);
186 parseResult.has_value())
187 return *parseResult;
188 return p.emitError(p.getNameLoc(), "invalid type `") << typeString << "`";
189 }
190
191 return p.parseType(result);
192}
193
194static void printHWElementType(AsmPrinter &p, Type element) {
195 if (succeeded(generatedTypePrinter(element, p)))
196 return;
197 p.printType(element);
198}
199
200//===----------------------------------------------------------------------===//
201// Int Type
202//===----------------------------------------------------------------------===//
203
204Type IntType::get(mlir::TypedAttr width) {
205 // The width expression must always be a 32-bit wide integer type itself.
206 auto widthWidth = llvm::dyn_cast<IntegerType>(width.getType());
207 assert(widthWidth && widthWidth.getWidth() == 32 &&
208 "!hw.int width must be 32-bits");
209 (void)widthWidth;
210
211 if (auto cstWidth = llvm::dyn_cast<IntegerAttr>(width))
212 return IntegerType::get(width.getContext(),
213 cstWidth.getValue().getZExtValue());
214
215 return Base::get(width.getContext(), width);
216}
217
218Type IntType::parse(AsmParser &p) {
219 // The bitwidth of the parameter size is always 32 bits.
220 auto int32Type = p.getBuilder().getIntegerType(32);
221
222 mlir::TypedAttr width;
223 if (p.parseLess() || p.parseAttribute(width, int32Type) || p.parseGreater())
224 return Type();
225 return get(width);
226}
227
228void IntType::print(AsmPrinter &p) const {
229 p << "<";
230 p.printAttributeWithoutType(getWidth());
231 p << '>';
232}
233
234//===----------------------------------------------------------------------===//
235// Struct Type
236//===----------------------------------------------------------------------===//
237
238namespace circt {
239namespace hw {
240namespace detail {
241bool operator==(const FieldInfo &a, const FieldInfo &b) {
242 return a.name == b.name && a.type == b.type;
243}
244llvm::hash_code hash_value(const FieldInfo &fi) {
245 return llvm::hash_combine(fi.name, fi.type);
246}
247} // namespace detail
248} // namespace hw
249} // namespace circt
250
251/// Parse a list of unique field names and types within <>. E.g.:
252/// <foo: i7, bar: i8>
253static ParseResult parseFields(AsmParser &p,
254 SmallVectorImpl<FieldInfo> &parameters) {
255 llvm::StringSet<> nameSet;
256 bool hasDuplicateName = false;
257 auto parseResult = p.parseCommaSeparatedList(
258 mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
259 std::string name;
260 Type type;
261
262 auto fieldLoc = p.getCurrentLocation();
263 if (p.parseKeywordOrString(&name) || p.parseColon() ||
264 p.parseType(type))
265 return failure();
266
267 if (!nameSet.insert(name).second) {
268 p.emitError(fieldLoc, "duplicate field name \'" + name + "\'");
269 // Continue parsing to print all duplicates, but make sure to error
270 // eventually
271 hasDuplicateName = true;
272 }
273
274 parameters.push_back(
275 FieldInfo{StringAttr::get(p.getContext(), name), type});
276 return success();
277 });
278
279 if (hasDuplicateName)
280 return failure();
281 return parseResult;
282}
283
284/// Print out a list of named fields surrounded by <>.
285static void printFields(AsmPrinter &p, ArrayRef<FieldInfo> fields) {
286 p << '<';
287 llvm::interleaveComma(fields, p, [&](const FieldInfo &field) {
288 p.printKeywordOrString(field.name.getValue());
289 p << ": " << field.type;
290 });
291 p << ">";
292}
293
294Type StructType::parse(AsmParser &p) {
295 llvm::SmallVector<FieldInfo, 4> parameters;
296 if (parseFields(p, parameters))
297 return Type();
298 return get(p.getContext(), parameters);
299}
300
301LogicalResult StructType::verify(function_ref<InFlightDiagnostic()> emitError,
302 ArrayRef<StructType::FieldInfo> elements) {
303 llvm::SmallDenseSet<StringAttr> fieldNameSet;
304 LogicalResult result = success();
305 fieldNameSet.reserve(elements.size());
306 for (const auto &elt : elements)
307 if (!fieldNameSet.insert(elt.name).second) {
308 result = failure();
309 emitError() << "duplicate field name '" << elt.name.getValue()
310 << "' in hw.struct type";
311 }
312 return result;
313}
314
315void StructType::print(AsmPrinter &p) const { printFields(p, getElements()); }
316
317Type StructType::getFieldType(mlir::StringRef fieldName) {
318 for (const auto &field : getElements())
319 if (field.name == fieldName)
320 return field.type;
321 return Type();
322}
323
324std::optional<uint32_t> StructType::getFieldIndex(mlir::StringRef fieldName) {
325 ArrayRef<hw::StructType::FieldInfo> elems = getElements();
326 for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
327 if (elems[idx].name == fieldName)
328 return idx;
329 return {};
330}
331
332std::optional<uint32_t> StructType::getFieldIndex(mlir::StringAttr fieldName) {
333 ArrayRef<hw::StructType::FieldInfo> elems = getElements();
334 for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
335 if (elems[idx].name == fieldName)
336 return idx;
337 return {};
338}
339
340static std::pair<uint64_t, SmallVector<uint64_t>>
341getFieldIDsStruct(const StructType &st) {
342 uint64_t fieldID = 0;
343 auto elements = st.getElements();
344 SmallVector<uint64_t> fieldIDs;
345 fieldIDs.reserve(elements.size());
346 for (auto &element : elements) {
347 auto type = element.type;
348 fieldID += 1;
349 fieldIDs.push_back(fieldID);
350 // Increment the field ID for the next field by the number of subfields.
351 fieldID += hw::FieldIdImpl::getMaxFieldID(type);
352 }
353 return {fieldID, fieldIDs};
354}
355
356void StructType::getInnerTypes(SmallVectorImpl<Type> &types) {
357 for (const auto &field : getElements())
358 types.push_back(field.type);
359}
360
361uint64_t StructType::getMaxFieldID() const {
362 uint64_t fieldID = 0;
363 for (const auto &field : getElements())
364 fieldID += 1 + hw::FieldIdImpl::getMaxFieldID(field.type);
365 return fieldID;
366}
367
368std::pair<Type, uint64_t>
369StructType::getSubTypeByFieldID(uint64_t fieldID) const {
370 if (fieldID == 0)
371 return {*this, 0};
372 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
373 auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
374 auto subfieldIndex = std::distance(fieldIDs.begin(), it);
375 auto subfieldType = getElements()[subfieldIndex].type;
376 auto subfieldID = fieldID - fieldIDs[subfieldIndex];
377 return {subfieldType, subfieldID};
378}
379
380std::pair<uint64_t, bool>
381StructType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
382 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
383 auto childRoot = fieldIDs[index];
384 auto rangeEnd =
385 index + 1 >= getElements().size() ? maxId : (fieldIDs[index + 1] - 1);
386 return std::make_pair(fieldID - childRoot,
387 fieldID >= childRoot && fieldID <= rangeEnd);
388}
389
390uint64_t StructType::getFieldID(uint64_t index) const {
391 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
392 return fieldIDs[index];
393}
394
395uint64_t StructType::getIndexForFieldID(uint64_t fieldID) const {
396 assert(!getElements().empty() && "Bundle must have >0 fields");
397 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
398 auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
399 return std::distance(fieldIDs.begin(), it);
400}
401
402std::pair<uint64_t, uint64_t>
403StructType::getIndexAndSubfieldID(uint64_t fieldID) const {
404 auto index = getIndexForFieldID(fieldID);
405 auto elementFieldID = getFieldID(index);
406 return {index, fieldID - elementFieldID};
407}
408
409std::optional<DenseMap<Attribute, Type>>
410hw::StructType::getSubelementIndexMap() const {
411 DenseMap<Attribute, Type> destructured;
412 for (auto [i, field] : llvm::enumerate(getElements()))
413 destructured.insert(
414 {IntegerAttr::get(IndexType::get(getContext()), i), field.type});
415 return destructured;
416}
417
418Type hw::StructType::getTypeAtIndex(Attribute index) const {
419 auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
420 if (!indexAttr)
421 return {};
422
423 return getSubTypeByFieldID(indexAttr.getInt()).first;
424}
425
426//===----------------------------------------------------------------------===//
427// Union Type
428//===----------------------------------------------------------------------===//
429
430namespace circt {
431namespace hw {
432namespace detail {
434 return a.name == b.name && a.type == b.type && a.offset == b.offset;
435}
436// NOLINTNEXTLINE
437llvm::hash_code hash_value(const OffsetFieldInfo &fi) {
438 return llvm::hash_combine(fi.name, fi.type, fi.offset);
439}
440} // namespace detail
441} // namespace hw
442} // namespace circt
443
444Type UnionType::parse(AsmParser &p) {
445 llvm::SmallVector<FieldInfo, 4> parameters;
446 llvm::StringSet<> nameSet;
447 bool hasDuplicateName = false;
448 if (p.parseCommaSeparatedList(
449 mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
450 StringRef name;
451 Type type;
452
453 auto fieldLoc = p.getCurrentLocation();
454 if (p.parseKeyword(&name) || p.parseColon() || p.parseType(type))
455 return failure();
456
457 if (!nameSet.insert(name).second) {
458 p.emitError(fieldLoc, "duplicate field name \'" + name +
459 "\' in hw.union type");
460 // Continue parsing to print all duplicates, but make sure to
461 // error eventually
462 hasDuplicateName = true;
463 }
464
465 size_t offset = 0;
466 if (succeeded(p.parseOptionalKeyword("offset")))
467 if (p.parseInteger(offset))
468 return failure();
469 parameters.push_back(UnionType::FieldInfo{
470 StringAttr::get(p.getContext(), name), type, offset});
471 return success();
472 }))
473 return Type();
474
475 if (hasDuplicateName)
476 return Type();
477
478 return get(p.getContext(), parameters);
479}
480
481void UnionType::print(AsmPrinter &odsPrinter) const {
482 odsPrinter << '<';
483 llvm::interleaveComma(
484 getElements(), odsPrinter, [&](const UnionType::FieldInfo &field) {
485 odsPrinter << field.name.getValue() << ": " << field.type;
486 if (field.offset)
487 odsPrinter << " offset " << field.offset;
488 });
489 odsPrinter << ">";
490}
491
492LogicalResult UnionType::verify(function_ref<InFlightDiagnostic()> emitError,
493 ArrayRef<UnionType::FieldInfo> elements) {
494 llvm::SmallDenseSet<StringAttr> fieldNameSet;
495 LogicalResult result = success();
496 fieldNameSet.reserve(elements.size());
497 for (const auto &elt : elements)
498 if (!fieldNameSet.insert(elt.name).second) {
499 result = failure();
500 emitError() << "duplicate field name '" << elt.name.getValue()
501 << "' in hw.union type";
502 }
503 return result;
504}
505
506std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringAttr fieldName) {
507 ArrayRef<hw::UnionType::FieldInfo> elems = getElements();
508 for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
509 if (elems[idx].name == fieldName)
510 return idx;
511 return {};
512}
513
514std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringRef fieldName) {
515 return getFieldIndex(StringAttr::get(getContext(), fieldName));
516}
517
518UnionType::FieldInfo UnionType::getFieldInfo(::mlir::StringRef fieldName) {
519 if (auto fieldIndex = getFieldIndex(fieldName))
520 return getElements()[*fieldIndex];
521 return FieldInfo();
522}
523
524Type UnionType::getFieldType(mlir::StringRef fieldName) {
525 return getFieldInfo(fieldName).type;
526}
527
528//===----------------------------------------------------------------------===//
529// Enum Type
530//===----------------------------------------------------------------------===//
531
532Type EnumType::parse(AsmParser &p) {
533 llvm::SmallVector<Attribute> fields;
534
535 if (p.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, [&]() {
536 StringRef name;
537 if (p.parseKeyword(&name))
538 return failure();
539 fields.push_back(StringAttr::get(p.getContext(), name));
540 return success();
541 }))
542 return Type();
543
544 return get(p.getContext(), ArrayAttr::get(p.getContext(), fields));
545}
546
547void EnumType::print(AsmPrinter &p) const {
548 p << '<';
549 llvm::interleaveComma(getFields(), p, [&](Attribute enumerator) {
550 p << llvm::cast<StringAttr>(enumerator).getValue();
551 });
552 p << ">";
553}
554
555bool EnumType::contains(mlir::StringRef field) {
556 return indexOf(field).has_value();
557}
558
559std::optional<size_t> EnumType::indexOf(mlir::StringRef field) {
560 for (auto it : llvm::enumerate(getFields()))
561 if (llvm::cast<StringAttr>(it.value()).getValue() == field)
562 return it.index();
563 return {};
564}
565
566size_t EnumType::getBitWidth() {
567 auto w = getFields().size();
568 if (w > 1)
569 return llvm::Log2_64_Ceil(getFields().size());
570 return 1;
571}
572
573//===----------------------------------------------------------------------===//
574// ArrayType
575//===----------------------------------------------------------------------===//
576
577static ParseResult parseHWArray(AsmParser &p, Attribute &dim, Type &inner) {
578 uint64_t dimLiteral;
579 auto int64Type = p.getBuilder().getIntegerType(64);
580
581 if (auto res = p.parseOptionalInteger(dimLiteral); res.has_value()) {
582 if (failed(*res))
583 return failure();
584 dim = p.getBuilder().getI64IntegerAttr(dimLiteral);
585 } else if (auto res64 = p.parseOptionalAttribute(dim, int64Type);
586 res64.has_value()) {
587 if (failed(*res64))
588 return failure();
589 } else
590 return p.emitError(p.getNameLoc(), "expected integer");
591
592 if (!isa<IntegerAttr, ParamExprAttr, ParamDeclRefAttr>(dim)) {
593 p.emitError(p.getNameLoc(), "unsupported dimension kind in hw.array");
594 return failure();
595 }
596
597 if (p.parseXInDimensionList() || parseHWElementType(p, inner))
598 return failure();
599
600 return success();
601}
602
603static void printHWArray(AsmPrinter &p, Attribute dim, Type elementType) {
604 p.printAttributeWithoutType(dim);
605 p << "x";
607}
608
609size_t ArrayType::getNumElements() const {
610 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(getSizeAttr()))
611 return intAttr.getInt();
612 return -1;
613}
614
615LogicalResult ArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
616 Type innerType, Attribute size) {
617 if (hasHWInOutType(innerType))
618 return emitError() << "hw.array cannot contain InOut types";
619 return success();
620}
621
622uint64_t ArrayType::getMaxFieldID() const {
623 return getNumElements() *
624 (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
625}
626
627std::pair<Type, uint64_t>
628ArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
629 if (fieldID == 0)
630 return {*this, 0};
631 return {getElementType(), getIndexAndSubfieldID(fieldID).second};
632}
633
634std::pair<uint64_t, bool>
635ArrayType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
636 auto childRoot = getFieldID(index);
637 auto rangeEnd =
638 index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
639 return std::make_pair(fieldID - childRoot,
640 fieldID >= childRoot && fieldID <= rangeEnd);
641}
642
643uint64_t ArrayType::getIndexForFieldID(uint64_t fieldID) const {
644 assert(fieldID && "fieldID must be at least 1");
645 // Divide the field ID by the number of fieldID's per element.
646 return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
647}
648
649std::pair<uint64_t, uint64_t>
650ArrayType::getIndexAndSubfieldID(uint64_t fieldID) const {
651 auto index = getIndexForFieldID(fieldID);
652 auto elementFieldID = getFieldID(index);
653 return {index, fieldID - elementFieldID};
654}
655
656uint64_t ArrayType::getFieldID(uint64_t index) const {
657 return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
658}
659
660std::optional<DenseMap<Attribute, Type>>
661hw::ArrayType::getSubelementIndexMap() const {
662 DenseMap<Attribute, Type> destructured;
663 for (unsigned i = 0; i < getNumElements(); ++i)
664 destructured.insert(
665 {IntegerAttr::get(IndexType::get(getContext()), i), getElementType()});
666 return destructured;
667}
668
669Type hw::ArrayType::getTypeAtIndex(Attribute index) const {
670 return getElementType();
671}
672
673//===----------------------------------------------------------------------===//
674// UnpackedArrayType
675//===----------------------------------------------------------------------===//
676
677LogicalResult
678UnpackedArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
679 Type innerType, Attribute size) {
680 if (!isHWValueType(innerType))
681 return emitError() << "invalid element for uarray type";
682 return success();
683}
684
685size_t UnpackedArrayType::getNumElements() const {
686 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(getSizeAttr()))
687 return intAttr.getInt();
688 return -1;
689}
690
691uint64_t UnpackedArrayType::getMaxFieldID() const {
692 return getNumElements() *
693 (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
694}
695
696std::pair<Type, uint64_t>
697UnpackedArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
698 if (fieldID == 0)
699 return {*this, 0};
700 return {getElementType(), getIndexAndSubfieldID(fieldID).second};
701}
702
703std::pair<uint64_t, bool>
704UnpackedArrayType::projectToChildFieldID(uint64_t fieldID,
705 uint64_t index) const {
706 auto childRoot = getFieldID(index);
707 auto rangeEnd =
708 index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
709 return std::make_pair(fieldID - childRoot,
710 fieldID >= childRoot && fieldID <= rangeEnd);
711}
712
713uint64_t UnpackedArrayType::getIndexForFieldID(uint64_t fieldID) const {
714 assert(fieldID && "fieldID must be at least 1");
715 // Divide the field ID by the number of fieldID's per element.
716 return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
717}
718
719std::pair<uint64_t, uint64_t>
720UnpackedArrayType::getIndexAndSubfieldID(uint64_t fieldID) const {
721 auto index = getIndexForFieldID(fieldID);
722 auto elementFieldID = getFieldID(index);
723 return {index, fieldID - elementFieldID};
724}
725
726uint64_t UnpackedArrayType::getFieldID(uint64_t index) const {
727 return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
728}
729
730//===----------------------------------------------------------------------===//
731// InOutType
732//===----------------------------------------------------------------------===//
733
734LogicalResult InOutType::verify(function_ref<InFlightDiagnostic()> emitError,
735 Type innerType) {
736 if (!isHWValueType(innerType))
737 return emitError() << "invalid element for hw.inout type " << innerType;
738 return success();
739}
740
741//===----------------------------------------------------------------------===//
742// TypeAliasType
743//===----------------------------------------------------------------------===//
744
745static Type computeCanonicalType(Type type) {
746 return llvm::TypeSwitch<Type, Type>(type)
747 .Case([](TypeAliasType t) {
748 return computeCanonicalType(t.getCanonicalType());
749 })
750 .Case([](ArrayType t) {
751 return ArrayType::get(computeCanonicalType(t.getElementType()),
752 t.getNumElements());
753 })
754 .Case([](UnpackedArrayType t) {
755 return UnpackedArrayType::get(computeCanonicalType(t.getElementType()),
756 t.getNumElements());
757 })
758 .Case([](StructType t) {
759 SmallVector<StructType::FieldInfo> fieldInfo;
760 for (auto field : t.getElements())
761 fieldInfo.push_back(StructType::FieldInfo{
762 field.name, computeCanonicalType(field.type)});
763 return StructType::get(t.getContext(), fieldInfo);
764 })
765 .Default([](Type t) { return t; });
766}
767
768TypeAliasType TypeAliasType::get(SymbolRefAttr ref, Type innerType) {
769 return get(ref.getContext(), ref, innerType, computeCanonicalType(innerType));
770}
771
772Type TypeAliasType::parse(AsmParser &p) {
773 SymbolRefAttr ref;
774 Type type;
775 if (p.parseLess() || p.parseAttribute(ref) || p.parseComma() ||
776 p.parseType(type) || p.parseGreater())
777 return Type();
778
779 return get(ref, type);
780}
781
782void TypeAliasType::print(AsmPrinter &p) const {
783 p << "<" << getRef() << ", " << getInnerType() << ">";
784}
785
786/// Return the Typedecl referenced by this TypeAlias, given the module to look
787/// in. This returns null when the IR is malformed.
788TypedeclOp TypeAliasType::getTypeDecl(const HWSymbolCache &cache) {
789 SymbolRefAttr ref = getRef();
790 auto typeScope = ::dyn_cast_or_null<TypeScopeOp>(
791 cache.getDefinition(ref.getRootReference()));
792 if (!typeScope)
793 return {};
794
795 return typeScope.lookupSymbol<TypedeclOp>(ref.getLeafReference());
796}
797
798//===----------------------------------------------------------------------===//
799// ModuleType
800//===----------------------------------------------------------------------===//
801
802LogicalResult ModuleType::verify(function_ref<InFlightDiagnostic()> emitError,
803 ArrayRef<ModulePort> ports) {
804 if (llvm::any_of(ports, [](const ModulePort &port) {
805 return hasHWInOutType(port.type);
806 }))
807 return emitError() << "Ports cannot be inout types";
808 return success();
809}
810
811size_t ModuleType::getPortIdForInputId(size_t idx) {
812 assert(idx < getImpl()->inputToAbs.size() && "input port out of range");
813 return getImpl()->inputToAbs[idx];
814}
815
816size_t ModuleType::getPortIdForOutputId(size_t idx) {
817 assert(idx < getImpl()->outputToAbs.size() && " output port out of range");
818 return getImpl()->outputToAbs[idx];
819}
820
821size_t ModuleType::getInputIdForPortId(size_t idx) {
822 auto nIdx = getImpl()->absToInput[idx];
823 assert(nIdx != ~0ULL);
824 return nIdx;
825}
826
827size_t ModuleType::getOutputIdForPortId(size_t idx) {
828 auto nIdx = getImpl()->absToOutput[idx];
829 assert(nIdx != ~0ULL);
830 return nIdx;
831}
832
833size_t ModuleType::getNumInputs() { return getImpl()->inputToAbs.size(); }
834
835size_t ModuleType::getNumOutputs() { return getImpl()->outputToAbs.size(); }
836
837size_t ModuleType::getNumPorts() { return getPorts().size(); }
838
839SmallVector<Type> ModuleType::getInputTypes() {
840 SmallVector<Type> retval;
841 for (auto &p : getPorts()) {
842 if (p.dir == ModulePort::Direction::Input)
843 retval.push_back(p.type);
844 else if (p.dir == ModulePort::Direction::InOut) {
845 retval.push_back(hw::InOutType::get(p.type));
846 }
847 }
848 return retval;
849}
850
851SmallVector<Type> ModuleType::getOutputTypes() {
852 SmallVector<Type> retval;
853 for (auto &p : getPorts())
854 if (p.dir == ModulePort::Direction::Output)
855 retval.push_back(p.type);
856 return retval;
857}
858
859SmallVector<Type> ModuleType::getPortTypes() {
860 SmallVector<Type> retval;
861 for (auto &p : getPorts())
862 retval.push_back(p.type);
863 return retval;
864}
865
866Type ModuleType::getInputType(size_t idx) {
867 const auto &portInfo = getPorts()[getPortIdForInputId(idx)];
868 if (portInfo.dir != ModulePort::InOut)
869 return portInfo.type;
870 return InOutType::get(portInfo.type);
871}
872
873Type ModuleType::getOutputType(size_t idx) {
874 return getPorts()[getPortIdForOutputId(idx)].type;
875}
876
877SmallVector<Attribute> ModuleType::getInputNames() {
878 SmallVector<Attribute> retval;
879 for (auto &p : getPorts())
880 if (p.dir != ModulePort::Direction::Output)
881 retval.push_back(p.name);
882 return retval;
883}
884
885SmallVector<Attribute> ModuleType::getOutputNames() {
886 SmallVector<Attribute> retval;
887 for (auto &p : getPorts())
888 if (p.dir == ModulePort::Direction::Output)
889 retval.push_back(p.name);
890 return retval;
891}
892
893StringAttr ModuleType::getPortNameAttr(size_t idx) {
894 return getPorts()[idx].name;
895}
896
897StringRef ModuleType::getPortName(size_t idx) {
898 auto sa = getPortNameAttr(idx);
899 if (sa)
900 return sa.getValue();
901 return {};
902}
903
904StringAttr ModuleType::getInputNameAttr(size_t idx) {
905 return getPorts()[getPortIdForInputId(idx)].name;
906}
907
908StringRef ModuleType::getInputName(size_t idx) {
909 auto sa = getInputNameAttr(idx);
910 if (sa)
911 return sa.getValue();
912 return {};
913}
914
915StringAttr ModuleType::getOutputNameAttr(size_t idx) {
916 return getPorts()[getPortIdForOutputId(idx)].name;
917}
918
919StringRef ModuleType::getOutputName(size_t idx) {
920 auto sa = getOutputNameAttr(idx);
921 if (sa)
922 return sa.getValue();
923 return {};
924}
925
926bool ModuleType::isOutput(size_t idx) {
927 auto &p = getPorts()[idx];
928 return p.dir == ModulePort::Direction::Output;
929}
930
931FunctionType ModuleType::getFuncType() {
932 SmallVector<Type> inputs, outputs;
933 for (auto p : getPorts())
934 if (p.dir == ModulePort::Input)
935 inputs.push_back(p.type);
936 else if (p.dir == ModulePort::InOut)
937 inputs.push_back(InOutType::get(p.type));
938 else
939 outputs.push_back(p.type);
940 return FunctionType::get(getContext(), inputs, outputs);
941}
942
943ArrayRef<ModulePort> ModuleType::getPorts() const {
944 return getImpl()->getPorts();
945}
946
947FailureOr<ModuleType> ModuleType::resolveParametricTypes(ArrayAttr parameters,
948 LocationAttr loc,
949 bool emitErrors) {
950 SmallVector<ModulePort, 8> resolvedPorts;
951 for (ModulePort port : getPorts()) {
952 FailureOr<Type> resolvedType =
953 evaluateParametricType(loc, parameters, port.type, emitErrors);
954 if (failed(resolvedType))
955 return failure();
956 port.type = *resolvedType;
957 resolvedPorts.push_back(port);
958 }
959 return ModuleType::get(getContext(), resolvedPorts);
960}
961
962static StringRef dirToStr(ModulePort::Direction dir) {
963 switch (dir) {
964 case ModulePort::Direction::Input:
965 return "input";
966 case ModulePort::Direction::Output:
967 return "output";
968 case ModulePort::Direction::InOut:
969 return "inout";
970 }
971}
972
973static ModulePort::Direction strToDir(StringRef str) {
974 if (str == "input")
975 return ModulePort::Direction::Input;
976 if (str == "output")
977 return ModulePort::Direction::Output;
978 if (str == "inout")
979 return ModulePort::Direction::InOut;
980 llvm::report_fatal_error("invalid direction");
981}
982
983/// Parse a list of field names and types within <>. E.g.:
984/// <input foo: i7, output bar: i8>
985static ParseResult parsePorts(AsmParser &p,
986 SmallVectorImpl<ModulePort> &ports) {
987 return p.parseCommaSeparatedList(
988 mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
989 StringRef dir;
990 std::string name;
991 Type type;
992 if (p.parseKeyword(&dir) || p.parseKeywordOrString(&name) ||
993 p.parseColon() || p.parseType(type))
994 return failure();
995 ports.push_back(
996 {StringAttr::get(p.getContext(), name), type, strToDir(dir)});
997 return success();
998 });
999}
1000
1001/// Print out a list of named fields surrounded by <>.
1002static void printPorts(AsmPrinter &p, ArrayRef<ModulePort> ports) {
1003 p << '<';
1004 llvm::interleaveComma(ports, p, [&](const ModulePort &port) {
1005 p << dirToStr(port.dir) << " ";
1006 p.printKeywordOrString(port.name.getValue());
1007 p << " : " << port.type;
1008 });
1009 p << ">";
1010}
1011
1012Type ModuleType::parse(AsmParser &odsParser) {
1013 llvm::SmallVector<ModulePort, 4> ports;
1014 if (parsePorts(odsParser, ports))
1015 return Type();
1016 return get(odsParser.getContext(), ports);
1017}
1018
1019void ModuleType::print(AsmPrinter &odsPrinter) const {
1020 printPorts(odsPrinter, getPorts());
1021}
1022
1023ModuleType circt::hw::detail::fnToMod(Operation *op,
1024 ArrayRef<Attribute> inputNames,
1025 ArrayRef<Attribute> outputNames) {
1026 return fnToMod(
1027 cast<FunctionType>(cast<mlir::FunctionOpInterface>(op).getFunctionType()),
1028 inputNames, outputNames);
1029}
1030
1031ModuleType circt::hw::detail::fnToMod(FunctionType fnty,
1032 ArrayRef<Attribute> inputNames,
1033 ArrayRef<Attribute> outputNames) {
1034 SmallVector<ModulePort> ports;
1035 if (!inputNames.empty()) {
1036 for (auto [t, n] : llvm::zip_equal(fnty.getInputs(), inputNames))
1037 if (auto iot = dyn_cast<hw::InOutType>(t))
1038 ports.push_back({cast<StringAttr>(n), iot.getElementType(),
1039 ModulePort::Direction::InOut});
1040 else
1041 ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Input});
1042 } else {
1043 for (auto t : fnty.getInputs())
1044 if (auto iot = dyn_cast<hw::InOutType>(t))
1045 ports.push_back(
1046 {{}, iot.getElementType(), ModulePort::Direction::InOut});
1047 else
1048 ports.push_back({{}, t, ModulePort::Direction::Input});
1049 }
1050 if (!outputNames.empty()) {
1051 for (auto [t, n] : llvm::zip_equal(fnty.getResults(), outputNames))
1052 ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Output});
1053 } else {
1054 for (auto t : fnty.getResults())
1055 ports.push_back({{}, t, ModulePort::Direction::Output});
1056 }
1057 return ModuleType::get(fnty.getContext(), ports);
1058}
1059
1061 : ports(inPorts) {
1062 size_t nextInput = 0;
1063 size_t nextOutput = 0;
1064 for (auto [idx, p] : llvm::enumerate(ports)) {
1065 if (p.dir == ModulePort::Direction::Output) {
1066 outputToAbs.push_back(idx);
1067 absToOutput.push_back(nextOutput);
1068 absToInput.push_back(~0ULL);
1069 ++nextOutput;
1070 } else {
1071 inputToAbs.push_back(idx);
1072 absToInput.push_back(nextInput);
1073 absToOutput.push_back(~0ULL);
1074 ++nextInput;
1075 }
1076 }
1077}
1078
1079//===----------------------------------------------------------------------===//
1080// BoilerPlate
1081//===----------------------------------------------------------------------===//
1082
1083void HWDialect::registerTypes() {
1084 addTypes<
1085#define GET_TYPEDEF_LIST
1086#include "circt/Dialect/HW/HWTypes.cpp.inc"
1087 >();
1088}
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static ModulePort::Direction strToDir(StringRef str)
Definition HWTypes.cpp:973
static void printPorts(AsmPrinter &p, ArrayRef< ModulePort > ports)
Print out a list of named fields surrounded by <>.
Definition HWTypes.cpp:1002
static void printFields(AsmPrinter &p, ArrayRef< FieldInfo > fields)
Print out a list of named fields surrounded by <>.
Definition HWTypes.cpp:285
static StringRef dirToStr(ModulePort::Direction dir)
Definition HWTypes.cpp:962
static ParseResult parseHWArray(AsmParser &parser, Attribute &dim, Type &elementType)
Definition HWTypes.cpp:577
static ParseResult parseHWElementType(AsmParser &parser, Type &elementType)
Parse and print nested HW types nicely.
Definition HWTypes.cpp:172
static ParseResult parsePorts(AsmParser &p, SmallVectorImpl< ModulePort > &ports)
Parse a list of field names and types within <>.
Definition HWTypes.cpp:985
static void printHWArray(AsmPrinter &printer, Attribute dim, Type elementType)
Definition HWTypes.cpp:603
static std::pair< uint64_t, SmallVector< uint64_t > > getFieldIDsStruct(const StructType &st)
Definition HWTypes.cpp:341
static ParseResult parseFields(AsmParser &p, SmallVectorImpl< FieldInfo > &parameters)
Parse a list of unique field names and types within <>.
Definition HWTypes.cpp:253
static Type computeCanonicalType(Type type)
Definition HWTypes.cpp:745
static void printHWElementType(AsmPrinter &printer, Type dim)
Definition HWTypes.cpp:194
@ Input
Definition HW.h:35
@ Output
Definition HW.h:35
static unsigned getFieldID(BundleType type, unsigned index)
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
static unsigned getMaxFieldID(FIRRTLBaseType type)
static InstancePath empty
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition HWSymCache.h:27
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition HWSymCache.h:56
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
mlir::Type innerType(mlir::Type type)
Definition ESITypes.cpp:227
std::pair< uint64_t, uint64_t > getIndexAndSubfieldID(Type type, uint64_t fieldID)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
llvm::hash_code hash_value(const FieldInfo &fi)
Definition HWTypes.cpp:244
bool operator==(const FieldInfo &a, const FieldInfo &b)
Definition HWTypes.cpp:241
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
Definition HWTypes.cpp:1023
bool isHWIntegerType(mlir::Type type)
Return true if the specified type is a value HW Integer type.
Definition HWTypes.cpp:60
bool isHWValueType(mlir::Type type)
Return true if the specified type can be used as an HW value type, that is the set of types that can ...
mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
bool isHWEnumType(mlir::Type type)
Return true if the specified type is a HW Enum type.
Definition HWTypes.cpp:73
mlir::Type getCanonicalType(mlir::Type type)
Definition HWTypes.cpp:49
bool hasHWInOutType(mlir::Type type)
Return true if the specified type contains known marker types like InOutType.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
mlir::Type type
Definition HWTypes.h:31
mlir::StringAttr name
Definition HWTypes.h:30
Struct defining a field. Used in structs.
Definition HWTypes.h:92
mlir::StringAttr name
Definition HWTypes.h:93
SmallVector< ModulePort > ports
The parametric data held by the storage class.
Definition HWTypes.h:70
ModuleTypeStorage(ArrayRef< ModulePort > inPorts)
Definition HWTypes.cpp:1060
SmallVector< size_t > absToInput
Definition HWTypes.h:74
SmallVector< size_t > outputToAbs
Definition HWTypes.h:73
SmallVector< size_t > inputToAbs
Definition HWTypes.h:72
SmallVector< size_t > absToOutput
Definition HWTypes.h:75
Struct defining a field with an offset. Used in unions.
Definition HWTypes.h:98