CIRCT 23.0.0git
Loading...
Searching...
No Matches
FSMOps.cpp
Go to the documentation of this file.
1//===- FSMOps.cpp - Implementation of FSM dialect operations --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10#include "mlir/Dialect/Arith/IR/Arith.h"
11#include "mlir/Dialect/Func/IR/FuncOps.h"
12#include "mlir/IR/BuiltinOps.h"
13#include "mlir/IR/DialectImplementation.h"
14#include "mlir/IR/PatternMatch.h"
15#include "mlir/Interfaces/FunctionImplementation.h"
16#include "llvm/Support/FormatVariadic.h"
17
18using namespace mlir;
19using namespace circt;
20using namespace fsm;
21
22//===----------------------------------------------------------------------===//
23// MachineOp
24//===----------------------------------------------------------------------===//
25
26void MachineOp::build(OpBuilder &builder, OperationState &state, StringRef name,
27 StringRef initialStateName, FunctionType type,
28 ArrayRef<NamedAttribute> attrs,
29 ArrayRef<DictionaryAttr> argAttrs) {
30 state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
31 builder.getStringAttr(name));
32 state.addAttribute(MachineOp::getFunctionTypeAttrName(state.name),
33 TypeAttr::get(type));
34 state.addAttribute("initialState",
35 StringAttr::get(state.getContext(), initialStateName));
36 state.attributes.append(attrs.begin(), attrs.end());
37 Region *region = state.addRegion();
38 Block *body = new Block();
39 region->push_back(body);
40 body->addArguments(
41 type.getInputs(),
42 SmallVector<Location, 4>(type.getNumInputs(), builder.getUnknownLoc()));
43
44 if (argAttrs.empty())
45 return;
46 assert(type.getNumInputs() == argAttrs.size());
47 call_interface_impl::addArgAndResultAttrs(
48 builder, state, argAttrs,
49 /*resultAttrs=*/{}, MachineOp::getArgAttrsAttrName(state.name),
50 MachineOp::getResAttrsAttrName(state.name));
51}
52
53/// Get the initial state of the machine.
54StateOp MachineOp::getInitialStateOp() {
55 return dyn_cast_or_null<StateOp>(lookupSymbol(getInitialState()));
56}
57
58size_t MachineOp::getNumStates() {
59 auto stateOps = getBody().getOps<StateOp>();
60 return std::distance(stateOps.begin(), stateOps.end());
61}
62
63StringAttr MachineOp::getArgName(size_t i) {
64 if (auto args = getArgNames())
65 return cast<StringAttr>((*args)[i]);
66
67 return StringAttr::get(getContext(), "in" + std::to_string(i));
68}
69
70StringAttr MachineOp::getResName(size_t i) {
71 if (auto resNameAttrs = getResNames())
72 return cast<StringAttr>((*resNameAttrs)[i]);
73
74 return StringAttr::get(getContext(), "out" + std::to_string(i));
75}
76
77/// Get the port information of the machine.
78void MachineOp::getHWPortInfo(SmallVectorImpl<hw::PortInfo> &ports) {
79 ports.clear();
80 auto machineType = getFunctionType();
81 for (unsigned i = 0, e = machineType.getNumInputs(); i < e; ++i) {
82 hw::PortInfo port;
83 port.name = getArgName(i);
84 if (!port.name)
85 port.name = StringAttr::get(getContext(), "in" + std::to_string(i));
87 port.type = machineType.getInput(i);
88 port.argNum = i;
89 ports.push_back(port);
90 }
91
92 for (unsigned i = 0, e = machineType.getNumResults(); i < e; ++i) {
93 hw::PortInfo port;
94 port.name = getResName(i);
95 if (!port.name)
96 port.name = StringAttr::get(getContext(), "out" + std::to_string(i));
98 port.type = machineType.getResult(i);
99 port.argNum = i;
100 ports.push_back(port);
101 }
102}
103
104ParseResult MachineOp::parse(OpAsmParser &parser, OperationState &result) {
105 auto buildFuncType =
106 [&](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
107 function_interface_impl::VariadicFlag,
108 std::string &) { return builder.getFunctionType(argTypes, results); };
109
110 return function_interface_impl::parseFunctionOp(
111 parser, result, /*allowVariadic=*/false,
112 MachineOp::getFunctionTypeAttrName(result.name), buildFuncType,
113 MachineOp::getArgAttrsAttrName(result.name),
114 MachineOp::getResAttrsAttrName(result.name));
115}
116
117void MachineOp::print(OpAsmPrinter &p) {
118 function_interface_impl::printFunctionOp(
119 p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(),
120 getArgAttrsAttrName(), getResAttrsAttrName());
121}
122
123static LogicalResult compareTypes(Location loc, TypeRange rangeA,
124 TypeRange rangeB) {
125 if (rangeA.size() != rangeB.size())
126 return emitError(loc) << "mismatch in number of types compared ("
127 << rangeA.size() << " != " << rangeB.size() << ")";
128
129 size_t index = 0;
130 for (auto zip : llvm::zip(rangeA, rangeB)) {
131 auto typeA = std::get<0>(zip);
132 auto typeB = std::get<1>(zip);
133 if (typeA != typeB)
134 return emitError(loc) << "type mismatch at index " << index << " ("
135 << typeA << " != " << typeB << ")";
136 ++index;
137 }
138
139 return success();
140}
141
142LogicalResult MachineOp::verify() {
143 // If this function is external there is nothing to do.
144 if (isExternal())
145 return success();
146
147 // Verify that the argument list of the function and the arg list of the entry
148 // block line up. The trait already verified that the number of arguments is
149 // the same between the signature and the block.
150 if (failed(compareTypes(getLoc(), getArgumentTypes(),
151 front().getArgumentTypes())))
152 return emitOpError(
153 "entry block argument types must match the machine input types");
154
155 // Verify that the machine only has one block terminated with OutputOp.
156 if (!llvm::hasSingleElement(*this))
157 return emitOpError("must only have a single block");
158
159 // Verify that the initial state exists
160 if (!getInitialStateOp())
161 return emitOpError("initial state '" + getInitialState() +
162 "' was not defined in the machine");
163
164 if (getArgNames() && getArgNames()->size() != getArgumentTypes().size())
165 return emitOpError() << "number of machine arguments ("
166 << getArgumentTypes().size()
167 << ") does "
168 "not match the provided number "
169 "of argument names ("
170 << getArgNames()->size() << ")";
171
172 if (getResNames() && getResNames()->size() != getResultTypes().size())
173 return emitOpError() << "number of machine results ("
174 << getResultTypes().size()
175 << ") does "
176 "not match the provided number "
177 "of result names ("
178 << getResNames()->size() << ")";
179
180 return success();
181}
182
183SmallVector<::circt::hw::PortInfo> MachineOp::getPortList() {
184 SmallVector<hw::PortInfo> ports;
185 auto argNames = getArgNames();
186 auto argTypes = getFunctionType().getInputs();
187 for (unsigned i = 0, e = argTypes.size(); i < e; ++i) {
188 bool isInOut = false;
189 auto type = argTypes[i];
190
191 if (auto inout = dyn_cast<hw::InOutType>(type)) {
192 isInOut = true;
193 type = inout.getElementType();
194 }
195
196 auto direction = isInOut ? hw::ModulePort::Direction::InOut
197 : hw::ModulePort::Direction::Input;
198
199 ports.push_back(
200 {{argNames ? cast<StringAttr>((*argNames)[i])
201 : StringAttr::get(getContext(), Twine("input") + Twine(i)),
202 type, direction},
203 i,
204 {},
205 {}});
206 }
207
208 auto resultNames = getResNames();
209 auto resultTypes = getFunctionType().getResults();
210 for (unsigned i = 0, e = resultTypes.size(); i < e; ++i) {
211 ports.push_back({{resultNames ? cast<StringAttr>((*resultNames)[i])
212 : StringAttr::get(getContext(),
213 Twine("output") + Twine(i)),
214 resultTypes[i], hw::ModulePort::Direction::Output},
215 i,
216 {},
217 {}});
218 }
219 return ports;
220}
221
222//===----------------------------------------------------------------------===//
223// InstanceOp
224//===----------------------------------------------------------------------===//
225
226/// Lookup the machine for the symbol. This returns null on invalid IR.
227MachineOp InstanceOp::getMachineOp() {
228 auto module = (*this)->getParentOfType<ModuleOp>();
229 return module.lookupSymbol<MachineOp>(getMachine());
230}
231
232LogicalResult InstanceOp::verify() {
233 auto m = getMachineOp();
234 if (!m)
235 return emitError("cannot find machine definition '") << getMachine() << "'";
236
237 return success();
238}
239
240void InstanceOp::getAsmResultNames(
241 function_ref<void(Value, StringRef)> setNameFn) {
242 setNameFn(getInstance(), getName());
243}
244
245//===----------------------------------------------------------------------===//
246// TriggerOp
247//===----------------------------------------------------------------------===//
248
249template <typename OpType>
250static LogicalResult verifyCallerTypes(OpType op) {
251 auto machine = op.getMachineOp();
252 if (!machine)
253 return op.emitError("cannot find machine definition");
254
255 // Check operand types first.
256 if (failed(compareTypes(op.getLoc(), machine.getArgumentTypes(),
257 op.getInputs().getTypes()))) {
258 auto diag =
259 op.emitOpError("operand types must match the machine input types");
260 diag.attachNote(machine->getLoc()) << "original machine declared here";
261 return failure();
262 }
263
264 // Check result types.
265 if (failed(compareTypes(op.getLoc(), machine.getResultTypes(),
266 op.getOutputs().getTypes()))) {
267 auto diag =
268 op.emitOpError("result types must match the machine output types");
269 diag.attachNote(machine->getLoc()) << "original machine declared here";
270 return failure();
271 }
272
273 return success();
274}
275
276/// Lookup the machine for the symbol. This returns null on invalid IR.
277MachineOp TriggerOp::getMachineOp() {
278 auto instanceOp = getInstance().getDefiningOp<InstanceOp>();
279 if (!instanceOp)
280 return nullptr;
281
282 return instanceOp.getMachineOp();
283}
284
285LogicalResult TriggerOp::verify() { return verifyCallerTypes(*this); }
286
287//===----------------------------------------------------------------------===//
288// HWInstanceOp
289//===----------------------------------------------------------------------===//
290
291// InstanceOpInterface interface
292
293/// Lookup the machine for the symbol. This returns null on invalid IR.
294MachineOp HWInstanceOp::getMachineOp() {
295 auto module = (*this)->getParentOfType<ModuleOp>();
296 return module.lookupSymbol<MachineOp>(getMachine());
297}
298
299LogicalResult HWInstanceOp::verify() { return verifyCallerTypes(*this); }
300
301SmallVector<hw::PortInfo> HWInstanceOp::getPortList() {
302 return getMachineOp().getPortList();
303}
304
305/// Module name is the same as the machine name.
306StringRef HWInstanceOp::getModuleName() { return getMachine(); }
307FlatSymbolRefAttr HWInstanceOp::getModuleNameAttr() { return getMachineAttr(); }
308
309mlir::StringAttr HWInstanceOp::getInstanceNameAttr() { return getNameAttr(); }
310
311llvm::StringRef HWInstanceOp::getInstanceName() { return getName(); }
312
313//===----------------------------------------------------------------------===//
314// StateOp
315//===----------------------------------------------------------------------===//
316
317void StateOp::build(OpBuilder &builder, OperationState &state,
318 StringRef stateName) {
319 OpBuilder::InsertionGuard guard(builder);
320 Region *output = state.addRegion();
321 output->push_back(new Block());
322 builder.setInsertionPointToEnd(&output->back());
323 fsm::OutputOp::create(builder, state.location);
324 Region *transitions = state.addRegion();
325 transitions->push_back(new Block());
326 state.addAttribute("sym_name", builder.getStringAttr(stateName));
327}
328
329void StateOp::build(OpBuilder &builder, OperationState &state,
330 StringRef stateName, ValueRange outputs) {
331 OpBuilder::InsertionGuard guard(builder);
332 Region *output = state.addRegion();
333 output->push_back(new Block());
334 builder.setInsertionPointToEnd(&output->back());
335 fsm::OutputOp::create(builder, state.location, outputs);
336 Region *transitions = state.addRegion();
337 transitions->push_back(new Block());
338 state.addAttribute("sym_name", builder.getStringAttr(stateName));
339}
340
341SetVector<StateOp> StateOp::getNextStates() {
342 SmallVector<StateOp> nextStates;
343 llvm::transform(
344 getTransitions().getOps<TransitionOp>(),
345 std::inserter(nextStates, nextStates.begin()),
346 [](TransitionOp transition) { return transition.getNextStateOp(); });
347 return SetVector<StateOp>(nextStates.begin(), nextStates.end());
348}
349
350LogicalResult StateOp::canonicalize(StateOp op, PatternRewriter &rewriter) {
351 bool hasAlwaysTakenTransition = false;
352 SmallVector<TransitionOp, 4> transitionsToErase;
353 // Remove all transitions after an "always-taken" transition.
354 for (auto transition : op.getTransitions().getOps<TransitionOp>()) {
355 if (!hasAlwaysTakenTransition)
356 hasAlwaysTakenTransition = transition.isAlwaysTaken();
357 else
358 transitionsToErase.push_back(transition);
359 }
360
361 for (auto transition : transitionsToErase)
362 rewriter.eraseOp(transition);
363
364 return failure(transitionsToErase.empty());
365}
366
367LogicalResult StateOp::verify() {
368 MachineOp parent = getOperation()->getParentOfType<MachineOp>();
369
370 if (parent.getNumResults() != 0 && (getOutput().empty()))
371 return emitOpError("state must have a non-empty output region when the "
372 "machine has results.");
373
374 if (!getOutput().empty()) {
375 // Ensure that the output block has a single OutputOp terminator.
376 Block *outputBlock = &getOutput().front();
377 if (outputBlock->empty() || !isa<fsm::OutputOp>(outputBlock->back()))
378 return emitOpError("output block must have a single OutputOp terminator");
379 }
380
381 return success();
382}
383
384Block *StateOp::ensureOutput(OpBuilder &builder) {
385 if (getOutput().empty()) {
386 OpBuilder::InsertionGuard g(builder);
387 auto *block = new Block();
388 getOutput().push_back(block);
389 builder.setInsertionPointToStart(block);
390 fsm::OutputOp::create(builder, getLoc());
391 }
392 return &getOutput().front();
393}
394
395//===----------------------------------------------------------------------===//
396// OutputOp
397//===----------------------------------------------------------------------===//
398
399LogicalResult OutputOp::verify() {
400 if ((*this)->getParentRegion() ==
401 &(*this)->getParentOfType<StateOp>().getTransitions()) {
402 if (getNumOperands() != 0)
403 emitOpError("transitions region must not output any value");
404 return success();
405 }
406
407 // Verify that the result list of the machine and the operand list of the
408 // OutputOp line up.
409 auto machine = (*this)->getParentOfType<MachineOp>();
410 if (failed(
411 compareTypes(getLoc(), machine.getResultTypes(), getOperandTypes())))
412 return emitOpError("operand types must match the machine output types");
413
414 return success();
415}
416
417//===----------------------------------------------------------------------===//
418// TransitionOp
419//===----------------------------------------------------------------------===//
420
421void TransitionOp::build(OpBuilder &builder, OperationState &state,
422 StateOp nextState) {
423 build(builder, state, nextState.getName());
424}
425
426void TransitionOp::build(OpBuilder &builder, OperationState &state,
427 StringRef nextState,
428 llvm::function_ref<void()> guardCtor,
429 llvm::function_ref<void()> actionCtor) {
430 state.addAttribute("nextState",
431 FlatSymbolRefAttr::get(builder.getStringAttr(nextState)));
432 OpBuilder::InsertionGuard guard(builder);
433
434 Region *guardRegion = state.addRegion(); // guard
435 if (guardCtor) {
436 builder.createBlock(guardRegion);
437 guardCtor();
438 }
439
440 Region *actionRegion = state.addRegion(); // action
441 if (actionCtor) {
442 builder.createBlock(actionRegion);
443 actionCtor();
444 }
445}
446
447Block *TransitionOp::ensureGuard(OpBuilder &builder) {
448 if (getGuard().empty()) {
449 OpBuilder::InsertionGuard g(builder);
450 auto *block = new Block();
451 getGuard().push_back(block);
452 builder.setInsertionPointToStart(block);
453 fsm::ReturnOp::create(builder, getLoc());
454 }
455 return &getGuard().front();
456}
457
458Block *TransitionOp::ensureAction(OpBuilder &builder) {
459 if (getAction().empty())
460 getAction().push_back(new Block());
461 return &getAction().front();
462}
463
464/// Lookup the next state for the symbol. This returns null on invalid IR.
465StateOp TransitionOp::getNextStateOp() {
466 auto machineOp = (*this)->getParentOfType<MachineOp>();
467 if (!machineOp)
468 return nullptr;
469
470 return machineOp.lookupSymbol<StateOp>(getNextState());
471}
472
473bool TransitionOp::isAlwaysTaken() {
474 if (!hasGuard())
475 return true;
476
477 auto guardReturn = getGuardReturn();
478 if (guardReturn.getNumOperands() == 0)
479 return true;
480
481 if (auto constantOp =
482 guardReturn.getOperand().getDefiningOp<mlir::arith::ConstantOp>())
483 return cast<BoolAttr>(constantOp.getValue()).getValue();
484
485 return false;
486}
487
488LogicalResult TransitionOp::canonicalize(TransitionOp op,
489 PatternRewriter &rewriter) {
490 if (op.hasGuard()) {
491 auto guardReturn = op.getGuardReturn();
492 if (guardReturn.getNumOperands() == 1)
493 if (auto constantOp = guardReturn.getOperand()
494 .getDefiningOp<mlir::arith::ConstantOp>()) {
495 // Simplify when the guard region returns a constant value.
496 if (cast<BoolAttr>(constantOp.getValue()).getValue()) {
497 // Replace the original return op with a new one without any operands
498 // if the constant is TRUE.
499 rewriter.setInsertionPoint(guardReturn);
500 fsm::ReturnOp::create(rewriter, guardReturn.getLoc());
501 rewriter.eraseOp(guardReturn);
502 } else {
503 // Erase the whole transition op if the constant is FALSE, because the
504 // transition will never be taken.
505 rewriter.eraseOp(op);
506 }
507 return success();
508 }
509 }
510
511 return failure();
512}
513
514LogicalResult TransitionOp::verify() {
515 if (!getNextStateOp())
516 return emitOpError("cannot find the definition of the next state `")
517 << getNextState() << "`";
518
519 // Verify the action region, if present.
520 if (hasGuard()) {
521 if (getGuard().front().empty() ||
522 !isa_and_nonnull<fsm::ReturnOp>(&getGuard().front().back()))
523 return emitOpError("guard region must terminate with a ReturnOp");
524 }
525
526 // Verify the transition is located in the correct region.
527 if ((*this)->getParentRegion() != &getCurrentState().getTransitions())
528 return emitOpError("must only be located in the transitions region");
529
530 return success();
531}
532
533//===----------------------------------------------------------------------===//
534// VariableOp
535//===----------------------------------------------------------------------===//
536
537void VariableOp::getAsmResultNames(
538 function_ref<void(Value, StringRef)> setNameFn) {
539 setNameFn(getResult(), getName());
540}
541
542//===----------------------------------------------------------------------===//
543// ReturnOp
544//===----------------------------------------------------------------------===//
545
546void ReturnOp::setOperand(Value value) {
547 if (getOperand())
548 getOperation()->setOperand(0, value);
549 else
550 getOperation()->insertOperands(0, {value});
551}
552
553//===----------------------------------------------------------------------===//
554// UpdateOp
555//===----------------------------------------------------------------------===//
556
557/// Get the targeted variable operation. This returns null on invalid IR.
558VariableOp UpdateOp::getVariableOp() {
559 return getVariable().getDefiningOp<VariableOp>();
560}
561
562LogicalResult UpdateOp::verify() {
563 if (!getVariable())
564 return emitOpError("destination is not a variable operation");
565
566 if (!(*this)->getParentOfType<TransitionOp>().getAction().isAncestor(
567 (*this)->getParentRegion()))
568 return emitOpError("must only be located in the action region");
569
570 auto transition = (*this)->getParentOfType<TransitionOp>();
571 for (auto otherUpdateOp : transition.getAction().getOps<UpdateOp>()) {
572 if (otherUpdateOp == *this)
573 continue;
574 if (otherUpdateOp.getVariable() == getVariable())
575 return otherUpdateOp.emitOpError(
576 "multiple updates to the same variable within a single action region "
577 "is disallowed");
578 }
579
580 return success();
581}
582
583//===----------------------------------------------------------------------===//
584// TableGen generated logic
585//===----------------------------------------------------------------------===//
586
587// Provide the autogenerated implementation guts for the Op classes.
588#define GET_OP_CLASSES
589#include "circt/Dialect/FSM/FSM.cpp.inc"
590#undef GET_OP_CLASSES
591
592#include "circt/Dialect/FSM/FSMDialect.cpp.inc"
assert(baseType &&"element must be base type")
static LogicalResult verifyCallerTypes(OpType op)
Definition FSMOps.cpp:250
static LogicalResult compareTypes(Location loc, TypeRange rangeA, TypeRange rangeB)
Definition FSMOps.cpp:123
@ Output
Definition HW.h:42
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static StringAttr getResName(Operation *op, size_t idx)
static StringAttr getArgName(Operation *op, size_t idx)
static InstancePath empty
type(self)
Definition fsm.py:58
create(*operands)
Definition fsm.py:160
Direction get(bool isOutput)
Returns an output direction if isOutput is true, otherwise returns an input direction.
Definition CalyxOps.cpp:55
Direction
The direction of a Component or Cell port.
Definition CalyxOps.h:76
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition fsm.py:1
Definition hw.py:1
mlir::Type type
Definition HWTypes.h:31
mlir::StringAttr name
Definition HWTypes.h:30
This holds the name, type, direction of a module's ports.
size_t argNum
This is the argument index or the result index depending on the direction.