17#include "mlir/IR/Attributes.h"
18#include "mlir/IR/DialectImplementation.h"
19#include "llvm/ADT/ArrayRef.h"
20#include "llvm/ADT/DenseMap.h"
21#include "llvm/ADT/TypeSwitch.h"
26AnyType AnyType::get(MLIRContext *
context) {
return Base::get(
context); }
29WindowFieldType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
30 StringAttr fieldName, uint64_t numItems,
31 uint64_t bulkCountWidth) {
32 if (numItems > 0 && bulkCountWidth > 0)
33 return emitError() <<
"cannot specify both numItems and countWidth for "
35 << fieldName.getValue() <<
"'";
41 case ChannelSignaling::ValidReady:
43 case ChannelSignaling::FIFO:
46 llvm_unreachable(
"Unhandled ChannelSignaling");
48std::optional<int64_t> ChannelType::getBitWidth()
const {
58 return llvm::make_filter_range(chan.getUses(), [](
auto &use) {
59 return !isa<SnoopValidReadyOp, SnoopTransactionOp>(use.getOwner());
62SmallVector<std::reference_wrapper<OpOperand>, 4>
63ChannelType::getConsumers(mlir::TypedValue<ChannelType> chan) {
64 return SmallVector<std::reference_wrapper<OpOperand>, 4>(
67bool ChannelType::hasOneConsumer(mlir::TypedValue<ChannelType> chan) {
69 if (consumers.empty())
71 return ++consumers.begin() == consumers.end();
73bool ChannelType::hasNoConsumers(mlir::TypedValue<ChannelType> chan) {
76OpOperand *ChannelType::getSingleConsumer(mlir::TypedValue<ChannelType> chan) {
78 auto iter = consumers.begin();
79 if (iter == consumers.end())
81 OpOperand *result = &*iter;
82 if (++iter != consumers.end())
86LogicalResult ChannelType::verifyChannel(mlir::TypedValue<ChannelType> chan) {
88 if (consumers.empty() || ++consumers.begin() == consumers.end())
90 auto err = chan.getDefiningOp()->emitOpError(
91 "channels must have at most one consumer");
92 for (
auto &consumer : consumers)
93 err.attachNote(consumer.getOwner()->
getLoc()) <<
"channel used here";
97std::optional<int64_t> WindowType::getBitWidth()
const {
98 return hw::getBitWidth(getLoweredType());
102WindowType::verify(llvm::function_ref<InFlightDiagnostic()> emitError,
103 StringAttr name, Type into,
104 ArrayRef<WindowFrameType> frames) {
105 auto structInto = hw::type_dyn_cast<hw::StructType>(into);
107 return emitError() <<
"only windows into structs are currently supported";
111 for (hw::StructType::FieldInfo field : structInto.getElements())
112 fieldTypes[field.name] = field.type;
115 DenseSet<StringAttr> consumedFields;
118 DenseSet<StringAttr> bulkTransferFields;
120 for (
auto frame : frames) {
121 bool encounteredArrayOrListWithNumItems =
false;
123 for (WindowFieldType field : frame.getMembers()) {
124 auto fieldTypeIter = fieldTypes.find(field.getFieldName());
125 if (fieldTypeIter == fieldTypes.end())
126 return emitError() <<
"invalid field name: " << field.getFieldName();
128 Type fieldType = fieldTypeIter->getSecond();
134 bool isBulkTransferHeader = field.getBulkCountWidth() > 0;
135 bool isBulkTransferData =
136 bulkTransferFields.contains(field.getFieldName());
138 if (consumedFields.contains(field.getFieldName())) {
141 return emitError() <<
"field '" << field.getFieldName()
142 <<
"' already consumed by a previous frame";
147 bool isArrayOrListWithNumItems =
148 hw::type_isa<hw::ArrayType, esi::ListType>(fieldType) &&
149 field.getNumItems() > 0;
150 if (isArrayOrListWithNumItems) {
151 if (encounteredArrayOrListWithNumItems)
153 <<
"cannot have two array or list fields with num items (in "
154 << field.getFieldName() <<
")";
155 encounteredArrayOrListWithNumItems =
true;
159 uint64_t numItems = field.getNumItems();
161 if (
auto arrField = hw::type_dyn_cast<hw::ArrayType>(fieldType)) {
162 if (numItems > arrField.getNumElements())
164 <<
"num items is larger than array size in field "
165 << field.getFieldName();
166 }
else if (!hw::type_isa<esi::ListType>(fieldType)) {
167 return emitError() <<
"specification of num items only allowed on "
168 "array or list fields (in "
169 << field.getFieldName() <<
")";
175 uint64_t bulkCountWidth = field.getBulkCountWidth();
176 if (bulkCountWidth > 0) {
177 if (!hw::type_isa<esi::ListType>(fieldType))
178 return emitError() <<
"bulk transfer (countWidth) only allowed on "
180 << field.getFieldName() <<
")";
183 if (bulkTransferFields.contains(field.getFieldName()))
184 return emitError() <<
"field '" << field.getFieldName()
185 <<
"' already has countWidth specified";
187 bulkTransferFields.insert(field.getFieldName());
191 if (isBulkTransferData) {
194 consumedFields.insert(field.getFieldName());
195 }
else if (!isBulkTransferHeader) {
197 consumedFields.insert(field.getFieldName());
204Type WindowType::getLoweredType()
const {
206 auto into = hw::type_cast<hw::StructType>(getInto());
208 for (hw::StructType::FieldInfo field : into.getElements())
209 intoFields[field.name] = field.type;
211 auto getInnerTypeOrSelf = [&](Type t) {
212 return TypeSwitch<Type, Type>(t)
213 .Case<hw::ArrayType>(
214 [](hw::ArrayType arr) {
return arr.getElementType(); })
215 .Case<esi::ListType>(
217 .Default([&](Type t) {
return t; });
222 auto wrapInTypeAliasIfNeeded = [&](Type loweredType) -> Type {
223 if (
auto intoAlias = dyn_cast<hw::TypeAliasType>(getInto())) {
224 auto intoRef = intoAlias.getRef();
225 std::string aliasName = (Twine(intoRef.getLeafReference().getValue()) +
228 auto newRef = SymbolRefAttr::get(
229 intoRef.getRootReference(),
230 {FlatSymbolRefAttr::get(StringAttr::get(getContext(), aliasName))});
231 return hw::TypeAliasType::get(newRef, loweredType);
238 DenseSet<StringAttr> bulkTransferFields;
239 for (WindowFrameType frame : getFrames()) {
240 for (WindowFieldType field : frame.getMembers()) {
241 if (field.getBulkCountWidth() > 0)
242 bulkTransferFields.insert(field.getFieldName());
247 SmallVector<hw::UnionType::FieldInfo, 4> unionFields;
248 for (WindowFrameType frame : getFrames()) {
251 SmallVector<hw::StructType::FieldInfo, 4> fields;
252 SmallVector<hw::StructType::FieldInfo, 4> leftOverFields;
253 bool hasLeftOver =
false;
254 StringAttr leftOverName;
256 for (WindowFieldType field : frame.getMembers()) {
257 auto fieldTypeIter = intoFields.find(field.getFieldName());
258 assert(fieldTypeIter != intoFields.end());
259 auto fieldType = fieldTypeIter->getSecond();
262 uint64_t bulkCountWidth = field.getBulkCountWidth();
263 if (bulkCountWidth > 0) {
268 {StringAttr::get(getContext(),
269 Twine(field.getFieldName().getValue()) +
"_count"),
270 IntegerType::get(getContext(), bulkCountWidth)});
277 bool isBulkTransferData =
278 bulkTransferFields.contains(field.getFieldName());
281 if (field.getNumItems() == 0) {
284 auto type = getInnerTypeOrSelf(fieldType);
285 fields.push_back({field.getFieldName(), type});
286 leftOverFields.push_back({field.getFieldName(), type});
288 if (hw::type_isa<esi::ListType>(fieldType) && !isBulkTransferData) {
291 auto lastType = IntegerType::get(getContext(), 1);
292 auto lastField = StringAttr::get(getContext(),
"last");
293 fields.push_back({lastField, lastType});
294 leftOverFields.push_back({lastField, lastType});
298 hw::type_dyn_cast<hw::ArrayType>(fieldTypeIter->getSecond())) {
301 {field.getFieldName(), hw::ArrayType::get(array.getElementType(),
302 field.getNumItems())});
306 size_t leftOver = array.getNumElements() % field.getNumItems();
312 leftOverFields.push_back(
313 {field.getFieldName(),
314 hw::ArrayType::get(array.getElementType(), leftOver)});
316 leftOverName = StringAttr::get(
317 getContext(), Twine(frame.getName().getValue(),
"_leftOver"));
319 }
else if (
auto list = hw::type_cast<esi::ListType>(
320 fieldTypeIter->getSecond())) {
323 {field.getFieldName(),
326 if (!isBulkTransferData) {
332 Twine(field.getFieldName().getValue(),
"_size")),
333 IntegerType::get(getContext(),
334 llvm::Log2_64_Ceil(field.getNumItems()))});
336 fields.push_back({StringAttr::get(getContext(),
"last"),
337 IntegerType::get(getContext(), 1)});
340 llvm_unreachable(
"numItems specified on non-array/list field");
347 if (getFrames().size() == 1 && frame.getName().getValue().empty() &&
349 auto loweredStruct = hw::StructType::get(getContext(), fields);
350 return wrapInTypeAliasIfNeeded(loweredStruct);
354 unionFields.push_back(
355 {frame.getName(), hw::StructType::get(getContext(), fields), 0});
358 unionFields.push_back(
359 {leftOverName, hw::StructType::get(getContext(), leftOverFields), 0});
362 auto unionType = hw::UnionType::get(getContext(), unionFields);
363 return wrapInTypeAliasIfNeeded(unionType);
369 static FailureOr<::BundledChannel>
parse(AsmParser &p) {
372 if (p.parseType(type))
374 auto dir = FieldParser<::ChannelDirection>::parse(p);
377 if (p.parseKeywordOrString(&name))
379 return BundledChannel{StringAttr::get(p.getContext(), name), *dir, type};
385inline ::llvm::raw_ostream &
operator<<(::llvm::raw_ostream &p,
392ChannelBundleType ChannelBundleType::getReversed()
const {
393 SmallVector<BundledChannel, 4> reversed;
394 for (
auto channel : getChannels())
395 reversed.push_back({channel.name,
flip(channel.direction), channel.type});
396 return ChannelBundleType::get(getContext(), reversed, getResettable());
399std::optional<int64_t> ChannelBundleType::getBitWidth()
const {
400 int64_t totalWidth = 0;
401 for (
auto channel : getChannels()) {
402 std::optional<int64_t> channelWidth = channel.type.getBitWidth();
405 totalWidth += *channelWidth;
410#define GET_TYPEDEF_CLASSES
411#include "circt/Dialect/ESI/ESITypes.cpp.inc"
413void ESIDialect::registerTypes() {
415#define GET_TYPEDEF_LIST
416#include "circt/Dialect/ESI/ESITypes.cpp.inc"
421 circt::esi::ChannelType chan =
422 dyn_cast_or_null<circt::esi::ChannelType>(type);
424 type = chan.getInner();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static unsigned getSignalingBitWidth(ChannelSignaling signaling)
static auto getChannelConsumers(mlir::TypedValue< ChannelType > chan)
Get the list of users with snoops filtered out.
static Location getLoc(DefSlot slot)
Lists represent variable-length sequences of elements of a single type.
const Type * getElementType() const
mlir::Type innerType(mlir::Type type)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
inline ::llvm::raw_ostream & operator<<(::llvm::raw_ostream &p, ::BundledChannel channel)
ChannelDirection direction
static FailureOr<::BundledChannel > parse(AsmParser &p)