10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Interfaces/FunctionImplementation.h"
16 #include "llvm/Support/FormatVariadic.h"
19 using namespace circt;
26 void MachineOp::build(OpBuilder &builder, OperationState &state, StringRef name,
27 StringRef initialStateName, FunctionType type,
28 ArrayRef<NamedAttribute> attrs,
29 ArrayRef<DictionaryAttr> argAttrs) {
30 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
31 builder.getStringAttr(name));
32 state.addAttribute(MachineOp::getFunctionTypeAttrName(state.name),
34 state.addAttribute(
"initialState",
36 state.attributes.append(attrs.begin(), attrs.end());
37 Region *region = state.addRegion();
39 region->push_back(body);
42 SmallVector<Location, 4>(type.getNumInputs(), builder.getUnknownLoc()));
46 assert(type.getNumInputs() == argAttrs.size());
47 function_interface_impl::addArgAndResultAttrs(
48 builder, state, argAttrs,
49 std::nullopt, MachineOp::getArgAttrsAttrName(state.name),
50 MachineOp::getResAttrsAttrName(state.name));
54 StateOp MachineOp::getInitialStateOp() {
55 return dyn_cast_or_null<StateOp>(lookupSymbol(getInitialState()));
59 if (
auto args = getArgNames())
60 return cast<StringAttr>((*args)[i]);
66 if (
auto resNameAttrs = getResNames())
67 return cast<StringAttr>((*resNameAttrs)[i]);
73 void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
75 auto machineType = getFunctionType();
76 for (
unsigned i = 0, e = machineType.getNumInputs(); i < e; ++i) {
82 port.type = machineType.getInput(i);
84 ports.push_back(port);
87 for (
unsigned i = 0, e = machineType.getNumResults(); i < e; ++i) {
93 port.type = machineType.getResult(i);
95 ports.push_back(port);
99 ParseResult MachineOp::parse(OpAsmParser &parser, OperationState &result) {
101 [&](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
102 function_interface_impl::VariadicFlag,
103 std::string &) {
return builder.getFunctionType(argTypes, results); };
105 return function_interface_impl::parseFunctionOp(
106 parser, result,
false,
107 MachineOp::getFunctionTypeAttrName(result.name), buildFuncType,
108 MachineOp::getArgAttrsAttrName(result.name),
109 MachineOp::getResAttrsAttrName(result.name));
112 void MachineOp::print(OpAsmPrinter &p) {
113 function_interface_impl::printFunctionOp(
114 p, *
this,
false, getFunctionTypeAttrName(),
115 getArgAttrsAttrName(), getResAttrsAttrName());
120 if (rangeA.size() != rangeB.size())
121 return emitError(loc) <<
"mismatch in number of types compared ("
122 << rangeA.size() <<
" != " << rangeB.size() <<
")";
125 for (
auto zip : llvm::zip(rangeA, rangeB)) {
126 auto typeA = std::get<0>(zip);
127 auto typeB = std::get<1>(zip);
129 return emitError(loc) <<
"type mismatch at index " << index <<
" ("
130 << typeA <<
" != " << typeB <<
")";
146 front().getArgumentTypes())))
148 "entry block argument types must match the machine input types");
151 if (!llvm::hasSingleElement(*
this))
152 return emitOpError(
"must only have a single block");
155 if (!getInitialStateOp())
156 return emitOpError(
"initial state '" + getInitialState() +
157 "' was not defined in the machine");
159 if (getArgNames() && getArgNames()->size() != getArgumentTypes().size())
160 return emitOpError() <<
"number of machine arguments ("
161 << getArgumentTypes().size()
163 "not match the provided number "
164 "of argument names ("
165 << getArgNames()->size() <<
")";
167 if (getResNames() && getResNames()->size() != getResultTypes().size())
168 return emitOpError() <<
"number of machine results ("
169 << getResultTypes().size()
171 "not match the provided number "
173 << getResNames()->size() <<
")";
179 SmallVector<hw::PortInfo> ports;
180 auto argNames = getArgNames();
181 auto argTypes = getFunctionType().getInputs();
182 for (
unsigned i = 0, e = argTypes.size(); i < e; ++i) {
183 bool isInOut =
false;
184 auto type = argTypes[i];
186 if (
auto inout = dyn_cast<hw::InOutType>(type)) {
188 type = inout.getElementType();
195 {{argNames ? cast<StringAttr>((*argNames)[i])
196 : StringAttr::
get(getContext(), Twine(
"input") + Twine(i)),
203 auto resultNames = getResNames();
204 auto resultTypes = getFunctionType().getResults();
205 for (
unsigned i = 0, e = resultTypes.size(); i < e; ++i) {
206 ports.push_back({{resultNames ? cast<StringAttr>((*resultNames)[i])
207 : StringAttr::
get(getContext(),
208 Twine(
"output") + Twine(i)),
223 auto module = (*this)->getParentOfType<ModuleOp>();
224 return module.lookupSymbol<
MachineOp>(getMachine());
228 auto m = getMachineOp();
230 return emitError(
"cannot find machine definition '") << getMachine() <<
"'";
236 function_ref<
void(Value, StringRef)> setNameFn) {
237 setNameFn(getInstance(),
getName());
244 template <
typename OpType>
246 auto machine = op.getMachineOp();
248 return op.emitError(
"cannot find machine definition");
251 if (failed(
compareTypes(op.getLoc(), machine.getArgumentTypes(),
252 op.getInputs().getTypes()))) {
254 op.emitOpError(
"operand types must match the machine input types");
255 diag.attachNote(machine->getLoc()) <<
"original machine declared here";
260 if (failed(
compareTypes(op.getLoc(), machine.getResultTypes(),
261 op.getOutputs().getTypes()))) {
263 op.emitOpError(
"result types must match the machine output types");
264 diag.attachNote(machine->getLoc()) <<
"original machine declared here";
273 auto instanceOp = getInstance().getDefiningOp<InstanceOp>();
277 return instanceOp.getMachineOp();
290 auto module = (*this)->getParentOfType<ModuleOp>();
291 return module.lookupSymbol<
MachineOp>(getMachine());
297 return getMachineOp().getPortList();
301 StringRef HWInstanceOp::getModuleName() {
return getMachine(); }
302 FlatSymbolRefAttr HWInstanceOp::getModuleNameAttr() {
return getMachineAttr(); }
304 mlir::StringAttr HWInstanceOp::getInstanceNameAttr() {
return getNameAttr(); }
312 void StateOp::build(OpBuilder &builder, OperationState &state,
313 StringRef stateName) {
314 OpBuilder::InsertionGuard guard(builder);
315 Region *output = state.addRegion();
316 output->push_back(
new Block());
317 builder.setInsertionPointToEnd(&output->back());
319 Region *transitions = state.addRegion();
320 transitions->push_back(
new Block());
321 state.addAttribute(
"sym_name", builder.getStringAttr(stateName));
324 void StateOp::build(OpBuilder &builder, OperationState &state,
325 StringRef stateName, ValueRange outputs) {
326 OpBuilder::InsertionGuard guard(builder);
327 Region *output = state.addRegion();
328 output->push_back(
new Block());
329 builder.setInsertionPointToEnd(&output->back());
331 Region *transitions = state.addRegion();
332 transitions->push_back(
new Block());
333 state.addAttribute(
"sym_name", builder.getStringAttr(stateName));
336 SetVector<StateOp> StateOp::getNextStates() {
337 SmallVector<StateOp> nextStates;
339 getTransitions().getOps<TransitionOp>(),
340 std::inserter(nextStates, nextStates.begin()),
341 [](
TransitionOp transition) { return transition.getNextStateOp(); });
342 return SetVector<StateOp>(nextStates.begin(), nextStates.end());
346 bool hasAlwaysTakenTransition =
false;
347 SmallVector<TransitionOp, 4> transitionsToErase;
349 for (
auto transition : op.getTransitions().getOps<
TransitionOp>()) {
350 if (!hasAlwaysTakenTransition)
351 hasAlwaysTakenTransition = transition.isAlwaysTaken();
353 transitionsToErase.push_back(transition);
356 for (
auto transition : transitionsToErase)
357 rewriter.eraseOp(transition);
359 return failure(transitionsToErase.empty());
365 if (parent.getNumResults() != 0 && (getOutput().empty()))
366 return emitOpError(
"state must have a non-empty output region when the "
367 "machine has results.");
369 if (!getOutput().
empty()) {
371 Block *outputBlock = &getOutput().front();
372 if (outputBlock->empty() || !isa<fsm::OutputOp>(outputBlock->back()))
373 return emitOpError(
"output block must have a single OutputOp terminator");
379 Block *StateOp::ensureOutput(OpBuilder &builder) {
380 if (getOutput().
empty()) {
381 OpBuilder::InsertionGuard g(builder);
382 auto *block =
new Block();
383 getOutput().push_back(block);
384 builder.setInsertionPointToStart(block);
387 return &getOutput().front();
395 if ((*this)->getParentRegion() ==
396 &(*this)->getParentOfType<
StateOp>().getTransitions()) {
397 if (getNumOperands() != 0)
398 emitOpError(
"transitions region must not output any value");
404 auto machine = (*this)->getParentOfType<
MachineOp>();
406 compareTypes(getLoc(), machine.getResultTypes(), getOperandTypes())))
407 return emitOpError(
"operand types must match the machine output types");
416 void TransitionOp::build(OpBuilder &builder, OperationState &state,
418 build(builder, state, nextState.getName());
421 void TransitionOp::build(OpBuilder &builder, OperationState &state,
423 llvm::function_ref<
void()> guardCtor,
424 llvm::function_ref<
void()> actionCtor) {
425 state.addAttribute(
"nextState",
427 OpBuilder::InsertionGuard guard(builder);
429 Region *guardRegion = state.addRegion();
431 builder.createBlock(guardRegion);
435 Region *actionRegion = state.addRegion();
437 builder.createBlock(actionRegion);
442 Block *TransitionOp::ensureGuard(OpBuilder &builder) {
443 if (getGuard().
empty()) {
444 OpBuilder::InsertionGuard g(builder);
445 auto *block =
new Block();
446 getGuard().push_back(block);
447 builder.setInsertionPointToStart(block);
448 builder.create<fsm::ReturnOp>(getLoc());
450 return &getGuard().front();
453 Block *TransitionOp::ensureAction(OpBuilder &builder) {
454 if (getAction().
empty())
455 getAction().push_back(
new Block());
456 return &getAction().front();
460 StateOp TransitionOp::getNextStateOp() {
461 auto machineOp = (*this)->getParentOfType<
MachineOp>();
465 return machineOp.lookupSymbol<
StateOp>(getNextState());
468 bool TransitionOp::isAlwaysTaken() {
472 auto guardReturn = getGuardReturn();
473 if (guardReturn.getNumOperands() == 0)
476 if (
auto constantOp =
477 guardReturn.getOperand().getDefiningOp<mlir::arith::ConstantOp>())
478 return cast<BoolAttr>(constantOp.getValue()).getValue();
484 PatternRewriter &rewriter) {
486 auto guardReturn = op.getGuardReturn();
487 if (guardReturn.getNumOperands() == 1)
488 if (
auto constantOp = guardReturn.getOperand()
489 .getDefiningOp<mlir::arith::ConstantOp>()) {
491 if (cast<BoolAttr>(constantOp.getValue()).getValue()) {
494 rewriter.setInsertionPoint(guardReturn);
495 rewriter.
create<fsm::ReturnOp>(guardReturn.getLoc());
496 rewriter.eraseOp(guardReturn);
500 rewriter.eraseOp(op);
510 if (!getNextStateOp())
511 return emitOpError(
"cannot find the definition of the next state `")
512 << getNextState() <<
"`";
516 if (getGuard().front().
empty() ||
517 !isa_and_nonnull<fsm::ReturnOp>(&getGuard().front().back()))
518 return emitOpError(
"guard region must terminate with a ReturnOp");
522 if ((*this)->getParentRegion() != &getCurrentState().getTransitions())
523 return emitOpError(
"must only be located in the transitions region");
533 function_ref<
void(Value, StringRef)> setNameFn) {
534 setNameFn(getResult(),
getName());
541 void ReturnOp::setOperand(Value value) {
543 getOperation()->setOperand(0, value);
545 getOperation()->insertOperands(0, {value});
553 VariableOp UpdateOp::getVariableOp() {
554 return getVariable().getDefiningOp<VariableOp>();
559 return emitOpError(
"destination is not a variable operation");
561 if (!(*this)->getParentOfType<
TransitionOp>().getAction().isAncestor(
562 (*this)->getParentRegion()))
563 return emitOpError(
"must only be located in the action region");
565 auto transition = (*this)->getParentOfType<
TransitionOp>();
566 for (
auto otherUpdateOp : transition.getAction().getOps<UpdateOp>()) {
567 if (otherUpdateOp == *
this)
569 if (otherUpdateOp.getVariable() == getVariable())
570 return otherUpdateOp.emitOpError(
571 "multiple updates to the same variable within a single action region "
583 #define GET_OP_CLASSES
584 #include "circt/Dialect/FSM/FSM.cpp.inc"
585 #undef GET_OP_CLASSES
587 #include "circt/Dialect/FSM/FSMDialect.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyCallerTypes(OpType op)
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
static InstancePath empty
static StringAttr getResName(Operation *op, size_t idx)
static StringAttr getArgName(Operation *op, size_t idx)
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Direction
The direction of a Component or Cell port.
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.