5 from __future__
import annotations
10 from ..dialects._ods_common
import _cext
as _ods_cext
12 from ._hw_ops_gen
import *
13 from ._hw_ops_gen
import _Dialect
14 from typing
import Dict, Type
19 mod_param_decls = module.parameters
20 mod_param_decls_idxs = {
21 decl.name: idx
for (idx, decl)
in enumerate(mod_param_decls)
23 inst_param_array = [
None] * len(module.parameters)
26 if isinstance(parameters, DictAttr):
27 parameters = {i.name: i.attr
for i
in parameters}
28 for (pname, pval)
in parameters.items():
29 if pname
not in mod_param_decls_idxs:
31 f
"Could not find parameter '{pname}' in module parameter decls")
32 idx = mod_param_decls_idxs[pname]
33 param_decl = mod_param_decls[idx]
34 inst_param_array[idx] = hw.ParamDeclAttr.get(pname, param_decl.param_type,
38 for (idx, pval)
in enumerate(inst_param_array):
41 inst_param_array[idx] = mod_param_decls[idx]
43 return inst_param_array
47 """Helper class to incrementally construct an instance of a module."""
60 instance_name = StringAttr.get(name)
61 module_name = FlatSymbolRefAttr.get(StringAttr(module.name).value)
64 inner_sym = hw.InnerSymAttr.get(StringAttr.get(sym_name))
67 pre_args = [instance_name, module_name]
69 ArrayAttr.get([StringAttr.get(x)
for x
in self.
operand_namesoperand_names()]),
70 ArrayAttr.get([StringAttr.get(x)
for x
in self.
result_namesresult_names()]),
71 ArrayAttr.get(inst_param_array)
74 results = module.type.output_types
77 input_name_type_lookup = {
78 name: support.type_to_pytype(ty)
79 for name, ty
in zip(self.
operand_namesoperand_names(), module.type.input_types)
81 for input_name, input_value
in input_port_mapping.items():
82 if input_name
not in input_name_type_lookup:
84 mod_input_type = input_name_type_lookup[input_name]
85 if support.type_to_pytype(input_value.type) != mod_input_type:
86 raise TypeError(f
"Input '{input_name}' has type '{input_value.type}' "
87 f
"but expected '{mod_input_type}'")
94 needs_result_type=
True,
100 type = self.
modulemodule.type.input_types[index]
101 return support.BackedgeBuilder.create(type,
104 instance_of=self.
modulemodule)
107 return self.
modulemodule.type.input_names
110 return self.
modulemodule.type.output_names
114 """Custom Python base class for module-like operations."""
129 Create a module-like with the provided `name`, `input_ports`, and
131 - `name` is a string representing the module name.
132 - `input_ports` is a list of pairs of string names and mlir.ir types.
133 - `output_ports` is a list of pairs of string names and mlir.ir types.
134 - `body_builder` is an optional callback, when provided a new entry block
135 is created and the callback is invoked with the new op as argument within
136 an InsertionPoint context already set for the block. The callback is
137 expected to insert a terminator in the block.
140 input_ports = list(input_ports)
141 output_ports = list(output_ports)
142 parameters = list(parameters)
143 attributes = dict(attributes)
147 attributes[
"sym_name"] = StringAttr.get(str(name))
151 unknownLoc = Location.unknown().attr
152 for (i, (port_name, port_type))
in enumerate(input_ports):
153 input_name = StringAttr.get(str(port_name))
154 input_dir = hw.ModulePortDirection.INPUT
155 input_port = hw.ModulePort(input_name, port_type, input_dir)
156 module_ports.append(input_port)
157 input_names.append(input_name)
161 for (i, (port_name, port_type))
in enumerate(output_ports):
162 output_name = StringAttr.get(str(port_name))
163 output_dir = hw.ModulePortDirection.OUTPUT
164 output_port = hw.ModulePort(output_name, port_type, output_dir)
165 module_ports.append(output_port)
166 output_names.append(output_name)
167 attributes[
"per_port_attrs"] = ArrayAttr.get([])
169 if len(parameters) > 0
or "parameters" not in attributes:
170 attributes[
"parameters"] = ArrayAttr.get(parameters)
172 attributes[
"module_type"] = TypeAttr.get(hw.ModuleType.get(module_ports))
174 _ods_cext.ir.OpView.__init__(
176 self.build_generic(attributes=attributes,
183 entry_block = self.add_entry_block()
185 with InsertionPoint(entry_block):
186 with support.BackedgeBuilder(str(name)):
187 outputs = body_builder(self)
192 return hw.ModuleType(TypeAttr(self.attributes[
"module_type"]).value)
196 return self.attributes[
"sym_name"]
200 return len(self.regions[0].blocks) == 0
205 hw.ParamDeclAttr(a)
for a
in ArrayAttr(self.attributes[
"parameters"])
210 parameters: Dict[str, object] = {},
219 parameters=parameters,
227 """Create the hw.OutputOp from the body_builder return."""
230 block_len = len(entry_block.operations)
232 last_op = entry_block.operations[block_len - 1]
233 if isinstance(last_op, hw.OutputOp):
235 if bb_ret
is not None and bb_ret != last_op:
236 raise support.ConnectionError(
237 f
"In {cls_name}, cannot return value from body_builder and "
238 "create hw.OutputOp")
244 if len(output_ports) == 0:
247 raise support.ConnectionError(
248 f
"In {cls_name}, must return module output values")
251 outputs: list[Value] = list()
254 if not isinstance(bb_ret, dict):
255 raise support.ConnectionError(
256 f
"In {cls_name}, can only return a dict of port, value mappings "
257 "from body_builder.")
261 unconnected_ports = []
262 for (name, port_type)
in output_ports:
263 if name
not in bb_ret:
264 unconnected_ports.append(name)
267 val = support.get_value(bb_ret[name])
269 field_type = type(bb_ret[name])
271 f
"In {cls_name}, body_builder return doesn't support type "
273 if val.type != port_type:
274 if isinstance(port_type, hw.TypeAliasType)
and \
275 port_type.inner_type == val.type:
279 f
"In {cls_name}, output port '{name}' type ({val.type}) doesn't "
280 f
"match declared type ({port_type})")
283 if len(unconnected_ports) > 0:
284 raise support.UnconnectedSignalError(cls_name, unconnected_ports)
286 raise support.ConnectionError(
287 f
"Could not map the following to output ports in {cls_name}: " +
288 ",".join(bb_ret.keys()))
293 @_ods_cext.register_operation(_Dialect, replace=True)
295 """Specialization for the HW module op class."""
309 if "comment" not in attributes:
310 attributes[
"comment"] = StringAttr.get(
"")
314 parameters=parameters,
315 attributes=attributes,
316 body_builder=body_builder,
322 return self.regions[0]
326 return self.regions[0].blocks[0]
330 indices: dict[int, str] = {}
331 op_names = self.
typetype.input_names
332 for idx, name
in enumerate(op_names):
340 return self.
entry_blockentry_block.arguments[index]
341 raise AttributeError(f
"unknown input port name {name}")
346 ret[name] = self.
entry_blockentry_block.arguments[idx]
350 result_names = self.
typetype.output_names
351 result_types = self.
typetype.output_types
352 return dict(zip(result_names, result_types))
356 raise IndexError(
'The module already has an entry block')
357 self.
bodybody.blocks.append(*self.
typetype.input_types)
358 return self.
bodybody.blocks[0]
361 @_ods_cext.register_operation(_Dialect, replace=True)
363 """Specialization for the HW module op class."""
377 if "comment" not in attributes:
378 attributes[
"comment"] = StringAttr.get(
"")
382 parameters=parameters,
383 attributes=attributes,
384 body_builder=body_builder,
389 @_ods_cext.register_operation(_Dialect, replace=True)
397 @_ods_cext.register_operation(_Dialect, replace=True)
402 value = support.get_value(value)
406 @_ods_cext.register_operation(_Dialect, replace=True)
411 array_value = support.get_value(array_value)
412 array_type = support.get_self_or_inner(array_value.type)
413 if isinstance(idx, int):
414 idx_width = (array_type.size - 1).bit_length()
415 idx_val = ConstantOp.create(IntegerType.get_signless(idx_width),
418 idx_val = support.get_value(idx)
422 @_ods_cext.register_operation(_Dialect, replace=True)
426 def create(array_value, low_index, ret_type):
427 array_value = support.get_value(array_value)
428 array_type = support.get_self_or_inner(array_value.type)
429 if isinstance(low_index, int):
430 idx_width = (array_type.size - 1).bit_length()
431 idx_width = max(1, idx_width)
432 idx_val = ConstantOp.create(IntegerType.get_signless(idx_width),
435 idx_val = support.get_value(low_index)
439 @_ods_cext.register_operation(_Dialect, replace=True)
445 raise ValueError(
"Cannot 'create' an array of length zero")
448 for i, arg
in enumerate(elements):
449 arg_val = support.get_value(arg)
453 elif type != arg_val.type:
455 f
"Argument {i} has a different element type ({arg_val.type}) than the element type of the array ({type})"
460 @_ods_cext.register_operation(_Dialect, replace=True)
468 for i, array
in enumerate(sub_arrays):
469 array_value = support.get_value(array)
470 array_type = support.type_to_pytype(array_value.type)
471 if array_value
is None or not isinstance(array_type, hw.ArrayType):
472 raise TypeError(f
"Cannot concatenate {array_value}")
473 if element_type
is None:
474 element_type = array_type.element_type
475 elif element_type != array_type.element_type:
477 f
"Argument {i} has a different element type ({element_type}) than the element type of the array ({array_type.element_type})"
480 vals.append(array_value)
481 types.append(array_type)
483 size = sum(t.size
for t
in types)
484 combined_type = hw.ArrayType.get(element_type, size)
488 @_ods_cext.register_operation(_Dialect, replace=True)
492 def create(elements, result_type: Type =
None):
494 (name, support.get_value(value))
for (name, value)
in elements
496 struct_fields = [(name, value.type)
for (name, value)
in elem_name_values]
497 struct_type = hw.StructType.get(struct_fields)
499 if result_type
is None:
500 result_type = struct_type
502 result_type_inner = support.get_self_or_inner(result_type)
503 if result_type_inner != struct_type:
505 f
"result type:\n\t{result_type_inner}\nmust match generated struct type:\n\t{struct_type}"
509 [value
for (_, value)
in elem_name_values])
512 @_ods_cext.register_operation(_Dialect, replace=True)
516 def create(struct_value, field_name: str):
517 struct_value = support.get_value(struct_value)
518 struct_type = support.get_self_or_inner(struct_value.type)
519 field_type = struct_type.get_field(field_name)
520 field_index = struct_type.get_field_index(field_name)
521 if field_index == UnitAttr.get():
523 f
"field '{field_name}' not element of struct type {struct_type}")
527 @_ods_cext.register_operation(_Dialect, replace=True)
531 def create(sym_name: str, type: Type, verilog_name: str =
None):
534 verilogName=verilog_name)
537 @_ods_cext.register_operation(_Dialect, replace=True)
543 op.regions[0].blocks.append()
548 return self.regions[0].blocks[0]
def create(array_value, idx)
def create(array_value, low_index, ret_type)
def create(data_type, value)
def create(data_type, value)
def __init__(self, name, input_ports=[], output_ports=[], *parameters=[], attributes={}, body_builder=None, loc=None, ip=None)
def add_entry_block(self)
dict[str:Value] inputs(self)
dict[str:Type] outputs(self)
def __init__(self, name, input_ports=[], output_ports=[], *parameters=[], attributes={}, body_builder=None, loc=None, ip=None)
def __getattr__(self, name)
def create_default_value(self, index, data_type, arg_name)
def __init__(self, module, name, input_port_mapping, *results=None, parameters={}, sym_name=None, loc=None, ip=None)
list[ParamDeclAttr] parameters(self)
def instantiate(self, str name, Dict[str, object] parameters={}, results=None, sym_name=None, loc=None, ip=None, **kwargs)
def __init__(self, name, input_ports=[], output_ports=[], *parameters=[], attributes={}, body_builder=None, loc=None, ip=None)
def create(elements, Type result_type=None)
def create(str sym_name, Type type, str verilog_name=None)
def _create_output_op(cls_name, output_ports, entry_block, bb_ret)
def create_parameters(dict[str, Attribute] parameters, ModuleLike module)