CIRCT 20.0.0git
Loading...
Searching...
No Matches
FSMOps.cpp
Go to the documentation of this file.
1//===- FSMOps.cpp - Implementation of FSM dialect operations --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/BuiltinOps.h"
13#include "mlir/IR/DialectImplementation.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Interfaces/FunctionImplementation.h"
16#include "llvm/Support/FormatVariadic.h"
17
18using namespace mlir;
19using namespace circt;
20using namespace fsm;
21
22//===----------------------------------------------------------------------===//
23// MachineOp
24//===----------------------------------------------------------------------===//
25
26void MachineOp::build(OpBuilder &builder, OperationState &state, StringRef name,
27 StringRef initialStateName, FunctionType type,
28 ArrayRef<NamedAttribute> attrs,
29 ArrayRef<DictionaryAttr> argAttrs) {
30 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
31 builder.getStringAttr(name));
32 state.addAttribute(MachineOp::getFunctionTypeAttrName(state.name),
33 TypeAttr::get(type));
34 state.addAttribute("initialState",
35 StringAttr::get(state.getContext(), initialStateName));
36 state.attributes.append(attrs.begin(), attrs.end());
37 Region *region = state.addRegion();
38 Block *body = new Block();
39 region->push_back(body);
40 body->addArguments(
41 type.getInputs(),
42 SmallVector<Location, 4>(type.getNumInputs(), builder.getUnknownLoc()));
43
44 if (argAttrs.empty())
45 return;
46 assert(type.getNumInputs() == argAttrs.size());
47 function_interface_impl::addArgAndResultAttrs(
48 builder, state, argAttrs,
49 /*resultAttrs=*/std::nullopt, MachineOp::getArgAttrsAttrName(state.name),
50 MachineOp::getResAttrsAttrName(state.name));
51}
52
53/// Get the initial state of the machine.
54StateOp MachineOp::getInitialStateOp() {
55 return dyn_cast_or_null<StateOp>(lookupSymbol(getInitialState()));
56}
57
58StringAttr MachineOp::getArgName(size_t i) {
59 if (auto args = getArgNames())
60 return cast<StringAttr>((*args)[i]);
61
62 return StringAttr::get(getContext(), "in" + std::to_string(i));
63}
64
65StringAttr MachineOp::getResName(size_t i) {
66 if (auto resNameAttrs = getResNames())
67 return cast<StringAttr>((*resNameAttrs)[i]);
68
69 return StringAttr::get(getContext(), "out" + std::to_string(i));
70}
71
72/// Get the port information of the machine.
73void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
74 ports.clear();
75 auto machineType = getFunctionType();
76 for (unsigned i = 0, e = machineType.getNumInputs(); i < e; ++i) {
77 hw::PortInfo port;
78 port.name = getArgName(i);
79 if (!port.name)
80 port.name = StringAttr::get(getContext(), "in" + std::to_string(i));
82 port.type = machineType.getInput(i);
83 port.argNum = i;
84 ports.push_back(port);
85 }
86
87 for (unsigned i = 0, e = machineType.getNumResults(); i < e; ++i) {
88 hw::PortInfo port;
89 port.name = getResName(i);
90 if (!port.name)
91 port.name = StringAttr::get(getContext(), "out" + std::to_string(i));
93 port.type = machineType.getResult(i);
94 port.argNum = i;
95 ports.push_back(port);
96 }
97}
98
99ParseResult MachineOp::parse(OpAsmParser &parser, OperationState &result) {
100 auto buildFuncType =
101 [&](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
102 function_interface_impl::VariadicFlag,
103 std::string &) { return builder.getFunctionType(argTypes, results); };
104
105 return function_interface_impl::parseFunctionOp(
106 parser, result, /*allowVariadic=*/false,
107 MachineOp::getFunctionTypeAttrName(result.name), buildFuncType,
108 MachineOp::getArgAttrsAttrName(result.name),
109 MachineOp::getResAttrsAttrName(result.name));
110}
111
112void MachineOp::print(OpAsmPrinter &p) {
113 function_interface_impl::printFunctionOp(
114 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
115 getArgAttrsAttrName(), getResAttrsAttrName());
116}
117
118static LogicalResult compareTypes(Location loc, TypeRange rangeA,
119 TypeRange rangeB) {
120 if (rangeA.size() != rangeB.size())
121 return emitError(loc) << "mismatch in number of types compared ("
122 << rangeA.size() << " != " << rangeB.size() << ")";
123
124 size_t index = 0;
125 for (auto zip : llvm::zip(rangeA, rangeB)) {
126 auto typeA = std::get<0>(zip);
127 auto typeB = std::get<1>(zip);
128 if (typeA != typeB)
129 return emitError(loc) << "type mismatch at index " << index << " ("
130 << typeA << " != " << typeB << ")";
131 ++index;
132 }
133
134 return success();
135}
136
137LogicalResult MachineOp::verify() {
138 // If this function is external there is nothing to do.
139 if (isExternal())
140 return success();
141
142 // Verify that the argument list of the function and the arg list of the entry
143 // block line up. The trait already verified that the number of arguments is
144 // the same between the signature and the block.
145 if (failed(compareTypes(getLoc(), getArgumentTypes(),
146 front().getArgumentTypes())))
147 return emitOpError(
148 "entry block argument types must match the machine input types");
149
150 // Verify that the machine only has one block terminated with OutputOp.
151 if (!llvm::hasSingleElement(*this))
152 return emitOpError("must only have a single block");
153
154 // Verify that the initial state exists
155 if (!getInitialStateOp())
156 return emitOpError("initial state '" + getInitialState() +
157 "' was not defined in the machine");
158
159 if (getArgNames() && getArgNames()->size() != getArgumentTypes().size())
160 return emitOpError() << "number of machine arguments ("
161 << getArgumentTypes().size()
162 << ") does "
163 "not match the provided number "
164 "of argument names ("
165 << getArgNames()->size() << ")";
166
167 if (getResNames() && getResNames()->size() != getResultTypes().size())
168 return emitOpError() << "number of machine results ("
169 << getResultTypes().size()
170 << ") does "
171 "not match the provided number "
172 "of result names ("
173 << getResNames()->size() << ")";
174
175 return success();
176}
177
178SmallVector<::circt::hw::PortInfo> MachineOp::getPortList() {
179 SmallVector<hw::PortInfo> ports;
180 auto argNames = getArgNames();
181 auto argTypes = getFunctionType().getInputs();
182 for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
183 bool isInOut = false;
184 auto type = argTypes[i];
185
186 if (auto inout = dyn_cast<hw::InOutType>(type)) {
187 isInOut = true;
188 type = inout.getElementType();
189 }
190
191 auto direction = isInOut ? hw::ModulePort::Direction::InOut
192 : hw::ModulePort::Direction::Input;
193
194 ports.push_back(
195 {{argNames ? cast<StringAttr>((*argNames)[i])
196 : StringAttr::get(getContext(), Twine("input") + Twine(i)),
197 type, direction},
198 i,
199 {},
200 {}});
201 }
202
203 auto resultNames = getResNames();
204 auto resultTypes = getFunctionType().getResults();
205 for (unsigned i = 0, e = resultTypes.size(); i < e; ++i) {
206 ports.push_back({{resultNames ? cast<StringAttr>((*resultNames)[i])
207 : StringAttr::get(getContext(),
208 Twine("output") + Twine(i)),
209 resultTypes[i], hw::ModulePort::Direction::Output},
210 i,
211 {},
212 {}});
213 }
214 return ports;
215}
216
217//===----------------------------------------------------------------------===//
218// InstanceOp
219//===----------------------------------------------------------------------===//
220
221/// Lookup the machine for the symbol. This returns null on invalid IR.
222MachineOp InstanceOp::getMachineOp() {
223 auto module = (*this)->getParentOfType<ModuleOp>();
224 return module.lookupSymbol<MachineOp>(getMachine());
225}
226
227LogicalResult InstanceOp::verify() {
228 auto m = getMachineOp();
229 if (!m)
230 return emitError("cannot find machine definition '") << getMachine() << "'";
231
232 return success();
233}
234
235void InstanceOp::getAsmResultNames(
236 function_ref<void(Value, StringRef)> setNameFn) {
237 setNameFn(getInstance(), getName());
238}
239
240//===----------------------------------------------------------------------===//
241// TriggerOp
242//===----------------------------------------------------------------------===//
243
244template <typename OpType>
245static LogicalResult verifyCallerTypes(OpType op) {
246 auto machine = op.getMachineOp();
247 if (!machine)
248 return op.emitError("cannot find machine definition");
249
250 // Check operand types first.
251 if (failed(compareTypes(op.getLoc(), machine.getArgumentTypes(),
252 op.getInputs().getTypes()))) {
253 auto diag =
254 op.emitOpError("operand types must match the machine input types");
255 diag.attachNote(machine->getLoc()) << "original machine declared here";
256 return failure();
257 }
258
259 // Check result types.
260 if (failed(compareTypes(op.getLoc(), machine.getResultTypes(),
261 op.getOutputs().getTypes()))) {
262 auto diag =
263 op.emitOpError("result types must match the machine output types");
264 diag.attachNote(machine->getLoc()) << "original machine declared here";
265 return failure();
266 }
267
268 return success();
269}
270
271/// Lookup the machine for the symbol. This returns null on invalid IR.
272MachineOp TriggerOp::getMachineOp() {
273 auto instanceOp = getInstance().getDefiningOp<InstanceOp>();
274 if (!instanceOp)
275 return nullptr;
276
277 return instanceOp.getMachineOp();
278}
279
280LogicalResult TriggerOp::verify() { return verifyCallerTypes(*this); }
281
282//===----------------------------------------------------------------------===//
283// HWInstanceOp
284//===----------------------------------------------------------------------===//
285
286// InstanceOpInterface interface
287
288/// Lookup the machine for the symbol. This returns null on invalid IR.
289MachineOp HWInstanceOp::getMachineOp() {
290 auto module = (*this)->getParentOfType<ModuleOp>();
291 return module.lookupSymbol<MachineOp>(getMachine());
292}
293
294LogicalResult HWInstanceOp::verify() { return verifyCallerTypes(*this); }
295
296SmallVector<hw::PortInfo> HWInstanceOp::getPortList() {
297 return getMachineOp().getPortList();
298}
299
300/// Module name is the same as the machine name.
301StringRef HWInstanceOp::getModuleName() { return getMachine(); }
302FlatSymbolRefAttr HWInstanceOp::getModuleNameAttr() { return getMachineAttr(); }
303
304mlir::StringAttr HWInstanceOp::getInstanceNameAttr() { return getNameAttr(); }
305
306llvm::StringRef HWInstanceOp::getInstanceName() { return getName(); }
307
308//===----------------------------------------------------------------------===//
309// StateOp
310//===----------------------------------------------------------------------===//
311
312void StateOp::build(OpBuilder &builder, OperationState &state,
313 StringRef stateName) {
314 OpBuilder::InsertionGuard guard(builder);
315 Region *output = state.addRegion();
316 output->push_back(new Block());
317 builder.setInsertionPointToEnd(&output->back());
318 builder.create<fsm::OutputOp>(state.location);
319 Region *transitions = state.addRegion();
320 transitions->push_back(new Block());
321 state.addAttribute("sym_name", builder.getStringAttr(stateName));
322}
323
324void StateOp::build(OpBuilder &builder, OperationState &state,
325 StringRef stateName, ValueRange outputs) {
326 OpBuilder::InsertionGuard guard(builder);
327 Region *output = state.addRegion();
328 output->push_back(new Block());
329 builder.setInsertionPointToEnd(&output->back());
330 builder.create<fsm::OutputOp>(state.location, outputs);
331 Region *transitions = state.addRegion();
332 transitions->push_back(new Block());
333 state.addAttribute("sym_name", builder.getStringAttr(stateName));
334}
335
336SetVector<StateOp> StateOp::getNextStates() {
337 SmallVector<StateOp> nextStates;
338 llvm::transform(
339 getTransitions().getOps<TransitionOp>(),
340 std::inserter(nextStates, nextStates.begin()),
341 [](TransitionOp transition) { return transition.getNextStateOp(); });
342 return SetVector<StateOp>(nextStates.begin(), nextStates.end());
343}
344
345LogicalResult StateOp::canonicalize(StateOp op, PatternRewriter &rewriter) {
346 bool hasAlwaysTakenTransition = false;
347 SmallVector<TransitionOp, 4> transitionsToErase;
348 // Remove all transitions after an "always-taken" transition.
349 for (auto transition : op.getTransitions().getOps<TransitionOp>()) {
350 if (!hasAlwaysTakenTransition)
351 hasAlwaysTakenTransition = transition.isAlwaysTaken();
352 else
353 transitionsToErase.push_back(transition);
354 }
355
356 for (auto transition : transitionsToErase)
357 rewriter.eraseOp(transition);
358
359 return failure(transitionsToErase.empty());
360}
361
362LogicalResult StateOp::verify() {
363 MachineOp parent = getOperation()->getParentOfType<MachineOp>();
364
365 if (parent.getNumResults() != 0 && (getOutput().empty()))
366 return emitOpError("state must have a non-empty output region when the "
367 "machine has results.");
368
369 if (!getOutput().empty()) {
370 // Ensure that the output block has a single OutputOp terminator.
371 Block *outputBlock = &getOutput().front();
372 if (outputBlock->empty() || !isa<fsm::OutputOp>(outputBlock->back()))
373 return emitOpError("output block must have a single OutputOp terminator");
374 }
375
376 return success();
377}
378
379Block *StateOp::ensureOutput(OpBuilder &builder) {
380 if (getOutput().empty()) {
381 OpBuilder::InsertionGuard g(builder);
382 auto *block = new Block();
383 getOutput().push_back(block);
384 builder.setInsertionPointToStart(block);
385 builder.create<fsm::OutputOp>(getLoc());
386 }
387 return &getOutput().front();
388}
389
390//===----------------------------------------------------------------------===//
391// OutputOp
392//===----------------------------------------------------------------------===//
393
394LogicalResult OutputOp::verify() {
395 if ((*this)->getParentRegion() ==
396 &(*this)->getParentOfType<StateOp>().getTransitions()) {
397 if (getNumOperands() != 0)
398 emitOpError("transitions region must not output any value");
399 return success();
400 }
401
402 // Verify that the result list of the machine and the operand list of the
403 // OutputOp line up.
404 auto machine = (*this)->getParentOfType<MachineOp>();
405 if (failed(
406 compareTypes(getLoc(), machine.getResultTypes(), getOperandTypes())))
407 return emitOpError("operand types must match the machine output types");
408
409 return success();
410}
411
412//===----------------------------------------------------------------------===//
413// TransitionOp
414//===----------------------------------------------------------------------===//
415
416void TransitionOp::build(OpBuilder &builder, OperationState &state,
417 StateOp nextState) {
418 build(builder, state, nextState.getName());
419}
420
421void TransitionOp::build(OpBuilder &builder, OperationState &state,
422 StringRef nextState,
423 llvm::function_ref<void()> guardCtor,
424 llvm::function_ref<void()> actionCtor) {
425 state.addAttribute("nextState",
426 FlatSymbolRefAttr::get(builder.getStringAttr(nextState)));
427 OpBuilder::InsertionGuard guard(builder);
428
429 Region *guardRegion = state.addRegion(); // guard
430 if (guardCtor) {
431 builder.createBlock(guardRegion);
432 guardCtor();
433 }
434
435 Region *actionRegion = state.addRegion(); // action
436 if (actionCtor) {
437 builder.createBlock(actionRegion);
438 actionCtor();
439 }
440}
441
442Block *TransitionOp::ensureGuard(OpBuilder &builder) {
443 if (getGuard().empty()) {
444 OpBuilder::InsertionGuard g(builder);
445 auto *block = new Block();
446 getGuard().push_back(block);
447 builder.setInsertionPointToStart(block);
448 builder.create<fsm::ReturnOp>(getLoc());
449 }
450 return &getGuard().front();
451}
452
453Block *TransitionOp::ensureAction(OpBuilder &builder) {
454 if (getAction().empty())
455 getAction().push_back(new Block());
456 return &getAction().front();
457}
458
459/// Lookup the next state for the symbol. This returns null on invalid IR.
460StateOp TransitionOp::getNextStateOp() {
461 auto machineOp = (*this)->getParentOfType<MachineOp>();
462 if (!machineOp)
463 return nullptr;
464
465 return machineOp.lookupSymbol<StateOp>(getNextState());
466}
467
468bool TransitionOp::isAlwaysTaken() {
469 if (!hasGuard())
470 return true;
471
472 auto guardReturn = getGuardReturn();
473 if (guardReturn.getNumOperands() == 0)
474 return true;
475
476 if (auto constantOp =
477 guardReturn.getOperand().getDefiningOp<mlir::arith::ConstantOp>())
478 return cast<BoolAttr>(constantOp.getValue()).getValue();
479
480 return false;
481}
482
483LogicalResult TransitionOp::canonicalize(TransitionOp op,
484 PatternRewriter &rewriter) {
485 if (op.hasGuard()) {
486 auto guardReturn = op.getGuardReturn();
487 if (guardReturn.getNumOperands() == 1)
488 if (auto constantOp = guardReturn.getOperand()
489 .getDefiningOp<mlir::arith::ConstantOp>()) {
490 // Simplify when the guard region returns a constant value.
491 if (cast<BoolAttr>(constantOp.getValue()).getValue()) {
492 // Replace the original return op with a new one without any operands
493 // if the constant is TRUE.
494 rewriter.setInsertionPoint(guardReturn);
495 rewriter.create<fsm::ReturnOp>(guardReturn.getLoc());
496 rewriter.eraseOp(guardReturn);
497 } else {
498 // Erase the whole transition op if the constant is FALSE, because the
499 // transition will never be taken.
500 rewriter.eraseOp(op);
501 }
502 return success();
503 }
504 }
505
506 return failure();
507}
508
509LogicalResult TransitionOp::verify() {
510 if (!getNextStateOp())
511 return emitOpError("cannot find the definition of the next state `")
512 << getNextState() << "`";
513
514 // Verify the action region, if present.
515 if (hasGuard()) {
516 if (getGuard().front().empty() ||
517 !isa_and_nonnull<fsm::ReturnOp>(&getGuard().front().back()))
518 return emitOpError("guard region must terminate with a ReturnOp");
519 }
520
521 // Verify the transition is located in the correct region.
522 if ((*this)->getParentRegion() != &getCurrentState().getTransitions())
523 return emitOpError("must only be located in the transitions region");
524
525 return success();
526}
527
528//===----------------------------------------------------------------------===//
529// VariableOp
530//===----------------------------------------------------------------------===//
531
532void VariableOp::getAsmResultNames(
533 function_ref<void(Value, StringRef)> setNameFn) {
534 setNameFn(getResult(), getName());
535}
536
537//===----------------------------------------------------------------------===//
538// ReturnOp
539//===----------------------------------------------------------------------===//
540
541void ReturnOp::setOperand(Value value) {
542 if (getOperand())
543 getOperation()->setOperand(0, value);
544 else
545 getOperation()->insertOperands(0, {value});
546}
547
548//===----------------------------------------------------------------------===//
549// UpdateOp
550//===----------------------------------------------------------------------===//
551
552/// Get the targeted variable operation. This returns null on invalid IR.
553VariableOp UpdateOp::getVariableOp() {
554 return getVariable().getDefiningOp<VariableOp>();
555}
556
557LogicalResult UpdateOp::verify() {
558 if (!getVariable())
559 return emitOpError("destination is not a variable operation");
560
561 if (!(*this)->getParentOfType<TransitionOp>().getAction().isAncestor(
562 (*this)->getParentRegion()))
563 return emitOpError("must only be located in the action region");
564
565 auto transition = (*this)->getParentOfType<TransitionOp>();
566 for (auto otherUpdateOp : transition.getAction().getOps<UpdateOp>()) {
567 if (otherUpdateOp == *this)
568 continue;
569 if (otherUpdateOp.getVariable() == getVariable())
570 return otherUpdateOp.emitOpError(
571 "multiple updates to the same variable within a single action region "
572 "is disallowed");
573 }
574
575 return success();
576}
577
578//===----------------------------------------------------------------------===//
579// TableGen generated logic
580//===----------------------------------------------------------------------===//
581
582// Provide the autogenerated implementation guts for the Op classes.
583#define GET_OP_CLASSES
584#include "circt/Dialect/FSM/FSM.cpp.inc"
585#undef GET_OP_CLASSES
586
587#include "circt/Dialect/FSM/FSMDialect.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyCallerTypes(OpType op)
Definition FSMOps.cpp:245
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
Definition FSMOps.cpp:118
@ Output
Definition HW.h:35
static InstancePath empty
static StringAttr getResName(Operation *op, size_t idx)
static StringAttr getArgName(Operation *op, size_t idx)
type(self)
Definition fsm.py:58
create(to_state)
Definition fsm.py:110
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition fsm.py:1
Definition hw.py:1
mlir::Type type
Definition HWTypes.h:31
mlir::StringAttr name
Definition HWTypes.h:30
This holds the name, type, direction of a module's ports.
size_t argNum
This is the argument index or the result index depending on the direction.