13#include "mlir/Dialect/SMT/IR/SMTOps.h"
14#include "mlir/IR/BuiltinTypes.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/Pass/Pass.h"
19#define GEN_PASS_DEF_CONVERTFSMTOSMT
20#include "circt/Conversion/Passes.h.inc"
33struct LoweringConfig {
36 bool withTime =
false;
38 unsigned timeWidth = 8;
41class MachineOpConverter {
43 MachineOpConverter(OpBuilder &builder,
MachineOp machineOp,
44 const LoweringConfig &cfg)
45 : machineOp(machineOp),
b(builder), cfg(cfg) {}
46 LogicalResult dispatch();
49 struct PendingAssertion {
57 std::optional<Region *> guard;
58 std::optional<Region *> action;
59 std::optional<Region *> output;
64 SmallVector<std::pair<int, Value>> stateFunctions;
67 SmallVector<StringRef> &states,
69 StringRef nextState = t.getNextState();
70 Transition tr = {from, insertStates(states, nextState), std::nullopt,
71 std::nullopt, std::nullopt};
73 tr.guard = &t.getGuard();
75 tr.action = &t.getAction();
79 static int insertStates(SmallVector<StringRef> &states, llvm::StringRef st) {
80 for (
auto [
id, s] :
llvm::enumerate(states))
84 return states.size() - 1;
88 getOutputRegion(
const SmallVector<std::pair<Region *, int>> &outputOfStateId,
90 for (
auto oid : outputOfStateId)
91 if (stateId == oid.second)
93 llvm_unreachable(
"State could not be found.");
96 SmallVector<std::pair<mlir::Value, mlir::Value>>
97 mapSmtToFsm(Location loc, OpBuilder b, SmallVector<Value> smtValues,
98 int numArgs,
int numOut, SmallVector<Value> fsmArgs,
99 SmallVector<Value> fsmVars) {
102 SmallVector<std::pair<mlir::Value, mlir::Value>> fsmToCast;
103 for (
auto [idx, fq] :
llvm::enumerate(smtValues)) {
104 if (
int(idx) < numArgs) {
105 auto convCast = UnrealizedConversionCastOp::create(
106 b, loc, fsmArgs[idx].getType(), fq);
107 fsmToCast.push_back({fsmArgs[idx], convCast->getResult(0)});
108 }
else if (numArgs + numOut <=
int(idx)) {
109 if (cfg.withTime && idx == smtValues.size() - 1)
111 auto convCast = UnrealizedConversionCastOp::create(
112 b, loc, fsmVars[idx - numArgs - numOut].getType(), fq);
114 {fsmVars[idx - numArgs - numOut], convCast->getResult(0)});
121 createIRMapping(
const SmallVector<std::pair<Value, Value>> &fsmToCast,
122 const IRMapping &constMapper) {
124 for (
auto couple : fsmToCast)
125 mapping.map(couple.first, couple.second);
126 for (
auto &pair : constMapper.getValueMap())
127 mapping.map(pair.first, pair.second);
132 Value bv1toSmtBool(OpBuilder &b, Location loc, Value i1Value) {
133 auto castVal = UnrealizedConversionCastOp::create(
134 b, loc,
b.getType<smt::BitVectorType>(1), i1Value);
135 return smt::EqOp::create(b, loc, castVal->getResult(0),
136 smt::BVConstantOp::create(b, loc, 1, 1));
139 Value getStateFunction(
int stateId) {
140 for (
auto sf : stateFunctions)
141 if (sf.first == stateId)
143 llvm_unreachable(
"State function could not be found.");
151LogicalResult MachineOpConverter::dispatch() {
152 b.setInsertionPoint(machineOp);
153 auto loc = machineOp.getLoc();
154 auto machineArgs = machineOp.getArguments();
156 mlir::SmallVector<mlir::Value> fsmVars;
157 mlir::SmallVector<mlir::Value> fsmArgs;
158 mlir::SmallVector<mlir::Type> quantifiedTypes;
161 ValueRange valueRange;
164 for (
auto &op : machineOp.front().getOperations())
165 if (!isa<
fsm::FSMDialect>(op.getDialect()) && !isa<
hw::ConstantOp>(op))
167 "Only fsm operations and hw.constants are allowed in the "
168 "top level of the fsm.machine op.");
172 for (
auto stateOp : machineOp.front().getOps<
fsm::
StateOp>()) {
173 if (!stateOp.getOutput().empty())
174 for (
auto &op : stateOp.getOutput().front().getOperations())
175 if (!isa<
fsm::FSMDialect,
comb::CombDialect,
hw::HWDialect>(
177 !isa<
verif::AssertOp>(op))
179 "Only fsm, comb, hw, and verif.assert operations are handled in "
180 "the output region of a state.");
181 if (!stateOp.getTransitions().empty())
185 for (
auto &op : t.getGuard().front().getOperations())
186 if (!isa<
fsm::FSMDialect,
comb::CombDialect,
hw::HWDialect>(
188 !isa<
verif::AssertOp>(op))
189 return op.emitError(
"Only fsm, comb, hw, and verif.assert "
190 "operations are handled in the guard "
191 "region of a transition.");
193 for (
auto &op : t.getAction().front().getOperations())
194 if (!isa<
fsm::FSMDialect,
comb::CombDialect,
hw::HWDialect>(
196 !isa<
verif::AssertOp>(op))
197 return op.emitError(
"Only fsm, comb, hw, and verif.assert "
198 "operations are handled in the action "
199 "region of a transition.");
203 auto solver = smt::SolverOp::create(b, loc, typeRange, valueRange);
204 solver.getBodyRegion().emplaceBlock();
205 b.setInsertionPointToStart(solver.getBody());
208 for (
auto a : machineArgs) {
209 fsmArgs.push_back(a);
210 if (!isa<IntegerType>(
a.getType()))
211 return solver.emitError(
"Only integer arguments are supported in FSMs.");
212 quantifiedTypes.push_back(
213 b.getType<smt::BitVectorType>(
a.getType().getIntOrFloatBitWidth()));
215 size_t numArgs = fsmArgs.size();
218 if (!machineOp.getResultTypes().empty()) {
219 for (
auto t : machineOp.getResultTypes()) {
220 if (!isa<IntegerType>(t))
221 return solver.emitError(
"Only integer outputs are supported in FSMs.");
222 quantifiedTypes.push_back(
223 b.getType<smt::BitVectorType>(t.getIntOrFloatBitWidth()));
227 size_t numOut = quantifiedTypes.size() - numArgs;
230 SmallVector<llvm::APInt> varInitValues;
231 for (
auto v : machineOp.front().getOps<
fsm::VariableOp>()) {
232 if (!isa<IntegerType>(v.getType()))
233 return v.emitError(
"Only integer variables are supported in FSMs.");
234 auto intAttr = dyn_cast<IntegerAttr>(v.getInitValueAttr());
235 varInitValues.push_back(intAttr.getValue());
236 quantifiedTypes.push_back(
237 b.getType<smt::BitVectorType>(v.getType().getIntOrFloatBitWidth()));
238 fsmVars.push_back(v.getResult());
242 IRMapping constMapper;
243 for (
auto constOp : machineOp.front().getOps<
hw::ConstantOp>())
244 b.clone(*constOp, constMapper);
248 quantifiedTypes.push_back(
b.getType<smt::BitVectorType>(cfg.timeWidth));
250 size_t numVars = varInitValues.size();
253 SmallVector<MachineOpConverter::Transition> transitions;
255 SmallVector<StringRef> states;
257 SmallVector<std::pair<Region *, int>> outputOfStateId;
260 StringRef initialState = machineOp.getInitialState();
261 insertStates(states, initialState);
265 SmallVector<Type> stateFunDomain;
266 for (
auto i = numArgs; i < quantifiedTypes.size(); i++)
267 stateFunDomain.push_back(quantifiedTypes[i]);
270 if (stateFunDomain.empty())
271 return solver.emitError(
"At least one variable or output is required.");
275 for (
auto stateOp : machineOp.front().getOps<
fsm::
StateOp>()) {
276 int idx = insertStates(states, stateOp.getName());
277 mlir::StringAttr funName =
278 b.getStringAttr((
"F_" + stateOp.getName().str()));
279 auto rangeTy =
b.getType<smt::BoolType>();
280 auto funTy =
b.getType<smt::SMTFuncType>(stateFunDomain, rangeTy);
281 auto acFun = smt::DeclareFunOp::create(b, loc, funTy, funName);
282 stateFunctions.push_back({idx, acFun});
283 outputOfStateId.push_back({&stateOp.getOutput(), idx});
290 for (
auto stateOp : machineOp.front().getOps<
fsm::
StateOp>()) {
291 if (stateOp.getTransitions().empty())
293 StringRef stateName = stateOp.getName();
294 auto fromState = insertStates(states, stateName);
297 auto t = getTransitionRegions(tr, fromState, states, loc);
299 t.output = getOutputRegion(outputOfStateId, t.to);
301 t.output = std::nullopt;
302 transitions.push_back(t);
306 SmallVector<PendingAssertion> assertions;
311 auto initialAssertion = smt::ForallOp::create(
312 b, loc, quantifiedTypes,
313 [&](OpBuilder &b, Location loc,
314 const SmallVector<Value> &forallQuantified) -> Value {
317 SmallVector<std::pair<mlir::Value, mlir::Value>> fsmToCast =
318 mapSmtToFsm(loc, b, forallQuantified, numArgs, numOut, fsmArgs,
320 auto *initOutputReg = getOutputRegion(outputOfStateId, 0);
321 SmallVector<Value> castOutValues;
324 for (
auto [
id, couple] :
llvm::enumerate(fsmToCast)) {
325 if (numArgs + numOut <=
id &&
id < numArgs + numOut + numVars) {
332 IRMapping mapping = createIRMapping(fsmToCast, constMapper);
334 SmallVector<mlir::Value> combOutputValues;
336 if (!initOutputReg->empty()) {
341 for (
auto &op : initOutputReg->front()) {
342 if (isa<verif::AssertOp>(op)) {
344 assertions.push_back({0, initOutputReg});
346 auto *newOp =
b.clone(op, mapping);
349 if (isa<fsm::OutputOp>(newOp)) {
350 for (
auto out : newOp->getOperands())
351 combOutputValues.push_back(out);
360 for (
auto [idx, out] :
llvm::enumerate(combOutputValues)) {
361 auto convCast = UnrealizedConversionCastOp::create(
362 b, loc, forallQuantified[numArgs + idx].getType(), out);
363 castOutValues.push_back(convCast->getResult(0));
369 SmallVector<mlir::Value> initialCondition;
370 for (
auto [idx, q] :
llvm::enumerate(forallQuantified)) {
372 if (numArgs + numOut <= idx &&
373 idx < forallQuantified.size() -
376 initialCondition.push_back(smt::BVConstantOp::create(
377 b, loc, varInitValues[idx - numArgs - numOut]));
378 }
else if (numArgs <= idx &&
379 idx < forallQuantified.size() -
382 initialCondition.push_back(castOutValues[idx - numArgs]);
383 }
else if (idx == forallQuantified.size() -
385 initialCondition.push_back(
386 smt::BVConstantOp::create(b, loc, 0, cfg.timeWidth));
388 if (numArgs + numOut <= idx) {
390 initialCondition.push_back(smt::BVConstantOp::create(
391 b, loc, varInitValues[idx - numArgs - numOut]));
392 }
else if (numArgs <= idx) {
394 initialCondition.push_back(castOutValues[idx - numArgs]);
399 return smt::ApplyFuncOp::create(b, loc, getStateFunction(0),
404 smt::AssertOp::create(b, loc, initialAssertion);
408 SmallVector<Type> transitionQuantified;
409 for (
auto [
id, ty] :
llvm::enumerate(quantifiedTypes)) {
410 transitionQuantified.push_back(ty);
412 transitionQuantified.push_back(ty);
420 for (
auto [transId, transition] :
llvm::enumerate(transitions)) {
423 const auto ¤tTransition = transition;
424 const auto ¤tTransId = transId;
429 [&](SmallVector<Value> &actionArgsOutsVarsVals) -> SmallVector<Value> {
432 SmallVector<std::pair<mlir::Value, mlir::Value>> fsmToCast = mapSmtToFsm(
433 loc, b, actionArgsOutsVarsVals, numArgs, numOut, fsmArgs, fsmVars);
436 SmallVector<Value> castUpdatedVars;
439 for (
size_t i = 0; i < numVars; i++)
440 castUpdatedVars.push_back(actionArgsOutsVarsVals[numArgs + numOut + i]);
442 if (currentTransition.action.has_value()) {
444 auto *actionReg = currentTransition.action.value();
446 IRMapping mapping = createIRMapping(fsmToCast, constMapper);
448 SmallVector<std::pair<mlir::Value, mlir::Value>> combActionValues;
454 for (
auto &op : actionReg->front()) {
455 if (isa<verif::AssertOp>(op)) {
457 op.emitWarning(
"Assertions in action regions are ignored.");
459 auto *newOp =
b.clone(op, mapping);
461 if (isa<fsm::UpdateOp>(newOp)) {
462 auto varToUpdate = newOp->getOperand(0);
463 auto updatedValue = newOp->getOperand(1);
464 for (
auto [
id, var] :
llvm::enumerate(fsmToCast)) {
465 if (var.second == varToUpdate) {
467 auto convCast = UnrealizedConversionCastOp::create(
468 b, loc, actionArgsOutsVarsVals[numOut +
id].getType(),
471 castUpdatedVars[
id - numArgs] = convCast->getResult(0);
479 return castUpdatedVars;
483 auto guard = [&](SmallVector<Value> &actionArgsOutsVarsVals) -> Value {
486 SmallVector<std::pair<mlir::Value, mlir::Value>> fsmToCast = mapSmtToFsm(
487 loc, b, actionArgsOutsVarsVals, numArgs, numOut, fsmArgs, fsmVars);
489 IRMapping mapping = createIRMapping(fsmToCast, constMapper);
495 if (currentTransition.guard.has_value()) {
496 for (
auto &op : currentTransition.guard.value()->front()) {
497 if (isa<verif::AssertOp>(op)) {
499 op.emitWarning(
"Assertions in guard regions are ignored.");
501 auto *newOp =
b.clone(op, mapping);
503 if (isa<fsm::ReturnOp>(newOp)) {
505 auto castVal = mlir::UnrealizedConversionCastOp::create(
506 b, loc,
b.getType<smt::BitVectorType>(1),
507 newOp->getOperand(0));
509 guardVal = bv1toSmtBool(b, loc, castVal.getResult(0));
515 guardVal = smt::BoolConstantOp::create(b, loc,
true);
521 for (
auto [transId1, transition1] :
llvm::enumerate(transitions)) {
522 if (transition1.from == currentTransition.from &&
523 transId1 < currentTransId) {
525 if (transition1.guard.has_value()) {
526 for (
auto &op : transition1.guard.value()->front()) {
527 if (isa<fsm::ReturnOp>(op)) {
529 auto castVal = mlir::UnrealizedConversionCastOp::create(
530 b, loc,
b.getType<smt::BitVectorType>(1),
531 mapping.lookup(op.getOperand(0)));
533 prevGuardVal = bv1toSmtBool(b, loc, castVal.getResult(0));
536 Value negVal = smt::NotOp::create(b, loc, prevGuardVal);
537 guardVal = smt::AndOp::create(b, loc, guardVal, negVal);
538 }
else if (!isa<verif::AssertOp>(op)) {
539 b.clone(op, mapping);
545 Value negVal = smt::NotOp::create(
546 b, loc, smt::BoolConstantOp::create(b, loc,
true));
547 guardVal = smt::AndOp::create(b, loc, guardVal, negVal);
556 [&](SmallVector<Value> &outputArgsOutsVarsVals) -> SmallVector<Value> {
559 SmallVector<std::pair<mlir::Value, mlir::Value>> fsmToCast = mapSmtToFsm(
560 loc, b, outputArgsOutsVarsVals, numArgs, numOut, fsmArgs, fsmVars);
561 SmallVector<Value> castOutputVars;
562 auto *outputReg = currentTransition.output.value();
563 IRMapping mapping = createIRMapping(fsmToCast, constMapper);
568 for (
auto &op : outputReg->front()) {
569 if (isa<verif::AssertOp>(op)) {
571 assertions.push_back({currentTransition.to, outputReg});
573 auto *newOp =
b.clone(op, mapping);
576 if (isa<fsm::OutputOp>(newOp)) {
577 for (
auto [
id, operand] :
llvm::enumerate(newOp->getOperands())) {
580 auto convCast = UnrealizedConversionCastOp::create(
581 b, loc, outputArgsOutsVarsVals[numArgs +
id].getType(),
583 castOutputVars.push_back(convCast->getResult(0));
590 return castOutputVars;
593 auto forall = smt::ForallOp::create(
594 b, loc, transitionQuantified,
595 [&](OpBuilder &b, Location loc,
596 ValueRange doubledQuantifiedVars) -> Value {
597 SmallVector<Value> startingArgsOutsVars;
598 SmallVector<Value> startingFunArgs;
599 SmallVector<Value> arrivingArgsOutsVars;
600 SmallVector<Value> arrivingFunArgs;
602 for (
size_t i = 0; i < 2 * numArgs; i++) {
604 arrivingArgsOutsVars.push_back(doubledQuantifiedVars[i]);
606 startingArgsOutsVars.push_back(doubledQuantifiedVars[i]);
608 for (
auto i = 2 * numArgs; i < doubledQuantifiedVars.size(); ++i) {
609 startingArgsOutsVars.push_back(doubledQuantifiedVars[i]);
610 arrivingArgsOutsVars.push_back(doubledQuantifiedVars[i]);
611 startingFunArgs.push_back(doubledQuantifiedVars[i]);
616 auto lhs = smt::ApplyFuncOp::create(
617 b, loc, getStateFunction(currentTransition.from),
621 auto updatedCastVals = action(startingArgsOutsVars);
622 for (
size_t i = 0; i < numVars; i++)
623 arrivingArgsOutsVars[numArgs + numOut + i] = updatedCastVals[i];
627 if (currentTransition.output.has_value()) {
628 auto outputCastVals = output(arrivingArgsOutsVars);
631 for (
auto o : outputCastVals)
632 arrivingFunArgs.push_back(o);
637 for (
auto u : updatedCastVals)
638 arrivingFunArgs.push_back(u);
642 auto timeVal = doubledQuantifiedVars.back();
643 auto oneConst = smt::BVConstantOp::create(b, loc, 1, cfg.timeWidth);
644 auto incrementedTime = smt::BVAddOp::create(
645 b, loc, timeVal.getType(), timeVal, oneConst);
646 arrivingFunArgs.push_back(incrementedTime);
649 auto rhs = smt::ApplyFuncOp::create(
650 b, loc, getStateFunction(currentTransition.to), arrivingFunArgs);
654 auto guardVal = guard(startingArgsOutsVars);
655 auto guardedlhs = smt::AndOp::create(b, loc, lhs, guardVal);
656 return smt::ImpliesOp::create(b, loc, guardedlhs, rhs);
659 smt::AssertOp::create(b, loc, forall);
664 for (
auto pa : assertions) {
665 auto forall = smt::ForallOp::create(
666 b, loc, quantifiedTypes,
667 [&](OpBuilder &b, Location loc, ValueRange forallQuantified) -> Value {
670 SmallVector<std::pair<mlir::Value, mlir::Value>> fsmToCast =
671 mapSmtToFsm(loc, b, forallQuantified, numArgs, numOut, fsmArgs,
674 IRMapping mapping = createIRMapping(fsmToCast, constMapper);
676 Value returnVal = smt::BoolConstantOp::create(b, loc,
true);
682 for (
auto &op : pa.outputRegion->front()) {
684 if (isa<comb::CombDialect, hw::HWDialect>(op.getDialect()) ||
685 isa<verif::AssertOp>(op)) {
686 auto *newOp =
b.clone(op, mapping);
689 if (isa<verif::AssertOp>(newOp)) {
690 auto assertedVal = newOp->getOperand(0);
691 auto castVal = mlir::UnrealizedConversionCastOp::create(
692 b, loc,
b.getType<smt::BitVectorType>(1), assertedVal);
695 auto toBool = bv1toSmtBool(b, loc, castVal.getResult(0));
696 auto inState = smt::ApplyFuncOp::create(
697 b, loc, getStateFunction(pa.stateId),
698 forallQuantified.drop_front(numArgs));
702 returnVal = smt::ImpliesOp::create(b, loc, inState, toBool);
710 smt::AssertOp::create(b, loc, forall);
713 smt::YieldOp::create(b, loc, typeRange, valueRange);
719struct FSMToSMTPass :
public circt::impl::ConvertFSMToSMTBase<FSMToSMTPass> {
720 void runOnOperation()
override;
723void FSMToSMTPass::runOnOperation() {
724 auto module = getOperation();
728 if (machineOps.empty()) {
734 cfg.withTime = withTime;
735 cfg.timeWidth = timeWidth;
738 for (
auto &op : module.getOps())
741 for (auto &use : op.getUses())
742 if (use.getOwner()->getParentOfType<
MachineOp>()) {
743 op.emitError(
"Only operations defined within fsm.machine operations "
744 "are currently supported for use within them");
749 for (
auto machine :
llvm::make_early_inc_range(module.getOps<
MachineOp>())) {
750 MachineOpConverter converter(b, machine, cfg);
751 if (failed(converter.dispatch())) {
761 return std::make_unique<FSMToSMTPass>();
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createConvertFSMToSMTPass()