11 #include "mlir/IR/Builders.h"
12 #include "mlir/IR/OpImplementation.h"
13 #include "mlir/IR/PatternMatch.h"
14 #include "mlir/IR/SymbolTable.h"
15 #include "mlir/Interfaces/FunctionImplementation.h"
16 #include "mlir/Interfaces/SideEffectInterfaces.h"
17 #include "llvm/ADT/SmallPtrSet.h"
18 #include "llvm/ADT/TypeSwitch.h"
20 using namespace circt;
29 TypeRange expectedTypeList,
30 TypeRange actualTypeList,
31 StringRef elementName) {
32 if (expectedTypeList.size() != actualTypeList.size())
33 return op->emitOpError(
"incorrect number of ")
34 << elementName <<
"s: expected " << expectedTypeList.size()
35 <<
", but got " << actualTypeList.size();
37 for (
unsigned i = 0, e = expectedTypeList.size(); i != e; ++i) {
38 if (expectedTypeList[i] != actualTypeList[i]) {
39 auto diag = op->emitOpError(elementName)
40 <<
" type mismatch: " << elementName <<
" #" << i;
41 diag.attachNote() <<
"expected type: " << expectedTypeList[i];
42 diag.attachNote() <<
" actual type: " << actualTypeList[i];
52 SymbolTableCollection &symbolTable) {
54 auto arcName = op->getAttrOfType<FlatSymbolRefAttr>(
"arc");
57 assert(arcName &&
"FlatSymbolRefAttr called 'arc' missing");
58 DefineOp arc = symbolTable.lookupNearestSymbolFrom<DefineOp>(op, arcName);
60 return op->emitOpError() <<
"`" << arcName.getValue()
61 <<
"` does not reference a valid `arc.define`";
64 auto type = arc.getFunctionType();
77 return llvm::isa<arc::ModelOp, hw::HWModuleLike>(moduleOp);
83 Operation *pointing, StringAttr symbol) {
84 Operation *moduleOp = symbolTable.lookupNearestSymbolFrom(pointing, symbol);
86 pointing->emitOpError(
"model not found");
91 pointing->emitOpError(
"model symbol does not point to a supported model "
92 "operation, points to ")
93 << moduleOp->getName() <<
" instead";
101 StringRef portName) {
102 auto findRightPort = [&](
auto ports) -> std::optional<hw::ModulePort> {
103 const hw::ModulePort *port = llvm::find_if(
104 ports, [&](hw::ModulePort port) {
return port.name == portName; });
105 if (port == ports.end())
110 return TypeSwitch<Operation *, std::optional<hw::ModulePort>>(moduleOp)
112 [&](arc::ModelOp modelOp) -> std::optional<hw::ModulePort> {
113 return findRightPort(modelOp.getIo().getPorts());
115 .Case<hw::HWModuleLike>(
116 [&](hw::HWModuleLike moduleLike) -> std::optional<hw::ModulePort> {
117 return findRightPort(moduleLike.getPortList());
119 .Default([](Operation *) {
return std::nullopt; });
126 ParseResult DefineOp::parse(OpAsmParser &parser, OperationState &result) {
128 [](Builder &
builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
129 function_interface_impl::VariadicFlag,
130 std::string &) {
return builder.getFunctionType(argTypes, results); };
132 return function_interface_impl::parseFunctionOp(
133 parser, result,
false,
134 getFunctionTypeAttrName(result.name), buildFuncType,
135 getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name));
138 void DefineOp::print(OpAsmPrinter &p) {
139 function_interface_impl::printFunctionOp(
140 p, *
this,
false,
"function_type", getArgAttrsAttrName(),
141 getResAttrsAttrName());
144 LogicalResult DefineOp::verifyRegions() {
150 for (
auto &op : getBodyBlock()) {
151 if (isMemoryEffectFree(&op))
160 auto diag = mlir::emitError(getLoc(),
"body contains non-pure operation");
161 diag.attachNote(op.getLoc()).append(
"first non-pure operation here: ");
167 bool DefineOp::isPassthrough() {
168 if (getNumArguments() != getNumResults())
172 llvm::zip(getArguments(), getBodyBlock().getTerminator()->getOperands()),
173 [](
const auto &argAndRes) {
174 return std::get<0>(argAndRes) == std::get<1>(argAndRes);
182 LogicalResult OutputOp::verify() {
183 auto *parent = (*this)->getParentOp();
184 TypeRange expectedTypes = parent->getResultTypes();
185 if (
auto defOp = dyn_cast<DefineOp>(parent))
186 expectedTypes = defOp.getResultTypes();
188 TypeRange actualTypes = getOperands().getTypes();
196 LogicalResult StateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
198 getResults().getTypes(), symbolTable);
201 LogicalResult StateOp::verify() {
202 if (getLatency() < 1)
203 return emitOpError(
"latency must be a positive integer");
205 if (!getOperation()->getParentOfType<ClockDomainOp>() && !getClock())
206 return emitOpError(
"outside a clock domain requires a clock");
208 if (getOperation()->getParentOfType<ClockDomainOp>() && getClock())
209 return emitOpError(
"inside a clock domain cannot have a clock");
218 LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
220 getResults().getTypes(), symbolTable);
223 bool CallOp::isClocked() {
return false; }
225 Value CallOp::getClock() {
return Value{}; }
227 void CallOp::eraseClock() {}
229 uint32_t CallOp::getLatency() {
return 0; }
235 SmallVector<Type> MemoryWritePortOp::getArcResultTypes() {
236 auto memType = cast<MemoryType>(getMemory().getType());
237 SmallVector<Type> resultTypes{memType.getAddressType(),
238 memType.getWordType()};
242 resultTypes.push_back(memType.getWordType());
247 MemoryWritePortOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
252 LogicalResult MemoryWritePortOp::verify() {
253 if (getLatency() < 1)
254 return emitOpError(
"latency must be at least 1");
256 if (!getOperation()->getParentOfType<ClockDomainOp>() && !getClock())
257 return emitOpError(
"outside a clock domain requires a clock");
259 if (getOperation()->getParentOfType<ClockDomainOp>() && getClock())
260 return emitOpError(
"inside a clock domain cannot have a clock");
269 LogicalResult ClockDomainOp::verifyRegions() {
271 getInputs().getTypes(),
"input");
279 SmallString<32> buf(
"in_");
281 setNameFn(getState(), buf);
289 SmallString<32> buf(
"out_");
291 setNameFn(getState(), buf);
298 LogicalResult ModelOp::verify() {
299 if (getBodyBlock().getArguments().size() != 1)
300 return emitOpError(
"must have exactly one argument");
301 if (
auto type = getBodyBlock().getArgument(0).getType();
302 !isa<StorageType>(type))
303 return emitOpError(
"argument must be of storage type");
304 for (
const hw::ModulePort &port : getIo().getPorts())
306 return emitOpError(
"inout ports are not supported");
314 LogicalResult LutOp::verify() {
316 const WalkResult result = getBody().walk([&](Operation *op) {
317 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(op)) {
318 SmallVector<SideEffects::EffectInstance<MemoryEffects::Effect>> effects;
319 memOp.getEffects(effects);
321 if (!effects.empty()) {
322 firstSideEffectOpLoc = memOp->getLoc();
323 return WalkResult::interrupt();
327 return WalkResult::advance();
330 if (result.wasInterrupted())
331 return emitOpError(
"no operations with side-effects allowed inside a LUT")
332 .attachNote(firstSideEffectOpLoc)
333 <<
"first operation with side-effects here";
342 LogicalResult VectorizeOp::verify() {
343 if (getInputs().
empty())
344 return emitOpError(
"there has to be at least one input vector");
346 if (!llvm::all_equal(llvm::map_range(
347 getInputs(), [](OperandRange range) {
return range.size(); })))
348 return emitOpError(
"all input vectors must have the same size");
350 for (OperandRange range : getInputs()) {
351 if (!llvm::all_equal(range.getTypes()))
352 return emitOpError(
"all input vector lane types must match");
355 return emitOpError(
"input vector must have at least one element");
358 if (getInputs().front().size() > 1 &&
359 !isa<IntegerType>(getInputs().front().front().getType()))
360 return emitOpError(
"input vector element type must be a signless integer");
362 if (getResults().
empty())
363 return emitOpError(
"must have at least one result");
365 if (!llvm::all_equal(getResults().getTypes()))
366 return emitOpError(
"all result types must match");
368 if (getResults().size() != getInputs().front().size())
369 return emitOpError(
"number results must match input vector size");
371 if (getResults().size() > 1 &&
372 !isa<IntegerType>(getResults().front().getType()))
374 "may only return a vector type if boundary is already vectorized");
380 if (isa<VectorType>(base))
383 if (
auto vectorTy = dyn_cast<VectorType>(vectorized)) {
384 if (vectorTy.getElementType() != base)
387 return vectorTy.getDimSize(0);
390 if (vectorized.getIntOrFloatBitWidth() < base.getIntOrFloatBitWidth())
393 if (vectorized.getIntOrFloatBitWidth() % base.getIntOrFloatBitWidth() == 0)
394 return vectorized.getIntOrFloatBitWidth() / base.getIntOrFloatBitWidth();
399 LogicalResult VectorizeOp::verifyRegions() {
400 auto returnOp = cast<VectorizeReturnOp>(getBody().front().getTerminator());
401 TypeRange bodyArgTypes = getBody().front().getArgumentTypes();
403 if (bodyArgTypes.size() != getInputs().size())
405 "number of block arguments must match number of input vectors");
408 if (returnOp.getValue().getType() == getResultTypes().front()) {
409 for (
auto [i, argTy] : llvm::enumerate(bodyArgTypes))
410 if (argTy != getInputs()[i].getTypes().front())
411 return emitOpError(
"if terminator type matches result type the "
412 "argument types must match the input types");
419 getResultTypes().front());
421 for (
auto [i, argTy] : llvm::enumerate(bodyArgTypes)) {
422 Type inputTy = getInputs()[i].getTypes().front();
424 if (failed(argWidth))
425 return emitOpError(
"block argument must be a scalar variant of the "
426 "vectorized operand");
428 if (*argWidth !=
width)
429 return emitOpError(
"input and output vector width must match");
437 returnOp.getValue().getType());
439 for (
auto [i, argTy] : llvm::enumerate(bodyArgTypes)) {
440 Type inputTy = getInputs()[i].getTypes().front();
442 if (failed(argWidth))
444 "block argument must be a vectorized variant of the operand");
446 if (*argWidth !=
width)
447 return emitOpError(
"input and output vector width must match");
449 if (getInputs()[i].size() > 1 && argWidth != getInputs()[i].size())
451 "when boundary not vectorized the number of vector element "
452 "operands must match the width of the vectorized body");
458 return returnOp.emitOpError(
459 "operand type must match parent op's result value or be a vectorized or "
460 "non-vectorized variant of it");
463 bool VectorizeOp::isBoundaryVectorized() {
464 return getInputs().front().size() == 1;
466 bool VectorizeOp::isBodyVectorized() {
467 auto returnOp = cast<VectorizeReturnOp>(getBody().front().getTerminator());
468 if (isBoundaryVectorized() &&
469 returnOp.getValue().getType() == getResultTypes().front())
473 returnOp.getValue().getType());
484 void SimInstantiateOp::print(OpAsmPrinter &p) {
485 BlockArgument modelArg = getBody().getArgument(0);
486 auto modelType = cast<SimModelInstanceType>(modelArg.getType());
488 p <<
" " << modelType.getModel() <<
" as ";
489 p.printRegionArgument(modelArg, {},
true);
491 p.printOptionalAttrDictWithKeyword(getOperation()->getAttrs());
495 p.printRegion(getBody(),
false);
498 ParseResult SimInstantiateOp::parse(OpAsmParser &parser,
499 OperationState &result) {
500 StringAttr modelName;
501 if (failed(parser.parseSymbolName(modelName)))
504 if (failed(parser.parseKeyword(
"as")))
507 OpAsmParser::Argument modelArg;
508 if (failed(parser.parseArgument(modelArg,
false,
false)))
511 if (failed(parser.parseOptionalAttrDictWithKeyword(result.attributes)))
514 MLIRContext *
ctxt = result.getContext();
518 std::unique_ptr<Region> body = std::make_unique<Region>();
519 if (failed(parser.parseRegion(*body, {modelArg})))
522 result.addRegion(std::move(body));
526 LogicalResult SimInstantiateOp::verifyRegions() {
527 Region &body = getBody();
528 if (body.getNumArguments() != 1)
529 return emitError(
"entry block of body region must have the model instance "
530 "as a single argument");
531 if (!llvm::isa<SimModelInstanceType>(body.getArgument(0).getType()))
532 return emitError(
"entry block argument type is not a model instance");
537 SimInstantiateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
539 symbolTable, getOperation(),
540 llvm::cast<SimModelInstanceType>(getBody().getArgument(0).getType())
554 SimSetInputOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
556 symbolTable, getOperation(),
557 llvm::cast<SimModelInstanceType>(getInstance().getType())
563 std::optional<hw::ModulePort> port =
getModulePort(moduleOp, getInput());
565 return emitOpError(
"port not found on model");
569 return emitOpError(
"port is not an input port");
571 if (port->type != getValue().getType())
573 "mismatched types between value and model port, port expects ")
584 SimGetPortOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
586 symbolTable, getOperation(),
587 llvm::cast<SimModelInstanceType>(getInstance().getType())
595 return emitOpError(
"port not found on model");
597 if (port->type != getValue().getType())
599 "mismatched types between value and model port, port expects ")
609 LogicalResult SimStepOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
611 symbolTable, getOperation(),
612 llvm::cast<SimModelInstanceType>(getInstance().getType())
621 #include "circt/Dialect/Arc/ArcInterfaces.cpp.inc"
623 #define GET_OP_CLASSES
624 #include "circt/Dialect/Arc/Arc.cpp.inc"
static bool isSupportedModuleOp(Operation *moduleOp)
static LogicalResult verifyArcSymbolUse(Operation *op, TypeRange inputs, TypeRange results, SymbolTableCollection &symbolTable)
static LogicalResult verifyTypeListEquivalence(Operation *op, TypeRange expectedTypeList, TypeRange actualTypeList, StringRef elementName)
static FailureOr< unsigned > getVectorWidth(Type base, Type vectorized)
static Operation * getSupportedModuleOp(SymbolTableCollection &symbolTable, Operation *pointing, StringAttr symbol)
Fetches the operation pointed to by pointing with name symbol, checking that it is a supported model ...
static std::optional< hw::ModulePort > getModulePort(Operation *moduleOp, StringRef portName)
assert(baseType &&"element must be base type")
static Attribute getAttr(ArrayRef< NamedAttribute > attrs, StringRef name)
Get an attribute by name from a list of named attributes.
static PortInfo getPort(ModuleTy &mod, size_t idx)
static InstancePath empty
llvm::SmallVector< StringAttr > inputs
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn