Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
SystemCOps.cpp
Go to the documentation of this file.
1//===- SystemCOps.cpp - Implement the SystemC operations ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SystemC ops.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/IRMapping.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Interfaces/FunctionImplementation.h"
21#include "llvm/ADT/MapVector.h"
22#include "llvm/ADT/TypeSwitch.h"
23
24using namespace circt;
25using namespace circt::systemc;
26
27//===----------------------------------------------------------------------===//
28// Helpers
29//===----------------------------------------------------------------------===//
30
31static LogicalResult verifyUniqueNamesInRegion(
32 Operation *operation, ArrayAttr argNames,
33 std::function<void(mlir::InFlightDiagnostic &)> attachNote) {
34 DenseMap<StringRef, BlockArgument> portNames;
35 DenseMap<StringRef, Operation *> memberNames;
36 DenseMap<StringRef, Operation *> localNames;
37
38 if (operation->getNumRegions() != 1)
39 return operation->emitError("required to have exactly one region");
40
41 bool portsVerified = true;
42
43 for (auto arg : llvm::zip(argNames, operation->getRegion(0).getArguments())) {
44 StringRef argName = cast<StringAttr>(std::get<0>(arg)).getValue();
45 BlockArgument argValue = std::get<1>(arg);
46
47 if (portNames.count(argName)) {
48 auto diag = mlir::emitError(argValue.getLoc(), "redefines name '")
49 << argName << "'";
50 diag.attachNote(portNames[argName].getLoc())
51 << "'" << argName << "' first defined here";
52 attachNote(diag);
53 portsVerified = false;
54 continue;
55 }
56
57 portNames.insert({argName, argValue});
58 }
59
60 WalkResult result =
61 operation->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
62 if (isa<SCModuleOp>(op->getParentOp()))
63 localNames.clear();
64
65 if (auto nameDeclOp = dyn_cast<SystemCNameDeclOpInterface>(op)) {
66 StringRef name = nameDeclOp.getName();
67
68 auto reportNameRedefinition = [&](Location firstLoc) -> WalkResult {
69 auto diag = mlir::emitError(op->getLoc(), "redefines name '")
70 << name << "'";
71 diag.attachNote(firstLoc) << "'" << name << "' first defined here";
72 attachNote(diag);
73 return WalkResult::interrupt();
74 };
75
76 if (portNames.count(name))
77 return reportNameRedefinition(portNames[name].getLoc());
78 if (memberNames.count(name))
79 return reportNameRedefinition(memberNames[name]->getLoc());
80 if (localNames.count(name))
81 return reportNameRedefinition(localNames[name]->getLoc());
82
83 if (isa<SCModuleOp>(op->getParentOp()))
84 memberNames.insert({name, op});
85 else
86 localNames.insert({name, op});
87 }
88
89 return WalkResult::advance();
90 });
91
92 if (result.wasInterrupted() || !portsVerified)
93 return failure();
94
95 return success();
96}
97
98//===----------------------------------------------------------------------===//
99// SCModuleOp
100//===----------------------------------------------------------------------===//
101
103 return TypeSwitch<Type, hw::ModulePort::Direction>(type)
104 .Case<InOutType>([](auto ty) { return hw::ModulePort::Direction::InOut; })
105 .Case<InputType>([](auto ty) { return hw::ModulePort::Direction::Input; })
106 .Case<OutputType>(
107 [](auto ty) { return hw::ModulePort::Direction::Output; });
108}
109
110SCModuleOp::PortDirectionRange
111SCModuleOp::getPortsOfDirection(hw::ModulePort::Direction direction) {
112 std::function<bool(const BlockArgument &)> predicateFn =
113 [&](const BlockArgument &arg) -> bool {
114 return getDirection(arg.getType()) == direction;
115 };
116 return llvm::make_filter_range(getArguments(), predicateFn);
117}
118
119SmallVector<::circt::hw::PortInfo> SCModuleOp::getPortList() {
120 SmallVector<hw::PortInfo> ports;
121 for (int i = 0, e = getNumArguments(); i < e; ++i) {
123 info.name = cast<StringAttr>(getPortNames()[i]);
124 info.type = getSignalBaseType(getArgument(i).getType());
125 info.dir = getDirection(info.type);
126 ports.push_back(info);
127 }
128 return ports;
129}
130
131mlir::Region *SCModuleOp::getCallableRegion() { return &getBody(); }
132
133StringRef SCModuleOp::getModuleName() {
134 return (*this)
135 ->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
136 .getValue();
137}
138
139ParseResult SCModuleOp::parse(OpAsmParser &parser, OperationState &result) {
140
141 // Parse the visibility attribute.
142 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
143
144 // Parse the name as a symbol.
145 StringAttr moduleName;
146 if (parser.parseSymbolName(moduleName, SymbolTable::getSymbolAttrName(),
147 result.attributes))
148 return failure();
149
150 // Parse the function signature.
151 bool isVariadic = false;
152 SmallVector<OpAsmParser::Argument, 4> entryArgs;
153 SmallVector<Attribute> argNames;
154 SmallVector<Attribute> argLocs;
155 SmallVector<Attribute> resultNames;
156 SmallVector<DictionaryAttr> resultAttrs;
157 SmallVector<Attribute> resultLocs;
158 TypeAttr functionType;
160 parser, isVariadic, entryArgs, argNames, argLocs, resultNames,
161 resultAttrs, resultLocs, functionType)))
162 return failure();
163
164 // Parse the attribute dict.
165 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
166 return failure();
167
168 result.addAttribute("portNames",
169 ArrayAttr::get(parser.getContext(), argNames));
170
171 result.addAttribute(SCModuleOp::getFunctionTypeAttrName(result.name),
172 functionType);
173
174 mlir::call_interface_impl::addArgAndResultAttrs(
175 parser.getBuilder(), result, entryArgs, resultAttrs,
176 SCModuleOp::getArgAttrsAttrName(result.name),
177 SCModuleOp::getResAttrsAttrName(result.name));
178
179 auto &body = *result.addRegion();
180 if (parser.parseRegion(body, entryArgs))
181 return failure();
182 if (body.empty())
183 body.push_back(std::make_unique<Block>().release());
184
185 return success();
186}
187
188void SCModuleOp::print(OpAsmPrinter &p) {
189 p << ' ';
190
191 // Print the visibility of the module.
192 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
193 if (auto visibility =
194 getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
195 p << visibility.getValue() << ' ';
196
197 p.printSymbolName(SymbolTable::getSymbolName(*this).getValue());
198 p << ' ';
199
200 bool needArgNamesAttr = false;
202 p, *this, getFunctionType().getInputs(), false, {}, needArgNamesAttr);
203 mlir::function_interface_impl::printFunctionAttributes(
204 p, *this,
205 {"portNames", getFunctionTypeAttrName(), getArgAttrsAttrName(),
206 getResAttrsAttrName()});
207
208 p << ' ';
209 p.printRegion(getBody(), false, false);
210}
211
212/// Returns the argument types of this function.
213ArrayRef<Type> SCModuleOp::getArgumentTypes() {
214 return getFunctionType().getInputs();
215}
216
217/// Returns the result types of this function.
218ArrayRef<Type> SCModuleOp::getResultTypes() {
219 return getFunctionType().getResults();
220}
221
222static Type wrapPortType(Type type, hw::ModulePort::Direction direction) {
223 if (auto inoutTy = dyn_cast<hw::InOutType>(type))
224 type = inoutTy.getElementType();
225
226 switch (direction) {
227 case hw::ModulePort::Direction::InOut:
228 return InOutType::get(type);
229 case hw::ModulePort::Direction::Input:
230 return InputType::get(type);
231 case hw::ModulePort::Direction::Output:
232 return OutputType::get(type);
233 }
234 llvm_unreachable("Impossible port direction");
235}
236
237void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
238 StringAttr name, ArrayAttr portNames,
239 ArrayRef<Type> portTypes,
240 ArrayRef<NamedAttribute> attributes) {
241 odsState.addAttribute(getPortNamesAttrName(odsState.name), portNames);
242 Region *region = odsState.addRegion();
243
244 auto moduleType = odsBuilder.getFunctionType(portTypes, {});
245 odsState.addAttribute(getFunctionTypeAttrName(odsState.name),
246 TypeAttr::get(moduleType));
247
248 odsState.addAttribute(SymbolTable::getSymbolAttrName(), name);
249 region->push_back(new Block);
250 region->addArguments(
251 portTypes,
252 SmallVector<Location>(portTypes.size(), odsBuilder.getUnknownLoc()));
253 odsState.addAttributes(attributes);
254}
255
256void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
257 StringAttr name, ArrayRef<hw::PortInfo> ports,
258 ArrayRef<NamedAttribute> attributes) {
259 MLIRContext *ctxt = odsBuilder.getContext();
260 SmallVector<Attribute> portNames;
261 SmallVector<Type> portTypes;
262 for (auto port : ports) {
263 portNames.push_back(StringAttr::get(ctxt, port.getName()));
264 portTypes.push_back(wrapPortType(port.type, port.dir));
265 }
266 build(odsBuilder, odsState, name, ArrayAttr::get(ctxt, portNames), portTypes);
267}
268
269void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
270 StringAttr name, const hw::ModulePortInfo &ports,
271 ArrayRef<NamedAttribute> attributes) {
272 MLIRContext *ctxt = odsBuilder.getContext();
273 SmallVector<Attribute> portNames;
274 SmallVector<Type> portTypes;
275 for (auto port : ports) {
276 portNames.push_back(StringAttr::get(ctxt, port.getName()));
277 portTypes.push_back(wrapPortType(port.type, port.dir));
278 }
279 build(odsBuilder, odsState, name, ArrayAttr::get(ctxt, portNames), portTypes);
280}
281
282void SCModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
283 mlir::OpAsmSetValueNameFn setNameFn) {
284 if (region.empty())
285 return;
286
287 ArrayAttr portNames = getPortNames();
288 for (size_t i = 0, e = getNumArguments(); i != e; ++i) {
289 auto str = cast<StringAttr>(portNames[i]).getValue();
290 setNameFn(getArgument(i), str);
291 }
292}
293
294LogicalResult SCModuleOp::verify() {
295 if (getFunctionType().getNumResults() != 0)
296 return emitOpError(
297 "incorrect number of function results (always has to be 0)");
298 if (getPortNames().size() != getFunctionType().getNumInputs())
299 return emitOpError("incorrect number of port names");
300
301 for (auto arg : getArguments()) {
302 if (!hw::type_isa<InputType, OutputType, InOutType>(arg.getType()))
303 return mlir::emitError(
304 arg.getLoc(),
305 "module port must be of type 'sc_in', 'sc_out', or 'sc_inout'");
306 }
307
308 for (auto portName : getPortNames()) {
309 if (cast<StringAttr>(portName).getValue().empty())
310 return emitOpError("port name must not be empty");
311 }
312
313 return success();
314}
315
316LogicalResult SCModuleOp::verifyRegions() {
317 auto attachNote = [&](mlir::InFlightDiagnostic &diag) {
318 diag.attachNote(getLoc()) << "in module '@" << getModuleName() << "'";
319 };
320 return verifyUniqueNamesInRegion(getOperation(), getPortNames(), attachNote);
321}
322
323CtorOp SCModuleOp::getOrCreateCtor() {
324 CtorOp ctor;
325 getBody().walk([&](Operation *op) {
326 if ((ctor = dyn_cast<CtorOp>(op)))
327 return WalkResult::interrupt();
328
329 return WalkResult::skip();
330 });
331
332 if (ctor)
333 return ctor;
334
335 auto builder = OpBuilder(getBody());
336 return CtorOp::create(builder, getLoc());
337}
338
339DestructorOp SCModuleOp::getOrCreateDestructor() {
340 DestructorOp destructor;
341 getBody().walk([&](Operation *op) {
342 if ((destructor = dyn_cast<DestructorOp>(op)))
343 return WalkResult::interrupt();
344
345 return WalkResult::skip();
346 });
347
348 if (destructor)
349 return destructor;
350
351 auto builder = OpBuilder::atBlockEnd(getBodyBlock());
352 return DestructorOp::create(builder, getLoc());
353}
354
355//===----------------------------------------------------------------------===//
356// SignalOp
357//===----------------------------------------------------------------------===//
358
359void SignalOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
360 setNameFn(getSignal(), getName());
361}
362
363//===----------------------------------------------------------------------===//
364// ConvertOp
365//===----------------------------------------------------------------------===//
366
367OpFoldResult ConvertOp::fold(FoldAdaptor) {
368 if (getInput().getType() == getResult().getType())
369 return getInput();
370
371 if (auto other = getInput().getDefiningOp<ConvertOp>()) {
372 Type inputType = other.getInput().getType();
373 Type intermediateType = getInput().getType();
374
375 if (inputType != getResult().getType())
376 return {};
377
378 // Either both the input and intermediate types are signed or both are
379 // unsigned.
380 bool inputSigned = isa<SignedType, IntBaseType>(inputType);
381 bool intermediateSigned = isa<SignedType, IntBaseType>(intermediateType);
382 if (inputSigned ^ intermediateSigned)
383 return {};
384
385 // Converting 4-valued to 2-valued and back may lose information.
386 if (isa<LogicVectorBaseType, LogicType>(inputType) &&
387 !isa<LogicVectorBaseType, LogicType>(intermediateType))
388 return {};
389
390 auto inputBw = getBitWidth(inputType);
391 auto intermediateBw = getBitWidth(intermediateType);
392
393 if (!inputBw && intermediateBw) {
394 if (isa<IntBaseType, UIntBaseType>(inputType) && *intermediateBw >= 64)
395 return other.getInput();
396 // We cannot support input types of signed, unsigned, and vector types
397 // since they have no upper bound for the bit-width.
398 }
399
400 if (!intermediateBw) {
401 if (isa<BitVectorBaseType, LogicVectorBaseType>(intermediateType))
402 return other.getInput();
403
404 if (!inputBw && isa<IntBaseType, UIntBaseType>(inputType) &&
405 isa<SignedType, UnsignedType>(intermediateType))
406 return other.getInput();
407
408 if (inputBw && *inputBw <= 64 &&
409 isa<IntBaseType, UIntBaseType, SignedType, UnsignedType>(
410 intermediateType))
411 return other.getInput();
412
413 // We have to be careful with the signed and unsigned types as they often
414 // have a max bit-width defined (that can be customized) and thus folding
415 // here could change the behavior.
416 }
417
418 if (inputBw && intermediateBw && *inputBw <= *intermediateBw)
419 return other.getInput();
420 }
421
422 return {};
423}
424
425//===----------------------------------------------------------------------===//
426// CtorOp
427//===----------------------------------------------------------------------===//
428
429LogicalResult CtorOp::verify() {
430 if (getBody().getNumArguments() != 0)
431 return emitOpError("must not have any arguments");
432
433 return success();
434}
435
436//===----------------------------------------------------------------------===//
437// SCFuncOp
438//===----------------------------------------------------------------------===//
439
440void SCFuncOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
441 setNameFn(getHandle(), getName());
442}
443
444LogicalResult SCFuncOp::verify() {
445 if (getBody().getNumArguments() != 0)
446 return emitOpError("must not have any arguments");
447
448 return success();
449}
450
451//===----------------------------------------------------------------------===//
452// InstanceDeclOp
453//===----------------------------------------------------------------------===//
454
455void InstanceDeclOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
456 setNameFn(getInstanceHandle(), getName());
457}
458
459StringRef InstanceDeclOp::getInstanceName() { return getName(); }
460StringAttr InstanceDeclOp::getInstanceNameAttr() { return getNameAttr(); }
461
462Operation *
463InstanceDeclOp::getReferencedModuleCached(const hw::HWSymbolCache *cache) {
464 if (cache)
465 if (auto *result = cache->getDefinition(getModuleNameAttr()))
466 return result;
467
468 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
469 return topLevelModuleOp.lookupSymbol(getModuleName());
470}
471
472LogicalResult
473InstanceDeclOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
474 auto *module =
475 symbolTable.lookupNearestSymbolFrom(*this, getModuleNameAttr());
476 if (module == nullptr)
477 return emitError("cannot find module definition '")
478 << getModuleName() << "'";
479
480 auto emitError = [&](const std::function<void(InFlightDiagnostic & diag)> &fn)
481 -> LogicalResult {
482 auto diag = emitOpError();
483 fn(diag);
484 diag.attachNote(module->getLoc()) << "module declared here";
485 return failure();
486 };
487
488 // It must be a systemc module.
489 if (!isa<SCModuleOp>(module))
490 return emitError([&](auto &diag) {
491 diag << "symbol reference '" << getModuleName()
492 << "' isn't a systemc module";
493 });
494
495 auto scModule = cast<SCModuleOp>(module);
496
497 // Check that the module name of the symbol and instance type match.
498 if (scModule.getModuleName() != getInstanceType().getModuleName())
499 return emitError([&](auto &diag) {
500 diag << "module names must match; expected '" << scModule.getModuleName()
501 << "' but got '" << getInstanceType().getModuleName().getValue()
502 << "'";
503 });
504
505 // Check that port types and names are consistent with the referenced module.
506 ArrayRef<ModuleType::PortInfo> ports = getInstanceType().getPorts();
507 ArrayAttr modArgNames = scModule.getPortNames();
508 auto numPorts = ports.size();
509 auto expectedPortTypes = scModule.getArgumentTypes();
510
511 if (expectedPortTypes.size() != numPorts)
512 return emitError([&](auto &diag) {
513 diag << "has a wrong number of ports; expected "
514 << expectedPortTypes.size() << " but got " << numPorts;
515 });
516
517 for (size_t i = 0; i != numPorts; ++i) {
518 if (ports[i].type != expectedPortTypes[i]) {
519 return emitError([&](auto &diag) {
520 diag << "port type #" << i << " must be " << expectedPortTypes[i]
521 << ", but got " << ports[i].type;
522 });
523 }
524
525 if (ports[i].name != modArgNames[i])
526 return emitError([&](auto &diag) {
527 diag << "port name #" << i << " must be " << modArgNames[i]
528 << ", but got " << ports[i].name;
529 });
530 }
531
532 return success();
533}
534
535SmallVector<hw::PortInfo> InstanceDeclOp::getPortList() {
536 return cast<hw::PortList>(SymbolTable::lookupNearestSymbolFrom(
537 getOperation(), getReferencedModuleNameAttr()))
538 .getPortList();
539}
540
541//===----------------------------------------------------------------------===//
542// DestructorOp
543//===----------------------------------------------------------------------===//
544
545LogicalResult DestructorOp::verify() {
546 if (getBody().getNumArguments() != 0)
547 return emitOpError("must not have any arguments");
548
549 return success();
550}
551
552//===----------------------------------------------------------------------===//
553// BindPortOp
554//===----------------------------------------------------------------------===//
555
556ParseResult BindPortOp::parse(OpAsmParser &parser, OperationState &result) {
557 OpAsmParser::UnresolvedOperand instance, channel;
558 std::string portName;
559 if (parser.parseOperand(instance) || parser.parseLSquare() ||
560 parser.parseString(&portName))
561 return failure();
562
563 auto portNameLoc = parser.getCurrentLocation();
564
565 if (parser.parseRSquare() || parser.parseKeyword("to") ||
566 parser.parseOperand(channel))
567 return failure();
568
569 if (parser.parseOptionalAttrDict(result.attributes))
570 return failure();
571
572 auto typeListLoc = parser.getCurrentLocation();
573 SmallVector<Type> types;
574 if (parser.parseColonTypeList(types))
575 return failure();
576
577 if (types.size() != 2)
578 return parser.emitError(typeListLoc,
579 "expected a list of exactly 2 types, but got ")
580 << types.size();
581
582 if (parser.resolveOperand(instance, types[0], result.operands))
583 return failure();
584 if (parser.resolveOperand(channel, types[1], result.operands))
585 return failure();
586
587 if (auto moduleType = dyn_cast<ModuleType>(types[0])) {
588 auto ports = moduleType.getPorts();
589 uint64_t index = 0;
590 for (auto port : ports) {
591 if (port.name == portName)
592 break;
593 index++;
594 }
595 if (index >= ports.size())
596 return parser.emitError(portNameLoc, "port name \"")
597 << portName << "\" not found in module";
598
599 result.addAttribute("portId", parser.getBuilder().getIndexAttr(index));
600
601 return success();
602 }
603
604 return failure();
605}
606
607void BindPortOp::print(OpAsmPrinter &p) {
608 p << " " << getInstance() << "["
609 << cast<ModuleType>(getInstance().getType())
610 .getPorts()[getPortId().getZExtValue()]
611 .name
612 << "] to " << getChannel();
613 p.printOptionalAttrDict((*this)->getAttrs(), {"portId"});
614 p << " : " << getInstance().getType() << ", " << getChannel().getType();
615}
616
617LogicalResult BindPortOp::verify() {
618 auto ports = cast<ModuleType>(getInstance().getType()).getPorts();
619 if (getPortId().getZExtValue() >= ports.size())
620 return emitOpError("port #")
621 << getPortId().getZExtValue() << " does not exist, there are only "
622 << ports.size() << " ports";
623
624 // Verify that the base types match.
625 Type portType = ports[getPortId().getZExtValue()].type;
626 Type channelType = getChannel().getType();
627 if (getSignalBaseType(portType) != getSignalBaseType(channelType))
628 return emitOpError() << portType << " port cannot be bound to "
629 << channelType << " channel due to base type mismatch";
630
631 // Verify that the port/channel directions are valid.
632 if ((isa<InputType>(portType) && isa<OutputType>(channelType)) ||
633 (isa<OutputType>(portType) && isa<InputType>(channelType)))
634 return emitOpError() << portType << " port cannot be bound to "
635 << channelType
636 << " channel due to port direction mismatch";
637
638 return success();
639}
640
641StringRef BindPortOp::getPortName() {
642 return cast<ModuleType>(getInstance().getType())
643 .getPorts()[getPortId().getZExtValue()]
644 .name.getValue();
645}
646
647//===----------------------------------------------------------------------===//
648// SensitiveOp
649//===----------------------------------------------------------------------===//
650
651LogicalResult SensitiveOp::canonicalize(SensitiveOp op,
652 PatternRewriter &rewriter) {
653 if (op.getSensitivities().empty()) {
654 rewriter.eraseOp(op);
655 return success();
656 }
657
658 return failure();
659}
660
661//===----------------------------------------------------------------------===//
662// VariableOp
663//===----------------------------------------------------------------------===//
664
665void VariableOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
666 setNameFn(getVariable(), getName());
667}
668
669ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
670 StringAttr nameAttr;
671 if (parseImplicitSSAName(parser, nameAttr))
672 return failure();
673 result.addAttribute("name", nameAttr);
674
675 OpAsmParser::UnresolvedOperand init;
676 auto initResult = parser.parseOptionalOperand(init);
677
678 if (parser.parseOptionalAttrDict(result.attributes))
679 return failure();
680
681 Type variableType;
682 if (parser.parseColonType(variableType))
683 return failure();
684
685 if (initResult.has_value()) {
686 if (parser.resolveOperand(init, variableType, result.operands))
687 return failure();
688 }
689 result.addTypes({variableType});
690
691 return success();
692}
693
694void VariableOp::print(::mlir::OpAsmPrinter &p) {
695 p << " ";
696
697 if (getInit())
698 p << getInit() << " ";
699
700 p.printOptionalAttrDict(getOperation()->getAttrs(), {"name"});
701 p << ": " << getVariable().getType();
702}
703
704LogicalResult VariableOp::verify() {
705 if (getInit() && getInit().getType() != getVariable().getType())
706 return emitOpError(
707 "'init' and 'variable' must have the same type, but got ")
708 << getInit().getType() << " and " << getVariable().getType();
709
710 return success();
711}
712
713//===----------------------------------------------------------------------===//
714// InteropVerilatedOp
715//===----------------------------------------------------------------------===//
716
717/// Create a instance that refers to a known module.
718void InteropVerilatedOp::build(OpBuilder &odsBuilder, OperationState &odsState,
719 Operation *module, StringAttr name,
720 ArrayRef<Value> inputs) {
721 auto mod = cast<hw::HWModuleLike>(module);
722 auto argNames = odsBuilder.getArrayAttr(mod.getInputNames());
723 auto resultNames = odsBuilder.getArrayAttr(mod.getOutputNames());
724 build(odsBuilder, odsState, mod.getHWModuleType().getOutputTypes(), name,
725 FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), argNames,
726 resultNames, inputs);
727}
728
729LogicalResult
730InteropVerilatedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
732 *this, getModuleNameAttr(), getInputs(), getResultTypes(),
733 getInputNames(), getResultNames(), ArrayAttr(), symbolTable);
734}
735
736/// Suggest a name for each result value based on the saved result names
737/// attribute.
738void InteropVerilatedOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
740 getResultNames(), getResults());
741}
742
743//===----------------------------------------------------------------------===//
744// CallOp
745//
746// TODO: The implementation for this operation was copy-pasted from the
747// 'func' dialect. Ideally, this upstream dialect refactored such that we can
748// re-use the implementation here.
749//===----------------------------------------------------------------------===//
750
751// FIXME: This is an exact copy from upstream
752LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
753 // Check that the callee attribute was specified.
754 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
755 if (!fnAttr)
756 return emitOpError("requires a 'callee' symbol reference attribute");
757 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
758 if (!fn)
759 return emitOpError() << "'" << fnAttr.getValue()
760 << "' does not reference a valid function";
761
762 // Verify that the operand and result types match the callee.
763 auto fnType = fn.getFunctionType();
764 if (fnType.getNumInputs() != getNumOperands())
765 return emitOpError("incorrect number of operands for callee");
766
767 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
768 if (getOperand(i).getType() != fnType.getInput(i))
769 return emitOpError("operand type mismatch: expected operand type ")
770 << fnType.getInput(i) << ", but provided "
771 << getOperand(i).getType() << " for operand number " << i;
772
773 if (fnType.getNumResults() != getNumResults())
774 return emitOpError("incorrect number of results for callee");
775
776 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
777 if (getResult(i).getType() != fnType.getResult(i)) {
778 auto diag = emitOpError("result type mismatch at index ") << i;
779 diag.attachNote() << " op result types: " << getResultTypes();
780 diag.attachNote() << "function result types: " << fnType.getResults();
781 return diag;
782 }
783
784 return success();
785}
786
787FunctionType CallOp::getCalleeType() {
788 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
789}
790
791// This verifier was added compared to the upstream implementation.
792LogicalResult CallOp::verify() {
793 if (getNumResults() > 1)
794 return emitOpError(
795 "incorrect number of function results (always has to be 0 or 1)");
796
797 return success();
798}
799
800//===----------------------------------------------------------------------===//
801// CallIndirectOp
802//===----------------------------------------------------------------------===//
803
804// This verifier was added compared to the upstream implementation.
805LogicalResult CallIndirectOp::verify() {
806 if (getNumResults() > 1)
807 return emitOpError(
808 "incorrect number of function results (always has to be 0 or 1)");
809
810 return success();
811}
812
813//===----------------------------------------------------------------------===//
814// FuncOp
815//
816// TODO: Most of the implementation for this operation was copy-pasted from the
817// 'func' dialect. Ideally, this upstream dialect refactored such that we can
818// re-use the implementation here.
819//===----------------------------------------------------------------------===//
820
821// Note that the create and build operations are taken from upstream, but the
822// argNames argument was added.
823FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
824 FunctionType type, ArrayRef<NamedAttribute> attrs) {
825 OpBuilder builder(location->getContext());
826 OperationState state(location, getOperationName());
827 FuncOp::build(builder, state, name, argNames, type, attrs);
828 return cast<FuncOp>(Operation::create(state));
829}
830
831FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
832 FunctionType type, Operation::dialect_attr_range attrs) {
833 SmallVector<NamedAttribute, 8> attrRef(attrs);
834 return create(location, name, argNames, type, ArrayRef(attrRef));
835}
836
837FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
838 FunctionType type, ArrayRef<NamedAttribute> attrs,
839 ArrayRef<DictionaryAttr> argAttrs) {
840 FuncOp func = create(location, name, argNames, type, attrs);
841 func.setAllArgAttrs(argAttrs);
842 return func;
843}
844
845void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
846 StringRef name, ArrayAttr argNames, FunctionType type,
847 ArrayRef<NamedAttribute> attrs,
848 ArrayRef<DictionaryAttr> argAttrs) {
849 odsState.addAttribute(getArgNamesAttrName(odsState.name), argNames);
850 odsState.addAttribute(SymbolTable::getSymbolAttrName(),
851 odsBuilder.getStringAttr(name));
852 odsState.addAttribute(FuncOp::getFunctionTypeAttrName(odsState.name),
853 TypeAttr::get(type));
854 odsState.attributes.append(attrs.begin(), attrs.end());
855 odsState.addRegion();
856
857 if (argAttrs.empty())
858 return;
859 assert(type.getNumInputs() == argAttrs.size());
860 mlir::call_interface_impl::addArgAndResultAttrs(
861 odsBuilder, odsState, argAttrs,
862 /*resultAttrs=*/{}, FuncOp::getArgAttrsAttrName(odsState.name),
863 FuncOp::getResAttrsAttrName(odsState.name));
864}
865
866ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
867 auto buildFuncType =
868 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
869 mlir::function_interface_impl::VariadicFlag,
870 std::string &) { return builder.getFunctionType(argTypes, results); };
871
872 // This was added specifically for our implementation, upstream does not have
873 // this feature.
874 if (succeeded(parser.parseOptionalKeyword("externC")))
875 result.addAttribute(getExternCAttrName(result.name),
876 UnitAttr::get(result.getContext()));
877
878 // FIXME: below is an exact copy of the
879 // mlir::function_interface_impl::parseFunctionOp implementation, this was
880 // needed because we need to access the SSA names of the arguments.
881 SmallVector<OpAsmParser::Argument> entryArgs;
882 SmallVector<DictionaryAttr> resultAttrs;
883 SmallVector<Type> resultTypes;
884 auto &builder = parser.getBuilder();
885
886 // Parse visibility.
887 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
888
889 // Parse the name as a symbol.
890 StringAttr nameAttr;
891 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
892 result.attributes))
893 return failure();
894
895 // Parse the function signature.
896 mlir::SMLoc signatureLocation = parser.getCurrentLocation();
897 bool isVariadic = false;
898 if (mlir::function_interface_impl::parseFunctionSignatureWithArguments(
899 parser, false, entryArgs, isVariadic, resultTypes, resultAttrs))
900 return failure();
901
902 std::string errorMessage;
903 SmallVector<Type> argTypes;
904 argTypes.reserve(entryArgs.size());
905 for (auto &arg : entryArgs)
906 argTypes.push_back(arg.type);
907
908 Type type = buildFuncType(
909 builder, argTypes, resultTypes,
910 mlir::function_interface_impl::VariadicFlag(isVariadic), errorMessage);
911 if (!type) {
912 return parser.emitError(signatureLocation)
913 << "failed to construct function type"
914 << (errorMessage.empty() ? "" : ": ") << errorMessage;
915 }
916 result.addAttribute(FuncOp::getFunctionTypeAttrName(result.name),
917 TypeAttr::get(type));
918
919 // If function attributes are present, parse them.
920 NamedAttrList parsedAttributes;
921 mlir::SMLoc attributeDictLocation = parser.getCurrentLocation();
922 if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes))
923 return failure();
924
925 // Disallow attributes that are inferred from elsewhere in the attribute
926 // dictionary.
927 for (StringRef disallowed :
928 {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
929 FuncOp::getFunctionTypeAttrName(result.name).getValue()}) {
930 if (parsedAttributes.get(disallowed))
931 return parser.emitError(attributeDictLocation, "'")
932 << disallowed
933 << "' is an inferred attribute and should not be specified in the "
934 "explicit attribute dictionary";
935 }
936 result.attributes.append(parsedAttributes);
937
938 // Add the attributes to the function arguments.
939 assert(resultAttrs.size() == resultTypes.size());
940 mlir::call_interface_impl::addArgAndResultAttrs(
941 builder, result, entryArgs, resultAttrs,
942 FuncOp::getArgAttrsAttrName(result.name),
943 FuncOp::getResAttrsAttrName(result.name));
944
945 // Parse the optional function body. The printer will not print the body if
946 // its empty, so disallow parsing of empty body in the parser.
947 auto *body = result.addRegion();
948 mlir::SMLoc loc = parser.getCurrentLocation();
949 mlir::OptionalParseResult parseResult =
950 parser.parseOptionalRegion(*body, entryArgs,
951 /*enableNameShadowing=*/false);
952 if (parseResult.has_value()) {
953 if (failed(*parseResult))
954 return failure();
955 // Function body was parsed, make sure its not empty.
956 if (body->empty())
957 return parser.emitError(loc, "expected non-empty function body");
958 }
959
960 // Everythink below is added compared to the upstream implemenation to handle
961 // argument names.
962 SmallVector<Attribute> argNames;
963 if (!entryArgs.empty() && !entryArgs.front().ssaName.name.empty()) {
964 for (auto &arg : entryArgs)
965 argNames.push_back(
966 StringAttr::get(parser.getContext(), arg.ssaName.name.drop_front()));
967 }
968
969 result.addAttribute(getArgNamesAttrName(result.name),
970 ArrayAttr::get(parser.getContext(), argNames));
971
972 return success();
973}
974
975void FuncOp::print(OpAsmPrinter &p) {
976 if (getExternC())
977 p << " externC";
978
979 mlir::FunctionOpInterface op = *this;
980
981 // FIXME: inlined mlir::function_interface_impl::printFunctionOp because we
982 // need to elide more attributes
983
984 // Print the operation and the function name.
985 auto funcName =
986 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
987 .getValue();
988 p << ' ';
989
990 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
991 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
992 p << visibility.getValue() << ' ';
993 p.printSymbolName(funcName);
994
995 ArrayRef<Type> argTypes = op.getArgumentTypes();
996 ArrayRef<Type> resultTypes = op.getResultTypes();
997 mlir::function_interface_impl::printFunctionSignature(p, op, argTypes, false,
998 resultTypes);
999 mlir::function_interface_impl::printFunctionAttributes(
1000 p, op,
1001 {visibilityAttrName, "externC", "argNames", getFunctionTypeAttrName(),
1002 getArgAttrsAttrName(), getResAttrsAttrName()});
1003 // Print the body if this is not an external function.
1004 Region &body = op->getRegion(0);
1005 if (!body.empty()) {
1006 p << ' ';
1007 p.printRegion(body, /*printEntryBlockArgs=*/false,
1008 /*printBlockTerminators=*/true);
1009 }
1010}
1011
1012// FIXME: the below clone operation are exact copies from upstream.
1013
1014/// Clone the internal blocks from this function into dest and all attributes
1015/// from this function to dest.
1016void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
1017 // Add the attributes of this function to dest.
1018 llvm::MapVector<StringAttr, Attribute> newAttrMap;
1019 for (const auto &attr : dest->getAttrs())
1020 newAttrMap.insert({attr.getName(), attr.getValue()});
1021 for (const auto &attr : (*this)->getAttrs())
1022 newAttrMap.insert({attr.getName(), attr.getValue()});
1023
1024 auto newAttrs = llvm::to_vector(llvm::map_range(
1025 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
1026 return NamedAttribute(attrPair.first, attrPair.second);
1027 }));
1028 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
1029
1030 // Clone the body.
1031 getBody().cloneInto(&dest.getBody(), mapper);
1032}
1033
1034/// Create a deep copy of this function and all of its blocks, remapping
1035/// any operands that use values outside of the function using the map that is
1036/// provided (leaving them alone if no entry is present). Replaces references
1037/// to cloned sub-values with the corresponding value that is copied, and adds
1038/// those mappings to the mapper.
1039FuncOp FuncOp::clone(IRMapping &mapper) {
1040 // Create the new function.
1041 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
1042
1043 // If the function has a body, then the user might be deleting arguments to
1044 // the function by specifying them in the mapper. If so, we don't add the
1045 // argument to the input type vector.
1046 if (!isExternal()) {
1047 FunctionType oldType = getFunctionType();
1048
1049 unsigned oldNumArgs = oldType.getNumInputs();
1050 SmallVector<Type, 4> newInputs;
1051 newInputs.reserve(oldNumArgs);
1052 for (unsigned i = 0; i != oldNumArgs; ++i)
1053 if (!mapper.contains(getArgument(i)))
1054 newInputs.push_back(oldType.getInput(i));
1055
1056 /// If any of the arguments were dropped, update the type and drop any
1057 /// necessary argument attributes.
1058 if (newInputs.size() != oldNumArgs) {
1059 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
1060 oldType.getResults()));
1061
1062 if (ArrayAttr argAttrs = getAllArgAttrs()) {
1063 SmallVector<Attribute> newArgAttrs;
1064 newArgAttrs.reserve(newInputs.size());
1065 for (unsigned i = 0; i != oldNumArgs; ++i)
1066 if (!mapper.contains(getArgument(i)))
1067 newArgAttrs.push_back(argAttrs[i]);
1068 newFunc.setAllArgAttrs(newArgAttrs);
1069 }
1070 }
1071 }
1072
1073 /// Clone the current function into the new one and return it.
1074 cloneInto(newFunc, mapper);
1075 return newFunc;
1076}
1077
1078FuncOp FuncOp::clone() {
1079 IRMapping mapper;
1080 return clone(mapper);
1081}
1082
1083// The following functions are entirely new additions compared to upstream.
1084
1085void FuncOp::getAsmBlockArgumentNames(mlir::Region &region,
1086 mlir::OpAsmSetValueNameFn setNameFn) {
1087 if (region.empty())
1088 return;
1089
1090 for (auto [arg, name] : llvm::zip(getArguments(), getArgNames()))
1091 setNameFn(arg, cast<StringAttr>(name).getValue());
1092}
1093
1094LogicalResult FuncOp::verify() {
1095 if (getFunctionType().getNumResults() > 1)
1096 return emitOpError(
1097 "incorrect number of function results (always has to be 0 or 1)");
1098
1099 if (getBody().empty())
1100 return success();
1101
1102 if (getArgNames().size() != getFunctionType().getNumInputs())
1103 return emitOpError("incorrect number of argument names");
1104
1105 for (auto portName : getArgNames()) {
1106 if (cast<StringAttr>(portName).getValue().empty())
1107 return emitOpError("arg name must not be empty");
1108 }
1109
1110 return success();
1111}
1112
1113LogicalResult FuncOp::verifyRegions() {
1114 auto attachNote = [&](mlir::InFlightDiagnostic &diag) {
1115 diag.attachNote(getLoc()) << "in function '@" << getName() << "'";
1116 };
1117 return verifyUniqueNamesInRegion(getOperation(), getArgNames(), attachNote);
1118}
1119
1120//===----------------------------------------------------------------------===//
1121// ReturnOp
1122//
1123// TODO: The implementation for this operation was copy-pasted from the
1124// 'func' dialect. Ideally, this upstream dialect refactored such that we can
1125// re-use the implementation here.
1126//===----------------------------------------------------------------------===//
1127
1128LogicalResult ReturnOp::verify() {
1129 auto function = cast<FuncOp>((*this)->getParentOp());
1130
1131 // The operand number and types must match the function signature.
1132 const auto &results = function.getFunctionType().getResults();
1133 if (getNumOperands() != results.size())
1134 return emitOpError("has ")
1135 << getNumOperands() << " operands, but enclosing function (@"
1136 << function.getName() << ") returns " << results.size();
1137
1138 for (unsigned i = 0, e = results.size(); i != e; ++i)
1139 if (getOperand(i).getType() != results[i])
1140 return emitError() << "type of return operand " << i << " ("
1141 << getOperand(i).getType()
1142 << ") doesn't match function result type ("
1143 << results[i] << ")"
1144 << " in function @" << function.getName();
1145
1146 return success();
1147}
1148
1149//===----------------------------------------------------------------------===//
1150// TableGen generated logic.
1151//===----------------------------------------------------------------------===//
1152
1153// Provide the autogenerated implementation guts for the Op classes.
1154#define GET_OP_CLASSES
1155#include "circt/Dialect/SystemC/SystemC.cpp.inc"
assert(baseType &&"element must be base type")
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:215
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
static hw::ModulePort::Direction getDirection(Type type)
static Type wrapPortType(Type type, hw::ModulePort::Direction direction)
static LogicalResult verifyUniqueNamesInRegion(Operation *operation, ArrayAttr argNames, std::function< void(mlir::InFlightDiagnostic &)> attachNote)
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition HWSymCache.h:27
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition HWSymCache.h:56
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
void printModuleSignature(OpAsmPrinter &p, Operation *op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes, bool &needArgNamesAttr)
Print a module signature with named results.
ParseResult parseModuleFunctionSignature(OpAsmParser &parser, bool &isVariadic, SmallVectorImpl< OpAsmParser::Argument > &args, SmallVectorImpl< Attribute > &argNames, SmallVectorImpl< Attribute > &argLocs, SmallVectorImpl< Attribute > &resultNames, SmallVectorImpl< DictionaryAttr > &resultAttrs, SmallVectorImpl< Attribute > &resultLocs, TypeAttr &type)
This is a variant of mlir::parseFunctionSignature that allows names on result arguments.
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
void info(Twine message)
Definition LSPUtils.cpp:20
Type getSignalBaseType(Type type)
Get the type wrapped by a signal or port (in, inout, out) type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseImplicitSSAName(OpAsmParser &parser, StringAttr &attr)
Parse an implicit SSA name string attribute.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
This holds a decoded list of input/inout and output ports for a module or instance.
This holds the name, type, direction of a module's ports.