CIRCT 20.0.0git
Loading...
Searching...
No Matches
RTGOps.cpp
Go to the documentation of this file.
1//===- RTGOps.cpp - Implement the RTG operations --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the RTG ops.
10//
11//===----------------------------------------------------------------------===//
12
14#include "mlir/IR/Builders.h"
15#include "mlir/IR/DialectImplementation.h"
16
17using namespace mlir;
18using namespace circt;
19using namespace rtg;
20
21//===----------------------------------------------------------------------===//
22// SequenceOp
23//===----------------------------------------------------------------------===//
24
25LogicalResult SequenceOp::verifyRegions() {
26 if (TypeRange(getSequenceType().getElementTypes()) !=
27 getBody()->getArgumentTypes())
28 return emitOpError("sequence type does not match block argument types");
29
30 return success();
31}
32
33ParseResult SequenceOp::parse(OpAsmParser &parser, OperationState &result) {
34 // Parse the name as a symbol.
35 if (parser.parseSymbolName(
36 result.getOrAddProperties<SequenceOp::Properties>().sym_name))
37 return failure();
38
39 // Parse the function signature.
40 SmallVector<OpAsmParser::Argument> arguments;
41 if (parser.parseArgumentList(arguments, OpAsmParser::Delimiter::Paren,
42 /*allowType=*/true, /*allowAttrs=*/true))
43 return failure();
44
45 SmallVector<Type> argTypes;
46 SmallVector<Location> argLocs;
47 argTypes.reserve(arguments.size());
48 argLocs.reserve(arguments.size());
49 for (auto &arg : arguments) {
50 argTypes.push_back(arg.type);
51 argLocs.push_back(arg.sourceLoc ? *arg.sourceLoc : result.location);
52 }
53 Type type = SequenceType::get(result.getContext(), argTypes);
54 result.getOrAddProperties<SequenceOp::Properties>().sequenceType =
55 TypeAttr::get(type);
56
57 auto loc = parser.getCurrentLocation();
58 if (parser.parseOptionalAttrDictWithKeyword(result.attributes))
59 return failure();
60 if (failed(verifyInherentAttrs(result.name, result.attributes, [&]() {
61 return parser.emitError(loc)
62 << "'" << result.name.getStringRef() << "' op ";
63 })))
64 return failure();
65
66 std::unique_ptr<Region> bodyRegionRegion = std::make_unique<Region>();
67 if (parser.parseRegion(*bodyRegionRegion, arguments))
68 return failure();
69
70 if (bodyRegionRegion->empty()) {
71 bodyRegionRegion->emplaceBlock();
72 bodyRegionRegion->addArguments(argTypes, argLocs);
73 }
74 result.addRegion(std::move(bodyRegionRegion));
75
76 return success();
77}
78
79void SequenceOp::print(OpAsmPrinter &p) {
80 p << ' ';
81 p.printSymbolName(getSymNameAttr().getValue());
82 p << "(";
83 llvm::interleaveComma(getBody()->getArguments(), p,
84 [&](auto arg) { p.printRegionArgument(arg); });
85 p << ")";
86 p.printOptionalAttrDictWithKeyword(
87 (*this)->getAttrs(), {getSymNameAttrName(), getSequenceTypeAttrName()});
88 p << ' ';
89 p.printRegion(getBodyRegion(), /*printEntryBlockArgs=*/false);
90}
91
92//===----------------------------------------------------------------------===//
93// GetSequenceOp
94//===----------------------------------------------------------------------===//
95
96LogicalResult
97GetSequenceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
98 SequenceOp seq =
99 symbolTable.lookupNearestSymbolFrom<SequenceOp>(*this, getSequenceAttr());
100 if (!seq)
101 return emitOpError()
102 << "'" << getSequence()
103 << "' does not reference a valid 'rtg.sequence' operation";
104
105 if (seq.getSequenceType() != getType())
106 return emitOpError("referenced 'rtg.sequence' op's type does not match");
107
108 return success();
109}
110
111//===----------------------------------------------------------------------===//
112// SubstituteSequenceOp
113//===----------------------------------------------------------------------===//
114
115LogicalResult SubstituteSequenceOp::verify() {
116 if (getReplacements().empty())
117 return emitOpError("must at least have one replacement value");
118
119 if (getReplacements().size() >
120 getSequence().getType().getElementTypes().size())
121 return emitOpError(
122 "must not have more replacement values than sequence arguments");
123
124 if (getReplacements().getTypes() !=
125 getSequence().getType().getElementTypes().take_front(
126 getReplacements().size()))
127 return emitOpError("replacement types must match the same number of "
128 "sequence argument types from the front");
129
130 return success();
131}
132
133LogicalResult SubstituteSequenceOp::inferReturnTypes(
134 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
135 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
136 SmallVectorImpl<Type> &inferredReturnTypes) {
137 ArrayRef<Type> argTypes =
138 cast<SequenceType>(operands[0].getType()).getElementTypes();
139 auto seqType =
140 SequenceType::get(context, argTypes.drop_front(operands.size() - 1));
141 inferredReturnTypes.push_back(seqType);
142 return success();
143}
144
145ParseResult SubstituteSequenceOp::parse(::mlir::OpAsmParser &parser,
146 ::mlir::OperationState &result) {
147 OpAsmParser::UnresolvedOperand sequenceRawOperand;
148 SmallVector<OpAsmParser::UnresolvedOperand, 4> replacementsOperands;
149 Type sequenceRawType;
150
151 if (parser.parseOperand(sequenceRawOperand) || parser.parseLParen())
152 return failure();
153
154 auto replacementsOperandsLoc = parser.getCurrentLocation();
155 if (parser.parseOperandList(replacementsOperands) || parser.parseRParen() ||
156 parser.parseColon() || parser.parseType(sequenceRawType) ||
157 parser.parseOptionalAttrDict(result.attributes))
158 return failure();
159
160 if (!isa<SequenceType>(sequenceRawType))
161 return parser.emitError(parser.getNameLoc())
162 << "'sequence' must be handle to a sequence or sequence family, but "
163 "got "
164 << sequenceRawType;
165
166 if (parser.resolveOperand(sequenceRawOperand, sequenceRawType,
167 result.operands))
168 return failure();
169
170 if (parser.resolveOperands(replacementsOperands,
171 cast<SequenceType>(sequenceRawType)
172 .getElementTypes()
173 .take_front(replacementsOperands.size()),
174 replacementsOperandsLoc, result.operands))
175 return failure();
176
177 SmallVector<Type> inferredReturnTypes;
178 if (failed(inferReturnTypes(
179 parser.getContext(), result.location, result.operands,
180 result.attributes.getDictionary(parser.getContext()),
181 result.getRawProperties(), result.regions, inferredReturnTypes)))
182 return failure();
183
184 result.addTypes(inferredReturnTypes);
185 return success();
186}
187
188void SubstituteSequenceOp::print(OpAsmPrinter &p) {
189 p << ' ' << getSequence() << "(" << getReplacements()
190 << ") : " << getSequence().getType();
191 p.printOptionalAttrDict((*this)->getAttrs(), {});
192}
193
194//===----------------------------------------------------------------------===//
195// SetCreateOp
196//===----------------------------------------------------------------------===//
197
198ParseResult SetCreateOp::parse(OpAsmParser &parser, OperationState &result) {
199 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> operands;
200 Type elemType;
201
202 if (parser.parseOperandList(operands) ||
203 parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() ||
204 parser.parseType(elemType))
205 return failure();
206
207 result.addTypes({SetType::get(result.getContext(), elemType)});
208
209 for (auto operand : operands)
210 if (parser.resolveOperand(operand, elemType, result.operands))
211 return failure();
212
213 return success();
214}
215
216void SetCreateOp::print(OpAsmPrinter &p) {
217 p << " ";
218 p.printOperands(getElements());
219 p.printOptionalAttrDict((*this)->getAttrs());
220 p << " : " << getSet().getType().getElementType();
221}
222
223LogicalResult SetCreateOp::verify() {
224 if (getElements().size() > 0) {
225 // We only need to check the first element because of the `SameTypeOperands`
226 // trait.
227 if (getElements()[0].getType() != getSet().getType().getElementType())
228 return emitOpError() << "operand types must match set element type";
229 }
230
231 return success();
232}
233
234//===----------------------------------------------------------------------===//
235// BagCreateOp
236//===----------------------------------------------------------------------===//
237
238ParseResult BagCreateOp::parse(OpAsmParser &parser, OperationState &result) {
239 llvm::SmallVector<OpAsmParser::UnresolvedOperand, 16> elementOperands,
240 multipleOperands;
241 Type elemType;
242
243 if (!parser.parseOptionalLParen()) {
244 while (true) {
245 OpAsmParser::UnresolvedOperand elementOperand, multipleOperand;
246 if (parser.parseOperand(multipleOperand) || parser.parseKeyword("x") ||
247 parser.parseOperand(elementOperand))
248 return failure();
249
250 elementOperands.push_back(elementOperand);
251 multipleOperands.push_back(multipleOperand);
252
253 if (parser.parseOptionalComma()) {
254 if (parser.parseRParen())
255 return failure();
256 break;
257 }
258 }
259 }
260
261 if (parser.parseColon() || parser.parseType(elemType) ||
262 parser.parseOptionalAttrDict(result.attributes))
263 return failure();
264
265 result.addTypes({BagType::get(result.getContext(), elemType)});
266
267 for (auto operand : elementOperands)
268 if (parser.resolveOperand(operand, elemType, result.operands))
269 return failure();
270
271 for (auto operand : multipleOperands)
272 if (parser.resolveOperand(operand, IndexType::get(result.getContext()),
273 result.operands))
274 return failure();
275
276 return success();
277}
278
279void BagCreateOp::print(OpAsmPrinter &p) {
280 p << " ";
281 if (!getElements().empty())
282 p << "(";
283 llvm::interleaveComma(llvm::zip(getElements(), getMultiples()), p,
284 [&](auto elAndMultiple) {
285 auto [el, multiple] = elAndMultiple;
286 p << multiple << " x " << el;
287 });
288 if (!getElements().empty())
289 p << ")";
290
291 p << " : " << getBag().getType().getElementType();
292 p.printOptionalAttrDict((*this)->getAttrs());
293}
294
295LogicalResult BagCreateOp::verify() {
296 if (!llvm::all_equal(getElements().getTypes()))
297 return emitOpError() << "types of all elements must match";
298
299 if (getElements().size() > 0)
300 if (getElements()[0].getType() != getBag().getType().getElementType())
301 return emitOpError() << "operand types must match bag element type";
302
303 return success();
304}
305
306//===----------------------------------------------------------------------===//
307// FixedRegisterOp
308//===----------------------------------------------------------------------===//
309
310LogicalResult FixedRegisterOp::inferReturnTypes(
311 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
312 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
313 SmallVectorImpl<Type> &inferredReturnTypes) {
314 inferredReturnTypes.push_back(
315 properties.as<Properties *>()->getReg().getType());
316 return success();
317}
318
319OpFoldResult FixedRegisterOp::fold(FoldAdaptor adaptor) { return getRegAttr(); }
320
321//===----------------------------------------------------------------------===//
322// VirtualRegisterOp
323//===----------------------------------------------------------------------===//
324
325LogicalResult VirtualRegisterOp::verify() {
326 if (getAllowedRegs().empty())
327 return emitOpError("must have at least one allowed register");
328
329 if (llvm::any_of(getAllowedRegs(), [](Attribute attr) {
330 return !isa<RegisterAttrInterface>(attr);
331 }))
332 return emitOpError("all elements must be of RegisterAttrInterface");
333
334 if (!llvm::all_equal(
335 llvm::map_range(getAllowedRegs().getAsRange<RegisterAttrInterface>(),
336 [](auto attr) { return attr.getType(); })))
337 return emitOpError("all allowed registers must be of the same type");
338
339 return success();
340}
341
342LogicalResult VirtualRegisterOp::inferReturnTypes(
343 MLIRContext *context, std::optional<Location> loc, ValueRange operands,
344 DictionaryAttr attributes, OpaqueProperties properties, RegionRange regions,
345 SmallVectorImpl<Type> &inferredReturnTypes) {
346 auto allowedRegs = properties.as<Properties *>()->getAllowedRegs();
347 if (allowedRegs.empty()) {
348 if (loc)
349 return mlir::emitError(*loc, "must have at least one allowed register");
350
351 return failure();
352 }
353
354 auto regAttr = dyn_cast<RegisterAttrInterface>(allowedRegs[0]);
355 if (!regAttr) {
356 if (loc)
357 return mlir::emitError(
358 *loc, "allowed register attributes must be of RegisterAttrInterface");
359
360 return failure();
361 }
362 inferredReturnTypes.push_back(regAttr.getType());
363 return success();
364}
365
366//===----------------------------------------------------------------------===//
367// TestOp
368//===----------------------------------------------------------------------===//
369
370LogicalResult TestOp::verifyRegions() {
371 if (!getTarget().entryTypesMatch(getBody()->getArgumentTypes()))
372 return emitOpError("argument types must match dict entry types");
373
374 return success();
375}
376
377//===----------------------------------------------------------------------===//
378// TargetOp
379//===----------------------------------------------------------------------===//
380
381LogicalResult TargetOp::verifyRegions() {
382 if (!getTarget().entryTypesMatch(
383 getBody()->getTerminator()->getOperandTypes()))
384 return emitOpError("terminator operand types must match dict entry types");
385
386 return success();
387}
388
389//===----------------------------------------------------------------------===//
390// TableGen generated logic.
391//===----------------------------------------------------------------------===//
392
393#define GET_OP_CLASSES
394#include "circt/Dialect/RTG/IR/RTG.cpp.inc"
static InstancePath empty
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition rtg.py:1
Definition seq.py:1