19#include "mlir/IR/Builders.h"
20#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/SymbolTable.h"
25#include "mlir/Interfaces/SideEffectInterfaces.h"
26#include "mlir/Support/WalkResult.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/Support/LogicalResult.h"
33#define GEN_PASS_DEF_ELABORATEOBJECT
34#include "circt/Dialect/OM/OMPasses.h.inc"
44using FieldIndex = DenseMap<std::pair<StringAttr, StringAttr>,
unsigned>;
49 ObjectOpInliningPattern(MLIRContext *
context, SymbolTable &symTable,
50 bool replaceExternalWithUnknown)
52 replaceExternalWithUnknown(replaceExternalWithUnknown) {}
54 LogicalResult matchAndRewrite(ObjectOp objOp,
55 PatternRewriter &rewriter)
const override {
57 symTable.lookup<ClassLike>(objOp.getClassNameAttr().getAttr());
61 if (isa<ClassExternOp>(classLike)) {
62 if (!replaceExternalWithUnknown)
64 rewriter.replaceOpWithNewOp<UnknownValueOp>(objOp, objOp.getType());
68 auto classOp = dyn_cast<ClassOp>(classLike.getOperation());
73 for (
auto [formal, actual] :
llvm::zip(
74 classOp.
getBodyBlock()->getArguments(), objOp.getActualParams()))
75 mapper.map(formal, actual);
79 classOp.getBody().cloneInto(&clonedRegion, mapper);
80 Block *clonedBlock = &clonedRegion.front();
82 auto clonedFields = cast<ClassFieldsOp>(clonedBlock->getTerminator());
83 SmallVector<Value> fieldValues(clonedFields.getFields());
86 if (
auto classOp = dyn_cast<ClassOp>(classLike.getOperation()))
87 for (
auto [i, v] :
llvm::enumerate(fieldValues))
89 rewriter.getFusedLoc({classOp.getFieldLocByIndex(i), v.getLoc()}));
92 rewriter.eraseOp(clonedFields);
93 rewriter.inlineBlockBefore(clonedBlock, objOp);
95 rewriter.replaceOpWithNewOp<ElaboratedObjectOp>(objOp, classLike,
101 const SymbolTable &symTable;
102 bool replaceExternalWithUnknown;
108 EvaluateObjectField(MLIRContext *
context,
const SymbolTable &symTable,
109 const FieldIndex &fieldIndexes)
111 fieldIndexes(fieldIndexes) {}
113 LogicalResult matchAndRewrite(ObjectFieldOp op,
114 PatternRewriter &rewriter)
const override {
116 auto elaboratedOp = op.getObject().getDefiningOp<ElaboratedObjectOp>();
121 symTable.lookup<ClassLike>(elaboratedOp.getClassNameAttr().getAttr());
126 fieldIndexes.at({classLike.getSymNameAttr(), op.getFieldAttr()});
127 auto result = elaboratedOp.getFieldValues()[index];
131 if (op.getResult() == result)
134 rewriter.replaceOp(op, result);
138 const SymbolTable &symTable;
139 const FieldIndex &fieldIndexes;
144struct UnknownPropagationPattern : RewritePattern {
145 UnknownPropagationPattern(MLIRContext *
context)
146 : RewritePattern(MatchAnyOpTypeTag(), 1,
context) {}
148 LogicalResult matchAndRewrite(Operation *op,
149 PatternRewriter &rewriter)
const override {
153 if (!isa_and_nonnull<OMDialect>(op->getDialect()) || !isPure(op) ||
154 op->getNumResults() == 0)
161 if (!llvm::any_of(op->getOperands(), [](Value operand) {
162 return operand.getDefiningOp<UnknownValueOp>();
167 SmallVector<Value> unknowns;
168 for (Type resultType : op->getResultTypes())
170 UnknownValueOp::create(rewriter, op->
getLoc(), resultType));
172 rewriter.replaceOp(op, unknowns);
179bool isFullyEvaluated(Operation *op) {
182 ClassOp, ClassFieldsOp, ElaboratedObjectOp, AnyCastOp,
184 ConstantOp, UnknownValueOp,
186 FrozenBasePathCreateOp, FrozenPathCreateOp, FrozenEmptyPathOp,
188 ListCreateOp, ListConcatOp>(op);
191LogicalResult verifyResult(ClassOp module,
bool allowUnevaluated) {
192 auto isLegal = [allowUnevaluated](Operation *op) -> LogicalResult {
194 if (
auto assertOp = dyn_cast<PropertyAssertOp>(op)) {
197 auto *defOp = assertOp.getCondition().getDefiningOp();
199 auto checkAssert = [&](
bool cond) -> LogicalResult {
207 return op->emitError(
"OM property assertion failed: ")
208 << assertOp.getMessage();
212 if (matchPattern(assertOp.getCondition(), m_ConstantInt(&value)))
213 return checkAssert(!value.isZero());
216 if (
auto unknownOp = dyn_cast_or_null<UnknownValueOp>(defOp))
217 return checkAssert(
true);
220 if (allowUnevaluated)
222 return emitError(op->getLoc(),
"failed to evaluate assertion condition");
225 if (!isFullyEvaluated(op)) {
226 if (allowUnevaluated)
228 return emitError(op->getLoc()) <<
"failed to evaluate " << op->getName();
233 bool encounteredError =
false;
234 module.walk([&](Operation *op) { encounteredError |= failed(isLegal(op)); });
236 return failure(encounteredError);
239struct ElaborateObjectPass
240 :
public circt::om::impl::ElaborateObjectBase<ElaborateObjectPass> {
243 static LogicalResult elaborateClass(ClassOp classOp, SymbolTable &symTable,
244 FieldIndex &fieldIndexes,
245 bool allowUnevaluated =
false) {
250 RewritePatternSet
patterns(classOp.getContext());
251 patterns.add<ObjectOpInliningPattern>(classOp.getContext(), symTable,
253 patterns.add<EvaluateObjectField>(classOp.getContext(), symTable,
255 patterns.add<UnknownPropagationPattern>(classOp.getContext());
256 GreedyRewriteConfig config;
258 config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
259 if (failed(applyPatternsGreedily(classOp, std::move(
patterns), config)))
263 return verifyResult(classOp, allowUnevaluated);
266 LogicalResult initialize(MLIRContext *
context)
override {
268 allPublicClasses.getValue() + !targetClass.getValue().empty();
270 return emitError(UnknownLoc::get(
context))
271 <<
"exactly one of 'target-class' or 'all-public-classes' must "
276 void runOnOperation()
override {
277 auto module = getOperation();
278 auto &symTable = getAnalysis<SymbolTable>();
282 FieldIndex fieldIndexes;
283 for (
auto classOp : module.getOps<ClassLike>()) {
284 auto name = classOp.getSymNameAttr();
285 for (
auto [idx, fieldName] :
286 llvm::enumerate(classOp.getFieldNames().getAsRange<StringAttr>()))
287 fieldIndexes[{name, fieldName}] = idx;
291 if (allPublicClasses) {
292 for (
auto classOp : module.getOps<ClassOp>()) {
293 if (!classOp.isPublic())
295 if (failed(elaborateClass(classOp, symTable, fieldIndexes,
297 return signalPassFailure();
303 auto classOp = symTable.lookup<ClassOp>(targetClass);
305 emitError(module.getLoc())
306 <<
"target class '" << targetClass <<
"' was not found";
307 return signalPassFailure();
311 elaborateClass(classOp, symTable, fieldIndexes, allowUnevaluated)))
312 return signalPassFailure();
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location getLoc(DefSlot slot)
static Block * getBodyBlock(FModuleLike mod)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.