19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Diagnostics.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/StorageUniquerSupport.h"
24 #include "mlir/IR/Types.h"
25 #include "llvm/ADT/SmallSet.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include "llvm/ADT/TypeSwitch.h"
30 using namespace circt;
34 #define GET_TYPEDEF_CLASSES
35 #include "circt/Dialect/HW/HWTypes.cpp.inc"
43 if (
auto typeAlias = type.dyn_cast<TypeAliasType>())
44 canonicalType = typeAlias.getCanonicalType();
55 if (canonicalType.isa<hw::IntType>())
58 auto intType = canonicalType.dyn_cast<IntegerType>();
59 if (!intType || !intType.isSignless())
74 if (type.isa<IntegerType, IntType, EnumType>())
77 if (
auto array = type.dyn_cast<ArrayType>())
80 if (
auto array = type.dyn_cast<UnpackedArrayType>())
83 if (
auto t = type.dyn_cast<StructType>())
84 return llvm::all_of(t.getElements(),
85 [](
auto f) { return isHWValueType(f.type); });
87 if (
auto t = type.dyn_cast<UnionType>())
88 return llvm::all_of(t.getElements(),
89 [](
auto f) { return isHWValueType(f.type); });
91 if (
auto t = type.dyn_cast<TypeAliasType>())
103 return llvm::TypeSwitch<::mlir::Type, size_t>(type)
105 [](IntegerType t) {
return t.getIntOrFloatBitWidth(); })
106 .Case<ArrayType, UnpackedArrayType>([](
auto a) {
107 int64_t elementBitWidth =
getBitWidth(a.getElementType());
108 if (elementBitWidth < 0)
109 return elementBitWidth;
110 int64_t dimBitWidth = a.getNumElements();
112 return static_cast<int64_t
>(-1L);
113 return (int64_t)a.getNumElements() * elementBitWidth;
115 .Case<StructType>([](StructType s) {
117 for (
auto field : s.getElements()) {
125 .Case<UnionType>([](UnionType u) {
127 for (
auto field : u.getElements()) {
128 int64_t fieldSize =
getBitWidth(field.type) + field.offset;
129 if (fieldSize > maxSize)
134 .Case<EnumType>([](EnumType e) {
return e.getBitWidth(); })
135 .Case<TypeAliasType>(
136 [](TypeAliasType t) {
return getBitWidth(t.getCanonicalType()); })
137 .Default([](Type) {
return -1; });
144 if (
auto array = type.dyn_cast<ArrayType>())
147 if (
auto array = type.dyn_cast<UnpackedArrayType>())
150 if (
auto t = type.dyn_cast<StructType>()) {
151 return std::any_of(t.getElements().begin(), t.getElements().end(),
152 [](
const auto &f) { return hasHWInOutType(f.type); });
155 if (
auto t = type.dyn_cast<TypeAliasType>())
167 auto fullString =
static_cast<DialectAsmParser &
>(p).getFullSymbolSpec();
168 auto *curPtr = p.getCurrentLocation().getPointer();
170 StringRef(curPtr, fullString.size() - (curPtr - fullString.data()));
172 if (typeString.startswith(
"array<") || typeString.startswith(
"inout<") ||
173 typeString.startswith(
"uarray<") || typeString.startswith(
"struct<") ||
174 typeString.startswith(
"typealias<") || typeString.startswith(
"int<") ||
175 typeString.startswith(
"enum<")) {
176 llvm::StringRef mnemonic;
177 auto parseResult = generatedTypeParser(p, &mnemonic, result);
178 return parseResult.has_value() ? success() : failure();
181 return p.parseType(result);
185 if (succeeded(generatedTypePrinter(element, p)))
187 p.printType(element);
196 auto widthWidth =
width.getType().dyn_cast<IntegerType>();
197 assert(widthWidth && widthWidth.getWidth() == 32 &&
198 "!hw.int width must be 32-bits");
201 if (
auto cstWidth =
width.dyn_cast<IntegerAttr>())
203 cstWidth.getValue().getZExtValue());
208 Type IntType::parse(AsmParser &p) {
210 auto int32Type = p.getBuilder().getIntegerType(32);
212 mlir::TypedAttr
width;
213 if (p.parseLess() || p.parseAttribute(
width, int32Type) || p.parseGreater())
218 void IntType::print(AsmPrinter &p)
const {
220 p.printAttributeWithoutType(
getWidth());
224 //===----------------------------------------------------------------------===//
226 //===----------------------------------------------------------------------===//
231 bool operator==(const FieldInfo &a, const FieldInfo &b) {
232 return a.name == b.name && a.type == b.type;
234 llvm::hash_code hash_value(const FieldInfo &fi) {
235 return llvm::hash_combine(fi.name, fi.type);
237 } // namespace detail
243 static ParseResult parseFields(AsmParser &p,
244 SmallVectorImpl<FieldInfo> ¶meters) {
245 llvm::StringSet<> nameSet;
246 bool hasDuplicateName = false;
247 auto parseResult = p.parseCommaSeparatedList(
248 mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
252 auto fieldLoc = p.getCurrentLocation();
253 if (p.parseKeywordOrString(&name) || p.parseColon() ||
257 if (!nameSet.insert(name).second) {
258 p.emitError(fieldLoc, "duplicate field name \'" + name + "\'");
259 // Continue parsing to print all duplicates, but make sure to error
261 hasDuplicateName = true;
264 parameters.push_back(
265 FieldInfo{StringAttr::get(p.getContext(), name), type});
269 if (hasDuplicateName)
275 static void printFields(AsmPrinter &p, ArrayRef<FieldInfo> fields) {
277 llvm::interleaveComma(fields, p, [&](const FieldInfo &field) {
278 p.printKeywordOrString(field.name.getValue());
279 p << ": " << field.type;
284 Type StructType::parse(AsmParser &p) {
285 llvm::SmallVector<FieldInfo, 4> parameters;
286 if (parseFields(p, parameters))
288 return get(p.getContext(), parameters);
291 LogicalResult StructType::verify(function_ref<InFlightDiagnostic()> emitError,
292 ArrayRef<StructType::FieldInfo> elements) {
293 llvm::SmallDenseSet<StringAttr> fieldNameSet;
294 LogicalResult result = success();
295 fieldNameSet.reserve(elements.size());
296 for (const auto &elt : elements)
297 if (!fieldNameSet.insert(elt.name).second) {
299 emitError() << "duplicate field name '" << elt.name.getValue()
300 << "' in hw.struct type";
305 void StructType::print(AsmPrinter &p) const { printFields(p, getElements()); }
307 Type StructType::getFieldType(mlir::StringRef fieldName) {
308 for (const auto &field : getElements())
309 if (field.name == fieldName)
314 std::optional<uint32_t> StructType::getFieldIndex(mlir::StringRef fieldName) {
315 ArrayRef<hw::StructType::FieldInfo> elems = getElements();
316 for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
317 if (elems[idx].name == fieldName)
322 std::optional<uint32_t> StructType::getFieldIndex(mlir::StringAttr fieldName) {
323 ArrayRef<hw::StructType::FieldInfo> elems = getElements();
324 for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
325 if (elems[idx].name == fieldName)
330 static std::pair<uint64_t, SmallVector<uint64_t>>
331 getFieldIDsStruct(const StructType &st) {
332 uint64_t fieldID = 0;
333 auto elements = st.getElements();
334 SmallVector<uint64_t> fieldIDs;
335 fieldIDs.reserve(elements.size());
336 for (auto &element : elements) {
337 auto type = element.type;
339 fieldIDs.push_back(fieldID);
340 // Increment the field ID for the next field by the number of subfields.
341 fieldID += hw::FieldIdImpl::getMaxFieldID(type);
343 return {fieldID, fieldIDs};
346 void StructType::getInnerTypes(SmallVectorImpl<Type> &types) {
347 for (const auto &field : getElements())
348 types.push_back(field.type);
351 uint64_t StructType::getMaxFieldID() const {
352 uint64_t fieldID = 0;
353 for (const auto &field : getElements())
354 fieldID += 1 + hw::FieldIdImpl::getMaxFieldID(field.type);
358 std::pair<Type, uint64_t>
359 StructType::getSubTypeByFieldID(uint64_t fieldID) const {
362 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
363 auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
364 auto subfieldIndex = std::distance(fieldIDs.begin(), it);
365 auto subfieldType = getElements()[subfieldIndex].type;
366 auto subfieldID = fieldID - fieldIDs[subfieldIndex];
367 return {subfieldType, subfieldID};
370 std::pair<uint64_t, bool>
371 StructType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
372 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
373 auto childRoot = fieldIDs[index];
375 index + 1 >= getElements().size() ? maxId : (fieldIDs[index + 1] - 1);
376 return std::make_pair(fieldID - childRoot,
377 fieldID >= childRoot && fieldID <= rangeEnd);
380 uint64_t StructType::getFieldID(uint64_t index) const {
381 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
382 return fieldIDs[index];
385 uint64_t StructType::getIndexForFieldID(uint64_t fieldID) const {
386 assert(!getElements().empty() && "Bundle must have >0 fields");
387 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
388 auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
389 return std::distance(fieldIDs.begin(), it);
392 std::pair<uint64_t, uint64_t>
393 StructType::getIndexAndSubfieldID(uint64_t fieldID) const {
394 auto index = getIndexForFieldID(fieldID);
395 auto elementFieldID = getFieldID(index);
396 return {index, fieldID - elementFieldID};
399 //===----------------------------------------------------------------------===//
401 //===----------------------------------------------------------------------===//
406 bool operator==(const OffsetFieldInfo &a, const OffsetFieldInfo &b) {
407 return a.name == b.name && a.type == b.type && a.offset == b.offset;
410 llvm::hash_code hash_value(const OffsetFieldInfo &fi) {
411 return llvm::hash_combine(fi.name, fi.type, fi.offset);
413 } // namespace detail
417 Type UnionType::parse(AsmParser &p) {
418 llvm::SmallVector<FieldInfo, 4> parameters;
419 llvm::StringSet<> nameSet;
420 bool hasDuplicateName = false;
421 if (p.parseCommaSeparatedList(
422 mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
426 auto fieldLoc = p.getCurrentLocation();
427 if (p.parseKeyword(&name) || p.parseColon() || p.parseType(type))
430 if (!nameSet.insert(name).second) {
431 p.emitError(fieldLoc, "duplicate field name \'" + name +
432 "\' in hw.union type");
433 // Continue parsing to print all duplicates, but make sure to
435 hasDuplicateName = true;
439 if (succeeded(p.parseOptionalKeyword("offset")))
440 if (p.parseInteger(offset))
442 parameters.push_back(UnionType::FieldInfo{
443 StringAttr::get(p.getContext(), name), type, offset});
448 if (hasDuplicateName)
451 return get(p.getContext(), parameters);
454 void UnionType::print(AsmPrinter &odsPrinter) const {
456 llvm::interleaveComma(
457 getElements(), odsPrinter, [&](const UnionType::FieldInfo &field) {
458 odsPrinter << field.name.getValue() << ": " << field.type;
460 odsPrinter << " offset " << field.offset;
465 LogicalResult UnionType::verify(function_ref<InFlightDiagnostic()> emitError,
466 ArrayRef<UnionType::FieldInfo> elements) {
467 llvm::SmallDenseSet<StringAttr> fieldNameSet;
468 LogicalResult result = success();
469 fieldNameSet.reserve(elements.size());
470 for (const auto &elt : elements)
471 if (!fieldNameSet.insert(elt.name).second) {
473 emitError() << "duplicate field name '" << elt.name.getValue()
474 << "' in hw.union type";
479 std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringAttr fieldName) {
480 ArrayRef<hw::UnionType::FieldInfo> elems = getElements();
481 for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
482 if (elems[idx].name == fieldName)
487 std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringRef fieldName) {
488 return getFieldIndex(StringAttr::get(getContext(), fieldName));
491 UnionType::FieldInfo UnionType::getFieldInfo(::mlir::StringRef fieldName) {
492 if (auto fieldIndex = getFieldIndex(fieldName))
493 return getElements()[*fieldIndex];
497 Type UnionType::getFieldType(mlir::StringRef fieldName) {
498 return getFieldInfo(fieldName).type;
501 //===----------------------------------------------------------------------===//
503 //===----------------------------------------------------------------------===//
505 Type EnumType::parse(AsmParser &p) {
506 llvm::SmallVector<Attribute> fields;
508 if (p.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, [&]() {
510 if (p.parseKeyword(&name))
512 fields.push_back(StringAttr::get(p.getContext(), name));
517 return get(p.getContext(), ArrayAttr::get(p.getContext(), fields));
520 void EnumType::print(AsmPrinter &p) const {
522 llvm::interleaveComma(getFields(), p, [&](Attribute enumerator) {
523 p << enumerator.cast<StringAttr>().getValue();
528 bool EnumType::contains(mlir::StringRef field) {
529 return indexOf(field).has_value();
532 std::optional<size_t> EnumType::indexOf(mlir::StringRef field) {
533 for (auto it : llvm::enumerate(getFields()))
534 if (it.value().cast<StringAttr>().getValue() == field)
539 size_t EnumType::getBitWidth() {
540 auto w = getFields().size();
542 return llvm::Log2_64_Ceil(getFields().size());
546 //===----------------------------------------------------------------------===//
548 //===----------------------------------------------------------------------===//
550 static LogicalResult parseArray(AsmParser &p, Attribute &dim, Type &inner) {
555 auto int64Type = p.getBuilder().getIntegerType(64);
557 if (auto res = p.parseOptionalInteger(dimLiteral); res.has_value())
558 dim = p.getBuilder().getI64IntegerAttr(dimLiteral);
559 else if (!p.parseOptionalAttribute(dim, int64Type).has_value())
562 if (!dim.isa<IntegerAttr, ParamExprAttr, ParamDeclRefAttr>()) {
563 p.emitError(p.getNameLoc(), "unsupported dimension kind in hw.array");
567 if (p.parseXInDimensionList() || parseHWElementType(inner, p) ||
574 Type ArrayType::parse(AsmParser &p) {
578 if (failed(parseArray(p, dim, inner)))
581 auto loc = p.getEncodedSourceLoc(p.getCurrentLocation());
582 if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner, dim)))
585 return get(inner.getContext(), inner, dim);
588 void ArrayType::print(AsmPrinter &p) const {
590 p.printAttributeWithoutType(getSizeAttr());
592 printHWElementType(getElementType(), p);
596 size_t ArrayType::getNumElements() const {
597 if (auto intAttr = getSizeAttr().dyn_cast<IntegerAttr>())
598 return intAttr.getInt();
602 LogicalResult ArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
603 Type innerType, Attribute size) {
604 if (hasHWInOutType(innerType))
605 return emitError() << "hw.array cannot contain InOut types";
609 uint64_t ArrayType::getMaxFieldID() const {
610 return getNumElements() *
611 (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
614 std::pair<Type, uint64_t>
615 ArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
618 return {getElementType(), getIndexAndSubfieldID(fieldID).second};
621 std::pair<uint64_t, bool>
622 ArrayType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
623 auto childRoot = getFieldID(index);
625 index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
626 return std::make_pair(fieldID - childRoot,
627 fieldID >= childRoot && fieldID <= rangeEnd);
630 uint64_t ArrayType::getIndexForFieldID(uint64_t fieldID) const {
631 assert(fieldID && "fieldID must be at least 1");
632 // Divide the field ID by the number of fieldID's per element.
636 std::pair<uint64_t, uint64_t>
640 return {index, fieldID - elementFieldID};
651 Type UnpackedArrayType::parse(AsmParser &p) {
658 auto loc = p.getEncodedSourceLoc(p.getCurrentLocation());
659 if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner, dim)))
662 return get(inner.getContext(), inner, dim);
665 void UnpackedArrayType::print(AsmPrinter &p)
const {
667 p.printAttributeWithoutType(getSizeAttr());
674 UnpackedArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
675 Type innerType, Attribute size) {
676 if (!isHWValueType(innerType))
677 return emitError() << "invalid element for uarray type";
681 size_t UnpackedArrayType::getNumElements() const {
682 if (auto intAttr = getSizeAttr().dyn_cast<IntegerAttr>())
683 return intAttr.getInt();
687 uint64_t UnpackedArrayType::getMaxFieldID() const {
688 return getNumElements() *
689 (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
692 std::pair<Type, uint64_t>
693 UnpackedArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
696 return {getElementType(), getIndexAndSubfieldID(fieldID).second};
699 std::pair<uint64_t, bool>
700 UnpackedArrayType::projectToChildFieldID(uint64_t fieldID,
701 uint64_t index) const {
702 auto childRoot = getFieldID(index);
704 index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
705 return std::make_pair(fieldID - childRoot,
706 fieldID >= childRoot && fieldID <= rangeEnd);
709 uint64_t UnpackedArrayType::getIndexForFieldID(uint64_t fieldID) const {
710 assert(fieldID && "fieldID must be at least 1");
711 // Divide the field ID by the number of fieldID's per element.
715 std::pair<uint64_t, uint64_t>
719 return {index, fieldID - elementFieldID};
730 Type InOutType::parse(AsmParser &p) {
735 auto loc = p.getEncodedSourceLoc(p.getCurrentLocation());
736 if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner)))
739 return get(p.getContext(), inner);
742 void InOutType::print(AsmPrinter &p)
const {
748 LogicalResult InOutType::verify(function_ref<InFlightDiagnostic()> emitError,
750 if (!isHWValueType(innerType))
751 return emitError() << "invalid element for hw.inout type " << innerType;
755 //===----------------------------------------------------------------------===//
757 //===----------------------------------------------------------------------===//
759 static Type computeCanonicalType(Type type) {
760 return llvm::TypeSwitch<Type, Type>(type)
761 .Case([](TypeAliasType t) {
762 return computeCanonicalType(t.getCanonicalType());
764 .Case([](ArrayType t) {
765 return ArrayType::get(computeCanonicalType(t.getElementType()),
768 .Case([](UnpackedArrayType t) {
769 return UnpackedArrayType::get(computeCanonicalType(t.getElementType()),
772 .Case([](StructType t) {
773 SmallVector<StructType::FieldInfo> fieldInfo;
774 for (auto field : t.getElements())
775 fieldInfo.push_back(StructType::FieldInfo{
776 field.name, computeCanonicalType(field.type)});
777 return StructType::get(t.getContext(), fieldInfo);
779 .Default([](Type t) { return t; });
782 TypeAliasType TypeAliasType::get(SymbolRefAttr ref, Type innerType) {
783 return get(ref.getContext(), ref, innerType, computeCanonicalType(innerType));
786 Type TypeAliasType::parse(AsmParser &p) {
789 if (p.parseLess() || p.parseAttribute(ref) || p.parseComma() ||
790 p.parseType(type) || p.parseGreater())
793 return get(ref, type);
796 void TypeAliasType::print(AsmPrinter &p) const {
797 p << "<" << getRef() << ", " << getInnerType() << ">";
802 TypedeclOp TypeAliasType::getTypeDecl(const HWSymbolCache &cache) {
803 SymbolRefAttr ref = getRef();
804 auto typeScope = ::dyn_cast_or_null<TypeScopeOp>(
805 cache.getDefinition(ref.getRootReference()));
809 return typeScope.lookupSymbol<TypedeclOp>(ref.getLeafReference());
816 LogicalResult ModuleType::verify(function_ref<InFlightDiagnostic()> emitError,
817 ArrayRef<ModulePort> ports) {
818 if (llvm::any_of(ports, [](const ModulePort &port) {
819 return hasHWInOutType(port.type);
821 return emitError() << "Ports cannot be inout types";
825 size_t ModuleType::getPortIdForInputId(size_t idx) {
826 for (auto [i, p] : llvm::enumerate(getPorts())) {
827 if (p.dir != ModulePort::Direction::Output) {
833 assert(0 && "Out of bounds input port id");
837 size_t ModuleType::getPortIdForOutputId(size_t idx) {
838 for (auto [i, p] : llvm::enumerate(getPorts())) {
839 if (p.dir == ModulePort::Direction::Output) {
845 assert(0 && "Out of bounds output port id");
849 size_t ModuleType::getInputIdForPortId(size_t idx) {
850 auto ports = getPorts();
851 assert(ports[idx].dir != ModulePort::Direction::Output);
853 for (size_t i = 0; i < idx; ++i)
854 if (ports[i].dir != ModulePort::Direction::Output)
859 size_t ModuleType::getOutputIdForPortId(size_t idx) {
860 auto ports = getPorts();
861 assert(ports[idx].dir == ModulePort::Direction::Output);
863 for (size_t i = 0; i < idx; ++i)
864 if (ports[i].dir == ModulePort::Direction::Output)
869 size_t ModuleType::getNumInputs() {
870 return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) {
871 return p.dir != ModulePort::Direction::Output;
875 size_t ModuleType::getNumOutputs() {
876 return std::count_if(getPorts().begin(), getPorts().end(), [](auto &p) {
877 return p.dir == ModulePort::Direction::Output;
881 size_t ModuleType::getNumPorts() { return getPorts().size(); }
883 SmallVector<Type> ModuleType::getInputTypes() {
884 SmallVector<Type> retval;
885 for (auto &p : getPorts()) {
886 if (p.dir == ModulePort::Direction::Input)
887 retval.push_back(p.type);
888 else if (p.dir == ModulePort::Direction::InOut) {
889 retval.push_back(hw::InOutType::get(p.type));
895 SmallVector<Type> ModuleType::getOutputTypes() {
896 SmallVector<Type> retval;
897 for (auto &p : getPorts())
898 if (p.dir == ModulePort::Direction::Output)
899 retval.push_back(p.type);
903 SmallVector<Type> ModuleType::getPortTypes() {
904 SmallVector<Type> retval;
905 for (auto &p : getPorts())
906 retval.push_back(p.type);
910 Type ModuleType::getInputType(size_t idx) {
911 const auto &portInfo = getPorts()[getPortIdForInputId(idx)];
912 if (portInfo.dir != ModulePort::InOut)
913 return portInfo.type;
914 return InOutType::get(portInfo.type);
917 Type ModuleType::getOutputType(size_t idx) {
918 return getPorts()[getPortIdForOutputId(idx)].type;
921 SmallVector<StringAttr> ModuleType::getInputNamesStr() {
922 SmallVector<StringAttr> retval;
923 for (auto &p : getPorts())
924 if (p.dir != ModulePort::Direction::Output)
925 retval.push_back(p.name);
929 SmallVector<StringAttr> ModuleType::getOutputNamesStr() {
930 SmallVector<StringAttr> retval;
931 for (auto &p : getPorts())
932 if (p.dir == ModulePort::Direction::Output)
933 retval.push_back(p.name);
937 SmallVector<Attribute> ModuleType::getInputNames() {
938 SmallVector<Attribute> retval;
939 for (auto &p : getPorts())
940 if (p.dir != ModulePort::Direction::Output)
941 retval.push_back(p.name);
945 SmallVector<Attribute> ModuleType::getOutputNames() {
946 SmallVector<Attribute> retval;
947 for (auto &p : getPorts())
948 if (p.dir == ModulePort::Direction::Output)
949 retval.push_back(p.name);
953 StringAttr ModuleType::getPortNameAttr(size_t idx) {
954 return getPorts()[idx].name;
957 StringRef ModuleType::getPortName(size_t idx) {
958 auto sa = getPortNameAttr(idx);
960 return sa.getValue();
964 StringAttr ModuleType::getInputNameAttr(size_t idx) {
965 return getPorts()[getPortIdForInputId(idx)].name;
968 StringRef ModuleType::getInputName(size_t idx) {
969 auto sa = getInputNameAttr(idx);
971 return sa.getValue();
975 StringAttr ModuleType::getOutputNameAttr(size_t idx) {
976 return getPorts()[getPortIdForOutputId(idx)].name;
979 StringRef ModuleType::getOutputName(size_t idx) {
980 auto sa = getOutputNameAttr(idx);
982 return sa.getValue();
986 bool ModuleType::isOutput(size_t idx) {
987 auto &p = getPorts()[idx];
988 return p.dir == ModulePort::Direction::Output;
991 FunctionType ModuleType::getFuncType() {
992 SmallVector<Type> inputs, outputs;
993 for (auto p : getPorts())
994 if (p.dir == ModulePort::Input)
995 inputs.push_back(p.type);
996 else if (p.dir == ModulePort::InOut)
997 inputs.push_back(InOutType::get(p.type));
999 outputs.push_back(p.type);
1000 return FunctionType::get(getContext(), inputs, outputs);
1003 FailureOr<ModuleType> ModuleType::resolveParametricTypes(ArrayAttr parameters,
1006 SmallVector<ModulePort, 8> resolvedPorts;
1007 for (ModulePort port : getPorts()) {
1008 FailureOr<Type> resolvedType =
1009 evaluateParametricType(loc, parameters, port.type, emitErrors);
1010 if (failed(resolvedType))
1012 port.type = *resolvedType;
1013 resolvedPorts.push_back(port);
1015 return ModuleType::get(getContext(), resolvedPorts);
1018 static StringRef dirToStr(ModulePort::Direction dir) {
1020 case ModulePort::Direction::Input:
1022 case ModulePort::Direction::Output:
1024 case ModulePort::Direction::InOut:
1029 static ModulePort::Direction strToDir(StringRef str) {
1031 return ModulePort::Direction::Input;
1032 if (str == "output")
1033 return ModulePort::Direction::Output;
1035 return ModulePort::Direction::InOut;
1036 llvm::report_fatal_error("invalid direction");
1041 static ParseResult parsePorts(AsmParser &p,
1042 SmallVectorImpl<ModulePort> &ports) {
1043 return p.parseCommaSeparatedList(
1044 mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
1048 if (p.parseKeyword(&dir) || p.parseKeywordOrString(&name) ||
1049 p.parseColon() || p.parseType(type))
1052 {StringAttr::get(p.getContext(), name), type, strToDir(dir)});
1058 static void printPorts(AsmPrinter &p, ArrayRef<ModulePort> ports) {
1060 llvm::interleaveComma(ports, p, [&](const ModulePort &port) {
1061 p << dirToStr(port.dir) << " ";
1062 p.printKeywordOrString(port.name.getValue());
1063 p << " : " << port.type;
1068 Type ModuleType::parse(AsmParser &odsParser) {
1069 llvm::SmallVector<ModulePort, 4> ports;
1070 if (parsePorts(odsParser, ports))
1072 return get(odsParser.getContext(), ports);
1075 void ModuleType::print(AsmPrinter &odsPrinter) const {
1076 printPorts(odsPrinter, getPorts());
1082 static bool operator==(const ModulePort &a, const ModulePort &b) {
1083 return a.dir == b.dir && a.name == b.name && a.type == b.type;
1085 static llvm::hash_code hash_value(const ModulePort &port) {
1086 return llvm::hash_combine(port.dir, port.name, port.type);
1089 } // namespace circt
1091 ModuleType circt::hw::detail::fnToMod(Operation *op,
1092 ArrayRef<Attribute> inputNames,
1093 ArrayRef<Attribute> outputNames) {
1095 cast<FunctionType>(cast<mlir::FunctionOpInterface>(op).getFunctionType()),
1096 inputNames, outputNames);
1099 ModuleType circt::hw::detail::fnToMod(FunctionType fnty,
1100 ArrayRef<Attribute> inputNames,
1101 ArrayRef<Attribute> outputNames) {
1102 SmallVector<ModulePort> ports;
1103 if (!inputNames.empty()) {
1104 for (auto [t, n] : llvm::zip_equal(fnty.getInputs(), inputNames))
1105 if (auto iot = dyn_cast<hw::InOutType>(t))
1106 ports.push_back({cast<StringAttr>(n), iot.getElementType(),
1107 ModulePort::Direction::InOut});
1109 ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Input});
1111 for (auto t : fnty.getInputs())
1112 if (auto iot = dyn_cast<hw::InOutType>(t))
1114 {{}, iot.getElementType(), ModulePort::Direction::InOut});
1116 ports.push_back({{}, t, ModulePort::Direction::Input});
1118 if (!outputNames.empty()) {
1119 for (auto [t, n] : llvm::zip_equal(fnty.getResults(), outputNames))
1120 ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Output});
1122 for (auto t : fnty.getResults())
1123 ports.push_back({{}, t, ModulePort::Direction::Output});
1125 return ModuleType::get(fnty.getContext(), ports);
1132 void HWDialect::registerTypes() {
1134 #define GET_TYPEDEF_LIST
1135 #include "circt/Dialect/HW/HWTypes.cpp.inc"
assert(baseType &&"element must be base type")
static void printHWElementType(Type element, AsmPrinter &p)
static LogicalResult parseArray(AsmParser &p, Attribute &dim, Type &inner)
static ParseResult parseHWElementType(Type &result, AsmParser &p)
Parse and print nested HW types nicely.
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
uint64_t getWidth(Type t)
std::pair< uint64_t, uint64_t > getIndexAndSubfieldID(Type type, uint64_t fieldID)
uint64_t getFieldID(Type type, uint64_t index)
uint64_t getMaxFieldID(Type)
bool isHWIntegerType(mlir::Type type)
Return true if the specified type is a value HW Integer type.
bool isHWValueType(mlir::Type type)
Return true if the specified type can be used as an HW value type, that is the set of types that can ...
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
bool isHWEnumType(mlir::Type type)
Return true if the specified type is a HW Enum type.
mlir::Type getCanonicalType(mlir::Type type)
bool hasHWInOutType(mlir::Type type)
Return true if the specified type contains known marker types like InOutType.
circt::hw::InOutType InOutType
This file defines an intermediate representation for circuits acting as an abstraction for constraint...