CIRCT 21.0.0git
Loading...
Searching...
No Matches
RTGOps.cpp
Go to the documentation of this file.
1//===- RTGOps.cpp - Implement the RTG operations --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the RTG ops.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/DialectImplementation.h"
17#include "llvm/ADT/SmallString.h"
18
19using namespace mlir;
20using namespace circt;
21using namespace rtg;
22
23//===----------------------------------------------------------------------===//
24// SequenceOp
25//===----------------------------------------------------------------------===//
26
27LogicalResult SequenceOp::verifyRegions() {
28 if (TypeRange(getSequenceType().getElementTypes()) !=
29 getBody()->getArgumentTypes())
30 return emitOpError("sequence type does not match block argument types");
31
32 return success();
33}
34
35ParseResult SequenceOp::parse(OpAsmParser &parser, OperationState &result) {
36 // Parse the name as a symbol.
37 if (parser.parseSymbolName(
38 result.getOrAddProperties<SequenceOp::Properties>().sym_name))
39 return failure();
40
41 // Parse the function signature.
42 SmallVector<OpAsmParser::Argument> arguments;
43 if (parser.parseArgumentList(arguments, OpAsmParser::Delimiter::Paren,
44 /*allowType=*/true, /*allowAttrs=*/true))
45 return failure();
46
47 SmallVector<Type> argTypes;
48 SmallVector<Location> argLocs;
49 argTypes.reserve(arguments.size());
50 argLocs.reserve(arguments.size());
51 for (auto &arg : arguments) {
52 argTypes.push_back(arg.type);
53 argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
54 }
55 Type type = SequenceType::get(result.getContext(), argTypes);
56 result.getOrAddProperties<SequenceOp::Properties>().sequenceType =
57 TypeAttr::get(type);
58
59 auto loc = parser.getCurrentLocation();
60 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
61 return failure();
62 if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
63 return parser.emitError(loc)
64 << "'" << result.name.getStringRef() << "' op ";
65 })))
66 return failure();
67
68 std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
69 if (parser.parseRegion(*bodyRegionRegion, arguments))
70 return failure();
71
72 if (bodyRegionRegion->empty()) {
73 bodyRegionRegion->emplaceBlock();
74 bodyRegionRegion->addArguments(argTypes, argLocs);
75 }
76 result.addRegion(std::move(bodyRegionRegion));
77
78 return success();
79}
80
81void SequenceOp::print(OpAsmPrinter &p) {
82 p << ' ';
83 p.printSymbolName(getSymNameAttr().getValue());
84 p << "(";
85 llvm::interleaveComma(getBody()->getArguments(), p,
86 [&](auto arg) { p.printRegionArgument(arg); });
87 p << ")";
88 p.printOptionalAttrDictWithKeyword(
89 (*this)->getAttrs(), {getSymNameAttrName(), getSequenceTypeAttrName()});
90 p << ' ';
91 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
92}
93
94//===----------------------------------------------------------------------===//
95// GetSequenceOp
96//===----------------------------------------------------------------------===//
97
98LogicalResult
99GetSequenceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
100 SequenceOp seq =
101 symbolTable.lookupNearestSymbolFrom<SequenceOp>(*this, getSequenceAttr());
102 if (!seq)
103 return emitOpError()
104 << "'" << getSequence()
105 << "' does not reference a valid 'rtg.sequence' operation";
106
107 if (seq.getSequenceType() != getType())
108 return emitOpError("referenced 'rtg.sequence' op's type does not match");
109
110 return success();
111}
112
113//===----------------------------------------------------------------------===//
114// SubstituteSequenceOp
115//===----------------------------------------------------------------------===//
116
117LogicalResult SubstituteSequenceOp::verify() {
118 if (getReplacements().empty())
119 return emitOpError("must at least have one replacement value");
120
121 if (getReplacements().size() >
122 getSequence().getType().getElementTypes().size())
123 return emitOpError(
124 "must not have more replacement values than sequence arguments");
125
126 if (getReplacements().getTypes() !=
127 getSequence().getType().getElementTypes().take_front(
128 getReplacements().size()))
129 return emitOpError("replacement types must match the same number of "
130 "sequence argument types from the front");
131
132 return success();
133}
134
135LogicalResult SubstituteSequenceOp::inferReturnTypes(
136 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
137 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
138 SmallVectorImpl<Type> &inferredReturnTypes) {
139 ArrayRef<Type> argTypes =
140 cast<SequenceType>(operands[0].getType()).getElementTypes();
141 auto seqType =
142 SequenceType::get(context, argTypes.drop_front(operands.size() - 1));
143 inferredReturnTypes.push_back(seqType);
144 return success();
145}
146
147ParseResult SubstituteSequenceOp::parse(::mlir::OpAsmParser &parser,
148 ::mlir::OperationState &result) {
149 OpAsmParser::UnresolvedOperand sequenceRawOperand;
150 SmallVector<OpAsmParser::UnresolvedOperand, 4> replacementsOperands;
151 Type sequenceRawType;
152
153 if (parser.parseOperand(sequenceRawOperand) || parser.parseLParen())
154 return failure();
155
156 auto replacementsOperandsLoc = parser.getCurrentLocation();
157 if (parser.parseOperandList(replacementsOperands) || parser.parseRParen() ||
158 parser.parseColon() || parser.parseType(sequenceRawType) ||
159 parser.parseOptionalAttrDict(result.attributes))
160 return failure();
161
162 if (!isa<SequenceType>(sequenceRawType))
163 return parser.emitError(parser.getNameLoc())
164 << "'sequence' must be handle to a sequence or sequence family, but "
165 "got "
166 << sequenceRawType;
167
168 if (parser.resolveOperand(sequenceRawOperand, sequenceRawType,
169 result.operands))
170 return failure();
171
172 if (parser.resolveOperands(replacementsOperands,
173 cast<SequenceType>(sequenceRawType)
174 .getElementTypes()
175 .take_front(replacementsOperands.size()),
176 replacementsOperandsLoc, result.operands))
177 return failure();
178
179 SmallVector<Type> inferredReturnTypes;
180 if (failed(inferReturnTypes(
181 parser.getContext(), result.location, result.operands,
182 result.attributes.getDictionary(parser.getContext()),
183 result.getRawProperties(), result.regions, inferredReturnTypes)))
184 return failure();
185
186 result.addTypes(inferredReturnTypes);
187 return success();
188}
189
190void SubstituteSequenceOp::print(OpAsmPrinter &p) {
191 p << ' ' << getSequence() << "(" << getReplacements()
192 << ") : " << getSequence().getType();
193 p.printOptionalAttrDict((*this)->getAttrs(), {});
194}
195
196//===----------------------------------------------------------------------===//
197// InterleaveSequencesOp
198//===----------------------------------------------------------------------===//
199
200LogicalResult InterleaveSequencesOp::verify() {
201 if (getSequences().empty())
202 return emitOpError("must have at least one sequence in the list");
203
204 return success();
205}
206
207OpFoldResult InterleaveSequencesOp::fold(FoldAdaptor adaptor) {
208 if (getSequences().size() == 1)
209 return getSequences()[0];
210
211 return {};
212}
213
214//===----------------------------------------------------------------------===//
215// SetCreateOp
216//===----------------------------------------------------------------------===//
217
218ParseResult SetCreateOp::parse(OpAsmParser &parser, OperationState &result) {
219 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
220 Type elemType;
221
222 if (parser.parseOperandList(operands) ||
223 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
224 parser.parseType(elemType))
225 return failure();
226
227 result.addTypes({SetType::get(result.getContext(), elemType)});
228
229 for (auto operand : operands)
230 if (parser.resolveOperand(operand, elemType, result.operands))
231 return failure();
232
233 return success();
234}
235
236void SetCreateOp::print(OpAsmPrinter &p) {
237 p << " ";
238 p.printOperands(getElements());
239 p.printOptionalAttrDict((*this)->getAttrs());
240 p << " : " << getSet().getType().getElementType();
241}
242
243LogicalResult SetCreateOp::verify() {
244 if (getElements().size() > 0) {
245 // We only need to check the first element because of the `SameTypeOperands`
246 // trait.
247 if (getElements()[0].getType() != getSet().getType().getElementType())
248 return emitOpError() << "operand types must match set element type";
249 }
250
251 return success();
252}
253
254//===----------------------------------------------------------------------===//
255// BagCreateOp
256//===----------------------------------------------------------------------===//
257
258ParseResult BagCreateOp::parse(OpAsmParser &parser, OperationState &result) {
259 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> elementOperands,
260 multipleOperands;
261 Type elemType;
262
263 if (!parser.parseOptionalLParen()) {
264 while (true) {
265 OpAsmParser::UnresolvedOperand elementOperand, multipleOperand;
266 if (parser.parseOperand(multipleOperand) || parser.parseKeyword("x") ||
267 parser.parseOperand(elementOperand))
268 return failure();
269
270 elementOperands.push_back(elementOperand);
271 multipleOperands.push_back(multipleOperand);
272
273 if (parser.parseOptionalComma()) {
274 if (parser.parseRParen())
275 return failure();
276 break;
277 }
278 }
279 }
280
281 if (parser.parseColon() || parser.parseType(elemType) ||
282 parser.parseOptionalAttrDict(result.attributes))
283 return failure();
284
285 result.addTypes({BagType::get(result.getContext(), elemType)});
286
287 for (auto operand : elementOperands)
288 if (parser.resolveOperand(operand, elemType, result.operands))
289 return failure();
290
291 for (auto operand : multipleOperands)
292 if (parser.resolveOperand(operand, IndexType::get(result.getContext()),
293 result.operands))
294 return failure();
295
296 return success();
297}
298
299void BagCreateOp::print(OpAsmPrinter &p) {
300 p << " ";
301 if (!getElements().empty())
302 p << "(";
303 llvm::interleaveComma(llvm::zip(getElements(), getMultiples()), p,
304 [&](auto elAndMultiple) {
305 auto [el, multiple] = elAndMultiple;
306 p << multiple << " x " << el;
307 });
308 if (!getElements().empty())
309 p << ")";
310
311 p << " : " << getBag().getType().getElementType();
312 p.printOptionalAttrDict((*this)->getAttrs());
313}
314
315LogicalResult BagCreateOp::verify() {
316 if (!llvm::all_equal(getElements().getTypes()))
317 return emitOpError() << "types of all elements must match";
318
319 if (getElements().size() > 0)
320 if (getElements()[0].getType() != getBag().getType().getElementType())
321 return emitOpError() << "operand types must match bag element type";
322
323 return success();
324}
325
326//===----------------------------------------------------------------------===//
327// FixedRegisterOp
328//===----------------------------------------------------------------------===//
329
330LogicalResult FixedRegisterOp::inferReturnTypes(
331 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
332 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
333 SmallVectorImpl<Type> &inferredReturnTypes) {
334 inferredReturnTypes.push_back(
335 properties.as<Properties *>()->getReg().getType());
336 return success();
337}
338
339OpFoldResult FixedRegisterOp::fold(FoldAdaptor adaptor) { return getRegAttr(); }
340
341//===----------------------------------------------------------------------===//
342// VirtualRegisterOp
343//===----------------------------------------------------------------------===//
344
345LogicalResult VirtualRegisterOp::verify() {
346 if (getAllowedRegs().empty())
347 return emitOpError("must have at least one allowed register");
348
349 if (llvm::any_of(getAllowedRegs(), [](Attribute attr) {
350 return !isa<RegisterAttrInterface>(attr);
351 }))
352 return emitOpError("all elements must be of RegisterAttrInterface");
353
354 if (!llvm::all_equal(
355 llvm::map_range(getAllowedRegs().getAsRange<RegisterAttrInterface>(),
356 [](auto attr) { return attr.getType(); })))
357 return emitOpError("all allowed registers must be of the same type");
358
359 return success();
360}
361
362LogicalResult VirtualRegisterOp::inferReturnTypes(
363 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
364 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
365 SmallVectorImpl<Type> &inferredReturnTypes) {
366 auto allowedRegs = properties.as<Properties *>()->getAllowedRegs();
367 if (allowedRegs.empty()) {
368 if (loc)
369 return mlir::emitError(*loc, "must have at least one allowed register");
370
371 return failure();
372 }
373
374 auto regAttr = dyn_cast<RegisterAttrInterface>(allowedRegs[0]);
375 if (!regAttr) {
376 if (loc)
377 return mlir::emitError(
378 *loc, "allowed register attributes must be of RegisterAttrInterface");
379
380 return failure();
381 }
382 inferredReturnTypes.push_back(regAttr.getType());
383 return success();
384}
385
386//===----------------------------------------------------------------------===//
387// ContextSwitchOp
388//===----------------------------------------------------------------------===//
389
390LogicalResult ContextSwitchOp::verify() {
391 auto elementTypes = getSequence().getType().getElementTypes();
392 if (elementTypes.size() != 3)
393 return emitOpError("sequence type must have exactly 3 element types");
394
395 if (getFrom().getType() != elementTypes[0])
396 return emitOpError(
397 "first sequence element type must match 'from' attribute type");
398
399 if (getTo().getType() != elementTypes[1])
400 return emitOpError(
401 "second sequence element type must match 'to' attribute type");
402
403 auto seqTy = dyn_cast<SequenceType>(elementTypes[2]);
404 if (!seqTy || !seqTy.getElementTypes().empty())
405 return emitOpError(
406 "third sequence element type must be a fully substituted sequence");
407
408 return success();
409}
410
411//===----------------------------------------------------------------------===//
412// TestOp
413//===----------------------------------------------------------------------===//
414
415LogicalResult TestOp::verifyRegions() {
416 if (!getTarget().entryTypesMatch(getBody()->getArgumentTypes()))
417 return emitOpError("argument types must match dict entry types");
418
419 return success();
420}
421
422ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
423 // Parse the name as a symbol.
424 if (parser.parseSymbolName(
425 result.getOrAddProperties<TestOp::Properties>().sym_name))
426 return failure();
427
428 // Parse the function signature.
429 SmallVector<OpAsmParser::Argument> arguments;
430 SmallVector<StringAttr> names;
431
432 auto parseOneArgument = [&]() -> ParseResult {
433 std::string name;
434 if (parser.parseKeywordOrString(&name) || parser.parseEqual() ||
435 parser.parseArgument(arguments.emplace_back(), /*allowType=*/true,
436 /*allowAttrs=*/true))
437 return failure();
438
439 names.push_back(StringAttr::get(result.getContext(), name));
440 return success();
441 };
442 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
443 parseOneArgument, " in argument list"))
444 return failure();
445
446 SmallVector<Type> argTypes;
447 SmallVector<DictEntry> entries;
448 SmallVector<Location> argLocs;
449 argTypes.reserve(arguments.size());
450 argLocs.reserve(arguments.size());
451 for (auto [name, arg] : llvm::zip(names, arguments)) {
452 argTypes.push_back(arg.type);
453 argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
454 entries.push_back({name, arg.type});
455 }
456 auto emitError = [&]() -> InFlightDiagnostic {
457 return parser.emitError(parser.getCurrentLocation());
458 };
459 Type type = DictType::getChecked(emitError, result.getContext(),
460 ArrayRef<DictEntry>(entries));
461 if (!type)
462 return failure();
463 result.getOrAddProperties<TestOp::Properties>().target = TypeAttr::get(type);
464
465 auto loc = parser.getCurrentLocation();
466 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
467 return failure();
468 if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
469 return parser.emitError(loc)
470 << "'" << result.name.getStringRef() << "' op ";
471 })))
472 return failure();
473
474 std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
475 if (parser.parseRegion(*bodyRegionRegion, arguments))
476 return failure();
477
478 if (bodyRegionRegion->empty()) {
479 bodyRegionRegion->emplaceBlock();
480 bodyRegionRegion->addArguments(argTypes, argLocs);
481 }
482 result.addRegion(std::move(bodyRegionRegion));
483
484 return success();
485}
486
487void TestOp::print(OpAsmPrinter &p) {
488 p << ' ';
489 p.printSymbolName(getSymNameAttr().getValue());
490 p << "(";
491 SmallString<32> resultNameStr;
492 llvm::interleaveComma(
493 llvm::zip(getTarget().getEntries(), getBody()->getArguments()), p,
494 [&](auto entryAndArg) {
495 auto [entry, arg] = entryAndArg;
496 p << entry.name.getValue() << " = ";
497 p.printRegionArgument(arg);
498 });
499 p << ")";
500 p.printOptionalAttrDictWithKeyword(
501 (*this)->getAttrs(), {getSymNameAttrName(), getTargetAttrName()});
502 p << ' ';
503 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
504}
505
506void TestOp::getAsmBlockArgumentNames(Region &region,
507 OpAsmSetValueNameFn setNameFn) {
508 for (auto [entry, arg] :
509 llvm::zip(getTarget().getEntries(), region.getArguments()))
510 setNameFn(arg, entry.name.getValue());
511}
512
513//===----------------------------------------------------------------------===//
514// TargetOp
515//===----------------------------------------------------------------------===//
516
517LogicalResult TargetOp::verifyRegions() {
518 if (!getTarget().entryTypesMatch(
519 getBody()->getTerminator()->getOperandTypes()))
520 return emitOpError("terminator operand types must match dict entry types");
521
522 return success();
523}
524
525//===----------------------------------------------------------------------===//
526// TableGen generated logic.
527//===----------------------------------------------------------------------===//
528
529#define GET_OP_CLASSES
530#include "circt/Dialect/RTG/IR/RTG.cpp.inc"
static InstancePath empty
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:182
Definition rtg.py:1
Definition seq.py:1