19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Interfaces/FunctionImplementation.h"
23#include "llvm/ADT/MapVector.h"
33void DPIFuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
34 StringAttr symName, ArrayRef<StringAttr> argNames,
35 ArrayRef<Type> argTypes,
36 ArrayRef<DPIDirection> argDirections, ArrayAttr argLocs,
37 StringAttr verilogName) {
39 SmallVector<DPIArgument> args;
40 args.reserve(argNames.size());
41 for (
auto [name, type, dir] :
llvm::zip(argNames, argTypes, argDirections))
42 args.push_back({name, type, dir});
43 auto dpiType = DPIFunctionType::get(odsBuilder.getContext(), args);
44 build(odsBuilder, odsState, symName, dpiType, argLocs, verilogName);
47void DPIFuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
48 StringAttr symName, DPIFunctionType dpiFunctionType,
49 ArrayAttr argLocs, StringAttr verilogName) {
50 odsState.addAttribute(getSymNameAttrName(odsState.name), symName);
51 odsState.addAttribute(getDpiFunctionTypeAttrName(odsState.name),
52 TypeAttr::get(dpiFunctionType));
54 odsState.addAttribute(getArgumentLocsAttrName(odsState.name), argLocs);
56 odsState.addAttribute(getVerilogNameAttrName(odsState.name), verilogName);
60::mlir::Type DPIFuncOp::getFunctionType() {
61 return getDpiFunctionType().getFunctionType();
64void DPIFuncOp::setFunctionTypeAttr(::mlir::TypeAttr type) {
66 auto dpiType = llvm::dyn_cast<DPIFunctionType>(type.getValue());
67 assert(dpiType &&
"DPIFuncOp function type can only be set via "
68 "DPIFunctionType, not a plain FunctionType");
69 setDpiFunctionType(dpiType);
72::mlir::Type DPIFuncOp::cloneTypeWith(::mlir::TypeRange inputs,
73 ::mlir::TypeRange results) {
74 return FunctionType::get(getContext(), inputs, results);
77ParseResult DPIFuncOp::parse(OpAsmParser &parser, OperationState &result) {
78 auto builder = parser.getBuilder();
79 auto ctx = builder.getContext();
81 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
84 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
88 SmallVector<DPIArgument> args;
89 SmallVector<Attribute> argLocs;
90 auto unknownLoc = builder.getUnknownLoc();
93 auto parseOneArg = [&]() -> ParseResult {
95 auto keyLoc = parser.getCurrentLocation();
96 if (parser.parseKeyword(&dirKeyword))
100 return parser.emitError(keyLoc,
101 "expected DPI argument direction keyword");
107 OpAsmParser::UnresolvedOperand ssaName;
108 if (parser.parseOperand(ssaName,
false))
110 argName = ssaName.name.substr(1).str();
112 if (parser.parseKeywordOrString(&argName))
117 if (parser.parseColonType(argType))
119 args.push_back({StringAttr::get(ctx, argName), argType, *dir});
121 std::optional<Location> maybeLoc;
122 if (failed(parser.parseOptionalLocationSpecifier(maybeLoc)))
125 argLocs.push_back(*maybeLoc);
128 argLocs.push_back(unknownLoc);
133 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, parseOneArg,
134 " in DPI argument list"))
137 auto dpiType = DPIFunctionType::get(ctx, args);
139 result.addAttribute(DPIFuncOp::getDpiFunctionTypeAttrName(result.name),
140 TypeAttr::get(dpiType));
142 result.addAttribute(DPIFuncOp::getArgumentLocsAttrName(result.name),
143 builder.getArrayAttr(argLocs));
146 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
151void DPIFuncOp::print(OpAsmPrinter &p) {
154 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
155 if (
auto visibility = (*this)->getAttrOfType<StringAttr>(visibilityAttrName))
156 p << visibility.getValue() <<
' ';
157 p.printSymbolName(getSymName());
159 auto dpiType = getDpiFunctionType();
160 auto dpiArgs = dpiType.getArguments();
163 llvm::interleaveComma(llvm::enumerate(dpiArgs), p, [&](
auto it) {
164 auto &arg = it.value();
171 p.printKeywordOrString(arg.name.getValue());
173 p.printType(arg.type);
175 if (getArgumentLocs()) {
176 auto loc = cast<Location>(getArgumentLocsAttr()[i]);
177 if (loc != UnknownLoc::get(getContext()))
178 p.printOptionalLocationSpecifier(loc);
183 mlir::function_interface_impl::printFunctionAttributes(
185 {visibilityAttrName, getDpiFunctionTypeAttrName(),
186 getArgumentLocsAttrName()});
189LogicalResult DPIFuncOp::verify() {
190 auto dpiType = getDpiFunctionType();
193 if (failed(dpiType.verify([&]() { return emitOpError(); })))
197 for (
auto &arg : dpiType.getArguments()) {
198 if (arg.dir == DPIDirection::Ref) {
199 if (!isa<LLVM::LLVMPointerType>(arg.type))
200 return emitOpError(
"'ref' arguments must use !llvm.ptr type");
208sim::DPICallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
210 symbolTable.lookupNearestSymbolFrom(*
this, getCalleeAttr());
212 return emitError(
"cannot find function declaration '")
213 << getCallee() <<
"'";
214 if (
auto dpiFunc = dyn_cast<sim::DPIFuncOp>(referencedOp)) {
215 auto expectedFuncType = cast<FunctionType>(dpiFunc.getFunctionType());
216 auto expectedInputs = expectedFuncType.getInputs();
217 auto expectedResults = expectedFuncType.getResults();
218 if (getInputs().size() != expectedInputs.size())
219 return emitError(
"expects ")
220 << expectedInputs.size() <<
" DPI operands, but got "
221 << getInputs().size();
222 if (getResults().size() != expectedResults.size())
223 return emitError(
"expects ")
224 << expectedResults.size() <<
" DPI results, but got "
225 << getResults().size();
226 for (
auto [operand, expectedType] :
llvm::zip(getInputs(), expectedInputs))
227 if (operand.getType() != expectedType)
228 return emitError(
"operand type mismatch: expected ")
229 << expectedType <<
", but got " << operand.getType();
230 for (
auto [result, expectedType] :
llvm::zip(getResults(), expectedResults))
231 if (result.getType() != expectedType)
232 return emitError(
"result type mismatch: expected ")
233 << expectedType <<
", but got " << result.getType();
236 if (isa<func::FuncOp>(referencedOp))
238 return emitError(
"callee must be 'sim.func.dpi' or 'func.func' but got '")
239 << referencedOp->getName() <<
"'";
243 const Attribute &value,
244 bool isUpperCase,
bool isLeftAligned,
246 std::optional<unsigned> specifierWidth,
247 bool isSigned =
false) {
248 auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(value);
251 if (intAttr.getType().getIntOrFloatBitWidth() == 0)
252 return StringAttr::get(ctx,
"");
254 SmallVector<char, 32> strBuf;
255 intAttr.getValue().toString(strBuf, radix, isSigned,
false, isUpperCase);
256 unsigned width = intAttr.getType().getIntOrFloatBitWidth();
264 padWidth = (width + 2) / 3;
267 padWidth = (width + 3) / 4;
274 unsigned numSpaces = 0;
275 if (specifierWidth.has_value() &&
276 (specifierWidth.value() >
277 std::max(padWidth,
static_cast<unsigned>(strBuf.size())))) {
278 numSpaces = std::max(
279 0U, specifierWidth.value() -
280 std::max(padWidth,
static_cast<unsigned>(strBuf.size())));
283 SmallVector<char, 1> spacePadding(numSpaces,
' ');
285 padWidth = padWidth > strBuf.size() ? padWidth - strBuf.size() : 0;
287 SmallVector<char, 32> padding(padWidth, paddingChar);
289 return StringAttr::get(ctx, Twine(padding) + Twine(strBuf) +
290 Twine(spacePadding));
292 return StringAttr::get(ctx,
293 Twine(spacePadding) + Twine(padding) + Twine(strBuf));
298 std::optional<unsigned> fieldWidth,
299 std::optional<unsigned> fracDigits,
300 std::string formatSpecifier) {
301 if (
auto floatAttr = llvm::dyn_cast_or_null<FloatAttr>(value)) {
302 std::string widthString = isLeftAligned ?
"-" :
"";
303 if (fieldWidth.has_value()) {
304 widthString += std::to_string(fieldWidth.value());
306 std::string fmtSpecifier =
"%" + widthString +
"." +
307 std::to_string(fracDigits.value()) +
312 int bufferSize = std::snprintf(
nullptr, 0, fmtSpecifier.c_str(),
313 floatAttr.getValue().convertToDouble());
314 std::string floatFmtBuffer(bufferSize,
'\0');
315 snprintf(floatFmtBuffer.data(), bufferSize + 1, fmtSpecifier.c_str(),
316 floatAttr.getValue().convertToDouble());
317 return StringAttr::get(ctx, floatFmtBuffer);
325OpFoldResult FormatLiteralOp::fold(FoldAdaptor adaptor) {
326 return getLiteralAttr();
331StringAttr FormatDecOp::formatConstant(Attribute constVal) {
332 auto intAttr = llvm::dyn_cast<IntegerAttr>(constVal);
335 SmallVector<char, 16> strBuf;
336 intAttr.getValue().toString(strBuf, 10, getIsSigned());
338 if (getSpecifierWidth().has_value()) {
339 padWidth = getSpecifierWidth().value();
341 unsigned width = intAttr.getType().getIntOrFloatBitWidth();
342 padWidth = FormatDecOp::getDecimalWidth(width, getIsSigned());
345 padWidth = padWidth > strBuf.size() ? padWidth - strBuf.size() : 0;
347 SmallVector<char, 10> padding(padWidth, getPaddingChar());
348 if (getIsLeftAligned())
349 return StringAttr::get(getContext(), Twine(strBuf) + Twine(padding));
350 return StringAttr::get(getContext(), Twine(padding) + Twine(strBuf));
353OpFoldResult FormatDecOp::fold(FoldAdaptor adaptor) {
354 if (getValue().getType().getIntOrFloatBitWidth() == 0)
355 return StringAttr::get(getContext(),
"0");
361StringAttr FormatHexOp::formatConstant(Attribute constVal) {
363 getIsHexUppercase(), getIsLeftAligned(),
364 getPaddingChar(), getSpecifierWidth());
367OpFoldResult FormatHexOp::fold(FoldAdaptor adaptor) {
368 if (getValue().getType().getIntOrFloatBitWidth() == 0)
370 getContext(), 16, IntegerAttr::get(getValue().getType(), 0),
false,
371 getIsLeftAligned(), getPaddingChar(), getSpecifierWidth());
377StringAttr FormatOctOp::formatConstant(Attribute constVal) {
379 getIsLeftAligned(), getPaddingChar(),
380 getSpecifierWidth());
383OpFoldResult FormatOctOp::fold(FoldAdaptor adaptor) {
384 if (getValue().getType().getIntOrFloatBitWidth() == 0)
386 getContext(), 8, IntegerAttr::get(getValue().getType(), 0),
false,
387 getIsLeftAligned(), getPaddingChar(), getSpecifierWidth());
393StringAttr FormatBinOp::formatConstant(Attribute constVal) {
395 getIsLeftAligned(), getPaddingChar(),
396 getSpecifierWidth());
399OpFoldResult FormatBinOp::fold(FoldAdaptor adaptor) {
400 if (getValue().getType().getIntOrFloatBitWidth() == 0)
402 getContext(), 2, IntegerAttr::get(getValue().getType(), 0),
false,
403 getIsLeftAligned(), getPaddingChar(), getSpecifierWidth());
409StringAttr FormatScientificOp::formatConstant(Attribute constVal) {
411 getFieldWidth(), getFracDigits(),
"e");
416StringAttr FormatFloatOp::formatConstant(Attribute constVal) {
418 getFieldWidth(), getFracDigits(),
"f");
423StringAttr FormatGeneralOp::formatConstant(Attribute constVal) {
425 getFieldWidth(), getFracDigits(),
"g");
430StringAttr FormatCharOp::formatConstant(Attribute constVal) {
431 auto intCst = dyn_cast<IntegerAttr>(constVal);
434 if (intCst.getType().getIntOrFloatBitWidth() == 0)
435 return StringAttr::get(getContext(), Twine(
static_cast<char>(0)));
436 if (intCst.getType().getIntOrFloatBitWidth() > 8)
438 auto intValue = intCst.getValue().getZExtValue();
439 return StringAttr::get(getContext(), Twine(
static_cast<char>(intValue)));
442OpFoldResult FormatCharOp::fold(FoldAdaptor adaptor) {
443 if (getValue().getType().getIntOrFloatBitWidth() == 0)
444 return StringAttr::get(getContext(), Twine(
static_cast<char>(0)));
449 assert(!lits.empty() &&
"No literals to concatenate");
450 if (lits.size() == 1)
451 return StringAttr::get(ctxt, lits.front());
452 SmallString<64> newLit;
453 for (
auto lit : lits)
455 return StringAttr::get(ctxt, newLit);
458OpFoldResult FormatStringConcatOp::fold(FoldAdaptor adaptor) {
459 if (getNumOperands() == 0)
460 return StringAttr::get(getContext(),
"");
461 if (getNumOperands() == 1) {
463 if (getResult() == getOperand(0))
465 return getOperand(0);
469 SmallVector<StringRef> lits;
470 for (
auto attr : adaptor.getInputs()) {
471 auto lit = dyn_cast_or_null<StringAttr>(attr);
479LogicalResult FormatStringConcatOp::getFlattenedInputs(
480 llvm::SmallVectorImpl<Value> &flatOperands) {
482 bool isCyclic =
false;
486 concatStack.insert({*
this, 0});
487 while (!concatStack.empty()) {
488 auto &top = concatStack.back();
489 auto currentConcat = top.first;
490 unsigned operandIndex = top.second;
493 while (operandIndex < currentConcat.getNumOperands()) {
494 auto currentOperand = currentConcat.getOperand(operandIndex);
496 if (
auto nextConcat =
497 currentOperand.getDefiningOp<FormatStringConcatOp>()) {
499 if (!concatStack.contains(nextConcat)) {
502 top.second = operandIndex + 1;
503 concatStack.insert({nextConcat, 0});
510 flatOperands.push_back(currentOperand);
515 if (operandIndex >= currentConcat.getNumOperands())
516 concatStack.pop_back();
519 return success(!isCyclic);
522LogicalResult FormatStringConcatOp::verify() {
523 if (llvm::any_of(getOperands(),
524 [&](Value operand) {
return operand == getResult(); }))
525 return emitOpError(
"is infinitely recursive.");
529LogicalResult FormatStringConcatOp::canonicalize(FormatStringConcatOp op,
530 PatternRewriter &rewriter) {
532 rewriter.setInsertionPoint(op);
534 auto fmtStrType = FormatStringType::get(op.getContext());
537 bool hasBeenFlattened =
false;
538 SmallVector<Value, 0> flatOperands;
541 flatOperands.reserve(op.getNumOperands() + 4);
542 auto isAcyclic = op.getFlattenedInputs(flatOperands);
544 if (failed(isAcyclic)) {
547 op.emitWarning(
"Cyclic concatenation detected.");
551 hasBeenFlattened =
true;
554 if (!hasBeenFlattened && op.getNumOperands() < 2)
559 SmallVector<StringRef> litSequence;
560 SmallVector<Value> newOperands;
561 newOperands.reserve(op.getNumOperands());
562 FormatLiteralOp prevLitOp;
564 auto oldOperands = hasBeenFlattened ? flatOperands : op.getOperands();
565 for (
auto operand : oldOperands) {
566 if (
auto litOp = operand.getDefiningOp<FormatLiteralOp>()) {
567 if (!litOp.getLiteral().empty()) {
569 litSequence.push_back(litOp.getLiteral());
572 if (!litSequence.empty()) {
573 if (litSequence.size() > 1) {
575 auto newLit = rewriter.createOrFold<FormatLiteralOp>(
576 op.getLoc(), fmtStrType,
578 newOperands.push_back(newLit);
581 newOperands.push_back(prevLitOp.getResult());
585 newOperands.push_back(operand);
590 if (!litSequence.empty()) {
591 if (litSequence.size() > 1) {
593 auto newLit = rewriter.createOrFold<FormatLiteralOp>(
594 op.getLoc(), fmtStrType,
596 newOperands.push_back(newLit);
599 newOperands.push_back(prevLitOp.getResult());
603 if (!hasBeenFlattened && newOperands.size() == op.getNumOperands())
606 if (newOperands.empty())
607 rewriter.replaceOpWithNewOp<FormatLiteralOp>(op, fmtStrType,
608 rewriter.getStringAttr(
""));
609 else if (newOperands.size() == 1)
610 rewriter.replaceOp(op, newOperands);
612 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); });
617LogicalResult PrintFormattedOp::canonicalize(PrintFormattedOp op,
618 PatternRewriter &rewriter) {
620 if (
auto cstCond = op.getCondition().getDefiningOp<
hw::ConstantOp>()) {
621 if (cstCond.getValue().isZero()) {
622 rewriter.eraseOp(op);
629LogicalResult PrintFormattedProcOp::canonicalize(PrintFormattedProcOp op,
630 PatternRewriter &rewriter) {
632 if (
auto litInput = op.getInput().getDefiningOp<FormatLiteralOp>()) {
633 if (litInput.getLiteral().empty()) {
634 rewriter.eraseOp(op);
641OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
642 return adaptor.getLiteralAttr();
645OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
646 auto operands = adaptor.getInputs();
647 if (operands.empty())
648 return StringAttr::get(getContext(),
"");
650 SmallString<128> result;
651 for (
auto &operand : operands) {
652 auto strAttr = cast_if_present<StringAttr>(operand);
655 result += strAttr.getValue();
658 return StringAttr::get(getContext(), result);
661OpFoldResult StringLengthOp::fold(FoldAdaptor adaptor) {
662 auto inputAttr = adaptor.getInput();
666 if (
auto strAttr = cast<StringAttr>(inputAttr))
667 return IntegerAttr::get(getType(), strAttr.getValue().size());
672OpFoldResult IntToStringOp::fold(FoldAdaptor adaptor) {
673 auto intAttr = cast_or_null<IntegerAttr>(adaptor.getInput());
677 SmallString<128> result;
678 auto width = intAttr.getType().getIntOrFloatBitWidth();
683 for (
unsigned int i = 0; i < width; i += 8) {
685 intAttr.getValue().extractBitsAsZExtValue(std::min(width - i, 8U), i);
687 result.push_back(
static_cast<char>(
byte));
689 std::reverse(result.begin(), result.end());
690 return StringAttr::get(getContext(), result);
698OpFoldResult StringGetOp::fold(FoldAdaptor adaptor) {
699 auto strAttr = cast_or_null<StringAttr>(adaptor.getStr());
700 auto indexAttr = cast_or_null<IntegerAttr>(adaptor.getIndex());
701 if (!strAttr || !indexAttr)
704 auto str = strAttr.getValue();
705 int64_t index = indexAttr.getValue().getSExtValue();
708 if (index < 0 || index >=
static_cast<int64_t
>(str.size()))
709 return IntegerAttr::get(getType(), 0);
712 uint8_t ch =
static_cast<uint8_t
>(str[index]);
713 return IntegerAttr::get(getType(), ch);
720LogicalResult QueueResizeOp::verify() {
721 if (cast<QueueType>(getInput().getType()).getElementType() !=
722 cast<QueueType>(getResult().getType()).getElementType())
727LogicalResult QueueFromArrayOp::verify() {
728 auto queueElementType =
729 cast<QueueType>(getResult().getType()).getElementType();
731 auto arrayElementType =
732 cast<hw::ArrayType>(getInput().getType()).getElementType();
734 if (queueElementType != arrayElementType) {
735 return emitOpError() <<
"sim::Queue element type " << queueElementType
736 <<
" doesn't match hw::ArrayType element type "
743LogicalResult QueueConcatOp::verify() {
746 auto resultElType = cast<QueueType>(getResult().getType()).getElementType();
748 for (Value input : getInputs()) {
749 auto inpElType = cast<QueueType>(input.getType()).getElementType();
750 if (inpElType != resultElType) {
751 return emitOpError() <<
"sim::Queue element type " << inpElType
752 <<
" doesn't match result sim::Queue element type "
764#include "circt/Dialect/Sim/SimOpInterfaces.cpp.inc"
767#define GET_OP_CLASSES
768#include "circt/Dialect/Sim/Sim.cpp.inc"
769#include "circt/Dialect/Sim/SimEnums.cpp.inc"
assert(baseType &&"element must be base type")
static StringAttr formatFloatsBySpecifier(MLIRContext *ctx, Attribute value, bool isLeftAligned, std::optional< unsigned > fieldWidth, std::optional< unsigned > fracDigits, std::string formatSpecifier)
static StringAttr formatIntegersByRadix(MLIRContext *ctx, unsigned radix, const Attribute &value, bool isUpperCase, bool isLeftAligned, char paddingChar, std::optional< unsigned > specifierWidth, bool isSigned=false)
static StringAttr concatLiterals(MLIRContext *ctxt, ArrayRef< StringRef > lits)
llvm::StringRef stringifyDPIDirectionKeyword(DPIDirection dir)
Return the keyword string for a DPIDirection (e.g. "in", "return").
std::optional< DPIDirection > parseDPIDirectionKeyword(llvm::StringRef keyword)
Parse a keyword string to a DPIDirection. Returns std::nullopt on failure.
bool isCallOperandDir(DPIDirection dir)
True if an argument with this direction is a call operand (input/inout/ref).
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.