24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Interfaces/FunctionImplementation.h"
27 #include "llvm/ADT/BitVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/StringSet.h"
31 using namespace circt;
33 using mlir::TypedAttr;
45 llvm_unreachable(
"unknown PortDirection");
49 hw::ArrayType arrayType =
51 assert(arrayType &&
"expected array type");
52 unsigned indexWidth = index.getType().getIntOrFloatBitWidth();
53 auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements());
54 return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1)
55 : indexWidth == requiredWidth;
60 struct IsCombClassifier :
public TypeOpVisitor<IsCombClassifier, bool> {
61 bool visitInvalidTypeOp(Operation *op) {
return false; }
62 bool visitUnhandledTypeOp(Operation *op) {
return true; }
65 return (op->getDialect() && op->getDialect()->getNamespace() ==
"comb") ||
71 if (
auto structCreate = dyn_cast_or_null<StructCreateOp>(inputOp)) {
72 return structCreate.getOperand(fieldIndex);
76 if (
auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
77 if (structInject.getFieldIndex() != fieldIndex)
79 return structInject.getNewValue();
85 ArrayRef<Attribute> attrs) {
90 if (a && !cast<DictionaryAttr>(a).
empty()) {
106 auto module = cast<HWModuleOp>(region.getParentOp());
108 auto *block = ®ion.front();
109 for (
size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
110 auto name = module.getInputName(i);
112 setNameFn(block->getArgument(i), name);
129 Attribute value, ArrayAttr moduleParameters,
131 bool disallowParamRefs) {
134 if (isa<IntegerAttr>(value) || isa<FloatAttr>(value) ||
135 isa<StringAttr>(value) || isa<ParamVerbatimAttr>(value))
139 if (
auto expr = dyn_cast<ParamExprAttr>(value)) {
140 for (
auto op : expr.getOperands())
149 if (
auto parameterRef = dyn_cast<ParamDeclRefAttr>(value)) {
150 auto nameAttr = parameterRef.getName();
154 if (disallowParamRefs) {
155 instanceError([&](
auto &diag) {
156 diag <<
"parameter " << nameAttr
157 <<
" cannot be used as a default value for a parameter";
164 for (
auto param : moduleParameters) {
165 auto paramAttr = cast<ParamDeclAttr>(param);
166 if (paramAttr.getName() != nameAttr)
170 if (paramAttr.getType() == parameterRef.getType())
173 instanceError([&](
auto &diag) {
174 diag <<
"parameter " << nameAttr <<
" used with type "
175 << parameterRef.getType() <<
"; should have type "
176 << paramAttr.getType();
182 instanceError([&](
auto &diag) {
183 diag <<
"use of unknown parameter " << nameAttr;
189 instanceError([&](
auto &diag) {
190 diag <<
"invalid parameter value " << value;
204 bool disallowParamRefs) {
206 [&](
const std::function<bool(InFlightDiagnostic &)> &fn) {
208 auto diag = usingOp->emitOpError();
210 diag.attachNote(module->getLoc()) <<
"module declared here";
215 module->getAttrOfType<ArrayAttr>(
"parameters"),
216 emitError, disallowParamRefs);
230 for (
auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) {
262 void ConstantOp::print(OpAsmPrinter &p) {
264 p.printAttribute(getValueAttr());
265 p.printOptionalAttrDict((*this)->getAttrs(), {
"value"});
268 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
269 IntegerAttr valueAttr;
271 if (parser.parseAttribute(valueAttr,
"value", result.attributes) ||
272 parser.parseOptionalAttrDict(result.attributes))
275 result.addTypes(valueAttr.getType());
283 "hw.constant attribute bitwidth doesn't match return type");
290 void ConstantOp::build(OpBuilder &builder, OperationState &result,
291 const APInt &value) {
294 auto attr = builder.getIntegerAttr(type, value);
295 return build(builder, result, type, attr);
300 void ConstantOp::build(OpBuilder &builder, OperationState &result,
302 return build(builder, result, value.getType(), value);
309 void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type,
311 auto numBits = cast<IntegerType>(type).getWidth();
312 build(builder, result, APInt(numBits, (uint64_t)value,
true));
316 function_ref<
void(Value, StringRef)> setNameFn) {
317 auto intTy = getType();
318 auto intCst = getValue();
321 if (cast<IntegerType>(intTy).
getWidth() == 1)
322 return setNameFn(getResult(), intCst.isZero() ?
"false" :
"true");
325 SmallVector<char, 32> specialNameBuffer;
326 llvm::raw_svector_ostream specialName(specialNameBuffer);
327 specialName <<
'c' << intCst <<
'_' << intTy;
328 setNameFn(getResult(), specialName.str());
331 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
332 assert(adaptor.getOperands().empty() &&
"constant has no operands");
333 return getValueAttr();
344 ArrayRef<StringRef> ignoredAttrs = {}) {
345 auto names = op.getAttributeNames();
346 llvm::SmallDenseSet<StringRef> nameSet;
347 nameSet.reserve(names.size() + ignoredAttrs.size());
348 nameSet.insert(names.begin(), names.end());
349 nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end());
350 return llvm::any_of(op->getAttrs(), [&](
auto namedAttr) {
351 return !nameSet.contains(namedAttr.getName());
357 auto nameAttr = (*this)->getAttrOfType<StringAttr>(
"name");
358 if (nameAttr && !nameAttr.getValue().empty())
359 setNameFn(getResult(), nameAttr.getValue());
362 std::optional<size_t> WireOp::getTargetResultIndex() {
return 0; }
364 OpFoldResult WireOp::fold(FoldAdaptor adaptor) {
379 if (wire.getInnerSymAttr())
384 if (
auto *inputOp = wire.getInput().getDefiningOp())
386 rewriter.modifyOpInPlace(inputOp,
387 [&] { inputOp->setAttr(
"sv.namehint", name); });
389 rewriter.replaceOp(wire, wire.getInput());
399 if (
auto typeAlias = dyn_cast<TypeAliasType>(type))
400 type = typeAlias.getCanonicalType();
402 if (
auto structType = dyn_cast<StructType>(type)) {
403 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
405 return op->emitOpError(
"expected array attribute for constant of type ")
407 if (structType.getElements().size() != arrayAttr.size())
408 return op->emitOpError(
"array attribute (")
409 << arrayAttr.size() <<
") has wrong size for struct constant ("
410 << structType.getElements().size() <<
")";
412 for (
auto [attr, fieldInfo] :
413 llvm::zip(arrayAttr.getValue(), structType.getElements())) {
417 }
else if (
auto arrayType = dyn_cast<ArrayType>(type)) {
418 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
420 return op->emitOpError(
"expected array attribute for constant of type ")
422 if (arrayType.getNumElements() != arrayAttr.size())
423 return op->emitOpError(
"array attribute (")
424 << arrayAttr.size() <<
") has wrong size for array constant ("
425 << arrayType.getNumElements() <<
")";
428 for (
auto attr : arrayAttr.getValue()) {
432 }
else if (
auto arrayType = dyn_cast<UnpackedArrayType>(type)) {
433 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
435 return op->emitOpError(
"expected array attribute for constant of type ")
438 if (arrayType.getNumElements() != arrayAttr.size())
439 return op->emitOpError(
"array attribute (")
441 <<
") has wrong size for unpacked array constant ("
442 << arrayType.getNumElements() <<
")";
444 for (
auto attr : arrayAttr.getValue()) {
448 }
else if (
auto enumType = dyn_cast<EnumType>(type)) {
449 auto stringAttr = dyn_cast<StringAttr>(attr);
451 return op->emitOpError(
"expected string attribute for constant of type ")
453 }
else if (
auto intType = dyn_cast<IntegerType>(type)) {
455 auto intAttr = dyn_cast<IntegerAttr>(attr);
457 return op->emitOpError(
"expected integer attribute for constant of type ")
460 if (intAttr.getValue().getBitWidth() != intType.getWidth())
461 return op->emitOpError(
"hw.constant attribute bitwidth "
462 "doesn't match return type");
463 }
else if (
auto typedAttr = dyn_cast<TypedAttr>(attr)) {
464 if (typedAttr.getType() != type)
465 return op->emitOpError(
"typed attr doesn't match the return type ")
468 return op->emitOpError(
"unknown element type ") << type;
477 OpFoldResult AggregateConstantOp::fold(FoldAdaptor) {
return getFieldsAttr(); }
485 if (p.parseType(resultType) || p.parseEqual() ||
486 p.parseAttribute(value, resultType))
493 p << resultType <<
" = ";
494 p.printAttributeWithoutType(value);
503 OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) {
504 assert(adaptor.getOperands().empty() &&
"hw.param.value has no operands");
505 return getValueAttr();
514 return isa<HWModuleLike, InstanceOp>(moduleOrInstance);
520 return TypeSwitch<Operation *, FunctionType>(moduleOrInstance)
521 .Case<InstanceOp, InstanceChoiceOp>([](
auto instance) {
522 SmallVector<Type> inputs(instance->getOperandTypes());
523 SmallVector<Type> results(instance->getResultTypes());
527 [](
auto mod) {
return mod.getHWModuleType().getFuncType(); })
528 .Default([](Operation *op) {
529 return cast<FunctionType>(
530 cast<mlir::FunctionOpInterface>(op).getFunctionType());
538 auto nameAttr = module->getAttrOfType<StringAttr>(
"verilogName");
542 return module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
545 template <
typename ModuleTy>
547 buildModule(OpBuilder &builder, OperationState &result, StringAttr name,
549 ArrayRef<NamedAttribute> attributes, StringAttr comment) {
550 using namespace mlir::function_interface_impl;
553 result.addAttribute(SymbolTable::getSymbolAttrName(), name);
555 SmallVector<Attribute> perPortAttrs;
556 SmallVector<ModulePort> portTypes;
558 for (
auto elt : ports) {
559 portTypes.push_back(elt);
560 llvm::SmallVector<NamedAttribute> portAttrs;
562 llvm::copy(elt.attrs, std::back_inserter(portAttrs));
563 perPortAttrs.push_back(builder.getDictionaryAttr(portAttrs));
568 parameters = builder.getArrayAttr({});
572 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
574 result.addAttribute(
"per_port_attrs",
576 result.addAttribute(
"parameters", parameters);
578 comment = builder.getStringAttr(
"");
579 result.addAttribute(
"comment", comment);
580 result.addAttributes(attributes);
586 MLIRContext *context, ArrayRef<std::pair<unsigned, PortInfo>> insertArgs,
587 ArrayRef<unsigned> removeArgs, ArrayRef<Attribute> oldArgNames,
588 ArrayRef<Type> oldArgTypes, ArrayRef<Attribute> oldArgAttrs,
589 ArrayRef<Location> oldArgLocs, SmallVector<Attribute> &newArgNames,
590 SmallVector<Type> &newArgTypes, SmallVector<Attribute> &newArgAttrs,
591 SmallVector<Location> &newArgLocs, Block *body =
nullptr) {
596 assert(llvm::is_sorted(insertArgs,
597 [](
auto &a,
auto &b) {
return a.first < b.first; }) &&
598 "insertArgs must be in ascending order");
599 assert(llvm::is_sorted(removeArgs, [](
auto &a,
auto &b) {
return a < b; }) &&
600 "removeArgs must be in ascending order");
603 auto oldArgCount = oldArgTypes.size();
604 auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size();
605 assert((
int)newArgCount >= 0);
607 newArgNames.reserve(newArgCount);
608 newArgTypes.reserve(newArgCount);
609 newArgAttrs.reserve(newArgCount);
610 newArgLocs.reserve(newArgCount);
616 BitVector erasedIndices;
618 erasedIndices.resize(oldArgCount + insertArgs.size());
620 for (
unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) {
622 while (!insertArgs.empty() && insertArgs[0].first == argIdx) {
623 auto port = insertArgs[0].second;
625 !isa<InOutType>(port.type))
627 auto sym = port.getSym();
629 (sym && !sym.empty())
632 newArgNames.push_back(port.name);
633 newArgTypes.push_back(port.type);
634 newArgAttrs.push_back(attr);
635 insertArgs = insertArgs.drop_front();
636 LocationAttr loc = port.loc ? port.loc : unknownLoc;
637 newArgLocs.push_back(loc);
639 body->insertArgument(idx++, port.type, loc);
641 if (argIdx == oldArgCount)
645 bool removed =
false;
646 while (!removeArgs.empty() && removeArgs[0] == argIdx) {
647 removeArgs = removeArgs.drop_front();
653 erasedIndices.set(idx);
655 newArgNames.push_back(oldArgNames[argIdx]);
656 newArgTypes.push_back(oldArgTypes[argIdx]);
657 newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr
658 : oldArgAttrs[argIdx]);
659 newArgLocs.push_back(oldArgLocs[argIdx]);
664 body->eraseArguments(erasedIndices);
666 assert(newArgNames.size() == newArgCount);
667 assert(newArgTypes.size() == newArgCount);
668 assert(newArgAttrs.size() == newArgCount);
669 assert(newArgLocs.size() == newArgCount);
683 [[deprecated]]
static void
685 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
686 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
687 ArrayRef<unsigned> removeInputs,
688 ArrayRef<unsigned> removeOutputs, Block *body =
nullptr) {
689 auto moduleOp = cast<HWModuleLike>(op);
690 auto *context = moduleOp.getContext();
693 auto oldArgNames = moduleOp.getInputNames();
694 auto oldArgTypes = moduleOp.getInputTypes();
695 auto oldArgAttrs = moduleOp.getAllInputAttrs();
696 auto oldArgLocs = moduleOp.getInputLocs();
698 auto oldResultNames = moduleOp.getOutputNames();
699 auto oldResultTypes = moduleOp.getOutputTypes();
700 auto oldResultAttrs = moduleOp.getAllOutputAttrs();
701 auto oldResultLocs = moduleOp.getOutputLocs();
704 SmallVector<Attribute> newArgNames, newResultNames;
705 SmallVector<Type> newArgTypes, newResultTypes;
706 SmallVector<Attribute> newArgAttrs, newResultAttrs;
707 SmallVector<Location> newArgLocs, newResultLocs;
710 oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames,
711 newArgTypes, newArgAttrs, newArgLocs, body);
714 oldResultTypes, oldResultAttrs, oldResultLocs,
715 newResultNames, newResultTypes, newResultAttrs,
721 moduleOp.setHWModuleType(modty);
722 moduleOp.setAllInputAttrs(newArgAttrs);
723 moduleOp.setAllOutputAttrs(newResultAttrs);
725 newArgLocs.append(newResultLocs.begin(), newResultLocs.end());
726 moduleOp.setAllPortLocs(newArgLocs);
729 void HWModuleOp::build(OpBuilder &builder, OperationState &result,
731 ArrayAttr parameters,
732 ArrayRef<NamedAttribute> attributes, StringAttr comment,
733 bool shouldEnsureTerminator) {
734 buildModule<HWModuleOp>(builder, result, name, ports, parameters, attributes,
738 auto *bodyRegion = result.regions[0].get();
740 bodyRegion->push_back(body);
743 auto unknownLoc = builder.getUnknownLoc();
745 auto loc = port.loc ? Location(port.loc) : unknownLoc;
746 auto type = port.type;
747 if (port.isInOut() && !isa<InOutType>(type))
749 body->addArgument(type, loc);
753 auto unknownLocAttr = cast<LocationAttr>(unknownLoc);
754 SmallVector<Attribute> resultLocs;
756 resultLocs.push_back(port.loc ? port.loc : unknownLocAttr);
757 result.addAttribute(
"result_locs", builder.getArrayAttr(resultLocs));
759 if (shouldEnsureTerminator)
760 HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location);
763 void HWModuleOp::build(OpBuilder &builder, OperationState &result,
764 StringAttr name, ArrayRef<PortInfo> ports,
765 ArrayAttr parameters,
766 ArrayRef<NamedAttribute> attributes,
767 StringAttr comment) {
768 build(builder, result, name,
ModulePortInfo(ports), parameters, attributes,
772 void HWModuleOp::build(OpBuilder &builder, OperationState &odsState,
775 ArrayRef<NamedAttribute> attributes,
776 StringAttr comment) {
777 build(builder, odsState, name, ports, parameters, attributes, comment,
779 auto *bodyRegion = odsState.regions[0].get();
780 OpBuilder::InsertionGuard guard(builder);
782 builder.setInsertionPointToEnd(&bodyRegion->front());
783 modBuilder(builder, accessor);
785 llvm::SmallVector<Value> outputOperands = accessor.getOutputOperands();
786 builder.create<hw::OutputOp>(odsState.location, outputOperands);
789 void HWModuleOp::modifyPorts(
790 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
791 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
792 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
801 if (
auto vName = getVerilogNameAttr())
804 return (*this)->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
808 if (
auto vName = getVerilogNameAttr()) {
811 return (*this)->getAttrOfType<StringAttr>(
812 ::mlir::SymbolTable::getSymbolAttrName());
815 void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
817 StringRef verilogName, ArrayAttr parameters,
818 ArrayRef<NamedAttribute> attributes) {
819 buildModule<HWModuleExternOp>(builder, result, name, ports, parameters,
823 LocationAttr unknownLoc = builder.getUnknownLoc();
824 SmallVector<Attribute> portLocs;
825 for (
auto elt : ports)
826 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
827 result.addAttribute(
"port_locs", builder.getArrayAttr(portLocs));
829 if (!verilogName.empty())
830 result.addAttribute(
"verilogName", builder.getStringAttr(verilogName));
833 void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
834 StringAttr name, ArrayRef<PortInfo> ports,
835 StringRef verilogName, ArrayAttr parameters,
836 ArrayRef<NamedAttribute> attributes) {
837 build(builder, result, name,
ModulePortInfo(ports), verilogName, parameters,
841 void HWModuleExternOp::modifyPorts(
842 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
843 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
844 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
849 void HWModuleExternOp::appendOutputs(
850 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
852 void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
853 FlatSymbolRefAttr genKind, StringAttr name,
855 StringRef verilogName, ArrayAttr parameters,
856 ArrayRef<NamedAttribute> attributes) {
857 buildModule<HWModuleGeneratedOp>(builder, result, name, ports, parameters,
860 LocationAttr unknownLoc = builder.getUnknownLoc();
861 SmallVector<Attribute> portLocs;
862 for (
auto elt : ports)
863 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
864 result.addAttribute(
"port_locs", builder.getArrayAttr(portLocs));
866 result.addAttribute(
"generatorKind", genKind);
867 if (!verilogName.empty())
868 result.addAttribute(
"verilogName", builder.getStringAttr(verilogName));
871 void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
872 FlatSymbolRefAttr genKind, StringAttr name,
873 ArrayRef<PortInfo> ports, StringRef verilogName,
874 ArrayAttr parameters,
875 ArrayRef<NamedAttribute> attributes) {
876 build(builder, result, genKind, name,
ModulePortInfo(ports), verilogName,
877 parameters, attributes);
880 void HWModuleGeneratedOp::modifyPorts(
881 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
882 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
883 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
888 void HWModuleGeneratedOp::appendOutputs(
889 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
891 static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
892 for (
auto &argAttr : attrs)
893 if (argAttr.getName() == name)
898 template <
typename ModuleTy>
900 OperationState &result) {
902 using namespace mlir::function_interface_impl;
903 auto builder = parser.getBuilder();
904 auto loc = parser.getCurrentLocation();
907 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
911 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
916 FlatSymbolRefAttr kindAttr;
917 if constexpr (std::is_same_v<ModuleTy, HWModuleGeneratedOp>) {
918 if (parser.parseComma() ||
919 parser.parseAttribute(kindAttr,
"generatorKind", result.attributes)) {
925 ArrayAttr parameters;
929 SmallVector<module_like_impl::PortParse> ports;
935 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
939 parser.emitError(loc,
"explicit `parameters` attributes not allowed");
943 result.addAttribute(
"parameters", parameters);
944 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), modType);
948 SmallVector<Attribute> attrs;
949 for (
auto &port : ports)
950 attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
952 auto nonEmptyAttrsFn = [](Attribute attr) {
953 return attr && !cast<DictionaryAttr>(attr).empty();
955 if (llvm::any_of(attrs, nonEmptyAttrsFn))
956 result.addAttribute(ModuleTy::getPerPortAttrsAttrName(result.name),
957 builder.getArrayAttr(attrs));
960 auto unknownLoc = builder.getUnknownLoc();
961 auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
962 return attr && cast<Location>(attr) != unknownLoc;
964 SmallVector<Attribute> locs;
965 StringAttr portLocsAttrName;
966 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>) {
969 portLocsAttrName = ModuleTy::getResultLocsAttrName(result.name);
970 for (
auto &port : ports)
972 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
975 portLocsAttrName = ModuleTy::getPortLocsAttrName(result.name);
976 for (
auto &port : ports)
977 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
979 if (llvm::any_of(locs, nonEmptyLocsFn))
980 result.addAttribute(portLocsAttrName, builder.getArrayAttr(locs));
983 SmallVector<OpAsmParser::Argument, 4> entryArgs;
984 for (
auto &port : ports)
986 entryArgs.push_back(port);
989 auto *body = result.addRegion();
990 if (std::is_same_v<ModuleTy, HWModuleOp>) {
991 if (parser.parseRegion(*body, entryArgs))
994 HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
999 ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) {
1000 return parseHWModuleOp<HWModuleOp>(parser, result);
1003 ParseResult HWModuleExternOp::parse(OpAsmParser &parser,
1004 OperationState &result) {
1005 return parseHWModuleOp<HWModuleExternOp>(parser, result);
1008 ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser,
1009 OperationState &result) {
1010 return parseHWModuleOp<HWModuleGeneratedOp>(parser, result);
1014 if (
auto mod = dyn_cast<HWModuleLike>(op))
1015 return mod.getHWModuleType().getFuncType();
1016 return cast<FunctionType>(
1017 cast<mlir::FunctionOpInterface>(op).getFunctionType());
1020 template <
typename ModuleTy>
1024 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1025 if (
auto visibility = mod.getOperation()->template getAttrOfType<StringAttr>(
1026 visibilityAttrName))
1027 p << visibility.getValue() <<
' ';
1030 p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue());
1031 if (
auto gen = dyn_cast<HWModuleGeneratedOp>(mod.getOperation())) {
1033 p.printSymbolName(gen.getGeneratorKind());
1041 SmallVector<StringRef, 3> omittedAttrs;
1042 if (isa<HWModuleGeneratedOp>(mod.getOperation()))
1043 omittedAttrs.push_back(
"generatorKind");
1044 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>)
1045 omittedAttrs.push_back(mod.getResultLocsAttrName());
1047 omittedAttrs.push_back(mod.getPortLocsAttrName());
1048 omittedAttrs.push_back(mod.getModuleTypeAttrName());
1049 omittedAttrs.push_back(mod.getPerPortAttrsAttrName());
1050 omittedAttrs.push_back(mod.getParametersAttrName());
1051 omittedAttrs.push_back(visibilityAttrName);
1053 mod.getOperation()->template getAttrOfType<StringAttr>(
"comment"))
1054 if (cmt.getValue().empty())
1055 omittedAttrs.push_back(
"comment");
1057 mlir::function_interface_impl::printFunctionAttributes(p, mod.getOperation(),
1061 void HWModuleExternOp::print(OpAsmPrinter &p) {
printModuleOp(p, *
this); }
1062 void HWModuleGeneratedOp::print(OpAsmPrinter &p) {
printModuleOp(p, *
this); }
1064 void HWModuleOp::print(OpAsmPrinter &p) {
1068 Region &body = getBody();
1069 if (!body.empty()) {
1071 p.printRegion(body,
false,
1077 assert(isa<HWModuleLike>(module) &&
1078 "verifier hook should only be called on modules");
1080 SmallPtrSet<Attribute, 4> paramNames;
1083 for (
auto param : module->getAttrOfType<ArrayAttr>(
"parameters")) {
1084 auto paramAttr = cast<ParamDeclAttr>(param);
1088 if (!paramNames.insert(paramAttr.getName()).second)
1089 return module->emitOpError(
"parameter ")
1090 << paramAttr <<
" has the same name as a previous parameter";
1093 auto value = paramAttr.getValue();
1097 auto typedValue = dyn_cast<TypedAttr>(value);
1099 return module->emitOpError(
"parameter ")
1100 << paramAttr <<
" should have a typed value; has value " << value;
1102 if (typedValue.getType() != paramAttr.getType())
1103 return module->emitOpError(
"parameter ")
1104 << paramAttr <<
" should have type " << paramAttr.getType()
1105 <<
"; has type " << typedValue.getType();
1125 auto numInputs = type.getNumInputs();
1126 if (body->getNumArguments() != numInputs)
1127 return emitOpError(
"entry block must have")
1128 << numInputs <<
" arguments to match module signature";
1135 std::pair<StringAttr, BlockArgument>
1136 HWModuleOp::insertInput(
unsigned index, StringAttr name, Type ty) {
1140 for (
auto port : ports)
1141 ns.
newName(port.name.getValue());
1148 port.
name = nameAttr;
1155 return {nameAttr, body->getArgument(index)};
1158 void HWModuleOp::insertOutputs(
unsigned index,
1159 ArrayRef<std::pair<StringAttr, Value>> outputs) {
1161 auto output = cast<OutputOp>(
getBodyBlock()->getTerminator());
1162 assert(index <= output->getNumOperands() &&
"invalid output index");
1165 SmallVector<std::pair<unsigned, PortInfo>> indexedNewPorts;
1166 for (
auto &[name, value] : outputs) {
1170 port.
type = value.getType();
1171 indexedNewPorts.emplace_back(index, port);
1177 for (
auto &[name, value] : outputs)
1178 output->insertOperands(index++, value);
1181 void HWModuleOp::appendOutputs(ArrayRef<std::pair<StringAttr, Value>> outputs) {
1182 return insertOutputs(getNumOutputPorts(), outputs);
1185 void HWModuleOp::getAsmBlockArgumentNames(mlir::Region ®ion,
1190 void HWModuleExternOp::getAsmBlockArgumentNames(
1195 template <
typename ModTy>
1197 auto locs = module.getPortLocs();
1199 SmallVector<Location> retval;
1200 retval.reserve(locs->size());
1201 for (
auto l : *locs)
1202 retval.push_back(cast<Location>(l));
1204 assert(!locs->size() || locs->size() == module.getNumPorts());
1207 return SmallVector<Location>(module.getNumPorts(),
1212 SmallVector<Location> portLocs;
1214 auto resultLocs = getResultLocsAttr();
1215 unsigned inputCount = 0;
1219 for (
unsigned i = 0, e =
getNumPorts(); i < e; ++i) {
1220 if (modType.isOutput(i)) {
1221 auto loc = resultLocs
1223 resultLocs.getValue()[portLocs.size() - inputCount])
1225 portLocs.push_back(loc);
1227 auto loc = body ? body->getArgument(inputCount).getLoc() : unknownLoc;
1228 portLocs.push_back(loc);
1243 void HWModuleOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1244 SmallVector<Attribute> resultLocs;
1245 unsigned inputCount = 0;
1248 for (
unsigned i = 0, e =
getNumPorts(); i < e; ++i) {
1249 if (modType.isOutput(i))
1250 resultLocs.push_back(locs[i]);
1252 body->getArgument(inputCount++).setLoc(cast<Location>(locs[i]));
1257 void HWModuleExternOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1261 void HWModuleGeneratedOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1265 template <
typename ModTy>
1267 auto numInputs = module.getNumInputPorts();
1268 SmallVector<Attribute> argNames(names.begin(), names.begin() + numInputs);
1269 SmallVector<Attribute> resNames(names.begin() + numInputs, names.end());
1270 auto oldType = module.getModuleType();
1271 SmallVector<ModulePort> newPorts(oldType.getPorts().begin(),
1272 oldType.getPorts().end());
1273 for (
size_t i = 0UL, e = newPorts.size(); i != e; ++i)
1274 newPorts[i].name = cast<StringAttr>(names[i]);
1276 module.setModuleType(newType);
1291 ArrayRef<Attribute> HWModuleOp::getAllPortAttrs() {
1292 auto attrs = getPerPortAttrs();
1293 if (attrs && !attrs->empty())
1294 return attrs->getValue();
1298 ArrayRef<Attribute> HWModuleExternOp::getAllPortAttrs() {
1299 auto attrs = getPerPortAttrs();
1300 if (attrs && !attrs->empty())
1301 return attrs->getValue();
1305 ArrayRef<Attribute> HWModuleGeneratedOp::getAllPortAttrs() {
1306 auto attrs = getPerPortAttrs();
1307 if (attrs && !attrs->empty())
1308 return attrs->getValue();
1312 void HWModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1313 setPerPortAttrsAttr(
arrayOrEmpty(getContext(), attrs));
1316 void HWModuleExternOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1317 setPerPortAttrsAttr(
arrayOrEmpty(getContext(), attrs));
1320 void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1321 setPerPortAttrsAttr(
arrayOrEmpty(getContext(), attrs));
1324 void HWModuleOp::removeAllPortAttrs() {
1328 void HWModuleExternOp::removeAllPortAttrs() {
1332 void HWModuleGeneratedOp::removeAllPortAttrs() {
1338 template <
typename ModTy>
1340 auto argAttrs = mod.getAllInputAttrs();
1341 auto resAttrs = mod.getAllOutputAttrs();
1343 unsigned newNumArgs = type.getNumInputs();
1344 unsigned newNumResults = type.getNumOutputs();
1347 argAttrs.resize(newNumArgs, emptyDict);
1348 resAttrs.resize(newNumResults, emptyDict);
1350 SmallVector<Attribute> attrs;
1351 attrs.append(argAttrs.begin(), argAttrs.end());
1352 attrs.append(resAttrs.begin(), resAttrs.end());
1355 return mod.removeAllPortAttrs();
1356 mod.setAllPortAttrs(attrs);
1373 Operation *HWModuleGeneratedOp::getGeneratorKindOp() {
1374 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1375 return topLevelModuleOp.lookupSymbol(getGeneratorKind());
1379 HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1380 auto *referencedKind =
1381 symbolTable.lookupNearestSymbolFrom(*
this, getGeneratorKindAttr());
1383 if (referencedKind ==
nullptr)
1384 return emitError(
"Cannot find generator definition '")
1385 << getGeneratorKind() <<
"'";
1387 if (!isa<HWGeneratorSchemaOp>(referencedKind))
1388 return emitError(
"Symbol resolved to '")
1389 << referencedKind->getName()
1390 <<
"' which is not a HWGeneratorSchemaOp";
1392 auto referencedKindOp = dyn_cast<HWGeneratorSchemaOp>(referencedKind);
1393 auto paramRef = referencedKindOp.getRequiredAttrs();
1394 auto dict = (*this)->getAttrDictionary();
1395 for (
auto str : paramRef) {
1396 auto strAttr = dyn_cast<StringAttr>(str);
1398 return emitError(
"Unknown attribute type, expected a string");
1399 if (!dict.get(strAttr.getValue()))
1400 return emitError(
"Missing attribute '") << strAttr.getValue() <<
"'";
1410 void HWModuleGeneratedOp::getAsmBlockArgumentNames(
1415 LogicalResult HWModuleOp::verifyBody() {
return success(); }
1417 template <
typename ModuleTy>
1419 auto modTy = mod.getHWModuleType();
1421 SmallVector<PortInfo> retval;
1422 auto locs = mod.getAllPortLocs();
1423 for (
unsigned i = 0, e = modTy.getNumPorts(); i < e; ++i) {
1424 LocationAttr loc = locs[i];
1425 DictionaryAttr attrs =
1426 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(i));
1429 retval.push_back({modTy.getPorts()[i],
1430 modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
1431 : modTy.getInputIdForPortId(i),
1437 template <
typename ModuleTy>
1439 auto modTy = mod.getHWModuleType();
1441 LocationAttr loc = mod.getPortLoc(idx);
1442 DictionaryAttr attrs =
1443 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(idx));
1446 return {modTy.getPorts()[idx],
1447 modTy.isOutput(idx) ? modTy.getOutputIdForPortId(idx)
1448 : modTy.getInputIdForPortId(idx),
1457 void InstanceOp::build(OpBuilder &builder, OperationState &result,
1458 Operation *module, StringAttr name,
1459 ArrayRef<Value> inputs, ArrayAttr parameters,
1460 InnerSymAttr innerSym) {
1462 parameters = builder.getArrayAttr({});
1464 auto mod = cast<hw::HWModuleLike>(module);
1465 auto argNames = builder.getArrayAttr(mod.getInputNames());
1466 auto resultNames = builder.getArrayAttr(mod.getOutputNames());
1471 ModuleType modType = mod.getHWModuleType();
1472 FailureOr<ModuleType> resolvedModType = modType.resolveParametricTypes(
1473 parameters, result.location,
false);
1474 if (succeeded(resolvedModType))
1475 modType = *resolvedModType;
1476 FunctionType funcType = resolvedModType->getFuncType();
1477 build(builder, result, funcType.getResults(), name,
1479 argNames, resultNames, parameters, innerSym);
1482 std::optional<size_t> InstanceOp::getTargetResultIndex() {
1484 return std::nullopt;
1487 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1489 *
this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(),
1490 getResultNames(), getParameters(), symbolTable);
1494 auto module = (*this)->getParentOfType<
HWModuleOp>();
1498 auto moduleParameters = module->getAttrOfType<ArrayAttr>(
"parameters");
1500 [&](
const std::function<bool(InFlightDiagnostic &)> &fn) {
1501 auto diag = emitOpError();
1503 diag.attachNote(module->getLoc()) <<
"module declared here";
1506 getParameters(), moduleParameters, emitError);
1509 ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
1510 StringAttr instanceNameAttr;
1511 InnerSymAttr innerSym;
1512 FlatSymbolRefAttr moduleNameAttr;
1513 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1514 SmallVector<Type, 1> inputsTypes, allResultTypes;
1515 ArrayAttr argNames, resultNames, parameters;
1516 auto noneType = parser.getBuilder().getType<NoneType>();
1518 if (parser.parseAttribute(instanceNameAttr, noneType,
"instanceName",
1522 if (succeeded(parser.parseOptionalKeyword(
"sym"))) {
1525 if (parser.parseCustomAttributeWithFallback(innerSym))
1530 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1531 if (parser.parseAttribute(moduleNameAttr, noneType,
"moduleName",
1532 result.attributes) ||
1533 parser.getCurrentLocation(¶metersLoc) ||
1536 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1538 parser.parseArrow() ||
1540 parser.parseOptionalAttrDict(result.attributes)) {
1544 result.addAttribute(
"argNames", argNames);
1545 result.addAttribute(
"resultNames", resultNames);
1546 result.addAttribute(
"parameters", parameters);
1547 result.addTypes(allResultTypes);
1551 void InstanceOp::print(OpAsmPrinter &p) {
1553 p.printAttributeWithoutType(getInstanceNameAttr());
1554 if (
auto attr = getInnerSymAttr()) {
1559 p.printAttributeWithoutType(getModuleNameAttr());
1566 p.printOptionalAttrDict(
1567 (*this)->getAttrs(),
1569 InnerSymbolTable::getInnerSymbolAttrName(),
"moduleName",
1570 "argNames",
"resultNames",
"parameters"});
1577 std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
1579 return std::nullopt;
1583 InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1584 for (Attribute name : getModuleNamesAttr()) {
1586 *
this, cast<FlatSymbolRefAttr>(name), getInputs(), getResultTypes(),
1587 getArgNames(), getResultNames(), getParameters(), symbolTable))) {
1595 auto module = (*this)->getParentOfType<
HWModuleOp>();
1599 auto moduleParameters = module->getAttrOfType<ArrayAttr>(
"parameters");
1601 [&](
const std::function<bool(InFlightDiagnostic &)> &fn) {
1602 auto diag = emitOpError();
1604 diag.attachNote(module->getLoc()) <<
"module declared here";
1607 getParameters(), moduleParameters, emitError);
1610 ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
1611 OperationState &result) {
1612 StringAttr optionNameAttr;
1613 StringAttr instanceNameAttr;
1614 InnerSymAttr innerSym;
1615 SmallVector<Attribute> moduleNames;
1616 SmallVector<Attribute> caseNames;
1617 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1618 SmallVector<Type, 1> inputsTypes, allResultTypes;
1619 ArrayAttr argNames, resultNames, parameters;
1620 auto noneType = parser.getBuilder().getType<NoneType>();
1622 if (parser.parseAttribute(instanceNameAttr, noneType,
"instanceName",
1626 if (succeeded(parser.parseOptionalKeyword(
"sym"))) {
1629 if (parser.parseCustomAttributeWithFallback(innerSym))
1634 if (parser.parseKeyword(
"option") ||
1635 parser.parseAttribute(optionNameAttr, noneType,
"optionName",
1639 FlatSymbolRefAttr defaultModuleName;
1640 if (parser.parseAttribute(defaultModuleName))
1642 moduleNames.push_back(defaultModuleName);
1644 while (succeeded(parser.parseOptionalKeyword(
"or"))) {
1645 FlatSymbolRefAttr moduleName;
1646 StringAttr targetName;
1647 if (parser.parseAttribute(moduleName) ||
1648 parser.parseOptionalKeyword(
"if") || parser.parseAttribute(targetName))
1650 moduleNames.push_back(moduleName);
1651 caseNames.push_back(targetName);
1654 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1655 if (parser.getCurrentLocation(¶metersLoc) ||
1658 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1660 parser.parseArrow() ||
1662 parser.parseOptionalAttrDict(result.attributes)) {
1666 result.addAttribute(
"moduleNames",
1668 result.addAttribute(
"caseNames",
1670 result.addAttribute(
"argNames", argNames);
1671 result.addAttribute(
"resultNames", resultNames);
1672 result.addAttribute(
"parameters", parameters);
1673 result.addTypes(allResultTypes);
1677 void InstanceChoiceOp::print(OpAsmPrinter &p) {
1679 p.printAttributeWithoutType(getInstanceNameAttr());
1680 if (
auto attr = getInnerSymAttr()) {
1684 p <<
" option " << getOptionNameAttr() <<
' ';
1686 auto moduleNames = getModuleNamesAttr();
1687 auto caseNames = getCaseNamesAttr();
1688 assert(moduleNames.size() == caseNames.size() + 1);
1690 p.printAttributeWithoutType(moduleNames[0]);
1691 for (
size_t i = 0, n = caseNames.size(); i < n; ++i) {
1693 p.printAttributeWithoutType(moduleNames[i + 1]);
1695 p.printAttributeWithoutType(caseNames[i]);
1704 p.printOptionalAttrDict(
1705 (*this)->getAttrs(),
1707 InnerSymbolTable::getInnerSymbolAttrName(),
1708 "moduleNames",
"caseNames",
"argNames",
"resultNames",
1709 "parameters",
"optionName"});
1712 ArrayAttr InstanceChoiceOp::getReferencedModuleNamesAttr() {
1713 SmallVector<Attribute> moduleNames;
1714 for (Attribute attr : getModuleNamesAttr()) {
1715 moduleNames.push_back(cast<FlatSymbolRefAttr>(attr).getAttr());
1729 if (
auto mod = dyn_cast<HWModuleOp>((*this)->getParentOp()))
1730 modType = mod.getHWModuleType();
1732 emitOpError(
"must have a module parent");
1735 auto modResults = modType.getOutputTypes();
1736 OperandRange outputValues = getOperands();
1737 if (modResults.size() != outputValues.size()) {
1738 emitOpError(
"must have same number of operands as region results.");
1743 for (
size_t i = 0, e = modResults.size(); i < e; ++i) {
1744 if (modResults[i] != outputValues[i].getType()) {
1745 emitOpError(
"output types must match module. In "
1747 << i <<
", expected " << modResults[i] <<
", but got "
1748 << outputValues[i].getType() <<
".";
1763 if (p.parseType(type))
1764 return p.emitError(p.getCurrentLocation(),
"Expected type");
1765 auto arrType = type_dyn_cast<ArrayType>(type);
1767 return p.emitError(p.getCurrentLocation(),
"Expected !hw.array type");
1769 unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements());
1776 p.printType(srcType);
1779 ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
1780 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
1781 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
1784 if (parser.parseOperandList(operands) ||
1785 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1786 parser.parseType(elemType))
1789 if (operands.size() == 0)
1790 return parser.emitError(inputOperandsLoc,
1791 "Cannot construct an array of length 0");
1794 for (
auto operand : operands)
1795 if (parser.resolveOperand(operand, elemType, result.operands))
1800 void ArrayCreateOp::print(OpAsmPrinter &p) {
1802 p.printOperands(getInputs());
1803 p.printOptionalAttrDict((*this)->getAttrs());
1804 p <<
" : " << getInputs()[0].getType();
1807 void ArrayCreateOp::build(OpBuilder &b, OperationState &state,
1808 ValueRange values) {
1809 assert(values.size() > 0 &&
"Cannot build array of zero elements");
1810 Type elemType = values[0].getType();
1813 [elemType](Value v) ->
bool {
return v.getType() == elemType; }) &&
1814 "All values must have same type.");
1815 build(b, state,
ArrayType::get(elemType, values.size()), values);
1819 unsigned returnSize = cast<ArrayType>(getType()).getNumElements();
1820 if (getInputs().size() != returnSize)
1825 OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) {
1826 if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) { return !attr; }))
1838 auto baseValue = constBase.getValue();
1839 auto indexValue = constIndex.getValue();
1841 unsigned bits = baseValue.getBitWidth();
1842 assert(bits == indexValue.getBitWidth() &&
"mismatched widths");
1844 if (bits < 64 && offset >= (1ull << bits))
1847 APInt baseExt = baseValue.zextOrTrunc(bits + 1);
1848 APInt indexExt = indexValue.zextOrTrunc(bits + 1);
1849 return baseExt + offset == indexExt;
1857 PatternRewriter &rewriter) {
1859 auto arrayTy = hw::type_cast<ArrayType>(op.getType());
1860 if (arrayTy.getNumElements() <= 1)
1862 auto elemTy = arrayTy.getElementType();
1871 SmallVector<Chunk> chunks;
1872 for (Value value : llvm::reverse(op.getInputs())) {
1877 Value input =
get.getInput();
1878 Value index =
get.getIndex();
1879 if (!chunks.empty()) {
1880 auto &c = *chunks.rbegin();
1881 if (c.input ==
get.getInput() &&
isOffset(c.index, index, c.size)) {
1887 chunks.push_back(Chunk{input, index, 1});
1891 if (chunks.size() == 1) {
1892 auto &chunk = chunks[0];
1893 rewriter.replaceOp(op, rewriter.createOrFold<
ArraySliceOp>(
1894 op.getLoc(), arrayTy, chunk.input, chunk.index));
1900 if (chunks.size() * 2 < arrayTy.getNumElements()) {
1901 SmallVector<Value> slices;
1902 for (
auto &chunk : llvm::reverse(chunks)) {
1905 op.getLoc(), sliceTy, chunk.input, chunk.index));
1907 rewriter.replaceOpWithNewOp<
ArrayConcatOp>(op, arrayTy, slices);
1915 PatternRewriter &rewriter) {
1921 Value ArrayCreateOp::getUniformElement() {
1922 if (!getInputs().
empty() && llvm::all_equal(getInputs()))
1923 return getInputs()[0];
1928 auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
1930 return std::nullopt;
1931 APInt idxAttr = idxOp.getValue();
1932 if (idxAttr.getBitWidth() > 64)
1933 return std::nullopt;
1934 return idxAttr.getLimitedValue();
1938 unsigned inputSize =
1939 type_cast<ArrayType>(getInput().getType()).getNumElements();
1940 if (llvm::Log2_64_Ceil(inputSize) !=
1941 getLowIndex().getType().getIntOrFloatBitWidth())
1943 "ArraySlice: index width must match clog2 of array size");
1947 OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) {
1949 if (getType() == getInput().getType())
1955 PatternRewriter &rewriter) {
1956 auto sliceTy = hw::type_cast<ArrayType>(op.getType());
1957 auto elemTy = sliceTy.getElementType();
1958 uint64_t sliceSize = sliceTy.getNumElements();
1962 if (sliceSize == 1) {
1964 auto get = rewriter.create<
ArrayGetOp>(op.getLoc(), op.getInput(),
1966 rewriter.replaceOpWithNewOp<
ArrayCreateOp>(op, op.getType(),
1975 auto inputOp = op.getInput().getDefiningOp();
1976 if (
auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
1978 if (inputSlice == op)
1981 auto inputIndex = inputSlice.getLowIndex();
1983 if (!inputOffsetOpt)
1986 uint64_t offset = *offsetOpt + *inputOffsetOpt;
1988 rewriter.create<
ConstantOp>(op.getLoc(), inputIndex.getType(), offset);
1989 rewriter.replaceOpWithNewOp<
ArraySliceOp>(op, op.getType(),
1990 inputSlice.getInput(), lowIndex);
1994 if (
auto inputCreate = dyn_cast_or_null<ArrayCreateOp>(inputOp)) {
1996 auto inputs = inputCreate.getInputs();
1998 uint64_t begin = inputs.size() - *offsetOpt - sliceSize;
1999 rewriter.replaceOpWithNewOp<
ArrayCreateOp>(op, op.getType(),
2000 inputs.slice(begin, sliceSize));
2004 if (
auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2006 SmallVector<Value> chunks;
2007 uint64_t sliceStart = *offsetOpt;
2008 for (
auto input : llvm::reverse(inputConcat.getInputs())) {
2010 uint64_t inputSize =
2011 hw::type_cast<ArrayType>(input.getType()).getNumElements();
2012 if (inputSize == 0 || inputSize <= sliceStart) {
2013 sliceStart -= inputSize;
2018 uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize);
2019 uint64_t cutSize = cutEnd - sliceStart;
2020 assert(cutSize != 0 &&
"slice cannot be empty");
2022 if (cutSize == inputSize) {
2024 assert(sliceStart == 0 &&
"invalid cut size");
2025 chunks.push_back(input);
2028 unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize);
2030 op.getLoc(), rewriter.getIntegerType(
width), sliceStart);
2036 sliceSize -= cutSize;
2041 assert(chunks.size() > 0 &&
"missing sliced items");
2042 if (chunks.size() == 1)
2043 rewriter.replaceOp(op, chunks[0]);
2046 op, llvm::to_vector(llvm::reverse(chunks)));
2057 SmallVectorImpl<Type> &inputTypes,
2060 uint64_t resultSize = 0;
2062 auto parseElement = [&]() -> ParseResult {
2064 if (p.parseType(ty))
2066 auto arrTy = type_dyn_cast<ArrayType>(ty);
2068 return p.emitError(p.getCurrentLocation(),
"Expected !hw.array type");
2069 if (elemType && elemType != arrTy.getElementType())
2070 return p.emitError(p.getCurrentLocation(),
"Expected array element type ")
2073 elemType = arrTy.getElementType();
2074 inputTypes.push_back(ty);
2075 resultSize += arrTy.getNumElements();
2079 if (p.parseCommaSeparatedList(parseElement))
2087 TypeRange inputTypes, Type resultType) {
2088 llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; });
2091 void ArrayConcatOp::build(OpBuilder &b, OperationState &state,
2092 ValueRange values) {
2093 assert(!values.empty() &&
"Cannot build array of zero elements");
2094 ArrayType arrayTy = cast<ArrayType>(values[0].getType());
2095 Type elemTy = arrayTy.getElementType();
2096 assert(llvm::all_of(values,
2097 [elemTy](Value v) ->
bool {
2098 return isa<ArrayType>(v.getType()) &&
2099 cast<ArrayType>(v.getType()).getElementType() ==
2102 "All values must be of ArrayType with the same element type.");
2104 uint64_t resultSize = 0;
2105 for (Value val : values)
2106 resultSize += cast<ArrayType>(val.getType()).getNumElements();
2110 OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) {
2111 auto inputs = adaptor.getInputs();
2112 SmallVector<Attribute> array;
2113 for (
size_t i = 0, e = getNumOperands(); i < e; ++i) {
2116 llvm::copy(cast<ArrayAttr>(inputs[i]), std::back_inserter(array));
2123 for (
auto input : op.getInputs())
2127 SmallVector<Value> items;
2128 for (
auto input : op.getInputs()) {
2129 auto create = cast<ArrayCreateOp>(input.getDefiningOp());
2130 for (
auto item : create.getInputs())
2131 items.push_back(item);
2145 SmallVector<Location> locs;
2148 SmallVector<Value> items;
2149 std::optional<Slice> last;
2150 bool changed =
false;
2152 auto concatenate = [&] {
2157 items.push_back(last->op);
2166 auto origTy = hw::type_cast<ArrayType>(last->input.getType());
2167 auto arrayTy =
ArrayType::get(origTy.getElementType(), last->size);
2169 loc, arrayTy, last->input, last->index));
2174 auto append = [&](Value op, Value input, Value index,
size_t size) {
2179 if (last->input == input &&
isOffset(last->index, index, last->size)) {
2182 last->locs.push_back(op.getLoc());
2187 last.emplace(Slice{input, index, size, op, {op.getLoc()}});
2190 for (
auto item : llvm::reverse(op.getInputs())) {
2192 auto size = hw::type_cast<ArrayType>(slice.getType()).getNumElements();
2193 append(item, slice.getInput(), slice.getLowIndex(), size);
2198 if (create.getInputs().size() == 1) {
2199 if (
auto get = create.getInputs()[0].getDefiningOp<
ArrayGetOp>()) {
2207 items.push_back(item);
2214 if (items.size() == 1) {
2215 rewriter.replaceOp(op, items[0]);
2217 std::reverse(items.begin(), items.end());
2224 PatternRewriter &rewriter) {
2240 ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) {
2247 auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
2248 if (parser.parseKeyword(&field) || parser.parseColonType(type))
2257 result.addAttribute(
"field", fieldAttr);
2258 result.addTypes(type);
2263 void EnumConstantOp::print(OpAsmPrinter &p) {
2264 p <<
" " << getField().getField().getValue() <<
" : "
2265 << getField().getType().getValue();
2269 function_ref<
void(Value, StringRef)> setNameFn) {
2270 setNameFn(getResult(), getField().getField().str());
2273 void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState,
2274 EnumFieldAttr field) {
2275 return build(builder, odsState, field.getType().getValue(), field);
2278 OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) {
2279 assert(adaptor.getOperands().empty() &&
"constant has no operands");
2280 return getFieldAttr();
2284 auto fieldAttr = getFieldAttr();
2285 auto fieldType = fieldAttr.getType().getValue();
2288 if (fieldType != getType())
2289 emitOpError(
"return type ")
2290 << getType() <<
" does not match attribute type " << fieldAttr;
2300 auto lhsType = type_cast<EnumType>(getLhs().getType());
2301 auto rhsType = type_cast<EnumType>(getRhs().getType());
2302 if (rhsType != lhsType)
2303 emitOpError(
"types do not match");
2311 ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2312 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2313 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2314 Type declOrAliasType;
2316 if (parser.parseLParen() || parser.parseOperandList(operands) ||
2317 parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
2318 parser.parseColonType(declOrAliasType))
2321 auto declType = type_dyn_cast<StructType>(declOrAliasType);
2323 return parser.emitError(parser.getNameLoc(),
2324 "expected !hw.struct type or alias");
2326 llvm::SmallVector<Type, 4> structInnerTypes;
2327 declType.getInnerTypes(structInnerTypes);
2328 result.addTypes(declOrAliasType);
2330 if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc,
2336 void StructCreateOp::print(OpAsmPrinter &printer) {
2338 printer.printOperands(getInput());
2340 printer.printOptionalAttrDict((*this)->getAttrs());
2341 printer <<
" : " << getType();
2345 auto elements = hw::type_cast<StructType>(getType()).getElements();
2347 if (elements.size() != getInput().size())
2348 return emitOpError(
"structure field count mismatch");
2350 for (
const auto &[field, value] : llvm::zip(elements, getInput()))
2351 if (field.type != value.getType())
2352 return emitOpError(
"structure field `")
2353 << field.name <<
"` type does not match";
2358 OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
2360 if (!getInput().
empty())
2361 if (
auto explodeOp = getInput()[0].getDefiningOp<StructExplodeOp>();
2362 explodeOp && getInput() == explodeOp.getResults() &&
2363 getResult().getType() == explodeOp.getInput().getType())
2364 return explodeOp.getInput();
2366 auto inputs = adaptor.getInput();
2367 if (llvm::any_of(inputs, [](Attribute attr) {
return !attr; }))
2376 ParseResult StructExplodeOp::parse(OpAsmParser &parser,
2377 OperationState &result) {
2378 OpAsmParser::UnresolvedOperand operand;
2381 if (parser.parseOperand(operand) ||
2382 parser.parseOptionalAttrDict(result.attributes) ||
2383 parser.parseColonType(declType))
2385 auto structType = type_dyn_cast<StructType>(declType);
2387 return parser.emitError(parser.getNameLoc(),
2388 "invalid kind of type specified");
2390 llvm::SmallVector<Type, 4> structInnerTypes;
2391 structType.getInnerTypes(structInnerTypes);
2392 result.addTypes(structInnerTypes);
2394 if (parser.resolveOperand(operand, declType, result.operands))
2399 void StructExplodeOp::print(OpAsmPrinter &printer) {
2401 printer.printOperand(getInput());
2402 printer.printOptionalAttrDict((*this)->getAttrs());
2403 printer <<
" : " << getInput().getType();
2406 LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor,
2407 SmallVectorImpl<OpFoldResult> &results) {
2408 auto input = adaptor.getInput();
2411 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(results));
2416 PatternRewriter &rewriter) {
2417 auto *inputOp = op.getInput().getDefiningOp();
2418 auto elements = type_cast<StructType>(op.getInput().getType()).getElements();
2419 auto result = failure();
2420 auto opResults = op.getResults();
2421 for (uint32_t index = 0; index < elements.size(); index++) {
2423 rewriter.replaceAllUsesWith(opResults[index], foldResult);
2431 function_ref<
void(Value, StringRef)> setNameFn) {
2432 auto structType = type_cast<StructType>(getInput().getType());
2433 for (
auto [res, field] : llvm::zip(getResults(), structType.getElements()))
2434 setNameFn(res, field.name.str());
2437 void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2439 StructType inputType = dyn_cast<StructType>(input.getType());
2441 SmallVector<Type, 16> fieldTypes;
2442 for (
auto field : inputType.getElements())
2443 fieldTypes.push_back(field.type);
2444 build(odsBuilder, odsState, fieldTypes, input);
2453 template <
typename AggregateOp,
typename AggregateType>
2455 AggregateType aggType,
2457 auto index = op.getFieldIndex();
2458 if (index >= aggType.getElements().size())
2459 return op.emitOpError() <<
"field index " << index
2460 <<
" exceeds element count of aggregate type";
2464 return op.emitOpError()
2465 <<
"type " << aggType.getElements()[index].type
2466 <<
" of accessed field in aggregate at index " << index
2467 <<
" does not match expected type " <<
elementType;
2473 return verifyAggregateFieldIndexAndType<StructExtractOp, StructType>(
2474 *
this, getInput().getType(), getType());
2479 template <
typename AggregateType>
2481 OpAsmParser::UnresolvedOperand operand;
2482 StringAttr fieldName;
2485 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2486 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2487 parser.parseOptionalAttrDict(result.attributes) ||
2488 parser.parseColonType(declType))
2490 auto aggType = type_dyn_cast<AggregateType>(declType);
2492 return parser.emitError(parser.getNameLoc(),
2493 "invalid kind of type specified");
2495 auto fieldIndex = aggType.getFieldIndex(fieldName);
2497 parser.emitError(parser.getNameLoc(),
"field name '" +
2498 fieldName.getValue() +
2499 "' not found in aggregate type");
2505 result.addAttribute(
"fieldIndex", indexAttr);
2506 Type resultType = aggType.getElements()[*fieldIndex].type;
2507 result.addTypes(resultType);
2509 if (parser.resolveOperand(operand, declType, result.operands))
2516 template <
typename AggType>
2519 printer.printOperand(op.getInput());
2520 printer <<
"[\"" << op.getFieldName() <<
"\"]";
2521 printer.printOptionalAttrDict(op->getAttrs(), {
"fieldIndex"});
2522 printer <<
" : " << op.getInput().getType();
2525 ParseResult StructExtractOp::parse(OpAsmParser &parser,
2526 OperationState &result) {
2527 return parseExtractOp<StructType>(parser, result);
2530 void StructExtractOp::print(OpAsmPrinter &printer) {
2534 void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2535 Value input, StructType::FieldInfo field) {
2537 type_cast<StructType>(input.getType()).getFieldIndex(field.name);
2538 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2539 build(builder, odsState, field.type, input, *fieldIndex);
2542 void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2543 Value input, StringAttr fieldName) {
2544 auto structType = type_cast<StructType>(input.getType());
2545 auto fieldIndex = structType.getFieldIndex(fieldName);
2546 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2547 auto resultType = structType.getElements()[*fieldIndex].type;
2548 build(builder, odsState, resultType, input, *fieldIndex);
2551 OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
2552 if (
auto constOperand = adaptor.getInput()) {
2554 auto operandAttr = llvm::cast<ArrayAttr>(constOperand);
2555 return operandAttr.getValue()[getFieldIndex()];
2558 if (
auto foldResult =
2565 PatternRewriter &rewriter) {
2566 auto inputOp = op.getInput().getDefiningOp();
2569 if (
auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
2570 if (structInject.getFieldIndex() != op.getFieldIndex()) {
2572 op, op.getType(), structInject.getInput(), op.getFieldIndexAttr());
2581 function_ref<
void(Value, StringRef)> setNameFn) {
2589 void StructInjectOp::build(OpBuilder &builder, OperationState &odsState,
2590 Value input, StringAttr fieldName, Value newValue) {
2591 auto structType = type_cast<StructType>(input.getType());
2592 auto fieldIndex = structType.getFieldIndex(fieldName);
2593 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2594 build(builder, odsState, input, *fieldIndex, newValue);
2598 return verifyAggregateFieldIndexAndType<StructInjectOp, StructType>(
2599 *
this, getInput().getType(), getNewValue().getType());
2602 ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) {
2603 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2604 OpAsmParser::UnresolvedOperand operand, val;
2605 StringAttr fieldName;
2608 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2609 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2610 parser.parseComma() || parser.parseOperand(val) ||
2611 parser.parseOptionalAttrDict(result.attributes) ||
2612 parser.parseColonType(declType))
2614 auto structType = type_dyn_cast<StructType>(declType);
2616 return parser.emitError(inputOperandsLoc,
"invalid kind of type specified");
2618 auto fieldIndex = structType.getFieldIndex(fieldName);
2620 parser.emitError(parser.getNameLoc(),
"field name '" +
2621 fieldName.getValue() +
2622 "' not found in aggregate type");
2628 result.addAttribute(
"fieldIndex", indexAttr);
2629 result.addTypes(declType);
2631 Type resultType = structType.getElements()[*fieldIndex].type;
2632 if (parser.resolveOperands({operand, val}, {declType, resultType},
2633 inputOperandsLoc, result.operands))
2638 void StructInjectOp::print(OpAsmPrinter &printer) {
2640 printer.printOperand(getInput());
2642 printer.printOperand(getNewValue());
2643 printer.printOptionalAttrDict((*this)->getAttrs(), {
"fieldIndex"});
2644 printer <<
" : " << getInput().getType();
2647 OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
2648 auto input = adaptor.getInput();
2649 auto newValue = adaptor.getNewValue();
2650 if (!input || !newValue)
2652 SmallVector<Attribute> array;
2653 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(array));
2654 array[getFieldIndex()] = newValue;
2659 PatternRewriter &rewriter) {
2661 SmallPtrSet<Operation *, 4> injects;
2662 DenseMap<StringAttr, Value> fields;
2665 StructInjectOp inject = op;
2668 if (!injects.insert(inject).second)
2671 fields.try_emplace(inject.getFieldNameAttr(), inject.getNewValue());
2672 input = inject.getInput();
2673 inject = dyn_cast_or_null<StructInjectOp>(input.getDefiningOp());
2675 assert(input &&
"missing input to inject chain");
2677 auto ty = hw::type_cast<StructType>(op.getType());
2678 auto elements = ty.getElements();
2681 if (fields.size() == elements.size()) {
2682 SmallVector<Value> createFields;
2683 for (
const auto &field : elements) {
2684 auto it = fields.find(field.name);
2685 assert(it != fields.end() &&
"missing field");
2686 createFields.push_back(it->second);
2688 rewriter.replaceOpWithNewOp<
StructCreateOp>(op, ty, createFields);
2693 if (injects.size() == fields.size())
2697 for (uint32_t fieldIndex = 0; fieldIndex < elements.size(); fieldIndex++) {
2698 auto it = fields.find(elements[fieldIndex].name);
2699 if (it == fields.end())
2701 input = rewriter.create<StructInjectOp>(op.getLoc(), ty, input, fieldIndex,
2705 rewriter.replaceOp(op, input);
2714 return verifyAggregateFieldIndexAndType<UnionCreateOp, UnionType>(
2715 *
this, getType(), getInput().getType());
2718 void UnionCreateOp::build(OpBuilder &builder, OperationState &odsState,
2719 Type unionType, StringAttr fieldName, Value input) {
2720 auto fieldIndex = type_cast<UnionType>(unionType).getFieldIndex(fieldName);
2721 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2722 build(builder, odsState, unionType, *fieldIndex, input);
2725 ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2726 Type declOrAliasType;
2727 StringAttr fieldName;
2728 OpAsmParser::UnresolvedOperand input;
2729 llvm::SMLoc fieldLoc = parser.getCurrentLocation();
2731 if (parser.parseAttribute(fieldName) || parser.parseComma() ||
2732 parser.parseOperand(input) ||
2733 parser.parseOptionalAttrDict(result.attributes) ||
2734 parser.parseColonType(declOrAliasType))
2737 auto declType = type_dyn_cast<UnionType>(declOrAliasType);
2739 return parser.emitError(parser.getNameLoc(),
2740 "expected !hw.union type or alias");
2742 auto fieldIndex = declType.getFieldIndex(fieldName);
2744 parser.emitError(fieldLoc,
"cannot find union field '")
2745 << fieldName.getValue() <<
'\'';
2751 result.addAttribute(
"fieldIndex", indexAttr);
2752 Type inputType = declType.getElements()[*fieldIndex].type;
2754 if (parser.resolveOperand(input, inputType, result.operands))
2756 result.addTypes({declOrAliasType});
2760 void UnionCreateOp::print(OpAsmPrinter &printer) {
2762 printer.printOperand(getInput());
2763 printer.printOptionalAttrDict((*this)->getAttrs(), {
"fieldIndex"});
2764 printer <<
" : " << getType();
2771 ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) {
2772 return parseExtractOp<UnionType>(parser, result);
2775 void UnionExtractOp::print(OpAsmPrinter &printer) {
2779 LogicalResult UnionExtractOp::inferReturnTypes(
2780 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
2781 DictionaryAttr attrs, mlir::OpaqueProperties properties,
2782 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
2783 Adaptor adaptor(operands, attrs, properties, regions);
2784 auto unionElements =
2785 hw::type_cast<UnionType>((adaptor.getInput().getType())).getElements();
2786 unsigned fieldIndex = adaptor.getFieldIndexAttr().getValue().getZExtValue();
2787 if (fieldIndex >= unionElements.size()) {
2789 mlir::emitError(*loc,
"field index " + Twine(fieldIndex) +
2790 " exceeds element count of aggregate type");
2793 results.push_back(unionElements[fieldIndex].type);
2797 void UnionExtractOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2798 Value input, StringAttr fieldName) {
2799 auto unionType = type_cast<UnionType>(input.getType());
2800 auto fieldIndex = unionType.getFieldIndex(fieldName);
2801 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2802 auto resultType = unionType.getElements()[*fieldIndex].type;
2803 build(odsBuilder, odsState, resultType, input, *fieldIndex);
2815 OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) {
2816 auto inputCst = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2817 auto indexCst = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
2822 auto indexVal = indexCst.getValue();
2823 if (indexVal.getBitWidth() < 64) {
2824 auto index = indexVal.getZExtValue();
2825 return inputCst[inputCst.size() - 1 - index];
2830 if (!inputCst.empty() && llvm::all_equal(inputCst))
2835 if (
auto bitcast = getInput().getDefiningOp<hw::BitcastOp>()) {
2836 auto intTy = dyn_cast<IntegerType>(getType());
2839 auto bitcastInputOp = bitcast.getInput().getDefiningOp<
hw::ConstantOp>();
2840 if (!bitcastInputOp)
2844 auto bitcastInputCst = bitcastInputOp.getValue();
2847 auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) *
2848 getType().getIntOrFloatBitWidth();
2851 intTy.getIntOrFloatBitWidth()));
2854 auto inputCreate = getInput().getDefiningOp<
ArrayCreateOp>();
2858 if (
auto uniformValue = inputCreate.getUniformElement())
2859 return uniformValue;
2861 if (!indexCst || indexCst.getValue().getBitWidth() > 64)
2864 uint64_t index = indexCst.getValue().getLimitedValue();
2865 auto createInputs = inputCreate.getInputs();
2866 if (index >= createInputs.size())
2868 return createInputs[createInputs.size() - index - 1];
2872 PatternRewriter &rewriter) {
2877 auto *inputOp = op.getInput().getDefiningOp();
2878 if (
auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
2880 auto offsetOp = inputSlice.getLowIndex();
2885 uint64_t offset = *offsetOpt + *idxOpt;
2887 rewriter.create<
ConstantOp>(op.getLoc(), offsetOp.getType(), offset);
2888 rewriter.replaceOpWithNewOp<
ArrayGetOp>(op, inputSlice.getInput(),
2893 if (
auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2895 uint64_t elemIndex = *idxOpt;
2896 for (
auto input : llvm::reverse(inputConcat.getInputs())) {
2897 size_t size = hw::type_cast<ArrayType>(input.getType()).getNumElements();
2898 if (elemIndex >= size) {
2903 unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size);
2905 op.getLoc(), rewriter.getIntegerType(indexWidth), elemIndex);
2907 rewriter.replaceOpWithNewOp<
ArrayGetOp>(op, input, newIdxOp);
2916 if (
auto innerGet = dyn_cast_or_null<hw::ArrayGetOp>(inputOp)) {
2921 SmallVector<Value> newValues;
2922 for (
auto operand : create.getOperands())
2924 op.getLoc(), operand, op.getIndex()));
2929 innerGet.getIndex());
2942 StringRef TypedeclOp::getPreferredName() {
2943 return getVerilogName().value_or(
getName());
2946 Type TypedeclOp::getAliasType() {
2947 auto parentScope = cast<hw::TypeScopeOp>(getOperation()->getParentOp());
2950 {FlatSymbolRefAttr::get(*this)}),
2958 OpFoldResult BitcastOp::fold(FoldAdaptor) {
2961 if (getOperand().getType() == getType())
2962 return getOperand();
2973 dyn_cast_or_null<BitcastOp>(op.getInput().getDefiningOp());
2976 auto bitcast = rewriter.createOrFold<
BitcastOp>(op.getLoc(), op.getType(),
2977 inputBitcast.getInput());
2978 rewriter.replaceOp(op, bitcast);
2984 return this->emitOpError(
"Bitwidth of input must match result");
2992 bool HierPathOp::dropModule(StringAttr moduleToDrop) {
2993 SmallVector<Attribute, 4> newPath;
2994 bool updateMade =
false;
2995 for (
auto nameRef : getNamepath()) {
2997 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
2998 if (ref.getModule() == moduleToDrop)
3001 newPath.push_back(ref);
3003 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3006 newPath.push_back(nameRef);
3014 bool HierPathOp::inlineModule(StringAttr moduleToDrop) {
3015 SmallVector<Attribute, 4> newPath;
3016 bool updateMade =
false;
3017 StringRef inlinedInstanceName =
"";
3018 for (
auto nameRef : getNamepath()) {
3020 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3021 if (ref.getModule() == moduleToDrop) {
3022 inlinedInstanceName = ref.getName().getValue();
3024 }
else if (!inlinedInstanceName.empty()) {
3028 ref.getName().getValue())));
3029 inlinedInstanceName =
"";
3031 newPath.push_back(ref);
3033 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == moduleToDrop)
3036 newPath.push_back(nameRef);
3044 bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) {
3045 SmallVector<Attribute, 4> newPath;
3046 bool updateMade =
false;
3047 for (
auto nameRef : getNamepath()) {
3049 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3050 if (ref.getModule() == oldMod) {
3054 newPath.push_back(ref);
3056 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == oldMod) {
3060 newPath.push_back(nameRef);
3068 bool HierPathOp::updateModuleAndInnerRef(
3069 StringAttr oldMod, StringAttr newMod,
3070 const llvm::DenseMap<StringAttr, StringAttr> &innerSymRenameMap) {
3072 if (oldMod == newMod)
3075 auto namepathNew = getNamepath().getValue().vec();
3076 bool updateMade =
false;
3078 for (
auto &element : namepathNew) {
3079 if (
auto innerRef = dyn_cast<hw::InnerRefAttr>(element)) {
3080 if (innerRef.getModule() != oldMod)
3082 auto symName = innerRef.getName();
3085 auto to = innerSymRenameMap.find(symName);
3086 if (to != innerSymRenameMap.end())
3087 symName = to->second;
3092 if (element != fromRef)
3104 bool HierPathOp::truncateAtModule(StringAttr atMod,
bool includeMod) {
3105 SmallVector<Attribute, 4> newPath;
3106 bool updateMade =
false;
3107 for (
auto nameRef : getNamepath()) {
3109 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3110 if (ref.getModule() == atMod) {
3113 newPath.push_back(ref);
3115 newPath.push_back(ref);
3117 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == atMod && !includeMod)
3120 newPath.push_back(nameRef);
3131 StringAttr HierPathOp::modPart(
unsigned i) {
3132 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3133 .Case<FlatSymbolRefAttr>([](
auto a) {
return a.getAttr(); })
3134 .Case<hw::InnerRefAttr>([](
auto a) {
return a.getModule(); });
3138 StringAttr HierPathOp::root() {
3144 bool HierPathOp::hasModule(StringAttr modName) {
3145 for (
auto nameRef : getNamepath()) {
3147 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3148 if (ref.getModule() == modName)
3151 if (cast<FlatSymbolRefAttr>(nameRef).getAttr() == modName)
3159 bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName)
const {
3160 for (
auto nameRef :
const_cast<HierPathOp *
>(
this)->getNamepath())
3161 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef))
3162 if (ref.getName() == symName && ref.getModule() == modName)
3170 StringAttr HierPathOp::refPart(
unsigned i) {
3171 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3172 .Case<FlatSymbolRefAttr>([](
auto a) {
return StringAttr({}); })
3173 .Case<hw::InnerRefAttr>([](
auto a) {
return a.getName(); });
3178 StringAttr HierPathOp::ref() {
3180 return refPart(getNamepath().size() - 1);
3184 StringAttr HierPathOp::leafMod() {
3186 return modPart(getNamepath().size() - 1);
3191 bool HierPathOp::isModule() {
return !ref(); }
3195 bool HierPathOp::isComponent() {
return (
bool)ref(); }
3210 LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
3211 ArrayAttr expectedModuleNames = {};
3212 auto checkExpectedModule = [&](Attribute name) -> LogicalResult {
3213 if (!expectedModuleNames)
3215 if (llvm::any_of(expectedModuleNames,
3216 [name](Attribute attr) {
return attr == name; }))
3218 auto diag = emitOpError() <<
"instance path is incorrect. Expected ";
3219 size_t n = expectedModuleNames.size();
3223 for (
size_t i = 0; i < n; ++i) {
3225 diag << ((i + 1 == n) ?
" or " :
", ");
3226 diag << cast<StringAttr>(expectedModuleNames[i]);
3228 diag <<
". Instead found: " << name;
3232 if (!getNamepath() || getNamepath().
empty())
3233 return emitOpError() <<
"the instance path cannot be empty";
3234 for (
unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) {
3235 hw::InnerRefAttr innerRef = dyn_cast<hw::InnerRefAttr>(getNamepath()[i]);
3237 return emitOpError()
3238 <<
"the instance path can only contain inner sym reference"
3239 <<
", only the leaf can refer to a module symbol";
3241 if (failed(checkExpectedModule(innerRef.getModule())))
3244 auto instOp = ns.lookupOp<igraph::InstanceOpInterface>(innerRef);
3246 return emitOpError() <<
" module: " << innerRef.getModule()
3247 <<
" does not contain any instance with symbol: "
3248 << innerRef.getName();
3249 expectedModuleNames = instOp.getReferencedModuleNamesAttr();
3253 auto leafRef = getNamepath()[getNamepath().size() - 1];
3254 if (
auto innerRef = dyn_cast<hw::InnerRefAttr>(leafRef)) {
3255 if (!ns.lookup(innerRef)) {
3256 return emitOpError() <<
" operation with symbol: " << innerRef
3257 <<
" was not found ";
3259 if (failed(checkExpectedModule(innerRef.getModule())))
3261 }
else if (failed(checkExpectedModule(
3262 cast<FlatSymbolRefAttr>(leafRef).getAttr()))) {
3268 void HierPathOp::print(OpAsmPrinter &p) {
3272 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
3273 if (
auto visibility =
3274 getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
3275 p << visibility.getValue() <<
' ';
3277 p.printSymbolName(getSymName());
3279 llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) {
3280 if (
auto ref = dyn_cast<hw::InnerRefAttr>(attr)) {
3281 p.printSymbolName(ref.getModule().getValue());
3283 p.printSymbolName(ref.getName().getValue());
3285 p.printSymbolName(cast<FlatSymbolRefAttr>(attr).getValue());
3289 p.printOptionalAttrDict(
3290 (*this)->getAttrs(),
3291 {SymbolTable::getSymbolAttrName(),
"namepath", visibilityAttrName});
3294 ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) {
3296 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
3300 if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(),
3305 SmallVector<Attribute> namepath;
3306 if (parser.parseCommaSeparatedList(
3307 OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
3308 auto loc = parser.getCurrentLocation();
3310 if (parser.parseAttribute(ref))
3314 auto pathLength = ref.getNestedReferences().size();
3315 if (pathLength == 0)
3317 FlatSymbolRefAttr::get(ref.getRootReference()));
3318 else if (pathLength == 1)
3319 namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(),
3320 ref.getLeafReference()));
3322 return parser.emitError(loc,
3323 "only one nested reference is allowed");
3327 result.addAttribute(
"namepath",
3330 if (parser.parseOptionalAttrDict(result.attributes))
3340 void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
3341 EventControlAttr event, Value trigger,
3342 ValueRange inputs) {
3343 odsState.addOperands(trigger);
3344 odsState.addOperands(inputs);
3345 odsState.addAttribute(getEventAttrName(odsState.name), event);
3346 auto *r = odsState.addRegion();
3350 llvm::SmallVector<Location> argLocs;
3351 llvm::transform(inputs, std::back_inserter(argLocs),
3352 [&](Value v) {
return v.getLoc(); });
3353 b->addArguments(inputs.getTypes(), argLocs);
3361 #define GET_OP_CLASSES
3362 #include "circt/Dialect/HW/HW.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyModuleCommon(HWModuleLike module)
static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, Type resultType)
static void printModuleOp(OpAsmPrinter &p, ModuleTy mod)
static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter)
static LogicalResult foldCreateToSlice(ArrayCreateOp op, PatternRewriter &rewriter)
static SmallVector< Location > getAllPortLocs(ModTy module)
static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context, ArrayRef< Attribute > attrs)
FunctionType getHWModuleOpType(Operation *op)
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, const ModulePortInfo &ports, ArrayAttr parameters, ArrayRef< NamedAttribute > attributes, StringAttr comment)
static void printExtractOp(OpAsmPrinter &printer, AggType op)
Use the same printer for both struct_extract and union_extract since the syntax is identical.
static void modifyModulePorts(Operation *op, ArrayRef< std::pair< unsigned, PortInfo >> insertInputs, ArrayRef< std::pair< unsigned, PortInfo >> insertOutputs, ArrayRef< unsigned > removeInputs, ArrayRef< unsigned > removeOutputs, Block *body=nullptr)
Insert and remove ports of a module.
static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, TypeRange inputTypes, Type resultType)
static void modifyModuleArgs(MLIRContext *context, ArrayRef< std::pair< unsigned, PortInfo >> insertArgs, ArrayRef< unsigned > removeArgs, ArrayRef< Attribute > oldArgNames, ArrayRef< Type > oldArgTypes, ArrayRef< Attribute > oldArgAttrs, ArrayRef< Location > oldArgLocs, SmallVector< Attribute > &newArgNames, SmallVector< Type > &newArgTypes, SmallVector< Attribute > &newArgAttrs, SmallVector< Location > &newArgLocs, Block *body=nullptr)
Internal implementation of argument/result insertion and removal on modules.
static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, Type &idxType)
static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex)
static bool hasAttribute(StringRef name, ArrayRef< NamedAttribute > attrs)
static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter)
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result)
Use the same parser for both struct_extract and union_extract since the syntax is identical.
static void setAllPortNames(ArrayRef< Attribute > names, ModTy module)
static void getAsmBlockArgumentNamesImpl(mlir::Region ®ion, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
static void setHWModuleType(ModTy &mod, ModuleType type)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, Type &resultType)
static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type)
static std::optional< uint64_t > getUIntFromValue(Value value)
static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op, AggregateType aggType, Type elementType)
Ensure an aggregate op's field index is within the bounds of the aggregate type and the accessed fiel...
static PortInfo getPort(ModuleTy &mod, size_t idx)
static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, Type idxType)
static bool hasAdditionalAttributes(Op op, ArrayRef< StringRef > ignoredAttrs={})
Check whether an operation has any additional attributes set beyond its standard list of attributes r...
static ParseResult parseArrayConcatTypes(OpAsmParser &p, SmallVectorImpl< Type > &inputTypes, Type &resultType)
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static InstancePath empty
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
void setOutput(unsigned i, Value v)
Value getInput(unsigned i)
llvm::SmallVector< Value > outputOperands
llvm::SmallVector< Value > inputArgs
llvm::StringMap< unsigned > outputIdx
llvm::StringMap< unsigned > inputIdx
HWModulePortAccessor(Location loc, const ModulePortInfo &info, Region &bodyRegion)
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
This helps visit TypeOp nodes.
ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args)
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
uint64_t getWidth(Type t)
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
LogicalResult verifyParameterStructure(ArrayAttr parameters, ArrayAttr moduleParameters, const EmitErrorFn &emitError)
Check that all the parameter values specified to the instance are structurally valid.
std::function< void(std::function< bool(InFlightDiagnostic &)>)> EmitErrorFn
Whenever the nested function returns true, a note referring to the referenced module is attached to t...
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
bool isOffset(Value base, Value index, uint64_t offset)
llvm::function_ref< void(OpBuilder &, HWModulePortAccessor &)> HWModuleBuilder
bool isCombinational(Operation *op)
Return true if the specified operation is a combinational logic op.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
LogicalResult checkParameterInContext(Attribute value, Operation *module, Operation *usingOp, bool disallowParamRefs=false)
Check parameter specified by value to see if it is valid within the scope of the specified module mod...
bool isValidParameterExpression(Attribute attr, Operation *module)
Return true if the specified attribute tree is made up of nodes that are valid in a parameter express...
bool isValidIndexBitWidth(Value index, Value array)
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
bool isAnyModuleOrInstance(Operation *module)
TODO: Move all these functions to a hw::ModuleLike interface.
StringAttr getVerilogModuleNameAttr(Operation *module)
Returns the verilog module name attribute or symbol name of any module-like operations.
mlir::Type getCanonicalType(mlir::Type type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
ParseResult parseOptionalParameterList(OpAsmParser &parser, ArrayAttr ¶meters)
Parse an parameter list if present.
void printOptionalParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
This holds a decoded list of input/inout and output ports for a module or instance.
size_t sizeOutputs() const
PortInfo & at(size_t idx)
size_t sizeInputs() const
PortDirectionRange getInputs()
PortDirectionRange getOutputs()
This holds the name, type, direction of a module's ports.