5from __future__
import annotations
10from ..dialects._ods_common
import _cext
as _ods_cext
12from ._hw_ops_gen
import *
13from ._hw_ops_gen
import _Dialect
14from typing
import Dict, Type
19 mod_param_decls = module.parameters
20 mod_param_decls_idxs = {
21 decl.name: idx
for (idx, decl)
in enumerate(mod_param_decls)
23 inst_param_array = [
None] * len(module.parameters)
26 if isinstance(parameters, DictAttr):
27 parameters = {i.name: i.attr
for i
in parameters}
28 for (pname, pval)
in parameters.items():
29 if pname
not in mod_param_decls_idxs:
31 f
"Could not find parameter '{pname}' in module parameter decls")
32 idx = mod_param_decls_idxs[pname]
33 param_decl = mod_param_decls[idx]
34 inst_param_array[idx] = hw.ParamDeclAttr.get(pname, param_decl.param_type,
38 for (idx, pval)
in enumerate(inst_param_array):
41 inst_param_array[idx] = mod_param_decls[idx]
43 return inst_param_array
47 """Helper class to incrementally construct an instance of a module."""
60 instance_name = StringAttr.get(name)
61 module_name = FlatSymbolRefAttr.get(StringAttr(module.name).value)
64 inner_sym = hw.InnerSymAttr.get(StringAttr.get(sym_name))
67 pre_args = [instance_name, module_name]
69 ArrayAttr.get([StringAttr.get(x)
for x
in self.
operand_names()]),
70 ArrayAttr.get([StringAttr.get(x)
for x
in self.
result_names()]),
71 ArrayAttr.get(inst_param_array)
74 results = module.type.output_types
77 input_name_type_lookup = {
78 name: support.type_to_pytype(ty)
79 for name, ty
in zip(self.
operand_names(), module.type.input_types)
81 for input_name, input_value
in input_port_mapping.items():
82 if input_name
not in input_name_type_lookup:
84 mod_input_type = input_name_type_lookup[input_name]
85 if support.type_to_pytype(input_value.type) != mod_input_type:
86 raise TypeError(f
"Input '{input_name}' has type '{input_value.type}' "
87 f
"but expected '{mod_input_type}'")
94 needs_result_type=
True,
100 type = self.
module.type.input_types[index]
101 return support.BackedgeBuilder.create(type,
107 return self.
module.type.input_names
110 return self.
module.type.output_names
114 """Custom Python helper class for module-like operations."""
130 Create a module-like with the provided `name`, `input_ports`, and
132 - `name` is a string representing the module name.
133 - `input_ports` is a list of pairs of string names and mlir.ir types.
134 - `output_ports` is a list of pairs of string names and mlir.ir types.
135 - `body_builder` is an optional callback, when provided a new entry block
136 is created and the callback is invoked with the new op as argument within
137 an InsertionPoint context already set for the block. The callback is
138 expected to insert a terminator in the block.
141 input_ports = list(input_ports)
142 output_ports = list(output_ports)
143 parameters = list(parameters)
144 attributes = dict(attributes)
148 attributes[
"sym_name"] = StringAttr.get(str(name))
152 unknownLoc = Location.unknown().attr
153 for (i, (port_name, port_type))
in enumerate(input_ports):
154 input_name = StringAttr.get(str(port_name))
155 input_dir = hw.ModulePortDirection.INPUT
156 input_port = hw.ModulePort(input_name, port_type, input_dir)
157 module_ports.append(input_port)
158 input_names.append(input_name)
162 for (i, (port_name, port_type))
in enumerate(output_ports):
163 output_name = StringAttr.get(str(port_name))
164 output_dir = hw.ModulePortDirection.OUTPUT
165 output_port = hw.ModulePort(output_name, port_type, output_dir)
166 module_ports.append(output_port)
167 output_names.append(output_name)
168 attributes[
"per_port_attrs"] = ArrayAttr.get([])
170 if len(parameters) > 0
or "parameters" not in attributes:
171 attributes[
"parameters"] = ArrayAttr.get(parameters)
173 attributes[
"module_type"] = TypeAttr.get(hw.ModuleType.get(module_ports))
175 _ods_cext.ir.OpView.__init__(
177 op.build_generic(attributes=attributes,
184 entry_block = op.add_entry_block()
186 with InsertionPoint(entry_block):
187 with support.BackedgeBuilder(str(name)):
188 outputs = body_builder(op)
193 return hw.ModuleType(TypeAttr(op.attributes[
"module_type"]).value)
197 return op.attributes[
"sym_name"]
201 return len(op.regions[0].blocks) == 0
205 return [hw.ParamDeclAttr(a)
for a
in ArrayAttr(op.attributes[
"parameters"])]
210 parameters: Dict[str, object] = {},
219 parameters=parameters,
227 """Create the hw.OutputOp from the body_builder return."""
230 block_len = len(entry_block.operations)
232 last_op = entry_block.operations[block_len - 1]
233 if isinstance(last_op, hw.OutputOp):
235 if bb_ret
is not None and bb_ret != last_op:
236 raise support.ConnectionError(
237 f
"In {cls_name}, cannot return value from body_builder and "
238 "create hw.OutputOp")
244 if len(output_ports) == 0:
247 raise support.ConnectionError(
248 f
"In {cls_name}, must return module output values")
251 outputs: list[Value] = list()
254 if not isinstance(bb_ret, dict):
255 raise support.ConnectionError(
256 f
"In {cls_name}, can only return a dict of port, value mappings "
257 "from body_builder.")
261 unconnected_ports = []
262 for (name, port_type)
in output_ports:
263 if name
not in bb_ret:
264 unconnected_ports.append(name)
267 val = support.get_value(bb_ret[name])
269 field_type = type(bb_ret[name])
271 f
"In {cls_name}, body_builder return doesn't support type "
273 if val.type != port_type:
274 if isinstance(port_type, hw.TypeAliasType)
and \
275 port_type.inner_type == val.type:
279 f
"In {cls_name}, output port '{name}' type ({val.type}) doesn't "
280 f
"match declared type ({port_type})")
283 if len(unconnected_ports) > 0:
284 raise support.UnconnectedSignalError(cls_name, unconnected_ports)
286 raise support.ConnectionError(
287 f
"Could not map the following to output ports in {cls_name}: " +
288 ",".join(bb_ret.keys()))
293@_ods_cext.register_operation(_Dialect, replace=True)
295 """Specialization for the HW module op class."""
309 if "comment" not in attributes:
310 attributes[
"comment"] = StringAttr.get(
"")
311 ModuleLike.init(self,
315 parameters=parameters,
316 attributes=attributes,
317 body_builder=body_builder,
322 return ModuleLike.instantiate(self, *args, **kwargs)
326 return ModuleLike.type(self)
330 return ModuleLike.name(self)
334 return ModuleLike.is_external(self)
338 return ModuleLike.parameters(self)
342 return self.regions[0]
346 return self.regions[0].blocks[0]
350 indices: dict[int, str] = {}
351 op_names = self.
type.input_names
352 for idx, name
in enumerate(op_names):
361 raise AttributeError(f
"unknown input port name {name}")
370 result_names = self.
type.output_names
371 result_types = self.
type.output_types
372 return dict(zip(result_names, result_types))
376 raise IndexError(
'The module already has an entry block')
377 self.
body.blocks.append(*self.
type.input_types)
378 return self.
body.blocks[0]
381@_ods_cext.register_operation(_Dialect, replace=True)
383 """Specialization for the HW module op class."""
397 if "comment" not in attributes:
398 attributes[
"comment"] = StringAttr.get(
"")
399 ModuleLike.init(self,
403 parameters=parameters,
404 attributes=attributes,
405 body_builder=body_builder,
410 return ModuleLike.instantiate(self, *args, **kwargs)
414 return ModuleLike.type(self)
418 return ModuleLike.name(self)
422 return ModuleLike.is_external(self)
426 return ModuleLike.parameters(self)
429@_ods_cext.register_operation(_Dialect, replace=True)
437@_ods_cext.register_operation(_Dialect, replace=True)
442 value = support.get_value(value)
446@_ods_cext.register_operation(_Dialect, replace=True)
451 array_value = support.get_value(array_value)
452 array_type = support.get_self_or_inner(array_value.type)
453 if isinstance(idx, int):
454 idx_width = (array_type.size - 1).bit_length()
455 idx_val = ConstantOp.create(IntegerType.get_signless(idx_width),
458 idx_val = support.get_value(idx)
462@_ods_cext.register_operation(_Dialect, replace=True)
466 def create(array_value, low_index, ret_type):
467 array_value = support.get_value(array_value)
468 array_type = support.get_self_or_inner(array_value.type)
469 if isinstance(low_index, int):
470 idx_width = (array_type.size - 1).bit_length()
471 idx_width = max(1, idx_width)
472 idx_val = ConstantOp.create(IntegerType.get_signless(idx_width),
475 idx_val = support.get_value(low_index)
479@_ods_cext.register_operation(_Dialect, replace=True)
485 raise ValueError(
"Cannot 'create' an array of length zero")
488 for i, arg
in enumerate(elements):
489 arg_val = support.get_value(arg)
493 elif type != arg_val.type:
495 f
"Argument {i} has a different element type ({arg_val.type}) than the element type of the array ({type})"
500@_ods_cext.register_operation(_Dialect, replace=True)
508 for i, array
in enumerate(sub_arrays):
509 array_value = support.get_value(array)
510 array_type = support.type_to_pytype(array_value.type)
511 if array_value
is None or not isinstance(array_type, hw.ArrayType):
512 raise TypeError(f
"Cannot concatenate {array_value}")
513 if element_type
is None:
514 element_type = array_type.element_type
515 elif element_type != array_type.element_type:
517 f
"Argument {i} has a different element type ({element_type}) than the element type of the array ({array_type.element_type})"
520 vals.append(array_value)
521 types.append(array_type)
523 size = sum(t.size
for t
in types)
524 combined_type = hw.ArrayType.get(element_type, size)
528@_ods_cext.register_operation(_Dialect, replace=True)
532 def create(elements, result_type: Type =
None):
534 (name, support.get_value(value))
for (name, value)
in elements
536 struct_fields = [(name, value.type)
for (name, value)
in elem_name_values]
537 struct_type = hw.StructType.get(struct_fields)
539 if result_type
is None:
540 result_type = struct_type
542 result_type_inner = support.get_self_or_inner(result_type)
543 if result_type_inner != struct_type:
545 f
"result type:\n\t{result_type_inner}\nmust match generated struct type:\n\t{struct_type}"
549 [value
for (_, value)
in elem_name_values])
552@_ods_cext.register_operation(_Dialect, replace=True)
556 def create(struct_value, field_name: str):
557 struct_value = support.get_value(struct_value)
558 struct_type = support.get_self_or_inner(struct_value.type)
559 field_type = struct_type.get_field(field_name)
560 field_index = struct_type.get_field_index(field_name)
561 if field_index == UnitAttr.get():
563 f
"field '{field_name}' not element of struct type {struct_type}")
567@_ods_cext.register_operation(_Dialect, replace=True)
571 def create(sym_name: str, type: Type, verilog_name: str =
None):
574 verilogName=verilog_name)
577@_ods_cext.register_operation(_Dialect, replace=True)
583 op.regions[0].blocks.append()
588 return self.regions[0].blocks[0]
create(array_value, low_index, ret_type)
__init__(self, name, input_ports=[], output_ports=[], *parameters=[], attributes={}, body_builder=None, loc=None, ip=None)
instantiate(self, *args, **kwargs)
list[ParamDeclAttr] parameters(self)
dict[str:Value] inputs(self)
instantiate(self, *args, **kwargs)
list[ParamDeclAttr] parameters(self)
dict[str:Type] outputs(self)
__init__(self, name, input_ports=[], output_ports=[], *parameters=[], attributes={}, body_builder=None, loc=None, ip=None)
__init__(self, module, name, input_port_mapping, *results=None, parameters={}, sym_name=None, loc=None, ip=None)
create_default_value(self, index, data_type, arg_name)
list[ParamDeclAttr] parameters(op)
instantiate(op, str name, Dict[str, object] parameters={}, results=None, sym_name=None, loc=None, ip=None, **kwargs)
init(op, name, input_ports=[], output_ports=[], *parameters=[], attributes={}, body_builder=None, loc=None, ip=None)
create(elements, Type result_type=None)
create(str sym_name, Type type, str verilog_name=None)
_create_output_op(cls_name, output_ports, entry_block, bb_ret)
create_parameters(dict[str, Attribute] parameters, ModuleLike module)