CIRCT  19.0.0git
HWTypes.cpp
Go to the documentation of this file.
1 //===- HWTypes.cpp - HW types code defs -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation logic for HW data types.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "circt/Dialect/HW/HWOps.h"
18 #include "circt/Support/LLVM.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Diagnostics.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/StorageUniquerSupport.h"
24 #include "mlir/IR/Types.h"
25 #include "llvm/ADT/SmallSet.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 
30 using namespace circt;
31 using namespace circt::hw;
32 using namespace circt::hw::detail;
33 
34 static ParseResult parseHWArray(AsmParser &parser, Attribute &dim,
35  Type &elementType);
36 static void printHWArray(AsmPrinter &printer, Attribute dim, Type elementType);
37 
38 static ParseResult parseHWElementType(AsmParser &parser, Type &elementType);
39 static void printHWElementType(AsmPrinter &printer, Type dim);
40 
41 #define GET_TYPEDEF_CLASSES
42 #include "circt/Dialect/HW/HWTypes.cpp.inc"
43 
44 //===----------------------------------------------------------------------===//
45 // Type Helpers
46 //===----------------------------------------------------------------------===/
47 
48 mlir::Type circt::hw::getCanonicalType(mlir::Type type) {
49  Type canonicalType;
50  if (auto typeAlias = dyn_cast<TypeAliasType>(type))
51  canonicalType = typeAlias.getCanonicalType();
52  else
53  canonicalType = type;
54  return canonicalType;
55 }
56 
57 /// Return true if the specified type is a value HW Integer type. This checks
58 /// that it is a signless standard dialect type or a hw::IntType.
59 bool circt::hw::isHWIntegerType(mlir::Type type) {
60  Type canonicalType = getCanonicalType(type);
61 
62  if (isa<hw::IntType>(canonicalType))
63  return true;
64 
65  auto intType = dyn_cast<IntegerType>(canonicalType);
66  if (!intType || !intType.isSignless())
67  return false;
68 
69  return true;
70 }
71 
72 bool circt::hw::isHWEnumType(mlir::Type type) {
73  return isa<hw::EnumType>(getCanonicalType(type));
74 }
75 
76 /// Return true if the specified type can be used as an HW value type, that is
77 /// the set of types that can be composed together to represent synthesized,
78 /// hardware but not marker types like InOutType.
79 bool circt::hw::isHWValueType(Type type) {
80  // Signless and signed integer types are both valid.
81  if (isa<IntegerType, IntType, EnumType>(type))
82  return true;
83 
84  if (auto array = dyn_cast<ArrayType>(type))
85  return isHWValueType(array.getElementType());
86 
87  if (auto array = dyn_cast<UnpackedArrayType>(type))
88  return isHWValueType(array.getElementType());
89 
90  if (auto t = dyn_cast<StructType>(type))
91  return llvm::all_of(t.getElements(),
92  [](auto f) { return isHWValueType(f.type); });
93 
94  if (auto t = dyn_cast<UnionType>(type))
95  return llvm::all_of(t.getElements(),
96  [](auto f) { return isHWValueType(f.type); });
97 
98  if (auto t = dyn_cast<TypeAliasType>(type))
99  return isHWValueType(t.getCanonicalType());
100 
101  return false;
102 }
103 
104 /// Return the hardware bit width of a type. Does not reflect any encoding,
105 /// padding, or storage scheme, just the bit (and wire width) of a
106 /// statically-size type. Reflects the number of wires needed to transmit a
107 /// value of this type. Returns -1 if the type is not known or cannot be
108 /// statically computed.
109 int64_t circt::hw::getBitWidth(mlir::Type type) {
110  return llvm::TypeSwitch<::mlir::Type, size_t>(type)
111  .Case<IntegerType>(
112  [](IntegerType t) { return t.getIntOrFloatBitWidth(); })
113  .Case<ArrayType, UnpackedArrayType>([](auto a) {
114  int64_t elementBitWidth = getBitWidth(a.getElementType());
115  if (elementBitWidth < 0)
116  return elementBitWidth;
117  int64_t dimBitWidth = a.getNumElements();
118  if (dimBitWidth < 0)
119  return static_cast<int64_t>(-1L);
120  return (int64_t)a.getNumElements() * elementBitWidth;
121  })
122  .Case<StructType>([](StructType s) {
123  int64_t total = 0;
124  for (auto field : s.getElements()) {
125  int64_t fieldSize = getBitWidth(field.type);
126  if (fieldSize < 0)
127  return fieldSize;
128  total += fieldSize;
129  }
130  return total;
131  })
132  .Case<UnionType>([](UnionType u) {
133  int64_t maxSize = 0;
134  for (auto field : u.getElements()) {
135  int64_t fieldSize = getBitWidth(field.type) + field.offset;
136  if (fieldSize > maxSize)
137  maxSize = fieldSize;
138  }
139  return maxSize;
140  })
141  .Case<EnumType>([](EnumType e) { return e.getBitWidth(); })
142  .Case<TypeAliasType>(
143  [](TypeAliasType t) { return getBitWidth(t.getCanonicalType()); })
144  .Default([](Type) { return -1; });
145 }
146 
147 /// Return true if the specified type contains known marker types like
148 /// InOutType. Unlike isHWValueType, this is not conservative, it only returns
149 /// false on known InOut types, rather than any unknown types.
150 bool circt::hw::hasHWInOutType(Type type) {
151  if (auto array = dyn_cast<ArrayType>(type))
152  return hasHWInOutType(array.getElementType());
153 
154  if (auto array = dyn_cast<UnpackedArrayType>(type))
155  return hasHWInOutType(array.getElementType());
156 
157  if (auto t = dyn_cast<StructType>(type)) {
158  return std::any_of(t.getElements().begin(), t.getElements().end(),
159  [](const auto &f) { return hasHWInOutType(f.type); });
160  }
161 
162  if (auto t = dyn_cast<TypeAliasType>(type))
163  return hasHWInOutType(t.getCanonicalType());
164 
165  return isa<InOutType>(type);
166 }
167 
168 /// Parse and print nested HW types nicely. These helper methods allow eliding
169 /// the "hw." prefix on array, inout, and other types when in a context that
170 /// expects HW subelement types.
171 static ParseResult parseHWElementType(AsmParser &p, Type &result) {
172  // If this is an HW dialect type, then we don't need/want the !hw. prefix
173  // redundantly specified.
174  auto fullString = static_cast<DialectAsmParser &>(p).getFullSymbolSpec();
175  auto *curPtr = p.getCurrentLocation().getPointer();
176  auto typeString =
177  StringRef(curPtr, fullString.size() - (curPtr - fullString.data()));
178 
179  if (typeString.starts_with("array<") || typeString.starts_with("inout<") ||
180  typeString.starts_with("uarray<") || typeString.starts_with("struct<") ||
181  typeString.starts_with("typealias<") || typeString.starts_with("int<") ||
182  typeString.starts_with("enum<")) {
183  llvm::StringRef mnemonic;
184  auto parseResult = generatedTypeParser(p, &mnemonic, result);
185  return parseResult.has_value() ? success() : failure();
186  }
187 
188  return p.parseType(result);
189 }
190 
191 static void printHWElementType(AsmPrinter &p, Type element) {
192  if (succeeded(generatedTypePrinter(element, p)))
193  return;
194  p.printType(element);
195 }
196 
197 //===----------------------------------------------------------------------===//
198 // Int Type
199 //===----------------------------------------------------------------------===//
200 
201 Type IntType::get(mlir::TypedAttr width) {
202  // The width expression must always be a 32-bit wide integer type itself.
203  auto widthWidth = llvm::dyn_cast<IntegerType>(width.getType());
204  assert(widthWidth && widthWidth.getWidth() == 32 &&
205  "!hw.int width must be 32-bits");
206  (void)widthWidth;
207 
208  if (auto cstWidth = llvm::dyn_cast<IntegerAttr>(width))
209  return IntegerType::get(width.getContext(),
210  cstWidth.getValue().getZExtValue());
211 
212  return Base::get(width.getContext(), width);
213 }
214 
215 Type IntType::parse(AsmParser &p) {
216  // The bitwidth of the parameter size is always 32 bits.
217  auto int32Type = p.getBuilder().getIntegerType(32);
218 
219  mlir::TypedAttr width;
220  if (p.parseLess() || p.parseAttribute(width, int32Type) || p.parseGreater())
221  return Type();
222  return get(width);
223 }
224 
225 void IntType::print(AsmPrinter &p) const {
226  p << "<";
227  p.printAttributeWithoutType(getWidth());
228  p << '>';
229 }
230 
231 //===----------------------------------------------------------------------===//
232 // Struct Type
233 //===----------------------------------------------------------------------===//
234 
235 namespace circt {
236 namespace hw {
237 namespace detail {
238 bool operator==(const FieldInfo &a, const FieldInfo &b) {
239  return a.name == b.name && a.type == b.type;
240 }
241 llvm::hash_code hash_value(const FieldInfo &fi) {
242  return llvm::hash_combine(fi.name, fi.type);
243 }
244 } // namespace detail
245 } // namespace hw
246 } // namespace circt
247 
248 /// Parse a list of unique field names and types within <>. E.g.:
249 /// <foo: i7, bar: i8>
250 static ParseResult parseFields(AsmParser &p,
251  SmallVectorImpl<FieldInfo> &parameters) {
252  llvm::StringSet<> nameSet;
253  bool hasDuplicateName = false;
254  auto parseResult = p.parseCommaSeparatedList(
255  mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
256  std::string name;
257  Type type;
258 
259  auto fieldLoc = p.getCurrentLocation();
260  if (p.parseKeywordOrString(&name) || p.parseColon() ||
261  p.parseType(type))
262  return failure();
263 
264  if (!nameSet.insert(name).second) {
265  p.emitError(fieldLoc, "duplicate field name \'" + name + "\'");
266  // Continue parsing to print all duplicates, but make sure to error
267  // eventually
268  hasDuplicateName = true;
269  }
270 
271  parameters.push_back(
272  FieldInfo{StringAttr::get(p.getContext(), name), type});
273  return success();
274  });
275 
276  if (hasDuplicateName)
277  return failure();
278  return parseResult;
279 }
280 
281 /// Print out a list of named fields surrounded by <>.
282 static void printFields(AsmPrinter &p, ArrayRef<FieldInfo> fields) {
283  p << '<';
284  llvm::interleaveComma(fields, p, [&](const FieldInfo &field) {
285  p.printKeywordOrString(field.name.getValue());
286  p << ": " << field.type;
287  });
288  p << ">";
289 }
290 
291 Type StructType::parse(AsmParser &p) {
292  llvm::SmallVector<FieldInfo, 4> parameters;
293  if (parseFields(p, parameters))
294  return Type();
295  return get(p.getContext(), parameters);
296 }
297 
298 LogicalResult StructType::verify(function_ref<InFlightDiagnostic()> emitError,
299  ArrayRef<StructType::FieldInfo> elements) {
300  llvm::SmallDenseSet<StringAttr> fieldNameSet;
301  LogicalResult result = success();
302  fieldNameSet.reserve(elements.size());
303  for (const auto &elt : elements)
304  if (!fieldNameSet.insert(elt.name).second) {
305  result = failure();
306  emitError() << "duplicate field name '" << elt.name.getValue()
307  << "' in hw.struct type";
308  }
309  return result;
310 }
311 
312 void StructType::print(AsmPrinter &p) const { printFields(p, getElements()); }
313 
314 Type StructType::getFieldType(mlir::StringRef fieldName) {
315  for (const auto &field : getElements())
316  if (field.name == fieldName)
317  return field.type;
318  return Type();
319 }
320 
321 std::optional<uint32_t> StructType::getFieldIndex(mlir::StringRef fieldName) {
322  ArrayRef<hw::StructType::FieldInfo> elems = getElements();
323  for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
324  if (elems[idx].name == fieldName)
325  return idx;
326  return {};
327 }
328 
329 std::optional<uint32_t> StructType::getFieldIndex(mlir::StringAttr fieldName) {
330  ArrayRef<hw::StructType::FieldInfo> elems = getElements();
331  for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
332  if (elems[idx].name == fieldName)
333  return idx;
334  return {};
335 }
336 
337 static std::pair<uint64_t, SmallVector<uint64_t>>
338 getFieldIDsStruct(const StructType &st) {
339  uint64_t fieldID = 0;
340  auto elements = st.getElements();
341  SmallVector<uint64_t> fieldIDs;
342  fieldIDs.reserve(elements.size());
343  for (auto &element : elements) {
344  auto type = element.type;
345  fieldID += 1;
346  fieldIDs.push_back(fieldID);
347  // Increment the field ID for the next field by the number of subfields.
348  fieldID += hw::FieldIdImpl::getMaxFieldID(type);
349  }
350  return {fieldID, fieldIDs};
351 }
352 
353 void StructType::getInnerTypes(SmallVectorImpl<Type> &types) {
354  for (const auto &field : getElements())
355  types.push_back(field.type);
356 }
357 
358 uint64_t StructType::getMaxFieldID() const {
359  uint64_t fieldID = 0;
360  for (const auto &field : getElements())
361  fieldID += 1 + hw::FieldIdImpl::getMaxFieldID(field.type);
362  return fieldID;
363 }
364 
365 std::pair<Type, uint64_t>
366 StructType::getSubTypeByFieldID(uint64_t fieldID) const {
367  if (fieldID == 0)
368  return {*this, 0};
369  auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
370  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
371  auto subfieldIndex = std::distance(fieldIDs.begin(), it);
372  auto subfieldType = getElements()[subfieldIndex].type;
373  auto subfieldID = fieldID - fieldIDs[subfieldIndex];
374  return {subfieldType, subfieldID};
375 }
376 
377 std::pair<uint64_t, bool>
378 StructType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
379  auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
380  auto childRoot = fieldIDs[index];
381  auto rangeEnd =
382  index + 1 >= getElements().size() ? maxId : (fieldIDs[index + 1] - 1);
383  return std::make_pair(fieldID - childRoot,
384  fieldID >= childRoot && fieldID <= rangeEnd);
385 }
386 
387 uint64_t StructType::getFieldID(uint64_t index) const {
388  auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
389  return fieldIDs[index];
390 }
391 
392 uint64_t StructType::getIndexForFieldID(uint64_t fieldID) const {
393  assert(!getElements().empty() && "Bundle must have >0 fields");
394  auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
395  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
396  return std::distance(fieldIDs.begin(), it);
397 }
398 
399 std::pair<uint64_t, uint64_t>
400 StructType::getIndexAndSubfieldID(uint64_t fieldID) const {
401  auto index = getIndexForFieldID(fieldID);
402  auto elementFieldID = getFieldID(index);
403  return {index, fieldID - elementFieldID};
404 }
405 
406 //===----------------------------------------------------------------------===//
407 // Union Type
408 //===----------------------------------------------------------------------===//
409 
410 namespace circt {
411 namespace hw {
412 namespace detail {
413 bool operator==(const OffsetFieldInfo &a, const OffsetFieldInfo &b) {
414  return a.name == b.name && a.type == b.type && a.offset == b.offset;
415 }
416 // NOLINTNEXTLINE
417 llvm::hash_code hash_value(const OffsetFieldInfo &fi) {
418  return llvm::hash_combine(fi.name, fi.type, fi.offset);
419 }
420 } // namespace detail
421 } // namespace hw
422 } // namespace circt
423 
424 Type UnionType::parse(AsmParser &p) {
425  llvm::SmallVector<FieldInfo, 4> parameters;
426  llvm::StringSet<> nameSet;
427  bool hasDuplicateName = false;
428  if (p.parseCommaSeparatedList(
429  mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
430  StringRef name;
431  Type type;
432 
433  auto fieldLoc = p.getCurrentLocation();
434  if (p.parseKeyword(&name) || p.parseColon() || p.parseType(type))
435  return failure();
436 
437  if (!nameSet.insert(name).second) {
438  p.emitError(fieldLoc, "duplicate field name \'" + name +
439  "\' in hw.union type");
440  // Continue parsing to print all duplicates, but make sure to
441  // error eventually
442  hasDuplicateName = true;
443  }
444 
445  size_t offset = 0;
446  if (succeeded(p.parseOptionalKeyword("offset")))
447  if (p.parseInteger(offset))
448  return failure();
449  parameters.push_back(UnionType::FieldInfo{
450  StringAttr::get(p.getContext(), name), type, offset});
451  return success();
452  }))
453  return Type();
454 
455  if (hasDuplicateName)
456  return Type();
457 
458  return get(p.getContext(), parameters);
459 }
460 
461 void UnionType::print(AsmPrinter &odsPrinter) const {
462  odsPrinter << '<';
463  llvm::interleaveComma(
464  getElements(), odsPrinter, [&](const UnionType::FieldInfo &field) {
465  odsPrinter << field.name.getValue() << ": " << field.type;
466  if (field.offset)
467  odsPrinter << " offset " << field.offset;
468  });
469  odsPrinter << ">";
470 }
471 
472 LogicalResult UnionType::verify(function_ref<InFlightDiagnostic()> emitError,
473  ArrayRef<UnionType::FieldInfo> elements) {
474  llvm::SmallDenseSet<StringAttr> fieldNameSet;
475  LogicalResult result = success();
476  fieldNameSet.reserve(elements.size());
477  for (const auto &elt : elements)
478  if (!fieldNameSet.insert(elt.name).second) {
479  result = failure();
480  emitError() << "duplicate field name '" << elt.name.getValue()
481  << "' in hw.union type";
482  }
483  return result;
484 }
485 
486 std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringAttr fieldName) {
487  ArrayRef<hw::UnionType::FieldInfo> elems = getElements();
488  for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
489  if (elems[idx].name == fieldName)
490  return idx;
491  return {};
492 }
493 
494 std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringRef fieldName) {
495  return getFieldIndex(StringAttr::get(getContext(), fieldName));
496 }
497 
498 UnionType::FieldInfo UnionType::getFieldInfo(::mlir::StringRef fieldName) {
499  if (auto fieldIndex = getFieldIndex(fieldName))
500  return getElements()[*fieldIndex];
501  return FieldInfo();
502 }
503 
504 Type UnionType::getFieldType(mlir::StringRef fieldName) {
505  return getFieldInfo(fieldName).type;
506 }
507 
508 //===----------------------------------------------------------------------===//
509 // Enum Type
510 //===----------------------------------------------------------------------===//
511 
512 Type EnumType::parse(AsmParser &p) {
513  llvm::SmallVector<Attribute> fields;
514 
515  if (p.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, [&]() {
516  StringRef name;
517  if (p.parseKeyword(&name))
518  return failure();
519  fields.push_back(StringAttr::get(p.getContext(), name));
520  return success();
521  }))
522  return Type();
523 
524  return get(p.getContext(), ArrayAttr::get(p.getContext(), fields));
525 }
526 
527 void EnumType::print(AsmPrinter &p) const {
528  p << '<';
529  llvm::interleaveComma(getFields(), p, [&](Attribute enumerator) {
530  p << llvm::cast<StringAttr>(enumerator).getValue();
531  });
532  p << ">";
533 }
534 
535 bool EnumType::contains(mlir::StringRef field) {
536  return indexOf(field).has_value();
537 }
538 
539 std::optional<size_t> EnumType::indexOf(mlir::StringRef field) {
540  for (auto it : llvm::enumerate(getFields()))
541  if (llvm::cast<StringAttr>(it.value()).getValue() == field)
542  return it.index();
543  return {};
544 }
545 
546 size_t EnumType::getBitWidth() {
547  auto w = getFields().size();
548  if (w > 1)
549  return llvm::Log2_64_Ceil(getFields().size());
550  return 1;
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // ArrayType
555 //===----------------------------------------------------------------------===//
556 
557 static ParseResult parseHWArray(AsmParser &p, Attribute &dim, Type &inner) {
558  uint64_t dimLiteral;
559  auto int64Type = p.getBuilder().getIntegerType(64);
560 
561  if (auto res = p.parseOptionalInteger(dimLiteral); res.has_value())
562  dim = p.getBuilder().getI64IntegerAttr(dimLiteral);
563  else if (!p.parseOptionalAttribute(dim, int64Type).has_value())
564  return failure();
565 
566  if (!isa<IntegerAttr, ParamExprAttr, ParamDeclRefAttr>(dim)) {
567  p.emitError(p.getNameLoc(), "unsupported dimension kind in hw.array");
568  return failure();
569  }
570 
571  if (p.parseXInDimensionList() || parseHWElementType(p, inner))
572  return failure();
573 
574  return success();
575 }
576 
577 static void printHWArray(AsmPrinter &p, Attribute dim, Type elementType) {
578  p.printAttributeWithoutType(dim);
579  p << "x";
580  printHWElementType(p, elementType);
581 }
582 
583 size_t ArrayType::getNumElements() const {
584  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(getSizeAttr()))
585  return intAttr.getInt();
586  return -1;
587 }
588 
589 LogicalResult ArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
590  Type innerType, Attribute size) {
591  if (hasHWInOutType(innerType))
592  return emitError() << "hw.array cannot contain InOut types";
593  return success();
594 }
595 
596 uint64_t ArrayType::getMaxFieldID() const {
597  return getNumElements() *
598  (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
599 }
600 
601 std::pair<Type, uint64_t>
602 ArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
603  if (fieldID == 0)
604  return {*this, 0};
605  return {getElementType(), getIndexAndSubfieldID(fieldID).second};
606 }
607 
608 std::pair<uint64_t, bool>
609 ArrayType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
610  auto childRoot = getFieldID(index);
611  auto rangeEnd =
612  index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
613  return std::make_pair(fieldID - childRoot,
614  fieldID >= childRoot && fieldID <= rangeEnd);
615 }
616 
617 uint64_t ArrayType::getIndexForFieldID(uint64_t fieldID) const {
618  assert(fieldID && "fieldID must be at least 1");
619  // Divide the field ID by the number of fieldID's per element.
620  return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
621 }
622 
623 std::pair<uint64_t, uint64_t>
624 ArrayType::getIndexAndSubfieldID(uint64_t fieldID) const {
625  auto index = getIndexForFieldID(fieldID);
626  auto elementFieldID = getFieldID(index);
627  return {index, fieldID - elementFieldID};
628 }
629 
630 uint64_t ArrayType::getFieldID(uint64_t index) const {
631  return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
632 }
633 
634 //===----------------------------------------------------------------------===//
635 // UnpackedArrayType
636 //===----------------------------------------------------------------------===//
637 
638 LogicalResult
639 UnpackedArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
640  Type innerType, Attribute size) {
641  if (!isHWValueType(innerType))
642  return emitError() << "invalid element for uarray type";
643  return success();
644 }
645 
646 size_t UnpackedArrayType::getNumElements() const {
647  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(getSizeAttr()))
648  return intAttr.getInt();
649  return -1;
650 }
651 
652 uint64_t UnpackedArrayType::getMaxFieldID() const {
653  return getNumElements() *
654  (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
655 }
656 
657 std::pair<Type, uint64_t>
658 UnpackedArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
659  if (fieldID == 0)
660  return {*this, 0};
661  return {getElementType(), getIndexAndSubfieldID(fieldID).second};
662 }
663 
664 std::pair<uint64_t, bool>
666  uint64_t index) const {
667  auto childRoot = getFieldID(index);
668  auto rangeEnd =
669  index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
670  return std::make_pair(fieldID - childRoot,
671  fieldID >= childRoot && fieldID <= rangeEnd);
672 }
673 
674 uint64_t UnpackedArrayType::getIndexForFieldID(uint64_t fieldID) const {
675  assert(fieldID && "fieldID must be at least 1");
676  // Divide the field ID by the number of fieldID's per element.
677  return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
678 }
679 
680 std::pair<uint64_t, uint64_t>
681 UnpackedArrayType::getIndexAndSubfieldID(uint64_t fieldID) const {
682  auto index = getIndexForFieldID(fieldID);
683  auto elementFieldID = getFieldID(index);
684  return {index, fieldID - elementFieldID};
685 }
686 
687 uint64_t UnpackedArrayType::getFieldID(uint64_t index) const {
688  return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
689 }
690 
691 //===----------------------------------------------------------------------===//
692 // InOutType
693 //===----------------------------------------------------------------------===//
694 
695 LogicalResult InOutType::verify(function_ref<InFlightDiagnostic()> emitError,
696  Type innerType) {
697  if (!isHWValueType(innerType))
698  return emitError() << "invalid element for hw.inout type " << innerType;
699  return success();
700 }
701 
702 //===----------------------------------------------------------------------===//
703 // TypeAliasType
704 //===----------------------------------------------------------------------===//
705 
706 static Type computeCanonicalType(Type type) {
707  return llvm::TypeSwitch<Type, Type>(type)
708  .Case([](TypeAliasType t) {
709  return computeCanonicalType(t.getCanonicalType());
710  })
711  .Case([](ArrayType t) {
712  return ArrayType::get(computeCanonicalType(t.getElementType()),
713  t.getNumElements());
714  })
715  .Case([](UnpackedArrayType t) {
716  return UnpackedArrayType::get(computeCanonicalType(t.getElementType()),
717  t.getNumElements());
718  })
719  .Case([](StructType t) {
720  SmallVector<StructType::FieldInfo> fieldInfo;
721  for (auto field : t.getElements())
722  fieldInfo.push_back(StructType::FieldInfo{
723  field.name, computeCanonicalType(field.type)});
724  return StructType::get(t.getContext(), fieldInfo);
725  })
726  .Default([](Type t) { return t; });
727 }
728 
729 TypeAliasType TypeAliasType::get(SymbolRefAttr ref, Type innerType) {
730  return get(ref.getContext(), ref, innerType, computeCanonicalType(innerType));
731 }
732 
733 Type TypeAliasType::parse(AsmParser &p) {
734  SymbolRefAttr ref;
735  Type type;
736  if (p.parseLess() || p.parseAttribute(ref) || p.parseComma() ||
737  p.parseType(type) || p.parseGreater())
738  return Type();
739 
740  return get(ref, type);
741 }
742 
743 void TypeAliasType::print(AsmPrinter &p) const {
744  p << "<" << getRef() << ", " << getInnerType() << ">";
745 }
746 
747 /// Return the Typedecl referenced by this TypeAlias, given the module to look
748 /// in. This returns null when the IR is malformed.
749 TypedeclOp TypeAliasType::getTypeDecl(const HWSymbolCache &cache) {
750  SymbolRefAttr ref = getRef();
751  auto typeScope = ::dyn_cast_or_null<TypeScopeOp>(
752  cache.getDefinition(ref.getRootReference()));
753  if (!typeScope)
754  return {};
755 
756  return typeScope.lookupSymbol<TypedeclOp>(ref.getLeafReference());
757 }
758 
759 //===----------------------------------------------------------------------===//
760 // ModuleType
761 //===----------------------------------------------------------------------===//
762 
763 LogicalResult ModuleType::verify(function_ref<InFlightDiagnostic()> emitError,
764  ArrayRef<ModulePort> ports) {
765  if (llvm::any_of(ports, [](const ModulePort &port) {
766  return hasHWInOutType(port.type);
767  }))
768  return emitError() << "Ports cannot be inout types";
769  return success();
770 }
771 
772 size_t ModuleType::getPortIdForInputId(size_t idx) {
773  assert(idx < getImpl()->inputToAbs.size() && "input port out of range");
774  return getImpl()->inputToAbs[idx];
775 }
776 
777 size_t ModuleType::getPortIdForOutputId(size_t idx) {
778  assert(idx < getImpl()->outputToAbs.size() && " output port out of range");
779  return getImpl()->outputToAbs[idx];
780 }
781 
782 size_t ModuleType::getInputIdForPortId(size_t idx) {
783  auto nIdx = getImpl()->absToInput[idx];
784  assert(nIdx != ~0ULL);
785  return nIdx;
786 }
787 
788 size_t ModuleType::getOutputIdForPortId(size_t idx) {
789  auto nIdx = getImpl()->absToOutput[idx];
790  assert(nIdx != ~0ULL);
791  return nIdx;
792 }
793 
794 size_t ModuleType::getNumInputs() { return getImpl()->inputToAbs.size(); }
795 
796 size_t ModuleType::getNumOutputs() { return getImpl()->outputToAbs.size(); }
797 
798 size_t ModuleType::getNumPorts() { return getPorts().size(); }
799 
800 SmallVector<Type> ModuleType::getInputTypes() {
801  SmallVector<Type> retval;
802  for (auto &p : getPorts()) {
803  if (p.dir == ModulePort::Direction::Input)
804  retval.push_back(p.type);
805  else if (p.dir == ModulePort::Direction::InOut) {
806  retval.push_back(hw::InOutType::get(p.type));
807  }
808  }
809  return retval;
810 }
811 
812 SmallVector<Type> ModuleType::getOutputTypes() {
813  SmallVector<Type> retval;
814  for (auto &p : getPorts())
815  if (p.dir == ModulePort::Direction::Output)
816  retval.push_back(p.type);
817  return retval;
818 }
819 
820 SmallVector<Type> ModuleType::getPortTypes() {
821  SmallVector<Type> retval;
822  for (auto &p : getPorts())
823  retval.push_back(p.type);
824  return retval;
825 }
826 
827 Type ModuleType::getInputType(size_t idx) {
828  const auto &portInfo = getPorts()[getPortIdForInputId(idx)];
829  if (portInfo.dir != ModulePort::InOut)
830  return portInfo.type;
831  return InOutType::get(portInfo.type);
832 }
833 
834 Type ModuleType::getOutputType(size_t idx) {
835  return getPorts()[getPortIdForOutputId(idx)].type;
836 }
837 
838 SmallVector<Attribute> ModuleType::getInputNames() {
839  SmallVector<Attribute> retval;
840  for (auto &p : getPorts())
841  if (p.dir != ModulePort::Direction::Output)
842  retval.push_back(p.name);
843  return retval;
844 }
845 
846 SmallVector<Attribute> ModuleType::getOutputNames() {
847  SmallVector<Attribute> retval;
848  for (auto &p : getPorts())
849  if (p.dir == ModulePort::Direction::Output)
850  retval.push_back(p.name);
851  return retval;
852 }
853 
854 StringAttr ModuleType::getPortNameAttr(size_t idx) {
855  return getPorts()[idx].name;
856 }
857 
858 StringRef ModuleType::getPortName(size_t idx) {
859  auto sa = getPortNameAttr(idx);
860  if (sa)
861  return sa.getValue();
862  return {};
863 }
864 
865 StringAttr ModuleType::getInputNameAttr(size_t idx) {
866  return getPorts()[getPortIdForInputId(idx)].name;
867 }
868 
869 StringRef ModuleType::getInputName(size_t idx) {
870  auto sa = getInputNameAttr(idx);
871  if (sa)
872  return sa.getValue();
873  return {};
874 }
875 
876 StringAttr ModuleType::getOutputNameAttr(size_t idx) {
877  return getPorts()[getPortIdForOutputId(idx)].name;
878 }
879 
880 StringRef ModuleType::getOutputName(size_t idx) {
881  auto sa = getOutputNameAttr(idx);
882  if (sa)
883  return sa.getValue();
884  return {};
885 }
886 
887 bool ModuleType::isOutput(size_t idx) {
888  auto &p = getPorts()[idx];
889  return p.dir == ModulePort::Direction::Output;
890 }
891 
892 FunctionType ModuleType::getFuncType() {
893  SmallVector<Type> inputs, outputs;
894  for (auto p : getPorts())
895  if (p.dir == ModulePort::Input)
896  inputs.push_back(p.type);
897  else if (p.dir == ModulePort::InOut)
898  inputs.push_back(InOutType::get(p.type));
899  else
900  outputs.push_back(p.type);
901  return FunctionType::get(getContext(), inputs, outputs);
902 }
903 
904 ArrayRef<ModulePort> ModuleType::getPorts() const {
905  return getImpl()->getPorts();
906 }
907 
908 FailureOr<ModuleType> ModuleType::resolveParametricTypes(ArrayAttr parameters,
909  LocationAttr loc,
910  bool emitErrors) {
911  SmallVector<ModulePort, 8> resolvedPorts;
912  for (ModulePort port : getPorts()) {
913  FailureOr<Type> resolvedType =
914  evaluateParametricType(loc, parameters, port.type, emitErrors);
915  if (failed(resolvedType))
916  return failure();
917  port.type = *resolvedType;
918  resolvedPorts.push_back(port);
919  }
920  return ModuleType::get(getContext(), resolvedPorts);
921 }
922 
923 static StringRef dirToStr(ModulePort::Direction dir) {
924  switch (dir) {
926  return "input";
928  return "output";
930  return "inout";
931  }
932 }
933 
934 static ModulePort::Direction strToDir(StringRef str) {
935  if (str == "input")
937  if (str == "output")
939  if (str == "inout")
941  llvm::report_fatal_error("invalid direction");
942 }
943 
944 /// Parse a list of field names and types within <>. E.g.:
945 /// <input foo: i7, output bar: i8>
946 static ParseResult parsePorts(AsmParser &p,
947  SmallVectorImpl<ModulePort> &ports) {
948  return p.parseCommaSeparatedList(
949  mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
950  StringRef dir;
951  std::string name;
952  Type type;
953  if (p.parseKeyword(&dir) || p.parseKeywordOrString(&name) ||
954  p.parseColon() || p.parseType(type))
955  return failure();
956  ports.push_back(
957  {StringAttr::get(p.getContext(), name), type, strToDir(dir)});
958  return success();
959  });
960 }
961 
962 /// Print out a list of named fields surrounded by <>.
963 static void printPorts(AsmPrinter &p, ArrayRef<ModulePort> ports) {
964  p << '<';
965  llvm::interleaveComma(ports, p, [&](const ModulePort &port) {
966  p << dirToStr(port.dir) << " ";
967  p.printKeywordOrString(port.name.getValue());
968  p << " : " << port.type;
969  });
970  p << ">";
971 }
972 
973 Type ModuleType::parse(AsmParser &odsParser) {
974  llvm::SmallVector<ModulePort, 4> ports;
975  if (parsePorts(odsParser, ports))
976  return Type();
977  return get(odsParser.getContext(), ports);
978 }
979 
980 void ModuleType::print(AsmPrinter &odsPrinter) const {
981  printPorts(odsPrinter, getPorts());
982 }
983 
984 ModuleType circt::hw::detail::fnToMod(Operation *op,
985  ArrayRef<Attribute> inputNames,
986  ArrayRef<Attribute> outputNames) {
987  return fnToMod(
988  cast<FunctionType>(cast<mlir::FunctionOpInterface>(op).getFunctionType()),
989  inputNames, outputNames);
990 }
991 
992 ModuleType circt::hw::detail::fnToMod(FunctionType fnty,
993  ArrayRef<Attribute> inputNames,
994  ArrayRef<Attribute> outputNames) {
995  SmallVector<ModulePort> ports;
996  if (!inputNames.empty()) {
997  for (auto [t, n] : llvm::zip_equal(fnty.getInputs(), inputNames))
998  if (auto iot = dyn_cast<hw::InOutType>(t))
999  ports.push_back({cast<StringAttr>(n), iot.getElementType(),
1001  else
1002  ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Input});
1003  } else {
1004  for (auto t : fnty.getInputs())
1005  if (auto iot = dyn_cast<hw::InOutType>(t))
1006  ports.push_back(
1007  {{}, iot.getElementType(), ModulePort::Direction::InOut});
1008  else
1009  ports.push_back({{}, t, ModulePort::Direction::Input});
1010  }
1011  if (!outputNames.empty()) {
1012  for (auto [t, n] : llvm::zip_equal(fnty.getResults(), outputNames))
1013  ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Output});
1014  } else {
1015  for (auto t : fnty.getResults())
1016  ports.push_back({{}, t, ModulePort::Direction::Output});
1017  }
1018  return ModuleType::get(fnty.getContext(), ports);
1019 }
1020 
1021 detail::ModuleTypeStorage::ModuleTypeStorage(ArrayRef<ModulePort> inPorts)
1022  : ports(inPorts) {
1023  size_t nextInput = 0;
1024  size_t nextOutput = 0;
1025  for (auto [idx, p] : llvm::enumerate(ports)) {
1026  if (p.dir == ModulePort::Direction::Output) {
1027  outputToAbs.push_back(idx);
1028  absToOutput.push_back(nextOutput);
1029  absToInput.push_back(~0ULL);
1030  ++nextOutput;
1031  } else {
1032  inputToAbs.push_back(idx);
1033  absToInput.push_back(nextInput);
1034  absToOutput.push_back(~0ULL);
1035  ++nextInput;
1036  }
1037  }
1038 }
1039 
1040 //===----------------------------------------------------------------------===//
1041 // BoilerPlate
1042 //===----------------------------------------------------------------------===//
1043 
1044 void HWDialect::registerTypes() {
1045  addTypes<
1046 #define GET_TYPEDEF_LIST
1047 #include "circt/Dialect/HW/HWTypes.cpp.inc"
1048  >();
1049 }
assert(baseType &&"element must be base type")
MlirType elementType
Definition: CHIRRTL.cpp:29
int32_t width
Definition: FIRRTL.cpp:36
static ModulePort::Direction strToDir(StringRef str)
Definition: HWTypes.cpp:934
static void printPorts(AsmPrinter &p, ArrayRef< ModulePort > ports)
Print out a list of named fields surrounded by <>.
Definition: HWTypes.cpp:963
static StringRef dirToStr(ModulePort::Direction dir)
Definition: HWTypes.cpp:923
static ParseResult parseHWArray(AsmParser &parser, Attribute &dim, Type &elementType)
Definition: HWTypes.cpp:557
static ParseResult parseHWElementType(AsmParser &parser, Type &elementType)
Parse and print nested HW types nicely.
Definition: HWTypes.cpp:171
static ParseResult parsePorts(AsmParser &p, SmallVectorImpl< ModulePort > &ports)
Parse a list of field names and types within <>.
Definition: HWTypes.cpp:946
static void printHWArray(AsmPrinter &printer, Attribute dim, Type elementType)
Definition: HWTypes.cpp:577
static Type computeCanonicalType(Type type)
Definition: HWTypes.cpp:706
static void printHWElementType(AsmPrinter &printer, Type dim)
Definition: HWTypes.cpp:191
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
@ InOut
Definition: HW.h:35
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
static unsigned getMaxFieldID(FIRRTLBaseType type)
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition: HWSymCache.h:27
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition: HWSymCache.h:56
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Definition: SVOps.cpp:2443
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:32
mlir::Type innerType(mlir::Type type)
Definition: ESITypes.cpp:184
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
Definition: FIRRTLOps.cpp:298
std::pair< uint64_t, uint64_t > getIndexAndSubfieldID(Type type, uint64_t fieldID)
uint64_t getFieldID(Type type, uint64_t index)
std::pair<::mlir::Type, uint64_t > getSubTypeByFieldID(Type, uint64_t fieldID)
std::pair< uint64_t, bool > projectToChildFieldID(Type, uint64_t fieldID, uint64_t index)
uint64_t getMaxFieldID(Type)
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
Definition: HWTypes.cpp:984
LogicalResult resolveParametricTypes(Location loc, ArrayAttr parameters, ArrayRef< Type > types, SmallVectorImpl< Type > &resolvedTypes, const EmitErrorFn &emitError)
Stores a resolved version of each type in.
bool isHWIntegerType(mlir::Type type)
Return true if the specified type is a value HW Integer type.
Definition: HWTypes.cpp:59
bool isHWValueType(mlir::Type type)
Return true if the specified type can be used as an HW value type, that is the set of types that can ...
enum PEO uint32_t mlir::FailureOr< mlir::Type > evaluateParametricType(mlir::Location loc, mlir::ArrayAttr parameters, mlir::Type type, bool emitErrors=true)
Returns a resolved version of 'type' wherein any parameter reference has been evaluated based on the ...
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition: HWTypes.cpp:109
bool isHWEnumType(mlir::Type type)
Return true if the specified type is a HW Enum type.
Definition: HWTypes.cpp:72
mlir::Type getCanonicalType(mlir::Type type)
Definition: HWTypes.cpp:48
bool hasHWInOutType(mlir::Type type)
Return true if the specified type contains known marker types like InOutType.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
mlir::Type type
Definition: HWTypes.h:30
mlir::StringAttr name
Definition: HWTypes.h:29
SmallVector< ModulePort > ports
The parametric data held by the storage class.
Definition: HWTypes.h:69
SmallVector< size_t > absToInput
Definition: HWTypes.h:73
SmallVector< size_t > outputToAbs
Definition: HWTypes.h:72
SmallVector< size_t > inputToAbs
Definition: HWTypes.h:71
SmallVector< size_t > absToOutput
Definition: HWTypes.h:74