14 #include "../PassDetail.h"
21 #include "llvm/ADT/TypeSwitch.h"
24 using namespace circt;
25 using namespace calyx;
31 class CompileFSMVisitor {
34 : graph(graph), sc(sc), ctx(graph.getMachine().getContext()),
35 builder(graph.getMachine().getContext()) {
40 LogicalResult dispatch(StateOp currentState, Operation *op,
42 return TypeSwitch<Operation *, LogicalResult>(op)
43 .template Case<SeqOp, EnableOp, IfOp, WhileOp>(
44 [&](
auto opNode) {
return visit(currentState, opNode, nextState); })
46 return op->emitError() <<
"Operation '" << op->getName()
47 <<
"' not supported for FSM compilation";
51 ArrayRef<Attribute> getCompiledGroups() {
return compiledGroups; }
63 LogicalResult visit(StateOp currentState, SeqOp, StateOp nextState);
64 LogicalResult visit(StateOp currentState, EnableOp, StateOp nextState);
65 LogicalResult visit(StateOp currentState, IfOp, StateOp nextState);
66 LogicalResult visit(StateOp currentState, WhileOp, StateOp nextState);
71 struct StateScopeGuard {
73 StateScopeGuard(CompileFSMVisitor &visitor, StringRef name,
75 : visitor(visitor), name(name) {
76 visitor.stateStack.push_back(suffix.str());
79 assert(!visitor.stateStack.empty());
80 visitor.stateStack.pop_back();
83 StringRef
getName() {
return name; }
86 CompileFSMVisitor &visitor;
93 StateScopeGuard pushStateScope(StringRef suffix) {
95 llvm::raw_string_ostream ss(name);
97 stateStack, ss, [&](
const auto &it) { ss << it; },
"_");
98 ss <<
"_" << suffix.str();
99 return StateScopeGuard(*
this, ns.newName(name), suffix);
107 SmallVector<std::string, 4> stateStack;
111 SmallVector<Attribute, 8> compiledGroups;
114 LogicalResult CompileFSMVisitor::visit(StateOp currentState, IfOp ifOp,
116 auto stateGuard = pushStateScope(
"if");
117 auto loc = ifOp.getLoc();
120 graph.renameState(currentState, stateGuard.getName());
122 auto lowerBranch = [&](Value cond, StringRef nextStateSuffix,
bool invert,
123 Operation *innerBranchOp) {
124 auto branchStateGuard = pushStateScope(nextStateSuffix);
126 graph.createState(
builder, ifOp.getLoc(), branchStateGuard.getName())
129 auto transitionOp = graph
130 .createTransition(
builder, ifOp.getLoc(),
131 currentState, branchStateOp)
133 transitionOp.ensureGuard(
builder);
134 fsm::ReturnOp returnOp = transitionOp.getGuardReturn();
135 OpBuilder::InsertionGuard g(
builder);
136 builder.setInsertionPointToStart(&transitionOp.getGuard().front());
137 Value branchTaken = cond;
139 OpBuilder::InsertionGuard g(
builder);
143 returnOp.setOperand(branchTaken);
147 if (failed(dispatch(branchStateOp, innerBranchOp, nextState)))
153 if (failed(lowerBranch(ifOp.getCond(),
"then",
false,
154 &ifOp.getThenBody()->front())))
158 if (ifOp.elseBodyExists() &&
159 failed(lowerBranch(ifOp.getCond(),
"else",
true,
160 &ifOp.getElseBody()->front())))
166 LogicalResult CompileFSMVisitor::visit(StateOp currentState, SeqOp seqOp,
168 Location loc = seqOp.getLoc();
169 auto seqStateGuard = pushStateScope(
"seq");
172 auto &seqOps = seqOp.getBodyBlock()->getOperations();
173 llvm::SmallVector<std::pair<Operation *, StateOp>> seqStates;
177 StateOp currentOpNextState = nextState;
178 int n = seqOps.size() - 1;
179 for (
auto &op : llvm::reverse(*seqOp.getBodyBlock())) {
180 auto subStateGuard = pushStateScope(std::to_string(n--));
182 graph.createState(
builder, op.getLoc(), subStateGuard.getName())
184 seqStates.insert(seqStates.begin(), {&op, thisStateOp});
188 if (failed(dispatch(thisStateOp, &op, currentOpNextState)))
192 currentOpNextState = thisStateOp;
196 graph.createTransition(
builder, loc, currentState, seqStates.front().second);
201 LogicalResult CompileFSMVisitor::visit(StateOp currentState, WhileOp whileOp,
203 OpBuilder::InsertionGuard g(
builder);
204 auto whileStateGuard = pushStateScope(
"while");
205 auto loc = whileOp.getLoc();
209 StateOp whileHeaderState = currentState;
210 graph.renameState(whileHeaderState,
211 (whileStateGuard.getName() +
"_header").str());
216 auto whileBodyEntryState =
219 (whileStateGuard.getName() +
"_entry").str())
222 Operation *whileBodyOp = &whileOp.getBodyBlock()->front();
223 if (failed(dispatch(whileBodyEntryState, whileBodyOp, whileHeaderState)))
228 auto bodyTransition =
230 .createTransition(
builder, loc, whileHeaderState, whileBodyEntryState)
232 auto nextStateTransition =
233 graph.createTransition(
builder, loc, whileHeaderState, nextState)
236 bodyTransition.ensureGuard(
builder);
237 bodyTransition.getGuardReturn().setOperand(whileOp.getCond());
238 nextStateTransition.ensureGuard(
builder);
239 builder.setInsertionPoint(nextStateTransition.getGuardReturn());
240 nextStateTransition.getGuardReturn().setOperand(
245 LogicalResult CompileFSMVisitor::visit(StateOp currentState, EnableOp enableOp,
248 "Expected this enableOp to be nested into some provided state");
251 auto enableStateGuard = pushStateScope(enableOp.getGroupName());
252 graph.renameState(currentState, enableStateGuard.getName());
257 OpBuilder::InsertionGuard g(
builder);
258 builder.setInsertionPointToStart(¤tState.getOutput().front());
259 builder.create<calyx::EnableOp>(enableOp.getLoc(), enableOp.getGroupName());
262 graph.createTransition(
builder, enableOp.getLoc(), currentState, nextState);
265 compiledGroups.push_back(
273 class CompileInvoke {
275 CompileInvoke(ComponentOp component, OpBuilder
builder)
280 void lowerInvokeOp(InvokeOp invokeOp);
281 std::string getTransitionName(InvokeOp invokeOp);
282 ComponentOp component;
288 size_t transitionNameTail = 0;
292 void CompileInvoke::compile() {
293 llvm::SmallVector<InvokeOp> invokeOps =
294 component.getControlOp().getInvokeOps();
295 for (InvokeOp op : invokeOps)
300 std::string CompileInvoke::getTransitionName(InvokeOp invokeOp) {
301 llvm::StringRef callee = invokeOp.getCallee();
302 std::string transitionNameHead =
"invoke_" + callee.str() +
"_";
303 std::string transitionName;
308 transitionName = transitionNameHead + std::to_string(transitionNameTail++);
309 }
while (component.getWiresOp().lookupSymbol(transitionName));
310 return transitionName;
314 void CompileInvoke::lowerInvokeOp(InvokeOp invokeOp) {
316 Operation *prevNode = component.getWiresOp().getOperation()->getPrevNode();
317 builder.setInsertionPointAfter(prevNode);
318 hw::ConstantOp constantOp =
builder.create<hw::ConstantOp>(
319 prevNode->getLoc(),
builder.getI1Type(), 1);
320 Location loc = component.getWiresOp().getLoc();
323 builder.setInsertionPointToEnd(component.getWiresOp().getBodyBlock());
324 std::string transitionName = getTransitionName(invokeOp);
325 GroupOp groupOp =
builder.create<GroupOp>(loc, transitionName);
326 builder.setInsertionPointToStart(groupOp.getBodyBlock());
327 Value go = invokeOp.getInstGoValue();
330 builder.create<AssignOp>(loc, go, constantOp);
331 auto ports = invokeOp.getPorts();
332 auto inputs = invokeOp.getInputs();
335 for (
auto [port, input] : llvm::zip(ports,
inputs))
336 builder.create<AssignOp>(loc, port, input);
337 Value done = invokeOp.getInstDoneValue();
340 builder.create<calyx::GroupDoneOp>(loc, done);
341 builder.setInsertionPointAfter(invokeOp.getOperation());
342 builder.create<EnableOp>(invokeOp.getLoc(), transitionName);
346 class CalyxToFSMPass :
public CalyxToFSMBase<CalyxToFSMPass> {
348 void runOnOperation()
override;
351 void CalyxToFSMPass::runOnOperation() {
352 ComponentOp component = getOperation();
353 OpBuilder
builder(&getContext());
354 auto ctrlOp = component.getControlOp();
355 assert(ctrlOp.getBodyBlock()->getOperations().size() == 1 &&
356 "Expected a single top-level operation in the schedule");
357 CompileInvoke compileInvoke(component,
builder);
358 compileInvoke.compile();
359 Operation &topLevelCtrlOp = ctrlOp.getBodyBlock()->front();
360 builder.setInsertionPoint(&topLevelCtrlOp);
366 auto machineName = (
"control_" + component.getName()).str();
369 builder.create<MachineOp>(ctrlOp.getLoc(), machineName,
370 "fsm_entry", funcType);
371 auto graph = FSMGraph(machine);
384 auto visitor = CompileFSMVisitor(sc, graph);
385 if (failed(visitor.dispatch(entryState, &topLevelCtrlOp, exitState))) {
392 topLevelCtrlOp.erase();
403 return std::make_unique<CalyxToFSMPass>();
assert(baseType &&"element must be base type")
llvm::SmallVector< StringAttr > inputs
A namespace that is used to store existing names and generate new names in some scope within the IR.
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
void addSymbol(mlir::SymbolOpInterface op)
Adds the symbol-defining 'op' to the cache.
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
static constexpr std::string_view sExitStateName
static constexpr std::string_view sEntryStateName
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
std::unique_ptr< mlir::Pass > createCalyxToFSMPass()