10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/BuiltinOps.h"
13#include "mlir/IR/DialectImplementation.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Interfaces/FunctionImplementation.h"
16#include "llvm/Support/FormatVariadic.h"
26void MachineOp::build(OpBuilder &builder, OperationState &state, StringRef name,
27 StringRef initialStateName, FunctionType type,
28 ArrayRef<NamedAttribute> attrs,
29 ArrayRef<DictionaryAttr> argAttrs) {
30 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
31 builder.getStringAttr(name));
32 state.addAttribute(MachineOp::getFunctionTypeAttrName(state.name),
34 state.addAttribute(
"initialState",
35 StringAttr::get(state.getContext(), initialStateName));
36 state.attributes.append(attrs.begin(), attrs.end());
37 Region *region = state.addRegion();
39 region->push_back(
body);
42 SmallVector<Location, 4>(
type.getNumInputs(), builder.getUnknownLoc()));
46 assert(
type.getNumInputs() == argAttrs.size());
47 function_interface_impl::addArgAndResultAttrs(
48 builder, state, argAttrs,
49 std::nullopt, MachineOp::getArgAttrsAttrName(state.name),
50 MachineOp::getResAttrsAttrName(state.name));
54StateOp MachineOp::getInitialStateOp() {
55 return dyn_cast_or_null<StateOp>(lookupSymbol(getInitialState()));
58StringAttr MachineOp::getArgName(
size_t i) {
59 if (
auto args = getArgNames())
60 return cast<StringAttr>((*args)[i]);
62 return StringAttr::get(getContext(),
"in" + std::to_string(i));
65StringAttr MachineOp::getResName(
size_t i) {
66 if (
auto resNameAttrs = getResNames())
67 return cast<StringAttr>((*resNameAttrs)[i]);
69 return StringAttr::get(getContext(),
"out" + std::to_string(i));
73void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
75 auto machineType = getFunctionType();
76 for (
unsigned i = 0, e = machineType.getNumInputs(); i < e; ++i) {
80 port.
name = StringAttr::get(getContext(),
"in" + std::to_string(i));
82 port.
type = machineType.getInput(i);
84 ports.push_back(port);
87 for (
unsigned i = 0, e = machineType.getNumResults(); i < e; ++i) {
91 port.
name = StringAttr::get(getContext(),
"out" + std::to_string(i));
93 port.
type = machineType.getResult(i);
95 ports.push_back(port);
99ParseResult MachineOp::parse(OpAsmParser &parser, OperationState &result) {
101 [&](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
102 function_interface_impl::VariadicFlag,
103 std::string &) {
return builder.getFunctionType(argTypes, results); };
105 return function_interface_impl::parseFunctionOp(
106 parser, result,
false,
107 MachineOp::getFunctionTypeAttrName(result.name), buildFuncType,
108 MachineOp::getArgAttrsAttrName(result.name),
109 MachineOp::getResAttrsAttrName(result.name));
112void MachineOp::print(OpAsmPrinter &p) {
113 function_interface_impl::printFunctionOp(
114 p, *
this,
false, getFunctionTypeAttrName(),
115 getArgAttrsAttrName(), getResAttrsAttrName());
120 if (rangeA.size() != rangeB.size())
121 return emitError(loc) <<
"mismatch in number of types compared ("
122 << rangeA.size() <<
" != " << rangeB.size() <<
")";
125 for (
auto zip : llvm::zip(rangeA, rangeB)) {
126 auto typeA = std::get<0>(zip);
127 auto typeB = std::get<1>(zip);
129 return emitError(loc) <<
"type mismatch at index " << index <<
" ("
130 << typeA <<
" != " << typeB <<
")";
137LogicalResult MachineOp::verify() {
146 front().getArgumentTypes())))
148 "entry block argument types must match the machine input types");
151 if (!llvm::hasSingleElement(*
this))
152 return emitOpError(
"must only have a single block");
155 if (!getInitialStateOp())
156 return emitOpError(
"initial state '" + getInitialState() +
157 "' was not defined in the machine");
159 if (getArgNames() && getArgNames()->size() != getArgumentTypes().size())
160 return emitOpError() <<
"number of machine arguments ("
161 << getArgumentTypes().size()
163 "not match the provided number "
164 "of argument names ("
165 << getArgNames()->size() <<
")";
167 if (getResNames() && getResNames()->size() != getResultTypes().size())
168 return emitOpError() <<
"number of machine results ("
169 << getResultTypes().size()
171 "not match the provided number "
173 << getResNames()->size() <<
")";
178SmallVector<::circt::hw::PortInfo> MachineOp::getPortList() {
179 SmallVector<hw::PortInfo> ports;
180 auto argNames = getArgNames();
181 auto argTypes = getFunctionType().getInputs();
182 for (
unsigned i = 0, e = argTypes.size(); i < e; ++i) {
183 bool isInOut =
false;
184 auto type = argTypes[i];
186 if (
auto inout = dyn_cast<hw::InOutType>(
type)) {
188 type = inout.getElementType();
191 auto direction = isInOut ? hw::ModulePort::Direction::InOut
192 : hw::ModulePort::Direction::Input;
195 {{argNames ? cast<StringAttr>((*argNames)[i])
196 : StringAttr::
get(getContext(), Twine(
"input") + Twine(i)),
203 auto resultNames = getResNames();
204 auto resultTypes = getFunctionType().getResults();
205 for (
unsigned i = 0, e = resultTypes.size(); i < e; ++i) {
206 ports.push_back({{resultNames ? cast<StringAttr>((*resultNames)[i])
207 : StringAttr::
get(getContext(),
208 Twine(
"output") + Twine(i)),
223 auto module = (*this)->getParentOfType<ModuleOp>();
224 return module.lookupSymbol<MachineOp>(getMachine());
227LogicalResult InstanceOp::verify() {
228 auto m = getMachineOp();
230 return emitError(
"cannot find machine definition '") << getMachine() <<
"'";
235void InstanceOp::getAsmResultNames(
236 function_ref<
void(Value, StringRef)> setNameFn) {
237 setNameFn(getInstance(),
getName());
244template <
typename OpType>
246 auto machine = op.getMachineOp();
248 return op.emitError(
"cannot find machine definition");
251 if (failed(
compareTypes(op.getLoc(), machine.getArgumentTypes(),
252 op.getInputs().getTypes()))) {
254 op.emitOpError(
"operand types must match the machine input types");
255 diag.attachNote(machine->getLoc()) <<
"original machine declared here";
260 if (failed(
compareTypes(op.getLoc(), machine.getResultTypes(),
261 op.getOutputs().getTypes()))) {
263 op.emitOpError(
"result types must match the machine output types");
264 diag.attachNote(machine->getLoc()) <<
"original machine declared here";
273 auto instanceOp = getInstance().getDefiningOp<InstanceOp>();
277 return instanceOp.getMachineOp();
290 auto module = (*this)->getParentOfType<ModuleOp>();
291 return module.lookupSymbol<MachineOp>(getMachine());
296SmallVector<hw::PortInfo> HWInstanceOp::getPortList() {
297 return getMachineOp().getPortList();
301StringRef HWInstanceOp::getModuleName() {
return getMachine(); }
302FlatSymbolRefAttr HWInstanceOp::getModuleNameAttr() {
return getMachineAttr(); }
304mlir::StringAttr HWInstanceOp::getInstanceNameAttr() {
return getNameAttr(); }
306llvm::StringRef HWInstanceOp::getInstanceName() {
return getName(); }
312void StateOp::build(OpBuilder &builder, OperationState &state,
313 StringRef stateName) {
314 OpBuilder::InsertionGuard guard(builder);
315 Region *output = state.addRegion();
316 output->push_back(
new Block());
317 builder.setInsertionPointToEnd(&output->back());
319 Region *transitions = state.addRegion();
320 transitions->push_back(
new Block());
321 state.addAttribute(
"sym_name", builder.getStringAttr(stateName));
324void StateOp::build(OpBuilder &builder, OperationState &state,
325 StringRef stateName, ValueRange outputs) {
326 OpBuilder::InsertionGuard guard(builder);
327 Region *output = state.addRegion();
328 output->push_back(
new Block());
329 builder.setInsertionPointToEnd(&output->back());
331 Region *transitions = state.addRegion();
332 transitions->push_back(
new Block());
333 state.addAttribute(
"sym_name", builder.getStringAttr(stateName));
336SetVector<StateOp> StateOp::getNextStates() {
337 SmallVector<StateOp> nextStates;
339 getTransitions().getOps<TransitionOp>(),
340 std::inserter(nextStates, nextStates.begin()),
341 [](
TransitionOp transition) { return transition.getNextStateOp(); });
342 return SetVector<StateOp>(nextStates.begin(), nextStates.end());
345LogicalResult StateOp::canonicalize(
StateOp op, PatternRewriter &rewriter) {
346 bool hasAlwaysTakenTransition =
false;
347 SmallVector<TransitionOp, 4> transitionsToErase;
349 for (
auto transition : op.getTransitions().getOps<
TransitionOp>()) {
350 if (!hasAlwaysTakenTransition)
351 hasAlwaysTakenTransition = transition.isAlwaysTaken();
353 transitionsToErase.push_back(transition);
356 for (
auto transition : transitionsToErase)
357 rewriter.eraseOp(transition);
359 return failure(transitionsToErase.empty());
362LogicalResult StateOp::verify() {
365 if (parent.getNumResults() != 0 && (getOutput().empty()))
366 return emitOpError(
"state must have a non-empty output region when the "
367 "machine has results.");
369 if (!getOutput().
empty()) {
371 Block *outputBlock = &getOutput().front();
372 if (outputBlock->empty() || !isa<fsm::OutputOp>(outputBlock->back()))
373 return emitOpError(
"output block must have a single OutputOp terminator");
379Block *StateOp::ensureOutput(OpBuilder &builder) {
380 if (getOutput().
empty()) {
381 OpBuilder::InsertionGuard g(builder);
382 auto *block =
new Block();
383 getOutput().push_back(block);
384 builder.setInsertionPointToStart(block);
387 return &getOutput().front();
394LogicalResult OutputOp::verify() {
395 if ((*this)->getParentRegion() ==
396 &(*this)->getParentOfType<
StateOp>().getTransitions()) {
397 if (getNumOperands() != 0)
398 emitOpError(
"transitions region must not output any value");
404 auto machine = (*this)->getParentOfType<
MachineOp>();
406 compareTypes(getLoc(), machine.getResultTypes(), getOperandTypes())))
407 return emitOpError(
"operand types must match the machine output types");
416void TransitionOp::build(OpBuilder &builder, OperationState &state,
418 build(builder, state, nextState.getName());
421void TransitionOp::build(OpBuilder &builder, OperationState &state,
423 llvm::function_ref<
void()> guardCtor,
424 llvm::function_ref<
void()> actionCtor) {
425 state.addAttribute(
"nextState",
426 FlatSymbolRefAttr::get(builder.getStringAttr(nextState)));
427 OpBuilder::InsertionGuard guard(builder);
429 Region *guardRegion = state.addRegion();
431 builder.createBlock(guardRegion);
435 Region *actionRegion = state.addRegion();
437 builder.createBlock(actionRegion);
442Block *TransitionOp::ensureGuard(OpBuilder &builder) {
443 if (getGuard().
empty()) {
444 OpBuilder::InsertionGuard g(builder);
445 auto *block =
new Block();
446 getGuard().push_back(block);
447 builder.setInsertionPointToStart(block);
448 builder.create<fsm::ReturnOp>(getLoc());
450 return &getGuard().front();
453Block *TransitionOp::ensureAction(OpBuilder &builder) {
454 if (getAction().
empty())
455 getAction().push_back(
new Block());
456 return &getAction().front();
460StateOp TransitionOp::getNextStateOp() {
461 auto machineOp = (*this)->getParentOfType<
MachineOp>();
465 return machineOp.lookupSymbol<
StateOp>(getNextState());
468bool TransitionOp::isAlwaysTaken() {
472 auto guardReturn = getGuardReturn();
473 if (guardReturn.getNumOperands() == 0)
476 if (
auto constantOp =
477 guardReturn.getOperand().getDefiningOp<mlir::arith::ConstantOp>())
478 return cast<BoolAttr>(constantOp.getValue()).getValue();
483LogicalResult TransitionOp::canonicalize(
TransitionOp op,
484 PatternRewriter &rewriter) {
486 auto guardReturn = op.getGuardReturn();
487 if (guardReturn.getNumOperands() == 1)
488 if (
auto constantOp = guardReturn.getOperand()
489 .getDefiningOp<mlir::arith::ConstantOp>()) {
491 if (cast<BoolAttr>(constantOp.getValue()).getValue()) {
494 rewriter.setInsertionPoint(guardReturn);
495 rewriter.
create<fsm::ReturnOp>(guardReturn.getLoc());
496 rewriter.eraseOp(guardReturn);
500 rewriter.eraseOp(op);
509LogicalResult TransitionOp::verify() {
510 if (!getNextStateOp())
511 return emitOpError(
"cannot find the definition of the next state `")
512 << getNextState() <<
"`";
516 if (getGuard().front().
empty() ||
517 !isa_and_nonnull<fsm::ReturnOp>(&getGuard().front().back()))
518 return emitOpError(
"guard region must terminate with a ReturnOp");
522 if ((*this)->getParentRegion() != &getCurrentState().getTransitions())
523 return emitOpError(
"must only be located in the transitions region");
532void VariableOp::getAsmResultNames(
533 function_ref<
void(Value, StringRef)> setNameFn) {
534 setNameFn(getResult(),
getName());
541void ReturnOp::setOperand(Value value) {
543 getOperation()->setOperand(0, value);
545 getOperation()->insertOperands(0, {value});
553VariableOp UpdateOp::getVariableOp() {
554 return getVariable().getDefiningOp<VariableOp>();
557LogicalResult UpdateOp::verify() {
559 return emitOpError(
"destination is not a variable operation");
561 if (!(*this)->getParentOfType<
TransitionOp>().getAction().isAncestor(
562 (*this)->getParentRegion()))
563 return emitOpError(
"must only be located in the action region");
565 auto transition = (*this)->getParentOfType<
TransitionOp>();
566 for (
auto otherUpdateOp : transition.getAction().getOps<UpdateOp>()) {
567 if (otherUpdateOp == *
this)
569 if (otherUpdateOp.getVariable() == getVariable())
570 return otherUpdateOp.emitOpError(
571 "multiple updates to the same variable within a single action region "
583#define GET_OP_CLASSES
584#include "circt/Dialect/FSM/FSM.cpp.inc"
587#include "circt/Dialect/FSM/FSMDialect.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyCallerTypes(OpType op)
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
static InstancePath empty
static StringAttr getResName(Operation *op, size_t idx)
static StringAttr getArgName(Operation *op, size_t idx)
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Direction
The direction of a Component or Cell port.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name, type, direction of a module's ports.
size_t argNum
This is the argument index or the result index depending on the direction.