12#include "mlir/IR/SymbolTable.h"
13#include "llvm/ADT/DenseMap.h"
14#include "llvm/ADT/SmallSet.h"
15#include "llvm/Support/Debug.h"
17#define DEBUG_TYPE "om-reductions"
31 SymbolCache() : tables(std::make_unique<SymbolTableCollection>()) {}
33 SymbolTable &getSymbolTable(Operation *op) {
34 return tables->getSymbolTable(op);
37 SymbolTable &getNearestSymbolTable(Operation *op) {
38 return getSymbolTable(SymbolTable::getNearestSymbolTable(op));
41 SymbolUserMap &getSymbolUserMap(Operation *op) {
42 auto it = userMaps.find(op);
43 if (it != userMaps.end())
45 return userMaps.insert({op, SymbolUserMap(*tables, op)}).first->second;
48 SymbolUserMap &getNearestSymbolUserMap(Operation *op) {
49 return getSymbolUserMap(SymbolTable::getNearestSymbolTable(op));
53 tables = std::make_unique<SymbolTableCollection>();
58 std::unique_ptr<SymbolTableCollection> tables;
72struct OMObjectToUnknownReplacer :
public OpReduction<ObjectOp> {
73 uint64_t
match(ObjectOp objectOp)
override {
75 bool onlyAnyCastOrUnused =
76 llvm::all_of(objectOp->getUsers(),
77 [](Operation *user) { return isa<om::AnyCastOp>(user); });
79 if (!onlyAnyCastOrUnused)
83 return 1 + objectOp.getActualParams().size();
86 LogicalResult
rewrite(ObjectOp objectOp)
override {
87 OpBuilder builder(objectOp);
88 auto unknownOp = om::UnknownValueOp::create(builder, objectOp.getLoc(),
89 objectOp.getResult().getType());
90 objectOp.getResult().replaceAllUsesWith(unknownOp.getResult());
95 std::string
getName()
const override {
return "om-object-to-unknown"; }
99struct OMListElementPruner :
public OpReduction<ListCreateOp> {
100 void matches(ListCreateOp listOp,
101 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
103 auto inputs = listOp.getInputs();
104 for (
size_t i = 0; i < inputs.size(); ++i)
109 ArrayRef<uint64_t> matches)
override {
111 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
114 SmallVector<Value> newInputs;
115 for (
auto [i, input] :
llvm::enumerate(listOp.getInputs()))
116 if (!matchesSet.contains(i))
117 newInputs.push_back(input);
120 ImplicitLocOpBuilder builder(listOp.getLoc(), listOp);
122 ListCreateOp::create(builder, listOp.getResult().getType(), newInputs);
123 listOp.getResult().replaceAllUsesWith(newListOp.getResult());
128 std::string
getName()
const override {
return "om-list-element-pruner"; }
132struct OMClassFieldPruner :
public OpReduction<ClassOp> {
136 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
138 auto fieldNames = classOp.getFieldNames();
139 if (fieldNames.empty())
145 llvm::DenseSet<StringAttr> usedFields;
146 auto moduleOp = classOp->getParentOfType<mlir::ModuleOp>();
148 moduleOp.walk([&](ObjectOp objectOp) {
150 if (objectOp.getClassNameAttr() != classOp.getSymNameAttr())
154 for (
auto *objectUser : objectOp->getUsers()) {
155 auto fieldOp = dyn_cast<ObjectFieldOp>(objectUser);
159 auto fieldPath = fieldOp.getFieldPath();
160 if (fieldPath.empty())
164 usedFields.insert(cast<FlatSymbolRefAttr>(fieldPath[0]).getAttr());
169 for (
auto [idx, fieldName] :
llvm::enumerate(fieldNames))
170 if (!usedFields.contains(cast<StringAttr>(fieldName)))
175 ArrayRef<uint64_t> matches)
override {
177 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
180 auto fieldsOp = classOp.getFieldsOp();
181 auto oldFieldValues = fieldsOp.getFields();
183 SmallVector<Attribute> names;
184 SmallVector<Value> values;
185 SmallVector<NamedAttribute> types;
186 for (
auto [idx, nameValue] :
187 llvm::enumerate(
llvm::zip(classOp.getFieldNames(), oldFieldValues))) {
188 if (matchesSet.contains(idx))
191 auto [name, value] = nameValue;
192 auto nameAttr = cast<StringAttr>(name);
193 names.push_back(name);
194 values.push_back(value);
195 types.push_back(NamedAttribute(nameAttr, TypeAttr::get(value.getType())));
199 OpBuilder builder(classOp);
200 classOp.setFieldNamesAttr(builder.getArrayAttr(names));
201 classOp.setFieldTypesAttr(builder.getDictionaryAttr(types));
202 fieldsOp.getFieldsMutable().assign(values);
207 std::string
getName()
const override {
return "om-class-field-pruner"; }
213struct OMClassParameterPruner :
public OpReduction<ClassOp> {
217 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
218 auto *bodyBlock = classOp.getBodyBlock();
219 if (bodyBlock->getNumArguments() == 0)
223 for (
auto [idx, arg] :
llvm::enumerate(bodyBlock->getArguments()))
229 ArrayRef<uint64_t> matches)
override {
231 llvm::SmallDenseSet<uint64_t, 4> matchesSet(matches.begin(), matches.end());
233 auto *bodyBlock = classOp.getBodyBlock();
238 SmallVector<ObjectOp> objectsToUpdate;
239 auto moduleOp = classOp->getParentOfType<mlir::ModuleOp>();
240 moduleOp.walk([&](ObjectOp objectOp) {
241 if (objectOp.getClassNameAttr() == classOp.getSymNameAttr())
242 objectsToUpdate.push_back(objectOp);
246 SmallVector<Attribute> newParamNames;
247 for (
auto [idx, name] :
llvm::enumerate(classOp.getFormalParamNames()))
248 if (!matchesSet.contains(idx))
249 newParamNames.push_back(name);
251 OpBuilder builder(classOp);
252 classOp.setFormalParamNamesAttr(builder.getArrayAttr(newParamNames));
255 SmallVector<unsigned> indicesToRemove(matches.begin(), matches.end());
256 llvm::sort(indicesToRemove, std::greater<unsigned>());
257 for (
unsigned idx : indicesToRemove)
258 bodyBlock->eraseArgument(idx);
261 for (
auto objectOp : objectsToUpdate) {
263 SmallVector<Value> newParams;
264 for (
auto [idx, param] :
llvm::enumerate(objectOp.getActualParams()))
265 if (!matchesSet.contains(idx))
266 newParams.push_back(param);
269 OpBuilder builder(objectOp);
270 auto newObjectOp = ObjectOp::create(
271 builder, objectOp.getLoc(), objectOp.getResult().getType(),
272 objectOp.getClassNameAttr(), newParams);
273 objectOp.getResult().replaceAllUsesWith(newObjectOp.getResult());
280 std::string
getName()
const override {
return "om-class-parameter-pruner"; }
286struct OMUnusedClassRemover :
public OpReduction<ClassOp> {
289 uint64_t
match(ClassOp classOp)
override {
293 auto moduleOp = classOp->getParentOfType<mlir::ModuleOp>();
296 auto result = moduleOp.walk([&](ObjectOp objectOp) {
297 if (objectOp.getClassNameAttr() == classOp.getSymNameAttr())
298 return WalkResult::interrupt();
299 return WalkResult::advance();
302 if (result.wasInterrupted())
308 result = classOp.walk(
309 [&](ObjectOp objectOp) {
return WalkResult::interrupt(); });
311 if (result.wasInterrupted())
320 LogicalResult
rewrite(ClassOp classOp)
override {
325 std::string
getName()
const override {
return "om-unused-class-remover"; }
332struct OMAnyCastOfUnknownSimplifier :
public OpReduction<om::AnyCastOp> {
333 uint64_t
match(om::AnyCastOp anyCastOp)
override {
334 auto unknownOp = anyCastOp.getInput().getDefiningOp<om::UnknownValueOp>();
340 LogicalResult
rewrite(om::AnyCastOp anyCastOp)
override {
341 ImplicitLocOpBuilder builder(anyCastOp.getLoc(), anyCastOp);
342 anyCastOp.getResult().replaceAllUsesWith(
343 om::UnknownValueOp::create(builder, anyCastOp.getResult().getType()));
348 std::string
getName()
const override {
349 return "om-anycast-of-unknown-simplifier";
362 patterns.add<OMClassParameterPruner, 50>();
363 patterns.add<OMClassFieldPruner, 49>();
364 patterns.add<OMObjectToUnknownReplacer, 48>();
365 patterns.add<OMListElementPruner, 45>();
368 patterns.add<OMUnusedClassRemover, 40>();
369 patterns.add<OMAnyCastOfUnknownSimplifier, 35>();
372void om::registerReducePatternDialectInterface(
373 mlir::DialectRegistry ®istry) {
374 registry.addExtension(+[](MLIRContext *ctx, OMDialect *dialect) {
Default symbol cache implementation; stores associations between names (StringAttr's) to mlir::Operat...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
A reduction pattern for a specific operation.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
virtual LogicalResult rewrite(OpTy op)
virtual uint64_t match(OpTy op)
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override