10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/BuiltinOps.h"
13#include "mlir/IR/DialectImplementation.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Interfaces/FunctionImplementation.h"
16#include "llvm/Support/FormatVariadic.h"
26void MachineOp::build(OpBuilder &builder, OperationState &state, StringRef name,
27 StringRef initialStateName, FunctionType type,
28 ArrayRef<NamedAttribute> attrs,
29 ArrayRef<DictionaryAttr> argAttrs) {
30 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
31 builder.getStringAttr(name));
32 state.addAttribute(MachineOp::getFunctionTypeAttrName(state.name),
34 state.addAttribute(
"initialState",
35 StringAttr::get(state.getContext(), initialStateName));
36 state.attributes.append(attrs.begin(), attrs.end());
37 Region *region = state.addRegion();
39 region->push_back(
body);
42 SmallVector<Location, 4>(
type.getNumInputs(), builder.getUnknownLoc()));
46 assert(
type.getNumInputs() == argAttrs.size());
47 call_interface_impl::addArgAndResultAttrs(
48 builder, state, argAttrs,
49 {}, MachineOp::getArgAttrsAttrName(state.name),
50 MachineOp::getResAttrsAttrName(state.name));
54StateOp MachineOp::getInitialStateOp() {
55 return dyn_cast_or_null<StateOp>(lookupSymbol(getInitialState()));
58size_t MachineOp::getNumStates() {
59 auto stateOps = getBody().getOps<
StateOp>();
60 return std::distance(stateOps.begin(), stateOps.end());
63StringAttr MachineOp::getArgName(
size_t i) {
64 if (
auto args = getArgNames())
65 return cast<StringAttr>((*args)[i]);
67 return StringAttr::get(getContext(),
"in" + std::to_string(i));
70StringAttr MachineOp::getResName(
size_t i) {
71 if (
auto resNameAttrs = getResNames())
72 return cast<StringAttr>((*resNameAttrs)[i]);
74 return StringAttr::get(getContext(),
"out" + std::to_string(i));
78void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
80 auto machineType = getFunctionType();
81 for (
unsigned i = 0, e = machineType.getNumInputs(); i < e; ++i) {
85 port.
name = StringAttr::get(getContext(),
"in" + std::to_string(i));
87 port.
type = machineType.getInput(i);
89 ports.push_back(port);
92 for (
unsigned i = 0, e = machineType.getNumResults(); i < e; ++i) {
96 port.
name = StringAttr::get(getContext(),
"out" + std::to_string(i));
98 port.
type = machineType.getResult(i);
100 ports.push_back(port);
104ParseResult MachineOp::parse(OpAsmParser &parser, OperationState &result) {
106 [&](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
107 function_interface_impl::VariadicFlag,
108 std::string &) {
return builder.getFunctionType(argTypes, results); };
110 return function_interface_impl::parseFunctionOp(
111 parser, result,
false,
112 MachineOp::getFunctionTypeAttrName(result.name), buildFuncType,
113 MachineOp::getArgAttrsAttrName(result.name),
114 MachineOp::getResAttrsAttrName(result.name));
117void MachineOp::print(OpAsmPrinter &p) {
118 function_interface_impl::printFunctionOp(
119 p, *
this,
false, getFunctionTypeAttrName(),
120 getArgAttrsAttrName(), getResAttrsAttrName());
125 if (rangeA.size() != rangeB.size())
126 return emitError(loc) <<
"mismatch in number of types compared ("
127 << rangeA.size() <<
" != " << rangeB.size() <<
")";
130 for (
auto zip : llvm::zip(rangeA, rangeB)) {
131 auto typeA = std::get<0>(zip);
132 auto typeB = std::get<1>(zip);
134 return emitError(loc) <<
"type mismatch at index " << index <<
" ("
135 << typeA <<
" != " << typeB <<
")";
142LogicalResult MachineOp::verify() {
151 front().getArgumentTypes())))
153 "entry block argument types must match the machine input types");
156 if (!llvm::hasSingleElement(*
this))
157 return emitOpError(
"must only have a single block");
160 if (!getInitialStateOp())
161 return emitOpError(
"initial state '" + getInitialState() +
162 "' was not defined in the machine");
164 if (getArgNames() && getArgNames()->size() != getArgumentTypes().size())
165 return emitOpError() <<
"number of machine arguments ("
166 << getArgumentTypes().size()
168 "not match the provided number "
169 "of argument names ("
170 << getArgNames()->size() <<
")";
172 if (getResNames() && getResNames()->size() != getResultTypes().size())
173 return emitOpError() <<
"number of machine results ("
174 << getResultTypes().size()
176 "not match the provided number "
178 << getResNames()->size() <<
")";
183SmallVector<::circt::hw::PortInfo> MachineOp::getPortList() {
184 SmallVector<hw::PortInfo> ports;
185 auto argNames = getArgNames();
186 auto argTypes = getFunctionType().getInputs();
187 for (
unsigned i = 0, e = argTypes.size(); i < e; ++i) {
188 bool isInOut =
false;
189 auto type = argTypes[i];
191 if (
auto inout = dyn_cast<hw::InOutType>(
type)) {
193 type = inout.getElementType();
196 auto direction = isInOut ? hw::ModulePort::Direction::InOut
197 : hw::ModulePort::Direction::Input;
200 {{argNames ? cast<StringAttr>((*argNames)[i])
201 : StringAttr::
get(getContext(), Twine(
"input") + Twine(i)),
208 auto resultNames = getResNames();
209 auto resultTypes = getFunctionType().getResults();
210 for (
unsigned i = 0, e = resultTypes.size(); i < e; ++i) {
211 ports.push_back({{resultNames ? cast<StringAttr>((*resultNames)[i])
212 : StringAttr::
get(getContext(),
213 Twine(
"output") + Twine(i)),
228 auto module = (*this)->getParentOfType<ModuleOp>();
229 return module.lookupSymbol<MachineOp>(getMachine());
232LogicalResult InstanceOp::verify() {
233 auto m = getMachineOp();
235 return emitError(
"cannot find machine definition '") << getMachine() <<
"'";
240void InstanceOp::getAsmResultNames(
241 function_ref<
void(Value, StringRef)> setNameFn) {
242 setNameFn(getInstance(),
getName());
249template <
typename OpType>
251 auto machine = op.getMachineOp();
253 return op.emitError(
"cannot find machine definition");
256 if (failed(
compareTypes(op.getLoc(), machine.getArgumentTypes(),
257 op.getInputs().getTypes()))) {
259 op.emitOpError(
"operand types must match the machine input types");
260 diag.attachNote(machine->getLoc()) <<
"original machine declared here";
265 if (failed(
compareTypes(op.getLoc(), machine.getResultTypes(),
266 op.getOutputs().getTypes()))) {
268 op.emitOpError(
"result types must match the machine output types");
269 diag.attachNote(machine->getLoc()) <<
"original machine declared here";
278 auto instanceOp = getInstance().getDefiningOp<InstanceOp>();
282 return instanceOp.getMachineOp();
295 auto module = (*this)->getParentOfType<ModuleOp>();
296 return module.lookupSymbol<MachineOp>(getMachine());
301SmallVector<hw::PortInfo> HWInstanceOp::getPortList() {
302 return getMachineOp().getPortList();
306StringRef HWInstanceOp::getModuleName() {
return getMachine(); }
307FlatSymbolRefAttr HWInstanceOp::getModuleNameAttr() {
return getMachineAttr(); }
309mlir::StringAttr HWInstanceOp::getInstanceNameAttr() {
return getNameAttr(); }
311llvm::StringRef HWInstanceOp::getInstanceName() {
return getName(); }
317void StateOp::build(OpBuilder &builder, OperationState &state,
318 StringRef stateName) {
319 OpBuilder::InsertionGuard guard(builder);
320 Region *output = state.addRegion();
321 output->push_back(
new Block());
322 builder.setInsertionPointToEnd(&output->back());
324 Region *transitions = state.addRegion();
325 transitions->push_back(
new Block());
326 state.addAttribute(
"sym_name", builder.getStringAttr(stateName));
329void StateOp::build(OpBuilder &builder, OperationState &state,
330 StringRef stateName, ValueRange outputs) {
331 OpBuilder::InsertionGuard guard(builder);
332 Region *output = state.addRegion();
333 output->push_back(
new Block());
334 builder.setInsertionPointToEnd(&output->back());
336 Region *transitions = state.addRegion();
337 transitions->push_back(
new Block());
338 state.addAttribute(
"sym_name", builder.getStringAttr(stateName));
341SetVector<StateOp> StateOp::getNextStates() {
342 SmallVector<StateOp> nextStates;
344 getTransitions().getOps<TransitionOp>(),
345 std::inserter(nextStates, nextStates.begin()),
346 [](
TransitionOp transition) { return transition.getNextStateOp(); });
347 return SetVector<StateOp>(nextStates.begin(), nextStates.end());
350LogicalResult StateOp::canonicalize(
StateOp op, PatternRewriter &rewriter) {
351 bool hasAlwaysTakenTransition =
false;
352 SmallVector<TransitionOp, 4> transitionsToErase;
354 for (
auto transition : op.getTransitions().getOps<
TransitionOp>()) {
355 if (!hasAlwaysTakenTransition)
356 hasAlwaysTakenTransition = transition.isAlwaysTaken();
358 transitionsToErase.push_back(transition);
361 for (
auto transition : transitionsToErase)
362 rewriter.eraseOp(transition);
364 return failure(transitionsToErase.empty());
367LogicalResult StateOp::verify() {
370 if (parent.getNumResults() != 0 && (getOutput().empty()))
371 return emitOpError(
"state must have a non-empty output region when the "
372 "machine has results.");
374 if (!getOutput().
empty()) {
376 Block *outputBlock = &getOutput().front();
377 if (outputBlock->empty() || !isa<fsm::OutputOp>(outputBlock->back()))
378 return emitOpError(
"output block must have a single OutputOp terminator");
384Block *StateOp::ensureOutput(OpBuilder &builder) {
385 if (getOutput().
empty()) {
386 OpBuilder::InsertionGuard g(builder);
387 auto *block =
new Block();
388 getOutput().push_back(block);
389 builder.setInsertionPointToStart(block);
392 return &getOutput().front();
399LogicalResult OutputOp::verify() {
400 if ((*this)->getParentRegion() ==
401 &(*this)->getParentOfType<
StateOp>().getTransitions()) {
402 if (getNumOperands() != 0)
403 emitOpError(
"transitions region must not output any value");
409 auto machine = (*this)->getParentOfType<
MachineOp>();
412 return emitOpError(
"operand types must match the machine output types");
421void TransitionOp::build(OpBuilder &builder, OperationState &state,
423 build(builder, state, nextState.getName());
426void TransitionOp::build(OpBuilder &builder, OperationState &state,
428 llvm::function_ref<
void()> guardCtor,
429 llvm::function_ref<
void()> actionCtor) {
430 state.addAttribute(
"nextState",
431 FlatSymbolRefAttr::get(builder.getStringAttr(nextState)));
432 OpBuilder::InsertionGuard guard(builder);
434 Region *guardRegion = state.addRegion();
436 builder.createBlock(guardRegion);
440 Region *actionRegion = state.addRegion();
442 builder.createBlock(actionRegion);
447Block *TransitionOp::ensureGuard(OpBuilder &builder) {
448 if (getGuard().
empty()) {
449 OpBuilder::InsertionGuard g(builder);
450 auto *block =
new Block();
451 getGuard().push_back(block);
452 builder.setInsertionPointToStart(block);
453 fsm::ReturnOp::create(builder,
getLoc());
455 return &getGuard().front();
458Block *TransitionOp::ensureAction(OpBuilder &builder) {
459 if (getAction().
empty())
460 getAction().push_back(
new Block());
461 return &getAction().front();
465StateOp TransitionOp::getNextStateOp() {
466 auto machineOp = (*this)->getParentOfType<
MachineOp>();
470 return machineOp.lookupSymbol<
StateOp>(getNextState());
473bool TransitionOp::isAlwaysTaken() {
477 auto guardReturn = getGuardReturn();
478 if (guardReturn.getNumOperands() == 0)
481 if (
auto constantOp =
482 guardReturn.getOperand().getDefiningOp<mlir::arith::ConstantOp>())
483 return cast<BoolAttr>(constantOp.getValue()).getValue();
488LogicalResult TransitionOp::canonicalize(
TransitionOp op,
489 PatternRewriter &rewriter) {
491 auto guardReturn = op.getGuardReturn();
492 if (guardReturn.getNumOperands() == 1)
493 if (
auto constantOp = guardReturn.getOperand()
494 .getDefiningOp<mlir::arith::ConstantOp>()) {
496 if (cast<BoolAttr>(constantOp.getValue()).getValue()) {
499 rewriter.setInsertionPoint(guardReturn);
500 fsm::ReturnOp::create(rewriter, guardReturn.getLoc());
501 rewriter.eraseOp(guardReturn);
505 rewriter.eraseOp(op);
514LogicalResult TransitionOp::verify() {
515 if (!getNextStateOp())
516 return emitOpError(
"cannot find the definition of the next state `")
517 << getNextState() <<
"`";
521 if (getGuard().front().
empty() ||
522 !isa_and_nonnull<fsm::ReturnOp>(&getGuard().front().back()))
523 return emitOpError(
"guard region must terminate with a ReturnOp");
527 if ((*this)->getParentRegion() != &getCurrentState().getTransitions())
528 return emitOpError(
"must only be located in the transitions region");
537void VariableOp::getAsmResultNames(
538 function_ref<
void(Value, StringRef)> setNameFn) {
539 setNameFn(getResult(),
getName());
546void ReturnOp::setOperand(Value value) {
548 getOperation()->setOperand(0, value);
550 getOperation()->insertOperands(0, {value});
558VariableOp UpdateOp::getVariableOp() {
559 return getVariable().getDefiningOp<VariableOp>();
562LogicalResult UpdateOp::verify() {
564 return emitOpError(
"destination is not a variable operation");
566 if (!(*this)->getParentOfType<
TransitionOp>().getAction().isAncestor(
567 (*this)->getParentRegion()))
568 return emitOpError(
"must only be located in the action region");
570 auto transition = (*this)->getParentOfType<
TransitionOp>();
571 for (
auto otherUpdateOp : transition.getAction().getOps<UpdateOp>()) {
572 if (otherUpdateOp == *
this)
574 if (otherUpdateOp.getVariable() == getVariable())
575 return otherUpdateOp.emitOpError(
576 "multiple updates to the same variable within a single action region "
588#define GET_OP_CLASSES
589#include "circt/Dialect/FSM/FSM.cpp.inc"
592#include "circt/Dialect/FSM/FSMDialect.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyCallerTypes(OpType op)
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
static Location getLoc(DefSlot slot)
static StringAttr getResName(Operation *op, size_t idx)
static StringAttr getArgName(Operation *op, size_t idx)
static InstancePath empty
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Direction
The direction of a Component or Cell port.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds the name, type, direction of a module's ports.
size_t argNum
This is the argument index or the result index depending on the direction.