CIRCT 23.0.0git
Loading...
Searching...
No Matches
SystemCOps.cpp
Go to the documentation of this file.
1//===- SystemCOps.cpp - Implement the SystemC operations ------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the SystemC ops.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/IRMapping.h"
19#include "mlir/IR/PatternMatch.h"
20#include "mlir/Interfaces/FunctionImplementation.h"
21#include "llvm/ADT/MapVector.h"
22#include "llvm/ADT/TypeSwitch.h"
23
24using namespace circt;
25using namespace circt::systemc;
26
27//===----------------------------------------------------------------------===//
28// Helpers
29//===----------------------------------------------------------------------===//
30
31static LogicalResult verifyUniqueNamesInRegion(
32 Operation *operation, ArrayAttr argNames,
33 std::function<void(mlir::InFlightDiagnostic &)> attachNote) {
34 DenseMap<StringRef, BlockArgument> portNames;
35 DenseMap<StringRef, Operation *> memberNames;
36 DenseMap<StringRef, Operation *> localNames;
37
38 if (operation->getNumRegions() != 1)
39 return operation->emitError("required to have exactly one region");
40
41 bool portsVerified = true;
42
43 for (auto arg : llvm::zip(argNames, operation->getRegion(0).getArguments())) {
44 StringRef argName = cast<StringAttr>(std::get<0>(arg)).getValue();
45 BlockArgument argValue = std::get<1>(arg);
46
47 if (portNames.count(argName)) {
48 auto diag = mlir::emitError(argValue.getLoc(), "redefines name '")
49 << argName << "'";
50 diag.attachNote(portNames[argName].getLoc())
51 << "'" << argName << "' first defined here";
52 attachNote(diag);
53 portsVerified = false;
54 continue;
55 }
56
57 portNames.insert({argName, argValue});
58 }
59
60 WalkResult result =
61 operation->walk<mlir::WalkOrder::PreOrder>([&](Operation *op) {
62 if (isa<SCModuleOp>(op->getParentOp()))
63 localNames.clear();
64
65 if (auto nameDeclOp = dyn_cast<SystemCNameDeclOpInterface>(op)) {
66 StringRef name = nameDeclOp.getName();
67
68 auto reportNameRedefinition = [&](Location firstLoc) -> WalkResult {
69 auto diag = mlir::emitError(op->getLoc(), "redefines name '")
70 << name << "'";
71 diag.attachNote(firstLoc) << "'" << name << "' first defined here";
72 attachNote(diag);
73 return WalkResult::interrupt();
74 };
75
76 if (portNames.count(name))
77 return reportNameRedefinition(portNames[name].getLoc());
78 if (memberNames.count(name))
79 return reportNameRedefinition(memberNames[name]->getLoc());
80 if (localNames.count(name))
81 return reportNameRedefinition(localNames[name]->getLoc());
82
83 if (isa<SCModuleOp>(op->getParentOp()))
84 memberNames.insert({name, op});
85 else
86 localNames.insert({name, op});
87 }
88
89 return WalkResult::advance();
90 });
91
92 if (result.wasInterrupted() || !portsVerified)
93 return failure();
94
95 return success();
96}
97
98//===----------------------------------------------------------------------===//
99// SCModuleOp
100//===----------------------------------------------------------------------===//
101
103 return TypeSwitch<Type, hw::ModulePort::Direction>(type)
104 .Case<InOutType>([](auto ty) { return hw::ModulePort::Direction::InOut; })
105 .Case<InputType>([](auto ty) { return hw::ModulePort::Direction::Input; })
106 .Case<OutputType>(
107 [](auto ty) { return hw::ModulePort::Direction::Output; });
108}
109
110SCModuleOp::PortDirectionRange
111SCModuleOp::getPortsOfDirection(hw::ModulePort::Direction direction) {
112 std::function<bool(const BlockArgument &)> predicateFn =
113 [&](const BlockArgument &arg) -> bool {
114 return getDirection(arg.getType()) == direction;
115 };
116 return llvm::make_filter_range(getArguments(), predicateFn);
117}
118
119SmallVector<::circt::hw::PortInfo> SCModuleOp::getPortList() {
120 SmallVector<hw::PortInfo> ports;
121 for (int i = 0, e = getNumArguments(); i < e; ++i) {
123 info.name = cast<StringAttr>(getPortNames()[i]);
124 info.type = getSignalBaseType(getArgument(i).getType());
125 info.dir = getDirection(info.type);
126 ports.push_back(info);
127 }
128 return ports;
129}
130
131mlir::Region *SCModuleOp::getCallableRegion() { return &getBody(); }
132
133StringRef SCModuleOp::getModuleName() {
134 return (*this)
135 ->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
136 .getValue();
137}
138
139ParseResult SCModuleOp::parse(OpAsmParser &parser, OperationState &result) {
140
141 // Parse the visibility attribute.
142 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
143
144 // Parse the name as a symbol.
145 StringAttr moduleName;
146 if (parser.parseSymbolName(moduleName, SymbolTable::getSymbolAttrName(),
147 result.attributes))
148 return failure();
149
150 // Parse the function signature.
151 bool isVariadic = false;
152 SmallVector<OpAsmParser::Argument, 4> entryArgs;
153 SmallVector<Attribute> argNames;
154 SmallVector<Attribute> argLocs;
155 SmallVector<Attribute> resultNames;
156 SmallVector<DictionaryAttr> resultAttrs;
157 SmallVector<Attribute> resultLocs;
158 TypeAttr functionType;
160 parser, isVariadic, entryArgs, argNames, argLocs, resultNames,
161 resultAttrs, resultLocs, functionType)))
162 return failure();
163
164 // Parse the attribute dict.
165 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
166 return failure();
167
168 result.addAttribute("portNames",
169 ArrayAttr::get(parser.getContext(), argNames));
170
171 result.addAttribute(SCModuleOp::getFunctionTypeAttrName(result.name),
172 functionType);
173
174 mlir::call_interface_impl::addArgAndResultAttrs(
175 parser.getBuilder(), result, entryArgs, resultAttrs,
176 SCModuleOp::getArgAttrsAttrName(result.name),
177 SCModuleOp::getResAttrsAttrName(result.name));
178
179 auto &body = *result.addRegion();
180 if (parser.parseRegion(body, entryArgs))
181 return failure();
182 if (body.empty())
183 body.push_back(std::make_unique<Block>().release());
184
185 return success();
186}
187
188void SCModuleOp::print(OpAsmPrinter &p) {
189 p << ' ';
190
191 // Print the visibility of the module.
192 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
193 if (auto visibility =
194 getOperation()->getAttrOfType<StringAttr>(visibilityAttrName))
195 p << visibility.getValue() << ' ';
196
197 p.printSymbolName(SymbolTable::getSymbolName(*this).getValue());
198 p << ' ';
199
200 bool needArgNamesAttr = false;
202 p, *this, getFunctionType().getInputs(), false, {}, needArgNamesAttr);
203 mlir::function_interface_impl::printFunctionAttributes(
204 p, *this,
205 {"portNames", getFunctionTypeAttrName(), getArgAttrsAttrName(),
206 getResAttrsAttrName()});
207
208 p << ' ';
209 p.printRegion(getBody(), false, false);
210}
211
212/// Returns the argument types of this function.
213ArrayRef<Type> SCModuleOp::getArgumentTypes() {
214 return getFunctionType().getInputs();
215}
216
217/// Returns the result types of this function.
218ArrayRef<Type> SCModuleOp::getResultTypes() {
219 return getFunctionType().getResults();
220}
221
222static Type wrapPortType(Type type, hw::ModulePort::Direction direction) {
223 if (auto inoutTy = dyn_cast<hw::InOutType>(type))
224 type = inoutTy.getElementType();
225
226 switch (direction) {
227 case hw::ModulePort::Direction::InOut:
228 return InOutType::get(type);
229 case hw::ModulePort::Direction::Input:
230 return InputType::get(type);
231 case hw::ModulePort::Direction::Output:
232 return OutputType::get(type);
233 }
234 llvm_unreachable("Impossible port direction");
235}
236
237void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
238 StringAttr name, ArrayAttr portNames,
239 ArrayRef<Type> portTypes,
240 ArrayRef<NamedAttribute> attributes) {
241 odsState.addAttribute(getPortNamesAttrName(odsState.name), portNames);
242 Region *region = odsState.addRegion();
243
244 auto moduleType = odsBuilder.getFunctionType(portTypes, {});
245 odsState.addAttribute(getFunctionTypeAttrName(odsState.name),
246 TypeAttr::get(moduleType));
247
248 odsState.addAttribute(SymbolTable::getSymbolAttrName(), name);
249 region->push_back(new Block);
250 region->addArguments(
251 portTypes,
252 SmallVector<Location>(portTypes.size(), odsBuilder.getUnknownLoc()));
253 odsState.addAttributes(attributes);
254}
255
256void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
257 StringAttr name, ArrayRef<hw::PortInfo> ports,
258 ArrayRef<NamedAttribute> attributes) {
259 MLIRContext *ctxt = odsBuilder.getContext();
260 SmallVector<Attribute> portNames;
261 SmallVector<Type> portTypes;
262 for (auto port : ports) {
263 portNames.push_back(StringAttr::get(ctxt, port.getName()));
264 portTypes.push_back(wrapPortType(port.type, port.dir));
265 }
266 build(odsBuilder, odsState, name, ArrayAttr::get(ctxt, portNames), portTypes);
267}
268
269void SCModuleOp::build(OpBuilder &odsBuilder, OperationState &odsState,
270 StringAttr name, const hw::ModulePortInfo &ports,
271 ArrayRef<NamedAttribute> attributes) {
272 MLIRContext *ctxt = odsBuilder.getContext();
273 SmallVector<Attribute> portNames;
274 SmallVector<Type> portTypes;
275 for (auto port : ports) {
276 portNames.push_back(StringAttr::get(ctxt, port.getName()));
277 portTypes.push_back(wrapPortType(port.type, port.dir));
278 }
279 build(odsBuilder, odsState, name, ArrayAttr::get(ctxt, portNames), portTypes);
280}
281
282void SCModuleOp::getAsmBlockArgumentNames(mlir::Region &region,
283 mlir::OpAsmSetValueNameFn setNameFn) {
284 if (region.empty())
285 return;
286
287 ArrayAttr portNames = getPortNames();
288 for (size_t i = 0, e = getNumArguments(); i != e; ++i) {
289 auto str = cast<StringAttr>(portNames[i]).getValue();
290 setNameFn(getArgument(i), str);
291 }
292}
293
294LogicalResult SCModuleOp::verify() {
295 if (getFunctionType().getNumResults() != 0)
296 return emitOpError(
297 "incorrect number of function results (always has to be 0)");
298 if (getPortNames().size() != getFunctionType().getNumInputs())
299 return emitOpError("incorrect number of port names");
300
301 for (auto arg : getArguments()) {
302 if (!hw::type_isa<InputType, OutputType, InOutType>(arg.getType()))
303 return mlir::emitError(
304 arg.getLoc(),
305 "module port must be of type 'sc_in', 'sc_out', or 'sc_inout'");
306 }
307
308 for (auto portName : getPortNames()) {
309 if (cast<StringAttr>(portName).getValue().empty())
310 return emitOpError("port name must not be empty");
311 }
312
313 return success();
314}
315
316LogicalResult SCModuleOp::verifyRegions() {
317 auto attachNote = [&](mlir::InFlightDiagnostic &diag) {
318 diag.attachNote(getLoc()) << "in module '@" << getModuleName() << "'";
319 };
320 return verifyUniqueNamesInRegion(getOperation(), getPortNames(), attachNote);
321}
322
323CtorOp SCModuleOp::getOrCreateCtor(OpBuilder &builder) {
324 CtorOp ctor;
325 getBody().walk([&](Operation *op) {
326 if ((ctor = dyn_cast<CtorOp>(op)))
327 return WalkResult::interrupt();
328
329 return WalkResult::skip();
330 });
331
332 if (ctor)
333 return ctor;
334
335 OpBuilder::InsertionGuard guard(builder);
336 builder.setInsertionPoint(getBodyBlock(), getBodyBlock()->begin());
337 return CtorOp::create(builder, getLoc());
338}
339
340DestructorOp SCModuleOp::getOrCreateDestructor() {
341 DestructorOp destructor;
342 getBody().walk([&](Operation *op) {
343 if ((destructor = dyn_cast<DestructorOp>(op)))
344 return WalkResult::interrupt();
345
346 return WalkResult::skip();
347 });
348
349 if (destructor)
350 return destructor;
351
352 auto builder = OpBuilder::atBlockEnd(getBodyBlock());
353 return DestructorOp::create(builder, getLoc());
354}
355
356//===----------------------------------------------------------------------===//
357// SignalOp
358//===----------------------------------------------------------------------===//
359
360void SignalOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
361 setNameFn(getSignal(), getName());
362}
363
364//===----------------------------------------------------------------------===//
365// ConvertOp
366//===----------------------------------------------------------------------===//
367
368OpFoldResult ConvertOp::fold(FoldAdaptor) {
369 if (getInput().getType() == getResult().getType())
370 return getInput();
371
372 if (auto other = getInput().getDefiningOp<ConvertOp>()) {
373 Type inputType = other.getInput().getType();
374 Type intermediateType = getInput().getType();
375
376 if (inputType != getResult().getType())
377 return {};
378
379 // Either both the input and intermediate types are signed or both are
380 // unsigned.
381 bool inputSigned = isa<SignedType, IntBaseType>(inputType);
382 bool intermediateSigned = isa<SignedType, IntBaseType>(intermediateType);
383 if (inputSigned ^ intermediateSigned)
384 return {};
385
386 // Converting 4-valued to 2-valued and back may lose information.
387 if (isa<LogicVectorBaseType, LogicType>(inputType) &&
388 !isa<LogicVectorBaseType, LogicType>(intermediateType))
389 return {};
390
391 auto inputBw = getBitWidth(inputType);
392 auto intermediateBw = getBitWidth(intermediateType);
393
394 if (!inputBw && intermediateBw) {
395 if (isa<IntBaseType, UIntBaseType>(inputType) && *intermediateBw >= 64)
396 return other.getInput();
397 // We cannot support input types of signed, unsigned, and vector types
398 // since they have no upper bound for the bit-width.
399 }
400
401 if (!intermediateBw) {
402 if (isa<BitVectorBaseType, LogicVectorBaseType>(intermediateType))
403 return other.getInput();
404
405 if (!inputBw && isa<IntBaseType, UIntBaseType>(inputType) &&
406 isa<SignedType, UnsignedType>(intermediateType))
407 return other.getInput();
408
409 if (inputBw && *inputBw <= 64 &&
410 isa<IntBaseType, UIntBaseType, SignedType, UnsignedType>(
411 intermediateType))
412 return other.getInput();
413
414 // We have to be careful with the signed and unsigned types as they often
415 // have a max bit-width defined (that can be customized) and thus folding
416 // here could change the behavior.
417 }
418
419 if (inputBw && intermediateBw && *inputBw <= *intermediateBw)
420 return other.getInput();
421 }
422
423 return {};
424}
425
426//===----------------------------------------------------------------------===//
427// CtorOp
428//===----------------------------------------------------------------------===//
429
430LogicalResult CtorOp::verify() {
431 if (getBody().getNumArguments() != 0)
432 return emitOpError("must not have any arguments");
433
434 return success();
435}
436
437//===----------------------------------------------------------------------===//
438// SCFuncOp
439//===----------------------------------------------------------------------===//
440
441void SCFuncOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
442 setNameFn(getHandle(), getName());
443}
444
445LogicalResult SCFuncOp::verify() {
446 if (getBody().getNumArguments() != 0)
447 return emitOpError("must not have any arguments");
448
449 return success();
450}
451
452//===----------------------------------------------------------------------===//
453// InstanceDeclOp
454//===----------------------------------------------------------------------===//
455
456void InstanceDeclOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
457 setNameFn(getInstanceHandle(), getName());
458}
459
460StringRef InstanceDeclOp::getInstanceName() { return getName(); }
461StringAttr InstanceDeclOp::getInstanceNameAttr() { return getNameAttr(); }
462
463Operation *
464InstanceDeclOp::getReferencedModuleCached(const hw::HWSymbolCache *cache) {
465 if (cache)
466 if (auto *result = cache->getDefinition(getModuleNameAttr()))
467 return result;
468
469 auto topLevelModuleOp = (*this)->getParentOfType<ModuleOp>();
470 return topLevelModuleOp.lookupSymbol(getModuleName());
471}
472
473LogicalResult
474InstanceDeclOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
475 auto *module =
476 symbolTable.lookupNearestSymbolFrom(*this, getModuleNameAttr());
477 if (module == nullptr)
478 return emitError("cannot find module definition '")
479 << getModuleName() << "'";
480
481 auto emitError = [&](const std::function<void(InFlightDiagnostic & diag)> &fn)
482 -> LogicalResult {
483 auto diag = emitOpError();
484 fn(diag);
485 diag.attachNote(module->getLoc()) << "module declared here";
486 return failure();
487 };
488
489 // It must be a systemc module.
490 if (!isa<SCModuleOp>(module))
491 return emitError([&](auto &diag) {
492 diag << "symbol reference '" << getModuleName()
493 << "' isn't a systemc module";
494 });
495
496 auto scModule = cast<SCModuleOp>(module);
497
498 // Check that the module name of the symbol and instance type match.
499 if (scModule.getModuleName() != getInstanceType().getModuleName())
500 return emitError([&](auto &diag) {
501 diag << "module names must match; expected '" << scModule.getModuleName()
502 << "' but got '" << getInstanceType().getModuleName().getValue()
503 << "'";
504 });
505
506 // Check that port types and names are consistent with the referenced module.
507 ArrayRef<ModuleType::PortInfo> ports = getInstanceType().getPorts();
508 ArrayAttr modArgNames = scModule.getPortNames();
509 auto numPorts = ports.size();
510 auto expectedPortTypes = scModule.getArgumentTypes();
511
512 if (expectedPortTypes.size() != numPorts)
513 return emitError([&](auto &diag) {
514 diag << "has a wrong number of ports; expected "
515 << expectedPortTypes.size() << " but got " << numPorts;
516 });
517
518 for (size_t i = 0; i != numPorts; ++i) {
519 if (ports[i].type != expectedPortTypes[i]) {
520 return emitError([&](auto &diag) {
521 diag << "port type #" << i << " must be " << expectedPortTypes[i]
522 << ", but got " << ports[i].type;
523 });
524 }
525
526 if (ports[i].name != modArgNames[i])
527 return emitError([&](auto &diag) {
528 diag << "port name #" << i << " must be " << modArgNames[i]
529 << ", but got " << ports[i].name;
530 });
531 }
532
533 return success();
534}
535
536SmallVector<hw::PortInfo> InstanceDeclOp::getPortList() {
537 return cast<hw::PortList>(SymbolTable::lookupNearestSymbolFrom(
538 getOperation(), getReferencedModuleNameAttr()))
539 .getPortList();
540}
541
542//===----------------------------------------------------------------------===//
543// DestructorOp
544//===----------------------------------------------------------------------===//
545
546LogicalResult DestructorOp::verify() {
547 if (getBody().getNumArguments() != 0)
548 return emitOpError("must not have any arguments");
549
550 return success();
551}
552
553//===----------------------------------------------------------------------===//
554// BindPortOp
555//===----------------------------------------------------------------------===//
556
557ParseResult BindPortOp::parse(OpAsmParser &parser, OperationState &result) {
558 OpAsmParser::UnresolvedOperand instance, channel;
559 std::string portName;
560 if (parser.parseOperand(instance) || parser.parseLSquare() ||
561 parser.parseString(&portName))
562 return failure();
563
564 auto portNameLoc = parser.getCurrentLocation();
565
566 if (parser.parseRSquare() || parser.parseKeyword("to") ||
567 parser.parseOperand(channel))
568 return failure();
569
570 if (parser.parseOptionalAttrDict(result.attributes))
571 return failure();
572
573 auto typeListLoc = parser.getCurrentLocation();
574 SmallVector<Type> types;
575 if (parser.parseColonTypeList(types))
576 return failure();
577
578 if (types.size() != 2)
579 return parser.emitError(typeListLoc,
580 "expected a list of exactly 2 types, but got ")
581 << types.size();
582
583 if (parser.resolveOperand(instance, types[0], result.operands))
584 return failure();
585 if (parser.resolveOperand(channel, types[1], result.operands))
586 return failure();
587
588 if (auto moduleType = dyn_cast<ModuleType>(types[0])) {
589 auto ports = moduleType.getPorts();
590 uint64_t index = 0;
591 for (auto port : ports) {
592 if (port.name == portName)
593 break;
594 index++;
595 }
596 if (index >= ports.size())
597 return parser.emitError(portNameLoc, "port name \"")
598 << portName << "\" not found in module";
599
600 result.addAttribute("portId", parser.getBuilder().getIndexAttr(index));
601
602 return success();
603 }
604
605 return failure();
606}
607
608void BindPortOp::print(OpAsmPrinter &p) {
609 p << " " << getInstance() << "["
610 << cast<ModuleType>(getInstance().getType())
611 .getPorts()[getPortId().getZExtValue()]
612 .name
613 << "] to " << getChannel();
614 p.printOptionalAttrDict((*this)->getAttrs(), {"portId"});
615 p << " : " << getInstance().getType() << ", " << getChannel().getType();
616}
617
618LogicalResult BindPortOp::verify() {
619 auto ports = cast<ModuleType>(getInstance().getType()).getPorts();
620 if (getPortId().getZExtValue() >= ports.size())
621 return emitOpError("port #")
622 << getPortId().getZExtValue() << " does not exist, there are only "
623 << ports.size() << " ports";
624
625 // Verify that the base types match.
626 Type portType = ports[getPortId().getZExtValue()].type;
627 Type channelType = getChannel().getType();
628 if (getSignalBaseType(portType) != getSignalBaseType(channelType))
629 return emitOpError() << portType << " port cannot be bound to "
630 << channelType << " channel due to base type mismatch";
631
632 // Verify that the port/channel directions are valid.
633 if ((isa<InputType>(portType) && isa<OutputType>(channelType)) ||
634 (isa<OutputType>(portType) && isa<InputType>(channelType)))
635 return emitOpError() << portType << " port cannot be bound to "
636 << channelType
637 << " channel due to port direction mismatch";
638
639 return success();
640}
641
642StringRef BindPortOp::getPortName() {
643 return cast<ModuleType>(getInstance().getType())
644 .getPorts()[getPortId().getZExtValue()]
645 .name.getValue();
646}
647
648//===----------------------------------------------------------------------===//
649// SensitiveOp
650//===----------------------------------------------------------------------===//
651
652LogicalResult SensitiveOp::canonicalize(SensitiveOp op,
653 PatternRewriter &rewriter) {
654 if (op.getSensitivities().empty()) {
655 rewriter.eraseOp(op);
656 return success();
657 }
658
659 return failure();
660}
661
662//===----------------------------------------------------------------------===//
663// VariableOp
664//===----------------------------------------------------------------------===//
665
666void VariableOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
667 setNameFn(getVariable(), getName());
668}
669
670ParseResult VariableOp::parse(OpAsmParser &parser, OperationState &result) {
671 StringAttr nameAttr;
672 if (parseImplicitSSAName(parser, nameAttr))
673 return failure();
674 result.addAttribute("name", nameAttr);
675
676 OpAsmParser::UnresolvedOperand init;
677 auto initResult = parser.parseOptionalOperand(init);
678
679 if (parser.parseOptionalAttrDict(result.attributes))
680 return failure();
681
682 Type variableType;
683 if (parser.parseColonType(variableType))
684 return failure();
685
686 if (initResult.has_value()) {
687 if (parser.resolveOperand(init, variableType, result.operands))
688 return failure();
689 }
690 result.addTypes({variableType});
691
692 return success();
693}
694
695void VariableOp::print(::mlir::OpAsmPrinter &p) {
696 p << " ";
697
698 if (getInit())
699 p << getInit() << " ";
700
701 p.printOptionalAttrDict(getOperation()->getAttrs(), {"name"});
702 p << ": " << getVariable().getType();
703}
704
705LogicalResult VariableOp::verify() {
706 if (getInit() && getInit().getType() != getVariable().getType())
707 return emitOpError(
708 "'init' and 'variable' must have the same type, but got ")
709 << getInit().getType() << " and " << getVariable().getType();
710
711 return success();
712}
713
714//===----------------------------------------------------------------------===//
715// InteropVerilatedOp
716//===----------------------------------------------------------------------===//
717
718/// Create a instance that refers to a known module.
719void InteropVerilatedOp::build(OpBuilder &odsBuilder, OperationState &odsState,
720 Operation *module, StringAttr name,
721 ArrayRef<Value> inputs) {
722 auto mod = cast<hw::HWModuleLike>(module);
723 auto argNames = odsBuilder.getArrayAttr(mod.getInputNames());
724 auto resultNames = odsBuilder.getArrayAttr(mod.getOutputNames());
725 build(odsBuilder, odsState, mod.getHWModuleType().getOutputTypes(), name,
726 FlatSymbolRefAttr::get(SymbolTable::getSymbolName(module)), argNames,
727 resultNames, inputs);
728}
729
730LogicalResult
731InteropVerilatedOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
733 *this, getModuleNameAttr(), getInputs(), getResultTypes(),
734 getInputNames(), getResultNames(), ArrayAttr(), symbolTable);
735}
736
737/// Suggest a name for each result value based on the saved result names
738/// attribute.
739void InteropVerilatedOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
741 getResultNames(), getResults());
742}
743
744//===----------------------------------------------------------------------===//
745// CallOp
746//
747// TODO: The implementation for this operation was copy-pasted from the
748// 'func' dialect. Ideally, this upstream dialect refactored such that we can
749// re-use the implementation here.
750//===----------------------------------------------------------------------===//
751
752// FIXME: This is an exact copy from upstream
753LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
754 // Check that the callee attribute was specified.
755 auto fnAttr = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
756 if (!fnAttr)
757 return emitOpError("requires a 'callee' symbol reference attribute");
758 FuncOp fn = symbolTable.lookupNearestSymbolFrom<FuncOp>(*this, fnAttr);
759 if (!fn)
760 return emitOpError() << "'" << fnAttr.getValue()
761 << "' does not reference a valid function";
762
763 // Verify that the operand and result types match the callee.
764 auto fnType = fn.getFunctionType();
765 if (fnType.getNumInputs() != getNumOperands())
766 return emitOpError("incorrect number of operands for callee");
767
768 for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i)
769 if (getOperand(i).getType() != fnType.getInput(i))
770 return emitOpError("operand type mismatch: expected operand type ")
771 << fnType.getInput(i) << ", but provided "
772 << getOperand(i).getType() << " for operand number " << i;
773
774 if (fnType.getNumResults() != getNumResults())
775 return emitOpError("incorrect number of results for callee");
776
777 for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i)
778 if (getResult(i).getType() != fnType.getResult(i)) {
779 auto diag = emitOpError("result type mismatch at index ") << i;
780 diag.attachNote() << " op result types: " << getResultTypes();
781 diag.attachNote() << "function result types: " << fnType.getResults();
782 return diag;
783 }
784
785 return success();
786}
787
788FunctionType CallOp::getCalleeType() {
789 return FunctionType::get(getContext(), getOperandTypes(), getResultTypes());
790}
791
792// This verifier was added compared to the upstream implementation.
793LogicalResult CallOp::verify() {
794 if (getNumResults() > 1)
795 return emitOpError(
796 "incorrect number of function results (always has to be 0 or 1)");
797
798 return success();
799}
800
801//===----------------------------------------------------------------------===//
802// CallIndirectOp
803//===----------------------------------------------------------------------===//
804
805// This verifier was added compared to the upstream implementation.
806LogicalResult CallIndirectOp::verify() {
807 if (getNumResults() > 1)
808 return emitOpError(
809 "incorrect number of function results (always has to be 0 or 1)");
810
811 return success();
812}
813
814//===----------------------------------------------------------------------===//
815// FuncOp
816//
817// TODO: Most of the implementation for this operation was copy-pasted from the
818// 'func' dialect. Ideally, this upstream dialect refactored such that we can
819// re-use the implementation here.
820//===----------------------------------------------------------------------===//
821
822// Note that the create and build operations are taken from upstream, but the
823// argNames argument was added.
824FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
825 FunctionType type, ArrayRef<NamedAttribute> attrs) {
826 OpBuilder builder(location->getContext());
827 OperationState state(location, getOperationName());
828 FuncOp::build(builder, state, name, argNames, type, attrs);
829 return cast<FuncOp>(Operation::create(state));
830}
831
832FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
833 FunctionType type, Operation::dialect_attr_range attrs) {
834 SmallVector<NamedAttribute, 8> attrRef(attrs);
835 return create(location, name, argNames, type, ArrayRef(attrRef));
836}
837
838FuncOp FuncOp::create(Location location, StringRef name, ArrayAttr argNames,
839 FunctionType type, ArrayRef<NamedAttribute> attrs,
840 ArrayRef<DictionaryAttr> argAttrs) {
841 FuncOp func = create(location, name, argNames, type, attrs);
842 func.setAllArgAttrs(argAttrs);
843 return func;
844}
845
846void FuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
847 StringRef name, ArrayAttr argNames, FunctionType type,
848 ArrayRef<NamedAttribute> attrs,
849 ArrayRef<DictionaryAttr> argAttrs) {
850 odsState.addAttribute(getArgNamesAttrName(odsState.name), argNames);
851 odsState.addAttribute(SymbolTable::getSymbolAttrName(),
852 odsBuilder.getStringAttr(name));
853 odsState.addAttribute(FuncOp::getFunctionTypeAttrName(odsState.name),
854 TypeAttr::get(type));
855 odsState.attributes.append(attrs.begin(), attrs.end());
856 odsState.addRegion();
857
858 if (argAttrs.empty())
859 return;
860 assert(type.getNumInputs() == argAttrs.size());
861 mlir::call_interface_impl::addArgAndResultAttrs(
862 odsBuilder, odsState, argAttrs,
863 /*resultAttrs=*/{}, FuncOp::getArgAttrsAttrName(odsState.name),
864 FuncOp::getResAttrsAttrName(odsState.name));
865}
866
867ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
868 auto buildFuncType =
869 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
870 mlir::function_interface_impl::VariadicFlag,
871 std::string &) { return builder.getFunctionType(argTypes, results); };
872
873 // This was added specifically for our implementation, upstream does not have
874 // this feature.
875 if (succeeded(parser.parseOptionalKeyword("externC")))
876 result.addAttribute(getExternCAttrName(result.name),
877 UnitAttr::get(result.getContext()));
878
879 // FIXME: below is an exact copy of the
880 // mlir::function_interface_impl::parseFunctionOp implementation, this was
881 // needed because we need to access the SSA names of the arguments.
882 SmallVector<OpAsmParser::Argument> entryArgs;
883 SmallVector<DictionaryAttr> resultAttrs;
884 SmallVector<Type> resultTypes;
885 auto &builder = parser.getBuilder();
886
887 // Parse visibility.
888 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
889
890 // Parse the name as a symbol.
891 StringAttr nameAttr;
892 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
893 result.attributes))
894 return failure();
895
896 // Parse the function signature.
897 mlir::SMLoc signatureLocation = parser.getCurrentLocation();
898 bool isVariadic = false;
899 if (mlir::function_interface_impl::parseFunctionSignatureWithArguments(
900 parser, false, entryArgs, isVariadic, resultTypes, resultAttrs))
901 return failure();
902
903 std::string errorMessage;
904 SmallVector<Type> argTypes;
905 argTypes.reserve(entryArgs.size());
906 for (auto &arg : entryArgs)
907 argTypes.push_back(arg.type);
908
909 Type type = buildFuncType(
910 builder, argTypes, resultTypes,
911 mlir::function_interface_impl::VariadicFlag(isVariadic), errorMessage);
912 if (!type) {
913 return parser.emitError(signatureLocation)
914 << "failed to construct function type"
915 << (errorMessage.empty() ? "" : ": ") << errorMessage;
916 }
917 result.addAttribute(FuncOp::getFunctionTypeAttrName(result.name),
918 TypeAttr::get(type));
919
920 // If function attributes are present, parse them.
921 NamedAttrList parsedAttributes;
922 mlir::SMLoc attributeDictLocation = parser.getCurrentLocation();
923 if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes))
924 return failure();
925
926 // Disallow attributes that are inferred from elsewhere in the attribute
927 // dictionary.
928 for (StringRef disallowed :
929 {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(),
930 FuncOp::getFunctionTypeAttrName(result.name).getValue()}) {
931 if (parsedAttributes.get(disallowed))
932 return parser.emitError(attributeDictLocation, "'")
933 << disallowed
934 << "' is an inferred attribute and should not be specified in the "
935 "explicit attribute dictionary";
936 }
937 result.attributes.append(parsedAttributes);
938
939 // Add the attributes to the function arguments.
940 assert(resultAttrs.size() == resultTypes.size());
941 mlir::call_interface_impl::addArgAndResultAttrs(
942 builder, result, entryArgs, resultAttrs,
943 FuncOp::getArgAttrsAttrName(result.name),
944 FuncOp::getResAttrsAttrName(result.name));
945
946 // Parse the optional function body. The printer will not print the body if
947 // its empty, so disallow parsing of empty body in the parser.
948 auto *body = result.addRegion();
949 mlir::SMLoc loc = parser.getCurrentLocation();
950 mlir::OptionalParseResult parseResult =
951 parser.parseOptionalRegion(*body, entryArgs,
952 /*enableNameShadowing=*/false);
953 if (parseResult.has_value()) {
954 if (failed(*parseResult))
955 return failure();
956 // Function body was parsed, make sure its not empty.
957 if (body->empty())
958 return parser.emitError(loc, "expected non-empty function body");
959 }
960
961 // Everythink below is added compared to the upstream implemenation to handle
962 // argument names.
963 SmallVector<Attribute> argNames;
964 if (!entryArgs.empty() && !entryArgs.front().ssaName.name.empty()) {
965 for (auto &arg : entryArgs)
966 argNames.push_back(
967 StringAttr::get(parser.getContext(), arg.ssaName.name.drop_front()));
968 }
969
970 result.addAttribute(getArgNamesAttrName(result.name),
971 ArrayAttr::get(parser.getContext(), argNames));
972
973 return success();
974}
975
976void FuncOp::print(OpAsmPrinter &p) {
977 if (getExternC())
978 p << " externC";
979
980 mlir::FunctionOpInterface op = *this;
981
982 // FIXME: inlined mlir::function_interface_impl::printFunctionOp because we
983 // need to elide more attributes
984
985 // Print the operation and the function name.
986 auto funcName =
987 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
988 .getValue();
989 p << ' ';
990
991 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
992 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
993 p << visibility.getValue() << ' ';
994 p.printSymbolName(funcName);
995
996 ArrayRef<Type> argTypes = op.getArgumentTypes();
997 ArrayRef<Type> resultTypes = op.getResultTypes();
998 mlir::function_interface_impl::printFunctionSignature(p, op, argTypes, false,
999 resultTypes);
1000 mlir::function_interface_impl::printFunctionAttributes(
1001 p, op,
1002 {visibilityAttrName, "externC", "argNames", getFunctionTypeAttrName(),
1003 getArgAttrsAttrName(), getResAttrsAttrName()});
1004 // Print the body if this is not an external function.
1005 Region &body = op->getRegion(0);
1006 if (!body.empty()) {
1007 p << ' ';
1008 p.printRegion(body, /*printEntryBlockArgs=*/false,
1009 /*printBlockTerminators=*/true);
1010 }
1011}
1012
1013// FIXME: the below clone operation are exact copies from upstream.
1014
1015/// Clone the internal blocks from this function into dest and all attributes
1016/// from this function to dest.
1017void FuncOp::cloneInto(FuncOp dest, IRMapping &mapper) {
1018 // Add the attributes of this function to dest.
1019 llvm::MapVector<StringAttr, Attribute> newAttrMap;
1020 for (const auto &attr : dest->getAttrs())
1021 newAttrMap.insert({attr.getName(), attr.getValue()});
1022 for (const auto &attr : (*this)->getAttrs())
1023 newAttrMap.insert({attr.getName(), attr.getValue()});
1024
1025 auto newAttrs = llvm::to_vector(llvm::map_range(
1026 newAttrMap, [](std::pair<StringAttr, Attribute> attrPair) {
1027 return NamedAttribute(attrPair.first, attrPair.second);
1028 }));
1029 dest->setAttrs(DictionaryAttr::get(getContext(), newAttrs));
1030
1031 // Clone the body.
1032 getBody().cloneInto(&dest.getBody(), mapper);
1033}
1034
1035/// Create a deep copy of this function and all of its blocks, remapping
1036/// any operands that use values outside of the function using the map that is
1037/// provided (leaving them alone if no entry is present). Replaces references
1038/// to cloned sub-values with the corresponding value that is copied, and adds
1039/// those mappings to the mapper.
1040FuncOp FuncOp::clone(IRMapping &mapper) {
1041 // Create the new function.
1042 FuncOp newFunc = cast<FuncOp>(getOperation()->cloneWithoutRegions());
1043
1044 // If the function has a body, then the user might be deleting arguments to
1045 // the function by specifying them in the mapper. If so, we don't add the
1046 // argument to the input type vector.
1047 if (!isExternal()) {
1048 FunctionType oldType = getFunctionType();
1049
1050 unsigned oldNumArgs = oldType.getNumInputs();
1051 SmallVector<Type, 4> newInputs;
1052 newInputs.reserve(oldNumArgs);
1053 for (unsigned i = 0; i != oldNumArgs; ++i)
1054 if (!mapper.contains(getArgument(i)))
1055 newInputs.push_back(oldType.getInput(i));
1056
1057 /// If any of the arguments were dropped, update the type and drop any
1058 /// necessary argument attributes.
1059 if (newInputs.size() != oldNumArgs) {
1060 newFunc.setType(FunctionType::get(oldType.getContext(), newInputs,
1061 oldType.getResults()));
1062
1063 if (ArrayAttr argAttrs = getAllArgAttrs()) {
1064 SmallVector<Attribute> newArgAttrs;
1065 newArgAttrs.reserve(newInputs.size());
1066 for (unsigned i = 0; i != oldNumArgs; ++i)
1067 if (!mapper.contains(getArgument(i)))
1068 newArgAttrs.push_back(argAttrs[i]);
1069 newFunc.setAllArgAttrs(newArgAttrs);
1070 }
1071 }
1072 }
1073
1074 /// Clone the current function into the new one and return it.
1075 cloneInto(newFunc, mapper);
1076 return newFunc;
1077}
1078
1079FuncOp FuncOp::clone() {
1080 IRMapping mapper;
1081 return clone(mapper);
1082}
1083
1084// The following functions are entirely new additions compared to upstream.
1085
1086void FuncOp::getAsmBlockArgumentNames(mlir::Region &region,
1087 mlir::OpAsmSetValueNameFn setNameFn) {
1088 if (region.empty())
1089 return;
1090
1091 for (auto [arg, name] : llvm::zip(getArguments(), getArgNames()))
1092 setNameFn(arg, cast<StringAttr>(name).getValue());
1093}
1094
1095LogicalResult FuncOp::verify() {
1096 if (getFunctionType().getNumResults() > 1)
1097 return emitOpError(
1098 "incorrect number of function results (always has to be 0 or 1)");
1099
1100 if (getBody().empty())
1101 return success();
1102
1103 if (getArgNames().size() != getFunctionType().getNumInputs())
1104 return emitOpError("incorrect number of argument names");
1105
1106 for (auto portName : getArgNames()) {
1107 if (cast<StringAttr>(portName).getValue().empty())
1108 return emitOpError("arg name must not be empty");
1109 }
1110
1111 return success();
1112}
1113
1114LogicalResult FuncOp::verifyRegions() {
1115 auto attachNote = [&](mlir::InFlightDiagnostic &diag) {
1116 diag.attachNote(getLoc()) << "in function '@" << getName() << "'";
1117 };
1118 return verifyUniqueNamesInRegion(getOperation(), getArgNames(), attachNote);
1119}
1120
1121//===----------------------------------------------------------------------===//
1122// ReturnOp
1123//
1124// TODO: The implementation for this operation was copy-pasted from the
1125// 'func' dialect. Ideally, this upstream dialect refactored such that we can
1126// re-use the implementation here.
1127//===----------------------------------------------------------------------===//
1128
1129LogicalResult ReturnOp::verify() {
1130 auto function = cast<FuncOp>((*this)->getParentOp());
1131
1132 // The operand number and types must match the function signature.
1133 const auto &results = function.getFunctionType().getResults();
1134 if (getNumOperands() != results.size())
1135 return emitOpError("has ")
1136 << getNumOperands() << " operands, but enclosing function (@"
1137 << function.getName() << ") returns " << results.size();
1138
1139 for (unsigned i = 0, e = results.size(); i != e; ++i)
1140 if (getOperand(i).getType() != results[i])
1141 return emitError() << "type of return operand " << i << " ("
1142 << getOperand(i).getType()
1143 << ") doesn't match function result type ("
1144 << results[i] << ")"
1145 << " in function @" << function.getName();
1146
1147 return success();
1148}
1149
1150//===----------------------------------------------------------------------===//
1151// TableGen generated logic.
1152//===----------------------------------------------------------------------===//
1153
1154// Provide the autogenerated implementation guts for the Op classes.
1155#define GET_OP_CLASSES
1156#include "circt/Dialect/SystemC/SystemC.cpp.inc"
assert(baseType &&"element must be base type")
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
static hw::ModulePort::Direction getDirection(Type type)
static Type wrapPortType(Type type, hw::ModulePort::Direction direction)
static LogicalResult verifyUniqueNamesInRegion(Operation *operation, ArrayAttr argNames, std::function< void(mlir::InFlightDiagnostic &)> attachNote)
This stores lookup tables to make manipulating and working with the IR more efficient.
Definition HWSymCache.h:28
mlir::Operation * getDefinition(mlir::Attribute attr) const override
Lookup a definition for 'symbol' in the cache.
Definition HWSymCache.h:57
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
LogicalResult verifyInstanceOfHWModule(Operation *instance, FlatSymbolRefAttr moduleRef, OperandRange inputs, TypeRange results, ArrayAttr argNames, ArrayAttr resultNames, ArrayAttr parameters, SymbolTableCollection &symbolTable)
Combines verifyReferencedModule, verifyInputs, verifyOutputs, and verifyParameters.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
void printModuleSignature(OpAsmPrinter &p, Operation *op, ArrayRef< Type > argTypes, bool isVariadic, ArrayRef< Type > resultTypes, bool &needArgNamesAttr)
Print a module signature with named results.
ParseResult parseModuleFunctionSignature(OpAsmParser &parser, bool &isVariadic, SmallVectorImpl< OpAsmParser::Argument > &args, SmallVectorImpl< Attribute > &argNames, SmallVectorImpl< Attribute > &argLocs, SmallVectorImpl< Attribute > &resultNames, SmallVectorImpl< DictionaryAttr > &resultAttrs, SmallVectorImpl< Attribute > &resultLocs, TypeAttr &type)
This is a variant of mlir::parseFunctionSignature that allows names on result arguments.
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
void info(Twine message)
Definition LSPUtils.cpp:20
Type getSignalBaseType(Type type)
Get the type wrapped by a signal or port (in, inout, out) type.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
ParseResult parseImplicitSSAName(OpAsmParser &parser, StringAttr &attr)
Parse an implicit SSA name string attribute.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
This holds a decoded list of input/inout and output ports for a module or instance.
This holds the name, type, direction of a module's ports.