11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/Builders.h"
13#include "mlir/IR/OpImplementation.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/IR/SymbolTable.h"
16#include "mlir/Interfaces/FunctionImplementation.h"
17#include "mlir/Interfaces/SideEffectInterfaces.h"
18#include "llvm/ADT/SmallPtrSet.h"
19#include "llvm/ADT/TypeSwitch.h"
30 TypeRange expectedTypeList,
31 TypeRange actualTypeList,
32 StringRef elementName) {
33 if (expectedTypeList.size() != actualTypeList.size())
34 return op->emitOpError(
"incorrect number of ")
35 << elementName <<
"s: expected " << expectedTypeList.size()
36 <<
", but got " << actualTypeList.size();
38 for (
unsigned i = 0, e = expectedTypeList.size(); i != e; ++i) {
39 if (expectedTypeList[i] != actualTypeList[i]) {
40 auto diag = op->emitOpError(elementName)
41 <<
" type mismatch: " << elementName <<
" #" << i;
42 diag.attachNote() <<
"expected type: " << expectedTypeList[i];
43 diag.attachNote() <<
" actual type: " << actualTypeList[i];
53 SymbolTableCollection &symbolTable) {
55 auto arcName = op->getAttrOfType<FlatSymbolRefAttr>(
"arc");
58 assert(arcName &&
"FlatSymbolRefAttr called 'arc' missing");
59 DefineOp
arc = symbolTable.lookupNearestSymbolFrom<DefineOp>(op, arcName);
61 return op->emitOpError() <<
"`" << arcName.getValue()
62 <<
"` does not reference a valid `arc.define`";
65 auto type =
arc.getFunctionType();
78 return llvm::isa<arc::ModelOp, hw::HWModuleLike>(moduleOp);
84 Operation *pointing, StringAttr symbol) {
85 Operation *moduleOp = symbolTable.lookupNearestSymbolFrom(pointing, symbol);
87 pointing->emitOpError(
"model not found");
92 pointing->emitOpError(
"model symbol does not point to a supported model "
93 "operation, points to ")
94 << moduleOp->getName() <<
" instead";
102 StringRef portName) {
103 auto findRightPort = [&](
auto ports) -> std::optional<hw::ModulePort> {
106 if (port == ports.end())
111 return TypeSwitch<Operation *, std::optional<hw::ModulePort>>(moduleOp)
113 [&](arc::ModelOp modelOp) -> std::optional<hw::ModulePort> {
114 return findRightPort(modelOp.getIo().getPorts());
116 .Case<hw::HWModuleLike>(
117 [&](hw::HWModuleLike moduleLike) -> std::optional<hw::ModulePort> {
118 return findRightPort(moduleLike.getPortList());
120 .Default([](Operation *) {
return std::nullopt; });
127ParseResult DefineOp::parse(OpAsmParser &parser, OperationState &result) {
129 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
130 function_interface_impl::VariadicFlag,
131 std::string &) {
return builder.getFunctionType(argTypes, results); };
133 return function_interface_impl::parseFunctionOp(
134 parser, result,
false,
135 getFunctionTypeAttrName(result.name), buildFuncType,
136 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
139void DefineOp::print(OpAsmPrinter &p) {
140 function_interface_impl::printFunctionOp(
141 p, *
this,
false,
"function_type", getArgAttrsAttrName(),
142 getResAttrsAttrName());
145LogicalResult DefineOp::verifyRegions() {
152 if (isMemoryEffectFree(&op))
161 auto diag = mlir::emitError(
getLoc(),
"body contains non-pure operation");
162 diag.attachNote(op.getLoc()).append(
"first non-pure operation here: ");
168bool DefineOp::isPassthrough() {
169 if (getNumArguments() != getNumResults())
173 llvm::zip(getArguments(),
getBodyBlock().getTerminator()->getOperands()),
174 [](
const auto &argAndRes) {
175 return std::get<0>(argAndRes) == std::get<1>(argAndRes);
183LogicalResult OutputOp::verify() {
184 auto *parent = (*this)->getParentOp();
185 TypeRange expectedTypes = parent->getResultTypes();
186 if (
auto defOp = dyn_cast<DefineOp>(parent))
187 expectedTypes = defOp.getResultTypes();
189 TypeRange actualTypes = getOperands().getTypes();
197LogicalResult StateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
199 getResults().getTypes(), symbolTable);
202LogicalResult StateOp::verify() {
203 if (getLatency() < 1)
204 return emitOpError(
"latency must be a positive integer");
207 return emitOpError(
"requires a clock");
217StateWriteOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
218 if (!getTraceTapModel().has_value())
221 auto modelOp = symbolTable.lookupNearestSymbolFrom<ModelOp>(
222 getOperation(), getTraceTapModelAttr());
224 return emitOpError() <<
"`" << getTraceTapModelAttr()
225 <<
"` does not reference a valid `arc.model`";
226 if (!modelOp.getTraceTaps())
227 return emitOpError() <<
"referenced model has no trace metadata";
228 if (modelOp.getTraceTapsAttr().size() <= *getTraceTapIndex())
229 return emitOpError() <<
"tap index exceeds model's tap array";
231 cast<TraceTapAttr>(modelOp.getTraceTapsAttr()[*getTraceTapIndex()]);
232 if (tapAttr.getSigType().getValue() != getValue().getType())
233 return emitOpError() <<
"incorrect signal type in referenced tap attribute";
238LogicalResult StateWriteOp::verify() {
239 if (getTraceTapIndex().has_value() == getTraceTapModel().has_value())
241 return emitOpError() <<
"must specify both a trace tap model and index";
248LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
250 getResults().getTypes(), symbolTable);
253bool CallOp::isClocked() {
return false; }
255Value CallOp::getClock() {
return Value{}; }
257void CallOp::eraseClock() {}
259uint32_t CallOp::getLatency() {
return 0; }
265SmallVector<Type> MemoryWritePortOp::getArcResultTypes() {
266 auto memType = cast<MemoryType>(getMemory().getType());
267 SmallVector<Type> resultTypes{memType.getAddressType(),
268 memType.getWordType()};
270 resultTypes.push_back(IntegerType::get(getContext(), 1));
272 resultTypes.push_back(memType.getWordType());
277MemoryWritePortOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
282LogicalResult MemoryWritePortOp::verify() {
283 if (getLatency() < 1)
284 return emitOpError(
"latency must be at least 1");
287 return emitOpError(
"requires a clock");
297 SmallString<32> buf(
"in_");
299 setNameFn(getState(), buf);
307 SmallString<32> buf(
"out_");
309 setNameFn(getState(), buf);
316LogicalResult ModelOp::verify() {
318 return emitOpError(
"must have exactly one argument");
319 if (
auto type =
getBodyBlock().getArgument(0).getType();
320 !isa<StorageType>(type))
321 return emitOpError(
"argument must be of storage type");
324 return emitOpError(
"inout ports are not supported");
328LogicalResult ModelOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
329 auto fnAttrs = std::array{getInitialFnAttr(), getFinalFnAttr()};
330 auto nouns = std::array{
"initializer",
"finalizer"};
331 for (
auto [fnAttr, noun] :
llvm::zip(fnAttrs, nouns)) {
334 auto fn = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*
this, fnAttr);
336 return emitOpError() << noun <<
" '" << fnAttr.getValue()
337 <<
"' does not reference a valid function";
338 if (!llvm::equal(fn.getArgumentTypes(), getBody().getArgumentTypes())) {
339 auto diag = emitError() << noun <<
" '" << fnAttr.getValue()
340 <<
"' arguments must match arguments of model";
341 diag.attachNote(fn.getLoc()) << noun <<
" declared here:";
352LogicalResult LutOp::verify() {
353 Location firstSideEffectOpLoc = UnknownLoc::get(getContext());
354 const WalkResult result = getBody().walk([&](Operation *op) {
355 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
356 SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>> effects;
357 memOp.getEffects(effects);
359 if (!effects.empty()) {
360 firstSideEffectOpLoc = memOp->getLoc();
361 return WalkResult::interrupt();
365 return WalkResult::advance();
368 if (result.wasInterrupted())
369 return emitOpError(
"no operations with side-effects allowed inside a LUT")
370 .attachNote(firstSideEffectOpLoc)
371 <<
"first operation with side-effects here";
380LogicalResult VectorizeOp::verify() {
381 if (getInputs().
empty())
382 return emitOpError(
"there has to be at least one input vector");
384 if (!llvm::all_equal(llvm::map_range(
385 getInputs(), [](OperandRange range) {
return range.size(); })))
386 return emitOpError(
"all input vectors must have the same size");
388 for (OperandRange range : getInputs()) {
389 if (!llvm::all_equal(range.getTypes()))
390 return emitOpError(
"all input vector lane types must match");
393 return emitOpError(
"input vector must have at least one element");
396 if (getResults().
empty())
397 return emitOpError(
"must have at least one result");
399 if (!llvm::all_equal(getResults().getTypes()))
400 return emitOpError(
"all result types must match");
402 if (getResults().size() != getInputs().front().size())
403 return emitOpError(
"number results must match input vector size");
409 if (isa<VectorType>(base))
412 if (
auto vectorTy = dyn_cast<VectorType>(vectorized)) {
413 if (vectorTy.getElementType() != base)
416 return vectorTy.getDimSize(0);
419 if (vectorized.getIntOrFloatBitWidth() < base.getIntOrFloatBitWidth())
422 if (vectorized.getIntOrFloatBitWidth() % base.getIntOrFloatBitWidth() == 0)
423 return vectorized.getIntOrFloatBitWidth() / base.getIntOrFloatBitWidth();
428LogicalResult VectorizeOp::verifyRegions() {
429 auto returnOp = cast<VectorizeReturnOp>(getBody().front().getTerminator());
430 TypeRange bodyArgTypes = getBody().front().getArgumentTypes();
432 if (bodyArgTypes.size() != getInputs().size())
434 "number of block arguments must match number of input vectors");
437 if (returnOp.getValue().getType() == getResultTypes().front()) {
438 for (
auto [i, argTy] :
llvm::enumerate(bodyArgTypes))
439 if (argTy != getInputs()[i].getTypes().front())
440 return emitOpError(
"if terminator type matches result type the "
441 "argument types must match the input types");
448 getResultTypes().front());
450 for (
auto [i, argTy] :
llvm::enumerate(bodyArgTypes)) {
451 Type inputTy = getInputs()[i].getTypes().front();
453 if (failed(argWidth))
454 return emitOpError(
"block argument must be a scalar variant of the "
455 "vectorized operand");
457 if (*argWidth != width)
458 return emitOpError(
"input and output vector width must match");
466 returnOp.getValue().getType());
468 for (
auto [i, argTy] :
llvm::enumerate(bodyArgTypes)) {
469 Type inputTy = getInputs()[i].getTypes().front();
471 if (failed(argWidth))
473 "block argument must be a vectorized variant of the operand");
475 if (*argWidth != width)
476 return emitOpError(
"input and output vector width must match");
478 if (getInputs()[i].size() > 1 && argWidth != getInputs()[i].size())
480 "when boundary not vectorized the number of vector element "
481 "operands must match the width of the vectorized body");
487 return returnOp.emitOpError(
488 "operand type must match parent op's result value or be a vectorized or "
489 "non-vectorized variant of it");
492bool VectorizeOp::isBoundaryVectorized() {
493 return getInputs().front().size() == 1;
495bool VectorizeOp::isBodyVectorized() {
496 auto returnOp = cast<VectorizeReturnOp>(getBody().front().getTerminator());
497 if (isBoundaryVectorized() &&
498 returnOp.getValue().getType() == getResultTypes().front())
502 returnOp.getValue().getType());
513void SimInstantiateOp::print(OpAsmPrinter &p) {
514 BlockArgument modelArg = getBody().getArgument(0);
515 auto modelType = cast<SimModelInstanceType>(modelArg.getType());
517 p <<
" " << modelType.getModel() <<
" as ";
518 p.printRegionArgument(modelArg, {},
true);
520 if (getRuntimeModel() || getRuntimeArgs()) {
522 if (getRuntimeModel())
523 p << getRuntimeModelAttr();
525 if (getRuntimeArgs())
526 p << getRuntimeArgsAttr();
530 p.printOptionalAttrDictWithKeyword(
531 getOperation()->getAttrs(),
532 {getRuntimeModelAttrName(), getRuntimeArgsAttrName()});
536 p.printRegion(getBody(),
false);
539ParseResult SimInstantiateOp::parse(OpAsmParser &parser,
540 OperationState &result) {
541 StringAttr modelName;
542 if (failed(parser.parseSymbolName(modelName)))
545 if (failed(parser.parseKeyword(
"as")))
548 OpAsmParser::Argument modelArg;
549 if (failed(parser.parseArgument(modelArg,
false,
false)))
552 if (succeeded(parser.parseOptionalKeyword(
"runtime"))) {
553 StringAttr runtimeSym;
554 StringAttr runtimeArgs;
555 auto symOpt = parser.parseOptionalSymbolName(runtimeSym);
556 if (parser.parseLParen())
558 auto nameOpt = parser.parseOptionalAttribute(runtimeArgs);
559 if (parser.parseRParen())
561 if (succeeded(symOpt))
563 SimInstantiateOp::getRuntimeModelAttrName(result.name),
564 FlatSymbolRefAttr::get(runtimeSym));
565 if (nameOpt.has_value())
566 result.addAttribute(SimInstantiateOp::getRuntimeArgsAttrName(result.name),
570 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
573 MLIRContext *ctxt = result.getContext();
575 SimModelInstanceType::get(ctxt, FlatSymbolRefAttr::get(ctxt, modelName));
577 std::unique_ptr<Region> body = std::make_unique<Region>();
578 if (failed(parser.parseRegion(*body, {modelArg})))
581 result.addRegion(std::move(body));
585LogicalResult SimInstantiateOp::verifyRegions() {
586 Region &body = getBody();
587 if (body.getNumArguments() != 1)
588 return emitError(
"entry block of body region must have the model instance "
589 "as a single argument");
590 if (!llvm::isa<SimModelInstanceType>(body.getArgument(0).getType()))
591 return emitError(
"entry block argument type is not a model instance");
596SimInstantiateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
599 symbolTable, getOperation(),
600 llvm::cast<SimModelInstanceType>(getBody().getArgument(0).getType())
606 if (getRuntimeModel().has_value()) {
607 Operation *runtimeModelOp = symbolTable.lookupNearestSymbolFrom(
608 getOperation(), getRuntimeModelAttr());
609 if (!runtimeModelOp) {
610 emitOpError(
"runtime model not found");
612 }
else if (!isa<RuntimeModelOp>(runtimeModelOp)) {
613 emitOpError(
"referenced runtime model is not a RuntimeModelOp");
618 return success(!failed);
626SimSetInputOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
628 symbolTable, getOperation(),
629 llvm::cast<SimModelInstanceType>(getInstance().getType())
635 std::optional<hw::ModulePort> port =
getModulePort(moduleOp, getInput());
637 return emitOpError(
"port not found on model");
639 if (port->dir != hw::ModulePort::Direction::Input &&
640 port->dir != hw::ModulePort::Direction::InOut)
641 return emitOpError(
"port is not an input port");
643 if (port->type != getValue().getType())
645 "mismatched types between value and model port, port expects ")
656SimGetPortOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
658 symbolTable, getOperation(),
659 llvm::cast<SimModelInstanceType>(getInstance().getType())
667 return emitOpError(
"port not found on model");
669 if (port->type != getValue().getType())
671 "mismatched types between value and model port, port expects ")
681LogicalResult SimStepOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
683 symbolTable, getOperation(),
684 llvm::cast<SimModelInstanceType>(getInstance().getType())
698SimGetTimeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
700 symbolTable, getOperation(),
701 llvm::cast<SimModelInstanceType>(getInstance().getType())
715SimSetTimeOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
717 symbolTable, getOperation(),
718 llvm::cast<SimModelInstanceType>(getInstance().getType())
732SimGetNextWakeupOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
734 symbolTable, getOperation(),
735 llvm::cast<SimModelInstanceType>(getInstance().getType())
751 FlatSymbolRefAttr callee,
754 SymbolTableCollection &symTable) {
756 symTable.lookupNearestSymbolFrom<CoroutineDefineOp>(op, callee);
758 return op->emitOpError() <<
"`" << callee.getValue()
759 <<
"` does not reference a valid "
760 "`arc.coroutine.define`";
762 auto fnType = defineOp.getFunctionType();
772ParseResult CoroutineDefineOp::parse(OpAsmParser &parser,
773 OperationState &result) {
775 [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
776 function_interface_impl::VariadicFlag,
777 std::string &) {
return builder.getFunctionType(argTypes, results); };
779 return function_interface_impl::parseFunctionOp(
780 parser, result,
false,
781 getFunctionTypeAttrName(result.name), buildFuncType,
782 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
785void CoroutineDefineOp::print(OpAsmPrinter &p) {
786 function_interface_impl::printFunctionOp(
787 p, *
this,
false,
"function_type", getArgAttrsAttrName(),
788 getResAttrsAttrName());
796CoroutineCallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
801 auto callee = (*this)->getAttrOfType<FlatSymbolRefAttr>(
"callee");
803 getResults().getTypes(), symbolTable);
814CoroutineInstanceOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
815 auto callee = (*this)->getAttrOfType<FlatSymbolRefAttr>(
"callee");
817 symbolTable.lookupNearestSymbolFrom<CoroutineDefineOp>(*
this, callee);
819 return emitOpError() <<
"`" << callee.getValue()
820 <<
"` does not reference a valid "
821 "`arc.coroutine.define`";
823 auto fnType = defineOp.getFunctionType();
824 auto fnResults = fnType.getResults();
825 if (fnResults.empty() || !fnResults.back().isInteger(64))
826 return emitOpError() <<
"referenced coroutine `" << callee.getValue()
827 <<
"` must produce an `i64` wakeup time as its "
831 getArgs().getTypes(),
"operand")))
834 getResults().getTypes(),
"result")))
848 TypeRange yieldOperands) {
849 auto parent = op->getParentOfType<CoroutineDefineOp>();
854LogicalResult CoroutineYieldOp::verify() {
863 auto parent = (*this)->getParentOfType<CoroutineDefineOp>();
864 TypeRange coroutineArgTypes = parent.getArgumentTypes();
865 TypeRange destArgTypes = getDest()->getArgumentTypes();
866 if (destArgTypes.size() >= coroutineArgTypes.size())
868 *
this, coroutineArgTypes,
869 destArgTypes.take_front(coroutineArgTypes.size()),
870 "destination resume argument")))
881SuccessorOperands CoroutineYieldOp::getSuccessorOperands(
unsigned index) {
882 assert(index == 0 &&
"invalid successor index");
883 auto parent = (*this)->getParentOfType<CoroutineDefineOp>();
884 return SuccessorOperands(parent.getArgumentTypes().size(),
885 getDestOperandsMutable());
888LogicalResult CoroutineReturnOp::verify() {
892LogicalResult CoroutineHaltOp::verify() {
900LogicalResult ExecuteOp::verifyRegions() {
902 getBody().getArgumentTypes(),
"input");
909LogicalResult ArrayRefAllocOp::verify() {
910 if (
auto init = getInit()) {
911 if (init->size() != getType().getNumElements()) {
912 return emitOpError(
"init size does not match array size; init had size ")
913 << init->size() <<
" but array has size "
914 << getType().getNumElements();
917 unsigned elemBitwidth = getType().getElementType().getIntOrFloatBitWidth();
918 for (APInt value : init->getAsValueRange<IntegerAttr>()) {
919 if (value.getBitWidth() != elemBitwidth) {
920 return emitOpError(
"expected element to be of type ")
921 << getType().getElementType();
928#include "circt/Dialect/Arc/ArcInterfaces.cpp.inc"
930#define GET_OP_CLASSES
931#include "circt/Dialect/Arc/Arc.cpp.inc"
static FailureOr< unsigned > getVectorWidth(Type base, Type vectorized)
static std::optional< hw::ModulePort > getModulePort(Operation *moduleOp, StringRef portName)
static bool isSupportedModuleOp(Operation *moduleOp)
static LogicalResult verifyArcSymbolUse(Operation *op, TypeRange inputs, TypeRange results, SymbolTableCollection &symbolTable)
static LogicalResult verifyTypeListEquivalence(Operation *op, TypeRange expectedTypeList, TypeRange actualTypeList, StringRef elementName)
static LogicalResult verifyCoroutineCallTypes(Operation *op, FlatSymbolRefAttr callee, TypeRange operands, TypeRange results, SymbolTableCollection &symTable)
Resolve the callee symbol to a CoroutineDefineOp and verify that the given operand and result types m...
static LogicalResult verifyCoroutineTerminator(Operation *op, TypeRange yieldOperands)
static Operation * getSupportedModuleOp(SymbolTableCollection &symbolTable, Operation *pointing, StringAttr symbol)
Fetches the operation pointed to by pointing with name symbol, checking that it is a supported model ...
assert(baseType &&"element must be base type")
static PortInfo getPort(ModuleTy &mod, size_t idx)
static Location getLoc(DefSlot slot)
static Block * getBodyBlock(FModuleLike mod)
static InstancePath empty
Direction
The direction of a Component or Cell port.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn