CIRCT 22.0.0git
Loading...
Searching...
No Matches
SimOps.cpp
Go to the documentation of this file.
1//===- SimOps.cpp - Implement the Sim operations ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements `sim` dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
17#include "mlir/Dialect/Func/IR/FuncOps.h"
18#include "mlir/IR/PatternMatch.h"
19#include "mlir/Interfaces/FunctionImplementation.h"
20#include "llvm/ADT/MapVector.h"
21
22using namespace mlir;
23using namespace circt;
24using namespace sim;
25
26static StringAttr formatIntegersByRadix(MLIRContext *ctx, unsigned radix,
27 const Attribute &value,
28 bool isUpperCase, bool isLeftAligned,
29 char paddingChar,
30 std::optional<unsigned> specifierWidth,
31 bool isSigned = false) {
32
33 if (auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(value)) {
34 SmallVector<char, 32> strBuf;
35 intAttr.getValue().toString(strBuf, radix, isSigned, false, isUpperCase);
36 unsigned width = intAttr.getType().getIntOrFloatBitWidth();
37 unsigned padWidth;
38 switch (radix) {
39 case 2:
40 padWidth = width;
41 break;
42 case 8:
43 padWidth = (width + 2) / 3;
44 break;
45 case 16:
46 padWidth = (width + 3) / 4;
47 break;
48 default:
49 padWidth = width;
50 break;
51 }
52
53 unsigned numSpaces = 0;
54 if (specifierWidth.has_value() &&
55 (specifierWidth.value() >
56 std::max(padWidth, static_cast<unsigned>(strBuf.size())))) {
57 numSpaces = std::max(
58 0U, specifierWidth.value() -
59 std::max(padWidth, static_cast<unsigned>(strBuf.size())));
60 }
61
62 SmallVector<char, 1> spacePadding(numSpaces, ' ');
63
64 padWidth = padWidth > strBuf.size() ? padWidth - strBuf.size() : 0;
65
66 SmallVector<char, 32> padding(padWidth, paddingChar);
67 if (isLeftAligned) {
68 return StringAttr::get(ctx, Twine(padding) + Twine(strBuf) +
69 Twine(spacePadding));
70 }
71 return StringAttr::get(ctx, Twine(spacePadding) + Twine(padding) +
72 Twine(strBuf));
73 }
74 return {};
75}
76
77static StringAttr formatFloatsBySpecifier(MLIRContext *ctx, Attribute value,
78 bool isLeftAligned,
79 std::optional<unsigned> fieldWidth,
80 std::optional<unsigned> fracDigits,
81 std::string formatSpecifier) {
82 if (auto floatAttr = llvm::dyn_cast_or_null<FloatAttr>(value)) {
83 std::string widthString = isLeftAligned ? "-" : "";
84 if (fieldWidth.has_value()) {
85 widthString += std::to_string(fieldWidth.value());
86 }
87 std::string fmtSpecifier = "%" + widthString + "." +
88 std::to_string(fracDigits.value()) +
89 formatSpecifier;
90
91 // Calculates number of bytes needed to store the format string
92 // excluding the null terminator
93 int bufferSize = std::snprintf(nullptr, 0, fmtSpecifier.c_str(),
94 floatAttr.getValue().convertToDouble());
95 std::string floatFmtBuffer(bufferSize, '\0');
96 snprintf(floatFmtBuffer.data(), bufferSize + 1, fmtSpecifier.c_str(),
97 floatAttr.getValue().convertToDouble());
98 return StringAttr::get(ctx, floatFmtBuffer);
99 }
100 return {};
101}
102
103ParseResult DPIFuncOp::parse(OpAsmParser &parser, OperationState &result) {
104 auto builder = parser.getBuilder();
105 // Parse visibility.
106 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
107
108 // Parse the name as a symbol.
109 StringAttr nameAttr;
110 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
111 result.attributes))
112 return failure();
113
114 SmallVector<hw::module_like_impl::PortParse> ports;
115 TypeAttr modType;
116 if (failed(
117 hw::module_like_impl::parseModuleSignature(parser, ports, modType)))
118 return failure();
119
120 result.addAttribute(DPIFuncOp::getModuleTypeAttrName(result.name), modType);
121
122 // Convert the specified array of dictionary attrs (which may have null
123 // entries) to an ArrayAttr of dictionaries.
124 auto unknownLoc = builder.getUnknownLoc();
125 SmallVector<Attribute> attrs, locs;
126 auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
127 return attr && cast<Location>(attr) != unknownLoc;
128 };
129
130 for (auto &port : ports) {
131 attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
132 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
133 }
134
135 result.addAttribute(DPIFuncOp::getPerArgumentAttrsAttrName(result.name),
136 builder.getArrayAttr(attrs));
137 result.addRegion();
138
139 if (llvm::any_of(locs, nonEmptyLocsFn))
140 result.addAttribute(DPIFuncOp::getArgumentLocsAttrName(result.name),
141 builder.getArrayAttr(locs));
142
143 // Parse the attribute dict.
144 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
145 return failure();
146
147 return success();
148}
149
150LogicalResult
151sim::DPICallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
152 auto referencedOp =
153 symbolTable.lookupNearestSymbolFrom(*this, getCalleeAttr());
154 if (!referencedOp)
155 return emitError("cannot find function declaration '")
156 << getCallee() << "'";
157 if (isa<func::FuncOp, sim::DPIFuncOp>(referencedOp))
158 return success();
159 return emitError("callee must be 'sim.dpi.func' or 'func.func' but got '")
160 << referencedOp->getName() << "'";
161}
162
163void DPIFuncOp::print(OpAsmPrinter &p) {
164 DPIFuncOp op = *this;
165 // Print the operation and the function name.
166 auto funcName =
167 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
168 .getValue();
169 p << ' ';
170
171 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
172 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
173 p << visibility.getValue() << ' ';
174 p.printSymbolName(funcName);
176 p, op->getRegion(0), op.getModuleType(),
177 getPerArgumentAttrsAttr()
178 ? ArrayRef<Attribute>(getPerArgumentAttrsAttr().getValue())
179 : ArrayRef<Attribute>{},
180 getArgumentLocs() ? SmallVector<Location>(
181 getArgumentLocs().value().getAsRange<Location>())
182 : ArrayRef<Location>{});
183
184 mlir::function_interface_impl::printFunctionAttributes(
185 p, op,
186 {visibilityAttrName, getModuleTypeAttrName(),
187 getPerArgumentAttrsAttrName(), getArgumentLocsAttrName()});
188}
189
190OpFoldResult FormatLiteralOp::fold(FoldAdaptor adaptor) {
191 return getLiteralAttr();
192}
193
194OpFoldResult FormatDecOp::fold(FoldAdaptor adaptor) {
195 if (getValue().getType() == IntegerType::get(getContext(), 0U))
196 return StringAttr::get(getContext(), "0");
197
198 if (auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(adaptor.getValue())) {
199 SmallVector<char, 16> strBuf;
200 intAttr.getValue().toString(strBuf, 10U, adaptor.getIsSigned());
201 unsigned padWidth;
202 if (adaptor.getSpecifierWidth().has_value()) {
203 padWidth = adaptor.getSpecifierWidth().value();
204 } else {
205 unsigned width = intAttr.getType().getIntOrFloatBitWidth();
206 padWidth = FormatDecOp::getDecimalWidth(width, adaptor.getIsSigned());
207 }
208
209 padWidth = padWidth > strBuf.size() ? padWidth - strBuf.size() : 0;
210
211 SmallVector<char, 10> padding(padWidth, adaptor.getPaddingChar());
212 if (adaptor.getIsLeftAligned()) {
213 return StringAttr::get(getContext(), Twine(strBuf) + Twine(padding));
214 }
215 return StringAttr::get(getContext(), Twine(padding) + Twine(strBuf));
216 }
217 return {};
218}
219
220OpFoldResult FormatHexOp::fold(FoldAdaptor adaptor) {
221 if (getValue().getType() == IntegerType::get(getContext(), 0U))
222 return StringAttr::get(getContext(), "");
223
225 getContext(), 16U, adaptor.getValue(), adaptor.getIsHexUppercase(),
226 adaptor.getIsLeftAligned(), adaptor.getPaddingChar(),
227 adaptor.getSpecifierWidth());
228}
229
230OpFoldResult FormatOctOp::fold(FoldAdaptor adaptor) {
231 if (getValue().getType() == IntegerType::get(getContext(), 0U))
232 return StringAttr::get(getContext(), "");
233
235 getContext(), 8U, adaptor.getValue(), false, adaptor.getIsLeftAligned(),
236 adaptor.getPaddingChar(), adaptor.getSpecifierWidth());
237}
238
239OpFoldResult FormatBinOp::fold(FoldAdaptor adaptor) {
240 if (getValue().getType() == IntegerType::get(getContext(), 0U))
241 return StringAttr::get(getContext(), "");
242
244 getContext(), 2U, adaptor.getValue(), false, adaptor.getIsLeftAligned(),
245 adaptor.getPaddingChar(), adaptor.getSpecifierWidth());
246}
247
248OpFoldResult FormatScientificOp::fold(FoldAdaptor adaptor) {
250 getContext(), adaptor.getValue(), adaptor.getIsLeftAligned(),
251 adaptor.getFieldWidth(), adaptor.getFracDigits(), "e");
252}
253
254OpFoldResult FormatFloatOp::fold(FoldAdaptor adaptor) {
256 getContext(), adaptor.getValue(), adaptor.getIsLeftAligned(),
257 adaptor.getFieldWidth(), adaptor.getFracDigits(), "f");
258}
259
260OpFoldResult FormatGeneralOp::fold(FoldAdaptor adaptor) {
262 getContext(), adaptor.getValue(), adaptor.getIsLeftAligned(),
263 adaptor.getFieldWidth(), adaptor.getFracDigits(), "g");
264}
265
266OpFoldResult FormatCharOp::fold(FoldAdaptor adaptor) {
267 auto width = getValue().getType().getIntOrFloatBitWidth();
268 if (width > 8)
269 return {};
270 if (width == 0)
271 return StringAttr::get(getContext(), Twine(static_cast<char>(0)));
272
273 if (auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(adaptor.getValue())) {
274 auto intValue = intAttr.getValue().getZExtValue();
275 return StringAttr::get(getContext(), Twine(static_cast<char>(intValue)));
276 }
277
278 return {};
279}
280
281static StringAttr concatLiterals(MLIRContext *ctxt, ArrayRef<StringRef> lits) {
282 assert(!lits.empty() && "No literals to concatenate");
283 if (lits.size() == 1)
284 return StringAttr::get(ctxt, lits.front());
285 SmallString<64> newLit;
286 for (auto lit : lits)
287 newLit += lit;
288 return StringAttr::get(ctxt, newLit);
289}
290
291OpFoldResult FormatStringConcatOp::fold(FoldAdaptor adaptor) {
292 if (getNumOperands() == 0)
293 return StringAttr::get(getContext(), "");
294 if (getNumOperands() == 1) {
295 // Don't fold to our own result to avoid an infinte loop.
296 if (getResult() == getOperand(0))
297 return {};
298 return getOperand(0);
299 }
300
301 // Fold if all operands are literals.
302 SmallVector<StringRef> lits;
303 for (auto attr : adaptor.getInputs()) {
304 auto lit = dyn_cast_or_null<StringAttr>(attr);
305 if (!lit)
306 return {};
307 lits.push_back(lit);
308 }
309 return concatLiterals(getContext(), lits);
310}
311
312LogicalResult FormatStringConcatOp::getFlattenedInputs(
313 llvm::SmallVectorImpl<Value> &flatOperands) {
314 llvm::SmallMapVector<FormatStringConcatOp, unsigned, 4> concatStack;
315 bool isCyclic = false;
316
317 // Perform a DFS on this operation's concatenated operands,
318 // collect the leaf format string fragments.
319 concatStack.insert({*this, 0});
320 while (!concatStack.empty()) {
321 auto &top = concatStack.back();
322 auto currentConcat = top.first;
323 unsigned operandIndex = top.second;
324
325 // Iterate over concatenated operands
326 while (operandIndex < currentConcat.getNumOperands()) {
327 auto currentOperand = currentConcat.getOperand(operandIndex);
328
329 if (auto nextConcat =
330 currentOperand.getDefiningOp<FormatStringConcatOp>()) {
331 // Concat of a concat
332 if (!concatStack.contains(nextConcat)) {
333 // Save the next operand index to visit on the
334 // stack and put the new concat on top.
335 top.second = operandIndex + 1;
336 concatStack.insert({nextConcat, 0});
337 break;
338 }
339 // Cyclic concatenation encountered. Don't recurse.
340 isCyclic = true;
341 }
342
343 flatOperands.push_back(currentOperand);
344 operandIndex++;
345 }
346
347 // Pop the concat off of the stack if we have visited all operands.
348 if (operandIndex >= currentConcat.getNumOperands())
349 concatStack.pop_back();
350 }
351
352 return success(!isCyclic);
353}
354
355LogicalResult FormatStringConcatOp::verify() {
356 if (llvm::any_of(getOperands(),
357 [&](Value operand) { return operand == getResult(); }))
358 return emitOpError("is infinitely recursive.");
359 return success();
360}
361
362LogicalResult FormatStringConcatOp::canonicalize(FormatStringConcatOp op,
363 PatternRewriter &rewriter) {
364
365 auto fmtStrType = FormatStringType::get(op.getContext());
366
367 // Check if we can flatten concats of concats
368 bool hasBeenFlattened = false;
369 SmallVector<Value, 0> flatOperands;
370 if (!op.isFlat()) {
371 // Get a new, flattened list of operands
372 flatOperands.reserve(op.getNumOperands() + 4);
373 auto isAcyclic = op.getFlattenedInputs(flatOperands);
374
375 if (failed(isAcyclic)) {
376 // Infinite recursion, but we cannot fail compilation right here (can we?)
377 // so just emit a warning and bail out.
378 op.emitWarning("Cyclic concatenation detected.");
379 return failure();
380 }
381
382 hasBeenFlattened = true;
383 }
384
385 if (!hasBeenFlattened && op.getNumOperands() < 2)
386 return failure(); // Should be handled by the folder
387
388 // Check if there are adjacent literals we can merge or empty literals to
389 // remove
390 SmallVector<StringRef> litSequence;
391 SmallVector<Value> newOperands;
392 newOperands.reserve(op.getNumOperands());
393 FormatLiteralOp prevLitOp;
394
395 auto oldOperands = hasBeenFlattened ? flatOperands : op.getOperands();
396 for (auto operand : oldOperands) {
397 if (auto litOp = operand.getDefiningOp<FormatLiteralOp>()) {
398 if (!litOp.getLiteral().empty()) {
399 prevLitOp = litOp;
400 litSequence.push_back(litOp.getLiteral());
401 }
402 } else {
403 if (!litSequence.empty()) {
404 if (litSequence.size() > 1) {
405 // Create a fused literal.
406 auto newLit = rewriter.createOrFold<FormatLiteralOp>(
407 op.getLoc(), fmtStrType,
408 concatLiterals(op.getContext(), litSequence));
409 newOperands.push_back(newLit);
410 } else {
411 // Reuse the existing literal.
412 newOperands.push_back(prevLitOp.getResult());
413 }
414 litSequence.clear();
415 }
416 newOperands.push_back(operand);
417 }
418 }
419
420 // Push trailing literals into the new operand list
421 if (!litSequence.empty()) {
422 if (litSequence.size() > 1) {
423 // Create a fused literal.
424 auto newLit = rewriter.createOrFold<FormatLiteralOp>(
425 op.getLoc(), fmtStrType,
426 concatLiterals(op.getContext(), litSequence));
427 newOperands.push_back(newLit);
428 } else {
429 // Reuse the existing literal.
430 newOperands.push_back(prevLitOp.getResult());
431 }
432 }
433
434 if (!hasBeenFlattened && newOperands.size() == op.getNumOperands())
435 return failure(); // Nothing changed
436
437 if (newOperands.empty())
438 rewriter.replaceOpWithNewOp<FormatLiteralOp>(op, fmtStrType,
439 rewriter.getStringAttr(""));
440 else if (newOperands.size() == 1)
441 rewriter.replaceOp(op, newOperands);
442 else
443 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); });
444
445 return success();
446}
447
448LogicalResult PrintFormattedOp::canonicalize(PrintFormattedOp op,
449 PatternRewriter &rewriter) {
450 // Remove ops with constant false condition.
451 if (auto cstCond = op.getCondition().getDefiningOp<hw::ConstantOp>()) {
452 if (cstCond.getValue().isZero()) {
453 rewriter.eraseOp(op);
454 return success();
455 }
456 }
457 return failure();
458}
459
460LogicalResult PrintFormattedProcOp::verify() {
461 // Check if we know for sure that the parent is not procedural.
462 auto *parentOp = getOperation()->getParentOp();
463
464 if (!parentOp)
465 return emitOpError("must be within a procedural region.");
466
467 if (isa_and_nonnull<hw::HWDialect>(parentOp->getDialect())) {
468 if (!isa<hw::TriggeredOp>(parentOp))
469 return emitOpError("must be within a procedural region.");
470 return success();
471 }
472
473 if (isa_and_nonnull<sv::SVDialect>(parentOp->getDialect())) {
474 if (!parentOp->hasTrait<sv::ProceduralRegion>())
475 return emitOpError("must be within a procedural region.");
476 return success();
477 }
478
479 // Don't fail for dialects that are not explicitly handled.
480 return success();
481}
482
483LogicalResult PrintFormattedProcOp::canonicalize(PrintFormattedProcOp op,
484 PatternRewriter &rewriter) {
485 // Remove empty prints.
486 if (auto litInput = op.getInput().getDefiningOp<FormatLiteralOp>()) {
487 if (litInput.getLiteral().empty()) {
488 rewriter.eraseOp(op);
489 return success();
490 }
491 }
492 return failure();
493}
494
495//===----------------------------------------------------------------------===//
496// TableGen generated logic.
497//===----------------------------------------------------------------------===//
498
499// Provide the autogenerated implementation guts for the Op classes.
500#define GET_OP_CLASSES
501#include "circt/Dialect/Sim/Sim.cpp.inc"
assert(baseType &&"element must be base type")
static StringAttr formatFloatsBySpecifier(MLIRContext *ctx, Attribute value, bool isLeftAligned, std::optional< unsigned > fieldWidth, std::optional< unsigned > fracDigits, std::string formatSpecifier)
Definition SimOps.cpp:77
static StringAttr formatIntegersByRadix(MLIRContext *ctx, unsigned radix, const Attribute &value, bool isUpperCase, bool isLeftAligned, char paddingChar, std::optional< unsigned > specifierWidth, bool isSigned=false)
Definition SimOps.cpp:26
static StringAttr concatLiterals(MLIRContext *ctxt, ArrayRef< StringRef > lits)
Definition SimOps.cpp:281
Signals that an operations regions are procedural.
Definition SVOps.h:176
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition sim.py:1