CIRCT 20.0.0git
Loading...
Searching...
No Matches
CalyxToFSM.cpp
Go to the documentation of this file.
1//===- CalyxToFSM.cpp - Calyx to FSM conversion pass ----------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7//===----------------------------------------------------------------------===//
8//
9// This is the main Calyx control to FSM Conversion Pass Implementation.
10//
11//===----------------------------------------------------------------------===//
12
20#include "mlir/Pass/Pass.h"
21#include "llvm/ADT/TypeSwitch.h"
22
23namespace circt {
24#define GEN_PASS_DEF_CALYXTOFSM
25#include "circt/Conversion/Passes.h.inc"
26} // namespace circt
27
28using namespace mlir;
29using namespace circt;
30using namespace calyx;
31using namespace fsm;
32
33namespace {
34
35class CompileFSMVisitor {
36public:
37 CompileFSMVisitor(SymbolCache &sc, FSMGraph &graph)
38 : graph(graph), sc(sc), ctx(graph.getMachine().getContext()),
39 builder(graph.getMachine().getContext()) {
40 ns.add(sc);
41 }
42
43 /// Lowers the provided 'op' into a new FSM StateOp.
44 LogicalResult dispatch(StateOp currentState, Operation *op,
45 StateOp nextState) {
46 return TypeSwitch<Operation *, LogicalResult>(op)
47 .template Case<SeqOp, EnableOp, IfOp, WhileOp>(
48 [&](auto opNode) { return visit(currentState, opNode, nextState); })
49 .Default([&](auto) {
50 return op->emitError() << "Operation '" << op->getName()
51 << "' not supported for FSM compilation";
52 });
53 }
54
55 ArrayRef<Attribute> getCompiledGroups() { return compiledGroups; }
56
57private:
58 /// Operation visitors;
59 /// Apart from the visited operation, a visitor is provided with two extra
60 /// arguments:
61 /// currentState:
62 /// This represents a state which the callee has allocated to this visitor;
63 /// the visitor is free to use this state to its liking.
64 /// nextState:
65 /// This represent the next state which this visitor eventually must
66 /// transition to.
67 LogicalResult visit(StateOp currentState, SeqOp, StateOp nextState);
68 LogicalResult visit(StateOp currentState, EnableOp, StateOp nextState);
69 LogicalResult visit(StateOp currentState, IfOp, StateOp nextState);
70 LogicalResult visit(StateOp currentState, WhileOp, StateOp nextState);
71
72 /// Represents unique state name scopes generated from pushing states onto
73 /// the state stack. The guard carries a unique name as well as managing the
74 /// lifetime of suffixes on the state stack.
75 struct StateScopeGuard {
76 public:
77 StateScopeGuard(CompileFSMVisitor &visitor, StringRef name,
78 StringRef suffix)
79 : visitor(visitor), name(name) {
80 visitor.stateStack.push_back(suffix.str());
81 }
82 ~StateScopeGuard() {
83 assert(!visitor.stateStack.empty());
84 visitor.stateStack.pop_back();
85 }
86
87 StringRef getName() { return name; }
88
89 private:
90 CompileFSMVisitor &visitor;
91 std::string name;
92 };
93
94 /// Generates a new state name based on the current state stack and the
95 /// provided suffix. The new suffix is pushed onto the state stack. Returns a
96 /// guard object which pops the new suffix upon destruction.
97 StateScopeGuard pushStateScope(StringRef suffix) {
98 std::string name;
99 llvm::raw_string_ostream ss(name);
100 llvm::interleave(
101 stateStack, ss, [&](const auto &it) { ss << it; }, "_");
102 ss << "_" << suffix.str();
103 return StateScopeGuard(*this, ns.newName(name), suffix);
104 }
105
106 FSMGraph &graph;
107 SymbolCache &sc;
108 MLIRContext *ctx;
109 OpBuilder builder;
110 Namespace ns;
111 SmallVector<std::string, 4> stateStack;
112
113 /// Maintain the set of compiled groups within this FSM, to pass Calyx
114 /// verifiers.
115 SmallVector<Attribute, 8> compiledGroups;
116};
117
118LogicalResult CompileFSMVisitor::visit(StateOp currentState, IfOp ifOp,
119 StateOp nextState) {
120 auto stateGuard = pushStateScope("if");
121 auto loc = ifOp.getLoc();
122
123 // Rename the current state now that we know it's an if header.
124 graph.renameState(currentState, stateGuard.getName());
125
126 auto lowerBranch = [&](Value cond, StringRef nextStateSuffix, bool invert,
127 Operation *innerBranchOp) {
128 auto branchStateGuard = pushStateScope(nextStateSuffix);
129 auto branchStateOp =
130 graph.createState(builder, ifOp.getLoc(), branchStateGuard.getName())
131 ->getState();
132
133 auto transitionOp = graph
134 .createTransition(builder, ifOp.getLoc(),
135 currentState, branchStateOp)
136 ->getTransition();
137 transitionOp.ensureGuard(builder);
138 fsm::ReturnOp returnOp = transitionOp.getGuardReturn();
139 OpBuilder::InsertionGuard g(builder);
140 builder.setInsertionPointToStart(&transitionOp.getGuard().front());
141 Value branchTaken = cond;
142 if (invert) {
143 OpBuilder::InsertionGuard g(builder);
144 branchTaken = comb::createOrFoldNot(loc, branchTaken, builder);
145 }
146
147 returnOp.setOperand(branchTaken);
148
149 // Recurse into the body of the branch, with an exit state targeting
150 // 'nextState'.
151 if (failed(dispatch(branchStateOp, innerBranchOp, nextState)))
152 return failure();
153 return success();
154 };
155
156 // Then branch.
157 if (failed(lowerBranch(ifOp.getCond(), "then", /*invert=*/false,
158 &ifOp.getThenBody()->front())))
159 return failure();
160
161 // Else branch.
162 if (ifOp.elseBodyExists() &&
163 failed(lowerBranch(ifOp.getCond(), "else", /*invert=*/true,
164 &ifOp.getElseBody()->front())))
165 return failure();
166
167 return success();
168}
169
170LogicalResult CompileFSMVisitor::visit(StateOp currentState, SeqOp seqOp,
171 StateOp nextState) {
172 Location loc = seqOp.getLoc();
173 auto seqStateGuard = pushStateScope("seq");
174
175 // Create a new state for each nested operation within this seqOp.
176 auto &seqOps = seqOp.getBodyBlock()->getOperations();
177 llvm::SmallVector<std::pair<Operation *, StateOp>> seqStates;
178
179 // Iterate over the operations within the sequence. We do this in reverse
180 // order to ensure that we always know the next state.
181 StateOp currentOpNextState = nextState;
182 int n = seqOps.size() - 1;
183 for (auto &op : llvm::reverse(*seqOp.getBodyBlock())) {
184 auto subStateGuard = pushStateScope(std::to_string(n--));
185 auto thisStateOp =
186 graph.createState(builder, op.getLoc(), subStateGuard.getName())
187 ->getState();
188 seqStates.insert(seqStates.begin(), {&op, thisStateOp});
189 sc.addSymbol(thisStateOp);
190
191 // Recurse into the current operation.
192 if (failed(dispatch(thisStateOp, &op, currentOpNextState)))
193 return failure();
194
195 // This state is now the next state for the following operation.
196 currentOpNextState = thisStateOp;
197 }
198
199 // Make 'currentState' transition directly the first state in the sequence.
200 graph.createTransition(builder, loc, currentState, seqStates.front().second);
201
202 return success();
203}
204
205LogicalResult CompileFSMVisitor::visit(StateOp currentState, WhileOp whileOp,
206 StateOp nextState) {
207 OpBuilder::InsertionGuard g(builder);
208 auto whileStateGuard = pushStateScope("while");
209 auto loc = whileOp.getLoc();
210
211 // The current state is the while header (branch to whileOp or nextState).
212 // Rename the current state now that we know it's a while header state.
213 StateOp whileHeaderState = currentState;
214 graph.renameState(whileHeaderState,
215 (whileStateGuard.getName() + "_header").str());
216 sc.addSymbol(whileHeaderState);
217
218 // Dispatch into the while body. The while body will always return to the
219 // header.
220 auto whileBodyEntryState =
221 graph
222 .createState(builder, loc,
223 (whileStateGuard.getName() + "_entry").str())
224 ->getState();
225 sc.addSymbol(whileBodyEntryState);
226 Operation *whileBodyOp = &whileOp.getBodyBlock()->front();
227 if (failed(dispatch(whileBodyEntryState, whileBodyOp, whileHeaderState)))
228 return failure();
229
230 // Create transitions to either the while body or the next state based on the
231 // while condition.
232 auto bodyTransition =
233 graph
234 .createTransition(builder, loc, whileHeaderState, whileBodyEntryState)
235 ->getTransition();
236 auto nextStateTransition =
237 graph.createTransition(builder, loc, whileHeaderState, nextState)
238 ->getTransition();
239
240 bodyTransition.ensureGuard(builder);
241 bodyTransition.getGuardReturn().setOperand(whileOp.getCond());
242 nextStateTransition.ensureGuard(builder);
243 builder.setInsertionPoint(nextStateTransition.getGuardReturn());
244 nextStateTransition.getGuardReturn().setOperand(
245 comb::createOrFoldNot(loc, whileOp.getCond(), builder));
246 return success();
247}
248
249LogicalResult CompileFSMVisitor::visit(StateOp currentState, EnableOp enableOp,
250 StateOp nextState) {
251 assert(currentState &&
252 "Expected this enableOp to be nested into some provided state");
253
254 // Rename the current state now that we know it's an enable state.
255 auto enableStateGuard = pushStateScope(enableOp.getGroupName());
256 graph.renameState(currentState, enableStateGuard.getName());
257
258 // Create a new calyx.enable in the output state referencing the enabled
259 // group. We create a new op here as opposed to moving the existing, to make
260 // callers iterating over nested ops safer.
261 OpBuilder::InsertionGuard g(builder);
262 builder.setInsertionPointToStart(&currentState.getOutput().front());
263 builder.create<calyx::EnableOp>(enableOp.getLoc(), enableOp.getGroupName());
264
265 if (nextState)
266 graph.createTransition(builder, enableOp.getLoc(), currentState, nextState);
267
268 // Append this group to the set of compiled groups.
269 compiledGroups.push_back(
270 SymbolRefAttr::get(builder.getContext(), enableOp.getGroupName()));
271
272 return success();
273}
274
275// CompileInvoke is used to convert invoke operations to group operations and
276// enable operations.
277class CompileInvoke {
278public:
279 CompileInvoke(ComponentOp component, OpBuilder builder)
280 : component(component), builder(builder) {}
281 void compile();
282
283private:
284 void lowerInvokeOp(InvokeOp invokeOp);
285 std::string getTransitionName(InvokeOp invokeOp);
286 ComponentOp component;
287 OpBuilder builder;
288 // Part of the group name. It is used to generate unique group names, the
289 // unique counter is reused across multiple calls to lowerInvokeOp, so the
290 // loop that's checking for name uniqueness usually finds a unique name on the
291 // first try.
292 size_t transitionNameTail = 0;
293};
294
295// Access all invokeOp.
296void CompileInvoke::compile() {
297 llvm::SmallVector<InvokeOp> invokeOps =
298 component.getControlOp().getInvokeOps();
299 for (InvokeOp op : invokeOps)
300 lowerInvokeOp(op);
301}
302
303// Get the name of the generation group.
304std::string CompileInvoke::getTransitionName(InvokeOp invokeOp) {
305 llvm::StringRef callee = invokeOp.getCallee();
306 std::string transitionNameHead = "invoke_" + callee.str() + "_";
307 std::string transitionName;
308
309 // The following loop is used to check if the transitionName already exists.
310 // If it does, the loop regenerates the transitionName.
311 do {
312 transitionName = transitionNameHead + std::to_string(transitionNameTail++);
313 } while (component.getWiresOp().lookupSymbol(transitionName));
314 return transitionName;
315}
316
317// Convert an invoke operation to a group operation and an enable operation.
318void CompileInvoke::lowerInvokeOp(InvokeOp invokeOp) {
319 // Create a ConstantOp to assign a value to the go port.
320 Operation *prevNode = component.getWiresOp().getOperation()->getPrevNode();
321 builder.setInsertionPointAfter(prevNode);
322 hw::ConstantOp constantOp = builder.create<hw::ConstantOp>(
323 prevNode->getLoc(), builder.getI1Type(), 1);
324 Location loc = component.getWiresOp().getLoc();
325
326 // Set the insertion point at the end of the wires block.
327 builder.setInsertionPointToEnd(component.getWiresOp().getBodyBlock());
328 std::string transitionName = getTransitionName(invokeOp);
329 GroupOp groupOp = builder.create<GroupOp>(loc, transitionName);
330 builder.setInsertionPointToStart(groupOp.getBodyBlock());
331 Value go = invokeOp.getInstGoValue();
332
333 // Assign a value to the go port.
334 builder.create<AssignOp>(loc, go, constantOp);
335 auto ports = invokeOp.getPorts();
336 auto inputs = invokeOp.getInputs();
337
338 // Generate a series of assignment operations from a list of parameters.
339 for (auto [port, input] : llvm::zip(ports, inputs))
340 builder.create<AssignOp>(loc, port, input);
341 Value done = invokeOp.getInstDoneValue();
342
343 // Generate a group_done operation with the instance's done port.
344 builder.create<calyx::GroupDoneOp>(loc, done);
345 builder.setInsertionPointAfter(invokeOp.getOperation());
346 builder.create<EnableOp>(invokeOp.getLoc(), transitionName);
347 invokeOp.erase();
348}
349
350class CalyxToFSMPass : public circt::impl::CalyxToFSMBase<CalyxToFSMPass> {
351public:
352 void runOnOperation() override;
353}; // end anonymous namespace
354
355void CalyxToFSMPass::runOnOperation() {
356 ComponentOp component = getOperation();
357 OpBuilder builder(&getContext());
358 auto ctrlOp = component.getControlOp();
359 assert(ctrlOp.getBodyBlock()->getOperations().size() == 1 &&
360 "Expected a single top-level operation in the schedule");
361 CompileInvoke compileInvoke(component, builder);
362 compileInvoke.compile();
363 Operation &topLevelCtrlOp = ctrlOp.getBodyBlock()->front();
364 builder.setInsertionPoint(&topLevelCtrlOp);
365
366 // Create a side-effect-only FSM (no inputs, no outputs) which will strictly
367 // refer to the symbols and SSA values defined in the regions of the
368 // ComponentOp. This makes for an intermediate step, which allows for
369 // outlining the FSM (materializing FSM I/O) at a later point.
370 auto machineName = ("control_" + component.getName()).str();
371 auto funcType = FunctionType::get(&getContext(), {}, {});
372 auto machine =
373 builder.create<MachineOp>(ctrlOp.getLoc(), machineName,
374 /*initialState=*/"fsm_entry", funcType);
375 auto graph = FSMGraph(machine);
376
377 SymbolCache sc;
378 sc.addDefinitions(machine);
379
380 // Create entry and exit states
381 auto entryState =
382 graph.createState(builder, ctrlOp.getLoc(), calyxToFSM::sEntryStateName)
383 ->getState();
384 auto exitState =
385 graph.createState(builder, ctrlOp.getLoc(), calyxToFSM::sExitStateName)
386 ->getState();
387
388 auto visitor = CompileFSMVisitor(sc, graph);
389 if (failed(visitor.dispatch(entryState, &topLevelCtrlOp, exitState))) {
390 signalPassFailure();
391 return;
392 }
393
394 // Remove the top-level calyx control operation that we've now converted to an
395 // FSM.
396 topLevelCtrlOp.erase();
397
398 // Add the set of compiled groups as an attribute to the fsm.
399 machine->setAttr(
400 "compiledGroups",
401 ArrayAttr::get(builder.getContext(), visitor.getCompiledGroups()));
402}
403
404} // namespace
405
406std::unique_ptr<mlir::Pass> circt::createCalyxToFSMPass() {
407 return std::make_unique<CalyxToFSMPass>();
408}
assert(baseType &&"element must be base type")
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition Namespace.h:30
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
create(data_type, value)
Definition hw.py:433
static constexpr std::string_view sExitStateName
Definition CalyxToFSM.h:36
static constexpr std::string_view sEntryStateName
Definition CalyxToFSM.h:35
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createCalyxToFSMPass()
Definition fsm.py:1