CIRCT 20.0.0git
Loading...
Searching...
No Matches
HWReductions.cpp
Go to the documentation of this file.
1//===- HWReductions.cpp - Reduction patterns for the HW dialect -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "llvm/ADT/SmallSet.h"
14#include "llvm/Support/Debug.h"
15
16#define DEBUG_TYPE "hw-reductions"
17
18using namespace mlir;
19using namespace circt;
20using namespace hw;
21
22//===----------------------------------------------------------------------===//
23// Utilities
24//===----------------------------------------------------------------------===//
25
26/// Utility to track the transitive size of modules.
27struct ModuleSizeCache {
28 void clear() { moduleSizes.clear(); }
29
30 uint64_t getModuleSize(HWModuleLike module,
31 hw::InstanceGraph &instanceGraph) {
32 if (auto it = moduleSizes.find(module); it != moduleSizes.end())
33 return it->second;
34 uint64_t size = 1;
35 module->walk([&](Operation *op) {
36 size += 1;
37 if (auto instOp = dyn_cast<HWInstanceLike>(op)) {
38 for (auto moduleName : instOp.getReferencedModuleNamesAttr()) {
39 auto *node = instanceGraph.lookup(cast<StringAttr>(moduleName));
40 if (auto instModule =
41 dyn_cast_or_null<hw::HWModuleLike>(*node->getModule()))
42 size += getModuleSize(instModule, instanceGraph);
43 }
44 }
45 });
46 moduleSizes.insert({module, size});
47 return size;
48 }
49
50private:
51 llvm::DenseMap<Operation *, uint64_t> moduleSizes;
52};
53
54//===----------------------------------------------------------------------===//
55// Reduction patterns
56//===----------------------------------------------------------------------===//
57
58/// A sample reduction pattern that maps `hw.module` to `hw.module.extern`.
59struct ModuleExternalizer : public OpReduction<HWModuleOp> {
60 void beforeReduction(mlir::ModuleOp op) override {
61 instanceGraph = std::make_unique<InstanceGraph>(op);
63 }
64
65 uint64_t match(HWModuleOp op) override {
67 }
68
69 LogicalResult rewrite(HWModuleOp op) override {
70 OpBuilder builder(op);
71 builder.create<HWModuleExternOp>(op->getLoc(), op.getModuleNameAttr(),
72 op.getPortList(), StringRef(),
73 op.getParameters());
74 op->erase();
75 return success();
76 }
77
78 std::string getName() const override { return "hw-module-externalizer"; }
79
80 std::unique_ptr<InstanceGraph> instanceGraph;
82};
83
84/// A sample reduction pattern that replaces all uses of an operation with one
85/// of its operands. This can help pruning large parts of the expression tree
86/// rapidly.
87template <unsigned OpNum>
89 uint64_t match(Operation *op) override {
90 if (op->getNumResults() != 1 || op->getNumOperands() < 2 ||
91 OpNum >= op->getNumOperands())
92 return 0;
93 auto resultTy = dyn_cast<IntegerType>(op->getResult(0).getType());
94 auto opTy = dyn_cast<IntegerType>(op->getOperand(OpNum).getType());
95 return resultTy && opTy && resultTy == opTy &&
96 op->getResult(0) != op->getOperand(OpNum);
97 }
98 LogicalResult rewrite(Operation *op) override {
99 assert(match(op));
100 ImplicitLocOpBuilder builder(op->getLoc(), op);
101 auto result = op->getResult(0);
102 auto operand = op->getOperand(OpNum);
103 LLVM_DEBUG(llvm::dbgs()
104 << "Forwarding " << operand << " in " << *op << "\n");
105 result.replaceAllUsesWith(operand);
106 reduce::pruneUnusedOps(op, *this);
107 return success();
108 }
109 std::string getName() const override {
110 return ("hw-operand" + Twine(OpNum) + "-forwarder").str();
111 }
112};
113
114/// A sample reduction pattern that replaces integer operations with a constant
115/// zero of their type.
116struct HWConstantifier : public Reduction {
117 uint64_t match(Operation *op) override {
118 if (op->getNumResults() == 0 || op->getNumOperands() == 0)
119 return 0;
120 return llvm::all_of(op->getResults(), [](Value result) {
121 return isa<IntegerType>(result.getType());
122 });
123 }
124 LogicalResult rewrite(Operation *op) override {
125 assert(match(op));
126 OpBuilder builder(op);
127 for (auto result : op->getResults()) {
128 auto type = cast<IntegerType>(result.getType());
129 auto newOp = builder.create<hw::ConstantOp>(op->getLoc(), type, 0);
130 result.replaceAllUsesWith(newOp);
131 }
132 reduce::pruneUnusedOps(op, *this);
133 return success();
134 }
135 std::string getName() const override { return "hw-constantifier"; }
136};
137
138/// Remove the first or last output of the top-level module depending on the
139/// 'Front' template parameter.
140template <bool Front>
141struct ModuleOutputPruner : public OpReduction<HWModuleOp> {
142 void beforeReduction(mlir::ModuleOp op) override {
143 useEmpty.clear();
144
145 SymbolTableCollection table;
146 SymbolUserMap users(table, op);
147 for (auto module : op.getOps<HWModuleOp>())
148 if (users.useEmpty(module))
149 useEmpty.insert(module);
150 }
151
152 uint64_t match(HWModuleOp op) override {
153 return op.getNumOutputPorts() != 0 && useEmpty.contains(op);
154 }
155
156 LogicalResult rewrite(HWModuleOp op) override {
157 Operation *terminator = op.getBody().front().getTerminator();
158 auto operands = terminator->getOperands();
159 ValueRange newOutputs = operands.drop_back();
160 unsigned portToErase = op.getNumOutputPorts() - 1;
161 if (Front) {
162 newOutputs = operands.drop_front();
163 portToErase = 0;
164 }
165
166 terminator->setOperands(newOutputs);
167 op.erasePorts({}, {portToErase});
168
169 return success();
170 }
171
172 std::string getName() const override {
173 return Front ? "hw-module-output-pruner-front"
174 : "hw-module-output-pruner-back";
175 }
176
177 DenseSet<HWModuleOp> useEmpty;
178};
179
180/// Remove all input ports of the top-level module that have no users
181struct ModuleInputPruner : public OpReduction<HWModuleOp> {
182 void beforeReduction(mlir::ModuleOp op) override {
183 useEmpty.clear();
184
185 SymbolTableCollection table;
186 SymbolUserMap users(table, op);
187 for (auto module : op.getOps<HWModuleOp>())
188 if (users.useEmpty(module))
189 useEmpty.insert(module);
190 }
191
192 uint64_t match(HWModuleOp op) override { return useEmpty.contains(op); }
193
194 LogicalResult rewrite(HWModuleOp op) override {
195 SmallVector<unsigned> inputsToErase;
196 BitVector toErase(op.getNumPorts());
197 for (auto [i, arg] : llvm::enumerate(op.getBody().getArguments())) {
198 if (arg.use_empty()) {
199 toErase.set(i);
200 inputsToErase.push_back(i);
201 }
202 }
203
204 op.erasePorts(inputsToErase, {});
205 op.getBodyBlock()->eraseArguments(toErase);
206
207 return success();
208 }
209
210 std::string getName() const override { return "hw-module-input-pruner"; }
211
212 DenseSet<HWModuleOp> useEmpty;
213};
214
215//===----------------------------------------------------------------------===//
216// Reduction Registration
217//===----------------------------------------------------------------------===//
218
221 // Gather a list of reduction patterns that we should try. Ideally these are
222 // assigned reasonable benefit indicators (higher benefit patterns are
223 // prioritized). For example, things that can knock out entire modules while
224 // being cheap should be tried first (and thus have higher benefit), before
225 // trying to tweak operands of individual arithmetic ops.
227 patterns.add<HWConstantifier, 5>();
233 patterns.add<ModuleInputPruner, 2>();
234}
235
237 mlir::DialectRegistry &registry) {
238 registry.addExtension(+[](MLIRContext *ctx, HWDialect *dialect) {
239 dialect->addInterfaces<HWReducePatternDialectInterface>();
240 });
241}
assert(baseType &&"element must be base type")
HW-specific instance graph with a virtual entry node linking to all publicly visible modules.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry)
Register the HW Reduction pattern dialect interface to the given registry.
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
A sample reduction pattern that replaces integer operations with a constant zero of their type.
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
A sample reduction pattern that replaces all uses of an operation with one of its operands.
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
std::string getName() const override
Return a human-readable name for this reduction pattern.
A sample reduction pattern that maps hw.module to hw.module.extern.
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(HWModuleOp op) override
std::unique_ptr< InstanceGraph > instanceGraph
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
uint64_t match(HWModuleOp op) override
ModuleSizeCache moduleSizes
Remove all input ports of the top-level module that have no users.
uint64_t match(HWModuleOp op) override
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
DenseSet< HWModuleOp > useEmpty
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(HWModuleOp op) override
Remove the first or last output of the top-level module depending on the 'Front' template parameter.
LogicalResult rewrite(HWModuleOp op) override
DenseSet< HWModuleOp > useEmpty
std::string getName() const override
Return a human-readable name for this reduction pattern.
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
uint64_t match(HWModuleOp op) override
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
uint64_t getModuleSize(HWModuleLike module, hw::InstanceGraph &instanceGraph)
An abstract reduction pattern.
Definition Reduction.h:24
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override