28 #include "mlir/IR/Iterators.h"
29 #include "mlir/Pass/Pass.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
33 #define DEBUG_TYPE "firrtl-merge-connections"
37 #define GEN_PASS_DEF_MERGECONNECTIONS
38 #include "circt/Dialect/FIRRTL/Passes.h.inc"
42 using namespace circt;
43 using namespace firrtl;
47 if (isa_and_nonnull<ConstantOp, InvalidValueOp>(value.getDefiningOp()))
49 if (
auto bitcast = value.getDefiningOp<BitCastOp>())
62 struct MergeConnection {
63 MergeConnection(FModuleOp moduleOp,
bool enableAggressiveMerging)
64 : moduleOp(moduleOp), enableAggressiveMerging(enableAggressiveMerging) {}
71 bool peelConnect(MatchingConnectOp
connect);
75 DenseMap<FieldRef, std::pair<unsigned, SmallVector<MatchingConnectOp>>>
79 ImplicitLocOpBuilder *builder =
nullptr;
83 bool enableAggressiveMerging =
false;
86 bool MergeConnection::peelConnect(MatchingConnectOp
connect) {
90 LLVM_DEBUG(llvm::dbgs() <<
"Visiting " <<
connect <<
"\n");
91 auto destTy = type_dyn_cast<FIRRTLBaseType>(
connect.getDest().getType());
92 if (!destTy || !destTy.isPassive() ||
97 auto destRoot = destFieldRef.getValue();
101 if (destRoot.getDefiningOp<MemOp>() || destRoot ==
connect.getDest())
106 if (
auto subfield = dyn_cast<SubfieldOp>(
connect.getDest().getDefiningOp()))
107 parent = subfield.getInput(), index = subfield.getFieldIndex();
108 else if (
auto subindex =
109 dyn_cast<SubindexOp>(
connect.getDest().getDefiningOp()))
110 parent = subindex.getInput(), index = subindex.getIndex();
112 llvm_unreachable(
"unexpected destination");
115 auto &count = countAndSubConnections.first;
116 auto &subConnections = countAndSubConnections.second;
121 if (
auto bundle = type_dyn_cast<BundleType>(parent.getType()))
122 subConnections.resize(bundle.getNumElements());
123 if (
auto vector = type_dyn_cast<FVectorType>(parent.getType()))
124 subConnections.resize(vector.getNumElements());
127 subConnections[index] =
connect;
130 if (count != subConnections.size())
133 auto parentType = parent.getType();
134 auto parentBaseTy = type_dyn_cast<FIRRTLBaseType>(parentType);
137 if (!parentBaseTy || !parentBaseTy.isPassive())
142 auto getMergedValue = [&](
auto aggregateType) {
143 SmallVector<Value> operands;
147 bool canUseSourceParent =
true;
148 bool areOperandsAllConstants =
true;
153 auto checkSourceParent = [&](
auto subelement,
unsigned destIndex,
154 unsigned sourceIndex) {
156 if (destIndex == 0) {
157 if (subelement.getInput().getType() == parentType)
158 sourceParent = subelement.getInput();
161 canUseSourceParent =
false;
166 canUseSourceParent &=
167 subelement.getInput() == sourceParent && destIndex == sourceIndex;
170 for (
auto idx : llvm::seq(0u, (
unsigned)aggregateType.getNumElements())) {
171 auto src = subConnections[idx].getSrc();
172 assert(src &&
"all subconnections are guranteed to exist");
173 operands.push_back(src);
181 if (!canUseSourceParent)
185 if (!src.getDefiningOp()) {
186 canUseSourceParent =
false;
190 TypeSwitch<Operation *>(src.getDefiningOp())
191 .template Case<SubfieldOp>([&](SubfieldOp subfield) {
192 checkSourceParent(subfield, idx, subfield.getFieldIndex());
194 .
template Case<SubindexOp>([&](SubindexOp subindex) {
195 checkSourceParent(subindex, idx, subindex.getIndex());
197 .Default([&](
auto) { canUseSourceParent =
false; });
202 if (canUseSourceParent) {
203 LLVM_DEBUG(llvm::dbgs() <<
"Success to merge " << destFieldRef.getValue()
204 <<
" ,fieldID= " << destFieldRef.getFieldID()
205 <<
" to " << sourceParent <<
"\n";);
208 for (
auto idx : llvm::seq(0u,
static_cast<unsigned>(operands.size())))
210 subConnections[idx].erase();
216 if (!enableAggressiveMerging && !areOperandsAllConstants)
219 SmallVector<Location> locs;
221 for (
auto idx : llvm::seq(0u,
static_cast<unsigned>(operands.size()))) {
222 locs.push_back(subConnections[idx].getLoc());
226 subConnections[idx].erase();
229 return isa<FVectorType>(parentType)
230 ? builder->createOrFold<VectorCreateOp>(
231 builder->getFusedLoc(locs), parentType, operands)
232 : builder->createOrFold<BundleCreateOp>(
233 builder->getFusedLoc(locs), parentType, operands);
237 if (
auto bundle = type_dyn_cast<BundleType>(parentType))
238 merged = getMergedValue(bundle);
239 if (
auto vector = type_dyn_cast<FVectorType>(parentType))
240 merged = getMergedValue(vector);
246 if (!parentBaseTy.hasUninferredWidth())
247 builder->create<MatchingConnectOp>(
connect.getLoc(), parent, merged);
249 builder->create<ConnectOp>(
connect.getLoc(), parent, merged);
255 ImplicitLocOpBuilder theBuilder(moduleOp.getLoc(), moduleOp.getContext());
256 builder = &theBuilder;
259 SmallVector<std::pair<Block::iterator, Block::iterator>> worklist;
265 auto *body = moduleOp.getBodyBlock();
266 worklist.push_back({body->begin(), body->end()});
267 while (!worklist.empty()) {
268 auto &[it, e] = worklist.back();
271 bool opWithBlocks =
false;
274 for (
auto ®ion : llvm::reverse(it->getRegions()))
275 for (
auto &block : llvm::reverse(region.getBlocks())) {
276 worklist.push_back({block.begin(), block.end()});
287 auto connectOp = dyn_cast<MatchingConnectOp>(*it);
292 builder->setInsertionPointAfter(connectOp);
293 builder->setLoc(connectOp.getLoc());
294 bool removeOp = peelConnect(connectOp);
309 moduleOp.walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>(
311 if (isa<SubfieldOp, SubindexOp, InvalidValueOp, ConstantOp, BitCastOp,
313 if (op->use_empty()) {
322 struct MergeConnectionsPass
323 :
public circt::firrtl::impl::MergeConnectionsBase<MergeConnectionsPass> {
324 MergeConnectionsPass(
bool enableAggressiveMergingFlag) {
325 enableAggressiveMerging = enableAggressiveMergingFlag;
327 void runOnOperation()
override;
332 void MergeConnectionsPass::runOnOperation() {
335 <<
"Module: '" << getOperation().
getName() <<
"'\n");
337 MergeConnection mergeConnection(getOperation(), enableAggressiveMerging);
338 bool changed = mergeConnection.run();
341 return markAllAnalysesPreserved();
344 std::unique_ptr<mlir::Pass>
346 return std::make_unique<MergeConnectionsPass>(enableAggressiveMerging);
assert(baseType &&"element must be base type")
static bool isConstantLike(Value value)
def connect(destination, source)
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
std::unique_ptr< mlir::Pass > createMergeConnectionsPass(bool enableAggressiveMerging=false)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)