CIRCT  20.0.0git
CalyxToFSM.cpp
Go to the documentation of this file.
1 //===- CalyxToFSM.cpp - Calyx to FSM conversion pass ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 
7 //===----------------------------------------------------------------------===//
8 //
9 // This is the main Calyx control to FSM Conversion Pass Implementation.
10 //
11 //===----------------------------------------------------------------------===//
12 
20 #include "mlir/Pass/Pass.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 
23 namespace circt {
24 #define GEN_PASS_DEF_CALYXTOFSM
25 #include "circt/Conversion/Passes.h.inc"
26 } // namespace circt
27 
28 using namespace mlir;
29 using namespace circt;
30 using namespace calyx;
31 using namespace fsm;
32 
33 namespace {
34 
35 class CompileFSMVisitor {
36 public:
37  CompileFSMVisitor(SymbolCache &sc, FSMGraph &graph)
38  : graph(graph), sc(sc), ctx(graph.getMachine().getContext()),
39  builder(graph.getMachine().getContext()) {
40  ns.add(sc);
41  }
42 
43  /// Lowers the provided 'op' into a new FSM StateOp.
44  LogicalResult dispatch(StateOp currentState, Operation *op,
45  StateOp nextState) {
46  return TypeSwitch<Operation *, LogicalResult>(op)
47  .template Case<SeqOp, EnableOp, IfOp, WhileOp>(
48  [&](auto opNode) { return visit(currentState, opNode, nextState); })
49  .Default([&](auto) {
50  return op->emitError() << "Operation '" << op->getName()
51  << "' not supported for FSM compilation";
52  });
53  }
54 
55  ArrayRef<Attribute> getCompiledGroups() { return compiledGroups; }
56 
57 private:
58  /// Operation visitors;
59  /// Apart from the visited operation, a visitor is provided with two extra
60  /// arguments:
61  /// currentState:
62  /// This represents a state which the callee has allocated to this visitor;
63  /// the visitor is free to use this state to its liking.
64  /// nextState:
65  /// This represent the next state which this visitor eventually must
66  /// transition to.
67  LogicalResult visit(StateOp currentState, SeqOp, StateOp nextState);
68  LogicalResult visit(StateOp currentState, EnableOp, StateOp nextState);
69  LogicalResult visit(StateOp currentState, IfOp, StateOp nextState);
70  LogicalResult visit(StateOp currentState, WhileOp, StateOp nextState);
71 
72  /// Represents unique state name scopes generated from pushing states onto
73  /// the state stack. The guard carries a unique name as well as managing the
74  /// lifetime of suffixes on the state stack.
75  struct StateScopeGuard {
76  public:
77  StateScopeGuard(CompileFSMVisitor &visitor, StringRef name,
78  StringRef suffix)
79  : visitor(visitor), name(name) {
80  visitor.stateStack.push_back(suffix.str());
81  }
82  ~StateScopeGuard() {
83  assert(!visitor.stateStack.empty());
84  visitor.stateStack.pop_back();
85  }
86 
87  StringRef getName() { return name; }
88 
89  private:
90  CompileFSMVisitor &visitor;
91  std::string name;
92  };
93 
94  /// Generates a new state name based on the current state stack and the
95  /// provided suffix. The new suffix is pushed onto the state stack. Returns a
96  /// guard object which pops the new suffix upon destruction.
97  StateScopeGuard pushStateScope(StringRef suffix) {
98  std::string name;
99  llvm::raw_string_ostream ss(name);
100  llvm::interleave(
101  stateStack, ss, [&](const auto &it) { ss << it; }, "_");
102  ss << "_" << suffix.str();
103  return StateScopeGuard(*this, ns.newName(name), suffix);
104  }
105 
106  FSMGraph &graph;
107  SymbolCache &sc;
108  MLIRContext *ctx;
109  OpBuilder builder;
110  Namespace ns;
111  SmallVector<std::string, 4> stateStack;
112 
113  /// Maintain the set of compiled groups within this FSM, to pass Calyx
114  /// verifiers.
115  SmallVector<Attribute, 8> compiledGroups;
116 };
117 
118 LogicalResult CompileFSMVisitor::visit(StateOp currentState, IfOp ifOp,
119  StateOp nextState) {
120  auto stateGuard = pushStateScope("if");
121  auto loc = ifOp.getLoc();
122 
123  // Rename the current state now that we know it's an if header.
124  graph.renameState(currentState, stateGuard.getName());
125 
126  auto lowerBranch = [&](Value cond, StringRef nextStateSuffix, bool invert,
127  Operation *innerBranchOp) {
128  auto branchStateGuard = pushStateScope(nextStateSuffix);
129  auto branchStateOp =
130  graph.createState(builder, ifOp.getLoc(), branchStateGuard.getName())
131  ->getState();
132 
133  auto transitionOp = graph
134  .createTransition(builder, ifOp.getLoc(),
135  currentState, branchStateOp)
136  ->getTransition();
137  transitionOp.ensureGuard(builder);
138  fsm::ReturnOp returnOp = transitionOp.getGuardReturn();
139  OpBuilder::InsertionGuard g(builder);
140  builder.setInsertionPointToStart(&transitionOp.getGuard().front());
141  Value branchTaken = cond;
142  if (invert) {
143  OpBuilder::InsertionGuard g(builder);
144  branchTaken = comb::createOrFoldNot(loc, branchTaken, builder);
145  }
146 
147  returnOp.setOperand(branchTaken);
148 
149  // Recurse into the body of the branch, with an exit state targeting
150  // 'nextState'.
151  if (failed(dispatch(branchStateOp, innerBranchOp, nextState)))
152  return failure();
153  return success();
154  };
155 
156  // Then branch.
157  if (failed(lowerBranch(ifOp.getCond(), "then", /*invert=*/false,
158  &ifOp.getThenBody()->front())))
159  return failure();
160 
161  // Else branch.
162  if (ifOp.elseBodyExists() &&
163  failed(lowerBranch(ifOp.getCond(), "else", /*invert=*/true,
164  &ifOp.getElseBody()->front())))
165  return failure();
166 
167  return success();
168 }
169 
170 LogicalResult CompileFSMVisitor::visit(StateOp currentState, SeqOp seqOp,
171  StateOp nextState) {
172  Location loc = seqOp.getLoc();
173  auto seqStateGuard = pushStateScope("seq");
174 
175  // Create a new state for each nested operation within this seqOp.
176  auto &seqOps = seqOp.getBodyBlock()->getOperations();
177  llvm::SmallVector<std::pair<Operation *, StateOp>> seqStates;
178 
179  // Iterate over the operations within the sequence. We do this in reverse
180  // order to ensure that we always know the next state.
181  StateOp currentOpNextState = nextState;
182  int n = seqOps.size() - 1;
183  for (auto &op : llvm::reverse(*seqOp.getBodyBlock())) {
184  auto subStateGuard = pushStateScope(std::to_string(n--));
185  auto thisStateOp =
186  graph.createState(builder, op.getLoc(), subStateGuard.getName())
187  ->getState();
188  seqStates.insert(seqStates.begin(), {&op, thisStateOp});
189  sc.addSymbol(thisStateOp);
190 
191  // Recurse into the current operation.
192  if (failed(dispatch(thisStateOp, &op, currentOpNextState)))
193  return failure();
194 
195  // This state is now the next state for the following operation.
196  currentOpNextState = thisStateOp;
197  }
198 
199  // Make 'currentState' transition directly the first state in the sequence.
200  graph.createTransition(builder, loc, currentState, seqStates.front().second);
201 
202  return success();
203 }
204 
205 LogicalResult CompileFSMVisitor::visit(StateOp currentState, WhileOp whileOp,
206  StateOp nextState) {
207  OpBuilder::InsertionGuard g(builder);
208  auto whileStateGuard = pushStateScope("while");
209  auto loc = whileOp.getLoc();
210 
211  // The current state is the while header (branch to whileOp or nextState).
212  // Rename the current state now that we know it's a while header state.
213  StateOp whileHeaderState = currentState;
214  graph.renameState(whileHeaderState,
215  (whileStateGuard.getName() + "_header").str());
216  sc.addSymbol(whileHeaderState);
217 
218  // Dispatch into the while body. The while body will always return to the
219  // header.
220  auto whileBodyEntryState =
221  graph
222  .createState(builder, loc,
223  (whileStateGuard.getName() + "_entry").str())
224  ->getState();
225  sc.addSymbol(whileBodyEntryState);
226  Operation *whileBodyOp = &whileOp.getBodyBlock()->front();
227  if (failed(dispatch(whileBodyEntryState, whileBodyOp, whileHeaderState)))
228  return failure();
229 
230  // Create transitions to either the while body or the next state based on the
231  // while condition.
232  auto bodyTransition =
233  graph
234  .createTransition(builder, loc, whileHeaderState, whileBodyEntryState)
235  ->getTransition();
236  auto nextStateTransition =
237  graph.createTransition(builder, loc, whileHeaderState, nextState)
238  ->getTransition();
239 
240  bodyTransition.ensureGuard(builder);
241  bodyTransition.getGuardReturn().setOperand(whileOp.getCond());
242  nextStateTransition.ensureGuard(builder);
243  builder.setInsertionPoint(nextStateTransition.getGuardReturn());
244  nextStateTransition.getGuardReturn().setOperand(
245  comb::createOrFoldNot(loc, whileOp.getCond(), builder));
246  return success();
247 }
248 
249 LogicalResult CompileFSMVisitor::visit(StateOp currentState, EnableOp enableOp,
250  StateOp nextState) {
251  assert(currentState &&
252  "Expected this enableOp to be nested into some provided state");
253 
254  // Rename the current state now that we know it's an enable state.
255  auto enableStateGuard = pushStateScope(enableOp.getGroupName());
256  graph.renameState(currentState, enableStateGuard.getName());
257 
258  // Create a new calyx.enable in the output state referencing the enabled
259  // group. We create a new op here as opposed to moving the existing, to make
260  // callers iterating over nested ops safer.
261  OpBuilder::InsertionGuard g(builder);
262  builder.setInsertionPointToStart(&currentState.getOutput().front());
263  builder.create<calyx::EnableOp>(enableOp.getLoc(), enableOp.getGroupName());
264 
265  if (nextState)
266  graph.createTransition(builder, enableOp.getLoc(), currentState, nextState);
267 
268  // Append this group to the set of compiled groups.
269  compiledGroups.push_back(
270  SymbolRefAttr::get(builder.getContext(), enableOp.getGroupName()));
271 
272  return success();
273 }
274 
275 // CompileInvoke is used to convert invoke operations to group operations and
276 // enable operations.
277 class CompileInvoke {
278 public:
279  CompileInvoke(ComponentOp component, OpBuilder builder)
280  : component(component), builder(builder) {}
281  void compile();
282 
283 private:
284  void lowerInvokeOp(InvokeOp invokeOp);
285  std::string getTransitionName(InvokeOp invokeOp);
286  ComponentOp component;
287  OpBuilder builder;
288  // Part of the group name. It is used to generate unique group names, the
289  // unique counter is reused across multiple calls to lowerInvokeOp, so the
290  // loop that's checking for name uniqueness usually finds a unique name on the
291  // first try.
292  size_t transitionNameTail = 0;
293 };
294 
295 // Access all invokeOp.
296 void CompileInvoke::compile() {
297  llvm::SmallVector<InvokeOp> invokeOps =
298  component.getControlOp().getInvokeOps();
299  for (InvokeOp op : invokeOps)
300  lowerInvokeOp(op);
301 }
302 
303 // Get the name of the generation group.
304 std::string CompileInvoke::getTransitionName(InvokeOp invokeOp) {
305  llvm::StringRef callee = invokeOp.getCallee();
306  std::string transitionNameHead = "invoke_" + callee.str() + "_";
307  std::string transitionName;
308 
309  // The following loop is used to check if the transitionName already exists.
310  // If it does, the loop regenerates the transitionName.
311  do {
312  transitionName = transitionNameHead + std::to_string(transitionNameTail++);
313  } while (component.getWiresOp().lookupSymbol(transitionName));
314  return transitionName;
315 }
316 
317 // Convert an invoke operation to a group operation and an enable operation.
318 void CompileInvoke::lowerInvokeOp(InvokeOp invokeOp) {
319  // Create a ConstantOp to assign a value to the go port.
320  Operation *prevNode = component.getWiresOp().getOperation()->getPrevNode();
321  builder.setInsertionPointAfter(prevNode);
322  hw::ConstantOp constantOp = builder.create<hw::ConstantOp>(
323  prevNode->getLoc(), builder.getI1Type(), 1);
324  Location loc = component.getWiresOp().getLoc();
325 
326  // Set the insertion point at the end of the wires block.
327  builder.setInsertionPointToEnd(component.getWiresOp().getBodyBlock());
328  std::string transitionName = getTransitionName(invokeOp);
329  GroupOp groupOp = builder.create<GroupOp>(loc, transitionName);
330  builder.setInsertionPointToStart(groupOp.getBodyBlock());
331  Value go = invokeOp.getInstGoValue();
332 
333  // Assign a value to the go port.
334  builder.create<AssignOp>(loc, go, constantOp);
335  auto ports = invokeOp.getPorts();
336  auto inputs = invokeOp.getInputs();
337 
338  // Generate a series of assignment operations from a list of parameters.
339  for (auto [port, input] : llvm::zip(ports, inputs))
340  builder.create<AssignOp>(loc, port, input);
341  Value done = invokeOp.getInstDoneValue();
342 
343  // Generate a group_done operation with the instance's done port.
344  builder.create<calyx::GroupDoneOp>(loc, done);
345  builder.setInsertionPointAfter(invokeOp.getOperation());
346  builder.create<EnableOp>(invokeOp.getLoc(), transitionName);
347  invokeOp.erase();
348 }
349 
350 class CalyxToFSMPass : public circt::impl::CalyxToFSMBase<CalyxToFSMPass> {
351 public:
352  void runOnOperation() override;
353 }; // end anonymous namespace
354 
355 void CalyxToFSMPass::runOnOperation() {
356  ComponentOp component = getOperation();
357  OpBuilder builder(&getContext());
358  auto ctrlOp = component.getControlOp();
359  assert(ctrlOp.getBodyBlock()->getOperations().size() == 1 &&
360  "Expected a single top-level operation in the schedule");
361  CompileInvoke compileInvoke(component, builder);
362  compileInvoke.compile();
363  Operation &topLevelCtrlOp = ctrlOp.getBodyBlock()->front();
364  builder.setInsertionPoint(&topLevelCtrlOp);
365 
366  // Create a side-effect-only FSM (no inputs, no outputs) which will strictly
367  // refer to the symbols and SSA values defined in the regions of the
368  // ComponentOp. This makes for an intermediate step, which allows for
369  // outlining the FSM (materializing FSM I/O) at a later point.
370  auto machineName = ("control_" + component.getName()).str();
371  auto funcType = FunctionType::get(&getContext(), {}, {});
372  auto machine =
373  builder.create<MachineOp>(ctrlOp.getLoc(), machineName,
374  /*initialState=*/"fsm_entry", funcType);
375  auto graph = FSMGraph(machine);
376 
377  SymbolCache sc;
378  sc.addDefinitions(machine);
379 
380  // Create entry and exit states
381  auto entryState =
382  graph.createState(builder, ctrlOp.getLoc(), calyxToFSM::sEntryStateName)
383  ->getState();
384  auto exitState =
385  graph.createState(builder, ctrlOp.getLoc(), calyxToFSM::sExitStateName)
386  ->getState();
387 
388  auto visitor = CompileFSMVisitor(sc, graph);
389  if (failed(visitor.dispatch(entryState, &topLevelCtrlOp, exitState))) {
390  signalPassFailure();
391  return;
392  }
393 
394  // Remove the top-level calyx control operation that we've now converted to an
395  // FSM.
396  topLevelCtrlOp.erase();
397 
398  // Add the set of compiled groups as an attribute to the fsm.
399  machine->setAttr(
400  "compiledGroups",
401  ArrayAttr::get(builder.getContext(), visitor.getCompiledGroups()));
402 }
403 
404 } // namespace
405 
406 std::unique_ptr<mlir::Pass> circt::createCalyxToFSMPass() {
407  return std::make_unique<CalyxToFSMPass>();
408 }
assert(baseType &&"element must be base type")
A namespace that is used to store existing names and generate new names in some scope within the IR.
Definition: Namespace.h:30
void addDefinitions(mlir::Operation *top)
Populate the symbol cache with all symbol-defining operations within the 'top' operation.
Definition: SymCache.cpp:23
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition: SymCache.h:85
def create(data_type, value)
Definition: hw.py:433
static constexpr std::string_view sExitStateName
Definition: CalyxToFSM.h:36
static constexpr std::string_view sEntryStateName
Definition: CalyxToFSM.h:35
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:55
Value createOrFoldNot(Location loc, Value value, OpBuilder &builder, bool twoState=false)
Create a `‘Not’' gate on a value.
Definition: CombOps.cpp:48
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
std::unique_ptr< mlir::Pass > createCalyxToFSMPass()
Definition: CalyxToFSM.cpp:406
Definition: fsm.py:1