24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Interfaces/FunctionImplementation.h"
27 #include "llvm/ADT/BitVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/StringSet.h"
31 using namespace circt;
33 using mlir::TypedAttr;
45 llvm_unreachable(
"unknown PortDirection");
49 hw::ArrayType arrayType =
51 assert(arrayType &&
"expected array type");
52 unsigned indexWidth = index.getType().getIntOrFloatBitWidth();
53 auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements());
54 return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1)
55 : indexWidth == requiredWidth;
60 struct IsCombClassifier :
public TypeOpVisitor<IsCombClassifier, bool> {
61 bool visitInvalidTypeOp(Operation *op) {
return false; }
62 bool visitUnhandledTypeOp(Operation *op) {
return true; }
65 return (op->getDialect() && op->getDialect()->getNamespace() ==
"comb") ||
71 if (
auto structCreate = dyn_cast_or_null<StructCreateOp>(inputOp)) {
72 return structCreate.getOperand(fieldIndex);
76 if (
auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
77 if (structInject.getFieldIndex() != fieldIndex)
79 return structInject.getNewValue();
85 ArrayRef<Attribute> attrs) {
90 if (a && !cast<DictionaryAttr>(a).
empty()) {
106 auto module = cast<HWModuleOp>(region.getParentOp());
108 auto *block = ®ion.front();
109 for (
size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
110 auto name = module.getInputName(i);
112 setNameFn(block->getArgument(i), name);
129 Attribute value, ArrayAttr moduleParameters,
131 bool disallowParamRefs) {
134 if (isa<IntegerAttr>(value) || isa<FloatAttr>(value) ||
135 isa<StringAttr>(value) || isa<ParamVerbatimAttr>(value))
139 if (
auto expr = dyn_cast<ParamExprAttr>(value)) {
140 for (
auto op : expr.getOperands())
149 if (
auto parameterRef = dyn_cast<ParamDeclRefAttr>(value)) {
150 auto nameAttr = parameterRef.getName();
154 if (disallowParamRefs) {
155 instanceError([&](
auto &diag) {
156 diag <<
"parameter " << nameAttr
157 <<
" cannot be used as a default value for a parameter";
164 for (
auto param : moduleParameters) {
165 auto paramAttr = cast<ParamDeclAttr>(param);
166 if (paramAttr.getName() != nameAttr)
170 if (paramAttr.getType() == parameterRef.getType())
173 instanceError([&](
auto &diag) {
174 diag <<
"parameter " << nameAttr <<
" used with type "
175 << parameterRef.getType() <<
"; should have type "
176 << paramAttr.getType();
182 instanceError([&](
auto &diag) {
183 diag <<
"use of unknown parameter " << nameAttr;
189 instanceError([&](
auto &diag) {
190 diag <<
"invalid parameter value " << value;
204 bool disallowParamRefs) {
206 [&](
const std::function<bool(InFlightDiagnostic &)> &fn) {
208 auto diag = usingOp->emitOpError();
210 diag.attachNote(module->getLoc()) <<
"module declared here";
215 module->getAttrOfType<ArrayAttr>(
"parameters"),
216 emitError, disallowParamRefs);
230 for (
auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) {
262 void ConstantOp::print(OpAsmPrinter &p) {
264 p.printAttribute(getValueAttr());
265 p.printOptionalAttrDict((*this)->getAttrs(), {
"value"});
268 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
269 IntegerAttr valueAttr;
271 if (parser.parseAttribute(valueAttr,
"value", result.attributes) ||
272 parser.parseOptionalAttrDict(result.attributes))
275 result.addTypes(valueAttr.getType());
283 "hw.constant attribute bitwidth doesn't match return type");
290 void ConstantOp::build(OpBuilder &builder, OperationState &result,
291 const APInt &value) {
294 auto attr = builder.getIntegerAttr(type, value);
295 return build(builder, result, type, attr);
300 void ConstantOp::build(OpBuilder &builder, OperationState &result,
302 return build(builder, result, value.getType(), value);
309 void ConstantOp::build(OpBuilder &builder, OperationState &result, Type type,
311 auto numBits = cast<IntegerType>(type).getWidth();
312 build(builder, result, APInt(numBits, (uint64_t)value,
true));
316 function_ref<
void(Value, StringRef)> setNameFn) {
317 auto intTy = getType();
318 auto intCst = getValue();
321 if (cast<IntegerType>(intTy).
getWidth() == 1)
322 return setNameFn(getResult(), intCst.isZero() ?
"false" :
"true");
325 SmallVector<char, 32> specialNameBuffer;
326 llvm::raw_svector_ostream specialName(specialNameBuffer);
327 specialName <<
'c' << intCst <<
'_' << intTy;
328 setNameFn(getResult(), specialName.str());
331 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
332 assert(adaptor.getOperands().empty() &&
"constant has no operands");
333 return getValueAttr();
344 ArrayRef<StringRef> ignoredAttrs = {}) {
345 auto names = op.getAttributeNames();
346 llvm::SmallDenseSet<StringRef> nameSet;
347 nameSet.reserve(names.size() + ignoredAttrs.size());
348 nameSet.insert(names.begin(), names.end());
349 nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end());
350 return llvm::any_of(op->getAttrs(), [&](
auto namedAttr) {
351 return !nameSet.contains(namedAttr.getName());
357 auto nameAttr = (*this)->getAttrOfType<StringAttr>(
"name");
358 if (nameAttr && !nameAttr.getValue().empty())
359 setNameFn(getResult(), nameAttr.getValue());
362 std::optional<size_t> WireOp::getTargetResultIndex() {
return 0; }
364 OpFoldResult WireOp::fold(FoldAdaptor adaptor) {
379 if (wire.getInnerSymAttr())
384 if (
auto *inputOp = wire.getInput().getDefiningOp())
386 rewriter.modifyOpInPlace(inputOp,
387 [&] { inputOp->setAttr(
"sv.namehint", name); });
389 rewriter.replaceOp(wire, wire.getInput());
399 if (
auto typeAlias = dyn_cast<TypeAliasType>(type))
400 type = typeAlias.getCanonicalType();
402 if (
auto structType = dyn_cast<StructType>(type)) {
403 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
405 return op->emitOpError(
"expected array attribute for constant of type ")
407 if (structType.getElements().size() != arrayAttr.size())
408 return op->emitOpError(
"array attribute (")
409 << arrayAttr.size() <<
") has wrong size for struct constant ("
410 << structType.getElements().size() <<
")";
412 for (
auto [attr, fieldInfo] :
413 llvm::zip(arrayAttr.getValue(), structType.getElements())) {
417 }
else if (
auto arrayType = dyn_cast<ArrayType>(type)) {
418 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
420 return op->emitOpError(
"expected array attribute for constant of type ")
422 if (arrayType.getNumElements() != arrayAttr.size())
423 return op->emitOpError(
"array attribute (")
424 << arrayAttr.size() <<
") has wrong size for array constant ("
425 << arrayType.getNumElements() <<
")";
428 for (
auto attr : arrayAttr.getValue()) {
432 }
else if (
auto arrayType = dyn_cast<UnpackedArrayType>(type)) {
433 auto arrayAttr = dyn_cast<ArrayAttr>(attr);
435 return op->emitOpError(
"expected array attribute for constant of type ")
438 if (arrayType.getNumElements() != arrayAttr.size())
439 return op->emitOpError(
"array attribute (")
441 <<
") has wrong size for unpacked array constant ("
442 << arrayType.getNumElements() <<
")";
444 for (
auto attr : arrayAttr.getValue()) {
448 }
else if (
auto enumType = dyn_cast<EnumType>(type)) {
449 auto stringAttr = dyn_cast<StringAttr>(attr);
451 return op->emitOpError(
"expected string attribute for constant of type ")
453 }
else if (
auto intType = dyn_cast<IntegerType>(type)) {
455 auto intAttr = dyn_cast<IntegerAttr>(attr);
457 return op->emitOpError(
"expected integer attribute for constant of type ")
460 if (intAttr.getValue().getBitWidth() != intType.getWidth())
461 return op->emitOpError(
"hw.constant attribute bitwidth "
462 "doesn't match return type");
464 return op->emitOpError(
"unknown element type") << type;
473 OpFoldResult AggregateConstantOp::fold(FoldAdaptor) {
return getFieldsAttr(); }
481 if (p.parseType(resultType) || p.parseEqual() ||
482 p.parseAttribute(value, resultType))
489 p << resultType <<
" = ";
490 p.printAttributeWithoutType(value);
499 OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) {
500 assert(adaptor.getOperands().empty() &&
"hw.param.value has no operands");
501 return getValueAttr();
510 return isa<HWModuleLike, InstanceOp>(moduleOrInstance);
516 return TypeSwitch<Operation *, FunctionType>(moduleOrInstance)
517 .Case<InstanceOp, InstanceChoiceOp>([](
auto instance) {
518 SmallVector<Type> inputs(instance->getOperandTypes());
519 SmallVector<Type> results(instance->getResultTypes());
523 [](
auto mod) {
return mod.getHWModuleType().getFuncType(); })
524 .Default([](Operation *op) {
525 return cast<FunctionType>(
526 cast<mlir::FunctionOpInterface>(op).getFunctionType());
534 auto nameAttr = module->getAttrOfType<StringAttr>(
"verilogName");
538 return module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
541 template <
typename ModuleTy>
543 buildModule(OpBuilder &builder, OperationState &result, StringAttr name,
545 ArrayRef<NamedAttribute> attributes, StringAttr comment) {
546 using namespace mlir::function_interface_impl;
549 result.addAttribute(SymbolTable::getSymbolAttrName(), name);
551 SmallVector<Attribute> perPortAttrs;
552 SmallVector<ModulePort> portTypes;
554 for (
auto elt : ports) {
555 portTypes.push_back(elt);
556 llvm::SmallVector<NamedAttribute> portAttrs;
558 llvm::copy(elt.attrs, std::back_inserter(portAttrs));
559 perPortAttrs.push_back(builder.getDictionaryAttr(portAttrs));
564 parameters = builder.getArrayAttr({});
568 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
570 result.addAttribute(
"per_port_attrs",
572 result.addAttribute(
"parameters", parameters);
574 comment = builder.getStringAttr(
"");
575 result.addAttribute(
"comment", comment);
576 result.addAttributes(attributes);
582 MLIRContext *context, ArrayRef<std::pair<unsigned, PortInfo>> insertArgs,
583 ArrayRef<unsigned> removeArgs, ArrayRef<Attribute> oldArgNames,
584 ArrayRef<Type> oldArgTypes, ArrayRef<Attribute> oldArgAttrs,
585 ArrayRef<Location> oldArgLocs, SmallVector<Attribute> &newArgNames,
586 SmallVector<Type> &newArgTypes, SmallVector<Attribute> &newArgAttrs,
587 SmallVector<Location> &newArgLocs, Block *body =
nullptr) {
592 assert(llvm::is_sorted(insertArgs,
593 [](
auto &a,
auto &b) {
return a.first < b.first; }) &&
594 "insertArgs must be in ascending order");
595 assert(llvm::is_sorted(removeArgs, [](
auto &a,
auto &b) {
return a < b; }) &&
596 "removeArgs must be in ascending order");
599 auto oldArgCount = oldArgTypes.size();
600 auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size();
601 assert((
int)newArgCount >= 0);
603 newArgNames.reserve(newArgCount);
604 newArgTypes.reserve(newArgCount);
605 newArgAttrs.reserve(newArgCount);
606 newArgLocs.reserve(newArgCount);
612 BitVector erasedIndices;
614 erasedIndices.resize(oldArgCount + insertArgs.size());
616 for (
unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) {
618 while (!insertArgs.empty() && insertArgs[0].first == argIdx) {
619 auto port = insertArgs[0].second;
621 !isa<InOutType>(port.type))
623 auto sym = port.getSym();
625 (sym && !sym.empty())
628 newArgNames.push_back(port.name);
629 newArgTypes.push_back(port.type);
630 newArgAttrs.push_back(attr);
631 insertArgs = insertArgs.drop_front();
632 LocationAttr loc = port.loc ? port.loc : unknownLoc;
633 newArgLocs.push_back(loc);
635 body->insertArgument(idx++, port.type, loc);
637 if (argIdx == oldArgCount)
641 bool removed =
false;
642 while (!removeArgs.empty() && removeArgs[0] == argIdx) {
643 removeArgs = removeArgs.drop_front();
649 erasedIndices.set(idx);
651 newArgNames.push_back(oldArgNames[argIdx]);
652 newArgTypes.push_back(oldArgTypes[argIdx]);
653 newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr
654 : oldArgAttrs[argIdx]);
655 newArgLocs.push_back(oldArgLocs[argIdx]);
660 body->eraseArguments(erasedIndices);
662 assert(newArgNames.size() == newArgCount);
663 assert(newArgTypes.size() == newArgCount);
664 assert(newArgAttrs.size() == newArgCount);
665 assert(newArgLocs.size() == newArgCount);
679 [[deprecated]]
static void
681 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
682 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
683 ArrayRef<unsigned> removeInputs,
684 ArrayRef<unsigned> removeOutputs, Block *body =
nullptr) {
685 auto moduleOp = cast<HWModuleLike>(op);
686 auto *context = moduleOp.getContext();
689 auto oldArgNames = moduleOp.getInputNames();
690 auto oldArgTypes = moduleOp.getInputTypes();
691 auto oldArgAttrs = moduleOp.getAllInputAttrs();
692 auto oldArgLocs = moduleOp.getInputLocs();
694 auto oldResultNames = moduleOp.getOutputNames();
695 auto oldResultTypes = moduleOp.getOutputTypes();
696 auto oldResultAttrs = moduleOp.getAllOutputAttrs();
697 auto oldResultLocs = moduleOp.getOutputLocs();
700 SmallVector<Attribute> newArgNames, newResultNames;
701 SmallVector<Type> newArgTypes, newResultTypes;
702 SmallVector<Attribute> newArgAttrs, newResultAttrs;
703 SmallVector<Location> newArgLocs, newResultLocs;
706 oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames,
707 newArgTypes, newArgAttrs, newArgLocs, body);
710 oldResultTypes, oldResultAttrs, oldResultLocs,
711 newResultNames, newResultTypes, newResultAttrs,
717 moduleOp.setHWModuleType(modty);
718 moduleOp.setAllInputAttrs(newArgAttrs);
719 moduleOp.setAllOutputAttrs(newResultAttrs);
721 newArgLocs.append(newResultLocs.begin(), newResultLocs.end());
722 moduleOp.setAllPortLocs(newArgLocs);
725 void HWModuleOp::build(OpBuilder &builder, OperationState &result,
727 ArrayAttr parameters,
728 ArrayRef<NamedAttribute> attributes, StringAttr comment,
729 bool shouldEnsureTerminator) {
730 buildModule<HWModuleOp>(builder, result, name, ports, parameters, attributes,
734 auto *bodyRegion = result.regions[0].get();
736 bodyRegion->push_back(body);
739 auto unknownLoc = builder.getUnknownLoc();
741 auto loc = port.loc ? Location(port.loc) : unknownLoc;
742 auto type = port.type;
743 if (port.isInOut() && !isa<InOutType>(type))
745 body->addArgument(type, loc);
749 auto unknownLocAttr = cast<LocationAttr>(unknownLoc);
750 SmallVector<Attribute> resultLocs;
752 resultLocs.push_back(port.loc ? port.loc : unknownLocAttr);
753 result.addAttribute(
"result_locs", builder.getArrayAttr(resultLocs));
755 if (shouldEnsureTerminator)
756 HWModuleOp::ensureTerminator(*bodyRegion, builder, result.location);
759 void HWModuleOp::build(OpBuilder &builder, OperationState &result,
760 StringAttr name, ArrayRef<PortInfo> ports,
761 ArrayAttr parameters,
762 ArrayRef<NamedAttribute> attributes,
763 StringAttr comment) {
764 build(builder, result, name,
ModulePortInfo(ports), parameters, attributes,
768 void HWModuleOp::build(OpBuilder &builder, OperationState &odsState,
771 ArrayRef<NamedAttribute> attributes,
772 StringAttr comment) {
773 build(builder, odsState, name, ports, parameters, attributes, comment,
775 auto *bodyRegion = odsState.regions[0].get();
776 OpBuilder::InsertionGuard guard(builder);
778 builder.setInsertionPointToEnd(&bodyRegion->front());
779 modBuilder(builder, accessor);
781 llvm::SmallVector<Value> outputOperands = accessor.getOutputOperands();
782 builder.create<hw::OutputOp>(odsState.location, outputOperands);
785 void HWModuleOp::modifyPorts(
786 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
787 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
788 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
797 if (
auto vName = getVerilogNameAttr())
800 return (*this)->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
804 if (
auto vName = getVerilogNameAttr()) {
807 return (*this)->getAttrOfType<StringAttr>(
808 ::mlir::SymbolTable::getSymbolAttrName());
811 void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
813 StringRef verilogName, ArrayAttr parameters,
814 ArrayRef<NamedAttribute> attributes) {
815 buildModule<HWModuleExternOp>(builder, result, name, ports, parameters,
819 LocationAttr unknownLoc = builder.getUnknownLoc();
820 SmallVector<Attribute> portLocs;
821 for (
auto elt : ports)
822 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
823 result.addAttribute(
"port_locs", builder.getArrayAttr(portLocs));
825 if (!verilogName.empty())
826 result.addAttribute(
"verilogName", builder.getStringAttr(verilogName));
829 void HWModuleExternOp::build(OpBuilder &builder, OperationState &result,
830 StringAttr name, ArrayRef<PortInfo> ports,
831 StringRef verilogName, ArrayAttr parameters,
832 ArrayRef<NamedAttribute> attributes) {
833 build(builder, result, name,
ModulePortInfo(ports), verilogName, parameters,
837 void HWModuleExternOp::modifyPorts(
838 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
839 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
840 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
845 void HWModuleExternOp::appendOutputs(
846 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
848 void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
849 FlatSymbolRefAttr genKind, StringAttr name,
851 StringRef verilogName, ArrayAttr parameters,
852 ArrayRef<NamedAttribute> attributes) {
853 buildModule<HWModuleGeneratedOp>(builder, result, name, ports, parameters,
856 LocationAttr unknownLoc = builder.getUnknownLoc();
857 SmallVector<Attribute> portLocs;
858 for (
auto elt : ports)
859 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
860 result.addAttribute(
"port_locs", builder.getArrayAttr(portLocs));
862 result.addAttribute(
"generatorKind", genKind);
863 if (!verilogName.empty())
864 result.addAttribute(
"verilogName", builder.getStringAttr(verilogName));
867 void HWModuleGeneratedOp::build(OpBuilder &builder, OperationState &result,
868 FlatSymbolRefAttr genKind, StringAttr name,
869 ArrayRef<PortInfo> ports, StringRef verilogName,
870 ArrayAttr parameters,
871 ArrayRef<NamedAttribute> attributes) {
872 build(builder, result, genKind, name,
ModulePortInfo(ports), verilogName,
873 parameters, attributes);
876 void HWModuleGeneratedOp::modifyPorts(
877 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
878 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
879 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
884 void HWModuleGeneratedOp::appendOutputs(
885 ArrayRef<std::pair<StringAttr, Value>> outputs) {}
887 static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
888 for (
auto &argAttr : attrs)
889 if (argAttr.getName() == name)
894 template <
typename ModuleTy>
896 OperationState &result) {
898 using namespace mlir::function_interface_impl;
899 auto builder = parser.getBuilder();
900 auto loc = parser.getCurrentLocation();
903 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
907 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
912 FlatSymbolRefAttr kindAttr;
913 if constexpr (std::is_same_v<ModuleTy, HWModuleGeneratedOp>) {
914 if (parser.parseComma() ||
915 parser.parseAttribute(kindAttr,
"generatorKind", result.attributes)) {
921 ArrayAttr parameters;
925 SmallVector<module_like_impl::PortParse> ports;
931 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
935 parser.emitError(loc,
"explicit `parameters` attributes not allowed");
939 result.addAttribute(
"parameters", parameters);
940 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), modType);
944 SmallVector<Attribute> attrs;
945 for (
auto &port : ports)
946 attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
948 auto nonEmptyAttrsFn = [](Attribute attr) {
949 return attr && !cast<DictionaryAttr>(attr).empty();
951 if (llvm::any_of(attrs, nonEmptyAttrsFn))
952 result.addAttribute(ModuleTy::getPerPortAttrsAttrName(result.name),
953 builder.getArrayAttr(attrs));
956 auto unknownLoc = builder.getUnknownLoc();
957 auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
958 return attr && cast<Location>(attr) != unknownLoc;
960 SmallVector<Attribute> locs;
961 StringAttr portLocsAttrName;
962 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>) {
965 portLocsAttrName = ModuleTy::getResultLocsAttrName(result.name);
966 for (
auto &port : ports)
968 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
971 portLocsAttrName = ModuleTy::getPortLocsAttrName(result.name);
972 for (
auto &port : ports)
973 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
975 if (llvm::any_of(locs, nonEmptyLocsFn))
976 result.addAttribute(portLocsAttrName, builder.getArrayAttr(locs));
979 SmallVector<OpAsmParser::Argument, 4> entryArgs;
980 for (
auto &port : ports)
982 entryArgs.push_back(port);
985 auto *body = result.addRegion();
986 if (std::is_same_v<ModuleTy, HWModuleOp>) {
987 if (parser.parseRegion(*body, entryArgs))
990 HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
995 ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) {
996 return parseHWModuleOp<HWModuleOp>(parser, result);
999 ParseResult HWModuleExternOp::parse(OpAsmParser &parser,
1000 OperationState &result) {
1001 return parseHWModuleOp<HWModuleExternOp>(parser, result);
1004 ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser,
1005 OperationState &result) {
1006 return parseHWModuleOp<HWModuleGeneratedOp>(parser, result);
1010 if (
auto mod = dyn_cast<HWModuleLike>(op))
1011 return mod.getHWModuleType().getFuncType();
1012 return cast<FunctionType>(
1013 cast<mlir::FunctionOpInterface>(op).getFunctionType());
1016 template <
typename ModuleTy>
1020 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1021 if (
auto visibility = mod.getOperation()->template getAttrOfType<StringAttr>(
1022 visibilityAttrName))
1023 p << visibility.getValue() <<
' ';
1026 p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue());
1027 if (
auto gen = dyn_cast<HWModuleGeneratedOp>(mod.getOperation())) {
1029 p.printSymbolName(gen.getGeneratorKind());
1037 SmallVector<StringRef, 3> omittedAttrs;
1038 if (isa<HWModuleGeneratedOp>(mod.getOperation()))
1039 omittedAttrs.push_back(
"generatorKind");
1040 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>)
1041 omittedAttrs.push_back(mod.getResultLocsAttrName());
1043 omittedAttrs.push_back(mod.getPortLocsAttrName());
1044 omittedAttrs.push_back(mod.getModuleTypeAttrName());
1045 omittedAttrs.push_back(mod.getPerPortAttrsAttrName());
1046 omittedAttrs.push_back(mod.getParametersAttrName());
1047 omittedAttrs.push_back(visibilityAttrName);
1049 mod.getOperation()->template getAttrOfType<StringAttr>(
"comment"))
1050 if (cmt.getValue().empty())
1051 omittedAttrs.push_back(
"comment");
1053 mlir::function_interface_impl::printFunctionAttributes(p, mod.getOperation(),
1057 void HWModuleExternOp::print(OpAsmPrinter &p) {
printModuleOp(p, *
this); }
1058 void HWModuleGeneratedOp::print(OpAsmPrinter &p) {
printModuleOp(p, *
this); }
1060 void HWModuleOp::print(OpAsmPrinter &p) {
1064 Region &body = getBody();
1065 if (!body.empty()) {
1067 p.printRegion(body,
false,
1073 assert(isa<HWModuleLike>(module) &&
1074 "verifier hook should only be called on modules");
1076 SmallPtrSet<Attribute, 4> paramNames;
1079 for (
auto param : module->getAttrOfType<ArrayAttr>(
"parameters")) {
1080 auto paramAttr = cast<ParamDeclAttr>(param);
1084 if (!paramNames.insert(paramAttr.getName()).second)
1085 return module->emitOpError(
"parameter ")
1086 << paramAttr <<
" has the same name as a previous parameter";
1089 auto value = paramAttr.getValue();
1093 auto typedValue = dyn_cast<TypedAttr>(value);
1095 return module->emitOpError(
"parameter ")
1096 << paramAttr <<
" should have a typed value; has value " << value;
1098 if (typedValue.getType() != paramAttr.getType())
1099 return module->emitOpError(
"parameter ")
1100 << paramAttr <<
" should have type " << paramAttr.getType()
1101 <<
"; has type " << typedValue.getType();
1121 auto numInputs = type.getNumInputs();
1122 if (body->getNumArguments() != numInputs)
1123 return emitOpError(
"entry block must have")
1124 << numInputs <<
" arguments to match module signature";
1131 std::pair<StringAttr, BlockArgument>
1132 HWModuleOp::insertInput(
unsigned index, StringAttr name, Type ty) {
1136 for (
auto port : ports)
1137 ns.
newName(port.name.getValue());
1144 port.
name = nameAttr;
1151 return {nameAttr, body->getArgument(index)};
1154 void HWModuleOp::insertOutputs(
unsigned index,
1155 ArrayRef<std::pair<StringAttr, Value>> outputs) {
1157 auto output = cast<OutputOp>(
getBodyBlock()->getTerminator());
1158 assert(index <= output->getNumOperands() &&
"invalid output index");
1161 SmallVector<std::pair<unsigned, PortInfo>> indexedNewPorts;
1162 for (
auto &[name, value] : outputs) {
1166 port.
type = value.getType();
1167 indexedNewPorts.emplace_back(index, port);
1173 for (
auto &[name, value] : outputs)
1174 output->insertOperands(index++, value);
1177 void HWModuleOp::appendOutputs(ArrayRef<std::pair<StringAttr, Value>> outputs) {
1178 return insertOutputs(getNumOutputPorts(), outputs);
1181 void HWModuleOp::getAsmBlockArgumentNames(mlir::Region ®ion,
1186 void HWModuleExternOp::getAsmBlockArgumentNames(
1191 template <
typename ModTy>
1193 auto locs = module.getPortLocs();
1195 SmallVector<Location> retval;
1196 retval.reserve(locs->size());
1197 for (
auto l : *locs)
1198 retval.push_back(cast<Location>(l));
1200 assert(!locs->size() || locs->size() == module.getNumPorts());
1203 return SmallVector<Location>(module.getNumPorts(),
1208 SmallVector<Location> portLocs;
1210 auto resultLocs = getResultLocsAttr();
1211 unsigned inputCount = 0;
1215 for (
unsigned i = 0, e =
getNumPorts(); i < e; ++i) {
1216 if (modType.isOutput(i)) {
1217 auto loc = resultLocs
1219 resultLocs.getValue()[portLocs.size() - inputCount])
1221 portLocs.push_back(loc);
1223 auto loc = body ? body->getArgument(inputCount).getLoc() : unknownLoc;
1224 portLocs.push_back(loc);
1239 void HWModuleOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1240 SmallVector<Attribute> resultLocs;
1241 unsigned inputCount = 0;
1244 for (
unsigned i = 0, e =
getNumPorts(); i < e; ++i) {
1245 if (modType.isOutput(i))
1246 resultLocs.push_back(locs[i]);
1248 body->getArgument(inputCount++).setLoc(cast<Location>(locs[i]));
1253 void HWModuleExternOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1257 void HWModuleGeneratedOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1261 template <
typename ModTy>
1263 auto numInputs = module.getNumInputPorts();
1264 SmallVector<Attribute> argNames(names.begin(), names.begin() + numInputs);
1265 SmallVector<Attribute> resNames(names.begin() + numInputs, names.end());
1266 auto oldType = module.getModuleType();
1267 SmallVector<ModulePort> newPorts(oldType.getPorts().begin(),
1268 oldType.getPorts().end());
1269 for (
size_t i = 0UL, e = newPorts.size(); i != e; ++i)
1270 newPorts[i].name = cast<StringAttr>(names[i]);
1272 module.setModuleType(newType);
1287 ArrayRef<Attribute> HWModuleOp::getAllPortAttrs() {
1288 auto attrs = getPerPortAttrs();
1289 if (attrs && !attrs->empty())
1290 return attrs->getValue();
1294 ArrayRef<Attribute> HWModuleExternOp::getAllPortAttrs() {
1295 auto attrs = getPerPortAttrs();
1296 if (attrs && !attrs->empty())
1297 return attrs->getValue();
1301 ArrayRef<Attribute> HWModuleGeneratedOp::getAllPortAttrs() {
1302 auto attrs = getPerPortAttrs();
1303 if (attrs && !attrs->empty())
1304 return attrs->getValue();
1308 void HWModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1309 setPerPortAttrsAttr(
arrayOrEmpty(getContext(), attrs));
1312 void HWModuleExternOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1313 setPerPortAttrsAttr(
arrayOrEmpty(getContext(), attrs));
1316 void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1317 setPerPortAttrsAttr(
arrayOrEmpty(getContext(), attrs));
1320 void HWModuleOp::removeAllPortAttrs() {
1324 void HWModuleExternOp::removeAllPortAttrs() {
1328 void HWModuleGeneratedOp::removeAllPortAttrs() {
1334 template <
typename ModTy>
1336 auto argAttrs = mod.getAllInputAttrs();
1337 auto resAttrs = mod.getAllOutputAttrs();
1339 unsigned newNumArgs = type.getNumInputs();
1340 unsigned newNumResults = type.getNumOutputs();
1343 argAttrs.resize(newNumArgs, emptyDict);
1344 resAttrs.resize(newNumResults, emptyDict);
1346 SmallVector<Attribute> attrs;
1347 attrs.append(argAttrs.begin(), argAttrs.end());
1348 attrs.append(resAttrs.begin(), resAttrs.end());
1351 return mod.removeAllPortAttrs();
1352 mod.setAllPortAttrs(attrs);
1369 Operation *HWModuleGeneratedOp::getGeneratorKindOp() {
1370 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1371 return topLevelModuleOp.lookupSymbol(getGeneratorKind());
1375 HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1376 auto *referencedKind =
1377 symbolTable.lookupNearestSymbolFrom(*
this, getGeneratorKindAttr());
1379 if (referencedKind ==
nullptr)
1380 return emitError(
"Cannot find generator definition '")
1381 << getGeneratorKind() <<
"'";
1383 if (!isa<HWGeneratorSchemaOp>(referencedKind))
1384 return emitError(
"Symbol resolved to '")
1385 << referencedKind->getName()
1386 <<
"' which is not a HWGeneratorSchemaOp";
1388 auto referencedKindOp = dyn_cast<HWGeneratorSchemaOp>(referencedKind);
1389 auto paramRef = referencedKindOp.getRequiredAttrs();
1390 auto dict = (*this)->getAttrDictionary();
1391 for (
auto str : paramRef) {
1392 auto strAttr = dyn_cast<StringAttr>(str);
1394 return emitError(
"Unknown attribute type, expected a string");
1395 if (!dict.get(strAttr.getValue()))
1396 return emitError(
"Missing attribute '") << strAttr.getValue() <<
"'";
1406 void HWModuleGeneratedOp::getAsmBlockArgumentNames(
1411 LogicalResult HWModuleOp::verifyBody() {
return success(); }
1413 template <
typename ModuleTy>
1415 auto modTy = mod.getHWModuleType();
1417 SmallVector<PortInfo> retval;
1418 auto locs = mod.getAllPortLocs();
1419 for (
unsigned i = 0, e = modTy.getNumPorts(); i < e; ++i) {
1420 LocationAttr loc = locs[i];
1421 DictionaryAttr attrs =
1422 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(i));
1425 retval.push_back({modTy.getPorts()[i],
1426 modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
1427 : modTy.getInputIdForPortId(i),
1433 template <
typename ModuleTy>
1435 auto modTy = mod.getHWModuleType();
1437 LocationAttr loc = mod.getPortLoc(idx);
1438 DictionaryAttr attrs =
1439 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(idx));
1442 return {modTy.getPorts()[idx],
1443 modTy.isOutput(idx) ? modTy.getOutputIdForPortId(idx)
1444 : modTy.getInputIdForPortId(idx),
1453 void InstanceOp::build(OpBuilder &builder, OperationState &result,
1454 Operation *module, StringAttr name,
1455 ArrayRef<Value> inputs, ArrayAttr parameters,
1456 InnerSymAttr innerSym) {
1458 parameters = builder.getArrayAttr({});
1460 auto mod = cast<hw::HWModuleLike>(module);
1461 auto argNames = builder.getArrayAttr(mod.getInputNames());
1462 auto resultNames = builder.getArrayAttr(mod.getOutputNames());
1467 ModuleType modType = mod.getHWModuleType();
1468 FailureOr<ModuleType> resolvedModType = modType.resolveParametricTypes(
1469 parameters, result.location,
false);
1470 if (succeeded(resolvedModType))
1471 modType = *resolvedModType;
1472 FunctionType funcType = resolvedModType->getFuncType();
1473 build(builder, result, funcType.getResults(), name,
1475 argNames, resultNames, parameters, innerSym);
1478 std::optional<size_t> InstanceOp::getTargetResultIndex() {
1480 return std::nullopt;
1483 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1485 *
this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(),
1486 getResultNames(), getParameters(), symbolTable);
1490 auto module = (*this)->getParentOfType<
HWModuleOp>();
1494 auto moduleParameters = module->getAttrOfType<ArrayAttr>(
"parameters");
1496 [&](
const std::function<bool(InFlightDiagnostic &)> &fn) {
1497 auto diag = emitOpError();
1499 diag.attachNote(module->getLoc()) <<
"module declared here";
1502 getParameters(), moduleParameters, emitError);
1505 ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
1506 StringAttr instanceNameAttr;
1507 InnerSymAttr innerSym;
1508 FlatSymbolRefAttr moduleNameAttr;
1509 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1510 SmallVector<Type, 1> inputsTypes, allResultTypes;
1511 ArrayAttr argNames, resultNames, parameters;
1512 auto noneType = parser.getBuilder().getType<NoneType>();
1514 if (parser.parseAttribute(instanceNameAttr, noneType,
"instanceName",
1518 if (succeeded(parser.parseOptionalKeyword(
"sym"))) {
1521 if (parser.parseCustomAttributeWithFallback(innerSym))
1526 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1527 if (parser.parseAttribute(moduleNameAttr, noneType,
"moduleName",
1528 result.attributes) ||
1529 parser.getCurrentLocation(¶metersLoc) ||
1532 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1534 parser.parseArrow() ||
1536 parser.parseOptionalAttrDict(result.attributes)) {
1540 result.addAttribute(
"argNames", argNames);
1541 result.addAttribute(
"resultNames", resultNames);
1542 result.addAttribute(
"parameters", parameters);
1543 result.addTypes(allResultTypes);
1547 void InstanceOp::print(OpAsmPrinter &p) {
1549 p.printAttributeWithoutType(getInstanceNameAttr());
1550 if (
auto attr = getInnerSymAttr()) {
1555 p.printAttributeWithoutType(getModuleNameAttr());
1562 p.printOptionalAttrDict(
1563 (*this)->getAttrs(),
1565 InnerSymbolTable::getInnerSymbolAttrName(),
"moduleName",
1566 "argNames",
"resultNames",
"parameters"});
1573 std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
1575 return std::nullopt;
1579 InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1580 for (Attribute name : getModuleNamesAttr()) {
1582 *
this, cast<FlatSymbolRefAttr>(name), getInputs(), getResultTypes(),
1583 getArgNames(), getResultNames(), getParameters(), symbolTable))) {
1591 auto module = (*this)->getParentOfType<
HWModuleOp>();
1595 auto moduleParameters = module->getAttrOfType<ArrayAttr>(
"parameters");
1597 [&](
const std::function<bool(InFlightDiagnostic &)> &fn) {
1598 auto diag = emitOpError();
1600 diag.attachNote(module->getLoc()) <<
"module declared here";
1603 getParameters(), moduleParameters, emitError);
1606 ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
1607 OperationState &result) {
1608 StringAttr optionNameAttr;
1609 StringAttr instanceNameAttr;
1610 InnerSymAttr innerSym;
1611 SmallVector<Attribute> moduleNames;
1612 SmallVector<Attribute> caseNames;
1613 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1614 SmallVector<Type, 1> inputsTypes, allResultTypes;
1615 ArrayAttr argNames, resultNames, parameters;
1616 auto noneType = parser.getBuilder().getType<NoneType>();
1618 if (parser.parseAttribute(instanceNameAttr, noneType,
"instanceName",
1622 if (succeeded(parser.parseOptionalKeyword(
"sym"))) {
1625 if (parser.parseCustomAttributeWithFallback(innerSym))
1630 if (parser.parseKeyword(
"option") ||
1631 parser.parseAttribute(optionNameAttr, noneType,
"optionName",
1635 FlatSymbolRefAttr defaultModuleName;
1636 if (parser.parseAttribute(defaultModuleName))
1638 moduleNames.push_back(defaultModuleName);
1640 while (succeeded(parser.parseOptionalKeyword(
"or"))) {
1641 FlatSymbolRefAttr moduleName;
1642 StringAttr targetName;
1643 if (parser.parseAttribute(moduleName) ||
1644 parser.parseOptionalKeyword(
"if") || parser.parseAttribute(targetName))
1646 moduleNames.push_back(moduleName);
1647 caseNames.push_back(targetName);
1650 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1651 if (parser.getCurrentLocation(¶metersLoc) ||
1654 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1656 parser.parseArrow() ||
1658 parser.parseOptionalAttrDict(result.attributes)) {
1662 result.addAttribute(
"moduleNames",
1664 result.addAttribute(
"caseNames",
1666 result.addAttribute(
"argNames", argNames);
1667 result.addAttribute(
"resultNames", resultNames);
1668 result.addAttribute(
"parameters", parameters);
1669 result.addTypes(allResultTypes);
1673 void InstanceChoiceOp::print(OpAsmPrinter &p) {
1675 p.printAttributeWithoutType(getInstanceNameAttr());
1676 if (
auto attr = getInnerSymAttr()) {
1680 p <<
" option " << getOptionNameAttr() <<
' ';
1682 auto moduleNames = getModuleNamesAttr();
1683 auto caseNames = getCaseNamesAttr();
1684 assert(moduleNames.size() == caseNames.size() + 1);
1686 p.printAttributeWithoutType(moduleNames[0]);
1687 for (
size_t i = 0, n = caseNames.size(); i < n; ++i) {
1689 p.printAttributeWithoutType(moduleNames[i + 1]);
1691 p.printAttributeWithoutType(caseNames[i]);
1700 p.printOptionalAttrDict(
1701 (*this)->getAttrs(),
1703 InnerSymbolTable::getInnerSymbolAttrName(),
1704 "moduleNames",
"caseNames",
"argNames",
"resultNames",
1705 "parameters",
"optionName"});
1708 ArrayAttr InstanceChoiceOp::getReferencedModuleNamesAttr() {
1709 SmallVector<Attribute> moduleNames;
1710 for (Attribute attr : getModuleNamesAttr()) {
1711 moduleNames.push_back(cast<FlatSymbolRefAttr>(attr).
getAttr());
1725 if (
auto mod = dyn_cast<HWModuleOp>((*this)->getParentOp()))
1726 modType = mod.getHWModuleType();
1728 emitOpError(
"must have a module parent");
1731 auto modResults = modType.getOutputTypes();
1732 OperandRange outputValues = getOperands();
1733 if (modResults.size() != outputValues.size()) {
1734 emitOpError(
"must have same number of operands as region results.");
1739 for (
size_t i = 0, e = modResults.size(); i < e; ++i) {
1740 if (modResults[i] != outputValues[i].getType()) {
1741 emitOpError(
"output types must match module. In "
1743 << i <<
", expected " << modResults[i] <<
", but got "
1744 << outputValues[i].getType() <<
".";
1759 if (p.parseType(type))
1760 return p.emitError(p.getCurrentLocation(),
"Expected type");
1761 auto arrType = type_dyn_cast<ArrayType>(type);
1763 return p.emitError(p.getCurrentLocation(),
"Expected !hw.array type");
1765 unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements());
1772 p.printType(srcType);
1775 ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
1776 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
1777 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
1780 if (parser.parseOperandList(operands) ||
1781 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1782 parser.parseType(elemType))
1785 if (operands.size() == 0)
1786 return parser.emitError(inputOperandsLoc,
1787 "Cannot construct an array of length 0");
1790 for (
auto operand : operands)
1791 if (parser.resolveOperand(operand, elemType, result.operands))
1796 void ArrayCreateOp::print(OpAsmPrinter &p) {
1798 p.printOperands(getInputs());
1799 p.printOptionalAttrDict((*this)->getAttrs());
1800 p <<
" : " << getInputs()[0].getType();
1803 void ArrayCreateOp::build(OpBuilder &b, OperationState &state,
1804 ValueRange values) {
1805 assert(values.size() > 0 &&
"Cannot build array of zero elements");
1806 Type elemType = values[0].getType();
1809 [elemType](Value v) ->
bool {
return v.getType() == elemType; }) &&
1810 "All values must have same type.");
1811 build(b, state,
ArrayType::get(elemType, values.size()), values);
1815 unsigned returnSize = cast<ArrayType>(getType()).getNumElements();
1816 if (getInputs().size() != returnSize)
1821 OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) {
1822 if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) { return !attr; }))
1834 auto baseValue = constBase.getValue();
1835 auto indexValue = constIndex.getValue();
1837 unsigned bits = baseValue.getBitWidth();
1838 assert(bits == indexValue.getBitWidth() &&
"mismatched widths");
1840 if (bits < 64 && offset >= (1ull << bits))
1843 APInt baseExt = baseValue.zextOrTrunc(bits + 1);
1844 APInt indexExt = indexValue.zextOrTrunc(bits + 1);
1845 return baseExt + offset == indexExt;
1853 PatternRewriter &rewriter) {
1855 auto arrayTy = hw::type_cast<ArrayType>(op.getType());
1856 if (arrayTy.getNumElements() <= 1)
1858 auto elemTy = arrayTy.getElementType();
1867 SmallVector<Chunk> chunks;
1868 for (Value value : llvm::reverse(op.getInputs())) {
1873 Value input =
get.getInput();
1874 Value index =
get.getIndex();
1875 if (!chunks.empty()) {
1876 auto &c = *chunks.rbegin();
1877 if (c.input ==
get.getInput() &&
isOffset(c.index, index, c.size)) {
1883 chunks.push_back(Chunk{input, index, 1});
1887 if (chunks.size() == 1) {
1888 auto &chunk = chunks[0];
1889 rewriter.replaceOp(op, rewriter.createOrFold<
ArraySliceOp>(
1890 op.getLoc(), arrayTy, chunk.input, chunk.index));
1896 if (chunks.size() * 2 < arrayTy.getNumElements()) {
1897 SmallVector<Value> slices;
1898 for (
auto &chunk : llvm::reverse(chunks)) {
1901 op.getLoc(), sliceTy, chunk.input, chunk.index));
1903 rewriter.replaceOpWithNewOp<
ArrayConcatOp>(op, arrayTy, slices);
1911 PatternRewriter &rewriter) {
1917 Value ArrayCreateOp::getUniformElement() {
1918 if (!getInputs().
empty() && llvm::all_equal(getInputs()))
1919 return getInputs()[0];
1924 auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
1926 return std::nullopt;
1927 APInt idxAttr = idxOp.getValue();
1928 if (idxAttr.getBitWidth() > 64)
1929 return std::nullopt;
1930 return idxAttr.getLimitedValue();
1934 unsigned inputSize =
1935 type_cast<ArrayType>(getInput().getType()).getNumElements();
1936 if (llvm::Log2_64_Ceil(inputSize) !=
1937 getLowIndex().getType().getIntOrFloatBitWidth())
1939 "ArraySlice: index width must match clog2 of array size");
1943 OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) {
1945 if (getType() == getInput().getType())
1951 PatternRewriter &rewriter) {
1952 auto sliceTy = hw::type_cast<ArrayType>(op.getType());
1953 auto elemTy = sliceTy.getElementType();
1954 uint64_t sliceSize = sliceTy.getNumElements();
1958 if (sliceSize == 1) {
1960 auto get = rewriter.create<
ArrayGetOp>(op.getLoc(), op.getInput(),
1962 rewriter.replaceOpWithNewOp<
ArrayCreateOp>(op, op.getType(),
1971 auto inputOp = op.getInput().getDefiningOp();
1972 if (
auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
1974 if (inputSlice == op)
1977 auto inputIndex = inputSlice.getLowIndex();
1979 if (!inputOffsetOpt)
1982 uint64_t offset = *offsetOpt + *inputOffsetOpt;
1984 rewriter.create<
ConstantOp>(op.getLoc(), inputIndex.getType(), offset);
1985 rewriter.replaceOpWithNewOp<
ArraySliceOp>(op, op.getType(),
1986 inputSlice.getInput(), lowIndex);
1990 if (
auto inputCreate = dyn_cast_or_null<ArrayCreateOp>(inputOp)) {
1992 auto inputs = inputCreate.getInputs();
1994 uint64_t begin = inputs.size() - *offsetOpt - sliceSize;
1995 rewriter.replaceOpWithNewOp<
ArrayCreateOp>(op, op.getType(),
1996 inputs.slice(begin, sliceSize));
2000 if (
auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2002 SmallVector<Value> chunks;
2003 uint64_t sliceStart = *offsetOpt;
2004 for (
auto input : llvm::reverse(inputConcat.getInputs())) {
2006 uint64_t inputSize =
2007 hw::type_cast<ArrayType>(input.getType()).getNumElements();
2008 if (inputSize == 0 || inputSize <= sliceStart) {
2009 sliceStart -= inputSize;
2014 uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize);
2015 uint64_t cutSize = cutEnd - sliceStart;
2016 assert(cutSize != 0 &&
"slice cannot be empty");
2018 if (cutSize == inputSize) {
2020 assert(sliceStart == 0 &&
"invalid cut size");
2021 chunks.push_back(input);
2024 unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize);
2026 op.getLoc(), rewriter.getIntegerType(
width), sliceStart);
2032 sliceSize -= cutSize;
2037 assert(chunks.size() > 0 &&
"missing sliced items");
2038 if (chunks.size() == 1)
2039 rewriter.replaceOp(op, chunks[0]);
2042 op, llvm::to_vector(llvm::reverse(chunks)));
2053 SmallVectorImpl<Type> &inputTypes,
2056 uint64_t resultSize = 0;
2058 auto parseElement = [&]() -> ParseResult {
2060 if (p.parseType(ty))
2062 auto arrTy = type_dyn_cast<ArrayType>(ty);
2064 return p.emitError(p.getCurrentLocation(),
"Expected !hw.array type");
2065 if (elemType && elemType != arrTy.getElementType())
2066 return p.emitError(p.getCurrentLocation(),
"Expected array element type ")
2069 elemType = arrTy.getElementType();
2070 inputTypes.push_back(ty);
2071 resultSize += arrTy.getNumElements();
2075 if (p.parseCommaSeparatedList(parseElement))
2083 TypeRange inputTypes, Type resultType) {
2084 llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; });
2087 void ArrayConcatOp::build(OpBuilder &b, OperationState &state,
2088 ValueRange values) {
2089 assert(!values.empty() &&
"Cannot build array of zero elements");
2090 ArrayType arrayTy = cast<ArrayType>(values[0].getType());
2091 Type elemTy = arrayTy.getElementType();
2092 assert(llvm::all_of(values,
2093 [elemTy](Value v) ->
bool {
2094 return isa<ArrayType>(v.getType()) &&
2095 cast<ArrayType>(v.getType()).getElementType() ==
2098 "All values must be of ArrayType with the same element type.");
2100 uint64_t resultSize = 0;
2101 for (Value val : values)
2102 resultSize += cast<ArrayType>(val.getType()).getNumElements();
2106 OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) {
2107 auto inputs = adaptor.getInputs();
2108 SmallVector<Attribute> array;
2109 for (
size_t i = 0, e = getNumOperands(); i < e; ++i) {
2112 llvm::copy(cast<ArrayAttr>(inputs[i]), std::back_inserter(array));
2119 for (
auto input : op.getInputs())
2123 SmallVector<Value> items;
2124 for (
auto input : op.getInputs()) {
2125 auto create = cast<ArrayCreateOp>(input.getDefiningOp());
2126 for (
auto item : create.getInputs())
2127 items.push_back(item);
2141 SmallVector<Location> locs;
2144 SmallVector<Value> items;
2145 std::optional<Slice> last;
2146 bool changed =
false;
2148 auto concatenate = [&] {
2153 items.push_back(last->op);
2162 auto origTy = hw::type_cast<ArrayType>(last->input.getType());
2163 auto arrayTy =
ArrayType::get(origTy.getElementType(), last->size);
2165 loc, arrayTy, last->input, last->index));
2170 auto append = [&](Value op, Value input, Value index,
size_t size) {
2175 if (last->input == input &&
isOffset(last->index, index, last->size)) {
2178 last->locs.push_back(op.getLoc());
2183 last.emplace(Slice{input, index, size, op, {op.getLoc()}});
2186 for (
auto item : llvm::reverse(op.getInputs())) {
2188 auto size = hw::type_cast<ArrayType>(slice.getType()).getNumElements();
2189 append(item, slice.getInput(), slice.getLowIndex(), size);
2194 if (create.getInputs().size() == 1) {
2195 if (
auto get = create.getInputs()[0].getDefiningOp<
ArrayGetOp>()) {
2203 items.push_back(item);
2210 if (items.size() == 1) {
2211 rewriter.replaceOp(op, items[0]);
2213 std::reverse(items.begin(), items.end());
2220 PatternRewriter &rewriter) {
2236 ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) {
2243 auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
2244 if (parser.parseKeyword(&field) || parser.parseColonType(type))
2253 result.addAttribute(
"field", fieldAttr);
2254 result.addTypes(type);
2259 void EnumConstantOp::print(OpAsmPrinter &p) {
2260 p <<
" " << getField().getField().getValue() <<
" : "
2261 << getField().getType().getValue();
2265 function_ref<
void(Value, StringRef)> setNameFn) {
2266 setNameFn(getResult(), getField().getField().str());
2269 void EnumConstantOp::build(OpBuilder &builder, OperationState &odsState,
2270 EnumFieldAttr field) {
2271 return build(builder, odsState, field.getType().getValue(), field);
2274 OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) {
2275 assert(adaptor.getOperands().empty() &&
"constant has no operands");
2276 return getFieldAttr();
2280 auto fieldAttr = getFieldAttr();
2281 auto fieldType = fieldAttr.getType().getValue();
2284 if (fieldType != getType())
2285 emitOpError(
"return type ")
2286 << getType() <<
" does not match attribute type " << fieldAttr;
2296 auto lhsType = type_cast<EnumType>(getLhs().getType());
2297 auto rhsType = type_cast<EnumType>(getRhs().getType());
2298 if (rhsType != lhsType)
2299 emitOpError(
"types do not match");
2307 ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2308 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2309 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2310 Type declOrAliasType;
2312 if (parser.parseLParen() || parser.parseOperandList(operands) ||
2313 parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
2314 parser.parseColonType(declOrAliasType))
2317 auto declType = type_dyn_cast<StructType>(declOrAliasType);
2319 return parser.emitError(parser.getNameLoc(),
2320 "expected !hw.struct type or alias");
2322 llvm::SmallVector<Type, 4> structInnerTypes;
2323 declType.getInnerTypes(structInnerTypes);
2324 result.addTypes(declOrAliasType);
2326 if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc,
2332 void StructCreateOp::print(OpAsmPrinter &printer) {
2334 printer.printOperands(getInput());
2336 printer.printOptionalAttrDict((*this)->getAttrs());
2337 printer <<
" : " << getType();
2341 auto elements = hw::type_cast<StructType>(getType()).getElements();
2343 if (elements.size() != getInput().size())
2344 return emitOpError(
"structure field count mismatch");
2346 for (
const auto &[field, value] : llvm::zip(elements, getInput()))
2347 if (field.type != value.getType())
2348 return emitOpError(
"structure field `")
2349 << field.name <<
"` type does not match";
2354 OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
2356 if (!getInput().
empty())
2357 if (
auto explodeOp = getInput()[0].getDefiningOp<StructExplodeOp>();
2358 explodeOp && getInput() == explodeOp.getResults() &&
2359 getResult().getType() == explodeOp.getInput().getType())
2360 return explodeOp.getInput();
2362 auto inputs = adaptor.getInput();
2363 if (llvm::any_of(inputs, [](Attribute attr) {
return !attr; }))
2372 ParseResult StructExplodeOp::parse(OpAsmParser &parser,
2373 OperationState &result) {
2374 OpAsmParser::UnresolvedOperand operand;
2377 if (parser.parseOperand(operand) ||
2378 parser.parseOptionalAttrDict(result.attributes) ||
2379 parser.parseColonType(declType))
2381 auto structType = type_dyn_cast<StructType>(declType);
2383 return parser.emitError(parser.getNameLoc(),
2384 "invalid kind of type specified");
2386 llvm::SmallVector<Type, 4> structInnerTypes;
2387 structType.getInnerTypes(structInnerTypes);
2388 result.addTypes(structInnerTypes);
2390 if (parser.resolveOperand(operand, declType, result.operands))
2395 void StructExplodeOp::print(OpAsmPrinter &printer) {
2397 printer.printOperand(getInput());
2398 printer.printOptionalAttrDict((*this)->getAttrs());
2399 printer <<
" : " << getInput().getType();
2402 LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor,
2403 SmallVectorImpl<OpFoldResult> &results) {
2404 auto input = adaptor.getInput();
2407 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(results));
2412 PatternRewriter &rewriter) {
2413 auto *inputOp = op.getInput().getDefiningOp();
2414 auto elements = type_cast<StructType>(op.getInput().getType()).getElements();
2415 auto result = failure();
2416 auto opResults = op.getResults();
2417 for (uint32_t index = 0; index < elements.size(); index++) {
2419 rewriter.replaceAllUsesWith(opResults[index], foldResult);
2427 function_ref<
void(Value, StringRef)> setNameFn) {
2428 auto structType = type_cast<StructType>(getInput().getType());
2429 for (
auto [res, field] : llvm::zip(getResults(), structType.getElements()))
2430 setNameFn(res, field.name.str());
2433 void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2435 StructType inputType = dyn_cast<StructType>(input.getType());
2437 SmallVector<Type, 16> fieldTypes;
2438 for (
auto field : inputType.getElements())
2439 fieldTypes.push_back(field.type);
2440 build(odsBuilder, odsState, fieldTypes, input);
2449 template <
typename AggregateOp,
typename AggregateType>
2451 AggregateType aggType,
2453 auto index = op.getFieldIndex();
2454 if (index >= aggType.getElements().size())
2455 return op.emitOpError() <<
"field index " << index
2456 <<
" exceeds element count of aggregate type";
2460 return op.emitOpError()
2461 <<
"type " << aggType.getElements()[index].type
2462 <<
" of accessed field in aggregate at index " << index
2463 <<
" does not match expected type " <<
elementType;
2469 return verifyAggregateFieldIndexAndType<StructExtractOp, StructType>(
2470 *
this, getInput().getType(), getType());
2475 template <
typename AggregateType>
2477 OpAsmParser::UnresolvedOperand operand;
2478 StringAttr fieldName;
2481 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2482 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2483 parser.parseOptionalAttrDict(result.attributes) ||
2484 parser.parseColonType(declType))
2486 auto aggType = type_dyn_cast<AggregateType>(declType);
2488 return parser.emitError(parser.getNameLoc(),
2489 "invalid kind of type specified");
2491 auto fieldIndex = aggType.getFieldIndex(fieldName);
2493 parser.emitError(parser.getNameLoc(),
"field name '" +
2494 fieldName.getValue() +
2495 "' not found in aggregate type");
2501 result.addAttribute(
"fieldIndex", indexAttr);
2502 Type resultType = aggType.getElements()[*fieldIndex].type;
2503 result.addTypes(resultType);
2505 if (parser.resolveOperand(operand, declType, result.operands))
2512 template <
typename AggType>
2515 printer.printOperand(op.getInput());
2516 printer <<
"[\"" << op.getFieldName() <<
"\"]";
2517 printer.printOptionalAttrDict(op->getAttrs(), {
"fieldIndex"});
2518 printer <<
" : " << op.getInput().getType();
2521 ParseResult StructExtractOp::parse(OpAsmParser &parser,
2522 OperationState &result) {
2523 return parseExtractOp<StructType>(parser, result);
2526 void StructExtractOp::print(OpAsmPrinter &printer) {
2530 void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2531 Value input, StructType::FieldInfo field) {
2533 type_cast<StructType>(input.getType()).getFieldIndex(field.name);
2534 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2535 build(builder, odsState, field.type, input, *fieldIndex);
2538 void StructExtractOp::build(OpBuilder &builder, OperationState &odsState,
2539 Value input, StringAttr fieldName) {
2540 auto structType = type_cast<StructType>(input.getType());
2541 auto fieldIndex = structType.getFieldIndex(fieldName);
2542 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2543 auto resultType = structType.getElements()[*fieldIndex].type;
2544 build(builder, odsState, resultType, input, *fieldIndex);
2547 OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
2548 if (
auto constOperand = adaptor.getInput()) {
2550 auto operandAttr = llvm::cast<ArrayAttr>(constOperand);
2551 return operandAttr.getValue()[getFieldIndex()];
2554 if (
auto foldResult =
2561 PatternRewriter &rewriter) {
2562 auto inputOp = op.getInput().getDefiningOp();
2565 if (
auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
2566 if (structInject.getFieldIndex() != op.getFieldIndex()) {
2568 op, op.getType(), structInject.getInput(), op.getFieldIndexAttr());
2577 function_ref<
void(Value, StringRef)> setNameFn) {
2585 void StructInjectOp::build(OpBuilder &builder, OperationState &odsState,
2586 Value input, StringAttr fieldName, Value newValue) {
2587 auto structType = type_cast<StructType>(input.getType());
2588 auto fieldIndex = structType.getFieldIndex(fieldName);
2589 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2590 build(builder, odsState, input, *fieldIndex, newValue);
2594 return verifyAggregateFieldIndexAndType<StructInjectOp, StructType>(
2595 *
this, getInput().getType(), getNewValue().getType());
2598 ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) {
2599 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2600 OpAsmParser::UnresolvedOperand operand, val;
2601 StringAttr fieldName;
2604 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2605 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2606 parser.parseComma() || parser.parseOperand(val) ||
2607 parser.parseOptionalAttrDict(result.attributes) ||
2608 parser.parseColonType(declType))
2610 auto structType = type_dyn_cast<StructType>(declType);
2612 return parser.emitError(inputOperandsLoc,
"invalid kind of type specified");
2614 auto fieldIndex = structType.getFieldIndex(fieldName);
2616 parser.emitError(parser.getNameLoc(),
"field name '" +
2617 fieldName.getValue() +
2618 "' not found in aggregate type");
2624 result.addAttribute(
"fieldIndex", indexAttr);
2625 result.addTypes(declType);
2627 Type resultType = structType.getElements()[*fieldIndex].type;
2628 if (parser.resolveOperands({operand, val}, {declType, resultType},
2629 inputOperandsLoc, result.operands))
2634 void StructInjectOp::print(OpAsmPrinter &printer) {
2636 printer.printOperand(getInput());
2638 printer.printOperand(getNewValue());
2639 printer.printOptionalAttrDict((*this)->getAttrs(), {
"fieldIndex"});
2640 printer <<
" : " << getInput().getType();
2643 OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
2644 auto input = adaptor.getInput();
2645 auto newValue = adaptor.getNewValue();
2646 if (!input || !newValue)
2648 SmallVector<Attribute> array;
2649 llvm::copy(cast<ArrayAttr>(input), std::back_inserter(array));
2650 array[getFieldIndex()] = newValue;
2655 PatternRewriter &rewriter) {
2657 SmallPtrSet<Operation *, 4> injects;
2658 DenseMap<StringAttr, Value> fields;
2661 StructInjectOp inject = op;
2664 if (!injects.insert(inject).second)
2667 fields.try_emplace(inject.getFieldNameAttr(), inject.getNewValue());
2668 input = inject.getInput();
2669 inject = dyn_cast_or_null<StructInjectOp>(input.getDefiningOp());
2671 assert(input &&
"missing input to inject chain");
2673 auto ty = hw::type_cast<StructType>(op.getType());
2674 auto elements = ty.getElements();
2677 if (fields.size() == elements.size()) {
2678 SmallVector<Value> createFields;
2679 for (
const auto &field : elements) {
2680 auto it = fields.find(field.name);
2681 assert(it != fields.end() &&
"missing field");
2682 createFields.push_back(it->second);
2684 rewriter.replaceOpWithNewOp<
StructCreateOp>(op, ty, createFields);
2689 if (injects.size() == fields.size())
2693 for (uint32_t fieldIndex = 0; fieldIndex < elements.size(); fieldIndex++) {
2694 auto it = fields.find(elements[fieldIndex].name);
2695 if (it == fields.end())
2697 input = rewriter.create<StructInjectOp>(op.getLoc(), ty, input, fieldIndex,
2701 rewriter.replaceOp(op, input);
2710 return verifyAggregateFieldIndexAndType<UnionCreateOp, UnionType>(
2711 *
this, getType(), getInput().getType());
2714 void UnionCreateOp::build(OpBuilder &builder, OperationState &odsState,
2715 Type unionType, StringAttr fieldName, Value input) {
2716 auto fieldIndex = type_cast<UnionType>(unionType).getFieldIndex(fieldName);
2717 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2718 build(builder, odsState, unionType, *fieldIndex, input);
2721 ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2722 Type declOrAliasType;
2723 StringAttr fieldName;
2724 OpAsmParser::UnresolvedOperand input;
2725 llvm::SMLoc fieldLoc = parser.getCurrentLocation();
2727 if (parser.parseAttribute(fieldName) || parser.parseComma() ||
2728 parser.parseOperand(input) ||
2729 parser.parseOptionalAttrDict(result.attributes) ||
2730 parser.parseColonType(declOrAliasType))
2733 auto declType = type_dyn_cast<UnionType>(declOrAliasType);
2735 return parser.emitError(parser.getNameLoc(),
2736 "expected !hw.union type or alias");
2738 auto fieldIndex = declType.getFieldIndex(fieldName);
2740 parser.emitError(fieldLoc,
"cannot find union field '")
2741 << fieldName.getValue() <<
'\'';
2747 result.addAttribute(
"fieldIndex", indexAttr);
2748 Type inputType = declType.getElements()[*fieldIndex].type;
2750 if (parser.resolveOperand(input, inputType, result.operands))
2752 result.addTypes({declOrAliasType});
2756 void UnionCreateOp::print(OpAsmPrinter &printer) {
2758 printer.printOperand(getInput());
2759 printer.printOptionalAttrDict((*this)->getAttrs(), {
"fieldIndex"});
2760 printer <<
" : " << getType();
2767 ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) {
2768 return parseExtractOp<UnionType>(parser, result);
2771 void UnionExtractOp::print(OpAsmPrinter &printer) {
2776 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
2777 DictionaryAttr attrs, mlir::OpaqueProperties properties,
2778 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
2779 auto unionElements =
2780 hw::type_cast<UnionType>((operands[0].getType())).getElements();
2781 unsigned fieldIndex =
2782 attrs.getAs<IntegerAttr>(
"fieldIndex").getValue().getZExtValue();
2783 if (fieldIndex >= unionElements.size()) {
2785 mlir::emitError(*loc,
"field index " + Twine(fieldIndex) +
2786 " exceeds element count of aggregate type");
2789 results.push_back(unionElements[fieldIndex].type);
2793 void UnionExtractOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2794 Value input, StringAttr fieldName) {
2795 auto unionType = type_cast<UnionType>(input.getType());
2796 auto fieldIndex = unionType.getFieldIndex(fieldName);
2797 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2798 auto resultType = unionType.getElements()[*fieldIndex].type;
2799 build(odsBuilder, odsState, resultType, input, *fieldIndex);
2811 OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) {
2812 auto inputCst = dyn_cast_or_null<ArrayAttr>(adaptor.getInput());
2813 auto indexCst = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex());
2818 auto indexVal = indexCst.getValue();
2819 if (indexVal.getBitWidth() < 64) {
2820 auto index = indexVal.getZExtValue();
2821 return inputCst[inputCst.size() - 1 - index];
2826 if (!inputCst.empty() && llvm::all_equal(inputCst))
2831 if (
auto bitcast = getInput().getDefiningOp<hw::BitcastOp>()) {
2832 auto intTy = dyn_cast<IntegerType>(getType());
2835 auto bitcastInputOp = bitcast.getInput().getDefiningOp<
hw::ConstantOp>();
2836 if (!bitcastInputOp)
2840 auto bitcastInputCst = bitcastInputOp.getValue();
2843 auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) *
2844 getType().getIntOrFloatBitWidth();
2847 intTy.getIntOrFloatBitWidth()));
2850 auto inputCreate = getInput().getDefiningOp<
ArrayCreateOp>();
2854 if (
auto uniformValue = inputCreate.getUniformElement())
2855 return uniformValue;
2857 if (!indexCst || indexCst.getValue().getBitWidth() > 64)
2860 uint64_t index = indexCst.getValue().getLimitedValue();
2861 auto createInputs = inputCreate.getInputs();
2862 if (index >= createInputs.size())
2864 return createInputs[createInputs.size() - index - 1];
2868 PatternRewriter &rewriter) {
2873 auto *inputOp = op.getInput().getDefiningOp();
2874 if (
auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
2876 auto offsetOp = inputSlice.getLowIndex();
2881 uint64_t offset = *offsetOpt + *idxOpt;
2883 rewriter.create<
ConstantOp>(op.getLoc(), offsetOp.getType(), offset);
2884 rewriter.replaceOpWithNewOp<
ArrayGetOp>(op, inputSlice.getInput(),
2889 if (
auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2891 uint64_t elemIndex = *idxOpt;
2892 for (
auto input : llvm::reverse(inputConcat.getInputs())) {
2893 size_t size = hw::type_cast<ArrayType>(input.getType()).getNumElements();
2894 if (elemIndex >= size) {
2899 unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size);
2901 op.getLoc(), rewriter.getIntegerType(indexWidth), elemIndex);
2903 rewriter.replaceOpWithNewOp<
ArrayGetOp>(op, input, newIdxOp);
2912 if (
auto innerGet = dyn_cast_or_null<hw::ArrayGetOp>(inputOp)) {
2917 SmallVector<Value> newValues;
2918 for (
auto operand : create.getOperands())
2920 op.getLoc(), operand, op.getIndex()));
2925 innerGet.getIndex());
2938 StringRef TypedeclOp::getPreferredName() {
2939 return getVerilogName().value_or(
getName());
2942 Type TypedeclOp::getAliasType() {
2943 auto parentScope = cast<hw::TypeScopeOp>(getOperation()->getParentOp());
2946 {FlatSymbolRefAttr::get(*this)}),
2954 OpFoldResult BitcastOp::fold(FoldAdaptor) {
2957 if (getOperand().getType() == getType())
2958 return getOperand();
2969 dyn_cast_or_null<BitcastOp>(op.getInput().getDefiningOp());
2972 auto bitcast = rewriter.createOrFold<
BitcastOp>(op.getLoc(), op.getType(),
2973 inputBitcast.getInput());
2974 rewriter.replaceOp(op, bitcast);
2980 return this->emitOpError(
"Bitwidth of input must match result");
2988 bool HierPathOp::dropModule(StringAttr moduleToDrop) {
2989 SmallVector<Attribute, 4> newPath;
2990 bool updateMade =
false;
2991 for (
auto nameRef : getNamepath()) {
2993 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
2994 if (ref.getModule() == moduleToDrop)
2997 newPath.push_back(ref);
2999 if (cast<FlatSymbolRefAttr>(nameRef).
getAttr() == moduleToDrop)
3002 newPath.push_back(nameRef);
3010 bool HierPathOp::inlineModule(StringAttr moduleToDrop) {
3011 SmallVector<Attribute, 4> newPath;
3012 bool updateMade =
false;
3013 StringRef inlinedInstanceName =
"";
3014 for (
auto nameRef : getNamepath()) {
3016 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3017 if (ref.getModule() == moduleToDrop) {
3018 inlinedInstanceName = ref.getName().getValue();
3020 }
else if (!inlinedInstanceName.empty()) {
3024 ref.getName().getValue())));
3025 inlinedInstanceName =
"";
3027 newPath.push_back(ref);
3029 if (cast<FlatSymbolRefAttr>(nameRef).
getAttr() == moduleToDrop)
3032 newPath.push_back(nameRef);
3040 bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) {
3041 SmallVector<Attribute, 4> newPath;
3042 bool updateMade =
false;
3043 for (
auto nameRef : getNamepath()) {
3045 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3046 if (ref.getModule() == oldMod) {
3050 newPath.push_back(ref);
3052 if (cast<FlatSymbolRefAttr>(nameRef).
getAttr() == oldMod) {
3056 newPath.push_back(nameRef);
3064 bool HierPathOp::updateModuleAndInnerRef(
3065 StringAttr oldMod, StringAttr newMod,
3066 const llvm::DenseMap<StringAttr, StringAttr> &innerSymRenameMap) {
3068 if (oldMod == newMod)
3071 auto namepathNew = getNamepath().getValue().vec();
3072 bool updateMade =
false;
3074 for (
auto &element : namepathNew) {
3075 if (
auto innerRef = dyn_cast<hw::InnerRefAttr>(element)) {
3076 if (innerRef.getModule() != oldMod)
3078 auto symName = innerRef.getName();
3081 auto to = innerSymRenameMap.find(symName);
3082 if (to != innerSymRenameMap.end())
3083 symName = to->second;
3088 if (element != fromRef)
3100 bool HierPathOp::truncateAtModule(StringAttr atMod,
bool includeMod) {
3101 SmallVector<Attribute, 4> newPath;
3102 bool updateMade =
false;
3103 for (
auto nameRef : getNamepath()) {
3105 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3106 if (ref.getModule() == atMod) {
3109 newPath.push_back(ref);
3111 newPath.push_back(ref);
3113 if (cast<FlatSymbolRefAttr>(nameRef).
getAttr() == atMod && !includeMod)
3116 newPath.push_back(nameRef);
3127 StringAttr HierPathOp::modPart(
unsigned i) {
3128 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3129 .Case<FlatSymbolRefAttr>([](
auto a) {
return a.getAttr(); })
3130 .Case<hw::InnerRefAttr>([](
auto a) {
return a.getModule(); });
3134 StringAttr HierPathOp::root() {
3140 bool HierPathOp::hasModule(StringAttr modName) {
3141 for (
auto nameRef : getNamepath()) {
3143 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef)) {
3144 if (ref.getModule() == modName)
3147 if (cast<FlatSymbolRefAttr>(nameRef).
getAttr() == modName)
3155 bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName)
const {
3156 for (
auto nameRef :
const_cast<HierPathOp *
>(
this)->getNamepath())
3157 if (
auto ref = dyn_cast<hw::InnerRefAttr>(nameRef))
3158 if (ref.getName() == symName && ref.getModule() == modName)
3166 StringAttr HierPathOp::refPart(
unsigned i) {
3167 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3168 .Case<FlatSymbolRefAttr>([](
auto a) {
return StringAttr({}); })
3169 .Case<hw::InnerRefAttr>([](
auto a) {
return a.getName(); });
3174 StringAttr HierPathOp::ref() {
3176 return refPart(getNamepath().size() - 1);
3180 StringAttr HierPathOp::leafMod() {
3182 return modPart(getNamepath().size() - 1);
3187 bool HierPathOp::isModule() {
return !ref(); }
3191 bool HierPathOp::isComponent() {
return (
bool)ref(); }
3206 LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
3207 ArrayAttr expectedModuleNames = {};
3208 auto checkExpectedModule = [&](Attribute name) -> LogicalResult {
3209 if (!expectedModuleNames)
3211 if (llvm::any_of(expectedModuleNames,
3212 [name](Attribute attr) {
return attr == name; }))
3214 auto diag = emitOpError() <<
"instance path is incorrect. Expected ";
3215 size_t n = expectedModuleNames.size();
3219 for (
size_t i = 0; i < n; ++i) {
3221 diag << ((i + 1 == n) ?
" or " :
", ");
3222 diag << cast<StringAttr>(expectedModuleNames[i]);
3224 diag <<
". Instead found: " << name;
3228 if (!getNamepath() || getNamepath().
empty())
3229 return emitOpError() <<
"the instance path cannot be empty";
3230 for (
unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) {
3231 hw::InnerRefAttr innerRef = dyn_cast<hw::InnerRefAttr>(getNamepath()[i]);
3233 return emitOpError()
3234 <<
"the instance path can only contain inner sym reference"
3235 <<
", only the leaf can refer to a module symbol";
3237 if (failed(checkExpectedModule(innerRef.getModule())))
3240 auto instOp = ns.lookupOp<igraph::InstanceOpInterface>(innerRef);
3242 return emitOpError() <<
" module: " << innerRef.getModule()
3243 <<
" does not contain any instance with symbol: "
3244 << innerRef.getName();
3245 expectedModuleNames = instOp.getReferencedModuleNamesAttr();
3249 auto leafRef = getNamepath()[getNamepath().size() - 1];
3250 if (
auto innerRef = dyn_cast<hw::InnerRefAttr>(leafRef)) {
3251 if (!ns.lookup(innerRef)) {
3252 return emitOpError() <<
" operation with symbol: " << innerRef
3253 <<
" was not found ";
3255 if (failed(checkExpectedModule(innerRef.getModule())))
3257 }
else if (failed(checkExpectedModule(
3258 cast<FlatSymbolRefAttr>(leafRef).
getAttr()))) {
3264 void HierPathOp::print(OpAsmPrinter &p) {
3268 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
3269 if (
auto visibility =
3270 getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
3271 p << visibility.getValue() <<
' ';
3273 p.printSymbolName(getSymName());
3275 llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) {
3276 if (
auto ref = dyn_cast<hw::InnerRefAttr>(attr)) {
3277 p.printSymbolName(ref.getModule().getValue());
3279 p.printSymbolName(ref.getName().getValue());
3281 p.printSymbolName(cast<FlatSymbolRefAttr>(attr).getValue());
3285 p.printOptionalAttrDict(
3286 (*this)->getAttrs(),
3287 {SymbolTable::getSymbolAttrName(),
"namepath", visibilityAttrName});
3290 ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) {
3292 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
3296 if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(),
3301 SmallVector<Attribute> namepath;
3302 if (parser.parseCommaSeparatedList(
3303 OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
3304 auto loc = parser.getCurrentLocation();
3306 if (parser.parseAttribute(ref))
3310 auto pathLength = ref.getNestedReferences().size();
3311 if (pathLength == 0)
3313 FlatSymbolRefAttr::get(ref.getRootReference()));
3314 else if (pathLength == 1)
3315 namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(),
3316 ref.getLeafReference()));
3318 return parser.emitError(loc,
3319 "only one nested reference is allowed");
3323 result.addAttribute(
"namepath",
3326 if (parser.parseOptionalAttrDict(result.attributes))
3336 void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
3337 EventControlAttr event, Value trigger,
3338 ValueRange inputs) {
3339 odsState.addOperands(trigger);
3340 odsState.addOperands(inputs);
3341 odsState.addAttribute(getEventAttrName(odsState.name), event);
3342 auto *r = odsState.addRegion();
3346 llvm::SmallVector<Location> argLocs;
3347 llvm::transform(inputs, std::back_inserter(argLocs),
3348 [&](Value v) {
return v.getLoc(); });
3349 b->addArguments(inputs.getTypes(), argLocs);
3357 #define GET_OP_CLASSES
3358 #include "circt/Dialect/HW/HW.cpp.inc"
assert(baseType &&"element must be base type")
static Attribute getAttr(ArrayRef< NamedAttribute > attrs, StringRef name)
Get an attribute by name from a list of named attributes.
static LogicalResult verifyModuleCommon(HWModuleLike module)
static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, Type resultType)
static void printModuleOp(OpAsmPrinter &p, ModuleTy mod)
static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter)
static LogicalResult foldCreateToSlice(ArrayCreateOp op, PatternRewriter &rewriter)
static SmallVector< Location > getAllPortLocs(ModTy module)
static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context, ArrayRef< Attribute > attrs)
FunctionType getHWModuleOpType(Operation *op)
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, const ModulePortInfo &ports, ArrayAttr parameters, ArrayRef< NamedAttribute > attributes, StringAttr comment)
static void printExtractOp(OpAsmPrinter &printer, AggType op)
Use the same printer for both struct_extract and union_extract since the syntax is identical.
static void modifyModulePorts(Operation *op, ArrayRef< std::pair< unsigned, PortInfo >> insertInputs, ArrayRef< std::pair< unsigned, PortInfo >> insertOutputs, ArrayRef< unsigned > removeInputs, ArrayRef< unsigned > removeOutputs, Block *body=nullptr)
Insert and remove ports of a module.
static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, TypeRange inputTypes, Type resultType)
static void modifyModuleArgs(MLIRContext *context, ArrayRef< std::pair< unsigned, PortInfo >> insertArgs, ArrayRef< unsigned > removeArgs, ArrayRef< Attribute > oldArgNames, ArrayRef< Type > oldArgTypes, ArrayRef< Attribute > oldArgAttrs, ArrayRef< Location > oldArgLocs, SmallVector< Attribute > &newArgNames, SmallVector< Type > &newArgTypes, SmallVector< Attribute > &newArgAttrs, SmallVector< Location > &newArgLocs, Block *body=nullptr)
Internal implementation of argument/result insertion and removal on modules.
static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, Type &idxType)
static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex)
static bool hasAttribute(StringRef name, ArrayRef< NamedAttribute > attrs)
static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter)
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result)
Use the same parser for both struct_extract and union_extract since the syntax is identical.
static void setAllPortNames(ArrayRef< Attribute > names, ModTy module)
static void getAsmBlockArgumentNamesImpl(mlir::Region ®ion, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
static void setHWModuleType(ModTy &mod, ModuleType type)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, Type &resultType)
static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type)
static std::optional< uint64_t > getUIntFromValue(Value value)
static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op, AggregateType aggType, Type elementType)
Ensure an aggregate op's field index is within the bounds of the aggregate type and the accessed fiel...
static PortInfo getPort(ModuleTy &mod, size_t idx)
static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, Type idxType)
static bool hasAdditionalAttributes(Op op, ArrayRef< StringRef > ignoredAttrs={})
Check whether an operation has any additional attributes set beyond its standard list of attributes r...
static ParseResult parseArrayConcatTypes(OpAsmParser &p, SmallVectorImpl< Type > &inputTypes, Type &resultType)
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static InstancePath empty
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
void setOutput(unsigned i, Value v)
Value getInput(unsigned i)
llvm::SmallVector< Value > outputOperands
llvm::SmallVector< Value > inputArgs
llvm::StringMap< unsigned > outputIdx
llvm::StringMap< unsigned > inputIdx
HWModulePortAccessor(Location loc, const ModulePortInfo &info, Region &bodyRegion)
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
This helps visit TypeOp nodes.
ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args)
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
uint64_t getWidth(Type t)
LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results, llvm::function_ref< FIRRTLType(ValueRange, ArrayRef< NamedAttribute >, std::optional< Location >)> callback)
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
LogicalResult verifyParameterStructure(ArrayAttr parameters, ArrayAttr moduleParameters, const EmitErrorFn &emitError)
Check that all the parameter values specified to the instance are structurally valid.
std::function< void(std::function< bool(InFlightDiagnostic &)>)> EmitErrorFn
Whenever the nested function returns true, a note referring to the referenced module is attached to t...
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
bool isOffset(Value base, Value index, uint64_t offset)
llvm::function_ref< void(OpBuilder &, HWModulePortAccessor &)> HWModuleBuilder
bool isCombinational(Operation *op)
Return true if the specified operation is a combinational logic op.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
LogicalResult checkParameterInContext(Attribute value, Operation *module, Operation *usingOp, bool disallowParamRefs=false)
Check parameter specified by value to see if it is valid within the scope of the specified module mod...
bool isValidParameterExpression(Attribute attr, Operation *module)
Return true if the specified attribute tree is made up of nodes that are valid in a parameter express...
bool isValidIndexBitWidth(Value index, Value array)
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
bool isAnyModuleOrInstance(Operation *module)
TODO: Move all these functions to a hw::ModuleLike interface.
StringAttr getVerilogModuleNameAttr(Operation *module)
Returns the verilog module name attribute or symbol name of any module-like operations.
mlir::Type getCanonicalType(mlir::Type type)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
ParseResult parseOptionalParameterList(OpAsmParser &parser, ArrayAttr ¶meters)
Parse an parameter list if present.
void printOptionalParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
This holds a decoded list of input/inout and output ports for a module or instance.
size_t sizeOutputs() const
PortInfo & at(size_t idx)
size_t sizeInputs() const
PortDirectionRange getInputs()
PortDirectionRange getOutputs()
This holds the name, type, direction of a module's ports.