13#include "mlir/Dialect/SMT/IR/SMTOps.h"
14#include "mlir/IR/BuiltinTypes.h"
15#include "mlir/IR/IRMapping.h"
16#include "mlir/Pass/Pass.h"
19#define GEN_PASS_DEF_CONVERTFSMTOSMT
20#include "circt/Conversion/Passes.h.inc"
33struct LoweringConfig {
36 bool withTime =
false;
38 unsigned timeWidth = 8;
41class MachineOpConverter {
43 MachineOpConverter(OpBuilder &builder,
MachineOp machineOp,
44 const LoweringConfig &cfg)
45 : machineOp(machineOp),
b(builder), cfg(cfg) {}
46 LogicalResult dispatch();
49 struct PendingAssertion {
57 std::optional<Region *> guard;
58 std::optional<Region *> action;
59 std::optional<Region *> output;
64 SmallVector<std::pair<int, Value>> stateFunctions;
67 SmallVector<StringRef> &states,
69 StringRef nextState = t.getNextState();
70 Transition tr = {from, insertStates(states, nextState), std::nullopt,
71 std::nullopt, std::nullopt};
73 tr.guard = &t.getGuard();
75 tr.action = &t.getAction();
79 static int insertStates(SmallVector<StringRef> &states, llvm::StringRef st) {
80 for (
auto [
id, s] :
llvm::enumerate(states))
84 return states.size() - 1;
88 getOutputRegion(
const SmallVector<std::pair<Region *, int>> &outputOfStateId,
90 for (
auto oid : outputOfStateId)
91 if (stateId == oid.second)
93 llvm_unreachable(
"State could not be found.");
96 SmallVector<std::pair<mlir::Value, mlir::Value>>
97 mapSmtToFsm(Location loc, OpBuilder b, SmallVector<Value> smtValues,
98 int numArgs,
int numOut, SmallVector<Value> fsmArgs,
99 SmallVector<Value> fsmVars) {
102 SmallVector<std::pair<mlir::Value, mlir::Value>> fsmToCast;
103 for (
auto [idx, fq] :
llvm::enumerate(smtValues)) {
104 if (
int(idx) < numArgs) {
105 auto convCast = UnrealizedConversionCastOp::create(
106 b, loc, fsmArgs[idx].getType(), fq);
107 fsmToCast.push_back({fsmArgs[idx], convCast->getResult(0)});
108 }
else if (numArgs + numOut <=
int(idx)) {
109 if (cfg.withTime && idx == smtValues.size() - 1)
111 auto convCast = UnrealizedConversionCastOp::create(
112 b, loc, fsmVars[idx - numArgs - numOut].getType(), fq);
114 {fsmVars[idx - numArgs - numOut], convCast->getResult(0)});
121 createIRMapping(
const SmallVector<std::pair<Value, Value>> &fsmToCast,
122 const IRMapping &constMapper) {
124 for (
auto couple : fsmToCast)
125 mapping.map(couple.first, couple.second);
126 for (
auto &pair : constMapper.getValueMap())
127 mapping.map(pair.first, pair.second);
132 Value bv1toSmtBool(OpBuilder &b, Location loc, Value i1Value) {
133 auto castVal = UnrealizedConversionCastOp::create(
134 b, loc,
b.getType<smt::BitVectorType>(1), i1Value);
135 return smt::EqOp::create(b, loc, castVal->getResult(0),
136 smt::BVConstantOp::create(b, loc, 1, 1));
139 Value getStateFunction(
int stateId) {
140 for (
auto sf : stateFunctions)
141 if (sf.first == stateId)
143 llvm_unreachable(
"State function could not be found.");
151LogicalResult MachineOpConverter::dispatch() {
152 b.setInsertionPoint(machineOp);
153 auto loc = machineOp.getLoc();
154 auto machineArgs = machineOp.getArguments();
156 mlir::SmallVector<mlir::Value> fsmVars;
157 mlir::SmallVector<mlir::Value> fsmArgs;
158 mlir::SmallVector<mlir::Type> quantifiedTypes;
161 ValueRange valueRange;
164 for (
auto &op : machineOp.front().getOperations())
165 if (!isa<
fsm::FSMDialect>(op.getDialect()) && !isa<
hw::ConstantOp>(op))
167 "Only fsm operations and hw.constants are allowed in the "
168 "top level of the fsm.machine op.");
172 for (
auto stateOp : machineOp.front().getOps<
fsm::
StateOp>()) {
173 if (!stateOp.getOutput().empty())
174 for (
auto &op : stateOp.getOutput().front().getOperations())
175 if (!isa<
fsm::FSMDialect,
comb::CombDialect,
hw::HWDialect>(
177 !isa<
verif::AssertOp>(op))
179 "Only fsm, comb, hw, and verif.assert operations are handled in "
180 "the output region of a state.");
181 if (!stateOp.getTransitions().empty())
185 for (
auto &op : t.getGuard().front().getOperations())
186 if (!isa<
fsm::FSMDialect,
comb::CombDialect,
hw::HWDialect>(
188 !isa<
verif::AssertOp>(op))
189 return op.emitError(
"Only fsm, comb, hw, and verif.assert "
190 "operations are handled in the guard "
191 "region of a transition.");
193 for (
auto &op : t.getAction().front().getOperations())
194 if (!isa<
fsm::FSMDialect,
comb::CombDialect,
hw::HWDialect>(
196 !isa<
verif::AssertOp>(op))
197 return op.emitError(
"Only fsm, comb, hw, and verif.assert "
198 "operations are handled in the action "
199 "region of a transition.");
203 auto solver = smt::SolverOp::create(b, loc, typeRange, valueRange);
204 solver.getBodyRegion().emplaceBlock();
205 b.setInsertionPointToStart(solver.getBody());
208 for (
auto a : machineArgs) {
209 fsmArgs.push_back(a);
210 if (!isa<IntegerType>(
a.getType()))
211 return solver.emitError(
"Only integer arguments are supported in FSMs.");
212 quantifiedTypes.push_back(
213 b.getType<smt::BitVectorType>(
a.getType().getIntOrFloatBitWidth()));
215 size_t numArgs = fsmArgs.size();
218 if (!machineOp.getResultTypes().empty()) {
219 for (
auto t : machineOp.getResultTypes()) {
220 if (!isa<IntegerType>(t))
221 return solver.emitError(
"Only integer outputs are supported in FSMs.");
222 quantifiedTypes.push_back(
223 b.getType<smt::BitVectorType>(t.getIntOrFloatBitWidth()));
227 size_t numOut = quantifiedTypes.size() - numArgs;
230 SmallVector<llvm::APInt> varInitValues;
231 for (
auto v : machineOp.front().getOps<
fsm::VariableOp>()) {
232 if (!isa<IntegerType>(v.getType()))
233 return v.emitError(
"Only integer variables are supported in FSMs.");
234 auto intAttr = dyn_cast<IntegerAttr>(v.getInitValueAttr());
235 varInitValues.push_back(intAttr.getValue());
236 quantifiedTypes.push_back(
237 b.getType<smt::BitVectorType>(v.getType().getIntOrFloatBitWidth()));
238 fsmVars.push_back(v.getResult());
242 IRMapping constMapper;
243 for (
auto constOp : machineOp.front().getOps<
hw::ConstantOp>())
244 b.clone(*constOp, constMapper);
248 quantifiedTypes.push_back(
b.getType<smt::BitVectorType>(cfg.timeWidth));
250 size_t numVars = varInitValues.size();
253 SmallVector<MachineOpConverter::Transition> transitions;
255 SmallVector<StringRef> states;
257 SmallVector<std::pair<Region *, int>> outputOfStateId;
260 StringRef initialState = machineOp.getInitialState();
261 insertStates(states, initialState);
265 SmallVector<Type> stateFunDomain;
266 for (
auto i = numArgs; i < quantifiedTypes.size(); i++)
267 stateFunDomain.push_back(quantifiedTypes[i]);
271 for (
auto stateOp : machineOp.front().getOps<
fsm::
StateOp>()) {
272 int idx = insertStates(states, stateOp.getName());
273 mlir::StringAttr funName =
274 b.getStringAttr((
"F_" + stateOp.getName().str()));
275 auto rangeTy =
b.getType<smt::BoolType>();
276 auto funTy =
b.getType<smt::SMTFuncType>(stateFunDomain, rangeTy);
277 auto acFun = smt::DeclareFunOp::create(b, loc, funTy, funName);
278 stateFunctions.push_back({idx, acFun});
279 outputOfStateId.push_back({&stateOp.getOutput(), idx});
286 for (
auto stateOp : machineOp.front().getOps<
fsm::
StateOp>()) {
287 if (stateOp.getTransitions().empty())
289 StringRef stateName = stateOp.getName();
290 auto fromState = insertStates(states, stateName);
293 auto t = getTransitionRegions(tr, fromState, states, loc);
295 t.output = getOutputRegion(outputOfStateId, t.to);
297 t.output = std::nullopt;
298 transitions.push_back(t);
302 SmallVector<PendingAssertion> assertions;
307 auto initialAssertion = smt::ForallOp::create(
308 b, loc, quantifiedTypes,
309 [&](OpBuilder &b, Location loc,
310 const SmallVector<Value> &forallQuantified) -> Value {
313 SmallVector<std::pair<mlir::Value, mlir::Value>> fsmToCast =
314 mapSmtToFsm(loc, b, forallQuantified, numArgs, numOut, fsmArgs,
316 auto *initOutputReg = getOutputRegion(outputOfStateId, 0);
317 SmallVector<Value> castOutValues;
320 for (
auto [
id, couple] :
llvm::enumerate(fsmToCast)) {
321 if (numArgs + numOut <=
id &&
id < numArgs + numOut + numVars) {
328 IRMapping mapping = createIRMapping(fsmToCast, constMapper);
330 SmallVector<mlir::Value> combOutputValues;
332 if (!initOutputReg->empty()) {
337 for (
auto &op : initOutputReg->front()) {
338 if (isa<verif::AssertOp>(op)) {
340 assertions.push_back({0, initOutputReg});
342 auto *newOp =
b.clone(op, mapping);
345 if (isa<fsm::OutputOp>(newOp)) {
346 for (
auto out : newOp->getOperands())
347 combOutputValues.push_back(out);
356 for (
auto [idx, out] :
llvm::enumerate(combOutputValues)) {
357 auto convCast = UnrealizedConversionCastOp::create(
358 b, loc, forallQuantified[numArgs + idx].getType(), out);
359 castOutValues.push_back(convCast->getResult(0));
365 SmallVector<mlir::Value> initialCondition;
366 for (
auto [idx, q] :
llvm::enumerate(forallQuantified)) {
368 if (numArgs + numOut <= idx &&
369 idx < forallQuantified.size() -
372 initialCondition.push_back(smt::BVConstantOp::create(
373 b, loc, varInitValues[idx - numArgs - numOut]));
374 }
else if (numArgs <= idx &&
375 idx < forallQuantified.size() -
378 initialCondition.push_back(castOutValues[idx - numArgs]);
379 }
else if (idx == forallQuantified.size() -
381 initialCondition.push_back(
382 smt::BVConstantOp::create(b, loc, 0, cfg.timeWidth));
384 if (numArgs + numOut <= idx) {
386 initialCondition.push_back(smt::BVConstantOp::create(
387 b, loc, varInitValues[idx - numArgs - numOut]));
388 }
else if (numArgs <= idx) {
390 initialCondition.push_back(castOutValues[idx - numArgs]);
395 return smt::ApplyFuncOp::create(b, loc, getStateFunction(0),
400 smt::AssertOp::create(b, loc, initialAssertion);
404 SmallVector<Type> transitionQuantified;
405 for (
auto [
id, ty] :
llvm::enumerate(quantifiedTypes)) {
406 transitionQuantified.push_back(ty);
408 transitionQuantified.push_back(ty);
416 for (
auto [transId, transition] :
llvm::enumerate(transitions)) {
419 const auto ¤tTransition = transition;
420 const auto ¤tTransId = transId;
425 [&](SmallVector<Value> &actionArgsOutsVarsVals) -> SmallVector<Value> {
428 SmallVector<std::pair<mlir::Value, mlir::Value>> fsmToCast = mapSmtToFsm(
429 loc, b, actionArgsOutsVarsVals, numArgs, numOut, fsmArgs, fsmVars);
432 SmallVector<Value> castUpdatedVars;
435 for (
size_t i = 0; i < numVars; i++)
436 castUpdatedVars.push_back(actionArgsOutsVarsVals[numArgs + numOut + i]);
438 if (currentTransition.action.has_value()) {
440 auto *actionReg = currentTransition.action.value();
442 IRMapping mapping = createIRMapping(fsmToCast, constMapper);
444 SmallVector<std::pair<mlir::Value, mlir::Value>> combActionValues;
450 for (
auto &op : actionReg->front()) {
451 if (isa<verif::AssertOp>(op)) {
453 op.emitWarning(
"Assertions in action regions are ignored.");
455 auto *newOp =
b.clone(op, mapping);
457 if (isa<fsm::UpdateOp>(newOp)) {
458 auto varToUpdate = newOp->getOperand(0);
459 auto updatedValue = newOp->getOperand(1);
460 for (
auto [
id, var] :
llvm::enumerate(fsmToCast)) {
461 if (var.second == varToUpdate) {
463 auto convCast = UnrealizedConversionCastOp::create(
464 b, loc, actionArgsOutsVarsVals[numOut +
id].getType(),
467 castUpdatedVars[
id - numArgs] = convCast->getResult(0);
475 return castUpdatedVars;
479 auto guard = [&](SmallVector<Value> &actionArgsOutsVarsVals) -> Value {
482 SmallVector<std::pair<mlir::Value, mlir::Value>> fsmToCast = mapSmtToFsm(
483 loc, b, actionArgsOutsVarsVals, numArgs, numOut, fsmArgs, fsmVars);
485 IRMapping mapping = createIRMapping(fsmToCast, constMapper);
491 if (currentTransition.guard.has_value()) {
492 for (
auto &op : currentTransition.guard.value()->front()) {
493 if (isa<verif::AssertOp>(op)) {
495 op.emitWarning(
"Assertions in guard regions are ignored.");
497 auto *newOp =
b.clone(op, mapping);
499 if (isa<fsm::ReturnOp>(newOp)) {
501 auto castVal = mlir::UnrealizedConversionCastOp::create(
502 b, loc,
b.getType<smt::BitVectorType>(1),
503 newOp->getOperand(0));
505 guardVal = bv1toSmtBool(b, loc, castVal.getResult(0));
511 guardVal = smt::BoolConstantOp::create(b, loc,
true);
517 for (
auto [transId1, transition1] :
llvm::enumerate(transitions)) {
518 if (transition1.from == currentTransition.from &&
519 transId1 < currentTransId) {
521 if (transition1.guard.has_value()) {
522 for (
auto &op : transition1.guard.value()->front()) {
523 if (isa<fsm::ReturnOp>(op)) {
525 auto castVal = mlir::UnrealizedConversionCastOp::create(
526 b, loc,
b.getType<smt::BitVectorType>(1),
527 mapping.lookup(op.getOperand(0)));
529 prevGuardVal = bv1toSmtBool(b, loc, castVal.getResult(0));
532 Value negVal = smt::NotOp::create(b, loc, prevGuardVal);
533 guardVal = smt::AndOp::create(b, loc, guardVal, negVal);
534 }
else if (!isa<verif::AssertOp>(op)) {
535 b.clone(op, mapping);
541 Value negVal = smt::NotOp::create(
542 b, loc, smt::BoolConstantOp::create(b, loc,
true));
543 guardVal = smt::AndOp::create(b, loc, guardVal, negVal);
552 [&](SmallVector<Value> &outputArgsOutsVarsVals) -> SmallVector<Value> {
555 SmallVector<std::pair<mlir::Value, mlir::Value>> fsmToCast = mapSmtToFsm(
556 loc, b, outputArgsOutsVarsVals, numArgs, numOut, fsmArgs, fsmVars);
557 SmallVector<Value> castOutputVars;
558 auto *outputReg = currentTransition.output.value();
559 IRMapping mapping = createIRMapping(fsmToCast, constMapper);
564 for (
auto &op : outputReg->front()) {
565 if (isa<verif::AssertOp>(op)) {
567 assertions.push_back({currentTransition.to, outputReg});
569 auto *newOp =
b.clone(op, mapping);
572 if (isa<fsm::OutputOp>(newOp)) {
573 for (
auto [
id, operand] :
llvm::enumerate(newOp->getOperands())) {
576 auto convCast = UnrealizedConversionCastOp::create(
577 b, loc, outputArgsOutsVarsVals[numArgs +
id].getType(),
579 castOutputVars.push_back(convCast->getResult(0));
586 return castOutputVars;
589 auto forall = smt::ForallOp::create(
590 b, loc, transitionQuantified,
591 [&](OpBuilder &b, Location loc,
592 ValueRange doubledQuantifiedVars) -> Value {
593 SmallVector<Value> startingArgsOutsVars;
594 SmallVector<Value> startingFunArgs;
595 SmallVector<Value> arrivingArgsOutsVars;
596 SmallVector<Value> arrivingFunArgs;
598 for (
size_t i = 0; i < 2 * numArgs; i++) {
600 arrivingArgsOutsVars.push_back(doubledQuantifiedVars[i]);
602 startingArgsOutsVars.push_back(doubledQuantifiedVars[i]);
604 for (
auto i = 2 * numArgs; i < doubledQuantifiedVars.size(); ++i) {
605 startingArgsOutsVars.push_back(doubledQuantifiedVars[i]);
606 arrivingArgsOutsVars.push_back(doubledQuantifiedVars[i]);
607 startingFunArgs.push_back(doubledQuantifiedVars[i]);
612 auto lhs = smt::ApplyFuncOp::create(
613 b, loc, getStateFunction(currentTransition.from),
617 auto updatedCastVals = action(startingArgsOutsVars);
618 for (
size_t i = 0; i < numVars; i++)
619 arrivingArgsOutsVars[numArgs + numOut + i] = updatedCastVals[i];
623 if (currentTransition.output.has_value()) {
624 auto outputCastVals = output(arrivingArgsOutsVars);
627 for (
auto o : outputCastVals)
628 arrivingFunArgs.push_back(o);
633 for (
auto u : updatedCastVals)
634 arrivingFunArgs.push_back(u);
638 auto timeVal = doubledQuantifiedVars.back();
639 auto oneConst = smt::BVConstantOp::create(b, loc, 1, cfg.timeWidth);
640 auto incrementedTime = smt::BVAddOp::create(
641 b, loc, timeVal.getType(), timeVal, oneConst);
642 arrivingFunArgs.push_back(incrementedTime);
645 auto rhs = smt::ApplyFuncOp::create(
646 b, loc, getStateFunction(currentTransition.to), arrivingFunArgs);
650 auto guardVal = guard(startingArgsOutsVars);
651 auto guardedlhs = smt::AndOp::create(b, loc, lhs, guardVal);
652 return smt::ImpliesOp::create(b, loc, guardedlhs, rhs);
655 smt::AssertOp::create(b, loc, forall);
660 for (
auto pa : assertions) {
661 auto forall = smt::ForallOp::create(
662 b, loc, quantifiedTypes,
663 [&](OpBuilder &b, Location loc, ValueRange forallQuantified) -> Value {
666 SmallVector<std::pair<mlir::Value, mlir::Value>> fsmToCast =
667 mapSmtToFsm(loc, b, forallQuantified, numArgs, numOut, fsmArgs,
670 IRMapping mapping = createIRMapping(fsmToCast, constMapper);
672 Value returnVal = smt::BoolConstantOp::create(b, loc,
true);
678 for (
auto &op : pa.outputRegion->front()) {
680 if (isa<comb::CombDialect, hw::HWDialect>(op.getDialect()) ||
681 isa<verif::AssertOp>(op)) {
682 auto *newOp =
b.clone(op, mapping);
685 if (isa<verif::AssertOp>(newOp)) {
686 auto assertedVal = newOp->getOperand(0);
687 auto castVal = mlir::UnrealizedConversionCastOp::create(
688 b, loc,
b.getType<smt::BitVectorType>(1), assertedVal);
691 auto toBool = bv1toSmtBool(b, loc, castVal.getResult(0));
692 auto inState = smt::ApplyFuncOp::create(
693 b, loc, getStateFunction(pa.stateId),
694 forallQuantified.drop_front(numArgs));
698 returnVal = smt::ImpliesOp::create(b, loc, inState, toBool);
706 smt::AssertOp::create(b, loc, forall);
709 smt::YieldOp::create(b, loc, typeRange, valueRange);
715struct FSMToSMTPass :
public circt::impl::ConvertFSMToSMTBase<FSMToSMTPass> {
716 void runOnOperation()
override;
719void FSMToSMTPass::runOnOperation() {
720 auto module = getOperation();
724 if (machineOps.empty()) {
730 cfg.withTime = withTime;
731 cfg.timeWidth = timeWidth;
734 for (
auto &op : module.getOps())
737 for (auto &use : op.getUses())
738 if (use.getOwner()->getParentOfType<
MachineOp>()) {
739 op.emitError(
"Only operations defined within fsm.machine operations "
740 "are currently supported for use within them");
745 for (
auto machine :
llvm::make_early_inc_range(module.getOps<
MachineOp>())) {
746 MachineOpConverter converter(b, machine, cfg);
747 if (failed(converter.dispatch())) {
757 return std::make_unique<FSMToSMTPass>();
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::unique_ptr< mlir::Pass > createConvertFSMToSMTPass()