CIRCT 23.0.0git
Loading...
Searching...
No Matches
ElaborateObject.cpp
Go to the documentation of this file.
1//===- ElaborateObject.cpp - OM compile-time evaluation pass --------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This pass performs evaluation of OM classes by inlining object
10// instantiations and folding field accesses. It replaces the runtime Evaluator
11// framework with static compile-time evaluation.
12//
13//===----------------------------------------------------------------------===//
14
18#include "circt/Support/LLVM.h"
19#include "mlir/IR/Builders.h"
20#include "mlir/IR/Diagnostics.h"
21#include "mlir/IR/IRMapping.h"
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/Operation.h"
24#include "mlir/IR/SymbolTable.h"
25#include "mlir/Interfaces/SideEffectInterfaces.h"
26#include "mlir/Support/WalkResult.h"
27#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/Support/LogicalResult.h"
30
31namespace circt {
32namespace om {
33#define GEN_PASS_DEF_ELABORATEOBJECT
34#include "circt/Dialect/OM/OMPasses.h.inc"
35} // namespace om
36} // namespace circt
37
38using namespace mlir;
39using namespace circt;
40using namespace om;
41
42namespace {
43// A map from (class name, field name) to field index.
44using FieldIndex = DenseMap<std::pair<StringAttr, StringAttr>, unsigned>;
45
46/// Pattern to inline ObjectOp instances by cloning the class body and
47/// replacing them with ElaboratedObjectOp.
48struct ObjectOpInliningPattern : public OpRewritePattern<ObjectOp> {
49 ObjectOpInliningPattern(MLIRContext *context, SymbolTable &symTable,
50 bool replaceExternalWithUnknown)
51 : OpRewritePattern<ObjectOp>(context), symTable(symTable),
52 replaceExternalWithUnknown(replaceExternalWithUnknown) {}
53
54 LogicalResult matchAndRewrite(ObjectOp objOp,
55 PatternRewriter &rewriter) const override {
56 auto classLike =
57 symTable.lookup<ClassLike>(objOp.getClassNameAttr().getAttr());
58 assert(classLike);
59
60 // External classes cannot be elaborated; replace with unknown values.
61 if (isa<ClassExternOp>(classLike)) {
62 if (!replaceExternalWithUnknown)
63 return failure();
64 rewriter.replaceOpWithNewOp<UnknownValueOp>(objOp, objOp.getType());
65 return success();
66 }
67
68 auto classOp = dyn_cast<ClassOp>(classLike.getOperation());
69 if (!classOp)
70 return failure();
71
72 IRMapping mapper;
73 for (auto [formal, actual] : llvm::zip(
74 classOp.getBodyBlock()->getArguments(), objOp.getActualParams()))
75 mapper.map(formal, actual);
76
77 // Clone the class body into a temporary region with argument substitution.
78 Region clonedRegion;
79 classOp.getBody().cloneInto(&clonedRegion, mapper);
80 Block *clonedBlock = &clonedRegion.front();
81
82 auto clonedFields = cast<ClassFieldsOp>(clonedBlock->getTerminator());
83 SmallVector<Value> fieldValues(clonedFields.getFields());
84 // Propagate the class's per-field locations onto each field value, fused
85 // with the value's existing location.
86 if (auto classOp = dyn_cast<ClassOp>(classLike.getOperation()))
87 for (auto [i, v] : llvm::enumerate(fieldValues))
88 v.setLoc(
89 rewriter.getFusedLoc({classOp.getFieldLocByIndex(i), v.getLoc()}));
90
91 // Erase the terminator and inline the body at the object instantiation.
92 rewriter.eraseOp(clonedFields);
93 rewriter.inlineBlockBefore(clonedBlock, objOp);
94
95 rewriter.replaceOpWithNewOp<ElaboratedObjectOp>(objOp, classLike,
96 fieldValues);
97
98 return success();
99 }
100
101 const SymbolTable &symTable;
102 bool replaceExternalWithUnknown;
103};
104
105/// Pattern to fold ObjectFieldOp on ElaboratedObjectOp by directly accessing
106/// the field value operands.
107struct EvaluateObjectField : OpRewritePattern<ObjectFieldOp> {
108 EvaluateObjectField(MLIRContext *context, const SymbolTable &symTable,
109 const FieldIndex &fieldIndexes)
110 : OpRewritePattern<ObjectFieldOp>(context), symTable(symTable),
111 fieldIndexes(fieldIndexes) {}
112
113 LogicalResult matchAndRewrite(ObjectFieldOp op,
114 PatternRewriter &rewriter) const override {
115 // Only fold if the object is an ElaboratedObjectOp.
116 auto elaboratedOp = op.getObject().getDefiningOp<ElaboratedObjectOp>();
117 if (!elaboratedOp)
118 return failure();
119
120 auto classLike =
121 symTable.lookup<ClassLike>(elaboratedOp.getClassNameAttr().getAttr());
122 assert(classLike);
123
124 // Find the field index and get the corresponding value.
125 auto index =
126 fieldIndexes.at({classLike.getSymNameAttr(), op.getFieldAttr()});
127 auto result = elaboratedOp.getFieldValues()[index];
128
129 // Skip cycles where a field references itself.
130 // This will be raised as an error later.
131 if (op.getResult() == result)
132 return failure();
133
134 rewriter.replaceOp(op, result);
135 return success();
136 }
137
138 const SymbolTable &symTable;
139 const FieldIndex &fieldIndexes;
140};
141
142/// Pattern to propagate UnknownValueOp through pure OM operations.
143/// If any operand is unknown, all results become unknown.
144struct UnknownPropagationPattern : RewritePattern {
145 UnknownPropagationPattern(MLIRContext *context)
146 : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {}
147
148 LogicalResult matchAndRewrite(Operation *op,
149 PatternRewriter &rewriter) const override {
150 // Only target pure OM operations.
151 // TODO: Consider add a trait for this if we want to have more explict
152 // behavior.
153 if (!isa_and_nonnull<OMDialect>(op->getDialect()) || !isPure(op) ||
154 op->getNumResults() == 0)
155 return failure();
156
157 // Check if any operand is an UnknownValueOp.
158 // TODO: This directly ports the existing Evaluator semantics, but it
159 // causes inconsistent evaluation for operations that can reason about
160 // known values, e.g., "and(0, unknown) -> 0".
161 if (!llvm::any_of(op->getOperands(), [](Value operand) {
162 return operand.getDefiningOp<UnknownValueOp>();
163 }))
164 return failure();
165
166 // Replace all results with UnknownValueOp.
167 SmallVector<Value> unknowns;
168 for (Type resultType : op->getResultTypes())
169 unknowns.push_back(
170 UnknownValueOp::create(rewriter, op->getLoc(), resultType));
171
172 rewriter.replaceOp(op, unknowns);
173 return success();
174 }
175};
176
177// Check if an operation can be evaluated at compile time and is valid to
178// remain in the IR after elaboration.
179bool isFullyEvaluated(Operation *op) {
180 return isa<
181 // Structure.
182 ClassOp, ClassFieldsOp, ElaboratedObjectOp, AnyCastOp,
183 // Constant-like.
184 ConstantOp, UnknownValueOp,
185 // Path.
186 FrozenBasePathCreateOp, FrozenPathCreateOp, FrozenEmptyPathOp,
187 // List.
188 ListCreateOp, ListConcatOp>(op);
189}
190
191LogicalResult verifyResult(ClassOp module, bool allowUnevaluated) {
192 auto isLegal = [allowUnevaluated](Operation *op) -> LogicalResult {
193 // Check assert satisfied.
194 if (auto assertOp = dyn_cast<PropertyAssertOp>(op)) {
195 // Check if the condition is a constant false, which means the assertion
196 // is violated.
197 auto *defOp = assertOp.getCondition().getDefiningOp();
198 APInt value;
199 auto checkAssert = [&](bool cond) -> LogicalResult {
200 if (cond) {
201 // Erase when success, serialization doesn't need to care about
202 // this.
203 op->erase();
204 return success();
205 }
206
207 return op->emitError("OM property assertion failed: ")
208 << assertOp.getMessage();
209 };
210
211 // Condition is a constant integer/bool - check if it's true.
212 if (matchPattern(assertOp.getCondition(), m_ConstantInt(&value)))
213 return checkAssert(!value.isZero());
214
215 // Condition is unknown - treat as passing.
216 if (auto unknownOp = dyn_cast_or_null<UnknownValueOp>(defOp))
217 return checkAssert(true);
218
219 // This means the condition was not fully evaluated.
220 if (allowUnevaluated)
221 return success();
222 return emitError(op->getLoc(), "failed to evaluate assertion condition");
223 }
224
225 if (!isFullyEvaluated(op)) {
226 if (allowUnevaluated)
227 return success();
228 return emitError(op->getLoc()) << "failed to evaluate " << op->getName();
229 }
230
231 return success();
232 };
233 bool encounteredError = false;
234 module.walk([&](Operation *op) { encounteredError |= failed(isLegal(op)); });
235
236 return failure(encounteredError);
237}
238
239struct ElaborateObjectPass
240 : public circt::om::impl::ElaborateObjectBase<ElaborateObjectPass> {
241 using Base::Base;
242
243 static LogicalResult elaborateClass(ClassOp classOp, SymbolTable &symTable,
244 FieldIndex &fieldIndexes,
245 bool allowUnevaluated = false) {
246 // Elaborate objects by inlining all ObjectOps and folding field accesses
247 // using a greedy pattern rewriter. NOTE: The conversion framework is not
248 // suitable here because inlining patterns need to be applied recursively to
249 // fully evaluate nested object instantiations.
250 RewritePatternSet patterns(classOp.getContext());
251 patterns.add<ObjectOpInliningPattern>(classOp.getContext(), symTable,
252 !allowUnevaluated);
253 patterns.add<EvaluateObjectField>(classOp.getContext(), symTable,
254 fieldIndexes);
255 patterns.add<UnknownPropagationPattern>(classOp.getContext());
256 GreedyRewriteConfig config;
257 // Disable iteration limit to allow full recursive inlining.
258 config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
259 if (failed(applyPatternsGreedily(classOp, std::move(patterns), config)))
260 return failure();
261
262 // Check if elaboration succeeded after saturation.
263 return verifyResult(classOp, allowUnevaluated);
264 }
265
266 LogicalResult initialize(MLIRContext *context) override {
267 unsigned numModes =
268 allPublicClasses.getValue() + !targetClass.getValue().empty();
269 if (numModes != 1)
270 return emitError(UnknownLoc::get(context))
271 << "exactly one of 'target-class' or 'all-public-classes' must "
272 "be specified";
273 return success();
274 }
275
276 void runOnOperation() override {
277 auto module = getOperation();
278 auto &symTable = getAnalysis<SymbolTable>();
279
280 // Build a map from (class name, field name) to field index for all
281 // classes.
282 FieldIndex fieldIndexes;
283 for (auto classOp : module.getOps<ClassLike>()) {
284 auto name = classOp.getSymNameAttr();
285 for (auto [idx, fieldName] :
286 llvm::enumerate(classOp.getFieldNames().getAsRange<StringAttr>()))
287 fieldIndexes[{name, fieldName}] = idx;
288 }
289
290 // Elaborate all public classes.
291 if (allPublicClasses) {
292 for (auto classOp : module.getOps<ClassOp>()) {
293 if (!classOp.isPublic())
294 continue;
295 if (failed(elaborateClass(classOp, symTable, fieldIndexes,
296 allowUnevaluated)))
297 return signalPassFailure();
298 }
299 return;
300 }
301
302 // Normal mode: elaborate the specified target class.
303 auto classOp = symTable.lookup<ClassOp>(targetClass);
304 if (!classOp) {
305 emitError(module.getLoc())
306 << "target class '" << targetClass << "' was not found";
307 return signalPassFailure();
308 }
309
310 if (failed(
311 elaborateClass(classOp, symTable, fieldIndexes, allowUnevaluated)))
312 return signalPassFailure();
313 }
314};
315
316} // namespace
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
static Location getLoc(DefSlot slot)
Definition Mem2Reg.cpp:212
static Block * getBodyBlock(FModuleLike mod)
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition om.py:1