23#include "mlir/IR/Builders.h"
24#include "mlir/IR/Matchers.h"
25#include "mlir/IR/PatternMatch.h"
26#include "mlir/Interfaces/FunctionImplementation.h"
27#include "llvm/ADT/BitVector.h"
28#include "llvm/ADT/SmallPtrSet.h"
29#include "llvm/ADT/StringSet.h"
38 case ModulePort::Direction::Input:
39 return ModulePort::Direction::Output;
40 case ModulePort::Direction::Output:
41 return ModulePort::Direction::Input;
42 case ModulePort::Direction::InOut:
43 return ModulePort::Direction::InOut;
45 llvm_unreachable(
"unknown PortDirection");
48bool hw::isValidIndexBitWidth(Value index, Value array) {
49 hw::ArrayType arrayType =
50 dyn_cast<hw::ArrayType>(hw::getCanonicalType(array.getType()));
51 assert(arrayType &&
"expected array type");
52 unsigned indexWidth = index.getType().getIntOrFloatBitWidth();
53 auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements());
54 return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1)
55 : indexWidth == requiredWidth;
59bool hw::isCombinational(Operation *op) {
60 struct IsCombClassifier :
public TypeOpVisitor<IsCombClassifier, bool> {
65 return (op->getDialect() && op->getDialect()->getNamespace() ==
"comb") ||
71 if (
auto structCreate = dyn_cast_or_null<StructCreateOp>(inputOp)) {
72 return structCreate.getOperand(fieldIndex);
76 if (
auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
77 if (structInject.getFieldIndex() != fieldIndex)
79 return structInject.getNewValue();
85 ArrayRef<Attribute> attrs) {
87 return ArrayAttr::get(context, {});
90 if (a && !cast<DictionaryAttr>(a).empty()) {
95 return ArrayAttr::get(context, {});
96 return ArrayAttr::get(context, attrs);
102 OpAsmSetValueNameFn setNameFn) {
106 auto module = cast<HWModuleOp>(region.getParentOp());
108 auto *block = ®ion.front();
109 for (
size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
110 auto name =
module.getInputName(i);
112 setNameFn(block->getArgument(i), name);
128LogicalResult hw::checkParameterInContext(
129 Attribute value, ArrayAttr moduleParameters,
131 bool disallowParamRefs) {
134 if (isa<IntegerAttr>(value) || isa<FloatAttr>(value) ||
135 isa<StringAttr>(value) || isa<ParamVerbatimAttr>(value))
139 if (
auto expr = dyn_cast<ParamExprAttr>(value)) {
140 for (
auto op : expr.getOperands())
149 if (
auto parameterRef = dyn_cast<ParamDeclRefAttr>(value)) {
150 auto nameAttr = parameterRef.getName();
154 if (disallowParamRefs) {
155 instanceError([&](
auto &diag) {
156 diag <<
"parameter " << nameAttr
157 <<
" cannot be used as a default value for a parameter";
164 for (
auto param : moduleParameters) {
165 auto paramAttr = cast<ParamDeclAttr>(param);
166 if (paramAttr.getName() != nameAttr)
170 if (paramAttr.getType() == parameterRef.getType())
173 instanceError([&](
auto &diag) {
174 diag <<
"parameter " << nameAttr <<
" used with type "
175 << parameterRef.getType() <<
"; should have type "
176 << paramAttr.getType();
182 instanceError([&](
auto &diag) {
183 diag <<
"use of unknown parameter " << nameAttr;
189 instanceError([&](
auto &diag) {
190 diag <<
"invalid parameter value " << value;
202LogicalResult hw::checkParameterInContext(Attribute value, Operation *module,
204 bool disallowParamRefs) {
206 [&](
const std::function<bool(InFlightDiagnostic &)> &fn) {
208 auto diag = usingOp->emitOpError();
210 diag.attachNote(module->getLoc()) <<
"module declared here";
215 module->getAttrOfType<ArrayAttr>(
"parameters"),
216 emitError, disallowParamRefs);
221bool hw::isValidParameterExpression(Attribute attr, Operation *module) {
230 for (
auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) {
263#include "circt/Dialect/HW/HWCanonicalization.cpp.inc"
270void ConstantOp::print(OpAsmPrinter &p) {
272 p.printAttribute(getValueAttr());
273 p.printOptionalAttrDict((*this)->getAttrs(), {
"value"});
276ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
277 IntegerAttr valueAttr;
279 if (parser.parseAttribute(valueAttr,
"value", result.attributes) ||
280 parser.parseOptionalAttrDict(result.attributes))
283 result.addTypes(valueAttr.getType());
287LogicalResult ConstantOp::verify() {
291 "hw.constant attribute bitwidth doesn't match return type");
298void ConstantOp::build(OpBuilder &builder, OperationState &result,
299 const APInt &value) {
301 auto type = IntegerType::get(builder.getContext(), value.getBitWidth());
302 auto attr = builder.getIntegerAttr(type, value);
303 return build(builder, result, type, attr);
308void ConstantOp::build(OpBuilder &builder, OperationState &result,
310 return build(builder, result, value.getType(), value);
317void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type,
319 auto numBits = cast<IntegerType>(type).getWidth();
320 build(builder, result,
321 APInt(numBits, (uint64_t)value,
true,
325void ConstantOp::getAsmResultNames(
326 function_ref<
void(Value, StringRef)> setNameFn) {
327 auto intTy = getType();
328 auto intCst = getValue();
331 if (cast<IntegerType>(intTy).
getWidth() == 1)
332 return setNameFn(getResult(), intCst.isZero() ?
"false" :
"true");
335 SmallVector<char, 32> specialNameBuffer;
336 llvm::raw_svector_ostream specialName(specialNameBuffer);
337 specialName <<
'c' << intCst <<
'_' << intTy;
338 setNameFn(getResult(), specialName.str());
341OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
342 assert(adaptor.getOperands().empty() &&
"constant has no operands");
343 return getValueAttr();
354 ArrayRef<StringRef> ignoredAttrs = {}) {
355 auto names = op.getAttributeNames();
356 llvm::SmallDenseSet<StringRef> nameSet;
357 nameSet.reserve(names.size() + ignoredAttrs.size());
358 nameSet.insert(names.begin(), names.end());
359 nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end());
360 return llvm::any_of(op->getAttrs(), [&](
auto namedAttr) {
361 return !nameSet.contains(namedAttr.getName());
365void WireOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
367 auto nameAttr = (*this)->getAttrOfType<StringAttr>(
"name");
368 if (nameAttr && !nameAttr.getValue().empty())
369 setNameFn(getResult(), nameAttr.getValue());
372std::optional<size_t> WireOp::getTargetResultIndex() {
return 0; }
374OpFoldResult WireOp::fold(FoldAdaptor adaptor) {
383LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) {
389 if (wire.getInnerSymAttr())
394 if (
auto *inputOp = wire.getInput().getDefiningOp())
396 rewriter.modifyOpInPlace(inputOp,
397 [&] { inputOp->setAttr(
"sv.namehint", name); });
399 rewriter.replaceOp(wire, wire.getInput());
409 if (
auto typeAlias = dyn_cast<TypeAliasType>(type))
410 type = typeAlias.getCanonicalType();
412 if (
auto structType = dyn_cast<StructType>(type)) {
413 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
415 return op->emitOpError(
"expected array attribute for constant of type ")
417 if (structType.getElements().size() != arrayAttr.size())
418 return op->emitOpError(
"array attribute (")
419 << arrayAttr.size() <<
") has wrong size for struct constant ("
420 << structType.getElements().size() <<
")";
422 for (
auto [attr, fieldInfo] :
423 llvm::zip(arrayAttr.getValue(), structType.getElements())) {
427 }
else if (
auto arrayType = dyn_cast<ArrayType>(type)) {
428 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
430 return op->emitOpError(
"expected array attribute for constant of type ")
432 if (arrayType.getNumElements() != arrayAttr.size())
433 return op->emitOpError(
"array attribute (")
434 << arrayAttr.size() <<
") has wrong size for array constant ("
435 << arrayType.getNumElements() <<
")";
438 for (
auto attr : arrayAttr.getValue()) {
442 }
else if (
auto arrayType = dyn_cast<UnpackedArrayType>(type)) {
443 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
445 return op->emitOpError(
"expected array attribute for constant of type ")
448 if (arrayType.getNumElements() != arrayAttr.size())
449 return op->emitOpError(
"array attribute (")
451 <<
") has wrong size for unpacked array constant ("
452 << arrayType.getNumElements() <<
")";
454 for (
auto attr : arrayAttr.getValue()) {
458 }
else if (
auto enumType = dyn_cast<EnumType>(type)) {
459 auto stringAttr = dyn_cast<StringAttr>(attr);
461 return op->emitOpError(
"expected string attribute for constant of type ")
463 }
else if (
auto intType = dyn_cast<IntegerType>(type)) {
465 auto intAttr = dyn_cast<IntegerAttr>(attr);
467 return op->emitOpError(
"expected integer attribute for constant of type ")
470 if (intAttr.getValue().getBitWidth() != intType.getWidth())
471 return op->emitOpError(
"hw.constant attribute bitwidth "
472 "doesn't match return type");
473 }
else if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
474 if (typedAttr.getType() != type)
475 return op->emitOpError(
"typed attr doesn't match the return type ")
478 return op->emitOpError(
"unknown element type ") << type;
483LogicalResult AggregateConstantOp::verify() {
487OpFoldResult AggregateConstantOp::fold(FoldAdaptor) {
return getFieldsAttr(); }
495 if (p.parseType(resultType) || p.parseEqual() ||
496 p.parseAttribute(value, resultType))
503 p << resultType <<
" = ";
504 p.printAttributeWithoutType(value);
507LogicalResult ParamValueOp::verify() {
513OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) {
514 assert(adaptor.getOperands().empty() &&
"hw.param.value has no operands");
515 return getValueAttr();
524 return isa<HWModuleLike, InstanceOp>(moduleOrInstance);
530 return TypeSwitch<Operation *, FunctionType>(moduleOrInstance)
531 .Case<InstanceOp, InstanceChoiceOp>([](
auto instance) {
532 SmallVector<Type> inputs(instance->getOperandTypes());
533 SmallVector<Type> results(instance->getResultTypes());
534 return FunctionType::get(instance->getContext(), inputs, results);
537 [](
auto mod) {
return mod.getHWModuleType().getFuncType(); })
538 .Default([](Operation *op) {
539 return cast<FunctionType>(
540 cast<mlir::FunctionOpInterface>(op).getFunctionType());
548 auto nameAttr =
module->getAttrOfType<StringAttr>("verilogName");
552 return module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
555template <
typename ModuleTy>
557buildModule(OpBuilder &builder, OperationState &result, StringAttr name,
559 ArrayRef<NamedAttribute> attributes, StringAttr comment) {
560 using namespace mlir::function_interface_impl;
563 result.addAttribute(SymbolTable::getSymbolAttrName(), name);
565 SmallVector<Attribute> perPortAttrs;
566 SmallVector<ModulePort> portTypes;
568 for (
auto elt : ports) {
569 portTypes.push_back(elt);
570 llvm::SmallVector<NamedAttribute> portAttrs;
572 llvm::copy(elt.attrs, std::back_inserter(portAttrs));
573 perPortAttrs.push_back(builder.getDictionaryAttr(portAttrs));
578 parameters = builder.getArrayAttr({});
581 auto type = ModuleType::get(builder.getContext(), portTypes);
582 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
583 TypeAttr::get(type));
584 result.addAttribute(
"per_port_attrs",
586 result.addAttribute(
"parameters", parameters);
588 comment = builder.getStringAttr(
"");
589 result.addAttribute(
"comment", comment);
590 result.addAttributes(attributes);
596 MLIRContext *context, ArrayRef<std::pair<unsigned, PortInfo>> insertArgs,
597 ArrayRef<unsigned> removeArgs, ArrayRef<Attribute> oldArgNames,
598 ArrayRef<Type> oldArgTypes, ArrayRef<Attribute> oldArgAttrs,
599 ArrayRef<Location> oldArgLocs, SmallVector<Attribute> &newArgNames,
600 SmallVector<Type> &newArgTypes, SmallVector<Attribute> &newArgAttrs,
601 SmallVector<Location> &newArgLocs, Block *body =
nullptr) {
606 assert(llvm::is_sorted(insertArgs,
607 [](
auto &a,
auto &b) {
return a.first < b.first; }) &&
608 "insertArgs must be in ascending order");
609 assert(llvm::is_sorted(removeArgs, [](
auto &a,
auto &b) {
return a < b; }) &&
610 "removeArgs must be in ascending order");
613 auto oldArgCount = oldArgTypes.size();
614 auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size();
615 assert((
int)newArgCount >= 0);
617 newArgNames.reserve(newArgCount);
618 newArgTypes.reserve(newArgCount);
619 newArgAttrs.reserve(newArgCount);
620 newArgLocs.reserve(newArgCount);
622 auto exportPortAttrName = StringAttr::get(context,
"hw.exportPort");
623 auto emptyDictAttr = DictionaryAttr::get(context, {});
624 auto unknownLoc = UnknownLoc::get(context);
626 BitVector erasedIndices;
628 erasedIndices.resize(oldArgCount + insertArgs.size());
630 for (
unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) {
632 while (!insertArgs.empty() && insertArgs[0].first == argIdx) {
633 auto port = insertArgs[0].second;
635 !isa<InOutType>(port.type))
636 port.type = InOutType::get(port.type);
637 auto sym = port.getSym();
639 (sym && !sym.empty())
640 ? DictionaryAttr::get(context, {{exportPortAttrName, sym}})
642 newArgNames.push_back(port.name);
643 newArgTypes.push_back(port.type);
644 newArgAttrs.push_back(attr);
645 insertArgs = insertArgs.drop_front();
646 LocationAttr loc = port.loc ? port.loc : unknownLoc;
647 newArgLocs.push_back(loc);
649 body->insertArgument(idx++, port.type, loc);
651 if (argIdx == oldArgCount)
655 bool removed =
false;
656 while (!removeArgs.empty() && removeArgs[0] == argIdx) {
657 removeArgs = removeArgs.drop_front();
663 erasedIndices.set(idx);
665 newArgNames.push_back(oldArgNames[argIdx]);
666 newArgTypes.push_back(oldArgTypes[argIdx]);
667 newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr
668 : oldArgAttrs[argIdx]);
669 newArgLocs.push_back(oldArgLocs[argIdx]);
674 body->eraseArguments(erasedIndices);
676 assert(newArgNames.size() == newArgCount);
677 assert(newArgTypes.size() == newArgCount);
678 assert(newArgAttrs.size() == newArgCount);
679 assert(newArgLocs.size() == newArgCount);
693[[deprecated]]
static void
695 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
696 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
697 ArrayRef<unsigned> removeInputs,
698 ArrayRef<unsigned> removeOutputs, Block *body =
nullptr) {
699 auto moduleOp = cast<HWModuleLike>(op);
700 auto *context = moduleOp.getContext();
703 auto oldArgNames = moduleOp.getInputNames();
704 auto oldArgTypes = moduleOp.getInputTypes();
705 auto oldArgAttrs = moduleOp.getAllInputAttrs();
706 auto oldArgLocs = moduleOp.getInputLocs();
708 auto oldResultNames = moduleOp.getOutputNames();
709 auto oldResultTypes = moduleOp.getOutputTypes();
710 auto oldResultAttrs = moduleOp.getAllOutputAttrs();
711 auto oldResultLocs = moduleOp.getOutputLocs();
714 SmallVector<Attribute> newArgNames, newResultNames;
715 SmallVector<Type> newArgTypes, newResultTypes;
716 SmallVector<Attribute> newArgAttrs, newResultAttrs;
717 SmallVector<Location> newArgLocs, newResultLocs;
720 oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames,
721 newArgTypes, newArgAttrs, newArgLocs, body);
724 oldResultTypes, oldResultAttrs, oldResultLocs,
725 newResultNames, newResultTypes, newResultAttrs,
729 auto fnty = FunctionType::get(context, newArgTypes, newResultTypes);
731 moduleOp.setHWModuleType(modty);
732 moduleOp.setAllInputAttrs(newArgAttrs);
733 moduleOp.setAllOutputAttrs(newResultAttrs);
735 newArgLocs.append(newResultLocs.begin(), newResultLocs.end());
736 moduleOp.setAllPortLocs(newArgLocs);
739void HWModuleOp::build(OpBuilder &builder, OperationState &result,
741 ArrayAttr parameters,
742 ArrayRef<NamedAttribute> attributes, StringAttr comment,
743 bool shouldEnsureTerminator) {
744 buildModule<HWModuleOp>(builder, result, name, ports, parameters, attributes,
748 auto *bodyRegion = result.regions[0].get();
750 bodyRegion->push_back(body);
753 auto unknownLoc = builder.getUnknownLoc();
754 for (
auto port : ports.getInputs()) {
755 auto loc = port.loc ? Location(port.loc) : unknownLoc;
756 auto type = port.type;
757 if (port.isInOut() && !isa<InOutType>(type))
758 type = InOutType::get(type);
759 body->addArgument(type, loc);
763 auto unknownLocAttr = cast<LocationAttr>(unknownLoc);
764 SmallVector<Attribute> resultLocs;
765 for (
auto port : ports.getOutputs())
766 resultLocs.push_back(port.loc ? port.loc : unknownLocAttr);
767 result.addAttribute(
"result_locs", builder.getArrayAttr(resultLocs));
769 if (shouldEnsureTerminator)
770 HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location);
773void HWModuleOp::build(OpBuilder &builder, OperationState &result,
774 StringAttr name, ArrayRef<PortInfo> ports,
775 ArrayAttr parameters,
776 ArrayRef<NamedAttribute> attributes,
777 StringAttr comment) {
778 build(builder, result, name,
ModulePortInfo(ports), parameters, attributes,
782void HWModuleOp::build(OpBuilder &builder, OperationState &odsState,
785 ArrayRef<NamedAttribute> attributes,
786 StringAttr comment) {
787 build(builder, odsState, name, ports, parameters, attributes, comment,
789 auto *bodyRegion = odsState.regions[0].get();
790 OpBuilder::InsertionGuard guard(builder);
792 builder.setInsertionPointToEnd(&bodyRegion->front());
793 modBuilder(builder, accessor);
795 llvm::SmallVector<Value> outputOperands = accessor.getOutputOperands();
796 hw::OutputOp::create(builder, odsState.location, outputOperands);
799void HWModuleOp::modifyPorts(
800 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
801 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
802 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
810StringAttr HWModuleExternOp::getVerilogModuleNameAttr() {
811 if (
auto vName = getVerilogNameAttr())
814 return (*this)->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
817StringAttr HWModuleGeneratedOp::getVerilogModuleNameAttr() {
818 if (
auto vName = getVerilogNameAttr()) {
821 return (*this)->getAttrOfType<StringAttr>(
822 ::mlir::SymbolTable::getSymbolAttrName());
825void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
827 StringRef verilogName, ArrayAttr parameters,
828 ArrayRef<NamedAttribute> attributes) {
829 buildModule<HWModuleExternOp>(builder, result, name, ports, parameters,
833 LocationAttr unknownLoc = builder.getUnknownLoc();
834 SmallVector<Attribute> portLocs;
835 for (
auto elt : ports)
836 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
837 result.addAttribute(
"port_locs", builder.getArrayAttr(portLocs));
839 if (!verilogName.empty())
840 result.addAttribute(
"verilogName", builder.getStringAttr(verilogName));
843void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
844 StringAttr name, ArrayRef<PortInfo> ports,
845 StringRef verilogName, ArrayAttr parameters,
846 ArrayRef<NamedAttribute> attributes) {
847 build(builder, result, name,
ModulePortInfo(ports), verilogName, parameters,
851void HWModuleExternOp::modifyPorts(
852 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
853 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
854 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
859void HWModuleExternOp::appendOutputs(
860 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
862void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
863 FlatSymbolRefAttr genKind, StringAttr name,
865 StringRef verilogName, ArrayAttr parameters,
866 ArrayRef<NamedAttribute> attributes) {
867 buildModule<HWModuleGeneratedOp>(builder, result, name, ports, parameters,
870 LocationAttr unknownLoc = builder.getUnknownLoc();
871 SmallVector<Attribute> portLocs;
872 for (
auto elt : ports)
873 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
874 result.addAttribute(
"port_locs", builder.getArrayAttr(portLocs));
876 result.addAttribute(
"generatorKind", genKind);
877 if (!verilogName.empty())
878 result.addAttribute(
"verilogName", builder.getStringAttr(verilogName));
881void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
882 FlatSymbolRefAttr genKind, StringAttr name,
883 ArrayRef<PortInfo> ports, StringRef verilogName,
884 ArrayAttr parameters,
885 ArrayRef<NamedAttribute> attributes) {
886 build(builder, result, genKind, name,
ModulePortInfo(ports), verilogName,
887 parameters, attributes);
890void HWModuleGeneratedOp::modifyPorts(
891 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
892 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
893 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
898void HWModuleGeneratedOp::appendOutputs(
899 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
901static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
902 for (
auto &argAttr : attrs)
903 if (argAttr.getName() == name)
908template <
typename ModuleTy>
910 OperationState &result) {
912 using namespace mlir::function_interface_impl;
913 auto builder = parser.getBuilder();
914 auto loc = parser.getCurrentLocation();
917 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
921 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
926 FlatSymbolRefAttr kindAttr;
927 if constexpr (std::is_same_v<ModuleTy, HWModuleGeneratedOp>) {
928 if (parser.parseComma() ||
929 parser.parseAttribute(kindAttr,
"generatorKind", result.attributes)) {
935 ArrayAttr parameters;
939 SmallVector<module_like_impl::PortParse> ports;
945 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
949 parser.emitError(loc,
"explicit `parameters` attributes not allowed");
953 result.addAttribute(
"parameters", parameters);
954 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), modType);
958 SmallVector<Attribute> attrs;
959 for (
auto &port : ports)
960 attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
962 auto nonEmptyAttrsFn = [](Attribute attr) {
963 return attr && !cast<DictionaryAttr>(attr).empty();
965 if (llvm::any_of(attrs, nonEmptyAttrsFn))
966 result.addAttribute(ModuleTy::getPerPortAttrsAttrName(result.name),
967 builder.getArrayAttr(attrs));
970 auto unknownLoc = builder.getUnknownLoc();
971 auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
972 return attr && cast<Location>(attr) != unknownLoc;
974 SmallVector<Attribute> locs;
975 StringAttr portLocsAttrName;
976 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>) {
979 portLocsAttrName = ModuleTy::getResultLocsAttrName(result.name);
980 for (
auto &port : ports)
982 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
985 portLocsAttrName = ModuleTy::getPortLocsAttrName(result.name);
986 for (
auto &port : ports)
987 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
989 if (llvm::any_of(locs, nonEmptyLocsFn))
990 result.addAttribute(portLocsAttrName, builder.getArrayAttr(locs));
993 SmallVector<OpAsmParser::Argument, 4> entryArgs;
994 for (
auto &port : ports)
996 entryArgs.push_back(port);
999 auto *body = result.addRegion();
1000 if (std::is_same_v<ModuleTy, HWModuleOp>) {
1001 if (parser.parseRegion(*body, entryArgs))
1004 HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
1009ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1010 return parseHWModuleOp<HWModuleOp>(parser, result);
1013ParseResult HWModuleExternOp::parse(OpAsmParser &parser,
1014 OperationState &result) {
1015 return parseHWModuleOp<HWModuleExternOp>(parser, result);
1018ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser,
1019 OperationState &result) {
1020 return parseHWModuleOp<HWModuleGeneratedOp>(parser, result);
1024 if (
auto mod = dyn_cast<HWModuleLike>(op))
1025 return mod.getHWModuleType().getFuncType();
1026 return cast<FunctionType>(
1027 cast<mlir::FunctionOpInterface>(op).getFunctionType());
1030template <
typename ModuleTy>
1034 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1035 if (
auto visibility = mod.getOperation()->template getAttrOfType<StringAttr>(
1036 visibilityAttrName))
1037 p << visibility.getValue() <<
' ';
1040 p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue());
1041 if (
auto gen = dyn_cast<HWModuleGeneratedOp>(mod.getOperation())) {
1043 p.printSymbolName(gen.getGeneratorKind());
1051 SmallVector<StringRef, 3> omittedAttrs;
1052 if (isa<HWModuleGeneratedOp>(mod.getOperation()))
1053 omittedAttrs.push_back(
"generatorKind");
1054 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>)
1055 omittedAttrs.push_back(mod.getResultLocsAttrName());
1057 omittedAttrs.push_back(mod.getPortLocsAttrName());
1058 omittedAttrs.push_back(mod.getModuleTypeAttrName());
1059 omittedAttrs.push_back(mod.getPerPortAttrsAttrName());
1060 omittedAttrs.push_back(mod.getParametersAttrName());
1061 omittedAttrs.push_back(visibilityAttrName);
1063 mod.getOperation()->template getAttrOfType<StringAttr>(
"comment"))
1064 if (cmt.getValue().empty())
1065 omittedAttrs.push_back(
"comment");
1067 mlir::function_interface_impl::printFunctionAttributes(p, mod.getOperation(),
1071void HWModuleExternOp::print(OpAsmPrinter &p) {
printModuleOp(p, *
this); }
1072void HWModuleGeneratedOp::print(OpAsmPrinter &p) {
printModuleOp(p, *
this); }
1074void HWModuleOp::print(OpAsmPrinter &p) {
1078 Region &body = getBody();
1079 if (!body.empty()) {
1081 p.printRegion(body,
false,
1087 assert(isa<HWModuleLike>(module) &&
1088 "verifier hook should only be called on modules");
1090 SmallPtrSet<Attribute, 4> paramNames;
1093 for (
auto param :
module->getAttrOfType<ArrayAttr>("parameters")) {
1094 auto paramAttr = cast<ParamDeclAttr>(param);
1098 if (!paramNames.insert(paramAttr.getName()).second)
1099 return module->emitOpError("parameter ")
1100 << paramAttr << " has the same name as a previous parameter";
1103 auto value = paramAttr.getValue();
1107 auto typedValue = dyn_cast<TypedAttr>(value);
1109 return module->emitOpError("parameter ")
1110 << paramAttr << " should have a typed value; has value
" << value;
1112 if (typedValue.getType() != paramAttr.getType())
1113 return module->emitOpError("parameter
")
1114 << paramAttr << " should have type
" << paramAttr.getType()
1115 << "; has type
" << typedValue.getType();
1117 // Verify that this is a valid parameter value, disallowing parameter
1118 // references. We could allow parameters to refer to each other in the
1119 // future with lexical ordering if there is a need.
1120 if (failed(checkParameterInContext(value, module, module,
1121 /*disallowParamRefs=*/true)))
1127LogicalResult HWModuleOp::verify() {
1128 if (failed(verifyModuleCommon(*this)))
1131 auto type = getModuleType();
1132 auto *body = getBodyBlock();
1134 // Verify the number of block arguments.
1135 auto numInputs = type.getNumInputs();
1136 if (body->getNumArguments() != numInputs)
1137 return emitOpError("entry block must have
")
1138 << numInputs << " arguments to match
module signature";
1145std::pair<StringAttr, BlockArgument>
1146HWModuleOp::insertInput(
unsigned index, StringAttr name, Type ty) {
1150 for (
auto port : ports)
1151 ns.newName(port.name.getValue());
1152 auto nameAttr = StringAttr::get(getContext(), ns.
newName(name.getValue()));
1158 port.
name = nameAttr;
1165 return {nameAttr, body->getArgument(index)};
1168void HWModuleOp::insertOutputs(
unsigned index,
1169 ArrayRef<std::pair<StringAttr, Value>> outputs) {
1171 auto output = cast<OutputOp>(
getBodyBlock()->getTerminator());
1172 assert(index <= output->getNumOperands() &&
"invalid output index");
1175 SmallVector<std::pair<unsigned, PortInfo>> indexedNewPorts;
1176 for (
auto &[name, value] : outputs) {
1180 port.
type = value.getType();
1181 indexedNewPorts.emplace_back(index, port);
1187 for (
auto &[name, value] : outputs)
1188 output->insertOperands(index++, value);
1191void HWModuleOp::appendOutputs(ArrayRef<std::pair<StringAttr, Value>> outputs) {
1192 return insertOutputs(getNumOutputPorts(), outputs);
1195void HWModuleOp::getAsmBlockArgumentNames(mlir::Region ®ion,
1200void HWModuleExternOp::getAsmBlockArgumentNames(
1205template <
typename ModTy>
1207 auto locs =
module.getPortLocs();
1209 SmallVector<Location> retval;
1210 retval.reserve(locs->size());
1211 for (
auto l : *locs)
1212 retval.push_back(cast<Location>(l));
1214 assert(!locs->size() || locs->size() == module.getNumPorts());
1217 return SmallVector<Location>(module.getNumPorts(),
1218 UnknownLoc::get(module.getContext()));
1221SmallVector<Location> HWModuleOp::getAllPortLocs() {
1222 SmallVector<Location> portLocs;
1224 auto resultLocs = getResultLocsAttr();
1225 unsigned inputCount = 0;
1227 auto unknownLoc = UnknownLoc::get(getContext());
1229 for (
unsigned i = 0, e =
getNumPorts(); i < e; ++i) {
1230 if (modType.isOutput(i)) {
1231 auto loc = resultLocs
1233 resultLocs.getValue()[portLocs.size() - inputCount])
1235 portLocs.push_back(loc);
1237 auto loc = body ? body->getArgument(inputCount).getLoc() : unknownLoc;
1238 portLocs.push_back(loc);
1245SmallVector<Location> HWModuleExternOp::getAllPortLocs() {
1246 return ::getAllPortLocs(*
this);
1249SmallVector<Location> HWModuleGeneratedOp::getAllPortLocs() {
1250 return ::getAllPortLocs(*
this);
1253void HWModuleOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1254 SmallVector<Attribute> resultLocs;
1255 unsigned inputCount = 0;
1258 for (
unsigned i = 0, e =
getNumPorts(); i < e; ++i) {
1259 if (modType.isOutput(i))
1260 resultLocs.push_back(locs[i]);
1262 body->getArgument(inputCount++).setLoc(cast<Location>(locs[i]));
1264 setResultLocsAttr(ArrayAttr::get(getContext(), resultLocs));
1267void HWModuleExternOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1268 setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1271void HWModuleGeneratedOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1272 setPortLocsAttr(ArrayAttr::get(getContext(), locs));
1275template <
typename ModTy>
1277 auto numInputs =
module.getNumInputPorts();
1278 SmallVector<Attribute> argNames(names.begin(), names.begin() + numInputs);
1279 SmallVector<Attribute> resNames(names.begin() + numInputs, names.end());
1280 auto oldType =
module.getModuleType();
1281 SmallVector<ModulePort> newPorts(oldType.getPorts().begin(),
1282 oldType.getPorts().end());
1283 for (
size_t i = 0UL, e = newPorts.size(); i != e; ++i)
1284 newPorts[i].name = cast<StringAttr>(names[i]);
1285 auto newType = ModuleType::get(module.getContext(), newPorts);
1286 module.setModuleType(newType);
1289void HWModuleOp::setAllPortNames(ArrayRef<Attribute> names) {
1293void HWModuleExternOp::setAllPortNames(ArrayRef<Attribute> names) {
1297void HWModuleGeneratedOp::setAllPortNames(ArrayRef<Attribute> names) {
1301ArrayRef<Attribute> HWModuleOp::getAllPortAttrs() {
1302 auto attrs = getPerPortAttrs();
1303 if (attrs && !attrs->empty())
1304 return attrs->getValue();
1308ArrayRef<Attribute> HWModuleExternOp::getAllPortAttrs() {
1309 auto attrs = getPerPortAttrs();
1310 if (attrs && !attrs->empty())
1311 return attrs->getValue();
1315ArrayRef<Attribute> HWModuleGeneratedOp::getAllPortAttrs() {
1316 auto attrs = getPerPortAttrs();
1317 if (attrs && !attrs->empty())
1318 return attrs->getValue();
1322void HWModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1323 setPerPortAttrsAttr(
arrayOrEmpty(getContext(), attrs));
1326void HWModuleExternOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1327 setPerPortAttrsAttr(
arrayOrEmpty(getContext(), attrs));
1330void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1331 setPerPortAttrsAttr(
arrayOrEmpty(getContext(), attrs));
1334void HWModuleOp::removeAllPortAttrs() {
1335 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1338void HWModuleExternOp::removeAllPortAttrs() {
1339 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1342void HWModuleGeneratedOp::removeAllPortAttrs() {
1343 setPerPortAttrsAttr(ArrayAttr::get(getContext(), {}));
1348template <
typename ModTy>
1350 auto argAttrs = mod.getAllInputAttrs();
1351 auto resAttrs = mod.getAllOutputAttrs();
1352 mod.setModuleTypeAttr(TypeAttr::get(type));
1353 unsigned newNumArgs = type.getNumInputs();
1354 unsigned newNumResults = type.getNumOutputs();
1356 auto emptyDict = DictionaryAttr::get(mod.getContext());
1357 argAttrs.resize(newNumArgs, emptyDict);
1358 resAttrs.resize(newNumResults, emptyDict);
1360 SmallVector<Attribute> attrs;
1361 attrs.append(argAttrs.begin(), argAttrs.end());
1362 attrs.append(resAttrs.begin(), resAttrs.end());
1365 return mod.removeAllPortAttrs();
1366 mod.setAllPortAttrs(attrs);
1369void HWModuleOp::setHWModuleType(ModuleType type) {
1370 return ::setHWModuleType(*
this, type);
1373void HWModuleExternOp::setHWModuleType(ModuleType type) {
1374 return ::setHWModuleType(*
this, type);
1377void HWModuleGeneratedOp::setHWModuleType(ModuleType type) {
1378 return ::setHWModuleType(*
this, type);
1383Operation *HWModuleGeneratedOp::getGeneratorKindOp() {
1384 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1385 return topLevelModuleOp.lookupSymbol(getGeneratorKind());
1389HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1390 auto *referencedKind =
1391 symbolTable.lookupNearestSymbolFrom(*
this, getGeneratorKindAttr());
1393 if (referencedKind ==
nullptr)
1394 return emitError(
"Cannot find generator definition '")
1395 << getGeneratorKind() <<
"'";
1397 if (!isa<HWGeneratorSchemaOp>(referencedKind))
1398 return emitError(
"Symbol resolved to '")
1399 << referencedKind->getName()
1400 <<
"' which is not a HWGeneratorSchemaOp";
1402 auto referencedKindOp = dyn_cast<HWGeneratorSchemaOp>(referencedKind);
1403 auto paramRef = referencedKindOp.getRequiredAttrs();
1404 auto dict = (*this)->getAttrDictionary();
1405 for (
auto str : paramRef) {
1406 auto strAttr = dyn_cast<StringAttr>(str);
1408 return emitError(
"Unknown attribute type, expected a string");
1409 if (!dict.get(strAttr.getValue()))
1410 return emitError(
"Missing attribute '") << strAttr.getValue() <<
"'";
1416LogicalResult HWModuleGeneratedOp::verify() {
1420void HWModuleGeneratedOp::getAsmBlockArgumentNames(
1425LogicalResult HWModuleOp::verifyBody() {
return success(); }
1427template <
typename ModuleTy>
1429 auto modTy = mod.getHWModuleType();
1430 auto emptyDict = DictionaryAttr::get(mod.getContext());
1431 SmallVector<PortInfo> retval;
1432 auto locs = mod.getAllPortLocs();
1433 for (
unsigned i = 0, e = modTy.getNumPorts(); i < e; ++i) {
1434 LocationAttr loc = locs[i];
1435 DictionaryAttr attrs =
1436 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(i));
1439 retval.push_back({modTy.getPorts()[i],
1440 modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
1441 : modTy.getInputIdForPortId(i),
1447template <
typename ModuleTy>
1449 auto modTy = mod.getHWModuleType();
1450 auto emptyDict = DictionaryAttr::get(mod.getContext());
1451 LocationAttr loc = mod.getPortLoc(idx);
1452 DictionaryAttr attrs =
1453 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(idx));
1456 return {modTy.getPorts()[idx],
1457 modTy.isOutput(idx) ? modTy.getOutputIdForPortId(idx)
1458 : modTy.getInputIdForPortId(idx),
1467void InstanceOp::build(OpBuilder &builder, OperationState &result,
1468 Operation *module, StringAttr name,
1469 ArrayRef<Value> inputs, ArrayAttr parameters,
1470 InnerSymAttr innerSym) {
1472 parameters = builder.getArrayAttr({});
1474 auto mod = cast<hw::HWModuleLike>(module);
1475 auto argNames = builder.getArrayAttr(mod.getInputNames());
1476 auto resultNames = builder.getArrayAttr(mod.getOutputNames());
1481 ModuleType modType = mod.getHWModuleType();
1482 FailureOr<ModuleType> resolvedModType = modType.resolveParametricTypes(
1483 parameters, result.location,
false);
1484 if (succeeded(resolvedModType))
1485 modType = *resolvedModType;
1486 FunctionType funcType = resolvedModType->getFuncType();
1487 build(builder, result, funcType.getResults(), name,
1488 FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), inputs,
1489 argNames, resultNames, parameters, innerSym, {});
1492std::optional<size_t> InstanceOp::getTargetResultIndex() {
1494 return std::nullopt;
1497LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1499 *
this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(),
1500 getResultNames(), getParameters(), symbolTable);
1503LogicalResult InstanceOp::verify() {
1504 auto module = (*this)->getParentOfType<HWModuleOp>();
1508 auto moduleParameters =
module->getAttrOfType<ArrayAttr>("parameters");
1510 [&](
const std::function<bool(InFlightDiagnostic &)> &fn) {
1511 auto diag = emitOpError();
1513 diag.attachNote(module->getLoc()) <<
"module declared here";
1516 getParameters(), moduleParameters, emitError);
1519ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
1520 StringAttr instanceNameAttr;
1521 InnerSymAttr innerSym;
1522 FlatSymbolRefAttr moduleNameAttr;
1523 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1524 SmallVector<Type, 1> inputsTypes, allResultTypes;
1525 ArrayAttr argNames, resultNames, parameters;
1526 auto noneType = parser.getBuilder().getType<NoneType>();
1528 if (parser.parseAttribute(instanceNameAttr, noneType,
"instanceName",
1532 if (succeeded(parser.parseOptionalKeyword(
"sym"))) {
1535 if (parser.parseCustomAttributeWithFallback(innerSym))
1540 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1541 if (parser.parseAttribute(moduleNameAttr, noneType,
"moduleName",
1542 result.attributes) ||
1543 parser.getCurrentLocation(¶metersLoc) ||
1546 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1548 parser.parseArrow() ||
1550 parser.parseOptionalAttrDict(result.attributes)) {
1554 result.addAttribute(
"argNames", argNames);
1555 result.addAttribute(
"resultNames", resultNames);
1556 result.addAttribute(
"parameters", parameters);
1557 result.addTypes(allResultTypes);
1561void InstanceOp::print(OpAsmPrinter &p) {
1563 p.printAttributeWithoutType(getInstanceNameAttr());
1564 if (
auto attr = getInnerSymAttr()) {
1569 p.printAttributeWithoutType(getModuleNameAttr());
1576 p.printOptionalAttrDict(
1577 (*this)->getAttrs(),
1579 InnerSymbolTable::getInnerSymbolAttrName(),
"moduleName",
1580 "argNames",
"resultNames",
"parameters"});
1587std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
1589 return std::nullopt;
1593InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1594 for (Attribute name : getModuleNamesAttr()) {
1596 *
this, cast<FlatSymbolRefAttr>(name), getInputs(), getResultTypes(),
1597 getArgNames(), getResultNames(), getParameters(), symbolTable))) {
1604LogicalResult InstanceChoiceOp::verify() {
1605 auto module = (*this)->getParentOfType<HWModuleOp>();
1609 auto moduleParameters =
module->getAttrOfType<ArrayAttr>("parameters");
1611 [&](
const std::function<bool(InFlightDiagnostic &)> &fn) {
1612 auto diag = emitOpError();
1614 diag.attachNote(module->getLoc()) <<
"module declared here";
1617 getParameters(), moduleParameters, emitError);
1620ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
1621 OperationState &result) {
1622 StringAttr optionNameAttr;
1623 StringAttr instanceNameAttr;
1624 InnerSymAttr innerSym;
1625 SmallVector<Attribute> moduleNames;
1626 SmallVector<Attribute> caseNames;
1627 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1628 SmallVector<Type, 1> inputsTypes, allResultTypes;
1629 ArrayAttr argNames, resultNames, parameters;
1630 auto noneType = parser.getBuilder().getType<NoneType>();
1632 if (parser.parseAttribute(instanceNameAttr, noneType,
"instanceName",
1636 if (succeeded(parser.parseOptionalKeyword(
"sym"))) {
1639 if (parser.parseCustomAttributeWithFallback(innerSym))
1644 if (parser.parseKeyword(
"option") ||
1645 parser.parseAttribute(optionNameAttr, noneType,
"optionName",
1649 FlatSymbolRefAttr defaultModuleName;
1650 if (parser.parseAttribute(defaultModuleName))
1652 moduleNames.push_back(defaultModuleName);
1654 while (succeeded(parser.parseOptionalKeyword(
"or"))) {
1655 FlatSymbolRefAttr moduleName;
1656 StringAttr targetName;
1657 if (parser.parseAttribute(moduleName) ||
1658 parser.parseOptionalKeyword(
"if") || parser.parseAttribute(targetName))
1660 moduleNames.push_back(moduleName);
1661 caseNames.push_back(targetName);
1664 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1665 if (parser.getCurrentLocation(¶metersLoc) ||
1668 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1670 parser.parseArrow() ||
1672 parser.parseOptionalAttrDict(result.attributes)) {
1676 result.addAttribute(
"moduleNames",
1677 ArrayAttr::get(parser.getContext(), moduleNames));
1678 result.addAttribute(
"caseNames",
1679 ArrayAttr::get(parser.getContext(), caseNames));
1680 result.addAttribute(
"argNames", argNames);
1681 result.addAttribute(
"resultNames", resultNames);
1682 result.addAttribute(
"parameters", parameters);
1683 result.addTypes(allResultTypes);
1687void InstanceChoiceOp::print(OpAsmPrinter &p) {
1689 p.printAttributeWithoutType(getInstanceNameAttr());
1690 if (
auto attr = getInnerSymAttr()) {
1694 p <<
" option " << getOptionNameAttr() <<
' ';
1696 auto moduleNames = getModuleNamesAttr();
1697 auto caseNames = getCaseNamesAttr();
1698 assert(moduleNames.size() == caseNames.size() + 1);
1700 p.printAttributeWithoutType(moduleNames[0]);
1701 for (
size_t i = 0, n = caseNames.size(); i < n; ++i) {
1703 p.printAttributeWithoutType(moduleNames[i + 1]);
1705 p.printAttributeWithoutType(caseNames[i]);
1714 p.printOptionalAttrDict(
1715 (*this)->getAttrs(),
1717 InnerSymbolTable::getInnerSymbolAttrName(),
1718 "moduleNames",
"caseNames",
"argNames",
"resultNames",
1719 "parameters",
"optionName"});
1722ArrayAttr InstanceChoiceOp::getReferencedModuleNamesAttr() {
1723 SmallVector<Attribute> moduleNames;
1724 for (Attribute attr : getModuleNamesAttr()) {
1725 moduleNames.push_back(cast<FlatSymbolRefAttr>(attr).getAttr());
1727 return ArrayAttr::get(getContext(), moduleNames);
1735LogicalResult OutputOp::verify() {
1739 if (
auto mod = dyn_cast<HWModuleOp>((*this)->getParentOp()))
1740 modType = mod.getHWModuleType();
1742 emitOpError(
"must have a module parent");
1745 auto modResults = modType.getOutputTypes();
1746 OperandRange outputValues = getOperands();
1747 if (modResults.size() != outputValues.size()) {
1748 emitOpError(
"must have same number of operands as region results.");
1753 for (
size_t i = 0, e = modResults.size(); i < e; ++i) {
1754 if (modResults[i] != outputValues[i].getType()) {
1755 emitOpError(
"output types must match module. In "
1757 << i <<
", expected " << modResults[i] <<
", but got "
1758 << outputValues[i].getType() <<
".";
1773 if (p.parseType(type))
1774 return p.emitError(p.getCurrentLocation(),
"Expected type");
1775 auto arrType = type_dyn_cast<ArrayType>(type);
1777 return p.emitError(p.getCurrentLocation(),
"Expected !hw.array type");
1779 unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements());
1780 idxType = IntegerType::get(p.getBuilder().getContext(), idxWidth);
1786 p.printType(srcType);
1789ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
1790 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
1791 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
1794 if (parser.parseOperandList(operands) ||
1795 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1796 parser.parseType(elemType))
1799 if (operands.size() == 0)
1800 return parser.emitError(inputOperandsLoc,
1801 "Cannot construct an array of length 0");
1802 result.addTypes({ArrayType::get(elemType, operands.size())});
1804 for (
auto operand : operands)
1805 if (parser.resolveOperand(operand, elemType, result.operands))
1810void ArrayCreateOp::print(OpAsmPrinter &p) {
1812 p.printOperands(getInputs());
1813 p.printOptionalAttrDict((*this)->getAttrs());
1814 p <<
" : " << getInputs()[0].getType();
1817void ArrayCreateOp::build(OpBuilder &b, OperationState &state,
1818 ValueRange values) {
1819 assert(values.size() > 0 &&
"Cannot build array of zero elements");
1820 Type elemType = values[0].getType();
1823 [elemType](Value v) ->
bool {
return v.getType() == elemType; }) &&
1824 "All values must have same type.");
1825 build(b, state, ArrayType::get(elemType, values.size()), values);
1828LogicalResult ArrayCreateOp::verify() {
1829 unsigned returnSize = cast<ArrayType>(getType()).getNumElements();
1830 if (getInputs().size() != returnSize)
1835OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) {
1836 if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) { return !attr; }))
1838 return ArrayAttr::get(getContext(), adaptor.getInputs());
1848 auto baseValue = constBase.getValue();
1849 auto indexValue = constIndex.getValue();
1851 unsigned bits = baseValue.getBitWidth();
1852 assert(bits == indexValue.getBitWidth() &&
"mismatched widths");
1854 if (bits < 64 && offset >= (1ull << bits))
1857 APInt baseExt = baseValue.zextOrTrunc(bits + 1);
1858 APInt indexExt = indexValue.zextOrTrunc(bits + 1);
1859 return baseExt + offset == indexExt;
1867 PatternRewriter &rewriter) {
1869 auto arrayTy = hw::type_cast<ArrayType>(op.getType());
1870 if (arrayTy.getNumElements() <= 1)
1872 auto elemTy = arrayTy.getElementType();
1881 SmallVector<Chunk> chunks;
1882 for (Value value : llvm::reverse(op.getInputs())) {
1883 auto get = value.getDefiningOp<
ArrayGetOp>();
1887 Value input = get.getInput();
1888 Value index = get.getIndex();
1889 if (!chunks.empty()) {
1890 auto &c = *chunks.rbegin();
1891 if (c.input == get.getInput() &&
isOffset(c.index, index, c.size)) {
1897 chunks.push_back(Chunk{input, index, 1});
1901 if (chunks.size() == 1) {
1902 auto &chunk = chunks[0];
1903 rewriter.replaceOp(op, rewriter.createOrFold<
ArraySliceOp>(
1904 op.getLoc(), arrayTy, chunk.input, chunk.index));
1910 if (chunks.size() * 2 < arrayTy.getNumElements()) {
1911 SmallVector<Value> slices;
1912 for (
auto &chunk : llvm::reverse(chunks)) {
1913 auto sliceTy = ArrayType::get(elemTy, chunk.size);
1915 op.getLoc(), sliceTy, chunk.input, chunk.index));
1917 rewriter.replaceOpWithNewOp<
ArrayConcatOp>(op, arrayTy, slices);
1925 PatternRewriter &rewriter) {
1931Value ArrayCreateOp::getUniformElement() {
1932 if (!getInputs().
empty() && llvm::all_equal(getInputs()))
1933 return getInputs()[0];
1938 auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
1940 return std::nullopt;
1941 APInt idxAttr = idxOp.getValue();
1942 if (idxAttr.getBitWidth() > 64)
1943 return std::nullopt;
1944 return idxAttr.getLimitedValue();
1947LogicalResult ArraySliceOp::verify() {
1948 unsigned inputSize =
1949 type_cast<ArrayType>(getInput().getType()).getNumElements();
1950 if (llvm::Log2_64_Ceil(inputSize) !=
1951 getLowIndex().getType().getIntOrFloatBitWidth())
1953 "ArraySlice: index width must match clog2 of array size");
1957OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) {
1959 if (getType() == getInput().getType())
1964LogicalResult ArraySliceOp::canonicalize(
ArraySliceOp op,
1965 PatternRewriter &rewriter) {
1966 auto sliceTy = hw::type_cast<ArrayType>(op.getType());
1967 auto elemTy = sliceTy.getElementType();
1968 uint64_t sliceSize = sliceTy.getNumElements();
1972 if (sliceSize == 1) {
1976 rewriter.replaceOpWithNewOp<
ArrayCreateOp>(op, op.getType(),
1985 auto *inputOp = op.getInput().getDefiningOp();
1986 if (
auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
1988 if (inputSlice == op)
1991 auto inputIndex = inputSlice.getLowIndex();
1993 if (!inputOffsetOpt)
1996 uint64_t offset = *offsetOpt + *inputOffsetOpt;
1999 rewriter.replaceOpWithNewOp<
ArraySliceOp>(op, op.getType(),
2000 inputSlice.getInput(), lowIndex);
2004 if (
auto inputCreate = dyn_cast_or_null<ArrayCreateOp>(inputOp)) {
2006 auto inputs = inputCreate.getInputs();
2008 uint64_t begin = inputs.size() - *offsetOpt - sliceSize;
2009 rewriter.replaceOpWithNewOp<
ArrayCreateOp>(op, op.getType(),
2010 inputs.slice(begin, sliceSize));
2014 if (
auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2016 SmallVector<Value> chunks;
2017 uint64_t sliceStart = *offsetOpt;
2018 for (
auto input :
llvm::reverse(inputConcat.getInputs())) {
2020 uint64_t inputSize =
2021 hw::type_cast<ArrayType>(input.getType()).getNumElements();
2022 if (inputSize == 0 || inputSize <= sliceStart) {
2023 sliceStart -= inputSize;
2028 uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize);
2029 uint64_t cutSize = cutEnd - sliceStart;
2030 assert(cutSize != 0 &&
"slice cannot be empty");
2032 if (cutSize == inputSize) {
2034 assert(sliceStart == 0 &&
"invalid cut size");
2035 chunks.push_back(input);
2038 unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize);
2040 rewriter, op.getLoc(), rewriter.getIntegerType(width), sliceStart);
2042 rewriter, op.getLoc(), hw::ArrayType::get(elemTy, cutSize), input,
2047 sliceSize -= cutSize;
2052 assert(chunks.size() > 0 &&
"missing sliced items");
2053 if (chunks.size() == 1)
2054 rewriter.replaceOp(op, chunks[0]);
2057 op, llvm::to_vector(llvm::reverse(chunks)));
2068 SmallVectorImpl<Type> &inputTypes,
2071 uint64_t resultSize = 0;
2073 auto parseElement = [&]() -> ParseResult {
2075 if (p.parseType(ty))
2077 auto arrTy = type_dyn_cast<ArrayType>(ty);
2079 return p.emitError(p.getCurrentLocation(),
"Expected !hw.array type");
2080 if (elemType && elemType != arrTy.getElementType())
2081 return p.emitError(p.getCurrentLocation(),
"Expected array element type ")
2084 elemType = arrTy.getElementType();
2085 inputTypes.push_back(ty);
2086 resultSize += arrTy.getNumElements();
2090 if (p.parseCommaSeparatedList(parseElement))
2093 resultType = ArrayType::get(elemType, resultSize);
2098 TypeRange inputTypes, Type resultType) {
2099 llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; });
2102void ArrayConcatOp::build(OpBuilder &b, OperationState &state,
2103 ValueRange values) {
2104 assert(!values.empty() &&
"Cannot build array of zero elements");
2105 ArrayType arrayTy = cast<ArrayType>(values[0].getType());
2106 Type elemTy = arrayTy.getElementType();
2107 assert(llvm::all_of(values,
2108 [elemTy](Value v) ->
bool {
2109 return isa<ArrayType>(v.getType()) &&
2110 cast<ArrayType>(v.getType()).getElementType() ==
2113 "All values must be of ArrayType with the same element type.");
2115 uint64_t resultSize = 0;
2116 for (Value val : values)
2117 resultSize += cast<ArrayType>(val.getType()).getNumElements();
2118 build(b, state, ArrayType::get(elemTy, resultSize), values);
2121OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) {
2122 if (getInputs().size() == 1)
2123 return getInputs()[0];
2125 auto inputs = adaptor.getInputs();
2126 SmallVector<Attribute> array;
2127 for (
size_t i = 0, e = getNumOperands(); i < e; ++i) {
2130 llvm::copy(cast<ArrayAttr>(inputs[i]), std::back_inserter(array));
2132 return ArrayAttr::get(getContext(), array);
2137 for (
auto input : op.getInputs())
2141 SmallVector<Value> items;
2142 for (
auto input : op.getInputs()) {
2143 auto create = cast<ArrayCreateOp>(input.getDefiningOp());
2144 for (
auto item : create.getInputs())
2145 items.push_back(item);
2159 SmallVector<Location> locs;
2162 SmallVector<Value> items;
2163 std::optional<Slice> last;
2164 bool changed =
false;
2166 auto concatenate = [&] {
2171 items.push_back(last->op);
2179 auto loc = FusedLoc::get(op.getContext(), last->locs);
2180 auto origTy = hw::type_cast<ArrayType>(last->input.getType());
2181 auto arrayTy = ArrayType::get(origTy.getElementType(), last->size);
2183 loc, arrayTy, last->input, last->index));
2188 auto append = [&](Value op, Value input, Value index,
size_t size) {
2193 if (last->input == input &&
isOffset(last->index, index, last->size)) {
2196 last->locs.push_back(op.getLoc());
2201 last.emplace(Slice{input, index, size, op, {op.getLoc()}});
2204 for (
auto item : llvm::reverse(op.getInputs())) {
2206 auto size = hw::type_cast<ArrayType>(slice.getType()).getNumElements();
2207 append(item, slice.getInput(), slice.getLowIndex(), size);
2212 if (create.getInputs().size() == 1) {
2213 if (
auto get = create.getInputs()[0].getDefiningOp<
ArrayGetOp>()) {
2214 append(item, get.getInput(), get.getIndex(), 1);
2221 items.push_back(item);
2228 if (items.size() == 1) {
2229 rewriter.replaceOp(op, items[0]);
2231 std::reverse(items.begin(), items.end());
2238 PatternRewriter &rewriter) {
2254ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) {
2261 auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
2262 if (parser.parseKeyword(&field) || parser.parseColonType(type))
2265 auto fieldAttr = EnumFieldAttr::get(
2266 loc, StringAttr::get(parser.getContext(), field), type);
2271 result.addAttribute(
"field", fieldAttr);
2272 result.addTypes(type);
2277void EnumConstantOp::print(OpAsmPrinter &p) {
2278 p <<
" " << getField().getField().getValue() <<
" : "
2279 << getField().getType().getValue();
2282void EnumConstantOp::getAsmResultNames(
2283 function_ref<
void(Value, StringRef)> setNameFn) {
2284 setNameFn(getResult(), getField().getField().str());
2287void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState,
2288 EnumFieldAttr field) {
2289 return build(builder, odsState, field.getType().getValue(), field);
2292OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) {
2293 assert(adaptor.getOperands().empty() &&
"constant has no operands");
2294 return getFieldAttr();
2297LogicalResult EnumConstantOp::verify() {
2298 auto fieldAttr = getFieldAttr();
2299 auto fieldType = fieldAttr.getType().getValue();
2302 if (fieldType != getType())
2303 emitOpError(
"return type ")
2304 << getType() <<
" does not match attribute type " << fieldAttr;
2312LogicalResult EnumCmpOp::verify() {
2314 auto lhsType = type_cast<EnumType>(getLhs().getType());
2315 auto rhsType = type_cast<EnumType>(getRhs().getType());
2316 if (rhsType != lhsType)
2317 emitOpError(
"types do not match");
2325ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2326 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2327 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2328 Type declOrAliasType;
2330 if (parser.parseLParen() || parser.parseOperandList(operands) ||
2331 parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
2332 parser.parseColonType(declOrAliasType))
2335 auto declType = type_dyn_cast<StructType>(declOrAliasType);
2337 return parser.emitError(parser.getNameLoc(),
2338 "expected !hw.struct type or alias");
2340 llvm::SmallVector<Type, 4> structInnerTypes;
2341 declType.getInnerTypes(structInnerTypes);
2342 result.addTypes(declOrAliasType);
2344 if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc,
2350void StructCreateOp::print(OpAsmPrinter &printer) {
2352 printer.printOperands(getInput());
2354 printer.printOptionalAttrDict((*this)->getAttrs());
2355 printer <<
" : " << getType();
2358LogicalResult StructCreateOp::verify() {
2359 auto elements = hw::type_cast<StructType>(getType()).getElements();
2361 if (elements.size() != getInput().size())
2362 return emitOpError(
"structure field count mismatch");
2364 for (
const auto &[field, value] :
llvm::zip(elements, getInput()))
2365 if (field.type != value.getType())
2366 return emitOpError(
"structure field `")
2367 << field.name <<
"` type does not match";
2372OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
2374 if (!getInput().
empty())
2375 if (
auto explodeOp = getInput()[0].getDefiningOp<StructExplodeOp>();
2376 explodeOp && getInput() == explodeOp.getResults() &&
2377 getResult().getType() == explodeOp.getInput().getType())
2378 return explodeOp.getInput();
2380 auto inputs = adaptor.getInput();
2381 if (llvm::any_of(inputs, [](Attribute attr) {
return !attr; }))
2383 return ArrayAttr::get(getContext(), inputs);
2390ParseResult StructExplodeOp::parse(OpAsmParser &parser,
2391 OperationState &result) {
2392 OpAsmParser::UnresolvedOperand operand;
2395 if (parser.parseOperand(operand) ||
2396 parser.parseOptionalAttrDict(result.attributes) ||
2397 parser.parseColonType(declType))
2399 auto structType = type_dyn_cast<StructType>(declType);
2401 return parser.emitError(parser.getNameLoc(),
2402 "invalid kind of type specified");
2404 llvm::SmallVector<Type, 4> structInnerTypes;
2405 structType.getInnerTypes(structInnerTypes);
2406 result.addTypes(structInnerTypes);
2408 if (parser.resolveOperand(operand, declType, result.operands))
2413void StructExplodeOp::print(OpAsmPrinter &printer) {
2415 printer.printOperand(getInput());
2416 printer.printOptionalAttrDict((*this)->getAttrs());
2417 printer <<
" : " << getInput().getType();
2420LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor,
2421 SmallVectorImpl<OpFoldResult> &results) {
2422 auto input = adaptor.getInput();
2425 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(results));
2429LogicalResult StructExplodeOp::canonicalize(StructExplodeOp op,
2430 PatternRewriter &rewriter) {
2431 auto *inputOp = op.getInput().getDefiningOp();
2432 auto elements = type_cast<StructType>(op.getInput().getType()).getElements();
2433 auto result = failure();
2434 auto opResults = op.getResults();
2435 for (uint32_t index = 0; index < elements.size(); index++) {
2437 rewriter.replaceAllUsesWith(opResults[index], foldResult);
2444void StructExplodeOp::getAsmResultNames(
2445 function_ref<
void(Value, StringRef)> setNameFn) {
2446 auto structType = type_cast<StructType>(getInput().getType());
2447 for (
auto [res, field] :
llvm::zip(getResults(), structType.getElements()))
2448 setNameFn(res, field.name.str());
2451void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2453 StructType inputType = dyn_cast<StructType>(input.getType());
2455 SmallVector<Type, 16> fieldTypes;
2456 for (
auto field : inputType.getElements())
2457 fieldTypes.push_back(field.type);
2458 build(odsBuilder, odsState, fieldTypes, input);
2467template <
typename AggregateOp,
typename AggregateType>
2469 AggregateType aggType,
2471 auto index = op.getFieldIndex();
2472 if (index >= aggType.getElements().size())
2473 return op.emitOpError() <<
"field index " << index
2474 <<
" exceeds element count of aggregate type";
2478 return op.emitOpError()
2479 <<
"type " << aggType.getElements()[index].type
2480 <<
" of accessed field in aggregate at index " << index
2481 <<
" does not match expected type " <<
elementType;
2486LogicalResult StructExtractOp::verify() {
2487 return verifyAggregateFieldIndexAndType<StructExtractOp, StructType>(
2488 *
this, getInput().getType(), getType());
2493template <
typename AggregateType>
2495 OpAsmParser::UnresolvedOperand operand;
2496 StringAttr fieldName;
2499 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2500 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2501 parser.parseOptionalAttrDict(result.attributes) ||
2502 parser.parseColonType(declType))
2504 auto aggType = type_dyn_cast<AggregateType>(declType);
2506 return parser.emitError(parser.getNameLoc(),
2507 "invalid kind of type specified");
2509 auto fieldIndex = aggType.getFieldIndex(fieldName);
2511 parser.emitError(parser.getNameLoc(),
"field name '" +
2512 fieldName.getValue() +
2513 "' not found in aggregate type");
2518 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2519 result.addAttribute(
"fieldIndex", indexAttr);
2520 Type resultType = aggType.getElements()[*fieldIndex].type;
2521 result.addTypes(resultType);
2523 if (parser.resolveOperand(operand, declType, result.operands))
2530template <
typename AggType>
2533 printer.printOperand(op.getInput());
2534 printer <<
"[\"" << op.getFieldName() <<
"\"]";
2535 printer.printOptionalAttrDict(op->getAttrs(), {
"fieldIndex"});
2536 printer <<
" : " << op.getInput().getType();
2539ParseResult StructExtractOp::parse(OpAsmParser &parser,
2540 OperationState &result) {
2541 return parseExtractOp<StructType>(parser, result);
2544void StructExtractOp::print(OpAsmPrinter &printer) {
2548void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2549 Value input, StructType::FieldInfo field) {
2551 type_cast<StructType>(input.getType()).getFieldIndex(field.name);
2552 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2553 build(builder, odsState, field.type, input, *fieldIndex);
2556void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2557 Value input, StringAttr fieldName) {
2558 auto structType = type_cast<StructType>(input.getType());
2559 auto fieldIndex = structType.getFieldIndex(fieldName);
2560 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2561 auto resultType = structType.getElements()[*fieldIndex].type;
2562 build(builder, odsState, resultType, input, *fieldIndex);
2565OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
2566 if (
auto constOperand = adaptor.getInput()) {
2568 auto operandAttr = llvm::cast<ArrayAttr>(constOperand);
2569 return operandAttr.getValue()[getFieldIndex()];
2572 if (
auto foldResult =
2579 PatternRewriter &rewriter) {
2580 auto *inputOp = op.getInput().getDefiningOp();
2583 if (
auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
2584 if (structInject.getFieldIndex() != op.getFieldIndex()) {
2586 op, op.getType(), structInject.getInput(), op.getFieldIndexAttr());
2594void StructExtractOp::getAsmResultNames(
2595 function_ref<
void(Value, StringRef)> setNameFn) {
2603void StructInjectOp::build(OpBuilder &builder, OperationState &odsState,
2604 Value input, StringAttr fieldName, Value newValue) {
2605 auto structType = type_cast<StructType>(input.getType());
2606 auto fieldIndex = structType.getFieldIndex(fieldName);
2607 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2608 build(builder, odsState, input, *fieldIndex, newValue);
2611LogicalResult StructInjectOp::verify() {
2612 return verifyAggregateFieldIndexAndType<StructInjectOp, StructType>(
2613 *
this, getInput().getType(), getNewValue().getType());
2616ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) {
2617 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2618 OpAsmParser::UnresolvedOperand operand, val;
2619 StringAttr fieldName;
2622 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2623 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2624 parser.parseComma() || parser.parseOperand(val) ||
2625 parser.parseOptionalAttrDict(result.attributes) ||
2626 parser.parseColonType(declType))
2628 auto structType = type_dyn_cast<StructType>(declType);
2630 return parser.emitError(inputOperandsLoc,
"invalid kind of type specified");
2632 auto fieldIndex = structType.getFieldIndex(fieldName);
2634 parser.emitError(parser.getNameLoc(),
"field name '" +
2635 fieldName.getValue() +
2636 "' not found in aggregate type");
2641 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2642 result.addAttribute(
"fieldIndex", indexAttr);
2643 result.addTypes(declType);
2645 Type resultType = structType.getElements()[*fieldIndex].type;
2646 if (parser.resolveOperands({operand, val}, {declType, resultType},
2647 inputOperandsLoc, result.operands))
2652void StructInjectOp::print(OpAsmPrinter &printer) {
2654 printer.printOperand(getInput());
2656 printer.printOperand(getNewValue());
2657 printer.printOptionalAttrDict((*this)->getAttrs(), {
"fieldIndex"});
2658 printer <<
" : " << getInput().getType();
2661OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
2662 auto input = adaptor.getInput();
2663 auto newValue = adaptor.getNewValue();
2664 if (!input || !newValue)
2666 SmallVector<Attribute> array;
2667 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(array));
2668 array[getFieldIndex()] = newValue;
2669 return ArrayAttr::get(getContext(), array);
2672LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
2673 PatternRewriter &rewriter) {
2677 if (op->hasOneUse()) {
2678 auto &use = *op->use_begin();
2679 if (isa<StructInjectOp>(use.getOwner()) && use.getOperandNumber() == 0)
2684 SmallPtrSet<Operation *, 4> injects;
2685 DenseMap<StringAttr, Value> fields;
2688 StructInjectOp inject = op;
2691 if (!injects.insert(inject).second)
2694 fields.try_emplace(inject.getFieldNameAttr(), inject.getNewValue());
2695 input = inject.getInput();
2696 inject = dyn_cast_or_null<StructInjectOp>(input.getDefiningOp());
2698 assert(input &&
"missing input to inject chain");
2700 auto ty = hw::type_cast<StructType>(op.getType());
2701 auto elements = ty.getElements();
2704 if (fields.size() == elements.size()) {
2705 SmallVector<Value> createFields;
2706 for (
const auto &field : elements) {
2707 auto it = fields.find(field.name);
2708 assert(it != fields.end() &&
"missing field");
2709 createFields.push_back(it->second);
2711 rewriter.replaceOpWithNewOp<
StructCreateOp>(op, ty, createFields);
2716 if (injects.size() == fields.size())
2720 for (uint32_t fieldIndex = 0; fieldIndex < elements.size(); fieldIndex++) {
2721 auto it = fields.find(elements[fieldIndex].name);
2722 if (it == fields.end())
2724 input = StructInjectOp::create(rewriter, op.getLoc(), ty, input, fieldIndex,
2728 rewriter.replaceOp(op, input);
2736LogicalResult UnionCreateOp::verify() {
2737 return verifyAggregateFieldIndexAndType<UnionCreateOp, UnionType>(
2738 *
this, getType(), getInput().getType());
2741void UnionCreateOp::build(OpBuilder &builder, OperationState &odsState,
2742 Type unionType, StringAttr fieldName, Value input) {
2743 auto fieldIndex = type_cast<UnionType>(unionType).getFieldIndex(fieldName);
2744 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2745 build(builder, odsState, unionType, *fieldIndex, input);
2748ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2749 Type declOrAliasType;
2750 StringAttr fieldName;
2751 OpAsmParser::UnresolvedOperand input;
2752 llvm::SMLoc fieldLoc = parser.getCurrentLocation();
2754 if (parser.parseAttribute(fieldName) || parser.parseComma() ||
2755 parser.parseOperand(input) ||
2756 parser.parseOptionalAttrDict(result.attributes) ||
2757 parser.parseColonType(declOrAliasType))
2760 auto declType = type_dyn_cast<UnionType>(declOrAliasType);
2762 return parser.emitError(parser.getNameLoc(),
2763 "expected !hw.union type or alias");
2765 auto fieldIndex = declType.getFieldIndex(fieldName);
2767 parser.emitError(fieldLoc,
"cannot find union field '")
2768 << fieldName.getValue() <<
'\'';
2773 IntegerAttr::get(IntegerType::get(parser.getContext(), 32), *fieldIndex);
2774 result.addAttribute(
"fieldIndex", indexAttr);
2775 Type inputType = declType.getElements()[*fieldIndex].type;
2777 if (parser.resolveOperand(input, inputType, result.operands))
2779 result.addTypes({declOrAliasType});
2783void UnionCreateOp::print(OpAsmPrinter &printer) {
2785 printer.printOperand(getInput());
2786 printer.printOptionalAttrDict((*this)->getAttrs(), {
"fieldIndex"});
2787 printer <<
" : " << getType();
2794ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) {
2795 return parseExtractOp<UnionType>(parser, result);
2798void UnionExtractOp::print(OpAsmPrinter &printer) {
2802LogicalResult UnionExtractOp::inferReturnTypes(
2803 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
2804 DictionaryAttr attrs, mlir::OpaqueProperties properties,
2805 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
2806 Adaptor adaptor(operands, attrs, properties, regions);
2807 auto unionElements =
2808 hw::type_cast<UnionType>((adaptor.getInput().getType())).getElements();
2809 unsigned fieldIndex = adaptor.getFieldIndexAttr().getValue().getZExtValue();
2810 if (fieldIndex >= unionElements.size()) {
2812 mlir::emitError(*loc,
"field index " + Twine(fieldIndex) +
2813 " exceeds element count of aggregate type");
2816 results.push_back(unionElements[fieldIndex].type);
2820void UnionExtractOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2821 Value input, StringAttr fieldName) {
2822 auto unionType = type_cast<UnionType>(input.getType());
2823 auto fieldIndex = unionType.getFieldIndex(fieldName);
2824 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2825 auto resultType = unionType.getElements()[*fieldIndex].type;
2826 build(odsBuilder, odsState, resultType, input, *fieldIndex);
2838OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) {
2839 auto inputCst = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2840 auto indexCst = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
2845 auto indexVal = indexCst.getValue();
2846 if (indexVal.getBitWidth() < 64) {
2847 auto index = indexVal.getZExtValue();
2848 return inputCst[inputCst.size() - 1 - index];
2853 if (!inputCst.empty() && llvm::all_equal(inputCst))
2858 if (
auto bitcast = getInput().getDefiningOp<hw::BitcastOp>()) {
2859 auto intTy = dyn_cast<IntegerType>(getType());
2862 auto bitcastInputOp = bitcast.getInput().getDefiningOp<
hw::ConstantOp>();
2863 if (!bitcastInputOp)
2867 auto bitcastInputCst = bitcastInputOp.getValue();
2870 auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) *
2871 getType().getIntOrFloatBitWidth();
2873 return IntegerAttr::get(intTy, bitcastInputCst.lshr(startIdx).trunc(
2874 intTy.getIntOrFloatBitWidth()));
2878 if (
auto inject = getInput().getDefiningOp<ArrayInjectOp>())
2879 if (getIndex() == inject.getIndex())
2880 return inject.getElement();
2882 auto inputCreate = getInput().getDefiningOp<
ArrayCreateOp>();
2886 if (
auto uniformValue = inputCreate.getUniformElement())
2887 return uniformValue;
2889 if (!indexCst || indexCst.getValue().getBitWidth() > 64)
2892 uint64_t index = indexCst.getValue().getLimitedValue();
2893 auto createInputs = inputCreate.getInputs();
2894 if (index >= createInputs.size())
2896 return createInputs[createInputs.size() - index - 1];
2899LogicalResult ArrayGetOp::canonicalize(
ArrayGetOp op,
2900 PatternRewriter &rewriter) {
2905 auto *inputOp = op.getInput().getDefiningOp();
2906 if (
auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
2908 auto offsetOp = inputSlice.getLowIndex();
2913 uint64_t offset = *offsetOpt + *idxOpt;
2916 rewriter.replaceOpWithNewOp<
ArrayGetOp>(op, inputSlice.getInput(),
2921 if (
auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2923 uint64_t elemIndex = *idxOpt;
2924 for (
auto input :
llvm::reverse(inputConcat.getInputs())) {
2925 size_t size = hw::type_cast<ArrayType>(input.getType()).getNumElements();
2926 if (elemIndex >= size) {
2931 unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size);
2934 rewriter.getIntegerType(indexWidth), elemIndex);
2936 rewriter.replaceOpWithNewOp<
ArrayGetOp>(op, input, newIdxOp);
2945 if (
auto innerGet = dyn_cast_or_null<hw::ArrayGetOp>(inputOp)) {
2950 SmallVector<Value> newValues;
2951 for (
auto operand : create.getOperands())
2952 newValues.push_back(rewriter.createOrFold<
hw::
ArrayGetOp>(
2953 op.
getLoc(), operand, op.getIndex()));
2958 innerGet.getIndex());
2971OpFoldResult ArrayInjectOp::fold(FoldAdaptor adaptor) {
2972 auto inputAttr = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2973 auto indexAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
2974 auto elementAttr = adaptor.getElement();
2977 if (inputAttr && indexAttr && elementAttr) {
2978 if (
auto index = indexAttr.getValue().tryZExtValue()) {
2979 if (*index < inputAttr.size()) {
2980 SmallVector<Attribute> elements(inputAttr.getValue());
2981 elements[inputAttr.size() - 1 - *index] = elementAttr;
2982 return ArrayAttr::get(getContext(), elements);
2991 PatternRewriter &rewriter) {
2995 if (op->hasOneUse()) {
2996 auto &use = *op->use_begin();
2997 if (isa<ArrayInjectOp>(use.getOwner()) && use.getOperandNumber() == 0)
3002 auto arrayLength = type_cast<ArrayType>(op.getType()).getNumElements();
3005 while (
auto inject = input.getDefiningOp<ArrayInjectOp>()) {
3008 if (!matchPattern(inject.getIndex(), mlir::m_ConstantInt(&indexAPInt)))
3010 if (indexAPInt.getActiveBits() > 32)
3012 uint32_t index = indexAPInt.getZExtValue();
3017 if (index < arrayLength)
3018 elements.insert({index, inject.getElement()});
3021 input = inject.getInput();
3028 if (elements.size() == arrayLength) {
3029 SmallVector<Value, 4> operands;
3030 operands.reserve(arrayLength);
3031 for (uint32_t idx = 0; idx < arrayLength; ++idx)
3032 operands.push_back(elements.at(arrayLength - idx - 1));
3033 rewriter.replaceOpWithNewOp<
ArrayCreateOp>(op, op.getType(), operands);
3042 auto createOp = op.getInput().getDefiningOp<
ArrayCreateOp>();
3048 if (!matchPattern(op.getIndex(), mlir::m_ConstantInt(&indexAPInt)) ||
3049 !indexAPInt.ult(createOp.getInputs().size()))
3053 SmallVector<Value> elements = createOp.getInputs();
3054 elements[elements.size() - indexAPInt.getZExtValue() - 1] = op.getElement();
3059void ArrayInjectOp::getCanonicalizationPatterns(RewritePatternSet &
patterns,
3060 MLIRContext *context) {
3061 patterns.add<ArrayInjectToSameIndex>(context);
3070StringRef TypedeclOp::getPreferredName() {
3071 return getVerilogName().value_or(
getName());
3074Type TypedeclOp::getAliasType() {
3075 auto parentScope = cast<hw::TypeScopeOp>(getOperation()->getParentOp());
3076 return hw::TypeAliasType::get(
3077 SymbolRefAttr::get(parentScope.getSymNameAttr(),
3078 {FlatSymbolRefAttr::get(*this)}),
3086OpFoldResult BitcastOp::fold(FoldAdaptor) {
3089 if (getOperand().getType() == getType())
3090 return getOperand();
3095LogicalResult BitcastOp::canonicalize(
BitcastOp op, PatternRewriter &rewriter) {
3101 dyn_cast_or_null<BitcastOp>(op.getInput().getDefiningOp());
3104 auto bitcast = rewriter.createOrFold<
BitcastOp>(op.getLoc(), op.getType(),
3105 inputBitcast.getInput());
3106 rewriter.replaceOp(op, bitcast);
3110LogicalResult BitcastOp::verify() {
3112 return this->emitOpError(
"Bitwidth of input must match result");
3120bool HierPathOp::dropModule(StringAttr moduleToDrop) {
3121 SmallVector<Attribute, 4> newPath;
3122 bool updateMade =
false;
3123 for (
auto nameRef : getNamepath()) {
3125 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3126 if (ref.getModule() == moduleToDrop)
3129 newPath.push_back(ref);
3131 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3134 newPath.push_back(nameRef);
3138 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3142bool HierPathOp::inlineModule(StringAttr moduleToDrop) {
3143 SmallVector<Attribute, 4> newPath;
3144 bool updateMade =
false;
3145 StringRef inlinedInstanceName =
"";
3146 for (
auto nameRef : getNamepath()) {
3148 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3149 if (ref.getModule() == moduleToDrop) {
3150 inlinedInstanceName = ref.getName().getValue();
3152 }
else if (!inlinedInstanceName.empty()) {
3153 newPath.push_back(hw::InnerRefAttr::get(
3155 StringAttr::get(getContext(), inlinedInstanceName +
"_" +
3156 ref.getName().getValue())));
3157 inlinedInstanceName =
"";
3159 newPath.push_back(ref);
3161 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3164 newPath.push_back(nameRef);
3168 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3172bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) {
3173 SmallVector<Attribute, 4> newPath;
3174 bool updateMade =
false;
3175 for (
auto nameRef : getNamepath()) {
3177 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3178 if (ref.getModule() == oldMod) {
3179 newPath.push_back(hw::InnerRefAttr::get(newMod, ref.getName()));
3182 newPath.push_back(ref);
3184 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldMod) {
3185 newPath.push_back(FlatSymbolRefAttr::get(newMod));
3188 newPath.push_back(nameRef);
3192 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3196bool HierPathOp::updateModuleAndInnerRef(
3197 StringAttr oldMod, StringAttr newMod,
3198 const llvm::DenseMap<StringAttr, StringAttr> &innerSymRenameMap) {
3199 auto fromRef = FlatSymbolRefAttr::get(oldMod);
3200 if (oldMod == newMod)
3203 auto namepathNew = getNamepath().getValue().vec();
3204 bool updateMade =
false;
3206 for (
auto &element : namepathNew) {
3207 if (
auto innerRef = dyn_cast<hw::InnerRefAttr>(element)) {
3208 if (innerRef.getModule() != oldMod)
3210 auto symName = innerRef.getName();
3213 auto to = innerSymRenameMap.find(symName);
3214 if (to != innerSymRenameMap.end())
3215 symName = to->second;
3217 element = hw::InnerRefAttr::get(newMod, symName);
3220 if (element != fromRef)
3224 element = FlatSymbolRefAttr::get(newMod);
3228 setNamepathAttr(ArrayAttr::get(getContext(), namepathNew));
3232bool HierPathOp::truncateAtModule(StringAttr atMod,
bool includeMod) {
3233 SmallVector<Attribute, 4> newPath;
3234 bool updateMade =
false;
3235 for (
auto nameRef : getNamepath()) {
3237 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3238 if (ref.getModule() == atMod) {
3241 newPath.push_back(ref);
3243 newPath.push_back(ref);
3245 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == atMod && !includeMod)
3248 newPath.push_back(nameRef);
3254 setNamepathAttr(ArrayAttr::get(getContext(), newPath));
3259StringAttr HierPathOp::modPart(
unsigned i) {
3260 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3261 .Case<FlatSymbolRefAttr>([](
auto a) {
return a.getAttr(); })
3262 .Case<hw::InnerRefAttr>([](
auto a) {
return a.getModule(); });
3266StringAttr HierPathOp::root() {
3272bool HierPathOp::hasModule(StringAttr modName) {
3273 for (
auto nameRef : getNamepath()) {
3275 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3276 if (ref.getModule() == modName)
3279 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == modName)
3287bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName)
const {
3288 for (
auto nameRef : const_cast<HierPathOp *>(this)->getNamepath())
3289 if (auto ref = dyn_cast<
hw::InnerRefAttr>(nameRef))
3290 if (ref.
getName() == symName && ref.getModule() == modName)
3298StringAttr HierPathOp::refPart(
unsigned i) {
3299 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3300 .Case<FlatSymbolRefAttr>([](
auto a) {
return StringAttr({}); })
3301 .Case<hw::InnerRefAttr>([](
auto a) {
return a.getName(); });
3306StringAttr HierPathOp::ref() {
3308 return refPart(getNamepath().size() - 1);
3312StringAttr HierPathOp::leafMod() {
3314 return modPart(getNamepath().size() - 1);
3319bool HierPathOp::isModule() {
return !ref(); }
3323bool HierPathOp::isComponent() {
return (
bool)ref(); }
3339 ArrayAttr expectedModuleNames = {};
3340 auto checkExpectedModule = [&](Attribute name) -> LogicalResult {
3341 if (!expectedModuleNames)
3343 if (llvm::any_of(expectedModuleNames,
3344 [name](Attribute attr) {
return attr == name; }))
3346 auto diag = emitOpError() <<
"instance path is incorrect. Expected ";
3347 size_t n = expectedModuleNames.size();
3351 for (
size_t i = 0; i < n; ++i) {
3353 diag << ((i + 1 == n) ?
" or " :
", ");
3354 diag << cast<StringAttr>(expectedModuleNames[i]);
3356 diag <<
". Instead found: " << name;
3360 if (!getNamepath() || getNamepath().
empty())
3361 return emitOpError() <<
"the instance path cannot be empty";
3362 for (
unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) {
3363 hw::InnerRefAttr innerRef = dyn_cast<hw::InnerRefAttr>(getNamepath()[i]);
3365 return emitOpError()
3366 <<
"the instance path can only contain inner sym reference"
3367 <<
", only the leaf can refer to a module symbol";
3369 if (failed(checkExpectedModule(innerRef.getModule())))
3372 auto instOp = ns.
lookupOp<igraph::InstanceOpInterface>(innerRef);
3374 return emitOpError() <<
" module: " << innerRef.getModule()
3375 <<
" does not contain any instance with symbol: "
3376 << innerRef.getName();
3377 expectedModuleNames = instOp.getReferencedModuleNamesAttr();
3381 auto leafRef = getNamepath()[getNamepath().size() - 1];
3382 if (
auto innerRef = dyn_cast<hw::InnerRefAttr>(leafRef)) {
3383 if (!ns.
lookup(innerRef)) {
3384 return emitOpError() <<
" operation with symbol: " << innerRef
3385 <<
" was not found ";
3387 if (failed(checkExpectedModule(innerRef.getModule())))
3389 }
else if (failed(checkExpectedModule(
3390 cast<FlatSymbolRefAttr>(leafRef).getAttr()))) {
3396void HierPathOp::print(OpAsmPrinter &p) {
3400 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
3401 if (
auto visibility =
3402 getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
3403 p << visibility.getValue() <<
' ';
3405 p.printSymbolName(getSymName());
3407 llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) {
3408 if (
auto ref = dyn_cast<hw::InnerRefAttr>(attr)) {
3409 p.printSymbolName(ref.getModule().getValue());
3411 p.printSymbolName(ref.getName().getValue());
3413 p.printSymbolName(cast<FlatSymbolRefAttr>(attr).getValue());
3417 p.printOptionalAttrDict(
3418 (*this)->getAttrs(),
3419 {SymbolTable::getSymbolAttrName(),
"namepath", visibilityAttrName});
3422ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) {
3424 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
3428 if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(),
3433 SmallVector<Attribute> namepath;
3434 if (parser.parseCommaSeparatedList(
3435 OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
3436 auto loc = parser.getCurrentLocation();
3438 if (parser.parseAttribute(ref))
3442 auto pathLength = ref.getNestedReferences().size();
3443 if (pathLength == 0)
3445 FlatSymbolRefAttr::get(ref.getRootReference()));
3446 else if (pathLength == 1)
3447 namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(),
3448 ref.getLeafReference()));
3450 return parser.emitError(loc,
3451 "only one nested reference is allowed");
3455 result.addAttribute(
"namepath",
3456 ArrayAttr::get(parser.getContext(), namepath));
3458 if (parser.parseOptionalAttrDict(result.attributes))
3468void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
3469 EventControlAttr event, Value trigger,
3470 ValueRange inputs) {
3471 odsState.addOperands(trigger);
3472 odsState.addOperands(inputs);
3473 odsState.addAttribute(getEventAttrName(odsState.name), event);
3474 auto *r = odsState.addRegion();
3478 llvm::SmallVector<Location> argLocs;
3479 llvm::transform(inputs, std::back_inserter(argLocs),
3480 [&](Value v) {
return v.getLoc(); });
3481 b->addArguments(inputs.getTypes(), argLocs);
3489#define GET_OP_CLASSES
3490#include "circt/Dialect/HW/HW.cpp.inc"
assert(baseType &&"element must be base type")
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, ArrayRef< PortInfo > ports, ArrayAttr annotations, ArrayAttr layers)
void getAsmBlockArgumentNamesImpl(Operation *op, mlir::Region ®ion, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
static LogicalResult verifyModuleCommon(HWModuleLike module)
static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, Type resultType)
static LogicalResult canonicalizeArrayInjectChain(ArrayInjectOp op, PatternRewriter &rewriter)
static void printModuleOp(OpAsmPrinter &p, ModuleTy mod)
static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter)
static LogicalResult foldCreateToSlice(ArrayCreateOp op, PatternRewriter &rewriter)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context, ArrayRef< Attribute > attrs)
FunctionType getHWModuleOpType(Operation *op)
static void printExtractOp(OpAsmPrinter &printer, AggType op)
Use the same printer for both struct_extract and union_extract since the syntax is identical.
static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, TypeRange inputTypes, Type resultType)
static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, Type &idxType)
static void modifyModulePorts(Operation *op, ArrayRef< std::pair< unsigned, PortInfo > > insertInputs, ArrayRef< std::pair< unsigned, PortInfo > > insertOutputs, ArrayRef< unsigned > removeInputs, ArrayRef< unsigned > removeOutputs, Block *body=nullptr)
Insert and remove ports of a module.
static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex)
static bool hasAttribute(StringRef name, ArrayRef< NamedAttribute > attrs)
static void modifyModuleArgs(MLIRContext *context, ArrayRef< std::pair< unsigned, PortInfo > > insertArgs, ArrayRef< unsigned > removeArgs, ArrayRef< Attribute > oldArgNames, ArrayRef< Type > oldArgTypes, ArrayRef< Attribute > oldArgAttrs, ArrayRef< Location > oldArgLocs, SmallVector< Attribute > &newArgNames, SmallVector< Type > &newArgTypes, SmallVector< Attribute > &newArgAttrs, SmallVector< Location > &newArgLocs, Block *body=nullptr)
Internal implementation of argument/result insertion and removal on modules.
static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter)
static SmallVector< Location > getAllPortLocs(ModTy module)
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result)
Use the same parser for both struct_extract and union_extract since the syntax is identical.
static void setAllPortNames(ArrayRef< Attribute > names, ModTy module)
static void getAsmBlockArgumentNamesImpl(mlir::Region ®ion, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
static void setHWModuleType(ModTy &mod, ModuleType type)
static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, Type &resultType)
static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type)
static LogicalResult canonicalizeArrayInjectIntoCreate(ArrayInjectOp op, PatternRewriter &rewriter)
static std::optional< uint64_t > getUIntFromValue(Value value)
static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op, AggregateType aggType, Type elementType)
Ensure an aggregate op's field index is within the bounds of the aggregate type and the accessed fiel...
static PortInfo getPort(ModuleTy &mod, size_t idx)
static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, Type idxType)
static bool hasAdditionalAttributes(Op op, ArrayRef< StringRef > ignoredAttrs={})
Check whether an operation has any additional attributes set beyond its standard list of attributes r...
static ParseResult parseArrayConcatTypes(OpAsmParser &p, SmallVectorImpl< Type > &inputTypes, Type &resultType)
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static Location getLoc(DefSlot slot)
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
void setOutput(unsigned i, Value v)
Value getInput(unsigned i)
llvm::SmallVector< Value > outputOperands
llvm::SmallVector< Value > inputArgs
llvm::StringMap< unsigned > outputIdx
llvm::StringMap< unsigned > inputIdx
HWModulePortAccessor(Location loc, const ModulePortInfo &info, Region &bodyRegion)
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
This helps visit TypeOp nodes.
ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args)
ResultType visitUnhandledTypeOp(Operation *op, ExtraArgs... args)
This callback is invoked on any combinational operations that are not handled by the concrete visitor...
ResultType visitInvalidTypeOp(Operation *op, ExtraArgs... args)
This callback is invoked on any non-expression operations.
create(array_value, low_index, ret_type)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
uint64_t getWidth(Type t)
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
LogicalResult verifyParameterStructure(ArrayAttr parameters, ArrayAttr moduleParameters, const EmitErrorFn &emitError)
Check that all the parameter values specified to the instance are structurally valid.
std::function< void(std::function< bool(InFlightDiagnostic &)>)> EmitErrorFn
Whenever the nested function returns true, a note referring to the referenced module is attached to t...
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
bool isOffset(Value base, Value index, uint64_t offset)
llvm::function_ref< void(OpBuilder &, HWModulePortAccessor &)> HWModuleBuilder
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
LogicalResult checkParameterInContext(Attribute value, Operation *module, Operation *usingOp, bool disallowParamRefs=false)
Check parameter specified by value to see if it is valid within the scope of the specified module mod...
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
bool isAnyModuleOrInstance(Operation *module)
TODO: Move all these functions to a hw::ModuleLike interface.
StringAttr getVerilogModuleNameAttr(Operation *module)
Returns the verilog module name attribute or symbol name of any module-like operations.
mlir::Type getCanonicalType(mlir::Type type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
ParseResult parseOptionalParameterList(OpAsmParser &parser, ArrayAttr ¶meters)
Parse an parameter list if present.
void printOptionalParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
This class represents the namespace in which InnerRef's can be resolved.
InnerSymTarget lookup(hw::InnerRefAttr inner) const
Resolve the InnerRef to its target within this namespace, returning empty target if no such name exis...
Operation * lookupOp(hw::InnerRefAttr inner) const
Resolve the InnerRef to its target within this namespace, returning empty target if no such name exis...
This holds a decoded list of input/inout and output ports for a module or instance.
PortInfo & at(size_t idx)
size_t sizeOutputs() const
size_t sizeInputs() const
PortDirectionRange getOutputs()
This holds the name, type, direction of a module's ports.