CIRCT 21.0.0git
Loading...
Searching...
No Matches
types.py
Go to the documentation of this file.
1# ===-----------------------------------------------------------------------===#
2# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
3# See https://llvm.org/LICENSE.txt for license information.
4# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
5# ===-----------------------------------------------------------------------===#
6#
7# The structure of the Python classes and hierarchy roughly mirrors the C++
8# side, but wraps the C++ objects. The wrapper classes sometimes add convenience
9# functionality and serve to return wrapped versions of the returned objects.
10#
11# ===-----------------------------------------------------------------------===#
12
13from __future__ import annotations
14
15from . import esiCppAccel as cpp
16
17from typing import TYPE_CHECKING
18if TYPE_CHECKING:
19 from .accelerator import HWModule
20
21from concurrent.futures import Future
22from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
23
24
25def _get_esi_type(cpp_type: cpp.Type):
26 """Get the wrapper class for a C++ type."""
27 for cpp_type_cls, fn in __esi_mapping.items():
28 if isinstance(cpp_type, cpp_type_cls):
29 return fn(cpp_type)
30 return ESIType(cpp_type)
31
32
33# Mapping from C++ types to functions constructing the Python object
34# corresponding to that type.
35__esi_mapping: Dict[Type, Callable] = {
36 cpp.ChannelType: lambda cpp_type: _get_esi_type(cpp_type.inner)
37}
38
39
40class ESIType:
41
42 def __init__(self, cpp_type: cpp.Type):
43 self.cpp_type = cpp_type
44
45 @property
46 def supports_host(self) -> Tuple[bool, Optional[str]]:
47 """Does this type support host communication via Python? Returns either
48 '(True, None)' if it is, or '(False, reason)' if it is not."""
49
50 if self.bit_width % 8 != 0:
51 return (False, "runtime only supports types with multiple of 8 bits")
52 return (True, None)
53
54 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
55 """Is a Python object compatible with HW type? Returns either '(True,
56 None)' if it is, or '(False, reason)' if it is not."""
57 assert False, "unimplemented"
58
59 @property
60 def bit_width(self) -> int:
61 """Size of this type, in bits. Negative for unbounded types."""
62 assert False, "unimplemented"
63
64 @property
65 def max_size(self) -> int:
66 """Maximum size of a value of this type, in bytes."""
67 bitwidth = int((self.bit_width + 7) / 8)
68 if bitwidth < 0:
69 return bitwidth
70 return bitwidth
71
72 def serialize(self, obj) -> bytearray:
73 """Convert a Python object to a bytearray."""
74 assert False, "unimplemented"
75
76 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
77 """Convert a bytearray to a Python object. Return the object and the
78 leftover bytes."""
79 assert False, "unimplemented"
80
81 def __str__(self) -> str:
82 return str(self.cpp_type)
83
84
86
87 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
88 if obj is not None:
89 return (False, f"void type cannot must represented by None, not {obj}")
90 return (True, None)
91
92 @property
93 def bit_width(self) -> int:
94 return 8
95
96 def serialize(self, obj) -> bytearray:
97 # By convention, void is represented by a single byte of value 0.
98 return bytearray([0])
99
100 def deserialize(self, data: bytearray) -> Tuple[object, bytearray]:
101 if len(data) == 0:
102 raise ValueError(f"void type cannot be represented by {data}")
103 return (None, data[1:])
104
105
106__esi_mapping[cpp.VoidType] = VoidType
107
108
110
111 def __init__(self, cpp_type: cpp.BitsType):
112 self.cpp_type: cpp.BitsType = cpp_type
113
114 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
115 if not isinstance(obj, (bytearray, bytes, list)):
116 return (False, f"invalid type: {type(obj)}")
117 if isinstance(obj, list) and not all(
118 [isinstance(b, int) and b.bit_length() <= 8 for b in obj]):
119 return (False, f"list item too large: {obj}")
120 if len(obj) != self.max_size:
121 return (False, f"wrong size: {len(obj)}")
122 return (True, None)
123
124 @property
125 def bit_width(self) -> int:
126 return self.cpp_type.width
127
128 def serialize(self, obj: Union[bytearray, bytes, List[int]]) -> bytearray:
129 if isinstance(obj, bytearray):
130 return obj
131 if isinstance(obj, bytes) or isinstance(obj, list):
132 return bytearray(obj)
133 raise ValueError(f"cannot convert {obj} to bytearray")
134
135 def deserialize(self, data: bytearray) -> Tuple[bytearray, bytearray]:
136 return (data[0:self.max_size], data[self.max_size:])
137
138
139__esi_mapping[cpp.BitsType] = BitsType
140
141
143
144 def __init__(self, cpp_type: cpp.IntegerType):
145 self.cpp_type: cpp.IntegerType = cpp_type
146
147 @property
148 def bit_width(self) -> int:
149 return self.cpp_type.width
150
151
153
154 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
155 if not isinstance(obj, int):
156 return (False, f"must be an int, not {type(obj)}")
157 if obj < 0 or obj.bit_length() > self.bit_widthbit_width:
158 return (False, f"out of range: {obj}")
159 return (True, None)
160
161 def __str__(self) -> str:
162 return f"uint{self.bit_width}"
163
164 def serialize(self, obj: int) -> bytearray:
165 return bytearray(int.to_bytes(obj, self.max_sizemax_size, "little"))
166
167 def deserialize(self, data: bytearray) -> Tuple[int, bytearray]:
168 return (int.from_bytes(data[0:self.max_sizemax_size],
169 "little"), data[self.max_sizemax_size:])
170
171
172__esi_mapping[cpp.UIntType] = UIntType
173
174
176
177 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
178 if not isinstance(obj, int):
179 return (False, f"must be an int, not {type(obj)}")
180 if obj < 0:
181 if (-1 * obj) > 2**(self.bit_widthbit_width - 1):
182 return (False, f"out of range: {obj}")
183 elif obj < 0:
184 if obj >= 2**(self.bit_widthbit_width - 1) - 1:
185 return (False, f"out of range: {obj}")
186 return (True, None)
187
188 def __str__(self) -> str:
189 return f"sint{self.bit_width}"
190
191 def serialize(self, obj: int) -> bytearray:
192 return bytearray(int.to_bytes(obj, self.max_sizemax_size, "little", signed=True))
193
194 def deserialize(self, data: bytearray) -> Tuple[int, bytearray]:
195 return (int.from_bytes(data[0:self.max_sizemax_size], "little",
196 signed=True), data[self.max_sizemax_size:])
197
198
199__esi_mapping[cpp.SIntType] = SIntType
200
201
203
204 def __init__(self, cpp_type: cpp.StructType):
205 self.cpp_typecpp_type = cpp_type
206 self.fields: List[Tuple[str, ESIType]] = [
207 (name, _get_esi_type(ty)) for (name, ty) in cpp_type.fields
208 ]
209
210 @property
211 def bit_width(self) -> int:
212 widths = [ty.bit_width for (_, ty) in self.fields]
213 if any([w < 0 for w in widths]):
214 return -1
215 return sum(widths)
216
217 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
218 fields_count = 0
219 if not isinstance(obj, dict):
220 obj = obj.__dict__
221
222 for (fname, ftype) in self.fields:
223 if fname not in obj:
224 return (False, f"missing field '{fname}'")
225 fvalid, reason = ftype.is_valid(obj[fname])
226 if not fvalid:
227 return (False, f"invalid field '{fname}': {reason}")
228 fields_count += 1
229 if fields_count != len(obj):
230 return (False, "missing fields")
231 return (True, None)
232
233 def serialize(self, obj) -> bytearray:
234 ret = bytearray()
235 for (fname, ftype) in reversed(self.fields):
236 fval = obj[fname]
237 ret.extend(ftype.serialize(fval))
238 return ret
239
240 def deserialize(self, data: bytearray) -> Tuple[Dict[str, Any], bytearray]:
241 ret = {}
242 for (fname, ftype) in reversed(self.fields):
243 (fval, data) = ftype.deserialize(data)
244 ret[fname] = fval
245 return (ret, data)
246
247
248__esi_mapping[cpp.StructType] = StructType
249
250
252
253 def __init__(self, cpp_type: cpp.ArrayType):
254 self.cpp_typecpp_type = cpp_type
255 self.element_type = _get_esi_type(cpp_type.element)
256 self.size = cpp_type.size
257
258 @property
259 def bit_width(self) -> int:
260 return self.element_type.bit_width * self.size
261
262 def is_valid(self, obj) -> Tuple[bool, Optional[str]]:
263 if not isinstance(obj, list):
264 return (False, f"must be a list, not {type(obj)}")
265 if len(obj) != self.size:
266 return (False, f"wrong size: expected {self.size} not {len(obj)}")
267 for (idx, e) in enumerate(obj):
268 evalid, reason = self.element_type.is_valid(e)
269 if not evalid:
270 return (False, f"invalid element {idx}: {reason}")
271 return (True, None)
272
273 def serialize(self, lst: list) -> bytearray:
274 ret = bytearray()
275 for e in reversed(lst):
276 ret.extend(self.element_type.serialize(e))
277 return ret
278
279 def deserialize(self, data: bytearray) -> Tuple[List[Any], bytearray]:
280 ret = []
281 for _ in range(self.size):
282 (obj, data) = self.element_type.deserialize(data)
283 ret.append(obj)
284 ret.reverse()
285 return (ret, data)
286
287
288__esi_mapping[cpp.ArrayType] = ArrayType
289
290
291class Port:
292 """A unidirectional communication channel. This is the basic communication
293 method with an accelerator."""
294
295 def __init__(self, owner: BundlePort, cpp_port: cpp.ChannelPort):
296 self.owner = owner
297 self.cpp_port = cpp_port
298 self.type = _get_esi_type(cpp_port.type)
299
300 def connect(self, buffer_size: Optional[int] = None):
301 (supports_host, reason) = self.type.supports_host
302 if not supports_host:
303 raise TypeError(f"unsupported type: {reason}")
304
305 self.cpp_port.connect(buffer_size)
306 return self
307
308 def disconnect(self):
309 self.cpp_port.disconnect()
310
311
313 """A unidirectional communication channel from the host to the accelerator."""
314
315 def __init__(self, owner: BundlePort, cpp_port: cpp.WriteChannelPort):
316 super().__init__(owner, cpp_port)
317 self.cpp_port: cpp.WriteChannelPort = cpp_port
318
319 def __serialize_msg(self, msg=None) -> bytearray:
320 valid, reason = self.type.is_valid(msg)
321 if not valid:
322 raise ValueError(
323 f"'{msg}' cannot be converted to '{self.type}': {reason}")
324 msg_bytes: bytearray = self.type.serialize(msg)
325 return msg_bytes
326
327 def write(self, msg=None) -> bool:
328 """Write a typed message to the channel. Attempts to serialize 'msg' to what
329 the accelerator expects, but will fail if the object is not convertible to
330 the port type."""
331 self.cpp_port.write(self.__serialize_msg(msg))
332 return True
333
334 def try_write(self, msg=None) -> bool:
335 """Like 'write', but uses the non-blocking tryWrite method of the underlying
336 port. Returns True if the write was successful, False otherwise."""
337 return self.cpp_port.tryWrite(self.__serialize_msg(msg))
338
339
341 """A unidirectional communication channel from the accelerator to the host."""
342
343 def __init__(self, owner: BundlePort, cpp_port: cpp.ReadChannelPort):
344 super().__init__(owner, cpp_port)
345 self.cpp_port: cpp.ReadChannelPort = cpp_port
346
347 def read(self) -> object:
348 """Read a typed message from the channel. Returns a deserialized object of a
349 type defined by the port type."""
350
351 buffer = self.cpp_port.read()
352 (msg, leftover) = self.type.deserialize(buffer)
353 if len(leftover) != 0:
354 raise ValueError(f"leftover bytes: {leftover}")
355 return msg
356
357
359 """A collections of named, unidirectional communication channels."""
360
361 # When creating a new port, we need to determine if it is a service port and
362 # instantiate it correctly.
363 def __new__(cls, owner: HWModule, cpp_port: cpp.BundlePort):
364 # TODO: add a proper registration mechanism for service ports.
365 if isinstance(cpp_port, cpp.Function):
366 return super().__new__(FunctionPort)
367 if isinstance(cpp_port, cpp.MMIORegion):
368 return super().__new__(MMIORegion)
369 return super().__new__(cls)
370
371 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
372 self.owner = owner
373 self.cpp_port = cpp_port
374
375 def write_port(self, channel_name: str) -> WritePort:
376 return WritePort(self, self.cpp_port.getWrite(channel_name))
377
378 def read_port(self, channel_name: str) -> ReadPort:
379 return ReadPort(self, self.cpp_port.getRead(channel_name))
380
381
382class MessageFuture(Future):
383 """A specialization of `Future` for ESI messages. Wraps the cpp object and
384 deserializes the result. Hopefully overrides all the methods necessary for
385 proper operation, which is assumed to be not all of them."""
386
387 def __init__(self, result_type: Type, cpp_future: cpp.MessageDataFuture):
388 self.result_type = result_type
389 self.cpp_future = cpp_future
390
391 def running(self) -> bool:
392 return True
393
394 def done(self) -> bool:
395 return self.cpp_future.valid()
396
397 def result(self, timeout: Optional[Union[int, float]] = None) -> Any:
398 # TODO: respect timeout
399 self.cpp_future.wait()
400 result_bytes = self.cpp_future.get()
401 (msg, leftover) = self.result_type.deserialize(result_bytes)
402 if len(leftover) != 0:
403 raise ValueError(f"leftover bytes: {leftover}")
404 return msg
405
406 def add_done_callback(self, fn: Callable[[Future], object]) -> None:
407 raise NotImplementedError("add_done_callback is not implemented")
408
409
411 """A region of memory-mapped I/O space. This is a collection of named
412 channels, which are either read or read-write. The channels are accessed
413 by name, and can be connected to the host."""
414
415 def __init__(self, owner: HWModule, cpp_port: cpp.MMIORegion):
416 super().__init__(owner, cpp_port)
417 self.region = cpp_port
418
419 @property
420 def descriptor(self) -> cpp.MMIORegionDesc:
421 return self.region.descriptor
422
423 def read(self, offset: int) -> bytearray:
424 """Read a value from the MMIO region at the given offset."""
425 return self.region.read(offset)
426
427 def write(self, offset: int, data: bytearray) -> None:
428 """Write a value to the MMIO region at the given offset."""
429 self.region.write(offset, data)
430
431
433 """A pair of channels which carry the input and output of a function."""
434
435 def __init__(self, owner: HWModule, cpp_port: cpp.BundlePort):
436 super().__init__(owner, cpp_port)
437 self.arg_type = self.write_port("arg").type
438 self.result_type = self.read_port("result").type
439 self.connected = False
440
441 def connect(self):
442 self.cpp_port.connect()
443 self.connected = True
444
445 def call(self, **kwargs: Any) -> Future:
446 """Call the function with the given argument and returns a future of the
447 result."""
448 valid, reason = self.arg_type.is_valid(kwargs)
449 if not valid:
450 raise ValueError(
451 f"'{kwargs}' cannot be converted to '{self.arg_type}': {reason}")
452 arg_bytes: bytearray = self.arg_type.serialize(kwargs)
453 cpp_future = self.cpp_port.call(arg_bytes)
454 return MessageFuture(self.result_type, cpp_future)
455
456 def __call__(self, *args: Any, **kwds: Any) -> Future:
457 return self.call(*args, **kwds)
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:262
__init__(self, cpp.ArrayType cpp_type)
Definition types.py:253
Tuple[List[Any], bytearray] deserialize(self, bytearray data)
Definition types.py:279
bytearray serialize(self, list lst)
Definition types.py:273
bytearray serialize(self, Union[bytearray, bytes, List[int]] obj)
Definition types.py:128
int bit_width(self)
Definition types.py:125
Tuple[bytearray, bytearray] deserialize(self, bytearray data)
Definition types.py:135
__init__(self, cpp.BitsType cpp_type)
Definition types.py:111
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:114
__new__(cls, HWModule owner, cpp.BundlePort cpp_port)
Definition types.py:363
WritePort write_port(self, str channel_name)
Definition types.py:375
ReadPort read_port(self, str channel_name)
Definition types.py:378
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition types.py:371
int bit_width(self)
Definition types.py:60
Tuple[bool, Optional[str]] supports_host(self)
Definition types.py:46
Tuple[object, bytearray] deserialize(self, bytearray data)
Definition types.py:76
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:54
int max_size(self)
Definition types.py:65
__init__(self, cpp.Type cpp_type)
Definition types.py:42
str __str__(self)
Definition types.py:81
bytearray serialize(self, obj)
Definition types.py:72
Future call(self, **Any kwargs)
Definition types.py:445
__init__(self, HWModule owner, cpp.BundlePort cpp_port)
Definition types.py:435
Future __call__(self, *Any args, **Any kwds)
Definition types.py:456
int bit_width(self)
Definition types.py:148
__init__(self, cpp.IntegerType cpp_type)
Definition types.py:144
None write(self, int offset, bytearray data)
Definition types.py:427
bytearray read(self, int offset)
Definition types.py:423
__init__(self, HWModule owner, cpp.MMIORegion cpp_port)
Definition types.py:415
cpp.MMIORegionDesc descriptor(self)
Definition types.py:420
Any result(self, Optional[Union[int, float]] timeout=None)
Definition types.py:397
__init__(self, Type result_type, cpp.MessageDataFuture cpp_future)
Definition types.py:387
None add_done_callback(self, Callable[[Future], object] fn)
Definition types.py:406
__init__(self, BundlePort owner, cpp.ChannelPort cpp_port)
Definition types.py:295
connect(self, Optional[int] buffer_size=None)
Definition types.py:300
__init__(self, BundlePort owner, cpp.ReadChannelPort cpp_port)
Definition types.py:343
object read(self)
Definition types.py:347
Tuple[int, bytearray] deserialize(self, bytearray data)
Definition types.py:194
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:177
bytearray serialize(self, int obj)
Definition types.py:191
__init__(self, cpp.StructType cpp_type)
Definition types.py:204
bytearray serialize(self, obj)
Definition types.py:233
Tuple[Dict[str, Any], bytearray] deserialize(self, bytearray data)
Definition types.py:240
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:217
Tuple[int, bytearray] deserialize(self, bytearray data)
Definition types.py:167
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:154
bytearray serialize(self, int obj)
Definition types.py:164
int bit_width(self)
Definition types.py:93
Tuple[object, bytearray] deserialize(self, bytearray data)
Definition types.py:100
Tuple[bool, Optional[str]] is_valid(self, obj)
Definition types.py:87
bytearray serialize(self, obj)
Definition types.py:96
bool try_write(self, msg=None)
Definition types.py:334
bytearray __serialize_msg(self, msg=None)
Definition types.py:319
bool write(self, msg=None)
Definition types.py:327
__init__(self, BundlePort owner, cpp.WriteChannelPort cpp_port)
Definition types.py:315
_get_esi_type(cpp.Type cpp_type)
Definition types.py:25