13#include "llvm/ADT/SmallSet.h"
14#include "llvm/Support/Debug.h"
16#define DEBUG_TYPE "hw-reductions"
35 module->walk([&](Operation *op) {
37 if (
auto instOp = dyn_cast<HWInstanceLike>(op)) {
38 for (
auto moduleName : instOp.getReferencedModuleNamesAttr()) {
39 auto *node = instanceGraph.
lookup(cast<StringAttr>(moduleName));
41 dyn_cast_or_null<hw::HWModuleLike>(*node->getModule()))
51 llvm::DenseMap<Operation *, uint64_t> moduleSizes;
70 OpBuilder builder(op);
71 HWModuleExternOp::create(builder, op->getLoc(), op.getModuleNameAttr(),
72 op.getPortList(), StringRef(), op.getParameters());
77 std::string
getName()
const override {
return "hw-module-externalizer"; }
87template <
unsigned OpNum>
89 uint64_t
match(Operation *op)
override {
90 if (op->getNumResults() != 1 || op->getNumOperands() < 2 ||
91 OpNum >= op->getNumOperands())
93 auto resultTy = dyn_cast<IntegerType>(op->getResult(0).getType());
94 auto opTy = dyn_cast<IntegerType>(op->getOperand(OpNum).getType());
95 return resultTy && opTy && resultTy == opTy &&
96 op->getResult(0) != op->getOperand(OpNum);
98 LogicalResult
rewrite(Operation *op)
override {
100 ImplicitLocOpBuilder builder(op->getLoc(), op);
101 auto result = op->getResult(0);
102 auto operand = op->getOperand(OpNum);
103 LLVM_DEBUG(llvm::dbgs()
104 <<
"Forwarding " << operand <<
" in " << *op <<
"\n");
105 result.replaceAllUsesWith(operand);
110 return (
"hw-operand" + Twine(OpNum) +
"-forwarder").str();
118 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
119 if (op->hasTrait<OpTrait::ConstantLike>())
121 for (
auto result : op->getResults())
122 if (!result.use_empty())
123 if (isa<IntegerType>(result.getType()))
124 addMatch(1, result.getResultNumber());
127 ArrayRef<uint64_t> indices)
override {
128 OpBuilder builder(op);
129 for (
auto idx : indices) {
130 auto result = op->getResult(idx);
131 auto type = cast<IntegerType>(result.getType());
133 result.replaceAllUsesWith(newOp);
138 std::string
getName()
const override {
return "hw-constantifier"; }
145 symbolTables = std::make_unique<SymbolTableCollection>();
150 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
151 auto mod = dyn_cast<HWModuleLike>(op);
154 auto modType = mod.getHWModuleType();
155 if (modType.getNumInputs() == 0)
158 if (!llvm::all_of(users, [](
auto *user) {
return isa<InstanceOp>(user); }))
160 auto *block = mod.getBodyBlock();
161 for (
unsigned idx = 0; idx < modType.getNumInputs(); ++idx)
162 if (!block || block->getArgument(idx).use_empty())
167 ArrayRef<uint64_t>
matches)
override {
168 auto mod = cast<HWMutableModuleLike>(op);
171 SmallVector<unsigned> indexList;
172 BitVector indexSet(mod.getNumInputPorts());
174 indexList.push_back(idx);
177 llvm::sort(indexList);
178 mod.erasePorts(indexList, {});
179 if (
auto *block = mod.getBodyBlock())
180 block->eraseArguments(indexSet);
184 auto instOp = cast<InstanceOp>(user);
185 SmallVector<Value> newOperands;
186 SmallVector<Attribute> newArgNames;
187 for (
auto [idx, data] : llvm::enumerate(
188 llvm::zip(instOp.getInputs(), instOp.getArgNames()))) {
189 if (indexSet.test(idx))
191 auto [operand, argName] = data;
192 newOperands.push_back(operand);
193 newArgNames.push_back(argName);
195 instOp.getInputsMutable().assign(newOperands);
196 instOp.setArgNamesAttr(ArrayAttr::get(op->getContext(), newArgNames));
202 std::string
getName()
const override {
return "hw-module-input-pruner"; }
212 symbolTables = std::make_unique<SymbolTableCollection>();
217 llvm::function_ref<
void(uint64_t, uint64_t)> addMatch)
override {
218 auto mod = dyn_cast<HWModuleLike>(op);
221 auto modType = mod.getHWModuleType();
222 if (modType.getNumOutputs() == 0)
225 if (!llvm::all_of(users, [](
auto *user) {
return isa<InstanceOp>(user); }))
227 for (
unsigned idx = 0; idx < modType.getNumOutputs(); ++idx)
228 if (llvm::all_of(users, [&](
auto *user) {
229 return user->getResult(idx).use_empty();
235 ArrayRef<uint64_t>
matches)
override {
236 auto mod = cast<HWMutableModuleLike>(op);
239 SmallVector<unsigned> indexList;
240 BitVector indexSet(mod.getNumOutputPorts());
242 indexList.push_back(idx);
245 llvm::sort(indexList);
246 mod.erasePorts({}, indexList);
249 if (
auto *block = mod.getBodyBlock()) {
250 auto outputOp = cast<OutputOp>(block->getTerminator());
251 SmallVector<Value> newOutputs;
252 for (
auto [idx, output] : llvm::enumerate(outputOp.getOutputs()))
253 if (!indexSet.test(idx))
254 newOutputs.push_back(output);
255 outputOp.getOutputsMutable().assign(newOutputs);
260 OpBuilder builder(user);
261 auto instOp = cast<InstanceOp>(user);
262 SmallVector<Value> oldResults;
263 SmallVector<Type> newResultTypes;
264 SmallVector<Attribute> newResultNames;
265 for (
auto [idx, data] : llvm::enumerate(
266 llvm::zip(instOp.getResults(), instOp.getResultNames()))) {
267 if (indexSet.test(idx))
269 auto [result, resultName] = data;
270 oldResults.push_back(result);
271 newResultTypes.push_back(result.getType());
272 newResultNames.push_back(resultName);
274 auto newOp = InstanceOp::create(
275 builder, instOp.getLoc(), newResultTypes,
276 instOp.getInstanceNameAttr(), instOp.getModuleNameAttr(),
277 instOp.getInputs(), instOp.getArgNamesAttr(),
278 builder.getArrayAttr(newResultNames), instOp.getParametersAttr(),
279 instOp.getInnerSymAttr(), instOp.getDoNotPrintAttr());
280 for (
auto [oldResult, newResult] :
281 llvm::zip(oldResults, newOp.getResults()))
282 oldResult.replaceAllUsesWith(newResult);
289 std::string
getName()
const override {
return "hw-module-output-pruner"; }
300 uint64_t
match(Operation *op)
override {
return op->hasAttr(
"sv.namehint"); }
301 LogicalResult
rewrite(Operation *op)
override {
302 op->removeAttr(
"sv.namehint");
305 std::string
getName()
const override {
return "sv-namehint-remover"; }
318 uint64_t
match(Operation *op)
override {
320 return isa<hw::WireOp>(op);
322 LogicalResult
rewrite(Operation *op)
override {
323 cast<hw::WireOp>(op).setName(
"wire");
328 return "hw-module-internal-name-sanitizer";
348 if (portNameIndex >= 26)
350 return 'a' + portNameIndex++;
353 LogicalResult
rewrite(mlir::ModuleOp moduleOp)
override {
362 for (
auto *node : instanceGraph) {
364 if (!node->getModule())
366 auto hwModule = dyn_cast<HWModuleOp>(node->getModule().getOperation());
370 auto *
context = hwModule.getContext();
375 auto numPorts = hwModule.getNumPorts();
376 SmallVector<Attribute> newPortNames(numPorts);
377 for (
unsigned i = 0; i != numPorts; ++i)
379 hwModule.setAllPortNames(newPortNames);
382 for (
auto *use : node->uses()) {
383 auto useOp = use->getInstance();
386 auto instOp = dyn_cast<hw::InstanceOp>(*useOp);
389 instOp.setModuleName(newName);
390 instOp.setInstanceName(newName);
392 auto *inputEnd = newPortNames.begin() + hwModule.getNumInputPorts();
393 SmallVector<Attribute> argNames(newPortNames.begin(), inputEnd);
394 instOp.setArgNamesAttr(ArrayAttr::get(
context, argNames));
396 SmallVector<Attribute> resultNames(inputEnd, newPortNames.end());
397 instOp.setResultNamesAttr(ArrayAttr::get(
context, resultNames));
401 hwModule.setName(newName);
407 std::string
getName()
const override {
return "hw-module-name-sanitizer"; }
438 mlir::DialectRegistry ®istry) {
439 registry.addExtension(+[](MLIRContext *ctx, HWDialect *dialect) {
assert(baseType &&"element must be base type")
static std::unique_ptr< Context > context
HW-specific instance graph with a virtual entry node linking to all publicly visible modules.
InstanceGraphNode * lookup(ModuleOpInterface op)
Look up an InstanceGraphNode for a module.
void registerReducePatternDialectInterface(mlir::DialectRegistry ®istry)
Register the HW Reduction pattern dialect interface to the given registry.
void pruneUnusedOps(Operation *initialOp, Reduction &reduction)
Starting at the given op, traverse through it and its operands and erase operations that have no more...
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
A sample reduction pattern that replaces integer operations with a constant zero of their type.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > indices) override
Apply a set of matches of this reduction to a specific operation.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
std::string getName() const override
Return a human-readable name for this reduction pattern.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
A sample reduction pattern that replaces all uses of an operation with one of its operands.
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
std::string getName() const override
Return a human-readable name for this reduction pattern.
A sample reduction pattern that maps hw.module to hw.module.extern.
std::string getName() const override
Return a human-readable name for this reduction pattern.
LogicalResult rewrite(HWModuleOp op) override
std::unique_ptr< InstanceGraph > instanceGraph
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
uint64_t match(HWModuleOp op) override
ModuleSizeCache moduleSizes
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
Pseudo-reduction that sanitizes the names of operations inside modules.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
Pseudo-reduction that sanitizes module and port names.
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
LogicalResult rewrite(mlir::ModuleOp moduleOp) override
Remove unused module output ports.
void beforeReduction(mlir::ModuleOp op) override
Called before the reduction is applied to a new subset of operations.
void matches(Operation *op, llvm::function_ref< void(uint64_t, uint64_t)> addMatch) override
Collect all ways how this reduction can apply to a specific operation.
LogicalResult rewriteMatches(Operation *op, ArrayRef< uint64_t > matches) override
Apply a set of matches of this reduction to a specific operation.
std::unique_ptr< SymbolTableCollection > symbolTables
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
std::unique_ptr< SymbolUserMap > symbolUsers
Utility to track the transitive size of modules.
llvm::DenseMap< Operation *, uint64_t > moduleSizes
uint64_t getModuleSize(Operation *module, ::detail::SymbolCache &symbols)
uint64_t getModuleSize(HWModuleLike module, hw::InstanceGraph &instanceGraph)
Pseudo-reduction that removes sv.namehint attributes from operations.
bool isOneShot() const override
Return true if the tool should not try to reapply this reduction after it has been successful.
uint64_t match(Operation *op) override
Check if the reduction can apply to a specific operation.
std::string getName() const override
Return a human-readable name for this reduction pattern.
bool acceptSizeIncrease() const override
Return true if the tool should accept the transformation this reduction performs on the module even i...
LogicalResult rewrite(Operation *op) override
Apply the reduction to a specific operation.
A reduction pattern for a specific operation.
An abstract reduction pattern.
A dialect interface to provide reduction patterns to a reducer tool.
void populateReducePatterns(circt::ReducePatternSet &patterns) const override