19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Diagnostics.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/StorageUniquerSupport.h"
24 #include "mlir/IR/Types.h"
25 #include "llvm/ADT/SmallSet.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include "llvm/ADT/TypeSwitch.h"
30 using namespace circt;
34 #define GET_TYPEDEF_CLASSES
35 #include "circt/Dialect/HW/HWTypes.cpp.inc"
43 if (
auto typeAlias = dyn_cast<TypeAliasType>(type))
44 canonicalType = typeAlias.getCanonicalType();
55 if (isa<hw::IntType>(canonicalType))
58 auto intType = dyn_cast<IntegerType>(canonicalType);
59 if (!intType || !intType.isSignless())
74 if (isa<IntegerType, IntType, EnumType>(type))
77 if (
auto array = dyn_cast<ArrayType>(type))
80 if (
auto array = dyn_cast<UnpackedArrayType>(type))
83 if (
auto t = dyn_cast<StructType>(type))
84 return llvm::all_of(t.getElements(),
85 [](
auto f) { return isHWValueType(f.type); });
87 if (
auto t = dyn_cast<UnionType>(type))
88 return llvm::all_of(t.getElements(),
89 [](
auto f) { return isHWValueType(f.type); });
91 if (
auto t = dyn_cast<TypeAliasType>(type))
103 return llvm::TypeSwitch<::mlir::Type, size_t>(type)
105 [](IntegerType t) {
return t.getIntOrFloatBitWidth(); })
106 .Case<ArrayType, UnpackedArrayType>([](
auto a) {
107 int64_t elementBitWidth =
getBitWidth(a.getElementType());
108 if (elementBitWidth < 0)
109 return elementBitWidth;
110 int64_t dimBitWidth = a.getNumElements();
112 return static_cast<int64_t
>(-1L);
113 return (int64_t)a.getNumElements() * elementBitWidth;
115 .Case<StructType>([](StructType s) {
117 for (
auto field : s.getElements()) {
125 .Case<UnionType>([](UnionType u) {
127 for (
auto field : u.getElements()) {
128 int64_t fieldSize =
getBitWidth(field.type) + field.offset;
129 if (fieldSize > maxSize)
134 .Case<EnumType>([](EnumType e) {
return e.getBitWidth(); })
135 .Case<TypeAliasType>(
136 [](TypeAliasType t) {
return getBitWidth(t.getCanonicalType()); })
137 .Default([](Type) {
return -1; });
144 if (
auto array = dyn_cast<ArrayType>(type))
147 if (
auto array = dyn_cast<UnpackedArrayType>(type))
150 if (
auto t = dyn_cast<StructType>(type)) {
151 return std::any_of(t.getElements().begin(), t.getElements().end(),
152 [](
const auto &f) { return hasHWInOutType(f.type); });
155 if (
auto t = dyn_cast<TypeAliasType>(type))
158 return isa<InOutType>(type);
167 auto fullString =
static_cast<DialectAsmParser &
>(p).getFullSymbolSpec();
168 auto *curPtr = p.getCurrentLocation().getPointer();
170 StringRef(curPtr, fullString.size() - (curPtr - fullString.data()));
172 if (typeString.starts_with(
"array<") || typeString.starts_with(
"inout<") ||
173 typeString.starts_with(
"uarray<") || typeString.starts_with(
"struct<") ||
174 typeString.starts_with(
"typealias<") || typeString.starts_with(
"int<") ||
175 typeString.starts_with(
"enum<")) {
176 llvm::StringRef mnemonic;
177 auto parseResult = generatedTypeParser(p, &mnemonic, result);
178 return parseResult.has_value() ? success() : failure();
181 return p.parseType(result);
185 if (succeeded(generatedTypePrinter(element, p)))
187 p.printType(element);
196 auto widthWidth = llvm::dyn_cast<IntegerType>(
width.getType());
197 assert(widthWidth && widthWidth.getWidth() == 32 &&
198 "!hw.int width must be 32-bits");
201 if (
auto cstWidth = llvm::dyn_cast<IntegerAttr>(
width))
203 cstWidth.getValue().getZExtValue());
208 Type IntType::parse(AsmParser &p) {
210 auto int32Type = p.getBuilder().getIntegerType(32);
212 mlir::TypedAttr
width;
213 if (p.parseLess() || p.parseAttribute(
width, int32Type) || p.parseGreater())
218 void IntType::print(AsmPrinter &p)
const {
220 p.printAttributeWithoutType(
getWidth());
224 //===----------------------------------------------------------------------===//
226 //===----------------------------------------------------------------------===//
231 bool operator==(const FieldInfo &a, const FieldInfo &b) {
232 return a.name == b.name && a.type == b.type;
234 llvm::hash_code hash_value(const FieldInfo &fi) {
235 return llvm::hash_combine(fi.name, fi.type);
237 } // namespace detail
243 static ParseResult parseFields(AsmParser &p,
244 SmallVectorImpl<FieldInfo> ¶meters) {
245 llvm::StringSet<> nameSet;
246 bool hasDuplicateName = false;
247 auto parseResult = p.parseCommaSeparatedList(
248 mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
252 auto fieldLoc = p.getCurrentLocation();
253 if (p.parseKeywordOrString(&name) || p.parseColon() ||
257 if (!nameSet.insert(name).second) {
258 p.emitError(fieldLoc, "duplicate field name \'" + name + "\'");
259 // Continue parsing to print all duplicates, but make sure to error
261 hasDuplicateName = true;
264 parameters.push_back(
265 FieldInfo{StringAttr::get(p.getContext(), name), type});
269 if (hasDuplicateName)
275 static void printFields(AsmPrinter &p, ArrayRef<FieldInfo> fields) {
277 llvm::interleaveComma(fields, p, [&](const FieldInfo &field) {
278 p.printKeywordOrString(field.name.getValue());
279 p << ": " << field.type;
284 Type StructType::parse(AsmParser &p) {
285 llvm::SmallVector<FieldInfo, 4> parameters;
286 if (parseFields(p, parameters))
288 return get(p.getContext(), parameters);
291 LogicalResult StructType::verify(function_ref<InFlightDiagnostic()> emitError,
292 ArrayRef<StructType::FieldInfo> elements) {
293 llvm::SmallDenseSet<StringAttr> fieldNameSet;
294 LogicalResult result = success();
295 fieldNameSet.reserve(elements.size());
296 for (const auto &elt : elements)
297 if (!fieldNameSet.insert(elt.name).second) {
299 emitError() << "duplicate field name '" << elt.name.getValue()
300 << "' in hw.struct type";
305 void StructType::print(AsmPrinter &p) const { printFields(p, getElements()); }
307 Type StructType::getFieldType(mlir::StringRef fieldName) {
308 for (const auto &field : getElements())
309 if (field.name == fieldName)
314 std::optional<uint32_t> StructType::getFieldIndex(mlir::StringRef fieldName) {
315 ArrayRef<hw::StructType::FieldInfo> elems = getElements();
316 for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
317 if (elems[idx].name == fieldName)
322 std::optional<uint32_t> StructType::getFieldIndex(mlir::StringAttr fieldName) {
323 ArrayRef<hw::StructType::FieldInfo> elems = getElements();
324 for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
325 if (elems[idx].name == fieldName)
330 static std::pair<uint64_t, SmallVector<uint64_t>>
331 getFieldIDsStruct(const StructType &st) {
332 uint64_t fieldID = 0;
333 auto elements = st.getElements();
334 SmallVector<uint64_t> fieldIDs;
335 fieldIDs.reserve(elements.size());
336 for (auto &element : elements) {
337 auto type = element.type;
339 fieldIDs.push_back(fieldID);
340 // Increment the field ID for the next field by the number of subfields.
341 fieldID += hw::FieldIdImpl::getMaxFieldID(type);
343 return {fieldID, fieldIDs};
346 void StructType::getInnerTypes(SmallVectorImpl<Type> &types) {
347 for (const auto &field : getElements())
348 types.push_back(field.type);
351 uint64_t StructType::getMaxFieldID() const {
352 uint64_t fieldID = 0;
353 for (const auto &field : getElements())
354 fieldID += 1 + hw::FieldIdImpl::getMaxFieldID(field.type);
358 std::pair<Type, uint64_t>
359 StructType::getSubTypeByFieldID(uint64_t fieldID) const {
362 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
363 auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
364 auto subfieldIndex = std::distance(fieldIDs.begin(), it);
365 auto subfieldType = getElements()[subfieldIndex].type;
366 auto subfieldID = fieldID - fieldIDs[subfieldIndex];
367 return {subfieldType, subfieldID};
370 std::pair<uint64_t, bool>
371 StructType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
372 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
373 auto childRoot = fieldIDs[index];
375 index + 1 >= getElements().size() ? maxId : (fieldIDs[index + 1] - 1);
376 return std::make_pair(fieldID - childRoot,
377 fieldID >= childRoot && fieldID <= rangeEnd);
380 uint64_t StructType::getFieldID(uint64_t index) const {
381 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
382 return fieldIDs[index];
385 uint64_t StructType::getIndexForFieldID(uint64_t fieldID) const {
386 assert(!getElements().empty() && "Bundle must have >0 fields");
387 auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
388 auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
389 return std::distance(fieldIDs.begin(), it);
392 std::pair<uint64_t, uint64_t>
393 StructType::getIndexAndSubfieldID(uint64_t fieldID) const {
394 auto index = getIndexForFieldID(fieldID);
395 auto elementFieldID = getFieldID(index);
396 return {index, fieldID - elementFieldID};
399 //===----------------------------------------------------------------------===//
401 //===----------------------------------------------------------------------===//
406 bool operator==(const OffsetFieldInfo &a, const OffsetFieldInfo &b) {
407 return a.name == b.name && a.type == b.type && a.offset == b.offset;
410 llvm::hash_code hash_value(const OffsetFieldInfo &fi) {
411 return llvm::hash_combine(fi.name, fi.type, fi.offset);
413 } // namespace detail
417 Type UnionType::parse(AsmParser &p) {
418 llvm::SmallVector<FieldInfo, 4> parameters;
419 llvm::StringSet<> nameSet;
420 bool hasDuplicateName = false;
421 if (p.parseCommaSeparatedList(
422 mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
426 auto fieldLoc = p.getCurrentLocation();
427 if (p.parseKeyword(&name) || p.parseColon() || p.parseType(type))
430 if (!nameSet.insert(name).second) {
431 p.emitError(fieldLoc, "duplicate field name \'" + name +
432 "\' in hw.union type");
433 // Continue parsing to print all duplicates, but make sure to
435 hasDuplicateName = true;
439 if (succeeded(p.parseOptionalKeyword("offset")))
440 if (p.parseInteger(offset))
442 parameters.push_back(UnionType::FieldInfo{
443 StringAttr::get(p.getContext(), name), type, offset});
448 if (hasDuplicateName)
451 return get(p.getContext(), parameters);
454 void UnionType::print(AsmPrinter &odsPrinter) const {
456 llvm::interleaveComma(
457 getElements(), odsPrinter, [&](const UnionType::FieldInfo &field) {
458 odsPrinter << field.name.getValue() << ": " << field.type;
460 odsPrinter << " offset " << field.offset;
465 LogicalResult UnionType::verify(function_ref<InFlightDiagnostic()> emitError,
466 ArrayRef<UnionType::FieldInfo> elements) {
467 llvm::SmallDenseSet<StringAttr> fieldNameSet;
468 LogicalResult result = success();
469 fieldNameSet.reserve(elements.size());
470 for (const auto &elt : elements)
471 if (!fieldNameSet.insert(elt.name).second) {
473 emitError() << "duplicate field name '" << elt.name.getValue()
474 << "' in hw.union type";
479 std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringAttr fieldName) {
480 ArrayRef<hw::UnionType::FieldInfo> elems = getElements();
481 for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
482 if (elems[idx].name == fieldName)
487 std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringRef fieldName) {
488 return getFieldIndex(StringAttr::get(getContext(), fieldName));
491 UnionType::FieldInfo UnionType::getFieldInfo(::mlir::StringRef fieldName) {
492 if (auto fieldIndex = getFieldIndex(fieldName))
493 return getElements()[*fieldIndex];
497 Type UnionType::getFieldType(mlir::StringRef fieldName) {
498 return getFieldInfo(fieldName).type;
501 //===----------------------------------------------------------------------===//
503 //===----------------------------------------------------------------------===//
505 Type EnumType::parse(AsmParser &p) {
506 llvm::SmallVector<Attribute> fields;
508 if (p.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, [&]() {
510 if (p.parseKeyword(&name))
512 fields.push_back(StringAttr::get(p.getContext(), name));
517 return get(p.getContext(), ArrayAttr::get(p.getContext(), fields));
520 void EnumType::print(AsmPrinter &p) const {
522 llvm::interleaveComma(getFields(), p, [&](Attribute enumerator) {
523 p << llvm::cast<StringAttr>(enumerator).getValue();
528 bool EnumType::contains(mlir::StringRef field) {
529 return indexOf(field).has_value();
532 std::optional<size_t> EnumType::indexOf(mlir::StringRef field) {
533 for (auto it : llvm::enumerate(getFields()))
534 if (llvm::cast<StringAttr>(it.value()).getValue() == field)
539 size_t EnumType::getBitWidth() {
540 auto w = getFields().size();
542 return llvm::Log2_64_Ceil(getFields().size());
546 //===----------------------------------------------------------------------===//
548 //===----------------------------------------------------------------------===//
550 static LogicalResult parseArray(AsmParser &p, Attribute &dim, Type &inner) {
555 auto int64Type = p.getBuilder().getIntegerType(64);
557 if (auto res = p.parseOptionalInteger(dimLiteral); res.has_value())
558 dim = p.getBuilder().getI64IntegerAttr(dimLiteral);
559 else if (!p.parseOptionalAttribute(dim, int64Type).has_value())
562 if (!isa<IntegerAttr, ParamExprAttr, ParamDeclRefAttr>(dim)) {
563 p.emitError(p.getNameLoc(), "unsupported dimension kind in hw.array");
567 if (p.parseXInDimensionList() || parseHWElementType(inner, p) ||
574 Type ArrayType::parse(AsmParser &p) {
578 if (failed(parseArray(p, dim, inner)))
581 auto loc = p.getEncodedSourceLoc(p.getCurrentLocation());
582 if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner, dim)))
585 return get(inner.getContext(), inner, dim);
588 void ArrayType::print(AsmPrinter &p) const {
590 p.printAttributeWithoutType(getSizeAttr());
592 printHWElementType(getElementType(), p);
596 size_t ArrayType::getNumElements() const {
597 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(getSizeAttr()))
598 return intAttr.getInt();
602 LogicalResult ArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
603 Type innerType, Attribute size) {
604 if (hasHWInOutType(innerType))
605 return emitError() << "hw.array cannot contain InOut types";
609 uint64_t ArrayType::getMaxFieldID() const {
610 return getNumElements() *
611 (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
614 std::pair<Type, uint64_t>
615 ArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
618 return {getElementType(), getIndexAndSubfieldID(fieldID).second};
621 std::pair<uint64_t, bool>
622 ArrayType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
623 auto childRoot = getFieldID(index);
625 index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
626 return std::make_pair(fieldID - childRoot,
627 fieldID >= childRoot && fieldID <= rangeEnd);
630 uint64_t ArrayType::getIndexForFieldID(uint64_t fieldID) const {
631 assert(fieldID && "fieldID must be at least 1");
632 // Divide the field ID by the number of fieldID's per element.
636 std::pair<uint64_t, uint64_t>
640 return {index, fieldID - elementFieldID};
651 Type UnpackedArrayType::parse(AsmParser &p) {
658 auto loc = p.getEncodedSourceLoc(p.getCurrentLocation());
659 if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner, dim)))
662 return get(inner.getContext(), inner, dim);
665 void UnpackedArrayType::print(AsmPrinter &p)
const {
667 p.printAttributeWithoutType(getSizeAttr());
674 UnpackedArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
675 Type innerType, Attribute size) {
676 if (!isHWValueType(innerType))
677 return emitError() << "invalid element for uarray type";
681 size_t UnpackedArrayType::getNumElements() const {
682 if (auto intAttr = llvm::dyn_cast<IntegerAttr>(getSizeAttr()))
683 return intAttr.getInt();
687 uint64_t UnpackedArrayType::getMaxFieldID() const {
688 return getNumElements() *
689 (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
692 std::pair<Type, uint64_t>
693 UnpackedArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
696 return {getElementType(), getIndexAndSubfieldID(fieldID).second};
699 std::pair<uint64_t, bool>
700 UnpackedArrayType::projectToChildFieldID(uint64_t fieldID,
701 uint64_t index) const {
702 auto childRoot = getFieldID(index);
704 index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
705 return std::make_pair(fieldID - childRoot,
706 fieldID >= childRoot && fieldID <= rangeEnd);
709 uint64_t UnpackedArrayType::getIndexForFieldID(uint64_t fieldID) const {
710 assert(fieldID && "fieldID must be at least 1");
711 // Divide the field ID by the number of fieldID's per element.
715 std::pair<uint64_t, uint64_t>
719 return {index, fieldID - elementFieldID};
730 Type InOutType::parse(AsmParser &p) {
735 auto loc = p.getEncodedSourceLoc(p.getCurrentLocation());
736 if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner)))
739 return get(p.getContext(), inner);
742 void InOutType::print(AsmPrinter &p)
const {
748 LogicalResult InOutType::verify(function_ref<InFlightDiagnostic()> emitError,
750 if (!isHWValueType(innerType))
751 return emitError() << "invalid element for hw.inout type " << innerType;
755 //===----------------------------------------------------------------------===//
757 //===----------------------------------------------------------------------===//
759 static Type computeCanonicalType(Type type) {
760 return llvm::TypeSwitch<Type, Type>(type)
761 .Case([](TypeAliasType t) {
762 return computeCanonicalType(t.getCanonicalType());
764 .Case([](ArrayType t) {
765 return ArrayType::get(computeCanonicalType(t.getElementType()),
768 .Case([](UnpackedArrayType t) {
769 return UnpackedArrayType::get(computeCanonicalType(t.getElementType()),
772 .Case([](StructType t) {
773 SmallVector<StructType::FieldInfo> fieldInfo;
774 for (auto field : t.getElements())
775 fieldInfo.push_back(StructType::FieldInfo{
776 field.name, computeCanonicalType(field.type)});
777 return StructType::get(t.getContext(), fieldInfo);
779 .Default([](Type t) { return t; });
782 TypeAliasType TypeAliasType::get(SymbolRefAttr ref, Type innerType) {
783 return get(ref.getContext(), ref, innerType, computeCanonicalType(innerType));
786 Type TypeAliasType::parse(AsmParser &p) {
789 if (p.parseLess() || p.parseAttribute(ref) || p.parseComma() ||
790 p.parseType(type) || p.parseGreater())
793 return get(ref, type);
796 void TypeAliasType::print(AsmPrinter &p) const {
797 p << "<" << getRef() << ", " << getInnerType() << ">";
802 TypedeclOp TypeAliasType::getTypeDecl(const HWSymbolCache &cache) {
803 SymbolRefAttr ref = getRef();
804 auto typeScope = ::dyn_cast_or_null<TypeScopeOp>(
805 cache.getDefinition(ref.getRootReference()));
809 return typeScope.lookupSymbol<TypedeclOp>(ref.getLeafReference());
816 LogicalResult ModuleType::verify(function_ref<InFlightDiagnostic()> emitError,
817 ArrayRef<ModulePort> ports) {
818 if (llvm::any_of(ports, [](const ModulePort &port) {
819 return hasHWInOutType(port.type);
821 return emitError() << "Ports cannot be inout types";
825 size_t ModuleType::getPortIdForInputId(size_t idx) {
826 assert(idx < getImpl()->inputToAbs.size() && "input port out of range");
827 return getImpl()->inputToAbs[idx];
830 size_t ModuleType::getPortIdForOutputId(size_t idx) {
831 assert(idx < getImpl()->outputToAbs.size() && " output port out of range");
832 return getImpl()->outputToAbs[idx];
835 size_t ModuleType::getInputIdForPortId(size_t idx) {
836 auto nIdx = getImpl()->absToInput[idx];
837 assert(nIdx != ~0ULL);
841 size_t ModuleType::getOutputIdForPortId(size_t idx) {
842 auto nIdx = getImpl()->absToOutput[idx];
843 assert(nIdx != ~0ULL);
847 size_t ModuleType::getNumInputs() { return getImpl()->inputToAbs.size(); }
849 size_t ModuleType::getNumOutputs() { return getImpl()->outputToAbs.size(); }
851 size_t ModuleType::getNumPorts() { return getPorts().size(); }
853 SmallVector<Type> ModuleType::getInputTypes() {
854 SmallVector<Type> retval;
855 for (auto &p : getPorts()) {
856 if (p.dir == ModulePort::Direction::Input)
857 retval.push_back(p.type);
858 else if (p.dir == ModulePort::Direction::InOut) {
859 retval.push_back(hw::InOutType::get(p.type));
865 SmallVector<Type> ModuleType::getOutputTypes() {
866 SmallVector<Type> retval;
867 for (auto &p : getPorts())
868 if (p.dir == ModulePort::Direction::Output)
869 retval.push_back(p.type);
873 SmallVector<Type> ModuleType::getPortTypes() {
874 SmallVector<Type> retval;
875 for (auto &p : getPorts())
876 retval.push_back(p.type);
880 Type ModuleType::getInputType(size_t idx) {
881 const auto &portInfo = getPorts()[getPortIdForInputId(idx)];
882 if (portInfo.dir != ModulePort::InOut)
883 return portInfo.type;
884 return InOutType::get(portInfo.type);
887 Type ModuleType::getOutputType(size_t idx) {
888 return getPorts()[getPortIdForOutputId(idx)].type;
891 SmallVector<Attribute> ModuleType::getInputNames() {
892 SmallVector<Attribute> retval;
893 for (auto &p : getPorts())
894 if (p.dir != ModulePort::Direction::Output)
895 retval.push_back(p.name);
899 SmallVector<Attribute> ModuleType::getOutputNames() {
900 SmallVector<Attribute> retval;
901 for (auto &p : getPorts())
902 if (p.dir == ModulePort::Direction::Output)
903 retval.push_back(p.name);
907 StringAttr ModuleType::getPortNameAttr(size_t idx) {
908 return getPorts()[idx].name;
911 StringRef ModuleType::getPortName(size_t idx) {
912 auto sa = getPortNameAttr(idx);
914 return sa.getValue();
918 StringAttr ModuleType::getInputNameAttr(size_t idx) {
919 return getPorts()[getPortIdForInputId(idx)].name;
922 StringRef ModuleType::getInputName(size_t idx) {
923 auto sa = getInputNameAttr(idx);
925 return sa.getValue();
929 StringAttr ModuleType::getOutputNameAttr(size_t idx) {
930 return getPorts()[getPortIdForOutputId(idx)].name;
933 StringRef ModuleType::getOutputName(size_t idx) {
934 auto sa = getOutputNameAttr(idx);
936 return sa.getValue();
940 bool ModuleType::isOutput(size_t idx) {
941 auto &p = getPorts()[idx];
942 return p.dir == ModulePort::Direction::Output;
945 FunctionType ModuleType::getFuncType() {
946 SmallVector<Type> inputs, outputs;
947 for (auto p : getPorts())
948 if (p.dir == ModulePort::Input)
949 inputs.push_back(p.type);
950 else if (p.dir == ModulePort::InOut)
951 inputs.push_back(InOutType::get(p.type));
953 outputs.push_back(p.type);
954 return FunctionType::get(getContext(), inputs, outputs);
957 ArrayRef<ModulePort> ModuleType::getPorts() const {
958 return getImpl()->getPorts();
961 FailureOr<ModuleType> ModuleType::resolveParametricTypes(ArrayAttr parameters,
964 SmallVector<ModulePort, 8> resolvedPorts;
965 for (ModulePort port : getPorts()) {
966 FailureOr<Type> resolvedType =
967 evaluateParametricType(loc, parameters, port.type, emitErrors);
968 if (failed(resolvedType))
970 port.type = *resolvedType;
971 resolvedPorts.push_back(port);
973 return ModuleType::get(getContext(), resolvedPorts);
976 static StringRef dirToStr(ModulePort::Direction dir) {
978 case ModulePort::Direction::Input:
980 case ModulePort::Direction::Output:
982 case ModulePort::Direction::InOut:
987 static ModulePort::Direction strToDir(StringRef str) {
989 return ModulePort::Direction::Input;
991 return ModulePort::Direction::Output;
993 return ModulePort::Direction::InOut;
994 llvm::report_fatal_error("invalid direction");
999 static ParseResult parsePorts(AsmParser &p,
1000 SmallVectorImpl<ModulePort> &ports) {
1001 return p.parseCommaSeparatedList(
1002 mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
1006 if (p.parseKeyword(&dir) || p.parseKeywordOrString(&name) ||
1007 p.parseColon() || p.parseType(type))
1010 {StringAttr::get(p.getContext(), name), type, strToDir(dir)});
1016 static void printPorts(AsmPrinter &p, ArrayRef<ModulePort> ports) {
1018 llvm::interleaveComma(ports, p, [&](const ModulePort &port) {
1019 p << dirToStr(port.dir) << " ";
1020 p.printKeywordOrString(port.name.getValue());
1021 p << " : " << port.type;
1026 Type ModuleType::parse(AsmParser &odsParser) {
1027 llvm::SmallVector<ModulePort, 4> ports;
1028 if (parsePorts(odsParser, ports))
1030 return get(odsParser.getContext(), ports);
1033 void ModuleType::print(AsmPrinter &odsPrinter) const {
1034 printPorts(odsPrinter, getPorts());
1037 ModuleType circt::hw::detail::fnToMod(Operation *op,
1038 ArrayRef<Attribute> inputNames,
1039 ArrayRef<Attribute> outputNames) {
1041 cast<FunctionType>(cast<mlir::FunctionOpInterface>(op).getFunctionType()),
1042 inputNames, outputNames);
1045 ModuleType circt::hw::detail::fnToMod(FunctionType fnty,
1046 ArrayRef<Attribute> inputNames,
1047 ArrayRef<Attribute> outputNames) {
1048 SmallVector<ModulePort> ports;
1049 if (!inputNames.empty()) {
1050 for (auto [t, n] : llvm::zip_equal(fnty.getInputs(), inputNames))
1051 if (auto iot = dyn_cast<hw::InOutType>(t))
1052 ports.push_back({cast<StringAttr>(n), iot.getElementType(),
1053 ModulePort::Direction::InOut});
1055 ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Input});
1057 for (auto t : fnty.getInputs())
1058 if (auto iot = dyn_cast<hw::InOutType>(t))
1060 {{}, iot.getElementType(), ModulePort::Direction::InOut});
1062 ports.push_back({{}, t, ModulePort::Direction::Input});
1064 if (!outputNames.empty()) {
1065 for (auto [t, n] : llvm::zip_equal(fnty.getResults(), outputNames))
1066 ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Output});
1068 for (auto t : fnty.getResults())
1069 ports.push_back({{}, t, ModulePort::Direction::Output});
1071 return ModuleType::get(fnty.getContext(), ports);
1074 detail::ModuleTypeStorage::ModuleTypeStorage(ArrayRef<ModulePort> inPorts)
1076 size_t nextInput = 0;
1077 size_t nextOutput = 0;
1078 for (auto [idx, p] : llvm::enumerate(ports)) {
1079 if (p.dir == ModulePort::Direction::Output) {
1080 outputToAbs.push_back(idx);
1081 absToOutput.push_back(nextOutput);
1082 absToInput.push_back(~0ULL);
1085 inputToAbs.push_back(idx);
1086 absToInput.push_back(nextInput);
1087 absToOutput.push_back(~0ULL);
1097 void HWDialect::registerTypes() {
1099 #define GET_TYPEDEF_LIST
1100 #include "circt/Dialect/HW/HWTypes.cpp.inc"
assert(baseType &&"element must be base type")
static void printHWElementType(Type element, AsmPrinter &p)
static LogicalResult parseArray(AsmParser &p, Attribute &dim, Type &inner)
static ParseResult parseHWElementType(Type &result, AsmParser &p)
Parse and print nested HW types nicely.
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
uint64_t getWidth(Type t)
std::pair< uint64_t, uint64_t > getIndexAndSubfieldID(Type type, uint64_t fieldID)
uint64_t getFieldID(Type type, uint64_t index)
uint64_t getMaxFieldID(Type)
bool isHWIntegerType(mlir::Type type)
Return true if the specified type is a value HW Integer type.
bool isHWValueType(mlir::Type type)
Return true if the specified type can be used as an HW value type, that is the set of types that can ...
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
bool isHWEnumType(mlir::Type type)
Return true if the specified type is a HW Enum type.
mlir::Type getCanonicalType(mlir::Type type)
bool hasHWInOutType(mlir::Type type)
Return true if the specified type contains known marker types like InOutType.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.