CIRCT  19.0.0git
types.py
Go to the documentation of this file.
1 # ===-----------------------------------------------------------------------===#
2 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3 # See https://llvm.org/LICENSE.txt for license information.
4 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 # ===-----------------------------------------------------------------------===#
6 #
7 # The structure of the Python classes and hierarchy roughly mirrors the C++
8 # side, but wraps the C++ objects. The wrapper classes sometimes add convenience
9 # functionality and serve to return wrapped versions of the returned objects.
10 #
11 # ===-----------------------------------------------------------------------===#
12 
13 from __future__ import annotations
14 
15 from . import esiCppAccel as cpp
16 
17 from typing import TYPE_CHECKING
18 if TYPE_CHECKING:
19  from .accelerator import HWModule
20 
21 from concurrent.futures import Future
22 from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
23 
24 
25 def _get_esi_type(cpp_type: cpp.Type):
26  """Get the wrapper class for a C++ type."""
27  for cpp_type_cls, fn in __esi_mapping.items():
28  if isinstance(cpp_type, cpp_type_cls):
29  return fn(cpp_type)
30  return ESIType(cpp_type)
31 
32 
33 # Mapping from C++ types to functions constructing the Python object
34 # corresponding to that type.
35 __esi_mapping: Dict[Type, Callable] = {
36  cpp.ChannelType: lambda cpp_type: _get_esi_type(cpp_type.inner)
37 }
38 
39 
40 class ESIType:
41 
42  def __init__(self, cpp_type: cpp.Type):
43  self.cpp_typecpp_type = cpp_type
44 
45  @property
46  def supports_host(self) -> Tuple[bool, Optional[str]]:
47  """Does this type support host communication via Python? Returns either
48  '(True, None)' if it is, or '(False, reason)' if it is not."""
49 
50  if self.bit_widthbit_width % 8 != 0:
51  return (False, "runtime only supports types with multiple of 8 bits")
52  return (True, None)
53 
54  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
55  """Is a Python object compatible with HW type? Returns either '(True,
56  None)' if it is, or '(False, reason)' if it is not."""
57  assert False, "unimplemented"
58 
59  @property
60  def bit_width(self) -> int:
61  """Size of this type, in bits. Negative for unbounded types."""
62  assert False, "unimplemented"
63 
64  @property
65  def max_size(self) -> int:
66  """Maximum size of a value of this type, in bytes."""
67  bitwidth = int((self.bit_widthbit_width + 7) / 8)
68  if bitwidth < 0:
69  return bitwidth
70  return bitwidth
71 
72  def serialize(self, obj) -> bytearray:
73  """Convert a Python object to a bytearray."""
74  assert False, "unimplemented"
75 
76  def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
77  """Convert a bytearray to a Python object. Return the object and the
78  leftover bytes."""
79  assert False, "unimplemented"
80 
81  def __str__(self) -> str:
82  return str(self.cpp_typecpp_type)
83 
84 
86 
87  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
88  if obj is not None:
89  return (False, f"void type cannot must represented by None, not {obj}")
90  return (True, None)
91 
92  @property
93  def bit_width(self) -> int:
94  return 8
95 
96  def serialize(self, obj) -> bytearray:
97  # By convention, void is represented by a single byte of value 0.
98  return bytearray([0])
99 
100  def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
101  if len(data) == 0:
102  raise ValueError(f"void type cannot be represented by {data}")
103  return (None, data[1:])
104 
105 
106 __esi_mapping[cpp.VoidType] = VoidType
107 
108 
110 
111  def __init__(self, cpp_type: cpp.BitsType):
112  self.cpp_typecpp_type: cpp.BitsType = cpp_type
113 
114  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
115  if not isinstance(obj, (bytearray, bytes, list)):
116  return (False, f"invalid type: {type(obj)}")
117  if isinstance(obj, list) and not all(
118  [isinstance(b, int) and b.bit_length() <= 8 for b in obj]):
119  return (False, f"list item too large: {obj}")
120  if len(obj) != self.max_sizemax_size:
121  return (False, f"wrong size: {len(obj)}")
122  return (True, None)
123 
124  @property
125  def bit_width(self) -> int:
126  return self.cpp_typecpp_type.width
127 
128  def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray:
129  if isinstance(obj, bytearray):
130  return obj
131  if isinstance(obj, bytes) or isinstance(obj, list):
132  return bytearray(obj)
133  raise ValueError(f"cannot convert {obj} to bytearray")
134 
135  def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]:
136  return (data[0:self.max_sizemax_size], data[self.max_sizemax_size:])
137 
138 
139 __esi_mapping[cpp.BitsType] = BitsType
140 
141 
143 
144  def __init__(self, cpp_type: cpp.IntegerType):
145  self.cpp_typecpp_type: cpp.IntegerType = cpp_type
146 
147  @property
148  def bit_width(self) -> int:
149  return self.cpp_typecpp_type.width
150 
151 
153 
154  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
155  if not isinstance(obj, int):
156  return (False, f"must be an int, not {type(obj)}")
157  if obj < 0 or obj.bit_length() > self.bit_widthbit_widthbit_width:
158  return (False, f"out of range: {obj}")
159  return (True, None)
160 
161  def __str__(self) -> str:
162  return f"uint{self.bit_width}"
163 
164  def serialize(self, obj: int) -> bytearray:
165  return bytearray(int.to_bytes(obj, self.max_sizemax_size, "little"))
166 
167  def deserialize(self, data: bytearray) -> Tuple[int, bytearray]:
168  return (int.from_bytes(data[0:self.max_sizemax_size],
169  "little"), data[self.max_sizemax_size:])
170 
171 
172 __esi_mapping[cpp.UIntType] = UIntType
173 
174 
176 
177  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
178  if not isinstance(obj, int):
179  return (False, f"must be an int, not {type(obj)}")
180  if obj < 0:
181  if (-1 * obj) > 2**(self.bit_widthbit_widthbit_width - 1):
182  return (False, f"out of range: {obj}")
183  elif obj < 0:
184  if obj >= 2**(self.bit_widthbit_widthbit_width - 1) - 1:
185  return (False, f"out of range: {obj}")
186  return (True, None)
187 
188  def __str__(self) -> str:
189  return f"sint{self.bit_width}"
190 
191  def serialize(self, obj: int) -> bytearray:
192  return bytearray(int.to_bytes(obj, self.max_sizemax_size, "little", signed=True))
193 
194  def deserialize(self, data: bytearray) -> Tuple[int, bytearray]:
195  return (int.from_bytes(data[0:self.max_sizemax_size], "little",
196  signed=True), data[self.max_sizemax_size:])
197 
198 
199 __esi_mapping[cpp.SIntType] = SIntType
200 
201 
203 
204  def __init__(self, cpp_type: cpp.StructType):
205  self.cpp_typecpp_typecpp_type = cpp_type
206  self.fields: List[Tuple[str, ESIType]] = [
207  (name, _get_esi_type(ty)) for (name, ty) in cpp_type.fields
208  ]
209 
210  @property
211  def bit_width(self) -> int:
212  widths = [ty.bit_width for (_, ty) in self.fields]
213  if any([w < 0 for w in widths]):
214  return -1
215  return sum(widths)
216 
217  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
218  fields_count = 0
219  if not isinstance(obj, dict):
220  obj = obj.__dict__
221 
222  for (fname, ftype) in self.fields:
223  if fname not in obj:
224  return (False, f"missing field '{fname}'")
225  fvalid, reason = ftype.is_valid(obj[fname])
226  if not fvalid:
227  return (False, f"invalid field '{fname}': {reason}")
228  fields_count += 1
229  if fields_count != len(obj):
230  return (False, "missing fields")
231  return (True, None)
232 
233  def serialize(self, obj) -> bytearray:
234  ret = bytearray()
235  for (fname, ftype) in reversed(self.fields):
236  fval = obj[fname]
237  ret.extend(ftype.serialize(fval))
238  return ret
239 
240  def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
241  ret = {}
242  for (fname, ftype) in reversed(self.fields):
243  (fval, data) = ftype.deserialize(data)
244  ret[fname] = fval
245  return (ret, data)
246 
247 
248 __esi_mapping[cpp.StructType] = StructType
249 
250 
252 
253  def __init__(self, cpp_type: cpp.ArrayType):
254  self.cpp_typecpp_typecpp_type = cpp_type
255  self.element_typeelement_type = _get_esi_type(cpp_type.element)
256  self.sizesize = cpp_type.size
257 
258  @property
259  def bit_width(self) -> int:
260  return self.element_typeelement_type.bit_width * self.sizesize
261 
262  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
263  if not isinstance(obj, list):
264  return (False, f"must be a list, not {type(obj)}")
265  if len(obj) != self.sizesize:
266  return (False, f"wrong size: expected {self.size} not {len(obj)}")
267  for (idx, e) in enumerate(obj):
268  evalid, reason = self.element_typeelement_type.is_valid(e)
269  if not evalid:
270  return (False, f"invalid element {idx}: {reason}")
271  return (True, None)
272 
273  def serialize(self, lst: list) -> bytearray:
274  ret = bytearray()
275  for e in reversed(lst):
276  ret.extend(self.element_typeelement_type.serialize(e))
277  return ret
278 
279  def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]:
280  ret = []
281  for _ in range(self.sizesize):
282  (obj, data) = self.element_typeelement_type.deserialize(data)
283  ret.append(obj)
284  ret.reverse()
285  return (ret, data)
286 
287 
288 __esi_mapping[cpp.ArrayType] = ArrayType
289 
290 
291 class Port:
292  """A unidirectional communication channel. This is the basic communication
293  method with an accelerator."""
294 
295  def __init__(self, owner: BundlePort, cpp_port: cpp.ChannelPort):
296  self.ownerowner = owner
297  self.cpp_portcpp_port = cpp_port
298  self.typetype = _get_esi_type(cpp_port.type)
299  (supports_host, reason) = self.typetype.supports_host
300  if not supports_host:
301  raise TypeError(f"unsupported type: {reason}")
302 
303  def connect(self):
304  self.cpp_portcpp_port.connect()
305  return self
306 
307 
309  """A unidirectional communication channel from the host to the accelerator."""
310 
311  def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
312  super().__init__(owner, cpp_port)
313  self.cpp_portcpp_port: cpp.WriteChannelPort = cpp_port
314 
315  def write(self, msg=None) -> bool:
316  """Write a typed message to the channel. Attempts to serialize 'msg' to what
317  the accelerator expects, but will fail if the object is not convertible to
318  the port type."""
319 
320  valid, reason = self.typetype.is_valid(msg)
321  if not valid:
322  raise ValueError(
323  f"'{msg}' cannot be converted to '{self.type}': {reason}")
324  msg_bytes: bytearray = self.typetype.serialize(msg)
325  self.cpp_portcpp_port.write(msg_bytes)
326  return True
327 
328 
329 class ReadPort(Port):
330  """A unidirectional communication channel from the accelerator to the host."""
331 
332  def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
333  super().__init__(owner, cpp_port)
334  self.cpp_portcpp_port: cpp.ReadChannelPort = cpp_port
335 
336  def read(self) -> object:
337  """Read a typed message from the channel. Returns a deserialized object of a
338  type defined by the port type."""
339 
340  buffer = self.cpp_portcpp_port.read()
341  (msg, leftover) = self.typetype.deserialize(buffer)
342  if len(leftover) != 0:
343  raise ValueError(f"leftover bytes: {leftover}")
344  return msg
345 
346 
348  """A collections of named, unidirectional communication channels."""
349 
350  # When creating a new port, we need to determine if it is a service port and
351  # instantiate it correctly.
352  def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort):
353  # TODO: add a proper registration mechanism for service ports.
354  if isinstance(cpp_port, cpp.Function):
355  return super().__new__(FunctionPort)
356  return super().__new__(cls)
357 
358  def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
359  self.ownerowner = owner
360  self.cpp_portcpp_port = cpp_port
361 
362  def write_port(self, channel_name: str) -> WritePort:
363  return WritePort(self, self.cpp_portcpp_port.getWrite(channel_name))
364 
365  def read_port(self, channel_name: str) -> ReadPort:
366  return ReadPort(self, self.cpp_portcpp_port.getRead(channel_name))
367 
368 
369 class MessageFuture(Future):
370  """A specialization of `Future` for ESI messages. Wraps the cpp object and
371  deserializes the result. Hopefully overrides all the methods necessary for
372  proper operation, which is assumed to be not all of them."""
373 
374  def __init__(self, result_type: Type, cpp_future: cpp.MessageDataFuture):
375  self.result_typeresult_type = result_type
376  self.cpp_futurecpp_future = cpp_future
377 
378  def running(self) -> bool:
379  return True
380 
381  def done(self) -> bool:
382  return self.cpp_futurecpp_future.valid()
383 
384  def result(self, timeout: Optional[Union[int, float]] = None) -> Any:
385  # TODO: respect timeout
386  self.cpp_futurecpp_future.wait()
387  result_bytes = self.cpp_futurecpp_future.get()
388  (msg, leftover) = self.result_typeresult_type.deserialize(result_bytes)
389  if len(leftover) != 0:
390  raise ValueError(f"leftover bytes: {leftover}")
391  return msg
392 
393  def add_done_callback(self, fn: Callable[[Future], object]) -> None:
394  raise NotImplementedError("add_done_callback is not implemented")
395 
396 
398  """A pair of channels which carry the input and output of a function."""
399 
400  def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
401  super().__init__(owner, cpp_port)
402  self.arg_typearg_type = self.write_portwrite_port("arg").type
403  self.result_typeresult_type = self.read_portread_port("result").type
404  self.connectedconnected = False
405 
406  def connect(self):
407  self.cpp_portcpp_port.connect()
408  self.connectedconnected = True
409 
410  def call(self, **kwargs: Any) -> Future:
411  """Call the function with the given argument and returns a future of the
412  result."""
413  valid, reason = self.arg_typearg_type.is_valid(kwargs)
414  if not valid:
415  raise ValueError(
416  f"'{kwargs}' cannot be converted to '{self.arg_type}': {reason}")
417  arg_bytes: bytearray = self.arg_typearg_type.serialize(kwargs)
418  cpp_future = self.cpp_portcpp_port.call(arg_bytes)
419  return MessageFuture(self.result_typeresult_type, cpp_future)
420 
421  def __call__(self, *args: Any, **kwds: Any) -> Future:
422  return self.callcall(*args, **kwds)
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:262
def __init__(self, cpp.ArrayType cpp_type)
Definition: types.py:253
Tuple[List[Any], bytearray] deserialize(self, bytearray data)
Definition: types.py:279
int bit_width(self)
Definition: types.py:259
bytearray serialize(self, list lst)
Definition: types.py:273
bytearray serialize(self, Union[bytearray, bytes, List[int]] obj)
Definition: types.py:128
int bit_width(self)
Definition: types.py:125
Tuple[bytearray, bytearray] deserialize(self, bytearray data)
Definition: types.py:135
def __init__(self, cpp.BitsType cpp_type)
Definition: types.py:111
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:114
WritePort write_port(self, str channel_name)
Definition: types.py:362
def __new__(cls, HWModule owner, cpp.BundlePort cpp_port)
Definition: types.py:352
ReadPort read_port(self, str channel_name)
Definition: types.py:365
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition: types.py:358
int bit_width(self)
Definition: types.py:60
Tuple[bool, Optional[str]] supports_host(self)
Definition: types.py:46
Tuple[object, bytearray] deserialize(self, bytearray data)
Definition: types.py:76
def __init__(self, cpp.Type cpp_type)
Definition: types.py:42
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:54
int max_size(self)
Definition: types.py:65
str __str__(self)
Definition: types.py:81
bytearray serialize(self, obj)
Definition: types.py:72
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition: types.py:400
Future call(self, **Any kwargs)
Definition: types.py:410
Future __call__(self, *Any args, **Any kwds)
Definition: types.py:421
def __init__(self, cpp.IntegerType cpp_type)
Definition: types.py:144
int bit_width(self)
Definition: types.py:148
Any result(self, Optional[Union[int, float]] timeout=None)
Definition: types.py:384
def __init__(self, Type result_type, cpp.MessageDataFuture cpp_future)
Definition: types.py:374
None add_done_callback(self, Callable[[Future], object] fn)
Definition: types.py:393
def connect(self)
Definition: types.py:303
def __init__(self, BundlePort owner, cpp.ChannelPort cpp_port)
Definition: types.py:295
def __init__(self, BundlePort owner, cpp.ReadChannelPort cpp_port)
Definition: types.py:332
object read(self)
Definition: types.py:336
Tuple[int, bytearray] deserialize(self, bytearray data)
Definition: types.py:194
str __str__(self)
Definition: types.py:188
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:177
bytearray serialize(self, int obj)
Definition: types.py:191
bytearray serialize(self, obj)
Definition: types.py:233
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
Definition: types.py:240
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:217
def __init__(self, cpp.StructType cpp_type)
Definition: types.py:204
Tuple[int, bytearray] deserialize(self, bytearray data)
Definition: types.py:167
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:154
str __str__(self)
Definition: types.py:161
bytearray serialize(self, int obj)
Definition: types.py:164
int bit_width(self)
Definition: types.py:93
Tuple[object, bytearray] deserialize(self, bytearray data)
Definition: types.py:100
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:87
bytearray serialize(self, obj)
Definition: types.py:96
def __init__(self, BundlePort owner, cpp.WriteChannelPort cpp_port)
Definition: types.py:311
bool write(self, msg=None)
Definition: types.py:315
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
def _get_esi_type(cpp.Type cpp_type)
Definition: types.py:25