CIRCT  19.0.0git
types.py
Go to the documentation of this file.
1 # ===-----------------------------------------------------------------------===#
2 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3 # See https://llvm.org/LICENSE.txt for license information.
4 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 # ===-----------------------------------------------------------------------===#
6 #
7 # The structure of the Python classes and hierarchy roughly mirrors the C++
8 # side, but wraps the C++ objects. The wrapper classes sometimes add convenience
9 # functionality and serve to return wrapped versions of the returned objects.
10 #
11 # ===-----------------------------------------------------------------------===#
12 
13 from __future__ import annotations
14 
15 from . import esiCppAccel as cpp
16 
17 from typing import TYPE_CHECKING
18 if TYPE_CHECKING:
19  from .accelerator import HWModule
20 
21 from concurrent.futures import Future
22 from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
23 
24 
25 def _get_esi_type(cpp_type: cpp.Type):
26  """Get the wrapper class for a C++ type."""
27  for cpp_type_cls, fn in __esi_mapping.items():
28  if isinstance(cpp_type, cpp_type_cls):
29  return fn(cpp_type)
30  return ESIType(cpp_type)
31 
32 
33 # Mapping from C++ types to functions constructing the Python object
34 # corresponding to that type.
35 __esi_mapping: Dict[Type, Callable] = {
36  cpp.ChannelType: lambda cpp_type: _get_esi_type(cpp_type.inner)
37 }
38 
39 
40 class ESIType:
41 
42  def __init__(self, cpp_type: cpp.Type):
43  self.cpp_typecpp_type = cpp_type
44 
45  @property
46  def supports_host(self) -> Tuple[bool, Optional[str]]:
47  """Does this type support host communication via Python? Returns either
48  '(True, None)' if it is, or '(False, reason)' if it is not."""
49 
50  if self.bit_widthbit_width % 8 != 0:
51  return (False, "runtime only supports types with multiple of 8 bits")
52  return (True, None)
53 
54  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
55  """Is a Python object compatible with HW type? Returns either '(True,
56  None)' if it is, or '(False, reason)' if it is not."""
57  assert False, "unimplemented"
58 
59  @property
60  def bit_width(self) -> int:
61  """Size of this type, in bits. Negative for unbounded types."""
62  assert False, "unimplemented"
63 
64  @property
65  def max_size(self) -> int:
66  """Maximum size of a value of this type, in bytes."""
67  bitwidth = int((self.bit_widthbit_width + 7) / 8)
68  if bitwidth < 0:
69  return bitwidth
70  return bitwidth
71 
72  def serialize(self, obj) -> bytearray:
73  """Convert a Python object to a bytearray."""
74  assert False, "unimplemented"
75 
76  def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
77  """Convert a bytearray to a Python object. Return the object and the
78  leftover bytes."""
79  assert False, "unimplemented"
80 
81  def __str__(self) -> str:
82  return str(self.cpp_typecpp_type)
83 
84 
86 
87  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
88  if obj is not None:
89  return (False, f"void type cannot must represented by None, not {obj}")
90  return (True, None)
91 
92  @property
93  def bit_width(self) -> int:
94  return 8
95 
96  def serialize(self, obj) -> bytearray:
97  # By convention, void is represented by a single byte of value 0.
98  return bytearray([0])
99 
100  def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
101  if len(data) == 0:
102  raise ValueError(f"void type cannot be represented by {data}")
103  return (None, data[1:])
104 
105 
106 __esi_mapping[cpp.VoidType] = VoidType
107 
108 
110 
111  def __init__(self, cpp_type: cpp.BitsType):
112  self.cpp_typecpp_type: cpp.BitsType = cpp_type
113 
114  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
115  if not isinstance(obj, (bytearray, bytes, list)):
116  return (False, f"invalid type: {type(obj)}")
117  if isinstance(obj, list) and not all(
118  [isinstance(b, int) and b.bit_length() <= 8 for b in obj]):
119  return (False, f"list item too large: {obj}")
120  if len(obj) != self.max_sizemax_size:
121  return (False, f"wrong size: {len(obj)}")
122  return (True, None)
123 
124  @property
125  def bit_width(self) -> int:
126  return self.cpp_typecpp_type.width
127 
128  def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray:
129  if isinstance(obj, bytearray):
130  return obj
131  if isinstance(obj, bytes) or isinstance(obj, list):
132  return bytearray(obj)
133  raise ValueError(f"cannot convert {obj} to bytearray")
134 
135  def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]:
136  return (data[0:self.max_sizemax_size], data[self.max_sizemax_size:])
137 
138 
139 __esi_mapping[cpp.BitsType] = BitsType
140 
141 
143 
144  def __init__(self, cpp_type: cpp.IntegerType):
145  self.cpp_typecpp_type: cpp.IntegerType = cpp_type
146 
147  @property
148  def bit_width(self) -> int:
149  return self.cpp_typecpp_type.width
150 
151 
153 
154  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
155  if not isinstance(obj, int):
156  return (False, f"must be an int, not {type(obj)}")
157  if obj < 0 or obj.bit_length() > self.bit_widthbit_widthbit_width:
158  return (False, f"out of range: {obj}")
159  return (True, None)
160 
161  def __str__(self) -> str:
162  return f"uint{self.bit_width}"
163 
164  def serialize(self, obj: int) -> bytearray:
165  return bytearray(int.to_bytes(obj, self.max_sizemax_size, "little"))
166 
167  def deserialize(self, data: bytearray) -> Tuple[int, bytearray]:
168  return (int.from_bytes(data[0:self.max_sizemax_size],
169  "little"), data[self.max_sizemax_size:])
170 
171 
172 __esi_mapping[cpp.UIntType] = UIntType
173 
174 
176 
177  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
178  if not isinstance(obj, int):
179  return (False, f"must be an int, not {type(obj)}")
180  if obj < 0:
181  if (-1 * obj) > 2**(self.bit_widthbit_widthbit_width - 1):
182  return (False, f"out of range: {obj}")
183  elif obj < 0:
184  if obj >= 2**(self.bit_widthbit_widthbit_width - 1) - 1:
185  return (False, f"out of range: {obj}")
186  return (True, None)
187 
188  def __str__(self) -> str:
189  return f"sint{self.bit_width}"
190 
191  def serialize(self, obj: int) -> bytearray:
192  return bytearray(int.to_bytes(obj, self.max_sizemax_size, "little", signed=True))
193 
194  def deserialize(self, data: bytearray) -> Tuple[int, bytearray]:
195  return (int.from_bytes(data[0:self.max_sizemax_size], "little",
196  signed=True), data[self.max_sizemax_size:])
197 
198 
199 __esi_mapping[cpp.SIntType] = SIntType
200 
201 
203 
204  def __init__(self, cpp_type: cpp.StructType):
205  self.cpp_typecpp_typecpp_type = cpp_type
206  self.fields: List[Tuple[str, ESIType]] = [
207  (name, _get_esi_type(ty)) for (name, ty) in cpp_type.fields
208  ]
209 
210  @property
211  def bit_width(self) -> int:
212  widths = [ty.bit_width for (_, ty) in self.fields]
213  if any([w < 0 for w in widths]):
214  return -1
215  return sum(widths)
216 
217  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
218  fields_count = 0
219  if not isinstance(obj, dict):
220  obj = obj.__dict__
221 
222  for (fname, ftype) in self.fields:
223  if fname not in obj:
224  return (False, f"missing field '{fname}'")
225  fvalid, reason = ftype.is_valid(obj[fname])
226  if not fvalid:
227  return (False, f"invalid field '{fname}': {reason}")
228  fields_count += 1
229  if fields_count != len(obj):
230  return (False, "missing fields")
231  return (True, None)
232 
233  def serialize(self, obj) -> bytearray:
234  ret = bytearray()
235  for (fname, ftype) in reversed(self.fields):
236  fval = obj[fname]
237  ret.extend(ftype.serialize(fval))
238  return ret
239 
240  def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
241  ret = {}
242  for (fname, ftype) in reversed(self.fields):
243  (fval, data) = ftype.deserialize(data)
244  ret[fname] = fval
245  return (ret, data)
246 
247 
248 __esi_mapping[cpp.StructType] = StructType
249 
250 
252 
253  def __init__(self, cpp_type: cpp.ArrayType):
254  self.cpp_typecpp_typecpp_type = cpp_type
255  self.element_typeelement_type = _get_esi_type(cpp_type.element)
256  self.sizesize = cpp_type.size
257 
258  @property
259  def bit_width(self) -> int:
260  return self.element_typeelement_type.bit_width * self.sizesize
261 
262  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
263  if not isinstance(obj, list):
264  return (False, f"must be a list, not {type(obj)}")
265  if len(obj) != self.sizesize:
266  return (False, f"wrong size: expected {self.size} not {len(obj)}")
267  for (idx, e) in enumerate(obj):
268  evalid, reason = self.element_typeelement_type.is_valid(e)
269  if not evalid:
270  return (False, f"invalid element {idx}: {reason}")
271  return (True, None)
272 
273  def serialize(self, lst: list) -> bytearray:
274  ret = bytearray()
275  for e in reversed(lst):
276  ret.extend(self.element_typeelement_type.serialize(e))
277  return ret
278 
279  def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]:
280  ret = []
281  for _ in range(self.sizesize):
282  (obj, data) = self.element_typeelement_type.deserialize(data)
283  ret.append(obj)
284  ret.reverse()
285  return (ret, data)
286 
287 
288 __esi_mapping[cpp.ArrayType] = ArrayType
289 
290 
291 class Port:
292  """A unidirectional communication channel. This is the basic communication
293  method with an accelerator."""
294 
295  def __init__(self, owner: BundlePort, cpp_port: cpp.ChannelPort):
296  self.ownerowner = owner
297  self.cpp_portcpp_port = cpp_port
298  self.typetype = _get_esi_type(cpp_port.type)
299  (supports_host, reason) = self.typetype.supports_host
300  if not supports_host:
301  raise TypeError(f"unsupported type: {reason}")
302 
303  def connect(self):
304  self.cpp_portcpp_port.connect()
305  return self
306 
307 
309  """A unidirectional communication channel from the host to the accelerator."""
310 
311  def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
312  super().__init__(owner, cpp_port)
313  self.cpp_portcpp_port: cpp.WriteChannelPort = cpp_port
314 
315  def write(self, msg=None) -> bool:
316  """Write a typed message to the channel. Attempts to serialize 'msg' to what
317  the accelerator expects, but will fail if the object is not convertible to
318  the port type."""
319 
320  valid, reason = self.typetype.is_valid(msg)
321  if not valid:
322  raise ValueError(
323  f"'{msg}' cannot be converted to '{self.type}': {reason}")
324  msg_bytes: bytearray = self.typetype.serialize(msg)
325  self.cpp_portcpp_port.write(msg_bytes)
326  return True
327 
328 
329 class ReadPort(Port):
330  """A unidirectional communication channel from the accelerator to the host."""
331 
332  def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
333  super().__init__(owner, cpp_port)
334  self.cpp_portcpp_port: cpp.ReadChannelPort = cpp_port
335 
336  def read(self) -> Tuple[bool, Optional[object]]:
337  """Read a typed message from the channel. Returns a deserialized object of a
338  type defined by the port type."""
339 
340  buffer = self.cpp_portcpp_port.read()
341  if buffer is None:
342  return (False, None)
343  (msg, leftover) = self.typetype.deserialize(buffer)
344  if len(leftover) != 0:
345  raise ValueError(f"leftover bytes: {leftover}")
346  return (True, msg)
347 
348 
350  """A collections of named, unidirectional communication channels."""
351 
352  # When creating a new port, we need to determine if it is a service port and
353  # instantiate it correctly.
354  def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort):
355  # TODO: add a proper registration mechanism for service ports.
356  if isinstance(cpp_port, cpp.Function):
357  return super().__new__(FunctionPort)
358  return super().__new__(cls)
359 
360  def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
361  self.ownerowner = owner
362  self.cpp_portcpp_port = cpp_port
363 
364  def write_port(self, channel_name: str) -> WritePort:
365  return WritePort(self, self.cpp_portcpp_port.getWrite(channel_name))
366 
367  def read_port(self, channel_name: str) -> ReadPort:
368  return ReadPort(self, self.cpp_portcpp_port.getRead(channel_name))
369 
370 
371 class MessageFuture(Future):
372  """A specialization of `Future` for ESI messages. Wraps the cpp object and
373  deserializes the result. Hopefully overrides all the methods necessary for
374  proper operation, which is assumed to be not all of them."""
375 
376  def __init__(self, result_type: Type, cpp_future: cpp.MessageDataFuture):
377  self.result_typeresult_type = result_type
378  self.cpp_futurecpp_future = cpp_future
379 
380  def running(self) -> bool:
381  return True
382 
383  def done(self) -> bool:
384  return self.cpp_futurecpp_future.valid()
385 
386  def result(self, timeout: Optional[Union[int, float]] = None) -> Any:
387  # TODO: respect timeout
388  self.cpp_futurecpp_future.wait()
389  result_bytes = self.cpp_futurecpp_future.get()
390  (msg, leftover) = self.result_typeresult_type.deserialize(result_bytes)
391  if len(leftover) != 0:
392  raise ValueError(f"leftover bytes: {leftover}")
393  return msg
394 
395  def add_done_callback(self, fn: Callable[[Future], object]) -> None:
396  raise NotImplementedError("add_done_callback is not implemented")
397 
398 
400  """A pair of channels which carry the input and output of a function."""
401 
402  def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
403  super().__init__(owner, cpp_port)
404  self.arg_typearg_type = self.write_portwrite_port("arg").type
405  self.result_typeresult_type = self.read_portread_port("result").type
406  self.connectedconnected = False
407 
408  def connect(self):
409  self.cpp_portcpp_port.connect()
410  self.connectedconnected = True
411 
412  def call(self, **kwargs: Any) -> Future:
413  """Call the function with the given argument and returns a future of the
414  result."""
415  valid, reason = self.arg_typearg_type.is_valid(kwargs)
416  if not valid:
417  raise ValueError(
418  f"'{kwargs}' cannot be converted to '{self.arg_type}': {reason}")
419  arg_bytes: bytearray = self.arg_typearg_type.serialize(kwargs)
420  cpp_future = self.cpp_portcpp_port.call(arg_bytes)
421  return MessageFuture(self.result_typeresult_type, cpp_future)
422 
423  def __call__(self, *args: Any, **kwds: Any) -> Future:
424  return self.callcall(*args, **kwds)
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:262
def __init__(self, cpp.ArrayType cpp_type)
Definition: types.py:253
Tuple[List[Any], bytearray] deserialize(self, bytearray data)
Definition: types.py:279
int bit_width(self)
Definition: types.py:259
bytearray serialize(self, list lst)
Definition: types.py:273
bytearray serialize(self, Union[bytearray, bytes, List[int]] obj)
Definition: types.py:128
int bit_width(self)
Definition: types.py:125
Tuple[bytearray, bytearray] deserialize(self, bytearray data)
Definition: types.py:135
def __init__(self, cpp.BitsType cpp_type)
Definition: types.py:111
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:114
WritePort write_port(self, str channel_name)
Definition: types.py:364
def __new__(cls, HWModule owner, cpp.BundlePort cpp_port)
Definition: types.py:354
ReadPort read_port(self, str channel_name)
Definition: types.py:367
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition: types.py:360
int bit_width(self)
Definition: types.py:60
Tuple[bool, Optional[str]] supports_host(self)
Definition: types.py:46
Tuple[object, bytearray] deserialize(self, bytearray data)
Definition: types.py:76
def __init__(self, cpp.Type cpp_type)
Definition: types.py:42
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:54
int max_size(self)
Definition: types.py:65
str __str__(self)
Definition: types.py:81
bytearray serialize(self, obj)
Definition: types.py:72
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition: types.py:402
Future call(self, **Any kwargs)
Definition: types.py:412
Future __call__(self, *Any args, **Any kwds)
Definition: types.py:423
def __init__(self, cpp.IntegerType cpp_type)
Definition: types.py:144
int bit_width(self)
Definition: types.py:148
Any result(self, Optional[Union[int, float]] timeout=None)
Definition: types.py:386
def __init__(self, Type result_type, cpp.MessageDataFuture cpp_future)
Definition: types.py:376
None add_done_callback(self, Callable[[Future], object] fn)
Definition: types.py:395
def connect(self)
Definition: types.py:303
def __init__(self, BundlePort owner, cpp.ChannelPort cpp_port)
Definition: types.py:295
def __init__(self, BundlePort owner, cpp.ReadChannelPort cpp_port)
Definition: types.py:332
Tuple[bool, Optional[object]] read(self)
Definition: types.py:336
Tuple[int, bytearray] deserialize(self, bytearray data)
Definition: types.py:194
str __str__(self)
Definition: types.py:188
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:177
bytearray serialize(self, int obj)
Definition: types.py:191
bytearray serialize(self, obj)
Definition: types.py:233
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
Definition: types.py:240
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:217
def __init__(self, cpp.StructType cpp_type)
Definition: types.py:204
Tuple[int, bytearray] deserialize(self, bytearray data)
Definition: types.py:167
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:154
str __str__(self)
Definition: types.py:161
bytearray serialize(self, int obj)
Definition: types.py:164
int bit_width(self)
Definition: types.py:93
Tuple[object, bytearray] deserialize(self, bytearray data)
Definition: types.py:100
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:87
bytearray serialize(self, obj)
Definition: types.py:96
def __init__(self, BundlePort owner, cpp.WriteChannelPort cpp_port)
Definition: types.py:311
bool write(self, msg=None)
Definition: types.py:315
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
def _get_esi_type(cpp.Type cpp_type)
Definition: types.py:25