CIRCT 23.0.0git
Loading...
Searching...
No Matches
OMReductions.cpp
Go to the documentation of this file.
1//===- OMReductions.cpp - Reduction patterns for the OM dialect ----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
12#include "mlir/IR/SymbolTable.h"
13#include "llvm/ADT/DenseMap.h"
14#include "llvm/ADT/SmallSet.h"
15#include "llvm/Support/Debug.h"
16
17#define DEBUG_TYPE "om-reductions"
18
19using namespace mlir;
20using namespace circt;
21using namespace om;
22
23//===----------------------------------------------------------------------===//
24// Utilities
25//===----------------------------------------------------------------------===//
26
27namespace {
28
29/// A utility for caching symbol tables and symbol user maps.
30struct SymbolCache {
31 SymbolCache() : tables(std::make_unique<SymbolTableCollection>()) {}
32
33 SymbolTable &getSymbolTable(Operation *op) {
34 return tables->getSymbolTable(op);
35 }
36
37 SymbolTable &getNearestSymbolTable(Operation *op) {
38 return getSymbolTable(SymbolTable::getNearestSymbolTable(op));
39 }
40
41 SymbolUserMap &getSymbolUserMap(Operation *op) {
42 auto it = userMaps.find(op);
43 if (it != userMaps.end())
44 return it->second;
45 return userMaps.insert({op, SymbolUserMap(*tables, op)}).first->second;
46 }
47
48 SymbolUserMap &getNearestSymbolUserMap(Operation *op) {
49 return getSymbolUserMap(SymbolTable::getNearestSymbolTable(op));
50 }
51
52 void clear() {
53 tables = std::make_unique<SymbolTableCollection>();
54 userMaps.clear();
55 }
56
57private:
58 std::unique_ptr<SymbolTableCollection> tables;
60};
61
62} // namespace
63
64//===----------------------------------------------------------------------===//
65// Reduction Patterns
66//===----------------------------------------------------------------------===//
67
68namespace {
69
70/// Replaces om.object instantiations with om.unknown when the object's fields
71/// are not accessed (only used in om.any_cast or not used at all).
72struct OMObjectToUnknownReplacer : public OpReduction<ObjectOp> {
73 uint64_t match(ObjectOp objectOp) override {
74 // Check if the object is only used in om.any_cast operations or not used
75 bool onlyAnyCastOrUnused =
76 llvm::all_of(objectOp->getUsers(),
77 [](Operation *user) { return isa<om::AnyCastOp>(user); });
78
79 if (!onlyAnyCastOrUnused)
80 return 0;
81
82 // Return a benefit proportional to the number of operands we can eliminate
83 return 1 + objectOp.getActualParams().size();
84 }
85
86 LogicalResult rewrite(ObjectOp objectOp) override {
87 OpBuilder builder(objectOp);
88 auto unknownOp = om::UnknownValueOp::create(builder, objectOp.getLoc(),
89 objectOp.getResult().getType());
90 objectOp.getResult().replaceAllUsesWith(unknownOp.getResult());
91 objectOp->erase();
92 return success();
93 }
94
95 std::string getName() const override { return "om-object-to-unknown"; }
96};
97
98/// Removes unused elements from om.list_create operations.
99struct OMListElementPruner : public OpReduction<ListCreateOp> {
100 void matches(ListCreateOp listOp,
101 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
102 // Create one match for each element in the list
103 auto inputs = listOp.getInputs();
104 for (size_t i = 0; i < inputs.size(); ++i)
105 addMatch(1, i);
106 }
107
108 LogicalResult rewriteMatches(ListCreateOp listOp,
109 ArrayRef<uint64_t> matches) override {
110 // Convert matches to a set for fast lookup
111 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
112
113 // Collect inputs that should be kept (not in matches)
114 SmallVector<Value> newInputs;
115 for (auto [i, input] : llvm::enumerate(listOp.getInputs()))
116 if (!matchesSet.contains(i))
117 newInputs.push_back(input);
118
119 // Create a new list with the remaining inputs
120 ImplicitLocOpBuilder builder(listOp.getLoc(), listOp);
121 auto newListOp =
122 ListCreateOp::create(builder, listOp.getResult().getType(), newInputs);
123 listOp.getResult().replaceAllUsesWith(newListOp.getResult());
124 listOp->erase();
125 return success();
126 }
127
128 std::string getName() const override { return "om-list-element-pruner"; }
129};
130
131/// Removes unused output fields from om.class definitions.
132struct OMClassFieldPruner : public OpReduction<ClassOp> {
133 void beforeReduction(mlir::ModuleOp op) override { symbols.clear(); }
134
135 void matches(ClassOp classOp,
136 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
137 // Get the field names.
138 auto fieldNames = classOp.getFieldNames();
139 if (fieldNames.empty())
140 return;
141
142 // Find which fields are actually accessed. We need to walk all operations
143 // in the module because om.class operations are IsolatedFromAbove, so the
144 // SymbolUserMap doesn't find nested uses.
145 llvm::DenseSet<StringAttr> usedFields;
146 auto moduleOp = classOp->getParentOfType<mlir::ModuleOp>();
147
148 moduleOp.walk([&](ObjectOp objectOp) {
149 // Check if this object is an instance of our class
150 if (objectOp.getClassNameAttr() != classOp.getSymNameAttr())
151 return;
152
153 // Check all object field uses of this object
154 for (auto *objectUser : objectOp->getUsers()) {
155 auto fieldOp = dyn_cast<ObjectFieldOp>(objectUser);
156 if (!fieldOp)
157 continue;
158
159 auto fieldPath = fieldOp.getFieldPath();
160 if (fieldPath.empty())
161 continue;
162
163 // Mark the accessed field as used.
164 usedFields.insert(cast<FlatSymbolRefAttr>(fieldPath[0]).getAttr());
165 }
166 });
167
168 // Create one match for each unused field.
169 for (auto [idx, fieldName] : llvm::enumerate(fieldNames))
170 if (!usedFields.contains(cast<StringAttr>(fieldName)))
171 addMatch(1, idx);
172 }
173
174 LogicalResult rewriteMatches(ClassOp classOp,
175 ArrayRef<uint64_t> matches) override {
176 // Convert matches to a set for fast lookup.
177 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
178
179 // Build new field lists with only fields not in matches.
180 auto fieldsOp = classOp.getFieldsOp();
181 auto oldFieldValues = fieldsOp.getFields();
182
183 SmallVector<Attribute> names;
184 SmallVector<Value> values;
185 SmallVector<NamedAttribute> types;
186 for (auto [idx, nameValue] :
187 llvm::enumerate(llvm::zip(classOp.getFieldNames(), oldFieldValues))) {
188 if (matchesSet.contains(idx))
189 continue;
190
191 auto [name, value] = nameValue;
192 auto nameAttr = cast<StringAttr>(name);
193 names.push_back(name);
194 values.push_back(value);
195 types.push_back(NamedAttribute(nameAttr, TypeAttr::get(value.getType())));
196 }
197
198 // Update the class.
199 OpBuilder builder(classOp);
200 classOp.setFieldNamesAttr(builder.getArrayAttr(names));
201 classOp.setFieldTypesAttr(builder.getDictionaryAttr(types));
202 fieldsOp.getFieldsMutable().assign(values);
203
204 return success();
205 }
206
207 std::string getName() const override { return "om-class-field-pruner"; }
208
209 SymbolCache symbols;
210};
211
212/// Removes unused parameters from om.class definitions.
213struct OMClassParameterPruner : public OpReduction<ClassOp> {
214 void beforeReduction(mlir::ModuleOp op) override { symbols.clear(); }
215
216 void matches(ClassOp classOp,
217 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
218 auto *bodyBlock = classOp.getBodyBlock();
219 if (bodyBlock->getNumArguments() == 0)
220 return;
221
222 // Create one match for each unused parameter
223 for (auto [idx, arg] : llvm::enumerate(bodyBlock->getArguments()))
224 if (arg.use_empty())
225 addMatch(1, idx);
226 }
227
228 LogicalResult rewriteMatches(ClassOp classOp,
229 ArrayRef<uint64_t> matches) override {
230 // Convert matches to a set for fast lookup
231 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
232
233 auto *bodyBlock = classOp.getBodyBlock();
234
235 // Collect all object instantiations that need to be updated. We need to
236 // walk all operations in the module because om.class operations are
237 // IsolatedFromAbove, so the SymbolUserMap doesn't find nested uses.
238 SmallVector<ObjectOp> objectsToUpdate;
239 auto moduleOp = classOp->getParentOfType<mlir::ModuleOp>();
240 moduleOp.walk([&](ObjectOp objectOp) {
241 if (objectOp.getClassNameAttr() == classOp.getSymNameAttr())
242 objectsToUpdate.push_back(objectOp);
243 });
244
245 // Update the class formal parameters FIRST, before updating instantiations.
246 SmallVector<Attribute> newParamNames;
247 for (auto [idx, name] : llvm::enumerate(classOp.getFormalParamNames()))
248 if (!matchesSet.contains(idx))
249 newParamNames.push_back(name);
250
251 OpBuilder builder(classOp);
252 classOp.setFormalParamNamesAttr(builder.getArrayAttr(newParamNames));
253
254 // Remove the block arguments in reverse order to maintain indices.
255 SmallVector<unsigned> indicesToRemove(matches.begin(), matches.end());
256 llvm::sort(indicesToRemove, std::greater<unsigned>());
257 for (unsigned idx : indicesToRemove)
258 bodyBlock->eraseArgument(idx);
259
260 // Now update all om.object instantiations to match the new signature.
261 for (auto objectOp : objectsToUpdate) {
262 // Build new actual parameters without the removed parameters
263 SmallVector<Value> newParams;
264 for (auto [idx, param] : llvm::enumerate(objectOp.getActualParams()))
265 if (!matchesSet.contains(idx))
266 newParams.push_back(param);
267
268 // Create new object op
269 OpBuilder builder(objectOp);
270 auto newObjectOp = ObjectOp::create(
271 builder, objectOp.getLoc(), objectOp.getResult().getType(),
272 objectOp.getClassNameAttr(), newParams);
273 objectOp.getResult().replaceAllUsesWith(newObjectOp.getResult());
274 objectOp->erase();
275 }
276
277 return success();
278 }
279
280 std::string getName() const override { return "om-class-parameter-pruner"; }
281
282 SymbolCache symbols;
283};
284
285/// Removes unused om.class definitions that are never instantiated.
286struct OMUnusedClassRemover : public OpReduction<ClassOp> {
287 void beforeReduction(mlir::ModuleOp op) override { symbols.clear(); }
288
289 uint64_t match(ClassOp classOp) override {
290 // Check if this class is ever instantiated via om.object. We need to walk
291 // all operations because SymbolUserMap doesn't find nested symbol uses
292 // inside IsolatedFromAbove operations like om.class.
293 auto moduleOp = classOp->getParentOfType<mlir::ModuleOp>();
294
295 // Check if this class is instantiated via om.object anywhere
296 auto result = moduleOp.walk([&](ObjectOp objectOp) {
297 if (objectOp.getClassNameAttr() == classOp.getSymNameAttr())
298 return WalkResult::interrupt();
299 return WalkResult::advance();
300 });
301
302 if (result.wasInterrupted())
303 return 0;
304
305 // Check if this class contains any om.object instantiations
306 // (classes that instantiate other classes should be kept as they might be
307 // entry points)
308 result = classOp.walk(
309 [&](ObjectOp objectOp) { return WalkResult::interrupt(); });
310
311 if (result.wasInterrupted())
312 return 0;
313
314 // Remove the class if:
315 // 1. it's never instantiated via om.object, AND
316 // 2. it doesn't contain any om.object instantiations (not an entry point).
317 return 10;
318 }
319
320 LogicalResult rewrite(ClassOp classOp) override {
321 classOp->erase();
322 return success();
323 }
324
325 std::string getName() const override { return "om-unused-class-remover"; }
326
327 SymbolCache symbols;
328};
329
330/// Simplifies om.unknown -> om.any_cast chains by replacing with a direct
331/// om.unknown of the target type.
332struct OMAnyCastOfUnknownSimplifier : public OpReduction<om::AnyCastOp> {
333 uint64_t match(om::AnyCastOp anyCastOp) override {
334 auto unknownOp = anyCastOp.getInput().getDefiningOp<om::UnknownValueOp>();
335 if (unknownOp)
336 return 2;
337 return 0;
338 }
339
340 LogicalResult rewrite(om::AnyCastOp anyCastOp) override {
341 ImplicitLocOpBuilder builder(anyCastOp.getLoc(), anyCastOp);
342 anyCastOp.getResult().replaceAllUsesWith(
343 om::UnknownValueOp::create(builder, anyCastOp.getResult().getType()));
344 anyCastOp->erase();
345 return success();
346 }
347
348 std::string getName() const override {
349 return "om-anycast-of-unknown-simplifier";
350 }
351};
352
353} // namespace
354
355//===----------------------------------------------------------------------===//
356// Reduction Registration
357//===----------------------------------------------------------------------===//
358
361 // High priority reductions
362 patterns.add<OMClassParameterPruner, 50>();
363 patterns.add<OMClassFieldPruner, 49>();
364 patterns.add<OMObjectToUnknownReplacer, 48>();
365 patterns.add<OMListElementPruner, 45>();
366
367 // Medium priority reductions
368 patterns.add<OMUnusedClassRemover, 40>();
369 patterns.add<OMAnyCastOfUnknownSimplifier, 35>();
370}
371
372void om::registerReducePatternDialectInterface(
373 mlir::DialectRegistry &registry) {
374 registry.addExtension(+[](MLIRContext *ctx, OMDialect *dialect) {
375 dialect->addInterfaces<OMReducePatternDialectInterface>();
376 });
377}
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
Definition SymCache.h:85
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition om.py:1
A reduction pattern for a specific operation.
Definition Reduction.h:112
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:113
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:118
virtual LogicalResult rewrite(OpTy op)
Definition Reduction.h:128
virtual uint64_t match(OpTy op)
Definition Reduction.h:123
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
Definition Reduction.h:30
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override