CIRCT 23.0.0git
Loading...
Searching...
No Matches
ElaborateObject.cpp
Go to the documentation of this file.
1//===- ElaborateObject.cpp - OM compile-time evaluation pass --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass performs evaluation of OM classes by inlining object
10// instantiations and folding field accesses. It replaces the runtime Evaluator
11// framework with static compile-time evaluation.
12//
13//===----------------------------------------------------------------------===//
14
18#include "circt/Support/LLVM.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/SymbolTable.h"
25#include "mlir/Interfaces/SideEffectInterfaces.h"
26#include "mlir/Support/WalkResult.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/Support/LogicalResult.h"
30
31namespace circt {
32namespace om {
33#define GEN_PASS_DEF_ELABORATEOBJECT
34#include "circt/Dialect/OM/OMPasses.h.inc"
35} // namespace om
36} // namespace circt
37
38using namespace mlir;
39using namespace circt;
40using namespace om;
41
42namespace {
43// A map from (class name, field name) to field index.
44using FieldIndex = DenseMap<std::pair<StringAttr, StringAttr>, unsigned>;
45
46/// Pattern to inline ObjectOp instances by cloning the class body and
47/// replacing them with ElaboratedObjectOp.
48struct ObjectOpInliningPattern : public OpRewritePattern<ObjectOp> {
49 ObjectOpInliningPattern(MLIRContext *context, SymbolTable &symTable,
50 bool replaceExternalWithUnknown)
51 : OpRewritePattern<ObjectOp>(context), symTable(symTable),
52 replaceExternalWithUnknown(replaceExternalWithUnknown) {}
53
54 LogicalResult matchAndRewrite(ObjectOp objOp,
55 PatternRewriter &rewriter) const override {
56 auto classLike = symTable.lookup<ClassLike>(objOp.getClassNameAttr());
57 assert(classLike);
58
59 // External classes cannot be elaborated; replace with unknown values.
60 if (isa<ClassExternOp>(classLike)) {
61 if (!replaceExternalWithUnknown)
62 return failure();
63 rewriter.replaceOpWithNewOp<UnknownValueOp>(objOp, objOp.getType());
64 return success();
65 }
66
67 auto classOp = dyn_cast<ClassOp>(classLike.getOperation());
68 if (!classOp)
69 return failure();
70
71 IRMapping mapper;
72 for (auto [formal, actual] : llvm::zip(
73 classOp.getBodyBlock()->getArguments(), objOp.getActualParams()))
74 mapper.map(formal, actual);
75
76 // Clone the class body into a temporary region with argument substitution.
77 Region clonedRegion;
78 classOp.getBody().cloneInto(&clonedRegion, mapper);
79 Block *clonedBlock = &clonedRegion.front();
80
81 auto clonedFields = cast<ClassFieldsOp>(clonedBlock->getTerminator());
82 SmallVector<Value> fieldValues(clonedFields.getFields());
83
84 // Erase the terminator and inline the body at the object instantiation.
85 rewriter.eraseOp(clonedFields);
86 rewriter.inlineBlockBefore(clonedBlock, objOp);
87
88 rewriter.replaceOpWithNewOp<ElaboratedObjectOp>(objOp, classLike,
89 fieldValues);
90
91 return success();
92 }
93
94 const SymbolTable &symTable;
95 bool replaceExternalWithUnknown;
96};
97
98/// Pattern to fold ObjectFieldOp on ElaboratedObjectOp by directly accessing
99/// the field value operands.
100struct EvaluateObjectField : OpRewritePattern<ObjectFieldOp> {
101 EvaluateObjectField(MLIRContext *context, const SymbolTable &symTable,
102 const FieldIndex &fieldIndexes)
103 : OpRewritePattern<ObjectFieldOp>(context), symTable(symTable),
104 fieldIndexes(fieldIndexes) {}
105
106 LogicalResult matchAndRewrite(ObjectFieldOp op,
107 PatternRewriter &rewriter) const override {
108 // Only fold if the object is an ElaboratedObjectOp.
109 auto elaboratedOp = op.getObject().getDefiningOp<ElaboratedObjectOp>();
110 if (!elaboratedOp)
111 return failure();
112
113 auto classLike =
114 symTable.lookup<ClassLike>(elaboratedOp.getClassNameAttr());
115 assert(classLike);
116
117 // Find the field index and get the corresponding value.
118 auto index =
119 fieldIndexes.at({classLike.getSymNameAttr(), op.getFieldAttr()});
120 auto result = elaboratedOp.getFieldValues()[index];
121
122 // Skip cycles where a field references itself.
123 // This will be raised as an error later.
124 if (op.getResult() == result)
125 return failure();
126
127 rewriter.replaceOp(op, result);
128 return success();
129 }
130
131 const SymbolTable &symTable;
132 const FieldIndex &fieldIndexes;
133};
134
135/// Pattern to propagate UnknownValueOp through pure OM operations.
136/// If any operand is unknown, all results become unknown.
137struct UnknownPropagationPattern : RewritePattern {
138 UnknownPropagationPattern(MLIRContext *context)
139 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
140
141 LogicalResult matchAndRewrite(Operation *op,
142 PatternRewriter &rewriter) const override {
143 // Only target pure OM operations.
144 // TODO: Consider add a trait for this if we want to have more explict
145 // behavior.
146 if (!isa_and_nonnull<OMDialect>(op->getDialect()) || !isPure(op) ||
147 op->getNumResults() == 0)
148 return failure();
149
150 // Check if any operand is an UnknownValueOp.
151 // TODO: This directly ports the existing Evaluator semantics, but it
152 // causes inconsistent evaluation for operations that can reason about
153 // known values, e.g., "and(0, unknown) -> 0".
154 if (!llvm::any_of(op->getOperands(), [](Value operand) {
155 return operand.getDefiningOp<UnknownValueOp>();
156 }))
157 return failure();
158
159 // Replace all results with UnknownValueOp.
160 SmallVector<Value> unknowns;
161 for (Type resultType : op->getResultTypes())
162 unknowns.push_back(
163 UnknownValueOp::create(rewriter, op->getLoc(), resultType));
164
165 rewriter.replaceOp(op, unknowns);
166 return success();
167 }
168};
169
170// Check if an operation can be evaluated at compile time and is valid to
171// remain in the IR after elaboration.
172bool isFullyEvaluated(Operation *op) {
173 return isa<
174 // Structure.
175 ClassOp, ClassFieldsOp, ElaboratedObjectOp, AnyCastOp,
176 // Constant-like.
177 ConstantOp, UnknownValueOp,
178 // Path.
179 FrozenBasePathCreateOp, FrozenPathCreateOp, FrozenEmptyPathOp,
180 // List.
181 ListCreateOp, ListConcatOp>(op);
182}
183
184LogicalResult verifyResult(ClassOp module, bool allowUnevaluated) {
185 auto isLegal = [allowUnevaluated](Operation *op) -> LogicalResult {
186 // Check assert satisfied.
187 if (auto assertOp = dyn_cast<PropertyAssertOp>(op)) {
188 // Check if the condition is a constant false, which means the assertion
189 // is violated.
190 auto *defOp = assertOp.getCondition().getDefiningOp();
191 APInt value;
192 auto checkAssert = [&](bool cond) -> LogicalResult {
193 if (cond) {
194 // Erase when success, serialization doesn't need to care about
195 // this.
196 op->erase();
197 return success();
198 }
199
200 return op->emitError("OM property assertion failed: ")
201 << assertOp.getMessage();
202 };
203
204 // Condition is a constant integer/bool - check if it's true.
205 if (matchPattern(assertOp.getCondition(), m_ConstantInt(&value)))
206 return checkAssert(!value.isZero());
207
208 // Condition is unknown - treat as passing.
209 if (auto unknownOp = dyn_cast_or_null<UnknownValueOp>(defOp))
210 return checkAssert(true);
211
212 // This means the condition was not fully evaluated.
213 if (allowUnevaluated)
214 return success();
215 return emitError(op->getLoc(), "failed to evaluate assertion condition");
216 }
217
218 if (!isFullyEvaluated(op)) {
219 if (allowUnevaluated)
220 return success();
221 return emitError(op->getLoc()) << "failed to evaluate " << op->getName();
222 }
223
224 return success();
225 };
226 bool encounteredError = false;
227 module.walk([&](Operation *op) { encounteredError |= failed(isLegal(op)); });
228
229 return failure(encounteredError);
230}
231
232struct ElaborateObjectPass
233 : public circt::om::impl::ElaborateObjectBase<ElaborateObjectPass> {
234 using Base::Base;
235
236 static LogicalResult elaborateClass(ClassOp classOp, SymbolTable &symTable,
237 FieldIndex &fieldIndexes,
238 bool allowUnevaluated = false) {
239 // Elaborate objects by inlining all ObjectOps and folding field accesses
240 // using a greedy pattern rewriter. NOTE: The conversion framework is not
241 // suitable here because inlining patterns need to be applied recursively to
242 // fully evaluate nested object instantiations.
243 RewritePatternSet patterns(classOp.getContext());
244 patterns.add<ObjectOpInliningPattern>(classOp.getContext(), symTable,
245 !allowUnevaluated);
246 patterns.add<EvaluateObjectField>(classOp.getContext(), symTable,
247 fieldIndexes);
248 patterns.add<UnknownPropagationPattern>(classOp.getContext());
249 GreedyRewriteConfig config;
250 // Disable iteration limit to allow full recursive inlining.
251 config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
252 if (failed(applyPatternsGreedily(classOp, std::move(patterns), config)))
253 return failure();
254
255 // Check if elaboration succeeded after saturation.
256 return verifyResult(classOp, allowUnevaluated);
257 }
258
259 LogicalResult initialize(MLIRContext *context) override {
260 unsigned numModes =
261 allPublicClasses.getValue() + !targetClass.getValue().empty();
262 if (numModes != 1)
263 return emitError(UnknownLoc::get(context))
264 << "exactly one of 'target-class' or 'all-public-classes' must "
265 "be specified";
266 return success();
267 }
268
269 void runOnOperation() override {
270 auto module = getOperation();
271 auto &symTable = getAnalysis<SymbolTable>();
272
273 // Build a map from (class name, field name) to field index for all
274 // classes.
275 FieldIndex fieldIndexes;
276 for (auto classOp : module.getOps<ClassLike>()) {
277 auto name = classOp.getSymNameAttr();
278 for (auto [idx, fieldName] :
279 llvm::enumerate(classOp.getFieldNames().getAsRange<StringAttr>()))
280 fieldIndexes[{name, fieldName}] = idx;
281 }
282
283 // Elaborate all public classes.
284 if (allPublicClasses) {
285 for (auto classOp : module.getOps<ClassOp>()) {
286 if (!classOp.isPublic())
287 continue;
288 if (failed(elaborateClass(classOp, symTable, fieldIndexes,
289 allowUnevaluated)))
290 return signalPassFailure();
291 }
292 return;
293 }
294
295 // Normal mode: elaborate the specified target class.
296 auto classOp = symTable.lookup<ClassOp>(targetClass);
297 if (!classOp) {
298 emitError(module.getLoc())
299 << "target class '" << targetClass << "' was not found";
300 return signalPassFailure();
301 }
302
303 if (failed(
304 elaborateClass(classOp, symTable, fieldIndexes, allowUnevaluated)))
305 return signalPassFailure();
306 }
307};
308
309} // namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:218
static Block * getBodyBlock(FModuleLike mod)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition om.py:1