17#include "mlir/Analysis/TopologicalSortUtils.h"
18#include "mlir/IR/IRMapping.h"
24#define GEN_PASS_DEF_LOWERCONTRACTSPASS
25#include "circt/Dialect/Verif/Passes.h.inc"
34struct LowerContractsPass
35 : verif::impl::LowerContractsPassBase<LowerContractsPass> {
36 void runOnOperation()
override;
39Operation *replaceContractOp(OpBuilder &builder, RequireLike op,
40 IRMapping &mapping,
bool assumeContract) {
41 StringAttr labelAttr = op.getLabelAttr();
44 if (
auto enable = op.getEnable())
45 enable = mapping.lookup(enable);
47 auto loc = op.getLoc();
48 auto property = mapping.lookup(op.getProperty());
50 if ((isa<EnsureOp>(op) && !assumeContract) ||
51 (isa<RequireOp>(op) && assumeContract))
52 return AssertOp::create(builder, loc, property, enableValue, labelAttr);
53 if ((isa<EnsureOp>(op) && assumeContract) ||
54 (isa<RequireOp>(op) && !assumeContract))
55 return AssumeOp::create(builder, loc, property, enableValue, labelAttr);
59FailureOr<Operation *> cloneOp(OpBuilder &builder, Operation *opToClone,
60 IRMapping &mapping,
bool assumeContract) {
63 if (
auto requireLike = dyn_cast<RequireLike>(*opToClone)) {
64 clonedOp = replaceContractOp(builder, requireLike, mapping, assumeContract);
69 clonedOp = builder.clone(*opToClone, mapping);
74 llvm::zip(opToClone->getResults(), clonedOp->getResults())) {
80void assumeContractHolds(OpBuilder &builder, IRMapping &mapping,
82 std::queue<Operation *> &workList) {
86 for (
auto result : contract.getResults()) {
88 SymbolicValueOp::create(builder, result.getLoc(), result.getType());
89 mapping.map(result, sym);
91 auto &contractOps = contract.getBody().front().getOperations();
92 for (
auto it = contractOps.rbegin(); it != contractOps.rend(); ++it) {
97void buildOpsToClone(OpBuilder &builder, IRMapping &mapping, Operation *op,
98 std::queue<Operation *> &workList,
99 Operation *parent =
nullptr) {
100 for (
auto operand : op->getOperands()) {
101 if (mapping.contains(operand))
104 if (parent && parent->isAncestor(operand.getParentBlock()->getParentOp()))
106 if (
auto *definingOp = operand.getDefiningOp()) {
107 workList.push(definingOp);
110 auto sym = verif::SymbolicValueOp::create(builder, operand.getLoc(),
112 mapping.map(operand, sym);
117LogicalResult cloneFanIn(OpBuilder &builder, Operation *contractToClone,
118 IRMapping &mapping) {
119 DenseSet<Operation *> seen;
120 SmallVector<Operation *> opsToClone;
121 SmallVector<Operation *> clonedOps;
122 std::queue<Operation *> workList;
123 workList.push(contractToClone);
125 while (!workList.empty()) {
126 auto *currentOp = workList.front();
129 if (seen.contains(currentOp))
132 seen.insert(currentOp);
134 auto contract = dyn_cast<ContractOp>(*currentOp);
136 if (contract && (currentOp != contractToClone)) {
137 assumeContractHolds(builder, mapping, contract, workList);
144 buildOpsToClone(builder, mapping, currentOp, workList);
146 currentOp->walk([&](Operation *nestedOp) {
147 buildOpsToClone(builder, mapping, nestedOp, workList, currentOp);
151 if (currentOp != contractToClone)
152 opsToClone.push_back(currentOp);
156 computeTopologicalSorting(opsToClone);
158 clonedOps.reserve(opsToClone.size());
159 for (
auto *op : opsToClone) {
160 if (
auto clonedOp = cloneOp(builder, op, mapping,
true);
161 succeeded(clonedOp)) {
162 clonedOps.push_back(*clonedOp);
169 for (
auto *clonedOp : clonedOps) {
170 for (
unsigned i = 0, e = clonedOp->getNumOperands(); i < e; i++) {
171 auto operand = clonedOp->getOperand(i);
172 if (mapping.contains(operand)) {
173 clonedOp->setOperand(i, mapping.lookup(operand));
180LogicalResult inlineContract(ContractOp &contract, OpBuilder &builder,
181 bool assumeContract) {
183 if (!assumeContract) {
185 if (failed(cloneFanIn(builder, contract, mapping)))
189 if (assumeContract) {
191 for (
auto result : contract.getResults()) {
193 SymbolicValueOp::create(builder, result.getLoc(), result.getType());
194 mapping.map(result, sym);
198 for (
auto [result, input] :
199 llvm::zip(contract.getResults(), contract.getInputs())) {
200 mapping.map(result, mapping.lookup(input));
205 for (
auto &op : contract.getBody().front().getOperations()) {
206 if (failed(cloneOp(builder, &op, mapping, assumeContract)))
210 if (assumeContract) {
212 for (
auto result : contract.getResults())
213 result.replaceAllUsesWith(mapping.lookup(result));
218LogicalResult runOnHWModule(
HWModuleOp hwModule, ModuleOp mlirModule) {
219 OpBuilder mlirModuleBuilder(mlirModule);
220 mlirModuleBuilder.setInsertionPointAfter(hwModule);
221 OpBuilder hwModuleBuilder(hwModule);
224 SmallVector<ContractOp> contracts;
225 hwModule.walk([&](ContractOp op) { contracts.push_back(op); });
227 for (
unsigned i = 0; i < contracts.size(); i++) {
228 auto contract = contracts[i];
231 auto name = mlirModuleBuilder.getStringAttr(
232 hwModule.getNameAttr().getValue() +
"_CheckContract_" + Twine(i));
234 verif::FormalOp::create(mlirModuleBuilder, contract.getLoc(), name,
235 mlirModuleBuilder.getDictionaryAttr({}));
238 OpBuilder formalBuilder(formalOp);
239 formalBuilder.createBlock(&formalOp.getBody());
241 if (failed(inlineContract(contract, formalBuilder,
false)))
245 for (
auto contract : contracts) {
247 hwModuleBuilder.setInsertionPointAfter(contract);
248 if (failed(inlineContract(contract, hwModuleBuilder,
true)))
257void LowerContractsPass::runOnOperation() {
258 auto mlirModule = getOperation();
259 for (
auto module : mlirModule.getOps<
HWModuleOp>())
260 if (failed(runOnHWModule(module, mlirModule)))
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.