CIRCT 23.0.0git
Loading...
Searching...
No Matches
HWReductions.cpp
Go to the documentation of this file.
1//===- HWReductions.cpp - Reduction patterns for the HW dialect -----------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
13#include "llvm/ADT/SmallSet.h"
14#include "llvm/Support/Debug.h"
15
16#define DEBUG_TYPE "hw-reductions"
17
18using namespace mlir;
19using namespace circt;
20using namespace hw;
21
22//===----------------------------------------------------------------------===//
23// Utilities
24//===----------------------------------------------------------------------===//
25
26/// Utility to track the transitive size of modules.
27struct ModuleSizeCache {
28 void clear() { moduleSizes.clear(); }
29
30 uint64_t getModuleSize(HWModuleLike module,
31 hw::InstanceGraph &instanceGraph) {
32 if (auto it = moduleSizes.find(module); it != moduleSizes.end())
33 return it->second;
34 uint64_t size = 1;
35 module->walk([&](Operation *op) {
36 size += 1;
37 if (auto instOp = dyn_cast<HWInstanceLike>(op)) {
38 for (auto moduleName : instOp.getReferencedModuleNamesAttr()) {
39 auto *node = instanceGraph.lookup(cast<StringAttr>(moduleName));
40 if (auto instModule =
41 dyn_cast_or_null<hw::HWModuleLike>(*node->getModule()))
42 size += getModuleSize(instModule, instanceGraph);
43 }
44 }
45 });
46 moduleSizes.insert({module, size});
47 return size;
48 }
49
50private:
51 llvm::DenseMap<Operation *, uint64_t> moduleSizes;
52};
53
54//===----------------------------------------------------------------------===//
55// Reduction patterns
56//===----------------------------------------------------------------------===//
57
58/// A sample reduction pattern that maps `hw.module` to `hw.module.extern`.
59struct ModuleExternalizer : public OpReduction<HWModuleOp> {
60 void beforeReduction(mlir::ModuleOp op) override {
61 instanceGraph = std::make_unique<InstanceGraph>(op);
63 }
64
65 uint64_t match(HWModuleOp op) override {
67 }
68
69 LogicalResult rewrite(HWModuleOp op) override {
70 OpBuilder builder(op);
71 HWModuleExternOp::create(builder, op->getLoc(), op.getModuleNameAttr(),
72 op.getPortList(), StringRef(), op.getParameters());
73 op->erase();
74 return success();
75 }
76
77 std::string getName() const override { return "hw-module-externalizer"; }
78 bool acceptSizeIncrease() const override { return true; }
79
80 std::unique_ptr<InstanceGraph> instanceGraph;
82};
83
84/// A sample reduction pattern that replaces all uses of an operation with one
85/// of its operands. This can help pruning large parts of the expression tree
86/// rapidly.
87template <unsigned OpNum>
89 uint64_t match(Operation *op) override {
90 if (op->getNumResults() != 1 || op->getNumOperands() < 2 ||
91 OpNum >= op->getNumOperands())
92 return 0;
93 auto resultTy = dyn_cast<IntegerType>(op->getResult(0).getType());
94 auto opTy = dyn_cast<IntegerType>(op->getOperand(OpNum).getType());
95 return resultTy && opTy && resultTy == opTy &&
96 op->getResult(0) != op->getOperand(OpNum);
97 }
98 LogicalResult rewrite(Operation *op) override {
99 assert(match(op));
100 ImplicitLocOpBuilder builder(op->getLoc(), op);
101 auto result = op->getResult(0);
102 auto operand = op->getOperand(OpNum);
103 LLVM_DEBUG(llvm::dbgs()
104 << "Forwarding " << operand << " in " << *op << "\n");
105 result.replaceAllUsesWith(operand);
106 reduce::pruneUnusedOps(op, *this);
107 return success();
108 }
109 std::string getName() const override {
110 return ("hw-operand" + Twine(OpNum) + "-forwarder").str();
111 }
112};
113
114/// A sample reduction pattern that replaces integer operations with a constant
115/// zero of their type.
116struct HWConstantifier : public Reduction {
117 void matches(Operation *op,
118 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
119 if (op->hasTrait<OpTrait::ConstantLike>())
120 return;
121 for (auto result : op->getResults())
122 if (!result.use_empty())
123 if (isa<IntegerType>(result.getType()))
124 addMatch(1, result.getResultNumber());
125 }
126 LogicalResult rewriteMatches(Operation *op,
127 ArrayRef<uint64_t> indices) override {
128 OpBuilder builder(op);
129 for (auto idx : indices) {
130 auto result = op->getResult(idx);
131 auto type = cast<IntegerType>(result.getType());
132 auto newOp = hw::ConstantOp::create(builder, result.getLoc(), type, 0);
133 result.replaceAllUsesWith(newOp);
134 }
135 reduce::pruneUnusedOps(op, *this);
136 return success();
137 }
138 std::string getName() const override { return "hw-constantifier"; }
139 bool acceptSizeIncrease() const override { return true; }
140};
141
142/// Remove unused module input ports.
144 void beforeReduction(mlir::ModuleOp op) override {
145 symbolTables = std::make_unique<SymbolTableCollection>();
146 symbolUsers = std::make_unique<SymbolUserMap>(*symbolTables, op);
147 }
148
149 void matches(Operation *op,
150 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
151 auto mod = dyn_cast<HWModuleLike>(op);
152 if (!mod)
153 return;
154 auto modType = mod.getHWModuleType();
155 if (modType.getNumInputs() == 0)
156 return;
157 auto users = symbolUsers->getUsers(op);
158 if (!llvm::all_of(users, [](auto *user) { return isa<InstanceOp>(user); }))
159 return;
160 auto *block = mod.getBodyBlock();
161 for (unsigned idx = 0; idx < modType.getNumInputs(); ++idx)
162 if (!block || block->getArgument(idx).use_empty())
163 addMatch(1, idx);
164 }
165
166 LogicalResult rewriteMatches(Operation *op,
167 ArrayRef<uint64_t> matches) override {
168 auto mod = cast<HWMutableModuleLike>(op);
169
170 // Remove the ports from the module.
171 SmallVector<unsigned> indexList;
172 BitVector indexSet(mod.getNumInputPorts());
173 for (auto idx : matches) {
174 indexList.push_back(idx);
175 indexSet.set(idx);
176 }
177 llvm::sort(indexList);
178 mod.erasePorts(indexList, {});
179 if (auto *block = mod.getBodyBlock())
180 block->eraseArguments(indexSet);
181
182 // Remove the ports from the instances.
183 for (auto *user : symbolUsers->getUsers(op)) {
184 auto instOp = cast<InstanceOp>(user);
185 SmallVector<Value> newOperands;
186 SmallVector<Attribute> newArgNames;
187 for (auto [idx, data] : llvm::enumerate(
188 llvm::zip(instOp.getInputs(), instOp.getArgNames()))) {
189 if (indexSet.test(idx))
190 continue;
191 auto [operand, argName] = data;
192 newOperands.push_back(operand);
193 newArgNames.push_back(argName);
194 }
195 instOp.getInputsMutable().assign(newOperands);
196 instOp.setArgNamesAttr(ArrayAttr::get(op->getContext(), newArgNames));
197 }
198
199 return success();
200 }
201
202 std::string getName() const override { return "hw-module-input-pruner"; }
203 bool acceptSizeIncrease() const override { return true; }
204
205 std::unique_ptr<SymbolTableCollection> symbolTables;
206 std::unique_ptr<SymbolUserMap> symbolUsers;
207};
208
209/// Remove unused module output ports.
211 void beforeReduction(mlir::ModuleOp op) override {
212 symbolTables = std::make_unique<SymbolTableCollection>();
213 symbolUsers = std::make_unique<SymbolUserMap>(*symbolTables, op);
214 }
215
216 void matches(Operation *op,
217 llvm::function_ref<void(uint64_t, uint64_t)> addMatch) override {
218 auto mod = dyn_cast<HWModuleLike>(op);
219 if (!mod)
220 return;
221 auto modType = mod.getHWModuleType();
222 if (modType.getNumOutputs() == 0)
223 return;
224 auto users = symbolUsers->getUsers(op);
225 if (!llvm::all_of(users, [](auto *user) { return isa<InstanceOp>(user); }))
226 return;
227 for (unsigned idx = 0; idx < modType.getNumOutputs(); ++idx)
228 if (llvm::all_of(users, [&](auto *user) {
229 return user->getResult(idx).use_empty();
230 }))
231 addMatch(1, idx);
232 }
233
234 LogicalResult rewriteMatches(Operation *op,
235 ArrayRef<uint64_t> matches) override {
236 auto mod = cast<HWMutableModuleLike>(op);
237
238 // Remove the ports from the module.
239 SmallVector<unsigned> indexList;
240 BitVector indexSet(mod.getNumOutputPorts());
241 for (auto idx : matches) {
242 indexList.push_back(idx);
243 indexSet.set(idx);
244 }
245 llvm::sort(indexList);
246 mod.erasePorts({}, indexList);
247
248 // Update the `hw.output` op.
249 if (auto *block = mod.getBodyBlock()) {
250 auto outputOp = cast<OutputOp>(block->getTerminator());
251 SmallVector<Value> newOutputs;
252 for (auto [idx, output] : llvm::enumerate(outputOp.getOutputs()))
253 if (!indexSet.test(idx))
254 newOutputs.push_back(output);
255 outputOp.getOutputsMutable().assign(newOutputs);
256 }
257
258 // Remove the ports from the instances.
259 for (auto *user : symbolUsers->getUsers(op)) {
260 OpBuilder builder(user);
261 auto instOp = cast<InstanceOp>(user);
262 SmallVector<Value> oldResults;
263 SmallVector<Type> newResultTypes;
264 SmallVector<Attribute> newResultNames;
265 for (auto [idx, data] : llvm::enumerate(
266 llvm::zip(instOp.getResults(), instOp.getResultNames()))) {
267 if (indexSet.test(idx))
268 continue;
269 auto [result, resultName] = data;
270 oldResults.push_back(result);
271 newResultTypes.push_back(result.getType());
272 newResultNames.push_back(resultName);
273 }
274 auto newOp = InstanceOp::create(
275 builder, instOp.getLoc(), newResultTypes,
276 instOp.getInstanceNameAttr(), instOp.getModuleNameAttr(),
277 instOp.getInputs(), instOp.getArgNamesAttr(),
278 builder.getArrayAttr(newResultNames), instOp.getParametersAttr(),
279 instOp.getInnerSymAttr(), instOp.getDoNotPrintAttr());
280 for (auto [oldResult, newResult] :
281 llvm::zip(oldResults, newOp.getResults()))
282 oldResult.replaceAllUsesWith(newResult);
283 instOp.erase();
284 }
285
286 return success();
287 }
288
289 std::string getName() const override { return "hw-module-output-pruner"; }
290 bool acceptSizeIncrease() const override { return true; }
291
292 std::unique_ptr<SymbolTableCollection> symbolTables;
293 std::unique_ptr<SymbolUserMap> symbolUsers;
294};
295
296/// Pseudo-reduction that removes sv.namehint attributes from operations.
297/// This is not an actual reduction, but often removes extraneous information
298/// that has no bearing on the actual reduction.
300 uint64_t match(Operation *op) override { return op->hasAttr("sv.namehint"); }
301 LogicalResult rewrite(Operation *op) override {
302 op->removeAttr("sv.namehint");
303 return success();
304 }
305 std::string getName() const override { return "sv-namehint-remover"; }
306 bool acceptSizeIncrease() const override { return true; }
307 bool isOneShot() const override { return true; }
308};
309
310/// Pseudo-reduction that sanitizes the names of operations inside modules.
311/// This is not an actual reduction, but often removes extraneous information
312/// that has no bearing on the actual reduction. This makes the following
313/// changes:
314///
315/// - All wires are renamed to "wire"
316///
318 uint64_t match(Operation *op) override {
319 // Only match wire operations.
320 return isa<hw::WireOp>(op);
321 }
322 LogicalResult rewrite(Operation *op) override {
323 cast<hw::WireOp>(op).setName("wire");
324 return success();
325 }
326
327 std::string getName() const override {
328 return "hw-module-internal-name-sanitizer";
329 }
330
331 bool acceptSizeIncrease() const override { return true; }
332
333 bool isOneShot() const override { return true; }
334};
335
336/// Pseudo-reduction that sanitizes module and port names. This makes the
337/// following changes:
338///
339/// - All modules are given metasyntactic names ("Foo", "Bar", etc.)
340/// - All instances are renamed to match the new module name
341/// - All module ports are renamed to simple names ("a", "b", "c", etc.)
342///
343struct ModuleNameSanitizer : OpReduction<mlir::ModuleOp> {
344
345 size_t portNameIndex = 0;
346
347 char getPortName() {
348 if (portNameIndex >= 26)
349 portNameIndex = 0;
350 return 'a' + portNameIndex++;
351 }
352
353 LogicalResult rewrite(mlir::ModuleOp moduleOp) override {
354 // Create a new instance graph for this rewrite. We need to recreate it
355 // because renaming modules invalidates the nodeMap (which maps module names
356 // to nodes).
357 InstanceGraph instanceGraph(moduleOp);
358
360
361 // Iterate over the instance graph nodes
362 for (auto *node : instanceGraph) {
363 // Skip nodes without a module (e.g., the entry node)
364 if (!node->getModule())
365 continue;
366 auto hwModule = dyn_cast<HWModuleOp>(node->getModule().getOperation());
367 if (!hwModule)
368 continue;
369
370 auto *context = hwModule.getContext();
371 auto newName = StringAttr::get(context, nameGenerator.getNextName());
372
373 // Rename ports
374 portNameIndex = 0;
375 auto numPorts = hwModule.getNumPorts();
376 SmallVector<Attribute> newPortNames(numPorts);
377 for (unsigned i = 0; i != numPorts; ++i)
378 newPortNames[i] = StringAttr::get(context, Twine(getPortName()));
379 hwModule.setAllPortNames(newPortNames);
380
381 // Update all instances of this module
382 for (auto *use : node->uses()) {
383 auto useOp = use->getInstance();
384 if (!useOp)
385 continue;
386 auto instOp = dyn_cast<hw::InstanceOp>(*useOp);
387 if (!instOp)
388 continue;
389 instOp.setModuleName(newName);
390 instOp.setInstanceName(newName);
391 // Update argument names (inputs only)
392 auto *inputEnd = newPortNames.begin() + hwModule.getNumInputPorts();
393 SmallVector<Attribute> argNames(newPortNames.begin(), inputEnd);
394 instOp.setArgNamesAttr(ArrayAttr::get(context, argNames));
395 // Update result names (outputs)
396 SmallVector<Attribute> resultNames(inputEnd, newPortNames.end());
397 instOp.setResultNamesAttr(ArrayAttr::get(context, resultNames));
398 }
399
400 // Rename the module (do this last, after updating instances)
401 hwModule.setName(newName);
402 }
403
404 return success();
405 }
406
407 std::string getName() const override { return "hw-module-name-sanitizer"; }
408
409 bool acceptSizeIncrease() const override { return true; }
410
411 bool isOneShot() const override { return true; }
412};
413
414//===----------------------------------------------------------------------===//
415// Reduction Registration
416//===----------------------------------------------------------------------===//
417
420 // Gather a list of reduction patterns that we should try. Ideally these are
421 // assigned reasonable benefit indicators (higher benefit patterns are
422 // prioritized). For example, things that can knock out entire modules while
423 // being cheap should be tried first (and thus have higher benefit), before
424 // trying to tweak operands of individual arithmetic ops.
426 patterns.add<HWConstantifier, 5>();
431 patterns.add<ModuleInputPruner, 2>();
432 patterns.add<SVNamehintRemover, 1>();
435}
436
438 mlir::DialectRegistry &registry) {
439 registry.addExtension(+[](MLIRContext *ctx, HWDialect *dialect) {
440 dialect->addInterfaces<HWReducePatternDialectInterface>();
441 });
442}
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
HW-specific instance graph with a virtual entry node linking to all publicly visible modules.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
A utility class that generates metasyntactic variable names for use in reductions.
const char * getNextName()
Get the next metasyntactic name in the sequence.
create(data_type, value)
Definition hw.py:433
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry)
Register the HW Reduction pattern dialect interface to the given registry.
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition hw.py:1
A sample reduction pattern that replaces integer operations with a constant zero of their type.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > indices) override
Apply a set of matches of this reduction to a specific operation.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
std::string getName() const override
Return a human-readable name for this reduction pattern.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
A sample reduction pattern that replaces all uses of an operation with one of its operands.
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
std::string getName() const override
Return a human-readable name for this reduction pattern.
A sample reduction pattern that maps hw.module to hw.module.extern.
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(HWModuleOp op) override
std::unique_ptr< InstanceGraph > instanceGraph
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
uint64_t match(HWModuleOp op) override
ModuleSizeCache moduleSizes
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
Remove unused module input ports.
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
std::string getName() const override
Return a human-readable name for this reduction pattern.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
std::unique_ptr< SymbolTableCollection > symbolTables
std::unique_ptr< SymbolUserMap > symbolUsers
Pseudo-reduction that sanitizes the names of operations inside modules.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
Pseudo-reduction that sanitizes module and port names.
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
LogicalResult rewrite(mlir::ModuleOp moduleOp) override
Remove unused module output ports.
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
std::unique_ptr< SymbolTableCollection > symbolTables
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
std::unique_ptr< SymbolUserMap > symbolUsers
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
uint64_t getModuleSize(HWModuleLike module, hw::InstanceGraph &instanceGraph)
Pseudo-reduction that removes sv.namehint attributes from operations.
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
A reduction pattern for a specific operation.
Definition Reduction.h:112
An abstract reduction pattern.
Definition Reduction.h:24
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override