16#include "mlir/IR/Builders.h"
17#include "mlir/IR/ImplicitLocOpBuilder.h"
27 auto *context = parser.getContext();
28 auto loc = parser.getCurrentLocation();
30 if (parser.parseString(&rawPath))
33 return parser.emitError(loc,
"invalid base path");
42 p << elt.module.getValue() <<
'/' << elt.
instance.getValue();
49 StringAttr &module, StringAttr &ref,
52 auto *context = parser.getContext();
53 auto loc = parser.getCurrentLocation();
55 if (parser.parseString(&rawPath))
57 if (
parsePath(context, rawPath, path, module, ref, field))
58 return parser.emitError(loc,
"invalid path");
63 StringAttr module, StringAttr ref,
66 for (
const auto &elt : path)
67 p << elt.module.getValue() <<
'/' << elt.instance.getValue() <<
':';
68 if (!module.getValue().empty())
69 p << module.getValue();
70 if (!ref.getValue().empty())
71 p <<
'>' << ref.getValue();
72 if (!field.getValue().empty())
73 p << field.getValue();
77static ParseResult
parseFieldLocs(OpAsmParser &parser, ArrayAttr &fieldLocs) {
78 if (parser.parseOptionalKeyword(
"field_locs"))
80 if (parser.parseLParen() || parser.parseAttribute(fieldLocs) ||
81 parser.parseRParen()) {
88 ArrayAttr fieldLocs) {
89 mlir::OpPrintingFlags flags;
90 if (!flags.shouldPrintDebugInfo() || !fieldLocs)
92 printer <<
"field_locs(";
93 printer.printAttribute(fieldLocs);
101 SmallVectorImpl<Attribute> &fieldNames,
102 SmallVectorImpl<Type> &fieldTypes) {
104 llvm::StringMap<SMLoc> nameLocMap;
105 auto parseElt = [&]() -> ParseResult {
107 std::string fieldName;
108 if (parser.parseKeywordOrString(&fieldName))
110 SMLoc currLoc = parser.getCurrentLocation();
111 if (nameLocMap.count(fieldName)) {
112 parser.emitError(currLoc,
"field \"")
113 << fieldName <<
"\" is defined twice";
114 parser.emitError(nameLocMap[fieldName]) <<
"previous definition is here";
117 nameLocMap[fieldName] = currLoc;
118 fieldNames.push_back(StringAttr::get(parser.getContext(), fieldName));
121 fieldTypes.emplace_back();
122 if (parser.parseColonType(fieldTypes.back()))
128 return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
135 if (parser.parseSymbolName(symName, mlir::SymbolTable::getSymbolAttrName(),
140 SmallVector<OpAsmParser::Argument> args;
141 if (parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren,
145 SmallVector<Type> fieldTypes;
146 SmallVector<Attribute> fieldNames;
147 if (succeeded(parser.parseOptionalArrow()))
151 SmallVector<NamedAttribute> fieldTypesMap;
152 if (!fieldNames.empty()) {
153 for (
auto [name, type] : zip(fieldNames, fieldTypes))
154 fieldTypesMap.push_back(
155 NamedAttribute(cast<StringAttr>(name), TypeAttr::get(type)));
157 auto *ctx = parser.getContext();
158 state.addAttribute(
"fieldNames", mlir::ArrayAttr::get(ctx, fieldNames));
159 state.addAttribute(
"fieldTypes",
160 mlir::DictionaryAttr::get(ctx, fieldTypesMap));
163 if (failed(parser.parseOptionalAttrDictWithKeyword(state.attributes)))
167 Region *region = state.addRegion();
168 if (parser.parseRegion(*region, args))
173 region->emplaceBlock();
176 auto argNames = llvm::map_range(args, [&](OpAsmParser::Argument arg) {
177 return StringAttr::get(parser.getContext(), arg.ssaName.name.drop_front());
181 ArrayAttr::get(parser.getContext(), SmallVector<Attribute>(argNames)));
189 printer << classLike.getSymName();
192 auto argNames = SmallVector<StringRef>(
193 classLike.getFormalParamNames().getAsValueRange<StringAttr>());
194 ArrayRef<BlockArgument> args = classLike.getBodyBlock()->getArguments();
198 for (
size_t i = 0, e = args.size(); i < e; ++i) {
199 printer <<
'%' << argNames[i] <<
": " << args[i].getType();
205 ArrayRef<Attribute> fieldNames =
206 cast<ArrayAttr>(classLike->getAttr(
"fieldNames")).getValue();
208 if (!fieldNames.empty()) {
210 for (
size_t i = 0, e = fieldNames.size(); i < e; ++i) {
213 StringAttr name = cast<StringAttr>(fieldNames[i]);
214 printer.printKeywordOrString(name.getValue());
216 Type type = classLike.getFieldType(name).value();
217 printer.printType(type);
223 SmallVector<StringRef> elidedAttrs{classLike.getSymNameAttrName(),
224 classLike.getFormalParamNamesAttrName(),
225 "fieldTypes",
"fieldNames"};
226 printer.printOptionalAttrDictWithKeyword(classLike.getOperation()->getAttrs(),
230 printer.printRegion(classLike.getBody(),
false,
236 if (classLike.getFormalParamNames().size() !=
237 classLike.getBodyBlock()->getArguments().size()) {
238 auto error = classLike.emitOpError(
239 "formal parameter name list doesn't match formal parameter value list");
240 error.attachNote(classLike.getLoc())
241 <<
"formal parameter names: " << classLike.getFormalParamNames();
242 error.attachNote(classLike.getLoc())
243 <<
"formal parameter values: "
244 << classLike.getBodyBlock()->getArguments();
254 auto argNames = SmallVector<StringRef>(
255 classLike.getFormalParamNames().getAsValueRange<StringAttr>());
256 ArrayRef<BlockArgument> args = classLike.getBodyBlock()->getArguments();
259 for (
size_t i = 0, e = args.size(); i < e; ++i)
260 setNameFn(args[i], argNames[i]);
264 return NamedAttribute(name, TypeAttr::get(type));
269 return NamedAttribute(StringAttr(name),
270 mlir::IntegerAttr::get(mlir::IndexType::get(ctx), i));
275 DictionaryAttr fieldTypes = mlir::cast<DictionaryAttr>(
276 classLike.getOperation()->getAttr(
"fieldTypes"));
277 Attribute type = fieldTypes.get(name);
280 return cast<TypeAttr>(type).getValue();
284 AttrTypeReplacer &replacer) {
285 classLike->setAttr(
"fieldTypes", cast<DictionaryAttr>(replacer.replace(
286 classLike.getFieldTypes())));
293ParseResult circt::om::ClassOp::parse(OpAsmParser &parser,
294 OperationState &state) {
298circt::om::ClassOp circt::om::ClassOp::buildSimpleClassOp(
299 OpBuilder &odsBuilder, Location loc, Twine name,
300 ArrayRef<StringRef> formalParamNames, ArrayRef<StringRef> fieldNames,
301 ArrayRef<Type> fieldTypes) {
302 circt::om::ClassOp classOp = circt::om::ClassOp::create(
303 odsBuilder, loc, odsBuilder.getStringAttr(name),
304 odsBuilder.getStrArrayAttr(formalParamNames),
305 odsBuilder.getStrArrayAttr(fieldNames),
306 odsBuilder.getDictionaryAttr(llvm::map_to_vector(
307 llvm::zip(fieldNames, fieldTypes), [&](
auto field) -> NamedAttribute {
308 return NamedAttribute(odsBuilder.getStringAttr(std::get<0>(field)),
309 TypeAttr::get(std::get<1>(field)));
311 Block *body = &classOp.getRegion().emplaceBlock();
312 auto prevLoc = odsBuilder.saveInsertionPoint();
313 odsBuilder.setInsertionPointToEnd(body);
315 mlir::SmallVector<Attribute> locAttrs(fieldNames.size(), LocationAttr(loc));
317 ClassFieldsOp::create(odsBuilder, loc,
318 llvm::map_to_vector(fieldTypes,
319 [&](Type type) -> Value {
320 return body->addArgument(type,
323 odsBuilder.getArrayAttr(locAttrs));
325 odsBuilder.restoreInsertionPoint(prevLoc);
330void circt::om::ClassOp::print(OpAsmPrinter &printer) {
334LogicalResult circt::om::ClassOp::verify() {
return verifyClassLike(*
this); }
336LogicalResult circt::om::ClassOp::verifyRegions() {
337 auto fieldsOp = cast<ClassFieldsOp>(this->
getBodyBlock()->getTerminator());
340 if (fieldsOp.getNumOperands() != this->getFieldNames().size()) {
341 auto diag = this->emitOpError()
342 <<
"returns '" << this->getFieldNames().size()
343 <<
"' fields, but its terminator returned '"
344 << fieldsOp.getNumOperands() <<
"' fields";
345 return diag.attachNote(fieldsOp.getLoc()) <<
"see terminator:";
349 auto types = this->getFieldTypes();
350 for (
auto [fieldName, terminatorOperandType] :
351 llvm::zip(this->getFieldNames(), fieldsOp.getOperandTypes())) {
353 if (terminatorOperandType ==
354 cast<TypeAttr>(types.get(cast<StringAttr>(fieldName))).getValue())
357 auto diag = this->emitOpError()
358 <<
"returns different field types than its terminator";
359 return diag.attachNote(fieldsOp.getLoc()) <<
"see terminator:";
365void circt::om::ClassOp::getAsmBlockArgumentNames(
370std::optional<mlir::Type>
371circt::om::ClassOp::getFieldType(mlir::StringAttr field) {
375void circt::om::ClassOp::replaceFieldTypes(AttrTypeReplacer replacer) {
379void circt::om::ClassOp::updateFields(
380 mlir::ArrayRef<mlir::Location> newLocations,
381 mlir::ArrayRef<mlir::Value> newValues,
382 mlir::ArrayRef<mlir::Attribute> newNames) {
384 auto fieldsOp = getFieldsOp();
385 assert(fieldsOp &&
"The fields op should exist");
387 SmallVector<Attribute> names(getFieldNamesAttr().getAsRange<StringAttr>());
389 SmallVector<NamedAttribute> fieldTypes(getFieldTypesAttr().getValue());
391 SmallVector<Value> fieldVals(fieldsOp.getFields());
393 Location fieldOpLoc = fieldsOp->getLoc();
396 SmallVector<Location> locations;
397 if (
auto fl = dyn_cast<FusedLoc>(fieldOpLoc)) {
398 auto metadataArr = dyn_cast<ArrayAttr>(fl.getMetadata());
399 assert(metadataArr &&
"Expected the metadata for the fused location");
400 auto r = metadataArr.getAsRange<LocationAttr>();
401 locations.append(r.begin(), r.end());
404 locations.append(names.size(), fieldOpLoc);
408 names.append(newNames.begin(), newNames.end());
409 locations.append(newLocations.begin(), newLocations.end());
410 fieldVals.append(newValues.begin(), newValues.end());
413 for (
auto [v, n] :
llvm::zip(newValues, newNames))
414 fieldTypes.emplace_back(
415 NamedAttribute(
llvm::cast<StringAttr>(n), TypeAttr::
get(v.getType())));
418 SmallVector<Attribute> locationsAttr;
419 llvm::for_each(locations, [&](Location &l) {
420 locationsAttr.push_back(cast<Attribute>(l));
423 ImplicitLocOpBuilder builder(
getLoc(), *
this);
425 setFieldNamesAttr(builder.getArrayAttr(names));
427 setFieldTypesAttr(builder.getDictionaryAttr(fieldTypes));
428 fieldsOp.getFieldsMutable().assign(fieldVals);
430 fieldsOp->setLoc(builder.getFusedLoc(
431 locations, ArrayAttr::get(getContext(), locationsAttr)));
434void circt::om::ClassOp::addNewFieldsOp(mlir::OpBuilder &builder,
435 mlir::ArrayRef<Location> locs,
436 mlir::ArrayRef<Value> values) {
439 assert(locs.size() == values.size() &&
"Expected a location per value");
440 mlir::SmallVector<Attribute> locAttrs;
441 for (
auto loc : locs) {
442 locAttrs.push_back(cast<Attribute>(LocationAttr(loc)));
446 ClassFieldsOp::create(builder, builder.getFusedLoc(locs), values,
447 builder.getArrayAttr(locAttrs));
450mlir::Location circt::om::ClassOp::getFieldLocByIndex(
size_t i) {
451 auto fieldsOp = this->getFieldsOp();
452 auto fieldLocs = fieldsOp.getFieldLocs();
453 if (!fieldLocs.has_value())
454 return fieldsOp.getLoc();
455 assert(i < fieldLocs.value().size() &&
456 "field index too large for location array");
457 return cast<LocationAttr>(fieldLocs.value()[i]);
464ParseResult circt::om::ClassExternOp::parse(OpAsmParser &parser,
465 OperationState &state) {
469void circt::om::ClassExternOp::print(OpAsmPrinter &printer) {
473LogicalResult circt::om::ClassExternOp::verify() {
479 return this->emitOpError(
"external class body should be empty");
485void circt::om::ClassExternOp::getAsmBlockArgumentNames(
490std::optional<mlir::Type>
491circt::om::ClassExternOp::getFieldType(mlir::StringAttr field) {
495void circt::om::ClassExternOp::replaceFieldTypes(AttrTypeReplacer replacer) {
503LogicalResult circt::om::ClassFieldsOp::verify() {
504 auto fieldLocs = this->getFieldLocs();
505 if (fieldLocs.has_value()) {
506 auto fieldLocsVal = fieldLocs.value();
507 if (fieldLocsVal.size() != this->getFields().size()) {
508 auto error = this->emitOpError(
"size of field_locs (")
509 << fieldLocsVal.size()
510 <<
") does not match number of fields ("
511 << this->getFields().size() <<
")";
521void circt::om::ObjectOp::build(::mlir::OpBuilder &odsBuilder,
522 ::mlir::OperationState &odsState,
524 ::mlir::ValueRange actualParams) {
525 return build(odsBuilder, odsState,
526 om::ClassType::get(odsBuilder.getContext(),
527 mlir::FlatSymbolRefAttr::get(classOp)),
528 classOp.getNameAttr(), actualParams);
532circt::om::ObjectOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
534 StringAttr resultClassName = getResult().getType().getClassName().getAttr();
535 StringAttr className = getClassNameAttr();
536 if (resultClassName != className)
537 return emitOpError(
"result type (")
538 << resultClassName <<
") does not match referred to class ("
542 auto classDef = dyn_cast_or_null<ClassLike>(
543 symbolTable.lookupNearestSymbolFrom(*
this, className));
545 return emitOpError(
"refers to non-existant class (") << className <<
')';
547 auto actualTypes = getActualParams().getTypes();
548 auto formalTypes = classDef.getBodyBlock()->getArgumentTypes();
551 if (actualTypes.size() != formalTypes.size()) {
552 auto error = emitOpError(
553 "actual parameter list doesn't match formal parameter list");
554 error.attachNote(classDef.getLoc())
555 <<
"formal parameters: " << classDef.getBodyBlock()->getArguments();
556 error.attachNote(
getLoc()) <<
"actual parameters: " << getActualParams();
561 for (
size_t i = 0, e = actualTypes.size(); i < e; ++i) {
562 if (actualTypes[i] != formalTypes[i]) {
563 return emitOpError(
"actual parameter type (")
564 << actualTypes[i] <<
") doesn't match formal parameter type ("
565 << formalTypes[i] <<
')';
576void circt::om::ConstantOp::build(::mlir::OpBuilder &odsBuilder,
577 ::mlir::OperationState &odsState,
578 ::mlir::TypedAttr constVal) {
579 return build(odsBuilder, odsState, constVal.getType(), constVal);
582OpFoldResult circt::om::ConstantOp::fold(FoldAdaptor adaptor) {
583 assert(adaptor.getOperands().empty() &&
"constant has no operands");
584 return getValueAttr();
591void circt::om::ListCreateOp::print(OpAsmPrinter &p) {
593 p.printOperands(getInputs());
594 p.printOptionalAttrDict((*this)->getAttrs());
595 p <<
" : " << getType().getElementType();
598ParseResult circt::om::ListCreateOp::parse(OpAsmParser &parser,
599 OperationState &result) {
600 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
603 if (parser.parseOperandList(operands) ||
604 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
605 parser.parseType(elemType))
607 result.addTypes({circt::om::ListType::get(elemType)});
609 for (
auto operand : operands)
610 if (parser.resolveOperand(operand, elemType, result.operands))
620BasePathCreateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
621 auto hierPath = symbolTable.lookupNearestSymbolFrom<hw::HierPathOp>(
622 *
this, getTargetAttr());
624 return emitOpError(
"invalid symbol reference");
633PathCreateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
634 auto hierPath = symbolTable.lookupNearestSymbolFrom<hw::HierPathOp>(
635 *
this, getTargetAttr());
637 return emitOpError(
"invalid symbol reference");
645FailureOr<llvm::APSInt>
646IntegerAddOp::evaluateIntegerOperation(
const llvm::APSInt &lhs,
647 const llvm::APSInt &rhs) {
648 return success(lhs + rhs);
655FailureOr<llvm::APSInt>
656IntegerMulOp::evaluateIntegerOperation(
const llvm::APSInt &lhs,
657 const llvm::APSInt &rhs) {
658 return success(lhs * rhs);
665FailureOr<llvm::APSInt>
666IntegerShrOp::evaluateIntegerOperation(
const llvm::APSInt &lhs,
667 const llvm::APSInt &rhs) {
669 if (!rhs.isNonNegative())
670 return emitOpError(
"shift amount must be non-negative");
672 if (!rhs.isRepresentableByInt64())
673 return emitOpError(
"shift amount must be representable in 64 bits");
674 return success(lhs >> rhs.getExtValue());
681FailureOr<llvm::APSInt>
682IntegerShlOp::evaluateIntegerOperation(
const llvm::APSInt &lhs,
683 const llvm::APSInt &rhs) {
685 if (!rhs.isNonNegative())
686 return emitOpError(
"shift amount must be non-negative");
688 if (!rhs.isRepresentableByInt64())
689 return emitOpError(
"shift amount must be representable in 64 bits");
690 return success(lhs << rhs.getExtValue());
697#define GET_OP_CLASSES
698#include "circt/Dialect/OM/OM.cpp.inc"
assert(baseType &&"element must be base type")
static Location getLoc(DefSlot slot)
static ParseResult parseClassLike(OpAsmParser &parser, OperationState &state)
LogicalResult verifyClassLike(ClassLike classLike)
std::optional< Type > getClassLikeFieldType(ClassLike classLike, StringAttr name)
void getClassLikeAsmBlockArgumentNames(ClassLike classLike, Region ®ion, OpAsmSetValueNameFn setNameFn)
static ParseResult parseBasePathString(OpAsmParser &parser, PathAttr &path)
static ParseResult parsePathString(OpAsmParser &parser, PathAttr &path, StringAttr &module, StringAttr &ref, StringAttr &field)
static void printBasePathString(OpAsmPrinter &p, Operation *op, PathAttr path)
static void printFieldLocs(OpAsmPrinter &printer, Operation *op, ArrayAttr fieldLocs)
static ParseResult parseFieldLocs(OpAsmParser &parser, ArrayAttr &fieldLocs)
static ParseResult parseClassFieldsList(OpAsmParser &parser, SmallVectorImpl< Attribute > &fieldNames, SmallVectorImpl< Type > &fieldTypes)
static void printClassLike(ClassLike classLike, OpAsmPrinter &printer)
void replaceClassLikeFieldTypes(ClassLike classLike, AttrTypeReplacer &replacer)
NamedAttribute makeFieldType(StringAttr name, Type type)
NamedAttribute makeFieldIdx(MLIRContext *ctx, mlir::StringAttr name, unsigned i)
static void printPathString(OpAsmPrinter &p, Operation *op, PathAttr path, StringAttr module, StringAttr ref, StringAttr field)
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
void error(Twine message)
ParseResult parsePath(MLIRContext *context, StringRef spelling, PathAttr &path, StringAttr &module, StringAttr &ref, StringAttr &field)
Parse a target string in to a path.
ParseResult parseBasePath(MLIRContext *context, StringRef spelling, PathAttr &path)
Parse a target string of the form "Foo/bar:Bar/baz" in to a base path.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A module name, and the name of an instance inside that module.
mlir::StringAttr mlir::StringAttr instance