CIRCT 23.0.0git
Loading...
Searching...
No Matches
ArcOps.cpp
Go to the documentation of this file.
1//===- ArcOps.cpp ---------------------------------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/OpImplementation.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/IR/SymbolTable.h"
16#include "mlir/Interfaces/FunctionImplementation.h"
17#include "mlir/Interfaces/SideEffectInterfaces.h"
18#include "llvm/ADT/SmallPtrSet.h"
19#include "llvm/ADT/TypeSwitch.h"
20
21using namespace circt;
22using namespace arc;
23using namespace mlir;
24
25//===----------------------------------------------------------------------===//
26// Helpers
27//===----------------------------------------------------------------------===//
28
29static LogicalResult verifyTypeListEquivalence(Operation *op,
30 TypeRange expectedTypeList,
31 TypeRange actualTypeList,
32 StringRef elementName) {
33 if (expectedTypeList.size() != actualTypeList.size())
34 return op->emitOpError("incorrect number of ")
35 << elementName << "s: expected " << expectedTypeList.size()
36 << ", but got " << actualTypeList.size();
37
38 for (unsigned i = 0, e = expectedTypeList.size(); i != e; ++i) {
39 if (expectedTypeList[i] != actualTypeList[i]) {
40 auto diag = op->emitOpError(elementName)
41 << " type mismatch: " << elementName << " #" << i;
42 diag.attachNote() << "expected type: " << expectedTypeList[i];
43 diag.attachNote() << " actual type: " << actualTypeList[i];
44 return diag;
45 }
46 }
47
48 return success();
49}
50
51static LogicalResult verifyArcSymbolUse(Operation *op, TypeRange inputs,
52 TypeRange results,
53 SymbolTableCollection &symbolTable) {
54 // Check that the arc attribute was specified.
55 auto arcName = op->getAttrOfType<FlatSymbolRefAttr>("arc");
56 // The arc attribute is verified by the tablegen generated verifier as it is
57 // an ODS defined attribute.
58 assert(arcName && "FlatSymbolRefAttr called 'arc' missing");
59 DefineOp arc = symbolTable.lookupNearestSymbolFrom<DefineOp>(op, arcName);
60 if (!arc)
61 return op->emitOpError() << "`" << arcName.getValue()
62 << "` does not reference a valid `arc.define`";
63
64 // Verify that the operand and result types match the arc.
65 auto type = arc.getFunctionType();
66 if (failed(
67 verifyTypeListEquivalence(op, type.getInputs(), inputs, "operand")))
68 return failure();
69
70 if (failed(
71 verifyTypeListEquivalence(op, type.getResults(), results, "result")))
72 return failure();
73
74 return success();
75}
76
77static bool isSupportedModuleOp(Operation *moduleOp) {
78 return llvm::isa<arc::ModelOp, hw::HWModuleLike>(moduleOp);
79}
80
81/// Fetches the operation pointed to by `pointing` with name `symbol`, checking
82/// that it is a supported model operation for simulation.
83static Operation *getSupportedModuleOp(SymbolTableCollection &symbolTable,
84 Operation *pointing, StringAttr symbol) {
85 Operation *moduleOp = symbolTable.lookupNearestSymbolFrom(pointing, symbol);
86 if (!moduleOp) {
87 pointing->emitOpError("model not found");
88 return nullptr;
89 }
90
91 if (!isSupportedModuleOp(moduleOp)) {
92 pointing->emitOpError("model symbol does not point to a supported model "
93 "operation, points to ")
94 << moduleOp->getName() << " instead";
95 return nullptr;
96 }
97
98 return moduleOp;
99}
100
101static std::optional<hw::ModulePort> getModulePort(Operation *moduleOp,
102 StringRef portName) {
103 auto findRightPort = [&](auto ports) -> std::optional<hw::ModulePort> {
104 const hw::ModulePort *port = llvm::find_if(
105 ports, [&](hw::ModulePort port) { return port.name == portName; });
106 if (port == ports.end())
107 return std::nullopt;
108 return *port;
109 };
110
111 return TypeSwitch<Operation *, std::optional<hw::ModulePort>>(moduleOp)
112 .Case<arc::ModelOp>(
113 [&](arc::ModelOp modelOp) -> std::optional<hw::ModulePort> {
114 return findRightPort(modelOp.getIo().getPorts());
115 })
116 .Case<hw::HWModuleLike>(
117 [&](hw::HWModuleLike moduleLike) -> std::optional<hw::ModulePort> {
118 return findRightPort(moduleLike.getPortList());
119 })
120 .Default([](Operation *) { return std::nullopt; });
121}
122
123//===----------------------------------------------------------------------===//
124// DefineOp
125//===----------------------------------------------------------------------===//
126
127ParseResult DefineOp::parse(OpAsmParser &parser, OperationState &result) {
128 auto buildFuncType =
129 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
130 function_interface_impl::VariadicFlag,
131 std::string &) { return builder.getFunctionType(argTypes, results); };
132
133 return function_interface_impl::parseFunctionOp(
134 parser, result, /*allowVariadic=*/false,
135 getFunctionTypeAttrName(result.name), buildFuncType,
136 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
137}
138
139void DefineOp::print(OpAsmPrinter &p) {
140 function_interface_impl::printFunctionOp(
141 p, *this, /*isVariadic=*/false, "function_type", getArgAttrsAttrName(),
142 getResAttrsAttrName());
143}
144
145LogicalResult DefineOp::verifyRegions() {
146 // Check that the body does not contain any side-effecting operations. We can
147 // simply iterate over the ops directly within the body; operations with
148 // regions, like scf::IfOp, implement the `HasRecursiveMemoryEffects` trait
149 // which causes the `isMemoryEffectFree` check to already recur into their
150 // regions.
151 for (auto &op : getBodyBlock()) {
152 if (isMemoryEffectFree(&op))
153 continue;
154
155 // We don't use a op-error here because that leads to the whole arc being
156 // printed. This can be switched of when creating the context, but one
157 // might not want to switch that off for other error messages. Here it's
158 // definitely not desirable as arcs can be very big and would fill up the
159 // error log, making it hard to read. Currently, only the signature (first
160 // line) of the arc is printed.
161 auto diag = mlir::emitError(getLoc(), "body contains non-pure operation");
162 diag.attachNote(op.getLoc()).append("first non-pure operation here: ");
163 return diag;
164 }
165 return success();
166}
167
168bool DefineOp::isPassthrough() {
169 if (getNumArguments() != getNumResults())
170 return false;
171
172 return llvm::all_of(
173 llvm::zip(getArguments(), getBodyBlock().getTerminator()->getOperands()),
174 [](const auto &argAndRes) {
175 return std::get<0>(argAndRes) == std::get<1>(argAndRes);
176 });
177}
178
179//===----------------------------------------------------------------------===//
180// OutputOp
181//===----------------------------------------------------------------------===//
182
183LogicalResult OutputOp::verify() {
184 auto *parent = (*this)->getParentOp();
185 TypeRange expectedTypes = parent->getResultTypes();
186 if (auto defOp = dyn_cast<DefineOp>(parent))
187 expectedTypes = defOp.getResultTypes();
188
189 TypeRange actualTypes = getOperands().getTypes();
190 return verifyTypeListEquivalence(*this, expectedTypes, actualTypes, "output");
191}
192
193//===----------------------------------------------------------------------===//
194// StateOp
195//===----------------------------------------------------------------------===//
196
197LogicalResult StateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
198 return verifyArcSymbolUse(*this, getInputs().getTypes(),
199 getResults().getTypes(), symbolTable);
200}
201
202LogicalResult StateOp::verify() {
203 if (getLatency() < 1)
204 return emitOpError("latency must be a positive integer");
205
206 if (!getClock())
207 return emitOpError("requires a clock");
208
209 return success();
210}
211
212//===----------------------------------------------------------------------===//
213// StateWriteOp
214//===----------------------------------------------------------------------===//
215
216LogicalResult
217StateWriteOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
218 if (!getTraceTapModel().has_value())
219 return success();
220
221 auto modelOp = symbolTable.lookupNearestSymbolFrom<ModelOp>(
222 getOperation(), getTraceTapModelAttr());
223 if (!modelOp)
224 return emitOpError() << "`" << getTraceTapModelAttr()
225 << "` does not reference a valid `arc.model`";
226 if (!modelOp.getTraceTaps())
227 return emitOpError() << "referenced model has no trace metadata";
228 if (modelOp.getTraceTapsAttr().size() <= *getTraceTapIndex())
229 return emitOpError() << "tap index exceeds model's tap array";
230 auto tapAttr =
231 cast<TraceTapAttr>(modelOp.getTraceTapsAttr()[*getTraceTapIndex()]);
232 if (tapAttr.getSigType().getValue() != getValue().getType())
233 return emitOpError() << "incorrect signal type in referenced tap attribute";
234
235 return success();
236}
237
238LogicalResult StateWriteOp::verify() {
239 if (getTraceTapIndex().has_value() == getTraceTapModel().has_value())
240 return success();
241 return emitOpError() << "must specify both a trace tap model and index";
242}
243
244//===----------------------------------------------------------------------===//
245// CallOp
246//===----------------------------------------------------------------------===//
247
248LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
249 return verifyArcSymbolUse(*this, getInputs().getTypes(),
250 getResults().getTypes(), symbolTable);
251}
252
253bool CallOp::isClocked() { return false; }
254
255Value CallOp::getClock() { return Value{}; }
256
257void CallOp::eraseClock() {}
258
259uint32_t CallOp::getLatency() { return 0; }
260
261//===----------------------------------------------------------------------===//
262// MemoryWritePortOp
263//===----------------------------------------------------------------------===//
264
265SmallVector<Type> MemoryWritePortOp::getArcResultTypes() {
266 auto memType = cast<MemoryType>(getMemory().getType());
267 SmallVector<Type> resultTypes{memType.getAddressType(),
268 memType.getWordType()};
269 if (getEnable())
270 resultTypes.push_back(IntegerType::get(getContext(), 1));
271 if (getMask())
272 resultTypes.push_back(memType.getWordType());
273 return resultTypes;
274}
275
276LogicalResult
277MemoryWritePortOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
278 return verifyArcSymbolUse(*this, getInputs().getTypes(), getArcResultTypes(),
279 symbolTable);
280}
281
282LogicalResult MemoryWritePortOp::verify() {
283 if (getLatency() < 1)
284 return emitOpError("latency must be at least 1");
285
286 if (!getClock())
287 return emitOpError("requires a clock");
288
289 return success();
290}
291
292//===----------------------------------------------------------------------===//
293// RootInputOp
294//===----------------------------------------------------------------------===//
295
296void RootInputOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
297 SmallString<32> buf("in_");
298 buf += getName();
299 setNameFn(getState(), buf);
300}
301
302//===----------------------------------------------------------------------===//
303// RootOutputOp
304//===----------------------------------------------------------------------===//
305
306void RootOutputOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
307 SmallString<32> buf("out_");
308 buf += getName();
309 setNameFn(getState(), buf);
310}
311
312//===----------------------------------------------------------------------===//
313// ModelOp
314//===----------------------------------------------------------------------===//
315
316LogicalResult ModelOp::verify() {
317 if (getBodyBlock().getArguments().size() != 1)
318 return emitOpError("must have exactly one argument");
319 if (auto type = getBodyBlock().getArgument(0).getType();
320 !isa<StorageType>(type))
321 return emitOpError("argument must be of storage type");
322 for (const hw::ModulePort &port : getIo().getPorts())
323 if (port.dir == hw::ModulePort::Direction::InOut)
324 return emitOpError("inout ports are not supported");
325 return success();
326}
327
328LogicalResult ModelOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
329 auto fnAttrs = std::array{getInitialFnAttr(), getFinalFnAttr()};
330 auto nouns = std::array{"initializer", "finalizer"};
331 for (auto [fnAttr, noun] : llvm::zip(fnAttrs, nouns)) {
332 if (!fnAttr)
333 continue;
334 auto fn = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, fnAttr);
335 if (!fn)
336 return emitOpError() << noun << " '" << fnAttr.getValue()
337 << "' does not reference a valid function";
338 if (!llvm::equal(fn.getArgumentTypes(), getBody().getArgumentTypes())) {
339 auto diag = emitError() << noun << " '" << fnAttr.getValue()
340 << "' arguments must match arguments of model";
341 diag.attachNote(fn.getLoc()) << noun << " declared here:";
342 return diag;
343 }
344 }
345 return success();
346}
347
348//===----------------------------------------------------------------------===//
349// LutOp
350//===----------------------------------------------------------------------===//
351
352LogicalResult LutOp::verify() {
353 Location firstSideEffectOpLoc = UnknownLoc::get(getContext());
354 const WalkResult result = getBody().walk([&](Operation *op) {
355 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
356 SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>> effects;
357 memOp.getEffects(effects);
358
359 if (!effects.empty()) {
360 firstSideEffectOpLoc = memOp->getLoc();
361 return WalkResult::interrupt();
362 }
363 }
364
365 return WalkResult::advance();
366 });
367
368 if (result.wasInterrupted())
369 return emitOpError("no operations with side-effects allowed inside a LUT")
370 .attachNote(firstSideEffectOpLoc)
371 << "first operation with side-effects here";
372
373 return success();
374}
375
376//===----------------------------------------------------------------------===//
377// VectorizeOp
378//===----------------------------------------------------------------------===//
379
380LogicalResult VectorizeOp::verify() {
381 if (getInputs().empty())
382 return emitOpError("there has to be at least one input vector");
383
384 if (!llvm::all_equal(llvm::map_range(
385 getInputs(), [](OperandRange range) { return range.size(); })))
386 return emitOpError("all input vectors must have the same size");
387
388 for (OperandRange range : getInputs()) {
389 if (!llvm::all_equal(range.getTypes()))
390 return emitOpError("all input vector lane types must match");
391
392 if (range.empty())
393 return emitOpError("input vector must have at least one element");
394 }
395
396 if (getResults().empty())
397 return emitOpError("must have at least one result");
398
399 if (!llvm::all_equal(getResults().getTypes()))
400 return emitOpError("all result types must match");
401
402 if (getResults().size() != getInputs().front().size())
403 return emitOpError("number results must match input vector size");
404
405 return success();
406}
407
408static FailureOr<unsigned> getVectorWidth(Type base, Type vectorized) {
409 if (isa<VectorType>(base))
410 return failure();
411
412 if (auto vectorTy = dyn_cast<VectorType>(vectorized)) {
413 if (vectorTy.getElementType() != base)
414 return failure();
415
416 return vectorTy.getDimSize(0);
417 }
418
419 if (vectorized.getIntOrFloatBitWidth() < base.getIntOrFloatBitWidth())
420 return failure();
421
422 if (vectorized.getIntOrFloatBitWidth() % base.getIntOrFloatBitWidth() == 0)
423 return vectorized.getIntOrFloatBitWidth() / base.getIntOrFloatBitWidth();
424
425 return failure();
426}
427
428LogicalResult VectorizeOp::verifyRegions() {
429 auto returnOp = cast<VectorizeReturnOp>(getBody().front().getTerminator());
430 TypeRange bodyArgTypes = getBody().front().getArgumentTypes();
431
432 if (bodyArgTypes.size() != getInputs().size())
433 return emitOpError(
434 "number of block arguments must match number of input vectors");
435
436 // Boundary and body are vectorized, or both are not vectorized
437 if (returnOp.getValue().getType() == getResultTypes().front()) {
438 for (auto [i, argTy] : llvm::enumerate(bodyArgTypes))
439 if (argTy != getInputs()[i].getTypes().front())
440 return emitOpError("if terminator type matches result type the "
441 "argument types must match the input types");
442
443 return success();
444 }
445
446 // Boundary is vectorized, body is not
447 if (auto width = getVectorWidth(returnOp.getValue().getType(),
448 getResultTypes().front());
449 succeeded(width)) {
450 for (auto [i, argTy] : llvm::enumerate(bodyArgTypes)) {
451 Type inputTy = getInputs()[i].getTypes().front();
452 FailureOr<unsigned> argWidth = getVectorWidth(argTy, inputTy);
453 if (failed(argWidth))
454 return emitOpError("block argument must be a scalar variant of the "
455 "vectorized operand");
456
457 if (*argWidth != width)
458 return emitOpError("input and output vector width must match");
459 }
460
461 return success();
462 }
463
464 // Body is vectorized, boundary is not
465 if (auto width = getVectorWidth(getResultTypes().front(),
466 returnOp.getValue().getType());
467 succeeded(width)) {
468 for (auto [i, argTy] : llvm::enumerate(bodyArgTypes)) {
469 Type inputTy = getInputs()[i].getTypes().front();
470 FailureOr<unsigned> argWidth = getVectorWidth(inputTy, argTy);
471 if (failed(argWidth))
472 return emitOpError(
473 "block argument must be a vectorized variant of the operand");
474
475 if (*argWidth != width)
476 return emitOpError("input and output vector width must match");
477
478 if (getInputs()[i].size() > 1 && argWidth != getInputs()[i].size())
479 return emitOpError(
480 "when boundary not vectorized the number of vector element "
481 "operands must match the width of the vectorized body");
482 }
483
484 return success();
485 }
486
487 return returnOp.emitOpError(
488 "operand type must match parent op's result value or be a vectorized or "
489 "non-vectorized variant of it");
490}
491
492bool VectorizeOp::isBoundaryVectorized() {
493 return getInputs().front().size() == 1;
494}
495bool VectorizeOp::isBodyVectorized() {
496 auto returnOp = cast<VectorizeReturnOp>(getBody().front().getTerminator());
497 if (isBoundaryVectorized() &&
498 returnOp.getValue().getType() == getResultTypes().front())
499 return true;
500
501 if (auto width = getVectorWidth(getResultTypes().front(),
502 returnOp.getValue().getType());
503 succeeded(width))
504 return *width > 1;
505
506 return false;
507}
508
509//===----------------------------------------------------------------------===//
510// SimInstantiateOp
511//===----------------------------------------------------------------------===//
512
513void SimInstantiateOp::print(OpAsmPrinter &p) {
514 BlockArgument modelArg = getBody().getArgument(0);
515 auto modelType = cast<SimModelInstanceType>(modelArg.getType());
516
517 p << " " << modelType.getModel() << " as ";
518 p.printRegionArgument(modelArg, {}, true);
519
520 if (getRuntimeModel() || getRuntimeArgs()) {
521 p << " runtime ";
522 if (getRuntimeModel())
523 p << getRuntimeModelAttr();
524 p << "(";
525 if (getRuntimeArgs())
526 p << getRuntimeArgsAttr();
527 p << ")";
528 }
529
530 p.printOptionalAttrDictWithKeyword(
531 getOperation()->getAttrs(),
532 {getRuntimeModelAttrName(), getRuntimeArgsAttrName()});
533
534 p << " ";
535
536 p.printRegion(getBody(), false);
537}
538
539ParseResult SimInstantiateOp::parse(OpAsmParser &parser,
540 OperationState &result) {
541 StringAttr modelName;
542 if (failed(parser.parseSymbolName(modelName)))
543 return failure();
544
545 if (failed(parser.parseKeyword("as")))
546 return failure();
547
548 OpAsmParser::Argument modelArg;
549 if (failed(parser.parseArgument(modelArg, false, false)))
550 return failure();
551
552 if (succeeded(parser.parseOptionalKeyword("runtime"))) {
553 StringAttr runtimeSym;
554 StringAttr runtimeArgs;
555 auto symOpt = parser.parseOptionalSymbolName(runtimeSym);
556 if (parser.parseLParen())
557 return failure();
558 auto nameOpt = parser.parseOptionalAttribute(runtimeArgs);
559 if (parser.parseRParen())
560 return failure();
561 if (succeeded(symOpt))
562 result.addAttribute(
563 SimInstantiateOp::getRuntimeModelAttrName(result.name),
564 FlatSymbolRefAttr::get(runtimeSym));
565 if (nameOpt.has_value())
566 result.addAttribute(SimInstantiateOp::getRuntimeArgsAttrName(result.name),
567 runtimeArgs);
568 }
569
570 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
571 return failure();
572
573 MLIRContext *ctxt = result.getContext();
574 modelArg.type =
575 SimModelInstanceType::get(ctxt, FlatSymbolRefAttr::get(ctxt, modelName));
576
577 std::unique_ptr<Region> body = std::make_unique<Region>();
578 if (failed(parser.parseRegion(*body, {modelArg})))
579 return failure();
580
581 result.addRegion(std::move(body));
582 return success();
583}
584
585LogicalResult SimInstantiateOp::verifyRegions() {
586 Region &body = getBody();
587 if (body.getNumArguments() != 1)
588 return emitError("entry block of body region must have the model instance "
589 "as a single argument");
590 if (!llvm::isa<SimModelInstanceType>(body.getArgument(0).getType()))
591 return emitError("entry block argument type is not a model instance");
592 return success();
593}
594
595LogicalResult
596SimInstantiateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
597 bool failed = false;
598 Operation *moduleOp = getSupportedModuleOp(
599 symbolTable, getOperation(),
600 llvm::cast<SimModelInstanceType>(getBody().getArgument(0).getType())
601 .getModel()
602 .getAttr());
603 if (!moduleOp)
604 failed = true;
605
606 if (getRuntimeModel().has_value()) {
607 Operation *runtimeModelOp = symbolTable.lookupNearestSymbolFrom(
608 getOperation(), getRuntimeModelAttr());
609 if (!runtimeModelOp) {
610 emitOpError("runtime model not found");
611 failed = true;
612 } else if (!isa<RuntimeModelOp>(runtimeModelOp)) {
613 emitOpError("referenced runtime model is not a RuntimeModelOp");
614 failed = true;
615 }
616 }
617
618 return success(!failed);
619}
620
621//===----------------------------------------------------------------------===//
622// SimSetInputOp
623//===----------------------------------------------------------------------===//
624
625LogicalResult
626SimSetInputOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
627 Operation *moduleOp = getSupportedModuleOp(
628 symbolTable, getOperation(),
629 llvm::cast<SimModelInstanceType>(getInstance().getType())
630 .getModel()
631 .getAttr());
632 if (!moduleOp)
633 return failure();
634
635 std::optional<hw::ModulePort> port = getModulePort(moduleOp, getInput());
636 if (!port)
637 return emitOpError("port not found on model");
638
639 if (port->dir != hw::ModulePort::Direction::Input &&
640 port->dir != hw::ModulePort::Direction::InOut)
641 return emitOpError("port is not an input port");
642
643 if (port->type != getValue().getType())
644 return emitOpError(
645 "mismatched types between value and model port, port expects ")
646 << port->type;
647
648 return success();
649}
650
651//===----------------------------------------------------------------------===//
652// SimGetPortOp
653//===----------------------------------------------------------------------===//
654
655LogicalResult
656SimGetPortOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
657 Operation *moduleOp = getSupportedModuleOp(
658 symbolTable, getOperation(),
659 llvm::cast<SimModelInstanceType>(getInstance().getType())
660 .getModel()
661 .getAttr());
662 if (!moduleOp)
663 return failure();
664
665 std::optional<hw::ModulePort> port = getModulePort(moduleOp, getPort());
666 if (!port)
667 return emitOpError("port not found on model");
668
669 if (port->type != getValue().getType())
670 return emitOpError(
671 "mismatched types between value and model port, port expects ")
672 << port->type;
673
674 return success();
675}
676
677//===----------------------------------------------------------------------===//
678// SimStepOp
679//===----------------------------------------------------------------------===//
680
681LogicalResult SimStepOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
682 Operation *moduleOp = getSupportedModuleOp(
683 symbolTable, getOperation(),
684 llvm::cast<SimModelInstanceType>(getInstance().getType())
685 .getModel()
686 .getAttr());
687 if (!moduleOp)
688 return failure();
689
690 return success();
691}
692
693//===----------------------------------------------------------------------===//
694// SimGetTimeOp
695//===----------------------------------------------------------------------===//
696
697LogicalResult
698SimGetTimeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
699 Operation *moduleOp = getSupportedModuleOp(
700 symbolTable, getOperation(),
701 llvm::cast<SimModelInstanceType>(getInstance().getType())
702 .getModel()
703 .getAttr());
704 if (!moduleOp)
705 return failure();
706
707 return success();
708}
709
710//===----------------------------------------------------------------------===//
711// SimSetTimeOp
712//===----------------------------------------------------------------------===//
713
714LogicalResult
715SimSetTimeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
716 Operation *moduleOp = getSupportedModuleOp(
717 symbolTable, getOperation(),
718 llvm::cast<SimModelInstanceType>(getInstance().getType())
719 .getModel()
720 .getAttr());
721 if (!moduleOp)
722 return failure();
723
724 return success();
725}
726
727//===----------------------------------------------------------------------===//
728// SimGetNextWakeupOp
729//===----------------------------------------------------------------------===//
730
731LogicalResult
732SimGetNextWakeupOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
733 Operation *moduleOp = getSupportedModuleOp(
734 symbolTable, getOperation(),
735 llvm::cast<SimModelInstanceType>(getInstance().getType())
736 .getModel()
737 .getAttr());
738 if (!moduleOp)
739 return failure();
740
741 return success();
742}
743
744//===----------------------------------------------------------------------===//
745// CoroutineDefineOp
746//===----------------------------------------------------------------------===//
747
748/// Resolve the callee symbol to a `CoroutineDefineOp` and verify that the
749/// given operand and result types match its function type.
750static LogicalResult verifyCoroutineCallTypes(Operation *op,
751 FlatSymbolRefAttr callee,
752 TypeRange operands,
753 TypeRange results,
754 SymbolTableCollection &symTable) {
755 auto defineOp =
756 symTable.lookupNearestSymbolFrom<CoroutineDefineOp>(op, callee);
757 if (!defineOp)
758 return op->emitOpError() << "`" << callee.getValue()
759 << "` does not reference a valid "
760 "`arc.coroutine.define`";
761
762 auto fnType = defineOp.getFunctionType();
763 if (failed(verifyTypeListEquivalence(op, fnType.getInputs(), operands,
764 "operand")))
765 return failure();
766 if (failed(verifyTypeListEquivalence(op, fnType.getResults(), results,
767 "result")))
768 return failure();
769 return success();
770}
771
772ParseResult CoroutineDefineOp::parse(OpAsmParser &parser,
773 OperationState &result) {
774 auto buildFuncType =
775 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
776 function_interface_impl::VariadicFlag,
777 std::string &) { return builder.getFunctionType(argTypes, results); };
778
779 return function_interface_impl::parseFunctionOp(
780 parser, result, /*allowVariadic=*/false,
781 getFunctionTypeAttrName(result.name), buildFuncType,
782 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
783}
784
785void CoroutineDefineOp::print(OpAsmPrinter &p) {
786 function_interface_impl::printFunctionOp(
787 p, *this, /*isVariadic=*/false, "function_type", getArgAttrsAttrName(),
788 getResAttrsAttrName());
789}
790
791//===----------------------------------------------------------------------===//
792// CoroutineCallOp
793//===----------------------------------------------------------------------===//
794
795LogicalResult
796CoroutineCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
797 // The `state`/`pc` and `resumeState`/`resumePC` types are constrained to
798 // wrap the callee symbol by the `CoroutineCalleeWrappedType` traits on the
799 // op. All that remains is to resolve the callee and check that the trailing
800 // arg/result types match its signature.
801 auto callee = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
802 return verifyCoroutineCallTypes(*this, callee, getArgs().getTypes(),
803 getResults().getTypes(), symbolTable);
804}
805
806//===----------------------------------------------------------------------===//
807// CoroutineInstanceOp
808//===----------------------------------------------------------------------===//
809
810// An instance hides the coroutine's trailing wakeup time. Verify that the
811// callee declares a wakeup as its last result and that the instance's args and
812// results match the callee's signature with that last result removed.
813LogicalResult
814CoroutineInstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
815 auto callee = (*this)->getAttrOfType<FlatSymbolRefAttr>("callee");
816 auto defineOp =
817 symbolTable.lookupNearestSymbolFrom<CoroutineDefineOp>(*this, callee);
818 if (!defineOp)
819 return emitOpError() << "`" << callee.getValue()
820 << "` does not reference a valid "
821 "`arc.coroutine.define`";
822
823 auto fnType = defineOp.getFunctionType();
824 auto fnResults = fnType.getResults();
825 if (fnResults.empty() || !fnResults.back().isInteger(64))
826 return emitOpError() << "referenced coroutine `" << callee.getValue()
827 << "` must produce an `i64` wakeup time as its "
828 "last result";
829
830 if (failed(verifyTypeListEquivalence(*this, fnType.getInputs(),
831 getArgs().getTypes(), "operand")))
832 return failure();
833 if (failed(verifyTypeListEquivalence(*this, fnResults.drop_back(),
834 getResults().getTypes(), "result")))
835 return failure();
836 return success();
837}
838
839//===----------------------------------------------------------------------===//
840// Coroutine Terminators
841//===----------------------------------------------------------------------===//
842
843// The three terminators all yield values back through the enclosing
844// `arc.coroutine.define`'s result types. The helper below extracts the
845// expected types from the parent and checks them against the given operand
846// types.
847static LogicalResult verifyCoroutineTerminator(Operation *op,
848 TypeRange yieldOperands) {
849 auto parent = op->getParentOfType<CoroutineDefineOp>();
850 return verifyTypeListEquivalence(op, parent.getResultTypes(), yieldOperands,
851 "yielded value");
852}
853
854LogicalResult CoroutineYieldOp::verify() {
855 if (failed(verifyCoroutineTerminator(*this, getYieldOperands().getTypes())))
856 return failure();
857
858 // The `BranchOpInterface` already verifies that the destination block has
859 // the right number of arguments and that the trailing arguments match the
860 // yield's destination operands. Additionally verify that the leading
861 // arguments, which are supplied fresh by the caller upon resumption, match
862 // the coroutine's function type.
863 auto parent = (*this)->getParentOfType<CoroutineDefineOp>();
864 TypeRange coroutineArgTypes = parent.getArgumentTypes();
865 TypeRange destArgTypes = getDest()->getArgumentTypes();
866 if (destArgTypes.size() >= coroutineArgTypes.size())
867 if (failed(verifyTypeListEquivalence(
868 *this, coroutineArgTypes,
869 destArgTypes.take_front(coroutineArgTypes.size()),
870 "destination resume argument")))
871 return failure();
872
873 return success();
874}
875
876// The destination block's leading arguments match the coroutine's function
877// type and are supplied fresh by the caller upon resumption. They are
878// therefore "produced" operands from the branch's point of view. The
879// remaining destination block arguments map to the yield's destination
880// operands.
881SuccessorOperands CoroutineYieldOp::getSuccessorOperands(unsigned index) {
882 assert(index == 0 && "invalid successor index");
883 auto parent = (*this)->getParentOfType<CoroutineDefineOp>();
884 return SuccessorOperands(parent.getArgumentTypes().size(),
885 getDestOperandsMutable());
886}
887
888LogicalResult CoroutineReturnOp::verify() {
889 return verifyCoroutineTerminator(*this, getYieldOperands().getTypes());
890}
891
892LogicalResult CoroutineHaltOp::verify() {
893 return verifyCoroutineTerminator(*this, getYieldOperands().getTypes());
894}
895
896//===----------------------------------------------------------------------===//
897// ExecuteOp
898//===----------------------------------------------------------------------===//
899
900LogicalResult ExecuteOp::verifyRegions() {
901 return verifyTypeListEquivalence(*this, getInputs().getTypes(),
902 getBody().getArgumentTypes(), "input");
903}
904
905//===----------------------------------------------------------------------===//
906// ArrayRefAllocOp
907//===----------------------------------------------------------------------===//
908
909LogicalResult ArrayRefAllocOp::verify() {
910 if (auto init = getInit()) {
911 if (init->size() != getType().getNumElements()) {
912 return emitOpError("init size does not match array size; init had size ")
913 << init->size() << " but array has size "
914 << getType().getNumElements();
915 }
916
917 unsigned elemBitwidth = getType().getElementType().getIntOrFloatBitWidth();
918 for (APInt value : init->getAsValueRange<IntegerAttr>()) {
919 if (value.getBitWidth() != elemBitwidth) {
920 return emitOpError("expected element to be of type ")
921 << getType().getElementType();
922 }
923 }
924 }
925 return success();
926}
927
928#include "circt/Dialect/Arc/ArcInterfaces.cpp.inc"
929
930#define GET_OP_CLASSES
931#include "circt/Dialect/Arc/Arc.cpp.inc"
static FailureOr< unsigned > getVectorWidth(Type base, Type vectorized)
Definition ArcOps.cpp:408
static std::optional< hw::ModulePort > getModulePort(Operation *moduleOp, StringRef portName)
Definition ArcOps.cpp:101
static bool isSupportedModuleOp(Operation *moduleOp)
Definition ArcOps.cpp:77
static LogicalResult verifyArcSymbolUse(Operation *op, TypeRange inputs, TypeRange results, SymbolTableCollection &symbolTable)
Definition ArcOps.cpp:51
static LogicalResult verifyTypeListEquivalence(Operation *op, TypeRange expectedTypeList, TypeRange actualTypeList, StringRef elementName)
Definition ArcOps.cpp:29
static LogicalResult verifyCoroutineCallTypes(Operation *op, FlatSymbolRefAttr callee, TypeRange operands, TypeRange results, SymbolTableCollection &symTable)
Resolve the callee symbol to a CoroutineDefineOp and verify that the given operand and result types m...
Definition ArcOps.cpp:750
static LogicalResult verifyCoroutineTerminator(Operation *op, TypeRange yieldOperands)
Definition ArcOps.cpp:847
static Operation * getSupportedModuleOp(SymbolTableCollection &symbolTable, Operation *pointing, StringAttr symbol)
Fetches the operation pointed to by pointing with name symbol, checking that it is a supported model ...
Definition ArcOps.cpp:83
assert(baseType &&"element must be base type")
static PortInfo getPort(ModuleTy &mod, size_t idx)
Definition HWOps.cpp:1458
@ InOut
Definition HW.h:42
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:212
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
Definition arc.py:1
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
Definition LLVM.h:193
mlir::StringAttr name
Definition HWTypes.h:31