13 from __future__
import annotations
15 from .
import esiCppAccel
as cpp
17 from typing
import TYPE_CHECKING
19 from .accelerator
import HWModule
21 from concurrent.futures
import Future
22 from typing
import Any, Callable, Dict, List, Optional, Tuple, Type, Union
26 """Get the wrapper class for a C++ type."""
27 for cpp_type_cls, fn
in __esi_mapping.items():
28 if isinstance(cpp_type, cpp_type_cls):
35 __esi_mapping: Dict[Type, Callable] = {
36 cpp.ChannelType:
lambda cpp_type:
_get_esi_type(cpp_type.inner)
47 """Does this type support host communication via Python? Returns either
48 '(True, None)' if it is, or '(False, reason)' if it is not."""
50 print(f
"supports_host: {self.cpp_type} {type(self)}")
52 return (
False,
"runtime only supports types with multiple of 8 bits")
55 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
56 """Is a Python object compatible with HW type? Returns either '(True,
57 None)' if it is, or '(False, reason)' if it is not."""
58 assert False,
"unimplemented"
62 """Size of this type, in bits. Negative for unbounded types."""
63 assert False,
"unimplemented"
67 """Maximum size of a value of this type, in bytes."""
68 bitwidth = int((self.
bit_widthbit_width + 7) / 8)
74 """Convert a Python object to a bytearray."""
75 assert False,
"unimplemented"
77 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
78 """Convert a bytearray to a Python object. Return the object and the
80 assert False,
"unimplemented"
88 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
90 return (
False, f
"void type cannot must represented by None, not {obj}")
101 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
103 raise ValueError(f
"void type cannot be represented by {data}")
104 return (
None, data[1:])
107 __esi_mapping[cpp.VoidType] = VoidType
113 self.
cpp_typecpp_type: cpp.BitsType = cpp_type
115 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
116 if not isinstance(obj, (bytearray, bytes, list)):
117 return (
False, f
"invalid type: {type(obj)}")
118 if isinstance(obj, list)
and not all(
119 [isinstance(b, int)
and b.bit_length() <= 8
for b
in obj]):
120 return (
False, f
"list item too large: {obj}")
121 if len(obj) != self.
max_sizemax_size:
122 return (
False, f
"wrong size: {len(obj)}")
129 def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray:
130 if isinstance(obj, bytearray):
132 if isinstance(obj, bytes)
or isinstance(obj, list):
133 return bytearray(obj)
134 raise ValueError(f
"cannot convert {obj} to bytearray")
136 def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]:
140 __esi_mapping[cpp.BitsType] = BitsType
146 self.
cpp_typecpp_type: cpp.IntegerType = cpp_type
155 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
156 if not isinstance(obj, int):
157 return (
False, f
"must be an int, not {type(obj)}")
159 return (
False, f
"out of range: {obj}")
163 return f
"uint{self.bit_width}"
166 return bytearray(int.to_bytes(obj, self.
max_sizemax_size,
"little"))
169 return (int.from_bytes(data[0:self.
max_sizemax_size],
170 "little"), data[self.
max_sizemax_size:])
173 __esi_mapping[cpp.UIntType] = UIntType
178 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
179 if not isinstance(obj, int):
180 return (
False, f
"must be an int, not {type(obj)}")
183 return (
False, f
"out of range: {obj}")
186 return (
False, f
"out of range: {obj}")
190 return f
"sint{self.bit_width}"
193 return bytearray(int.to_bytes(obj, self.
max_sizemax_size,
"little", signed=
True))
196 return (int.from_bytes(data[0:self.
max_sizemax_size],
"little",
197 signed=
True), data[self.
max_sizemax_size:])
200 __esi_mapping[cpp.SIntType] = SIntType
207 self.fields: List[Tuple[str, ESIType]] = [
213 widths = [ty.bit_width
for (_, ty)
in self.fields]
214 if any([w < 0
for w
in widths]):
218 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
220 if not isinstance(obj, dict):
223 for (fname, ftype)
in self.fields:
225 return (
False, f
"missing field '{fname}'")
226 fvalid, reason = ftype.is_valid(obj[fname])
228 return (
False, f
"invalid field '{fname}': {reason}")
230 if fields_count != len(obj):
231 return (
False,
"missing fields")
236 for (fname, ftype)
in reversed(self.fields):
238 ret.extend(ftype.serialize(fval))
241 def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
243 for (fname, ftype)
in reversed(self.fields):
244 (fval, data) = ftype.deserialize(data)
249 __esi_mapping[cpp.StructType] = StructType
263 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
264 if not isinstance(obj, list):
265 return (
False, f
"must be a list, not {type(obj)}")
266 if len(obj) != self.
sizesize:
267 return (
False, f
"wrong size: expected {self.size} not {len(obj)}")
268 for (idx, e)
in enumerate(obj):
271 return (
False, f
"invalid element {idx}: {reason}")
276 for e
in reversed(lst):
280 def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]:
282 for _
in range(self.
sizesize):
289 __esi_mapping[cpp.ArrayType] = ArrayType
293 """A unidirectional communication channel. This is the basic communication
294 method with an accelerator."""
296 def __init__(self, owner: BundlePort, cpp_port: cpp.ChannelPort):
301 def connect(self, buffer_size: Optional[int] =
None):
302 (supports_host, reason) = self.
typetype.supports_host
303 if not supports_host:
304 raise TypeError(f
"unsupported type: {reason}")
314 """A unidirectional communication channel from the host to the accelerator."""
316 def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
318 self.
cpp_portcpp_port: cpp.WriteChannelPort = cpp_port
321 valid, reason = self.
typetype.is_valid(msg)
324 f
"'{msg}' cannot be converted to '{self.type}': {reason}")
325 msg_bytes: bytearray = self.
typetype.serialize(msg)
329 """Write a typed message to the channel. Attempts to serialize 'msg' to what
330 the accelerator expects, but will fail if the object is not convertible to
336 """Like 'write', but uses the non-blocking tryWrite method of the underlying
337 port. Returns True if the write was successful, False otherwise."""
342 """A unidirectional communication channel from the accelerator to the host."""
344 def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
346 self.
cpp_portcpp_port: cpp.ReadChannelPort = cpp_port
349 """Read a typed message from the channel. Returns a deserialized object of a
350 type defined by the port type."""
353 (msg, leftover) = self.
typetype.deserialize(buffer)
354 if len(leftover) != 0:
355 raise ValueError(f
"leftover bytes: {leftover}")
360 """A collections of named, unidirectional communication channels."""
364 def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort):
366 if isinstance(cpp_port, cpp.Function):
367 return super().
__new__(FunctionPort)
368 if isinstance(cpp_port, cpp.MMIORegion):
369 return super().
__new__(MMIORegion)
372 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
384 """A specialization of `Future` for ESI messages. Wraps the cpp object and
385 deserializes the result. Hopefully overrides all the methods necessary for
386 proper operation, which is assumed to be not all of them."""
388 def __init__(self, result_type: Type, cpp_future: cpp.MessageDataFuture):
398 def result(self, timeout: Optional[Union[int, float]] =
None) -> Any:
402 (msg, leftover) = self.
result_typeresult_type.deserialize(result_bytes)
403 if len(leftover) != 0:
404 raise ValueError(f
"leftover bytes: {leftover}")
408 raise NotImplementedError(
"add_done_callback is not implemented")
412 """A region of memory-mapped I/O space. This is a collection of named
413 channels, which are either read or read-write. The channels are accessed
414 by name, and can be connected to the host."""
416 def __init__(self, owner: HWModule, cpp_port: cpp.MMIORegion):
422 return self.
regionregion.descriptor
424 def read(self, offset: int) -> bytearray:
425 """Read a value from the MMIO region at the given offset."""
428 def write(self, offset: int, data: bytearray) ->
None:
429 """Write a value to the MMIO region at the given offset."""
434 """A pair of channels which carry the input and output of a function."""
436 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
446 def call(self, **kwargs: Any) -> Future:
447 """Call the function with the given argument and returns a future of the
449 valid, reason = self.
arg_typearg_type.is_valid(kwargs)
452 f
"'{kwargs}' cannot be converted to '{self.arg_type}': {reason}")
453 arg_bytes: bytearray = self.
arg_typearg_type.serialize(kwargs)
457 def __call__(self, *args: Any, **kwds: Any) -> Future:
458 return self.
callcall(*args, **kwds)
WriteChannelPort & getWrite(const std::map< std::string, ChannelPort & > &channels, const std::string &name)
ReadChannelPort & getRead(const std::map< std::string, ChannelPort & > &channels, const std::string &name)
Tuple[bool, Optional[str]] is_valid(self, obj)
def __init__(self, cpp.ArrayType cpp_type)
Tuple[List[Any], bytearray] deserialize(self, bytearray data)
bytearray serialize(self, list lst)
bytearray serialize(self, Union[bytearray, bytes, List[int]] obj)
Tuple[bytearray, bytearray] deserialize(self, bytearray data)
def __init__(self, cpp.BitsType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
WritePort write_port(self, str channel_name)
def __new__(cls, HWModule owner, cpp.BundlePort cpp_port)
ReadPort read_port(self, str channel_name)
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Tuple[bool, Optional[str]] supports_host(self)
Tuple[object, bytearray] deserialize(self, bytearray data)
def __init__(self, cpp.Type cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Future call(self, **Any kwargs)
Future __call__(self, *Any args, **Any kwds)
def __init__(self, cpp.IntegerType cpp_type)
None write(self, int offset, bytearray data)
bytearray read(self, int offset)
def __init__(self, HWModule owner, cpp.MMIORegion cpp_port)
cpp.MMIORegionDesc descriptor(self)
Any result(self, Optional[Union[int, float]] timeout=None)
def __init__(self, Type result_type, cpp.MessageDataFuture cpp_future)
None add_done_callback(self, Callable[[Future], object] fn)
def connect(self, Optional[int] buffer_size=None)
def __init__(self, BundlePort owner, cpp.ChannelPort cpp_port)
def __init__(self, BundlePort owner, cpp.ReadChannelPort cpp_port)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
bytearray serialize(self, obj)
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
def __init__(self, cpp.StructType cpp_type)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
Tuple[object, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
bool try_write(self, msg=None)
def __init__(self, BundlePort owner, cpp.WriteChannelPort cpp_port)
bytearray __serialize_msg(self, msg=None)
bool write(self, msg=None)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
def _get_esi_type(cpp.Type cpp_type)