CIRCT  19.0.0git
HWTypes.cpp
Go to the documentation of this file.
1 //===- HWTypes.cpp - HW types code defs -----------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation logic for HW data types.
10 //
11 //===----------------------------------------------------------------------===//
12 
16 #include "circt/Dialect/HW/HWOps.h"
18 #include "circt/Support/LLVM.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Diagnostics.h"
22 #include "mlir/IR/DialectImplementation.h"
23 #include "mlir/IR/StorageUniquerSupport.h"
24 #include "mlir/IR/Types.h"
25 #include "llvm/ADT/SmallSet.h"
26 #include "llvm/ADT/StringExtras.h"
27 #include "llvm/ADT/StringSet.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 
30 using namespace circt;
31 using namespace circt::hw;
32 using namespace circt::hw::detail;
33 
34 #define GET_TYPEDEF_CLASSES
35 #include "circt/Dialect/HW/HWTypes.cpp.inc"
36 
37 //===----------------------------------------------------------------------===//
38 // Type Helpers
39 //===----------------------------------------------------------------------===/
40 
41 mlir::Type circt::hw::getCanonicalType(mlir::Type type) {
42  Type canonicalType;
43  if (auto typeAlias = dyn_cast<TypeAliasType>(type))
44  canonicalType = typeAlias.getCanonicalType();
45  else
46  canonicalType = type;
47  return canonicalType;
48 }
49 
50 /// Return true if the specified type is a value HW Integer type. This checks
51 /// that it is a signless standard dialect type or a hw::IntType.
52 bool circt::hw::isHWIntegerType(mlir::Type type) {
53  Type canonicalType = getCanonicalType(type);
54 
55  if (isa<hw::IntType>(canonicalType))
56  return true;
57 
58  auto intType = dyn_cast<IntegerType>(canonicalType);
59  if (!intType || !intType.isSignless())
60  return false;
61 
62  return true;
63 }
64 
65 bool circt::hw::isHWEnumType(mlir::Type type) {
66  return isa<hw::EnumType>(getCanonicalType(type));
67 }
68 
69 /// Return true if the specified type can be used as an HW value type, that is
70 /// the set of types that can be composed together to represent synthesized,
71 /// hardware but not marker types like InOutType.
72 bool circt::hw::isHWValueType(Type type) {
73  // Signless and signed integer types are both valid.
74  if (isa<IntegerType, IntType, EnumType>(type))
75  return true;
76 
77  if (auto array = dyn_cast<ArrayType>(type))
78  return isHWValueType(array.getElementType());
79 
80  if (auto array = dyn_cast<UnpackedArrayType>(type))
81  return isHWValueType(array.getElementType());
82 
83  if (auto t = dyn_cast<StructType>(type))
84  return llvm::all_of(t.getElements(),
85  [](auto f) { return isHWValueType(f.type); });
86 
87  if (auto t = dyn_cast<UnionType>(type))
88  return llvm::all_of(t.getElements(),
89  [](auto f) { return isHWValueType(f.type); });
90 
91  if (auto t = dyn_cast<TypeAliasType>(type))
92  return isHWValueType(t.getCanonicalType());
93 
94  return false;
95 }
96 
97 /// Return the hardware bit width of a type. Does not reflect any encoding,
98 /// padding, or storage scheme, just the bit (and wire width) of a
99 /// statically-size type. Reflects the number of wires needed to transmit a
100 /// value of this type. Returns -1 if the type is not known or cannot be
101 /// statically computed.
102 int64_t circt::hw::getBitWidth(mlir::Type type) {
103  return llvm::TypeSwitch<::mlir::Type, size_t>(type)
104  .Case<IntegerType>(
105  [](IntegerType t) { return t.getIntOrFloatBitWidth(); })
106  .Case<ArrayType, UnpackedArrayType>([](auto a) {
107  int64_t elementBitWidth = getBitWidth(a.getElementType());
108  if (elementBitWidth < 0)
109  return elementBitWidth;
110  int64_t dimBitWidth = a.getNumElements();
111  if (dimBitWidth < 0)
112  return static_cast<int64_t>(-1L);
113  return (int64_t)a.getNumElements() * elementBitWidth;
114  })
115  .Case<StructType>([](StructType s) {
116  int64_t total = 0;
117  for (auto field : s.getElements()) {
118  int64_t fieldSize = getBitWidth(field.type);
119  if (fieldSize < 0)
120  return fieldSize;
121  total += fieldSize;
122  }
123  return total;
124  })
125  .Case<UnionType>([](UnionType u) {
126  int64_t maxSize = 0;
127  for (auto field : u.getElements()) {
128  int64_t fieldSize = getBitWidth(field.type) + field.offset;
129  if (fieldSize > maxSize)
130  maxSize = fieldSize;
131  }
132  return maxSize;
133  })
134  .Case<EnumType>([](EnumType e) { return e.getBitWidth(); })
135  .Case<TypeAliasType>(
136  [](TypeAliasType t) { return getBitWidth(t.getCanonicalType()); })
137  .Default([](Type) { return -1; });
138 }
139 
140 /// Return true if the specified type contains known marker types like
141 /// InOutType. Unlike isHWValueType, this is not conservative, it only returns
142 /// false on known InOut types, rather than any unknown types.
143 bool circt::hw::hasHWInOutType(Type type) {
144  if (auto array = dyn_cast<ArrayType>(type))
145  return hasHWInOutType(array.getElementType());
146 
147  if (auto array = dyn_cast<UnpackedArrayType>(type))
148  return hasHWInOutType(array.getElementType());
149 
150  if (auto t = dyn_cast<StructType>(type)) {
151  return std::any_of(t.getElements().begin(), t.getElements().end(),
152  [](const auto &f) { return hasHWInOutType(f.type); });
153  }
154 
155  if (auto t = dyn_cast<TypeAliasType>(type))
156  return hasHWInOutType(t.getCanonicalType());
157 
158  return isa<InOutType>(type);
159 }
160 
161 /// Parse and print nested HW types nicely. These helper methods allow eliding
162 /// the "hw." prefix on array, inout, and other types when in a context that
163 /// expects HW subelement types.
164 static ParseResult parseHWElementType(Type &result, AsmParser &p) {
165  // If this is an HW dialect type, then we don't need/want the !hw. prefix
166  // redundantly specified.
167  auto fullString = static_cast<DialectAsmParser &>(p).getFullSymbolSpec();
168  auto *curPtr = p.getCurrentLocation().getPointer();
169  auto typeString =
170  StringRef(curPtr, fullString.size() - (curPtr - fullString.data()));
171 
172  if (typeString.starts_with("array<") || typeString.starts_with("inout<") ||
173  typeString.starts_with("uarray<") || typeString.starts_with("struct<") ||
174  typeString.starts_with("typealias<") || typeString.starts_with("int<") ||
175  typeString.starts_with("enum<")) {
176  llvm::StringRef mnemonic;
177  auto parseResult = generatedTypeParser(p, &mnemonic, result);
178  return parseResult.has_value() ? success() : failure();
179  }
180 
181  return p.parseType(result);
182 }
183 
184 static void printHWElementType(Type element, AsmPrinter &p) {
185  if (succeeded(generatedTypePrinter(element, p)))
186  return;
187  p.printType(element);
188 }
189 
190 //===----------------------------------------------------------------------===//
191 // Int Type
192 //===----------------------------------------------------------------------===//
193 
194 Type IntType::get(mlir::TypedAttr width) {
195  // The width expression must always be a 32-bit wide integer type itself.
196  auto widthWidth = llvm::dyn_cast<IntegerType>(width.getType());
197  assert(widthWidth && widthWidth.getWidth() == 32 &&
198  "!hw.int width must be 32-bits");
199  (void)widthWidth;
200 
201  if (auto cstWidth = llvm::dyn_cast<IntegerAttr>(width))
202  return IntegerType::get(width.getContext(),
203  cstWidth.getValue().getZExtValue());
204 
205  return Base::get(width.getContext(), width);
206 }
207 
208 Type IntType::parse(AsmParser &p) {
209  // The bitwidth of the parameter size is always 32 bits.
210  auto int32Type = p.getBuilder().getIntegerType(32);
211 
212  mlir::TypedAttr width;
213  if (p.parseLess() || p.parseAttribute(width, int32Type) || p.parseGreater())
214  return Type();
215  return get(width);
216 }
217 
218 void IntType::print(AsmPrinter &p) const {
219  p << "<";
220  p.printAttributeWithoutType(getWidth());
221  p << '>';
222 }
223 
224 //===----------------------------------------------------------------------===//
225 // Struct Type
226 //===----------------------------------------------------------------------===//
227 
228 namespace circt {
229 namespace hw {
230 namespace detail {
231 bool operator==(const FieldInfo &a, const FieldInfo &b) {
232  return a.name == b.name && a.type == b.type;
233 }
234 llvm::hash_code hash_value(const FieldInfo &fi) {
235  return llvm::hash_combine(fi.name, fi.type);
236 }
237 } // namespace detail
238 } // namespace hw
239 } // namespace circt
240 
241 /// Parse a list of unique field names and types within <>. E.g.:
242 /// <foo: i7, bar: i8>
243 static ParseResult parseFields(AsmParser &p,
244  SmallVectorImpl<FieldInfo> &parameters) {
245  llvm::StringSet<> nameSet;
246  bool hasDuplicateName = false;
247  auto parseResult = p.parseCommaSeparatedList(
248  mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
249  std::string name;
250  Type type;
251 
252  auto fieldLoc = p.getCurrentLocation();
253  if (p.parseKeywordOrString(&name) || p.parseColon() ||
254  p.parseType(type))
255  return failure();
256 
257  if (!nameSet.insert(name).second) {
258  p.emitError(fieldLoc, "duplicate field name \'" + name + "\'");
259  // Continue parsing to print all duplicates, but make sure to error
260  // eventually
261  hasDuplicateName = true;
262  }
263 
264  parameters.push_back(
265  FieldInfo{StringAttr::get(p.getContext(), name), type});
266  return success();
267  });
268 
269  if (hasDuplicateName)
270  return failure();
271  return parseResult;
272 }
273 
274 /// Print out a list of named fields surrounded by <>.
275 static void printFields(AsmPrinter &p, ArrayRef<FieldInfo> fields) {
276  p << '<';
277  llvm::interleaveComma(fields, p, [&](const FieldInfo &field) {
278  p.printKeywordOrString(field.name.getValue());
279  p << ": " << field.type;
280  });
281  p << ">";
282 }
283 
284 Type StructType::parse(AsmParser &p) {
285  llvm::SmallVector<FieldInfo, 4> parameters;
286  if (parseFields(p, parameters))
287  return Type();
288  return get(p.getContext(), parameters);
289 }
290 
291 LogicalResult StructType::verify(function_ref<InFlightDiagnostic()> emitError,
292  ArrayRef<StructType::FieldInfo> elements) {
293  llvm::SmallDenseSet<StringAttr> fieldNameSet;
294  LogicalResult result = success();
295  fieldNameSet.reserve(elements.size());
296  for (const auto &elt : elements)
297  if (!fieldNameSet.insert(elt.name).second) {
298  result = failure();
299  emitError() << "duplicate field name '" << elt.name.getValue()
300  << "' in hw.struct type";
301  }
302  return result;
303 }
304 
305 void StructType::print(AsmPrinter &p) const { printFields(p, getElements()); }
306 
307 Type StructType::getFieldType(mlir::StringRef fieldName) {
308  for (const auto &field : getElements())
309  if (field.name == fieldName)
310  return field.type;
311  return Type();
312 }
313 
314 std::optional<uint32_t> StructType::getFieldIndex(mlir::StringRef fieldName) {
315  ArrayRef<hw::StructType::FieldInfo> elems = getElements();
316  for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
317  if (elems[idx].name == fieldName)
318  return idx;
319  return {};
320 }
321 
322 std::optional<uint32_t> StructType::getFieldIndex(mlir::StringAttr fieldName) {
323  ArrayRef<hw::StructType::FieldInfo> elems = getElements();
324  for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
325  if (elems[idx].name == fieldName)
326  return idx;
327  return {};
328 }
329 
330 static std::pair<uint64_t, SmallVector<uint64_t>>
331 getFieldIDsStruct(const StructType &st) {
332  uint64_t fieldID = 0;
333  auto elements = st.getElements();
334  SmallVector<uint64_t> fieldIDs;
335  fieldIDs.reserve(elements.size());
336  for (auto &element : elements) {
337  auto type = element.type;
338  fieldID += 1;
339  fieldIDs.push_back(fieldID);
340  // Increment the field ID for the next field by the number of subfields.
341  fieldID += hw::FieldIdImpl::getMaxFieldID(type);
342  }
343  return {fieldID, fieldIDs};
344 }
345 
346 void StructType::getInnerTypes(SmallVectorImpl<Type> &types) {
347  for (const auto &field : getElements())
348  types.push_back(field.type);
349 }
350 
351 uint64_t StructType::getMaxFieldID() const {
352  uint64_t fieldID = 0;
353  for (const auto &field : getElements())
354  fieldID += 1 + hw::FieldIdImpl::getMaxFieldID(field.type);
355  return fieldID;
356 }
357 
358 std::pair<Type, uint64_t>
359 StructType::getSubTypeByFieldID(uint64_t fieldID) const {
360  if (fieldID == 0)
361  return {*this, 0};
362  auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
363  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
364  auto subfieldIndex = std::distance(fieldIDs.begin(), it);
365  auto subfieldType = getElements()[subfieldIndex].type;
366  auto subfieldID = fieldID - fieldIDs[subfieldIndex];
367  return {subfieldType, subfieldID};
368 }
369 
370 std::pair<uint64_t, bool>
371 StructType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
372  auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
373  auto childRoot = fieldIDs[index];
374  auto rangeEnd =
375  index + 1 >= getElements().size() ? maxId : (fieldIDs[index + 1] - 1);
376  return std::make_pair(fieldID - childRoot,
377  fieldID >= childRoot && fieldID <= rangeEnd);
378 }
379 
380 uint64_t StructType::getFieldID(uint64_t index) const {
381  auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
382  return fieldIDs[index];
383 }
384 
385 uint64_t StructType::getIndexForFieldID(uint64_t fieldID) const {
386  assert(!getElements().empty() && "Bundle must have >0 fields");
387  auto [maxId, fieldIDs] = getFieldIDsStruct(*this);
388  auto *it = std::prev(llvm::upper_bound(fieldIDs, fieldID));
389  return std::distance(fieldIDs.begin(), it);
390 }
391 
392 std::pair<uint64_t, uint64_t>
393 StructType::getIndexAndSubfieldID(uint64_t fieldID) const {
394  auto index = getIndexForFieldID(fieldID);
395  auto elementFieldID = getFieldID(index);
396  return {index, fieldID - elementFieldID};
397 }
398 
399 //===----------------------------------------------------------------------===//
400 // Union Type
401 //===----------------------------------------------------------------------===//
402 
403 namespace circt {
404 namespace hw {
405 namespace detail {
406 bool operator==(const OffsetFieldInfo &a, const OffsetFieldInfo &b) {
407  return a.name == b.name && a.type == b.type && a.offset == b.offset;
408 }
409 // NOLINTNEXTLINE
410 llvm::hash_code hash_value(const OffsetFieldInfo &fi) {
411  return llvm::hash_combine(fi.name, fi.type, fi.offset);
412 }
413 } // namespace detail
414 } // namespace hw
415 } // namespace circt
416 
417 Type UnionType::parse(AsmParser &p) {
418  llvm::SmallVector<FieldInfo, 4> parameters;
419  llvm::StringSet<> nameSet;
420  bool hasDuplicateName = false;
421  if (p.parseCommaSeparatedList(
422  mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
423  StringRef name;
424  Type type;
425 
426  auto fieldLoc = p.getCurrentLocation();
427  if (p.parseKeyword(&name) || p.parseColon() || p.parseType(type))
428  return failure();
429 
430  if (!nameSet.insert(name).second) {
431  p.emitError(fieldLoc, "duplicate field name \'" + name +
432  "\' in hw.union type");
433  // Continue parsing to print all duplicates, but make sure to
434  // error eventually
435  hasDuplicateName = true;
436  }
437 
438  size_t offset = 0;
439  if (succeeded(p.parseOptionalKeyword("offset")))
440  if (p.parseInteger(offset))
441  return failure();
442  parameters.push_back(UnionType::FieldInfo{
443  StringAttr::get(p.getContext(), name), type, offset});
444  return success();
445  }))
446  return Type();
447 
448  if (hasDuplicateName)
449  return Type();
450 
451  return get(p.getContext(), parameters);
452 }
453 
454 void UnionType::print(AsmPrinter &odsPrinter) const {
455  odsPrinter << '<';
456  llvm::interleaveComma(
457  getElements(), odsPrinter, [&](const UnionType::FieldInfo &field) {
458  odsPrinter << field.name.getValue() << ": " << field.type;
459  if (field.offset)
460  odsPrinter << " offset " << field.offset;
461  });
462  odsPrinter << ">";
463 }
464 
465 LogicalResult UnionType::verify(function_ref<InFlightDiagnostic()> emitError,
466  ArrayRef<UnionType::FieldInfo> elements) {
467  llvm::SmallDenseSet<StringAttr> fieldNameSet;
468  LogicalResult result = success();
469  fieldNameSet.reserve(elements.size());
470  for (const auto &elt : elements)
471  if (!fieldNameSet.insert(elt.name).second) {
472  result = failure();
473  emitError() << "duplicate field name '" << elt.name.getValue()
474  << "' in hw.union type";
475  }
476  return result;
477 }
478 
479 std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringAttr fieldName) {
480  ArrayRef<hw::UnionType::FieldInfo> elems = getElements();
481  for (size_t idx = 0, numElems = elems.size(); idx < numElems; ++idx)
482  if (elems[idx].name == fieldName)
483  return idx;
484  return {};
485 }
486 
487 std::optional<uint32_t> UnionType::getFieldIndex(mlir::StringRef fieldName) {
488  return getFieldIndex(StringAttr::get(getContext(), fieldName));
489 }
490 
491 UnionType::FieldInfo UnionType::getFieldInfo(::mlir::StringRef fieldName) {
492  if (auto fieldIndex = getFieldIndex(fieldName))
493  return getElements()[*fieldIndex];
494  return FieldInfo();
495 }
496 
497 Type UnionType::getFieldType(mlir::StringRef fieldName) {
498  return getFieldInfo(fieldName).type;
499 }
500 
501 //===----------------------------------------------------------------------===//
502 // Enum Type
503 //===----------------------------------------------------------------------===//
504 
505 Type EnumType::parse(AsmParser &p) {
506  llvm::SmallVector<Attribute> fields;
507 
508  if (p.parseCommaSeparatedList(AsmParser::Delimiter::LessGreater, [&]() {
509  StringRef name;
510  if (p.parseKeyword(&name))
511  return failure();
512  fields.push_back(StringAttr::get(p.getContext(), name));
513  return success();
514  }))
515  return Type();
516 
517  return get(p.getContext(), ArrayAttr::get(p.getContext(), fields));
518 }
519 
520 void EnumType::print(AsmPrinter &p) const {
521  p << '<';
522  llvm::interleaveComma(getFields(), p, [&](Attribute enumerator) {
523  p << llvm::cast<StringAttr>(enumerator).getValue();
524  });
525  p << ">";
526 }
527 
528 bool EnumType::contains(mlir::StringRef field) {
529  return indexOf(field).has_value();
530 }
531 
532 std::optional<size_t> EnumType::indexOf(mlir::StringRef field) {
533  for (auto it : llvm::enumerate(getFields()))
534  if (llvm::cast<StringAttr>(it.value()).getValue() == field)
535  return it.index();
536  return {};
537 }
538 
539 size_t EnumType::getBitWidth() {
540  auto w = getFields().size();
541  if (w > 1)
542  return llvm::Log2_64_Ceil(getFields().size());
543  return 1;
544 }
545 
546 //===----------------------------------------------------------------------===//
547 // ArrayType
548 //===----------------------------------------------------------------------===//
549 
550 static LogicalResult parseArray(AsmParser &p, Attribute &dim, Type &inner) {
551  if (p.parseLess())
552  return failure();
553 
554  uint64_t dimLiteral;
555  auto int64Type = p.getBuilder().getIntegerType(64);
556 
557  if (auto res = p.parseOptionalInteger(dimLiteral); res.has_value())
558  dim = p.getBuilder().getI64IntegerAttr(dimLiteral);
559  else if (!p.parseOptionalAttribute(dim, int64Type).has_value())
560  return failure();
561 
562  if (!isa<IntegerAttr, ParamExprAttr, ParamDeclRefAttr>(dim)) {
563  p.emitError(p.getNameLoc(), "unsupported dimension kind in hw.array");
564  return failure();
565  }
566 
567  if (p.parseXInDimensionList() || parseHWElementType(inner, p) ||
568  p.parseGreater())
569  return failure();
570 
571  return success();
572 }
573 
574 Type ArrayType::parse(AsmParser &p) {
575  Attribute dim;
576  Type inner;
577 
578  if (failed(parseArray(p, dim, inner)))
579  return Type();
580 
581  auto loc = p.getEncodedSourceLoc(p.getCurrentLocation());
582  if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner, dim)))
583  return Type();
584 
585  return get(inner.getContext(), inner, dim);
586 }
587 
588 void ArrayType::print(AsmPrinter &p) const {
589  p << "<";
590  p.printAttributeWithoutType(getSizeAttr());
591  p << "x";
592  printHWElementType(getElementType(), p);
593  p << '>';
594 }
595 
596 size_t ArrayType::getNumElements() const {
597  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(getSizeAttr()))
598  return intAttr.getInt();
599  return -1;
600 }
601 
602 LogicalResult ArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
603  Type innerType, Attribute size) {
604  if (hasHWInOutType(innerType))
605  return emitError() << "hw.array cannot contain InOut types";
606  return success();
607 }
608 
609 uint64_t ArrayType::getMaxFieldID() const {
610  return getNumElements() *
611  (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
612 }
613 
614 std::pair<Type, uint64_t>
615 ArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
616  if (fieldID == 0)
617  return {*this, 0};
618  return {getElementType(), getIndexAndSubfieldID(fieldID).second};
619 }
620 
621 std::pair<uint64_t, bool>
622 ArrayType::projectToChildFieldID(uint64_t fieldID, uint64_t index) const {
623  auto childRoot = getFieldID(index);
624  auto rangeEnd =
625  index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
626  return std::make_pair(fieldID - childRoot,
627  fieldID >= childRoot && fieldID <= rangeEnd);
628 }
629 
630 uint64_t ArrayType::getIndexForFieldID(uint64_t fieldID) const {
631  assert(fieldID && "fieldID must be at least 1");
632  // Divide the field ID by the number of fieldID's per element.
633  return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
634 }
635 
636 std::pair<uint64_t, uint64_t>
637 ArrayType::getIndexAndSubfieldID(uint64_t fieldID) const {
638  auto index = getIndexForFieldID(fieldID);
639  auto elementFieldID = getFieldID(index);
640  return {index, fieldID - elementFieldID};
641 }
642 
643 uint64_t ArrayType::getFieldID(uint64_t index) const {
644  return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // UnpackedArrayType
649 //===----------------------------------------------------------------------===//
650 
651 Type UnpackedArrayType::parse(AsmParser &p) {
652  Attribute dim;
653  Type inner;
654 
655  if (failed(parseArray(p, dim, inner)))
656  return Type();
657 
658  auto loc = p.getEncodedSourceLoc(p.getCurrentLocation());
659  if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner, dim)))
660  return Type();
661 
662  return get(inner.getContext(), inner, dim);
663 }
664 
665 void UnpackedArrayType::print(AsmPrinter &p) const {
666  p << "<";
667  p.printAttributeWithoutType(getSizeAttr());
668  p << "x";
669  printHWElementType(getElementType(), p);
670  p << '>';
671 }
672 
673 LogicalResult
674 UnpackedArrayType::verify(function_ref<InFlightDiagnostic()> emitError,
675  Type innerType, Attribute size) {
676  if (!isHWValueType(innerType))
677  return emitError() << "invalid element for uarray type";
678  return success();
679 }
680 
681 size_t UnpackedArrayType::getNumElements() const {
682  if (auto intAttr = llvm::dyn_cast<IntegerAttr>(getSizeAttr()))
683  return intAttr.getInt();
684  return -1;
685 }
686 
687 uint64_t UnpackedArrayType::getMaxFieldID() const {
688  return getNumElements() *
689  (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
690 }
691 
692 std::pair<Type, uint64_t>
693 UnpackedArrayType::getSubTypeByFieldID(uint64_t fieldID) const {
694  if (fieldID == 0)
695  return {*this, 0};
696  return {getElementType(), getIndexAndSubfieldID(fieldID).second};
697 }
698 
699 std::pair<uint64_t, bool>
700 UnpackedArrayType::projectToChildFieldID(uint64_t fieldID,
701  uint64_t index) const {
702  auto childRoot = getFieldID(index);
703  auto rangeEnd =
704  index >= getNumElements() ? getMaxFieldID() : (getFieldID(index + 1) - 1);
705  return std::make_pair(fieldID - childRoot,
706  fieldID >= childRoot && fieldID <= rangeEnd);
707 }
708 
709 uint64_t UnpackedArrayType::getIndexForFieldID(uint64_t fieldID) const {
710  assert(fieldID && "fieldID must be at least 1");
711  // Divide the field ID by the number of fieldID's per element.
712  return (fieldID - 1) / (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
713 }
714 
715 std::pair<uint64_t, uint64_t>
716 UnpackedArrayType::getIndexAndSubfieldID(uint64_t fieldID) const {
717  auto index = getIndexForFieldID(fieldID);
718  auto elementFieldID = getFieldID(index);
719  return {index, fieldID - elementFieldID};
720 }
721 
722 uint64_t UnpackedArrayType::getFieldID(uint64_t index) const {
723  return 1 + index * (hw::FieldIdImpl::getMaxFieldID(getElementType()) + 1);
724 }
725 
726 //===----------------------------------------------------------------------===//
727 // InOutType
728 //===----------------------------------------------------------------------===//
729 
730 Type InOutType::parse(AsmParser &p) {
731  Type inner;
732  if (p.parseLess() || parseHWElementType(inner, p) || p.parseGreater())
733  return Type();
734 
735  auto loc = p.getEncodedSourceLoc(p.getCurrentLocation());
736  if (failed(verify(mlir::detail::getDefaultDiagnosticEmitFn(loc), inner)))
737  return Type();
738 
739  return get(p.getContext(), inner);
740 }
741 
742 void InOutType::print(AsmPrinter &p) const {
743  p << "<";
744  printHWElementType(getElementType(), p);
745  p << '>';
746 }
747 
748 LogicalResult InOutType::verify(function_ref<InFlightDiagnostic()> emitError,
749  Type innerType) {
750  if (!isHWValueType(innerType))
751  return emitError() << "invalid element for hw.inout type " << innerType;
752  return success();
753 }
754 
755 //===----------------------------------------------------------------------===//
756 // TypeAliasType
757 //===----------------------------------------------------------------------===//
758 
759 static Type computeCanonicalType(Type type) {
760  return llvm::TypeSwitch<Type, Type>(type)
761  .Case([](TypeAliasType t) {
762  return computeCanonicalType(t.getCanonicalType());
763  })
764  .Case([](ArrayType t) {
765  return ArrayType::get(computeCanonicalType(t.getElementType()),
766  t.getNumElements());
767  })
768  .Case([](UnpackedArrayType t) {
769  return UnpackedArrayType::get(computeCanonicalType(t.getElementType()),
770  t.getNumElements());
771  })
772  .Case([](StructType t) {
773  SmallVector<StructType::FieldInfo> fieldInfo;
774  for (auto field : t.getElements())
775  fieldInfo.push_back(StructType::FieldInfo{
776  field.name, computeCanonicalType(field.type)});
777  return StructType::get(t.getContext(), fieldInfo);
778  })
779  .Default([](Type t) { return t; });
780 }
781 
782 TypeAliasType TypeAliasType::get(SymbolRefAttr ref, Type innerType) {
783  return get(ref.getContext(), ref, innerType, computeCanonicalType(innerType));
784 }
785 
786 Type TypeAliasType::parse(AsmParser &p) {
787  SymbolRefAttr ref;
788  Type type;
789  if (p.parseLess() || p.parseAttribute(ref) || p.parseComma() ||
790  p.parseType(type) || p.parseGreater())
791  return Type();
792 
793  return get(ref, type);
794 }
795 
796 void TypeAliasType::print(AsmPrinter &p) const {
797  p << "<" << getRef() << ", " << getInnerType() << ">";
798 }
799 
800 /// Return the Typedecl referenced by this TypeAlias, given the module to look
801 /// in. This returns null when the IR is malformed.
802 TypedeclOp TypeAliasType::getTypeDecl(const HWSymbolCache &cache) {
803  SymbolRefAttr ref = getRef();
804  auto typeScope = ::dyn_cast_or_null<TypeScopeOp>(
805  cache.getDefinition(ref.getRootReference()));
806  if (!typeScope)
807  return {};
808 
809  return typeScope.lookupSymbol<TypedeclOp>(ref.getLeafReference());
810 }
811 
812 ////////////////////////////////////////////////////////////////////////////////
813 // ModuleType
814 ////////////////////////////////////////////////////////////////////////////////
815 
816 LogicalResult ModuleType::verify(function_ref<InFlightDiagnostic()> emitError,
817  ArrayRef<ModulePort> ports) {
818  if (llvm::any_of(ports, [](const ModulePort &port) {
819  return hasHWInOutType(port.type);
820  }))
821  return emitError() << "Ports cannot be inout types";
822  return success();
823 }
824 
825 size_t ModuleType::getPortIdForInputId(size_t idx) {
826  assert(idx < getImpl()->inputToAbs.size() && "input port out of range");
827  return getImpl()->inputToAbs[idx];
828 }
829 
830 size_t ModuleType::getPortIdForOutputId(size_t idx) {
831  assert(idx < getImpl()->outputToAbs.size() && " output port out of range");
832  return getImpl()->outputToAbs[idx];
833 }
834 
835 size_t ModuleType::getInputIdForPortId(size_t idx) {
836  auto nIdx = getImpl()->absToInput[idx];
837  assert(nIdx != ~0ULL);
838  return nIdx;
839 }
840 
841 size_t ModuleType::getOutputIdForPortId(size_t idx) {
842  auto nIdx = getImpl()->absToOutput[idx];
843  assert(nIdx != ~0ULL);
844  return nIdx;
845 }
846 
847 size_t ModuleType::getNumInputs() { return getImpl()->inputToAbs.size(); }
848 
849 size_t ModuleType::getNumOutputs() { return getImpl()->outputToAbs.size(); }
850 
851 size_t ModuleType::getNumPorts() { return getPorts().size(); }
852 
853 SmallVector<Type> ModuleType::getInputTypes() {
854  SmallVector<Type> retval;
855  for (auto &p : getPorts()) {
856  if (p.dir == ModulePort::Direction::Input)
857  retval.push_back(p.type);
858  else if (p.dir == ModulePort::Direction::InOut) {
859  retval.push_back(hw::InOutType::get(p.type));
860  }
861  }
862  return retval;
863 }
864 
865 SmallVector<Type> ModuleType::getOutputTypes() {
866  SmallVector<Type> retval;
867  for (auto &p : getPorts())
868  if (p.dir == ModulePort::Direction::Output)
869  retval.push_back(p.type);
870  return retval;
871 }
872 
873 SmallVector<Type> ModuleType::getPortTypes() {
874  SmallVector<Type> retval;
875  for (auto &p : getPorts())
876  retval.push_back(p.type);
877  return retval;
878 }
879 
880 Type ModuleType::getInputType(size_t idx) {
881  const auto &portInfo = getPorts()[getPortIdForInputId(idx)];
882  if (portInfo.dir != ModulePort::InOut)
883  return portInfo.type;
884  return InOutType::get(portInfo.type);
885 }
886 
887 Type ModuleType::getOutputType(size_t idx) {
888  return getPorts()[getPortIdForOutputId(idx)].type;
889 }
890 
891 SmallVector<Attribute> ModuleType::getInputNames() {
892  SmallVector<Attribute> retval;
893  for (auto &p : getPorts())
894  if (p.dir != ModulePort::Direction::Output)
895  retval.push_back(p.name);
896  return retval;
897 }
898 
899 SmallVector<Attribute> ModuleType::getOutputNames() {
900  SmallVector<Attribute> retval;
901  for (auto &p : getPorts())
902  if (p.dir == ModulePort::Direction::Output)
903  retval.push_back(p.name);
904  return retval;
905 }
906 
907 StringAttr ModuleType::getPortNameAttr(size_t idx) {
908  return getPorts()[idx].name;
909 }
910 
911 StringRef ModuleType::getPortName(size_t idx) {
912  auto sa = getPortNameAttr(idx);
913  if (sa)
914  return sa.getValue();
915  return {};
916 }
917 
918 StringAttr ModuleType::getInputNameAttr(size_t idx) {
919  return getPorts()[getPortIdForInputId(idx)].name;
920 }
921 
922 StringRef ModuleType::getInputName(size_t idx) {
923  auto sa = getInputNameAttr(idx);
924  if (sa)
925  return sa.getValue();
926  return {};
927 }
928 
929 StringAttr ModuleType::getOutputNameAttr(size_t idx) {
930  return getPorts()[getPortIdForOutputId(idx)].name;
931 }
932 
933 StringRef ModuleType::getOutputName(size_t idx) {
934  auto sa = getOutputNameAttr(idx);
935  if (sa)
936  return sa.getValue();
937  return {};
938 }
939 
940 bool ModuleType::isOutput(size_t idx) {
941  auto &p = getPorts()[idx];
942  return p.dir == ModulePort::Direction::Output;
943 }
944 
945 FunctionType ModuleType::getFuncType() {
946  SmallVector<Type> inputs, outputs;
947  for (auto p : getPorts())
948  if (p.dir == ModulePort::Input)
949  inputs.push_back(p.type);
950  else if (p.dir == ModulePort::InOut)
951  inputs.push_back(InOutType::get(p.type));
952  else
953  outputs.push_back(p.type);
954  return FunctionType::get(getContext(), inputs, outputs);
955 }
956 
957 ArrayRef<ModulePort> ModuleType::getPorts() const {
958  return getImpl()->getPorts();
959 }
960 
961 FailureOr<ModuleType> ModuleType::resolveParametricTypes(ArrayAttr parameters,
962  LocationAttr loc,
963  bool emitErrors) {
964  SmallVector<ModulePort, 8> resolvedPorts;
965  for (ModulePort port : getPorts()) {
966  FailureOr<Type> resolvedType =
967  evaluateParametricType(loc, parameters, port.type, emitErrors);
968  if (failed(resolvedType))
969  return failure();
970  port.type = *resolvedType;
971  resolvedPorts.push_back(port);
972  }
973  return ModuleType::get(getContext(), resolvedPorts);
974 }
975 
976 static StringRef dirToStr(ModulePort::Direction dir) {
977  switch (dir) {
978  case ModulePort::Direction::Input:
979  return "input";
980  case ModulePort::Direction::Output:
981  return "output";
982  case ModulePort::Direction::InOut:
983  return "inout";
984  }
985 }
986 
987 static ModulePort::Direction strToDir(StringRef str) {
988  if (str == "input")
989  return ModulePort::Direction::Input;
990  if (str == "output")
991  return ModulePort::Direction::Output;
992  if (str == "inout")
993  return ModulePort::Direction::InOut;
994  llvm::report_fatal_error("invalid direction");
995 }
996 
997 /// Parse a list of field names and types within <>. E.g.:
998 /// <input foo: i7, output bar: i8>
999 static ParseResult parsePorts(AsmParser &p,
1000  SmallVectorImpl<ModulePort> &ports) {
1001  return p.parseCommaSeparatedList(
1002  mlir::AsmParser::Delimiter::LessGreater, [&]() -> ParseResult {
1003  StringRef dir;
1004  std::string name;
1005  Type type;
1006  if (p.parseKeyword(&dir) || p.parseKeywordOrString(&name) ||
1007  p.parseColon() || p.parseType(type))
1008  return failure();
1009  ports.push_back(
1010  {StringAttr::get(p.getContext(), name), type, strToDir(dir)});
1011  return success();
1012  });
1013 }
1014 
1015 /// Print out a list of named fields surrounded by <>.
1016 static void printPorts(AsmPrinter &p, ArrayRef<ModulePort> ports) {
1017  p << '<';
1018  llvm::interleaveComma(ports, p, [&](const ModulePort &port) {
1019  p << dirToStr(port.dir) << " ";
1020  p.printKeywordOrString(port.name.getValue());
1021  p << " : " << port.type;
1022  });
1023  p << ">";
1024 }
1025 
1026 Type ModuleType::parse(AsmParser &odsParser) {
1027  llvm::SmallVector<ModulePort, 4> ports;
1028  if (parsePorts(odsParser, ports))
1029  return Type();
1030  return get(odsParser.getContext(), ports);
1031 }
1032 
1033 void ModuleType::print(AsmPrinter &odsPrinter) const {
1034  printPorts(odsPrinter, getPorts());
1035 }
1036 
1037 ModuleType circt::hw::detail::fnToMod(Operation *op,
1038  ArrayRef<Attribute> inputNames,
1039  ArrayRef<Attribute> outputNames) {
1040  return fnToMod(
1041  cast<FunctionType>(cast<mlir::FunctionOpInterface>(op).getFunctionType()),
1042  inputNames, outputNames);
1043 }
1044 
1045 ModuleType circt::hw::detail::fnToMod(FunctionType fnty,
1046  ArrayRef<Attribute> inputNames,
1047  ArrayRef<Attribute> outputNames) {
1048  SmallVector<ModulePort> ports;
1049  if (!inputNames.empty()) {
1050  for (auto [t, n] : llvm::zip_equal(fnty.getInputs(), inputNames))
1051  if (auto iot = dyn_cast<hw::InOutType>(t))
1052  ports.push_back({cast<StringAttr>(n), iot.getElementType(),
1053  ModulePort::Direction::InOut});
1054  else
1055  ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Input});
1056  } else {
1057  for (auto t : fnty.getInputs())
1058  if (auto iot = dyn_cast<hw::InOutType>(t))
1059  ports.push_back(
1060  {{}, iot.getElementType(), ModulePort::Direction::InOut});
1061  else
1062  ports.push_back({{}, t, ModulePort::Direction::Input});
1063  }
1064  if (!outputNames.empty()) {
1065  for (auto [t, n] : llvm::zip_equal(fnty.getResults(), outputNames))
1066  ports.push_back({cast<StringAttr>(n), t, ModulePort::Direction::Output});
1067  } else {
1068  for (auto t : fnty.getResults())
1069  ports.push_back({{}, t, ModulePort::Direction::Output});
1070  }
1071  return ModuleType::get(fnty.getContext(), ports);
1072 }
1073 
1074 detail::ModuleTypeStorage::ModuleTypeStorage(ArrayRef<ModulePort> inPorts)
1075  : ports(inPorts) {
1076  size_t nextInput = 0;
1077  size_t nextOutput = 0;
1078  for (auto [idx, p] : llvm::enumerate(ports)) {
1079  if (p.dir == ModulePort::Direction::Output) {
1080  outputToAbs.push_back(idx);
1081  absToOutput.push_back(nextOutput);
1082  absToInput.push_back(~0ULL);
1083  ++nextOutput;
1084  } else {
1085  inputToAbs.push_back(idx);
1086  absToInput.push_back(nextInput);
1087  absToOutput.push_back(~0ULL);
1088  ++nextInput;
1089  }
1090  }
1091 }
1092 
1093 ////////////////////////////////////////////////////////////////////////////////
1094 // BoilerPlate
1095 ////////////////////////////////////////////////////////////////////////////////
1096 
1097 void HWDialect::registerTypes() {
1098  addTypes<
1099 #define GET_TYPEDEF_LIST
1100 #include "circt/Dialect/HW/HWTypes.cpp.inc"
1101  >();
1102 }
assert(baseType &&"element must be base type")
int32_t width
Definition: FIRRTL.cpp:36
static void printHWElementType(Type element, AsmPrinter &p)
Definition: HWTypes.cpp:184
static LogicalResult parseArray(AsmParser &p, Attribute &dim, Type &inner)
Definition: HWTypes.cpp:550
static ParseResult parseHWElementType(Type &result, AsmParser &p)
Parse and print nested HW types nicely.
Definition: HWTypes.cpp:164
static unsigned getIndexForFieldID(BundleType type, unsigned fieldID)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
uint64_t getWidth(Type t)
Definition: ESIPasses.cpp:34
std::pair< uint64_t, uint64_t > getIndexAndSubfieldID(Type type, uint64_t fieldID)
uint64_t getFieldID(Type type, uint64_t index)
uint64_t getMaxFieldID(Type)
bool isHWIntegerType(mlir::Type type)
Return true if the specified type is a value HW Integer type.
Definition: HWTypes.cpp:52
bool isHWValueType(mlir::Type type)
Return true if the specified type can be used as an HW value type, that is the set of types that can ...
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition: HWTypes.cpp:102
bool isHWEnumType(mlir::Type type)
Return true if the specified type is a HW Enum type.
Definition: HWTypes.cpp:65
mlir::Type getCanonicalType(mlir::Type type)
Definition: HWTypes.cpp:41
bool hasHWInOutType(mlir::Type type)
Return true if the specified type contains known marker types like InOutType.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21