CIRCT  20.0.0git
HWTypes.cpp
Go to the documentation of this file.
1 //===- HWTypes.cpp - HW types code defs -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation logic for HW data types.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "circt/Dialect/HW/HWOps.h"
18 #include "circt/Support/LLVM.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Diagnostics.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/StorageUniquerSupport.h"
24 #include "mlir/IR/Types.h"
25 #include "mlir/Interfaces/MemorySlotInterfaces.h"
26 #include "llvm/ADT/SmallSet.h"
27 #include "llvm/ADT/StringExtras.h"
28 #include "llvm/ADT/StringSet.h"
29 #include "llvm/ADT/TypeSwitch.h"
30 
31 using namespace circt;
32 using namespace circt::hw;
33 using namespace circt::hw::detail;
34 
35 static ParseResult parseHWArray(AsmParser &parser, Attribute &dim,
36  Type &elementType);
37 static void printHWArray(AsmPrinter &printer, Attribute dim, Type elementType);
38 
39 static ParseResult parseHWElementType(AsmParser &parser, Type &elementType);
40 static void printHWElementType(AsmPrinter &printer, Type dim);
41 
42 #define GET_TYPEDEF_CLASSES
43 #include "circt/Dialect/HW/HWTypes.cpp.inc"
44 
45 //===----------------------------------------------------------------------===//
46 // Type Helpers
47 //===----------------------------------------------------------------------===/
48 
49 mlir::Type circt::hw::getCanonicalType(mlir::Type type) {
50  Type canonicalType;
51  if (auto typeAlias = dyn_cast<TypeAliasType>(type))
52  canonicalType = typeAlias.getCanonicalType();
53  else
54  canonicalType = type;
55  return canonicalType;
56 }
57 
58 /// Return true if the specified type is a value HW Integer type. This checks
59 /// that it is a signless standard dialect type or a hw::IntType.
60 bool circt::hw::isHWIntegerType(mlir::Type type) {
61  Type canonicalType = getCanonicalType(type);
62 
63  if (isa<hw::IntType>(canonicalType))
64  return true;
65 
66  auto intType = dyn_cast<IntegerType>(canonicalType);
67  if (!intType || !intType.isSignless())
68  return false;
69 
70  return true;
71 }
72 
73 bool circt::hw::isHWEnumType(mlir::Type type) {
74  return isa<hw::EnumType>(getCanonicalType(type));
75 }
76 
77 /// Return true if the specified type can be used as an HW value type, that is
78 /// the set of types that can be composed together to represent synthesized,
79 /// hardware but not marker types like InOutType.
80 bool circt::hw::isHWValueType(Type type) {
81  // Signless and signed integer types are both valid.
82  if (isa<IntegerType, IntType, EnumType>(type))
83  return true;
84 
85  if (auto array = dyn_cast<ArrayType>(type))
86  return isHWValueType(array.getElementType());
87 
88  if (auto array = dyn_cast<UnpackedArrayType>(type))
89  return isHWValueType(array.getElementType());
90 
91  if (auto t = dyn_cast<StructType>(type))
92  return llvm::all_of(t.getElements(),
93  [](auto f) { return isHWValueType(f.type); });
94 
95  if (auto t = dyn_cast<UnionType>(type))
96  return llvm::all_of(t.getElements(),
97  [](auto f) { return isHWValueType(f.type); });
98 
99  if (auto t = dyn_cast<TypeAliasType>(type))
100  return isHWValueType(t.getCanonicalType());
101 
102  return false;
103 }
104 
105 /// Return the hardware bit width of a type. Does not reflect any encoding,
106 /// padding, or storage scheme, just the bit (and wire width) of a
107 /// statically-size type. Reflects the number of wires needed to transmit a
108 /// value of this type. Returns -1 if the type is not known or cannot be
109 /// statically computed.
110 int64_t circt::hw::getBitWidth(mlir::Type type) {
111  return llvm::TypeSwitch<::mlir::Type, size_t>(type)
112  .Case<IntegerType>(
113  [](IntegerType t) { return t.getIntOrFloatBitWidth(); })
114  .Case<ArrayType, UnpackedArrayType>([](auto a) {
115  int64_t elementBitWidth = getBitWidth(a.getElementType());
116  if (elementBitWidth < 0)
117  return elementBitWidth;
118  int64_t dimBitWidth = a.getNumElements();
119  if (dimBitWidth < 0)
120  return static_cast<int64_t>(-1L);
121  return (int64_t)a.getNumElements() * elementBitWidth;
122  })
123  .Case<StructType>([](StructType s) {
124  int64_t total = 0;
125  for (auto field : s.getElements()) {
126  int64_t fieldSize = getBitWidth(field.type);
127  if (fieldSize < 0)
128  return fieldSize;
129  total += fieldSize;
130  }
131  return total;
132  })
133  .Case<UnionType>([](UnionType u) {
134  int64_t maxSize = 0;
135  for (auto field : u.getElements()) {
136  int64_t fieldSize = getBitWidth(field.type) + field.offset;
137  if (fieldSize > maxSize)
138  maxSize = fieldSize;
139  }
140  return maxSize;
141  })
142  .Case<EnumType>([](EnumType e) { return e.getBitWidth(); })
143  .Case<TypeAliasType>(
144  [](TypeAliasType t) { return getBitWidth(t.getCanonicalType()); })
145  .Default([](Type) { return -1; });
146 }
147 
148 /// Return true if the specified type contains known marker types like
149 /// InOutType. Unlike isHWValueType, this is not conservative, it only returns
150 /// false on known InOut types, rather than any unknown types.
151 bool circt::hw::hasHWInOutType(Type type) {
152  if (auto array = dyn_cast<ArrayType>(type))
153  return hasHWInOutType(array.getElementType());
154 
155  if (auto array = dyn_cast<UnpackedArrayType>(type))
156  return hasHWInOutType(array.getElementType());
157 
158  if (auto t = dyn_cast<StructType>(type)) {
159  return std::any_of(t.getElements().begin(), t.getElements().end(),
160  [](const auto &f) { return hasHWInOutType(f.type); });
161  }
162 
163  if (auto t = dyn_cast<TypeAliasType>(type))
164  return hasHWInOutType(t.getCanonicalType());
165 
166  return isa<InOutType>(type);
167 }
168 
169 /// Parse and print nested HW types nicely. These helper methods allow eliding
170 /// the "hw." prefix on array, inout, and other types when in a context that
171 /// expects HW subelement types.
172 static ParseResult parseHWElementType(AsmParser &p, Type &result) {
173  // If this is an HW dialect type, then we don't need/want the !hw. prefix
174  // redundantly specified.
175  auto fullString = static_cast<DialectAsmParser &>(p).getFullSymbolSpec();
176  auto *curPtr = p.getCurrentLocation().getPointer();
177  auto typeString =
178  StringRef(curPtr, fullString.size() - (curPtr - fullString.data()));
179 
180  if (typeString.starts_with("array<") || typeString.starts_with("inout<") ||
181  typeString.starts_with("uarray<") || typeString.starts_with("struct<") ||
182  typeString.starts_with("typealias<") || typeString.starts_with("int<") ||
183  typeString.starts_with("enum<")) {
184  llvm::StringRef mnemonic;
185  if (auto parseResult = generatedTypeParser(p, &mnemonic, result);
186  parseResult.has_value())
187  return *parseResult;
188  return p.emitError(p.getNameLoc(), "invalid type `") << typeString << "`";
189  }
190 
191  return p.parseType(result);
192 }
193 
194 static void printHWElementType(AsmPrinter &p, Type element) {
195  if (succeeded(generatedTypePrinter(element, p)))
196  return;
197  p.printType(element);
198 }
199 
200 //===----------------------------------------------------------------------===//
201 // Int Type
202 //===----------------------------------------------------------------------===//
203 
204 Type IntType::get(mlir::TypedAttr width) {
205  // The width expression must always be a 32-bit wide integer type itself.
206  auto widthWidth = llvm::dyn_cast<IntegerType>(width.getType());
207  assert(widthWidth && widthWidth.getWidth() == 32 &&
208  "!hw.int width must be 32-bits");
209  (void)widthWidth;
210 
211  if (auto cstWidth = llvm::dyn_cast<IntegerAttr>(width))
212  return IntegerType::get(width.getContext(),
213  cstWidth.getValue().getZExtValue());
214 
215  return Base::get(width.getContext(), width);
216 }
217 
218 Type IntType::parse(AsmParser &p) {
219  // The bitwidth of the parameter size is always 32 bits.
220  auto int32Type = p.getBuilder().getIntegerType(32);
221 
222  mlir::TypedAttr width;
223  if (p.parseLess() || p.parseAttribute(width, int32Type) || p.parseGreater())
224  return Type();
225  return get(width);
226 }
227 
228 void IntType::print(AsmPrinter &p) const {
229  p << "<";
230  p.printAttributeWithoutType(getWidth());
231  p << '>';
232 }
233 
234 //===----------------------------------------------------------------------===//
235 // Struct Type
236 //===----------------------------------------------------------------------===//
237 
238 namespace circt {
239 namespace hw {
240 namespace detail {
241 bool operator==(const FieldInfo &a, const FieldInfo &b) {
242  return a.name == b.name && a.type == b.type;
243 }
244 llvm::hash_code hash_value(const FieldInfo &fi) {
245  return llvm::hash_combine(fi.name, fi.type);
246 }
247 } // namespace detail
248 } // namespace hw
249 } // namespace circt
250 
251 /// Parse a list of unique field names and types within <>. E.g.:
252 /// <foo: i7, bar: i8>
253 static ParseResult parseFields(AsmParser &p,
254  SmallVectorImpl<FieldInfo> &parameters) {
255  llvm::StringSet<> nameSet;
256  bool hasDuplicateName = false;
257  auto parseResult = p.parseCommaSeparatedList(
258  mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
259  std::string name;
260  Type type;
261 
262  auto fieldLoc = p.getCurrentLocation();
263  if (p.parseKeywordOrString(&name) || p.parseColon() ||
264  p.parseType(type))
265  return failure();
266 
267  if (!nameSet.insert(name).second) {
268  p.emitError(fieldLoc, "duplicate field name \'" + name + "\'");
269  // Continue parsing to print all duplicates, but make sure to error
270  // eventually
271  hasDuplicateName = true;
272  }
273 
274  parameters.push_back(
275  FieldInfo{StringAttr::get(p.getContext(), name), type});
276  return success();
277  });
278 
279  if (hasDuplicateName)
280  return failure();
281  return parseResult;
282 }
283 
284 /// Print out a list of named fields surrounded by <>.
285 static void printFields(AsmPrinter &p, ArrayRef<FieldInfo> fields) {
286  p << '<';
287  llvm::interleaveComma(fields, p, [&](const FieldInfo &field) {
288  p.printKeywordOrString(field.name.getValue());
289  p << ": " << field.type;
290  });
291  p << ">";
292 }
293 
294 Type StructType::parse(AsmParser &p) {
295  llvm::SmallVector<FieldInfo, 4> parameters;
296  if (parseFields(p, parameters))
297  return Type();
298  return get(p.getContext(), parameters);
299 }
300 
301 LogicalResult StructType::verify(function_ref<InFlightDiagnostic()> emitError,
302  ArrayRef<StructType::FieldInfo> elements) {
303  llvm::SmallDenseSet<StringAttr> fieldNameSet;
304  LogicalResult result = success();
305  fieldNameSet.reserve(elements.size());
306  for (const auto &elt : elements)
307  if (!fieldNameSet.insert(elt.name).second) {
308  result = failure();
309  emitError() << "duplicate field name '" << elt.name.getValue()
310  << "' in hw.struct type";
311  }
312  return result;
313 }
314 
315 void StructType::print(AsmPrinter &p) const { printFields(p, getElements()); }
316 
317 Type StructType::getFieldType(mlir::StringRef fieldName) {
318  for (const auto &field : getElements())
319  if (field.name == fieldName)
320  return field.type;
321  return Type();
322 }
323 
324 std::optional<uint32_t> StructType::getFieldIndex(mlir::StringRef fieldName) {
325  ArrayRef<hw::StructType::FieldInfo> elems = getElements();
326  for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
327  if (elems[idx].name == fieldName)
328  return idx;
329  return {};
330 }
331 
332 std::optional<uint32_t> StructType::getFieldIndex(mlir::StringAttr fieldName) {
333  ArrayRef<hw::StructType::FieldInfo> elems = getElements();
334  for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
335  if (elems[idx].name == fieldName)
336  return idx;
337  return {};
338 }
339 
340 static std::pair<uint64_t, SmallVector<uint64_t>>
341 getFieldIDsStruct(const StructType &st) {
342  uint64_t fieldID = 0;
343  auto elements = st.getElements();
344  SmallVector<uint64_t> fieldIDs;
345  fieldIDs.reserve(elements.size());
346  for (auto &element : elements) {
347  auto type = element.type;
348  fieldID += 1;
349  fieldIDs.push_back(fieldID);
350  // Increment the field ID for the next field by the number of subfields.
351  fieldID += hw::FieldIdImpl::getMaxFieldID(type);
352  }
353  return {fieldID, fieldIDs};
354 }
355 
356 void StructType::getInnerTypes(SmallVectorImpl<Type> &types) {
357  for (const auto &field : getElements())
358  types.push_back(field.type);
359 }
360 
361 uint64_t StructType::getMaxFieldID() const {
362  uint64_t fieldID = 0;
363  for (const auto &field : getElements())
364  fieldID += 1 + hw::FieldIdImpl::getMaxFieldID(field.type);
365  return fieldID;
366 }
367 
368 std::pair<Type, uint64_t>
369 StructType::getSubTypeByFieldID(uint64_t fieldID) const {
370  if (fieldID == 0)
371  return {*this, 0};
372  auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
373  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
374  auto subfieldIndex = std::distance(fieldIDs.begin(), it);
375  auto subfieldType = getElements()[subfieldIndex].type;
376  auto subfieldID = fieldID - fieldIDs[subfieldIndex];
377  return {subfieldType, subfieldID};
378 }
379 
380 std::pair<uint64_t, bool>
381 StructType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
382  auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
383  auto childRoot = fieldIDs[index];
384  auto rangeEnd =
385  index + 1 >= getElements().size() ? maxId : (fieldIDs[index + 1] - 1);
386  return std::make_pair(fieldID - childRoot,
387  fieldID >= childRoot && fieldID <= rangeEnd);
388 }
389 
390 uint64_t StructType::getFieldID(uint64_t index) const {
391  auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
392  return fieldIDs[index];
393 }
394 
395 uint64_t StructType::getIndexForFieldID(uint64_t fieldID) const {
396  assert(!getElements().empty() && "Bundle must have >0 fields");
397  auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
398  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
399  return std::distance(fieldIDs.begin(), it);
400 }
401 
402 std::pair<uint64_t, uint64_t>
403 StructType::getIndexAndSubfieldID(uint64_t fieldID) const {
404  auto index = getIndexForFieldID(fieldID);
405  auto elementFieldID = getFieldID(index);
406  return {index, fieldID - elementFieldID};
407 }
408 
409 std::optional<DenseMap<Attribute, Type>>
410 hw::StructType::getSubelementIndexMap() const {
411  DenseMap<Attribute, Type> destructured;
412  for (auto [i, field] : llvm::enumerate(getElements()))
413  destructured.insert(
414  {IntegerAttr::get(IndexType::get(getContext()), i), field.type});
415  return destructured;
416 }
417 
418 Type hw::StructType::getTypeAtIndex(Attribute index) const {
419  auto indexAttr = llvm::dyn_cast<IntegerAttr>(index);
420  if (!indexAttr)
421  return {};
422 
423  return getSubTypeByFieldID(indexAttr.getInt()).first;
424 }
425 
426 //===----------------------------------------------------------------------===//
427 // Union Type
428 //===----------------------------------------------------------------------===//
429 
430 namespace circt {
431 namespace hw {
432 namespace detail {
433 bool operator==(const OffsetFieldInfo &a, const OffsetFieldInfo &b) {
434  return a.name == b.name && a.type == b.type && a.offset == b.offset;
435 }
436 // NOLINTNEXTLINE
437 llvm::hash_code hash_value(const OffsetFieldInfo &fi) {
438  return llvm::hash_combine(fi.name, fi.type, fi.offset);
439 }
440 } // namespace detail
441 } // namespace hw
442 } // namespace circt
443 
444 Type UnionType::parse(AsmParser &p) {
445  llvm::SmallVector<FieldInfo, 4> parameters;
446  llvm::StringSet<> nameSet;
447  bool hasDuplicateName = false;
448  if (p.parseCommaSeparatedList(
449  mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
450  StringRef name;
451  Type type;
452 
453  auto fieldLoc = p.getCurrentLocation();
454  if (p.parseKeyword(&name) || p.parseColon() || p.parseType(type))
455  return failure();
456 
457  if (!nameSet.insert(name).second) {
458  p.emitError(fieldLoc, "duplicate field name \'" + name +
459  "\' in hw.union type");
460  // Continue parsing to print all duplicates, but make sure to
461  // error eventually
462  hasDuplicateName = true;
463  }
464 
465  size_t offset = 0;
466  if (succeeded(p.parseOptionalKeyword("offset")))
467  if (p.parseInteger(offset))
468  return failure();
469  parameters.push_back(UnionType::FieldInfo{
470  StringAttr::get(p.getContext(), name), type, offset});
471  return success();
472  }))
473  return Type();
474 
475  if (hasDuplicateName)
476  return Type();
477 
478  return get(p.getContext(), parameters);
479 }
480 
481 void UnionType::print(AsmPrinter &odsPrinter) const {
482  odsPrinter << '<';
483  llvm::interleaveComma(
484  getElements(), odsPrinter, [&](const UnionType::FieldInfo &field) {
485  odsPrinter << field.name.getValue() << ": " << field.type;
486  if (field.offset)
487  odsPrinter << " offset " << field.offset;
488  });
489  odsPrinter << ">";
490 }
491 
492 LogicalResult UnionType::verify(function_ref<InFlightDiagnostic()> emitError,
493  ArrayRef<UnionType::FieldInfo> elements) {
494  llvm::SmallDenseSet<StringAttr> fieldNameSet;
495  LogicalResult result = success();
496  fieldNameSet.reserve(elements.size());
497  for (const auto &elt : elements)
498  if (!fieldNameSet.insert(elt.name).second) {
499  result = failure();
500  emitError() << "duplicate field name '" << elt.name.getValue()
501  << "' in hw.union type";
502  }
503  return result;
504 }
505 
506 std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringAttr fieldName) {
507  ArrayRef<hw::UnionType::FieldInfo> elems = getElements();
508  for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
509  if (elems[idx].name == fieldName)
510  return idx;
511  return {};
512 }
513 
514 std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringRef fieldName) {
515  return getFieldIndex(StringAttr::get(getContext(), fieldName));
516 }
517 
518 UnionType::FieldInfo UnionType::getFieldInfo(::mlir::StringRef fieldName) {
519  if (auto fieldIndex = getFieldIndex(fieldName))
520  return getElements()[*fieldIndex];
521  return FieldInfo();
522 }
523 
524 Type UnionType::getFieldType(mlir::StringRef fieldName) {
525  return getFieldInfo(fieldName).type;
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // Enum Type
530 //===----------------------------------------------------------------------===//
531 
532 Type EnumType::parse(AsmParser &p) {
533  llvm::SmallVector<Attribute> fields;
534 
535  if (p.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, [&]() {
536  StringRef name;
537  if (p.parseKeyword(&name))
538  return failure();
539  fields.push_back(StringAttr::get(p.getContext(), name));
540  return success();
541  }))
542  return Type();
543 
544  return get(p.getContext(), ArrayAttr::get(p.getContext(), fields));
545 }
546 
547 void EnumType::print(AsmPrinter &p) const {
548  p << '<';
549  llvm::interleaveComma(getFields(), p, [&](Attribute enumerator) {
550  p << llvm::cast<StringAttr>(enumerator).getValue();
551  });
552  p << ">";
553 }
554 
555 bool EnumType::contains(mlir::StringRef field) {
556  return indexOf(field).has_value();
557 }
558 
559 std::optional<size_t> EnumType::indexOf(mlir::StringRef field) {
560  for (auto it : llvm::enumerate(getFields()))
561  if (llvm::cast<StringAttr>(it.value()).getValue() == field)
562  return it.index();
563  return {};
564 }
565 
566 size_t EnumType::getBitWidth() {
567  auto w = getFields().size();
568  if (w > 1)
569  return llvm::Log2_64_Ceil(getFields().size());
570  return 1;
571 }
572 
573 //===----------------------------------------------------------------------===//
574 // ArrayType
575 //===----------------------------------------------------------------------===//
576 
577 static ParseResult parseHWArray(AsmParser &p, Attribute &dim, Type &inner) {
578  uint64_t dimLiteral;
579  auto int64Type = p.getBuilder().getIntegerType(64);
580 
581  if (auto res = p.parseOptionalInteger(dimLiteral); res.has_value()) {
582  if (failed(*res))
583  return failure();
584  dim = p.getBuilder().getI64IntegerAttr(dimLiteral);
585  } else if (auto res64 = p.parseOptionalAttribute(dim, int64Type);
586  res64.has_value()) {
587  if (failed(*res64))
588  return failure();
589  } else
590  return p.emitError(p.getNameLoc(), "expected integer");
591 
592  if (!isa<IntegerAttr, ParamExprAttr, ParamDeclRefAttr>(dim)) {
593  p.emitError(p.getNameLoc(), "unsupported dimension kind in hw.array");
594  return failure();
595  }
596 
597  if (p.parseXInDimensionList() || parseHWElementType(p, inner))
598  return failure();
599 
600  return success();
601 }
602 
603 static void printHWArray(AsmPrinter &p, Attribute dim, Type elementType) {
604  p.printAttributeWithoutType(dim);
605  p << "x";
606  printHWElementType(p, elementType);
607 }
608 
609 size_t ArrayType::getNumElements() const {
610  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(getSizeAttr()))
611  return intAttr.getInt();
612  return -1;
613 }
614 
615 LogicalResult ArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
616  Type innerType, Attribute size) {
617  if (hasHWInOutType(innerType))
618  return emitError() << "hw.array cannot contain InOut types";
619  return success();
620 }
621 
622 uint64_t ArrayType::getMaxFieldID() const {
623  return getNumElements() *
624  (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
625 }
626 
627 std::pair<Type, uint64_t>
628 ArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
629  if (fieldID == 0)
630  return {*this, 0};
631  return {getElementType(), getIndexAndSubfieldID(fieldID).second};
632 }
633 
634 std::pair<uint64_t, bool>
635 ArrayType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
636  auto childRoot = getFieldID(index);
637  auto rangeEnd =
638  index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
639  return std::make_pair(fieldID - childRoot,
640  fieldID >= childRoot && fieldID <= rangeEnd);
641 }
642 
643 uint64_t ArrayType::getIndexForFieldID(uint64_t fieldID) const {
644  assert(fieldID && "fieldID must be at least 1");
645  // Divide the field ID by the number of fieldID's per element.
646  return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
647 }
648 
649 std::pair<uint64_t, uint64_t>
650 ArrayType::getIndexAndSubfieldID(uint64_t fieldID) const {
651  auto index = getIndexForFieldID(fieldID);
652  auto elementFieldID = getFieldID(index);
653  return {index, fieldID - elementFieldID};
654 }
655 
656 uint64_t ArrayType::getFieldID(uint64_t index) const {
657  return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
658 }
659 
660 std::optional<DenseMap<Attribute, Type>>
661 hw::ArrayType::getSubelementIndexMap() const {
662  DenseMap<Attribute, Type> destructured;
663  for (unsigned i = 0; i < getNumElements(); ++i)
664  destructured.insert(
665  {IntegerAttr::get(IndexType::get(getContext()), i), getElementType()});
666  return destructured;
667 }
668 
669 Type hw::ArrayType::getTypeAtIndex(Attribute index) const {
670  return getElementType();
671 }
672 
673 //===----------------------------------------------------------------------===//
674 // UnpackedArrayType
675 //===----------------------------------------------------------------------===//
676 
677 LogicalResult
678 UnpackedArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
679  Type innerType, Attribute size) {
680  if (!isHWValueType(innerType))
681  return emitError() << "invalid element for uarray type";
682  return success();
683 }
684 
685 size_t UnpackedArrayType::getNumElements() const {
686  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(getSizeAttr()))
687  return intAttr.getInt();
688  return -1;
689 }
690 
691 uint64_t UnpackedArrayType::getMaxFieldID() const {
692  return getNumElements() *
693  (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
694 }
695 
696 std::pair<Type, uint64_t>
697 UnpackedArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
698  if (fieldID == 0)
699  return {*this, 0};
700  return {getElementType(), getIndexAndSubfieldID(fieldID).second};
701 }
702 
703 std::pair<uint64_t, bool>
705  uint64_t index) const {
706  auto childRoot = getFieldID(index);
707  auto rangeEnd =
708  index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
709  return std::make_pair(fieldID - childRoot,
710  fieldID >= childRoot && fieldID <= rangeEnd);
711 }
712 
713 uint64_t UnpackedArrayType::getIndexForFieldID(uint64_t fieldID) const {
714  assert(fieldID && "fieldID must be at least 1");
715  // Divide the field ID by the number of fieldID's per element.
716  return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
717 }
718 
719 std::pair<uint64_t, uint64_t>
720 UnpackedArrayType::getIndexAndSubfieldID(uint64_t fieldID) const {
721  auto index = getIndexForFieldID(fieldID);
722  auto elementFieldID = getFieldID(index);
723  return {index, fieldID - elementFieldID};
724 }
725 
726 uint64_t UnpackedArrayType::getFieldID(uint64_t index) const {
727  return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
728 }
729 
730 //===----------------------------------------------------------------------===//
731 // InOutType
732 //===----------------------------------------------------------------------===//
733 
734 LogicalResult InOutType::verify(function_ref<InFlightDiagnostic()> emitError,
735  Type innerType) {
736  if (!isHWValueType(innerType))
737  return emitError() << "invalid element for hw.inout type " << innerType;
738  return success();
739 }
740 
741 //===----------------------------------------------------------------------===//
742 // TypeAliasType
743 //===----------------------------------------------------------------------===//
744 
745 static Type computeCanonicalType(Type type) {
746  return llvm::TypeSwitch<Type, Type>(type)
747  .Case([](TypeAliasType t) {
748  return computeCanonicalType(t.getCanonicalType());
749  })
750  .Case([](ArrayType t) {
751  return ArrayType::get(computeCanonicalType(t.getElementType()),
752  t.getNumElements());
753  })
754  .Case([](UnpackedArrayType t) {
755  return UnpackedArrayType::get(computeCanonicalType(t.getElementType()),
756  t.getNumElements());
757  })
758  .Case([](StructType t) {
759  SmallVector<StructType::FieldInfo> fieldInfo;
760  for (auto field : t.getElements())
761  fieldInfo.push_back(StructType::FieldInfo{
762  field.name, computeCanonicalType(field.type)});
763  return StructType::get(t.getContext(), fieldInfo);
764  })
765  .Default([](Type t) { return t; });
766 }
767 
768 TypeAliasType TypeAliasType::get(SymbolRefAttr ref, Type innerType) {
769  return get(ref.getContext(), ref, innerType, computeCanonicalType(innerType));
770 }
771 
772 Type TypeAliasType::parse(AsmParser &p) {
773  SymbolRefAttr ref;
774  Type type;
775  if (p.parseLess() || p.parseAttribute(ref) || p.parseComma() ||
776  p.parseType(type) || p.parseGreater())
777  return Type();
778 
779  return get(ref, type);
780 }
781 
782 void TypeAliasType::print(AsmPrinter &p) const {
783  p << "<" << getRef() << ", " << getInnerType() << ">";
784 }
785 
786 /// Return the Typedecl referenced by this TypeAlias, given the module to look
787 /// in. This returns null when the IR is malformed.
788 TypedeclOp TypeAliasType::getTypeDecl(const HWSymbolCache &cache) {
789  SymbolRefAttr ref = getRef();
790  auto typeScope = ::dyn_cast_or_null<TypeScopeOp>(
791  cache.getDefinition(ref.getRootReference()));
792  if (!typeScope)
793  return {};
794 
795  return typeScope.lookupSymbol<TypedeclOp>(ref.getLeafReference());
796 }
797 
798 //===----------------------------------------------------------------------===//
799 // ModuleType
800 //===----------------------------------------------------------------------===//
801 
802 LogicalResult ModuleType::verify(function_ref<InFlightDiagnostic()> emitError,
803  ArrayRef<ModulePort> ports) {
804  if (llvm::any_of(ports, [](const ModulePort &port) {
805  return hasHWInOutType(port.type);
806  }))
807  return emitError() << "Ports cannot be inout types";
808  return success();
809 }
810 
811 size_t ModuleType::getPortIdForInputId(size_t idx) {
812  assert(idx < getImpl()->inputToAbs.size() && "input port out of range");
813  return getImpl()->inputToAbs[idx];
814 }
815 
816 size_t ModuleType::getPortIdForOutputId(size_t idx) {
817  assert(idx < getImpl()->outputToAbs.size() && " output port out of range");
818  return getImpl()->outputToAbs[idx];
819 }
820 
821 size_t ModuleType::getInputIdForPortId(size_t idx) {
822  auto nIdx = getImpl()->absToInput[idx];
823  assert(nIdx != ~0ULL);
824  return nIdx;
825 }
826 
827 size_t ModuleType::getOutputIdForPortId(size_t idx) {
828  auto nIdx = getImpl()->absToOutput[idx];
829  assert(nIdx != ~0ULL);
830  return nIdx;
831 }
832 
833 size_t ModuleType::getNumInputs() { return getImpl()->inputToAbs.size(); }
834 
835 size_t ModuleType::getNumOutputs() { return getImpl()->outputToAbs.size(); }
836 
837 size_t ModuleType::getNumPorts() { return getPorts().size(); }
838 
839 SmallVector<Type> ModuleType::getInputTypes() {
840  SmallVector<Type> retval;
841  for (auto &p : getPorts()) {
842  if (p.dir == ModulePort::Direction::Input)
843  retval.push_back(p.type);
844  else if (p.dir == ModulePort::Direction::InOut) {
845  retval.push_back(hw::InOutType::get(p.type));
846  }
847  }
848  return retval;
849 }
850 
851 SmallVector<Type> ModuleType::getOutputTypes() {
852  SmallVector<Type> retval;
853  for (auto &p : getPorts())
854  if (p.dir == ModulePort::Direction::Output)
855  retval.push_back(p.type);
856  return retval;
857 }
858 
859 SmallVector<Type> ModuleType::getPortTypes() {
860  SmallVector<Type> retval;
861  for (auto &p : getPorts())
862  retval.push_back(p.type);
863  return retval;
864 }
865 
866 Type ModuleType::getInputType(size_t idx) {
867  const auto &portInfo = getPorts()[getPortIdForInputId(idx)];
868  if (portInfo.dir != ModulePort::InOut)
869  return portInfo.type;
870  return InOutType::get(portInfo.type);
871 }
872 
873 Type ModuleType::getOutputType(size_t idx) {
874  return getPorts()[getPortIdForOutputId(idx)].type;
875 }
876 
877 SmallVector<Attribute> ModuleType::getInputNames() {
878  SmallVector<Attribute> retval;
879  for (auto &p : getPorts())
880  if (p.dir != ModulePort::Direction::Output)
881  retval.push_back(p.name);
882  return retval;
883 }
884 
885 SmallVector<Attribute> ModuleType::getOutputNames() {
886  SmallVector<Attribute> retval;
887  for (auto &p : getPorts())
888  if (p.dir == ModulePort::Direction::Output)
889  retval.push_back(p.name);
890  return retval;
891 }
892 
893 StringAttr ModuleType::getPortNameAttr(size_t idx) {
894  return getPorts()[idx].name;
895 }
896 
897 StringRef ModuleType::getPortName(size_t idx) {
898  auto sa = getPortNameAttr(idx);
899  if (sa)
900  return sa.getValue();
901  return {};
902 }
903 
904 StringAttr ModuleType::getInputNameAttr(size_t idx) {
905  return getPorts()[getPortIdForInputId(idx)].name;
906 }
907 
908 StringRef ModuleType::getInputName(size_t idx) {
909  auto sa = getInputNameAttr(idx);
910  if (sa)
911  return sa.getValue();
912  return {};
913 }
914 
915 StringAttr ModuleType::getOutputNameAttr(size_t idx) {
916  return getPorts()[getPortIdForOutputId(idx)].name;
917 }
918 
919 StringRef ModuleType::getOutputName(size_t idx) {
920  auto sa = getOutputNameAttr(idx);
921  if (sa)
922  return sa.getValue();
923  return {};
924 }
925 
926 bool ModuleType::isOutput(size_t idx) {
927  auto &p = getPorts()[idx];
928  return p.dir == ModulePort::Direction::Output;
929 }
930 
931 FunctionType ModuleType::getFuncType() {
932  SmallVector<Type> inputs, outputs;
933  for (auto p : getPorts())
934  if (p.dir == ModulePort::Input)
935  inputs.push_back(p.type);
936  else if (p.dir == ModulePort::InOut)
937  inputs.push_back(InOutType::get(p.type));
938  else
939  outputs.push_back(p.type);
940  return FunctionType::get(getContext(), inputs, outputs);
941 }
942 
943 ArrayRef<ModulePort> ModuleType::getPorts() const {
944  return getImpl()->getPorts();
945 }
946 
947 FailureOr<ModuleType> ModuleType::resolveParametricTypes(ArrayAttr parameters,
948  LocationAttr loc,
949  bool emitErrors) {
950  SmallVector<ModulePort, 8> resolvedPorts;
951  for (ModulePort port : getPorts()) {
952  FailureOr<Type> resolvedType =
953  evaluateParametricType(loc, parameters, port.type, emitErrors);
954  if (failed(resolvedType))
955  return failure();
956  port.type = *resolvedType;
957  resolvedPorts.push_back(port);
958  }
959  return ModuleType::get(getContext(), resolvedPorts);
960 }
961 
962 static StringRef dirToStr(ModulePort::Direction dir) {
963  switch (dir) {
965  return "input";
967  return "output";
969  return "inout";
970  }
971 }
972 
973 static ModulePort::Direction strToDir(StringRef str) {
974  if (str == "input")
976  if (str == "output")
978  if (str == "inout")
980  llvm::report_fatal_error("invalid direction");
981 }
982 
983 /// Parse a list of field names and types within <>. E.g.:
984 /// <input foo: i7, output bar: i8>
985 static ParseResult parsePorts(AsmParser &p,
986  SmallVectorImpl<ModulePort> &ports) {
987  return p.parseCommaSeparatedList(
988  mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
989  StringRef dir;
990  std::string name;
991  Type type;
992  if (p.parseKeyword(&dir) || p.parseKeywordOrString(&name) ||
993  p.parseColon() || p.parseType(type))
994  return failure();
995  ports.push_back(
996  {StringAttr::get(p.getContext(), name), type, strToDir(dir)});
997  return success();
998  });
999 }
1000 
1001 /// Print out a list of named fields surrounded by <>.
1002 static void printPorts(AsmPrinter &p, ArrayRef<ModulePort> ports) {
1003  p << '<';
1004  llvm::interleaveComma(ports, p, [&](const ModulePort &port) {
1005  p << dirToStr(port.dir) << " ";
1006  p.printKeywordOrString(port.name.getValue());
1007  p << " : " << port.type;
1008  });
1009  p << ">";
1010 }
1011 
1012 Type ModuleType::parse(AsmParser &odsParser) {
1013  llvm::SmallVector<ModulePort, 4> ports;
1014  if (parsePorts(odsParser, ports))
1015  return Type();
1016  return get(odsParser.getContext(), ports);
1017 }
1018 
1019 void ModuleType::print(AsmPrinter &odsPrinter) const {
1020  printPorts(odsPrinter, getPorts());
1021 }
1022 
1023 ModuleType circt::hw::detail::fnToMod(Operation *op,
1024  ArrayRef<Attribute> inputNames,
1025  ArrayRef<Attribute> outputNames) {
1026  return fnToMod(
1027  cast<FunctionType>(cast<mlir::FunctionOpInterface>(op).getFunctionType()),
1028  inputNames, outputNames);
1029 }
1030 
1031 ModuleType circt::hw::detail::fnToMod(FunctionType fnty,
1032  ArrayRef<Attribute> inputNames,
1033  ArrayRef<Attribute> outputNames) {
1034  SmallVector<ModulePort> ports;
1035  if (!inputNames.empty()) {
1036  for (auto [t, n] : llvm::zip_equal(fnty.getInputs(), inputNames))
1037  if (auto iot = dyn_cast<hw::InOutType>(t))
1038  ports.push_back({cast<StringAttr>(n), iot.getElementType(),
1040  else
1041  ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Input});
1042  } else {
1043  for (auto t : fnty.getInputs())
1044  if (auto iot = dyn_cast<hw::InOutType>(t))
1045  ports.push_back(
1046  {{}, iot.getElementType(), ModulePort::Direction::InOut});
1047  else
1048  ports.push_back({{}, t, ModulePort::Direction::Input});
1049  }
1050  if (!outputNames.empty()) {
1051  for (auto [t, n] : llvm::zip_equal(fnty.getResults(), outputNames))
1052  ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Output});
1053  } else {
1054  for (auto t : fnty.getResults())
1055  ports.push_back({{}, t, ModulePort::Direction::Output});
1056  }
1057  return ModuleType::get(fnty.getContext(), ports);
1058 }
1059 
1060 detail::ModuleTypeStorage::ModuleTypeStorage(ArrayRef<ModulePort> inPorts)
1061  : ports(inPorts) {
1062  size_t nextInput = 0;
1063  size_t nextOutput = 0;
1064  for (auto [idx, p] : llvm::enumerate(ports)) {
1065  if (p.dir == ModulePort::Direction::Output) {
1066  outputToAbs.push_back(idx);
1067  absToOutput.push_back(nextOutput);
1068  absToInput.push_back(~0ULL);
1069  ++nextOutput;
1070  } else {
1071  inputToAbs.push_back(idx);
1072  absToInput.push_back(nextInput);
1073  absToOutput.push_back(~0ULL);
1074  ++nextInput;
1075  }
1076  }
1077 }
1078 
1079 //===----------------------------------------------------------------------===//
1080 // BoilerPlate
1081 //===----------------------------------------------------------------------===//
1082 
1083 void HWDialect::registerTypes() {
1084  addTypes<
1085 #define GET_TYPEDEF_LIST
1086 #include "circt/Dialect/HW/HWTypes.cpp.inc"
1087  >();
1088 }
assert(baseType &&"element must be base type")
MlirType elementType
Definition: CHIRRTL.cpp:29
static ModulePort::Direction strToDir(StringRef str)
Definition: HWTypes.cpp:973
static void printPorts(AsmPrinter &p, ArrayRef< ModulePort > ports)
Print out a list of named fields surrounded by <>.
Definition: HWTypes.cpp:1002
static StringRef dirToStr(ModulePort::Direction dir)
Definition: HWTypes.cpp:962
static ParseResult parseHWArray(AsmParser &parser, Attribute &dim, Type &elementType)
Definition: HWTypes.cpp:577
static ParseResult parseHWElementType(AsmParser &parser, Type &elementType)
Parse and print nested HW types nicely.
Definition: HWTypes.cpp:172
static ParseResult parsePorts(AsmParser &p, SmallVectorImpl< ModulePort > &ports)
Parse a list of field names and types within <>.
Definition: HWTypes.cpp:985
static void printHWArray(AsmPrinter &printer, Attribute dim, Type elementType)
Definition: HWTypes.cpp:603
static Type computeCanonicalType(Type type)
Definition: HWTypes.cpp:745
static void printHWElementType(AsmPrinter &printer, Type dim)
Definition: HWTypes.cpp:194
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
@ InOut
Definition: HW.h:35
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
static unsigned getMaxFieldID(FIRRTLBaseType type)
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition: HWSymCache.h:27
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition: HWSymCache.h:56
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Definition: SVOps.cpp:2467
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:32
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
Definition: FIRRTLOps.cpp:301
std::pair< uint64_t, uint64_t > getIndexAndSubfieldID(Type type, uint64_t fieldID)
uint64_t getFieldID(Type type, uint64_t index)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
std::pair< uint64_t, bool > projectToChildFieldID(Type, uint64_t fieldID, uint64_t index)
uint64_t getMaxFieldID(Type)
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
Definition: HWTypes.cpp:1023
LogicalResult resolveParametricTypes(Location loc, ArrayAttr parameters, ArrayRef< Type > types, SmallVectorImpl< Type > &resolvedTypes, const EmitErrorFn &emitError)
Stores a resolved version of each type in.
bool isHWIntegerType(mlir::Type type)
Return true if the specified type is a value HW Integer type.
Definition: HWTypes.cpp:60
bool isHWValueType(mlir::Type type)
Return true if the specified type can be used as an HW value type, that is the set of types that can ...
enum PEO uint32_t mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition: HWTypes.cpp:110
bool isHWEnumType(mlir::Type type)
Return true if the specified type is a HW Enum type.
Definition: HWTypes.cpp:73
mlir::Type getCanonicalType(mlir::Type type)
Definition: HWTypes.cpp:49
bool hasHWInOutType(mlir::Type type)
Return true if the specified type contains known marker types like InOutType.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
mlir::Type type
Definition: HWTypes.h:31
mlir::StringAttr name
Definition: HWTypes.h:30
SmallVector< ModulePort > ports
The parametric data held by the storage class.
Definition: HWTypes.h:70
SmallVector< size_t > absToInput
Definition: HWTypes.h:74
SmallVector< size_t > outputToAbs
Definition: HWTypes.h:73
SmallVector< size_t > inputToAbs
Definition: HWTypes.h:72
SmallVector< size_t > absToOutput
Definition: HWTypes.h:75