CIRCT 22.0.0git
Loading...
Searching...
No Matches
Reduction.h
Go to the documentation of this file.
1//===- Reduction.h - Reduction datastructure decl. for circt-reduce -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines datastructures to handle reduction patterns.
10//
11//===----------------------------------------------------------------------===//
12
13#ifndef CIRCT_REDUCE_REDUCTION_H
14#define CIRCT_REDUCE_REDUCTION_H
15
16#include "circt/Support/LLVM.h"
17#include "mlir/IR/BuiltinOps.h"
18#include "mlir/Pass/PassManager.h"
19#include "llvm/ADT/SmallVector.h"
20
21namespace circt {
22
23/// An abstract reduction pattern.
24struct Reduction {
25 virtual ~Reduction();
26
27 /// Called before the reduction is applied to a new subset of operations.
28 /// Reductions may use this callback to collect information such as symbol
29 /// tables about the module upfront.
30 virtual void beforeReduction(mlir::ModuleOp) {}
31
32 /// Called after the reduction has been applied to a subset of operations.
33 /// Reductions may use this callback to perform post-processing of the
34 /// reductions before the resulting module is tried for interestingness.
35 virtual void afterReduction(mlir::ModuleOp) {}
36
37 /// Check if the reduction can apply to a specific operation. Returns a
38 /// benefit measure where a higher number means that applying the pattern
39 /// leads to a bigger reduction and zero means that the patten does not
40 /// match and thus cannot be applied at all.
41 virtual uint64_t match(Operation *op) { return 0; }
42
43 /// Collect all ways how this reduction can apply to a specific operation. If
44 /// a reduction can apply to an operation in different ways, for example
45 /// deleting different operands, it should call `addMatch` multiple times with
46 /// the expected benefit of the match, as well as an integer identifying one
47 /// of the different ways it can match.
48 ///
49 /// Calls `match(op)` by default.
50 virtual void matches(Operation *op,
51 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) {
52 addMatch(match(op), 0);
53 }
54
55 /// Apply the reduction to a specific operation. If the returned result
56 /// indicates that the application failed, the resulting module is treated the
57 /// same as if the tester marked it as uninteresting.
58 virtual LogicalResult rewrite(Operation *op) { return failure(); }
59
60 /// Apply a set of matches of this reduction to a specific operation. If the
61 /// reduction registered multiple matches for an operation, a subset of the
62 /// integer identifiers of those matches will be passed to this function
63 /// again. If the returned result indicates that the application failed, the
64 /// resulting module is treated the same as if the tester marked it as
65 /// uninteresting.
66 virtual LogicalResult rewriteMatches(Operation *op,
67 ArrayRef<uint64_t> matches) {
68 assert(matches.size() == 1 && matches[0] == 0);
69 return rewrite(op);
70 }
71
72 /// Return a human-readable name for this reduction pattern.
73 virtual std::string getName() const = 0;
74
75 /// Return true if the tool should accept the transformation this reduction
76 /// performs on the module even if the overall size of the output increases.
77 /// This can be handy for patterns that reduce the complexity of the IR at the
78 /// cost of some verbosity.
79 virtual bool acceptSizeIncrease() const { return false; }
80
81 /// Return true if the tool should not try to reapply this reduction after it
82 /// has been successful. This is useful for reductions whose `match()`
83 /// function keeps returning true even after the reduction has reached a
84 /// fixed-point and no longer performs any change. An example of this are
85 /// reductions that apply a lowering pass which always applies but may leave
86 /// the input unmodified.
87 ///
88 /// This is mainly useful in conjunction with returning true from
89 /// `acceptSizeIncrease()`. For reductions that don't accept an increase, the
90 /// module size has to decrease for them to be considered useful, which
91 /// prevents the tool from getting stuck at a local point where the reduction
92 /// applies but produces no change in the input. However, reductions that *do*
93 /// accept a size increase can get stuck in this local fixed-point as they
94 /// keep applying to the same operations and the tool keeps accepting the
95 /// unmodified input as an improvement.
96 virtual bool isOneShot() const { return false; }
97
98 /// An optional callback for reductions to communicate removal of operations.
99 std::function<void(Operation *)> notifyOpErasedCallback = nullptr;
100
101 void notifyOpErased(Operation *op) {
104 }
105};
106
107/// A reduction pattern for a specific operation.
108///
109/// Only matches on operations of type `OpTy`, and calls corresponding match and
110/// rewrite functions with the operation cast to this type, for convenience.
111template <typename OpTy>
112struct OpReduction : public Reduction {
113 void matches(Operation *op,
114 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
115 if (auto concreteOp = dyn_cast<OpTy>(op))
116 matches(concreteOp, addMatch);
117 }
118 LogicalResult rewriteMatches(Operation *op,
119 ArrayRef<uint64_t> matches) override {
120 return rewriteMatches(cast<OpTy>(op), matches);
121 }
122
123 virtual uint64_t match(OpTy op) { return 1; }
124 virtual void matches(OpTy op,
125 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) {
126 addMatch(match(op), 0);
127 }
128 virtual LogicalResult rewrite(OpTy op) { return failure(); }
129 virtual LogicalResult rewriteMatches(OpTy op, ArrayRef<uint64_t> matches) {
130 assert(matches.size() == 1 && matches[0] == 0);
131 return rewrite(op);
132 }
133
134private:
135 /// Hide the base class match/rewrite functions to prevent compiler warnings
136 /// about the `OpTy`-specific ones hiding the base class functions.
139};
140
141/// A reduction pattern that applies an `mlir::Pass`.
142struct PassReduction : public Reduction {
143 PassReduction(MLIRContext *context, std::unique_ptr<Pass> pass,
144 bool canIncreaseSize = false, bool oneShot = false);
145 uint64_t match(Operation *op) override;
146 LogicalResult rewrite(Operation *op) override;
147 std::string getName() const override;
148 bool acceptSizeIncrease() const override { return canIncreaseSize; }
149 bool isOneShot() const override { return oneShot; }
150
151protected:
152 MLIRContext *const context;
153 std::unique_ptr<mlir::PassManager> pm;
154 StringRef passName;
157};
158
160public:
161 template <typename R, unsigned Benefit, typename... Args>
162 void add(Args &&...args) {
164 {std::make_unique<R>(std::forward<Args>(args)...), Benefit});
165 }
166
167 void filter(const std::function<bool(const Reduction &)> &pred);
168 void sortByBenefit();
169 size_t size() const;
170
171 Reduction &operator[](size_t idx) const;
172
173private:
174 SmallVector<std::pair<std::unique_ptr<Reduction>, unsigned>>
176};
177
178/// A dialect interface to provide reduction patterns to a reducer tool.
180 : public mlir::DialectInterface::Base<ReducePatternDialectInterface> {
181 ReducePatternDialectInterface(Dialect *dialect) : Base(dialect) {}
182
184};
185
187 : public mlir::DialectInterfaceCollection<ReducePatternDialectInterface> {
188 using Base::Base;
189
190 // Collect the reduce patterns defined by each dialect.
192};
193
194} // namespace circt
195
196#endif // CIRCT_REDUCE_REDUCTION_H
assert(baseType &&"element must be base type")
SmallVector< std::pair< std::unique_ptr< Reduction >, unsigned > > reducePatternsWithBenefit
Definition Reduction.h:175
void add(Args &&...args)
Definition Reduction.h:162
void filter(const std::function< bool(const Reduction &)> &pred)
Definition Reduction.cpp:66
size_t size() const
Definition Reduction.cpp:82
Reduction & operator[](size_t idx) const
Definition Reduction.cpp:86
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
A reduction pattern for a specific operation.
Definition Reduction.h:112
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:113
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:118
virtual LogicalResult rewrite(OpTy op)
Definition Reduction.h:128
virtual uint64_t match(OpTy op)
Definition Reduction.h:123
virtual void matches(OpTy op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Definition Reduction.h:124
virtual LogicalResult rewriteMatches(OpTy op, ArrayRef< uint64_t > matches)
Definition Reduction.h:129
A reduction pattern that applies an mlir::Pass.
Definition Reduction.h:142
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
Definition Reduction.h:149
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
Definition Reduction.cpp:58
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
Definition Reduction.cpp:54
MLIRContext *const context
Definition Reduction.h:152
std::string getName() const override
Return a human-readable name for this reduction pattern.
Definition Reduction.cpp:60
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
Definition Reduction.h:148
std::unique_ptr< mlir::PassManager > pm
Definition Reduction.h:153
A dialect interface to provide reduction patterns to a reducer tool.
Definition Reduction.h:180
ReducePatternDialectInterface(Dialect *dialect)
Definition Reduction.h:181
virtual void populateReducePatterns(ReducePatternSet &patterns) const =0
void populateReducePatterns(ReducePatternSet &patterns) const
Definition Reduction.cpp:94
An abstract reduction pattern.
Definition Reduction.h:24
virtual LogicalResult rewrite(Operation *op)
Apply the reduction to a specific operation.
Definition Reduction.h:58
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
Definition Reduction.h:35
std::function< void(Operation *)> notifyOpErasedCallback
An optional callback for reductions to communicate removal of operations.
Definition Reduction.h:99
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
Definition Reduction.h:79
virtual LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches)
Apply a set of matches of this reduction to a specific operation.
Definition Reduction.h:66
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
Definition Reduction.h:96
virtual uint64_t match(Operation *op)
Check if the reduction can apply to a specific operation.
Definition Reduction.h:41
virtual ~Reduction()
void notifyOpErased(Operation *op)
Definition Reduction.h:101
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch)
Collect all ways how this reduction can apply to a specific operation.
Definition Reduction.h:50
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
Definition Reduction.h:30