17#include "mlir/Analysis/TopologicalSortUtils.h"
18#include "mlir/IR/IRMapping.h"
24#define GEN_PASS_DEF_LOWERCONTRACTSPASS
25#include "circt/Dialect/Verif/Passes.h.inc"
34struct LowerContractsPass
35 : verif::impl::LowerContractsPassBase<LowerContractsPass> {
36 void runOnOperation()
override;
39Operation *replaceContractOp(OpBuilder &builder, RequireLike op,
40 IRMapping &mapping,
bool assumeContract) {
41 StringAttr labelAttr = op.getLabelAttr();
44 if (
auto enable = op.getEnable())
45 enable = mapping.lookup(enable);
47 auto loc = op.getLoc();
48 auto property = mapping.lookup(op.getProperty());
50 if ((isa<EnsureOp>(op) && !assumeContract) ||
51 (isa<RequireOp>(op) && assumeContract))
52 return builder.create<AssertOp>(loc, property, enableValue, labelAttr);
53 if ((isa<EnsureOp>(op) && assumeContract) ||
54 (isa<RequireOp>(op) && !assumeContract))
55 return builder.create<AssumeOp>(loc, property, enableValue, labelAttr);
59LogicalResult cloneOp(OpBuilder &builder, Operation *opToClone,
60 IRMapping &mapping,
bool assumeContract) {
63 if (
auto requireLike = dyn_cast<RequireLike>(*opToClone)) {
64 clonedOp = replaceContractOp(builder, requireLike, mapping, assumeContract);
69 clonedOp = builder.clone(*opToClone, mapping);
74 llvm::zip(opToClone->getResults(), clonedOp->getResults())) {
77 return llvm::success();
80void assumeContractHolds(OpBuilder &builder, IRMapping &mapping,
82 std::queue<Operation *> &workList) {
86 for (
auto result : contract.getResults()) {
88 builder.create<SymbolicValueOp>(result.getLoc(), result.getType());
89 mapping.map(result, sym);
91 auto &contractOps = contract.getBody().front().getOperations();
92 for (
auto it = contractOps.rbegin(); it != contractOps.rend(); ++it) {
97void buildOpsToClone(OpBuilder &builder, IRMapping &mapping, Operation *op,
98 std::queue<Operation *> &workList,
99 Operation *parent =
nullptr) {
100 for (
auto operand : op->getOperands()) {
101 if (mapping.contains(operand))
104 if (parent && parent->isAncestor(operand.getParentBlock()->getParentOp()))
106 if (
auto *definingOp = operand.getDefiningOp()) {
107 workList.push(definingOp);
110 auto sym = builder.create<verif::SymbolicValueOp>(operand.getLoc(),
112 mapping.map(operand, sym);
117LogicalResult cloneFanIn(OpBuilder &builder, Operation *contractToClone,
118 IRMapping &mapping) {
119 DenseSet<Operation *> seen;
120 SmallVector<Operation *> opsToClone;
121 std::queue<Operation *> workList;
122 workList.push(contractToClone);
124 while (!workList.empty()) {
125 auto *currentOp = workList.front();
128 if (seen.contains(currentOp))
131 seen.insert(currentOp);
133 auto contract = dyn_cast<ContractOp>(*currentOp);
135 if (contract && (currentOp != contractToClone)) {
136 assumeContractHolds(builder, mapping, contract, workList);
143 buildOpsToClone(builder, mapping, currentOp, workList);
145 currentOp->walk([&](Operation *nestedOp) {
146 buildOpsToClone(builder, mapping, nestedOp, workList, currentOp);
150 if (currentOp != contractToClone)
151 opsToClone.push_back(currentOp);
155 computeTopologicalSorting(opsToClone);
157 for (
auto *op : opsToClone) {
158 if (failed(cloneOp(builder, op, mapping,
true)))
164LogicalResult inlineContract(ContractOp &contract, OpBuilder &builder,
165 bool assumeContract) {
167 if (!assumeContract) {
169 if (failed(cloneFanIn(builder, contract, mapping)))
173 if (assumeContract) {
175 for (
auto result : contract.getResults()) {
177 builder.create<SymbolicValueOp>(result.getLoc(), result.getType());
178 mapping.map(result, sym);
182 for (
auto [result, input] :
183 llvm::zip(contract.getResults(), contract.getInputs())) {
184 mapping.map(result, mapping.lookup(input));
189 for (
auto &op : contract.getBody().front().getOperations()) {
190 if (failed(cloneOp(builder, &op, mapping, assumeContract)))
194 if (assumeContract) {
196 for (
auto result : contract.getResults())
197 result.replaceAllUsesWith(mapping.lookup(result));
202LogicalResult runOnHWModule(
HWModuleOp hwModule, ModuleOp mlirModule) {
203 OpBuilder mlirModuleBuilder(mlirModule);
204 mlirModuleBuilder.setInsertionPointAfter(hwModule);
205 OpBuilder hwModuleBuilder(hwModule);
208 SmallVector<ContractOp> contracts;
209 hwModule.walk([&](ContractOp op) { contracts.push_back(op); });
211 for (
unsigned i = 0; i < contracts.size(); i++) {
212 auto contract = contracts[i];
215 auto name = mlirModuleBuilder.getStringAttr(
216 hwModule.getNameAttr().getValue() +
"_CheckContract_" + Twine(i));
217 auto formalOp = mlirModuleBuilder.create<verif::FormalOp>(
218 contract.getLoc(), name, mlirModuleBuilder.getDictionaryAttr({}));
221 OpBuilder formalBuilder(formalOp);
222 formalBuilder.createBlock(&formalOp.getBody());
224 if (failed(inlineContract(contract, formalBuilder,
false)))
228 for (
auto contract : contracts) {
230 hwModuleBuilder.setInsertionPointAfter(contract);
231 if (failed(inlineContract(contract, hwModuleBuilder,
true)))
240void LowerContractsPass::runOnOperation() {
241 auto mlirModule = getOperation();
242 for (
auto module : mlirModule.getOps<
HWModuleOp>())
243 if (failed(runOnHWModule(module, mlirModule)))
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.