CIRCT 22.0.0git
Loading...
Searching...
No Matches
RTGOps.cpp
Go to the documentation of this file.
1//===- RTGOps.cpp - Implement the RTG operations --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the RTG ops.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/DialectImplementation.h"
18#include "llvm/ADT/SmallString.h"
19
20using namespace mlir;
21using namespace circt;
22using namespace rtg;
23
24//===----------------------------------------------------------------------===//
25// ConstantOp
26//===----------------------------------------------------------------------===//
27
28LogicalResult
29ConstantOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
30 ValueRange operands, DictionaryAttr attributes,
31 OpaqueProperties properties, RegionRange regions,
32 SmallVectorImpl<Type> &inferredReturnTypes) {
33 inferredReturnTypes.push_back(
34 properties.as<Properties *>()->getValue().getType());
35 return success();
36}
37
38OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
39
40void ConstantOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
41 if (auto reg = dyn_cast<rtg::RegisterAttrInterface>(getValueAttr())) {
42 setNameFn(getResult(), reg.getRegisterAssembly());
43 return;
44 }
45}
46
47//===----------------------------------------------------------------------===//
48// SequenceOp
49//===----------------------------------------------------------------------===//
50
51LogicalResult SequenceOp::verifyRegions() {
52 if (TypeRange(getSequenceType().getElementTypes()) !=
53 getBody()->getArgumentTypes())
54 return emitOpError("sequence type does not match block argument types");
55
56 return success();
57}
58
59ParseResult SequenceOp::parse(OpAsmParser &parser, OperationState &result) {
60 // Parse the name as a symbol.
61 if (parser.parseSymbolName(
62 result.getOrAddProperties<SequenceOp::Properties>().sym_name))
63 return failure();
64
65 // Parse the function signature.
66 SmallVector<OpAsmParser::Argument> arguments;
67 if (parser.parseArgumentList(arguments, OpAsmParser::Delimiter::Paren,
68 /*allowType=*/true, /*allowAttrs=*/true))
69 return failure();
70
71 SmallVector<Type> argTypes;
72 SmallVector<Location> argLocs;
73 argTypes.reserve(arguments.size());
74 argLocs.reserve(arguments.size());
75 for (auto &arg : arguments) {
76 argTypes.push_back(arg.type);
77 argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
78 }
79 Type type = SequenceType::get(result.getContext(), argTypes);
80 result.getOrAddProperties<SequenceOp::Properties>().sequenceType =
81 TypeAttr::get(type);
82
83 auto loc = parser.getCurrentLocation();
84 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
85 return failure();
86 if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
87 return parser.emitError(loc)
88 << "'" << result.name.getStringRef() << "' op ";
89 })))
90 return failure();
91
92 std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
93 if (parser.parseRegion(*bodyRegionRegion, arguments))
94 return failure();
95
96 if (bodyRegionRegion->empty()) {
97 bodyRegionRegion->emplaceBlock();
98 bodyRegionRegion->addArguments(argTypes, argLocs);
99 }
100 result.addRegion(std::move(bodyRegionRegion));
101
102 return success();
103}
104
105void SequenceOp::print(OpAsmPrinter &p) {
106 p << ' ';
107 p.printSymbolName(getSymNameAttr().getValue());
108 p << "(";
109 llvm::interleaveComma(getBody()->getArguments(), p,
110 [&](auto arg) { p.printRegionArgument(arg); });
111 p << ")";
112 p.printOptionalAttrDictWithKeyword(
113 (*this)->getAttrs(), {getSymNameAttrName(), getSequenceTypeAttrName()});
114 p << ' ';
115 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
116}
117
118mlir::SymbolTable::Visibility SequenceOp::getVisibility() {
119 return mlir::SymbolTable::Visibility::Private;
120}
121
122void SequenceOp::setVisibility(mlir::SymbolTable::Visibility visibility) {
123 // Do nothing, always private.
124 assert(false && "cannot change visibility of sequence");
125}
126
127//===----------------------------------------------------------------------===//
128// GetSequenceOp
129//===----------------------------------------------------------------------===//
130
131LogicalResult
132GetSequenceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
133 SequenceOp seq =
134 symbolTable.lookupNearestSymbolFrom<SequenceOp>(*this, getSequenceAttr());
135 if (!seq)
136 return emitOpError()
137 << "'" << getSequence()
138 << "' does not reference a valid 'rtg.sequence' operation";
139
140 if (seq.getSequenceType() != getType())
141 return emitOpError("referenced 'rtg.sequence' op's type does not match");
142
143 return success();
144}
145
146//===----------------------------------------------------------------------===//
147// SubstituteSequenceOp
148//===----------------------------------------------------------------------===//
149
150LogicalResult SubstituteSequenceOp::verify() {
151 if (getReplacements().empty())
152 return emitOpError("must at least have one replacement value");
153
154 if (getReplacements().size() >
155 getSequence().getType().getElementTypes().size())
156 return emitOpError(
157 "must not have more replacement values than sequence arguments");
158
159 if (getReplacements().getTypes() !=
160 getSequence().getType().getElementTypes().take_front(
161 getReplacements().size()))
162 return emitOpError("replacement types must match the same number of "
163 "sequence argument types from the front");
164
165 return success();
166}
167
168LogicalResult SubstituteSequenceOp::inferReturnTypes(
169 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
170 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
171 SmallVectorImpl<Type> &inferredReturnTypes) {
172 ArrayRef<Type> argTypes =
173 cast<SequenceType>(operands[0].getType()).getElementTypes();
174 auto seqType =
175 SequenceType::get(context, argTypes.drop_front(operands.size() - 1));
176 inferredReturnTypes.push_back(seqType);
177 return success();
178}
179
180ParseResult SubstituteSequenceOp::parse(::mlir::OpAsmParser &parser,
181 ::mlir::OperationState &result) {
182 OpAsmParser::UnresolvedOperand sequenceRawOperand;
183 SmallVector<OpAsmParser::UnresolvedOperand, 4> replacementsOperands;
184 Type sequenceRawType;
185
186 if (parser.parseOperand(sequenceRawOperand) || parser.parseLParen())
187 return failure();
188
189 auto replacementsOperandsLoc = parser.getCurrentLocation();
190 if (parser.parseOperandList(replacementsOperands) || parser.parseRParen() ||
191 parser.parseColon() || parser.parseType(sequenceRawType) ||
192 parser.parseOptionalAttrDict(result.attributes))
193 return failure();
194
195 if (!isa<SequenceType>(sequenceRawType))
196 return parser.emitError(parser.getNameLoc())
197 << "'sequence' must be handle to a sequence or sequence family, but "
198 "got "
199 << sequenceRawType;
200
201 if (parser.resolveOperand(sequenceRawOperand, sequenceRawType,
202 result.operands))
203 return failure();
204
205 if (parser.resolveOperands(replacementsOperands,
206 cast<SequenceType>(sequenceRawType)
207 .getElementTypes()
208 .take_front(replacementsOperands.size()),
209 replacementsOperandsLoc, result.operands))
210 return failure();
211
212 SmallVector<Type> inferredReturnTypes;
213 if (failed(inferReturnTypes(
214 parser.getContext(), result.location, result.operands,
215 result.attributes.getDictionary(parser.getContext()),
216 result.getRawProperties(), result.regions, inferredReturnTypes)))
217 return failure();
218
219 result.addTypes(inferredReturnTypes);
220 return success();
221}
222
223void SubstituteSequenceOp::print(OpAsmPrinter &p) {
224 p << ' ' << getSequence() << "(" << getReplacements()
225 << ") : " << getSequence().getType();
226 p.printOptionalAttrDict((*this)->getAttrs(), {});
227}
228
229//===----------------------------------------------------------------------===//
230// InterleaveSequencesOp
231//===----------------------------------------------------------------------===//
232
233LogicalResult InterleaveSequencesOp::verify() {
234 if (getSequences().empty())
235 return emitOpError("must have at least one sequence in the list");
236
237 return success();
238}
239
240OpFoldResult InterleaveSequencesOp::fold(FoldAdaptor adaptor) {
241 if (getSequences().size() == 1)
242 return getSequences()[0];
243
244 return {};
245}
246
247//===----------------------------------------------------------------------===//
248// SetCreateOp
249//===----------------------------------------------------------------------===//
250
251ParseResult SetCreateOp::parse(OpAsmParser &parser, OperationState &result) {
252 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
253 Type elemType;
254
255 if (parser.parseOperandList(operands) ||
256 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
257 parser.parseType(elemType))
258 return failure();
259
260 result.addTypes({SetType::get(result.getContext(), elemType)});
261
262 for (auto operand : operands)
263 if (parser.resolveOperand(operand, elemType, result.operands))
264 return failure();
265
266 return success();
267}
268
269void SetCreateOp::print(OpAsmPrinter &p) {
270 p << " ";
271 p.printOperands(getElements());
272 p.printOptionalAttrDict((*this)->getAttrs());
273 p << " : " << getSet().getType().getElementType();
274}
275
276LogicalResult SetCreateOp::verify() {
277 if (getElements().size() > 0) {
278 // We only need to check the first element because of the `SameTypeOperands`
279 // trait.
280 if (getElements()[0].getType() != getSet().getType().getElementType())
281 return emitOpError() << "operand types must match set element type";
282 }
283
284 return success();
285}
286
287//===----------------------------------------------------------------------===//
288// SetCartesianProductOp
289//===----------------------------------------------------------------------===//
290
291LogicalResult SetCartesianProductOp::inferReturnTypes(
292 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
293 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
294 SmallVectorImpl<Type> &inferredReturnTypes) {
295 if (operands.empty()) {
296 if (loc)
297 return mlir::emitError(*loc) << "at least one set must be provided";
298 return failure();
299 }
300
301 SmallVector<Type> elementTypes;
302 for (auto operand : operands)
303 elementTypes.push_back(cast<SetType>(operand.getType()).getElementType());
304 inferredReturnTypes.push_back(
305 SetType::get(rtg::TupleType::get(context, elementTypes)));
306 return success();
307}
308
309//===----------------------------------------------------------------------===//
310// BagCreateOp
311//===----------------------------------------------------------------------===//
312
313ParseResult BagCreateOp::parse(OpAsmParser &parser, OperationState &result) {
314 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> elementOperands,
315 multipleOperands;
316 Type elemType;
317
318 if (!parser.parseOptionalLParen()) {
319 while (true) {
320 OpAsmParser::UnresolvedOperand elementOperand, multipleOperand;
321 if (parser.parseOperand(multipleOperand) || parser.parseKeyword("x") ||
322 parser.parseOperand(elementOperand))
323 return failure();
324
325 elementOperands.push_back(elementOperand);
326 multipleOperands.push_back(multipleOperand);
327
328 if (parser.parseOptionalComma()) {
329 if (parser.parseRParen())
330 return failure();
331 break;
332 }
333 }
334 }
335
336 if (parser.parseColon() || parser.parseType(elemType) ||
337 parser.parseOptionalAttrDict(result.attributes))
338 return failure();
339
340 result.addTypes({BagType::get(result.getContext(), elemType)});
341
342 for (auto operand : elementOperands)
343 if (parser.resolveOperand(operand, elemType, result.operands))
344 return failure();
345
346 for (auto operand : multipleOperands)
347 if (parser.resolveOperand(operand, IndexType::get(result.getContext()),
348 result.operands))
349 return failure();
350
351 return success();
352}
353
354void BagCreateOp::print(OpAsmPrinter &p) {
355 p << " ";
356 if (!getElements().empty())
357 p << "(";
358 llvm::interleaveComma(llvm::zip(getElements(), getMultiples()), p,
359 [&](auto elAndMultiple) {
360 auto [el, multiple] = elAndMultiple;
361 p << multiple << " x " << el;
362 });
363 if (!getElements().empty())
364 p << ")";
365
366 p << " : " << getBag().getType().getElementType();
367 p.printOptionalAttrDict((*this)->getAttrs());
368}
369
370LogicalResult BagCreateOp::verify() {
371 if (!llvm::all_equal(getElements().getTypes()))
372 return emitOpError() << "types of all elements must match";
373
374 if (getElements().size() > 0)
375 if (getElements()[0].getType() != getBag().getType().getElementType())
376 return emitOpError() << "operand types must match bag element type";
377
378 return success();
379}
380
381//===----------------------------------------------------------------------===//
382// TupleCreateOp
383//===----------------------------------------------------------------------===//
384
385LogicalResult TupleCreateOp::inferReturnTypes(
386 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
387 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
388 SmallVectorImpl<Type> &inferredReturnTypes) {
389 SmallVector<Type> elementTypes;
390 for (auto operand : operands)
391 elementTypes.push_back(operand.getType());
392 inferredReturnTypes.push_back(rtg::TupleType::get(context, elementTypes));
393 return success();
394}
395
396//===----------------------------------------------------------------------===//
397// TupleExtractOp
398//===----------------------------------------------------------------------===//
399
400LogicalResult TupleExtractOp::inferReturnTypes(
401 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
402 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
403 SmallVectorImpl<Type> &inferredReturnTypes) {
404 assert(operands.size() == 1 && "must have exactly one operand");
405
406 auto tupleTy = dyn_cast<rtg::TupleType>(operands[0].getType());
407 size_t idx = properties.as<Properties *>()->getIndex().getInt();
408 if (!tupleTy) {
409 if (loc)
410 return mlir::emitError(*loc) << "only RTG tuples are supported";
411 return failure();
412 }
413
414 if (tupleTy.getFieldTypes().size() <= idx) {
415 if (loc)
416 return mlir::emitError(*loc)
417 << "index (" << idx
418 << ") must be smaller than number of elements in tuple ("
419 << tupleTy.getFieldTypes().size() << ")";
420 return failure();
421 }
422
423 inferredReturnTypes.push_back(tupleTy.getFieldTypes()[idx]);
424 return success();
425}
426
427//===----------------------------------------------------------------------===//
428// VirtualRegisterOp
429//===----------------------------------------------------------------------===//
430
431LogicalResult VirtualRegisterOp::inferReturnTypes(
432 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
433 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
434 SmallVectorImpl<Type> &inferredReturnTypes) {
435 auto allowedRegs = properties.as<Properties *>()->getAllowedRegs();
436 inferredReturnTypes.push_back(allowedRegs.getType());
437 return success();
438}
439
440//===----------------------------------------------------------------------===//
441// ContextSwitchOp
442//===----------------------------------------------------------------------===//
443
444LogicalResult ContextSwitchOp::verify() {
445 auto elementTypes = getSequence().getType().getElementTypes();
446 if (elementTypes.size() != 3)
447 return emitOpError("sequence type must have exactly 3 element types");
448
449 if (getFrom().getType() != elementTypes[0])
450 return emitOpError(
451 "first sequence element type must match 'from' attribute type");
452
453 if (getTo().getType() != elementTypes[1])
454 return emitOpError(
455 "second sequence element type must match 'to' attribute type");
456
457 auto seqTy = dyn_cast<SequenceType>(elementTypes[2]);
458 if (!seqTy || !seqTy.getElementTypes().empty())
459 return emitOpError(
460 "third sequence element type must be a fully substituted sequence");
461
462 return success();
463}
464
465//===----------------------------------------------------------------------===//
466// TestOp
467//===----------------------------------------------------------------------===//
468
469LogicalResult TestOp::verifyRegions() {
470 if (!getTargetType().entryTypesMatch(getBody()->getArgumentTypes()))
471 return emitOpError("argument types must match dict entry types");
472
473 return success();
474}
475
476LogicalResult TestOp::verify() {
477 if (getTemplateName().empty())
478 return emitOpError("template name must not be empty");
479
480 return success();
481}
482
483LogicalResult TestOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
484 if (!getTargetAttr())
485 return success();
486
487 auto target =
488 symbolTable.lookupNearestSymbolFrom<TargetOp>(*this, getTargetAttr());
489 if (!target)
490 return emitOpError()
491 << "'" << *getTarget()
492 << "' does not reference a valid 'rtg.target' operation";
493
494 // Check if target is a subtype of test requirements
495 // Since entries are sorted by name, we can do this in a single pass
496 size_t targetIdx = 0;
497 auto targetEntries = target.getTarget().getEntries();
498 for (auto testEntry : getTargetType().getEntries()) {
499 // Find the matching entry in target entries.
500 while (targetIdx < targetEntries.size() &&
501 targetEntries[targetIdx].name.getValue() < testEntry.name.getValue())
502 targetIdx++;
503
504 // Check if we found a matching entry with the same name and type
505 if (targetIdx >= targetEntries.size() ||
506 targetEntries[targetIdx].name != testEntry.name ||
507 targetEntries[targetIdx].type != testEntry.type) {
508 return emitOpError("referenced 'rtg.target' op's type is invalid: "
509 "missing entry called '")
510 << testEntry.name.getValue() << "' of type " << testEntry.type;
511 }
512 }
513
514 return success();
515}
516
517ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
518 // Parse the name as a symbol.
519 StringAttr symNameAttr;
520 if (parser.parseSymbolName(symNameAttr))
521 return failure();
522
523 result.getOrAddProperties<TestOp::Properties>().sym_name = symNameAttr;
524
525 // Parse the function signature.
526 SmallVector<OpAsmParser::Argument> arguments;
527 SmallVector<StringAttr> names;
528
529 auto parseOneArgument = [&]() -> ParseResult {
530 std::string name;
531 if (parser.parseKeywordOrString(&name) || parser.parseEqual() ||
532 parser.parseArgument(arguments.emplace_back(), /*allowType=*/true,
533 /*allowAttrs=*/true))
534 return failure();
535
536 names.push_back(StringAttr::get(result.getContext(), name));
537 return success();
538 };
539 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
540 parseOneArgument, " in argument list"))
541 return failure();
542
543 SmallVector<Type> argTypes;
544 SmallVector<DictEntry> entries;
545 SmallVector<Location> argLocs;
546 argTypes.reserve(arguments.size());
547 argLocs.reserve(arguments.size());
548 for (auto [name, arg] : llvm::zip(names, arguments)) {
549 argTypes.push_back(arg.type);
550 argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
551 entries.push_back({name, arg.type});
552 }
553 auto emitError = [&]() -> InFlightDiagnostic {
554 return parser.emitError(parser.getCurrentLocation());
555 };
556 Type type = DictType::getChecked(emitError, result.getContext(),
557 ArrayRef<DictEntry>(entries));
558 if (!type)
559 return failure();
560 result.getOrAddProperties<TestOp::Properties>().targetType =
561 TypeAttr::get(type);
562
563 std::string templateName;
564 if (!parser.parseOptionalKeyword("template")) {
565 auto loc = parser.getCurrentLocation();
566 if (parser.parseString(&templateName))
567 return failure();
568
569 if (templateName.empty())
570 return parser.emitError(loc, "template name must not be empty");
571 }
572
573 StringAttr templateNameAttr = symNameAttr;
574 if (!templateName.empty())
575 templateNameAttr = StringAttr::get(result.getContext(), templateName);
576
577 StringAttr targetName;
578 if (!parser.parseOptionalKeyword("target"))
579 if (parser.parseSymbolName(targetName))
580 return failure();
581
582 result.getOrAddProperties<TestOp::Properties>().templateName =
583 templateNameAttr;
584 result.getOrAddProperties<TestOp::Properties>().target = targetName;
585
586 auto loc = parser.getCurrentLocation();
587 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
588 return failure();
589 if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
590 return parser.emitError(loc)
591 << "'" << result.name.getStringRef() << "' op ";
592 })))
593 return failure();
594
595 std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
596 if (parser.parseRegion(*bodyRegionRegion, arguments))
597 return failure();
598
599 if (bodyRegionRegion->empty()) {
600 bodyRegionRegion->emplaceBlock();
601 bodyRegionRegion->addArguments(argTypes, argLocs);
602 }
603 result.addRegion(std::move(bodyRegionRegion));
604
605 return success();
606}
607
608void TestOp::print(OpAsmPrinter &p) {
609 p << ' ';
610 p.printSymbolName(getSymNameAttr().getValue());
611 p << "(";
612 SmallString<32> resultNameStr;
613 llvm::interleaveComma(
614 llvm::zip(getTargetType().getEntries(), getBody()->getArguments()), p,
615 [&](auto entryAndArg) {
616 auto [entry, arg] = entryAndArg;
617 p << entry.name.getValue() << " = ";
618 p.printRegionArgument(arg);
619 });
620 p << ")";
621
622 if (getSymNameAttr() != getTemplateNameAttr())
623 p << " template " << getTemplateNameAttr();
624
625 if (getTargetAttr()) {
626 p << " target ";
627 p.printSymbolName(getTargetAttr().getValue());
628 }
629
630 p.printOptionalAttrDictWithKeyword(
631 (*this)->getAttrs(), {getSymNameAttrName(), getTargetTypeAttrName(),
632 getTargetAttrName(), getTemplateNameAttrName()});
633 p << ' ';
634 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
635}
636
637void TestOp::getAsmBlockArgumentNames(Region &region,
638 OpAsmSetValueNameFn setNameFn) {
639 for (auto [entry, arg] :
640 llvm::zip(getTargetType().getEntries(), region.getArguments()))
641 setNameFn(arg, entry.name.getValue());
642}
643
644//===----------------------------------------------------------------------===//
645// TargetOp
646//===----------------------------------------------------------------------===//
647
648LogicalResult TargetOp::verifyRegions() {
649 if (!getTarget().entryTypesMatch(
650 getBody()->getTerminator()->getOperandTypes()))
651 return emitOpError("terminator operand types must match dict entry types");
652
653 return success();
654}
655
656//===----------------------------------------------------------------------===//
657// ValidateOp
658//===----------------------------------------------------------------------===//
659
660LogicalResult ValidateOp::verify() {
661 if (!getRef().getType().isValidContentType(getValue().getType()))
662 return emitOpError(
663 "result type must be a valid content type for the ref value");
664
665 return success();
666}
667
668//===----------------------------------------------------------------------===//
669// ArrayCreateOp
670//===----------------------------------------------------------------------===//
671
672LogicalResult ArrayCreateOp::verify() {
673 if (!getElements().empty() &&
674 getElements()[0].getType() != getType().getElementType())
675 return emitOpError("operand types must match array element type, expected ")
676 << getType().getElementType() << " but got "
677 << getElements()[0].getType();
678
679 return success();
680}
681
682ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
683 SmallVector<OpAsmParser::UnresolvedOperand> operands;
684 Type elementType;
685
686 if (parser.parseOperandList(operands) || parser.parseColon() ||
687 parser.parseType(elementType) ||
688 parser.parseOptionalAttrDict(result.attributes))
689 return failure();
690
691 if (failed(parser.resolveOperands(operands, elementType, result.operands)))
692 return failure();
693
694 result.addTypes(ArrayType::get(elementType));
695
696 return success();
697}
698
699void ArrayCreateOp::print(OpAsmPrinter &p) {
700 p << ' ';
701 p.printOperands(getElements());
702 p << " : " << getType().getElementType();
703 p.printOptionalAttrDict((*this)->getAttrs(), {});
704}
705
706//===----------------------------------------------------------------------===//
707// MemoryBlockDeclareOp
708//===----------------------------------------------------------------------===//
709
710LogicalResult MemoryBlockDeclareOp::verify() {
711 if (getBaseAddress().getBitWidth() != getType().getAddressWidth())
712 return emitOpError(
713 "base address width must match memory block address width");
714
715 if (getEndAddress().getBitWidth() != getType().getAddressWidth())
716 return emitOpError(
717 "end address width must match memory block address width");
718
719 if (getBaseAddress().ugt(getEndAddress()))
720 return emitOpError(
721 "base address must be smaller than or equal to the end address");
722
723 return success();
724}
725
726ParseResult MemoryBlockDeclareOp::parse(OpAsmParser &parser,
727 OperationState &result) {
728 SmallVector<OpAsmParser::UnresolvedOperand> operands;
729 MemoryBlockType memoryBlockType;
730 APInt start, end;
731
732 if (parser.parseLSquare())
733 return failure();
734
735 auto startLoc = parser.getCurrentLocation();
736 if (parser.parseInteger(start))
737 return failure();
738
739 if (parser.parseMinus())
740 return failure();
741
742 auto endLoc = parser.getCurrentLocation();
743 if (parser.parseInteger(end) || parser.parseRSquare() ||
744 parser.parseColonType(memoryBlockType) ||
745 parser.parseOptionalAttrDict(result.attributes))
746 return failure();
747
748 auto width = memoryBlockType.getAddressWidth();
749 auto adjustAPInt = [&](APInt value, llvm::SMLoc loc) -> FailureOr<APInt> {
750 if (value.getBitWidth() > width) {
751 if (!value.isIntN(width))
752 return parser.emitError(
753 loc,
754 "address out of range for memory block with address width ")
755 << width;
756
757 return value.trunc(width);
758 }
759
760 if (value.getBitWidth() < width)
761 return value.zext(width);
762
763 return value;
764 };
765
766 auto startRes = adjustAPInt(start, startLoc);
767 auto endRes = adjustAPInt(end, endLoc);
768 if (failed(startRes) || failed(endRes))
769 return failure();
770
771 auto intType = IntegerType::get(result.getContext(), width);
772 result.addAttribute(getBaseAddressAttrName(result.name),
773 IntegerAttr::get(intType, *startRes));
774 result.addAttribute(getEndAddressAttrName(result.name),
775 IntegerAttr::get(intType, *endRes));
776
777 result.addTypes(memoryBlockType);
778 return success();
779}
780
781void MemoryBlockDeclareOp::print(OpAsmPrinter &p) {
782 SmallVector<char> str;
783 getBaseAddress().toString(str, 16, false, false, false);
784 p << " [0x" << str;
785 p << " - 0x";
786 str.clear();
787 getEndAddress().toString(str, 16, false, false, false);
788 p << str << "] : " << getType();
789 p.printOptionalAttrDict((*this)->getAttrs(),
790 {getBaseAddressAttrName(), getEndAddressAttrName()});
791}
792
793//===----------------------------------------------------------------------===//
794// MemoryBaseAddressOp
795//===----------------------------------------------------------------------===//
796
797LogicalResult MemoryBaseAddressOp::inferReturnTypes(
798 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
799 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
800 SmallVectorImpl<Type> &inferredReturnTypes) {
801 if (operands.empty())
802 return failure();
803 auto memTy = dyn_cast<MemoryType>(operands[0].getType());
804 if (!memTy)
805 return failure();
806 inferredReturnTypes.push_back(
807 ImmediateType::get(context, memTy.getAddressWidth()));
808 return success();
809}
810
811//===----------------------------------------------------------------------===//
812// ConcatImmediateOp
813//===----------------------------------------------------------------------===//
814
815LogicalResult ConcatImmediateOp::inferReturnTypes(
816 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
817 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
818 SmallVectorImpl<Type> &inferredReturnTypes) {
819 if (operands.empty()) {
820 if (loc)
821 return mlir::emitError(*loc) << "at least one operand must be provided";
822 return failure();
823 }
824
825 unsigned totalWidth = 0;
826 for (auto operand : operands) {
827 auto immType = dyn_cast<ImmediateType>(operand.getType());
828 if (!immType) {
829 if (loc)
830 return mlir::emitError(*loc)
831 << "all operands must be of immediate type";
832 return failure();
833 }
834 totalWidth += immType.getWidth();
835 }
836
837 inferredReturnTypes.push_back(ImmediateType::get(context, totalWidth));
838 return success();
839}
840
841OpFoldResult ConcatImmediateOp::fold(FoldAdaptor adaptor) {
842 // concat(x) -> x
843 if (getOperands().size() == 1)
844 return getOperands()[0];
845
846 // If all operands are constants, fold into a single constant
847 if (llvm::all_of(adaptor.getOperands(), [](Attribute attr) {
848 return isa_and_nonnull<ImmediateAttr>(attr);
849 })) {
850 auto result = APInt::getZeroWidth();
851 for (auto attr : adaptor.getOperands())
852 result = result.concat(cast<ImmediateAttr>(attr).getValue());
853
854 return ImmediateAttr::get(getContext(), result);
855 }
856
857 return {};
858}
859
860//===----------------------------------------------------------------------===//
861// SliceImmediateOp
862//===----------------------------------------------------------------------===//
863
864LogicalResult SliceImmediateOp::verify() {
865 auto srcWidth = getInput().getType().getWidth();
866 auto dstWidth = getResult().getType().getWidth();
867
868 if (getLowBit() >= srcWidth)
869 return emitOpError("from bit too large for input (got ")
870 << getLowBit() << ", but input width is " << srcWidth << ")";
871
872 if (srcWidth - getLowBit() < dstWidth)
873 return emitOpError("slice does not fit in input (trying to extract ")
874 << dstWidth << " bits starting at index " << getLowBit()
875 << ", but only " << (srcWidth - getLowBit())
876 << " bits are available)";
877
878 return success();
879}
880
881OpFoldResult SliceImmediateOp::fold(FoldAdaptor adaptor) {
882 if (auto inputAttr = dyn_cast_or_null<ImmediateAttr>(adaptor.getInput())) {
883 auto resultWidth = getType().getWidth();
884 APInt sliced = inputAttr.getValue().extractBits(resultWidth, getLowBit());
885 return ImmediateAttr::get(getContext(), sliced);
886 }
887
888 return {};
889}
890
891//===----------------------------------------------------------------------===//
892// LabelUniqueDeclOp and LabelDeclOp
893//===----------------------------------------------------------------------===//
894
895static StringAttr substituteFormatString(StringAttr formatString,
896 ArrayRef<Attribute> substitutes) {
897 if (substitutes.empty() || formatString.empty())
898 return formatString;
899
900 auto original = formatString.getValue().str();
901 size_t curr = 0;
902 for (auto [i, subst] : llvm::enumerate(substitutes)) {
903 auto substInt = dyn_cast_or_null<IntegerAttr>(subst);
904 std::string substString;
905 if (!substInt && curr == i) {
906 ++curr;
907 continue;
908 }
909 if (!substInt)
910 substString = "{{" + std::to_string(curr++) + "}}";
911 else
912 substString = std::to_string(substInt.getValue().getZExtValue());
913
914 size_t startPos = 0;
915 std::string from = "{{" + std::to_string(i) + "}}";
916 while ((startPos = original.find(from, startPos)) != std::string::npos) {
917 original.replace(startPos, from.length(), substString);
918 }
919 }
920
921 return StringAttr::get(formatString.getContext(), original);
922}
923
924template <typename OpTy>
925OpFoldResult labelDeclFolder(OpTy op, typename OpTy::FoldAdaptor adaptor) {
926 auto newFormatString =
927 substituteFormatString(op.getFormatStringAttr(), adaptor.getArgs());
928 if (newFormatString == op.getFormatStringAttr())
929 return {};
930
931 op.setFormatStringAttr(newFormatString);
932
933 SmallVector<Value> newArgs;
934 for (auto [arg, attr] : llvm::zip(op.getArgs(), adaptor.getArgs())) {
935 if (!attr)
936 newArgs.push_back(arg);
937 }
938 op.getArgsMutable().assign(newArgs);
939
940 return op.getLabel();
941}
942
943OpFoldResult LabelUniqueDeclOp::fold(FoldAdaptor adaptor) {
944 return labelDeclFolder(*this, adaptor);
945}
946
947OpFoldResult LabelDeclOp::fold(FoldAdaptor adaptor) {
948 return labelDeclFolder(*this, adaptor);
949}
950
951//===----------------------------------------------------------------------===//
952// TableGen generated logic.
953//===----------------------------------------------------------------------===//
954
955#define GET_OP_CLASSES
956#include "circt/Dialect/RTG/IR/RTG.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static size_t getAddressWidth(size_t depth)
static StringAttr substituteFormatString(StringAttr formatString, ArrayRef< Attribute > substitutes)
Definition RTGOps.cpp:895
OpFoldResult labelDeclFolder(OpTy op, typename OpTy::FoldAdaptor adaptor)
Definition RTGOps.cpp:925
static InstancePath empty
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
Definition rtg.py:1
Definition seq.py:1
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21