CIRCT  20.0.0git
Reduction.h
Go to the documentation of this file.
1 //===- Reduction.h - Reduction datastructure decl. for circt-reduce -------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file defines datastructures to handle reduction patterns.
10 //
11 //===----------------------------------------------------------------------===//
12 
13 #ifndef CIRCT_REDUCE_REDUCTION_H
14 #define CIRCT_REDUCE_REDUCTION_H
15 
16 #include "circt/Support/LLVM.h"
17 #include "mlir/IR/BuiltinOps.h"
18 #include "mlir/Pass/PassManager.h"
19 #include "llvm/ADT/SmallVector.h"
20 
21 namespace circt {
22 
23 /// An abstract reduction pattern.
24 struct Reduction {
25  virtual ~Reduction();
26 
27  /// Called before the reduction is applied to a new subset of operations.
28  /// Reductions may use this callback to collect information such as symbol
29  /// tables about the module upfront.
30  virtual void beforeReduction(mlir::ModuleOp) {}
31 
32  /// Called after the reduction has been applied to a subset of operations.
33  /// Reductions may use this callback to perform post-processing of the
34  /// reductions before the resulting module is tried for interestingness.
35  virtual void afterReduction(mlir::ModuleOp) {}
36 
37  /// Check if the reduction can apply to a specific operation. Returns a
38  /// benefit measure where a higher number means that applying the pattern
39  /// leads to a bigger reduction and zero means that the patten does not
40  /// match and thus cannot be applied at all.
41  virtual uint64_t match(Operation *op) = 0;
42 
43  /// Apply the reduction to a specific operation. If the returned result
44  /// indicates that the application failed, the resulting module is treated the
45  /// same as if the tester marked it as uninteresting.
46  virtual LogicalResult rewrite(Operation *op) = 0;
47 
48  /// Return a human-readable name for this reduction pattern.
49  virtual std::string getName() const = 0;
50 
51  /// Return true if the tool should accept the transformation this reduction
52  /// performs on the module even if the overall size of the output increases.
53  /// This can be handy for patterns that reduce the complexity of the IR at the
54  /// cost of some verbosity.
55  virtual bool acceptSizeIncrease() const { return false; }
56 
57  /// Return true if the tool should not try to reapply this reduction after it
58  /// has been successful. This is useful for reductions whose `match()`
59  /// function keeps returning true even after the reduction has reached a
60  /// fixed-point and no longer performs any change. An example of this are
61  /// reductions that apply a lowering pass which always applies but may leave
62  /// the input unmodified.
63  ///
64  /// This is mainly useful in conjunction with returning true from
65  /// `acceptSizeIncrease()`. For reductions that don't accept an increase, the
66  /// module size has to decrease for them to be considered useful, which
67  /// prevents the tool from getting stuck at a local point where the reduction
68  /// applies but produces no change in the input. However, reductions that *do*
69  /// accept a size increase can get stuck in this local fixed-point as they
70  /// keep applying to the same operations and the tool keeps accepting the
71  /// unmodified input as an improvement.
72  virtual bool isOneShot() const { return false; }
73 
74  /// An optional callback for reductions to communicate removal of operations.
75  std::function<void(Operation *)> notifyOpErasedCallback = nullptr;
76 
77  void notifyOpErased(Operation *op) {
80  }
81 };
82 
83 template <typename OpTy>
84 struct OpReduction : public Reduction {
85  uint64_t match(Operation *op) override {
86  if (auto concreteOp = dyn_cast<OpTy>(op))
87  return match(concreteOp);
88  return 0;
89  }
90  LogicalResult rewrite(Operation *op) override {
91  return rewrite(cast<OpTy>(op));
92  }
93 
94  virtual uint64_t match(OpTy op) { return 1; }
95  virtual LogicalResult rewrite(OpTy op) = 0;
96 };
97 
98 /// A reduction pattern that applies an `mlir::Pass`.
99 struct PassReduction : public Reduction {
100  PassReduction(MLIRContext *context, std::unique_ptr<Pass> pass,
101  bool canIncreaseSize = false, bool oneShot = false);
102  uint64_t match(Operation *op) override;
103  LogicalResult rewrite(Operation *op) override;
104  std::string getName() const override;
105  bool acceptSizeIncrease() const override { return canIncreaseSize; }
106  bool isOneShot() const override { return oneShot; }
107 
108 protected:
109  MLIRContext *const context;
110  std::unique_ptr<mlir::PassManager> pm;
111  StringRef passName;
113  bool oneShot;
114 };
115 
117 public:
118  template <typename R, unsigned Benefit, typename... Args>
119  void add(Args &&...args) {
120  reducePatternsWithBenefit.push_back(
121  {std::make_unique<R>(std::forward<Args>(args)...), Benefit});
122  }
123 
124  void filter(const std::function<bool(const Reduction &)> &pred);
125  void sortByBenefit();
126  size_t size() const;
127 
128  Reduction &operator[](size_t idx) const;
129 
130 private:
131  SmallVector<std::pair<std::unique_ptr<Reduction>, unsigned>>
133 };
134 
135 /// A dialect interface to provide reduction patterns to a reducer tool.
137  : public mlir::DialectInterface::Base<ReducePatternDialectInterface> {
138  ReducePatternDialectInterface(Dialect *dialect) : Base(dialect) {}
139 
141 };
142 
144  : public mlir::DialectInterfaceCollection<ReducePatternDialectInterface> {
145  using Base::Base;
146 
147  // Collect the reduce patterns defined by each dialect.
149 };
150 
151 } // namespace circt
152 
153 #endif // CIRCT_REDUCE_REDUCTION_H
SmallVector< std::pair< std::unique_ptr< Reduction >, unsigned > > reducePatternsWithBenefit
Definition: Reduction.h:132
void add(Args &&...args)
Definition: Reduction.h:119
void filter(const std::function< bool(const Reduction &)> &pred)
Definition: Reduction.cpp:66
size_t size() const
Definition: Reduction.cpp:82
Reduction & operator[](size_t idx) const
Definition: Reduction.cpp:86
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
Definition: Reduction.h:90
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
Definition: Reduction.h:85
virtual LogicalResult rewrite(OpTy op)=0
virtual uint64_t match(OpTy op)
Definition: Reduction.h:94
A reduction pattern that applies an mlir::Pass.
Definition: Reduction.h:99
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
Definition: Reduction.h:106
StringRef passName
Definition: Reduction.h:111
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
Definition: Reduction.cpp:58
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
Definition: Reduction.cpp:54
MLIRContext *const context
Definition: Reduction.h:109
std::string getName() const override
Return a human-readable name for this reduction pattern.
Definition: Reduction.cpp:60
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
Definition: Reduction.h:105
std::unique_ptr< mlir::PassManager > pm
Definition: Reduction.h:110
PassReduction(MLIRContext *context, std::unique_ptr< Pass > pass, bool canIncreaseSize=false, bool oneShot=false)
Definition: Reduction.cpp:33
A dialect interface to provide reduction patterns to a reducer tool.
Definition: Reduction.h:137
ReducePatternDialectInterface(Dialect *dialect)
Definition: Reduction.h:138
virtual void populateReducePatterns(ReducePatternSet &patterns) const =0
void populateReducePatterns(ReducePatternSet &patterns) const
Definition: Reduction.cpp:94
An abstract reduction pattern.
Definition: Reduction.h:24
virtual uint64_t match(Operation *op)=0
Check if the reduction can apply to a specific operation.
virtual void afterReduction(mlir::ModuleOp)
Called after the reduction has been applied to a subset of operations.
Definition: Reduction.h:35
std::function< void(Operation *)> notifyOpErasedCallback
An optional callback for reductions to communicate removal of operations.
Definition: Reduction.h:75
virtual bool acceptSizeIncrease() const
Return true if the tool should accept the transformation this reduction performs on the module even i...
Definition: Reduction.h:55
virtual LogicalResult rewrite(Operation *op)=0
Apply the reduction to a specific operation.
virtual bool isOneShot() const
Return true if the tool should not try to reapply this reduction after it has been successful.
Definition: Reduction.h:72
virtual ~Reduction()
void notifyOpErased(Operation *op)
Definition: Reduction.h:77
virtual std::string getName() const =0
Return a human-readable name for this reduction pattern.
virtual void beforeReduction(mlir::ModuleOp)
Called before the reduction is applied to a new subset of operations.
Definition: Reduction.h:30