CIRCT  19.0.0git
FSMOps.cpp
Go to the documentation of this file.
1 //===- FSMOps.cpp - Implementation of FSM dialect operations --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Interfaces/FunctionImplementation.h"
16 #include "llvm/Support/FormatVariadic.h"
17 
18 using namespace mlir;
19 using namespace circt;
20 using namespace fsm;
21 
22 //===----------------------------------------------------------------------===//
23 // MachineOp
24 //===----------------------------------------------------------------------===//
25 
26 void MachineOp::build(OpBuilder &builder, OperationState &state, StringRef name,
27  StringRef initialStateName, FunctionType type,
28  ArrayRef<NamedAttribute> attrs,
29  ArrayRef<DictionaryAttr> argAttrs) {
30  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
31  builder.getStringAttr(name));
32  state.addAttribute(MachineOp::getFunctionTypeAttrName(state.name),
33  TypeAttr::get(type));
34  state.addAttribute("initialState",
35  StringAttr::get(state.getContext(), initialStateName));
36  state.attributes.append(attrs.begin(), attrs.end());
37  Region *region = state.addRegion();
38  Block *body = new Block();
39  region->push_back(body);
40  body->addArguments(
41  type.getInputs(),
42  SmallVector<Location, 4>(type.getNumInputs(), builder.getUnknownLoc()));
43 
44  if (argAttrs.empty())
45  return;
46  assert(type.getNumInputs() == argAttrs.size());
47  function_interface_impl::addArgAndResultAttrs(
48  builder, state, argAttrs,
49  /*resultAttrs=*/std::nullopt, MachineOp::getArgAttrsAttrName(state.name),
50  MachineOp::getResAttrsAttrName(state.name));
51 }
52 
53 /// Get the initial state of the machine.
54 StateOp MachineOp::getInitialStateOp() {
55  return dyn_cast_or_null<StateOp>(lookupSymbol(getInitialState()));
56 }
57 
58 StringAttr MachineOp::getArgName(size_t i) {
59  if (auto args = getArgNames())
60  return cast<StringAttr>((*args)[i]);
61 
62  return StringAttr::get(getContext(), "in" + std::to_string(i));
63 }
64 
65 StringAttr MachineOp::getResName(size_t i) {
66  if (auto resNameAttrs = getResNames())
67  return cast<StringAttr>((*resNameAttrs)[i]);
68 
69  return StringAttr::get(getContext(), "out" + std::to_string(i));
70 }
71 
72 /// Get the port information of the machine.
73 void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
74  ports.clear();
75  auto machineType = getFunctionType();
76  for (unsigned i = 0, e = machineType.getNumInputs(); i < e; ++i) {
77  hw::PortInfo port;
78  port.name = getArgName(i);
79  if (!port.name)
80  port.name = StringAttr::get(getContext(), "in" + std::to_string(i));
82  port.type = machineType.getInput(i);
83  port.argNum = i;
84  ports.push_back(port);
85  }
86 
87  for (unsigned i = 0, e = machineType.getNumResults(); i < e; ++i) {
88  hw::PortInfo port;
89  port.name = getResName(i);
90  if (!port.name)
91  port.name = StringAttr::get(getContext(), "out" + std::to_string(i));
93  port.type = machineType.getResult(i);
94  port.argNum = i;
95  ports.push_back(port);
96  }
97 }
98 
99 ParseResult MachineOp::parse(OpAsmParser &parser, OperationState &result) {
100  auto buildFuncType =
101  [&](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
102  function_interface_impl::VariadicFlag,
103  std::string &) { return builder.getFunctionType(argTypes, results); };
104 
105  return function_interface_impl::parseFunctionOp(
106  parser, result, /*allowVariadic=*/false,
107  MachineOp::getFunctionTypeAttrName(result.name), buildFuncType,
108  MachineOp::getArgAttrsAttrName(result.name),
109  MachineOp::getResAttrsAttrName(result.name));
110 }
111 
112 void MachineOp::print(OpAsmPrinter &p) {
113  function_interface_impl::printFunctionOp(
114  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
115  getArgAttrsAttrName(), getResAttrsAttrName());
116 }
117 
118 static LogicalResult compareTypes(Location loc, TypeRange rangeA,
119  TypeRange rangeB) {
120  if (rangeA.size() != rangeB.size())
121  return emitError(loc) << "mismatch in number of types compared ("
122  << rangeA.size() << " != " << rangeB.size() << ")";
123 
124  size_t index = 0;
125  for (auto zip : llvm::zip(rangeA, rangeB)) {
126  auto typeA = std::get<0>(zip);
127  auto typeB = std::get<1>(zip);
128  if (typeA != typeB)
129  return emitError(loc) << "type mismatch at index " << index << " ("
130  << typeA << " != " << typeB << ")";
131  ++index;
132  }
133 
134  return success();
135 }
136 
137 LogicalResult MachineOp::verify() {
138  // If this function is external there is nothing to do.
139  if (isExternal())
140  return success();
141 
142  // Verify that the argument list of the function and the arg list of the entry
143  // block line up. The trait already verified that the number of arguments is
144  // the same between the signature and the block.
145  if (failed(compareTypes(getLoc(), getArgumentTypes(),
146  front().getArgumentTypes())))
147  return emitOpError(
148  "entry block argument types must match the machine input types");
149 
150  // Verify that the machine only has one block terminated with OutputOp.
151  if (!llvm::hasSingleElement(*this))
152  return emitOpError("must only have a single block");
153 
154  // Verify that the initial state exists
155  if (!getInitialStateOp())
156  return emitOpError("initial state '" + getInitialState() +
157  "' was not defined in the machine");
158 
159  if (getArgNames() && getArgNames()->size() != getArgumentTypes().size())
160  return emitOpError() << "number of machine arguments ("
161  << getArgumentTypes().size()
162  << ") does "
163  "not match the provided number "
164  "of argument names ("
165  << getArgNames()->size() << ")";
166 
167  if (getResNames() && getResNames()->size() != getResultTypes().size())
168  return emitOpError() << "number of machine results ("
169  << getResultTypes().size()
170  << ") does "
171  "not match the provided number "
172  "of result names ("
173  << getResNames()->size() << ")";
174 
175  return success();
176 }
177 
178 SmallVector<::circt::hw::PortInfo> MachineOp::getPortList() {
179  SmallVector<hw::PortInfo> ports;
180  auto argNames = getArgNames();
181  auto argTypes = getFunctionType().getInputs();
182  for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
183  bool isInOut = false;
184  auto type = argTypes[i];
185 
186  if (auto inout = dyn_cast<hw::InOutType>(type)) {
187  isInOut = true;
188  type = inout.getElementType();
189  }
190 
191  auto direction = isInOut ? hw::ModulePort::Direction::InOut
193 
194  ports.push_back(
195  {{argNames ? cast<StringAttr>((*argNames)[i])
196  : StringAttr::get(getContext(), Twine("input") + Twine(i)),
197  type, direction},
198  i,
199  {},
200  {}});
201  }
202 
203  auto resultNames = getResNames();
204  auto resultTypes = getFunctionType().getResults();
205  for (unsigned i = 0, e = resultTypes.size(); i < e; ++i) {
206  ports.push_back({{resultNames ? cast<StringAttr>((*resultNames)[i])
207  : StringAttr::get(getContext(),
208  Twine("output") + Twine(i)),
209  resultTypes[i], hw::ModulePort::Direction::Output},
210  i,
211  {},
212  {}});
213  }
214  return ports;
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // InstanceOp
219 //===----------------------------------------------------------------------===//
220 
221 /// Lookup the machine for the symbol. This returns null on invalid IR.
222 MachineOp InstanceOp::getMachineOp() {
223  auto module = (*this)->getParentOfType<ModuleOp>();
224  return module.lookupSymbol<MachineOp>(getMachine());
225 }
226 
227 LogicalResult InstanceOp::verify() {
228  auto m = getMachineOp();
229  if (!m)
230  return emitError("cannot find machine definition '") << getMachine() << "'";
231 
232  return success();
233 }
234 
236  function_ref<void(Value, StringRef)> setNameFn) {
237  setNameFn(getInstance(), getName());
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // TriggerOp
242 //===----------------------------------------------------------------------===//
243 
244 template <typename OpType>
245 static LogicalResult verifyCallerTypes(OpType op) {
246  auto machine = op.getMachineOp();
247  if (!machine)
248  return op.emitError("cannot find machine definition");
249 
250  // Check operand types first.
251  if (failed(compareTypes(op.getLoc(), machine.getArgumentTypes(),
252  op.getInputs().getTypes()))) {
253  auto diag =
254  op.emitOpError("operand types must match the machine input types");
255  diag.attachNote(machine->getLoc()) << "original machine declared here";
256  return failure();
257  }
258 
259  // Check result types.
260  if (failed(compareTypes(op.getLoc(), machine.getResultTypes(),
261  op.getOutputs().getTypes()))) {
262  auto diag =
263  op.emitOpError("result types must match the machine output types");
264  diag.attachNote(machine->getLoc()) << "original machine declared here";
265  return failure();
266  }
267 
268  return success();
269 }
270 
271 /// Lookup the machine for the symbol. This returns null on invalid IR.
272 MachineOp TriggerOp::getMachineOp() {
273  auto instanceOp = getInstance().getDefiningOp<InstanceOp>();
274  if (!instanceOp)
275  return nullptr;
276 
277  return instanceOp.getMachineOp();
278 }
279 
280 LogicalResult TriggerOp::verify() { return verifyCallerTypes(*this); }
281 
282 //===----------------------------------------------------------------------===//
283 // HWInstanceOp
284 //===----------------------------------------------------------------------===//
285 
286 // InstanceOpInterface interface
287 
288 /// Lookup the machine for the symbol. This returns null on invalid IR.
289 MachineOp HWInstanceOp::getMachineOp() {
290  auto module = (*this)->getParentOfType<ModuleOp>();
291  return module.lookupSymbol<MachineOp>(getMachine());
292 }
293 
294 LogicalResult HWInstanceOp::verify() { return verifyCallerTypes(*this); }
295 
296 SmallVector<hw::PortInfo> HWInstanceOp::getPortList() {
297  return getMachineOp().getPortList();
298 }
299 
300 /// Module name is the same as the machine name.
301 StringRef HWInstanceOp::getModuleName() { return getMachine(); }
302 FlatSymbolRefAttr HWInstanceOp::getModuleNameAttr() { return getMachineAttr(); }
303 
304 mlir::StringAttr HWInstanceOp::getInstanceNameAttr() { return getNameAttr(); }
305 
306 llvm::StringRef HWInstanceOp::getInstanceName() { return getName(); }
307 
308 //===----------------------------------------------------------------------===//
309 // StateOp
310 //===----------------------------------------------------------------------===//
311 
312 void StateOp::build(OpBuilder &builder, OperationState &state,
313  StringRef stateName) {
314  OpBuilder::InsertionGuard guard(builder);
315  Region *output = state.addRegion();
316  output->push_back(new Block());
317  builder.setInsertionPointToEnd(&output->back());
318  builder.create<fsm::OutputOp>(state.location);
319  Region *transitions = state.addRegion();
320  transitions->push_back(new Block());
321  state.addAttribute("sym_name", builder.getStringAttr(stateName));
322 }
323 
324 SetVector<StateOp> StateOp::getNextStates() {
325  SmallVector<StateOp> nextStates;
326  llvm::transform(
327  getTransitions().getOps<TransitionOp>(),
328  std::inserter(nextStates, nextStates.begin()),
329  [](TransitionOp transition) { return transition.getNextStateOp(); });
330  return SetVector<StateOp>(nextStates.begin(), nextStates.end());
331 }
332 
333 LogicalResult StateOp::canonicalize(StateOp op, PatternRewriter &rewriter) {
334  bool hasAlwaysTakenTransition = false;
335  SmallVector<TransitionOp, 4> transitionsToErase;
336  // Remove all transitions after an "always-taken" transition.
337  for (auto transition : op.getTransitions().getOps<TransitionOp>()) {
338  if (!hasAlwaysTakenTransition)
339  hasAlwaysTakenTransition = transition.isAlwaysTaken();
340  else
341  transitionsToErase.push_back(transition);
342  }
343 
344  for (auto transition : transitionsToErase)
345  rewriter.eraseOp(transition);
346 
347  return failure(transitionsToErase.empty());
348 }
349 
350 LogicalResult StateOp::verify() {
351  MachineOp parent = getOperation()->getParentOfType<MachineOp>();
352 
353  if (parent.getNumResults() != 0 && (getOutput().empty()))
354  return emitOpError("state must have a non-empty output region when the "
355  "machine has results.");
356 
357  if (!getOutput().empty()) {
358  // Ensure that the output block has a single OutputOp terminator.
359  Block *outputBlock = &getOutput().front();
360  if (outputBlock->empty() || !isa<fsm::OutputOp>(outputBlock->back()))
361  return emitOpError("output block must have a single OutputOp terminator");
362  }
363 
364  return success();
365 }
366 
367 Block *StateOp::ensureOutput(OpBuilder &builder) {
368  if (getOutput().empty()) {
369  OpBuilder::InsertionGuard g(builder);
370  auto *block = new Block();
371  getOutput().push_back(block);
372  builder.setInsertionPointToStart(block);
373  builder.create<fsm::OutputOp>(getLoc());
374  }
375  return &getOutput().front();
376 }
377 
378 //===----------------------------------------------------------------------===//
379 // OutputOp
380 //===----------------------------------------------------------------------===//
381 
382 LogicalResult OutputOp::verify() {
383  if ((*this)->getParentRegion() ==
384  &(*this)->getParentOfType<StateOp>().getTransitions()) {
385  if (getNumOperands() != 0)
386  emitOpError("transitions region must not output any value");
387  return success();
388  }
389 
390  // Verify that the result list of the machine and the operand list of the
391  // OutputOp line up.
392  auto machine = (*this)->getParentOfType<MachineOp>();
393  if (failed(
394  compareTypes(getLoc(), machine.getResultTypes(), getOperandTypes())))
395  return emitOpError("operand types must match the machine output types");
396 
397  return success();
398 }
399 
400 //===----------------------------------------------------------------------===//
401 // TransitionOp
402 //===----------------------------------------------------------------------===//
403 
404 void TransitionOp::build(OpBuilder &builder, OperationState &state,
405  StringRef nextState) {
406  state.addRegion(); // guard
407  state.addRegion(); // action
408  state.addAttribute("nextState",
409  FlatSymbolRefAttr::get(builder.getStringAttr(nextState)));
410 }
411 
412 void TransitionOp::build(OpBuilder &builder, OperationState &state,
413  StateOp nextState) {
414  build(builder, state, nextState.getName());
415 }
416 
417 Block *TransitionOp::ensureGuard(OpBuilder &builder) {
418  if (getGuard().empty()) {
419  OpBuilder::InsertionGuard g(builder);
420  auto *block = new Block();
421  getGuard().push_back(block);
422  builder.setInsertionPointToStart(block);
423  builder.create<fsm::ReturnOp>(getLoc());
424  }
425  return &getGuard().front();
426 }
427 
428 Block *TransitionOp::ensureAction(OpBuilder &builder) {
429  if (getAction().empty())
430  getAction().push_back(new Block());
431  return &getAction().front();
432 }
433 
434 /// Lookup the next state for the symbol. This returns null on invalid IR.
435 StateOp TransitionOp::getNextStateOp() {
436  auto machineOp = (*this)->getParentOfType<MachineOp>();
437  if (!machineOp)
438  return nullptr;
439 
440  return machineOp.lookupSymbol<StateOp>(getNextState());
441 }
442 
443 bool TransitionOp::isAlwaysTaken() {
444  if (!hasGuard())
445  return true;
446 
447  auto guardReturn = getGuardReturn();
448  if (guardReturn.getNumOperands() == 0)
449  return true;
450 
451  if (auto constantOp =
452  guardReturn.getOperand().getDefiningOp<mlir::arith::ConstantOp>())
453  return cast<BoolAttr>(constantOp.getValue()).getValue();
454 
455  return false;
456 }
457 
458 LogicalResult TransitionOp::canonicalize(TransitionOp op,
459  PatternRewriter &rewriter) {
460  if (op.hasGuard()) {
461  auto guardReturn = op.getGuardReturn();
462  if (guardReturn.getNumOperands() == 1)
463  if (auto constantOp = guardReturn.getOperand()
464  .getDefiningOp<mlir::arith::ConstantOp>()) {
465  // Simplify when the guard region returns a constant value.
466  if (cast<BoolAttr>(constantOp.getValue()).getValue()) {
467  // Replace the original return op with a new one without any operands
468  // if the constant is TRUE.
469  rewriter.setInsertionPoint(guardReturn);
470  rewriter.create<fsm::ReturnOp>(guardReturn.getLoc());
471  rewriter.eraseOp(guardReturn);
472  } else {
473  // Erase the whole transition op if the constant is FALSE, because the
474  // transition will never be taken.
475  rewriter.eraseOp(op);
476  }
477  return success();
478  }
479  }
480 
481  return failure();
482 }
483 
484 LogicalResult TransitionOp::verify() {
485  if (!getNextStateOp())
486  return emitOpError("cannot find the definition of the next state `")
487  << getNextState() << "`";
488 
489  // Verify the action region, if present.
490  if (hasGuard()) {
491  if (getGuard().front().empty() ||
492  !isa_and_nonnull<fsm::ReturnOp>(&getGuard().front().back()))
493  return emitOpError("guard region must terminate with a ReturnOp");
494  }
495 
496  // Verify the transition is located in the correct region.
497  if ((*this)->getParentRegion() != &getCurrentState().getTransitions())
498  return emitOpError("must only be located in the transitions region");
499 
500  return success();
501 }
502 
503 //===----------------------------------------------------------------------===//
504 // VariableOp
505 //===----------------------------------------------------------------------===//
506 
508  function_ref<void(Value, StringRef)> setNameFn) {
509  setNameFn(getResult(), getName());
510 }
511 
512 //===----------------------------------------------------------------------===//
513 // ReturnOp
514 //===----------------------------------------------------------------------===//
515 
516 void ReturnOp::setOperand(Value value) {
517  if (getOperand())
518  getOperation()->setOperand(0, value);
519  else
520  getOperation()->insertOperands(0, {value});
521 }
522 
523 //===----------------------------------------------------------------------===//
524 // UpdateOp
525 //===----------------------------------------------------------------------===//
526 
527 /// Get the targeted variable operation. This returns null on invalid IR.
528 VariableOp UpdateOp::getVariableOp() {
529  return getVariable().getDefiningOp<VariableOp>();
530 }
531 
532 LogicalResult UpdateOp::verify() {
533  if (!getVariable())
534  return emitOpError("destination is not a variable operation");
535 
536  if (!(*this)->getParentOfType<TransitionOp>().getAction().isAncestor(
537  (*this)->getParentRegion()))
538  return emitOpError("must only be located in the action region");
539 
540  auto transition = (*this)->getParentOfType<TransitionOp>();
541  for (auto otherUpdateOp : transition.getAction().getOps<UpdateOp>()) {
542  if (otherUpdateOp == *this)
543  continue;
544  if (otherUpdateOp.getVariable() == getVariable())
545  return otherUpdateOp.emitOpError(
546  "multiple updates to the same variable within a single action region "
547  "is disallowed");
548  }
549 
550  return success();
551 }
552 
553 //===----------------------------------------------------------------------===//
554 // TableGen generated logic
555 //===----------------------------------------------------------------------===//
556 
557 // Provide the autogenerated implementation guts for the Op classes.
558 #define GET_OP_CLASSES
559 #include "circt/Dialect/FSM/FSM.cpp.inc"
560 #undef GET_OP_CLASSES
561 
562 #include "circt/Dialect/FSM/FSMDialect.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyCallerTypes(OpType op)
Definition: FSMOps.cpp:245
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
Definition: FSMOps.cpp:118
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition: HWOps.cpp:1414
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
@ InOut
Definition: HW.h:35
static InstancePath empty
Builder builder
static StringAttr getResName(Operation *op, size_t idx)
static StringAttr getArgName(Operation *op, size_t idx)
def create(to_state)
Definition: fsm.py:110
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
Direction
The direction of a Component or Cell port.
Definition: CalyxOps.h:72
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: fsm.py:1
Definition: hw.py:1