CIRCT 22.0.0git
Loading...
Searching...
No Matches
esiCppAccel.cpp
Go to the documentation of this file.
1//===- esiaccel.cpp - ESI runtime python bindings ---------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Simply wrap the C++ API into a Python module called 'esiaccel'.
10//
11//===----------------------------------------------------------------------===//
12
13#include "esi/Accelerator.h"
14#include "esi/Services.h"
15
16#include "esi/backends/Cosim.h"
17
18#include <ranges>
19#include <sstream>
20
21// pybind11 includes
22#include <pybind11/pybind11.h>
23namespace py = pybind11;
24
25#include <pybind11/functional.h>
26#include <pybind11/stl.h>
27
28using namespace esi;
29using namespace esi::services;
30
31namespace pybind11 {
32/// Pybind11 needs a little help downcasting with non-bound instances.
33template <>
34struct polymorphic_type_hook<ChannelPort> {
35 static const void *get(const ChannelPort *port, const std::type_info *&type) {
36 if (auto p = dynamic_cast<const WriteChannelPort *>(port)) {
37 type = &typeid(WriteChannelPort);
38 return p;
39 }
40 if (auto p = dynamic_cast<const ReadChannelPort *>(port)) {
41 type = &typeid(ReadChannelPort);
42 return p;
43 }
44 return port;
45 }
46};
47template <>
48struct polymorphic_type_hook<Service> {
49 static const void *get(const Service *svc, const std::type_info *&type) {
50 if (auto p = dynamic_cast<const MMIO *>(svc)) {
51 type = &typeid(MMIO);
52 return p;
53 }
54 if (auto p = dynamic_cast<const SysInfo *>(svc)) {
55 type = &typeid(SysInfo);
56 return p;
57 }
58 if (auto p = dynamic_cast<const HostMem *>(svc)) {
59 type = &typeid(HostMem);
60 return p;
61 }
62 if (auto p = dynamic_cast<const TelemetryService *>(svc)) {
63 type = &typeid(TelemetryService);
64 return p;
65 }
66 return svc;
67 }
68};
69
70namespace detail {
71/// Pybind11 doesn't have a built-in type caster for std::any
72/// (https://github.com/pybind/pybind11/issues/1590). We must provide one which
73/// knows about all of the potential types which the any might be.
74template <>
75struct type_caster<std::any> {
76public:
77 PYBIND11_TYPE_CASTER(std::any, const_name("object"));
78
79 static handle cast(std::any src, return_value_policy /* policy */,
80 handle /* parent */) {
81 const std::type_info &t = src.type();
82 if (t == typeid(std::string))
83 return py::str(std::any_cast<std::string>(src));
84 else if (t == typeid(int64_t))
85 return py::int_(std::any_cast<int64_t>(src));
86 else if (t == typeid(uint64_t))
87 return py::int_(std::any_cast<uint64_t>(src));
88 else if (t == typeid(double))
89 return py::float_(std::any_cast<double>(src));
90 else if (t == typeid(bool))
91 return py::bool_(std::any_cast<bool>(src));
92 else if (t == typeid(std::nullptr_t))
93 return py::none();
94 return py::none();
95 }
96};
97} // namespace detail
98} // namespace pybind11
99
100/// Resolve a Type to the Python wrapper object.
101py::object getPyType(std::optional<const Type *> t) {
102 py::object typesModule = py::module_::import("esiaccel.types");
103 if (!t)
104 return py::none();
105 return typesModule.attr("_get_esi_type")(*t);
106}
107
108// NOLINTNEXTLINE(readability-identifier-naming)
109PYBIND11_MODULE(esiCppAccel, m) {
110 py::class_<Type>(m, "Type")
111 .def(py::init<const Type::ID &>(), py::arg("id"))
112 .def_property_readonly("id", &Type::getID)
113 .def("__repr__", [](Type &t) { return "<" + t.getID() + ">"; });
114 py::class_<ChannelType, Type>(m, "ChannelType")
115 .def(py::init<const Type::ID &, const Type *>(), py::arg("id"),
116 py::arg("inner"))
117 .def_property_readonly("inner", &ChannelType::getInner,
118 py::return_value_policy::reference);
119 py::enum_<BundleType::Direction>(m, "Direction")
120 .value("To", BundleType::Direction::To)
121 .value("From", BundleType::Direction::From)
122 .export_values();
123 py::class_<BundleType, Type>(m, "BundleType")
124 .def(py::init<const Type::ID &, const BundleType::ChannelVector &>(),
125 py::arg("id"), py::arg("channels"))
126 .def_property_readonly("channels", &BundleType::getChannels,
127 py::return_value_policy::reference);
128 py::class_<VoidType, Type>(m, "VoidType")
129 .def(py::init<const Type::ID &>(), py::arg("id"));
130 py::class_<AnyType, Type>(m, "AnyType")
131 .def(py::init<const Type::ID &>(), py::arg("id"));
132 py::class_<BitVectorType, Type>(m, "BitVectorType")
133 .def(py::init<const Type::ID &, uint64_t>(), py::arg("id"),
134 py::arg("width"))
135 .def_property_readonly("width", &BitVectorType::getWidth);
136 py::class_<BitsType, BitVectorType>(m, "BitsType")
137 .def(py::init<const Type::ID &, uint64_t>(), py::arg("id"),
138 py::arg("width"));
139 py::class_<IntegerType, BitVectorType>(m, "IntegerType")
140 .def(py::init<const Type::ID &, uint64_t>(), py::arg("id"),
141 py::arg("width"));
142 py::class_<SIntType, IntegerType>(m, "SIntType")
143 .def(py::init<const Type::ID &, uint64_t>(), py::arg("id"),
144 py::arg("width"));
145 py::class_<UIntType, IntegerType>(m, "UIntType")
146 .def(py::init<const Type::ID &, uint64_t>(), py::arg("id"),
147 py::arg("width"));
148 py::class_<StructType, Type>(m, "StructType")
149 .def(py::init<const Type::ID &, const StructType::FieldVector &, bool>(),
150 py::arg("id"), py::arg("fields"), py::arg("reverse") = true)
151 .def_property_readonly("fields", &StructType::getFields,
152 py::return_value_policy::reference)
153 .def_property_readonly("reverse", &StructType::isReverse);
154 py::class_<ArrayType, Type>(m, "ArrayType")
155 .def(py::init<const Type::ID &, const Type *, uint64_t>(), py::arg("id"),
156 py::arg("element_type"), py::arg("size"))
157 .def_property_readonly("element", &ArrayType::getElementType,
158 py::return_value_policy::reference)
159 .def_property_readonly("size", &ArrayType::getSize);
160
161 py::class_<Constant>(m, "Constant")
162 .def_property_readonly("value", [](Constant &c) { return c.value; })
163 .def_property_readonly(
164 "type", [](Constant &c) { return getPyType(*c.type); },
165 py::return_value_policy::reference);
166
167 py::class_<AppID>(m, "AppID")
168 .def(py::init<std::string, std::optional<uint32_t>>(), py::arg("name"),
169 py::arg("idx") = std::nullopt)
170 .def_property_readonly("name", [](AppID &id) { return id.name; })
171 .def_property_readonly("idx",
172 [](AppID &id) -> py::object {
173 if (id.idx)
174 return py::cast(id.idx);
175 return py::none();
176 })
177 .def("__repr__",
178 [](AppID &id) {
179 std::string ret = "<" + id.name;
180 if (id.idx)
181 ret = ret + "[" + std::to_string(*id.idx) + "]";
182 ret = ret + ">";
183 return ret;
184 })
185 .def("__eq__", [](AppID &a, AppID &b) { return a == b; })
186 .def("__hash__", [](AppID &id) {
187 return utils::hash_combine(std::hash<std::string>{}(id.name),
188 std::hash<uint32_t>{}(id.idx.value_or(-1)));
189 });
190 py::class_<AppIDPath>(m, "AppIDPath").def("__repr__", &AppIDPath::toStr);
191
192 py::class_<ModuleInfo>(m, "ModuleInfo")
193 .def_property_readonly("name", [](ModuleInfo &info) { return info.name; })
194 .def_property_readonly("summary",
195 [](ModuleInfo &info) { return info.summary; })
196 .def_property_readonly("version",
197 [](ModuleInfo &info) { return info.version; })
198 .def_property_readonly("repo", [](ModuleInfo &info) { return info.repo; })
199 .def_property_readonly("commit_hash",
200 [](ModuleInfo &info) { return info.commitHash; })
201 .def_property_readonly("constants",
202 [](ModuleInfo &info) { return info.constants; })
203 // TODO: "extra" field.
204 .def("__repr__", [](ModuleInfo &info) {
205 std::string ret;
206 std::stringstream os(ret);
207 os << info;
208 return os.str();
209 });
210
211 py::enum_<Logger::Level>(m, "LogLevel")
212 .value("Debug", Logger::Level::Debug)
213 .value("Info", Logger::Level::Info)
214 .value("Warning", Logger::Level::Warning)
215 .value("Error", Logger::Level::Error)
216 .export_values();
217 py::class_<Logger>(m, "Logger");
218
219 py::class_<services::Service>(m, "Service");
220
221 py::class_<SysInfo, services::Service>(m, "SysInfo")
222 .def("esi_version", &SysInfo::getEsiVersion)
223 .def("json_manifest", &SysInfo::getJsonManifest);
224
225 py::class_<MMIO::RegionDescriptor>(m, "MMIORegionDescriptor")
226 .def_property_readonly("base",
227 [](MMIO::RegionDescriptor &r) { return r.base; })
228 .def_property_readonly("size",
229 [](MMIO::RegionDescriptor &r) { return r.size; });
230 py::class_<services::MMIO, services::Service>(m, "MMIO")
231 .def("read", &services::MMIO::read)
232 .def("write", &services::MMIO::write)
233 .def_property_readonly("regions", &services::MMIO::getRegions,
234 py::return_value_policy::reference);
235
236 py::class_<services::HostMem::HostMemRegion>(m, "HostMemRegion")
237 .def_property_readonly("ptr",
239 return reinterpret_cast<uintptr_t>(mem.getPtr());
240 })
241 .def_property_readonly("size",
243
244 py::class_<services::HostMem::Options>(m, "HostMemOptions")
245 .def(py::init<>())
246 .def_readwrite("writeable", &services::HostMem::Options::writeable)
247 .def_readwrite("use_large_pages",
249 .def("__repr__", [](services::HostMem::Options &opts) {
250 std::string ret = "HostMemOptions(";
251 if (opts.writeable)
252 ret += "writeable ";
253 if (opts.useLargePages)
254 ret += "use_large_pages";
255 ret += ")";
256 return ret;
257 });
258
259 py::class_<services::HostMem, services::Service>(m, "HostMem")
260 .def("allocate", &services::HostMem::allocate, py::arg("size"),
261 py::arg("options") = services::HostMem::Options(),
262 py::return_value_policy::take_ownership)
263 .def(
264 "map_memory",
265 [](HostMem &self, uintptr_t ptr, size_t size, HostMem::Options opts) {
266 return self.mapMemory(reinterpret_cast<void *>(ptr), size, opts);
267 },
268 py::arg("ptr"), py::arg("size"),
269 py::arg("options") = services::HostMem::Options())
270 .def(
271 "unmap_memory",
272 [](HostMem &self, uintptr_t ptr) {
273 return self.unmapMemory(reinterpret_cast<void *>(ptr));
274 },
275 py::arg("ptr"));
276
277 // py::class_<std::__basic_future<MessageData>>(m, "MessageDataFuture");
278 py::class_<std::future<MessageData>>(m, "MessageDataFuture")
279 .def("valid",
280 [](std::future<MessageData> &f) {
281 // For some reason, if we just pass the function pointer, pybind11
282 // sees `std::__basic_future` as the type and pybind11_stubgen
283 // emits an error.
284 return f.valid();
285 })
286 .def("wait",
287 [](std::future<MessageData> &f) {
288 // Yield the GIL while waiting for the future to complete, in case
289 // of python callbacks occuring from other threads while waiting.
290 py::gil_scoped_release release{};
291 f.wait();
292 })
293 .def("get", [](std::future<MessageData> &f) {
294 std::optional<MessageData> data;
295 {
296 // Yield the GIL while waiting for the future to complete, in case of
297 // python callbacks occuring from other threads while waiting.
298 py::gil_scoped_release release{};
299 data.emplace(f.get());
300 }
301 return py::bytearray((const char *)data->getBytes(), data->getSize());
302 });
303
304 py::class_<ChannelPort>(m, "ChannelPort")
305 .def("connect", &ChannelPort::connect,
306 py::arg("buffer_size") = std::nullopt)
307 .def("disconnect", &ChannelPort::disconnect)
308 .def_property_readonly("type", &ChannelPort::getType,
309 py::return_value_policy::reference);
310
311 py::class_<WriteChannelPort, ChannelPort>(m, "WriteChannelPort")
312 .def("write",
313 [](WriteChannelPort &p, py::bytearray &data) {
314 py::buffer_info info(py::buffer(data).request());
315 std::vector<uint8_t> dataVec((uint8_t *)info.ptr,
316 (uint8_t *)info.ptr + info.size);
317 p.write(dataVec);
318 })
319 .def("tryWrite", [](WriteChannelPort &p, py::bytearray &data) {
320 py::buffer_info info(py::buffer(data).request());
321 std::vector<uint8_t> dataVec((uint8_t *)info.ptr,
322 (uint8_t *)info.ptr + info.size);
323 return p.tryWrite(dataVec);
324 });
325 py::class_<ReadChannelPort, ChannelPort>(m, "ReadChannelPort")
326 .def(
327 "read",
328 [](ReadChannelPort &p) -> py::bytearray {
329 MessageData data;
330 p.read(data);
331 return py::bytearray((const char *)data.getBytes(), data.getSize());
332 },
333 "Read data from the channel. Blocking.")
334 .def("read_async", &ReadChannelPort::readAsync);
335
336 py::class_<BundlePort>(m, "BundlePort")
337 .def_property_readonly("id", &BundlePort::getID)
338 .def_property_readonly("channels", &BundlePort::getChannels,
339 py::return_value_policy::reference)
340 .def("getWrite", &BundlePort::getRawWrite,
341 py::return_value_policy::reference)
342 .def("getRead", &BundlePort::getRawRead,
343 py::return_value_policy::reference);
344
345 py::class_<ServicePort, BundlePort>(m, "ServicePort");
346
347 py::class_<MMIO::MMIORegion, ServicePort>(m, "MMIORegion")
348 .def_property_readonly("descriptor", &MMIO::MMIORegion::getDescriptor)
349 .def("read", &MMIO::MMIORegion::read)
350 .def("write", &MMIO::MMIORegion::write);
351
352 py::class_<FuncService::Function, ServicePort>(m, "Function")
353 .def(
354 "call",
355 [](FuncService::Function &self,
356 py::bytearray msg) -> std::future<MessageData> {
357 py::buffer_info info(py::buffer(msg).request());
358 std::vector<uint8_t> dataVec((uint8_t *)info.ptr,
359 (uint8_t *)info.ptr + info.size);
360 MessageData data(dataVec);
361 return self.call(data);
362 },
363 py::return_value_policy::take_ownership)
364 .def("connect", &FuncService::Function::connect);
365
366 py::class_<CallService::Callback, ServicePort>(m, "Callback")
367 .def("connect", [](CallService::Callback &self,
368 std::function<py::object(py::object)> pyCallback) {
369 // TODO: Under certain conditions this will cause python to crash. I
370 // don't remember how to replicate these crashes, but IIRC they are
371 // deterministic.
372 self.connect([pyCallback](const MessageData &req) -> MessageData {
373 py::gil_scoped_acquire acquire{};
374 std::vector<uint8_t> arg(req.getBytes(),
375 req.getBytes() + req.getSize());
376 py::bytearray argObj((const char *)arg.data(), arg.size());
377 auto ret = pyCallback(argObj);
378 if (ret.is_none())
379 return MessageData();
380 py::buffer_info info(py::buffer(ret).request());
381 std::vector<uint8_t> dataVec((uint8_t *)info.ptr,
382 (uint8_t *)info.ptr + info.size);
383 return MessageData(dataVec);
384 });
385 });
386
387 py::class_<TelemetryService::Telemetry, ServicePort>(m, "Telemetry")
390
391 // Store this variable (not commonly done) as the "children" method needs for
392 // "Instance" to be defined first.
393 auto hwmodule =
394 py::class_<HWModule>(m, "HWModule")
395 .def_property_readonly("info", &HWModule::getInfo)
396 .def_property_readonly("ports", &HWModule::getPorts,
397 py::return_value_policy::reference)
398 .def_property_readonly("services", &HWModule::getServices,
399 py::return_value_policy::reference);
400
401 // In order to inherit methods from "HWModule", it needs to be defined first.
402 py::class_<Instance, HWModule>(m, "Instance")
403 .def_property_readonly("id", &Instance::getID);
404
405 py::class_<Accelerator, HWModule>(m, "Accelerator");
406
407 // Since this returns a vector of Instance*, we need to define Instance first
408 // or else pybind11-stubgen complains.
409 hwmodule.def_property_readonly("children", &HWModule::getChildren,
410 py::return_value_policy::reference);
411
412 auto accConn = py::class_<AcceleratorConnection>(m, "AcceleratorConnection");
413
414 py::class_<Context>(m, "Context")
415 .def(py::init<>())
416 .def("connect", &Context::connect)
417 .def("set_stdio_logger", [](Context &ctxt, Logger::Level level) {
418 ctxt.setLogger(std::make_unique<StreamLogger>(level));
419 });
420
421 accConn.def(py::init(&registry::connect))
422 .def(
423 "sysinfo",
424 [](AcceleratorConnection &acc) {
425 return acc.getService<services::SysInfo>({});
426 },
427 py::return_value_policy::reference)
428 .def(
429 "get_service_mmio",
430 [](AcceleratorConnection &acc) {
431 return acc.getService<services::MMIO>({});
432 },
433 py::return_value_policy::reference)
434 .def(
435 "get_service_hostmem",
436 [](AcceleratorConnection &acc) {
437 return acc.getService<services::HostMem>({});
438 },
439 py::return_value_policy::reference)
440 .def("get_accelerator", &AcceleratorConnection::getAccelerator,
441 py::return_value_policy::reference);
442
443 py::class_<Manifest>(m, "Manifest")
444 .def(py::init<Context &, std::string>())
445 .def_property_readonly("api_version", &Manifest::getApiVersion)
446 .def(
447 "build_accelerator",
448 [&](Manifest &m, AcceleratorConnection &conn) {
449 auto acc = m.buildAccelerator(conn);
450 conn.getServiceThread()->addPoll(*acc);
451 return acc;
452 },
453 py::return_value_policy::reference)
454 .def_property_readonly("type_table",
455 [](Manifest &m) {
456 std::vector<py::object> ret;
457 std::ranges::transform(m.getTypeTable(),
458 std::back_inserter(ret),
459 getPyType);
460 return ret;
461 })
462 .def_property_readonly("module_infos", &Manifest::getModuleInfos);
463}
Abstract class representing a connection to an accelerator.
Definition Accelerator.h:79
ServiceClass * getService(AppIDPath id={}, std::string implName={}, ServiceImplDetails details={}, HWClientDetails clients={})
Get a typed reference to a particular service type.
Accelerator & getAccelerator()
AcceleratorServiceThread * getServiceThread()
Return a pointer to the accelerator 'service' thread (or threads).
std::string toStr() const
Definition Manifest.cpp:740
uint64_t getWidth() const
Definition Types.h:145
Unidirectional channels are the basic communication primitive between the host and accelerator.
Definition Ports.h:36
const Type * getType() const
Definition Ports.h:62
virtual void disconnect()=0
virtual void connect(std::optional< unsigned > bufferSize=std::nullopt)=0
Set up a connection to the accelerator.
AcceleratorConnections, Accelerators, and Manifests must all share a context.
Definition Context.h:31
const std::map< AppID, BundlePort & > & getPorts() const
Access the module's ports by ID.
Definition Design.h:76
const std::map< AppID, Instance * > & getChildren() const
Access the module's children by ID.
Definition Design.h:67
const std::vector< services::Service * > & getServices() const
Access the services provided by this module.
Definition Design.h:78
std::optional< ModuleInfo > getInfo() const
Access the module's metadata, if any.
Definition Design.h:58
const AppID getID() const
Get the instance's ID, which it will always have.
Definition Design.h:119
Class to parse a manifest.
Definition Manifest.h:39
A logical chunk of data representing serialized data.
Definition Common.h:105
const uint8_t * getBytes() const
Definition Common.h:114
size_t getSize() const
Get the size of the data in bytes.
Definition Common.h:128
A ChannelPort which reads data from the accelerator.
Definition Ports.h:124
virtual std::future< MessageData > readAsync()
Asynchronous read.
Definition Ports.cpp:77
Root class of the ESI type system.
Definition Types.h:33
ID getID() const
Definition Types.h:39
A ChannelPort which sends data to the accelerator.
Definition Ports.h:77
virtual void write(const MessageData &)=0
A very basic blocking write API.
virtual bool tryWrite(const MessageData &data)=0
A basic non-blocking write API.
A function call which gets attached to a service port.
Definition Services.h:313
A function call which gets attached to a service port.
Definition Services.h:263
virtual std::unique_ptr< HostMemRegion > allocate(std::size_t size, Options opts) const =0
Allocate a region of host memory in accelerator accessible address space.
virtual void unmapMemory(void *ptr) const
Unmap memory which was previously mapped with 'mapMemory'.
Definition Services.h:249
virtual bool mapMemory(void *ptr, std::size_t size, Options opts) const
Try to make a region of host memory accessible to the accelerator.
Definition Services.h:244
virtual uint64_t read(uint32_t addr) const
Read a 64-bit value from this region, not the global address space.
Definition Services.cpp:122
virtual void write(uint32_t addr, uint64_t data)
Write a 64-bit value to this region, not the global address space.
Definition Services.cpp:127
virtual RegionDescriptor getDescriptor() const
Get the offset (and size) of the region in the parent (usually global) MMIO address space.
Definition Services.h:167
virtual uint64_t read(uint32_t addr) const =0
Read a 64-bit value from the global MMIO space.
virtual void write(uint32_t addr, uint64_t data)=0
Write a 64-bit value to the global MMIO space.
const std::map< AppIDPath, RegionDescriptor > & getRegions() const
Get the regions of MMIO space that this service manages.
Definition Services.h:137
Parent class of all APIs modeled as 'services'.
Definition Services.h:46
Information about the Accelerator system.
Definition Services.h:100
virtual std::string getJsonManifest() const
Return the JSON-formatted system manifest.
Definition Services.cpp:40
virtual uint32_t getEsiVersion() const =0
Get the ESI version number to check version compatibility.
std::future< MessageData > read()
Definition Services.cpp:322
void connect()
Connect to a particular telemetry port.
Definition Services.cpp:308
Service for retrieving telemetry data from the accelerator.
Definition Services.h:358
PYBIND11_MODULE(esiCppAccel, m)
py::object getPyType(std::optional< const Type * > t)
Resolve a Type to the Python wrapper object.
std::unique_ptr< AcceleratorConnection > connect(Context &ctxt, const std::string &backend, const std::string &connection)
size_t hash_combine(size_t h1, size_t h2)
C++'s stdlib doesn't have a hash_combine function. This is a simple one.
Definition Utils.h:32
Definition esi.py:1
std::any value
Definition Common.h:60
std::optional< const Type * > type
Definition Common.h:61
RAII memory region for host memory.
Definition Services.h:209
virtual void * getPtr() const =0
Get a pointer to the host memory.
virtual std::size_t getSize() const =0
Options for allocating host memory.
Definition Services.h:227
Describe a region (slice) of MMIO space.
Definition Services.h:122
static handle cast(std::any src, return_value_policy, handle)
PYBIND11_TYPE_CASTER(std::any, const_name("object"))
static const void * get(const ChannelPort *port, const std::type_info *&type)
static const void * get(const Service *svc, const std::type_info *&type)