21#include "mlir/IR/Builders.h"
22#include "llvm/Support/Debug.h"
24#define DEBUG_TYPE "hw-parameterize-constant-ports"
28#define GEN_PASS_DEF_HWPARAMETERIZECONSTANTPORTS
29#include "circt/Dialect/HW/Passes.h.inc"
38static std::pair<Attribute, Operation *>
40 Value value = inst.getInputs()[portIndex];
41 auto *op = value.getDefiningOp();
44 if (
auto constOp = dyn_cast<hw::ConstantOp>(op))
45 return {constOp.getValueAttr(), op};
46 if (
auto paramOp = dyn_cast<hw::ParamValueOp>(op))
47 return {paramOp.getValueAttr(), op};
57 for (
auto *instRecord : node->
uses()) {
58 auto inst = dyn_cast<InstanceOp>(instRecord->getInstance().getOperation());
66struct HWParameterizeConstantPortsPass
67 :
public circt::hw::impl::HWParameterizeConstantPortsBase<
68 HWParameterizeConstantPortsPass> {
69 void runOnOperation()
override;
77void HWParameterizeConstantPortsPass::processModule(
81 if (!module.isPrivate() || node->
noUses())
86 SmallVector<hw::PortInfo> inputPorts(portInfo.getInputs());
87 SmallVector<unsigned> portsToParameterize;
89 for (
auto [idx, port] :
llvm::enumerate(inputPorts)) {
91 if (port.dir != ModulePort::Direction::Input || port.getSym())
95 portsToParameterize.push_back(idx);
98 if (portsToParameterize.empty())
101 LLVM_DEBUG(llvm::dbgs() <<
"Parameterizing " << portsToParameterize.size()
102 <<
" ports in module " << module.getModuleName()
105 OpBuilder builder(module.getContext());
106 builder.setInsertionPointToStart(module.getBodyBlock());
109 SmallVector<Attribute> newParameters;
111 if (
auto existingParams = module.getParameters()) {
112 newParameters.append(existingParams.begin(), existingParams.end());
113 for (
auto param : existingParams)
114 paramNamespace.newName(cast<ParamDeclAttr>(param).
getName().str());
118 DenseMap<unsigned, StringAttr> portToParamName;
120 for (
unsigned portIdx : portsToParameterize) {
121 auto port = inputPorts[portIdx];
125 builder.getStringAttr(paramNamespace.
newName(port.name.str()));
126 portToParamName[portIdx] = paramNameAttr;
129 auto paramDecl = ParamDeclAttr::get(paramNameAttr, port.type);
130 newParameters.push_back(paramDecl);
133 auto paramRef = ParamDeclRefAttr::get(paramNameAttr, port.type);
135 ParamValueOp::create(builder, module.getLoc(), port.
type, paramRef);
138 module.getBodyBlock()->getArgument(portIdx).replaceAllUsesWith(
143 module.setParametersAttr(builder.getArrayAttr(newParameters));
146 module.modifyPorts({}, {}, portsToParameterize, {});
149 for (
auto idx :
llvm::reverse(portsToParameterize))
153 DenseSet<unsigned> portsToRemoveSet(portsToParameterize.begin(),
154 portsToParameterize.end());
155 SmallVector<Attribute> newPortNames;
156 for (
auto [idx, port] :
llvm::enumerate(portInfo.getInputs()))
157 if (!portsToRemoveSet.count(idx))
158 newPortNames.push_back(port.name);
160 ArrayAttr newPortNamesAttr = builder.getArrayAttr(newPortNames);
163 for (
auto *instRecord : node->uses()) {
164 auto inst = dyn_cast<InstanceOp>(instRecord->getInstance().getOperation());
169 builder.setInsertionPoint(inst);
172 SmallVector<Attribute> instParams;
173 if (
auto existingParams = inst.getParameters())
174 instParams.append(existingParams.begin(), existingParams.end());
176 for (
unsigned portIdx : portsToParameterize) {
178 assert(paramValueAttr &&
"expected constant or param value");
180 ParamDeclAttr::get(builder.getContext(), portToParamName[portIdx],
181 inputPorts[portIdx].type, paramValueAttr);
182 instParams.push_back(paramDecl);
186 if (constOp->hasOneUse()) {
187 constOp->dropAllUses();
193 SmallVector<Value> newInputs;
194 for (
auto [idx, input] :
llvm::enumerate(inst.getInputs()))
195 if (!portsToRemoveSet.count(idx))
196 newInputs.push_back(input);
199 auto newInst = InstanceOp::create(
200 builder, inst.getLoc(), inst.getResultTypes(),
201 inst.getInstanceNameAttr(), inst.getModuleNameAttr(), newInputs,
202 newPortNamesAttr, inst.getResultNamesAttr(),
203 builder.getArrayAttr(instParams), inst.getInnerSymAttr(),
204 inst.getDoNotPrintAttr());
208 inst.replaceAllUsesWith(newInst.getResults());
213void HWParameterizeConstantPortsPass::runOnOperation() {
214 auto &instanceGraph = getAnalysis<hw::InstanceGraph>();
219 dyn_cast_or_null<HWModuleOp>(node.
getModule().getOperation()))
224 markAnalysesPreserved<hw::InstanceGraph>();
assert(baseType &&"element must be base type")
static std::pair< Attribute, Operation * > getAttributeAndDefiningOp(InstanceOp inst, unsigned portIndex)
Helper to extract the attribute value and defining operation from a constant or param....
static bool allInstancesHaveConstantForPort(igraph::InstanceGraphNode *node, unsigned portIndex)
Check if all instances have constant values for a given port.
static LogicalResult processModule(const DomainInfo &info, TermAllocator &allocator, DomainTable &table, const ModuleUpdateTable &updateTable, FModuleOp moduleOp)
Populate the domain table by processing the moduleOp.
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
HW-specific instance graph with a virtual entry node linking to all publicly visible modules.
This is a Node in the InstanceGraph.
llvm::iterator_range< UseIterator > uses()
bool noUses()
Return true if there are no more instances of this module.
auto getModule()
Get the module that this node is tracking.
virtual void replaceInstance(InstanceOpInterface inst, InstanceOpInterface newInst)
Replaces an instance of a module with another instance.
decltype(auto) walkInversePostOrder(Fn &&fn)
Perform an inverse-post-order walk across the modules.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds a decoded list of input/inout and output ports for a module or instance.