CIRCT  18.0.0git
SystemCOps.cpp
Go to the documentation of this file.
1 //===- SystemCOps.cpp - Implement the SystemC operations ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the SystemC ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
18 #include "mlir/IR/IRMapping.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/Interfaces/FunctionImplementation.h"
21 #include "llvm/ADT/MapVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 
24 using namespace circt;
25 using namespace circt::systemc;
26 
27 //===----------------------------------------------------------------------===//
28 // Helpers
29 //===----------------------------------------------------------------------===//
30 
31 static LogicalResult verifyUniqueNamesInRegion(
32  Operation *operation, ArrayAttr argNames,
33  std::function<void(mlir::InFlightDiagnostic &)> attachNote) {
34  DenseMap<StringRef, BlockArgument> portNames;
35  DenseMap<StringRef, Operation *> memberNames;
36  DenseMap<StringRef, Operation *> localNames;
37 
38  if (operation->getNumRegions() != 1)
39  return operation->emitError("required to have exactly one region");
40 
41  bool portsVerified = true;
42 
43  for (auto arg : llvm::zip(argNames, operation->getRegion(0).getArguments())) {
44  StringRef argName = std::get<0>(arg).cast<StringAttr>().getValue();
45  BlockArgument argValue = std::get<1>(arg);
46 
47  if (portNames.count(argName)) {
48  auto diag = mlir::emitError(argValue.getLoc(), "redefines name '")
49  << argName << "'";
50  diag.attachNote(portNames[argName].getLoc())
51  << "'" << argName << "' first defined here";
52  attachNote(diag);
53  portsVerified = false;
54  continue;
55  }
56 
57  portNames.insert({argName, argValue});
58  }
59 
60  WalkResult result =
61  operation->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
62  if (isa<SCModuleOp>(op->getParentOp()))
63  localNames.clear();
64 
65  if (auto nameDeclOp = dyn_cast<SystemCNameDeclOpInterface>(op)) {
66  StringRef name = nameDeclOp.getName();
67 
68  auto reportNameRedefinition = [&](Location firstLoc) -> WalkResult {
69  auto diag = mlir::emitError(op->getLoc(), "redefines name '")
70  << name << "'";
71  diag.attachNote(firstLoc) << "'" << name << "' first defined here";
72  attachNote(diag);
73  return WalkResult::interrupt();
74  };
75 
76  if (portNames.count(name))
77  return reportNameRedefinition(portNames[name].getLoc());
78  if (memberNames.count(name))
79  return reportNameRedefinition(memberNames[name]->getLoc());
80  if (localNames.count(name))
81  return reportNameRedefinition(localNames[name]->getLoc());
82 
83  if (isa<SCModuleOp>(op->getParentOp()))
84  memberNames.insert({name, op});
85  else
86  localNames.insert({name, op});
87  }
88 
89  return WalkResult::advance();
90  });
91 
92  if (result.wasInterrupted() || !portsVerified)
93  return failure();
94 
95  return success();
96 }
97 
98 //===----------------------------------------------------------------------===//
99 // SCModuleOp
100 //===----------------------------------------------------------------------===//
101 
103  return TypeSwitch<Type, hw::ModulePort::Direction>(type)
104  .Case<InOutType>([](auto ty) { return hw::ModulePort::Direction::InOut; })
105  .Case<InputType>([](auto ty) { return hw::ModulePort::Direction::Input; })
106  .Case<OutputType>(
107  [](auto ty) { return hw::ModulePort::Direction::Output; });
108 }
109 
110 SCModuleOp::PortDirectionRange
111 SCModuleOp::getPortsOfDirection(hw::ModulePort::Direction direction) {
112  std::function<bool(const BlockArgument &)> predicateFn =
113  [&](const BlockArgument &arg) -> bool {
114  return getDirection(arg.getType()) == direction;
115  };
116  return llvm::make_filter_range(getArguments(), predicateFn);
117 }
118 
119 SmallVector<::circt::hw::PortInfo> SCModuleOp::getPortList() {
120  SmallVector<hw::PortInfo> ports;
121  for (int i = 0, e = getNumArguments(); i < e; ++i) {
122  hw::PortInfo info;
123  info.name = getPortNames()[i].cast<StringAttr>();
124  info.type = getSignalBaseType(getArgument(i).getType());
125  info.dir = getDirection(info.type);
126  ports.push_back(info);
127  }
128  return ports;
129 }
130 
131 mlir::Region *SCModuleOp::getCallableRegion() { return &getBody(); }
132 
133 StringRef SCModuleOp::getModuleName() {
134  return (*this)
135  ->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
136  .getValue();
137 }
138 
139 ParseResult SCModuleOp::parse(OpAsmParser &parser, OperationState &result) {
140 
141  // Parse the visibility attribute.
142  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
143 
144  // Parse the name as a symbol.
145  StringAttr moduleName;
146  if (parser.parseSymbolName(moduleName, SymbolTable::getSymbolAttrName(),
147  result.attributes))
148  return failure();
149 
150  // Parse the function signature.
151  bool isVariadic = false;
152  SmallVector<OpAsmParser::Argument, 4> entryArgs;
153  SmallVector<Attribute> argNames;
154  SmallVector<Attribute> argLocs;
155  SmallVector<Attribute> resultNames;
156  SmallVector<DictionaryAttr> resultAttrs;
157  SmallVector<Attribute> resultLocs;
158  TypeAttr functionType;
160  parser, isVariadic, entryArgs, argNames, argLocs, resultNames,
161  resultAttrs, resultLocs, functionType)))
162  return failure();
163 
164  // Parse the attribute dict.
165  if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
166  return failure();
167 
168  result.addAttribute("portNames",
169  ArrayAttr::get(parser.getContext(), argNames));
170 
171  result.addAttribute(SCModuleOp::getFunctionTypeAttrName(result.name),
172  functionType);
173 
174  mlir::function_interface_impl::addArgAndResultAttrs(
175  parser.getBuilder(), result, entryArgs, resultAttrs,
176  SCModuleOp::getArgAttrsAttrName(result.name),
177  SCModuleOp::getResAttrsAttrName(result.name));
178 
179  auto &body = *result.addRegion();
180  if (parser.parseRegion(body, entryArgs))
181  return failure();
182  if (body.empty())
183  body.push_back(std::make_unique<Block>().release());
184 
185  return success();
186 }
187 
188 void SCModuleOp::print(OpAsmPrinter &p) {
189  p << ' ';
190 
191  // Print the visibility of the module.
192  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
193  if (auto visibility =
194  getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
195  p << visibility.getValue() << ' ';
196 
197  p.printSymbolName(SymbolTable::getSymbolName(*this).getValue());
198  p << ' ';
199 
200  bool needArgNamesAttr = false;
202  p, *this, getFunctionType().getInputs(), false, {}, needArgNamesAttr);
203  mlir::function_interface_impl::printFunctionAttributes(
204  p, *this,
205  {"portNames", getFunctionTypeAttrName(), getArgAttrsAttrName(),
206  getResAttrsAttrName()});
207 
208  p << ' ';
209  p.printRegion(getBody(), false, false);
210 }
211 
212 /// Returns the argument types of this function.
213 ArrayRef<Type> SCModuleOp::getArgumentTypes() {
214  return getFunctionType().getInputs();
215 }
216 
217 /// Returns the result types of this function.
218 ArrayRef<Type> SCModuleOp::getResultTypes() {
219  return getFunctionType().getResults();
220 }
221 
222 static Type wrapPortType(Type type, hw::ModulePort::Direction direction) {
223  if (auto inoutTy = type.dyn_cast<hw::InOutType>())
224  type = inoutTy.getElementType();
225 
226  switch (direction) {
228  return InOutType::get(type);
230  return InputType::get(type);
232  return OutputType::get(type);
233  }
234  llvm_unreachable("Impossible port direction");
235 }
236 
237 void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
238  StringAttr name, ArrayAttr portNames,
239  ArrayRef<Type> portTypes,
240  ArrayRef<NamedAttribute> attributes) {
241  odsState.addAttribute(getPortNamesAttrName(odsState.name), portNames);
242  Region *region = odsState.addRegion();
243 
244  auto moduleType = odsBuilder.getFunctionType(portTypes, {});
245  odsState.addAttribute(getFunctionTypeAttrName(odsState.name),
246  TypeAttr::get(moduleType));
247 
248  odsState.addAttribute(SymbolTable::getSymbolAttrName(), name);
249  region->push_back(new Block);
250  region->addArguments(
251  portTypes,
252  SmallVector<Location>(portTypes.size(), odsBuilder.getUnknownLoc()));
253  odsState.addAttributes(attributes);
254 }
255 
256 void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
257  StringAttr name, ArrayRef<hw::PortInfo> ports,
258  ArrayRef<NamedAttribute> attributes) {
259  MLIRContext *ctxt = odsBuilder.getContext();
260  SmallVector<Attribute> portNames;
261  SmallVector<Type> portTypes;
262  for (auto port : ports) {
263  portNames.push_back(StringAttr::get(ctxt, port.getName()));
264  portTypes.push_back(wrapPortType(port.type, port.dir));
265  }
266  build(odsBuilder, odsState, name, ArrayAttr::get(ctxt, portNames), portTypes);
267 }
268 
269 void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
270  StringAttr name, const hw::ModulePortInfo &ports,
271  ArrayRef<NamedAttribute> attributes) {
272  MLIRContext *ctxt = odsBuilder.getContext();
273  SmallVector<Attribute> portNames;
274  SmallVector<Type> portTypes;
275  for (auto port : ports) {
276  portNames.push_back(StringAttr::get(ctxt, port.getName()));
277  portTypes.push_back(wrapPortType(port.type, port.dir));
278  }
279  build(odsBuilder, odsState, name, ArrayAttr::get(ctxt, portNames), portTypes);
280 }
281 
282 void SCModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
283  mlir::OpAsmSetValueNameFn setNameFn) {
284  if (region.empty())
285  return;
286 
287  ArrayAttr portNames = getPortNames();
288  for (size_t i = 0, e = getNumArguments(); i != e; ++i) {
289  auto str = portNames[i].cast<StringAttr>().getValue();
290  setNameFn(getArgument(i), str);
291  }
292 }
293 
294 LogicalResult SCModuleOp::verify() {
295  if (getFunctionType().getNumResults() != 0)
296  return emitOpError(
297  "incorrect number of function results (always has to be 0)");
298  if (getPortNames().size() != getFunctionType().getNumInputs())
299  return emitOpError("incorrect number of port names");
300 
301  for (auto arg : getArguments()) {
302  if (!hw::type_isa<InputType, OutputType, InOutType>(arg.getType()))
303  return mlir::emitError(
304  arg.getLoc(),
305  "module port must be of type 'sc_in', 'sc_out', or 'sc_inout'");
306  }
307 
308  for (auto portName : getPortNames()) {
309  if (portName.cast<StringAttr>().getValue().empty())
310  return emitOpError("port name must not be empty");
311  }
312 
313  return success();
314 }
315 
316 LogicalResult SCModuleOp::verifyRegions() {
317  auto attachNote = [&](mlir::InFlightDiagnostic &diag) {
318  diag.attachNote(getLoc()) << "in module '@" << getModuleName() << "'";
319  };
320  return verifyUniqueNamesInRegion(getOperation(), getPortNames(), attachNote);
321 }
322 
323 CtorOp SCModuleOp::getOrCreateCtor() {
324  CtorOp ctor;
325  getBody().walk([&](Operation *op) {
326  if ((ctor = dyn_cast<CtorOp>(op)))
327  return WalkResult::interrupt();
328 
329  return WalkResult::skip();
330  });
331 
332  if (ctor)
333  return ctor;
334 
335  return OpBuilder(getBody()).create<CtorOp>(getLoc());
336 }
337 
338 DestructorOp SCModuleOp::getOrCreateDestructor() {
339  DestructorOp destructor;
340  getBody().walk([&](Operation *op) {
341  if ((destructor = dyn_cast<DestructorOp>(op)))
342  return WalkResult::interrupt();
343 
344  return WalkResult::skip();
345  });
346 
347  if (destructor)
348  return destructor;
349 
350  return OpBuilder::atBlockEnd(getBodyBlock()).create<DestructorOp>(getLoc());
351 }
352 
353 //===----------------------------------------------------------------------===//
354 // SignalOp
355 //===----------------------------------------------------------------------===//
356 
358  setNameFn(getSignal(), getName());
359 }
360 
361 //===----------------------------------------------------------------------===//
362 // ConvertOp
363 //===----------------------------------------------------------------------===//
364 
365 OpFoldResult ConvertOp::fold(FoldAdaptor) {
366  if (getInput().getType() == getResult().getType())
367  return getInput();
368 
369  if (auto other = getInput().getDefiningOp<ConvertOp>()) {
370  Type inputType = other.getInput().getType();
371  Type intermediateType = getInput().getType();
372 
373  if (inputType != getResult().getType())
374  return {};
375 
376  // Either both the input and intermediate types are signed or both are
377  // unsigned.
378  bool inputSigned = inputType.isa<SignedType, IntBaseType>();
379  bool intermediateSigned = intermediateType.isa<SignedType, IntBaseType>();
380  if (inputSigned ^ intermediateSigned)
381  return {};
382 
383  // Converting 4-valued to 2-valued and back may lose information.
384  if (inputType.isa<LogicVectorBaseType, LogicType>() &&
385  !intermediateType.isa<LogicVectorBaseType, LogicType>())
386  return {};
387 
388  auto inputBw = getBitWidth(inputType);
389  auto intermediateBw = getBitWidth(intermediateType);
390 
391  if (!inputBw && intermediateBw) {
392  if (inputType.isa<IntBaseType, UIntBaseType>() && *intermediateBw >= 64)
393  return other.getInput();
394  // We cannot support input types of signed, unsigned, and vector types
395  // since they have no upper bound for the bit-width.
396  }
397 
398  if (!intermediateBw) {
399  if (intermediateType.isa<BitVectorBaseType, LogicVectorBaseType>())
400  return other.getInput();
401 
402  if (!inputBw && inputType.isa<IntBaseType, UIntBaseType>() &&
403  intermediateType.isa<SignedType, UnsignedType>())
404  return other.getInput();
405 
406  if (inputBw && *inputBw <= 64 &&
407  intermediateType
408  .isa<IntBaseType, UIntBaseType, SignedType, UnsignedType>())
409  return other.getInput();
410 
411  // We have to be careful with the signed and unsigned types as they often
412  // have a max bit-width defined (that can be customized) and thus folding
413  // here could change the behavior.
414  }
415 
416  if (inputBw && intermediateBw && *inputBw <= *intermediateBw)
417  return other.getInput();
418  }
419 
420  return {};
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // CtorOp
425 //===----------------------------------------------------------------------===//
426 
427 LogicalResult CtorOp::verify() {
428  if (getBody().getNumArguments() != 0)
429  return emitOpError("must not have any arguments");
430 
431  return success();
432 }
433 
434 //===----------------------------------------------------------------------===//
435 // SCFuncOp
436 //===----------------------------------------------------------------------===//
437 
439  setNameFn(getHandle(), getName());
440 }
441 
442 LogicalResult SCFuncOp::verify() {
443  if (getBody().getNumArguments() != 0)
444  return emitOpError("must not have any arguments");
445 
446  return success();
447 }
448 
449 //===----------------------------------------------------------------------===//
450 // InstanceDeclOp
451 //===----------------------------------------------------------------------===//
452 
454  setNameFn(getInstanceHandle(), getName());
455 }
456 
457 StringRef InstanceDeclOp::getInstanceName() { return getName(); }
458 StringAttr InstanceDeclOp::getInstanceNameAttr() { return getNameAttr(); }
459 
460 Operation *
461 InstanceDeclOp::getReferencedModuleCached(const hw::HWSymbolCache *cache) {
462  if (cache)
463  if (auto *result = cache->getDefinition(getModuleNameAttr()))
464  return result;
465 
466  auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
467  return topLevelModuleOp.lookupSymbol(getModuleName());
468 }
469 
470 LogicalResult
471 InstanceDeclOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
472  auto *module =
473  symbolTable.lookupNearestSymbolFrom(*this, getModuleNameAttr());
474  if (module == nullptr)
475  return emitError("cannot find module definition '")
476  << getModuleName() << "'";
477 
478  auto emitError = [&](const std::function<void(InFlightDiagnostic & diag)> &fn)
479  -> LogicalResult {
480  auto diag = emitOpError();
481  fn(diag);
482  diag.attachNote(module->getLoc()) << "module declared here";
483  return failure();
484  };
485 
486  // It must be a systemc module.
487  if (!isa<SCModuleOp>(module))
488  return emitError([&](auto &diag) {
489  diag << "symbol reference '" << getModuleName()
490  << "' isn't a systemc module";
491  });
492 
493  auto scModule = cast<SCModuleOp>(module);
494 
495  // Check that the module name of the symbol and instance type match.
496  if (scModule.getModuleName() != getInstanceType().getModuleName())
497  return emitError([&](auto &diag) {
498  diag << "module names must match; expected '" << scModule.getModuleName()
499  << "' but got '" << getInstanceType().getModuleName().getValue()
500  << "'";
501  });
502 
503  // Check that port types and names are consistent with the referenced module.
504  ArrayRef<ModuleType::PortInfo> ports = getInstanceType().getPorts();
505  ArrayAttr modArgNames = scModule.getPortNames();
506  auto numPorts = ports.size();
507  auto expectedPortTypes = scModule.getArgumentTypes();
508 
509  if (expectedPortTypes.size() != numPorts)
510  return emitError([&](auto &diag) {
511  diag << "has a wrong number of ports; expected "
512  << expectedPortTypes.size() << " but got " << numPorts;
513  });
514 
515  for (size_t i = 0; i != numPorts; ++i) {
516  if (ports[i].type != expectedPortTypes[i]) {
517  return emitError([&](auto &diag) {
518  diag << "port type #" << i << " must be " << expectedPortTypes[i]
519  << ", but got " << ports[i].type;
520  });
521  }
522 
523  if (ports[i].name != modArgNames[i])
524  return emitError([&](auto &diag) {
525  diag << "port name #" << i << " must be " << modArgNames[i]
526  << ", but got " << ports[i].name;
527  });
528  }
529 
530  return success();
531 }
532 
533 SmallVector<hw::PortInfo> InstanceDeclOp::getPortList() {
534  return cast<hw::PortList>(getReferencedModuleSlow()).getPortList();
535 }
536 
537 //===----------------------------------------------------------------------===//
538 // DestructorOp
539 //===----------------------------------------------------------------------===//
540 
541 LogicalResult DestructorOp::verify() {
542  if (getBody().getNumArguments() != 0)
543  return emitOpError("must not have any arguments");
544 
545  return success();
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // BindPortOp
550 //===----------------------------------------------------------------------===//
551 
552 ParseResult BindPortOp::parse(OpAsmParser &parser, OperationState &result) {
553  OpAsmParser::UnresolvedOperand instance, channel;
554  std::string portName;
555  if (parser.parseOperand(instance) || parser.parseLSquare() ||
556  parser.parseString(&portName))
557  return failure();
558 
559  auto portNameLoc = parser.getCurrentLocation();
560 
561  if (parser.parseRSquare() || parser.parseKeyword("to") ||
562  parser.parseOperand(channel))
563  return failure();
564 
565  if (parser.parseOptionalAttrDict(result.attributes))
566  return failure();
567 
568  auto typeListLoc = parser.getCurrentLocation();
569  SmallVector<Type> types;
570  if (parser.parseColonTypeList(types))
571  return failure();
572 
573  if (types.size() != 2)
574  return parser.emitError(typeListLoc,
575  "expected a list of exactly 2 types, but got ")
576  << types.size();
577 
578  if (parser.resolveOperand(instance, types[0], result.operands))
579  return failure();
580  if (parser.resolveOperand(channel, types[1], result.operands))
581  return failure();
582 
583  if (auto moduleType = types[0].dyn_cast<ModuleType>()) {
584  auto ports = moduleType.getPorts();
585  uint64_t index = 0;
586  for (auto port : ports) {
587  if (port.name == portName)
588  break;
589  index++;
590  }
591  if (index >= ports.size())
592  return parser.emitError(portNameLoc, "port name \"")
593  << portName << "\" not found in module";
594 
595  result.addAttribute("portId", parser.getBuilder().getIndexAttr(index));
596 
597  return success();
598  }
599 
600  return failure();
601 }
602 
603 void BindPortOp::print(OpAsmPrinter &p) {
604  p << " " << getInstance() << "["
605  << getInstance()
606  .getType()
607  .cast<ModuleType>()
608  .getPorts()[getPortId().getZExtValue()]
609  .name
610  << "] to " << getChannel();
611  p.printOptionalAttrDict((*this)->getAttrs(), {"portId"});
612  p << " : " << getInstance().getType() << ", " << getChannel().getType();
613 }
614 
615 LogicalResult BindPortOp::verify() {
616  auto ports = getInstance().getType().cast<ModuleType>().getPorts();
617  if (getPortId().getZExtValue() >= ports.size())
618  return emitOpError("port #")
619  << getPortId().getZExtValue() << " does not exist, there are only "
620  << ports.size() << " ports";
621 
622  // Verify that the base types match.
623  Type portType = ports[getPortId().getZExtValue()].type;
624  Type channelType = getChannel().getType();
625  if (getSignalBaseType(portType) != getSignalBaseType(channelType))
626  return emitOpError() << portType << " port cannot be bound to "
627  << channelType << " channel due to base type mismatch";
628 
629  // Verify that the port/channel directions are valid.
630  if ((portType.isa<InputType>() && channelType.isa<OutputType>()) ||
631  (portType.isa<OutputType>() && channelType.isa<InputType>()))
632  return emitOpError() << portType << " port cannot be bound to "
633  << channelType
634  << " channel due to port direction mismatch";
635 
636  return success();
637 }
638 
639 StringRef BindPortOp::getPortName() {
640  return getInstance()
641  .getType()
642  .cast<ModuleType>()
643  .getPorts()[getPortId().getZExtValue()]
644  .name.getValue();
645 }
646 
647 //===----------------------------------------------------------------------===//
648 // SensitiveOp
649 //===----------------------------------------------------------------------===//
650 
651 LogicalResult SensitiveOp::canonicalize(SensitiveOp op,
652  PatternRewriter &rewriter) {
653  if (op.getSensitivities().empty()) {
654  rewriter.eraseOp(op);
655  return success();
656  }
657 
658  return failure();
659 }
660 
661 //===----------------------------------------------------------------------===//
662 // VariableOp
663 //===----------------------------------------------------------------------===//
664 
666  setNameFn(getVariable(), getName());
667 }
668 
669 ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
670  StringAttr nameAttr;
671  if (parseImplicitSSAName(parser, nameAttr))
672  return failure();
673  result.addAttribute("name", nameAttr);
674 
675  OpAsmParser::UnresolvedOperand init;
676  auto initResult = parser.parseOptionalOperand(init);
677 
678  if (parser.parseOptionalAttrDict(result.attributes))
679  return failure();
680 
681  Type variableType;
682  if (parser.parseColonType(variableType))
683  return failure();
684 
685  if (initResult.has_value()) {
686  if (parser.resolveOperand(init, variableType, result.operands))
687  return failure();
688  }
689  result.addTypes({variableType});
690 
691  return success();
692 }
693 
694 void VariableOp::print(::mlir::OpAsmPrinter &p) {
695  p << " ";
696 
697  if (getInit())
698  p << getInit() << " ";
699 
700  p.printOptionalAttrDict(getOperation()->getAttrs(), {"name"});
701  p << ": " << getVariable().getType();
702 }
703 
704 LogicalResult VariableOp::verify() {
705  if (getInit() && getInit().getType() != getVariable().getType())
706  return emitOpError(
707  "'init' and 'variable' must have the same type, but got ")
708  << getInit().getType() << " and " << getVariable().getType();
709 
710  return success();
711 }
712 
713 //===----------------------------------------------------------------------===//
714 // InteropVerilatedOp
715 //===----------------------------------------------------------------------===//
716 
717 /// Create a instance that refers to a known module.
718 void InteropVerilatedOp::build(OpBuilder &odsBuilder, OperationState &odsState,
719  Operation *module, StringAttr name,
720  ArrayRef<Value> inputs) {
721  auto mod = cast<hw::HWModuleLike>(module);
722  auto argNames = odsBuilder.getArrayAttr(mod.getInputNames());
723  auto resultNames = odsBuilder.getArrayAttr(mod.getOutputNames());
724  build(odsBuilder, odsState, mod.getHWModuleType().getOutputTypes(), name,
725  FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), argNames,
726  resultNames, inputs);
727 }
728 
729 LogicalResult
730 InteropVerilatedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
732  *this, getModuleNameAttr(), getInputs(), getResultTypes(),
733  getInputNames(), getResultNames(), ArrayAttr(), symbolTable);
734 }
735 
736 /// Suggest a name for each result value based on the saved result names
737 /// attribute.
740  getResultNames(), getResults());
741 }
742 
743 //===----------------------------------------------------------------------===//
744 // CallOp
745 //
746 // TODO: The implementation for this operation was copy-pasted from the
747 // 'func' dialect. Ideally, this upstream dialect refactored such that we can
748 // re-use the implementation here.
749 //===----------------------------------------------------------------------===//
750 
751 // FIXME: This is an exact copy from upstream
752 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
753  // Check that the callee attribute was specified.
754  auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
755  if (!fnAttr)
756  return emitOpError("requires a 'callee' symbol reference attribute");
757  FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
758  if (!fn)
759  return emitOpError() << "'" << fnAttr.getValue()
760  << "' does not reference a valid function";
761 
762  // Verify that the operand and result types match the callee.
763  auto fnType = fn.getFunctionType();
764  if (fnType.getNumInputs() != getNumOperands())
765  return emitOpError("incorrect number of operands for callee");
766 
767  for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
768  if (getOperand(i).getType() != fnType.getInput(i))
769  return emitOpError("operand type mismatch: expected operand type ")
770  << fnType.getInput(i) << ", but provided "
771  << getOperand(i).getType() << " for operand number " << i;
772 
773  if (fnType.getNumResults() != getNumResults())
774  return emitOpError("incorrect number of results for callee");
775 
776  for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
777  if (getResult(i).getType() != fnType.getResult(i)) {
778  auto diag = emitOpError("result type mismatch at index ") << i;
779  diag.attachNote() << " op result types: " << getResultTypes();
780  diag.attachNote() << "function result types: " << fnType.getResults();
781  return diag;
782  }
783 
784  return success();
785 }
786 
787 FunctionType CallOp::getCalleeType() {
788  return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
789 }
790 
791 // This verifier was added compared to the upstream implementation.
792 LogicalResult CallOp::verify() {
793  if (getNumResults() > 1)
794  return emitOpError(
795  "incorrect number of function results (always has to be 0 or 1)");
796 
797  return success();
798 }
799 
800 //===----------------------------------------------------------------------===//
801 // CallIndirectOp
802 //===----------------------------------------------------------------------===//
803 
804 // This verifier was added compared to the upstream implementation.
805 LogicalResult CallIndirectOp::verify() {
806  if (getNumResults() > 1)
807  return emitOpError(
808  "incorrect number of function results (always has to be 0 or 1)");
809 
810  return success();
811 }
812 
813 //===----------------------------------------------------------------------===//
814 // FuncOp
815 //
816 // TODO: Most of the implementation for this operation was copy-pasted from the
817 // 'func' dialect. Ideally, this upstream dialect refactored such that we can
818 // re-use the implementation here.
819 //===----------------------------------------------------------------------===//
820 
821 // Note that the create and build operations are taken from upstream, but the
822 // argNames argument was added.
823 FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
824  FunctionType type, ArrayRef<NamedAttribute> attrs) {
825  OpBuilder builder(location->getContext());
826  OperationState state(location, getOperationName());
827  FuncOp::build(builder, state, name, argNames, type, attrs);
828  return cast<FuncOp>(Operation::create(state));
829 }
830 
831 FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
832  FunctionType type, Operation::dialect_attr_range attrs) {
833  SmallVector<NamedAttribute, 8> attrRef(attrs);
834  return create(location, name, argNames, type, ArrayRef(attrRef));
835 }
836 
837 FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
838  FunctionType type, ArrayRef<NamedAttribute> attrs,
839  ArrayRef<DictionaryAttr> argAttrs) {
840  FuncOp func = create(location, name, argNames, type, attrs);
841  func.setAllArgAttrs(argAttrs);
842  return func;
843 }
844 
845 void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
846  StringRef name, ArrayAttr argNames, FunctionType type,
847  ArrayRef<NamedAttribute> attrs,
848  ArrayRef<DictionaryAttr> argAttrs) {
849  odsState.addAttribute(getArgNamesAttrName(odsState.name), argNames);
850  odsState.addAttribute(SymbolTable::getSymbolAttrName(),
851  odsBuilder.getStringAttr(name));
852  odsState.addAttribute(FuncOp::getFunctionTypeAttrName(odsState.name),
853  TypeAttr::get(type));
854  odsState.attributes.append(attrs.begin(), attrs.end());
855  odsState.addRegion();
856 
857  if (argAttrs.empty())
858  return;
859  assert(type.getNumInputs() == argAttrs.size());
860  mlir::function_interface_impl::addArgAndResultAttrs(
861  odsBuilder, odsState, argAttrs,
862  /*resultAttrs=*/std::nullopt, FuncOp::getArgAttrsAttrName(odsState.name),
863  FuncOp::getResAttrsAttrName(odsState.name));
864 }
865 
866 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
867  auto buildFuncType =
868  [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
869  mlir::function_interface_impl::VariadicFlag,
870  std::string &) { return builder.getFunctionType(argTypes, results); };
871 
872  // This was added specifically for our implementation, upstream does not have
873  // this feature.
874  if (succeeded(parser.parseOptionalKeyword("externC")))
875  result.addAttribute(getExternCAttrName(result.name),
876  UnitAttr::get(result.getContext()));
877 
878  // FIXME: below is an exact copy of the
879  // mlir::function_interface_impl::parseFunctionOp implementation, this was
880  // needed because we need to access the SSA names of the arguments.
881  SmallVector<OpAsmParser::Argument> entryArgs;
882  SmallVector<DictionaryAttr> resultAttrs;
883  SmallVector<Type> resultTypes;
884  auto &builder = parser.getBuilder();
885 
886  // Parse visibility.
887  (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
888 
889  // Parse the name as a symbol.
890  StringAttr nameAttr;
891  if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
892  result.attributes))
893  return failure();
894 
895  // Parse the function signature.
896  mlir::SMLoc signatureLocation = parser.getCurrentLocation();
897  bool isVariadic = false;
898  if (mlir::function_interface_impl::parseFunctionSignature(
899  parser, false, entryArgs, isVariadic, resultTypes, resultAttrs))
900  return failure();
901 
902  std::string errorMessage;
903  SmallVector<Type> argTypes;
904  argTypes.reserve(entryArgs.size());
905  for (auto &arg : entryArgs)
906  argTypes.push_back(arg.type);
907 
908  Type type = buildFuncType(
909  builder, argTypes, resultTypes,
910  mlir::function_interface_impl::VariadicFlag(isVariadic), errorMessage);
911  if (!type) {
912  return parser.emitError(signatureLocation)
913  << "failed to construct function type"
914  << (errorMessage.empty() ? "" : ": ") << errorMessage;
915  }
916  result.addAttribute(FuncOp::getFunctionTypeAttrName(result.name),
917  TypeAttr::get(type));
918 
919  // If function attributes are present, parse them.
920  NamedAttrList parsedAttributes;
921  mlir::SMLoc attributeDictLocation = parser.getCurrentLocation();
922  if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes))
923  return failure();
924 
925  // Disallow attributes that are inferred from elsewhere in the attribute
926  // dictionary.
927  for (StringRef disallowed :
928  {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
929  FuncOp::getFunctionTypeAttrName(result.name).getValue()}) {
930  if (parsedAttributes.get(disallowed))
931  return parser.emitError(attributeDictLocation, "'")
932  << disallowed
933  << "' is an inferred attribute and should not be specified in the "
934  "explicit attribute dictionary";
935  }
936  result.attributes.append(parsedAttributes);
937 
938  // Add the attributes to the function arguments.
939  assert(resultAttrs.size() == resultTypes.size());
940  mlir::function_interface_impl::addArgAndResultAttrs(
941  builder, result, entryArgs, resultAttrs,
942  FuncOp::getArgAttrsAttrName(result.name),
943  FuncOp::getResAttrsAttrName(result.name));
944 
945  // Parse the optional function body. The printer will not print the body if
946  // its empty, so disallow parsing of empty body in the parser.
947  auto *body = result.addRegion();
948  mlir::SMLoc loc = parser.getCurrentLocation();
949  mlir::OptionalParseResult parseResult =
950  parser.parseOptionalRegion(*body, entryArgs,
951  /*enableNameShadowing=*/false);
952  if (parseResult.has_value()) {
953  if (failed(*parseResult))
954  return failure();
955  // Function body was parsed, make sure its not empty.
956  if (body->empty())
957  return parser.emitError(loc, "expected non-empty function body");
958  }
959 
960  // Everythink below is added compared to the upstream implemenation to handle
961  // argument names.
962  SmallVector<Attribute> argNames;
963  if (!entryArgs.empty() && !entryArgs.front().ssaName.name.empty()) {
964  for (auto &arg : entryArgs)
965  argNames.push_back(
966  StringAttr::get(parser.getContext(), arg.ssaName.name.drop_front()));
967  }
968 
969  result.addAttribute(getArgNamesAttrName(result.name),
970  ArrayAttr::get(parser.getContext(), argNames));
971 
972  return success();
973 }
974 
975 void FuncOp::print(OpAsmPrinter &p) {
976  if (getExternC())
977  p << " externC";
978 
979  mlir::FunctionOpInterface op = *this;
980 
981  // FIXME: inlined mlir::function_interface_impl::printFunctionOp because we
982  // need to elide more attributes
983 
984  // Print the operation and the function name.
985  auto funcName =
986  op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
987  .getValue();
988  p << ' ';
989 
990  StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
991  if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
992  p << visibility.getValue() << ' ';
993  p.printSymbolName(funcName);
994 
995  ArrayRef<Type> argTypes = op.getArgumentTypes();
996  ArrayRef<Type> resultTypes = op.getResultTypes();
997  mlir::function_interface_impl::printFunctionSignature(p, op, argTypes, false,
998  resultTypes);
999  mlir::function_interface_impl::printFunctionAttributes(
1000  p, op,
1001  {visibilityAttrName, "externC", "argNames", getFunctionTypeAttrName(),
1002  getArgAttrsAttrName(), getResAttrsAttrName()});
1003  // Print the body if this is not an external function.
1004  Region &body = op->getRegion(0);
1005  if (!body.empty()) {
1006  p << ' ';
1007  p.printRegion(body, /*printEntryBlockArgs=*/false,
1008  /*printBlockTerminators=*/true);
1009  }
1010 }
1011 
1012 // FIXME: the below clone operation are exact copies from upstream.
1013 
1014 /// Clone the internal blocks from this function into dest and all attributes
1015 /// from this function to dest.
1016 void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
1017  // Add the attributes of this function to dest.
1018  llvm::MapVector<StringAttr, Attribute> newAttrMap;
1019  for (const auto &attr : dest->getAttrs())
1020  newAttrMap.insert({attr.getName(), attr.getValue()});
1021  for (const auto &attr : (*this)->getAttrs())
1022  newAttrMap.insert({attr.getName(), attr.getValue()});
1023 
1024  auto newAttrs = llvm::to_vector(llvm::map_range(
1025  newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
1026  return NamedAttribute(attrPair.first, attrPair.second);
1027  }));
1028  dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
1029 
1030  // Clone the body.
1031  getBody().cloneInto(&dest.getBody(), mapper);
1032 }
1033 
1034 /// Create a deep copy of this function and all of its blocks, remapping
1035 /// any operands that use values outside of the function using the map that is
1036 /// provided (leaving them alone if no entry is present). Replaces references
1037 /// to cloned sub-values with the corresponding value that is copied, and adds
1038 /// those mappings to the mapper.
1039 FuncOp FuncOp::clone(IRMapping &mapper) {
1040  // Create the new function.
1041  FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
1042 
1043  // If the function has a body, then the user might be deleting arguments to
1044  // the function by specifying them in the mapper. If so, we don't add the
1045  // argument to the input type vector.
1046  if (!isExternal()) {
1047  FunctionType oldType = getFunctionType();
1048 
1049  unsigned oldNumArgs = oldType.getNumInputs();
1050  SmallVector<Type, 4> newInputs;
1051  newInputs.reserve(oldNumArgs);
1052  for (unsigned i = 0; i != oldNumArgs; ++i)
1053  if (!mapper.contains(getArgument(i)))
1054  newInputs.push_back(oldType.getInput(i));
1055 
1056  /// If any of the arguments were dropped, update the type and drop any
1057  /// necessary argument attributes.
1058  if (newInputs.size() != oldNumArgs) {
1059  newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
1060  oldType.getResults()));
1061 
1062  if (ArrayAttr argAttrs = getAllArgAttrs()) {
1063  SmallVector<Attribute> newArgAttrs;
1064  newArgAttrs.reserve(newInputs.size());
1065  for (unsigned i = 0; i != oldNumArgs; ++i)
1066  if (!mapper.contains(getArgument(i)))
1067  newArgAttrs.push_back(argAttrs[i]);
1068  newFunc.setAllArgAttrs(newArgAttrs);
1069  }
1070  }
1071  }
1072 
1073  /// Clone the current function into the new one and return it.
1074  cloneInto(newFunc, mapper);
1075  return newFunc;
1076 }
1077 
1078 FuncOp FuncOp::clone() {
1079  IRMapping mapper;
1080  return clone(mapper);
1081 }
1082 
1083 // The following functions are entirely new additions compared to upstream.
1084 
1085 void FuncOp::getAsmBlockArgumentNames(mlir::Region &region,
1086  mlir::OpAsmSetValueNameFn setNameFn) {
1087  if (region.empty())
1088  return;
1089 
1090  for (auto [arg, name] : llvm::zip(getArguments(), getArgNames()))
1091  setNameFn(arg, name.cast<StringAttr>().getValue());
1092 }
1093 
1094 LogicalResult FuncOp::verify() {
1095  if (getFunctionType().getNumResults() > 1)
1096  return emitOpError(
1097  "incorrect number of function results (always has to be 0 or 1)");
1098 
1099  if (getBody().empty())
1100  return success();
1101 
1102  if (getArgNames().size() != getFunctionType().getNumInputs())
1103  return emitOpError("incorrect number of argument names");
1104 
1105  for (auto portName : getArgNames()) {
1106  if (portName.cast<StringAttr>().getValue().empty())
1107  return emitOpError("arg name must not be empty");
1108  }
1109 
1110  return success();
1111 }
1112 
1113 LogicalResult FuncOp::verifyRegions() {
1114  auto attachNote = [&](mlir::InFlightDiagnostic &diag) {
1115  diag.attachNote(getLoc()) << "in function '@" << getName() << "'";
1116  };
1117  return verifyUniqueNamesInRegion(getOperation(), getArgNames(), attachNote);
1118 }
1119 
1120 //===----------------------------------------------------------------------===//
1121 // ReturnOp
1122 //
1123 // TODO: The implementation for this operation was copy-pasted from the
1124 // 'func' dialect. Ideally, this upstream dialect refactored such that we can
1125 // re-use the implementation here.
1126 //===----------------------------------------------------------------------===//
1127 
1128 LogicalResult ReturnOp::verify() {
1129  auto function = cast<FuncOp>((*this)->getParentOp());
1130 
1131  // The operand number and types must match the function signature.
1132  const auto &results = function.getFunctionType().getResults();
1133  if (getNumOperands() != results.size())
1134  return emitOpError("has ")
1135  << getNumOperands() << " operands, but enclosing function (@"
1136  << function.getName() << ") returns " << results.size();
1137 
1138  for (unsigned i = 0, e = results.size(); i != e; ++i)
1139  if (getOperand(i).getType() != results[i])
1140  return emitError() << "type of return operand " << i << " ("
1141  << getOperand(i).getType()
1142  << ") doesn't match function result type ("
1143  << results[i] << ")"
1144  << " in function @" << function.getName();
1145 
1146  return success();
1147 }
1148 
1149 //===----------------------------------------------------------------------===//
1150 // TableGen generated logic.
1151 //===----------------------------------------------------------------------===//
1152 
1153 // Provide the autogenerated implementation guts for the Op classes.
1154 #define GET_OP_CLASSES
1155 #include "circt/Dialect/SystemC/SystemC.cpp.inc"
assert(baseType &&"element must be base type")
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition: HWOps.cpp:1382
@ Input
Definition: HW.h:32
@ Output
Definition: HW.h:32
@ InOut
Definition: HW.h:32
static InstancePath empty
llvm::SmallVector< StringAttr > inputs
Builder builder
static hw::ModulePort::Direction getDirection(Type type)
Definition: SystemCOps.cpp:102
static Type wrapPortType(Type type, hw::ModulePort::Direction direction)
Definition: SystemCOps.cpp:222
static LogicalResult verifyUniqueNamesInRegion(Operation *operation, ArrayAttr argNames, std::function< void(mlir::InFlightDiagnostic &)> attachNote)
Definition: SystemCOps.cpp:31
Represents a finite word-length bit vector in SystemC as described in IEEE 1666-2011 §7....
Definition: SystemCTypes.h:227
Represents a limited word-length signed integer in SystemC as described in IEEE 1666-2011 §7....
Definition: SystemCTypes.h:82
Represents a finite word-length bit vector in SystemC as described in IEEE 1666-2011 §7....
Definition: SystemCTypes.h:262
Represents a finite word-length signed integer in SystemC as described in IEEE 1666-2011 §7....
Definition: SystemCTypes.h:152
Represents a limited word-length unsigned integer in SystemC as described in IEEE 1666-2011 §7....
Definition: SystemCTypes.h:116
Represents a finite word-length unsigned integer in SystemC as described in IEEE 1666-2011 §7....
Definition: SystemCTypes.h:187
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:53
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
void printModuleSignature(OpAsmPrinter &p, Operation *op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes, bool &needArgNamesAttr)
Print a module signature with named results.
ParseResult parseModuleFunctionSignature(OpAsmParser &parser, bool &isVariadic, SmallVectorImpl< OpAsmParser::Argument > &args, SmallVectorImpl< Attribute > &argNames, SmallVectorImpl< Attribute > &argLocs, SmallVectorImpl< Attribute > &resultNames, SmallVectorImpl< DictionaryAttr > &resultAttrs, SmallVectorImpl< Attribute > &resultLocs, TypeAttr &type)
This is a variant of mlir::parseFunctionSignature that allows names on result arguments.
circt::hw::InOutType InOutType
Definition: SVTypes.h:25
Type getSignalBaseType(Type type)
Get the type wrapped by a signal or port (in, inout, out) type.
std::optional< size_t > getBitWidth(Type type)
Return the bitwidth of a type.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
ParseResult parseImplicitSSAName(OpAsmParser &parser, StringAttr &attr)
Parse an implicit SSA name string attribute.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition: LLVM.h:186