CIRCT  19.0.0git
OMOps.cpp
Go to the documentation of this file.
1 //===- OMOps.cpp - Object Model operation definitions ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the Object Model operation definitions.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #include "circt/Dialect/OM/OMOps.h"
14 #include "circt/Dialect/HW/HWOps.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/ImplicitLocOpBuilder.h"
18 
19 using namespace mlir;
20 using namespace circt::om;
21 
22 //===----------------------------------------------------------------------===//
23 // Path Printers and Parsers
24 //===----------------------------------------------------------------------===//
25 
26 static ParseResult parseBasePathString(OpAsmParser &parser, PathAttr &path) {
27  auto *context = parser.getContext();
28  auto loc = parser.getCurrentLocation();
29  std::string rawPath;
30  if (parser.parseString(&rawPath))
31  return failure();
32  if (parseBasePath(context, rawPath, path))
33  return parser.emitError(loc, "invalid base path");
34  return success();
35 }
36 
37 static void printBasePathString(OpAsmPrinter &p, Operation *op, PathAttr path) {
38  p << '\"';
39  llvm::interleave(
40  path, p,
41  [&](const PathElement &elt) {
42  p << elt.module.getValue() << '/' << elt.instance.getValue();
43  },
44  ":");
45  p << '\"';
46 }
47 
48 static ParseResult parsePathString(OpAsmParser &parser, PathAttr &path,
49  StringAttr &module, StringAttr &ref,
50  StringAttr &field) {
51 
52  auto *context = parser.getContext();
53  auto loc = parser.getCurrentLocation();
54  std::string rawPath;
55  if (parser.parseString(&rawPath))
56  return failure();
57  if (parsePath(context, rawPath, path, module, ref, field))
58  return parser.emitError(loc, "invalid path");
59  return success();
60 }
61 
62 static void printPathString(OpAsmPrinter &p, Operation *op, PathAttr path,
63  StringAttr module, StringAttr ref,
64  StringAttr field) {
65  p << '\"';
66  for (const auto &elt : path)
67  p << elt.module.getValue() << '/' << elt.instance.getValue() << ':';
68  if (!module.getValue().empty())
69  p << module.getValue();
70  if (!ref.getValue().empty())
71  p << '>' << ref.getValue();
72  if (!field.getValue().empty())
73  p << field.getValue();
74  p << '\"';
75 }
76 
77 //===----------------------------------------------------------------------===//
78 // Shared definitions
79 //===----------------------------------------------------------------------===//
80 
81 static ParseResult parseClassLike(OpAsmParser &parser, OperationState &state) {
82  // Parse the Class symbol name.
83  StringAttr symName;
84  if (parser.parseSymbolName(symName, mlir::SymbolTable::getSymbolAttrName(),
85  state.attributes))
86  return failure();
87 
88  // Parse the formal parameters.
89  SmallVector<OpAsmParser::Argument> args;
90  if (parser.parseArgumentList(args, OpAsmParser::Delimiter::Paren,
91  /*allowType=*/true, /*allowAttrs=*/false))
92  return failure();
93 
94  // Parse the optional attribute dictionary.
95  if (failed(parser.parseOptionalAttrDictWithKeyword(state.attributes)))
96  return failure();
97 
98  // Parse the body.
99  Region *region = state.addRegion();
100  if (parser.parseRegion(*region, args))
101  return failure();
102 
103  // If the region was empty, add an empty block so it's still a SizedRegion<1>.
104  if (region->empty())
105  region->emplaceBlock();
106 
107  // Remember the formal parameter names in an attribute.
108  auto argNames = llvm::map_range(args, [&](OpAsmParser::Argument arg) {
109  return StringAttr::get(parser.getContext(), arg.ssaName.name.drop_front());
110  });
111  state.addAttribute(
112  "formalParamNames",
113  ArrayAttr::get(parser.getContext(), SmallVector<Attribute>(argNames)));
114 
115  return success();
116 }
117 
118 static void printClassLike(ClassLike classLike, OpAsmPrinter &printer) {
119  // Print the Class symbol name.
120  printer << " @";
121  printer << classLike.getSymName();
122 
123  // Retrieve the formal parameter names and values.
124  auto argNames = SmallVector<StringRef>(
125  classLike.getFormalParamNames().getAsValueRange<StringAttr>());
126  ArrayRef<BlockArgument> args = classLike.getBodyBlock()->getArguments();
127 
128  // Print the formal parameters.
129  printer << '(';
130  for (size_t i = 0, e = args.size(); i < e; ++i) {
131  printer << '%' << argNames[i] << ": " << args[i].getType();
132  if (i < e - 1)
133  printer << ", ";
134  }
135  printer << ") ";
136 
137  // Print the optional attribute dictionary.
138  SmallVector<StringRef> elidedAttrs{classLike.getSymNameAttrName(),
139  classLike.getFormalParamNamesAttrName()};
140  printer.printOptionalAttrDictWithKeyword(classLike.getOperation()->getAttrs(),
141  elidedAttrs);
142 
143  // Print the body.
144  printer.printRegion(classLike.getBody(), /*printEntryBlockArgs=*/false,
145  /*printBlockTerminators=*/true);
146 }
147 
148 LogicalResult verifyClassLike(ClassLike classLike) {
149  // Verify the formal parameter names match up with the values.
150  if (classLike.getFormalParamNames().size() !=
151  classLike.getBodyBlock()->getArguments().size()) {
152  auto error = classLike.emitOpError(
153  "formal parameter name list doesn't match formal parameter value list");
154  error.attachNote(classLike.getLoc())
155  << "formal parameter names: " << classLike.getFormalParamNames();
156  error.attachNote(classLike.getLoc())
157  << "formal parameter values: "
158  << classLike.getBodyBlock()->getArguments();
159  return error;
160  }
161 
162  return success();
163 }
164 
165 void getClassLikeAsmBlockArgumentNames(ClassLike classLike, Region &region,
166  OpAsmSetValueNameFn setNameFn) {
167  // Retrieve the formal parameter names and values.
168  auto argNames = SmallVector<StringRef>(
169  classLike.getFormalParamNames().getAsValueRange<StringAttr>());
170  ArrayRef<BlockArgument> args = classLike.getBodyBlock()->getArguments();
171 
172  // Use the formal parameter names as the SSA value names.
173  for (size_t i = 0, e = args.size(); i < e; ++i)
174  setNameFn(args[i], argNames[i]);
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // ClassOp
179 //===----------------------------------------------------------------------===//
180 
181 ParseResult circt::om::ClassOp::parse(OpAsmParser &parser,
182  OperationState &state) {
183  return parseClassLike(parser, state);
184 }
185 
186 void circt::om::ClassOp::build(OpBuilder &odsBuilder, OperationState &odsState,
187  Twine name,
188  ArrayRef<StringRef> formalParamNames) {
189  return build(odsBuilder, odsState, odsBuilder.getStringAttr(name),
190  odsBuilder.getStrArrayAttr(formalParamNames));
191 }
192 
193 circt::om::ClassOp circt::om::ClassOp::buildSimpleClassOp(
194  OpBuilder &odsBuilder, Location loc, Twine name,
195  ArrayRef<StringRef> formalParamNames, ArrayRef<StringRef> fieldNames,
196  ArrayRef<Type> fieldTypes) {
197  circt::om::ClassOp classOp = odsBuilder.create<circt::om::ClassOp>(
198  loc, odsBuilder.getStringAttr(name),
199  odsBuilder.getStrArrayAttr(formalParamNames));
200  Block *body = &classOp.getRegion().emplaceBlock();
201  auto prevLoc = odsBuilder.saveInsertionPoint();
202  odsBuilder.setInsertionPointToEnd(body);
203  for (auto [name, type] : llvm::zip(fieldNames, fieldTypes))
204  odsBuilder.create<circt::om::ClassFieldOp>(loc, name,
205  body->addArgument(type, loc));
206  odsBuilder.restoreInsertionPoint(prevLoc);
207 
208  return classOp;
209 }
210 
211 void circt::om::ClassOp::build(OpBuilder &odsBuilder, OperationState &odsState,
212  Twine name) {
213  return build(odsBuilder, odsState, odsBuilder.getStringAttr(name),
214  odsBuilder.getStrArrayAttr({}));
215 }
216 
217 void circt::om::ClassOp::print(OpAsmPrinter &printer) {
218  printClassLike(*this, printer);
219 }
220 
221 LogicalResult circt::om::ClassOp::verify() { return verifyClassLike(*this); }
222 
223 void circt::om::ClassOp::getAsmBlockArgumentNames(
224  Region &region, OpAsmSetValueNameFn setNameFn) {
225  getClassLikeAsmBlockArgumentNames(*this, region, setNameFn);
226 }
227 
228 //===----------------------------------------------------------------------===//
229 // ClassFieldOp
230 //===----------------------------------------------------------------------===//
231 
232 Type circt::om::ClassFieldOp::getType() { return getValue().getType(); }
233 
234 //===----------------------------------------------------------------------===//
235 // ClassExternOp
236 //===----------------------------------------------------------------------===//
237 
238 ParseResult circt::om::ClassExternOp::parse(OpAsmParser &parser,
239  OperationState &state) {
240  return parseClassLike(parser, state);
241 }
242 
243 void circt::om::ClassExternOp::build(OpBuilder &odsBuilder,
244  OperationState &odsState, Twine name) {
245  return build(odsBuilder, odsState, odsBuilder.getStringAttr(name),
246  odsBuilder.getStrArrayAttr({}));
247 }
248 
249 void circt::om::ClassExternOp::build(OpBuilder &odsBuilder,
250  OperationState &odsState, Twine name,
251  ArrayRef<StringRef> formalParamNames) {
252  return build(odsBuilder, odsState, odsBuilder.getStringAttr(name),
253  odsBuilder.getStrArrayAttr(formalParamNames));
254 }
255 
256 void circt::om::ClassExternOp::print(OpAsmPrinter &printer) {
257  printClassLike(*this, printer);
258 }
259 
260 LogicalResult circt::om::ClassExternOp::verify() {
261  if (failed(verifyClassLike(*this))) {
262  return failure();
263  }
264 
265  // Verify that only external class field declarations are present in the body.
266  for (auto &op : getOps())
267  if (!isa<ClassExternFieldOp>(op))
268  return op.emitOpError("not allowed in external class");
269 
270  return success();
271 }
272 
273 void circt::om::ClassExternOp::getAsmBlockArgumentNames(
274  Region &region, OpAsmSetValueNameFn setNameFn) {
275  getClassLikeAsmBlockArgumentNames(*this, region, setNameFn);
276 }
277 
278 //===----------------------------------------------------------------------===//
279 // ObjectOp
280 //===----------------------------------------------------------------------===//
281 
282 void circt::om::ObjectOp::build(::mlir::OpBuilder &odsBuilder,
283  ::mlir::OperationState &odsState,
284  om::ClassOp classOp,
285  ::mlir::ValueRange actualParams) {
286  return build(odsBuilder, odsState,
287  om::ClassType::get(odsBuilder.getContext(),
288  mlir::FlatSymbolRefAttr::get(classOp)),
289  classOp.getNameAttr(), actualParams);
290 }
291 
292 LogicalResult
293 circt::om::ObjectOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
294  // Verify the result type is the same as the referred-to class.
295  StringAttr resultClassName = getResult().getType().getClassName().getAttr();
296  StringAttr className = getClassNameAttr();
297  if (resultClassName != className)
298  return emitOpError("result type (")
299  << resultClassName << ") does not match referred to class ("
300  << className << ')';
301 
302  // Verify the referred to ClassOp exists.
303  auto classDef = dyn_cast_or_null<ClassLike>(
304  symbolTable.lookupNearestSymbolFrom(*this, className));
305  if (!classDef)
306  return emitOpError("refers to non-existant class (") << className << ')';
307 
308  auto actualTypes = getActualParams().getTypes();
309  auto formalTypes = classDef.getBodyBlock()->getArgumentTypes();
310 
311  // Verify the actual parameter list matches the formal parameter list.
312  if (actualTypes.size() != formalTypes.size()) {
313  auto error = emitOpError(
314  "actual parameter list doesn't match formal parameter list");
315  error.attachNote(classDef.getLoc())
316  << "formal parameters: " << classDef.getBodyBlock()->getArguments();
317  error.attachNote(getLoc()) << "actual parameters: " << getActualParams();
318  return error;
319  }
320 
321  // Verify the actual parameter types match the formal parameter types.
322  for (size_t i = 0, e = actualTypes.size(); i < e; ++i) {
323  if (actualTypes[i] != formalTypes[i]) {
324  return emitOpError("actual parameter type (")
325  << actualTypes[i] << ") doesn't match formal parameter type ("
326  << formalTypes[i] << ')';
327  }
328  }
329 
330  return success();
331 }
332 
333 //===----------------------------------------------------------------------===//
334 // ConstantOp
335 //===----------------------------------------------------------------------===//
336 
337 void circt::om::ConstantOp::build(::mlir::OpBuilder &odsBuilder,
338  ::mlir::OperationState &odsState,
339  ::mlir::TypedAttr constVal) {
340  return build(odsBuilder, odsState, constVal.getType(), constVal);
341 }
342 
343 OpFoldResult circt::om::ConstantOp::fold(FoldAdaptor adaptor) {
344  assert(adaptor.getOperands().empty() && "constant has no operands");
345  return getValueAttr();
346 }
347 
348 //===----------------------------------------------------------------------===//
349 // ListCreateOp
350 //===----------------------------------------------------------------------===//
351 
352 void circt::om::ListCreateOp::print(OpAsmPrinter &p) {
353  p << " ";
354  p.printOperands(getInputs());
355  p.printOptionalAttrDict((*this)->getAttrs());
356  p << " : " << getType().getElementType();
357 }
358 
359 ParseResult circt::om::ListCreateOp::parse(OpAsmParser &parser,
360  OperationState &result) {
361  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
362  Type elemType;
363 
364  if (parser.parseOperandList(operands) ||
365  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
366  parser.parseType(elemType))
367  return failure();
368  result.addTypes({circt::om::ListType::get(elemType)});
369 
370  for (auto operand : operands)
371  if (parser.resolveOperand(operand, elemType, result.operands))
372  return failure();
373  return success();
374 }
375 
376 //===----------------------------------------------------------------------===//
377 // TupleCreateOp
378 //===----------------------------------------------------------------------===//
379 
380 LogicalResult TupleCreateOp::inferReturnTypes(
381  MLIRContext *context, std::optional<Location> location, ValueRange operands,
382  DictionaryAttr attributes, OpaqueProperties, RegionRange regions,
383  llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
384  ::llvm::SmallVector<Type> types;
385  for (auto op : operands)
386  types.push_back(op.getType());
387  inferredReturnTypes.push_back(TupleType::get(context, types));
388  return success();
389 }
390 
391 //===----------------------------------------------------------------------===//
392 // TupleGetOp
393 //===----------------------------------------------------------------------===//
394 
395 LogicalResult TupleGetOp::inferReturnTypes(
396  MLIRContext *context, std::optional<Location> location, ValueRange operands,
397  DictionaryAttr attributes, OpaqueProperties, RegionRange regions,
398  llvm::SmallVectorImpl<Type> &inferredReturnTypes) {
399  auto idx = attributes.getAs<mlir::IntegerAttr>("index");
400  if (operands.empty() || !idx)
401  return failure();
402 
403  auto tupleTypes = cast<TupleType>(operands[0].getType()).getTypes();
404  if (tupleTypes.size() <= idx.getValue().getLimitedValue()) {
405  if (location)
406  mlir::emitError(*location,
407  "tuple index out-of-bounds, must be less than ")
408  << tupleTypes.size() << " but got "
409  << idx.getValue().getLimitedValue();
410  return failure();
411  }
412 
413  inferredReturnTypes.push_back(tupleTypes[idx.getValue().getLimitedValue()]);
414  return success();
415 }
416 
417 //===----------------------------------------------------------------------===//
418 // MapCreateOp
419 //===----------------------------------------------------------------------===//
420 
421 void circt::om::MapCreateOp::print(OpAsmPrinter &p) {
422  p << " ";
423  p.printOperands(getInputs());
424  p.printOptionalAttrDict((*this)->getAttrs());
425  p << " : " << cast<circt::om::MapType>(getType()).getKeyType() << ", "
426  << cast<circt::om::MapType>(getType()).getValueType();
427 }
428 
429 ParseResult circt::om::MapCreateOp::parse(OpAsmParser &parser,
430  OperationState &result) {
431  llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
432  Type elementType, valueType;
433 
434  if (parser.parseOperandList(operands) ||
435  parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
436  parser.parseType(elementType) || parser.parseComma() ||
437  parser.parseType(valueType))
438  return failure();
439  result.addTypes({circt::om::MapType::get(elementType, valueType)});
440  auto operandType =
441  mlir::TupleType::get(valueType.getContext(), {elementType, valueType});
442 
443  for (auto operand : operands)
444  if (parser.resolveOperand(operand, operandType, result.operands))
445  return failure();
446  return success();
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // BasePathCreateOp
451 //===----------------------------------------------------------------------===//
452 
453 LogicalResult
454 BasePathCreateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
455  auto hierPath = symbolTable.lookupNearestSymbolFrom<hw::HierPathOp>(
456  *this, getTargetAttr());
457  if (!hierPath)
458  return emitOpError("invalid symbol reference");
459  return success();
460 }
461 
462 //===----------------------------------------------------------------------===//
463 // PathCreateOp
464 //===----------------------------------------------------------------------===//
465 
466 LogicalResult
467 PathCreateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
468  auto hierPath = symbolTable.lookupNearestSymbolFrom<hw::HierPathOp>(
469  *this, getTargetAttr());
470  if (!hierPath)
471  return emitOpError("invalid symbol reference");
472  return success();
473 }
474 
475 //===----------------------------------------------------------------------===//
476 // IntegerAddOp
477 //===----------------------------------------------------------------------===//
478 
479 FailureOr<llvm::APSInt>
480 IntegerAddOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
481  const llvm::APSInt &rhs) {
482  return success(lhs + rhs);
483 }
484 
485 //===----------------------------------------------------------------------===//
486 // IntegerMulOp
487 //===----------------------------------------------------------------------===//
488 
489 FailureOr<llvm::APSInt>
490 IntegerMulOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
491  const llvm::APSInt &rhs) {
492  return success(lhs * rhs);
493 }
494 
495 //===----------------------------------------------------------------------===//
496 // IntegerShrOp
497 //===----------------------------------------------------------------------===//
498 
499 FailureOr<llvm::APSInt>
500 IntegerShrOp::evaluateIntegerOperation(const llvm::APSInt &lhs,
501  const llvm::APSInt &rhs) {
502  // Check non-negative constraint from operation semantics.
503  if (!rhs.isNonNegative())
504  return emitOpError("shift amount must be non-negative");
505  // Check size constraint from implementation detail of using getExtValue.
506  if (!rhs.isRepresentableByInt64())
507  return emitOpError("shift amount must be representable in 64 bits");
508  return success(lhs >> rhs.getExtValue());
509 }
510 
511 //===----------------------------------------------------------------------===//
512 // TableGen generated logic.
513 //===----------------------------------------------------------------------===//
514 
515 #define GET_OP_CLASSES
516 #include "circt/Dialect/OM/OM.cpp.inc"
static ParseResult parseBasePathString(OpAsmParser &parser, PathAttr &path)
Definition: OMOps.cpp:26
static ParseResult parsePathString(OpAsmParser &parser, PathAttr &path, StringAttr &module, StringAttr &ref, StringAttr &field)
Definition: OMOps.cpp:48
static void printBasePathString(OpAsmPrinter &p, Operation *op, PathAttr path)
Definition: OMOps.cpp:37
static void printPathString(OpAsmPrinter &p, Operation *op, PathAttr path, StringAttr module, StringAttr ref, StringAttr field)
Definition: OMOps.cpp:62
ParseResult parsePath(MLIRContext *context, StringRef spelling, PathAttr &path, StringAttr &module, StringAttr &ref, StringAttr &field)
Parse a target string in to a path.
Definition: OMUtils.cpp:182
ParseResult parseBasePath(MLIRContext *context, StringRef spelling, PathAttr &path)
Parse a target string of the form "Foo/bar:Bar/baz" in to a base path.
Definition: OMUtils.cpp:177
A module name, and the name of an instance inside that module.
Definition: OMAttributes.h:22
mlir::StringAttr module
Definition: OMAttributes.h:35
mlir::StringAttr instance
Definition: OMAttributes.h:36