Loading [MathJax]/extensions/tex2jax.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
RTGOps.cpp
Go to the documentation of this file.
1//===- RTGOps.cpp - Implement the RTG operations --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the RTG ops.
10//
11//===----------------------------------------------------------------------===//
12
16#include "mlir/IR/Builders.h"
17#include "mlir/IR/DialectImplementation.h"
18#include "llvm/ADT/SmallString.h"
19
20using namespace mlir;
21using namespace circt;
22using namespace rtg;
23
24//===----------------------------------------------------------------------===//
25// ConstantOp
26//===----------------------------------------------------------------------===//
27
28LogicalResult
29ConstantOp::inferReturnTypes(MLIRContext *context, std::optional<Location> loc,
30 ValueRange operands, DictionaryAttr attributes,
31 OpaqueProperties properties, RegionRange regions,
32 SmallVectorImpl<Type> &inferredReturnTypes) {
33 inferredReturnTypes.push_back(
34 properties.as<Properties *>()->getValue().getType());
35 return success();
36}
37
38OpFoldResult ConstantOp::fold(FoldAdaptor adaptor) { return getValueAttr(); }
39
40//===----------------------------------------------------------------------===//
41// SequenceOp
42//===----------------------------------------------------------------------===//
43
44LogicalResult SequenceOp::verifyRegions() {
45 if (TypeRange(getSequenceType().getElementTypes()) !=
46 getBody()->getArgumentTypes())
47 return emitOpError("sequence type does not match block argument types");
48
49 return success();
50}
51
52ParseResult SequenceOp::parse(OpAsmParser &parser, OperationState &result) {
53 // Parse the name as a symbol.
54 if (parser.parseSymbolName(
55 result.getOrAddProperties<SequenceOp::Properties>().sym_name))
56 return failure();
57
58 // Parse the function signature.
59 SmallVector<OpAsmParser::Argument> arguments;
60 if (parser.parseArgumentList(arguments, OpAsmParser::Delimiter::Paren,
61 /*allowType=*/true, /*allowAttrs=*/true))
62 return failure();
63
64 SmallVector<Type> argTypes;
65 SmallVector<Location> argLocs;
66 argTypes.reserve(arguments.size());
67 argLocs.reserve(arguments.size());
68 for (auto &arg : arguments) {
69 argTypes.push_back(arg.type);
70 argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
71 }
72 Type type = SequenceType::get(result.getContext(), argTypes);
73 result.getOrAddProperties<SequenceOp::Properties>().sequenceType =
74 TypeAttr::get(type);
75
76 auto loc = parser.getCurrentLocation();
77 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
78 return failure();
79 if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
80 return parser.emitError(loc)
81 << "'" << result.name.getStringRef() << "' op ";
82 })))
83 return failure();
84
85 std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
86 if (parser.parseRegion(*bodyRegionRegion, arguments))
87 return failure();
88
89 if (bodyRegionRegion->empty()) {
90 bodyRegionRegion->emplaceBlock();
91 bodyRegionRegion->addArguments(argTypes, argLocs);
92 }
93 result.addRegion(std::move(bodyRegionRegion));
94
95 return success();
96}
97
98void SequenceOp::print(OpAsmPrinter &p) {
99 p << ' ';
100 p.printSymbolName(getSymNameAttr().getValue());
101 p << "(";
102 llvm::interleaveComma(getBody()->getArguments(), p,
103 [&](auto arg) { p.printRegionArgument(arg); });
104 p << ")";
105 p.printOptionalAttrDictWithKeyword(
106 (*this)->getAttrs(), {getSymNameAttrName(), getSequenceTypeAttrName()});
107 p << ' ';
108 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
109}
110
111//===----------------------------------------------------------------------===//
112// GetSequenceOp
113//===----------------------------------------------------------------------===//
114
115LogicalResult
116GetSequenceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
117 SequenceOp seq =
118 symbolTable.lookupNearestSymbolFrom<SequenceOp>(*this, getSequenceAttr());
119 if (!seq)
120 return emitOpError()
121 << "'" << getSequence()
122 << "' does not reference a valid 'rtg.sequence' operation";
123
124 if (seq.getSequenceType() != getType())
125 return emitOpError("referenced 'rtg.sequence' op's type does not match");
126
127 return success();
128}
129
130//===----------------------------------------------------------------------===//
131// SubstituteSequenceOp
132//===----------------------------------------------------------------------===//
133
134LogicalResult SubstituteSequenceOp::verify() {
135 if (getReplacements().empty())
136 return emitOpError("must at least have one replacement value");
137
138 if (getReplacements().size() >
139 getSequence().getType().getElementTypes().size())
140 return emitOpError(
141 "must not have more replacement values than sequence arguments");
142
143 if (getReplacements().getTypes() !=
144 getSequence().getType().getElementTypes().take_front(
145 getReplacements().size()))
146 return emitOpError("replacement types must match the same number of "
147 "sequence argument types from the front");
148
149 return success();
150}
151
152LogicalResult SubstituteSequenceOp::inferReturnTypes(
153 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
154 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
155 SmallVectorImpl<Type> &inferredReturnTypes) {
156 ArrayRef<Type> argTypes =
157 cast<SequenceType>(operands[0].getType()).getElementTypes();
158 auto seqType =
159 SequenceType::get(context, argTypes.drop_front(operands.size() - 1));
160 inferredReturnTypes.push_back(seqType);
161 return success();
162}
163
164ParseResult SubstituteSequenceOp::parse(::mlir::OpAsmParser &parser,
165 ::mlir::OperationState &result) {
166 OpAsmParser::UnresolvedOperand sequenceRawOperand;
167 SmallVector<OpAsmParser::UnresolvedOperand, 4> replacementsOperands;
168 Type sequenceRawType;
169
170 if (parser.parseOperand(sequenceRawOperand) || parser.parseLParen())
171 return failure();
172
173 auto replacementsOperandsLoc = parser.getCurrentLocation();
174 if (parser.parseOperandList(replacementsOperands) || parser.parseRParen() ||
175 parser.parseColon() || parser.parseType(sequenceRawType) ||
176 parser.parseOptionalAttrDict(result.attributes))
177 return failure();
178
179 if (!isa<SequenceType>(sequenceRawType))
180 return parser.emitError(parser.getNameLoc())
181 << "'sequence' must be handle to a sequence or sequence family, but "
182 "got "
183 << sequenceRawType;
184
185 if (parser.resolveOperand(sequenceRawOperand, sequenceRawType,
186 result.operands))
187 return failure();
188
189 if (parser.resolveOperands(replacementsOperands,
190 cast<SequenceType>(sequenceRawType)
191 .getElementTypes()
192 .take_front(replacementsOperands.size()),
193 replacementsOperandsLoc, result.operands))
194 return failure();
195
196 SmallVector<Type> inferredReturnTypes;
197 if (failed(inferReturnTypes(
198 parser.getContext(), result.location, result.operands,
199 result.attributes.getDictionary(parser.getContext()),
200 result.getRawProperties(), result.regions, inferredReturnTypes)))
201 return failure();
202
203 result.addTypes(inferredReturnTypes);
204 return success();
205}
206
207void SubstituteSequenceOp::print(OpAsmPrinter &p) {
208 p << ' ' << getSequence() << "(" << getReplacements()
209 << ") : " << getSequence().getType();
210 p.printOptionalAttrDict((*this)->getAttrs(), {});
211}
212
213//===----------------------------------------------------------------------===//
214// InterleaveSequencesOp
215//===----------------------------------------------------------------------===//
216
217LogicalResult InterleaveSequencesOp::verify() {
218 if (getSequences().empty())
219 return emitOpError("must have at least one sequence in the list");
220
221 return success();
222}
223
224OpFoldResult InterleaveSequencesOp::fold(FoldAdaptor adaptor) {
225 if (getSequences().size() == 1)
226 return getSequences()[0];
227
228 return {};
229}
230
231//===----------------------------------------------------------------------===//
232// SetCreateOp
233//===----------------------------------------------------------------------===//
234
235ParseResult SetCreateOp::parse(OpAsmParser &parser, OperationState &result) {
236 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
237 Type elemType;
238
239 if (parser.parseOperandList(operands) ||
240 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
241 parser.parseType(elemType))
242 return failure();
243
244 result.addTypes({SetType::get(result.getContext(), elemType)});
245
246 for (auto operand : operands)
247 if (parser.resolveOperand(operand, elemType, result.operands))
248 return failure();
249
250 return success();
251}
252
253void SetCreateOp::print(OpAsmPrinter &p) {
254 p << " ";
255 p.printOperands(getElements());
256 p.printOptionalAttrDict((*this)->getAttrs());
257 p << " : " << getSet().getType().getElementType();
258}
259
260LogicalResult SetCreateOp::verify() {
261 if (getElements().size() > 0) {
262 // We only need to check the first element because of the `SameTypeOperands`
263 // trait.
264 if (getElements()[0].getType() != getSet().getType().getElementType())
265 return emitOpError() << "operand types must match set element type";
266 }
267
268 return success();
269}
270
271//===----------------------------------------------------------------------===//
272// SetCartesianProductOp
273//===----------------------------------------------------------------------===//
274
275LogicalResult SetCartesianProductOp::inferReturnTypes(
276 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
277 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
278 SmallVectorImpl<Type> &inferredReturnTypes) {
279 if (operands.empty()) {
280 if (loc)
281 return mlir::emitError(*loc) << "at least one set must be provided";
282 return failure();
283 }
284
285 SmallVector<Type> elementTypes;
286 for (auto operand : operands)
287 elementTypes.push_back(cast<SetType>(operand.getType()).getElementType());
288 inferredReturnTypes.push_back(
289 SetType::get(rtg::TupleType::get(context, elementTypes)));
290 return success();
291}
292
293//===----------------------------------------------------------------------===//
294// BagCreateOp
295//===----------------------------------------------------------------------===//
296
297ParseResult BagCreateOp::parse(OpAsmParser &parser, OperationState &result) {
298 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> elementOperands,
299 multipleOperands;
300 Type elemType;
301
302 if (!parser.parseOptionalLParen()) {
303 while (true) {
304 OpAsmParser::UnresolvedOperand elementOperand, multipleOperand;
305 if (parser.parseOperand(multipleOperand) || parser.parseKeyword("x") ||
306 parser.parseOperand(elementOperand))
307 return failure();
308
309 elementOperands.push_back(elementOperand);
310 multipleOperands.push_back(multipleOperand);
311
312 if (parser.parseOptionalComma()) {
313 if (parser.parseRParen())
314 return failure();
315 break;
316 }
317 }
318 }
319
320 if (parser.parseColon() || parser.parseType(elemType) ||
321 parser.parseOptionalAttrDict(result.attributes))
322 return failure();
323
324 result.addTypes({BagType::get(result.getContext(), elemType)});
325
326 for (auto operand : elementOperands)
327 if (parser.resolveOperand(operand, elemType, result.operands))
328 return failure();
329
330 for (auto operand : multipleOperands)
331 if (parser.resolveOperand(operand, IndexType::get(result.getContext()),
332 result.operands))
333 return failure();
334
335 return success();
336}
337
338void BagCreateOp::print(OpAsmPrinter &p) {
339 p << " ";
340 if (!getElements().empty())
341 p << "(";
342 llvm::interleaveComma(llvm::zip(getElements(), getMultiples()), p,
343 [&](auto elAndMultiple) {
344 auto [el, multiple] = elAndMultiple;
345 p << multiple << " x " << el;
346 });
347 if (!getElements().empty())
348 p << ")";
349
350 p << " : " << getBag().getType().getElementType();
351 p.printOptionalAttrDict((*this)->getAttrs());
352}
353
354LogicalResult BagCreateOp::verify() {
355 if (!llvm::all_equal(getElements().getTypes()))
356 return emitOpError() << "types of all elements must match";
357
358 if (getElements().size() > 0)
359 if (getElements()[0].getType() != getBag().getType().getElementType())
360 return emitOpError() << "operand types must match bag element type";
361
362 return success();
363}
364
365//===----------------------------------------------------------------------===//
366// TupleCreateOp
367//===----------------------------------------------------------------------===//
368
369LogicalResult TupleCreateOp::inferReturnTypes(
370 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
371 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
372 SmallVectorImpl<Type> &inferredReturnTypes) {
373 SmallVector<Type> elementTypes;
374 for (auto operand : operands)
375 elementTypes.push_back(operand.getType());
376 inferredReturnTypes.push_back(rtg::TupleType::get(context, elementTypes));
377 return success();
378}
379
380//===----------------------------------------------------------------------===//
381// TupleExtractOp
382//===----------------------------------------------------------------------===//
383
384LogicalResult TupleExtractOp::inferReturnTypes(
385 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
386 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
387 SmallVectorImpl<Type> &inferredReturnTypes) {
388 assert(operands.size() == 1 && "must have exactly one operand");
389
390 auto tupleTy = dyn_cast<rtg::TupleType>(operands[0].getType());
391 size_t idx = properties.as<Properties *>()->getIndex().getInt();
392 if (!tupleTy) {
393 if (loc)
394 return mlir::emitError(*loc) << "only RTG tuples are supported";
395 return failure();
396 }
397
398 if (tupleTy.getFieldTypes().size() <= idx) {
399 if (loc)
400 return mlir::emitError(*loc)
401 << "index (" << idx
402 << ") must be smaller than number of elements in tuple ("
403 << tupleTy.getFieldTypes().size() << ")";
404 return failure();
405 }
406
407 inferredReturnTypes.push_back(tupleTy.getFieldTypes()[idx]);
408 return success();
409}
410
411//===----------------------------------------------------------------------===//
412// FixedRegisterOp
413//===----------------------------------------------------------------------===//
414
415LogicalResult FixedRegisterOp::inferReturnTypes(
416 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
417 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
418 SmallVectorImpl<Type> &inferredReturnTypes) {
419 inferredReturnTypes.push_back(
420 properties.as<Properties *>()->getReg().getType());
421 return success();
422}
423
424OpFoldResult FixedRegisterOp::fold(FoldAdaptor adaptor) { return getRegAttr(); }
425
426//===----------------------------------------------------------------------===//
427// VirtualRegisterOp
428//===----------------------------------------------------------------------===//
429
430LogicalResult VirtualRegisterOp::verify() {
431 if (getAllowedRegs().empty())
432 return emitOpError("must have at least one allowed register");
433
434 if (llvm::any_of(getAllowedRegs(), [](Attribute attr) {
435 return !isa<RegisterAttrInterface>(attr);
436 }))
437 return emitOpError("all elements must be of RegisterAttrInterface");
438
439 if (!llvm::all_equal(
440 llvm::map_range(getAllowedRegs().getAsRange<RegisterAttrInterface>(),
441 [](auto attr) { return attr.getType(); })))
442 return emitOpError("all allowed registers must be of the same type");
443
444 return success();
445}
446
447LogicalResult VirtualRegisterOp::inferReturnTypes(
448 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
449 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
450 SmallVectorImpl<Type> &inferredReturnTypes) {
451 auto allowedRegs = properties.as<Properties *>()->getAllowedRegs();
452 if (allowedRegs.empty()) {
453 if (loc)
454 return mlir::emitError(*loc, "must have at least one allowed register");
455
456 return failure();
457 }
458
459 auto regAttr = dyn_cast<RegisterAttrInterface>(allowedRegs[0]);
460 if (!regAttr) {
461 if (loc)
462 return mlir::emitError(
463 *loc, "allowed register attributes must be of RegisterAttrInterface");
464
465 return failure();
466 }
467 inferredReturnTypes.push_back(regAttr.getType());
468 return success();
469}
470
471//===----------------------------------------------------------------------===//
472// ContextSwitchOp
473//===----------------------------------------------------------------------===//
474
475LogicalResult ContextSwitchOp::verify() {
476 auto elementTypes = getSequence().getType().getElementTypes();
477 if (elementTypes.size() != 3)
478 return emitOpError("sequence type must have exactly 3 element types");
479
480 if (getFrom().getType() != elementTypes[0])
481 return emitOpError(
482 "first sequence element type must match 'from' attribute type");
483
484 if (getTo().getType() != elementTypes[1])
485 return emitOpError(
486 "second sequence element type must match 'to' attribute type");
487
488 auto seqTy = dyn_cast<SequenceType>(elementTypes[2]);
489 if (!seqTy || !seqTy.getElementTypes().empty())
490 return emitOpError(
491 "third sequence element type must be a fully substituted sequence");
492
493 return success();
494}
495
496//===----------------------------------------------------------------------===//
497// TestOp
498//===----------------------------------------------------------------------===//
499
500LogicalResult TestOp::verifyRegions() {
501 if (!getTargetType().entryTypesMatch(getBody()->getArgumentTypes()))
502 return emitOpError("argument types must match dict entry types");
503
504 return success();
505}
506
507LogicalResult TestOp::verify() {
508 if (getTemplateName().empty())
509 return emitOpError("template name must not be empty");
510
511 return success();
512}
513
514LogicalResult TestOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
515 if (!getTargetAttr())
516 return success();
517
518 auto target =
519 symbolTable.lookupNearestSymbolFrom<TargetOp>(*this, getTargetAttr());
520 if (!target)
521 return emitOpError()
522 << "'" << *getTarget()
523 << "' does not reference a valid 'rtg.target' operation";
524
525 // Check if target is a subtype of test requirements
526 // Since entries are sorted by name, we can do this in a single pass
527 size_t targetIdx = 0;
528 auto targetEntries = target.getTarget().getEntries();
529 for (auto testEntry : getTargetType().getEntries()) {
530 // Find the matching entry in target entries.
531 while (targetIdx < targetEntries.size() &&
532 targetEntries[targetIdx].name.getValue() < testEntry.name.getValue())
533 targetIdx++;
534
535 // Check if we found a matching entry with the same name and type
536 if (targetIdx >= targetEntries.size() ||
537 targetEntries[targetIdx].name != testEntry.name ||
538 targetEntries[targetIdx].type != testEntry.type) {
539 return emitOpError("referenced 'rtg.target' op's type is invalid: "
540 "missing entry called '")
541 << testEntry.name.getValue() << "' of type " << testEntry.type;
542 }
543 }
544
545 return success();
546}
547
548ParseResult TestOp::parse(OpAsmParser &parser, OperationState &result) {
549 // Parse the name as a symbol.
550 StringAttr symNameAttr;
551 if (parser.parseSymbolName(symNameAttr))
552 return failure();
553
554 result.getOrAddProperties<TestOp::Properties>().sym_name = symNameAttr;
555
556 // Parse the function signature.
557 SmallVector<OpAsmParser::Argument> arguments;
558 SmallVector<StringAttr> names;
559
560 auto parseOneArgument = [&]() -> ParseResult {
561 std::string name;
562 if (parser.parseKeywordOrString(&name) || parser.parseEqual() ||
563 parser.parseArgument(arguments.emplace_back(), /*allowType=*/true,
564 /*allowAttrs=*/true))
565 return failure();
566
567 names.push_back(StringAttr::get(result.getContext(), name));
568 return success();
569 };
570 if (parser.parseCommaSeparatedList(OpAsmParser::Delimiter::Paren,
571 parseOneArgument, " in argument list"))
572 return failure();
573
574 SmallVector<Type> argTypes;
575 SmallVector<DictEntry> entries;
576 SmallVector<Location> argLocs;
577 argTypes.reserve(arguments.size());
578 argLocs.reserve(arguments.size());
579 for (auto [name, arg] : llvm::zip(names, arguments)) {
580 argTypes.push_back(arg.type);
581 argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
582 entries.push_back({name, arg.type});
583 }
584 auto emitError = [&]() -> InFlightDiagnostic {
585 return parser.emitError(parser.getCurrentLocation());
586 };
587 Type type = DictType::getChecked(emitError, result.getContext(),
588 ArrayRef<DictEntry>(entries));
589 if (!type)
590 return failure();
591 result.getOrAddProperties<TestOp::Properties>().targetType =
592 TypeAttr::get(type);
593
594 std::string templateName;
595 if (!parser.parseOptionalKeyword("template")) {
596 auto loc = parser.getCurrentLocation();
597 if (parser.parseString(&templateName))
598 return failure();
599
600 if (templateName.empty())
601 return parser.emitError(loc, "template name must not be empty");
602 }
603
604 StringAttr templateNameAttr = symNameAttr;
605 if (!templateName.empty())
606 templateNameAttr = StringAttr::get(result.getContext(), templateName);
607
608 StringAttr targetName;
609 if (!parser.parseOptionalKeyword("target"))
610 if (parser.parseSymbolName(targetName))
611 return failure();
612
613 result.getOrAddProperties<TestOp::Properties>().templateName =
614 templateNameAttr;
615 result.getOrAddProperties<TestOp::Properties>().target = targetName;
616
617 auto loc = parser.getCurrentLocation();
618 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
619 return failure();
620 if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
621 return parser.emitError(loc)
622 << "'" << result.name.getStringRef() << "' op ";
623 })))
624 return failure();
625
626 std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
627 if (parser.parseRegion(*bodyRegionRegion, arguments))
628 return failure();
629
630 if (bodyRegionRegion->empty()) {
631 bodyRegionRegion->emplaceBlock();
632 bodyRegionRegion->addArguments(argTypes, argLocs);
633 }
634 result.addRegion(std::move(bodyRegionRegion));
635
636 return success();
637}
638
639void TestOp::print(OpAsmPrinter &p) {
640 p << ' ';
641 p.printSymbolName(getSymNameAttr().getValue());
642 p << "(";
643 SmallString<32> resultNameStr;
644 llvm::interleaveComma(
645 llvm::zip(getTargetType().getEntries(), getBody()->getArguments()), p,
646 [&](auto entryAndArg) {
647 auto [entry, arg] = entryAndArg;
648 p << entry.name.getValue() << " = ";
649 p.printRegionArgument(arg);
650 });
651 p << ")";
652
653 if (getSymNameAttr() != getTemplateNameAttr())
654 p << " template " << getTemplateNameAttr();
655
656 if (getTargetAttr()) {
657 p << " target ";
658 p.printSymbolName(getTargetAttr().getValue());
659 }
660
661 p.printOptionalAttrDictWithKeyword(
662 (*this)->getAttrs(), {getSymNameAttrName(), getTargetTypeAttrName(),
663 getTargetAttrName(), getTemplateNameAttrName()});
664 p << ' ';
665 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
666}
667
668void TestOp::getAsmBlockArgumentNames(Region &region,
669 OpAsmSetValueNameFn setNameFn) {
670 for (auto [entry, arg] :
671 llvm::zip(getTargetType().getEntries(), region.getArguments()))
672 setNameFn(arg, entry.name.getValue());
673}
674
675//===----------------------------------------------------------------------===//
676// TargetOp
677//===----------------------------------------------------------------------===//
678
679LogicalResult TargetOp::verifyRegions() {
680 if (!getTarget().entryTypesMatch(
681 getBody()->getTerminator()->getOperandTypes()))
682 return emitOpError("terminator operand types must match dict entry types");
683
684 return success();
685}
686
687//===----------------------------------------------------------------------===//
688// ValidateOp
689//===----------------------------------------------------------------------===//
690
691LogicalResult ValidateOp::verify() {
692 if (!getRef().getType().isValidContentType(getValue().getType()))
693 return emitOpError(
694 "result type must be a valid content type for the ref value");
695
696 return success();
697}
698
699//===----------------------------------------------------------------------===//
700// ArrayCreateOp
701//===----------------------------------------------------------------------===//
702
703LogicalResult ArrayCreateOp::verify() {
704 if (!getElements().empty() &&
705 getElements()[0].getType() != getType().getElementType())
706 return emitOpError("operand types must match array element type, expected ")
707 << getType().getElementType() << " but got "
708 << getElements()[0].getType();
709
710 return success();
711}
712
713ParseResult ArrayCreateOp::parse(OpAsmParser &parser, OperationState &result) {
714 SmallVector<OpAsmParser::UnresolvedOperand> operands;
715 Type elementType;
716
717 if (parser.parseOperandList(operands) || parser.parseColon() ||
718 parser.parseType(elementType) ||
719 parser.parseOptionalAttrDict(result.attributes))
720 return failure();
721
722 if (failed(parser.resolveOperands(operands, elementType, result.operands)))
723 return failure();
724
725 result.addTypes(ArrayType::get(elementType));
726
727 return success();
728}
729
730void ArrayCreateOp::print(OpAsmPrinter &p) {
731 p << ' ';
732 p.printOperands(getElements());
733 p << " : " << getType().getElementType();
734 p.printOptionalAttrDict((*this)->getAttrs(), {});
735}
736
737//===----------------------------------------------------------------------===//
738// MemoryBlockDeclareOp
739//===----------------------------------------------------------------------===//
740
741LogicalResult MemoryBlockDeclareOp::verify() {
742 if (getBaseAddress().getBitWidth() != getType().getAddressWidth())
743 return emitOpError(
744 "base address width must match memory block address width");
745
746 if (getEndAddress().getBitWidth() != getType().getAddressWidth())
747 return emitOpError(
748 "end address width must match memory block address width");
749
750 if (getBaseAddress().ugt(getEndAddress()))
751 return emitOpError(
752 "base address must be smaller than or equal to the end address");
753
754 return success();
755}
756
757ParseResult MemoryBlockDeclareOp::parse(OpAsmParser &parser,
758 OperationState &result) {
759 SmallVector<OpAsmParser::UnresolvedOperand> operands;
760 MemoryBlockType memoryBlockType;
761 APInt start, end;
762
763 if (parser.parseLSquare())
764 return failure();
765
766 auto startLoc = parser.getCurrentLocation();
767 if (parser.parseInteger(start))
768 return failure();
769
770 if (parser.parseMinus())
771 return failure();
772
773 auto endLoc = parser.getCurrentLocation();
774 if (parser.parseInteger(end) || parser.parseRSquare() ||
775 parser.parseColonType(memoryBlockType) ||
776 parser.parseOptionalAttrDict(result.attributes))
777 return failure();
778
779 auto width = memoryBlockType.getAddressWidth();
780 auto adjustAPInt = [&](APInt value, llvm::SMLoc loc) -> FailureOr<APInt> {
781 if (value.getBitWidth() > width) {
782 if (!value.isIntN(width))
783 return parser.emitError(
784 loc,
785 "address out of range for memory block with address width ")
786 << width;
787
788 return value.trunc(width);
789 }
790
791 if (value.getBitWidth() < width)
792 return value.zext(width);
793
794 return value;
795 };
796
797 auto startRes = adjustAPInt(start, startLoc);
798 auto endRes = adjustAPInt(end, endLoc);
799 if (failed(startRes) || failed(endRes))
800 return failure();
801
802 auto intType = IntegerType::get(result.getContext(), width);
803 result.addAttribute(getBaseAddressAttrName(result.name),
804 IntegerAttr::get(intType, *startRes));
805 result.addAttribute(getEndAddressAttrName(result.name),
806 IntegerAttr::get(intType, *endRes));
807
808 result.addTypes(memoryBlockType);
809 return success();
810}
811
812void MemoryBlockDeclareOp::print(OpAsmPrinter &p) {
813 SmallVector<char> str;
814 getBaseAddress().toString(str, 16, false, false, false);
815 p << " [0x" << str;
816 p << " - 0x";
817 str.clear();
818 getEndAddress().toString(str, 16, false, false, false);
819 p << str << "] : " << getType();
820 p.printOptionalAttrDict((*this)->getAttrs(),
821 {getBaseAddressAttrName(), getEndAddressAttrName()});
822}
823
824//===----------------------------------------------------------------------===//
825// MemoryBaseAddressOp
826//===----------------------------------------------------------------------===//
827
828LogicalResult MemoryBaseAddressOp::inferReturnTypes(
829 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
830 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
831 SmallVectorImpl<Type> &inferredReturnTypes) {
832 if (operands.empty())
833 return failure();
834 auto memTy = dyn_cast<MemoryType>(operands[0].getType());
835 if (!memTy)
836 return failure();
837 inferredReturnTypes.push_back(
838 ImmediateType::get(context, memTy.getAddressWidth()));
839 return success();
840}
841
842//===----------------------------------------------------------------------===//
843// ConcatImmediateOp
844//===----------------------------------------------------------------------===//
845
846LogicalResult ConcatImmediateOp::inferReturnTypes(
847 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
848 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
849 SmallVectorImpl<Type> &inferredReturnTypes) {
850 if (operands.empty()) {
851 if (loc)
852 return mlir::emitError(*loc) << "at least one operand must be provided";
853 return failure();
854 }
855
856 unsigned totalWidth = 0;
857 for (auto operand : operands) {
858 auto immType = dyn_cast<ImmediateType>(operand.getType());
859 if (!immType) {
860 if (loc)
861 return mlir::emitError(*loc)
862 << "all operands must be of immediate type";
863 return failure();
864 }
865 totalWidth += immType.getWidth();
866 }
867
868 inferredReturnTypes.push_back(ImmediateType::get(context, totalWidth));
869 return success();
870}
871
872OpFoldResult ConcatImmediateOp::fold(FoldAdaptor adaptor) {
873 // concat(x) -> x
874 if (getOperands().size() == 1)
875 return getOperands()[0];
876
877 // If all operands are constants, fold into a single constant
878 if (llvm::all_of(adaptor.getOperands(), [](Attribute attr) {
879 return isa_and_nonnull<ImmediateAttr>(attr);
880 })) {
881 auto result = APInt::getZeroWidth();
882 for (auto attr : adaptor.getOperands())
883 result = result.concat(cast<ImmediateAttr>(attr).getValue());
884
885 return ImmediateAttr::get(getContext(), result);
886 }
887
888 return {};
889}
890
891//===----------------------------------------------------------------------===//
892// SliceImmediateOp
893//===----------------------------------------------------------------------===//
894
895LogicalResult SliceImmediateOp::verify() {
896 auto srcWidth = getInput().getType().getWidth();
897 auto dstWidth = getResult().getType().getWidth();
898
899 if (getLowBit() >= srcWidth)
900 return emitOpError("from bit too large for input (got ")
901 << getLowBit() << ", but input width is " << srcWidth << ")";
902
903 if (srcWidth - getLowBit() < dstWidth)
904 return emitOpError("slice does not fit in input (trying to extract ")
905 << dstWidth << " bits starting at index " << getLowBit()
906 << ", but only " << (srcWidth - getLowBit())
907 << " bits are available)";
908
909 return success();
910}
911
912OpFoldResult SliceImmediateOp::fold(FoldAdaptor adaptor) {
913 if (auto inputAttr = dyn_cast_or_null<ImmediateAttr>(adaptor.getInput())) {
914 auto resultWidth = getType().getWidth();
915 APInt sliced = inputAttr.getValue().extractBits(resultWidth, getLowBit());
916 return ImmediateAttr::get(getContext(), sliced);
917 }
918
919 return {};
920}
921
922//===----------------------------------------------------------------------===//
923// TableGen generated logic.
924//===----------------------------------------------------------------------===//
925
926#define GET_OP_CLASSES
927#include "circt/Dialect/RTG/IR/RTG.cpp.inc"
assert(baseType &&"element must be base type")
MlirType elementType
Definition CHIRRTL.cpp:29
static SmallVector< T > concat(const SmallVectorImpl< T > &a, const SmallVectorImpl< T > &b)
Returns a new vector containing the concatenation of vectors a and b.
Definition CalyxOps.cpp:540
static size_t getAddressWidth(size_t depth)
static InstancePath empty
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
int64_t getBitWidth(mlir::Type type)
Return the hardware bit width of a type.
Definition HWTypes.cpp:110
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:183
Definition rtg.py:1
Definition seq.py:1