Loading [MathJax]/jax/input/TeX/config.js
CIRCT 22.0.0git
All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
HWReductions.cpp
Go to the documentation of this file.
1//===- HWReductions.cpp - Reduction patterns for the HW dialect -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "llvm/ADT/SmallSet.h"
14#include "llvm/Support/Debug.h"
15
16#define DEBUG_TYPE "hw-reductions"
17
18using namespace mlir;
19using namespace circt;
20using namespace hw;
21
22//===----------------------------------------------------------------------===//
23// Utilities
24//===----------------------------------------------------------------------===//
25
26/// Utility to track the transitive size of modules.
27struct ModuleSizeCache {
28 void clear() { moduleSizes.clear(); }
29
30 uint64_t getModuleSize(HWModuleLike module,
31 hw::InstanceGraph &instanceGraph) {
32 if (auto it = moduleSizes.find(module); it != moduleSizes.end())
33 return it->second;
34 uint64_t size = 1;
35 module->walk([&](Operation *op) {
36 size += 1;
37 if (auto instOp = dyn_cast<HWInstanceLike>(op)) {
38 for (auto moduleName : instOp.getReferencedModuleNamesAttr()) {
39 auto *node = instanceGraph.lookup(cast<StringAttr>(moduleName));
40 if (auto instModule =
41 dyn_cast_or_null<hw::HWModuleLike>(*node->getModule()))
42 size += getModuleSize(instModule, instanceGraph);
43 }
44 }
45 });
46 moduleSizes.insert({module, size});
47 return size;
48 }
49
50private:
51 llvm::DenseMap<Operation *, uint64_t> moduleSizes;
52};
53
54//===----------------------------------------------------------------------===//
55// Reduction patterns
56//===----------------------------------------------------------------------===//
57
58/// A sample reduction pattern that maps `hw.module` to `hw.module.extern`.
59struct ModuleExternalizer : public OpReduction<HWModuleOp> {
60 void beforeReduction(mlir::ModuleOp op) override {
61 instanceGraph = std::make_unique<InstanceGraph>(op);
63 }
64
65 uint64_t match(HWModuleOp op) override {
67 }
68
69 LogicalResult rewrite(HWModuleOp op) override {
70 OpBuilder builder(op);
71 HWModuleExternOp::create(builder, op->getLoc(), op.getModuleNameAttr(),
72 op.getPortList(), StringRef(), op.getParameters());
73 op->erase();
74 return success();
75 }
76
77 std::string getName() const override { return "hw-module-externalizer"; }
78 bool acceptSizeIncrease() const override { return true; }
79
80 std::unique_ptr<InstanceGraph> instanceGraph;
82};
83
84/// A sample reduction pattern that replaces all uses of an operation with one
85/// of its operands. This can help pruning large parts of the expression tree
86/// rapidly.
87template <unsigned OpNum>
89 uint64_t match(Operation *op) override {
90 if (op->getNumResults() != 1 || op->getNumOperands() < 2 ||
91 OpNum >= op->getNumOperands())
92 return 0;
93 auto resultTy = dyn_cast<IntegerType>(op->getResult(0).getType());
94 auto opTy = dyn_cast<IntegerType>(op->getOperand(OpNum).getType());
95 return resultTy && opTy && resultTy == opTy &&
96 op->getResult(0) != op->getOperand(OpNum);
97 }
98 LogicalResult rewrite(Operation *op) override {
99 assert(match(op));
100 ImplicitLocOpBuilder builder(op->getLoc(), op);
101 auto result = op->getResult(0);
102 auto operand = op->getOperand(OpNum);
103 LLVM_DEBUG(llvm::dbgs()
104 << "Forwarding " << operand << " in " << *op << "\n");
105 result.replaceAllUsesWith(operand);
106 reduce::pruneUnusedOps(op, *this);
107 return success();
108 }
109 std::string getName() const override {
110 return ("hw-operand" + Twine(OpNum) + "-forwarder").str();
111 }
112};
113
114/// A sample reduction pattern that replaces integer operations with a constant
115/// zero of their type.
116struct HWConstantifier : public Reduction {
117 void matches(Operation *op,
118 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
119 if (op->hasTrait<OpTrait::ConstantLike>())
120 return;
121 for (auto result : op->getResults())
122 if (!result.use_empty())
123 if (isa<IntegerType>(result.getType()))
124 addMatch(1, result.getResultNumber());
125 }
126 LogicalResult rewriteMatches(Operation *op,
127 ArrayRef<uint64_t> indices) override {
128 OpBuilder builder(op);
129 for (auto idx : indices) {
130 auto result = op->getResult(idx);
131 auto type = cast<IntegerType>(result.getType());
132 auto newOp = hw::ConstantOp::create(builder, result.getLoc(), type, 0);
133 result.replaceAllUsesWith(newOp);
134 }
135 reduce::pruneUnusedOps(op, *this);
136 return success();
137 }
138 std::string getName() const override { return "hw-constantifier"; }
139 bool acceptSizeIncrease() const override { return true; }
140};
141
142/// Remove unused module input ports.
144 void beforeReduction(mlir::ModuleOp op) override {
145 symbolTables = std::make_unique<SymbolTableCollection>();
146 symbolUsers = std::make_unique<SymbolUserMap>(*symbolTables, op);
147 }
148
149 void matches(Operation *op,
150 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
151 auto mod = dyn_cast<HWModuleLike>(op);
152 if (!mod)
153 return;
154 auto modType = mod.getHWModuleType();
155 if (modType.getNumInputs() == 0)
156 return;
157 auto users = symbolUsers->getUsers(op);
158 if (!llvm::all_of(users, [](auto *user) { return isa<InstanceOp>(user); }))
159 return;
160 auto *block = mod.getBodyBlock();
161 for (unsigned idx = 0; idx < modType.getNumInputs(); ++idx)
162 if (!block || block->getArgument(idx).use_empty())
163 addMatch(1, idx);
164 }
165
166 LogicalResult rewriteMatches(Operation *op,
167 ArrayRef<uint64_t> matches) override {
168 auto mod = cast<HWMutableModuleLike>(op);
169
170 // Remove the ports from the module.
171 SmallVector<unsigned> indexList;
172 BitVector indexSet(mod.getNumInputPorts());
173 for (auto idx : matches) {
174 indexList.push_back(idx);
175 indexSet.set(idx);
176 }
177 llvm::sort(indexList);
178 mod.erasePorts(indexList, {});
179 if (auto *block = mod.getBodyBlock())
180 block->eraseArguments(indexSet);
181
182 // Remove the ports from the instances.
183 for (auto *user : symbolUsers->getUsers(op)) {
184 auto instOp = cast<InstanceOp>(user);
185 SmallVector<Value> newOperands;
186 SmallVector<Attribute> newArgNames;
187 for (auto [idx, data] : llvm::enumerate(
188 llvm::zip(instOp.getInputs(), instOp.getArgNames()))) {
189 if (indexSet.test(idx))
190 continue;
191 auto [operand, argName] = data;
192 newOperands.push_back(operand);
193 newArgNames.push_back(argName);
194 }
195 instOp.getInputsMutable().assign(newOperands);
196 instOp.setArgNamesAttr(ArrayAttr::get(op->getContext(), newArgNames));
197 }
198
199 return success();
200 }
201
202 std::string getName() const override { return "hw-module-input-pruner"; }
203 bool acceptSizeIncrease() const override { return true; }
204
205 std::unique_ptr<SymbolTableCollection> symbolTables;
206 std::unique_ptr<SymbolUserMap> symbolUsers;
207};
208
209/// Remove unused module output ports.
211 void beforeReduction(mlir::ModuleOp op) override {
212 symbolTables = std::make_unique<SymbolTableCollection>();
213 symbolUsers = std::make_unique<SymbolUserMap>(*symbolTables, op);
214 }
215
216 void matches(Operation *op,
217 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
218 auto mod = dyn_cast<HWModuleLike>(op);
219 if (!mod)
220 return;
221 auto modType = mod.getHWModuleType();
222 if (modType.getNumOutputs() == 0)
223 return;
224 auto users = symbolUsers->getUsers(op);
225 if (!llvm::all_of(users, [](auto *user) { return isa<InstanceOp>(user); }))
226 return;
227 for (unsigned idx = 0; idx < modType.getNumOutputs(); ++idx)
228 if (llvm::all_of(users, [&](auto *user) {
229 return user->getResult(idx).use_empty();
230 }))
231 addMatch(1, idx);
232 }
233
234 LogicalResult rewriteMatches(Operation *op,
235 ArrayRef<uint64_t> matches) override {
236 auto mod = cast<HWMutableModuleLike>(op);
237
238 // Remove the ports from the module.
239 SmallVector<unsigned> indexList;
240 BitVector indexSet(mod.getNumOutputPorts());
241 for (auto idx : matches) {
242 indexList.push_back(idx);
243 indexSet.set(idx);
244 }
245 llvm::sort(indexList);
246 mod.erasePorts({}, indexList);
247
248 // Update the `hw.output` op.
249 if (auto *block = mod.getBodyBlock()) {
250 auto outputOp = cast<OutputOp>(block->getTerminator());
251 SmallVector<Value> newOutputs;
252 for (auto [idx, output] : llvm::enumerate(outputOp.getOutputs()))
253 if (!indexSet.test(idx))
254 newOutputs.push_back(output);
255 outputOp.getOutputsMutable().assign(newOutputs);
256 }
257
258 // Remove the ports from the instances.
259 for (auto *user : symbolUsers->getUsers(op)) {
260 OpBuilder builder(user);
261 auto instOp = cast<InstanceOp>(user);
262 SmallVector<Value> oldResults;
263 SmallVector<Type> newResultTypes;
264 SmallVector<Attribute> newResultNames;
265 for (auto [idx, data] : llvm::enumerate(
266 llvm::zip(instOp.getResults(), instOp.getResultNames()))) {
267 if (indexSet.test(idx))
268 continue;
269 auto [result, resultName] = data;
270 oldResults.push_back(result);
271 newResultTypes.push_back(result.getType());
272 newResultNames.push_back(resultName);
273 }
274 auto newOp = InstanceOp::create(
275 builder, instOp.getLoc(), newResultTypes,
276 instOp.getInstanceNameAttr(), instOp.getModuleNameAttr(),
277 instOp.getInputs(), instOp.getArgNamesAttr(),
278 builder.getArrayAttr(newResultNames), instOp.getParametersAttr(),
279 instOp.getInnerSymAttr(), instOp.getDoNotPrintAttr());
280 for (auto [oldResult, newResult] :
281 llvm::zip(oldResults, newOp.getResults()))
282 oldResult.replaceAllUsesWith(newResult);
283 instOp.erase();
284 }
285
286 return success();
287 }
288
289 std::string getName() const override { return "hw-module-output-pruner"; }
290 bool acceptSizeIncrease() const override { return true; }
291
292 std::unique_ptr<SymbolTableCollection> symbolTables;
293 std::unique_ptr<SymbolUserMap> symbolUsers;
294};
295
296//===----------------------------------------------------------------------===//
297// Reduction Registration
298//===----------------------------------------------------------------------===//
299
302 // Gather a list of reduction patterns that we should try. Ideally these are
303 // assigned reasonable benefit indicators (higher benefit patterns are
304 // prioritized). For example, things that can knock out entire modules while
305 // being cheap should be tried first (and thus have higher benefit), before
306 // trying to tweak operands of individual arithmetic ops.
308 patterns.add<HWConstantifier, 5>();
313 patterns.add<ModuleInputPruner, 2>();
314}
315
317 mlir::DialectRegistry &registry) {
318 registry.addExtension(+[](MLIRContext *ctx, HWDialect *dialect) {
319 dialect->addInterfaces<HWReducePatternDialectInterface>();
320 });
321}
assert(baseType &&"element must be base type")
HW-specific instance graph with a virtual entry node linking to all publicly visible modules.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
create(data_type, value)
Definition hw.py:433
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry)
Register the HW Reduction pattern dialect interface to the given registry.
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
A sample reduction pattern that replaces integer operations with a constant zero of their type.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > indices) override
Apply a set of matches of this reduction to a specific operation.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
std::string getName() const override
Return a human-readable name for this reduction pattern.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
A sample reduction pattern that replaces all uses of an operation with one of its operands.
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
std::string getName() const override
Return a human-readable name for this reduction pattern.
A sample reduction pattern that maps hw.module to hw.module.extern.
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(HWModuleOp op) override
std::unique_ptr< InstanceGraph > instanceGraph
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
uint64_t match(HWModuleOp op) override
ModuleSizeCache moduleSizes
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
Remove unused module input ports.
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
std::string getName() const override
Return a human-readable name for this reduction pattern.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
std::unique_ptr< SymbolTableCollection > symbolTables
std::unique_ptr< SymbolUserMap > symbolUsers
Remove unused module output ports.
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
std::unique_ptr< SymbolTableCollection > symbolTables
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
std::unique_ptr< SymbolUserMap > symbolUsers
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
uint64_t getModuleSize(HWModuleLike module, hw::InstanceGraph &instanceGraph)
An abstract reduction pattern.
Definition Reduction.h:24
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override