CIRCT 22.0.0git
Loading...
Searching...
No Matches
FSMToCore.cpp
Go to the documentation of this file.
1//===- FSMToCore.cpp - Convert FSM to HW Dialect --------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
17#include "mlir/IR/BuiltinTypes.h"
18#include "mlir/IR/MLIRContext.h"
19#include "mlir/IR/Value.h"
20#include "mlir/Pass/Pass.h"
21#include "mlir/Transforms/RegionUtils.h"
22#include "llvm/ADT/DenseMap.h"
23#include "llvm/ADT/SmallVector.h"
24
25namespace circt {
26#define GEN_PASS_DEF_CONVERTFSMTOCORE
27#include "circt/Conversion/Passes.h.inc"
28} // namespace circt
29
30using namespace mlir;
31using namespace circt;
32using namespace fsm;
33
34namespace {
35struct ClkRstIdxs {
36 size_t clockIdx;
37 size_t resetIdx;
38};
39
40/// Clones constants implicitly captured by the region, into the region.
41static void cloneConstantsIntoRegion(Region &region, OpBuilder &builder) {
42 // Values implicitly captured by the region.
43 llvm::SetVector<Value> captures;
44 getUsedValuesDefinedAbove(region, region, captures);
45
46 OpBuilder::InsertionGuard guard(builder);
47 builder.setInsertionPointToStart(&region.front());
48
49 // Clone ConstantLike operations into the region.
50 for (auto &capture : captures) {
51 Operation *op = capture.getDefiningOp();
52 if (!op || !op->hasTrait<OpTrait::ConstantLike>())
53 continue;
54
55 Operation *cloned = builder.clone(*op);
56 for (auto [orig, replacement] :
57 llvm::zip(op->getResults(), cloned->getResults()))
58 replaceAllUsesInRegionWith(orig, replacement, region);
59 }
60}
61
62static ClkRstIdxs getMachinePortInfo(SmallVectorImpl<hw::PortInfo> &ports,
63 MachineOp machine, OpBuilder &b) {
64 // Get the port info of the machine inputs and outputs.
65 machine.getHWPortInfo(ports);
66 ClkRstIdxs specialPorts;
67
68 // Add clock port.
69 hw::PortInfo clock;
70 clock.name = b.getStringAttr("clk");
71 clock.dir = hw::ModulePort::Direction::Input;
72 clock.type = seq::ClockType::get(b.getContext());
73 clock.argNum = machine.getNumArguments();
74 ports.push_back(clock);
75 specialPorts.clockIdx = clock.argNum;
76
77 // Add reset port.
78 hw::PortInfo reset;
79 reset.name = b.getStringAttr("rst");
80 reset.dir = hw::ModulePort::Direction::Input;
81 reset.type = b.getI1Type();
82 reset.argNum = machine.getNumArguments() + 1;
83 ports.push_back(reset);
84 specialPorts.resetIdx = reset.argNum;
85
86 return specialPorts;
87}
88} // namespace
89
90namespace {
91
92class StateEncoding {
93
94public:
95 StateEncoding(OpBuilder &b, MachineOp machine, hw::HWModuleOp hwModule);
96
97 /// Get the encoded value for a state.
98 Value encode(StateOp state);
99 /// Get the state corresponding to an encoded value.
100 StateOp decode(Value value);
101
102 /// Returns the type which encodes the state values.
103 Type getStateType() { return stateType; }
104
105protected:
106 /// Creates a constant value in the module for the given encoded state
107 /// and records the state value in the mappings.
108 void setEncoding(StateOp state, Value v);
109
110 /// A mapping between a StateOp and its corresponding encoded value.
112
113 /// A mapping between an encoded value and its corresponding StateOp.
115
116 Type stateType;
117
118 OpBuilder &b;
119 MachineOp machine;
120 hw::HWModuleOp hwModule;
121};
122
123StateEncoding::StateEncoding(OpBuilder &b, MachineOp machine,
124 hw::HWModuleOp hwModule)
125 : b(b), machine(machine), hwModule(hwModule) {
126 Location loc = machine.getLoc();
127
128 OpBuilder::InsertionGuard guard(b);
129 b.setInsertionPointToStart(&hwModule.getBodyRegion().front());
130 // If stateType is explicitly provided, use this - otherwise, calculate the
131 // minimum int size that can represent all states
132 if (machine->getAttr("stateType")) {
133 // We already checked that a static cast is valid
134 stateType = cast<TypeAttr>(machine->getAttr("stateType")).getValue();
135 } else {
136 int numStates = std::distance(machine.getBody().getOps<StateOp>().begin(),
137 machine.getBody().getOps<StateOp>().end());
138 stateType =
139 IntegerType::get(machine.getContext(), llvm::Log2_64_Ceil(numStates));
140 }
141 int stateValue = 0;
142 // And create values for the states
143 b.setInsertionPointToStart(&hwModule.getBody().front());
144 for (auto state : machine.getBody().getOps<StateOp>()) {
145 auto constantOp = hw::ConstantOp::create(b, loc, stateType, stateValue++);
146 setEncoding(state, constantOp);
147 }
148}
149
150// Get the encoded value for a state.
151Value StateEncoding::encode(StateOp state) {
152 auto it = stateToValue.find(state);
153 assert(it != stateToValue.end() && "state not found");
154 return it->second;
155}
156
157void StateEncoding::setEncoding(StateOp state, Value v) {
158 assert(stateToValue.find(state) == stateToValue.end() &&
159 "state already encoded");
160 stateToValue[state] = v;
161 valueToState[v] = state;
162}
163} // namespace
164
165namespace {
166class MachineOpConverter {
167public:
168 MachineOpConverter(OpBuilder &builder, MachineOp machineOp)
169 : machineOp(machineOp), b(builder),
170 bb(BackedgeBuilder(builder, machineOp->getLoc())) {}
171
172 /// Converts the machine op to a hardware module.
173 /// 1. Creates a HWModuleOp for the machine op, with the same I/O as the FSM +
174 /// clk/reset ports.
175 /// 2. Creates a state register + encodings for the states visible in the
176 /// machine.
177 /// 3. Iterates over all states in the machine
178 /// 3.1. Moves all `comb` logic into the body of the HW module
179 /// 3.2. Extends the output logic mux chains with cases for this state
180 /// 3.3. Iterates over the transitions of the state
181 /// 3.3.1. Moves all `comb` logic in the transition guard/action regions to
182 /// the body of the HW module.
183 /// 3.3.2. Extends the next state mux chain with an optionally guarded case
184 /// for this transition.
185 /// 4. Connects the state and variable mux chain outputs to the corresponding
186 /// register inputs.
187 LogicalResult dispatch();
188
189private:
190 /// Converts a StateOp within this machine, and returns the value
191 /// corresponding to the next-state output of the op.
192 LogicalResult convertState(StateOp state);
193
194 /// Converts the outgoing transitions of a state and returns the value
195 /// corresponding to the next-state output of the op.
196 /// Transitions are priority encoded in the order which they appear in the
197 /// state transition region.
198 FailureOr<Value> convertTransitions(StateOp currentState,
199 ArrayRef<TransitionOp> transitions);
200
201 /// Moves operations from 'block' into module scope, failing if any op were
202 /// deemed illegal. Returns the final op in the block if the op was a
203 /// terminator. An optional 'exclude' filer can be provided to dynamically
204 /// exclude some ops from being moved.
205 FailureOr<Operation *>
206 moveOps(Block *block,
207 llvm::function_ref<bool(Operation *)> exclude = nullptr);
208
209 DenseMap<Value, std::string> backedgeMap;
210
211 /// A handle to the state encoder for this machine.
212 std::unique_ptr<StateEncoding> encoding;
213
214 /// A mapping from a fsm.variable op to its register.
215 llvm::MapVector<VariableOp, seq::CompRegOp> variableToRegister;
216
217 /// A mapping from a fsm.variable op to the output of the mux chain that
218 /// calculates its next value.
219 llvm::MapVector<VariableOp, mlir::Value> variableToMuxChainOut;
220
221 /// Mapping from a hw port to
222 llvm::SmallVector<mlir::Value> outputMuxChainOuts;
223
224 /// A handle to the MachineOp being converted.
225 MachineOp machineOp;
226
227 /// A handle to the HW ModuleOp being created.
228 hw::HWModuleOp hwModuleOp;
229
230 /// A handle to the state register of the machine.
231 seq::CompRegOp stateReg;
232
233 OpBuilder &b;
234
235 mlir::Value stateMuxChainOut;
236
238};
239} // namespace
240
241LogicalResult MachineOpConverter::dispatch() {
242 b.setInsertionPoint(machineOp);
243 auto loc = machineOp.getLoc();
244
245 // Clone all referenced constants into the machine body - constants may have
246 // been moved to the machine parent due to the lack of IsolationFromAbove.
247 cloneConstantsIntoRegion(machineOp.getBody(), b);
248
249 // 1) Get the port info of the machine and create a new HW module for it.
250 SmallVector<hw::PortInfo, 16> ports;
251 auto clkRstIdxs = getMachinePortInfo(ports, machineOp, b);
252 hwModuleOp =
253 hw::HWModuleOp::create(b, loc, machineOp.getSymNameAttr(), ports);
254 b.setInsertionPointToStart(hwModuleOp.getBodyBlock());
255
256 // Replace all uses of the machine arguments with the arguments of the
257 // newly created HW module.
258 for (auto [machineArg, hwModuleArg] :
259 llvm::zip(machineOp.getArguments(),
260 hwModuleOp.getBodyBlock()->getArguments())) {
261 machineArg.replaceAllUsesWith(hwModuleArg);
262 }
263
264 auto clock = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.clockIdx);
265 auto reset = hwModuleOp.getBodyBlock()->getArgument(clkRstIdxs.resetIdx);
266
267 // 2) Build state and variable registers.
268
269 encoding = std::make_unique<StateEncoding>(b, machineOp, hwModuleOp);
270 auto stateType = encoding->getStateType();
271
272 Backedge nextStateWire = bb.get(stateType);
273
274 auto initialStateOp = machineOp.getInitialStateOp();
275 stateReg = seq::CompRegOp::create(
276 b, loc, nextStateWire, clock, reset,
277 /*reset value=*/encoding->encode(initialStateOp), "state_reg",
278 /*powerOn value=*/
279 seq::createConstantInitialValue(
280 b, encoding->encode(initialStateOp).getDefiningOp()));
281 stateMuxChainOut = stateReg;
282
283 llvm::DenseMap<VariableOp, Backedge> variableNextStateWires;
284 for (auto variableOp : machineOp.front().getOps<fsm::VariableOp>()) {
285 auto initValueAttr = cast<IntegerAttr>(variableOp.getInitValueAttr());
286 Type varType = variableOp.getType();
287 auto varLoc = variableOp.getLoc();
288 auto nextVariableStateWire = bb.get(varType);
289 auto varResetVal = hw::ConstantOp::create(b, varLoc, initValueAttr);
290 auto variableReg = seq::CompRegOp::create(
291 b, varLoc, nextVariableStateWire, clock, reset, varResetVal,
292 b.getStringAttr(variableOp.getName()),
293 seq::createConstantInitialValue(b, varResetVal));
294 variableToRegister[variableOp] = variableReg;
295 variableNextStateWires[variableOp] = nextVariableStateWire;
296 variableToMuxChainOut[variableOp] = variableReg;
297 // Postpone value replacement until all logic has been created.
298 // fsm::UpdateOp's require their target variables to refer to a
299 // fsm::VariableOp - if this is not the case, they'll throw an assert.
300 }
301
302 // Move any operations at the machine-level scope, excluding state ops,
303 // which are handled separately.
304 if (failed(moveOps(&machineOp.front(), [](Operation *op) {
305 return isa<fsm::StateOp, fsm::VariableOp>(op);
306 })))
307 return failure();
308
309 // Begin mux chains for outputs
310 auto hwPortList = hwModuleOp.getPortList();
311 llvm::SmallVector<Backedge> outputBackedges;
312 for (auto &port : hwPortList)
313 if (port.isOutput())
314 outputMuxChainOuts.push_back(Value());
315
316 // 3) Convert states and record their next-state value assignments.
317 for (auto state : machineOp.getBody().getOps<StateOp>()) {
318 auto stateConvRes = convertState(state);
319 if (failed(stateConvRes))
320 return failure();
321 }
322
323 // 4) Set the input of the state and variable registers to the output of their
324 // mux chains.
325 nextStateWire.setValue(stateMuxChainOut);
326 for (auto [variable, muxChainOut] : variableToMuxChainOut) {
327 variableNextStateWires[variable].setValue(muxChainOut);
328 }
329
330 // Replace variable values with their register counterparts.
331 for (auto [variableOp, variableReg] : variableToRegister)
332 variableOp.getResult().replaceAllUsesWith(variableReg);
333
334 // Cast to values to appease builder
335 llvm::SmallVector<Value> outputValues;
336 for (auto backedge : outputMuxChainOuts) {
337 outputValues.push_back(backedge);
338 }
339 auto *oldOutputOp = hwModuleOp.getBodyBlock()->getTerminator();
340 b.setInsertionPointToEnd(oldOutputOp->getBlock());
341 oldOutputOp->erase();
342 hw::OutputOp::create(b, loc, outputValues);
343 machineOp.erase();
344 return success();
345}
346
347FailureOr<Value>
348MachineOpConverter::convertTransitions( // NOLINT(misc-no-recursion)
349 StateOp currentState, ArrayRef<TransitionOp> transitions) {
350 Value nextState;
351 llvm::MapVector<fsm::VariableOp, Value> variableUpdates;
352 auto stateCmp =
353 comb::ICmpOp::create(b, machineOp.getLoc(), comb::ICmpPredicate::eq,
354 stateReg, encoding->encode(currentState));
355 if (transitions.empty()) {
356 // Base case
357 // State: transition to the current state.
358 nextState = encoding->encode(currentState);
359 } else {
360 // Recursive case - transition to a named state.
361 auto transition = cast<fsm::TransitionOp>(transitions.front());
362 nextState = encoding->encode(transition.getNextStateOp());
363 mlir::Value varUpdateCondition;
364 // Action conversion
365 if (transition.hasAction()) {
366 // Move any ops from the action region to the general scope, excluding
367 // variable update ops.
368 auto actionMoveOpsRes =
369 moveOps(&transition.getAction().front(),
370 [](Operation *op) { return isa<fsm::UpdateOp>(op); });
371 if (failed(actionMoveOpsRes))
372 return failure();
373
374 // Gather variable updates during the action.
375 for (auto updateOp : transition.getAction().getOps<fsm::UpdateOp>()) {
376 VariableOp variableOp = updateOp.getVariableOp();
377 variableUpdates[variableOp] = updateOp.getValue();
378 }
379 }
380
381 // Guard conversion
382 if (transition.hasGuard()) {
383 // Not always taken; recurse and mux between the targeted next state and
384 // the recursion result, selecting based on the provided guard.
385 auto guardOpRes = moveOps(&transition.getGuard().front());
386 if (failed(guardOpRes))
387 return failure();
388
389 auto guardOp = cast<ReturnOp>(*guardOpRes);
390 assert(guardOp && "guard should be defined");
391 auto guard = guardOp.getOperand();
392 auto otherNextState =
393 convertTransitions(currentState, transitions.drop_front());
394 if (failed(otherNextState))
395 return failure();
396 comb::MuxOp nextStateMux = comb::MuxOp::create(
397 b, transition.getLoc(), guard, nextState, *otherNextState, false);
398 nextState = nextStateMux;
399 varUpdateCondition =
400 comb::AndOp::create(b, machineOp.getLoc(), guard, stateCmp);
401 } else
402 varUpdateCondition = stateCmp;
403 // Handle variable updates
404 for (auto variableUpdate : variableUpdates) {
405 auto muxChainOut = variableToMuxChainOut[variableUpdate.first];
406 auto newMuxChainOut =
407 comb::MuxOp::create(b, machineOp.getLoc(), varUpdateCondition,
408 variableUpdate.second, muxChainOut, false);
409 variableToMuxChainOut[variableUpdate.first] = newMuxChainOut;
410 }
411 }
412
413 stateMuxChainOut = comb::MuxOp::create(b, machineOp.getLoc(), stateCmp,
414 nextState, stateMuxChainOut);
415 assert(nextState && "next state should be defined");
416 return nextState;
417}
418
419FailureOr<Operation *>
420MachineOpConverter::moveOps(Block *block,
421 llvm::function_ref<bool(Operation *)> exclude) {
422 for (auto &op : llvm::make_early_inc_range(*block)) {
423 if (!isa<comb::CombDialect, hw::HWDialect, fsm::FSMDialect>(
424 op.getDialect())) {
425 // Avoid giving unrelated errors about unbound backedges.
426 bb.abandon();
427 return op.emitOpError()
428 << "is unsupported (op from the "
429 << op.getDialect()->getNamespace() << " dialect).";
430 }
431 if (exclude && exclude(&op))
432 continue;
433
434 if (op.hasTrait<OpTrait::IsTerminator>())
435 return &op;
436
437 op.moveBefore(hwModuleOp.getBodyBlock(), b.getInsertionPoint());
438 }
439 return nullptr;
440}
441
442LogicalResult MachineOpConverter::convertState(StateOp state) {
443 // 3.1) Convert the output region by moving the operations into the module
444 // scope and gathering the operands of the output op.
445 if (!state.getOutput().empty()) {
446 auto outputOpRes = moveOps(&state.getOutput().front());
447 if (failed(outputOpRes))
448 return failure();
449
450 // 3.2) Extend the output mux chains with a comparison on this state
451 OutputOp outputOp = cast<fsm::OutputOp>(*outputOpRes);
452 auto stateCmp =
453 comb::ICmpOp::create(b, machineOp.getLoc(), comb::ICmpPredicate::eq,
454 stateReg, encoding->encode(state));
455
456 for (auto [i, operand] : llvm::enumerate(outputOp.getOperands())) {
457 auto muxChainOut = outputMuxChainOuts[i];
458 // If this is the first node in the mux chain, just use this value
459 // directly as the default
460 if (!muxChainOut) {
461 outputMuxChainOuts[i] = operand;
462 continue;
463 }
464 auto muxOp = comb::MuxOp::create(b, machineOp.getLoc(), stateCmp, operand,
465 muxChainOut);
466 outputMuxChainOuts[i] = muxOp;
467 }
468 }
469
470 auto transitions = llvm::SmallVector<TransitionOp>(
471 state.getTransitions().getOps<TransitionOp>());
472 // 3.3) Convert the transitions and add a case to the next-state mux
473 // chain for each
474 auto nextStateRes = convertTransitions(state, transitions);
475 if (failed(nextStateRes))
476 return failure();
477 return success();
478}
479namespace {
480struct FSMToCorePass : public circt::impl::ConvertFSMToCoreBase<FSMToCorePass> {
481 void runOnOperation() override;
482};
483
484void FSMToCorePass::runOnOperation() {
485 auto module = getOperation();
486 auto b = OpBuilder(module);
487 SmallVector<Operation *, 16> opToErase;
488
489 b.setInsertionPointToStart(module.getBody());
490 // Traverse all machines and convert.
491 for (auto machine : llvm::make_early_inc_range(module.getOps<MachineOp>())) {
492
493 // Check validity of the FSM while we can still easily error out
494 if (machine->getAttr("stateType")) {
495 auto stateType = dyn_cast<TypeAttr>(machine->getAttr("stateType"));
496 if (!stateType) {
497 machine->emitError("stateType attribute does not name a type");
498 signalPassFailure();
499 return;
500 }
501 if (!isa<IntegerType>(stateType.getValue())) {
502 machine->emitError("stateType attribute must name an integer type");
503 signalPassFailure();
504 return;
505 }
506 }
507 for (auto variableOp : machine.front().getOps<fsm::VariableOp>()) {
508 if (!isa<IntegerType>(variableOp.getType())) {
509 variableOp.emitOpError(
510 "only integer variables are currently supported");
511 signalPassFailure();
512 return;
513 }
514 }
515
516 MachineOpConverter converter(b, machine);
517
518 if (failed(converter.dispatch())) {
519 signalPassFailure();
520 return;
521 }
522 }
523
524 // Traverse all machine instances and convert to hw instances.
525 llvm::SmallVector<HWInstanceOp> instances;
526 module.walk([&](HWInstanceOp instance) { instances.push_back(instance); });
527 for (auto instance : instances) {
528 auto fsmHWModule =
529 module.lookupSymbol<hw::HWModuleOp>(instance.getMachine());
530 assert(fsmHWModule &&
531 "FSM machine should have been converted to a hw.module");
532
533 b.setInsertionPoint(instance);
534 llvm::SmallVector<Value, 4> operands;
535 llvm::transform(instance.getOperands(), std::back_inserter(operands),
536 [&](auto operand) { return operand; });
537 auto hwInstance = hw::InstanceOp::create(
538 b, instance.getLoc(), fsmHWModule, b.getStringAttr(instance.getName()),
539 operands, nullptr);
540 instance.replaceAllUsesWith(hwInstance);
541 instance.erase();
542 }
543}
544
545} // end anonymous namespace
546
547std::unique_ptr<mlir::Pass> circt::createConvertFSMToCorePass() {
548 return std::make_unique<FSMToCorePass>();
549}
assert(baseType &&"element must be base type")
static ClkRstIdxs getMachinePortInfo(SmallVectorImpl< hw::PortInfo > &ports, MachineOp machine, OpBuilder &b)
Definition FSMToSV.cpp:40
static void cloneConstantsIntoRegion(Region &region, OpBuilder &builder)
Definition FSMToSV.cpp:68
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:216
static Block * getBodyBlock(FModuleLike mod)
Instantiate one of these and use it to build typed backedges.
Backedge is a wrapper class around a Value.
void setValue(mlir::Value)
create(data_type, value)
Definition hw.py:433
create(cls, result_type, reset=None, reset_value=None, name=None, sym_name=None, **kwargs)
Definition seq.py:157
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createConvertFSMToCorePass()
Definition fsm.py:1
mlir::Type type
Definition HWTypes.h:31
mlir::StringAttr name
Definition HWTypes.h:30
This holds the name, type, direction of a module's ports.
size_t argNum
This is the argument index or the result index depending on the direction.