29 #include "mlir/IR/ImplicitLocOpBuilder.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
33 #define DEBUG_TYPE "firrtl-merge-connections"
35 using namespace circt;
36 using namespace firrtl;
40 if (isa_and_nonnull<ConstantOp, InvalidValueOp>(value.getDefiningOp()))
42 if (
auto bitcast = value.getDefiningOp<BitCastOp>())
55 struct MergeConnection {
56 MergeConnection(FModuleOp moduleOp,
bool enableAggressiveMerging)
57 : moduleOp(moduleOp), enableAggressiveMerging(enableAggressiveMerging) {}
64 bool peelConnect(StrictConnectOp
connect);
68 DenseMap<FieldRef, std::pair<unsigned, SmallVector<StrictConnectOp>>>
72 ImplicitLocOpBuilder *
builder =
nullptr;
76 bool enableAggressiveMerging =
false;
79 bool MergeConnection::peelConnect(StrictConnectOp
connect) {
83 LLVM_DEBUG(llvm::dbgs() <<
"Visiting " <<
connect <<
"\n");
84 auto destTy = type_dyn_cast<FIRRTLBaseType>(
connect.getDest().getType());
85 if (!destTy || !destTy.isPassive() ||
90 auto destRoot = destFieldRef.getValue();
94 if (destRoot.getDefiningOp<MemOp>() || destRoot ==
connect.getDest())
99 if (
auto subfield = dyn_cast<SubfieldOp>(
connect.getDest().getDefiningOp()))
100 parent = subfield.getInput(), index = subfield.getFieldIndex();
101 else if (
auto subindex =
102 dyn_cast<SubindexOp>(
connect.getDest().getDefiningOp()))
103 parent = subindex.getInput(), index = subindex.getIndex();
105 llvm_unreachable(
"unexpected destination");
108 auto &count = countAndSubConnections.first;
109 auto &subConnections = countAndSubConnections.second;
114 if (
auto bundle = type_dyn_cast<BundleType>(parent.getType()))
115 subConnections.resize(bundle.getNumElements());
116 if (
auto vector = type_dyn_cast<FVectorType>(parent.getType()))
117 subConnections.resize(vector.getNumElements());
120 subConnections[index] =
connect;
123 if (count != subConnections.size())
126 auto parentType = parent.getType();
127 auto parentBaseTy = type_dyn_cast<FIRRTLBaseType>(parentType);
130 if (!parentBaseTy || !parentBaseTy.isPassive())
135 auto getMergedValue = [&](
auto aggregateType) {
136 SmallVector<Value> operands;
140 bool canUseSourceParent =
true;
141 bool areOperandsAllConstants =
true;
146 auto checkSourceParent = [&](
auto subelement,
unsigned destIndex,
147 unsigned sourceIndex) {
149 if (destIndex == 0) {
150 if (subelement.getInput().getType() == parentType)
151 sourceParent = subelement.getInput();
154 canUseSourceParent =
false;
159 canUseSourceParent &=
160 subelement.getInput() == sourceParent && destIndex == sourceIndex;
163 for (
auto idx : llvm::seq(0u, (
unsigned)aggregateType.getNumElements())) {
164 auto src = subConnections[idx].getSrc();
165 assert(src &&
"all subconnections are guranteed to exist");
166 operands.push_back(src);
174 if (!canUseSourceParent)
178 if (!src.getDefiningOp()) {
179 canUseSourceParent =
false;
183 TypeSwitch<Operation *>(src.getDefiningOp())
184 .template Case<SubfieldOp>([&](SubfieldOp subfield) {
185 checkSourceParent(subfield, idx, subfield.getFieldIndex());
187 .
template Case<SubindexOp>([&](SubindexOp subindex) {
188 checkSourceParent(subindex, idx, subindex.getIndex());
190 .Default([&](
auto) { canUseSourceParent =
false; });
195 if (canUseSourceParent) {
196 LLVM_DEBUG(llvm::dbgs() <<
"Success to merge " << destFieldRef.getValue()
197 <<
" ,fieldID= " << destFieldRef.getFieldID()
198 <<
" to " << sourceParent <<
"\n";);
201 for (
auto idx : llvm::seq(0u,
static_cast<unsigned>(operands.size())))
203 subConnections[idx].erase();
209 if (!enableAggressiveMerging && !areOperandsAllConstants)
212 SmallVector<Location> locs;
214 for (
auto idx : llvm::seq(0u,
static_cast<unsigned>(operands.size()))) {
215 locs.push_back(subConnections[idx].getLoc());
219 subConnections[idx].erase();
222 return isa<FVectorType>(parentType)
223 ?
builder->createOrFold<VectorCreateOp>(
224 builder->getFusedLoc(locs), parentType, operands)
225 :
builder->createOrFold<BundleCreateOp>(
226 builder->getFusedLoc(locs), parentType, operands);
230 if (
auto bundle = type_dyn_cast<BundleType>(parentType))
231 merged = getMergedValue(bundle);
232 if (
auto vector = type_dyn_cast<FVectorType>(parentType))
233 merged = getMergedValue(vector);
239 if (!parentBaseTy.hasUninferredWidth())
240 builder->create<StrictConnectOp>(
connect.getLoc(), parent, merged);
247 bool MergeConnection::run() {
248 ImplicitLocOpBuilder theBuilder(moduleOp.getLoc(), moduleOp.getContext());
250 auto *body = moduleOp.getBodyBlock();
252 for (
auto it = body->begin(), e = body->end(); it != e;) {
253 auto connectOp = dyn_cast<StrictConnectOp>(*it);
258 builder->setInsertionPointAfter(connectOp);
259 builder->setLoc(connectOp.getLoc());
260 bool removeOp = peelConnect(connectOp);
267 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*body)))
268 if (isa<SubfieldOp, SubindexOp, InvalidValueOp, ConstantOp, BitCastOp,
270 if (op.use_empty()) {
278 struct MergeConnectionsPass
279 :
public MergeConnectionsBase<MergeConnectionsPass> {
280 MergeConnectionsPass(
bool enableAggressiveMergingFlag) {
281 enableAggressiveMerging = enableAggressiveMergingFlag;
283 void runOnOperation()
override;
288 void MergeConnectionsPass::runOnOperation() {
291 <<
"Module: '" << getOperation().
getName() <<
"'\n");
293 MergeConnection mergeConnection(getOperation(), enableAggressiveMerging);
294 bool changed = mergeConnection.run();
297 return markAllAnalysesPreserved();
300 std::unique_ptr<mlir::Pass>
302 return std::make_unique<MergeConnectionsPass>(enableAggressiveMerging);
assert(baseType &&"element must be base type")
static bool isConstantLike(Value value)
def connect(destination, source)
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
std::unique_ptr< mlir::Pass > createMergeConnectionsPass(bool enableAggressiveMerging=false)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.