CIRCT 23.0.0git
Loading...
Searching...
No Matches
OMReductions.cpp
Go to the documentation of this file.
1//===- OMReductions.cpp - Reduction patterns for the OM dialect ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/SymbolTable.h"
13#include "llvm/ADT/DenseMap.h"
14#include "llvm/ADT/SmallSet.h"
15#include "llvm/Support/Debug.h"
16
17#define DEBUG_TYPE "om-reductions"
18
19using namespace mlir;
20using namespace circt;
21using namespace om;
22
23//===----------------------------------------------------------------------===//
24// Utilities
25//===----------------------------------------------------------------------===//
26
27namespace {
28
29/// A utility for caching symbol tables and symbol user maps.
30struct SymbolCache {
31 SymbolCache() : tables(std::make_unique<SymbolTableCollection>()) {}
32
33 SymbolTable &getSymbolTable(Operation *op) {
34 return tables->getSymbolTable(op);
35 }
36
37 SymbolTable &getNearestSymbolTable(Operation *op) {
38 return getSymbolTable(SymbolTable::getNearestSymbolTable(op));
39 }
40
41 SymbolUserMap &getSymbolUserMap(Operation *op) {
42 auto it = userMaps.find(op);
43 if (it != userMaps.end())
44 return it->second;
45 return userMaps.insert({op, SymbolUserMap(*tables, op)}).first->second;
46 }
47
48 SymbolUserMap &getNearestSymbolUserMap(Operation *op) {
49 return getSymbolUserMap(SymbolTable::getNearestSymbolTable(op));
50 }
51
52 void clear() {
53 tables = std::make_unique<SymbolTableCollection>();
54 userMaps.clear();
55 }
56
57private:
58 std::unique_ptr<SymbolTableCollection> tables;
60};
61
62} // namespace
63
64//===----------------------------------------------------------------------===//
65// Reduction Patterns
66//===----------------------------------------------------------------------===//
67
68namespace {
69
70/// Replaces om.object instantiations with om.unknown when the object's fields
71/// are not accessed (only used in om.any_cast or not used at all).
72struct OMObjectToUnknownReplacer : public OpReduction<ObjectOp> {
73 uint64_t match(ObjectOp objectOp) override {
74 // Check if the object is only used in om.any_cast operations or not used
75 bool onlyAnyCastOrUnused =
76 llvm::all_of(objectOp->getUsers(),
77 [](Operation *user) { return isa<om::AnyCastOp>(user); });
78
79 if (!onlyAnyCastOrUnused)
80 return 0;
81
82 // Return a benefit proportional to the number of operands we can eliminate
83 return 1 + objectOp.getActualParams().size();
84 }
85
86 LogicalResult rewrite(ObjectOp objectOp) override {
87 OpBuilder builder(objectOp);
88 auto unknownOp = om::UnknownValueOp::create(builder, objectOp.getLoc(),
89 objectOp.getResult().getType());
90 objectOp.getResult().replaceAllUsesWith(unknownOp.getResult());
91 objectOp->erase();
92 return success();
93 }
94
95 std::string getName() const override { return "om-object-to-unknown"; }
96};
97
98/// Removes unused elements from om.list_create operations.
99struct OMListElementPruner : public OpReduction<ListCreateOp> {
100 void matches(ListCreateOp listOp,
101 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
102 // Create one match for each element in the list
103 auto inputs = listOp.getInputs();
104 for (size_t i = 0; i < inputs.size(); ++i)
105 addMatch(1, i);
106 }
107
108 LogicalResult rewriteMatches(ListCreateOp listOp,
109 ArrayRef<uint64_t> matches) override {
110 // Convert matches to a set for fast lookup
111 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
112
113 // Collect inputs that should be kept (not in matches)
114 SmallVector<Value> newInputs;
115 for (auto [i, input] : llvm::enumerate(listOp.getInputs()))
116 if (!matchesSet.contains(i))
117 newInputs.push_back(input);
118
119 // Create a new list with the remaining inputs
120 ImplicitLocOpBuilder builder(listOp.getLoc(), listOp);
121 auto newListOp =
122 ListCreateOp::create(builder, listOp.getResult().getType(), newInputs);
123 listOp.getResult().replaceAllUsesWith(newListOp.getResult());
124 listOp->erase();
125 return success();
126 }
127
128 std::string getName() const override { return "om-list-element-pruner"; }
129};
130
131/// Removes unused output fields from om.class definitions.
132struct OMClassFieldPruner : public OpReduction<ClassOp> {
133 void beforeReduction(mlir::ModuleOp op) override { symbols.clear(); }
134
135 void matches(ClassOp classOp,
136 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
137 // Get the field names.
138 auto fieldNames = classOp.getFieldNames();
139 if (fieldNames.empty())
140 return;
141
142 // Find which fields are actually accessed. We need to walk all operations
143 // in the module because om.class operations are IsolatedFromAbove, so the
144 // SymbolUserMap doesn't find nested uses.
145 llvm::DenseSet<StringAttr> usedFields;
146 auto moduleOp = classOp->getParentOfType<mlir::ModuleOp>();
147
148 moduleOp.walk([&](ObjectOp objectOp) {
149 // Check if this object is an instance of our class
150 if (objectOp.getClassNameAttr() != classOp.getSymNameAttr())
151 return;
152
153 // Check all object field uses of this object
154 for (auto *objectUser : objectOp->getUsers()) {
155 auto fieldOp = dyn_cast<ObjectFieldOp>(objectUser);
156 if (!fieldOp)
157 continue;
158
159 // Mark the accessed field as used.
160 usedFields.insert(fieldOp.getFieldAttr());
161 }
162 });
163
164 // Create one match for each unused field.
165 for (auto [idx, fieldName] : llvm::enumerate(fieldNames))
166 if (!usedFields.contains(cast<StringAttr>(fieldName)))
167 addMatch(1, idx);
168 }
169
170 LogicalResult rewriteMatches(ClassOp classOp,
171 ArrayRef<uint64_t> matches) override {
172 // Convert matches to a set for fast lookup.
173 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
174
175 // Build new field lists with only fields not in matches.
176 auto fieldsOp = classOp.getFieldsOp();
177 auto oldFieldValues = fieldsOp.getFields();
178
179 SmallVector<Attribute> names;
180 SmallVector<Value> values;
181 SmallVector<NamedAttribute> types;
182 for (auto [idx, nameValue] :
183 llvm::enumerate(llvm::zip(classOp.getFieldNames(), oldFieldValues))) {
184 if (matchesSet.contains(idx))
185 continue;
186
187 auto [name, value] = nameValue;
188 auto nameAttr = cast<StringAttr>(name);
189 names.push_back(name);
190 values.push_back(value);
191 types.push_back(NamedAttribute(nameAttr, TypeAttr::get(value.getType())));
192 }
193
194 // Update the class.
195 OpBuilder builder(classOp);
196 classOp.setFieldNamesAttr(builder.getArrayAttr(names));
197 classOp.setFieldTypesAttr(builder.getDictionaryAttr(types));
198 fieldsOp.getFieldsMutable().assign(values);
199
200 return success();
201 }
202
203 std::string getName() const override { return "om-class-field-pruner"; }
204
205 SymbolCache symbols;
206};
207
208/// Removes unused parameters from om.class definitions.
209struct OMClassParameterPruner : public OpReduction<ClassOp> {
210 void beforeReduction(mlir::ModuleOp op) override { symbols.clear(); }
211
212 void matches(ClassOp classOp,
213 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
214 auto *bodyBlock = classOp.getBodyBlock();
215 if (bodyBlock->getNumArguments() == 0)
216 return;
217
218 // Create one match for each unused parameter
219 for (auto [idx, arg] : llvm::enumerate(bodyBlock->getArguments()))
220 if (arg.use_empty())
221 addMatch(1, idx);
222 }
223
224 LogicalResult rewriteMatches(ClassOp classOp,
225 ArrayRef<uint64_t> matches) override {
226 // Convert matches to a set for fast lookup
227 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
228
229 auto *bodyBlock = classOp.getBodyBlock();
230
231 // Collect all object instantiations that need to be updated. We need to
232 // walk all operations in the module because om.class operations are
233 // IsolatedFromAbove, so the SymbolUserMap doesn't find nested uses.
234 SmallVector<ObjectOp> objectsToUpdate;
235 auto moduleOp = classOp->getParentOfType<mlir::ModuleOp>();
236 moduleOp.walk([&](ObjectOp objectOp) {
237 if (objectOp.getClassNameAttr() == classOp.getSymNameAttr())
238 objectsToUpdate.push_back(objectOp);
239 });
240
241 // Update the class formal parameters FIRST, before updating instantiations.
242 SmallVector<Attribute> newParamNames;
243 for (auto [idx, name] : llvm::enumerate(classOp.getFormalParamNames()))
244 if (!matchesSet.contains(idx))
245 newParamNames.push_back(name);
246
247 OpBuilder builder(classOp);
248 classOp.setFormalParamNamesAttr(builder.getArrayAttr(newParamNames));
249
250 // Remove the block arguments in reverse order to maintain indices.
251 SmallVector<unsigned> indicesToRemove(matches.begin(), matches.end());
252 llvm::sort(indicesToRemove, std::greater<unsigned>());
253 for (unsigned idx : indicesToRemove)
254 bodyBlock->eraseArgument(idx);
255
256 // Now update all om.object instantiations to match the new signature.
257 for (auto objectOp : objectsToUpdate) {
258 // Build new actual parameters without the removed parameters
259 SmallVector<Value> newParams;
260 for (auto [idx, param] : llvm::enumerate(objectOp.getActualParams()))
261 if (!matchesSet.contains(idx))
262 newParams.push_back(param);
263
264 // Create new object op
265 OpBuilder builder(objectOp);
266 auto newObjectOp = ObjectOp::create(
267 builder, objectOp.getLoc(), objectOp.getResult().getType(),
268 objectOp.getClassNameAttr(), newParams);
269 objectOp.getResult().replaceAllUsesWith(newObjectOp.getResult());
270 objectOp->erase();
271 }
272
273 return success();
274 }
275
276 std::string getName() const override { return "om-class-parameter-pruner"; }
277
278 SymbolCache symbols;
279};
280
281/// Removes unused om.class definitions that are never instantiated.
282struct OMUnusedClassRemover : public OpReduction<ClassOp> {
283 void beforeReduction(mlir::ModuleOp op) override { symbols.clear(); }
284
285 uint64_t match(ClassOp classOp) override {
286 // Check if this class is ever instantiated via om.object. We need to walk
287 // all operations because SymbolUserMap doesn't find nested symbol uses
288 // inside IsolatedFromAbove operations like om.class.
289 auto moduleOp = classOp->getParentOfType<mlir::ModuleOp>();
290
291 // Check if this class is instantiated via om.object anywhere
292 auto result = moduleOp.walk([&](ObjectOp objectOp) {
293 if (objectOp.getClassNameAttr() == classOp.getSymNameAttr())
294 return WalkResult::interrupt();
295 return WalkResult::advance();
296 });
297
298 if (result.wasInterrupted())
299 return 0;
300
301 // Check if this class contains any om.object instantiations
302 // (classes that instantiate other classes should be kept as they might be
303 // entry points)
304 result = classOp.walk(
305 [&](ObjectOp objectOp) { return WalkResult::interrupt(); });
306
307 if (result.wasInterrupted())
308 return 0;
309
310 // Remove the class if:
311 // 1. it's never instantiated via om.object, AND
312 // 2. it doesn't contain any om.object instantiations (not an entry point).
313 return 10;
314 }
315
316 LogicalResult rewrite(ClassOp classOp) override {
317 classOp->erase();
318 return success();
319 }
320
321 std::string getName() const override { return "om-unused-class-remover"; }
322
323 SymbolCache symbols;
324};
325
326/// Simplifies om.unknown -> om.any_cast chains by replacing with a direct
327/// om.unknown of the target type.
328struct OMAnyCastOfUnknownSimplifier : public OpReduction<om::AnyCastOp> {
329 uint64_t match(om::AnyCastOp anyCastOp) override {
330 auto unknownOp = anyCastOp.getInput().getDefiningOp<om::UnknownValueOp>();
331 if (unknownOp)
332 return 2;
333 return 0;
334 }
335
336 LogicalResult rewrite(om::AnyCastOp anyCastOp) override {
337 ImplicitLocOpBuilder builder(anyCastOp.getLoc(), anyCastOp);
338 anyCastOp.getResult().replaceAllUsesWith(
339 om::UnknownValueOp::create(builder, anyCastOp.getResult().getType()));
340 anyCastOp->erase();
341 return success();
342 }
343
344 std::string getName() const override {
345 return "om-anycast-of-unknown-simplifier";
346 }
347};
348
349/// Generic Operation-based reduction that replaces any OM operation with
350/// om.unknown of the same result type. This operates at the lowest level by
351/// working directly on Operation* without needing to know concrete op types.
352struct OMOpToUnknown : public Reduction {
353 uint64_t match(Operation *op) override {
354 // Only handle operations from the OM dialect
355 if (!isa<OMDialect>(op->getDialect()))
356 return 0;
357
358 // Must have exactly one result (what we'll replace with unknown)
359 if (op->getNumResults() != 1)
360 return 0;
361
362 // Constant benefit: just eliminate this single operation
363 return 1;
364 }
365
366 LogicalResult rewrite(Operation *op) override {
367 OpBuilder builder(op);
368 Type resultType = op->getResult(0).getType();
369 auto unknownOp =
370 om::UnknownValueOp::create(builder, op->getLoc(), resultType);
371 op->getResult(0).replaceAllUsesWith(unknownOp.getResult());
372 op->erase();
373 return success();
374 }
375
376 std::string getName() const override { return "om-op-to-unknown"; };
377};
378
379} // namespace
380
381//===----------------------------------------------------------------------===//
382// Reduction Registration
383//===----------------------------------------------------------------------===//
384
387 // High priority reductions
388 patterns.add<OMClassParameterPruner, 50>();
389 patterns.add<OMClassFieldPruner, 49>();
390 patterns.add<OMObjectToUnknownReplacer, 48>();
391 patterns.add<OMListElementPruner, 45>();
392
393 // Medium priority reductions
394 patterns.add<OMUnusedClassRemover, 40>();
395 patterns.add<OMOpToUnknown, 35>();
396 patterns.add<OMAnyCastOfUnknownSimplifier, 35>();
397}
398
399void om::registerReducePatternDialectInterface(
400 mlir::DialectRegistry &registry) {
401 registry.addExtension(+[](MLIRContext *ctx, OMDialect *dialect) {
402 dialect->addInterfaces<OMReducePatternDialectInterface>();
403 });
404}
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition om.py:1
A reduction pattern for a specific operation.
Definition Reduction.h:112
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:113
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:118
virtual LogicalResult rewrite(OpTy op)
Definition Reduction.h:128
virtual uint64_t match(OpTy op)
Definition Reduction.h:123
An abstract reduction pattern.
Definition Reduction.h:24
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
Definition Reduction.h:58
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
Definition Reduction.h:41
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
Definition Reduction.h:30
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override