CIRCT  20.0.0git
MergeConnections.cpp
Go to the documentation of this file.
1 //===- MergeConnections.cpp - Merge expanded connections --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //===----------------------------------------------------------------------===//
7 //
8 // This pass merges expanded connections into one connection.
9 // LowerTypes fully expands aggregate connections even when semantically
10 // not necessary to expand because it is required for ExpandWhen.
11 //
12 // More specifically this pass folds the following patterns:
13 // %dest(0) <= v0
14 // %dest(1) <= v1
15 // ...
16 // %dest(n) <= vn
17 // into
18 // %dest <= {vn, .., v1, v0}
19 // Also if v0, v1, .., vn are subfield op like %a(0), %a(1), ..., a(n), then we
20 // merge entire connections into %dest <= %a.
21 //
22 //===----------------------------------------------------------------------===//
23 
26 #include "mlir/Pass/Pass.h"
27 
30 #include "circt/Support/Debug.h"
31 #include "mlir/IR/ImplicitLocOpBuilder.h"
32 #include "llvm/ADT/TypeSwitch.h"
33 #include "llvm/Support/Debug.h"
34 
35 #define DEBUG_TYPE "firrtl-merge-connections"
36 
37 namespace circt {
38 namespace firrtl {
39 #define GEN_PASS_DEF_MERGECONNECTIONS
40 #include "circt/Dialect/FIRRTL/Passes.h.inc"
41 } // namespace firrtl
42 } // namespace circt
43 
44 using namespace circt;
45 using namespace firrtl;
46 
47 // Return true if value is essentially constant.
48 static bool isConstantLike(Value value) {
49  if (isa_and_nonnull<ConstantOp, InvalidValueOp>(value.getDefiningOp()))
50  return true;
51  if (auto bitcast = value.getDefiningOp<BitCastOp>())
52  return isConstant(bitcast.getInput());
53 
54  // TODO: Add unrealized_conversion, asUInt, asSInt
55  return false;
56 }
57 
58 namespace {
59 //===----------------------------------------------------------------------===//
60 // Pass Infrastructure
61 //===----------------------------------------------------------------------===//
62 
63 // A helper struct to merge connections.
64 struct MergeConnection {
65  MergeConnection(FModuleOp moduleOp, bool enableAggressiveMerging)
66  : moduleOp(moduleOp), enableAggressiveMerging(enableAggressiveMerging) {}
67 
68  // Return true if something is changed.
69  bool run();
70  bool changed = false;
71 
72  // Return true if the given connect op is merged.
73  bool peelConnect(MatchingConnectOp connect);
74 
75  // A map from a destination FieldRef to a pair of (i) the number of
76  // connections seen so far and (ii) the vector to store subconnections.
77  DenseMap<FieldRef, std::pair<unsigned, SmallVector<MatchingConnectOp>>>
78  connections;
79 
80  FModuleOp moduleOp;
81  ImplicitLocOpBuilder *builder = nullptr;
82 
83  // If true, we merge connections even when source values will not be
84  // simplified.
85  bool enableAggressiveMerging = false;
86 };
87 
88 bool MergeConnection::peelConnect(MatchingConnectOp connect) {
89  // Ignore connections between different types because it will produce a
90  // partial connect. Also ignore non-passive connections or non-integer
91  // connections.
92  LLVM_DEBUG(llvm::dbgs() << "Visiting " << connect << "\n");
93  auto destTy = type_dyn_cast<FIRRTLBaseType>(connect.getDest().getType());
94  if (!destTy || !destTy.isPassive() ||
95  !firrtl::getBitWidth(destTy).has_value())
96  return false;
97 
98  auto destFieldRef = getFieldRefFromValue(connect.getDest());
99  auto destRoot = destFieldRef.getValue();
100 
101  // If dest is derived from mem op or has a ground type, we cannot merge them.
102  // If the connect's destination is a root value, we cannot merge.
103  if (destRoot.getDefiningOp<MemOp>() || destRoot == connect.getDest())
104  return false;
105 
106  Value parent;
107  unsigned index;
108  if (auto subfield = dyn_cast<SubfieldOp>(connect.getDest().getDefiningOp()))
109  parent = subfield.getInput(), index = subfield.getFieldIndex();
110  else if (auto subindex =
111  dyn_cast<SubindexOp>(connect.getDest().getDefiningOp()))
112  parent = subindex.getInput(), index = subindex.getIndex();
113  else
114  llvm_unreachable("unexpected destination");
115 
116  auto &countAndSubConnections = connections[getFieldRefFromValue(parent)];
117  auto &count = countAndSubConnections.first;
118  auto &subConnections = countAndSubConnections.second;
119 
120  // If it is the first time to visit the parent op, then allocate the vector
121  // for subconnections.
122  if (count == 0) {
123  if (auto bundle = type_dyn_cast<BundleType>(parent.getType()))
124  subConnections.resize(bundle.getNumElements());
125  if (auto vector = type_dyn_cast<FVectorType>(parent.getType()))
126  subConnections.resize(vector.getNumElements());
127  }
128  ++count;
129  subConnections[index] = connect;
130 
131  // If we haven't visited all subconnections, stop at this point.
132  if (count != subConnections.size())
133  return false;
134 
135  auto parentType = parent.getType();
136  auto parentBaseTy = type_dyn_cast<FIRRTLBaseType>(parentType);
137 
138  // Reject if not passive, we don't support aggregate constants for these.
139  if (!parentBaseTy || !parentBaseTy.isPassive())
140  return false;
141 
142  changed = true;
143 
144  auto getMergedValue = [&](auto aggregateType) {
145  SmallVector<Value> operands;
146 
147  // This flag tracks whether we can use the parent of source values as the
148  // merged value.
149  bool canUseSourceParent = true;
150  bool areOperandsAllConstants = true;
151 
152  // The value which might be used as a merged value.
153  Value sourceParent;
154 
155  auto checkSourceParent = [&](auto subelement, unsigned destIndex,
156  unsigned sourceIndex) {
157  // In the first iteration, register a parent value.
158  if (destIndex == 0) {
159  if (subelement.getInput().getType() == parentType)
160  sourceParent = subelement.getInput();
161  else {
162  // If types are not same, it is not possible to use it.
163  canUseSourceParent = false;
164  }
165  }
166 
167  // Check that input is the same as `sourceAggregate` and indexes match.
168  canUseSourceParent &=
169  subelement.getInput() == sourceParent && destIndex == sourceIndex;
170  };
171 
172  for (auto idx : llvm::seq(0u, (unsigned)aggregateType.getNumElements())) {
173  auto src = subConnections[idx].getSrc();
174  assert(src && "all subconnections are guranteed to exist");
175  operands.push_back(src);
176 
177  areOperandsAllConstants &= isConstantLike(src);
178 
179  // From here, check whether the value is derived from the same aggregate
180  // value.
181 
182  // If canUseSourceParent is already false, abort.
183  if (!canUseSourceParent)
184  continue;
185 
186  // If the value is an argument, it is not derived from an aggregate value.
187  if (!src.getDefiningOp()) {
188  canUseSourceParent = false;
189  continue;
190  }
191 
192  TypeSwitch<Operation *>(src.getDefiningOp())
193  .template Case<SubfieldOp>([&](SubfieldOp subfield) {
194  checkSourceParent(subfield, idx, subfield.getFieldIndex());
195  })
196  .template Case<SubindexOp>([&](SubindexOp subindex) {
197  checkSourceParent(subindex, idx, subindex.getIndex());
198  })
199  .Default([&](auto) { canUseSourceParent = false; });
200  }
201 
202  // If it is fine to use `sourceParent` as a merged value, we just
203  // return it.
204  if (canUseSourceParent) {
205  LLVM_DEBUG(llvm::dbgs() << "Success to merge " << destFieldRef.getValue()
206  << " ,fieldID= " << destFieldRef.getFieldID()
207  << " to " << sourceParent << "\n";);
208  // Erase connections except for subConnections[index] since it must be
209  // erased at the top-level loop.
210  for (auto idx : llvm::seq(0u, static_cast<unsigned>(operands.size())))
211  if (idx != index)
212  subConnections[idx].erase();
213  return sourceParent;
214  }
215 
216  // If operands are not all constants, we don't merge connections unless
217  // "aggressive-merging" option is enabled.
218  if (!enableAggressiveMerging && !areOperandsAllConstants)
219  return Value();
220 
221  SmallVector<Location> locs;
222  // Otherwise, we concat all values and cast them into the aggregate type.
223  for (auto idx : llvm::seq(0u, static_cast<unsigned>(operands.size()))) {
224  locs.push_back(subConnections[idx].getLoc());
225  // Erase connections except for subConnections[index] since it must be
226  // erased at the top-level loop.
227  if (idx != index)
228  subConnections[idx].erase();
229  }
230 
231  return isa<FVectorType>(parentType)
232  ? builder->createOrFold<VectorCreateOp>(
233  builder->getFusedLoc(locs), parentType, operands)
234  : builder->createOrFold<BundleCreateOp>(
235  builder->getFusedLoc(locs), parentType, operands);
236  };
237 
238  Value merged;
239  if (auto bundle = type_dyn_cast<BundleType>(parentType))
240  merged = getMergedValue(bundle);
241  if (auto vector = type_dyn_cast<FVectorType>(parentType))
242  merged = getMergedValue(vector);
243  if (!merged)
244  return false;
245 
246  // Emit strict connect if possible, fallback to normal connect.
247  // Don't use emitConnect(), will split the connect apart.
248  if (!parentBaseTy.hasUninferredWidth())
249  builder->create<MatchingConnectOp>(connect.getLoc(), parent, merged);
250  else
251  builder->create<ConnectOp>(connect.getLoc(), parent, merged);
252 
253  return true;
254 }
255 
256 bool MergeConnection::run() {
257  ImplicitLocOpBuilder theBuilder(moduleOp.getLoc(), moduleOp.getContext());
258  builder = &theBuilder;
259  auto *body = moduleOp.getBodyBlock();
260  // Merge connections by forward iterations.
261  for (auto it = body->begin(), e = body->end(); it != e;) {
262  auto connectOp = dyn_cast<MatchingConnectOp>(*it);
263  if (!connectOp) {
264  it++;
265  continue;
266  }
267  builder->setInsertionPointAfter(connectOp);
268  builder->setLoc(connectOp.getLoc());
269  bool removeOp = peelConnect(connectOp);
270  ++it;
271  if (removeOp)
272  connectOp.erase();
273  }
274 
275  // Clean up dead operations introduced by this pass.
276  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*body)))
277  if (isa<SubfieldOp, SubindexOp, InvalidValueOp, ConstantOp, BitCastOp,
278  CatPrimOp>(op))
279  if (op.use_empty()) {
280  changed = true;
281  op.erase();
282  }
283 
284  return changed;
285 }
286 
287 struct MergeConnectionsPass
288  : public circt::firrtl::impl::MergeConnectionsBase<MergeConnectionsPass> {
289  MergeConnectionsPass(bool enableAggressiveMergingFlag) {
290  enableAggressiveMerging = enableAggressiveMergingFlag;
291  }
292  void runOnOperation() override;
293 };
294 
295 } // namespace
296 
297 void MergeConnectionsPass::runOnOperation() {
298  LLVM_DEBUG(debugPassHeader(this)
299  << "\n"
300  << "Module: '" << getOperation().getName() << "'\n");
301 
302  MergeConnection mergeConnection(getOperation(), enableAggressiveMerging);
303  bool changed = mergeConnection.run();
304 
305  if (!changed)
306  return markAllAnalysesPreserved();
307 }
308 
309 std::unique_ptr<mlir::Pass>
310 circt::firrtl::createMergeConnectionsPass(bool enableAggressiveMerging) {
311  return std::make_unique<MergeConnectionsPass>(enableAggressiveMerging);
312 }
assert(baseType &&"element must be base type")
static bool isConstantLike(Value value)
def connect(destination, source)
Definition: support.py:39
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
Definition: FIRRTLOps.cpp:4616
std::unique_ptr< mlir::Pass > createMergeConnectionsPass(bool enableAggressiveMerging=false)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition: Debug.cpp:31
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition: codegen.py:121