CIRCT 23.0.0git
Loading...
Searching...
No Matches
support.py
Go to the documentation of this file.
1# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2# See https://llvm.org/LICENSE.txt for license information.
3# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4
5from . import ir
6
7from ._mlir_libs._circt._support import _walk_with_filter
8from .ir import Operation
9from contextlib import AbstractContextManager
10from contextvars import ContextVar
11from typing import List
12
13_current_backedge_builder = ContextVar("current_bb")
14
15
17 pass
18
19
21
22 def __init__(self, module: str, port_names: List[str]):
23 super().__init__(
24 f"Ports {port_names} unconnected in design module {module}.")
25
26
27def get_value(obj) -> ir.Value:
28 """Resolve a Value from a few supported types."""
29
30 if isinstance(obj, ir.Value):
31 return obj
32 if hasattr(obj, "result"):
33 return obj.result
34 if hasattr(obj, "value"):
35 return obj.value
36 return None
37
38
39def connect(destination, source):
40 """A convenient way to use BackedgeBuilder."""
41 if not isinstance(destination, OpOperand):
42 raise TypeError(
43 f"cannot connect to destination of type {type(destination)}. "
44 "Must be OpOperand.")
45 value = get_value(source)
46 if value is None:
47 raise TypeError(f"cannot connect from source of type {type(source)}")
48
49 index = destination.index
50 destination.operation.operands[index] = value
51 if destination.backedge_owner and \
52 index in destination.backedge_owner.backedges:
53 destination.backedge_owner.backedges[index].erase()
54 del destination.backedge_owner.backedges[index]
55
56
57def var_to_attribute(obj, none_on_fail: bool = False) -> ir.Attribute:
58 """Create an MLIR attribute from a Python object for a few common cases."""
59 if isinstance(obj, ir.Attribute):
60 return obj
61 if isinstance(obj, bool):
62 return ir.BoolAttr.get(obj)
63 if isinstance(obj, int):
64 attrTy = ir.IntegerType.get_signless(64)
65 return ir.IntegerAttr.get(attrTy, obj)
66 if isinstance(obj, str):
67 return ir.StringAttr.get(obj)
68 if isinstance(obj, list):
69 arr = [var_to_attribute(x, none_on_fail) for x in obj]
70 if all(arr):
71 return ir.ArrayAttr.get(arr)
72 return None
73 if none_on_fail:
74 return None
75 raise TypeError(f"Cannot convert type '{type(obj)}' to MLIR attribute")
76
77
78# There is currently no support in MLIR for querying type types. The
79# conversation regarding how to achieve this is ongoing and I expect it to be a
80# long one. This is a way that works for now.
81def type_to_pytype(t) -> ir.Type:
82
83 if not isinstance(t, ir.Type):
84 raise TypeError("type_to_pytype only accepts MLIR Type objects")
85
86 # If it's not the root type, assume it's already been downcasted and don't do
87 # the expensive probing below.
88 if t.__class__ != ir.Type:
89 return t
90
91 from .dialects import esi, hw, seq, rtg, rtgtest
92 if isinstance(t, ir.IntegerType):
93 return ir.IntegerType(t)
94 if isinstance(t, ir.NoneType):
95 return ir.NoneType(t)
96 if isinstance(t, ir.TupleType):
97 return ir.TupleType(t)
98 if hw.ArrayType.isinstance(t):
99 return hw.ArrayType(t)
100 if hw.StructType.isinstance(t):
101 return hw.StructType(t)
102 if hw.UnionType.isinstance(t):
103 return hw.UnionType(t)
104 if hw.TypeAliasType.isinstance(t):
105 return hw.TypeAliasType(t)
106 if hw.InOutType.isinstance(t):
107 return hw.InOutType(t)
108 if seq.ClockType.isinstance(t):
109 return seq.ClockType(t)
110 if esi.ChannelType.isinstance(t):
111 return esi.ChannelType(t)
112 if esi.AnyType.isinstance(t):
113 return esi.AnyType(t)
114 if esi.BundleType.isinstance(t):
115 return esi.BundleType(t)
116 if esi.ListType.isinstance(t):
117 return esi.ListType(t)
118 if esi.WindowType.isinstance(t):
119 return esi.WindowType(t)
120 if esi.WindowFrameType.isinstance(t):
121 return esi.WindowFrameType(t)
122 if esi.WindowFieldType.isinstance(t):
123 return esi.WindowFieldType(t)
124 if rtg.LabelType.isinstance(t):
125 return rtg.LabelType(t)
126 if rtg.SetType.isinstance(t):
127 return rtg.SetType(t)
128 if rtg.BagType.isinstance(t):
129 return rtg.BagType(t)
130 if rtg.SequenceType.isinstance(t):
131 return rtg.SequenceType(t)
132 if rtg.RandomizedSequenceType.isinstance(t):
133 return rtg.RandomizedSequenceType(t)
134 if rtg.DictType.isinstance(t):
135 return rtg.DictType(t)
136 if rtg.ImmediateType.isinstance(t):
137 return rtg.ImmediateType(t)
138 if rtg.ArrayType.isinstance(t):
139 return rtg.ArrayType(t)
140 if rtg.MemoryType.isinstance(t):
141 return rtg.MemoryType(t)
142 if rtg.MemoryBlockType.isinstance(t):
143 return rtg.MemoryBlockType(t)
144 if rtg.TupleType.isinstance(t):
145 return rtg.TupleType(t)
146 if rtg.StringType.isinstance(t):
147 return rtg.StringType(t)
148 if rtgtest.IntegerRegisterType.isinstance(t):
149 return rtgtest.IntegerRegisterType(t)
150 if rtgtest.FloatRegisterType.isinstance(t):
151 return rtgtest.FloatRegisterType(t)
152 if rtgtest.CPUType.isinstance(t):
153 return rtgtest.CPUType(t)
154
155 raise TypeError(f"Cannot convert {repr(t)} to python type")
156
157
158# There is currently no support in MLIR for querying attribute types. The
159# conversation regarding how to achieve this is ongoing and I expect it to be a
160# long one. This is a way that works for now.
161def attribute_to_var(attr):
162
163 if attr is None:
164 return None
165 if not isinstance(attr, ir.Attribute):
166 raise TypeError("attribute_to_var only accepts MLIR Attributes")
167
168 # If it's not the root type, assume it's already been downcasted and don't do
169 # the expensive probing below.
170 if attr.__class__ != ir.Attribute and hasattr(attr, "value"):
171 return attr.value
172
173 from .dialects import hw, om
174 if isinstance(attr, ir.BoolAttr):
175 return ir.BoolAttr(attr).value
176 if isinstance(attr, ir.IntegerAttr):
177 return ir.IntegerAttr(attr).value
178 if hw.InnerSymAttr.isinstance(attr):
179 return ir.StringAttr(hw.InnerSymAttr(attr).symName).value
180 if isinstance(attr, ir.StringAttr):
181 return ir.StringAttr(attr).value
182 if isinstance(attr, ir.FlatSymbolRefAttr):
183 return ir.FlatSymbolRefAttr(attr).value
184 if isinstance(attr, ir.TypeAttr):
185 return ir.TypeAttr(attr).value
186 if isinstance(attr, ir.ArrayAttr):
187 arr = ir.ArrayAttr(attr)
188 return [attribute_to_var(x) for x in arr]
189 if isinstance(attr, ir.DictAttr):
190 dict = ir.DictAttr(attr)
191 return {i.name: attribute_to_var(i.attr) for i in dict}
192 if om.ReferenceAttr.isinstance(attr):
193 return attribute_to_var(om.ReferenceAttr(attr).inner_ref)
194 if hw.InnerRefAttr.isinstance(attr):
195 ref = hw.InnerRefAttr(attr)
196 return (ir.StringAttr(ref.module).value, ir.StringAttr(ref.name).value)
197 if om.ListAttr.isinstance(attr):
198 return list(map(attribute_to_var, om.ListAttr(attr)))
199 if om.OMIntegerAttr.isinstance(attr):
200 return int(str(om.OMIntegerAttr(attr)))
201 if om.PathAttr.isinstance(attr):
202 return om.PathAttr(attr).value
203
204 raise TypeError(f"Cannot convert {repr(attr)} to python value")
205
206
207def get_self_or_inner(mlir_type):
208 from .dialects import hw
209 if type(mlir_type) is ir.Type:
210 mlir_type = type_to_pytype(mlir_type)
211 if isinstance(mlir_type, hw.TypeAliasType):
212 return type_to_pytype(mlir_type.inner_type)
213 return mlir_type
214
215
216class BackedgeBuilder(AbstractContextManager):
217
218 class Edge:
219
220 def __init__(self,
221 creator,
222 type: ir.Type,
223 backedge_name: str,
224 op_view,
225 instance_of: ir.Operation,
226 loc: ir.Location = None):
227 self.creator: BackedgeBuilder = creator
228 self.dummy_op = ir.Operation.create("builtin.unrealized_conversion_cast",
229 [type],
230 loc=loc)
231 self.instance_of = instance_of
232 self.op_view = op_view
233 self.port_name = backedge_name
234 self.loc = loc
235 self.erased = False
236
237 @property
238 def result(self):
239 return self.dummy_op.result
240
241 def erase(self):
242 if self.erased:
243 return
244 if self in self.creator.edges:
245 self.creator.edges.remove(self)
246 self.dummy_op.operation.erase()
247
248 def __init__(self, circuit_name: str = ""):
249 self.circuit_name = circuit_name
250 self.edges = set()
251
252 @staticmethod
253 def current():
254 bb = _current_backedge_builder.get(None)
255 if bb is None:
256 raise RuntimeError("No backedge builder found in context!")
257 return bb
258
259 @staticmethod
260 def create(*args, **kwargs):
261 return BackedgeBuilder.current()._create(*args, **kwargs)
262
263 def _create(self,
264 type: ir.Type,
265 port_name: str,
266 op_view,
267 instance_of: ir.Operation = None,
268 loc: ir.Location = None):
269 edge = BackedgeBuilder.Edge(self, type, port_name, op_view, instance_of,
270 loc)
271 self.edges.add(edge)
272 return edge
273
274 def __enter__(self):
275 self.old_bb_token = _current_backedge_builder.set(self)
276
277 def __exit__(self, exc_type, exc_value, traceback):
278 if exc_value is not None:
279 return
280 _current_backedge_builder.reset(self.old_bb_token)
281 errors = []
282 for edge in list(self.edges):
283 # TODO: Make this use `UnconnectedSignalError`.
284 msg = "Backedge: " + edge.port_name + "\n"
285 if edge.instance_of is not None:
286 msg += "InstanceOf: " + str(edge.instance_of).split(" {")[0] + "\n"
287 if edge.op_view is not None:
288 op = edge.op_view.operation
289 msg += "Instance: " + str(op)
290 if edge.loc is not None:
291 msg += "Location: " + str(edge.loc)
292 errors.append(msg)
293
294 if errors:
295 errors.insert(
296 0, f"Uninitialized backedges remain in module '{self.circuit_name}'")
297 raise RuntimeError("\n".join(errors))
298
299
301 __slots__ = ["index", "operation", "value", "backedge_owner"]
302
303 def __init__(self,
304 operation: ir.Operation,
305 index: int,
306 value,
307 backedge_owner=None):
308 if not isinstance(index, int):
309 raise TypeError("Index must be int")
310 self.index = index
311
312 if not hasattr(operation, "operands"):
313 raise TypeError("Operation must be have 'operands' attribute")
314 self.operation = operation
315
316 self.value = value
317 self.backedge_owner = backedge_owner
318
319 @property
320 def type(self):
321 return self.value.type
322
323
325 """Helper class to incrementally construct an instance of an operation that
326 names its operands and results"""
327
328 def __init__(self,
329 cls,
330 data_type=None,
331 input_port_mapping=None,
332 pre_args=None,
333 post_args=None,
334 needs_result_type=False,
335 **kwargs):
336 # Set defaults
337 if input_port_mapping is None:
338 input_port_mapping = {}
339 if pre_args is None:
340 pre_args = []
341 if post_args is None:
342 post_args = []
343
344 # Set result_indices to name each result.
345 result_names = self.result_names()
346 result_indices = {}
347 for i in range(len(result_names)):
348 result_indices[result_names[i]] = i
349
350 # Set operand_indices to name each operand. Give them an initial value,
351 # either from input_port_mapping or a default value.
352 backedges = {}
353 operand_indices = {}
354 operand_values = []
355 operand_names = self.operand_names()
356 for i in range(len(operand_names)):
357 arg_name = operand_names[i]
358 operand_indices[arg_name] = i
359 if arg_name in input_port_mapping:
360 value = get_value(input_port_mapping[arg_name])
361 operand = value
362 else:
363 backedge = self.create_default_value(i, data_type, arg_name)
364 backedges[i] = backedge
365 operand = backedge.result
366 operand_values.append(operand)
367
368 # Some ops take a list of operand values rather than splatting them out.
369 if isinstance(data_type, list):
370 operand_values = [operand_values]
371
372 # In many cases, result types are inferred, and we do not need to pass
373 # data_type to the underlying constructor. It must be provided to
374 # NamedValueOpView in cases where we need to build backedges, but should
375 # generally not be passed to the underlying constructor in this case. There
376 # are some oddball ops that must pass it, even when building backedges, and
377 # these set needs_result_type=True.
378 if data_type is not None and (needs_result_type or len(backedges) == 0):
379 pre_args.insert(0, data_type)
380
381 self.opview = cls(*pre_args, *operand_values, *post_args, **kwargs)
382 self.operand_indices = operand_indices
383 self.result_indices = result_indices
384 self.backedges = backedges
385
386 def __getattr__(self, name):
387 # Check for the attribute in the arg name set.
388 if "operand_indices" in dir(self) and name in self.operand_indices:
389 index = self.operand_indices[name]
390 value = self.opview.operands[index]
391 return OpOperand(self.opview.operation, index, value, self)
392
393 # Check for the attribute in the result name set.
394 if "result_indices" in dir(self) and name in self.result_indices:
395 index = self.result_indices[name]
396 value = self.opview.results[index]
397 return OpOperand(self.opview.operation, index, value, self)
398
399 # Forward "attributes" attribute from the operation.
400 if name == "attributes":
401 return self.opview.operation.attributes
402
403 # If we fell through to here, the name isn't a result.
404 raise AttributeError(f"unknown port name {name}")
405
406 def create_default_value(self, index, data_type, arg_name):
407 return BackedgeBuilder.create(data_type, arg_name, self)
408
409 @property
410 def operation(self):
411 """Get the operation associated with this builder."""
412 return self.opview.operation
413
414
415# Helper function to walk operation with a filter on operation names.
416# `op_views` is a list of operation views to visit. This is a wrapper
417# around the C++ implementation of walk_with_filter.
418def walk_with_filter(operation: Operation, op_views: List[ir.OpView], callback,
419 walk_order):
420 op_names_identifiers = [name.OPERATION_NAME for name in op_views]
421 return _walk_with_filter(operation, op_names_identifiers, callback,
422 walk_order)
__init__(self, creator, ir.Type type, str backedge_name, op_view, ir.Operation instance_of, ir.Location loc=None)
Definition support.py:226
__init__(self, str circuit_name="")
Definition support.py:248
create(*args, **kwargs)
Definition support.py:260
_create(self, ir.Type type, str port_name, op_view, ir.Operation instance_of=None, ir.Location loc=None)
Definition support.py:268
__exit__(self, exc_type, exc_value, traceback)
Definition support.py:277
__init__(self, cls, data_type=None, input_port_mapping=None, pre_args=None, post_args=None, needs_result_type=False, **kwargs)
Definition support.py:335
create_default_value(self, index, data_type, arg_name)
Definition support.py:406
__init__(self, ir.Operation operation, int index, value, backedge_owner=None)
Definition support.py:307
__init__(self, str module, List[str] port_names)
Definition support.py:22
The "any" type is a special type which can be used to represent any type, as identified by the type i...
Definition Types.h:152
Bundles represent a collection of channels.
Definition Types.h:99
Channels are the basic communication primitives.
Definition Types.h:120
Lists represent variable-length sequences of elements of a single type.
Definition Types.h:348
Windows represent a fixed-size sliding window over a stream of data.
Definition Types.h:309
get_self_or_inner(mlir_type)
Definition support.py:207
walk_with_filter(Operation operation, List[ir.OpView] op_views, callback, walk_order)
Definition support.py:419
ir.Type type_to_pytype(t)
Definition support.py:81
connect(destination, source)
Definition support.py:39