CIRCT  20.0.0git
SystemCOps.cpp
Go to the documentation of this file.
1 //===- SystemCOps.cpp - Implement the SystemC operations ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SystemC ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/IRMapping.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Interfaces/FunctionImplementation.h"
21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 using namespace circt;
25 using namespace circt::systemc;
26 
27 //===----------------------------------------------------------------------===//
28 // Helpers
29 //===----------------------------------------------------------------------===//
30 
31 static LogicalResult verifyUniqueNamesInRegion(
32  Operation *operation, ArrayAttr argNames,
33  std::function<void(mlir::InFlightDiagnostic &)> attachNote) {
34  DenseMap<StringRef, BlockArgument> portNames;
35  DenseMap<StringRef, Operation *> memberNames;
36  DenseMap<StringRef, Operation *> localNames;
37 
38  if (operation->getNumRegions() != 1)
39  return operation->emitError("required to have exactly one region");
40 
41  bool portsVerified = true;
42 
43  for (auto arg : llvm::zip(argNames, operation->getRegion(0).getArguments())) {
44  StringRef argName = cast<StringAttr>(std::get<0>(arg)).getValue();
45  BlockArgument argValue = std::get<1>(arg);
46 
47  if (portNames.count(argName)) {
48  auto diag = mlir::emitError(argValue.getLoc(), "redefines name '")
49  << argName << "'";
50  diag.attachNote(portNames[argName].getLoc())
51  << "'" << argName << "' first defined here";
52  attachNote(diag);
53  portsVerified = false;
54  continue;
55  }
56 
57  portNames.insert({argName, argValue});
58  }
59 
60  WalkResult result =
61  operation->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
62  if (isa<SCModuleOp>(op->getParentOp()))
63  localNames.clear();
64 
65  if (auto nameDeclOp = dyn_cast<SystemCNameDeclOpInterface>(op)) {
66  StringRef name = nameDeclOp.getName();
67 
68  auto reportNameRedefinition = [&](Location firstLoc) -> WalkResult {
69  auto diag = mlir::emitError(op->getLoc(), "redefines name '")
70  << name << "'";
71  diag.attachNote(firstLoc) << "'" << name << "' first defined here";
72  attachNote(diag);
73  return WalkResult::interrupt();
74  };
75 
76  if (portNames.count(name))
77  return reportNameRedefinition(portNames[name].getLoc());
78  if (memberNames.count(name))
79  return reportNameRedefinition(memberNames[name]->getLoc());
80  if (localNames.count(name))
81  return reportNameRedefinition(localNames[name]->getLoc());
82 
83  if (isa<SCModuleOp>(op->getParentOp()))
84  memberNames.insert({name, op});
85  else
86  localNames.insert({name, op});
87  }
88 
89  return WalkResult::advance();
90  });
91 
92  if (result.wasInterrupted() || !portsVerified)
93  return failure();
94 
95  return success();
96 }
97 
98 //===----------------------------------------------------------------------===//
99 // SCModuleOp
100 //===----------------------------------------------------------------------===//
101 
103  return TypeSwitch<Type, hw::ModulePort::Direction>(type)
104  .Case<InOutType>([](auto ty) { return hw::ModulePort::Direction::InOut; })
105  .Case<InputType>([](auto ty) { return hw::ModulePort::Direction::Input; })
106  .Case<OutputType>(
107  [](auto ty) { return hw::ModulePort::Direction::Output; });
108 }
109 
110 SCModuleOp::PortDirectionRange
111 SCModuleOp::getPortsOfDirection(hw::ModulePort::Direction direction) {
112  std::function<bool(const BlockArgument &)> predicateFn =
113  [&](const BlockArgument &arg) -> bool {
114  return getDirection(arg.getType()) == direction;
115  };
116  return llvm::make_filter_range(getArguments(), predicateFn);
117 }
118 
119 SmallVector<::circt::hw::PortInfo> SCModuleOp::getPortList() {
120  SmallVector<hw::PortInfo> ports;
121  for (int i = 0, e = getNumArguments(); i < e; ++i) {
122  hw::PortInfo info;
123  info.name = cast<StringAttr>(getPortNames()[i]);
124  info.type = getSignalBaseType(getArgument(i).getType());
125  info.dir = getDirection(info.type);
126  ports.push_back(info);
127  }
128  return ports;
129 }
130 
131 mlir::Region *SCModuleOp::getCallableRegion() { return &getBody(); }
132 
133 StringRef SCModuleOp::getModuleName() {
134  return (*this)
135  ->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
136  .getValue();
137 }
138 
139 ParseResult SCModuleOp::parse(OpAsmParser &parser, OperationState &result) {
140 
141  // Parse the visibility attribute.
142  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
143 
144  // Parse the name as a symbol.
145  StringAttr moduleName;
146  if (parser.parseSymbolName(moduleName, SymbolTable::getSymbolAttrName(),
147  result.attributes))
148  return failure();
149 
150  // Parse the function signature.
151  bool isVariadic = false;
152  SmallVector<OpAsmParser::Argument, 4> entryArgs;
153  SmallVector<Attribute> argNames;
154  SmallVector<Attribute> argLocs;
155  SmallVector<Attribute> resultNames;
156  SmallVector<DictionaryAttr> resultAttrs;
157  SmallVector<Attribute> resultLocs;
158  TypeAttr functionType;
160  parser, isVariadic, entryArgs, argNames, argLocs, resultNames,
161  resultAttrs, resultLocs, functionType)))
162  return failure();
163 
164  // Parse the attribute dict.
165  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
166  return failure();
167 
168  result.addAttribute("portNames",
169  ArrayAttr::get(parser.getContext(), argNames));
170 
171  result.addAttribute(SCModuleOp::getFunctionTypeAttrName(result.name),
172  functionType);
173 
174  mlir::function_interface_impl::addArgAndResultAttrs(
175  parser.getBuilder(), result, entryArgs, resultAttrs,
176  SCModuleOp::getArgAttrsAttrName(result.name),
177  SCModuleOp::getResAttrsAttrName(result.name));
178 
179  auto &body = *result.addRegion();
180  if (parser.parseRegion(body, entryArgs))
181  return failure();
182  if (body.empty())
183  body.push_back(std::make_unique<Block>().release());
184 
185  return success();
186 }
187 
188 void SCModuleOp::print(OpAsmPrinter &p) {
189  p << ' ';
190 
191  // Print the visibility of the module.
192  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
193  if (auto visibility =
194  getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
195  p << visibility.getValue() << ' ';
196 
197  p.printSymbolName(SymbolTable::getSymbolName(*this).getValue());
198  p << ' ';
199 
200  bool needArgNamesAttr = false;
202  p, *this, getFunctionType().getInputs(), false, {}, needArgNamesAttr);
203  mlir::function_interface_impl::printFunctionAttributes(
204  p, *this,
205  {"portNames", getFunctionTypeAttrName(), getArgAttrsAttrName(),
206  getResAttrsAttrName()});
207 
208  p << ' ';
209  p.printRegion(getBody(), false, false);
210 }
211 
212 /// Returns the argument types of this function.
213 ArrayRef<Type> SCModuleOp::getArgumentTypes() {
214  return getFunctionType().getInputs();
215 }
216 
217 /// Returns the result types of this function.
218 ArrayRef<Type> SCModuleOp::getResultTypes() {
219  return getFunctionType().getResults();
220 }
221 
222 static Type wrapPortType(Type type, hw::ModulePort::Direction direction) {
223  if (auto inoutTy = dyn_cast<hw::InOutType>(type))
224  type = inoutTy.getElementType();
225 
226  switch (direction) {
228  return InOutType::get(type);
230  return InputType::get(type);
232  return OutputType::get(type);
233  }
234  llvm_unreachable("Impossible port direction");
235 }
236 
237 void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
238  StringAttr name, ArrayAttr portNames,
239  ArrayRef<Type> portTypes,
240  ArrayRef<NamedAttribute> attributes) {
241  odsState.addAttribute(getPortNamesAttrName(odsState.name), portNames);
242  Region *region = odsState.addRegion();
243 
244  auto moduleType = odsBuilder.getFunctionType(portTypes, {});
245  odsState.addAttribute(getFunctionTypeAttrName(odsState.name),
246  TypeAttr::get(moduleType));
247 
248  odsState.addAttribute(SymbolTable::getSymbolAttrName(), name);
249  region->push_back(new Block);
250  region->addArguments(
251  portTypes,
252  SmallVector<Location>(portTypes.size(), odsBuilder.getUnknownLoc()));
253  odsState.addAttributes(attributes);
254 }
255 
256 void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
257  StringAttr name, ArrayRef<hw::PortInfo> ports,
258  ArrayRef<NamedAttribute> attributes) {
259  MLIRContext *ctxt = odsBuilder.getContext();
260  SmallVector<Attribute> portNames;
261  SmallVector<Type> portTypes;
262  for (auto port : ports) {
263  portNames.push_back(StringAttr::get(ctxt, port.getName()));
264  portTypes.push_back(wrapPortType(port.type, port.dir));
265  }
266  build(odsBuilder, odsState, name, ArrayAttr::get(ctxt, portNames), portTypes);
267 }
268 
269 void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
270  StringAttr name, const hw::ModulePortInfo &ports,
271  ArrayRef<NamedAttribute> attributes) {
272  MLIRContext *ctxt = odsBuilder.getContext();
273  SmallVector<Attribute> portNames;
274  SmallVector<Type> portTypes;
275  for (auto port : ports) {
276  portNames.push_back(StringAttr::get(ctxt, port.getName()));
277  portTypes.push_back(wrapPortType(port.type, port.dir));
278  }
279  build(odsBuilder, odsState, name, ArrayAttr::get(ctxt, portNames), portTypes);
280 }
281 
282 void SCModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
283  mlir::OpAsmSetValueNameFn setNameFn) {
284  if (region.empty())
285  return;
286 
287  ArrayAttr portNames = getPortNames();
288  for (size_t i = 0, e = getNumArguments(); i != e; ++i) {
289  auto str = cast<StringAttr>(portNames[i]).getValue();
290  setNameFn(getArgument(i), str);
291  }
292 }
293 
294 LogicalResult SCModuleOp::verify() {
295  if (getFunctionType().getNumResults() != 0)
296  return emitOpError(
297  "incorrect number of function results (always has to be 0)");
298  if (getPortNames().size() != getFunctionType().getNumInputs())
299  return emitOpError("incorrect number of port names");
300 
301  for (auto arg : getArguments()) {
302  if (!hw::type_isa<InputType, OutputType, InOutType>(arg.getType()))
303  return mlir::emitError(
304  arg.getLoc(),
305  "module port must be of type 'sc_in', 'sc_out', or 'sc_inout'");
306  }
307 
308  for (auto portName : getPortNames()) {
309  if (cast<StringAttr>(portName).getValue().empty())
310  return emitOpError("port name must not be empty");
311  }
312 
313  return success();
314 }
315 
316 LogicalResult SCModuleOp::verifyRegions() {
317  auto attachNote = [&](mlir::InFlightDiagnostic &diag) {
318  diag.attachNote(getLoc()) << "in module '@" << getModuleName() << "'";
319  };
320  return verifyUniqueNamesInRegion(getOperation(), getPortNames(), attachNote);
321 }
322 
323 CtorOp SCModuleOp::getOrCreateCtor() {
324  CtorOp ctor;
325  getBody().walk([&](Operation *op) {
326  if ((ctor = dyn_cast<CtorOp>(op)))
327  return WalkResult::interrupt();
328 
329  return WalkResult::skip();
330  });
331 
332  if (ctor)
333  return ctor;
334 
335  return OpBuilder(getBody()).create<CtorOp>(getLoc());
336 }
337 
338 DestructorOp SCModuleOp::getOrCreateDestructor() {
339  DestructorOp destructor;
340  getBody().walk([&](Operation *op) {
341  if ((destructor = dyn_cast<DestructorOp>(op)))
342  return WalkResult::interrupt();
343 
344  return WalkResult::skip();
345  });
346 
347  if (destructor)
348  return destructor;
349 
350  return OpBuilder::atBlockEnd(getBodyBlock()).create<DestructorOp>(getLoc());
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // SignalOp
355 //===----------------------------------------------------------------------===//
356 
358  setNameFn(getSignal(), getName());
359 }
360 
361 //===----------------------------------------------------------------------===//
362 // ConvertOp
363 //===----------------------------------------------------------------------===//
364 
365 OpFoldResult ConvertOp::fold(FoldAdaptor) {
366  if (getInput().getType() == getResult().getType())
367  return getInput();
368 
369  if (auto other = getInput().getDefiningOp<ConvertOp>()) {
370  Type inputType = other.getInput().getType();
371  Type intermediateType = getInput().getType();
372 
373  if (inputType != getResult().getType())
374  return {};
375 
376  // Either both the input and intermediate types are signed or both are
377  // unsigned.
378  bool inputSigned = isa<SignedType, IntBaseType>(inputType);
379  bool intermediateSigned = isa<SignedType, IntBaseType>(intermediateType);
380  if (inputSigned ^ intermediateSigned)
381  return {};
382 
383  // Converting 4-valued to 2-valued and back may lose information.
384  if (isa<LogicVectorBaseType, LogicType>(inputType) &&
385  !isa<LogicVectorBaseType, LogicType>(intermediateType))
386  return {};
387 
388  auto inputBw = getBitWidth(inputType);
389  auto intermediateBw = getBitWidth(intermediateType);
390 
391  if (!inputBw && intermediateBw) {
392  if (isa<IntBaseType, UIntBaseType>(inputType) && *intermediateBw >= 64)
393  return other.getInput();
394  // We cannot support input types of signed, unsigned, and vector types
395  // since they have no upper bound for the bit-width.
396  }
397 
398  if (!intermediateBw) {
399  if (isa<BitVectorBaseType, LogicVectorBaseType>(intermediateType))
400  return other.getInput();
401 
402  if (!inputBw && isa<IntBaseType, UIntBaseType>(inputType) &&
403  isa<SignedType, UnsignedType>(intermediateType))
404  return other.getInput();
405 
406  if (inputBw && *inputBw <= 64 &&
407  isa<IntBaseType, UIntBaseType, SignedType, UnsignedType>(
408  intermediateType))
409  return other.getInput();
410 
411  // We have to be careful with the signed and unsigned types as they often
412  // have a max bit-width defined (that can be customized) and thus folding
413  // here could change the behavior.
414  }
415 
416  if (inputBw && intermediateBw && *inputBw <= *intermediateBw)
417  return other.getInput();
418  }
419 
420  return {};
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // CtorOp
425 //===----------------------------------------------------------------------===//
426 
427 LogicalResult CtorOp::verify() {
428  if (getBody().getNumArguments() != 0)
429  return emitOpError("must not have any arguments");
430 
431  return success();
432 }
433 
434 //===----------------------------------------------------------------------===//
435 // SCFuncOp
436 //===----------------------------------------------------------------------===//
437 
439  setNameFn(getHandle(), getName());
440 }
441 
442 LogicalResult SCFuncOp::verify() {
443  if (getBody().getNumArguments() != 0)
444  return emitOpError("must not have any arguments");
445 
446  return success();
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // InstanceDeclOp
451 //===----------------------------------------------------------------------===//
452 
454  setNameFn(getInstanceHandle(), getName());
455 }
456 
457 StringRef InstanceDeclOp::getInstanceName() { return getName(); }
458 StringAttr InstanceDeclOp::getInstanceNameAttr() { return getNameAttr(); }
459 
460 Operation *
461 InstanceDeclOp::getReferencedModuleCached(const hw::HWSymbolCache *cache) {
462  if (cache)
463  if (auto *result = cache->getDefinition(getModuleNameAttr()))
464  return result;
465 
466  auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
467  return topLevelModuleOp.lookupSymbol(getModuleName());
468 }
469 
470 LogicalResult
471 InstanceDeclOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
472  auto *module =
473  symbolTable.lookupNearestSymbolFrom(*this, getModuleNameAttr());
474  if (module == nullptr)
475  return emitError("cannot find module definition '")
476  << getModuleName() << "'";
477 
478  auto emitError = [&](const std::function<void(InFlightDiagnostic & diag)> &fn)
479  -> LogicalResult {
480  auto diag = emitOpError();
481  fn(diag);
482  diag.attachNote(module->getLoc()) << "module declared here";
483  return failure();
484  };
485 
486  // It must be a systemc module.
487  if (!isa<SCModuleOp>(module))
488  return emitError([&](auto &diag) {
489  diag << "symbol reference '" << getModuleName()
490  << "' isn't a systemc module";
491  });
492 
493  auto scModule = cast<SCModuleOp>(module);
494 
495  // Check that the module name of the symbol and instance type match.
496  if (scModule.getModuleName() != getInstanceType().getModuleName())
497  return emitError([&](auto &diag) {
498  diag << "module names must match; expected '" << scModule.getModuleName()
499  << "' but got '" << getInstanceType().getModuleName().getValue()
500  << "'";
501  });
502 
503  // Check that port types and names are consistent with the referenced module.
504  ArrayRef<ModuleType::PortInfo> ports = getInstanceType().getPorts();
505  ArrayAttr modArgNames = scModule.getPortNames();
506  auto numPorts = ports.size();
507  auto expectedPortTypes = scModule.getArgumentTypes();
508 
509  if (expectedPortTypes.size() != numPorts)
510  return emitError([&](auto &diag) {
511  diag << "has a wrong number of ports; expected "
512  << expectedPortTypes.size() << " but got " << numPorts;
513  });
514 
515  for (size_t i = 0; i != numPorts; ++i) {
516  if (ports[i].type != expectedPortTypes[i]) {
517  return emitError([&](auto &diag) {
518  diag << "port type #" << i << " must be " << expectedPortTypes[i]
519  << ", but got " << ports[i].type;
520  });
521  }
522 
523  if (ports[i].name != modArgNames[i])
524  return emitError([&](auto &diag) {
525  diag << "port name #" << i << " must be " << modArgNames[i]
526  << ", but got " << ports[i].name;
527  });
528  }
529 
530  return success();
531 }
532 
533 SmallVector<hw::PortInfo> InstanceDeclOp::getPortList() {
534  return cast<hw::PortList>(SymbolTable::lookupNearestSymbolFrom(
535  getOperation(), getReferencedModuleNameAttr()))
536  .getPortList();
537 }
538 
539 //===----------------------------------------------------------------------===//
540 // DestructorOp
541 //===----------------------------------------------------------------------===//
542 
543 LogicalResult DestructorOp::verify() {
544  if (getBody().getNumArguments() != 0)
545  return emitOpError("must not have any arguments");
546 
547  return success();
548 }
549 
550 //===----------------------------------------------------------------------===//
551 // BindPortOp
552 //===----------------------------------------------------------------------===//
553 
554 ParseResult BindPortOp::parse(OpAsmParser &parser, OperationState &result) {
555  OpAsmParser::UnresolvedOperand instance, channel;
556  std::string portName;
557  if (parser.parseOperand(instance) || parser.parseLSquare() ||
558  parser.parseString(&portName))
559  return failure();
560 
561  auto portNameLoc = parser.getCurrentLocation();
562 
563  if (parser.parseRSquare() || parser.parseKeyword("to") ||
564  parser.parseOperand(channel))
565  return failure();
566 
567  if (parser.parseOptionalAttrDict(result.attributes))
568  return failure();
569 
570  auto typeListLoc = parser.getCurrentLocation();
571  SmallVector<Type> types;
572  if (parser.parseColonTypeList(types))
573  return failure();
574 
575  if (types.size() != 2)
576  return parser.emitError(typeListLoc,
577  "expected a list of exactly 2 types, but got ")
578  << types.size();
579 
580  if (parser.resolveOperand(instance, types[0], result.operands))
581  return failure();
582  if (parser.resolveOperand(channel, types[1], result.operands))
583  return failure();
584 
585  if (auto moduleType = dyn_cast<ModuleType>(types[0])) {
586  auto ports = moduleType.getPorts();
587  uint64_t index = 0;
588  for (auto port : ports) {
589  if (port.name == portName)
590  break;
591  index++;
592  }
593  if (index >= ports.size())
594  return parser.emitError(portNameLoc, "port name \"")
595  << portName << "\" not found in module";
596 
597  result.addAttribute("portId", parser.getBuilder().getIndexAttr(index));
598 
599  return success();
600  }
601 
602  return failure();
603 }
604 
605 void BindPortOp::print(OpAsmPrinter &p) {
606  p << " " << getInstance() << "["
607  << cast<ModuleType>(getInstance().getType())
608  .getPorts()[getPortId().getZExtValue()]
609  .name
610  << "] to " << getChannel();
611  p.printOptionalAttrDict((*this)->getAttrs(), {"portId"});
612  p << " : " << getInstance().getType() << ", " << getChannel().getType();
613 }
614 
615 LogicalResult BindPortOp::verify() {
616  auto ports = cast<ModuleType>(getInstance().getType()).getPorts();
617  if (getPortId().getZExtValue() >= ports.size())
618  return emitOpError("port #")
619  << getPortId().getZExtValue() << " does not exist, there are only "
620  << ports.size() << " ports";
621 
622  // Verify that the base types match.
623  Type portType = ports[getPortId().getZExtValue()].type;
624  Type channelType = getChannel().getType();
625  if (getSignalBaseType(portType) != getSignalBaseType(channelType))
626  return emitOpError() << portType << " port cannot be bound to "
627  << channelType << " channel due to base type mismatch";
628 
629  // Verify that the port/channel directions are valid.
630  if ((isa<InputType>(portType) && isa<OutputType>(channelType)) ||
631  (isa<OutputType>(portType) && isa<InputType>(channelType)))
632  return emitOpError() << portType << " port cannot be bound to "
633  << channelType
634  << " channel due to port direction mismatch";
635 
636  return success();
637 }
638 
639 StringRef BindPortOp::getPortName() {
640  return cast<ModuleType>(getInstance().getType())
641  .getPorts()[getPortId().getZExtValue()]
642  .name.getValue();
643 }
644 
645 //===----------------------------------------------------------------------===//
646 // SensitiveOp
647 //===----------------------------------------------------------------------===//
648 
649 LogicalResult SensitiveOp::canonicalize(SensitiveOp op,
650  PatternRewriter &rewriter) {
651  if (op.getSensitivities().empty()) {
652  rewriter.eraseOp(op);
653  return success();
654  }
655 
656  return failure();
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // VariableOp
661 //===----------------------------------------------------------------------===//
662 
664  setNameFn(getVariable(), getName());
665 }
666 
667 ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
668  StringAttr nameAttr;
669  if (parseImplicitSSAName(parser, nameAttr))
670  return failure();
671  result.addAttribute("name", nameAttr);
672 
673  OpAsmParser::UnresolvedOperand init;
674  auto initResult = parser.parseOptionalOperand(init);
675 
676  if (parser.parseOptionalAttrDict(result.attributes))
677  return failure();
678 
679  Type variableType;
680  if (parser.parseColonType(variableType))
681  return failure();
682 
683  if (initResult.has_value()) {
684  if (parser.resolveOperand(init, variableType, result.operands))
685  return failure();
686  }
687  result.addTypes({variableType});
688 
689  return success();
690 }
691 
692 void VariableOp::print(::mlir::OpAsmPrinter &p) {
693  p << " ";
694 
695  if (getInit())
696  p << getInit() << " ";
697 
698  p.printOptionalAttrDict(getOperation()->getAttrs(), {"name"});
699  p << ": " << getVariable().getType();
700 }
701 
702 LogicalResult VariableOp::verify() {
703  if (getInit() && getInit().getType() != getVariable().getType())
704  return emitOpError(
705  "'init' and 'variable' must have the same type, but got ")
706  << getInit().getType() << " and " << getVariable().getType();
707 
708  return success();
709 }
710 
711 //===----------------------------------------------------------------------===//
712 // InteropVerilatedOp
713 //===----------------------------------------------------------------------===//
714 
715 /// Create a instance that refers to a known module.
716 void InteropVerilatedOp::build(OpBuilder &odsBuilder, OperationState &odsState,
717  Operation *module, StringAttr name,
718  ArrayRef<Value> inputs) {
719  auto mod = cast<hw::HWModuleLike>(module);
720  auto argNames = odsBuilder.getArrayAttr(mod.getInputNames());
721  auto resultNames = odsBuilder.getArrayAttr(mod.getOutputNames());
722  build(odsBuilder, odsState, mod.getHWModuleType().getOutputTypes(), name,
723  FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), argNames,
724  resultNames, inputs);
725 }
726 
727 LogicalResult
728 InteropVerilatedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
730  *this, getModuleNameAttr(), getInputs(), getResultTypes(),
731  getInputNames(), getResultNames(), ArrayAttr(), symbolTable);
732 }
733 
734 /// Suggest a name for each result value based on the saved result names
735 /// attribute.
738  getResultNames(), getResults());
739 }
740 
741 //===----------------------------------------------------------------------===//
742 // CallOp
743 //
744 // TODO: The implementation for this operation was copy-pasted from the
745 // 'func' dialect. Ideally, this upstream dialect refactored such that we can
746 // re-use the implementation here.
747 //===----------------------------------------------------------------------===//
748 
749 // FIXME: This is an exact copy from upstream
750 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
751  // Check that the callee attribute was specified.
752  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
753  if (!fnAttr)
754  return emitOpError("requires a 'callee' symbol reference attribute");
755  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
756  if (!fn)
757  return emitOpError() << "'" << fnAttr.getValue()
758  << "' does not reference a valid function";
759 
760  // Verify that the operand and result types match the callee.
761  auto fnType = fn.getFunctionType();
762  if (fnType.getNumInputs() != getNumOperands())
763  return emitOpError("incorrect number of operands for callee");
764 
765  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
766  if (getOperand(i).getType() != fnType.getInput(i))
767  return emitOpError("operand type mismatch: expected operand type ")
768  << fnType.getInput(i) << ", but provided "
769  << getOperand(i).getType() << " for operand number " << i;
770 
771  if (fnType.getNumResults() != getNumResults())
772  return emitOpError("incorrect number of results for callee");
773 
774  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
775  if (getResult(i).getType() != fnType.getResult(i)) {
776  auto diag = emitOpError("result type mismatch at index ") << i;
777  diag.attachNote() << " op result types: " << getResultTypes();
778  diag.attachNote() << "function result types: " << fnType.getResults();
779  return diag;
780  }
781 
782  return success();
783 }
784 
785 FunctionType CallOp::getCalleeType() {
786  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
787 }
788 
789 // This verifier was added compared to the upstream implementation.
790 LogicalResult CallOp::verify() {
791  if (getNumResults() > 1)
792  return emitOpError(
793  "incorrect number of function results (always has to be 0 or 1)");
794 
795  return success();
796 }
797 
798 //===----------------------------------------------------------------------===//
799 // CallIndirectOp
800 //===----------------------------------------------------------------------===//
801 
802 // This verifier was added compared to the upstream implementation.
803 LogicalResult CallIndirectOp::verify() {
804  if (getNumResults() > 1)
805  return emitOpError(
806  "incorrect number of function results (always has to be 0 or 1)");
807 
808  return success();
809 }
810 
811 //===----------------------------------------------------------------------===//
812 // FuncOp
813 //
814 // TODO: Most of the implementation for this operation was copy-pasted from the
815 // 'func' dialect. Ideally, this upstream dialect refactored such that we can
816 // re-use the implementation here.
817 //===----------------------------------------------------------------------===//
818 
819 // Note that the create and build operations are taken from upstream, but the
820 // argNames argument was added.
821 FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
822  FunctionType type, ArrayRef<NamedAttribute> attrs) {
823  OpBuilder builder(location->getContext());
824  OperationState state(location, getOperationName());
825  FuncOp::build(builder, state, name, argNames, type, attrs);
826  return cast<FuncOp>(Operation::create(state));
827 }
828 
829 FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
830  FunctionType type, Operation::dialect_attr_range attrs) {
831  SmallVector<NamedAttribute, 8> attrRef(attrs);
832  return create(location, name, argNames, type, ArrayRef(attrRef));
833 }
834 
835 FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
836  FunctionType type, ArrayRef<NamedAttribute> attrs,
837  ArrayRef<DictionaryAttr> argAttrs) {
838  FuncOp func = create(location, name, argNames, type, attrs);
839  func.setAllArgAttrs(argAttrs);
840  return func;
841 }
842 
843 void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
844  StringRef name, ArrayAttr argNames, FunctionType type,
845  ArrayRef<NamedAttribute> attrs,
846  ArrayRef<DictionaryAttr> argAttrs) {
847  odsState.addAttribute(getArgNamesAttrName(odsState.name), argNames);
848  odsState.addAttribute(SymbolTable::getSymbolAttrName(),
849  odsBuilder.getStringAttr(name));
850  odsState.addAttribute(FuncOp::getFunctionTypeAttrName(odsState.name),
851  TypeAttr::get(type));
852  odsState.attributes.append(attrs.begin(), attrs.end());
853  odsState.addRegion();
854 
855  if (argAttrs.empty())
856  return;
857  assert(type.getNumInputs() == argAttrs.size());
858  mlir::function_interface_impl::addArgAndResultAttrs(
859  odsBuilder, odsState, argAttrs,
860  /*resultAttrs=*/std::nullopt, FuncOp::getArgAttrsAttrName(odsState.name),
861  FuncOp::getResAttrsAttrName(odsState.name));
862 }
863 
864 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
865  auto buildFuncType =
866  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
867  mlir::function_interface_impl::VariadicFlag,
868  std::string &) { return builder.getFunctionType(argTypes, results); };
869 
870  // This was added specifically for our implementation, upstream does not have
871  // this feature.
872  if (succeeded(parser.parseOptionalKeyword("externC")))
873  result.addAttribute(getExternCAttrName(result.name),
874  UnitAttr::get(result.getContext()));
875 
876  // FIXME: below is an exact copy of the
877  // mlir::function_interface_impl::parseFunctionOp implementation, this was
878  // needed because we need to access the SSA names of the arguments.
879  SmallVector<OpAsmParser::Argument> entryArgs;
880  SmallVector<DictionaryAttr> resultAttrs;
881  SmallVector<Type> resultTypes;
882  auto &builder = parser.getBuilder();
883 
884  // Parse visibility.
885  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
886 
887  // Parse the name as a symbol.
888  StringAttr nameAttr;
889  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
890  result.attributes))
891  return failure();
892 
893  // Parse the function signature.
894  mlir::SMLoc signatureLocation = parser.getCurrentLocation();
895  bool isVariadic = false;
896  if (mlir::function_interface_impl::parseFunctionSignature(
897  parser, false, entryArgs, isVariadic, resultTypes, resultAttrs))
898  return failure();
899 
900  std::string errorMessage;
901  SmallVector<Type> argTypes;
902  argTypes.reserve(entryArgs.size());
903  for (auto &arg : entryArgs)
904  argTypes.push_back(arg.type);
905 
906  Type type = buildFuncType(
907  builder, argTypes, resultTypes,
908  mlir::function_interface_impl::VariadicFlag(isVariadic), errorMessage);
909  if (!type) {
910  return parser.emitError(signatureLocation)
911  << "failed to construct function type"
912  << (errorMessage.empty() ? "" : ": ") << errorMessage;
913  }
914  result.addAttribute(FuncOp::getFunctionTypeAttrName(result.name),
915  TypeAttr::get(type));
916 
917  // If function attributes are present, parse them.
918  NamedAttrList parsedAttributes;
919  mlir::SMLoc attributeDictLocation = parser.getCurrentLocation();
920  if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes))
921  return failure();
922 
923  // Disallow attributes that are inferred from elsewhere in the attribute
924  // dictionary.
925  for (StringRef disallowed :
926  {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
927  FuncOp::getFunctionTypeAttrName(result.name).getValue()}) {
928  if (parsedAttributes.get(disallowed))
929  return parser.emitError(attributeDictLocation, "'")
930  << disallowed
931  << "' is an inferred attribute and should not be specified in the "
932  "explicit attribute dictionary";
933  }
934  result.attributes.append(parsedAttributes);
935 
936  // Add the attributes to the function arguments.
937  assert(resultAttrs.size() == resultTypes.size());
938  mlir::function_interface_impl::addArgAndResultAttrs(
939  builder, result, entryArgs, resultAttrs,
940  FuncOp::getArgAttrsAttrName(result.name),
941  FuncOp::getResAttrsAttrName(result.name));
942 
943  // Parse the optional function body. The printer will not print the body if
944  // its empty, so disallow parsing of empty body in the parser.
945  auto *body = result.addRegion();
946  mlir::SMLoc loc = parser.getCurrentLocation();
947  mlir::OptionalParseResult parseResult =
948  parser.parseOptionalRegion(*body, entryArgs,
949  /*enableNameShadowing=*/false);
950  if (parseResult.has_value()) {
951  if (failed(*parseResult))
952  return failure();
953  // Function body was parsed, make sure its not empty.
954  if (body->empty())
955  return parser.emitError(loc, "expected non-empty function body");
956  }
957 
958  // Everythink below is added compared to the upstream implemenation to handle
959  // argument names.
960  SmallVector<Attribute> argNames;
961  if (!entryArgs.empty() && !entryArgs.front().ssaName.name.empty()) {
962  for (auto &arg : entryArgs)
963  argNames.push_back(
964  StringAttr::get(parser.getContext(), arg.ssaName.name.drop_front()));
965  }
966 
967  result.addAttribute(getArgNamesAttrName(result.name),
968  ArrayAttr::get(parser.getContext(), argNames));
969 
970  return success();
971 }
972 
973 void FuncOp::print(OpAsmPrinter &p) {
974  if (getExternC())
975  p << " externC";
976 
977  mlir::FunctionOpInterface op = *this;
978 
979  // FIXME: inlined mlir::function_interface_impl::printFunctionOp because we
980  // need to elide more attributes
981 
982  // Print the operation and the function name.
983  auto funcName =
984  op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
985  .getValue();
986  p << ' ';
987 
988  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
989  if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
990  p << visibility.getValue() << ' ';
991  p.printSymbolName(funcName);
992 
993  ArrayRef<Type> argTypes = op.getArgumentTypes();
994  ArrayRef<Type> resultTypes = op.getResultTypes();
995  mlir::function_interface_impl::printFunctionSignature(p, op, argTypes, false,
996  resultTypes);
997  mlir::function_interface_impl::printFunctionAttributes(
998  p, op,
999  {visibilityAttrName, "externC", "argNames", getFunctionTypeAttrName(),
1000  getArgAttrsAttrName(), getResAttrsAttrName()});
1001  // Print the body if this is not an external function.
1002  Region &body = op->getRegion(0);
1003  if (!body.empty()) {
1004  p << ' ';
1005  p.printRegion(body, /*printEntryBlockArgs=*/false,
1006  /*printBlockTerminators=*/true);
1007  }
1008 }
1009 
1010 // FIXME: the below clone operation are exact copies from upstream.
1011 
1012 /// Clone the internal blocks from this function into dest and all attributes
1013 /// from this function to dest.
1014 void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
1015  // Add the attributes of this function to dest.
1016  llvm::MapVector<StringAttr, Attribute> newAttrMap;
1017  for (const auto &attr : dest->getAttrs())
1018  newAttrMap.insert({attr.getName(), attr.getValue()});
1019  for (const auto &attr : (*this)->getAttrs())
1020  newAttrMap.insert({attr.getName(), attr.getValue()});
1021 
1022  auto newAttrs = llvm::to_vector(llvm::map_range(
1023  newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
1024  return NamedAttribute(attrPair.first, attrPair.second);
1025  }));
1026  dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
1027 
1028  // Clone the body.
1029  getBody().cloneInto(&dest.getBody(), mapper);
1030 }
1031 
1032 /// Create a deep copy of this function and all of its blocks, remapping
1033 /// any operands that use values outside of the function using the map that is
1034 /// provided (leaving them alone if no entry is present). Replaces references
1035 /// to cloned sub-values with the corresponding value that is copied, and adds
1036 /// those mappings to the mapper.
1037 FuncOp FuncOp::clone(IRMapping &mapper) {
1038  // Create the new function.
1039  FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
1040 
1041  // If the function has a body, then the user might be deleting arguments to
1042  // the function by specifying them in the mapper. If so, we don't add the
1043  // argument to the input type vector.
1044  if (!isExternal()) {
1045  FunctionType oldType = getFunctionType();
1046 
1047  unsigned oldNumArgs = oldType.getNumInputs();
1048  SmallVector<Type, 4> newInputs;
1049  newInputs.reserve(oldNumArgs);
1050  for (unsigned i = 0; i != oldNumArgs; ++i)
1051  if (!mapper.contains(getArgument(i)))
1052  newInputs.push_back(oldType.getInput(i));
1053 
1054  /// If any of the arguments were dropped, update the type and drop any
1055  /// necessary argument attributes.
1056  if (newInputs.size() != oldNumArgs) {
1057  newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
1058  oldType.getResults()));
1059 
1060  if (ArrayAttr argAttrs = getAllArgAttrs()) {
1061  SmallVector<Attribute> newArgAttrs;
1062  newArgAttrs.reserve(newInputs.size());
1063  for (unsigned i = 0; i != oldNumArgs; ++i)
1064  if (!mapper.contains(getArgument(i)))
1065  newArgAttrs.push_back(argAttrs[i]);
1066  newFunc.setAllArgAttrs(newArgAttrs);
1067  }
1068  }
1069  }
1070 
1071  /// Clone the current function into the new one and return it.
1072  cloneInto(newFunc, mapper);
1073  return newFunc;
1074 }
1075 
1076 FuncOp FuncOp::clone() {
1077  IRMapping mapper;
1078  return clone(mapper);
1079 }
1080 
1081 // The following functions are entirely new additions compared to upstream.
1082 
1083 void FuncOp::getAsmBlockArgumentNames(mlir::Region &region,
1084  mlir::OpAsmSetValueNameFn setNameFn) {
1085  if (region.empty())
1086  return;
1087 
1088  for (auto [arg, name] : llvm::zip(getArguments(), getArgNames()))
1089  setNameFn(arg, cast<StringAttr>(name).getValue());
1090 }
1091 
1092 LogicalResult FuncOp::verify() {
1093  if (getFunctionType().getNumResults() > 1)
1094  return emitOpError(
1095  "incorrect number of function results (always has to be 0 or 1)");
1096 
1097  if (getBody().empty())
1098  return success();
1099 
1100  if (getArgNames().size() != getFunctionType().getNumInputs())
1101  return emitOpError("incorrect number of argument names");
1102 
1103  for (auto portName : getArgNames()) {
1104  if (cast<StringAttr>(portName).getValue().empty())
1105  return emitOpError("arg name must not be empty");
1106  }
1107 
1108  return success();
1109 }
1110 
1111 LogicalResult FuncOp::verifyRegions() {
1112  auto attachNote = [&](mlir::InFlightDiagnostic &diag) {
1113  diag.attachNote(getLoc()) << "in function '@" << getName() << "'";
1114  };
1115  return verifyUniqueNamesInRegion(getOperation(), getArgNames(), attachNote);
1116 }
1117 
1118 //===----------------------------------------------------------------------===//
1119 // ReturnOp
1120 //
1121 // TODO: The implementation for this operation was copy-pasted from the
1122 // 'func' dialect. Ideally, this upstream dialect refactored such that we can
1123 // re-use the implementation here.
1124 //===----------------------------------------------------------------------===//
1125 
1126 LogicalResult ReturnOp::verify() {
1127  auto function = cast<FuncOp>((*this)->getParentOp());
1128 
1129  // The operand number and types must match the function signature.
1130  const auto &results = function.getFunctionType().getResults();
1131  if (getNumOperands() != results.size())
1132  return emitOpError("has ")
1133  << getNumOperands() << " operands, but enclosing function (@"
1134  << function.getName() << ") returns " << results.size();
1135 
1136  for (unsigned i = 0, e = results.size(); i != e; ++i)
1137  if (getOperand(i).getType() != results[i])
1138  return emitError() << "type of return operand " << i << " ("
1139  << getOperand(i).getType()
1140  << ") doesn't match function result type ("
1141  << results[i] << ")"
1142  << " in function @" << function.getName();
1143 
1144  return success();
1145 }
1146 
1147 //===----------------------------------------------------------------------===//
1148 // TableGen generated logic.
1149 //===----------------------------------------------------------------------===//
1150 
1151 // Provide the autogenerated implementation guts for the Op classes.
1152 #define GET_OP_CLASSES
1153 #include "circt/Dialect/SystemC/SystemC.cpp.inc"
assert(baseType &&"element must be base type")
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition: HWOps.cpp:1418
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
@ InOut
Definition: HW.h:35
static InstancePath empty
static Block * getBodyBlock(FModuleLike mod)
static hw::ModulePort::Direction getDirection(Type type)
Definition: SystemCOps.cpp:102
static Type wrapPortType(Type type, hw::ModulePort::Direction direction)
Definition: SystemCOps.cpp:222
static LogicalResult verifyUniqueNamesInRegion(Operation *operation, ArrayAttr argNames, std::function< void(mlir::InFlightDiagnostic &)> attachNote)
Definition: SystemCOps.cpp:31
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
Definition: VerifOps.cpp:66
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Definition: SVOps.cpp:2459
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
void printModuleSignature(OpAsmPrinter &p, Operation *op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes, bool &needArgNamesAttr)
Print a module signature with named results.
ParseResult parseModuleFunctionSignature(OpAsmParser &parser, bool &isVariadic, SmallVectorImpl< OpAsmParser::Argument > &args, SmallVectorImpl< Attribute > &argNames, SmallVectorImpl< Attribute > &argLocs, SmallVectorImpl< Attribute > &resultNames, SmallVectorImpl< DictionaryAttr > &resultAttrs, SmallVectorImpl< Attribute > &resultLocs, TypeAttr &type)
This is a variant of mlir::parseFunctionSignature that allows names on result arguments.
circt::hw::InOutType InOutType
Definition: SVTypes.h:25
Type getSignalBaseType(Type type)
Get the type wrapped by a signal or port (in, inout, out) type.
std::optional< size_t > getBitWidth(Type type)
Return the bitwidth of a type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
ParseResult parseImplicitSSAName(OpAsmParser &parser, StringAttr &attr)
Parse an implicit SSA name string attribute.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition: LLVM.h:182