CIRCT 21.0.0git
Loading...
Searching...
No Matches
SimOps.cpp
Go to the documentation of this file.
1//===- SimOps.cpp - Implement the Sim operations ------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements `sim` dialect ops.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/Dialect/Func/IR/FuncOps.h"
17#include "mlir/IR/PatternMatch.h"
18#include "mlir/Interfaces/FunctionImplementation.h"
19#include "llvm/ADT/MapVector.h"
20
21using namespace mlir;
22using namespace circt;
23using namespace sim;
24
25ParseResult DPIFuncOp::parse(OpAsmParser &parser, OperationState &result) {
26 auto builder = parser.getBuilder();
27 // Parse visibility.
28 (void)mlir::impl::parseOptionalVisibilityKeyword(parser, result.attributes);
29
30 // Parse the name as a symbol.
31 StringAttr nameAttr;
32 if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
33 result.attributes))
34 return failure();
35
36 SmallVector<hw::module_like_impl::PortParse> ports;
37 TypeAttr modType;
38 if (failed(
39 hw::module_like_impl::parseModuleSignature(parser, ports, modType)))
40 return failure();
41
42 result.addAttribute(DPIFuncOp::getModuleTypeAttrName(result.name), modType);
43
44 // Convert the specified array of dictionary attrs (which may have null
45 // entries) to an ArrayAttr of dictionaries.
46 auto unknownLoc = builder.getUnknownLoc();
47 SmallVector<Attribute> attrs, locs;
48 auto nonEmptyLocsFn = [unknownLoc](Attribute attr) {
49 return attr && cast<Location>(attr) != unknownLoc;
50 };
51
52 for (auto &port : ports) {
53 attrs.push_back(port.attrs ? port.attrs : builder.getDictionaryAttr({}));
54 locs.push_back(port.sourceLoc ? Location(*port.sourceLoc) : unknownLoc);
55 }
56
57 result.addAttribute(DPIFuncOp::getPerArgumentAttrsAttrName(result.name),
58 builder.getArrayAttr(attrs));
59 result.addRegion();
60
61 if (llvm::any_of(locs, nonEmptyLocsFn))
62 result.addAttribute(DPIFuncOp::getArgumentLocsAttrName(result.name),
63 builder.getArrayAttr(locs));
64
65 // Parse the attribute dict.
66 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
67 return failure();
68
69 return success();
70}
71
72LogicalResult
73sim::DPICallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
74 auto referencedOp =
75 symbolTable.lookupNearestSymbolFrom(*this, getCalleeAttr());
76 if (!referencedOp)
77 return emitError("cannot find function declaration '")
78 << getCallee() << "'";
79 if (isa<func::FuncOp, sim::DPIFuncOp>(referencedOp))
80 return success();
81 return emitError("callee must be 'sim.dpi.func' or 'func.func' but got '")
82 << referencedOp->getName() << "'";
83}
84
85void DPIFuncOp::print(OpAsmPrinter &p) {
86 DPIFuncOp op = *this;
87 // Print the operation and the function name.
88 auto funcName =
89 op->getAttrOfType<StringAttr>(SymbolTable::getSymbolAttrName())
90 .getValue();
91 p << ' ';
92
93 StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName();
94 if (auto visibility = op->getAttrOfType<StringAttr>(visibilityAttrName))
95 p << visibility.getValue() << ' ';
96 p.printSymbolName(funcName);
98 p, op->getRegion(0), op.getModuleType(),
99 getPerArgumentAttrsAttr()
100 ? ArrayRef<Attribute>(getPerArgumentAttrsAttr().getValue())
101 : ArrayRef<Attribute>{},
102 getArgumentLocs() ? SmallVector<Location>(
103 getArgumentLocs().value().getAsRange<Location>())
104 : ArrayRef<Location>{});
105
106 mlir::function_interface_impl::printFunctionAttributes(
107 p, op,
108 {visibilityAttrName, getModuleTypeAttrName(),
109 getPerArgumentAttrsAttrName(), getArgumentLocsAttrName()});
110}
111
112OpFoldResult FormatLitOp::fold(FoldAdaptor adaptor) { return getLiteralAttr(); }
113
114OpFoldResult FormatDecOp::fold(FoldAdaptor adaptor) {
115 if (getValue().getType() == IntegerType::get(getContext(), 0U))
116 return StringAttr::get(getContext(), "0");
117
118 if (auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(adaptor.getValue())) {
119 SmallVector<char, 16> strBuf;
120 intAttr.getValue().toString(strBuf, 10U, getIsSigned());
121
122 unsigned width = intAttr.getType().getIntOrFloatBitWidth();
123 unsigned padWidth = FormatDecOp::getDecimalWidth(width, getIsSigned());
124 padWidth = padWidth > strBuf.size() ? padWidth - strBuf.size() : 0;
125
126 SmallVector<char, 8> padding(padWidth, ' ');
127 return StringAttr::get(getContext(), Twine(padding) + Twine(strBuf));
128 }
129 return {};
130}
131
132OpFoldResult FormatHexOp::fold(FoldAdaptor adaptor) {
133 if (getValue().getType() == IntegerType::get(getContext(), 0U))
134 return StringAttr::get(getContext(), "");
135
136 if (auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(adaptor.getValue())) {
137 SmallVector<char, 8> strBuf;
138 intAttr.getValue().toString(strBuf, 16U, /*Signed*/ false,
139 /*formatAsCLiteral*/ false,
140 /*UpperCase*/ false);
141
142 unsigned width = intAttr.getType().getIntOrFloatBitWidth();
143 unsigned padWidth = width / 4;
144 if (width % 4 != 0)
145 padWidth++;
146 padWidth = padWidth > strBuf.size() ? padWidth - strBuf.size() : 0;
147
148 SmallVector<char, 8> padding(padWidth, '0');
149 return StringAttr::get(getContext(), Twine(padding) + Twine(strBuf));
150 }
151 return {};
152}
153
154OpFoldResult FormatBinOp::fold(FoldAdaptor adaptor) {
155 if (getValue().getType() == IntegerType::get(getContext(), 0U))
156 return StringAttr::get(getContext(), "");
157
158 if (auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(adaptor.getValue())) {
159 SmallVector<char, 32> strBuf;
160 intAttr.getValue().toString(strBuf, 2U, false);
161
162 unsigned width = intAttr.getType().getIntOrFloatBitWidth();
163 unsigned padWidth = width > strBuf.size() ? width - strBuf.size() : 0;
164
165 SmallVector<char, 32> padding(padWidth, '0');
166 return StringAttr::get(getContext(), Twine(padding) + Twine(strBuf));
167 }
168 return {};
169}
170
171OpFoldResult FormatCharOp::fold(FoldAdaptor adaptor) {
172 auto width = getValue().getType().getIntOrFloatBitWidth();
173 if (width > 8)
174 return {};
175 if (width == 0)
176 return StringAttr::get(getContext(), Twine(static_cast<char>(0)));
177
178 if (auto intAttr = llvm::dyn_cast_or_null<IntegerAttr>(adaptor.getValue())) {
179 auto intValue = intAttr.getValue().getZExtValue();
180 return StringAttr::get(getContext(), Twine(static_cast<char>(intValue)));
181 }
182
183 return {};
184}
185
186static StringAttr concatLiterals(MLIRContext *ctxt, ArrayRef<StringRef> lits) {
187 assert(!lits.empty() && "No literals to concatenate");
188 if (lits.size() == 1)
189 return StringAttr::get(ctxt, lits.front());
190 SmallString<64> newLit;
191 for (auto lit : lits)
192 newLit += lit;
193 return StringAttr::get(ctxt, newLit);
194}
195
196OpFoldResult FormatStringConcatOp::fold(FoldAdaptor adaptor) {
197 if (getNumOperands() == 0)
198 return StringAttr::get(getContext(), "");
199 if (getNumOperands() == 1) {
200 // Don't fold to our own result to avoid an infinte loop.
201 if (getResult() == getOperand(0))
202 return {};
203 return getOperand(0);
204 }
205
206 // Fold if all operands are literals.
207 SmallVector<StringRef> lits;
208 for (auto attr : adaptor.getInputs()) {
209 auto lit = dyn_cast_or_null<StringAttr>(attr);
210 if (!lit)
211 return {};
212 lits.push_back(lit);
213 }
214 return concatLiterals(getContext(), lits);
215}
216
217LogicalResult FormatStringConcatOp::getFlattenedInputs(
218 llvm::SmallVectorImpl<Value> &flatOperands) {
219 llvm::SmallMapVector<FormatStringConcatOp, unsigned, 4> concatStack;
220 bool isCyclic = false;
221
222 // Perform a DFS on this operation's concatenated operands,
223 // collect the leaf format string fragments.
224 concatStack.insert({*this, 0});
225 while (!concatStack.empty()) {
226 auto &top = concatStack.back();
227 auto currentConcat = top.first;
228 unsigned operandIndex = top.second;
229
230 // Iterate over concatenated operands
231 while (operandIndex < currentConcat.getNumOperands()) {
232 auto currentOperand = currentConcat.getOperand(operandIndex);
233
234 if (auto nextConcat =
235 currentOperand.getDefiningOp<FormatStringConcatOp>()) {
236 // Concat of a concat
237 if (!concatStack.contains(nextConcat)) {
238 // Save the next operand index to visit on the
239 // stack and put the new concat on top.
240 top.second = operandIndex + 1;
241 concatStack.insert({nextConcat, 0});
242 break;
243 }
244 // Cyclic concatenation encountered. Don't recurse.
245 isCyclic = true;
246 }
247
248 flatOperands.push_back(currentOperand);
249 operandIndex++;
250 }
251
252 // Pop the concat off of the stack if we have visited all operands.
253 if (operandIndex >= currentConcat.getNumOperands())
254 concatStack.pop_back();
255 }
256
257 return success(!isCyclic);
258}
259
260LogicalResult FormatStringConcatOp::verify() {
261 if (llvm::any_of(getOperands(),
262 [&](Value operand) { return operand == getResult(); }))
263 return emitOpError("is infinitely recursive.");
264 return success();
265}
266
267LogicalResult FormatStringConcatOp::canonicalize(FormatStringConcatOp op,
268 PatternRewriter &rewriter) {
269
270 auto fmtStrType = FormatStringType::get(op.getContext());
271
272 // Check if we can flatten concats of concats
273 bool hasBeenFlattened = false;
274 SmallVector<Value, 0> flatOperands;
275 if (!op.isFlat()) {
276 // Get a new, flattened list of operands
277 flatOperands.reserve(op.getNumOperands() + 4);
278 auto isAcyclic = op.getFlattenedInputs(flatOperands);
279
280 if (failed(isAcyclic)) {
281 // Infinite recursion, but we cannot fail compilation right here (can we?)
282 // so just emit a warning and bail out.
283 op.emitWarning("Cyclic concatenation detected.");
284 return failure();
285 }
286
287 hasBeenFlattened = true;
288 }
289
290 if (!hasBeenFlattened && op.getNumOperands() < 2)
291 return failure(); // Should be handled by the folder
292
293 // Check if there are adjacent literals we can merge or empty literals to
294 // remove
295 SmallVector<StringRef> litSequence;
296 SmallVector<Value> newOperands;
297 newOperands.reserve(op.getNumOperands());
298 FormatLitOp prevLitOp;
299
300 auto oldOperands = hasBeenFlattened ? flatOperands : op.getOperands();
301 for (auto operand : oldOperands) {
302 if (auto litOp = operand.getDefiningOp<FormatLitOp>()) {
303 if (!litOp.getLiteral().empty()) {
304 prevLitOp = litOp;
305 litSequence.push_back(litOp.getLiteral());
306 }
307 } else {
308 if (!litSequence.empty()) {
309 if (litSequence.size() > 1) {
310 // Create a fused literal.
311 auto newLit = rewriter.createOrFold<FormatLitOp>(
312 op.getLoc(), fmtStrType,
313 concatLiterals(op.getContext(), litSequence));
314 newOperands.push_back(newLit);
315 } else {
316 // Reuse the existing literal.
317 newOperands.push_back(prevLitOp.getResult());
318 }
319 litSequence.clear();
320 }
321 newOperands.push_back(operand);
322 }
323 }
324
325 // Push trailing literals into the new operand list
326 if (!litSequence.empty()) {
327 if (litSequence.size() > 1) {
328 // Create a fused literal.
329 auto newLit = rewriter.createOrFold<FormatLitOp>(
330 op.getLoc(), fmtStrType,
331 concatLiterals(op.getContext(), litSequence));
332 newOperands.push_back(newLit);
333 } else {
334 // Reuse the existing literal.
335 newOperands.push_back(prevLitOp.getResult());
336 }
337 }
338
339 if (!hasBeenFlattened && newOperands.size() == op.getNumOperands())
340 return failure(); // Nothing changed
341
342 if (newOperands.empty())
343 rewriter.replaceOpWithNewOp<FormatLitOp>(op, fmtStrType,
344 rewriter.getStringAttr(""));
345 else if (newOperands.size() == 1)
346 rewriter.replaceOp(op, newOperands);
347 else
348 rewriter.modifyOpInPlace(op, [&]() { op->setOperands(newOperands); });
349
350 return success();
351}
352
353LogicalResult PrintFormattedOp::canonicalize(PrintFormattedOp op,
354 PatternRewriter &rewriter) {
355 // Remove ops with constant false condition.
356 if (auto cstCond = op.getCondition().getDefiningOp<hw::ConstantOp>()) {
357 if (cstCond.getValue().isZero()) {
358 rewriter.eraseOp(op);
359 return success();
360 }
361 }
362 return failure();
363}
364
365LogicalResult PrintFormattedProcOp::verify() {
366 // Check if we know for sure that the parent is not procedural.
367 auto *parentOp = getOperation()->getParentOp();
368
369 if (!parentOp)
370 return emitOpError("must be within a procedural region.");
371
372 if (isa<hw::HWDialect>(parentOp->getDialect())) {
373 if (!isa<hw::TriggeredOp>(parentOp))
374 return emitOpError("must be within a procedural region.");
375 return success();
376 }
377
378 if (isa<sv::SVDialect>(parentOp->getDialect())) {
379 if (!parentOp->hasTrait<sv::ProceduralRegion>())
380 return emitOpError("must be within a procedural region.");
381 return success();
382 }
383
384 // Don't fail for dialects that are not explicitly handled.
385 return success();
386}
387
388LogicalResult PrintFormattedProcOp::canonicalize(PrintFormattedProcOp op,
389 PatternRewriter &rewriter) {
390 // Remove empty prints.
391 if (auto litInput = op.getInput().getDefiningOp<FormatLitOp>()) {
392 if (litInput.getLiteral().empty()) {
393 rewriter.eraseOp(op);
394 return success();
395 }
396 }
397 return failure();
398}
399
400//===----------------------------------------------------------------------===//
401// TableGen generated logic.
402//===----------------------------------------------------------------------===//
403
404// Provide the autogenerated implementation guts for the Op classes.
405#define GET_OP_CLASSES
406#include "circt/Dialect/Sim/Sim.cpp.inc"
assert(baseType &&"element must be base type")
static StringAttr concatLiterals(MLIRContext *ctxt, ArrayRef< StringRef > lits)
Definition SimOps.cpp:186
Signals that an operations regions are procedural.
Definition SVOps.h:160
ParseResult parseModuleSignature(OpAsmParser &parser, SmallVectorImpl< PortParse > &args, TypeAttr &modType)
New Style parsing.
void printModuleSignatureNew(OpAsmPrinter &p, Region &body, hw::ModuleType modType, ArrayRef< Attribute > portAttrs, ArrayRef< Location > locAttrs)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.