11 #include "mlir/Bindings/Python/PybindAdaptors.h"
14 #include "mlir-c/Support.h"
15 #include <pybind11/pybind11.h>
16 #include <pybind11/pytypes.h>
17 #include <pybind11/stl.h>
21 using namespace circt;
22 using namespace mlir::python::adaptors;
26 m.doc() =
"CIRCT Python utils";
30 [](MlirOperation operation,
const std::vector<std::string> &opNames,
31 std::function<py::object(MlirOperation)> callback,
32 py::object walkOrderRaw) {
34 std::function<py::object(MlirOperation)> callback;
36 std::string exceptionWhat;
37 py::object exceptionType;
38 std::vector<MlirIdentifier> opNames;
46 MlirWalkOrder walkOrder;
47 auto walkOrderRawValue = py::cast<int>(walkOrderRaw.attr(
"value"));
48 switch (walkOrderRawValue) {
50 walkOrder = MlirWalkOrder::MlirWalkPreOrder;
53 walkOrder = MlirWalkOrder::MlirWalkPostOrder;
57 std::vector<MlirIdentifier> opNamesIdentifiers;
58 opNamesIdentifiers.reserve(opNames.size());
61 for (
auto &opName : opNames)
62 opNamesIdentifiers.push_back(mlirIdentifierGet(
63 mlirOperationGetContext(operation),
64 mlirStringRefCreateFromCString(opName.c_str())));
67 std::move(callback),
false, {}, {}, opNamesIdentifiers};
68 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
70 UserData *calleeUserData =
static_cast<UserData *
>(userData);
71 auto opName = mlirOperationGetName(op);
74 bool inFilter =
false;
75 for (
auto &opNamesIdentifier : calleeUserData->opNames) {
76 if (mlirIdentifierEqual(opName, opNamesIdentifier)) {
84 return MlirWalkResult::MlirWalkResultAdvance;
92 MlirWalkResult walkResult;
93 auto walkResultRaw = (calleeUserData->callback)(op);
94 auto walkResultRawValue =
95 py::cast<int>(walkResultRaw.attr(
"value"));
96 switch (walkResultRawValue) {
98 walkResult = MlirWalkResult::MlirWalkResultAdvance;
101 walkResult = MlirWalkResult::MlirWalkResultInterrupt;
104 walkResult = MlirWalkResult::MlirWalkResultSkip;
108 }
catch (py::error_already_set &e) {
109 calleeUserData->gotException =
true;
110 calleeUserData->exceptionWhat = e.what();
111 calleeUserData->exceptionType = e.type();
112 return MlirWalkResult::MlirWalkResultInterrupt;
115 mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
116 if (userData.gotException) {
117 std::string message(
"Exception raised in callback: ");
118 message.append(userData.exceptionWhat);
119 throw std::runtime_error(message);
122 py::arg(
"op"), py::arg(
"op_names"), py::arg(
"callback"),
123 py::arg(
"walk_order"));
void populateSupportSubmodule(pybind11::module &m)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.