CIRCT  20.0.0git
types.py
Go to the documentation of this file.
1 # ===-----------------------------------------------------------------------===#
2 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3 # See https://llvm.org/LICENSE.txt for license information.
4 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 # ===-----------------------------------------------------------------------===#
6 #
7 # The structure of the Python classes and hierarchy roughly mirrors the C++
8 # side, but wraps the C++ objects. The wrapper classes sometimes add convenience
9 # functionality and serve to return wrapped versions of the returned objects.
10 #
11 # ===-----------------------------------------------------------------------===#
12 
13 from __future__ import annotations
14 
15 from . import esiCppAccel as cpp
16 
17 from typing import TYPE_CHECKING
18 if TYPE_CHECKING:
19  from .accelerator import HWModule
20 
21 from concurrent.futures import Future
22 from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
23 
24 
25 def _get_esi_type(cpp_type: cpp.Type):
26  """Get the wrapper class for a C++ type."""
27  for cpp_type_cls, fn in __esi_mapping.items():
28  if isinstance(cpp_type, cpp_type_cls):
29  return fn(cpp_type)
30  return ESIType(cpp_type)
31 
32 
33 # Mapping from C++ types to functions constructing the Python object
34 # corresponding to that type.
35 __esi_mapping: Dict[Type, Callable] = {
36  cpp.ChannelType: lambda cpp_type: _get_esi_type(cpp_type.inner)
37 }
38 
39 
40 class ESIType:
41 
42  def __init__(self, cpp_type: cpp.Type):
43  self.cpp_typecpp_type = cpp_type
44 
45  @property
46  def supports_host(self) -> Tuple[bool, Optional[str]]:
47  """Does this type support host communication via Python? Returns either
48  '(True, None)' if it is, or '(False, reason)' if it is not."""
49 
50  print(f"supports_host: {self.cpp_type} {type(self)}")
51  if self.bit_widthbit_width % 8 != 0:
52  return (False, "runtime only supports types with multiple of 8 bits")
53  return (True, None)
54 
55  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
56  """Is a Python object compatible with HW type? Returns either '(True,
57  None)' if it is, or '(False, reason)' if it is not."""
58  assert False, "unimplemented"
59 
60  @property
61  def bit_width(self) -> int:
62  """Size of this type, in bits. Negative for unbounded types."""
63  assert False, "unimplemented"
64 
65  @property
66  def max_size(self) -> int:
67  """Maximum size of a value of this type, in bytes."""
68  bitwidth = int((self.bit_widthbit_width + 7) / 8)
69  if bitwidth < 0:
70  return bitwidth
71  return bitwidth
72 
73  def serialize(self, obj) -> bytearray:
74  """Convert a Python object to a bytearray."""
75  assert False, "unimplemented"
76 
77  def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
78  """Convert a bytearray to a Python object. Return the object and the
79  leftover bytes."""
80  assert False, "unimplemented"
81 
82  def __str__(self) -> str:
83  return str(self.cpp_typecpp_type)
84 
85 
87 
88  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
89  if obj is not None:
90  return (False, f"void type cannot must represented by None, not {obj}")
91  return (True, None)
92 
93  @property
94  def bit_width(self) -> int:
95  return 8
96 
97  def serialize(self, obj) -> bytearray:
98  # By convention, void is represented by a single byte of value 0.
99  return bytearray([0])
100 
101  def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
102  if len(data) == 0:
103  raise ValueError(f"void type cannot be represented by {data}")
104  return (None, data[1:])
105 
106 
107 __esi_mapping[cpp.VoidType] = VoidType
108 
109 
111 
112  def __init__(self, cpp_type: cpp.BitsType):
113  self.cpp_typecpp_type: cpp.BitsType = cpp_type
114 
115  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
116  if not isinstance(obj, (bytearray, bytes, list)):
117  return (False, f"invalid type: {type(obj)}")
118  if isinstance(obj, list) and not all(
119  [isinstance(b, int) and b.bit_length() <= 8 for b in obj]):
120  return (False, f"list item too large: {obj}")
121  if len(obj) != self.max_sizemax_size:
122  return (False, f"wrong size: {len(obj)}")
123  return (True, None)
124 
125  @property
126  def bit_width(self) -> int:
127  return self.cpp_typecpp_type.width
128 
129  def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray:
130  if isinstance(obj, bytearray):
131  return obj
132  if isinstance(obj, bytes) or isinstance(obj, list):
133  return bytearray(obj)
134  raise ValueError(f"cannot convert {obj} to bytearray")
135 
136  def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]:
137  return (data[0:self.max_sizemax_size], data[self.max_sizemax_size:])
138 
139 
140 __esi_mapping[cpp.BitsType] = BitsType
141 
142 
144 
145  def __init__(self, cpp_type: cpp.IntegerType):
146  self.cpp_typecpp_type: cpp.IntegerType = cpp_type
147 
148  @property
149  def bit_width(self) -> int:
150  return self.cpp_typecpp_type.width
151 
152 
154 
155  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
156  if not isinstance(obj, int):
157  return (False, f"must be an int, not {type(obj)}")
158  if obj < 0 or obj.bit_length() > self.bit_widthbit_widthbit_width:
159  return (False, f"out of range: {obj}")
160  return (True, None)
161 
162  def __str__(self) -> str:
163  return f"uint{self.bit_width}"
164 
165  def serialize(self, obj: int) -> bytearray:
166  return bytearray(int.to_bytes(obj, self.max_sizemax_size, "little"))
167 
168  def deserialize(self, data: bytearray) -> Tuple[int, bytearray]:
169  return (int.from_bytes(data[0:self.max_sizemax_size],
170  "little"), data[self.max_sizemax_size:])
171 
172 
173 __esi_mapping[cpp.UIntType] = UIntType
174 
175 
177 
178  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
179  if not isinstance(obj, int):
180  return (False, f"must be an int, not {type(obj)}")
181  if obj < 0:
182  if (-1 * obj) > 2**(self.bit_widthbit_widthbit_width - 1):
183  return (False, f"out of range: {obj}")
184  elif obj < 0:
185  if obj >= 2**(self.bit_widthbit_widthbit_width - 1) - 1:
186  return (False, f"out of range: {obj}")
187  return (True, None)
188 
189  def __str__(self) -> str:
190  return f"sint{self.bit_width}"
191 
192  def serialize(self, obj: int) -> bytearray:
193  return bytearray(int.to_bytes(obj, self.max_sizemax_size, "little", signed=True))
194 
195  def deserialize(self, data: bytearray) -> Tuple[int, bytearray]:
196  return (int.from_bytes(data[0:self.max_sizemax_size], "little",
197  signed=True), data[self.max_sizemax_size:])
198 
199 
200 __esi_mapping[cpp.SIntType] = SIntType
201 
202 
204 
205  def __init__(self, cpp_type: cpp.StructType):
206  self.cpp_typecpp_typecpp_type = cpp_type
207  self.fields: List[Tuple[str, ESIType]] = [
208  (name, _get_esi_type(ty)) for (name, ty) in cpp_type.fields
209  ]
210 
211  @property
212  def bit_width(self) -> int:
213  widths = [ty.bit_width for (_, ty) in self.fields]
214  if any([w < 0 for w in widths]):
215  return -1
216  return sum(widths)
217 
218  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
219  fields_count = 0
220  if not isinstance(obj, dict):
221  obj = obj.__dict__
222 
223  for (fname, ftype) in self.fields:
224  if fname not in obj:
225  return (False, f"missing field '{fname}'")
226  fvalid, reason = ftype.is_valid(obj[fname])
227  if not fvalid:
228  return (False, f"invalid field '{fname}': {reason}")
229  fields_count += 1
230  if fields_count != len(obj):
231  return (False, "missing fields")
232  return (True, None)
233 
234  def serialize(self, obj) -> bytearray:
235  ret = bytearray()
236  for (fname, ftype) in reversed(self.fields):
237  fval = obj[fname]
238  ret.extend(ftype.serialize(fval))
239  return ret
240 
241  def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
242  ret = {}
243  for (fname, ftype) in reversed(self.fields):
244  (fval, data) = ftype.deserialize(data)
245  ret[fname] = fval
246  return (ret, data)
247 
248 
249 __esi_mapping[cpp.StructType] = StructType
250 
251 
253 
254  def __init__(self, cpp_type: cpp.ArrayType):
255  self.cpp_typecpp_typecpp_type = cpp_type
256  self.element_typeelement_type = _get_esi_type(cpp_type.element)
257  self.sizesize = cpp_type.size
258 
259  @property
260  def bit_width(self) -> int:
261  return self.element_typeelement_type.bit_width * self.sizesize
262 
263  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
264  if not isinstance(obj, list):
265  return (False, f"must be a list, not {type(obj)}")
266  if len(obj) != self.sizesize:
267  return (False, f"wrong size: expected {self.size} not {len(obj)}")
268  for (idx, e) in enumerate(obj):
269  evalid, reason = self.element_typeelement_type.is_valid(e)
270  if not evalid:
271  return (False, f"invalid element {idx}: {reason}")
272  return (True, None)
273 
274  def serialize(self, lst: list) -> bytearray:
275  ret = bytearray()
276  for e in reversed(lst):
277  ret.extend(self.element_typeelement_type.serialize(e))
278  return ret
279 
280  def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]:
281  ret = []
282  for _ in range(self.sizesize):
283  (obj, data) = self.element_typeelement_type.deserialize(data)
284  ret.append(obj)
285  ret.reverse()
286  return (ret, data)
287 
288 
289 __esi_mapping[cpp.ArrayType] = ArrayType
290 
291 
292 class Port:
293  """A unidirectional communication channel. This is the basic communication
294  method with an accelerator."""
295 
296  def __init__(self, owner: BundlePort, cpp_port: cpp.ChannelPort):
297  self.ownerowner = owner
298  self.cpp_portcpp_port = cpp_port
299  self.typetype = _get_esi_type(cpp_port.type)
300 
301  def connect(self, buffer_size: Optional[int] = None):
302  (supports_host, reason) = self.typetype.supports_host
303  if not supports_host:
304  raise TypeError(f"unsupported type: {reason}")
305 
306  self.cpp_portcpp_port.connect(buffer_size)
307  return self
308 
309  def disconnect(self):
310  self.cpp_portcpp_port.disconnect()
311 
312 
314  """A unidirectional communication channel from the host to the accelerator."""
315 
316  def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
317  super().__init__(owner, cpp_port)
318  self.cpp_portcpp_port: cpp.WriteChannelPort = cpp_port
319 
320  def __serialize_msg(self, msg=None) -> bytearray:
321  valid, reason = self.typetype.is_valid(msg)
322  if not valid:
323  raise ValueError(
324  f"'{msg}' cannot be converted to '{self.type}': {reason}")
325  msg_bytes: bytearray = self.typetype.serialize(msg)
326  return msg_bytes
327 
328  def write(self, msg=None) -> bool:
329  """Write a typed message to the channel. Attempts to serialize 'msg' to what
330  the accelerator expects, but will fail if the object is not convertible to
331  the port type."""
332  self.cpp_portcpp_port.write(self.__serialize_msg__serialize_msg(msg))
333  return True
334 
335  def try_write(self, msg=None) -> bool:
336  """Like 'write', but uses the non-blocking tryWrite method of the underlying
337  port. Returns True if the write was successful, False otherwise."""
338  return self.cpp_portcpp_port.tryWrite(self.__serialize_msg__serialize_msg(msg))
339 
340 
341 class ReadPort(Port):
342  """A unidirectional communication channel from the accelerator to the host."""
343 
344  def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
345  super().__init__(owner, cpp_port)
346  self.cpp_portcpp_port: cpp.ReadChannelPort = cpp_port
347 
348  def read(self) -> object:
349  """Read a typed message from the channel. Returns a deserialized object of a
350  type defined by the port type."""
351 
352  buffer = self.cpp_portcpp_port.read()
353  (msg, leftover) = self.typetype.deserialize(buffer)
354  if len(leftover) != 0:
355  raise ValueError(f"leftover bytes: {leftover}")
356  return msg
357 
358 
360  """A collections of named, unidirectional communication channels."""
361 
362  # When creating a new port, we need to determine if it is a service port and
363  # instantiate it correctly.
364  def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort):
365  # TODO: add a proper registration mechanism for service ports.
366  if isinstance(cpp_port, cpp.Function):
367  return super().__new__(FunctionPort)
368  if isinstance(cpp_port, cpp.MMIORegion):
369  return super().__new__(MMIORegion)
370  return super().__new__(cls)
371 
372  def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
373  self.ownerowner = owner
374  self.cpp_portcpp_port = cpp_port
375 
376  def write_port(self, channel_name: str) -> WritePort:
377  return WritePort(self, self.cpp_portcpp_port.getWrite(channel_name))
378 
379  def read_port(self, channel_name: str) -> ReadPort:
380  return ReadPort(self, self.cpp_portcpp_port.getRead(channel_name))
381 
382 
383 class MessageFuture(Future):
384  """A specialization of `Future` for ESI messages. Wraps the cpp object and
385  deserializes the result. Hopefully overrides all the methods necessary for
386  proper operation, which is assumed to be not all of them."""
387 
388  def __init__(self, result_type: Type, cpp_future: cpp.MessageDataFuture):
389  self.result_typeresult_type = result_type
390  self.cpp_futurecpp_future = cpp_future
391 
392  def running(self) -> bool:
393  return True
394 
395  def done(self) -> bool:
396  return self.cpp_futurecpp_future.valid()
397 
398  def result(self, timeout: Optional[Union[int, float]] = None) -> Any:
399  # TODO: respect timeout
400  self.cpp_futurecpp_future.wait()
401  result_bytes = self.cpp_futurecpp_future.get()
402  (msg, leftover) = self.result_typeresult_type.deserialize(result_bytes)
403  if len(leftover) != 0:
404  raise ValueError(f"leftover bytes: {leftover}")
405  return msg
406 
407  def add_done_callback(self, fn: Callable[[Future], object]) -> None:
408  raise NotImplementedError("add_done_callback is not implemented")
409 
410 
412  """A region of memory-mapped I/O space. This is a collection of named
413  channels, which are either read or read-write. The channels are accessed
414  by name, and can be connected to the host."""
415 
416  def __init__(self, owner: HWModule, cpp_port: cpp.MMIORegion):
417  super().__init__(owner, cpp_port)
418  self.regionregion = cpp_port
419 
420  @property
421  def descriptor(self) -> cpp.MMIORegionDesc:
422  return self.regionregion.descriptor
423 
424  def read(self, offset: int) -> bytearray:
425  """Read a value from the MMIO region at the given offset."""
426  return self.regionregion.read(offset)
427 
428  def write(self, offset: int, data: bytearray) -> None:
429  """Write a value to the MMIO region at the given offset."""
430  self.regionregion.write(offset, data)
431 
432 
434  """A pair of channels which carry the input and output of a function."""
435 
436  def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
437  super().__init__(owner, cpp_port)
438  self.arg_typearg_type = self.write_portwrite_port("arg").type
439  self.result_typeresult_type = self.read_portread_port("result").type
440  self.connectedconnected = False
441 
442  def connect(self):
443  self.cpp_portcpp_port.connect()
444  self.connectedconnected = True
445 
446  def call(self, **kwargs: Any) -> Future:
447  """Call the function with the given argument and returns a future of the
448  result."""
449  valid, reason = self.arg_typearg_type.is_valid(kwargs)
450  if not valid:
451  raise ValueError(
452  f"'{kwargs}' cannot be converted to '{self.arg_type}': {reason}")
453  arg_bytes: bytearray = self.arg_typearg_type.serialize(kwargs)
454  cpp_future = self.cpp_portcpp_port.call(arg_bytes)
455  return MessageFuture(self.result_typeresult_type, cpp_future)
456 
457  def __call__(self, *args: Any, **kwds: Any) -> Future:
458  return self.callcall(*args, **kwds)
WriteChannelPort & getWrite(const std::map< std::string, ChannelPort & > &channels, const std::string &name)
Definition: Services.cpp:247
ReadChannelPort & getRead(const std::map< std::string, ChannelPort & > &channels, const std::string &name)
Definition: Services.cpp:239
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:263
def __init__(self, cpp.ArrayType cpp_type)
Definition: types.py:254
Tuple[List[Any], bytearray] deserialize(self, bytearray data)
Definition: types.py:280
int bit_width(self)
Definition: types.py:260
bytearray serialize(self, list lst)
Definition: types.py:274
bytearray serialize(self, Union[bytearray, bytes, List[int]] obj)
Definition: types.py:129
int bit_width(self)
Definition: types.py:126
Tuple[bytearray, bytearray] deserialize(self, bytearray data)
Definition: types.py:136
def __init__(self, cpp.BitsType cpp_type)
Definition: types.py:112
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:115
WritePort write_port(self, str channel_name)
Definition: types.py:376
def __new__(cls, HWModule owner, cpp.BundlePort cpp_port)
Definition: types.py:364
ReadPort read_port(self, str channel_name)
Definition: types.py:379
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition: types.py:372
int bit_width(self)
Definition: types.py:61
Tuple[bool, Optional[str]] supports_host(self)
Definition: types.py:46
Tuple[object, bytearray] deserialize(self, bytearray data)
Definition: types.py:77
def __init__(self, cpp.Type cpp_type)
Definition: types.py:42
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:55
int max_size(self)
Definition: types.py:66
str __str__(self)
Definition: types.py:82
bytearray serialize(self, obj)
Definition: types.py:73
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition: types.py:436
Future call(self, **Any kwargs)
Definition: types.py:446
Future __call__(self, *Any args, **Any kwds)
Definition: types.py:457
def __init__(self, cpp.IntegerType cpp_type)
Definition: types.py:145
int bit_width(self)
Definition: types.py:149
None write(self, int offset, bytearray data)
Definition: types.py:428
bytearray read(self, int offset)
Definition: types.py:424
def __init__(self, HWModule owner, cpp.MMIORegion cpp_port)
Definition: types.py:416
cpp.MMIORegionDesc descriptor(self)
Definition: types.py:421
Any result(self, Optional[Union[int, float]] timeout=None)
Definition: types.py:398
def __init__(self, Type result_type, cpp.MessageDataFuture cpp_future)
Definition: types.py:388
None add_done_callback(self, Callable[[Future], object] fn)
Definition: types.py:407
def connect(self, Optional[int] buffer_size=None)
Definition: types.py:301
def disconnect(self)
Definition: types.py:309
def __init__(self, BundlePort owner, cpp.ChannelPort cpp_port)
Definition: types.py:296
def __init__(self, BundlePort owner, cpp.ReadChannelPort cpp_port)
Definition: types.py:344
object read(self)
Definition: types.py:348
Tuple[int, bytearray] deserialize(self, bytearray data)
Definition: types.py:195
str __str__(self)
Definition: types.py:189
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:178
bytearray serialize(self, int obj)
Definition: types.py:192
bytearray serialize(self, obj)
Definition: types.py:234
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
Definition: types.py:241
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:218
def __init__(self, cpp.StructType cpp_type)
Definition: types.py:205
Tuple[int, bytearray] deserialize(self, bytearray data)
Definition: types.py:168
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:155
str __str__(self)
Definition: types.py:162
bytearray serialize(self, int obj)
Definition: types.py:165
int bit_width(self)
Definition: types.py:94
Tuple[object, bytearray] deserialize(self, bytearray data)
Definition: types.py:101
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:88
bytearray serialize(self, obj)
Definition: types.py:97
bool try_write(self, msg=None)
Definition: types.py:335
def __init__(self, BundlePort owner, cpp.WriteChannelPort cpp_port)
Definition: types.py:316
bytearray __serialize_msg(self, msg=None)
Definition: types.py:320
bool write(self, msg=None)
Definition: types.py:328
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
def _get_esi_type(cpp.Type cpp_type)
Definition: types.py:25