Loading [MathJax]/extensions/tex2jax.js
CIRCT 21.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
RTGOps.cpp
Go to the documentation of this file.
1//===- RTGOps.cpp - Implement the RTG operations --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the RTG ops.
10//
11//===----------------------------------------------------------------------===//
12
15#include "mlir/IR/Builders.h"
16#include "mlir/IR/DialectImplementation.h"
17#include "llvm/ADT/SmallString.h"
18
19using namespace mlir;
20using namespace circt;
21using namespace rtg;
22
23//===----------------------------------------------------------------------===//
24// ConstantOp
25//===----------------------------------------------------------------------===//
26
27LogicalResult
28ConstantOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
29 ValueRange operands, DictionaryAttr attributes,
30 OpaqueProperties properties, RegionRange regions,
31 SmallVectorImpl<Type> &inferredReturnTypes) {
32 inferredReturnTypes.push_back(
33 properties.as<Properties *>()->getValue().getType());
34 return success();
35}
36
37OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
38
39//===----------------------------------------------------------------------===//
40// SequenceOp
41//===----------------------------------------------------------------------===//
42
43LogicalResult SequenceOp::verifyRegions() {
44 if (TypeRange(getSequenceType().getElementTypes()) !=
45 getBody()->getArgumentTypes())
46 return emitOpError("sequence type does not match block argument types");
47
48 return success();
49}
50
51ParseResult SequenceOp::parse(OpAsmParser &parser, OperationState &result) {
52 // Parse the name as a symbol.
53 if (parser.parseSymbolName(
54 result.getOrAddProperties<SequenceOp::Properties>().sym_name))
55 return failure();
56
57 // Parse the function signature.
58 SmallVector<OpAsmParser::Argument> arguments;
59 if (parser.parseArgumentList(arguments, OpAsmParser::Delimiter::Paren,
60 /*allowType=*/true, /*allowAttrs=*/true))
61 return failure();
62
63 SmallVector<Type> argTypes;
64 SmallVector<Location> argLocs;
65 argTypes.reserve(arguments.size());
66 argLocs.reserve(arguments.size());
67 for (auto &arg : arguments) {
68 argTypes.push_back(arg.type);
69 argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
70 }
71 Type type = SequenceType::get(result.getContext(), argTypes);
72 result.getOrAddProperties<SequenceOp::Properties>().sequenceType =
73 TypeAttr::get(type);
74
75 auto loc = parser.getCurrentLocation();
76 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
77 return failure();
78 if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
79 return parser.emitError(loc)
80 << "'" << result.name.getStringRef() << "' op ";
81 })))
82 return failure();
83
84 std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
85 if (parser.parseRegion(*bodyRegionRegion, arguments))
86 return failure();
87
88 if (bodyRegionRegion->empty()) {
89 bodyRegionRegion->emplaceBlock();
90 bodyRegionRegion->addArguments(argTypes, argLocs);
91 }
92 result.addRegion(std::move(bodyRegionRegion));
93
94 return success();
95}
96
97void SequenceOp::print(OpAsmPrinter &p) {
98 p << ' ';
99 p.printSymbolName(getSymNameAttr().getValue());
100 p << "(";
101 llvm::interleaveComma(getBody()->getArguments(), p,
102 [&](auto arg) { p.printRegionArgument(arg); });
103 p << ")";
104 p.printOptionalAttrDictWithKeyword(
105 (*this)->getAttrs(), {getSymNameAttrName(), getSequenceTypeAttrName()});
106 p << ' ';
107 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
108}
109
110//===----------------------------------------------------------------------===//
111// GetSequenceOp
112//===----------------------------------------------------------------------===//
113
114LogicalResult
115GetSequenceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
116 SequenceOp seq =
117 symbolTable.lookupNearestSymbolFrom<SequenceOp>(*this, getSequenceAttr());
118 if (!seq)
119 return emitOpError()
120 << "'" << getSequence()
121 << "' does not reference a valid 'rtg.sequence' operation";
122
123 if (seq.getSequenceType() != getType())
124 return emitOpError("referenced 'rtg.sequence' op's type does not match");
125
126 return success();
127}
128
129//===----------------------------------------------------------------------===//
130// SubstituteSequenceOp
131//===----------------------------------------------------------------------===//
132
133LogicalResult SubstituteSequenceOp::verify() {
134 if (getReplacements().empty())
135 return emitOpError("must at least have one replacement value");
136
137 if (getReplacements().size() >
138 getSequence().getType().getElementTypes().size())
139 return emitOpError(
140 "must not have more replacement values than sequence arguments");
141
142 if (getReplacements().getTypes() !=
143 getSequence().getType().getElementTypes().take_front(
144 getReplacements().size()))
145 return emitOpError("replacement types must match the same number of "
146 "sequence argument types from the front");
147
148 return success();
149}
150
151LogicalResult SubstituteSequenceOp::inferReturnTypes(
152 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
153 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
154 SmallVectorImpl<Type> &inferredReturnTypes) {
155 ArrayRef<Type> argTypes =
156 cast<SequenceType>(operands[0].getType()).getElementTypes();
157 auto seqType =
158 SequenceType::get(context, argTypes.drop_front(operands.size() - 1));
159 inferredReturnTypes.push_back(seqType);
160 return success();
161}
162
163ParseResult SubstituteSequenceOp::parse(::mlir::OpAsmParser &parser,
164 ::mlir::OperationState &result) {
165 OpAsmParser::UnresolvedOperand sequenceRawOperand;
166 SmallVector<OpAsmParser::UnresolvedOperand, 4> replacementsOperands;
167 Type sequenceRawType;
168
169 if (parser.parseOperand(sequenceRawOperand) || parser.parseLParen())
170 return failure();
171
172 auto replacementsOperandsLoc = parser.getCurrentLocation();
173 if (parser.parseOperandList(replacementsOperands) || parser.parseRParen() ||
174 parser.parseColon() || parser.parseType(sequenceRawType) ||
175 parser.parseOptionalAttrDict(result.attributes))
176 return failure();
177
178 if (!isa<SequenceType>(sequenceRawType))
179 return parser.emitError(parser.getNameLoc())
180 << "'sequence' must be handle to a sequence or sequence family, but "
181 "got "
182 << sequenceRawType;
183
184 if (parser.resolveOperand(sequenceRawOperand, sequenceRawType,
185 result.operands))
186 return failure();
187
188 if (parser.resolveOperands(replacementsOperands,
189 cast<SequenceType>(sequenceRawType)
190 .getElementTypes()
191 .take_front(replacementsOperands.size()),
192 replacementsOperandsLoc, result.operands))
193 return failure();
194
195 SmallVector<Type> inferredReturnTypes;
196 if (failed(inferReturnTypes(
197 parser.getContext(), result.location, result.operands,
198 result.attributes.getDictionary(parser.getContext()),
199 result.getRawProperties(), result.regions, inferredReturnTypes)))
200 return failure();
201
202 result.addTypes(inferredReturnTypes);
203 return success();
204}
205
206void SubstituteSequenceOp::print(OpAsmPrinter &p) {
207 p << ' ' << getSequence() << "(" << getReplacements()
208 << ") : " << getSequence().getType();
209 p.printOptionalAttrDict((*this)->getAttrs(), {});
210}
211
212//===----------------------------------------------------------------------===//
213// InterleaveSequencesOp
214//===----------------------------------------------------------------------===//
215
216LogicalResult InterleaveSequencesOp::verify() {
217 if (getSequences().empty())
218 return emitOpError("must have at least one sequence in the list");
219
220 return success();
221}
222
223OpFoldResult InterleaveSequencesOp::fold(FoldAdaptor adaptor) {
224 if (getSequences().size() == 1)
225 return getSequences()[0];
226
227 return {};
228}
229
230//===----------------------------------------------------------------------===//
231// SetCreateOp
232//===----------------------------------------------------------------------===//
233
234ParseResult SetCreateOp::parse(OpAsmParser &parser, OperationState &result) {
235 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
236 Type elemType;
237
238 if (parser.parseOperandList(operands) ||
239 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
240 parser.parseType(elemType))
241 return failure();
242
243 result.addTypes({SetType::get(result.getContext(), elemType)});
244
245 for (auto operand : operands)
246 if (parser.resolveOperand(operand, elemType, result.operands))
247 return failure();
248
249 return success();
250}
251
252void SetCreateOp::print(OpAsmPrinter &p) {
253 p << " ";
254 p.printOperands(getElements());
255 p.printOptionalAttrDict((*this)->getAttrs());
256 p << " : " << getSet().getType().getElementType();
257}
258
259LogicalResult SetCreateOp::verify() {
260 if (getElements().size() > 0) {
261 // We only need to check the first element because of the `SameTypeOperands`
262 // trait.
263 if (getElements()[0].getType() != getSet().getType().getElementType())
264 return emitOpError() << "operand types must match set element type";
265 }
266
267 return success();
268}
269
270//===----------------------------------------------------------------------===//
271// SetCartesianProductOp
272//===----------------------------------------------------------------------===//
273
274LogicalResult SetCartesianProductOp::inferReturnTypes(
275 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
276 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
277 SmallVectorImpl<Type> &inferredReturnTypes) {
278 if (operands.empty()) {
279 if (loc)
280 return mlir::emitError(*loc) << "at least one set must be provided";
281 return failure();
282 }
283
284 SmallVector<Type> elementTypes;
285 for (auto operand : operands)
286 elementTypes.push_back(cast<SetType>(operand.getType()).getElementType());
287 inferredReturnTypes.push_back(
288 SetType::get(TupleType::get(context, elementTypes)));
289 return success();
290}
291
292//===----------------------------------------------------------------------===//
293// BagCreateOp
294//===----------------------------------------------------------------------===//
295
296ParseResult BagCreateOp::parse(OpAsmParser &parser, OperationState &result) {
297 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> elementOperands,
298 multipleOperands;
299 Type elemType;
300
301 if (!parser.parseOptionalLParen()) {
302 while (true) {
303 OpAsmParser::UnresolvedOperand elementOperand, multipleOperand;
304 if (parser.parseOperand(multipleOperand) || parser.parseKeyword("x") ||
305 parser.parseOperand(elementOperand))
306 return failure();
307
308 elementOperands.push_back(elementOperand);
309 multipleOperands.push_back(multipleOperand);
310
311 if (parser.parseOptionalComma()) {
312 if (parser.parseRParen())
313 return failure();
314 break;
315 }
316 }
317 }
318
319 if (parser.parseColon() || parser.parseType(elemType) ||
320 parser.parseOptionalAttrDict(result.attributes))
321 return failure();
322
323 result.addTypes({BagType::get(result.getContext(), elemType)});
324
325 for (auto operand : elementOperands)
326 if (parser.resolveOperand(operand, elemType, result.operands))
327 return failure();
328
329 for (auto operand : multipleOperands)
330 if (parser.resolveOperand(operand, IndexType::get(result.getContext()),
331 result.operands))
332 return failure();
333
334 return success();
335}
336
337void BagCreateOp::print(OpAsmPrinter &p) {
338 p << " ";
339 if (!getElements().empty())
340 p << "(";
341 llvm::interleaveComma(llvm::zip(getElements(), getMultiples()), p,
342 [&](auto elAndMultiple) {
343 auto [el, multiple] = elAndMultiple;
344 p << multiple << " x " << el;
345 });
346 if (!getElements().empty())
347 p << ")";
348
349 p << " : " << getBag().getType().getElementType();
350 p.printOptionalAttrDict((*this)->getAttrs());
351}
352
353LogicalResult BagCreateOp::verify() {
354 if (!llvm::all_equal(getElements().getTypes()))
355 return emitOpError() << "types of all elements must match";
356
357 if (getElements().size() > 0)
358 if (getElements()[0].getType() != getBag().getType().getElementType())
359 return emitOpError() << "operand types must match bag element type";
360
361 return success();
362}
363
364//===----------------------------------------------------------------------===//
365// TupleCreateOp
366//===----------------------------------------------------------------------===//
367
368LogicalResult TupleCreateOp::inferReturnTypes(
369 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
370 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
371 SmallVectorImpl<Type> &inferredReturnTypes) {
372 if (operands.empty()) {
373 if (loc)
374 return mlir::emitError(*loc) << "empty tuples not allowed";
375 return failure();
376 }
377
378 SmallVector<Type> elementTypes;
379 for (auto operand : operands)
380 elementTypes.push_back(operand.getType());
381 inferredReturnTypes.push_back(TupleType::get(context, elementTypes));
382 return success();
383}
384
385//===----------------------------------------------------------------------===//
386// TupleExtractOp
387//===----------------------------------------------------------------------===//
388
389LogicalResult TupleExtractOp::inferReturnTypes(
390 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
391 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
392 SmallVectorImpl<Type> &inferredReturnTypes) {
393 assert(operands.size() == 1 && "must have exactly one operand");
394
395 auto tupleTy = dyn_cast<TupleType>(operands[0].getType());
396 size_t idx = properties.as<Properties *>()->getIndex().getInt();
397 if (!tupleTy || tupleTy.getTypes().size() <= idx) {
398 if (loc)
399 return mlir::emitError(*loc)
400 << "index (" << idx
401 << ") must be smaller than number of elements in tuple ("
402 << tupleTy.getTypes().size() << ")";
403 return failure();
404 }
405
406 inferredReturnTypes.push_back(tupleTy.getTypes()[idx]);
407 return success();
408}
409
410//===----------------------------------------------------------------------===//
411// FixedRegisterOp
412//===----------------------------------------------------------------------===//
413
414LogicalResult FixedRegisterOp::inferReturnTypes(
415 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
416 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
417 SmallVectorImpl<Type> &inferredReturnTypes) {
418 inferredReturnTypes.push_back(
419 properties.as<Properties *>()->getReg().getType());
420 return success();
421}
422
423OpFoldResult FixedRegisterOp::fold(FoldAdaptor adaptor) { return getRegAttr(); }
424
425//===----------------------------------------------------------------------===//
426// VirtualRegisterOp
427//===----------------------------------------------------------------------===//
428
429LogicalResult VirtualRegisterOp::verify() {
430 if (getAllowedRegs().empty())
431 return emitOpError("must have at least one allowed register");
432
433 if (llvm::any_of(getAllowedRegs(), [](Attribute attr) {
434 return !isa<RegisterAttrInterface>(attr);
435 }))
436 return emitOpError("all elements must be of RegisterAttrInterface");
437
438 if (!llvm::all_equal(
439 llvm::map_range(getAllowedRegs().getAsRange<RegisterAttrInterface>(),
440 [](auto attr) { return attr.getType(); })))
441 return emitOpError("all allowed registers must be of the same type");
442
443 return success();
444}
445
446LogicalResult VirtualRegisterOp::inferReturnTypes(
447 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
448 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
449 SmallVectorImpl<Type> &inferredReturnTypes) {
450 auto allowedRegs = properties.as<Properties *>()->getAllowedRegs();
451 if (allowedRegs.empty()) {
452 if (loc)
453 return mlir::emitError(*loc, "must have at least one allowed register");
454
455 return failure();
456 }
457
458 auto regAttr = dyn_cast<RegisterAttrInterface>(allowedRegs[0]);
459 if (!regAttr) {
460 if (loc)
461 return mlir::emitError(
462 *loc, "allowed register attributes must be of RegisterAttrInterface");
463
464 return failure();
465 }
466 inferredReturnTypes.push_back(regAttr.getType());
467 return success();
468}
469
470//===----------------------------------------------------------------------===//
471// ContextSwitchOp
472//===----------------------------------------------------------------------===//
473
474LogicalResult ContextSwitchOp::verify() {
475 auto elementTypes = getSequence().getType().getElementTypes();
476 if (elementTypes.size() != 3)
477 return emitOpError("sequence type must have exactly 3 element types");
478
479 if (getFrom().getType() != elementTypes[0])
480 return emitOpError(
481 "first sequence element type must match 'from' attribute type");
482
483 if (getTo().getType() != elementTypes[1])
484 return emitOpError(
485 "second sequence element type must match 'to' attribute type");
486
487 auto seqTy = dyn_cast<SequenceType>(elementTypes[2]);
488 if (!seqTy || !seqTy.getElementTypes().empty())
489 return emitOpError(
490 "third sequence element type must be a fully substituted sequence");
491
492 return success();
493}
494
495//===----------------------------------------------------------------------===//
496// TestOp
497//===----------------------------------------------------------------------===//
498
499LogicalResult TestOp::verifyRegions() {
500 if (!getTarget().entryTypesMatch(getBody()->getArgumentTypes()))
501 return emitOpError("argument types must match dict entry types");
502
503 return success();
504}
505
506ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
507 // Parse the name as a symbol.
508 if (parser.parseSymbolName(
509 result.getOrAddProperties<TestOp::Properties>().sym_name))
510 return failure();
511
512 // Parse the function signature.
513 SmallVector<OpAsmParser::Argument> arguments;
514 SmallVector<StringAttr> names;
515
516 auto parseOneArgument = [&]() -> ParseResult {
517 std::string name;
518 if (parser.parseKeywordOrString(&name) || parser.parseEqual() ||
519 parser.parseArgument(arguments.emplace_back(), /*allowType=*/true,
520 /*allowAttrs=*/true))
521 return failure();
522
523 names.push_back(StringAttr::get(result.getContext(), name));
524 return success();
525 };
526 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
527 parseOneArgument, " in argument list"))
528 return failure();
529
530 SmallVector<Type> argTypes;
531 SmallVector<DictEntry> entries;
532 SmallVector<Location> argLocs;
533 argTypes.reserve(arguments.size());
534 argLocs.reserve(arguments.size());
535 for (auto [name, arg] : llvm::zip(names, arguments)) {
536 argTypes.push_back(arg.type);
537 argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
538 entries.push_back({name, arg.type});
539 }
540 auto emitError = [&]() -> InFlightDiagnostic {
541 return parser.emitError(parser.getCurrentLocation());
542 };
543 Type type = DictType::getChecked(emitError, result.getContext(),
544 ArrayRef<DictEntry>(entries));
545 if (!type)
546 return failure();
547 result.getOrAddProperties<TestOp::Properties>().target = TypeAttr::get(type);
548
549 auto loc = parser.getCurrentLocation();
550 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
551 return failure();
552 if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
553 return parser.emitError(loc)
554 << "'" << result.name.getStringRef() << "' op ";
555 })))
556 return failure();
557
558 std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
559 if (parser.parseRegion(*bodyRegionRegion, arguments))
560 return failure();
561
562 if (bodyRegionRegion->empty()) {
563 bodyRegionRegion->emplaceBlock();
564 bodyRegionRegion->addArguments(argTypes, argLocs);
565 }
566 result.addRegion(std::move(bodyRegionRegion));
567
568 return success();
569}
570
571void TestOp::print(OpAsmPrinter &p) {
572 p << ' ';
573 p.printSymbolName(getSymNameAttr().getValue());
574 p << "(";
575 SmallString<32> resultNameStr;
576 llvm::interleaveComma(
577 llvm::zip(getTarget().getEntries(), getBody()->getArguments()), p,
578 [&](auto entryAndArg) {
579 auto [entry, arg] = entryAndArg;
580 p << entry.name.getValue() << " = ";
581 p.printRegionArgument(arg);
582 });
583 p << ")";
584 p.printOptionalAttrDictWithKeyword(
585 (*this)->getAttrs(), {getSymNameAttrName(), getTargetAttrName()});
586 p << ' ';
587 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
588}
589
590void TestOp::getAsmBlockArgumentNames(Region &region,
591 OpAsmSetValueNameFn setNameFn) {
592 for (auto [entry, arg] :
593 llvm::zip(getTarget().getEntries(), region.getArguments()))
594 setNameFn(arg, entry.name.getValue());
595}
596
597//===----------------------------------------------------------------------===//
598// TargetOp
599//===----------------------------------------------------------------------===//
600
601LogicalResult TargetOp::verifyRegions() {
602 if (!getTarget().entryTypesMatch(
603 getBody()->getTerminator()->getOperandTypes()))
604 return emitOpError("terminator operand types must match dict entry types");
605
606 return success();
607}
608
609//===----------------------------------------------------------------------===//
610// ArrayCreateOp
611//===----------------------------------------------------------------------===//
612
613LogicalResult ArrayCreateOp::verify() {
614 if (!getElements().empty() &&
615 getElements()[0].getType() != getType().getElementType())
616 return emitOpError("operand types must match array element type, expected ")
617 << getType().getElementType() << " but got "
618 << getElements()[0].getType();
619
620 return success();
621}
622
623ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
624 SmallVector<OpAsmParser::UnresolvedOperand> operands;
625 Type elementType;
626
627 if (parser.parseOperandList(operands) || parser.parseColon() ||
628 parser.parseType(elementType) ||
629 parser.parseOptionalAttrDict(result.attributes))
630 return failure();
631
632 if (failed(parser.resolveOperands(operands, elementType, result.operands)))
633 return failure();
634
635 result.addTypes(ArrayType::get(elementType));
636
637 return success();
638}
639
640void ArrayCreateOp::print(OpAsmPrinter &p) {
641 p << ' ';
642 p.printOperands(getElements());
643 p << " : " << getType().getElementType();
644 p.printOptionalAttrDict((*this)->getAttrs(), {});
645}
646
647//===----------------------------------------------------------------------===//
648// TableGen generated logic.
649//===----------------------------------------------------------------------===//
650
651#define GET_OP_CLASSES
652#include "circt/Dialect/RTG/IR/RTG.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static InstancePath empty
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
Definition rtg.py:1
Definition seq.py:1