19#include "mlir/Dialect/Func/IR/FuncOps.h"
20#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
21#include "mlir/IR/PatternMatch.h"
22#include "mlir/Interfaces/FunctionImplementation.h"
23#include "llvm/ADT/MapVector.h"
33void DPIFuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
34 StringAttr symName, ArrayRef<StringAttr> argNames,
35 ArrayRef<Type> argTypes,
36 ArrayRef<DPIDirection> argDirections, ArrayAttr argLocs,
37 StringAttr verilogName) {
39 SmallVector<DPIArgument> args;
40 args.reserve(argNames.size());
41 for (
auto [name, type, dir] :
llvm::zip(argNames, argTypes, argDirections))
42 args.push_back({name, type, dir});
43 auto dpiType = DPIFunctionType::get(odsBuilder.getContext(), args);
44 build(odsBuilder, odsState, symName, dpiType, argLocs, verilogName);
47void DPIFuncOp::build(OpBuilder &odsBuilder, OperationState &odsState,
48 StringAttr symName, DPIFunctionType dpiFunctionType,
49 ArrayAttr argLocs, StringAttr verilogName) {
50 odsState.addAttribute(getSymNameAttrName(odsState.name), symName);
51 odsState.addAttribute(getDpiFunctionTypeAttrName(odsState.name),
52 TypeAttr::get(dpiFunctionType));
54 odsState.addAttribute(getArgumentLocsAttrName(odsState.name), argLocs);
56 odsState.addAttribute(getVerilogNameAttrName(odsState.name), verilogName);
60::mlir::Type DPIFuncOp::getFunctionType() {
61 return getDpiFunctionType().getFunctionType();
64void DPIFuncOp::setFunctionTypeAttr(::mlir::TypeAttr type) {
66 auto dpiType = llvm::dyn_cast<DPIFunctionType>(type.getValue());
67 assert(dpiType &&
"DPIFuncOp function type can only be set via "
68 "DPIFunctionType, not a plain FunctionType");
69 setDpiFunctionType(dpiType);
72::mlir::Type DPIFuncOp::cloneTypeWith(::mlir::TypeRange inputs,
73 ::mlir::TypeRange results) {
74 return FunctionType::get(getContext(), inputs, results);
77ParseResult DPIFuncOp::parse(OpAsmParser &parser, OperationState &result) {
78 auto builder = parser.getBuilder();
79 auto ctx = builder.getContext();
81 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
84 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
88 SmallVector<DPIArgument> args;
89 SmallVector<Attribute> argLocs;
90 auto unknownLoc = builder.getUnknownLoc();
93 auto parseOneArg = [&]() -> ParseResult {
95 auto keyLoc = parser.getCurrentLocation();
96 if (parser.parseKeyword(&dirKeyword))
100 return parser.emitError(keyLoc,
101 "expected DPI argument direction keyword");
107 OpAsmParser::UnresolvedOperand ssaName;
108 if (parser.parseOperand(ssaName,
false))
110 argName = ssaName.name.substr(1).str();
112 if (parser.parseKeywordOrString(&argName))
117 if (parser.parseColonType(argType))
119 args.push_back({StringAttr::get(ctx, argName), argType, *dir});
121 std::optional<Location> maybeLoc;
122 if (failed(parser.parseOptionalLocationSpecifier(maybeLoc)))
125 argLocs.push_back(*maybeLoc);
128 argLocs.push_back(unknownLoc);
133 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren, parseOneArg,
134 " in DPI argument list"))
137 auto dpiType = DPIFunctionType::get(ctx, args);
139 result.addAttribute(DPIFuncOp::getDpiFunctionTypeAttrName(result.name),
140 TypeAttr::get(dpiType));
142 result.addAttribute(DPIFuncOp::getArgumentLocsAttrName(result.name),
143 builder.getArrayAttr(argLocs));
146 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
151void DPIFuncOp::print(OpAsmPrinter &p) {
154 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
155 if (
auto visibility = (*this)->getAttrOfType<StringAttr>(visibilityAttrName))
156 p << visibility.getValue() <<
' ';
157 p.printSymbolName(getSymName());
159 auto dpiType = getDpiFunctionType();
160 auto dpiArgs = dpiType.getArguments();
163 llvm::interleaveComma(llvm::enumerate(dpiArgs), p, [&](
auto it) {
164 auto &arg = it.value();
171 p.printKeywordOrString(arg.name.getValue());
173 p.printType(arg.type);
175 if (getArgumentLocs()) {
176 auto loc = cast<Location>(getArgumentLocsAttr()[i]);
177 if (loc != UnknownLoc::get(getContext()))
178 p.printOptionalLocationSpecifier(loc);
183 mlir::function_interface_impl::printFunctionAttributes(
185 {visibilityAttrName, getDpiFunctionTypeAttrName(),
186 getArgumentLocsAttrName()});
189LogicalResult DPIFuncOp::verify() {
190 auto dpiType = getDpiFunctionType();
193 if (failed(dpiType.verify([&]() { return emitOpError(); })))
197 for (
auto &arg : dpiType.getArguments()) {
198 if (arg.dir == DPIDirection::Ref) {
199 if (!isa<LLVM::LLVMPointerType>(arg.type))
200 return emitOpError(
"'ref' arguments must use !llvm.ptr type");
208sim::DPICallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
210 symbolTable.lookupNearestSymbolFrom(*
this, getCalleeAttr());
212 return emitError(
"cannot find function declaration '")
213 << getCallee() <<
"'";
214 if (
auto dpiFunc = dyn_cast<sim::DPIFuncOp>(referencedOp)) {
215 auto expectedFuncType = cast<FunctionType>(dpiFunc.getFunctionType());
216 auto expectedInputs = expectedFuncType.getInputs();
217 auto expectedResults = expectedFuncType.getResults();
218 if (getInputs().size() != expectedInputs.size())
219 return emitError(
"expects ")
220 << expectedInputs.size() <<
" DPI operands, but got "
221 << getInputs().size();
222 if (getResults().size() != expectedResults.size())
223 return emitError(
"expects ")
224 << expectedResults.size() <<
" DPI results, but got "
225 << getResults().size();
226 for (
auto [operand, expectedType] :
llvm::zip(getInputs(), expectedInputs))
227 if (operand.getType() != expectedType)
228 return emitError(
"operand type mismatch: expected ")
229 << expectedType <<
", but got " << operand.getType();
230 for (
auto [result, expectedType] :
llvm::zip(getResults(), expectedResults))
231 if (result.getType() != expectedType)
232 return emitError(
"result type mismatch: expected ")
233 << expectedType <<
", but got " << result.getType();
236 if (isa<func::FuncOp>(referencedOp))
238 return emitError(
"callee must be 'sim.func.dpi' or 'func.func' but got '")
239 << referencedOp->getName() <<
"'";
243 const Attribute &value,
244 bool isUpperCase,
bool isLeftAligned,
246 std::optional<unsigned> specifierWidth,
247 bool isSigned =
false) {
248 auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(value);
251 if (intAttr.getType().getIntOrFloatBitWidth() == 0)
252 return StringAttr::get(ctx,
"");
254 SmallVector<char, 32> strBuf;
255 intAttr.getValue().toString(strBuf, radix, isSigned,
false, isUpperCase);
256 unsigned width = intAttr.getType().getIntOrFloatBitWidth();
264 padWidth = (width + 2) / 3;
267 padWidth = (width + 3) / 4;
274 unsigned numSpaces = 0;
275 if (specifierWidth.has_value() &&
276 (specifierWidth.value() >
277 std::max(padWidth,
static_cast<unsigned>(strBuf.size())))) {
278 numSpaces = std::max(
279 0U, specifierWidth.value() -
280 std::max(padWidth,
static_cast<unsigned>(strBuf.size())));
283 SmallVector<char, 1> spacePadding(numSpaces,
' ');
285 padWidth = padWidth > strBuf.size() ? padWidth - strBuf.size() : 0;
287 SmallVector<char, 32> padding(padWidth, paddingChar);
289 return StringAttr::get(ctx, Twine(padding) + Twine(strBuf) +
290 Twine(spacePadding));
292 return StringAttr::get(ctx,
293 Twine(spacePadding) + Twine(padding) + Twine(strBuf));
298 std::optional<unsigned> fieldWidth,
299 std::optional<unsigned> fracDigits,
300 std::string formatSpecifier) {
301 if (
auto floatAttr = llvm::dyn_cast_or_null<FloatAttr>(value)) {
302 std::string widthString = isLeftAligned ?
"-" :
"";
303 if (fieldWidth.has_value()) {
304 widthString += std::to_string(fieldWidth.value());
306 std::string fmtSpecifier =
"%" + widthString +
"." +
307 std::to_string(fracDigits.value()) +
312 int bufferSize = std::snprintf(
nullptr, 0, fmtSpecifier.c_str(),
313 floatAttr.getValue().convertToDouble());
314 std::string floatFmtBuffer(bufferSize,
'\0');
315 snprintf(floatFmtBuffer.data(), bufferSize + 1, fmtSpecifier.c_str(),
316 floatAttr.getValue().convertToDouble());
317 return StringAttr::get(ctx, floatFmtBuffer);
325OpFoldResult FormatLiteralOp::fold(FoldAdaptor adaptor) {
326 return getLiteralAttr();
331StringAttr FormatStringOp::formatConstant(Attribute constVal) {
332 auto strAttr = llvm::dyn_cast<StringAttr>(constVal);
336 SmallString<128> strBuf(strAttr.getValue());
337 if (getSpecifierWidth().has_value()) {
338 auto padChar =
static_cast<char>(getPaddingChar());
339 unsigned padWidth = getSpecifierWidth().value();
340 padWidth = padWidth > strBuf.size() ? padWidth - strBuf.size() : 0;
341 if (getIsLeftAligned())
342 strBuf.append(padWidth, padChar);
344 strBuf.insert(strBuf.begin(), padWidth, padChar);
346 return StringAttr::get(getContext(), strBuf);
351StringAttr FormatDecOp::formatConstant(Attribute constVal) {
352 auto intAttr = llvm::dyn_cast<IntegerAttr>(constVal);
355 SmallVector<char, 16> strBuf;
356 intAttr.getValue().toString(strBuf, 10, getIsSigned());
358 if (getSpecifierWidth().has_value()) {
359 padWidth = getSpecifierWidth().value();
361 unsigned width = intAttr.getType().getIntOrFloatBitWidth();
362 padWidth = FormatDecOp::getDecimalWidth(width, getIsSigned());
365 padWidth = padWidth > strBuf.size() ? padWidth - strBuf.size() : 0;
367 SmallVector<char, 10> padding(padWidth, getPaddingChar());
368 if (getIsLeftAligned())
369 return StringAttr::get(getContext(), Twine(strBuf) + Twine(padding));
370 return StringAttr::get(getContext(), Twine(padding) + Twine(strBuf));
373OpFoldResult FormatDecOp::fold(FoldAdaptor adaptor) {
374 if (getValue().getType().getIntOrFloatBitWidth() == 0)
375 return StringAttr::get(getContext(),
"0");
381StringAttr FormatHexOp::formatConstant(Attribute constVal) {
383 getIsHexUppercase(), getIsLeftAligned(),
384 getPaddingChar(), getSpecifierWidth());
387OpFoldResult FormatHexOp::fold(FoldAdaptor adaptor) {
388 if (getValue().getType().getIntOrFloatBitWidth() == 0)
390 getContext(), 16, IntegerAttr::get(getValue().getType(), 0),
false,
391 getIsLeftAligned(), getPaddingChar(), getSpecifierWidth());
397StringAttr FormatOctOp::formatConstant(Attribute constVal) {
399 getIsLeftAligned(), getPaddingChar(),
400 getSpecifierWidth());
403OpFoldResult FormatOctOp::fold(FoldAdaptor adaptor) {
404 if (getValue().getType().getIntOrFloatBitWidth() == 0)
406 getContext(), 8, IntegerAttr::get(getValue().getType(), 0),
false,
407 getIsLeftAligned(), getPaddingChar(), getSpecifierWidth());
413StringAttr FormatBinOp::formatConstant(Attribute constVal) {
415 getIsLeftAligned(), getPaddingChar(),
416 getSpecifierWidth());
419OpFoldResult FormatBinOp::fold(FoldAdaptor adaptor) {
420 if (getValue().getType().getIntOrFloatBitWidth() == 0)
422 getContext(), 2, IntegerAttr::get(getValue().getType(), 0),
false,
423 getIsLeftAligned(), getPaddingChar(), getSpecifierWidth());
429StringAttr FormatScientificOp::formatConstant(Attribute constVal) {
431 getFieldWidth(), getFracDigits(),
"e");
436StringAttr FormatFloatOp::formatConstant(Attribute constVal) {
438 getFieldWidth(), getFracDigits(),
"f");
443StringAttr FormatGeneralOp::formatConstant(Attribute constVal) {
445 getFieldWidth(), getFracDigits(),
"g");
450StringAttr FormatCharOp::formatConstant(Attribute constVal) {
451 auto intCst = dyn_cast<IntegerAttr>(constVal);
454 if (intCst.getType().getIntOrFloatBitWidth() == 0)
455 return StringAttr::get(getContext(), Twine(
static_cast<char>(0)));
456 if (intCst.getType().getIntOrFloatBitWidth() > 8)
458 auto intValue = intCst.getValue().getZExtValue();
459 return StringAttr::get(getContext(), Twine(
static_cast<char>(intValue)));
462OpFoldResult FormatCharOp::fold(FoldAdaptor adaptor) {
463 if (getValue().getType().getIntOrFloatBitWidth() == 0)
464 return StringAttr::get(getContext(), Twine(
static_cast<char>(0)));
469 assert(!lits.empty() &&
"No literals to concatenate");
470 if (lits.size() == 1)
471 return StringAttr::get(ctxt, lits.front());
472 SmallString<64> newLit;
473 for (
auto lit : lits)
475 return StringAttr::get(ctxt, newLit);
478OpFoldResult FormatStringConcatOp::fold(FoldAdaptor adaptor) {
479 if (getNumOperands() == 0)
480 return StringAttr::get(getContext(),
"");
481 if (getNumOperands() == 1) {
483 if (getResult() == getOperand(0))
485 return getOperand(0);
489 SmallVector<StringRef> lits;
490 for (
auto attr : adaptor.getInputs()) {
491 auto lit = dyn_cast_or_null<StringAttr>(attr);
499LogicalResult FormatStringConcatOp::getFlattenedInputs(
500 llvm::SmallVectorImpl<Value> &flatOperands) {
502 bool isCyclic =
false;
506 concatStack.insert({*
this, 0});
507 while (!concatStack.empty()) {
508 auto &top = concatStack.back();
509 auto currentConcat = top.first;
510 unsigned operandIndex = top.second;
513 while (operandIndex < currentConcat.getNumOperands()) {
514 auto currentOperand = currentConcat.getOperand(operandIndex);
516 if (
auto nextConcat =
517 currentOperand.getDefiningOp<FormatStringConcatOp>()) {
519 if (!concatStack.contains(nextConcat)) {
522 top.second = operandIndex + 1;
523 concatStack.insert({nextConcat, 0});
530 flatOperands.push_back(currentOperand);
535 if (operandIndex >= currentConcat.getNumOperands())
536 concatStack.pop_back();
539 return success(!isCyclic);
542LogicalResult FormatStringConcatOp::verify() {
543 if (llvm::any_of(getOperands(),
544 [&](Value operand) {
return operand == getResult(); }))
545 return emitOpError(
"is infinitely recursive.");
549LogicalResult FormatStringConcatOp::canonicalize(FormatStringConcatOp op,
550 PatternRewriter &rewriter) {
552 rewriter.setInsertionPoint(op);
554 auto fmtStrType = FormatStringType::get(op.getContext());
557 bool hasBeenFlattened =
false;
558 SmallVector<Value, 0> flatOperands;
561 flatOperands.reserve(op.getNumOperands() + 4);
562 auto isAcyclic = op.getFlattenedInputs(flatOperands);
564 if (failed(isAcyclic)) {
567 op.emitWarning(
"Cyclic concatenation detected.");
571 hasBeenFlattened =
true;
574 if (!hasBeenFlattened && op.getNumOperands() < 2)
579 SmallVector<StringRef> litSequence;
580 SmallVector<Value> newOperands;
581 newOperands.reserve(op.getNumOperands());
582 FormatLiteralOp prevLitOp;
584 auto oldOperands = hasBeenFlattened ? flatOperands : op.getOperands();
585 for (
auto operand : oldOperands) {
586 if (
auto litOp = operand.getDefiningOp<FormatLiteralOp>()) {
587 if (!litOp.getLiteral().empty()) {
589 litSequence.push_back(litOp.getLiteral());
592 if (!litSequence.empty()) {
593 if (litSequence.size() > 1) {
595 auto newLit = rewriter.createOrFold<FormatLiteralOp>(
596 op.getLoc(), fmtStrType,
598 newOperands.push_back(newLit);
601 newOperands.push_back(prevLitOp.getResult());
605 newOperands.push_back(operand);
610 if (!litSequence.empty()) {
611 if (litSequence.size() > 1) {
613 auto newLit = rewriter.createOrFold<FormatLiteralOp>(
614 op.getLoc(), fmtStrType,
616 newOperands.push_back(newLit);
619 newOperands.push_back(prevLitOp.getResult());
623 if (!hasBeenFlattened && newOperands.size() == op.getNumOperands())
626 if (newOperands.empty())
627 rewriter.replaceOpWithNewOp<FormatLiteralOp>(op, fmtStrType,
628 rewriter.getStringAttr(
""));
629 else if (newOperands.size() == 1)
630 rewriter.replaceOp(op, newOperands);
632 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); });
637LogicalResult PrintFormattedOp::canonicalize(PrintFormattedOp op,
638 PatternRewriter &rewriter) {
640 if (
auto cstCond = op.getCondition().getDefiningOp<
hw::ConstantOp>()) {
641 if (cstCond.getValue().isZero()) {
642 rewriter.eraseOp(op);
649LogicalResult PrintFormattedProcOp::canonicalize(PrintFormattedProcOp op,
650 PatternRewriter &rewriter) {
652 if (
auto litInput = op.getInput().getDefiningOp<FormatLiteralOp>()) {
653 if (litInput.getLiteral().empty()) {
654 rewriter.eraseOp(op);
661OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
662 return adaptor.getLiteralAttr();
665OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
666 auto operands = adaptor.getInputs();
667 if (operands.empty())
668 return StringAttr::get(getContext(),
"");
670 SmallString<128> result;
671 for (
auto &operand : operands) {
672 auto strAttr = cast_if_present<StringAttr>(operand);
675 result += strAttr.getValue();
678 return StringAttr::get(getContext(), result);
681OpFoldResult StringLengthOp::fold(FoldAdaptor adaptor) {
682 auto inputAttr = adaptor.getInput();
686 if (
auto strAttr = cast<StringAttr>(inputAttr))
687 return IntegerAttr::get(getType(), strAttr.getValue().size());
692OpFoldResult IntToStringOp::fold(FoldAdaptor adaptor) {
693 auto intAttr = cast_or_null<IntegerAttr>(adaptor.getInput());
697 SmallString<128> result;
698 auto width = intAttr.getType().getIntOrFloatBitWidth();
703 for (
unsigned int i = 0; i < width; i += 8) {
705 intAttr.getValue().extractBitsAsZExtValue(std::min(width - i, 8U), i);
707 result.push_back(
static_cast<char>(
byte));
709 std::reverse(result.begin(), result.end());
710 return StringAttr::get(getContext(), result);
718OpFoldResult StringGetOp::fold(FoldAdaptor adaptor) {
719 auto strAttr = cast_or_null<StringAttr>(adaptor.getStr());
720 auto indexAttr = cast_or_null<IntegerAttr>(adaptor.getIndex());
721 if (!strAttr || !indexAttr)
724 auto str = strAttr.getValue();
725 int64_t index = indexAttr.getValue().getSExtValue();
728 if (index < 0 || index >=
static_cast<int64_t
>(str.size()))
729 return IntegerAttr::get(getType(), 0);
732 uint8_t ch =
static_cast<uint8_t
>(str[index]);
733 return IntegerAttr::get(getType(), ch);
740LogicalResult QueueResizeOp::verify() {
741 if (cast<QueueType>(getInput().getType()).getElementType() !=
742 cast<QueueType>(getResult().getType()).getElementType())
747LogicalResult QueueFromArrayOp::verify() {
748 auto queueElementType =
749 cast<QueueType>(getResult().getType()).getElementType();
751 auto arrayElementType =
752 cast<hw::ArrayType>(getInput().getType()).getElementType();
754 if (queueElementType != arrayElementType) {
755 return emitOpError() <<
"sim::Queue element type " << queueElementType
756 <<
" doesn't match hw::ArrayType element type "
763LogicalResult QueueConcatOp::verify() {
766 auto resultElType = cast<QueueType>(getResult().getType()).getElementType();
768 for (Value input : getInputs()) {
769 auto inpElType = cast<QueueType>(input.getType()).getElementType();
770 if (inpElType != resultElType) {
771 return emitOpError() <<
"sim::Queue element type " << inpElType
772 <<
" doesn't match result sim::Queue element type "
784void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
785 Value clock, Value condition) {
786 odsState.addOperands(clock);
788 odsState.addOperands(condition);
790 auto *region = odsState.addRegion();
791 region->push_back(
new Block());
794void TriggeredOp::build(OpBuilder &builder, OperationState &odsState,
795 Value clock, Value condition,
796 llvm::function_ref<
void()> bodyCtor) {
797 OpBuilder::InsertionGuard guard(builder);
799 odsState.addOperands(clock);
801 odsState.addOperands(condition);
803 builder.createBlock(odsState.addRegion());
812#include "circt/Dialect/Sim/SimOpInterfaces.cpp.inc"
815#define GET_OP_CLASSES
816#include "circt/Dialect/Sim/Sim.cpp.inc"
817#include "circt/Dialect/Sim/SimEnums.cpp.inc"
assert(baseType &&"element must be base type")
static StringAttr formatFloatsBySpecifier(MLIRContext *ctx, Attribute value, bool isLeftAligned, std::optional< unsigned > fieldWidth, std::optional< unsigned > fracDigits, std::string formatSpecifier)
static StringAttr formatIntegersByRadix(MLIRContext *ctx, unsigned radix, const Attribute &value, bool isUpperCase, bool isLeftAligned, char paddingChar, std::optional< unsigned > specifierWidth, bool isSigned=false)
static StringAttr concatLiterals(MLIRContext *ctxt, ArrayRef< StringRef > lits)
llvm::StringRef stringifyDPIDirectionKeyword(DPIDirection dir)
Return the keyword string for a DPIDirection (e.g. "in", "return").
std::optional< DPIDirection > parseDPIDirectionKeyword(llvm::StringRef keyword)
Parse a keyword string to a DPIDirection. Returns std::nullopt on failure.
bool isCallOperandDir(DPIDirection dir)
True if an argument with this direction is a call operand (input/inout/ref).
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.