5 from __future__
import annotations
10 from ..dialects._ods_common
import _cext
as _ods_cext
12 from ._hw_ops_gen
import *
13 from ._hw_ops_gen
import _Dialect
14 from typing
import Dict, Type
19 mod_param_decls = module.parameters
20 mod_param_decls_idxs = {
21 decl.name: idx
for (idx, decl)
in enumerate(mod_param_decls)
23 inst_param_array = [
None] * len(module.parameters)
26 if isinstance(parameters, DictAttr):
27 parameters = {i.name: i.attr
for i
in parameters}
28 for (pname, pval)
in parameters.items():
29 if pname
not in mod_param_decls_idxs:
31 f
"Could not find parameter '{pname}' in module parameter decls")
32 idx = mod_param_decls_idxs[pname]
33 param_decl = mod_param_decls[idx]
34 inst_param_array[idx] = hw.ParamDeclAttr.get(pname, param_decl.param_type,
38 for (idx, pval)
in enumerate(inst_param_array):
41 inst_param_array[idx] = mod_param_decls[idx]
43 return inst_param_array
47 """Helper class to incrementally construct an instance of a module."""
60 instance_name = StringAttr.get(name)
61 module_name = FlatSymbolRefAttr.get(StringAttr(module.name).value)
64 inner_sym = hw.InnerSymAttr.get(StringAttr.get(sym_name))
67 pre_args = [instance_name, module_name]
69 ArrayAttr.get([StringAttr.get(x)
for x
in self.
operand_namesoperand_names()]),
70 ArrayAttr.get([StringAttr.get(x)
for x
in self.
result_namesresult_names()]),
71 ArrayAttr.get(inst_param_array)
74 results = module.type.output_types
77 input_name_type_lookup = {
78 name: support.type_to_pytype(ty)
79 for name, ty
in zip(self.
operand_namesoperand_names(), module.type.input_types)
81 for input_name, input_value
in input_port_mapping.items():
82 if input_name
not in input_name_type_lookup:
84 mod_input_type = input_name_type_lookup[input_name]
85 if support.type_to_pytype(input_value.type) != mod_input_type:
86 raise TypeError(f
"Input '{input_name}' has type '{input_value.type}' "
87 f
"but expected '{mod_input_type}'")
94 needs_result_type=
True,
100 type = self.
modulemodule.type.input_types[index]
101 return support.BackedgeBuilder.create(type,
104 instance_of=self.
modulemodule)
107 return self.
modulemodule.type.input_names
110 return self.
modulemodule.type.output_names
114 """Custom Python base class for module-like operations."""
129 Create a module-like with the provided `name`, `input_ports`, and
131 - `name` is a string representing the module name.
132 - `input_ports` is a list of pairs of string names and mlir.ir types.
133 - `output_ports` is a list of pairs of string names and mlir.ir types.
134 - `body_builder` is an optional callback, when provided a new entry block
135 is created and the callback is invoked with the new op as argument within
136 an InsertionPoint context already set for the block. The callback is
137 expected to insert a terminator in the block.
140 input_ports = list(input_ports)
141 output_ports = list(output_ports)
142 parameters = list(parameters)
143 attributes = dict(attributes)
147 attributes[
"sym_name"] = StringAttr.get(str(name))
152 unknownLoc = Location.unknown().attr
153 for (i, (port_name, port_type))
in enumerate(input_ports):
154 input_name = StringAttr.get(str(port_name))
155 input_dir = hw.ModulePortDirection.INPUT
156 input_port = hw.ModulePort(input_name, port_type, input_dir)
157 module_ports.append(input_port)
158 input_names.append(input_name)
159 port_locs.append(unknownLoc)
163 for (i, (port_name, port_type))
in enumerate(output_ports):
164 output_name = StringAttr.get(str(port_name))
165 output_dir = hw.ModulePortDirection.OUTPUT
166 output_port = hw.ModulePort(output_name, port_type, output_dir)
167 module_ports.append(output_port)
168 output_names.append(output_name)
169 port_locs.append(unknownLoc)
170 attributes[
"port_locs"] = ArrayAttr.get(port_locs)
171 attributes[
"per_port_attrs"] = ArrayAttr.get([])
173 if len(parameters) > 0
or "parameters" not in attributes:
174 attributes[
"parameters"] = ArrayAttr.get(parameters)
176 attributes[
"module_type"] = TypeAttr.get(hw.ModuleType.get(module_ports))
178 _ods_cext.ir.OpView.__init__(
180 self.build_generic(attributes=attributes,
187 entry_block = self.add_entry_block()
189 with InsertionPoint(entry_block):
190 with support.BackedgeBuilder():
191 outputs = body_builder(self)
196 return hw.ModuleType(TypeAttr(self.attributes[
"module_type"]).value)
200 return self.attributes[
"sym_name"]
204 return len(self.regions[0].blocks) == 0
209 hw.ParamDeclAttr(a)
for a
in ArrayAttr(self.attributes[
"parameters"])
214 parameters: Dict[str, object] = {},
223 parameters=parameters,
231 """Create the hw.OutputOp from the body_builder return."""
234 block_len = len(entry_block.operations)
236 last_op = entry_block.operations[block_len - 1]
237 if isinstance(last_op, hw.OutputOp):
239 if bb_ret
is not None and bb_ret != last_op:
240 raise support.ConnectionError(
241 f
"In {cls_name}, cannot return value from body_builder and "
242 "create hw.OutputOp")
248 if len(output_ports) == 0:
251 raise support.ConnectionError(
252 f
"In {cls_name}, must return module output values")
255 outputs: list[Value] = list()
258 if not isinstance(bb_ret, dict):
259 raise support.ConnectionError(
260 f
"In {cls_name}, can only return a dict of port, value mappings "
261 "from body_builder.")
265 unconnected_ports = []
266 for (name, port_type)
in output_ports:
267 if name
not in bb_ret:
268 unconnected_ports.append(name)
271 val = support.get_value(bb_ret[name])
273 field_type = type(bb_ret[name])
275 f
"In {cls_name}, body_builder return doesn't support type "
277 if val.type != port_type:
278 if isinstance(port_type, hw.TypeAliasType)
and \
279 port_type.inner_type == val.type:
283 f
"In {cls_name}, output port '{name}' type ({val.type}) doesn't "
284 f
"match declared type ({port_type})")
287 if len(unconnected_ports) > 0:
288 raise support.UnconnectedSignalError(cls_name, unconnected_ports)
290 raise support.ConnectionError(
291 f
"Could not map the following to output ports in {cls_name}: " +
292 ",".join(bb_ret.keys()))
297 @_ods_cext.register_operation(_Dialect, replace=True)
299 """Specialization for the HW module op class."""
313 if "comment" not in attributes:
314 attributes[
"comment"] = StringAttr.get(
"")
318 parameters=parameters,
319 attributes=attributes,
320 body_builder=body_builder,
326 return self.regions[0]
330 return self.regions[0].blocks[0]
334 indices: dict[int, str] = {}
335 op_names = self.
typetype.input_names
336 for idx, name
in enumerate(op_names):
344 return self.
entry_blockentry_block.arguments[index]
345 raise AttributeError(f
"unknown input port name {name}")
350 ret[name] = self.
entry_blockentry_block.arguments[idx]
354 result_names = self.
typetype.output_names
355 result_types = self.
typetype.output_types
356 return dict(zip(result_names, result_types))
360 raise IndexError(
'The module already has an entry block')
361 self.
bodybody.blocks.append(*self.
typetype.input_types)
362 return self.
bodybody.blocks[0]
365 @_ods_cext.register_operation(_Dialect, replace=True)
367 """Specialization for the HW module op class."""
381 if "comment" not in attributes:
382 attributes[
"comment"] = StringAttr.get(
"")
386 parameters=parameters,
387 attributes=attributes,
388 body_builder=body_builder,
393 @_ods_cext.register_operation(_Dialect, replace=True)
401 @_ods_cext.register_operation(_Dialect, replace=True)
406 value = support.get_value(value)
410 @_ods_cext.register_operation(_Dialect, replace=True)
415 array_value = support.get_value(array_value)
416 array_type = support.get_self_or_inner(array_value.type)
417 if isinstance(idx, int):
418 idx_width = (array_type.size - 1).bit_length()
419 idx_val = ConstantOp.create(IntegerType.get_signless(idx_width),
422 idx_val = support.get_value(idx)
426 @_ods_cext.register_operation(_Dialect, replace=True)
430 def create(array_value, low_index, ret_type):
431 array_value = support.get_value(array_value)
432 array_type = support.get_self_or_inner(array_value.type)
433 if isinstance(low_index, int):
434 idx_width = (array_type.size - 1).bit_length()
435 idx_width = max(1, idx_width)
436 idx_val = ConstantOp.create(IntegerType.get_signless(idx_width),
439 idx_val = support.get_value(low_index)
443 @_ods_cext.register_operation(_Dialect, replace=True)
449 raise ValueError(
"Cannot 'create' an array of length zero")
452 for i, arg
in enumerate(elements):
453 arg_val = support.get_value(arg)
457 elif type != arg_val.type:
459 f
"Argument {i} has a different element type ({arg_val.type}) than the element type of the array ({type})"
464 @_ods_cext.register_operation(_Dialect, replace=True)
472 for i, array
in enumerate(sub_arrays):
473 array_value = support.get_value(array)
474 array_type = support.type_to_pytype(array_value.type)
475 if array_value
is None or not isinstance(array_type, hw.ArrayType):
476 raise TypeError(f
"Cannot concatenate {array_value}")
477 if element_type
is None:
478 element_type = array_type.element_type
479 elif element_type != array_type.element_type:
481 f
"Argument {i} has a different element type ({element_type}) than the element type of the array ({array_type.element_type})"
484 vals.append(array_value)
485 types.append(array_type)
487 size = sum(t.size
for t
in types)
488 combined_type = hw.ArrayType.get(element_type, size)
492 @_ods_cext.register_operation(_Dialect, replace=True)
496 def create(elements, result_type: Type =
None):
498 (name, support.get_value(value))
for (name, value)
in elements
500 struct_fields = [(name, value.type)
for (name, value)
in elem_name_values]
501 struct_type = hw.StructType.get(struct_fields)
503 if result_type
is None:
504 result_type = struct_type
506 result_type_inner = support.get_self_or_inner(result_type)
507 if result_type_inner != struct_type:
509 f
"result type:\n\t{result_type_inner}\nmust match generated struct type:\n\t{struct_type}"
513 [value
for (_, value)
in elem_name_values])
516 @_ods_cext.register_operation(_Dialect, replace=True)
520 def create(struct_value, field_name: str):
521 struct_value = support.get_value(struct_value)
522 struct_type = support.get_self_or_inner(struct_value.type)
523 field_type = struct_type.get_field(field_name)
524 field_index = struct_type.get_field_index(field_name)
525 if field_index == UnitAttr.get():
527 f
"field '{field_name}' not element of struct type {struct_type}")
531 @_ods_cext.register_operation(_Dialect, replace=True)
535 def create(sym_name: str, type: Type, verilog_name: str =
None):
538 verilogName=verilog_name)
541 @_ods_cext.register_operation(_Dialect, replace=True)
547 op.regions[0].blocks.append()
552 return self.regions[0].blocks[0]
def create(array_value, idx)
def create(array_value, low_index, ret_type)
def create(data_type, value)
def create(data_type, value)
def __init__(self, name, input_ports=[], output_ports=[], *parameters=[], attributes={}, body_builder=None, loc=None, ip=None)
def add_entry_block(self)
dict[str:Value] inputs(self)
dict[str:Type] outputs(self)
def __init__(self, name, input_ports=[], output_ports=[], *parameters=[], attributes={}, body_builder=None, loc=None, ip=None)
def __getattr__(self, name)
def create_default_value(self, index, data_type, arg_name)
def __init__(self, module, name, input_port_mapping, *results=None, parameters={}, sym_name=None, loc=None, ip=None)
list[ParamDeclAttr] parameters(self)
def instantiate(self, str name, Dict[str, object] parameters={}, results=None, sym_name=None, loc=None, ip=None, **kwargs)
def __init__(self, name, input_ports=[], output_ports=[], *parameters=[], attributes={}, body_builder=None, loc=None, ip=None)
def create(elements, Type result_type=None)
def create(str sym_name, Type type, str verilog_name=None)
def _create_output_op(cls_name, output_ports, entry_block, bb_ret)
def create_parameters(dict[str, Attribute] parameters, ModuleLike module)