12#include "mlir/IR/SymbolTable.h"
13#include "llvm/ADT/DenseMap.h"
14#include "llvm/ADT/SmallSet.h"
15#include "llvm/Support/Debug.h"
17#define DEBUG_TYPE "om-reductions"
31 SymbolCache() : tables(std::make_unique<SymbolTableCollection>()) {}
33 SymbolTable &getSymbolTable(Operation *op) {
34 return tables->getSymbolTable(op);
37 SymbolTable &getNearestSymbolTable(Operation *op) {
38 return getSymbolTable(SymbolTable::getNearestSymbolTable(op));
41 SymbolUserMap &getSymbolUserMap(Operation *op) {
42 auto it = userMaps.find(op);
43 if (it != userMaps.end())
45 return userMaps.insert({op, SymbolUserMap(*tables, op)}).first->second;
48 SymbolUserMap &getNearestSymbolUserMap(Operation *op) {
49 return getSymbolUserMap(SymbolTable::getNearestSymbolTable(op));
53 tables = std::make_unique<SymbolTableCollection>();
58 std::unique_ptr<SymbolTableCollection> tables;
72struct OMObjectToUnknownReplacer :
public OpReduction<ObjectOp> {
73 uint64_t
match(ObjectOp objectOp)
override {
75 bool onlyAnyCastOrUnused =
76 llvm::all_of(objectOp->getUsers(),
77 [](Operation *user) { return isa<om::AnyCastOp>(user); });
79 if (!onlyAnyCastOrUnused)
83 return 1 + objectOp.getActualParams().size();
86 LogicalResult
rewrite(ObjectOp objectOp)
override {
87 OpBuilder builder(objectOp);
88 auto unknownOp = om::UnknownValueOp::create(builder, objectOp.getLoc(),
89 objectOp.getResult().getType());
90 objectOp.getResult().replaceAllUsesWith(unknownOp.getResult());
95 std::string
getName()
const override {
return "om-object-to-unknown"; }
99struct OMListElementPruner :
public OpReduction<ListCreateOp> {
100 void matches(ListCreateOp listOp,
101 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
103 auto inputs = listOp.getInputs();
104 for (
size_t i = 0; i < inputs.size(); ++i)
109 ArrayRef<uint64_t> matches)
override {
111 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
114 SmallVector<Value> newInputs;
115 for (
auto [i, input] :
llvm::enumerate(listOp.getInputs()))
116 if (!matchesSet.contains(i))
117 newInputs.push_back(input);
120 ImplicitLocOpBuilder builder(listOp.getLoc(), listOp);
122 ListCreateOp::create(builder, listOp.getResult().getType(), newInputs);
123 listOp.getResult().replaceAllUsesWith(newListOp.getResult());
128 std::string
getName()
const override {
return "om-list-element-pruner"; }
132struct OMClassFieldPruner :
public OpReduction<ClassOp> {
136 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
138 auto fieldNames = classOp.getFieldNames();
139 if (fieldNames.empty())
145 llvm::DenseSet<StringAttr> usedFields;
146 auto moduleOp = classOp->getParentOfType<mlir::ModuleOp>();
148 moduleOp.walk([&](ObjectOp objectOp) {
150 if (objectOp.getClassNameAttr() != classOp.getSymNameAttr())
154 for (
auto *objectUser : objectOp->getUsers()) {
155 auto fieldOp = dyn_cast<ObjectFieldOp>(objectUser);
160 usedFields.insert(fieldOp.getFieldAttr());
165 for (
auto [idx, fieldName] :
llvm::enumerate(fieldNames))
166 if (!usedFields.contains(cast<StringAttr>(fieldName)))
171 ArrayRef<uint64_t> matches)
override {
173 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
176 auto fieldsOp = classOp.getFieldsOp();
177 auto oldFieldValues = fieldsOp.getFields();
179 SmallVector<Attribute> names;
180 SmallVector<Value> values;
181 SmallVector<NamedAttribute> types;
182 for (
auto [idx, nameValue] :
183 llvm::enumerate(
llvm::zip(classOp.getFieldNames(), oldFieldValues))) {
184 if (matchesSet.contains(idx))
187 auto [name, value] = nameValue;
188 auto nameAttr = cast<StringAttr>(name);
189 names.push_back(name);
190 values.push_back(value);
191 types.push_back(NamedAttribute(nameAttr, TypeAttr::get(value.getType())));
195 OpBuilder builder(classOp);
196 classOp.setFieldNamesAttr(builder.getArrayAttr(names));
197 classOp.setFieldTypesAttr(builder.getDictionaryAttr(types));
198 fieldsOp.getFieldsMutable().assign(values);
203 std::string
getName()
const override {
return "om-class-field-pruner"; }
209struct OMClassParameterPruner :
public OpReduction<ClassOp> {
213 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
214 auto *bodyBlock = classOp.getBodyBlock();
215 if (bodyBlock->getNumArguments() == 0)
219 for (
auto [idx, arg] :
llvm::enumerate(bodyBlock->getArguments()))
225 ArrayRef<uint64_t> matches)
override {
227 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
229 auto *bodyBlock = classOp.getBodyBlock();
234 SmallVector<ObjectOp> objectsToUpdate;
235 auto moduleOp = classOp->getParentOfType<mlir::ModuleOp>();
236 moduleOp.walk([&](ObjectOp objectOp) {
237 if (objectOp.getClassNameAttr() == classOp.getSymNameAttr())
238 objectsToUpdate.push_back(objectOp);
242 SmallVector<Attribute> newParamNames;
243 for (
auto [idx, name] :
llvm::enumerate(classOp.getFormalParamNames()))
244 if (!matchesSet.contains(idx))
245 newParamNames.push_back(name);
247 OpBuilder builder(classOp);
248 classOp.setFormalParamNamesAttr(builder.getArrayAttr(newParamNames));
251 SmallVector<unsigned> indicesToRemove(matches.begin(), matches.end());
252 llvm::sort(indicesToRemove, std::greater<unsigned>());
253 for (
unsigned idx : indicesToRemove)
254 bodyBlock->eraseArgument(idx);
257 for (
auto objectOp : objectsToUpdate) {
259 SmallVector<Value> newParams;
260 for (
auto [idx, param] :
llvm::enumerate(objectOp.getActualParams()))
261 if (!matchesSet.contains(idx))
262 newParams.push_back(param);
265 OpBuilder builder(objectOp);
266 auto newObjectOp = ObjectOp::create(
267 builder, objectOp.getLoc(), objectOp.getResult().getType(),
268 objectOp.getClassNameAttr(), newParams);
269 objectOp.getResult().replaceAllUsesWith(newObjectOp.getResult());
276 std::string
getName()
const override {
return "om-class-parameter-pruner"; }
282struct OMUnusedClassRemover :
public OpReduction<ClassOp> {
285 uint64_t
match(ClassOp classOp)
override {
289 auto moduleOp = classOp->getParentOfType<mlir::ModuleOp>();
292 auto result = moduleOp.walk([&](ObjectOp objectOp) {
293 if (objectOp.getClassNameAttr() == classOp.getSymNameAttr())
294 return WalkResult::interrupt();
295 return WalkResult::advance();
298 if (result.wasInterrupted())
304 result = classOp.walk(
305 [&](ObjectOp objectOp) {
return WalkResult::interrupt(); });
307 if (result.wasInterrupted())
316 LogicalResult
rewrite(ClassOp classOp)
override {
321 std::string
getName()
const override {
return "om-unused-class-remover"; }
328struct OMAnyCastOfUnknownSimplifier :
public OpReduction<om::AnyCastOp> {
329 uint64_t
match(om::AnyCastOp anyCastOp)
override {
330 auto unknownOp = anyCastOp.getInput().getDefiningOp<om::UnknownValueOp>();
336 LogicalResult
rewrite(om::AnyCastOp anyCastOp)
override {
337 ImplicitLocOpBuilder builder(anyCastOp.getLoc(), anyCastOp);
338 anyCastOp.getResult().replaceAllUsesWith(
339 om::UnknownValueOp::create(builder, anyCastOp.getResult().getType()));
344 std::string
getName()
const override {
345 return "om-anycast-of-unknown-simplifier";
353 uint64_t
match(Operation *op)
override {
355 if (!isa<OMDialect>(op->getDialect()))
359 if (op->getNumResults() != 1)
366 LogicalResult
rewrite(Operation *op)
override {
367 OpBuilder builder(op);
368 Type resultType = op->getResult(0).getType();
370 om::UnknownValueOp::create(builder, op->getLoc(), resultType);
371 op->getResult(0).replaceAllUsesWith(unknownOp.getResult());
376 std::string
getName()
const override {
return "om-op-to-unknown"; };
388 patterns.add<OMClassParameterPruner, 50>();
389 patterns.add<OMClassFieldPruner, 49>();
390 patterns.add<OMObjectToUnknownReplacer, 48>();
391 patterns.add<OMListElementPruner, 45>();
394 patterns.add<OMUnusedClassRemover, 40>();
396 patterns.add<OMAnyCastOfUnknownSimplifier, 35>();
399void om::registerReducePatternDialectInterface(
400 mlir::DialectRegistry ®istry) {
401 registry.addExtension(+[](MLIRContext *ctx, OMDialect *dialect) {
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
A reduction pattern for a specific operation.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
virtual LogicalResult rewrite(OpTy op)
virtual uint64_t match(OpTy op)
An abstract reduction pattern.
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override