CIRCT  19.0.0git
FSMOps.cpp
Go to the documentation of this file.
1 //===- FSMOps.cpp - Implementation of FSM dialect operations --------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 #include "mlir/Dialect/Arith/IR/Arith.h"
11 #include "mlir/Dialect/Func/IR/FuncOps.h"
12 #include "mlir/IR/BuiltinOps.h"
13 #include "mlir/IR/DialectImplementation.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Interfaces/FunctionImplementation.h"
16 #include "llvm/Support/FormatVariadic.h"
17 
18 using namespace mlir;
19 using namespace circt;
20 using namespace fsm;
21 
22 //===----------------------------------------------------------------------===//
23 // MachineOp
24 //===----------------------------------------------------------------------===//
25 
26 void MachineOp::build(OpBuilder &builder, OperationState &state, StringRef name,
27  StringRef initialStateName, FunctionType type,
28  ArrayRef<NamedAttribute> attrs,
29  ArrayRef<DictionaryAttr> argAttrs) {
30  state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
31  builder.getStringAttr(name));
32  state.addAttribute(MachineOp::getFunctionTypeAttrName(state.name),
33  TypeAttr::get(type));
34  state.addAttribute("initialState",
35  StringAttr::get(state.getContext(), initialStateName));
36  state.attributes.append(attrs.begin(), attrs.end());
37  Region *region = state.addRegion();
38  Block *body = new Block();
39  region->push_back(body);
40  body->addArguments(
41  type.getInputs(),
42  SmallVector<Location, 4>(type.getNumInputs(), builder.getUnknownLoc()));
43 
44  if (argAttrs.empty())
45  return;
46  assert(type.getNumInputs() == argAttrs.size());
47  function_interface_impl::addArgAndResultAttrs(
48  builder, state, argAttrs,
49  /*resultAttrs=*/std::nullopt, MachineOp::getArgAttrsAttrName(state.name),
50  MachineOp::getResAttrsAttrName(state.name));
51 }
52 
53 /// Get the initial state of the machine.
54 StateOp MachineOp::getInitialStateOp() {
55  return dyn_cast_or_null<StateOp>(lookupSymbol(getInitialState()));
56 }
57 
58 StringAttr MachineOp::getArgName(size_t i) {
59  if (auto args = getArgNames())
60  return cast<StringAttr>((*args)[i]);
61 
62  return StringAttr::get(getContext(), "in" + std::to_string(i));
63 }
64 
65 StringAttr MachineOp::getResName(size_t i) {
66  if (auto resNameAttrs = getResNames())
67  return cast<StringAttr>((*resNameAttrs)[i]);
68 
69  return StringAttr::get(getContext(), "out" + std::to_string(i));
70 }
71 
72 /// Get the port information of the machine.
73 void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
74  ports.clear();
75  auto machineType = getFunctionType();
76  for (unsigned i = 0, e = machineType.getNumInputs(); i < e; ++i) {
77  hw::PortInfo port;
78  port.name = getArgName(i);
79  if (!port.name)
80  port.name = StringAttr::get(getContext(), "in" + std::to_string(i));
82  port.type = machineType.getInput(i);
83  port.argNum = i;
84  ports.push_back(port);
85  }
86 
87  for (unsigned i = 0, e = machineType.getNumResults(); i < e; ++i) {
88  hw::PortInfo port;
89  port.name = getResName(i);
90  if (!port.name)
91  port.name = StringAttr::get(getContext(), "out" + std::to_string(i));
93  port.type = machineType.getResult(i);
94  port.argNum = i;
95  ports.push_back(port);
96  }
97 }
98 
99 ParseResult MachineOp::parse(OpAsmParser &parser, OperationState &result) {
100  auto buildFuncType =
101  [&](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
102  function_interface_impl::VariadicFlag,
103  std::string &) { return builder.getFunctionType(argTypes, results); };
104 
105  return function_interface_impl::parseFunctionOp(
106  parser, result, /*allowVariadic=*/false,
107  MachineOp::getFunctionTypeAttrName(result.name), buildFuncType,
108  MachineOp::getArgAttrsAttrName(result.name),
109  MachineOp::getResAttrsAttrName(result.name));
110 }
111 
112 void MachineOp::print(OpAsmPrinter &p) {
113  function_interface_impl::printFunctionOp(
114  p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
115  getArgAttrsAttrName(), getResAttrsAttrName());
116 }
117 
118 static LogicalResult compareTypes(Location loc, TypeRange rangeA,
119  TypeRange rangeB) {
120  if (rangeA.size() != rangeB.size())
121  return emitError(loc) << "mismatch in number of types compared ("
122  << rangeA.size() << " != " << rangeB.size() << ")";
123 
124  size_t index = 0;
125  for (auto zip : llvm::zip(rangeA, rangeB)) {
126  auto typeA = std::get<0>(zip);
127  auto typeB = std::get<1>(zip);
128  if (typeA != typeB)
129  return emitError(loc) << "type mismatch at index " << index << " ("
130  << typeA << " != " << typeB << ")";
131  ++index;
132  }
133 
134  return success();
135 }
136 
137 LogicalResult MachineOp::verify() {
138  // If this function is external there is nothing to do.
139  if (isExternal())
140  return success();
141 
142  // Verify that the argument list of the function and the arg list of the entry
143  // block line up. The trait already verified that the number of arguments is
144  // the same between the signature and the block.
145  if (failed(compareTypes(getLoc(), getArgumentTypes(),
146  front().getArgumentTypes())))
147  return emitOpError(
148  "entry block argument types must match the machine input types");
149 
150  // Verify that the machine only has one block terminated with OutputOp.
151  if (!llvm::hasSingleElement(*this))
152  return emitOpError("must only have a single block");
153 
154  // Verify that the initial state exists
155  if (!getInitialStateOp())
156  return emitOpError("initial state '" + getInitialState() +
157  "' was not defined in the machine");
158 
159  if (getArgNames() && getArgNames()->size() != getArgumentTypes().size())
160  return emitOpError() << "number of machine arguments ("
161  << getArgumentTypes().size()
162  << ") does "
163  "not match the provided number "
164  "of argument names ("
165  << getArgNames()->size() << ")";
166 
167  if (getResNames() && getResNames()->size() != getResultTypes().size())
168  return emitOpError() << "number of machine results ("
169  << getResultTypes().size()
170  << ") does "
171  "not match the provided number "
172  "of result names ("
173  << getResNames()->size() << ")";
174 
175  return success();
176 }
177 
178 SmallVector<::circt::hw::PortInfo> MachineOp::getPortList() {
179  SmallVector<hw::PortInfo> ports;
180  auto argNames = getArgNames();
181  auto argTypes = getFunctionType().getInputs();
182  for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
183  bool isInOut = false;
184  auto type = argTypes[i];
185 
186  if (auto inout = dyn_cast<hw::InOutType>(type)) {
187  isInOut = true;
188  type = inout.getElementType();
189  }
190 
191  auto direction = isInOut ? hw::ModulePort::Direction::InOut
193 
194  ports.push_back(
195  {{argNames ? cast<StringAttr>((*argNames)[i])
196  : StringAttr::get(getContext(), Twine("input") + Twine(i)),
197  type, direction},
198  i,
199  {},
200  {}});
201  }
202 
203  auto resultNames = getResNames();
204  auto resultTypes = getFunctionType().getResults();
205  for (unsigned i = 0, e = resultTypes.size(); i < e; ++i) {
206  ports.push_back({{resultNames ? cast<StringAttr>((*resultNames)[i])
207  : StringAttr::get(getContext(),
208  Twine("output") + Twine(i)),
209  resultTypes[i], hw::ModulePort::Direction::Output},
210  i,
211  {},
212  {}});
213  }
214  return ports;
215 }
216 
217 //===----------------------------------------------------------------------===//
218 // InstanceOp
219 //===----------------------------------------------------------------------===//
220 
221 /// Lookup the machine for the symbol. This returns null on invalid IR.
222 MachineOp InstanceOp::getMachineOp() {
223  auto module = (*this)->getParentOfType<ModuleOp>();
224  return module.lookupSymbol<MachineOp>(getMachine());
225 }
226 
227 LogicalResult InstanceOp::verify() {
228  auto m = getMachineOp();
229  if (!m)
230  return emitError("cannot find machine definition '") << getMachine() << "'";
231 
232  return success();
233 }
234 
236  function_ref<void(Value, StringRef)> setNameFn) {
237  setNameFn(getInstance(), getName());
238 }
239 
240 //===----------------------------------------------------------------------===//
241 // TriggerOp
242 //===----------------------------------------------------------------------===//
243 
244 template <typename OpType>
245 static LogicalResult verifyCallerTypes(OpType op) {
246  auto machine = op.getMachineOp();
247  if (!machine)
248  return op.emitError("cannot find machine definition");
249 
250  // Check operand types first.
251  if (failed(compareTypes(op.getLoc(), machine.getArgumentTypes(),
252  op.getInputs().getTypes()))) {
253  auto diag =
254  op.emitOpError("operand types must match the machine input types");
255  diag.attachNote(machine->getLoc()) << "original machine declared here";
256  return failure();
257  }
258 
259  // Check result types.
260  if (failed(compareTypes(op.getLoc(), machine.getResultTypes(),
261  op.getOutputs().getTypes()))) {
262  auto diag =
263  op.emitOpError("result types must match the machine output types");
264  diag.attachNote(machine->getLoc()) << "original machine declared here";
265  return failure();
266  }
267 
268  return success();
269 }
270 
271 /// Lookup the machine for the symbol. This returns null on invalid IR.
272 MachineOp TriggerOp::getMachineOp() {
273  auto instanceOp = getInstance().getDefiningOp<InstanceOp>();
274  if (!instanceOp)
275  return nullptr;
276 
277  return instanceOp.getMachineOp();
278 }
279 
280 LogicalResult TriggerOp::verify() { return verifyCallerTypes(*this); }
281 
282 //===----------------------------------------------------------------------===//
283 // HWInstanceOp
284 //===----------------------------------------------------------------------===//
285 
286 // InstanceOpInterface interface
287 
288 /// Lookup the machine for the symbol. This returns null on invalid IR.
289 MachineOp HWInstanceOp::getMachineOp() {
290  auto module = (*this)->getParentOfType<ModuleOp>();
291  return module.lookupSymbol<MachineOp>(getMachine());
292 }
293 
294 LogicalResult HWInstanceOp::verify() { return verifyCallerTypes(*this); }
295 
296 SmallVector<hw::PortInfo> HWInstanceOp::getPortList() {
297  return getMachineOp().getPortList();
298 }
299 
300 /// Module name is the same as the machine name.
301 StringRef HWInstanceOp::getModuleName() { return getMachine(); }
302 FlatSymbolRefAttr HWInstanceOp::getModuleNameAttr() { return getMachineAttr(); }
303 
304 mlir::StringAttr HWInstanceOp::getInstanceNameAttr() { return getNameAttr(); }
305 
306 llvm::StringRef HWInstanceOp::getInstanceName() { return getName(); }
307 
308 //===----------------------------------------------------------------------===//
309 // StateOp
310 //===----------------------------------------------------------------------===//
311 
312 void StateOp::build(OpBuilder &builder, OperationState &state,
313  StringRef stateName) {
314  OpBuilder::InsertionGuard guard(builder);
315  Region *output = state.addRegion();
316  output->push_back(new Block());
317  builder.setInsertionPointToEnd(&output->back());
318  builder.create<fsm::OutputOp>(state.location);
319  Region *transitions = state.addRegion();
320  transitions->push_back(new Block());
321  state.addAttribute("sym_name", builder.getStringAttr(stateName));
322 }
323 
324 void StateOp::build(OpBuilder &builder, OperationState &state,
325  StringRef stateName, ValueRange outputs) {
326  OpBuilder::InsertionGuard guard(builder);
327  Region *output = state.addRegion();
328  output->push_back(new Block());
329  builder.setInsertionPointToEnd(&output->back());
330  builder.create<fsm::OutputOp>(state.location, outputs);
331  Region *transitions = state.addRegion();
332  transitions->push_back(new Block());
333  state.addAttribute("sym_name", builder.getStringAttr(stateName));
334 }
335 
336 SetVector<StateOp> StateOp::getNextStates() {
337  SmallVector<StateOp> nextStates;
338  llvm::transform(
339  getTransitions().getOps<TransitionOp>(),
340  std::inserter(nextStates, nextStates.begin()),
341  [](TransitionOp transition) { return transition.getNextStateOp(); });
342  return SetVector<StateOp>(nextStates.begin(), nextStates.end());
343 }
344 
345 LogicalResult StateOp::canonicalize(StateOp op, PatternRewriter &rewriter) {
346  bool hasAlwaysTakenTransition = false;
347  SmallVector<TransitionOp, 4> transitionsToErase;
348  // Remove all transitions after an "always-taken" transition.
349  for (auto transition : op.getTransitions().getOps<TransitionOp>()) {
350  if (!hasAlwaysTakenTransition)
351  hasAlwaysTakenTransition = transition.isAlwaysTaken();
352  else
353  transitionsToErase.push_back(transition);
354  }
355 
356  for (auto transition : transitionsToErase)
357  rewriter.eraseOp(transition);
358 
359  return failure(transitionsToErase.empty());
360 }
361 
362 LogicalResult StateOp::verify() {
363  MachineOp parent = getOperation()->getParentOfType<MachineOp>();
364 
365  if (parent.getNumResults() != 0 && (getOutput().empty()))
366  return emitOpError("state must have a non-empty output region when the "
367  "machine has results.");
368 
369  if (!getOutput().empty()) {
370  // Ensure that the output block has a single OutputOp terminator.
371  Block *outputBlock = &getOutput().front();
372  if (outputBlock->empty() || !isa<fsm::OutputOp>(outputBlock->back()))
373  return emitOpError("output block must have a single OutputOp terminator");
374  }
375 
376  return success();
377 }
378 
379 Block *StateOp::ensureOutput(OpBuilder &builder) {
380  if (getOutput().empty()) {
381  OpBuilder::InsertionGuard g(builder);
382  auto *block = new Block();
383  getOutput().push_back(block);
384  builder.setInsertionPointToStart(block);
385  builder.create<fsm::OutputOp>(getLoc());
386  }
387  return &getOutput().front();
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // OutputOp
392 //===----------------------------------------------------------------------===//
393 
394 LogicalResult OutputOp::verify() {
395  if ((*this)->getParentRegion() ==
396  &(*this)->getParentOfType<StateOp>().getTransitions()) {
397  if (getNumOperands() != 0)
398  emitOpError("transitions region must not output any value");
399  return success();
400  }
401 
402  // Verify that the result list of the machine and the operand list of the
403  // OutputOp line up.
404  auto machine = (*this)->getParentOfType<MachineOp>();
405  if (failed(
406  compareTypes(getLoc(), machine.getResultTypes(), getOperandTypes())))
407  return emitOpError("operand types must match the machine output types");
408 
409  return success();
410 }
411 
412 //===----------------------------------------------------------------------===//
413 // TransitionOp
414 //===----------------------------------------------------------------------===//
415 
416 void TransitionOp::build(OpBuilder &builder, OperationState &state,
417  StateOp nextState) {
418  build(builder, state, nextState.getName());
419 }
420 
421 void TransitionOp::build(OpBuilder &builder, OperationState &state,
422  StringRef nextState,
423  llvm::function_ref<void()> guardCtor,
424  llvm::function_ref<void()> actionCtor) {
425  state.addAttribute("nextState",
426  FlatSymbolRefAttr::get(builder.getStringAttr(nextState)));
427  OpBuilder::InsertionGuard guard(builder);
428 
429  Region *guardRegion = state.addRegion(); // guard
430  if (guardCtor) {
431  builder.createBlock(guardRegion);
432  guardCtor();
433  }
434 
435  Region *actionRegion = state.addRegion(); // action
436  if (actionCtor) {
437  builder.createBlock(actionRegion);
438  actionCtor();
439  }
440 }
441 
442 Block *TransitionOp::ensureGuard(OpBuilder &builder) {
443  if (getGuard().empty()) {
444  OpBuilder::InsertionGuard g(builder);
445  auto *block = new Block();
446  getGuard().push_back(block);
447  builder.setInsertionPointToStart(block);
448  builder.create<fsm::ReturnOp>(getLoc());
449  }
450  return &getGuard().front();
451 }
452 
453 Block *TransitionOp::ensureAction(OpBuilder &builder) {
454  if (getAction().empty())
455  getAction().push_back(new Block());
456  return &getAction().front();
457 }
458 
459 /// Lookup the next state for the symbol. This returns null on invalid IR.
460 StateOp TransitionOp::getNextStateOp() {
461  auto machineOp = (*this)->getParentOfType<MachineOp>();
462  if (!machineOp)
463  return nullptr;
464 
465  return machineOp.lookupSymbol<StateOp>(getNextState());
466 }
467 
468 bool TransitionOp::isAlwaysTaken() {
469  if (!hasGuard())
470  return true;
471 
472  auto guardReturn = getGuardReturn();
473  if (guardReturn.getNumOperands() == 0)
474  return true;
475 
476  if (auto constantOp =
477  guardReturn.getOperand().getDefiningOp<mlir::arith::ConstantOp>())
478  return cast<BoolAttr>(constantOp.getValue()).getValue();
479 
480  return false;
481 }
482 
484  PatternRewriter &rewriter) {
485  if (op.hasGuard()) {
486  auto guardReturn = op.getGuardReturn();
487  if (guardReturn.getNumOperands() == 1)
488  if (auto constantOp = guardReturn.getOperand()
489  .getDefiningOp<mlir::arith::ConstantOp>()) {
490  // Simplify when the guard region returns a constant value.
491  if (cast<BoolAttr>(constantOp.getValue()).getValue()) {
492  // Replace the original return op with a new one without any operands
493  // if the constant is TRUE.
494  rewriter.setInsertionPoint(guardReturn);
495  rewriter.create<fsm::ReturnOp>(guardReturn.getLoc());
496  rewriter.eraseOp(guardReturn);
497  } else {
498  // Erase the whole transition op if the constant is FALSE, because the
499  // transition will never be taken.
500  rewriter.eraseOp(op);
501  }
502  return success();
503  }
504  }
505 
506  return failure();
507 }
508 
509 LogicalResult TransitionOp::verify() {
510  if (!getNextStateOp())
511  return emitOpError("cannot find the definition of the next state `")
512  << getNextState() << "`";
513 
514  // Verify the action region, if present.
515  if (hasGuard()) {
516  if (getGuard().front().empty() ||
517  !isa_and_nonnull<fsm::ReturnOp>(&getGuard().front().back()))
518  return emitOpError("guard region must terminate with a ReturnOp");
519  }
520 
521  // Verify the transition is located in the correct region.
522  if ((*this)->getParentRegion() != &getCurrentState().getTransitions())
523  return emitOpError("must only be located in the transitions region");
524 
525  return success();
526 }
527 
528 //===----------------------------------------------------------------------===//
529 // VariableOp
530 //===----------------------------------------------------------------------===//
531 
533  function_ref<void(Value, StringRef)> setNameFn) {
534  setNameFn(getResult(), getName());
535 }
536 
537 //===----------------------------------------------------------------------===//
538 // ReturnOp
539 //===----------------------------------------------------------------------===//
540 
541 void ReturnOp::setOperand(Value value) {
542  if (getOperand())
543  getOperation()->setOperand(0, value);
544  else
545  getOperation()->insertOperands(0, {value});
546 }
547 
548 //===----------------------------------------------------------------------===//
549 // UpdateOp
550 //===----------------------------------------------------------------------===//
551 
552 /// Get the targeted variable operation. This returns null on invalid IR.
553 VariableOp UpdateOp::getVariableOp() {
554  return getVariable().getDefiningOp<VariableOp>();
555 }
556 
557 LogicalResult UpdateOp::verify() {
558  if (!getVariable())
559  return emitOpError("destination is not a variable operation");
560 
561  if (!(*this)->getParentOfType<TransitionOp>().getAction().isAncestor(
562  (*this)->getParentRegion()))
563  return emitOpError("must only be located in the action region");
564 
565  auto transition = (*this)->getParentOfType<TransitionOp>();
566  for (auto otherUpdateOp : transition.getAction().getOps<UpdateOp>()) {
567  if (otherUpdateOp == *this)
568  continue;
569  if (otherUpdateOp.getVariable() == getVariable())
570  return otherUpdateOp.emitOpError(
571  "multiple updates to the same variable within a single action region "
572  "is disallowed");
573  }
574 
575  return success();
576 }
577 
578 //===----------------------------------------------------------------------===//
579 // TableGen generated logic
580 //===----------------------------------------------------------------------===//
581 
582 // Provide the autogenerated implementation guts for the Op classes.
583 #define GET_OP_CLASSES
584 #include "circt/Dialect/FSM/FSM.cpp.inc"
585 #undef GET_OP_CLASSES
586 
587 #include "circt/Dialect/FSM/FSMDialect.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyCallerTypes(OpType op)
Definition: FSMOps.cpp:245
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
Definition: FSMOps.cpp:118
static SmallVector< PortInfo > getPortList(ModuleTy &mod)
Definition: HWOps.cpp:1414
@ Input
Definition: HW.h:35
@ Output
Definition: HW.h:35
@ InOut
Definition: HW.h:35
static InstancePath empty
static StringAttr getResName(Operation *op, size_t idx)
static StringAttr getArgName(Operation *op, size_t idx)
def create(to_state)
Definition: fsm.py:110
static LogicalResult canonicalize(Op op, PatternRewriter &rewriter)
Definition: VerifOps.cpp:66
static LogicalResult verify(Value clock, bool eventExists, mlir::Location loc)
Definition: SVOps.cpp:2443
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition: CalyxOps.cpp:54
Direction
The direction of a Component or Cell port.
Definition: CalyxOps.h:72
std::string getInstanceName(mlir::func::CallOp callOp)
A helper function to get the instance name.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
void getAsmResultNames(OpAsmSetValueNameFn setNameFn, StringRef instanceName, ArrayAttr resultNames, ValueRange results)
Suggest a name for each result value based on the saved result names attribute.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: fsm.py:1
Definition: hw.py:1