12#include "mlir-c/BuiltinAttributes.h"
13#include "mlir-c/BuiltinTypes.h"
15#include "mlir/Bindings/Python/NanobindAdaptors.h"
16#include <nanobind/nanobind.h>
17#include <nanobind/stl/variant.h>
18#include <nanobind/stl/vector.h>
22using namespace mlir::python;
23using namespace mlir::python::nanobind_adaptors;
34using PythonPrimitive = std::variant<nb::int_, nb::float_, nb::str, nb::bool_,
35 nb::tuple, nb::list, nb::dict>;
44 std::variant<None, Object, List, BasePath, Path, PythonPrimitive>;
50static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr);
51static MlirAttribute omPythonValueToPrimitive(PythonPrimitive value,
62 PythonValue getElement(intptr_t i);
76 static BasePath getEmpty(MlirContext context) {
100 std::string dunderStr() {
102 return std::string(ref.data, ref.length);
122 MlirLocation getFieldLoc(
const std::string &name) {
125 MlirStringRef cName = mlirStringRefCreateFromCString(name.c_str());
126 MlirAttribute nameAttr = mlirStringAttrGet(context, cName);
136 PythonValue getField(
const std::string &name) {
139 MlirStringRef cName = mlirStringRefCreateFromCString(name.c_str());
140 MlirAttribute nameAttr = mlirStringAttrGet(context, cName);
145 return omEvaluatorValueToPythonValue(result);
149 std::vector<std::string> getFieldNames() {
151 intptr_t numFieldNames = mlirArrayAttrGetNumElements(fieldNames);
153 std::vector<std::string> pyFieldNames;
154 for (intptr_t i = 0; i < numFieldNames; ++i) {
155 MlirAttribute fieldName = mlirArrayAttrGetElement(fieldNames, i);
156 MlirStringRef fieldNameStr = mlirStringAttrGetValue(fieldName);
157 pyFieldNames.emplace_back(fieldNameStr.data, fieldNameStr.length);
182 Object instantiate(MlirAttribute className,
183 std::vector<PythonValue> actualParams) {
184 std::vector<OMEvaluatorValue> values;
185 for (
auto ¶m : actualParams)
186 values.push_back(pythonValueToOMEvaluatorValue(
187 param, mlirModuleGetContext(getModule())));
191 evaluator, className, values.size(), values.data());
197 throw nb::value_error(
198 "unable to instantiate object, see previous error(s)");
212class PyListAttrIterator {
214 PyListAttrIterator(MlirAttribute attr) : attr(std::move(attr)) {}
216 PyListAttrIterator &dunderIter() {
return *
this; }
218 MlirAttribute dunderNext() {
220 throw nb::stop_iteration();
224 static void bind(nb::module_ &m) {
225 nb::class_<PyListAttrIterator>(m,
"ListAttributeIterator")
226 .def(
"__iter__", &PyListAttrIterator::dunderIter)
227 .def(
"__next__", &PyListAttrIterator::dunderNext);
232 intptr_t nextIndex = 0;
235PythonValue List::getElement(intptr_t i) {
240static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) {
243 return nb::int_(nb::str(strRef.data, strRef.length));
246 if (mlirAttributeIsAFloat(attr)) {
247 return nb::float_(mlirFloatAttrGetValueDouble(attr));
250 if (mlirAttributeIsAString(attr)) {
251 auto strRef = mlirStringAttrGetValue(attr);
252 return nb::str(strRef.data, strRef.length);
256 if (mlirAttributeIsABool(attr)) {
257 return nb::bool_(mlirBoolAttrGetValue(attr));
260 if (mlirAttributeIsAInteger(attr)) {
261 MlirType type = mlirAttributeGetType(attr);
262 if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
263 return nb::int_(mlirIntegerAttrGetValueInt(attr));
264 if (mlirIntegerTypeIsSigned(type))
265 return nb::int_(mlirIntegerAttrGetValueSInt(attr));
266 return nb::int_(mlirIntegerAttrGetValueUInt(attr));
274 auto moduleStr = nb::str(moduleStrRef.data, moduleStrRef.length);
275 auto nameStr = nb::str(nameStrRef.data, nameStrRef.length);
276 return nb::make_tuple(moduleStr, nameStr);
286 mlirAttributeDump(attr);
287 throw nb::type_error(
"Unexpected OM primitive attribute");
293static MlirAttribute omPythonValueToPrimitive(PythonPrimitive value,
295 if (
auto *intValue = std::get_if<nb::int_>(&value)) {
296 auto intType = mlirIntegerTypeGet(ctx, 64);
297 auto intAttr = mlirIntegerAttrGet(intType, nb::cast<int64_t>(*intValue));
301 if (
auto *attr = std::get_if<nb::float_>(&value)) {
302 auto floatType = mlirF64TypeGet(ctx);
303 return mlirFloatAttrDoubleGet(ctx, floatType, nb::cast<double>(*attr));
306 if (
auto *attr = std::get_if<nb::str>(&value)) {
307 auto str = nb::cast<std::string>(*attr);
308 auto strRef = mlirStringRefCreate(str.data(), str.length());
310 return mlirStringAttrTypedGet(omStringType, strRef);
313 if (
auto *attr = std::get_if<nb::bool_>(&value)) {
314 return mlirBoolAttrGet(ctx, nb::cast<bool>(*attr));
319 if (
auto *attr = std::get_if<nb::list>(&value)) {
320 if (attr->size() == 0)
321 throw nb::type_error(
"Empty list is prohibited now");
323 std::vector<MlirAttribute> attrs;
324 attrs.reserve(attr->size());
325 std::optional<MlirType> elemenType;
326 for (
auto v : *attr) {
328 omPythonValueToPrimitive(nb::cast<PythonPrimitive>(v), ctx));
330 elemenType = mlirAttributeGetType(attrs.back());
331 else if (!mlirTypeEqual(*elemenType,
332 mlirAttributeGetType(attrs.back()))) {
333 throw nb::type_error(
"List elements must be of the same type");
336 return omListAttrGet(*elemenType, attrs.size(), attrs.data());
339 throw nb::type_error(
"Unexpected OM primitive value");
347 throw nb::value_error(
"unable to get field, see previous error(s)");
359 return BasePath(result);
366 return omEvaluatorValueToPythonValue(
376 if (
auto *list = std::get_if<List>(&result))
377 return list->getValue();
379 if (
auto *basePath = std::get_if<BasePath>(&result))
380 return basePath->getValue();
382 if (
auto *path = std::get_if<Path>(&result))
383 return path->getValue();
385 if (
auto *
object = std::get_if<Object>(&result))
386 return object->getValue();
388 auto primitive = std::get<PythonPrimitive>(result);
390 omPythonValueToPrimitive(primitive, ctx));
397 m.doc() =
"OM dialect Python native extension";
400 nb::class_<Evaluator>(m,
"Evaluator")
401 .def(nb::init<MlirModule>(), nb::arg(
"module"))
402 .def(
"instantiate", &Evaluator::instantiate,
"Instantiate an Object",
403 nb::arg(
"class_name"), nb::arg(
"actual_params"))
404 .def_prop_ro(
"module", &Evaluator::getModule,
405 "The Module the Evaluator is built from");
408 nb::class_<List>(m,
"List")
409 .def(nb::init<List>(), nb::arg(
"list"))
410 .def(
"__getitem__", &List::getElement)
411 .def(
"__len__", &List::getNumElements);
414 nb::class_<BasePath>(m,
"BasePath")
415 .def(nb::init<BasePath>(), nb::arg(
"basepath"))
416 .def_static(
"get_empty", &BasePath::getEmpty,
417 nb::arg(
"context") = nb::none());
420 nb::class_<Path>(m,
"Path")
421 .def(nb::init<Path>(), nb::arg(
"path"))
422 .def(
"__str__", &Path::dunderStr);
425 nb::class_<Object>(m,
"Object")
426 .def(nb::init<Object>(), nb::arg(
"object"))
427 .def(
"__getattr__", &Object::getField,
"Get a field from an Object",
429 .def(
"get_field_loc", &Object::getFieldLoc,
430 "Get the location of a field from an Object", nb::arg(
"name"))
431 .def_prop_ro(
"field_names", &Object::getFieldNames,
432 "Get field names from an Object")
433 .def_prop_ro(
"type", &Object::getType,
"The Type of the Object")
434 .def_prop_ro(
"loc", &Object::getLocation,
"The Location of the Object")
435 .def(
"__hash__", &Object::getHash,
"Get object hash")
436 .def(
"__eq__", &Object::eq,
"Check if two objects are same");
440 .def_property_readonly(
"inner_ref", [](MlirAttribute self) {
446 .def_classmethod(
"get",
447 [](nb::object cls, MlirAttribute intVal) {
450 .def_property_readonly(
453 .def(
"__str__", [](MlirAttribute self) {
455 return std::string(str.data, str.length);
463 [](MlirAttribute arr) {
return PyListAttrIterator(arr); });
464 PyListAttrIterator::bind(m);
471 .def_property_readonly(
"name", [](MlirType type) {
473 return std::string(name.data, name.length);
assert(baseType &&"element must be base type")
MLIR_CAPI_EXPORTED MlirAttribute hwInnerRefAttrGetModule(MlirAttribute)
MLIR_CAPI_EXPORTED MlirAttribute hwInnerRefAttrGetName(MlirAttribute)
MLIR_CAPI_EXPORTED bool omEvaluatorObjectIsEq(OMEvaluatorValue object, OMEvaluatorValue other)
Check equality of two objects.
MLIR_CAPI_EXPORTED intptr_t omListAttrGetNumElements(MlirAttribute attr)
MLIR_CAPI_EXPORTED unsigned omEvaluatorObjectGetHash(OMEvaluatorValue object)
Get the object hash.
MLIR_CAPI_EXPORTED bool omTypeIsAFrozenBasePathType(MlirType type)
Is the Type a FrozenBasePathType.
MLIR_CAPI_EXPORTED MlirTypeID omListTypeGetTypeID(void)
Get the TypeID for a ListType.
MLIR_CAPI_EXPORTED MlirAttribute omListAttrGet(MlirType elementType, intptr_t numElements, const MlirAttribute *elements)
MLIR_CAPI_EXPORTED bool omAttrIsAListAttr(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool omTypeIsAListType(MlirType type)
Is the Type a ListType.
MLIR_CAPI_EXPORTED MlirTypeID omFrozenPathTypeGetTypeID(void)
Get the TypeID for a FrozenPathType.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsABasePath(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a BasePath.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsNull(OMEvaluatorValue evaluatorValue)
MLIR_CAPI_EXPORTED OMEvaluator omEvaluatorNew(MlirModule mod)
Construct an Evaluator with an IR module.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorValueGetReferenceValue(OMEvaluatorValue evaluatorValue)
Dereference a Reference EvaluatorValue.
MLIR_CAPI_EXPORTED MlirLocation omEvaluatorValueGetLoc(OMEvaluatorValue evaluatorValue)
MLIR_CAPI_EXPORTED MlirAttribute omEvaluatorValueGetPrimitive(OMEvaluatorValue evaluatorValue)
Get the Primitive from an EvaluatorValue, which must contain a Primitive.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorBasePathGetEmpty(MlirContext context)
Create an empty BasePath.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAPath(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a Path.
MLIR_CAPI_EXPORTED bool omTypeIsAAnyType(MlirType type)
Is the Type an AnyType.
MLIR_CAPI_EXPORTED MlirTypeID omFrozenBasePathTypeGetTypeID(void)
Get the TypeID for a FrozenBasePathType.
MLIR_CAPI_EXPORTED MlirModule omEvaluatorGetModule(OMEvaluator evaluator)
Get the Module the Evaluator is built from.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorInstantiate(OMEvaluator evaluator, MlirAttribute className, intptr_t nActualParams, OMEvaluatorValue *actualParams)
Use the Evaluator to Instantiate an Object from its class name and actual parameters.
MLIR_CAPI_EXPORTED MlirTypeID omClassTypeGetTypeID(void)
Get the TypeID for a ClassType.
MLIR_CAPI_EXPORTED intptr_t omEvaluatorListGetNumElements(OMEvaluatorValue evaluatorValue)
Get the length of the list.
MLIR_CAPI_EXPORTED MlirIdentifier omClassTypeGetName(MlirType type)
Get the name for a ClassType.
MLIR_CAPI_EXPORTED bool omAttrIsAReferenceAttr(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAList(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is an Object.
MLIR_CAPI_EXPORTED MlirType omListTypeGetElementType(MlirType type)
MLIR_CAPI_EXPORTED bool omTypeIsAFrozenPathType(MlirType type)
Is the Type a FrozenPathType.
MLIR_CAPI_EXPORTED MlirType omEvaluatorObjectGetType(OMEvaluatorValue object)
Get the Type from an Object, which will be a ClassType.
MLIR_CAPI_EXPORTED bool omAttrIsAIntegerAttr(MlirAttribute attr)
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorListGetElement(OMEvaluatorValue evaluatorValue, intptr_t pos)
Get an element of the list.
MLIR_CAPI_EXPORTED bool omTypeIsAClassType(MlirType type)
Is the Type a ClassType.
MLIR_CAPI_EXPORTED bool omEvaluatorObjectIsNull(OMEvaluatorValue object)
Query if the Object is null.
MLIR_CAPI_EXPORTED MlirStringRef omIntegerAttrToString(MlirAttribute attr)
Get a string representation of an om::IntegerAttr.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorValueFromPrimitive(MlirAttribute primitive)
Get the EvaluatorValue from a Primitive value.
MLIR_CAPI_EXPORTED MlirAttribute omIntegerAttrGet(MlirAttribute attr)
Get an om::IntegerAttr from mlir::IntegerAttr.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAObject(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is an Object.
MLIR_CAPI_EXPORTED MlirAttribute omListAttrGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAPrimitive(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a Primitive.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorObjectGetField(OMEvaluatorValue object, MlirAttribute name)
Get a field from an Object, which must contain a field of that name.
MLIR_CAPI_EXPORTED MlirContext omEvaluatorValueGetContext(OMEvaluatorValue evaluatorValue)
MLIR_CAPI_EXPORTED MlirAttribute omEvaluatorObjectGetFieldNames(OMEvaluatorValue object)
Get all the field names from an Object, can be empty if object has no fields.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAReference(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a Reference.
MLIR_CAPI_EXPORTED MlirTypeID omAnyTypeGetTypeID(void)
Get the TypeID for an AnyType.
MLIR_CAPI_EXPORTED MlirAttribute omEvaluatorPathGetAsString(OMEvaluatorValue evaluatorValue)
Get a string representation of a Path.
MLIR_CAPI_EXPORTED MlirType omStringTypeGet(MlirContext ctx)
Get a StringType.
MLIR_CAPI_EXPORTED MlirAttribute omReferenceAttrGetInnerRef(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAttribute omIntegerAttrGetInt(MlirAttribute attr)
Given an om::IntegerAttr, return the mlir::IntegerAttr.
@ None
Don't preserve aggregate at all.
evaluator::ObjectValue Object
void populateDialectOMSubmodule(nanobind::module_ &m)
A value type for use in C APIs that just wraps a pointer to an Object.
A value type for use in C APIs that just wraps a pointer to an Evaluator.