10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Interfaces/FunctionImplementation.h"
16 #include "llvm/Support/FormatVariadic.h"
19 using namespace circt;
26 void MachineOp::build(OpBuilder &
builder, OperationState &state, StringRef name,
27 StringRef initialStateName, FunctionType type,
28 ArrayRef<NamedAttribute> attrs,
29 ArrayRef<DictionaryAttr> argAttrs) {
30 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
32 state.addAttribute(MachineOp::getFunctionTypeAttrName(state.name),
34 state.addAttribute(
"initialState",
36 state.attributes.append(attrs.begin(), attrs.end());
37 Region *region = state.addRegion();
39 region->push_back(body);
42 SmallVector<Location, 4>(type.getNumInputs(),
builder.getUnknownLoc()));
46 assert(type.getNumInputs() == argAttrs.size());
47 function_interface_impl::addArgAndResultAttrs(
49 std::nullopt, MachineOp::getArgAttrsAttrName(state.name),
50 MachineOp::getResAttrsAttrName(state.name));
54 StateOp MachineOp::getInitialStateOp() {
55 return dyn_cast_or_null<StateOp>(lookupSymbol(getInitialState()));
59 if (
auto args = getArgNames())
60 return cast<StringAttr>((*args)[i]);
66 if (
auto resNameAttrs = getResNames())
67 return cast<StringAttr>((*resNameAttrs)[i]);
73 void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
75 auto machineType = getFunctionType();
76 for (
unsigned i = 0, e = machineType.getNumInputs(); i < e; ++i) {
82 port.type = machineType.getInput(i);
84 ports.push_back(port);
87 for (
unsigned i = 0, e = machineType.getNumResults(); i < e; ++i) {
93 port.type = machineType.getResult(i);
95 ports.push_back(port);
99 ParseResult MachineOp::parse(OpAsmParser &parser, OperationState &result) {
101 [&](Builder &
builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
102 function_interface_impl::VariadicFlag,
103 std::string &) {
return builder.getFunctionType(argTypes, results); };
105 return function_interface_impl::parseFunctionOp(
106 parser, result,
false,
107 MachineOp::getFunctionTypeAttrName(result.name), buildFuncType,
108 MachineOp::getArgAttrsAttrName(result.name),
109 MachineOp::getResAttrsAttrName(result.name));
112 void MachineOp::print(OpAsmPrinter &p) {
113 function_interface_impl::printFunctionOp(
114 p, *
this,
false, getFunctionTypeAttrName(),
115 getArgAttrsAttrName(), getResAttrsAttrName());
120 if (rangeA.size() != rangeB.size())
121 return emitError(loc) <<
"mismatch in number of types compared ("
122 << rangeA.size() <<
" != " << rangeB.size() <<
")";
125 for (
auto zip : llvm::zip(rangeA, rangeB)) {
126 auto typeA = std::get<0>(zip);
127 auto typeB = std::get<1>(zip);
129 return emitError(loc) <<
"type mismatch at index " << index <<
" ("
130 << typeA <<
" != " << typeB <<
")";
137 LogicalResult MachineOp::verify() {
146 front().getArgumentTypes())))
148 "entry block argument types must match the machine input types");
151 if (!llvm::hasSingleElement(*
this))
152 return emitOpError(
"must only have a single block");
155 if (!getInitialStateOp())
156 return emitOpError(
"initial state '" + getInitialState() +
157 "' was not defined in the machine");
159 if (getArgNames() && getArgNames()->size() != getArgumentTypes().size())
160 return emitOpError() <<
"number of machine arguments ("
161 << getArgumentTypes().size()
163 "not match the provided number "
164 "of argument names ("
165 << getArgNames()->size() <<
")";
167 if (getResNames() && getResNames()->size() != getResultTypes().size())
168 return emitOpError() <<
"number of machine results ("
169 << getResultTypes().size()
171 "not match the provided number "
173 << getResNames()->size() <<
")";
179 SmallVector<hw::PortInfo> ports;
180 auto argNames = getArgNames();
181 auto argTypes = getFunctionType().getInputs();
182 for (
unsigned i = 0, e = argTypes.size(); i < e; ++i) {
183 bool isInOut =
false;
184 auto type = argTypes[i];
186 if (
auto inout = dyn_cast<hw::InOutType>(type)) {
188 type = inout.getElementType();
195 {{argNames ? cast<StringAttr>((*argNames)[i])
196 : StringAttr::
get(getContext(), Twine(
"input") + Twine(i)),
203 auto resultNames = getResNames();
204 auto resultTypes = getFunctionType().getResults();
205 for (
unsigned i = 0, e = resultTypes.size(); i < e; ++i) {
206 ports.push_back({{resultNames ? cast<StringAttr>((*resultNames)[i])
207 : StringAttr::
get(getContext(),
208 Twine(
"output") + Twine(i)),
223 auto module = (*this)->getParentOfType<ModuleOp>();
224 return module.lookupSymbol<
MachineOp>(getMachine());
227 LogicalResult InstanceOp::verify() {
228 auto m = getMachineOp();
230 return emitError(
"cannot find machine definition '") << getMachine() <<
"'";
236 function_ref<
void(Value, StringRef)> setNameFn) {
237 setNameFn(getInstance(),
getName());
244 template <
typename OpType>
246 auto machine = op.getMachineOp();
248 return op.emitError(
"cannot find machine definition");
251 if (failed(
compareTypes(op.getLoc(), machine.getArgumentTypes(),
252 op.getInputs().getTypes()))) {
254 op.emitOpError(
"operand types must match the machine input types");
255 diag.attachNote(machine->getLoc()) <<
"original machine declared here";
260 if (failed(
compareTypes(op.getLoc(), machine.getResultTypes(),
261 op.getOutputs().getTypes()))) {
263 op.emitOpError(
"result types must match the machine output types");
264 diag.attachNote(machine->getLoc()) <<
"original machine declared here";
273 auto instanceOp = getInstance().getDefiningOp<InstanceOp>();
277 return instanceOp.getMachineOp();
290 auto module = (*this)->getParentOfType<ModuleOp>();
291 return module.lookupSymbol<
MachineOp>(getMachine());
297 return getMachineOp().getPortList();
301 StringRef HWInstanceOp::getModuleName() {
return getMachine(); }
302 FlatSymbolRefAttr HWInstanceOp::getModuleNameAttr() {
return getMachineAttr(); }
304 mlir::StringAttr HWInstanceOp::getInstanceNameAttr() {
return getNameAttr(); }
312 void StateOp::build(OpBuilder &
builder, OperationState &state,
313 StringRef stateName) {
314 OpBuilder::InsertionGuard guard(
builder);
315 Region *output = state.addRegion();
316 output->push_back(
new Block());
317 builder.setInsertionPointToEnd(&output->back());
319 Region *transitions = state.addRegion();
320 transitions->push_back(
new Block());
321 state.addAttribute(
"sym_name",
builder.getStringAttr(stateName));
324 SetVector<StateOp> StateOp::getNextStates() {
325 SmallVector<StateOp> nextStates;
327 getTransitions().getOps<TransitionOp>(),
328 std::inserter(nextStates, nextStates.begin()),
329 [](
TransitionOp transition) { return transition.getNextStateOp(); });
330 return SetVector<StateOp>(nextStates.begin(), nextStates.end());
333 LogicalResult StateOp::canonicalize(
StateOp op, PatternRewriter &rewriter) {
334 bool hasAlwaysTakenTransition =
false;
335 SmallVector<TransitionOp, 4> transitionsToErase;
337 for (
auto transition : op.getTransitions().getOps<
TransitionOp>()) {
338 if (!hasAlwaysTakenTransition)
339 hasAlwaysTakenTransition = transition.isAlwaysTaken();
341 transitionsToErase.push_back(transition);
344 for (
auto transition : transitionsToErase)
345 rewriter.eraseOp(transition);
347 return failure(transitionsToErase.empty());
350 LogicalResult StateOp::verify() {
353 if (parent.getNumResults() != 0 && (getOutput().empty()))
354 return emitOpError(
"state must have a non-empty output region when the "
355 "machine has results.");
357 if (!getOutput().
empty()) {
359 Block *outputBlock = &getOutput().front();
360 if (outputBlock->empty() || !isa<fsm::OutputOp>(outputBlock->back()))
361 return emitOpError(
"output block must have a single OutputOp terminator");
368 if (getOutput().
empty()) {
369 OpBuilder::InsertionGuard g(
builder);
370 auto *block =
new Block();
371 getOutput().push_back(block);
372 builder.setInsertionPointToStart(block);
375 return &getOutput().front();
382 LogicalResult OutputOp::verify() {
383 if ((*this)->getParentRegion() ==
384 &(*this)->getParentOfType<
StateOp>().getTransitions()) {
385 if (getNumOperands() != 0)
386 emitOpError(
"transitions region must not output any value");
392 auto machine = (*this)->getParentOfType<
MachineOp>();
394 compareTypes(getLoc(), machine.getResultTypes(), getOperandTypes())))
395 return emitOpError(
"operand types must match the machine output types");
404 void TransitionOp::build(OpBuilder &
builder, OperationState &state,
405 StringRef nextState) {
408 state.addAttribute(
"nextState",
412 void TransitionOp::build(OpBuilder &
builder, OperationState &state,
414 build(
builder, state, nextState.getName());
418 if (getGuard().
empty()) {
419 OpBuilder::InsertionGuard g(
builder);
420 auto *block =
new Block();
421 getGuard().push_back(block);
422 builder.setInsertionPointToStart(block);
423 builder.create<fsm::ReturnOp>(getLoc());
425 return &getGuard().front();
429 if (getAction().
empty())
430 getAction().push_back(
new Block());
431 return &getAction().front();
435 StateOp TransitionOp::getNextStateOp() {
436 auto machineOp = (*this)->getParentOfType<
MachineOp>();
440 return machineOp.lookupSymbol<
StateOp>(getNextState());
443 bool TransitionOp::isAlwaysTaken() {
447 auto guardReturn = getGuardReturn();
448 if (guardReturn.getNumOperands() == 0)
451 if (
auto constantOp =
452 guardReturn.getOperand().getDefiningOp<mlir::arith::ConstantOp>())
453 return cast<BoolAttr>(constantOp.getValue()).getValue();
458 LogicalResult TransitionOp::canonicalize(
TransitionOp op,
459 PatternRewriter &rewriter) {
461 auto guardReturn = op.getGuardReturn();
462 if (guardReturn.getNumOperands() == 1)
463 if (
auto constantOp = guardReturn.getOperand()
464 .getDefiningOp<mlir::arith::ConstantOp>()) {
466 if (cast<BoolAttr>(constantOp.getValue()).getValue()) {
469 rewriter.setInsertionPoint(guardReturn);
470 rewriter.
create<fsm::ReturnOp>(guardReturn.getLoc());
471 rewriter.eraseOp(guardReturn);
475 rewriter.eraseOp(op);
484 LogicalResult TransitionOp::verify() {
485 if (!getNextStateOp())
486 return emitOpError(
"cannot find the definition of the next state `")
487 << getNextState() <<
"`";
491 if (getGuard().front().
empty() ||
492 !isa_and_nonnull<fsm::ReturnOp>(&getGuard().front().back()))
493 return emitOpError(
"guard region must terminate with a ReturnOp");
497 if ((*this)->getParentRegion() != &getCurrentState().getTransitions())
498 return emitOpError(
"must only be located in the transitions region");
508 function_ref<
void(Value, StringRef)> setNameFn) {
509 setNameFn(getResult(),
getName());
516 void ReturnOp::setOperand(Value value) {
518 getOperation()->setOperand(0, value);
520 getOperation()->insertOperands(0, {value});
528 VariableOp UpdateOp::getVariableOp() {
529 return getVariable().getDefiningOp<VariableOp>();
532 LogicalResult UpdateOp::verify() {
534 return emitOpError(
"destination is not a variable operation");
536 if (!(*this)->getParentOfType<
TransitionOp>().getAction().isAncestor(
537 (*this)->getParentRegion()))
538 return emitOpError(
"must only be located in the action region");
540 auto transition = (*this)->getParentOfType<
TransitionOp>();
541 for (
auto otherUpdateOp : transition.getAction().getOps<UpdateOp>()) {
542 if (otherUpdateOp == *
this)
544 if (otherUpdateOp.getVariable() == getVariable())
545 return otherUpdateOp.emitOpError(
546 "multiple updates to the same variable within a single action region "
558 #define GET_OP_CLASSES
559 #include "circt/Dialect/FSM/FSM.cpp.inc"
560 #undef GET_OP_CLASSES
562 #include "circt/Dialect/FSM/FSMDialect.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyCallerTypes(OpType op)
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
static InstancePath empty
static StringAttr getResName(Operation *op, size_t idx)
static StringAttr getArgName(Operation *op, size_t idx)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Direction
The direction of a Component or Cell port.
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.