CIRCT 22.0.0git
Loading...
Searching...
No Matches
SimOps.cpp
Go to the documentation of this file.
1//===- SimOps.cpp - Implement the Sim operations ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements `sim` dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Interfaces/FunctionImplementation.h"
20#include "llvm/ADT/MapVector.h"
21
22using namespace mlir;
23using namespace circt;
24using namespace sim;
25
26static StringAttr formatIntegersByRadix(MLIRContext *ctx, unsigned radix,
27 const Attribute &value,
28 bool isUpperCase, bool isLeftAligned,
29 char paddingChar,
30 std::optional<unsigned> specifierWidth,
31 bool isSigned = false) {
32 auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(value);
33 if (!intAttr)
34 return {};
35 if (intAttr.getType().getIntOrFloatBitWidth() == 0)
36 return StringAttr::get(ctx, "");
37
38 SmallVector<char, 32> strBuf;
39 intAttr.getValue().toString(strBuf, radix, isSigned, false, isUpperCase);
40 unsigned width = intAttr.getType().getIntOrFloatBitWidth();
41
42 unsigned padWidth;
43 switch (radix) {
44 case 2:
45 padWidth = width;
46 break;
47 case 8:
48 padWidth = (width + 2) / 3;
49 break;
50 case 16:
51 padWidth = (width + 3) / 4;
52 break;
53 default:
54 padWidth = width;
55 break;
56 }
57
58 unsigned numSpaces = 0;
59 if (specifierWidth.has_value() &&
60 (specifierWidth.value() >
61 std::max(padWidth, static_cast<unsigned>(strBuf.size())))) {
62 numSpaces = std::max(
63 0U, specifierWidth.value() -
64 std::max(padWidth, static_cast<unsigned>(strBuf.size())));
65 }
66
67 SmallVector<char, 1> spacePadding(numSpaces, ' ');
68
69 padWidth = padWidth > strBuf.size() ? padWidth - strBuf.size() : 0;
70
71 SmallVector<char, 32> padding(padWidth, paddingChar);
72 if (isLeftAligned) {
73 return StringAttr::get(ctx, Twine(padding) + Twine(strBuf) +
74 Twine(spacePadding));
75 }
76 return StringAttr::get(ctx,
77 Twine(spacePadding) + Twine(padding) + Twine(strBuf));
78}
79
80static StringAttr formatFloatsBySpecifier(MLIRContext *ctx, Attribute value,
81 bool isLeftAligned,
82 std::optional<unsigned> fieldWidth,
83 std::optional<unsigned> fracDigits,
84 std::string formatSpecifier) {
85 if (auto floatAttr = llvm::dyn_cast_or_null<FloatAttr>(value)) {
86 std::string widthString = isLeftAligned ? "-" : "";
87 if (fieldWidth.has_value()) {
88 widthString += std::to_string(fieldWidth.value());
89 }
90 std::string fmtSpecifier = "%" + widthString + "." +
91 std::to_string(fracDigits.value()) +
92 formatSpecifier;
93
94 // Calculates number of bytes needed to store the format string
95 // excluding the null terminator
96 int bufferSize = std::snprintf(nullptr, 0, fmtSpecifier.c_str(),
97 floatAttr.getValue().convertToDouble());
98 std::string floatFmtBuffer(bufferSize, '\0');
99 snprintf(floatFmtBuffer.data(), bufferSize + 1, fmtSpecifier.c_str(),
100 floatAttr.getValue().convertToDouble());
101 return StringAttr::get(ctx, floatFmtBuffer);
102 }
103 return {};
104}
105
106ParseResult DPIFuncOp::parse(OpAsmParser &parser, OperationState &result) {
107 auto builder = parser.getBuilder();
108 // Parse visibility.
109 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
110
111 // Parse the name as a symbol.
112 StringAttr nameAttr;
113 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
114 result.attributes))
115 return failure();
116
117 SmallVector<hw::module_like_impl::PortParse> ports;
118 TypeAttr modType;
119 if (failed(
120 hw::module_like_impl::parseModuleSignature(parser, ports, modType)))
121 return failure();
122
123 result.addAttribute(DPIFuncOp::getModuleTypeAttrName(result.name), modType);
124
125 // Convert the specified array of dictionary attrs (which may have null
126 // entries) to an ArrayAttr of dictionaries.
127 auto unknownLoc = builder.getUnknownLoc();
128 SmallVector<Attribute> attrs, locs;
129 auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
130 return attr && cast<Location>(attr) != unknownLoc;
131 };
132
133 for (auto &port : ports) {
134 attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
135 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
136 }
137
138 result.addAttribute(DPIFuncOp::getPerArgumentAttrsAttrName(result.name),
139 builder.getArrayAttr(attrs));
140 result.addRegion();
141
142 if (llvm::any_of(locs, nonEmptyLocsFn))
143 result.addAttribute(DPIFuncOp::getArgumentLocsAttrName(result.name),
144 builder.getArrayAttr(locs));
145
146 // Parse the attribute dict.
147 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
148 return failure();
149
150 return success();
151}
152
153LogicalResult
154sim::DPICallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
155 auto referencedOp =
156 symbolTable.lookupNearestSymbolFrom(*this, getCalleeAttr());
157 if (!referencedOp)
158 return emitError("cannot find function declaration '")
159 << getCallee() << "'";
160 if (isa<func::FuncOp, sim::DPIFuncOp>(referencedOp))
161 return success();
162 return emitError("callee must be 'sim.dpi.func' or 'func.func' but got '")
163 << referencedOp->getName() << "'";
164}
165
166void DPIFuncOp::print(OpAsmPrinter &p) {
167 DPIFuncOp op = *this;
168 // Print the operation and the function name.
169 auto funcName =
170 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
171 .getValue();
172 p << ' ';
173
174 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
175 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
176 p << visibility.getValue() << ' ';
177 p.printSymbolName(funcName);
179 p, op->getRegion(0), op.getModuleType(),
180 getPerArgumentAttrsAttr()
181 ? ArrayRef<Attribute>(getPerArgumentAttrsAttr().getValue())
182 : ArrayRef<Attribute>{},
183 getArgumentLocs() ? SmallVector<Location>(
184 getArgumentLocs().value().getAsRange<Location>())
185 : ArrayRef<Location>{});
186
187 mlir::function_interface_impl::printFunctionAttributes(
188 p, op,
189 {visibilityAttrName, getModuleTypeAttrName(),
190 getPerArgumentAttrsAttrName(), getArgumentLocsAttrName()});
191}
192
193OpFoldResult FormatLiteralOp::fold(FoldAdaptor adaptor) {
194 return getLiteralAttr();
195}
196
197// --- FormatDecOp ---
198
199StringAttr FormatDecOp::formatConstant(Attribute constVal) {
200 auto intAttr = llvm::dyn_cast<IntegerAttr>(constVal);
201 if (!intAttr)
202 return {};
203 SmallVector<char, 16> strBuf;
204 intAttr.getValue().toString(strBuf, 10, getIsSigned());
205 unsigned padWidth;
206 if (getSpecifierWidth().has_value()) {
207 padWidth = getSpecifierWidth().value();
208 } else {
209 unsigned width = intAttr.getType().getIntOrFloatBitWidth();
210 padWidth = FormatDecOp::getDecimalWidth(width, getIsSigned());
211 }
212
213 padWidth = padWidth > strBuf.size() ? padWidth - strBuf.size() : 0;
214
215 SmallVector<char, 10> padding(padWidth, getPaddingChar());
216 if (getIsLeftAligned())
217 return StringAttr::get(getContext(), Twine(strBuf) + Twine(padding));
218 return StringAttr::get(getContext(), Twine(padding) + Twine(strBuf));
219}
220
221OpFoldResult FormatDecOp::fold(FoldAdaptor adaptor) {
222 if (getValue().getType().getIntOrFloatBitWidth() == 0)
223 return StringAttr::get(getContext(), "0");
224 return {};
225}
226
227// --- FormatHexOp ---
228
229StringAttr FormatHexOp::formatConstant(Attribute constVal) {
230 return formatIntegersByRadix(constVal.getContext(), 16, constVal,
231 getIsHexUppercase(), getIsLeftAligned(),
232 getPaddingChar(), getSpecifierWidth());
233}
234
235OpFoldResult FormatHexOp::fold(FoldAdaptor adaptor) {
236 if (getValue().getType().getIntOrFloatBitWidth() == 0)
238 getContext(), 16, IntegerAttr::get(getValue().getType(), 0), false,
239 getIsLeftAligned(), getPaddingChar(), getSpecifierWidth());
240 return {};
241}
242
243// --- FormatOctOp ---
244
245StringAttr FormatOctOp::formatConstant(Attribute constVal) {
246 return formatIntegersByRadix(constVal.getContext(), 8, constVal, false,
247 getIsLeftAligned(), getPaddingChar(),
248 getSpecifierWidth());
249}
250
251OpFoldResult FormatOctOp::fold(FoldAdaptor adaptor) {
252 if (getValue().getType().getIntOrFloatBitWidth() == 0)
254 getContext(), 8, IntegerAttr::get(getValue().getType(), 0), false,
255 getIsLeftAligned(), getPaddingChar(), getSpecifierWidth());
256 return {};
257}
258
259// --- FormatBinOp ---
260
261StringAttr FormatBinOp::formatConstant(Attribute constVal) {
262 return formatIntegersByRadix(constVal.getContext(), 2, constVal, false,
263 getIsLeftAligned(), getPaddingChar(),
264 getSpecifierWidth());
265}
266
267OpFoldResult FormatBinOp::fold(FoldAdaptor adaptor) {
268 if (getValue().getType().getIntOrFloatBitWidth() == 0)
270 getContext(), 2, IntegerAttr::get(getValue().getType(), 0), false,
271 getIsLeftAligned(), getPaddingChar(), getSpecifierWidth());
272 return {};
273}
274
275// --- FormatScientificOp ---
276
277StringAttr FormatScientificOp::formatConstant(Attribute constVal) {
278 return formatFloatsBySpecifier(getContext(), constVal, getIsLeftAligned(),
279 getFieldWidth(), getFracDigits(), "e");
280}
281
282// --- FormatFloatOp ---
283
284StringAttr FormatFloatOp::formatConstant(Attribute constVal) {
285 return formatFloatsBySpecifier(getContext(), constVal, getIsLeftAligned(),
286 getFieldWidth(), getFracDigits(), "f");
287}
288
289// --- FormatGeneralOp ---
290
291StringAttr FormatGeneralOp::formatConstant(Attribute constVal) {
292 return formatFloatsBySpecifier(getContext(), constVal, getIsLeftAligned(),
293 getFieldWidth(), getFracDigits(), "g");
294}
295
296// --- FormatCharOp ---
297
298StringAttr FormatCharOp::formatConstant(Attribute constVal) {
299 auto intCst = dyn_cast<IntegerAttr>(constVal);
300 if (!intCst)
301 return {};
302 if (intCst.getType().getIntOrFloatBitWidth() == 0)
303 return StringAttr::get(getContext(), Twine(static_cast<char>(0)));
304 if (intCst.getType().getIntOrFloatBitWidth() > 8)
305 return {};
306 auto intValue = intCst.getValue().getZExtValue();
307 return StringAttr::get(getContext(), Twine(static_cast<char>(intValue)));
308}
309
310OpFoldResult FormatCharOp::fold(FoldAdaptor adaptor) {
311 if (getValue().getType().getIntOrFloatBitWidth() == 0)
312 return StringAttr::get(getContext(), Twine(static_cast<char>(0)));
313 return {};
314}
315
316static StringAttr concatLiterals(MLIRContext *ctxt, ArrayRef<StringRef> lits) {
317 assert(!lits.empty() && "No literals to concatenate");
318 if (lits.size() == 1)
319 return StringAttr::get(ctxt, lits.front());
320 SmallString<64> newLit;
321 for (auto lit : lits)
322 newLit += lit;
323 return StringAttr::get(ctxt, newLit);
324}
325
326OpFoldResult FormatStringConcatOp::fold(FoldAdaptor adaptor) {
327 if (getNumOperands() == 0)
328 return StringAttr::get(getContext(), "");
329 if (getNumOperands() == 1) {
330 // Don't fold to our own result to avoid an infinte loop.
331 if (getResult() == getOperand(0))
332 return {};
333 return getOperand(0);
334 }
335
336 // Fold if all operands are literals.
337 SmallVector<StringRef> lits;
338 for (auto attr : adaptor.getInputs()) {
339 auto lit = dyn_cast_or_null<StringAttr>(attr);
340 if (!lit)
341 return {};
342 lits.push_back(lit);
343 }
344 return concatLiterals(getContext(), lits);
345}
346
347LogicalResult FormatStringConcatOp::getFlattenedInputs(
348 llvm::SmallVectorImpl<Value> &flatOperands) {
349 llvm::SmallMapVector<FormatStringConcatOp, unsigned, 4> concatStack;
350 bool isCyclic = false;
351
352 // Perform a DFS on this operation's concatenated operands,
353 // collect the leaf format string fragments.
354 concatStack.insert({*this, 0});
355 while (!concatStack.empty()) {
356 auto &top = concatStack.back();
357 auto currentConcat = top.first;
358 unsigned operandIndex = top.second;
359
360 // Iterate over concatenated operands
361 while (operandIndex < currentConcat.getNumOperands()) {
362 auto currentOperand = currentConcat.getOperand(operandIndex);
363
364 if (auto nextConcat =
365 currentOperand.getDefiningOp<FormatStringConcatOp>()) {
366 // Concat of a concat
367 if (!concatStack.contains(nextConcat)) {
368 // Save the next operand index to visit on the
369 // stack and put the new concat on top.
370 top.second = operandIndex + 1;
371 concatStack.insert({nextConcat, 0});
372 break;
373 }
374 // Cyclic concatenation encountered. Don't recurse.
375 isCyclic = true;
376 }
377
378 flatOperands.push_back(currentOperand);
379 operandIndex++;
380 }
381
382 // Pop the concat off of the stack if we have visited all operands.
383 if (operandIndex >= currentConcat.getNumOperands())
384 concatStack.pop_back();
385 }
386
387 return success(!isCyclic);
388}
389
390LogicalResult FormatStringConcatOp::verify() {
391 if (llvm::any_of(getOperands(),
392 [&](Value operand) { return operand == getResult(); }))
393 return emitOpError("is infinitely recursive.");
394 return success();
395}
396
397LogicalResult FormatStringConcatOp::canonicalize(FormatStringConcatOp op,
398 PatternRewriter &rewriter) {
399
400 auto fmtStrType = FormatStringType::get(op.getContext());
401
402 // Check if we can flatten concats of concats
403 bool hasBeenFlattened = false;
404 SmallVector<Value, 0> flatOperands;
405 if (!op.isFlat()) {
406 // Get a new, flattened list of operands
407 flatOperands.reserve(op.getNumOperands() + 4);
408 auto isAcyclic = op.getFlattenedInputs(flatOperands);
409
410 if (failed(isAcyclic)) {
411 // Infinite recursion, but we cannot fail compilation right here (can we?)
412 // so just emit a warning and bail out.
413 op.emitWarning("Cyclic concatenation detected.");
414 return failure();
415 }
416
417 hasBeenFlattened = true;
418 }
419
420 if (!hasBeenFlattened && op.getNumOperands() < 2)
421 return failure(); // Should be handled by the folder
422
423 // Check if there are adjacent literals we can merge or empty literals to
424 // remove
425 SmallVector<StringRef> litSequence;
426 SmallVector<Value> newOperands;
427 newOperands.reserve(op.getNumOperands());
428 FormatLiteralOp prevLitOp;
429
430 auto oldOperands = hasBeenFlattened ? flatOperands : op.getOperands();
431 for (auto operand : oldOperands) {
432 if (auto litOp = operand.getDefiningOp<FormatLiteralOp>()) {
433 if (!litOp.getLiteral().empty()) {
434 prevLitOp = litOp;
435 litSequence.push_back(litOp.getLiteral());
436 }
437 } else {
438 if (!litSequence.empty()) {
439 if (litSequence.size() > 1) {
440 // Create a fused literal.
441 auto newLit = rewriter.createOrFold<FormatLiteralOp>(
442 op.getLoc(), fmtStrType,
443 concatLiterals(op.getContext(), litSequence));
444 newOperands.push_back(newLit);
445 } else {
446 // Reuse the existing literal.
447 newOperands.push_back(prevLitOp.getResult());
448 }
449 litSequence.clear();
450 }
451 newOperands.push_back(operand);
452 }
453 }
454
455 // Push trailing literals into the new operand list
456 if (!litSequence.empty()) {
457 if (litSequence.size() > 1) {
458 // Create a fused literal.
459 auto newLit = rewriter.createOrFold<FormatLiteralOp>(
460 op.getLoc(), fmtStrType,
461 concatLiterals(op.getContext(), litSequence));
462 newOperands.push_back(newLit);
463 } else {
464 // Reuse the existing literal.
465 newOperands.push_back(prevLitOp.getResult());
466 }
467 }
468
469 if (!hasBeenFlattened && newOperands.size() == op.getNumOperands())
470 return failure(); // Nothing changed
471
472 if (newOperands.empty())
473 rewriter.replaceOpWithNewOp<FormatLiteralOp>(op, fmtStrType,
474 rewriter.getStringAttr(""));
475 else if (newOperands.size() == 1)
476 rewriter.replaceOp(op, newOperands);
477 else
478 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); });
479
480 return success();
481}
482
483LogicalResult PrintFormattedOp::canonicalize(PrintFormattedOp op,
484 PatternRewriter &rewriter) {
485 // Remove ops with constant false condition.
486 if (auto cstCond = op.getCondition().getDefiningOp<hw::ConstantOp>()) {
487 if (cstCond.getValue().isZero()) {
488 rewriter.eraseOp(op);
489 return success();
490 }
491 }
492 return failure();
493}
494
495LogicalResult PrintFormattedProcOp::verify() {
496 // Check if we know for sure that the parent is not procedural.
497 auto *parentOp = getOperation()->getParentOp();
498
499 if (!parentOp)
500 return emitOpError("must be within a procedural region.");
501
502 if (isa_and_nonnull<hw::HWDialect>(parentOp->getDialect())) {
503 if (!isa<hw::TriggeredOp>(parentOp))
504 return emitOpError("must be within a procedural region.");
505 return success();
506 }
507
508 if (isa_and_nonnull<sv::SVDialect>(parentOp->getDialect())) {
509 if (!parentOp->hasTrait<sv::ProceduralRegion>())
510 return emitOpError("must be within a procedural region.");
511 return success();
512 }
513
514 // Don't fail for dialects that are not explicitly handled.
515 return success();
516}
517
518LogicalResult PrintFormattedProcOp::canonicalize(PrintFormattedProcOp op,
519 PatternRewriter &rewriter) {
520 // Remove empty prints.
521 if (auto litInput = op.getInput().getDefiningOp<FormatLiteralOp>()) {
522 if (litInput.getLiteral().empty()) {
523 rewriter.eraseOp(op);
524 return success();
525 }
526 }
527 return failure();
528}
529
530OpFoldResult StringConstantOp::fold(FoldAdaptor adaptor) {
531 return adaptor.getLiteralAttr();
532}
533
534OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
535 auto operands = adaptor.getInputs();
536 if (operands.empty())
537 return StringAttr::get(getContext(), "");
538
539 SmallString<128> result;
540 for (auto &operand : operands) {
541 if (auto strAttr = cast<StringAttr>(operand))
542 result += strAttr.getValue();
543 else
544 return {};
545 }
546
547 return StringAttr::get(getContext(), result);
548}
549
550OpFoldResult StringLengthOp::fold(FoldAdaptor adaptor) {
551 auto inputAttr = adaptor.getInput();
552 if (!inputAttr)
553 return {};
554
555 if (auto strAttr = cast<StringAttr>(inputAttr))
556 return IntegerAttr::get(getType(), strAttr.getValue().size());
557
558 return {};
559}
560
561OpFoldResult IntToStringOp::fold(FoldAdaptor adaptor) {
562 auto intAttr = cast_or_null<IntegerAttr>(adaptor.getInput());
563 if (!intAttr)
564 return {};
565
566 SmallString<128> result;
567 auto width = intAttr.getType().getIntOrFloatBitWidth();
568 // Starting from the LSB, we extract the values byte-by-byte,
569 // and convert each non-null byte to a char
570
571 // For example 0x00_00_00_48_00_00_6C_6F would look like "Hlo"
572 for (unsigned int i = 0; i < width; i += 8) {
573 auto byte =
574 intAttr.getValue().extractBitsAsZExtValue(std::min(width - i, 8U), i);
575 if (byte)
576 result.push_back(static_cast<char>(byte));
577 }
578 std::reverse(result.begin(), result.end());
579 return StringAttr::get(getContext(), result);
580 return {};
581}
582
583//===----------------------------------------------------------------------===//
584// TableGen generated logic.
585//===----------------------------------------------------------------------===//
586
587#include "circt/Dialect/Sim/SimOpInterfaces.cpp.inc"
588
589// Provide the autogenerated implementation guts for the Op classes.
590#define GET_OP_CLASSES
591#include "circt/Dialect/Sim/Sim.cpp.inc"
assert(baseType &&"element must be base type")
static StringAttr formatFloatsBySpecifier(MLIRContext *ctx, Attribute value, bool isLeftAligned, std::optional< unsigned > fieldWidth, std::optional< unsigned > fracDigits, std::string formatSpecifier)
Definition SimOps.cpp:80
static StringAttr formatIntegersByRadix(MLIRContext *ctx, unsigned radix, const Attribute &value, bool isUpperCase, bool isLeftAligned, char paddingChar, std::optional< unsigned > specifierWidth, bool isSigned=false)
Definition SimOps.cpp:26
static StringAttr concatLiterals(MLIRContext *ctxt, ArrayRef< StringRef > lits)
Definition SimOps.cpp:316
Signals that an operations regions are procedural.
Definition SVOps.h:176
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition sim.py:1