CIRCT  19.0.0git
OMOps.cpp
Go to the documentation of this file.
1 //===- OMOps.cpp - Object Model operation definitions ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the Object Model operation definitions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "circt/Dialect/OM/OMOps.h"
14 #include "circt/Dialect/HW/HWOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 
19 using namespace mlir;
20 using namespace circt::om;
21 
22 //===----------------------------------------------------------------------===//
23 // Path Printers and Parsers
24 //===----------------------------------------------------------------------===//
25 
26 static ParseResult parseBasePathString(OpAsmParser &parser, PathAttr &path) {
27  auto *context = parser.getContext();
28  auto loc = parser.getCurrentLocation();
29  std::string rawPath;
30  if (parser.parseString(&rawPath))
31  return failure();
32  if (parseBasePath(context, rawPath, path))
33  return parser.emitError(loc, "invalid base path");
34  return success();
35 }
36 
37 static void printBasePathString(OpAsmPrinter &p, Operation *op, PathAttr path) {
38  p << '\"';
39  llvm::interleave(
40  path, p,
41  [&](const PathElement &elt) {
42  p << elt.module.getValue() << '/' << elt.instance.getValue();
43  },
44  ":");
45  p << '\"';
46 }
47 
48 static ParseResult parsePathString(OpAsmParser &parser, PathAttr &path,
49  StringAttr &module, StringAttr &ref,
50  StringAttr &field) {
51 
52  auto *context = parser.getContext();
53  auto loc = parser.getCurrentLocation();
54  std::string rawPath;
55  if (parser.parseString(&rawPath))
56  return failure();
57  if (parsePath(context, rawPath, path, module, ref, field))
58  return parser.emitError(loc, "invalid path");
59  return success();
60 }
61 
62 static void printPathString(OpAsmPrinter &p, Operation *op, PathAttr path,
63  StringAttr module, StringAttr ref,
64  StringAttr field) {
65  p << '\"';
66  for (const auto &elt : path)
67  p << elt.module.getValue() << '/' << elt.instance.getValue() << ':';
68  if (!module.getValue().empty())
69  p << module.getValue();
70  if (!ref.getValue().empty())
71  p << '>' << ref.getValue();
72  if (!field.getValue().empty())
73  p << field.getValue();
74  p << '\"';
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Shared definitions
79 //===----------------------------------------------------------------------===//
80 
81 static ParseResult parseClassLike(OpAsmParser &parser, OperationState &state) {
82  // Parse the Class symbol name.
83  StringAttr symName;
84  if (parser.parseSymbolName(symName, mlir::SymbolTable::getSymbolAttrName(),
85  state.attributes))
86  return failure();
87 
88  // Parse the formal parameters.
89  SmallVector<OpAsmParser::Argument> args;
90  if (parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren,
91  /*allowType=*/true, /*allowAttrs=*/false))
92  return failure();
93 
94  // Parse the optional attribute dictionary.
95  if (failed(parser.parseOptionalAttrDictWithKeyword(state.attributes)))
96  return failure();
97 
98  // Parse the body.
99  Region *region = state.addRegion();
100  if (parser.parseRegion(*region, args))
101  return failure();
102 
103  // If the region was empty, add an empty block so it's still a SizedRegion<1>.
104  if (region->empty())
105  region->emplaceBlock();
106 
107  // Remember the formal parameter names in an attribute.
108  auto argNames = llvm::map_range(args, [&](OpAsmParser::Argument arg) {
109  return StringAttr::get(parser.getContext(), arg.ssaName.name.drop_front());
110  });
111  state.addAttribute(
112  "formalParamNames",
113  ArrayAttr::get(parser.getContext(), SmallVector<Attribute>(argNames)));
114 
115  return success();
116 }
117 
118 static void printClassLike(ClassLike classLike, OpAsmPrinter &printer) {
119  // Print the Class symbol name.
120  printer << " @";
121  printer << classLike.getSymName();
122 
123  // Retrieve the formal parameter names and values.
124  auto argNames = SmallVector<StringRef>(
125  classLike.getFormalParamNames().getAsValueRange<StringAttr>());
126  ArrayRef<BlockArgument> args = classLike.getBodyBlock()->getArguments();
127 
128  // Print the formal parameters.
129  printer << '(';
130  for (size_t i = 0, e = args.size(); i < e; ++i) {
131  printer << '%' << argNames[i] << ": " << args[i].getType();
132  if (i < e - 1)
133  printer << ", ";
134  }
135  printer << ") ";
136 
137  // Print the optional attribute dictionary.
138  SmallVector<StringRef> elidedAttrs{classLike.getSymNameAttrName(),
139  classLike.getFormalParamNamesAttrName()};
140  printer.printOptionalAttrDictWithKeyword(classLike.getOperation()->getAttrs(),
141  elidedAttrs);
142 
143  // Print the body.
144  printer.printRegion(classLike.getBody(), /*printEntryBlockArgs=*/false,
145  /*printBlockTerminators=*/true);
146 }
147 
148 LogicalResult verifyClassLike(ClassLike classLike) {
149  // Verify the formal parameter names match up with the values.
150  if (classLike.getFormalParamNames().size() !=
151  classLike.getBodyBlock()->getArguments().size()) {
152  auto error = classLike.emitOpError(
153  "formal parameter name list doesn't match formal parameter value list");
154  error.attachNote(classLike.getLoc())
155  << "formal parameter names: " << classLike.getFormalParamNames();
156  error.attachNote(classLike.getLoc())
157  << "formal parameter values: "
158  << classLike.getBodyBlock()->getArguments();
159  return error;
160  }
161 
162  return success();
163 }
164 
165 void getClassLikeAsmBlockArgumentNames(ClassLike classLike, Region &region,
166  OpAsmSetValueNameFn setNameFn) {
167  // Retrieve the formal parameter names and values.
168  auto argNames = SmallVector<StringRef>(
169  classLike.getFormalParamNames().getAsValueRange<StringAttr>());
170  ArrayRef<BlockArgument> args = classLike.getBodyBlock()->getArguments();
171 
172  // Use the formal parameter names as the SSA value names.
173  for (size_t i = 0, e = args.size(); i < e; ++i)
174  setNameFn(args[i], argNames[i]);
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // ClassOp
179 //===----------------------------------------------------------------------===//
180 
181 ParseResult circt::om::ClassOp::parse(OpAsmParser &parser,
182  OperationState &state) {
183  return parseClassLike(parser, state);
184 }
185 
186 void circt::om::ClassOp::build(OpBuilder &odsBuilder, OperationState &odsState,
187  Twine name,
188  ArrayRef<StringRef> formalParamNames) {
189  return build(odsBuilder, odsState, odsBuilder.getStringAttr(name),
190  odsBuilder.getStrArrayAttr(formalParamNames));
191 }
192 
193 circt::om::ClassOp circt::om::ClassOp::buildSimpleClassOp(
194  OpBuilder &odsBuilder, Location loc, Twine name,
195  ArrayRef<StringRef> formalParamNames, ArrayRef<StringRef> fieldNames,
196  ArrayRef<Type> fieldTypes) {
197  circt::om::ClassOp classOp = odsBuilder.create<circt::om::ClassOp>(
198  loc, odsBuilder.getStringAttr(name),
199  odsBuilder.getStrArrayAttr(formalParamNames));
200  Block *body = &classOp.getRegion().emplaceBlock();
201  auto prevLoc = odsBuilder.saveInsertionPoint();
202  odsBuilder.setInsertionPointToEnd(body);
203  for (auto [name, type] : llvm::zip(fieldNames, fieldTypes))
204  odsBuilder.create<circt::om::ClassFieldOp>(loc, name,
205  body->addArgument(type, loc));
206  odsBuilder.restoreInsertionPoint(prevLoc);
207 
208  return classOp;
209 }
210 
211 void circt::om::ClassOp::build(OpBuilder &odsBuilder, OperationState &odsState,
212  Twine name) {
213  return build(odsBuilder, odsState, odsBuilder.getStringAttr(name),
214  odsBuilder.getStrArrayAttr({}));
215 }
216 
217 void circt::om::ClassOp::print(OpAsmPrinter &printer) {
218  printClassLike(*this, printer);
219 }
220 
221 LogicalResult circt::om::ClassOp::verify() { return verifyClassLike(*this); }
222 
223 void circt::om::ClassOp::getAsmBlockArgumentNames(
224  Region &region, OpAsmSetValueNameFn setNameFn) {
225  getClassLikeAsmBlockArgumentNames(*this, region, setNameFn);
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // ClassFieldOp
230 //===----------------------------------------------------------------------===//
231 
232 Type circt::om::ClassFieldOp::getType() { return getValue().getType(); }
233 
234 //===----------------------------------------------------------------------===//
235 // ClassExternOp
236 //===----------------------------------------------------------------------===//
237 
238 ParseResult circt::om::ClassExternOp::parse(OpAsmParser &parser,
239  OperationState &state) {
240  return parseClassLike(parser, state);
241 }
242 
243 void circt::om::ClassExternOp::build(OpBuilder &odsBuilder,
244  OperationState &odsState, Twine name) {
245  return build(odsBuilder, odsState, odsBuilder.getStringAttr(name),
246  odsBuilder.getStrArrayAttr({}));
247 }
248 
249 void circt::om::ClassExternOp::build(OpBuilder &odsBuilder,
250  OperationState &odsState, Twine name,
251  ArrayRef<StringRef> formalParamNames) {
252  return build(odsBuilder, odsState, odsBuilder.getStringAttr(name),
253  odsBuilder.getStrArrayAttr(formalParamNames));
254 }
255 
256 void circt::om::ClassExternOp::print(OpAsmPrinter &printer) {
257  printClassLike(*this, printer);
258 }
259 
260 LogicalResult circt::om::ClassExternOp::verify() {
261  if (failed(verifyClassLike(*this))) {
262  return failure();
263  }
264 
265  // Verify that only external class field declarations are present in the body.
266  for (auto &op : getOps())
267  if (!isa<ClassExternFieldOp>(op))
268  return op.emitOpError("not allowed in external class");
269 
270  return success();
271 }
272 
273 void circt::om::ClassExternOp::getAsmBlockArgumentNames(
274  Region &region, OpAsmSetValueNameFn setNameFn) {
275  getClassLikeAsmBlockArgumentNames(*this, region, setNameFn);
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // ObjectOp
280 //===----------------------------------------------------------------------===//
281 
282 void circt::om::ObjectOp::build(::mlir::OpBuilder &odsBuilder,
283  ::mlir::OperationState &odsState,
284  om::ClassOp classOp,
285  ::mlir::ValueRange actualParams) {
286  return build(odsBuilder, odsState,
287  om::ClassType::get(odsBuilder.getContext(),
288  mlir::FlatSymbolRefAttr::get(classOp)),
289  classOp.getNameAttr(), actualParams);
290 }
291 
292 LogicalResult
293 circt::om::ObjectOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
294  // Verify the result type is the same as the referred-to class.
295  StringAttr resultClassName = getResult().getType().getClassName().getAttr();
296  StringAttr className = getClassNameAttr();
297  if (resultClassName != className)
298  return emitOpError("result type (")
299  << resultClassName << ") does not match referred to class ("
300  << className << ')';
301 
302  // Verify the referred to ClassOp exists.
303  auto classDef = dyn_cast_or_null<ClassLike>(
304  symbolTable.lookupNearestSymbolFrom(*this, className));
305  if (!classDef)
306  return emitOpError("refers to non-existant class (") << className << ')';
307 
308  auto actualTypes = getActualParams().getTypes();
309  auto formalTypes = classDef.getBodyBlock()->getArgumentTypes();
310 
311  // Verify the actual parameter list matches the formal parameter list.
312  if (actualTypes.size() != formalTypes.size()) {
313  auto error = emitOpError(
314  "actual parameter list doesn't match formal parameter list");
315  error.attachNote(classDef.getLoc())
316  << "formal parameters: " << classDef.getBodyBlock()->getArguments();
317  error.attachNote(getLoc()) << "actual parameters: " << getActualParams();
318  return error;
319  }
320 
321  // Verify the actual parameter types match the formal parameter types.
322  for (size_t i = 0, e = actualTypes.size(); i < e; ++i) {
323  if (actualTypes[i] != formalTypes[i]) {
324  return emitOpError("actual parameter type (")
325  << actualTypes[i] << ") doesn't match formal parameter type ("
326  << formalTypes[i] << ')';
327  }
328  }
329 
330  return success();
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // ObjectFieldOp
335 //===----------------------------------------------------------------------===//
336 
337 LogicalResult
338 circt::om::ObjectFieldOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
339  // Get the ObjectInstOp and the ClassLike it is an instance of.
340  ObjectOp objectInst = getObject().getDefiningOp<ObjectOp>();
341  ClassLike classDef = cast<ClassLike>(symbolTable.lookupNearestSymbolFrom(
342  *this, objectInst.getClassNameAttr()));
343 
344  // Traverse the field path, verifying each field exists.
345  ClassFieldLike finalField;
346  auto fields = SmallVector<FlatSymbolRefAttr>(
347  getFieldPath().getAsRange<FlatSymbolRefAttr>());
348  for (size_t i = 0, e = fields.size(); i < e; ++i) {
349  // Verify the field exists on the ClassOp.
350  auto field = fields[i];
351  ClassFieldLike fieldDef;
352  classDef.walk([&](ClassFieldLike fieldLike) {
353  if (fieldLike.getNameAttr() == field.getAttr()) {
354  fieldDef = fieldLike;
355  return WalkResult::interrupt();
356  }
357  return WalkResult::advance();
358  });
359  if (!fieldDef) {
360  auto error = emitOpError("referenced non-existant field ") << field;
361  error.attachNote(classDef.getLoc()) << "class defined here";
362  return error;
363  }
364 
365  // If there are more fields, verify the current field is of ClassType, and
366  // look up the ClassOp for that field.
367  if (i < e - 1) {
368  auto classType = fieldDef.getType().dyn_cast<ClassType>();
369  if (!classType)
370  return emitOpError("nested field access into ")
371  << field << " requires a ClassType, but found "
372  << fieldDef.getType();
373 
374  // The nested ClassOp must exist, since a field with ClassType must be
375  // an ObjectInstOp, which already verifies the class exists.
376  classDef = cast<ClassLike>(
377  symbolTable.lookupNearestSymbolFrom(*this, classType.getClassName()));
378 
379  // Proceed to the next field in the path.
380  continue;
381  }
382 
383  // On the last iteration down the path, save the final field being accessed.
384  finalField = fieldDef;
385  }
386 
387  // Verify the accessed field type matches the result type.
388  if (finalField.getType() != getResult().getType())
389  return emitOpError("expected type ")
390  << getResult().getType() << ", but accessed field has type "
391  << finalField.getType();
392 
393  return success();
394 }
395 
396 //===----------------------------------------------------------------------===//
397 // ConstantOp
398 //===----------------------------------------------------------------------===//
399 
400 void circt::om::ConstantOp::build(::mlir::OpBuilder &odsBuilder,
401  ::mlir::OperationState &odsState,
402  ::mlir::TypedAttr constVal) {
403  return build(odsBuilder, odsState, constVal.getType(), constVal);
404 }
405 
406 OpFoldResult circt::om::ConstantOp::fold(FoldAdaptor adaptor) {
407  assert(adaptor.getOperands().empty() && "constant has no operands");
408  return getValueAttr();
409 }
410 
411 //===----------------------------------------------------------------------===//
412 // ListCreateOp
413 //===----------------------------------------------------------------------===//
414 
415 void circt::om::ListCreateOp::print(OpAsmPrinter &p) {
416  p << " ";
417  p.printOperands(getInputs());
418  p.printOptionalAttrDict((*this)->getAttrs());
419  p << " : " << getType().getElementType();
420 }
421 
422 ParseResult circt::om::ListCreateOp::parse(OpAsmParser &parser,
423  OperationState &result) {
424  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
425  Type elemType;
426 
427  if (parser.parseOperandList(operands) ||
428  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
429  parser.parseType(elemType))
430  return failure();
431  result.addTypes({circt::om::ListType::get(elemType)});
432 
433  for (auto operand : operands)
434  if (parser.resolveOperand(operand, elemType, result.operands))
435  return failure();
436  return success();
437 }
438 
439 //===----------------------------------------------------------------------===//
440 // TupleCreateOp
441 //===----------------------------------------------------------------------===//
442 
443 LogicalResult TupleCreateOp::inferReturnTypes(
444  MLIRContext *context, std::optional<Location> location, ValueRange operands,
445  DictionaryAttr attributes, OpaqueProperties, RegionRange regions,
446  llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
447  ::llvm::SmallVector<Type> types;
448  for (auto op : operands)
449  types.push_back(op.getType());
450  inferredReturnTypes.push_back(TupleType::get(context, types));
451  return success();
452 }
453 
454 //===----------------------------------------------------------------------===//
455 // TupleGetOp
456 //===----------------------------------------------------------------------===//
457 
458 LogicalResult TupleGetOp::inferReturnTypes(
459  MLIRContext *context, std::optional<Location> location, ValueRange operands,
460  DictionaryAttr attributes, OpaqueProperties, RegionRange regions,
461  llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
462  auto idx = attributes.getAs<mlir::IntegerAttr>("index");
463  if (operands.empty() || !idx)
464  return failure();
465 
466  auto tupleTypes = operands[0].getType().cast<TupleType>().getTypes();
467  if (tupleTypes.size() <= idx.getValue().getLimitedValue()) {
468  if (location)
469  mlir::emitError(*location,
470  "tuple index out-of-bounds, must be less than ")
471  << tupleTypes.size() << " but got "
472  << idx.getValue().getLimitedValue();
473  return failure();
474  }
475 
476  inferredReturnTypes.push_back(tupleTypes[idx.getValue().getLimitedValue()]);
477  return success();
478 }
479 
480 //===----------------------------------------------------------------------===//
481 // MapCreateOp
482 //===----------------------------------------------------------------------===//
483 
484 void circt::om::MapCreateOp::print(OpAsmPrinter &p) {
485  p << " ";
486  p.printOperands(getInputs());
487  p.printOptionalAttrDict((*this)->getAttrs());
488  p << " : " << getType().cast<circt::om::MapType>().getKeyType() << ", "
489  << getType().cast<circt::om::MapType>().getValueType();
490 }
491 
492 ParseResult circt::om::MapCreateOp::parse(OpAsmParser &parser,
493  OperationState &result) {
494  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
495  Type elementType, valueType;
496 
497  if (parser.parseOperandList(operands) ||
498  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
499  parser.parseType(elementType) || parser.parseComma() ||
500  parser.parseType(valueType))
501  return failure();
502  result.addTypes({circt::om::MapType::get(elementType, valueType)});
503  auto operandType =
504  mlir::TupleType::get(valueType.getContext(), {elementType, valueType});
505 
506  for (auto operand : operands)
507  if (parser.resolveOperand(operand, operandType, result.operands))
508  return failure();
509  return success();
510 }
511 
512 //===----------------------------------------------------------------------===//
513 // BasePathCreateOp
514 //===----------------------------------------------------------------------===//
515 
516 LogicalResult
517 BasePathCreateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
518  auto hierPath = symbolTable.lookupNearestSymbolFrom<hw::HierPathOp>(
519  *this, getTargetAttr());
520  if (!hierPath)
521  return emitOpError("invalid symbol reference");
522  return success();
523 }
524 
525 //===----------------------------------------------------------------------===//
526 // PathCreateOp
527 //===----------------------------------------------------------------------===//
528 
529 LogicalResult
530 PathCreateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
531  auto hierPath = symbolTable.lookupNearestSymbolFrom<hw::HierPathOp>(
532  *this, getTargetAttr());
533  if (!hierPath)
534  return emitOpError("invalid symbol reference");
535  return success();
536 }
537 
538 //===----------------------------------------------------------------------===//
539 // IntegerAddOp
540 //===----------------------------------------------------------------------===//
541 
542 FailureOr<llvm::APSInt>
543 IntegerAddOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
544  const llvm::APSInt &rhs) {
545  return success(lhs + rhs);
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // IntegerMulOp
550 //===----------------------------------------------------------------------===//
551 
552 FailureOr<llvm::APSInt>
553 IntegerMulOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
554  const llvm::APSInt &rhs) {
555  return success(lhs * rhs);
556 }
557 
558 //===----------------------------------------------------------------------===//
559 // IntegerShrOp
560 //===----------------------------------------------------------------------===//
561 
562 FailureOr<llvm::APSInt>
563 IntegerShrOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
564  const llvm::APSInt &rhs) {
565  // Check non-negative constraint from operation semantics.
566  if (!rhs.isNonNegative())
567  return emitOpError("shift amount must be non-negative");
568  // Check size constraint from implementation detail of using getExtValue.
569  if (!rhs.isRepresentableByInt64())
570  return emitOpError("shift amount must be representable in 64 bits");
571  return success(lhs >> rhs.getExtValue());
572 }
573 
574 //===----------------------------------------------------------------------===//
575 // TableGen generated logic.
576 //===----------------------------------------------------------------------===//
577 
578 #define GET_OP_CLASSES
579 #include "circt/Dialect/OM/OM.cpp.inc"
static ParseResult parseBasePathString(OpAsmParser &parser, PathAttr &path)
Definition: OMOps.cpp:26
static ParseResult parsePathString(OpAsmParser &parser, PathAttr &path, StringAttr &module, StringAttr &ref, StringAttr &field)
Definition: OMOps.cpp:48
static void printBasePathString(OpAsmPrinter &p, Operation *op, PathAttr path)
Definition: OMOps.cpp:37
static void printPathString(OpAsmPrinter &p, Operation *op, PathAttr path, StringAttr module, StringAttr ref, StringAttr field)
Definition: OMOps.cpp:62
ParseResult parsePath(MLIRContext *context, StringRef spelling, PathAttr &path, StringAttr &module, StringAttr &ref, StringAttr &field)
Parse a target string in to a path.
Definition: OMUtils.cpp:182
ParseResult parseBasePath(MLIRContext *context, StringRef spelling, PathAttr &path)
Parse a target string of the form "Foo/bar:Bar/baz" in to a base path.
Definition: OMUtils.cpp:177
A module name, and the name of an instance inside that module.
Definition: OMAttributes.h:22
mlir::StringAttr module
Definition: OMAttributes.h:35
mlir::StringAttr instance
Definition: OMAttributes.h:36