13#ifndef CIRCT_REDUCE_REDUCTION_H
14#define CIRCT_REDUCE_REDUCTION_H
17#include "mlir/IR/BuiltinOps.h"
18#include "mlir/Pass/PassManager.h"
19#include "llvm/ADT/SmallVector.h"
41 virtual uint64_t
match(Operation *op) {
return 0; }
51 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch) {
52 addMatch(
match(op), 0);
58 virtual LogicalResult
rewrite(Operation *op) {
return failure(); }
111template <
typename OpTy>
114 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
115 if (
auto concreteOp = dyn_cast<OpTy>(op))
119 ArrayRef<uint64_t>
matches)
override {
123 virtual uint64_t
match(OpTy op) {
return 1; }
125 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch) {
126 addMatch(
match(op), 0);
128 virtual LogicalResult
rewrite(OpTy op) {
return failure(); }
145 uint64_t
match(Operation *op)
override;
146 LogicalResult
rewrite(Operation *op)
override;
147 std::string
getName()
const override;
153 std::unique_ptr<mlir::PassManager>
pm;
161 template <
typename R,
unsigned Benefit,
typename... Args>
162 void add(Args &&...args) {
164 {std::make_unique<R>(std::forward<Args>(args)...), Benefit});
174 SmallVector<std::pair<std::unique_ptr<Reduction>,
unsigned>>
187 :
public mlir::DialectInterfaceCollection<ReducePatternDialectInterface> {
assert(baseType &&"element must be base type")
SmallVector< std::pair< std::unique_ptr< Reduction >, unsigned > > reducePatternsWithBenefit
void filter(const std::function< bool(const Reduction &)> &pred)
Reduction & operator[](size_t idx) const
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
A reduction pattern for a specific operation.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
virtual LogicalResult rewrite(OpTy op)
virtual uint64_t match(OpTy op)
virtual void matches(OpTy op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
virtual LogicalResult rewriteMatches(OpTy op, ArrayRef< uint64_t > matches)
A reduction pattern that applies an mlir::Pass.
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
MLIRContext *const context
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
std::unique_ptr< mlir::PassManager > pm
A dialect interface to provide reduction patterns to a reducer tool.
ReducePatternDialectInterface(Dialect *dialect)
virtual void populateReducePatterns(ReducePatternSet &patterns) const =0
void populateReducePatterns(ReducePatternSet &patterns) const
An abstract reduction pattern.
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
std::function< void(Operation *)> notifyOpErasedCallback
An optional callback for reductions to communicate removal of operations.
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
void notifyOpErased(Operation *op)
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.