CIRCT 23.0.0git
Loading...
Searching...
No Matches
OMOps.cpp
Go to the documentation of this file.
1//===- OMOps.cpp - Object Model operation definitions ---------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file contains the Object Model operation definitions.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/ImplicitLocOpBuilder.h"
18#include "llvm/ADT/STLExtras.h"
19
20using namespace mlir;
21using namespace circt::om;
22
23//===----------------------------------------------------------------------===//
24// Custom Printers and Parsers
25//===----------------------------------------------------------------------===//
26
27static ParseResult parseBasePathString(OpAsmParser &parser, PathAttr &path) {
28 auto *context = parser.getContext();
29 auto loc = parser.getCurrentLocation();
30 std::string rawPath;
31 if (parser.parseString(&rawPath))
32 return failure();
33 if (parseBasePath(context, rawPath, path))
34 return parser.emitError(loc, "invalid base path");
35 return success();
36}
37
38static void printBasePathString(OpAsmPrinter &p, Operation *op, PathAttr path) {
39 p << '\"';
40 llvm::interleave(
41 path, p,
42 [&](const PathElement &elt) {
43 p << elt.module.getValue() << '/' << elt.instance.getValue();
44 },
45 ":");
46 p << '\"';
47}
48
49static ParseResult parsePathString(OpAsmParser &parser, PathAttr &path,
50 StringAttr &module, StringAttr &ref,
51 StringAttr &field) {
52
53 auto *context = parser.getContext();
54 auto loc = parser.getCurrentLocation();
55 std::string rawPath;
56 if (parser.parseString(&rawPath))
57 return failure();
58 if (parsePath(context, rawPath, path, module, ref, field))
59 return parser.emitError(loc, "invalid path");
60 return success();
61}
62
63static void printPathString(OpAsmPrinter &p, Operation *op, PathAttr path,
64 StringAttr module, StringAttr ref,
65 StringAttr field) {
66 p << '\"';
67 for (const auto &elt : path)
68 p << elt.module.getValue() << '/' << elt.instance.getValue() << ':';
69 if (!module.getValue().empty())
70 p << module.getValue();
71 if (!ref.getValue().empty())
72 p << '>' << ref.getValue();
73 if (!field.getValue().empty())
74 p << field.getValue();
75 p << '\"';
76}
77
78static ParseResult parseFieldLocs(OpAsmParser &parser, ArrayAttr &fieldLocs) {
79 if (parser.parseOptionalKeyword("field_locs"))
80 return success();
81 if (parser.parseLParen() || parser.parseAttribute(fieldLocs) ||
82 parser.parseRParen()) {
83 return failure();
84 }
85 return success();
86}
87
88static void printFieldLocs(OpAsmPrinter &printer, Operation *op,
89 ArrayAttr fieldLocs) {
90 mlir::OpPrintingFlags flags;
91 if (!flags.shouldPrintDebugInfo() || !fieldLocs)
92 return;
93 printer << "field_locs(";
94 printer.printAttribute(fieldLocs);
95 printer << ")";
96}
97
98//===----------------------------------------------------------------------===//
99// Shared definitions
100//===----------------------------------------------------------------------===//
101static ParseResult parseClassFieldsList(OpAsmParser &parser,
102 SmallVectorImpl<Attribute> &fieldNames,
103 SmallVectorImpl<Type> &fieldTypes) {
104
105 llvm::StringMap<SMLoc> nameLocMap;
106 auto parseElt = [&]() -> ParseResult {
107 // Parse the field name.
108 std::string fieldName;
109 if (parser.parseKeywordOrString(&fieldName))
110 return failure();
111 SMLoc currLoc = parser.getCurrentLocation();
112 if (nameLocMap.count(fieldName)) {
113 parser.emitError(currLoc, "field \"")
114 << fieldName << "\" is defined twice";
115 parser.emitError(nameLocMap[fieldName]) << "previous definition is here";
116 return failure();
117 }
118 nameLocMap[fieldName] = currLoc;
119 fieldNames.push_back(StringAttr::get(parser.getContext(), fieldName));
120
121 // Parse the field type.
122 fieldTypes.emplace_back();
123 if (parser.parseColonType(fieldTypes.back()))
124 return failure();
125
126 return success();
127 };
128
129 return parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
130 parseElt);
131}
132
133static ParseResult parseClassLike(OpAsmParser &parser, OperationState &state) {
134 // Parse the Class symbol name.
135 StringAttr symName;
136 if (parser.parseSymbolName(symName, mlir::SymbolTable::getSymbolAttrName(),
137 state.attributes))
138 return failure();
139
140 // Parse the formal parameters.
141 SmallVector<OpAsmParser::Argument> args;
142 if (parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren,
143 /*allowType=*/true, /*allowAttrs=*/false))
144 return failure();
145
146 SmallVector<Type> fieldTypes;
147 SmallVector<Attribute> fieldNames;
148 if (succeeded(parser.parseOptionalArrow()))
149 if (failed(parseClassFieldsList(parser, fieldNames, fieldTypes)))
150 return failure();
151
152 SmallVector<NamedAttribute> fieldTypesMap;
153 if (!fieldNames.empty()) {
154 for (auto [name, type] : zip(fieldNames, fieldTypes))
155 fieldTypesMap.push_back(
156 NamedAttribute(cast<StringAttr>(name), TypeAttr::get(type)));
157 }
158 auto *ctx = parser.getContext();
159 state.addAttribute("fieldNames", mlir::ArrayAttr::get(ctx, fieldNames));
160 state.addAttribute("fieldTypes",
161 mlir::DictionaryAttr::get(ctx, fieldTypesMap));
162
163 // Parse the optional attribute dictionary.
164 if (failed(parser.parseOptionalAttrDictWithKeyword(state.attributes)))
165 return failure();
166
167 // Parse the body.
168 Region *region = state.addRegion();
169 if (parser.parseRegion(*region, args))
170 return failure();
171
172 // If the region was empty, add an empty block so it's still a SizedRegion<1>.
173 if (region->empty())
174 region->emplaceBlock();
175
176 // Remember the formal parameter names in an attribute.
177 auto argNames = llvm::map_range(args, [&](OpAsmParser::Argument arg) {
178 return StringAttr::get(parser.getContext(), arg.ssaName.name.drop_front());
179 });
180 state.addAttribute(
181 "formalParamNames",
182 ArrayAttr::get(parser.getContext(), SmallVector<Attribute>(argNames)));
183
184 return success();
185}
186
187static void printClassLike(ClassLike classLike, OpAsmPrinter &printer) {
188 // Print the Class symbol name.
189 printer << " @";
190 printer << classLike.getSymName();
191
192 // Retrieve the formal parameter names and values.
193 auto argNames = SmallVector<StringRef>(
194 classLike.getFormalParamNames().getAsValueRange<StringAttr>());
195 ArrayRef<BlockArgument> args = classLike.getBodyBlock()->getArguments();
196
197 // Print the formal parameters.
198 printer << '(';
199 for (size_t i = 0, e = args.size(); i < e; ++i) {
200 printer << '%' << argNames[i] << ": " << args[i].getType();
201 if (i < e - 1)
202 printer << ", ";
203 }
204 printer << ") ";
205
206 ArrayRef<Attribute> fieldNames =
207 cast<ArrayAttr>(classLike->getAttr("fieldNames")).getValue();
208
209 if (!fieldNames.empty()) {
210 printer << " -> (";
211 for (size_t i = 0, e = fieldNames.size(); i < e; ++i) {
212 if (i != 0)
213 printer << ", ";
214 StringAttr name = cast<StringAttr>(fieldNames[i]);
215 printer.printKeywordOrString(name.getValue());
216 printer << ": ";
217 Type type = classLike.getFieldType(name).value();
218 printer.printType(type);
219 }
220 printer << ") ";
221 }
222
223 // Print the optional attribute dictionary.
224 SmallVector<StringRef> elidedAttrs{classLike.getSymNameAttrName(),
225 classLike.getFormalParamNamesAttrName(),
226 "fieldTypes", "fieldNames"};
227 printer.printOptionalAttrDictWithKeyword(classLike.getOperation()->getAttrs(),
228 elidedAttrs);
229
230 // Print the body.
231 printer.printRegion(classLike.getBody(), /*printEntryBlockArgs=*/false,
232 /*printBlockTerminators=*/true);
233}
234
235LogicalResult verifyClassLike(ClassLike classLike) {
236 // Verify the formal parameter names match up with the values.
237 if (classLike.getFormalParamNames().size() !=
238 classLike.getBodyBlock()->getArguments().size()) {
239 auto error = classLike.emitOpError(
240 "formal parameter name list doesn't match formal parameter value list");
241 error.attachNote(classLike.getLoc())
242 << "formal parameter names: " << classLike.getFormalParamNames();
243 error.attachNote(classLike.getLoc())
244 << "formal parameter values: "
245 << classLike.getBodyBlock()->getArguments();
246 return error;
247 }
248
249 return success();
250}
251
252void getClassLikeAsmBlockArgumentNames(ClassLike classLike, Region &region,
253 OpAsmSetValueNameFn setNameFn) {
254 // Retrieve the formal parameter names and values.
255 auto argNames = SmallVector<StringRef>(
256 classLike.getFormalParamNames().getAsValueRange<StringAttr>());
257 ArrayRef<BlockArgument> args = classLike.getBodyBlock()->getArguments();
258
259 // Use the formal parameter names as the SSA value names.
260 for (size_t i = 0, e = args.size(); i < e; ++i)
261 setNameFn(args[i], argNames[i]);
262}
263
264NamedAttribute makeFieldType(StringAttr name, Type type) {
265 return NamedAttribute(name, TypeAttr::get(type));
266}
267
268NamedAttribute makeFieldIdx(MLIRContext *ctx, mlir::StringAttr name,
269 unsigned i) {
270 return NamedAttribute(StringAttr(name),
271 mlir::IntegerAttr::get(mlir::IndexType::get(ctx), i));
272}
273
274std::optional<Type> getClassLikeFieldType(ClassLike classLike,
275 StringAttr name) {
276 DictionaryAttr fieldTypes = mlir::cast<DictionaryAttr>(
277 classLike.getOperation()->getAttr("fieldTypes"));
278 Attribute type = fieldTypes.get(name);
279 if (!type)
280 return std::nullopt;
281 return cast<TypeAttr>(type).getValue();
282}
283
284void replaceClassLikeFieldTypes(ClassLike classLike,
285 AttrTypeReplacer &replacer) {
286 classLike->setAttr("fieldTypes", cast<DictionaryAttr>(replacer.replace(
287 classLike.getFieldTypes())));
288}
289
290//===----------------------------------------------------------------------===//
291// ClassOp
292//===----------------------------------------------------------------------===//
293
294ParseResult circt::om::ClassOp::parse(OpAsmParser &parser,
295 OperationState &state) {
296 return parseClassLike(parser, state);
297}
298
299circt::om::ClassOp circt::om::ClassOp::buildSimpleClassOp(
300 OpBuilder &odsBuilder, Location loc, Twine name,
301 ArrayRef<StringRef> formalParamNames, ArrayRef<StringRef> fieldNames,
302 ArrayRef<Type> fieldTypes) {
303 circt::om::ClassOp classOp = circt::om::ClassOp::create(
304 odsBuilder, loc, odsBuilder.getStringAttr(name),
305 odsBuilder.getStrArrayAttr(formalParamNames),
306 odsBuilder.getStrArrayAttr(fieldNames),
307 odsBuilder.getDictionaryAttr(llvm::map_to_vector(
308 llvm::zip(fieldNames, fieldTypes), [&](auto field) -> NamedAttribute {
309 return NamedAttribute(odsBuilder.getStringAttr(std::get<0>(field)),
310 TypeAttr::get(std::get<1>(field)));
311 })));
312 Block *body = &classOp.getRegion().emplaceBlock();
313 auto prevLoc = odsBuilder.saveInsertionPoint();
314 odsBuilder.setInsertionPointToEnd(body);
315
316 mlir::SmallVector<Attribute> locAttrs(fieldNames.size(), LocationAttr(loc));
317
318 ClassFieldsOp::create(odsBuilder, loc,
319 llvm::map_to_vector(fieldTypes,
320 [&](Type type) -> Value {
321 return body->addArgument(type,
322 loc);
323 }),
324 odsBuilder.getArrayAttr(locAttrs));
325
326 odsBuilder.restoreInsertionPoint(prevLoc);
327
328 return classOp;
329}
330
331void circt::om::ClassOp::print(OpAsmPrinter &printer) {
332 printClassLike(*this, printer);
333}
334
335LogicalResult circt::om::ClassOp::verify() { return verifyClassLike(*this); }
336
337LogicalResult circt::om::ClassOp::verifyRegions() {
338 auto fieldsOp = cast<ClassFieldsOp>(this->getBodyBlock()->getTerminator());
339
340 // The number of results matches the number of terminator operands.
341 if (fieldsOp.getNumOperands() != this->getFieldNames().size()) {
342 auto diag = this->emitOpError()
343 << "returns '" << this->getFieldNames().size()
344 << "' fields, but its terminator returned '"
345 << fieldsOp.getNumOperands() << "' fields";
346 return diag.attachNote(fieldsOp.getLoc()) << "see terminator:";
347 }
348
349 // The type of each result matches the corresponding terminator operand type.
350 auto types = this->getFieldTypes();
351 for (auto [fieldName, terminatorOperandType] :
352 llvm::zip(this->getFieldNames(), fieldsOp.getOperandTypes())) {
353
354 if (terminatorOperandType ==
355 cast<TypeAttr>(types.get(cast<StringAttr>(fieldName))).getValue())
356 continue;
357
358 auto diag = this->emitOpError()
359 << "returns different field types than its terminator";
360 return diag.attachNote(fieldsOp.getLoc()) << "see terminator:";
361 }
362
363 return success();
364}
365
366void circt::om::ClassOp::getAsmBlockArgumentNames(
367 Region &region, OpAsmSetValueNameFn setNameFn) {
368 getClassLikeAsmBlockArgumentNames(*this, region, setNameFn);
369}
370
371std::optional<mlir::Type>
372circt::om::ClassOp::getFieldType(mlir::StringAttr field) {
373 return getClassLikeFieldType(*this, field);
374}
375
376void circt::om::ClassOp::replaceFieldTypes(AttrTypeReplacer replacer) {
377 replaceClassLikeFieldTypes(*this, replacer);
378}
379
380void circt::om::ClassOp::updateFields(
381 mlir::ArrayRef<mlir::Location> newLocations,
382 mlir::ArrayRef<mlir::Value> newValues,
383 mlir::ArrayRef<mlir::Attribute> newNames) {
384
385 auto fieldsOp = getFieldsOp();
386 assert(fieldsOp && "The fields op should exist");
387 // Get field names.
388 SmallVector<Attribute> names(getFieldNamesAttr().getAsRange<StringAttr>());
389 // Get the field types.
390 SmallVector<NamedAttribute> fieldTypes(getFieldTypesAttr().getValue());
391 // Get the field values.
392 SmallVector<Value> fieldVals(fieldsOp.getFields());
393 // Get the field locations.
394 Location fieldOpLoc = fieldsOp->getLoc();
395
396 // Extract the locations per field.
397 SmallVector<Location> locations;
398 if (auto fl = dyn_cast<FusedLoc>(fieldOpLoc)) {
399 auto metadataArr = dyn_cast<ArrayAttr>(fl.getMetadata());
400 assert(metadataArr && "Expected the metadata for the fused location");
401 auto r = metadataArr.getAsRange<LocationAttr>();
402 locations.append(r.begin(), r.end());
403 } else {
404 // Assume same loc for every field.
405 locations.append(names.size(), fieldOpLoc);
406 }
407
408 // Append the new names, locations and values.
409 names.append(newNames.begin(), newNames.end());
410 locations.append(newLocations.begin(), newLocations.end());
411 fieldVals.append(newValues.begin(), newValues.end());
412
413 // Construct the new field types from values and names.
414 for (auto [v, n] : llvm::zip(newValues, newNames))
415 fieldTypes.emplace_back(
416 NamedAttribute(llvm::cast<StringAttr>(n), TypeAttr::get(v.getType())));
417
418 // Keep the locations as array on the metadata.
419 SmallVector<Attribute> locationsAttr;
420 llvm::for_each(locations, [&](Location &l) {
421 locationsAttr.push_back(cast<Attribute>(l));
422 });
423
424 ImplicitLocOpBuilder builder(getLoc(), *this);
425 // Update the field names attribute.
426 setFieldNamesAttr(builder.getArrayAttr(names));
427 // Update the fields type attribute.
428 setFieldTypesAttr(builder.getDictionaryAttr(fieldTypes));
429 fieldsOp.getFieldsMutable().assign(fieldVals);
430 // Update the location.
431 fieldsOp->setLoc(builder.getFusedLoc(
432 locations, ArrayAttr::get(getContext(), locationsAttr)));
433}
434
435void circt::om::ClassOp::addNewFieldsOp(mlir::OpBuilder &builder,
436 mlir::ArrayRef<Location> locs,
437 mlir::ArrayRef<Value> values) {
438 // Store the original locations as a metadata array so that unique locations
439 // are preserved as a mapping from field index to location
440 assert(locs.size() == values.size() && "Expected a location per value");
441 mlir::SmallVector<Attribute> locAttrs;
442 for (auto loc : locs) {
443 locAttrs.push_back(cast<Attribute>(LocationAttr(loc)));
444 }
445 // Also store the locations incase there's some other analysis that might
446 // be able to use the default FusedLoc representation.
447 ClassFieldsOp::create(builder, builder.getFusedLoc(locs), values,
448 builder.getArrayAttr(locAttrs));
449}
450
451mlir::Location circt::om::ClassOp::getFieldLocByIndex(size_t i) {
452 auto fieldsOp = this->getFieldsOp();
453 auto fieldLocs = fieldsOp.getFieldLocs();
454 if (!fieldLocs.has_value())
455 return fieldsOp.getLoc();
456 assert(i < fieldLocs.value().size() &&
457 "field index too large for location array");
458 return cast<LocationAttr>(fieldLocs.value()[i]);
459}
460
461//===----------------------------------------------------------------------===//
462// ClassExternOp
463//===----------------------------------------------------------------------===//
464
465ParseResult circt::om::ClassExternOp::parse(OpAsmParser &parser,
466 OperationState &state) {
467 return parseClassLike(parser, state);
468}
469
470void circt::om::ClassExternOp::print(OpAsmPrinter &printer) {
471 printClassLike(*this, printer);
472}
473
474LogicalResult circt::om::ClassExternOp::verify() {
475 if (failed(verifyClassLike(*this))) {
476 return failure();
477 }
478 // Verify body is empty
479 if (!this->getBodyBlock()->getOperations().empty()) {
480 return this->emitOpError("external class body should be empty");
481 }
482
483 return success();
484}
485
486void circt::om::ClassExternOp::getAsmBlockArgumentNames(
487 Region &region, OpAsmSetValueNameFn setNameFn) {
488 getClassLikeAsmBlockArgumentNames(*this, region, setNameFn);
489}
490
491std::optional<mlir::Type>
492circt::om::ClassExternOp::getFieldType(mlir::StringAttr field) {
493 return getClassLikeFieldType(*this, field);
494}
495
496void circt::om::ClassExternOp::replaceFieldTypes(AttrTypeReplacer replacer) {
497 replaceClassLikeFieldTypes(*this, replacer);
498}
499
500//===----------------------------------------------------------------------===//
501// ClassFieldsOp
502//===----------------------------------------------------------------------===//
503//
504LogicalResult circt::om::ClassFieldsOp::verify() {
505 auto fieldLocs = this->getFieldLocs();
506 if (fieldLocs.has_value()) {
507 auto fieldLocsVal = fieldLocs.value();
508 if (fieldLocsVal.size() != this->getFields().size()) {
509 auto error = this->emitOpError("size of field_locs (")
510 << fieldLocsVal.size()
511 << ") does not match number of fields ("
512 << this->getFields().size() << ")";
513 }
514 }
515 return success();
516}
517
518//===----------------------------------------------------------------------===//
519// ObjectOp
520//===----------------------------------------------------------------------===//
521
522void circt::om::ObjectOp::build(::mlir::OpBuilder &odsBuilder,
523 ::mlir::OperationState &odsState,
524 om::ClassOp classOp,
525 ::mlir::ValueRange actualParams) {
526 return build(odsBuilder, odsState,
527 om::ClassType::get(odsBuilder.getContext(),
528 mlir::FlatSymbolRefAttr::get(classOp)),
529 classOp.getNameAttr(), actualParams);
530}
531
532LogicalResult
533circt::om::ObjectOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
534 // Verify the result type is the same as the referred-to class.
535 StringAttr resultClassName = getResult().getType().getClassName().getAttr();
536 StringAttr className = getClassNameAttr();
537 if (resultClassName != className)
538 return emitOpError("result type (")
539 << resultClassName << ") does not match referred to class ("
540 << className << ')';
541
542 // Verify the referred to ClassOp exists.
543 auto classDef = dyn_cast_or_null<ClassLike>(
544 symbolTable.lookupNearestSymbolFrom(*this, className));
545 if (!classDef)
546 return emitOpError("refers to non-existant class (") << className << ')';
547
548 auto actualTypes = getActualParams().getTypes();
549 auto formalTypes = classDef.getBodyBlock()->getArgumentTypes();
550
551 // Verify the actual parameter list matches the formal parameter list.
552 if (actualTypes.size() != formalTypes.size()) {
553 auto error = emitOpError(
554 "actual parameter list doesn't match formal parameter list");
555 error.attachNote(classDef.getLoc())
556 << "formal parameters: " << classDef.getBodyBlock()->getArguments();
557 error.attachNote(getLoc()) << "actual parameters: " << getActualParams();
558 return error;
559 }
560
561 // Verify the actual parameter types match the formal parameter types.
562 for (size_t i = 0, e = actualTypes.size(); i < e; ++i) {
563 if (actualTypes[i] != formalTypes[i]) {
564 return emitOpError("actual parameter type (")
565 << actualTypes[i] << ") doesn't match formal parameter type ("
566 << formalTypes[i] << ')';
567 }
568 }
569
570 return success();
571}
572
573//===----------------------------------------------------------------------===//
574// ConstantOp
575//===----------------------------------------------------------------------===//
576
577void circt::om::ConstantOp::build(::mlir::OpBuilder &odsBuilder,
578 ::mlir::OperationState &odsState,
579 ::mlir::TypedAttr constVal) {
580 return build(odsBuilder, odsState, constVal.getType(), constVal);
581}
582
583OpFoldResult circt::om::ConstantOp::fold(FoldAdaptor adaptor) {
584 assert(adaptor.getOperands().empty() && "constant has no operands");
585 return getValueAttr();
586}
587
588//===----------------------------------------------------------------------===//
589// ListCreateOp
590//===----------------------------------------------------------------------===//
591
592void circt::om::ListCreateOp::print(OpAsmPrinter &p) {
593 p << " ";
594 p.printOperands(getInputs());
595 p.printOptionalAttrDict((*this)->getAttrs());
596 p << " : " << getType().getElementType();
597}
598
599ParseResult circt::om::ListCreateOp::parse(OpAsmParser &parser,
600 OperationState &result) {
601 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
602 Type elemType;
603
604 if (parser.parseOperandList(operands) ||
605 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
606 parser.parseType(elemType))
607 return failure();
608 result.addTypes({circt::om::ListType::get(elemType)});
609
610 for (auto operand : operands)
611 if (parser.resolveOperand(operand, elemType, result.operands))
612 return failure();
613 return success();
614}
615
616//===----------------------------------------------------------------------===//
617// BasePathCreateOp
618//===----------------------------------------------------------------------===//
619
620LogicalResult
621BasePathCreateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
622 auto hierPath = symbolTable.lookupNearestSymbolFrom<hw::HierPathOp>(
623 *this, getTargetAttr());
624 if (!hierPath)
625 return emitOpError("invalid symbol reference");
626 return success();
627}
628
629//===----------------------------------------------------------------------===//
630// PathCreateOp
631//===----------------------------------------------------------------------===//
632
633LogicalResult
634PathCreateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
635 auto hierPath = symbolTable.lookupNearestSymbolFrom<hw::HierPathOp>(
636 *this, getTargetAttr());
637 if (!hierPath)
638 return emitOpError("invalid symbol reference");
639 return success();
640}
641
642//===----------------------------------------------------------------------===//
643// IntegerAddOp
644//===----------------------------------------------------------------------===//
645
646FailureOr<llvm::APSInt>
647IntegerAddOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
648 const llvm::APSInt &rhs) {
649 return success(lhs + rhs);
650}
651
652//===----------------------------------------------------------------------===//
653// IntegerMulOp
654//===----------------------------------------------------------------------===//
655
656FailureOr<llvm::APSInt>
657IntegerMulOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
658 const llvm::APSInt &rhs) {
659 return success(lhs * rhs);
660}
661
662//===----------------------------------------------------------------------===//
663// IntegerShrOp
664//===----------------------------------------------------------------------===//
665
666FailureOr<llvm::APSInt>
667IntegerShrOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
668 const llvm::APSInt &rhs) {
669 // Check non-negative constraint from operation semantics.
670 if (!rhs.isNonNegative())
671 return emitOpError("shift amount must be non-negative");
672 // Check size constraint from implementation detail of using getExtValue.
673 if (!rhs.isRepresentableByInt64())
674 return emitOpError("shift amount must be representable in 64 bits");
675 return success(lhs >> rhs.getExtValue());
676}
677
678//===----------------------------------------------------------------------===//
679// IntegerShlOp
680//===----------------------------------------------------------------------===//
681
682FailureOr<llvm::APSInt>
683IntegerShlOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
684 const llvm::APSInt &rhs) {
685 // Check non-negative constraint from operation semantics.
686 if (!rhs.isNonNegative())
687 return emitOpError("shift amount must be non-negative");
688 // Check size constraint from implementation detail of using getExtValue.
689 if (!rhs.isRepresentableByInt64())
690 return emitOpError("shift amount must be representable in 64 bits");
691 return success(lhs << rhs.getExtValue());
692}
693
694//===----------------------------------------------------------------------===//
695// StringConcatOp
696//===----------------------------------------------------------------------===//
697
698OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
699 // Fold single-operand concat to just the operand.
700 if (getStrings().size() == 1)
701 return getStrings()[0];
702
703 // Check if all operands are constant strings before accumulating.
704 if (!llvm::all_of(adaptor.getStrings(), [](Attribute operand) {
705 return isa_and_nonnull<StringAttr>(operand);
706 }))
707 return {};
708
709 // All operands are constant strings, concatenate them.
710 SmallString<64> result;
711 for (auto operand : adaptor.getStrings())
712 result += cast<StringAttr>(operand).getValue();
713
714 return StringAttr::get(result, getResult().getType());
715}
716
717namespace {
718/// Flatten nested string.concat operations into a single concat.
719/// string.concat(a, string.concat(b, c), d) -> string.concat(a, b, c, d)
720class FlattenOMStringConcat : public mlir::OpRewritePattern<StringConcatOp> {
721public:
722 using OpRewritePattern::OpRewritePattern;
723
724 LogicalResult
725 matchAndRewrite(StringConcatOp concat,
726 mlir::PatternRewriter &rewriter) const override {
727
728 // Check if any operands are nested concats with a single use. Only inline
729 // single-use nested concats to avoid fighting with DCE.
730 bool hasNestedConcat = llvm::any_of(concat.getStrings(), [](Value operand) {
731 auto nestedConcat = operand.getDefiningOp<StringConcatOp>();
732 return nestedConcat && operand.hasOneUse();
733 });
734
735 if (!hasNestedConcat)
736 return failure();
737
738 // Flatten nested concats that have a single use.
739 SmallVector<Value> flatOperands;
740 for (auto input : concat.getStrings()) {
741 if (auto nestedConcat = input.getDefiningOp<StringConcatOp>();
742 nestedConcat && input.hasOneUse())
743 llvm::append_range(flatOperands, nestedConcat.getStrings());
744 else
745 flatOperands.push_back(input);
746 }
747
748 rewriter.modifyOpInPlace(concat,
749 [&]() { concat->setOperands(flatOperands); });
750 return success();
751 }
752};
753
754/// Merge consecutive constant strings in a concat and remove empty strings.
755/// string.concat("a", "b", x, "", "c", "d") -> string.concat("ab", x, "cd")
756class MergeAdjacentOMStringConstants
757 : public mlir::OpRewritePattern<StringConcatOp> {
758public:
759 using OpRewritePattern::OpRewritePattern;
760
761 LogicalResult
762 matchAndRewrite(StringConcatOp concat,
763 mlir::PatternRewriter &rewriter) const override {
764
765 SmallVector<Value> newOperands;
766 SmallString<64> accumulatedLit;
767 SmallVector<ConstantOp> accumulatedOps;
768 bool changed = false;
769
770 auto flushLiterals = [&]() {
771 if (accumulatedOps.empty())
772 return;
773
774 // If only one literal, reuse it.
775 if (accumulatedOps.size() == 1) {
776 newOperands.push_back(accumulatedOps[0]);
777 } else {
778 // Multiple literals - merge them.
779 auto newLit = rewriter.createOrFold<ConstantOp>(
780 concat.getLoc(),
781 StringAttr::get(accumulatedLit, concat.getResult().getType()));
782 newOperands.push_back(newLit);
783 changed = true;
784 }
785 accumulatedLit.clear();
786 accumulatedOps.clear();
787 };
788
789 for (auto operand : concat.getStrings()) {
790 if (auto litOp = operand.getDefiningOp<ConstantOp>()) {
791 if (auto strAttr = dyn_cast<StringAttr>(litOp.getValue())) {
792 // Skip empty strings.
793 if (strAttr.getValue().empty()) {
794 changed = true;
795 continue;
796 }
797 accumulatedLit += strAttr.getValue();
798 accumulatedOps.push_back(litOp);
799 continue;
800 }
801 }
802
803 flushLiterals();
804 newOperands.push_back(operand);
805 }
806
807 // Flush any remaining literals.
808 flushLiterals();
809
810 if (!changed)
811 return failure();
812
813 // If no operands remain, replace with empty string.
814 if (newOperands.empty())
815 return rewriter.replaceOpWithNewOp<ConstantOp>(
816 concat, StringAttr::get("", concat.getResult().getType())),
817 success();
818
819 // Single-operand case is handled by the folder.
820 rewriter.modifyOpInPlace(concat,
821 [&]() { concat->setOperands(newOperands); });
822 return success();
823 }
824};
825
826} // namespace
827
828void StringConcatOp::getCanonicalizationPatterns(RewritePatternSet &results,
829 MLIRContext *context) {
830 results.insert<FlattenOMStringConcat, MergeAdjacentOMStringConstants>(
831 context);
832}
833
834//===----------------------------------------------------------------------===//
835// UnknownValueOp
836//===----------------------------------------------------------------------===//
837
838LogicalResult circt::om::UnknownValueOp::verifySymbolUses(
839 SymbolTableCollection &symbolTable) {
840
841 // Unknown values of non-class type don't need to be verified.
842 auto classType = dyn_cast<ClassType>(getType());
843 if (!classType)
844 return success();
845
846 // Verify the referred to ClassOp exists.
847 auto className = classType.getClassName();
848 if (symbolTable.lookupNearestSymbolFrom<ClassLike>(*this, className))
849 return success();
850
851 return emitOpError() << "refers to non-existant class (\""
852 << className.getValue() << "\")";
853}
854
855//===----------------------------------------------------------------------===//
856// TableGen generated logic.
857//===----------------------------------------------------------------------===//
858
859#define GET_OP_CLASSES
860#include "circt/Dialect/OM/OM.cpp.inc"
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static ParseResult parseClassLike(OpAsmParser &parser, OperationState &state)
Definition OMOps.cpp:133
LogicalResult verifyClassLike(ClassLike classLike)
Definition OMOps.cpp:235
std::optional< Type > getClassLikeFieldType(ClassLike classLike, StringAttr name)
Definition OMOps.cpp:274
void getClassLikeAsmBlockArgumentNames(ClassLike classLike, Region &region, OpAsmSetValueNameFn setNameFn)
Definition OMOps.cpp:252
static ParseResult parseBasePathString(OpAsmParser &parser, PathAttr &path)
Definition OMOps.cpp:27
static ParseResult parsePathString(OpAsmParser &parser, PathAttr &path, StringAttr &module, StringAttr &ref, StringAttr &field)
Definition OMOps.cpp:49
static void printBasePathString(OpAsmPrinter &p, Operation *op, PathAttr path)
Definition OMOps.cpp:38
static void printFieldLocs(OpAsmPrinter &printer, Operation *op, ArrayAttr fieldLocs)
Definition OMOps.cpp:88
static ParseResult parseFieldLocs(OpAsmParser &parser, ArrayAttr &fieldLocs)
Definition OMOps.cpp:78
static ParseResult parseClassFieldsList(OpAsmParser &parser, SmallVectorImpl< Attribute > &fieldNames, SmallVectorImpl< Type > &fieldTypes)
Definition OMOps.cpp:101
static void printClassLike(ClassLike classLike, OpAsmPrinter &printer)
Definition OMOps.cpp:187
void replaceClassLikeFieldTypes(ClassLike classLike, AttrTypeReplacer &replacer)
Definition OMOps.cpp:284
NamedAttribute makeFieldType(StringAttr name, Type type)
Definition OMOps.cpp:264
NamedAttribute makeFieldIdx(MLIRContext *ctx, mlir::StringAttr name, unsigned i)
Definition OMOps.cpp:268
static void printPathString(OpAsmPrinter &p, Operation *op, PathAttr path, StringAttr module, StringAttr ref, StringAttr field)
Definition OMOps.cpp:63
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
void error(Twine message)
Definition LSPUtils.cpp:16
ParseResult parsePath(MLIRContext *context, StringRef spelling, PathAttr &path, StringAttr &module, StringAttr &ref, StringAttr &field)
Parse a target string in to a path.
Definition OMUtils.cpp:182
ParseResult parseBasePath(MLIRContext *context, StringRef spelling, PathAttr &path)
Parse a target string of the form "Foo/bar:Bar/baz" in to a base path.
Definition OMUtils.cpp:177
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
A module name, and the name of an instance inside that module.
mlir::StringAttr mlir::StringAttr instance