13from __future__
import annotations
15from .
import esiCppAccel
as cpp
17from typing
import TYPE_CHECKING
20 from .accelerator
import HWModule
22from concurrent.futures
import Future
23from typing
import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union
29 """Get the wrapper class for a C++ type."""
30 for cpp_type_cls, wrapper_cls
in __esi_mapping.items():
31 if isinstance(cpp_type, cpp_type_cls):
32 return wrapper_cls.wrap_cpp(cpp_type)
33 return ESIType.wrap_cpp(cpp_type)
37__esi_mapping: Dict[Type, Type] = {}
47 """Wrap a C++ ESI type with its corresponding Python ESI Type."""
48 instance = cls.__new__(cls)
49 instance._init_from_cpp(cpp_type)
53 """Initialize instance attributes from a C++ type object."""
58 """Get the stable id of this type."""
63 """Does this type support host communication via Python? Returns either
64 '(True, None)' if it is, or '(False, reason)' if it is not."""
67 return (
False,
"runtime only supports types with multiple of 8 bits")
70 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
71 """Is a Python object compatible with HW type? Returns either '(True,
72 None)' if it is, or '(False, reason)' if it is not."""
73 assert False,
"unimplemented"
77 """Size of this type, in bits. Negative for unbounded types."""
78 assert False,
"unimplemented"
82 """Maximum size of a value of this type, in bytes."""
89 """Convert a Python object to a bytearray."""
90 assert False,
"unimplemented"
92 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
93 """Convert a bytearray to a Python object. Return the object and the
95 assert False,
"unimplemented"
101 return isinstance(other, ESIType)
and self.
idid == other.id
128 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
134 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
138__esi_mapping[cpp.ChannelType] = ChannelType
145 direction: cpp.BundleType.Direction
148 def __init__(self, id: str, channels: List[Channel]):
149 cpp_channels = [(name, direction, channel_type.cpp_type)
150 for name, direction, channel_type
in channels]
157 for name, direction, channel_type
in cpp_type.channels
165__esi_mapping[cpp.BundleType] = BundleType
173 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
175 return (
False, f
"void type cannot must represented by None, not {obj}")
184 return bytearray([0])
186 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
188 raise ValueError(f
"void type cannot be represented by {data}")
189 return (
None, data[1:])
192__esi_mapping[cpp.VoidType] = VoidType
200 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
201 return (
False,
"any type is not supported for host communication")
208 raise ValueError(
"any type cannot be serialized")
210 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
211 raise ValueError(
"any type cannot be deserialized")
214__esi_mapping[cpp.AnyType] = AnyType
222 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
223 if not isinstance(obj, (bytearray, bytes, list)):
224 return (
False, f
"invalid type: {type(obj)}")
225 if isinstance(obj, list)
and not all(
226 [isinstance(b, int)
and b.bit_length() <= 8
for b
in obj]):
227 return (
False, f
"list item too large: {obj}")
229 return (
False, f
"wrong size: {len(obj)}")
236 def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray:
237 if isinstance(obj, bytearray):
239 if isinstance(obj, bytes)
or isinstance(obj, list):
240 return bytearray(obj)
241 raise ValueError(f
"cannot convert {obj} to bytearray")
243 def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]:
247__esi_mapping[cpp.BitsType] = BitsType
265 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
266 if not isinstance(obj, int):
267 return (
False, f
"must be an int, not {type(obj)}")
269 return (
False, f
"out of range: {obj}")
273 return f
"uint{self.bit_width}"
283__esi_mapping[cpp.UIntType] = UIntType
291 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
292 if not isinstance(obj, int):
293 return (
False, f
"must be an int, not {type(obj)}")
296 return (
False, f
"out of range: {obj}")
299 return (
False, f
"out of range: {obj}")
303 return f
"sint{self.bit_width}"
313__esi_mapping[cpp.SIntType] = SIntType
318 def __init__(self, id: str, fields: List[Tuple[str,
"ESIType"]]):
320 cpp_fields = [(name, field_type.cpp_type)
for name, field_type
in fields]
324 """Initialize instance attributes from a C++ type object."""
331 widths = [ty.bit_width
for (_, ty)
in self.
fields]
332 if any([w < 0
for w
in widths]):
336 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
338 if not isinstance(obj, dict):
339 if not hasattr(obj,
"__dict__"):
340 return (
False,
"must be a dict or have __dict__ attribute")
343 for (fname, ftype)
in self.
fields:
345 return (
False, f
"missing field '{fname}'")
346 fvalid, reason = ftype.is_valid(obj[fname])
348 return (
False, f
"invalid field '{fname}': {reason}")
350 if fields_count != len(obj):
351 return (
False,
"missing fields")
356 if not isinstance(obj, dict):
358 ordered_fields = reversed(
360 for (fname, ftype)
in ordered_fields:
362 ret.extend(ftype.serialize(fval))
365 def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
367 ordered_fields = reversed(
369 for (fname, ftype)
in ordered_fields:
370 (fval, data) = ftype.deserialize(data)
375__esi_mapping[cpp.StructType] = StructType
380 def __init__(self, id: str, element_type:
"ESIType", size: int):
384 """Initialize instance attributes from a C++ type object."""
393 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
394 if not isinstance(obj, list):
395 return (
False, f
"must be a list, not {type(obj)}")
396 if len(obj) != self.
size:
397 return (
False, f
"wrong size: expected {self.size} not {len(obj)}")
398 for (idx, e)
in enumerate(obj):
401 return (
False, f
"invalid element {idx}: {reason}")
406 for e
in reversed(lst):
410 def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]:
412 for _
in range(self.
size):
419__esi_mapping[cpp.ArrayType] = ArrayType
424 def __init__(self, id: str, element_type:
"ESIType"):
433 return (
False,
"list types require an enclosing window encoding")
439 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
440 if not isinstance(obj, list):
441 return (
False, f
"must be a list, not {type(obj)}")
442 for (idx, element)
in enumerate(obj):
445 return (
False, f
"invalid element {idx}: {reason}")
449 raise ValueError(
"list type cannot be serialized without a window")
451 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
452 raise ValueError(
"list type cannot be deserialized without a window")
455__esi_mapping[cpp.ListType] = ListType
460 _HOST_UNSUPPORTED_REASON = (
461 "window types require into/lowered translation and are not yet "
462 "supported for host communication")
467 bulk_count_width: int
471 fields: List[
"WindowType.Field"]
473 def __init__(self, id: str, name: str, into_type:
"ESIType",
474 lowered_type:
"ESIType", frames: List[
"WindowType.Frame"]):
476 cpp.WindowFrame(frame.name, [
477 cpp.WindowField(field.name, field.num_items, field.bulk_count_width)
478 for field
in frame.fields
483 cpp.WindowType(id, name, into_type.cpp_type, lowered_type.cpp_type,
494 field.bulk_count_width)
for field
in frame.fields
495 ])
for frame
in cpp_type.frames
506 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
512 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
516__esi_mapping[cpp.WindowType] = WindowType
521 def __init__(self, id: str, fields: List[Tuple[str,
"ESIType"]]):
522 cpp_fields = [(name, field_type.cpp_type)
for name, field_type
in fields]
526 """Initialize instance attributes from a C++ type object."""
532 widths = [ty.bit_width
for (_, ty)
in self.
fields]
533 if any([w < 0
for w
in widths]):
535 return max(widths)
if widths
else 0
537 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
538 if not isinstance(obj, dict):
539 return (
False,
"must be a dict with exactly one field")
541 return (
False, f
"union must have exactly 1 active field, got {len(obj)}")
542 field_names = {name
for name, _
in self.
fields}
543 active_name = next(iter(obj))
544 if active_name
not in field_names:
545 return (
False, f
"unknown field '{active_name}' in union")
546 for (fname, ftype)
in self.
fields:
547 if fname == active_name:
548 return ftype.is_valid(obj[active_name])
549 return (
False, f
"unknown field '{active_name}' in union")
552 if not isinstance(obj, dict)
or len(obj) != 1:
553 raise ValueError(
"union value must be a dict with exactly one field")
554 active_name = next(iter(obj))
555 for (fname, ftype)
in self.
fields:
556 if fname == active_name:
557 field_bytes = ftype.serialize(obj[active_name])
561 pad_len = union_bytes - len(field_bytes)
563 return bytearray(pad_len) + field_bytes
565 raise ValueError(f
"unknown field '{active_name}' in union")
567 def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
569 union_data = data[:union_bytes]
570 remaining = data[union_bytes:]
572 for (fname, ftype)
in self.
fields:
575 field_bytes = (ftype.bit_width + 7) // 8
576 pad_len = union_bytes - field_bytes
577 (fval, _) = ftype.deserialize(bytearray(union_data[pad_len:]))
579 return (result, remaining)
582__esi_mapping[cpp.UnionType] = UnionType
587 def __init__(self, id: str, name: str, inner_type:
"ESIType"):
599 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
605 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
612__esi_mapping[cpp.TypeAliasType] = TypeAlias
616 """A unidirectional communication channel. This is the basic communication
617 method with an accelerator."""
619 def __init__(self, owner: BundlePort, cpp_port: cpp.ChannelPort):
624 def connect(self, buffer_size: Optional[int] =
None):
625 (supports_host, reason) = self.
type.supports_host
626 if not supports_host:
627 raise TypeError(f
"unsupported type: {reason}")
629 opts = cpp.ConnectOptions()
630 opts.buffer_size = buffer_size
639 """A unidirectional communication channel from the host to the accelerator."""
641 def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
643 self.
cpp_port: cpp.WriteChannelPort = cpp_port
646 valid, reason = self.
type.is_valid(msg)
649 f
"'{msg}' cannot be converted to '{self.type}': {reason}")
650 msg_bytes: bytearray = self.
type.serialize(msg)
654 """Write a typed message to the channel. Attempts to serialize 'msg' to what
655 the accelerator expects, but will fail if the object is not convertible to
661 """Like 'write', but uses the non-blocking tryWrite method of the underlying
662 port. Returns True if the write was successful, False otherwise."""
667 """A unidirectional communication channel from the accelerator to the host."""
669 def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
671 self.
cpp_port: cpp.ReadChannelPort = cpp_port
674 """Read a typed message from the channel. Returns a deserialized object of a
675 type defined by the port type."""
678 (msg, leftover) = self.
type.deserialize(buffer)
679 if len(leftover) != 0:
680 raise ValueError(f
"leftover bytes: {leftover}")
685 """A collections of named, unidirectional communication channels."""
689 def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort):
691 if isinstance(cpp_port, cpp.Function):
692 return super().
__new__(FunctionPort)
693 if isinstance(cpp_port, cpp.Callback):
694 return super().
__new__(CallbackPort)
695 if isinstance(cpp_port, cpp.MMIORegion):
696 return super().
__new__(MMIORegion)
697 if isinstance(cpp_port, cpp.Metric):
698 return super().
__new__(MetricPort)
699 if isinstance(cpp_port, cpp.ToHostChannel):
700 return super().
__new__(ToHostPort)
701 if isinstance(cpp_port, cpp.FromHostChannel):
702 return super().
__new__(FromHostPort)
705 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
717 """A specialization of `Future` for ESI messages. Wraps the cpp object and
718 deserializes the result. Hopefully overrides all the methods necessary for
719 proper operation, which is assumed to be not all of them."""
721 def __init__(self, result_type: Type, cpp_future: cpp.MessageDataFuture):
731 def result(self, timeout: Optional[Union[int, float]] =
None) -> Any:
735 (msg, leftover) = self.
result_type.deserialize(result_bytes)
736 if len(leftover) != 0:
737 raise ValueError(f
"leftover bytes: {leftover}")
741 raise NotImplementedError(
"add_done_callback is not implemented")
745 """A region of memory-mapped I/O space. This is a collection of named
746 channels, which are either read or read-write. The channels are accessed
747 by name, and can be connected to the host."""
749 def __init__(self, owner: HWModule, cpp_port: cpp.MMIORegion):
755 return self.
region.descriptor
757 def read(self, offset: int) -> bytearray:
758 """Read a value from the MMIO region at the given offset."""
761 def write(self, offset: int, data: bytearray) ->
None:
762 """Write a value to the MMIO region at the given offset."""
767 """A pair of channels which carry the input and output of a function."""
769 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
779 def call(self, *args: Any, **kwargs: Any) -> Future:
780 """Call the function with the given argument and returns a future of the
784 if len(args) > 0
and len(kwargs) > 0:
785 raise ValueError(
"cannot use both positional and keyword arguments")
795 valid, reason = self.
arg_type.is_valid(selected)
798 f
"'{selected}' cannot be converted to '{self.arg_type}': {reason}")
799 arg_bytes: bytearray = self.
arg_type.serialize(selected)
803 def __call__(self, *args: Any, **kwds: Any) -> Future:
804 return self.
call(*args, **kwds)
808 """Callback ports are the inverse of function ports -- instead of calls to the
809 accelerator, they get called from the accelerator. Specify the function which
810 you'd like the accelerator to call when you call `connect`."""
812 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
820 def type_convert_wrapper(cb: Callable[[Any], Any],
821 msg: bytearray) -> Optional[bytearray]:
823 (obj, leftover) = self.
arg_type.deserialize(msg)
824 if len(leftover) != 0:
825 raise ValueError(f
"leftover bytes: {leftover}")
830 except Exception
as e:
831 traceback.print_exception(e)
839 """Telemetry ports report an individual piece of information from the
840 acceelerator. The method of accessing telemetry will likely change in the
843 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
857 """A channel which reads data from the accelerator (to_host)."""
859 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
869 """Read a value from the channel. Returns a future."""
875 """A channel which writes data to the accelerator (from_host)."""
877 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
887 """Write a value to the channel."""
888 valid, reason = self.
data_type.is_valid(data)
891 f
"'{data}' cannot be converted to '{self.data_type}': {reason}")
Tuple[object, bytearray] deserialize(self, bytearray data)
bytearray serialize(self, obj)
Tuple[bool, Optional[str]] is_valid(self, obj)
Tuple[bool, Optional[str]] is_valid(self, obj)
_init_from_cpp(self, cpp.ArrayType cpp_type)
Tuple[List[Any], bytearray] deserialize(self, bytearray data)
__init__(self, str id, "ESIType" element_type, int size)
bytearray serialize(self, list lst)
bytearray serialize(self, Union[bytearray, bytes, List[int]] obj)
Tuple[bytearray, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
__init__(self, str id, int width)
__new__(cls, HWModule owner, cpp.BundlePort cpp_port)
WritePort write_port(self, str channel_name)
ReadPort read_port(self, str channel_name)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
List["BundleType.Channel"] channels(self)
_init_from_cpp(self, cpp.BundleType cpp_type)
__init__(self, str id, List[Channel] channels)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
connect(self, Callable[[Any], Any] cb)
bytearray serialize(self, obj)
_init_from_cpp(self, cpp.ChannelType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
Tuple[object, bytearray] deserialize(self, bytearray data)
__init__(self, str id, "ESIType" inner)
Tuple[bool, Optional[str]] supports_host(self)
Tuple[bool, Optional[str]] supports_host(self)
Tuple[object, bytearray] deserialize(self, bytearray data)
_init_from_cpp(self, cpp.Type cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
wrap_cpp(cls, cpp.Type cpp_type)
bytearray serialize(self, obj)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
None write(self, Any data)
Future call(self, *Any args, **Any kwargs)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
Future __call__(self, *Any args, **Any kwds)
__init__(self, str id, int width)
bytearray serialize(self, obj)
__init__(self, str id, "ESIType" element_type)
_init_from_cpp(self, cpp.ListType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
Tuple[object, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] supports_host(self)
None write(self, int offset, bytearray data)
bytearray read(self, int offset)
__init__(self, HWModule owner, cpp.MMIORegion cpp_port)
cpp.MMIORegionDesc descriptor(self)
Any result(self, Optional[Union[int, float]] timeout=None)
__init__(self, Type result_type, cpp.MessageDataFuture cpp_future)
None add_done_callback(self, Callable[[Future], object] fn)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
__init__(self, BundlePort owner, cpp.ChannelPort cpp_port)
connect(self, Optional[int] buffer_size=None)
__init__(self, BundlePort owner, cpp.ReadChannelPort cpp_port)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
__init__(self, str id, int width)
__init__(self, str id, List[Tuple[str, "ESIType"]] fields)
bytearray serialize(self, obj)
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
_init_from_cpp(self, cpp.StructType cpp_type)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
Tuple[object, bytearray] deserialize(self, bytearray data)
_init_from_cpp(self, cpp.TypeAliasType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
__init__(self, str id, str name, "ESIType" inner_type)
__init__(self, str id, int width)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
_init_from_cpp(self, cpp.UnionType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
__init__(self, str id, List[Tuple[str, "ESIType"]] fields)
Tuple[object, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
Tuple[object, bytearray] deserialize(self, bytearray data)
tuple _HOST_UNSUPPORTED_REASON
__init__(self, str id, str name, "ESIType" into_type, "ESIType" lowered_type, List["WindowType.Frame"] frames)
_init_from_cpp(self, cpp.WindowType cpp_type)
Tuple[bool, Optional[str]] supports_host(self)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
bool try_write(self, msg=None)
bytearray __serialize_msg(self, msg=None)
bool write(self, msg=None)
__init__(self, BundlePort owner, cpp.WriteChannelPort cpp_port)
_get_esi_type(cpp.Type cpp_type)