12 #include "mlir-c/BuiltinAttributes.h"
13 #include "mlir-c/BuiltinTypes.h"
14 #include "mlir-c/IR.h"
15 #include "mlir/Bindings/Python/PybindAdaptors.h"
16 #include <pybind11/pybind11.h>
17 #include <pybind11/stl.h>
21 using namespace mlir::python;
22 using namespace mlir::python::adaptors;
35 using PythonPrimitive = std::variant<py::int_, py::float_, py::str, py::bool_,
36 py::tuple, py::list, py::dict>;
44 using PythonValue = std::variant<
None,
Object, List, Tuple, Map, BasePath, Path,
51 static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr);
52 static MlirAttribute omPythonValueToPrimitive(PythonPrimitive value,
63 PythonValue getElement(intptr_t i);
78 PythonValue getElement(intptr_t i);
92 std::vector<py::str> getKeys() {
94 intptr_t numFieldNames = mlirArrayAttrGetNumElements(attr);
96 std::vector<py::str> pyFieldNames;
97 for (intptr_t i = 0; i < numFieldNames; ++i) {
98 auto name = mlirStringAttrGetValue(mlirArrayAttrGetElement(attr, i));
99 pyFieldNames.emplace_back(py::str(name.data, name.length));
106 PythonValue dunderGetItemAttr(MlirAttribute key);
107 PythonValue dunderGetItemNamed(
const std::string &key);
108 PythonValue dunderGetItemIndexed(intptr_t key);
110 dunderGetItem(std::variant<intptr_t, std::string, MlirAttribute> key);
129 static BasePath getEmpty(MlirContext context) {
153 std::string dunderStr() {
155 return std::string(ref.data, ref.length);
175 MlirLocation getFieldLoc(
const std::string &name) {
178 MlirStringRef cName = mlirStringRefCreateFromCString(name.c_str());
179 MlirAttribute nameAttr = mlirStringAttrGet(context, cName);
189 PythonValue getField(
const std::string &name) {
192 MlirStringRef cName = mlirStringRefCreateFromCString(name.c_str());
193 MlirAttribute nameAttr = mlirStringAttrGet(context, cName);
198 return omEvaluatorValueToPythonValue(result);
202 std::vector<std::string> getFieldNames() {
204 intptr_t numFieldNames = mlirArrayAttrGetNumElements(fieldNames);
206 std::vector<std::string> pyFieldNames;
207 for (intptr_t i = 0; i < numFieldNames; ++i) {
208 MlirAttribute fieldName = mlirArrayAttrGetElement(fieldNames, i);
209 MlirStringRef fieldNameStr = mlirStringAttrGetValue(fieldName);
210 pyFieldNames.emplace_back(fieldNameStr.data, fieldNameStr.length);
235 Object instantiate(MlirAttribute className,
236 std::vector<PythonValue> actualParams) {
237 std::vector<OMEvaluatorValue> values;
238 for (
auto ¶m : actualParams)
239 values.push_back(pythonValueToOMEvaluatorValue(
240 param, mlirModuleGetContext(getModule())));
244 evaluator, className, values.size(), values.data());
250 throw py::value_error(
251 "unable to instantiate object, see previous error(s)");
265 class PyListAttrIterator {
267 PyListAttrIterator(MlirAttribute attr) : attr(std::move(attr)) {}
269 PyListAttrIterator &dunderIter() {
return *
this; }
271 MlirAttribute dunderNext() {
273 throw py::stop_iteration();
277 static void bind(py::module &m) {
278 py::class_<PyListAttrIterator>(m,
"ListAttributeIterator",
280 .def(
"__iter__", &PyListAttrIterator::dunderIter)
281 .def(
"__next__", &PyListAttrIterator::dunderNext);
286 intptr_t nextIndex = 0;
289 PythonValue List::getElement(intptr_t i) {
293 class PyMapAttrIterator {
295 PyMapAttrIterator(MlirAttribute attr) : attr(std::move(attr)) {}
297 PyMapAttrIterator &dunderIter() {
return *
this; }
299 py::tuple dunderNext() {
301 throw py::stop_iteration();
308 auto keyName = mlirIdentifierStr(key);
309 std::string keyStr(keyName.data, keyName.length);
310 return py::make_tuple(keyStr, value);
313 static void bind(py::module &m) {
314 py::class_<PyMapAttrIterator>(m,
"MapAttributeIterator", py::module_local())
315 .def(
"__iter__", &PyMapAttrIterator::dunderIter)
316 .def(
"__next__", &PyMapAttrIterator::dunderNext);
321 intptr_t nextIndex = 0;
324 PythonValue Tuple::getElement(intptr_t i) {
326 throw std::out_of_range(
"tuple index out of range");
331 PythonValue Map::dunderGetItemNamed(
const std::string &key) {
334 throw pybind11::key_error(
"key is not string");
336 mlirStringAttrTypedGet(type, mlirStringRefCreateFromCString(key.c_str()));
337 return dunderGetItemAttr(attr);
340 PythonValue Map::dunderGetItemIndexed(intptr_t i) {
342 if (!mlirTypeIsAInteger(type))
343 throw pybind11::key_error(
"key is not integer");
344 MlirAttribute attr = mlirIntegerAttrGet(type, i);
345 return dunderGetItemAttr(attr);
348 PythonValue Map::dunderGetItemAttr(MlirAttribute key) {
352 throw pybind11::key_error(
"key not found");
354 return omEvaluatorValueToPythonValue(result);
358 Map::dunderGetItem(std::variant<intptr_t, std::string, MlirAttribute> key) {
359 if (
auto *i = std::get_if<intptr_t>(&key))
360 return dunderGetItemIndexed(*i);
361 else if (
auto *str = std::get_if<std::string>(&key))
362 return dunderGetItemNamed(*str);
363 return dunderGetItemAttr(std::get<MlirAttribute>(key));
368 static PythonPrimitive omPrimitiveToPythonValue(MlirAttribute attr) {
371 return py::int_(py::str(strRef.data, strRef.length));
374 if (mlirAttributeIsAFloat(attr)) {
375 return py::float_(mlirFloatAttrGetValueDouble(attr));
378 if (mlirAttributeIsAString(attr)) {
379 auto strRef = mlirStringAttrGetValue(attr);
380 return py::str(strRef.data, strRef.length);
384 if (mlirAttributeIsABool(attr)) {
385 return py::bool_(mlirBoolAttrGetValue(attr));
388 if (mlirAttributeIsAInteger(attr)) {
389 MlirType type = mlirAttributeGetType(attr);
390 if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
391 return py::int_(mlirIntegerAttrGetValueInt(attr));
392 if (mlirIntegerTypeIsSigned(type))
393 return py::int_(mlirIntegerAttrGetValueSInt(attr));
394 return py::int_(mlirIntegerAttrGetValueUInt(attr));
402 auto moduleStr = py::str(moduleStrRef.data, moduleStrRef.length);
403 auto nameStr = py::str(nameStrRef.data, nameStrRef.length);
404 return py::make_tuple(moduleStr, nameStr);
418 auto key = py::str(keyStrRef.data, keyStrRef.length);
420 results[key] = value;
425 mlirAttributeDump(attr);
426 throw py::type_error(
"Unexpected OM primitive attribute");
432 static MlirAttribute omPythonValueToPrimitive(PythonPrimitive value,
434 if (
auto *intValue = std::get_if<py::int_>(&value)) {
435 auto intType = mlirIntegerTypeGet(ctx, 64);
436 auto intAttr = mlirIntegerAttrGet(intType, intValue->cast<int64_t>());
440 if (
auto *attr = std::get_if<py::float_>(&value)) {
441 auto floatType = mlirF64TypeGet(ctx);
442 return mlirFloatAttrDoubleGet(ctx, floatType, attr->cast<
double>());
445 if (
auto *attr = std::get_if<py::str>(&value)) {
446 auto str = attr->cast<std::string>();
447 auto strRef = mlirStringRefCreate(str.data(), str.length());
448 return mlirStringAttrGet(ctx, strRef);
451 if (
auto *attr = std::get_if<py::bool_>(&value)) {
452 return mlirBoolAttrGet(ctx, attr->cast<
bool>());
455 throw py::type_error(
"Unexpected OM primitive value");
463 throw py::value_error(
"unable to get field, see previous error(s)");
475 return Tuple(result);
483 return BasePath(result);
490 return omEvaluatorValueToPythonValue(
500 if (
auto *list = std::get_if<List>(&result))
501 return list->getValue();
503 if (
auto *tuple = std::get_if<Tuple>(&result))
504 return tuple->getValue();
506 if (
auto *map = std::get_if<Map>(&result))
507 return map->getValue();
509 if (
auto *basePath = std::get_if<BasePath>(&result))
510 return basePath->getValue();
512 if (
auto *path = std::get_if<Path>(&result))
513 return path->getValue();
515 if (
auto *
object = std::get_if<Object>(&result))
516 return object->getValue();
518 auto primitive = std::get<PythonPrimitive>(result);
520 omPythonValueToPrimitive(primitive, ctx));
527 m.doc() =
"OM dialect Python native extension";
530 py::class_<Evaluator>(m,
"Evaluator")
531 .def(py::init<MlirModule>(), py::arg(
"module"))
532 .def(
"instantiate", &Evaluator::instantiate,
"Instantiate an Object",
533 py::arg(
"class_name"), py::arg(
"actual_params"))
534 .def_property_readonly(
"module", &Evaluator::getModule,
535 "The Module the Evaluator is built from");
538 py::class_<List>(m,
"List")
539 .def(py::init<List>(), py::arg(
"list"))
540 .def(
"__getitem__", &List::getElement)
541 .def(
"__len__", &List::getNumElements);
543 py::class_<Tuple>(m,
"Tuple")
544 .def(py::init<Tuple>(), py::arg(
"tuple"))
545 .def(
"__getitem__", &Tuple::getElement)
546 .def(
"__len__", &Tuple::getNumElements);
549 py::class_<Map>(m,
"Map")
550 .def(py::init<Map>(), py::arg(
"map"))
551 .def(
"__getitem__", &Map::dunderGetItem)
552 .def(
"keys", &Map::getKeys)
553 .def_property_readonly(
"type", &Map::getType,
"The Type of the Map");
556 py::class_<BasePath>(m,
"BasePath")
557 .def(py::init<BasePath>(), py::arg(
"basepath"))
558 .def_static(
"get_empty", &BasePath::getEmpty,
559 py::arg(
"context") = py::none());
562 py::class_<Path>(m,
"Path")
563 .def(py::init<Path>(), py::arg(
"path"))
564 .def(
"__str__", &Path::dunderStr);
567 py::class_<Object>(m,
"Object")
568 .def(py::init<Object>(), py::arg(
"object"))
569 .def(
"__getattr__", &Object::getField,
"Get a field from an Object",
571 .def(
"get_field_loc", &Object::getFieldLoc,
572 "Get the location of a field from an Object", py::arg(
"name"))
573 .def_property_readonly(
"field_names", &Object::getFieldNames,
574 "Get field names from an Object")
575 .def_property_readonly(
"type", &Object::getType,
"The Type of the Object")
576 .def_property_readonly(
"loc", &Object::getLocation,
577 "The Location of the Object")
578 .def(
"__hash__", &Object::getHash,
"Get object hash")
579 .def(
"__eq__", &Object::eq,
"Check if two objects are same");
583 .def_property_readonly(
"inner_ref", [](MlirAttribute
self) {
589 .def_classmethod(
"get",
590 [](py::object cls, MlirAttribute intVal) {
593 .def_property_readonly(
596 .def(
"__str__", [](MlirAttribute
self) {
598 return std::string(str.data, str.length);
606 [](MlirAttribute arr) {
return PyListAttrIterator(arr); });
607 PyListAttrIterator::bind(m);
611 .def(
"__iter__", [](MlirAttribute arr) {
return PyMapAttrIterator(arr); })
613 PyMapAttrIterator::bind(m);
620 .def_property_readonly(
"name", [](MlirType type) {
622 return std::string(name.data, name.length);
assert(baseType &&"element must be base type")
MLIR_CAPI_EXPORTED MlirAttribute hwInnerRefAttrGetModule(MlirAttribute)
MLIR_CAPI_EXPORTED MlirAttribute hwInnerRefAttrGetName(MlirAttribute)
MLIR_CAPI_EXPORTED bool omEvaluatorObjectIsEq(OMEvaluatorValue object, OMEvaluatorValue other)
Check equality of two objects.
MLIR_CAPI_EXPORTED intptr_t omListAttrGetNumElements(MlirAttribute attr)
MLIR_CAPI_EXPORTED unsigned omEvaluatorObjectGetHash(OMEvaluatorValue object)
Get the object hash.
MLIR_CAPI_EXPORTED bool omTypeIsAFrozenBasePathType(MlirType type)
Is the Type a FrozenBasePathType.
MLIR_CAPI_EXPORTED MlirTypeID omListTypeGetTypeID(void)
Get the TypeID for a ListType.
MLIR_CAPI_EXPORTED bool omAttrIsAListAttr(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool omTypeIsAListType(MlirType type)
Is the Type a ListType.
MLIR_CAPI_EXPORTED MlirTypeID omFrozenPathTypeGetTypeID(void)
Get the TypeID for a FrozenPathType.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsABasePath(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a BasePath.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsNull(OMEvaluatorValue evaluatorValue)
MLIR_CAPI_EXPORTED MlirAttribute omMapAttrGetElementValue(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED intptr_t omEvaluatorTupleGetNumElements(OMEvaluatorValue evaluatorValue)
Get the size of the tuple.
MLIR_CAPI_EXPORTED OMEvaluator omEvaluatorNew(MlirModule mod)
Construct an Evaluator with an IR module.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorValueGetReferenceValue(OMEvaluatorValue evaluatorValue)
Dereference a Reference EvaluatorValue.
MLIR_CAPI_EXPORTED MlirLocation omEvaluatorValueGetLoc(OMEvaluatorValue evaluatorValue)
MLIR_CAPI_EXPORTED MlirAttribute omEvaluatorValueGetPrimitive(OMEvaluatorValue evaluatorValue)
Get the Primitive from an EvaluatorValue, which must contain a Primitive.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorBasePathGetEmpty(MlirContext context)
Create an empty BasePath.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAPath(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a Path.
MLIR_CAPI_EXPORTED bool omTypeIsAAnyType(MlirType type)
Is the Type an AnyType.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAMap(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a Map.
MLIR_CAPI_EXPORTED MlirTypeID omFrozenBasePathTypeGetTypeID(void)
Get the TypeID for a FrozenBasePathType.
MLIR_CAPI_EXPORTED MlirModule omEvaluatorGetModule(OMEvaluator evaluator)
Get the Module the Evaluator is built from.
MLIR_CAPI_EXPORTED MlirType omMapTypeGetKeyType(MlirType type)
Return a key type of a map.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorInstantiate(OMEvaluator evaluator, MlirAttribute className, intptr_t nActualParams, OMEvaluatorValue *actualParams)
Use the Evaluator to Instantiate an Object from its class name and actual parameters.
MLIR_CAPI_EXPORTED MlirTypeID omClassTypeGetTypeID(void)
Get the TypeID for a ClassType.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorMapGetElement(OMEvaluatorValue evaluatorValue, MlirAttribute attr)
Get an element of the map.
MLIR_CAPI_EXPORTED intptr_t omEvaluatorListGetNumElements(OMEvaluatorValue evaluatorValue)
Get the length of the list.
MLIR_CAPI_EXPORTED MlirIdentifier omClassTypeGetName(MlirType type)
Get the name for a ClassType.
MLIR_CAPI_EXPORTED bool omAttrIsAReferenceAttr(MlirAttribute attr)
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorTupleGetElement(OMEvaluatorValue evaluatorValue, intptr_t pos)
Get an element of the tuple.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAList(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is an Object.
MLIR_CAPI_EXPORTED MlirType omListTypeGetElementType(MlirType type)
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsATuple(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a Tuple.
MLIR_CAPI_EXPORTED MlirType omEvaluatorMapGetType(OMEvaluatorValue evaluatorValue)
Get the Type from a Map, which will be a MapType.
MLIR_CAPI_EXPORTED bool omTypeIsAFrozenPathType(MlirType type)
Is the Type a FrozenPathType.
MLIR_CAPI_EXPORTED MlirType omEvaluatorObjectGetType(OMEvaluatorValue object)
Get the Type from an Object, which will be a ClassType.
MLIR_CAPI_EXPORTED bool omAttrIsAIntegerAttr(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool omTypeIsAStringType(MlirType type)
Is the Type a StringType.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorListGetElement(OMEvaluatorValue evaluatorValue, intptr_t pos)
Get an element of the list.
MLIR_CAPI_EXPORTED bool omTypeIsAClassType(MlirType type)
Is the Type a ClassType.
MLIR_CAPI_EXPORTED bool omEvaluatorObjectIsNull(OMEvaluatorValue object)
Query if the Object is null.
MLIR_CAPI_EXPORTED intptr_t omMapAttrGetNumElements(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirStringRef omIntegerAttrToString(MlirAttribute attr)
Get a string representation of an om::IntegerAttr.
MLIR_CAPI_EXPORTED MlirIdentifier omMapAttrGetElementKey(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorValueFromPrimitive(MlirAttribute primitive)
Get the EvaluatorValue from a Primitive value.
MLIR_CAPI_EXPORTED MlirAttribute omIntegerAttrGet(MlirAttribute attr)
Get an om::IntegerAttr from mlir::IntegerAttr.
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAObject(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is an Object.
MLIR_CAPI_EXPORTED MlirAttribute omListAttrGetElement(MlirAttribute attr, intptr_t pos)
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAPrimitive(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a Primitive.
MLIR_CAPI_EXPORTED OMEvaluatorValue omEvaluatorObjectGetField(OMEvaluatorValue object, MlirAttribute name)
Get a field from an Object, which must contain a field of that name.
MLIR_CAPI_EXPORTED MlirAttribute omEvaluatorMapGetKeys(OMEvaluatorValue object)
Get an ArrayAttr with the keys in a Map.
MLIR_CAPI_EXPORTED MlirContext omEvaluatorValueGetContext(OMEvaluatorValue evaluatorValue)
MLIR_CAPI_EXPORTED MlirAttribute omEvaluatorObjectGetFieldNames(OMEvaluatorValue object)
Get all the field names from an Object, can be empty if object has no fields.
MLIR_CAPI_EXPORTED bool omAttrIsAMapAttr(MlirAttribute attr)
MLIR_CAPI_EXPORTED bool omEvaluatorValueIsAReference(OMEvaluatorValue evaluatorValue)
Query if the EvaluatorValue is a Reference.
MLIR_CAPI_EXPORTED MlirTypeID omAnyTypeGetTypeID(void)
Get the TypeID for an AnyType.
MLIR_CAPI_EXPORTED MlirAttribute omEvaluatorPathGetAsString(OMEvaluatorValue evaluatorValue)
Get a string representation of a Path.
MLIR_CAPI_EXPORTED MlirAttribute omReferenceAttrGetInnerRef(MlirAttribute attr)
MLIR_CAPI_EXPORTED MlirAttribute omIntegerAttrGetInt(MlirAttribute attr)
Given an om::IntegerAttr, return the mlir::IntegerAttr.
@ None
Don't preserve aggregate at all.
evaluator::ObjectValue Object
void populateDialectOMSubmodule(pybind11::module &m)
A value type for use in C APIs that just wraps a pointer to an Object.
A value type for use in C APIs that just wraps a pointer to an Evaluator.