CIRCT  20.0.0git
types.py
Go to the documentation of this file.
1 # ===-----------------------------------------------------------------------===#
2 # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3 # See https://llvm.org/LICENSE.txt for license information.
4 # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5 # ===-----------------------------------------------------------------------===#
6 #
7 # The structure of the Python classes and hierarchy roughly mirrors the C++
8 # side, but wraps the C++ objects. The wrapper classes sometimes add convenience
9 # functionality and serve to return wrapped versions of the returned objects.
10 #
11 # ===-----------------------------------------------------------------------===#
12 
13 from __future__ import annotations
14 
15 from . import esiCppAccel as cpp
16 
17 from typing import TYPE_CHECKING
18 if TYPE_CHECKING:
19  from .accelerator import HWModule
20 
21 from concurrent.futures import Future
22 from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
23 
24 
25 def _get_esi_type(cpp_type: cpp.Type):
26  """Get the wrapper class for a C++ type."""
27  for cpp_type_cls, fn in __esi_mapping.items():
28  if isinstance(cpp_type, cpp_type_cls):
29  return fn(cpp_type)
30  return ESIType(cpp_type)
31 
32 
33 # Mapping from C++ types to functions constructing the Python object
34 # corresponding to that type.
35 __esi_mapping: Dict[Type, Callable] = {
36  cpp.ChannelType: lambda cpp_type: _get_esi_type(cpp_type.inner)
37 }
38 
39 
40 class ESIType:
41 
42  def __init__(self, cpp_type: cpp.Type):
43  self.cpp_typecpp_type = cpp_type
44 
45  @property
46  def supports_host(self) -> Tuple[bool, Optional[str]]:
47  """Does this type support host communication via Python? Returns either
48  '(True, None)' if it is, or '(False, reason)' if it is not."""
49 
50  if self.bit_widthbit_width % 8 != 0:
51  return (False, "runtime only supports types with multiple of 8 bits")
52  return (True, None)
53 
54  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
55  """Is a Python object compatible with HW type? Returns either '(True,
56  None)' if it is, or '(False, reason)' if it is not."""
57  assert False, "unimplemented"
58 
59  @property
60  def bit_width(self) -> int:
61  """Size of this type, in bits. Negative for unbounded types."""
62  assert False, "unimplemented"
63 
64  @property
65  def max_size(self) -> int:
66  """Maximum size of a value of this type, in bytes."""
67  bitwidth = int((self.bit_widthbit_width + 7) / 8)
68  if bitwidth < 0:
69  return bitwidth
70  return bitwidth
71 
72  def serialize(self, obj) -> bytearray:
73  """Convert a Python object to a bytearray."""
74  assert False, "unimplemented"
75 
76  def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
77  """Convert a bytearray to a Python object. Return the object and the
78  leftover bytes."""
79  assert False, "unimplemented"
80 
81  def __str__(self) -> str:
82  return str(self.cpp_typecpp_type)
83 
84 
86 
87  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
88  if obj is not None:
89  return (False, f"void type cannot must represented by None, not {obj}")
90  return (True, None)
91 
92  @property
93  def bit_width(self) -> int:
94  return 8
95 
96  def serialize(self, obj) -> bytearray:
97  # By convention, void is represented by a single byte of value 0.
98  return bytearray([0])
99 
100  def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
101  if len(data) == 0:
102  raise ValueError(f"void type cannot be represented by {data}")
103  return (None, data[1:])
104 
105 
106 __esi_mapping[cpp.VoidType] = VoidType
107 
108 
110 
111  def __init__(self, cpp_type: cpp.BitsType):
112  self.cpp_typecpp_type: cpp.BitsType = cpp_type
113 
114  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
115  if not isinstance(obj, (bytearray, bytes, list)):
116  return (False, f"invalid type: {type(obj)}")
117  if isinstance(obj, list) and not all(
118  [isinstance(b, int) and b.bit_length() <= 8 for b in obj]):
119  return (False, f"list item too large: {obj}")
120  if len(obj) != self.max_sizemax_size:
121  return (False, f"wrong size: {len(obj)}")
122  return (True, None)
123 
124  @property
125  def bit_width(self) -> int:
126  return self.cpp_typecpp_type.width
127 
128  def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray:
129  if isinstance(obj, bytearray):
130  return obj
131  if isinstance(obj, bytes) or isinstance(obj, list):
132  return bytearray(obj)
133  raise ValueError(f"cannot convert {obj} to bytearray")
134 
135  def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]:
136  return (data[0:self.max_sizemax_size], data[self.max_sizemax_size:])
137 
138 
139 __esi_mapping[cpp.BitsType] = BitsType
140 
141 
143 
144  def __init__(self, cpp_type: cpp.IntegerType):
145  self.cpp_typecpp_type: cpp.IntegerType = cpp_type
146 
147  @property
148  def bit_width(self) -> int:
149  return self.cpp_typecpp_type.width
150 
151 
153 
154  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
155  if not isinstance(obj, int):
156  return (False, f"must be an int, not {type(obj)}")
157  if obj < 0 or obj.bit_length() > self.bit_widthbit_widthbit_width:
158  return (False, f"out of range: {obj}")
159  return (True, None)
160 
161  def __str__(self) -> str:
162  return f"uint{self.bit_width}"
163 
164  def serialize(self, obj: int) -> bytearray:
165  return bytearray(int.to_bytes(obj, self.max_sizemax_size, "little"))
166 
167  def deserialize(self, data: bytearray) -> Tuple[int, bytearray]:
168  return (int.from_bytes(data[0:self.max_sizemax_size],
169  "little"), data[self.max_sizemax_size:])
170 
171 
172 __esi_mapping[cpp.UIntType] = UIntType
173 
174 
176 
177  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
178  if not isinstance(obj, int):
179  return (False, f"must be an int, not {type(obj)}")
180  if obj < 0:
181  if (-1 * obj) > 2**(self.bit_widthbit_widthbit_width - 1):
182  return (False, f"out of range: {obj}")
183  elif obj < 0:
184  if obj >= 2**(self.bit_widthbit_widthbit_width - 1) - 1:
185  return (False, f"out of range: {obj}")
186  return (True, None)
187 
188  def __str__(self) -> str:
189  return f"sint{self.bit_width}"
190 
191  def serialize(self, obj: int) -> bytearray:
192  return bytearray(int.to_bytes(obj, self.max_sizemax_size, "little", signed=True))
193 
194  def deserialize(self, data: bytearray) -> Tuple[int, bytearray]:
195  return (int.from_bytes(data[0:self.max_sizemax_size], "little",
196  signed=True), data[self.max_sizemax_size:])
197 
198 
199 __esi_mapping[cpp.SIntType] = SIntType
200 
201 
203 
204  def __init__(self, cpp_type: cpp.StructType):
205  self.cpp_typecpp_typecpp_type = cpp_type
206  self.fields: List[Tuple[str, ESIType]] = [
207  (name, _get_esi_type(ty)) for (name, ty) in cpp_type.fields
208  ]
209 
210  @property
211  def bit_width(self) -> int:
212  widths = [ty.bit_width for (_, ty) in self.fields]
213  if any([w < 0 for w in widths]):
214  return -1
215  return sum(widths)
216 
217  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
218  fields_count = 0
219  if not isinstance(obj, dict):
220  obj = obj.__dict__
221 
222  for (fname, ftype) in self.fields:
223  if fname not in obj:
224  return (False, f"missing field '{fname}'")
225  fvalid, reason = ftype.is_valid(obj[fname])
226  if not fvalid:
227  return (False, f"invalid field '{fname}': {reason}")
228  fields_count += 1
229  if fields_count != len(obj):
230  return (False, "missing fields")
231  return (True, None)
232 
233  def serialize(self, obj) -> bytearray:
234  ret = bytearray()
235  for (fname, ftype) in reversed(self.fields):
236  fval = obj[fname]
237  ret.extend(ftype.serialize(fval))
238  return ret
239 
240  def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
241  ret = {}
242  for (fname, ftype) in reversed(self.fields):
243  (fval, data) = ftype.deserialize(data)
244  ret[fname] = fval
245  return (ret, data)
246 
247 
248 __esi_mapping[cpp.StructType] = StructType
249 
250 
252 
253  def __init__(self, cpp_type: cpp.ArrayType):
254  self.cpp_typecpp_typecpp_type = cpp_type
255  self.element_typeelement_type = _get_esi_type(cpp_type.element)
256  self.sizesize = cpp_type.size
257 
258  @property
259  def bit_width(self) -> int:
260  return self.element_typeelement_type.bit_width * self.sizesize
261 
262  def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
263  if not isinstance(obj, list):
264  return (False, f"must be a list, not {type(obj)}")
265  if len(obj) != self.sizesize:
266  return (False, f"wrong size: expected {self.size} not {len(obj)}")
267  for (idx, e) in enumerate(obj):
268  evalid, reason = self.element_typeelement_type.is_valid(e)
269  if not evalid:
270  return (False, f"invalid element {idx}: {reason}")
271  return (True, None)
272 
273  def serialize(self, lst: list) -> bytearray:
274  ret = bytearray()
275  for e in reversed(lst):
276  ret.extend(self.element_typeelement_type.serialize(e))
277  return ret
278 
279  def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]:
280  ret = []
281  for _ in range(self.sizesize):
282  (obj, data) = self.element_typeelement_type.deserialize(data)
283  ret.append(obj)
284  ret.reverse()
285  return (ret, data)
286 
287 
288 __esi_mapping[cpp.ArrayType] = ArrayType
289 
290 
291 class Port:
292  """A unidirectional communication channel. This is the basic communication
293  method with an accelerator."""
294 
295  def __init__(self, owner: BundlePort, cpp_port: cpp.ChannelPort):
296  self.ownerowner = owner
297  self.cpp_portcpp_port = cpp_port
298  self.typetype = _get_esi_type(cpp_port.type)
299 
300  def connect(self, buffer_size: Optional[int] = None):
301  (supports_host, reason) = self.typetype.supports_host
302  if not supports_host:
303  raise TypeError(f"unsupported type: {reason}")
304 
305  self.cpp_portcpp_port.connect(buffer_size)
306  return self
307 
308  def disconnect(self):
309  self.cpp_portcpp_port.disconnect()
310 
311 
313  """A unidirectional communication channel from the host to the accelerator."""
314 
315  def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
316  super().__init__(owner, cpp_port)
317  self.cpp_portcpp_port: cpp.WriteChannelPort = cpp_port
318 
319  def __serialize_msg(self, msg=None) -> bytearray:
320  valid, reason = self.typetype.is_valid(msg)
321  if not valid:
322  raise ValueError(
323  f"'{msg}' cannot be converted to '{self.type}': {reason}")
324  msg_bytes: bytearray = self.typetype.serialize(msg)
325  return msg_bytes
326 
327  def write(self, msg=None) -> bool:
328  """Write a typed message to the channel. Attempts to serialize 'msg' to what
329  the accelerator expects, but will fail if the object is not convertible to
330  the port type."""
331  self.cpp_portcpp_port.write(self.__serialize_msg__serialize_msg(msg))
332  return True
333 
334  def try_write(self, msg=None) -> bool:
335  """Like 'write', but uses the non-blocking tryWrite method of the underlying
336  port. Returns True if the write was successful, False otherwise."""
337  return self.cpp_portcpp_port.tryWrite(self.__serialize_msg__serialize_msg(msg))
338 
339 
340 class ReadPort(Port):
341  """A unidirectional communication channel from the accelerator to the host."""
342 
343  def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
344  super().__init__(owner, cpp_port)
345  self.cpp_portcpp_port: cpp.ReadChannelPort = cpp_port
346 
347  def read(self) -> object:
348  """Read a typed message from the channel. Returns a deserialized object of a
349  type defined by the port type."""
350 
351  buffer = self.cpp_portcpp_port.read()
352  (msg, leftover) = self.typetype.deserialize(buffer)
353  if len(leftover) != 0:
354  raise ValueError(f"leftover bytes: {leftover}")
355  return msg
356 
357 
359  """A collections of named, unidirectional communication channels."""
360 
361  # When creating a new port, we need to determine if it is a service port and
362  # instantiate it correctly.
363  def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort):
364  # TODO: add a proper registration mechanism for service ports.
365  if isinstance(cpp_port, cpp.Function):
366  return super().__new__(FunctionPort)
367  if isinstance(cpp_port, cpp.MMIORegion):
368  return super().__new__(MMIORegion)
369  return super().__new__(cls)
370 
371  def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
372  self.ownerowner = owner
373  self.cpp_portcpp_port = cpp_port
374 
375  def write_port(self, channel_name: str) -> WritePort:
376  return WritePort(self, self.cpp_portcpp_port.getWrite(channel_name))
377 
378  def read_port(self, channel_name: str) -> ReadPort:
379  return ReadPort(self, self.cpp_portcpp_port.getRead(channel_name))
380 
381 
382 class MessageFuture(Future):
383  """A specialization of `Future` for ESI messages. Wraps the cpp object and
384  deserializes the result. Hopefully overrides all the methods necessary for
385  proper operation, which is assumed to be not all of them."""
386 
387  def __init__(self, result_type: Type, cpp_future: cpp.MessageDataFuture):
388  self.result_typeresult_type = result_type
389  self.cpp_futurecpp_future = cpp_future
390 
391  def running(self) -> bool:
392  return True
393 
394  def done(self) -> bool:
395  return self.cpp_futurecpp_future.valid()
396 
397  def result(self, timeout: Optional[Union[int, float]] = None) -> Any:
398  # TODO: respect timeout
399  self.cpp_futurecpp_future.wait()
400  result_bytes = self.cpp_futurecpp_future.get()
401  (msg, leftover) = self.result_typeresult_type.deserialize(result_bytes)
402  if len(leftover) != 0:
403  raise ValueError(f"leftover bytes: {leftover}")
404  return msg
405 
406  def add_done_callback(self, fn: Callable[[Future], object]) -> None:
407  raise NotImplementedError("add_done_callback is not implemented")
408 
409 
411  """A region of memory-mapped I/O space. This is a collection of named
412  channels, which are either read or read-write. The channels are accessed
413  by name, and can be connected to the host."""
414 
415  def __init__(self, owner: HWModule, cpp_port: cpp.MMIORegion):
416  super().__init__(owner, cpp_port)
417  self.regionregion = cpp_port
418 
419  @property
420  def descriptor(self) -> cpp.MMIORegionDesc:
421  return self.regionregion.descriptor
422 
423  def read(self, offset: int) -> bytearray:
424  """Read a value from the MMIO region at the given offset."""
425  return self.regionregion.read(offset)
426 
427  def write(self, offset: int, data: bytearray) -> None:
428  """Write a value to the MMIO region at the given offset."""
429  self.regionregion.write(offset, data)
430 
431 
433  """A pair of channels which carry the input and output of a function."""
434 
435  def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
436  super().__init__(owner, cpp_port)
437  self.arg_typearg_type = self.write_portwrite_port("arg").type
438  self.result_typeresult_type = self.read_portread_port("result").type
439  self.connectedconnected = False
440 
441  def connect(self):
442  self.cpp_portcpp_port.connect()
443  self.connectedconnected = True
444 
445  def call(self, **kwargs: Any) -> Future:
446  """Call the function with the given argument and returns a future of the
447  result."""
448  valid, reason = self.arg_typearg_type.is_valid(kwargs)
449  if not valid:
450  raise ValueError(
451  f"'{kwargs}' cannot be converted to '{self.arg_type}': {reason}")
452  arg_bytes: bytearray = self.arg_typearg_type.serialize(kwargs)
453  cpp_future = self.cpp_portcpp_port.call(arg_bytes)
454  return MessageFuture(self.result_typeresult_type, cpp_future)
455 
456  def __call__(self, *args: Any, **kwds: Any) -> Future:
457  return self.callcall(*args, **kwds)
WriteChannelPort & getWrite(const std::map< std::string, ChannelPort & > &channels, const std::string &name)
Definition: Services.cpp:247
ReadChannelPort & getRead(const std::map< std::string, ChannelPort & > &channels, const std::string &name)
Definition: Services.cpp:239
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:262
def __init__(self, cpp.ArrayType cpp_type)
Definition: types.py:253
Tuple[List[Any], bytearray] deserialize(self, bytearray data)
Definition: types.py:279
int bit_width(self)
Definition: types.py:259
bytearray serialize(self, list lst)
Definition: types.py:273
bytearray serialize(self, Union[bytearray, bytes, List[int]] obj)
Definition: types.py:128
int bit_width(self)
Definition: types.py:125
Tuple[bytearray, bytearray] deserialize(self, bytearray data)
Definition: types.py:135
def __init__(self, cpp.BitsType cpp_type)
Definition: types.py:111
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:114
WritePort write_port(self, str channel_name)
Definition: types.py:375
def __new__(cls, HWModule owner, cpp.BundlePort cpp_port)
Definition: types.py:363
ReadPort read_port(self, str channel_name)
Definition: types.py:378
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition: types.py:371
int bit_width(self)
Definition: types.py:60
Tuple[bool, Optional[str]] supports_host(self)
Definition: types.py:46
Tuple[object, bytearray] deserialize(self, bytearray data)
Definition: types.py:76
def __init__(self, cpp.Type cpp_type)
Definition: types.py:42
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:54
int max_size(self)
Definition: types.py:65
str __str__(self)
Definition: types.py:81
bytearray serialize(self, obj)
Definition: types.py:72
def __init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition: types.py:435
Future call(self, **Any kwargs)
Definition: types.py:445
Future __call__(self, *Any args, **Any kwds)
Definition: types.py:456
def __init__(self, cpp.IntegerType cpp_type)
Definition: types.py:144
int bit_width(self)
Definition: types.py:148
None write(self, int offset, bytearray data)
Definition: types.py:427
bytearray read(self, int offset)
Definition: types.py:423
def __init__(self, HWModule owner, cpp.MMIORegion cpp_port)
Definition: types.py:415
cpp.MMIORegionDesc descriptor(self)
Definition: types.py:420
Any result(self, Optional[Union[int, float]] timeout=None)
Definition: types.py:397
def __init__(self, Type result_type, cpp.MessageDataFuture cpp_future)
Definition: types.py:387
None add_done_callback(self, Callable[[Future], object] fn)
Definition: types.py:406
def connect(self, Optional[int] buffer_size=None)
Definition: types.py:300
def disconnect(self)
Definition: types.py:308
def __init__(self, BundlePort owner, cpp.ChannelPort cpp_port)
Definition: types.py:295
def __init__(self, BundlePort owner, cpp.ReadChannelPort cpp_port)
Definition: types.py:343
object read(self)
Definition: types.py:347
Tuple[int, bytearray] deserialize(self, bytearray data)
Definition: types.py:194
str __str__(self)
Definition: types.py:188
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:177
bytearray serialize(self, int obj)
Definition: types.py:191
bytearray serialize(self, obj)
Definition: types.py:233
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
Definition: types.py:240
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:217
def __init__(self, cpp.StructType cpp_type)
Definition: types.py:204
Tuple[int, bytearray] deserialize(self, bytearray data)
Definition: types.py:167
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:154
str __str__(self)
Definition: types.py:161
bytearray serialize(self, int obj)
Definition: types.py:164
int bit_width(self)
Definition: types.py:93
Tuple[object, bytearray] deserialize(self, bytearray data)
Definition: types.py:100
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition: types.py:87
bytearray serialize(self, obj)
Definition: types.py:96
bool try_write(self, msg=None)
Definition: types.py:334
def __init__(self, BundlePort owner, cpp.WriteChannelPort cpp_port)
Definition: types.py:315
bytearray __serialize_msg(self, msg=None)
Definition: types.py:319
bool write(self, msg=None)
Definition: types.py:327
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
def _get_esi_type(cpp.Type cpp_type)
Definition: types.py:25