13 from __future__
import annotations
15 from .
import esiCppAccel
as cpp
17 from typing
import TYPE_CHECKING
19 from .accelerator
import HWModule
21 from concurrent.futures
import Future
22 from typing
import Any, Callable, Dict, List, Optional, Tuple, Type, Union
26 """Get the wrapper class for a C++ type."""
27 for cpp_type_cls, fn
in __esi_mapping.items():
28 if isinstance(cpp_type, cpp_type_cls):
35 __esi_mapping: Dict[Type, Callable] = {
36 cpp.ChannelType:
lambda cpp_type:
_get_esi_type(cpp_type.inner)
47 """Does this type support host communication via Python? Returns either
48 '(True, None)' if it is, or '(False, reason)' if it is not."""
51 return (
False,
"runtime only supports types with multiple of 8 bits")
54 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
55 """Is a Python object compatible with HW type? Returns either '(True,
56 None)' if it is, or '(False, reason)' if it is not."""
57 assert False,
"unimplemented"
61 """Size of this type, in bits. Negative for unbounded types."""
62 assert False,
"unimplemented"
66 """Maximum size of a value of this type, in bytes."""
67 bitwidth = int((self.
bit_widthbit_width + 7) / 8)
73 """Convert a Python object to a bytearray."""
74 assert False,
"unimplemented"
76 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
77 """Convert a bytearray to a Python object. Return the object and the
79 assert False,
"unimplemented"
87 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
89 return (
False, f
"void type cannot must represented by None, not {obj}")
100 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
102 raise ValueError(f
"void type cannot be represented by {data}")
103 return (
None, data[1:])
106 __esi_mapping[cpp.VoidType] = VoidType
112 self.
cpp_typecpp_type: cpp.BitsType = cpp_type
114 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
115 if not isinstance(obj, (bytearray, bytes, list)):
116 return (
False, f
"invalid type: {type(obj)}")
117 if isinstance(obj, list)
and not all(
118 [isinstance(b, int)
and b.bit_length() <= 8
for b
in obj]):
119 return (
False, f
"list item too large: {obj}")
120 if len(obj) != self.
max_sizemax_size:
121 return (
False, f
"wrong size: {len(obj)}")
128 def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray:
129 if isinstance(obj, bytearray):
131 if isinstance(obj, bytes)
or isinstance(obj, list):
132 return bytearray(obj)
133 raise ValueError(f
"cannot convert {obj} to bytearray")
135 def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]:
139 __esi_mapping[cpp.BitsType] = BitsType
145 self.
cpp_typecpp_type: cpp.IntegerType = cpp_type
154 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
155 if not isinstance(obj, int):
156 return (
False, f
"must be an int, not {type(obj)}")
158 return (
False, f
"out of range: {obj}")
162 return f
"uint{self.bit_width}"
165 return bytearray(int.to_bytes(obj, self.
max_sizemax_size,
"little"))
168 return (int.from_bytes(data[0:self.
max_sizemax_size],
169 "little"), data[self.
max_sizemax_size:])
172 __esi_mapping[cpp.UIntType] = UIntType
177 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
178 if not isinstance(obj, int):
179 return (
False, f
"must be an int, not {type(obj)}")
182 return (
False, f
"out of range: {obj}")
185 return (
False, f
"out of range: {obj}")
189 return f
"sint{self.bit_width}"
192 return bytearray(int.to_bytes(obj, self.
max_sizemax_size,
"little", signed=
True))
195 return (int.from_bytes(data[0:self.
max_sizemax_size],
"little",
196 signed=
True), data[self.
max_sizemax_size:])
199 __esi_mapping[cpp.SIntType] = SIntType
206 self.fields: List[Tuple[str, ESIType]] = [
212 widths = [ty.bit_width
for (_, ty)
in self.fields]
213 if any([w < 0
for w
in widths]):
217 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
219 if not isinstance(obj, dict):
222 for (fname, ftype)
in self.fields:
224 return (
False, f
"missing field '{fname}'")
225 fvalid, reason = ftype.is_valid(obj[fname])
227 return (
False, f
"invalid field '{fname}': {reason}")
229 if fields_count != len(obj):
230 return (
False,
"missing fields")
235 for (fname, ftype)
in reversed(self.fields):
237 ret.extend(ftype.serialize(fval))
240 def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
242 for (fname, ftype)
in reversed(self.fields):
243 (fval, data) = ftype.deserialize(data)
248 __esi_mapping[cpp.StructType] = StructType
262 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
263 if not isinstance(obj, list):
264 return (
False, f
"must be a list, not {type(obj)}")
265 if len(obj) != self.
sizesize:
266 return (
False, f
"wrong size: expected {self.size} not {len(obj)}")
267 for (idx, e)
in enumerate(obj):
270 return (
False, f
"invalid element {idx}: {reason}")
275 for e
in reversed(lst):
279 def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]:
281 for _
in range(self.
sizesize):
288 __esi_mapping[cpp.ArrayType] = ArrayType
292 """A unidirectional communication channel. This is the basic communication
293 method with an accelerator."""
295 def __init__(self, owner: BundlePort, cpp_port: cpp.ChannelPort):
300 def connect(self, buffer_size: Optional[int] =
None):
301 (supports_host, reason) = self.
typetype.supports_host
302 if not supports_host:
303 raise TypeError(f
"unsupported type: {reason}")
313 """A unidirectional communication channel from the host to the accelerator."""
315 def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
317 self.
cpp_portcpp_port: cpp.WriteChannelPort = cpp_port
320 valid, reason = self.
typetype.is_valid(msg)
323 f
"'{msg}' cannot be converted to '{self.type}': {reason}")
324 msg_bytes: bytearray = self.
typetype.serialize(msg)
328 """Write a typed message to the channel. Attempts to serialize 'msg' to what
329 the accelerator expects, but will fail if the object is not convertible to
335 """Like 'write', but uses the non-blocking tryWrite method of the underlying
336 port. Returns True if the write was successful, False otherwise."""
341 """A unidirectional communication channel from the accelerator to the host."""
343 def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
345 self.
cpp_portcpp_port: cpp.ReadChannelPort = cpp_port
348 """Read a typed message from the channel. Returns a deserialized object of a
349 type defined by the port type."""
352 (msg, leftover) = self.
typetype.deserialize(buffer)
353 if len(leftover) != 0:
354 raise ValueError(f
"leftover bytes: {leftover}")
359 """A collections of named, unidirectional communication channels."""
363 def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort):
365 if isinstance(cpp_port, cpp.Function):
366 return super().
__new__(FunctionPort)
367 if isinstance(cpp_port, cpp.MMIORegion):
368 return super().
__new__(MMIORegion)
371 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
383 """A specialization of `Future` for ESI messages. Wraps the cpp object and
384 deserializes the result. Hopefully overrides all the methods necessary for
385 proper operation, which is assumed to be not all of them."""
387 def __init__(self, result_type: Type, cpp_future: cpp.MessageDataFuture):
397 def result(self, timeout: Optional[Union[int, float]] =
None) -> Any:
401 (msg, leftover) = self.
result_typeresult_type.deserialize(result_bytes)
402 if len(leftover) != 0:
403 raise ValueError(f
"leftover bytes: {leftover}")
407 raise NotImplementedError(
"add_done_callback is not implemented")
411 """A region of memory-mapped I/O space. This is a collection of named
412 channels, which are either read or read-write. The channels are accessed
413 by name, and can be connected to the host."""
415 def __init__(self, owner: HWModule, cpp_port: cpp.MMIORegion):
421 return self.
regionregion.descriptor
423 def read(self, offset: int) -> bytearray:
424 """Read a value from the MMIO region at the given offset."""
427 def write(self, offset: int, data: bytearray) ->
None:
428 """Write a value to the MMIO region at the given offset."""
433 """A pair of channels which carry the input and output of a function."""
435 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
445 def call(self, **kwargs: Any) -> Future:
446 """Call the function with the given argument and returns a future of the
448 valid, reason = self.
arg_typearg_type.is_valid(kwargs)
451 f
"'{kwargs}' cannot be converted to '{self.arg_type}': {reason}")
452 arg_bytes: bytearray = self.
arg_typearg_type.serialize(kwargs)
456 def __call__(self, *args: Any, **kwds: Any) -> Future:
457 return self.
callcall(*args, **kwds)
WriteChannelPort & getWrite(const std::map< std::string, ChannelPort & > &channels, const std::string &name)
ReadChannelPort & getRead(const std::map< std::string, ChannelPort & > &channels, const std::string &name)
Tuple[bool, Optional[str]] is_valid(self, obj)
def __init__(self, cpp.ArrayType cpp_type)
Tuple[List[Any], bytearray] deserialize(self, bytearray data)
bytearray serialize(self, list lst)
bytearray serialize(self, Union[bytearray, bytes, List[int]] obj)
Tuple[bytearray, bytearray] deserialize(self, bytearray data)
def __init__(self, cpp.BitsType cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
WritePort write_port(self, str channel_name)
def __new__(cls, HWModule owner, cpp.BundlePort cpp_port)
ReadPort read_port(self, str channel_name)
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Tuple[bool, Optional[str]] supports_host(self)
Tuple[object, bytearray] deserialize(self, bytearray data)
def __init__(self, cpp.Type cpp_type)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Future call(self, **Any kwargs)
Future __call__(self, *Any args, **Any kwds)
def __init__(self, cpp.IntegerType cpp_type)
None write(self, int offset, bytearray data)
bytearray read(self, int offset)
def __init__(self, HWModule owner, cpp.MMIORegion cpp_port)
cpp.MMIORegionDesc descriptor(self)
Any result(self, Optional[Union[int, float]] timeout=None)
def __init__(self, Type result_type, cpp.MessageDataFuture cpp_future)
None add_done_callback(self, Callable[[Future], object] fn)
def connect(self, Optional[int] buffer_size=None)
def __init__(self, BundlePort owner, cpp.ChannelPort cpp_port)
def __init__(self, BundlePort owner, cpp.ReadChannelPort cpp_port)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
bytearray serialize(self, obj)
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
def __init__(self, cpp.StructType cpp_type)
Tuple[int, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, int obj)
Tuple[object, bytearray] deserialize(self, bytearray data)
Tuple[bool, Optional[str]] is_valid(self, obj)
bytearray serialize(self, obj)
bool try_write(self, msg=None)
def __init__(self, BundlePort owner, cpp.WriteChannelPort cpp_port)
bytearray __serialize_msg(self, msg=None)
bool write(self, msg=None)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
def _get_esi_type(cpp.Type cpp_type)