115 nb::set_leak_warnings(
false);
117 nb::class_<Type>(m,
"Type")
118 .def(nb::init<const Type::ID &>(), nb::arg(
"id"))
119 .def_prop_ro(
"id", &Type::getID)
120 .def_prop_ro(
"bit_width", &Type::getBitWidth)
121 .def(
"__repr__", [](
Type &t) {
return "<" + t.
getID() +
">"; });
122 nb::class_<ChannelType, Type>(m,
"ChannelType")
123 .def(nb::init<const Type::ID &, const Type *>(), nb::arg(
"id"),
125 .def_prop_ro(
"inner", &ChannelType::getInner, nb::rv_policy::reference);
126 nb::enum_<BundleType::Direction>(m,
"Direction")
127 .value(
"To", BundleType::Direction::To)
128 .value(
"From", BundleType::Direction::From)
130 nb::class_<BundleType, Type>(m,
"BundleType")
131 .def(nb::init<const Type::ID &, const BundleType::ChannelVector &>(),
132 nb::arg(
"id"), nb::arg(
"channels"))
133 .def_prop_ro(
"channels", &BundleType::getChannels,
134 nb::rv_policy::reference);
135 nb::class_<VoidType, Type>(m,
"VoidType")
136 .def(nb::init<const Type::ID &>(), nb::arg(
"id"));
137 nb::class_<AnyType, Type>(m,
"AnyType")
138 .def(nb::init<const Type::ID &>(), nb::arg(
"id"));
139 nb::class_<TypeAliasType, Type>(m,
"TypeAliasType")
140 .def(nb::init<const Type::ID &, std::string, const Type *>(),
141 nb::arg(
"id"), nb::arg(
"name"), nb::arg(
"inner_type"))
144 nb::rv_policy::reference);
145 nb::class_<BitVectorType, Type>(m,
"BitVectorType")
146 .def(nb::init<const Type::ID &, uint64_t>(), nb::arg(
"id"),
149 nb::class_<BitsType, BitVectorType>(m,
"BitsType")
150 .def(nb::init<const Type::ID &, uint64_t>(), nb::arg(
"id"),
152 nb::class_<IntegerType, BitVectorType>(m,
"IntegerType")
153 .def(nb::init<const Type::ID &, uint64_t>(), nb::arg(
"id"),
155 nb::class_<SIntType, IntegerType>(m,
"SIntType")
156 .def(nb::init<const Type::ID &, uint64_t>(), nb::arg(
"id"),
158 nb::class_<UIntType, IntegerType>(m,
"UIntType")
159 .def(nb::init<const Type::ID &, uint64_t>(), nb::arg(
"id"),
161 nb::class_<StructType, Type>(m,
"StructType")
162 .def(nb::init<const Type::ID &, const StructType::FieldVector &, bool>(),
163 nb::arg(
"id"), nb::arg(
"fields"), nb::arg(
"reverse") =
true)
164 .def_prop_ro(
"fields", &StructType::getFields, nb::rv_policy::reference)
165 .def_prop_ro(
"reverse", &StructType::isReverse);
166 nb::class_<ArrayType, Type>(m,
"ArrayType")
167 .def(nb::init<const Type::ID &, const Type *, uint64_t>(), nb::arg(
"id"),
168 nb::arg(
"element_type"), nb::arg(
"size"))
169 .def_prop_ro(
"element", &ArrayType::getElementType,
170 nb::rv_policy::reference)
171 .def_prop_ro(
"size", &ArrayType::getSize);
172 nb::class_<WindowType::Field>(m,
"WindowField")
174 .def(nb::init<std::string, uint64_t, uint64_t>(), nb::arg(
"name"),
175 nb::arg(
"num_items") = 0, nb::arg(
"bulk_count_width") = 0)
179 nb::class_<WindowType::Frame>(m,
"WindowFrame")
181 .def(nb::init<std::string,
const std::vector<WindowType::Field> &>(),
182 nb::arg(
"name"), nb::arg(
"fields"))
185 nb::class_<WindowType, Type>(m,
"WindowType")
186 .def(nb::init<
const Type::ID &,
const std::string &,
const Type *,
187 const Type *,
const std::vector<WindowType::Frame> &>(),
188 nb::arg(
"id"), nb::arg(
"name"), nb::arg(
"into_type"),
189 nb::arg(
"lowered_type"), nb::arg(
"frames"))
190 .def_prop_ro(
"name", &WindowType::getName)
191 .def_prop_ro(
"into", &WindowType::getIntoType, nb::rv_policy::reference)
192 .def_prop_ro(
"lowered", &WindowType::getLoweredType,
193 nb::rv_policy::reference)
194 .def_prop_ro(
"frames", &WindowType::getFrames, nb::rv_policy::reference);
195 nb::class_<ListType, Type>(m,
"ListType")
196 .def(nb::init<const Type::ID &, const Type *>(), nb::arg(
"id"),
197 nb::arg(
"element_type"))
198 .def_prop_ro(
"element", &ListType::getElementType,
199 nb::rv_policy::reference);
200 nb::class_<UnionType, Type>(m,
"UnionType")
201 .def(nb::init<const Type::ID &, const UnionType::FieldVector &>(),
202 nb::arg(
"id"), nb::arg(
"fields"))
205 nb::class_<Constant>(m,
"Constant")
209 nb::class_<AppID>(m,
"AppID")
210 .def(nb::init<std::string, std::optional<uint32_t>>(), nb::arg(
"name"),
211 nb::arg(
"idx") = std::nullopt)
212 .def_prop_ro(
"name", [](
AppID &
id) {
return id.name; })
214 [](
AppID &
id) -> nb::object {
216 return nb::cast(
id.idx);
221 std::string ret =
"<" +
id.name;
223 ret = ret +
"[" + std::to_string(*
id.idx) +
"]";
227 .def(
"__eq__", [](
AppID &a,
AppID &b) {
return a == b; })
228 .def(
"__hash__", [](
AppID &
id) {
230 std::hash<uint32_t>{}(
id.idx.value_or(-1)));
234 nb::class_<ModuleInfo>(m,
"ModuleInfo")
235 .def_prop_ro(
"name", [](
ModuleInfo &info) {
return info.name; })
236 .def_prop_ro(
"summary", [](
ModuleInfo &info) {
return info.summary; })
237 .def_prop_ro(
"version", [](
ModuleInfo &info) {
return info.version; })
238 .def_prop_ro(
"repo", [](
ModuleInfo &info) {
return info.repo; })
239 .def_prop_ro(
"commit_hash",
240 [](
ModuleInfo &info) {
return info.commitHash; })
241 .def_prop_ro(
"constants", [](
ModuleInfo &info) {
return info.constants; })
245 std::stringstream os(ret);
250 nb::enum_<Logger::Level>(m,
"LogLevel")
251 .value(
"Debug", Logger::Level::Debug)
252 .value(
"Info", Logger::Level::Info)
253 .value(
"Warning", Logger::Level::Warning)
254 .value(
"Error", Logger::Level::Error)
256 nb::class_<Logger>(m,
"Logger");
258 nb::class_<services::Service>(m,
"Service")
261 nb::class_<SysInfo, services::Service>(m,
"SysInfo")
265 "Get the current cycle count of the accelerator system")
267 "Get the core clock frequency of the accelerator system in Hz");
269 nb::class_<MMIO::RegionDescriptor>(m,
"MMIORegionDescriptor")
272 nb::class_<services::MMIO, services::Service>(m,
"MMIO")
276 nb::rv_policy::reference);
278 nb::class_<services::HostMem::HostMemRegion>(m,
"HostMemRegion")
281 return reinterpret_cast<uintptr_t
>(mem.
getPtr());
285 nb::class_<services::HostMem::Options>(m,
"HostMemOptions")
290 std::string ret =
"HostMemOptions(";
294 ret +=
"use_large_pages";
299 nb::class_<services::HostMem, services::Service>(m,
"HostMem")
302 nb::rv_policy::take_ownership)
306 return self.
mapMemory(
reinterpret_cast<void *
>(ptr), size, opts);
308 nb::arg(
"ptr"), nb::arg(
"size"),
312 [](
HostMem &self, uintptr_t ptr) {
313 return self.
unmapMemory(
reinterpret_cast<void *
>(ptr));
316 nb::class_<services::TelemetryService, services::Service>(m,
319 nb::class_<std::future<MessageData>>(m,
"MessageDataFuture")
320 .def(
"valid", [](std::future<MessageData> &f) {
return f.valid(); })
322 [](std::future<MessageData> &f) {
325 nb::gil_scoped_release release{};
328 .def(
"get", [](std::future<MessageData> &f) {
329 std::optional<MessageData> data;
333 nb::gil_scoped_release release{};
334 data.emplace(f.get());
336 return nb::bytearray((
const char *)data->getBytes(), data->getSize());
339 nb::class_<ChannelPort::ConnectOptions>(m,
"ConnectOptions")
342 nb::arg(
"buffer_size").none())
343 .def_rw(
"translate_message",
346 nb::class_<ChannelPort>(m,
"ChannelPort")
348 "Connect with specified options")
352 nb::rv_policy::reference);
354 nb::class_<WriteChannelPort, ChannelPort>(m,
"WriteChannelPort")
357 std::vector<uint8_t> dataVec((
const uint8_t *)data.c_str(),
358 (
const uint8_t *)data.c_str() +
363 std::vector<uint8_t> dataVec((
const uint8_t *)data.c_str(),
364 (
const uint8_t *)data.c_str() +
368 nb::class_<ReadChannelPort, ChannelPort>(m,
"ReadChannelPort")
374 return nb::bytearray((
const char *)data.getBytes(), data.getSize());
376 "Read data from the channel. Blocking.")
379 nb::class_<BundlePort>(m,
"BundlePort")
380 .def_prop_ro(
"id", &BundlePort::getID)
381 .def_prop_ro(
"channels", &BundlePort::getChannels,
382 nb::rv_policy::reference)
383 .def(
"getWrite", &BundlePort::getRawWrite, nb::rv_policy::reference)
384 .def(
"getRead", &BundlePort::getRawRead, nb::rv_policy::reference);
386 nb::class_<ServicePort, BundlePort>(m,
"ServicePort");
388 nb::class_<MMIO::MMIORegion, ServicePort>(m,
"MMIORegion")
393 nb::class_<FuncService::Function, ServicePort>(m,
"Function")
396 nb::bytearray msg) -> std::future<MessageData> {
397 std::vector<uint8_t> dataVec((
const uint8_t *)msg.c_str(),
398 (
const uint8_t *)msg.c_str() +
401 return self.call(data);
405 nb::class_<CallService::Callback, ServicePort>(m,
"Callback")
407 std::function<nb::object(nb::object)> pyCallback) {
412 nb::gil_scoped_acquire acquire{};
413 std::vector<uint8_t> arg(req.
getBytes(),
415 nb::bytes argObj((
const char *)arg.data(), arg.size());
416 auto ret = pyCallback(argObj);
419 nb::bytearray retBytes = nb::cast<nb::bytearray>(ret);
420 std::vector<uint8_t> dataVec((
const uint8_t *)retBytes.c_str(),
421 (
const uint8_t *)retBytes.c_str() +
427 nb::class_<TelemetryService::Metric, ServicePort>(m,
"Metric")
432 nb::class_<ChannelService::ToHost, ServicePort>(m,
"ToHostChannel")
436 nb::class_<ChannelService::FromHost, ServicePort>(m,
"FromHostChannel")
441 std::vector<uint8_t> dataVec((
const uint8_t *)data.c_str(),
442 (
const uint8_t *)data.c_str() +
451 nb::class_<HWModule>(m,
"HWModule")
455 nb::rv_policy::reference);
458 nb::class_<Instance, HWModule>(m,
"Instance")
461 nb::class_<Accelerator, HWModule>(m,
"Accelerator");
466 nb::rv_policy::reference);
468 auto accConn = nb::class_<AcceleratorConnection>(m,
"AcceleratorConnection");
472 "An ESI context owns everything -- types, accelerator connections, and "
473 "the accelerator facade (aka Accelerator) itself. It MUST NOT be garbage "
474 "collected while the accelerator is still in use. When it is destroyed, "
475 "all accelerator connections are disconnected.")
476 .def(nb::init<>(),
"Create a context with a default logger.")
477 .def(
"connect", &Context::connect, nb::rv_policy::reference)
479 ctxt.
setLogger(std::make_unique<StreamLogger>(level));
488 nb::rv_policy::reference)
494 nb::rv_policy::reference)
496 "get_service_hostmem",
500 nb::rv_policy::reference)
502 nb::rv_policy::reference);
504 nb::class_<Manifest>(m,
"Manifest")
505 .def(nb::init<Context &, std::string>())
506 .def_prop_ro(
"api_version", &Manifest::getApiVersion)
510 auto *acc = m.buildAccelerator(conn);
511 conn.getServiceThread()->addPoll(*acc);
514 nb::rv_policy::reference)
515 .def_prop_ro(
"type_table",
517 std::vector<nb::object> ret;
518 std::ranges::transform(m.getTypeTable(),
522 .def_prop_ro(
"module_infos", &Manifest::getModuleInfos);