24 #include "mlir/IR/Builders.h"
25 #include "mlir/IR/PatternMatch.h"
26 #include "mlir/Interfaces/FunctionImplementation.h"
27 #include "llvm/ADT/BitVector.h"
28 #include "llvm/ADT/SmallPtrSet.h"
29 #include "llvm/ADT/StringSet.h"
31 using namespace circt;
33 using mlir::TypedAttr;
45 llvm_unreachable(
"unknown PortDirection");
49 hw::ArrayType arrayType =
51 assert(arrayType &&
"expected array type");
52 unsigned indexWidth = index.getType().getIntOrFloatBitWidth();
53 auto requiredWidth = llvm::Log2_64_Ceil(arrayType.getNumElements());
54 return requiredWidth == 0 ? (indexWidth == 0 || indexWidth == 1)
55 : indexWidth == requiredWidth;
60 struct IsCombClassifier :
public TypeOpVisitor<IsCombClassifier, bool> {
61 bool visitInvalidTypeOp(Operation *op) {
return false; }
62 bool visitUnhandledTypeOp(Operation *op) {
return true; }
65 return (op->getDialect() && op->getDialect()->getNamespace() ==
"comb") ||
71 if (
auto structCreate = dyn_cast_or_null<StructCreateOp>(inputOp)) {
72 return structCreate.getOperand(fieldIndex);
76 if (
auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
77 if (structInject.getFieldIndex() != fieldIndex)
79 return structInject.getNewValue();
85 ArrayRef<Attribute> attrs) {
90 if (a && !cast<DictionaryAttr>(a).
empty()) {
106 auto module = cast<HWModuleOp>(region.getParentOp());
108 auto *block = ®ion.front();
109 for (
size_t i = 0, e = block->getNumArguments(); i != e; ++i) {
110 auto name = module.getInputName(i);
112 setNameFn(block->getArgument(i), name);
129 Attribute value, ArrayAttr moduleParameters,
131 bool disallowParamRefs) {
134 if (value.isa<IntegerAttr>() || value.isa<FloatAttr>() ||
135 value.isa<StringAttr>() || value.isa<ParamVerbatimAttr>())
139 if (
auto expr = value.dyn_cast<ParamExprAttr>()) {
140 for (
auto op : expr.getOperands())
149 if (
auto parameterRef = value.dyn_cast<ParamDeclRefAttr>()) {
150 auto nameAttr = parameterRef.getName();
154 if (disallowParamRefs) {
155 instanceError([&](
auto &diag) {
156 diag <<
"parameter " << nameAttr
157 <<
" cannot be used as a default value for a parameter";
164 for (
auto param : moduleParameters) {
165 auto paramAttr = param.cast<ParamDeclAttr>();
166 if (paramAttr.getName() != nameAttr)
170 if (paramAttr.getType() == parameterRef.getType())
173 instanceError([&](
auto &diag) {
174 diag <<
"parameter " << nameAttr <<
" used with type "
175 << parameterRef.getType() <<
"; should have type "
176 << paramAttr.getType();
182 instanceError([&](
auto &diag) {
183 diag <<
"use of unknown parameter " << nameAttr;
189 instanceError([&](
auto &diag) {
190 diag <<
"invalid parameter value " << value;
204 bool disallowParamRefs) {
206 [&](
const std::function<bool(InFlightDiagnostic &)> &fn) {
208 auto diag = usingOp->emitOpError();
210 diag.attachNote(module->getLoc()) <<
"module declared here";
215 module->getAttrOfType<ArrayAttr>(
"parameters"),
216 emitError, disallowParamRefs);
230 for (
auto [i, barg] : llvm::enumerate(bodyRegion.getArguments())) {
262 void ConstantOp::print(OpAsmPrinter &p) {
264 p.printAttribute(getValueAttr());
265 p.printOptionalAttrDict((*this)->getAttrs(), {
"value"});
268 ParseResult ConstantOp::parse(OpAsmParser &parser, OperationState &result) {
269 IntegerAttr valueAttr;
271 if (parser.parseAttribute(valueAttr,
"value", result.attributes) ||
272 parser.parseOptionalAttrDict(result.attributes))
275 result.addTypes(valueAttr.getType());
279 LogicalResult ConstantOp::verify() {
283 "hw.constant attribute bitwidth doesn't match return type");
290 void ConstantOp::build(OpBuilder &
builder, OperationState &result,
291 const APInt &value) {
294 auto attr =
builder.getIntegerAttr(type, value);
295 return build(
builder, result, type, attr);
300 void ConstantOp::build(OpBuilder &
builder, OperationState &result,
302 return build(
builder, result, value.getType(), value);
309 void ConstantOp::build(OpBuilder &
builder, OperationState &result, Type type,
311 auto numBits = type.cast<IntegerType>().
getWidth();
312 build(
builder, result, APInt(numBits, (uint64_t)value,
true));
316 function_ref<
void(Value, StringRef)> setNameFn) {
317 auto intTy = getType();
318 auto intCst = getValue();
321 if (intTy.cast<IntegerType>().getWidth() == 1)
322 return setNameFn(getResult(), intCst.isZero() ?
"false" :
"true");
325 SmallVector<char, 32> specialNameBuffer;
326 llvm::raw_svector_ostream specialName(specialNameBuffer);
327 specialName <<
'c' << intCst <<
'_' << intTy;
328 setNameFn(getResult(), specialName.str());
331 OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) {
332 assert(adaptor.getOperands().empty() &&
"constant has no operands");
333 return getValueAttr();
344 ArrayRef<StringRef> ignoredAttrs = {}) {
345 auto names = op.getAttributeNames();
346 llvm::SmallDenseSet<StringRef> nameSet;
347 nameSet.reserve(names.size() + ignoredAttrs.size());
348 nameSet.insert(names.begin(), names.end());
349 nameSet.insert(ignoredAttrs.begin(), ignoredAttrs.end());
350 return llvm::any_of(op->getAttrs(), [&](
auto namedAttr) {
351 return !nameSet.contains(namedAttr.getName());
357 auto nameAttr = (*this)->getAttrOfType<StringAttr>(
"name");
358 if (nameAttr && !nameAttr.getValue().empty())
359 setNameFn(getResult(), nameAttr.getValue());
362 std::optional<size_t> WireOp::getTargetResultIndex() {
return 0; }
364 OpFoldResult WireOp::fold(FoldAdaptor adaptor) {
373 LogicalResult WireOp::canonicalize(WireOp wire, PatternRewriter &rewriter) {
379 if (wire.getInnerSymAttr())
384 if (
auto *inputOp = wire.getInput().getDefiningOp())
386 rewriter.modifyOpInPlace(inputOp,
387 [&] { inputOp->setAttr(
"sv.namehint", name); });
389 rewriter.replaceOp(wire, wire.getInput());
399 if (
auto typeAlias = type.dyn_cast<TypeAliasType>())
400 type = typeAlias.getCanonicalType();
402 if (
auto structType = type.dyn_cast<StructType>()) {
403 auto arrayAttr = attr.dyn_cast<ArrayAttr>();
405 return op->emitOpError(
"expected array attribute for constant of type ")
407 if (structType.getElements().size() != arrayAttr.size())
408 return op->emitOpError(
"array attribute (")
409 << arrayAttr.size() <<
") has wrong size for struct constant ("
410 << structType.getElements().size() <<
")";
412 for (
auto [attr, fieldInfo] :
413 llvm::zip(arrayAttr.getValue(), structType.getElements())) {
417 }
else if (
auto arrayType = type.dyn_cast<ArrayType>()) {
418 auto arrayAttr = attr.dyn_cast<ArrayAttr>();
420 return op->emitOpError(
"expected array attribute for constant of type ")
422 if (arrayType.getNumElements() != arrayAttr.size())
423 return op->emitOpError(
"array attribute (")
424 << arrayAttr.size() <<
") has wrong size for array constant ("
425 << arrayType.getNumElements() <<
")";
428 for (
auto attr : arrayAttr.getValue()) {
432 }
else if (
auto arrayType = type.dyn_cast<UnpackedArrayType>()) {
433 auto arrayAttr = attr.dyn_cast<ArrayAttr>();
435 return op->emitOpError(
"expected array attribute for constant of type ")
438 if (arrayType.getNumElements() != arrayAttr.size())
439 return op->emitOpError(
"array attribute (")
441 <<
") has wrong size for unpacked array constant ("
442 << arrayType.getNumElements() <<
")";
444 for (
auto attr : arrayAttr.getValue()) {
448 }
else if (
auto enumType = type.dyn_cast<EnumType>()) {
449 auto stringAttr = attr.dyn_cast<StringAttr>();
451 return op->emitOpError(
"expected string attribute for constant of type ")
453 }
else if (
auto intType = type.dyn_cast<IntegerType>()) {
455 auto intAttr = attr.dyn_cast<IntegerAttr>();
457 return op->emitOpError(
"expected integer attribute for constant of type ")
460 if (intAttr.getValue().getBitWidth() != intType.getWidth())
461 return op->emitOpError(
"hw.constant attribute bitwidth "
462 "doesn't match return type");
464 return op->emitOpError(
"unknown element type") << type;
469 LogicalResult AggregateConstantOp::verify() {
473 OpFoldResult AggregateConstantOp::fold(FoldAdaptor) {
return getFieldsAttr(); }
481 if (p.parseType(resultType) || p.parseEqual() ||
482 p.parseAttribute(value, resultType))
489 p << resultType <<
" = ";
490 p.printAttributeWithoutType(value);
493 LogicalResult ParamValueOp::verify() {
499 OpFoldResult ParamValueOp::fold(FoldAdaptor adaptor) {
500 assert(adaptor.getOperands().empty() &&
"hw.param.value has no operands");
501 return getValueAttr();
510 return isa<HWModuleLike, InstanceOp>(moduleOrInstance);
516 return TypeSwitch<Operation *, FunctionType>(moduleOrInstance)
517 .Case<InstanceOp, InstanceChoiceOp>([](
auto instance) {
518 SmallVector<Type>
inputs(instance->getOperandTypes());
519 SmallVector<Type> results(instance->getResultTypes());
523 [](
auto mod) {
return mod.getHWModuleType().getFuncType(); })
524 .Default([](Operation *op) {
525 return cast<mlir::FunctionOpInterface>(op)
527 .cast<FunctionType>();
535 auto nameAttr = module->getAttrOfType<StringAttr>(
"verilogName");
539 return module->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
542 template <
typename ModuleTy>
546 ArrayRef<NamedAttribute> attributes, StringAttr comment) {
547 using namespace mlir::function_interface_impl;
550 result.addAttribute(SymbolTable::getSymbolAttrName(), name);
552 SmallVector<Attribute> perPortAttrs;
553 SmallVector<ModulePort> portTypes;
555 for (
auto elt : ports) {
556 portTypes.push_back(elt);
557 llvm::SmallVector<NamedAttribute> portAttrs;
559 llvm::copy(elt.attrs, std::back_inserter(portAttrs));
560 perPortAttrs.push_back(
builder.getDictionaryAttr(portAttrs));
565 parameters =
builder.getArrayAttr({});
569 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name),
571 result.addAttribute(
"per_port_attrs",
573 result.addAttribute(
"parameters", parameters);
575 comment =
builder.getStringAttr(
"");
576 result.addAttribute(
"comment", comment);
577 result.addAttributes(attributes);
583 MLIRContext *context, ArrayRef<std::pair<unsigned, PortInfo>> insertArgs,
584 ArrayRef<unsigned> removeArgs, ArrayRef<Attribute> oldArgNames,
585 ArrayRef<Type> oldArgTypes, ArrayRef<Attribute> oldArgAttrs,
586 ArrayRef<Location> oldArgLocs, SmallVector<Attribute> &newArgNames,
587 SmallVector<Type> &newArgTypes, SmallVector<Attribute> &newArgAttrs,
588 SmallVector<Location> &newArgLocs, Block *body =
nullptr) {
593 assert(llvm::is_sorted(insertArgs,
594 [](
auto &a,
auto &b) {
return a.first < b.first; }) &&
595 "insertArgs must be in ascending order");
596 assert(llvm::is_sorted(removeArgs, [](
auto &a,
auto &b) {
return a < b; }) &&
597 "removeArgs must be in ascending order");
600 auto oldArgCount = oldArgTypes.size();
601 auto newArgCount = oldArgCount + insertArgs.size() - removeArgs.size();
602 assert((
int)newArgCount >= 0);
604 newArgNames.reserve(newArgCount);
605 newArgTypes.reserve(newArgCount);
606 newArgAttrs.reserve(newArgCount);
607 newArgLocs.reserve(newArgCount);
613 BitVector erasedIndices;
615 erasedIndices.resize(oldArgCount + insertArgs.size());
617 for (
unsigned argIdx = 0, idx = 0; argIdx <= oldArgCount; ++argIdx, ++idx) {
619 while (!insertArgs.empty() && insertArgs[0].first == argIdx) {
620 auto port = insertArgs[0].second;
624 auto sym = port.getSym();
626 (sym && !sym.empty())
629 newArgNames.push_back(port.name);
630 newArgTypes.push_back(port.type);
631 newArgAttrs.push_back(attr);
632 insertArgs = insertArgs.drop_front();
633 LocationAttr loc = port.loc ? port.loc : unknownLoc;
634 newArgLocs.push_back(loc);
636 body->insertArgument(idx++, port.type, loc);
638 if (argIdx == oldArgCount)
642 bool removed =
false;
643 while (!removeArgs.empty() && removeArgs[0] == argIdx) {
644 removeArgs = removeArgs.drop_front();
650 erasedIndices.set(idx);
652 newArgNames.push_back(oldArgNames[argIdx]);
653 newArgTypes.push_back(oldArgTypes[argIdx]);
654 newArgAttrs.push_back(oldArgAttrs.empty() ? emptyDictAttr
655 : oldArgAttrs[argIdx]);
656 newArgLocs.push_back(oldArgLocs[argIdx]);
661 body->eraseArguments(erasedIndices);
663 assert(newArgNames.size() == newArgCount);
664 assert(newArgTypes.size() == newArgCount);
665 assert(newArgAttrs.size() == newArgCount);
666 assert(newArgLocs.size() == newArgCount);
680 [[deprecated]]
static void
682 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
683 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
684 ArrayRef<unsigned> removeInputs,
685 ArrayRef<unsigned> removeOutputs, Block *body =
nullptr) {
686 auto moduleOp = cast<HWModuleLike>(op);
687 auto *context = moduleOp.getContext();
690 auto oldArgNames = moduleOp.getInputNames();
691 auto oldArgTypes = moduleOp.getInputTypes();
692 auto oldArgAttrs = moduleOp.getAllInputAttrs();
693 auto oldArgLocs = moduleOp.getInputLocs();
695 auto oldResultNames = moduleOp.getOutputNames();
696 auto oldResultTypes = moduleOp.getOutputTypes();
697 auto oldResultAttrs = moduleOp.getAllOutputAttrs();
698 auto oldResultLocs = moduleOp.getOutputLocs();
701 SmallVector<Attribute> newArgNames, newResultNames;
702 SmallVector<Type> newArgTypes, newResultTypes;
703 SmallVector<Attribute> newArgAttrs, newResultAttrs;
704 SmallVector<Location> newArgLocs, newResultLocs;
707 oldArgTypes, oldArgAttrs, oldArgLocs, newArgNames,
708 newArgTypes, newArgAttrs, newArgLocs, body);
711 oldResultTypes, oldResultAttrs, oldResultLocs,
712 newResultNames, newResultTypes, newResultAttrs,
718 moduleOp.setHWModuleType(modty);
719 moduleOp.setAllInputAttrs(newArgAttrs);
720 moduleOp.setAllOutputAttrs(newResultAttrs);
722 newArgLocs.append(newResultLocs.begin(), newResultLocs.end());
723 moduleOp.setAllPortLocs(newArgLocs);
726 void HWModuleOp::build(OpBuilder &
builder, OperationState &result,
728 ArrayAttr parameters,
729 ArrayRef<NamedAttribute> attributes, StringAttr comment,
730 bool shouldEnsureTerminator) {
731 buildModule<HWModuleOp>(
builder, result, name, ports, parameters, attributes,
735 auto *bodyRegion = result.regions[0].get();
737 bodyRegion->push_back(body);
740 auto unknownLoc =
builder.getUnknownLoc();
742 auto loc = port.loc ? Location(port.loc) : unknownLoc;
743 auto type = port.type;
744 if (port.isInOut() && !type.isa<
InOutType>())
746 body->addArgument(type, loc);
750 auto unknownLocAttr = cast<LocationAttr>(unknownLoc);
751 SmallVector<Attribute> resultLocs;
753 resultLocs.push_back(port.loc ? port.loc : unknownLocAttr);
754 result.addAttribute(
"result_locs",
builder.getArrayAttr(resultLocs));
756 if (shouldEnsureTerminator)
757 HWModuleOp::ensureTerminator(*bodyRegion,
builder, result.location);
760 void HWModuleOp::build(OpBuilder &
builder, OperationState &result,
761 StringAttr name, ArrayRef<PortInfo> ports,
762 ArrayAttr parameters,
763 ArrayRef<NamedAttribute> attributes,
764 StringAttr comment) {
769 void HWModuleOp::build(OpBuilder &
builder, OperationState &odsState,
772 ArrayRef<NamedAttribute> attributes,
773 StringAttr comment) {
774 build(
builder, odsState, name, ports, parameters, attributes, comment,
776 auto *bodyRegion = odsState.regions[0].get();
777 OpBuilder::InsertionGuard guard(
builder);
779 builder.setInsertionPointToEnd(&bodyRegion->front());
782 llvm::SmallVector<Value> outputOperands = accessor.getOutputOperands();
783 builder.create<hw::OutputOp>(odsState.location, outputOperands);
786 void HWModuleOp::modifyPorts(
787 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
788 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
789 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
798 if (
auto vName = getVerilogNameAttr())
801 return (*this)->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName());
805 if (
auto vName = getVerilogNameAttr()) {
808 return (*this)->getAttrOfType<StringAttr>(
809 ::mlir::SymbolTable::getSymbolAttrName());
812 void HWModuleExternOp::build(OpBuilder &
builder, OperationState &result,
814 StringRef verilogName, ArrayAttr parameters,
815 ArrayRef<NamedAttribute> attributes) {
816 buildModule<HWModuleExternOp>(
builder, result, name, ports, parameters,
820 LocationAttr unknownLoc =
builder.getUnknownLoc();
821 SmallVector<Attribute> portLocs;
822 for (
auto elt : ports)
823 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
824 result.addAttribute(
"port_locs",
builder.getArrayAttr(portLocs));
826 if (!verilogName.empty())
827 result.addAttribute(
"verilogName",
builder.getStringAttr(verilogName));
830 void HWModuleExternOp::build(OpBuilder &
builder, OperationState &result,
831 StringAttr name, ArrayRef<PortInfo> ports,
832 StringRef verilogName, ArrayAttr parameters,
833 ArrayRef<NamedAttribute> attributes) {
838 void HWModuleExternOp::modifyPorts(
839 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
840 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
841 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
846 void HWModuleExternOp::appendOutputs(
847 ArrayRef<std::pair<StringAttr, Value>>
outputs) {}
849 void HWModuleGeneratedOp::build(OpBuilder &
builder, OperationState &result,
850 FlatSymbolRefAttr genKind, StringAttr name,
852 StringRef verilogName, ArrayAttr parameters,
853 ArrayRef<NamedAttribute> attributes) {
854 buildModule<HWModuleGeneratedOp>(
builder, result, name, ports, parameters,
857 LocationAttr unknownLoc =
builder.getUnknownLoc();
858 SmallVector<Attribute> portLocs;
859 for (
auto elt : ports)
860 portLocs.push_back(elt.loc ? elt.loc : unknownLoc);
861 result.addAttribute(
"port_locs",
builder.getArrayAttr(portLocs));
863 result.addAttribute(
"generatorKind", genKind);
864 if (!verilogName.empty())
865 result.addAttribute(
"verilogName",
builder.getStringAttr(verilogName));
868 void HWModuleGeneratedOp::build(OpBuilder &
builder, OperationState &result,
869 FlatSymbolRefAttr genKind, StringAttr name,
870 ArrayRef<PortInfo> ports, StringRef verilogName,
871 ArrayAttr parameters,
872 ArrayRef<NamedAttribute> attributes) {
874 parameters, attributes);
877 void HWModuleGeneratedOp::modifyPorts(
878 ArrayRef<std::pair<unsigned, PortInfo>> insertInputs,
879 ArrayRef<std::pair<unsigned, PortInfo>> insertOutputs,
880 ArrayRef<unsigned> eraseInputs, ArrayRef<unsigned> eraseOutputs) {
885 void HWModuleGeneratedOp::appendOutputs(
886 ArrayRef<std::pair<StringAttr, Value>>
outputs) {}
888 static bool hasAttribute(StringRef name, ArrayRef<NamedAttribute> attrs) {
889 for (
auto &argAttr : attrs)
890 if (argAttr.getName() == name)
895 template <
typename ModuleTy>
897 OperationState &result) {
899 using namespace mlir::function_interface_impl;
900 auto builder = parser.getBuilder();
901 auto loc = parser.getCurrentLocation();
904 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
908 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
913 FlatSymbolRefAttr kindAttr;
914 if constexpr (std::is_same_v<ModuleTy, HWModuleGeneratedOp>) {
915 if (parser.parseComma() ||
916 parser.parseAttribute(kindAttr,
"generatorKind", result.attributes)) {
922 ArrayAttr parameters;
926 SmallVector<module_like_impl::PortParse> ports;
932 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
936 parser.emitError(loc,
"explicit `parameters` attributes not allowed");
940 result.addAttribute(
"parameters", parameters);
941 result.addAttribute(ModuleTy::getModuleTypeAttrName(result.name), modType);
945 SmallVector<Attribute> attrs;
946 for (
auto &port : ports)
947 attrs.push_back(port.attrs ? port.attrs :
builder.getDictionaryAttr({}));
949 auto nonEmptyAttrsFn = [](Attribute attr) {
950 return attr && !cast<DictionaryAttr>(attr).empty();
952 if (llvm::any_of(attrs, nonEmptyAttrsFn))
953 result.addAttribute(ModuleTy::getPerPortAttrsAttrName(result.name),
957 auto unknownLoc =
builder.getUnknownLoc();
958 auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
959 return attr && cast<Location>(attr) != unknownLoc;
961 SmallVector<Attribute> locs;
962 StringAttr portLocsAttrName;
963 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>) {
966 portLocsAttrName = ModuleTy::getResultLocsAttrName(result.name);
967 for (
auto &port : ports)
969 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
972 portLocsAttrName = ModuleTy::getPortLocsAttrName(result.name);
973 for (
auto &port : ports)
974 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
976 if (llvm::any_of(locs, nonEmptyLocsFn))
977 result.addAttribute(portLocsAttrName,
builder.getArrayAttr(locs));
980 SmallVector<OpAsmParser::Argument, 4> entryArgs;
981 for (
auto &port : ports)
983 entryArgs.push_back(port);
986 auto *body = result.addRegion();
987 if (std::is_same_v<ModuleTy, HWModuleOp>) {
988 if (parser.parseRegion(*body, entryArgs))
991 HWModuleOp::ensureTerminator(*body, parser.getBuilder(), result.location);
996 ParseResult HWModuleOp::parse(OpAsmParser &parser, OperationState &result) {
997 return parseHWModuleOp<HWModuleOp>(parser, result);
1000 ParseResult HWModuleExternOp::parse(OpAsmParser &parser,
1001 OperationState &result) {
1002 return parseHWModuleOp<HWModuleExternOp>(parser, result);
1005 ParseResult HWModuleGeneratedOp::parse(OpAsmParser &parser,
1006 OperationState &result) {
1007 return parseHWModuleOp<HWModuleGeneratedOp>(parser, result);
1011 if (
auto mod = dyn_cast<HWModuleLike>(op))
1012 return mod.getHWModuleType().getFuncType();
1013 return cast<mlir::FunctionOpInterface>(op)
1015 .cast<FunctionType>();
1018 template <
typename ModuleTy>
1022 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
1023 if (
auto visibility = mod.getOperation()->template getAttrOfType<StringAttr>(
1024 visibilityAttrName))
1025 p << visibility.getValue() <<
' ';
1028 p.printSymbolName(SymbolTable::getSymbolName(mod.getOperation()).getValue());
1029 if (
auto gen = dyn_cast<HWModuleGeneratedOp>(mod.getOperation())) {
1031 p.printSymbolName(gen.getGeneratorKind());
1039 SmallVector<StringRef, 3> omittedAttrs;
1040 if (isa<HWModuleGeneratedOp>(mod.getOperation()))
1041 omittedAttrs.push_back(
"generatorKind");
1042 if constexpr (std::is_same_v<ModuleTy, HWModuleOp>)
1043 omittedAttrs.push_back(mod.getResultLocsAttrName());
1045 omittedAttrs.push_back(mod.getPortLocsAttrName());
1046 omittedAttrs.push_back(mod.getModuleTypeAttrName());
1047 omittedAttrs.push_back(mod.getPerPortAttrsAttrName());
1048 omittedAttrs.push_back(mod.getParametersAttrName());
1049 omittedAttrs.push_back(visibilityAttrName);
1051 mod.getOperation()->template getAttrOfType<StringAttr>(
"comment"))
1052 if (cmt.getValue().empty())
1053 omittedAttrs.push_back(
"comment");
1055 mlir::function_interface_impl::printFunctionAttributes(p, mod.getOperation(),
1059 void HWModuleExternOp::print(OpAsmPrinter &p) {
printModuleOp(p, *
this); }
1060 void HWModuleGeneratedOp::print(OpAsmPrinter &p) {
printModuleOp(p, *
this); }
1062 void HWModuleOp::print(OpAsmPrinter &p) {
1066 Region &body = getBody();
1067 if (!body.empty()) {
1069 p.printRegion(body,
false,
1075 assert(isa<HWModuleLike>(module) &&
1076 "verifier hook should only be called on modules");
1078 SmallPtrSet<Attribute, 4> paramNames;
1081 for (
auto param : module->getAttrOfType<ArrayAttr>(
"parameters")) {
1082 auto paramAttr = param.cast<ParamDeclAttr>();
1086 if (!paramNames.insert(paramAttr.getName()).second)
1087 return module->emitOpError(
"parameter ")
1088 << paramAttr <<
" has the same name as a previous parameter";
1091 auto value = paramAttr.getValue();
1095 auto typedValue = value.dyn_cast<TypedAttr>();
1097 return module->emitOpError(
"parameter ")
1098 << paramAttr <<
" should have a typed value; has value " << value;
1100 if (typedValue.getType() != paramAttr.getType())
1101 return module->emitOpError(
"parameter ")
1102 << paramAttr <<
" should have type " << paramAttr.getType()
1103 <<
"; has type " << typedValue.getType();
1115 LogicalResult HWModuleOp::verify() {
1120 auto *body = getBodyBlock();
1123 auto numInputs = type.getNumInputs();
1124 if (body->getNumArguments() != numInputs)
1125 return emitOpError(
"entry block must have")
1126 << numInputs <<
" arguments to match module signature";
1133 std::pair<StringAttr, BlockArgument>
1134 HWModuleOp::insertInput(
unsigned index, StringAttr name, Type ty) {
1138 for (
auto port : ports)
1139 ns.
newName(port.name.getValue());
1142 Block *body = getBodyBlock();
1146 port.
name = nameAttr;
1153 return {nameAttr, body->getArgument(index)};
1156 void HWModuleOp::insertOutputs(
unsigned index,
1157 ArrayRef<std::pair<StringAttr, Value>>
outputs) {
1159 auto output = cast<OutputOp>(getBodyBlock()->getTerminator());
1160 assert(index <= output->getNumOperands() &&
"invalid output index");
1163 SmallVector<std::pair<unsigned, PortInfo>> indexedNewPorts;
1164 for (
auto &[name, value] :
outputs) {
1168 port.
type = value.getType();
1169 indexedNewPorts.emplace_back(index, port);
1175 for (
auto &[name, value] :
outputs)
1176 output->insertOperands(index++, value);
1179 void HWModuleOp::appendOutputs(ArrayRef<std::pair<StringAttr, Value>>
outputs) {
1180 return insertOutputs(getNumOutputPorts(),
outputs);
1183 void HWModuleOp::getAsmBlockArgumentNames(mlir::Region ®ion,
1188 void HWModuleExternOp::getAsmBlockArgumentNames(
1193 template <
typename ModTy>
1195 auto locs = module.getPortLocs();
1197 SmallVector<Location> retval;
1198 retval.reserve(locs->size());
1199 for (
auto l : *locs)
1200 retval.push_back(cast<Location>(l));
1202 assert(!locs->size() || locs->size() == module.getNumPorts());
1205 return SmallVector<Location>(module.getNumPorts(),
1210 SmallVector<Location> portLocs;
1212 auto resultLocs = getResultLocsAttr();
1213 unsigned inputCount = 0;
1216 auto *body = getBodyBlock();
1217 for (
unsigned i = 0, e =
getNumPorts(); i < e; ++i) {
1218 if (modType.isOutput(i)) {
1219 auto loc = resultLocs
1221 resultLocs.getValue()[portLocs.size() - inputCount])
1223 portLocs.push_back(loc);
1225 auto loc = body ? body->getArgument(inputCount).getLoc() : unknownLoc;
1226 portLocs.push_back(loc);
1241 void HWModuleOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1242 SmallVector<Attribute> resultLocs;
1243 unsigned inputCount = 0;
1245 auto *body = getBodyBlock();
1246 for (
unsigned i = 0, e =
getNumPorts(); i < e; ++i) {
1247 if (modType.isOutput(i))
1248 resultLocs.push_back(locs[i]);
1250 body->getArgument(inputCount++).setLoc(cast<Location>(locs[i]));
1255 void HWModuleExternOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1259 void HWModuleGeneratedOp::setAllPortLocsAttrs(ArrayRef<Attribute> locs) {
1263 template <
typename ModTy>
1265 auto numInputs = module.getNumInputPorts();
1266 SmallVector<Attribute> argNames(names.begin(), names.begin() + numInputs);
1267 SmallVector<Attribute> resNames(names.begin() + numInputs, names.end());
1268 auto oldType = module.getModuleType();
1269 SmallVector<ModulePort> newPorts(oldType.getPorts().begin(),
1270 oldType.getPorts().end());
1271 for (
size_t i = 0UL, e = newPorts.size(); i != e; ++i)
1272 newPorts[i].name = cast<StringAttr>(names[i]);
1274 module.setModuleType(newType);
1289 ArrayRef<Attribute> HWModuleOp::getAllPortAttrs() {
1290 auto attrs = getPerPortAttrs();
1291 if (attrs && !attrs->empty())
1292 return attrs->getValue();
1296 ArrayRef<Attribute> HWModuleExternOp::getAllPortAttrs() {
1297 auto attrs = getPerPortAttrs();
1298 if (attrs && !attrs->empty())
1299 return attrs->getValue();
1303 ArrayRef<Attribute> HWModuleGeneratedOp::getAllPortAttrs() {
1304 auto attrs = getPerPortAttrs();
1305 if (attrs && !attrs->empty())
1306 return attrs->getValue();
1310 void HWModuleOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1311 setPerPortAttrsAttr(
arrayOrEmpty(getContext(), attrs));
1314 void HWModuleExternOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1315 setPerPortAttrsAttr(
arrayOrEmpty(getContext(), attrs));
1318 void HWModuleGeneratedOp::setAllPortAttrs(ArrayRef<Attribute> attrs) {
1319 setPerPortAttrsAttr(
arrayOrEmpty(getContext(), attrs));
1322 void HWModuleOp::removeAllPortAttrs() {
1326 void HWModuleExternOp::removeAllPortAttrs() {
1330 void HWModuleGeneratedOp::removeAllPortAttrs() {
1336 template <
typename ModTy>
1338 auto argAttrs = mod.getAllInputAttrs();
1339 auto resAttrs = mod.getAllOutputAttrs();
1341 unsigned newNumArgs = type.getNumInputs();
1342 unsigned newNumResults = type.getNumOutputs();
1345 argAttrs.resize(newNumArgs, emptyDict);
1346 resAttrs.resize(newNumResults, emptyDict);
1348 SmallVector<Attribute> attrs;
1349 attrs.append(argAttrs.begin(), argAttrs.end());
1350 attrs.append(resAttrs.begin(), resAttrs.end());
1353 return mod.removeAllPortAttrs();
1354 mod.setAllPortAttrs(attrs);
1371 Operation *HWModuleGeneratedOp::getGeneratorKindOp() {
1372 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
1373 return topLevelModuleOp.lookupSymbol(getGeneratorKind());
1377 HWModuleGeneratedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1378 auto *referencedKind =
1379 symbolTable.lookupNearestSymbolFrom(*
this, getGeneratorKindAttr());
1381 if (referencedKind ==
nullptr)
1382 return emitError(
"Cannot find generator definition '")
1383 << getGeneratorKind() <<
"'";
1385 if (!isa<HWGeneratorSchemaOp>(referencedKind))
1386 return emitError(
"Symbol resolved to '")
1387 << referencedKind->getName()
1388 <<
"' which is not a HWGeneratorSchemaOp";
1390 auto referencedKindOp = dyn_cast<HWGeneratorSchemaOp>(referencedKind);
1391 auto paramRef = referencedKindOp.getRequiredAttrs();
1392 auto dict = (*this)->getAttrDictionary();
1393 for (
auto str : paramRef) {
1394 auto strAttr = str.dyn_cast<StringAttr>();
1396 return emitError(
"Unknown attribute type, expected a string");
1397 if (!dict.get(strAttr.getValue()))
1398 return emitError(
"Missing attribute '") << strAttr.getValue() <<
"'";
1404 LogicalResult HWModuleGeneratedOp::verify() {
1408 void HWModuleGeneratedOp::getAsmBlockArgumentNames(
1413 LogicalResult HWModuleOp::verifyBody() {
return success(); }
1415 template <
typename ModuleTy>
1417 auto modTy = mod.getHWModuleType();
1419 SmallVector<PortInfo> retval;
1420 auto locs = mod.getAllPortLocs();
1421 for (
unsigned i = 0, e = modTy.getNumPorts(); i < e; ++i) {
1422 LocationAttr loc = locs[i];
1423 DictionaryAttr attrs =
1424 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(i));
1427 retval.push_back({modTy.getPorts()[i],
1428 modTy.isOutput(i) ? modTy.getOutputIdForPortId(i)
1429 : modTy.getInputIdForPortId(i),
1435 template <
typename ModuleTy>
1437 auto modTy = mod.getHWModuleType();
1439 LocationAttr loc = mod.getPortLoc(idx);
1440 DictionaryAttr attrs =
1441 dyn_cast_or_null<DictionaryAttr>(mod.getPortAttrs(idx));
1444 return {modTy.getPorts()[idx],
1445 modTy.isOutput(idx) ? modTy.getOutputIdForPortId(idx)
1446 : modTy.getInputIdForPortId(idx),
1455 void InstanceOp::build(OpBuilder &
builder, OperationState &result,
1456 Operation *module, StringAttr name,
1457 ArrayRef<Value>
inputs, ArrayAttr parameters,
1458 InnerSymAttr innerSym) {
1460 parameters =
builder.getArrayAttr({});
1462 auto mod = cast<hw::HWModuleLike>(module);
1463 auto argNames =
builder.getArrayAttr(mod.getInputNames());
1464 auto resultNames =
builder.getArrayAttr(mod.getOutputNames());
1469 ModuleType modType = mod.getHWModuleType();
1471 parameters, result.location,
false);
1472 if (succeeded(resolvedModType))
1473 modType = *resolvedModType;
1474 FunctionType funcType = resolvedModType->getFuncType();
1475 build(
builder, result, funcType.getResults(), name,
1477 argNames, resultNames, parameters, innerSym);
1480 std::optional<size_t> InstanceOp::getTargetResultIndex() {
1482 return std::nullopt;
1485 LogicalResult InstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1487 *
this, getModuleNameAttr(), getInputs(), getResultTypes(), getArgNames(),
1488 getResultNames(), getParameters(), symbolTable);
1491 LogicalResult InstanceOp::verify() {
1492 auto module = (*this)->getParentOfType<
HWModuleOp>();
1496 auto moduleParameters = module->getAttrOfType<ArrayAttr>(
"parameters");
1498 [&](
const std::function<bool(InFlightDiagnostic &)> &fn) {
1499 auto diag = emitOpError();
1501 diag.attachNote(module->getLoc()) <<
"module declared here";
1504 getParameters(), moduleParameters, emitError);
1507 ParseResult InstanceOp::parse(OpAsmParser &parser, OperationState &result) {
1508 StringAttr instanceNameAttr;
1509 InnerSymAttr innerSym;
1510 FlatSymbolRefAttr moduleNameAttr;
1511 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1512 SmallVector<Type, 1> inputsTypes, allResultTypes;
1513 ArrayAttr argNames, resultNames, parameters;
1514 auto noneType = parser.getBuilder().getType<NoneType>();
1516 if (parser.parseAttribute(instanceNameAttr, noneType,
"instanceName",
1520 if (succeeded(parser.parseOptionalKeyword(
"sym"))) {
1523 if (parser.parseCustomAttributeWithFallback(innerSym))
1528 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1529 if (parser.parseAttribute(moduleNameAttr, noneType,
"moduleName",
1530 result.attributes) ||
1531 parser.getCurrentLocation(¶metersLoc) ||
1534 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1536 parser.parseArrow() ||
1538 parser.parseOptionalAttrDict(result.attributes)) {
1542 result.addAttribute(
"argNames", argNames);
1543 result.addAttribute(
"resultNames", resultNames);
1544 result.addAttribute(
"parameters", parameters);
1545 result.addTypes(allResultTypes);
1549 void InstanceOp::print(OpAsmPrinter &p) {
1551 p.printAttributeWithoutType(getInstanceNameAttr());
1552 if (
auto attr = getInnerSymAttr()) {
1557 p.printAttributeWithoutType(getModuleNameAttr());
1564 p.printOptionalAttrDict(
1565 (*this)->getAttrs(),
1567 InnerSymbolTable::getInnerSymbolAttrName(),
"moduleName",
1568 "argNames",
"resultNames",
"parameters"});
1575 std::optional<size_t> InstanceChoiceOp::getTargetResultIndex() {
1577 return std::nullopt;
1581 InstanceChoiceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1582 for (Attribute name : getModuleNamesAttr()) {
1584 *
this, name.cast<FlatSymbolRefAttr>(), getInputs(),
1585 getResultTypes(), getArgNames(), getResultNames(), getParameters(),
1593 LogicalResult InstanceChoiceOp::verify() {
1594 auto module = (*this)->getParentOfType<
HWModuleOp>();
1598 auto moduleParameters = module->getAttrOfType<ArrayAttr>(
"parameters");
1600 [&](
const std::function<bool(InFlightDiagnostic &)> &fn) {
1601 auto diag = emitOpError();
1603 diag.attachNote(module->getLoc()) <<
"module declared here";
1606 getParameters(), moduleParameters, emitError);
1609 ParseResult InstanceChoiceOp::parse(OpAsmParser &parser,
1610 OperationState &result) {
1611 StringAttr optionNameAttr;
1612 StringAttr instanceNameAttr;
1613 InnerSymAttr innerSym;
1614 SmallVector<Attribute> moduleNames;
1615 SmallVector<Attribute> caseNames;
1616 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands;
1617 SmallVector<Type, 1> inputsTypes, allResultTypes;
1618 ArrayAttr argNames, resultNames, parameters;
1619 auto noneType = parser.getBuilder().getType<NoneType>();
1621 if (parser.parseAttribute(instanceNameAttr, noneType,
"instanceName",
1625 if (succeeded(parser.parseOptionalKeyword(
"sym"))) {
1628 if (parser.parseCustomAttributeWithFallback(innerSym))
1633 if (parser.parseKeyword(
"option") ||
1634 parser.parseAttribute(optionNameAttr, noneType,
"optionName",
1638 FlatSymbolRefAttr defaultModuleName;
1639 if (parser.parseAttribute(defaultModuleName))
1641 moduleNames.push_back(defaultModuleName);
1643 while (succeeded(parser.parseOptionalKeyword(
"or"))) {
1644 FlatSymbolRefAttr moduleName;
1645 StringAttr targetName;
1646 if (parser.parseAttribute(moduleName) ||
1647 parser.parseOptionalKeyword(
"if") || parser.parseAttribute(targetName))
1649 moduleNames.push_back(moduleName);
1650 caseNames.push_back(targetName);
1653 llvm::SMLoc parametersLoc, inputsOperandsLoc;
1654 if (parser.getCurrentLocation(¶metersLoc) ||
1657 parser.resolveOperands(inputsOperands, inputsTypes, inputsOperandsLoc,
1659 parser.parseArrow() ||
1661 parser.parseOptionalAttrDict(result.attributes)) {
1665 result.addAttribute(
"moduleNames",
1667 result.addAttribute(
"caseNames",
1669 result.addAttribute(
"argNames", argNames);
1670 result.addAttribute(
"resultNames", resultNames);
1671 result.addAttribute(
"parameters", parameters);
1672 result.addTypes(allResultTypes);
1676 void InstanceChoiceOp::print(OpAsmPrinter &p) {
1678 p.printAttributeWithoutType(getInstanceNameAttr());
1679 if (
auto attr = getInnerSymAttr()) {
1683 p <<
" option " << getOptionNameAttr() <<
' ';
1685 auto moduleNames = getModuleNamesAttr();
1686 auto caseNames = getCaseNamesAttr();
1687 assert(moduleNames.size() == caseNames.size() + 1);
1689 p.printAttributeWithoutType(moduleNames[0]);
1690 for (
size_t i = 0, n = caseNames.size(); i < n; ++i) {
1692 p.printAttributeWithoutType(moduleNames[i + 1]);
1694 p.printAttributeWithoutType(caseNames[i]);
1703 p.printOptionalAttrDict(
1704 (*this)->getAttrs(),
1706 InnerSymbolTable::getInnerSymbolAttrName(),
1707 "moduleNames",
"caseNames",
"argNames",
"resultNames",
1708 "parameters",
"optionName"});
1711 ArrayAttr InstanceChoiceOp::getReferencedModuleNamesAttr() {
1712 SmallVector<Attribute> moduleNames;
1713 for (Attribute attr : getModuleNamesAttr()) {
1714 moduleNames.push_back(attr.cast<FlatSymbolRefAttr>().getAttr());
1724 LogicalResult OutputOp::verify() {
1728 if (
auto mod = dyn_cast<HWModuleOp>((*this)->getParentOp()))
1729 modType = mod.getHWModuleType();
1731 emitOpError(
"must have a module parent");
1734 auto modResults = modType.getOutputTypes();
1735 OperandRange outputValues = getOperands();
1736 if (modResults.size() != outputValues.size()) {
1737 emitOpError(
"must have same number of operands as region results.");
1742 for (
size_t i = 0, e = modResults.size(); i < e; ++i) {
1743 if (modResults[i] != outputValues[i].getType()) {
1744 emitOpError(
"output types must match module. In "
1746 << i <<
", expected " << modResults[i] <<
", but got "
1747 << outputValues[i].getType() <<
".";
1762 if (p.parseType(type))
1763 return p.emitError(p.getCurrentLocation(),
"Expected type");
1764 auto arrType = type_dyn_cast<ArrayType>(type);
1766 return p.emitError(p.getCurrentLocation(),
"Expected !hw.array type");
1768 unsigned idxWidth = llvm::Log2_64_Ceil(arrType.getNumElements());
1775 p.printType(srcType);
1778 ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
1779 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
1780 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
1783 if (parser.parseOperandList(operands) ||
1784 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
1785 parser.parseType(elemType))
1788 if (operands.size() == 0)
1789 return parser.emitError(inputOperandsLoc,
1790 "Cannot construct an array of length 0");
1793 for (
auto operand : operands)
1794 if (parser.resolveOperand(operand, elemType, result.operands))
1799 void ArrayCreateOp::print(OpAsmPrinter &p) {
1801 p.printOperands(getInputs());
1802 p.printOptionalAttrDict((*this)->getAttrs());
1803 p <<
" : " << getInputs()[0].getType();
1806 void ArrayCreateOp::build(OpBuilder &b, OperationState &state,
1807 ValueRange values) {
1808 assert(values.size() > 0 &&
"Cannot build array of zero elements");
1809 Type elemType = values[0].getType();
1812 [elemType](Value v) ->
bool {
return v.getType() == elemType; }) &&
1813 "All values must have same type.");
1814 build(b, state,
ArrayType::get(elemType, values.size()), values);
1817 LogicalResult ArrayCreateOp::verify() {
1818 unsigned returnSize = getType().cast<ArrayType>().getNumElements();
1819 if (getInputs().size() != returnSize)
1824 OpFoldResult ArrayCreateOp::fold(FoldAdaptor adaptor) {
1825 if (llvm::any_of(adaptor.getInputs(), [](Attribute attr) { return !attr; }))
1837 auto baseValue = constBase.getValue();
1838 auto indexValue = constIndex.getValue();
1840 unsigned bits = baseValue.getBitWidth();
1841 assert(bits == indexValue.getBitWidth() &&
"mismatched widths");
1843 if (bits < 64 && offset >= (1ull << bits))
1846 APInt baseExt = baseValue.zextOrTrunc(bits + 1);
1847 APInt indexExt = indexValue.zextOrTrunc(bits + 1);
1848 return baseExt + offset == indexExt;
1856 PatternRewriter &rewriter) {
1858 auto arrayTy = hw::type_cast<ArrayType>(op.getType());
1859 if (arrayTy.getNumElements() <= 1)
1861 auto elemTy = arrayTy.getElementType();
1870 SmallVector<Chunk> chunks;
1871 for (Value value : llvm::reverse(op.getInputs())) {
1876 Value input =
get.getInput();
1877 Value index =
get.getIndex();
1878 if (!chunks.empty()) {
1879 auto &c = *chunks.rbegin();
1880 if (c.input ==
get.getInput() &&
isOffset(c.index, index, c.size)) {
1886 chunks.push_back(Chunk{input, index, 1});
1890 if (chunks.size() == 1) {
1891 auto &chunk = chunks[0];
1892 rewriter.replaceOp(op, rewriter.createOrFold<
ArraySliceOp>(
1893 op.getLoc(), arrayTy, chunk.input, chunk.index));
1899 if (chunks.size() * 2 < arrayTy.getNumElements()) {
1900 SmallVector<Value> slices;
1901 for (
auto &chunk : llvm::reverse(chunks)) {
1904 op.getLoc(), sliceTy, chunk.input, chunk.index));
1906 rewriter.replaceOpWithNewOp<
ArrayConcatOp>(op, arrayTy, slices);
1914 PatternRewriter &rewriter) {
1920 Value ArrayCreateOp::getUniformElement() {
1921 if (!getInputs().
empty() && llvm::all_equal(getInputs()))
1922 return getInputs()[0];
1927 auto idxOp = dyn_cast_or_null<ConstantOp>(value.getDefiningOp());
1929 return std::nullopt;
1930 APInt idxAttr = idxOp.getValue();
1931 if (idxAttr.getBitWidth() > 64)
1932 return std::nullopt;
1933 return idxAttr.getLimitedValue();
1936 LogicalResult ArraySliceOp::verify() {
1937 unsigned inputSize =
1938 type_cast<ArrayType>(getInput().getType()).getNumElements();
1939 if (llvm::Log2_64_Ceil(inputSize) !=
1940 getLowIndex().getType().getIntOrFloatBitWidth())
1942 "ArraySlice: index width must match clog2 of array size");
1946 OpFoldResult ArraySliceOp::fold(FoldAdaptor adaptor) {
1948 if (getType() == getInput().getType())
1953 LogicalResult ArraySliceOp::canonicalize(
ArraySliceOp op,
1954 PatternRewriter &rewriter) {
1955 auto sliceTy = hw::type_cast<ArrayType>(op.getType());
1956 auto elemTy = sliceTy.getElementType();
1957 uint64_t sliceSize = sliceTy.getNumElements();
1961 if (sliceSize == 1) {
1963 auto get = rewriter.create<
ArrayGetOp>(op.getLoc(), op.getInput(),
1965 rewriter.replaceOpWithNewOp<
ArrayCreateOp>(op, op.getType(),
1974 auto inputOp = op.getInput().getDefiningOp();
1975 if (
auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
1977 if (inputSlice == op)
1980 auto inputIndex = inputSlice.getLowIndex();
1982 if (!inputOffsetOpt)
1985 uint64_t offset = *offsetOpt + *inputOffsetOpt;
1987 rewriter.create<
ConstantOp>(op.getLoc(), inputIndex.getType(), offset);
1988 rewriter.replaceOpWithNewOp<
ArraySliceOp>(op, op.getType(),
1989 inputSlice.getInput(), lowIndex);
1993 if (
auto inputCreate = dyn_cast_or_null<ArrayCreateOp>(inputOp)) {
1995 auto inputs = inputCreate.getInputs();
1997 uint64_t begin =
inputs.size() - *offsetOpt - sliceSize;
1998 rewriter.replaceOpWithNewOp<
ArrayCreateOp>(op, op.getType(),
1999 inputs.slice(begin, sliceSize));
2003 if (
auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2005 SmallVector<Value> chunks;
2006 uint64_t sliceStart = *offsetOpt;
2007 for (
auto input : llvm::reverse(inputConcat.getInputs())) {
2009 uint64_t inputSize =
2010 hw::type_cast<ArrayType>(input.getType()).getNumElements();
2011 if (inputSize == 0 || inputSize <= sliceStart) {
2012 sliceStart -= inputSize;
2017 uint64_t cutEnd = std::min(inputSize, sliceStart + sliceSize);
2018 uint64_t cutSize = cutEnd - sliceStart;
2019 assert(cutSize != 0 &&
"slice cannot be empty");
2021 if (cutSize == inputSize) {
2023 assert(sliceStart == 0 &&
"invalid cut size");
2024 chunks.push_back(input);
2027 unsigned width = inputSize == 1 ? 1 : llvm::Log2_64_Ceil(inputSize);
2029 op.getLoc(), rewriter.getIntegerType(
width), sliceStart);
2035 sliceSize -= cutSize;
2040 assert(chunks.size() > 0 &&
"missing sliced items");
2041 if (chunks.size() == 1)
2042 rewriter.replaceOp(op, chunks[0]);
2045 op, llvm::to_vector(llvm::reverse(chunks)));
2056 SmallVectorImpl<Type> &inputTypes,
2059 uint64_t resultSize = 0;
2061 auto parseElement = [&]() -> ParseResult {
2063 if (p.parseType(ty))
2065 auto arrTy = type_dyn_cast<ArrayType>(ty);
2067 return p.emitError(p.getCurrentLocation(),
"Expected !hw.array type");
2068 if (elemType && elemType != arrTy.getElementType())
2069 return p.emitError(p.getCurrentLocation(),
"Expected array element type ")
2072 elemType = arrTy.getElementType();
2073 inputTypes.push_back(ty);
2074 resultSize += arrTy.getNumElements();
2078 if (p.parseCommaSeparatedList(parseElement))
2086 TypeRange inputTypes, Type resultType) {
2087 llvm::interleaveComma(inputTypes, p, [&p](Type t) { p << t; });
2090 void ArrayConcatOp::build(OpBuilder &b, OperationState &state,
2091 ValueRange values) {
2092 assert(!values.empty() &&
"Cannot build array of zero elements");
2093 ArrayType arrayTy = values[0].getType().cast<ArrayType>();
2094 Type elemTy = arrayTy.getElementType();
2095 assert(llvm::all_of(values,
2096 [elemTy](Value v) ->
bool {
2097 return v.getType().isa<ArrayType>() &&
2098 v.getType().cast<ArrayType>().getElementType() ==
2101 "All values must be of ArrayType with the same element type.");
2103 uint64_t resultSize = 0;
2104 for (Value val : values)
2105 resultSize += val.getType().cast<ArrayType>().getNumElements();
2109 OpFoldResult ArrayConcatOp::fold(FoldAdaptor adaptor) {
2110 auto inputs = adaptor.getInputs();
2111 SmallVector<Attribute> array;
2112 for (
size_t i = 0, e = getNumOperands(); i < e; ++i) {
2115 llvm::copy(
inputs[i].cast<ArrayAttr>(), std::back_inserter(array));
2122 for (
auto input : op.getInputs())
2126 SmallVector<Value> items;
2127 for (
auto input : op.getInputs()) {
2128 auto create = cast<ArrayCreateOp>(input.getDefiningOp());
2129 for (
auto item : create.getInputs())
2130 items.push_back(item);
2144 SmallVector<Location> locs;
2147 SmallVector<Value> items;
2148 std::optional<Slice> last;
2149 bool changed =
false;
2151 auto concatenate = [&] {
2156 items.push_back(last->op);
2165 auto origTy = hw::type_cast<ArrayType>(last->input.getType());
2166 auto arrayTy =
ArrayType::get(origTy.getElementType(), last->size);
2168 loc, arrayTy, last->input, last->index));
2173 auto append = [&](Value op, Value input, Value index,
size_t size) {
2178 if (last->input == input &&
isOffset(last->index, index, last->size)) {
2181 last->locs.push_back(op.getLoc());
2186 last.emplace(Slice{input, index, size, op, {op.getLoc()}});
2189 for (
auto item : llvm::reverse(op.getInputs())) {
2191 auto size = hw::type_cast<ArrayType>(slice.getType()).getNumElements();
2192 append(item, slice.getInput(), slice.getLowIndex(), size);
2197 if (create.getInputs().size() == 1) {
2198 if (
auto get = create.getInputs()[0].getDefiningOp<
ArrayGetOp>()) {
2206 items.push_back(item);
2213 if (items.size() == 1) {
2214 rewriter.replaceOp(op, items[0]);
2216 std::reverse(items.begin(), items.end());
2223 PatternRewriter &rewriter) {
2239 ParseResult EnumConstantOp::parse(OpAsmParser &parser, OperationState &result) {
2246 auto loc = parser.getEncodedSourceLoc(parser.getCurrentLocation());
2247 if (parser.parseKeyword(&field) || parser.parseColonType(type))
2256 result.addAttribute(
"field", fieldAttr);
2257 result.addTypes(type);
2262 void EnumConstantOp::print(OpAsmPrinter &p) {
2263 p <<
" " << getField().getField().getValue() <<
" : "
2264 << getField().getType().getValue();
2268 function_ref<
void(Value, StringRef)> setNameFn) {
2269 setNameFn(getResult(), getField().getField().str());
2272 void EnumConstantOp::build(OpBuilder &
builder, OperationState &odsState,
2273 EnumFieldAttr field) {
2274 return build(
builder, odsState, field.getType().getValue(), field);
2277 OpFoldResult EnumConstantOp::fold(FoldAdaptor adaptor) {
2278 assert(adaptor.getOperands().empty() &&
"constant has no operands");
2279 return getFieldAttr();
2282 LogicalResult EnumConstantOp::verify() {
2283 auto fieldAttr = getFieldAttr();
2284 auto fieldType = fieldAttr.getType().getValue();
2287 if (fieldType != getType())
2288 emitOpError(
"return type ")
2289 << getType() <<
" does not match attribute type " << fieldAttr;
2297 LogicalResult EnumCmpOp::verify() {
2299 auto lhsType = type_cast<EnumType>(getLhs().getType());
2300 auto rhsType = type_cast<EnumType>(getRhs().getType());
2301 if (rhsType != lhsType)
2302 emitOpError(
"types do not match");
2310 ParseResult StructCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2311 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2312 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 4> operands;
2313 Type declOrAliasType;
2315 if (parser.parseLParen() || parser.parseOperandList(operands) ||
2316 parser.parseRParen() || parser.parseOptionalAttrDict(result.attributes) ||
2317 parser.parseColonType(declOrAliasType))
2320 auto declType = type_dyn_cast<StructType>(declOrAliasType);
2322 return parser.emitError(parser.getNameLoc(),
2323 "expected !hw.struct type or alias");
2325 llvm::SmallVector<Type, 4> structInnerTypes;
2326 declType.getInnerTypes(structInnerTypes);
2327 result.addTypes(declOrAliasType);
2329 if (parser.resolveOperands(operands, structInnerTypes, inputOperandsLoc,
2335 void StructCreateOp::print(OpAsmPrinter &printer) {
2337 printer.printOperands(getInput());
2339 printer.printOptionalAttrDict((*this)->getAttrs());
2340 printer <<
" : " << getType();
2343 LogicalResult StructCreateOp::verify() {
2344 auto elements = hw::type_cast<StructType>(getType()).getElements();
2346 if (elements.size() != getInput().size())
2347 return emitOpError(
"structure field count mismatch");
2349 for (
const auto &[field, value] : llvm::zip(elements, getInput()))
2350 if (field.type != value.getType())
2351 return emitOpError(
"structure field `")
2352 << field.name <<
"` type does not match";
2357 OpFoldResult StructCreateOp::fold(FoldAdaptor adaptor) {
2359 if (!getInput().
empty())
2360 if (
auto explodeOp = getInput()[0].getDefiningOp<StructExplodeOp>();
2361 explodeOp && getInput() == explodeOp.getResults() &&
2362 getResult().getType() == explodeOp.getInput().getType())
2363 return explodeOp.getInput();
2365 auto inputs = adaptor.getInput();
2366 if (llvm::any_of(
inputs, [](Attribute attr) {
return !attr; }))
2375 ParseResult StructExplodeOp::parse(OpAsmParser &parser,
2376 OperationState &result) {
2377 OpAsmParser::UnresolvedOperand operand;
2380 if (parser.parseOperand(operand) ||
2381 parser.parseOptionalAttrDict(result.attributes) ||
2382 parser.parseColonType(declType))
2384 auto structType = type_dyn_cast<StructType>(declType);
2386 return parser.emitError(parser.getNameLoc(),
2387 "invalid kind of type specified");
2389 llvm::SmallVector<Type, 4> structInnerTypes;
2390 structType.getInnerTypes(structInnerTypes);
2391 result.addTypes(structInnerTypes);
2393 if (parser.resolveOperand(operand, declType, result.operands))
2398 void StructExplodeOp::print(OpAsmPrinter &printer) {
2400 printer.printOperand(getInput());
2401 printer.printOptionalAttrDict((*this)->getAttrs());
2402 printer <<
" : " << getInput().getType();
2405 LogicalResult StructExplodeOp::fold(FoldAdaptor adaptor,
2406 SmallVectorImpl<OpFoldResult> &results) {
2407 auto input = adaptor.getInput();
2410 llvm::copy(input.cast<ArrayAttr>(), std::back_inserter(results));
2414 LogicalResult StructExplodeOp::canonicalize(StructExplodeOp op,
2415 PatternRewriter &rewriter) {
2416 auto *inputOp = op.getInput().getDefiningOp();
2417 auto elements = type_cast<StructType>(op.getInput().getType()).getElements();
2418 auto result = failure();
2419 auto opResults = op.getResults();
2420 for (uint32_t index = 0; index < elements.size(); index++) {
2422 rewriter.replaceAllUsesWith(opResults[index], foldResult);
2430 function_ref<
void(Value, StringRef)> setNameFn) {
2431 auto structType = type_cast<StructType>(getInput().getType());
2432 for (
auto [res, field] : llvm::zip(getResults(), structType.getElements()))
2433 setNameFn(res, field.name.str());
2436 void StructExplodeOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2438 StructType inputType = input.getType().dyn_cast<StructType>();
2440 SmallVector<Type, 16> fieldTypes;
2441 for (
auto field : inputType.getElements())
2442 fieldTypes.push_back(field.type);
2443 build(odsBuilder, odsState, fieldTypes, input);
2452 template <
typename AggregateOp,
typename AggregateType>
2454 AggregateType aggType,
2456 auto index = op.getFieldIndex();
2457 if (index >= aggType.getElements().size())
2458 return op.emitOpError() <<
"field index " << index
2459 <<
" exceeds element count of aggregate type";
2463 return op.emitOpError()
2464 <<
"type " << aggType.getElements()[index].type
2465 <<
" of accessed field in aggregate at index " << index
2466 <<
" does not match expected type " <<
elementType;
2471 LogicalResult StructExtractOp::verify() {
2472 return verifyAggregateFieldIndexAndType<StructExtractOp, StructType>(
2473 *
this, getInput().getType(), getType());
2478 template <
typename AggregateType>
2480 OpAsmParser::UnresolvedOperand operand;
2481 StringAttr fieldName;
2484 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2485 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2486 parser.parseOptionalAttrDict(result.attributes) ||
2487 parser.parseColonType(declType))
2489 auto aggType = type_dyn_cast<AggregateType>(declType);
2491 return parser.emitError(parser.getNameLoc(),
2492 "invalid kind of type specified");
2494 auto fieldIndex = aggType.getFieldIndex(fieldName);
2496 parser.emitError(parser.getNameLoc(),
"field name '" +
2497 fieldName.getValue() +
2498 "' not found in aggregate type");
2504 result.addAttribute(
"fieldIndex", indexAttr);
2505 Type resultType = aggType.getElements()[*fieldIndex].type;
2506 result.addTypes(resultType);
2508 if (parser.resolveOperand(operand, declType, result.operands))
2515 template <
typename AggType>
2518 printer.printOperand(op.getInput());
2519 printer <<
"[\"" << op.getFieldName() <<
"\"]";
2520 printer.printOptionalAttrDict(op->getAttrs(), {
"fieldIndex"});
2521 printer <<
" : " << op.getInput().getType();
2524 ParseResult StructExtractOp::parse(OpAsmParser &parser,
2525 OperationState &result) {
2526 return parseExtractOp<StructType>(parser, result);
2529 void StructExtractOp::print(OpAsmPrinter &printer) {
2533 void StructExtractOp::build(OpBuilder &
builder, OperationState &odsState,
2534 Value input, StructType::FieldInfo field) {
2536 type_cast<StructType>(input.getType()).getFieldIndex(field.name);
2537 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2538 build(
builder, odsState, field.type, input, *fieldIndex);
2541 void StructExtractOp::build(OpBuilder &
builder, OperationState &odsState,
2542 Value input, StringAttr fieldName) {
2543 auto structType = type_cast<StructType>(input.getType());
2544 auto fieldIndex = structType.getFieldIndex(fieldName);
2545 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2546 auto resultType = structType.getElements()[*fieldIndex].type;
2547 build(
builder, odsState, resultType, input, *fieldIndex);
2550 OpFoldResult StructExtractOp::fold(FoldAdaptor adaptor) {
2551 if (
auto constOperand = adaptor.getInput()) {
2553 auto operandAttr = llvm::cast<ArrayAttr>(constOperand);
2554 return operandAttr.getValue()[getFieldIndex()];
2557 if (
auto foldResult =
2564 PatternRewriter &rewriter) {
2565 auto inputOp = op.getInput().getDefiningOp();
2568 if (
auto structInject = dyn_cast_or_null<StructInjectOp>(inputOp)) {
2569 if (structInject.getFieldIndex() != op.getFieldIndex()) {
2571 op, op.getType(), structInject.getInput(), op.getFieldIndexAttr());
2580 function_ref<
void(Value, StringRef)> setNameFn) {
2588 void StructInjectOp::build(OpBuilder &
builder, OperationState &odsState,
2589 Value input, StringAttr fieldName, Value newValue) {
2590 auto structType = type_cast<StructType>(input.getType());
2591 auto fieldIndex = structType.getFieldIndex(fieldName);
2592 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2593 build(
builder, odsState, input, *fieldIndex, newValue);
2596 LogicalResult StructInjectOp::verify() {
2597 return verifyAggregateFieldIndexAndType<StructInjectOp, StructType>(
2598 *
this, getInput().getType(), getNewValue().getType());
2601 ParseResult StructInjectOp::parse(OpAsmParser &parser, OperationState &result) {
2602 llvm::SMLoc inputOperandsLoc = parser.getCurrentLocation();
2603 OpAsmParser::UnresolvedOperand operand, val;
2604 StringAttr fieldName;
2607 if (parser.parseOperand(operand) || parser.parseLSquare() ||
2608 parser.parseAttribute(fieldName) || parser.parseRSquare() ||
2609 parser.parseComma() || parser.parseOperand(val) ||
2610 parser.parseOptionalAttrDict(result.attributes) ||
2611 parser.parseColonType(declType))
2613 auto structType = type_dyn_cast<StructType>(declType);
2615 return parser.emitError(inputOperandsLoc,
"invalid kind of type specified");
2617 auto fieldIndex = structType.getFieldIndex(fieldName);
2619 parser.emitError(parser.getNameLoc(),
"field name '" +
2620 fieldName.getValue() +
2621 "' not found in aggregate type");
2627 result.addAttribute(
"fieldIndex", indexAttr);
2628 result.addTypes(declType);
2630 Type resultType = structType.getElements()[*fieldIndex].type;
2631 if (parser.resolveOperands({operand, val}, {declType, resultType},
2632 inputOperandsLoc, result.operands))
2637 void StructInjectOp::print(OpAsmPrinter &printer) {
2639 printer.printOperand(getInput());
2641 printer.printOperand(getNewValue());
2642 printer.printOptionalAttrDict((*this)->getAttrs(), {
"fieldIndex"});
2643 printer <<
" : " << getInput().getType();
2646 OpFoldResult StructInjectOp::fold(FoldAdaptor adaptor) {
2647 auto input = adaptor.getInput();
2648 auto newValue = adaptor.getNewValue();
2649 if (!input || !newValue)
2651 SmallVector<Attribute> array;
2652 llvm::copy(input.cast<ArrayAttr>(), std::back_inserter(array));
2653 array[getFieldIndex()] = newValue;
2657 LogicalResult StructInjectOp::canonicalize(StructInjectOp op,
2658 PatternRewriter &rewriter) {
2660 SmallPtrSet<Operation *, 4> injects;
2661 DenseMap<StringAttr, Value> fields;
2664 StructInjectOp inject = op;
2667 if (!injects.insert(inject).second)
2670 fields.try_emplace(inject.getFieldNameAttr(), inject.getNewValue());
2671 input = inject.getInput();
2672 inject = dyn_cast_or_null<StructInjectOp>(input.getDefiningOp());
2674 assert(input &&
"missing input to inject chain");
2676 auto ty = hw::type_cast<StructType>(op.getType());
2677 auto elements = ty.getElements();
2680 if (fields.size() == elements.size()) {
2681 SmallVector<Value> createFields;
2682 for (
const auto &field : elements) {
2683 auto it = fields.find(field.name);
2684 assert(it != fields.end() &&
"missing field");
2685 createFields.push_back(it->second);
2687 rewriter.replaceOpWithNewOp<
StructCreateOp>(op, ty, createFields);
2692 if (injects.size() == fields.size())
2696 for (uint32_t fieldIndex = 0; fieldIndex < elements.size(); fieldIndex++) {
2697 auto it = fields.find(elements[fieldIndex].name);
2698 if (it == fields.end())
2700 input = rewriter.create<StructInjectOp>(op.getLoc(), ty, input, fieldIndex,
2704 rewriter.replaceOp(op, input);
2712 LogicalResult UnionCreateOp::verify() {
2713 return verifyAggregateFieldIndexAndType<UnionCreateOp, UnionType>(
2714 *
this, getType(), getInput().getType());
2717 void UnionCreateOp::build(OpBuilder &
builder, OperationState &odsState,
2718 Type unionType, StringAttr fieldName, Value input) {
2719 auto fieldIndex = type_cast<UnionType>(unionType).getFieldIndex(fieldName);
2720 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2721 build(
builder, odsState, unionType, *fieldIndex, input);
2724 ParseResult UnionCreateOp::parse(OpAsmParser &parser, OperationState &result) {
2725 Type declOrAliasType;
2726 StringAttr fieldName;
2727 OpAsmParser::UnresolvedOperand input;
2728 llvm::SMLoc fieldLoc = parser.getCurrentLocation();
2730 if (parser.parseAttribute(fieldName) || parser.parseComma() ||
2731 parser.parseOperand(input) ||
2732 parser.parseOptionalAttrDict(result.attributes) ||
2733 parser.parseColonType(declOrAliasType))
2736 auto declType = type_dyn_cast<UnionType>(declOrAliasType);
2738 return parser.emitError(parser.getNameLoc(),
2739 "expected !hw.union type or alias");
2741 auto fieldIndex = declType.getFieldIndex(fieldName);
2743 parser.emitError(fieldLoc,
"cannot find union field '")
2744 << fieldName.getValue() <<
'\'';
2750 result.addAttribute(
"fieldIndex", indexAttr);
2751 Type inputType = declType.getElements()[*fieldIndex].type;
2753 if (parser.resolveOperand(input, inputType, result.operands))
2755 result.addTypes({declOrAliasType});
2759 void UnionCreateOp::print(OpAsmPrinter &printer) {
2761 printer.printOperand(getInput());
2762 printer.printOptionalAttrDict((*this)->getAttrs(), {
"fieldIndex"});
2763 printer <<
" : " << getType();
2770 ParseResult UnionExtractOp::parse(OpAsmParser &parser, OperationState &result) {
2771 return parseExtractOp<UnionType>(parser, result);
2774 void UnionExtractOp::print(OpAsmPrinter &printer) {
2779 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
2780 DictionaryAttr attrs, mlir::OpaqueProperties properties,
2781 mlir::RegionRange regions, SmallVectorImpl<Type> &results) {
2782 auto unionElements =
2783 hw::type_cast<UnionType>((operands[0].getType())).getElements();
2784 unsigned fieldIndex =
2785 attrs.getAs<IntegerAttr>(
"fieldIndex").getValue().getZExtValue();
2786 if (fieldIndex >= unionElements.size()) {
2788 mlir::emitError(*loc,
"field index " + Twine(fieldIndex) +
2789 " exceeds element count of aggregate type");
2792 results.push_back(unionElements[fieldIndex].type);
2796 void UnionExtractOp::build(OpBuilder &odsBuilder, OperationState &odsState,
2797 Value input, StringAttr fieldName) {
2798 auto unionType = type_cast<UnionType>(input.getType());
2799 auto fieldIndex = unionType.getFieldIndex(fieldName);
2800 assert(fieldIndex.has_value() &&
"field name not found in aggregate type");
2801 auto resultType = unionType.getElements()[*fieldIndex].type;
2802 build(odsBuilder, odsState, resultType, input, *fieldIndex);
2814 OpFoldResult ArrayGetOp::fold(FoldAdaptor adaptor) {
2815 auto inputCst = adaptor.getInput().dyn_cast_or_null<ArrayAttr>();
2816 auto indexCst = adaptor.getIndex().dyn_cast_or_null<IntegerAttr>();
2821 auto indexVal = indexCst.getValue();
2822 if (indexVal.getBitWidth() < 64) {
2823 auto index = indexVal.getZExtValue();
2824 return inputCst[inputCst.size() - 1 - index];
2829 if (!inputCst.empty() && llvm::all_equal(inputCst))
2834 if (
auto bitcast = getInput().getDefiningOp<
hw::BitcastOp>()) {
2835 auto intTy = getType().dyn_cast<IntegerType>();
2838 auto bitcastInputOp = bitcast.getInput().getDefiningOp<
hw::ConstantOp>();
2839 if (!bitcastInputOp)
2843 auto bitcastInputCst = bitcastInputOp.getValue();
2846 auto startIdx = indexCst.getValue().zext(bitcastInputCst.getBitWidth()) *
2847 getType().getIntOrFloatBitWidth();
2850 intTy.getIntOrFloatBitWidth()));
2853 auto inputCreate = getInput().getDefiningOp<
ArrayCreateOp>();
2857 if (
auto uniformValue = inputCreate.getUniformElement())
2858 return uniformValue;
2860 if (!indexCst || indexCst.getValue().getBitWidth() > 64)
2863 uint64_t index = indexCst.getValue().getLimitedValue();
2864 auto createInputs = inputCreate.getInputs();
2865 if (index >= createInputs.size())
2867 return createInputs[createInputs.size() - index - 1];
2870 LogicalResult ArrayGetOp::canonicalize(
ArrayGetOp op,
2871 PatternRewriter &rewriter) {
2876 auto *inputOp = op.getInput().getDefiningOp();
2877 if (
auto inputSlice = dyn_cast_or_null<ArraySliceOp>(inputOp)) {
2879 auto offsetOp = inputSlice.getLowIndex();
2884 uint64_t offset = *offsetOpt + *idxOpt;
2886 rewriter.create<
ConstantOp>(op.getLoc(), offsetOp.getType(), offset);
2887 rewriter.replaceOpWithNewOp<
ArrayGetOp>(op, inputSlice.getInput(),
2892 if (
auto inputConcat = dyn_cast_or_null<ArrayConcatOp>(inputOp)) {
2894 uint64_t elemIndex = *idxOpt;
2895 for (
auto input : llvm::reverse(inputConcat.getInputs())) {
2896 size_t size = hw::type_cast<ArrayType>(input.getType()).getNumElements();
2897 if (elemIndex >= size) {
2902 unsigned indexWidth = size == 1 ? 1 : llvm::Log2_64_Ceil(size);
2904 op.getLoc(), rewriter.getIntegerType(indexWidth), elemIndex);
2906 rewriter.replaceOpWithNewOp<
ArrayGetOp>(op, input, newIdxOp);
2915 if (
auto innerGet = dyn_cast_or_null<hw::ArrayGetOp>(inputOp)) {
2920 SmallVector<Value> newValues;
2921 for (
auto operand : create.getOperands())
2923 op.getLoc(), operand, op.getIndex()));
2928 innerGet.getIndex());
2941 StringRef TypedeclOp::getPreferredName() {
2942 return getVerilogName().value_or(
getName());
2945 Type TypedeclOp::getAliasType() {
2946 auto parentScope = cast<hw::TypeScopeOp>(getOperation()->getParentOp());
2949 {FlatSymbolRefAttr::get(*this)}),
2957 OpFoldResult BitcastOp::fold(FoldAdaptor) {
2960 if (getOperand().getType() == getType())
2961 return getOperand();
2966 LogicalResult BitcastOp::canonicalize(
BitcastOp op, PatternRewriter &rewriter) {
2972 dyn_cast_or_null<BitcastOp>(op.getInput().getDefiningOp());
2975 auto bitcast = rewriter.createOrFold<
BitcastOp>(op.getLoc(), op.getType(),
2976 inputBitcast.getInput());
2977 rewriter.replaceOp(op, bitcast);
2981 LogicalResult BitcastOp::verify() {
2983 return this->emitOpError(
"Bitwidth of input must match result");
2991 bool HierPathOp::dropModule(StringAttr moduleToDrop) {
2992 SmallVector<Attribute, 4> newPath;
2993 bool updateMade =
false;
2994 for (
auto nameRef : getNamepath()) {
2996 if (
auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
2997 if (ref.getModule() == moduleToDrop)
3000 newPath.push_back(ref);
3002 if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == moduleToDrop)
3005 newPath.push_back(nameRef);
3013 bool HierPathOp::inlineModule(StringAttr moduleToDrop) {
3014 SmallVector<Attribute, 4> newPath;
3015 bool updateMade =
false;
3016 StringRef inlinedInstanceName =
"";
3017 for (
auto nameRef : getNamepath()) {
3019 if (
auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
3020 if (ref.getModule() == moduleToDrop) {
3021 inlinedInstanceName = ref.getName().getValue();
3023 }
else if (!inlinedInstanceName.empty()) {
3027 ref.getName().getValue())));
3028 inlinedInstanceName =
"";
3030 newPath.push_back(ref);
3032 if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == moduleToDrop)
3035 newPath.push_back(nameRef);
3043 bool HierPathOp::updateModule(StringAttr oldMod, StringAttr newMod) {
3044 SmallVector<Attribute, 4> newPath;
3045 bool updateMade =
false;
3046 for (
auto nameRef : getNamepath()) {
3048 if (
auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
3049 if (ref.getModule() == oldMod) {
3053 newPath.push_back(ref);
3055 if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == oldMod) {
3059 newPath.push_back(nameRef);
3067 bool HierPathOp::updateModuleAndInnerRef(
3068 StringAttr oldMod, StringAttr newMod,
3069 const llvm::DenseMap<StringAttr, StringAttr> &innerSymRenameMap) {
3071 if (oldMod == newMod)
3074 auto namepathNew = getNamepath().getValue().vec();
3075 bool updateMade =
false;
3077 for (
auto &element : namepathNew) {
3078 if (
auto innerRef = element.dyn_cast<hw::InnerRefAttr>()) {
3079 if (innerRef.getModule() != oldMod)
3081 auto symName = innerRef.getName();
3084 auto to = innerSymRenameMap.find(symName);
3085 if (to != innerSymRenameMap.end())
3086 symName = to->second;
3091 if (element != fromRef)
3103 bool HierPathOp::truncateAtModule(StringAttr atMod,
bool includeMod) {
3104 SmallVector<Attribute, 4> newPath;
3105 bool updateMade =
false;
3106 for (
auto nameRef : getNamepath()) {
3108 if (
auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
3109 if (ref.getModule() == atMod) {
3112 newPath.push_back(ref);
3114 newPath.push_back(ref);
3116 if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == atMod && !includeMod)
3119 newPath.push_back(nameRef);
3130 StringAttr HierPathOp::modPart(
unsigned i) {
3131 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3132 .Case<FlatSymbolRefAttr>([](
auto a) {
return a.getAttr(); })
3133 .Case<hw::InnerRefAttr>([](
auto a) {
return a.getModule(); });
3137 StringAttr HierPathOp::root() {
3143 bool HierPathOp::hasModule(StringAttr modName) {
3144 for (
auto nameRef : getNamepath()) {
3146 if (
auto ref = nameRef.dyn_cast<hw::InnerRefAttr>()) {
3147 if (ref.getModule() == modName)
3150 if (nameRef.cast<FlatSymbolRefAttr>().getAttr() == modName)
3158 bool HierPathOp::hasInnerSym(StringAttr modName, StringAttr symName)
const {
3159 for (
auto nameRef :
const_cast<HierPathOp *
>(
this)->getNamepath())
3160 if (
auto ref = nameRef.dyn_cast<hw::InnerRefAttr>())
3161 if (ref.getName() == symName && ref.getModule() == modName)
3169 StringAttr HierPathOp::refPart(
unsigned i) {
3170 return TypeSwitch<Attribute, StringAttr>(getNamepath()[i])
3171 .Case<FlatSymbolRefAttr>([](
auto a) {
return StringAttr({}); })
3172 .Case<hw::InnerRefAttr>([](
auto a) {
return a.getName(); });
3177 StringAttr HierPathOp::ref() {
3179 return refPart(getNamepath().size() - 1);
3183 StringAttr HierPathOp::leafMod() {
3185 return modPart(getNamepath().size() - 1);
3190 bool HierPathOp::isModule() {
return !ref(); }
3194 bool HierPathOp::isComponent() {
return (
bool)ref(); }
3209 LogicalResult HierPathOp::verifyInnerRefs(hw::InnerRefNamespace &ns) {
3210 ArrayAttr expectedModuleNames = {};
3211 auto checkExpectedModule = [&](Attribute name) -> LogicalResult {
3212 if (!expectedModuleNames)
3214 if (llvm::any_of(expectedModuleNames,
3215 [name](Attribute attr) {
return attr == name; }))
3217 auto diag = emitOpError() <<
"instance path is incorrect. Expected ";
3218 size_t n = expectedModuleNames.size();
3222 for (
size_t i = 0; i < n; ++i) {
3224 diag << ((i + 1 == n) ?
" or " :
", ");
3225 diag << expectedModuleNames[i].cast<StringAttr>();
3227 diag <<
". Instead found: " << name;
3231 if (!getNamepath() || getNamepath().empty())
3232 return emitOpError() <<
"the instance path cannot be empty";
3233 for (
unsigned i = 0, s = getNamepath().size() - 1; i < s; ++i) {
3234 hw::InnerRefAttr innerRef = getNamepath()[i].dyn_cast<hw::InnerRefAttr>();
3236 return emitOpError()
3237 <<
"the instance path can only contain inner sym reference"
3238 <<
", only the leaf can refer to a module symbol";
3240 if (failed(checkExpectedModule(innerRef.getModule())))
3243 auto instOp = ns.lookupOp<igraph::InstanceOpInterface>(innerRef);
3245 return emitOpError() <<
" module: " << innerRef.getModule()
3246 <<
" does not contain any instance with symbol: "
3247 << innerRef.getName();
3248 expectedModuleNames = instOp.getReferencedModuleNamesAttr();
3252 auto leafRef = getNamepath()[getNamepath().size() - 1];
3253 if (
auto innerRef = leafRef.dyn_cast<hw::InnerRefAttr>()) {
3254 if (!ns.lookup(innerRef)) {
3255 return emitOpError() <<
" operation with symbol: " << innerRef
3256 <<
" was not found ";
3258 if (failed(checkExpectedModule(innerRef.getModule())))
3260 }
else if (failed(checkExpectedModule(
3261 leafRef.cast<FlatSymbolRefAttr>().getAttr()))) {
3267 void HierPathOp::print(OpAsmPrinter &p) {
3271 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
3272 if (
auto visibility =
3273 getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
3274 p << visibility.getValue() <<
' ';
3276 p.printSymbolName(getSymName());
3278 llvm::interleaveComma(getNamepath().getValue(), p, [&](Attribute attr) {
3279 if (
auto ref = attr.dyn_cast<hw::InnerRefAttr>()) {
3280 p.printSymbolName(ref.getModule().getValue());
3282 p.printSymbolName(ref.getName().getValue());
3284 p.printSymbolName(attr.cast<FlatSymbolRefAttr>().getValue());
3288 p.printOptionalAttrDict(
3289 (*this)->getAttrs(),
3290 {SymbolTable::getSymbolAttrName(),
"namepath", visibilityAttrName});
3293 ParseResult HierPathOp::parse(OpAsmParser &parser, OperationState &result) {
3295 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
3299 if (parser.parseSymbolName(symName, SymbolTable::getSymbolAttrName(),
3304 SmallVector<Attribute> namepath;
3305 if (parser.parseCommaSeparatedList(
3306 OpAsmParser::Delimiter::Square, [&]() -> ParseResult {
3307 auto loc = parser.getCurrentLocation();
3309 if (parser.parseAttribute(ref))
3313 auto pathLength = ref.getNestedReferences().size();
3314 if (pathLength == 0)
3316 FlatSymbolRefAttr::get(ref.getRootReference()));
3317 else if (pathLength == 1)
3318 namepath.push_back(hw::InnerRefAttr::get(ref.getRootReference(),
3319 ref.getLeafReference()));
3321 return parser.emitError(loc,
3322 "only one nested reference is allowed");
3326 result.addAttribute(
"namepath",
3329 if (parser.parseOptionalAttrDict(result.attributes))
3339 void TriggeredOp::build(OpBuilder &
builder, OperationState &odsState,
3340 EventControlAttr event, Value trigger,
3342 odsState.addOperands(trigger);
3343 odsState.addOperands(
inputs);
3344 odsState.addAttribute(getEventAttrName(odsState.name), event);
3345 auto *r = odsState.addRegion();
3349 llvm::SmallVector<Location> argLocs;
3350 llvm::transform(
inputs, std::back_inserter(argLocs),
3351 [&](Value v) {
return v.getLoc(); });
3352 b->addArguments(
inputs.getTypes(), argLocs);
3360 #define GET_OP_CLASSES
3361 #include "circt/Dialect/HW/HW.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyModuleCommon(HWModuleLike module)
static void printParamValue(OpAsmPrinter &p, Operation *, Attribute value, Type resultType)
static void printModuleOp(OpAsmPrinter &p, ModuleTy mod)
static bool flattenConcatOp(ArrayConcatOp op, PatternRewriter &rewriter)
static LogicalResult foldCreateToSlice(ArrayCreateOp op, PatternRewriter &rewriter)
static SmallVector< Location > getAllPortLocs(ModTy module)
static ArrayAttr arrayOrEmpty(mlir::MLIRContext *context, ArrayRef< Attribute > attrs)
FunctionType getHWModuleOpType(Operation *op)
static void buildModule(OpBuilder &builder, OperationState &result, StringAttr name, const ModulePortInfo &ports, ArrayAttr parameters, ArrayRef< NamedAttribute > attributes, StringAttr comment)
static void printExtractOp(OpAsmPrinter &printer, AggType op)
Use the same printer for both struct_extract and union_extract since the syntax is identical.
static void modifyModulePorts(Operation *op, ArrayRef< std::pair< unsigned, PortInfo >> insertInputs, ArrayRef< std::pair< unsigned, PortInfo >> insertOutputs, ArrayRef< unsigned > removeInputs, ArrayRef< unsigned > removeOutputs, Block *body=nullptr)
Insert and remove ports of a module.
static void printArrayConcatTypes(OpAsmPrinter &p, Operation *, TypeRange inputTypes, Type resultType)
static void modifyModuleArgs(MLIRContext *context, ArrayRef< std::pair< unsigned, PortInfo >> insertArgs, ArrayRef< unsigned > removeArgs, ArrayRef< Attribute > oldArgNames, ArrayRef< Type > oldArgTypes, ArrayRef< Attribute > oldArgAttrs, ArrayRef< Location > oldArgLocs, SmallVector< Attribute > &newArgNames, SmallVector< Type > &newArgTypes, SmallVector< Attribute > &newArgAttrs, SmallVector< Location > &newArgLocs, Block *body=nullptr)
Internal implementation of argument/result insertion and removal on modules.
static ParseResult parseSliceTypes(OpAsmParser &p, Type &srcType, Type &idxType)
static Value foldStructExtract(Operation *inputOp, uint32_t fieldIndex)
static bool hasAttribute(StringRef name, ArrayRef< NamedAttribute > attrs)
static bool mergeConcatSlices(ArrayConcatOp op, PatternRewriter &rewriter)
static ParseResult parseExtractOp(OpAsmParser &parser, OperationState &result)
Use the same parser for both struct_extract and union_extract since the syntax is identical.
static void setAllPortNames(ArrayRef< Attribute > names, ModTy module)
static void getAsmBlockArgumentNamesImpl(mlir::Region ®ion, OpAsmSetValueNameFn setNameFn)
Get a special name to use when printing the entry block arguments of the region contained by an opera...
static void setHWModuleType(ModTy &mod, ModuleType type)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
static ParseResult parseParamValue(OpAsmParser &p, Attribute &value, Type &resultType)
static LogicalResult checkAttributes(Operation *op, Attribute attr, Type type)
static std::optional< uint64_t > getUIntFromValue(Value value)
static ParseResult parseHWModuleOp(OpAsmParser &parser, OperationState &result)
static LogicalResult verifyAggregateFieldIndexAndType(AggregateOp &op, AggregateType aggType, Type elementType)
Ensure an aggregate op's field index is within the bounds of the aggregate type and the accessed fiel...
static PortInfo getPort(ModuleTy &mod, size_t idx)
static void printSliceTypes(OpAsmPrinter &p, Operation *, Type srcType, Type idxType)
static bool hasAdditionalAttributes(Op op, ArrayRef< StringRef > ignoredAttrs={})
Check whether an operation has any additional attributes set beyond its standard list of attributes r...
static ParseResult parseArrayConcatTypes(OpAsmParser &p, SmallVectorImpl< Type > &inputTypes, Type &resultType)
static bool getFieldName(const FieldRef &fieldRef, SmallString< 32 > &string)
static InstancePath empty
llvm::SmallVector< StringAttr > inputs
llvm::SmallVector< StringAttr > outputs
static StringAttr append(StringAttr base, const Twine &suffix)
Return a attribute with the specified suffix appended.
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
void setOutput(unsigned i, Value v)
Value getInput(unsigned i)
llvm::SmallVector< Value > outputOperands
llvm::SmallVector< Value > inputArgs
llvm::StringMap< unsigned > outputIdx
llvm::StringMap< unsigned > inputIdx
HWModulePortAccessor(Location loc, const ModulePortInfo &info, Region &bodyRegion)
static StringRef getInnerSymbolAttrName()
Return the name of the attribute used for inner symbol names.
This helps visit TypeOp nodes.
ResultType dispatchTypeOpVisitor(Operation *op, ExtraArgs... args)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
uint64_t getWidth(Type t)
LogicalResult inferReturnTypes(MLIRContext *context, std::optional< Location > loc, ValueRange operands, DictionaryAttr attrs, mlir::OpaqueProperties properties, mlir::RegionRange regions, SmallVectorImpl< Type > &results, llvm::function_ref< FIRRTLType(ValueRange, ArrayRef< NamedAttribute >, std::optional< Location >)> callback)
size_t getNumPorts(Operation *op)
Return the number of ports in a module-like thing (modules, memories, etc)
ModuleType fnToMod(Operation *op, ArrayRef< Attribute > inputNames, ArrayRef< Attribute > outputNames)
LogicalResult verifyParameterStructure(ArrayAttr parameters, ArrayAttr moduleParameters, const EmitErrorFn &emitError)
Check that all the parameter values specified to the instance are structurally valid.
std::function< void(std::function< bool(InFlightDiagnostic &)>)> EmitErrorFn
Whenever the nested function returns true, a note referring to the referenced module is attached to t...
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, HWModuleLike op)
bool isOffset(Value base, Value index, uint64_t offset)
llvm::function_ref< void(OpBuilder &, HWModulePortAccessor &)> HWModuleBuilder
bool isCombinational(Operation *op)
Return true if the specified operation is a combinational logic op.
ModulePort::Direction flip(ModulePort::Direction direction)
Flip a port direction.
FunctionType getModuleType(Operation *module)
Return the signature for the specified module as a function type.
LogicalResult checkParameterInContext(Attribute value, Operation *module, Operation *usingOp, bool disallowParamRefs=false)
Check parameter specified by value to see if it is valid within the scope of the specified module mod...
bool isValidParameterExpression(Attribute attr, Operation *module)
Return true if the specified attribute tree is made up of nodes that are valid in a parameter express...
bool isValidIndexBitWidth(Value index, Value array)
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
bool isAnyModuleOrInstance(Operation *module)
TODO: Move all these functions to a hw::ModuleLike interface.
StringAttr getVerilogModuleNameAttr(Operation *module)
Returns the verilog module name attribute or symbol name of any module-like operations.
mlir::Type getCanonicalType(mlir::Type type)
circt::hw::InOutType InOutType
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseInputPortList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inputs, SmallVectorImpl< Type > &inputTypes, ArrayAttr &inputNames)
Parse a list of instance input ports.
void printOutputPortList(OpAsmPrinter &p, Operation *op, TypeRange resultTypes, ArrayAttr resultNames)
Print a list of instance output ports.
ParseResult parseOptionalParameterList(OpAsmParser &parser, ArrayAttr ¶meters)
Parse an parameter list if present.
void printOptionalParameterList(OpAsmPrinter &p, Operation *op, ArrayAttr parameters)
Print a parameter list for a module or instance.
StringRef chooseName(StringRef a, StringRef b)
Choose a good name for an item from two options.
void printInputPortList(OpAsmPrinter &p, Operation *op, OperandRange inputs, TypeRange inputTypes, ArrayAttr inputNames)
Print a list of instance input ports.
ParseResult parseOutputPortList(OpAsmParser &parser, SmallVectorImpl< Type > &resultTypes, ArrayAttr &resultNames)
Parse a list of instance output ports.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
This holds a decoded list of input/inout and output ports for a module or instance.
size_t sizeOutputs() const
PortInfo & at(size_t idx)
size_t sizeInputs() const
PortDirectionRange getInputs()
PortDirectionRange getOutputs()
This holds the name, type, direction of a module's ports.