CIRCT  18.0.0git
MergeConnections.cpp
Go to the documentation of this file.
1 //===- MergeConnections.cpp - Merge expanded connections --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //===----------------------------------------------------------------------===//
7 //
8 // This pass merges expanded connections into one connection.
9 // LowerTypes fully expands aggregate connections even when semantically
10 // not necessary to expand because it is required for ExpandWhen.
11 //
12 // More specifically this pass folds the following patterns:
13 // %dest(0) <= v0
14 // %dest(1) <= v1
15 // ...
16 // %dest(n) <= vn
17 // into
18 // %dest <= {vn, .., v1, v0}
19 // Also if v0, v1, .., vn are subfield op like %a(0), %a(1), ..., a(n), then we
20 // merge entire connections into %dest <= %a.
21 //
22 //===----------------------------------------------------------------------===//
23 
24 #include "PassDetails.h"
27 #include "mlir/IR/ImplicitLocOpBuilder.h"
28 #include "llvm/ADT/TypeSwitch.h"
29 #include "llvm/Support/Debug.h"
30 
31 #define DEBUG_TYPE "firrtl-merge-connections"
32 
33 using namespace circt;
34 using namespace firrtl;
35 
36 // Return true if value is essentially constant.
37 static bool isConstantLike(Value value) {
38  if (isa_and_nonnull<ConstantOp, InvalidValueOp>(value.getDefiningOp()))
39  return true;
40  if (auto bitcast = value.getDefiningOp<BitCastOp>())
41  return isConstant(bitcast.getInput());
42 
43  // TODO: Add unrealized_conversion, asUInt, asSInt
44  return false;
45 }
46 
47 namespace {
48 //===----------------------------------------------------------------------===//
49 // Pass Infrastructure
50 //===----------------------------------------------------------------------===//
51 
52 // A helper struct to merge connections.
53 struct MergeConnection {
54  MergeConnection(FModuleOp moduleOp, bool enableAggressiveMerging)
55  : moduleOp(moduleOp), enableAggressiveMerging(enableAggressiveMerging) {}
56 
57  // Return true if something is changed.
58  bool run();
59  bool changed = false;
60 
61  // Return true if the given connect op is merged.
62  bool peelConnect(StrictConnectOp connect);
63 
64  // A map from a destination FieldRef to a pair of (i) the number of
65  // connections seen so far and (ii) the vector to store subconnections.
66  DenseMap<FieldRef, std::pair<unsigned, SmallVector<StrictConnectOp>>>
67  connections;
68 
69  FModuleOp moduleOp;
70  ImplicitLocOpBuilder *builder = nullptr;
71 
72  // If true, we merge connections even when source values will not be
73  // simplified.
74  bool enableAggressiveMerging = false;
75 };
76 
77 bool MergeConnection::peelConnect(StrictConnectOp connect) {
78  // Ignore connections between different types because it will produce a
79  // partial connect. Also ignore non-passive connections or non-integer
80  // connections.
81  LLVM_DEBUG(llvm::dbgs() << "Visiting " << connect << "\n");
82  auto destTy = type_dyn_cast<FIRRTLBaseType>(connect.getDest().getType());
83  if (!destTy || !destTy.isPassive() ||
84  !firrtl::getBitWidth(destTy).has_value())
85  return false;
86 
87  auto destFieldRef = getFieldRefFromValue(connect.getDest());
88  auto destRoot = destFieldRef.getValue();
89 
90  // If dest is derived from mem op or has a ground type, we cannot merge them.
91  // If the connect's destination is a root value, we cannot merge.
92  if (destRoot.getDefiningOp<MemOp>() || destRoot == connect.getDest())
93  return false;
94 
95  Value parent;
96  unsigned index;
97  if (auto subfield = dyn_cast<SubfieldOp>(connect.getDest().getDefiningOp()))
98  parent = subfield.getInput(), index = subfield.getFieldIndex();
99  else if (auto subindex =
100  dyn_cast<SubindexOp>(connect.getDest().getDefiningOp()))
101  parent = subindex.getInput(), index = subindex.getIndex();
102  else
103  llvm_unreachable("unexpected destination");
104 
105  auto &countAndSubConnections = connections[getFieldRefFromValue(parent)];
106  auto &count = countAndSubConnections.first;
107  auto &subConnections = countAndSubConnections.second;
108 
109  // If it is the first time to visit the parent op, then allocate the vector
110  // for subconnections.
111  if (count == 0) {
112  if (auto bundle = type_dyn_cast<BundleType>(parent.getType()))
113  subConnections.resize(bundle.getNumElements());
114  if (auto vector = type_dyn_cast<FVectorType>(parent.getType()))
115  subConnections.resize(vector.getNumElements());
116  }
117  ++count;
118  subConnections[index] = connect;
119 
120  // If we haven't visited all subconnections, stop at this point.
121  if (count != subConnections.size())
122  return false;
123 
124  auto parentType = parent.getType();
125  auto parentBaseTy = type_dyn_cast<FIRRTLBaseType>(parentType);
126 
127  // Reject if not passive, we don't support aggregate constants for these.
128  if (!parentBaseTy || !parentBaseTy.isPassive())
129  return false;
130 
131  changed = true;
132 
133  auto getMergedValue = [&](auto aggregateType) {
134  SmallVector<Value> operands;
135 
136  // This flag tracks whether we can use the parent of source values as the
137  // merged value.
138  bool canUseSourceParent = true;
139  bool areOperandsAllConstants = true;
140 
141  // The value which might be used as a merged value.
142  Value sourceParent;
143 
144  auto checkSourceParent = [&](auto subelement, unsigned destIndex,
145  unsigned sourceIndex) {
146  // In the first iteration, register a parent value.
147  if (destIndex == 0) {
148  if (subelement.getInput().getType() == parentType)
149  sourceParent = subelement.getInput();
150  else {
151  // If types are not same, it is not possible to use it.
152  canUseSourceParent = false;
153  }
154  }
155 
156  // Check that input is the same as `sourceAggregate` and indexes match.
157  canUseSourceParent &=
158  subelement.getInput() == sourceParent && destIndex == sourceIndex;
159  };
160 
161  for (auto idx : llvm::seq(0u, (unsigned)aggregateType.getNumElements())) {
162  auto src = subConnections[idx].getSrc();
163  assert(src && "all subconnections are guranteed to exist");
164  operands.push_back(src);
165 
166  areOperandsAllConstants &= isConstantLike(src);
167 
168  // From here, check whether the value is derived from the same aggregate
169  // value.
170 
171  // If canUseSourceParent is already false, abort.
172  if (!canUseSourceParent)
173  continue;
174 
175  // If the value is an argument, it is not derived from an aggregate value.
176  if (!src.getDefiningOp()) {
177  canUseSourceParent = false;
178  continue;
179  }
180 
181  TypeSwitch<Operation *>(src.getDefiningOp())
182  .template Case<SubfieldOp>([&](SubfieldOp subfield) {
183  checkSourceParent(subfield, idx, subfield.getFieldIndex());
184  })
185  .template Case<SubindexOp>([&](SubindexOp subindex) {
186  checkSourceParent(subindex, idx, subindex.getIndex());
187  })
188  .Default([&](auto) { canUseSourceParent = false; });
189  }
190 
191  // If it is fine to use `sourceParent` as a merged value, we just
192  // return it.
193  if (canUseSourceParent) {
194  LLVM_DEBUG(llvm::dbgs() << "Success to merge " << destFieldRef.getValue()
195  << " ,fieldID= " << destFieldRef.getFieldID()
196  << " to " << sourceParent << "\n";);
197  // Erase connections except for subConnections[index] since it must be
198  // erased at the top-level loop.
199  for (auto idx : llvm::seq(0u, static_cast<unsigned>(operands.size())))
200  if (idx != index)
201  subConnections[idx].erase();
202  return sourceParent;
203  }
204 
205  // If operands are not all constants, we don't merge connections unless
206  // "aggressive-merging" option is enabled.
207  if (!enableAggressiveMerging && !areOperandsAllConstants)
208  return Value();
209 
210  SmallVector<Location> locs;
211  // Otherwise, we concat all values and cast them into the aggregate type.
212  for (auto idx : llvm::seq(0u, static_cast<unsigned>(operands.size()))) {
213  locs.push_back(subConnections[idx].getLoc());
214  // Erase connections except for subConnections[index] since it must be
215  // erased at the top-level loop.
216  if (idx != index)
217  subConnections[idx].erase();
218  }
219 
220  return isa<FVectorType>(parentType)
221  ? builder->createOrFold<VectorCreateOp>(
222  builder->getFusedLoc(locs), parentType, operands)
223  : builder->createOrFold<BundleCreateOp>(
224  builder->getFusedLoc(locs), parentType, operands);
225  };
226 
227  Value merged;
228  if (auto bundle = type_dyn_cast<BundleType>(parentType))
229  merged = getMergedValue(bundle);
230  if (auto vector = type_dyn_cast<FVectorType>(parentType))
231  merged = getMergedValue(vector);
232  if (!merged)
233  return false;
234 
235  // Emit strict connect if possible, fallback to normal connect.
236  // Don't use emitConnect(), will split the connect apart.
237  if (!parentBaseTy.hasUninferredWidth())
238  builder->create<StrictConnectOp>(connect.getLoc(), parent, merged);
239  else
240  builder->create<ConnectOp>(connect.getLoc(), parent, merged);
241 
242  return true;
243 }
244 
245 bool MergeConnection::run() {
246  ImplicitLocOpBuilder theBuilder(moduleOp.getLoc(), moduleOp.getContext());
247  builder = &theBuilder;
248  auto *body = moduleOp.getBodyBlock();
249  // Merge connections by forward iterations.
250  for (auto it = body->begin(), e = body->end(); it != e;) {
251  auto connectOp = dyn_cast<StrictConnectOp>(*it);
252  if (!connectOp) {
253  it++;
254  continue;
255  }
256  builder->setInsertionPointAfter(connectOp);
257  builder->setLoc(connectOp.getLoc());
258  bool removeOp = peelConnect(connectOp);
259  ++it;
260  if (removeOp)
261  connectOp.erase();
262  }
263 
264  // Clean up dead operations introduced by this pass.
265  for (auto &op : llvm::make_early_inc_range(llvm::reverse(*body)))
266  if (isa<SubfieldOp, SubindexOp, InvalidValueOp, ConstantOp, BitCastOp,
267  CatPrimOp>(op))
268  if (op.use_empty()) {
269  changed = true;
270  op.erase();
271  }
272 
273  return changed;
274 }
275 
276 struct MergeConnectionsPass
277  : public MergeConnectionsBase<MergeConnectionsPass> {
278  MergeConnectionsPass(bool enableAggressiveMergingFlag) {
279  enableAggressiveMerging = enableAggressiveMergingFlag;
280  }
281  void runOnOperation() override;
282 };
283 
284 } // namespace
285 
286 void MergeConnectionsPass::runOnOperation() {
287  LLVM_DEBUG(llvm::dbgs() << "===----- Running MergeConnections "
288  "--------------------------------------===\n"
289  << "Module: '" << getOperation().getName() << "'\n";);
290 
291  MergeConnection mergeConnection(getOperation(), enableAggressiveMerging);
292  bool changed = mergeConnection.run();
293 
294  if (!changed)
295  return markAllAnalysesPreserved();
296 }
297 
298 std::unique_ptr<mlir::Pass>
299 circt::firrtl::createMergeConnectionsPass(bool enableAggressiveMerging) {
300  return std::make_unique<MergeConnectionsPass>(enableAggressiveMerging);
301 }
lowerAnnotationsNoRefTypePorts FirtoolPreserveValuesMode value
Definition: Firtool.cpp:95
assert(baseType &&"element must be base type")
static bool isConstantLike(Value value)
Builder builder
def connect(destination, source)
Definition: support.py:37
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
Definition: FIRRTLOps.cpp:4070
std::unique_ptr< mlir::Pass > createMergeConnectionsPass(bool enableAggressiveMerging=false)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
This file defines an intermediate representation for circuits acting as an abstraction for constraint...
Definition: DebugAnalysis.h:21
mlir::raw_indented_ostream & dbgs()
Definition: Utility.h:28