13 from __future__
import annotations
15 from .
import esiCppAccel
as cpp
17 from typing
import TYPE_CHECKING
19 from .accelerator
import HWModule
21 from concurrent.futures
import Future
22 from typing
import Any, Callable, Dict, List, Optional, Tuple, Type, Union
26 """Get the wrapper class for a C++ type."""
27 for cpp_type_cls, fn
in __esi_mapping.items():
28 if isinstance(cpp_type, cpp_type_cls):
35 __esi_mapping: Dict[Type, Callable] = {
36 cpp.ChannelType:
lambda cpp_type:
_get_esi_type(cpp_type.inner)
47 """Does this type support host communication via Python? Returns either
48 '(True, None)' if it is, or '(False, reason)' if it is not."""
51 return (
False,
"runtime only supports types with multiple of 8 bits")
54 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
55 """Is a Python object compatible with HW type? Returns either '(True,
56 None)' if it is, or '(False, reason)' if it is not."""
57 assert False,
"unimplemented"
61 """Size of this type, in bits. Negative for unbounded types."""
62 assert False,
"unimplemented"
66 """Maximum size of a value of this type, in bytes."""
67 bitwidth = int((self.
bit_widthbit_width + 7) / 8)
73 """Convert a Python object to a bytearray."""
74 assert False,
"unimplemented"
76 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
77 """Convert a bytearray to a Python object. Return the object and the
79 assert False,
"unimplemented"
87 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
89 return (
False, f
"void type cannot must represented by None, not {obj}")
100 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
102 raise ValueError(f
"void type cannot be represented by {data}")
103 return (
None, data[1:])
106 __esi_mapping[cpp.VoidType] = VoidType
112 self.
cpp_typecpp_type: cpp.BitsType = cpp_type
114 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
115 if not isinstance(obj, (bytearray, bytes, list)):
116 return (
False, f
"invalid type: {type(obj)}")
117 if isinstance(obj, list)
and not all(
118 [isinstance(b, int)
and b.bit_length() <= 8
for b
in obj]):
119 return (
False, f
"list item too large: {obj}")
120 if len(obj) != self.
max_sizemax_size:
121 return (
False, f
"wrong size: {len(obj)}")
128 def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray:
129 if isinstance(obj, bytearray):
131 if isinstance(obj, bytes)
or isinstance(obj, list):
132 return bytearray(obj)
133 raise ValueError(f
"cannot convert {obj} to bytearray")
135 def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]:
139 __esi_mapping[cpp.BitsType] = BitsType
145 self.
cpp_typecpp_type: cpp.IntegerType = cpp_type
154 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
155 if not isinstance(obj, int):
156 return (
False, f
"must be an int, not {type(obj)}")
158 return (
False, f
"out of range: {obj}")
162 return f
"uint{self.bit_width}"
165 return bytearray(int.to_bytes(obj, self.
max_sizemax_size,
"little"))
168 return (int.from_bytes(data[0:self.
max_sizemax_size],
169 "little"), data[self.
max_sizemax_size:])
172 __esi_mapping[cpp.UIntType] = UIntType
177 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
178 if not isinstance(obj, int):
179 return (
False, f
"must be an int, not {type(obj)}")
182 return (
False, f
"out of range: {obj}")
185 return (
False, f
"out of range: {obj}")
189 return f
"sint{self.bit_width}"
192 return bytearray(int.to_bytes(obj, self.
max_sizemax_size,
"little", signed=
True))
195 return (int.from_bytes(data[0:self.
max_sizemax_size],
"little",
196 signed=
True), data[self.
max_sizemax_size:])
199 __esi_mapping[cpp.SIntType] = SIntType
206 self.fields: List[Tuple[str, ESIType]] = [
212 widths = [ty.bit_width
for (_, ty)
in self.fields]
213 if any([w < 0
for w
in widths]):
217 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
219 if not isinstance(obj, dict):
222 for (fname, ftype)
in self.fields:
224 return (
False, f
"missing field '{fname}'")
225 fvalid, reason = ftype.is_valid(obj[fname])
227 return (
False, f
"invalid field '{fname}': {reason}")
229 if fields_count != len(obj):
230 return (
False,
"missing fields")
235 for (fname, ftype)
in reversed(self.fields):
237 ret.extend(ftype.serialize(fval))
240 def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
242 for (fname, ftype)
in reversed(self.fields):
243 (fval, data) = ftype.deserialize(data)
248 __esi_mapping[cpp.StructType] = StructType
262 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
263 if not isinstance(obj, list):
264 return (
False, f
"must be a list, not {type(obj)}")
265 if len(obj) != self.
sizesize:
266 return (
False, f
"wrong size: expected {self.size} not {len(obj)}")
267 for (idx, e)
in enumerate(obj):
270 return (
False, f
"invalid element {idx}: {reason}")
275 for e
in reversed(lst):
279 def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]:
281 for _
in range(self.
sizesize):
288 __esi_mapping[cpp.ArrayType] = ArrayType
292 """A unidirectional communication channel. This is the basic communication
293 method with an accelerator."""
295 def __init__(self, owner: BundlePort, cpp_port: cpp.ChannelPort):
299 (supports_host, reason) = self.
typetype.supports_host
300 if not supports_host:
301 raise TypeError(f
"unsupported type: {reason}")
309 """A unidirectional communication channel from the host to the accelerator."""
311 def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
313 self.
cpp_portcpp_port: cpp.WriteChannelPort = cpp_port
316 """Write a typed message to the channel. Attempts to serialize 'msg' to what
317 the accelerator expects, but will fail if the object is not convertible to
320 valid, reason = self.
typetype.is_valid(msg)
323 f
"'{msg}' cannot be converted to '{self.type}': {reason}")
324 msg_bytes: bytearray = self.
typetype.serialize(msg)
330 """A unidirectional communication channel from the accelerator to the host."""
332 def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
334 self.
cpp_portcpp_port: cpp.ReadChannelPort = cpp_port
336 def read(self) -> Tuple[bool, Optional[object]]:
337 """Read a typed message from the channel. Returns a deserialized object of a
338 type defined by the port type."""
343 (msg, leftover) = self.
typetype.deserialize(buffer)
344 if len(leftover) != 0:
345 raise ValueError(f
"leftover bytes: {leftover}")
350 """A collections of named, unidirectional communication channels."""
354 def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort):
356 if isinstance(cpp_port, cpp.Function):
357 return super().
__new__(FunctionPort)
360 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
372 """A specialization of `Future` for ESI messages. Wraps the cpp object and
373 deserializes the result. Hopefully overrides all the methods necessary for
374 proper operation, which is assumed to be not all of them."""
376 def __init__(self, result_type: Type, cpp_future: cpp.MessageDataFuture):
386 def result(self, timeout: Optional[Union[int, float]] =
None) -> Any:
390 (msg, leftover) = self.
result_typeresult_type.deserialize(result_bytes)
391 if len(leftover) != 0:
392 raise ValueError(f
"leftover bytes: {leftover}")
396 raise NotImplementedError(
"add_done_callback is not implemented")
400 """A pair of channels which carry the input and output of a function."""
402 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
412 def call(self, **kwargs: Any) -> Future:
413 """Call the function with the given argument and returns a future of the
415 valid, reason = self.
arg_typearg_type.is_valid(kwargs)
418 f
"'{kwargs}' cannot be converted to '{self.arg_type}': {reason}")
419 arg_bytes: bytearray = self.
arg_typearg_type.serialize(kwargs)
423 def __call__(self, *args: Any, **kwds: Any) -> Future:
424 return self.
callcall(*args, **kwds)
Tuple[bool, Optional[str]] is_valid(self, obj)
def __init__(self, cpp.ArrayType cpp_type)
Tuple[List[Any], bytearray] deserialize(self, bytearray data)
bytearray serialize(self, list lst)
bytearray serialize(self, Union[bytearray, bytes, List[int]] obj)
Tuple[bytearray, bytearray] deserialize(self, bytearray data)
def __init__(self, cpp.BitsType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
WritePort write_port(self, str channel_name)
def __new__(cls, HWModule owner, cpp.BundlePort cpp_port)
ReadPort read_port(self, str channel_name)
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Tuple[bool, Optional[str]] supports_host(self)
Tuple[object, bytearray] deserialize(self, bytearray data)
def __init__(self, cpp.Type cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Future call(self, **Any kwargs)
Future __call__(self, *Any args, **Any kwds)
def __init__(self, cpp.IntegerType cpp_type)
Any result(self, Optional[Union[int, float]] timeout=None)
def __init__(self, Type result_type, cpp.MessageDataFuture cpp_future)
None add_done_callback(self, Callable[[Future], object] fn)
def __init__(self, BundlePort owner, cpp.ChannelPort cpp_port)
def __init__(self, BundlePort owner, cpp.ReadChannelPort cpp_port)
Tuple[bool, Optional[object]] read(self)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
bytearray serialize(self, obj)
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
def __init__(self, cpp.StructType cpp_type)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
Tuple[object, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
def __init__(self, BundlePort owner, cpp.WriteChannelPort cpp_port)
bool write(self, msg=None)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
def _get_esi_type(cpp.Type cpp_type)