CIRCT 23.0.0git
Loading...
Searching...
No Matches
RTGOps.cpp
Go to the documentation of this file.
1//===- RTGOps.cpp - Implement the RTG operations --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the RTG ops.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/DialectImplementation.h"
18#include "mlir/IR/Matchers.h"
19#include "mlir/IR/PatternMatch.h"
20#include "llvm/ADT/SmallString.h"
21
22using namespace mlir;
23using namespace circt;
24using namespace rtg;
25
26//===----------------------------------------------------------------------===//
27// ConstantOp
28//===----------------------------------------------------------------------===//
29
30LogicalResult
31ConstantOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
32 ValueRange operands, DictionaryAttr attributes,
33 OpaqueProperties properties, RegionRange regions,
34 SmallVectorImpl<Type> &inferredReturnTypes) {
35 inferredReturnTypes.push_back(
36 properties.as<Properties *>()->getValue().getType());
37 return success();
38}
39
40OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
41
42void ConstantOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
43 if (auto reg = dyn_cast<rtg::RegisterAttrInterface>(getValueAttr())) {
44 setNameFn(getResult(), reg.getRegisterAssembly());
45 return;
46 }
47}
48
49//===----------------------------------------------------------------------===//
50// SequenceOp
51//===----------------------------------------------------------------------===//
52
53LogicalResult SequenceOp::verifyRegions() {
54 if (TypeRange(getSequenceType().getElementTypes()) !=
55 getBody()->getArgumentTypes())
56 return emitOpError("sequence type does not match block argument types");
57
58 return success();
59}
60
61ParseResult SequenceOp::parse(OpAsmParser &parser, OperationState &result) {
62 // Parse the name as a symbol.
63 if (parser.parseSymbolName(
64 result.getOrAddProperties<SequenceOp::Properties>().sym_name))
65 return failure();
66
67 // Parse the function signature.
68 SmallVector<OpAsmParser::Argument> arguments;
69 if (parser.parseArgumentList(arguments, OpAsmParser::Delimiter::Paren,
70 /*allowType=*/true, /*allowAttrs=*/true))
71 return failure();
72
73 SmallVector<Type> argTypes;
74 SmallVector<Location> argLocs;
75 argTypes.reserve(arguments.size());
76 argLocs.reserve(arguments.size());
77 for (auto &arg : arguments) {
78 argTypes.push_back(arg.type);
79 argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
80 }
81 Type type = SequenceType::get(result.getContext(), argTypes);
82 result.getOrAddProperties<SequenceOp::Properties>().sequenceType =
83 TypeAttr::get(type);
84
85 auto loc = parser.getCurrentLocation();
86 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
87 return failure();
88 if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
89 return parser.emitError(loc)
90 << "'" << result.name.getStringRef() << "' op ";
91 })))
92 return failure();
93
94 std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
95 if (parser.parseRegion(*bodyRegionRegion, arguments))
96 return failure();
97
98 if (bodyRegionRegion->empty()) {
99 bodyRegionRegion->emplaceBlock();
100 bodyRegionRegion->addArguments(argTypes, argLocs);
101 }
102 result.addRegion(std::move(bodyRegionRegion));
103
104 return success();
105}
106
107void SequenceOp::print(OpAsmPrinter &p) {
108 p << ' ';
109 p.printSymbolName(getSymNameAttr().getValue());
110 p << "(";
111 llvm::interleaveComma(getBody()->getArguments(), p,
112 [&](auto arg) { p.printRegionArgument(arg); });
113 p << ")";
114 p.printOptionalAttrDictWithKeyword(
115 (*this)->getAttrs(), {getSymNameAttrName(), getSequenceTypeAttrName()});
116 p << ' ';
117 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
118}
119
120mlir::SymbolTable::Visibility SequenceOp::getVisibility() {
121 return mlir::SymbolTable::Visibility::Private;
122}
123
124void SequenceOp::setVisibility(mlir::SymbolTable::Visibility visibility) {
125 // Do nothing, always private.
126 assert(false && "cannot change visibility of sequence");
127}
128
129//===----------------------------------------------------------------------===//
130// GetSequenceOp
131//===----------------------------------------------------------------------===//
132
133LogicalResult
134GetSequenceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
135 SequenceOp seq =
136 symbolTable.lookupNearestSymbolFrom<SequenceOp>(*this, getSequenceAttr());
137 if (!seq)
138 return emitOpError()
139 << "'" << getSequence()
140 << "' does not reference a valid 'rtg.sequence' operation";
141
142 if (seq.getSequenceType() != getType())
143 return emitOpError("referenced 'rtg.sequence' op's type does not match");
144
145 return success();
146}
147
148//===----------------------------------------------------------------------===//
149// SubstituteSequenceOp
150//===----------------------------------------------------------------------===//
151
152LogicalResult SubstituteSequenceOp::verify() {
153 if (getReplacements().empty())
154 return emitOpError("must at least have one replacement value");
155
156 if (getReplacements().size() >
157 getSequence().getType().getElementTypes().size())
158 return emitOpError(
159 "must not have more replacement values than sequence arguments");
160
161 if (getReplacements().getTypes() !=
162 getSequence().getType().getElementTypes().take_front(
163 getReplacements().size()))
164 return emitOpError("replacement types must match the same number of "
165 "sequence argument types from the front");
166
167 return success();
168}
169
170LogicalResult SubstituteSequenceOp::inferReturnTypes(
171 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
172 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
173 SmallVectorImpl<Type> &inferredReturnTypes) {
174 ArrayRef<Type> argTypes =
175 cast<SequenceType>(operands[0].getType()).getElementTypes();
176 auto seqType =
177 SequenceType::get(context, argTypes.drop_front(operands.size() - 1));
178 inferredReturnTypes.push_back(seqType);
179 return success();
180}
181
182ParseResult SubstituteSequenceOp::parse(::mlir::OpAsmParser &parser,
183 ::mlir::OperationState &result) {
184 OpAsmParser::UnresolvedOperand sequenceRawOperand;
185 SmallVector<OpAsmParser::UnresolvedOperand, 4> replacementsOperands;
186 Type sequenceRawType;
187
188 if (parser.parseOperand(sequenceRawOperand) || parser.parseLParen())
189 return failure();
190
191 auto replacementsOperandsLoc = parser.getCurrentLocation();
192 if (parser.parseOperandList(replacementsOperands) || parser.parseRParen() ||
193 parser.parseColon() || parser.parseType(sequenceRawType) ||
194 parser.parseOptionalAttrDict(result.attributes))
195 return failure();
196
197 if (!isa<SequenceType>(sequenceRawType))
198 return parser.emitError(parser.getNameLoc())
199 << "'sequence' must be handle to a sequence or sequence family, but "
200 "got "
201 << sequenceRawType;
202
203 if (parser.resolveOperand(sequenceRawOperand, sequenceRawType,
204 result.operands))
205 return failure();
206
207 if (parser.resolveOperands(replacementsOperands,
208 cast<SequenceType>(sequenceRawType)
209 .getElementTypes()
210 .take_front(replacementsOperands.size()),
211 replacementsOperandsLoc, result.operands))
212 return failure();
213
214 SmallVector<Type> inferredReturnTypes;
215 if (failed(inferReturnTypes(
216 parser.getContext(), result.location, result.operands,
217 result.attributes.getDictionary(parser.getContext()),
218 result.getRawProperties(), result.regions, inferredReturnTypes)))
219 return failure();
220
221 result.addTypes(inferredReturnTypes);
222 return success();
223}
224
225void SubstituteSequenceOp::print(OpAsmPrinter &p) {
226 p << ' ' << getSequence() << "(" << getReplacements()
227 << ") : " << getSequence().getType();
228 p.printOptionalAttrDict((*this)->getAttrs(), {});
229}
230
231//===----------------------------------------------------------------------===//
232// InterleaveSequencesOp
233//===----------------------------------------------------------------------===//
234
235LogicalResult InterleaveSequencesOp::verify() {
236 if (getSequences().empty())
237 return emitOpError("must have at least one sequence in the list");
238
239 return success();
240}
241
242OpFoldResult InterleaveSequencesOp::fold(FoldAdaptor adaptor) {
243 if (getSequences().size() == 1)
244 return getSequences()[0];
245
246 return {};
247}
248
249//===----------------------------------------------------------------------===//
250// SetCreateOp
251//===----------------------------------------------------------------------===//
252
253ParseResult SetCreateOp::parse(OpAsmParser &parser, OperationState &result) {
254 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
255 Type elemType;
256
257 if (parser.parseOperandList(operands) ||
258 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
259 parser.parseType(elemType))
260 return failure();
261
262 result.addTypes({SetType::get(result.getContext(), elemType)});
263
264 for (auto operand : operands)
265 if (parser.resolveOperand(operand, elemType, result.operands))
266 return failure();
267
268 return success();
269}
270
271void SetCreateOp::print(OpAsmPrinter &p) {
272 p << " ";
273 p.printOperands(getElements());
274 p.printOptionalAttrDict((*this)->getAttrs());
275 p << " : " << getSet().getType().getElementType();
276}
277
278LogicalResult SetCreateOp::verify() {
279 if (getElements().size() > 0) {
280 // We only need to check the first element because of the `SameTypeOperands`
281 // trait.
282 if (getElements()[0].getType() != getSet().getType().getElementType())
283 return emitOpError() << "operand types must match set element type";
284 }
285
286 return success();
287}
288
289//===----------------------------------------------------------------------===//
290// SetCartesianProductOp
291//===----------------------------------------------------------------------===//
292
293LogicalResult SetCartesianProductOp::inferReturnTypes(
294 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
295 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
296 SmallVectorImpl<Type> &inferredReturnTypes) {
297 if (operands.empty()) {
298 if (loc)
299 return mlir::emitError(*loc) << "at least one set must be provided";
300 return failure();
301 }
302
303 SmallVector<Type> elementTypes;
304 for (auto operand : operands)
305 elementTypes.push_back(cast<SetType>(operand.getType()).getElementType());
306 inferredReturnTypes.push_back(
307 SetType::get(rtg::TupleType::get(context, elementTypes)));
308 return success();
309}
310
311//===----------------------------------------------------------------------===//
312// BagCreateOp
313//===----------------------------------------------------------------------===//
314
315ParseResult BagCreateOp::parse(OpAsmParser &parser, OperationState &result) {
316 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> elementOperands,
317 multipleOperands;
318 Type elemType;
319
320 if (!parser.parseOptionalLParen()) {
321 while (true) {
322 OpAsmParser::UnresolvedOperand elementOperand, multipleOperand;
323 if (parser.parseOperand(multipleOperand) || parser.parseKeyword("x") ||
324 parser.parseOperand(elementOperand))
325 return failure();
326
327 elementOperands.push_back(elementOperand);
328 multipleOperands.push_back(multipleOperand);
329
330 if (parser.parseOptionalComma()) {
331 if (parser.parseRParen())
332 return failure();
333 break;
334 }
335 }
336 }
337
338 if (parser.parseColon() || parser.parseType(elemType) ||
339 parser.parseOptionalAttrDict(result.attributes))
340 return failure();
341
342 result.addTypes({BagType::get(result.getContext(), elemType)});
343
344 for (auto operand : elementOperands)
345 if (parser.resolveOperand(operand, elemType, result.operands))
346 return failure();
347
348 for (auto operand : multipleOperands)
349 if (parser.resolveOperand(operand, IndexType::get(result.getContext()),
350 result.operands))
351 return failure();
352
353 return success();
354}
355
356void BagCreateOp::print(OpAsmPrinter &p) {
357 p << " ";
358 if (!getElements().empty())
359 p << "(";
360 llvm::interleaveComma(llvm::zip(getElements(), getMultiples()), p,
361 [&](auto elAndMultiple) {
362 auto [el, multiple] = elAndMultiple;
363 p << multiple << " x " << el;
364 });
365 if (!getElements().empty())
366 p << ")";
367
368 p << " : " << getBag().getType().getElementType();
369 p.printOptionalAttrDict((*this)->getAttrs());
370}
371
372LogicalResult BagCreateOp::verify() {
373 if (!llvm::all_equal(getElements().getTypes()))
374 return emitOpError() << "types of all elements must match";
375
376 if (getElements().size() > 0)
377 if (getElements()[0].getType() != getBag().getType().getElementType())
378 return emitOpError() << "operand types must match bag element type";
379
380 return success();
381}
382
383//===----------------------------------------------------------------------===//
384// TupleCreateOp
385//===----------------------------------------------------------------------===//
386
387LogicalResult TupleCreateOp::inferReturnTypes(
388 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
389 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
390 SmallVectorImpl<Type> &inferredReturnTypes) {
391 SmallVector<Type> elementTypes;
392 for (auto operand : operands)
393 elementTypes.push_back(operand.getType());
394 inferredReturnTypes.push_back(rtg::TupleType::get(context, elementTypes));
395 return success();
396}
397
398//===----------------------------------------------------------------------===//
399// TupleExtractOp
400//===----------------------------------------------------------------------===//
401
402LogicalResult TupleExtractOp::inferReturnTypes(
403 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
404 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
405 SmallVectorImpl<Type> &inferredReturnTypes) {
406 assert(operands.size() == 1 && "must have exactly one operand");
407
408 auto tupleTy = dyn_cast<rtg::TupleType>(operands[0].getType());
409 size_t idx = properties.as<Properties *>()->getIndex().getInt();
410 if (!tupleTy) {
411 if (loc)
412 return mlir::emitError(*loc) << "only RTG tuples are supported";
413 return failure();
414 }
415
416 if (tupleTy.getFieldTypes().size() <= idx) {
417 if (loc)
418 return mlir::emitError(*loc)
419 << "index (" << idx
420 << ") must be smaller than number of elements in tuple ("
421 << tupleTy.getFieldTypes().size() << ")";
422 return failure();
423 }
424
425 inferredReturnTypes.push_back(tupleTy.getFieldTypes()[idx]);
426 return success();
427}
428
429//===----------------------------------------------------------------------===//
430// ConstraintOp
431//===----------------------------------------------------------------------===//
432
433LogicalResult ConstraintOp::canonicalize(ConstraintOp op,
434 PatternRewriter &rewriter) {
435 if (mlir::matchPattern(op.getCondition(), mlir::m_One())) {
436 rewriter.eraseOp(op);
437 return success();
438 }
439
440 return failure();
441}
442
443//===----------------------------------------------------------------------===//
444// VirtualRegisterOp
445//===----------------------------------------------------------------------===//
446
447LogicalResult VirtualRegisterOp::inferReturnTypes(
448 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
449 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
450 SmallVectorImpl<Type> &inferredReturnTypes) {
451 auto allowedRegs = properties.as<Properties *>()->getAllowedRegs();
452 inferredReturnTypes.push_back(allowedRegs.getType());
453 return success();
454}
455
456//===----------------------------------------------------------------------===//
457// RegisterToIndexOp
458//===----------------------------------------------------------------------===//
459
460OpFoldResult RegisterToIndexOp::fold(FoldAdaptor adaptor) {
461 if (auto reg = dyn_cast_or_null<rtg::RegisterAttrInterface>(adaptor.getReg()))
462 return IntegerAttr::get(IndexType::get(getContext()), reg.getClassIndex());
463
464 if (auto indexToRegOp = getReg().getDefiningOp<IndexToRegisterOp>())
465 return indexToRegOp.getIndex();
466
467 return {};
468}
469
470//===----------------------------------------------------------------------===//
471// IndexToRegisterOp
472//===----------------------------------------------------------------------===//
473
474LogicalResult IndexToRegisterOp::verify() {
475 // Check if the index is a constant and if it's within valid range
476 APInt indexValue;
477 if (matchPattern(getIndex(), m_ConstantInt(&indexValue))) {
478 if (indexValue.uge(getType().getRegisterClassSize())) {
479 SmallString<16> indexStr;
480 indexValue.toString(indexStr, 10, false);
481 return emitOpError() << "index " << indexStr
482 << " is out of range for register class "
483 << getReg().getType();
484 }
485 }
486
487 return success();
488}
489
490OpFoldResult IndexToRegisterOp::fold(FoldAdaptor adaptor) {
491 if (auto indexAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getIndex()))
492 return getType().getRegisterAttrForClassIndex(
493 getContext(), indexAttr.getValue().getZExtValue());
494
495 return {};
496}
497
498//===----------------------------------------------------------------------===//
499// ContextSwitchOp
500//===----------------------------------------------------------------------===//
501
502LogicalResult ContextSwitchOp::verify() {
503 auto elementTypes = getSequence().getType().getElementTypes();
504 if (elementTypes.size() != 3)
505 return emitOpError("sequence type must have exactly 3 element types");
506
507 if (getFrom().getType() != elementTypes[0])
508 return emitOpError(
509 "first sequence element type must match 'from' attribute type");
510
511 if (getTo().getType() != elementTypes[1])
512 return emitOpError(
513 "second sequence element type must match 'to' attribute type");
514
515 auto seqTy = dyn_cast<SequenceType>(elementTypes[2]);
516 if (!seqTy || !seqTy.getElementTypes().empty())
517 return emitOpError(
518 "third sequence element type must be a fully substituted sequence");
519
520 return success();
521}
522
523//===----------------------------------------------------------------------===//
524// TestOp
525//===----------------------------------------------------------------------===//
526
527LogicalResult TestOp::verifyRegions() {
528 if (!getTargetType().entryTypesMatch(getBody()->getArgumentTypes()))
529 return emitOpError("argument types must match dict entry types");
530
531 return success();
532}
533
534LogicalResult TestOp::verify() {
535 if (getTemplateName().empty())
536 return emitOpError("template name must not be empty");
537
538 return success();
539}
540
541LogicalResult TestOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
542 if (!getTargetAttr())
543 return success();
544
545 auto target =
546 symbolTable.lookupNearestSymbolFrom<TargetOp>(*this, getTargetAttr());
547 if (!target)
548 return emitOpError()
549 << "'" << *getTarget()
550 << "' does not reference a valid 'rtg.target' operation";
551
552 // Check if target is a subtype of test requirements
553 // Since entries are sorted by name, we can do this in a single pass
554 size_t targetIdx = 0;
555 auto targetEntries = target.getTarget().getEntries();
556 for (auto testEntry : getTargetType().getEntries()) {
557 // Find the matching entry in target entries.
558 while (targetIdx < targetEntries.size() &&
559 targetEntries[targetIdx].name.getValue() < testEntry.name.getValue())
560 targetIdx++;
561
562 // Check if we found a matching entry with the same name and type
563 if (targetIdx >= targetEntries.size() ||
564 targetEntries[targetIdx].name != testEntry.name ||
565 targetEntries[targetIdx].type != testEntry.type) {
566 return emitOpError("referenced 'rtg.target' op's type is invalid: "
567 "missing entry called '")
568 << testEntry.name.getValue() << "' of type " << testEntry.type;
569 }
570 }
571
572 return success();
573}
574
575ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
576 // Parse the name as a symbol.
577 StringAttr symNameAttr;
578 if (parser.parseSymbolName(symNameAttr))
579 return failure();
580
581 result.getOrAddProperties<TestOp::Properties>().sym_name = symNameAttr;
582
583 // Parse the function signature.
584 SmallVector<OpAsmParser::Argument> arguments;
585 SmallVector<StringAttr> names;
586
587 auto parseOneArgument = [&]() -> ParseResult {
588 std::string name;
589 if (parser.parseKeywordOrString(&name) || parser.parseEqual() ||
590 parser.parseArgument(arguments.emplace_back(), /*allowType=*/true,
591 /*allowAttrs=*/true))
592 return failure();
593
594 names.push_back(StringAttr::get(result.getContext(), name));
595 return success();
596 };
597 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
598 parseOneArgument, " in argument list"))
599 return failure();
600
601 SmallVector<Type> argTypes;
602 SmallVector<DictEntry> entries;
603 SmallVector<Location> argLocs;
604 argTypes.reserve(arguments.size());
605 argLocs.reserve(arguments.size());
606 for (auto [name, arg] : llvm::zip(names, arguments)) {
607 argTypes.push_back(arg.type);
608 argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
609 entries.push_back({name, arg.type});
610 }
611 auto emitError = [&]() -> InFlightDiagnostic {
612 return parser.emitError(parser.getCurrentLocation());
613 };
614 Type type = DictType::getChecked(emitError, result.getContext(),
615 ArrayRef<DictEntry>(entries));
616 if (!type)
617 return failure();
618 result.getOrAddProperties<TestOp::Properties>().targetType =
619 TypeAttr::get(type);
620
621 std::string templateName;
622 if (!parser.parseOptionalKeyword("template")) {
623 auto loc = parser.getCurrentLocation();
624 if (parser.parseString(&templateName))
625 return failure();
626
627 if (templateName.empty())
628 return parser.emitError(loc, "template name must not be empty");
629 }
630
631 StringAttr templateNameAttr = symNameAttr;
632 if (!templateName.empty())
633 templateNameAttr = StringAttr::get(result.getContext(), templateName);
634
635 StringAttr targetName;
636 if (!parser.parseOptionalKeyword("target"))
637 if (parser.parseSymbolName(targetName))
638 return failure();
639
640 result.getOrAddProperties<TestOp::Properties>().templateName =
641 templateNameAttr;
642 result.getOrAddProperties<TestOp::Properties>().target = targetName;
643
644 auto loc = parser.getCurrentLocation();
645 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
646 return failure();
647 if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
648 return parser.emitError(loc)
649 << "'" << result.name.getStringRef() << "' op ";
650 })))
651 return failure();
652
653 std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
654 if (parser.parseRegion(*bodyRegionRegion, arguments))
655 return failure();
656
657 if (bodyRegionRegion->empty()) {
658 bodyRegionRegion->emplaceBlock();
659 bodyRegionRegion->addArguments(argTypes, argLocs);
660 }
661 result.addRegion(std::move(bodyRegionRegion));
662
663 return success();
664}
665
666void TestOp::print(OpAsmPrinter &p) {
667 p << ' ';
668 p.printSymbolName(getSymNameAttr().getValue());
669 p << "(";
670 SmallString<32> resultNameStr;
671 llvm::interleaveComma(
672 llvm::zip(getTargetType().getEntries(), getBody()->getArguments()), p,
673 [&](auto entryAndArg) {
674 auto [entry, arg] = entryAndArg;
675 p << entry.name.getValue() << " = ";
676 p.printRegionArgument(arg);
677 });
678 p << ")";
679
680 if (getSymNameAttr() != getTemplateNameAttr())
681 p << " template " << getTemplateNameAttr();
682
683 if (getTargetAttr()) {
684 p << " target ";
685 p.printSymbolName(getTargetAttr().getValue());
686 }
687
688 p.printOptionalAttrDictWithKeyword(
689 (*this)->getAttrs(), {getSymNameAttrName(), getTargetTypeAttrName(),
690 getTargetAttrName(), getTemplateNameAttrName()});
691 p << ' ';
692 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
693}
694
695void TestOp::getAsmBlockArgumentNames(Region &region,
696 OpAsmSetValueNameFn setNameFn) {
697 for (auto [entry, arg] :
698 llvm::zip(getTargetType().getEntries(), region.getArguments()))
699 setNameFn(arg, entry.name.getValue());
700}
701
702//===----------------------------------------------------------------------===//
703// TargetOp
704//===----------------------------------------------------------------------===//
705
706LogicalResult TargetOp::verifyRegions() {
707 if (!getTarget().entryTypesMatch(
708 getBody()->getTerminator()->getOperandTypes()))
709 return emitOpError("terminator operand types must match dict entry types");
710
711 return success();
712}
713
714//===----------------------------------------------------------------------===//
715// ValidateOp
716//===----------------------------------------------------------------------===//
717
718LogicalResult ValidateOp::verify() {
719 if (!getRef().getType().isValidContentType(getValue().getType()))
720 return emitOpError(
721 "result type must be a valid content type for the ref value");
722
723 return success();
724}
725
726bool ValidateOp::isSourceRegister(unsigned index) {
727 if (index == 0)
728 return isa<RegisterTypeInterface>(getRef().getType());
729 return false;
730}
731
732bool ValidateOp::isDestinationRegister(unsigned index) { return false; }
733
734//===----------------------------------------------------------------------===//
735// ArrayCreateOp
736//===----------------------------------------------------------------------===//
737
738LogicalResult ArrayCreateOp::verify() {
739 if (!getElements().empty() &&
740 getElements()[0].getType() != getType().getElementType())
741 return emitOpError("operand types must match array element type, expected ")
742 << getType().getElementType() << " but got "
743 << getElements()[0].getType();
744
745 return success();
746}
747
748ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
749 SmallVector<OpAsmParser::UnresolvedOperand> operands;
750 Type elementType;
751
752 if (parser.parseOperandList(operands) || parser.parseColon() ||
753 parser.parseType(elementType) ||
754 parser.parseOptionalAttrDict(result.attributes))
755 return failure();
756
757 if (failed(parser.resolveOperands(operands, elementType, result.operands)))
758 return failure();
759
760 result.addTypes(ArrayType::get(elementType));
761
762 return success();
763}
764
765void ArrayCreateOp::print(OpAsmPrinter &p) {
766 p << ' ';
767 p.printOperands(getElements());
768 p << " : " << getType().getElementType();
769 p.printOptionalAttrDict((*this)->getAttrs(), {});
770}
771
772//===----------------------------------------------------------------------===//
773// ArrayAppendOp
774//===----------------------------------------------------------------------===//
775
776LogicalResult ArrayAppendOp::canonicalize(ArrayAppendOp op,
777 PatternRewriter &rewriter) {
778 auto createOp = op.getArray().getDefiningOp<ArrayCreateOp>();
779 if (!createOp)
780 return failure();
781
782 SmallVector<Value> newElements(createOp.getElements());
783 newElements.push_back(op.getElement());
784 rewriter.replaceOpWithNewOp<ArrayCreateOp>(op, op.getType(), newElements);
785 return success();
786}
787
788//===----------------------------------------------------------------------===//
789// MemoryBlockDeclareOp
790//===----------------------------------------------------------------------===//
791
792LogicalResult MemoryBlockDeclareOp::verify() {
793 if (getBaseAddress().getBitWidth() != getType().getAddressWidth())
794 return emitOpError(
795 "base address width must match memory block address width");
796
797 if (getEndAddress().getBitWidth() != getType().getAddressWidth())
798 return emitOpError(
799 "end address width must match memory block address width");
800
801 if (getBaseAddress().ugt(getEndAddress()))
802 return emitOpError(
803 "base address must be smaller than or equal to the end address");
804
805 return success();
806}
807
808ParseResult MemoryBlockDeclareOp::parse(OpAsmParser &parser,
809 OperationState &result) {
810 SmallVector<OpAsmParser::UnresolvedOperand> operands;
811 MemoryBlockType memoryBlockType;
812 APInt start, end;
813
814 if (parser.parseLSquare())
815 return failure();
816
817 auto startLoc = parser.getCurrentLocation();
818 if (parser.parseInteger(start))
819 return failure();
820
821 if (parser.parseMinus())
822 return failure();
823
824 auto endLoc = parser.getCurrentLocation();
825 if (parser.parseInteger(end) || parser.parseRSquare() ||
826 parser.parseColonType(memoryBlockType) ||
827 parser.parseOptionalAttrDict(result.attributes))
828 return failure();
829
830 auto width = memoryBlockType.getAddressWidth();
831 auto adjustAPInt = [&](APInt value, llvm::SMLoc loc) -> FailureOr<APInt> {
832 if (value.getBitWidth() > width) {
833 if (!value.isIntN(width))
834 return parser.emitError(
835 loc,
836 "address out of range for memory block with address width ")
837 << width;
838
839 return value.trunc(width);
840 }
841
842 if (value.getBitWidth() < width)
843 return value.zext(width);
844
845 return value;
846 };
847
848 auto startRes = adjustAPInt(start, startLoc);
849 auto endRes = adjustAPInt(end, endLoc);
850 if (failed(startRes) || failed(endRes))
851 return failure();
852
853 auto intType = IntegerType::get(result.getContext(), width);
854 result.addAttribute(getBaseAddressAttrName(result.name),
855 IntegerAttr::get(intType, *startRes));
856 result.addAttribute(getEndAddressAttrName(result.name),
857 IntegerAttr::get(intType, *endRes));
858
859 result.addTypes(memoryBlockType);
860 return success();
861}
862
863void MemoryBlockDeclareOp::print(OpAsmPrinter &p) {
864 SmallVector<char> str;
865 getBaseAddress().toString(str, 16, false, false, false);
866 p << " [0x" << str;
867 p << " - 0x";
868 str.clear();
869 getEndAddress().toString(str, 16, false, false, false);
870 p << str << "] : " << getType();
871 p.printOptionalAttrDict((*this)->getAttrs(),
872 {getBaseAddressAttrName(), getEndAddressAttrName()});
873}
874
875//===----------------------------------------------------------------------===//
876// MemoryBaseAddressOp
877//===----------------------------------------------------------------------===//
878
879LogicalResult MemoryBaseAddressOp::inferReturnTypes(
880 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
881 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
882 SmallVectorImpl<Type> &inferredReturnTypes) {
883 if (operands.empty())
884 return failure();
885 auto memTy = dyn_cast<MemoryType>(operands[0].getType());
886 if (!memTy)
887 return failure();
888 inferredReturnTypes.push_back(
889 ImmediateType::get(context, memTy.getAddressWidth()));
890 return success();
891}
892
893//===----------------------------------------------------------------------===//
894// ConcatImmediateOp
895//===----------------------------------------------------------------------===//
896
897LogicalResult ConcatImmediateOp::inferReturnTypes(
898 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
899 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
900 SmallVectorImpl<Type> &inferredReturnTypes) {
901 if (operands.empty()) {
902 if (loc)
903 return mlir::emitError(*loc) << "at least one operand must be provided";
904 return failure();
905 }
906
907 unsigned totalWidth = 0;
908 for (auto operand : operands) {
909 auto immType = dyn_cast<ImmediateType>(operand.getType());
910 if (!immType) {
911 if (loc)
912 return mlir::emitError(*loc)
913 << "all operands must be of immediate type";
914 return failure();
915 }
916 totalWidth += immType.getWidth();
917 }
918
919 inferredReturnTypes.push_back(ImmediateType::get(context, totalWidth));
920 return success();
921}
922
923OpFoldResult ConcatImmediateOp::fold(FoldAdaptor adaptor) {
924 // concat(x) -> x
925 if (getOperands().size() == 1)
926 return getOperands()[0];
927
928 // If all operands are constants, fold into a single constant
929 if (llvm::all_of(adaptor.getOperands(), [](Attribute attr) {
930 return isa_and_nonnull<ImmediateAttr>(attr);
931 })) {
932 auto result = APInt::getZeroWidth();
933 for (auto attr : adaptor.getOperands())
934 result = result.concat(cast<ImmediateAttr>(attr).getValue());
935
936 return ImmediateAttr::get(getContext(), result);
937 }
938
939 return {};
940}
941
942//===----------------------------------------------------------------------===//
943// SliceImmediateOp
944//===----------------------------------------------------------------------===//
945
946LogicalResult SliceImmediateOp::verify() {
947 auto srcWidth = getInput().getType().getWidth();
948 auto dstWidth = getResult().getType().getWidth();
949
950 if (getLowBit() >= srcWidth)
951 return emitOpError("from bit too large for input (got ")
952 << getLowBit() << ", but input width is " << srcWidth << ")";
953
954 if (srcWidth - getLowBit() < dstWidth)
955 return emitOpError("slice does not fit in input (trying to extract ")
956 << dstWidth << " bits starting at index " << getLowBit()
957 << ", but only " << (srcWidth - getLowBit())
958 << " bits are available)";
959
960 return success();
961}
962
963OpFoldResult SliceImmediateOp::fold(FoldAdaptor adaptor) {
964 if (auto inputAttr = dyn_cast_or_null<ImmediateAttr>(adaptor.getInput())) {
965 auto resultWidth = getType().getWidth();
966 APInt sliced = inputAttr.getValue().extractBits(resultWidth, getLowBit());
967 return ImmediateAttr::get(getContext(), sliced);
968 }
969
970 return {};
971}
972
973//===----------------------------------------------------------------------===//
974// StringConcatOp
975//===----------------------------------------------------------------------===//
976
977OpFoldResult StringConcatOp::fold(FoldAdaptor adaptor) {
978 SmallString<32> result;
979 for (auto attr : adaptor.getStrings()) {
980 auto stringAttr = dyn_cast_or_null<StringAttr>(attr);
981 if (!stringAttr)
982 return {};
983
984 result += stringAttr.getValue();
985 }
986
987 return StringAttr::get(result, StringType::get(getContext()));
988}
989
990//===----------------------------------------------------------------------===//
991// IntFormatOp
992//===----------------------------------------------------------------------===//
993
994OpFoldResult IntFormatOp::fold(FoldAdaptor adaptor) {
995 auto intAttr = dyn_cast_or_null<IntegerAttr>(adaptor.getValue());
996 if (!intAttr)
997 return {};
998 if (!intAttr.getType().isIndex())
999 return {};
1000 return StringAttr::get(Twine(intAttr.getValue().getZExtValue()),
1001 StringType::get(getContext()));
1002}
1003
1004//===----------------------------------------------------------------------===//
1005// ImmediateFormatOp
1006//===----------------------------------------------------------------------===//
1007
1008OpFoldResult ImmediateFormatOp::fold(FoldAdaptor adaptor) {
1009 auto immAttr = dyn_cast_or_null<ImmediateAttr>(adaptor.getValue());
1010 if (!immAttr)
1011 return {};
1012 SmallString<16> strBuf("0x");
1013 immAttr.getValue().toString(strBuf, 16, /*Signed=*/false);
1014 return StringAttr::get(strBuf, StringType::get(getContext()));
1015}
1016
1017//===----------------------------------------------------------------------===//
1018// RegisterFormatOp
1019//===----------------------------------------------------------------------===//
1020
1021OpFoldResult RegisterFormatOp::fold(FoldAdaptor adaptor) {
1022 auto regAttr = dyn_cast_or_null<RegisterAttrInterface>(adaptor.getValue());
1023 if (!regAttr)
1024 return {};
1025 return StringAttr::get(regAttr.getRegisterAssembly(),
1026 StringType::get(getContext()));
1027}
1028
1029//===----------------------------------------------------------------------===//
1030// StringToLabelOp
1031//===----------------------------------------------------------------------===//
1032
1033OpFoldResult StringToLabelOp::fold(FoldAdaptor adaptor) {
1034 if (auto stringAttr = dyn_cast_or_null<StringAttr>(adaptor.getString()))
1035 return LabelAttr::get(getContext(), stringAttr.getValue());
1036
1037 return {};
1038}
1039
1040//===----------------------------------------------------------------------===//
1041// TableGen generated logic.
1042//===----------------------------------------------------------------------===//
1043
1044#define GET_OP_CLASSES
1045#include "circt/Dialect/RTG/IR/RTG.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static std::unique_ptr< Context > context
static size_t getAddressWidth(size_t depth)
static InstancePath empty
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
Definition rtg.py:1
Definition seq.py:1
reg(value, clock, reset=None, reset_value=None, name=None, sym_name=None)
Definition seq.py:21