CIRCT 23.0.0git
Loading...
Searching...
No Matches
om.py
Go to the documentation of this file.
1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2# See https://llvm.org/LICENSE.txt for license information.
3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from __future__ import annotations
6
7from ._om_ops_gen import *
8from .._mlir_libs._circt._om import AnyType, Evaluator as BaseEvaluator, Object as BaseObject, List as BaseList, BasePath as BaseBasePath, BasePathType, Path, PathType, ClassType, ReferenceAttr, ListAttr, ListType, OMIntegerAttr, Unknown
9
10from ..ir import Attribute, Diagnostic, DiagnosticSeverity, Module, StringAttr, IntegerAttr, IntegerType
11from ..support import attribute_to_var, var_to_attribute
12
13import sys
14import logging
15from dataclasses import fields
16from typing import TYPE_CHECKING, Any, Sequence, TypeVar
17
18if TYPE_CHECKING:
19 from _typeshed.stdlib.dataclass import DataclassInstance
20
21
22# Wrap a base mlir object with high-level object.
24 # For primitives, return the Python value directly.
25 if isinstance(value, (int, float, str, bool, tuple, list, dict)):
26 return value
27
28 if isinstance(value, Unknown):
29 return value
30
31 if isinstance(value, BaseList):
32 return List(value)
33
34 if isinstance(value, BaseBasePath):
35 return BasePath(value)
36
37 if isinstance(value, Path):
38 return value
39
40 # For objects, return an Object, wrapping the base implementation.
41 assert isinstance(value, BaseObject)
42 return Object(value)
43
44
45def om_var_to_attribute(obj, none_on_fail: bool = False) -> ir.Attrbute:
46 if isinstance(obj, int):
47 return OMIntegerAttr.get(IntegerAttr.get(IntegerType.get_signed(64), obj))
48 return var_to_attribute(obj, none_on_fail)
49
50
52 # Check if the value is any of our container or custom types.
53 if isinstance(value, List):
54 return BaseList(value)
55
56 if isinstance(value, BasePath):
57 return BaseBasePath(value)
58
59 if isinstance(value, Path):
60 return value
61
62 if isinstance(value, Object):
63 return BaseObject(value)
64
65 # Otherwise, it must be a primitive, so just return it.
66 return value
67
68
70
71 def __init__(self, obj: BaseList) -> None:
72 super().__init__(obj)
73
74 def __getitem__(self, i):
75 val = super().__getitem__(i)
76 return wrap_mlir_object(val)
77
78 # Support iterating over a List by yielding its elements.
79 def __iter__(self):
80 for i in range(0, self.__len__()):
81 yield self.__getitem__(i)
82
83
85
86 @staticmethod
87 def get_empty(context=None) -> "BasePath":
88 return BasePath(BaseBasePath.get_empty(context))
89
90
91# Define the Object class by inheriting from the base implementation in C++.
93
94 def __init__(self, obj: BaseObject) -> None:
95 super().__init__(obj)
96
97 def __getattr__(self, name: str):
98 # Call the base method to get a field.
99 field = super().__getattr__(name)
100 return wrap_mlir_object(field)
101
102 def get_field_loc(self, name: str):
103 # Call the base method to get the loc.
104 loc = super().get_field_loc(name)
105 return loc
106
107 # Support iterating over an Object by yielding its fields.
108 def __iter__(self):
109 for name in self.field_names:
110 yield (name, getattr(self, name))
111
112
113# Define the Evaluator class by inheriting from the base implementation in C++.
115
116 def __init__(self, mod: Module) -> None:
117 """Instantiate an Evaluator with a Module."""
118
119 # Call the base constructor.
120 super().__init__(mod)
121
122 # Set up logging for diagnostics.
123 logging.basicConfig(
124 format="[%(asctime)s] %(name)s (%(levelname)s) %(message)s",
125 datefmt="%Y-%m-%d %H:%M:%S",
126 level=logging.INFO,
127 stream=sys.stdout,
128 )
129 self._logger = logging.getLogger("Evaluator")
130
131 # Attach our Diagnostic handler.
132 mod.context.attach_diagnostic_handler(self._handle_diagnostic_handle_diagnostic)
133
134 def instantiate(self, cls: str, *args: Any) -> Object:
135 """Instantiate an Object with a class name and actual parameters."""
136
137 # Convert the class name and actual parameters to Attributes within the
138 # Evaluator's context.
139 with self.module.context:
140 # Get the class name from the class name.
141 class_name = StringAttr.get(cls)
142
143 # Get the actual parameter Values from the supplied variadic
144 # arguments.
145 actual_params = [unwrap_python_object(arg) for arg in args]
146
147 # Call the base instantiate method.
148 obj = super().instantiate(class_name, actual_params)
149
150 # Return the Object, wrapping the base implementation.
151 return Object(obj)
152
153 def _handle_diagnostic(self, diagnostic: Diagnostic) -> bool:
154 """Handle MLIR Diagnostics by logging them."""
155
156 # Log the diagnostic message at the appropriate level.
157 if diagnostic.severity == DiagnosticSeverity.ERROR:
158 self._logger.error(diagnostic.message)
159 elif diagnostic.severity == DiagnosticSeverity.WARNING:
160 self._logger.warning(diagnostic.message)
161 else:
162 self._logger.info(diagnostic.message)
163
164 # Log any diagnostic notes at the info level.
165 for note in diagnostic.notes:
166 self._logger.info(str(note))
167
168 # Flush the stdout stream to ensure logs appear when expected.
169 sys.stdout.flush()
170
171 # Return True, indicating this diagnostic has been fully handled.
172 return True
"BasePath" get_empty(context=None)
Definition om.py:87
None __init__(self, Module mod)
Definition om.py:116
_handle_diagnostic
Definition om.py:132
bool _handle_diagnostic(self, Diagnostic diagnostic)
Definition om.py:153
Object instantiate(self, str cls, *Any args)
Definition om.py:134
Definition om.py:69
__iter__(self)
Definition om.py:79
__getitem__(self, i)
Definition om.py:74
None __init__(self, BaseList obj)
Definition om.py:71
__iter__(self)
Definition om.py:108
None __init__(self, BaseObject obj)
Definition om.py:94
get_field_loc(self, str name)
Definition om.py:102
__getattr__(self, str name)
Definition om.py:97
wrap_mlir_object(value)
Definition om.py:23
unwrap_python_object(value)
Definition om.py:51
ir.Attrbute om_var_to_attribute(obj, bool none_on_fail=False)
Definition om.py:45