13from __future__
import annotations
15from .
import esiCppAccel
as cpp
17from typing
import TYPE_CHECKING
19 from .accelerator
import HWModule
21from concurrent.futures
import Future
22from typing
import Any, Callable, Dict, List, Optional, Tuple, Type, Union
26 """Get the wrapper class for a C++ type."""
27 for cpp_type_cls, fn
in __esi_mapping.items():
28 if isinstance(cpp_type, cpp_type_cls):
35__esi_mapping: Dict[Type, Callable] = {
36 cpp.ChannelType:
lambda cpp_type:
_get_esi_type(cpp_type.inner)
47 """Does this type support host communication via Python? Returns either
48 '(True, None)' if it is, or '(False, reason)' if it is not."""
51 return (
False,
"runtime only supports types with multiple of 8 bits")
54 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
55 """Is a Python object compatible with HW type? Returns either '(True,
56 None)' if it is, or '(False, reason)' if it is not."""
57 assert False,
"unimplemented"
61 """Size of this type, in bits. Negative for unbounded types."""
62 assert False,
"unimplemented"
66 """Maximum size of a value of this type, in bytes."""
73 """Convert a Python object to a bytearray."""
74 assert False,
"unimplemented"
76 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
77 """Convert a bytearray to a Python object. Return the object and the
79 assert False,
"unimplemented"
87 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
89 return (
False, f
"void type cannot must represented by None, not {obj}")
100 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
102 raise ValueError(f
"void type cannot be represented by {data}")
103 return (
None, data[1:])
106__esi_mapping[cpp.VoidType] = VoidType
112 self.
cpp_type: cpp.BitsType = cpp_type
114 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
115 if not isinstance(obj, (bytearray, bytes, list)):
116 return (
False, f
"invalid type: {type(obj)}")
117 if isinstance(obj, list)
and not all(
118 [isinstance(b, int)
and b.bit_length() <= 8
for b
in obj]):
119 return (
False, f
"list item too large: {obj}")
121 return (
False, f
"wrong size: {len(obj)}")
128 def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray:
129 if isinstance(obj, bytearray):
131 if isinstance(obj, bytes)
or isinstance(obj, list):
132 return bytearray(obj)
133 raise ValueError(f
"cannot convert {obj} to bytearray")
135 def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]:
139__esi_mapping[cpp.BitsType] = BitsType
145 self.
cpp_type: cpp.IntegerType = cpp_type
154 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
155 if not isinstance(obj, int):
156 return (
False, f
"must be an int, not {type(obj)}")
158 return (
False, f
"out of range: {obj}")
162 return f
"uint{self.bit_width}"
172__esi_mapping[cpp.UIntType] = UIntType
177 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
178 if not isinstance(obj, int):
179 return (
False, f
"must be an int, not {type(obj)}")
182 return (
False, f
"out of range: {obj}")
185 return (
False, f
"out of range: {obj}")
189 return f
"sint{self.bit_width}"
199__esi_mapping[cpp.SIntType] = SIntType
206 self.
fields: List[Tuple[str, ESIType]] = [
212 widths = [ty.bit_width
for (_, ty)
in self.
fields]
213 if any([w < 0
for w
in widths]):
217 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
219 if not isinstance(obj, dict):
222 for (fname, ftype)
in self.
fields:
224 return (
False, f
"missing field '{fname}'")
225 fvalid, reason = ftype.is_valid(obj[fname])
227 return (
False, f
"invalid field '{fname}': {reason}")
229 if fields_count != len(obj):
230 return (
False,
"missing fields")
235 for (fname, ftype)
in reversed(self.
fields):
237 ret.extend(ftype.serialize(fval))
240 def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
242 for (fname, ftype)
in reversed(self.
fields):
243 (fval, data) = ftype.deserialize(data)
248__esi_mapping[cpp.StructType] = StructType
262 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
263 if not isinstance(obj, list):
264 return (
False, f
"must be a list, not {type(obj)}")
265 if len(obj) != self.
size:
266 return (
False, f
"wrong size: expected {self.size} not {len(obj)}")
267 for (idx, e)
in enumerate(obj):
270 return (
False, f
"invalid element {idx}: {reason}")
275 for e
in reversed(lst):
279 def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]:
281 for _
in range(self.
size):
288__esi_mapping[cpp.ArrayType] = ArrayType
292 """A unidirectional communication channel. This is the basic communication
293 method with an accelerator."""
295 def __init__(self, owner: BundlePort, cpp_port: cpp.ChannelPort):
300 def connect(self, buffer_size: Optional[int] =
None):
301 (supports_host, reason) = self.
type.supports_host
302 if not supports_host:
303 raise TypeError(f
"unsupported type: {reason}")
313 """A unidirectional communication channel from the host to the accelerator."""
315 def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
317 self.
cpp_port: cpp.WriteChannelPort = cpp_port
320 valid, reason = self.
type.is_valid(msg)
323 f
"'{msg}' cannot be converted to '{self.type}': {reason}")
324 msg_bytes: bytearray = self.
type.serialize(msg)
328 """Write a typed message to the channel. Attempts to serialize 'msg' to what
329 the accelerator expects, but will fail if the object is not convertible to
335 """Like 'write', but uses the non-blocking tryWrite method of the underlying
336 port. Returns True if the write was successful, False otherwise."""
341 """A unidirectional communication channel from the accelerator to the host."""
343 def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
345 self.
cpp_port: cpp.ReadChannelPort = cpp_port
348 """Read a typed message from the channel. Returns a deserialized object of a
349 type defined by the port type."""
352 (msg, leftover) = self.
type.deserialize(buffer)
353 if len(leftover) != 0:
354 raise ValueError(f
"leftover bytes: {leftover}")
359 """A collections of named, unidirectional communication channels."""
363 def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort):
365 if isinstance(cpp_port, cpp.Function):
366 return super().
__new__(FunctionPort)
367 if isinstance(cpp_port, cpp.MMIORegion):
368 return super().
__new__(MMIORegion)
371 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
383 """A specialization of `Future` for ESI messages. Wraps the cpp object and
384 deserializes the result. Hopefully overrides all the methods necessary for
385 proper operation, which is assumed to be not all of them."""
387 def __init__(self, result_type: Type, cpp_future: cpp.MessageDataFuture):
397 def result(self, timeout: Optional[Union[int, float]] =
None) -> Any:
401 (msg, leftover) = self.
result_type.deserialize(result_bytes)
402 if len(leftover) != 0:
403 raise ValueError(f
"leftover bytes: {leftover}")
407 raise NotImplementedError(
"add_done_callback is not implemented")
411 """A region of memory-mapped I/O space. This is a collection of named
412 channels, which are either read or read-write. The channels are accessed
413 by name, and can be connected to the host."""
415 def __init__(self, owner: HWModule, cpp_port: cpp.MMIORegion):
421 return self.
region.descriptor
423 def read(self, offset: int) -> bytearray:
424 """Read a value from the MMIO region at the given offset."""
427 def write(self, offset: int, data: bytearray) ->
None:
428 """Write a value to the MMIO region at the given offset."""
433 """A pair of channels which carry the input and output of a function."""
435 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
445 def call(self, **kwargs: Any) -> Future:
446 """Call the function with the given argument and returns a future of the
448 valid, reason = self.
arg_type.is_valid(kwargs)
451 f
"'{kwargs}' cannot be converted to '{self.arg_type}': {reason}")
452 arg_bytes: bytearray = self.
arg_type.serialize(kwargs)
456 def __call__(self, *args: Any, **kwds: Any) -> Future:
457 return self.
call(*args, **kwds)
Tuple[bool, Optional[str]] is_valid(self, obj)
__init__(self, cpp.ArrayType cpp_type)
Tuple[List[Any], bytearray] deserialize(self, bytearray data)
bytearray serialize(self, list lst)
bytearray serialize(self, Union[bytearray, bytes, List[int]] obj)
Tuple[bytearray, bytearray] deserialize(self, bytearray data)
__init__(self, cpp.BitsType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
__new__(cls, HWModule owner, cpp.BundlePort cpp_port)
WritePort write_port(self, str channel_name)
ReadPort read_port(self, str channel_name)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
Tuple[bool, Optional[str]] supports_host(self)
Tuple[object, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
__init__(self, cpp.Type cpp_type)
bytearray serialize(self, obj)
Future call(self, **Any kwargs)
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
Future __call__(self, *Any args, **Any kwds)
__init__(self, cpp.IntegerType cpp_type)
None write(self, int offset, bytearray data)
bytearray read(self, int offset)
__init__(self, HWModule owner, cpp.MMIORegion cpp_port)
cpp.MMIORegionDesc descriptor(self)
Any result(self, Optional[Union[int, float]] timeout=None)
__init__(self, Type result_type, cpp.MessageDataFuture cpp_future)
None add_done_callback(self, Callable[[Future], object] fn)
__init__(self, BundlePort owner, cpp.ChannelPort cpp_port)
connect(self, Optional[int] buffer_size=None)
__init__(self, BundlePort owner, cpp.ReadChannelPort cpp_port)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
__init__(self, cpp.StructType cpp_type)
bytearray serialize(self, obj)
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
Tuple[object, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
bool try_write(self, msg=None)
bytearray __serialize_msg(self, msg=None)
bool write(self, msg=None)
__init__(self, BundlePort owner, cpp.WriteChannelPort cpp_port)
_get_esi_type(cpp.Type cpp_type)