CIRCT 22.0.0git
Loading...
Searching...
No Matches
HWTypes.cpp
Go to the documentation of this file.
1//===- HWTypes.cpp - HW types code defs -----------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Implementation logic for HW data types.
10//
11//===----------------------------------------------------------------------===//
12
18#include "circt/Support/LLVM.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/BuiltinTypes.h"
21#include "mlir/IR/Diagnostics.h"
22#include "mlir/IR/DialectImplementation.h"
23#include "mlir/IR/StorageUniquerSupport.h"
24#include "mlir/IR/Types.h"
25#include "mlir/Interfaces/MemorySlotInterfaces.h"
26#include "llvm/ADT/SmallSet.h"
27#include "llvm/ADT/StringExtras.h"
28#include "llvm/ADT/StringSet.h"
29#include "llvm/ADT/TypeSwitch.h"
30
31using namespace circt;
32using namespace circt::hw;
33using namespace circt::hw::detail;
34
35static ParseResult parseHWArray(AsmParser &parser, Attribute &dim,
36 Type &elementType);
37static void printHWArray(AsmPrinter &printer, Attribute dim, Type elementType);
38
39static ParseResult parseHWElementType(AsmParser &parser, Type &elementType);
40static void printHWElementType(AsmPrinter &printer, Type dim);
41
42#define GET_TYPEDEF_CLASSES
43#include "circt/Dialect/HW/HWTypes.cpp.inc"
44
45//===----------------------------------------------------------------------===//
46// Type Helpers
47//===----------------------------------------------------------------------===/
48
49mlir::Type circt::hw::getCanonicalType(mlir::Type type) {
50 Type canonicalType;
51 if (auto typeAlias = dyn_cast<TypeAliasType>(type))
52 canonicalType = typeAlias.getCanonicalType();
53 else
54 canonicalType = type;
55 return canonicalType;
56}
57
58/// Return true if the specified type is a value HW Integer type. This checks
59/// that it is a signless standard dialect type or a hw::IntType.
60bool circt::hw::isHWIntegerType(mlir::Type type) {
61 Type canonicalType = getCanonicalType(type);
62
63 if (isa<hw::IntType>(canonicalType))
64 return true;
65
66 auto intType = dyn_cast<IntegerType>(canonicalType);
67 if (!intType || !intType.isSignless())
68 return false;
69
70 return true;
71}
72
73bool circt::hw::isHWEnumType(mlir::Type type) {
74 return isa<hw::EnumType>(getCanonicalType(type));
75}
76
77/// Return true if the specified type can be used as an HW value type, that is
78/// the set of types that can be composed together to represent synthesized,
79/// hardware but not marker types like InOutType.
80bool circt::hw::isHWValueType(Type type) {
81 // Signless and signed integer types are both valid.
82 if (isa<IntegerType, IntType, EnumType>(type))
83 return true;
84
85 if (auto array = dyn_cast<ArrayType>(type))
86 return isHWValueType(array.getElementType());
87
88 if (auto array = dyn_cast<UnpackedArrayType>(type))
89 return isHWValueType(array.getElementType());
90
91 if (auto t = dyn_cast<StructType>(type))
92 return llvm::all_of(t.getElements(),
93 [](auto f) { return isHWValueType(f.type); });
94
95 if (auto t = dyn_cast<UnionType>(type))
96 return llvm::all_of(t.getElements(),
97 [](auto f) { return isHWValueType(f.type); });
98
99 if (auto t = dyn_cast<TypeAliasType>(type))
100 return isHWValueType(t.getCanonicalType());
101
102 return false;
103}
104
105/// Return the hardware bit width of a type. Does not reflect any encoding,
106/// padding, or storage scheme, just the bit (and wire width) of a
107/// statically-size type. Reflects the number of wires needed to transmit a
108/// value of this type. Returns -1 if the type is not known or cannot be
109/// statically computed.
110int64_t circt::hw::getBitWidth(mlir::Type type) {
111 // Handle built-in types that don't implement the interface. Do this first
112 // since it is faster than downcasting to an interface.
113 return llvm::TypeSwitch<::mlir::Type, int64_t>(type)
114 .Case<IntegerType>(
115 [](IntegerType t) { return t.getIntOrFloatBitWidth(); })
116 .Default([](Type type) -> int64_t {
117 // If type implements the BitWidthTypeInterface, use it.
118 if (auto iface = dyn_cast<BitWidthTypeInterface>(type)) {
119 std::optional<int64_t> width = iface.getBitWidth();
120 return width.has_value() ? *width : -1;
121 }
122 return -1;
123 });
124}
125
126/// Return true if the specified type contains known marker types like
127/// InOutType. Unlike isHWValueType, this is not conservative, it only returns
128/// false on known InOut types, rather than any unknown types.
129bool circt::hw::hasHWInOutType(Type type) {
130 if (auto array = dyn_cast<ArrayType>(type))
131 return hasHWInOutType(array.getElementType());
132
133 if (auto array = dyn_cast<UnpackedArrayType>(type))
134 return hasHWInOutType(array.getElementType());
135
136 if (auto t = dyn_cast<StructType>(type)) {
137 return std::any_of(t.getElements().begin(), t.getElements().end(),
138 [](const auto &f) { return hasHWInOutType(f.type); });
139 }
140
141 if (auto t = dyn_cast<TypeAliasType>(type))
142 return hasHWInOutType(t.getCanonicalType());
143
144 return isa<InOutType>(type);
145}
146
147/// Parse and print nested HW types nicely. These helper methods allow eliding
148/// the "hw." prefix on array, inout, and other types when in a context that
149/// expects HW subelement types.
150static ParseResult parseHWElementType(AsmParser &p, Type &result) {
151 // If this is an HW dialect type, then we don't need/want the !hw. prefix
152 // redundantly specified.
153 auto fullString = static_cast<DialectAsmParser &>(p).getFullSymbolSpec();
154 auto *curPtr = p.getCurrentLocation().getPointer();
155 auto typeString =
156 StringRef(curPtr, fullString.size() - (curPtr - fullString.data()));
157
158 if (typeString.starts_with("array<") || typeString.starts_with("inout<") ||
159 typeString.starts_with("uarray<") || typeString.starts_with("struct<") ||
160 typeString.starts_with("typealias<") || typeString.starts_with("int<") ||
161 typeString.starts_with("enum<") || typeString.starts_with("union<")) {
162 llvm::StringRef mnemonic;
163 if (auto parseResult = generatedTypeParser(p, &mnemonic, result);
164 parseResult.has_value())
165 return *parseResult;
166 return p.emitError(p.getNameLoc(), "invalid type `") << typeString << "`";
167 }
168
169 return p.parseType(result);
170}
171
172static void printHWElementType(AsmPrinter &p, Type element) {
173 if (succeeded(generatedTypePrinter(element, p)))
174 return;
175 p.printType(element);
176}
177
178//===----------------------------------------------------------------------===//
179// Int Type
180//===----------------------------------------------------------------------===//
181
182Type IntType::get(mlir::TypedAttr width) {
183 // The width expression must always be a 32-bit wide integer type itself.
184 auto widthWidth = llvm::dyn_cast<IntegerType>(width.getType());
185 assert(widthWidth && widthWidth.getWidth() == 32 &&
186 "!hw.int width must be 32-bits");
187 (void)widthWidth;
188
189 if (auto cstWidth = llvm::dyn_cast<IntegerAttr>(width))
190 return IntegerType::get(width.getContext(),
191 cstWidth.getValue().getZExtValue());
192
193 return Base::get(width.getContext(), width);
194}
195
196Type IntType::parse(AsmParser &p) {
197 // The bitwidth of the parameter size is always 32 bits.
198 auto int32Type = p.getBuilder().getIntegerType(32);
199
200 mlir::TypedAttr width;
201 if (p.parseLess() || p.parseAttribute(width, int32Type) || p.parseGreater())
202 return Type();
203 return get(width);
204}
205
206void IntType::print(AsmPrinter &p) const {
207 p << "<";
208 p.printAttributeWithoutType(getWidth());
209 p << '>';
210}
211
212//===----------------------------------------------------------------------===//
213// Struct Type
214//===----------------------------------------------------------------------===//
215
216namespace circt {
217namespace hw {
218namespace detail {
219bool operator==(const FieldInfo &a, const FieldInfo &b) {
220 return a.name == b.name && a.type == b.type;
221}
222llvm::hash_code hash_value(const FieldInfo &fi) {
223 return llvm::hash_combine(fi.name, fi.type);
224}
225} // namespace detail
226} // namespace hw
227} // namespace circt
228
229/// Parse a list of unique field names and types within <>. E.g.:
230/// <foo: i7, bar: i8>
231static ParseResult parseFields(AsmParser &p,
232 SmallVectorImpl<FieldInfo> &parameters) {
233 llvm::StringSet<> nameSet;
234 bool hasDuplicateName = false;
235 auto parseResult = p.parseCommaSeparatedList(
236 mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
237 std::string name;
238 Type type;
239
240 auto fieldLoc = p.getCurrentLocation();
241 if (p.parseKeywordOrString(&name) || p.parseColon() ||
242 p.parseType(type))
243 return failure();
244
245 if (!nameSet.insert(name).second) {
246 p.emitError(fieldLoc, "duplicate field name \'" + name + "\'");
247 // Continue parsing to print all duplicates, but make sure to error
248 // eventually
249 hasDuplicateName = true;
250 }
251
252 parameters.push_back(
253 FieldInfo{StringAttr::get(p.getContext(), name), type});
254 return success();
255 });
256
257 if (hasDuplicateName)
258 return failure();
259 return parseResult;
260}
261
262/// Print out a list of named fields surrounded by <>.
263static void printFields(AsmPrinter &p, ArrayRef<FieldInfo> fields) {
264 p << '<';
265 llvm::interleaveComma(fields, p, [&](const FieldInfo &field) {
266 p.printKeywordOrString(field.name.getValue());
267 p << ": " << field.type;
268 });
269 p << ">";
270}
271
272Type StructType::parse(AsmParser &p) {
273 llvm::SmallVector<FieldInfo, 4> parameters;
274 if (parseFields(p, parameters))
275 return Type();
276 return get(p.getContext(), parameters);
277}
278
279LogicalResult StructType::verify(function_ref<InFlightDiagnostic()> emitError,
280 ArrayRef<StructType::FieldInfo> elements) {
281 llvm::SmallDenseSet<StringAttr> fieldNameSet;
282 LogicalResult result = success();
283 fieldNameSet.reserve(elements.size());
284 for (const auto &elt : elements)
285 if (!fieldNameSet.insert(elt.name).second) {
286 result = failure();
287 emitError() << "duplicate field name '" << elt.name.getValue()
288 << "' in hw.struct type";
289 }
290 return result;
291}
292
293void StructType::print(AsmPrinter &p) const { printFields(p, getElements()); }
294
295Type StructType::getFieldType(mlir::StringRef fieldName) {
296 for (const auto &field : getElements())
297 if (field.name == fieldName)
298 return field.type;
299 return Type();
300}
301
302std::optional<uint32_t> StructType::getFieldIndex(mlir::StringRef fieldName) {
303 ArrayRef<hw::StructType::FieldInfo> elems = getElements();
304 for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
305 if (elems[idx].name == fieldName)
306 return idx;
307 return {};
308}
309
310std::optional<uint32_t> StructType::getFieldIndex(mlir::StringAttr fieldName) {
311 ArrayRef<hw::StructType::FieldInfo> elems = getElements();
312 for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
313 if (elems[idx].name == fieldName)
314 return idx;
315 return {};
316}
317
318static std::pair<uint64_t, SmallVector<uint64_t>>
319getFieldIDsStruct(const StructType &st) {
320 uint64_t fieldID = 0;
321 auto elements = st.getElements();
322 SmallVector<uint64_t> fieldIDs;
323 fieldIDs.reserve(elements.size());
324 for (auto &element : elements) {
325 auto type = element.type;
326 fieldID += 1;
327 fieldIDs.push_back(fieldID);
328 // Increment the field ID for the next field by the number of subfields.
329 fieldID += hw::FieldIdImpl::getMaxFieldID(type);
330 }
331 return {fieldID, fieldIDs};
332}
333
334void StructType::getInnerTypes(SmallVectorImpl<Type> &types) {
335 for (const auto &field : getElements())
336 types.push_back(field.type);
337}
338
339uint64_t StructType::getMaxFieldID() const {
340 uint64_t fieldID = 0;
341 for (const auto &field : getElements())
342 fieldID += 1 + hw::FieldIdImpl::getMaxFieldID(field.type);
343 return fieldID;
344}
345
346std::pair<Type, uint64_t>
347StructType::getSubTypeByFieldID(uint64_t fieldID) const {
348 if (fieldID == 0)
349 return {*this, 0};
350 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
351 auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
352 auto subfieldIndex = std::distance(fieldIDs.begin(), it);
353 auto subfieldType = getElements()[subfieldIndex].type;
354 auto subfieldID = fieldID - fieldIDs[subfieldIndex];
355 return {subfieldType, subfieldID};
356}
357
358std::pair<uint64_t, bool>
359StructType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
360 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
361 auto childRoot = fieldIDs[index];
362 auto rangeEnd =
363 index + 1 >= getElements().size() ? maxId : (fieldIDs[index + 1] - 1);
364 return std::make_pair(fieldID - childRoot,
365 fieldID >= childRoot && fieldID <= rangeEnd);
366}
367
368uint64_t StructType::getFieldID(uint64_t index) const {
369 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
370 return fieldIDs[index];
371}
372
373uint64_t StructType::getIndexForFieldID(uint64_t fieldID) const {
374 assert(!getElements().empty() && "Bundle must have >0 fields");
375 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
376 auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
377 return std::distance(fieldIDs.begin(), it);
378}
379
380std::pair<uint64_t, uint64_t>
381StructType::getIndexAndSubfieldID(uint64_t fieldID) const {
382 auto index = getIndexForFieldID(fieldID);
383 auto elementFieldID = getFieldID(index);
384 return {index, fieldID - elementFieldID};
385}
386
387std::optional<DenseMap<Attribute, Type>>
388hw::StructType::getSubelementIndexMap() const {
389 DenseMap<Attribute, Type> destructured;
390 for (auto [i, field] : llvm::enumerate(getElements()))
391 destructured.insert(
392 {IntegerAttr::get(IndexType::get(getContext()), i), field.type});
393 return destructured;
394}
395
396Type hw::StructType::getTypeAtIndex(Attribute index) const {
397 auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
398 if (!indexAttr)
399 return {};
400
401 return getSubTypeByFieldID(indexAttr.getInt()).first;
402}
403
404std::optional<int64_t> StructType::getBitWidth() const {
405 int64_t total = 0;
406 for (auto field : getElements()) {
407 int64_t fieldSize = hw::getBitWidth(field.type);
408 if (fieldSize < 0)
409 return std::nullopt;
410 total += fieldSize;
411 }
412 return total;
413}
414
415//===----------------------------------------------------------------------===//
416// Union Type
417//===----------------------------------------------------------------------===//
418
419namespace circt {
420namespace hw {
421namespace detail {
423 return a.name == b.name && a.type == b.type && a.offset == b.offset;
424}
425// NOLINTNEXTLINE
426llvm::hash_code hash_value(const OffsetFieldInfo &fi) {
427 return llvm::hash_combine(fi.name, fi.type, fi.offset);
428}
429} // namespace detail
430} // namespace hw
431} // namespace circt
432
433Type UnionType::parse(AsmParser &p) {
434 llvm::SmallVector<FieldInfo, 4> parameters;
435 llvm::StringSet<> nameSet;
436 bool hasDuplicateName = false;
437 if (p.parseCommaSeparatedList(
438 mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
439 StringRef name;
440 Type type;
441
442 auto fieldLoc = p.getCurrentLocation();
443 if (p.parseKeyword(&name) || p.parseColon() || p.parseType(type))
444 return failure();
445
446 if (!nameSet.insert(name).second) {
447 p.emitError(fieldLoc, "duplicate field name \'" + name +
448 "\' in hw.union type");
449 // Continue parsing to print all duplicates, but make sure to
450 // error eventually
451 hasDuplicateName = true;
452 }
453
454 size_t offset = 0;
455 if (succeeded(p.parseOptionalKeyword("offset")))
456 if (p.parseInteger(offset))
457 return failure();
458 parameters.push_back(UnionType::FieldInfo{
459 StringAttr::get(p.getContext(), name), type, offset});
460 return success();
461 }))
462 return Type();
463
464 if (hasDuplicateName)
465 return Type();
466
467 return get(p.getContext(), parameters);
468}
469
470void UnionType::print(AsmPrinter &odsPrinter) const {
471 odsPrinter << '<';
472 llvm::interleaveComma(
473 getElements(), odsPrinter, [&](const UnionType::FieldInfo &field) {
474 odsPrinter << field.name.getValue() << ": " << field.type;
475 if (field.offset)
476 odsPrinter << " offset " << field.offset;
477 });
478 odsPrinter << ">";
479}
480
481LogicalResult UnionType::verify(function_ref<InFlightDiagnostic()> emitError,
482 ArrayRef<UnionType::FieldInfo> elements) {
483 llvm::SmallDenseSet<StringAttr> fieldNameSet;
484 LogicalResult result = success();
485 fieldNameSet.reserve(elements.size());
486 for (const auto &elt : elements)
487 if (!fieldNameSet.insert(elt.name).second) {
488 result = failure();
489 emitError() << "duplicate field name '" << elt.name.getValue()
490 << "' in hw.union type";
491 }
492 return result;
493}
494
495std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringAttr fieldName) {
496 ArrayRef<hw::UnionType::FieldInfo> elems = getElements();
497 for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
498 if (elems[idx].name == fieldName)
499 return idx;
500 return {};
501}
502
503std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringRef fieldName) {
504 return getFieldIndex(StringAttr::get(getContext(), fieldName));
505}
506
507UnionType::FieldInfo UnionType::getFieldInfo(::mlir::StringRef fieldName) {
508 if (auto fieldIndex = getFieldIndex(fieldName))
509 return getElements()[*fieldIndex];
510 return FieldInfo();
511}
512
513Type UnionType::getFieldType(mlir::StringRef fieldName) {
514 return getFieldInfo(fieldName).type;
515}
516
517std::optional<int64_t> UnionType::getBitWidth() const {
518 int64_t maxSize = 0;
519 for (auto field : getElements()) {
520 int64_t fieldSize = hw::getBitWidth(field.type);
521 if (fieldSize < 0)
522 return std::nullopt;
523 fieldSize += field.offset;
524 if (fieldSize > maxSize)
525 maxSize = fieldSize;
526 }
527 return maxSize;
528}
529
530//===----------------------------------------------------------------------===//
531// Enum Type
532//===----------------------------------------------------------------------===//
533
534Type EnumType::parse(AsmParser &p) {
535 llvm::SmallVector<Attribute> fields;
536
537 if (p.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, [&]() {
538 StringRef name;
539 if (p.parseKeyword(&name))
540 return failure();
541 fields.push_back(StringAttr::get(p.getContext(), name));
542 return success();
543 }))
544 return Type();
545
546 return get(p.getContext(), ArrayAttr::get(p.getContext(), fields));
547}
548
549void EnumType::print(AsmPrinter &p) const {
550 p << '<';
551 llvm::interleaveComma(getFields(), p, [&](Attribute enumerator) {
552 p << llvm::cast<StringAttr>(enumerator).getValue();
553 });
554 p << ">";
555}
556
557bool EnumType::contains(mlir::StringRef field) {
558 return indexOf(field).has_value();
559}
560
561std::optional<size_t> EnumType::indexOf(mlir::StringRef field) {
562 for (auto it : llvm::enumerate(getFields()))
563 if (llvm::cast<StringAttr>(it.value()).getValue() == field)
564 return it.index();
565 return {};
566}
567
568std::optional<int64_t> EnumType::getBitWidth() const {
569 auto w = getFields().size();
570 if (w > 1)
571 return llvm::Log2_64_Ceil(w);
572 return 1;
573}
574
575//===----------------------------------------------------------------------===//
576// ArrayType
577//===----------------------------------------------------------------------===//
578
579static ParseResult parseHWArray(AsmParser &p, Attribute &dim, Type &inner) {
580 uint64_t dimLiteral;
581 auto int64Type = p.getBuilder().getIntegerType(64);
582
583 if (auto res = p.parseOptionalInteger(dimLiteral); res.has_value()) {
584 if (failed(*res))
585 return failure();
586 dim = p.getBuilder().getI64IntegerAttr(dimLiteral);
587 } else if (auto res64 = p.parseOptionalAttribute(dim, int64Type);
588 res64.has_value()) {
589 if (failed(*res64))
590 return failure();
591 } else
592 return p.emitError(p.getNameLoc(), "expected integer");
593
594 if (!isa<IntegerAttr, ParamExprAttr, ParamDeclRefAttr>(dim)) {
595 p.emitError(p.getNameLoc(), "unsupported dimension kind in hw.array");
596 return failure();
597 }
598
599 if (p.parseXInDimensionList() || parseHWElementType(p, inner))
600 return failure();
601
602 return success();
603}
604
605static void printHWArray(AsmPrinter &p, Attribute dim, Type elementType) {
606 p.printAttributeWithoutType(dim);
607 p << "x";
609}
610
611size_t ArrayType::getNumElements() const {
612 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(getSizeAttr()))
613 return intAttr.getInt();
614 return -1;
615}
616
617LogicalResult ArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
618 Type innerType, Attribute size) {
619 if (hasHWInOutType(innerType))
620 return emitError() << "hw.array cannot contain InOut types";
621 return success();
622}
623
624uint64_t ArrayType::getMaxFieldID() const {
625 return getNumElements() *
626 (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
627}
628
629std::pair<Type, uint64_t>
630ArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
631 if (fieldID == 0)
632 return {*this, 0};
633 return {getElementType(), getIndexAndSubfieldID(fieldID).second};
634}
635
636std::pair<uint64_t, bool>
637ArrayType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
638 auto childRoot = getFieldID(index);
639 auto rangeEnd =
640 index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
641 return std::make_pair(fieldID - childRoot,
642 fieldID >= childRoot && fieldID <= rangeEnd);
643}
644
645uint64_t ArrayType::getIndexForFieldID(uint64_t fieldID) const {
646 assert(fieldID && "fieldID must be at least 1");
647 // Divide the field ID by the number of fieldID's per element.
648 return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
649}
650
651std::pair<uint64_t, uint64_t>
652ArrayType::getIndexAndSubfieldID(uint64_t fieldID) const {
653 auto index = getIndexForFieldID(fieldID);
654 auto elementFieldID = getFieldID(index);
655 return {index, fieldID - elementFieldID};
656}
657
658uint64_t ArrayType::getFieldID(uint64_t index) const {
659 return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
660}
661
662std::optional<DenseMap<Attribute, Type>>
663hw::ArrayType::getSubelementIndexMap() const {
664 DenseMap<Attribute, Type> destructured;
665 for (unsigned i = 0; i < getNumElements(); ++i)
666 destructured.insert(
667 {IntegerAttr::get(IndexType::get(getContext()), i), getElementType()});
668 return destructured;
669}
670
671Type hw::ArrayType::getTypeAtIndex(Attribute index) const {
672 return getElementType();
673}
674
675std::optional<int64_t> hw::ArrayType::getBitWidth() const {
676 auto elementBitWidth = hw::getBitWidth(getElementType());
677 if (elementBitWidth < 0)
678 return std::nullopt;
679 int64_t numElements = getNumElements();
680 if (numElements < 0)
681 return std::nullopt;
682 return numElements * elementBitWidth;
683}
684
685//===----------------------------------------------------------------------===//
686// UnpackedArrayType
687//===----------------------------------------------------------------------===//
688
689LogicalResult
690UnpackedArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
691 Type innerType, Attribute size) {
692 if (!isHWValueType(innerType))
693 return emitError() << "invalid element for uarray type";
694 return success();
695}
696
697size_t UnpackedArrayType::getNumElements() const {
698 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(getSizeAttr()))
699 return intAttr.getInt();
700 return -1;
701}
702
703uint64_t UnpackedArrayType::getMaxFieldID() const {
704 return getNumElements() *
705 (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
706}
707
708std::pair<Type, uint64_t>
709UnpackedArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
710 if (fieldID == 0)
711 return {*this, 0};
712 return {getElementType(), getIndexAndSubfieldID(fieldID).second};
713}
714
715std::pair<uint64_t, bool>
716UnpackedArrayType::projectToChildFieldID(uint64_t fieldID,
717 uint64_t index) const {
718 auto childRoot = getFieldID(index);
719 auto rangeEnd =
720 index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
721 return std::make_pair(fieldID - childRoot,
722 fieldID >= childRoot && fieldID <= rangeEnd);
723}
724
725uint64_t UnpackedArrayType::getIndexForFieldID(uint64_t fieldID) const {
726 assert(fieldID && "fieldID must be at least 1");
727 // Divide the field ID by the number of fieldID's per element.
728 return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
729}
730
731std::pair<uint64_t, uint64_t>
732UnpackedArrayType::getIndexAndSubfieldID(uint64_t fieldID) const {
733 auto index = getIndexForFieldID(fieldID);
734 auto elementFieldID = getFieldID(index);
735 return {index, fieldID - elementFieldID};
736}
737
738uint64_t UnpackedArrayType::getFieldID(uint64_t index) const {
739 return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
740}
741
742std::optional<int64_t> UnpackedArrayType::getBitWidth() const {
743 auto elementBitWidth = hw::getBitWidth(getElementType());
744 if (elementBitWidth < 0)
745 return std::nullopt;
746 int64_t dimBitWidth = getNumElements();
747 if (dimBitWidth < 0)
748 return std::nullopt;
749 return (int64_t)getNumElements() * elementBitWidth;
750}
751
752//===----------------------------------------------------------------------===//
753// InOutType
754//===----------------------------------------------------------------------===//
755
756LogicalResult InOutType::verify(function_ref<InFlightDiagnostic()> emitError,
757 Type innerType) {
758 if (!isHWValueType(innerType))
759 return emitError() << "invalid element for hw.inout type " << innerType;
760 return success();
761}
762
763//===----------------------------------------------------------------------===//
764// TypeAliasType
765//===----------------------------------------------------------------------===//
766
767static Type computeCanonicalType(Type type) {
768 return llvm::TypeSwitch<Type, Type>(type)
769 .Case([](TypeAliasType t) {
770 return computeCanonicalType(t.getCanonicalType());
771 })
772 .Case([](ArrayType t) {
773 return ArrayType::get(computeCanonicalType(t.getElementType()),
774 t.getNumElements());
775 })
776 .Case([](UnpackedArrayType t) {
777 return UnpackedArrayType::get(computeCanonicalType(t.getElementType()),
778 t.getNumElements());
779 })
780 .Case([](StructType t) {
781 SmallVector<StructType::FieldInfo> fieldInfo;
782 for (auto field : t.getElements())
783 fieldInfo.push_back(StructType::FieldInfo{
784 field.name, computeCanonicalType(field.type)});
785 return StructType::get(t.getContext(), fieldInfo);
786 })
787 .Default([](Type t) { return t; });
788}
789
790TypeAliasType TypeAliasType::get(SymbolRefAttr ref, Type innerType) {
791 return get(ref.getContext(), ref, innerType, computeCanonicalType(innerType));
792}
793
794Type TypeAliasType::parse(AsmParser &p) {
795 SymbolRefAttr ref;
796 Type type;
797 if (p.parseLess() || p.parseAttribute(ref) || p.parseComma() ||
798 p.parseType(type) || p.parseGreater())
799 return Type();
800
801 return get(ref, type);
802}
803
804void TypeAliasType::print(AsmPrinter &p) const {
805 p << "<" << getRef() << ", " << getInnerType() << ">";
806}
807
808/// Return the Typedecl referenced by this TypeAlias, given the module to look
809/// in. This returns null when the IR is malformed.
810TypedeclOp TypeAliasType::getTypeDecl(const HWSymbolCache &cache) {
811 SymbolRefAttr ref = getRef();
812 auto typeScope = ::dyn_cast_or_null<TypeScopeOp>(
813 cache.getDefinition(ref.getRootReference()));
814 if (!typeScope)
815 return {};
816
817 return typeScope.lookupSymbol<TypedeclOp>(ref.getLeafReference());
818}
819
820std::optional<int64_t> TypeAliasType::getBitWidth() const {
821 auto width = hw::getBitWidth(getCanonicalType());
822 if (width < 0)
823 return std::nullopt;
824 return width;
825}
826
827//===----------------------------------------------------------------------===//
828// ModuleType
829//===----------------------------------------------------------------------===//
830
831LogicalResult ModuleType::verify(function_ref<InFlightDiagnostic()> emitError,
832 ArrayRef<ModulePort> ports) {
833 if (llvm::any_of(ports, [](const ModulePort &port) {
834 return hasHWInOutType(port.type);
835 }))
836 return emitError() << "Ports cannot be inout types";
837 return success();
838}
839
840size_t ModuleType::getPortIdForInputId(size_t idx) {
841 assert(idx < getImpl()->inputToAbs.size() && "input port out of range");
842 return getImpl()->inputToAbs[idx];
843}
844
845size_t ModuleType::getPortIdForOutputId(size_t idx) {
846 assert(idx < getImpl()->outputToAbs.size() && " output port out of range");
847 return getImpl()->outputToAbs[idx];
848}
849
850size_t ModuleType::getInputIdForPortId(size_t idx) {
851 auto nIdx = getImpl()->absToInput[idx];
852 assert(nIdx != ~0ULL);
853 return nIdx;
854}
855
856size_t ModuleType::getOutputIdForPortId(size_t idx) {
857 auto nIdx = getImpl()->absToOutput[idx];
858 assert(nIdx != ~0ULL);
859 return nIdx;
860}
861
862size_t ModuleType::getNumInputs() { return getImpl()->inputToAbs.size(); }
863
864size_t ModuleType::getNumOutputs() { return getImpl()->outputToAbs.size(); }
865
866size_t ModuleType::getNumPorts() { return getPorts().size(); }
867
868SmallVector<Type> ModuleType::getInputTypes() {
869 SmallVector<Type> retval;
870 for (auto &p : getPorts()) {
871 if (p.dir == ModulePort::Direction::Input)
872 retval.push_back(p.type);
873 else if (p.dir == ModulePort::Direction::InOut) {
874 retval.push_back(hw::InOutType::get(p.type));
875 }
876 }
877 return retval;
878}
879
880SmallVector<Type> ModuleType::getOutputTypes() {
881 SmallVector<Type> retval;
882 for (auto &p : getPorts())
883 if (p.dir == ModulePort::Direction::Output)
884 retval.push_back(p.type);
885 return retval;
886}
887
888SmallVector<Type> ModuleType::getPortTypes() {
889 SmallVector<Type> retval;
890 for (auto &p : getPorts())
891 retval.push_back(p.type);
892 return retval;
893}
894
895Type ModuleType::getInputType(size_t idx) {
896 const auto &portInfo = getPorts()[getPortIdForInputId(idx)];
897 if (portInfo.dir != ModulePort::InOut)
898 return portInfo.type;
899 return InOutType::get(portInfo.type);
900}
901
902Type ModuleType::getOutputType(size_t idx) {
903 return getPorts()[getPortIdForOutputId(idx)].type;
904}
905
906SmallVector<Attribute> ModuleType::getInputNames() {
907 SmallVector<Attribute> retval;
908 for (auto &p : getPorts())
909 if (p.dir != ModulePort::Direction::Output)
910 retval.push_back(p.name);
911 return retval;
912}
913
914SmallVector<Attribute> ModuleType::getOutputNames() {
915 SmallVector<Attribute> retval;
916 for (auto &p : getPorts())
917 if (p.dir == ModulePort::Direction::Output)
918 retval.push_back(p.name);
919 return retval;
920}
921
922StringAttr ModuleType::getPortNameAttr(size_t idx) {
923 return getPorts()[idx].name;
924}
925
926StringRef ModuleType::getPortName(size_t idx) {
927 auto sa = getPortNameAttr(idx);
928 if (sa)
929 return sa.getValue();
930 return {};
931}
932
933StringAttr ModuleType::getInputNameAttr(size_t idx) {
934 return getPorts()[getPortIdForInputId(idx)].name;
935}
936
937StringRef ModuleType::getInputName(size_t idx) {
938 auto sa = getInputNameAttr(idx);
939 if (sa)
940 return sa.getValue();
941 return {};
942}
943
944StringAttr ModuleType::getOutputNameAttr(size_t idx) {
945 return getPorts()[getPortIdForOutputId(idx)].name;
946}
947
948StringRef ModuleType::getOutputName(size_t idx) {
949 auto sa = getOutputNameAttr(idx);
950 if (sa)
951 return sa.getValue();
952 return {};
953}
954
955bool ModuleType::isOutput(size_t idx) {
956 auto &p = getPorts()[idx];
957 return p.dir == ModulePort::Direction::Output;
958}
959
960FunctionType ModuleType::getFuncType() {
961 SmallVector<Type> inputs, outputs;
962 for (auto p : getPorts())
963 if (p.dir == ModulePort::Input)
964 inputs.push_back(p.type);
965 else if (p.dir == ModulePort::InOut)
966 inputs.push_back(InOutType::get(p.type));
967 else
968 outputs.push_back(p.type);
969 return FunctionType::get(getContext(), inputs, outputs);
970}
971
972ArrayRef<ModulePort> ModuleType::getPorts() const {
973 return getImpl()->getPorts();
974}
975
976FailureOr<ModuleType> ModuleType::resolveParametricTypes(ArrayAttr parameters,
977 LocationAttr loc,
978 bool emitErrors) {
979 SmallVector<ModulePort, 8> resolvedPorts;
980 for (ModulePort port : getPorts()) {
981 FailureOr<Type> resolvedType =
982 evaluateParametricType(loc, parameters, port.type, emitErrors);
983 if (failed(resolvedType))
984 return failure();
985 port.type = *resolvedType;
986 resolvedPorts.push_back(port);
987 }
988 return ModuleType::get(getContext(), resolvedPorts);
989}
990
991static StringRef dirToStr(ModulePort::Direction dir) {
992 switch (dir) {
993 case ModulePort::Direction::Input:
994 return "input";
995 case ModulePort::Direction::Output:
996 return "output";
997 case ModulePort::Direction::InOut:
998 return "inout";
999 }
1000}
1001
1002static ModulePort::Direction strToDir(StringRef str) {
1003 if (str == "input")
1004 return ModulePort::Direction::Input;
1005 if (str == "output")
1006 return ModulePort::Direction::Output;
1007 if (str == "inout")
1008 return ModulePort::Direction::InOut;
1009 llvm::report_fatal_error("invalid direction");
1010}
1011
1012/// Parse a list of field names and types within <>. E.g.:
1013/// <input foo: i7, output bar: i8>
1014static ParseResult parsePorts(AsmParser &p,
1015 SmallVectorImpl<ModulePort> &ports) {
1016 return p.parseCommaSeparatedList(
1017 mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
1018 StringRef dir;
1019 std::string name;
1020 Type type;
1021 if (p.parseKeyword(&dir) || p.parseKeywordOrString(&name) ||
1022 p.parseColon() || p.parseType(type))
1023 return failure();
1024 ports.push_back(
1025 {StringAttr::get(p.getContext(), name), type, strToDir(dir)});
1026 return success();
1027 });
1028}
1029
1030/// Print out a list of named fields surrounded by <>.
1031static void printPorts(AsmPrinter &p, ArrayRef<ModulePort> ports) {
1032 p << '<';
1033 llvm::interleaveComma(ports, p, [&](const ModulePort &port) {
1034 p << dirToStr(port.dir) << " ";
1035 p.printKeywordOrString(port.name.getValue());
1036 p << " : " << port.type;
1037 });
1038 p << ">";
1039}
1040
1041Type ModuleType::parse(AsmParser &odsParser) {
1042 llvm::SmallVector<ModulePort, 4> ports;
1043 if (parsePorts(odsParser, ports))
1044 return Type();
1045 return get(odsParser.getContext(), ports);
1046}
1047
1048void ModuleType::print(AsmPrinter &odsPrinter) const {
1049 printPorts(odsPrinter, getPorts());
1050}
1051
1052ModuleType circt::hw::detail::fnToMod(Operation *op,
1053 ArrayRef<Attribute> inputNames,
1054 ArrayRef<Attribute> outputNames) {
1055 return fnToMod(
1056 cast<FunctionType>(cast<mlir::FunctionOpInterface>(op).getFunctionType()),
1057 inputNames, outputNames);
1058}
1059
1060ModuleType circt::hw::detail::fnToMod(FunctionType fnty,
1061 ArrayRef<Attribute> inputNames,
1062 ArrayRef<Attribute> outputNames) {
1063 SmallVector<ModulePort> ports;
1064 if (!inputNames.empty()) {
1065 for (auto [t, n] : llvm::zip_equal(fnty.getInputs(), inputNames))
1066 if (auto iot = dyn_cast<hw::InOutType>(t))
1067 ports.push_back({cast<StringAttr>(n), iot.getElementType(),
1068 ModulePort::Direction::InOut});
1069 else
1070 ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Input});
1071 } else {
1072 for (auto t : fnty.getInputs())
1073 if (auto iot = dyn_cast<hw::InOutType>(t))
1074 ports.push_back(
1075 {{}, iot.getElementType(), ModulePort::Direction::InOut});
1076 else
1077 ports.push_back({{}, t, ModulePort::Direction::Input});
1078 }
1079 if (!outputNames.empty()) {
1080 for (auto [t, n] : llvm::zip_equal(fnty.getResults(), outputNames))
1081 ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Output});
1082 } else {
1083 for (auto t : fnty.getResults())
1084 ports.push_back({{}, t, ModulePort::Direction::Output});
1085 }
1086 return ModuleType::get(fnty.getContext(), ports);
1087}
1088
1090 : ports(inPorts) {
1091 size_t nextInput = 0;
1092 size_t nextOutput = 0;
1093 for (auto [idx, p] : llvm::enumerate(ports)) {
1094 if (p.dir == ModulePort::Direction::Output) {
1095 outputToAbs.push_back(idx);
1096 absToOutput.push_back(nextOutput);
1097 absToInput.push_back(~0ULL);
1098 ++nextOutput;
1099 } else {
1100 inputToAbs.push_back(idx);
1101 absToInput.push_back(nextInput);
1102 absToOutput.push_back(~0ULL);
1103 ++nextInput;
1104 }
1105 }
1106}
1107
1108//===----------------------------------------------------------------------===//
1109// BoilerPlate
1110//===----------------------------------------------------------------------===//
1111
1112void HWDialect::registerTypes() {
1113 addTypes<
1114#define GET_TYPEDEF_LIST
1115#include "circt/Dialect/HW/HWTypes.cpp.inc"
1116 >();
1117}
assert(baseType &&"element must be base type")
MlirType uint64_t numElements
Definition CHIRRTL.cpp:30
MlirType elementType
Definition CHIRRTL.cpp:29
static ModulePort::Direction strToDir(StringRef str)
Definition HWTypes.cpp:1002
static void printPorts(AsmPrinter &p, ArrayRef< ModulePort > ports)
Print out a list of named fields surrounded by <>.
Definition HWTypes.cpp:1031
static void printFields(AsmPrinter &p, ArrayRef< FieldInfo > fields)
Print out a list of named fields surrounded by <>.
Definition HWTypes.cpp:263
static StringRef dirToStr(ModulePort::Direction dir)
Definition HWTypes.cpp:991
static ParseResult parseHWArray(AsmParser &parser, Attribute &dim, Type &elementType)
Definition HWTypes.cpp:579
static ParseResult parseHWElementType(AsmParser &parser, Type &elementType)
Parse and print nested HW types nicely.
Definition HWTypes.cpp:150
static ParseResult parsePorts(AsmParser &p, SmallVectorImpl< ModulePort > &ports)
Parse a list of field names and types within <>.
Definition HWTypes.cpp:1014
static void printHWArray(AsmPrinter &printer, Attribute dim, Type elementType)
Definition HWTypes.cpp:605
static std::pair< uint64_t, SmallVector< uint64_t > > getFieldIDsStruct(const StructType &st)
Definition HWTypes.cpp:319
static ParseResult parseFields(AsmParser &p, SmallVectorImpl< FieldInfo > &parameters)
Parse a list of unique field names and types within <>.
Definition HWTypes.cpp:231
static Type computeCanonicalType(Type type)
Definition HWTypes.cpp:767
static void printHWElementType(AsmPrinter &printer, Type dim)
Definition HWTypes.cpp:172
@ Input
Definition HW.h:42
@ Output
Definition HW.h:42
static unsigned getFieldID(BundleType type, unsigned index)
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
static unsigned getMaxFieldID(FIRRTLBaseType type)
static InstancePath empty
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition HWSymCache.h:27
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition HWSymCache.h:56
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
uint64_t getWidth(Type t)
Definition ESIPasses.cpp:32
mlir::Type innerType(mlir::Type type)
Definition ESITypes.cpp:420
std::pair< uint64_t, uint64_t > getIndexAndSubfieldID(Type type, uint64_t fieldID)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
llvm::hash_code hash_value(const FieldInfo &fi)
Definition HWTypes.cpp:222
bool operator==(const FieldInfo &a, const FieldInfo &b)
Definition HWTypes.cpp:219
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
Definition HWTypes.cpp:1052
bool isHWIntegerType(mlir::Type type)
Return true if the specified type is a value HW Integer type.
Definition HWTypes.cpp:60
bool isHWValueType(mlir::Type type)
Return true if the specified type can be used as an HW value type, that is the set of types that can ...
mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
bool isHWEnumType(mlir::Type type)
Return true if the specified type is a HW Enum type.
Definition HWTypes.cpp:73
mlir::Type getCanonicalType(mlir::Type type)
Definition HWTypes.cpp:49
bool hasHWInOutType(mlir::Type type)
Return true if the specified type contains known marker types like InOutType.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
mlir::Type type
Definition HWTypes.h:31
mlir::StringAttr name
Definition HWTypes.h:30
Struct defining a field. Used in structs.
Definition HWTypes.h:92
mlir::StringAttr name
Definition HWTypes.h:93
SmallVector< ModulePort > ports
The parametric data held by the storage class.
Definition HWTypes.h:70
ModuleTypeStorage(ArrayRef< ModulePort > inPorts)
Definition HWTypes.cpp:1089
SmallVector< size_t > absToInput
Definition HWTypes.h:74
SmallVector< size_t > outputToAbs
Definition HWTypes.h:73
SmallVector< size_t > inputToAbs
Definition HWTypes.h:72
SmallVector< size_t > absToOutput
Definition HWTypes.h:75
Struct defining a field with an offset. Used in unions.
Definition HWTypes.h:98