CIRCT 20.0.0git
Loading...
Searching...
No Matches
SupportModule.cpp
Go to the documentation of this file.
1//===- SupportModule.cpp - Support API nanobind module --------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "CIRCTModules.h"
10
11#include "mlir/Bindings/Python/NanobindAdaptors.h"
12
13#include "NanobindUtils.h"
14#include "mlir-c/Support.h"
15#include <nanobind/nanobind.h>
16
17namespace nb = nanobind;
18
19using namespace circt;
20using namespace mlir::python::nanobind_adaptors;
21
22/// Populate the support python module.
23void circt::python::populateSupportSubmodule(nb::module_ &m) {
24 m.doc() = "CIRCT Python utils";
25 // Walk with filter.
26 m.def(
27 "_walk_with_filter",
28 [](MlirOperation operation, const std::vector<std::string> &opNames,
29 std::function<nb::object(MlirOperation)> callback,
30 nb::object walkOrderRaw) {
31 struct UserData {
32 std::function<nb::object(MlirOperation)> callback;
33 bool gotException;
34 std::string exceptionWhat;
35 nb::handle exceptionType;
36 std::vector<MlirIdentifier> opNames;
37 };
38
39 // As we transition from nanobind to nanobind, the WalkOrder enum and
40 // automatic casting will be defined as a nanobind enum upstream. Do a
41 // manual conversion that works with either nanobind or nanobind for
42 // now. When we're on nanobind in CIRCT, we can go back to automatic
43 // casting.
44 MlirWalkOrder walkOrder;
45 auto walkOrderRawValue = nb::cast<int>(walkOrderRaw.attr("value"));
46 switch (walkOrderRawValue) {
47 case 0:
48 walkOrder = MlirWalkOrder::MlirWalkPreOrder;
49 break;
50 case 1:
51 walkOrder = MlirWalkOrder::MlirWalkPostOrder;
52 break;
53 }
54
55 std::vector<MlirIdentifier> opNamesIdentifiers;
56 opNamesIdentifiers.reserve(opNames.size());
57
58 // Construct MlirIdentifier from string to perform pointer comparison.
59 for (auto &opName : opNames)
60 opNamesIdentifiers.push_back(mlirIdentifierGet(
61 mlirOperationGetContext(operation),
62 mlirStringRefCreateFromCString(opName.c_str())));
63
64 UserData userData{
65 std::move(callback), false, {}, {}, opNamesIdentifiers};
66 MlirOperationWalkCallback walkCallback = [](MlirOperation op,
67 void *userData) {
68 UserData *calleeUserData = static_cast<UserData *>(userData);
69 auto opName = mlirOperationGetName(op);
70
71 // Check if the operation name is in the filter.
72 bool inFilter = false;
73 for (auto &opNamesIdentifier : calleeUserData->opNames) {
74 if (mlirIdentifierEqual(opName, opNamesIdentifier)) {
75 inFilter = true;
76 break;
77 }
78 }
79
80 // If the operation name is not in the filter, skip it.
81 if (!inFilter)
82 return MlirWalkResult::MlirWalkResultAdvance;
83
84 try {
85 // As we transition from nanobind to nanobind, the WalkResult enum
86 // and automatic casting will be defined as a nanobind enum
87 // upstream. Do a manual conversion that works with either nanobind
88 // or nanobind for now. When we're on nanobind in CIRCT, we can go
89 // back to automatic casting.
90 MlirWalkResult walkResult;
91 auto walkResultRaw = (calleeUserData->callback)(op);
92 auto walkResultRawValue =
93 nb::cast<int>(walkResultRaw.attr("value"));
94 switch (walkResultRawValue) {
95 case 0:
96 walkResult = MlirWalkResult::MlirWalkResultAdvance;
97 break;
98 case 1:
99 walkResult = MlirWalkResult::MlirWalkResultInterrupt;
100 break;
101 case 2:
102 walkResult = MlirWalkResult::MlirWalkResultSkip;
103 break;
104 }
105 return walkResult;
106 } catch (nb::python_error &e) {
107 calleeUserData->gotException = true;
108 calleeUserData->exceptionWhat = e.what();
109 calleeUserData->exceptionType = e.type();
110 return MlirWalkResult::MlirWalkResultInterrupt;
111 }
112 };
113 mlirOperationWalk(operation, walkCallback, &userData, walkOrder);
114 if (userData.gotException) {
115 std::string message("Exception raised in callback: ");
116 message.append(userData.exceptionWhat);
117 throw std::runtime_error(message);
118 }
119 },
120 nb::arg("op"), nb::arg("op_names"), nb::arg("callback"),
121 nb::arg("walk_order"));
122}
void populateSupportSubmodule(nanobind::module_ &m)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.