19#include "mlir/IR/Builders.h"
20#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/SymbolTable.h"
25#include "mlir/Interfaces/SideEffectInterfaces.h"
26#include "mlir/Support/WalkResult.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/Support/LogicalResult.h"
33#define GEN_PASS_DEF_ELABORATEOBJECT
34#include "circt/Dialect/OM/OMPasses.h.inc"
44using FieldIndex = DenseMap<std::pair<StringAttr, StringAttr>,
unsigned>;
49 ObjectOpInliningPattern(MLIRContext *
context, SymbolTable &symTable,
50 bool replaceExternalWithUnknown)
52 replaceExternalWithUnknown(replaceExternalWithUnknown) {}
54 LogicalResult matchAndRewrite(ObjectOp objOp,
55 PatternRewriter &rewriter)
const override {
56 auto classLike = symTable.lookup<ClassLike>(objOp.getClassNameAttr());
60 if (isa<ClassExternOp>(classLike)) {
61 if (!replaceExternalWithUnknown)
63 rewriter.replaceOpWithNewOp<UnknownValueOp>(objOp, objOp.getType());
67 auto classOp = dyn_cast<ClassOp>(classLike.getOperation());
72 for (
auto [formal, actual] :
llvm::zip(
73 classOp.
getBodyBlock()->getArguments(), objOp.getActualParams()))
74 mapper.map(formal, actual);
78 classOp.getBody().cloneInto(&clonedRegion, mapper);
79 Block *clonedBlock = &clonedRegion.front();
81 auto clonedFields = cast<ClassFieldsOp>(clonedBlock->getTerminator());
82 SmallVector<Value> fieldValues(clonedFields.getFields());
85 rewriter.eraseOp(clonedFields);
86 rewriter.inlineBlockBefore(clonedBlock, objOp);
88 rewriter.replaceOpWithNewOp<ElaboratedObjectOp>(objOp, classLike,
94 const SymbolTable &symTable;
95 bool replaceExternalWithUnknown;
101 EvaluateObjectField(MLIRContext *
context,
const SymbolTable &symTable,
102 const FieldIndex &fieldIndexes)
104 fieldIndexes(fieldIndexes) {}
106 LogicalResult matchAndRewrite(ObjectFieldOp op,
107 PatternRewriter &rewriter)
const override {
109 auto elaboratedOp = op.getObject().getDefiningOp<ElaboratedObjectOp>();
114 symTable.lookup<ClassLike>(elaboratedOp.getClassNameAttr());
119 fieldIndexes.at({classLike.getSymNameAttr(), op.getFieldAttr()});
120 auto result = elaboratedOp.getFieldValues()[index];
124 if (op.getResult() == result)
127 rewriter.replaceOp(op, result);
131 const SymbolTable &symTable;
132 const FieldIndex &fieldIndexes;
137struct UnknownPropagationPattern : RewritePattern {
138 UnknownPropagationPattern(MLIRContext *
context)
139 : RewritePattern(MatchAnyOpTypeTag(), 1,
context) {}
141 LogicalResult matchAndRewrite(Operation *op,
142 PatternRewriter &rewriter)
const override {
146 if (!isa_and_nonnull<OMDialect>(op->getDialect()) || !isPure(op) ||
147 op->getNumResults() == 0)
154 if (!llvm::any_of(op->getOperands(), [](Value operand) {
155 return operand.getDefiningOp<UnknownValueOp>();
160 SmallVector<Value> unknowns;
161 for (Type resultType : op->getResultTypes())
163 UnknownValueOp::create(rewriter, op->
getLoc(), resultType));
165 rewriter.replaceOp(op, unknowns);
172bool isFullyEvaluated(Operation *op) {
175 ClassOp, ClassFieldsOp, ElaboratedObjectOp, AnyCastOp,
177 ConstantOp, UnknownValueOp,
179 FrozenBasePathCreateOp, FrozenPathCreateOp, FrozenEmptyPathOp,
181 ListCreateOp, ListConcatOp>(op);
184LogicalResult verifyResult(ClassOp module,
bool allowUnevaluated) {
185 auto isLegal = [allowUnevaluated](Operation *op) -> LogicalResult {
187 if (
auto assertOp = dyn_cast<PropertyAssertOp>(op)) {
190 auto *defOp = assertOp.getCondition().getDefiningOp();
192 auto checkAssert = [&](
bool cond) -> LogicalResult {
200 return op->emitError(
"OM property assertion failed: ")
201 << assertOp.getMessage();
205 if (matchPattern(assertOp.getCondition(), m_ConstantInt(&value)))
206 return checkAssert(!value.isZero());
209 if (
auto unknownOp = dyn_cast_or_null<UnknownValueOp>(defOp))
210 return checkAssert(
true);
213 if (allowUnevaluated)
215 return emitError(op->getLoc(),
"failed to evaluate assertion condition");
218 if (!isFullyEvaluated(op)) {
219 if (allowUnevaluated)
221 return emitError(op->getLoc()) <<
"failed to evaluate " << op->getName();
226 bool encounteredError =
false;
227 module.walk([&](Operation *op) { encounteredError |= failed(isLegal(op)); });
229 return failure(encounteredError);
232struct ElaborateObjectPass
233 :
public circt::om::impl::ElaborateObjectBase<ElaborateObjectPass> {
236 static LogicalResult elaborateClass(ClassOp classOp, SymbolTable &symTable,
237 FieldIndex &fieldIndexes,
238 bool allowUnevaluated =
false) {
243 RewritePatternSet
patterns(classOp.getContext());
244 patterns.add<ObjectOpInliningPattern>(classOp.getContext(), symTable,
246 patterns.add<EvaluateObjectField>(classOp.getContext(), symTable,
248 patterns.add<UnknownPropagationPattern>(classOp.getContext());
249 GreedyRewriteConfig config;
251 config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
252 if (failed(applyPatternsGreedily(classOp, std::move(
patterns), config)))
256 return verifyResult(classOp, allowUnevaluated);
259 LogicalResult initialize(MLIRContext *
context)
override {
261 allPublicClasses.getValue() + !targetClass.getValue().empty();
263 return emitError(UnknownLoc::get(
context))
264 <<
"exactly one of 'target-class' or 'all-public-classes' must "
269 void runOnOperation()
override {
270 auto module = getOperation();
271 auto &symTable = getAnalysis<SymbolTable>();
275 FieldIndex fieldIndexes;
276 for (
auto classOp : module.getOps<ClassLike>()) {
277 auto name = classOp.getSymNameAttr();
278 for (
auto [idx, fieldName] :
279 llvm::enumerate(classOp.getFieldNames().getAsRange<StringAttr>()))
280 fieldIndexes[{name, fieldName}] = idx;
284 if (allPublicClasses) {
285 for (
auto classOp : module.getOps<ClassOp>()) {
286 if (!classOp.isPublic())
288 if (failed(elaborateClass(classOp, symTable, fieldIndexes,
290 return signalPassFailure();
296 auto classOp = symTable.lookup<ClassOp>(targetClass);
298 emitError(module.getLoc())
299 <<
"target class '" << targetClass <<
"' was not found";
300 return signalPassFailure();
304 elaborateClass(classOp, symTable, fieldIndexes, allowUnevaluated)))
305 return signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location getLoc(DefSlot slot)
static Block * getBodyBlock(FModuleLike mod)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.