13from __future__
import annotations
15from .
import esiCppAccel
as cpp
17from typing
import TYPE_CHECKING
20 from .accelerator
import HWModule
22from concurrent.futures
import Future
23from typing
import Any, Callable, Dict, List, NamedTuple, Optional, Tuple, Type, Union
29 """Get the wrapper class for a C++ type."""
30 for cpp_type_cls, wrapper_cls
in __esi_mapping.items():
31 if isinstance(cpp_type, cpp_type_cls):
32 return wrapper_cls.wrap_cpp(cpp_type)
33 return ESIType.wrap_cpp(cpp_type)
37__esi_mapping: Dict[Type, Type] = {}
47 """Wrap a C++ ESI type with its corresponding Python ESI Type."""
48 instance = cls.__new__(cls)
49 instance._init_from_cpp(cpp_type)
53 """Initialize instance attributes from a C++ type object."""
58 """Get the stable id of this type."""
63 """Does this type support host communication via Python? Returns either
64 '(True, None)' if it is, or '(False, reason)' if it is not."""
67 return (
False,
"runtime only supports types with multiple of 8 bits")
70 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
71 """Is a Python object compatible with HW type? Returns either '(True,
72 None)' if it is, or '(False, reason)' if it is not."""
73 assert False,
"unimplemented"
77 """Size of this type, in bits. Negative for unbounded types."""
78 assert False,
"unimplemented"
82 """Maximum size of a value of this type, in bytes."""
89 """Convert a Python object to a bytearray."""
90 assert False,
"unimplemented"
92 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
93 """Convert a bytearray to a Python object. Return the object and the
95 assert False,
"unimplemented"
101 return isinstance(other, ESIType)
and self.
idid == other.id
128 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
134 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
138__esi_mapping[cpp.ChannelType] = ChannelType
145 direction: cpp.BundleType.Direction
148 def __init__(self, id: str, channels: List[Channel]):
149 cpp_channels = [(name, direction, channel_type.cpp_type)
150 for name, direction, channel_type
in channels]
157 for name, direction, channel_type
in cpp_type.channels
165__esi_mapping[cpp.BundleType] = BundleType
173 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
175 return (
False, f
"void type cannot must represented by None, not {obj}")
184 return bytearray([0])
186 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
188 raise ValueError(f
"void type cannot be represented by {data}")
189 return (
None, data[1:])
192__esi_mapping[cpp.VoidType] = VoidType
200 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
201 return (
False,
"any type is not supported for host communication")
208 raise ValueError(
"any type cannot be serialized")
210 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
211 raise ValueError(
"any type cannot be deserialized")
214__esi_mapping[cpp.AnyType] = AnyType
222 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
223 if not isinstance(obj, (bytearray, bytes, list)):
224 return (
False, f
"invalid type: {type(obj)}")
225 if isinstance(obj, list)
and not all(
226 [isinstance(b, int)
and b.bit_length() <= 8
for b
in obj]):
227 return (
False, f
"list item too large: {obj}")
229 return (
False, f
"wrong size: {len(obj)}")
236 def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray:
237 if isinstance(obj, bytearray):
239 if isinstance(obj, bytes)
or isinstance(obj, list):
240 return bytearray(obj)
241 raise ValueError(f
"cannot convert {obj} to bytearray")
243 def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]:
247__esi_mapping[cpp.BitsType] = BitsType
265 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
266 if not isinstance(obj, int):
267 return (
False, f
"must be an int, not {type(obj)}")
269 return (
False, f
"out of range: {obj}")
273 return f
"uint{self.bit_width}"
283__esi_mapping[cpp.UIntType] = UIntType
291 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
292 if not isinstance(obj, int):
293 return (
False, f
"must be an int, not {type(obj)}")
296 return (
False, f
"out of range: {obj}")
299 return (
False, f
"out of range: {obj}")
303 return f
"sint{self.bit_width}"
313__esi_mapping[cpp.SIntType] = SIntType
318 def __init__(self, id: str, fields: List[Tuple[str,
"ESIType"]]):
320 cpp_fields = [(name, field_type.cpp_type)
for name, field_type
in fields]
324 """Initialize instance attributes from a C++ type object."""
331 widths = [ty.bit_width
for (_, ty)
in self.
fields]
332 if any([w < 0
for w
in widths]):
336 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
338 if not isinstance(obj, dict):
339 if not hasattr(obj,
"__dict__"):
340 return (
False,
"must be a dict or have __dict__ attribute")
343 for (fname, ftype)
in self.
fields:
345 return (
False, f
"missing field '{fname}'")
346 fvalid, reason = ftype.is_valid(obj[fname])
348 return (
False, f
"invalid field '{fname}': {reason}")
350 if fields_count != len(obj):
351 return (
False,
"missing fields")
356 if not isinstance(obj, dict):
358 ordered_fields = reversed(
360 for (fname, ftype)
in ordered_fields:
362 ret.extend(ftype.serialize(fval))
365 def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
367 ordered_fields = reversed(
369 for (fname, ftype)
in ordered_fields:
370 (fval, data) = ftype.deserialize(data)
375__esi_mapping[cpp.StructType] = StructType
380 def __init__(self, id: str, element_type:
"ESIType", size: int):
384 """Initialize instance attributes from a C++ type object."""
393 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
394 if not isinstance(obj, list):
395 return (
False, f
"must be a list, not {type(obj)}")
396 if len(obj) != self.
size:
397 return (
False, f
"wrong size: expected {self.size} not {len(obj)}")
398 for (idx, e)
in enumerate(obj):
401 return (
False, f
"invalid element {idx}: {reason}")
406 for e
in reversed(lst):
410 def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]:
412 for _
in range(self.
size):
419__esi_mapping[cpp.ArrayType] = ArrayType
424 def __init__(self, id: str, name: str, inner_type:
"ESIType"):
436 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
442 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
449__esi_mapping[cpp.TypeAliasType] = TypeAlias
453 """A unidirectional communication channel. This is the basic communication
454 method with an accelerator."""
456 def __init__(self, owner: BundlePort, cpp_port: cpp.ChannelPort):
461 def connect(self, buffer_size: Optional[int] =
None):
462 (supports_host, reason) = self.
type.supports_host
463 if not supports_host:
464 raise TypeError(f
"unsupported type: {reason}")
466 opts = cpp.ConnectOptions()
467 opts.buffer_size = buffer_size
476 """A unidirectional communication channel from the host to the accelerator."""
478 def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
480 self.
cpp_port: cpp.WriteChannelPort = cpp_port
483 valid, reason = self.
type.is_valid(msg)
486 f
"'{msg}' cannot be converted to '{self.type}': {reason}")
487 msg_bytes: bytearray = self.
type.serialize(msg)
491 """Write a typed message to the channel. Attempts to serialize 'msg' to what
492 the accelerator expects, but will fail if the object is not convertible to
498 """Like 'write', but uses the non-blocking tryWrite method of the underlying
499 port. Returns True if the write was successful, False otherwise."""
504 """A unidirectional communication channel from the accelerator to the host."""
506 def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
508 self.
cpp_port: cpp.ReadChannelPort = cpp_port
511 """Read a typed message from the channel. Returns a deserialized object of a
512 type defined by the port type."""
515 (msg, leftover) = self.
type.deserialize(buffer)
516 if len(leftover) != 0:
517 raise ValueError(f
"leftover bytes: {leftover}")
522 """A collections of named, unidirectional communication channels."""
526 def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort):
528 if isinstance(cpp_port, cpp.Function):
529 return super().
__new__(FunctionPort)
530 if isinstance(cpp_port, cpp.Callback):
531 return super().
__new__(CallbackPort)
532 if isinstance(cpp_port, cpp.MMIORegion):
533 return super().
__new__(MMIORegion)
534 if isinstance(cpp_port, cpp.Metric):
535 return super().
__new__(MetricPort)
538 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
550 """A specialization of `Future` for ESI messages. Wraps the cpp object and
551 deserializes the result. Hopefully overrides all the methods necessary for
552 proper operation, which is assumed to be not all of them."""
554 def __init__(self, result_type: Type, cpp_future: cpp.MessageDataFuture):
564 def result(self, timeout: Optional[Union[int, float]] =
None) -> Any:
568 (msg, leftover) = self.
result_type.deserialize(result_bytes)
569 if len(leftover) != 0:
570 raise ValueError(f
"leftover bytes: {leftover}")
574 raise NotImplementedError(
"add_done_callback is not implemented")
578 """A region of memory-mapped I/O space. This is a collection of named
579 channels, which are either read or read-write. The channels are accessed
580 by name, and can be connected to the host."""
582 def __init__(self, owner: HWModule, cpp_port: cpp.MMIORegion):
588 return self.
region.descriptor
590 def read(self, offset: int) -> bytearray:
591 """Read a value from the MMIO region at the given offset."""
594 def write(self, offset: int, data: bytearray) ->
None:
595 """Write a value to the MMIO region at the given offset."""
600 """A pair of channels which carry the input and output of a function."""
602 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
612 def call(self, *args: Any, **kwargs: Any) -> Future:
613 """Call the function with the given argument and returns a future of the
617 if len(args) > 0
and len(kwargs) > 0:
618 raise ValueError(
"cannot use both positional and keyword arguments")
628 valid, reason = self.
arg_type.is_valid(selected)
631 f
"'{selected}' cannot be converted to '{self.arg_type}': {reason}")
632 arg_bytes: bytearray = self.
arg_type.serialize(selected)
636 def __call__(self, *args: Any, **kwds: Any) -> Future:
637 return self.
call(*args, **kwds)
641 """Callback ports are the inverse of function ports -- instead of calls to the
642 accelerator, they get called from the accelerator. Specify the function which
643 you'd like the accelerator to call when you call `connect`."""
645 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
653 def type_convert_wrapper(cb: Callable[[Any], Any],
654 msg: bytearray) -> Optional[bytearray]:
656 (obj, leftover) = self.
arg_type.deserialize(msg)
657 if len(leftover) != 0:
658 raise ValueError(f
"leftover bytes: {leftover}")
663 except Exception
as e:
664 traceback.print_exception(e)
672 """Telemetry ports report an individual piece of information from the
673 acceelerator. The method of accessing telemetry will likely change in the
676 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
Tuple[object, bytearray] deserialize(self, bytearray data)
bytearray serialize(self, obj)
Tuple[bool, Optional[str]] is_valid(self, obj)
Tuple[bool, Optional[str]] is_valid(self, obj)
_init_from_cpp(self, cpp.ArrayType cpp_type)
Tuple[List[Any], bytearray] deserialize(self, bytearray data)
__init__(self, str id, "ESIType" element_type, int size)
bytearray serialize(self, list lst)
bytearray serialize(self, Union[bytearray, bytes, List[int]] obj)
Tuple[bytearray, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
__init__(self, str id, int width)
__new__(cls, HWModule owner, cpp.BundlePort cpp_port)
WritePort write_port(self, str channel_name)
ReadPort read_port(self, str channel_name)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
List["BundleType.Channel"] channels(self)
_init_from_cpp(self, cpp.BundleType cpp_type)
__init__(self, str id, List[Channel] channels)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
connect(self, Callable[[Any], Any] cb)
bytearray serialize(self, obj)
_init_from_cpp(self, cpp.ChannelType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
Tuple[object, bytearray] deserialize(self, bytearray data)
__init__(self, str id, "ESIType" inner)
Tuple[bool, Optional[str]] supports_host(self)
Tuple[bool, Optional[str]] supports_host(self)
Tuple[object, bytearray] deserialize(self, bytearray data)
_init_from_cpp(self, cpp.Type cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
wrap_cpp(cls, cpp.Type cpp_type)
bytearray serialize(self, obj)
Future call(self, *Any args, **Any kwargs)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
Future __call__(self, *Any args, **Any kwds)
__init__(self, str id, int width)
None write(self, int offset, bytearray data)
bytearray read(self, int offset)
__init__(self, HWModule owner, cpp.MMIORegion cpp_port)
cpp.MMIORegionDesc descriptor(self)
Any result(self, Optional[Union[int, float]] timeout=None)
__init__(self, Type result_type, cpp.MessageDataFuture cpp_future)
None add_done_callback(self, Callable[[Future], object] fn)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
__init__(self, BundlePort owner, cpp.ChannelPort cpp_port)
connect(self, Optional[int] buffer_size=None)
__init__(self, BundlePort owner, cpp.ReadChannelPort cpp_port)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
__init__(self, str id, int width)
__init__(self, str id, List[Tuple[str, "ESIType"]] fields)
bytearray serialize(self, obj)
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
_init_from_cpp(self, cpp.StructType cpp_type)
Tuple[object, bytearray] deserialize(self, bytearray data)
_init_from_cpp(self, cpp.TypeAliasType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
__init__(self, str id, str name, "ESIType" inner_type)
__init__(self, str id, int width)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
Tuple[object, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
bool try_write(self, msg=None)
bytearray __serialize_msg(self, msg=None)
bool write(self, msg=None)
__init__(self, BundlePort owner, cpp.WriteChannelPort cpp_port)
_get_esi_type(cpp.Type cpp_type)