26 #include "mlir/Pass/Pass.h"
31 #include "mlir/IR/ImplicitLocOpBuilder.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
35 #define DEBUG_TYPE "firrtl-merge-connections"
39 #define GEN_PASS_DEF_MERGECONNECTIONS
40 #include "circt/Dialect/FIRRTL/Passes.h.inc"
44 using namespace circt;
45 using namespace firrtl;
49 if (isa_and_nonnull<ConstantOp, InvalidValueOp>(value.getDefiningOp()))
51 if (
auto bitcast = value.getDefiningOp<BitCastOp>())
64 struct MergeConnection {
65 MergeConnection(FModuleOp moduleOp,
bool enableAggressiveMerging)
66 : moduleOp(moduleOp), enableAggressiveMerging(enableAggressiveMerging) {}
73 bool peelConnect(MatchingConnectOp
connect);
77 DenseMap<FieldRef, std::pair<unsigned, SmallVector<MatchingConnectOp>>>
81 ImplicitLocOpBuilder *builder =
nullptr;
85 bool enableAggressiveMerging =
false;
88 bool MergeConnection::peelConnect(MatchingConnectOp
connect) {
92 LLVM_DEBUG(llvm::dbgs() <<
"Visiting " <<
connect <<
"\n");
93 auto destTy = type_dyn_cast<FIRRTLBaseType>(
connect.getDest().getType());
94 if (!destTy || !destTy.isPassive() ||
99 auto destRoot = destFieldRef.getValue();
103 if (destRoot.getDefiningOp<MemOp>() || destRoot ==
connect.getDest())
108 if (
auto subfield = dyn_cast<SubfieldOp>(
connect.getDest().getDefiningOp()))
109 parent = subfield.getInput(), index = subfield.getFieldIndex();
110 else if (
auto subindex =
111 dyn_cast<SubindexOp>(
connect.getDest().getDefiningOp()))
112 parent = subindex.getInput(), index = subindex.getIndex();
114 llvm_unreachable(
"unexpected destination");
117 auto &count = countAndSubConnections.first;
118 auto &subConnections = countAndSubConnections.second;
123 if (
auto bundle = type_dyn_cast<BundleType>(parent.getType()))
124 subConnections.resize(bundle.getNumElements());
125 if (
auto vector = type_dyn_cast<FVectorType>(parent.getType()))
126 subConnections.resize(vector.getNumElements());
129 subConnections[index] =
connect;
132 if (count != subConnections.size())
135 auto parentType = parent.getType();
136 auto parentBaseTy = type_dyn_cast<FIRRTLBaseType>(parentType);
139 if (!parentBaseTy || !parentBaseTy.isPassive())
144 auto getMergedValue = [&](
auto aggregateType) {
145 SmallVector<Value> operands;
149 bool canUseSourceParent =
true;
150 bool areOperandsAllConstants =
true;
155 auto checkSourceParent = [&](
auto subelement,
unsigned destIndex,
156 unsigned sourceIndex) {
158 if (destIndex == 0) {
159 if (subelement.getInput().getType() == parentType)
160 sourceParent = subelement.getInput();
163 canUseSourceParent =
false;
168 canUseSourceParent &=
169 subelement.getInput() == sourceParent && destIndex == sourceIndex;
172 for (
auto idx : llvm::seq(0u, (
unsigned)aggregateType.getNumElements())) {
173 auto src = subConnections[idx].getSrc();
174 assert(src &&
"all subconnections are guranteed to exist");
175 operands.push_back(src);
183 if (!canUseSourceParent)
187 if (!src.getDefiningOp()) {
188 canUseSourceParent =
false;
192 TypeSwitch<Operation *>(src.getDefiningOp())
193 .template Case<SubfieldOp>([&](SubfieldOp subfield) {
194 checkSourceParent(subfield, idx, subfield.getFieldIndex());
196 .
template Case<SubindexOp>([&](SubindexOp subindex) {
197 checkSourceParent(subindex, idx, subindex.getIndex());
199 .Default([&](
auto) { canUseSourceParent =
false; });
204 if (canUseSourceParent) {
205 LLVM_DEBUG(llvm::dbgs() <<
"Success to merge " << destFieldRef.getValue()
206 <<
" ,fieldID= " << destFieldRef.getFieldID()
207 <<
" to " << sourceParent <<
"\n";);
210 for (
auto idx : llvm::seq(0u,
static_cast<unsigned>(operands.size())))
212 subConnections[idx].erase();
218 if (!enableAggressiveMerging && !areOperandsAllConstants)
221 SmallVector<Location> locs;
223 for (
auto idx : llvm::seq(0u,
static_cast<unsigned>(operands.size()))) {
224 locs.push_back(subConnections[idx].getLoc());
228 subConnections[idx].erase();
231 return isa<FVectorType>(parentType)
232 ? builder->createOrFold<VectorCreateOp>(
233 builder->getFusedLoc(locs), parentType, operands)
234 : builder->createOrFold<BundleCreateOp>(
235 builder->getFusedLoc(locs), parentType, operands);
239 if (
auto bundle = type_dyn_cast<BundleType>(parentType))
240 merged = getMergedValue(bundle);
241 if (
auto vector = type_dyn_cast<FVectorType>(parentType))
242 merged = getMergedValue(vector);
248 if (!parentBaseTy.hasUninferredWidth())
249 builder->create<MatchingConnectOp>(
connect.getLoc(), parent, merged);
251 builder->create<ConnectOp>(
connect.getLoc(), parent, merged);
257 ImplicitLocOpBuilder theBuilder(moduleOp.getLoc(), moduleOp.getContext());
258 builder = &theBuilder;
259 auto *body = moduleOp.getBodyBlock();
261 for (
auto it = body->begin(), e = body->end(); it != e;) {
262 auto connectOp = dyn_cast<MatchingConnectOp>(*it);
267 builder->setInsertionPointAfter(connectOp);
268 builder->setLoc(connectOp.getLoc());
269 bool removeOp = peelConnect(connectOp);
276 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*body)))
277 if (isa<SubfieldOp, SubindexOp, InvalidValueOp, ConstantOp, BitCastOp,
279 if (op.use_empty()) {
287 struct MergeConnectionsPass
288 :
public circt::firrtl::impl::MergeConnectionsBase<MergeConnectionsPass> {
289 MergeConnectionsPass(
bool enableAggressiveMergingFlag) {
290 enableAggressiveMerging = enableAggressiveMergingFlag;
292 void runOnOperation()
override;
297 void MergeConnectionsPass::runOnOperation() {
300 <<
"Module: '" << getOperation().
getName() <<
"'\n");
302 MergeConnection mergeConnection(getOperation(), enableAggressiveMerging);
303 bool changed = mergeConnection.run();
306 return markAllAnalysesPreserved();
309 std::unique_ptr<mlir::Pass>
311 return std::make_unique<MergeConnectionsPass>(enableAggressiveMerging);
assert(baseType &&"element must be base type")
static bool isConstantLike(Value value)
def connect(destination, source)
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
std::unique_ptr< mlir::Pass > createMergeConnectionsPass(bool enableAggressiveMerging=false)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)