13from __future__
import annotations
15from .
import esiCppAccel
as cpp
17from typing
import TYPE_CHECKING
20 from .accelerator
import HWModule
22from concurrent.futures
import Future
23from typing
import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union
29 """Get the wrapper class for a C++ type."""
30 for cpp_type_cls, wrapper_cls
in __esi_mapping.items():
31 if isinstance(cpp_type, cpp_type_cls):
32 return wrapper_cls.wrap_cpp(cpp_type)
33 return ESIType.wrap_cpp(cpp_type)
37__esi_mapping: Dict[Type, Type] = {}
47 """Wrap a C++ ESI type with its corresponding Python ESI Type."""
48 instance = cls.__new__(cls)
49 instance._init_from_cpp(cpp_type)
53 """Initialize instance attributes from a C++ type object."""
58 """Get the stable id of this type."""
63 """Does this type support host communication via Python? Returns either
64 '(True, None)' if it is, or '(False, reason)' if it is not."""
67 return (
False,
"runtime only supports types with multiple of 8 bits")
70 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
71 """Is a Python object compatible with HW type? Returns either '(True,
72 None)' if it is, or '(False, reason)' if it is not."""
73 assert False,
"unimplemented"
77 """Size of this type, in bits. Negative for unbounded types."""
82 """Maximum size of a value of this type, in bytes."""
89 """Convert a Python object to a bytearray."""
90 assert False,
"unimplemented"
92 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
93 """Convert a bytearray to a Python object. Return the object and the
95 assert False,
"unimplemented"
101 return isinstance(other, ESIType)
and self.
idid == other.id
124 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
130 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
134__esi_mapping[cpp.ChannelType] = ChannelType
141 direction: cpp.BundleType.Direction
144 def __init__(self, id: str, channels: List[Channel]):
145 cpp_channels = [(name, direction, channel_type.cpp_type)
146 for name, direction, channel_type
in channels]
153 for name, direction, channel_type
in cpp_type.channels
161__esi_mapping[cpp.BundleType] = BundleType
169 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
171 return (
False, f
"void type must be represented by None, not {obj}")
179 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
185__esi_mapping[cpp.VoidType] = VoidType
193 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
194 return (
False,
"any type is not supported for host communication")
197 raise ValueError(
"any type cannot be serialized")
199 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
200 raise ValueError(
"any type cannot be deserialized")
203__esi_mapping[cpp.AnyType] = AnyType
211 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
212 if not isinstance(obj, (bytearray, bytes, list)):
213 return (
False, f
"invalid type: {type(obj)}")
214 if isinstance(obj, list)
and not all(
215 [isinstance(b, int)
and b.bit_length() <= 8
for b
in obj]):
216 return (
False, f
"list item too large: {obj}")
218 return (
False, f
"wrong size: {len(obj)}")
221 def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray:
222 if isinstance(obj, bytearray):
224 if isinstance(obj, bytes)
or isinstance(obj, list):
225 return bytearray(obj)
226 raise ValueError(f
"cannot convert {obj} to bytearray")
228 def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]:
232__esi_mapping[cpp.BitsType] = BitsType
246 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
247 if not isinstance(obj, int):
248 return (
False, f
"must be an int, not {type(obj)}")
249 if obj < 0
or obj.bit_length() > self.
bit_width:
250 return (
False, f
"out of range: {obj}")
254 return f
"uint{self.bit_width}"
264__esi_mapping[cpp.UIntType] = UIntType
272 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
273 if not isinstance(obj, int):
274 return (
False, f
"must be an int, not {type(obj)}")
277 return (
False, f
"out of range: {obj}")
280 return (
False, f
"out of range: {obj}")
284 return f
"sint{self.bit_width}"
294__esi_mapping[cpp.SIntType] = SIntType
299 def __init__(self, id: str, fields: List[Tuple[str,
"ESIType"]]):
301 cpp_fields = [(name, field_type.cpp_type)
for name, field_type
in fields]
305 """Initialize instance attributes from a C++ type object."""
310 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
312 if not isinstance(obj, dict):
313 if not hasattr(obj,
"__dict__"):
314 return (
False,
"must be a dict or have __dict__ attribute")
317 for (fname, ftype)
in self.
fields:
319 return (
False, f
"missing field '{fname}'")
320 fvalid, reason = ftype.is_valid(obj[fname])
322 return (
False, f
"invalid field '{fname}': {reason}")
324 if fields_count != len(obj):
325 return (
False,
"missing fields")
330 if not isinstance(obj, dict):
332 ordered_fields = reversed(
334 for (fname, ftype)
in ordered_fields:
336 ret.extend(ftype.serialize(fval))
339 def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
341 ordered_fields = reversed(
343 for (fname, ftype)
in ordered_fields:
344 (fval, data) = ftype.deserialize(data)
349__esi_mapping[cpp.StructType] = StructType
354 def __init__(self, id: str, element_type:
"ESIType", size: int):
358 """Initialize instance attributes from a C++ type object."""
363 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
364 if not isinstance(obj, list):
365 return (
False, f
"must be a list, not {type(obj)}")
366 if len(obj) != self.
size:
367 return (
False, f
"wrong size: expected {self.size} not {len(obj)}")
368 for (idx, e)
in enumerate(obj):
371 return (
False, f
"invalid element {idx}: {reason}")
376 for e
in reversed(lst):
380 def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]:
382 for _
in range(self.
size):
389__esi_mapping[cpp.ArrayType] = ArrayType
394 def __init__(self, id: str, element_type:
"ESIType"):
403 return (
False,
"list types require an enclosing window encoding")
405 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
406 if not isinstance(obj, list):
407 return (
False, f
"must be a list, not {type(obj)}")
408 for (idx, element)
in enumerate(obj):
411 return (
False, f
"invalid element {idx}: {reason}")
415 raise ValueError(
"list type cannot be serialized without a window")
417 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
418 raise ValueError(
"list type cannot be deserialized without a window")
421__esi_mapping[cpp.ListType] = ListType
426 _HOST_UNSUPPORTED_REASON = (
427 "window types require into/lowered translation and are not yet "
428 "supported for host communication")
433 bulk_count_width: int
437 fields: List[
"WindowType.Field"]
439 def __init__(self, id: str, name: str, into_type:
"ESIType",
440 lowered_type:
"ESIType", frames: List[
"WindowType.Frame"]):
442 cpp.WindowFrame(frame.name, [
443 cpp.WindowField(field.name, field.num_items, field.bulk_count_width)
444 for field
in frame.fields
449 cpp.WindowType(id, name, into_type.cpp_type, lowered_type.cpp_type,
460 field.bulk_count_width)
for field
in frame.fields
461 ])
for frame
in cpp_type.frames
468 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
474 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
478__esi_mapping[cpp.WindowType] = WindowType
483 def __init__(self, id: str, fields: List[Tuple[str,
"ESIType"]]):
484 cpp_fields = [(name, field_type.cpp_type)
for name, field_type
in fields]
488 """Initialize instance attributes from a C++ type object."""
492 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
493 if not isinstance(obj, dict):
494 return (
False,
"must be a dict with exactly one field")
496 return (
False, f
"union must have exactly 1 active field, got {len(obj)}")
497 field_names = {name
for name, _
in self.
fields}
498 active_name = next(iter(obj))
499 if active_name
not in field_names:
500 return (
False, f
"unknown field '{active_name}' in union")
501 for (fname, ftype)
in self.
fields:
502 if fname == active_name:
503 return ftype.is_valid(obj[active_name])
504 return (
False, f
"unknown field '{active_name}' in union")
507 if not isinstance(obj, dict)
or len(obj) != 1:
508 raise ValueError(
"union value must be a dict with exactly one field")
509 active_name = next(iter(obj))
510 for (fname, ftype)
in self.
fields:
511 if fname == active_name:
512 field_bytes = ftype.serialize(obj[active_name])
516 pad_len = union_bytes - len(field_bytes)
518 return bytearray(pad_len) + field_bytes
520 raise ValueError(f
"unknown field '{active_name}' in union")
522 def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
524 union_data = data[:union_bytes]
525 remaining = data[union_bytes:]
527 for (fname, ftype)
in self.
fields:
530 field_bytes = (ftype.bit_width + 7) // 8
531 pad_len = union_bytes - field_bytes
532 (fval, _) = ftype.deserialize(bytearray(union_data[pad_len:]))
534 return (result, remaining)
537__esi_mapping[cpp.UnionType] = UnionType
542 def __init__(self, id: str, name: str, inner_type:
"ESIType"):
550 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
556 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
563__esi_mapping[cpp.TypeAliasType] = TypeAlias
567 """A unidirectional communication channel. This is the basic communication
568 method with an accelerator."""
570 def __init__(self, owner: BundlePort, cpp_port: cpp.ChannelPort):
574 win = cpp_port.windowType
577 def connect(self, buffer_size: Optional[int] =
None):
578 (supports_host, reason) = self.
type.supports_host
579 if not supports_host:
580 raise TypeError(f
"unsupported type: {reason}")
582 opts = cpp.ConnectOptions()
583 opts.buffer_size = buffer_size
592 """A unidirectional communication channel from the host to the accelerator."""
594 def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
596 self.
cpp_port: cpp.WriteChannelPort = cpp_port
599 valid, reason = self.
type.is_valid(msg)
602 f
"'{msg}' cannot be converted to '{self.type}': {reason}")
603 msg_bytes: bytearray = self.
type.serialize(msg)
607 """Write a typed message to the channel. Attempts to serialize 'msg' to what
608 the accelerator expects, but will fail if the object is not convertible to
614 """Like 'write', but uses the non-blocking tryWrite method of the underlying
615 port. Returns True if the write was successful, False otherwise."""
620 """A unidirectional communication channel from the accelerator to the host."""
622 def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
624 self.
cpp_port: cpp.ReadChannelPort = cpp_port
627 """Read a typed message from the channel. Returns a deserialized object of a
628 type defined by the port type."""
631 (msg, leftover) = self.
type.deserialize(buffer)
632 if len(leftover) != 0:
633 raise ValueError(f
"leftover bytes: {leftover}")
638 """A collections of named, unidirectional communication channels."""
642 def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort):
644 if isinstance(cpp_port, cpp.Function):
645 return super().
__new__(FunctionPort)
646 if isinstance(cpp_port, cpp.Callback):
647 return super().
__new__(CallbackPort)
648 if isinstance(cpp_port, cpp.MMIORegion):
649 return super().
__new__(MMIORegion)
650 if isinstance(cpp_port, cpp.Metric):
651 return super().
__new__(MetricPort)
652 if isinstance(cpp_port, cpp.ToHostChannel):
653 return super().
__new__(ToHostPort)
654 if isinstance(cpp_port, cpp.FromHostChannel):
655 return super().
__new__(FromHostPort)
658 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
670 """A specialization of `Future` for ESI messages. Wraps the cpp object and
671 deserializes the result. Hopefully overrides all the methods necessary for
672 proper operation, which is assumed to be not all of them."""
674 def __init__(self, result_type: Type, cpp_future: cpp.MessageDataFuture):
684 def result(self, timeout: Optional[Union[int, float]] =
None) -> Any:
688 (msg, leftover) = self.
result_type.deserialize(result_bytes)
689 if len(leftover) != 0:
690 raise ValueError(f
"leftover bytes: {leftover}")
694 raise NotImplementedError(
"add_done_callback is not implemented")
698 """A region of memory-mapped I/O space. This is a collection of named
699 channels, which are either read or read-write. The channels are accessed
700 by name, and can be connected to the host."""
702 def __init__(self, owner: HWModule, cpp_port: cpp.MMIORegion):
708 return self.
region.descriptor
710 def read(self, offset: int) -> bytearray:
711 """Read a value from the MMIO region at the given offset."""
714 def write(self, offset: int, data: bytearray) ->
None:
715 """Write a value to the MMIO region at the given offset."""
720 """A pair of channels which carry the input and output of a function."""
722 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
736 def call(self, *args: Any, **kwargs: Any) -> Future:
737 """Call the function with the given argument and returns a future of the
741 if len(args) > 0
and len(kwargs) > 0:
742 raise ValueError(
"cannot use both positional and keyword arguments")
752 valid, reason = self.
arg_type.is_valid(selected)
755 f
"'{selected}' cannot be converted to '{self.arg_type}': {reason}")
756 arg_bytes: bytearray = self.
arg_type.serialize(selected)
760 def __call__(self, *args: Any, **kwds: Any) -> Future:
761 return self.
call(*args, **kwds)
765 """Callback ports are the inverse of function ports -- instead of calls to the
766 accelerator, they get called from the accelerator. Specify the function which
767 you'd like the accelerator to call when you call `connect`."""
769 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
781 def type_convert_wrapper(cb: Callable[[Any], Any],
782 msg: bytearray) -> Optional[bytearray]:
784 (obj, leftover) = self.
arg_type.deserialize(msg)
785 if len(leftover) != 0:
786 raise ValueError(f
"leftover bytes: {leftover}")
791 except Exception
as e:
792 traceback.print_exception(e)
800 """Telemetry ports report an individual piece of information from the
801 acceelerator. The method of accessing telemetry will likely change in the
804 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
818 """A channel which reads data from the accelerator (to_host)."""
820 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
832 """Read a value from the channel. Returns a future."""
838 """A channel which writes data to the accelerator (from_host)."""
840 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
852 """Write a value to the channel."""
853 valid, reason = self.
data_type.is_valid(data)
856 f
"'{data}' cannot be converted to '{self.data_type}': {reason}")
Tuple[object, bytearray] deserialize(self, bytearray data)
bytearray serialize(self, obj)
Tuple[bool, Optional[str]] is_valid(self, obj)
Tuple[bool, Optional[str]] is_valid(self, obj)
_init_from_cpp(self, cpp.ArrayType cpp_type)
Tuple[List[Any], bytearray] deserialize(self, bytearray data)
__init__(self, str id, "ESIType" element_type, int size)
bytearray serialize(self, list lst)
bytearray serialize(self, Union[bytearray, bytes, List[int]] obj)
Tuple[bytearray, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
__init__(self, str id, int width)
__new__(cls, HWModule owner, cpp.BundlePort cpp_port)
WritePort write_port(self, str channel_name)
ReadPort read_port(self, str channel_name)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
List["BundleType.Channel"] channels(self)
_init_from_cpp(self, cpp.BundleType cpp_type)
__init__(self, str id, List[Channel] channels)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
connect(self, Callable[[Any], Any] cb)
bytearray serialize(self, obj)
_init_from_cpp(self, cpp.ChannelType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
Tuple[object, bytearray] deserialize(self, bytearray data)
__init__(self, str id, "ESIType" inner)
Tuple[bool, Optional[str]] supports_host(self)
Tuple[bool, Optional[str]] supports_host(self)
Tuple[object, bytearray] deserialize(self, bytearray data)
_init_from_cpp(self, cpp.Type cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
wrap_cpp(cls, cpp.Type cpp_type)
bytearray serialize(self, obj)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
None write(self, Any data)
Future call(self, *Any args, **Any kwargs)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
Future __call__(self, *Any args, **Any kwds)
__init__(self, str id, int width)
bytearray serialize(self, obj)
__init__(self, str id, "ESIType" element_type)
_init_from_cpp(self, cpp.ListType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
Tuple[object, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] supports_host(self)
None write(self, int offset, bytearray data)
bytearray read(self, int offset)
__init__(self, HWModule owner, cpp.MMIORegion cpp_port)
cpp.MMIORegionDesc descriptor(self)
Any result(self, Optional[Union[int, float]] timeout=None)
__init__(self, Type result_type, cpp.MessageDataFuture cpp_future)
None add_done_callback(self, Callable[[Future], object] fn)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
__init__(self, BundlePort owner, cpp.ChannelPort cpp_port)
connect(self, Optional[int] buffer_size=None)
__init__(self, BundlePort owner, cpp.ReadChannelPort cpp_port)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
__init__(self, str id, int width)
__init__(self, str id, List[Tuple[str, "ESIType"]] fields)
bytearray serialize(self, obj)
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
_init_from_cpp(self, cpp.StructType cpp_type)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
Tuple[object, bytearray] deserialize(self, bytearray data)
_init_from_cpp(self, cpp.TypeAliasType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
__init__(self, str id, str name, "ESIType" inner_type)
__init__(self, str id, int width)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
_init_from_cpp(self, cpp.UnionType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
__init__(self, str id, List[Tuple[str, "ESIType"]] fields)
Tuple[object, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
Tuple[object, bytearray] deserialize(self, bytearray data)
tuple _HOST_UNSUPPORTED_REASON
__init__(self, str id, str name, "ESIType" into_type, "ESIType" lowered_type, List["WindowType.Frame"] frames)
_init_from_cpp(self, cpp.WindowType cpp_type)
Tuple[bool, Optional[str]] supports_host(self)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
bool try_write(self, msg=None)
bytearray __serialize_msg(self, msg=None)
bool write(self, msg=None)
__init__(self, BundlePort owner, cpp.WriteChannelPort cpp_port)
_get_esi_type(cpp.Type cpp_type)