27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/Debug.h"
31 #define DEBUG_TYPE "firrtl-merge-connections"
33 using namespace circt;
34 using namespace firrtl;
38 if (isa_and_nonnull<ConstantOp, InvalidValueOp>(
value.getDefiningOp()))
40 if (
auto bitcast =
value.getDefiningOp<BitCastOp>())
53 struct MergeConnection {
54 MergeConnection(FModuleOp moduleOp,
bool enableAggressiveMerging)
55 : moduleOp(moduleOp), enableAggressiveMerging(enableAggressiveMerging) {}
62 bool peelConnect(StrictConnectOp
connect);
66 DenseMap<FieldRef, std::pair<unsigned, SmallVector<StrictConnectOp>>>
70 ImplicitLocOpBuilder *
builder =
nullptr;
74 bool enableAggressiveMerging =
false;
77 bool MergeConnection::peelConnect(StrictConnectOp
connect) {
82 auto destTy = type_dyn_cast<FIRRTLBaseType>(
connect.getDest().getType());
83 if (!destTy || !destTy.isPassive() ||
88 auto destRoot = destFieldRef.getValue();
92 if (destRoot.getDefiningOp<MemOp>() || destRoot ==
connect.getDest())
97 if (
auto subfield = dyn_cast<SubfieldOp>(
connect.getDest().getDefiningOp()))
98 parent = subfield.getInput(), index = subfield.getFieldIndex();
99 else if (
auto subindex =
100 dyn_cast<SubindexOp>(
connect.getDest().getDefiningOp()))
101 parent = subindex.getInput(), index = subindex.getIndex();
103 llvm_unreachable(
"unexpected destination");
106 auto &count = countAndSubConnections.first;
107 auto &subConnections = countAndSubConnections.second;
112 if (
auto bundle = type_dyn_cast<BundleType>(parent.getType()))
113 subConnections.resize(bundle.getNumElements());
114 if (
auto vector = type_dyn_cast<FVectorType>(parent.getType()))
115 subConnections.resize(vector.getNumElements());
118 subConnections[index] =
connect;
121 if (count != subConnections.size())
124 auto parentType = parent.getType();
125 auto parentBaseTy = type_dyn_cast<FIRRTLBaseType>(parentType);
128 if (!parentBaseTy || !parentBaseTy.isPassive())
133 auto getMergedValue = [&](
auto aggregateType) {
134 SmallVector<Value> operands;
138 bool canUseSourceParent =
true;
139 bool areOperandsAllConstants =
true;
144 auto checkSourceParent = [&](
auto subelement,
unsigned destIndex,
145 unsigned sourceIndex) {
147 if (destIndex == 0) {
148 if (subelement.getInput().getType() == parentType)
149 sourceParent = subelement.getInput();
152 canUseSourceParent =
false;
157 canUseSourceParent &=
158 subelement.getInput() == sourceParent && destIndex == sourceIndex;
161 for (
auto idx : llvm::seq(0u, (
unsigned)aggregateType.getNumElements())) {
162 auto src = subConnections[idx].getSrc();
163 assert(src &&
"all subconnections are guranteed to exist");
164 operands.push_back(src);
172 if (!canUseSourceParent)
176 if (!src.getDefiningOp()) {
177 canUseSourceParent =
false;
181 TypeSwitch<Operation *>(src.getDefiningOp())
182 .template Case<SubfieldOp>([&](SubfieldOp subfield) {
183 checkSourceParent(subfield, idx, subfield.getFieldIndex());
185 .
template Case<SubindexOp>([&](SubindexOp subindex) {
186 checkSourceParent(subindex, idx, subindex.getIndex());
188 .Default([&](
auto) { canUseSourceParent =
false; });
193 if (canUseSourceParent) {
194 LLVM_DEBUG(
llvm::dbgs() <<
"Success to merge " << destFieldRef.getValue()
195 <<
" ,fieldID= " << destFieldRef.getFieldID()
196 <<
" to " << sourceParent <<
"\n";);
199 for (
auto idx : llvm::seq(0u,
static_cast<unsigned>(operands.size())))
201 subConnections[idx].erase();
207 if (!enableAggressiveMerging && !areOperandsAllConstants)
210 SmallVector<Location> locs;
212 for (
auto idx : llvm::seq(0u,
static_cast<unsigned>(operands.size()))) {
213 locs.push_back(subConnections[idx].getLoc());
217 subConnections[idx].erase();
220 return isa<FVectorType>(parentType)
221 ?
builder->createOrFold<VectorCreateOp>(
222 builder->getFusedLoc(locs), parentType, operands)
223 :
builder->createOrFold<BundleCreateOp>(
224 builder->getFusedLoc(locs), parentType, operands);
228 if (
auto bundle = type_dyn_cast<BundleType>(parentType))
229 merged = getMergedValue(bundle);
230 if (
auto vector = type_dyn_cast<FVectorType>(parentType))
231 merged = getMergedValue(vector);
237 if (!parentBaseTy.hasUninferredWidth())
238 builder->create<StrictConnectOp>(
connect.getLoc(), parent, merged);
245 bool MergeConnection::run() {
246 ImplicitLocOpBuilder theBuilder(moduleOp.getLoc(), moduleOp.getContext());
248 auto *body = moduleOp.getBodyBlock();
250 for (
auto it = body->begin(), e = body->end(); it != e;) {
251 auto connectOp = dyn_cast<StrictConnectOp>(*it);
256 builder->setInsertionPointAfter(connectOp);
257 builder->setLoc(connectOp.getLoc());
258 bool removeOp = peelConnect(connectOp);
265 for (
auto &op : llvm::make_early_inc_range(llvm::reverse(*body)))
266 if (isa<SubfieldOp, SubindexOp, InvalidValueOp, ConstantOp, BitCastOp,
268 if (op.use_empty()) {
276 struct MergeConnectionsPass
277 :
public MergeConnectionsBase<MergeConnectionsPass> {
278 MergeConnectionsPass(
bool enableAggressiveMergingFlag) {
279 enableAggressiveMerging = enableAggressiveMergingFlag;
281 void runOnOperation()
override;
286 void MergeConnectionsPass::runOnOperation() {
287 LLVM_DEBUG(
llvm::dbgs() <<
"===----- Running MergeConnections "
288 "--------------------------------------===\n"
289 <<
"Module: '" << getOperation().
getName() <<
"'\n";);
291 MergeConnection mergeConnection(getOperation(), enableAggressiveMerging);
292 bool changed = mergeConnection.run();
295 return markAllAnalysesPreserved();
298 std::unique_ptr<mlir::Pass>
300 return std::make_unique<MergeConnectionsPass>(enableAggressiveMerging);
assert(baseType &&"element must be base type")
static bool isConstantLike(Value value)
def connect(destination, source)
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
std::unique_ptr< mlir::Pass > createMergeConnectionsPass(bool enableAggressiveMerging=false)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
mlir::raw_indented_ostream & dbgs()