CIRCT  20.0.0git
MergeConnections.cpp
Go to the documentation of this file.
1 //===- MergeConnections.cpp - Merge expanded connections --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //===----------------------------------------------------------------------===//
7 //
8 // This pass merges expanded connections into one connection.
9 // LowerTypes fully expands aggregate connections even when semantically
10 // not necessary to expand because it is required for ExpandWhen.
11 //
12 // More specifically this pass folds the following patterns:
13 // %dest(0) <= v0
14 // %dest(1) <= v1
15 // ...
16 // %dest(n) <= vn
17 // into
18 // %dest <= {vn, .., v1, v0}
19 // Also if v0, v1, .., vn are subfield op like %a(0), %a(1), ..., a(n), then we
20 // merge entire connections into %dest <= %a.
21 //
22 //===----------------------------------------------------------------------===//
23 
27 #include "circt/Support/Debug.h"
28 #include "mlir/IR/Iterators.h"
29 #include "mlir/Pass/Pass.h"
30 #include "llvm/ADT/TypeSwitch.h"
31 #include "llvm/Support/Debug.h"
32 
33 #define DEBUG_TYPE "firrtl-merge-connections"
34 
35 namespace circt {
36 namespace firrtl {
37 #define GEN_PASS_DEF_MERGECONNECTIONS
38 #include "circt/Dialect/FIRRTL/Passes.h.inc"
39 } // namespace firrtl
40 } // namespace circt
41 
42 using namespace circt;
43 using namespace firrtl;
44 
45 // Return true if value is essentially constant.
46 static bool isConstantLike(Value value) {
47  if (isa_and_nonnull<ConstantOp, InvalidValueOp>(value.getDefiningOp()))
48  return true;
49  if (auto bitcast = value.getDefiningOp<BitCastOp>())
50  return isConstant(bitcast.getInput());
51 
52  // TODO: Add unrealized_conversion, asUInt, asSInt
53  return false;
54 }
55 
56 namespace {
57 //===----------------------------------------------------------------------===//
58 // Pass Infrastructure
59 //===----------------------------------------------------------------------===//
60 
61 // A helper struct to merge connections.
62 struct MergeConnection {
63  MergeConnection(FModuleOp moduleOp, bool enableAggressiveMerging)
64  : moduleOp(moduleOp), enableAggressiveMerging(enableAggressiveMerging) {}
65 
66  // Return true if something is changed.
67  bool run();
68  bool changed = false;
69 
70  // Return true if the given connect op is merged.
71  bool peelConnect(MatchingConnectOp connect);
72 
73  // A map from a destination FieldRef to a pair of (i) the number of
74  // connections seen so far and (ii) the vector to store subconnections.
75  DenseMap<FieldRef, std::pair<unsigned, SmallVector<MatchingConnectOp>>>
76  connections;
77 
78  FModuleOp moduleOp;
79  ImplicitLocOpBuilder *builder = nullptr;
80 
81  // If true, we merge connections even when source values will not be
82  // simplified.
83  bool enableAggressiveMerging = false;
84 };
85 
86 bool MergeConnection::peelConnect(MatchingConnectOp connect) {
87  // Ignore connections between different types because it will produce a
88  // partial connect. Also ignore non-passive connections or non-integer
89  // connections.
90  LLVM_DEBUG(llvm::dbgs() << "Visiting " << connect << "\n");
91  auto destTy = type_dyn_cast<FIRRTLBaseType>(connect.getDest().getType());
92  if (!destTy || !destTy.isPassive() ||
93  !firrtl::getBitWidth(destTy).has_value())
94  return false;
95 
96  auto destFieldRef = getFieldRefFromValue(connect.getDest());
97  auto destRoot = destFieldRef.getValue();
98 
99  // If dest is derived from mem op or has a ground type, we cannot merge them.
100  // If the connect's destination is a root value, we cannot merge.
101  if (destRoot.getDefiningOp<MemOp>() || destRoot == connect.getDest())
102  return false;
103 
104  Value parent;
105  unsigned index;
106  if (auto subfield = dyn_cast<SubfieldOp>(connect.getDest().getDefiningOp()))
107  parent = subfield.getInput(), index = subfield.getFieldIndex();
108  else if (auto subindex =
109  dyn_cast<SubindexOp>(connect.getDest().getDefiningOp()))
110  parent = subindex.getInput(), index = subindex.getIndex();
111  else
112  llvm_unreachable("unexpected destination");
113 
114  auto &countAndSubConnections = connections[getFieldRefFromValue(parent)];
115  auto &count = countAndSubConnections.first;
116  auto &subConnections = countAndSubConnections.second;
117 
118  // If it is the first time to visit the parent op, then allocate the vector
119  // for subconnections.
120  if (count == 0) {
121  if (auto bundle = type_dyn_cast<BundleType>(parent.getType()))
122  subConnections.resize(bundle.getNumElements());
123  if (auto vector = type_dyn_cast<FVectorType>(parent.getType()))
124  subConnections.resize(vector.getNumElements());
125  }
126  ++count;
127  subConnections[index] = connect;
128 
129  // If we haven't visited all subconnections, stop at this point.
130  if (count != subConnections.size())
131  return false;
132 
133  auto parentType = parent.getType();
134  auto parentBaseTy = type_dyn_cast<FIRRTLBaseType>(parentType);
135 
136  // Reject if not passive, we don't support aggregate constants for these.
137  if (!parentBaseTy || !parentBaseTy.isPassive())
138  return false;
139 
140  changed = true;
141 
142  auto getMergedValue = [&](auto aggregateType) {
143  SmallVector<Value> operands;
144 
145  // This flag tracks whether we can use the parent of source values as the
146  // merged value.
147  bool canUseSourceParent = true;
148  bool areOperandsAllConstants = true;
149 
150  // The value which might be used as a merged value.
151  Value sourceParent;
152 
153  auto checkSourceParent = [&](auto subelement, unsigned destIndex,
154  unsigned sourceIndex) {
155  // In the first iteration, register a parent value.
156  if (destIndex == 0) {
157  if (subelement.getInput().getType() == parentType)
158  sourceParent = subelement.getInput();
159  else {
160  // If types are not same, it is not possible to use it.
161  canUseSourceParent = false;
162  }
163  }
164 
165  // Check that input is the same as `sourceAggregate` and indexes match.
166  canUseSourceParent &=
167  subelement.getInput() == sourceParent && destIndex == sourceIndex;
168  };
169 
170  for (auto idx : llvm::seq(0u, (unsigned)aggregateType.getNumElements())) {
171  auto src = subConnections[idx].getSrc();
172  assert(src && "all subconnections are guranteed to exist");
173  operands.push_back(src);
174 
175  areOperandsAllConstants &= isConstantLike(src);
176 
177  // From here, check whether the value is derived from the same aggregate
178  // value.
179 
180  // If canUseSourceParent is already false, abort.
181  if (!canUseSourceParent)
182  continue;
183 
184  // If the value is an argument, it is not derived from an aggregate value.
185  if (!src.getDefiningOp()) {
186  canUseSourceParent = false;
187  continue;
188  }
189 
190  TypeSwitch<Operation *>(src.getDefiningOp())
191  .template Case<SubfieldOp>([&](SubfieldOp subfield) {
192  checkSourceParent(subfield, idx, subfield.getFieldIndex());
193  })
194  .template Case<SubindexOp>([&](SubindexOp subindex) {
195  checkSourceParent(subindex, idx, subindex.getIndex());
196  })
197  .Default([&](auto) { canUseSourceParent = false; });
198  }
199 
200  // If it is fine to use `sourceParent` as a merged value, we just
201  // return it.
202  if (canUseSourceParent) {
203  LLVM_DEBUG(llvm::dbgs() << "Success to merge " << destFieldRef.getValue()
204  << " ,fieldID= " << destFieldRef.getFieldID()
205  << " to " << sourceParent << "\n";);
206  // Erase connections except for subConnections[index] since it must be
207  // erased at the top-level loop.
208  for (auto idx : llvm::seq(0u, static_cast<unsigned>(operands.size())))
209  if (idx != index)
210  subConnections[idx].erase();
211  return sourceParent;
212  }
213 
214  // If operands are not all constants, we don't merge connections unless
215  // "aggressive-merging" option is enabled.
216  if (!enableAggressiveMerging && !areOperandsAllConstants)
217  return Value();
218 
219  SmallVector<Location> locs;
220  // Otherwise, we concat all values and cast them into the aggregate type.
221  for (auto idx : llvm::seq(0u, static_cast<unsigned>(operands.size()))) {
222  locs.push_back(subConnections[idx].getLoc());
223  // Erase connections except for subConnections[index] since it must be
224  // erased at the top-level loop.
225  if (idx != index)
226  subConnections[idx].erase();
227  }
228 
229  return isa<FVectorType>(parentType)
230  ? builder->createOrFold<VectorCreateOp>(
231  builder->getFusedLoc(locs), parentType, operands)
232  : builder->createOrFold<BundleCreateOp>(
233  builder->getFusedLoc(locs), parentType, operands);
234  };
235 
236  Value merged;
237  if (auto bundle = type_dyn_cast<BundleType>(parentType))
238  merged = getMergedValue(bundle);
239  if (auto vector = type_dyn_cast<FVectorType>(parentType))
240  merged = getMergedValue(vector);
241  if (!merged)
242  return false;
243 
244  // Emit strict connect if possible, fallback to normal connect.
245  // Don't use emitConnect(), will split the connect apart.
246  if (!parentBaseTy.hasUninferredWidth())
247  builder->create<MatchingConnectOp>(connect.getLoc(), parent, merged);
248  else
249  builder->create<ConnectOp>(connect.getLoc(), parent, merged);
250 
251  return true;
252 }
253 
254 bool MergeConnection::run() {
255  ImplicitLocOpBuilder theBuilder(moduleOp.getLoc(), moduleOp.getContext());
256  builder = &theBuilder;
257 
258  // Block worklist that tracks the current position within a block.
259  SmallVector<std::pair<Block::iterator, Block::iterator>> worklist;
260 
261  // Walk the IR in order, top-to-bottom, stepping into blocks as they are
262  // found. This is basically the same as `moduleOp.walk`, however, it allows
263  // for visiting operations that are inserted _after_ the current operation.
264  // Using the existing `walk` does not do this.
265  auto *body = moduleOp.getBodyBlock();
266  worklist.push_back({body->begin(), body->end()});
267  while (!worklist.empty()) {
268  auto &[it, e] = worklist.back();
269 
270  // Merge connections by forward iterations.
271  bool opWithBlocks = false;
272  while (it != e) {
273  // Add blocks to the stack such that they will be pulled off in-order.
274  for (auto &region : llvm::reverse(it->getRegions()))
275  for (auto &block : llvm::reverse(region.getBlocks())) {
276  worklist.push_back({block.begin(), block.end()});
277  opWithBlocks = true;
278  }
279 
280  // We found one or more blocks. Stop and go process these blocks.
281  if (opWithBlocks) {
282  ++it;
283  break;
284  }
285 
286  // This operation does not have blocks. Process it normally.
287  auto connectOp = dyn_cast<MatchingConnectOp>(*it);
288  if (!connectOp) {
289  ++it;
290  continue;
291  }
292  builder->setInsertionPointAfter(connectOp);
293  builder->setLoc(connectOp.getLoc());
294  bool removeOp = peelConnect(connectOp);
295  ++it;
296  if (removeOp)
297  connectOp.erase();
298  }
299 
300  // We found a block and added to the worklist.
301  if (opWithBlocks)
302  continue;
303 
304  // We finished processing a block.
305  worklist.pop_back();
306  }
307 
308  // Clean up dead operations introduced by this pass.
309  moduleOp.walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>(
310  [&](Operation *op) {
311  if (isa<SubfieldOp, SubindexOp, InvalidValueOp, ConstantOp, BitCastOp,
312  CatPrimOp>(op))
313  if (op->use_empty()) {
314  changed = true;
315  op->erase();
316  }
317  });
318 
319  return changed;
320 }
321 
322 struct MergeConnectionsPass
323  : public circt::firrtl::impl::MergeConnectionsBase<MergeConnectionsPass> {
324  MergeConnectionsPass(bool enableAggressiveMergingFlag) {
325  enableAggressiveMerging = enableAggressiveMergingFlag;
326  }
327  void runOnOperation() override;
328 };
329 
330 } // namespace
331 
332 void MergeConnectionsPass::runOnOperation() {
333  LLVM_DEBUG(debugPassHeader(this)
334  << "\n"
335  << "Module: '" << getOperation().getName() << "'\n");
336 
337  MergeConnection mergeConnection(getOperation(), enableAggressiveMerging);
338  bool changed = mergeConnection.run();
339 
340  if (!changed)
341  return markAllAnalysesPreserved();
342 }
343 
344 std::unique_ptr<mlir::Pass>
345 circt::firrtl::createMergeConnectionsPass(bool enableAggressiveMerging) {
346  return std::make_unique<MergeConnectionsPass>(enableAggressiveMerging);
347 }
assert(baseType &&"element must be base type")
static bool isConstantLike(Value value)
def connect(destination, source)
Definition: support.py:39
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
Definition: FIRRTLOps.cpp:4588
std::unique_ptr< mlir::Pass > createMergeConnectionsPass(bool enableAggressiveMerging=false)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
Definition: DebugAnalysis.h:21
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition: Debug.cpp:31
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition: codegen.py:121