CIRCT 23.0.0git
Loading...
Searching...
No Matches
SparseOpSCC.h
Go to the documentation of this file.
1//===- SparseOpSCCs.h - SCC analysis on sparse op subgraphs ---------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Collect strongly connected components (SCCs) in the (filtered) def-use graph
10// of MLIR operations, starting from a sparse set of seed operations.
11//
12// Graph model
13// -----------
14// Each operation is a node. A directed edge runs from op A to op B if B uses
15// one of A's results. The traversal direction is configurable:
16// - OpSCCDirection::Forward -- follow edges from defining ops to uses.
17// - OpSCCDirection::Backward -- follow edges from uses to defining op.
18//
19// SCC classification
20// ------------------
21// An SCC is either:
22// - Trivial: a single op with no self-loop. Represented as
23// mlir::Operation * inside an OpSCC value.
24// - Cyclic: a group of mutually-reachable ops, or a single op with a
25// self-loop. Represented as a CyclicOpSCC inside an OpSCC value.
26//
27// Filtering
28// ---------
29// An optional OpSCCFilter predicate can be supplied to the constructor to
30// prevent the traversal over certain edges of the graph. The first argument
31// contains the operation into which the traversal would lead. The second
32// argument contains the edge's destination operand. For forward traversal
33// the operand's owner is identical to the first argument. For reverse
34// traversal the first argument is identical to the operand's defining
35// operation.
36//
37// Output ordering
38// ---------------
39// SCCs are available in a topological order of the condensation DAG via
40// the iterator returned by topological(), and in the reverse via
41// reverseTopological(). The returned order is deterministic (identical graphs
42// will result in an identical order). However, if more than one topological
43// order exists, there is no guarantee on the specific order. Equally, the
44// order of operations within an SCC is deterministic but unspecified.
45//
46// Blocks and Regions
47// ------------------
48// The traversal does not follow through block arguments. It does not consider
49// control flow. It will descend into / ascend from regions without considering
50// the parent operation. The filter predicate can be used to restrict the
51// traversal to certain blocks or regions.
52//
53// Operation Graph Mutation
54// ------------------------
55// The SparseOpSCC class internally stores the result of the SCC analysis
56// and is only updated when visit(...) is called. The IR should not be mutated
57// between visit calls. Calling visit invalidates all iterators.
58// It is safe to mutate the IR while iterating. However, the iteration sequence
59// may contain invalid operation pointers, if the underlying operation is erased
60// after visiting the graph. To reflect changes to the graph in the analysis,
61// reset() must be called and the graph must be re-visited.
62//
63// Usage examples
64// -------------
65//
66// Check if seedOp can be reached from someOp:
67//
68// SparseOpSCC<OpSCCDirection::Backward> sccs(regFilter);
69// sccs.visit(seedOp);
70// if (OpSCC someScc = sccs.getSCC(someOp)) {
71// if (someScc == sccs.getSCC(seedOp)) {
72// if (auto cycScc = llvm::dyn_cast<CyclicOpSCC>(someScc)) {
73// // seedOp and someOp are on at least one common cycle
74// if (cycScc.size() == 1) {
75// // seedOp and someOp are equal and there is a self-loop (i.e., at
76// // least one operand of seedOp is a result of itself)
77// }
78// } else {
79// // seedOp and someOp are equal and there is no self-loop
80// // (trivial SCC)
81// }
82// } else {
83// // seedOp is reachable from someOp but not the other way around
84// // (someOp discovered during backwards traversal, but different SCCs)
85// }
86// } else {
87// // seedOp is not reachable from someOp (someOp not discovered during
88// // backwards traversal)
89// }
90//
91//
92// Collect all ops reachable from seedOp, excluding register ops, and process
93// them in topological order:
94//
95// auto regFilter = [](Operation *op, OpOperand&) {
96// return !isa<seq::FirRegOp>(op);
97// };
98// SparseOpSCC<OpSCCDirection::Forward> sccs(regFilter);
99// sccs.visit(seedOp);
100//
101// for (OpSCC entry : sccs.topological()) {
102// if (Operation *op = llvm::dyn_cast<mlir::Operation *>(entry)) {
103// // Trivial SCC: a single op with no cycle.
104// processSingle(op);
105// } else {
106// // Cyclic SCC: a group of mutually-reachable ops (or a self-loop).
107// for (Operation *op : llvm::cast<CyclicOpSCC>(entry))
108// processInCycle(op);
109// }
110// }
111//
112// Alternative filter that traverses registers through their clock and reset
113// values but not the "next" data values:
114//
115// auto regEdgeFilter = [](Operation*, OpOperand& operand) {
116// if (auto regOp = dyn_cast<seq::FirRegOp>(operand.getOwner()))
117// return operand != regOp.getNextMutable();
118// return true;
119// };
120//
121//===----------------------------------------------------------------------===//
122
123#ifndef CIRCT_SUPPORT_SPARSEOPSCC_H
124#define CIRCT_SUPPORT_SPARSEOPSCC_H
125
126#include "mlir/IR/Operation.h"
127#include "llvm/ADT/ArrayRef.h"
128#include "llvm/ADT/DenseMap.h"
129#include "llvm/ADT/PointerEmbeddedInt.h"
130#include "llvm/ADT/PointerUnion.h"
131#include "llvm/ADT/STLFunctionalExtras.h"
132#include "llvm/ADT/SetVector.h"
133#include "llvm/ADT/SmallVector.h"
134#include <type_traits>
135
136namespace circt {
137
138/// Filter predicate passed to the SparseOpSCC constructor. Return `true` to
139/// include an edge in the traversal, `false` to skip it. The first argument is
140/// the operation the traversal would enter. The second argument is the
141/// `OpOperand` being followed: for forward traversal its owner equals the
142/// first argument; for backward traversal its defining op equals the first
143/// argument.
144using OpSCCFilter = std::function<bool(mlir::Operation *, mlir::OpOperand &)>;
145namespace detail {
146/// Backing storage for a cyclic SCC (implementation detail).
147using CyclicOpSCCStorage = llvm::SmallVector<mlir::Operation *, 4>;
148} // namespace detail
149
150/// A cyclic SCC: a pointer-sized, directly-iterable reference to a group of
151/// mutually-reachable operations (or a single op with a self-loop).
152///
153/// Instances are obtained via llvm::cast<CyclicOpSCC> on an OpSCC entry.
154/// The referenced storage is owned by the SparseOpSCC that produced the entry.
156public:
157 using iterator = detail::CyclicOpSCCStorage::const_iterator;
158
159 CyclicOpSCC() : storage(nullptr) {}
161
162 iterator begin() const { return storage->begin(); }
163 iterator end() const { return storage->end(); }
164 size_t size() const { return storage->size(); }
165 mlir::Operation *const *data() const { return storage->data(); }
166 mlir::Operation *operator[](size_t i) const { return (*storage)[i]; }
167
168 operator bool() const { return storage != nullptr; }
169
170 bool operator==(CyclicOpSCC other) const { return storage == other.storage; }
171 bool operator!=(CyclicOpSCC other) const { return storage != other.storage; }
172
173 // Interface for PointerLikeTypeTraits.
174 void *getAsVoidPointer() const {
175 return const_cast<detail::CyclicOpSCCStorage *>(storage);
176 }
178 return CyclicOpSCC(static_cast<const detail::CyclicOpSCCStorage *>(p));
179 }
180 static constexpr int NumLowBitsAvailable = llvm::PointerLikeTypeTraits<
182
183private:
185};
186
187} // namespace circt
188
189namespace llvm {
190template <>
191struct PointerLikeTypeTraits<circt::CyclicOpSCC> {
193 return scc.getAsVoidPointer();
194 }
198 static constexpr int NumLowBitsAvailable =
200};
201} // namespace llvm
202
203namespace circt {
204
205/// One entry in the SCC output: a null sentinel, a trivial (non-cyclic)
206/// operation, or a cyclic group. Use llvm::isa / llvm::cast / llvm::dyn_cast
207/// to distinguish.
208/// Note: void * must be placed first in the union so that the all-zero
209/// (default-constructed) state identifies unambiguously as invalid, not as a
210/// null Operation*.
211using OpSCC = llvm::PointerUnion<void *, mlir::Operation *, CyclicOpSCC>;
212
213/// Traversal direction for SparseOpSCC.
214/// - Forward: follow def-use edges forward (defining op -> users).
215/// - Backward: follow def-use edges backward (user -> defining op).
217
218template <OpSCCDirection, unsigned>
219class SparseOpSCC;
220
221namespace detail {
222using OpSccEmbeddedIndex = llvm::PointerEmbeddedInt<unsigned, 31>;
223using OpOrIndex = llvm::PointerUnion<mlir::Operation *, OpSccEmbeddedIndex>;
224
225// Iterator template resolving indices to CyclicOpSCC
226template <typename BaseIteratorT>
227class OpSCCIterator final
228 : public llvm::mapped_iterator_base<OpSCCIterator<BaseIteratorT>,
229 BaseIteratorT, OpSCC> {
230public:
231 using llvm::mapped_iterator_base<OpSCCIterator<BaseIteratorT>, BaseIteratorT,
232 OpSCC>::mapped_iterator_base;
233
234 OpSCC mapElement(OpOrIndex opOrIndex) const {
235 if (llvm::isa<mlir::Operation *>(opOrIndex))
236 return llvm::cast<mlir::Operation *>(opOrIndex);
237 unsigned index = llvm::cast<OpSccEmbeddedIndex>(opOrIndex);
238 return CyclicOpSCC(&cyclicSccs[index]);
239 }
240
241private:
242 template <OpSCCDirection, unsigned>
243 friend class circt::SparseOpSCC;
244
245 OpSCCIterator(BaseIteratorT it,
246 const llvm::ArrayRef<CyclicOpSCCStorage> cyclicSccs)
247 : llvm::mapped_iterator_base<OpSCCIterator<BaseIteratorT>, BaseIteratorT,
248 OpSCC>(it),
250
251 const llvm::ArrayRef<CyclicOpSCCStorage> cyclicSccs;
252};
253
254} // namespace detail
255
256/// Iterative Tarjan SCC analysis on a sparse subgraph of MLIR operations.
257///
258/// Call visit() with one or more seed operations to trigger the DFS. Results
259/// accumulate across multiple visit() calls, so the discovered subgraph can be
260/// expanded incrementally.
261///
262/// The optional filter passed to the constructor is applied to every
263/// discovered edge before it is traversed. An edge that fails the filter is
264/// treated as if it did not exist in the graph.
265///
266/// Iterators obtained from topological() and reverseTopological() hold a
267/// reference into this object and are invalidated by calling visit() or
268// reset().
269template <OpSCCDirection Direction, unsigned NumInlineElts = 32>
271public:
274
275 /// Clear all accumulated state.
276 void reset() {
277 opToSccIndex.clear();
278 sccs.clear();
279 cyclicSccs.clear();
280 }
281
282 /// Seed `op` into the DFS if it has not already been discovered.
283 void visit(mlir::Operation *op) {
284 if (!opToSccIndex.contains(op))
285 tarjanImpl(op);
286 }
287
288 /// Visit each operation in `ops`, skipping already-discovered ones.
289 void visit(llvm::ArrayRef<mlir::Operation *> ops) {
290 for (auto *op : ops)
291 visit(op);
292 }
293
294 /// Return true if `op` was discovered (as a seed or transitively) by any
295 /// previous visit() call.
296 bool hasDiscovered(mlir::Operation *op) const {
297 return opToSccIndex.contains(op);
298 }
299
300 /// Return the SCC that `op` belongs to. If the operation has not been
301 /// discovered, it returns a `nullptr` sentinel.
302 OpSCC getSCC(mlir::Operation *op) const {
303 auto it = opToSccIndex.find(op);
304 if (it == opToSccIndex.end())
305 return OpSCC(nullptr);
306 detail::OpOrIndex entry = sccs[it->second];
307 if (llvm::isa<mlir::Operation *>(entry))
308 return OpSCC(llvm::cast<mlir::Operation *>(entry));
309 unsigned cyclicIdx = llvm::cast<detail::OpSccEmbeddedIndex>(entry);
310 return OpSCC(CyclicOpSCC(&cyclicSccs[cyclicIdx]));
311 }
312
313 /// Number of operations discovered so far across all visit() calls.
314 unsigned getNumDiscovered() const { return opToSccIndex.size(); }
315 /// Total number of SCC entries emitted (trivial ops + cyclic groups).
316 unsigned getNumSCCs() const { return sccs.size(); }
317 /// Number of cyclic SCC groups (excludes trivial ops).
318 unsigned getNumCyclicSCCs() const { return cyclicSccs.size(); }
319
320 /// Iterate over SCCs in topological order (sources/seeds first, leaves last).
321 auto topological() const {
322 return llvm::iterator_range(topological_begin(), topological_end());
323 }
324
325 // NOLINTNEXTLINE(readability-identifier-naming)
326 auto topological_begin() const {
327 if constexpr (Direction == OpSCCDirection::Backward)
328 return detail::OpSCCIterator<typename decltype(sccs)::const_iterator>(
329 sccs.begin(), cyclicSccs);
330 else
332 typename decltype(sccs)::const_reverse_iterator>(sccs.rbegin(),
333 cyclicSccs);
334 }
335
336 // NOLINTNEXTLINE(readability-identifier-naming)
337 auto topological_end() const {
338 if constexpr (Direction == OpSCCDirection::Backward)
339 return detail::OpSCCIterator<typename decltype(sccs)::const_iterator>(
340 sccs.end(), cyclicSccs);
341 else
343 typename decltype(sccs)::const_reverse_iterator>(sccs.rend(),
344 cyclicSccs);
345 }
346
347 /// Iterate over SCCs in reverse topological order (leaves first).
348 auto reverseTopological() const {
349 return llvm::iterator_range(reverseTopological_begin(),
351 }
352
353 // NOLINTNEXTLINE(readability-identifier-naming)
355 if constexpr (Direction == OpSCCDirection::Forward)
356 return detail::OpSCCIterator<typename decltype(sccs)::const_iterator>(
357 sccs.begin(), cyclicSccs);
358 else
360 typename decltype(sccs)::const_reverse_iterator>(sccs.rbegin(),
361 cyclicSccs);
362 }
363
364 // NOLINTNEXTLINE(readability-identifier-naming)
366 if constexpr (Direction == OpSCCDirection::Forward)
367 return detail::OpSCCIterator<typename decltype(sccs)::const_iterator>(
368 sccs.end(), cyclicSccs);
369 else
371 typename decltype(sccs)::const_reverse_iterator>(sccs.rend(),
372 cyclicSccs);
373 }
374
375private:
376 // DFS stack frame for forward traversal. Skips over unused results.
378 mlir::Operation *op;
379 std::optional<mlir::Value::use_iterator> useIt;
380 unsigned resultIdx;
381 bool hasSelfLoop = false;
382
383 explicit ForwardFrame(mlir::Operation *op)
384 : op(op), useIt(std::nullopt), resultIdx(0) {
385 if (op->getNumResults() > 0)
386 useIt = op->getResult(0).use_begin();
387 }
388
390 while (resultIdx < op->getNumResults()) {
391 auto useEnd = op->getResult(resultIdx).use_end();
392 while (*useIt != useEnd) {
393 mlir::OpOperand &use = **useIt;
394 ++(*useIt);
395 if (!shouldTraverseFn || shouldTraverseFn(use.getOwner(), use))
396 return use.getOwner();
397 }
398 ++resultIdx;
399 if (resultIdx < op->getNumResults())
400 useIt = op->getResult(resultIdx).use_begin();
401 }
402 return nullptr;
403 }
404 };
405
406 // DFS stack frame for backward traversal. Skips over block arguments.
408 mlir::Operation *op;
409 unsigned operandIdx;
410 bool hasSelfLoop = false;
411
412 explicit BackwardFrame(mlir::Operation *op) : op(op), operandIdx(0) {}
413
415 while (operandIdx < op->getNumOperands()) {
416 mlir::OpOperand &operand = op->getOpOperand(operandIdx++);
417 auto *defOp = operand.get().getDefiningOp();
418 if (defOp && (!shouldTraverseFn || shouldTraverseFn(defOp, operand)))
419 return defOp;
420 }
421 return nullptr;
422 }
423 };
424
425 using FrameT = std::conditional_t<Direction == OpSCCDirection::Forward,
427
428 void tarjanImpl(mlir::Operation *startOp) {
429 unsigned nextIdx = 0;
431 idxAndLowLinkMap;
432 llvm::SetVector<mlir::Operation *> sccStack;
433 llvm::SmallVector<FrameT> dfsStack;
434
435 auto pushFrame = [&](mlir::Operation *op) {
436 idxAndLowLinkMap[op] = {nextIdx, nextIdx};
437 ++nextIdx;
438 sccStack.insert(op);
439 dfsStack.push_back(FrameT(op));
440 };
441
442 pushFrame(startOp);
443
444 while (!dfsStack.empty()) {
445 FrameT &frame = dfsStack.back();
446 mlir::Operation *op = frame.op;
447
448 if (auto *child = frame.nextChild(shouldTraverseFn)) {
449 if (child == op) {
450 // Self-loop — record it in the frame; no lowlink update needed.
451 frame.hasSelfLoop = true;
452 } else {
453 auto it = idxAndLowLinkMap.find(child);
454 if (it != idxAndLowLinkMap.end()) {
455 // Already seen in this DFS.
456 if (sccStack.contains(child))
457 // Back edge — update lowlink.
458 idxAndLowLinkMap[op].second =
459 std::min(idxAndLowLinkMap[op].second, it->second.first);
460 // else: forward/cross edge within this DFS — ignore.
461 } else if (!opToSccIndex.contains(child)) {
462 // Not yet seen in any DFS — recurse.
463 pushFrame(child);
464 }
465 // else: completed in a previous visit() call — cross edge, ignore.
466 }
467 continue;
468 }
469
470 // All children processed — backtrack.
471 bool selfLoop = frame.hasSelfLoop;
472 auto [opIndex, opLowLink] = idxAndLowLinkMap.at(op);
473 dfsStack.pop_back();
474
475 // If op is the root of its SCC, pop and emit it.
476 if (opLowLink == opIndex) {
478 do {
479 sccOps.push_back(sccStack.pop_back_val());
480 } while (sccOps.back() != op);
481
482 // Store the SCC index of the discovered ops
483 unsigned sccIdx = sccs.size();
484 for (auto *sccOp : sccOps) {
485 bool inserted = opToSccIndex.insert({sccOp, sccIdx}).second;
486 (void)inserted;
487 assert(inserted && "Unexpectedly revisited node");
488 }
489
490 // Insert the pointers into the persistent storage
491 if (sccOps.size() == 1 && !selfLoop) {
492 sccs.push_back(detail::OpOrIndex(sccOps.front()));
493 } else {
494 unsigned cyclicIdx = cyclicSccs.size();
495 cyclicSccs.emplace_back(std::move(sccOps));
496 sccs.push_back(detail::OpOrIndex(cyclicIdx));
497 }
498 continue;
499 }
500
501 // Not an SCC root — back-propagate lowlink to the parent frame.
502 auto &parentLowLink = idxAndLowLinkMap.at(dfsStack.back().op).second;
503 parentLowLink = std::min(parentLowLink, opLowLink);
504 }
505 assert(sccStack.empty());
506 }
507
508 /// Optional edge filter supplied at construction time.
510
511 /// Maps each visited op to the index of its SCC in `sccs`. Persists across
512 /// visit() calls and is the authoritative "already visited" guard.
514 /// Flat list of SCC entries emitted by Tarjan, in emission order.
515 /// Trivial SCCs are stored directly. Cyclic SCCs are stored as index into
516 /// the `cyclicSccs` vector.
517 llvm::SmallVector<detail::OpOrIndex, NumInlineElts> sccs;
518 /// Backing storage for cyclic SCCs; CyclicOpSCC holds a pointer into here.
519 llvm::SmallVector<detail::CyclicOpSCCStorage, 0> cyclicSccs;
520};
521
522} // namespace circt
523
524#endif // CIRCT_SUPPORT_SPARSEOPSCC_H
assert(baseType &&"element must be base type")
A cyclic SCC: a pointer-sized, directly-iterable reference to a group of mutually-reachable operation...
iterator end() const
bool operator==(CyclicOpSCC other) const
CyclicOpSCC(const detail::CyclicOpSCCStorage *storage)
static constexpr int NumLowBitsAvailable
detail::CyclicOpSCCStorage::const_iterator iterator
size_t size() const
static CyclicOpSCC getFromVoidPointer(void *p)
bool operator!=(CyclicOpSCC other) const
iterator begin() const
void * getAsVoidPointer() const
mlir::Operation *const * data() const
mlir::Operation * operator[](size_t i) const
const detail::CyclicOpSCCStorage * storage
Iterative Tarjan SCC analysis on a sparse subgraph of MLIR operations.
void reset()
Clear all accumulated state.
auto reverseTopological_end() const
void visit(mlir::Operation *op)
Seed op into the DFS if it has not already been discovered.
llvm::SmallDenseMap< mlir::Operation *, unsigned, NumInlineElts > opToSccIndex
Maps each visited op to the index of its SCC in sccs.
auto reverseTopological() const
Iterate over SCCs in reverse topological order (leaves first).
unsigned getNumCyclicSCCs() const
Number of cyclic SCC groups (excludes trivial ops).
OpSCC getSCC(mlir::Operation *op) const
Return the SCC that op belongs to.
std::conditional_t< Direction==OpSCCDirection::Forward, ForwardFrame, BackwardFrame > FrameT
void tarjanImpl(mlir::Operation *startOp)
llvm::SmallVector< detail::CyclicOpSCCStorage, 0 > cyclicSccs
Backing storage for cyclic SCCs; CyclicOpSCC holds a pointer into here.
llvm::SmallVector< detail::OpOrIndex, NumInlineElts > sccs
Flat list of SCC entries emitted by Tarjan, in emission order.
auto topological_end() const
void visit(llvm::ArrayRef< mlir::Operation * > ops)
Visit each operation in ops, skipping already-discovered ones.
bool hasDiscovered(mlir::Operation *op) const
Return true if op was discovered (as a seed or transitively) by any previous visit() call.
unsigned getNumDiscovered() const
Number of operations discovered so far across all visit() calls.
auto topological() const
Iterate over SCCs in topological order (sources/seeds first, leaves last).
SparseOpSCC(OpSCCFilter shouldTraverseFn={})
unsigned getNumSCCs() const
Total number of SCC entries emitted (trivial ops + cyclic groups).
auto reverseTopological_begin() const
OpSCCFilter shouldTraverseFn
Optional edge filter supplied at construction time.
auto topological_begin() const
OpSCCIterator(BaseIteratorT it, const llvm::ArrayRef< CyclicOpSCCStorage > cyclicSccs)
OpSCC mapElement(OpOrIndex opOrIndex) const
const llvm::ArrayRef< CyclicOpSCCStorage > cyclicSccs
llvm::PointerUnion< mlir::Operation *, OpSccEmbeddedIndex > OpOrIndex
llvm::PointerEmbeddedInt< unsigned, 31 > OpSccEmbeddedIndex
llvm::SmallVector< mlir::Operation *, 4 > CyclicOpSCCStorage
Backing storage for a cyclic SCC (implementation detail).
The InstanceGraph op interface, see InstanceGraphInterface.td for more details.
std::function< bool(mlir::Operation *, mlir::OpOperand &)> OpSCCFilter
Filter predicate passed to the SparseOpSCC constructor.
OpSCCDirection
Traversal direction for SparseOpSCC.
llvm::PointerUnion< void *, mlir::Operation *, CyclicOpSCC > OpSCC
One entry in the SCC output: a null sentinel, a trivial (non-cyclic) operation, or a cyclic group.
mlir::Operation * nextChild(OpSCCFilter shouldTraverseFn)
BackwardFrame(mlir::Operation *op)
ForwardFrame(mlir::Operation *op)
std::optional< mlir::Value::use_iterator > useIt
mlir::Operation * nextChild(OpSCCFilter shouldTraverseFn)
static circt::CyclicOpSCC getFromVoidPointer(void *p)
static void * getAsVoidPointer(circt::CyclicOpSCC scc)