21#include "mlir/IR/Builders.h"
22#include "llvm/Support/Debug.h"
24#define DEBUG_TYPE "hw-parameterize-constant-ports"
28#define GEN_PASS_DEF_HWPARAMETERIZECONSTANTPORTS
29#include "circt/Dialect/HW/Passes.h.inc"
38static std::pair<Attribute, Operation *>
40 Value value = inst.getInputs()[portIndex];
41 auto *op = value.getDefiningOp();
44 if (
auto constOp = dyn_cast<hw::ConstantOp>(op))
45 return {constOp.getValueAttr(), op};
46 if (
auto paramOp = dyn_cast<hw::ParamValueOp>(op))
47 return {paramOp.getValueAttr(), op};
57 for (
auto *instRecord : node->
uses()) {
58 auto inst = dyn_cast<InstanceOp>(instRecord->getInstance().getOperation());
66struct HWParameterizeConstantPortsPass
67 :
public circt::hw::impl::HWParameterizeConstantPortsBase<
68 HWParameterizeConstantPortsPass> {
69 void runOnOperation()
override;
77void HWParameterizeConstantPortsPass::processModule(
81 if (!module.isPrivate() || node->
noUses())
86 SmallVector<hw::PortInfo> inputPorts(portInfo.getInputs());
87 SmallVector<unsigned> portsToParameterize;
89 for (
auto [idx, port] :
llvm::enumerate(inputPorts)) {
91 if (port.dir != ModulePort::Direction::Input || port.getSym())
95 portsToParameterize.push_back(idx);
98 if (portsToParameterize.empty())
101 LLVM_DEBUG(llvm::dbgs() <<
"Parameterizing " << portsToParameterize.size()
102 <<
" ports in module " << module.getModuleName()
105 OpBuilder builder(module.getContext());
106 builder.setInsertionPointToStart(module.getBodyBlock());
109 SmallVector<Attribute> newParameters;
111 if (
auto existingParams = module.getParameters()) {
112 newParameters.append(existingParams.begin(), existingParams.end());
113 for (
auto param : existingParams)
114 paramNamespace.newName(cast<ParamDeclAttr>(param).
getName().str());
118 DenseMap<unsigned, StringAttr> portToParamName;
120 for (
unsigned portIdx : portsToParameterize) {
121 auto port = inputPorts[portIdx];
125 builder.getStringAttr(paramNamespace.
newName(port.name.str()));
126 portToParamName[portIdx] = paramNameAttr;
129 auto paramDecl = ParamDeclAttr::get(paramNameAttr, port.type);
130 newParameters.push_back(paramDecl);
133 auto paramRef = ParamDeclRefAttr::get(paramNameAttr, port.type);
135 ParamValueOp::create(builder, module.getLoc(), port.
type, paramRef);
138 module.getBodyBlock()->getArgument(portIdx).replaceAllUsesWith(
143 module.setParametersAttr(builder.getArrayAttr(newParameters));
146 module.modifyPorts({}, {}, portsToParameterize, {});
149 for (
auto idx :
llvm::reverse(portsToParameterize))
153 DenseSet<unsigned> portsToRemoveSet(portsToParameterize.begin(),
154 portsToParameterize.end());
155 SmallVector<Attribute> newPortNames;
156 for (
auto [idx, port] :
llvm::enumerate(portInfo.getInputs()))
157 if (!portsToRemoveSet.count(idx))
158 newPortNames.push_back(port.name);
160 ArrayAttr newPortNamesAttr = builder.getArrayAttr(newPortNames);
163 for (
auto *instRecord : node->uses()) {
164 auto inst = dyn_cast<InstanceOp>(instRecord->getInstance().getOperation());
170 builder.setInsertionPoint(inst);
173 SmallVector<Attribute> instParams;
174 if (
auto existingParams = inst.getParameters())
175 instParams.append(existingParams.begin(), existingParams.end());
177 for (
unsigned portIdx : portsToParameterize) {
179 assert(paramValueAttr &&
"expected constant or param value");
181 ParamDeclAttr::get(builder.getContext(), portToParamName[portIdx],
182 inputPorts[portIdx].type, paramValueAttr);
183 instParams.push_back(paramDecl);
187 if (constOp->hasOneUse()) {
188 constOp->dropAllUses();
194 SmallVector<Value> newInputs;
195 for (
auto [idx, input] :
llvm::enumerate(inst.getInputs()))
196 if (!portsToRemoveSet.count(idx))
197 newInputs.push_back(input);
200 auto newInst = InstanceOp::create(
201 builder, inst.getLoc(), inst.getResultTypes(),
202 inst.getInstanceNameAttr(), inst.getModuleNameAttr(), newInputs,
203 newPortNamesAttr, inst.getResultNamesAttr(),
204 builder.getArrayAttr(instParams), inst.getInnerSymAttr(),
205 inst.getDoNotPrintAttr());
209 inst.replaceAllUsesWith(newInst.getResults());
214void HWParameterizeConstantPortsPass::runOnOperation() {
215 auto &instanceGraph = getAnalysis<hw::InstanceGraph>();
220 dyn_cast_or_null<HWModuleOp>(node.
getModule().getOperation()))
221 processModule(module, &node, instanceGraph);
225 markAnalysesPreserved<hw::InstanceGraph>();
assert(baseType &&"element must be base type")
static std::pair< Attribute, Operation * > getAttributeAndDefiningOp(InstanceOp inst, unsigned portIndex)
Helper to extract the attribute value and defining operation from a constant or param....
static bool allInstancesHaveConstantForPort(igraph::InstanceGraphNode *node, unsigned portIndex)
Check if all instances have constant values for a given port.
static Block * getBodyBlock(FModuleLike mod)
A namespace that is used to store existing names and generate new names in some scope within the IR.
StringRef newName(const Twine &name)
Return a unique name, derived from the input name, and add the new name to the internal namespace.
HW-specific instance graph with a virtual entry node linking to all publicly visible modules.
This is a Node in the InstanceGraph.
llvm::iterator_range< UseIterator > uses()
bool noUses()
Return true if there are no more instances of this module.
auto getModule()
Get the module that this node is tracking.
virtual void replaceInstance(InstanceOpInterface inst, InstanceOpInterface newInst)
Replaces an instance of a module with another instance.
decltype(auto) walkInversePostOrder(Fn &&fn)
Perform an inverse-post-order walk across the modules.
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
This holds a decoded list of input/inout and output ports for a module or instance.