11#include "mlir/Bindings/Python/NanobindAdaptors.h"
14#include "mlir-c/Support.h"
15#include <nanobind/nanobind.h>
20using namespace mlir::python::nanobind_adaptors;
24 m.doc() =
"CIRCT Python utils";
28 [](MlirOperation operation,
const std::vector<std::string> &opNames,
29 std::function<nb::object(MlirOperation)> callback,
30 nb::object walkOrderRaw) {
32 std::function<nb::object(MlirOperation)> callback;
34 std::string exceptionWhat;
35 nb::handle exceptionType;
36 std::vector<MlirIdentifier> opNames;
44 MlirWalkOrder walkOrder;
45 auto walkOrderRawValue = nb::cast<int>(walkOrderRaw.attr(
"value"));
46 switch (walkOrderRawValue) {
48 walkOrder = MlirWalkOrder::MlirWalkPreOrder;
51 walkOrder = MlirWalkOrder::MlirWalkPostOrder;
55 std::vector<MlirIdentifier> opNamesIdentifiers;
56 opNamesIdentifiers.reserve(opNames.size());
59 for (
auto &opName : opNames)
60 opNamesIdentifiers.push_back(mlirIdentifierGet(
61 mlirOperationGetContext(operation),
62 mlirStringRefCreateFromCString(opName.c_str())));
65 std::move(callback),
false, {}, {}, opNamesIdentifiers};
66 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
68 UserData *calleeUserData =
static_cast<UserData *
>(userData);
69 auto opName = mlirOperationGetName(op);
72 bool inFilter =
false;
73 for (
auto &opNamesIdentifier : calleeUserData->opNames) {
74 if (mlirIdentifierEqual(opName, opNamesIdentifier)) {
82 return MlirWalkResult::MlirWalkResultAdvance;
90 MlirWalkResult walkResult;
91 auto walkResultRaw = (calleeUserData->callback)(op);
92 auto walkResultRawValue =
93 nb::cast<int>(walkResultRaw.attr(
"value"));
94 switch (walkResultRawValue) {
96 walkResult = MlirWalkResult::MlirWalkResultAdvance;
99 walkResult = MlirWalkResult::MlirWalkResultInterrupt;
102 walkResult = MlirWalkResult::MlirWalkResultSkip;
106 }
catch (nb::python_error &e) {
107 calleeUserData->gotException =
true;
108 calleeUserData->exceptionWhat = e.what();
109 calleeUserData->exceptionType = e.type();
110 return MlirWalkResult::MlirWalkResultInterrupt;
113 mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
114 if (userData.gotException) {
115 std::string message(
"Exception raised in callback: ");
116 message.append(userData.exceptionWhat);
117 throw std::runtime_error(message);
120 nb::arg(
"op"), nb::arg(
"op_names"), nb::arg(
"callback"),
121 nb::arg(
"walk_order"));
void populateSupportSubmodule(nanobind::module_ &m)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.