Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
om.py
Go to the documentation of this file.
1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2# See https://llvm.org/LICENSE.txt for license information.
3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from __future__ import annotations
6
7from ._om_ops_gen import *
8from .._mlir_libs._circt._om import AnyType, Evaluator as BaseEvaluator, Object as BaseObject, List as BaseList, BasePath as BaseBasePath, BasePathType, Path, PathType, ClassType, ReferenceAttr, ListAttr, ListType, OMIntegerAttr
9
10from ..ir import Attribute, Diagnostic, DiagnosticSeverity, Module, StringAttr, IntegerAttr, IntegerType
11from ..support import attribute_to_var, var_to_attribute
12
13import sys
14import logging
15from dataclasses import fields
16from typing import TYPE_CHECKING, Any, Sequence, TypeVar
17
18if TYPE_CHECKING:
19 from _typeshed.stdlib.dataclass import DataclassInstance
20
21
22# Wrap a base mlir object with high-level object.
24 # For primitives, return the Python value directly.
25 if isinstance(value, (int, float, str, bool, tuple, list, dict)):
26 return value
27
28 if isinstance(value, BaseList):
29 return List(value)
30
31 if isinstance(value, BaseBasePath):
32 return BasePath(value)
33
34 if isinstance(value, Path):
35 return value
36
37 # For objects, return an Object, wrapping the base implementation.
38 assert isinstance(value, BaseObject)
39 return Object(value)
40
41
42def om_var_to_attribute(obj, none_on_fail: bool = False) -> ir.Attrbute:
43 if isinstance(obj, int):
44 return OMIntegerAttr.get(IntegerAttr.get(IntegerType.get_signless(64), obj))
45 return var_to_attribute(obj, none_on_fail)
46
47
49 # Check if the value is any of our container or custom types.
50 if isinstance(value, List):
51 return BaseList(value)
52
53 if isinstance(value, BasePath):
54 return BaseBasePath(value)
55
56 if isinstance(value, Path):
57 return value
58
59 if isinstance(value, Object):
60 return BaseObject(value)
61
62 # Otherwise, it must be a primitive, so just return it.
63 return value
64
65
67
68 def __init__(self, obj: BaseList) -> None:
69 super().__init__(obj)
70
71 def __getitem__(self, i):
72 val = super().__getitem__(i)
73 return wrap_mlir_object(val)
74
75 # Support iterating over a List by yielding its elements.
76 def __iter__(self):
77 for i in range(0, self.__len__()):
78 yield self.__getitem__(i)
79
80
82
83 @staticmethod
84 def get_empty(context=None) -> "BasePath":
85 return BasePath(BaseBasePath.get_empty(context))
86
87
88# Define the Object class by inheriting from the base implementation in C++.
90
91 def __init__(self, obj: BaseObject) -> None:
92 super().__init__(obj)
93
94 def __getattr__(self, name: str):
95 # Call the base method to get a field.
96 field = super().__getattr__(name)
97 return wrap_mlir_object(field)
98
99 def get_field_loc(self, name: str):
100 # Call the base method to get the loc.
101 loc = super().get_field_loc(name)
102 return loc
103
104 # Support iterating over an Object by yielding its fields.
105 def __iter__(self):
106 for name in self.field_names:
107 yield (name, getattr(self, name))
108
109
110# Define the Evaluator class by inheriting from the base implementation in C++.
112
113 def __init__(self, mod: Module) -> None:
114 """Instantiate an Evaluator with a Module."""
115
116 # Call the base constructor.
117 super().__init__(mod)
118
119 # Set up logging for diagnostics.
120 logging.basicConfig(
121 format="[%(asctime)s] %(name)s (%(levelname)s) %(message)s",
122 datefmt="%Y-%m-%d %H:%M:%S",
123 level=logging.INFO,
124 stream=sys.stdout,
125 )
126 self._logger = logging.getLogger("Evaluator")
127
128 # Attach our Diagnostic handler.
129 mod.context.attach_diagnostic_handler(self._handle_diagnostic_handle_diagnostic)
130
131 def instantiate(self, cls: str, *args: Any) -> Object:
132 """Instantiate an Object with a class name and actual parameters."""
133
134 # Convert the class name and actual parameters to Attributes within the
135 # Evaluator's context.
136 with self.module.context:
137 # Get the class name from the class name.
138 class_name = StringAttr.get(cls)
139
140 # Get the actual parameter Values from the supplied variadic
141 # arguments.
142 actual_params = [unwrap_python_object(arg) for arg in args]
143
144 # Call the base instantiate method.
145 obj = super().instantiate(class_name, actual_params)
146
147 # Return the Object, wrapping the base implementation.
148 return Object(obj)
149
150 def _handle_diagnostic(self, diagnostic: Diagnostic) -> bool:
151 """Handle MLIR Diagnostics by logging them."""
152
153 # Log the diagnostic message at the appropriate level.
154 if diagnostic.severity == DiagnosticSeverity.ERROR:
155 self._logger.error(diagnostic.message)
156 elif diagnostic.severity == DiagnosticSeverity.WARNING:
157 self._logger.warning(diagnostic.message)
158 else:
159 self._logger.info(diagnostic.message)
160
161 # Log any diagnostic notes at the info level.
162 for note in diagnostic.notes:
163 self._logger.info(str(note))
164
165 # Flush the stdout stream to ensure logs appear when expected.
166 sys.stdout.flush()
167
168 # Return True, indicating this diagnostic has been fully handled.
169 return True
"BasePath" get_empty(context=None)
Definition om.py:84
None __init__(self, Module mod)
Definition om.py:113
_handle_diagnostic
Definition om.py:129
bool _handle_diagnostic(self, Diagnostic diagnostic)
Definition om.py:150
Object instantiate(self, str cls, *Any args)
Definition om.py:131
Definition om.py:66
__iter__(self)
Definition om.py:76
__getitem__(self, i)
Definition om.py:71
None __init__(self, BaseList obj)
Definition om.py:68
__iter__(self)
Definition om.py:105
None __init__(self, BaseObject obj)
Definition om.py:91
get_field_loc(self, str name)
Definition om.py:99
__getattr__(self, str name)
Definition om.py:94
wrap_mlir_object(value)
Definition om.py:23
unwrap_python_object(value)
Definition om.py:48
ir.Attrbute om_var_to_attribute(obj, bool none_on_fail=False)
Definition om.py:42