CIRCT  20.0.0git
HWReductions.cpp
Go to the documentation of this file.
1 //===- HWReductions.cpp - Reduction patterns for the HW dialect -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
11 #include "circt/Dialect/HW/HWOps.h"
13 #include "llvm/ADT/SmallSet.h"
14 #include "llvm/Support/Debug.h"
15 
16 #define DEBUG_TYPE "hw-reductions"
17 
18 using namespace mlir;
19 using namespace circt;
20 using namespace hw;
21 
22 //===----------------------------------------------------------------------===//
23 // Utilities
24 //===----------------------------------------------------------------------===//
25 
26 /// Utility to track the transitive size of modules.
27 struct ModuleSizeCache {
28  void clear() { moduleSizes.clear(); }
29 
30  uint64_t getModuleSize(HWModuleLike module,
31  hw::InstanceGraph &instanceGraph) {
32  if (auto it = moduleSizes.find(module); it != moduleSizes.end())
33  return it->second;
34  uint64_t size = 1;
35  module->walk([&](Operation *op) {
36  size += 1;
37  if (auto instOp = dyn_cast<HWInstanceLike>(op)) {
38  for (auto moduleName : instOp.getReferencedModuleNamesAttr()) {
39  auto *node = instanceGraph.lookup(cast<StringAttr>(moduleName));
40  if (auto instModule =
41  dyn_cast_or_null<hw::HWModuleLike>(*node->getModule()))
42  size += getModuleSize(instModule, instanceGraph);
43  }
44  }
45  });
46  moduleSizes.insert({module, size});
47  return size;
48  }
49 
50 private:
51  llvm::DenseMap<Operation *, uint64_t> moduleSizes;
52 };
53 
54 //===----------------------------------------------------------------------===//
55 // Reduction patterns
56 //===----------------------------------------------------------------------===//
57 
58 /// A sample reduction pattern that maps `hw.module` to `hw.module.extern`.
59 struct ModuleExternalizer : public OpReduction<HWModuleOp> {
60  void beforeReduction(mlir::ModuleOp op) override {
61  instanceGraph = std::make_unique<InstanceGraph>(op);
62  moduleSizes.clear();
63  }
64 
65  uint64_t match(HWModuleOp op) override {
66  return moduleSizes.getModuleSize(op, *instanceGraph);
67  }
68 
69  LogicalResult rewrite(HWModuleOp op) override {
70  OpBuilder builder(op);
71  builder.create<HWModuleExternOp>(op->getLoc(), op.getModuleNameAttr(),
72  op.getPortList(), StringRef(),
73  op.getParameters());
74  op->erase();
75  return success();
76  }
77 
78  std::string getName() const override { return "hw-module-externalizer"; }
79 
80  std::unique_ptr<InstanceGraph> instanceGraph;
82 };
83 
84 /// A sample reduction pattern that replaces all uses of an operation with one
85 /// of its operands. This can help pruning large parts of the expression tree
86 /// rapidly.
87 template <unsigned OpNum>
88 struct HWOperandForwarder : public Reduction {
89  uint64_t match(Operation *op) override {
90  if (op->getNumResults() != 1 || op->getNumOperands() < 2 ||
91  OpNum >= op->getNumOperands())
92  return 0;
93  auto resultTy = dyn_cast<IntegerType>(op->getResult(0).getType());
94  auto opTy = dyn_cast<IntegerType>(op->getOperand(OpNum).getType());
95  return resultTy && opTy && resultTy == opTy &&
96  op->getResult(0) != op->getOperand(OpNum);
97  }
98  LogicalResult rewrite(Operation *op) override {
99  assert(match(op));
100  ImplicitLocOpBuilder builder(op->getLoc(), op);
101  auto result = op->getResult(0);
102  auto operand = op->getOperand(OpNum);
103  LLVM_DEBUG(llvm::dbgs()
104  << "Forwarding " << operand << " in " << *op << "\n");
105  result.replaceAllUsesWith(operand);
106  reduce::pruneUnusedOps(op, *this);
107  return success();
108  }
109  std::string getName() const override {
110  return ("hw-operand" + Twine(OpNum) + "-forwarder").str();
111  }
112 };
113 
114 /// A sample reduction pattern that replaces integer operations with a constant
115 /// zero of their type.
116 struct HWConstantifier : public Reduction {
117  uint64_t match(Operation *op) override {
118  if (op->getNumResults() == 0 || op->getNumOperands() == 0)
119  return 0;
120  return llvm::all_of(op->getResults(), [](Value result) {
121  return isa<IntegerType>(result.getType());
122  });
123  }
124  LogicalResult rewrite(Operation *op) override {
125  assert(match(op));
126  OpBuilder builder(op);
127  for (auto result : op->getResults()) {
128  auto type = cast<IntegerType>(result.getType());
129  auto newOp = builder.create<hw::ConstantOp>(op->getLoc(), type, 0);
130  result.replaceAllUsesWith(newOp);
131  }
132  reduce::pruneUnusedOps(op, *this);
133  return success();
134  }
135  std::string getName() const override { return "hw-constantifier"; }
136 };
137 
138 /// Remove the first or last output of the top-level module depending on the
139 /// 'Front' template parameter.
140 template <bool Front>
141 struct ModuleOutputPruner : public OpReduction<HWModuleOp> {
142  void beforeReduction(mlir::ModuleOp op) override {
143  useEmpty.clear();
144 
145  SymbolTableCollection table;
146  SymbolUserMap users(table, op);
147  for (auto module : op.getOps<HWModuleOp>())
148  if (users.useEmpty(module))
149  useEmpty.insert(module);
150  }
151 
152  uint64_t match(HWModuleOp op) override {
153  return op.getNumOutputPorts() != 0 && useEmpty.contains(op);
154  }
155 
156  LogicalResult rewrite(HWModuleOp op) override {
157  Operation *terminator = op.getBody().front().getTerminator();
158  auto operands = terminator->getOperands();
159  ValueRange newOutputs = operands.drop_back();
160  unsigned portToErase = op.getNumOutputPorts() - 1;
161  if (Front) {
162  newOutputs = operands.drop_front();
163  portToErase = 0;
164  }
165 
166  terminator->setOperands(newOutputs);
167  op.erasePorts({}, {portToErase});
168 
169  return success();
170  }
171 
172  std::string getName() const override {
173  return Front ? "hw-module-output-pruner-front"
174  : "hw-module-output-pruner-back";
175  }
176 
177  DenseSet<HWModuleOp> useEmpty;
178 };
179 
180 /// Remove all input ports of the top-level module that have no users
181 struct ModuleInputPruner : public OpReduction<HWModuleOp> {
182  void beforeReduction(mlir::ModuleOp op) override {
183  useEmpty.clear();
184 
185  SymbolTableCollection table;
186  SymbolUserMap users(table, op);
187  for (auto module : op.getOps<HWModuleOp>())
188  if (users.useEmpty(module))
189  useEmpty.insert(module);
190  }
191 
192  uint64_t match(HWModuleOp op) override { return useEmpty.contains(op); }
193 
194  LogicalResult rewrite(HWModuleOp op) override {
195  SmallVector<unsigned> inputsToErase;
196  BitVector toErase(op.getNumPorts());
197  for (auto [i, arg] : llvm::enumerate(op.getBody().getArguments())) {
198  if (arg.use_empty()) {
199  toErase.set(i);
200  inputsToErase.push_back(i);
201  }
202  }
203 
204  op.erasePorts(inputsToErase, {});
205  op.getBodyBlock()->eraseArguments(toErase);
206 
207  return success();
208  }
209 
210  std::string getName() const override { return "hw-module-input-pruner"; }
211 
212  DenseSet<HWModuleOp> useEmpty;
213 };
214 
215 //===----------------------------------------------------------------------===//
216 // Reduction Registration
217 //===----------------------------------------------------------------------===//
218 
219 void HWReducePatternDialectInterface::populateReducePatterns(
221  // Gather a list of reduction patterns that we should try. Ideally these are
222  // assigned reasonable benefit indicators (higher benefit patterns are
223  // prioritized). For example, things that can knock out entire modules while
224  // being cheap should be tried first (and thus have higher benefit), before
225  // trying to tweak operands of individual arithmetic ops.
226  patterns.add<ModuleExternalizer, 6>();
227  patterns.add<HWConstantifier, 5>();
233  patterns.add<ModuleInputPruner, 2>();
234 }
235 
237  mlir::DialectRegistry &registry) {
238  registry.addExtension(+[](MLIRContext *ctx, HWDialect *dialect) {
239  dialect->addInterfaces<HWReducePatternDialectInterface>();
240  });
241 }
assert(baseType &&"element must be base type")
void registerReducePatternDialectInterface(mlir::DialectRegistry &registry)
Register the Arc Reduction pattern dialect interface to the given registry.
std::map< std::string, std::set< std::string > > InstanceGraph
Iterates over the handshake::FuncOp's in the program to build an instance graph.
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
Definition: hw.py:1
A sample reduction pattern that replaces integer operations with a constant zero of their type.
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
A sample reduction pattern that replaces all uses of an operation with one of its operands.
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
std::string getName() const override
Return a human-readable name for this reduction pattern.
A sample reduction pattern that maps hw.module to hw.module.extern.
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(HWModuleOp op) override
std::unique_ptr< InstanceGraph > instanceGraph
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
uint64_t match(HWModuleOp op) override
ModuleSizeCache moduleSizes
Remove all input ports of the top-level module that have no users.
uint64_t match(HWModuleOp op) override
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
DenseSet< HWModuleOp > useEmpty
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(HWModuleOp op) override
Remove the first or last output of the top-level module depending on the 'Front' template parameter.
LogicalResult rewrite(HWModuleOp op) override
DenseSet< HWModuleOp > useEmpty
std::string getName() const override
Return a human-readable name for this reduction pattern.
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
uint64_t match(HWModuleOp op) override
Utility to track the transitive size of modules.
uint64_t getModuleSize(HWModuleLike module, hw::InstanceGraph &instanceGraph)
An abstract reduction pattern.
Definition: Reduction.h:24
A dialect interface to provide reduction patterns to a reducer tool.
Definition: HWReductions.h:18