CIRCT  19.0.0git
MergeConnections.cpp
Go to the documentation of this file.
1 //===- MergeConnections.cpp - Merge expanded connections --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //===----------------------------------------------------------------------===//
7 //
8 // This pass merges expanded connections into one connection.
9 // LowerTypes fully expands aggregate connections even when semantically
10 // not necessary to expand because it is required for ExpandWhen.
11 //
12 // More specifically this pass folds the following patterns:
13 // %dest(0) <= v0
14 // %dest(1) <= v1
15 // ...
16 // %dest(n) <= vn
17 // into
18 // %dest <= {vn, .., v1, v0}
19 // Also if v0, v1, .., vn are subfield op like %a(0), %a(1), ..., a(n), then we
20 // merge entire connections into %dest <= %a.
21 //
22 //===----------------------------------------------------------------------===//
23 
24 #include "PassDetails.h"
25 
28 #include "circt/Support/Debug.h"
29 #include "mlir/IR/ImplicitLocOpBuilder.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
32 
33 #define DEBUG_TYPE "firrtl-merge-connections"
34 
35 using namespace circt;
36 using namespace firrtl;
37 
38 // Return true if value is essentially constant.
39 static bool isConstantLike(Value value) {
40  if (isa_and_nonnull<ConstantOp, InvalidValueOp>(value.getDefiningOp()))
41  return true;
42  if (auto bitcast = value.getDefiningOp<BitCastOp>())
43  return isConstant(bitcast.getInput());
44 
45  // TODO: Add unrealized_conversion, asUInt, asSInt
46  return false;
47 }
48 
49 namespace {
50 //===----------------------------------------------------------------------===//
51 // Pass Infrastructure
52 //===----------------------------------------------------------------------===//
53 
54 // A helper struct to merge connections.
55 struct MergeConnection {
56  MergeConnection(FModuleOp moduleOp, bool enableAggressiveMerging)
57  : moduleOp(moduleOp), enableAggressiveMerging(enableAggressiveMerging) {}
58 
59  // Return true if something is changed.
60  bool run();
61  bool changed = false;
62 
63  // Return true if the given connect op is merged.
64  bool peelConnect(StrictConnectOp connect);
65 
66  // A map from a destination FieldRef to a pair of (i) the number of
67  // connections seen so far and (ii) the vector to store subconnections.
68  DenseMap<FieldRef, std::pair<unsigned, SmallVector<StrictConnectOp>>>
69  connections;
70 
71  FModuleOp moduleOp;
72  ImplicitLocOpBuilder *builder = nullptr;
73 
74  // If true, we merge connections even when source values will not be
75  // simplified.
76  bool enableAggressiveMerging = false;
77 };
78 
79 bool MergeConnection::peelConnect(StrictConnectOp connect) {
80  // Ignore connections between different types because it will produce a
81  // partial connect. Also ignore non-passive connections or non-integer
82  // connections.
83  LLVM_DEBUG(llvm::dbgs() << "Visiting " << connect << "\n");
84  auto destTy = type_dyn_cast<FIRRTLBaseType>(connect.getDest().getType());
85  if (!destTy || !destTy.isPassive() ||
86  !firrtl::getBitWidth(destTy).has_value())
87  return false;
88 
89  auto destFieldRef = getFieldRefFromValue(connect.getDest());
90  auto destRoot = destFieldRef.getValue();
91 
92  // If dest is derived from mem op or has a ground type, we cannot merge them.
93  // If the connect's destination is a root value, we cannot merge.
94  if (destRoot.getDefiningOp<MemOp>() || destRoot == connect.getDest())
95  return false;
96 
97  Value parent;
98  unsigned index;
99  if (auto subfield = dyn_cast<SubfieldOp>(connect.getDest().getDefiningOp()))
100  parent = subfield.getInput(), index = subfield.getFieldIndex();
101  else if (auto subindex =
102  dyn_cast<SubindexOp>(connect.getDest().getDefiningOp()))
103  parent = subindex.getInput(), index = subindex.getIndex();
104  else
105  llvm_unreachable("unexpected destination");
106 
107  auto &countAndSubConnections = connections[getFieldRefFromValue(parent)];
108  auto &count = countAndSubConnections.first;
109  auto &subConnections = countAndSubConnections.second;
110 
111  // If it is the first time to visit the parent op, then allocate the vector
112  // for subconnections.
113  if (count == 0) {
114  if (auto bundle = type_dyn_cast<BundleType>(parent.getType()))
115  subConnections.resize(bundle.getNumElements());
116  if (auto vector = type_dyn_cast<FVectorType>(parent.getType()))
117  subConnections.resize(vector.getNumElements());
118  }
119  ++count;
120  subConnections[index] = connect;
121 
122  // If we haven't visited all subconnections, stop at this point.
123  if (count != subConnections.size())
124  return false;
125 
126  auto parentType = parent.getType();
127  auto parentBaseTy = type_dyn_cast<FIRRTLBaseType>(parentType);
128 
129  // Reject if not passive, we don't support aggregate constants for these.
130  if (!parentBaseTy || !parentBaseTy.isPassive())
131  return false;
132 
133  changed = true;
134 
135  auto getMergedValue = [&](auto aggregateType) {
136  SmallVector<Value> operands;
137 
138  // This flag tracks whether we can use the parent of source values as the
139  // merged value.
140  bool canUseSourceParent = true;
141  bool areOperandsAllConstants = true;
142 
143  // The value which might be used as a merged value.
144  Value sourceParent;
145 
146  auto checkSourceParent = [&](auto subelement, unsigned destIndex,
147  unsigned sourceIndex) {
148  // In the first iteration, register a parent value.
149  if (destIndex == 0) {
150  if (subelement.getInput().getType() == parentType)
151  sourceParent = subelement.getInput();
152  else {
153  // If types are not same, it is not possible to use it.
154  canUseSourceParent = false;
155  }
156  }
157 
158  // Check that input is the same as `sourceAggregate` and indexes match.
159  canUseSourceParent &=
160  subelement.getInput() == sourceParent && destIndex == sourceIndex;
161  };
162 
163  for (auto idx : llvm::seq(0u, (unsigned)aggregateType.getNumElements())) {
164  auto src = subConnections[idx].getSrc();
165  assert(src && "all subconnections are guranteed to exist");
166  operands.push_back(src);
167 
168  areOperandsAllConstants &= isConstantLike(src);
169 
170  // From here, check whether the value is derived from the same aggregate
171  // value.
172 
173  // If canUseSourceParent is already false, abort.
174  if (!canUseSourceParent)
175  continue;
176 
177  // If the value is an argument, it is not derived from an aggregate value.
178  if (!src.getDefiningOp()) {
179  canUseSourceParent = false;
180  continue;
181  }
182 
183  TypeSwitch<Operation *>(src.getDefiningOp())
184  .template Case<SubfieldOp>([&](SubfieldOp subfield) {
185  checkSourceParent(subfield, idx, subfield.getFieldIndex());
186  })
187  .template Case<SubindexOp>([&](SubindexOp subindex) {
188  checkSourceParent(subindex, idx, subindex.getIndex());
189  })
190  .Default([&](auto) { canUseSourceParent = false; });
191  }
192 
193  // If it is fine to use `sourceParent` as a merged value, we just
194  // return it.
195  if (canUseSourceParent) {
196  LLVM_DEBUG(llvm::dbgs() << "Success to merge " << destFieldRef.getValue()
197  << " ,fieldID= " << destFieldRef.getFieldID()
198  << " to " << sourceParent << "\n";);
199  // Erase connections except for subConnections[index] since it must be
200  // erased at the top-level loop.
201  for (auto idx : llvm::seq(0u, static_cast<unsigned>(operands.size())))
202  if (idx != index)
203  subConnections[idx].erase();
204  return sourceParent;
205  }
206 
207  // If operands are not all constants, we don't merge connections unless
208  // "aggressive-merging" option is enabled.
209  if (!enableAggressiveMerging && !areOperandsAllConstants)
210  return Value();
211 
212  SmallVector<Location> locs;
213  // Otherwise, we concat all values and cast them into the aggregate type.
214  for (auto idx : llvm::seq(0u, static_cast<unsigned>(operands.size()))) {
215  locs.push_back(subConnections[idx].getLoc());
216  // Erase connections except for subConnections[index] since it must be
217  // erased at the top-level loop.
218  if (idx != index)
219  subConnections[idx].erase();
220  }
221 
222  return isa<FVectorType>(parentType)
223  ? builder->createOrFold<VectorCreateOp>(
224  builder->getFusedLoc(locs), parentType, operands)
225  : builder->createOrFold<BundleCreateOp>(
226  builder->getFusedLoc(locs), parentType, operands);
227  };
228 
229  Value merged;
230  if (auto bundle = type_dyn_cast<BundleType>(parentType))
231  merged = getMergedValue(bundle);
232  if (auto vector = type_dyn_cast<FVectorType>(parentType))
233  merged = getMergedValue(vector);
234  if (!merged)
235  return false;
236 
237  // Emit strict connect if possible, fallback to normal connect.
238  // Don't use emitConnect(), will split the connect apart.
239  if (!parentBaseTy.hasUninferredWidth())
240  builder->create<StrictConnectOp>(connect.getLoc(), parent, merged);
241  else
242  builder->create<ConnectOp>(connect.getLoc(), parent, merged);
243 
244  return true;
245 }
246 
247 bool MergeConnection::run() {
248  ImplicitLocOpBuilder theBuilder(moduleOp.getLoc(), moduleOp.getContext());
249  builder = &theBuilder;
250  auto *body = moduleOp.getBodyBlock();
251  // Merge connections by forward iterations.
252  for (auto it = body->begin(), e = body->end(); it != e;) {
253  auto connectOp = dyn_cast<StrictConnectOp>(*it);
254  if (!connectOp) {
255  it++;
256  continue;
257  }
258  builder->setInsertionPointAfter(connectOp);
259  builder->setLoc(connectOp.getLoc());
260  bool removeOp = peelConnect(connectOp);
261  ++it;
262  if (removeOp)
263  connectOp.erase();
264  }
265 
266  // Clean up dead operations introduced by this pass.
267  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*body)))
268  if (isa<SubfieldOp, SubindexOp, InvalidValueOp, ConstantOp, BitCastOp,
269  CatPrimOp>(op))
270  if (op.use_empty()) {
271  changed = true;
272  op.erase();
273  }
274 
275  return changed;
276 }
277 
278 struct MergeConnectionsPass
279  : public MergeConnectionsBase<MergeConnectionsPass> {
280  MergeConnectionsPass(bool enableAggressiveMergingFlag) {
281  enableAggressiveMerging = enableAggressiveMergingFlag;
282  }
283  void runOnOperation() override;
284 };
285 
286 } // namespace
287 
288 void MergeConnectionsPass::runOnOperation() {
289  LLVM_DEBUG(debugPassHeader(this)
290  << "\n"
291  << "Module: '" << getOperation().getName() << "'\n");
292 
293  MergeConnection mergeConnection(getOperation(), enableAggressiveMerging);
294  bool changed = mergeConnection.run();
295 
296  if (!changed)
297  return markAllAnalysesPreserved();
298 }
299 
300 std::unique_ptr<mlir::Pass>
301 circt::firrtl::createMergeConnectionsPass(bool enableAggressiveMerging) {
302  return std::make_unique<MergeConnectionsPass>(enableAggressiveMerging);
303 }
assert(baseType &&"element must be base type")
static bool isConstantLike(Value value)
Builder builder
def connect(destination, source)
Definition: support.py:37
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
Definition: FIRRTLOps.cpp:4494
std::unique_ptr< mlir::Pass > createMergeConnectionsPass(bool enableAggressiveMerging=false)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition: Debug.cpp:31