13#include "llvm/ADT/SmallSet.h"
14#include "llvm/Support/Debug.h"
16#define DEBUG_TYPE "hw-reductions"
35 module->walk([&](Operation *op) {
37 if (
auto instOp = dyn_cast<HWInstanceLike>(op)) {
38 for (
auto moduleName : instOp.getReferencedModuleNamesAttr()) {
39 auto *node = instanceGraph.
lookup(cast<StringAttr>(moduleName));
41 dyn_cast_or_null<hw::HWModuleLike>(*node->getModule()))
51 llvm::DenseMap<Operation *, uint64_t> moduleSizes;
70 OpBuilder builder(op);
71 HWModuleExternOp::create(builder, op->getLoc(), op.getModuleNameAttr(),
72 op.getPortList(), StringRef(), op.getParameters());
77 std::string
getName()
const override {
return "hw-module-externalizer"; }
87template <
unsigned OpNum>
89 uint64_t
match(Operation *op)
override {
90 if (op->getNumResults() != 1 || op->getNumOperands() < 2 ||
91 OpNum >= op->getNumOperands())
93 auto resultTy = dyn_cast<IntegerType>(op->getResult(0).getType());
94 auto opTy = dyn_cast<IntegerType>(op->getOperand(OpNum).getType());
95 return resultTy && opTy && resultTy == opTy &&
96 op->getResult(0) != op->getOperand(OpNum);
89 uint64_t
match(Operation *op)
override {
…}
98 LogicalResult
rewrite(Operation *op)
override {
100 ImplicitLocOpBuilder builder(op->getLoc(), op);
101 auto result = op->getResult(0);
102 auto operand = op->getOperand(OpNum);
103 LLVM_DEBUG(llvm::dbgs()
104 <<
"Forwarding " << operand <<
" in " << *op <<
"\n");
105 result.replaceAllUsesWith(operand);
98 LogicalResult
rewrite(Operation *op)
override {
…}
110 return (
"hw-operand" + Twine(OpNum) +
"-forwarder").str();
118 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
119 if (op->hasTrait<OpTrait::ConstantLike>())
121 for (
auto result : op->getResults())
122 if (!result.use_empty())
123 if (isa<IntegerType>(result.getType()))
124 addMatch(1, result.getResultNumber());
127 ArrayRef<uint64_t> indices)
override {
128 OpBuilder builder(op);
129 for (
auto idx : indices) {
130 auto result = op->getResult(idx);
131 auto type = cast<IntegerType>(result.getType());
133 result.replaceAllUsesWith(newOp);
138 std::string
getName()
const override {
return "hw-constantifier"; }
145 symbolTables = std::make_unique<SymbolTableCollection>();
150 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
151 auto mod = dyn_cast<HWModuleLike>(op);
154 auto modType = mod.getHWModuleType();
155 if (modType.getNumInputs() == 0)
158 if (!llvm::all_of(users, [](
auto *user) {
return isa<InstanceOp>(user); }))
160 auto *block = mod.getBodyBlock();
161 for (
unsigned idx = 0; idx < modType.getNumInputs(); ++idx)
162 if (!block || block->getArgument(idx).use_empty())
167 ArrayRef<uint64_t>
matches)
override {
168 auto mod = cast<HWMutableModuleLike>(op);
171 SmallVector<unsigned> indexList;
172 BitVector indexSet(mod.getNumInputPorts());
174 indexList.push_back(idx);
177 llvm::sort(indexList);
178 mod.erasePorts(indexList, {});
179 if (
auto *block = mod.getBodyBlock())
180 block->eraseArguments(indexSet);
184 auto instOp = cast<InstanceOp>(user);
185 SmallVector<Value> newOperands;
186 SmallVector<Attribute> newArgNames;
187 for (
auto [idx, data] : llvm::enumerate(
188 llvm::zip(instOp.getInputs(), instOp.getArgNames()))) {
189 if (indexSet.test(idx))
191 auto [operand, argName] = data;
192 newOperands.push_back(operand);
193 newArgNames.push_back(argName);
195 instOp.getInputsMutable().assign(newOperands);
196 instOp.setArgNamesAttr(ArrayAttr::get(op->getContext(), newArgNames));
202 std::string
getName()
const override {
return "hw-module-input-pruner"; }
212 symbolTables = std::make_unique<SymbolTableCollection>();
217 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
218 auto mod = dyn_cast<HWModuleLike>(op);
221 auto modType = mod.getHWModuleType();
222 if (modType.getNumOutputs() == 0)
225 if (!llvm::all_of(users, [](
auto *user) {
return isa<InstanceOp>(user); }))
227 for (
unsigned idx = 0; idx < modType.getNumOutputs(); ++idx)
228 if (llvm::all_of(users, [&](
auto *user) {
229 return user->getResult(idx).use_empty();
235 ArrayRef<uint64_t>
matches)
override {
236 auto mod = cast<HWMutableModuleLike>(op);
239 SmallVector<unsigned> indexList;
240 BitVector indexSet(mod.getNumOutputPorts());
242 indexList.push_back(idx);
245 llvm::sort(indexList);
246 mod.erasePorts({}, indexList);
249 if (
auto *block = mod.getBodyBlock()) {
250 auto outputOp = cast<OutputOp>(block->getTerminator());
251 SmallVector<Value> newOutputs;
252 for (
auto [idx, output] : llvm::enumerate(outputOp.getOutputs()))
253 if (!indexSet.test(idx))
254 newOutputs.push_back(output);
255 outputOp.getOutputsMutable().assign(newOutputs);
260 OpBuilder builder(user);
261 auto instOp = cast<InstanceOp>(user);
262 SmallVector<Value> oldResults;
263 SmallVector<Type> newResultTypes;
264 SmallVector<Attribute> newResultNames;
265 for (
auto [idx, data] : llvm::enumerate(
266 llvm::zip(instOp.getResults(), instOp.getResultNames()))) {
267 if (indexSet.test(idx))
269 auto [result, resultName] = data;
270 oldResults.push_back(result);
271 newResultTypes.push_back(result.getType());
272 newResultNames.push_back(resultName);
274 auto newOp = InstanceOp::create(
275 builder, instOp.getLoc(), newResultTypes,
276 instOp.getInstanceNameAttr(), instOp.getModuleNameAttr(),
277 instOp.getInputs(), instOp.getArgNamesAttr(),
278 builder.getArrayAttr(newResultNames), instOp.getParametersAttr(),
279 instOp.getInnerSymAttr(), instOp.getDoNotPrintAttr());
280 for (
auto [oldResult, newResult] :
281 llvm::zip(oldResults, newOp.getResults()))
282 oldResult.replaceAllUsesWith(newResult);
289 std::string
getName()
const override {
return "hw-module-output-pruner"; }
317 mlir::DialectRegistry ®istry) {
318 registry.addExtension(+[](MLIRContext *ctx, HWDialect *dialect) {
assert(baseType &&"element must be base type")
HW-specific instance graph with a virtual entry node linking to all publicly visible modules.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
void registerReducePatternDialectInterface(mlir::DialectRegistry ®istry)
Register the HW Reduction pattern dialect interface to the given registry.
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
A sample reduction pattern that replaces integer operations with a constant zero of their type.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > indices) override
Apply a set of matches of this reduction to a specific operation.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
std::string getName() const override
Return a human-readable name for this reduction pattern.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
A sample reduction pattern that replaces all uses of an operation with one of its operands.
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
std::string getName() const override
Return a human-readable name for this reduction pattern.
A sample reduction pattern that maps hw.module to hw.module.extern.
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(HWModuleOp op) override
std::unique_ptr< InstanceGraph > instanceGraph
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
uint64_t match(HWModuleOp op) override
ModuleSizeCache moduleSizes
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
Remove unused module output ports.
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
std::unique_ptr< SymbolTableCollection > symbolTables
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
std::unique_ptr< SymbolUserMap > symbolUsers
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
uint64_t getModuleSize(HWModuleLike module, hw::InstanceGraph &instanceGraph)
An abstract reduction pattern.
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override