CIRCT 20.0.0git
Loading...
Searching...
No Matches
MergeConnections.cpp
Go to the documentation of this file.
1//===- MergeConnections.cpp - Merge expanded connections --------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//===----------------------------------------------------------------------===//
7//
8// This pass merges expanded connections into one connection.
9// LowerTypes fully expands aggregate connections even when semantically
10// not necessary to expand because it is required for ExpandWhen.
11//
12// More specifically this pass folds the following patterns:
13// %dest(0) <= v0
14// %dest(1) <= v1
15// ...
16// %dest(n) <= vn
17// into
18// %dest <= {vn, .., v1, v0}
19// Also if v0, v1, .., vn are subfield op like %a(0), %a(1), ..., a(n), then we
20// merge entire connections into %dest <= %a.
21//
22//===----------------------------------------------------------------------===//
23
27#include "circt/Support/Debug.h"
28#include "mlir/IR/Iterators.h"
29#include "mlir/Pass/Pass.h"
30#include "llvm/ADT/TypeSwitch.h"
31#include "llvm/Support/Debug.h"
32
33#define DEBUG_TYPE "firrtl-merge-connections"
34
35namespace circt {
36namespace firrtl {
37#define GEN_PASS_DEF_MERGECONNECTIONS
38#include "circt/Dialect/FIRRTL/Passes.h.inc"
39} // namespace firrtl
40} // namespace circt
41
42using namespace circt;
43using namespace firrtl;
44
45// Return true if value is essentially constant.
46static bool isConstantLike(Value value) {
47 if (isa_and_nonnull<ConstantOp, InvalidValueOp>(value.getDefiningOp()))
48 return true;
49 if (auto bitcast = value.getDefiningOp<BitCastOp>())
50 return isConstant(bitcast.getInput());
51
52 // TODO: Add unrealized_conversion, asUInt, asSInt
53 return false;
54}
55
56namespace {
57//===----------------------------------------------------------------------===//
58// Pass Infrastructure
59//===----------------------------------------------------------------------===//
60
61// A helper struct to merge connections.
62struct MergeConnection {
63 MergeConnection(FModuleOp moduleOp, bool enableAggressiveMerging)
64 : moduleOp(moduleOp), enableAggressiveMerging(enableAggressiveMerging) {}
65
66 // Return true if something is changed.
67 bool run();
68 bool changed = false;
69
70 // Return true if the given connect op is merged.
71 bool peelConnect(MatchingConnectOp connect);
72
73 // A map from a destination FieldRef to a pair of (i) the number of
74 // connections seen so far and (ii) the vector to store subconnections.
75 DenseMap<FieldRef, std::pair<unsigned, SmallVector<MatchingConnectOp>>>
76 connections;
77
78 FModuleOp moduleOp;
79 ImplicitLocOpBuilder *builder = nullptr;
80
81 // If true, we merge connections even when source values will not be
82 // simplified.
83 bool enableAggressiveMerging = false;
84};
85
86bool MergeConnection::peelConnect(MatchingConnectOp connect) {
87 // Ignore connections between different types because it will produce a
88 // partial connect. Also ignore non-passive connections or non-integer
89 // connections.
90 LLVM_DEBUG(llvm::dbgs() << "Visiting " << connect << "\n");
91 auto destTy = type_dyn_cast<FIRRTLBaseType>(connect.getDest().getType());
92 if (!destTy || !destTy.isPassive() ||
93 !firrtl::getBitWidth(destTy).has_value())
94 return false;
95
96 auto destFieldRef = getFieldRefFromValue(connect.getDest());
97 auto destRoot = destFieldRef.getValue();
98
99 // If dest is derived from mem op or has a ground type, we cannot merge them.
100 // If the connect's destination is a root value, we cannot merge.
101 if (destRoot.getDefiningOp<MemOp>() || destRoot == connect.getDest())
102 return false;
103
104 Value parent;
105 unsigned index;
106 if (auto subfield = dyn_cast<SubfieldOp>(connect.getDest().getDefiningOp()))
107 parent = subfield.getInput(), index = subfield.getFieldIndex();
108 else if (auto subindex =
109 dyn_cast<SubindexOp>(connect.getDest().getDefiningOp()))
110 parent = subindex.getInput(), index = subindex.getIndex();
111 else
112 llvm_unreachable("unexpected destination");
113
114 auto &countAndSubConnections = connections[getFieldRefFromValue(parent)];
115 auto &count = countAndSubConnections.first;
116 auto &subConnections = countAndSubConnections.second;
117
118 // If it is the first time to visit the parent op, then allocate the vector
119 // for subconnections.
120 if (count == 0) {
121 if (auto bundle = type_dyn_cast<BundleType>(parent.getType()))
122 subConnections.resize(bundle.getNumElements());
123 if (auto vector = type_dyn_cast<FVectorType>(parent.getType()))
124 subConnections.resize(vector.getNumElements());
125 }
126 ++count;
127 subConnections[index] = connect;
128
129 // If we haven't visited all subconnections, stop at this point.
130 if (count != subConnections.size())
131 return false;
132
133 auto parentType = parent.getType();
134 auto parentBaseTy = type_dyn_cast<FIRRTLBaseType>(parentType);
135
136 // Reject if not passive, we don't support aggregate constants for these.
137 if (!parentBaseTy || !parentBaseTy.isPassive())
138 return false;
139
140 changed = true;
141
142 auto getMergedValue = [&](auto aggregateType) {
143 SmallVector<Value> operands;
144
145 // This flag tracks whether we can use the parent of source values as the
146 // merged value.
147 bool canUseSourceParent = true;
148 bool areOperandsAllConstants = true;
149
150 // The value which might be used as a merged value.
151 Value sourceParent;
152
153 auto checkSourceParent = [&](auto subelement, unsigned destIndex,
154 unsigned sourceIndex) {
155 // In the first iteration, register a parent value.
156 if (destIndex == 0) {
157 if (subelement.getInput().getType() == parentType)
158 sourceParent = subelement.getInput();
159 else {
160 // If types are not same, it is not possible to use it.
161 canUseSourceParent = false;
162 }
163 }
164
165 // Check that input is the same as `sourceAggregate` and indexes match.
166 canUseSourceParent &=
167 subelement.getInput() == sourceParent && destIndex == sourceIndex;
168 };
169
170 for (auto idx : llvm::seq(0u, (unsigned)aggregateType.getNumElements())) {
171 auto src = subConnections[idx].getSrc();
172 assert(src && "all subconnections are guranteed to exist");
173 operands.push_back(src);
174
175 areOperandsAllConstants &= isConstantLike(src);
176
177 // From here, check whether the value is derived from the same aggregate
178 // value.
179
180 // If canUseSourceParent is already false, abort.
181 if (!canUseSourceParent)
182 continue;
183
184 // If the value is an argument, it is not derived from an aggregate value.
185 if (!src.getDefiningOp()) {
186 canUseSourceParent = false;
187 continue;
188 }
189
190 TypeSwitch<Operation *>(src.getDefiningOp())
191 .template Case<SubfieldOp>([&](SubfieldOp subfield) {
192 checkSourceParent(subfield, idx, subfield.getFieldIndex());
193 })
194 .template Case<SubindexOp>([&](SubindexOp subindex) {
195 checkSourceParent(subindex, idx, subindex.getIndex());
196 })
197 .Default([&](auto) { canUseSourceParent = false; });
198 }
199
200 // If it is fine to use `sourceParent` as a merged value, we just
201 // return it.
202 if (canUseSourceParent) {
203 LLVM_DEBUG(llvm::dbgs() << "Success to merge " << destFieldRef.getValue()
204 << " ,fieldID= " << destFieldRef.getFieldID()
205 << " to " << sourceParent << "\n";);
206 // Erase connections except for subConnections[index] since it must be
207 // erased at the top-level loop.
208 for (auto idx : llvm::seq(0u, static_cast<unsigned>(operands.size())))
209 if (idx != index)
210 subConnections[idx].erase();
211 return sourceParent;
212 }
213
214 // If operands are not all constants, we don't merge connections unless
215 // "aggressive-merging" option is enabled.
216 if (!enableAggressiveMerging && !areOperandsAllConstants)
217 return Value();
218
219 SmallVector<Location> locs;
220 // Otherwise, we concat all values and cast them into the aggregate type.
221 for (auto idx : llvm::seq(0u, static_cast<unsigned>(operands.size()))) {
222 locs.push_back(subConnections[idx].getLoc());
223 // Erase connections except for subConnections[index] since it must be
224 // erased at the top-level loop.
225 if (idx != index)
226 subConnections[idx].erase();
227 }
228
229 return isa<FVectorType>(parentType)
230 ? builder->createOrFold<VectorCreateOp>(
231 builder->getFusedLoc(locs), parentType, operands)
232 : builder->createOrFold<BundleCreateOp>(
233 builder->getFusedLoc(locs), parentType, operands);
234 };
235
236 Value merged;
237 if (auto bundle = type_dyn_cast<BundleType>(parentType))
238 merged = getMergedValue(bundle);
239 if (auto vector = type_dyn_cast<FVectorType>(parentType))
240 merged = getMergedValue(vector);
241 if (!merged)
242 return false;
243
244 // Emit strict connect if possible, fallback to normal connect.
245 // Don't use emitConnect(), will split the connect apart.
246 if (!parentBaseTy.hasUninferredWidth())
247 builder->create<MatchingConnectOp>(connect.getLoc(), parent, merged);
248 else
249 builder->create<ConnectOp>(connect.getLoc(), parent, merged);
250
251 return true;
252}
253
254bool MergeConnection::run() {
255 ImplicitLocOpBuilder theBuilder(moduleOp.getLoc(), moduleOp.getContext());
256 builder = &theBuilder;
257
258 // Block worklist that tracks the current position within a block.
259 SmallVector<std::pair<Block::iterator, Block::iterator>> worklist;
260
261 // Walk the IR in order, top-to-bottom, stepping into blocks as they are
262 // found. This is basically the same as `moduleOp.walk`, however, it allows
263 // for visiting operations that are inserted _after_ the current operation.
264 // Using the existing `walk` does not do this.
265 auto *body = moduleOp.getBodyBlock();
266 worklist.push_back({body->begin(), body->end()});
267 while (!worklist.empty()) {
268 auto &[it, e] = worklist.back();
269
270 // Merge connections by forward iterations.
271 bool opWithBlocks = false;
272 while (it != e) {
273 // Add blocks to the stack such that they will be pulled off in-order.
274 for (auto &region : llvm::reverse(it->getRegions()))
275 for (auto &block : llvm::reverse(region.getBlocks())) {
276 worklist.push_back({block.begin(), block.end()});
277 opWithBlocks = true;
278 }
279
280 // We found one or more blocks. Stop and go process these blocks.
281 if (opWithBlocks) {
282 ++it;
283 break;
284 }
285
286 // This operation does not have blocks. Process it normally.
287 auto connectOp = dyn_cast<MatchingConnectOp>(*it);
288 if (!connectOp) {
289 ++it;
290 continue;
291 }
292 builder->setInsertionPointAfter(connectOp);
293 builder->setLoc(connectOp.getLoc());
294 bool removeOp = peelConnect(connectOp);
295 ++it;
296 if (removeOp)
297 connectOp.erase();
298 }
299
300 // We found a block and added to the worklist.
301 if (opWithBlocks)
302 continue;
303
304 // We finished processing a block.
305 worklist.pop_back();
306 }
307
308 // Clean up dead operations introduced by this pass.
309 moduleOp.walk<mlir::WalkOrder::PostOrder, mlir::ReverseIterator>(
310 [&](Operation *op) {
311 if (isa<SubfieldOp, SubindexOp, InvalidValueOp, ConstantOp, BitCastOp,
312 CatPrimOp>(op))
313 if (op->use_empty()) {
314 changed = true;
315 op->erase();
316 }
317 });
318
319 return changed;
320}
321
322struct MergeConnectionsPass
323 : public circt::firrtl::impl::MergeConnectionsBase<MergeConnectionsPass> {
324 MergeConnectionsPass(bool enableAggressiveMergingFlag) {
325 enableAggressiveMerging = enableAggressiveMergingFlag;
326 }
327 void runOnOperation() override;
328};
329
330} // namespace
331
332void MergeConnectionsPass::runOnOperation() {
333 LLVM_DEBUG(debugPassHeader(this)
334 << "\n"
335 << "Module: '" << getOperation().getName() << "'\n");
336
337 MergeConnection mergeConnection(getOperation(), enableAggressiveMerging);
338 bool changed = mergeConnection.run();
339
340 if (!changed)
341 return markAllAnalysesPreserved();
342}
343
344std::unique_ptr<mlir::Pass>
345circt::firrtl::createMergeConnectionsPass(bool enableAggressiveMerging) {
346 return std::make_unique<MergeConnectionsPass>(enableAggressiveMerging);
347}
assert(baseType &&"element must be base type")
static bool isConstantLike(Value value)
connect(destination, source)
Definition support.py:39
FieldRef getFieldRefFromValue(Value value, bool lookThroughCasts=false)
Get the FieldRef from a value.
bool isConstant(Operation *op)
Return true if the specified operation has a constant value.
std::unique_ptr< mlir::Pass > createMergeConnectionsPass(bool enableAggressiveMerging=false)
std::optional< int64_t > getBitWidth(FIRRTLBaseType type, bool ignoreFlip=false)
StringAttr getName(ArrayAttr names, size_t idx)
Return the name at the specified index of the ArrayAttr or null if it cannot be determined.
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
llvm::raw_ostream & debugPassHeader(const mlir::Pass *pass, int width=80)
Write a boilerplate header for a pass to the debug stream.
Definition Debug.cpp:31
int run(Type[Generator] generator=CppGenerator, cmdline_args=sys.argv)
Definition codegen.py:121
Definition seq.py:1