Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
types.py
Go to the documentation of this file.
1# ===-----------------------------------------------------------------------===#
2# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3# See https://llvm.org/LICENSE.txt for license information.
4# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5# ===-----------------------------------------------------------------------===#
6#
7# The structure of the Python classes and hierarchy roughly mirrors the C++
8# side, but wraps the C++ objects. The wrapper classes sometimes add convenience
9# functionality and serve to return wrapped versions of the returned objects.
10#
11# ===-----------------------------------------------------------------------===#
12
13from __future__ import annotations
14
15from . import esiCppAccel as cpp
16
17from typing import TYPE_CHECKING
18if TYPE_CHECKING:
19 from .accelerator import HWModule
20
21from concurrent.futures import Future
22from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
23import sys
24import traceback
25
26
27def _get_esi_type(cpp_type: cpp.Type):
28 """Get the wrapper class for a C++ type."""
29 for cpp_type_cls, fn in __esi_mapping.items():
30 if isinstance(cpp_type, cpp_type_cls):
31 return fn(cpp_type)
32 return ESIType(cpp_type)
33
34
35# Mapping from C++ types to functions constructing the Python object
36# corresponding to that type.
37__esi_mapping: Dict[Type, Callable] = {
38 cpp.ChannelType: lambda cpp_type: _get_esi_type(cpp_type.inner)
39}
40
41
42class ESIType:
43
44 def __init__(self, cpp_type: cpp.Type):
45 self.cpp_type = cpp_type
46
47 @property
48 def supports_host(self) -> Tuple[bool, Optional[str]]:
49 """Does this type support host communication via Python? Returns either
50 '(True, None)' if it is, or '(False, reason)' if it is not."""
51
52 if self.bit_width % 8 != 0:
53 return (False, "runtime only supports types with multiple of 8 bits")
54 return (True, None)
55
56 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
57 """Is a Python object compatible with HW type? Returns either '(True,
58 None)' if it is, or '(False, reason)' if it is not."""
59 assert False, "unimplemented"
60
61 @property
62 def bit_width(self) -> int:
63 """Size of this type, in bits. Negative for unbounded types."""
64 assert False, "unimplemented"
65
66 @property
67 def max_size(self) -> int:
68 """Maximum size of a value of this type, in bytes."""
69 bitwidth = int((self.bit_width + 7) / 8)
70 if bitwidth < 0:
71 return bitwidth
72 return bitwidth
73
74 def serialize(self, obj) -> bytearray:
75 """Convert a Python object to a bytearray."""
76 assert False, "unimplemented"
77
78 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
79 """Convert a bytearray to a Python object. Return the object and the
80 leftover bytes."""
81 assert False, "unimplemented"
82
83 def __str__(self) -> str:
84 return str(self.cpp_type)
85
86
88
89 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
90 if obj is not None:
91 return (False, f"void type cannot must represented by None, not {obj}")
92 return (True, None)
93
94 @property
95 def bit_width(self) -> int:
96 return 8
97
98 def serialize(self, obj) -> bytearray:
99 # By convention, void is represented by a single byte of value 0.
100 return bytearray([0])
101
102 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
103 if len(data) == 0:
104 raise ValueError(f"void type cannot be represented by {data}")
105 return (None, data[1:])
106
107
108__esi_mapping[cpp.VoidType] = VoidType
109
110
112
113 def __init__(self, cpp_type: cpp.BitsType):
114 self.cpp_type: cpp.BitsType = cpp_type
115
116 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
117 if not isinstance(obj, (bytearray, bytes, list)):
118 return (False, f"invalid type: {type(obj)}")
119 if isinstance(obj, list) and not all(
120 [isinstance(b, int) and b.bit_length() <= 8 for b in obj]):
121 return (False, f"list item too large: {obj}")
122 if len(obj) != self.max_size:
123 return (False, f"wrong size: {len(obj)}")
124 return (True, None)
125
126 @property
127 def bit_width(self) -> int:
128 return self.cpp_type.width
129
130 def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray:
131 if isinstance(obj, bytearray):
132 return obj
133 if isinstance(obj, bytes) or isinstance(obj, list):
134 return bytearray(obj)
135 raise ValueError(f"cannot convert {obj} to bytearray")
136
137 def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]:
138 return (data[0:self.max_size], data[self.max_size:])
139
140
141__esi_mapping[cpp.BitsType] = BitsType
142
143
145
146 def __init__(self, cpp_type: cpp.IntegerType):
147 self.cpp_type: cpp.IntegerType = cpp_type
148
149 @property
150 def bit_width(self) -> int:
151 return self.cpp_type.width
152
153
155
156 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
157 if not isinstance(obj, int):
158 return (False, f"must be an int, not {type(obj)}")
159 if obj < 0 or obj.bit_length() > self.bit_widthbit_width:
160 return (False, f"out of range: {obj}")
161 return (True, None)
162
163 def __str__(self) -> str:
164 return f"uint{self.bit_width}"
165
166 def serialize(self, obj: int) -> bytearray:
167 return bytearray(int.to_bytes(obj, self.max_sizemax_size, "little"))
168
169 def deserialize(self, data: bytearray) -> Tuple[int, bytearray]:
170 return (int.from_bytes(data[0:self.max_sizemax_size],
171 "little"), data[self.max_sizemax_size:])
172
173
174__esi_mapping[cpp.UIntType] = UIntType
175
176
178
179 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
180 if not isinstance(obj, int):
181 return (False, f"must be an int, not {type(obj)}")
182 if obj < 0:
183 if (-1 * obj) > 2**(self.bit_widthbit_width - 1):
184 return (False, f"out of range: {obj}")
185 elif obj < 0:
186 if obj >= 2**(self.bit_widthbit_width - 1) - 1:
187 return (False, f"out of range: {obj}")
188 return (True, None)
189
190 def __str__(self) -> str:
191 return f"sint{self.bit_width}"
192
193 def serialize(self, obj: int) -> bytearray:
194 return bytearray(int.to_bytes(obj, self.max_sizemax_size, "little", signed=True))
195
196 def deserialize(self, data: bytearray) -> Tuple[int, bytearray]:
197 return (int.from_bytes(data[0:self.max_sizemax_size], "little",
198 signed=True), data[self.max_sizemax_size:])
199
200
201__esi_mapping[cpp.SIntType] = SIntType
202
203
205
206 def __init__(self, cpp_type: cpp.StructType):
207 self.cpp_typecpp_type = cpp_type
208 self.fields: List[Tuple[str, ESIType]] = [
209 (name, _get_esi_type(ty)) for (name, ty) in cpp_type.fields
210 ]
211
212 @property
213 def bit_width(self) -> int:
214 widths = [ty.bit_width for (_, ty) in self.fields]
215 if any([w < 0 for w in widths]):
216 return -1
217 return sum(widths)
218
219 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
220 fields_count = 0
221 if not isinstance(obj, dict):
222 obj = obj.__dict__
223
224 for (fname, ftype) in self.fields:
225 if fname not in obj:
226 return (False, f"missing field '{fname}'")
227 fvalid, reason = ftype.is_valid(obj[fname])
228 if not fvalid:
229 return (False, f"invalid field '{fname}': {reason}")
230 fields_count += 1
231 if fields_count != len(obj):
232 return (False, "missing fields")
233 return (True, None)
234
235 def serialize(self, obj) -> bytearray:
236 ret = bytearray()
237 for (fname, ftype) in reversed(self.fields):
238 fval = obj[fname]
239 ret.extend(ftype.serialize(fval))
240 return ret
241
242 def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
243 ret = {}
244 for (fname, ftype) in reversed(self.fields):
245 (fval, data) = ftype.deserialize(data)
246 ret[fname] = fval
247 return (ret, data)
248
249
250__esi_mapping[cpp.StructType] = StructType
251
252
254
255 def __init__(self, cpp_type: cpp.ArrayType):
256 self.cpp_typecpp_type = cpp_type
257 self.element_type = _get_esi_type(cpp_type.element)
258 self.size = cpp_type.size
259
260 @property
261 def bit_width(self) -> int:
262 return self.element_type.bit_width * self.size
263
264 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
265 if not isinstance(obj, list):
266 return (False, f"must be a list, not {type(obj)}")
267 if len(obj) != self.size:
268 return (False, f"wrong size: expected {self.size} not {len(obj)}")
269 for (idx, e) in enumerate(obj):
270 evalid, reason = self.element_type.is_valid(e)
271 if not evalid:
272 return (False, f"invalid element {idx}: {reason}")
273 return (True, None)
274
275 def serialize(self, lst: list) -> bytearray:
276 ret = bytearray()
277 for e in reversed(lst):
278 ret.extend(self.element_type.serialize(e))
279 return ret
280
281 def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]:
282 ret = []
283 for _ in range(self.size):
284 (obj, data) = self.element_type.deserialize(data)
285 ret.append(obj)
286 ret.reverse()
287 return (ret, data)
288
289
290__esi_mapping[cpp.ArrayType] = ArrayType
291
292
293class Port:
294 """A unidirectional communication channel. This is the basic communication
295 method with an accelerator."""
296
297 def __init__(self, owner: BundlePort, cpp_port: cpp.ChannelPort):
298 self.owner = owner
299 self.cpp_port = cpp_port
300 self.type = _get_esi_type(cpp_port.type)
301
302 def connect(self, buffer_size: Optional[int] = None):
303 (supports_host, reason) = self.type.supports_host
304 if not supports_host:
305 raise TypeError(f"unsupported type: {reason}")
306
307 self.cpp_port.connect(buffer_size)
308 return self
309
310 def disconnect(self):
311 self.cpp_port.disconnect()
312
313
315 """A unidirectional communication channel from the host to the accelerator."""
316
317 def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
318 super().__init__(owner, cpp_port)
319 self.cpp_port: cpp.WriteChannelPort = cpp_port
320
321 def __serialize_msg(self, msg=None) -> bytearray:
322 valid, reason = self.type.is_valid(msg)
323 if not valid:
324 raise ValueError(
325 f"'{msg}' cannot be converted to '{self.type}': {reason}")
326 msg_bytes: bytearray = self.type.serialize(msg)
327 return msg_bytes
328
329 def write(self, msg=None) -> bool:
330 """Write a typed message to the channel. Attempts to serialize 'msg' to what
331 the accelerator expects, but will fail if the object is not convertible to
332 the port type."""
333 self.cpp_port.write(self.__serialize_msg(msg))
334 return True
335
336 def try_write(self, msg=None) -> bool:
337 """Like 'write', but uses the non-blocking tryWrite method of the underlying
338 port. Returns True if the write was successful, False otherwise."""
339 return self.cpp_port.tryWrite(self.__serialize_msg(msg))
340
341
343 """A unidirectional communication channel from the accelerator to the host."""
344
345 def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
346 super().__init__(owner, cpp_port)
347 self.cpp_port: cpp.ReadChannelPort = cpp_port
348
349 def read(self) -> object:
350 """Read a typed message from the channel. Returns a deserialized object of a
351 type defined by the port type."""
352
353 buffer = self.cpp_port.read()
354 (msg, leftover) = self.type.deserialize(buffer)
355 if len(leftover) != 0:
356 raise ValueError(f"leftover bytes: {leftover}")
357 return msg
358
359
361 """A collections of named, unidirectional communication channels."""
362
363 # When creating a new port, we need to determine if it is a service port and
364 # instantiate it correctly.
365 def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort):
366 # TODO: add a proper registration mechanism for service ports.
367 if isinstance(cpp_port, cpp.Function):
368 return super().__new__(FunctionPort)
369 if isinstance(cpp_port, cpp.Callback):
370 return super().__new__(CallbackPort)
371 if isinstance(cpp_port, cpp.MMIORegion):
372 return super().__new__(MMIORegion)
373 if isinstance(cpp_port, cpp.Telemetry):
374 return super().__new__(TelemetryPort)
375 return super().__new__(cls)
376
377 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
378 self.owner = owner
379 self.cpp_port = cpp_port
380
381 def write_port(self, channel_name: str) -> WritePort:
382 return WritePort(self, self.cpp_port.getWrite(channel_name))
383
384 def read_port(self, channel_name: str) -> ReadPort:
385 return ReadPort(self, self.cpp_port.getRead(channel_name))
386
387
388class MessageFuture(Future):
389 """A specialization of `Future` for ESI messages. Wraps the cpp object and
390 deserializes the result. Hopefully overrides all the methods necessary for
391 proper operation, which is assumed to be not all of them."""
392
393 def __init__(self, result_type: Type, cpp_future: cpp.MessageDataFuture):
394 self.result_type = result_type
395 self.cpp_future = cpp_future
396
397 def running(self) -> bool:
398 return True
399
400 def done(self) -> bool:
401 return self.cpp_future.valid()
402
403 def result(self, timeout: Optional[Union[int, float]] = None) -> Any:
404 # TODO: respect timeout
405 self.cpp_future.wait()
406 result_bytes = self.cpp_future.get()
407 (msg, leftover) = self.result_type.deserialize(result_bytes)
408 if len(leftover) != 0:
409 raise ValueError(f"leftover bytes: {leftover}")
410 return msg
411
412 def add_done_callback(self, fn: Callable[[Future], object]) -> None:
413 raise NotImplementedError("add_done_callback is not implemented")
414
415
417 """A region of memory-mapped I/O space. This is a collection of named
418 channels, which are either read or read-write. The channels are accessed
419 by name, and can be connected to the host."""
420
421 def __init__(self, owner: HWModule, cpp_port: cpp.MMIORegion):
422 super().__init__(owner, cpp_port)
423 self.region = cpp_port
424
425 @property
426 def descriptor(self) -> cpp.MMIORegionDesc:
427 return self.region.descriptor
428
429 def read(self, offset: int) -> bytearray:
430 """Read a value from the MMIO region at the given offset."""
431 return self.region.read(offset)
432
433 def write(self, offset: int, data: bytearray) -> None:
434 """Write a value to the MMIO region at the given offset."""
435 self.region.write(offset, data)
436
437
439 """A pair of channels which carry the input and output of a function."""
440
441 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
442 super().__init__(owner, cpp_port)
443 self.arg_type = self.write_port("arg").type
444 self.result_type = self.read_port("result").type
445 self.connected = False
446
447 def connect(self):
448 self.cpp_port.connect()
449 self.connected = True
450
451 def call(self, **kwargs: Any) -> Future:
452 """Call the function with the given argument and returns a future of the
453 result."""
454 valid, reason = self.arg_type.is_valid(kwargs)
455 if not valid:
456 raise ValueError(
457 f"'{kwargs}' cannot be converted to '{self.arg_type}': {reason}")
458 arg_bytes: bytearray = self.arg_type.serialize(kwargs)
459 cpp_future = self.cpp_port.call(arg_bytes)
460 return MessageFuture(self.result_type, cpp_future)
461
462 def __call__(self, *args: Any, **kwds: Any) -> Future:
463 return self.call(*args, **kwds)
464
465
467 """Callback ports are the inverse of function ports -- instead of calls to the
468 accelerator, they get called from the accelerator. Specify the function which
469 you'd like the accelerator to call when you call `connect`."""
470
471 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
472 super().__init__(owner, cpp_port)
473 self.arg_type = self.read_port("arg").type
474 self.result_type = self.write_port("result").type
475 self.connected = False
476
477 def connect(self, cb: Callable[[Any], Any]):
478
479 def type_convert_wrapper(cb: Callable[[Any], Any],
480 msg: bytearray) -> Optional[bytearray]:
481 try:
482 (obj, leftover) = self.arg_type.deserialize(msg)
483 if len(leftover) != 0:
484 raise ValueError(f"leftover bytes: {leftover}")
485 result = cb(obj)
486 if result is None:
487 return None
488 return self.result_type.serialize(result)
489 except Exception as e:
490 traceback.print_exception(e)
491 return None
492
493 self.cpp_port.connect(lambda x: type_convert_wrapper(cb=cb, msg=x))
494 self.connected = True
495
496
498 """Telemetry ports report an individual piece of information from the
499 acceelerator. The method of accessing telemetry will likely change in the
500 future."""
501
502 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
503 super().__init__(owner, cpp_port)
504 self.connected = False
505
506 def connect(self):
507 self.cpp_port.connect()
508 self.connected = True
509
510 def read(self) -> Future:
511 cpp_future = self.cpp_port.read()
512 return MessageFuture(self.cpp_port.type, cpp_future)
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:264
__init__(self, cpp.ArrayType cpp_type)
Definition types.py:255
Tuple[List[Any], bytearray] deserialize(self, bytearray data)
Definition types.py:281
bytearray serialize(self, list lst)
Definition types.py:275
bytearray serialize(self, Union[bytearray, bytes, List[int]] obj)
Definition types.py:130
int bit_width(self)
Definition types.py:127
Tuple[bytearray, bytearray] deserialize(self, bytearray data)
Definition types.py:137
__init__(self, cpp.BitsType cpp_type)
Definition types.py:113
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:116
__new__(cls, HWModule owner, cpp.BundlePort cpp_port)
Definition types.py:365
WritePort write_port(self, str channel_name)
Definition types.py:381
ReadPort read_port(self, str channel_name)
Definition types.py:384
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition types.py:377
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition types.py:471
connect(self, Callable[[Any], Any] cb)
Definition types.py:477
int bit_width(self)
Definition types.py:62
Tuple[bool, Optional[str]] supports_host(self)
Definition types.py:48
Tuple[object, bytearray] deserialize(self, bytearray data)
Definition types.py:78
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:56
int max_size(self)
Definition types.py:67
__init__(self, cpp.Type cpp_type)
Definition types.py:44
str __str__(self)
Definition types.py:83
bytearray serialize(self, obj)
Definition types.py:74
Future call(self, **Any kwargs)
Definition types.py:451
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition types.py:441
Future __call__(self, *Any args, **Any kwds)
Definition types.py:462
int bit_width(self)
Definition types.py:150
__init__(self, cpp.IntegerType cpp_type)
Definition types.py:146
None write(self, int offset, bytearray data)
Definition types.py:433
bytearray read(self, int offset)
Definition types.py:429
__init__(self, HWModule owner, cpp.MMIORegion cpp_port)
Definition types.py:421
cpp.MMIORegionDesc descriptor(self)
Definition types.py:426
Any result(self, Optional[Union[int, float]] timeout=None)
Definition types.py:403
__init__(self, Type result_type, cpp.MessageDataFuture cpp_future)
Definition types.py:393
None add_done_callback(self, Callable[[Future], object] fn)
Definition types.py:412
__init__(self, BundlePort owner, cpp.ChannelPort cpp_port)
Definition types.py:297
connect(self, Optional[int] buffer_size=None)
Definition types.py:302
__init__(self, BundlePort owner, cpp.ReadChannelPort cpp_port)
Definition types.py:345
object read(self)
Definition types.py:349
Tuple[int, bytearray] deserialize(self, bytearray data)
Definition types.py:196
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:179
bytearray serialize(self, int obj)
Definition types.py:193
__init__(self, cpp.StructType cpp_type)
Definition types.py:206
bytearray serialize(self, obj)
Definition types.py:235
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
Definition types.py:242
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:219
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition types.py:502
Tuple[int, bytearray] deserialize(self, bytearray data)
Definition types.py:169
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:156
bytearray serialize(self, int obj)
Definition types.py:166
int bit_width(self)
Definition types.py:95
Tuple[object, bytearray] deserialize(self, bytearray data)
Definition types.py:102
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:89
bytearray serialize(self, obj)
Definition types.py:98
bool try_write(self, msg=None)
Definition types.py:336
bytearray __serialize_msg(self, msg=None)
Definition types.py:321
bool write(self, msg=None)
Definition types.py:329
__init__(self, BundlePort owner, cpp.WriteChannelPort cpp_port)
Definition types.py:317
_get_esi_type(cpp.Type cpp_type)
Definition types.py:27