CIRCT 20.0.0git
Loading...
Searching...
No Matches
SystemCOps.cpp
Go to the documentation of this file.
1//===- SystemCOps.cpp - Implement the SystemC operations ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SystemC ops.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/IRMapping.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Interfaces/FunctionImplementation.h"
21#include "llvm/ADT/MapVector.h"
22#include "llvm/ADT/TypeSwitch.h"
23
24using namespace circt;
25using namespace circt::systemc;
26
27//===----------------------------------------------------------------------===//
28// Helpers
29//===----------------------------------------------------------------------===//
30
31static LogicalResult verifyUniqueNamesInRegion(
32 Operation *operation, ArrayAttr argNames,
33 std::function<void(mlir::InFlightDiagnostic &)> attachNote) {
34 DenseMap<StringRef, BlockArgument> portNames;
35 DenseMap<StringRef, Operation *> memberNames;
36 DenseMap<StringRef, Operation *> localNames;
37
38 if (operation->getNumRegions() != 1)
39 return operation->emitError("required to have exactly one region");
40
41 bool portsVerified = true;
42
43 for (auto arg : llvm::zip(argNames, operation->getRegion(0).getArguments())) {
44 StringRef argName = cast<StringAttr>(std::get<0>(arg)).getValue();
45 BlockArgument argValue = std::get<1>(arg);
46
47 if (portNames.count(argName)) {
48 auto diag = mlir::emitError(argValue.getLoc(), "redefines name '")
49 << argName << "'";
50 diag.attachNote(portNames[argName].getLoc())
51 << "'" << argName << "' first defined here";
52 attachNote(diag);
53 portsVerified = false;
54 continue;
55 }
56
57 portNames.insert({argName, argValue});
58 }
59
60 WalkResult result =
61 operation->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
62 if (isa<SCModuleOp>(op->getParentOp()))
63 localNames.clear();
64
65 if (auto nameDeclOp = dyn_cast<SystemCNameDeclOpInterface>(op)) {
66 StringRef name = nameDeclOp.getName();
67
68 auto reportNameRedefinition = [&](Location firstLoc) -> WalkResult {
69 auto diag = mlir::emitError(op->getLoc(), "redefines name '")
70 << name << "'";
71 diag.attachNote(firstLoc) << "'" << name << "' first defined here";
72 attachNote(diag);
73 return WalkResult::interrupt();
74 };
75
76 if (portNames.count(name))
77 return reportNameRedefinition(portNames[name].getLoc());
78 if (memberNames.count(name))
79 return reportNameRedefinition(memberNames[name]->getLoc());
80 if (localNames.count(name))
81 return reportNameRedefinition(localNames[name]->getLoc());
82
83 if (isa<SCModuleOp>(op->getParentOp()))
84 memberNames.insert({name, op});
85 else
86 localNames.insert({name, op});
87 }
88
89 return WalkResult::advance();
90 });
91
92 if (result.wasInterrupted() || !portsVerified)
93 return failure();
94
95 return success();
96}
97
98//===----------------------------------------------------------------------===//
99// SCModuleOp
100//===----------------------------------------------------------------------===//
101
103 return TypeSwitch<Type, hw::ModulePort::Direction>(type)
104 .Case<InOutType>([](auto ty) { return hw::ModulePort::Direction::InOut; })
105 .Case<InputType>([](auto ty) { return hw::ModulePort::Direction::Input; })
106 .Case<OutputType>(
107 [](auto ty) { return hw::ModulePort::Direction::Output; });
108}
109
110SCModuleOp::PortDirectionRange
111SCModuleOp::getPortsOfDirection(hw::ModulePort::Direction direction) {
112 std::function<bool(const BlockArgument &)> predicateFn =
113 [&](const BlockArgument &arg) -> bool {
114 return getDirection(arg.getType()) == direction;
115 };
116 return llvm::make_filter_range(getArguments(), predicateFn);
117}
118
119SmallVector<::circt::hw::PortInfo> SCModuleOp::getPortList() {
120 SmallVector<hw::PortInfo> ports;
121 for (int i = 0, e = getNumArguments(); i < e; ++i) {
122 hw::PortInfo info;
123 info.name = cast<StringAttr>(getPortNames()[i]);
124 info.type = getSignalBaseType(getArgument(i).getType());
125 info.dir = getDirection(info.type);
126 ports.push_back(info);
127 }
128 return ports;
129}
130
131mlir::Region *SCModuleOp::getCallableRegion() { return &getBody(); }
132
133StringRef SCModuleOp::getModuleName() {
134 return (*this)
135 ->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
136 .getValue();
137}
138
139ParseResult SCModuleOp::parse(OpAsmParser &parser, OperationState &result) {
140
141 // Parse the visibility attribute.
142 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
143
144 // Parse the name as a symbol.
145 StringAttr moduleName;
146 if (parser.parseSymbolName(moduleName, SymbolTable::getSymbolAttrName(),
147 result.attributes))
148 return failure();
149
150 // Parse the function signature.
151 bool isVariadic = false;
152 SmallVector<OpAsmParser::Argument, 4> entryArgs;
153 SmallVector<Attribute> argNames;
154 SmallVector<Attribute> argLocs;
155 SmallVector<Attribute> resultNames;
156 SmallVector<DictionaryAttr> resultAttrs;
157 SmallVector<Attribute> resultLocs;
158 TypeAttr functionType;
160 parser, isVariadic, entryArgs, argNames, argLocs, resultNames,
161 resultAttrs, resultLocs, functionType)))
162 return failure();
163
164 // Parse the attribute dict.
165 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
166 return failure();
167
168 result.addAttribute("portNames",
169 ArrayAttr::get(parser.getContext(), argNames));
170
171 result.addAttribute(SCModuleOp::getFunctionTypeAttrName(result.name),
172 functionType);
173
174 mlir::function_interface_impl::addArgAndResultAttrs(
175 parser.getBuilder(), result, entryArgs, resultAttrs,
176 SCModuleOp::getArgAttrsAttrName(result.name),
177 SCModuleOp::getResAttrsAttrName(result.name));
178
179 auto &body = *result.addRegion();
180 if (parser.parseRegion(body, entryArgs))
181 return failure();
182 if (body.empty())
183 body.push_back(std::make_unique<Block>().release());
184
185 return success();
186}
187
188void SCModuleOp::print(OpAsmPrinter &p) {
189 p << ' ';
190
191 // Print the visibility of the module.
192 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
193 if (auto visibility =
194 getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
195 p << visibility.getValue() << ' ';
196
197 p.printSymbolName(SymbolTable::getSymbolName(*this).getValue());
198 p << ' ';
199
200 bool needArgNamesAttr = false;
202 p, *this, getFunctionType().getInputs(), false, {}, needArgNamesAttr);
203 mlir::function_interface_impl::printFunctionAttributes(
204 p, *this,
205 {"portNames", getFunctionTypeAttrName(), getArgAttrsAttrName(),
206 getResAttrsAttrName()});
207
208 p << ' ';
209 p.printRegion(getBody(), false, false);
210}
211
212/// Returns the argument types of this function.
213ArrayRef<Type> SCModuleOp::getArgumentTypes() {
214 return getFunctionType().getInputs();
215}
216
217/// Returns the result types of this function.
218ArrayRef<Type> SCModuleOp::getResultTypes() {
219 return getFunctionType().getResults();
220}
221
222static Type wrapPortType(Type type, hw::ModulePort::Direction direction) {
223 if (auto inoutTy = dyn_cast<hw::InOutType>(type))
224 type = inoutTy.getElementType();
225
226 switch (direction) {
227 case hw::ModulePort::Direction::InOut:
228 return InOutType::get(type);
229 case hw::ModulePort::Direction::Input:
230 return InputType::get(type);
231 case hw::ModulePort::Direction::Output:
232 return OutputType::get(type);
233 }
234 llvm_unreachable("Impossible port direction");
235}
236
237void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
238 StringAttr name, ArrayAttr portNames,
239 ArrayRef<Type> portTypes,
240 ArrayRef<NamedAttribute> attributes) {
241 odsState.addAttribute(getPortNamesAttrName(odsState.name), portNames);
242 Region *region = odsState.addRegion();
243
244 auto moduleType = odsBuilder.getFunctionType(portTypes, {});
245 odsState.addAttribute(getFunctionTypeAttrName(odsState.name),
246 TypeAttr::get(moduleType));
247
248 odsState.addAttribute(SymbolTable::getSymbolAttrName(), name);
249 region->push_back(new Block);
250 region->addArguments(
251 portTypes,
252 SmallVector<Location>(portTypes.size(), odsBuilder.getUnknownLoc()));
253 odsState.addAttributes(attributes);
254}
255
256void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
257 StringAttr name, ArrayRef<hw::PortInfo> ports,
258 ArrayRef<NamedAttribute> attributes) {
259 MLIRContext *ctxt = odsBuilder.getContext();
260 SmallVector<Attribute> portNames;
261 SmallVector<Type> portTypes;
262 for (auto port : ports) {
263 portNames.push_back(StringAttr::get(ctxt, port.getName()));
264 portTypes.push_back(wrapPortType(port.type, port.dir));
265 }
266 build(odsBuilder, odsState, name, ArrayAttr::get(ctxt, portNames), portTypes);
267}
268
269void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
270 StringAttr name, const hw::ModulePortInfo &ports,
271 ArrayRef<NamedAttribute> attributes) {
272 MLIRContext *ctxt = odsBuilder.getContext();
273 SmallVector<Attribute> portNames;
274 SmallVector<Type> portTypes;
275 for (auto port : ports) {
276 portNames.push_back(StringAttr::get(ctxt, port.getName()));
277 portTypes.push_back(wrapPortType(port.type, port.dir));
278 }
279 build(odsBuilder, odsState, name, ArrayAttr::get(ctxt, portNames), portTypes);
280}
281
282void SCModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
283 mlir::OpAsmSetValueNameFn setNameFn) {
284 if (region.empty())
285 return;
286
287 ArrayAttr portNames = getPortNames();
288 for (size_t i = 0, e = getNumArguments(); i != e; ++i) {
289 auto str = cast<StringAttr>(portNames[i]).getValue();
290 setNameFn(getArgument(i), str);
291 }
292}
293
294LogicalResult SCModuleOp::verify() {
295 if (getFunctionType().getNumResults() != 0)
296 return emitOpError(
297 "incorrect number of function results (always has to be 0)");
298 if (getPortNames().size() != getFunctionType().getNumInputs())
299 return emitOpError("incorrect number of port names");
300
301 for (auto arg : getArguments()) {
302 if (!hw::type_isa<InputType, OutputType, InOutType>(arg.getType()))
303 return mlir::emitError(
304 arg.getLoc(),
305 "module port must be of type 'sc_in', 'sc_out', or 'sc_inout'");
306 }
307
308 for (auto portName : getPortNames()) {
309 if (cast<StringAttr>(portName).getValue().empty())
310 return emitOpError("port name must not be empty");
311 }
312
313 return success();
314}
315
316LogicalResult SCModuleOp::verifyRegions() {
317 auto attachNote = [&](mlir::InFlightDiagnostic &diag) {
318 diag.attachNote(getLoc()) << "in module '@" << getModuleName() << "'";
319 };
320 return verifyUniqueNamesInRegion(getOperation(), getPortNames(), attachNote);
321}
322
323CtorOp SCModuleOp::getOrCreateCtor() {
324 CtorOp ctor;
325 getBody().walk([&](Operation *op) {
326 if ((ctor = dyn_cast<CtorOp>(op)))
327 return WalkResult::interrupt();
328
329 return WalkResult::skip();
330 });
331
332 if (ctor)
333 return ctor;
334
335 return OpBuilder(getBody()).create<CtorOp>(getLoc());
336}
337
338DestructorOp SCModuleOp::getOrCreateDestructor() {
339 DestructorOp destructor;
340 getBody().walk([&](Operation *op) {
341 if ((destructor = dyn_cast<DestructorOp>(op)))
342 return WalkResult::interrupt();
343
344 return WalkResult::skip();
345 });
346
347 if (destructor)
348 return destructor;
349
350 return OpBuilder::atBlockEnd(getBodyBlock()).create<DestructorOp>(getLoc());
351}
352
353//===----------------------------------------------------------------------===//
354// SignalOp
355//===----------------------------------------------------------------------===//
356
357void SignalOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
358 setNameFn(getSignal(), getName());
359}
360
361//===----------------------------------------------------------------------===//
362// ConvertOp
363//===----------------------------------------------------------------------===//
364
365OpFoldResult ConvertOp::fold(FoldAdaptor) {
366 if (getInput().getType() == getResult().getType())
367 return getInput();
368
369 if (auto other = getInput().getDefiningOp<ConvertOp>()) {
370 Type inputType = other.getInput().getType();
371 Type intermediateType = getInput().getType();
372
373 if (inputType != getResult().getType())
374 return {};
375
376 // Either both the input and intermediate types are signed or both are
377 // unsigned.
378 bool inputSigned = isa<SignedType, IntBaseType>(inputType);
379 bool intermediateSigned = isa<SignedType, IntBaseType>(intermediateType);
380 if (inputSigned ^ intermediateSigned)
381 return {};
382
383 // Converting 4-valued to 2-valued and back may lose information.
384 if (isa<LogicVectorBaseType, LogicType>(inputType) &&
385 !isa<LogicVectorBaseType, LogicType>(intermediateType))
386 return {};
387
388 auto inputBw = getBitWidth(inputType);
389 auto intermediateBw = getBitWidth(intermediateType);
390
391 if (!inputBw && intermediateBw) {
392 if (isa<IntBaseType, UIntBaseType>(inputType) && *intermediateBw >= 64)
393 return other.getInput();
394 // We cannot support input types of signed, unsigned, and vector types
395 // since they have no upper bound for the bit-width.
396 }
397
398 if (!intermediateBw) {
399 if (isa<BitVectorBaseType, LogicVectorBaseType>(intermediateType))
400 return other.getInput();
401
402 if (!inputBw && isa<IntBaseType, UIntBaseType>(inputType) &&
403 isa<SignedType, UnsignedType>(intermediateType))
404 return other.getInput();
405
406 if (inputBw && *inputBw <= 64 &&
407 isa<IntBaseType, UIntBaseType, SignedType, UnsignedType>(
408 intermediateType))
409 return other.getInput();
410
411 // We have to be careful with the signed and unsigned types as they often
412 // have a max bit-width defined (that can be customized) and thus folding
413 // here could change the behavior.
414 }
415
416 if (inputBw && intermediateBw && *inputBw <= *intermediateBw)
417 return other.getInput();
418 }
419
420 return {};
421}
422
423//===----------------------------------------------------------------------===//
424// CtorOp
425//===----------------------------------------------------------------------===//
426
427LogicalResult CtorOp::verify() {
428 if (getBody().getNumArguments() != 0)
429 return emitOpError("must not have any arguments");
430
431 return success();
432}
433
434//===----------------------------------------------------------------------===//
435// SCFuncOp
436//===----------------------------------------------------------------------===//
437
438void SCFuncOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
439 setNameFn(getHandle(), getName());
440}
441
442LogicalResult SCFuncOp::verify() {
443 if (getBody().getNumArguments() != 0)
444 return emitOpError("must not have any arguments");
445
446 return success();
447}
448
449//===----------------------------------------------------------------------===//
450// InstanceDeclOp
451//===----------------------------------------------------------------------===//
452
453void InstanceDeclOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
454 setNameFn(getInstanceHandle(), getName());
455}
456
457StringRef InstanceDeclOp::getInstanceName() { return getName(); }
458StringAttr InstanceDeclOp::getInstanceNameAttr() { return getNameAttr(); }
459
460Operation *
461InstanceDeclOp::getReferencedModuleCached(const hw::HWSymbolCache *cache) {
462 if (cache)
463 if (auto *result = cache->getDefinition(getModuleNameAttr()))
464 return result;
465
466 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
467 return topLevelModuleOp.lookupSymbol(getModuleName());
468}
469
470LogicalResult
471InstanceDeclOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
472 auto *module =
473 symbolTable.lookupNearestSymbolFrom(*this, getModuleNameAttr());
474 if (module == nullptr)
475 return emitError("cannot find module definition '")
476 << getModuleName() << "'";
477
478 auto emitError = [&](const std::function<void(InFlightDiagnostic & diag)> &fn)
479 -> LogicalResult {
480 auto diag = emitOpError();
481 fn(diag);
482 diag.attachNote(module->getLoc()) << "module declared here";
483 return failure();
484 };
485
486 // It must be a systemc module.
487 if (!isa<SCModuleOp>(module))
488 return emitError([&](auto &diag) {
489 diag << "symbol reference '" << getModuleName()
490 << "' isn't a systemc module";
491 });
492
493 auto scModule = cast<SCModuleOp>(module);
494
495 // Check that the module name of the symbol and instance type match.
496 if (scModule.getModuleName() != getInstanceType().getModuleName())
497 return emitError([&](auto &diag) {
498 diag << "module names must match; expected '" << scModule.getModuleName()
499 << "' but got '" << getInstanceType().getModuleName().getValue()
500 << "'";
501 });
502
503 // Check that port types and names are consistent with the referenced module.
504 ArrayRef<ModuleType::PortInfo> ports = getInstanceType().getPorts();
505 ArrayAttr modArgNames = scModule.getPortNames();
506 auto numPorts = ports.size();
507 auto expectedPortTypes = scModule.getArgumentTypes();
508
509 if (expectedPortTypes.size() != numPorts)
510 return emitError([&](auto &diag) {
511 diag << "has a wrong number of ports; expected "
512 << expectedPortTypes.size() << " but got " << numPorts;
513 });
514
515 for (size_t i = 0; i != numPorts; ++i) {
516 if (ports[i].type != expectedPortTypes[i]) {
517 return emitError([&](auto &diag) {
518 diag << "port type #" << i << " must be " << expectedPortTypes[i]
519 << ", but got " << ports[i].type;
520 });
521 }
522
523 if (ports[i].name != modArgNames[i])
524 return emitError([&](auto &diag) {
525 diag << "port name #" << i << " must be " << modArgNames[i]
526 << ", but got " << ports[i].name;
527 });
528 }
529
530 return success();
531}
532
533SmallVector<hw::PortInfo> InstanceDeclOp::getPortList() {
534 return cast<hw::PortList>(SymbolTable::lookupNearestSymbolFrom(
535 getOperation(), getReferencedModuleNameAttr()))
536 .getPortList();
537}
538
539//===----------------------------------------------------------------------===//
540// DestructorOp
541//===----------------------------------------------------------------------===//
542
543LogicalResult DestructorOp::verify() {
544 if (getBody().getNumArguments() != 0)
545 return emitOpError("must not have any arguments");
546
547 return success();
548}
549
550//===----------------------------------------------------------------------===//
551// BindPortOp
552//===----------------------------------------------------------------------===//
553
554ParseResult BindPortOp::parse(OpAsmParser &parser, OperationState &result) {
555 OpAsmParser::UnresolvedOperand instance, channel;
556 std::string portName;
557 if (parser.parseOperand(instance) || parser.parseLSquare() ||
558 parser.parseString(&portName))
559 return failure();
560
561 auto portNameLoc = parser.getCurrentLocation();
562
563 if (parser.parseRSquare() || parser.parseKeyword("to") ||
564 parser.parseOperand(channel))
565 return failure();
566
567 if (parser.parseOptionalAttrDict(result.attributes))
568 return failure();
569
570 auto typeListLoc = parser.getCurrentLocation();
571 SmallVector<Type> types;
572 if (parser.parseColonTypeList(types))
573 return failure();
574
575 if (types.size() != 2)
576 return parser.emitError(typeListLoc,
577 "expected a list of exactly 2 types, but got ")
578 << types.size();
579
580 if (parser.resolveOperand(instance, types[0], result.operands))
581 return failure();
582 if (parser.resolveOperand(channel, types[1], result.operands))
583 return failure();
584
585 if (auto moduleType = dyn_cast<ModuleType>(types[0])) {
586 auto ports = moduleType.getPorts();
587 uint64_t index = 0;
588 for (auto port : ports) {
589 if (port.name == portName)
590 break;
591 index++;
592 }
593 if (index >= ports.size())
594 return parser.emitError(portNameLoc, "port name \"")
595 << portName << "\" not found in module";
596
597 result.addAttribute("portId", parser.getBuilder().getIndexAttr(index));
598
599 return success();
600 }
601
602 return failure();
603}
604
605void BindPortOp::print(OpAsmPrinter &p) {
606 p << " " << getInstance() << "["
607 << cast<ModuleType>(getInstance().getType())
608 .getPorts()[getPortId().getZExtValue()]
609 .name
610 << "] to " << getChannel();
611 p.printOptionalAttrDict((*this)->getAttrs(), {"portId"});
612 p << " : " << getInstance().getType() << ", " << getChannel().getType();
613}
614
615LogicalResult BindPortOp::verify() {
616 auto ports = cast<ModuleType>(getInstance().getType()).getPorts();
617 if (getPortId().getZExtValue() >= ports.size())
618 return emitOpError("port #")
619 << getPortId().getZExtValue() << " does not exist, there are only "
620 << ports.size() << " ports";
621
622 // Verify that the base types match.
623 Type portType = ports[getPortId().getZExtValue()].type;
624 Type channelType = getChannel().getType();
625 if (getSignalBaseType(portType) != getSignalBaseType(channelType))
626 return emitOpError() << portType << " port cannot be bound to "
627 << channelType << " channel due to base type mismatch";
628
629 // Verify that the port/channel directions are valid.
630 if ((isa<InputType>(portType) && isa<OutputType>(channelType)) ||
631 (isa<OutputType>(portType) && isa<InputType>(channelType)))
632 return emitOpError() << portType << " port cannot be bound to "
633 << channelType
634 << " channel due to port direction mismatch";
635
636 return success();
637}
638
639StringRef BindPortOp::getPortName() {
640 return cast<ModuleType>(getInstance().getType())
641 .getPorts()[getPortId().getZExtValue()]
642 .name.getValue();
643}
644
645//===----------------------------------------------------------------------===//
646// SensitiveOp
647//===----------------------------------------------------------------------===//
648
649LogicalResult SensitiveOp::canonicalize(SensitiveOp op,
650 PatternRewriter &rewriter) {
651 if (op.getSensitivities().empty()) {
652 rewriter.eraseOp(op);
653 return success();
654 }
655
656 return failure();
657}
658
659//===----------------------------------------------------------------------===//
660// VariableOp
661//===----------------------------------------------------------------------===//
662
663void VariableOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
664 setNameFn(getVariable(), getName());
665}
666
667ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
668 StringAttr nameAttr;
669 if (parseImplicitSSAName(parser, nameAttr))
670 return failure();
671 result.addAttribute("name", nameAttr);
672
673 OpAsmParser::UnresolvedOperand init;
674 auto initResult = parser.parseOptionalOperand(init);
675
676 if (parser.parseOptionalAttrDict(result.attributes))
677 return failure();
678
679 Type variableType;
680 if (parser.parseColonType(variableType))
681 return failure();
682
683 if (initResult.has_value()) {
684 if (parser.resolveOperand(init, variableType, result.operands))
685 return failure();
686 }
687 result.addTypes({variableType});
688
689 return success();
690}
691
692void VariableOp::print(::mlir::OpAsmPrinter &p) {
693 p << " ";
694
695 if (getInit())
696 p << getInit() << " ";
697
698 p.printOptionalAttrDict(getOperation()->getAttrs(), {"name"});
699 p << ": " << getVariable().getType();
700}
701
702LogicalResult VariableOp::verify() {
703 if (getInit() && getInit().getType() != getVariable().getType())
704 return emitOpError(
705 "'init' and 'variable' must have the same type, but got ")
706 << getInit().getType() << " and " << getVariable().getType();
707
708 return success();
709}
710
711//===----------------------------------------------------------------------===//
712// InteropVerilatedOp
713//===----------------------------------------------------------------------===//
714
715/// Create a instance that refers to a known module.
716void InteropVerilatedOp::build(OpBuilder &odsBuilder, OperationState &odsState,
717 Operation *module, StringAttr name,
718 ArrayRef<Value> inputs) {
719 auto mod = cast<hw::HWModuleLike>(module);
720 auto argNames = odsBuilder.getArrayAttr(mod.getInputNames());
721 auto resultNames = odsBuilder.getArrayAttr(mod.getOutputNames());
722 build(odsBuilder, odsState, mod.getHWModuleType().getOutputTypes(), name,
723 FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), argNames,
724 resultNames, inputs);
725}
726
727LogicalResult
728InteropVerilatedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
730 *this, getModuleNameAttr(), getInputs(), getResultTypes(),
731 getInputNames(), getResultNames(), ArrayAttr(), symbolTable);
732}
733
734/// Suggest a name for each result value based on the saved result names
735/// attribute.
736void InteropVerilatedOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
738 getResultNames(), getResults());
739}
740
741//===----------------------------------------------------------------------===//
742// CallOp
743//
744// TODO: The implementation for this operation was copy-pasted from the
745// 'func' dialect. Ideally, this upstream dialect refactored such that we can
746// re-use the implementation here.
747//===----------------------------------------------------------------------===//
748
749// FIXME: This is an exact copy from upstream
750LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
751 // Check that the callee attribute was specified.
752 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
753 if (!fnAttr)
754 return emitOpError("requires a 'callee' symbol reference attribute");
755 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
756 if (!fn)
757 return emitOpError() << "'" << fnAttr.getValue()
758 << "' does not reference a valid function";
759
760 // Verify that the operand and result types match the callee.
761 auto fnType = fn.getFunctionType();
762 if (fnType.getNumInputs() != getNumOperands())
763 return emitOpError("incorrect number of operands for callee");
764
765 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
766 if (getOperand(i).getType() != fnType.getInput(i))
767 return emitOpError("operand type mismatch: expected operand type ")
768 << fnType.getInput(i) << ", but provided "
769 << getOperand(i).getType() << " for operand number " << i;
770
771 if (fnType.getNumResults() != getNumResults())
772 return emitOpError("incorrect number of results for callee");
773
774 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
775 if (getResult(i).getType() != fnType.getResult(i)) {
776 auto diag = emitOpError("result type mismatch at index ") << i;
777 diag.attachNote() << " op result types: " << getResultTypes();
778 diag.attachNote() << "function result types: " << fnType.getResults();
779 return diag;
780 }
781
782 return success();
783}
784
785FunctionType CallOp::getCalleeType() {
786 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
787}
788
789// This verifier was added compared to the upstream implementation.
790LogicalResult CallOp::verify() {
791 if (getNumResults() > 1)
792 return emitOpError(
793 "incorrect number of function results (always has to be 0 or 1)");
794
795 return success();
796}
797
798//===----------------------------------------------------------------------===//
799// CallIndirectOp
800//===----------------------------------------------------------------------===//
801
802// This verifier was added compared to the upstream implementation.
803LogicalResult CallIndirectOp::verify() {
804 if (getNumResults() > 1)
805 return emitOpError(
806 "incorrect number of function results (always has to be 0 or 1)");
807
808 return success();
809}
810
811//===----------------------------------------------------------------------===//
812// FuncOp
813//
814// TODO: Most of the implementation for this operation was copy-pasted from the
815// 'func' dialect. Ideally, this upstream dialect refactored such that we can
816// re-use the implementation here.
817//===----------------------------------------------------------------------===//
818
819// Note that the create and build operations are taken from upstream, but the
820// argNames argument was added.
821FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
822 FunctionType type, ArrayRef<NamedAttribute> attrs) {
823 OpBuilder builder(location->getContext());
824 OperationState state(location, getOperationName());
825 FuncOp::build(builder, state, name, argNames, type, attrs);
826 return cast<FuncOp>(Operation::create(state));
827}
828
829FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
830 FunctionType type, Operation::dialect_attr_range attrs) {
831 SmallVector<NamedAttribute, 8> attrRef(attrs);
832 return create(location, name, argNames, type, ArrayRef(attrRef));
833}
834
835FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
836 FunctionType type, ArrayRef<NamedAttribute> attrs,
837 ArrayRef<DictionaryAttr> argAttrs) {
838 FuncOp func = create(location, name, argNames, type, attrs);
839 func.setAllArgAttrs(argAttrs);
840 return func;
841}
842
843void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
844 StringRef name, ArrayAttr argNames, FunctionType type,
845 ArrayRef<NamedAttribute> attrs,
846 ArrayRef<DictionaryAttr> argAttrs) {
847 odsState.addAttribute(getArgNamesAttrName(odsState.name), argNames);
848 odsState.addAttribute(SymbolTable::getSymbolAttrName(),
849 odsBuilder.getStringAttr(name));
850 odsState.addAttribute(FuncOp::getFunctionTypeAttrName(odsState.name),
851 TypeAttr::get(type));
852 odsState.attributes.append(attrs.begin(), attrs.end());
853 odsState.addRegion();
854
855 if (argAttrs.empty())
856 return;
857 assert(type.getNumInputs() == argAttrs.size());
858 mlir::function_interface_impl::addArgAndResultAttrs(
859 odsBuilder, odsState, argAttrs,
860 /*resultAttrs=*/std::nullopt, FuncOp::getArgAttrsAttrName(odsState.name),
861 FuncOp::getResAttrsAttrName(odsState.name));
862}
863
864ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
865 auto buildFuncType =
866 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
867 mlir::function_interface_impl::VariadicFlag,
868 std::string &) { return builder.getFunctionType(argTypes, results); };
869
870 // This was added specifically for our implementation, upstream does not have
871 // this feature.
872 if (succeeded(parser.parseOptionalKeyword("externC")))
873 result.addAttribute(getExternCAttrName(result.name),
874 UnitAttr::get(result.getContext()));
875
876 // FIXME: below is an exact copy of the
877 // mlir::function_interface_impl::parseFunctionOp implementation, this was
878 // needed because we need to access the SSA names of the arguments.
879 SmallVector<OpAsmParser::Argument> entryArgs;
880 SmallVector<DictionaryAttr> resultAttrs;
881 SmallVector<Type> resultTypes;
882 auto &builder = parser.getBuilder();
883
884 // Parse visibility.
885 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
886
887 // Parse the name as a symbol.
888 StringAttr nameAttr;
889 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
890 result.attributes))
891 return failure();
892
893 // Parse the function signature.
894 mlir::SMLoc signatureLocation = parser.getCurrentLocation();
895 bool isVariadic = false;
896 if (mlir::function_interface_impl::parseFunctionSignature(
897 parser, false, entryArgs, isVariadic, resultTypes, resultAttrs))
898 return failure();
899
900 std::string errorMessage;
901 SmallVector<Type> argTypes;
902 argTypes.reserve(entryArgs.size());
903 for (auto &arg : entryArgs)
904 argTypes.push_back(arg.type);
905
906 Type type = buildFuncType(
907 builder, argTypes, resultTypes,
908 mlir::function_interface_impl::VariadicFlag(isVariadic), errorMessage);
909 if (!type) {
910 return parser.emitError(signatureLocation)
911 << "failed to construct function type"
912 << (errorMessage.empty() ? "" : ": ") << errorMessage;
913 }
914 result.addAttribute(FuncOp::getFunctionTypeAttrName(result.name),
915 TypeAttr::get(type));
916
917 // If function attributes are present, parse them.
918 NamedAttrList parsedAttributes;
919 mlir::SMLoc attributeDictLocation = parser.getCurrentLocation();
920 if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes))
921 return failure();
922
923 // Disallow attributes that are inferred from elsewhere in the attribute
924 // dictionary.
925 for (StringRef disallowed :
926 {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
927 FuncOp::getFunctionTypeAttrName(result.name).getValue()}) {
928 if (parsedAttributes.get(disallowed))
929 return parser.emitError(attributeDictLocation, "'")
930 << disallowed
931 << "' is an inferred attribute and should not be specified in the "
932 "explicit attribute dictionary";
933 }
934 result.attributes.append(parsedAttributes);
935
936 // Add the attributes to the function arguments.
937 assert(resultAttrs.size() == resultTypes.size());
938 mlir::function_interface_impl::addArgAndResultAttrs(
939 builder, result, entryArgs, resultAttrs,
940 FuncOp::getArgAttrsAttrName(result.name),
941 FuncOp::getResAttrsAttrName(result.name));
942
943 // Parse the optional function body. The printer will not print the body if
944 // its empty, so disallow parsing of empty body in the parser.
945 auto *body = result.addRegion();
946 mlir::SMLoc loc = parser.getCurrentLocation();
947 mlir::OptionalParseResult parseResult =
948 parser.parseOptionalRegion(*body, entryArgs,
949 /*enableNameShadowing=*/false);
950 if (parseResult.has_value()) {
951 if (failed(*parseResult))
952 return failure();
953 // Function body was parsed, make sure its not empty.
954 if (body->empty())
955 return parser.emitError(loc, "expected non-empty function body");
956 }
957
958 // Everythink below is added compared to the upstream implemenation to handle
959 // argument names.
960 SmallVector<Attribute> argNames;
961 if (!entryArgs.empty() && !entryArgs.front().ssaName.name.empty()) {
962 for (auto &arg : entryArgs)
963 argNames.push_back(
964 StringAttr::get(parser.getContext(), arg.ssaName.name.drop_front()));
965 }
966
967 result.addAttribute(getArgNamesAttrName(result.name),
968 ArrayAttr::get(parser.getContext(), argNames));
969
970 return success();
971}
972
973void FuncOp::print(OpAsmPrinter &p) {
974 if (getExternC())
975 p << " externC";
976
977 mlir::FunctionOpInterface op = *this;
978
979 // FIXME: inlined mlir::function_interface_impl::printFunctionOp because we
980 // need to elide more attributes
981
982 // Print the operation and the function name.
983 auto funcName =
984 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
985 .getValue();
986 p << ' ';
987
988 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
989 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
990 p << visibility.getValue() << ' ';
991 p.printSymbolName(funcName);
992
993 ArrayRef<Type> argTypes = op.getArgumentTypes();
994 ArrayRef<Type> resultTypes = op.getResultTypes();
995 mlir::function_interface_impl::printFunctionSignature(p, op, argTypes, false,
996 resultTypes);
997 mlir::function_interface_impl::printFunctionAttributes(
998 p, op,
999 {visibilityAttrName, "externC", "argNames", getFunctionTypeAttrName(),
1000 getArgAttrsAttrName(), getResAttrsAttrName()});
1001 // Print the body if this is not an external function.
1002 Region &body = op->getRegion(0);
1003 if (!body.empty()) {
1004 p << ' ';
1005 p.printRegion(body, /*printEntryBlockArgs=*/false,
1006 /*printBlockTerminators=*/true);
1007 }
1008}
1009
1010// FIXME: the below clone operation are exact copies from upstream.
1011
1012/// Clone the internal blocks from this function into dest and all attributes
1013/// from this function to dest.
1014void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
1015 // Add the attributes of this function to dest.
1016 llvm::MapVector<StringAttr, Attribute> newAttrMap;
1017 for (const auto &attr : dest->getAttrs())
1018 newAttrMap.insert({attr.getName(), attr.getValue()});
1019 for (const auto &attr : (*this)->getAttrs())
1020 newAttrMap.insert({attr.getName(), attr.getValue()});
1021
1022 auto newAttrs = llvm::to_vector(llvm::map_range(
1023 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
1024 return NamedAttribute(attrPair.first, attrPair.second);
1025 }));
1026 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
1027
1028 // Clone the body.
1029 getBody().cloneInto(&dest.getBody(), mapper);
1030}
1031
1032/// Create a deep copy of this function and all of its blocks, remapping
1033/// any operands that use values outside of the function using the map that is
1034/// provided (leaving them alone if no entry is present). Replaces references
1035/// to cloned sub-values with the corresponding value that is copied, and adds
1036/// those mappings to the mapper.
1037FuncOp FuncOp::clone(IRMapping &mapper) {
1038 // Create the new function.
1039 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
1040
1041 // If the function has a body, then the user might be deleting arguments to
1042 // the function by specifying them in the mapper. If so, we don't add the
1043 // argument to the input type vector.
1044 if (!isExternal()) {
1045 FunctionType oldType = getFunctionType();
1046
1047 unsigned oldNumArgs = oldType.getNumInputs();
1048 SmallVector<Type, 4> newInputs;
1049 newInputs.reserve(oldNumArgs);
1050 for (unsigned i = 0; i != oldNumArgs; ++i)
1051 if (!mapper.contains(getArgument(i)))
1052 newInputs.push_back(oldType.getInput(i));
1053
1054 /// If any of the arguments were dropped, update the type and drop any
1055 /// necessary argument attributes.
1056 if (newInputs.size() != oldNumArgs) {
1057 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
1058 oldType.getResults()));
1059
1060 if (ArrayAttr argAttrs = getAllArgAttrs()) {
1061 SmallVector<Attribute> newArgAttrs;
1062 newArgAttrs.reserve(newInputs.size());
1063 for (unsigned i = 0; i != oldNumArgs; ++i)
1064 if (!mapper.contains(getArgument(i)))
1065 newArgAttrs.push_back(argAttrs[i]);
1066 newFunc.setAllArgAttrs(newArgAttrs);
1067 }
1068 }
1069 }
1070
1071 /// Clone the current function into the new one and return it.
1072 cloneInto(newFunc, mapper);
1073 return newFunc;
1074}
1075
1076FuncOp FuncOp::clone() {
1077 IRMapping mapper;
1078 return clone(mapper);
1079}
1080
1081// The following functions are entirely new additions compared to upstream.
1082
1083void FuncOp::getAsmBlockArgumentNames(mlir::Region &region,
1084 mlir::OpAsmSetValueNameFn setNameFn) {
1085 if (region.empty())
1086 return;
1087
1088 for (auto [arg, name] : llvm::zip(getArguments(), getArgNames()))
1089 setNameFn(arg, cast<StringAttr>(name).getValue());
1090}
1091
1092LogicalResult FuncOp::verify() {
1093 if (getFunctionType().getNumResults() > 1)
1094 return emitOpError(
1095 "incorrect number of function results (always has to be 0 or 1)");
1096
1097 if (getBody().empty())
1098 return success();
1099
1100 if (getArgNames().size() != getFunctionType().getNumInputs())
1101 return emitOpError("incorrect number of argument names");
1102
1103 for (auto portName : getArgNames()) {
1104 if (cast<StringAttr>(portName).getValue().empty())
1105 return emitOpError("arg name must not be empty");
1106 }
1107
1108 return success();
1109}
1110
1111LogicalResult FuncOp::verifyRegions() {
1112 auto attachNote = [&](mlir::InFlightDiagnostic &diag) {
1113 diag.attachNote(getLoc()) << "in function '@" << getName() << "'";
1114 };
1115 return verifyUniqueNamesInRegion(getOperation(), getArgNames(), attachNote);
1116}
1117
1118//===----------------------------------------------------------------------===//
1119// ReturnOp
1120//
1121// TODO: The implementation for this operation was copy-pasted from the
1122// 'func' dialect. Ideally, this upstream dialect refactored such that we can
1123// re-use the implementation here.
1124//===----------------------------------------------------------------------===//
1125
1126LogicalResult ReturnOp::verify() {
1127 auto function = cast<FuncOp>((*this)->getParentOp());
1128
1129 // The operand number and types must match the function signature.
1130 const auto &results = function.getFunctionType().getResults();
1131 if (getNumOperands() != results.size())
1132 return emitOpError("has ")
1133 << getNumOperands() << " operands, but enclosing function (@"
1134 << function.getName() << ") returns " << results.size();
1135
1136 for (unsigned i = 0, e = results.size(); i != e; ++i)
1137 if (getOperand(i).getType() != results[i])
1138 return emitError() << "type of return operand " << i << " ("
1139 << getOperand(i).getType()
1140 << ") doesn't match function result type ("
1141 << results[i] << ")"
1142 << " in function @" << function.getName();
1143
1144 return success();
1145}
1146
1147//===----------------------------------------------------------------------===//
1148// TableGen generated logic.
1149//===----------------------------------------------------------------------===//
1150
1151// Provide the autogenerated implementation guts for the Op classes.
1152#define GET_OP_CLASSES
1153#include "circt/Dialect/SystemC/SystemC.cpp.inc"
assert(baseType &&"element must be base type")
static InstancePath empty
static Block * getBodyBlock(FModuleLike mod)
static hw::ModulePort::Direction getDirection(Type type)
static Type wrapPortType(Type type, hw::ModulePort::Direction direction)
static LogicalResult verifyUniqueNamesInRegion(Operation *operation, ArrayAttr argNames, std::function< void(mlir::InFlightDiagnostic &)> attachNote)
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition HWSymCache.h:27
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition HWSymCache.h:56
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
void printModuleSignature(OpAsmPrinter &p, Operation *op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes, bool &needArgNamesAttr)
Print a module signature with named results.
ParseResult parseModuleFunctionSignature(OpAsmParser &parser, bool &isVariadic, SmallVectorImpl< OpAsmParser::Argument > &args, SmallVectorImpl< Attribute > &argNames, SmallVectorImpl< Attribute > &argLocs, SmallVectorImpl< Attribute > &resultNames, SmallVectorImpl< DictionaryAttr > &resultAttrs, SmallVectorImpl< Attribute > &resultLocs, TypeAttr &type)
This is a variant of mlir::parseFunctionSignature that allows names on result arguments.
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
Type getSignalBaseType(Type type)
Get the type wrapped by a signal or port (in, inout, out) type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseImplicitSSAName(OpAsmParser &parser, StringAttr &attr)
Parse an implicit SSA name string attribute.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:182
This holds a decoded list of input/inout and output ports for a module or instance.
mlir::Type type
Definition HWTypes.h:31
mlir::StringAttr name
Definition HWTypes.h:30
This holds the name, type, direction of a module's ports.