115 nb::set_leak_warnings(
false);
117 nb::class_<Type>(m,
"Type")
118 .def(nb::init<const Type::ID &>(), nb::arg(
"id"))
119 .def_prop_ro(
"id", &Type::getID)
120 .def(
"__repr__", [](
Type &t) {
return "<" + t.
getID() +
">"; });
121 nb::class_<ChannelType, Type>(m,
"ChannelType")
122 .def(nb::init<const Type::ID &, const Type *>(), nb::arg(
"id"),
124 .def_prop_ro(
"inner", &ChannelType::getInner, nb::rv_policy::reference);
125 nb::enum_<BundleType::Direction>(m,
"Direction")
126 .value(
"To", BundleType::Direction::To)
127 .value(
"From", BundleType::Direction::From)
129 nb::class_<BundleType, Type>(m,
"BundleType")
130 .def(nb::init<const Type::ID &, const BundleType::ChannelVector &>(),
131 nb::arg(
"id"), nb::arg(
"channels"))
132 .def_prop_ro(
"channels", &BundleType::getChannels,
133 nb::rv_policy::reference);
134 nb::class_<VoidType, Type>(m,
"VoidType")
135 .def(nb::init<const Type::ID &>(), nb::arg(
"id"));
136 nb::class_<AnyType, Type>(m,
"AnyType")
137 .def(nb::init<const Type::ID &>(), nb::arg(
"id"));
138 nb::class_<TypeAliasType, Type>(m,
"TypeAliasType")
139 .def(nb::init<const Type::ID &, std::string, const Type *>(),
140 nb::arg(
"id"), nb::arg(
"name"), nb::arg(
"inner_type"))
143 nb::rv_policy::reference);
144 nb::class_<BitVectorType, Type>(m,
"BitVectorType")
145 .def(nb::init<const Type::ID &, uint64_t>(), nb::arg(
"id"),
148 nb::class_<BitsType, BitVectorType>(m,
"BitsType")
149 .def(nb::init<const Type::ID &, uint64_t>(), nb::arg(
"id"),
151 nb::class_<IntegerType, BitVectorType>(m,
"IntegerType")
152 .def(nb::init<const Type::ID &, uint64_t>(), nb::arg(
"id"),
154 nb::class_<SIntType, IntegerType>(m,
"SIntType")
155 .def(nb::init<const Type::ID &, uint64_t>(), nb::arg(
"id"),
157 nb::class_<UIntType, IntegerType>(m,
"UIntType")
158 .def(nb::init<const Type::ID &, uint64_t>(), nb::arg(
"id"),
160 nb::class_<StructType, Type>(m,
"StructType")
161 .def(nb::init<const Type::ID &, const StructType::FieldVector &, bool>(),
162 nb::arg(
"id"), nb::arg(
"fields"), nb::arg(
"reverse") =
true)
163 .def_prop_ro(
"fields", &StructType::getFields, nb::rv_policy::reference)
164 .def_prop_ro(
"reverse", &StructType::isReverse);
165 nb::class_<ArrayType, Type>(m,
"ArrayType")
166 .def(nb::init<const Type::ID &, const Type *, uint64_t>(), nb::arg(
"id"),
167 nb::arg(
"element_type"), nb::arg(
"size"))
168 .def_prop_ro(
"element", &ArrayType::getElementType,
169 nb::rv_policy::reference)
170 .def_prop_ro(
"size", &ArrayType::getSize);
171 nb::class_<WindowType::Field>(m,
"WindowField")
173 .def(nb::init<std::string, uint64_t, uint64_t>(), nb::arg(
"name"),
174 nb::arg(
"num_items") = 0, nb::arg(
"bulk_count_width") = 0)
178 nb::class_<WindowType::Frame>(m,
"WindowFrame")
180 .def(nb::init<std::string,
const std::vector<WindowType::Field> &>(),
181 nb::arg(
"name"), nb::arg(
"fields"))
184 nb::class_<WindowType, Type>(m,
"WindowType")
185 .def(nb::init<
const Type::ID &,
const std::string &,
const Type *,
186 const Type *,
const std::vector<WindowType::Frame> &>(),
187 nb::arg(
"id"), nb::arg(
"name"), nb::arg(
"into_type"),
188 nb::arg(
"lowered_type"), nb::arg(
"frames"))
189 .def_prop_ro(
"name", &WindowType::getName)
190 .def_prop_ro(
"into", &WindowType::getIntoType, nb::rv_policy::reference)
191 .def_prop_ro(
"lowered", &WindowType::getLoweredType,
192 nb::rv_policy::reference)
193 .def_prop_ro(
"frames", &WindowType::getFrames, nb::rv_policy::reference);
194 nb::class_<ListType, Type>(m,
"ListType")
195 .def(nb::init<const Type::ID &, const Type *>(), nb::arg(
"id"),
196 nb::arg(
"element_type"))
197 .def_prop_ro(
"element", &ListType::getElementType,
198 nb::rv_policy::reference);
199 nb::class_<UnionType, Type>(m,
"UnionType")
200 .def(nb::init<const Type::ID &, const UnionType::FieldVector &>(),
201 nb::arg(
"id"), nb::arg(
"fields"))
204 nb::class_<Constant>(m,
"Constant")
208 nb::class_<AppID>(m,
"AppID")
209 .def(nb::init<std::string, std::optional<uint32_t>>(), nb::arg(
"name"),
210 nb::arg(
"idx") = std::nullopt)
211 .def_prop_ro(
"name", [](
AppID &
id) {
return id.name; })
213 [](
AppID &
id) -> nb::object {
215 return nb::cast(
id.idx);
220 std::string ret =
"<" +
id.name;
222 ret = ret +
"[" + std::to_string(*
id.idx) +
"]";
226 .def(
"__eq__", [](
AppID &a,
AppID &b) {
return a == b; })
227 .def(
"__hash__", [](
AppID &
id) {
229 std::hash<uint32_t>{}(
id.idx.value_or(-1)));
233 nb::class_<ModuleInfo>(m,
"ModuleInfo")
234 .def_prop_ro(
"name", [](
ModuleInfo &info) {
return info.name; })
235 .def_prop_ro(
"summary", [](
ModuleInfo &info) {
return info.summary; })
236 .def_prop_ro(
"version", [](
ModuleInfo &info) {
return info.version; })
237 .def_prop_ro(
"repo", [](
ModuleInfo &info) {
return info.repo; })
238 .def_prop_ro(
"commit_hash",
239 [](
ModuleInfo &info) {
return info.commitHash; })
240 .def_prop_ro(
"constants", [](
ModuleInfo &info) {
return info.constants; })
244 std::stringstream os(ret);
249 nb::enum_<Logger::Level>(m,
"LogLevel")
250 .value(
"Debug", Logger::Level::Debug)
251 .value(
"Info", Logger::Level::Info)
252 .value(
"Warning", Logger::Level::Warning)
253 .value(
"Error", Logger::Level::Error)
255 nb::class_<Logger>(m,
"Logger");
257 nb::class_<services::Service>(m,
"Service")
260 nb::class_<SysInfo, services::Service>(m,
"SysInfo")
264 "Get the current cycle count of the accelerator system")
266 "Get the core clock frequency of the accelerator system in Hz");
268 nb::class_<MMIO::RegionDescriptor>(m,
"MMIORegionDescriptor")
271 nb::class_<services::MMIO, services::Service>(m,
"MMIO")
275 nb::rv_policy::reference);
277 nb::class_<services::HostMem::HostMemRegion>(m,
"HostMemRegion")
280 return reinterpret_cast<uintptr_t
>(mem.
getPtr());
284 nb::class_<services::HostMem::Options>(m,
"HostMemOptions")
289 std::string ret =
"HostMemOptions(";
293 ret +=
"use_large_pages";
298 nb::class_<services::HostMem, services::Service>(m,
"HostMem")
301 nb::rv_policy::take_ownership)
305 return self.
mapMemory(
reinterpret_cast<void *
>(ptr), size, opts);
307 nb::arg(
"ptr"), nb::arg(
"size"),
311 [](
HostMem &self, uintptr_t ptr) {
312 return self.
unmapMemory(
reinterpret_cast<void *
>(ptr));
315 nb::class_<services::TelemetryService, services::Service>(m,
318 nb::class_<std::future<MessageData>>(m,
"MessageDataFuture")
319 .def(
"valid", [](std::future<MessageData> &f) {
return f.valid(); })
321 [](std::future<MessageData> &f) {
324 nb::gil_scoped_release release{};
327 .def(
"get", [](std::future<MessageData> &f) {
328 std::optional<MessageData> data;
332 nb::gil_scoped_release release{};
333 data.emplace(f.get());
335 return nb::bytearray((
const char *)data->getBytes(), data->getSize());
338 nb::class_<ChannelPort::ConnectOptions>(m,
"ConnectOptions")
341 nb::arg(
"buffer_size").none())
342 .def_rw(
"translate_message",
345 nb::class_<ChannelPort>(m,
"ChannelPort")
347 "Connect with specified options")
351 nb::class_<WriteChannelPort, ChannelPort>(m,
"WriteChannelPort")
354 std::vector<uint8_t> dataVec((
const uint8_t *)data.c_str(),
355 (
const uint8_t *)data.c_str() +
360 std::vector<uint8_t> dataVec((
const uint8_t *)data.c_str(),
361 (
const uint8_t *)data.c_str() +
365 nb::class_<ReadChannelPort, ChannelPort>(m,
"ReadChannelPort")
371 return nb::bytearray((
const char *)data.getBytes(), data.getSize());
373 "Read data from the channel. Blocking.")
376 nb::class_<BundlePort>(m,
"BundlePort")
377 .def_prop_ro(
"id", &BundlePort::getID)
378 .def_prop_ro(
"channels", &BundlePort::getChannels,
379 nb::rv_policy::reference)
380 .def(
"getWrite", &BundlePort::getRawWrite, nb::rv_policy::reference)
381 .def(
"getRead", &BundlePort::getRawRead, nb::rv_policy::reference);
383 nb::class_<ServicePort, BundlePort>(m,
"ServicePort");
385 nb::class_<MMIO::MMIORegion, ServicePort>(m,
"MMIORegion")
390 nb::class_<FuncService::Function, ServicePort>(m,
"Function")
393 nb::bytearray msg) -> std::future<MessageData> {
394 std::vector<uint8_t> dataVec((
const uint8_t *)msg.c_str(),
395 (
const uint8_t *)msg.c_str() +
398 return self.call(data);
402 nb::class_<CallService::Callback, ServicePort>(m,
"Callback")
404 std::function<nb::object(nb::object)> pyCallback) {
409 nb::gil_scoped_acquire acquire{};
410 std::vector<uint8_t> arg(req.
getBytes(),
412 nb::bytes argObj((
const char *)arg.data(), arg.size());
413 auto ret = pyCallback(argObj);
416 nb::bytearray retBytes = nb::cast<nb::bytearray>(ret);
417 std::vector<uint8_t> dataVec((
const uint8_t *)retBytes.c_str(),
418 (
const uint8_t *)retBytes.c_str() +
424 nb::class_<TelemetryService::Metric, ServicePort>(m,
"Metric")
429 nb::class_<ChannelService::ToHost, ServicePort>(m,
"ToHostChannel")
433 nb::class_<ChannelService::FromHost, ServicePort>(m,
"FromHostChannel")
438 std::vector<uint8_t> dataVec((
const uint8_t *)data.c_str(),
439 (
const uint8_t *)data.c_str() +
448 nb::class_<HWModule>(m,
"HWModule")
452 nb::rv_policy::reference);
455 nb::class_<Instance, HWModule>(m,
"Instance")
458 nb::class_<Accelerator, HWModule>(m,
"Accelerator");
463 nb::rv_policy::reference);
465 auto accConn = nb::class_<AcceleratorConnection>(m,
"AcceleratorConnection");
469 "An ESI context owns everything -- types, accelerator connections, and "
470 "the accelerator facade (aka Accelerator) itself. It MUST NOT be garbage "
471 "collected while the accelerator is still in use. When it is destroyed, "
472 "all accelerator connections are disconnected.")
473 .def(nb::init<>(),
"Create a context with a default logger.")
474 .def(
"connect", &Context::connect, nb::rv_policy::reference)
476 ctxt.
setLogger(std::make_unique<StreamLogger>(level));
485 nb::rv_policy::reference)
491 nb::rv_policy::reference)
493 "get_service_hostmem",
497 nb::rv_policy::reference)
499 nb::rv_policy::reference);
501 nb::class_<Manifest>(m,
"Manifest")
502 .def(nb::init<Context &, std::string>())
503 .def_prop_ro(
"api_version", &Manifest::getApiVersion)
507 auto *acc = m.buildAccelerator(conn);
508 conn.getServiceThread()->addPoll(*acc);
511 nb::rv_policy::reference)
512 .def_prop_ro(
"type_table",
514 std::vector<nb::object> ret;
515 std::ranges::transform(m.getTypeTable(),
519 .def_prop_ro(
"module_infos", &Manifest::getModuleInfos);