12#include "mlir-c/BuiltinAttributes.h"
13#include "mlir-c/BuiltinTypes.h"
15#include "mlir/Bindings/Python/NanobindAdaptors.h"
16#include "mlir/Bindings/Python/NanobindUtils.h"
17#include <nanobind/nanobind.h>
18#include <nanobind/stl/variant.h>
19#include <nanobind/stl/vector.h>
23using namespace mlir::python;
24using namespace mlir::python::nanobind_adaptors;
36 Unknown(MlirType type) : type(type) {}
37 MlirType getType()
const {
return type; }
45using PythonPrimitive = std::variant<nb::int_, nb::float_, nb::str, nb::bool_,
46 nb::tuple, nb::list, nb::dict>;
55 std::variant<None, Object, List, BasePath, Path, Unknown, PythonPrimitive>;
61static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr);
62static MlirAttribute omPythonValueToPrimitive(PythonPrimitive value,
73 PythonValue getElement(intptr_t i);
87 static BasePath getEmpty(MlirContext
context) {
111 std::string dunderStr() {
113 return std::string(ref.data, ref.length);
133 MlirLocation getFieldLoc(
const std::string &name) {
136 MlirStringRef cName = mlirStringRefCreateFromCString(name.c_str());
137 MlirAttribute nameAttr = mlirStringAttrGet(
context, cName);
147 PythonValue getField(
const std::string &name) {
150 MlirStringRef cName = mlirStringRefCreateFromCString(name.c_str());
151 MlirAttribute nameAttr = mlirStringAttrGet(
context, cName);
156 return omEvaluatorValueToPythonValue(result);
160 std::vector<std::string> getFieldNames() {
162 intptr_t numFieldNames = mlirArrayAttrGetNumElements(fieldNames);
164 std::vector<std::string> pyFieldNames;
165 for (intptr_t i = 0; i < numFieldNames; ++i) {
166 MlirAttribute fieldName = mlirArrayAttrGetElement(fieldNames, i);
167 MlirStringRef fieldNameStr = mlirStringAttrGetValue(fieldName);
168 pyFieldNames.emplace_back(fieldNameStr.data, fieldNameStr.length);
193 Object instantiate(MlirAttribute className,
194 std::vector<PythonValue> actualParams) {
195 std::vector<OMEvaluatorValue> values;
196 for (
auto ¶m : actualParams)
197 values.push_back(pythonValueToOMEvaluatorValue(
198 param, mlirModuleGetContext(getModule())));
202 evaluator, className, values.size(), values.data());
208 throw nb::value_error(
209 "unable to instantiate object, see previous error(s)");
223class PyListAttrIterator {
225 PyListAttrIterator(MlirAttribute attr) : attr(std::move(attr)) {}
227 PyListAttrIterator &dunderIter() {
return *
this; }
229 MlirAttribute dunderNext() {
231 throw nb::stop_iteration();
235 static void bind(nb::module_ &m) {
236 nb::class_<PyListAttrIterator>(m,
"ListAttributeIterator")
237 .def(
"__iter__", &PyListAttrIterator::dunderIter)
238 .def(
"__next__", &PyListAttrIterator::dunderNext);
243 intptr_t nextIndex = 0;
246PythonValue List::getElement(intptr_t i) {
251static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) {
254 return nb::int_(nb::str(strRef.data, strRef.length));
257 if (mlirAttributeIsAFloat(attr)) {
258 return nb::float_(mlirFloatAttrGetValueDouble(attr));
261 if (mlirAttributeIsAString(attr)) {
262 auto strRef = mlirStringAttrGetValue(attr);
263 return nb::str(strRef.data, strRef.length);
267 if (mlirAttributeIsABool(attr)) {
268 return nb::bool_(mlirBoolAttrGetValue(attr));
271 if (mlirAttributeIsAInteger(attr)) {
272 MlirType type = mlirAttributeGetType(attr);
273 if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
274 return nb::int_(mlirIntegerAttrGetValueInt(attr));
275 if (mlirIntegerTypeIsSigned(type))
276 return nb::int_(mlirIntegerAttrGetValueSInt(attr));
277 return nb::int_(mlirIntegerAttrGetValueUInt(attr));
285 auto moduleStr = nb::str(moduleStrRef.data, moduleStrRef.length);
286 auto nameStr = nb::str(nameStrRef.data, nameStrRef.length);
287 return nb::make_tuple(moduleStr, nameStr);
297 mlirAttributeDump(attr);
298 throw nb::type_error(
"Unexpected OM primitive attribute");
304static MlirAttribute omPythonValueToPrimitive(PythonPrimitive value,
306 if (
auto *intValue = std::get_if<nb::int_>(&value)) {
307 auto intType = mlirIntegerTypeSignedGet(ctx, 64);
308 auto intAttr = mlirIntegerAttrGet(intType, nb::cast<int64_t>(*intValue));
312 if (
auto *attr = std::get_if<nb::float_>(&value)) {
313 auto floatType = mlirF64TypeGet(ctx);
314 return mlirFloatAttrDoubleGet(ctx, floatType, nb::cast<double>(*attr));
317 if (
auto *attr = std::get_if<nb::str>(&value)) {
318 auto str = nb::cast<std::string>(*attr);
319 auto strRef = mlirStringRefCreate(str.data(), str.length());
321 return mlirStringAttrTypedGet(omStringType, strRef);
324 if (
auto *attr = std::get_if<nb::bool_>(&value)) {
325 return mlirBoolAttrGet(ctx, nb::cast<bool>(*attr));
330 if (
auto *attr = std::get_if<nb::list>(&value)) {
331 if (attr->size() == 0)
332 throw nb::type_error(
"Empty list is prohibited now");
334 std::vector<MlirAttribute> attrs;
335 attrs.reserve(attr->size());
336 std::optional<MlirType> elemenType;
337 for (
auto v : *attr) {
339 omPythonValueToPrimitive(nb::cast<PythonPrimitive>(v), ctx));
341 elemenType = mlirAttributeGetType(attrs.back());
342 else if (!mlirTypeEqual(*elemenType,
343 mlirAttributeGetType(attrs.back()))) {
344 throw nb::type_error(
"List elements must be of the same type");
347 return omListAttrGet(*elemenType, attrs.size(), attrs.data());
350 throw nb::type_error(
"Unexpected OM primitive value");
358 throw nb::value_error(
"unable to get field, see previous error(s)");
370 return BasePath(result);
380 return omEvaluatorValueToPythonValue(
390 if (
auto *list = std::get_if<List>(&result))
391 return list->getValue();
393 if (
auto *basePath = std::get_if<BasePath>(&result))
394 return basePath->getValue();
396 if (
auto *path = std::get_if<Path>(&result))
397 return path->getValue();
399 if (
auto *
object = std::get_if<Object>(&result))
400 return object->getValue();
402 if (
auto *unknown = std::get_if<Unknown>(&result))
405 auto primitive = std::get<PythonPrimitive>(result);
407 omPythonValueToPrimitive(primitive, ctx));
414 m.doc() =
"OM dialect Python native extension";
417 nb::class_<Evaluator>(m,
"Evaluator")
418 .def(nb::init<MlirModule>(), nb::arg(
"module"))
419 .def(
"instantiate", &Evaluator::instantiate,
"Instantiate an Object",
420 nb::arg(
"class_name"), nb::arg(
"actual_params"))
421 .def_prop_ro(
"module", &Evaluator::getModule,
422 "The Module the Evaluator is built from");
425 nb::class_<List>(m,
"List")
426 .def(nb::init<List>(), nb::arg(
"list"))
427 .def(
"__getitem__", &List::getElement)
428 .def(
"__len__", &List::getNumElements);
431 nb::class_<BasePath>(m,
"BasePath")
432 .def(nb::init<BasePath>(), nb::arg(
"basepath"))
433 .def_static(
"get_empty", &BasePath::getEmpty,
434 nb::arg(
"context") = nb::none());
437 nb::class_<Path>(m,
"Path")
438 .def(nb::init<Path>(), nb::arg(
"path"))
439 .def(
"__str__", &Path::dunderStr);
442 nb::class_<Unknown>(m,
"Unknown")
443 .def(nb::init<MlirType>(), nb::arg(
"type"))
444 .def_prop_ro(
"type", &Unknown::getType)
445 .def(
"__repr__", [](
const Unknown &u) {
446 PyPrintAccumulator printAccum;
447 printAccum.parts.append(
"Unknown(");
448 mlirTypePrint(u.getType(), printAccum.getCallback(),
449 printAccum.getUserData());
450 printAccum.parts.append(
")");
451 return printAccum.join();
455 nb::class_<Object>(m,
"Object")
456 .def(nb::init<Object>(), nb::arg(
"object"))
457 .def(
"__getattr__", &Object::getField,
"Get a field from an Object",
459 .def(
"get_field_loc", &Object::getFieldLoc,
460 "Get the location of a field from an Object", nb::arg(
"name"))
461 .def_prop_ro(
"field_names", &Object::getFieldNames,
462 "Get field names from an Object")
463 .def_prop_ro(
"type", &Object::getType,
"The Type of the Object")
464 .def_prop_ro(
"loc", &Object::getLocation,
"The Location of the Object")
465 .def(
"__hash__", &Object::getHash,
"Get object hash")
466 .def(
"__eq__", &Object::eq,
"Check if two objects are same");
470 .def_property_readonly(
"inner_ref", [](MlirAttribute self) {
476 .def_classmethod(
"get",
477 [](nb::object cls, MlirAttribute intVal) {
480 .def_property_readonly(
483 .def(
"__str__", [](MlirAttribute self) {
485 return std::string(str.data, str.length);
493 [](MlirAttribute arr) {
return PyListAttrIterator(arr); });
494 PyListAttrIterator::bind(m);
501 .def_property_readonly(
"name", [](MlirType type) {
503 return std::string(name.data, name.length);
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
MLIR_CAPI_EXPORTED MlirAttribute hwInnerRefAttrGetModule(MlirAttribute)
MLIR_CAPI_EXPORTED MlirAttribute hwInnerRefAttrGetName(MlirAttribute)
MLIR_CAPI_EXPORTED bool omEvaluatorObjectIsEq(OMEvaluatorValue object, OMEvaluatorValue other)
Check equality of two objects.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsUnknown(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is Unknown.
MLIR_CAPI_EXPORTED intptr_t omListAttrGetNumElements(MlirAttribute attr)
MLIR_CAPI_EXPORTED unsigned omEvaluatorObjectGetHash(OMEvaluatorValue object)
Get the object hash.
MLIR_CAPI_EXPORTED bool omTypeIsAFrozenBasePathType(MlirType type)
Is the Type a FrozenBasePathType.
MLIR_CAPI_EXPORTED MlirTypeID omListTypeGetTypeID(void)
Get the TypeID for a ListType.
MLIR_CAPI_EXPORTED MlirAttribute omListAttrGet(MlirType elementType, intptr_t numElements, const MlirAttribute *elements)
MLIR_CAPI_EXPORTED bool omAttrIsAListAttr(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool omTypeIsAListType(MlirType type)
Is the Type a ListType.
MLIR_CAPI_EXPORTED MlirTypeID omFrozenPathTypeGetTypeID(void)
Get the TypeID for a FrozenPathType.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsABasePath(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a BasePath.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsNull(OMEvaluatorValue evaluatorValue)
MLIR_CAPI_EXPORTED OMEvaluator omEvaluatorNew(MlirModule mod)
Construct an Evaluator with an IR module.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorValueGetReferenceValue(OMEvaluatorValue evaluatorValue)
Dereference a Reference EvaluatorValue.
MLIR_CAPI_EXPORTED MlirLocation omEvaluatorValueGetLoc(OMEvaluatorValue evaluatorValue)
MLIR_CAPI_EXPORTED MlirAttribute omEvaluatorValueGetPrimitive(OMEvaluatorValue evaluatorValue)
Get the Primitive from an EvaluatorValue, which must contain a Primitive.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorBasePathGetEmpty(MlirContext context)
Create an empty BasePath.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAPath(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a Path.
MLIR_CAPI_EXPORTED bool omTypeIsAAnyType(MlirType type)
Is the Type an AnyType.
MLIR_CAPI_EXPORTED MlirTypeID omFrozenBasePathTypeGetTypeID(void)
Get the TypeID for a FrozenBasePathType.
MLIR_CAPI_EXPORTED MlirModule omEvaluatorGetModule(OMEvaluator evaluator)
Get the Module the Evaluator is built from.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorInstantiate(OMEvaluator evaluator, MlirAttribute className, intptr_t nActualParams, OMEvaluatorValue *actualParams)
Use the Evaluator to Instantiate an Object from its class name and actual parameters.
MLIR_CAPI_EXPORTED MlirTypeID omClassTypeGetTypeID(void)
Get the TypeID for a ClassType.
MLIR_CAPI_EXPORTED intptr_t omEvaluatorListGetNumElements(OMEvaluatorValue evaluatorValue)
Get the length of the list.
MLIR_CAPI_EXPORTED MlirIdentifier omClassTypeGetName(MlirType type)
Get the name for a ClassType.
MLIR_CAPI_EXPORTED bool omAttrIsAReferenceAttr(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAList(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is an Object.
MLIR_CAPI_EXPORTED MlirType omListTypeGetElementType(MlirType type)
MLIR_CAPI_EXPORTED bool omTypeIsAFrozenPathType(MlirType type)
Is the Type a FrozenPathType.
MLIR_CAPI_EXPORTED MlirType omEvaluatorObjectGetType(OMEvaluatorValue object)
Get the Type from an Object, which will be a ClassType.
MLIR_CAPI_EXPORTED bool omAttrIsAIntegerAttr(MlirAttribute attr)
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorListGetElement(OMEvaluatorValue evaluatorValue, intptr_t pos)
Get an element of the list.
MLIR_CAPI_EXPORTED bool omTypeIsAClassType(MlirType type)
Is the Type a ClassType.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorUnknownGet(MlirContext context, MlirType type)
Create an Unknown EvaluatorValue with the given type.
MLIR_CAPI_EXPORTED bool omEvaluatorObjectIsNull(OMEvaluatorValue object)
Query if the Object is null.
MLIR_CAPI_EXPORTED MlirStringRef omIntegerAttrToString(MlirAttribute attr)
Get a string representation of an om::IntegerAttr.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorValueFromPrimitive(MlirAttribute primitive)
Get the EvaluatorValue from a Primitive value.
MLIR_CAPI_EXPORTED MlirAttribute omIntegerAttrGet(MlirAttribute attr)
Get an om::IntegerAttr from mlir::IntegerAttr.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAObject(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is an Object.
MLIR_CAPI_EXPORTED MlirAttribute omListAttrGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAPrimitive(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a Primitive.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorObjectGetField(OMEvaluatorValue object, MlirAttribute name)
Get a field from an Object, which must contain a field of that name.
MLIR_CAPI_EXPORTED MlirType omEvaluatorValueGetType(OMEvaluatorValue evaluatorValue)
Get the type of an EvaluatorValue.
MLIR_CAPI_EXPORTED MlirContext omEvaluatorValueGetContext(OMEvaluatorValue evaluatorValue)
MLIR_CAPI_EXPORTED MlirAttribute omEvaluatorObjectGetFieldNames(OMEvaluatorValue object)
Get all the field names from an Object, can be empty if object has no fields.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAReference(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a Reference.
MLIR_CAPI_EXPORTED MlirTypeID omAnyTypeGetTypeID(void)
Get the TypeID for an AnyType.
MLIR_CAPI_EXPORTED MlirAttribute omEvaluatorPathGetAsString(OMEvaluatorValue evaluatorValue)
Get a string representation of a Path.
MLIR_CAPI_EXPORTED MlirType omStringTypeGet(MlirContext ctx)
Get a StringType.
MLIR_CAPI_EXPORTED MlirAttribute omReferenceAttrGetInnerRef(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAttribute omIntegerAttrGetInt(MlirAttribute attr)
Given an om::IntegerAttr, return the mlir::IntegerAttr.
evaluator::ObjectValue Object
void populateDialectOMSubmodule(nanobind::module_ &m)
A value type for use in C APIs that just wraps a pointer to an Object.
A value type for use in C APIs that just wraps a pointer to an Evaluator.