20 #include "mlir/Pass/Pass.h"
21 #include "llvm/ADT/TypeSwitch.h"
24 #define GEN_PASS_DEF_CALYXTOFSM
25 #include "circt/Conversion/Passes.h.inc"
29 using namespace circt;
30 using namespace calyx;
35 class CompileFSMVisitor {
38 : graph(graph), sc(sc), ctx(graph.getMachine().getContext()),
39 builder(graph.getMachine().getContext()) {
44 LogicalResult dispatch(
StateOp currentState, Operation *op,
46 return TypeSwitch<Operation *, LogicalResult>(op)
47 .template Case<SeqOp, EnableOp, IfOp, WhileOp>(
48 [&](
auto opNode) {
return visit(currentState, opNode, nextState); })
50 return op->emitError() <<
"Operation '" << op->getName()
51 <<
"' not supported for FSM compilation";
55 ArrayRef<Attribute> getCompiledGroups() {
return compiledGroups; }
67 LogicalResult visit(
StateOp currentState, SeqOp,
StateOp nextState);
68 LogicalResult visit(
StateOp currentState, EnableOp,
StateOp nextState);
69 LogicalResult visit(
StateOp currentState, IfOp,
StateOp nextState);
70 LogicalResult visit(
StateOp currentState, WhileOp,
StateOp nextState);
75 struct StateScopeGuard {
77 StateScopeGuard(CompileFSMVisitor &visitor, StringRef name,
79 : visitor(visitor), name(name) {
80 visitor.stateStack.push_back(suffix.str());
83 assert(!visitor.stateStack.empty());
84 visitor.stateStack.pop_back();
87 StringRef
getName() {
return name; }
90 CompileFSMVisitor &visitor;
97 StateScopeGuard pushStateScope(StringRef suffix) {
99 llvm::raw_string_ostream ss(name);
101 stateStack, ss, [&](
const auto &it) { ss << it; },
"_");
102 ss <<
"_" << suffix.str();
103 return StateScopeGuard(*
this, ns.newName(name), suffix);
111 SmallVector<std::string, 4> stateStack;
115 SmallVector<Attribute, 8> compiledGroups;
118 LogicalResult CompileFSMVisitor::visit(
StateOp currentState, IfOp ifOp,
120 auto stateGuard = pushStateScope(
"if");
121 auto loc = ifOp.getLoc();
124 graph.renameState(currentState, stateGuard.getName());
126 auto lowerBranch = [&](Value cond, StringRef nextStateSuffix,
bool invert,
127 Operation *innerBranchOp) {
128 auto branchStateGuard = pushStateScope(nextStateSuffix);
130 graph.createState(builder, ifOp.getLoc(), branchStateGuard.getName())
133 auto transitionOp = graph
134 .createTransition(builder, ifOp.getLoc(),
135 currentState, branchStateOp)
137 transitionOp.ensureGuard(builder);
138 fsm::ReturnOp returnOp = transitionOp.getGuardReturn();
139 OpBuilder::InsertionGuard g(builder);
140 builder.setInsertionPointToStart(&transitionOp.getGuard().front());
141 Value branchTaken = cond;
143 OpBuilder::InsertionGuard g(builder);
147 returnOp.setOperand(branchTaken);
151 if (failed(dispatch(branchStateOp, innerBranchOp, nextState)))
157 if (failed(lowerBranch(ifOp.getCond(),
"then",
false,
158 &ifOp.getThenBody()->front())))
162 if (ifOp.elseBodyExists() &&
163 failed(lowerBranch(ifOp.getCond(),
"else",
true,
164 &ifOp.getElseBody()->front())))
170 LogicalResult CompileFSMVisitor::visit(
StateOp currentState, SeqOp seqOp,
172 Location loc = seqOp.getLoc();
173 auto seqStateGuard = pushStateScope(
"seq");
176 auto &seqOps = seqOp.getBodyBlock()->getOperations();
177 llvm::SmallVector<std::pair<Operation *, StateOp>> seqStates;
181 StateOp currentOpNextState = nextState;
182 int n = seqOps.size() - 1;
183 for (
auto &op : llvm::reverse(*seqOp.getBodyBlock())) {
184 auto subStateGuard = pushStateScope(std::to_string(n--));
186 graph.createState(builder, op.getLoc(), subStateGuard.getName())
188 seqStates.insert(seqStates.begin(), {&op, thisStateOp});
189 sc.addSymbol(thisStateOp);
192 if (failed(dispatch(thisStateOp, &op, currentOpNextState)))
196 currentOpNextState = thisStateOp;
200 graph.createTransition(builder, loc, currentState, seqStates.front().second);
205 LogicalResult CompileFSMVisitor::visit(
StateOp currentState, WhileOp whileOp,
207 OpBuilder::InsertionGuard g(builder);
208 auto whileStateGuard = pushStateScope(
"while");
209 auto loc = whileOp.getLoc();
213 StateOp whileHeaderState = currentState;
214 graph.renameState(whileHeaderState,
215 (whileStateGuard.getName() +
"_header").str());
216 sc.addSymbol(whileHeaderState);
220 auto whileBodyEntryState =
222 .createState(builder, loc,
223 (whileStateGuard.getName() +
"_entry").str())
225 sc.addSymbol(whileBodyEntryState);
226 Operation *whileBodyOp = &whileOp.getBodyBlock()->front();
227 if (failed(dispatch(whileBodyEntryState, whileBodyOp, whileHeaderState)))
232 auto bodyTransition =
234 .createTransition(builder, loc, whileHeaderState, whileBodyEntryState)
236 auto nextStateTransition =
237 graph.createTransition(builder, loc, whileHeaderState, nextState)
240 bodyTransition.ensureGuard(builder);
241 bodyTransition.getGuardReturn().setOperand(whileOp.getCond());
242 nextStateTransition.ensureGuard(builder);
243 builder.setInsertionPoint(nextStateTransition.getGuardReturn());
244 nextStateTransition.getGuardReturn().setOperand(
249 LogicalResult CompileFSMVisitor::visit(
StateOp currentState, EnableOp enableOp,
252 "Expected this enableOp to be nested into some provided state");
255 auto enableStateGuard = pushStateScope(enableOp.getGroupName());
256 graph.renameState(currentState, enableStateGuard.getName());
261 OpBuilder::InsertionGuard g(builder);
262 builder.setInsertionPointToStart(¤tState.getOutput().front());
263 builder.create<calyx::EnableOp>(enableOp.getLoc(), enableOp.getGroupName());
266 graph.createTransition(builder, enableOp.getLoc(), currentState, nextState);
269 compiledGroups.push_back(
277 class CompileInvoke {
279 CompileInvoke(ComponentOp component, OpBuilder builder)
280 : component(component), builder(builder) {}
284 void lowerInvokeOp(InvokeOp invokeOp);
285 std::string getTransitionName(InvokeOp invokeOp);
286 ComponentOp component;
292 size_t transitionNameTail = 0;
296 void CompileInvoke::compile() {
297 llvm::SmallVector<InvokeOp> invokeOps =
298 component.getControlOp().getInvokeOps();
299 for (InvokeOp op : invokeOps)
304 std::string CompileInvoke::getTransitionName(InvokeOp invokeOp) {
305 llvm::StringRef callee = invokeOp.getCallee();
306 std::string transitionNameHead =
"invoke_" + callee.str() +
"_";
307 std::string transitionName;
312 transitionName = transitionNameHead + std::to_string(transitionNameTail++);
313 }
while (component.getWiresOp().lookupSymbol(transitionName));
314 return transitionName;
318 void CompileInvoke::lowerInvokeOp(InvokeOp invokeOp) {
320 Operation *prevNode = component.getWiresOp().getOperation()->getPrevNode();
321 builder.setInsertionPointAfter(prevNode);
323 prevNode->getLoc(), builder.getI1Type(), 1);
324 Location loc = component.getWiresOp().getLoc();
327 builder.setInsertionPointToEnd(component.getWiresOp().getBodyBlock());
328 std::string transitionName = getTransitionName(invokeOp);
329 GroupOp groupOp = builder.create<GroupOp>(loc, transitionName);
330 builder.setInsertionPointToStart(groupOp.getBodyBlock());
331 Value go = invokeOp.getInstGoValue();
334 builder.create<AssignOp>(loc, go, constantOp);
335 auto ports = invokeOp.getPorts();
336 auto inputs = invokeOp.getInputs();
339 for (
auto [port, input] : llvm::zip(ports, inputs))
340 builder.create<AssignOp>(loc, port, input);
341 Value done = invokeOp.getInstDoneValue();
344 builder.create<calyx::GroupDoneOp>(loc, done);
345 builder.setInsertionPointAfter(invokeOp.getOperation());
346 builder.create<EnableOp>(invokeOp.getLoc(), transitionName);
350 class CalyxToFSMPass :
public circt::impl::CalyxToFSMBase<CalyxToFSMPass> {
352 void runOnOperation()
override;
355 void CalyxToFSMPass::runOnOperation() {
356 ComponentOp component = getOperation();
357 OpBuilder builder(&getContext());
358 auto ctrlOp = component.getControlOp();
359 assert(ctrlOp.getBodyBlock()->getOperations().size() == 1 &&
360 "Expected a single top-level operation in the schedule");
361 CompileInvoke compileInvoke(component, builder);
362 compileInvoke.compile();
363 Operation &topLevelCtrlOp = ctrlOp.getBodyBlock()->front();
364 builder.setInsertionPoint(&topLevelCtrlOp);
370 auto machineName = (
"control_" + component.getName()).str();
373 builder.create<
MachineOp>(ctrlOp.getLoc(), machineName,
374 "fsm_entry", funcType);
375 auto graph = FSMGraph(machine);
388 auto visitor = CompileFSMVisitor(sc, graph);
389 if (failed(visitor.dispatch(entryState, &topLevelCtrlOp, exitState))) {
396 topLevelCtrlOp.erase();
401 ArrayAttr::get(builder.getContext(), visitor.getCompiledGroups()));
407 return std::make_unique<CalyxToFSMPass>();
assert(baseType &&"element must be base type")
A namespace that is used to store existing names and generate new names in some scope within the IR.
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
def create(data_type, value)
static constexpr std::string_view sExitStateName
static constexpr std::string_view sEntryStateName
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createCalyxToFSMPass()